Skip to content

Commit

Permalink
Fix is_interactive_compatible logic after AcceleratorConnector rewr…
Browse files Browse the repository at this point in the history
…ite (#12008)

* fix is_interactive_compatible

* improve tests

* update message

* address review

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
awaelchli and pre-commit-ci[bot] authored Feb 22, 2022
1 parent 4c4b9d5 commit d0f5460
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 17 deletions.
5 changes: 5 additions & 0 deletions pytorch_lightning/strategies/launchers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ class _Launcher(ABC):
cluster environment, hardware, strategy, etc.
"""

@property
@abstractmethod
def is_interactive_compatible(self) -> bool:
"""Returns whether this launcher can work in interactive environments such as Jupyter notebooks."""

@abstractmethod
def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
"""Launches the processes."""
7 changes: 7 additions & 0 deletions pytorch_lightning/strategies/launchers/spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ def __init__(self, strategy: Strategy) -> None:
self._strategy = strategy
self._start_method = "spawn"

@property
def is_interactive_compatible(self) -> bool:
# The start method 'spawn' is currently the only one that works with DDP and CUDA support
# The start method 'fork' is the only one supported in Jupyter environments but not compatible with CUDA
# For more context, see https://github.com/PyTorchLightning/pytorch-lightning/issues/7550
return self._start_method == "fork" and self._strategy.root_device.type != "cuda"

def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any:
"""Spawns processes that run the given function in parallel.
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/strategies/launchers/subprocess_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ class _SubprocessScriptLauncher(_Launcher):
num_nodes: The total number of nodes that participate in this process group.
"""

@property
def is_interactive_compatible(self) -> bool:
return False

def __init__(self, cluster_environment: ClusterEnvironment, num_processes: int, num_nodes: int) -> None:
super().__init__()
self.cluster_environment = cluster_environment
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/strategies/launchers/xla_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def __init__(self, strategy: "Strategy") -> None:
super().__init__(strategy)
self._start_method = "fork"

@property
def is_interactive_compatible(self) -> bool:
return True

def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any:
"""Spawns processes that run the given function in parallel.
Expand Down
14 changes: 4 additions & 10 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
TPUSpawnStrategy,
)
from pytorch_lightning.utilities import (
_StrategyType,
AMPType,
device_parser,
LightningEnum,
Expand Down Expand Up @@ -734,19 +735,12 @@ def _lazy_init_strategy(self) -> None:

from pytorch_lightning.utilities import _IS_INTERACTIVE

# TODO move is_compatible logic to strategy API
interactive_compatible_strategy = (
DataParallelStrategy.strategy_name,
DDPSpawnStrategy.strategy_name,
DDPSpawnShardedStrategy.strategy_name,
TPUSpawnStrategy.strategy_name,
)
if _IS_INTERACTIVE and self.strategy.strategy_name not in interactive_compatible_strategy:
if _IS_INTERACTIVE and self.strategy.launcher and not self.strategy.launcher.is_interactive_compatible:
raise MisconfigurationException(
f"`Trainer(strategy={self.strategy.strategy_name!r})` or"
f" `Trainer(accelerator={self.strategy.strategy_name!r})` is not compatible with an interactive"
" environment. Run your code as a script, or choose one of the compatible backends:"
f" {', '.join(interactive_compatible_strategy)}."
" environment. Run your code as a script, or choose one of the compatible strategies:"
f" Trainer(strategy=None|{'|'.join(_StrategyType.interactive_compatible_types())})."
" In case you are spawning processes yourself, make sure to include the Trainer"
" creation inside the worker function."
)
Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/utilities/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,6 @@ def interactive_compatible_types() -> list[_StrategyType]:
"""Returns a list containing interactive compatible _StrategyTypes."""
return [
_StrategyType.DP,
_StrategyType.DDP_SPAWN,
_StrategyType.DDP_SHARDED_SPAWN,
_StrategyType.TPU_SPAWN,
]

Expand Down
23 changes: 18 additions & 5 deletions tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch
import torch.distributed

import pytorch_lightning
from pytorch_lightning import Trainer
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.cpu import CPUAccelerator
Expand Down Expand Up @@ -392,19 +393,31 @@ def test_dist_backend_accelerator_mapping(*_):
assert trainer.strategy.local_rank == 0


@mock.patch("pytorch_lightning.utilities._IS_INTERACTIVE", return_value=True)
@mock.patch("torch.cuda.device_count", return_value=2)
def test_ipython_incompatible_backend_error(*_):
def test_ipython_incompatible_backend_error(_, monkeypatch):
monkeypatch.setattr(pytorch_lightning.utilities, "_IS_INTERACTIVE", True)
with pytest.raises(MisconfigurationException, match=r"strategy='ddp'\)`.*is not compatible"):
Trainer(strategy="ddp", gpus=2)

with pytest.raises(MisconfigurationException, match=r"strategy='ddp2'\)`.*is not compatible"):
Trainer(strategy="ddp2", gpus=2)

with pytest.raises(MisconfigurationException, match=r"strategy='ddp_spawn'\)`.*is not compatible"):
Trainer(strategy="ddp_spawn")

@mock.patch("pytorch_lightning.utilities._IS_INTERACTIVE", return_value=True)
def test_ipython_compatible_backend(*_):
Trainer(strategy="ddp_spawn", num_processes=2)
with pytest.raises(MisconfigurationException, match=r"strategy='ddp_sharded_spawn'\)`.*is not compatible"):
Trainer(strategy="ddp_sharded_spawn")

with pytest.raises(MisconfigurationException, match=r"strategy='ddp'\)`.*is not compatible"):
# Edge case: AcceleratorConnector maps dp to ddp if accelerator != gpu
Trainer(strategy="dp")


@pytest.mark.parametrize("trainer_kwargs", [{}, dict(strategy="dp", accelerator="gpu"), dict(accelerator="tpu")])
def test_ipython_compatible_backend(trainer_kwargs, monkeypatch):
monkeypatch.setattr(pytorch_lightning.utilities, "_IS_INTERACTIVE", True)
trainer = Trainer(**trainer_kwargs)
assert trainer.strategy.launcher is None or trainer.strategy.launcher.is_interactive_compatible


@pytest.mark.parametrize(["accelerator", "plugin"], [("ddp_spawn", "ddp_sharded"), (None, "ddp_sharded")])
Expand Down

0 comments on commit d0f5460

Please sign in to comment.