diff --git a/.github/workflows/ci_test-tpu.yml b/.github/workflows/ci_test-tpu.yml
deleted file mode 100644
index 22bb7bd7cd4e5f..00000000000000
--- a/.github/workflows/ci_test-tpu.yml
+++ /dev/null
@@ -1,144 +0,0 @@
-name: TPU tests
-
-on:
- push:
- branches: [master, "release/*"]
-# TODO: temporal disable TPU testing until we find way how to pass credentials to forked PRs
-# pull_request:
-# branches:
-# - master
-
-env:
- GKE_CLUSTER: lightning-cluster
- GKE_ZONE: us-central1-a
- IMAGE: gcr.io/${{ secrets.GKE_PROJECT }}/tpu-testing-image
- MAX_CHECKS: 360
- CHECK_SPEEP: 5
-
-jobs:
- setup-build-publish-deploy:
- name: tpu-testing-job
- runs-on: ubuntu-20.04
- strategy:
- fail-fast: false
- matrix:
- python-version: [3.7]
- xla-version: [1.6, 1.8]
- # Timeout: https://stackoverflow.com/a/59076067/4521646
- timeout-minutes: 50
-
- steps:
- - name: Set IMAGETAG
- run: echo "IMAGETAG=$(date +%s)_${{ matrix.python-version }}" >> $GITHUB_ENV
- - name: Install Go
- uses: actions/setup-go@v2
- with:
- go-version: 1.14.x
- - name: Set up Python 3.7
- uses: actions/setup-python@v2
- with:
- python-version: 3.7
-
- - name: Checkout Pytorch Lightning
- uses: actions/checkout@v2
- with:
- repository: PyTorchLightning/pytorch-lightning
- ref: ${{ github.event.pull_request.head.sha }}
-
- - name: Checkout ml-testing-accelerators
- uses: actions/checkout@v2
- with:
- repository: GoogleCloudPlatform/ml-testing-accelerators
- path: ml-testing-accelerators
- ref: 5e88ac24f631c27045e62f0e8d5dfcf34e425e25
-
- - name: Setup gcloud CLI
- uses: GoogleCloudPlatform/github-actions/setup-gcloud@master
- with:
- version: '290.0.1'
- service_account_key: ${{ secrets.GKE_SA_KEY_BASE64 }}
- project_id: ${{ secrets.GKE_PROJECT }}
- export_default_credentials: true
-
- # Configure Docker to use the gcloud command-line tool as a credential helper for authentication.
- - name: Configure Docker
- run: |-
- gcloud --quiet auth configure-docker
- shell: bash
- - name: Build and Push Docker Image
- env:
- PYTHON_VER: ${{ matrix.python-version }}
- XLA_VER: ${{ matrix.xla-version }}
- run: |
- #cd dockers/tpu-tests
- docker build --tag "$IMAGE:$IMAGETAG" -f ./dockers/tpu-tests/Dockerfile --build-arg "PYTHON_VERSION=$PYTHON_VER" --build-arg "PYTORCH_VERSION=$XLA_VER" .
- docker push "$IMAGE:$IMAGETAG"
- shell: bash
-
- - name: Install jsonnet
- run: |-
- go get github.com/google/go-jsonnet/cmd/jsonnet
- shell: bash
- # Get the GKE credentials so we can deploy to the cluster
- # Use either zone or region depending on cluster setup.
- - run: |-
- gcloud container clusters get-credentials "$GKE_CLUSTER" --zone "$GKE_ZONE"
- shell: bash
-
- - name: Deploy the job on the kubernetes cluster
- env:
- XLA_VER: ${{ matrix.xla-version }}
- run: |-
- python -c "fname = 'dockers/tpu-tests/tpu_test_cases.jsonnet' ; ttt = open(fname).read().replace('pytorch-VERSION', 'pytorch-$XLA_VER') ; open(fname, 'w').write(ttt)"
- job_name=$(jsonnet -J ml-testing-accelerators/ dockers/tpu-tests/tpu_test_cases.jsonnet --ext-str image=$IMAGE --ext-str image-tag=$IMAGETAG | kubectl create -f -) && \
- job_name=${job_name#job.batch/} && \
- job_name=${job_name% created} && \
- echo "Waiting on kubernetes job: $job_name in cluster: $GKE_CLUSTER" && \
- i=0 && \
- # 60 checks spaced 30s apart = 900s total.
- status_code=2 && \
- # Check on the job periodically. Set the status code depending on what
- # happened to the job in Kubernetes. If we try MAX_CHECKS times and
- # still the job hasn't finished, give up and return the starting
- # non-zero status code.
- printf "Waiting for job to finish: " && \
- while [ $i -lt $MAX_CHECKS ]; do ((i++)); if kubectl get jobs $job_name -o jsonpath='Failed:{.status.failed}' | grep "Failed:1"; then status_code=1 && break; elif kubectl get jobs $job_name -o jsonpath='Succeeded:{.status.succeeded}' | grep "Succeeded:1" ; then status_code=0 && break; else printf "." ; fi; sleep $CHECK_SPEEP; done && \
- echo "Done waiting. Job status code: $status_code" && \
- pod_name=$(kubectl get po -l controller-uid=`kubectl get job $job_name -o "jsonpath={.metadata.labels.controller-uid}"` | awk 'match($0,!/NAME/) {print $1}') && \
- echo "GKE pod name: $pod_name" && \
- kubectl logs -f $pod_name --container=train > /tmp/full_output.txt
- if grep -q '' /tmp/full_output.txt ; then csplit /tmp/full_output.txt '//'; else mv /tmp/full_output.txt xx00; fi && \
- # First portion is the test logs. Print these to Github Action stdout.
- cat xx00 && \
- echo "Done with log retrieval attempt." && \
- gcloud container images delete "$IMAGE:$IMAGETAG" --force-delete-tags && \
- echo "Status code: $status_code"
- exit $status_code
- shell: bash
-
- - name: Statistics
- if: success()
- run: |
- mv ./xx01 coverage
- # TODO: add human readable report
- cat coverage
- # sudo pip install pycobertura
- # pycobertura show coverage.xml
-
- - name: Upload coverage results
- uses: actions/upload-artifact@v2
- with:
- name: coverage-TPU
- path: coverage
-
- - name: Upload coverage to Codecov
- uses: codecov/codecov-action@v1
- # see: https://github.com/actions/toolkit/issues/399
- continue-on-error: true
- if: always()
- with:
- token: ${{ secrets.CODECOV_TOKEN }}
- file: coverage
- flags: tpu,pytest
- name: TPU-coverage
- fail_ci_if_error: true
diff --git a/CHANGELOG.md b/CHANGELOG.md
index c560f99de949af..fb79fcc9d9f8dc 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -54,12 +54,17 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added `CaptureMapDataset` for state management in map-style datasets ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891))
* Added Fault Tolerant Training to `DataFetcher` ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891))
* Replaced old prefetch iterator with new `DataFetcher` in training loop ([#8953](https://github.com/PyTorchLightning/pytorch-lightning/pull/8953))
+ * Added partial support for global random state fault-tolerance in map-style datasets ([#8950](https://github.com/PyTorchLightning/pytorch-lightning/pull/8950))
- Checkpoint saving & loading extensibility:
* Added `CheckpointIO` to expose checkpoint IO from training type plugin ([#8743](https://github.com/PyTorchLightning/pytorch-lightning/pull/8743))
* Refactored CheckpointConnector to offload validating logic to the checkpoitn IO plugin ([#9045](https://github.com/PyTorchLightning/pytorch-lightning/pull/9045))
+- Loop customization:
+ * Added `Closure` and `AbstractClosure` classes ([#8642](https://github.com/PyTorchLightning/pytorch-lightning/pull/8642))
+
+
- Added support for saving and loading state of multiple callbacks of the same type ([#7187](https://github.com/PyTorchLightning/pytorch-lightning/pull/7187))
@@ -90,6 +95,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added validate logic for precision ([#9080](https://github.com/PyTorchLightning/pytorch-lightning/pull/9080))
+- Add support for CPU AMP autocast ([#9084](https://github.com/PyTorchLightning/pytorch-lightning/pull/9084))
+
+
### Changed
- Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770))
@@ -144,7 +152,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `DataModule` properties: `train_transforms`, `val_transforms`, `test_transforms`, `size`, `dims` ([#8851](https://github.com/PyTorchLightning/pytorch-lightning/pull/8851))
-- Deprecated `prepare_data_per_node` flag on Trainer and set it as a property of `DataHooks`, accessible in the `LightningModule` and `LightningDataModule` [#8958](https://github.com/PyTorchLightning/pytorch-lightning/pull/8958)
+- Deprecated `prepare_data_per_node` flag on Trainer and set it as a property of `DataHooks`, accessible in the `LightningModule` and `LightningDataModule` ([#8958](https://github.com/PyTorchLightning/pytorch-lightning/pull/8958))
+
+
+- Deprecated the `TestTubeLogger` ([#9065](https://github.com/PyTorchLightning/pytorch-lightning/pull/9065))
- Deprecated `on_{train/val/test/predict}_trainer()` from `DataHooks` [#9098](https://github.com/PyTorchLightning/pytorch-lightning/pull/9098)
@@ -207,6 +218,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed `InterBatchProcessor` in favor of `DataLoaderIterDataFetcher` ([#9052](https://github.com/PyTorchLightning/pytorch-lightning/pull/9052))
+- Removed `Plugin` in `base_plugin.py`, access `TrainingTypePlugin` and `PrecisionPlugin` directly instead ([#9066](https://github.com/PyTorchLightning/pytorch-lightning/pull/9066))
+
+
+- Removed `teardown` from `ParallelPlugin` ([#8943](https://github.com/PyTorchLightning/pytorch-lightning/pull/8943))
+
+
### Fixed
- Fixed save/load/resume from checkpoint for DeepSpeed Plugin (
diff --git a/README.md b/README.md
index d31d6528504584..c80995293ae9d6 100644
--- a/README.md
+++ b/README.md
@@ -78,14 +78,14 @@ Lightning is rigorously tested across multiple GPUs, TPUs CPUs and against major
-| System / PyTorch ver. | 1.6 (min. req.) | 1.7 | 1.8 (LTS) | 1.9 (latest) | 1.10 (nightly) |
-| :----------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
-| Conda py3.7 \[linux\] | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) |
-| Linux py3.7 \[GPUs\*\*\] | - | - | [![Build Status]()](https://dev.azure.com/PytorchLightning/pytorch-lightning/_build/latest?definitionId=6&branchName=master) | - | - |
-| Linux py3.{6,7} \[TPUs\*\*\*\] | [![TPU tests](https://github.com/PyTorchLightning/pytorch-lightning/workflows/TPU%20tests/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22TPU+tests%22+branch%3Amaster) | - | [![TPU tests](https://github.com/PyTorchLightning/pytorch-lightning/workflows/TPU%20tests/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22TPU+tests%22+branch%3Amaster) | - | - |
-| Linux py3.{6,7,8,9} | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - |
-| OSX py3.{6,7,8,9} | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - |
-| Windows py3.{6,7,8,9} | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - |
+| System / PyTorch ver. | 1.6 (min. req.) | 1.7 | 1.8 (LTS) | 1.9 (latest) | 1.10 (nightly) |
+| :------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| Conda py3.7 \[linux\] | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) |
+| Linux py3.7 \[GPUs\*\*\] | - | - | [![Build Status]()](https://dev.azure.com/PytorchLightning/pytorch-lightning/_build/latest?definitionId=6&branchName=master) | - | - |
+| Linux py3.7 \[TPUs\*\*\*\] | - | - | [![CircleCI](https://circleci.com/gh/PyTorchLightning/pytorch-lightning/tree/master.svg?style=svg)](https://circleci.com/gh/PyTorchLightning/pytorch-lightning/tree/master) | - | - |
+| Linux py3.{6,7,8,9} | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - |
+| OSX py3.{6,7,8,9} | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - |
+| Windows py3.{6,7,8,9} | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - |
- _\*\* tests run on two NVIDIA P100_
- _\*\*\* tests run on Google GKE TPUv2/3_
diff --git a/docs/source/advanced/mixed_precision.rst b/docs/source/advanced/mixed_precision.rst
new file mode 100644
index 00000000000000..ea784f0894a009
--- /dev/null
+++ b/docs/source/advanced/mixed_precision.rst
@@ -0,0 +1,83 @@
+.. testsetup:: *
+
+ from pytorch_lightning import Trainer
+
+
+.. _amp:
+
+Mixed Precision Training
+========================
+
+Mixed precision combines the use of both FP32 and lower bit floating points (such as FP16) to reduce memory footprint during model training, resulting in improved performance.
+
+Lightning offers mixed precision training for GPUs and CPUs, as well as bfloat16 mixed precision training for TPUs.
+
+.. note::
+
+ In some cases it is important to remain in FP32 for numerical stability, so keep this in mind when using mixed precision.
+
+ For example when running scatter operations during the forward (such as torchpoint3d) computation must remain in FP32.
+
+FP16 Mixed Precision
+--------------------
+
+In most cases, mixed precision uses FP16. Supported torch operations are automatically run in FP16, saving memory and improving throughput on GPU and TPU accelerators.
+
+Since computation happens in FP16, there is a chance of numerical instability. This is handled internally by a dynamic grad scaler which skips steps that are invalid, and adjusts the scaler to ensure subsequent steps fall within a finite range. For more information `see the autocast docs `__.
+
+.. note::
+
+ When using TPUs, setting ``precision=16`` will enable bfloat16 which is the only supported precision type on TPUs.
+
+.. testcode::
+ :skipif: not _APEX_AVAILABLE and not _NATIVE_AMP_AVAILABLE or not torch.cuda.is_available()
+
+ Trainer(gpus=1, precision=16)
+
+BFloat16 Mixed Precision
+------------------------
+
+.. warning::
+
+ BFloat16 requires PyTorch 1.10 or later. Currently this requires installing `PyTorch Nightly `__.
+
+ BFloat16 is also experimental and may not provide large speedups or memory improvements, but offer better numerical stability.
+
+ Do note for GPUs, largest benefits require `Ampere `__ based GPUs, such as A100s or 3090s.
+
+BFloat16 Mixed precision is similar to FP16 mixed precision, however we maintain more of the "dynamic range" that FP32 has to offer. This means we are able to improve numerical stability, compared to FP16 mixed precision. For more information see `this TPU performance blog post `__.
+
+Since BFloat16 is more stable than FP16 during training, we do not need to worry about any gradient scaling or nan gradient values that comes with using FP16 mixed precision.
+
+.. testcode::
+ :skipif: not _TORCH_BFLOAT_AVAILABLE
+
+ Trainer(gpus=1, precision="bf16")
+
+It is also possible to use BFloat16 mixed precision on the CPU, relying on MKLDNN under the hood.
+
+.. testcode::
+ :skipif: not _TORCH_CPU_AMP_AVAILABLE
+
+ Trainer(precision="bf16")
+
+NVIDIA APEX Mixed Precision
+---------------------------
+
+.. warning::
+
+ We strongly recommend to use the above native mixed precision rather than NVIDIA APEX unless you require more finer control.
+
+`NVIDIA APEX `__ offers some additional flexibility in setting mixed precision. This can be useful for when wanting to try out different precision configurations, such as keeping most of your weights in FP16 as well as running computation in FP16.
+
+.. testcode::
+ :skipif: not _APEX_AVAILABLE and not _NATIVE_AMP_AVAILABLE or not torch.cuda.is_available()
+
+ Trainer(gpus=1, amp_backend="apex")
+
+Set the `NVIDIA optimization level `__ via the trainer.
+
+.. testcode::
+ :skipif: not _APEX_AVAILABLE and not _NATIVE_AMP_AVAILABLE or not torch.cuda.is_available()
+
+ Trainer(gpus=1, amp_backend="apex", amp_level="O2")
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 8ddc896b6e9120..4adbacd4cf60c0 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -370,6 +370,8 @@ def package_list_from_file(file):
_XLA_AVAILABLE,
_TPU_AVAILABLE,
_TORCHVISION_AVAILABLE,
+ _TORCH_BFLOAT_AVAILABLE,
+ _TORCH_CPU_AMP_AVAILABLE,
_module_available,
)
_JSONARGPARSE_AVAILABLE = _module_available("jsonargparse")
diff --git a/docs/source/guides/speed.rst b/docs/source/guides/speed.rst
index 4e3ed0b1de8011..fd245e741b9aa1 100644
--- a/docs/source/guides/speed.rst
+++ b/docs/source/guides/speed.rst
@@ -186,7 +186,7 @@ Read more in our :ref:`accelerators` and :ref:`plugins` guides.
-----------
-.. _amp:
+.. _speed_amp:
*********************************
Mixed precision (16-bit) training
@@ -210,7 +210,7 @@ Mixed precision (16-bit) training
Mixed precision combines the use of both 32 and 16 bit floating points to reduce memory footprint during model training, resulting in improved performance, achieving +3X speedups on modern GPUs.
-Lightning offers mixed precision or 16-bit training for GPUs and TPUs.
+Lightning offers mixed precision training for GPUs and CPUs, as well as bfloat16 mixed precision training for TPUs.
.. testcode::
diff --git a/docs/source/index.rst b/docs/source/index.rst
index f3c154a7d257b1..e1de1ed30defa9 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -54,6 +54,7 @@ PyTorch Lightning Documentation
common/loggers
advanced/multi_gpu
advanced/advanced_gpu
+ advanced/mixed_precision
common/weights_loading
advanced/checkpoint_io
common/optimizers
diff --git a/pyproject.toml b/pyproject.toml
index d07f19ef109863..e848464c9d2faf 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -34,6 +34,8 @@ warn_unused_ignores = "True"
allow_redefinition = "True"
# disable this rule as the Trainer attributes are defined in the connectors, not in its __init__
disable_error_code = "attr-defined"
+# style choices
+warn_no_return = "False"
# TODO: Fix typing for these modules
[[tool.mypy.overrides]]
@@ -60,8 +62,10 @@ ignore_errors = "True"
[[tool.mypy.overrides]]
module = [
"pytorch_lightning.callbacks.pruning",
+ "pytorch_lightning.loops.closure",
"pytorch_lightning.trainer.evaluation_loop",
"pytorch_lightning.trainer.connectors.logger_connector",
+ "pytorch_lightning.utilities.apply_func",
"pytorch_lightning.utilities.argparse",
"pytorch_lightning.utilities.cli",
"pytorch_lightning.utilities.cloud_io",
diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py
index 2b647f5811fefc..6038a8abc8f5c3 100644
--- a/pytorch_lightning/accelerators/accelerator.py
+++ b/pytorch_lightning/accelerators/accelerator.py
@@ -186,7 +186,7 @@ def training_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT:
- hiddens(:class:`~torch.Tensor`): Passed in if
:paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0.
"""
- with self.precision_plugin.train_step_context(), self.training_type_plugin.train_step_context():
+ with self.precision_plugin.train_step_context():
return self.training_type_plugin.training_step(*step_kwargs.values())
def post_training_step(self) -> None:
@@ -204,7 +204,7 @@ def validation_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[S
- dataloader_idx (int): The index of the dataloader that produced this batch
(only if multiple val dataloaders used)
"""
- with self.precision_plugin.val_step_context(), self.training_type_plugin.val_step_context():
+ with self.precision_plugin.val_step_context():
return self.training_type_plugin.validation_step(*step_kwargs.values())
def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]:
@@ -219,7 +219,7 @@ def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OU
- dataloader_idx (int): The index of the dataloader that produced this batch
(only if multiple test dataloaders used).
"""
- with self.precision_plugin.test_step_context(), self.training_type_plugin.test_step_context():
+ with self.precision_plugin.test_step_context():
return self.training_type_plugin.test_step(*step_kwargs.values())
def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT:
@@ -234,7 +234,7 @@ def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT:
- dataloader_idx (int): The index of the dataloader that produced this batch
(only if multiple predict dataloaders used).
"""
- with self.precision_plugin.predict_step_context(), self.training_type_plugin.predict_step_context():
+ with self.precision_plugin.predict_step_context():
return self.training_type_plugin.predict_step(*step_kwargs.values())
def training_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT:
diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py
index 48957219a1ec0e..46e74193fb557a 100644
--- a/pytorch_lightning/accelerators/cpu.py
+++ b/pytorch_lightning/accelerators/cpu.py
@@ -13,7 +13,6 @@
# limitations under the License.
import pytorch_lightning as pl
from pytorch_lightning.accelerators.accelerator import Accelerator
-from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -24,15 +23,8 @@ def setup(self, trainer: "pl.Trainer") -> None:
"""
Raises:
MisconfigurationException:
- If AMP is used with CPU, or if the selected device is not CPU.
+ If the selected device is not CPU.
"""
- if isinstance(self.precision_plugin, MixedPrecisionPlugin):
- raise MisconfigurationException(
- " Mixed precision is currenty only supported with the AMP backend"
- " and AMP + CPU is not supported. Please use a GPU option or"
- " change precision setting."
- )
-
if "cpu" not in str(self.root_device):
raise MisconfigurationException(f"Device should be CPU, got {self.root_device} instead.")
diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py
index 414a92af6a66c9..e7daa4ee53cdeb 100644
--- a/pytorch_lightning/callbacks/model_checkpoint.py
+++ b/pytorch_lightning/callbacks/model_checkpoint.py
@@ -265,16 +265,16 @@ def state_key(self) -> str:
save_on_train_epoch_end=self._save_on_train_epoch_end,
)
- def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
- """
- When pretrain routine starts we build the ckpt dir on the fly
- """
- self.__resolve_ckpt_dir(trainer)
+ def on_init_end(self, trainer: "pl.Trainer") -> None:
if self._save_on_train_epoch_end is None:
# if the user runs validation multiple times per training epoch, we try to save checkpoint after
# validation instead of on train epoch end
self._save_on_train_epoch_end = trainer.val_check_interval == 1.0
+ def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
+ """When pretrain routine starts we build the ckpt dir on the fly."""
+ self.__resolve_ckpt_dir(trainer)
+
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._last_time_checked = time.monotonic()
diff --git a/pytorch_lightning/core/mixins/hparams_mixin.py b/pytorch_lightning/core/mixins/hparams_mixin.py
index 029ecc173bcf24..72129f22f54bb6 100644
--- a/pytorch_lightning/core/mixins/hparams_mixin.py
+++ b/pytorch_lightning/core/mixins/hparams_mixin.py
@@ -15,7 +15,7 @@
import inspect
import types
from argparse import Namespace
-from typing import Optional, Sequence, Union
+from typing import MutableMapping, Optional, Sequence, Union
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES
from pytorch_lightning.utilities import AttributeDict
@@ -104,7 +104,7 @@ class ``__init__`` to be ignored
frame = inspect.currentframe().f_back
save_hyperparameters(self, *args, ignore=ignore, frame=frame)
- def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None:
+ def _set_hparams(self, hp: Union[MutableMapping, Namespace, str]) -> None:
hp = self._to_hparams_dict(hp)
if isinstance(hp, dict) and isinstance(self.hparams, dict):
@@ -113,7 +113,7 @@ def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None:
self._hparams = hp
@staticmethod
- def _to_hparams_dict(hp: Union[dict, Namespace, str]):
+ def _to_hparams_dict(hp: Union[MutableMapping, Namespace, str]):
if isinstance(hp, Namespace):
hp = vars(hp)
if isinstance(hp, dict):
diff --git a/pytorch_lightning/loggers/comet.py b/pytorch_lightning/loggers/comet.py
index 1799def445ea4a..81084d61b6bd25 100644
--- a/pytorch_lightning/loggers/comet.py
+++ b/pytorch_lightning/loggers/comet.py
@@ -268,10 +268,22 @@ def finalize(self, status: str) -> None:
@property
def save_dir(self) -> Optional[str]:
+ """
+ Gets the save directory.
+
+ Returns:
+ The path to the save directory.
+ """
return self._save_dir
@property
def name(self) -> str:
+ """
+ Gets the project name.
+
+ Returns:
+ The project name if it is specified, else "comet-default".
+ """
# Don't create an experiment if we don't have one
if self._experiment is not None and self._experiment.project_name is not None:
return self._experiment.project_name
@@ -283,6 +295,19 @@ def name(self) -> str:
@property
def version(self) -> str:
+ """
+ Gets the version.
+
+ Returns:
+ The first one of the following that is set in the following order
+
+ 1. experiment id.
+ 2. experiment key.
+ 3. "COMET_EXPERIMENT_KEY" environment variable.
+ 4. future experiment key.
+
+ If none are present generates a new guid.
+ """
# Don't create an experiment if we don't have one
if self._experiment is not None:
return self._experiment.id
diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py
index e51e6b1f6a9d22..a9e1118b9d647a 100644
--- a/pytorch_lightning/loggers/tensorboard.py
+++ b/pytorch_lightning/loggers/tensorboard.py
@@ -134,10 +134,22 @@ def log_dir(self) -> str:
@property
def save_dir(self) -> Optional[str]:
+ """
+ Gets the save directory where the TensorBoard experiments are saved.
+
+ Returns:
+ The local path to the save directory where the TensorBoard experiments are saved.
+ """
return self._save_dir
@property
def sub_dir(self) -> Optional[str]:
+ """
+ Gets the sub directory where the TensorBoard experiments are saved.
+
+ Returns:
+ The local path to the sub directory where the TensorBoard experiments are saved.
+ """
return self._sub_dir
@property
@@ -258,10 +270,22 @@ def finalize(self, status: str) -> None:
@property
def name(self) -> str:
+ """
+ Get the name of the experiment.
+
+ Returns:
+ The name of the experiment.
+ """
return self._name
@property
def version(self) -> int:
+ """
+ Get the experiment version.
+
+ Returns:
+ The experiment version if specified else the next version.
+ """
if self._version is None:
self._version = self._get_next_version()
return self._version
diff --git a/pytorch_lightning/loggers/test_tube.py b/pytorch_lightning/loggers/test_tube.py
index 800b13b83c39c3..9a3b6ccee64df3 100644
--- a/pytorch_lightning/loggers/test_tube.py
+++ b/pytorch_lightning/loggers/test_tube.py
@@ -20,7 +20,7 @@
import pytorch_lightning as pl
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
-from pytorch_lightning.utilities import _module_available, rank_zero_warn
+from pytorch_lightning.utilities import _module_available, rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.distributed import rank_zero_only
_TESTTUBE_AVAILABLE = _module_available("test_tube")
@@ -36,6 +36,10 @@ class TestTubeLogger(LightningLoggerBase):
Log to local file system in `TensorBoard `_ format
but using a nicer folder structure (see `full docs `_).
+ Warning:
+ The test-tube package is no longer maintained and PyTorch Lightning will remove the :class:´TestTubeLogger´
+ in v1.7.0.
+
Install it with pip:
.. code-block:: bash
@@ -97,6 +101,10 @@ def __init__(
log_graph: bool = False,
prefix: str = "",
):
+ rank_zero_deprecation(
+ "The TestTubeLogger is deprecated since v1.5 and will be removed in v1.7. We recommend switching to the"
+ " `pytorch_lightning.loggers.TensorBoardLogger` as an alternative."
+ )
if Experiment is None:
raise ImportError(
"You want to use `test_tube` logger which is not installed yet,"
@@ -197,10 +205,22 @@ def close(self) -> None:
@property
def save_dir(self) -> Optional[str]:
+ """
+ Gets the save directory.
+
+ Returns:
+ The path to the save directory.
+ """
return self._save_dir
@property
def name(self) -> str:
+ """
+ Gets the experiment name.
+
+ Returns:
+ The experiment name if the experiment exists, else the name specified in the constructor.
+ """
if self._experiment is None:
return self._name
@@ -208,6 +228,12 @@ def name(self) -> str:
@property
def version(self) -> int:
+ """
+ Gets the experiment version.
+
+ Returns:
+ The experiment version if the experiment exists, else the next version.
+ """
if self._experiment is None:
return self._version
diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py
index 0299cb522e59fa..7081e95d352a3d 100644
--- a/pytorch_lightning/loggers/wandb.py
+++ b/pytorch_lightning/loggers/wandb.py
@@ -219,15 +219,33 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
@property
def save_dir(self) -> Optional[str]:
+ """
+ Gets the save directory.
+
+ Returns:
+ The path to the save directory.
+ """
return self._save_dir
@property
def name(self) -> Optional[str]:
+ """
+ Gets the name of the experiment.
+
+ Returns:
+ The name of the experiment if the experiment exists else the name given to the constructor.
+ """
# don't create an experiment if we don't have one
return self._experiment.project_name() if self._experiment else self._name
@property
def version(self) -> Optional[str]:
+ """
+ Gets the id of the experiment.
+
+ Returns:
+ The id of the experiment if the experiment exists else the id given to the constructor.
+ """
# don't create an experiment if we don't have one
return self._experiment.id if self._experiment else self._id
diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py
index 29517ad306ebae..3f94a0181672e4 100644
--- a/pytorch_lightning/loops/batch/training_batch_loop.py
+++ b/pytorch_lightning/loops/batch/training_batch_loop.py
@@ -15,7 +15,7 @@
from collections import OrderedDict
from contextlib import contextmanager
from copy import copy
-from functools import partial, update_wrapper
+from functools import partial
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
import numpy as np
@@ -26,6 +26,7 @@
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loops.base import Loop
+from pytorch_lightning.loops.closure import Closure, ClosureResult
from pytorch_lightning.loops.utilities import (
_check_training_step_output,
_process_training_step_output,
@@ -144,12 +145,12 @@ def advance(self, batch, batch_idx):
result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer)
if result:
- self.batch_outputs[opt_idx].append(copy(result.training_step_output))
+ self.batch_outputs[opt_idx].append(copy(result.result_collection))
else:
# in manual optimization, there is no looping over optimizers
result = self._run_optimization(batch_idx, split_batch)
if result:
- self.batch_outputs[0].append(copy(result.training_step_output))
+ self.batch_outputs[0].append(copy(result.result_collection))
def teardown(self) -> None:
# release memory
@@ -165,7 +166,7 @@ def _run_optimization(
split_batch: Any,
opt_idx: Optional[int] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
- ):
+ ) -> Optional[ClosureResult]:
"""Runs closure (train step + backward) together with optimization if necessary.
Args:
@@ -177,8 +178,7 @@ def _run_optimization(
# toggle model params
self._run_optimization_start(opt_idx, optimizer)
- result = AttributeDict()
- closure = self._make_closure(split_batch, batch_idx, opt_idx, optimizer, self._hiddens, result)
+ closure = self._make_closure(split_batch, batch_idx, opt_idx, optimizer, self._hiddens)
if self.trainer.fit_loop.should_accumulate():
# For gradient accumulation
@@ -199,7 +199,9 @@ def _run_optimization(
if self.trainer.lightning_module.automatic_optimization:
self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
else:
- result = self._training_step(split_batch, batch_idx, opt_idx, self._hiddens)
+ closure()
+
+ result = closure.get_result()
if result:
# if no result, user decided to skip optimization
@@ -211,37 +213,79 @@ def _run_optimization(
self._run_optimization_end(opt_idx)
return result
- def _training_step_and_backward_closure(
+ def _make_closure(
self,
split_batch: Any,
batch_idx: int,
opt_idx: int,
optimizer: Optimizer,
- hiddens: Tensor,
- return_result: AttributeDict,
- ) -> Optional[Tensor]:
- """Closure for training step and backward
+ hiddens: Any,
+ ) -> Closure:
+ """
+ Build a closure object that captures the given arguments and runs the `training_step` function and optionally
+ other functions such as `backward` and `zero_grad`.
+ """
+ step_fn = self._make_step_fn(split_batch, batch_idx, opt_idx, hiddens)
+ backward_fn = self._make_backward_fn(batch_idx, optimizer, opt_idx)
+ zero_grad_fn = self._make_zero_grad_fn(batch_idx, opt_idx, optimizer)
+
+ return Closure(
+ step_fn=step_fn,
+ backward_fn=backward_fn,
+ zero_grad_fn=zero_grad_fn,
+ profiler=self.trainer.profiler,
+ )
- Args:
- split_batch: the current tbptt split of the batch
- batch_idx: the index of the current batch
- opt_idx: the index of the current optimizer
- optimizer: the current optimizer
- hiddens: the hidden state of the recurrent net
- return_result: the storage of the trainstep results
+ def _make_step_fn(self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens: Any) -> Callable[[], dict]:
+ """Build the step function that runs the `training_step` and processes its output."""
+ return partial(self._training_step, split_batch, batch_idx, opt_idx, hiddens)
+
+ def _make_zero_grad_fn(self, batch_idx: int, opt_idx: int, optimizer: Optimizer) -> Optional[Callable[[], None]]:
+ """
+ Build a `zero_grad` function that zeroes the gradients before back-propagation.
+ Returns ``None`` in the case backward needs to be skipped, e.g., when manual optimization is on.
+ """
+
+ def zero_grad_fn():
+ self._on_before_zero_grad(optimizer)
+ self._optimizer_zero_grad(batch_idx, optimizer, opt_idx)
+
+ is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0
+ if (
+ not self._skip_backward
+ and self.trainer.lightning_module.automatic_optimization
+ and is_first_batch_to_accumulate
+ ):
+ return zero_grad_fn
+
+ def _make_backward_fn(
+ self,
+ batch_idx: int,
+ optimizer: Optimizer,
+ opt_idx: int,
+ ) -> Optional[Callable[[Tensor], Tensor]]:
+ """
+ Build a `backward` function that handles back-propagation through the output produced by the `training_step`
+ function. Returns ``None`` in the case backward needs to be skipped, e.g., when manual optimization is on.
"""
- result = self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens)
- if result is not None:
- return_result.update(result)
- return return_result.loss
+ def backward_fn(loss: Tensor):
+ self.backward(loss, optimizer, opt_idx)
+
+ # when in dev debugging track the losses
+ # TODO: remove dev debugger tracking loss history
+ self.trainer.dev_debugger.track_train_loss_history(batch_idx, loss)
+
+ # check if loss or model weights are nan
+ if self.trainer.terminate_on_nan:
+ check_finite_loss(self.trainer.lightning_module, loss)
- def _make_closure(self, *closure_args: Any, **closure_kwargs: Any) -> Callable:
- """Wraps the training step closure into a partial object which will be called within ``optimizer.step``."""
- partial_func = partial(self._training_step_and_backward_closure, *closure_args, **closure_kwargs)
- return update_wrapper(partial_func, self._training_step_and_backward_closure)
+ return loss
- def _process_closure_result(self, opt_closure_result: Optional[AttributeDict]) -> None:
+ if not self._skip_backward and self.trainer.lightning_module.automatic_optimization:
+ return backward_fn
+
+ def _process_closure_result(self, opt_closure_result: Optional[ClosureResult]) -> None:
"""Checks if the closure results is finite and optionally breaks if it is not
Args:
@@ -286,18 +330,18 @@ def _training_step(
_check_training_step_output(self.trainer.lightning_module, training_step_output)
- training_step_output, self._hiddens = _process_training_step_output(self.trainer, training_step_output)
- if training_step_output is None:
+ result_collection, self._hiddens = _process_training_step_output(self.trainer, training_step_output)
+ if result_collection is None:
return
closure_loss = None
loss = None
if self.trainer.lightning_module.automatic_optimization:
# accumulate loss. if accumulate_grad_batches==1, no effect
- closure_loss = training_step_output.minimize / self.trainer.accumulate_grad_batches
+ closure_loss = result_collection.minimize / self.trainer.accumulate_grad_batches
# the loss will get scaled for amp. avoid any modifications to it
loss = closure_loss.detach().clone()
- return AttributeDict(closure_loss=closure_loss, loss=loss, training_step_output=training_step_output)
+ return AttributeDict(closure_loss=closure_loss, loss=loss, result_collection=result_collection)
def _optimizer_step(
self, optimizer: torch.optim.Optimizer, opt_idx: int, batch_idx: int, train_step_and_backward_closure: Callable
@@ -438,60 +482,22 @@ def block_ddp_sync_behaviour(self, should_block_sync: bool = False) -> Generator
else:
yield None
- def training_step_and_backward(
- self,
- split_batch: Any,
- batch_idx: int,
- opt_idx: int,
- optimizer: torch.optim.Optimizer,
- hiddens: Optional[Tensor],
- ) -> STEP_OUTPUT:
- """Wrap forward, zero_grad and backward in a closure so second order methods work"""
- with self.trainer.profiler.profile("training_step_and_backward"):
- # lightning module hook
- result = self._training_step(split_batch, batch_idx, opt_idx, hiddens)
-
- if not self._skip_backward and self.trainer.lightning_module.automatic_optimization:
- is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0
-
- if is_first_batch_to_accumulate:
- self._on_before_zero_grad(optimizer)
- self._optimizer_zero_grad(batch_idx, optimizer, opt_idx)
-
- # backward pass
- if result is not None:
- with self.trainer.profiler.profile("backward"):
- self.backward(result, optimizer, opt_idx)
-
- # when in dev debugging track the losses
- self.trainer.dev_debugger.track_train_loss_history(batch_idx, result.loss)
-
- # check if loss or model weights are nan
- if self.trainer.terminate_on_nan:
- check_finite_loss(self.trainer.lightning_module, result.loss)
-
- else:
- self._warning_cache.warn(
- "training_step returned None. If this was on purpose, ignore this warning..."
- )
-
- return result
-
def backward(
- self, result: STEP_OUTPUT, optimizer: Optional[torch.optim.Optimizer], *args: Any, **kwargs: Any
- ) -> None:
+ self,
+ loss: Tensor,
+ optimizer: Optional[torch.optim.Optimizer],
+ opt_idx: Optional[int] = None,
+ *args: Any,
+ **kwargs: Any,
+ ) -> Tensor:
"""Performs the backward step.
Args:
- result: The output of the trainstep (including the loss value)
+ loss: The loss value to back-propagate on
optimizer: Current optimizer being used. ``None`` if using manual optimization.
opt_idx: Index of the current optimizer being used. ``None`` if using manual optimization.
"""
- # backward can be called manually in the training loop
- if isinstance(result, Tensor):
- self.trainer.accelerator.backward(result, optimizer, *args, **kwargs)
- else:
- result.closure_loss = self.trainer.accelerator.backward(result.closure_loss, optimizer, *args, **kwargs)
+ self.trainer.accelerator.backward(loss, optimizer, opt_idx, *args, **kwargs)
if not self.trainer.fit_loop.should_accumulate():
# track gradients
@@ -499,6 +505,7 @@ def backward(
if grad_norm_dict:
self.trainer.lightning_module._current_fx_name = "on_after_backward"
self.trainer.lightning_module.log_grad_norm(grad_norm_dict)
+ return loss
def _update_running_loss(self, current_loss: Tensor) -> None:
"""Updates the running loss value with the current value"""
diff --git a/pytorch_lightning/loops/closure.py b/pytorch_lightning/loops/closure.py
new file mode 100644
index 00000000000000..b47af75199708c
--- /dev/null
+++ b/pytorch_lightning/loops/closure.py
@@ -0,0 +1,131 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from typing import Any, Callable, Optional
+
+from torch import Tensor
+
+from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler
+from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
+from pytorch_lightning.utilities.warnings import WarningCache
+
+
+@dataclass
+class ClosureResult:
+ """A container to hold the result of a :class:`AbstractClosure` call.
+
+ Attributes:
+ closure_loss: The loss with a graph attached.
+ loss: A detached copy of the closure loss.
+ result_collection: A collection of results returned by the closure.
+ """
+
+ closure_loss: Optional[Tensor]
+ loss: Optional[Tensor]
+ result_collection: Optional[ResultCollection]
+
+
+class AbstractClosure(ABC):
+ """
+ Abstract base class for optimizer closures in Lightning.
+
+ Formally, a closure is binding variables from an external scope to a function that does a computation on these
+ variables without taking them explicitly as input. This has the benefit that a closure can be passed to an
+ object which later can call it like a function but without requiring to pass in any arguments.
+
+ This class provides a simple abstraction making the instance of this class callable like a function while capturing
+ the :class:`ClosureResult` and caching it.
+ """
+
+ def __init__(self) -> None:
+ super().__init__()
+ self._result: Optional[ClosureResult] = None
+
+ def get_result(self) -> Optional[ClosureResult]:
+ """The cached result from the last time the closure was called. Once accessed, the internal reference
+ gets reset and the consumer will have to hold on to the reference as long as necessary."""
+ result = self._result
+ self._result = None # free memory
+ return result
+
+ @abstractmethod
+ def closure(self, *args: Any, **kwargs: Any) -> Optional[ClosureResult]:
+ """Implements the behavior of the closure once it is getting called."""
+ pass
+
+ def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
+ self._result = self.closure(*args, **kwargs)
+ if self._result is not None:
+ return self._result.loss
+
+
+class Closure(AbstractClosure):
+ """
+ An implementation of a :class:`AbstractClosure` for optimization in Lightning that combines three elementary
+ closures into one: ``training_step``, ``backward`` and ``zero_grad``.
+
+ The Closure gets created by the training loop(s) and is then passed to the
+ :meth:`torch.optim.Optimizer.step` method. An optimizer is responsible for calling the closure and optionally
+ do something with the output.
+
+ Args:
+ step_fn: This is typically the :meth:`pytorch_lightning.core.lightning.LightningModule.training_step
+ wrapped with processing for its outputs
+ backward_fn: A function that takes a loss value as input, performs back-propagation and returns the loss value.
+ Can be set to ``None`` to skip the backward operation.
+ zero_grad_fn: A function that zeroes the gradients. Can be set to ``None`` to skip zero_grad, for example
+ when accumulating gradients.
+ profiler: A profiler for profiling the actions of the passed in closure functions.
+
+ Example:
+
+ closure = Closure()
+ optimizer = torch.optim.Adam(...)
+ optimizer.step(closure)
+ """
+
+ warning_cache = WarningCache()
+
+ def __init__(
+ self,
+ step_fn: Callable[[], dict],
+ backward_fn: Optional[Callable[[Tensor], Tensor]] = None,
+ zero_grad_fn: Optional[Callable[[], None]] = None,
+ profiler: Optional[BaseProfiler] = None,
+ ):
+ super().__init__()
+ self._step_fn = step_fn
+ self._backward_fn = backward_fn
+ self._zero_grad_fn = zero_grad_fn
+ self._profiler = PassThroughProfiler() if profiler is None else profiler
+
+ def closure(self, *args: Any, **kwargs: Any) -> Optional[ClosureResult]:
+ with self._profiler.profile("training_step_and_backward"):
+ step_output = self._step_fn()
+ step_output = ClosureResult(**step_output) if step_output else None
+
+ if step_output is None:
+ self.warning_cache.warn("training_step returned None. If this was on purpose, ignore this warning...")
+
+ if self._zero_grad_fn is not None:
+ with self._profiler.profile("zero_grad"):
+ self._zero_grad_fn()
+
+ if self._backward_fn is not None and step_output is not None and step_output.closure_loss is not None:
+ with self._profiler.profile("backward"):
+ step_output.closure_loss = self._backward_fn(step_output.closure_loss)
+
+ return step_output
diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py
index 7f06f5cd4ff638..68b75b68eb91ba 100644
--- a/pytorch_lightning/loops/dataloader/evaluation_loop.py
+++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py
@@ -101,11 +101,10 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
dataloader_idx: int = self.current_dataloader_idx
dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader)
dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader, dataloader_idx=dataloader_idx)
- dataloader_iter = iter(dataloader)
dl_max_batches = self._max_batches[dataloader_idx]
- dl_outputs = self.epoch_loop.run(dataloader_iter, dataloader_idx, dl_max_batches, self.num_dataloaders)
+ dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
# store batch level output per dataloader
if self.should_track_batch_outputs_for_epoch_end:
diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
index e4770084c84cd4..158d4cf527143b 100644
--- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
+++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
@@ -21,6 +21,7 @@
from pytorch_lightning.loops.base import Loop
from pytorch_lightning.loops.utilities import _prepare_dataloader_iter
from pytorch_lightning.trainer.progress import Progress
+from pytorch_lightning.utilities.fetching import AbstractDataFetcher
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.types import STEP_OUTPUT
@@ -58,12 +59,12 @@ def reset(self) -> None:
self.batch_progress.current.reset()
def on_run_start(
- self, dataloader_iter: Iterator, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int
+ self, data_fetcher: AbstractDataFetcher, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int
) -> None:
"""Adds the passed arguments to the loop's state if necessary
Args:
- dataloader_iter: iterator over the dataloader
+ data_fetcher: the current data_fetcher wrapping the dataloader
dataloader_idx: index of the current dataloader
dl_max_batches: maximum number of batches the dataloader can produce
num_dataloaders: the total number of dataloaders
@@ -72,10 +73,10 @@ def on_run_start(
self._dl_max_batches = dl_max_batches
self._num_dataloaders = num_dataloaders
- self.dataloader_iter = _prepare_dataloader_iter(dataloader_iter, self.batch_progress.current.ready)
+ self.dataloader_iter = _prepare_dataloader_iter(data_fetcher, self.batch_progress.current.ready)
def advance(
- self, dataloader_iter: Iterator, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int
+ self, data_fetcher: AbstractDataFetcher, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int
) -> None:
"""Calls the evaluation step with the corresponding hooks and updates the logger connector.
@@ -88,7 +89,7 @@ def advance(
Raises:
StopIteration: If the current batch is None
"""
- void(dataloader_iter, dl_max_batches, num_dataloaders)
+ void(data_fetcher, dl_max_batches, num_dataloaders)
batch_idx, (batch, _) = next(self.dataloader_iter)
diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py
index f63bb4877e7458..73557f71ade73f 100644
--- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py
+++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py
@@ -1,11 +1,13 @@
from collections import OrderedDict
from typing import Any, Dict, Iterator, List, Optional, Tuple
+import torch
from deprecate import void
from pytorch_lightning.loops.base import Loop
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
from pytorch_lightning.trainer.progress import Progress
+from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.warnings import WarningCache
@@ -140,7 +142,7 @@ def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None
self.batch_progress.increment_completed()
if self.should_store_predictions:
- self.predictions.append(predictions)
+ self.predictions.append(move_data_to_device(predictions, torch.device("cpu")))
def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict[str, Any]:
"""
diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py
index 49af10d4b2c0d9..4a09c0ca1faebd 100644
--- a/pytorch_lightning/loops/fit_loop.py
+++ b/pytorch_lightning/loops/fit_loop.py
@@ -14,7 +14,7 @@
import logging
from contextlib import suppress
-from typing import Optional
+from typing import Any, Dict, Optional
from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.epoch import TrainingEpochLoop
@@ -40,6 +40,8 @@ def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] =
self.min_epochs = min_epochs
self.epoch_loop: Optional[TrainingEpochLoop] = None
self.epoch_progress = Progress()
+ # caches the loaded dataloader state until dataloader objects are available
+ self._dataloader_state_dict: Dict[str, Any] = {}
@property
def current_epoch(self) -> int:
@@ -175,6 +177,10 @@ def on_advance_start(self) -> None:
if self.current_epoch != 0 and self.trainer._should_reload_dl_epoch:
self.trainer.reset_train_dataloader(model)
+ if self._dataloader_state_dict:
+ self.trainer.train_dataloader.load_state_dict(self._dataloader_state_dict)
+ self._dataloader_state_dict = {}
+
# TODO: specify the possible exception
with suppress(Exception):
# set seed for distributed sampler (enables shuffling for each epoch)
@@ -193,12 +199,11 @@ def on_advance_start(self) -> None:
def advance(self) -> None:
"""Runs one whole epoch."""
dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader)
- dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader)
- dataloader_iter = iter(dataloader)
+ data_fetcher = self.trainer.data_connector.get_profiled_dataloader(dataloader)
with self.trainer.profiler.profile("run_training_epoch"):
# run train epoch
- epoch_output = self.epoch_loop.run(dataloader_iter)
+ epoch_output = self.epoch_loop.run(data_fetcher)
if epoch_output is None:
return
@@ -235,3 +240,13 @@ def should_accumulate(self) -> bool:
def teardown(self) -> None:
self.epoch_loop.teardown()
+
+ def on_save_checkpoint(self) -> Dict:
+ state_dict = super().on_save_checkpoint()
+ # FIXME(@tchaton) Should pass has_completed=True when iterator is exhausted ?
+ state_dict["dataloader_state_dict"] = self.trainer.train_dataloader.state_dict(has_completed=False)
+ return state_dict
+
+ def on_load_checkpoint(self, state_dict: Dict) -> None:
+ # cache the dataloader state dict until the dataloader objects are available
+ self._dataloader_state_dict = state_dict.get("dataloader_state_dict", {})
diff --git a/pytorch_lightning/loops/utilities.py b/pytorch_lightning/loops/utilities.py
index 89ba5cd07d4590..dd69640106af8f 100644
--- a/pytorch_lightning/loops/utilities.py
+++ b/pytorch_lightning/loops/utilities.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Iterator, Mapping, Optional, Tuple
+from typing import Any, Iterator, Mapping, Optional, Tuple
import torch
@@ -20,7 +20,7 @@
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from pytorch_lightning.utilities.fetching import DataLoaderIterDataFetcher
+from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters
from pytorch_lightning.utilities.types import STEP_OUTPUT
@@ -65,7 +65,7 @@ def _check_training_step_output(model: "pl.LightningModule", training_step_outpu
def _process_training_step_output(
trainer: "pl.Trainer", training_step_output: STEP_OUTPUT
-) -> Tuple[Optional[ResultCollection], Optional[torch.Tensor]]:
+) -> Tuple[Optional[ResultCollection], Optional[Any]]:
"""Adds the :param:`training_step_output` to the trainer's results
Args:
@@ -105,9 +105,11 @@ def _process_training_step_output(
return results, hiddens
-def _prepare_dataloader_iter(dataloader_iter: Iterator, batch_idx: int) -> Iterator:
+def _prepare_dataloader_iter(data_fetcher: AbstractDataFetcher, batch_idx: int) -> Iterator:
"""Attach the dataloader"""
- if not isinstance(dataloader_iter, DataLoaderIterDataFetcher):
- dataloader_iter = enumerate(dataloader_iter, batch_idx)
- # restore iteration
+ if not isinstance(data_fetcher, DataLoaderIterDataFetcher):
+ # restore iteration
+ dataloader_iter = enumerate(data_fetcher, batch_idx)
+ else:
+ dataloader_iter = iter(data_fetcher)
return dataloader_iter
diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py
index a69065fa74f733..13f8c7404baf4f 100644
--- a/pytorch_lightning/plugins/__init__.py
+++ b/pytorch_lightning/plugins/__init__.py
@@ -1,4 +1,7 @@
-from pytorch_lightning.plugins.base_plugin import Plugin
+from pathlib import Path
+from typing import Union
+
+from pytorch_lightning.plugins.environments import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
from pytorch_lightning.plugins.plugins_registry import ( # noqa: F401
@@ -30,6 +33,9 @@
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
+PLUGIN = Union[TrainingTypePlugin, PrecisionPlugin, ClusterEnvironment, CheckpointIO]
+PLUGIN_INPUT = Union[PLUGIN, str]
+
__all__ = [
"CheckpointIO",
"TorchCheckpointIO",
@@ -55,13 +61,10 @@
"TPUSpawnPlugin",
"TrainingTypePlugin",
"ParallelPlugin",
- "Plugin",
"DDPShardedPlugin",
"DDPSpawnShardedPlugin",
]
-from pathlib import Path
-
FILE_ROOT = Path(__file__).parent
TRAINING_TYPE_BASE_MODULE = "pytorch_lightning.plugins.training_type"
diff --git a/pytorch_lightning/plugins/base_plugin.py b/pytorch_lightning/plugins/base_plugin.py
deleted file mode 100644
index 515fc29d0e355a..00000000000000
--- a/pytorch_lightning/plugins/base_plugin.py
+++ /dev/null
@@ -1,51 +0,0 @@
-# Copyright The PyTorch Lightning team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import contextlib
-from abc import ABC
-from typing import Generator
-
-import pytorch_lightning as pl
-
-
-class Plugin(ABC):
- """Basic class for all precision- and training type plugins."""
-
- def pre_dispatch(self) -> None:
- """Hook to do something before the training/evaluation/prediction starts."""
-
- def dispatch(self, trainer: "pl.Trainer") -> None:
- """Hook to do something at trainer run_stage starts."""
-
- def post_dispatch(self) -> None:
- """Hook to do something after the training/evaluation/prediction finishes."""
-
- @contextlib.contextmanager
- def train_step_context(self) -> Generator:
- """A contextmanager for the trainstep"""
- yield
-
- @contextlib.contextmanager
- def val_step_context(self) -> Generator:
- """A contextmanager for the validation step"""
- yield
-
- @contextlib.contextmanager
- def test_step_context(self) -> Generator:
- """A contextmanager for the teststep"""
- yield
-
- @contextlib.contextmanager
- def predict_step_context(self) -> Generator:
- """A contextmanager for the predict step"""
- yield
diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py
index 4c34bed9720678..c32f926236f567 100644
--- a/pytorch_lightning/plugins/precision/native_amp.py
+++ b/pytorch_lightning/plugins/precision/native_amp.py
@@ -19,7 +19,12 @@
import pytorch_lightning as pl
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
-from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, _TORCH_BFLOAT_AVAILABLE, AMPType
+from pytorch_lightning.utilities import (
+ _NATIVE_AMP_AVAILABLE,
+ _TORCH_BFLOAT_AVAILABLE,
+ _TORCH_CPU_AMP_AVAILABLE,
+ AMPType,
+)
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -31,7 +36,7 @@ class NativeMixedPrecisionPlugin(MixedPrecisionPlugin):
precision: Whether to use torch.float16 (16) or torch.bfloat16 (bf16).
"""
- def __init__(self, precision: Union[int, str] = 16) -> None:
+ def __init__(self, precision: Union[int, str] = 16, use_cpu: bool = False) -> None:
super().__init__()
if not _NATIVE_AMP_AVAILABLE:
@@ -39,6 +44,14 @@ def __init__(self, precision: Union[int, str] = 16) -> None:
"You have asked for native AMP but your PyTorch version does not support it."
" Consider upgrading with `pip install torch>=1.6`."
)
+
+ if use_cpu and not _TORCH_CPU_AMP_AVAILABLE:
+ raise MisconfigurationException(
+ "You have asked for native AMP on CPU, but AMP is only available on GPU for PyTorch 1.9 "
+ "and lower. To use native AMP on CPU, install PyTorch 1.10 or later."
+ )
+
+ self.use_cpu = use_cpu
self._fast_dtype = self._select_precision_dtype(precision)
self.backend = AMPType.NATIVE
if not self.is_bfloat16:
@@ -51,6 +64,10 @@ def _select_precision_dtype(self, precision: Union[int, str] = 16) -> torch.dtyp
"To use bfloat16 with native amp you must install torch greater or equal to 1.10."
)
return torch.bfloat16
+ elif self.use_cpu:
+ raise MisconfigurationException(
+ "CPU native amp only supports bfloat16. Please pass precision='bf16' to the Trainer."
+ )
return torch.float16
@property
@@ -91,6 +108,8 @@ def pre_optimizer_step(
return False
def autocast_context_manager(self) -> torch.cuda.amp.autocast:
+ if self.use_cpu:
+ return torch.cpu.amp.autocast(fast_dtype=self._fast_dtype)
if self.is_bfloat16:
return torch.cuda.amp.autocast(fast_dtype=self._fast_dtype)
return torch.cuda.amp.autocast()
diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py
index 1261fea87c06ed..86486bfc37cd92 100644
--- a/pytorch_lightning/plugins/precision/precision_plugin.py
+++ b/pytorch_lightning/plugins/precision/precision_plugin.py
@@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Callable, List, Optional, Tuple, Union
+import contextlib
+from typing import Any, Callable, Generator, List, Optional, Tuple, Union
import torch
from torch import Tensor
@@ -20,12 +21,11 @@
import pytorch_lightning as pl
from pytorch_lightning.core.hooks import CheckpointHooks
-from pytorch_lightning.plugins.base_plugin import Plugin
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.types import _PARAMETERS
-class PrecisionPlugin(Plugin, CheckpointHooks):
+class PrecisionPlugin(CheckpointHooks):
"""
Base class for all plugins handling the precision-specific parts of the training.
The class attribute precision must be overwritten in child classes.
@@ -136,3 +136,32 @@ def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -
"""Clip gradients by norm"""
parameters = self.master_params(optimizer)
torch.nn.utils.clip_grad_norm_(parameters, clip_val)
+
+ def pre_dispatch(self) -> None:
+ """Hook to do something before the training/evaluation/prediction starts."""
+
+ def dispatch(self, trainer: "pl.Trainer") -> None:
+ """Hook to do something when ``Accelerator.dispatch()`` gets called."""
+
+ def post_dispatch(self) -> None:
+ """Hook to do something after the training/evaluation/prediction finishes."""
+
+ @contextlib.contextmanager
+ def train_step_context(self) -> Generator:
+ """A contextmanager for the training step"""
+ yield
+
+ @contextlib.contextmanager
+ def val_step_context(self) -> Generator:
+ """A contextmanager for the validation step"""
+ yield
+
+ @contextlib.contextmanager
+ def test_step_context(self) -> Generator:
+ """A contextmanager for the test step"""
+ yield
+
+ @contextlib.contextmanager
+ def predict_step_context(self) -> Generator:
+ """A contextmanager for the predict step"""
+ yield
diff --git a/pytorch_lightning/plugins/precision/sharded_native_amp.py b/pytorch_lightning/plugins/precision/sharded_native_amp.py
index 861e5e1363dd26..a1eb23e478132d 100644
--- a/pytorch_lightning/plugins/precision/sharded_native_amp.py
+++ b/pytorch_lightning/plugins/precision/sharded_native_amp.py
@@ -24,9 +24,10 @@
class ShardedNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin):
"""Mixed Precision for Sharded Training"""
- def __init__(self, precision: Union[int, str] = 16) -> None:
- super().__init__(precision)
- self.scaler = ShardedGradScaler()
+ def __init__(self, precision: Union[int, str] = 16, use_cpu: bool = False) -> None:
+ super().__init__(precision, use_cpu=use_cpu)
+ if not self.use_cpu:
+ self.scaler = ShardedGradScaler()
def clip_grad_by_norm(
self, optimizer: "OSS", clip_val: Union[int, float], norm_type: float = 2.0, eps: float = 1e-6
diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py
index 787353be307e69..6d96a443e391a0 100644
--- a/pytorch_lightning/plugins/training_type/ddp.py
+++ b/pytorch_lightning/plugins/training_type/ddp.py
@@ -29,17 +29,21 @@
import torch.distributed
from torch.nn.parallel.distributed import DistributedDataParallel
+from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
+from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import (
+ _FAIRSCALE_AVAILABLE,
_HYDRA_AVAILABLE,
_TORCH_GREATER_EQUAL_1_7,
_TORCH_GREATER_EQUAL_1_8,
_TORCH_GREATER_EQUAL_1_9,
+ _TORCH_GREATER_EQUAL_1_10,
rank_zero_deprecation,
rank_zero_warn,
)
@@ -53,11 +57,19 @@
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
from pytorch_lightning.utilities.seed import reset_seed
+if _TORCH_GREATER_EQUAL_1_10:
+ from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer
+
+if _FAIRSCALE_AVAILABLE:
+ from fairscale.optim import OSS
if _HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
from hydra.utils import get_original_cwd, to_absolute_path
if _TORCH_GREATER_EQUAL_1_8:
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
+if _TORCH_GREATER_EQUAL_1_10:
+ import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
+ import torch.distributed.algorithms.model_averaging.averagers as averagers
log = logging.getLogger(__name__)
@@ -83,6 +95,7 @@ def __init__(
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[callable] = None,
ddp_comm_wrapper: Optional[callable] = None,
+ model_averaging_period: Optional[int] = None,
**kwargs: Union[Any, Dict[str, Any]],
) -> None:
super().__init__(
@@ -110,6 +123,7 @@ def __init__(
self._ddp_comm_state = ddp_comm_state
self._ddp_comm_hook = ddp_comm_hook
self._ddp_comm_wrapper = ddp_comm_wrapper
+ self._model_averaging_period = model_averaging_period
self._pids: Optional[List[int]] = None
self._sync_dir: Optional[str] = None
self.set_world_ranks()
@@ -302,6 +316,51 @@ def _register_ddp_hooks(self) -> None:
ddp_comm_wrapper=self._ddp_comm_wrapper,
)
+ # Post-localSDG is only available after 1.9,
+ # and `torch.distributed.optim` package currently is not available on Windows.
+ if (
+ _TORCH_GREATER_EQUAL_1_10
+ and isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState)
+ and self.lightning_module.trainer.state.fn == TrainerFn.FITTING
+ ):
+ self._reinit_optimizers_with_post_localSGD(self._ddp_comm_state.start_localSGD_iter)
+
+ def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int):
+ optimizers = self.lightning_module.trainer.optimizers
+ if self._model_averaging_period is None:
+ raise ValueError(
+ "Post-localSGD algorithm is used, " "but model averaging period is not provided to DDP plugin."
+ )
+ averager = averagers.PeriodicModelAverager(period=self._model_averaging_period, warmup_steps=warmup_steps)
+ for x, optimizer in enumerate(optimizers):
+ if isinstance(optimizer, LightningOptimizer):
+ optimizer = optimizer._optimizer
+
+ if (
+ isinstance(optimizer, DistributedOptimizer)
+ or isinstance(optimizer, ZeroRedundancyOptimizer)
+ or (_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS))
+ ):
+ raise ValueError(
+ f"Cannot wrap a distributed optimizer of type {optimizer.__name__} by PostLocalSGDOptimizer."
+ )
+
+ if isinstance(optimizer, PostLocalSGDOptimizer):
+ continue
+
+ optim_class = type(optimizer)
+ post_localSGD_optimizer = PostLocalSGDOptimizer(
+ params=optimizer.param_groups,
+ optimizer_class=optim_class,
+ averager=averager,
+ **optimizer.defaults,
+ )
+ optimizers[x] = post_localSGD_optimizer
+ del optimizer
+ trainer = self.lightning_module.trainer
+ trainer.optimizers = optimizers
+ trainer.convert_to_lightning_optimizers()
+
def configure_ddp(self):
self.pre_configure_ddp()
self._model = DistributedDataParallel(
@@ -442,3 +501,13 @@ def reconciliate_processes(self, trace: str):
os.kill(pid, signal.SIGKILL)
shutil.rmtree(sync_dir)
raise DeadlockDetectedException(f"DeadLock detected from rank: {self.global_rank} \n {trace}")
+
+ def teardown(self) -> None:
+ if isinstance(self.model, DistributedDataParallel):
+ self.model = self.lightning_module
+
+ if self.on_gpu:
+ # GPU teardown
+ self.lightning_module.cpu()
+ # clean up memory
+ torch.cuda.empty_cache()
diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py
index 08c049997bdfda..c31a908902a276 100644
--- a/pytorch_lightning/plugins/training_type/ddp_spawn.py
+++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py
@@ -364,3 +364,13 @@ def register_plugins(cls, plugin_registry: Dict) -> None:
description="DDPSpawn Plugin with `find_unused_parameters` as False",
find_unused_parameters=False,
)
+
+ def teardown(self) -> None:
+ if isinstance(self.model, DistributedDataParallel):
+ self.model = self.lightning_module
+
+ if self.on_gpu:
+ # GPU teardown
+ self.lightning_module.cpu()
+ # clean up memory
+ torch.cuda.empty_cache()
diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py
index 551324416cce93..5b0887c8483223 100644
--- a/pytorch_lightning/plugins/training_type/dp.py
+++ b/pytorch_lightning/plugins/training_type/dp.py
@@ -119,3 +119,10 @@ def test_step_end(self, output):
if not is_overridden("test_step_end", self.lightning_module):
return self.reduce(output)
return output
+
+ def teardown(self) -> None:
+ if self.on_gpu:
+ # GPU teardown
+ self.lightning_module.cpu()
+ # clean up memory
+ torch.cuda.empty_cache()
diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py
index e5eb8bf9723ea3..19694e1bcda11e 100644
--- a/pytorch_lightning/plugins/training_type/horovod.py
+++ b/pytorch_lightning/plugins/training_type/horovod.py
@@ -206,3 +206,10 @@ def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["hvd.Distributed
def _filter_named_parameters(model: nn.Module, optimizer: Optimizer) -> List[Tuple[str, nn.Parameter]]:
opt_params = {p for group in optimizer.param_groups for p in group.get("params", [])}
return [(name, p) for name, p in model.named_parameters() if p in opt_params]
+
+ def teardown(self) -> None:
+ if self.on_gpu:
+ # GPU teardown
+ self.lightning_module.cpu()
+ # clean up memory
+ torch.cuda.empty_cache()
diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py
index 71aae1bb71a918..31d2deb5f65e62 100644
--- a/pytorch_lightning/plugins/training_type/parallel.py
+++ b/pytorch_lightning/plugins/training_type/parallel.py
@@ -133,15 +133,3 @@ def block_backward_sync(self):
yield None
else:
yield None
-
- def teardown(self) -> None:
- # Un-reference the wrapper if any was used.
- # todo (tchaton): Add support for all plugins.
- if isinstance(self.model, DistributedDataParallel):
- self.model = self.lightning_module
-
- if self.on_gpu:
- # GPU teardown
- self.lightning_module.cpu()
- # clean up memory
- torch.cuda.empty_cache()
diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py
index 82a33290aafbf9..6ee1ce77c8c24c 100644
--- a/pytorch_lightning/plugins/training_type/training_type_plugin.py
+++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py
@@ -25,14 +25,13 @@
import pytorch_lightning as pl
from pytorch_lightning.overrides.base import unwrap_lightning_module
from pytorch_lightning.plugins import TorchCheckpointIO
-from pytorch_lightning.plugins.base_plugin import Plugin
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT
TBroadcast = TypeVar("T")
-class TrainingTypePlugin(Plugin, ABC):
+class TrainingTypePlugin(ABC):
"""
Base class for all training type plugins that change the behaviour of the training, validation and test-loop.
"""
@@ -352,3 +351,12 @@ def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int)
Called in the training loop before anything happens for that batch.
"""
pass
+
+ def pre_dispatch(self) -> None:
+ """Hook to do something before the training/evaluation/prediction starts."""
+
+ def dispatch(self, trainer: "pl.Trainer") -> None:
+ """Hook to do something at trainer run_stage starts."""
+
+ def post_dispatch(self) -> None:
+ """Hook to do something after the training/evaluation/prediction finishes."""
diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py
index 36a3e9abb7b7a4..bbfcbb22802a87 100644
--- a/pytorch_lightning/trainer/callback_hook.py
+++ b/pytorch_lightning/trainer/callback_hook.py
@@ -11,12 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from abc import ABC
from copy import deepcopy
from typing import Any, Dict, List, Optional, Type, Union
import torch
+from packaging.version import Version
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
@@ -255,14 +255,14 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
if callback_states is None:
return
- current_callbacks_type = {type(cb) for cb in self.callbacks}
- saved_callbacks_type = set(callback_states.keys())
- difference = saved_callbacks_type.difference(current_callbacks_type)
+ is_legacy_ckpt = Version(checkpoint["pytorch-lightning_version"]) < Version("1.5.0dev")
+ current_callbacks_keys = {cb._legacy_state_key if is_legacy_ckpt else cb.state_key for cb in self.callbacks}
+ difference = callback_states.keys() - current_callbacks_keys
if difference:
rank_zero_warn(
- "Be aware that when using ``resume_from_checkpoint``, "
- "callbacks used to create the checkpoint need to be provided. "
- f"Please, add the following callbacks: {list(difference)}. ",
+ "Be aware that when using `resume_from_checkpoint`,"
+ " callbacks used to create the checkpoint need to be provided."
+ f" Please add the following callbacks: {list(difference)}.",
UserWarning,
)
diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py
index 40ba9de2a96fd7..85a2c7ee87d131 100644
--- a/pytorch_lightning/trainer/connectors/accelerator_connector.py
+++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py
@@ -567,10 +567,6 @@ def select_precision_plugin(self) -> PrecisionPlugin:
return TPUHalfPrecisionPlugin()
if self.amp_type == AMPType.NATIVE:
- if self.use_cpu:
- raise MisconfigurationException(
- "You have asked for native AMP on CPU, but AMP is only available on GPU."
- )
if not _NATIVE_AMP_AVAILABLE:
msg = (
"You have asked for native AMP but your PyTorch version does not support it."
@@ -585,10 +581,10 @@ def select_precision_plugin(self) -> PrecisionPlugin:
else:
log.info(f"Using native {self.precision} bit Automatic Mixed Precision")
if self._is_sharded_training_type:
- return ShardedNativeMixedPrecisionPlugin(self.precision)
+ return ShardedNativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu)
if self._is_fully_sharded_training_type:
- return FullyShardedNativeMixedPrecisionPlugin(self.precision)
- return NativeMixedPrecisionPlugin(self.precision)
+ return FullyShardedNativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu)
+ return NativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu)
if self.amp_type == AMPType.APEX:
if not _APEX_AVAILABLE:
diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py
index 0e747a9e4857d8..4c0ddddd2c234f 100644
--- a/pytorch_lightning/trainer/supporters.py
+++ b/pytorch_lightning/trainer/supporters.py
@@ -13,20 +13,23 @@
# limitations under the License.
from collections.abc import Iterable, Iterator, Mapping, Sequence
-from dataclasses import dataclass, field
+from dataclasses import asdict, dataclass, field
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from torch.utils.data import Dataset
-from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, DataLoader
+from torch.utils.data.dataloader import _BaseDataLoaderIter, DataLoader
from torch.utils.data.dataset import IterableDataset
from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections
from pytorch_lightning.utilities.auto_restart import (
- _cycle_to_next_worker_and_reset,
- _find_current_worker,
+ _find_fast_forward_samplers,
CaptureIterableDataset,
+ CaptureMapDataset,
+ IteratorState,
+ MergedIteratorState,
+ patch_dataloader_iterator,
)
from pytorch_lightning.utilities.data import get_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -167,6 +170,7 @@ def __init__(self, loader: Any, length: Optional[int] = None, state: SharedCycle
self.loader = loader
self._loader_iter = None
self.counter = 0
+ self.state = state
def __iter__(self) -> Any:
"""
@@ -176,6 +180,7 @@ def __iter__(self) -> Any:
CycleIterator: self
"""
self.counter = 0
+ self.state.reset()
self._loader_iter = iter(self.loader)
return self
@@ -205,6 +210,12 @@ def __next__(self) -> Any:
raise StopIteration
self._loader_iter = iter(self.loader)
+ # if fault tolerant is enabled, we need to patch the iterator to collect the states
+ # before the batch gets returned.
+ fetcher = getattr(self.loader, "_lightning_fetcher", None)
+ if fetcher:
+ patch_dataloader_iterator(self.loader, self._loader_iter, fetcher)
+
return next(self._loader_iter)
finally:
@@ -302,11 +313,6 @@ def __len__(self) -> int:
return self._calc_num_data(self.datasets, self.mode)
-class DataLoaderDict(Dict):
- # behaves exactly like a dict, this is used to simplify apply_to_collection.
- pass
-
-
class CombinedLoader:
"""
Combines different dataloaders and allows sampling in parallel.
@@ -360,80 +366,110 @@ def __init__(self, loaders: Any, mode: str = "min_size"):
self._iterator = None # assigned in __iter__
@staticmethod
- def _state_dict_fn(dataloader: DataLoader, iterator: Optional[Iterator], num_batches_processed: int) -> Dict:
- # find next worker if multiple workers were used
- state = _find_current_worker(iterator)
- if isinstance(dataloader.dataset, CaptureIterableDataset):
- # the sampler state dict are extracted in `CombinedLoaderIterator`
- if iterator is not None and getattr(iterator, "_sampler_state_dict", None) is not None:
- state.update(iterator._sampler_state_dict[0])
- else:
- # fetch directly from fast forward sampler
- state.update(dataloader.fast_forward_sampler.state_dict(num_batches_processed))
- return DataLoaderDict(state)
-
- def state_dict(self, num_batches_processed: int) -> Dict:
+ def _state_dict_fn(dataloader: DataLoader, iterator: Optional[Iterator], has_completed: int) -> Dict:
+ if isinstance(dataloader, CycleIterator):
+ iterator = dataloader._loader_iter
+ state = getattr(iterator, "state", None) if has_completed else getattr(iterator, "previous_state", None)
+ if state:
+ return asdict(state)
+ return {}
+
+ def state_dict(self, has_completed: bool = False) -> Dict:
"""
The state dict includes all states from wrapped dataloaders and their samplers through the
``CaptureIterableDataset`` and fast-forward samplers.
Args:
- num_batches_processed: The number of batches processed so far, needed because the individual dataloaders
- may have already prefetched more batches by the time a state dict is requested.
+ has_completed: whether the current state of data fetching is considered completed or not. If it is, the
+ current state gets returned, otherwise the previously cached state.
"""
- if not _fault_tolerant_training():
- return DataLoaderDict()
-
- state_dict_fn = partial(self._state_dict_fn, num_batches_processed=num_batches_processed)
+ if not _fault_tolerant_training() or self._iterator is None:
+ return {}
- return apply_to_collections(self.loaders, self._iterator.loader_iters, (Iterator, DataLoader), state_dict_fn)
+ return apply_to_collections(
+ self.loaders,
+ self._iterator.loader_iters,
+ (Iterator, DataLoader),
+ partial(self._state_dict_fn, has_completed=has_completed),
+ )
- def load_state_dict(self, state_dict):
+ def load_state_dict(self, state_dict) -> None:
# store the samplers state.
# They would be reloaded once the `CombinedIterator` as been created
# and the workers are created.
self._loaders_iter_state_dict = state_dict
- def mock_reset_fn(self, *_, **__):
- pass
-
- # mock reset call, so we can rotate the `_worker_queue_idx_cycle` to failed worker
- # and get the first batch from it
- _MultiProcessingDataLoaderIter._original_reset = _MultiProcessingDataLoaderIter._reset
- _MultiProcessingDataLoaderIter._reset = mock_reset_fn
-
- def on_restart(self, iterator: Iterator):
+ def on_restart(self, iterator: Iterator) -> None:
if not self._loaders_iter_state_dict:
return
- # this happen inside the workers if any were specificied.
+ def create_loader_iters(dataloader: DataLoader, state_dict: Dict) -> Iterator:
+ """Function used to reload the iterator state before once the workers are created."""
+
+ dataloader_to_iter_on = dataloader
+ if isinstance(dataloader, CycleIterator):
+ dataloader = dataloader_to_iter_on.loader
+
+ dataset = dataloader.dataset
+
+ # We reload the states before creating the workers
+ # The specific type of dataset will then decide if the state should be applied before or after
+ # spawning the workers
+ if isinstance(dataset, CaptureMapDataset):
+ iterator_state = state_dict["state"][0]
+
+ if not isinstance(iterator_state, IteratorState):
+ iterator_state = IteratorState.from_state_dict(iterator_state)
+
+ # reload sampler state
+ ff_sampler = _find_fast_forward_samplers(dataloader)
+ ff_sampler.load_state_dict(iterator_state.sampler_state)
+ # reload dataset state
+ dataset.load_state_dict(
+ iterator_state.dataset_state,
+ latest_worker_id=state_dict["latest_worker_id"],
+ num_workers=iterator_state.num_workers,
+ )
+
+ elif isinstance(dataset, CaptureIterableDataset):
+ dataset_dict = {
+ sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items()
+ }
+ dataset.load_state_dict(dataset_dict)
- def create_loader_iters(dataloader: DataLoader, state_dict: DataLoaderDict):
- if isinstance(dataloader.dataset, CaptureIterableDataset):
- # provide the `state_dict` to the `CaptureIterableDataset`
- # as it is responsible for passing down the state to associated `FastForwardSampler`
- dataloader.dataset.load_state_dict(state_dict)
else:
- # for `Mapping-based` dataset, the `fast_forward_sampler` was attached
- # on the dataloader for simplicity
- dataloader.fast_forward_sampler.load_state_dict(state_dict)
+ raise MisconfigurationException(
+ "This shouldn't happen. Please, open an issue on PyTorch Lightning Github."
+ )
+
+ # We finally spawned the workers if any.
+ it = iter(dataloader_to_iter_on)
- # cycle back the iterator to the failed worker if multiple workers were provided
- iterator = _cycle_to_next_worker_and_reset(dataloader, state_dict)
+ # restore caching state
+ state = MergedIteratorState.from_state_dict(state_dict)
- if isinstance(dataloader.dataset, CaptureIterableDataset):
- # remove keys related to iterator
- state_dict = {k: v for k, v in state_dict.items() if k not in ("num_worker", "previous_worker")}
- # need to re-attach the state dict into the iterator for future collection.
- iterator._sampler_state_dict = [state_dict]
- return iterator
+ if isinstance(dataloader_to_iter_on, CycleIterator):
+ it._loader_iter.state = state
+ else:
+ it.state = state
+ return it
+
+ # create an un-existing token, so it doesn't activate for something else than an iterator.
+ class DataLoaderDict(dict):
+ pass
# apply the `create_loader_iters` on the collection of `DataLoader / Iterator`.
# each `Iterator` was created from the `DataLoader`.
iterator._loader_iters = apply_to_collections(
- self.loaders, self._loaders_iter_state_dict, (DataLoader, DataLoaderDict), create_loader_iters
+ self.loaders,
+ self._loaders_iter_state_dict,
+ (Iterable, DataLoaderDict),
+ create_loader_iters,
+ wrong_dtype=(Sequence, Mapping),
)
+ self._loaders_iter_state_dict = None
+
@property
def sampler(self) -> Union[Iterable, Sequence, Mapping]:
"""Return a collections of samplers extracting from loaders."""
@@ -457,7 +493,6 @@ def _wrap_loaders_max_size_cycle(self) -> Any:
self.loaders = apply_to_collection(
self.loaders, Iterable, CycleIterator, length=length, state=state, wrong_dtype=(Sequence, Mapping)
)
-
state.reset()
def __iter__(self) -> Any:
diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py
index 19ccf3935a168f..7ebbc55ae7ac9e 100644
--- a/pytorch_lightning/trainer/trainer.py
+++ b/pytorch_lightning/trainer/trainer.py
@@ -32,8 +32,7 @@
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
from pytorch_lightning.loops.dataloader.prediction_loop import PredictionLoop
from pytorch_lightning.loops.fit_loop import FitLoop
-from pytorch_lightning.plugins import DDPSpawnPlugin, Plugin
-from pytorch_lightning.plugins.environments import ClusterEnvironment
+from pytorch_lightning.plugins import DDPSpawnPlugin, PLUGIN_INPUT
from pytorch_lightning.profiler import (
AdvancedProfiler,
BaseProfiler,
@@ -151,7 +150,7 @@ def __init__(
terminate_on_nan: bool = False,
auto_scale_batch_size: Union[str, bool] = False,
prepare_data_per_node: Optional[bool] = None,
- plugins: Optional[Union[List[Union[Plugin, ClusterEnvironment, str]], Plugin, ClusterEnvironment, str]] = None,
+ plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None,
amp_backend: str = "native",
amp_level: str = "O2",
distributed_backend: Optional[str] = None,
diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py
index 56e3e03910de9a..ed9bb930001d87 100644
--- a/pytorch_lightning/utilities/__init__.py
+++ b/pytorch_lightning/utilities/__init__.py
@@ -45,6 +45,7 @@
_POPTORCH_AVAILABLE,
_RICH_AVAILABLE,
_TORCH_BFLOAT_AVAILABLE,
+ _TORCH_CPU_AMP_AVAILABLE,
_TORCH_GREATER_EQUAL_1_7,
_TORCH_GREATER_EQUAL_1_8,
_TORCH_GREATER_EQUAL_1_9,
diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py
index b96a0110e58fa3..d7d09251f60870 100644
--- a/pytorch_lightning/utilities/apply_func.py
+++ b/pytorch_lightning/utilities/apply_func.py
@@ -18,7 +18,7 @@
from collections.abc import Mapping, Sequence
from copy import copy
from functools import partial
-from typing import Any, Callable, Optional, Union
+from typing import Any, Callable, List, Optional, Tuple, Union
import numpy as np
import torch
@@ -35,19 +35,23 @@
Batch = type(None)
-def to_dtype_tensor(value, dtype: torch.dtype = None, device: torch.device = None):
+def to_dtype_tensor(
+ value: Union[int, float, List[Union[int, float]]],
+ dtype: Optional[torch.dtype] = None,
+ device: Union[str, torch.device] = None,
+) -> torch.Tensor:
if device is None:
raise MisconfigurationException("device (torch.device) should be provided.")
return torch.tensor(value, dtype=dtype, device=device)
-def from_numpy(value, device: torch.device = None):
+def from_numpy(value: np.ndarray, device: Union[str, torch.device] = None) -> torch.Tensor:
if device is None:
raise MisconfigurationException("device (torch.device) should be provided.")
return torch.from_numpy(value).to(device)
-CONVERSION_DTYPES = [
+CONVERSION_DTYPES: List[Tuple[Any, Callable[[Any], torch.Tensor]]] = [
# bool -> uint8 as bool -> torch.bool triggers RuntimeError: Unsupported data type for NCCL process group
(bool, partial(to_dtype_tensor, dtype=torch.uint8)),
(int, partial(to_dtype_tensor, dtype=torch.int)),
@@ -61,19 +65,19 @@ def _is_namedtuple(obj: object) -> bool:
return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
-def _is_dataclass_instance(obj):
+def _is_dataclass_instance(obj: object) -> bool:
# https://docs.python.org/3/library/dataclasses.html#module-level-decorators-classes-and-functions
return dataclasses.is_dataclass(obj) and not isinstance(obj, type)
def apply_to_collection(
data: Any,
- dtype: Union[type, tuple],
+ dtype: Union[type, Any, Tuple[Union[type, Any]]],
function: Callable,
- *args,
- wrong_dtype: Optional[Union[type, tuple]] = None,
+ *args: Any,
+ wrong_dtype: Optional[Union[type, Tuple[type]]] = None,
include_none: bool = True,
- **kwargs
+ **kwargs: Any,
) -> Any:
"""
Recursively applies a function to all elements of a certain dtype.
@@ -121,7 +125,7 @@ def apply_to_collection(
return elem_type(*out) if is_namedtuple else elem_type(out)
if _is_dataclass_instance(data):
- out = {}
+ out_dict = {}
for field in data.__dataclass_fields__:
v = apply_to_collection(
getattr(data, field),
@@ -130,11 +134,11 @@ def apply_to_collection(
*args,
wrong_dtype=wrong_dtype,
include_none=include_none,
- **kwargs
+ **kwargs,
)
if include_none or v is not None:
- out[field] = v
- return elem_type(**out)
+ out_dict[field] = v
+ return elem_type(**out_dict)
# data is neither of dtype, nor a collection
return data
@@ -143,11 +147,11 @@ def apply_to_collection(
def apply_to_collections(
data1: Optional[Any],
data2: Optional[Any],
- dtype: Union[type, tuple],
+ dtype: Union[type, Any, Tuple[Union[type, Any]]],
function: Callable,
- *args,
- wrong_dtype: Optional[Union[type, tuple]] = None,
- **kwargs
+ *args: Any,
+ wrong_dtype: Optional[Union[type, Tuple[type]]] = None,
+ **kwargs: Any,
) -> Any:
"""
Zips two collections and applies a function to their items of a certain dtype.
@@ -169,7 +173,9 @@ def apply_to_collections(
AssertionError:
If sequence collections have different data sizes.
"""
- if data1 is None and data2 is not None:
+ if data1 is None:
+ if data2 is None:
+ return
# in case they were passed reversed
data1, data2 = data2, None
@@ -220,14 +226,14 @@ class TransferableDataType(ABC):
"""
@classmethod
- def __subclasshook__(cls, subclass):
+ def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
if cls is TransferableDataType:
to = getattr(subclass, "to", None)
return callable(to)
return NotImplemented
-def move_data_to_device(batch: Any, device: torch.device):
+def move_data_to_device(batch: Any, device: Union[str, torch.device]) -> Any:
"""
Transfers a collection of data to the given device. Any object that defines a method
``to(device)`` will be moved and all other objects in the collection will be left untouched.
@@ -245,7 +251,7 @@ def move_data_to_device(batch: Any, device: torch.device):
- :class:`torch.device`
"""
- def batch_to(data):
+ def batch_to(data: Any) -> Any:
# try to move torchtext data first
if _TORCHTEXT_AVAILABLE and isinstance(data, Batch):
@@ -269,14 +275,14 @@ def batch_to(data):
return apply_to_collection(batch, dtype=dtype, function=batch_to)
-def convert_to_tensors(data: Any, device: torch.device) -> Any:
+def convert_to_tensors(data: Any, device: Union[str, torch.device]) -> Any:
if device is None:
raise MisconfigurationException("`torch.device` should be provided.")
for src_dtype, conversion_func in CONVERSION_DTYPES:
data = apply_to_collection(data, src_dtype, conversion_func, device=device)
- def _move_to_device_and_make_contiguous(t: torch.Tensor, device: torch.device) -> torch.Tensor:
+ def _move_to_device_and_make_contiguous(t: torch.Tensor, device: Union[str, torch.device]) -> torch.Tensor:
return t.to(device).contiguous()
data = apply_to_collection(data, torch.Tensor, _move_to_device_and_make_contiguous, device=device)
diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py
index c9e378dbadee6e..256168ae4382f3 100644
--- a/pytorch_lightning/utilities/auto_restart.py
+++ b/pytorch_lightning/utilities/auto_restart.py
@@ -16,8 +16,12 @@
from copy import deepcopy
from dataclasses import dataclass, field
from functools import partial, wraps
+from random import getstate as python_get_rng_state
+from random import setstate as python_set_rng_state
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Tuple, Union
+import numpy as np
+import torch
from torch.utils.data import Dataset, get_worker_info, Sampler
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset
@@ -168,6 +172,16 @@ def update(self, generator_name: Optional[str], new_state: IteratorState) -> Non
state[latest_worker_id] = new_state
self.latest_worker_id = latest_worker_id
+ @property
+ def sampler_states(self) -> Dict[int, Any]:
+ """Returns the merged sampler states for all worker processes."""
+ return {0: self.state[k].sampler_state[0] for k in self.state.keys()}
+
+ @property
+ def dataset_states(self) -> Dict[int, Any]:
+ """Returns the merged dataset states for all worker processes."""
+ return {k: self.state[k].dataset_state[k] for k in self.state.keys()}
+
@classmethod
def from_state_dict(cls, state_dict) -> "MergedIteratorState":
if state_dict["represent_map_dataset"]:
@@ -188,7 +202,12 @@ def __len__(self) -> int:
class CaptureMapDataset(Dataset):
- """This class is used to capture the state from the map-based state dataset."""
+ """This class is used to capture the state from the map-based state dataset.
+
+ Note:
+ We currently don't support restoring if we fail during the first `N = num_workers` batches, where
+ `num_workers` is the number of workers spawned by the dataloader.
+ """
def __init__(self, dataset: Dataset) -> None:
self.dataset = dataset
@@ -202,8 +221,7 @@ def worker_id(self) -> int:
def __getitem__(self, item) -> Tuple[Any, Dict[int, Dict]]:
if self._cached_state_dict is not None:
if self.worker_id in self._cached_state_dict:
- # TODO: reset random states
- pass
+ set_rng_states(self._cached_state_dict[self.worker_id]["rng_states"])
self._cached_state_dict = None
data = self.dataset[item]
@@ -227,7 +245,19 @@ def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num
self._cached_state_dict = state_dict
def _state_dict(self) -> Dict[int, Dict[str, Any]]:
- return {self.worker_id: {"rng_states": {}}}
+ return {self.worker_id: {"rng_states": collect_rng_states()}}
+
+
+def collect_rng_states() -> Dict[str, Any]:
+ """Collect the global random state of :mod:`torch`, :mod:`numpy` and Python."""
+ return {"torch": torch.get_rng_state(), "numpy": np.random.get_state(), "python": python_get_rng_state()}
+
+
+def set_rng_states(rng_state_dict: Dict[str, Any]) -> None:
+ """Set the global random state of :mod:`torch`, :mod:`numpy` and Python in the current process."""
+ torch.set_rng_state(rng_state_dict.get("torch"))
+ np.random.set_state(rng_state_dict.get("numpy"))
+ python_set_rng_state(rng_state_dict.get("python"))
class CaptureIterableDataset(IterableDataset):
diff --git a/pytorch_lightning/utilities/debugging.py b/pytorch_lightning/utilities/debugging.py
index ee2b58be106b58..c759f2aee28b7d 100644
--- a/pytorch_lightning/utilities/debugging.py
+++ b/pytorch_lightning/utilities/debugging.py
@@ -109,7 +109,7 @@ def track_load_dataloader_call(self, name: str, dataloaders: List[DataLoader]) -
@enabled_only
def track_train_loss_history(self, batch_idx: int, loss: torch.Tensor) -> None:
- loss_dict = {"batch_idx": batch_idx, "epoch": self.trainer.current_epoch, "loss": loss.detach()}
+ loss_dict = {"batch_idx": batch_idx, "epoch": self.trainer.current_epoch, "loss": loss.detach().clone()}
self.saved_train_losses.append(loss_dict)
@enabled_only
diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py
index 68868a0ff74cdc..4f254b6824489c 100644
--- a/pytorch_lightning/utilities/distributed.py
+++ b/pytorch_lightning/utilities/distributed.py
@@ -284,12 +284,14 @@ def register_ddp_comm_hook(
.. warning ::
DDP communication wrapper needs pytorch version at least 1.9.0
+ Post-localSGD hook needs pytorch version at least 1.9.0
Example:
from torch.distributed.algorithms.ddp_comm_hooks import (
default_hooks as default,
powerSGD_hook as powerSGD,
+ post_localSGD_hook as post_localSGD,
)
# fp16_compress_hook for compress gradients
@@ -309,6 +311,18 @@ def register_ddp_comm_hook(
ddp_comm_hook=powerSGD.powerSGD_hook,
)
+ # post_localSGD_hook
+ subgroup, _ = torch.distributed.new_subgroups()
+ register_comm_hook(
+ model=ddp_model,
+ state=post_localSGD.PostLocalSGDState(
+ process_group=None,
+ subgroup=subgroup,
+ start_localSGD_iter=1_000,
+ ),
+ ddp_comm_hook=post_localSGD.post_localSGD_hook,
+ )
+
# fp16_compress_wrapper combined with other communication hook
register_ddp_comm_hook(
model=ddp_model,
diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py
index dd31d23f6cf84d..ff0e5bea6bf19e 100644
--- a/pytorch_lightning/utilities/imports.py
+++ b/pytorch_lightning/utilities/imports.py
@@ -70,7 +70,6 @@ def _compare_version(package: str, op, version) -> bool:
_TORCH_GREATER_EQUAL_1_9 = _compare_version("torch", operator.ge, "1.9.0")
_TORCH_GREATER_EQUAL_1_10 = _compare_version("torch", operator.ge, "1.10.0")
-
_APEX_AVAILABLE = _module_available("apex.amp")
_BOLTS_AVAILABLE = _module_available("pl_bolts")
_DEEPSPEED_AVAILABLE = _module_available("deepspeed")
@@ -87,6 +86,9 @@ def _compare_version(package: str, op, version) -> bool:
_OMEGACONF_AVAILABLE = _module_available("omegaconf")
_POPTORCH_AVAILABLE = _module_available("poptorch")
_RICH_AVAILABLE = _module_available("rich")
+_TORCH_CPU_AMP_AVAILABLE = _compare_version(
+ "torch", operator.ge, "1.10.0dev20210501"
+) # todo: swap to 1.10.0 once released
_TORCH_BFLOAT_AVAILABLE = _compare_version(
"torch", operator.ge, "1.10.0.dev20210820"
) # todo: swap to 1.10.0 once released
diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py
index 7280afea02f76d..d695e4c63f43e6 100644
--- a/tests/accelerators/test_cpu.py
+++ b/tests/accelerators/test_cpu.py
@@ -1,7 +1,6 @@
import os
from pathlib import Path
from typing import Any, Dict, Union
-from unittest.mock import Mock
import pytest
import torch
@@ -10,22 +9,10 @@
from pytorch_lightning.accelerators import CPUAccelerator
from pytorch_lightning.plugins import SingleDevicePlugin
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
-from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
-from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel
-def test_unsupported_precision_plugins():
- """Test error messages are raised for unsupported precision plugins with CPU."""
- trainer = Mock()
- accelerator = CPUAccelerator(
- training_type_plugin=SingleDevicePlugin(torch.device("cpu")), precision_plugin=MixedPrecisionPlugin()
- )
- with pytest.raises(MisconfigurationException, match=r"AMP \+ CPU is not supported"):
- accelerator.setup(trainer=trainer)
-
-
@pytest.mark.parametrize("delay_dispatch", [True, False])
def test_plugin_setup_optimizers_in_pre_dispatch(tmpdir, delay_dispatch):
"""
diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py
index dc191f4853cc1d..c363638d565d2a 100644
--- a/tests/callbacks/test_callbacks.py
+++ b/tests/callbacks/test_callbacks.py
@@ -12,10 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
+from re import escape
from unittest.mock import call, Mock
+import pytest
+
from pytorch_lightning import Callback, Trainer
+from pytorch_lightning.callbacks import ModelCheckpoint
from tests.helpers import BoringModel
+from tests.helpers.utils import no_warning_call
def test_callbacks_configured_in_model(tmpdir):
@@ -132,3 +137,34 @@ def test_resume_callback_state_saved_by_type(tmpdir):
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback], resume_from_checkpoint=ckpt_path)
trainer.fit(model)
assert callback.state == 111
+
+
+def test_resume_incomplete_callbacks_list_warning(tmpdir):
+ model = BoringModel()
+ callback0 = ModelCheckpoint(monitor="epoch")
+ callback1 = ModelCheckpoint(monitor="global_step")
+ trainer = Trainer(
+ default_root_dir=tmpdir,
+ max_steps=1,
+ callbacks=[callback0, callback1],
+ )
+ trainer.fit(model)
+ ckpt_path = trainer.checkpoint_callback.best_model_path
+
+ trainer = Trainer(
+ default_root_dir=tmpdir,
+ max_steps=1,
+ callbacks=[callback1], # one callback is missing!
+ resume_from_checkpoint=ckpt_path,
+ )
+ with pytest.warns(UserWarning, match=escape(f"Please add the following callbacks: [{repr(callback0.state_key)}]")):
+ trainer.fit(model)
+
+ trainer = Trainer(
+ default_root_dir=tmpdir,
+ max_steps=1,
+ callbacks=[callback1, callback0], # all callbacks here, order switched
+ resume_from_checkpoint=ckpt_path,
+ )
+ with no_warning_call(UserWarning, match="Please add the following callbacks:"):
+ trainer.fit(model)
diff --git a/tests/core/test_lightning_optimizer.py b/tests/core/test_lightning_optimizer.py
index f8c3317b6c5958..4068bc5504b5bf 100644
--- a/tests/core/test_lightning_optimizer.py
+++ b/tests/core/test_lightning_optimizer.py
@@ -20,6 +20,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.core.optimizer import LightningOptimizer
+from pytorch_lightning.loops.closure import Closure
from tests.helpers.boring_model import BoringModel
@@ -221,7 +222,7 @@ def training_epoch_end(self, outputs):
...
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, **_):
- assert optimizer_closure.__name__ == "_training_step_and_backward_closure"
+ assert isinstance(optimizer_closure, Closure)
# not passing the closure to the optimizer because step is mocked
# zero_grad is called inside the closure
if isinstance(optimizer, SGD) and batch_idx % 2 == 0:
diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py
index 67b29baf98065e..7993b94fafda67 100644
--- a/tests/deprecated_api/test_remove_1-7.py
+++ b/tests/deprecated_api/test_remove_1-7.py
@@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Test deprecated functionality which will be removed in v1.7.0 """
+from unittest import mock
import pytest
from pytorch_lightning import LightningDataModule, Trainer
+from pytorch_lightning.loggers import TestTubeLogger
from tests.deprecated_api import _soft_unimport_module
from tests.helpers import BoringModel
from tests.helpers.datamodules import MNISTDataModule
@@ -88,7 +90,6 @@ def test_v1_7_0_trainer_prepare_data_per_node(tmpdir):
):
_ = Trainer(prepare_data_per_node=False)
-
def test_v1_7_0_deprecated_on_train_dataloader(tmpdir):
model = BoringModel()
@@ -108,3 +109,9 @@ def test_v1_7_0_deprecated_on_train_dataloader(tmpdir):
match="Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
):
model.on_predict_dataloader()
+
+@mock.patch("pytorch_lightning.loggers.test_tube.Experiment")
+def test_v1_7_0_test_tube_logger(_, tmpdir):
+ with pytest.deprecated_call(match="The TestTubeLogger is deprecated since v1.5 and will be removed in v1.7"):
+ _ = TestTubeLogger(tmpdir)
+
diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py
index 22d2be8c3a9b06..5fdd09d5fd4d48 100644
--- a/tests/loops/test_loop_state_dict.py
+++ b/tests/loops/test_loop_state_dict.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from unittest.mock import Mock
+from unittest.mock import ANY, Mock
import pytest
import torch
@@ -22,23 +22,31 @@
def test_loops_state_dict():
+ trainer = Trainer()
+ trainer.train_dataloader = Mock()
+
fit_loop = FitLoop()
with pytest.raises(MisconfigurationException, match="Loop FitLoop should be connected to a"):
fit_loop.trainer = object()
+ fit_loop.trainer = trainer
fit_loop.connect(Mock())
state_dict = fit_loop.state_dict()
+
new_fit_loop = FitLoop()
+ new_fit_loop.trainer = trainer
+
new_fit_loop.load_state_dict(state_dict)
assert fit_loop.state_dict() == new_fit_loop.state_dict()
def test_loops_state_dict_structure():
trainer = Trainer()
+ trainer.train_dataloader = Mock()
state_dict = trainer.checkpoint_connector._get_loops_state_dict()
expected = {
"fit_loop": {
- "state_dict": {},
+ "state_dict": {"dataloader_state_dict": ANY},
"epoch_loop.state_dict": {},
"epoch_loop.batch_progress": {
"total": {"ready": 0, "started": 0, "processed": 0, "completed": 0},
diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py
index 65cbebc8203e53..200b2daae93eda 100644
--- a/tests/loops/test_loops.py
+++ b/tests/loops/test_loops.py
@@ -504,7 +504,13 @@ def configure_optimizers_multiple(self):
assert checkpoint["loops"]["fit_loop"] == expected
trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"], restart_progress=False)
- assert trainer.fit_loop.state_dict() == checkpoint["loops"]["fit_loop"]
+ state_dict = trainer.fit_loop.state_dict()
+
+ # need to remove these elements for comparison; comparing with `fit_loop.state_dict()` would require the
+ # fit loop to have an iterator, which is only available during training
+ checkpoint["loops"]["fit_loop"]["state_dict"]["dataloader_state_dict"] = ANY
+
+ assert state_dict == checkpoint["loops"]["fit_loop"]
trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"])
state_dict = trainer.fit_loop.state_dict()
diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py
index 173d1f8a0d1f8e..5fdc180613e676 100644
--- a/tests/models/test_amp.py
+++ b/tests/models/test_amp.py
@@ -22,7 +22,7 @@
import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.plugins.environments import SLURMEnvironment
-from pytorch_lightning.utilities import _TORCH_BFLOAT_AVAILABLE
+from pytorch_lightning.utilities import _TORCH_BFLOAT_AVAILABLE, _TORCH_CPU_AMP_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.runif import RunIf
@@ -30,13 +30,19 @@
class AMPTestModel(BoringModel):
def _step(self, batch, batch_idx):
- assert torch.is_autocast_enabled()
+ self._assert_autocast_enabled()
output = self(batch)
bfloat16 = self.trainer.precision_plugin.is_bfloat16
assert output.dtype == torch.float16 if not bfloat16 else torch.bfloat16
loss = self.loss(batch, output)
return loss
+ def loss(self, batch, prediction):
+ # todo (sean): convert bfloat16 to float32 as mse loss for cpu amp is currently not supported
+ if self.trainer.precision_plugin.use_cpu:
+ prediction = prediction.float()
+ return super().loss(batch, prediction)
+
def training_step(self, batch, batch_idx):
output = self._step(batch, batch_idx)
return {"loss": output}
@@ -50,17 +56,58 @@ def test_step(self, batch, batch_idx):
return {"y": output}
def predict(self, batch, batch_idx, dataloader_idx=None):
- assert torch.is_autocast_enabled()
+ self._assert_autocast_enabled()
output = self(batch)
bfloat16 = self.trainer.precision_plugin.is_bfloat16
assert output.dtype == torch.float16 if not bfloat16 else torch.bfloat16
return output
+ def _assert_autocast_enabled(self):
+ if self.trainer.precision_plugin.use_cpu:
+ assert torch.is_autocast_cpu_enabled()
+ else:
+ assert torch.is_autocast_enabled()
+
+
+@pytest.mark.skipif(not _TORCH_CPU_AMP_AVAILABLE, reason="CPU AMP not available")
+@pytest.mark.parametrize(
+ "accelerator",
+ [
+ None,
+ pytest.param("dp", marks=pytest.mark.skip("dp + amp not supported currently")), # TODO
+ "ddp_spawn",
+ ],
+)
+@pytest.mark.parametrize(
+ "precision",
+ [
+ pytest.param(16, marks=pytest.mark.skip("CPU precision 16 is not supported in PyTorch yet.")), # TODO
+ "bf16",
+ ],
+)
+@pytest.mark.parametrize("num_processes", [1, 2])
+def test_amp_cpus(tmpdir, accelerator, precision, num_processes):
+ """Make sure combinations of AMP and training types work if supported."""
+ tutils.reset_seed()
+
+ trainer = Trainer(
+ default_root_dir=tmpdir, num_processes=num_processes, max_epochs=1, accelerator=accelerator, precision=precision
+ )
+
+ model = AMPTestModel()
+ # tutils.run_model_test(trainer_options, model)
+ trainer.fit(model)
+ trainer.test(model)
+ trainer.predict(model, DataLoader(RandomDataset(32, 64)))
+
+ assert trainer.state.finished, f"Training failed with {trainer.state}"
+
@RunIf(min_gpus=2)
@pytest.mark.parametrize(
"accelerator",
[
+ None,
pytest.param("dp", marks=pytest.mark.skip("dp + amp not supported currently")), # TODO
"ddp_spawn",
],
diff --git a/tests/plugins/test_amp_plugins.py b/tests/plugins/test_amp_plugins.py
index 15ec43973b0ed6..930cbd87972480 100644
--- a/tests/plugins/test_amp_plugins.py
+++ b/tests/plugins/test_amp_plugins.py
@@ -21,6 +21,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
+from pytorch_lightning.utilities import _TORCH_CPU_AMP_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
@@ -190,3 +191,50 @@ def test_amp_precision_16_bfloat_throws_error(tmpdir):
precision="bf16",
gpus=1,
)
+
+
+@RunIf(amp_native=True, max_torch="1.9")
+def test_cpu_amp_precision_throws_error(tmpdir):
+ with pytest.raises(
+ MisconfigurationException,
+ match="To use native AMP on CPU, install PyTorch 1.10 or later.",
+ ):
+ NativeMixedPrecisionPlugin(use_cpu=True)
+
+
+@pytest.mark.skipif(not _TORCH_CPU_AMP_AVAILABLE, reason="Torch CPU AMP is not available.")
+@RunIf(
+ min_gpus=1,
+ amp_native=True,
+)
+def test_cpu_amp_precision_context_manager(tmpdir):
+ """
+ Test to ensure that the context manager correctly is set to CPU + bfloat16, and a scaler isn't set.
+ """
+
+ plugin = NativeMixedPrecisionPlugin(precision="bf16", use_cpu=True)
+ assert plugin.use_cpu
+ assert not hasattr(plugin, "scaler")
+ context_manager = plugin.autocast_context_manager()
+ assert isinstance(context_manager, torch.cpu.amp.autocast)
+ assert context_manager.fast_dtype == torch.bfloat16
+
+
+@pytest.mark.skipif(not _TORCH_CPU_AMP_AVAILABLE, reason="Torch CPU AMP is not available.")
+@RunIf(
+ min_gpus=1,
+ amp_native=True,
+)
+def test_cpu_amp_precision_16_throws_error(tmpdir):
+ """
+ Throw error when using 16 as Native CPU AMP only supports bfloat16.
+ """
+
+ with pytest.raises(
+ MisconfigurationException,
+ match="CPU native amp only supports bfloat16. Please pass precision='bf16' to the Trainer.",
+ ):
+ Trainer(
+ default_root_dir=tmpdir,
+ precision=16,
+ )
diff --git a/tests/plugins/test_ddp_plugin_with_comm_hook.py b/tests/plugins/test_ddp_plugin_with_comm_hook.py
index 49c4cd18ef316f..1e5968f1a0a5a1 100644
--- a/tests/plugins/test_ddp_plugin_with_comm_hook.py
+++ b/tests/plugins/test_ddp_plugin_with_comm_hook.py
@@ -15,13 +15,15 @@
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DDPPlugin, DDPSpawnPlugin
-from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8
+from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_10
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
if torch.distributed.is_available() and _TORCH_GREATER_EQUAL_1_8:
from torch.distributed.algorithms.ddp_comm_hooks import default_hooks as default
from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook as powerSGD
+if torch.distributed.is_available() and _TORCH_GREATER_EQUAL_1_10:
+ import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
@RunIf(skip_windows=True, min_torch="1.9.0", min_gpus=2, special=True)
@@ -108,3 +110,32 @@ def test_ddp_spawn_fp16_compress_comm_hook(tmpdir):
)
trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"
+
+
+@RunIf(skip_windows=True, min_torch="1.10.0", min_gpus=2, special=True)
+def test_ddp_post_local_sgd_comm_hook(tmpdir):
+ """Test for DDP post-localSGD hook."""
+ model = BoringModel()
+
+ training_type_plugin = DDPPlugin(
+ ddp_comm_state=post_localSGD.PostLocalSGDState(
+ process_group=None,
+ subgroup=None,
+ start_localSGD_iter=8,
+ ),
+ ddp_comm_hook=post_localSGD.post_localSGD_hook,
+ model_averaging_period=4,
+ sync_batchnorm=True,
+ )
+ trainer = Trainer(
+ fast_dev_run=True,
+ gpus=2,
+ plugins=[training_type_plugin],
+ default_root_dir=tmpdir,
+ sync_batchnorm=True,
+ )
+ trainer.fit(model)
+ trainer_comm_hook = trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook
+ expected_comm_hook = post_localSGD.post_localSGD_hook.__qualname__
+ assert trainer_comm_hook == expected_comm_hook
+ assert trainer.state.finished, f"Training failed with {trainer.state}"
diff --git a/tests/trainer/loops/test_evaluation_loop_flow.py b/tests/trainer/loops/test_evaluation_loop_flow.py
index 6c0b7b7ed40d7c..78dd4bce728d42 100644
--- a/tests/trainer/loops/test_evaluation_loop_flow.py
+++ b/tests/trainer/loops/test_evaluation_loop_flow.py
@@ -78,10 +78,11 @@ def backward(self, loss, optimizer, optimizer_idx):
assert train_step_out.minimize.item() == 171
# make sure the optimizer closure returns the correct things
- opt_closure_result = trainer.fit_loop.epoch_loop.batch_loop.training_step_and_backward(
+ opt_closure = trainer.fit_loop.epoch_loop.batch_loop._make_closure(
batch, batch_idx, 0, trainer.optimizers[0], hiddens=None
)
- assert opt_closure_result["loss"].item() == 171
+ opt_closure_result = opt_closure()
+ assert opt_closure_result.item() == 171
def test__eval_step__eval_step_end__flow(tmpdir):
@@ -144,10 +145,11 @@ def backward(self, loss, optimizer, optimizer_idx):
assert train_step_out.minimize.item() == 171
# make sure the optimizer closure returns the correct things
- opt_closure_result = trainer.fit_loop.epoch_loop.batch_loop.training_step_and_backward(
+ opt_closure = trainer.fit_loop.epoch_loop.batch_loop._make_closure(
batch, batch_idx, 0, trainer.optimizers[0], hiddens=None
)
- assert opt_closure_result["loss"].item() == 171
+ opt_closure_result = opt_closure()
+ assert opt_closure_result.item() == 171
def test__eval_step__epoch_end__flow(tmpdir):
diff --git a/tests/trainer/loops/test_training_loop_flow_scalar.py b/tests/trainer/loops/test_training_loop_flow_scalar.py
index 4ee9d858d44c95..692d2420bfd82d 100644
--- a/tests/trainer/loops/test_training_loop_flow_scalar.py
+++ b/tests/trainer/loops/test_training_loop_flow_scalar.py
@@ -18,6 +18,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.core.lightning import LightningModule
+from pytorch_lightning.loops.closure import Closure
from pytorch_lightning.trainer.states import RunningStage
from tests.helpers.boring_model import BoringModel, RandomDataset
from tests.helpers.deterministic_model import DeterministicModel
@@ -156,10 +157,11 @@ def backward(self, loss, optimizer, optimizer_idx):
assert train_step_out.minimize.item() == 171
# make sure the optimizer closure returns the correct things
- opt_closure_result = trainer.fit_loop.epoch_loop.batch_loop.training_step_and_backward(
+ opt_closure = trainer.fit_loop.epoch_loop.batch_loop._make_closure(
batch, batch_idx, 0, trainer.optimizers[0], hiddens=None
)
- assert opt_closure_result["loss"].item() == 171
+ opt_closure_result = opt_closure()
+ assert opt_closure_result.item() == 171
def test__training_step__step_end__epoch_end__flow_scalar(tmpdir):
@@ -229,10 +231,11 @@ def backward(self, loss, optimizer, optimizer_idx):
assert train_step_out.minimize.item() == 171
# make sure the optimizer closure returns the correct things
- opt_closure_result = trainer.fit_loop.epoch_loop.batch_loop.training_step_and_backward(
+ opt_closure = trainer.fit_loop.epoch_loop.batch_loop._make_closure(
batch, batch_idx, 0, trainer.optimizers[0], hiddens=None
)
- assert opt_closure_result["loss"].item() == 171
+ opt_closure_result = opt_closure()
+ assert opt_closure_result.item() == 171
def test_train_step_no_return(tmpdir):
@@ -258,6 +261,8 @@ def validation_epoch_end(self, outputs):
trainer = Trainer(**trainer_args)
+ Closure.warning_cache.clear()
+
with pytest.warns(UserWarning, match=r"training_step returned None.*"):
trainer.fit(model)
@@ -268,6 +273,8 @@ def validation_epoch_end(self, outputs):
model.automatic_optimization = False
trainer = Trainer(**trainer_args)
+ Closure.warning_cache.clear()
+
with no_warning_call(UserWarning, match=r"training_step returned None.*"):
trainer.fit(model)
@@ -293,6 +300,8 @@ def training_step(self, batch, batch_idx):
checkpoint_callback=False,
)
+ Closure.warning_cache.clear()
+
with pytest.warns(UserWarning, match=r".*training_step returned None.*"):
trainer.fit(model)
diff --git a/tests/trainer/properties/log_dir.py b/tests/trainer/properties/test_log_dir.py
similarity index 100%
rename from tests/trainer/properties/log_dir.py
rename to tests/trainer/properties/test_log_dir.py
diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py
index e665fc79e4323f..ca5b9084591712 100644
--- a/tests/utilities/test_auto_restart.py
+++ b/tests/utilities/test_auto_restart.py
@@ -14,9 +14,13 @@
import math
import os
import random
+import random as python_random
from collections.abc import Iterable
-from typing import Optional
+from contextlib import suppress
+from copy import deepcopy
+from typing import List, Optional
from unittest import mock
+from unittest.mock import ANY
import numpy as np
import pytest
@@ -29,16 +33,19 @@
from torch.utils.data.dataset import Dataset, IterableDataset
import tests.helpers.utils as tutils
-from pytorch_lightning import Callback, seed_everything, Trainer
+from pytorch_lightning import Callback, LightningModule, seed_everything, Trainer
from pytorch_lightning.utilities.auto_restart import (
_add_capture_metadata_collate,
_dataloader_load_state_dict,
_dataloader_to_state_dict,
CaptureIterableDataset,
+ CaptureMapDataset,
FastForwardSampler,
+ MergedIteratorState,
)
from pytorch_lightning.utilities.enums import AutoRestartBatchKeys
from pytorch_lightning.utilities.exceptions import MisconfigurationException
+from pytorch_lightning.utilities.fetching import DataFetcher
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf
@@ -671,7 +678,11 @@ def create_dataloader():
_ = next(iter_dataloader)
state_dict = _dataloader_to_state_dict(dataloader, iter_dataloader)
- assert state_dict[0]["current_iteration"] == 16
+ assert state_dict == {
+ "num_workers": 0,
+ "previous_worker": None,
+ 0: {"current_iteration": 16},
+ }
dataloader = create_dataloader()
dataloader = _dataloader_load_state_dict(dataloader, state_dict)
@@ -679,14 +690,18 @@ def create_dataloader():
_ = next(iter_dataloader)
state_dict = _dataloader_to_state_dict(dataloader, iter_dataloader)
- assert state_dict[0]["current_iteration"] == 24
+ assert state_dict == {
+ "num_workers": 0,
+ "previous_worker": None,
+ 0: {"current_iteration": 24},
+ }
@RunIf(min_torch="1.7.0")
@pytest.mark.parametrize("use_fault_tolerant", ["0", "1"])
def test_data_loading_wraps_dataset_and_samplers(use_fault_tolerant, tmpdir):
"""
- this test ensures the dataset and sampler are properly wrapped when fault tolerant is enabled.
+ This test ensures the dataset and sampler are properly wrapped when fault tolerant is enabled.
"""
class CustomBatchSampler(BatchSampler):
@@ -713,7 +728,15 @@ def train_dataloader(self):
}
def training_step(self, batch, batch_idx):
- pass
+ assert batch == {
+ "a": [ANY, ANY, ANY],
+ "b": ANY,
+ }
+
+ def validation_step(self, batch, batch_idx):
+ assert isinstance(batch, torch.Tensor)
+
+ validation_epoch_end = None
class Check(Callback):
def on_train_batch_start(self, trainer, *_) -> None:
@@ -721,12 +744,16 @@ def on_train_batch_start(self, trainer, *_) -> None:
if use_fault_tolerant == "1":
assert isinstance(loaders["a"][0].loader.dataset, CaptureIterableDataset)
assert isinstance(loaders["a"][1].loader.sampler, FastForwardSampler)
+ assert isinstance(loaders["a"][1].loader.dataset, CaptureMapDataset)
assert isinstance(loaders["a"][2].loader.batch_sampler, FastForwardSampler)
+ assert isinstance(loaders["a"][2].loader.dataset, CaptureMapDataset)
assert isinstance(loaders["b"].loader.dataset, CaptureIterableDataset)
else:
assert isinstance(loaders["a"][0].loader.dataset, RangeIterableDataset)
assert isinstance(loaders["a"][1].loader.sampler, SequentialSampler)
+ assert not isinstance(loaders["a"][1].loader.dataset, CaptureMapDataset)
assert isinstance(loaders["a"][2].loader.batch_sampler, CustomBatchSampler)
+ assert not isinstance(loaders["a"][2].loader.dataset, CaptureMapDataset)
assert isinstance(loaders["b"].loader.dataset, RangeIterableDataset)
with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": use_fault_tolerant}):
@@ -734,3 +761,210 @@ def on_train_batch_start(self, trainer, *_) -> None:
model.training_epoch_end = None
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_train_batches=1, callbacks=Check())
trainer.fit(model)
+
+
+class SequentialGetItemDataset(Dataset):
+ def __init__(self, length, *_):
+ self.len = length
+
+ def __getitem__(self, index):
+ return torch.tensor([index]).float()
+
+ def __len__(self):
+ return self.len
+
+
+class RandomGetItemDataset(Dataset):
+ """A dataset with random elements generated using global rng from torch, numpy and python."""
+
+ def __init__(self, length, size):
+ self.size = size
+ self.len = length
+
+ def __getitem__(self, index):
+ t = torch.rand(self.size)
+ n = torch.from_numpy(np.random.rand(self.size))
+ p = torch.tensor([python_random.random() for _ in range(self.size)])
+ sample = (index + (t + n + p) / 10).float()
+ return sample
+
+ def __len__(self):
+ return self.len
+
+
+# TODO: test with `RandomGeneratorGetItemDataset`
+@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
+@RunIf(min_torch="1.7.0")
+@pytest.mark.parametrize(
+ "dataset_class",
+ [
+ SequentialGetItemDataset,
+ RandomGetItemDataset,
+ # RandomGeneratorGetItemDataset,
+ ],
+)
+@pytest.mark.parametrize("num_workers", [0])
+@pytest.mark.parametrize("batch_size", [1, 2, 3])
+def test_dataset_rng_states_restart(dataset_class, num_workers, batch_size):
+ """Test that the sequence of batches coming from a random number generator continues with the correct sequence
+ after reloading the state.
+ """
+
+ def create_dataset_sampler():
+ dset = CaptureMapDataset(dataset_class(16, 8))
+ random_sampler = RandomSampler(dset, generator=torch.Generator())
+ return dset, random_sampler
+
+ def create_dataloader_sampler(dset, sampler):
+ sampler = FastForwardSampler(sampler)
+ sampler.setup(batch_size)
+ dl = DataLoader(dset, num_workers=num_workers, sampler=sampler, batch_size=batch_size)
+ _add_capture_metadata_collate(dl)
+ return dl, sampler
+
+ def fetch(fetcher, prefetch_iter, num_batches_fetched):
+ batch, _ = next(prefetch_iter)
+
+ state: List[MergedIteratorState] = fetcher.state
+ assert len(state) == 1
+ assert isinstance(state[0], MergedIteratorState)
+
+ assert len(fetcher.dataloader_iter.cache_states) == 1
+ if num_workers == 0:
+ assert state[0].state[0].num_batches_fetched == num_batches_fetched
+ return state
+
+ dataset, random_sampler = create_dataset_sampler()
+ dataloader, ff_sampler = create_dataloader_sampler(dataset, random_sampler)
+
+ fetcher = DataFetcher()
+ fetcher.setup(dataloader)
+ prefetch_iter = iter(fetcher)
+
+ # fetch 4 batches
+ fetch(fetcher, prefetch_iter, 1)
+ fetch(fetcher, prefetch_iter, 2)
+ fetch(fetcher, prefetch_iter, 3)
+
+ # (A) capture the state after fetching 4 batches
+ state = fetch(fetcher, prefetch_iter, 4)
+ state = deepcopy(state[0])
+
+ # (B) simulate 2 additional batches
+ batch05, _ = next(prefetch_iter)
+ batch06, _ = next(prefetch_iter)
+
+ # start reloading
+ dataset, random_sampler = create_dataset_sampler()
+ dataloader, ff_sampler = create_dataloader_sampler(dataset, random_sampler)
+
+ # load the state dict saved at (A)
+ ff_sampler.load_state_dict(state.sampler_states)
+ dataset.load_state_dict(state.dataset_states, latest_worker_id=state.latest_worker_id, num_workers=num_workers)
+
+ prefetcher = DataFetcher()
+ prefetcher.setup(dataloader)
+ prefetch_iter = iter(prefetcher)
+
+ # fetch 2 random batches, these should match exactly the batches seen at (B)
+ batch05_restart, _ = next(prefetch_iter)
+ batch06_restart, _ = next(prefetch_iter)
+
+ assert torch.equal(batch05, batch05_restart)
+ assert torch.equal(batch06, batch06_restart)
+
+
+class CustomException(Exception):
+ pass
+
+
+class SequentialIterableDataset(IterableDataset):
+ def __init__(self, length, *_):
+ self.len = length
+ self.sampler = SequentialSampler(range(self.len))
+
+ def __iter__(self):
+ self.sampler_iter = iter(self.sampler)
+ return self
+
+ def __next__(self):
+ indice = next(self.sampler_iter)
+ return torch.tensor([indice]).float()
+
+
+class TestModel(LightningModule):
+ def __init__(self, fail_on_step: int = -1):
+ super().__init__()
+ self.layer = torch.nn.Linear(1, 2)
+ self.seen_batches = []
+ self.fail_on_step = fail_on_step
+
+ def training_step(self, batch, batch_idx):
+ if self.global_step == self.fail_on_step:
+ raise CustomException()
+ self.seen_batches.append(torch.stack(batch) if isinstance(batch, list) else batch)
+ loss = sum(self.layer(b).sum() for b in batch)
+ return loss
+
+ def configure_optimizers(self):
+ return torch.optim.SGD(self.layer.parameters(), lr=0.1)
+
+
+def _run_training(trainer_kwargs, dataset_classes, fail_on_step: int = -1):
+ seed_everything(1)
+ train_dataloader = [
+ DataLoader(dataset_class(3, 1), batch_size=1, num_workers=0) for dataset_class in dataset_classes
+ ]
+ train_dataloader = train_dataloader[0] if len(train_dataloader) == 1 else train_dataloader
+ model = TestModel(fail_on_step=fail_on_step)
+ trainer = Trainer(**trainer_kwargs)
+ with suppress(CustomException):
+ trainer.fit(model, train_dataloader=train_dataloader)
+ return model.seen_batches
+
+
+@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
+@RunIf(min_torch="1.7.0")
+@pytest.mark.parametrize(
+ "dataset_classes",
+ [
+ # single training dataset
+ [RandomGetItemDataset],
+ [SequentialIterableDataset],
+ # multiple training datasets (combinded dataloader)
+ [SequentialGetItemDataset, SequentialIterableDataset],
+ [SequentialIterableDataset, SequentialIterableDataset],
+ # [RandomGetItemDataset, RandomGetItemDataset], # TODO: support in the future
+ ],
+)
+@pytest.mark.parametrize("multiple_trainloader_mode", ["min_size", "max_size_cycle"])
+def test_dataset_rng_states_restart_with_lightning(tmpdir, dataset_classes, multiple_trainloader_mode):
+ """Test that the Trainer can resume from a failed run in the case of several types of datasets."""
+ trainer_kwargs = dict(
+ default_root_dir=tmpdir,
+ max_epochs=3,
+ weights_summary=None,
+ progress_bar_refresh_rate=0,
+ multiple_trainloader_mode=multiple_trainloader_mode,
+ )
+
+ all_batches = _run_training(trainer_kwargs, dataset_classes)
+ all_batches = torch.stack(all_batches)
+ assert len(all_batches) == 9
+
+ # Simulate 1st failure
+ complete_batches = _run_training(trainer_kwargs, dataset_classes, fail_on_step=4)
+ assert len(complete_batches) == 4
+
+ checkpoint_path = os.path.join(tmpdir, ".pl_auto_save.ckpt")
+ assert os.path.exists(checkpoint_path)
+
+ # Resume after failure
+ trainer_kwargs.update(resume_from_checkpoint=checkpoint_path)
+ resumed_batches = _run_training(trainer_kwargs, dataset_classes, fail_on_step=-1)
+ assert len(resumed_batches) == 5
+
+ # the resumed batches should match the batches of the successful training
+ all_batches_resumed = torch.stack(complete_batches + resumed_batches)
+ assert len(all_batches_resumed) == 9
+ assert torch.equal(all_batches, all_batches_resumed)