- Wed 20 December 2023
- Neuron
- #ml-code, #llm, #ml-acceleration, #hpc-concept
All Reduce Decomposition
As part of LLM serving, AllReduce is a common, but expensive operation which also blocks compute utilization. A common improvement would be to replace AllGather with a computationally equivalent ReduceScatter + AllGather
ReduceScatter: In this step, each process performs a reduction operation (e.g., sum, product, max, min) on its input data. The resulting reduced values are then scattered across all processes, such that each process receives a distinct portion of the overall reduction result.
AllGather: After the ReduceScatter step, each process holds a different portion of the overall reduction result. The AllGather operation then collects all these partial results from all processes and distributes the complete result to every process.
By breaking down AllReduce into these two steps, it can potentially improve performance and communication efficiency, particularly in certain parallel computing architectures or communication patterns. Some parallel computing libraries or frameworks may implement AllReduce as a single operation or provide optimized implementations that combine the two steps for better performance.
Diagram Explanation
Initial State Reduce-Scatter AllGather Final State
┌───┬───┬───┬───┐ ┌───┬───┬───┬───┐ ┌───┬───┬───┬───┐ ┌───┬───┬───┬───┐
D0 │ A │ B │ C │ D │ → │ W │ │ │ │ → │ W │ X │ Y │ Z │ → │ W │ X │ Y │ Z │
├───┼───┼───┼───┤ ├───┼───┼───┼───┤ ├───┼───┼───┼───┤ ├───┼───┼───┼───┤
D1 │ E │ F │ G │ H │ → │ │ X │ │ │ → │ W │ X │ Y │ Z │ → │ W │ X │ Y │ Z │
├───┼───┼───┼───┤ ├───┼───┼───┼───┤ ├───┼───┼───┼───┤ ├───┼───┼───┼───┤
D2 │ I │ J │ K │ L │ → │ │ │ Y │ │ → │ W │ X │ Y │ Z │ → │ W │ X │ Y │ Z │
├───┼───┼───┼───┤ ├───┼───┼───┼───┤ ├───┼───┼───┼───┤ ├───┼───┼───┼───┤
D3 │ M │ N │ O │ P │ → │ │ │ │ Z │ → │ W │ X │ Y │ Z │ → │ W │ X │ Y │ Z │
└───┴───┴───┴───┘ └───┴───┴───┴───┘ └───┴───┴───┴───┘ └───┴───┴───┴───┘
Explanation:
Initial State: Each device (D0, D1, D2, D3) has its own data (A-P).
Reduce-Scatter:
- Each device performs a partial reduction on a specific portion of the data.
- W = reduction of (A, E, I, M)
- X = reduction of (B, F, J, N)
- Y = reduction of (C, G, K, O)
- Z = reduction of (D, H, L, P)
- The results are scattered across devices.
AllGather:
- Each device gathers the partial results from all other devices.
- All devices now have the complete reduced result (W, X, Y, Z).
Final State: All devices have the same, complete reduced result.
This decomposition can be more efficient than a direct AllReduce in certain network topologies and for large data sizes. It allows for better utilization of network bandwidth and can reduce overall communication time123.
Citations
If you found this useful, please cite this post using
Senthilkumar Gopal. (Dec 2023). All Reduce Decomposition. sengopal.me. https://sengopal.me/posts/all-reduce-decomposition
or
@article{gopal2023allreducedecomposition, title = {All Reduce Decomposition}, author = {Senthilkumar Gopal}, journal = {sengopal.me}, year = {2023}, month = {Dec}, url = {https://sengopal.me/posts/all-reduce-decomposition} }