From bf8c1fd76624fb6c3cb8ad0336244908b8c9cde1 Mon Sep 17 00:00:00 2001 From: Burhanuddin Rangwala Date: Wed, 25 Aug 2021 15:11:18 +0530 Subject: [PATCH 01/19] Add doc strings to tensorboard logger class (#9093) --- pytorch_lightning/loggers/tensorboard.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index e51e6b1f6a9d22..a9e1118b9d647a 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -134,10 +134,22 @@ def log_dir(self) -> str: @property def save_dir(self) -> Optional[str]: + """ + Gets the save directory where the TensorBoard experiments are saved. + + Returns: + The local path to the save directory where the TensorBoard experiments are saved. + """ return self._save_dir @property def sub_dir(self) -> Optional[str]: + """ + Gets the sub directory where the TensorBoard experiments are saved. + + Returns: + The local path to the sub directory where the TensorBoard experiments are saved. + """ return self._sub_dir @property @@ -258,10 +270,22 @@ def finalize(self, status: str) -> None: @property def name(self) -> str: + """ + Get the name of the experiment. + + Returns: + The name of the experiment. + """ return self._name @property def version(self) -> int: + """ + Get the experiment version. + + Returns: + The experiment version if specified else the next version. + """ if self._version is None: self._version = self._get_next_version() return self._version From bac8b1be81eeede3d12190ab0d820597df77fc3b Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Wed, 25 Aug 2021 13:18:00 +0100 Subject: [PATCH 02/19] Add support for CPU AMP autocast (#9084) --- CHANGELOG.md | 3 ++ pytorch_lightning/accelerators/cpu.py | 10 +--- .../plugins/precision/native_amp.py | 23 +++++++- .../plugins/precision/sharded_native_amp.py | 7 +-- .../connectors/accelerator_connector.py | 10 ++-- pytorch_lightning/utilities/__init__.py | 1 + pytorch_lightning/utilities/imports.py | 4 +- tests/accelerators/test_cpu.py | 13 ----- tests/models/test_amp.py | 53 +++++++++++++++++-- tests/plugins/test_amp_plugins.py | 48 +++++++++++++++++ 10 files changed, 134 insertions(+), 38 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8d7e99d7163fb2..a221971a1b56d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -90,6 +90,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added validate logic for precision ([#9080](https://github.com/PyTorchLightning/pytorch-lightning/pull/9080)) +- Add support for CPU AMP autocast ([#9084](https://github.com/PyTorchLightning/pytorch-lightning/pull/9084)) + + ### Changed - Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770)) diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index 48957219a1ec0e..46e74193fb557a 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -13,7 +13,6 @@ # limitations under the License. import pytorch_lightning as pl from pytorch_lightning.accelerators.accelerator import Accelerator -from pytorch_lightning.plugins.precision import MixedPrecisionPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -24,15 +23,8 @@ def setup(self, trainer: "pl.Trainer") -> None: """ Raises: MisconfigurationException: - If AMP is used with CPU, or if the selected device is not CPU. + If the selected device is not CPU. """ - if isinstance(self.precision_plugin, MixedPrecisionPlugin): - raise MisconfigurationException( - " Mixed precision is currenty only supported with the AMP backend" - " and AMP + CPU is not supported. Please use a GPU option or" - " change precision setting." - ) - if "cpu" not in str(self.root_device): raise MisconfigurationException(f"Device should be CPU, got {self.root_device} instead.") diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 4c34bed9720678..c32f926236f567 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -19,7 +19,12 @@ import pytorch_lightning as pl from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin -from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, _TORCH_BFLOAT_AVAILABLE, AMPType +from pytorch_lightning.utilities import ( + _NATIVE_AMP_AVAILABLE, + _TORCH_BFLOAT_AVAILABLE, + _TORCH_CPU_AMP_AVAILABLE, + AMPType, +) from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -31,7 +36,7 @@ class NativeMixedPrecisionPlugin(MixedPrecisionPlugin): precision: Whether to use torch.float16 (16) or torch.bfloat16 (bf16). """ - def __init__(self, precision: Union[int, str] = 16) -> None: + def __init__(self, precision: Union[int, str] = 16, use_cpu: bool = False) -> None: super().__init__() if not _NATIVE_AMP_AVAILABLE: @@ -39,6 +44,14 @@ def __init__(self, precision: Union[int, str] = 16) -> None: "You have asked for native AMP but your PyTorch version does not support it." " Consider upgrading with `pip install torch>=1.6`." ) + + if use_cpu and not _TORCH_CPU_AMP_AVAILABLE: + raise MisconfigurationException( + "You have asked for native AMP on CPU, but AMP is only available on GPU for PyTorch 1.9 " + "and lower. To use native AMP on CPU, install PyTorch 1.10 or later." + ) + + self.use_cpu = use_cpu self._fast_dtype = self._select_precision_dtype(precision) self.backend = AMPType.NATIVE if not self.is_bfloat16: @@ -51,6 +64,10 @@ def _select_precision_dtype(self, precision: Union[int, str] = 16) -> torch.dtyp "To use bfloat16 with native amp you must install torch greater or equal to 1.10." ) return torch.bfloat16 + elif self.use_cpu: + raise MisconfigurationException( + "CPU native amp only supports bfloat16. Please pass precision='bf16' to the Trainer." + ) return torch.float16 @property @@ -91,6 +108,8 @@ def pre_optimizer_step( return False def autocast_context_manager(self) -> torch.cuda.amp.autocast: + if self.use_cpu: + return torch.cpu.amp.autocast(fast_dtype=self._fast_dtype) if self.is_bfloat16: return torch.cuda.amp.autocast(fast_dtype=self._fast_dtype) return torch.cuda.amp.autocast() diff --git a/pytorch_lightning/plugins/precision/sharded_native_amp.py b/pytorch_lightning/plugins/precision/sharded_native_amp.py index 861e5e1363dd26..a1eb23e478132d 100644 --- a/pytorch_lightning/plugins/precision/sharded_native_amp.py +++ b/pytorch_lightning/plugins/precision/sharded_native_amp.py @@ -24,9 +24,10 @@ class ShardedNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin): """Mixed Precision for Sharded Training""" - def __init__(self, precision: Union[int, str] = 16) -> None: - super().__init__(precision) - self.scaler = ShardedGradScaler() + def __init__(self, precision: Union[int, str] = 16, use_cpu: bool = False) -> None: + super().__init__(precision, use_cpu=use_cpu) + if not self.use_cpu: + self.scaler = ShardedGradScaler() def clip_grad_by_norm( self, optimizer: "OSS", clip_val: Union[int, float], norm_type: float = 2.0, eps: float = 1e-6 diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 40ba9de2a96fd7..85a2c7ee87d131 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -567,10 +567,6 @@ def select_precision_plugin(self) -> PrecisionPlugin: return TPUHalfPrecisionPlugin() if self.amp_type == AMPType.NATIVE: - if self.use_cpu: - raise MisconfigurationException( - "You have asked for native AMP on CPU, but AMP is only available on GPU." - ) if not _NATIVE_AMP_AVAILABLE: msg = ( "You have asked for native AMP but your PyTorch version does not support it." @@ -585,10 +581,10 @@ def select_precision_plugin(self) -> PrecisionPlugin: else: log.info(f"Using native {self.precision} bit Automatic Mixed Precision") if self._is_sharded_training_type: - return ShardedNativeMixedPrecisionPlugin(self.precision) + return ShardedNativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu) if self._is_fully_sharded_training_type: - return FullyShardedNativeMixedPrecisionPlugin(self.precision) - return NativeMixedPrecisionPlugin(self.precision) + return FullyShardedNativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu) + return NativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu) if self.amp_type == AMPType.APEX: if not _APEX_AVAILABLE: diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 56e3e03910de9a..ed9bb930001d87 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -45,6 +45,7 @@ _POPTORCH_AVAILABLE, _RICH_AVAILABLE, _TORCH_BFLOAT_AVAILABLE, + _TORCH_CPU_AMP_AVAILABLE, _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index dd31d23f6cf84d..ff0e5bea6bf19e 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -70,7 +70,6 @@ def _compare_version(package: str, op, version) -> bool: _TORCH_GREATER_EQUAL_1_9 = _compare_version("torch", operator.ge, "1.9.0") _TORCH_GREATER_EQUAL_1_10 = _compare_version("torch", operator.ge, "1.10.0") - _APEX_AVAILABLE = _module_available("apex.amp") _BOLTS_AVAILABLE = _module_available("pl_bolts") _DEEPSPEED_AVAILABLE = _module_available("deepspeed") @@ -87,6 +86,9 @@ def _compare_version(package: str, op, version) -> bool: _OMEGACONF_AVAILABLE = _module_available("omegaconf") _POPTORCH_AVAILABLE = _module_available("poptorch") _RICH_AVAILABLE = _module_available("rich") +_TORCH_CPU_AMP_AVAILABLE = _compare_version( + "torch", operator.ge, "1.10.0dev20210501" +) # todo: swap to 1.10.0 once released _TORCH_BFLOAT_AVAILABLE = _compare_version( "torch", operator.ge, "1.10.0.dev20210820" ) # todo: swap to 1.10.0 once released diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index 7280afea02f76d..d695e4c63f43e6 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -1,7 +1,6 @@ import os from pathlib import Path from typing import Any, Dict, Union -from unittest.mock import Mock import pytest import torch @@ -10,22 +9,10 @@ from pytorch_lightning.accelerators import CPUAccelerator from pytorch_lightning.plugins import SingleDevicePlugin from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO -from pytorch_lightning.plugins.precision import MixedPrecisionPlugin from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin -from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel -def test_unsupported_precision_plugins(): - """Test error messages are raised for unsupported precision plugins with CPU.""" - trainer = Mock() - accelerator = CPUAccelerator( - training_type_plugin=SingleDevicePlugin(torch.device("cpu")), precision_plugin=MixedPrecisionPlugin() - ) - with pytest.raises(MisconfigurationException, match=r"AMP \+ CPU is not supported"): - accelerator.setup(trainer=trainer) - - @pytest.mark.parametrize("delay_dispatch", [True, False]) def test_plugin_setup_optimizers_in_pre_dispatch(tmpdir, delay_dispatch): """ diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index 173d1f8a0d1f8e..5fdc180613e676 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -22,7 +22,7 @@ import tests.helpers.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.plugins.environments import SLURMEnvironment -from pytorch_lightning.utilities import _TORCH_BFLOAT_AVAILABLE +from pytorch_lightning.utilities import _TORCH_BFLOAT_AVAILABLE, _TORCH_CPU_AMP_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -30,13 +30,19 @@ class AMPTestModel(BoringModel): def _step(self, batch, batch_idx): - assert torch.is_autocast_enabled() + self._assert_autocast_enabled() output = self(batch) bfloat16 = self.trainer.precision_plugin.is_bfloat16 assert output.dtype == torch.float16 if not bfloat16 else torch.bfloat16 loss = self.loss(batch, output) return loss + def loss(self, batch, prediction): + # todo (sean): convert bfloat16 to float32 as mse loss for cpu amp is currently not supported + if self.trainer.precision_plugin.use_cpu: + prediction = prediction.float() + return super().loss(batch, prediction) + def training_step(self, batch, batch_idx): output = self._step(batch, batch_idx) return {"loss": output} @@ -50,17 +56,58 @@ def test_step(self, batch, batch_idx): return {"y": output} def predict(self, batch, batch_idx, dataloader_idx=None): - assert torch.is_autocast_enabled() + self._assert_autocast_enabled() output = self(batch) bfloat16 = self.trainer.precision_plugin.is_bfloat16 assert output.dtype == torch.float16 if not bfloat16 else torch.bfloat16 return output + def _assert_autocast_enabled(self): + if self.trainer.precision_plugin.use_cpu: + assert torch.is_autocast_cpu_enabled() + else: + assert torch.is_autocast_enabled() + + +@pytest.mark.skipif(not _TORCH_CPU_AMP_AVAILABLE, reason="CPU AMP not available") +@pytest.mark.parametrize( + "accelerator", + [ + None, + pytest.param("dp", marks=pytest.mark.skip("dp + amp not supported currently")), # TODO + "ddp_spawn", + ], +) +@pytest.mark.parametrize( + "precision", + [ + pytest.param(16, marks=pytest.mark.skip("CPU precision 16 is not supported in PyTorch yet.")), # TODO + "bf16", + ], +) +@pytest.mark.parametrize("num_processes", [1, 2]) +def test_amp_cpus(tmpdir, accelerator, precision, num_processes): + """Make sure combinations of AMP and training types work if supported.""" + tutils.reset_seed() + + trainer = Trainer( + default_root_dir=tmpdir, num_processes=num_processes, max_epochs=1, accelerator=accelerator, precision=precision + ) + + model = AMPTestModel() + # tutils.run_model_test(trainer_options, model) + trainer.fit(model) + trainer.test(model) + trainer.predict(model, DataLoader(RandomDataset(32, 64))) + + assert trainer.state.finished, f"Training failed with {trainer.state}" + @RunIf(min_gpus=2) @pytest.mark.parametrize( "accelerator", [ + None, pytest.param("dp", marks=pytest.mark.skip("dp + amp not supported currently")), # TODO "ddp_spawn", ], diff --git a/tests/plugins/test_amp_plugins.py b/tests/plugins/test_amp_plugins.py index 15ec43973b0ed6..930cbd87972480 100644 --- a/tests/plugins/test_amp_plugins.py +++ b/tests/plugins/test_amp_plugins.py @@ -21,6 +21,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.plugins import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin from pytorch_lightning.plugins.precision import MixedPrecisionPlugin +from pytorch_lightning.utilities import _TORCH_CPU_AMP_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -190,3 +191,50 @@ def test_amp_precision_16_bfloat_throws_error(tmpdir): precision="bf16", gpus=1, ) + + +@RunIf(amp_native=True, max_torch="1.9") +def test_cpu_amp_precision_throws_error(tmpdir): + with pytest.raises( + MisconfigurationException, + match="To use native AMP on CPU, install PyTorch 1.10 or later.", + ): + NativeMixedPrecisionPlugin(use_cpu=True) + + +@pytest.mark.skipif(not _TORCH_CPU_AMP_AVAILABLE, reason="Torch CPU AMP is not available.") +@RunIf( + min_gpus=1, + amp_native=True, +) +def test_cpu_amp_precision_context_manager(tmpdir): + """ + Test to ensure that the context manager correctly is set to CPU + bfloat16, and a scaler isn't set. + """ + + plugin = NativeMixedPrecisionPlugin(precision="bf16", use_cpu=True) + assert plugin.use_cpu + assert not hasattr(plugin, "scaler") + context_manager = plugin.autocast_context_manager() + assert isinstance(context_manager, torch.cpu.amp.autocast) + assert context_manager.fast_dtype == torch.bfloat16 + + +@pytest.mark.skipif(not _TORCH_CPU_AMP_AVAILABLE, reason="Torch CPU AMP is not available.") +@RunIf( + min_gpus=1, + amp_native=True, +) +def test_cpu_amp_precision_16_throws_error(tmpdir): + """ + Throw error when using 16 as Native CPU AMP only supports bfloat16. + """ + + with pytest.raises( + MisconfigurationException, + match="CPU native amp only supports bfloat16. Please pass precision='bf16' to the Trainer.", + ): + Trainer( + default_root_dir=tmpdir, + precision=16, + ) From 5ff89a7074c0296bd5d7e34b45c6238e098c3e30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 25 Aug 2021 14:48:06 +0200 Subject: [PATCH 03/19] Rename test file from log_dir to test_log_dir (#9105) --- tests/trainer/properties/{log_dir.py => test_log_dir.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/trainer/properties/{log_dir.py => test_log_dir.py} (100%) diff --git a/tests/trainer/properties/log_dir.py b/tests/trainer/properties/test_log_dir.py similarity index 100% rename from tests/trainer/properties/log_dir.py rename to tests/trainer/properties/test_log_dir.py From 12d076f7476d949ac614cd30235dfa17c62f7151 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Wed, 25 Aug 2021 15:44:29 +0100 Subject: [PATCH 04/19] [docs] Add Mixed Precision detailed docs (#9104) --- docs/source/advanced/mixed_precision.rst | 83 ++++++++++++++++++++++++ docs/source/conf.py | 2 + docs/source/guides/speed.rst | 4 +- docs/source/index.rst | 1 + 4 files changed, 88 insertions(+), 2 deletions(-) create mode 100644 docs/source/advanced/mixed_precision.rst diff --git a/docs/source/advanced/mixed_precision.rst b/docs/source/advanced/mixed_precision.rst new file mode 100644 index 00000000000000..ea784f0894a009 --- /dev/null +++ b/docs/source/advanced/mixed_precision.rst @@ -0,0 +1,83 @@ +.. testsetup:: * + + from pytorch_lightning import Trainer + + +.. _amp: + +Mixed Precision Training +======================== + +Mixed precision combines the use of both FP32 and lower bit floating points (such as FP16) to reduce memory footprint during model training, resulting in improved performance. + +Lightning offers mixed precision training for GPUs and CPUs, as well as bfloat16 mixed precision training for TPUs. + +.. note:: + + In some cases it is important to remain in FP32 for numerical stability, so keep this in mind when using mixed precision. + + For example when running scatter operations during the forward (such as torchpoint3d) computation must remain in FP32. + +FP16 Mixed Precision +-------------------- + +In most cases, mixed precision uses FP16. Supported torch operations are automatically run in FP16, saving memory and improving throughput on GPU and TPU accelerators. + +Since computation happens in FP16, there is a chance of numerical instability. This is handled internally by a dynamic grad scaler which skips steps that are invalid, and adjusts the scaler to ensure subsequent steps fall within a finite range. For more information `see the autocast docs `__. + +.. note:: + + When using TPUs, setting ``precision=16`` will enable bfloat16 which is the only supported precision type on TPUs. + +.. testcode:: + :skipif: not _APEX_AVAILABLE and not _NATIVE_AMP_AVAILABLE or not torch.cuda.is_available() + + Trainer(gpus=1, precision=16) + +BFloat16 Mixed Precision +------------------------ + +.. warning:: + + BFloat16 requires PyTorch 1.10 or later. Currently this requires installing `PyTorch Nightly `__. + + BFloat16 is also experimental and may not provide large speedups or memory improvements, but offer better numerical stability. + + Do note for GPUs, largest benefits require `Ampere `__ based GPUs, such as A100s or 3090s. + +BFloat16 Mixed precision is similar to FP16 mixed precision, however we maintain more of the "dynamic range" that FP32 has to offer. This means we are able to improve numerical stability, compared to FP16 mixed precision. For more information see `this TPU performance blog post `__. + +Since BFloat16 is more stable than FP16 during training, we do not need to worry about any gradient scaling or nan gradient values that comes with using FP16 mixed precision. + +.. testcode:: + :skipif: not _TORCH_BFLOAT_AVAILABLE + + Trainer(gpus=1, precision="bf16") + +It is also possible to use BFloat16 mixed precision on the CPU, relying on MKLDNN under the hood. + +.. testcode:: + :skipif: not _TORCH_CPU_AMP_AVAILABLE + + Trainer(precision="bf16") + +NVIDIA APEX Mixed Precision +--------------------------- + +.. warning:: + + We strongly recommend to use the above native mixed precision rather than NVIDIA APEX unless you require more finer control. + +`NVIDIA APEX `__ offers some additional flexibility in setting mixed precision. This can be useful for when wanting to try out different precision configurations, such as keeping most of your weights in FP16 as well as running computation in FP16. + +.. testcode:: + :skipif: not _APEX_AVAILABLE and not _NATIVE_AMP_AVAILABLE or not torch.cuda.is_available() + + Trainer(gpus=1, amp_backend="apex") + +Set the `NVIDIA optimization level `__ via the trainer. + +.. testcode:: + :skipif: not _APEX_AVAILABLE and not _NATIVE_AMP_AVAILABLE or not torch.cuda.is_available() + + Trainer(gpus=1, amp_backend="apex", amp_level="O2") diff --git a/docs/source/conf.py b/docs/source/conf.py index 8ddc896b6e9120..4adbacd4cf60c0 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -370,6 +370,8 @@ def package_list_from_file(file): _XLA_AVAILABLE, _TPU_AVAILABLE, _TORCHVISION_AVAILABLE, + _TORCH_BFLOAT_AVAILABLE, + _TORCH_CPU_AMP_AVAILABLE, _module_available, ) _JSONARGPARSE_AVAILABLE = _module_available("jsonargparse") diff --git a/docs/source/guides/speed.rst b/docs/source/guides/speed.rst index 4e3ed0b1de8011..fd245e741b9aa1 100644 --- a/docs/source/guides/speed.rst +++ b/docs/source/guides/speed.rst @@ -186,7 +186,7 @@ Read more in our :ref:`accelerators` and :ref:`plugins` guides. ----------- -.. _amp: +.. _speed_amp: ********************************* Mixed precision (16-bit) training @@ -210,7 +210,7 @@ Mixed precision (16-bit) training Mixed precision combines the use of both 32 and 16 bit floating points to reduce memory footprint during model training, resulting in improved performance, achieving +3X speedups on modern GPUs. -Lightning offers mixed precision or 16-bit training for GPUs and TPUs. +Lightning offers mixed precision training for GPUs and CPUs, as well as bfloat16 mixed precision training for TPUs. .. testcode:: diff --git a/docs/source/index.rst b/docs/source/index.rst index f3c154a7d257b1..e1de1ed30defa9 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -54,6 +54,7 @@ PyTorch Lightning Documentation common/loggers advanced/multi_gpu advanced/advanced_gpu + advanced/mixed_precision common/weights_loading advanced/checkpoint_io common/optimizers From f01a9a6cd2c7035f10bf83ba2db3100ffb3ae8b6 Mon Sep 17 00:00:00 2001 From: four4fish <88516121+four4fish@users.noreply.github.com> Date: Wed, 25 Aug 2021 12:10:28 -0700 Subject: [PATCH 05/19] Remove `BasePlugin` (#9066) * Remove BasePlugin Co-authored-by: ananthsub Co-authored-by: Carlos Mocholi --- CHANGELOG.md | 5 +- pytorch_lightning/accelerators/accelerator.py | 8 +-- pytorch_lightning/plugins/__init__.py | 11 ++-- pytorch_lightning/plugins/base_plugin.py | 51 ------------------- .../plugins/precision/precision_plugin.py | 35 +++++++++++-- .../training_type/training_type_plugin.py | 12 ++++- pytorch_lightning/trainer/trainer.py | 5 +- 7 files changed, 58 insertions(+), 69 deletions(-) delete mode 100644 pytorch_lightning/plugins/base_plugin.py diff --git a/CHANGELOG.md b/CHANGELOG.md index a221971a1b56d2..5232cc793163f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -150,8 +150,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `prepare_data_per_node` flag on Trainer and set it as a property of `DataHooks`, accessible in the `LightningModule` and `LightningDataModule` [#8958](https://github.com/PyTorchLightning/pytorch-lightning/pull/8958) -- - ### Removed - Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/)) @@ -208,6 +206,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed `InterBatchProcessor` in favor of `DataLoaderIterDataFetcher` ([#9052](https://github.com/PyTorchLightning/pytorch-lightning/pull/9052)) +- Removed `Plugin` in `base_plugin.py`, access `TrainingTypePlugin` and `PrecisionPlugin` directly instead ([#9066](https://github.com/PyTorchLightning/pytorch-lightning/pull/9066)) + + ### Fixed - Fixed save/load/resume from checkpoint for DeepSpeed Plugin ( diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 2b647f5811fefc..6038a8abc8f5c3 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -186,7 +186,7 @@ def training_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: - hiddens(:class:`~torch.Tensor`): Passed in if :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0. """ - with self.precision_plugin.train_step_context(), self.training_type_plugin.train_step_context(): + with self.precision_plugin.train_step_context(): return self.training_type_plugin.training_step(*step_kwargs.values()) def post_training_step(self) -> None: @@ -204,7 +204,7 @@ def validation_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[S - dataloader_idx (int): The index of the dataloader that produced this batch (only if multiple val dataloaders used) """ - with self.precision_plugin.val_step_context(), self.training_type_plugin.val_step_context(): + with self.precision_plugin.val_step_context(): return self.training_type_plugin.validation_step(*step_kwargs.values()) def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]: @@ -219,7 +219,7 @@ def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OU - dataloader_idx (int): The index of the dataloader that produced this batch (only if multiple test dataloaders used). """ - with self.precision_plugin.test_step_context(), self.training_type_plugin.test_step_context(): + with self.precision_plugin.test_step_context(): return self.training_type_plugin.test_step(*step_kwargs.values()) def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: @@ -234,7 +234,7 @@ def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: - dataloader_idx (int): The index of the dataloader that produced this batch (only if multiple predict dataloaders used). """ - with self.precision_plugin.predict_step_context(), self.training_type_plugin.predict_step_context(): + with self.precision_plugin.predict_step_context(): return self.training_type_plugin.predict_step(*step_kwargs.values()) def training_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py index a69065fa74f733..13f8c7404baf4f 100644 --- a/pytorch_lightning/plugins/__init__.py +++ b/pytorch_lightning/plugins/__init__.py @@ -1,4 +1,7 @@ -from pytorch_lightning.plugins.base_plugin import Plugin +from pathlib import Path +from typing import Union + +from pytorch_lightning.plugins.environments import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO from pytorch_lightning.plugins.plugins_registry import ( # noqa: F401 @@ -30,6 +33,9 @@ from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin +PLUGIN = Union[TrainingTypePlugin, PrecisionPlugin, ClusterEnvironment, CheckpointIO] +PLUGIN_INPUT = Union[PLUGIN, str] + __all__ = [ "CheckpointIO", "TorchCheckpointIO", @@ -55,13 +61,10 @@ "TPUSpawnPlugin", "TrainingTypePlugin", "ParallelPlugin", - "Plugin", "DDPShardedPlugin", "DDPSpawnShardedPlugin", ] -from pathlib import Path - FILE_ROOT = Path(__file__).parent TRAINING_TYPE_BASE_MODULE = "pytorch_lightning.plugins.training_type" diff --git a/pytorch_lightning/plugins/base_plugin.py b/pytorch_lightning/plugins/base_plugin.py deleted file mode 100644 index 515fc29d0e355a..00000000000000 --- a/pytorch_lightning/plugins/base_plugin.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import contextlib -from abc import ABC -from typing import Generator - -import pytorch_lightning as pl - - -class Plugin(ABC): - """Basic class for all precision- and training type plugins.""" - - def pre_dispatch(self) -> None: - """Hook to do something before the training/evaluation/prediction starts.""" - - def dispatch(self, trainer: "pl.Trainer") -> None: - """Hook to do something at trainer run_stage starts.""" - - def post_dispatch(self) -> None: - """Hook to do something after the training/evaluation/prediction finishes.""" - - @contextlib.contextmanager - def train_step_context(self) -> Generator: - """A contextmanager for the trainstep""" - yield - - @contextlib.contextmanager - def val_step_context(self) -> Generator: - """A contextmanager for the validation step""" - yield - - @contextlib.contextmanager - def test_step_context(self) -> Generator: - """A contextmanager for the teststep""" - yield - - @contextlib.contextmanager - def predict_step_context(self) -> Generator: - """A contextmanager for the predict step""" - yield diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 1261fea87c06ed..86486bfc37cd92 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, List, Optional, Tuple, Union +import contextlib +from typing import Any, Callable, Generator, List, Optional, Tuple, Union import torch from torch import Tensor @@ -20,12 +21,11 @@ import pytorch_lightning as pl from pytorch_lightning.core.hooks import CheckpointHooks -from pytorch_lightning.plugins.base_plugin import Plugin from pytorch_lightning.utilities import GradClipAlgorithmType from pytorch_lightning.utilities.types import _PARAMETERS -class PrecisionPlugin(Plugin, CheckpointHooks): +class PrecisionPlugin(CheckpointHooks): """ Base class for all plugins handling the precision-specific parts of the training. The class attribute precision must be overwritten in child classes. @@ -136,3 +136,32 @@ def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) - """Clip gradients by norm""" parameters = self.master_params(optimizer) torch.nn.utils.clip_grad_norm_(parameters, clip_val) + + def pre_dispatch(self) -> None: + """Hook to do something before the training/evaluation/prediction starts.""" + + def dispatch(self, trainer: "pl.Trainer") -> None: + """Hook to do something when ``Accelerator.dispatch()`` gets called.""" + + def post_dispatch(self) -> None: + """Hook to do something after the training/evaluation/prediction finishes.""" + + @contextlib.contextmanager + def train_step_context(self) -> Generator: + """A contextmanager for the training step""" + yield + + @contextlib.contextmanager + def val_step_context(self) -> Generator: + """A contextmanager for the validation step""" + yield + + @contextlib.contextmanager + def test_step_context(self) -> Generator: + """A contextmanager for the test step""" + yield + + @contextlib.contextmanager + def predict_step_context(self) -> Generator: + """A contextmanager for the predict step""" + yield diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 82a33290aafbf9..6ee1ce77c8c24c 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -25,14 +25,13 @@ import pytorch_lightning as pl from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins import TorchCheckpointIO -from pytorch_lightning.plugins.base_plugin import Plugin from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT TBroadcast = TypeVar("T") -class TrainingTypePlugin(Plugin, ABC): +class TrainingTypePlugin(ABC): """ Base class for all training type plugins that change the behaviour of the training, validation and test-loop. """ @@ -352,3 +351,12 @@ def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) Called in the training loop before anything happens for that batch. """ pass + + def pre_dispatch(self) -> None: + """Hook to do something before the training/evaluation/prediction starts.""" + + def dispatch(self, trainer: "pl.Trainer") -> None: + """Hook to do something at trainer run_stage starts.""" + + def post_dispatch(self) -> None: + """Hook to do something after the training/evaluation/prediction finishes.""" diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 19ccf3935a168f..7ebbc55ae7ac9e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -32,8 +32,7 @@ from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop from pytorch_lightning.loops.dataloader.prediction_loop import PredictionLoop from pytorch_lightning.loops.fit_loop import FitLoop -from pytorch_lightning.plugins import DDPSpawnPlugin, Plugin -from pytorch_lightning.plugins.environments import ClusterEnvironment +from pytorch_lightning.plugins import DDPSpawnPlugin, PLUGIN_INPUT from pytorch_lightning.profiler import ( AdvancedProfiler, BaseProfiler, @@ -151,7 +150,7 @@ def __init__( terminate_on_nan: bool = False, auto_scale_batch_size: Union[str, bool] = False, prepare_data_per_node: Optional[bool] = None, - plugins: Optional[Union[List[Union[Plugin, ClusterEnvironment, str]], Plugin, ClusterEnvironment, str]] = None, + plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None, amp_backend: str = "native", amp_level: str = "O2", distributed_backend: Optional[str] = None, From 9d62f248476c6358d8707188f7b20fafa79f8a4f Mon Sep 17 00:00:00 2001 From: Burhanuddin Rangwala Date: Thu, 26 Aug 2021 00:45:00 +0530 Subject: [PATCH 06/19] Add docstrings to Test Tube logger (#9110) --- pytorch_lightning/loggers/test_tube.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/pytorch_lightning/loggers/test_tube.py b/pytorch_lightning/loggers/test_tube.py index 800b13b83c39c3..bbe4897b47528b 100644 --- a/pytorch_lightning/loggers/test_tube.py +++ b/pytorch_lightning/loggers/test_tube.py @@ -197,10 +197,22 @@ def close(self) -> None: @property def save_dir(self) -> Optional[str]: + """ + Gets the save directory. + + Returns: + The path to the save directory. + """ return self._save_dir @property def name(self) -> str: + """ + Gets the experiment name. + + Returns: + The experiment name if the experiment exists, else the name specified in the constructor. + """ if self._experiment is None: return self._name @@ -208,6 +220,12 @@ def name(self) -> str: @property def version(self) -> int: + """ + Gets the experiment version. + + Returns: + The experiment version if the experiment exists, else the next version. + """ if self._experiment is None: return self._version From 69f66fd6bb361a7932a82291e4ef001f4f381f99 Mon Sep 17 00:00:00 2001 From: Santiago Castro Date: Thu, 26 Aug 2021 02:54:13 -0300 Subject: [PATCH 07/19] Fix typing in hparams methods (#9116) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- pytorch_lightning/core/mixins/hparams_mixin.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/mixins/hparams_mixin.py b/pytorch_lightning/core/mixins/hparams_mixin.py index 029ecc173bcf24..72129f22f54bb6 100644 --- a/pytorch_lightning/core/mixins/hparams_mixin.py +++ b/pytorch_lightning/core/mixins/hparams_mixin.py @@ -15,7 +15,7 @@ import inspect import types from argparse import Namespace -from typing import Optional, Sequence, Union +from typing import MutableMapping, Optional, Sequence, Union from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES from pytorch_lightning.utilities import AttributeDict @@ -104,7 +104,7 @@ class ``__init__`` to be ignored frame = inspect.currentframe().f_back save_hyperparameters(self, *args, ignore=ignore, frame=frame) - def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None: + def _set_hparams(self, hp: Union[MutableMapping, Namespace, str]) -> None: hp = self._to_hparams_dict(hp) if isinstance(hp, dict) and isinstance(self.hparams, dict): @@ -113,7 +113,7 @@ def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None: self._hparams = hp @staticmethod - def _to_hparams_dict(hp: Union[dict, Namespace, str]): + def _to_hparams_dict(hp: Union[MutableMapping, Namespace, str]): if isinstance(hp, Namespace): hp = vars(hp) if isinstance(hp, dict): From 366fb39d2e021ef6fa663b972f2f4e6d92b62775 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Thu, 26 Aug 2021 00:24:49 -0700 Subject: [PATCH 08/19] Support post-localSGD in Lightning DDP plugin (#8967) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: ananthsub Co-authored-by: Adrian Wälchli --- .../plugins/training_type/ddp.py | 59 +++++++++++++++++++ pytorch_lightning/utilities/distributed.py | 14 +++++ .../plugins/test_ddp_plugin_with_comm_hook.py | 33 ++++++++++- 3 files changed, 105 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 787353be307e69..aeb43fcdebfe44 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -29,17 +29,21 @@ import torch.distributed from torch.nn.parallel.distributed import DistributedDataParallel +from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.distributed import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin +from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import ( + _FAIRSCALE_AVAILABLE, _HYDRA_AVAILABLE, _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, + _TORCH_GREATER_EQUAL_1_10, rank_zero_deprecation, rank_zero_warn, ) @@ -53,11 +57,19 @@ from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException from pytorch_lightning.utilities.seed import reset_seed +if _TORCH_GREATER_EQUAL_1_10: + from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer + +if _FAIRSCALE_AVAILABLE: + from fairscale.optim import OSS if _HYDRA_AVAILABLE: from hydra.core.hydra_config import HydraConfig from hydra.utils import get_original_cwd, to_absolute_path if _TORCH_GREATER_EQUAL_1_8: from pytorch_lightning.utilities.distributed import register_ddp_comm_hook +if _TORCH_GREATER_EQUAL_1_10: + import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD + import torch.distributed.algorithms.model_averaging.averagers as averagers log = logging.getLogger(__name__) @@ -83,6 +95,7 @@ def __init__( ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[callable] = None, ddp_comm_wrapper: Optional[callable] = None, + model_averaging_period: Optional[int] = None, **kwargs: Union[Any, Dict[str, Any]], ) -> None: super().__init__( @@ -110,6 +123,7 @@ def __init__( self._ddp_comm_state = ddp_comm_state self._ddp_comm_hook = ddp_comm_hook self._ddp_comm_wrapper = ddp_comm_wrapper + self._model_averaging_period = model_averaging_period self._pids: Optional[List[int]] = None self._sync_dir: Optional[str] = None self.set_world_ranks() @@ -302,6 +316,51 @@ def _register_ddp_hooks(self) -> None: ddp_comm_wrapper=self._ddp_comm_wrapper, ) + # Post-localSDG is only available after 1.9, + # and `torch.distributed.optim` package currently is not available on Windows. + if ( + _TORCH_GREATER_EQUAL_1_10 + and isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState) + and self.lightning_module.trainer.state.fn == TrainerFn.FITTING + ): + self._reinit_optimizers_with_post_localSGD(self._ddp_comm_state.start_localSGD_iter) + + def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int): + optimizers = self.lightning_module.trainer.optimizers + if self._model_averaging_period is None: + raise ValueError( + "Post-localSGD algorithm is used, " "but model averaging period is not provided to DDP plugin." + ) + averager = averagers.PeriodicModelAverager(period=self._model_averaging_period, warmup_steps=warmup_steps) + for x, optimizer in enumerate(optimizers): + if isinstance(optimizer, LightningOptimizer): + optimizer = optimizer._optimizer + + if ( + isinstance(optimizer, DistributedOptimizer) + or isinstance(optimizer, ZeroRedundancyOptimizer) + or (_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS)) + ): + raise ValueError( + f"Cannot wrap a distributed optimizer of type {optimizer.__name__} by PostLocalSGDOptimizer." + ) + + if isinstance(optimizer, PostLocalSGDOptimizer): + continue + + optim_class = type(optimizer) + post_localSGD_optimizer = PostLocalSGDOptimizer( + params=optimizer.param_groups, + optimizer_class=optim_class, + averager=averager, + **optimizer.defaults, + ) + optimizers[x] = post_localSGD_optimizer + del optimizer + trainer = self.lightning_module.trainer + trainer.optimizers = optimizers + trainer.convert_to_lightning_optimizers() + def configure_ddp(self): self.pre_configure_ddp() self._model = DistributedDataParallel( diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 68868a0ff74cdc..4f254b6824489c 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -284,12 +284,14 @@ def register_ddp_comm_hook( .. warning :: DDP communication wrapper needs pytorch version at least 1.9.0 + Post-localSGD hook needs pytorch version at least 1.9.0 Example: from torch.distributed.algorithms.ddp_comm_hooks import ( default_hooks as default, powerSGD_hook as powerSGD, + post_localSGD_hook as post_localSGD, ) # fp16_compress_hook for compress gradients @@ -309,6 +311,18 @@ def register_ddp_comm_hook( ddp_comm_hook=powerSGD.powerSGD_hook, ) + # post_localSGD_hook + subgroup, _ = torch.distributed.new_subgroups() + register_comm_hook( + model=ddp_model, + state=post_localSGD.PostLocalSGDState( + process_group=None, + subgroup=subgroup, + start_localSGD_iter=1_000, + ), + ddp_comm_hook=post_localSGD.post_localSGD_hook, + ) + # fp16_compress_wrapper combined with other communication hook register_ddp_comm_hook( model=ddp_model, diff --git a/tests/plugins/test_ddp_plugin_with_comm_hook.py b/tests/plugins/test_ddp_plugin_with_comm_hook.py index 49c4cd18ef316f..1e5968f1a0a5a1 100644 --- a/tests/plugins/test_ddp_plugin_with_comm_hook.py +++ b/tests/plugins/test_ddp_plugin_with_comm_hook.py @@ -15,13 +15,15 @@ from pytorch_lightning import Trainer from pytorch_lightning.plugins import DDPPlugin, DDPSpawnPlugin -from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8 +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_10 from tests.helpers import BoringModel from tests.helpers.runif import RunIf if torch.distributed.is_available() and _TORCH_GREATER_EQUAL_1_8: from torch.distributed.algorithms.ddp_comm_hooks import default_hooks as default from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook as powerSGD +if torch.distributed.is_available() and _TORCH_GREATER_EQUAL_1_10: + import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD @RunIf(skip_windows=True, min_torch="1.9.0", min_gpus=2, special=True) @@ -108,3 +110,32 @@ def test_ddp_spawn_fp16_compress_comm_hook(tmpdir): ) trainer.fit(model) assert trainer.state.finished, f"Training failed with {trainer.state}" + + +@RunIf(skip_windows=True, min_torch="1.10.0", min_gpus=2, special=True) +def test_ddp_post_local_sgd_comm_hook(tmpdir): + """Test for DDP post-localSGD hook.""" + model = BoringModel() + + training_type_plugin = DDPPlugin( + ddp_comm_state=post_localSGD.PostLocalSGDState( + process_group=None, + subgroup=None, + start_localSGD_iter=8, + ), + ddp_comm_hook=post_localSGD.post_localSGD_hook, + model_averaging_period=4, + sync_batchnorm=True, + ) + trainer = Trainer( + fast_dev_run=True, + gpus=2, + plugins=[training_type_plugin], + default_root_dir=tmpdir, + sync_batchnorm=True, + ) + trainer.fit(model) + trainer_comm_hook = trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook + expected_comm_hook = post_localSGD.post_localSGD_hook.__qualname__ + assert trainer_comm_hook == expected_comm_hook + assert trainer.state.finished, f"Training failed with {trainer.state}" From 0abd6e94b5cc728135ec1734a7dcefebea6f4e5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 26 Aug 2021 10:02:49 +0200 Subject: [PATCH 09/19] [3 / 3] improvements to saving and loading callback state (#7161) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- .../callbacks/model_checkpoint.py | 10 +++--- pytorch_lightning/trainer/callback_hook.py | 14 ++++---- tests/callbacks/test_callbacks.py | 36 +++++++++++++++++++ 3 files changed, 48 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 414a92af6a66c9..e7daa4ee53cdeb 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -265,16 +265,16 @@ def state_key(self) -> str: save_on_train_epoch_end=self._save_on_train_epoch_end, ) - def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - """ - When pretrain routine starts we build the ckpt dir on the fly - """ - self.__resolve_ckpt_dir(trainer) + def on_init_end(self, trainer: "pl.Trainer") -> None: if self._save_on_train_epoch_end is None: # if the user runs validation multiple times per training epoch, we try to save checkpoint after # validation instead of on train epoch end self._save_on_train_epoch_end = trainer.val_check_interval == 1.0 + def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """When pretrain routine starts we build the ckpt dir on the fly.""" + self.__resolve_ckpt_dir(trainer) + def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._last_time_checked = time.monotonic() diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 36a3e9abb7b7a4..bbfcbb22802a87 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -11,12 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from abc import ABC from copy import deepcopy from typing import Any, Dict, List, Optional, Type, Union import torch +from packaging.version import Version import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback @@ -255,14 +255,14 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: if callback_states is None: return - current_callbacks_type = {type(cb) for cb in self.callbacks} - saved_callbacks_type = set(callback_states.keys()) - difference = saved_callbacks_type.difference(current_callbacks_type) + is_legacy_ckpt = Version(checkpoint["pytorch-lightning_version"]) < Version("1.5.0dev") + current_callbacks_keys = {cb._legacy_state_key if is_legacy_ckpt else cb.state_key for cb in self.callbacks} + difference = callback_states.keys() - current_callbacks_keys if difference: rank_zero_warn( - "Be aware that when using ``resume_from_checkpoint``, " - "callbacks used to create the checkpoint need to be provided. " - f"Please, add the following callbacks: {list(difference)}. ", + "Be aware that when using `resume_from_checkpoint`," + " callbacks used to create the checkpoint need to be provided." + f" Please add the following callbacks: {list(difference)}.", UserWarning, ) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index dc191f4853cc1d..c363638d565d2a 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -12,10 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. from pathlib import Path +from re import escape from unittest.mock import call, Mock +import pytest + from pytorch_lightning import Callback, Trainer +from pytorch_lightning.callbacks import ModelCheckpoint from tests.helpers import BoringModel +from tests.helpers.utils import no_warning_call def test_callbacks_configured_in_model(tmpdir): @@ -132,3 +137,34 @@ def test_resume_callback_state_saved_by_type(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback], resume_from_checkpoint=ckpt_path) trainer.fit(model) assert callback.state == 111 + + +def test_resume_incomplete_callbacks_list_warning(tmpdir): + model = BoringModel() + callback0 = ModelCheckpoint(monitor="epoch") + callback1 = ModelCheckpoint(monitor="global_step") + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + callbacks=[callback0, callback1], + ) + trainer.fit(model) + ckpt_path = trainer.checkpoint_callback.best_model_path + + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + callbacks=[callback1], # one callback is missing! + resume_from_checkpoint=ckpt_path, + ) + with pytest.warns(UserWarning, match=escape(f"Please add the following callbacks: [{repr(callback0.state_key)}]")): + trainer.fit(model) + + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + callbacks=[callback1, callback0], # all callbacks here, order switched + resume_from_checkpoint=ckpt_path, + ) + with no_warning_call(UserWarning, match="Please add the following callbacks:"): + trainer.fit(model) From 6592d0e4545280ea520f278013ad449ef380cc74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 26 Aug 2021 10:36:21 +0200 Subject: [PATCH 10/19] generalize closure api in Lightning (#8642) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí Co-authored-by: thomas chaton Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- CHANGELOG.md | 4 + pyproject.toml | 3 + .../loops/batch/training_batch_loop.py | 165 +++++++++--------- pytorch_lightning/loops/closure.py | 131 ++++++++++++++ pytorch_lightning/loops/utilities.py | 4 +- pytorch_lightning/utilities/debugging.py | 2 +- tests/core/test_lightning_optimizer.py | 3 +- .../loops/test_evaluation_loop_flow.py | 10 +- .../loops/test_training_loop_flow_scalar.py | 17 +- 9 files changed, 248 insertions(+), 91 deletions(-) create mode 100644 pytorch_lightning/loops/closure.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 5232cc793163f7..aefe4c322212be 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -60,6 +60,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Refactored CheckpointConnector to offload validating logic to the checkpoitn IO plugin ([#9045](https://github.com/PyTorchLightning/pytorch-lightning/pull/9045)) +- Loop customization: + * Added `Closure` and `AbstractClosure` classes ([#8642](https://github.com/PyTorchLightning/pytorch-lightning/pull/8642)) + + - Added support for saving and loading state of multiple callbacks of the same type ([#7187](https://github.com/PyTorchLightning/pytorch-lightning/pull/7187)) diff --git a/pyproject.toml b/pyproject.toml index d07f19ef109863..206a4717a1b82b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,8 @@ warn_unused_ignores = "True" allow_redefinition = "True" # disable this rule as the Trainer attributes are defined in the connectors, not in its __init__ disable_error_code = "attr-defined" +# style choices +warn_no_return = "False" # TODO: Fix typing for these modules [[tool.mypy.overrides]] @@ -60,6 +62,7 @@ ignore_errors = "True" [[tool.mypy.overrides]] module = [ "pytorch_lightning.callbacks.pruning", + "pytorch_lightning.loops.closure", "pytorch_lightning.trainer.evaluation_loop", "pytorch_lightning.trainer.connectors.logger_connector", "pytorch_lightning.utilities.argparse", diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 29517ad306ebae..3f94a0181672e4 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -15,7 +15,7 @@ from collections import OrderedDict from contextlib import contextmanager from copy import copy -from functools import partial, update_wrapper +from functools import partial from typing import Any, Callable, Dict, Generator, List, Optional, Tuple import numpy as np @@ -26,6 +26,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.loops.base import Loop +from pytorch_lightning.loops.closure import Closure, ClosureResult from pytorch_lightning.loops.utilities import ( _check_training_step_output, _process_training_step_output, @@ -144,12 +145,12 @@ def advance(self, batch, batch_idx): result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer) if result: - self.batch_outputs[opt_idx].append(copy(result.training_step_output)) + self.batch_outputs[opt_idx].append(copy(result.result_collection)) else: # in manual optimization, there is no looping over optimizers result = self._run_optimization(batch_idx, split_batch) if result: - self.batch_outputs[0].append(copy(result.training_step_output)) + self.batch_outputs[0].append(copy(result.result_collection)) def teardown(self) -> None: # release memory @@ -165,7 +166,7 @@ def _run_optimization( split_batch: Any, opt_idx: Optional[int] = None, optimizer: Optional[torch.optim.Optimizer] = None, - ): + ) -> Optional[ClosureResult]: """Runs closure (train step + backward) together with optimization if necessary. Args: @@ -177,8 +178,7 @@ def _run_optimization( # toggle model params self._run_optimization_start(opt_idx, optimizer) - result = AttributeDict() - closure = self._make_closure(split_batch, batch_idx, opt_idx, optimizer, self._hiddens, result) + closure = self._make_closure(split_batch, batch_idx, opt_idx, optimizer, self._hiddens) if self.trainer.fit_loop.should_accumulate(): # For gradient accumulation @@ -199,7 +199,9 @@ def _run_optimization( if self.trainer.lightning_module.automatic_optimization: self._optimizer_step(optimizer, opt_idx, batch_idx, closure) else: - result = self._training_step(split_batch, batch_idx, opt_idx, self._hiddens) + closure() + + result = closure.get_result() if result: # if no result, user decided to skip optimization @@ -211,37 +213,79 @@ def _run_optimization( self._run_optimization_end(opt_idx) return result - def _training_step_and_backward_closure( + def _make_closure( self, split_batch: Any, batch_idx: int, opt_idx: int, optimizer: Optimizer, - hiddens: Tensor, - return_result: AttributeDict, - ) -> Optional[Tensor]: - """Closure for training step and backward + hiddens: Any, + ) -> Closure: + """ + Build a closure object that captures the given arguments and runs the `training_step` function and optionally + other functions such as `backward` and `zero_grad`. + """ + step_fn = self._make_step_fn(split_batch, batch_idx, opt_idx, hiddens) + backward_fn = self._make_backward_fn(batch_idx, optimizer, opt_idx) + zero_grad_fn = self._make_zero_grad_fn(batch_idx, opt_idx, optimizer) + + return Closure( + step_fn=step_fn, + backward_fn=backward_fn, + zero_grad_fn=zero_grad_fn, + profiler=self.trainer.profiler, + ) - Args: - split_batch: the current tbptt split of the batch - batch_idx: the index of the current batch - opt_idx: the index of the current optimizer - optimizer: the current optimizer - hiddens: the hidden state of the recurrent net - return_result: the storage of the trainstep results + def _make_step_fn(self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens: Any) -> Callable[[], dict]: + """Build the step function that runs the `training_step` and processes its output.""" + return partial(self._training_step, split_batch, batch_idx, opt_idx, hiddens) + + def _make_zero_grad_fn(self, batch_idx: int, opt_idx: int, optimizer: Optimizer) -> Optional[Callable[[], None]]: + """ + Build a `zero_grad` function that zeroes the gradients before back-propagation. + Returns ``None`` in the case backward needs to be skipped, e.g., when manual optimization is on. + """ + + def zero_grad_fn(): + self._on_before_zero_grad(optimizer) + self._optimizer_zero_grad(batch_idx, optimizer, opt_idx) + + is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0 + if ( + not self._skip_backward + and self.trainer.lightning_module.automatic_optimization + and is_first_batch_to_accumulate + ): + return zero_grad_fn + + def _make_backward_fn( + self, + batch_idx: int, + optimizer: Optimizer, + opt_idx: int, + ) -> Optional[Callable[[Tensor], Tensor]]: + """ + Build a `backward` function that handles back-propagation through the output produced by the `training_step` + function. Returns ``None`` in the case backward needs to be skipped, e.g., when manual optimization is on. """ - result = self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens) - if result is not None: - return_result.update(result) - return return_result.loss + def backward_fn(loss: Tensor): + self.backward(loss, optimizer, opt_idx) + + # when in dev debugging track the losses + # TODO: remove dev debugger tracking loss history + self.trainer.dev_debugger.track_train_loss_history(batch_idx, loss) + + # check if loss or model weights are nan + if self.trainer.terminate_on_nan: + check_finite_loss(self.trainer.lightning_module, loss) - def _make_closure(self, *closure_args: Any, **closure_kwargs: Any) -> Callable: - """Wraps the training step closure into a partial object which will be called within ``optimizer.step``.""" - partial_func = partial(self._training_step_and_backward_closure, *closure_args, **closure_kwargs) - return update_wrapper(partial_func, self._training_step_and_backward_closure) + return loss - def _process_closure_result(self, opt_closure_result: Optional[AttributeDict]) -> None: + if not self._skip_backward and self.trainer.lightning_module.automatic_optimization: + return backward_fn + + def _process_closure_result(self, opt_closure_result: Optional[ClosureResult]) -> None: """Checks if the closure results is finite and optionally breaks if it is not Args: @@ -286,18 +330,18 @@ def _training_step( _check_training_step_output(self.trainer.lightning_module, training_step_output) - training_step_output, self._hiddens = _process_training_step_output(self.trainer, training_step_output) - if training_step_output is None: + result_collection, self._hiddens = _process_training_step_output(self.trainer, training_step_output) + if result_collection is None: return closure_loss = None loss = None if self.trainer.lightning_module.automatic_optimization: # accumulate loss. if accumulate_grad_batches==1, no effect - closure_loss = training_step_output.minimize / self.trainer.accumulate_grad_batches + closure_loss = result_collection.minimize / self.trainer.accumulate_grad_batches # the loss will get scaled for amp. avoid any modifications to it loss = closure_loss.detach().clone() - return AttributeDict(closure_loss=closure_loss, loss=loss, training_step_output=training_step_output) + return AttributeDict(closure_loss=closure_loss, loss=loss, result_collection=result_collection) def _optimizer_step( self, optimizer: torch.optim.Optimizer, opt_idx: int, batch_idx: int, train_step_and_backward_closure: Callable @@ -438,60 +482,22 @@ def block_ddp_sync_behaviour(self, should_block_sync: bool = False) -> Generator else: yield None - def training_step_and_backward( - self, - split_batch: Any, - batch_idx: int, - opt_idx: int, - optimizer: torch.optim.Optimizer, - hiddens: Optional[Tensor], - ) -> STEP_OUTPUT: - """Wrap forward, zero_grad and backward in a closure so second order methods work""" - with self.trainer.profiler.profile("training_step_and_backward"): - # lightning module hook - result = self._training_step(split_batch, batch_idx, opt_idx, hiddens) - - if not self._skip_backward and self.trainer.lightning_module.automatic_optimization: - is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0 - - if is_first_batch_to_accumulate: - self._on_before_zero_grad(optimizer) - self._optimizer_zero_grad(batch_idx, optimizer, opt_idx) - - # backward pass - if result is not None: - with self.trainer.profiler.profile("backward"): - self.backward(result, optimizer, opt_idx) - - # when in dev debugging track the losses - self.trainer.dev_debugger.track_train_loss_history(batch_idx, result.loss) - - # check if loss or model weights are nan - if self.trainer.terminate_on_nan: - check_finite_loss(self.trainer.lightning_module, result.loss) - - else: - self._warning_cache.warn( - "training_step returned None. If this was on purpose, ignore this warning..." - ) - - return result - def backward( - self, result: STEP_OUTPUT, optimizer: Optional[torch.optim.Optimizer], *args: Any, **kwargs: Any - ) -> None: + self, + loss: Tensor, + optimizer: Optional[torch.optim.Optimizer], + opt_idx: Optional[int] = None, + *args: Any, + **kwargs: Any, + ) -> Tensor: """Performs the backward step. Args: - result: The output of the trainstep (including the loss value) + loss: The loss value to back-propagate on optimizer: Current optimizer being used. ``None`` if using manual optimization. opt_idx: Index of the current optimizer being used. ``None`` if using manual optimization. """ - # backward can be called manually in the training loop - if isinstance(result, Tensor): - self.trainer.accelerator.backward(result, optimizer, *args, **kwargs) - else: - result.closure_loss = self.trainer.accelerator.backward(result.closure_loss, optimizer, *args, **kwargs) + self.trainer.accelerator.backward(loss, optimizer, opt_idx, *args, **kwargs) if not self.trainer.fit_loop.should_accumulate(): # track gradients @@ -499,6 +505,7 @@ def backward( if grad_norm_dict: self.trainer.lightning_module._current_fx_name = "on_after_backward" self.trainer.lightning_module.log_grad_norm(grad_norm_dict) + return loss def _update_running_loss(self, current_loss: Tensor) -> None: """Updates the running loss value with the current value""" diff --git a/pytorch_lightning/loops/closure.py b/pytorch_lightning/loops/closure.py new file mode 100644 index 00000000000000..b47af75199708c --- /dev/null +++ b/pytorch_lightning/loops/closure.py @@ -0,0 +1,131 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Callable, Optional + +from torch import Tensor + +from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler +from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection +from pytorch_lightning.utilities.warnings import WarningCache + + +@dataclass +class ClosureResult: + """A container to hold the result of a :class:`AbstractClosure` call. + + Attributes: + closure_loss: The loss with a graph attached. + loss: A detached copy of the closure loss. + result_collection: A collection of results returned by the closure. + """ + + closure_loss: Optional[Tensor] + loss: Optional[Tensor] + result_collection: Optional[ResultCollection] + + +class AbstractClosure(ABC): + """ + Abstract base class for optimizer closures in Lightning. + + Formally, a closure is binding variables from an external scope to a function that does a computation on these + variables without taking them explicitly as input. This has the benefit that a closure can be passed to an + object which later can call it like a function but without requiring to pass in any arguments. + + This class provides a simple abstraction making the instance of this class callable like a function while capturing + the :class:`ClosureResult` and caching it. + """ + + def __init__(self) -> None: + super().__init__() + self._result: Optional[ClosureResult] = None + + def get_result(self) -> Optional[ClosureResult]: + """The cached result from the last time the closure was called. Once accessed, the internal reference + gets reset and the consumer will have to hold on to the reference as long as necessary.""" + result = self._result + self._result = None # free memory + return result + + @abstractmethod + def closure(self, *args: Any, **kwargs: Any) -> Optional[ClosureResult]: + """Implements the behavior of the closure once it is getting called.""" + pass + + def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]: + self._result = self.closure(*args, **kwargs) + if self._result is not None: + return self._result.loss + + +class Closure(AbstractClosure): + """ + An implementation of a :class:`AbstractClosure` for optimization in Lightning that combines three elementary + closures into one: ``training_step``, ``backward`` and ``zero_grad``. + + The Closure gets created by the training loop(s) and is then passed to the + :meth:`torch.optim.Optimizer.step` method. An optimizer is responsible for calling the closure and optionally + do something with the output. + + Args: + step_fn: This is typically the :meth:`pytorch_lightning.core.lightning.LightningModule.training_step + wrapped with processing for its outputs + backward_fn: A function that takes a loss value as input, performs back-propagation and returns the loss value. + Can be set to ``None`` to skip the backward operation. + zero_grad_fn: A function that zeroes the gradients. Can be set to ``None`` to skip zero_grad, for example + when accumulating gradients. + profiler: A profiler for profiling the actions of the passed in closure functions. + + Example: + + closure = Closure() + optimizer = torch.optim.Adam(...) + optimizer.step(closure) + """ + + warning_cache = WarningCache() + + def __init__( + self, + step_fn: Callable[[], dict], + backward_fn: Optional[Callable[[Tensor], Tensor]] = None, + zero_grad_fn: Optional[Callable[[], None]] = None, + profiler: Optional[BaseProfiler] = None, + ): + super().__init__() + self._step_fn = step_fn + self._backward_fn = backward_fn + self._zero_grad_fn = zero_grad_fn + self._profiler = PassThroughProfiler() if profiler is None else profiler + + def closure(self, *args: Any, **kwargs: Any) -> Optional[ClosureResult]: + with self._profiler.profile("training_step_and_backward"): + step_output = self._step_fn() + step_output = ClosureResult(**step_output) if step_output else None + + if step_output is None: + self.warning_cache.warn("training_step returned None. If this was on purpose, ignore this warning...") + + if self._zero_grad_fn is not None: + with self._profiler.profile("zero_grad"): + self._zero_grad_fn() + + if self._backward_fn is not None and step_output is not None and step_output.closure_loss is not None: + with self._profiler.profile("backward"): + step_output.closure_loss = self._backward_fn(step_output.closure_loss) + + return step_output diff --git a/pytorch_lightning/loops/utilities.py b/pytorch_lightning/loops/utilities.py index 89ba5cd07d4590..585cb685a7ca92 100644 --- a/pytorch_lightning/loops/utilities.py +++ b/pytorch_lightning/loops/utilities.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterator, Mapping, Optional, Tuple +from typing import Any, Iterator, Mapping, Optional, Tuple import torch @@ -65,7 +65,7 @@ def _check_training_step_output(model: "pl.LightningModule", training_step_outpu def _process_training_step_output( trainer: "pl.Trainer", training_step_output: STEP_OUTPUT -) -> Tuple[Optional[ResultCollection], Optional[torch.Tensor]]: +) -> Tuple[Optional[ResultCollection], Optional[Any]]: """Adds the :param:`training_step_output` to the trainer's results Args: diff --git a/pytorch_lightning/utilities/debugging.py b/pytorch_lightning/utilities/debugging.py index ee2b58be106b58..c759f2aee28b7d 100644 --- a/pytorch_lightning/utilities/debugging.py +++ b/pytorch_lightning/utilities/debugging.py @@ -109,7 +109,7 @@ def track_load_dataloader_call(self, name: str, dataloaders: List[DataLoader]) - @enabled_only def track_train_loss_history(self, batch_idx: int, loss: torch.Tensor) -> None: - loss_dict = {"batch_idx": batch_idx, "epoch": self.trainer.current_epoch, "loss": loss.detach()} + loss_dict = {"batch_idx": batch_idx, "epoch": self.trainer.current_epoch, "loss": loss.detach().clone()} self.saved_train_losses.append(loss_dict) @enabled_only diff --git a/tests/core/test_lightning_optimizer.py b/tests/core/test_lightning_optimizer.py index f8c3317b6c5958..4068bc5504b5bf 100644 --- a/tests/core/test_lightning_optimizer.py +++ b/tests/core/test_lightning_optimizer.py @@ -20,6 +20,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.core.optimizer import LightningOptimizer +from pytorch_lightning.loops.closure import Closure from tests.helpers.boring_model import BoringModel @@ -221,7 +222,7 @@ def training_epoch_end(self, outputs): ... def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, **_): - assert optimizer_closure.__name__ == "_training_step_and_backward_closure" + assert isinstance(optimizer_closure, Closure) # not passing the closure to the optimizer because step is mocked # zero_grad is called inside the closure if isinstance(optimizer, SGD) and batch_idx % 2 == 0: diff --git a/tests/trainer/loops/test_evaluation_loop_flow.py b/tests/trainer/loops/test_evaluation_loop_flow.py index 6c0b7b7ed40d7c..78dd4bce728d42 100644 --- a/tests/trainer/loops/test_evaluation_loop_flow.py +++ b/tests/trainer/loops/test_evaluation_loop_flow.py @@ -78,10 +78,11 @@ def backward(self, loss, optimizer, optimizer_idx): assert train_step_out.minimize.item() == 171 # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.fit_loop.epoch_loop.batch_loop.training_step_and_backward( + opt_closure = trainer.fit_loop.epoch_loop.batch_loop._make_closure( batch, batch_idx, 0, trainer.optimizers[0], hiddens=None ) - assert opt_closure_result["loss"].item() == 171 + opt_closure_result = opt_closure() + assert opt_closure_result.item() == 171 def test__eval_step__eval_step_end__flow(tmpdir): @@ -144,10 +145,11 @@ def backward(self, loss, optimizer, optimizer_idx): assert train_step_out.minimize.item() == 171 # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.fit_loop.epoch_loop.batch_loop.training_step_and_backward( + opt_closure = trainer.fit_loop.epoch_loop.batch_loop._make_closure( batch, batch_idx, 0, trainer.optimizers[0], hiddens=None ) - assert opt_closure_result["loss"].item() == 171 + opt_closure_result = opt_closure() + assert opt_closure_result.item() == 171 def test__eval_step__epoch_end__flow(tmpdir): diff --git a/tests/trainer/loops/test_training_loop_flow_scalar.py b/tests/trainer/loops/test_training_loop_flow_scalar.py index 4ee9d858d44c95..692d2420bfd82d 100644 --- a/tests/trainer/loops/test_training_loop_flow_scalar.py +++ b/tests/trainer/loops/test_training_loop_flow_scalar.py @@ -18,6 +18,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.loops.closure import Closure from pytorch_lightning.trainer.states import RunningStage from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.deterministic_model import DeterministicModel @@ -156,10 +157,11 @@ def backward(self, loss, optimizer, optimizer_idx): assert train_step_out.minimize.item() == 171 # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.fit_loop.epoch_loop.batch_loop.training_step_and_backward( + opt_closure = trainer.fit_loop.epoch_loop.batch_loop._make_closure( batch, batch_idx, 0, trainer.optimizers[0], hiddens=None ) - assert opt_closure_result["loss"].item() == 171 + opt_closure_result = opt_closure() + assert opt_closure_result.item() == 171 def test__training_step__step_end__epoch_end__flow_scalar(tmpdir): @@ -229,10 +231,11 @@ def backward(self, loss, optimizer, optimizer_idx): assert train_step_out.minimize.item() == 171 # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.fit_loop.epoch_loop.batch_loop.training_step_and_backward( + opt_closure = trainer.fit_loop.epoch_loop.batch_loop._make_closure( batch, batch_idx, 0, trainer.optimizers[0], hiddens=None ) - assert opt_closure_result["loss"].item() == 171 + opt_closure_result = opt_closure() + assert opt_closure_result.item() == 171 def test_train_step_no_return(tmpdir): @@ -258,6 +261,8 @@ def validation_epoch_end(self, outputs): trainer = Trainer(**trainer_args) + Closure.warning_cache.clear() + with pytest.warns(UserWarning, match=r"training_step returned None.*"): trainer.fit(model) @@ -268,6 +273,8 @@ def validation_epoch_end(self, outputs): model.automatic_optimization = False trainer = Trainer(**trainer_args) + Closure.warning_cache.clear() + with no_warning_call(UserWarning, match=r"training_step returned None.*"): trainer.fit(model) @@ -293,6 +300,8 @@ def training_step(self, batch, batch_idx): checkpoint_callback=False, ) + Closure.warning_cache.clear() + with pytest.warns(UserWarning, match=r".*training_step returned None.*"): trainer.fit(model) From 02612e14efe9b1f92a00557ebc0def65acb14b1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 26 Aug 2021 11:36:29 +0200 Subject: [PATCH 11/19] remove redundant iterator call to data fetcher in loops (#9117) Co-authored-by: tchaton --- .../loops/dataloader/evaluation_loop.py | 3 +-- .../loops/epoch/evaluation_epoch_loop.py | 11 ++++++----- pytorch_lightning/loops/fit_loop.py | 5 ++--- pytorch_lightning/loops/utilities.py | 12 +++++++----- 4 files changed, 16 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 7f06f5cd4ff638..68b75b68eb91ba 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -101,11 +101,10 @@ def advance(self, *args: Any, **kwargs: Any) -> None: dataloader_idx: int = self.current_dataloader_idx dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader) dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader, dataloader_idx=dataloader_idx) - dataloader_iter = iter(dataloader) dl_max_batches = self._max_batches[dataloader_idx] - dl_outputs = self.epoch_loop.run(dataloader_iter, dataloader_idx, dl_max_batches, self.num_dataloaders) + dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders) # store batch level output per dataloader if self.should_track_batch_outputs_for_epoch_end: diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index e4770084c84cd4..158d4cf527143b 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -21,6 +21,7 @@ from pytorch_lightning.loops.base import Loop from pytorch_lightning.loops.utilities import _prepare_dataloader_iter from pytorch_lightning.trainer.progress import Progress +from pytorch_lightning.utilities.fetching import AbstractDataFetcher from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -58,12 +59,12 @@ def reset(self) -> None: self.batch_progress.current.reset() def on_run_start( - self, dataloader_iter: Iterator, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int + self, data_fetcher: AbstractDataFetcher, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int ) -> None: """Adds the passed arguments to the loop's state if necessary Args: - dataloader_iter: iterator over the dataloader + data_fetcher: the current data_fetcher wrapping the dataloader dataloader_idx: index of the current dataloader dl_max_batches: maximum number of batches the dataloader can produce num_dataloaders: the total number of dataloaders @@ -72,10 +73,10 @@ def on_run_start( self._dl_max_batches = dl_max_batches self._num_dataloaders = num_dataloaders - self.dataloader_iter = _prepare_dataloader_iter(dataloader_iter, self.batch_progress.current.ready) + self.dataloader_iter = _prepare_dataloader_iter(data_fetcher, self.batch_progress.current.ready) def advance( - self, dataloader_iter: Iterator, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int + self, data_fetcher: AbstractDataFetcher, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int ) -> None: """Calls the evaluation step with the corresponding hooks and updates the logger connector. @@ -88,7 +89,7 @@ def advance( Raises: StopIteration: If the current batch is None """ - void(dataloader_iter, dl_max_batches, num_dataloaders) + void(data_fetcher, dl_max_batches, num_dataloaders) batch_idx, (batch, _) = next(self.dataloader_iter) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 49af10d4b2c0d9..4f3f8951f2b7a4 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -193,12 +193,11 @@ def on_advance_start(self) -> None: def advance(self) -> None: """Runs one whole epoch.""" dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader) - dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader) - dataloader_iter = iter(dataloader) + data_fetcher = self.trainer.data_connector.get_profiled_dataloader(dataloader) with self.trainer.profiler.profile("run_training_epoch"): # run train epoch - epoch_output = self.epoch_loop.run(dataloader_iter) + epoch_output = self.epoch_loop.run(data_fetcher) if epoch_output is None: return diff --git a/pytorch_lightning/loops/utilities.py b/pytorch_lightning/loops/utilities.py index 585cb685a7ca92..dd69640106af8f 100644 --- a/pytorch_lightning/loops/utilities.py +++ b/pytorch_lightning/loops/utilities.py @@ -20,7 +20,7 @@ from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.fetching import DataLoaderIterDataFetcher +from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher from pytorch_lightning.utilities.finite_checks import detect_nan_parameters from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -105,9 +105,11 @@ def _process_training_step_output( return results, hiddens -def _prepare_dataloader_iter(dataloader_iter: Iterator, batch_idx: int) -> Iterator: +def _prepare_dataloader_iter(data_fetcher: AbstractDataFetcher, batch_idx: int) -> Iterator: """Attach the dataloader""" - if not isinstance(dataloader_iter, DataLoaderIterDataFetcher): - dataloader_iter = enumerate(dataloader_iter, batch_idx) - # restore iteration + if not isinstance(data_fetcher, DataLoaderIterDataFetcher): + # restore iteration + dataloader_iter = enumerate(data_fetcher, batch_idx) + else: + dataloader_iter = iter(data_fetcher) return dataloader_iter From 5c2b7cadcdce49dc2d85f75089a254f20ba2510d Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Thu, 26 Aug 2021 15:10:59 +0530 Subject: [PATCH 12/19] Remove ci tpu test from github workflows (#8965) Co-authored-by: Jirka --- .github/workflows/ci_test-tpu.yml | 144 ------------------------------ README.md | 16 ++-- 2 files changed, 8 insertions(+), 152 deletions(-) delete mode 100644 .github/workflows/ci_test-tpu.yml diff --git a/.github/workflows/ci_test-tpu.yml b/.github/workflows/ci_test-tpu.yml deleted file mode 100644 index 22bb7bd7cd4e5f..00000000000000 --- a/.github/workflows/ci_test-tpu.yml +++ /dev/null @@ -1,144 +0,0 @@ -name: TPU tests - -on: - push: - branches: [master, "release/*"] -# TODO: temporal disable TPU testing until we find way how to pass credentials to forked PRs -# pull_request: -# branches: -# - master - -env: - GKE_CLUSTER: lightning-cluster - GKE_ZONE: us-central1-a - IMAGE: gcr.io/${{ secrets.GKE_PROJECT }}/tpu-testing-image - MAX_CHECKS: 360 - CHECK_SPEEP: 5 - -jobs: - setup-build-publish-deploy: - name: tpu-testing-job - runs-on: ubuntu-20.04 - strategy: - fail-fast: false - matrix: - python-version: [3.7] - xla-version: [1.6, 1.8] - # Timeout: https://stackoverflow.com/a/59076067/4521646 - timeout-minutes: 50 - - steps: - - name: Set IMAGETAG - run: echo "IMAGETAG=$(date +%s)_${{ matrix.python-version }}" >> $GITHUB_ENV - - name: Install Go - uses: actions/setup-go@v2 - with: - go-version: 1.14.x - - name: Set up Python 3.7 - uses: actions/setup-python@v2 - with: - python-version: 3.7 - - - name: Checkout Pytorch Lightning - uses: actions/checkout@v2 - with: - repository: PyTorchLightning/pytorch-lightning - ref: ${{ github.event.pull_request.head.sha }} - - - name: Checkout ml-testing-accelerators - uses: actions/checkout@v2 - with: - repository: GoogleCloudPlatform/ml-testing-accelerators - path: ml-testing-accelerators - ref: 5e88ac24f631c27045e62f0e8d5dfcf34e425e25 - - - name: Setup gcloud CLI - uses: GoogleCloudPlatform/github-actions/setup-gcloud@master - with: - version: '290.0.1' - service_account_key: ${{ secrets.GKE_SA_KEY_BASE64 }} - project_id: ${{ secrets.GKE_PROJECT }} - export_default_credentials: true - - # Configure Docker to use the gcloud command-line tool as a credential helper for authentication. - - name: Configure Docker - run: |- - gcloud --quiet auth configure-docker - shell: bash - - name: Build and Push Docker Image - env: - PYTHON_VER: ${{ matrix.python-version }} - XLA_VER: ${{ matrix.xla-version }} - run: | - #cd dockers/tpu-tests - docker build --tag "$IMAGE:$IMAGETAG" -f ./dockers/tpu-tests/Dockerfile --build-arg "PYTHON_VERSION=$PYTHON_VER" --build-arg "PYTORCH_VERSION=$XLA_VER" . - docker push "$IMAGE:$IMAGETAG" - shell: bash - - - name: Install jsonnet - run: |- - go get github.com/google/go-jsonnet/cmd/jsonnet - shell: bash - # Get the GKE credentials so we can deploy to the cluster - # Use either zone or region depending on cluster setup. - - run: |- - gcloud container clusters get-credentials "$GKE_CLUSTER" --zone "$GKE_ZONE" - shell: bash - - - name: Deploy the job on the kubernetes cluster - env: - XLA_VER: ${{ matrix.xla-version }} - run: |- - python -c "fname = 'dockers/tpu-tests/tpu_test_cases.jsonnet' ; ttt = open(fname).read().replace('pytorch-VERSION', 'pytorch-$XLA_VER') ; open(fname, 'w').write(ttt)" - job_name=$(jsonnet -J ml-testing-accelerators/ dockers/tpu-tests/tpu_test_cases.jsonnet --ext-str image=$IMAGE --ext-str image-tag=$IMAGETAG | kubectl create -f -) && \ - job_name=${job_name#job.batch/} && \ - job_name=${job_name% created} && \ - echo "Waiting on kubernetes job: $job_name in cluster: $GKE_CLUSTER" && \ - i=0 && \ - # 60 checks spaced 30s apart = 900s total. - status_code=2 && \ - # Check on the job periodically. Set the status code depending on what - # happened to the job in Kubernetes. If we try MAX_CHECKS times and - # still the job hasn't finished, give up and return the starting - # non-zero status code. - printf "Waiting for job to finish: " && \ - while [ $i -lt $MAX_CHECKS ]; do ((i++)); if kubectl get jobs $job_name -o jsonpath='Failed:{.status.failed}' | grep "Failed:1"; then status_code=1 && break; elif kubectl get jobs $job_name -o jsonpath='Succeeded:{.status.succeeded}' | grep "Succeeded:1" ; then status_code=0 && break; else printf "." ; fi; sleep $CHECK_SPEEP; done && \ - echo "Done waiting. Job status code: $status_code" && \ - pod_name=$(kubectl get po -l controller-uid=`kubectl get job $job_name -o "jsonpath={.metadata.labels.controller-uid}"` | awk 'match($0,!/NAME/) {print $1}') && \ - echo "GKE pod name: $pod_name" && \ - kubectl logs -f $pod_name --container=train > /tmp/full_output.txt - if grep -q '' /tmp/full_output.txt ; then csplit /tmp/full_output.txt '//'; else mv /tmp/full_output.txt xx00; fi && \ - # First portion is the test logs. Print these to Github Action stdout. - cat xx00 && \ - echo "Done with log retrieval attempt." && \ - gcloud container images delete "$IMAGE:$IMAGETAG" --force-delete-tags && \ - echo "Status code: $status_code" - exit $status_code - shell: bash - - - name: Statistics - if: success() - run: | - mv ./xx01 coverage - # TODO: add human readable report - cat coverage - # sudo pip install pycobertura - # pycobertura show coverage.xml - - - name: Upload coverage results - uses: actions/upload-artifact@v2 - with: - name: coverage-TPU - path: coverage - - - name: Upload coverage to Codecov - uses: codecov/codecov-action@v1 - # see: https://github.com/actions/toolkit/issues/399 - continue-on-error: true - if: always() - with: - token: ${{ secrets.CODECOV_TOKEN }} - file: coverage - flags: tpu,pytest - name: TPU-coverage - fail_ci_if_error: true diff --git a/README.md b/README.md index d31d6528504584..c80995293ae9d6 100644 --- a/README.md +++ b/README.md @@ -78,14 +78,14 @@ Lightning is rigorously tested across multiple GPUs, TPUs CPUs and against major
-| System / PyTorch ver. | 1.6 (min. req.) | 1.7 | 1.8 (LTS) | 1.9 (latest) | 1.10 (nightly) | -| :----------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | -| Conda py3.7 \[linux\] | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | -| Linux py3.7 \[GPUs\*\*\] | - | - | [![Build Status]()](https://dev.azure.com/PytorchLightning/pytorch-lightning/_build/latest?definitionId=6&branchName=master) | - | - | -| Linux py3.{6,7} \[TPUs\*\*\*\] | [![TPU tests](https://github.com/PyTorchLightning/pytorch-lightning/workflows/TPU%20tests/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22TPU+tests%22+branch%3Amaster) | - | [![TPU tests](https://github.com/PyTorchLightning/pytorch-lightning/workflows/TPU%20tests/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22TPU+tests%22+branch%3Amaster) | - | - | -| Linux py3.{6,7,8,9} | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | -| OSX py3.{6,7,8,9} | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | -| Windows py3.{6,7,8,9} | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | +| System / PyTorch ver. | 1.6 (min. req.) | 1.7 | 1.8 (LTS) | 1.9 (latest) | 1.10 (nightly) | +| :------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| Conda py3.7 \[linux\] | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | +| Linux py3.7 \[GPUs\*\*\] | - | - | [![Build Status]()](https://dev.azure.com/PytorchLightning/pytorch-lightning/_build/latest?definitionId=6&branchName=master) | - | - | +| Linux py3.7 \[TPUs\*\*\*\] | - | - | [![CircleCI](https://circleci.com/gh/PyTorchLightning/pytorch-lightning/tree/master.svg?style=svg)](https://circleci.com/gh/PyTorchLightning/pytorch-lightning/tree/master) | - | - | +| Linux py3.{6,7,8,9} | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | +| OSX py3.{6,7,8,9} | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | +| Windows py3.{6,7,8,9} | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - _\*\* tests run on two NVIDIA P100_ - _\*\*\* tests run on Google GKE TPUv2/3_ From 8efdeb2c00888a98d2571eec96767aabc32eae1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 26 Aug 2021 12:28:14 +0200 Subject: [PATCH 13/19] deprecate the TestTubeLogger (#9065) Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> --- CHANGELOG.md | 5 ++++- pytorch_lightning/loggers/test_tube.py | 10 +++++++++- tests/deprecated_api/test_remove_1-7.py | 8 ++++++++ 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index aefe4c322212be..3d8a9b440c74df 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -151,7 +151,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `DataModule` properties: `train_transforms`, `val_transforms`, `test_transforms`, `size`, `dims` ([#8851](https://github.com/PyTorchLightning/pytorch-lightning/pull/8851)) -- Deprecated `prepare_data_per_node` flag on Trainer and set it as a property of `DataHooks`, accessible in the `LightningModule` and `LightningDataModule` [#8958](https://github.com/PyTorchLightning/pytorch-lightning/pull/8958) +- Deprecated `prepare_data_per_node` flag on Trainer and set it as a property of `DataHooks`, accessible in the `LightningModule` and `LightningDataModule` ([#8958](https://github.com/PyTorchLightning/pytorch-lightning/pull/8958)) + + +- Deprecated the `TestTubeLogger` ([#9065](https://github.com/PyTorchLightning/pytorch-lightning/pull/9065)) ### Removed diff --git a/pytorch_lightning/loggers/test_tube.py b/pytorch_lightning/loggers/test_tube.py index bbe4897b47528b..9a3b6ccee64df3 100644 --- a/pytorch_lightning/loggers/test_tube.py +++ b/pytorch_lightning/loggers/test_tube.py @@ -20,7 +20,7 @@ import pytorch_lightning as pl from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment -from pytorch_lightning.utilities import _module_available, rank_zero_warn +from pytorch_lightning.utilities import _module_available, rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.distributed import rank_zero_only _TESTTUBE_AVAILABLE = _module_available("test_tube") @@ -36,6 +36,10 @@ class TestTubeLogger(LightningLoggerBase): Log to local file system in `TensorBoard `_ format but using a nicer folder structure (see `full docs `_). + Warning: + The test-tube package is no longer maintained and PyTorch Lightning will remove the :class:´TestTubeLogger´ + in v1.7.0. + Install it with pip: .. code-block:: bash @@ -97,6 +101,10 @@ def __init__( log_graph: bool = False, prefix: str = "", ): + rank_zero_deprecation( + "The TestTubeLogger is deprecated since v1.5 and will be removed in v1.7. We recommend switching to the" + " `pytorch_lightning.loggers.TensorBoardLogger` as an alternative." + ) if Experiment is None: raise ImportError( "You want to use `test_tube` logger which is not installed yet," diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 7581bf2b0c142d..ae8f9e1dcc53db 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. """ Test deprecated functionality which will be removed in v1.7.0 """ +from unittest import mock import pytest from pytorch_lightning import LightningDataModule, Trainer +from pytorch_lightning.loggers import TestTubeLogger from tests.deprecated_api import _soft_unimport_module from tests.helpers import BoringModel from tests.helpers.datamodules import MNISTDataModule @@ -87,3 +89,9 @@ def test_v1_7_0_trainer_prepare_data_per_node(tmpdir): match="Setting `prepare_data_per_node` with the trainer flag is deprecated and will be removed in v1.7.0!" ): _ = Trainer(prepare_data_per_node=False) + + +@mock.patch("pytorch_lightning.loggers.test_tube.Experiment") +def test_v1_7_0_test_tube_logger(_, tmpdir): + with pytest.deprecated_call(match="The TestTubeLogger is deprecated since v1.5 and will be removed in v1.7"): + _ = TestTubeLogger(tmpdir) From 0752bcd0ebab4f29540b57fd0a77085152357527 Mon Sep 17 00:00:00 2001 From: Burhanuddin Rangwala Date: Thu, 26 Aug 2021 16:31:46 +0530 Subject: [PATCH 14/19] Added doc strings for Comet logger (#9114) --- pytorch_lightning/loggers/comet.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/pytorch_lightning/loggers/comet.py b/pytorch_lightning/loggers/comet.py index 1799def445ea4a..81084d61b6bd25 100644 --- a/pytorch_lightning/loggers/comet.py +++ b/pytorch_lightning/loggers/comet.py @@ -268,10 +268,22 @@ def finalize(self, status: str) -> None: @property def save_dir(self) -> Optional[str]: + """ + Gets the save directory. + + Returns: + The path to the save directory. + """ return self._save_dir @property def name(self) -> str: + """ + Gets the project name. + + Returns: + The project name if it is specified, else "comet-default". + """ # Don't create an experiment if we don't have one if self._experiment is not None and self._experiment.project_name is not None: return self._experiment.project_name @@ -283,6 +295,19 @@ def name(self) -> str: @property def version(self) -> str: + """ + Gets the version. + + Returns: + The first one of the following that is set in the following order + + 1. experiment id. + 2. experiment key. + 3. "COMET_EXPERIMENT_KEY" environment variable. + 4. future experiment key. + + If none are present generates a new guid. + """ # Don't create an experiment if we don't have one if self._experiment is not None: return self._experiment.id From b13749b4ec0931f16204d433f44b1e7e0775689b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 26 Aug 2021 14:13:31 +0200 Subject: [PATCH 15/19] add fault-tolerance for global random state in map-style datasets (#8950) Co-authored-by: tchaton Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Co-authored-by: Carlos Mocholi Co-authored-by: Jirka Borovec --- CHANGELOG.md | 1 + pytorch_lightning/loops/fit_loop.py | 18 +- pytorch_lightning/trainer/supporters.py | 151 +++++++----- pytorch_lightning/utilities/auto_restart.py | 38 ++- tests/loops/test_loop_state_dict.py | 12 +- tests/loops/test_loops.py | 8 +- tests/utilities/test_auto_restart.py | 246 +++++++++++++++++++- 7 files changed, 402 insertions(+), 72 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3d8a9b440c74df..36d90ae213fdb9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,6 +54,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added `CaptureMapDataset` for state management in map-style datasets ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891)) * Added Fault Tolerant Training to `DataFetcher` ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891)) * Replaced old prefetch iterator with new `DataFetcher` in training loop ([#8953](https://github.com/PyTorchLightning/pytorch-lightning/pull/8953)) + * Added partial support for global random state fault-tolerance in map-style datasets ([#8950](https://github.com/PyTorchLightning/pytorch-lightning/pull/8950)) - Checkpoint saving & loading extensibility: * Added `CheckpointIO` to expose checkpoint IO from training type plugin ([#8743](https://github.com/PyTorchLightning/pytorch-lightning/pull/8743)) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 4f3f8951f2b7a4..4a09c0ca1faebd 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -14,7 +14,7 @@ import logging from contextlib import suppress -from typing import Optional +from typing import Any, Dict, Optional from pytorch_lightning.loops import Loop from pytorch_lightning.loops.epoch import TrainingEpochLoop @@ -40,6 +40,8 @@ def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] = self.min_epochs = min_epochs self.epoch_loop: Optional[TrainingEpochLoop] = None self.epoch_progress = Progress() + # caches the loaded dataloader state until dataloader objects are available + self._dataloader_state_dict: Dict[str, Any] = {} @property def current_epoch(self) -> int: @@ -175,6 +177,10 @@ def on_advance_start(self) -> None: if self.current_epoch != 0 and self.trainer._should_reload_dl_epoch: self.trainer.reset_train_dataloader(model) + if self._dataloader_state_dict: + self.trainer.train_dataloader.load_state_dict(self._dataloader_state_dict) + self._dataloader_state_dict = {} + # TODO: specify the possible exception with suppress(Exception): # set seed for distributed sampler (enables shuffling for each epoch) @@ -234,3 +240,13 @@ def should_accumulate(self) -> bool: def teardown(self) -> None: self.epoch_loop.teardown() + + def on_save_checkpoint(self) -> Dict: + state_dict = super().on_save_checkpoint() + # FIXME(@tchaton) Should pass has_completed=True when iterator is exhausted ? + state_dict["dataloader_state_dict"] = self.trainer.train_dataloader.state_dict(has_completed=False) + return state_dict + + def on_load_checkpoint(self, state_dict: Dict) -> None: + # cache the dataloader state dict until the dataloader objects are available + self._dataloader_state_dict = state_dict.get("dataloader_state_dict", {}) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 0e747a9e4857d8..4c0ddddd2c234f 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -13,20 +13,23 @@ # limitations under the License. from collections.abc import Iterable, Iterator, Mapping, Sequence -from dataclasses import dataclass, field +from dataclasses import asdict, dataclass, field from functools import partial from typing import Any, Callable, Dict, List, Optional, Union import torch from torch.utils.data import Dataset -from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, DataLoader +from torch.utils.data.dataloader import _BaseDataLoaderIter, DataLoader from torch.utils.data.dataset import IterableDataset from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections from pytorch_lightning.utilities.auto_restart import ( - _cycle_to_next_worker_and_reset, - _find_current_worker, + _find_fast_forward_samplers, CaptureIterableDataset, + CaptureMapDataset, + IteratorState, + MergedIteratorState, + patch_dataloader_iterator, ) from pytorch_lightning.utilities.data import get_len from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -167,6 +170,7 @@ def __init__(self, loader: Any, length: Optional[int] = None, state: SharedCycle self.loader = loader self._loader_iter = None self.counter = 0 + self.state = state def __iter__(self) -> Any: """ @@ -176,6 +180,7 @@ def __iter__(self) -> Any: CycleIterator: self """ self.counter = 0 + self.state.reset() self._loader_iter = iter(self.loader) return self @@ -205,6 +210,12 @@ def __next__(self) -> Any: raise StopIteration self._loader_iter = iter(self.loader) + # if fault tolerant is enabled, we need to patch the iterator to collect the states + # before the batch gets returned. + fetcher = getattr(self.loader, "_lightning_fetcher", None) + if fetcher: + patch_dataloader_iterator(self.loader, self._loader_iter, fetcher) + return next(self._loader_iter) finally: @@ -302,11 +313,6 @@ def __len__(self) -> int: return self._calc_num_data(self.datasets, self.mode) -class DataLoaderDict(Dict): - # behaves exactly like a dict, this is used to simplify apply_to_collection. - pass - - class CombinedLoader: """ Combines different dataloaders and allows sampling in parallel. @@ -360,80 +366,110 @@ def __init__(self, loaders: Any, mode: str = "min_size"): self._iterator = None # assigned in __iter__ @staticmethod - def _state_dict_fn(dataloader: DataLoader, iterator: Optional[Iterator], num_batches_processed: int) -> Dict: - # find next worker if multiple workers were used - state = _find_current_worker(iterator) - if isinstance(dataloader.dataset, CaptureIterableDataset): - # the sampler state dict are extracted in `CombinedLoaderIterator` - if iterator is not None and getattr(iterator, "_sampler_state_dict", None) is not None: - state.update(iterator._sampler_state_dict[0]) - else: - # fetch directly from fast forward sampler - state.update(dataloader.fast_forward_sampler.state_dict(num_batches_processed)) - return DataLoaderDict(state) - - def state_dict(self, num_batches_processed: int) -> Dict: + def _state_dict_fn(dataloader: DataLoader, iterator: Optional[Iterator], has_completed: int) -> Dict: + if isinstance(dataloader, CycleIterator): + iterator = dataloader._loader_iter + state = getattr(iterator, "state", None) if has_completed else getattr(iterator, "previous_state", None) + if state: + return asdict(state) + return {} + + def state_dict(self, has_completed: bool = False) -> Dict: """ The state dict includes all states from wrapped dataloaders and their samplers through the ``CaptureIterableDataset`` and fast-forward samplers. Args: - num_batches_processed: The number of batches processed so far, needed because the individual dataloaders - may have already prefetched more batches by the time a state dict is requested. + has_completed: whether the current state of data fetching is considered completed or not. If it is, the + current state gets returned, otherwise the previously cached state. """ - if not _fault_tolerant_training(): - return DataLoaderDict() - - state_dict_fn = partial(self._state_dict_fn, num_batches_processed=num_batches_processed) + if not _fault_tolerant_training() or self._iterator is None: + return {} - return apply_to_collections(self.loaders, self._iterator.loader_iters, (Iterator, DataLoader), state_dict_fn) + return apply_to_collections( + self.loaders, + self._iterator.loader_iters, + (Iterator, DataLoader), + partial(self._state_dict_fn, has_completed=has_completed), + ) - def load_state_dict(self, state_dict): + def load_state_dict(self, state_dict) -> None: # store the samplers state. # They would be reloaded once the `CombinedIterator` as been created # and the workers are created. self._loaders_iter_state_dict = state_dict - def mock_reset_fn(self, *_, **__): - pass - - # mock reset call, so we can rotate the `_worker_queue_idx_cycle` to failed worker - # and get the first batch from it - _MultiProcessingDataLoaderIter._original_reset = _MultiProcessingDataLoaderIter._reset - _MultiProcessingDataLoaderIter._reset = mock_reset_fn - - def on_restart(self, iterator: Iterator): + def on_restart(self, iterator: Iterator) -> None: if not self._loaders_iter_state_dict: return - # this happen inside the workers if any were specificied. + def create_loader_iters(dataloader: DataLoader, state_dict: Dict) -> Iterator: + """Function used to reload the iterator state before once the workers are created.""" + + dataloader_to_iter_on = dataloader + if isinstance(dataloader, CycleIterator): + dataloader = dataloader_to_iter_on.loader + + dataset = dataloader.dataset + + # We reload the states before creating the workers + # The specific type of dataset will then decide if the state should be applied before or after + # spawning the workers + if isinstance(dataset, CaptureMapDataset): + iterator_state = state_dict["state"][0] + + if not isinstance(iterator_state, IteratorState): + iterator_state = IteratorState.from_state_dict(iterator_state) + + # reload sampler state + ff_sampler = _find_fast_forward_samplers(dataloader) + ff_sampler.load_state_dict(iterator_state.sampler_state) + # reload dataset state + dataset.load_state_dict( + iterator_state.dataset_state, + latest_worker_id=state_dict["latest_worker_id"], + num_workers=iterator_state.num_workers, + ) + + elif isinstance(dataset, CaptureIterableDataset): + dataset_dict = { + sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items() + } + dataset.load_state_dict(dataset_dict) - def create_loader_iters(dataloader: DataLoader, state_dict: DataLoaderDict): - if isinstance(dataloader.dataset, CaptureIterableDataset): - # provide the `state_dict` to the `CaptureIterableDataset` - # as it is responsible for passing down the state to associated `FastForwardSampler` - dataloader.dataset.load_state_dict(state_dict) else: - # for `Mapping-based` dataset, the `fast_forward_sampler` was attached - # on the dataloader for simplicity - dataloader.fast_forward_sampler.load_state_dict(state_dict) + raise MisconfigurationException( + "This shouldn't happen. Please, open an issue on PyTorch Lightning Github." + ) + + # We finally spawned the workers if any. + it = iter(dataloader_to_iter_on) - # cycle back the iterator to the failed worker if multiple workers were provided - iterator = _cycle_to_next_worker_and_reset(dataloader, state_dict) + # restore caching state + state = MergedIteratorState.from_state_dict(state_dict) - if isinstance(dataloader.dataset, CaptureIterableDataset): - # remove keys related to iterator - state_dict = {k: v for k, v in state_dict.items() if k not in ("num_worker", "previous_worker")} - # need to re-attach the state dict into the iterator for future collection. - iterator._sampler_state_dict = [state_dict] - return iterator + if isinstance(dataloader_to_iter_on, CycleIterator): + it._loader_iter.state = state + else: + it.state = state + return it + + # create an un-existing token, so it doesn't activate for something else than an iterator. + class DataLoaderDict(dict): + pass # apply the `create_loader_iters` on the collection of `DataLoader / Iterator`. # each `Iterator` was created from the `DataLoader`. iterator._loader_iters = apply_to_collections( - self.loaders, self._loaders_iter_state_dict, (DataLoader, DataLoaderDict), create_loader_iters + self.loaders, + self._loaders_iter_state_dict, + (Iterable, DataLoaderDict), + create_loader_iters, + wrong_dtype=(Sequence, Mapping), ) + self._loaders_iter_state_dict = None + @property def sampler(self) -> Union[Iterable, Sequence, Mapping]: """Return a collections of samplers extracting from loaders.""" @@ -457,7 +493,6 @@ def _wrap_loaders_max_size_cycle(self) -> Any: self.loaders = apply_to_collection( self.loaders, Iterable, CycleIterator, length=length, state=state, wrong_dtype=(Sequence, Mapping) ) - state.reset() def __iter__(self) -> Any: diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index c9e378dbadee6e..256168ae4382f3 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -16,8 +16,12 @@ from copy import deepcopy from dataclasses import dataclass, field from functools import partial, wraps +from random import getstate as python_get_rng_state +from random import setstate as python_set_rng_state from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Tuple, Union +import numpy as np +import torch from torch.utils.data import Dataset, get_worker_info, Sampler from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset @@ -168,6 +172,16 @@ def update(self, generator_name: Optional[str], new_state: IteratorState) -> Non state[latest_worker_id] = new_state self.latest_worker_id = latest_worker_id + @property + def sampler_states(self) -> Dict[int, Any]: + """Returns the merged sampler states for all worker processes.""" + return {0: self.state[k].sampler_state[0] for k in self.state.keys()} + + @property + def dataset_states(self) -> Dict[int, Any]: + """Returns the merged dataset states for all worker processes.""" + return {k: self.state[k].dataset_state[k] for k in self.state.keys()} + @classmethod def from_state_dict(cls, state_dict) -> "MergedIteratorState": if state_dict["represent_map_dataset"]: @@ -188,7 +202,12 @@ def __len__(self) -> int: class CaptureMapDataset(Dataset): - """This class is used to capture the state from the map-based state dataset.""" + """This class is used to capture the state from the map-based state dataset. + + Note: + We currently don't support restoring if we fail during the first `N = num_workers` batches, where + `num_workers` is the number of workers spawned by the dataloader. + """ def __init__(self, dataset: Dataset) -> None: self.dataset = dataset @@ -202,8 +221,7 @@ def worker_id(self) -> int: def __getitem__(self, item) -> Tuple[Any, Dict[int, Dict]]: if self._cached_state_dict is not None: if self.worker_id in self._cached_state_dict: - # TODO: reset random states - pass + set_rng_states(self._cached_state_dict[self.worker_id]["rng_states"]) self._cached_state_dict = None data = self.dataset[item] @@ -227,7 +245,19 @@ def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num self._cached_state_dict = state_dict def _state_dict(self) -> Dict[int, Dict[str, Any]]: - return {self.worker_id: {"rng_states": {}}} + return {self.worker_id: {"rng_states": collect_rng_states()}} + + +def collect_rng_states() -> Dict[str, Any]: + """Collect the global random state of :mod:`torch`, :mod:`numpy` and Python.""" + return {"torch": torch.get_rng_state(), "numpy": np.random.get_state(), "python": python_get_rng_state()} + + +def set_rng_states(rng_state_dict: Dict[str, Any]) -> None: + """Set the global random state of :mod:`torch`, :mod:`numpy` and Python in the current process.""" + torch.set_rng_state(rng_state_dict.get("torch")) + np.random.set_state(rng_state_dict.get("numpy")) + python_set_rng_state(rng_state_dict.get("python")) class CaptureIterableDataset(IterableDataset): diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py index 22d2be8c3a9b06..5fdd09d5fd4d48 100644 --- a/tests/loops/test_loop_state_dict.py +++ b/tests/loops/test_loop_state_dict.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock +from unittest.mock import ANY, Mock import pytest import torch @@ -22,23 +22,31 @@ def test_loops_state_dict(): + trainer = Trainer() + trainer.train_dataloader = Mock() + fit_loop = FitLoop() with pytest.raises(MisconfigurationException, match="Loop FitLoop should be connected to a"): fit_loop.trainer = object() + fit_loop.trainer = trainer fit_loop.connect(Mock()) state_dict = fit_loop.state_dict() + new_fit_loop = FitLoop() + new_fit_loop.trainer = trainer + new_fit_loop.load_state_dict(state_dict) assert fit_loop.state_dict() == new_fit_loop.state_dict() def test_loops_state_dict_structure(): trainer = Trainer() + trainer.train_dataloader = Mock() state_dict = trainer.checkpoint_connector._get_loops_state_dict() expected = { "fit_loop": { - "state_dict": {}, + "state_dict": {"dataloader_state_dict": ANY}, "epoch_loop.state_dict": {}, "epoch_loop.batch_progress": { "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 65cbebc8203e53..200b2daae93eda 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -504,7 +504,13 @@ def configure_optimizers_multiple(self): assert checkpoint["loops"]["fit_loop"] == expected trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"], restart_progress=False) - assert trainer.fit_loop.state_dict() == checkpoint["loops"]["fit_loop"] + state_dict = trainer.fit_loop.state_dict() + + # need to remove these elements for comparison; comparing with `fit_loop.state_dict()` would require the + # fit loop to have an iterator, which is only available during training + checkpoint["loops"]["fit_loop"]["state_dict"]["dataloader_state_dict"] = ANY + + assert state_dict == checkpoint["loops"]["fit_loop"] trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"]) state_dict = trainer.fit_loop.state_dict() diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index e665fc79e4323f..ca5b9084591712 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -14,9 +14,13 @@ import math import os import random +import random as python_random from collections.abc import Iterable -from typing import Optional +from contextlib import suppress +from copy import deepcopy +from typing import List, Optional from unittest import mock +from unittest.mock import ANY import numpy as np import pytest @@ -29,16 +33,19 @@ from torch.utils.data.dataset import Dataset, IterableDataset import tests.helpers.utils as tutils -from pytorch_lightning import Callback, seed_everything, Trainer +from pytorch_lightning import Callback, LightningModule, seed_everything, Trainer from pytorch_lightning.utilities.auto_restart import ( _add_capture_metadata_collate, _dataloader_load_state_dict, _dataloader_to_state_dict, CaptureIterableDataset, + CaptureMapDataset, FastForwardSampler, + MergedIteratorState, ) from pytorch_lightning.utilities.enums import AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.fetching import DataFetcher from pytorch_lightning.utilities.imports import _fault_tolerant_training from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -671,7 +678,11 @@ def create_dataloader(): _ = next(iter_dataloader) state_dict = _dataloader_to_state_dict(dataloader, iter_dataloader) - assert state_dict[0]["current_iteration"] == 16 + assert state_dict == { + "num_workers": 0, + "previous_worker": None, + 0: {"current_iteration": 16}, + } dataloader = create_dataloader() dataloader = _dataloader_load_state_dict(dataloader, state_dict) @@ -679,14 +690,18 @@ def create_dataloader(): _ = next(iter_dataloader) state_dict = _dataloader_to_state_dict(dataloader, iter_dataloader) - assert state_dict[0]["current_iteration"] == 24 + assert state_dict == { + "num_workers": 0, + "previous_worker": None, + 0: {"current_iteration": 24}, + } @RunIf(min_torch="1.7.0") @pytest.mark.parametrize("use_fault_tolerant", ["0", "1"]) def test_data_loading_wraps_dataset_and_samplers(use_fault_tolerant, tmpdir): """ - this test ensures the dataset and sampler are properly wrapped when fault tolerant is enabled. + This test ensures the dataset and sampler are properly wrapped when fault tolerant is enabled. """ class CustomBatchSampler(BatchSampler): @@ -713,7 +728,15 @@ def train_dataloader(self): } def training_step(self, batch, batch_idx): - pass + assert batch == { + "a": [ANY, ANY, ANY], + "b": ANY, + } + + def validation_step(self, batch, batch_idx): + assert isinstance(batch, torch.Tensor) + + validation_epoch_end = None class Check(Callback): def on_train_batch_start(self, trainer, *_) -> None: @@ -721,12 +744,16 @@ def on_train_batch_start(self, trainer, *_) -> None: if use_fault_tolerant == "1": assert isinstance(loaders["a"][0].loader.dataset, CaptureIterableDataset) assert isinstance(loaders["a"][1].loader.sampler, FastForwardSampler) + assert isinstance(loaders["a"][1].loader.dataset, CaptureMapDataset) assert isinstance(loaders["a"][2].loader.batch_sampler, FastForwardSampler) + assert isinstance(loaders["a"][2].loader.dataset, CaptureMapDataset) assert isinstance(loaders["b"].loader.dataset, CaptureIterableDataset) else: assert isinstance(loaders["a"][0].loader.dataset, RangeIterableDataset) assert isinstance(loaders["a"][1].loader.sampler, SequentialSampler) + assert not isinstance(loaders["a"][1].loader.dataset, CaptureMapDataset) assert isinstance(loaders["a"][2].loader.batch_sampler, CustomBatchSampler) + assert not isinstance(loaders["a"][2].loader.dataset, CaptureMapDataset) assert isinstance(loaders["b"].loader.dataset, RangeIterableDataset) with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": use_fault_tolerant}): @@ -734,3 +761,210 @@ def on_train_batch_start(self, trainer, *_) -> None: model.training_epoch_end = None trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_train_batches=1, callbacks=Check()) trainer.fit(model) + + +class SequentialGetItemDataset(Dataset): + def __init__(self, length, *_): + self.len = length + + def __getitem__(self, index): + return torch.tensor([index]).float() + + def __len__(self): + return self.len + + +class RandomGetItemDataset(Dataset): + """A dataset with random elements generated using global rng from torch, numpy and python.""" + + def __init__(self, length, size): + self.size = size + self.len = length + + def __getitem__(self, index): + t = torch.rand(self.size) + n = torch.from_numpy(np.random.rand(self.size)) + p = torch.tensor([python_random.random() for _ in range(self.size)]) + sample = (index + (t + n + p) / 10).float() + return sample + + def __len__(self): + return self.len + + +# TODO: test with `RandomGeneratorGetItemDataset` +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +@RunIf(min_torch="1.7.0") +@pytest.mark.parametrize( + "dataset_class", + [ + SequentialGetItemDataset, + RandomGetItemDataset, + # RandomGeneratorGetItemDataset, + ], +) +@pytest.mark.parametrize("num_workers", [0]) +@pytest.mark.parametrize("batch_size", [1, 2, 3]) +def test_dataset_rng_states_restart(dataset_class, num_workers, batch_size): + """Test that the sequence of batches coming from a random number generator continues with the correct sequence + after reloading the state. + """ + + def create_dataset_sampler(): + dset = CaptureMapDataset(dataset_class(16, 8)) + random_sampler = RandomSampler(dset, generator=torch.Generator()) + return dset, random_sampler + + def create_dataloader_sampler(dset, sampler): + sampler = FastForwardSampler(sampler) + sampler.setup(batch_size) + dl = DataLoader(dset, num_workers=num_workers, sampler=sampler, batch_size=batch_size) + _add_capture_metadata_collate(dl) + return dl, sampler + + def fetch(fetcher, prefetch_iter, num_batches_fetched): + batch, _ = next(prefetch_iter) + + state: List[MergedIteratorState] = fetcher.state + assert len(state) == 1 + assert isinstance(state[0], MergedIteratorState) + + assert len(fetcher.dataloader_iter.cache_states) == 1 + if num_workers == 0: + assert state[0].state[0].num_batches_fetched == num_batches_fetched + return state + + dataset, random_sampler = create_dataset_sampler() + dataloader, ff_sampler = create_dataloader_sampler(dataset, random_sampler) + + fetcher = DataFetcher() + fetcher.setup(dataloader) + prefetch_iter = iter(fetcher) + + # fetch 4 batches + fetch(fetcher, prefetch_iter, 1) + fetch(fetcher, prefetch_iter, 2) + fetch(fetcher, prefetch_iter, 3) + + # (A) capture the state after fetching 4 batches + state = fetch(fetcher, prefetch_iter, 4) + state = deepcopy(state[0]) + + # (B) simulate 2 additional batches + batch05, _ = next(prefetch_iter) + batch06, _ = next(prefetch_iter) + + # start reloading + dataset, random_sampler = create_dataset_sampler() + dataloader, ff_sampler = create_dataloader_sampler(dataset, random_sampler) + + # load the state dict saved at (A) + ff_sampler.load_state_dict(state.sampler_states) + dataset.load_state_dict(state.dataset_states, latest_worker_id=state.latest_worker_id, num_workers=num_workers) + + prefetcher = DataFetcher() + prefetcher.setup(dataloader) + prefetch_iter = iter(prefetcher) + + # fetch 2 random batches, these should match exactly the batches seen at (B) + batch05_restart, _ = next(prefetch_iter) + batch06_restart, _ = next(prefetch_iter) + + assert torch.equal(batch05, batch05_restart) + assert torch.equal(batch06, batch06_restart) + + +class CustomException(Exception): + pass + + +class SequentialIterableDataset(IterableDataset): + def __init__(self, length, *_): + self.len = length + self.sampler = SequentialSampler(range(self.len)) + + def __iter__(self): + self.sampler_iter = iter(self.sampler) + return self + + def __next__(self): + indice = next(self.sampler_iter) + return torch.tensor([indice]).float() + + +class TestModel(LightningModule): + def __init__(self, fail_on_step: int = -1): + super().__init__() + self.layer = torch.nn.Linear(1, 2) + self.seen_batches = [] + self.fail_on_step = fail_on_step + + def training_step(self, batch, batch_idx): + if self.global_step == self.fail_on_step: + raise CustomException() + self.seen_batches.append(torch.stack(batch) if isinstance(batch, list) else batch) + loss = sum(self.layer(b).sum() for b in batch) + return loss + + def configure_optimizers(self): + return torch.optim.SGD(self.layer.parameters(), lr=0.1) + + +def _run_training(trainer_kwargs, dataset_classes, fail_on_step: int = -1): + seed_everything(1) + train_dataloader = [ + DataLoader(dataset_class(3, 1), batch_size=1, num_workers=0) for dataset_class in dataset_classes + ] + train_dataloader = train_dataloader[0] if len(train_dataloader) == 1 else train_dataloader + model = TestModel(fail_on_step=fail_on_step) + trainer = Trainer(**trainer_kwargs) + with suppress(CustomException): + trainer.fit(model, train_dataloader=train_dataloader) + return model.seen_batches + + +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +@RunIf(min_torch="1.7.0") +@pytest.mark.parametrize( + "dataset_classes", + [ + # single training dataset + [RandomGetItemDataset], + [SequentialIterableDataset], + # multiple training datasets (combinded dataloader) + [SequentialGetItemDataset, SequentialIterableDataset], + [SequentialIterableDataset, SequentialIterableDataset], + # [RandomGetItemDataset, RandomGetItemDataset], # TODO: support in the future + ], +) +@pytest.mark.parametrize("multiple_trainloader_mode", ["min_size", "max_size_cycle"]) +def test_dataset_rng_states_restart_with_lightning(tmpdir, dataset_classes, multiple_trainloader_mode): + """Test that the Trainer can resume from a failed run in the case of several types of datasets.""" + trainer_kwargs = dict( + default_root_dir=tmpdir, + max_epochs=3, + weights_summary=None, + progress_bar_refresh_rate=0, + multiple_trainloader_mode=multiple_trainloader_mode, + ) + + all_batches = _run_training(trainer_kwargs, dataset_classes) + all_batches = torch.stack(all_batches) + assert len(all_batches) == 9 + + # Simulate 1st failure + complete_batches = _run_training(trainer_kwargs, dataset_classes, fail_on_step=4) + assert len(complete_batches) == 4 + + checkpoint_path = os.path.join(tmpdir, ".pl_auto_save.ckpt") + assert os.path.exists(checkpoint_path) + + # Resume after failure + trainer_kwargs.update(resume_from_checkpoint=checkpoint_path) + resumed_batches = _run_training(trainer_kwargs, dataset_classes, fail_on_step=-1) + assert len(resumed_batches) == 5 + + # the resumed batches should match the batches of the successful training + all_batches_resumed = torch.stack(complete_batches + resumed_batches) + assert len(all_batches_resumed) == 9 + assert torch.equal(all_batches, all_batches_resumed) From b576201a3db3e69f29b7960a18a21b4bb1c73d41 Mon Sep 17 00:00:00 2001 From: Burhanuddin Rangwala Date: Thu, 26 Aug 2021 20:31:42 +0530 Subject: [PATCH 16/19] Added doc strings to wandb logger (#9109) --- pytorch_lightning/loggers/wandb.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 0299cb522e59fa..7081e95d352a3d 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -219,15 +219,33 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> @property def save_dir(self) -> Optional[str]: + """ + Gets the save directory. + + Returns: + The path to the save directory. + """ return self._save_dir @property def name(self) -> Optional[str]: + """ + Gets the name of the experiment. + + Returns: + The name of the experiment if the experiment exists else the name given to the constructor. + """ # don't create an experiment if we don't have one return self._experiment.project_name() if self._experiment else self._name @property def version(self) -> Optional[str]: + """ + Gets the id of the experiment. + + Returns: + The id of the experiment if the experiment exists else the id given to the constructor. + """ # don't create an experiment if we don't have one return self._experiment.id if self._experiment else self._id From dfffb94b3ce59e40d6eaa6b10a7b560545e9c2a6 Mon Sep 17 00:00:00 2001 From: Eric Wiener Date: Thu, 26 Aug 2021 11:02:42 -0400 Subject: [PATCH 17/19] Move predictions to CPU before accumulating (#9085) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli Co-authored-by: Adrian Wälchli --- pytorch_lightning/loops/epoch/prediction_epoch_loop.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index f63bb4877e7458..73557f71ade73f 100644 --- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -1,11 +1,13 @@ from collections import OrderedDict from typing import Any, Dict, Iterator, List, Optional, Tuple +import torch from deprecate import void from pytorch_lightning.loops.base import Loop from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper from pytorch_lightning.trainer.progress import Progress +from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.warnings import WarningCache @@ -140,7 +142,7 @@ def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None self.batch_progress.increment_completed() if self.should_store_predictions: - self.predictions.append(predictions) + self.predictions.append(move_data_to_device(predictions, torch.device("cpu"))) def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict[str, Any]: """ From 53885afc2e233df1c9856ebc75da6f85e0a1377e Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Thu, 26 Aug 2021 18:36:22 +0200 Subject: [PATCH 18/19] Fix mypy typing for `utilities.apply_func` (#8781) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Co-authored-by: Carlos Mocholi --- pyproject.toml | 1 + pytorch_lightning/utilities/apply_func.py | 52 +++++++++++++---------- 2 files changed, 30 insertions(+), 23 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 206a4717a1b82b..e848464c9d2faf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,7 @@ module = [ "pytorch_lightning.loops.closure", "pytorch_lightning.trainer.evaluation_loop", "pytorch_lightning.trainer.connectors.logger_connector", + "pytorch_lightning.utilities.apply_func", "pytorch_lightning.utilities.argparse", "pytorch_lightning.utilities.cli", "pytorch_lightning.utilities.cloud_io", diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index b96a0110e58fa3..d7d09251f60870 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -18,7 +18,7 @@ from collections.abc import Mapping, Sequence from copy import copy from functools import partial -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, List, Optional, Tuple, Union import numpy as np import torch @@ -35,19 +35,23 @@ Batch = type(None) -def to_dtype_tensor(value, dtype: torch.dtype = None, device: torch.device = None): +def to_dtype_tensor( + value: Union[int, float, List[Union[int, float]]], + dtype: Optional[torch.dtype] = None, + device: Union[str, torch.device] = None, +) -> torch.Tensor: if device is None: raise MisconfigurationException("device (torch.device) should be provided.") return torch.tensor(value, dtype=dtype, device=device) -def from_numpy(value, device: torch.device = None): +def from_numpy(value: np.ndarray, device: Union[str, torch.device] = None) -> torch.Tensor: if device is None: raise MisconfigurationException("device (torch.device) should be provided.") return torch.from_numpy(value).to(device) -CONVERSION_DTYPES = [ +CONVERSION_DTYPES: List[Tuple[Any, Callable[[Any], torch.Tensor]]] = [ # bool -> uint8 as bool -> torch.bool triggers RuntimeError: Unsupported data type for NCCL process group (bool, partial(to_dtype_tensor, dtype=torch.uint8)), (int, partial(to_dtype_tensor, dtype=torch.int)), @@ -61,19 +65,19 @@ def _is_namedtuple(obj: object) -> bool: return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") -def _is_dataclass_instance(obj): +def _is_dataclass_instance(obj: object) -> bool: # https://docs.python.org/3/library/dataclasses.html#module-level-decorators-classes-and-functions return dataclasses.is_dataclass(obj) and not isinstance(obj, type) def apply_to_collection( data: Any, - dtype: Union[type, tuple], + dtype: Union[type, Any, Tuple[Union[type, Any]]], function: Callable, - *args, - wrong_dtype: Optional[Union[type, tuple]] = None, + *args: Any, + wrong_dtype: Optional[Union[type, Tuple[type]]] = None, include_none: bool = True, - **kwargs + **kwargs: Any, ) -> Any: """ Recursively applies a function to all elements of a certain dtype. @@ -121,7 +125,7 @@ def apply_to_collection( return elem_type(*out) if is_namedtuple else elem_type(out) if _is_dataclass_instance(data): - out = {} + out_dict = {} for field in data.__dataclass_fields__: v = apply_to_collection( getattr(data, field), @@ -130,11 +134,11 @@ def apply_to_collection( *args, wrong_dtype=wrong_dtype, include_none=include_none, - **kwargs + **kwargs, ) if include_none or v is not None: - out[field] = v - return elem_type(**out) + out_dict[field] = v + return elem_type(**out_dict) # data is neither of dtype, nor a collection return data @@ -143,11 +147,11 @@ def apply_to_collection( def apply_to_collections( data1: Optional[Any], data2: Optional[Any], - dtype: Union[type, tuple], + dtype: Union[type, Any, Tuple[Union[type, Any]]], function: Callable, - *args, - wrong_dtype: Optional[Union[type, tuple]] = None, - **kwargs + *args: Any, + wrong_dtype: Optional[Union[type, Tuple[type]]] = None, + **kwargs: Any, ) -> Any: """ Zips two collections and applies a function to their items of a certain dtype. @@ -169,7 +173,9 @@ def apply_to_collections( AssertionError: If sequence collections have different data sizes. """ - if data1 is None and data2 is not None: + if data1 is None: + if data2 is None: + return # in case they were passed reversed data1, data2 = data2, None @@ -220,14 +226,14 @@ class TransferableDataType(ABC): """ @classmethod - def __subclasshook__(cls, subclass): + def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: if cls is TransferableDataType: to = getattr(subclass, "to", None) return callable(to) return NotImplemented -def move_data_to_device(batch: Any, device: torch.device): +def move_data_to_device(batch: Any, device: Union[str, torch.device]) -> Any: """ Transfers a collection of data to the given device. Any object that defines a method ``to(device)`` will be moved and all other objects in the collection will be left untouched. @@ -245,7 +251,7 @@ def move_data_to_device(batch: Any, device: torch.device): - :class:`torch.device` """ - def batch_to(data): + def batch_to(data: Any) -> Any: # try to move torchtext data first if _TORCHTEXT_AVAILABLE and isinstance(data, Batch): @@ -269,14 +275,14 @@ def batch_to(data): return apply_to_collection(batch, dtype=dtype, function=batch_to) -def convert_to_tensors(data: Any, device: torch.device) -> Any: +def convert_to_tensors(data: Any, device: Union[str, torch.device]) -> Any: if device is None: raise MisconfigurationException("`torch.device` should be provided.") for src_dtype, conversion_func in CONVERSION_DTYPES: data = apply_to_collection(data, src_dtype, conversion_func, device=device) - def _move_to_device_and_make_contiguous(t: torch.Tensor, device: torch.device) -> torch.Tensor: + def _move_to_device_and_make_contiguous(t: torch.Tensor, device: Union[str, torch.device]) -> torch.Tensor: return t.to(device).contiguous() data = apply_to_collection(data, torch.Tensor, _move_to_device_and_make_contiguous, device=device) From b497fb80e53238ad345c6914be17c8b1e1a6577b Mon Sep 17 00:00:00 2001 From: four4fish <88516121+four4fish@users.noreply.github.com> Date: Thu, 26 Aug 2021 17:51:05 -0700 Subject: [PATCH 19/19] Remove reference to DistributedDataParallel from parallel plugin teardown (#8943) --- CHANGELOG.md | 3 +++ pytorch_lightning/plugins/training_type/ddp.py | 10 ++++++++++ pytorch_lightning/plugins/training_type/ddp_spawn.py | 10 ++++++++++ pytorch_lightning/plugins/training_type/dp.py | 7 +++++++ pytorch_lightning/plugins/training_type/horovod.py | 7 +++++++ pytorch_lightning/plugins/training_type/parallel.py | 12 ------------ 6 files changed, 37 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 36d90ae213fdb9..bba7ed346980ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -217,6 +217,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed `Plugin` in `base_plugin.py`, access `TrainingTypePlugin` and `PrecisionPlugin` directly instead ([#9066](https://github.com/PyTorchLightning/pytorch-lightning/pull/9066)) +- Removed `teardown` from `ParallelPlugin` ([#8943](https://github.com/PyTorchLightning/pytorch-lightning/pull/8943)) + + ### Fixed - Fixed save/load/resume from checkpoint for DeepSpeed Plugin ( diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index aeb43fcdebfe44..6d96a443e391a0 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -501,3 +501,13 @@ def reconciliate_processes(self, trace: str): os.kill(pid, signal.SIGKILL) shutil.rmtree(sync_dir) raise DeadlockDetectedException(f"DeadLock detected from rank: {self.global_rank} \n {trace}") + + def teardown(self) -> None: + if isinstance(self.model, DistributedDataParallel): + self.model = self.lightning_module + + if self.on_gpu: + # GPU teardown + self.lightning_module.cpu() + # clean up memory + torch.cuda.empty_cache() diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 08c049997bdfda..c31a908902a276 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -364,3 +364,13 @@ def register_plugins(cls, plugin_registry: Dict) -> None: description="DDPSpawn Plugin with `find_unused_parameters` as False", find_unused_parameters=False, ) + + def teardown(self) -> None: + if isinstance(self.model, DistributedDataParallel): + self.model = self.lightning_module + + if self.on_gpu: + # GPU teardown + self.lightning_module.cpu() + # clean up memory + torch.cuda.empty_cache() diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index 551324416cce93..5b0887c8483223 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -119,3 +119,10 @@ def test_step_end(self, output): if not is_overridden("test_step_end", self.lightning_module): return self.reduce(output) return output + + def teardown(self) -> None: + if self.on_gpu: + # GPU teardown + self.lightning_module.cpu() + # clean up memory + torch.cuda.empty_cache() diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index e5eb8bf9723ea3..19694e1bcda11e 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -206,3 +206,10 @@ def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["hvd.Distributed def _filter_named_parameters(model: nn.Module, optimizer: Optimizer) -> List[Tuple[str, nn.Parameter]]: opt_params = {p for group in optimizer.param_groups for p in group.get("params", [])} return [(name, p) for name, p in model.named_parameters() if p in opt_params] + + def teardown(self) -> None: + if self.on_gpu: + # GPU teardown + self.lightning_module.cpu() + # clean up memory + torch.cuda.empty_cache() diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index 71aae1bb71a918..31d2deb5f65e62 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -133,15 +133,3 @@ def block_backward_sync(self): yield None else: yield None - - def teardown(self) -> None: - # Un-reference the wrapper if any was used. - # todo (tchaton): Add support for all plugins. - if isinstance(self.model, DistributedDataParallel): - self.model = self.lightning_module - - if self.on_gpu: - # GPU teardown - self.lightning_module.cpu() - # clean up memory - torch.cuda.empty_cache()