Skip to content

Commit

Permalink
Revert accidental push of fused-seqpar PRs
Browse files Browse the repository at this point in the history
They were not ready yet

__original_commit__ = fairinternal/xformers@d75afed
  • Loading branch information
lw authored and xFormers Bot committed Feb 11, 2025
1 parent 6df18d9 commit d6b20b5
Show file tree
Hide file tree
Showing 3 changed files with 690 additions and 178 deletions.
56 changes: 56 additions & 0 deletions xformers/benchmarks/benchmark_sequence_parallel_fused.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ def run_one_rank(
torch.distributed.init_process_group(backend="nccl", init_method="env://")

subgroup = torch.distributed.new_group()
subgroup_nowait = torch.distributed.new_group()
subgroup_nowait_nomemcpy = torch.distributed.new_group()

scenario = SCENARIOS[scenario_name](world_size)
if step is Step.AllGather:
Expand Down Expand Up @@ -235,6 +237,56 @@ def run_fused_rs():
timeout_s=10,
)

def run_fused_nowait_ag():
nonlocal gathered_outputs_fused
from xformers.ops import fused_allgather_and_linear

gathered_outputs_fused = fused_allgather_and_linear(
scattered_input,
[w.t() for w in weights],
group=subgroup_nowait,
_wait=False,
timeout_s=10,
)

def run_fused_nowait_rs():
nonlocal scattered_outputs_fused
from xformers.ops import fused_linear_and_reducescatter

scattered_outputs_fused = fused_linear_and_reducescatter(
gathered_input,
[w.t() for w in weights],
group=subgroup_nowait,
_wait=False,
timeout_s=10,
)

def run_fused_nowait_nomemcpy_ag():
nonlocal gathered_outputs_fused
from xformers.ops import fused_allgather_and_linear

gathered_outputs_fused = fused_allgather_and_linear(
scattered_input,
[w.t() for w in weights],
group=subgroup_nowait_nomemcpy,
_wait=False,
_memcpy=False,
timeout_s=10,
)

def run_fused_nowait_nomemcpy_rs():
nonlocal scattered_outputs_fused
from xformers.ops import fused_linear_and_reducescatter

scattered_outputs_fused = fused_linear_and_reducescatter(
gathered_input,
[w.t() for w in weights],
group=subgroup_nowait_nomemcpy,
_wait=False,
_memcpy=False,
timeout_s=10,
)

print(f"Sizes: ({world_size}x{M // world_size})x({num_matrices}x{N})x{K}")

if step is Step.AllGather:
Expand Down Expand Up @@ -302,6 +354,10 @@ def run_fused_rs():
),
"nccl_reference": Bench(ag=run_nccl_reference_ag, rs=run_nccl_reference_rs),
"fused": Bench(ag=run_fused_ag, rs=run_fused_rs),
"fused_nowait": Bench(ag=run_fused_nowait_ag, rs=run_fused_nowait_rs),
"fused_nowait_nomemcpy": Bench(
ag=run_fused_nowait_nomemcpy_ag, rs=run_fused_nowait_nomemcpy_rs
),
}

unused_events = deque(
Expand Down
27 changes: 13 additions & 14 deletions xformers/ops/seqpar.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,20 +97,21 @@ def sequence_parallel_leading_matmul_bwd(
]

def my_si_matmul(
grad_gathered_input: torch.Tensor,
grad_gathered_inputs: List[torch.Tensor],
dst_rank: int,
stream_factory: Callable[[], torch.cuda.Stream],
) -> None:
(grad_gi,) = grad_gathered_inputs
with torch.cuda.stream(stream_factory()):
tiled_matmul_out(
[[grad_gos[dst_rank] for grad_gos in grad_gathered_outputss]],
[[w.t()] for w in weights],
out=[[grad_gathered_input]],
out=[[grad_gi]],
)

fused_anything_and_reducescatter(
my_si_matmul,
grad_scattered_input,
[grad_scattered_input],
group=process_group,
)

Expand All @@ -120,20 +121,21 @@ def my_si_matmul(
events = [torch.cuda.Event() for _ in weights]

def my_w_matmul(
gathered_input_shard: torch.Tensor,
gathered_inputs_shard: List[torch.Tensor],
src_rank: int,
stream_factory: Callable[[], torch.cuda.Stream],
) -> None:
(gi_shard,) = gathered_inputs_shard
for grad_gos, grad_w, event in zip(
grad_gathered_outputss, grad_weights, events
):
with torch.cuda.stream(stream_factory()):
event.wait()
grad_w.t().addmm_(grad_gos[src_rank].t(), gathered_input_shard)
grad_w.t().addmm_(grad_gos[src_rank].t(), gi_shard)
event.record()

fused_allgather_and_anything(
scattered_input,
[scattered_input],
my_w_matmul,
group=process_group,
)
Expand Down Expand Up @@ -280,23 +282,20 @@ def sequence_parallel_trailing_matmul_bwd(
grad_gathered_inputs = grad_gathered_input.tensor_split(mp_size, dim=0)

def my_gi_and_w_matmul(
grad_gathered_output_shard: torch.Tensor,
grad_gathered_outputs_shard: List[torch.Tensor],
src_rank: int,
stream_factory: Callable[[], torch.cuda.Stream],
) -> None:
(grad_go_shard,) = grad_gathered_outputs_shard
with torch.cuda.stream(stream_factory()):
torch.matmul(
grad_gathered_output_shard,
weight.t(),
out=grad_gathered_inputs[src_rank],
grad_go_shard, weight.t(), out=grad_gathered_inputs[src_rank]
)
with torch.cuda.stream(stream_factory()):
grad_weight.t().addmm_(
grad_gathered_output_shard.t(), gathered_inputs[src_rank]
)
grad_weight.t().addmm_(grad_go_shard.t(), gathered_inputs[src_rank])

fused_allgather_and_anything(
grad_scattered_output,
[grad_scattered_output],
my_gi_and_w_matmul,
group=process_group,
)
Expand Down
Loading

0 comments on commit d6b20b5

Please sign in to comment.