Skip to content

Commit

Permalink
Add support for specifying process group backend to DDP strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
ananthsub committed Feb 4, 2022
1 parent 8c07d8b commit dd12d27
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 17 deletions.
24 changes: 17 additions & 7 deletions pytorch_lightning/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,12 @@
_TORCH_GREATER_EQUAL_1_10,
rank_zero_warn,
)
from pytorch_lightning.utilities.distributed import _revert_sync_batchnorm, distributed_available
from pytorch_lightning.utilities.distributed import (
_get_process_group_backend_from_env,
_revert_sync_batchnorm,
distributed_available,
get_default_process_group_backend_for_device,
)
from pytorch_lightning.utilities.distributed import group as _group
from pytorch_lightning.utilities.distributed import (
init_dist_connection,
Expand Down Expand Up @@ -76,8 +81,9 @@
class DDPStrategy(ParallelStrategy):
"""Plugin for multi-process single-device training on one or multiple nodes.
The main process in each node spawns N-1 child processes via :func:`subprocess.Popen`, where N is the number of
devices (e.g. GPU) per node. It is very similar to how :mod:`torch.distributed.launch` launches processes.
If processes are not already created ahead of time, the main process in each node creates N-1 child processes via
:func:`subprocess.Popen`, where N is the number of devices (e.g. GPU) per node. It is very similar to how
:mod:`torch.distributed.launch` launches processes.
"""

distributed_backend = _StrategyType.DDP
Expand All @@ -93,6 +99,7 @@ def __init__(
ddp_comm_hook: Optional[callable] = None,
ddp_comm_wrapper: Optional[callable] = None,
model_averaging_period: Optional[int] = None,
pg_backend: Optional[str] = None,
**kwargs: Union[Any, Dict[str, Any]],
) -> None:
super().__init__(
Expand All @@ -114,6 +121,7 @@ def __init__(
self._pids: Optional[List[int]] = None
self._sync_dir: Optional[str] = None
self._rank_0_has_called_call_children_scripts: bool = False
self._pg_backend: Optional[str] = None
self.set_world_ranks()

@property
Expand Down Expand Up @@ -258,10 +266,12 @@ def setup_distributed(self):
# set warning rank
rank_zero_only.rank = self.global_rank

# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
init_dist_connection(self.cluster_environment, self.torch_distributed_backend)
self._pg_backend = (
self._pg_backend
or _get_process_group_backend_from_env()
or get_default_process_group_backend_for_device(self.root_device)
)
init_dist_connection(self.cluster_environment, self._pg_backend)

def _check_can_spawn_children(self):
if self.local_rank != 0:
Expand Down
27 changes: 18 additions & 9 deletions pytorch_lightning/strategies/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, List, Optional
Expand All @@ -25,7 +24,13 @@
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.strategy import Strategy
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, ReduceOp
from pytorch_lightning.utilities.distributed import (
_get_process_group_backend_from_env,
all_gather_ddp_if_available,
get_default_process_group_backend_for_device,
ReduceOp,
)
from pytorch_lightning.utilities.warnings import rank_zero_deprecation


class ParallelStrategy(Strategy, ABC):
Expand Down Expand Up @@ -98,13 +103,6 @@ def reduce_boolean_decision(self, decision: bool) -> bool:
decision = bool(decision == self.world_size)
return decision

@property
def torch_distributed_backend(self):
torch_backend = os.getenv("PL_TORCH_DISTRIBUTED_BACKEND")
if torch_backend is None:
torch_backend = "nccl" if self.root_device.type == "cuda" else "gloo"
return torch_backend

@staticmethod
def configure_sync_batchnorm(model: "pl.LightningModule") -> "pl.LightningModule":
"""Add global batchnorm for a model spread across multiple GPUs and nodes.
Expand Down Expand Up @@ -136,3 +134,14 @@ def block_backward_sync(self):
def teardown(self) -> None:
self.cluster_environment.teardown()
super().teardown()

@property
def torch_distributed_backend(self) -> str:
"""Deprecated property."""
rank_zero_deprecation(
"ParallelStrategy.torch_distributed_backend was deprecated in v1.6 and will be removed in v1.8."
)
pg_backend = _get_process_group_backend_from_env()
if pg_backend:
return pg_backend
return get_default_process_group_backend_for_device(self.root_device)
14 changes: 14 additions & 0 deletions pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,20 @@ def tpu_distributed() -> bool:
return _TPU_AVAILABLE and xm.xrt_world_size() > 1


def get_default_process_group_backend_for_device(device: torch.device) -> str:
return "nccl" if device.type == "cuda" else "gloo"


def _get_process_group_backend_from_env() -> Optional[str]:
torch_backend = os.getenv("PL_TORCH_DISTRIBUTED_BACKEND")
# TODO: circular dependency if rank_zero_deprecation is called from here
# if torch_backend:
# Emit warning
# rank_zero_deprecation("Environment variable `PL_TORCH_DISTRIBUTED_BACKEND`"
# " was deprecated in v1.6 and will be removed in v1.8.")
return torch_backend


def init_dist_connection(
cluster_environment: "pl.plugins.environments.ClusterEnvironment",
torch_distributed_backend: str,
Expand Down
44 changes: 43 additions & 1 deletion tests/deprecated_api/test_remove_1-8.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test deprecated functionality which will be removed in v1.8.0."""

# from unittest import mock
from unittest.mock import Mock

import pytest
Expand All @@ -31,11 +33,12 @@
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
from pytorch_lightning.strategies import ParallelStrategy
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.enums import DeviceType, DistributedType
from pytorch_lightning.utilities.imports import _TORCHTEXT_LEGACY
from pytorch_lightning.utilities.warnings import rank_zero_warn
from tests.helpers.boring_model import BoringDataModule, BoringModel
from tests.helpers.runif import RunIf
from tests.helpers.torchtext_utils import get_dummy_torchtext_data_iterator
Expand Down Expand Up @@ -403,3 +406,42 @@ def on_configure_sharded_model(self, trainer, model):
match="The `on_configure_sharded_model` callback hook was deprecated in v1.6 and will be removed in v1.8."
):
trainer.fit(model)


def test_v1_8_0_torch_distributed_backend_property():
class DummyParallel(ParallelStrategy):
def barrier(self):
return

def broadcast(self, obj, src: int = 0):
return obj

def model_to_device(self):
return

def reduce(
self,
tensor,
group,
reduce_op,
):
return torch.tensor(1)

@property
def root_device(self) -> torch.device:
return torch.device("cpu")

strategy = DummyParallel()
with pytest.deprecated_call(
match="ParallelStrategy.torch_distributed_backend was deprecated in v1.6 and will be removed in v1.8."
):
strategy.torch_distributed_backend


# @mock.patch.dict(os.environ, {"PL_TORCH_DISTRIBUTED_BACKEND": "foo"})
# def test_v1_8_0_torch_distributed_backend_env():
# with pytest.deprecated_call(
# match="Environment variable `PL_TORCH_DISTRIBUTED_BACKEND`"
# " was deprecated in v1.6 and will be removed in v1.8."
# ):
# _get_process_group_backend_from_env()

0 comments on commit dd12d27

Please sign in to comment.