Skip to content

Commit

Permalink
Rename TPUHalfPrecisionPlugin to TPUBf16PrecisionPlugin (#10026)
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Wälchli <[email protected]>
  • Loading branch information
carmocca and awaelchli authored Oct 19, 2021
1 parent 0b68f2a commit d45897d
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 22 deletions.
9 changes: 6 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `DataLoaderIterDataFetcher` ([#9020](https://github.com/PyTorchLightning/pytorch-lightning/pull/9020))


- Added bfloat16 support for Lightning Trainer ([#9049](https://github.com/PyTorchLightning/pytorch-lightning/pull/9049))


- Added `DataFetcher` within `Fit / Evaluation` Loop ([#9047](https://github.com/PyTorchLightning/pytorch-lightning/pull/9047))


Expand Down Expand Up @@ -199,6 +196,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `TPUPrecisionPlugin` ([#10020](https://github.com/PyTorchLightning/pytorch-lightning/pull/#10020))


- `torch.bfloat16` support:
* Added bfloat16 support for Lightning Trainer ([#9049](https://github.com/PyTorchLightning/pytorch-lightning/pull/9049))
* Renamed `TPUHalfPrecisionPlugin` to `TPUBf16PrecisionPlugin` ([#10026](https://github.com/PyTorchLightning/pytorch-lightning/pull/10026))



- Added `kfold` example for loop customization ([#9965](https://github.com/PyTorchLightning/pytorch-lightning/pull/9965))


Expand Down
2 changes: 1 addition & 1 deletion docs/source/api_references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ Precision Plugins
ApexMixedPrecisionPlugin
DeepSpeedPrecisionPlugin
TPUPrecisionPlugin
TPUHalfPrecisionPlugin
TPUBf16PrecisionPlugin
DoublePrecisionPlugin
FullyShardedNativeMixedPrecisionPlugin
IPUPrecisionPlugin
Expand Down
2 changes: 1 addition & 1 deletion docs/source/extensions/plugins.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ Precision Plugins
ApexMixedPrecisionPlugin
DeepSpeedPrecisionPlugin
TPUPrecisionPlugin
TPUHalfPrecisionPlugin
TPUBf16PrecisionPlugin
DoublePrecisionPlugin
FullyShardedNativeMixedPrecisionPlugin
IPUPrecisionPlugin
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
from pytorch_lightning.plugins.precision.tpu import TPUPrecisionPlugin
from pytorch_lightning.plugins.precision.tpu_bfloat import TPUHalfPrecisionPlugin
from pytorch_lightning.plugins.precision.tpu_bf16 import TPUBf16PrecisionPlugin
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
Expand Down Expand Up @@ -59,7 +59,7 @@
"SingleDevicePlugin",
"SingleTPUPlugin",
"TPUPrecisionPlugin",
"TPUHalfPrecisionPlugin",
"TPUBf16PrecisionPlugin",
"TPUSpawnPlugin",
"TrainingTypePlugin",
"ParallelPlugin",
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/precision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.tpu import TPUPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.tpu_bfloat import TPUHalfPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.tpu_bf16 import TPUBf16PrecisionPlugin # noqa: F401
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
from pytorch_lightning.plugins.precision import TPUPrecisionPlugin


class TPUHalfPrecisionPlugin(TPUPrecisionPlugin):
class TPUBf16PrecisionPlugin(TPUPrecisionPlugin):
"""Plugin that enables bfloats on TPUs."""

precision: int = 16
precision: str = "bf16"

def connect(
self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any]
Expand Down
22 changes: 14 additions & 8 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
ShardedNativeMixedPrecisionPlugin,
SingleDevicePlugin,
SingleTPUPlugin,
TPUHalfPrecisionPlugin,
TPUBf16PrecisionPlugin,
TPUPrecisionPlugin,
TPUSpawnPlugin,
TrainingTypePlugin,
Expand Down Expand Up @@ -591,6 +591,12 @@ def select_precision_plugin(self) -> PrecisionPlugin:
# set precision type
self.amp_type = AMPType.from_str(self.amp_type)

# validation for all plugins
if self.amp_level is not None and self.amp_type != AMPType.APEX:
raise MisconfigurationException(
f"You have asked for `amp_level={self.amp_level!r}` but it's only supported with `amp_backend='apex'`."
)

if self.use_ipu:
return IPUPrecisionPlugin(self.precision)
if self.use_tpu:
Expand All @@ -603,7 +609,13 @@ def select_precision_plugin(self) -> PrecisionPlugin:
" requesting this feature."
)
elif self.precision in (16, "bf16"):
return TPUHalfPrecisionPlugin()
if self.precision == 16:
# this is not deprecated to ease transition between accelerator environments
rank_zero_warn(
f"You passed `Trainer(accelerator='tpu', precision=16)` but {self.amp_type.value} AMP"
f" is not supported with TPUs. Using `precision='bf16'` instead."
)
return TPUBf16PrecisionPlugin()

if self._distrib_type == DistributedType.DEEPSPEED or isinstance(self._training_type_plugin, DeepSpeedPlugin):
return DeepSpeedPrecisionPlugin(self.precision)
Expand All @@ -614,12 +626,6 @@ def select_precision_plugin(self) -> PrecisionPlugin:
return DoublePrecisionPlugin()
if self.precision in (16, "bf16"):
if self.amp_type == AMPType.NATIVE:
if self.amp_level is not None:
raise MisconfigurationException(
f"You have asked for `amp_level={repr(self.amp_level)}` which is not supported"
" with `amp_backend='native'`."
)

log.info(f"Using native {self.precision} bit Automatic Mixed Precision")
if self._is_sharded_training_type:
return ShardedNativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu)
Expand Down
12 changes: 8 additions & 4 deletions tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,10 +708,9 @@ def test_validate_precision_type(tmpdir, precision):
Trainer(precision=precision)


@RunIf(min_gpus=1, amp_native=True)
def test_amp_level_raises_error_with_native(tmpdir):
with pytest.raises(MisconfigurationException, match="not supported with `amp_backend='native'`"):
_ = Trainer(default_root_dir=tmpdir, gpus=1, amp_level="O2", amp_backend="native", precision=16)
def test_amp_level_raises_error_with_native():
with pytest.raises(MisconfigurationException, match="O2'` but it's only supported with `amp_backend='apex'`"):
_ = Trainer(amp_level="O2", amp_backend="native", precision=16)


def test_strategy_choice_ddp_spawn_cpu(tmpdir):
Expand Down Expand Up @@ -986,3 +985,8 @@ def test_unsupported_tpu_choice(monkeypatch):
monkeypatch.setattr(AcceleratorConnector, "has_tpu", True)
with pytest.raises(MisconfigurationException, match=r"accelerator='tpu', precision=64\)` is not implemented"):
Trainer(accelerator="tpu", precision=64)

with pytest.warns(UserWarning, match=r"accelerator='tpu', precision=16\)` but native AMP is not supported"):
Trainer(accelerator="tpu", precision=16)
with pytest.warns(UserWarning, match=r"accelerator='tpu', precision=16\)` but apex AMP is not supported"):
Trainer(accelerator="tpu", precision=16, amp_backend="apex")

0 comments on commit d45897d

Please sign in to comment.