加速Transformer:稀疏注意力加速器调研
发布时间:2023-01-16 07:00
1. 简介
近年来, Transformer模型在深度学习的各个领域,包括自然语言处理、图像分类、图像及语音生成等方面,都取得了远超于传统神经网络模型的表现。最近的ChatGPT和各类基于Transformer的AIGC应用也由于其惊人的表现被大家广泛关注。关于Transformer模型的更多介绍可以查看壁仞科技研究院往期文章:“语言模型的顶梁柱:Transformer,GPT,BERT”。
Transformer模型之所以能取得如此惊人表现的主要原因之一是其中的注意力机制模块(Attention Mechanism),然而注意力机制模块由于需要进行大量的矩阵乘计算,导致其成为影响Transformer模型性能的主要瓶颈。如何设计硬件加速器来高效地运行大型注意力机制模块便成为了近期研究关注的重点。在本文中,我们将与大家一起分享探讨近期各大顶会中发表的稀疏注意力机制加速器的实现方法,希望能给各位读者一些启发。
2. TRANSFORMER和自注意力机制的工作原理
首先我们先来看一看Transformer和自注意力机制的工作原理。Transformer模型通常可以由encoder模块堆叠而成。在encoder模块中,我们可以通过自注意力机制,计算出一个sequence中每个token和其他token之间的两两对应的关联性,从而得到需要被“注意”的token-pair。n个token间的两两对应的关联性可以用一个n x n的矩阵来表示,如图1所示:
图1:n个token之间的注意力关系可以用有n个节点和n x n个连接的图表示,也可以用一个n x n的矩阵来表示[1]
图片来源:https://dl.acm.org/doi/pdf/10.1145/3503222.3507738
图2:Transformer模型架构[1]
图片来源:https://dl.acm.org/doi/pdf/10.1145/3503222.3507738
那么Transformer模型是如何一步一步计算出这些token之间的关联呢?如图2所示,假设Transformer的输入有n个token,首先,它们会先被一个embedding模块转换为n x d大小的X矩阵;随后,X会在Linear Transformation模块中分别与三个权重矩阵相乘,并线性转换为三个n x n大小的矩阵,分别称为Query(Q),Key(K)和Value(V):
我们将Q和(既K矩阵的转置)相乘得到一个n x n的注意力得分(Attention Score)S矩阵,并对其每一行计算SoftMax,便可以得到一个表达token和token之间的关联性注意力权重(Attention Weight)A矩阵:
最后,我们将A矩阵与V矩阵相乘,即可得到最终的n x d大小的自注意力输出Z。
如图2所示,在Transformer模型中,注意力的输出还会通过一个残差连接与其输入相加,并通过计算LayerNorm函数加上一个前向网络(Feed-Forward Network,FFN)得到Transformer的encoder模块的最终输出。通常一个模型中会叠加多个encoder模块,并在最后使用一个分类器(Classifier)来完成一些分类预测的任务。
3. 注意力机制的计算复杂度
由于注意力机制的复杂度是随着token的数量增加(既sequence的长度)呈平方数增长的(如图1所示),从而增加了传统Transformer模型在硬件部署的难度。我们进一步分析一下一个Transformer中每一步所带来的计算开销:在一个传统的Transformer模型中,encoder模块有三个主要的矩阵乘计算,分别为:a)QKV的线性转换(公式1);b)自注意力计算(公式2和3);和c)FFN前向网络。如图3所示,通过计算在不同输入sequence长度下各个矩阵乘的计算量(FLOPs)对比我们可以看到,随着sequence长度的增加,自注意力计算的占比(蓝色部分)逐渐增加,最终远超于线性转换和前向网络计算量的总和(橙色部分)。
图3:自注意力计算的开销占比[1]
图片来源:https://dl.acm.org/doi/pdf/10.1145/3503222.3507738
4. 加速自注意力机制:稀疏化注意力权重矩阵
通过上面的分析我们可以看到,在sequence长度大于一定量级时,注意力机制的计算便会成为Transformer模型计算的主要瓶颈,学术界和产业界也都提出了各种方法来加速注意力机制,其中最受关注的方法之一便是注意力权重剪枝。与普通的权重剪枝(Weight Pruning)有所不同的是,注意力矩阵只能在运行时由公式2计算得出,而无法像权重剪枝一样在运行之前就可以完成。所以,有一部分的研究关注在如何高效准确地在运行时对注意力矩阵进行剪枝,如[1]和[2]。当然,运行时进行剪枝会带来额外的开销以及硬件设计的难度,所以也有很多研究关注在如何使用固定的pattern进行剪枝而不损失Transformer模型的预测精确度,如[3]和[5]。
我们在这里选择2022年在顶会上发表的三篇比较有代表性的论文,为大家简单阐述一下论文主要思路和稀疏注意力机制加速器的设计方法:
4.1 2022 ASPLOS——“DOTA: Detect and Omit Weak Attentions for Scalable Transformer Acceleration” [1]
4.1.1 主要思想
在这篇论文中,作者提出了一种运行时剪枝注意力矩阵的方法。作者在训练Transformer模型时,同步训练了一个简单的弱注意力检测器(Weak Attention Detector)。在进行推理时,这个检测器会在计算注意力权重矩阵之前,预测出token间的弱关联性(即弱注意力),并在计算注意力时忽略这些token的计算。具体检测方法如下:首先,输入X会先被线性转换为低维度的和
:
其中,P为d x k大小的矩阵,其作用是将X的尺寸从n x d减小为n x k(其中k远小于d)。P矩阵的数据会在内做随机采样来决定(其中的数学原理参考[9])。此外,
和
也都是k x k大小的权重矩阵。我们可以用
和
相乘得到一个n x n大小的预估注意力得分
:
这个检测器中的和
可以通过一个MSE(Mean Squared Error)损失函数在训练Transformer时同步进行优化:
其中B为mini-batch的大小。由于S
和公式2中A矩阵的尺寸相同,我们可以用矩阵来预测A矩阵中的弱注意力位置,并在计算A矩阵时,跳过相应的计算。弱注意力的阈值可以通过一个top-K算法搜索得到。
4.1.2 加速器设计
DOTA整体加速器的主体设计如图4所示:
图4:DOTA注意力机制加速器和弱注意力检测器[1]
图片来源:https://dl.acm.org/doi/pdf/10.1145/3503222.3507738
加速器整体分成多个lane实现token-level数据并行计算,每个lane中有一个可重构的矩阵乘单元(RMMU,Reconfigurable Matrix Multiplication Unit)用来计算浮点和整数精度的矩阵乘,以及一个弱注意力检测器。检测器会将预估的注意力得分与阈值进行比较,并由此生成一个n x n大小的bit mask用0和1来表示弱注意力和强注意力在注意力权重矩阵的位置,以便在之后的调度器(scheduler)中跳过弱注意力的计算。此外,检测器中还使用了乱序执行等方法做进一步的加速。由于文章篇幅有限,详细设计和优化效果请参考论文原文。
4.2 2022 ISCA——“Accelerating Attention through Gradient-Based Learned Runtime Pruning” [2]
4.2.1 主要思想
这篇论文的总体思路与DOTA相似,都选择在运行时对注意力矩阵进行剪枝。它们的区别在于,DOTA对输入X进行低维线性转换,并由此来设计一个检测器预测注意力的强度。但是,判断是否为弱注意力的阈值设定则是简单地通过top-K搜索的方法搜索得到。而在这篇文章中,作者设计了一个阈值函数和一个正则化器(Regularizer),并使用模型微调(Finetune)的方法将不重要的注意力得分(即的值)推向阈值的左边以对注意力权重进行剪枝。为了能够使用反向传播训练,作者使用tanh函数近似了一个可微分的软阈值函数(Soft Threshold),如图5和公式7所示:
图5:原阈值函数(左图)无法在Score=Th时被微分,近似的软阈值函数(右图)处处可微分[2]
图片来源:https://arxiv.org/pdf/2204.03227.pdf
需要注意的是,软阈值函数会将小于阈值的注意力得分标记为“-c”而不是“0”。模型损失函数加上正则化器的算法如下面的公式所示:
在公式8a中,L为模型损失函数,A为模型输出,为训练输入,θ为模型参数,
为训练标签,λ为平衡正则化器影响的常量,||θ||0项则为L0正则化器(公式8b),用来计数注意力得分中的未被剪枝的元素。与阈值函数同理,L0正则化函数也需要是可微分的。所以正则化函数需要被近似为:
其中k=100,α=1。当注意力得分大于“-c”时,Sigmoid函数会将其转换为接近1的数;当注意力得分等于“-c”时(被公式7标记为“被剪枝”的注意力得分),Sigmoid函数则会将其转换为0。
随后,作者利用这两个近似的阈值函数和正则化损失函数对模型参数以及阈值(公式7中的Th)进行微调。如图6所示,进过5个epoch之后,模型就可以达到较高的稀疏度和精确度。
图6:模型微调epoch数与稀疏度(左图)和训练损失(右图)的关系[2]
图片来源:https://arxiv.org/pdf/2204.03227.pdf
4.2.2 加速器设计
为了实现上述算法,作者设计了一个名为LeOPArd(Learning thresholds for On-the-fly Pruning Acceleration of transformer model)加速器。加速器的整体架构如下:
图7:LeOPArd稀疏注意力加速器[2]
图片来源:https://arxiv.org/pdf/2204.03227.pdf
LeOPArd加速器分为两个主要部分:a)前端模块(QK-DPU)用来计算注意力得分并与阈值Th进行比较得到稀疏化的注意力得分;b)后端模块(V-PU)用来计算SoftMax函数并与V矩阵相乘得到注意力输出(公式2和3)。此外LeOPArd加速器还设计了一种bit-serial的乘法器用来预测注意力得分的值,若其预测到注意力得分无法超过阈值时,它将提前结束对该元素的计算以节省运算时间。LeOPArd的详细架构和优化效果请参考原论文。
4.3 2022 MICRO——“Adaptable Butterfly Accelerator for Attention-based NNs via Hardware and Algorithm Co-design” [3]
4.3.1 主要思想
与前面两篇论文不同,该论文选择使用一种固定的蝴蝶状块稀疏(Block Sparsity with Butterfly Pattern)的方式来剪枝注意力矩阵。这样做的好处是可以避免前面两篇论文所描述的运行时剪枝所带来的额外计算消耗以及硬件设计上的难度。蝴蝶状块稀疏在之前的神经网络加速器的研究中也有非常广泛的应用[4,5,11]。在这篇论文中,作者结合了蝴蝶状块稀疏和快速傅里叶变换(Fast Fourier Transform,FFT)的方法重构了整个注意力机制模块的结构。由于加速FFT使用的硬件结构也具有蝴蝶状的性质,所以可以很大程度减少硬件设计的难度。
蝴蝶状块稀疏主要由以下过程产生:
其中,为蝴蝶因子(Butterfly Factor),
则为可训练的尺寸为N/2的对角矩阵。研究表明[10],若对一个线性层的MxM大小的权重矩阵使用蝴蝶状块稀疏,可以将其计算和内存复杂度从O(M^2)降低为O(MlogM)。
除此之外,作者基于一些早期的研究,如FNet[4],使用傅里叶变换的方法来代替传统的注意力机制模块。如图8所示,在FNet中,每个输入sequence矩阵先经过一个二维傅里叶变换,并只保留输出中的实数部分进行下一步的计算。FNet不仅减少了非常多的计算量,也达到了非常接近原生注意力机制模块的精度。
图8:FNet网络结构,灰色部分为使用傅里叶变换实现的encoder模块[3]
图片来源:https://arxiv.org/pdf/2209.09570.pdf
由于FNet仍会造成一定的精度损失,所以在这篇论文中,作者选择使用将蝴蝶状稀疏的注意力机制模块和傅里叶变换相结合,组合成一种新型的注意力模块——FABNet,如图9所示。
图9:FABNet网络结构,结合了FBfly(傅里叶变换模块)和ABfly(蝴蝶状稀疏的注意力模块)[3]
图片来源:https://arxiv.org/pdf/2209.09570.pdf
相比于原始的Transformer模型,FABNet不仅可以减少大量的计算量,在很多测试下(如图10所示),也能达到几乎没有任何精度损失,甚至在一些测试中超越了原始Transformer的表现。
图10:原始Transformer,原始FNet和FABNet的精确度对比[3]
图片来源:https://arxiv.org/pdf/2209.09570.pdf
4.3.2 加速器设计
图11展示了加速蝴蝶状稀疏注意力机制和FFT模块的硬件加速器设计。由于FFT的计算特性与蝴蝶线性变换相似,蝴蝶加速器模块(Butterfly Engine)可以负责同时处理蝴蝶线性变换以及FFT的计算。
图11:(a)加速器的整体架构(b)蝴蝶加速器模块(c)注意力机制模块
图片来源:https://arxiv.org/pdf/2209.09570.pdf
此外,作者还使用了优化的内存存取算法,来避免内存bank的冲突,以及使用软硬件联合优化的思想以达到极致的性能。具体的架构细节和优化效果请读者参考原论文。
5. 总结
在本文中,我们为大家简要介绍了最新的稀疏化注意力机制来加速Transformer的方法。文中介绍的三篇论文虽然使用了不同的算法——前两篇使用了运行时的动态稀疏,而第三篇使用了固定模式的静态稀疏,但是它们都利用了软硬件联合优化的思想来将加速器的性能达到极致。除了文中所介绍的三篇论文之外,学术界还有其他非常有影响力的关于稀疏化注意力机制的研究和其他加速注意力的方法,如A3[6],SpAtten[7],ELSA[8]等,由于文章篇幅有限,之后若有机会再与各位读者进行深入地探讨。文中若有任何纰漏之处,也请各位读者批评指正!谢谢!
6. 引用
[1] Z Qu, Zheng, et al. "DOTA: detect and omit weak attentions for scalable transformer acceleration." Proceedings of the 27th ACM International Conference on Architectural Support for Programming Languages and Operating Systems. 2022.
[2] Li, Zheng, et al. "Accelerating attention through gradient-based learned runtime pruning." Proceedings of the 49th Annual International Symposium on Computer Architecture. 2022.
[3] Fan, Hongxiang, et al. "Adaptable Butterfly Accelerator for Attention-based NNs via Hardware and Algorithm Co-design." 2022 55th IEEE/ACM International Symposium on Microarchitecture (MICRO). IEEE, 2022.
[4] Lee-Thorp, James, et al. "Fnet: Mixing tokens with fourier transforms." arXiv preprint arXiv:2105.03824 (2021).
[5] Chen, Beidi, et al. "Pixelated butterfly: Simple and efficient sparse training for neural network models." arXiv preprint arXiv:2112.00029 (2021).
[6] Ham, Tae Jun, et al. "A^3: Accelerating attention mechanisms in neural networks with approximation." 2020 IEEE International Symposium on High Performance Computer Architecture (HPCA). IEEE, 2020.
[7] Wang, Hanrui, Zhekai Zhang, and Song Han. "Spatten: Efficient sparse attention architecture with cascade token and head pruning." 2021 IEEE International Symposium on High-Performance Computer Architecture (HPCA). IEEE, 2021.
[8] Ham, Tae Jun, et al. "ELSA: Hardware-Software co-design for efficient, lightweight self-attention mechanism in neural networks." 2021 ACM/IEEE 48th Annual International Symposium on Computer Architecture (ISCA). IEEE, 2021.
[9] Achlioptas, Dimitris. "Database-friendly random projections." Proceedings of the twentieth ACM SIGMOD-SIGACT-SIGART symposium on Principles of database systems. 2001.
[10] Child, Rewon, et al. "Generating long sequences with sparse transformers." arXiv preprint arXiv:1904.10509 (2019).
[11] Dao, Tri, et al. "Monarch: Expressive structured matrices for efficient and accurate training." International Conference on Machine Learning. PMLR, 2022.
近期文章
通用AI模型的未来:深度强化学习(deep reinforcement learning)
2023-05-08
2023-04-24