Skip to main content

2.5D Tensor Parallelism

Author: Zhengda Bian, Yongbin Li

Prerequisite

Example Code

Related Paper

Introduction

Compared with 1D tensor parallelism, 2D parallelism reduces the memory cost, but may introduce more communication. Therefore, a 2.5D tensor parallelism algorithm was proposed based on 2.5D SUMMA to reduce communication by using more devices.

Let's still take a linear layer Y=XAY = XA as an example. Given P=q×q×dP=q \times q \times d processors (necessary condition), e.g. q=d=2q=d=2, we split the input XX into d×qd\times q rows and qq columns as

[X00X01X10X11X20X21X30X31],\left[\begin{matrix} X_{00} & X_{01} \\ X_{10} & X_{11} \\ X_{20} & X_{21} \\ X_{30} & X_{31}\end{matrix} \right],

which can be reshaped into dd layers as

[X00X01X10X11] and [X20X21X30X31].\left[\begin{matrix} X_{00} & X_{01} \\ X_{10} & X_{11} \end{matrix} \right] \text{~and~}\left[\begin{matrix} X_{20} & X_{21} \\ X_{30} & X_{31} \end{matrix} \right].

Also, the weight AA is split into

[A00A01A10A11].\left[\begin{matrix} A_{00} & A_{01} \\ A_{10} & A_{11} \end{matrix} \right].

For each layer of XX, we use the SUMMA algorithm to multiply XX and AA. Then, we have the output

[Y00=X00A00+X01A10Y01=X00A01+X01A11Y10=X10A00+X11A10Y11=X10A01+X11A11] and \left[\begin{matrix} Y_{00}=X_{00}A_{00}+X_{01}A_{10} & Y_{01}=X_{00}A_{01}+X_{01}A_{11} \\ Y_{10}=X_{10}A_{00}+X_{11}A_{10} & Y_{11}=X_{10}A_{01}+X_{11}A_{11} \end{matrix} \right] \text{~and~}
[Y20=X20A00+X21A10Y21=X20A01+X21A11Y30=X30A00+X31A10Y31=X30A01+X31A11].\left[\begin{matrix} Y_{20}=X_{20}A_{00}+X_{21}A_{10} & Y_{21}=X_{20}A_{01}+X_{21}A_{11} \\ Y_{30}=X_{30}A_{00}+X_{31}A_{10} & Y_{31}=X_{30}A_{01}+X_{31}A_{11} \end{matrix} \right].

Efficiency

Given P=q×q×dP=q \times q \times d processors, we present the theoretical computation and memory cost, as well as the communication cost based on the ring algorithm in both the forward and backward pass of 2.5D tensor parallelism.

ComputationMemory (parameters)Memory (activations)Communication (bandwidth)Communication (latency)
O(1/dq2)O(1/dq^2)O(1/q2)O(1/q^2)O(1/dq2)O(1/dq^2)O(3(q1)(d+1)/dq)\small O(3(q-1)(d+1)/dq)O(6(q1))O(6(q-1))

Usage

Currently the newest version of ColossalAI doesn't support 2.5D tensor parallelism, but this feature will be integrated into Shardformer in future releases. For more details about ideas and usages of Shardformer, please refer to Shardformer Doc.

For users of older version of ColossalAI, please refer to ColossalAI-Examples - 2.5D Tensor Parallelism.