Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add validate logic for precision #9080

Merged
merged 8 commits into from
Aug 24, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added Rich Progress Bar ([#8929](https://github.com/PyTorchLightning/pytorch-lightning/pull/8929))


- Added validate logic for precision ([#9080](https://github.com/PyTorchLightning/pytorch-lightning/pull/9080))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- Added validate logic for precision ([#9080](https://github.com/PyTorchLightning/pytorch-lightning/pull/9080))
- Added input validation for precision Trainer argument ([#9080](https://github.com/PyTorchLightning/pytorch-lightning/pull/9080))

this should be less ambiguous



### Changed

- Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770))
Expand Down
15 changes: 14 additions & 1 deletion pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
rank_zero_info,
rank_zero_warn,
)
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _HOROVOD_AVAILABLE:
Expand Down Expand Up @@ -148,6 +149,8 @@ def __init__(
self.plugins = plugins

self._validate_accelerator_and_devices()
self._validate_precision_type(self.precision)

self._warn_if_devices_flag_ignored()

self.select_accelerator_type()
Expand Down Expand Up @@ -546,6 +549,16 @@ def is_using_torchelastic(self) -> bool:
)
return TorchElasticEnvironment.is_using_torchelastic()

@staticmethod
def _validate_precision_type(precision: Union[int, str]) -> None:
"""
Ensures that the set precision type passed by the user is valid.
"""
if not PrecisionType.supported_type(precision):
raise MisconfigurationException(
f"Precision {precision} is invalid. Allowed precision values: {PrecisionType.supported_types()}"
)

def select_precision_plugin(self) -> PrecisionPlugin:
# set precision type
self.amp_type = AMPType.from_str(self.amp_type)
Expand Down Expand Up @@ -601,7 +614,7 @@ def select_precision_plugin(self) -> PrecisionPlugin:
log.info("Using APEX 16bit precision.")
return ApexMixedPrecisionPlugin(self.amp_level)

raise NotImplementedError("We only support precisions 64, 32 and 16!")
raise NotImplementedError(f"We only support precision: {PrecisionType.supported_types()}")
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

def select_training_type_plugin(self) -> TrainingTypePlugin:
if (
Expand Down
23 changes: 23 additions & 0 deletions pytorch_lightning/utilities/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,29 @@ class AMPType(LightningEnum):
NATIVE = "native"


class PrecisionType(LightningEnum):
"""Type of precision used.

>>> PrecisionType.HALF == 16
True
>>> PrecisionType.HALF in (16, "16")
True
"""

HALF = 16
FLOAT = 32
FULL = 64
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
BFLOAT = "bf16"

@staticmethod
def supported_type(precision: Union[str, int]) -> bool:
return any(x == precision for x in PrecisionType)

@staticmethod
def supported_types() -> List[str]:
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
return [x.value for x in PrecisionType]


class DistributedType(LightningEnum):
"""Define type of ditributed computing.

Expand Down
7 changes: 7 additions & 0 deletions tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,3 +629,10 @@ def test_accelerator_ddp_for_cpu(tmpdir):
trainer = Trainer(accelerator="ddp", num_processes=2)
assert isinstance(trainer.accelerator, CPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDPPlugin)


@pytest.mark.parametrize("precision", [1, 12, "invalid"])
def test_validate_precision_type(tmpdir, precision):

with pytest.raises(MisconfigurationException, match=f"Precision {precision} is invalid"):
Trainer(precision=precision)
9 changes: 9 additions & 0 deletions tests/utilities/test_enums.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pytorch_lightning.utilities import DeviceType
from pytorch_lightning.utilities.enums import PrecisionType


def test_consistency():
Expand All @@ -9,3 +10,11 @@ def test_consistency():
# hash cannot be case invariant
assert DeviceType.TPU not in {"TPU", "CPU"}
assert DeviceType.TPU in {"tpu", "CPU"}


def test_precision_supported_types():
assert PrecisionType.supported_types() == ["16", "32", "64", "bf16"]
assert PrecisionType.supported_type(16)
assert PrecisionType.supported_type("16")
assert not PrecisionType.supported_type(1)
assert not PrecisionType.supported_type("invalid")