Senthilkumar Gopal

Musings of a machine learning researcher, engineer and leader

All Reduce Decomposition

As part of LLM serving, AllReduce is a common, but expensive operation which also blocks compute utilization. A common improvement would be to replace AllGather with a computationally equivalent ReduceScatter + AllGather

ReduceScatter: In this step, each process performs a reduction operation (e.g., sum, product, max, min) on its input data. The resulting reduced values are then scattered across all processes, such that each process receives a distinct portion of the overall reduction result.

AllGather: After the ReduceScatter step, each process holds a different portion of the overall reduction result. The AllGather operation then collects all these partial results from all processes and distributes the complete result to every process.

By breaking down AllReduce into these two steps, it can potentially improve performance and communication efficiency, particularly in certain parallel computing architectures or communication patterns. Some parallel computing libraries or frameworks may implement AllReduce as a single operation or provide optimized implementations that combine the two steps for better performance.

Diagram Explanation

     Initial State        Reduce-Scatter         AllGather          Final State
   ┌───┬───┬───┬───┐    ┌───┬───┬───┬───┐    ┌───┬───┬───┬───┐    ┌───┬───┬───┬───┐
D0 │ A │ B │ C │ D │ →  │ W │   │   │   │ →  │ W │ X │ Y │ Z │ →  │ W │ X │ Y │ Z │
   ├───┼───┼───┼───┤    ├───┼───┼───┼───┤    ├───┼───┼───┼───┤    ├───┼───┼───┼───┤
D1 │ E │ F │ G │ H │ →  │   │ X │   │   │ →  │ W │ X │ Y │ Z │ →  │ W │ X │ Y │ Z │
   ├───┼───┼───┼───┤    ├───┼───┼───┼───┤    ├───┼───┼───┼───┤    ├───┼───┼───┼───┤
D2 │ I │ J │ K │ L │ →  │   │   │ Y │   │ →  │ W │ X │ Y │ Z │ →  │ W │ X │ Y │ Z │
   ├───┼───┼───┼───┤    ├───┼───┼───┼───┤    ├───┼───┼───┼───┤    ├───┼───┼───┼───┤
D3 │ M │ N │ O │ P │ →  │   │   │   │ Z │ →  │ W │ X │ Y │ Z │ →  │ W │ X │ Y │ Z │
   └───┴───┴───┴───┘    └───┴───┴───┴───┘    └───┴───┴───┴───┘    └───┴───┴───┴───┘

Explanation:

  1. Initial State: Each device (D0, D1, D2, D3) has its own data (A-P).

  2. Reduce-Scatter:

    • Each device performs a partial reduction on a specific portion of the data.
    • W = reduction of (A, E, I, M)
    • X = reduction of (B, F, J, N)
    • Y = reduction of (C, G, K, O)
    • Z = reduction of (D, H, L, P)
    • The results are scattered across devices.
  3. AllGather:

    • Each device gathers the partial results from all other devices.
    • All devices now have the complete reduced result (W, X, Y, Z).
  4. Final State: All devices have the same, complete reduced result.

This decomposition can be more efficient than a direct AllReduce in certain network topologies and for large data sizes. It allows for better utilization of network bandwidth and can reduce overall communication time123.

Citations


  1. https://marek.ai/allreduce-the-basis-of-multi-device-communication-for-neural-network-training.html↩︎

  2. https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/operations.html↩︎

  3. https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html↩︎


If you found this useful, please cite this post using

Senthilkumar Gopal. (Dec 2023). All Reduce Decomposition. sengopal.me. https://sengopal.me/posts/all-reduce-decomposition

or

@article{gopal2023allreducedecomposition,
  title   = {All Reduce Decomposition},
  author  = {Senthilkumar Gopal},
  journal = {sengopal.me},
  year    = {2023},
  month   = {Dec},
  url     = {https://sengopal.me/posts/all-reduce-decomposition}
}