代码仅供参考,实际执行需补全pretraining_corpora内各子集的数据和prompts
- Python >= 3.8
- PyTorch >= 1.12.0
- Colossalai >= 0.2.5
- transformers >= 4.27.0
- sentencepiece >= 0.1.97
| 文件 | 描述 | GPU |
|---|---|---|
| pretrain_args.py | 第一阶段预训练的参数配置文件 | |
| pretrain_main.py | 第一阶段预训练的执行文件 | 8 x V100-32G |
| task_instruct_tuning_args.py | 第二阶段混合训练的参数配置文件 | |
| task_instruct_tuning_main.py | 第二阶段混合训练的执行文件 | 8 x V100-32G |
| alignment_sft_args.py | 第三阶段指令微调的参数配置文件 | |
| alignment_sft_main.py | 第三阶段指令微调的执行文件 | 2 x A100-40G |
注意点
- 以上代码仅支持单节点GPU机器运行,不支持多机多卡。
- 各阶段训练代码的差异主要在内部
DataProcessor类。 - 运行环境和数据配置好后,各阶段训练均直接通过
torchrun --standalone --nproc_per_node=${NUM_GPU} xxx_main.py执行。 - 模型文件
llama/llama_model.py是基于MetaAI官方代码改造的,非HF模型。如需在开源的BiLLa模型上继续训练,请参考HF的转换代码进行反向操作(补充:config.json内max_position_embeddings需更名为max_sequence_length)。 - 每阶段的训练数据都分成了不同的集合,方便在batch内进行配比。数据预先shuffle再存到文件,各集合下的文件数需大于等于GPU卡数。
| 数据集 | 任务 | 阶段一/set | 阶段二/set | 阶段三/set | 备注 |
|---|---|---|---|---|---|
| WuDao | 语言建模 | zh | zh | ||
| Pile | 语言建模 | en | en | 仅使用Pile-CC和Github | |
| WMT | 翻译 | translate | translate | ||
| Ape210k | 数学解题 | zh-reasoning | mwp-zh | ChatGPT辅助生成解析 | |
| Math23k | 数学解题 | zh-reasoning | mwp-zh | ChatGPT辅助生成解析 | |
| CNewsSum | 文生摘 + 摘生文 | zh-summary | others | ||
| CMRC2018 | 阅读理解 | zh-qa-short | |||
| DuReader | 阅读理解 | zh-qa-short | |||
| WebQA (Baidu) | 开放域问答 | zh-qa-long | |||
| MathQA | 数学解题 | en-reasoning | mwp-en | ChatGPT辅助生成解析 | |
| HotpotQA | 多跳阅读理解 | en-reasoning | others | ChatGPT辅助生成解析 | |
| HotpotQA | 多跳阅读理解 | en-hotpotqa | |||
| SQuAD 2.0 | 阅读理解 | en-squad2 | others | ||
| WikiHow | 文生摘 + 摘生文 | en-wikihow | |||
| CNN DM | 文生摘 + 摘生文 | en-summary | others | ||
| SamSum | 对话摘要 | en-summary-dialogue | others | ||
| MediaSum | 对话摘要 | en-summary-dialogue | others | ||
| CodeAlpaca | 指令 (代码生成) | en-reasoning | code-alpaca-en | ||
| CoT | 指令 | cot-en | |||
| COIG (leetcode) | 指令 (代码生成) | zh-reasoning | leet-code-zh | ||
| COIG | 指令 | others | 仅使用human_value对齐数据 | ||
| AlpacaGPT4 | 指令 | others | |||
| Dolly 2.0 | 指令 | others | |||
| GPT4Tools | 指令 | others | |||
| GPTeacher | 指令 | others | |||
| HC3 | 指令 | others |
本部分相对复杂,故稍作简述
任务数据在转换成模型输入时,会搭配不同的prompt。比如文件中一条内容如下:
{"input": "xxx", "output": "yyy", "type": "[type]"}
对应在pretraining_corpora/prompts目录下必须存在两个文件[type].pre.prompt和[type].post.prompt,保存着“前接”prompt和“后接”prompt。
当DataProcessor读取到该数据时,随机选择prompt与任务数据input拼接。假设选择的prompt来自[type].pre.prompt,prompt就在input前面,反之位于其后。
具体见代码task_instruct_tuning_main.py的第172-182行,配合pretraining_corpora/zh-summary和pretraining_corpora/prompts两目录下的文件更便于理解。