Skip to content

Commit

Permalink
Fix torch bfloat import version (#9089)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Naren authored Aug 24, 2021
1 parent f959b13 commit 1bab0a1
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 5 deletions.
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import pytorch_lightning as pl
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, _TORCH_GREATER_EQUAL_1_10, AMPType
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, _TORCH_BFLOAT_AVAILABLE, AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException


Expand All @@ -46,7 +46,7 @@ def __init__(self, precision: Union[int, str] = 16) -> None:

def _select_precision_dtype(self, precision: Union[int, str] = 16) -> torch.dtype:
if precision == "bf16":
if not _TORCH_GREATER_EQUAL_1_10:
if not _TORCH_BFLOAT_AVAILABLE:
raise MisconfigurationException(
"To use bfloat16 with native amp you must install torch greater or equal to 1.10."
)
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
_OMEGACONF_AVAILABLE,
_POPTORCH_AVAILABLE,
_RICH_AVAILABLE,
_TORCH_BFLOAT_AVAILABLE,
_TORCH_GREATER_EQUAL_1_7,
_TORCH_GREATER_EQUAL_1_8,
_TORCH_GREATER_EQUAL_1_9,
Expand Down
7 changes: 6 additions & 1 deletion pytorch_lightning/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,13 @@ def _compare_version(package: str, op, version) -> bool:
_OMEGACONF_AVAILABLE = _module_available("omegaconf")
_POPTORCH_AVAILABLE = _module_available("poptorch")
_RICH_AVAILABLE = _module_available("rich")
_TORCH_BFLOAT_AVAILABLE = _compare_version(
"torch", operator.ge, "1.10.0.dev20210820"
) # todo: swap to 1.10.0 once released
_TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != "none"])
_TORCH_SHARDED_TENSOR_AVAILABLE = _compare_version("torch", operator.ge, "1.10.0.dev20210809")
_TORCH_SHARDED_TENSOR_AVAILABLE = _compare_version(
"torch", operator.ge, "1.10.0.dev20210809"
) # todo: swap to 1.10.0 once released
_TORCHTEXT_AVAILABLE = _module_available("torchtext")
_TORCHVISION_AVAILABLE = _module_available("torchvision")
_TORCHMETRICS_LOWER_THAN_0_3 = _compare_version("torchmetrics", operator.lt, "0.3.0")
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.plugins.environments import SLURMEnvironment
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10
from pytorch_lightning.utilities import _TORCH_BFLOAT_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -71,7 +71,7 @@ def predict(self, batch, batch_idx, dataloader_idx=None):
16,
pytest.param(
"bf16",
marks=pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_10, reason="torch.bfloat16 not available"),
marks=pytest.mark.skipif(not _TORCH_BFLOAT_AVAILABLE, reason="torch.bfloat16 not available"),
),
],
)
Expand Down

0 comments on commit 1bab0a1

Please sign in to comment.