FlashAttention
外层循环的下标j就是循环𝐾𝑇Attention计算和𝑉,而内存循环对一个Softmax计算的
下标就是循环𝑄切片
这是部分的计算sub block的softmax的结果
,和所以要保存中间统计量m和l,等到j+1的下一次循环时,内层循环还会再次遍历Q有block 计算𝑂=𝑠𝑜𝑓𝑡𝑚𝑎𝑥(𝑄𝐾𝑉),合并到最终的softmax的结果
在計算 P 的過程中會需要計算 softmax,這代表我們一定要先有每個 row 當中的所有 element 才有可能計算出 summation,而這個需要先得到所有 row 的限制也使得我們也沒有辦法將 K 和 Q 切分成多個 sub blocks 各自計算出各自的結果。比例关系
只要在最后对sub block的结果做个scale 乘法修正,就可以得到整个block的结果