API

trio.RestClient

class RestClient:
    http_client: HttpClient
    user_info: dict

RestClient 是用于 REST API 操作的客户端,通过 ServiceClient.create_rest_client() 创建。

client = ServiceClient()
rest_client = client.create_rest_client()

# 列出权重
weights = rest_client.list_weights()

# 列出训练运行
runs = rest_client.list_training_runs()

属性

属性类型说明
http_clientHttpClient底层 HTTP 客户端
user_infodict当前用户信息

方法

get_user_info

def get_user_info(self) -> Dict

获取当前用户信息。

返回值

Dict — 用户信息字典。

示例

info = rest_client.get_user_info()
print(info)

list_weights

def list_weights(self, page: int = 1, page_size: int = 20) -> Dict

分页列出当前用户的模型权重。

参数

参数类型默认值说明
pageint1当前页码
page_sizeint20每页条数

返回值

Dict — 权重列表及分页信息。

示例

weights = rest_client.list_weights(page=1, page_size=20)

get_archive_url

def get_archive_url(self, checkpoint_id: str) -> Dict

获取指定模型权重的临时下载链接。

参数

参数类型说明
checkpoint_idstr模型权重 ID,通过 list_weights 获取

返回值

Dict — 包含临时下载链接的字典。

示例

url_info = rest_client.get_archive_url(checkpoint_id="abc123")

get_training_run

def get_training_run(self, training_run_id: str) -> Dict

获取指定训练运行的详细信息。

参数

参数类型说明
training_run_idstr训练运行 ID

返回值

Dict — 训练运行详情。

示例

run = rest_client.get_training_run(training_run_id="run-001")

list_training_runs

def list_training_runs(self, limit: int = 20, offset: int = 0) -> Dict

分页列出当前用户的训练运行。

参数

参数类型默认值说明
limitint20返回条数
offsetint0偏移量

返回值

Dict — 训练运行列表及分页信息。

示例

runs = rest_client.list_training_runs(limit=10, offset=0)

list_checkpoints

def list_checkpoints(self, training_run_id: str) -> Dict

列出指定训练运行下的所有检查点。

参数

参数类型说明
training_run_idstr训练运行 ID

返回值

Dict — 检查点列表。

示例

checkpoints = rest_client.list_checkpoints(training_run_id="run-001")

get_checkpoint_archive_url

def get_checkpoint_archive_url(self, training_run_id: str, checkpoint_id: str) -> Dict

获取指定检查点的临时下载链接。

参数

参数类型说明
training_run_idstr训练运行 ID
checkpoint_idstr检查点 ID

返回值

Dict — 包含临时下载链接的字典。

示例

url_info = rest_client.get_checkpoint_archive_url(
    training_run_id="run-001",
    checkpoint_id="ckpt-step100",
)

delete_checkpoint

def delete_checkpoint(self, training_run_id: str, checkpoint_id: str) -> None

删除指定训练运行下的检查点。

参数

参数类型说明
training_run_idstr训练运行 ID
checkpoint_idstr检查点 ID

示例

rest_client.delete_checkpoint(training_run_id="run-001", checkpoint_id="ckpt-step100")

list_user_checkpoints

def list_user_checkpoints(self, limit: int = 100, offset: int = 0) -> Dict

分页列出当前用户的所有检查点(跨训练运行)。

参数

参数类型默认值说明
limitint100返回条数
offsetint0偏移量

返回值

Dict — 检查点列表及分页信息。

示例

checkpoints = rest_client.list_user_checkpoints(limit=50)

get_session

def get_session(self, session_id: str) -> Dict

获取指定会话的信息。

参数

参数类型说明
session_idstr会话 ID

返回值

Dict — 会话详情。

示例

session = rest_client.get_session(session_id="sess-abc")

list_sessions

def list_sessions(self, limit: int = 20, offset: int = 0) -> Dict

分页列出当前用户的所有会话。

参数

参数类型默认值说明
limitint20返回条数
offsetint0偏移量

返回值

Dict — 会话列表及分页信息。

示例

sessions = rest_client.list_sessions()

get_sampler

def get_sampler(self, sampler_id: str) -> Dict

获取指定采样器的信息。

参数

参数类型说明
sampler_idstr采样器 ID

返回值

Dict — 采样器详情。

示例

sampler = rest_client.get_sampler(sampler_id="sampler-xyz")

get_checkpoint_info

def get_checkpoint_info(self, path: str) -> Dict

获取指定路径的检查点信息。

参数

参数类型说明
pathstr检查点路径

返回值

Dict — 检查点信息字典。

示例

info = rest_client.get_checkpoint_info(path="/path/to/checkpoint")
print(info)

On this page