Skip to content

Commit

Permalink
Replace torch.distributed.get_world_size(pg) with pg.size()
Browse files Browse the repository at this point in the history
This allows us to invoke it even if torch.distributed wasn't initialized yet.

ghstack-source-id: f2eb1c3a32e3f572cb97585c58eb30383d5f86c9
Pull Request resolved: fairinternal/xformers#1006

__original_commit__ = fairinternal/xformers@bd516e6
  • Loading branch information
lw authored and xFormers Bot committed Jan 26, 2024
1 parent 94582e8 commit e5e812a
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 18 deletions.
6 changes: 3 additions & 3 deletions xformers/ops/differentiable_collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def all_reduce(
) -> None:
assert x.is_contiguous()

mp_size = torch.distributed.get_world_size(process_group)
mp_size = process_group.size()
if mp_size == 1:
return

Expand All @@ -28,7 +28,7 @@ def gather_along_first_dim_async(
input_: torch.Tensor, *, process_group: torch.distributed.ProcessGroup
) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]:
assert input_.is_contiguous()
mp_size = torch.distributed.get_world_size(process_group)
mp_size = process_group.size()
if mp_size == 1:
return input_, None

Expand All @@ -47,7 +47,7 @@ def reduce_scatter_along_first_dim_async(
input_: torch.Tensor, *, process_group: torch.distributed.ProcessGroup
) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]:
assert input_.is_contiguous()
mp_size = torch.distributed.get_world_size(process_group)
mp_size = process_group.size()
if mp_size == 1:
return input_, None

Expand Down
8 changes: 4 additions & 4 deletions xformers/ops/modpar_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def _init_2d_weight(
# The reason we initialize the full unpartitioned/gathered weight is so that
# different ranks get different initial values and thus "break the symmetry"
# and in order to achieve the same init for any value of model parallelism.
rank = torch.distributed.get_rank(process_group)
world_size = torch.distributed.get_world_size(process_group)
rank = process_group.rank()
world_size = process_group.size()

nrows, ncols = weight.shape
if partition_dim == 0:
Expand Down Expand Up @@ -77,7 +77,7 @@ def __init__(
self.sequence_parallel = sequence_parallel
self.fuse_sequence_parallel = fuse_sequence_parallel
self.process_group = process_group
mp_size = torch.distributed.get_world_size(process_group)
mp_size = process_group.size()
assert all(dim % mp_size == 0 for dim in out_features)
self.my_out_features = [dim // mp_size for dim in out_features]

Expand Down Expand Up @@ -136,7 +136,7 @@ def __init__(
self.sequence_parallel = sequence_parallel
self.fuse_sequence_parallel = fuse_sequence_parallel
self.process_group = process_group
mp_size = torch.distributed.get_world_size(process_group)
mp_size = process_group.size()
assert in_features % mp_size == 0
self.my_in_features = in_features // mp_size

Expand Down
4 changes: 2 additions & 2 deletions xformers/ops/seqpar.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def sequence_parallel_leading_matmul_bwd(
fuse: bool,
process_group: torch.distributed.ProcessGroup,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
mp_size = torch.distributed.get_world_size(process_group)
mp_size = process_group.size()

if fuse:
grad_scattered_input = torch.empty_like(scattered_input)
Expand Down Expand Up @@ -206,7 +206,7 @@ def sequence_parallel_trailing_matmul_bwd(
fuse: bool,
process_group: torch.distributed.ProcessGroup,
) -> Tuple[torch.Tensor, torch.Tensor]:
mp_size = torch.distributed.get_world_size(process_group)
mp_size = process_group.size()

if fuse:
grad_gathered_input = torch.empty_like(gathered_input)
Expand Down
18 changes: 9 additions & 9 deletions xformers/ops/sequence_parallel_fused_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _exchange_addresses(
group: dist.ProcessGroup,
device: torch.device,
) -> List[List[str]]:
world_size = dist.get_world_size(group=group)
world_size = group.size()
my_addresses: List[str] = []
for listener in listeners:
addr = listener.address
Expand Down Expand Up @@ -179,8 +179,8 @@ def __init__(
):
self.my_device = device
self.dtype = dtype
self.my_rank = dist.get_rank(group=group)
self.world_size = dist.get_world_size(group=group)
self.my_rank = group.rank()
self.world_size = group.size()
self.num_stripes = num_stripes
self.my_device_capability = torch.cuda.get_device_capability(self.my_device)

Expand Down Expand Up @@ -670,13 +670,13 @@ def _can_ranks_communicate_all_to_all_over_nvlink(group: dist.ProcessGroup) -> b
# visible? maybe just trying to exchange IPC handles and catching errors
# would work? note that in any case some ranks might succeed while some
# might fail so we need a barrier to have them all make the same decision)
return dist.get_world_size(group=group) <= 8
return group.size() <= 8


def _lazy_init(
device: torch.device, dtype: torch.dtype, group: dist.ProcessGroup, num_stripes: int
) -> Optional[_FusedSequenceParallel]:
world_size = dist.get_world_size(group=group)
world_size = group.size()
try:
obj = CACHE[id(group)]
except KeyError:
Expand Down Expand Up @@ -766,7 +766,7 @@ def fused_allgather_and_linear(
memory for speed. This can be controlled using the num_stripes argument.
"""
world_size = dist.get_world_size(group=group)
world_size = group.size()
weights = weight if isinstance(weight, list) else [weight]
assert all(w.ndim == 2 for w in weights)
assert scattered_input.ndim >= 2
Expand Down Expand Up @@ -828,7 +828,7 @@ def fused_allgather_and_anything(
timeout_s: int = 60 * 60,
**private_args_DO_NOT_USE,
) -> None:
world_size = dist.get_world_size(group=group)
world_size = group.size()

if len(scattered_inputs) == 0:
for src_rank in range(world_size):
Expand Down Expand Up @@ -928,7 +928,7 @@ def fused_linear_and_reducescatter(
dist.reduce_scatter_tensor(scattered_output, gathered_output, group=group)
"""
world_size = dist.get_world_size(group=group)
world_size = group.size()
weights = weight if isinstance(weight, list) else [weight]
assert all(w.ndim == 2 for w in weights)
assert gathered_input.ndim >= 2
Expand Down Expand Up @@ -996,7 +996,7 @@ def fused_anything_and_reducescatter(
timeout_s: int = 60 * 60,
**private_args_DO_NOT_USE,
) -> None:
world_size = dist.get_world_size(group=group)
world_size = group.size()

if len(scattered_outputs) == 0:
for dst_rank in range(world_size):
Expand Down

0 comments on commit e5e812a

Please sign in to comment.