From d6b20b50d7f6ba16b7df12caca4b422c3758ae26 Mon Sep 17 00:00:00 2001 From: Luca Wehrstedt Date: Tue, 11 Feb 2025 09:46:40 +0000 Subject: [PATCH] Revert accidental push of fused-seqpar PRs They were not ready yet __original_commit__ = fairinternal/xformers@d75afed8b0b7a6f1362e5e176bdd5c5f000c8242 --- .../benchmark_sequence_parallel_fused.py | 56 ++ xformers/ops/seqpar.py | 27 +- xformers/ops/sequence_parallel_fused_ops.py | 785 ++++++++++++++---- 3 files changed, 690 insertions(+), 178 deletions(-) diff --git a/xformers/benchmarks/benchmark_sequence_parallel_fused.py b/xformers/benchmarks/benchmark_sequence_parallel_fused.py index 983427f03c..cd0d646a53 100644 --- a/xformers/benchmarks/benchmark_sequence_parallel_fused.py +++ b/xformers/benchmarks/benchmark_sequence_parallel_fused.py @@ -142,6 +142,8 @@ def run_one_rank( torch.distributed.init_process_group(backend="nccl", init_method="env://") subgroup = torch.distributed.new_group() + subgroup_nowait = torch.distributed.new_group() + subgroup_nowait_nomemcpy = torch.distributed.new_group() scenario = SCENARIOS[scenario_name](world_size) if step is Step.AllGather: @@ -235,6 +237,56 @@ def run_fused_rs(): timeout_s=10, ) + def run_fused_nowait_ag(): + nonlocal gathered_outputs_fused + from xformers.ops import fused_allgather_and_linear + + gathered_outputs_fused = fused_allgather_and_linear( + scattered_input, + [w.t() for w in weights], + group=subgroup_nowait, + _wait=False, + timeout_s=10, + ) + + def run_fused_nowait_rs(): + nonlocal scattered_outputs_fused + from xformers.ops import fused_linear_and_reducescatter + + scattered_outputs_fused = fused_linear_and_reducescatter( + gathered_input, + [w.t() for w in weights], + group=subgroup_nowait, + _wait=False, + timeout_s=10, + ) + + def run_fused_nowait_nomemcpy_ag(): + nonlocal gathered_outputs_fused + from xformers.ops import fused_allgather_and_linear + + gathered_outputs_fused = fused_allgather_and_linear( + scattered_input, + [w.t() for w in weights], + group=subgroup_nowait_nomemcpy, + _wait=False, + _memcpy=False, + timeout_s=10, + ) + + def run_fused_nowait_nomemcpy_rs(): + nonlocal scattered_outputs_fused + from xformers.ops import fused_linear_and_reducescatter + + scattered_outputs_fused = fused_linear_and_reducescatter( + gathered_input, + [w.t() for w in weights], + group=subgroup_nowait_nomemcpy, + _wait=False, + _memcpy=False, + timeout_s=10, + ) + print(f"Sizes: ({world_size}x{M // world_size})x({num_matrices}x{N})x{K}") if step is Step.AllGather: @@ -302,6 +354,10 @@ def run_fused_rs(): ), "nccl_reference": Bench(ag=run_nccl_reference_ag, rs=run_nccl_reference_rs), "fused": Bench(ag=run_fused_ag, rs=run_fused_rs), + "fused_nowait": Bench(ag=run_fused_nowait_ag, rs=run_fused_nowait_rs), + "fused_nowait_nomemcpy": Bench( + ag=run_fused_nowait_nomemcpy_ag, rs=run_fused_nowait_nomemcpy_rs + ), } unused_events = deque( diff --git a/xformers/ops/seqpar.py b/xformers/ops/seqpar.py index 05fddf9d77..b734911fa0 100644 --- a/xformers/ops/seqpar.py +++ b/xformers/ops/seqpar.py @@ -97,20 +97,21 @@ def sequence_parallel_leading_matmul_bwd( ] def my_si_matmul( - grad_gathered_input: torch.Tensor, + grad_gathered_inputs: List[torch.Tensor], dst_rank: int, stream_factory: Callable[[], torch.cuda.Stream], ) -> None: + (grad_gi,) = grad_gathered_inputs with torch.cuda.stream(stream_factory()): tiled_matmul_out( [[grad_gos[dst_rank] for grad_gos in grad_gathered_outputss]], [[w.t()] for w in weights], - out=[[grad_gathered_input]], + out=[[grad_gi]], ) fused_anything_and_reducescatter( my_si_matmul, - grad_scattered_input, + [grad_scattered_input], group=process_group, ) @@ -120,20 +121,21 @@ def my_si_matmul( events = [torch.cuda.Event() for _ in weights] def my_w_matmul( - gathered_input_shard: torch.Tensor, + gathered_inputs_shard: List[torch.Tensor], src_rank: int, stream_factory: Callable[[], torch.cuda.Stream], ) -> None: + (gi_shard,) = gathered_inputs_shard for grad_gos, grad_w, event in zip( grad_gathered_outputss, grad_weights, events ): with torch.cuda.stream(stream_factory()): event.wait() - grad_w.t().addmm_(grad_gos[src_rank].t(), gathered_input_shard) + grad_w.t().addmm_(grad_gos[src_rank].t(), gi_shard) event.record() fused_allgather_and_anything( - scattered_input, + [scattered_input], my_w_matmul, group=process_group, ) @@ -280,23 +282,20 @@ def sequence_parallel_trailing_matmul_bwd( grad_gathered_inputs = grad_gathered_input.tensor_split(mp_size, dim=0) def my_gi_and_w_matmul( - grad_gathered_output_shard: torch.Tensor, + grad_gathered_outputs_shard: List[torch.Tensor], src_rank: int, stream_factory: Callable[[], torch.cuda.Stream], ) -> None: + (grad_go_shard,) = grad_gathered_outputs_shard with torch.cuda.stream(stream_factory()): torch.matmul( - grad_gathered_output_shard, - weight.t(), - out=grad_gathered_inputs[src_rank], + grad_go_shard, weight.t(), out=grad_gathered_inputs[src_rank] ) with torch.cuda.stream(stream_factory()): - grad_weight.t().addmm_( - grad_gathered_output_shard.t(), gathered_inputs[src_rank] - ) + grad_weight.t().addmm_(grad_go_shard.t(), gathered_inputs[src_rank]) fused_allgather_and_anything( - grad_scattered_output, + [grad_scattered_output], my_gi_and_w_matmul, group=process_group, ) diff --git a/xformers/ops/sequence_parallel_fused_ops.py b/xformers/ops/sequence_parallel_fused_ops.py index e1fb00680d..31fb7b471c 100644 --- a/xformers/ops/sequence_parallel_fused_ops.py +++ b/xformers/ops/sequence_parallel_fused_ops.py @@ -4,14 +4,17 @@ # LICENSE file in the root directory of this source tree. import os -from typing import Callable, List, Optional, Sequence, Union, cast, overload +from typing import Callable, Dict, List, Optional, Sequence, Union, overload import torch import torch.distributed as dist -from torch.distributed._symmetric_memory import ( - _pipelined_all_gather_and_consume, - _pipelined_produce_and_all2all, -) +import torch.multiprocessing.reductions +from torch.distributed._symmetric_memory import get_symm_mem_workspace + +OP_FINISHED_CHANNEL = 0 +COMMS_READY_CHANNEL = 1 + +MS_IN_S = 1_000 def _is_fp8_dtype(dt: torch.dtype): @@ -20,6 +23,322 @@ def _is_fp8_dtype(dt: torch.dtype): return dt.is_floating_point and torch.finfo(dt).bits == 8 +class _FusedSequenceParallel: + """Set up a communication ring and perform fused ops on it + + Stores the persistent state needed to support a ring of connections between + processes, and the logic that can do fused comms + matmuls on it. + + We want to achieve overlap between: + - a computation which reads from the data we received from a remote GPU + - and the communication where we send some data to another GPU + And in order to do that we need some staging buffers and a way to + synchronize access to them across processes. + + To perform the communication over NVLink we make the processes exchange + their staging buffers using IPC (Inter-Process Communication) handles, which + "mounts"/"mmaps" an allocation on one GPU into the virtual address space of + another GPU: the memory remains backed by the original GPU but the other GPU + can access it as if it were local. We exchange these IPC handles using + multiprocessing Connections (and the "reductions" provided by PyTorch), + which we establish over UNIX domain sockets, whose addresses we exchange by + using a ProcessGroup. + + To synchronize accesses we use a set of counters/sequence numbers that are + also allocated in memory shared over IPC handles. Processes signal that they + completed an operation by launching a kernel that increases that value, and + they wait for anoher process to complete an operation by launching a kernel + that busy-waits for that value to increase. Currently we implement these + kernels manually, but on recent CUDA drivers (515.43.04+, corresponding to + CUDA 11.7) we could use standard stream memory operations (see + https://docs.nvidia.com/cuda/archive/11.7.0/cuda-driver-api/group__CUDA__MEMOP.html). + + We prefer to use these kernels (or the stream memory ops) over IPC events + because IPC events require signaling between processes at launch time to + ensure that the wait on one process occurs after the record on another + process. This signaling means that _launching_ our fused operation becomes a + synchronization barrier, which can increase the launch overhead. It would + also behave differently from NCCL, where launching is async and all the + synchronization happens on device in the kernels. A previous version of this + code which uses IPC events can be found here: + https://github.com/fairinternal/xformers/pull/504. + + """ + + def __init__( + self, + device: torch.device, + group: dist.ProcessGroup, + ): + self.my_device = device + self.my_rank = group.rank() + self.world_size = group.size() + self.group = group + + self.second_stream = torch.cuda.Stream() + # CUDA can schedule the matmul and the memcpy at the same time, but it + # tends to run the matmul first and delay the memcpy, which causes a + # domino effect. We thus "encourage" it to prioritize the memcpy. + self.memcpy_stream = torch.cuda.Stream(priority=-1) + # Use dedicated streams to run the wait kernels in the background. + self.compute_wait_stream = torch.cuda.Stream(priority=-1) + self.memcpy_wait_stream = torch.cuda.Stream(priority=-1) + + self.next_stream_idx = 0 + + def make_stream_factory( + self, current_stream: torch.cuda.Stream + ) -> Callable[[], torch.cuda.Stream]: + def result(): + stream = [current_stream, self.second_stream][self.next_stream_idx] + self.next_stream_idx += 1 + self.next_stream_idx %= 2 + return stream + + return result + + def allgather_and_linear( + self, + scattered_inputs: List[torch.Tensor], + my_matmul: Callable[ + [List[torch.Tensor], int, Callable[[], torch.cuda.Stream]], None + ], + timeout_s: int, + _wait: bool = True, + _memcpy: bool = True, + ): + """Perform a fused all-gather followed by a linear layer""" + + dtype = scattered_inputs[0].dtype + assert all(si.device == self.my_device for si in scattered_inputs) + assert all(si.dtype == dtype for si in scattered_inputs) + + scattered_input_numels = [si.numel() for si in scattered_inputs] + total_scattered_input_numel = sum(scattered_input_numels) + + with torch.cuda.device(self.my_device): + symm_mem = get_symm_mem_workspace( + self.group.group_name, + self.world_size * total_scattered_input_numel * dtype.itemsize, + ) + # FIXME Do something about random_init if _memcpy is True. + buffers = [ + [ + s.view((self.world_size,) + si.shape) + for s, si in zip( + symm_mem.get_buffer( + rank, [self.world_size, total_scattered_input_numel], dtype + ).split(scattered_input_numels, dim=-1), + scattered_inputs, + ) + ] + for rank in range(self.world_size) + ] + + current_stream = torch.cuda.current_stream() + + # Signal to buddy that we have read from the data (in previous iter) so + # it can overwrite it (this write matches up with wait [B] below). + for iter_ in range(1, self.world_size): + src_rank = (self.my_rank - iter_) % self.world_size + if _wait: + with torch.cuda.stream(current_stream): + symm_mem.put_signal(src_rank, OP_FINISHED_CHANNEL) + + self.second_stream.wait_stream(current_stream) + self.compute_wait_stream.wait_stream(current_stream) + self.memcpy_wait_stream.wait_stream(current_stream) + stream_factory = self.make_stream_factory(current_stream) + + for iter_ in range(1, self.world_size): + dst_rank = (self.my_rank + iter_) % self.world_size + + # Wait for buddy to signal that it read from the data before we + # overwrite it (this wait matches up with write [B] above). + if _wait: + with torch.cuda.stream(self.memcpy_wait_stream): + symm_mem.wait_signal( + dst_rank, + OP_FINISHED_CHANNEL, + timeout_ms=timeout_s * MS_IN_S, # type: ignore[call-arg] + ) + + self.memcpy_stream.wait_stream(self.memcpy_wait_stream) + + if _memcpy: + with torch.cuda.stream(self.memcpy_stream): + for bs, si in zip(buffers[dst_rank], scattered_inputs): + bs[self.my_rank].copy_(si) + + # Signal to buddy that we have written into the data so it can + # read from it (this write matches up with wait [A] below). + if _wait: + with torch.cuda.stream(self.memcpy_stream): + symm_mem.memset32( + symm_mem.get_signal_pad(dst_rank), # type: ignore[attr-defined] + self.world_size * COMMS_READY_CHANNEL + self.my_rank, + val=1, + count=1, + ) + + my_matmul(scattered_inputs, self.my_rank, stream_factory) + + for iter_ in range(1, self.world_size): + src_rank = (self.my_rank - iter_) % self.world_size + + # Wait for buddy to signal that it wrote into the data before we + # read from it (this wait matches up with write [A] above). + if _wait: + with torch.cuda.stream(self.compute_wait_stream): + symm_mem.wait_signal( + src_rank, + COMMS_READY_CHANNEL, + timeout_ms=timeout_s * MS_IN_S, # type: ignore[call-arg] + ) + + current_stream.wait_stream(self.compute_wait_stream) + self.second_stream.wait_stream(self.compute_wait_stream) + + my_matmul( + [s[src_rank] for s in buffers[self.my_rank]], src_rank, stream_factory + ) + + current_stream.wait_stream(self.second_stream) + current_stream.wait_stream(self.memcpy_stream) + + def linear_and_reducescatter( + self, + my_matmul: Callable[ + [List[torch.Tensor], int, Callable[[], torch.cuda.Stream]], None + ], + gathered_outputs: List[torch.Tensor], + scattered_outputs: List[torch.Tensor], + timeout_s: int, + _wait: bool = True, + _memcpy: bool = True, + ): + """Perform a fused linear layer followed by a reduce-scatter""" + + dtype = gathered_outputs[0].dtype + assert all(go.device == self.my_device for go in gathered_outputs) + assert all(go.dtype == dtype for go in gathered_outputs) + assert all(so.device == self.my_device for so in scattered_outputs) + assert all(so.dtype == dtype for so in scattered_outputs) + + scattered_output_numels = [so.numel() for so in scattered_outputs] + total_scattered_output_numel = sum(scattered_output_numels) + + with torch.cuda.device(self.my_device): + symm_mem = get_symm_mem_workspace( + self.group.group_name, + self.world_size * total_scattered_output_numel * dtype.itemsize, + ) + # FIXME Do something about random_init if _memcpy is True. + buffers = [ + [ + s.view((self.world_size,) + so.shape) + for s, so in zip( + symm_mem.get_buffer( + rank, [self.world_size, total_scattered_output_numel], dtype + ).split(scattered_output_numels, dim=-1), + scattered_outputs, + ) + ] + for rank in range(self.world_size) + ] + + current_stream = torch.cuda.current_stream() + + # Signal to buddy that we have read from the data (in previous iter) + # so it can overwrite it (this write matches up with wait [2] below). + for iter_ in range(1, self.world_size): + src_rank = (self.my_rank - iter_) % self.world_size + if _wait: + with torch.cuda.stream(current_stream): + symm_mem.put_signal(src_rank, OP_FINISHED_CHANNEL) + + self.second_stream.wait_stream(current_stream) + self.compute_wait_stream.wait_stream(current_stream) + self.memcpy_wait_stream.wait_stream(current_stream) + stream_factory = self.make_stream_factory(current_stream) + + for iter_ in range(1, self.world_size): + dst_rank = (self.my_rank + iter_) % self.world_size + + # Wait for buddy to signal that it read from the data before we + # overwrite it (this wait matches up with write [2] above). + if _wait: + with torch.cuda.stream(self.compute_wait_stream): + symm_mem.wait_signal( + dst_rank, + OP_FINISHED_CHANNEL, + timeout_ms=timeout_s * MS_IN_S, # type: ignore[call-arg] + ) + + current_stream.wait_stream(self.compute_wait_stream) + self.second_stream.wait_stream(self.compute_wait_stream) + + my_matmul( + [s[dst_rank] for s in buffers[self.my_rank]], dst_rank, stream_factory + ) + + # Deduce which stream contains the last kernel launched. + final_stream = [current_stream, self.second_stream][ + (self.next_stream_idx - 1) % 2 + ] + final_stream.wait_stream(current_stream) + final_stream.wait_stream(self.second_stream) + + # Signal to buddy that we have written into the data so it can + # read from it (this write matches up with wait [1] below). + if _wait: + with torch.cuda.stream(final_stream): + symm_mem.memset32( + symm_mem.get_signal_pad(dst_rank), # type: ignore[attr-defined] + self.world_size * COMMS_READY_CHANNEL + self.my_rank, + val=1, + count=1, + ) + + my_matmul( + [o[self.my_rank] for o in gathered_outputs], + self.my_rank, + stream_factory, + ) + + for iter_ in range(1, self.world_size): + src_rank = (self.my_rank - iter_) % self.world_size + + # Wait for buddy to signal that it wrote into the data before we + # read from it (this wait matches up with write [1] above). + if _wait: + with torch.cuda.stream(self.memcpy_wait_stream): + symm_mem.wait_signal( + src_rank, + COMMS_READY_CHANNEL, + timeout_ms=timeout_s * MS_IN_S, # type: ignore[call-arg] + ) + + self.memcpy_stream.wait_stream(self.memcpy_wait_stream) + + if _memcpy: + with torch.cuda.stream(self.memcpy_stream): + for go, bs in zip(gathered_outputs, buffers[src_rank]): + go[src_rank].copy_(bs[self.my_rank]) + + current_stream.wait_stream(self.second_stream) + current_stream.wait_stream(self.memcpy_stream) + + for go, so in zip(gathered_outputs, scattered_outputs): + torch.sum(go, dim=0, out=so) + + +# We'd store this as an attribute on the PG object itself, but some PGs are +# pybind-bound classes and thus don't support it, so we simulate this as an +# external cache. +CACHE: Dict[int, Optional[_FusedSequenceParallel]] = {} + + def _can_ranks_communicate_all_to_all_over_nvlink(group: dist.ProcessGroup) -> bool: # FIXME This is currently overly simplistic, must be improved. The following # should be enough: @@ -33,15 +352,23 @@ def _can_ranks_communicate_all_to_all_over_nvlink(group: dist.ProcessGroup) -> b return group.size() <= 8 -def _should_use_fallback(group: dist.ProcessGroup) -> bool: +def _lazy_init( + device: torch.device, group: dist.ProcessGroup +) -> Optional[_FusedSequenceParallel]: world_size = group.size() - if int(os.environ.get("DISABLE_FUSED_SEQUENCE_PARALLEL", "0")): - return True - elif world_size == 1: - return True - elif not _can_ranks_communicate_all_to_all_over_nvlink(group): - return True - return False + try: + obj = CACHE[id(group)] + except KeyError: + if int(os.environ.get("DISABLE_FUSED_SEQUENCE_PARALLEL", "0")): + obj = None + elif world_size == 1: + obj = None + elif not _can_ranks_communicate_all_to_all_over_nvlink(group): + obj = None + else: + obj = _FusedSequenceParallel(device, group) + CACHE[id(group)] = obj + return obj def _default_stream_factory() -> torch.cuda.Stream: @@ -146,74 +473,101 @@ def fused_allgather_and_linear( assert scattered_input.ndim >= 2 assert all(scattered_input.shape[-1] == w.shape[-1] for w in weights) assert scattered_input.is_contiguous() - - # Fallback - if world_size == 1 or _should_use_fallback(group): - if world_size == 1: - gathered_input = scattered_input - else: - gathered_input = scattered_input.new_empty( - (world_size,) + scattered_input.shape - ).flatten(0, 1) - dist.all_gather_into_tensor( - output_tensor=gathered_input, input_tensor=scattered_input, group=group - ) - - if scale_scattered_input is None: - gathered_outputs = [torch.matmul(gathered_input, w.t()) for w in weights] - else: - gathered_outputs = [ - torch._scaled_mm( - gathered_input, - w.t(), - scale_scattered_input, - cast(torch.Tensor, scale_w), - out_dtype=out_dtype, - ) - for w, scale_w in zip(weights, scales_weights) - ] - - # Fast path - else: - if scale_scattered_input is None: - _, gathered_outputs = torch.ops.symm_mem.fused_all_gather_matmul( - scattered_input, [w.t() for w in weights], 0, group.group_name - ) - else: - _, gathered_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( - scattered_input, - [w.t() for w in weights], - scale_scattered_input, - scales_weights, - 0, - group.group_name, - biases=[None] * len(weights), - result_scales=[None] * len(weights), - out_dtypes=[out_dtype] * len(weights), - use_fast_accum=[False] * len(weights), - ) - + gathered_input_shape = (world_size,) + scattered_input.shape + gathered_output_shapes = [gathered_input_shape[:-1] + w.shape[:-1] for w in weights] if out is not None: assert isinstance(out, list) == isinstance(weight, list) - outs = out if isinstance(out, list) else [out] - assert len(outs) == len(gathered_outputs) - for o, go in zip(outs, gathered_outputs): - assert o.device == go.device - assert o.dtype == go.dtype - assert o.shape == go.shape - if out_dtype is not None: - assert o.dtype == out_dtype - o.copy_(go) + gathered_outputs = out if isinstance(out, list) else [out] + assert len(gathered_outputs) == len(gathered_output_shapes) + assert all( + go.shape == gos for go, gos in zip(gathered_outputs, gathered_output_shapes) + ) + assert all(go.is_contiguous() for go in gathered_outputs) + if out_dtype is not None: + if isinstance(out, list): + for o in out: + assert o.dtype == out_dtype + else: + assert out.dtype == out_dtype + else: + gathered_outputs = [ + scattered_input.new_empty( + gos, + dtype=out_dtype if out_dtype is not None else scattered_input.dtype, + ) + for gos in gathered_output_shapes + ] + + torch.ops.xformers_python._fused_allgather_and_linear_impl( + scattered_input, + weights, + group.group_name, + gathered_outputs, + timeout_s=timeout_s, + _wait=private_args_DO_NOT_USE.get("_wait", True), + _memcpy=private_args_DO_NOT_USE.get("_memcpy", True), + scale_scattered_input=scale_scattered_input, + scales_weights=scales_weights, + ) if isinstance(weight, list): - return gathered_outputs + return [go.flatten(0, 1) for go in gathered_outputs] else: - return gathered_outputs[0] + return gathered_outputs[0].flatten(0, 1) -def fused_allgather_and_anything( +@torch.library.custom_op( + "xformers_python::_fused_allgather_and_linear_impl", + mutates_args={"gathered_outputs"}, + device_types="cuda", +) +def _fused_allgather_and_linear_custom_op( scattered_input: torch.Tensor, - my_matmul: Callable[[torch.Tensor, int, Callable[[], torch.cuda.Stream]], None], + weights: List[torch.Tensor], + process_group_name: str, + gathered_outputs: List[torch.Tensor], + timeout_s: int, + _wait: bool, + _memcpy: bool, + scale_scattered_input: torch.Tensor, + scales_weights: Sequence[Optional[torch.Tensor]], +) -> None: + process_group = dist.distributed_c10d._resolve_process_group(process_group_name) + + def my_matmul( + inputs: List[torch.Tensor], + src_rank: int, + stream_factory: Callable[[], torch.cuda.Stream], + ) -> None: + for w, scale_weight, go in zip(weights, scales_weights, gathered_outputs): + with torch.cuda.stream(stream_factory()): + if scale_scattered_input is not None and scale_weight is not None: + torch._scaled_mm( + inputs[0], + w.t(), + out_dtype=go[src_rank].dtype, + scale_a=scale_scattered_input, + scale_b=scale_weight, + out=go[src_rank], + ) + else: + torch.matmul(inputs[0], w.t(), out=go[src_rank]) + + fused_allgather_and_anything( + [scattered_input], + my_matmul, + group=process_group, + timeout_s=timeout_s, + _wait=_wait, + _memcpy=_memcpy, + ) + + +def fused_allgather_and_anything( + scattered_inputs: List[torch.Tensor], + my_matmul: Callable[ + [List[torch.Tensor], int, Callable[[], torch.cuda.Stream]], None + ], *, group: dist.ProcessGroup, timeout_s: int = 60 * 60, @@ -221,41 +575,50 @@ def fused_allgather_and_anything( ) -> None: world_size = group.size() - assert scattered_input.is_contiguous() + if len(scattered_inputs) == 0: + for src_rank in range(world_size): + my_matmul([], src_rank, _default_stream_factory) + return - gathered_input_shape = (world_size,) + scattered_input.shape + assert all(si.is_contiguous() for si in scattered_inputs) + assert all(si.device == scattered_inputs[0].device for si in scattered_inputs) + assert all(si.dtype == scattered_inputs[0].dtype for si in scattered_inputs) + + gathered_input_shapes = [(world_size,) + si.shape for si in scattered_inputs] + + obj = _lazy_init(scattered_inputs[0].device, group) if world_size == 1: - my_matmul(scattered_input, 0, _default_stream_factory) + my_matmul(scattered_inputs, 0, _default_stream_factory) # Fallback - elif _should_use_fallback(group): - gathered_input = scattered_input.new_empty(gathered_input_shape) - dist.all_gather_into_tensor( - output_tensor=gathered_input, input_tensor=scattered_input, group=group - ) + elif obj is None: + gathered_inputs = [ + si.new_empty(gis) + for si, gis in zip(scattered_inputs, gathered_input_shapes) + ] + for si, gi in zip(scattered_inputs, gathered_inputs): + dist.all_gather_into_tensor(output_tensor=gi, input_tensor=si, group=group) for src_rank in range(world_size): my_matmul( - gathered_input[src_rank], + [gi[src_rank] for gi in gathered_inputs], src_rank, _default_stream_factory, ) # Fast path else: - - def my_wrapper(t, rank): - my_matmul(t.squeeze(0), rank, _default_stream_factory) - - gathered_input = scattered_input.new_empty(gathered_input_shape) - _pipelined_all_gather_and_consume( - scattered_input, - my_wrapper, - gathered_input, - group.group_name, + assert scattered_inputs[0].device == obj.my_device + obj.allgather_and_linear( + scattered_inputs, + my_matmul, + timeout_s=timeout_s, + _wait=private_args_DO_NOT_USE.get("_wait", True), + _memcpy=private_args_DO_NOT_USE.get("_memcpy", True), ) +@overload def fused_linear_and_reducescatter( gathered_input: torch.Tensor, weight: torch.Tensor, @@ -264,10 +627,41 @@ def fused_linear_and_reducescatter( out: Optional[torch.Tensor] = None, timeout_s: int = 60 * 60, scale_gathered_input: Optional[torch.Tensor] = None, - scale_weight: Optional[torch.Tensor] = None, + scale_weight: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, out_dtype: Optional[torch.dtype] = None, **private_args_DO_NOT_USE, ) -> torch.Tensor: + ... + + +@overload +def fused_linear_and_reducescatter( + gathered_input: torch.Tensor, + weight: List[torch.Tensor], + *, + group: dist.ProcessGroup, + out: Optional[List[torch.Tensor]] = None, + timeout_s: int = 60 * 60, + scale_gathered_input: Optional[torch.Tensor] = None, + scale_weight: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + out_dtype: Optional[torch.dtype] = None, + **private_args_DO_NOT_USE, +) -> List[torch.Tensor]: + ... + + +def fused_linear_and_reducescatter( + gathered_input: torch.Tensor, + weight: Union[torch.Tensor, List[torch.Tensor]], + *, + group: dist.ProcessGroup, + out: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + timeout_s: int = 60 * 60, + scale_gathered_input: Optional[torch.Tensor] = None, + scale_weight: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + out_dtype: Optional[torch.dtype] = None, + **private_args_DO_NOT_USE, +) -> Union[torch.Tensor, List[torch.Tensor]]: """Performs a fused linear op followed by a reduce-scatter It is equivalent to the following plain PyTorch code: @@ -285,72 +679,124 @@ def fused_linear_and_reducescatter( gathered_input datatype. """ world_size = group.size() + weights = weight if isinstance(weight, list) else [weight] assert (scale_gathered_input is None) == (scale_weight is None) if scale_weight is not None: + assert isinstance(weight, list) == isinstance(scale_weight, list) + scales_weights: Sequence[Optional[torch.Tensor]] = ( + scale_weight if isinstance(scale_weight, list) else [scale_weight] + ) + assert len(weights) == len(scales_weights) assert _is_fp8_dtype(gathered_input.dtype) - assert _is_fp8_dtype(weight.dtype) + assert all(_is_fp8_dtype(w.dtype) for w in weights) assert out_dtype is not None, "output_dtype is required with FP8" - assert weight.ndim == 2 + else: + scales_weights = [None] * len(weights) + assert all(w.ndim == 2 for w in weights) assert gathered_input.ndim >= 2 - assert gathered_input.shape[-1] == weight.shape[-1] + assert all(gathered_input.shape[-1] == w.shape[-1] for w in weights) assert gathered_input.is_contiguous() assert gathered_input.shape[0] % world_size == 0 - - # Fallback - if world_size == 1 or _should_use_fallback(group): - if scale_gathered_input is None: - gathered_output = torch.matmul(gathered_input, weight.t()) - else: - gathered_output = torch._scaled_mm( - gathered_input, - weight.t(), - scale_gathered_input, - cast(torch.Tensor, scale_weight), - out_dtype=out_dtype, - ) - - if world_size == 1: - scattered_output = gathered_output - else: - scattered_output = torch.empty_like( - gathered_output.unflatten(0, (world_size, -1))[0] - ) - dist.reduce_scatter_tensor( - output=scattered_output, input=gathered_output, group=group + gathered_input = gathered_input.view( + (world_size, gathered_input.shape[0] // world_size) + gathered_input.shape[1:] + ) + gathered_output_shapes = [gathered_input.shape[:-1] + w.shape[:-1] for w in weights] + scattered_output_shapes = [gos[1:] for gos in gathered_output_shapes] + if out is not None: + assert isinstance(out, list) == isinstance(weight, list) + scattered_outputs = out if isinstance(out, list) else [out] + assert len(scattered_outputs) == scattered_output_shapes + assert all(so.device == gathered_input.device for so in scattered_outputs) + assert all(so.dtype == gathered_input.dtype for so in scattered_outputs) + assert all( + so.shape == sos + for so, sos in zip(scattered_outputs, scattered_output_shapes) + ) + if out_dtype is not None: + if isinstance(out, list): + for o in out: + assert o.dtype == out_dtype + else: + assert out.dtype == out_dtype + else: + scattered_outputs = [ + gathered_input.new_empty( + sos, + dtype=out_dtype if out_dtype is not None else gathered_input.dtype, ) + for sos in scattered_output_shapes + ] + + torch.ops.xformers_python._fused_linear_and_reducescatter_impl( + gathered_input, + weights, + group.group_name, + scattered_outputs, + timeout_s=timeout_s, + _wait=private_args_DO_NOT_USE.get("_wait", True), + _memcpy=private_args_DO_NOT_USE.get("_memcpy", True), + scale_gathered_input=scale_gathered_input, + scales_weights=scales_weights, + ) - # Fast path + if isinstance(weight, list): + return scattered_outputs else: - if scale_gathered_input is None: - scattered_output = torch.ops.symm_mem.fused_matmul_reduce_scatter( - gathered_input, weight.t(), "sum", 0, group.group_name - ) - else: - scattered_output = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter( - gathered_input, - weight.t(), - scale_gathered_input, - scale_weight, - "sum", - 0, - group.group_name, - out_dtype=out_dtype, - ) + return scattered_outputs[0] - if out is not None: - assert out.device == scattered_output.device - assert out.dtype == scattered_output.dtype - assert out.shape == scattered_output.shape - if out_dtype is not None: - assert out.dtype == out_dtype - out.copy_(scattered_output) - return scattered_output +@torch.library.custom_op( + "xformers_python::_fused_linear_and_reducescatter_impl", + mutates_args={"scattered_outputs"}, + device_types="cuda", +) +def _fused_linear_and_reducescatter_custom_op( + gathered_input: torch.Tensor, + weights: List[torch.Tensor], + process_group_name: str, + scattered_outputs: List[torch.Tensor], + timeout_s: int, + _wait: bool, + _memcpy: bool, + scale_gathered_input: torch.Tensor, + scales_weights: Sequence[Optional[torch.Tensor]], +) -> None: + process_group = dist.distributed_c10d._resolve_process_group(process_group_name) + + def my_matmul( + outputs: List[torch.Tensor], + dst_rank: int, + stream_factory: Callable[[], torch.cuda.Stream], + ) -> None: + for w, scale_weight, o in zip(weights, scales_weights, outputs): + with torch.cuda.stream(stream_factory()): + if scale_gathered_input is not None and scale_weight is not None: + torch._scaled_mm( + gathered_input[dst_rank], + w.t(), + out_dtype=o.dtype, + scale_a=scale_gathered_input, + scale_b=scale_weight, + out=o, + ) + else: + torch.matmul(gathered_input[dst_rank], w.t(), out=o) + + fused_anything_and_reducescatter( + my_matmul, + scattered_outputs, + group=process_group, + timeout_s=timeout_s, + _wait=_wait, + _memcpy=_memcpy, + ) def fused_anything_and_reducescatter( - my_matmul: Callable[[torch.Tensor, int, Callable[[], torch.cuda.Stream]], None], - scattered_output: torch.Tensor, + my_matmul: Callable[ + [List[torch.Tensor], int, Callable[[], torch.cuda.Stream]], None + ], + scattered_outputs: List[torch.Tensor], *, group: dist.ProcessGroup, timeout_s: int = 60 * 60, @@ -358,37 +804,48 @@ def fused_anything_and_reducescatter( ) -> None: world_size = group.size() - assert scattered_output.is_contiguous() + if len(scattered_outputs) == 0: + for dst_rank in range(world_size): + my_matmul([], dst_rank, _default_stream_factory) + return + + assert all(so.is_contiguous() for so in scattered_outputs) + assert all(so.device == scattered_outputs[0].device for so in scattered_outputs) + assert all(so.dtype == scattered_outputs[0].dtype for so in scattered_outputs) + + gathered_output_shapes = [(world_size,) + so.shape for so in scattered_outputs] - gathered_output_shape = (world_size,) + scattered_output.shape + obj = _lazy_init(scattered_outputs[0].device, group) if world_size == 1: - my_matmul(scattered_output, 0, _default_stream_factory) + my_matmul(scattered_outputs, 0, _default_stream_factory) # Fallback - elif _should_use_fallback(group): - gathered_output = scattered_output.new_empty(gathered_output_shape) + elif obj is None: + gathered_outputs = [ + so.new_empty(gos) + for so, gos in zip(scattered_outputs, gathered_output_shapes) + ] for dst_rank in range(world_size): my_matmul( - gathered_output[dst_rank], + [go[dst_rank] for go in gathered_outputs], dst_rank, _default_stream_factory, ) - dist.reduce_scatter_tensor( - output=scattered_output, input=gathered_output, group=group - ) + for go, so in zip(gathered_outputs, scattered_outputs): + dist.reduce_scatter_tensor(output=so, input=go, group=group) # Fast path else: - - def my_wrapper(rank, t): - my_matmul(t.squeeze(0), rank, _default_stream_factory) - - gathered_output = scattered_output.new_empty(gathered_output_shape) - _pipelined_produce_and_all2all( - my_wrapper, - gathered_output, - group.group_name, + assert scattered_outputs[0].device == obj.my_device + gathered_outputs = [ + scattered_outputs[0].new_empty(gos) for gos in gathered_output_shapes + ] + obj.linear_and_reducescatter( + my_matmul, + gathered_outputs, + scattered_outputs, + timeout_s=timeout_s, + _wait=private_args_DO_NOT_USE.get("_wait", True), + _memcpy=private_args_DO_NOT_USE.get("_memcpy", True), ) - - torch.sum(gathered_output, dim=0, out=scattered_output)