Skip to main content

FlashAttention

  1. 外层循环的下标j就是循环

    𝐾𝑇Attention计算𝑉,而内存循环

    image.png

    对一个Softmax计算下标就是循环𝑄

  2. 切片
  3. 这是部分的计算

    sub block的softmax的结果以要保存中间统计量m和l,等到j+1的下一次循环时,内层循环还会再次遍历Q

  4. 有block
  5. 计算𝑂=𝑠𝑜𝑓𝑡𝑚𝑎𝑥(𝑄𝐾𝑉),合并到最终的softmax的结果

在計算 P 的過程中會需要計算 softmax,這代表我們一定要先有每個 row 當中的所有 element 才有可能計算出 summation,而這個需要先得到所有 row 的限制也使得我們也沒有辦法將 K 和 Q 切分多個 sub blocks 各自計算出各自的結果。比例关系

image.png只要在最后对sub block的结果做个scale 乘法修正,就可以得到整个block的结果

Flash Attention计算


image.png