Skip to content

Commit

Permalink
Add strategy argument to Trainer (#8597)
Browse files Browse the repository at this point in the history
Co-authored-by: Rohit Gupta <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 13, 2021
1 parent 28fc8d2 commit 05b15e6
Show file tree
Hide file tree
Showing 9 changed files with 323 additions and 16 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `enable_model_summary` flag to Trainer ([#9699](https://github.com/PyTorchLightning/pytorch-lightning/pull/9699))


- Added `strategy` argument to Trainer ([#8597](https://github.com/PyTorchLightning/pytorch-lightning/pull/8597))


### Changed

- Module imports are now catching `ModuleNotFoundError` instead of `ImportError` ([#9867](https://github.com/PyTorchLightning/pytorch-lightning/pull/9867))
Expand Down
107 changes: 95 additions & 12 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(
ipus,
distributed_backend,
accelerator,
strategy: Optional[Union[str, TrainingTypePlugin]],
gpus,
gpu_ids,
num_nodes,
Expand All @@ -111,12 +112,9 @@ def __init__(
self._distrib_type = None
self._accelerator_type = None

if distributed_backend is not None:
rank_zero_deprecation(
f"`Trainer(distributed_backend={distributed_backend})` has been deprecated and will be removed in v1.5."
f" Use `Trainer(accelerator={distributed_backend})` instead."
)
distributed_backend = distributed_backend or accelerator
self.strategy = strategy.lower() if isinstance(strategy, str) else strategy
self.distributed_backend = distributed_backend or accelerator

self._init_deterministic(deterministic)

self.num_processes = num_processes
Expand All @@ -126,7 +124,6 @@ def __init__(
self.parallel_device_ids = gpu_ids
self.tpu_cores = tpu_cores
self.ipus = ipus
self.distributed_backend = distributed_backend
self.num_nodes = num_nodes
self.sync_batchnorm = sync_batchnorm
self.benchmark = benchmark
Expand All @@ -151,16 +148,23 @@ def __init__(

self.plugins = plugins

self._handle_accelerator_and_distributed_backend(distributed_backend, accelerator)

self._validate_accelerator_and_devices()

self._warn_if_devices_flag_ignored()

self.select_accelerator_type()
self.set_distributed_mode()

if self.strategy is not None:
self._set_training_type_plugin()
else:
self.set_distributed_mode()
self.configure_slurm_ddp()

self.handle_given_plugins()
self.update_device_type_if_ipu_plugin()
self.update_device_type_if_training_type_plugin_passed()

self._validate_accelerator_type()
self._set_devices_if_none()
Expand Down Expand Up @@ -228,11 +232,11 @@ def select_accelerator_type(self) -> None:
self._set_devices_to_cpu_num_processes()
self._accelerator_type = DeviceType.CPU

if self.distributed_backend in ["auto"] + list(DeviceType):
if self.distributed_backend in self.accelerator_types:
self.distributed_backend = None

def _validate_accelerator_and_devices(self) -> None:
if self.distributed_backend not in ["auto"] + list(DeviceType) and self.devices is not None:
if self.distributed_backend not in self.accelerator_types and self.devices is not None:
raise MisconfigurationException(
f"You passed `devices={self.devices}` but haven't specified"
" `accelerator=('auto'|'tpu'|'gpu'|'ipu'|'cpu')` for the devices mapping,"
Expand Down Expand Up @@ -285,9 +289,56 @@ def _set_devices_if_none(self) -> None:
elif self._accelerator_type == DeviceType.CPU:
self.devices = self.num_processes

def _handle_accelerator_and_distributed_backend(
self, distributed_backend: Optional[str], accelerator: Optional[Union[str, Accelerator]]
) -> None:
if distributed_backend is not None:
rank_zero_deprecation(
f"`Trainer(distributed_backend={distributed_backend})` has been deprecated and will be removed in v1.5."
f" Use `Trainer(strategy={distributed_backend})` instead."
)
if self.strategy is not None:
raise MisconfigurationException(
f"You have passed `Trainer(strategy={self.strategy})` but have"
f" also passed `Trainer(distributed_backend={distributed_backend})`."
f"HINT: Use just `Trainer(strategy={self.strategy})` instead."
)

if accelerator is not None and accelerator in list(DistributedType):
rank_zero_deprecation(
f"Passing {accelerator} `strategy` to the `accelerator` flag in Trainer has been deprecated"
f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={accelerator})` instead."
)
if self.strategy is not None:
raise MisconfigurationException(
f"You have passed `Trainer(strategy={self.strategy})` but have"
f" also passed `Trainer(accelerator={accelerator})`."
f"HINT: Use just `Trainer(strategy={self.strategy})` instead."
)

def _set_training_type_plugin(self) -> None:
if isinstance(self.strategy, str) and self.strategy in TrainingTypePluginsRegistry:
self._training_type_plugin = TrainingTypePluginsRegistry.get(self.strategy)
if isinstance(self.strategy, str):
self.set_distributed_mode(self.strategy)
elif isinstance(self.strategy, TrainingTypePlugin):
self._training_type_plugin = self.strategy

def handle_given_plugins(self) -> None:

training_type = None
for plug in self.plugins:
if self.strategy is not None and self._is_plugin_training_type(plug):
raise MisconfigurationException(
f"You have passed `Trainer(strategy={self.strategy})`"
f" and you can only specify one training type plugin, but you have passed {plug} as a plugin."
)
if self._is_plugin_training_type(plug):
rank_zero_deprecation(
f"Passing {plug} `strategy` to the `plugins` flag in Trainer has been deprecated"
f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={plug})` instead."
)

training_type = self._training_type_plugin or None
checkpoint = None
precision = None
cluster_environment = None
Expand Down Expand Up @@ -350,6 +401,10 @@ def handle_given_plugins(self) -> None:
self._checkpoint_io = checkpoint
self._cluster_environment = cluster_environment or self.select_cluster_environment()

@property
def accelerator_types(self) -> List[str]:
return ["auto"] + list(DeviceType)

@property
def precision_plugin(self) -> PrecisionPlugin:
if self._precision_plugin is None:
Expand Down Expand Up @@ -540,9 +595,18 @@ def root_gpu(self) -> Optional[int]:
else None
)

@staticmethod
def _is_plugin_training_type(plugin: Union[str, TrainingTypePlugin]) -> bool:
if isinstance(plugin, str) and (plugin in TrainingTypePluginsRegistry or plugin in list(DistributedType)):
return True
return isinstance(plugin, TrainingTypePlugin)

@property
def is_training_type_in_plugins(self) -> bool:
return any(isinstance(plug, str) and plug in TrainingTypePluginsRegistry for plug in self.plugins)
return any(
(isinstance(plug, str) and plug in TrainingTypePluginsRegistry) or isinstance(plug, TrainingTypePlugin)
for plug in self.plugins
)

def select_precision_plugin(self) -> PrecisionPlugin:
# set precision type
Expand Down Expand Up @@ -875,6 +939,25 @@ def update_device_type_if_ipu_plugin(self) -> None:
if isinstance(self._training_type_plugin, IPUPlugin) and self._device_type != DeviceType.IPU:
self._device_type = DeviceType.IPU

def update_device_type_if_training_type_plugin_passed(self) -> None:
if isinstance(self.strategy, TrainingTypePlugin) or any(
isinstance(plug, TrainingTypePlugin) for plug in self.plugins
):
if self._accelerator_type is not None:
if self.use_ipu:
self._device_type = DeviceType.IPU
elif self.use_tpu:
self._device_type = DeviceType.TPU
elif self.use_gpu:
self._device_type = DeviceType.GPU
else:
if self.has_ipu:
self._device_type = DeviceType.IPU
elif self.has_tpu:
self._device_type = DeviceType.TPU
elif self.has_gpu:
self._device_type = DeviceType.GPU

def configure_slurm_ddp(self):
# extract SLURM flag vars
# whenever we have the correct number of tasks, we let slurm manage processes
Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def __init__(
flush_logs_every_n_steps: Optional[int] = None,
log_every_n_steps: int = 50,
accelerator: Optional[Union[str, Accelerator]] = None,
strategy: Optional[Union[str, TrainingTypePlugin]] = None,
sync_batchnorm: bool = False,
precision: Union[int, str] = 32,
enable_model_summary: bool = True,
Expand Down Expand Up @@ -354,6 +355,9 @@ def __init__(
no checkpoint file at the path, start from scratch. If resuming from mid-epoch checkpoint,
training will start from the beginning of the next epoch.
strategy: Supports different training strategies with aliases
as well custom training type plugins.
sync_batchnorm: Synchronize batch norm layers between process groups/whole world.
terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the
Expand Down Expand Up @@ -424,6 +428,7 @@ def __init__(
ipus,
distributed_backend,
accelerator,
strategy,
gpus,
gpu_ids,
num_nodes,
Expand Down
75 changes: 74 additions & 1 deletion tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pytorch_lightning.accelerators.gpu import GPUAccelerator
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.plugins import (
DataParallelPlugin,
DDP2Plugin,
DDPPlugin,
DDPShardedPlugin,
Expand All @@ -42,7 +43,7 @@
SLURMEnvironment,
TorchElasticEnvironment,
)
from pytorch_lightning.utilities import DistributedType
from pytorch_lightning.utilities import DeviceType, DistributedType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -631,6 +632,78 @@ def test_accelerator_ddp_for_cpu(tmpdir):
assert isinstance(trainer.training_type_plugin, DDPPlugin)


def test_exception_when_strategy_used_with_distributed_backend():
with pytest.raises(MisconfigurationException, match="but have also passed"):
Trainer(distributed_backend="ddp_cpu", strategy="ddp_spawn")


def test_exception_when_strategy_used_with_accelerator():
with pytest.raises(MisconfigurationException, match="but have also passed"):
Trainer(accelerator="ddp", strategy="ddp_spawn")


def test_exception_when_strategy_used_with_plugins():
with pytest.raises(MisconfigurationException, match="only specify one training type plugin, but you have passed"):
Trainer(plugins="ddp_find_unused_parameters_false", strategy="ddp_spawn")


@pytest.mark.parametrize(
["strategy", "plugin"],
[
("ddp_spawn", DDPSpawnPlugin),
("ddp_spawn_find_unused_parameters_false", DDPSpawnPlugin),
("ddp", DDPPlugin),
("ddp_find_unused_parameters_false", DDPPlugin),
],
)
def test_strategy_choice_cpu_str(tmpdir, strategy, plugin):
trainer = Trainer(strategy=strategy, accelerator="cpu", devices=2)
assert isinstance(trainer.training_type_plugin, plugin)


@pytest.mark.parametrize("plugin", [DDPSpawnPlugin, DDPPlugin])
def test_strategy_choice_cpu_plugin(tmpdir, plugin):
trainer = Trainer(strategy=plugin(), accelerator="cpu", devices=2)
assert isinstance(trainer.training_type_plugin, plugin)


@RunIf(min_gpus=2)
@pytest.mark.parametrize(
["strategy", "plugin"],
[
("ddp_spawn", DDPSpawnPlugin),
("ddp_spawn_find_unused_parameters_false", DDPSpawnPlugin),
("ddp", DDPPlugin),
("ddp_find_unused_parameters_false", DDPPlugin),
("ddp2", DDP2Plugin),
("dp", DataParallelPlugin),
("ddp_sharded", DDPShardedPlugin),
("ddp_sharded_spawn", DDPSpawnShardedPlugin),
pytest.param("deepspeed", DeepSpeedPlugin, marks=RunIf(deepspeed=True)),
],
)
def test_strategy_choice_gpu_str(tmpdir, strategy, plugin):
trainer = Trainer(strategy=strategy, accelerator="gpu", devices=2)
assert isinstance(trainer.training_type_plugin, plugin)


@RunIf(min_gpus=2)
@pytest.mark.parametrize("plugin", [DDPSpawnPlugin, DDPPlugin])
def test_strategy_choice_gpu_plugin(tmpdir, plugin):
trainer = Trainer(strategy=plugin(), accelerator="gpu", devices=2)
assert isinstance(trainer.training_type_plugin, plugin)


@RunIf(min_gpus=2)
@pytest.mark.parametrize("plugin", [DDPSpawnPlugin, DDPPlugin])
def test_device_type_when_training_plugin_gpu_passed(tmpdir, plugin):

trainer = Trainer(strategy=plugin(), gpus=2)
assert isinstance(trainer.training_type_plugin, plugin)
assert trainer._device_type == DeviceType.GPU
assert isinstance(trainer.accelerator, GPUAccelerator)


@pytest.mark.parametrize("precision", [1, 12, "invalid"])
def test_validate_precision_type(tmpdir, precision):

Expand Down
19 changes: 17 additions & 2 deletions tests/accelerators/test_ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pytorch_lightning.plugins import IPUPlugin, IPUPrecisionPlugin
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities import _IPU_AVAILABLE
from pytorch_lightning.utilities import _IPU_AVAILABLE, DeviceType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel
from tests.helpers.datamodules import ClassifDataModule
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_warning_if_ipus_not_used(tmpdir):
@RunIf(ipu=True)
def test_no_warning_plugin(tmpdir):
with pytest.warns(None) as record:
Trainer(default_root_dir=tmpdir, plugins=IPUPlugin(training_opts=poptorch.Options()))
Trainer(default_root_dir=tmpdir, strategy=IPUPlugin(training_opts=poptorch.Options()))
assert len(record) == 0


Expand Down Expand Up @@ -528,3 +528,18 @@ def test_set_devices_if_none_ipu():

trainer = Trainer(accelerator="ipu", ipus=8)
assert trainer.devices == 8


@RunIf(ipu=True)
def test_strategy_choice_ipu_plugin(tmpdir):
trainer = Trainer(strategy=IPUPlugin(), accelerator="ipu", devices=8)
assert isinstance(trainer.training_type_plugin, IPUPlugin)


@RunIf(ipu=True)
def test_device_type_when_training_plugin_ipu_passed(tmpdir):

trainer = Trainer(strategy=IPUPlugin(), ipus=8)
assert isinstance(trainer.training_type_plugin, IPUPlugin)
assert trainer._device_type == DeviceType.IPU
assert isinstance(trainer.accelerator, IPUAccelerator)
13 changes: 13 additions & 0 deletions tests/accelerators/test_tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,19 @@ def test_ddp_cpu_not_supported_on_tpus():
Trainer(accelerator="ddp_cpu")


@RunIf(tpu=True)
@pytest.mark.parametrize("strategy", ["tpu_spawn", "tpu_spawn_debug"])
def test_strategy_choice_tpu_str(tmpdir, strategy):
trainer = Trainer(strategy=strategy, accelerator="tpu", devices=8)
assert isinstance(trainer.training_type_plugin, TPUSpawnPlugin)


@RunIf(tpu=True)
def test_strategy_choice_tpu_plugin(tmpdir):
trainer = Trainer(strategy=TPUSpawnPlugin(), accelerator="tpu", devices=8)
assert isinstance(trainer.training_type_plugin, TPUSpawnPlugin)


@RunIf(tpu=True)
def test_auto_parameters_tying_tpus(tmpdir):

Expand Down
10 changes: 10 additions & 0 deletions tests/deprecated_api/test_remove_1-7.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,16 @@ def test_v1_7_0_deprecate_parameter_validation():
from pytorch_lightning.core.decorators import parameter_validation # noqa: F401


def test_v1_7_0_passing_strategy_to_accelerator_trainer_flag():
with pytest.deprecated_call(match="has been deprecated in v1.5 and will be removed in v1.7."):
Trainer(accelerator="ddp_spawn")


def test_v1_7_0_passing_strategy_to_plugins_flag():
with pytest.deprecated_call(match="has been deprecated in v1.5 and will be removed in v1.7."):
Trainer(plugins="ddp_spawn")


def test_v1_7_0_weights_summary_trainer(tmpdir):
with pytest.deprecated_call(match=r"Setting `Trainer\(weights_summary=full\)` is deprecated in v1.5"):
t = Trainer(weights_summary="full")
Expand Down
Loading

0 comments on commit 05b15e6

Please sign in to comment.