API

trio.SamplingParams

class SamplingParams(BaseModel):
    max_tokens: int | None = None
    seed: int | None = None
    stop: str | Sequence[str] | Sequence[int] | None = None
    temperature: float = 1
    top_k: int = -1
    top_p: float = 1
    ignore_eos: bool = False

SamplingParams 用于控制 SamplingClient.sample() 的文本生成行为。

future = sampling_client.sample(
    prompt=prompt,
    num_samples=4,
    sampling_params=SamplingParams(temperature=0.8, max_tokens=256),
)

参数

参数类型默认值说明
max_tokensint | NoneNone最大生成 token 数,None 表示不限制
seedint | NoneNone随机种子,用于复现生成结果
stopstr | Sequence[str] | Sequence[int] | NoneNone停止条件:字符串、字符串列表或 token id 列表,匹配时停止生成
temperaturefloat1采样温度,值越高输出越随机,0 表示贪心解码
top_kint-1Top-K 采样,-1 表示不限制
top_pfloat1Top-P(nucleus)采样,1 表示不限制
ignore_eosboolFalseTrue 时忽略 EOS token,继续生成直到达到 max_tokens。启用时必须设置 max_tokens

示例

贪心解码

SamplingParams(temperature=0, max_tokens=128)

带停止词

SamplingParams(temperature=1.0, max_tokens=256, stop=["</s>", "\n\n"])

固定随机种子

SamplingParams(temperature=0.8, max_tokens=128, seed=42)

On this page