trio.TrainingClient
class TrainingClient:
task_id: str
base_model: str
lora: dictTrainingClient 是用于 LoRA 模型训练的客户端,通过 ServiceClient.create_lora_training_client() 创建。
client = ServiceClient()
training_client = client.create_lora_training_client(base_model="Qwen/Qwen2.5-3B")
tokenizer = training_client.get_tokenizer()
tokens = tokenizer.encode("The meaning of life is")
data = [Datum(model_input=tokens, loss_fn_inputs={"target_tokens": tokens})]
# 前反向传播
future = training_client.forward_backward(data=data)
output = future.result()
# 梯度更新
training_client.optim_step(AdamParams(learning_rate=1e-4)).result()属性
| 属性 | 类型 | 说明 |
|---|---|---|
task_id | str | 当前训练任务 ID |
base_model | str | 使用的基础模型名称 |
lora | dict | LoRA 初始化参数 |
方法
forward
def forward(
self,
data: list[Datum],
loss_fn: str = "cross_entropy",
loss_fn_config: dict[str, Any] | None = None,
auto_shift: bool = False,
) -> APIFuture[ForwardBackwardOutput]仅执行前向传播,计算损失但不更新梯度。
参数
| 参数 | 类型 | 默认值 | 说明 |
|---|---|---|---|
data | list[Datum] | — | 样本列表,每个样本包含输入 token ids 和损失函数参数 |
loss_fn | str | "cross_entropy" | 损失函数类型:"cross_entropy" / "importance_sampling" / "ppo" |
loss_fn_config | dict | None | None | 损失函数的额外配置项 |
auto_shift | bool | False | 为 True 时自动将 labels 偏移一位对齐预测目标 |
返回值
APIFuture[ForwardBackwardOutput] — 调用 .result() 获取输出。
示例
future = training_client.forward(data=data)
output = future.result()
print(output.metrics)forward_backward
def forward_backward(
self,
data: list[Datum],
loss_fn: str = "cross_entropy",
loss_fn_config: dict[str, Any] | None = None,
auto_shift: bool = False,
) -> APIFuture[ForwardBackwardOutput]执行前向 + 反向传播,计算并累积梯度。
参数
| 参数 | 类型 | 默认值 | 说明 |
|---|---|---|---|
data | list[Datum] | — | 样本列表,每个样本包含输入 token ids 和损失函数参数 |
loss_fn | str | "cross_entropy" | 损失函数类型:"cross_entropy" / "importance_sampling" / "ppo" |
loss_fn_config | dict | None | None | 损失函数的额外配置项 |
auto_shift | bool | False | 为 True 时自动将 labels 偏移一位对齐预测目标 |
返回值
APIFuture[ForwardBackwardOutput] — 调用 .result() 获取输出。
示例
future = training_client.forward_backward(data=data, loss_fn="cross_entropy")
output = future.result()forward_backward_custom
def forward_backward_custom(
self,
data: list[Datum],
loss_fn: Callable[
[list[Datum], list[torch.Tensor]],
tuple[torch.Tensor, dict[str, float]]
],
) -> APIFuture[ForwardBackwardOutput]使用自定义 PyTorch 损失函数执行前向 + 反向传播。需要本地安装 torch。
参数
| 参数 | 类型 | 说明 |
|---|---|---|
data | list[Datum] | 样本列表 |
loss_fn | Callable | 接收 (data, logprobs) 并返回 (loss_tensor, metrics_dict) 的函数 |
返回值
APIFuture[ForwardBackwardOutput] — 调用 .result() 获取输出,metrics 中会合并自定义指标。
示例
def my_loss(data, logprobs):
loss = -sum(lp.mean() for lp in logprobs)
return loss, {"my_loss": loss.item()}
future = training_client.forward_backward_custom(data=data, loss_fn=my_loss)
output = future.result()optim_step
def optim_step(self, adam_params: AdamParams) -> APIFuture[OptimStepResponse]根据当前累积的梯度执行一次 Adam 优化器更新,并清零梯度。
参数
| 参数 | 类型 | 说明 |
|---|---|---|
adam_params | AdamParams | Adam 优化器参数,见 AdamParams |
返回值
APIFuture[OptimStepResponse]
示例
training_client.optim_step(AdamParams(learning_rate=1e-4)).result()save_state
def save_state(self, name: str) -> APIFuture[SaveWeightsResponse]保存模型权重和优化器状态(完整 checkpoint),用于断点续训。
参数
| 参数 | 类型 | 说明 |
|---|---|---|
name | str | checkpoint 名称 |
返回值
APIFuture[SaveWeightsResponse]
示例
result = training_client.save_state(name="step-100").result()
print(result.path)save_weights_for_sampler
def save_weights_for_sampler(self, name: str) -> APIFuture[SaveWeightsForSamplerResponse]仅保存模型权重(不含优化器状态),用于后续推理采样。
参数
| 参数 | 类型 | 说明 |
|---|---|---|
name | str | 权重保存名称 |
返回值
APIFuture[SaveWeightsForSamplerResponse]
示例
result = training_client.save_weights_for_sampler(name="step-100").result()
print(result.path)create_sampling_client
def create_sampling_client(self, model_path: str) -> SamplingClient基于指定的 LoRA 权重路径创建 SamplingClient,可在训练过程中随时启动推理。
参数
| 参数 | 类型 | 说明 |
|---|---|---|
model_path | str | LoRA 权重路径 |
返回值
SamplingClient
示例
sampling_client = training_client.create_sampling_client(model_path="/path/to/weights")save_weights_and_get_sampling_client
def save_weights_and_get_sampling_client(self, name: str) -> SamplingClient保存模型权重并立即返回一个已加载该权重的 SamplingClient,相当于 save_weights_for_sampler + create_sampling_client 的组合。
参数
| 参数 | 类型 | 说明 |
|---|---|---|
name | str | 权重保存名称 |
返回值
SamplingClient
示例
sampling_client = training_client.save_weights_and_get_sampling_client(name="step-100")get_tokenizer
def get_tokenizer(self)获取与当前基础模型匹配的 tokenizer,基于 transformers / modelscope 提供的 AutoTokenizer。
示例
tokenizer = training_client.get_tokenizer()
tokens = tokenizer.encode("The meaning of life is")异步方法
forward_async
async def forward_async(
self,
data: list[Datum],
loss_fn: str = "cross_entropy",
loss_fn_config: dict[str, Any] | None = None,
auto_shift: bool = False,
) -> APIFuture[ForwardBackwardOutput]forward 的异步版本,参数相同。
future = await training_client.forward_async(data=data)
output = future.result()forward_backward_async
async def forward_backward_async(
self,
data: list[Datum],
loss_fn: str = "cross_entropy",
loss_fn_config: dict[str, Any] | None = None,
auto_shift: bool = False,
) -> APIFuture[ForwardBackwardOutput]forward_backward 的异步版本,参数相同。
future = await training_client.forward_backward_async(data=data)
output = future.result()forward_backward_custom_async
async def forward_backward_custom_async(
self,
data: list[Datum],
loss_fn: Callable[
[list[Datum], list[torch.Tensor]],
tuple[torch.Tensor, dict[str, float]]
],
) -> APIFuture[ForwardBackwardOutput]forward_backward_custom 的异步版本,参数相同。
future = await training_client.forward_backward_custom_async(data=data, loss_fn=my_loss)
output = future.result()optim_step_async
async def optim_step_async(self, adam_params: AdamParams) -> APIFuture[OptimStepResponse]optim_step 的异步版本,参数相同。
future = await training_client.optim_step_async(AdamParams(learning_rate=1e-4))
await futuresave_state_async
async def save_state_async(self, name: str) -> APIFuture[SaveWeightsResponse]save_state 的异步版本,参数相同。
future = await training_client.save_state_async(name="step-100")
result = await futuresave_weights_for_sampler_async
async def save_weights_for_sampler_async(self, name: str) -> APIFuture[SaveWeightsForSamplerResponse]save_weights_for_sampler 的异步版本,参数相同。
future = await training_client.save_weights_for_sampler_async(name="step-100")
result = await futurecreate_sampling_client_async
async def create_sampling_client_async(self, model_path: str) -> SamplingClientcreate_sampling_client 的异步版本,参数相同。
sampling_client = await training_client.create_sampling_client_async(model_path="/path/to/weights")save_weights_and_get_sampling_client_async
async def save_weights_and_get_sampling_client_async(self, name: str) -> SamplingClientsave_weights_and_get_sampling_client 的异步版本,参数相同。
sampling_client = await training_client.save_weights_and_get_sampling_client_async(name="step-100")