分片矩阵与通信⚓︎
当模型规模增大到单卡内存放不下完整参数,或者为了降低时延希望把同一层摊到更多设备时,矩阵就必须分片到多张卡上。分片之后,单卡上的局部计算规模变小,但完整结果通常不再能直接从本地得到,而是需要通过额外通信恢复。
因此,分片矩阵问题可以统一理解为两步:先在每张卡上执行局部计算,再通过通信把局部结果拼接、规约或重新分发为全局结果。真正的性能差异,往往不在局部 matmul 本身,而在“局部结果缺了什么”以及“为了补齐这些信息需要付出多少通信”。
分片视角⚓︎
分片并不会改变矩阵乘法的逻辑定义,只是把一个全局矩阵拆成多个设备局部持有的 shard。每个 shard 仍然属于同一个逻辑张量,只是不同设备保存的是不同切片,局部 shape 与全局 shape 不再相同。
分析分片时,最重要的不是“某张卡上有什么数据”,而是同时区分三个层次:
- 逻辑 shape:全局张量原本的形状。
- 局部 shape:单个设备实际持有的 shard 形状。
- mesh 轴:这些 shard 是沿哪条设备轴切分的。
如果一个逻辑维度沿某条 mesh 轴切分,那么每台设备只持有该维度上的一部分;如果某个逻辑维度没有沿某条 mesh 轴切分,那么这个维度会在那条轴上保持复制。于是,分片并不只是减少单卡内存占用,它还决定了后续哪些计算能完全本地完成,哪些计算必须跨卡交换数据。
统一记号与分片约束⚓︎
描述分片时,可以把设备组织成一个带名字的 mesh,例如 X、Y、Z 三个 mesh 轴。随后再说明张量的每个逻辑维度沿哪些 mesh 轴被切分。
例如,一个二维数组 A[I, J]:
A[I, J]表示完全复制,每台设备都有完整副本。A[I_X, J]表示第一个逻辑维度沿X轴切分。A[I_X, J_Y]表示两个逻辑维度分别沿不同 mesh 轴切分。A[I_XY, J]表示把X和Y视为更大的合并轴,联合切分第一个逻辑维度。
这里的关键是:一个 mesh 轴只能花一次。也就是说,同一张量不能同时出现类似 A[I_X, J_X] 这样的布局,因为这等价于要求同一条设备轴同时承担两个独立逻辑维度的切分,而本地 shard 无法唯一对应全局位置。
这条约束决定了许多形式上对称的分片并不合法,也解释了为什么某些矩阵乘法在开始前必须先做通信,才能转换成可计算的局部布局。
常见矩阵分片方式⚓︎
对矩阵乘法
最常见的分片方式通常围绕三类维度展开:批维或输出行维、求和维 D,以及输出列维 F。不同分片方式的关键差别在于:每张卡本地保留了哪一部分结构,又因此缺失了哪一部分全局信息。
按批维或输出行维切分⚓︎
若 B 维被切开,则每张卡负责不同批样本。因为不同 batch 行之间本来就独立,局部 matmul 往往可以直接生成一部分完整输出。只要权重本身未沿求和维打散,这类布局的通信压力通常较小。
按求和维切分⚓︎
若 D 维被切开,则每张卡只持有输入与权重在求和维上的一部分。此时每张卡本地只能算出最终输出的 partial result,也就是对同一输出位置的一部分贡献。最终必须通过规约类通信把这些 partial result 相加,才能得到完整值。
按输出列维切分⚓︎
若 F 维被切开,则每张卡只生成一部分输出特征。对当前层而言,这种布局经常没有问题,因为每张卡确实得到了完整的本地输出切片。问题在于下游层是否还能沿相同维度继续保持这种分片。如果不能,就需要在某个位置通过拼接类通信重新收集完整输出。
因此,分片方式并不只是“把矩阵切开”,而是在选择:是让通信发生在当前算子之后,还是把分片结构继续传给下游算子,让后续若干层共同承担这份结构约束。
局部矩阵乘法⚓︎
分片之后,每张卡执行的并不是原始大矩阵乘法,而是某个局部 shard 与另一个局部 shard 之间的乘法。判断局部 matmul 是否已经“完成”,关键只看一件事:求和维是否完整。
- 若本地保留了完整求和维,则局部输出就是完整子结果。
- 若求和维本身被切开,则局部输出只是 partial result,还缺其他设备上的贡献。
这条规则可以把大多数分片矩阵乘法归纳为四种基本情形。
情形一:求和维没有被切分⚓︎
如果两个输入在求和维上都保持完整,那么每张卡都可以直接做本地 block matmul,而不需要通信。例如:
这里 J 没有被切开,因此每个局部块已经拥有完成收缩所需的全部信息。本地乘法结束后,结果天然落在目标分片上。
情形二:只有一个输入沿求和维切分⚓︎
如果只有一个乘数沿求和维切分,那么本地设备缺的是“另一个求和方向上的完整块”。这时最常见做法是先把被切开的输入 AllGather 成完整版本,再进行本地 matmul。
例如:
由于 A 在 J 维上不完整,无法直接完成收缩,因此通常先把 A 沿 X 轴 AllGather 回完整 J 维,再做局部乘法。
情形三:两个输入都沿同一求和维切分⚓︎
如果两个乘数在求和维上按同一 mesh 轴切分,那么每张卡的局部 shard 可以彼此相乘,但每张卡得到的只是完整结果的部分和。此时最自然的步骤是:
- 本地做局部 matmul,得到 partial result。
- 再做 AllReduce 或 ReduceScatter,把 partial result 规约成完整结果或新的分片结果。
例如:
这里的 \{U_X\} 可以理解为“沿 X 轴尚未规约完成”。本地 matmul 只是先把每张卡负责的求和维块贡献算出来,还需要后续规约消去这层未完成状态。
情形四:两个输入在同一 mesh 轴上切了不同的非求和维⚓︎
这是最容易出现非法布局的情况。直观地看,如果两个乘数都把不同的非求和维切在同一 mesh 轴上,那么本地 shard 可能只对得上结果矩阵的某些对角块,而无法恢复完整结果。此时必须先 AllGather 某一侧输入,释放那条 mesh 轴,再继续计算。
这类情况说明,分片合法性不只是“每台设备都有数据”,而是要保证每个局部 shard 对应的全局块组合仍然足以覆盖整个结果张量。
通信模式与代价⚓︎
当局部结果不再完整时,就需要集体通信补齐信息。对分片矩阵乘法,最核心的通信模式有四类:AllGather、ReduceScatter、AllReduce 和 AllToAll。
AllGather⚓︎
AllGather 的作用是把某条轴上的 shard 收集到每个设备上,使该逻辑维度在这条 mesh 轴上不再切分。它解决的问题是:每台设备只有一部分切片,但接下来的局部计算需要完整张量。
从语义上看,AllGather 是“去掉一个下标”。例如:
在带宽受限阶段,AllGather 的主要成本由总字节数和链路带宽决定,而不是由设备数本身决定。设备数更多时,每跳发送的块更小,但 hop 数也更多,这两者会在吞吐阶段大致抵消。只有当数组很小、单 hop 时间接近链路固定延迟时,才会进入 latency-bound 区域,此时设备数和拓扑直径才会显著影响耗时。
ReduceScatter⚓︎
ReduceScatter 的作用是先把 partial result 求和,再顺手把结果沿某个逻辑维度重新切散到不同设备。它解决的问题是:每台设备持有的是同一输出位置的部分贡献,但后续又不需要完整复制结果,而只需要某种新的分片布局。
语义上可以理解为:先去掉“未规约完成”的状态,再在另一个逻辑维度上引入新的切分。因此它既是规约操作,也是布局变换操作。
在很多训练场景里,ReduceScatter 比 AllReduce 更有价值,因为它避免了先把完整结果复制到所有设备上,再在后续步骤重新切分一次。
AllReduce⚓︎
AllReduce 的作用是对 partial result 求和,并把完整结果保留在每个设备上。它解决的问题是:每台设备都需要同一份完整规约结果。
从通信结构上看,AllReduce 通常可以理解为:
- 一次 ReduceScatter
- 再接一次 AllGather
因此其总成本通常约为 AllGather 的两倍。若后续步骤并不需要所有设备持有完整副本,那么直接做 AllReduce 往往不是最省通信的方案。
AllToAll⚓︎
AllToAll 的作用不是复制,也不是求和,而是把分片下标从一个逻辑维度“搬”到另一个逻辑维度。例如:
它本质上是一次布局重排。其典型用途不是普通 matmul,而是当两个相邻计算区域所要求的分片布局不兼容时,用来把张量重新排布到下一个阶段需要的形式。Mixture of Experts 等模型里经常会出现这种需求。
由于 AllToAll 不需要像 AllGather 那样把完整张量复制到所有设备,它的通信成本通常低于 AllGather。在双向 ring 条件下,常见经验是其成本大约是 AllGather 的四分之一量级。
分片策略与 roofline⚓︎
把分片矩阵放回 roofline 框架后,需要同时看两部分:
- 局部 matmul 的片内 roofline
- 分片引入的跨卡通信 roofline
前者决定单卡上的计算是否已经接近 compute-bound,后者决定多卡之后新增的 AllGather、AllReduce、ReduceScatter 或 AllToAll 是否会把整体瓶颈重新拉回 communication-bound。
因此,有些分片方式虽然没有改变单卡 matmul 的高算术强度,但会显著增加 AllGather 或 AllReduce 的频率,使得端到端性能主要受 ICI 或更慢网络约束;另一些分片方式则能把输出继续保持为分片布局,从而把部分通信推迟到后续算子,降低当前层的通信压力。
从这个角度看,分片策略真正优化的不是单个公式,而是 计算路径与通信路径之间的平衡。一个方案也许节省了显存,却引入了更频繁的全量收集;另一个方案也许局部 FLOPs 更少,但因为结果只能以 partial result 形式存在,最后必须付出高昂规约成本。只有把局部 matmul 与通信一起看,roofline 才能反映真实瓶颈。
两类典型策略对比⚓︎
以 X[B, D] @ Y[D_X, F] -> Z[B, F] 为例,有两类常见策略:
- 先 AllGather 再本地乘法:把
Y沿X轴收集回完整权重,然后每台设备做完整 matmul。 - 先本地乘法再 AllReduce:每张卡用自己的
D_X局部块做 matmul,得到 partial result,随后再对Z做 AllReduce。
这两类策略的区别在于:前者通信对象是权重块,后者通信对象是输出块。哪一个更优,并不由规则先验决定,而取决于 B、D、F 的相对大小,以及通信与计算是否能重叠。
若输出张量明显小于被收集的权重张量,则“先本地算再规约”可能更划算;若 batch 很大、输出也很大,则过早产生 partial result 反而可能让规约代价失控。这个比较对应两类不同代价:一类是为完成收缩先付出输入通信,另一类是先生成 partial result 再付出输出通信。
反向传播里的通信对偶⚓︎
在分布式训练里,通信原语不仅出现在前向,也会在反向中成对出现。一个很重要的规律是:
- AllGather 的反向常对应 ReduceScatter
- ReduceScatter 的反向常对应 AllGather
其原因在于,广播与求和在线性算子意义下互为转置。对工程理解而言,这意味着前向里选择了某种“先复制后算”的布局,反向里往往就要承担一次与之对应的“先规约再切分”代价。因此,某种分片策略是否合理,不能只看前向最省哪一次通信,还要看整个训练闭环里前向与反向是否平衡。
通信与计算的重叠⚓︎
在 roofline 估计里,通常会假设通信可以与部分计算重叠。对分片矩阵乘法,这并不是自动成立的,而往往需要显式地把矩阵分成更细的块,让某些块开始进行 ring reduction 时,其他块仍在继续做局部 matmul。
这种重叠思路的意义在于:
- 若通信慢于计算,则计算会被网络饿住。
- 若通信足够快且可流水化,则部分 collective 代价可以被隐藏在 matmul 的执行窗口中。
因此,在评估一个分片策略时,不应只看“需要哪种通信”,还应看“这段通信是否有机会与当前 matmul 的其他块并行”。同样的 AllReduce,在完全串行执行和可与块级 matmul 流水重叠时,端到端成本可能差异很大。
工程判断⚓︎
分析一个分片矩阵问题时,可以按以下顺序判断:
- 先看矩阵哪一维被切分。
- 再看每张卡本地 matmul 得到的是完整子结果还是 partial result。
- 再看缺失的信息需要通过哪种通信补齐。
- 最后比较这段通信是否值得,以及它能否与局部计算重叠。
对工程实现而言,最有价值的问题通常不是“能不能分片”,而是:
- 分片后缺失的信息是什么。
- 这些信息是否必须立刻恢复。
- 恢复它的通信是否会比节省下来的内存与算力更贵。
- 这种分片布局是否还能顺着后续层继续保持。
可以把分片矩阵的几个稳定结论归纳为:
- 求和维未切开时,局部 matmul 往往最简单,通信也最少。
- 求和维一旦被切开,就必须在“先收集输入”与“后规约输出”之间做权衡。
- AllGather 解决的是“缺切片”,AllReduce 解决的是“缺求和”,ReduceScatter 解决的是“求和后仍想保持分片”,AllToAll 解决的是“布局不兼容”。
- 单卡上高算术强度的 matmul,上多卡后仍可能因为 collectives 而重新变成 communication-bound。
- 是否保留分片布局给下游层,往往和单层最优同样重要。
因此,理解分片矩阵的关键不是死记每种 collective 的定义,而是把 局部计算、结果语义、通信模式和 roofline 代价 放在同一张图里看。只有这样,才能判断一种分片策略到底是加速,还是只是把瓶颈从单卡显存换成了跨卡网络。