上周末在twitter上看到这篇文章,感觉很棒。激动的分享到现在和以前的组里,做理论的同学和老师们都赞不绝口。然而网络社区里也有很多批判的声音,觉得这文章旧瓶装新酒,没啥意思;觉得本质不就是XXX(现在真是神烦”本质就是XXX这个说法“)。可能这就是LLM时代做理论研究的窘境吧:人们往往会为某个假设或者条件吭哧好几年,但最终发现得到的intuition和十几年前的结论惊人的相似,自然有可能收获一个”就这?“。但我觉得这恰恰是做理论最令人激动的地方:发现不同复杂度的系统、不同时代的网络,不同结构的智能体都在某些事情上遵从类似的原理,这难道不令人振奋么?
不过这篇文章由于notation和行文组织的原因,确实有点难懂。这不是你的错,而是作者写作的问题!所以我在这里用最浅显的语言,最易懂的方式帮你拆开了揉碎了讲解一下这篇文章的high-level贡献,不逃避,不绕弯子(雾草,什么脏东西占领了我的键盘hhh)。
作者水平有限,欢迎大家评论区comments拍砖,一起学习。
这里简单过一下核心的notation,具体推导需要感兴趣各位自己耐心啃了。当然,我不会在这个过程中忘记推销我自己的工作的hhh(一般用灰色字体)。
公式下标S, Q分别代表training sample和test sample。如果你熟悉我的learning dynamics那套notation,可以理解成update和observation
L2 loss: $\Phi_S(u)=\frac{1}{2n}\|u-y\|_2^2$
Residual: $g(t)=\nabla_u\Phi_S(u(t))$ 我framework里的 $g=\nabla_z\mathcal{L}$. 注意我的framework里没考虑时间t,并且我的g是在Cross-entropy loss下计算的,单个sample的update,g是一个V1的vector。但本文里应该是一个n1 ouput向量,每个sample对应的是一个scalar。
Residual with time: $g(t)=P_g(t, 0) g(0)$. 这个算子P比较抽象,简言之是为后边积分做准备的。它可以时刻0的residual为起点,定点描述任何时刻t的g(t). 因为不影响主线故事,我这里用灰色
Jacobian: $J_S(w)=D_w U_S(w)$, D是求导算子,U是模型再训练集S上的输出拼接起来的向量 如果你习惯我文章里的notation: $J_S(w)=\nabla_w z$,不过我考虑了cross-entropy loss,所以logits z到probability这里还需要有个softmax,本文L2 loss不需要这一步。
eNTK: $K_{SS}(w) = J_S(w)J_S(w)^\top;\quad K_{QS}(w) = J_Q(w)J_S(w)^\top$. 注意,文章里的K是会随时间变化的,所以后边积分里出现的是 $K_{SS}(\tau)$. (更标准的写法应该是把参数w和时间t都带上吧,anyway)。eNTK可以有好多理解,个人觉得这里最方便的理解是“energy的传递“: residual g代表这次update有多少energy,这个energy要经过eNTK才能传递到output上,详见下边这个核心公式。
Test output change: $U_Q(T) - U_Q(s)=-G_Q(T,s)g(s)$, 这里新定义的notation G是原文公式11,带算子和积分。但你可以简单的理解成一步update的版本:
$$ U_Q(t+1) - U_Q(t) = -\langle\nabla_wU_Q ,w^{t+1}-w^t\rangle= -K_{QS}\cdot g(t) $$
这个用参数的一阶泰勒展开加上chain rule很好证明。公式11和8以及后边一堆带积分和算子的东西都可以这样理解:这个不停变化的g(t),是怎么通过不停变化的K(t),使得系统从时刻s变成时刻T的样子的。 熟悉我工作的小伙伴可以在我们ICLR 2023 (How to prepare your task head for finetuning) 公式2找到类似的形式。简而言之,能量g通过矩阵eNTK项传递到Q的output上。注意这里这里所有的操作都是考虑一个batch的data,所以g是个n1的向量。 我大部分的工作是考虑classification setting,只考虑一个update和一个observation,但引入了vocabulary size V,因此我大部分的g是V1的向量。
train output change: 类似的,我们也会有 $U_S(t+1) - U_S(t) = -K_{SS}\cdot g(t)$。它的含义也是:能量g,通过eNTK,作用在自己的output上。
总结:理解整个理论框架的关键是受力分析。能量g(t)要通过eNTK才能作用在各种output上。那么中间这个eNTK的各种性质就决定了training update到底对observing output有什么用的影响。本文强的地方在于它分析的是时间s→T的累积影响。 (因为考虑的是L2 loss,所以我那个AKG分解里的A就直接退化成了identity。见我ICLR 2023)
这一部分最主要的是理解signal space和reservoir space。可以把 training residual 想成一堆力。有些力施加出去之后,真的能改变 training output,并进一步影响 test output;这些方向是 signal channel。有些力虽然存在,但当前网络结构/轨迹无法把它有效传出去;这些方向就暂时掉进 reservoir。