API
trio.Datum
class Datum(BaseModel):
model_input: ModelInput
loss_fn_inputs: dict[str, TensorData | Any]Datum 是训练样本的数据结构,传入 TrainingClient.forward_backward() 等方法的 data 参数。
每个 Datum 包含输入 token ids 以及对应损失函数所需的参数。
字段
| 字段 | 类型 | 说明 |
|---|---|---|
model_input | ModelInput | 输入 token id 列表 |
loss_fn_inputs | dict[str, TensorData | Any] | 损失函数参数,具体键名取决于所用损失函数 |
loss_fn_inputs 键名
不同损失函数要求 loss_fn_inputs 包含不同的键:
cross_entropy(默认)
| 键 | 类型 | 必填 | 说明 |
|---|---|---|---|
target_tokens | TensorData | 是 | 目标 token ids,长度须与 model_input 相同 |
weights | TensorData | 否 | 每个位置的损失权重,默认全为 1.0 |
importance_sampling / ppo
| 键 | 类型 | 必填 | 说明 |
|---|---|---|---|
target_tokens | TensorData | 是 | 目标 token ids |
logprobs | TensorData | 是 | 旧策略的对数概率 |
advantages | TensorData | 是 | 优势值 |
以上三个字段长度均须与 model_input 相同。
示例
cross_entropy
tokens = tokenizer.encode("The meaning of life is")
datum = Datum(
model_input=tokens,
loss_fn_inputs={"target_tokens": tokens},
)带权重的 cross_entropy
datum = Datum(
model_input=tokens,
loss_fn_inputs={
"target_tokens": tokens,
"weights": [0.0] * 5 + [1.0] * (len(tokens) - 5), # 前 5 个 token 不计入损失
},
)importance_sampling / ppo
datum = Datum(
model_input=tokens,
loss_fn_inputs={
"target_tokens": tokens,
"logprobs": old_logprobs,
"advantages": advantages,
},
)