Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only allow one value for each plugin type in plugins flag #12083

Merged
merged 9 commits into from
Mar 11, 2022
21 changes: 17 additions & 4 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

import logging
import os
from typing import List, Optional, Union
from collections import Counter
from typing import Dict, List, Optional, Union

import torch

Expand Down Expand Up @@ -304,34 +305,46 @@ def _check_config_and_set_final_flags(
self._precision_flag = precision

if plugins:
plugins_flags_types: Dict[str, int] = Counter()
for plugin in plugins:
if isinstance(plugin, Strategy) or isinstance(plugin, str) and plugin in self._registered_strategies:
self._strategy_flag = plugin
rank_zero_deprecation(
f"Passing {plugin} `strategy` to the `plugins` flag in Trainer has been deprecated"
f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={plugin})` instead."
)
plugins_flags_types[Strategy.__name__] += 1

elif isinstance(plugin, PrecisionPlugin):
self._precision_plugin_flag = plugin
elif isinstance(plugin, str) and plugin in self._precision_types:
self._precision_flag = plugin
plugins_flags_types[PrecisionPlugin.__name__] += 1
elif isinstance(plugin, CheckpointIO):
self.checkpoint_io = plugin
plugins_flags_types[CheckpointIO.__name__] += 1
elif isinstance(plugin, ClusterEnvironment):
self._cluster_environment_flag = plugin
plugins_flags_types[ClusterEnvironment.__name__] += 1
elif isinstance(plugin, LayerSync):
if sync_batchnorm and not isinstance(plugin, NativeSyncBatchNorm):
raise MisconfigurationException(
f"You set `Trainer(sync_batchnorm=True)` and provided a `{plugin.__class__.__name__}`"
" plugin, but this is not allowed. Choose one or the other."
)
self._layer_sync = plugin
plugins_flags_types[NativeSyncBatchNorm.__name__] += 1
else:
raise MisconfigurationException(
f"Found invalid type for plugin {plugin}. Expected a precision plugin or training strategy."
f"Found invalid type for plugin {plugin}. Expected one of: PrecisionPlugin, "
"CheckpointIO, ClusterEnviroment, LayerSync, or Strategy."
)

duplicated_plugin_key = [k for k, v in plugins_flags_types.items() if v > 1]
if duplicated_plugin_key:
raise MisconfigurationException(
f"Received multiple values for {', '.join(duplicated_plugin_key)} flags in `plugins`."
ananthsub marked this conversation as resolved.
Show resolved Hide resolved
" Expected one value for each type at most."
)

# handle the case when the user passes in a strategy instance which has an accelerator, precision,
# checkpoint io or cluster env set up
# TODO: @awaelchli improve the error messages below
Expand Down
19 changes: 18 additions & 1 deletion tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.cpu import CPUAccelerator
from pytorch_lightning.accelerators.gpu import GPUAccelerator
from pytorch_lightning.plugins import LayerSync, NativeSyncBatchNorm, PrecisionPlugin
from pytorch_lightning.plugins import DoublePrecisionPlugin, LayerSync, NativeSyncBatchNorm, PrecisionPlugin
from pytorch_lightning.plugins.environments import (
KubeflowEnvironment,
LightningEnvironment,
SLURMEnvironment,
TorchElasticEnvironment,
)
from pytorch_lightning.plugins.io import TorchCheckpointIO
from pytorch_lightning.strategies import (
DataParallelStrategy,
DDP2Strategy,
Expand Down Expand Up @@ -1019,3 +1020,19 @@ def __init__(self, **kwargs):
assert strategy._layer_sync is None
Trainer(strategy=strategy, sync_batchnorm=True)
assert isinstance(strategy._layer_sync, NativeSyncBatchNorm)


@pytest.mark.parametrize(
["plugins", "expected"],
[
([LightningEnvironment(), SLURMEnvironment()], "ClusterEnvironment"),
([TorchCheckpointIO(), TorchCheckpointIO()], "CheckpointIO"),
(
[PrecisionPlugin(), DoublePrecisionPlugin(), LightningEnvironment(), SLURMEnvironment()],
"PrecisionPlugin, ClusterEnvironment",
),
],
)
def test_plugin_only_one_instance_for_one_type(plugins, expected):
with pytest.raises(MisconfigurationException, match=f"Received multiple values for {expected}"):
Trainer(plugins=plugins)