You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Intra-layer model parallelism or tensor parallelism has become an efficient sharding strategy for training LLMs. This strategy is generally implemented by distributing MLP layers to multiple devices, performing AllGather before MLP gemm and performing ReduceScatter after MLP gemm. However, with increasingly large features and hidden dims, this strategy can quickly be bottlenecked by the lack of overlap between collective and compute since only one can happen at a time. In order to break through the bottleneck, a collective matmul is proposed in this paper to further partition large gemms into multiple parts, inject collective_permutes between each partition for pipelined execution and update their corresponding partitions in the result matrix. This optimization has been implemented for TPU and shown performance improvements on some models. We think having this optimization available for XLA GPU pipeline will also be beneficial.
High-level Design
HLO Graph
The current HandleDot has EmitWindowedDotGeneral function in the SpmdPartitioningVisitor which provides logic to rewrite AllGather+gemm or gemm+reduceScatter into a while loop with trip count equal to number of partitions. For example, a full HLO Allgather+gemm pattern before pattern matching:
The above all-gather+gemm pattern will be rewritten into a while loop with sharded dots and collective-permutes to send other shards to neighbors. At a high-level, the rewritten graph has this structure depicted using pseudo-HLO for simplicity (full HLO in Appendix 1):
The above loop body implements the exact logic of collective matmul, but it has some drawbacks in terms of performance due to the fact that each while loop iteration contains a single gemm for 1 partition. One observation is that once a worker receives all data from its peer for the second partition of a matrix, the second partition’s gemm can start right away concurrently on another stream while the gemm of the first partition is running. The above logic can then be improved by unrolling the loop by a factor of 2 to allow multiple gemms running in the while loop body for a more efficient execution. An example of a 2-partition overlapping execution will look like:
ag_while_body (initial_matrix, sharded_lhs, rhs) {
// bidirectional sendrecv sharded lhs to/from peer on collective stream
collective-permute-start = collective-permute-start(sharded_lhs), operation_queue_id=3
// Concurrently on another compute stream
dot = dot(sharded_lhs, rhs), operation_queue_id=2
// Await on main stream
collective-permute-done = collective-permute-done(collective-permute-start), operation_queue_id
=0
// Run another gemm on main stream when data is ready
dot2 = dot(collective-permute-done, rhs), operation_queue_id=0
// Update intermediate result on the main stream, await on operation_queue_id=2
dynamic-update-slice = dynamic-update-slice(initial_matrix, dot), operation_queue_id=0, wait_on_operation_queues={2}
dynamic-update-slice2 = dynamic-update-slice(dynamic-update-slice, dot2), operation_queue_id=0
ROOT (dynamic-update-slice2, collective-permute-done, rhs)
}
The first dot in the above example will be run on a separate compute stream operation_queue_id=2. When its consumer, which is the first DUS, runs on the main compute stream, it will await the async event on the stream with operation_queue_id=2 due to data dependency. There are 2 alternatives to achieve this:
@jurahul suggested an idea to re-use while loop construct, to expand on that idea, we can achieve running multiple dots by unrolling the while loop by a factor of 2. Then dots from 2 partitions can be parallelized within one iteration. Benchmarks have shown that overlapping more than 2 gemms is not beneficial.
Another alternative is to manually construct the sharded dot and sendrecv sequences into a sub-computation region of a custom_call object. The instruction sequence will be very similar to the fully unrolled loop above. Since this is a custom call, other loop optimization passes won’t impact the custom call body so that frees us from adding special attributes. But we’d need to implement a separate thunk executor for this custom call.
Thunk Execution
Either of the alternatives above will require a multi-streamed execution. We will discuss how to achieve this using the while thunk alternative. The execution strategy will be the same for the other custom call alternative.
The current while thunk executes all compute thunk in a single stream. In order to do multi-stream execution, ExecuteParams will need to host multiple compute streams. The number can be controlled by a debug option for now. We will create the corresponding number of compute streams when executing the while thunk.
For optimal performance, we will need to add a operation_queue_id attribute to each instruction that will run on a non-default compute stream to instruct the runtime of which stream this kernel should be dispatched to. Non-attributed instructions are still dispatched in the default way: collectives to the collective stream; computes to the default compute stream. Note that the operation_queue_id is merely an opaque identifier, it doesn’t necessarily reflect the actual stream id on the hardware, thunks need to keep a mapping of the HLO operation_queue_id and the actual hardware stream id if needed.
Use AsyncStart and AsyncDone Ops with Synchronization thunk
Only adding stream attributes to existing instructions doesn’t change the liveness of buffers. This could have a drastic impact on buffer assignment for parallel gemms because the buffer assigner doesn’t know the gemms should consume separate buffers. We could change the logic of buffer assignment to only share buffers for kernels on the same stream. However, the buffer assignment is shared by all backends, having stream-specific logic is not reasonable without heavy refactoring of the code. @ezhulenev has suggested an approach here. The high-level idea is to add a pass to wrap compute kernels that don’t run on the main stream into asyncStart and asyncDone operations. There’s already infra set up to support liveness of buffers for async pairs so the buffer assignment should already be taken care of.
In order to be more explicit about parallel execution, we can introduce a synchronization thunk. The thunk will await on the streams of its operands and return when the data is available, an example to show the interface and definition of its ExecuteOnStream:
asyncStart kernel that runs on a non-default compute stream
Consumer of the corresponding asyncDone op if it’s running on a different stream
Here’s a high-level flow of the lowering logic:
Other Considerations
Scheduling: For both alternatives, we will need to introduce a new scheduler resource type so LHS won’t try to overlap it with other collectives.
The number of shards are currently determined by a simple model in DotHandler. But we still need to determine the number of gemms to run in parallel. For the initial phase, we can assume to use 2 streams. The end goal is to use a cost model, possible GpuPerformanceModel, to determine concurrency. However we’d need to know what the dot will be lowered to triton or cublas, the current phase ordering of SMPD passes won’t suffice. We’d likely need to introduce another pass after all the gemm rewriters to use the cost model to assess whether we want gemms to actually execute concurrently or not.
Triggering condition of collective matmul, currently it’s controlled by a threshold value as an internal field which is disabled by default for GPU. We will keep this mechanism and expose the threshold in debug options so users can decide to trigger it based on their model size. The default threshold will need to be determined using heuristics once we conduct more experiments.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Motivation
Intra-layer model parallelism or tensor parallelism has become an efficient sharding strategy for training LLMs. This strategy is generally implemented by distributing MLP layers to multiple devices, performing AllGather before MLP gemm and performing ReduceScatter after MLP gemm. However, with increasingly large features and hidden dims, this strategy can quickly be bottlenecked by the lack of overlap between collective and compute since only one can happen at a time. In order to break through the bottleneck, a collective matmul is proposed in this paper to further partition large gemms into multiple parts, inject collective_permutes between each partition for pipelined execution and update their corresponding partitions in the result matrix. This optimization has been implemented for TPU and shown performance improvements on some models. We think having this optimization available for XLA GPU pipeline will also be beneficial.
High-level Design
HLO Graph
The current HandleDot has EmitWindowedDotGeneral function in the SpmdPartitioningVisitor which provides logic to rewrite AllGather+gemm or gemm+reduceScatter into a while loop with trip count equal to number of partitions. For example, a full HLO Allgather+gemm pattern before pattern matching:
The above all-gather+gemm pattern will be rewritten into a while loop with sharded dots and collective-permutes to send other shards to neighbors. At a high-level, the rewritten graph has this structure depicted using pseudo-HLO for simplicity (full HLO in Appendix 1):
The above loop body implements the exact logic of collective matmul, but it has some drawbacks in terms of performance due to the fact that each while loop iteration contains a single gemm for 1 partition. One observation is that once a worker receives all data from its peer for the second partition of a matrix, the second partition’s gemm can start right away concurrently on another stream while the gemm of the first partition is running. The above logic can then be improved by unrolling the loop by a factor of 2 to allow multiple gemms running in the while loop body for a more efficient execution. An example of a 2-partition overlapping execution will look like:
The first dot in the above example will be run on a separate compute stream operation_queue_id=2. When its consumer, which is the first DUS, runs on the main compute stream, it will await the async event on the stream with operation_queue_id=2 due to data dependency. There are 2 alternatives to achieve this:
Thunk Execution
Either of the alternatives above will require a multi-streamed execution. We will discuss how to achieve this using the while thunk alternative. The execution strategy will be the same for the other custom call alternative.
The current while thunk executes all compute thunk in a single stream. In order to do multi-stream execution, ExecuteParams will need to host multiple compute streams. The number can be controlled by a debug option for now. We will create the corresponding number of compute streams when executing the while thunk.
For optimal performance, we will need to add a operation_queue_id attribute to each instruction that will run on a non-default compute stream to instruct the runtime of which stream this kernel should be dispatched to. Non-attributed instructions are still dispatched in the default way: collectives to the collective stream; computes to the default compute stream. Note that the operation_queue_id is merely an opaque identifier, it doesn’t necessarily reflect the actual stream id on the hardware, thunks need to keep a mapping of the HLO operation_queue_id and the actual hardware stream id if needed.
Use AsyncStart and AsyncDone Ops with Synchronization thunk
Only adding stream attributes to existing instructions doesn’t change the liveness of buffers. This could have a drastic impact on buffer assignment for parallel gemms because the buffer assigner doesn’t know the gemms should consume separate buffers. We could change the logic of buffer assignment to only share buffers for kernels on the same stream. However, the buffer assignment is shared by all backends, having stream-specific logic is not reasonable without heavy refactoring of the code. @ezhulenev has suggested an approach here. The high-level idea is to add a pass to wrap compute kernels that don’t run on the main stream into asyncStart and asyncDone operations. There’s already infra set up to support liveness of buffers for async pairs so the buffer assignment should already be taken care of.
In order to be more explicit about parallel execution, we can introduce a synchronization thunk. The thunk will await on the streams of its operands and return when the data is available, an example to show the interface and definition of its ExecuteOnStream:
The ir emitter will emit this thunk right before:
Here’s a high-level flow of the lowering logic:
Other Considerations
Scheduling: For both alternatives, we will need to introduce a new scheduler resource type so LHS won’t try to overlap it with other collectives.
The number of shards are currently determined by a simple model in DotHandler. But we still need to determine the number of gemms to run in parallel. For the initial phase, we can assume to use 2 streams. The end goal is to use a cost model, possible GpuPerformanceModel, to determine concurrency. However we’d need to know what the dot will be lowered to triton or cublas, the current phase ordering of SMPD passes won’t suffice. We’d likely need to introduce another pass after all the gemm rewriters to use the cost model to assess whether we want gemms to actually execute concurrently or not.
Triggering condition of collective matmul, currently it’s controlled by a threshold value as an internal field which is disabled by default for GPU. We will keep this mechanism and expose the threshold in debug options so users can decide to trigger it based on their model size. The default threshold will need to be determined using heuristics once we conduct more experiments.
Appendix
Beta Was this translation helpful? Give feedback.
All reactions