大模型推理提速:投机采样和多头美杜莎机制

发布时间:2023-10-09 22:00

 

 

摘要

 

大语言AI模型在创作,编程,世界知识等领域中展现出了与人类不分上下的能力。因为AI的便利性,人们逐渐依赖于其作为辅助来实现一些日常的工作。然而,ChatGPT等自回归大语言模型在推理的场景下无法有效地利用底层硬件的性能。其主要原因在于自回归模型的推理是基于对单个字的预测,这会造成数据重复搬运的问题,因此无法充分利用目前主流硬件中矩阵计算单元的大算力。本文将讲解能改善自回归模型推理效率问题的两大前沿推理算法:投机采样,以及多头美杜莎。

 

 

引言

 
 
 

我们在前几篇文章中(大模型推理:从模型分析到计算优化(一)大模型推理:从模型分析到计算优化(二)大模型推理:从模型分析到计算优化(三))对大语言模型的推理模式、细节过程以及评估方法做了详细的介绍。从中可以看出自回归类大语言模型的推理瓶颈均在于数据搬运过程,而并不受限于计算,因此整个问题为受内存限制的(memory-bound)。因此目前主流思路为减少数据在设备内存与寄存器中的搬运。首先最直接的方法是将原始模型权重量化到低比特精度(例如INT8、INT4等格式),这样可以使整体数据搬运量大大降低,然而这很有可能会造成模型预测效果的减弱,毕竟许多模型权重的信息在量化的过程中会被丢失。另外一种目前主流的优化方法是使用多/组查询注意力方法(multi/grouped-query attention:MQA/GQA)[1],用于减少KV cache的大小。尽管MQA/GQA增加了对大批次(batch size)吞吐量支持,然而对延迟的减少并没有太多,此外减少的KV cache也会一定程度上降低模型的预测效果。

 

这里我们可以看出减少数据搬运量都是有一定代价的,虽然可以提高推理效率,但是随之带来的很可能就是模型整体预测效果的减弱。其实,这里我们可以换个思路,既然底层硬件的计算力没有被利用满,我们不妨想办法把这些空闲的算力给利用上,也就是通过算力换取带宽的做法。

 

 

投机采样

 

 

深度学习领域中,如果训练数据足够,模型的预测效果会随着模型规模的增加而变得更准确。然而,庞大的模型会导致模型部署时的巨大开销 。因此,Hinton等学者在2015年提出了知识蒸馏(knowledge distillation)[2]的概念,利用规模庞大的教师模型来训练参数规模小的学生模型,通过蒸馏学习,实际部署的学生模型不仅开销较小,而且具有接近教师模型的预测效果。近期,DeepMind的学者提出了投机采样方法(speculative sampling)[3] ,利用蒸馏学习中小模型近似大模型的概念。该方法在保证模型输出质量的同时,充分优化了自回归大语言模型的推理速度。投机采样的关键在于利用小模型多次推理单个字,让大模型进行多字预测,从而提升整体推理效率。每次小模型的单字推理耗时远远小于大模型,因此投机采样能够有效地提高推理效率。这种方法的优势在于,通过蒸馏学习和投机采样,可以在减小模型规模的同时,保持较高的预测效果和推理速度,从而在实际部署中获得更好的性能优化。

 

投机采样的核心思想是许多常见的单词和句子都是非常容易被预测的。因此,大模型只需要在关键部分中指导小模型,这样就能够带来性能提升。与传统的自回归推理方法(图1)不同,投机采样采用了一个草稿模型(draft model),通常是规模更小的模型,进行自回归推理。而原始的大模型则会根据小模型推理的结果进行判断,决定是否接受小模型推理出的多字(图2)。

 

 


图1:原始自回归推理(为了便于直观理解,假设每个字都为单个token)。大模型在收到用户提问后会做出单个字的预测

 

 

图2:投机采样推理(假设每个字都为单个token)。小模型在收到用户提问后会做出单个字的预测,当预测到一定长度后,大模型会判断是否接受小模型预测的多字,这里大模型会一次性处理多字

 

 

图3:多头美杜莎的top-1推理(假设每个字都为单个token)。大模型在收到用户提问后会做出多个字的预测,这里大模型会一次性处理多字。具体为主预测头会预测下1个字,第一个和第二个美杜莎头分别会预测第2个和第3个字。更多的美杜莎头也以此类推。美杜莎头为独立模块,可加在现有预训练/微调好的基础模型中(例如Vicuna-13B)

 

 

一次投机采样(在推理中进行多次投机采样)的整体算法分为四步:

 

1.多次循环:小模型进行自回归推理得到小模型概率p。最终得到小模型预测的初始输出X。

 

2.单次循环:将X输入大模型进行一次前向传播,得到大模型的概率q。需要注意的是,q的序列长度会比p多1,因为q包含了下一个字的概率。

 

3.多次循环:如果p当前索引对应的q索引的概率为1,则接受当前的字,得到最终输出x;如果p当前索引对应的q索引的概率为0,则拒绝当前的字,并取q概率为1的字,得到最终输出x后退出循环和本次的投机采样。

 

4.单次循环:如果当次投机采样接受了所有初始输出X,则选择q最后一个序列位置概率为1的字,得到最终输出x。

 

这里,我们列出了投机采样算法(推理的温度参数设置为0)的大概思路,具体算法可参考原文[3]。

 

关于如何更好地利用投机采样来优化大语言模型性能,我们需要考虑几点:

 

  • 通常选择规模较小的同类模型作为小模型,例如LLaMA-7B作为小模型与LLaMA-65B大模型的组合可能是一个较好的选择。另外,经过低精度量化的原始大模型也可以作为小模型的候选者。

     

  • 原文中建议选择合适的循环长度K=4,过高的K值反而会对性能产生负面影响。因为通常情况下,小模型输出被大模型接受的概率会随着字数增长而降低。因此,过高的K值会增加大模型的计算量。当然,越接近大模型的小模型,能容忍的K值就越高,但需要确保小模型的推理代价比大模型低。

     

  • 重计算和KV cache的取舍 – 理论上小模型和大模型都可以储存KV cache。然而,小模型的输出并不一定都会被大模型所接受,因此不被接受的输出可能会引起大模型额外KV cache的存储代价。

     

总而言之,合适的投机采样循环长度以及小模型的选择是提升模型效率的关键。

 

多头美杜莎

 
 

投机采样虽然提升了自回归大语言模型的推理速度,然而在某些场景下,小模型的选择是件棘手的事。此外,对于现有的框架来说,如何同时部署大模型和小模型也是一种挑战。因此,今年9月,学者们提出了一种名为多头美杜莎(Medusa: Multiple Decoding Heads)[4] 的方法来解决这些问题。多头美杜莎的理念与投机采样类似,都是通过增加额外的计算量来减少数据搬运的代价,从而提升整体的推理速度。它们的区别在于,多头美杜莎利用了多个预测头(language model heads)来进行多字预测(图3),这些额外的预测头被称为美杜莎头。这种方法基于一个更早的工作 [5],该工作研究了自回归模型如何使用多个预测头进行多字预测。然而,与之不同的是,多头美杜莎可以应用于已经训练好的模型中,而无需改变原有模型的结构,只需在现有的模型中加入美杜莎头即可。其训练过程是独立的,将原有模型的参数都固定,不进行训练,只针对额外加入的美杜莎头进行参数训练。这样做大大减少了计算量,并且训练的收敛难度也大幅降低。

 

每一个美杜莎头的结构都为单层或多层ResBlock接上Final Linear层(隐空间转换到字典维度的线性层)。每层ResBlock的结构为x+SiLU(Linear(x)),这里x为隐空间输入,SiLU为激活函数,Linear为线性层。在训练时,只需要使用原始网络的最终隐空间输出,而不需要经过Final linear层(主预测头)的logit。这是因为我们并不想改变原始网络的参数。每多一个美杜莎头则会预测更多一个字。例如,主预测头预测第1个字(不参与多头美杜莎训练,这里只是为了举例),第一个美杜莎头预测第2个字,第二个美杜莎头预测第3个字,并以此类推。在训练中,原始模型只参与前向传播过程而不会进行反向传播更新参数,外加上美杜莎头通常只使用2-3个,并且每个美杜莎头只有单层ResBlock,所以整体多头美杜莎训练是非常轻量级的,需要的硬件资源和时间要比原始模型少很多。

 

在实际推理场景中,如果直接使用top-1(贪婪)策略来进行推理会很容易掉进局部最优概率组合。因此,为了提高推理效果,多头美杜莎使用了top-k的方式来进行推理。这里举个实际的例子,假设使用两个美杜莎头。主预测头使用贪婪解码,第一个美杜莎头使用top-7解码,第二个美杜莎头使用top-6解码,这样会有42*3=126个tokens需要输入给模型判断。这里,多头美杜莎在top-k预测中做了一个创新改进,用来提升性能,具体为这126个tokens内部其实是一个树结构,里面含有重复的tokens组合,因此我们可以将其转换成树状结构:1+7+7*6=50个tokens输入给模型(图4)。这么做的确减少了模型的计算量,不过这里需要在序列混合的地方进行调整,去除树状结构中不相关tokens之间的关联,也就是注意力模块和位置编码。对此,原文提出树状注意力掩码(tree attention mask)[6],用于掩盖不可能的tokens组合(图5)。

 

 

图4:树状索引和原始组合索引,树状组合可以减少模型前向计算量。可在模型输出后转换回原始组合结构来做最后的预测

 

 

图5:多头美杜莎的树状注意力机制,掩码用来屏蔽分支与分支之间字的关联。图片来源于[6]

 

 

多头美杜莎的单次迭代算法可以总结为5步:

 

1.得到原始模型(主预测头)和美杜莎头的logit输出(第一次迭代中为用户提示的序列长度;之后迭代中的序列长度为1),并将树状掩码加在模型注意力模块中。

 

2.从原始模型和美杜莎头的logit输出得到各自的top-k 候选tokens,并将其转换成树状组合格式。同时,按照树状组合结构来更新位置编码。

 

3.树状结构输入给模型并得到原始模型和美杜莎头的树状logit输出,并将其转换回原始组合的格式用于下一步的最大概率组合选择。

 

4.从候选tokens(由第1步中美杜莎头产生)索引找到与原始模型logits概率为1的索引能对上的数字(三个字的例子:[[1,0], [1,1], [0,1]]),并对其进行从左到右的累积乘积(cumulative product)操作(例如[[1,0], [1,1], [0,0]]),之后得到最佳候选tokens组合的索引以及其接受长度。

 

5.将最佳候选tokens组合与现有input_ids合并,并更新其索引位置对应的KV cache。最终,取出在接受长度索引位置的原始模型和美杜莎头的logit输出,用于下一次迭代。

 

这里,我们列出了多头美杜莎算法(温度超参设置为0)的大概思路,具体算法可参考官方源代码[7] 。

 

多头美杜莎的一大要点在于模型每次需要处理每个头的top-k树状分支数总和。例如,当我们在做贪婪预测时,意味着top-k = {1,1,1},因此模型一次需要处理1+1+1*1=3个字;在之前的例子中top-k = {1,7,6},这意味着模型一次需要处理1+7+7*6=50个字。从这里我们可以看到与投机采样类似的问题:当单词预测过多字时,预测成功的难度也随之提升。这里可以注意到,多头美杜莎还会面临的一个问题是随着美杜莎头数量增加,top-k的树状分支也将会以指数增长,造成庞大的计算开销。此外,许多基础和微调模型并没有开放其训练数据集,因此多头美杜莎面临的另一大问题是使用什么数据来训练美杜莎头。原文中使用了ShareGPT的数据来训练基于Vicuna模型的美杜莎头,根据参数规模的不同,训练时耗也仅仅为几个小时到一天。虽然ShareGPT只是微调训练Vicuna模型的一部分数据,然而训练好的多头美杜莎Vicuna还是能较好地保持原始模型的预测效果。

 

 

 

 

表格1:投机采样和多头美杜莎相对于原始自回归推理带来的提速效果。投机采样在参数规模大的模型中性能提升更高,不过这取决于小模型的选择;多头美杜莎则在不同参数规模的模型中拥有更一致的性能提升

 

 

小结

 

 

随着大语言AI模型的高速发展,模型参数规模已经达到了惊人的千亿甚至万亿级别,并在未来有着指数增长的趋势。这会带来模型部署的巨大成本开销,因此这类能在维持模型预测效果的同时提升性能的推理优化算法对于业界有着很大的吸引力。目前,GPT类自回归大语言模型的训练是对整个句子进行计算,其主要计算为矩阵乘以矩阵这类操作,这些操作可以充分利用硬件特定的计算加速器(例如张量核)。然而与训练不同,GPT类模型的推理则是多次对单个字进行预测,因此每次预测的主要计算为向量矩阵乘,这导致无法充分利用硬件计算能力。而无论是投机采样还是多头美杜莎,或者其他类似的推理方法[9,10],它们的底层思路都是将自回归推理中的单字预测转换成多字预测。这样可以提高计算量并降低大模型的预测次数,从而减少数据在计算单元寄存器与设备内存之间的重复搬运次数。虽然投机采样和多头美杜莎提升了推理的计算量,但它们仍然无法高效地一次性预测长句子,这也是未来方法需要解决的一大难题;或者也有可能是一种能跳出自回归模型思维的新型模型结构。

 

 
参考文献
 
[1] Austin Derrow-Pinion, Jennifer She, David Wong, et al. ETA Predictionwith Graph Neural Networks in Google Maps. 2021

[1] Ainslie, J., Lee-Thorp, J., de Jong, M., Zemlyanskiy, Y., Lebrón, F., & Sanghai, S. (2023). GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. arXiv preprint arXiv:2305.13245.

[2] Hinton, G., Vinyals, O., & Dean, J. (2015). Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531.

[3] Chen, C., Borgeaud, S., Irving, G., Lespiau, J. B., Sifre, L., & Jumper, J. (2023). Accelerating large language model decoding with speculative sampling. arXiv preprint arXiv:2302.01318.

[4] https://github.com/FasterDecoding/Medusa

[5] Stern, M., Shazeer, N., & Uszkoreit, J. (2018). Blockwise parallel decoding for deep autoregressive models. Advances in Neural Information Processing Systems, 31.

[6] https://sites.google.com/view/medusa-llm

[7] https://github.com/FasterDecoding/Medusa/blob/main/medusa/model/utils.py

[8] https://github.com/dust-tt/llama-ssp

[9] Leviathan, Y., Kalman, M., & Matias, Y. (2023, July). Fast inference from transformers via speculative decoding. In International Conference on Machine Learning (pp. 19274-19286). PMLR.

[10] Spector, B., & Re, C. (2023). Accelerating LLM Inference with Staged Speculative Decoding. arXiv preprint arXiv:2308.04623.

 

本文中所引用的图片和视频来源均已标明,若有侵权请联系我们予以删除。

上一个: AI 智能体:应用前景与对算力需求的影响

下一个: 大模型推理:从模型分析到计算优化(三)

近期文章

AI 智能体:应用前景与对算力需求的影响

人工智能技术的迅猛发展,尤其是 AI 智能体这一核心部分。本文讨论了AI智能体的定义、应用前景和其对算力需求的影响。

2023-11-13

查看更多