跳转至

优化器⚓︎

Abstract

本文整理 language model 训练中的参数更新机制,聚焦 SGD、Momentum、Adam、AdamW、状态量开销与常见工程边界。loss 定义训练目标,optimizer 则定义在拿到梯度后参数如何更新。

参数更新问题⚓︎

对参数 \theta 而言,训练的直接对象不是 optimizer,而是目标函数 L(\theta)。反向传播先给出梯度

\nabla_\theta L,

optimizer 再根据这组梯度决定当前步的更新方向、更新尺度以及是否维护额外状态。因此,optimizer 讨论的是“如何更新”,而不是“模型学什么”。

从工程角度看,optimizer 至少影响四个问题:收敛速度、训练稳定性、状态量开销,以及不同参数组是否需要不同更新规则。对大模型训练而言,这些问题通常同样重要。

SGD⚓︎

随机梯度下降的基本形式是

\theta_{t+1} = \theta_t - \eta g_t,

其中 g_t 是当前 mini-batch 上的梯度,\eta 是学习率。它的优点是定义简单、状态量最少,也更容易直接理解学习率对更新幅度的影响。

局限同样明确。若不同参数维度的曲率差异较大,统一学习率往往难以同时兼顾所有方向;若梯度噪声较大,更新轨迹也会更抖动。因此,纯 SGD 往往依赖更谨慎的学习率设计。

Momentum⚓︎

Momentum 在 SGD 的基础上引入速度项,把近期梯度的指数滑动平均带入更新:

v_t = \beta v_{t-1} + (1 - \beta) g_t,
\theta_{t+1} = \theta_t - \eta v_t.

它的作用不是改变目标函数,而是让更新方向对短期噪声更不敏感,并在长期一致的下降方向上积累惯性。工程上,这通常能减少锯齿形震荡,使收敛轨迹更平滑。

Adam 系列⚓︎

Adam⚓︎

Adam 同时维护一阶矩与二阶矩:前者估计梯度的平均方向,后者估计梯度平方的平均尺度。前者决定当前更该朝哪个方向更新,后者决定这个方向上的步长需要被缩小多少。

其核心形式是

m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t,
v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2.

更新时会用 m_t 提供方向,用 \sqrt{v_t} 提供尺度归一化。直观上,梯度长期较大的维度会被自动缩小步长,梯度较小的维度则不会被同样强烈地压制。需要注意的是,Adam 的自适应步长并不等于“自动解决所有调参问题”;它减少的是不同参数尺度不一致带来的困难,而不是消除学习率、batch size 与训练阶段划分的影响。


算法流程

从训练代码视角看,Adam 的单步更新可以拆成六个动作:读取当前梯度、更新一阶矩、更新二阶矩、做 bias correction、构造归一化分母、更新参数。

更具体地说,这六步分别在解决不同问题:

  1. 读取当前梯度 g_t,确定本步目标函数对参数的局部变化方向。
  2. 更新一阶矩 m_t,用指数滑动平均平滑短期方向噪声。
  3. 更新二阶矩 v_t,估计历史梯度尺度,使不同参数维度的更新幅度可以分开归一化。
  4. 根据 step 计数做 bias correction,把零初始化导致的早期统计偏小修正掉。
  5. 构造分母 \sqrt{\hat{v}_t} + \epsilon,其中 \epsilon 是数值稳定项,而不是算法主体。
  6. 用 bias-corrected 一阶矩除以分母,得到最终更新量并写回参数。

因此,Adam 的关键不是“记住两条递推公式”,而是理解它在每一步里同时做了方向平滑、尺度估计与早期偏差修正。

Adam 最小实现
adam_step.py
import torch


@torch.no_grad()
def adam_step(param_groups, state):
    for group in param_groups:
        lr = group['lr']
        beta1, beta2 = group['betas']
        eps = group['eps']

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise RuntimeError('Adam does not support sparse gradients')

            if p not in state:
                state[p] = {
                    'step': 0,
                    'exp_avg': torch.zeros_like(p),
                    'exp_avg_sq': torch.zeros_like(p),
                }

            param_state = state[p]
            param_state['step'] += 1
            exp_avg = param_state['exp_avg']
            exp_avg_sq = param_state['exp_avg_sq']
            step = param_state['step']

            exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
            exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

            bias_correction1 = 1 - beta1 ** step
            bias_correction2 = 1 - beta2 ** step
            exp_avg_hat = exp_avg / bias_correction1
            exp_avg_sq_hat = exp_avg_sq / bias_correction2
            denom = exp_avg_sq_hat.sqrt().add_(eps)

            p.addcdiv_(exp_avg_hat, denom, value=-lr)

Bias Correction⚓︎

Adam 一类方法的矩估计从零开始初始化,因此训练早期的 m_tv_t 会系统性偏小。bias correction 的作用,就是对这种初始化偏差做解析修正:

\hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \qquad \hat{v}_t = \frac{v_t}{1 - \beta_2^t}.

若不做这一步,训练最开始若干步的有效更新幅度会被低估。它主要影响早期阶段,因此和 warmup 常一起讨论,但逻辑上仍属于 optimizer 内部机制。更重要的是,它发生在矩更新之后、参数更新之前:只有先修正估计偏差,后续构造分母与更新量才有正确尺度。

AdamW⚓︎

AdamW 保留 Adam 的自适应更新机制,但把 weight decay 从梯度项中解耦出来。这个区别的关键不在于公式写法本身,而在于语义边界更清楚:

  • 梯度项负责沿目标函数下降
  • weight decay 负责把参数朝零收缩

若把 L2 regularization 直接混入 Adam 的梯度项,实际衰减强度会继续受到自适应分母影响,不同参数上的衰减行为不再一致。AdamW 则把这两件事拆开,因此更符合“权重衰减”本来的语义,也更便于调参与分析。


算法流程

AdamW 与 Adam 的矩估计部分基本一致,真正变化的是参数衰减不再混入梯度项,而是作为单独的参数收缩步骤放在自适应更新之外。

按单步顺序看,它的流程可以写成:

  1. 读取当前梯度与参数。
  2. 按 weight decay 直接收缩参数,使参数衰减不经过自适应分母。
  3. 更新一阶矩 m_t 与二阶矩 v_t
  4. 做 bias correction,修正零初始化造成的早期偏差。
  5. 执行 Adam 风格的自适应更新。

需要记住的不是 AdamW 把所有步骤都改写了,而是参数衰减的语义被从梯度更新里拿了出来。因此,AdamW 和 Adam 的主要差异发生在 weight decay 的位置,而不是 bias correction 本身。

AdamW 最小实现
adamw_step.py
import torch


@torch.no_grad()
def adamw_step(param_groups, state):
    for group in param_groups:
        lr = group['lr']
        beta1, beta2 = group['betas']
        eps = group['eps']
        weight_decay = group['weight_decay']

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise RuntimeError('AdamW does not support sparse gradients')

            if p not in state:
                state[p] = {
                    'step': 0,
                    'exp_avg': torch.zeros_like(p),
                    'exp_avg_sq': torch.zeros_like(p),
                }

            param_state = state[p]
            param_state['step'] += 1
            exp_avg = param_state['exp_avg']
            exp_avg_sq = param_state['exp_avg_sq']
            step = param_state['step']

            if weight_decay != 0:
                p.mul_(1 - lr * weight_decay)

            exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
            exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

            bias_correction1 = 1 - beta1 ** step
            bias_correction2 = 1 - beta2 ** step
            exp_avg_hat = exp_avg / bias_correction1
            exp_avg_sq_hat = exp_avg_sq / bias_correction2
            denom = exp_avg_sq_hat.sqrt().add_(eps)

            p.addcdiv_(exp_avg_hat, denom, value=-lr)

Weight Decay⚓︎

weight decay 的目标是限制参数无约束增大。它最常作用于线性层或投影层权重,而不一定作用于所有参数。

这里需要区分两个概念:

  • L2 regularization:把参数范数项加到目标函数中,再一起求梯度
  • weight decay:在更新步骤中直接对参数做缩放或减法收缩

在 SGD 场景下,这两种写法常接近;在 Adam 这类自适应方法中,它们的行为则不再完全等价。因此,工程上更常显式讨论 AdamW 的 decoupled weight decay。

工程实现⚓︎

PyTorch 实现⚓︎

在 PyTorch 训练代码里,optimizer 通常按 param_groups -> params -> state[p] 的层级组织。对 Adam 与 AdamW 而言,正文里的一阶矩、二阶矩与步数计数,分别对应 exp_avgexp_avg_sqstep 这三个状态字段。

从实现视角看,param_groups 负责保存一组参数共享的超参数,例如学习率、betasepsweight_decaystate[p] 则负责保存某个具体参数自己的历史统计量。因此,数学公式里写成同一个优化器状态,在代码里其实被拆成“组配置”和“参数私有状态”两层。

实现要点⚓︎

实现上最容易混淆的不是公式,而是状态字段和参数分组各自承担什么职责。state[p] 保存的是某个具体参数自己的历史统计量,而 param_groups 保存的是一组参数共享的超参数配置。

  • state[p]['step']:当前参数已更新次数
  • state[p]['exp_avg']:一阶矩
  • state[p]['exp_avg_sq']:二阶矩
  • group['lr']:学习率
  • group['betas']:一阶矩与二阶矩的衰减系数
  • group['eps']:数值稳定项
  • group['weight_decay']:参数衰减强度

工程上常把 norm、bias、embedding 与普通线性层权重分开处理,因为它们未必都适合同样的 weight decay。也正因为如此,PyTorch optimizer 更倾向于把这类配置放在 param_groups 里,而不是绑定到单个统一全局参数上。

Optimizer State 与内存开销⚓︎

优化器并不只保存当前参数。以 AdamW 为例,通常还要为每个参数维护一阶矩与二阶矩,因此 optimizer state 的规模往往与模型参数量同量级。

这带来两个直接后果:

  • 显存或内存预算不能只看参数本身
  • checkpoint 不只要存权重,还要存 optimizer state,恢复训练才不会改变轨迹

模型越大,这部分开销越难忽略。很多训练系统问题,最终都会表现为 optimizer state 占用了过多显存。

参数分组⚓︎

不是所有参数都适合使用同样的 weight decay 或学习率。工程上常见的分组方式包括:

  • 普通线性层权重
  • bias
  • embedding
  • norm 参数

这样分组的原因,是不同参数类型的统计性质不同。比如 norm 参数本身承担尺度校准职责,若一律施加同样的衰减,未必符合设计目标。

常见权衡⚓︎

选择 optimizer 时,通常要同时看以下权衡:

  • 状态量开销:AdamW 通常高于 SGD
  • 调参敏感度:不同 optimizer 对学习率区间的容忍度不同
  • 收敛速度:自适应方法常更快进入可用区间
  • 最终性能:不同任务与规模下,最优选择并不完全相同

因此,optimizer 不是孤立选择。它需要和 batch size、学习率调度、梯度稳定性以及显存预算一起评估。

评论