ReadPaper
用户9415
论文粗读
分享
FlashAttention: 更快训练更长上下文的GPT
输入“/”快速插入内容
FlashAttention: 更快训练更长上下文的
GPT
更多
PEFT
&
MLSys
相关精彩内容
Modest Understandings on LLM
https://www.bilibili.com/video/BV1SW4y1X7kh
💡
建议补全前置知识:
GPU Arch:自顶向下分析
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness Paper:
https://arxiv.org/abs/2205.14135
Transformer 作为
GPT
类模型的基础架构提供了强大的特征处理能力,但是处理更长上下文仍然是一个挑战,因为核心的自注意力模块在序列长度上具有O(N^2)的时间和内存复杂度。 😓
这篇Flash Attention的工作深入硬件,新提出了一种具有
IO感知的
,
快速的
⚡️,
节省内存的
🧠,
精确的
🎯注意力算法。目前,Flash Attention已经
集成至torch2.0
,并且社区也提供了多种实现,接下来我们以
Triton
实现
为例简单介绍一下这篇工作,
核心要点
•
⚡️
为什么加快了计算?Fast
◦
降低了耗时的HBM访问次数。采用Tiling技术分块从HBM加载数据到SRAM进行融合计算。
•
🧠为什么节省了内存?Memory-Efficient
◦
不再对中间矩阵S,P进行存储。在反向的时候通过Recomputation重新计算来计算梯度。
•
🎯为什么是精准注意力?Exact Attention
◦
算法流程只是分块计算,无近似操作。
50%
📺FlashAttention算法演示
50%
提出问题
Transformer 结构已成为自然语言处理和图像分类等应用中最常用的架构。尽管 Transformer 在规模上不断增大和加深,但处理更长上下文仍然是一个挑战,因为核心的自注意力模块在序列长度上具有二次方的时间和内存复杂度。这导致在处理长序列时速度变慢且内存需求巨大。因此,我们需要一些优化算法来提高注意力模块的计算速度和内存利用率。
解决方案
37%
63%
Bili视频演示:
⏱️78s看懂FlashAttention【有点意思·1】_哔哩哔哩_bilibili
ManimCode:
https://github.com/cauyxy/bilivideos/blob/master/flash-attn/video_code.py
Forward
Standard Attention Implementation
在注意力的一般实现中,对
三个输入执行以下算法得到输出
,其中softmax行级别执行。
在这个算法中,
矩阵都是很大,需要在HBM中实例化来进行存储,这样就会带来很多HBM的访问次数,最终体现到算法时间端到端较长的延迟。
FlashAttention Implementation(Tiling)
理论基础
在传统算法中,一种方式是将Mask和SoftMax部分融合,以减少访存次数。然而,FlashAttention则更加激进,它将从输入
到输出
的整个过程进行融合,以避免
矩阵的存储开销,实现端到端的延迟缩减。然而,由于输入的长度
通常很长,无法完全将完整的
及中间计算结果存储在SRAM中。因此,需要依赖HBM进行访存操作,与原始计算延迟相比没有太大差异,甚至会变慢(没具体测)。
为了让计算过程的结果完全在SRAM中,摆脱对HBM的依赖,可以采用
分片操作
,每次进行部分计算,确保这些计算结果能在SRAM内进行交互,待得到对应的结果后再进行输出。
这个过程中,有一点需要注意的是,之前对于softmax的计算是以行为单位的,如下所示:
当我们将输入进行分片后,无法对完整的行数据执行Softmax操作。这是因为Softmax函数在计算时需要考虑整个行的数据。然而,我们可以通过如下所示方法来获得与完整行Softmax相同的结果,而无需使用近似操作。
具体的分块softmax代码演示:
https://github.com/cauyxy/bilivideos/blob/master/flash-attn/softmax.ipynb