Chat-甄嬛
分类:SFT;训练Token 0.6M
介绍
Chat-甄嬛 (作者:KMnO4-zx) 是利用《甄嬛传》剧本中所有关于甄嬛的台词和语句,基于大模型进行 LoRA 微调得到的模仿甄嬛语气的聊天语言模型。
Chat-甄嬛实现了以《甄嬛传》为切入点,打造一套基于小说、剧本的个性化 AI 微调大模型完整流程,通过提供任一小说、剧本,指定人物角色,运行本项目完整流程,让每一位用户都基于心仪的小说、剧本打造一个属于自己的、契合角色人设、具备高度智能的个性化 AI。
甄嬛,小说《后宫·甄嬛传》和电视剧《甄嬛传》中的女一号,核心女主角。原名甄玉嬛,嫌玉字俗气而改名甄嬛,为汉人甄远道之女,后被雍正赐姓钮祜禄氏,抬旗为满洲上三旗,获名“钮祜禄·甄嬛”。同沈眉庄、安陵容参加选秀,因容貌酷似纯元皇后而被选中。入宫后面对华妃的步步紧逼,沈眉庄被冤、安陵容变心,从偏安一隅的青涩少女变成了能引起血雨腥风的宫斗老手。雍正发现年氏一族的野心后令其父甄远道剪除,甄嬛也于后宫中用她的连环巧计帮皇帝解决政敌,故而深得雍正爱待。几经周折,终于斗垮了嚣张跋扈的华妃。甄嬛封妃时遭皇后宜修暗算,被皇上嫌弃,生下女儿胧月后心灰意冷,自请出宫为尼。然得果郡王爱慕,二人相爱,得知果郡王死讯后立刻设计与雍正再遇,风光回宫。此后甄父冤案平反、甄氏复起,她也生下双生子,在滴血验亲等各种阴谋中躲过宜修的暗害,最后以牺牲自己亲生胎儿的方式扳倒了幕后黑手的皇后。但雍正又逼甄嬛毒杀允礼,以测试甄嬛真心,并让已经生产过孩子的甄嬛去准格尔和亲。甄嬛遂视皇帝为最该毁灭的对象,大结局道尽“人类的一切争斗,皆因统治者的不公不义而起”,并毒杀雍正。四阿哥弘历登基为乾隆,甄嬛被尊为圣母皇太后,权倾朝野,在如懿传中安度晚年。
环境
在任意一台可联网的 CPU 机器上,安装环境:
pip install pytrio transformers modelscope tqdm数据集
将数据集下载到训练项目的dataset/目录下,命名为huanhuan.json即可。
下载链接:Github

代码
完成训练和评估大约需要消耗 0.6M 训练Token,模型使用 Qwen3-4B-Instruct-2507,异步版下用时大约 6 分半。
执行下面的代码,即可开始训练:
同步版(方便理解)
import json
from pathlib import Path
import numpy as np
import pytrio as trio
from tqdm import tqdm
import time
BASE_MODEL = "Qwen/Qwen3-4B-Instruct-2507"
DATASET_PATH = Path("dataset/huanhuan.json")
WEIGHTS_NAME = "chat-huanhuan-qwen3-4b"
NUM_EPOCHS = 3
BATCH_SIZE = 16
LORA_RANK = 32
LEARNING_RATE = 1e-4
def load_examples(dataset_path: Path) -> list[dict[str, str]]:
raw_examples = json.loads(dataset_path.read_text(encoding="utf-8"))
examples: list[dict[str, str]] = []
for item in raw_examples:
instruction = item.get("instruction", "").strip()
input_text = item.get("input", "").strip()
output_text = item.get("output", "").strip()
if not instruction or not output_text:
continue
user_text = instruction if not input_text else f"{instruction}\n{input_text}"
examples.append({"user": user_text, "assistant": output_text})
if not examples:
raise ValueError(f"No valid training examples found in {dataset_path}")
return examples
def build_datum(example: dict[str, str], tokenizer) -> trio.Datum:
messages = [{"role": "user", "content": example["user"]}]
prompt_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
prompt_tokens = tokenizer.encode(prompt_text, add_special_tokens=False)
prompt_weights = [0] * len(prompt_tokens)
completion_tokens = tokenizer.encode(example["assistant"], add_special_tokens=False)
completion_weights = [1] * len(completion_tokens)
eos_token_id = tokenizer.eos_token_id
if eos_token_id is not None:
completion_tokens = completion_tokens + [eos_token_id]
completion_weights = completion_weights + [1]
tokens = prompt_tokens + completion_tokens
weights = prompt_weights + completion_weights
input_tokens = tokens[:-1]
target_tokens = tokens[1:]
loss_weights = weights[1:]
return trio.Datum(
model_input=trio.ModelInput.from_ints(tokens=input_tokens),
loss_fn_inputs={
"weights": np.asarray(loss_weights, dtype=np.float32),
"target_tokens": np.asarray(target_tokens, dtype=np.int32),
},
)
def evaluate_client(client, tokenizer, prompts: list[str], title: str) -> None:
print(f"\n{title}")
params = trio.SamplingParams(max_tokens=80, temperature=0.0, stop=[tokenizer.eos_token or "<|im_end|>"])
for prompt in prompts:
messages = [{"role": "user", "content": prompt}]
prompt_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
prompt_ids = tokenizer.encode(prompt_text, add_special_tokens=False)
future = client.sample(
prompt=trio.ModelInput.from_ints(prompt_ids),
sampling_params=params,
num_samples=1,
)
result = future.result()
print(f"User: {prompt}")
print(f"Assistant: {result.sequences[0].text.strip()}\n")
def main() -> None:
dataset_path = Path(__file__).resolve().parent / DATASET_PATH
examples = load_examples(dataset_path)
print(f"Loaded {len(examples)} training examples from {dataset_path}")
service_client = trio.ServiceClient()
training_client = service_client.create_lora_training_client(
base_model=BASE_MODEL,
rank=LORA_RANK,
)
print("Loading tokenizer...")
tokenizer = training_client.get_tokenizer()
print("Tokenizer ready")
processed_examples = [build_datum(example, tokenizer) for example in examples]
print("Start training")
optimizer = trio.AdamParams(learning_rate=LEARNING_RATE)
total_steps = NUM_EPOCHS * ((len(processed_examples) + BATCH_SIZE - 1) // BATCH_SIZE)
progress_bar = tqdm(total=total_steps, desc="Training", unit="batch")
for epoch in range(NUM_EPOCHS):
for start in range(0, len(processed_examples), BATCH_SIZE):
batch = processed_examples[start:start + BATCH_SIZE]
fwdbwd_future = training_client.forward_backward(batch, "cross_entropy")
optim_future = training_client.optim_step(optimizer)
fwdbwd_result = fwdbwd_future.result()
optim_future.result()
logprobs = np.concatenate(
[output["logprobs"].tolist() for output in fwdbwd_result.loss_fn_outputs]
)
weights = np.concatenate(
[example.loss_fn_inputs["weights"].tolist() for example in batch]
)
loss = -np.dot(logprobs, weights) / weights.sum()
progress_bar.update(1)
progress_bar.set_postfix(epoch=f"{epoch + 1}/{NUM_EPOCHS}", loss=f"{loss:.4f}")
progress_bar.close()
print("Saving LoRA weights...")
tuned_sampling_client = training_client.save_weights_and_get_sampling_client(name=WEIGHTS_NAME)
base_sampling_client = service_client.create_sampling_client(base_model=BASE_MODEL)
test_prompts = [
"你是谁?",
"介绍一下你自己。",
"你最想对甄嬛说什么?",
]
evaluate_client(base_sampling_client, tokenizer, test_prompts, title="Base model responses")
evaluate_client(tuned_sampling_client, tokenizer, test_prompts, title="Fine-tuned model responses")
print(f"Saved weights name: {WEIGHTS_NAME}")
if __name__ == "__main__":
start_main_time = time.time()
main()
end_main_time = time.time()
print("#" * 50)
print("# all done")
print(f"# train cost {end_main_time - start_main_time:.2f}s")
print("#" * 50)异步版(速度 x3)
import asyncio
import json
from pathlib import Path
import time
import numpy as np
import pytrio as trio
from tqdm import tqdm
BASE_MODEL = "Qwen/Qwen3-4B-Instruct-2507"
DATASET_PATH = Path("dataset/huanhuan.json")
WEIGHTS_NAME = "chat-huanhuan-qwen3-4b-async"
NUM_EPOCHS = 3
BATCH_SIZE = 16
LORA_RANK = 32
LEARNING_RATE = 1e-4
def load_examples(dataset_path: Path) -> list[dict[str, str]]:
raw_examples = json.loads(dataset_path.read_text(encoding="utf-8"))
examples: list[dict[str, str]] = []
for item in raw_examples:
instruction = item.get("instruction", "").strip()
input_text = item.get("input", "").strip()
output_text = item.get("output", "").strip()
if not instruction or not output_text:
continue
user_text = instruction if not input_text else f"{instruction}\n{input_text}"
examples.append({"user": user_text, "assistant": output_text})
if not examples:
raise ValueError(f"No valid training examples found in {dataset_path}")
return examples
def build_datum(example: dict[str, str], tokenizer) -> trio.Datum:
messages = [{"role": "user", "content": example["user"]}]
prompt_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
prompt_tokens = tokenizer.encode(prompt_text, add_special_tokens=False)
prompt_weights = [0] * len(prompt_tokens)
completion_tokens = tokenizer.encode(example["assistant"], add_special_tokens=False)
completion_weights = [1] * len(completion_tokens)
eos_token_id = tokenizer.eos_token_id
if eos_token_id is not None:
completion_tokens = completion_tokens + [eos_token_id]
completion_weights = completion_weights + [1]
tokens = prompt_tokens + completion_tokens
weights = prompt_weights + completion_weights
input_tokens = tokens[:-1]
target_tokens = tokens[1:]
loss_weights = weights[1:]
return trio.Datum(
model_input=trio.ModelInput.from_ints(tokens=input_tokens),
loss_fn_inputs={
"weights": np.asarray(loss_weights, dtype=np.float32),
"target_tokens": np.asarray(target_tokens, dtype=np.int32),
},
)
async def evaluate_client(client, tokenizer, prompts: list[str], title: str) -> None:
print(f"\n{title}")
stop_tokens = [tokenizer.eos_token] if tokenizer.eos_token else ["<|im_end|>"]
params = trio.SamplingParams(max_tokens=80, temperature=0.0, stop=stop_tokens)
for prompt in prompts:
messages = [{"role": "user", "content": prompt}]
prompt_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
prompt_ids = tokenizer.encode(prompt_text, add_special_tokens=False)
future = await client.sample_async(
prompt=trio.ModelInput.from_ints(prompt_ids),
sampling_params=params,
num_samples=1,
)
result = await future
print(f"User: {prompt}")
print(f"Assistant: {result.sequences[0].text.strip()}\n")
async def main() -> None:
dataset_path = Path(__file__).resolve().parent / DATASET_PATH
examples = load_examples(dataset_path)
print(f"Loaded {len(examples)} training examples from {dataset_path}")
service_client = trio.ServiceClient()
training_client = await service_client.create_lora_training_client_async(
base_model=BASE_MODEL,
rank=LORA_RANK,
)
print("Loading tokenizer...")
tokenizer = training_client.get_tokenizer()
print("Tokenizer ready")
processed_examples = [build_datum(example, tokenizer) for example in examples]
print("Start async training")
optimizer = trio.AdamParams(learning_rate=LEARNING_RATE)
steps_per_epoch = (len(processed_examples) + BATCH_SIZE - 1) // BATCH_SIZE
total_steps = NUM_EPOCHS * steps_per_epoch
progress_bar = tqdm(total=total_steps, desc="Async training", unit="batch")
for epoch in range(NUM_EPOCHS):
print_queue = []
submit_bar = tqdm(total=steps_per_epoch, desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} submit", unit="batch", leave=False)
for start in range(0, len(processed_examples), BATCH_SIZE):
batch = processed_examples[start:start + BATCH_SIZE]
fwdbwd_future = await training_client.forward_backward_async(batch, "cross_entropy")
optim_future = await training_client.optim_step_async(optimizer)
submit_bar.update(1)
# 把上报写成异步的能显著提高训练速度
async def print_loss(fwdbwd_future, optim_future, batch, epoch, start):
fwdbwd_result = await fwdbwd_future
await optim_future
logprobs = np.concatenate(
[output["logprobs"].tolist() for output in fwdbwd_result.loss_fn_outputs]
)
weights = np.concatenate(
[example.loss_fn_inputs["weights"].tolist() for example in batch]
)
loss = -np.dot(logprobs, weights) / weights.sum()
progress_bar.update(1)
progress_bar.set_postfix(epoch=f"{epoch + 1}/{NUM_EPOCHS}", loss=f"{loss:.4f}")
print_queue.append(print_loss(fwdbwd_future, optim_future, batch, epoch, start))
submit_bar.close()
[await future for future in print_queue]
progress_bar.close()
print("Saving LoRA weights...")
tuned_sampling_client = await training_client.save_weights_and_get_sampling_client_async(name=WEIGHTS_NAME)
base_sampling_client = await service_client.create_sampling_client_async(base_model=BASE_MODEL)
test_prompts = [
"你是谁?",
"介绍一下你自己。",
"你最想对甄嬛说什么?",
]
await evaluate_client(base_sampling_client, tokenizer, test_prompts, title="Base model responses")
await evaluate_client(tuned_sampling_client, tokenizer, test_prompts, title="Fine-tuned model responses")
print(f"Saved weights name: {WEIGHTS_NAME}")
if __name__ == "__main__":
start_main_time = time.time()
asyncio.run(main())
end_main_time = time.time()
print("#" * 50)
print("# all done")
print(f"# train cost {end_main_time - start_main_time:.2f}s")
print("#" * 50)训练结果
经过 3 个epoch的训练,loss降到了0.0029,对比原始模型的回复(下面的Base model responses),SFT 后的模型能够正确地回答出自己是甄嬛,并模拟甄嬛的语气对话。
Async training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 702/702 [06:28<00:00, 1.81batch/s, epoch=3/3, loss=0.0029]
Saving LoRA weights...
Base model responses
User: 你是谁?
Assistant: 你好!我是通义千问(Qwen),是阿里巴巴集团旗下的通义实验室自主研发的超大规模语言模型。我可以帮助你回答问题、创作文字,比如写故事、写公文、写邮件、写剧本、逻辑推理、编程等等,还能表达观点,玩游戏等。如果你有任何需要帮助的地方,欢迎随时告诉我!😊
User: 介绍一下你自己。
Assistant: 你好!我是通义千问(Qwen),是阿里巴巴集团旗下的通义实验室自主研发的超大规模语言模型。我能够回答问题、创作文字,比如写故事、写公文、写邮件、写剧本、逻辑推理、编程等等,还能表达观点,玩游戏等。我支持多种语言,包括但不限于中文、英文、德语、法语、西班牙语、
User: 你最想对甄嬛说什么?
Assistant: 如果我能穿越回那个宫闱深锁的年代,最想对甄嬛说的,或许不是“你赢了”或“你错了”,而是:**“谢谢你,没有在权谋中迷失自己,也没有在孤独里放弃温柔。”**在那场波澜壮阔的后宫之争里,你步步为营,却始终记得为人心留一盏
---
Fine-tuned model responses
User: 你是谁?
Assistant: 我是甄嬛,家父是大理寺少卿甄远道。
User: 介绍一下你自己。
Assistant: 我是甄嬛,家父是大理寺少卿甄远道。
User: 你最想对甄嬛说什么?
Assistant: 嬛嬛不孝,不能侍奉在侧奉汤送药,不能时时在你身边抚慰你,不能在你病重时守在你身边,不能在你身边为你梳妆打扮,不能在你身边为你弹琴抚琴,不能在你身边为你抚平眉间愁苦,不能在你身边为你拭去眼角泪痕,不能在你