From 703a5a662cb5ee13211c5736c9ce395003c92a4c Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Wed, 23 Feb 2022 18:00:41 -0800 Subject: [PATCH 1/9] Only allow one value for one plugin type in flag --- .../connectors/accelerator_connector.py | 19 +++++++++++++++++++ .../test_accelerator_connector.py | 12 ++++++++++++ 2 files changed, 31 insertions(+) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 7cf64203f52b5..23d70c26b282f 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -304,6 +304,13 @@ def _check_config_and_set_final_flags( self._precision_flag = precision if plugins: + plugins_flags_count = { + "strategy": 0, + "precision_plugin": 0, + "precision": 0, + "checkpoint_io": 0, + "cluster_env": 0, + } for plugin in plugins: if isinstance(plugin, Strategy) or isinstance(plugin, str) and plugin in self._registered_strategies: self._strategy_flag = plugin @@ -311,13 +318,17 @@ def _check_config_and_set_final_flags( f"Passing {plugin} `strategy` to the `plugins` flag in Trainer has been deprecated" f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={plugin})` instead." ) + plugins_flags_count["strategy"] += 1 elif isinstance(plugin, PrecisionPlugin): self._precision_plugin_flag = plugin + plugins_flags_count["precision_plugin"] += 1 elif isinstance(plugin, str) and plugin in self._precision_types: self._precision_flag = plugin + plugins_flags_count["precision"] += 1 elif isinstance(plugin, CheckpointIO): self.checkpoint_io = plugin + plugins_flags_count["checkpoint_io"] += 1 elif isinstance(plugin, ClusterEnvironment): self._cluster_environment_flag = plugin elif isinstance(plugin, LayerSync): @@ -327,11 +338,19 @@ def _check_config_and_set_final_flags( " plugin, but this is not allowed. Choose one or the other." ) self._layer_sync = plugin + plugins_flags_count["cluster_env"] += 1 else: raise MisconfigurationException( f"Found invalid type for plugin {plugin}. Expected a precision plugin or training strategy." ) + duplicated_plugin_key = [k for k, v in plugins_flags_count.items() if v > 1] + if duplicated_plugin_key: + raise MisconfigurationException( + f"Received multiple values for {', '.join(duplicated_plugin_key)} flags in `plugins`." + " Expected one 1 value for each type at most" + ) + # handle the case when the user passes in a strategy instance which has an accelerator, precision, # checkpoint io or cluster env set up # TODO: @awaelchli improve the error messages below diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 80603ab1a616e..796de5b8c51ed 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -1019,3 +1019,15 @@ def __init__(self, **kwargs): assert strategy._layer_sync is None Trainer(strategy=strategy, sync_batchnorm=True) assert isinstance(strategy._layer_sync, NativeSyncBatchNorm) + +@pytest.mark.parametrize( + ["plugins", "expected"], + [ + ([LightningEnvironment(), SLURMEnvironment()], "cluster_env"), + (["16", "32"], "precision"), + ([LightningEnvironment(), SLURMEnvironment(), "16", "32"], "precision, cluster_env"), + ], +) +def test_plugin_only_one_instance_for_one_type(plugins, expected): + with pytest.raises(MisconfigurationException, match=f"Received multiple values for {expected}"): + Trainer(plugins=plugins) From 2c01d24420d4b1cbb957159b6e71d80e5de5412c Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Wed, 23 Feb 2022 18:08:29 -0800 Subject: [PATCH 2/9] add change log --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 23d70c26b282f..439730b454c70 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -348,7 +348,7 @@ def _check_config_and_set_final_flags( if duplicated_plugin_key: raise MisconfigurationException( f"Received multiple values for {', '.join(duplicated_plugin_key)} flags in `plugins`." - " Expected one 1 value for each type at most" + " Expected one value for each type at most" ) # handle the case when the user passes in a strategy instance which has an accelerator, precision, From 8a2f62ccece701bb4e72f046cea9abb042696ece Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Mon, 28 Feb 2022 13:12:31 -0800 Subject: [PATCH 3/9] address comments --- .../connectors/accelerator_connector.py | 27 +++++++------------ .../test_accelerator_connector.py | 12 ++++++--- 2 files changed, 18 insertions(+), 21 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 439730b454c70..b5ed50973ba7c 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -14,6 +14,7 @@ import logging import os +from collections import Counter from typing import List, Optional, Union import torch @@ -304,13 +305,7 @@ def _check_config_and_set_final_flags( self._precision_flag = precision if plugins: - plugins_flags_count = { - "strategy": 0, - "precision_plugin": 0, - "precision": 0, - "checkpoint_io": 0, - "cluster_env": 0, - } + plugins_flags_types_list = [] for plugin in plugins: if isinstance(plugin, Strategy) or isinstance(plugin, str) and plugin in self._registered_strategies: self._strategy_flag = plugin @@ -318,19 +313,17 @@ def _check_config_and_set_final_flags( f"Passing {plugin} `strategy` to the `plugins` flag in Trainer has been deprecated" f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={plugin})` instead." ) - plugins_flags_count["strategy"] += 1 + plugins_flags_types_list.append("Strategy") elif isinstance(plugin, PrecisionPlugin): self._precision_plugin_flag = plugin - plugins_flags_count["precision_plugin"] += 1 - elif isinstance(plugin, str) and plugin in self._precision_types: - self._precision_flag = plugin - plugins_flags_count["precision"] += 1 + plugins_flags_types_list.append("PrecisionPlugin") elif isinstance(plugin, CheckpointIO): self.checkpoint_io = plugin - plugins_flags_count["checkpoint_io"] += 1 + plugins_flags_types_list.append("CheckpointIO") elif isinstance(plugin, ClusterEnvironment): self._cluster_environment_flag = plugin + plugins_flags_types_list.append("ClusterEnvironment") elif isinstance(plugin, LayerSync): if sync_batchnorm and not isinstance(plugin, NativeSyncBatchNorm): raise MisconfigurationException( @@ -338,17 +331,17 @@ def _check_config_and_set_final_flags( " plugin, but this is not allowed. Choose one or the other." ) self._layer_sync = plugin - plugins_flags_count["cluster_env"] += 1 else: raise MisconfigurationException( - f"Found invalid type for plugin {plugin}. Expected a precision plugin or training strategy." + f"Found invalid type for plugin {plugin}. Expected PrecisionPlugin, " + "CheckpointIO plugin, ClusterEnviroment plugin or a training strategy." ) - duplicated_plugin_key = [k for k, v in plugins_flags_count.items() if v > 1] + duplicated_plugin_key = [k for k, v in Counter(plugins_flags_types_list).items() if v > 1] if duplicated_plugin_key: raise MisconfigurationException( f"Received multiple values for {', '.join(duplicated_plugin_key)} flags in `plugins`." - " Expected one value for each type at most" + " Expected one value for each type at most." ) # handle the case when the user passes in a strategy instance which has an accelerator, precision, diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 796de5b8c51ed..dbaa315b4028c 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -26,13 +26,14 @@ from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.accelerators.cpu import CPUAccelerator from pytorch_lightning.accelerators.gpu import GPUAccelerator -from pytorch_lightning.plugins import LayerSync, NativeSyncBatchNorm, PrecisionPlugin +from pytorch_lightning.plugins import LayerSync, DoublePrecisionPlugin, NativeSyncBatchNorm, PrecisionPlugin from pytorch_lightning.plugins.environments import ( KubeflowEnvironment, LightningEnvironment, SLURMEnvironment, TorchElasticEnvironment, ) +from pytorch_lightning.plugins.io import TorchCheckpointIO from pytorch_lightning.strategies import ( DataParallelStrategy, DDP2Strategy, @@ -1023,9 +1024,12 @@ def __init__(self, **kwargs): @pytest.mark.parametrize( ["plugins", "expected"], [ - ([LightningEnvironment(), SLURMEnvironment()], "cluster_env"), - (["16", "32"], "precision"), - ([LightningEnvironment(), SLURMEnvironment(), "16", "32"], "precision, cluster_env"), + ([LightningEnvironment(), SLURMEnvironment()], "ClusterEnvironment"), + ([TorchCheckpointIO(), TorchCheckpointIO()], "CheckpointIO"), + ( + [PrecisionPlugin(), DoublePrecisionPlugin(), LightningEnvironment(), SLURMEnvironment()], + "PrecisionPlugin, ClusterEnvironment", + ), ], ) def test_plugin_only_one_instance_for_one_type(plugins, expected): From a455d479b978a0d9ea70b564fb2a2ca213d09ff2 Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Mon, 28 Feb 2022 13:17:45 -0800 Subject: [PATCH 4/9] adress comments --- .../trainer/connectors/accelerator_connector.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index b5ed50973ba7c..eb73f046e946b 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -313,14 +313,14 @@ def _check_config_and_set_final_flags( f"Passing {plugin} `strategy` to the `plugins` flag in Trainer has been deprecated" f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={plugin})` instead." ) - plugins_flags_types_list.append("Strategy") + plugins_flags_types_list.append(Strategy.__name__) elif isinstance(plugin, PrecisionPlugin): self._precision_plugin_flag = plugin - plugins_flags_types_list.append("PrecisionPlugin") + plugins_flags_types_list.append(PrecisionPlugin.__name__) elif isinstance(plugin, CheckpointIO): self.checkpoint_io = plugin - plugins_flags_types_list.append("CheckpointIO") + plugins_flags_types_list.append(CheckpointIO.__name__) elif isinstance(plugin, ClusterEnvironment): self._cluster_environment_flag = plugin plugins_flags_types_list.append("ClusterEnvironment") @@ -331,6 +331,7 @@ def _check_config_and_set_final_flags( " plugin, but this is not allowed. Choose one or the other." ) self._layer_sync = plugin + plugins_flags_types_list.append(ClusterEnvironment.__name__) else: raise MisconfigurationException( f"Found invalid type for plugin {plugin}. Expected PrecisionPlugin, " From cf3e618edb76512f6eeeb72715fefba1bfa8d428 Mon Sep 17 00:00:00 2001 From: four4fish <88516121+four4fish@users.noreply.github.com> Date: Fri, 4 Mar 2022 10:35:42 -0800 Subject: [PATCH 5/9] Update pytorch_lightning/trainer/connectors/accelerator_connector.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index eb73f046e946b..97c2a361d563a 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -335,7 +335,7 @@ def _check_config_and_set_final_flags( else: raise MisconfigurationException( f"Found invalid type for plugin {plugin}. Expected PrecisionPlugin, " - "CheckpointIO plugin, ClusterEnviroment plugin or a training strategy." + "CheckpointIO plugin, ClusterEnviroment plugin or a Strategy." ) duplicated_plugin_key = [k for k, v in Counter(plugins_flags_types_list).items() if v > 1] From 2d383037920055a9d7002f6e31c1c5770011a9da Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Fri, 4 Mar 2022 11:04:30 -0800 Subject: [PATCH 6/9] rebase and remove changelog --- .../trainer/connectors/accelerator_connector.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 97c2a361d563a..8729747352f55 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -305,7 +305,7 @@ def _check_config_and_set_final_flags( self._precision_flag = precision if plugins: - plugins_flags_types_list = [] + plugins_flags_types = Counter() for plugin in plugins: if isinstance(plugin, Strategy) or isinstance(plugin, str) and plugin in self._registered_strategies: self._strategy_flag = plugin @@ -313,17 +313,17 @@ def _check_config_and_set_final_flags( f"Passing {plugin} `strategy` to the `plugins` flag in Trainer has been deprecated" f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={plugin})` instead." ) - plugins_flags_types_list.append(Strategy.__name__) + plugins_flags_types[Strategy.__name__] += 1 elif isinstance(plugin, PrecisionPlugin): self._precision_plugin_flag = plugin - plugins_flags_types_list.append(PrecisionPlugin.__name__) + plugins_flags_types[PrecisionPlugin.__name__] += 1 elif isinstance(plugin, CheckpointIO): self.checkpoint_io = plugin - plugins_flags_types_list.append(CheckpointIO.__name__) + plugins_flags_types[CheckpointIO.__name__] += 1 elif isinstance(plugin, ClusterEnvironment): self._cluster_environment_flag = plugin - plugins_flags_types_list.append("ClusterEnvironment") + plugins_flags_types[ClusterEnvironment.__name__] += 1 elif isinstance(plugin, LayerSync): if sync_batchnorm and not isinstance(plugin, NativeSyncBatchNorm): raise MisconfigurationException( @@ -331,14 +331,14 @@ def _check_config_and_set_final_flags( " plugin, but this is not allowed. Choose one or the other." ) self._layer_sync = plugin - plugins_flags_types_list.append(ClusterEnvironment.__name__) + plugins_flags_types[NativeSyncBatchNorm.__name__] += 1 else: raise MisconfigurationException( f"Found invalid type for plugin {plugin}. Expected PrecisionPlugin, " "CheckpointIO plugin, ClusterEnviroment plugin or a Strategy." ) - duplicated_plugin_key = [k for k, v in Counter(plugins_flags_types_list).items() if v > 1] + duplicated_plugin_key = [k for k, v in plugins_flags_types.items() if v > 1] if duplicated_plugin_key: raise MisconfigurationException( f"Received multiple values for {', '.join(duplicated_plugin_key)} flags in `plugins`." From 1fa65ea38d9f6afde2f706a1a7a3eb9ae2ea48aa Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Fri, 4 Mar 2022 15:32:46 -0800 Subject: [PATCH 7/9] fix mypy --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 8729747352f55..e44153e8736ca 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -15,7 +15,7 @@ import logging import os from collections import Counter -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union import torch @@ -305,7 +305,7 @@ def _check_config_and_set_final_flags( self._precision_flag = precision if plugins: - plugins_flags_types = Counter() + plugins_flags_types: Dict[str, int] = Counter() for plugin in plugins: if isinstance(plugin, Strategy) or isinstance(plugin, str) and plugin in self._registered_strategies: self._strategy_flag = plugin From c7b81c8ed8f399398fa12c7ef71e756401159556 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 4 Mar 2022 23:34:38 +0000 Subject: [PATCH 8/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/accelerators/test_accelerator_connector.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index dbaa315b4028c..2cf1d936e7747 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -26,7 +26,7 @@ from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.accelerators.cpu import CPUAccelerator from pytorch_lightning.accelerators.gpu import GPUAccelerator -from pytorch_lightning.plugins import LayerSync, DoublePrecisionPlugin, NativeSyncBatchNorm, PrecisionPlugin +from pytorch_lightning.plugins import DoublePrecisionPlugin, LayerSync, NativeSyncBatchNorm, PrecisionPlugin from pytorch_lightning.plugins.environments import ( KubeflowEnvironment, LightningEnvironment, @@ -1021,6 +1021,7 @@ def __init__(self, **kwargs): Trainer(strategy=strategy, sync_batchnorm=True) assert isinstance(strategy._layer_sync, NativeSyncBatchNorm) + @pytest.mark.parametrize( ["plugins", "expected"], [ From 810405ff1dff1fef76dd43cc05c3c488adff50ee Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Sat, 5 Mar 2022 10:31:58 +0000 Subject: [PATCH 9/9] Update pytorch_lightning/trainer/connectors/accelerator_connector.py Co-authored-by: ananthsub --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index e44153e8736ca..c3b5b45cfe8f8 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -334,8 +334,8 @@ def _check_config_and_set_final_flags( plugins_flags_types[NativeSyncBatchNorm.__name__] += 1 else: raise MisconfigurationException( - f"Found invalid type for plugin {plugin}. Expected PrecisionPlugin, " - "CheckpointIO plugin, ClusterEnviroment plugin or a Strategy." + f"Found invalid type for plugin {plugin}. Expected one of: PrecisionPlugin, " + "CheckpointIO, ClusterEnviroment, LayerSync, or Strategy." ) duplicated_plugin_key = [k for k, v in plugins_flags_types.items() if v > 1]