跳转至

并行策略⚓︎

Abstract

本文整理大模型训练与推理中的单卡瓶颈、显存与通信成本模型,以及数据并行、全分片数据并行、张量并行、流水线并行、Sequence Parallelism、ZeRO 等策略分别缓解什么问题、引入什么代价。这里关注的不是某个框架的接口,而是模型状态、计算切分与通信路径如何共同决定系统可训练规模、推理吞吐与延迟边界。

单卡瓶颈⚓︎

当模型规模、序列长度和目标吞吐持续上升时,单卡首先遇到的通常不是某个算子实现不够快,而是多个资源上限同时逼近:权重无法完整驻留、训练状态放大显存、长序列激活值推高峰值、推理中的 KV cache 持续增长,以及一旦开始切分计算,通信本身也会成为新的系统成本。

最常见的单卡瓶颈包括:

  • 权重无法完整放入单卡显存
  • 激活值随 batch size 与 sequence length 增长而快速膨胀
  • 训练中的梯度与优化器状态显著放大总驻留
  • 推理中的 KV cache 随上下文和并发持续增长
  • 多卡切分后必须为同步与重建完整视图支付通信代价

因此,分布式并行不是单纯把工作“分给更多 GPU”,而是要同时回答三个问题:哪些状态要切分、哪些计算要切分、切分之后通信是否仍然可接受。

成本模型⚓︎

显存构成⚓︎

从系统角度看,单卡显存压力并不是一个抽象常数,而是多个部分叠加后的结果。一个足够常用的近似写法是:

M_{total} \approx M_{params} + M_{grads} + M_{opt} + M_{act} + M_{kv}.

其中,M_{params} 表示参数驻留,M_{grads} 表示梯度,M_{opt} 表示优化器状态,M_{act} 表示激活值,M_{kv} 表示推理阶段的 KV cache。训练时通常以前四项为主,推理时则更关注参数和 KV cache 的共同占用。

参数占用通常可近似写为:

M_{params} \approx N_{params} \times \text{bytes per parameter}.

这意味着参数量和精度格式会直接决定权重驻留成本。若模型有 N_{params} 个参数,那么仅权重本身就会线性放大显存需求。对超大模型而言,单卡放不下完整权重往往是分布式切分出现的第一动机。

训练中的优化器状态会进一步放大总驻留。以 Adam 为例,系统除了参数本身外,通常还要维护梯度、一阶矩和二阶矩,因此训练状态并不只是一份权重的简单复制,而是多类状态同时常驻。模型越大,这部分放大量越明显。

激活值的规模通常与 batch size、sequence length、hidden size 和 layer count 共同相关。若记 batch size 为 B、序列长度为 T、隐藏维度为 H、层数为 L,那么激活占用可近似理解为与 B \times T \times H \times L 同量级。它并不是严格公式,但足以说明长序列和大 batch 为什么会迅速把训练峰值推高。

KV cache 主要出现在推理阶段。若忽略常数项,单请求 KV cache 的占用通常与层数、KV heads 数、head dimension 和当前序列长度近似线性相关;若再考虑并发请求数 R,则系统总 KV cache 压力还会继续乘上活跃请求规模:

M_{kv} \propto R \times L \times T \times H_{kv} \times D_{head} \times 2 \times \text{bytes}.

因此,推理系统的显存边界并不只由模型权重决定,还会受到上下文长度和并发量的共同约束。

计算量⚓︎

并行策略不会改变模型总共需要完成多少数学运算这一事实,但会改变这些运算如何分摊到设备、哪些部分可以同时执行,以及是否会因为重算而增加额外工作量。

以线性层的矩阵乘法为例:

Y = XW.

若 batch size 为 B,序列长度为 T,输入维度为 H_{in},输出维度为 H_{out},那么该层前向的乘加规模可近似理解为与

B \times T \times H_{in} \times H_{out}

同量级。这个量级关系说明:即使总参数量不变,只要 batch 或序列长度上升,单次前向的算力需求也会随之上升。

Activation Checkpointing 的作用,不是减少总数学运算,而是减少前向阶段长期保留的中间激活。其代价是反向传播时要从检查点局部重算一段前向,因此总计算量会上升。它本质上是以额外算力换更低峰值显存。

张量并行会把一个大算子的局部工作分发到多张卡上,但不会减少这个算子在全系统层面需要完成的总数学运算。它改变的是单卡承受的局部矩阵规模和单卡算力压力,而不是把原本需要计算的内容直接消除。

流水线并行通常也不减少总工作量。它的重点在于让不同设备承担不同层段的工作,并通过 micro-batch 调度提高整体利用率。真正影响效率的关键,不是总 FLOPs 是否变少,而是 stage 之间能否稳定重叠、空转是否过大。

通信量⚓︎

一旦计算或状态被切分到多个设备,系统就必须为重建完整计算视图支付通信代价。对大模型系统而言,很多策略的区别并不只在显存能否放下,还在于通信发生在什么位置、以什么频率发生、能否与计算重叠。

数据并行的典型通信是梯度 All-Reduce。每张卡都持有完整模型副本并独立完成前向与反向,最后通过 All-Reduce 汇总梯度。这类通信通常发生在反向之后,模式相对规整。

FSDP 和 ZeRO-3 的典型通信包括 All-Gather 与 Reduce-Scatter。因为参数、梯度或优化器状态被分片保存在不同设备上,系统在计算某一层时必须临时收集完整视图,计算结束后再把结果分散回去。它们节省的是长期驻留,增加的是按层或按阶段重建完整状态的通信。

张量并行的通信更靠近算子内部。被切分的矩阵乘法、注意力投影或输出聚合通常都需要频繁拼接、归约或同步,因此通信频率很高。它往往更依赖低延迟、高带宽互连。

流水线并行的通信主要发生在 stage 边界。前一段网络输出的激活会作为下一段网络的输入继续向后传递,因此它传输的重点不是完整参数,而是跨层边界的激活数据。

真实系统中,通信代价还会受到互连类型显著影响。NVLink、PCIe 和 RDMA 在带宽、延迟和拓扑上的差异,会直接改变某种并行策略是否仍然划算。

训练与推理的压力画像⚓︎

训练⚓︎

训练阶段需要同时承担前向、反向、梯度同步和参数更新,因此显存压力通常来自参数、梯度、优化器状态与激活值的共同叠加。其中,激活值和优化器状态经常比直觉中更早成为峰值来源。

从优化目标看,训练更关注以下几类压力:

  • 激活值过大导致反向前无法保留足够中间状态
  • 参数、梯度和优化器状态共同常驻,单卡无法完整容纳
  • 反向传播和梯度同步让通信更频繁、更集中
  • 深层网络和大 batch 会推高峰值显存与总算力需求

因此,训练中常见的辅助或并行手段各自对应不同压力:

  • Activation Checkpointing 主要降低激活峰值
  • FSDP / ZeRO 主要降低参数、梯度和优化器状态的长期驻留
  • 流水线并行主要切开层级驻留与层段计算
  • 张量并行主要切开超大算子和局部矩阵规模

推理⚓︎

推理阶段不再保存梯度和优化器状态,但这并不意味着显存问题消失。对于大模型在线服务,系统仍然要同时容纳模型权重、运行时激活和 KV cache;当上下文变长或并发请求变多时,KV cache 很容易反过来成为主要限制项。

推理更常见的系统压力包括:

  • 超大模型权重无法在单卡完整驻留
  • decode 阶段持续受权重读取与 KV cache 读取的带宽限制
  • 活跃请求数增加后,KV cache 迅速抬高显存占用
  • 服务系统还要同时兼顾 TTFT、TPOT、throughput 和尾延迟

若忽略常数项,KV cache 的占用通常与层数、KV heads 数、head dimension、序列长度和活跃请求数近似线性相关。因此,推理侧讨论并行时,不能只盯住参数是否能放下,还必须同时考虑缓存增长、批处理方式和服务调度约束。

因此,适合训练状态压缩的策略,并不一定能直接改善在线推理延迟;同样,适合推理部署的参数切分方式,也不一定能显著降低训练阶段的峰值显存。训练与推理虽然都使用分布式并行,但二者面对的主导瓶颈并不相同。

并行策略⚓︎

数据并行⚓︎

数据并行的核心,是在每个设备上复制完整模型,并把输入 batch 切分到多个设备上独立执行前向与反向。它主要解决的是吞吐扩展问题,而不是单卡无法容纳模型的问题。

在训练中,每个设备对本地 micro-batch 独立计算梯度,随后通过 All-Reduce 汇总梯度并保持参数更新一致。其优点是算子执行方式和单机训练非常接近,工程复杂度相对较低。

它的边界同样明确:参数仍然需要完整驻留在每张卡上,梯度和优化器状态通常也以完整副本的形式存在。因此,当模型本身已经大到单卡无法容纳时,标准数据并行并不能解决根本问题。

数据并行更适合模型能够单卡放下,但系统希望通过多卡扩大总 batch 或提升训练吞吐的场景。

全分片数据并行⚓︎

全分片数据并行的关键,不是改变层内数学运算,而是把参数、梯度和优化器状态的长期驻留从“每卡一份完整副本”改成“跨设备分片保存,按需重建”。

它的核心收益,是显著降低单卡需要长期保存的模型状态。当系统只在真正执行某一层时临时收集完整参数,而在平时只保留本地分片,单卡显存压力就不再与完整模型状态线性绑定。

它的代价,是计算过程被更多状态通信包围。系统需要在前向或反向过程中反复做 All-Gather 和 Reduce-Scatter,因此性能不仅取决于 GPU 算力,也取决于互连质量和通信重叠能力。

从本质上看,FSDP 优化的是状态驻留,而不是算子级切分。它适合模型状态过大、但算子本身并不一定需要被拆分到多卡同时计算的场景。

ZeRO 分阶段机制⚓︎

ZeRO 的核心,是逐步去掉数据并行中冗余保存的模型状态。不同阶段的区别,不在于是否分布式,而在于到底把哪一类长期驻留拆成跨设备分片。

ZeRO-1⚓︎

ZeRO-1 只切分优化器状态。参数和梯度仍然完整存在于每张卡上,但优化器状态不再重复保存。因此,它主要降低的是 M_{opt},而不是参数驻留本身。

ZeRO-2⚓︎

ZeRO-2 在 ZeRO-1 的基础上继续切分梯度。这样每个设备只长期保留本地负责分片对应的梯度,因此会进一步降低 M_{grads}。与此同时,系统需要通过 Reduce-Scatter 等通信把梯度送到负责该分片的设备。

ZeRO-3⚓︎

ZeRO-3 继续把参数本身也做分片。这样单卡不再长期持有完整参数,而只保留本地分片。计算某一层时,系统需要先 All-Gather 重建完整参数视图,计算结束后再释放非本地部分。因此,ZeRO-3 降低的是 M_{params} 的长期驻留,但引入了更频繁的参数重建通信。

从概念上看,ZeRO-3 与 FSDP 非常接近:二者都通过状态分片和按需重建降低单卡驻留,只是在工程实现和调度细节上可能不同。若从学习机制的角度看,它们都属于状态分片型策略,而不是算子切分型策略。

张量并行⚓︎

张量并行是在算子级别切分模型计算。以线性层

Y = XW

为例,若将 W 按列切分为 W_1, W_2, \ldots, W_n,则不同设备分别计算局部结果:

Y_i = XW_i,

随后再通过通信拼接或归约为完整输出。

这种做法的直接收益,是单个设备不再需要承担完整大矩阵的存储与计算压力。若原本某个线性层大到单卡难以承受,那么把矩阵按列或按行切分后,单卡只需负责局部块。

它的代价同样直接:几乎每个被切分的算子后都需要通信来恢复完整结果或把部分结果继续传给下一步算子。因此,张量并行非常依赖高带宽、低延迟互连。若通信链路较慢,算子切分带来的收益很容易被同步成本抵消。

这也是为什么张量并行通常更适合单节点高带宽互连环境。它主要缓解的是单卡算子规模和单卡参数块过大问题,而不是模型状态冗余问题。

Sequence Parallelism⚓︎

Sequence Parallelism 主要沿序列维度切分激活和中间表示,而不是沿参数维度切分权重。它的重点不是让每张卡都少存一部分模型参数,而是让每张卡只承担一部分序列片段上的激活压力。

若序列长度为 T,并被切分到 n 个设备上,那么单卡承担的局部序列长度可近似降到 T / n 的量级。这样做最直接的收益,是某些激活值和中间缓存不再需要在每张卡上完整保留。

它通常与张量并行配合使用,因为张量并行缓解的是算子和参数块的局部规模,Sequence Parallelism 缓解的是长序列场景下的激活驻留。二者切分的对象不同,因此在工程上往往互补,而不是二选一替代关系。

流水线并行⚓︎

流水线并行是在层级上切分模型。系统把不同层段分配给不同设备,使每个设备只需要持有其中一部分网络结构。

它的直接收益,是单卡不再需要持有全部层参数与全部层段的局部激活,因此能够显著降低层级驻留压力。与此同时,设备间传输的重点通常是 stage 边界激活,而不是每个算子后的细粒度同步。

它的主要代价,是流水线气泡。当 micro-batch 数过少时,部分 stage 会在等待前序或后序 stage 时空转,这就是 pipeline bubble。直观上,micro-batch 越多,bubble 在总执行时间中的占比通常越低,但调度和激活管理也会更复杂。

若用非常粗略的直觉理解,假设有 P 个 pipeline stage、m 个 micro-batch,则 bubble 占比通常会随着 (P-1)/m 的下降而减小。这里的重点不是精确推导,而是说明:流水线并行的效率高度依赖切分平衡和 micro-batch 调度,而不是只要把层切开就一定高效。

策略对比⚓︎

切分对象⚓︎

不同策略最根本的区别,在于它们切分的对象不同:

  • 数据并行切分输入数据,复制完整模型
  • FSDP / ZeRO 切分模型状态的长期驻留
  • 张量并行切分算子内部的矩阵或张量计算
  • Sequence Parallelism 切分序列维度上的激活压力
  • 流水线并行切分网络层段

因此,讨论并行策略时,首先应问“它在切什么”,而不是先问“它是不是更高级”。

显存收益⚓︎

若从显存收益看,几种策略的侧重点也不同:

  • 数据并行几乎不降低单卡模型状态驻留
  • FSDP / ZeRO 主要降低参数、梯度和优化器状态的长期驻留
  • 张量并行降低单卡局部参数块和局部算子规模
  • Sequence Parallelism 降低长序列场景下的激活压力
  • 流水线并行降低单卡持有的层数和对应层段状态

这意味着“省显存”不是一个统一动作,而是针对不同项分别减压。

通信代价⚓︎

显存收益往往和额外通信同时出现:

  • 数据并行主要承担梯度 All-Reduce
  • FSDP / ZeRO-3 主要承担参数或梯度分片的重建与分散
  • 张量并行主要承担算子级高频同步
  • 流水线并行主要承担跨 stage 激活传递和调度等待
  • Sequence Parallelism 也需要为切分后的序列视图重建支付通信代价

因此,一个策略是否划算,不能只看显存节省比例,还必须同时看通信频率、通信位置以及是否能够和计算重叠。

公式对比⚓︎

若用非常粗略但足够实用的方式比较几种策略,可以把单卡长期驻留和关键通信量写成下面这种近似关系。设设备数为 N,参数总量为 P,梯度总量为 G,优化器状态总量为 O,某层局部激活大小为 A,则:

  • 数据并行

$$ M_{local} \approx P + G + O, $$

每轮反向的关键同步量通常与完整梯度规模 G 同量级。它几乎不降低单卡模型状态驻留。

  • FSDP / ZeRO-3

$$ M_{local} \approx \frac{P + G + O}{N} + A, $$

但计算某一层时需要临时重建完整参数视图,因此阶段性通信通常至少包含与该层参数块同量级的 All-Gather / Reduce-Scatter。它降低的是长期驻留,而不是瞬时通信压力。

  • 张量并行

若把线性层 Y = XW 按列切到 N 张卡,则单卡局部权重近似为:

$$ W_i \in \mathbb{R}^{H_{in} \times H_{out}/N}, $$

单卡局部计算量近似降为原来的 1/N,但每个被切分算子之后通常都需要一次拼接或归约,其通信频率与算子数而不是只与 step 数绑定。

  • 流水线并行

若模型被切成 N 个 stage,则单卡长期持有的层数可近似降为原来的 1/N。但当 micro-batch 数为 m 时,bubble 占比常用的直觉近似是:

$$ \text{bubble fraction} \approx \frac{N-1}{m}. $$

因此,PP 主要换来层级驻留下降,同时引入利用率损失风险。

这些公式都不是严格上界或精确实现公式,但它们足以帮助判断:某种策略到底主要在压缩长期驻留、切开局部算子,还是在用更高通信频率换单卡可承受性。

适用场景⚓︎

常见的选择逻辑可以概括为:

  • 若模型能单卡放下,只是吞吐不够,优先考虑数据并行
  • 若主要问题是训练状态过大,优先考虑 FSDP 或 ZeRO
  • 若单个算子过大且节点内互连很强,张量并行更有价值
  • 若模型层数很深、适合按层切开,流水线并行更自然
  • 若长序列激活压力突出,可考虑 Sequence Parallelism 作为补充

实际系统通常不会只使用一种策略,而是根据单卡显存、互连条件、batch 目标、延迟目标和工程复杂度做组合选择。

组合方式⚓︎

真实的大模型系统往往组合多种方法,而不是把某一种策略单独推到极致。

常见组合包括:

  • 数据并行 + 张量并行
  • FSDP + 张量并行
  • 流水线并行 + 张量并行 + 数据并行
  • 推理侧权重分片 + KV-cache-aware serving

组合的核心原因在于,不同策略解决的是不同瓶颈。单卡显存不足、节点内互连充足、全局 batch 需要继续扩大、在线服务需要控制延迟,这些目标往往不会由同一种并行方式同时最优满足。

因此,组合策略的选择本质上是在单卡容量、跨设备通信、全局吞吐、系统复杂度和延迟目标之间做多目标折中。

常见问题⚓︎

为什么数据并行不能解决超大模型放不下的问题?⚓︎

因为数据并行切分的是输入数据,而不是模型本身。每张卡仍然需要持有完整参数,训练时通常还要持有完整梯度和优化器状态。

因此,数据并行可以有效扩大吞吐,但不能改变“单卡必须先能容纳完整模型状态”这一前提。

FSDP / ZeRO-3 和张量并行的根本区别是什么?⚓︎

FSDP / ZeRO-3 的核心是状态分片:把参数、梯度或优化器状态的长期驻留拆散,计算时再按需重建完整视图。张量并行的核心是算子分片:多个设备共同完成同一个矩阵乘法或同一组张量计算。

前者主要缓解长期驻留压力,后者主要缓解局部算子规模和局部算力压力。二者都能降低单卡压力,但切分对象和通信位置完全不同。

为什么有些策略更适合训练,而不是推理?⚓︎

训练和推理的主导压力不同。训练更受梯度、优化器状态、激活值和反向同步影响,因此状态分片和激活压缩通常收益更明显。推理则更受权重驻留、带宽和 KV cache 增长影响。

因此,一个对训练峰值显存非常有效的方法,不一定能直接改善在线推理延迟;反过来,一个适合推理部署的权重切分方案,也不一定最适合训练中的状态管理。

流水线并行是不是只要把层切开就一定有效?⚓︎

不是。流水线并行的效率强烈依赖层段切分是否平衡、micro-batch 是否足够多,以及 stage 之间的等待是否能够被压低。

若切分不平衡或 micro-batch 太少,某些 stage 会长期空转,pipeline bubble 就会吞掉本来希望获得的并行收益。

评论