Skip to content

Commit

Permalink
Replace _TORCH_GREATER_EQUAL_DEV_1_10 with `_TORCH_GREATER_EQUAL_1_…
Browse files Browse the repository at this point in the history
…10` (#10157)
  • Loading branch information
carmocca authored Oct 27, 2021
1 parent 808edcd commit dbe1662
Show file tree
Hide file tree
Showing 10 changed files with 18 additions and 26 deletions.
4 changes: 2 additions & 2 deletions docs/source/advanced/mixed_precision.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ BFloat16 Mixed precision is similar to FP16 mixed precision, however we maintain
Since BFloat16 is more stable than FP16 during training, we do not need to worry about any gradient scaling or nan gradient values that comes with using FP16 mixed precision.

.. testcode::
:skipif: not _TORCH_GREATER_EQUAL_DEV_1_10 or not torch.cuda.is_available()
:skipif: not _TORCH_GREATER_EQUAL_1_10 or not torch.cuda.is_available()

Trainer(gpus=1, precision="bf16")

It is also possible to use BFloat16 mixed precision on the CPU, relying on MKLDNN under the hood.

.. testcode::
:skipif: not _TORCH_GREATER_EQUAL_DEV_1_10
:skipif: not _TORCH_GREATER_EQUAL_1_10

Trainer(precision="bf16")

Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def package_list_from_file(file):
_XLA_AVAILABLE,
_TPU_AVAILABLE,
_TORCHVISION_AVAILABLE,
_TORCH_GREATER_EQUAL_DEV_1_10,
_TORCH_GREATER_EQUAL_1_10,
_module_available,
)
_JSONARGPARSE_AVAILABLE = _module_available("jsonargparse")
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/callbacks/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_DEV_1_10
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _TORCH_GREATER_EQUAL_DEV_1_10:
if _TORCH_GREATER_EQUAL_1_10:
from torch.ao.quantization.qconfig import QConfig
else:
from torch.quantization import QConfig
Expand Down Expand Up @@ -245,7 +245,7 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
# version=None corresponds to using FakeQuantize rather than
# FusedMovingAvgObsFakeQuantize which was introduced in PT1.10
# details in https://github.com/pytorch/pytorch/issues/64564
extra_kwargs = dict(version=None) if _TORCH_GREATER_EQUAL_DEV_1_10 else {}
extra_kwargs = dict(version=None) if _TORCH_GREATER_EQUAL_1_10 else {}
pl_module.qconfig = torch.quantization.get_default_qat_qconfig(self._qconfig, **extra_kwargs)

elif isinstance(self._qconfig, QConfig):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
from pytorch_lightning.utilities import (
_IS_WINDOWS,
_TORCH_GREATER_EQUAL_DEV_1_10,
_TORCH_GREATER_EQUAL_1_10,
GradClipAlgorithmType,
rank_zero_deprecation,
rank_zero_warn,
Expand Down Expand Up @@ -2043,7 +2043,7 @@ def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None:
These hooks ensure that ShardedTensors are included when saving, and are loaded the LightningModule correctly.
"""
if not _TORCH_GREATER_EQUAL_DEV_1_10 or _IS_WINDOWS:
if not _TORCH_GREATER_EQUAL_1_10 or _IS_WINDOWS:
return

from torch.distributed._sharded_tensor import pre_load_state_dict_hook, state_dict_hook
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@

import pytorch_lightning as pl
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_DEV_1_10, AMPType
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10, AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _TORCH_GREATER_EQUAL_DEV_1_10:
if _TORCH_GREATER_EQUAL_1_10:
from torch import autocast
else:
from torch.cuda.amp import autocast
Expand All @@ -47,7 +47,7 @@ def __init__(self, precision: Union[int, str] = 16, use_cpu: bool = False) -> No

def _select_precision_dtype(self, precision: Union[int, str] = 16) -> torch.dtype:
if precision == "bf16":
if not _TORCH_GREATER_EQUAL_DEV_1_10:
if not _TORCH_GREATER_EQUAL_1_10:
raise MisconfigurationException(
"To use bfloat16 with native amp you must install torch greater or equal to 1.10."
)
Expand Down Expand Up @@ -97,7 +97,7 @@ def optimizer_step(
self.scaler.update()

def autocast_context_manager(self) -> autocast:
if _TORCH_GREATER_EQUAL_DEV_1_10:
if _TORCH_GREATER_EQUAL_1_10:
return autocast("cpu" if self.use_cpu else "cuda", dtype=self._dtype)
return autocast()

Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
_TORCH_GREATER_EQUAL_1_8,
_TORCH_GREATER_EQUAL_1_9,
_TORCH_GREATER_EQUAL_1_10,
_TORCH_GREATER_EQUAL_DEV_1_10,
_TORCH_QUANTIZE_AVAILABLE,
_TORCHTEXT_AVAILABLE,
_TORCHVISION_AVAILABLE,
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version:
_TORCH_GREATER_EQUAL_1_8_1 = _compare_version("torch", operator.ge, "1.8.1")
_TORCH_GREATER_EQUAL_1_9 = _compare_version("torch", operator.ge, "1.9.0")
_TORCH_GREATER_EQUAL_1_10 = _compare_version("torch", operator.ge, "1.10.0")
_TORCH_GREATER_EQUAL_DEV_1_10 = _compare_version("torch", operator.ge, "1.10.0", use_base_version=True)
# _TORCH_GREATER_EQUAL_DEV_1_11 = _compare_version("torch", operator.ge, "1.11.0", use_base_version=True)

_APEX_AVAILABLE = _module_available("apex.amp")
_DEEPSPEED_AVAILABLE = _module_available("deepspeed")
Expand Down
6 changes: 1 addition & 5 deletions tests/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities import _IS_WINDOWS, _TORCH_GREATER_EQUAL_DEV_1_10
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -312,10 +311,7 @@ def __init__(self, spec):
self.sharded_tensor.local_shards()[0].tensor.fill_(0)


@pytest.mark.skipif(
not _TORCH_GREATER_EQUAL_DEV_1_10, reason="Test requires the torch version to support `ShardedTensor`"
)
@pytest.mark.skipif(_IS_WINDOWS, reason="Not supported on Windows")
@RunIf(min_torch="1.10", skip_windows=True)
def test_sharded_tensor_state_dict(tmpdir, single_process_pg):
spec = dist._sharding_spec.ChunkShardingSpec(
dim=0,
Expand Down
6 changes: 2 additions & 4 deletions tests/models/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.plugins.environments import SLURMEnvironment
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_DEV_1_10
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.runif import RunIf

Expand Down Expand Up @@ -68,7 +67,7 @@ def _assert_autocast_enabled(self):
assert torch.is_autocast_enabled()


@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_DEV_1_10, reason="Needs bfloat16 support")
@RunIf(min_torch="1.10")
@pytest.mark.parametrize(
"strategy",
[
Expand All @@ -95,8 +94,7 @@ def test_amp_cpus(tmpdir, strategy, precision, num_processes):
assert trainer.state.finished, f"Training failed with {trainer.state}"


@RunIf(min_gpus=2)
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_DEV_1_10, reason="Needs bfloat16 support")
@RunIf(min_gpus=2, min_torch="1.10")
@pytest.mark.parametrize("strategy", [None, "dp", "ddp_spawn"])
@pytest.mark.parametrize("precision", [16, "bf16"])
@pytest.mark.parametrize("gpus", [1, 2])
Expand Down
5 changes: 2 additions & 3 deletions tests/plugins/test_amp_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_DEV_1_10
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -178,7 +177,7 @@ def test_amp_apex_ddp_spawn_fit(amp_level, tmpdir):
trainer.fit(model)


@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_DEV_1_10, reason="Torch CPU AMP is not available.")
@RunIf(min_torch="1.10")
def test_cpu_amp_precision_context_manager(tmpdir):
"""Test to ensure that the context manager correctly is set to CPU + bfloat16, and a scaler isn't set."""
plugin = NativeMixedPrecisionPlugin(precision="bf16", use_cpu=True)
Expand All @@ -197,7 +196,7 @@ def test_precision_selection_raises(monkeypatch):

import pytorch_lightning.plugins.precision.native_amp as amp

monkeypatch.setattr(amp, "_TORCH_GREATER_EQUAL_DEV_1_10", False)
monkeypatch.setattr(amp, "_TORCH_GREATER_EQUAL_1_10", False)
with pytest.warns(
UserWarning, match=r"precision=16\)` but native AMP is not supported on CPU. Using `precision='bf16"
), pytest.raises(MisconfigurationException, match="must install torch greater or equal to 1.10"):
Expand Down

0 comments on commit dbe1662

Please sign in to comment.