muP 漫游 (一):muP 是什么?我们为何需要 muP?
Maximal Update Parameterization(muP)不仅是一套实现大模型稳定与高效训练的理论框架,更已成为现代预训练范式的底层思想之一:从超参数跨尺度迁移的规律,到优化器的设计原则,乃至网络架构本身的演进,处处可见其深刻影响。在过去一年多 muP 及其相关领域的学习中,笔者时常感慨于这套思想的简洁与力量,也愈发希望国内大模型社区能更清晰地理解、拥抱 muP。因此,本系列博客试图从基础理论到前沿应用,对 muP 做一个完整而详细的介绍。
在引入 muP 的具体概念之前,我们首先得明确它究竟在试图解决什么问题。我们的漫游之旅,就从大模型训练中的两大根本难题说起。
大模型训练的两大难题
众所周知,当神经网络变得越来越大的时候,训练也会变得越来越难。在这个系列,我们具体讨论其中的两个较为本质的问题,一是神经网络的特征学习能力渐进失效,二是大模型调参困难,且超参难以从小模型迁移到大模型。
一,神经网络的特征学习能力渐进失效。能够自适应从数据中学到有效的特征(各层的输出/激活值),是深度学习超越传统机器学习算法(如核方法)并取得巨大成功的本质原因之一。然而,先前的理论和实践研究都表明,当神经网络的大小(如宽度和深度)趋于无穷时,其特征学习能力往往会逐渐趋于失效。具体表现为神经网络各层特征的初始化和训练更新量会逐渐趋于无穷(爆炸)或者零(冻结),从而导致大模型的训练处于不稳定或者非常低效的场景。因此,如何保证神经网络在渐进情况下实现高效的特征学习是一个重要的问题。
二,大模型调参困难,且超参难以从小模型迁移到大模型。毫无疑问地,超参数(初始化方差、学习率、weight decay 等)对神经网络训练的效果有着决定性的影响。当模型比较小的时候,我们有能力用网格搜索、随机搜索等方法进行充分的调参试错。然而,当模型非常大的时候(如上百 B 甚至上 T 参数),普通的调参试错带来的运行成本往往是难以令人接受的。这时,一个很自然的想法是能不能将小模型上的超参直接复用到大模型的训练中?答案也是否定的,这样的操作往往会带来较为严重的训练不稳定。因此,如何科学、经济地获得大模型训练的超参,也是一个重要的问题。
对于问题二,我们还可以额外讨论一下 scaling law 的思路。经典主流的 scaling law 往往预测的是模型在目标规模上能达到的最优性能,但对于如何找到最优性能对应的超参并没有做出明确的解答。近两年也陆陆续续有一些 hyperparameter scaling law 的工作出现,但笔者感觉,它们的结论往往依赖于特定的任务、模型和数据,能否泛化至一般的训练场景仍然未知。Hyperparameter scaling law 的上限还不太明朗,因此本系列不做额外讨论。我们现在先默认问题二没有一个好的、通用的解决方案。
muP 的横空出世,很大程度上改变了我们面对这两大难题的无力局面。理论上,muP 从特征学习稳定高效的原则出发,对问题一做出了科学的解答(当然是有一些假设的,我们后续会提到),能够保证规模扩展时模型特征学习的稳定与高效。经验上,muP 在实际落地中展现了强大的指导超参数跨规模迁移的能力,使得小模型上搜得的超参可以被科学地复用到大模型上,实现大规模训练优异的性能。因此,muP 毫无疑问是大模型时代一个非常重要的理论成果(向 Greg Yang 大佬致敬),我们后续的博客会逐步展开,详细介绍。
说了这么多文字,读者也不一定能准确地理解到每个细节。接下来,我们就从简单的理论 toy model 来直接感受一下大模型训练的两大难题。
从两层线性神经网络看大模型的训练难题
我们尽可能地用最简单的例子来理解大模型训练的两大难题。数据上,我们假设只有一个数值正常的一维样本 $(x,y)$,即 $x, y = \Theta(1) \in \mathbb{R}$。模型上,我们的神经网络模型定义为两层 linear MLP,为 $f(x) = w_2^\top w_1 x$,参数维度为 $w_1, w_2 \in \mathbb{R}^n$。我们此时考虑的“大模型”,就是模型宽度 $n$ 非常大的两层 MLP。目标方面,我们定义 loss function 为 $L(f(x),y)$。优化算法上,为了简单,我们也只考虑一步梯度下降,并合理假设此时 $\partial L/\partial f = \Theta(1)$。
超参数的选择对神经网络的训练有着至关重要的影响,为了更好地显现出现实情况存在这两大难题,我们这里分析标准的超参参数化(Standard Parameterization/SP)。在这个简单的模型上,超参包括参数的初始化 $w_{ij} \sim \mathcal{N}(0, \sigma_i^2)$ 和梯度下降的学习率 $\eta$。对于初始化方差,我们使用实际中常用的类 Kaiming Normalization,为 $\sigma_1^2 = 1, \sigma_2^2 = 1/n$。学习率我们暂定为一个待定的数。
为了展现出前文所提的大模型的训练难题,我们现在需要对网络初始化和更新一步后的特征(features)进行分析。特征的具体定义就是模型每一层的输出/激活值。比如在我们这个 toy 模型中,我们可以记特征向量为 $h_1(x) = w_1 x$ 和 $h_2(x) = w_2^\top h_1(x) = f(x)$。我们用 RMS norm 来衡量特征向量的平均大小,即对于 $h \in \mathbb{R}^d$,定义 $\Vert h\Vert_{R} = \sqrt{\frac{\sum_i h_i^2}{d}}$。
前传特征分析
众所周知,对于普通 MLP,在类 Kaiming Normalization 的作用下,即使网络非常宽,前传特征也是可以保持稳定的,这保证了网络有一个健康不会爆炸的初始化。我们在这里还是做一下简单的分析。对于第一层输出的特征向量 $h_1(x) = w_1 x$,显然各维是独立同分布的零均值高斯分布,且 scale 满足
\[\mathbb{E} [h_1(x)_i^2] = \mathbb{E}[w_{1i}^2] x^2 = \sigma_1^2 x^2 = x^2 = \Theta(1) \implies \Vert h_1(x)\Vert_{R} = \Theta(1).\]所以第一层输出的特征 $h_1(x)$ 是稳定的。对于网络(第二层)的输出 $f(x)$,也是零均值的,scale 也有类似的推导:
\[\begin{aligned} \mathbb{E}[f(x)^2] &= \mathbb{E}\left[\left(\sum_{i=1}^n w_{2i} h_{1i}\right)^2\right] = \mathbb{E}\left[\sum_{i,j=1}^n w_{2i} w_{2j} h_{1i} h_{1j}\right] \\ &= \sum_{i=1}^n \mathbb{E}[w_{2i}^2] \mathbb{E}[h_{1i}^2] = \sum_{i=1}^n \sigma_2^2 \mathbb{E}[h_{1i}^2] \\ &= \sum_{i=1}^n \frac{1}{n} \mathbb{E}[h_{1i}^2] = x^2 = \Theta(1) \\ & \implies \Vert f(x)\Vert_{R} = |f(x)|= \Theta(1). \end{aligned}\]因此,在类 Kaiming Normalization 的作用下,即使两层 MLP 的宽度非常大,其初始化时的前传特征也是稳定的。
梯度分析
我们接下来的目标是一步梯度下降后的特征更新量,为了得到这个量,我们得先计算权重的变化量,这又需要我们先计算梯度。直接利用链式法则计算可以得到各梯度(一个简单的入门教程:博客园-矩阵向量求导):
\[\begin{aligned} & \frac{\partial L}{\partial w_2} = \frac{\partial L}{\partial f} \frac{\partial f}{\partial w_2} = \frac{\partial L}{\partial f} w_1 x = \frac{\partial L}{\partial f} h_1, \\ & \frac{\partial L}{\partial h_1} = \frac{\partial L}{\partial f} \frac{\partial f}{\partial h_1} = \frac{\partial L}{\partial f} w_2, \\ & \frac{\partial L}{\partial w_1} = \frac{\partial L}{\partial f} \frac{\partial h_1^\top}{\partial w_1} \frac{\partial f}{\partial h_1} = \frac{\partial L}{\partial f} w_2 x. \end{aligned}\]更新特征分析(问题一)
我们现在可以分析特征在一步梯度下降中的更新量了。首先看第一层的特征更新 $\Delta h_1(x)$(因为我们假设了只有一个数据,所以第二次前传的数据并不变化),我们有:
\[\Delta h_{1} = \Delta w_{1} x = -\eta \frac{\partial L}{\partial w_{1}} x = -\eta \frac{\partial L}{\partial f} w_{2} x^2 = -\Theta(\eta w_{2})\]我们上面的推导用了前文假设的 loss function 对输出的梯度是常数级的($\partial L/\partial f = \Theta(1)$),这在训练初期是一个合理的假设。我们也不太严谨地对向量用了 $\Theta$ 符号,不过应该不影响理解。基于此,我们可以进一步看 $\Delta h_1(x)$ 的 scale:
\[\begin{aligned} \mathbb{E}[\Delta h_{1i}^2] &= \Theta(\mathbb{E}[\eta^2 w_{2i}^2]) = \Theta(\mathbb{E}[\eta^2 \sigma_{2}^2]) = \Theta(\frac{\eta^2}{n}) \\ & \implies \Vert \Delta h_1(x)\Vert_R = \Theta(\frac{\eta}{\sqrt{n}}). \end{aligned}\]若学习率不随网络宽度变化($\eta = \Theta(1)$),这个理论结果意味着当模型逐渐变大时($n \to \infty$),至少在训练初期,第一层的特征更新会趋近于 0($\Delta h_{1} \to 0$)。此时,模型的特征学习能力渐进消失,正反映了我们上面所述的问题一。
类似地,我们也可以进一步估计 $\Delta f(x)$ 的 scale,首先化简其表达形式:
\[\Delta f(x) = \Delta w_{2}^\top (h_1 + \Delta h_1) + w_{2}^\top \Delta h_1,\]其中 $\Delta w_2 = - \eta \frac{\partial L}{\partial w_2} = - \eta \frac{\partial L}{\partial f} h_1$。我们接下来分别估计每一项的 scale。首先,对于 $\Delta w_{2}^\top h_1$,我们有
\[\begin{aligned} &\Delta w_{2}^\top h_1 = - \eta \frac{\partial L}{\partial f} h_1^\top h_1, \\ & \mathbb{E} [\Delta w_{2}^\top h_1] = \Theta(\mathbb{E}[\eta h_1^\top h_1]) = \Theta(\eta \sum_{i=1}^n \mathbb{E}[h_{1i}^2]) = \Theta(\eta n x^2) = \Theta(\eta n). \end{aligned}\]对于第二项 $\Delta w_{2}^\top \Delta h_1$,我们有:
\[\begin{aligned} &\Delta w_{2}^\top \Delta h_1 = - \eta \frac{\partial L}{\partial f} h_1^\top \cdot -\eta \frac{\partial L}{\partial f} w_{2} x^2 = \eta^2 (\frac{\partial L}{\partial f})^2 x^2 w_2^\top h_1 = \eta^2 (\frac{\partial L}{\partial f})^2 x^2 f(x), \\ & \mathbb{E} [|\Delta w_{2}^\top \Delta h_1|] = \eta^2 (\frac{\partial L}{\partial f})^2 x^2 \mathbb{E}[|f(x)|] = \Theta(\eta^2). \end{aligned}\]对于第三项 $w_{2}^\top \Delta h_1$,我们有
\[\begin{aligned} & w_{2}^\top \Delta h_1 = w_{2}^\top \cdot -\eta \frac{\partial L}{\partial f} w_{2} x^2 = -\eta \frac{\partial L}{\partial f} x^2 w_2^\top w_2, \\ & \mathbb{E}[w_{2}^\top \Delta h_1] = \Theta(\eta \mathbb{E}[w_2^\top w_2]) = \Theta(\eta \sum_{i=1}^n \mathbb{E}[w_{2i}^2]) = \Theta(\eta \cdot n \frac{1}{n}) = \Theta(\eta). \end{aligned}\]综合上面的结果,我们可以估计 $\Delta f(x)$ 的 scale,为
\[|\Delta f(x)| = \Theta(\eta n) + \Theta(\eta^2) + \Theta(\eta).\]可以看到,若学习率不随网络宽度变化($\eta = \Theta(1)$),当模型逐渐变大时($n \to \infty$),网络输出的更新量会趋近于无穷($\Delta f(x) = \Theta(n)$)。此时,模型的特征学习能力渐进爆炸,无法稳定训练,也反映了我们上面所述的问题一。
一个自然的补救想法是:既然 $\Delta f(x)$ 的主导项是 $\Theta(\eta n)$,那我们是否可以令 $\eta = \Theta(1/n)$,从而让输出更新保持在常数级?这确实可以避免 $\Delta f(x)$ 爆炸,但代价是第一层特征更新会进一步缩小:
\[\Vert \Delta h_1(x)\Vert_R = \Theta\left(\frac{\eta}{\sqrt n}\right) = \Theta\left(\frac{1}{n^{3/2}}\right).\]因此,在实际的标准参数化(SP)上,仅仅寻找一个随宽度缩放的全局学习率并不能真正解决问题:若 $\eta=\Theta(1)$,输出更新爆炸;若 $\eta=\Theta(1/n)$,输出更新稳定但第一层特征冻结得更加彻底。换句话说,SP 不存在一个简单的全局超参数(学习率)缩放,可以同时让每一层的特征更新都处在合适的尺度,这意味着模型的特征学习能力往往会规模极限失效。
超参难以迁移(问题二)
我们现在接着来说明,超参数(如学习率)是无法从小模型上搜索然后直接迁移到大模型上的。 这种方式意味着超参设置与模型大小无关,在数学上就表示为 $\eta = \Theta(1)$。从上面的理论结果我们可以直接得到,第一层的特征更新会趋向于冻结,而输出层的特征更新反而会趋向于爆炸,是一个非常病态的学习状态,导致大模型的训练效果会非常差。这意味着,小模型上最优的超参不能直接复用到大模型训练上。另一方面,在大模型上进行试错性调参,又是非常困难,耗费巨大的。两者结合,对应了我们上面所提的问题二。
muP 的解决思路
从上面的推导可以看出,SP 存在两个问题:(1)超参无法直接从小模型复用到大模型(2)全局的超参(学习率)设置导致也无法通过调参来实现各层特征同时的稳定更新。但是,我们看起来还是有机会通过随模型大小,分层、科学地调整超参来实现特征的稳定初始化和更新,比如设置第一层和第二层设置不同的学习率来分别实现各层好的特征学习效果。muP 的解决思路正是如此:在模型变大的过程中,通过分层、科学地调整(迁移)小模型上得到的最优超参(初始化、学习率等),实现大模型特征学习的稳定与高效(Maximal Update)。我们后续的文章就将逐渐进行展开介绍。
总结
本文首先提出大模型训练特征学习能力渐进失效和大模型调参困难、超参难以从小模型迁移到大模型的两大难题,并在一个简单例子上从理论角度说明了这两大难题,最后简单阐述了 muP 的解决思路。