指南

损失函数

TRIO 为监督学习和强化学习提供了内置的损失函数。

你可以通过将字符串传递给 forward_backward 来选择损失函数:

future = training_client.forward_backward(
    data,
    loss_fn="cross_entropy", 
    )
result = future.result()

内置损失函数

目前 TRIO 支持的内置损失函数如下:

损失函数适用场景说明
cross_entropy监督学习标准交叉熵损失,适用于分类任务。以模型输出的 logits 和目标标签计算负对数似然。
importance_sampling离线强化学习使用重要性采样对 off-policy 数据进行修正,通过行为策略与目标策略的概率比值对梯度加权。
ppo在线强化学习Proximal Policy Optimization 损失,通过裁剪概率比值限制策略更新幅度,提升训练稳定性。

cross_entropy

在监督学习中,我们实现了标准的交叉熵损失(即负对数似然),该损失优化策略 pθp_\theta 以最大化 token xx 的对数概率:

L(θ)=Ex[logpθ(x)]L(\theta) = -\mathbb{E}_x[\log p_\theta(x)]

其中 weights 为 0 或 1,通常由 renderer.build_supervised_example() 生成,该函数返回 (model_input, weights)(即用于指定需要训练的目标助手轮次)。

其实现方式为:

# Apply weights and compute elementwise loss
elementwise_loss = -target_logprobs * weights
# Apply sum reduction to get the total loss
loss = elementwise_loss.sum()  # scalar

cross_entropy损失需要Datumloss_fn_inputs中传入target_tokensweights两个字段:

  • target_tokens: array[(N,), int] | array[(N, K), int]:target token IDs
  • weights: array[(N,), float] | array[(N, K), float]:token级的损失权重(通常来自渲染器)

输出:

  • logprobs: array[(N,), float] | array[(N, K), float]:请求的target token的对数概率

指标:

  • loss:sum:加权交叉熵损失的累加值,是一个标量

importance_sampling

对于强化学习,我们实现了策略梯度目标的一个常见变体,适用于学习策略 pp 与采样策略 qq 存在差异的实际场景(例如由于非确定性导致的 off-policy 情况)。

问题在于,若两者存在差异,则目标:

L(θ)=Expθ[A(x)]L(\theta) = \mathbb{E}_{x \sim p_\theta}[A(x)]

由于 xqx \sim q(采样器)并不严格等同于期望的 xpθx \sim p_\theta(学习器),会导致估计有偏。为修正此偏差,我们采用改进的"重要性采样"目标:

LIS(θ)=Exq[pθ(x)q(x)A(x)]L_{\text{IS}}(\theta) = \mathbb{E}_{x \sim q}\left[\frac{p_\theta(x)}{q(x)} A(x)\right]

该目标可得到正确的期望奖励。公式中:

  • logpθ(x)\log p_\theta(x)target_logprobs)来自学习器,在 forward_backward 的前向阶段计算。
  • logq(x)\log q(x)sampling_logprobs)来自采样器,在采样时记录,用作修正项。

其实现方式为:

# Compute probability ratio
prob_ratio = torch.exp(target_logprobs - sampling_logprobs)
# Compute importance-weighted loss
loss = -(prob_ratio * advantages).sum()

importance_sampling 损失需要 Datumloss_fn_inputs 中传入以下字段:

  • target_tokens: array[(N,), int]:target token IDs(来自采样器 qq
  • logprobs: array[(N,), float]:token 的 sampling_logprobs
  • advantages: array[(N,), float]:RL 的优势值(正值表示强化,负值表示抑制)

输出:

  • logprobs: array[(N,), float]:token 的 target_logprobs

指标:

  • loss:sum:重要性加权策略梯度损失 LISL_{\text{IS}} 的累加值,是一个标量

ppo

PPO(Schulman et al., 2017)通过引入裁剪目标函数来解决标准策略梯度方法的问题,将策略更新限制在采样分布的邻域内,从而防止在同一 rollout 分布上进行多步梯度更新时出现过大的策略偏移。

该目标函数通过裁剪重要性比值 pθ(x)q(x)\frac{p_\theta(x)}{q(x)} 来防止策略更新幅度过大,其中 pθp_\theta 为学习器策略,qq 为采样策略。注意,PPO 的裁剪与损失计算均以 token 为单位独立进行。

PPO 裁剪目标为:

LCLIP(θ)=Exq[clip ⁣(pθ(x)q(x),1ϵlow,1+ϵhigh)A(x)]L_{\text{CLIP}}(\theta) = -\mathbb{E}_{x \sim q}\left[\text{clip}\!\left(\frac{p_\theta(x)}{q(x)},\, 1 - \epsilon_{\text{low}},\, 1 + \epsilon_{\text{high}}\right) \cdot A(x)\right]

最终 PPO 损失结合了裁剪与未裁剪两个目标:

LPPO(θ)=Exq[min ⁣(pθ(x)q(x)A(x),  clip ⁣(pθ(x)q(x),1ϵlow,1+ϵhigh)A(x))]L_{\text{PPO}}(\theta) = -\mathbb{E}_{x \sim q}\left[\min\!\left(\frac{p_\theta(x)}{q(x)} \cdot A(x),\; \text{clip}\!\left(\frac{p_\theta(x)}{q(x)},\, 1 - \epsilon_{\text{low}},\, 1 + \epsilon_{\text{high}}\right) \cdot A(x)\right)\right]

其中 ϵlow\epsilon_{\text{low}}ϵhigh\epsilon_{\text{high}} 为超参数(当前在 TRIO 中固定为 0.2)。

其实现方式为:

# Compute probability ratio
prob_ratio = torch.exp(target_logprobs - sampling_logprobs)
# Apply clipping
clipped_ratio = torch.clamp(prob_ratio, clip_low_threshold, clip_high_threshold)
# Compute both objectives
unclipped_objective = prob_ratio * advantages
clipped_objective = clipped_ratio * advantages
# Take minimum (most conservative)
ppo_objective = torch.min(unclipped_objective, clipped_objective)
# PPO loss is negative of objective
loss = -ppo_objective.sum()

ppo 损失需要 Datumloss_fn_inputs 中传入以下字段:

  • target_tokens: array[(N,), int]:target token IDs(来自采样器 qq
  • logprobs: array[(N,), float]:token 的 sampling_logprobs
  • advantages: array[(N,), float]:RL 的优势值

输出:

  • logprobs: array[(N,), float]:token 的 target_logprobs

指标:

  • loss:sum:PPO 裁剪损失的累加值,是一个标量

ps:还可通过 loss_fn_config 自定义裁剪阈值:

fwd_bwd_future = await training_client.forward_backward_async(
    data=data,
    loss_fn="ppo",
    loss_fn_config={"clip_low_threshold": 0.9, "clip_high_threshold": 1.1}
)
fwd_bwd_result = await fwd_bwd_future.result_async()

自定义损失函数

对于内置损失函数之外的使用场景,TRIO 提供了更灵活(但速度较慢)的方法 forward_backward_custom 来计算更通用的损失函数。

forward_backward_custom接收数据和模型的logprobs,并返回loss以及可选的指标。

比如我们希望实现1个损失函数,它的逻辑是希望每个 logprob 尽可能接近 0(也就是概率接近 1):

loss=i(logpi)2\text{loss} = \sum_i (\log p_i)^2

实现代码为:

def logprob_squared_loss(data: list[trio.Datum], logprobs: list[torch.Tensor]) -> tuple[torch.Tensor, dict[str, float]]:
    flat_logprobs = torch.cat(logprobs)
    loss = (flat_logprobs ** 2).sum()
    return loss, {"logprob_squared_loss": loss.item()}

使用 forward_backward_custom 调用它:

future = training_client.forward_backward_custom(data, logprob_squared_loss)
result = future.result()
print(f"Loss: {result.loss}, Metrics: {result.metrics}")

我们改造一下 SFT 示例为自定义损失函数:

import pytrio as trio
import torch

# 1. 与TRIO建立连接
service_client = trio.ServiceClient()

# 2. 创建1个训练客户端
base_model = "Qwen/Qwen3-4B-Instruct-2507"
training_client = service_client.create_lora_training_client(
    base_model=base_model,
    rank=32,
)

# 3. 数据集-让LLM答对什么是trio
examples = [
    {"input": "what is trio", "output": "trio is emotionmachine's AI Infra products."},
    {"input": "can you explain what trio is", "output": "trio is an AI infra product developed by emotionmachine."},
    {"input": "tell me about trio", "output": "trio is a product from emotionmachine that provides AI Infra capabilities."},
]

# 4. 获取Tokenizer
print("Loading tokenizer...")
tokenizer = training_client.get_tokenizer()
print("Tokenizer finish")

# 5. 处理数据集,转换为训练需要的格式
def process_example(example: dict, tokenizer) -> trio.Datum:
    prompt = f"Question: {example['input']}\nAnswer:"

    prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
    prompt_weights = [0] * len(prompt_tokens)
    
    completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False)
    completion_weights = [1] * len(completion_tokens)

    tokens = prompt_tokens + completion_tokens
    weights = prompt_weights + completion_weights

    input_tokens = tokens[:-1]
    target_tokens = tokens[1:]
    weights = weights[1:]
    
    # 转换为trio训练需要的格式
    return trio.Datum(
        model_input=trio.ModelInput.from_ints(tokens=input_tokens),
        loss_fn_inputs=dict(weights=weights, target_tokens=target_tokens)
    )

processed_examples = [process_example(ex, tokenizer) for ex in examples]

# 6. 自定义损失函数
def logprob_squared_loss(data: list[trio.Datum], logprobs: list[torch.Tensor]) -> tuple[torch.Tensor, dict[str, float]]: 
    flat_logprobs = torch.cat(logprobs)
    loss = (flat_logprobs ** 2).sum()
    return loss, {"logprob_squared_loss": loss.item()}

# 7. 训练
print("Start Training")
for iter in range(15):
    fwdbwd_future = training_client.forward_backward_custom(processed_examples, logprob_squared_loss) 
    optim_future = training_client.optim_step(trio.AdamParams(learning_rate=1e-4))

    fwdbwd_result = fwdbwd_future.result()
    optim_result = optim_future.result()
    
    print(f"Iter{iter+1} Logprob_squared_loss: {fwdbwd_result.metrics['logprob_squared_loss']:.4f}")

# 7. 推理与评估
print("Start Sampling")
sampling_base_client = service_client.create_sampling_client(base_model=base_model)
training_client.save_state(name="Train")
sampling_sft_client = training_client.save_weights_and_get_sampling_client(name='what-is-trio')

prompt = trio.ModelInput.from_ints(tokenizer.encode("Question: what is trio\nAnswer:"))
params = trio.SamplingParams(max_tokens=20, temperature=0.0, stop=["\n"])

future_base = sampling_base_client.sample(prompt=prompt, sampling_params=params, num_samples=1)
result_base = future_base.result()
future_sft = sampling_sft_client.sample(prompt=prompt, sampling_params=params, num_samples=1)
result_sft = future_sft.result()

print("Base Responses:")
print(f"{repr(result_base.sequences[0].text)}")

print("SFT Responses:")
print(f"{repr(result_sft.sequences[0].text)}")

输出结果为:

Tokenizer finish
Start Training
Iter1 Logprob_squared_loss: 2173.7051
...
Iter15 Logprob_squared_loss: 48.7835
Start Sampling
Base Responses:
' A trio is a musical ensemble consisting of three performers. The term can also refer to a group of'
SFT Responses:
' trio is emotionmachine emotionmachine emotionmachine emotionmachine emotionmachine emotionmachine emotionmachine emotionmachine emotionmachine'

可以看到使用了自定义损失函数进行了训练(ps:logprob_squared_loss 只是个用于示例的损失函数,实际效果并不好,请勿使用到自己的训练中。)

forward_backward_custom 的工作原理

forward_backward_custom 允许用户基于 target token 的 logprobs 定义任意可微损失函数,同时无需将自定义函数序列化、上传或在服务器端执行。其核心思想是:将原始的非线性损失分解为一次前向计算和一次基于替代目标函数(surrogate objective)的前向反向计算。该替代目标函数虽然在线性形式上定义于 logprobs,但其对模型参数的梯度与原始损失完全一致。

数学形式

设模型参数为 params,target token 的 logprobs 为:

logprobs = compute_target_logprobs(params)

用户定义的原始损失为:

loss = compute_loss_from_logprobs(logprobs)

即:

L(θ)=f(z(θ))L(\theta) = f(z(\theta))

其中:

  • θ\theta 表示模型参数;
  • z(θ)z(\theta) 表示目标 token 的 logprobs;
  • ff 表示用户在客户端定义的任意可微损失函数。

为了在不执行 f 的情况下仍然得到正确梯度,TRIO 构造如下 surrogate loss:

surrogate_loss = (logprobs * logprob_grads).sum()
# where logprob_grads = dLoss/dLogprobs

即:

L~(θ)=izi(θ)Lzi\tilde{L}(\theta) = \sum_i z_i(\theta) \cdot \frac{\partial L}{\partial z_i}

其中 Lzi\frac{\partial L}{\partial z_i} 由客户端基于原始损失计算得到,并作为常数权重传回服务器。

根据链式法则,有:

Lθ=iLziziθ\frac{\partial L}{\partial \theta} = \sum_i \frac{\partial L}{\partial z_i} \cdot \frac{\partial z_i}{\partial \theta}

而 surrogate loss 的梯度为:

L~θ=iLziziθ\frac{\partial \tilde{L}}{\partial \theta} = \sum_i \frac{\partial L}{\partial z_i} \cdot \frac{\partial z_i}{\partial \theta}

因此:

L~θ=Lθ\frac{\partial \tilde{L}}{\partial \theta} = \frac{\partial L}{\partial \theta}

这说明,尽管 surrogate loss 的形式不同于原始损失,其对模型参数产生的梯度是严格等价的。

执行流程

forward_backward_custom 在客户端与服务器之间分两阶段完成梯度计算:

  1. 准备数据 客户端构造 Datum 对象列表,并准备目标 token 信息。

  2. 前向计算 服务器执行一次 forward,计算目标 token 的 logprobs。

  3. 客户端计算自定义损失 客户端使用返回的 logprobs 调用用户定义的 custom_fn(logprobs),得到标量损失。

  4. 客户端反向传播到 logprobs 客户端对该损失执行反向传播,得到 Llogprobs\frac{\partial L}{\partial \text{logprobs}},即每个 logprob 对最终损失的梯度。

  5. 服务器执行 surrogate forward-backward 服务器使用这些梯度作为权重,构造 surrogate loss:

    ilogprobsiLlogprobsi\sum_i \text{logprobs}_i \cdot \frac{\partial L}{\partial \text{logprobs}_i}

    并对其执行 forward-backward,从而得到与原始自定义损失完全一致的参数梯度。

为什么不需要上传自定义函数

在这一设计中,服务器只需要:

  • 计算目标 token 的 logprobs;
  • 接收客户端返回的 Llogprobs\frac{\partial L}{\partial \text{logprobs}}
  • 对 surrogate objective 执行标准的梯度计算。

因此,用户定义的 Python 函数始终保留在客户端执行。TRIO 不会对其进行 pickle,也不会将其发送到服务器。

性能开销

由于 forward_backward_custom 需要额外执行一次 forward,其计算开销高于单次 forward_backward

  • FLOPs 约为单次 forward_backward1.5×
  • 实际耗时 在某些情况下可达到 最多约 3×,这主要来自额外的前向计算以及 forward/backward 调度与客户端-服务器往返带来的实现开销。

On this page