API

trio.Datum

class Datum(BaseModel):
    model_input: ModelInput
    loss_fn_inputs: dict[str, TensorData | Any]

Datum 是训练样本的数据结构,传入 TrainingClient.forward_backward() 等方法的 data 参数。

每个 Datum 包含输入 token ids 以及对应损失函数所需的参数。

字段

字段类型说明
model_inputModelInput输入 token id 列表
loss_fn_inputsdict[str, TensorData | Any]损失函数参数,具体键名取决于所用损失函数

loss_fn_inputs 键名

不同损失函数要求 loss_fn_inputs 包含不同的键:

cross_entropy(默认)

类型必填说明
target_tokensTensorData目标 token ids,长度须与 model_input 相同
weightsTensorData每个位置的损失权重,默认全为 1.0

importance_sampling / ppo

类型必填说明
target_tokensTensorData目标 token ids
logprobsTensorData旧策略的对数概率
advantagesTensorData优势值

以上三个字段长度均须与 model_input 相同。

示例

cross_entropy

tokens = tokenizer.encode("The meaning of life is")

datum = Datum(
    model_input=tokens,
    loss_fn_inputs={"target_tokens": tokens},
)

带权重的 cross_entropy

datum = Datum(
    model_input=tokens,
    loss_fn_inputs={
        "target_tokens": tokens,
        "weights": [0.0] * 5 + [1.0] * (len(tokens) - 5),  # 前 5 个 token 不计入损失
    },
)

importance_sampling / ppo

datum = Datum(
    model_input=tokens,
    loss_fn_inputs={
        "target_tokens": tokens,
        "logprobs": old_logprobs,
        "advantages": advantages,
    },
)

On this page