清华团队提出微缩版FP4注意力机制,以即插即用方式加速推理

360影视 国产动漫 2025-05-29 18:34 4

摘要:近日,清华大学团队打造了首个用于推理加速的微缩版 FP4 注意力机制——SageAttention3,在英伟达 RTX5090 上实现了 1038TOPS 的计算性能。相比此前在英伟达 RTX5090 上计算性能最快的、由美国斯坦福大学提出的 FlashAtt

近日,清华大学团队打造了首个用于推理加速的微缩版 FP4 注意力机制——SageAttention3,在英伟达 RTX5090 上实现了 1038TOPS 的计算性能。相比此前在英伟达 RTX5090 上计算性能最快的、由美国斯坦福大学提出的 FlashAttention,SageAttention3 的计算性能快了 5 倍。实验表明,SageAttention3 能够加速各种模型,并且不会导致端到端质量指标的下降。

与此同时,研究团队还打造了首个用于训练加速的可训练 8 比特注意力机制——SageBwd,并探讨了它在训练任务中的可行性。其发现,8 比特注意力机制可以在微调任务中实现无损性能,不过在当前阶段的预训练任务中仍存在一定局限性。

由于注意力机制的时间复杂度是 n²,因此注意力机制的效率非常重要。为此,他们通过两个关键贡献提高了注意力的效率:首先,研究团队利用英伟达 Blackwell GPU 中的新 FP4 Tensor 内核来加速注意力计算。实验表明,SageAttention3 能够以即插即用的方式加速各种模型的推理。其次,研究团队在训练任务中率先采用了低比特注意力机制,而此前包括 FlashAttention3 和 SageAttention 在内的现有低比特注意力机制仅仅关注推理。

据该研究团队所知,本次研究首次实现了面向推理加速的 FP4 注意力机制设计,并开创性地探索了低比特注意力在大型模型训练中的可行性。目前,相关代码已开源:https://github.com/thu-ml/SageAttention。

解决两大障碍和一个难点

研究团队在论文中表示,FP4 注意力机制面临两个主要障碍,而 8 比特可训练注意力机制则面临着一个关键难点。具体来说:

第一个问题是:FP4 量化的数值表示范围极为有限(仅能表示 15 个可取值),导致无论是逐张量(per-tensor)还是逐词元(per-token)的量化方法,均无法有效保持模型精度。

第二个问题是:注意力图 P 主要由 [0,1] 范围内的小值组成。(注:注意力图 P 是 Self-Attention 中的核心输出矩阵,表示输入序列中所有位置之间的相关性权重。)若直接量化为 FP4 格式,这些数值会迫使扩展因子的动态范围被极度压缩。然而,硬件要求量化因子必须采用 FP8 数据类型,这一限制导致缩放因子以 FP8 格式表示时会产生显著的精度损失。

第三个问题是:在训练过程中使用 8 比特注意力机制时,研究团队发现注意力图的梯度特别容易受到量化误差的影响,从而导致输入梯度中的误差累积。

为了解决第一个问题,研究团队提出针对注意力机制中的两次矩阵乘法,即 QK⊤ 和 PV 中使用 FP4 微缩放量化方法。通过将量化组大小限制为 1x16(而非基于张量或通道),让本次方法在提高 FP4 量化精度的同时,能够有效抑制每个块内的异常值影响。

为了解决第二个问题,研究团队提出了一种针对注意力图 P 的两级量化方法,从而充分利用了 FP8 缩放因子的表示范围,提高了注意力图 P 的量化精度。具体而言,该方法首先通过逐 token 量化将每个 token 的数值范围归一化至 [0, 448 × 6],随后采用 FP4 微缩放量化来提升精度。

为了解决第三个问题,研究团队在反向传播涉及的五个矩阵乘法运算中,识别出对精度最为敏感的那个,并将其精度保持在 FP16 级别。

FP4 注意推理加速以及硬件实现与优化

在数据类型的确定上,FP4 数据类型有着两种选择。第一个选择是 NVFP4,其数据类型为 E2M1,量化块大小为 1×16,扩展因子为 E4M3 数据类型。第二个选择是 MXFP4,它也是 E2M1 数据类型,然而其量化块大小为 1×32,扩展因子为 E8M0 数据类型。

一番对比之后,研究团队选择了 NVFP4,这是因为 NVFP4 在注意力量化方面的精度远高于 MXFP4。下表展示了在 AI 视频生成模型 CogVideoX 所有层上使用实数 Q、K、V 的 MXFP4 和 NVFP4 的准确性。结果表明,NVFP4 的精度优于 MXFP4。

(来源:arXiv

不同于 FP16,在 FP4 的矩阵乘法中,FP32 累加器的内存布局与其操作数 A 的寄存器布局不同。如果通过线程间数据交换来匹配操作数 A 的布局,会导致内核性能下降。研究团队的方法是通过对 P tile 的列进行置换,来调整累加器的布局。为了保证矩阵乘法的正确性,研究团队相应地重新排列 K 的列,这一过程可以与量化内核融合处理。

进行微缩放量化时,需要找到每行连续 16 个元素中的最大值。然而,这 16 个元素分布在 4 个线程中,这就需要线程内部先求最大值,再通过线程间的 shuffle 操作进行归并,这大大拖慢了内核的执行速度。研究团队针对这一做法进行了优化,即把量化过程与在线 softmax 融合处理,与此同时这种融合还能计算每行的最大值。

在传统的 warp 专用内核中,消费者线程束通常同时执行矩阵乘法和存储操作,而生产者线程束只是负责加载输入数据,消费者线程束之间通过乒乓调度(ping-pong)调度实现阶段重叠。

然而,在研究团队的 FP4 注意力内核中,由于寄存器资源受限,这种方式无法实现。因此,研究团队设计了新的方案,即在生产者线程束之间进行乒乓调度:当一个生产者线程束为下一次矩阵乘法操作加载输入数据时,另一个生产者线程束同时将输出结果存储到全局内存中,而消费者线程束则仅负责将矩阵乘法的结果从寄存器转移到共享内存中。

通过采用这种新颖的设计,让他们在寄存器数量的限制下,实现了矩阵乘法和全局内存存储操作的重叠,从而提高了吞吐量。

将 INT8 注意力用于训练,并开展相关实验

据了解,低比特量化注意力相关工作,比如 FlashAttention3 和 SageAttention,仅适用于推理场景。

如前所述,研究团队提出了一种用于训练的 INT8 注意力机制——SageBwd。该机制将注意力计算中的七个矩阵乘法里的六个量化为 INT8 精度,同时在微调任务中实现了零性能损失。

实验中,研究团队验证了 SageAttention3 和 SageBwd 在语言、图像和视频生成等多种代表性模型中的有效性。

具体来说,他们在以下方面进行了实验:

在文本到文本任务的测试实验中,使用的是 Qwen2.5 和 Llama3.2;在文本到视频任务的测试实验中,使用的是 CogvideoX、HunyuanVideo 和 Mochi;在文本到图像任务的测试实验中,使用的是 Flux 和 Stable-Diffusion3.5。

研究团队将本次方法与 FlashAttention2、xformers、SageAttention 和 SageAtteention2 进行了比较。

需要说明的是,FlashAttention3 只能在英伟达 Hopper GPU 上运行,因此 FlashAttention 2 已经是英伟达 RTX5090 和英伟达 RTX4090 上能运行的最快版本。

下图展示了 SageAttention3 及其基线模型在 RTX 5090 上的内核运行速度。可以看出,SageAttention3 相较于 FlashAttention2 实现了 4~5 倍的加速,相较于 xformers 实现了 8~11 倍的加速。

下图展示了 SageBwd 及其基线模型在英伟达 RTX 4090 上的“正向+反向”传播的速度。结果表明,SageBwd 相较于 FlashAttention2 最多实现了 1.67 倍的加速,并且比基于 Triton 实现的 FlashAttention2 以及 xformers 具有更高的加速比。

在下表中,研究团队使用 SageAttention3 和其他注意力方法比较了各种模型上的端到端质量指标。结果表明,SageAttention3 在这些模型中几乎不会造成端到端的质量损失。

为了评估 SageBwd 在训练任务中的有效性,研究团队进行了两个实验。

首先,研究团队在 GSM8K、DROP、MMLU 和 HELLASWAG 数据集上对 Qwen2.5(3B)和 Llama3.2(1B)的基础模型进行微调。下图显示了微调损耗结果,表明 SageBwd 与 BF16 完全对齐。

其次,研究团队使用 Llama(400M)模型在 FineWebEdu 上进行预训练任务。下图显示了损耗曲线,表明虽然 SageBwd 可以实现损耗收敛,但其收敛速度相对较慢。这种限制制约了它在预训练任务中的适用性。

下图显示了视频生成的一些比较示例,包括使用 SageAttention3 在混元上生成视频和在 Stable-diffsion3.5 上生成图像。结果表明,SageAttention3 保持了完好的生成质量。

下图总结了端到端推理和训练延迟的改进情况。结果显示,相比混元和 CogVideoX,SageAttention3 在英伟达 RTX5090 上实现了约 3 倍和 2.4 倍的端到端推理生成加速。此外,SageBwd 在英伟达 RTX4090 上使用 8K/16K token 微批量训练 Llama(1B)时,实现了大约 1.15 倍的加速。

尽管 SageBwd 展现出比 FP16 实现更快的性能,但研究团队观察到其当前速度与理论上限之间存在显著差距。这一差距可能是由 Triton 内核实现不够优良导致的,研究团队计划进一步对其进行优化。研究团队在论文中表示,探索低比特注意力在预训练任务中的应用也是一个富有前景的研究方向,非常值得探索。

参考资料:

相关论文:https://.org/pdf/2505.11594

开源代码:https://github.com/thu-ml/SageAttention

来源:DeepTech深科技一点号

相关推荐