From bad5c6947f1efd14bbacecb387ebd476bf125bde Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 24 Aug 2021 10:28:31 +0100 Subject: [PATCH 1/7] Add validate logic for precision --- .../trainer/connectors/accelerator_connector.py | 11 +++++++++++ tests/accelerators/test_accelerator_connector.py | 7 +++++++ 2 files changed, 18 insertions(+) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index ec3b56489e32a..8dc5adfec0a27 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -546,7 +546,18 @@ def is_using_torchelastic(self) -> bool: ) return TorchElasticEnvironment.is_using_torchelastic() + def _validate_precision_type(self) -> None: + """ + Ensures that the set precision type by the user is valid. + """ + valid_types = (16, 32, 64, "bf16") + if self.precision not in valid_types: + raise MisconfigurationException( + f"Precision {self.precision} is invalid. Allowed precision values: {valid_types}" + ) + def select_precision_plugin(self) -> PrecisionPlugin: + self._validate_precision_type() # set precision type self.amp_type = AMPType.from_str(self.amp_type) diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index e27b873a63941..f4793a61b87c2 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -629,3 +629,10 @@ def test_accelerator_ddp_for_cpu(tmpdir): trainer = Trainer(accelerator="ddp", num_processes=2) assert isinstance(trainer.accelerator, CPUAccelerator) assert isinstance(trainer.training_type_plugin, DDPPlugin) + + +@pytest.mark.parametrize("precision", [1, 12, "invalid"]) +def test_validate_precision_type(tmpdir, precision): + + with pytest.raises(MisconfigurationException, match=f"Precision {precision} is invalid"): + Trainer(precision=precision) From 04b125831dc3d6977df74e88e3de158dd4b5bc7d Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 24 Aug 2021 10:32:16 +0100 Subject: [PATCH 2/7] Add CHANGELOG.md --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 35404e85a5816..9882512234ec8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -78,6 +78,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added Rich Progress Bar ([#8929](https://github.com/PyTorchLightning/pytorch-lightning/pull/8929)) +- Added validate logic for precision ([#9080](https://github.com/PyTorchLightning/pytorch-lightning/pull/9080)) + + ### Changed - Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770)) From 0e1cf64ab38eab3ea6298b63f71a041e2a69f731 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Tue, 24 Aug 2021 10:37:44 +0100 Subject: [PATCH 3/7] Update pytorch_lightning/trainer/connectors/accelerator_connector.py Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 8dc5adfec0a27..be1882cb0cb40 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -548,7 +548,7 @@ def is_using_torchelastic(self) -> bool: def _validate_precision_type(self) -> None: """ - Ensures that the set precision type by the user is valid. + Ensures that the set precision type passed by the user is valid. """ valid_types = (16, 32, 64, "bf16") if self.precision not in valid_types: From 92ee992898893ee5018ab0b7e4d92b5abcf9c3b7 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 24 Aug 2021 11:20:06 +0100 Subject: [PATCH 4/7] Address review --- .../connectors/accelerator_connector.py | 10 ++++--- pytorch_lightning/utilities/enums.py | 27 +++++++++++++++++++ tests/utilities/test_enums.py | 9 +++++++ 3 files changed, 42 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index be1882cb0cb40..80c1b647560cc 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -73,6 +73,7 @@ rank_zero_info, rank_zero_warn, ) +from pytorch_lightning.utilities.enums import PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException if _HOROVOD_AVAILABLE: @@ -148,6 +149,8 @@ def __init__( self.plugins = plugins self._validate_accelerator_and_devices() + self._validate_precision_type() + self._warn_if_devices_flag_ignored() self.select_accelerator_type() @@ -550,14 +553,13 @@ def _validate_precision_type(self) -> None: """ Ensures that the set precision type passed by the user is valid. """ - valid_types = (16, 32, 64, "bf16") - if self.precision not in valid_types: + + if not PrecisionType.supported_type(self.precision): raise MisconfigurationException( - f"Precision {self.precision} is invalid. Allowed precision values: {valid_types}" + f"Precision {self.precision} is invalid. Allowed precision values: {PrecisionType.supported_types()}" ) def select_precision_plugin(self) -> PrecisionPlugin: - self._validate_precision_type() # set precision type self.amp_type = AMPType.from_str(self.amp_type) diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index 977b763299f8a..0091e50b4d568 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -49,6 +49,33 @@ class AMPType(LightningEnum): NATIVE = "native" +class PrecisionType(LightningEnum): + """Type of precision used. + + >>> PrecisionType.BFLOAT == PrecisionType.from_str('bfloat16') + True + >>> # you can match the type + >>> PrecisionType.HALF == 16 + True + >>> # which is type invariant + >>> PrecisionType.HALF in (16, "16") + True + """ + + HALF = 16 + FLOAT = 32 + FULL = 64 + BFLOAT = "bfloat16" + + @staticmethod + def supported_type(precision: Union[str, int]) -> bool: + return any(x == precision for x in PrecisionType) + + @staticmethod + def supported_types() -> List[Union[str, int]]: + return [x.value for x in PrecisionType] + + class DistributedType(LightningEnum): """Define type of ditributed computing. diff --git a/tests/utilities/test_enums.py b/tests/utilities/test_enums.py index ec33fc74b5a3b..9aafbb5175687 100644 --- a/tests/utilities/test_enums.py +++ b/tests/utilities/test_enums.py @@ -1,4 +1,5 @@ from pytorch_lightning.utilities import DeviceType +from pytorch_lightning.utilities.enums import PrecisionType def test_consistency(): @@ -9,3 +10,11 @@ def test_consistency(): # hash cannot be case invariant assert DeviceType.TPU not in {"TPU", "CPU"} assert DeviceType.TPU in {"tpu", "CPU"} + + +def test_precision_supported_types(): + assert PrecisionType.supported_types() == ["16", "32", "64", "bfloat16"] + assert PrecisionType.supported_type(16) + assert PrecisionType.supported_type("16") + assert not PrecisionType.supported_type(1) + assert not PrecisionType.supported_type("invalid") From 61d95e280e934deec1a7b86bf370873b0e7fbb23 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 24 Aug 2021 12:10:54 +0100 Subject: [PATCH 5/7] Address review --- .../trainer/connectors/accelerator_connector.py | 10 +++++----- pytorch_lightning/utilities/enums.py | 6 +----- tests/utilities/test_enums.py | 2 +- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 80c1b647560cc..0086f72b7b65e 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -149,7 +149,7 @@ def __init__( self.plugins = plugins self._validate_accelerator_and_devices() - self._validate_precision_type() + self._validate_precision_type(self.precision) self._warn_if_devices_flag_ignored() @@ -549,14 +549,14 @@ def is_using_torchelastic(self) -> bool: ) return TorchElasticEnvironment.is_using_torchelastic() - def _validate_precision_type(self) -> None: + @staticmethod + def _validate_precision_type(precision: Union[int, str]) -> None: """ Ensures that the set precision type passed by the user is valid. """ - - if not PrecisionType.supported_type(self.precision): + if not PrecisionType.supported_type(precision): raise MisconfigurationException( - f"Precision {self.precision} is invalid. Allowed precision values: {PrecisionType.supported_types()}" + f"Precision {precision} is invalid. Allowed precision values: {PrecisionType.supported_types()}" ) def select_precision_plugin(self) -> PrecisionPlugin: diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index 0091e50b4d568..5807f50ee9c5f 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -52,12 +52,8 @@ class AMPType(LightningEnum): class PrecisionType(LightningEnum): """Type of precision used. - >>> PrecisionType.BFLOAT == PrecisionType.from_str('bfloat16') - True - >>> # you can match the type >>> PrecisionType.HALF == 16 True - >>> # which is type invariant >>> PrecisionType.HALF in (16, "16") True """ @@ -65,7 +61,7 @@ class PrecisionType(LightningEnum): HALF = 16 FLOAT = 32 FULL = 64 - BFLOAT = "bfloat16" + BFLOAT = "bf16" @staticmethod def supported_type(precision: Union[str, int]) -> bool: diff --git a/tests/utilities/test_enums.py b/tests/utilities/test_enums.py index 9aafbb5175687..c92ce938c7607 100644 --- a/tests/utilities/test_enums.py +++ b/tests/utilities/test_enums.py @@ -13,7 +13,7 @@ def test_consistency(): def test_precision_supported_types(): - assert PrecisionType.supported_types() == ["16", "32", "64", "bfloat16"] + assert PrecisionType.supported_types() == ["16", "32", "64", "bf16"] assert PrecisionType.supported_type(16) assert PrecisionType.supported_type("16") assert not PrecisionType.supported_type(1) From 956892ef7ba8c0a852b3d46ac3ffe47df3793745 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 24 Aug 2021 12:21:56 +0100 Subject: [PATCH 6/7] Clean up not implemented exception --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 2 +- pytorch_lightning/utilities/enums.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 0086f72b7b65e..3fefc470aee9b 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -614,7 +614,7 @@ def select_precision_plugin(self) -> PrecisionPlugin: log.info("Using APEX 16bit precision.") return ApexMixedPrecisionPlugin(self.amp_level) - raise NotImplementedError("We only support precisions 64, 32 and 16!") + raise NotImplementedError(f"We only support precision: {PrecisionType.supported_types()}") def select_training_type_plugin(self) -> TrainingTypePlugin: if ( diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index 5807f50ee9c5f..628c5e386dd4d 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -68,7 +68,7 @@ def supported_type(precision: Union[str, int]) -> bool: return any(x == precision for x in PrecisionType) @staticmethod - def supported_types() -> List[Union[str, int]]: + def supported_types() -> List[str]: return [x.value for x in PrecisionType] From 7a9ffaedd684f2e12c12a3b4b4fe5fe0b8f8cd00 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 24 Aug 2021 14:08:06 +0100 Subject: [PATCH 7/7] Address feedback --- .../trainer/connectors/accelerator_connector.py | 15 +++------------ pytorch_lightning/utilities/enums.py | 6 +++--- 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 3fefc470aee9b..e0f1f6bcec43f 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -149,7 +149,6 @@ def __init__( self.plugins = plugins self._validate_accelerator_and_devices() - self._validate_precision_type(self.precision) self._warn_if_devices_flag_ignored() @@ -549,16 +548,6 @@ def is_using_torchelastic(self) -> bool: ) return TorchElasticEnvironment.is_using_torchelastic() - @staticmethod - def _validate_precision_type(precision: Union[int, str]) -> None: - """ - Ensures that the set precision type passed by the user is valid. - """ - if not PrecisionType.supported_type(precision): - raise MisconfigurationException( - f"Precision {precision} is invalid. Allowed precision values: {PrecisionType.supported_types()}" - ) - def select_precision_plugin(self) -> PrecisionPlugin: # set precision type self.amp_type = AMPType.from_str(self.amp_type) @@ -614,7 +603,9 @@ def select_precision_plugin(self) -> PrecisionPlugin: log.info("Using APEX 16bit precision.") return ApexMixedPrecisionPlugin(self.amp_level) - raise NotImplementedError(f"We only support precision: {PrecisionType.supported_types()}") + raise MisconfigurationException( + f"Precision {self.precision} is invalid. Allowed precision values: {PrecisionType.supported_types()}" + ) def select_training_type_plugin(self) -> TrainingTypePlugin: if ( diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index 628c5e386dd4d..73fafabe8f5d9 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -58,9 +58,9 @@ class PrecisionType(LightningEnum): True """ - HALF = 16 - FLOAT = 32 - FULL = 64 + HALF = "16" + FLOAT = "32" + FULL = "64" BFLOAT = "bf16" @staticmethod