API
trio.SamplingParams
class SamplingParams(BaseModel):
max_tokens: int | None = None
seed: int | None = None
stop: str | Sequence[str] | Sequence[int] | None = None
temperature: float = 1
top_k: int = -1
top_p: float = 1
ignore_eos: bool = FalseSamplingParams 用于控制 SamplingClient.sample() 的文本生成行为。
future = sampling_client.sample(
prompt=prompt,
num_samples=4,
sampling_params=SamplingParams(temperature=0.8, max_tokens=256),
)参数
| 参数 | 类型 | 默认值 | 说明 |
|---|---|---|---|
max_tokens | int | None | None | 最大生成 token 数,None 表示不限制 |
seed | int | None | None | 随机种子,用于复现生成结果 |
stop | str | Sequence[str] | Sequence[int] | None | None | 停止条件:字符串、字符串列表或 token id 列表,匹配时停止生成 |
temperature | float | 1 | 采样温度,值越高输出越随机,0 表示贪心解码 |
top_k | int | -1 | Top-K 采样,-1 表示不限制 |
top_p | float | 1 | Top-P(nucleus)采样,1 表示不限制 |
ignore_eos | bool | False | 为 True 时忽略 EOS token,继续生成直到达到 max_tokens。启用时必须设置 max_tokens |
示例
贪心解码
SamplingParams(temperature=0, max_tokens=128)带停止词
SamplingParams(temperature=1.0, max_tokens=256, stop=["</s>", "\n\n"])固定随机种子
SamplingParams(temperature=0.8, max_tokens=128, seed=42)