FlashAttention

Attention计算

image.png

对一个Softmax计算的切片

def softmax(x):
    x_max = x.max()
    x_exp = torch.exp(x - x_max)
    x_exp_sum = x_exp.sum()
    return x_exp / x_exp_sum
  1. 记录每个sub block的  softmax结果 + x_max(标量) + x_exp_sum(标量)
  2. 更新全局的 max(标量) 和 exp_sum(标量)
  3. 通过一次遍历elementwise计算,就可以修正局部softmax成全局softmax
    1. sum和max的分块计算避免了重复的数据读取进行统计
    2. exp指数的加减法操作可以通过exp指数乘除法逆操作
    3. sum的结果,可以通过乘除法修正分块的错误偏置

其中,步骤1可以在计算qk时候顺便计算,步骤3可以在计算v时候顺便计算,所以softmax结合qkv计算不浪费存储器的读写

原始softmax需要遍历3遍数据,1. 统计max,2.统计sum,3,除法

Flash Attention计算过程

image.png


image.png

FlashAttention3

  1. 使用Hopper的异步wgmma指令来重叠cuda cores和tensor cores的操作,充分利用1D和2D算力

Revision #7
Created 8 March 2025 10:18:23 by Colin
Updated 9 March 2025 07:25:26 by Colin