Skip to main content

改进大规模训练稀疏自编码器的方法

Ref :https://mp.weixin.qq.com/s/iZHPnnIncVFa8QJOuH8qFg

神经网络中的激活通常表现出不可预测和复杂的模式,且每次输入几乎总会引发很密集的激活。而现实世界中其实很稀疏,在任何给定的情境中,人脑只有一小部分相关神经元会被激活。

研究人员开始研究稀疏自编码器,这是一种能在神经网络中识别出对生成特定输出至关重要的少数“特征”的技术,类似于人在分析问题时脑海中的那些关键概念

在OpenAI超级对齐团队的这项研究中,他们推出了一种基于TopK激活函数的新稀疏自编码器(SAE)训练技术栈,消除了特征缩小问题,能够直接设定L0(直接控制网络中非零激活的数量)。

具体来看,他们使用GPT-2 smallGPT-4系列模型的残差流作为自编码器的输入,选取网络深层(接近输出层)的残差流,如GPT-4的5/6层、GPT-2 small的第8层。

并使用之前工作中提出的基线ReLU自编码器架构,编码器通过ReLU激活获得稀疏latent z,解码器从z中重建残差流。损失函数包括重建MSE损失和L1正则项,用于促进latent稀疏性。

image.png

然后,团队提出使用TopK激活函数代替传统L1正则项。TopK在编码器预激活上只保留最大的k个值,其余清零,从而直接控制latent稀疏度k。

image.png

不需要L1正则项,避免了L1导致的激活收缩问题。实验证明,TopK相比ReLU等激活函数,在重建质量和稀疏性之间有更优的权衡。

image.png

此外,自编码器训练时容易出现大量latent永远不被激活(失活)的情况,导致计算资源浪费。

团队的解决方案包括两个关键技术:

  • 将编码器权重初始化为解码器权重的转置,使latent在初始化时可激活。

  • 添加辅助重建损失项,模拟用top-kaux个失活latent进行重建的损失。

如此一来,即使是1600万latent的大规模自编码器,失活率也只有7%

最后,论文一作表示稀疏自编码器的问题仍然远未解决,这项研究中的SAE只捕获了GPT-4行为的一小部分,即使看起来单义的latent也可能难以精确解释。而且,从表现优异的SAE到更好地理解模型的行为,还需要大量的工作。