Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix is_interactive_compatible logic after AcceleratorConnector rewrite #12008

Merged
merged 9 commits into from
Feb 22, 2022
5 changes: 5 additions & 0 deletions pytorch_lightning/strategies/launchers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ class _Launcher(ABC):
cluster environment, hardware, strategy, etc.
"""

@property
@abstractmethod
def is_interactive_compatible(self) -> bool:
"""Returns whether this launcher can work in interactive environments such as Jupyter notebooks."""

@abstractmethod
def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
"""Launches the processes."""
4 changes: 4 additions & 0 deletions pytorch_lightning/strategies/launchers/spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ class _SpawnLauncher(_Launcher):
def __init__(self, strategy: Strategy) -> None:
self._strategy = strategy

@property
def is_interactive_compatible(self) -> bool:
return False # TODO: the return value should depend on 1) start_method 2) CUDA vs. CPU
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved

def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
"""Spawns processes that run the given function in parallel.

Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/strategies/launchers/subprocess_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ class _SubprocessScriptLauncher(_Launcher):
num_nodes: The total number of nodes that participate in this process group.
"""

@property
def is_interactive_compatible(self) -> bool:
return False

def __init__(self, cluster_environment: ClusterEnvironment, num_processes: int, num_nodes: int) -> None:
super().__init__()
self.cluster_environment = cluster_environment
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/strategies/launchers/xla_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ class _XLASpawnLauncher(_SpawnLauncher):
- It is important that the entry point to the program/script is guarded by ``if __name__ == "__main__"``.
"""

@property
def is_interactive_compatible(self) -> bool:
return True

def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
"""Spawns processes that run the given function in parallel.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -715,19 +715,16 @@ def _lazy_init_strategy(self) -> None:

from pytorch_lightning.utilities import _IS_INTERACTIVE

# TODO move is_compatible logic to strategy API
interactive_compatible_strategy = (
interactive_recomended_strategy = (
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
DataParallelStrategy.strategy_name,
DDPSpawnStrategy.strategy_name,
DDPSpawnShardedStrategy.strategy_name,
TPUSpawnStrategy.strategy_name,
)
if _IS_INTERACTIVE and self.strategy.strategy_name not in interactive_compatible_strategy:
if _IS_INTERACTIVE and self.strategy.launcher and not self.strategy.launcher.is_interactive_compatible:
raise MisconfigurationException(
f"`Trainer(strategy={self.strategy.strategy_name!r})` or"
f" `Trainer(accelerator={self.strategy.strategy_name!r})` is not compatible with an interactive"
" environment. Run your code as a script, or choose one of the compatible backends:"
f" {', '.join(interactive_compatible_strategy)}."
f" Trainer(strategy=None|{'|'.join(interactive_recomended_strategy)})."
" In case you are spawning processes yourself, make sure to include the Trainer"
" creation inside the worker function."
)
Expand Down
20 changes: 15 additions & 5 deletions tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch
import torch.distributed

import pytorch_lightning
from pytorch_lightning import Trainer
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.cpu import CPUAccelerator
Expand Down Expand Up @@ -393,19 +394,28 @@ def test_dist_backend_accelerator_mapping(*_):
assert trainer.strategy.local_rank == 0


@mock.patch("pytorch_lightning.utilities._IS_INTERACTIVE", return_value=True)
@mock.patch("torch.cuda.device_count", return_value=2)
def test_ipython_incompatible_backend_error(*_):
def test_ipython_incompatible_backend_error(_, monkeypatch):
monkeypatch.setattr(pytorch_lightning.utilities, "_IS_INTERACTIVE", True)
with pytest.raises(MisconfigurationException, match=r"strategy='ddp'\)`.*is not compatible"):
Trainer(strategy="ddp", gpus=2)

with pytest.raises(MisconfigurationException, match=r"strategy='ddp2'\)`.*is not compatible"):
Trainer(strategy="ddp2", gpus=2)

with pytest.raises(MisconfigurationException, match=r"strategy='ddp_spawn'\)`.*is not compatible"):
Trainer(strategy="ddp_spawn")

@mock.patch("pytorch_lightning.utilities._IS_INTERACTIVE", return_value=True)
def test_ipython_compatible_backend(*_):
Trainer(strategy="ddp_spawn", num_processes=2)
with pytest.raises(MisconfigurationException, match=r"strategy='ddp'\)`.*is not compatible"):
# Edge case: AcceleratorConnector maps dp to ddp if accelerator != gpu
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
Trainer(strategy="dp")


def test_ipython_compatible_backend(monkeypatch):
monkeypatch.setattr(pytorch_lightning.utilities, "_IS_INTERACTIVE", True)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
Trainer()
Trainer(strategy="dp", accelerator="gpu")
Trainer(accelerator="tpu")
awaelchli marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize(["accelerator", "plugin"], [("ddp_spawn", "ddp_sharded"), (None, "ddp_sharded")])
Expand Down