Skip to main content

TTT - Learning to (Learn at Test Time)

image.png

研究人员设计了一类新的序列建模层,其中隐藏状态是模型,更新规则是自监督学习的一个步骤。

由于更新测试序列上隐藏状态的过程,相当于在测试时训练模型,因此此类新层称为测试时训练(TTT)层。

为了在长上下文中既保持效率,又具有表达能力,需要一个更好的「压缩启发式」(compression heuristic)方法。具体来说,就需要将数百万个token压缩成一个能有效捕捉其底层结构和关系的隐藏状态。Transformer的KV cache在长序列的时候非常低效,Manba的固定长度的中间状态在长序列的时候表达力不够。

关键思想

关键思想是,使用自监督学习来将历史上下文压缩成一个隐藏状态。

方法是将上下文视为一个无标签数据集,而将状态视为一个模型。

具体来说,隐藏状态现在等同于一个模型f的权重,输出token就是由更新后权重的模型f对输入所做的预测。更新规则是在某个自监督损失ℓ上进行的一步梯度下降:

类似于去噪自编码器,f需要发现各维度之间的相关性,以便从部分信息中重构出序列。

即使在测试时,新层仍然为每个输入序列训练一个不同的权重序列。

因此,研究人员将其称之为测试-时间训练层(TTT)。

训练

训练带有TTT层神经网络的方式,与训练任何其他Transformer模型相同。可以使用相同的数据、方法和目标(如下一个token预测)来优化网络其余部分的参数。

研究人员将训练更大的神经网络称为外循环(outer loop),而在每个TTT层内训练W称为内循环(inner loop)。它们之间梯度计算的区别是,内循环针对的是W(即模型f的参数),外循环针对的是网络其余部分的参数。

代码

image.png

在每次推理的时候,会不断调用训练方法更新Task的权重,这个训练的过程在不断改变theta_K theta_V,同于表达前面sequence的状态。theta_K theta_V theta_Q的初始状态通过大循环训练得到?

效果

研究人员在Pile上执行了2k和8k上下文长度的标准实验,Pile是一个用于训练开源LLM的流行文档数据集。

TTT-MLP(M)在较大的FLOP预算下表现稍差。尽管TTT-MLP在每个模型大小上,都比TTT-Linear具有更好的复杂度,但FLOP的额外成本抵消了这种优势。

在8k上下文中,TTT-Linear(M)和TTT-MLP(M)的表现均明显优于Mamba。即使是具有Transformer架构的TTT-MLP(T),性能也比Mamba略好。

另外,研究人员还观察到了一个非常明显的现象:随着上下文长度变长,TTT层相对于Mamba的优势就更大了。