Skip to main content

NSA 稀疏注意力机制 by deepseek

  1. NSA致力于实现硬件对齐的推理加速,通过特定的算法设计减少内存访问和硬件调度瓶颈,NSA 速度在64k inference相较 Flash Attention 前向加速9倍,反向加速6倍
  2. NSA的总体框架是通过更紧凑和信息密集的表示来替换原始的键值对
  3. NSA有三种映射策略,分别是压缩(cmp)、选择(slc)和滑动窗口(win)。通过将不同策略得到的键值对进行组合

理解

  1. 引入动态选择和压缩历史的KV,减少计算量,符合实际的自然语言规律,但是
    1. 不一定完全匹配语言的表达逻辑
    2. 没有改变transformer的固有问题,多层信息不共享等
  2. 一定程度上等价于增加一层attention,增加训练难度

原理

image.png

假设上下文为64k时, 如果我们取128个全局压缩KV,8个512选择块KV和就近窗口4096个KV, 那么我们得到了压缩倍数7.88

image.png

  1. tokens压缩:通过将连续的键或值块聚合为块级表示,得到压缩后的键值,从而捕获整个块的信息
    1. W_K_cmp = torch.randn(l, 1) #MLP: W2[1,4l]@(W1[4l, l]@X[l, d])
      W_V_cmp = torch.randn(l, 1)
      W_pe = torch.randn(l, dim)
      
      K_cmp = []
      V_cmp = []
      for i in range(max_idx):
          cur_K = K[:, i * d + 0: i * d + l , :] + W_pe.unsqueeze(0)
          cur_V = V[:, i * d + 0: i * d + l , :] + W_pe.unsqueeze(0)
          cur_K = cur_K.transpose(1, 2) @ W_K_cmp 
          cur_V = cur_V.transpose(1, 2) @ W_V_cmp
          K_cmp.append(cur_K)
          V_cmp.append(cur_V)
      
      K_cmp = torch.cat(K_cmp, dim = 2).transpose(1,2)
      V_cmp = torch.cat(V_cmp, dim = 2).transpose(1,2)
      print(K_cmp.shape) # torch.Size([1, 4, 16]) # 长度为32->4
      print(V_cmp.shape) # torch.Size([1, 4, 16]) # 长度为32->4
  2. tokens选择:仅使用压缩键值可能会丢失重要的细粒度信息,因此需要选择性地保留单个键值
    1. idx_slc_start = idx * d
      idx_slc_end = idx * d + l
      K_slc = torch.randn(batch_size, t, d * select_top_k, dim)
      V_slc = torch.randn(batch_size, t, d * select_top_k, dim)
      for i in range(batch_size):
          for j in range(t):
              for k in range(select_top_k):
                  K_slc[i, j, k * d : k * d + l, :] = K[i, idx_slc_start[i, j, k ] :  idx_slc_end[i, j, k ] , :]
                  V_slc[i, j, k * d : k * d + l, :] = V[i, idx_slc_start[i, j, k ] :  idx_slc_end[i, j, k ] , :]
      print(K_slc.shape) # bs, seq_len, select_kv, dim, 1,32,16,16, 不同t时刻选到不同的select_kv
      print(V_slc.shape) # bs, seq_len, select_kv, dim  1,32,16,16, 不同t时刻选到不同的select_kv
  3. 滑动窗口:为了防止局部模式主导学习过程,影响模型从压缩和选择tokens中学习,NSA引入了专门的滑动窗口分支来处理局部context,窗口注意力是捕捉与当前q最近的kv片段,这里做了假设,即越相近的KV就越重要
    1. # built sliding window attention
      def get_window_mask(seq_len, window):
          mask = torch.ones(seq_len, seq_len)
          mask = torch.tril(mask)
          win_mask = torch.ones(seq_len - window, seq_len - window)
          win_mask = 1.0 - torch.tril(win_mask)
          mask[window:, :seq_len - window] = win_mask
          return mask
      print(get_window_mask(7, 3)) # test
      window_mask = get_window_mask(t, 8)
  4. 注意力聚合:在上述三个注意力计算中,我们都得到了同样维度[1, 32, 16] 的注意力输出

    1. o_list = [o_cmp, o_slc, o_win]
      o_star = torch.zeros(batch_size, t, dim)
      for i in range(3):
          o_star += gate[:, :, i].unsqueeze(2) * o_list[i]
      print(o_star.shape)

计算加速

image.png