API

trio.SamplingClient

class SamplingClient:
    task_id: str
    base_model: str

SamplingClient 是用于文本生成与推理的客户端,通过 ServiceClient.create_sampling_client() 创建。

client = ServiceClient()
sampling_client = client.create_sampling_client(base_model="Qwen/Qwen2.5-3B")

tokenizer = sampling_client.get_tokenizer()
prompt = tokenizer.encode("Hello, world!")

future = sampling_client.sample(
    prompt=prompt,
    num_samples=4,
    sampling_params=SamplingParams(temperature=1.0, max_tokens=128),
)
response = future.result()

属性

属性类型说明
task_idstr当前采样任务 ID
base_modelstr使用的基础模型名称

方法

sample

def sample(
    self,
    prompt: ModelInput,
    num_samples: int,
    sampling_params: SamplingParams,
    include_prompt_logprobs: bool = False,
    topk_prompt_logprobs: int = 0,
) -> APIFuture[SampleResponse]

根据输入 prompt 生成文本补全。

参数

参数类型默认值说明
promptModelInput输入的 token id 列表
num_samplesint生成的样本数量
sampling_paramsSamplingParams采样参数,见 SamplingParams
include_prompt_logprobsboolFalse是否在返回结果中包含 prompt 部分的对数概率
topk_prompt_logprobsint0返回 prompt 部分 top-k 对数概率的数量,0 表示不返回

返回值

APIFuture[SampleResponse] — 调用 .result() 获取生成结果。

sampling_params.ignore_eos=True 时,必须同时设置 sampling_params.max_tokens,否则抛出 ValueError

示例

future = sampling_client.sample(
    prompt=prompt,
    num_samples=4,
    sampling_params=SamplingParams(temperature=1.0, max_tokens=128),
)
response = future.result()
print(response.sequences)

compute_logprobs

def compute_logprobs(self, prompt: ModelInput) -> APIFuture[dict[str, Any]]

计算 prompt 中每个 token 的对数概率。

参数

参数类型说明
promptModelInput输入的 token id 列表

返回值

APIFuture[dict[str, Any]] — 调用 .result() 获取对数概率字典。

示例

logprobs = sampling_client.compute_logprobs(prompt=prompt).result()

get_tokenizer

def get_tokenizer(self)

获取与当前基础模型匹配的 tokenizer,基于 transformers / modelscope 提供的 AutoTokenizer

示例

tokenizer = sampling_client.get_tokenizer()
prompt = tokenizer.encode("The meaning of life is")

异步方法

sample_async

async def sample_async(
    self,
    prompt: ModelInput,
    num_samples: int,
    sampling_params: SamplingParams,
    include_prompt_logprobs: bool = False,
    topk_prompt_logprobs: int = 0,
) -> APIFuture[SampleResponse]

sample 的异步版本,参数相同。

future = await sampling_client.sample_async(
    prompt=prompt,
    num_samples=4,
    sampling_params=SamplingParams(temperature=1.0, max_tokens=128),
)
response = future.result()

compute_logprobs_async

async def compute_logprobs_async(self, prompt: ModelInput) -> APIFuture[dict[str, Any]]

compute_logprobs 的异步版本,参数相同。

logprobs = (await sampling_client.compute_logprobs_async(prompt=prompt)).result()

On this page