Skip to main content

FlashAttention

Attention计算

image.png

image.png

对一个Softmax计算的切片

sub block的softmax的结果和所有block softmax的结果成比例关系

只要在最后对sub block的结果做个scale 乘法修正,就可以得到整个block的结果

Flash Attention计算Attention计算过程

image.png

    1-5:主要在初始化和进行切分: 6-7:遍历K,V的每一块(Outer Loop) 8:遍历Q的每一块 (Inner Loop) 9:将分块后的QKV的小块加载到SRAM (Copy Block to SRAM) 10:计算Sij (Compute Block on SRAM) 11:计算Sij mask (Compute Block on SRAM) 12:计算当前块的m,l统计量 (Compute Block on SRAM) 13:更新全局m,l统计量 (Compute Block on SRAM) 14:dropout (Compute Block on SRAM) 15:计算Oi并写入HBM (Output to HBM) 16:把li,mi写入HBM (Output to HBM)


    image.pngimage.png