API

trio.TrainingClient

class TrainingClient:
    task_id: str
    base_model: str
    lora: dict

TrainingClient 是用于 LoRA 模型训练的客户端,通过 ServiceClient.create_lora_training_client() 创建。

client = ServiceClient()
training_client = client.create_lora_training_client(base_model="Qwen/Qwen2.5-3B")

tokenizer = training_client.get_tokenizer()
tokens = tokenizer.encode("The meaning of life is")

data = [Datum(model_input=tokens, loss_fn_inputs={"target_tokens": tokens})]

# 前反向传播
future = training_client.forward_backward(data=data)
output = future.result()

# 梯度更新
training_client.optim_step(AdamParams(learning_rate=1e-4)).result()

属性

属性类型说明
task_idstr当前训练任务 ID
base_modelstr使用的基础模型名称
loradictLoRA 初始化参数

方法

forward

def forward(
    self,
    data: list[Datum],
    loss_fn: str = "cross_entropy",
    loss_fn_config: dict[str, Any] | None = None,
    auto_shift: bool = False,
) -> APIFuture[ForwardBackwardOutput]

仅执行前向传播,计算损失但不更新梯度。

参数

参数类型默认值说明
datalist[Datum]样本列表,每个样本包含输入 token ids 和损失函数参数
loss_fnstr"cross_entropy"损失函数类型:"cross_entropy" / "importance_sampling" / "ppo"
loss_fn_configdict | NoneNone损失函数的额外配置项
auto_shiftboolFalseTrue 时自动将 labels 偏移一位对齐预测目标

返回值

APIFuture[ForwardBackwardOutput] — 调用 .result() 获取输出。

示例

future = training_client.forward(data=data)
output = future.result()
print(output.metrics)

forward_backward

def forward_backward(
    self,
    data: list[Datum],
    loss_fn: str = "cross_entropy",
    loss_fn_config: dict[str, Any] | None = None,
    auto_shift: bool = False,
) -> APIFuture[ForwardBackwardOutput]

执行前向 + 反向传播,计算并累积梯度。

参数

参数类型默认值说明
datalist[Datum]样本列表,每个样本包含输入 token ids 和损失函数参数
loss_fnstr"cross_entropy"损失函数类型:"cross_entropy" / "importance_sampling" / "ppo"
loss_fn_configdict | NoneNone损失函数的额外配置项
auto_shiftboolFalseTrue 时自动将 labels 偏移一位对齐预测目标

返回值

APIFuture[ForwardBackwardOutput] — 调用 .result() 获取输出。

示例

future = training_client.forward_backward(data=data, loss_fn="cross_entropy")
output = future.result()

forward_backward_custom

def forward_backward_custom(
    self,
    data: list[Datum],
    loss_fn: Callable[
        [list[Datum], list[torch.Tensor]],
        tuple[torch.Tensor, dict[str, float]]
    ],
) -> APIFuture[ForwardBackwardOutput]

使用自定义 PyTorch 损失函数执行前向 + 反向传播。需要本地安装 torch

参数

参数类型说明
datalist[Datum]样本列表
loss_fnCallable接收 (data, logprobs) 并返回 (loss_tensor, metrics_dict) 的函数

返回值

APIFuture[ForwardBackwardOutput] — 调用 .result() 获取输出,metrics 中会合并自定义指标。

示例

def my_loss(data, logprobs):
    loss = -sum(lp.mean() for lp in logprobs)
    return loss, {"my_loss": loss.item()}

future = training_client.forward_backward_custom(data=data, loss_fn=my_loss)
output = future.result()

optim_step

def optim_step(self, adam_params: AdamParams) -> APIFuture[OptimStepResponse]

根据当前累积的梯度执行一次 Adam 优化器更新,并清零梯度。

参数

参数类型说明
adam_paramsAdamParamsAdam 优化器参数,见 AdamParams

返回值

APIFuture[OptimStepResponse]

示例

training_client.optim_step(AdamParams(learning_rate=1e-4)).result()

save_state

def save_state(self, name: str) -> APIFuture[SaveWeightsResponse]

保存模型权重和优化器状态(完整 checkpoint),用于断点续训。

参数

参数类型说明
namestrcheckpoint 名称

返回值

APIFuture[SaveWeightsResponse]

示例

result = training_client.save_state(name="step-100").result()
print(result.path)

save_weights_for_sampler

def save_weights_for_sampler(self, name: str) -> APIFuture[SaveWeightsForSamplerResponse]

仅保存模型权重(不含优化器状态),用于后续推理采样。

参数

参数类型说明
namestr权重保存名称

返回值

APIFuture[SaveWeightsForSamplerResponse]

示例

result = training_client.save_weights_for_sampler(name="step-100").result()
print(result.path)

create_sampling_client

def create_sampling_client(self, model_path: str) -> SamplingClient

基于指定的 LoRA 权重路径创建 SamplingClient,可在训练过程中随时启动推理。

参数

参数类型说明
model_pathstrLoRA 权重路径

返回值

SamplingClient

示例

sampling_client = training_client.create_sampling_client(model_path="/path/to/weights")

save_weights_and_get_sampling_client

def save_weights_and_get_sampling_client(self, name: str) -> SamplingClient

保存模型权重并立即返回一个已加载该权重的 SamplingClient,相当于 save_weights_for_sampler + create_sampling_client 的组合。

参数

参数类型说明
namestr权重保存名称

返回值

SamplingClient

示例

sampling_client = training_client.save_weights_and_get_sampling_client(name="step-100")

get_tokenizer

def get_tokenizer(self)

获取与当前基础模型匹配的 tokenizer,基于 transformers / modelscope 提供的 AutoTokenizer

示例

tokenizer = training_client.get_tokenizer()
tokens = tokenizer.encode("The meaning of life is")

异步方法

forward_async

async def forward_async(
    self,
    data: list[Datum],
    loss_fn: str = "cross_entropy",
    loss_fn_config: dict[str, Any] | None = None,
    auto_shift: bool = False,
) -> APIFuture[ForwardBackwardOutput]

forward 的异步版本,参数相同。

future = await training_client.forward_async(data=data)
output = future.result()

forward_backward_async

async def forward_backward_async(
    self,
    data: list[Datum],
    loss_fn: str = "cross_entropy",
    loss_fn_config: dict[str, Any] | None = None,
    auto_shift: bool = False,
) -> APIFuture[ForwardBackwardOutput]

forward_backward 的异步版本,参数相同。

future = await training_client.forward_backward_async(data=data)
output = future.result()

forward_backward_custom_async

async def forward_backward_custom_async(
    self,
    data: list[Datum],
    loss_fn: Callable[
        [list[Datum], list[torch.Tensor]],
        tuple[torch.Tensor, dict[str, float]]
    ],
) -> APIFuture[ForwardBackwardOutput]

forward_backward_custom 的异步版本,参数相同。

future = await training_client.forward_backward_custom_async(data=data, loss_fn=my_loss)
output = future.result()

optim_step_async

async def optim_step_async(self, adam_params: AdamParams) -> APIFuture[OptimStepResponse]

optim_step 的异步版本,参数相同。

future = await training_client.optim_step_async(AdamParams(learning_rate=1e-4))
await future

save_state_async

async def save_state_async(self, name: str) -> APIFuture[SaveWeightsResponse]

save_state 的异步版本,参数相同。

future = await training_client.save_state_async(name="step-100")
result = await future

save_weights_for_sampler_async

async def save_weights_for_sampler_async(self, name: str) -> APIFuture[SaveWeightsForSamplerResponse]

save_weights_for_sampler 的异步版本,参数相同。

future = await training_client.save_weights_for_sampler_async(name="step-100")
result = await future

create_sampling_client_async

async def create_sampling_client_async(self, model_path: str) -> SamplingClient

create_sampling_client 的异步版本,参数相同。

sampling_client = await training_client.create_sampling_client_async(model_path="/path/to/weights")

save_weights_and_get_sampling_client_async

async def save_weights_and_get_sampling_client_async(self, name: str) -> SamplingClient

save_weights_and_get_sampling_client 的异步版本,参数相同。

sampling_client = await training_client.save_weights_and_get_sampling_client_async(name="step-100")

On this page