Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Support multi-node distributed training with NPU backend #1459

Merged
merged 1 commit into from
Dec 26, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions mmengine/dist/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,10 @@ def _init_dist_pytorch(backend, init_backend='torch', **kwargs) -> None:
**kwargs: keyword arguments are passed to ``init_process_group``.
"""
rank = int(os.environ['RANK'])
# LOCAL_RANK is set by `torch.distributed.launch` since PyTorch 1.1
local_rank = int(os.environ['LOCAL_RANK'])
if is_mlu_available():
import torch_mlu # noqa: F401
local_rank = int(os.environ['LOCAL_RANK'])
torch.mlu.set_device(local_rank)
torch_dist.init_process_group(
backend='cncl',
Expand All @@ -110,15 +111,13 @@ def _init_dist_pytorch(backend, init_backend='torch', **kwargs) -> None:
**kwargs)
elif is_npu_available():
import torch_npu # noqa: F401
torch.npu.set_device(rank)
torch.npu.set_device(local_rank)
torch_dist.init_process_group(
backend='hccl',
rank=rank,
world_size=int(os.environ['WORLD_SIZE']),
**kwargs)
else:
# LOCAL_RANK is set by `torch.distributed.launch` since PyTorch 1.1
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)

if init_backend == 'torch':
Expand Down