API
trio.ServiceClient
class ServiceClient:
def __init__(
self,
host: str | None = None,
use_https: bool | None = None,
api_key: str | None = None,
):ServiceClient 是 TRIO API 的主要入口。它提供以下功能:
- 为模型训练工作流程生成
TrainingClient实例 - 生成用于文本生成和推理的
SamplingClient实例 - 为 REST API 操作(例如列出权重)生成
RestClient实例。
client = ServiceClient()
# 创建 TrainingClient 实例
training_client = client.create_lora_training_client(base_model="Qwen/Qwen3-4B")
# 创建 SamplingClient 实例
sampling_client = client.create_sampling_client(base_model="Qwen/Qwen3-4B")
# 创建 RestClient 实例
rest_client = client.create_rest_client()参数
| 参数 | 类型 | 默认值 | 说明 |
|---|---|---|---|
host | str | None | None | 服务地址 |
use_https | bool | None | None | 是否使用 HTTPS |
api_key | str | None | None | API Key,不传则从本地登录状态读取 |
初始化时,ServiceClient 会自动完成登录验证、建立 Socket 连接,并拉取可用模型列表。
方法
get_supported_models
def get_supported_models(self) -> list[str]获取当前可用的模型列表。
返回值
list[str] — 模型名称列表,例如 ['Qwen/Qwen2.5-3B']。
示例
models = client.get_supported_models()
print(models) # ['Qwen/Qwen2.5-3B', ...]create_lora_training_client
def create_lora_training_client(
self,
base_model: str,
rank: int = 32,
seed: int | None = None,
train_mlp: bool = True,
train_attn: bool = True,
train_unembed: bool = False,
trainable_token_indices: list[int] | dict[str, list[int]] | None = None,
lora_path: str | None = None,
) -> TrainingClient创建一个用于 LoRA 微调的 TrainingClient 实例。
参数
| 参数 | 类型 | 默认值 | 说明 |
|---|---|---|---|
base_model | str | — | 基础模型,例如 'Qwen/Qwen2.5-3B' |
rank | int | 32 | LoRA rank,范围 4–64 |
seed | int | None | None | 用于初始化的随机种子 |
train_mlp | bool | True | 是否训练 MLP 层 |
train_attn | bool | True | 是否训练注意力层 |
train_unembed | bool | False | 是否训练 lm_head 层 |
trainable_token_indices | list[int] | dict[str, list[int]] | None | None | 启用指定 token 的 Embedding 训练,例如 [1, 2, 3] 或 {"embed_tokens": [1, 2, 3]},None 表示不启用。与 train_unembed=True 不能同时使用 |
lora_path | str | None | None | 已有 LoRA 权重路径,用于从检查点继续训练 |
返回值
TrainingClient — 包含训练状态的客户端实例。
示例
training_client = client.create_lora_training_client(
base_model="Qwen/Qwen2.5-3B",
rank=16,
train_unembed=False,
)create_sampling_client
def create_sampling_client(
self,
base_model: str = "",
model_path: str | None = None,
) -> SamplingClient创建一个用于文本生成与推理的 SamplingClient 实例。
参数
| 参数 | 类型 | 默认值 | 说明 |
|---|---|---|---|
base_model | str | "" | 基础模型,例如 'Qwen/Qwen2.5-3B' |
model_path | str | None | None | LoRA 模型路径,传入后会在初始化时自动加载 |
返回值
SamplingClient — 模型采样客户端实例。
示例
# 使用基础模型
sampling_client = client.create_sampling_client(base_model="Qwen/Qwen2.5-3B")
# 加载已有 LoRA 权重
sampling_client = client.create_sampling_client(
base_model="Qwen/Qwen2.5-3B",
model_path="/path/to/lora",
)create_rest_client
def create_rest_client(self) -> RestClient创建一个 RestClient 实例,用于执行 REST API 操作(例如列出权重、查询检查点信息)。
返回值
RestClient — REST 客户端实例。
示例
rest_client = client.create_rest_client()
weights = rest_client.list_weights()create_training_client_from_state
def create_training_client_from_state(self, path: str) -> TrainingClient从已保存的检查点恢复,创建一个 TrainingClient 实例(仅恢复模型权重,不恢复优化器状态)。
参数
| 参数 | 类型 | 说明 |
|---|---|---|
path | str | 检查点路径 |
返回值
TrainingClient — 包含训练状态的客户端实例。
示例
training_client = client.create_training_client_from_state(
path="/path/to/checkpoint"
)create_training_client_from_state_with_optimizer
def create_training_client_from_state_with_optimizer(self, path: str) -> TrainingClient从已保存的检查点恢复,创建一个 TrainingClient 实例,同时恢复优化器状态,可用于无缝续训。
参数
| 参数 | 类型 | 说明 |
|---|---|---|
path | str | 检查点路径 |
返回值
TrainingClient — 包含训练状态与优化器状态的客户端实例。
示例
# 恢复完整训练状态(含优化器),可无缝续训
training_client = client.create_training_client_from_state_with_optimizer(
path="/path/to/checkpoint"
)异步方法
create_lora_training_client_async
async def create_lora_training_client_async(
self,
base_model: str,
rank: int = 32,
seed: int | None = None,
train_mlp: bool = True,
train_attn: bool = True,
train_unembed: bool = False,
trainable_token_indices: list[int] | dict[str, list[int]] | None = None,
lora_path: str | None = None,
) -> TrainingClientcreate_lora_training_client 的异步版本,参数相同。
training_client = await client.create_lora_training_client_async(
base_model="Qwen/Qwen2.5-3B",
rank=16,
)create_sampling_client_async
async def create_sampling_client_async(
self,
base_model: str = "",
model_path: str | None = None,
) -> SamplingClientcreate_sampling_client 的异步版本,参数相同。
sampling_client = await client.create_sampling_client_async(base_model="Qwen/Qwen2.5-3B")create_training_client_from_state_async
async def create_training_client_from_state_async(self, path: str) -> TrainingClientcreate_training_client_from_state 的异步版本,参数相同。
training_client = await client.create_training_client_from_state_async(
path="/path/to/checkpoint"
)create_training_client_from_state_with_optimizer_async
async def create_training_client_from_state_with_optimizer_async(self, path: str) -> TrainingClientcreate_training_client_from_state_with_optimizer 的异步版本,参数相同。
training_client = await client.create_training_client_from_state_with_optimizer_async(
path="/path/to/checkpoint"
)