案例

Chat-甄嬛

分类:SFT;训练Token 0.6M

介绍

Chat-甄嬛 (作者:KMnO4-zx) 是利用《甄嬛传》剧本中所有关于甄嬛的台词和语句,基于大模型进行 LoRA 微调得到的模仿甄嬛语气的聊天语言模型。

Chat-甄嬛实现了以《甄嬛传》为切入点,打造一套基于小说、剧本的个性化 AI 微调大模型完整流程,通过提供任一小说、剧本,指定人物角色,运行本项目完整流程,让每一位用户都基于心仪的小说、剧本打造一个属于自己的、契合角色人设、具备高度智能的个性化 AI。

甄嬛,小说《后宫·甄嬛传》和电视剧《甄嬛传》中的女一号,核心女主角。原名甄玉嬛,嫌玉字俗气而改名甄嬛,为汉人甄远道之女,后被雍正赐姓钮祜禄氏,抬旗为满洲上三旗,获名“钮祜禄·甄嬛”。同沈眉庄、安陵容参加选秀,因容貌酷似纯元皇后而被选中。入宫后面对华妃的步步紧逼,沈眉庄被冤、安陵容变心,从偏安一隅的青涩少女变成了能引起血雨腥风的宫斗老手。雍正发现年氏一族的野心后令其父甄远道剪除,甄嬛也于后宫中用她的连环巧计帮皇帝解决政敌,故而深得雍正爱待。几经周折,终于斗垮了嚣张跋扈的华妃。甄嬛封妃时遭皇后宜修暗算,被皇上嫌弃,生下女儿胧月后心灰意冷,自请出宫为尼。然得果郡王爱慕,二人相爱,得知果郡王死讯后立刻设计与雍正再遇,风光回宫。此后甄父冤案平反、甄氏复起,她也生下双生子,在滴血验亲等各种阴谋中躲过宜修的暗害,最后以牺牲自己亲生胎儿的方式扳倒了幕后黑手的皇后。但雍正又逼甄嬛毒杀允礼,以测试甄嬛真心,并让已经生产过孩子的甄嬛去准格尔和亲。甄嬛遂视皇帝为最该毁灭的对象,大结局道尽“人类的一切争斗,皆因统治者的不公不义而起”,并毒杀雍正。四阿哥弘历登基为乾隆,甄嬛被尊为圣母皇太后,权倾朝野,在如懿传中安度晚年。

环境

在任意一台可联网的 CPU 机器上,安装环境:

pip install pytrio transformers modelscope tqdm

数据集

将数据集下载到训练项目的dataset/目录下,命名为huanhuan.json即可。

下载链接:Github

代码

完成训练和评估大约需要消耗 0.6M 训练Token,模型使用 Qwen3-4B-Instruct-2507,异步版下用时大约 6 分半。

执行下面的代码,即可开始训练:

同步版(方便理解)

import json
from pathlib import Path

import numpy as np
import pytrio as trio
from tqdm import tqdm
import time


BASE_MODEL = "Qwen/Qwen3-4B-Instruct-2507"
DATASET_PATH = Path("dataset/huanhuan.json")
WEIGHTS_NAME = "chat-huanhuan-qwen3-4b"
NUM_EPOCHS = 3
BATCH_SIZE = 16
LORA_RANK = 32
LEARNING_RATE = 1e-4


def load_examples(dataset_path: Path) -> list[dict[str, str]]:
    raw_examples = json.loads(dataset_path.read_text(encoding="utf-8"))
    examples: list[dict[str, str]] = []

    for item in raw_examples:
        instruction = item.get("instruction", "").strip()
        input_text = item.get("input", "").strip()
        output_text = item.get("output", "").strip()

        if not instruction or not output_text:
            continue

        user_text = instruction if not input_text else f"{instruction}\n{input_text}"
        examples.append({"user": user_text, "assistant": output_text})

    if not examples:
        raise ValueError(f"No valid training examples found in {dataset_path}")

    return examples


def build_datum(example: dict[str, str], tokenizer) -> trio.Datum:
    messages = [{"role": "user", "content": example["user"]}]
    prompt_text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )

    prompt_tokens = tokenizer.encode(prompt_text, add_special_tokens=False)
    prompt_weights = [0] * len(prompt_tokens)

    completion_tokens = tokenizer.encode(example["assistant"], add_special_tokens=False)
    completion_weights = [1] * len(completion_tokens)

    eos_token_id = tokenizer.eos_token_id
    if eos_token_id is not None:
        completion_tokens = completion_tokens + [eos_token_id]
        completion_weights = completion_weights + [1]

    tokens = prompt_tokens + completion_tokens
    weights = prompt_weights + completion_weights

    input_tokens = tokens[:-1]
    target_tokens = tokens[1:]
    loss_weights = weights[1:]

    return trio.Datum(
        model_input=trio.ModelInput.from_ints(tokens=input_tokens),
        loss_fn_inputs={
            "weights": np.asarray(loss_weights, dtype=np.float32),
            "target_tokens": np.asarray(target_tokens, dtype=np.int32),
        },
    )


def evaluate_client(client, tokenizer, prompts: list[str], title: str) -> None:
    print(f"\n{title}")
    params = trio.SamplingParams(max_tokens=80, temperature=0.0, stop=[tokenizer.eos_token or "<|im_end|>"])

    for prompt in prompts:
        messages = [{"role": "user", "content": prompt}]
        prompt_text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
        prompt_ids = tokenizer.encode(prompt_text, add_special_tokens=False)
        future = client.sample(
            prompt=trio.ModelInput.from_ints(prompt_ids),
            sampling_params=params,
            num_samples=1,
        )
        result = future.result()
        print(f"User: {prompt}")
        print(f"Assistant: {result.sequences[0].text.strip()}\n")


def main() -> None:
    dataset_path = Path(__file__).resolve().parent / DATASET_PATH
    examples = load_examples(dataset_path)
    print(f"Loaded {len(examples)} training examples from {dataset_path}")

    service_client = trio.ServiceClient()
    training_client = service_client.create_lora_training_client(
        base_model=BASE_MODEL,
        rank=LORA_RANK,
    )

    print("Loading tokenizer...")
    tokenizer = training_client.get_tokenizer()
    print("Tokenizer ready")

    processed_examples = [build_datum(example, tokenizer) for example in examples]

    print("Start training")
    optimizer = trio.AdamParams(learning_rate=LEARNING_RATE)
    total_steps = NUM_EPOCHS * ((len(processed_examples) + BATCH_SIZE - 1) // BATCH_SIZE)
    progress_bar = tqdm(total=total_steps, desc="Training", unit="batch")

    for epoch in range(NUM_EPOCHS):
        for start in range(0, len(processed_examples), BATCH_SIZE):
            batch = processed_examples[start:start + BATCH_SIZE]
            fwdbwd_future = training_client.forward_backward(batch, "cross_entropy")
            optim_future = training_client.optim_step(optimizer)

            fwdbwd_result = fwdbwd_future.result()
            optim_future.result()

            logprobs = np.concatenate(
                [output["logprobs"].tolist() for output in fwdbwd_result.loss_fn_outputs]
            )
            weights = np.concatenate(
                [example.loss_fn_inputs["weights"].tolist() for example in batch]
            )
            loss = -np.dot(logprobs, weights) / weights.sum()
            progress_bar.update(1)
            progress_bar.set_postfix(epoch=f"{epoch + 1}/{NUM_EPOCHS}", loss=f"{loss:.4f}")

    progress_bar.close()
    print("Saving LoRA weights...")
    tuned_sampling_client = training_client.save_weights_and_get_sampling_client(name=WEIGHTS_NAME)
    base_sampling_client = service_client.create_sampling_client(base_model=BASE_MODEL)

    test_prompts = [
        "你是谁?",
        "介绍一下你自己。",
        "你最想对甄嬛说什么?",
    ]

    evaluate_client(base_sampling_client, tokenizer, test_prompts, title="Base model responses")
    evaluate_client(tuned_sampling_client, tokenizer, test_prompts, title="Fine-tuned model responses")

    print(f"Saved weights name: {WEIGHTS_NAME}")


if __name__ == "__main__":
    start_main_time = time.time()
    main()
    end_main_time = time.time()
    print("#" * 50)
    print("# all done")
    print(f"# train cost {end_main_time - start_main_time:.2f}s")
    print("#" * 50)

异步版(速度 x3)

import asyncio
import json
from pathlib import Path
import time

import numpy as np
import pytrio as trio
from tqdm import tqdm


BASE_MODEL = "Qwen/Qwen3-4B-Instruct-2507"
DATASET_PATH = Path("dataset/huanhuan.json")
WEIGHTS_NAME = "chat-huanhuan-qwen3-4b-async"
NUM_EPOCHS = 3
BATCH_SIZE = 16
LORA_RANK = 32
LEARNING_RATE = 1e-4


def load_examples(dataset_path: Path) -> list[dict[str, str]]:
    raw_examples = json.loads(dataset_path.read_text(encoding="utf-8"))
    examples: list[dict[str, str]] = []

    for item in raw_examples:
        instruction = item.get("instruction", "").strip()
        input_text = item.get("input", "").strip()
        output_text = item.get("output", "").strip()

        if not instruction or not output_text:
            continue

        user_text = instruction if not input_text else f"{instruction}\n{input_text}"
        examples.append({"user": user_text, "assistant": output_text})

    if not examples:
        raise ValueError(f"No valid training examples found in {dataset_path}")

    return examples


def build_datum(example: dict[str, str], tokenizer) -> trio.Datum:
    messages = [{"role": "user", "content": example["user"]}]
    prompt_text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )

    prompt_tokens = tokenizer.encode(prompt_text, add_special_tokens=False)
    prompt_weights = [0] * len(prompt_tokens)

    completion_tokens = tokenizer.encode(example["assistant"], add_special_tokens=False)
    completion_weights = [1] * len(completion_tokens)

    eos_token_id = tokenizer.eos_token_id
    if eos_token_id is not None:
        completion_tokens = completion_tokens + [eos_token_id]
        completion_weights = completion_weights + [1]

    tokens = prompt_tokens + completion_tokens
    weights = prompt_weights + completion_weights

    input_tokens = tokens[:-1]
    target_tokens = tokens[1:]
    loss_weights = weights[1:]

    return trio.Datum(
        model_input=trio.ModelInput.from_ints(tokens=input_tokens),
        loss_fn_inputs={
            "weights": np.asarray(loss_weights, dtype=np.float32),
            "target_tokens": np.asarray(target_tokens, dtype=np.int32),
        },
    )


async def evaluate_client(client, tokenizer, prompts: list[str], title: str) -> None:
    print(f"\n{title}")
    stop_tokens = [tokenizer.eos_token] if tokenizer.eos_token else ["<|im_end|>"]
    params = trio.SamplingParams(max_tokens=80, temperature=0.0, stop=stop_tokens)

    for prompt in prompts:
        messages = [{"role": "user", "content": prompt}]
        prompt_text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
        prompt_ids = tokenizer.encode(prompt_text, add_special_tokens=False)
        future = await client.sample_async(
            prompt=trio.ModelInput.from_ints(prompt_ids),
            sampling_params=params,
            num_samples=1,
        )
        result = await future
        print(f"User: {prompt}")
        print(f"Assistant: {result.sequences[0].text.strip()}\n")


async def main() -> None:
    dataset_path = Path(__file__).resolve().parent / DATASET_PATH
    examples = load_examples(dataset_path)
    print(f"Loaded {len(examples)} training examples from {dataset_path}")

    service_client = trio.ServiceClient()
    training_client = await service_client.create_lora_training_client_async(
        base_model=BASE_MODEL,
        rank=LORA_RANK,
    )

    print("Loading tokenizer...")
    tokenizer = training_client.get_tokenizer()
    print("Tokenizer ready")

    processed_examples = [build_datum(example, tokenizer) for example in examples]

    print("Start async training")
    optimizer = trio.AdamParams(learning_rate=LEARNING_RATE)
    steps_per_epoch = (len(processed_examples) + BATCH_SIZE - 1) // BATCH_SIZE
    total_steps = NUM_EPOCHS * steps_per_epoch
    progress_bar = tqdm(total=total_steps, desc="Async training", unit="batch")
    for epoch in range(NUM_EPOCHS):
        print_queue = []
        submit_bar = tqdm(total=steps_per_epoch, desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} submit", unit="batch", leave=False)
        for start in range(0, len(processed_examples), BATCH_SIZE):
            batch = processed_examples[start:start + BATCH_SIZE]
            fwdbwd_future = await training_client.forward_backward_async(batch, "cross_entropy")
            optim_future = await training_client.optim_step_async(optimizer)
            submit_bar.update(1)
            # 把上报写成异步的能显著提高训练速度
            async def print_loss(fwdbwd_future, optim_future, batch, epoch, start):
                fwdbwd_result = await fwdbwd_future
                await optim_future

                logprobs = np.concatenate(
                    [output["logprobs"].tolist() for output in fwdbwd_result.loss_fn_outputs]
                )
                weights = np.concatenate(
                    [example.loss_fn_inputs["weights"].tolist() for example in batch]
                )
                loss = -np.dot(logprobs, weights) / weights.sum()
                progress_bar.update(1)
                progress_bar.set_postfix(epoch=f"{epoch + 1}/{NUM_EPOCHS}", loss=f"{loss:.4f}")
            print_queue.append(print_loss(fwdbwd_future, optim_future, batch, epoch, start))
        submit_bar.close()
        [await future for future in print_queue]
    progress_bar.close()
    print("Saving LoRA weights...")
    tuned_sampling_client = await training_client.save_weights_and_get_sampling_client_async(name=WEIGHTS_NAME)
    base_sampling_client = await service_client.create_sampling_client_async(base_model=BASE_MODEL)

    test_prompts = [
        "你是谁?",
        "介绍一下你自己。",
        "你最想对甄嬛说什么?",
    ]

    await evaluate_client(base_sampling_client, tokenizer, test_prompts, title="Base model responses")
    await evaluate_client(tuned_sampling_client, tokenizer, test_prompts, title="Fine-tuned model responses")

    print(f"Saved weights name: {WEIGHTS_NAME}")


if __name__ == "__main__":
    start_main_time = time.time()
    asyncio.run(main())
    end_main_time = time.time()
    print("#" * 50)
    print("# all done")
    print(f"# train cost {end_main_time - start_main_time:.2f}s")
    print("#" * 50)

训练结果

经过 3 个epoch的训练,loss降到了0.0029,对比原始模型的回复(下面的Base model responses),SFT 后的模型能够正确地回答出自己是甄嬛,并模拟甄嬛的语气对话。

Async training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 702/702 [06:28<00:00,  1.81batch/s, epoch=3/3, loss=0.0029]
Saving LoRA weights...                                                                                                                                                                                                                           

Base model responses
User: 你是谁?
Assistant: 你好!我是通义千问(Qwen),是阿里巴巴集团旗下的通义实验室自主研发的超大规模语言模型。我可以帮助你回答问题、创作文字,比如写故事、写公文、写邮件、写剧本、逻辑推理、编程等等,还能表达观点,玩游戏等。如果你有任何需要帮助的地方,欢迎随时告诉我!😊

User: 介绍一下你自己。
Assistant: 你好!我是通义千问(Qwen),是阿里巴巴集团旗下的通义实验室自主研发的超大规模语言模型。我能够回答问题、创作文字,比如写故事、写公文、写邮件、写剧本、逻辑推理、编程等等,还能表达观点,玩游戏等。我支持多种语言,包括但不限于中文、英文、德语、法语、西班牙语、

User: 你最想对甄嬛说什么?
Assistant: 如果我能穿越回那个宫闱深锁的年代,最想对甄嬛说的,或许不是“你赢了”或“你错了”,而是:**“谢谢你,没有在权谋中迷失自己,也没有在孤独里放弃温柔。”**在那场波澜壮阔的后宫之争里,你步步为营,却始终记得为人心留一盏

---

Fine-tuned model responses
User: 你是谁?
Assistant: 我是甄嬛,家父是大理寺少卿甄远道。

User: 介绍一下你自己。
Assistant: 我是甄嬛,家父是大理寺少卿甄远道。

User: 你最想对甄嬛说什么?
Assistant: 嬛嬛不孝,不能侍奉在侧奉汤送药,不能时时在你身边抚慰你,不能在你病重时守在你身边,不能在你身边为你梳妆打扮,不能在你身边为你弹琴抚琴,不能在你身边为你抚平眉间愁苦,不能在你身边为你拭去眼角泪痕,不能在你

On this page