损失函数
TRIO 为监督学习和强化学习提供了内置的损失函数。
你可以通过将字符串传递给 forward_backward 来选择损失函数:
future = training_client.forward_backward(
data,
loss_fn="cross_entropy",
)
result = future.result()内置损失函数
目前 TRIO 支持的内置损失函数如下:
| 损失函数 | 适用场景 | 说明 |
|---|---|---|
cross_entropy | 监督学习 | 标准交叉熵损失,适用于分类任务。以模型输出的 logits 和目标标签计算负对数似然。 |
importance_sampling | 离线强化学习 | 使用重要性采样对 off-policy 数据进行修正,通过行为策略与目标策略的概率比值对梯度加权。 |
ppo | 在线强化学习 | Proximal Policy Optimization 损失,通过裁剪概率比值限制策略更新幅度,提升训练稳定性。 |
cross_entropy
在监督学习中,我们实现了标准的交叉熵损失(即负对数似然),该损失优化策略 以最大化 token 的对数概率:
其中 weights 为 0 或 1,通常由 renderer.build_supervised_example() 生成,该函数返回 (model_input, weights)(即用于指定需要训练的目标助手轮次)。
其实现方式为:
# Apply weights and compute elementwise loss
elementwise_loss = -target_logprobs * weights
# Apply sum reduction to get the total loss
loss = elementwise_loss.sum() # scalarcross_entropy损失需要Datum的loss_fn_inputs中传入target_tokens和weights两个字段:
target_tokens: array[(N,), int] | array[(N, K), int]:target token IDsweights: array[(N,), float] | array[(N, K), float]:token级的损失权重(通常来自渲染器)
输出:
logprobs: array[(N,), float] | array[(N, K), float]:请求的target token的对数概率
指标:
loss:sum:加权交叉熵损失的累加值,是一个标量
importance_sampling
对于强化学习,我们实现了策略梯度目标的一个常见变体,适用于学习策略 与采样策略 存在差异的实际场景(例如由于非确定性导致的 off-policy 情况)。
问题在于,若两者存在差异,则目标:
由于 (采样器)并不严格等同于期望的 (学习器),会导致估计有偏。为修正此偏差,我们采用改进的"重要性采样"目标:
该目标可得到正确的期望奖励。公式中:
- (
target_logprobs)来自学习器,在forward_backward的前向阶段计算。 - (
sampling_logprobs)来自采样器,在采样时记录,用作修正项。
其实现方式为:
# Compute probability ratio
prob_ratio = torch.exp(target_logprobs - sampling_logprobs)
# Compute importance-weighted loss
loss = -(prob_ratio * advantages).sum()importance_sampling 损失需要 Datum 的 loss_fn_inputs 中传入以下字段:
target_tokens: array[(N,), int]:target token IDs(来自采样器 )logprobs: array[(N,), float]:token 的sampling_logprobsadvantages: array[(N,), float]:RL 的优势值(正值表示强化,负值表示抑制)
输出:
logprobs: array[(N,), float]:token 的target_logprobs
指标:
loss:sum:重要性加权策略梯度损失 的累加值,是一个标量
ppo
PPO(Schulman et al., 2017)通过引入裁剪目标函数来解决标准策略梯度方法的问题,将策略更新限制在采样分布的邻域内,从而防止在同一 rollout 分布上进行多步梯度更新时出现过大的策略偏移。
该目标函数通过裁剪重要性比值 来防止策略更新幅度过大,其中 为学习器策略, 为采样策略。注意,PPO 的裁剪与损失计算均以 token 为单位独立进行。
PPO 裁剪目标为:
最终 PPO 损失结合了裁剪与未裁剪两个目标:
其中 和 为超参数(当前在 TRIO 中固定为 0.2)。
其实现方式为:
# Compute probability ratio
prob_ratio = torch.exp(target_logprobs - sampling_logprobs)
# Apply clipping
clipped_ratio = torch.clamp(prob_ratio, clip_low_threshold, clip_high_threshold)
# Compute both objectives
unclipped_objective = prob_ratio * advantages
clipped_objective = clipped_ratio * advantages
# Take minimum (most conservative)
ppo_objective = torch.min(unclipped_objective, clipped_objective)
# PPO loss is negative of objective
loss = -ppo_objective.sum()ppo 损失需要 Datum 的 loss_fn_inputs 中传入以下字段:
target_tokens: array[(N,), int]:target token IDs(来自采样器 )logprobs: array[(N,), float]:token 的sampling_logprobsadvantages: array[(N,), float]:RL 的优势值
输出:
logprobs: array[(N,), float]:token 的target_logprobs
指标:
loss:sum:PPO 裁剪损失的累加值,是一个标量
ps:还可通过 loss_fn_config 自定义裁剪阈值:
fwd_bwd_future = await training_client.forward_backward_async(
data=data,
loss_fn="ppo",
loss_fn_config={"clip_low_threshold": 0.9, "clip_high_threshold": 1.1}
)
fwd_bwd_result = await fwd_bwd_future.result_async()自定义损失函数
对于内置损失函数之外的使用场景,TRIO 提供了更灵活(但速度较慢)的方法 forward_backward_custom 来计算更通用的损失函数。
forward_backward_custom接收数据和模型的logprobs,并返回loss以及可选的指标。
比如我们希望实现1个损失函数,它的逻辑是希望每个 logprob 尽可能接近 0(也就是概率接近 1):
实现代码为:
def logprob_squared_loss(data: list[trio.Datum], logprobs: list[torch.Tensor]) -> tuple[torch.Tensor, dict[str, float]]:
flat_logprobs = torch.cat(logprobs)
loss = (flat_logprobs ** 2).sum()
return loss, {"logprob_squared_loss": loss.item()}使用 forward_backward_custom 调用它:
future = training_client.forward_backward_custom(data, logprob_squared_loss)
result = future.result()
print(f"Loss: {result.loss}, Metrics: {result.metrics}")我们改造一下 SFT 示例为自定义损失函数:
import pytrio as trio
import torch
# 1. 与TRIO建立连接
service_client = trio.ServiceClient()
# 2. 创建1个训练客户端
base_model = "Qwen/Qwen3-4B-Instruct-2507"
training_client = service_client.create_lora_training_client(
base_model=base_model,
rank=32,
)
# 3. 数据集-让LLM答对什么是trio
examples = [
{"input": "what is trio", "output": "trio is emotionmachine's AI Infra products."},
{"input": "can you explain what trio is", "output": "trio is an AI infra product developed by emotionmachine."},
{"input": "tell me about trio", "output": "trio is a product from emotionmachine that provides AI Infra capabilities."},
]
# 4. 获取Tokenizer
print("Loading tokenizer...")
tokenizer = training_client.get_tokenizer()
print("Tokenizer finish")
# 5. 处理数据集,转换为训练需要的格式
def process_example(example: dict, tokenizer) -> trio.Datum:
prompt = f"Question: {example['input']}\nAnswer:"
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
prompt_weights = [0] * len(prompt_tokens)
completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False)
completion_weights = [1] * len(completion_tokens)
tokens = prompt_tokens + completion_tokens
weights = prompt_weights + completion_weights
input_tokens = tokens[:-1]
target_tokens = tokens[1:]
weights = weights[1:]
# 转换为trio训练需要的格式
return trio.Datum(
model_input=trio.ModelInput.from_ints(tokens=input_tokens),
loss_fn_inputs=dict(weights=weights, target_tokens=target_tokens)
)
processed_examples = [process_example(ex, tokenizer) for ex in examples]
# 6. 自定义损失函数
def logprob_squared_loss(data: list[trio.Datum], logprobs: list[torch.Tensor]) -> tuple[torch.Tensor, dict[str, float]]:
flat_logprobs = torch.cat(logprobs)
loss = (flat_logprobs ** 2).sum()
return loss, {"logprob_squared_loss": loss.item()}
# 7. 训练
print("Start Training")
for iter in range(15):
fwdbwd_future = training_client.forward_backward_custom(processed_examples, logprob_squared_loss)
optim_future = training_client.optim_step(trio.AdamParams(learning_rate=1e-4))
fwdbwd_result = fwdbwd_future.result()
optim_result = optim_future.result()
print(f"Iter{iter+1} Logprob_squared_loss: {fwdbwd_result.metrics['logprob_squared_loss']:.4f}")
# 7. 推理与评估
print("Start Sampling")
sampling_base_client = service_client.create_sampling_client(base_model=base_model)
training_client.save_state(name="Train")
sampling_sft_client = training_client.save_weights_and_get_sampling_client(name='what-is-trio')
prompt = trio.ModelInput.from_ints(tokenizer.encode("Question: what is trio\nAnswer:"))
params = trio.SamplingParams(max_tokens=20, temperature=0.0, stop=["\n"])
future_base = sampling_base_client.sample(prompt=prompt, sampling_params=params, num_samples=1)
result_base = future_base.result()
future_sft = sampling_sft_client.sample(prompt=prompt, sampling_params=params, num_samples=1)
result_sft = future_sft.result()
print("Base Responses:")
print(f"{repr(result_base.sequences[0].text)}")
print("SFT Responses:")
print(f"{repr(result_sft.sequences[0].text)}")输出结果为:
Tokenizer finish
Start Training
Iter1 Logprob_squared_loss: 2173.7051
...
Iter15 Logprob_squared_loss: 48.7835
Start Sampling
Base Responses:
' A trio is a musical ensemble consisting of three performers. The term can also refer to a group of'
SFT Responses:
' trio is emotionmachine emotionmachine emotionmachine emotionmachine emotionmachine emotionmachine emotionmachine emotionmachine emotionmachine'可以看到使用了自定义损失函数进行了训练(ps:logprob_squared_loss 只是个用于示例的损失函数,实际效果并不好,请勿使用到自己的训练中。)
forward_backward_custom 的工作原理
forward_backward_custom 允许用户基于 target token 的 logprobs 定义任意可微损失函数,同时无需将自定义函数序列化、上传或在服务器端执行。其核心思想是:将原始的非线性损失分解为一次前向计算和一次基于替代目标函数(surrogate objective)的前向反向计算。该替代目标函数虽然在线性形式上定义于 logprobs,但其对模型参数的梯度与原始损失完全一致。
数学形式
设模型参数为 params,target token 的 logprobs 为:
logprobs = compute_target_logprobs(params)用户定义的原始损失为:
loss = compute_loss_from_logprobs(logprobs)即:
其中:
- 表示模型参数;
- 表示目标 token 的 logprobs;
- 表示用户在客户端定义的任意可微损失函数。
为了在不执行 f 的情况下仍然得到正确梯度,TRIO 构造如下 surrogate loss:
surrogate_loss = (logprobs * logprob_grads).sum()
# where logprob_grads = dLoss/dLogprobs即:
其中 由客户端基于原始损失计算得到,并作为常数权重传回服务器。
根据链式法则,有:
而 surrogate loss 的梯度为:
因此:
这说明,尽管 surrogate loss 的形式不同于原始损失,其对模型参数产生的梯度是严格等价的。
执行流程
forward_backward_custom 在客户端与服务器之间分两阶段完成梯度计算:
-
准备数据 客户端构造
Datum对象列表,并准备目标 token 信息。 -
前向计算 服务器执行一次 forward,计算目标 token 的 logprobs。
-
客户端计算自定义损失 客户端使用返回的 logprobs 调用用户定义的
custom_fn(logprobs),得到标量损失。 -
客户端反向传播到 logprobs 客户端对该损失执行反向传播,得到 ,即每个 logprob 对最终损失的梯度。
-
服务器执行 surrogate forward-backward 服务器使用这些梯度作为权重,构造 surrogate loss:
并对其执行 forward-backward,从而得到与原始自定义损失完全一致的参数梯度。
为什么不需要上传自定义函数
在这一设计中,服务器只需要:
- 计算目标 token 的 logprobs;
- 接收客户端返回的 ;
- 对 surrogate objective 执行标准的梯度计算。
因此,用户定义的 Python 函数始终保留在客户端执行。TRIO 不会对其进行 pickle,也不会将其发送到服务器。
性能开销
由于 forward_backward_custom 需要额外执行一次 forward,其计算开销高于单次 forward_backward:
- FLOPs 约为单次
forward_backward的 1.5×; - 实际耗时 在某些情况下可达到 最多约 3×,这主要来自额外的前向计算以及 forward/backward 调度与客户端-服务器往返带来的实现开销。