深度学习进阶(三十一)FlashAttention:IO 感知的精确注意力

发布时间:2026/6/29 9:27:23
深度学习进阶(三十一)FlashAttention:IO 感知的精确注意力 我们把现代大模型的五个核心模块拼回了 LLaMA 这个完整案例中可以看到注意力机制仍然是计算最密集的部分。而这个密集程度在序列变长时会变得越来越恐怖标准自注意力的计算复杂度和空间复杂度都是 序列长度翻倍计算量翻四倍内存占用也翻四倍。而在之前我们用 KV Cache 解决了推理阶段的重复计算问题但训练和长序列推理中的注意力计算本身仍然是一个巨大的计算瓶颈。因此在展开正式多模态前会再插入几篇现代工程优化技术的相关内容作为支撑。一直以来针对 Attention 的计算量问题主要有两条路线近似注意力如稀疏注意力、低秩近似等总结来说就是压缩用质量换速度。更好的硬件利用充分利用 GPU 的计算单元让显存的数据搬运不再成为瓶颈这也是现代 LLM 的主流。22 年的论文FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness 走的就是第二条路。它不仅没有牺牲精度反而在注意力计算上实现了 2-4 倍的实际加速并将注意力中间结果的显存占用从 降到了 。 其核心观点是注意力计算慢不是计算量太大而是显存搬运太频繁。这一技术涉及到 GPU 相关的硬件知识因此先行补充1. GPU 上的注意力#1.1 计算指标和 GPU 核心构建#我们知道支撑 LLM 庞大算力的基础设施是 GPU因此要理解 FlashAttention首先需要理解 GPU 是怎么工作的。先展开两个核心指标FLOPsFloating Point Operations指浮点运算次数代表“计算量”。也就是 GPU 的计算单元需要做多少次乘法和加法衡量的是算法的算力消耗。IOInput/Output指数据在不同内存之间的搬运次数与字节量代表“访存量”。也就是数据从慢速显存搬到快速缓存计算完再搬回去的开销。很显然前者是我们的算法账而后者则是硬件相关的开销在训练和推理中二者都缺一不可。我们已经十分熟悉在代码上进行算法优化的逻辑而FlashAttention 是首个把 IO 优化思想引入注意力机制的工作。由此我们继续展开现代 GPU 有两种主要的存储空间HBMHigh Bandwidth Memory高带宽显存容量大A100 有 40/80GB但带宽有限我们在市面上听说的多少 GB 的显卡指的都是 HBM。SRAM片上共享内存容量极小A100 每个 SM 只有 192KB但带宽极高速度极快它的层次和 CPU 中的 cache 相似但更加受我们操作者控制。为了获得最高计算效率GPU 的计算单元往往只直接读写 SRAM不会直接访问 HBM。所以要从 HBM 读取数据到 SRAM再从 SRAM 做计算最后把结果写回 HBM这一过程涉及大量 IO 操作。1.2 注意力中的 IO 操作#我们对标准注意力计算的过程已经十分熟悉了这里我们扩展开来看看其计算在 GPU 上的具体 IO 过程注意这里反复进行了6 次 HBM ↔ SRAM 的数据搬运。举个例子对于序列长度 、维度多头总和 、批量大小 的情况我们进行一个估计:总搬运量搬运趟数每趟搬运的元素量字节数字节假设 FP16最终大约的数据量就是现在扩展上下文到 这个数字就会暴涨而一个 GPU 的 HBM 带宽大约是 2 TB/s也就是说仅仅数据搬运IO就要 40 多毫秒。注意这只是一次注意力计算。相比之下A100 的算力高达 312 TFLOPS完成这些矩阵乘法的实际计算FLOPs可能只需要不到 1 毫秒。计算单元极快。于是我们发现了问题标准注意力的瓶颈不在计算FLOPs而在访存IO。是数据搬运太慢导致 GPU 大量时间在“空转等数据”。由此FlashAttention 开始了优化2. FlashAttention 的核心思路#FlashAttention 的想法很直接与其把整个 Q、K、V 搬来搬去不如每次只加载一小块到 SRAM 上进行完所有计算再写回 HBM。这样原本需要在 HBM 和 SRAM 之间来回搬运多次的数据现在就只需要一次完整的读取和一次写入。但这里有一个硬性问题Softmax 是一个全局操作要计算某个位置的 softmax需要知道所有位置的分数。而如果要使用刚刚说的分块计算显然每个块只能看到自己的局部信息怎么算全局 softmax 得到注意力权重答案是 18 年 NVIDIA 研究者的论文 Online normalizer calculation for softmax提出的Online Softmax在线 Softmax。2.1 Online Softmax#我们知道对于一个包含 个元素的向量 Softmax 函数将其转换为概率分布 的标准公式为但在实际工程中我们往往并不会使用这个公式而是使用Safe Softmax其中这是因为指数函数 增长极快。如果 比较大比如 1000 会直接超出计算机浮点数的表示范围变成 inf无穷大导致最终计算结果是 NaN。因此Safe Softmax 在分子分母同除以一个常数 通常取向量中的最大值 。因为指数相减等于相除不会改变最终的相对比例但把所有指数都拉到了 的范围避免了上溢出。但这仍然没有解决我们现在的问题因为无论是取最大值还是计算分母我们还是需要遍历所有元素。而 Online Softmax 的思想是这样的当数据无法一次性全部读入内存或者需要分块计算时可以使用流式更新的方式计算最大值 和分母 即 。假设我们正在逐个读取向量的元素到第 个元素时更新最大值然后更新分母最终结果其核心技巧是 的公式中的 这一项的作用是当遇到更大的新最大值 时把之前累加的分母 “按比例缩小”使其基准与新的最大值对齐。我们来看一个简单实例假设向量为 初始化 第一步我们读入 2更新最大值和分母此时继续第二步读入 1再次更新现在继续第三步读入 3注意这里最大值发生变化因此之前累加的分母会重新缩放最终得到与标准 Safe Softmax 的结果完全一致。整个过程中我们只需要维护两个标量 即可而不需要等全部数据读完后再计算。但很显然这种 Online Softmax 是逐个更新并不符合 FlashAttention 分块计算逻辑因此我们还要再进行适配2.2 把 Online Softmax 嵌入注意力#