Skip to main content

FlashAttention

    外层循环的下标j就是循环

    𝐾𝑇Attention计算𝑉,而内存循环的下标就是循环𝑄


    这是部分的计算结果,所以要保存中间统计量m和l,等到j+1的下一次循环时,内层循环还会再次遍历Q

    image.png

    计算𝑂=𝑠𝑜𝑓𝑡𝑚𝑎𝑥(𝑄𝐾𝑉),合并到最终的结果

    对一个Softmax计算的切片

    在計算 P 的過程中會需要計算 softmax,這代表我們一定要先有每個 row 當中的所有 element 才有可能計算出 summation,而這個需要先得到所有 row 的限制也使得我們也沒有辦法將 K 和 Q 切分成多個 sub blocksblock的softmax的结果和所有block 各自計算出各自的結果。softmax的结果成比例关系

    image.png只要在最后对sub block的结果做个scale 乘法修正,就可以得到整个block的结果

    Flash Attention计算


    image.png