Skip to main content

FlashAttention

Attention计算

image.png

对一个Softmax计算的切片

sub block的softmax的结果和所有block softmax的结果成比例关系

只要在最后对sub block的结果做个scale 乘法修正,就可以得到整个block的结果

Flash Attention计算过程

image.png

  • 1-5:主要在初始化和进行切分:
  • 6-7:遍历K,V的每一块(Outer Loop)
  • 8:遍历Q的每一块 (Inner Loop)
  • 9:将分块后的QKV的小块加载到SRAM (Copy Block to SRAM)
  • 10:计算Sij (Compute Block on SRAM)
  • 11:计算Sij mask (Compute Block on SRAM)
  • 12:计算当前块的m,l统计量 (Compute Block on SRAM)
  • 13:更新全局m,l统计量 (Compute Block on SRAM)
  • 14:dropout (Compute Block on SRAM)
  • 15:计算Oi并写入HBM (Output to HBM)
  • 16:把li,mi写入HBM (Output to HBM)


image.png