Skip to content
This repository has been archived by the owner on Jul 31, 2024. It is now read-only.

Commit

Permalink
Add timeout to dist.init()
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhijian Liu committed May 7, 2021
1 parent 84afd32 commit fe96244
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions torchpack/distributed/context.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import os
from datetime import timedelta

import torch.distributed
from torch.distributed.constants import default_pg_timeout

__all__ = ['init', 'size', 'rank', 'local_size', 'local_rank', 'is_master']

_world_size, _world_rank = 1, 0
_local_size, _local_rank = 1, 0


def init(backend: int = 'nccl') -> None:
from mpi4py import MPI # type: ignore
def init(backend: int = 'nccl',
timeout: timedelta = default_pg_timeout) -> None:
from mpi4py import MPI
world_comm = MPI.COMM_WORLD
local_comm = MPI.COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED)

Expand All @@ -20,6 +23,7 @@ def init(backend: int = 'nccl') -> None:
master_host = 'tcp://' + os.environ['MASTER_HOST']
torch.distributed.init_process_group(backend=backend,
init_method=master_host,
timeout=timeout,
world_size=_world_size,
rank=_world_rank)

Expand Down

0 comments on commit fe96244

Please sign in to comment.