Skip to content

Commit

Permalink
Revert "fix backend check (#2652)" (#2669)
Browse files Browse the repository at this point in the history
This reverts commit 2fc48c7.
  • Loading branch information
muellerzr authored Apr 15, 2024
1 parent 581a390 commit c470a13
Showing 1 changed file with 13 additions and 14 deletions.
27 changes: 13 additions & 14 deletions src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,28 +730,27 @@ def _prepare_backend(
elif is_npu_available():
backend = "hccl"
distributed_type = DistributedType.MULTI_NPU
if (
if backend is None and (
int(os.environ.get("LOCAL_RANK", -1)) != -1
or get_int_from_env(["PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE", "WORLD_SIZE"], 1) > 1
):
if not cpu and is_xpu_available():
distributed_type = DistributedType.MULTI_XPU
else:
distributed_type = DistributedType.MULTI_CPU
if backend is None or backend == "ccl":
if is_ccl_available() and (
get_int_from_env(["CCL_WORKER_COUNT"], 0) > 0 or distributed_type == DistributedType.MULTI_XPU
):
if get_ccl_version() >= "1.12":
import oneccl_bindings_for_pytorch # noqa: F401
else:
import torch_ccl # noqa: F401

backend = "ccl"
elif torch.distributed.is_mpi_available():
backend = "mpi"
if is_ccl_available() and (
get_int_from_env(["CCL_WORKER_COUNT"], 0) > 0 or distributed_type == DistributedType.MULTI_XPU
):
if get_ccl_version() >= "1.12":
import oneccl_bindings_for_pytorch # noqa: F401
else:
backend = "gloo"
import torch_ccl # noqa: F401

backend = "ccl"
elif torch.distributed.is_mpi_available():
backend = "mpi"
else:
backend = "gloo"
if distributed_type is None:
distributed_type = DistributedType.NO
return backend, distributed_type
Expand Down

0 comments on commit c470a13

Please sign in to comment.