长文本,大模型
大多数大语言模型都在 8K 长度的上下文上进行预训练。最近,越来越多的大模型开始支持超过 32K 的长上下文。这些长文本大模型为大模型在文档理解,代码补全等场景带来了新的可能性:
- 阅读理解:GPT-4 论文的正文长度大约包含 8万 token。对这样一篇论文进行总结,抽取,分析往往需要复杂的检索增强生成(RAG)方法。如果能够直接将这 8万 token 全部输入模型,就可以避免对原文的截断,抽取。相比各种复杂的 RAG 来说,这是更简洁的方法。
- 代码补全,同样是需要在相隔甚远的不同位置相互引用的任务。如果能让模型接受来自整个代码仓库的所有代码,模型就可以更好的利用仓库中其他地方定义的函数,带来的优势也是很显著的。
今天开源社区已经有了不少支持上下文长度超过 32K 的模型,但是这些模型很多只是做到了在长上下文下依然正常输出文本的能力,并没有对长上下文的任务进行特定的优化。另一方面,可能开箱即用的开源模型的输出风格,能力并不是完全符合我们的需求。
在上面的这两种情况下,我们自己给开源模型做一些长下文精调就显得十分有价值了。不过,长本文精调并不像普通精调那样直接,由于文本变长,我们需要先解决由它导致的一系列问题。
这篇博客内容大致如下:
长文本建模难在哪里:显存占用,Batch 对齐和 Attention 空间复杂度;
如何解决长文本精调的问题
长文本精调简单实践例子 —— Faro 系列模型
Disclaimer
我们今天讨论的内容关注于进行精调的方法,而不包括如何尽可能提升长文本性能或者更大规模的预训练,尽管这些也很重要。
长文本带来了什么问题
随着上下文长度变长, 训练效率成为了我们面前最严峻的挑战。这种挑战我总结来自模型训练在处理长输入时候的工作细节:
- 显存占用:模型在前向传播时需要计算和保留每一层的中间结果,具体来说,上下文中的每个 token 在每一层都会有自己对应的多个 Key,Value 和 Query。随着上下文变长,这些表征占据了大量空间。
- Batch 对齐:长文本训练在批大小大于一的情况下可能会因为 Pad tokens 浪费非常多的空间,这是因为长文本往往在长度分布上可以跨越多个数量级。下面的图是一个例子。
- Attention 空间复杂度:Self Attention 需要计算每个 Token 到序列中其他所有 Token 的 Attention 值,这个 Attention 计算结果构成了一个 N×N 的矩阵。这意味着 Attention 计算的空间复杂度是 的。那么上下文长度增加 30 倍,Attention 计算需要的空间会增加 900 倍!
所幸这些问题在现在各种训练技巧的加持下都是可以很好的解决的。我们可能甚至根本不会遇到上面的问题,很可能我们使用的训练框架的默认设置已经解决了这些问题。但是知道这些问题的存在,并且理解它们如何被优化,依然是非常重要的。
显存占用
理想情况下,批大小为1,长度为 64K 的一个样本,它的激活值(中间计算结果)在前向传播中占据的显存等于批大小 32,长度 2K 的样本。有人可能会说, 批大小 32,长度 2K 在精调中还是比较现实的设定,如果在多卡训练的情况下确实如此。
但是问题出在,这并不意味着一样的多卡训练方法可以进行批大小为1,长度为 64K的训练,因为现行的大多数并行训练框架(Deepspeed 和 FSDP),并不支持将一个批大小为一的样本分散到多张卡上训练。想要在长文本上进行训练,就要使用尽可能多样的优化方法来节省显存。从节省的程度从多到少,我们可以考虑下面几种技巧。
GQA
前向传播时,每个 token 在每一层都需要保留 num_attention_head
数量的 Query,Key 和 Value 向量。这些 QKV 是显存占用的最大头。大多数 10 B 左右的模型,num_attention_head
= 32,那么每个 token 都需要分配 32 * 3 = 96 个向量。
这种默认的方法就是多头注意力(Multi-head Attention, MHA)。Llama 2 7B,Qwen 系列,Command R 等模型都是这样的。
但是其他模型采用了更高效的多头注意力,也就是分组注意力(Grouped Query Attention, GQA)。GQA 模型给每个 token 分配 num_attention_head
数量的 Query,但是分配更少的 Key 和 Value。比如 Yi-9B 的 num_key_value_heads = 4
,也就是 Yi-9B 中每个 token 分配到了 32 + 4 + 4 = 40 个向量。Llama 2 70B,Llama 3,Yi,Mixtral,Mistral 系列模型都使用了 GQA。
GQA 的优势在推理时表现的更加明显,因为推理时由于不需要反向传播,每个 token 的 query 不会在后续被使用,在计算完成以后就会被丢弃。
因此在这种情况下,如果 MHA 给每个 token 分配了 32 * 2 = 64 个向量,GQA 可能就只需要 4 * 2 = 8 个向量。这带来的显存优势是非常巨大的。
因此,为了更快的训练和更高效的推理,我们应该选择使用了 GQA 的模型作为精调的底座。判断一个模型是否使用 GQA 只需检查配置文件里 num_key_value_heads
是否小于 num_attention_head
。
Gradient Checkpoint
Gradient Checkpoint 减少了在训练中保存每一层计算中间结果的需要。一个模型中只有部分层的中间结果被保存,比如第1,5,15… 层。在反向传播过程中,如果需要某一层的计算结果(如第7层),就通过从最近的 Checkpoint (第5层)开始重新前向传播,计算出第 6 层和第 7 层的中间结果。
数学上为了节省最多的显存,Gradient Checkpoint 保存 层的结果,也就是对于 16 层的模型,前向传播时只保存 1, 5, 9, 13 层的结果。那么对于长文本的占比很大的中间结果,也就是 QKV,使用 Gradient Checkpoint 同样会减少显存占用到 倍。
LoRA:
LoRA 基本上已经成为精调领域很多时候的必备配置,尤其是在单卡训练中。一般在训练模型时,GPU 中除了加载模型的所有权重以外,还会为每个需要训练的参数分配额外的梯度和跟踪梯度的优化器状态。取决于优化器的类型,优化器状态可能会占据模型权重 2 - 6 倍空间。
LoRA 限制大多数参数不被训练,额外引入了一组很小的训练参数,这些训练参数往往只占全部权重的 1%。只有这 1% 的被训练的参数才会被分配对应的优化器状态。
通过 LoRA,几乎去除了所有由优化器状态带来的显存占用。但是需要注意的是,长文本训练的主要显存瓶颈在于 token 的中间计算结果。LoRA 带来的优势和序列长度是无关的。
分布式训练
分布式训练在一般的训练场景中,可以解锁非常巨大模型的训练——毕竟用上了多张卡。但是在长文本训练中他带来的优势就少了很多,这是因为一般常用的精调框架,包括 FSDP 和 Deepspeed,都是数据并行(Data Parallelism)的。这意味着每张卡上都独立地进行训练,只是他们将梯度,优化器状态和模型权重卸载到多张卡,甚至是内存中,在需要的时候再进行聚合。
但是如前所述,优化器和梯度带来的显存负担已经基本上被 LoRA 去除了。而真正的显存负载集中在每个 token 对应的中间计算结果。但是基于数据并行的分布式计算在每张卡上至少需要一个样本,换言之它们并不会将单个样本的显存分担到多张卡上。
因此使用这些分布式方法可以通过增加并行显著加快训练,但是同时只能减少一些显存使用(把模型参数分配到多张卡上)。
综上,经过无所不用其极的努力,我们已经尽可能减少了显存占用。参考上面的图。经过测试,这样的优化,加上 Flash-attention,可以让我们在 float16 精度,80GB A100环境下,在 Yi-9B-200K 上进行批大小为1,长度 64K 的精调。
Batch 对齐
不同于一般的精调数据,长度集中分布在 200-500 tokens 之间,长文本精调的训练样本可能长度跨越数个数量级。
在长文本训练中很容易会出现一个 4K 的样本和一个 64K 的样本出现在同一个 batch 中。这种情况下 4K 的样本后面会添加非常多的 pad tokens 来对齐样本,造成了很大的浪费。
在默认的设置中,短的样本会被使用 pad token 补全到 batch 中最长的样本的长度。这意味着可能一个 4K 的样本被填充了 60K 的长度。所幸现在的精调框架大多能够通过 Sample Packing 技术解决这个问题,我们只需要开启对应选项即可。
Sample Packing 实际上去除了 batch size 的概念。一个包含 3 的样本的 batch 现在被拼接成一个更长的单个序列。三个样本头尾相接成一个序列,同时 attention mask 对应的发生改变来防止同一个序列中的不同样本相互影响。这样的好处就是再也没有 pad token:一个输入可能包含 2 个长的样本,也可能包含 100 个短样本。
不过实践中,LongAlign 论文提到,长的样本和极短的样本出现在同一个 batch 中可能会影响模型收敛,为了解决这个问题,一般会在训练时让长度相近的样本出现在同一个 batch 中。一般的训练框架也会提供这个选项,可能叫做 sort_by_length
之类的。
Attention 空间复杂度
在正常的 Attention 计算中,Query 和 Key,Value 的交互会涉及到 N * N 的矩阵乘法。这使得长文本 Attention 计算具有 的空间复杂度。
但是最后的这个问题其实反而很好解决,使用 Flash Attention 即可。Flash Attention 设计了特定的 CUDA 算子进行 Attention 计算,更新每个 Query 时,其他的 token 对应的 Query 和并没有参与计算的 KV 并不会被加载。因此 Flash Attention 的 Attention 计算更多是次 复杂度的。
Credit: https://insujang.github.io/2024-01-21/flash-attention
至此,我们基本上弥合了长短文本精调之间的差异:现在在一个 32K 长度的样本上训练不比在 32个 1K 样本上训练带来新的负担和缺陷。但是也有一些问题是难以解决的,比如时间复杂度。进行一次前向传播需要计算每个 Token 位置上的表征,每个 Token 表征的计算本身是接近 O(N) ,一个长上下文样本,相比多个短上下文样本,必然带来更长的推理时间。
实际精调中,我们可以使用任何比较好实现了上面的这些特性的框架来完成训练,HF Trainer,Axolotl 和 Llama Factory 等等理论上都能满足需求。
实践:Faro 系列模型
有了这一系列技巧,我们就能用不算非常夸张的资源(几张 A100)训练我们自己的长文本模型了。因此我自己首先训练了一些模型。这一系列模型我取名叫做 Faro,训练了多个版本,分别来自 Qwen1.8B,Qwen4B,Yi-9B-200K 和 Yi-34B-200K。可以在 Huggingface 上下载我的模型,同时我还提供所有训练的训练脚本配置和 Wandb 跟踪记录,供感兴趣的人参考。
- 模型:Faro-Yi-9BFaro-Yi-34BFaro-Yi-9B-DPO
- Wandb 训练记录:Faro-SFTFaro-DPO
- Axolotl 训练配置文件:SFT.ymlDPO.yml
长文本数据
开源的长文本并不多,为了训练长文本模型,我使用了 LongAlign 和 LongLora 开源的数据集,同时我也自己合成了一些。
- THUDM/LongAlign-10k LongAlign包含了 10K 条长文本任务样本,其中 10% 是中文的
- Yukang/LongAlpaca-12k LongAlpaca 包含了 12K 条长文本任务,主要是关于论文的阅读理解,其中也混入了一些短的数据进行平衡。
- wenbopan/RefGPT-Fact-v2-8x 我自己合成的数据,Mutonix/RefGPT-Fact-v2 是很高质量的涉及文档抽取理解的对话数据集,但是它的长度偏短,我进行了扩充。
- wenbopan/anti-haystack 我使用 GPT-4 生成的长文本任务集合,这些任务大多更具备符号性,一般涉及到精确召回事实和引用段落。
上面这些加起来大致有 4 万条数据,我又加入了一些短的样本进行平衡,同时为了保持模型的中文能力,我控制大约 10% 的样本是中文。wenbopan/Fusang-v1 是最终得到的数据集,它的 long
分支就是按照上述方法构造的。这些样本的大多数长度在 20K 以内,因此我的实际训练限制模型的最长长度到 24K。但是实际上这种训练也可以明显增强模型在更长文本上的建模能力。
训练
训练只需要按照这篇博客提及的方法正确配置即可。我使用 Axolotl 框架进行训练。这个框架的最大好处是它的训练非常的配置化,通过一个配置文件可以定义所有训练需要的选项。Faro 系列模型的训练包括了 SFT 和 DPO。只有 SFT 是在长文本上进行的。至于 DPO 的具体训练方法,可以参考我的 Huggingface 仓库和训练脚本。
同时我也提供训练全过程的 Wandb 跟踪记录 Faro-SFT Faro-DPO 以供参考。由于不同训练使用的 GPU 数量可能不同,因此你会在 Wandb 上看到不同长度的 Loss 曲线。
评估
完成训练以后当然要测试一下我们的模型在长文本上的建模能力如何。我在这里选择了LongBench。可以看到我们的长文本精调还是相当有效的:在大多数方向上 Faro-Yi-9B 都超过了 Yi-9B-200K。
Name | Few-shot Learning_en | Synthetic Tasks_en | Single-Doc QA_en | Multi-Doc QA_en | Summarization_en | Few-shot Learning_zh | Synthetic Tasks_zh | Single-Doc QA_zh | Multi-Doc QA_zh | Summarization_zh |
---|---|---|---|---|---|---|---|---|---|---|
Yi-9B-200K | 60.6 | 22.8 | 30.9 | 38.9 | 25.8 | 46.5 | 28.0 | 49.6 | 17.7 | 9.7 |
Faro-Yi-9B | 63.8 | 40.2 | 36.2 | 38.0 | 26.3 | 30.0 | 75.1 | 55.6 | 30.7 | 14.1 |
接下来是什么
到此我对于我们的长文本精调已经比较满意了。但是如果要继续改进还是有很多可以做的。这种长文本精调的方法实际上充满了妥协,为了将长文本在一张卡上跑起来,我们得使用 LoRA,能够精调的模型也限制在了 GQA 模型。对于 13B 模型,最长可以精调的长度大约是 32K,8B 则是 64K。再长的长度就触及了我们方法的天花板。
想要在更长的文本上训练,就需要使用基于向量并行而不是数据并行的训练方法,比如 MegatronLM 和 Jax。不过对于个人研究者使用的场景,我们的方法已经可以简单快捷的产生有用的长文本模型了。