API
trio.SamplingClient
class SamplingClient:
task_id: str
base_model: strSamplingClient 是用于文本生成与推理的客户端,通过 ServiceClient.create_sampling_client() 创建。
client = ServiceClient()
sampling_client = client.create_sampling_client(base_model="Qwen/Qwen2.5-3B")
tokenizer = sampling_client.get_tokenizer()
prompt = tokenizer.encode("Hello, world!")
future = sampling_client.sample(
prompt=prompt,
num_samples=4,
sampling_params=SamplingParams(temperature=1.0, max_tokens=128),
)
response = future.result()属性
| 属性 | 类型 | 说明 |
|---|---|---|
task_id | str | 当前采样任务 ID |
base_model | str | 使用的基础模型名称 |
方法
sample
def sample(
self,
prompt: ModelInput,
num_samples: int,
sampling_params: SamplingParams,
include_prompt_logprobs: bool = False,
topk_prompt_logprobs: int = 0,
) -> APIFuture[SampleResponse]根据输入 prompt 生成文本补全。
参数
| 参数 | 类型 | 默认值 | 说明 |
|---|---|---|---|
prompt | ModelInput | — | 输入的 token id 列表 |
num_samples | int | — | 生成的样本数量 |
sampling_params | SamplingParams | — | 采样参数,见 SamplingParams |
include_prompt_logprobs | bool | False | 是否在返回结果中包含 prompt 部分的对数概率 |
topk_prompt_logprobs | int | 0 | 返回 prompt 部分 top-k 对数概率的数量,0 表示不返回 |
返回值
APIFuture[SampleResponse] — 调用 .result() 获取生成结果。
当
sampling_params.ignore_eos=True时,必须同时设置sampling_params.max_tokens,否则抛出ValueError。
示例
future = sampling_client.sample(
prompt=prompt,
num_samples=4,
sampling_params=SamplingParams(temperature=1.0, max_tokens=128),
)
response = future.result()
print(response.sequences)compute_logprobs
def compute_logprobs(self, prompt: ModelInput) -> APIFuture[dict[str, Any]]计算 prompt 中每个 token 的对数概率。
参数
| 参数 | 类型 | 说明 |
|---|---|---|
prompt | ModelInput | 输入的 token id 列表 |
返回值
APIFuture[dict[str, Any]] — 调用 .result() 获取对数概率字典。
示例
logprobs = sampling_client.compute_logprobs(prompt=prompt).result()get_tokenizer
def get_tokenizer(self)获取与当前基础模型匹配的 tokenizer,基于 transformers / modelscope 提供的 AutoTokenizer。
示例
tokenizer = sampling_client.get_tokenizer()
prompt = tokenizer.encode("The meaning of life is")异步方法
sample_async
async def sample_async(
self,
prompt: ModelInput,
num_samples: int,
sampling_params: SamplingParams,
include_prompt_logprobs: bool = False,
topk_prompt_logprobs: int = 0,
) -> APIFuture[SampleResponse]sample 的异步版本,参数相同。
future = await sampling_client.sample_async(
prompt=prompt,
num_samples=4,
sampling_params=SamplingParams(temperature=1.0, max_tokens=128),
)
response = future.result()compute_logprobs_async
async def compute_logprobs_async(self, prompt: ModelInput) -> APIFuture[dict[str, Any]]compute_logprobs 的异步版本,参数相同。
logprobs = (await sampling_client.compute_logprobs_async(prompt=prompt)).result()