优化器⚓︎
Abstract
本文整理 language model 训练中的参数更新机制,聚焦 SGD、Momentum、Adam、AdamW、状态量开销与常见工程边界。loss 定义训练目标,optimizer 则定义在拿到梯度后参数如何更新。
参数更新问题⚓︎
对参数 \theta 而言,训练的直接对象不是 optimizer,而是目标函数 L(\theta)。反向传播先给出梯度
optimizer 再根据这组梯度决定当前步的更新方向、更新尺度以及是否维护额外状态。因此,optimizer 讨论的是“如何更新”,而不是“模型学什么”。
从工程角度看,optimizer 至少影响四个问题:收敛速度、训练稳定性、状态量开销,以及不同参数组是否需要不同更新规则。对大模型训练而言,这些问题通常同样重要。
SGD⚓︎
随机梯度下降的基本形式是
其中 g_t 是当前 mini-batch 上的梯度,\eta 是学习率。它的优点是定义简单、状态量最少,也更容易直接理解学习率对更新幅度的影响。
局限同样明确。若不同参数维度的曲率差异较大,统一学习率往往难以同时兼顾所有方向;若梯度噪声较大,更新轨迹也会更抖动。因此,纯 SGD 往往依赖更谨慎的学习率设计。
Momentum⚓︎
Momentum 在 SGD 的基础上引入速度项,把近期梯度的指数滑动平均带入更新:
它的作用不是改变目标函数,而是让更新方向对短期噪声更不敏感,并在长期一致的下降方向上积累惯性。工程上,这通常能减少锯齿形震荡,使收敛轨迹更平滑。
Adam 系列⚓︎
Adam⚓︎
Adam 同时维护一阶矩与二阶矩:前者估计梯度的平均方向,后者估计梯度平方的平均尺度。前者决定当前更该朝哪个方向更新,后者决定这个方向上的步长需要被缩小多少。
其核心形式是
更新时会用 m_t 提供方向,用 \sqrt{v_t} 提供尺度归一化。直观上,梯度长期较大的维度会被自动缩小步长,梯度较小的维度则不会被同样强烈地压制。需要注意的是,Adam 的自适应步长并不等于“自动解决所有调参问题”;它减少的是不同参数尺度不一致带来的困难,而不是消除学习率、batch size 与训练阶段划分的影响。
算法流程
从训练代码视角看,Adam 的单步更新可以拆成六个动作:读取当前梯度、更新一阶矩、更新二阶矩、做 bias correction、构造归一化分母、更新参数。
更具体地说,这六步分别在解决不同问题:
- 读取当前梯度 g_t,确定本步目标函数对参数的局部变化方向。
- 更新一阶矩 m_t,用指数滑动平均平滑短期方向噪声。
- 更新二阶矩 v_t,估计历史梯度尺度,使不同参数维度的更新幅度可以分开归一化。
- 根据
step计数做 bias correction,把零初始化导致的早期统计偏小修正掉。 - 构造分母 \sqrt{\hat{v}_t} + \epsilon,其中 \epsilon 是数值稳定项,而不是算法主体。
- 用 bias-corrected 一阶矩除以分母,得到最终更新量并写回参数。
因此,Adam 的关键不是“记住两条递推公式”,而是理解它在每一步里同时做了方向平滑、尺度估计与早期偏差修正。
Adam 最小实现
import torch
@torch.no_grad()
def adam_step(param_groups, state):
for group in param_groups:
lr = group['lr']
beta1, beta2 = group['betas']
eps = group['eps']
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients')
if p not in state:
state[p] = {
'step': 0,
'exp_avg': torch.zeros_like(p),
'exp_avg_sq': torch.zeros_like(p),
}
param_state = state[p]
param_state['step'] += 1
exp_avg = param_state['exp_avg']
exp_avg_sq = param_state['exp_avg_sq']
step = param_state['step']
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step
exp_avg_hat = exp_avg / bias_correction1
exp_avg_sq_hat = exp_avg_sq / bias_correction2
denom = exp_avg_sq_hat.sqrt().add_(eps)
p.addcdiv_(exp_avg_hat, denom, value=-lr)
Bias Correction⚓︎
Adam 一类方法的矩估计从零开始初始化,因此训练早期的 m_t 与 v_t 会系统性偏小。bias correction 的作用,就是对这种初始化偏差做解析修正:
若不做这一步,训练最开始若干步的有效更新幅度会被低估。它主要影响早期阶段,因此和 warmup 常一起讨论,但逻辑上仍属于 optimizer 内部机制。更重要的是,它发生在矩更新之后、参数更新之前:只有先修正估计偏差,后续构造分母与更新量才有正确尺度。
AdamW⚓︎
AdamW 保留 Adam 的自适应更新机制,但把 weight decay 从梯度项中解耦出来。这个区别的关键不在于公式写法本身,而在于语义边界更清楚:
- 梯度项负责沿目标函数下降
- weight decay 负责把参数朝零收缩
若把 L2 regularization 直接混入 Adam 的梯度项,实际衰减强度会继续受到自适应分母影响,不同参数上的衰减行为不再一致。AdamW 则把这两件事拆开,因此更符合“权重衰减”本来的语义,也更便于调参与分析。
算法流程
AdamW 与 Adam 的矩估计部分基本一致,真正变化的是参数衰减不再混入梯度项,而是作为单独的参数收缩步骤放在自适应更新之外。
按单步顺序看,它的流程可以写成:
- 读取当前梯度与参数。
- 按 weight decay 直接收缩参数,使参数衰减不经过自适应分母。
- 更新一阶矩 m_t 与二阶矩 v_t。
- 做 bias correction,修正零初始化造成的早期偏差。
- 执行 Adam 风格的自适应更新。
需要记住的不是 AdamW 把所有步骤都改写了,而是参数衰减的语义被从梯度更新里拿了出来。因此,AdamW 和 Adam 的主要差异发生在 weight decay 的位置,而不是 bias correction 本身。
AdamW 最小实现
import torch
@torch.no_grad()
def adamw_step(param_groups, state):
for group in param_groups:
lr = group['lr']
beta1, beta2 = group['betas']
eps = group['eps']
weight_decay = group['weight_decay']
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
raise RuntimeError('AdamW does not support sparse gradients')
if p not in state:
state[p] = {
'step': 0,
'exp_avg': torch.zeros_like(p),
'exp_avg_sq': torch.zeros_like(p),
}
param_state = state[p]
param_state['step'] += 1
exp_avg = param_state['exp_avg']
exp_avg_sq = param_state['exp_avg_sq']
step = param_state['step']
if weight_decay != 0:
p.mul_(1 - lr * weight_decay)
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step
exp_avg_hat = exp_avg / bias_correction1
exp_avg_sq_hat = exp_avg_sq / bias_correction2
denom = exp_avg_sq_hat.sqrt().add_(eps)
p.addcdiv_(exp_avg_hat, denom, value=-lr)
Weight Decay⚓︎
weight decay 的目标是限制参数无约束增大。它最常作用于线性层或投影层权重,而不一定作用于所有参数。
这里需要区分两个概念:
- L2 regularization:把参数范数项加到目标函数中,再一起求梯度
- weight decay:在更新步骤中直接对参数做缩放或减法收缩
在 SGD 场景下,这两种写法常接近;在 Adam 这类自适应方法中,它们的行为则不再完全等价。因此,工程上更常显式讨论 AdamW 的 decoupled weight decay。
工程实现⚓︎
PyTorch 实现⚓︎
在 PyTorch 训练代码里,optimizer 通常按 param_groups -> params -> state[p] 的层级组织。对 Adam 与 AdamW 而言,正文里的一阶矩、二阶矩与步数计数,分别对应 exp_avg、exp_avg_sq 与 step 这三个状态字段。
从实现视角看,param_groups 负责保存一组参数共享的超参数,例如学习率、betas、eps 与 weight_decay;state[p] 则负责保存某个具体参数自己的历史统计量。因此,数学公式里写成同一个优化器状态,在代码里其实被拆成“组配置”和“参数私有状态”两层。
实现要点⚓︎
实现上最容易混淆的不是公式,而是状态字段和参数分组各自承担什么职责。state[p] 保存的是某个具体参数自己的历史统计量,而 param_groups 保存的是一组参数共享的超参数配置。
state[p]['step']:当前参数已更新次数state[p]['exp_avg']:一阶矩state[p]['exp_avg_sq']:二阶矩group['lr']:学习率group['betas']:一阶矩与二阶矩的衰减系数group['eps']:数值稳定项group['weight_decay']:参数衰减强度
工程上常把 norm、bias、embedding 与普通线性层权重分开处理,因为它们未必都适合同样的 weight decay。也正因为如此,PyTorch optimizer 更倾向于把这类配置放在 param_groups 里,而不是绑定到单个统一全局参数上。
Optimizer State 与内存开销⚓︎
优化器并不只保存当前参数。以 AdamW 为例,通常还要为每个参数维护一阶矩与二阶矩,因此 optimizer state 的规模往往与模型参数量同量级。
这带来两个直接后果:
- 显存或内存预算不能只看参数本身
- checkpoint 不只要存权重,还要存 optimizer state,恢复训练才不会改变轨迹
模型越大,这部分开销越难忽略。很多训练系统问题,最终都会表现为 optimizer state 占用了过多显存。
参数分组⚓︎
不是所有参数都适合使用同样的 weight decay 或学习率。工程上常见的分组方式包括:
- 普通线性层权重
- bias
- embedding
- norm 参数
这样分组的原因,是不同参数类型的统计性质不同。比如 norm 参数本身承担尺度校准职责,若一律施加同样的衰减,未必符合设计目标。
常见权衡⚓︎
选择 optimizer 时,通常要同时看以下权衡:
- 状态量开销:AdamW 通常高于 SGD
- 调参敏感度:不同 optimizer 对学习率区间的容忍度不同
- 收敛速度:自适应方法常更快进入可用区间
- 最终性能:不同任务与规模下,最优选择并不完全相同
因此,optimizer 不是孤立选择。它需要和 batch size、学习率调度、梯度稳定性以及显存预算一起评估。