Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add strategy argument to Trainer #8597

Merged
merged 34 commits into from
Oct 13, 2021
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
98c8830
Add training_type argument to Trainer
kaushikb11 Jul 28, 2021
8b2f6c4
Add deprecation warning
kaushikb11 Jul 28, 2021
73120fa
Update training_types
kaushikb11 Jul 28, 2021
e856db1
Add set training type plugin
kaushikb11 Jul 30, 2021
59ebe70
Add deprecation_and_warn_for_accelerator_and_distributed_backend
kaushikb11 Jul 30, 2021
b5094ec
Add plugins warning
kaushikb11 Jul 30, 2021
4d11ced
Add is_plugin_training_type
kaushikb11 Jul 30, 2021
c878f1c
Update tests
kaushikb11 Jul 30, 2021
f9a1a63
Add deprecation test
kaushikb11 Jul 30, 2021
ba6bc88
Add tests for cpus training type
kaushikb11 Jul 31, 2021
3fe1e67
Add tests for gpus & tpus training type
kaushikb11 Jul 31, 2021
ae1395c
Add tests for ipus training type
kaushikb11 Jul 31, 2021
fff3385
Add update_device_type_if_training_type_plugin_passed
kaushikb11 Aug 1, 2021
127e5e8
Merge branch 'master' into add/training_type
kaushikb11 Aug 3, 2021
bf7b9cb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 3, 2021
aeca5aa
Update test
kaushikb11 Aug 3, 2021
91caf4b
Code improvements
kaushikb11 Aug 3, 2021
3110ce9
Add trainer kwargs test
kaushikb11 Aug 3, 2021
ddbeab4
Add update_device_type_if_training_type_plugin_passed
kaushikb11 Aug 3, 2021
e6eb8a6
Update to accelerator strategy
kaushikb11 Aug 3, 2021
a5c1978
Update to accelerator strategy
kaushikb11 Aug 3, 2021
ea03cff
Merge branch 'master' into add/training_type
kaushikb11 Oct 11, 2021
c9cd9f4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 11, 2021
0f00172
Fix tests
kaushikb11 Oct 11, 2021
8b63157
Use strategy instead
kaushikb11 Oct 11, 2021
353434d
Update ipu test
kaushikb11 Oct 11, 2021
87a3b60
Merge branch 'master' into add/training_type
kaushikb11 Oct 12, 2021
aee59fd
Update typing
kaushikb11 Oct 13, 2021
9ba9c13
Update tests
kaushikb11 Oct 13, 2021
52a04ab
Update pytorch_lightning/trainer/connectors/accelerator_connector.py
kaushikb11 Oct 13, 2021
8df86ec
Update tests
kaushikb11 Oct 13, 2021
6a7d9d4
Merge branch 'add/training_type' of https://github.com/kaushikb11/pyt…
kaushikb11 Oct 13, 2021
89c190f
Merge branch 'master' into add/training_type
kaushikb11 Oct 13, 2021
dfecb4f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for `torch.autograd.set_detect_anomaly` through `Trainer` constructor argument `detect_anomaly` ([#9848](https://github.com/PyTorchLightning/pytorch-lightning/pull/9848))


- Added `strategy` argument to Trainer ([#8597](https://github.com/PyTorchLightning/pytorch-lightning/pull/8597))


### Changed

- Module imports are now catching `ModuleNotFoundError` instead of `ImportError` ([#9867](https://github.com/PyTorchLightning/pytorch-lightning/pull/9867))
Expand Down
107 changes: 95 additions & 12 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(
ipus,
distributed_backend,
accelerator,
strategy,
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
gpus,
gpu_ids,
num_nodes,
Expand All @@ -111,12 +112,9 @@ def __init__(
self._distrib_type = None
self._accelerator_type = None

if distributed_backend is not None:
rank_zero_deprecation(
f"`Trainer(distributed_backend={distributed_backend})` has been deprecated and will be removed in v1.5."
f" Use `Trainer(accelerator={distributed_backend})` instead."
)
distributed_backend = distributed_backend or accelerator
self.strategy = strategy
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
self.distributed_backend = distributed_backend or accelerator

self._init_deterministic(deterministic)

self.num_processes = num_processes
Expand All @@ -126,7 +124,6 @@ def __init__(
self.parallel_device_ids = gpu_ids
self.tpu_cores = tpu_cores
self.ipus = ipus
self.distributed_backend = distributed_backend
self.num_nodes = num_nodes
self.sync_batchnorm = sync_batchnorm
self.benchmark = benchmark
Expand All @@ -151,16 +148,23 @@ def __init__(

self.plugins = plugins

self._handle_accelerator_and_distributed_backend(distributed_backend, accelerator)

self._validate_accelerator_and_devices()

self._warn_if_devices_flag_ignored()

self.select_accelerator_type()
self.set_distributed_mode()

if self.strategy is not None:
self._set_training_type_plugin()
else:
self.set_distributed_mode()
self.configure_slurm_ddp()

self.handle_given_plugins()
self.update_device_type_if_ipu_plugin()
self.update_device_type_if_training_type_plugin_passed()

self._validate_accelerator_type()
self._set_devices_if_none()
Expand Down Expand Up @@ -228,11 +232,11 @@ def select_accelerator_type(self) -> None:
self._set_devices_to_cpu_num_processes()
self._accelerator_type = DeviceType.CPU

if self.distributed_backend in ["auto"] + list(DeviceType):
if self.distributed_backend in self.accelerator_types:
self.distributed_backend = None

def _validate_accelerator_and_devices(self) -> None:
if self.distributed_backend not in ["auto"] + list(DeviceType) and self.devices is not None:
if self.distributed_backend not in self.accelerator_types and self.devices is not None:
raise MisconfigurationException(
f"You passed `devices={self.devices}` but haven't specified"
" `accelerator=('auto'|'tpu'|'gpu'|'ipu'|'cpu')` for the devices mapping,"
Expand Down Expand Up @@ -285,9 +289,56 @@ def _set_devices_if_none(self) -> None:
elif self._accelerator_type == DeviceType.CPU:
self.devices = self.num_processes

def _handle_accelerator_and_distributed_backend(
self, distributed_backend: Optional[str], accelerator: Optional[Union[str, Accelerator]]
) -> None:
if distributed_backend is not None:
rank_zero_deprecation(
f"`Trainer(distributed_backend={distributed_backend})` has been deprecated and will be removed in v1.5."
f" Use `Trainer(strategy={distributed_backend})` instead."
)
if self.strategy is not None:
raise MisconfigurationException(
f"You have passed `Trainer(strategy={self.strategy})` but have"
f" also passed `Trainer(distributed_backend={distributed_backend})`."
f"HINT: Use just `Trainer(strategy={self.strategy})` instead."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing whitespace here

)

if accelerator is not None and accelerator in list(DistributedType):
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
rank_zero_deprecation(
f"Passing {accelerator} `strategy` to the `accelerator` flag in Trainer has been deprecated"
f" in v1.5 and will be removed in v1.6. Use `Trainer(strategy={accelerator})` instead."
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
)
if self.strategy is not None:
raise MisconfigurationException(
f"You have passed `Trainer(strategy={self.strategy})` but have"
f" also passed `Trainer(accelerator={accelerator})`."
f"HINT: Use just `Trainer(strategy={self.strategy})` instead."
)

def _set_training_type_plugin(self) -> None:
if isinstance(self.strategy, str) and self.strategy in TrainingTypePluginsRegistry:
self._training_type_plugin = TrainingTypePluginsRegistry.get(self.strategy)
if isinstance(self.strategy, str):
self.set_distributed_mode(self.strategy)
elif isinstance(self.strategy, TrainingTypePlugin):
self._training_type_plugin = self.strategy

def handle_given_plugins(self) -> None:

training_type = None
for plug in self.plugins:
if self.strategy is not None and self._is_plugin_training_type(plug):
raise MisconfigurationException(
f"You have passed `Trainer(strategy={self.strategy})`"
f" and you can only specify one training type plugin, but you have passed {plug} as a plugin."
)
if self._is_plugin_training_type(plug):
rank_zero_deprecation(
f"Passing {plug} `strategy` to the `plugins` flag in Trainer has been deprecated"
f" in v1.5 and will be removed in v1.6. Use `Trainer(strategy={plug})` instead."
)

training_type = self._training_type_plugin or None
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
checkpoint = None
precision = None
cluster_environment = None
Expand Down Expand Up @@ -350,6 +401,10 @@ def handle_given_plugins(self) -> None:
self._checkpoint_io = checkpoint
self._cluster_environment = cluster_environment or self.select_cluster_environment()

@property
def accelerator_types(self) -> List[str]:
return ["auto"] + list(DeviceType)
carmocca marked this conversation as resolved.
Show resolved Hide resolved

@property
def precision_plugin(self) -> PrecisionPlugin:
if self._precision_plugin is None:
Expand Down Expand Up @@ -540,9 +595,18 @@ def root_gpu(self) -> Optional[int]:
else None
)

@staticmethod
def _is_plugin_training_type(plugin: Union[str, TrainingTypePlugin]) -> bool:
if isinstance(plugin, str) and (plugin in TrainingTypePluginsRegistry or plugin in list(DistributedType)):
return True
return isinstance(plugin, TrainingTypePlugin)

@property
def is_training_type_in_plugins(self) -> bool:
return any(isinstance(plug, str) and plug in TrainingTypePluginsRegistry for plug in self.plugins)
return any(
(isinstance(plug, str) and plug in TrainingTypePluginsRegistry) or isinstance(plug, TrainingTypePlugin)
for plug in self.plugins
)

def select_precision_plugin(self) -> PrecisionPlugin:
# set precision type
Expand Down Expand Up @@ -873,6 +937,25 @@ def update_device_type_if_ipu_plugin(self) -> None:
if isinstance(self._training_type_plugin, IPUPlugin) and self._device_type != DeviceType.IPU:
self._device_type = DeviceType.IPU

def update_device_type_if_training_type_plugin_passed(self) -> None:
if isinstance(self.strategy, TrainingTypePlugin) or any(
isinstance(plug, TrainingTypePlugin) for plug in self.plugins
):
if self._accelerator_type is not None:
if self.use_ipu:
self._device_type = DeviceType.IPU
elif self.use_tpu:
self._device_type = DeviceType.TPU
elif self.use_gpu:
self._device_type = DeviceType.GPU
else:
if self.has_ipu:
self._device_type = DeviceType.IPU
elif self.has_tpu:
self._device_type = DeviceType.TPU
elif self.has_gpu:
self._device_type = DeviceType.GPU
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved

def configure_slurm_ddp(self):
# extract SLURM flag vars
# whenever we have the correct number of tasks, we let slurm manage processes
Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def __init__(
flush_logs_every_n_steps: Optional[int] = None,
log_every_n_steps: int = 50,
accelerator: Optional[Union[str, Accelerator]] = None,
strategy: Optional[Union[str, TrainingTypePlugin]] = None,
sync_batchnorm: bool = False,
precision: Union[int, str] = 32,
weights_summary: Optional[str] = "top",
Expand Down Expand Up @@ -346,6 +347,9 @@ def __init__(
no checkpoint file at the path, start from scratch. If resuming from mid-epoch checkpoint,
training will start from the beginning of the next epoch.

strategy: Supports different training strategies with aliases
as well custom training type plugins.

sync_batchnorm: Synchronize batch norm layers between process groups/whole world.

terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the
Expand Down Expand Up @@ -408,6 +412,7 @@ def __init__(
ipus,
distributed_backend,
accelerator,
strategy,
gpus,
gpu_ids,
num_nodes,
Expand Down
75 changes: 74 additions & 1 deletion tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pytorch_lightning.accelerators.gpu import GPUAccelerator
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.plugins import (
DataParallelPlugin,
DDP2Plugin,
DDPPlugin,
DDPShardedPlugin,
Expand All @@ -42,7 +43,7 @@
SLURMEnvironment,
TorchElasticEnvironment,
)
from pytorch_lightning.utilities import DistributedType
from pytorch_lightning.utilities import DeviceType, DistributedType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -631,6 +632,78 @@ def test_accelerator_ddp_for_cpu(tmpdir):
assert isinstance(trainer.training_type_plugin, DDPPlugin)


def test_exception_when_strategy_used_with_distributed_backend():
with pytest.raises(MisconfigurationException, match="but have also passed"):
Trainer(distributed_backend="ddp_cpu", strategy="ddp_spawn")


def test_exception_when_strategy_used_with_accelerator():
with pytest.raises(MisconfigurationException, match="but have also passed"):
Trainer(accelerator="ddp", strategy="ddp_spawn")


def test_exception_when_strategy_used_with_plugins():
with pytest.raises(MisconfigurationException, match="only specify one training type plugin, but you have passed"):
Trainer(plugins="ddp_find_unused_parameters_false", strategy="ddp_spawn")


@pytest.mark.parametrize(
["strategy", "plugin"],
[
("ddp_spawn", DDPSpawnPlugin),
("ddp_spawn_find_unused_parameters_false", DDPSpawnPlugin),
("ddp", DDPPlugin),
("ddp_find_unused_parameters_false", DDPPlugin),
],
)
def test_strategy_choice_cpu_str(tmpdir, strategy, plugin):
trainer = Trainer(strategy=strategy, accelerator="cpu", devices=2)
assert isinstance(trainer.training_type_plugin, plugin)


@pytest.mark.parametrize("plugin", [DDPSpawnPlugin, DDPPlugin])
def test_strategy_choice_cpu_plugin(tmpdir, plugin):
trainer = Trainer(strategy=plugin(), accelerator="cpu", devices=2)
assert isinstance(trainer.training_type_plugin, plugin)


@RunIf(min_gpus=2)
@pytest.mark.parametrize(
["strategy", "plugin"],
[
("ddp_spawn", DDPSpawnPlugin),
("ddp_spawn_find_unused_parameters_false", DDPSpawnPlugin),
("ddp", DDPPlugin),
("ddp_find_unused_parameters_false", DDPPlugin),
("ddp2", DDP2Plugin),
("dp", DataParallelPlugin),
("ddp_sharded", DDPShardedPlugin),
("ddp_sharded_spawn", DDPSpawnShardedPlugin),
pytest.param("deepspeed", DeepSpeedPlugin, marks=RunIf(deepspeed=True)),
],
)
def test_strategy_choice_gpu_str(tmpdir, strategy, plugin):
trainer = Trainer(strategy=strategy, accelerator="gpu", devices=2)
assert isinstance(trainer.training_type_plugin, plugin)


@RunIf(min_gpus=2)
@pytest.mark.parametrize("plugin", [DDPSpawnPlugin, DDPPlugin])
def test_strategy_choice_gpu_plugin(tmpdir, plugin):
trainer = Trainer(strategy=plugin(), accelerator="gpu", devices=2)
assert isinstance(trainer.training_type_plugin, plugin)


@RunIf(min_gpus=2)
@pytest.mark.parametrize("plugin", [DDPSpawnPlugin, DDPPlugin])
def test_device_type_when_training_plugin_gpu_passed(tmpdir, plugin):

trainer = Trainer(strategy=plugin(), gpus=2)
assert isinstance(trainer.training_type_plugin, plugin)
assert trainer._device_type == DeviceType.GPU
assert isinstance(trainer.accelerator, GPUAccelerator)


@pytest.mark.parametrize("precision", [1, 12, "invalid"])
def test_validate_precision_type(tmpdir, precision):

Expand Down
19 changes: 17 additions & 2 deletions tests/accelerators/test_ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pytorch_lightning.plugins import IPUPlugin, IPUPrecisionPlugin
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities import _IPU_AVAILABLE
from pytorch_lightning.utilities import _IPU_AVAILABLE, DeviceType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel
from tests.helpers.datamodules import ClassifDataModule
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_warning_if_ipus_not_used(tmpdir):
@RunIf(ipu=True)
def test_no_warning_plugin(tmpdir):
with pytest.warns(None) as record:
Trainer(default_root_dir=tmpdir, plugins=IPUPlugin(training_opts=poptorch.Options()))
Trainer(default_root_dir=tmpdir, strategy=IPUPlugin(training_opts=poptorch.Options()))
assert len(record) == 0


Expand Down Expand Up @@ -528,3 +528,18 @@ def test_set_devices_if_none_ipu():

trainer = Trainer(accelerator="ipu", ipus=8)
assert trainer.devices == 8


@RunIf(ipu=True)
def test_strategy_choice_ipu_plugin(tmpdir):
trainer = Trainer(strategy=IPUPlugin(), accelerator="ipu", devices=8)
assert isinstance(trainer.training_type_plugin, IPUPlugin)


@RunIf(ipu=True)
def test_device_type_when_training_plugin_ipu_passed(tmpdir):

trainer = Trainer(strategy=IPUPlugin(), ipus=8)
assert isinstance(trainer.training_type_plugin, IPUPlugin)
assert trainer._device_type == DeviceType.IPU
assert isinstance(trainer.accelerator, IPUAccelerator)
13 changes: 13 additions & 0 deletions tests/accelerators/test_tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,19 @@ def test_ddp_cpu_not_supported_on_tpus():
Trainer(accelerator="ddp_cpu")


@RunIf(tpu=True)
@pytest.mark.parametrize("strategy", ["tpu_spawn", "tpu_spawn_debug"])
def test_strategy_choice_tpu_str(tmpdir, strategy):
trainer = Trainer(strategy=strategy, accelerator="tpu", devices=8)
assert isinstance(trainer.training_type_plugin, TPUSpawnPlugin)


@RunIf(tpu=True)
def test_strategy_choice_tpu_plugin(tmpdir):
trainer = Trainer(strategy=TPUSpawnPlugin(), accelerator="tpu", devices=8)
assert isinstance(trainer.training_type_plugin, TPUSpawnPlugin)


@RunIf(tpu=True)
def test_auto_parameters_tying_tpus(tmpdir):

Expand Down
10 changes: 10 additions & 0 deletions tests/deprecated_api/test_remove_1-6.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,16 @@ def test_v1_6_0_deprecated_device_dtype_mixin_import():
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin # noqa: F401


def test_v1_6_0_passing_strategy_to_accelerator_trainer_flag():
with pytest.deprecated_call(match="has been deprecated in v1.5 and will be removed in v1.6."):
Trainer(accelerator="ddp_spawn")


def test_v1_6_0_passing_strategy_to_plugins_flag():
with pytest.deprecated_call(match="has been deprecated in v1.5 and will be removed in v1.6."):
Trainer(plugins="ddp_spawn")


def test_v1_6_0_deprecated_accelerator_collective():
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type import SingleDevicePlugin
Expand Down
Loading