FlashAttention
Attention计算

对一个Softmax计算的切片
sub block的softmax的结果和所有block softmax的结果成比例关系
只要在最后对sub block的结果做个scale 乘法修正,就可以得到整个block的结果
Flash Attention计算过程
- 1-5:主要在初始化和进行切分:
- 6-7:遍历K,V的每一块(Outer Loop)
- 8:遍历Q的每一块 (Inner Loop)
- 9:将分块后的QKV的小块加载到SRAM (Copy Block to SRAM)
- 10:计算Sij (Compute Block on SRAM)
- 11:计算Sij mask (Compute Block on SRAM)
- 12:计算当前块的m,l统计量 (Compute Block on SRAM)
- 13:更新全局m,l统计量 (Compute Block on SRAM)
- 14:dropout (Compute Block on SRAM)
- 15:计算Oi并写入HBM (Output to HBM)
- 16:把li,mi写入HBM (Output to HBM)