API

trio.ServiceClient

class ServiceClient:
    def __init__(
        self,
        host: str | None = None,
        use_https: bool | None = None,
        api_key: str | None = None,
    ):

ServiceClient 是 TRIO API 的主要入口。它提供以下功能:

  • 为模型训练工作流程生成 TrainingClient 实例
  • 生成用于文本生成和推理的 SamplingClient 实例
  • 为 REST API 操作(例如列出权重)生成 RestClient 实例。
client = ServiceClient()

# 创建 TrainingClient 实例
training_client = client.create_lora_training_client(base_model="Qwen/Qwen3-4B")

# 创建 SamplingClient 实例
sampling_client = client.create_sampling_client(base_model="Qwen/Qwen3-4B")

# 创建 RestClient 实例
rest_client = client.create_rest_client()

参数

参数类型默认值说明
hoststr | NoneNone服务地址
use_httpsbool | NoneNone是否使用 HTTPS
api_keystr | NoneNoneAPI Key,不传则从本地登录状态读取

初始化时,ServiceClient 会自动完成登录验证、建立 Socket 连接,并拉取可用模型列表。

方法

get_supported_models

def get_supported_models(self) -> list[str]

获取当前可用的模型列表。

返回值

list[str] — 模型名称列表,例如 ['Qwen/Qwen2.5-3B']

示例

models = client.get_supported_models()
print(models)  # ['Qwen/Qwen2.5-3B', ...]

create_lora_training_client

def create_lora_training_client(
    self,
    base_model: str,
    rank: int = 32,
    seed: int | None = None,
    train_mlp: bool = True,
    train_attn: bool = True,
    train_unembed: bool = False,
    trainable_token_indices: list[int] | dict[str, list[int]] | None = None,
    lora_path: str | None = None,
) -> TrainingClient

创建一个用于 LoRA 微调的 TrainingClient 实例。

参数

参数类型默认值说明
base_modelstr基础模型,例如 'Qwen/Qwen2.5-3B'
rankint32LoRA rank,范围 4–64
seedint | NoneNone用于初始化的随机种子
train_mlpboolTrue是否训练 MLP 层
train_attnboolTrue是否训练注意力层
train_unembedboolFalse是否训练 lm_head
trainable_token_indiceslist[int] | dict[str, list[int]] | NoneNone启用指定 token 的 Embedding 训练,例如 [1, 2, 3]{"embed_tokens": [1, 2, 3]}None 表示不启用。与 train_unembed=True 不能同时使用
lora_pathstr | NoneNone已有 LoRA 权重路径,用于从检查点继续训练

返回值

TrainingClient — 包含训练状态的客户端实例。

示例

training_client = client.create_lora_training_client(
    base_model="Qwen/Qwen2.5-3B",
    rank=16,
    train_unembed=False,
)

create_sampling_client

def create_sampling_client(
    self,
    base_model: str = "",
    model_path: str | None = None,
) -> SamplingClient

创建一个用于文本生成与推理的 SamplingClient 实例。

参数

参数类型默认值说明
base_modelstr""基础模型,例如 'Qwen/Qwen2.5-3B'
model_pathstr | NoneNoneLoRA 模型路径,传入后会在初始化时自动加载

返回值

SamplingClient — 模型采样客户端实例。

示例

# 使用基础模型
sampling_client = client.create_sampling_client(base_model="Qwen/Qwen2.5-3B")

# 加载已有 LoRA 权重
sampling_client = client.create_sampling_client(
    base_model="Qwen/Qwen2.5-3B",
    model_path="/path/to/lora",
)

create_rest_client

def create_rest_client(self) -> RestClient

创建一个 RestClient 实例,用于执行 REST API 操作(例如列出权重、查询检查点信息)。

返回值

RestClient — REST 客户端实例。

示例

rest_client = client.create_rest_client()
weights = rest_client.list_weights()

create_training_client_from_state

def create_training_client_from_state(self, path: str) -> TrainingClient

从已保存的检查点恢复,创建一个 TrainingClient 实例(仅恢复模型权重,不恢复优化器状态)。

参数

参数类型说明
pathstr检查点路径

返回值

TrainingClient — 包含训练状态的客户端实例。

示例

training_client = client.create_training_client_from_state(
    path="/path/to/checkpoint"
)

create_training_client_from_state_with_optimizer

def create_training_client_from_state_with_optimizer(self, path: str) -> TrainingClient

从已保存的检查点恢复,创建一个 TrainingClient 实例,同时恢复优化器状态,可用于无缝续训。

参数

参数类型说明
pathstr检查点路径

返回值

TrainingClient — 包含训练状态与优化器状态的客户端实例。

示例

# 恢复完整训练状态(含优化器),可无缝续训
training_client = client.create_training_client_from_state_with_optimizer(
    path="/path/to/checkpoint"
)

异步方法

create_lora_training_client_async

async def create_lora_training_client_async(
    self,
    base_model: str,
    rank: int = 32,
    seed: int | None = None,
    train_mlp: bool = True,
    train_attn: bool = True,
    train_unembed: bool = False,
    trainable_token_indices: list[int] | dict[str, list[int]] | None = None,
    lora_path: str | None = None,
) -> TrainingClient

create_lora_training_client 的异步版本,参数相同。

training_client = await client.create_lora_training_client_async(
    base_model="Qwen/Qwen2.5-3B",
    rank=16,
)

create_sampling_client_async

async def create_sampling_client_async(
    self,
    base_model: str = "",
    model_path: str | None = None,
) -> SamplingClient

create_sampling_client 的异步版本,参数相同。

sampling_client = await client.create_sampling_client_async(base_model="Qwen/Qwen2.5-3B")

create_training_client_from_state_async

async def create_training_client_from_state_async(self, path: str) -> TrainingClient

create_training_client_from_state 的异步版本,参数相同。

training_client = await client.create_training_client_from_state_async(
    path="/path/to/checkpoint"
)

create_training_client_from_state_with_optimizer_async

async def create_training_client_from_state_with_optimizer_async(self, path: str) -> TrainingClient

create_training_client_from_state_with_optimizer 的异步版本,参数相同。

training_client = await client.create_training_client_from_state_with_optimizer_async(
    path="/path/to/checkpoint"
)

On this page