API
trio.RestClient
class RestClient:
http_client: HttpClient
user_info: dictRestClient 是用于 REST API 操作的客户端,通过 ServiceClient.create_rest_client() 创建。
client = ServiceClient()
rest_client = client.create_rest_client()
# 列出权重
weights = rest_client.list_weights()
# 列出训练运行
runs = rest_client.list_training_runs()属性
| 属性 | 类型 | 说明 |
|---|---|---|
http_client | HttpClient | 底层 HTTP 客户端 |
user_info | dict | 当前用户信息 |
方法
get_user_info
def get_user_info(self) -> Dict获取当前用户信息。
返回值
Dict — 用户信息字典。
示例
info = rest_client.get_user_info()
print(info)list_weights
def list_weights(self, page: int = 1, page_size: int = 20) -> Dict分页列出当前用户的模型权重。
参数
| 参数 | 类型 | 默认值 | 说明 |
|---|---|---|---|
page | int | 1 | 当前页码 |
page_size | int | 20 | 每页条数 |
返回值
Dict — 权重列表及分页信息。
示例
weights = rest_client.list_weights(page=1, page_size=20)get_archive_url
def get_archive_url(self, checkpoint_id: str) -> Dict获取指定模型权重的临时下载链接。
参数
| 参数 | 类型 | 说明 |
|---|---|---|
checkpoint_id | str | 模型权重 ID,通过 list_weights 获取 |
返回值
Dict — 包含临时下载链接的字典。
示例
url_info = rest_client.get_archive_url(checkpoint_id="abc123")get_training_run
def get_training_run(self, training_run_id: str) -> Dict获取指定训练运行的详细信息。
参数
| 参数 | 类型 | 说明 |
|---|---|---|
training_run_id | str | 训练运行 ID |
返回值
Dict — 训练运行详情。
示例
run = rest_client.get_training_run(training_run_id="run-001")list_training_runs
def list_training_runs(self, limit: int = 20, offset: int = 0) -> Dict分页列出当前用户的训练运行。
参数
| 参数 | 类型 | 默认值 | 说明 |
|---|---|---|---|
limit | int | 20 | 返回条数 |
offset | int | 0 | 偏移量 |
返回值
Dict — 训练运行列表及分页信息。
示例
runs = rest_client.list_training_runs(limit=10, offset=0)list_checkpoints
def list_checkpoints(self, training_run_id: str) -> Dict列出指定训练运行下的所有检查点。
参数
| 参数 | 类型 | 说明 |
|---|---|---|
training_run_id | str | 训练运行 ID |
返回值
Dict — 检查点列表。
示例
checkpoints = rest_client.list_checkpoints(training_run_id="run-001")get_checkpoint_archive_url
def get_checkpoint_archive_url(self, training_run_id: str, checkpoint_id: str) -> Dict获取指定检查点的临时下载链接。
参数
| 参数 | 类型 | 说明 |
|---|---|---|
training_run_id | str | 训练运行 ID |
checkpoint_id | str | 检查点 ID |
返回值
Dict — 包含临时下载链接的字典。
示例
url_info = rest_client.get_checkpoint_archive_url(
training_run_id="run-001",
checkpoint_id="ckpt-step100",
)delete_checkpoint
def delete_checkpoint(self, training_run_id: str, checkpoint_id: str) -> None删除指定训练运行下的检查点。
参数
| 参数 | 类型 | 说明 |
|---|---|---|
training_run_id | str | 训练运行 ID |
checkpoint_id | str | 检查点 ID |
示例
rest_client.delete_checkpoint(training_run_id="run-001", checkpoint_id="ckpt-step100")list_user_checkpoints
def list_user_checkpoints(self, limit: int = 100, offset: int = 0) -> Dict分页列出当前用户的所有检查点(跨训练运行)。
参数
| 参数 | 类型 | 默认值 | 说明 |
|---|---|---|---|
limit | int | 100 | 返回条数 |
offset | int | 0 | 偏移量 |
返回值
Dict — 检查点列表及分页信息。
示例
checkpoints = rest_client.list_user_checkpoints(limit=50)get_session
def get_session(self, session_id: str) -> Dict获取指定会话的信息。
参数
| 参数 | 类型 | 说明 |
|---|---|---|
session_id | str | 会话 ID |
返回值
Dict — 会话详情。
示例
session = rest_client.get_session(session_id="sess-abc")list_sessions
def list_sessions(self, limit: int = 20, offset: int = 0) -> Dict分页列出当前用户的所有会话。
参数
| 参数 | 类型 | 默认值 | 说明 |
|---|---|---|---|
limit | int | 20 | 返回条数 |
offset | int | 0 | 偏移量 |
返回值
Dict — 会话列表及分页信息。
示例
sessions = rest_client.list_sessions()get_sampler
def get_sampler(self, sampler_id: str) -> Dict获取指定采样器的信息。
参数
| 参数 | 类型 | 说明 |
|---|---|---|
sampler_id | str | 采样器 ID |
返回值
Dict — 采样器详情。
示例
sampler = rest_client.get_sampler(sampler_id="sampler-xyz")get_checkpoint_info
def get_checkpoint_info(self, path: str) -> Dict获取指定路径的检查点信息。
参数
| 参数 | 类型 | 说明 |
|---|---|---|
path | str | 检查点路径 |
返回值
Dict — 检查点信息字典。
示例
info = rest_client.get_checkpoint_info(path="/path/to/checkpoint")
print(info)