进阶

HuggingFace datasets

如何将 HuggingFace 上的数据集用于 TRIO 训练?其实十分简单。

比如一个经典的HuggingFace上的数据集openai/gsm8k,分为 mainsocratic 两个子集,traintest 两个再往下细分的子集,每条数据有两个元素questionanswer

在代码中,可以使用datasets库获取数据:

from datasets import load_dataset

ds = load_dataset("openai/gsm8k", "main")

train_dataset = ds["train"]
eval_dataset = ds["test"]

打印出训练集中的第一条数据,大致就能理解datasets的使用逻辑:

print(train_dataset[0])
"""
{
    'question': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?', 
    'answer': 'Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 72'
}
"""

print(train_dataset[0]["question"])
"""
Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
"""

数据处理Datum 中,我们掌握了Datum的用法,核心是将数据集中的字符串取出来,并进行tokenizer及其一系列处理后,传入到Datum中。

这里我们用SFT任务为例,转换一下openai/gsm8k数据集:

from datasets import load_dataset
import pytrio as trio

# 载入数据集
train_dataset = load_dataset("openai/gsm8k", "main")["train"]

# 连接 TRIO 计算引擎
service_client = trio.ServiceClient()
training_client = service_client.create_lora_training_client(
    base_model="Qwen/Qwen3-4B-Instruct-2507",
    rank=32,
)

print("Loading tokenizer...")
tokenizer = training_client.get_tokenizer()
print("Tokenizer finish")

def process_example(example: dict, tokenizer) -> trio.Datum:
    prompt = f"Question: {example['question']}\nAnswer:"
    prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
    prompt_weights = [0] * len(prompt_tokens)
    
    completion_tokens = tokenizer.encode(f" {example['answer']}\n\n", add_special_tokens=False)
    completion_weights = [1] * len(completion_tokens)
    tokens = prompt_tokens + completion_tokens
    weights = prompt_weights + completion_weights
    input_tokens = tokens[:-1]
    target_tokens = tokens[1:]
    weights = weights[1:]
    
    # 转换为trio训练需要的格式
    return trio.Datum(
        model_input=trio.ModelInput.from_ints(tokens=input_tokens),
        loss_fn_inputs=dict(weights=weights, target_tokens=target_tokens)
    )
    
processed_examples = [process_example(ex, tokenizer) for ex in train_dataset]