diff --git a/src/accelerate/state.py b/src/accelerate/state.py index 147a9521640..6f0a06162af 100644 --- a/src/accelerate/state.py +++ b/src/accelerate/state.py @@ -730,7 +730,7 @@ def _prepare_backend( elif is_npu_available(): backend = "hccl" distributed_type = DistributedType.MULTI_NPU - if ( + if backend is None and ( int(os.environ.get("LOCAL_RANK", -1)) != -1 or get_int_from_env(["PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE", "WORLD_SIZE"], 1) > 1 ): @@ -738,20 +738,19 @@ def _prepare_backend( distributed_type = DistributedType.MULTI_XPU else: distributed_type = DistributedType.MULTI_CPU - if backend is None or backend == "ccl": - if is_ccl_available() and ( - get_int_from_env(["CCL_WORKER_COUNT"], 0) > 0 or distributed_type == DistributedType.MULTI_XPU - ): - if get_ccl_version() >= "1.12": - import oneccl_bindings_for_pytorch # noqa: F401 - else: - import torch_ccl # noqa: F401 - - backend = "ccl" - elif torch.distributed.is_mpi_available(): - backend = "mpi" + if is_ccl_available() and ( + get_int_from_env(["CCL_WORKER_COUNT"], 0) > 0 or distributed_type == DistributedType.MULTI_XPU + ): + if get_ccl_version() >= "1.12": + import oneccl_bindings_for_pytorch # noqa: F401 else: - backend = "gloo" + import torch_ccl # noqa: F401 + + backend = "ccl" + elif torch.distributed.is_mpi_available(): + backend = "mpi" + else: + backend = "gloo" if distributed_type is None: distributed_type = DistributedType.NO return backend, distributed_type