diff --git a/.github/workflows/ci_test-tpu.yml b/.github/workflows/ci_test-tpu.yml
deleted file mode 100644
index 22bb7bd7cd4e5..0000000000000
--- a/.github/workflows/ci_test-tpu.yml
+++ /dev/null
@@ -1,144 +0,0 @@
-name: TPU tests
-
-on:
- push:
- branches: [master, "release/*"]
-# TODO: temporal disable TPU testing until we find way how to pass credentials to forked PRs
-# pull_request:
-# branches:
-# - master
-
-env:
- GKE_CLUSTER: lightning-cluster
- GKE_ZONE: us-central1-a
- IMAGE: gcr.io/${{ secrets.GKE_PROJECT }}/tpu-testing-image
- MAX_CHECKS: 360
- CHECK_SPEEP: 5
-
-jobs:
- setup-build-publish-deploy:
- name: tpu-testing-job
- runs-on: ubuntu-20.04
- strategy:
- fail-fast: false
- matrix:
- python-version: [3.7]
- xla-version: [1.6, 1.8]
- # Timeout: https://stackoverflow.com/a/59076067/4521646
- timeout-minutes: 50
-
- steps:
- - name: Set IMAGETAG
- run: echo "IMAGETAG=$(date +%s)_${{ matrix.python-version }}" >> $GITHUB_ENV
- - name: Install Go
- uses: actions/setup-go@v2
- with:
- go-version: 1.14.x
- - name: Set up Python 3.7
- uses: actions/setup-python@v2
- with:
- python-version: 3.7
-
- - name: Checkout Pytorch Lightning
- uses: actions/checkout@v2
- with:
- repository: PyTorchLightning/pytorch-lightning
- ref: ${{ github.event.pull_request.head.sha }}
-
- - name: Checkout ml-testing-accelerators
- uses: actions/checkout@v2
- with:
- repository: GoogleCloudPlatform/ml-testing-accelerators
- path: ml-testing-accelerators
- ref: 5e88ac24f631c27045e62f0e8d5dfcf34e425e25
-
- - name: Setup gcloud CLI
- uses: GoogleCloudPlatform/github-actions/setup-gcloud@master
- with:
- version: '290.0.1'
- service_account_key: ${{ secrets.GKE_SA_KEY_BASE64 }}
- project_id: ${{ secrets.GKE_PROJECT }}
- export_default_credentials: true
-
- # Configure Docker to use the gcloud command-line tool as a credential helper for authentication.
- - name: Configure Docker
- run: |-
- gcloud --quiet auth configure-docker
- shell: bash
- - name: Build and Push Docker Image
- env:
- PYTHON_VER: ${{ matrix.python-version }}
- XLA_VER: ${{ matrix.xla-version }}
- run: |
- #cd dockers/tpu-tests
- docker build --tag "$IMAGE:$IMAGETAG" -f ./dockers/tpu-tests/Dockerfile --build-arg "PYTHON_VERSION=$PYTHON_VER" --build-arg "PYTORCH_VERSION=$XLA_VER" .
- docker push "$IMAGE:$IMAGETAG"
- shell: bash
-
- - name: Install jsonnet
- run: |-
- go get github.com/google/go-jsonnet/cmd/jsonnet
- shell: bash
- # Get the GKE credentials so we can deploy to the cluster
- # Use either zone or region depending on cluster setup.
- - run: |-
- gcloud container clusters get-credentials "$GKE_CLUSTER" --zone "$GKE_ZONE"
- shell: bash
-
- - name: Deploy the job on the kubernetes cluster
- env:
- XLA_VER: ${{ matrix.xla-version }}
- run: |-
- python -c "fname = 'dockers/tpu-tests/tpu_test_cases.jsonnet' ; ttt = open(fname).read().replace('pytorch-VERSION', 'pytorch-$XLA_VER') ; open(fname, 'w').write(ttt)"
- job_name=$(jsonnet -J ml-testing-accelerators/ dockers/tpu-tests/tpu_test_cases.jsonnet --ext-str image=$IMAGE --ext-str image-tag=$IMAGETAG | kubectl create -f -) && \
- job_name=${job_name#job.batch/} && \
- job_name=${job_name% created} && \
- echo "Waiting on kubernetes job: $job_name in cluster: $GKE_CLUSTER" && \
- i=0 && \
- # 60 checks spaced 30s apart = 900s total.
- status_code=2 && \
- # Check on the job periodically. Set the status code depending on what
- # happened to the job in Kubernetes. If we try MAX_CHECKS times and
- # still the job hasn't finished, give up and return the starting
- # non-zero status code.
- printf "Waiting for job to finish: " && \
- while [ $i -lt $MAX_CHECKS ]; do ((i++)); if kubectl get jobs $job_name -o jsonpath='Failed:{.status.failed}' | grep "Failed:1"; then status_code=1 && break; elif kubectl get jobs $job_name -o jsonpath='Succeeded:{.status.succeeded}' | grep "Succeeded:1" ; then status_code=0 && break; else printf "." ; fi; sleep $CHECK_SPEEP; done && \
- echo "Done waiting. Job status code: $status_code" && \
- pod_name=$(kubectl get po -l controller-uid=`kubectl get job $job_name -o "jsonpath={.metadata.labels.controller-uid}"` | awk 'match($0,!/NAME/) {print $1}') && \
- echo "GKE pod name: $pod_name" && \
- kubectl logs -f $pod_name --container=train > /tmp/full_output.txt
- if grep -q '' /tmp/full_output.txt ; then csplit /tmp/full_output.txt '//'; else mv /tmp/full_output.txt xx00; fi && \
- # First portion is the test logs. Print these to Github Action stdout.
- cat xx00 && \
- echo "Done with log retrieval attempt." && \
- gcloud container images delete "$IMAGE:$IMAGETAG" --force-delete-tags && \
- echo "Status code: $status_code"
- exit $status_code
- shell: bash
-
- - name: Statistics
- if: success()
- run: |
- mv ./xx01 coverage
- # TODO: add human readable report
- cat coverage
- # sudo pip install pycobertura
- # pycobertura show coverage.xml
-
- - name: Upload coverage results
- uses: actions/upload-artifact@v2
- with:
- name: coverage-TPU
- path: coverage
-
- - name: Upload coverage to Codecov
- uses: codecov/codecov-action@v1
- # see: https://github.com/actions/toolkit/issues/399
- continue-on-error: true
- if: always()
- with:
- token: ${{ secrets.CODECOV_TOKEN }}
- file: coverage
- flags: tpu,pytest
- name: TPU-coverage
- fail_ci_if_error: true
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 7a4ee3f7ddb3a..eb4c75bd46bbf 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -19,7 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a flavor of `training_step` that takes `dataloader_iter` as an argument ([#8807](https://github.com/PyTorchLightning/pytorch-lightning/pull/8807))
-- Added `state_id` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886))
+- Added `state_key` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886))
- Progress tracking
@@ -56,12 +56,20 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added `CaptureMapDataset` for state management in map-style datasets ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891))
* Added Fault Tolerant Training to `DataFetcher` ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891))
* Replaced old prefetch iterator with new `DataFetcher` in training loop ([#8953](https://github.com/PyTorchLightning/pytorch-lightning/pull/8953))
+ * Added partial support for global random state fault-tolerance in map-style datasets ([#8950](https://github.com/PyTorchLightning/pytorch-lightning/pull/8950))
- Checkpoint saving & loading extensibility:
* Added `CheckpointIO` to expose checkpoint IO from training type plugin ([#8743](https://github.com/PyTorchLightning/pytorch-lightning/pull/8743))
* Refactored CheckpointConnector to offload validating logic to the checkpoitn IO plugin ([#9045](https://github.com/PyTorchLightning/pytorch-lightning/pull/9045))
+- Loop customization:
+ * Added `Closure` and `AbstractClosure` classes ([#8642](https://github.com/PyTorchLightning/pytorch-lightning/pull/8642))
+
+
+- Added support for saving and loading state of multiple callbacks of the same type ([#7187](https://github.com/PyTorchLightning/pytorch-lightning/pull/7187))
+
+
- Added DeepSpeed Stage 1 support ([#8974](https://github.com/PyTorchLightning/pytorch-lightning/pull/8974))
@@ -86,6 +94,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added Rich Progress Bar ([#8929](https://github.com/PyTorchLightning/pytorch-lightning/pull/8929))
+- Added validate logic for precision ([#9080](https://github.com/PyTorchLightning/pytorch-lightning/pull/9080))
+
+
+- Add support for CPU AMP autocast ([#9084](https://github.com/PyTorchLightning/pytorch-lightning/pull/9084))
+
+
### Changed
- Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770))
@@ -140,10 +154,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `DataModule` properties: `train_transforms`, `val_transforms`, `test_transforms`, `size`, `dims` ([#8851](https://github.com/PyTorchLightning/pytorch-lightning/pull/8851))
-- Deprecated `prepare_data_per_node` flag on Trainer and set it as a property of `DataHooks`, accessible in the `LightningModule` and `LightningDataModule` [#8958](https://github.com/PyTorchLightning/pytorch-lightning/pull/8958)
+- Deprecated `prepare_data_per_node` flag on Trainer and set it as a property of `DataHooks`, accessible in the `LightningModule` and `LightningDataModule` ([#8958](https://github.com/PyTorchLightning/pytorch-lightning/pull/8958))
+
+
+- Deprecated the `TestTubeLogger` ([#9065](https://github.com/PyTorchLightning/pytorch-lightning/pull/9065))
+
+
+- Updated deprecation of `argparse_utils.py` from removal in 1.4 to 2.0 ([#9162](https://github.com/PyTorchLightning/pytorch-lightning/pull/9162))
--
### Removed
@@ -198,6 +217,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed `on_train_epoch_end` from `Accelerator` ([#9035](https://github.com/PyTorchLightning/pytorch-lightning/pull/9035))
+- Removed `InterBatchProcessor` in favor of `DataLoaderIterDataFetcher` ([#9052](https://github.com/PyTorchLightning/pytorch-lightning/pull/9052))
+
+
+- Removed `Plugin` in `base_plugin.py`, access `TrainingTypePlugin` and `PrecisionPlugin` directly instead ([#9066](https://github.com/PyTorchLightning/pytorch-lightning/pull/9066))
+
+
+- Removed `teardown` from `ParallelPlugin` ([#8943](https://github.com/PyTorchLightning/pytorch-lightning/pull/8943))
+
+
### Fixed
- Fixed save/load/resume from checkpoint for DeepSpeed Plugin (
@@ -211,9 +239,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed bug where data-loading functions where not getting the correct running stage passed ([#8858](https://github.com/PyTorchLightning/pytorch-lightning/pull/8858))
+
- Fixed a bug in the binary search mode of auto batch size scaling where exception was thrown if the first trainer run resulted in OOM ([#8954](https://github.com/PyTorchLightning/pytorch-lightning/pull/8954))
+- Fixed not setting a default value for `max_epochs` if `max_time` was specified on the `Trainer` constructor ([#9072](https://github.com/PyTorchLightning/pytorch-lightning/pull/9072))
+
+
## [1.4.3] - 2021-08-17
- Fixed plateau scheduler stepping on incomplete epoch ([#8861](https://github.com/PyTorchLightning/pytorch-lightning/pull/8861))
diff --git a/README.md b/README.md
index d31d652850458..c80995293ae9d 100644
--- a/README.md
+++ b/README.md
@@ -78,14 +78,14 @@ Lightning is rigorously tested across multiple GPUs, TPUs CPUs and against major
-| System / PyTorch ver. | 1.6 (min. req.) | 1.7 | 1.8 (LTS) | 1.9 (latest) | 1.10 (nightly) |
-| :----------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
-| Conda py3.7 \[linux\] | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) |
-| Linux py3.7 \[GPUs\*\*\] | - | - | [![Build Status]()](https://dev.azure.com/PytorchLightning/pytorch-lightning/_build/latest?definitionId=6&branchName=master) | - | - |
-| Linux py3.{6,7} \[TPUs\*\*\*\] | [![TPU tests](https://github.com/PyTorchLightning/pytorch-lightning/workflows/TPU%20tests/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22TPU+tests%22+branch%3Amaster) | - | [![TPU tests](https://github.com/PyTorchLightning/pytorch-lightning/workflows/TPU%20tests/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22TPU+tests%22+branch%3Amaster) | - | - |
-| Linux py3.{6,7,8,9} | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - |
-| OSX py3.{6,7,8,9} | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - |
-| Windows py3.{6,7,8,9} | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - |
+| System / PyTorch ver. | 1.6 (min. req.) | 1.7 | 1.8 (LTS) | 1.9 (latest) | 1.10 (nightly) |
+| :------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| Conda py3.7 \[linux\] | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) |
+| Linux py3.7 \[GPUs\*\*\] | - | - | [![Build Status]()](https://dev.azure.com/PytorchLightning/pytorch-lightning/_build/latest?definitionId=6&branchName=master) | - | - |
+| Linux py3.7 \[TPUs\*\*\*\] | - | - | [![CircleCI](https://circleci.com/gh/PyTorchLightning/pytorch-lightning/tree/master.svg?style=svg)](https://circleci.com/gh/PyTorchLightning/pytorch-lightning/tree/master) | - | - |
+| Linux py3.{6,7,8,9} | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - |
+| OSX py3.{6,7,8,9} | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - |
+| Windows py3.{6,7,8,9} | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - |
- _\*\* tests run on two NVIDIA P100_
- _\*\*\* tests run on Google GKE TPUv2/3_
diff --git a/docs/source/advanced/mixed_precision.rst b/docs/source/advanced/mixed_precision.rst
new file mode 100644
index 0000000000000..ea784f0894a00
--- /dev/null
+++ b/docs/source/advanced/mixed_precision.rst
@@ -0,0 +1,83 @@
+.. testsetup:: *
+
+ from pytorch_lightning import Trainer
+
+
+.. _amp:
+
+Mixed Precision Training
+========================
+
+Mixed precision combines the use of both FP32 and lower bit floating points (such as FP16) to reduce memory footprint during model training, resulting in improved performance.
+
+Lightning offers mixed precision training for GPUs and CPUs, as well as bfloat16 mixed precision training for TPUs.
+
+.. note::
+
+ In some cases it is important to remain in FP32 for numerical stability, so keep this in mind when using mixed precision.
+
+ For example when running scatter operations during the forward (such as torchpoint3d) computation must remain in FP32.
+
+FP16 Mixed Precision
+--------------------
+
+In most cases, mixed precision uses FP16. Supported torch operations are automatically run in FP16, saving memory and improving throughput on GPU and TPU accelerators.
+
+Since computation happens in FP16, there is a chance of numerical instability. This is handled internally by a dynamic grad scaler which skips steps that are invalid, and adjusts the scaler to ensure subsequent steps fall within a finite range. For more information `see the autocast docs `__.
+
+.. note::
+
+ When using TPUs, setting ``precision=16`` will enable bfloat16 which is the only supported precision type on TPUs.
+
+.. testcode::
+ :skipif: not _APEX_AVAILABLE and not _NATIVE_AMP_AVAILABLE or not torch.cuda.is_available()
+
+ Trainer(gpus=1, precision=16)
+
+BFloat16 Mixed Precision
+------------------------
+
+.. warning::
+
+ BFloat16 requires PyTorch 1.10 or later. Currently this requires installing `PyTorch Nightly `__.
+
+ BFloat16 is also experimental and may not provide large speedups or memory improvements, but offer better numerical stability.
+
+ Do note for GPUs, largest benefits require `Ampere `__ based GPUs, such as A100s or 3090s.
+
+BFloat16 Mixed precision is similar to FP16 mixed precision, however we maintain more of the "dynamic range" that FP32 has to offer. This means we are able to improve numerical stability, compared to FP16 mixed precision. For more information see `this TPU performance blog post `__.
+
+Since BFloat16 is more stable than FP16 during training, we do not need to worry about any gradient scaling or nan gradient values that comes with using FP16 mixed precision.
+
+.. testcode::
+ :skipif: not _TORCH_BFLOAT_AVAILABLE
+
+ Trainer(gpus=1, precision="bf16")
+
+It is also possible to use BFloat16 mixed precision on the CPU, relying on MKLDNN under the hood.
+
+.. testcode::
+ :skipif: not _TORCH_CPU_AMP_AVAILABLE
+
+ Trainer(precision="bf16")
+
+NVIDIA APEX Mixed Precision
+---------------------------
+
+.. warning::
+
+ We strongly recommend to use the above native mixed precision rather than NVIDIA APEX unless you require more finer control.
+
+`NVIDIA APEX `__ offers some additional flexibility in setting mixed precision. This can be useful for when wanting to try out different precision configurations, such as keeping most of your weights in FP16 as well as running computation in FP16.
+
+.. testcode::
+ :skipif: not _APEX_AVAILABLE and not _NATIVE_AMP_AVAILABLE or not torch.cuda.is_available()
+
+ Trainer(gpus=1, amp_backend="apex")
+
+Set the `NVIDIA optimization level `__ via the trainer.
+
+.. testcode::
+ :skipif: not _APEX_AVAILABLE and not _NATIVE_AMP_AVAILABLE or not torch.cuda.is_available()
+
+ Trainer(gpus=1, amp_backend="apex", amp_level="O2")
diff --git a/docs/source/benchmarking/benchmarks.rst b/docs/source/benchmarking/benchmarks.rst
index f5a2e4e19b7fa..05959587a6684 100644
--- a/docs/source/benchmarking/benchmarks.rst
+++ b/docs/source/benchmarking/benchmarks.rst
@@ -12,3 +12,6 @@ In average for simple MNIST CNN classifier we are only about 0.06s slower per ep
.. figure:: ../_static/images/benchmarks/figure-parity-times.png
:alt: Speed parity to vanilla PT, created on 2020-12-16
:width: 500
+
+
+Learn more about reproducible benchmarking from the `PyTorch Reproducibility Guide _`.
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 8ddc896b6e912..88c22059b3fe1 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -107,6 +107,7 @@ def _transform_changelog(path_in: str, path_out: str) -> None:
"sphinx_copybutton",
"sphinx_paramlinks",
"sphinx_togglebutton",
+ "pt_lightning_sphinx_theme.extensions.lightning_tutorials",
]
# Add any paths that contain templates here, relative to this directory.
@@ -370,6 +371,8 @@ def package_list_from_file(file):
_XLA_AVAILABLE,
_TPU_AVAILABLE,
_TORCHVISION_AVAILABLE,
+ _TORCH_BFLOAT_AVAILABLE,
+ _TORCH_CPU_AMP_AVAILABLE,
_module_available,
)
_JSONARGPARSE_AVAILABLE = _module_available("jsonargparse")
diff --git a/docs/source/extensions/callbacks.rst b/docs/source/extensions/callbacks.rst
index 2c9ee612ceb22..b007fd479b0d0 100644
--- a/docs/source/extensions/callbacks.rst
+++ b/docs/source/extensions/callbacks.rst
@@ -113,16 +113,69 @@ Lightning has a few built-in callbacks.
----------
+.. _Persisting Callback State:
+
Persisting State
----------------
Some callbacks require internal state in order to function properly. You can optionally
choose to persist your callback's state as part of model checkpoint files using the callback hooks
:meth:`~pytorch_lightning.callbacks.Callback.on_save_checkpoint` and :meth:`~pytorch_lightning.callbacks.Callback.on_load_checkpoint`.
-However, you must follow two constraints:
+Note that the returned state must be able to be pickled.
+
+When your callback is meant to be used only as a singleton callback then implementing the above two hooks is enough
+to persist state effectively. However, if passing multiple instances of the callback to the Trainer is supported, then
+the callback must define a :attr:`~pytorch_lightning.callbacks.Callback.state_key` property in order for Lightning
+to be able to distinguish the different states when loading the callback state. This concept is best illustrated by
+the following example.
+
+.. testcode::
+
+ class Counter(Callback):
+ def __init__(self, what="epochs", verbose=True):
+ self.what = what
+ self.verbose = verbose
+ self.state = {"epochs": 0, "batches": 0}
+
+ @property
+ def state_key(self):
+ # note: we do not include `verbose` here on purpose
+ return self._generate_state_key(what=self.what)
+
+ def on_train_epoch_end(self, *args, **kwargs):
+ if self.what == "epochs":
+ self.state["epochs"] += 1
+
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.what == "batches":
+ self.state["batches"] += 1
+
+ def on_load_checkpoint(self, trainer, pl_module, callback_state):
+ self.state.update(callback_state)
+
+ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
+ return self.state.copy()
+
+
+ # two callbacks of the same type are being used
+ trainer = Trainer(callbacks=[Counter(what="epochs"), Counter(what="batches")])
+
+A Lightning checkpoint from this Trainer with the two stateful callbacks will include the following information:
+
+.. code-block::
+
+ {
+ "state_dict": ...,
+ "callbacks": {
+ "Counter{'what': 'batches'}": {"batches": 32, "epochs": 0},
+ "Counter{'what': 'epochs'}": {"batches": 0, "epochs": 2},
+ ...
+ }
+ }
-1. Your returned state must be able to be pickled.
-2. You can only use one instance of that class in the Trainer callbacks list. We don't support persisting state for multiple callbacks of the same class.
+The implementation of a :attr:`~pytorch_lightning.callbacks.Callback.state_key` is essential here. If it were missing,
+Lightning would not be able to disambiguate the state for these two callbacks, and :attr:`~pytorch_lightning.callbacks.Callback.state_key`
+by default only defines the class name as the key, e.g., here ``Counter``.
Best Practices
diff --git a/docs/source/governance.rst b/docs/source/governance.rst
index 80742525f58c5..d40ec0618079f 100644
--- a/docs/source/governance.rst
+++ b/docs/source/governance.rst
@@ -90,7 +90,7 @@ For API removal, renaming or other forms of backward-incompatible changes, the p
#. Calls to the deprecated API remain unchanged in their function during the deprecation phase.
#. Two minor versions in the future at version X+2 the breaking change takes effect.
-The "X+2" rule is a recommendation and not a strict requirement. Longer deprecation cylces may apply for some cases.
+The "X+2" rule is a recommendation and not a strict requirement. Longer deprecation cycles may apply for some cases.
New API and features are declared as:
diff --git a/docs/source/guides/speed.rst b/docs/source/guides/speed.rst
index 4e3ed0b1de801..fd245e741b9aa 100644
--- a/docs/source/guides/speed.rst
+++ b/docs/source/guides/speed.rst
@@ -186,7 +186,7 @@ Read more in our :ref:`accelerators` and :ref:`plugins` guides.
-----------
-.. _amp:
+.. _speed_amp:
*********************************
Mixed precision (16-bit) training
@@ -210,7 +210,7 @@ Mixed precision (16-bit) training
Mixed precision combines the use of both 32 and 16 bit floating points to reduce memory footprint during model training, resulting in improved performance, achieving +3X speedups on modern GPUs.
-Lightning offers mixed precision or 16-bit training for GPUs and TPUs.
+Lightning offers mixed precision training for GPUs and CPUs, as well as bfloat16 mixed precision training for TPUs.
.. testcode::
diff --git a/docs/source/index.rst b/docs/source/index.rst
index f3c154a7d257b..e1de1ed30defa 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -54,6 +54,7 @@ PyTorch Lightning Documentation
common/loggers
advanced/multi_gpu
advanced/advanced_gpu
+ advanced/mixed_precision
common/weights_loading
advanced/checkpoint_io
common/optimizers
diff --git a/pyproject.toml b/pyproject.toml
index 3e0d6826877c3..e848464c9d2fa 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -62,8 +62,10 @@ ignore_errors = "True"
[[tool.mypy.overrides]]
module = [
"pytorch_lightning.callbacks.pruning",
+ "pytorch_lightning.loops.closure",
"pytorch_lightning.trainer.evaluation_loop",
"pytorch_lightning.trainer.connectors.logger_connector",
+ "pytorch_lightning.utilities.apply_func",
"pytorch_lightning.utilities.argparse",
"pytorch_lightning.utilities.cli",
"pytorch_lightning.utilities.cloud_io",
diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py
index 2b647f5811fef..6038a8abc8f5c 100644
--- a/pytorch_lightning/accelerators/accelerator.py
+++ b/pytorch_lightning/accelerators/accelerator.py
@@ -186,7 +186,7 @@ def training_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT:
- hiddens(:class:`~torch.Tensor`): Passed in if
:paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0.
"""
- with self.precision_plugin.train_step_context(), self.training_type_plugin.train_step_context():
+ with self.precision_plugin.train_step_context():
return self.training_type_plugin.training_step(*step_kwargs.values())
def post_training_step(self) -> None:
@@ -204,7 +204,7 @@ def validation_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[S
- dataloader_idx (int): The index of the dataloader that produced this batch
(only if multiple val dataloaders used)
"""
- with self.precision_plugin.val_step_context(), self.training_type_plugin.val_step_context():
+ with self.precision_plugin.val_step_context():
return self.training_type_plugin.validation_step(*step_kwargs.values())
def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]:
@@ -219,7 +219,7 @@ def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OU
- dataloader_idx (int): The index of the dataloader that produced this batch
(only if multiple test dataloaders used).
"""
- with self.precision_plugin.test_step_context(), self.training_type_plugin.test_step_context():
+ with self.precision_plugin.test_step_context():
return self.training_type_plugin.test_step(*step_kwargs.values())
def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT:
@@ -234,7 +234,7 @@ def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT:
- dataloader_idx (int): The index of the dataloader that produced this batch
(only if multiple predict dataloaders used).
"""
- with self.precision_plugin.predict_step_context(), self.training_type_plugin.predict_step_context():
+ with self.precision_plugin.predict_step_context():
return self.training_type_plugin.predict_step(*step_kwargs.values())
def training_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT:
diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py
index 48957219a1ec0..46e74193fb557 100644
--- a/pytorch_lightning/accelerators/cpu.py
+++ b/pytorch_lightning/accelerators/cpu.py
@@ -13,7 +13,6 @@
# limitations under the License.
import pytorch_lightning as pl
from pytorch_lightning.accelerators.accelerator import Accelerator
-from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -24,15 +23,8 @@ def setup(self, trainer: "pl.Trainer") -> None:
"""
Raises:
MisconfigurationException:
- If AMP is used with CPU, or if the selected device is not CPU.
+ If the selected device is not CPU.
"""
- if isinstance(self.precision_plugin, MixedPrecisionPlugin):
- raise MisconfigurationException(
- " Mixed precision is currenty only supported with the AMP backend"
- " and AMP + CPU is not supported. Please use a GPU option or"
- " change precision setting."
- )
-
if "cpu" not in str(self.root_device):
raise MisconfigurationException(f"Device should be CPU, got {self.root_device} instead.")
diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py
index c38613b0e3159..fdb22a44ed307 100644
--- a/pytorch_lightning/callbacks/base.py
+++ b/pytorch_lightning/callbacks/base.py
@@ -34,20 +34,30 @@ class Callback(abc.ABC):
"""
@property
- def state_id(self) -> str:
+ def state_key(self) -> str:
"""
Identifier for the state of the callback. Used to store and retrieve a callback's state from the
- checkpoint dictionary by ``checkpoint["callbacks"][state_id]``. Implementations of a callback need to
- provide a unique state id if 1) the callback has state and 2) it is desired to maintain the state of
+ checkpoint dictionary by ``checkpoint["callbacks"][state_key]``. Implementations of a callback need to
+ provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of
multiple instances of that callback.
"""
return self.__class__.__qualname__
@property
- def _legacy_state_id(self) -> Type["Callback"]:
- """State identifier for checkpoints saved prior to version 1.5.0."""
+ def _legacy_state_key(self) -> Type["Callback"]:
+ """State key for checkpoints saved prior to version 1.5.0."""
return type(self)
+ def _generate_state_key(self, **kwargs: Any) -> str:
+ """
+ Formats a set of key-value pairs into a state key string with the callback class name prefixed.
+ Useful for defining a :attr:`state_key`.
+
+ Args:
+ **kwargs: A set of key-value pairs. Must be serializable to :class:`str`.
+ """
+ return f"{self.__class__.__qualname__}{repr(kwargs)}"
+
def on_configure_sharded_model(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called before configure sharded model"""
diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py
index ad7beec6927d7..77683ad2819f3 100644
--- a/pytorch_lightning/callbacks/early_stopping.py
+++ b/pytorch_lightning/callbacks/early_stopping.py
@@ -75,6 +75,13 @@ class EarlyStopping(Callback):
>>> from pytorch_lightning.callbacks import EarlyStopping
>>> early_stopping = EarlyStopping('val_loss')
>>> trainer = Trainer(callbacks=[early_stopping])
+
+ .. tip:: Saving and restoring multiple early stopping callbacks at the same time is supported under variation in the
+ following arguments:
+
+ *monitor, mode*
+
+ Read more: :ref:`Persisting Callback State`
"""
mode_dict = {"min": torch.lt, "max": torch.gt}
@@ -120,6 +127,10 @@ def __init__(
)
self.monitor = monitor or "early_stop_on"
+ @property
+ def state_key(self) -> str:
+ return self._generate_state_key(monitor=self.monitor, mode=self.mode)
+
def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self._check_on_train_epoch_end is None:
# if the user runs validation multiple times per training epoch, we try to check after
diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py
index 1144af7e32e9f..e7daa4ee53cde 100644
--- a/pytorch_lightning/callbacks/model_checkpoint.py
+++ b/pytorch_lightning/callbacks/model_checkpoint.py
@@ -194,6 +194,12 @@ class ModelCheckpoint(Callback):
trainer.fit(model)
checkpoint_callback.best_model_path
+ .. tip:: Saving and restoring multiple checkpoint callbacks at the same time is supported under variation in the
+ following arguments:
+
+ *monitor, mode, every_n_train_steps, every_n_epochs, train_time_interval, save_on_train_epoch_end*
+
+ Read more: :ref:`Persisting Callback State`
"""
CHECKPOINT_JOIN_CHAR = "-"
@@ -248,16 +254,27 @@ def __init__(
self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval, period)
self.__validate_init_configuration()
- def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
- """
- When pretrain routine starts we build the ckpt dir on the fly
- """
- self.__resolve_ckpt_dir(trainer)
+ @property
+ def state_key(self) -> str:
+ return self._generate_state_key(
+ monitor=self.monitor,
+ mode=self.mode,
+ every_n_train_steps=self._every_n_train_steps,
+ every_n_epochs=self._every_n_epochs,
+ train_time_interval=self._train_time_interval,
+ save_on_train_epoch_end=self._save_on_train_epoch_end,
+ )
+
+ def on_init_end(self, trainer: "pl.Trainer") -> None:
if self._save_on_train_epoch_end is None:
# if the user runs validation multiple times per training epoch, we try to save checkpoint after
# validation instead of on train epoch end
self._save_on_train_epoch_end = trainer.val_check_interval == 1.0
+ def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
+ """When pretrain routine starts we build the ckpt dir on the fly."""
+ self.__resolve_ckpt_dir(trainer)
+
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._last_time_checked = time.monotonic()
diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py
index 01f9b4fc025fe..096333388c3b1 100644
--- a/pytorch_lightning/core/lightning.py
+++ b/pytorch_lightning/core/lightning.py
@@ -456,6 +456,15 @@ def log(
f" of {list(self._metric_attributes.values())}"
)
+ if (
+ self.trainer.training
+ and is_param_in_hook_signature(self.training_step, "dataloader_iter", explicit=True)
+ and batch_size is None
+ ):
+ raise MisconfigurationException(
+ "With `def training_step(self, dataloader_iter)`, `self.log(..., batch_size=...)` should be provided."
+ )
+
results.log(
self._current_fx_name,
name,
diff --git a/pytorch_lightning/core/mixins/hparams_mixin.py b/pytorch_lightning/core/mixins/hparams_mixin.py
index 029ecc173bcf2..72129f22f54bb 100644
--- a/pytorch_lightning/core/mixins/hparams_mixin.py
+++ b/pytorch_lightning/core/mixins/hparams_mixin.py
@@ -15,7 +15,7 @@
import inspect
import types
from argparse import Namespace
-from typing import Optional, Sequence, Union
+from typing import MutableMapping, Optional, Sequence, Union
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES
from pytorch_lightning.utilities import AttributeDict
@@ -104,7 +104,7 @@ class ``__init__`` to be ignored
frame = inspect.currentframe().f_back
save_hyperparameters(self, *args, ignore=ignore, frame=frame)
- def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None:
+ def _set_hparams(self, hp: Union[MutableMapping, Namespace, str]) -> None:
hp = self._to_hparams_dict(hp)
if isinstance(hp, dict) and isinstance(self.hparams, dict):
@@ -113,7 +113,7 @@ def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None:
self._hparams = hp
@staticmethod
- def _to_hparams_dict(hp: Union[dict, Namespace, str]):
+ def _to_hparams_dict(hp: Union[MutableMapping, Namespace, str]):
if isinstance(hp, Namespace):
hp = vars(hp)
if isinstance(hp, dict):
diff --git a/pytorch_lightning/loggers/comet.py b/pytorch_lightning/loggers/comet.py
index 1799def445ea4..81084d61b6bd2 100644
--- a/pytorch_lightning/loggers/comet.py
+++ b/pytorch_lightning/loggers/comet.py
@@ -268,10 +268,22 @@ def finalize(self, status: str) -> None:
@property
def save_dir(self) -> Optional[str]:
+ """
+ Gets the save directory.
+
+ Returns:
+ The path to the save directory.
+ """
return self._save_dir
@property
def name(self) -> str:
+ """
+ Gets the project name.
+
+ Returns:
+ The project name if it is specified, else "comet-default".
+ """
# Don't create an experiment if we don't have one
if self._experiment is not None and self._experiment.project_name is not None:
return self._experiment.project_name
@@ -283,6 +295,19 @@ def name(self) -> str:
@property
def version(self) -> str:
+ """
+ Gets the version.
+
+ Returns:
+ The first one of the following that is set in the following order
+
+ 1. experiment id.
+ 2. experiment key.
+ 3. "COMET_EXPERIMENT_KEY" environment variable.
+ 4. future experiment key.
+
+ If none are present generates a new guid.
+ """
# Don't create an experiment if we don't have one
if self._experiment is not None:
return self._experiment.id
diff --git a/pytorch_lightning/loggers/neptune.py b/pytorch_lightning/loggers/neptune.py
index f179745d6d1d7..b810aebce47ba 100644
--- a/pytorch_lightning/loggers/neptune.py
+++ b/pytorch_lightning/loggers/neptune.py
@@ -267,17 +267,35 @@ def finalize(self, status: str) -> None:
@property
def save_dir(self) -> Optional[str]:
+ """
+ Gets the save directory of the experiment which in this case is ``None`` because Neptune does not save locally.
+
+ Returns:
+ None
+ """
# Neptune does not save any local files
return None
@property
def name(self) -> str:
+ """
+ Gets the name of the experiment.
+
+ Returns:
+ The name of the experiment if not in offline mode else "offline-name".
+ """
if self.offline_mode:
return "offline-name"
return self.experiment.name
@property
def version(self) -> str:
+ """
+ Gets the id of the experiment.
+
+ Returns:
+ The id of the experiment if not in offline mode else "offline-id-1234".
+ """
if self.offline_mode:
return "offline-id-1234"
return self.experiment.id
diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py
index e51e6b1f6a9d2..a9e1118b9d647 100644
--- a/pytorch_lightning/loggers/tensorboard.py
+++ b/pytorch_lightning/loggers/tensorboard.py
@@ -134,10 +134,22 @@ def log_dir(self) -> str:
@property
def save_dir(self) -> Optional[str]:
+ """
+ Gets the save directory where the TensorBoard experiments are saved.
+
+ Returns:
+ The local path to the save directory where the TensorBoard experiments are saved.
+ """
return self._save_dir
@property
def sub_dir(self) -> Optional[str]:
+ """
+ Gets the sub directory where the TensorBoard experiments are saved.
+
+ Returns:
+ The local path to the sub directory where the TensorBoard experiments are saved.
+ """
return self._sub_dir
@property
@@ -258,10 +270,22 @@ def finalize(self, status: str) -> None:
@property
def name(self) -> str:
+ """
+ Get the name of the experiment.
+
+ Returns:
+ The name of the experiment.
+ """
return self._name
@property
def version(self) -> int:
+ """
+ Get the experiment version.
+
+ Returns:
+ The experiment version if specified else the next version.
+ """
if self._version is None:
self._version = self._get_next_version()
return self._version
diff --git a/pytorch_lightning/loggers/test_tube.py b/pytorch_lightning/loggers/test_tube.py
index 800b13b83c39c..9a3b6ccee64df 100644
--- a/pytorch_lightning/loggers/test_tube.py
+++ b/pytorch_lightning/loggers/test_tube.py
@@ -20,7 +20,7 @@
import pytorch_lightning as pl
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
-from pytorch_lightning.utilities import _module_available, rank_zero_warn
+from pytorch_lightning.utilities import _module_available, rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.distributed import rank_zero_only
_TESTTUBE_AVAILABLE = _module_available("test_tube")
@@ -36,6 +36,10 @@ class TestTubeLogger(LightningLoggerBase):
Log to local file system in `TensorBoard `_ format
but using a nicer folder structure (see `full docs `_).
+ Warning:
+ The test-tube package is no longer maintained and PyTorch Lightning will remove the :class:´TestTubeLogger´
+ in v1.7.0.
+
Install it with pip:
.. code-block:: bash
@@ -97,6 +101,10 @@ def __init__(
log_graph: bool = False,
prefix: str = "",
):
+ rank_zero_deprecation(
+ "The TestTubeLogger is deprecated since v1.5 and will be removed in v1.7. We recommend switching to the"
+ " `pytorch_lightning.loggers.TensorBoardLogger` as an alternative."
+ )
if Experiment is None:
raise ImportError(
"You want to use `test_tube` logger which is not installed yet,"
@@ -197,10 +205,22 @@ def close(self) -> None:
@property
def save_dir(self) -> Optional[str]:
+ """
+ Gets the save directory.
+
+ Returns:
+ The path to the save directory.
+ """
return self._save_dir
@property
def name(self) -> str:
+ """
+ Gets the experiment name.
+
+ Returns:
+ The experiment name if the experiment exists, else the name specified in the constructor.
+ """
if self._experiment is None:
return self._name
@@ -208,6 +228,12 @@ def name(self) -> str:
@property
def version(self) -> int:
+ """
+ Gets the experiment version.
+
+ Returns:
+ The experiment version if the experiment exists, else the next version.
+ """
if self._experiment is None:
return self._version
diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py
index 0299cb522e59f..7081e95d352a3 100644
--- a/pytorch_lightning/loggers/wandb.py
+++ b/pytorch_lightning/loggers/wandb.py
@@ -219,15 +219,33 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
@property
def save_dir(self) -> Optional[str]:
+ """
+ Gets the save directory.
+
+ Returns:
+ The path to the save directory.
+ """
return self._save_dir
@property
def name(self) -> Optional[str]:
+ """
+ Gets the name of the experiment.
+
+ Returns:
+ The name of the experiment if the experiment exists else the name given to the constructor.
+ """
# don't create an experiment if we don't have one
return self._experiment.project_name() if self._experiment else self._name
@property
def version(self) -> Optional[str]:
+ """
+ Gets the id of the experiment.
+
+ Returns:
+ The id of the experiment if the experiment exists else the id given to the constructor.
+ """
# don't create an experiment if we don't have one
return self._experiment.id if self._experiment else self._id
diff --git a/pytorch_lightning/loops/__init__.py b/pytorch_lightning/loops/__init__.py
index c9775ed44155e..b7eb47167d26f 100644
--- a/pytorch_lightning/loops/__init__.py
+++ b/pytorch_lightning/loops/__init__.py
@@ -17,4 +17,3 @@
from pytorch_lightning.loops.dataloader import DataLoaderLoop, EvaluationLoop, PredictionLoop # noqa: F401
from pytorch_lightning.loops.epoch import EvaluationEpochLoop, PredictionEpochLoop, TrainingEpochLoop # noqa: F401
from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401
-from pytorch_lightning.loops.processors import IteratorBatchProcessor # noqa: F401
diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py
index 163d7681f29a8..3f94a0181672e 100644
--- a/pytorch_lightning/loops/batch/training_batch_loop.py
+++ b/pytorch_lightning/loops/batch/training_batch_loop.py
@@ -15,7 +15,7 @@
from collections import OrderedDict
from contextlib import contextmanager
from copy import copy
-from functools import partial, update_wrapper
+from functools import partial
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
import numpy as np
@@ -26,6 +26,7 @@
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loops.base import Loop
+from pytorch_lightning.loops.closure import Closure, ClosureResult
from pytorch_lightning.loops.utilities import (
_check_training_step_output,
_process_training_step_output,
@@ -144,12 +145,12 @@ def advance(self, batch, batch_idx):
result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer)
if result:
- self.batch_outputs[opt_idx].append(copy(result.training_step_output))
+ self.batch_outputs[opt_idx].append(copy(result.result_collection))
else:
# in manual optimization, there is no looping over optimizers
result = self._run_optimization(batch_idx, split_batch)
if result:
- self.batch_outputs[0].append(copy(result.training_step_output))
+ self.batch_outputs[0].append(copy(result.result_collection))
def teardown(self) -> None:
# release memory
@@ -165,7 +166,7 @@ def _run_optimization(
split_batch: Any,
opt_idx: Optional[int] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
- ):
+ ) -> Optional[ClosureResult]:
"""Runs closure (train step + backward) together with optimization if necessary.
Args:
@@ -177,8 +178,7 @@ def _run_optimization(
# toggle model params
self._run_optimization_start(opt_idx, optimizer)
- result = AttributeDict()
- closure = self._make_closure(split_batch, batch_idx, opt_idx, optimizer, self._hiddens, result)
+ closure = self._make_closure(split_batch, batch_idx, opt_idx, optimizer, self._hiddens)
if self.trainer.fit_loop.should_accumulate():
# For gradient accumulation
@@ -199,7 +199,9 @@ def _run_optimization(
if self.trainer.lightning_module.automatic_optimization:
self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
else:
- result = self._training_step(split_batch, batch_idx, opt_idx, self._hiddens)
+ closure()
+
+ result = closure.get_result()
if result:
# if no result, user decided to skip optimization
@@ -211,37 +213,79 @@ def _run_optimization(
self._run_optimization_end(opt_idx)
return result
- def _training_step_and_backward_closure(
+ def _make_closure(
self,
split_batch: Any,
batch_idx: int,
opt_idx: int,
optimizer: Optimizer,
- hiddens: Tensor,
- return_result: AttributeDict,
- ) -> Optional[Tensor]:
- """Closure for training step and backward
+ hiddens: Any,
+ ) -> Closure:
+ """
+ Build a closure object that captures the given arguments and runs the `training_step` function and optionally
+ other functions such as `backward` and `zero_grad`.
+ """
+ step_fn = self._make_step_fn(split_batch, batch_idx, opt_idx, hiddens)
+ backward_fn = self._make_backward_fn(batch_idx, optimizer, opt_idx)
+ zero_grad_fn = self._make_zero_grad_fn(batch_idx, opt_idx, optimizer)
+
+ return Closure(
+ step_fn=step_fn,
+ backward_fn=backward_fn,
+ zero_grad_fn=zero_grad_fn,
+ profiler=self.trainer.profiler,
+ )
- Args:
- split_batch: the current tbptt split of the batch
- batch_idx: the index of the current batch
- opt_idx: the index of the current optimizer
- optimizer: the current optimizer
- hiddens: the hidden state of the recurrent net
- return_result: the storage of the trainstep results
+ def _make_step_fn(self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens: Any) -> Callable[[], dict]:
+ """Build the step function that runs the `training_step` and processes its output."""
+ return partial(self._training_step, split_batch, batch_idx, opt_idx, hiddens)
+
+ def _make_zero_grad_fn(self, batch_idx: int, opt_idx: int, optimizer: Optimizer) -> Optional[Callable[[], None]]:
"""
+ Build a `zero_grad` function that zeroes the gradients before back-propagation.
+ Returns ``None`` in the case backward needs to be skipped, e.g., when manual optimization is on.
+ """
+
+ def zero_grad_fn():
+ self._on_before_zero_grad(optimizer)
+ self._optimizer_zero_grad(batch_idx, optimizer, opt_idx)
- result = self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens)
- if result is not None:
- return_result.update(result)
- return return_result.loss
+ is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0
+ if (
+ not self._skip_backward
+ and self.trainer.lightning_module.automatic_optimization
+ and is_first_batch_to_accumulate
+ ):
+ return zero_grad_fn
- def _make_closure(self, *closure_args: Any, **closure_kwargs: Any) -> Callable:
- """Wraps the training step closure into a partial object which will be called within ``optimizer.step``."""
- partial_func = partial(self._training_step_and_backward_closure, *closure_args, **closure_kwargs)
- return update_wrapper(partial_func, self._training_step_and_backward_closure)
+ def _make_backward_fn(
+ self,
+ batch_idx: int,
+ optimizer: Optimizer,
+ opt_idx: int,
+ ) -> Optional[Callable[[Tensor], Tensor]]:
+ """
+ Build a `backward` function that handles back-propagation through the output produced by the `training_step`
+ function. Returns ``None`` in the case backward needs to be skipped, e.g., when manual optimization is on.
+ """
- def _process_closure_result(self, opt_closure_result: Optional[AttributeDict]) -> None:
+ def backward_fn(loss: Tensor):
+ self.backward(loss, optimizer, opt_idx)
+
+ # when in dev debugging track the losses
+ # TODO: remove dev debugger tracking loss history
+ self.trainer.dev_debugger.track_train_loss_history(batch_idx, loss)
+
+ # check if loss or model weights are nan
+ if self.trainer.terminate_on_nan:
+ check_finite_loss(self.trainer.lightning_module, loss)
+
+ return loss
+
+ if not self._skip_backward and self.trainer.lightning_module.automatic_optimization:
+ return backward_fn
+
+ def _process_closure_result(self, opt_closure_result: Optional[ClosureResult]) -> None:
"""Checks if the closure results is finite and optionally breaks if it is not
Args:
@@ -286,18 +330,18 @@ def _training_step(
_check_training_step_output(self.trainer.lightning_module, training_step_output)
- training_step_output, self._hiddens = _process_training_step_output(self.trainer, training_step_output)
- if training_step_output is None:
+ result_collection, self._hiddens = _process_training_step_output(self.trainer, training_step_output)
+ if result_collection is None:
return
closure_loss = None
loss = None
if self.trainer.lightning_module.automatic_optimization:
# accumulate loss. if accumulate_grad_batches==1, no effect
- closure_loss = training_step_output.minimize / self.trainer.accumulate_grad_batches
+ closure_loss = result_collection.minimize / self.trainer.accumulate_grad_batches
# the loss will get scaled for amp. avoid any modifications to it
loss = closure_loss.detach().clone()
- return AttributeDict(closure_loss=closure_loss, loss=loss, training_step_output=training_step_output)
+ return AttributeDict(closure_loss=closure_loss, loss=loss, result_collection=result_collection)
def _optimizer_step(
self, optimizer: torch.optim.Optimizer, opt_idx: int, batch_idx: int, train_step_and_backward_closure: Callable
@@ -438,60 +482,22 @@ def block_ddp_sync_behaviour(self, should_block_sync: bool = False) -> Generator
else:
yield None
- def training_step_and_backward(
- self,
- split_batch: Any,
- batch_idx: int,
- opt_idx: int,
- optimizer: torch.optim.Optimizer,
- hiddens: Optional[Tensor],
- ) -> STEP_OUTPUT:
- """Wrap forward, zero_grad and backward in a closure so second order methods work"""
- with self.trainer.profiler.profile("training_step_and_backward"):
- # lightning module hook
- result = self._training_step(split_batch, batch_idx, opt_idx, hiddens)
-
- if not self._skip_backward and self.trainer.lightning_module.automatic_optimization:
- is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0
-
- if is_first_batch_to_accumulate:
- self._on_before_zero_grad(optimizer)
- self._optimizer_zero_grad(batch_idx, optimizer, opt_idx)
-
- # backward pass
- if result is not None:
- with self.trainer.profiler.profile("backward"):
- self.backward(result, optimizer, opt_idx)
-
- # when in dev debugging track the losses
- self.trainer.dev_debugger.track_train_loss_history(batch_idx, result.loss)
-
- # check if loss or model weights are nan
- if self.trainer.terminate_on_nan:
- check_finite_loss(self.trainer.lightning_module, result.loss)
-
- else:
- self._warning_cache.warn(
- "training_step returned None. If this was on purpose, ignore this warning..."
- )
-
- return result
-
def backward(
- self, result: STEP_OUTPUT, optimizer: Optional[torch.optim.Optimizer], *args: Any, **kwargs: Any
- ) -> None:
+ self,
+ loss: Tensor,
+ optimizer: Optional[torch.optim.Optimizer],
+ opt_idx: Optional[int] = None,
+ *args: Any,
+ **kwargs: Any,
+ ) -> Tensor:
"""Performs the backward step.
Args:
- result: The output of the trainstep (including the loss value)
+ loss: The loss value to back-propagate on
optimizer: Current optimizer being used. ``None`` if using manual optimization.
opt_idx: Index of the current optimizer being used. ``None`` if using manual optimization.
"""
- # backward can be called manually in the training loop
- if isinstance(result, Tensor):
- self.trainer.accelerator.backward(result, optimizer, *args, **kwargs)
- else:
- result.closure_loss = self.trainer.accelerator.backward(result.closure_loss, optimizer, *args, **kwargs)
+ self.trainer.accelerator.backward(loss, optimizer, opt_idx, *args, **kwargs)
if not self.trainer.fit_loop.should_accumulate():
# track gradients
@@ -499,6 +505,7 @@ def backward(
if grad_norm_dict:
self.trainer.lightning_module._current_fx_name = "on_after_backward"
self.trainer.lightning_module.log_grad_norm(grad_norm_dict)
+ return loss
def _update_running_loss(self, current_loss: Tensor) -> None:
"""Updates the running loss value with the current value"""
@@ -547,12 +554,15 @@ def _build_kwargs(self, batch: Any, batch_idx: int, opt_idx: int, hiddens: Optio
the keyword arguments for the training step
"""
# enable not needing to add opt_idx to training_step
- step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)])
+ step_kwargs = OrderedDict([("batch", batch)])
lightning_module = self.trainer.lightning_module
+ training_step_fx = getattr(lightning_module, "training_step")
+
+ if is_param_in_hook_signature(training_step_fx, "batch_idx", min_args=2):
+ step_kwargs["batch_idx"] = batch_idx
if len(self.trainer.optimizers) > 1:
- training_step_fx = getattr(lightning_module, "training_step")
has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx")
if has_opt_idx_in_train_step:
if not lightning_module.automatic_optimization:
diff --git a/pytorch_lightning/loops/closure.py b/pytorch_lightning/loops/closure.py
new file mode 100644
index 0000000000000..b47af75199708
--- /dev/null
+++ b/pytorch_lightning/loops/closure.py
@@ -0,0 +1,131 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from typing import Any, Callable, Optional
+
+from torch import Tensor
+
+from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler
+from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
+from pytorch_lightning.utilities.warnings import WarningCache
+
+
+@dataclass
+class ClosureResult:
+ """A container to hold the result of a :class:`AbstractClosure` call.
+
+ Attributes:
+ closure_loss: The loss with a graph attached.
+ loss: A detached copy of the closure loss.
+ result_collection: A collection of results returned by the closure.
+ """
+
+ closure_loss: Optional[Tensor]
+ loss: Optional[Tensor]
+ result_collection: Optional[ResultCollection]
+
+
+class AbstractClosure(ABC):
+ """
+ Abstract base class for optimizer closures in Lightning.
+
+ Formally, a closure is binding variables from an external scope to a function that does a computation on these
+ variables without taking them explicitly as input. This has the benefit that a closure can be passed to an
+ object which later can call it like a function but without requiring to pass in any arguments.
+
+ This class provides a simple abstraction making the instance of this class callable like a function while capturing
+ the :class:`ClosureResult` and caching it.
+ """
+
+ def __init__(self) -> None:
+ super().__init__()
+ self._result: Optional[ClosureResult] = None
+
+ def get_result(self) -> Optional[ClosureResult]:
+ """The cached result from the last time the closure was called. Once accessed, the internal reference
+ gets reset and the consumer will have to hold on to the reference as long as necessary."""
+ result = self._result
+ self._result = None # free memory
+ return result
+
+ @abstractmethod
+ def closure(self, *args: Any, **kwargs: Any) -> Optional[ClosureResult]:
+ """Implements the behavior of the closure once it is getting called."""
+ pass
+
+ def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
+ self._result = self.closure(*args, **kwargs)
+ if self._result is not None:
+ return self._result.loss
+
+
+class Closure(AbstractClosure):
+ """
+ An implementation of a :class:`AbstractClosure` for optimization in Lightning that combines three elementary
+ closures into one: ``training_step``, ``backward`` and ``zero_grad``.
+
+ The Closure gets created by the training loop(s) and is then passed to the
+ :meth:`torch.optim.Optimizer.step` method. An optimizer is responsible for calling the closure and optionally
+ do something with the output.
+
+ Args:
+ step_fn: This is typically the :meth:`pytorch_lightning.core.lightning.LightningModule.training_step
+ wrapped with processing for its outputs
+ backward_fn: A function that takes a loss value as input, performs back-propagation and returns the loss value.
+ Can be set to ``None`` to skip the backward operation.
+ zero_grad_fn: A function that zeroes the gradients. Can be set to ``None`` to skip zero_grad, for example
+ when accumulating gradients.
+ profiler: A profiler for profiling the actions of the passed in closure functions.
+
+ Example:
+
+ closure = Closure()
+ optimizer = torch.optim.Adam(...)
+ optimizer.step(closure)
+ """
+
+ warning_cache = WarningCache()
+
+ def __init__(
+ self,
+ step_fn: Callable[[], dict],
+ backward_fn: Optional[Callable[[Tensor], Tensor]] = None,
+ zero_grad_fn: Optional[Callable[[], None]] = None,
+ profiler: Optional[BaseProfiler] = None,
+ ):
+ super().__init__()
+ self._step_fn = step_fn
+ self._backward_fn = backward_fn
+ self._zero_grad_fn = zero_grad_fn
+ self._profiler = PassThroughProfiler() if profiler is None else profiler
+
+ def closure(self, *args: Any, **kwargs: Any) -> Optional[ClosureResult]:
+ with self._profiler.profile("training_step_and_backward"):
+ step_output = self._step_fn()
+ step_output = ClosureResult(**step_output) if step_output else None
+
+ if step_output is None:
+ self.warning_cache.warn("training_step returned None. If this was on purpose, ignore this warning...")
+
+ if self._zero_grad_fn is not None:
+ with self._profiler.profile("zero_grad"):
+ self._zero_grad_fn()
+
+ if self._backward_fn is not None and step_output is not None and step_output.closure_loss is not None:
+ with self._profiler.profile("backward"):
+ step_output.closure_loss = self._backward_fn(step_output.closure_loss)
+
+ return step_output
diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py
index d15236dc7694c..68b75b68eb91b 100644
--- a/pytorch_lightning/loops/dataloader/evaluation_loop.py
+++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py
@@ -98,17 +98,13 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
"""Performs evaluation on one single dataloader"""
void(*args, **kwargs)
+ dataloader_idx: int = self.current_dataloader_idx
dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader)
- dataloader = self.trainer.data_connector.get_profiled_dataloader(
- dataloader, dataloader_idx=self.current_dataloader_idx
- )
- dataloader_iter = iter(dataloader)
+ dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader, dataloader_idx=dataloader_idx)
- dl_max_batches = self._max_batches[self.current_dataloader_idx]
+ dl_max_batches = self._max_batches[dataloader_idx]
- dl_outputs = self.epoch_loop.run(
- dataloader_iter, self.current_dataloader_idx, dl_max_batches, self.num_dataloaders
- )
+ dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
# store batch level output per dataloader
if self.should_track_batch_outputs_for_epoch_end:
diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
index eb3f9dad58bcf..158d4cf527143 100644
--- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
+++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
@@ -19,7 +19,9 @@
from torch import Tensor
from pytorch_lightning.loops.base import Loop
+from pytorch_lightning.loops.utilities import _prepare_dataloader_iter
from pytorch_lightning.trainer.progress import Progress
+from pytorch_lightning.utilities.fetching import AbstractDataFetcher
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.types import STEP_OUTPUT
@@ -37,6 +39,7 @@ def __init__(self) -> None:
self._num_dataloaders: Optional[int] = None
self.outputs: List[STEP_OUTPUT] = []
self.batch_progress = Progress()
+ self.dataloader_iter: Optional[Iterator] = None
@property
def done(self) -> bool:
@@ -56,22 +59,24 @@ def reset(self) -> None:
self.batch_progress.current.reset()
def on_run_start(
- self, dataloader_iter: Iterator, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int
+ self, data_fetcher: AbstractDataFetcher, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int
) -> None:
"""Adds the passed arguments to the loop's state if necessary
Args:
- dataloader_iter: iterator over the dataloader
+ data_fetcher: the current data_fetcher wrapping the dataloader
dataloader_idx: index of the current dataloader
dl_max_batches: maximum number of batches the dataloader can produce
num_dataloaders: the total number of dataloaders
"""
- void(dataloader_iter, dataloader_idx)
+ void(dataloader_idx)
self._dl_max_batches = dl_max_batches
self._num_dataloaders = num_dataloaders
+ self.dataloader_iter = _prepare_dataloader_iter(data_fetcher, self.batch_progress.current.ready)
+
def advance(
- self, dataloader_iter: Iterator, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int
+ self, data_fetcher: AbstractDataFetcher, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int
) -> None:
"""Calls the evaluation step with the corresponding hooks and updates the logger connector.
@@ -84,9 +89,9 @@ def advance(
Raises:
StopIteration: If the current batch is None
"""
- void(dl_max_batches, num_dataloaders)
+ void(data_fetcher, dl_max_batches, num_dataloaders)
- batch_idx, (batch, _) = next(dataloader_iter)
+ batch_idx, (batch, _) = next(self.dataloader_iter)
if batch is None:
raise StopIteration
diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py
index f63bb4877e745..73557f71ade73 100644
--- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py
+++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py
@@ -1,11 +1,13 @@
from collections import OrderedDict
from typing import Any, Dict, Iterator, List, Optional, Tuple
+import torch
from deprecate import void
from pytorch_lightning.loops.base import Loop
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
from pytorch_lightning.trainer.progress import Progress
+from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.warnings import WarningCache
@@ -140,7 +142,7 @@ def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None
self.batch_progress.increment_completed()
if self.should_store_predictions:
- self.predictions.append(predictions)
+ self.predictions.append(move_data_to_device(predictions, torch.device("cpu")))
def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict[str, Any]:
"""
diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py
index 741a05cd5701e..43d51fe0027c6 100644
--- a/pytorch_lightning/loops/epoch/training_epoch_loop.py
+++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py
@@ -11,24 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterator, List, Optional, Union
import torch
from pytorch_lightning import loops # import as loops to avoid circular imports
from pytorch_lightning.loops.batch import TrainingBatchLoop
-from pytorch_lightning.loops.processors import IteratorBatchProcessor
+from pytorch_lightning.loops.utilities import _prepare_dataloader_iter
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import Progress, SchedulerProgress
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import STEP_OUTPUT
-# TODO: currently, the batch processor is only a loop when tbptt is enabled.
-# As we introduce more specialized batch processors, we may want to choose a
-# more suitable abstraction for them.
-BATCH_LOOP_TYPE = Optional[Tuple[TrainingBatchLoop, IteratorBatchProcessor]]
-
class TrainingEpochLoop(loops.Loop):
"""
@@ -50,7 +45,7 @@ def __init__(self, min_steps: int, max_steps: int):
self.batch_progress = Progress()
self.scheduler_progress = SchedulerProgress()
- self.batch_loop: BATCH_LOOP_TYPE = None
+ self.batch_loop: Optional[TrainingBatchLoop] = None
self.val_loop: Optional["loops.EvaluationLoop"] = None
self._results = ResultCollection(training=True)
@@ -81,7 +76,7 @@ def done(self) -> bool:
def connect(
self,
- batch_loop: BATCH_LOOP_TYPE = None,
+ batch_loop: TrainingBatchLoop = None,
val_loop: Optional["loops.EvaluationLoop"] = None,
) -> None:
"""Optionally connect a custom batch or validation loop to this training epoch loop."""
@@ -102,14 +97,16 @@ def reset(self) -> None:
self.scheduler_progress.current.reset()
self.batch_loop.optim_progress.reset_on_epoch()
- def on_run_start(self, *args: Any, **kwargs: Any) -> None:
+ def on_run_start(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
# hook
self.trainer.logger_connector.on_epoch_start()
self.trainer.call_hook("on_epoch_start")
self.trainer.call_hook("on_train_epoch_start")
self.trainer.fit_loop.epoch_progress.increment_started()
- def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
+ self.dataloader_iter = _prepare_dataloader_iter(dataloader_iter, self.batch_idx + 1)
+
+ def advance(self, *args: Any, **kwargs: Any) -> None:
"""Runs a single training batch.
Args:
@@ -118,33 +115,18 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
Raises:
StopIteration: When the epoch is canceled by the user returning -1
"""
- if isinstance(self.batch_loop, IteratorBatchProcessor):
- # By contract, when taking `dataloader_iter` as an argument,
- # `training_step` is responsible for reporting `is_last` in the
- # result dict, which is used to determine the stop condition for
- # the epoch. So as long as `advance` is invoked, it's correct to
- # assume that there are more batches to be processed.
- self.batch_progress.increment_ready()
- with self.trainer.profiler.profile("run_training_batch"):
- batch_output = self.batch_loop.run(dataloader_iter)
- self.batch_progress.increment_processed()
- is_last = batch_output.is_last
- else:
- _, (batch, is_last) = next(dataloader_iter)
-
- if not self.trainer.data_connector.train_data_fetcher.store_on_device:
- with self.trainer.profiler.profile("training_batch_to_device"):
- batch = self.trainer.accelerator.batch_to_device(batch)
-
- # ------------------------------------
- # TRAINING_STEP + TRAINING_STEP_END
- # ------------------------------------
- self.batch_progress.increment_ready()
-
- with self.trainer.profiler.profile("run_training_batch"):
- batch_output = self.batch_loop.run(batch, self.batch_idx)
-
- self.batch_progress.increment_processed()
+ batch_idx, (batch, is_last) = next(self.dataloader_iter)
+
+ if not self.trainer.data_connector.train_data_fetcher.store_on_device:
+ with self.trainer.profiler.profile("training_batch_to_device"):
+ batch = self.trainer.accelerator.batch_to_device(batch)
+
+ self.batch_progress.increment_ready()
+
+ with self.trainer.profiler.profile("run_training_batch"):
+ batch_output = self.batch_loop.run(batch, batch_idx)
+
+ self.batch_progress.increment_processed()
self.is_last_batch = is_last
@@ -162,8 +144,7 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
processed_batch_end_outputs = self._prepare_outputs(batch_end_outputs, batch_mode=True)
# hook
- if not isinstance(self.batch_loop, IteratorBatchProcessor):
- self.trainer.call_hook("on_train_batch_end", processed_batch_end_outputs, batch, self.batch_idx, 0)
+ self.trainer.call_hook("on_train_batch_end", processed_batch_end_outputs, batch, self.batch_idx, 0)
self.trainer.call_hook("on_batch_end")
self.trainer.logger_connector.on_batch_end()
diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py
index 49af10d4b2c0d..4a09c0ca1faeb 100644
--- a/pytorch_lightning/loops/fit_loop.py
+++ b/pytorch_lightning/loops/fit_loop.py
@@ -14,7 +14,7 @@
import logging
from contextlib import suppress
-from typing import Optional
+from typing import Any, Dict, Optional
from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.epoch import TrainingEpochLoop
@@ -40,6 +40,8 @@ def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] =
self.min_epochs = min_epochs
self.epoch_loop: Optional[TrainingEpochLoop] = None
self.epoch_progress = Progress()
+ # caches the loaded dataloader state until dataloader objects are available
+ self._dataloader_state_dict: Dict[str, Any] = {}
@property
def current_epoch(self) -> int:
@@ -175,6 +177,10 @@ def on_advance_start(self) -> None:
if self.current_epoch != 0 and self.trainer._should_reload_dl_epoch:
self.trainer.reset_train_dataloader(model)
+ if self._dataloader_state_dict:
+ self.trainer.train_dataloader.load_state_dict(self._dataloader_state_dict)
+ self._dataloader_state_dict = {}
+
# TODO: specify the possible exception
with suppress(Exception):
# set seed for distributed sampler (enables shuffling for each epoch)
@@ -193,12 +199,11 @@ def on_advance_start(self) -> None:
def advance(self) -> None:
"""Runs one whole epoch."""
dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader)
- dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader)
- dataloader_iter = iter(dataloader)
+ data_fetcher = self.trainer.data_connector.get_profiled_dataloader(dataloader)
with self.trainer.profiler.profile("run_training_epoch"):
# run train epoch
- epoch_output = self.epoch_loop.run(dataloader_iter)
+ epoch_output = self.epoch_loop.run(data_fetcher)
if epoch_output is None:
return
@@ -235,3 +240,13 @@ def should_accumulate(self) -> bool:
def teardown(self) -> None:
self.epoch_loop.teardown()
+
+ def on_save_checkpoint(self) -> Dict:
+ state_dict = super().on_save_checkpoint()
+ # FIXME(@tchaton) Should pass has_completed=True when iterator is exhausted ?
+ state_dict["dataloader_state_dict"] = self.trainer.train_dataloader.state_dict(has_completed=False)
+ return state_dict
+
+ def on_load_checkpoint(self, state_dict: Dict) -> None:
+ # cache the dataloader state dict until the dataloader objects are available
+ self._dataloader_state_dict = state_dict.get("dataloader_state_dict", {})
diff --git a/pytorch_lightning/loops/processors/__init__.py b/pytorch_lightning/loops/processors/__init__.py
deleted file mode 100644
index 9fcbe9e82dca8..0000000000000
--- a/pytorch_lightning/loops/processors/__init__.py
+++ /dev/null
@@ -1,15 +0,0 @@
-# Copyright The PyTorch Lightning team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from pytorch_lightning.loops.processors.iterator_batch_processor import IteratorBatchProcessor # noqa: F401
diff --git a/pytorch_lightning/loops/processors/iterator_batch_processor.py b/pytorch_lightning/loops/processors/iterator_batch_processor.py
deleted file mode 100644
index c1981173215ae..0000000000000
--- a/pytorch_lightning/loops/processors/iterator_batch_processor.py
+++ /dev/null
@@ -1,174 +0,0 @@
-# Copyright The PyTorch Lightning team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import logging
-from collections import OrderedDict
-from typing import Any, Dict, Iterator, List, Optional, Tuple
-
-import torch
-
-import pytorch_lightning as pl
-from pytorch_lightning.loops.utilities import (
- _check_training_step_output,
- _process_training_step_output,
- check_finite_loss,
-)
-from pytorch_lightning.trainer.progress import OptimizationProgress
-from pytorch_lightning.trainer.supporters import TensorRunningAccum
-from pytorch_lightning.utilities import AttributeDict
-from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from pytorch_lightning.utilities.model_helpers import is_overridden
-from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
-
-log = logging.getLogger(__name__)
-
-
-class IteratorBatchProcessor:
- """
- The processor for performing a training iteration when ``training_step`` needs access to the
- dataloader. It is selected when the signature of ``training_step`` contains ``dataloader_iter``:
-
- def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT:
-
- The ``training_step`` is allowed to fetch multiple batches during one training iteration. The
- framework provides minimum amount of automation with regards to model optimization. The
- flexibility allows for ease of experimentation with inter-batch parallelism techniques.
-
- This processor doesn't support ``automatic_optimization`` and ``tbptt``. An error will be thrown
- if the ``LightningModule`` or the ``Trainer`` is configured to use these features.
-
- The ``training_step`` is responsible for reporting whether it has reached the last batch by
- including an ``is_last`` field in the result dict. Failing to do so will result in an error.
-
- The ``training_step`` should only optimize the model with one batch for the sake of API and
- reporting consistency (TODO: consider removing this limitation).
-
- Args:
- trainer: a reference to the trainer
- model: a reference to the lightning module (for config validation purposes only)
- """
-
- def __init__(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
- if is_overridden("on_train_batch_start", model):
- raise MisconfigurationException(
- "The model hook `on_train_batch_start` is not compatible with "
- "taking a `dataloader_iter` argument in your `training_step`."
- )
- if is_overridden("on_train_batch_end", model):
- raise MisconfigurationException(
- "The model hook `on_train_batch_end` is not compatible with "
- "taking a `dataloader_iter` argument in your `training_step`."
- )
- if is_overridden("tbptt_split_batch", model):
- raise MisconfigurationException(
- "The model hook `tbptt_split_batch` is not compatible with "
- "taking a `dataloader_iter` argument in your `training_step`."
- )
- if trainer.accumulate_grad_batches != 1:
- raise MisconfigurationException(
- "`accumulate_grad_batches` can only be 1 when your "
- "`training_step` takes `dataloader_iter` as an argument."
- )
-
- self.trainer = trainer
-
- # The following field is not used by the processor since it doesn't support automatic
- # optimization and tbptt. Initializing them regardless since they are currently expected by
- # `FitLoop` or `TrainingEpochLoop`.
- # TODO: come up with an abstraction for "batch processors" so they can be better decoupled
- # with parent loops.
- self.accumulated_loss: Optional[torch.Tensor] = None
- self.running_loss: TensorRunningAccum = TensorRunningAccum(window_length=1)
- self.optim_progress = OptimizationProgress()
- self.split_idx: int = 0
- self._skip_backward = False
-
- def num_active_optimizers(self, batch_idx: Optional[int] = None) -> int:
- """
- Returns the number of active optimizers.
- """
- return len(self.trainer.optimizers)
-
- def get_active_optimizers(self, batch_idx: Optional[int] = None) -> List[Tuple[int, torch.optim.Optimizer]]:
- """
- Returns the currently active optimizers.
-
- Returns:
- A list of tuples (opt_idx, optimizer) of currently active optimizers.
- """
- return list(enumerate(self.trainer.optimizers))
-
- def run(self, dataloader_iter: Iterator) -> Optional[AttributeDict]:
- """
- Args:
- dataloader_iter: the iterator over the dataloader producing the new batch
- """
- batch_idx, (dataloader_iter, is_last) = next(dataloader_iter)
-
- self.trainer.logger_connector.on_batch_start()
- response = self.trainer.call_hook("on_batch_start")
- if response == -1:
- return AttributeDict(signal=-1)
-
- self.trainer.fit_loop.epoch_loop.batch_progress.increment_started()
-
- # give the PL module a result for logging
- model = self.trainer.lightning_module
- # manually capture logged metrics
- model._current_fx_name = "training_step"
- step_kwargs = self._build_kwargs(dataloader_iter, batch_idx)
-
- with self.trainer.profiler.profile("model_forward"):
- with self.trainer.profiler.profile("training_step"):
- training_step_output = self.trainer.accelerator.training_step(step_kwargs)
- self.trainer.accelerator.post_training_step()
-
- training_step_output = self.trainer.call_hook("training_step_end", training_step_output)
- _check_training_step_output(self.trainer.lightning_module, training_step_output)
-
- training_step_output, _ = _process_training_step_output(self.trainer, training_step_output)
-
- if self.trainer.terminate_on_nan:
- check_finite_loss(self.trainer.lightning_module, training_step_output.minimize)
-
- batch_outputs = [[] for _ in range(len(self.trainer.optimizers))]
- if training_step_output:
- batch_outputs[0].append(training_step_output)
-
- return AttributeDict(signal=0, training_step_output=batch_outputs, is_last=is_last)
-
- def teardown(self) -> None:
- """
- No-op. Only defined to comply with FitLoop's expectation.
- """
- pass
-
- # FIXME: To be deleted in next PR.
- def _build_kwargs(self, dataloader_iter: Iterator, batch_idx: int) -> Dict[str, Any]:
- """Builds the keyword arguments for training_step
-
- Args:
- dataloader_iter: The dataloader to pass
- batch_idx: the index of the current batch
-
- Returns:
- An ordered dict with the keyword arguments for the training step
- """
- # enable not needing to add opt_idx to training_step
- step_kwargs = OrderedDict([("dataloader_iter", dataloader_iter)])
-
- training_step_fx = getattr(self.trainer.lightning_module, "training_step")
- if is_param_in_hook_signature(training_step_fx, "batch_idx"):
- step_kwargs["batch_idx"] = batch_idx
-
- return step_kwargs
diff --git a/pytorch_lightning/loops/utilities.py b/pytorch_lightning/loops/utilities.py
index 1f0636d8b6cda..dd69640106af8 100644
--- a/pytorch_lightning/loops/utilities.py
+++ b/pytorch_lightning/loops/utilities.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Mapping, Optional, Tuple
+from typing import Any, Iterator, Mapping, Optional, Tuple
import torch
@@ -20,6 +20,7 @@
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
+from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters
from pytorch_lightning.utilities.types import STEP_OUTPUT
@@ -64,7 +65,7 @@ def _check_training_step_output(model: "pl.LightningModule", training_step_outpu
def _process_training_step_output(
trainer: "pl.Trainer", training_step_output: STEP_OUTPUT
-) -> Tuple[Optional[ResultCollection], Optional[torch.Tensor]]:
+) -> Tuple[Optional[ResultCollection], Optional[Any]]:
"""Adds the :param:`training_step_output` to the trainer's results
Args:
@@ -102,3 +103,13 @@ def _process_training_step_output(
if trainer.move_metrics_to_cpu:
results.cpu()
return results, hiddens
+
+
+def _prepare_dataloader_iter(data_fetcher: AbstractDataFetcher, batch_idx: int) -> Iterator:
+ """Attach the dataloader"""
+ if not isinstance(data_fetcher, DataLoaderIterDataFetcher):
+ # restore iteration
+ dataloader_iter = enumerate(data_fetcher, batch_idx)
+ else:
+ dataloader_iter = iter(data_fetcher)
+ return dataloader_iter
diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py
index a69065fa74f73..13f8c7404baf4 100644
--- a/pytorch_lightning/plugins/__init__.py
+++ b/pytorch_lightning/plugins/__init__.py
@@ -1,4 +1,7 @@
-from pytorch_lightning.plugins.base_plugin import Plugin
+from pathlib import Path
+from typing import Union
+
+from pytorch_lightning.plugins.environments import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
from pytorch_lightning.plugins.plugins_registry import ( # noqa: F401
@@ -30,6 +33,9 @@
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
+PLUGIN = Union[TrainingTypePlugin, PrecisionPlugin, ClusterEnvironment, CheckpointIO]
+PLUGIN_INPUT = Union[PLUGIN, str]
+
__all__ = [
"CheckpointIO",
"TorchCheckpointIO",
@@ -55,13 +61,10 @@
"TPUSpawnPlugin",
"TrainingTypePlugin",
"ParallelPlugin",
- "Plugin",
"DDPShardedPlugin",
"DDPSpawnShardedPlugin",
]
-from pathlib import Path
-
FILE_ROOT = Path(__file__).parent
TRAINING_TYPE_BASE_MODULE = "pytorch_lightning.plugins.training_type"
diff --git a/pytorch_lightning/plugins/base_plugin.py b/pytorch_lightning/plugins/base_plugin.py
deleted file mode 100644
index 515fc29d0e355..0000000000000
--- a/pytorch_lightning/plugins/base_plugin.py
+++ /dev/null
@@ -1,51 +0,0 @@
-# Copyright The PyTorch Lightning team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import contextlib
-from abc import ABC
-from typing import Generator
-
-import pytorch_lightning as pl
-
-
-class Plugin(ABC):
- """Basic class for all precision- and training type plugins."""
-
- def pre_dispatch(self) -> None:
- """Hook to do something before the training/evaluation/prediction starts."""
-
- def dispatch(self, trainer: "pl.Trainer") -> None:
- """Hook to do something at trainer run_stage starts."""
-
- def post_dispatch(self) -> None:
- """Hook to do something after the training/evaluation/prediction finishes."""
-
- @contextlib.contextmanager
- def train_step_context(self) -> Generator:
- """A contextmanager for the trainstep"""
- yield
-
- @contextlib.contextmanager
- def val_step_context(self) -> Generator:
- """A contextmanager for the validation step"""
- yield
-
- @contextlib.contextmanager
- def test_step_context(self) -> Generator:
- """A contextmanager for the teststep"""
- yield
-
- @contextlib.contextmanager
- def predict_step_context(self) -> Generator:
- """A contextmanager for the predict step"""
- yield
diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py
index ae9f261085229..c32f926236f56 100644
--- a/pytorch_lightning/plugins/precision/native_amp.py
+++ b/pytorch_lightning/plugins/precision/native_amp.py
@@ -19,7 +19,12 @@
import pytorch_lightning as pl
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
-from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, _TORCH_GREATER_EQUAL_1_10, AMPType
+from pytorch_lightning.utilities import (
+ _NATIVE_AMP_AVAILABLE,
+ _TORCH_BFLOAT_AVAILABLE,
+ _TORCH_CPU_AMP_AVAILABLE,
+ AMPType,
+)
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -31,7 +36,7 @@ class NativeMixedPrecisionPlugin(MixedPrecisionPlugin):
precision: Whether to use torch.float16 (16) or torch.bfloat16 (bf16).
"""
- def __init__(self, precision: Union[int, str] = 16) -> None:
+ def __init__(self, precision: Union[int, str] = 16, use_cpu: bool = False) -> None:
super().__init__()
if not _NATIVE_AMP_AVAILABLE:
@@ -39,6 +44,14 @@ def __init__(self, precision: Union[int, str] = 16) -> None:
"You have asked for native AMP but your PyTorch version does not support it."
" Consider upgrading with `pip install torch>=1.6`."
)
+
+ if use_cpu and not _TORCH_CPU_AMP_AVAILABLE:
+ raise MisconfigurationException(
+ "You have asked for native AMP on CPU, but AMP is only available on GPU for PyTorch 1.9 "
+ "and lower. To use native AMP on CPU, install PyTorch 1.10 or later."
+ )
+
+ self.use_cpu = use_cpu
self._fast_dtype = self._select_precision_dtype(precision)
self.backend = AMPType.NATIVE
if not self.is_bfloat16:
@@ -46,11 +59,15 @@ def __init__(self, precision: Union[int, str] = 16) -> None:
def _select_precision_dtype(self, precision: Union[int, str] = 16) -> torch.dtype:
if precision == "bf16":
- if not _TORCH_GREATER_EQUAL_1_10:
+ if not _TORCH_BFLOAT_AVAILABLE:
raise MisconfigurationException(
"To use bfloat16 with native amp you must install torch greater or equal to 1.10."
)
return torch.bfloat16
+ elif self.use_cpu:
+ raise MisconfigurationException(
+ "CPU native amp only supports bfloat16. Please pass precision='bf16' to the Trainer."
+ )
return torch.float16
@property
@@ -91,6 +108,8 @@ def pre_optimizer_step(
return False
def autocast_context_manager(self) -> torch.cuda.amp.autocast:
+ if self.use_cpu:
+ return torch.cpu.amp.autocast(fast_dtype=self._fast_dtype)
if self.is_bfloat16:
return torch.cuda.amp.autocast(fast_dtype=self._fast_dtype)
return torch.cuda.amp.autocast()
diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py
index 1261fea87c06e..86486bfc37cd9 100644
--- a/pytorch_lightning/plugins/precision/precision_plugin.py
+++ b/pytorch_lightning/plugins/precision/precision_plugin.py
@@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Callable, List, Optional, Tuple, Union
+import contextlib
+from typing import Any, Callable, Generator, List, Optional, Tuple, Union
import torch
from torch import Tensor
@@ -20,12 +21,11 @@
import pytorch_lightning as pl
from pytorch_lightning.core.hooks import CheckpointHooks
-from pytorch_lightning.plugins.base_plugin import Plugin
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.types import _PARAMETERS
-class PrecisionPlugin(Plugin, CheckpointHooks):
+class PrecisionPlugin(CheckpointHooks):
"""
Base class for all plugins handling the precision-specific parts of the training.
The class attribute precision must be overwritten in child classes.
@@ -136,3 +136,32 @@ def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -
"""Clip gradients by norm"""
parameters = self.master_params(optimizer)
torch.nn.utils.clip_grad_norm_(parameters, clip_val)
+
+ def pre_dispatch(self) -> None:
+ """Hook to do something before the training/evaluation/prediction starts."""
+
+ def dispatch(self, trainer: "pl.Trainer") -> None:
+ """Hook to do something when ``Accelerator.dispatch()`` gets called."""
+
+ def post_dispatch(self) -> None:
+ """Hook to do something after the training/evaluation/prediction finishes."""
+
+ @contextlib.contextmanager
+ def train_step_context(self) -> Generator:
+ """A contextmanager for the training step"""
+ yield
+
+ @contextlib.contextmanager
+ def val_step_context(self) -> Generator:
+ """A contextmanager for the validation step"""
+ yield
+
+ @contextlib.contextmanager
+ def test_step_context(self) -> Generator:
+ """A contextmanager for the test step"""
+ yield
+
+ @contextlib.contextmanager
+ def predict_step_context(self) -> Generator:
+ """A contextmanager for the predict step"""
+ yield
diff --git a/pytorch_lightning/plugins/precision/sharded_native_amp.py b/pytorch_lightning/plugins/precision/sharded_native_amp.py
index 861e5e1363dd2..a1eb23e478132 100644
--- a/pytorch_lightning/plugins/precision/sharded_native_amp.py
+++ b/pytorch_lightning/plugins/precision/sharded_native_amp.py
@@ -24,9 +24,10 @@
class ShardedNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin):
"""Mixed Precision for Sharded Training"""
- def __init__(self, precision: Union[int, str] = 16) -> None:
- super().__init__(precision)
- self.scaler = ShardedGradScaler()
+ def __init__(self, precision: Union[int, str] = 16, use_cpu: bool = False) -> None:
+ super().__init__(precision, use_cpu=use_cpu)
+ if not self.use_cpu:
+ self.scaler = ShardedGradScaler()
def clip_grad_by_norm(
self, optimizer: "OSS", clip_val: Union[int, float], norm_type: float = 2.0, eps: float = 1e-6
diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py
index 787353be307e6..6d96a443e391a 100644
--- a/pytorch_lightning/plugins/training_type/ddp.py
+++ b/pytorch_lightning/plugins/training_type/ddp.py
@@ -29,17 +29,21 @@
import torch.distributed
from torch.nn.parallel.distributed import DistributedDataParallel
+from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
+from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import (
+ _FAIRSCALE_AVAILABLE,
_HYDRA_AVAILABLE,
_TORCH_GREATER_EQUAL_1_7,
_TORCH_GREATER_EQUAL_1_8,
_TORCH_GREATER_EQUAL_1_9,
+ _TORCH_GREATER_EQUAL_1_10,
rank_zero_deprecation,
rank_zero_warn,
)
@@ -53,11 +57,19 @@
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
from pytorch_lightning.utilities.seed import reset_seed
+if _TORCH_GREATER_EQUAL_1_10:
+ from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer
+
+if _FAIRSCALE_AVAILABLE:
+ from fairscale.optim import OSS
if _HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
from hydra.utils import get_original_cwd, to_absolute_path
if _TORCH_GREATER_EQUAL_1_8:
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
+if _TORCH_GREATER_EQUAL_1_10:
+ import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
+ import torch.distributed.algorithms.model_averaging.averagers as averagers
log = logging.getLogger(__name__)
@@ -83,6 +95,7 @@ def __init__(
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[callable] = None,
ddp_comm_wrapper: Optional[callable] = None,
+ model_averaging_period: Optional[int] = None,
**kwargs: Union[Any, Dict[str, Any]],
) -> None:
super().__init__(
@@ -110,6 +123,7 @@ def __init__(
self._ddp_comm_state = ddp_comm_state
self._ddp_comm_hook = ddp_comm_hook
self._ddp_comm_wrapper = ddp_comm_wrapper
+ self._model_averaging_period = model_averaging_period
self._pids: Optional[List[int]] = None
self._sync_dir: Optional[str] = None
self.set_world_ranks()
@@ -302,6 +316,51 @@ def _register_ddp_hooks(self) -> None:
ddp_comm_wrapper=self._ddp_comm_wrapper,
)
+ # Post-localSDG is only available after 1.9,
+ # and `torch.distributed.optim` package currently is not available on Windows.
+ if (
+ _TORCH_GREATER_EQUAL_1_10
+ and isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState)
+ and self.lightning_module.trainer.state.fn == TrainerFn.FITTING
+ ):
+ self._reinit_optimizers_with_post_localSGD(self._ddp_comm_state.start_localSGD_iter)
+
+ def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int):
+ optimizers = self.lightning_module.trainer.optimizers
+ if self._model_averaging_period is None:
+ raise ValueError(
+ "Post-localSGD algorithm is used, " "but model averaging period is not provided to DDP plugin."
+ )
+ averager = averagers.PeriodicModelAverager(period=self._model_averaging_period, warmup_steps=warmup_steps)
+ for x, optimizer in enumerate(optimizers):
+ if isinstance(optimizer, LightningOptimizer):
+ optimizer = optimizer._optimizer
+
+ if (
+ isinstance(optimizer, DistributedOptimizer)
+ or isinstance(optimizer, ZeroRedundancyOptimizer)
+ or (_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS))
+ ):
+ raise ValueError(
+ f"Cannot wrap a distributed optimizer of type {optimizer.__name__} by PostLocalSGDOptimizer."
+ )
+
+ if isinstance(optimizer, PostLocalSGDOptimizer):
+ continue
+
+ optim_class = type(optimizer)
+ post_localSGD_optimizer = PostLocalSGDOptimizer(
+ params=optimizer.param_groups,
+ optimizer_class=optim_class,
+ averager=averager,
+ **optimizer.defaults,
+ )
+ optimizers[x] = post_localSGD_optimizer
+ del optimizer
+ trainer = self.lightning_module.trainer
+ trainer.optimizers = optimizers
+ trainer.convert_to_lightning_optimizers()
+
def configure_ddp(self):
self.pre_configure_ddp()
self._model = DistributedDataParallel(
@@ -442,3 +501,13 @@ def reconciliate_processes(self, trace: str):
os.kill(pid, signal.SIGKILL)
shutil.rmtree(sync_dir)
raise DeadlockDetectedException(f"DeadLock detected from rank: {self.global_rank} \n {trace}")
+
+ def teardown(self) -> None:
+ if isinstance(self.model, DistributedDataParallel):
+ self.model = self.lightning_module
+
+ if self.on_gpu:
+ # GPU teardown
+ self.lightning_module.cpu()
+ # clean up memory
+ torch.cuda.empty_cache()
diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py
index 08c049997bdfd..c31a908902a27 100644
--- a/pytorch_lightning/plugins/training_type/ddp_spawn.py
+++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py
@@ -364,3 +364,13 @@ def register_plugins(cls, plugin_registry: Dict) -> None:
description="DDPSpawn Plugin with `find_unused_parameters` as False",
find_unused_parameters=False,
)
+
+ def teardown(self) -> None:
+ if isinstance(self.model, DistributedDataParallel):
+ self.model = self.lightning_module
+
+ if self.on_gpu:
+ # GPU teardown
+ self.lightning_module.cpu()
+ # clean up memory
+ torch.cuda.empty_cache()
diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py
index 551324416cce9..5b0887c848322 100644
--- a/pytorch_lightning/plugins/training_type/dp.py
+++ b/pytorch_lightning/plugins/training_type/dp.py
@@ -119,3 +119,10 @@ def test_step_end(self, output):
if not is_overridden("test_step_end", self.lightning_module):
return self.reduce(output)
return output
+
+ def teardown(self) -> None:
+ if self.on_gpu:
+ # GPU teardown
+ self.lightning_module.cpu()
+ # clean up memory
+ torch.cuda.empty_cache()
diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py
index e5eb8bf9723ea..19694e1bcda11 100644
--- a/pytorch_lightning/plugins/training_type/horovod.py
+++ b/pytorch_lightning/plugins/training_type/horovod.py
@@ -206,3 +206,10 @@ def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["hvd.Distributed
def _filter_named_parameters(model: nn.Module, optimizer: Optimizer) -> List[Tuple[str, nn.Parameter]]:
opt_params = {p for group in optimizer.param_groups for p in group.get("params", [])}
return [(name, p) for name, p in model.named_parameters() if p in opt_params]
+
+ def teardown(self) -> None:
+ if self.on_gpu:
+ # GPU teardown
+ self.lightning_module.cpu()
+ # clean up memory
+ torch.cuda.empty_cache()
diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py
index 71aae1bb71a91..31d2deb5f65e6 100644
--- a/pytorch_lightning/plugins/training_type/parallel.py
+++ b/pytorch_lightning/plugins/training_type/parallel.py
@@ -133,15 +133,3 @@ def block_backward_sync(self):
yield None
else:
yield None
-
- def teardown(self) -> None:
- # Un-reference the wrapper if any was used.
- # todo (tchaton): Add support for all plugins.
- if isinstance(self.model, DistributedDataParallel):
- self.model = self.lightning_module
-
- if self.on_gpu:
- # GPU teardown
- self.lightning_module.cpu()
- # clean up memory
- torch.cuda.empty_cache()
diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py
index 82a33290aafbf..6ee1ce77c8c24 100644
--- a/pytorch_lightning/plugins/training_type/training_type_plugin.py
+++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py
@@ -25,14 +25,13 @@
import pytorch_lightning as pl
from pytorch_lightning.overrides.base import unwrap_lightning_module
from pytorch_lightning.plugins import TorchCheckpointIO
-from pytorch_lightning.plugins.base_plugin import Plugin
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT
TBroadcast = TypeVar("T")
-class TrainingTypePlugin(Plugin, ABC):
+class TrainingTypePlugin(ABC):
"""
Base class for all training type plugins that change the behaviour of the training, validation and test-loop.
"""
@@ -352,3 +351,12 @@ def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int)
Called in the training loop before anything happens for that batch.
"""
pass
+
+ def pre_dispatch(self) -> None:
+ """Hook to do something before the training/evaluation/prediction starts."""
+
+ def dispatch(self, trainer: "pl.Trainer") -> None:
+ """Hook to do something at trainer run_stage starts."""
+
+ def post_dispatch(self) -> None:
+ """Hook to do something after the training/evaluation/prediction finishes."""
diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py
index 2261538d7cbcc..bbfcbb22802a8 100644
--- a/pytorch_lightning/trainer/callback_hook.py
+++ b/pytorch_lightning/trainer/callback_hook.py
@@ -11,12 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from abc import ABC
from copy import deepcopy
from typing import Any, Dict, List, Optional, Type, Union
import torch
+from packaging.version import Version
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
@@ -242,7 +242,7 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]:
for callback in self.callbacks:
state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint)
if state:
- callback_states[callback.state_id] = state
+ callback_states[callback.state_key] = state
return callback_states
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
@@ -255,19 +255,19 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
if callback_states is None:
return
- current_callbacks_type = {type(cb) for cb in self.callbacks}
- saved_callbacks_type = set(callback_states.keys())
- difference = saved_callbacks_type.difference(current_callbacks_type)
+ is_legacy_ckpt = Version(checkpoint["pytorch-lightning_version"]) < Version("1.5.0dev")
+ current_callbacks_keys = {cb._legacy_state_key if is_legacy_ckpt else cb.state_key for cb in self.callbacks}
+ difference = callback_states.keys() - current_callbacks_keys
if difference:
rank_zero_warn(
- "Be aware that when using ``resume_from_checkpoint``, "
- "callbacks used to create the checkpoint need to be provided. "
- f"Please, add the following callbacks: {list(difference)}. ",
+ "Be aware that when using `resume_from_checkpoint`,"
+ " callbacks used to create the checkpoint need to be provided."
+ f" Please add the following callbacks: {list(difference)}.",
UserWarning,
)
for callback in self.callbacks:
- state = callback_states.get(callback.state_id, callback_states.get(callback._legacy_state_id))
+ state = callback_states.get(callback.state_key, callback_states.get(callback._legacy_state_key))
if state:
state = deepcopy(state)
callback.on_load_checkpoint(self, self.lightning_module, state)
diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py
index 07548f9c49074..d9c341c5dfaeb 100644
--- a/pytorch_lightning/trainer/configuration_validator.py
+++ b/pytorch_lightning/trainer/configuration_validator.py
@@ -16,6 +16,7 @@
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
+from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
class ConfigValidator:
@@ -34,6 +35,7 @@ def verify_loop_configurations(self, model: "pl.LightningModule") -> None:
self.__verify_train_loop_configuration(model)
self.__verify_eval_loop_configuration(model, "val")
self.__verify_manual_optimization_support(model)
+ self.__check_training_step_requires_dataloader_iter(model)
elif self.trainer.state.fn == TrainerFn.VALIDATING:
self.__verify_eval_loop_configuration(model, "val")
elif self.trainer.state.fn == TrainerFn.TESTING:
@@ -128,3 +130,26 @@ def __verify_manual_optimization_support(self, model: "pl.LightningModule") -> N
f" Remove `Trainer(accumulate_grad_batches={self.trainer.accumulate_grad_batches})`"
" or switch to automatic optimization."
)
+
+ def __check_training_step_requires_dataloader_iter(self, model: "pl.LightningModule"):
+ """Check if the current `training_step` is requesting `dataloader_iter`."""
+ training_step_fx = getattr(model, "training_step")
+ if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True):
+
+ if is_overridden("on_train_batch_start", model):
+ raise MisconfigurationException(
+ "The model hook `on_train_batch_start` is not compatible with "
+ "taking a `dataloader_iter` argument in your `training_step`."
+ )
+
+ if is_overridden("on_train_batch_end", model):
+ raise MisconfigurationException(
+ "The model hook `on_train_batch_end` is not compatible with "
+ "taking a `dataloader_iter` argument in your `training_step`."
+ )
+
+ if model.truncated_bptt_steps > 0:
+ raise MisconfigurationException(
+ "The model taking a `dataloader_iter` argument in your `training_step` "
+ "is incompatible with `truncated_bptt_steps > 0`."
+ )
diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py
index b0aabddf2d1cb..85a2c7ee87d13 100644
--- a/pytorch_lightning/trainer/connectors/accelerator_connector.py
+++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py
@@ -73,6 +73,7 @@
rank_zero_info,
rank_zero_warn,
)
+from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
if _HOROVOD_AVAILABLE:
@@ -148,6 +149,7 @@ def __init__(
self.plugins = plugins
self._validate_accelerator_and_devices()
+
self._warn_if_devices_flag_ignored()
self.select_accelerator_type()
@@ -565,10 +567,6 @@ def select_precision_plugin(self) -> PrecisionPlugin:
return TPUHalfPrecisionPlugin()
if self.amp_type == AMPType.NATIVE:
- if self.use_cpu:
- raise MisconfigurationException(
- "You have asked for native AMP on CPU, but AMP is only available on GPU."
- )
if not _NATIVE_AMP_AVAILABLE:
msg = (
"You have asked for native AMP but your PyTorch version does not support it."
@@ -583,10 +581,10 @@ def select_precision_plugin(self) -> PrecisionPlugin:
else:
log.info(f"Using native {self.precision} bit Automatic Mixed Precision")
if self._is_sharded_training_type:
- return ShardedNativeMixedPrecisionPlugin(self.precision)
+ return ShardedNativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu)
if self._is_fully_sharded_training_type:
- return FullyShardedNativeMixedPrecisionPlugin(self.precision)
- return NativeMixedPrecisionPlugin(self.precision)
+ return FullyShardedNativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu)
+ return NativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu)
if self.amp_type == AMPType.APEX:
if not _APEX_AVAILABLE:
@@ -601,7 +599,9 @@ def select_precision_plugin(self) -> PrecisionPlugin:
log.info("Using APEX 16bit precision.")
return ApexMixedPrecisionPlugin(self.amp_level)
- raise NotImplementedError("We only support precisions 64, 32 and 16!")
+ raise MisconfigurationException(
+ f"Precision {self.precision} is invalid. Allowed precision values: {PrecisionType.supported_types()}"
+ )
def select_training_type_plugin(self) -> TrainingTypePlugin:
if (
diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py
index 8d337b972dce2..71a30460af791 100644
--- a/pytorch_lightning/trainer/connectors/data_connector.py
+++ b/pytorch_lightning/trainer/connectors/data_connector.py
@@ -91,21 +91,18 @@ def on_trainer_init(
self.trainer.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs
self.trainer._is_data_prepared = False
- def _check_training_step_requires_dataloader_iter(self) -> bool:
- training_step_fx = getattr(self.trainer.lightning_module, "training_step")
- contains_dataloader_iter = is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True)
- return contains_dataloader_iter
-
def _select_data_fetcher(self) -> AbstractDataFetcher:
if self.trainer.sanity_checking:
return DataFetcher()
- if self.trainer.training and self._check_training_step_requires_dataloader_iter():
+ training_step_fx = getattr(self.trainer.lightning_module, "training_step")
+ if self.trainer.training and is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True):
rank_zero_warn(
"Found `dataloader_iter` argument in the `training_step`. Note that the support for "
"this signature is experimental and the behavior is subject to change."
)
return DataLoaderIterDataFetcher()
+
elif self.trainer.training and os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1":
# note: this is an experimental feature
if not self.trainer.training_type_plugin.on_gpu:
@@ -124,9 +121,7 @@ def get_profiled_dataloader(self, dataloader: Iterable, dataloader_idx: int = 0)
profiler=self.trainer.profiler,
)
setattr(self, f"{stage}_data_fetcher", data_fetcher)
- if isinstance(data_fetcher, DataLoaderIterDataFetcher):
- return data_fetcher
- return enumerate(data_fetcher)
+ return data_fetcher
def prepare_data(self) -> None:
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
@@ -250,6 +245,16 @@ def detach_data(model: "pl.LightningModule") -> None:
if isinstance(loader, _PatchDataLoader):
loader.unpatch(model)
+ def teardown(self) -> None:
+ if self.train_data_fetcher:
+ self.train_data_fetcher.teardown()
+ if self.validate_data_fetcher:
+ self.validate_data_fetcher.teardown()
+ if self.test_data_fetcher:
+ self.test_data_fetcher.teardown()
+ if self.sanity_check_data_fetcher:
+ self.sanity_check_data_fetcher.teardown()
+
class _PatchDataLoader:
r"""
@@ -267,7 +272,7 @@ def __init__(self, dataloader: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS], stage
# exists which shows dataloader methods have been overwritten.
# so, we hack it by using the string representation
self.patch_loader_code = str(self.__call__.__code__)
- self.old_loader: Optional[Callable] = None
+ self._old_loader: Optional[Callable] = None
self.stage = stage
def __call__(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]:
diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py
index 672695885ebf9..a965699510689 100644
--- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py
+++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py
@@ -199,7 +199,11 @@ def update_eval_epoch_metrics(self) -> _EVALUATE_OUTPUT:
"""
def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) -> None:
- self.trainer._results.extract_batch_size(split_batch)
+ # when the user request `dataloader_iter`, we can't track the batch_size
+ # and this is left to user responsibility.
+ if isinstance(split_batch, pl.utilities.fetching.DataLoaderIterDataFetcher):
+ self.trainer._results.extract_batch_size(split_batch)
+
self._batch_idx = batch_idx
self._split_idx = split_idx
diff --git a/pytorch_lightning/trainer/connectors/training_trick_connector.py b/pytorch_lightning/trainer/connectors/training_trick_connector.py
index 733199c93267c..285ed5afbf62b 100644
--- a/pytorch_lightning/trainer/connectors/training_trick_connector.py
+++ b/pytorch_lightning/trainer/connectors/training_trick_connector.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, List, Union
+from typing import Dict, Union
from pytorch_lightning.callbacks import GradientAccumulationScheduler
from pytorch_lightning.utilities import GradClipAlgorithmType
@@ -27,7 +27,7 @@ def on_trainer_init(
gradient_clip_val: float,
gradient_clip_algorithm: str,
track_grad_norm: Union[int, float, str],
- accumulate_grad_batches: Union[int, Dict[int, int], List[list]],
+ accumulate_grad_batches: Union[int, Dict[int, int]],
terminate_on_nan: bool,
):
@@ -48,7 +48,7 @@ def on_trainer_init(
self.trainer.accumulate_grad_batches = accumulate_grad_batches
self.configure_accumulated_gradients(accumulate_grad_batches)
- def configure_accumulated_gradients(self, accumulate_grad_batches):
+ def configure_accumulated_gradients(self, accumulate_grad_batches: Union[int, Dict[int, int]]) -> None:
if isinstance(accumulate_grad_batches, dict):
self.trainer.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
elif isinstance(accumulate_grad_batches, int):
diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py
index 0e747a9e4857d..4c0ddddd2c234 100644
--- a/pytorch_lightning/trainer/supporters.py
+++ b/pytorch_lightning/trainer/supporters.py
@@ -13,20 +13,23 @@
# limitations under the License.
from collections.abc import Iterable, Iterator, Mapping, Sequence
-from dataclasses import dataclass, field
+from dataclasses import asdict, dataclass, field
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from torch.utils.data import Dataset
-from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, DataLoader
+from torch.utils.data.dataloader import _BaseDataLoaderIter, DataLoader
from torch.utils.data.dataset import IterableDataset
from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections
from pytorch_lightning.utilities.auto_restart import (
- _cycle_to_next_worker_and_reset,
- _find_current_worker,
+ _find_fast_forward_samplers,
CaptureIterableDataset,
+ CaptureMapDataset,
+ IteratorState,
+ MergedIteratorState,
+ patch_dataloader_iterator,
)
from pytorch_lightning.utilities.data import get_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -167,6 +170,7 @@ def __init__(self, loader: Any, length: Optional[int] = None, state: SharedCycle
self.loader = loader
self._loader_iter = None
self.counter = 0
+ self.state = state
def __iter__(self) -> Any:
"""
@@ -176,6 +180,7 @@ def __iter__(self) -> Any:
CycleIterator: self
"""
self.counter = 0
+ self.state.reset()
self._loader_iter = iter(self.loader)
return self
@@ -205,6 +210,12 @@ def __next__(self) -> Any:
raise StopIteration
self._loader_iter = iter(self.loader)
+ # if fault tolerant is enabled, we need to patch the iterator to collect the states
+ # before the batch gets returned.
+ fetcher = getattr(self.loader, "_lightning_fetcher", None)
+ if fetcher:
+ patch_dataloader_iterator(self.loader, self._loader_iter, fetcher)
+
return next(self._loader_iter)
finally:
@@ -302,11 +313,6 @@ def __len__(self) -> int:
return self._calc_num_data(self.datasets, self.mode)
-class DataLoaderDict(Dict):
- # behaves exactly like a dict, this is used to simplify apply_to_collection.
- pass
-
-
class CombinedLoader:
"""
Combines different dataloaders and allows sampling in parallel.
@@ -360,80 +366,110 @@ def __init__(self, loaders: Any, mode: str = "min_size"):
self._iterator = None # assigned in __iter__
@staticmethod
- def _state_dict_fn(dataloader: DataLoader, iterator: Optional[Iterator], num_batches_processed: int) -> Dict:
- # find next worker if multiple workers were used
- state = _find_current_worker(iterator)
- if isinstance(dataloader.dataset, CaptureIterableDataset):
- # the sampler state dict are extracted in `CombinedLoaderIterator`
- if iterator is not None and getattr(iterator, "_sampler_state_dict", None) is not None:
- state.update(iterator._sampler_state_dict[0])
- else:
- # fetch directly from fast forward sampler
- state.update(dataloader.fast_forward_sampler.state_dict(num_batches_processed))
- return DataLoaderDict(state)
-
- def state_dict(self, num_batches_processed: int) -> Dict:
+ def _state_dict_fn(dataloader: DataLoader, iterator: Optional[Iterator], has_completed: int) -> Dict:
+ if isinstance(dataloader, CycleIterator):
+ iterator = dataloader._loader_iter
+ state = getattr(iterator, "state", None) if has_completed else getattr(iterator, "previous_state", None)
+ if state:
+ return asdict(state)
+ return {}
+
+ def state_dict(self, has_completed: bool = False) -> Dict:
"""
The state dict includes all states from wrapped dataloaders and their samplers through the
``CaptureIterableDataset`` and fast-forward samplers.
Args:
- num_batches_processed: The number of batches processed so far, needed because the individual dataloaders
- may have already prefetched more batches by the time a state dict is requested.
+ has_completed: whether the current state of data fetching is considered completed or not. If it is, the
+ current state gets returned, otherwise the previously cached state.
"""
- if not _fault_tolerant_training():
- return DataLoaderDict()
-
- state_dict_fn = partial(self._state_dict_fn, num_batches_processed=num_batches_processed)
+ if not _fault_tolerant_training() or self._iterator is None:
+ return {}
- return apply_to_collections(self.loaders, self._iterator.loader_iters, (Iterator, DataLoader), state_dict_fn)
+ return apply_to_collections(
+ self.loaders,
+ self._iterator.loader_iters,
+ (Iterator, DataLoader),
+ partial(self._state_dict_fn, has_completed=has_completed),
+ )
- def load_state_dict(self, state_dict):
+ def load_state_dict(self, state_dict) -> None:
# store the samplers state.
# They would be reloaded once the `CombinedIterator` as been created
# and the workers are created.
self._loaders_iter_state_dict = state_dict
- def mock_reset_fn(self, *_, **__):
- pass
-
- # mock reset call, so we can rotate the `_worker_queue_idx_cycle` to failed worker
- # and get the first batch from it
- _MultiProcessingDataLoaderIter._original_reset = _MultiProcessingDataLoaderIter._reset
- _MultiProcessingDataLoaderIter._reset = mock_reset_fn
-
- def on_restart(self, iterator: Iterator):
+ def on_restart(self, iterator: Iterator) -> None:
if not self._loaders_iter_state_dict:
return
- # this happen inside the workers if any were specificied.
+ def create_loader_iters(dataloader: DataLoader, state_dict: Dict) -> Iterator:
+ """Function used to reload the iterator state before once the workers are created."""
+
+ dataloader_to_iter_on = dataloader
+ if isinstance(dataloader, CycleIterator):
+ dataloader = dataloader_to_iter_on.loader
+
+ dataset = dataloader.dataset
+
+ # We reload the states before creating the workers
+ # The specific type of dataset will then decide if the state should be applied before or after
+ # spawning the workers
+ if isinstance(dataset, CaptureMapDataset):
+ iterator_state = state_dict["state"][0]
+
+ if not isinstance(iterator_state, IteratorState):
+ iterator_state = IteratorState.from_state_dict(iterator_state)
+
+ # reload sampler state
+ ff_sampler = _find_fast_forward_samplers(dataloader)
+ ff_sampler.load_state_dict(iterator_state.sampler_state)
+ # reload dataset state
+ dataset.load_state_dict(
+ iterator_state.dataset_state,
+ latest_worker_id=state_dict["latest_worker_id"],
+ num_workers=iterator_state.num_workers,
+ )
+
+ elif isinstance(dataset, CaptureIterableDataset):
+ dataset_dict = {
+ sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items()
+ }
+ dataset.load_state_dict(dataset_dict)
- def create_loader_iters(dataloader: DataLoader, state_dict: DataLoaderDict):
- if isinstance(dataloader.dataset, CaptureIterableDataset):
- # provide the `state_dict` to the `CaptureIterableDataset`
- # as it is responsible for passing down the state to associated `FastForwardSampler`
- dataloader.dataset.load_state_dict(state_dict)
else:
- # for `Mapping-based` dataset, the `fast_forward_sampler` was attached
- # on the dataloader for simplicity
- dataloader.fast_forward_sampler.load_state_dict(state_dict)
+ raise MisconfigurationException(
+ "This shouldn't happen. Please, open an issue on PyTorch Lightning Github."
+ )
+
+ # We finally spawned the workers if any.
+ it = iter(dataloader_to_iter_on)
- # cycle back the iterator to the failed worker if multiple workers were provided
- iterator = _cycle_to_next_worker_and_reset(dataloader, state_dict)
+ # restore caching state
+ state = MergedIteratorState.from_state_dict(state_dict)
- if isinstance(dataloader.dataset, CaptureIterableDataset):
- # remove keys related to iterator
- state_dict = {k: v for k, v in state_dict.items() if k not in ("num_worker", "previous_worker")}
- # need to re-attach the state dict into the iterator for future collection.
- iterator._sampler_state_dict = [state_dict]
- return iterator
+ if isinstance(dataloader_to_iter_on, CycleIterator):
+ it._loader_iter.state = state
+ else:
+ it.state = state
+ return it
+
+ # create an un-existing token, so it doesn't activate for something else than an iterator.
+ class DataLoaderDict(dict):
+ pass
# apply the `create_loader_iters` on the collection of `DataLoader / Iterator`.
# each `Iterator` was created from the `DataLoader`.
iterator._loader_iters = apply_to_collections(
- self.loaders, self._loaders_iter_state_dict, (DataLoader, DataLoaderDict), create_loader_iters
+ self.loaders,
+ self._loaders_iter_state_dict,
+ (Iterable, DataLoaderDict),
+ create_loader_iters,
+ wrong_dtype=(Sequence, Mapping),
)
+ self._loaders_iter_state_dict = None
+
@property
def sampler(self) -> Union[Iterable, Sequence, Mapping]:
"""Return a collections of samplers extracting from loaders."""
@@ -457,7 +493,6 @@ def _wrap_loaders_max_size_cycle(self) -> Any:
self.loaders = apply_to_collection(
self.loaders, Iterable, CycleIterator, length=length, state=state, wrong_dtype=(Sequence, Mapping)
)
-
state.reset()
def __iter__(self) -> Any:
diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py
index 6b39bb5159086..f56a84b1d294d 100644
--- a/pytorch_lightning/trainer/trainer.py
+++ b/pytorch_lightning/trainer/trainer.py
@@ -28,12 +28,11 @@
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.loggers import LightningLoggerBase
-from pytorch_lightning.loops import IteratorBatchProcessor, TrainingBatchLoop, TrainingEpochLoop
+from pytorch_lightning.loops import TrainingBatchLoop, TrainingEpochLoop
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
from pytorch_lightning.loops.dataloader.prediction_loop import PredictionLoop
from pytorch_lightning.loops.fit_loop import FitLoop
-from pytorch_lightning.plugins import DDPSpawnPlugin, Plugin
-from pytorch_lightning.plugins.environments import ClusterEnvironment
+from pytorch_lightning.plugins import DDPSpawnPlugin, PLUGIN_INPUT
from pytorch_lightning.profiler import (
AdvancedProfiler,
BaseProfiler,
@@ -77,12 +76,10 @@
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from pytorch_lightning.utilities.fetching import DataLoaderIterDataFetcher
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.model_summary import ModelSummary, summarize
from pytorch_lightning.utilities.seed import reset_seed
-from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS
log = logging.getLogger(__name__)
@@ -123,7 +120,7 @@ def __init__(
track_grad_norm: Union[int, float, str] = -1,
check_val_every_n_epoch: int = 1,
fast_dev_run: Union[int, bool] = False,
- accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1,
+ accumulate_grad_batches: Union[int, Dict[int, int]] = 1,
max_epochs: Optional[int] = None,
min_epochs: Optional[int] = None,
max_steps: Optional[int] = None,
@@ -153,7 +150,7 @@ def __init__(
terminate_on_nan: bool = False,
auto_scale_batch_size: Union[str, bool] = False,
prepare_data_per_node: Optional[bool] = None,
- plugins: Optional[Union[List[Union[Plugin, ClusterEnvironment, str]], Plugin, ClusterEnvironment, str]] = None,
+ plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None,
amp_backend: str = "native",
amp_level: str = "O2",
distributed_backend: Optional[str] = None,
@@ -378,8 +375,8 @@ def __init__(
self.tuner = Tuner(self)
fit_loop = FitLoop(
- min_epochs=(1 if (min_epochs is None and min_steps is None) else min_epochs),
- max_epochs=(1000 if (max_epochs is None and max_steps is None) else max_epochs),
+ min_epochs=(1 if (min_epochs is None and min_steps is None and max_time is None) else min_epochs),
+ max_epochs=(1000 if (max_epochs is None and max_steps is None and max_time is None) else max_epochs),
)
training_epoch_loop = TrainingEpochLoop(min_steps, max_steps)
training_batch_loop = TrainingBatchLoop()
@@ -920,18 +917,6 @@ def _load_checkpoint_weights(self):
rank_zero_info(f"Loading model weights from checkpoint at {self._ckpt_path}")
self.checkpoint_connector.restore_model_weights(self._ckpt_path)
- def _maybe_switch_to_iterator_batch_processor(self, model: "pl.LightningModule") -> None:
- training_step_fx = getattr(model, "training_step")
- if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True):
- log.warning(
- "Found `dataloader_iter` argument in the `training_step`. Note that the support for "
- "this signature is experimental and the behavior may subject to change."
- )
- batch_loop = IteratorBatchProcessor(self, model)
- self.fit_loop.epoch_loop.connect(batch_loop)
- # FIXME: Move this logic to data_connector after removing `IteratorBatchProcessor`
- self.data_connector.data_fetcher = DataLoaderIterDataFetcher()
-
def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
# clean hparams
if hasattr(model, "hparams"):
@@ -939,9 +924,6 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
self.config_validator.verify_loop_configurations(model)
- if self.training:
- self._maybe_switch_to_iterator_batch_processor(model)
-
# attach model log function to callback
self.callback_connector.attach_model_logging_functions(model)
@@ -1077,6 +1059,7 @@ def _post_dispatch(self):
# these `teardown` calls are here instead of in `_call_teardown_hook` since they are internal teardowns
# which need to happen before.
self.accelerator.teardown()
+ self.data_connector.teardown()
self._active_loop.teardown()
self.logger_connector.teardown()
diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py
index 747c0be617cbd..ed9bb930001d8 100644
--- a/pytorch_lightning/utilities/__init__.py
+++ b/pytorch_lightning/utilities/__init__.py
@@ -44,6 +44,8 @@
_OMEGACONF_AVAILABLE,
_POPTORCH_AVAILABLE,
_RICH_AVAILABLE,
+ _TORCH_BFLOAT_AVAILABLE,
+ _TORCH_CPU_AMP_AVAILABLE,
_TORCH_GREATER_EQUAL_1_7,
_TORCH_GREATER_EQUAL_1_8,
_TORCH_GREATER_EQUAL_1_9,
diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py
index b96a0110e58fa..d7d09251f6087 100644
--- a/pytorch_lightning/utilities/apply_func.py
+++ b/pytorch_lightning/utilities/apply_func.py
@@ -18,7 +18,7 @@
from collections.abc import Mapping, Sequence
from copy import copy
from functools import partial
-from typing import Any, Callable, Optional, Union
+from typing import Any, Callable, List, Optional, Tuple, Union
import numpy as np
import torch
@@ -35,19 +35,23 @@
Batch = type(None)
-def to_dtype_tensor(value, dtype: torch.dtype = None, device: torch.device = None):
+def to_dtype_tensor(
+ value: Union[int, float, List[Union[int, float]]],
+ dtype: Optional[torch.dtype] = None,
+ device: Union[str, torch.device] = None,
+) -> torch.Tensor:
if device is None:
raise MisconfigurationException("device (torch.device) should be provided.")
return torch.tensor(value, dtype=dtype, device=device)
-def from_numpy(value, device: torch.device = None):
+def from_numpy(value: np.ndarray, device: Union[str, torch.device] = None) -> torch.Tensor:
if device is None:
raise MisconfigurationException("device (torch.device) should be provided.")
return torch.from_numpy(value).to(device)
-CONVERSION_DTYPES = [
+CONVERSION_DTYPES: List[Tuple[Any, Callable[[Any], torch.Tensor]]] = [
# bool -> uint8 as bool -> torch.bool triggers RuntimeError: Unsupported data type for NCCL process group
(bool, partial(to_dtype_tensor, dtype=torch.uint8)),
(int, partial(to_dtype_tensor, dtype=torch.int)),
@@ -61,19 +65,19 @@ def _is_namedtuple(obj: object) -> bool:
return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
-def _is_dataclass_instance(obj):
+def _is_dataclass_instance(obj: object) -> bool:
# https://docs.python.org/3/library/dataclasses.html#module-level-decorators-classes-and-functions
return dataclasses.is_dataclass(obj) and not isinstance(obj, type)
def apply_to_collection(
data: Any,
- dtype: Union[type, tuple],
+ dtype: Union[type, Any, Tuple[Union[type, Any]]],
function: Callable,
- *args,
- wrong_dtype: Optional[Union[type, tuple]] = None,
+ *args: Any,
+ wrong_dtype: Optional[Union[type, Tuple[type]]] = None,
include_none: bool = True,
- **kwargs
+ **kwargs: Any,
) -> Any:
"""
Recursively applies a function to all elements of a certain dtype.
@@ -121,7 +125,7 @@ def apply_to_collection(
return elem_type(*out) if is_namedtuple else elem_type(out)
if _is_dataclass_instance(data):
- out = {}
+ out_dict = {}
for field in data.__dataclass_fields__:
v = apply_to_collection(
getattr(data, field),
@@ -130,11 +134,11 @@ def apply_to_collection(
*args,
wrong_dtype=wrong_dtype,
include_none=include_none,
- **kwargs
+ **kwargs,
)
if include_none or v is not None:
- out[field] = v
- return elem_type(**out)
+ out_dict[field] = v
+ return elem_type(**out_dict)
# data is neither of dtype, nor a collection
return data
@@ -143,11 +147,11 @@ def apply_to_collection(
def apply_to_collections(
data1: Optional[Any],
data2: Optional[Any],
- dtype: Union[type, tuple],
+ dtype: Union[type, Any, Tuple[Union[type, Any]]],
function: Callable,
- *args,
- wrong_dtype: Optional[Union[type, tuple]] = None,
- **kwargs
+ *args: Any,
+ wrong_dtype: Optional[Union[type, Tuple[type]]] = None,
+ **kwargs: Any,
) -> Any:
"""
Zips two collections and applies a function to their items of a certain dtype.
@@ -169,7 +173,9 @@ def apply_to_collections(
AssertionError:
If sequence collections have different data sizes.
"""
- if data1 is None and data2 is not None:
+ if data1 is None:
+ if data2 is None:
+ return
# in case they were passed reversed
data1, data2 = data2, None
@@ -220,14 +226,14 @@ class TransferableDataType(ABC):
"""
@classmethod
- def __subclasshook__(cls, subclass):
+ def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
if cls is TransferableDataType:
to = getattr(subclass, "to", None)
return callable(to)
return NotImplemented
-def move_data_to_device(batch: Any, device: torch.device):
+def move_data_to_device(batch: Any, device: Union[str, torch.device]) -> Any:
"""
Transfers a collection of data to the given device. Any object that defines a method
``to(device)`` will be moved and all other objects in the collection will be left untouched.
@@ -245,7 +251,7 @@ def move_data_to_device(batch: Any, device: torch.device):
- :class:`torch.device`
"""
- def batch_to(data):
+ def batch_to(data: Any) -> Any:
# try to move torchtext data first
if _TORCHTEXT_AVAILABLE and isinstance(data, Batch):
@@ -269,14 +275,14 @@ def batch_to(data):
return apply_to_collection(batch, dtype=dtype, function=batch_to)
-def convert_to_tensors(data: Any, device: torch.device) -> Any:
+def convert_to_tensors(data: Any, device: Union[str, torch.device]) -> Any:
if device is None:
raise MisconfigurationException("`torch.device` should be provided.")
for src_dtype, conversion_func in CONVERSION_DTYPES:
data = apply_to_collection(data, src_dtype, conversion_func, device=device)
- def _move_to_device_and_make_contiguous(t: torch.Tensor, device: torch.device) -> torch.Tensor:
+ def _move_to_device_and_make_contiguous(t: torch.Tensor, device: Union[str, torch.device]) -> torch.Tensor:
return t.to(device).contiguous()
data = apply_to_collection(data, torch.Tensor, _move_to_device_and_make_contiguous, device=device)
diff --git a/pytorch_lightning/utilities/argparse_utils.py b/pytorch_lightning/utilities/argparse_utils.py
index e797719eb54a9..e3c2c3c86dd94 100644
--- a/pytorch_lightning/utilities/argparse_utils.py
+++ b/pytorch_lightning/utilities/argparse_utils.py
@@ -1,6 +1,6 @@
from pytorch_lightning.utilities import rank_zero_deprecation
-rank_zero_deprecation("`argparse_utils` package has been renamed to `argparse` since v1.2 and will be removed in v1.4")
+rank_zero_deprecation("`argparse_utils` package has been renamed to `argparse` since v1.2 and will be removed in v2.0")
# for backward compatibility with old checkpoints (versions < 1.2.0)
# that need to be able to unpickle the function from the checkpoint
diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py
index c9e378dbadee6..256168ae4382f 100644
--- a/pytorch_lightning/utilities/auto_restart.py
+++ b/pytorch_lightning/utilities/auto_restart.py
@@ -16,8 +16,12 @@
from copy import deepcopy
from dataclasses import dataclass, field
from functools import partial, wraps
+from random import getstate as python_get_rng_state
+from random import setstate as python_set_rng_state
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Tuple, Union
+import numpy as np
+import torch
from torch.utils.data import Dataset, get_worker_info, Sampler
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset
@@ -168,6 +172,16 @@ def update(self, generator_name: Optional[str], new_state: IteratorState) -> Non
state[latest_worker_id] = new_state
self.latest_worker_id = latest_worker_id
+ @property
+ def sampler_states(self) -> Dict[int, Any]:
+ """Returns the merged sampler states for all worker processes."""
+ return {0: self.state[k].sampler_state[0] for k in self.state.keys()}
+
+ @property
+ def dataset_states(self) -> Dict[int, Any]:
+ """Returns the merged dataset states for all worker processes."""
+ return {k: self.state[k].dataset_state[k] for k in self.state.keys()}
+
@classmethod
def from_state_dict(cls, state_dict) -> "MergedIteratorState":
if state_dict["represent_map_dataset"]:
@@ -188,7 +202,12 @@ def __len__(self) -> int:
class CaptureMapDataset(Dataset):
- """This class is used to capture the state from the map-based state dataset."""
+ """This class is used to capture the state from the map-based state dataset.
+
+ Note:
+ We currently don't support restoring if we fail during the first `N = num_workers` batches, where
+ `num_workers` is the number of workers spawned by the dataloader.
+ """
def __init__(self, dataset: Dataset) -> None:
self.dataset = dataset
@@ -202,8 +221,7 @@ def worker_id(self) -> int:
def __getitem__(self, item) -> Tuple[Any, Dict[int, Dict]]:
if self._cached_state_dict is not None:
if self.worker_id in self._cached_state_dict:
- # TODO: reset random states
- pass
+ set_rng_states(self._cached_state_dict[self.worker_id]["rng_states"])
self._cached_state_dict = None
data = self.dataset[item]
@@ -227,7 +245,19 @@ def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num
self._cached_state_dict = state_dict
def _state_dict(self) -> Dict[int, Dict[str, Any]]:
- return {self.worker_id: {"rng_states": {}}}
+ return {self.worker_id: {"rng_states": collect_rng_states()}}
+
+
+def collect_rng_states() -> Dict[str, Any]:
+ """Collect the global random state of :mod:`torch`, :mod:`numpy` and Python."""
+ return {"torch": torch.get_rng_state(), "numpy": np.random.get_state(), "python": python_get_rng_state()}
+
+
+def set_rng_states(rng_state_dict: Dict[str, Any]) -> None:
+ """Set the global random state of :mod:`torch`, :mod:`numpy` and Python in the current process."""
+ torch.set_rng_state(rng_state_dict.get("torch"))
+ np.random.set_state(rng_state_dict.get("numpy"))
+ python_set_rng_state(rng_state_dict.get("python"))
class CaptureIterableDataset(IterableDataset):
diff --git a/pytorch_lightning/utilities/debugging.py b/pytorch_lightning/utilities/debugging.py
index ee2b58be106b5..c759f2aee28b7 100644
--- a/pytorch_lightning/utilities/debugging.py
+++ b/pytorch_lightning/utilities/debugging.py
@@ -109,7 +109,7 @@ def track_load_dataloader_call(self, name: str, dataloaders: List[DataLoader]) -
@enabled_only
def track_train_loss_history(self, batch_idx: int, loss: torch.Tensor) -> None:
- loss_dict = {"batch_idx": batch_idx, "epoch": self.trainer.current_epoch, "loss": loss.detach()}
+ loss_dict = {"batch_idx": batch_idx, "epoch": self.trainer.current_epoch, "loss": loss.detach().clone()}
self.saved_train_losses.append(loss_dict)
@enabled_only
diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py
index 68868a0ff74cd..4f254b6824489 100644
--- a/pytorch_lightning/utilities/distributed.py
+++ b/pytorch_lightning/utilities/distributed.py
@@ -284,12 +284,14 @@ def register_ddp_comm_hook(
.. warning ::
DDP communication wrapper needs pytorch version at least 1.9.0
+ Post-localSGD hook needs pytorch version at least 1.9.0
Example:
from torch.distributed.algorithms.ddp_comm_hooks import (
default_hooks as default,
powerSGD_hook as powerSGD,
+ post_localSGD_hook as post_localSGD,
)
# fp16_compress_hook for compress gradients
@@ -309,6 +311,18 @@ def register_ddp_comm_hook(
ddp_comm_hook=powerSGD.powerSGD_hook,
)
+ # post_localSGD_hook
+ subgroup, _ = torch.distributed.new_subgroups()
+ register_comm_hook(
+ model=ddp_model,
+ state=post_localSGD.PostLocalSGDState(
+ process_group=None,
+ subgroup=subgroup,
+ start_localSGD_iter=1_000,
+ ),
+ ddp_comm_hook=post_localSGD.post_localSGD_hook,
+ )
+
# fp16_compress_wrapper combined with other communication hook
register_ddp_comm_hook(
model=ddp_model,
diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py
index 977b763299f8a..73fafabe8f5d9 100644
--- a/pytorch_lightning/utilities/enums.py
+++ b/pytorch_lightning/utilities/enums.py
@@ -49,6 +49,29 @@ class AMPType(LightningEnum):
NATIVE = "native"
+class PrecisionType(LightningEnum):
+ """Type of precision used.
+
+ >>> PrecisionType.HALF == 16
+ True
+ >>> PrecisionType.HALF in (16, "16")
+ True
+ """
+
+ HALF = "16"
+ FLOAT = "32"
+ FULL = "64"
+ BFLOAT = "bf16"
+
+ @staticmethod
+ def supported_type(precision: Union[str, int]) -> bool:
+ return any(x == precision for x in PrecisionType)
+
+ @staticmethod
+ def supported_types() -> List[str]:
+ return [x.value for x in PrecisionType]
+
+
class DistributedType(LightningEnum):
"""Define type of ditributed computing.
diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py
index 72f54a891cde3..d37cd3a9c1e6f 100644
--- a/pytorch_lightning/utilities/fetching.py
+++ b/pytorch_lightning/utilities/fetching.py
@@ -390,7 +390,6 @@ def __iter__(self) -> "StepFuncDataLoaderIter":
def __next__(self) -> Any:
try:
data = next(self.iterator)
- # FIXME: Link this to `batch_idx`.
self.data_fetcher.fetched += 1
return data
except StopIteration:
diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py
index fa6598f884b19..ff0e5bea6bf19 100644
--- a/pytorch_lightning/utilities/imports.py
+++ b/pytorch_lightning/utilities/imports.py
@@ -70,7 +70,6 @@ def _compare_version(package: str, op, version) -> bool:
_TORCH_GREATER_EQUAL_1_9 = _compare_version("torch", operator.ge, "1.9.0")
_TORCH_GREATER_EQUAL_1_10 = _compare_version("torch", operator.ge, "1.10.0")
-
_APEX_AVAILABLE = _module_available("apex.amp")
_BOLTS_AVAILABLE = _module_available("pl_bolts")
_DEEPSPEED_AVAILABLE = _module_available("deepspeed")
@@ -87,8 +86,16 @@ def _compare_version(package: str, op, version) -> bool:
_OMEGACONF_AVAILABLE = _module_available("omegaconf")
_POPTORCH_AVAILABLE = _module_available("poptorch")
_RICH_AVAILABLE = _module_available("rich")
+_TORCH_CPU_AMP_AVAILABLE = _compare_version(
+ "torch", operator.ge, "1.10.0dev20210501"
+) # todo: swap to 1.10.0 once released
+_TORCH_BFLOAT_AVAILABLE = _compare_version(
+ "torch", operator.ge, "1.10.0.dev20210820"
+) # todo: swap to 1.10.0 once released
_TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != "none"])
-_TORCH_SHARDED_TENSOR_AVAILABLE = _compare_version("torch", operator.ge, "1.10.0.dev20210809")
+_TORCH_SHARDED_TENSOR_AVAILABLE = _compare_version(
+ "torch", operator.ge, "1.10.0.dev20210809"
+) # todo: swap to 1.10.0 once released
_TORCHTEXT_AVAILABLE = _module_available("torchtext")
_TORCHVISION_AVAILABLE = _module_available("torchvision")
_TORCHMETRICS_LOWER_THAN_0_3 = _compare_version("torchmetrics", operator.lt, "0.3.0")
diff --git a/pytorch_lightning/utilities/signature_utils.py b/pytorch_lightning/utilities/signature_utils.py
index 5c7e468d84738..05045e98d3af6 100644
--- a/pytorch_lightning/utilities/signature_utils.py
+++ b/pytorch_lightning/utilities/signature_utils.py
@@ -12,15 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
-from typing import Callable
+from typing import Callable, Optional
-def is_param_in_hook_signature(hook_fx: Callable, param: str, explicit: bool = False) -> bool:
+def is_param_in_hook_signature(
+ hook_fx: Callable, param: str, explicit: bool = False, min_args: Optional[int] = None
+) -> bool:
"""
Args:
hook_fx: the hook callable
param: the name of the parameter to check
explicit: whether the parameter has to be explicitly declared
+ min_args: whether the `signature` as at least `min_args` parameters
"""
hook_params = list(inspect.signature(hook_fx).parameters)
- return param in hook_params or (not explicit and "args" in hook_params)
+ return (
+ param in hook_params
+ or (not explicit and "args" in hook_params)
+ or (isinstance(min_args, int) and len(hook_params) >= min_args)
+ )
diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py
index e27b873a63941..f4793a61b87c2 100644
--- a/tests/accelerators/test_accelerator_connector.py
+++ b/tests/accelerators/test_accelerator_connector.py
@@ -629,3 +629,10 @@ def test_accelerator_ddp_for_cpu(tmpdir):
trainer = Trainer(accelerator="ddp", num_processes=2)
assert isinstance(trainer.accelerator, CPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDPPlugin)
+
+
+@pytest.mark.parametrize("precision", [1, 12, "invalid"])
+def test_validate_precision_type(tmpdir, precision):
+
+ with pytest.raises(MisconfigurationException, match=f"Precision {precision} is invalid"):
+ Trainer(precision=precision)
diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py
index 7280afea02f76..d695e4c63f43e 100644
--- a/tests/accelerators/test_cpu.py
+++ b/tests/accelerators/test_cpu.py
@@ -1,7 +1,6 @@
import os
from pathlib import Path
from typing import Any, Dict, Union
-from unittest.mock import Mock
import pytest
import torch
@@ -10,22 +9,10 @@
from pytorch_lightning.accelerators import CPUAccelerator
from pytorch_lightning.plugins import SingleDevicePlugin
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
-from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
-from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel
-def test_unsupported_precision_plugins():
- """Test error messages are raised for unsupported precision plugins with CPU."""
- trainer = Mock()
- accelerator = CPUAccelerator(
- training_type_plugin=SingleDevicePlugin(torch.device("cpu")), precision_plugin=MixedPrecisionPlugin()
- )
- with pytest.raises(MisconfigurationException, match=r"AMP \+ CPU is not supported"):
- accelerator.setup(trainer=trainer)
-
-
@pytest.mark.parametrize("delay_dispatch", [True, False])
def test_plugin_setup_optimizers_in_pre_dispatch(tmpdir, delay_dispatch):
"""
diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py
index d190feed7e1f7..c363638d565d2 100644
--- a/tests/callbacks/test_callbacks.py
+++ b/tests/callbacks/test_callbacks.py
@@ -12,10 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
+from re import escape
from unittest.mock import call, Mock
+import pytest
+
from pytorch_lightning import Callback, Trainer
+from pytorch_lightning.callbacks import ModelCheckpoint
from tests.helpers import BoringModel
+from tests.helpers.utils import no_warning_call
def test_callbacks_configured_in_model(tmpdir):
@@ -109,7 +114,7 @@ def __init__(self, state):
self.state = state
@property
- def state_id(self):
+ def state_key(self):
return type(self)
def on_save_checkpoint(self, *args):
@@ -120,7 +125,7 @@ def on_load_checkpoint(self, trainer, pl_module, callback_state):
def test_resume_callback_state_saved_by_type(tmpdir):
- """Test that a legacy checkpoint that didn't use a state identifier before can still be loaded."""
+ """Test that a legacy checkpoint that didn't use a state key before can still be loaded."""
model = BoringModel()
callback = OldStatefulCallback(state=111)
trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[callback])
@@ -132,3 +137,34 @@ def test_resume_callback_state_saved_by_type(tmpdir):
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback], resume_from_checkpoint=ckpt_path)
trainer.fit(model)
assert callback.state == 111
+
+
+def test_resume_incomplete_callbacks_list_warning(tmpdir):
+ model = BoringModel()
+ callback0 = ModelCheckpoint(monitor="epoch")
+ callback1 = ModelCheckpoint(monitor="global_step")
+ trainer = Trainer(
+ default_root_dir=tmpdir,
+ max_steps=1,
+ callbacks=[callback0, callback1],
+ )
+ trainer.fit(model)
+ ckpt_path = trainer.checkpoint_callback.best_model_path
+
+ trainer = Trainer(
+ default_root_dir=tmpdir,
+ max_steps=1,
+ callbacks=[callback1], # one callback is missing!
+ resume_from_checkpoint=ckpt_path,
+ )
+ with pytest.warns(UserWarning, match=escape(f"Please add the following callbacks: [{repr(callback0.state_key)}]")):
+ trainer.fit(model)
+
+ trainer = Trainer(
+ default_root_dir=tmpdir,
+ max_steps=1,
+ callbacks=[callback1, callback0], # all callbacks here, order switched
+ resume_from_checkpoint=ckpt_path,
+ )
+ with no_warning_call(UserWarning, match="Please add the following callbacks:"):
+ trainer.fit(model)
diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py
index 4c3b990dd1b13..ad343cdf329f5 100644
--- a/tests/callbacks/test_early_stopping.py
+++ b/tests/callbacks/test_early_stopping.py
@@ -33,6 +33,11 @@
_logger = logging.getLogger(__name__)
+def test_early_stopping_state_key():
+ early_stopping = EarlyStopping(monitor="val_loss")
+ assert early_stopping.state_key == "EarlyStopping{'monitor': 'val_loss', 'mode': 'min'}"
+
+
class EarlyStoppingTestRestore(EarlyStopping):
# this class has to be defined outside the test function, otherwise we get pickle error
def __init__(self, expected_state, *args, **kwargs):
@@ -77,7 +82,8 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
# the checkpoint saves "epoch + 1"
early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"] - 1]
assert 4 == len(early_stop_callback.saved_states)
- assert checkpoint["callbacks"]["EarlyStoppingTestRestore"] == early_stop_callback_state
+ es_name = "EarlyStoppingTestRestore{'monitor': 'train_loss', 'mode': 'min'}"
+ assert checkpoint["callbacks"][es_name] == early_stop_callback_state
# ensure state is reloaded properly (assertion in the callback)
early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state, monitor="train_loss")
diff --git a/tests/callbacks/test_lambda_function.py b/tests/callbacks/test_lambda_function.py
index 82f64d676c774..88752d56bf697 100644
--- a/tests/callbacks/test_lambda_function.py
+++ b/tests/callbacks/test_lambda_function.py
@@ -11,12 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import inspect
from functools import partial
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import Callback, LambdaCallback
from tests.helpers.boring_model import BoringModel
+from tests.models.test_hooks import get_members
def test_lambda_call(tmpdir):
@@ -32,7 +32,7 @@ def on_train_epoch_start(self):
def call(hook, *_, **__):
checker.add(hook)
- hooks = {m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)}
+ hooks = get_members(Callback)
hooks_args = {h: partial(call, h) for h in hooks}
hooks_args["on_save_checkpoint"] = lambda *_: [checker.add("on_save_checkpoint")]
diff --git a/tests/callbacks/test_timer.py b/tests/callbacks/test_timer.py
index 44d3c305bb1ac..c7b636d3f843a 100644
--- a/tests/callbacks/test_timer.py
+++ b/tests/callbacks/test_timer.py
@@ -42,6 +42,11 @@ def on_fit_start(self):
trainer.fit(TestModel())
assert "callbacks list already contains a Timer" in caplog.text
+ seconds = 1
+ trainer = Trainer(max_time=dict(seconds=seconds))
+ assert trainer.max_epochs is None
+ assert trainer.max_steps is None
+
@pytest.mark.parametrize(
"duration,expected",
diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py
index 314ed899c588a..f49fa16598fd2 100644
--- a/tests/checkpointing/test_model_checkpoint.py
+++ b/tests/checkpointing/test_model_checkpoint.py
@@ -43,6 +43,15 @@
from tests.helpers.runif import RunIf
+def test_model_checkpoint_state_key():
+ early_stopping = ModelCheckpoint(monitor="val_loss")
+ expected_id = (
+ "ModelCheckpoint{'monitor': 'val_loss', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
+ " 'train_time_interval': None, 'save_on_train_epoch_end': None}"
+ )
+ assert early_stopping.state_key == expected_id
+
+
class LogInTwoMethods(BoringModel):
def training_step(self, batch, batch_idx):
out = super().training_step(batch, batch_idx)
@@ -148,7 +157,10 @@ def on_validation_epoch_end(self):
assert chk["epoch"] == epoch + 1
assert chk["global_step"] == limit_train_batches * (epoch + 1)
- mc_specific_data = chk["callbacks"]["ModelCheckpoint"]
+ mc_specific_data = chk["callbacks"][
+ f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
+ " 'train_time_interval': None, 'save_on_train_epoch_end': True}"
+ ]
assert mc_specific_data["dirpath"] == checkpoint.dirpath
assert mc_specific_data["monitor"] == monitor
assert mc_specific_data["current_score"] == score
@@ -259,7 +271,10 @@ def _make_assertions(epoch, ix, version=""):
expected_global_step = per_val_train_batches * (global_ix + 1) + (leftover_train_batches * epoch_num)
assert chk["global_step"] == expected_global_step
- mc_specific_data = chk["callbacks"]["ModelCheckpoint"]
+ mc_specific_data = chk["callbacks"][
+ f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
+ " 'train_time_interval': None, 'save_on_train_epoch_end': False}"
+ ]
assert mc_specific_data["dirpath"] == checkpoint.dirpath
assert mc_specific_data["monitor"] == monitor
assert mc_specific_data["current_score"] == score
@@ -857,7 +872,12 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
assert ckpt_last_epoch["epoch"] == ckpt_last["epoch"]
assert ckpt_last_epoch["global_step"] == ckpt_last["global_step"]
- assert ckpt_last["callbacks"]["ModelCheckpoint"] == ckpt_last_epoch["callbacks"]["ModelCheckpoint"]
+
+ ckpt_id = (
+ "ModelCheckpoint{'monitor': 'early_stop_on', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
+ " 'train_time_interval': None, 'save_on_train_epoch_end': True}"
+ )
+ assert ckpt_last["callbacks"][ckpt_id] == ckpt_last_epoch["callbacks"][ckpt_id]
# it is easier to load the model objects than to iterate over the raw dict of tensors
model_last_epoch = LogInTwoMethods.load_from_checkpoint(path_last_epoch)
@@ -1095,7 +1115,13 @@ def training_step(self, *args):
trainer.fit(TestModel())
assert model_checkpoint.current_score == 0.3
ckpts = [torch.load(str(ckpt)) for ckpt in tmpdir.listdir()]
- ckpts = [ckpt["callbacks"]["ModelCheckpoint"] for ckpt in ckpts]
+ ckpts = [
+ ckpt["callbacks"][
+ "ModelCheckpoint{'monitor': 'foo', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
+ " 'train_time_interval': None, 'save_on_train_epoch_end': True}"
+ ]
+ for ckpt in ckpts
+ ]
assert sorted(ckpt["current_score"] for ckpt in ckpts) == [0.1, 0.2, 0.3]
diff --git a/tests/checkpointing/test_trainer_checkpoint.py b/tests/checkpointing/test_trainer_checkpoint.py
index f76e76b2f9dd9..6a8192ef0149e 100644
--- a/tests/checkpointing/test_trainer_checkpoint.py
+++ b/tests/checkpointing/test_trainer_checkpoint.py
@@ -17,7 +17,7 @@
import torch
import pytorch_lightning as pl
-from pytorch_lightning import seed_everything, Trainer
+from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from tests.helpers import BoringModel
@@ -27,8 +27,6 @@ def test_finetuning_with_resume_from_checkpoint(tmpdir):
This test validates that generated ModelCheckpoint is pointing to the right best_model_path during test
"""
- seed_everything(4)
-
checkpoint_callback = ModelCheckpoint(monitor="val_loss", dirpath=tmpdir, filename="{epoch:02d}", save_top_k=-1)
class ExtendedBoringModel(BoringModel):
@@ -75,9 +73,6 @@ def validation_step(self, batch, batch_idx):
results.append(deepcopy(trainer.callback_metrics))
best_model_paths.append(trainer.checkpoint_callback.best_model_path)
- for idx in range(len(results) - 1):
- assert results[idx]["val_loss"] > results[idx + 1]["val_loss"]
-
for idx, best_model_path in enumerate(best_model_paths):
if idx == 0:
assert best_model_path.endswith(f"epoch=0{idx}.ckpt")
diff --git a/tests/core/test_lightning_optimizer.py b/tests/core/test_lightning_optimizer.py
index f8c3317b6c595..4068bc5504b5b 100644
--- a/tests/core/test_lightning_optimizer.py
+++ b/tests/core/test_lightning_optimizer.py
@@ -20,6 +20,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.core.optimizer import LightningOptimizer
+from pytorch_lightning.loops.closure import Closure
from tests.helpers.boring_model import BoringModel
@@ -221,7 +222,7 @@ def training_epoch_end(self, outputs):
...
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, **_):
- assert optimizer_closure.__name__ == "_training_step_and_backward_closure"
+ assert isinstance(optimizer_closure, Closure)
# not passing the closure to the optimizer because step is mocked
# zero_grad is called inside the closure
if isinstance(optimizer, SGD) and batch_idx % 2 == 0:
diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py
index 7581bf2b0c142..ae8f9e1dcc53d 100644
--- a/tests/deprecated_api/test_remove_1-7.py
+++ b/tests/deprecated_api/test_remove_1-7.py
@@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Test deprecated functionality which will be removed in v1.7.0 """
+from unittest import mock
import pytest
from pytorch_lightning import LightningDataModule, Trainer
+from pytorch_lightning.loggers import TestTubeLogger
from tests.deprecated_api import _soft_unimport_module
from tests.helpers import BoringModel
from tests.helpers.datamodules import MNISTDataModule
@@ -87,3 +89,9 @@ def test_v1_7_0_trainer_prepare_data_per_node(tmpdir):
match="Setting `prepare_data_per_node` with the trainer flag is deprecated and will be removed in v1.7.0!"
):
_ = Trainer(prepare_data_per_node=False)
+
+
+@mock.patch("pytorch_lightning.loggers.test_tube.Experiment")
+def test_v1_7_0_test_tube_logger(_, tmpdir):
+ with pytest.deprecated_call(match="The TestTubeLogger is deprecated since v1.5 and will be removed in v1.7"):
+ _ = TestTubeLogger(tmpdir)
diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_2-0.py
similarity index 89%
rename from tests/deprecated_api/test_remove_1-4.py
rename to tests/deprecated_api/test_remove_2-0.py
index a3a4a0b1b9180..9c372c8f1a9c6 100644
--- a/tests/deprecated_api/test_remove_1-4.py
+++ b/tests/deprecated_api/test_remove_2-0.py
@@ -18,7 +18,7 @@
from tests.deprecated_api import _soft_unimport_module
-def test_v1_4_0_deprecated_imports():
+def test_v2_0_0_deprecated_imports():
_soft_unimport_module("pytorch_lightning.utilities.argparse_utils")
- with pytest.deprecated_call(match="will be removed in v1.4"):
+ with pytest.deprecated_call(match="will be removed in v2.0"):
from pytorch_lightning.utilities.argparse_utils import _gpus_arg_default # noqa: F401
diff --git a/tests/loops/test_inter_batch_parallelism.py b/tests/loops/test_inter_batch_parallelism.py
deleted file mode 100644
index 00bc7049b0a29..0000000000000
--- a/tests/loops/test_inter_batch_parallelism.py
+++ /dev/null
@@ -1,190 +0,0 @@
-# Copyright The PyTorch Lightning team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import time
-from statistics import mean
-from typing import Iterator
-
-import torch
-from torch.utils.data import DataLoader, IterableDataset
-
-from pytorch_lightning import LightningModule, Trainer
-from pytorch_lightning.utilities.types import STEP_OUTPUT
-from tests.helpers.runif import RunIf
-
-
-def count_cycles_per_ms() -> float:
- """
- Measure and return approximate number of cycles per millisecond for torch.cuda._sleep
-
- Copied from: github.com/pytorch/pytorch/blob/master/test/test_cuda.py
- """
-
- def measure() -> float:
- start = torch.cuda.Event(enable_timing=True)
- end = torch.cuda.Event(enable_timing=True)
- start.record()
- torch.cuda._sleep(1000000)
- end.record()
- end.synchronize()
- cycles_per_ms = 1000000 / start.elapsed_time(end)
- return cycles_per_ms
-
- # Get 10 values and remove the 2 max and 2 min and return the avg.
- # This is to avoid system disturbance that skew the results, e.g.
- # the very first cuda call likely does a bunch of init, which takes
- # much longer than subsequent calls.
- #
- # Tested on both Tesla V100, Quadro GP100, Titan RTX, RTX 3090 GPUs
- # and seems to return stable values. Therefore, we enable caching
- # using lru_cache decorator above.
- num = 10
- vals = []
- for _ in range(num):
- vals.append(measure())
- vals = sorted(vals)
- return mean(vals[2 : num - 2])
-
-
-_CYCLES_PER_MS = int(count_cycles_per_ms()) if torch.cuda.is_available() else 0
-_BATCH_SIZE = 128
-_EMB_SZ = 100
-_EMB_DIM = 64
-
-
-class RandomSparseDataset(IterableDataset):
- def __init__(self, emb_dim: int, batch_size: int, count: int) -> None:
- self.emb_dim = emb_dim
- self.batch_size = batch_size
- self.count = count
-
- def __iter__(self):
- for _ in range(self.count):
- yield torch.randint(self.emb_dim, [self.batch_size])
-
-
-class ToyDLRMModel(LightningModule):
- """
- A toy model for mimicking the communication overhead of sharded embedding
- modules in DLRM models.
-
- DLRM models can be trained in a DDP-like fashion, where each trainer
- receives different batches (embedding indices in this example). Since the
- embeddings are sharded across trainers, the lookup process involves (1)
- routing the indices to the trainer that possesses the corresponding
- embeddings (2) performing local lookup (3) routing the embedding lookup
- result back.
-
- The toy model doesn't actually performs index/result routing. It simply
- uses torch.cuda._sleep() to mimic the cost of the communication op (i.e.
- a2a).
- """
-
- def __init__(self):
- super().__init__()
- self.automatic_optimization = False
- self.local_embedding = torch.nn.Embedding(_EMB_SZ, _EMB_DIM)
-
- def _route_indices(self, batch: torch.Tensor, non_blocking=False):
- """
- This can be parallelized across different batches since it's model
- weight independent.
-
- Why not run this in dataloader/datamodule?
- - The routing logic depends on how model is sharded
- - Putting this in data preprocessor changes the semantic of the model
- """
- torch.cuda._sleep(_CYCLES_PER_MS * 1_000)
- if not non_blocking:
- torch.cuda.synchronize()
- return batch
-
- def _route_result(self, result: torch.Tensor, non_blocking=False):
- torch.cuda._sleep(_CYCLES_PER_MS * 1_000)
- if not non_blocking:
- torch.cuda.synchronize()
- return result
-
- def forward(self, indices: torch.Tensor):
- local_indices = self._route_indices(indices)
- result = self.local_embedding(local_indices)
- return self._route_result(result)
-
- def training_step(self, batch: torch.Tensor, batch_idx: int) -> STEP_OUTPUT:
- return self.forward(batch)
-
- def configure_optimizers(self):
- return torch.optim.SGD(self.local_embedding.parameters(), lr=0.1)
-
- def train_dataloader(self):
- return DataLoader(RandomSparseDataset(_EMB_DIM, _BATCH_SIZE, 5))
-
-
-class AsyncToyDLRMModel(ToyDLRMModel):
- def __init__(self):
- super().__init__()
- self.comm_stream = torch.cuda.Stream()
- self.batch_i = None
- self.batch_i_ready = torch.cuda.Event()
-
- def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT:
- if self.batch_i is None:
- self.batch_i = next(dataloader_iter)
- with torch.cuda.stream(self.comm_stream):
- self._route_indices(self.batch_i, non_blocking=True)
- self.batch_i_ready.record()
-
- # Invariant: the routing for batch[i] has been kicked off
- is_last = False
- batch_ip1 = None
- batch_ip1_ready = torch.cuda.Event()
- try:
- batch_ip1 = next(dataloader_iter)
- with torch.cuda.stream(self.comm_stream):
- self._route_indices(batch_ip1, non_blocking=True)
- batch_ip1_ready.record()
- except StopIteration:
- is_last = True
-
- self.batch_i_ready.wait()
-
- result = self.local_embedding(self.batch_i)
- self._route_result(result)
-
- self.batch_i = batch_ip1
- self.batch_i_ready = batch_ip1_ready
-
- return {"is_last": is_last}
-
-
-@RunIf(min_gpus=1)
-def test_inter_batch_parallelism(tmpdir):
- """
- Verify the speedup of a simple inter-batch parallelization use case enabled
- by exposing `dataloader_iter` to `training_step`.
- """
- begin_time = time.time()
- m = AsyncToyDLRMModel()
- trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
- trainer.fit(m)
- async_duration = time.time() - begin_time
-
- begin_time = time.time()
- m = ToyDLRMModel()
- trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
- trainer.fit(m)
- sync_duration = time.time() - begin_time
-
- # We expect 2x speedup. However, we only assert that the async
- # training_step is faster in order to avoid flaky tests
- assert async_duration < sync_duration, "Expect `AsyncToyDLRMModel` to train faster than `ToyDLRMModel`."
diff --git a/tests/loops/test_iterator_batch_processor.py b/tests/loops/test_iterator_batch_processor.py
deleted file mode 100644
index 2cd6a172f6941..0000000000000
--- a/tests/loops/test_iterator_batch_processor.py
+++ /dev/null
@@ -1,183 +0,0 @@
-# Copyright The PyTorch Lightning team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from typing import Any, Iterator
-
-import pytest
-from torch.utils.data import DataLoader
-
-from pytorch_lightning import Trainer
-from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from pytorch_lightning.utilities.types import STEP_OUTPUT
-from tests.helpers import BoringModel, RandomDataset
-
-_BATCH_SIZE = 32
-_DATASET_LEN = 64
-
-
-class DummyWaitable:
- def __init__(self, val: Any) -> None:
- self.val = val
-
- def wait(self) -> Any:
- return self.val
-
-
-class AsyncBoringModel(BoringModel):
- def __init__(self) -> None:
- super().__init__()
- self.automatic_optimization = False
- self.batch_i_handle = None
- self.num_batches_processed = 0
-
- def _async_op(self, batch: Any) -> DummyWaitable:
- return DummyWaitable(val=batch)
-
- def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT:
- if self.batch_i_handle is None:
- batch_i_raw = next(dataloader_iter)
- self.batch_i_handle = self._async_op(batch_i_raw)
-
- # Invariant: _async_op for batch[i] has been initiated
- batch_ip1_handle = None
- is_last = False
- try:
- batch_ip1_raw = next(dataloader_iter)
- batch_ip1_handle = self._async_op(batch_ip1_raw)
- except StopIteration:
- is_last = True
-
- batch_i = self.batch_i_handle.wait()
-
- pred = self.layer(batch_i)
- loss = self.loss(batch_i, pred)
- loss.backward()
- self.optimizers().step()
- self.optimizers().zero_grad()
-
- self.batch_i_handle = batch_ip1_handle
- self.num_batches_processed += 1
-
- return {"loss": loss, "is_last": is_last}
-
- def train_dataloader(self):
- return DataLoader(RandomDataset(_BATCH_SIZE, _DATASET_LEN))
-
-
-def test_training_step_with_dataloader_access(tmpdir) -> None:
- """
- A baseline functional test for `training_step` with dataloader access.
- """
- trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
- m = AsyncBoringModel()
- trainer.fit(m)
- assert m.num_batches_processed == _DATASET_LEN, f"Expect all {_DATASET_LEN} batches to be processed."
-
-
-def test_stop_iteration(tmpdir) -> None:
- """
- Verify that when `StopIteration` is raised within `training_step`, `fit()`
- terminiates as expected.
- """
- EXPECT_NUM_BATCHES_PROCESSED = 2
-
- class TestModel(AsyncBoringModel):
- def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT:
- output = super().training_step(dataloader_iter)
- if self.num_batches_processed == EXPECT_NUM_BATCHES_PROCESSED:
- raise StopIteration()
- return output
-
- trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
- m = TestModel()
- trainer.fit(m)
- assert (
- m.num_batches_processed == EXPECT_NUM_BATCHES_PROCESSED
- ), "Expect {EXPECT_NUM_BATCHES_PROCESSED} batches to be processed."
-
-
-def test_on_train_batch_start_overridden(tmpdir) -> None:
- """
- Verify that a `MisconfigurationException` is raised when
- `on_train_batch_start` is overridden on the `LightningModule`.
- """
-
- class InvalidModel(AsyncBoringModel):
- def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
- pass
-
- trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
- m = InvalidModel()
- with pytest.raises(MisconfigurationException):
- trainer.fit(m)
-
-
-def test_on_train_batch_end_overridden(tmpdir) -> None:
- """
- Verify that a `MisconfigurationException` is raised when
- `on_train_batch_end` is overridden on the `LightningModule`.
- """
-
- class InvalidModel(AsyncBoringModel):
- def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
- pass
-
- trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
- m = InvalidModel()
- with pytest.raises(MisconfigurationException):
- trainer.fit(m)
-
-
-def test_tbptt_split_batch_overridden(tmpdir) -> None:
- """
- Verify that a `MisconfigurationException` is raised when
- `tbptt_split_batch` is overridden on the `LightningModule`.
- """
-
- class InvalidModel(AsyncBoringModel):
- def tbptt_split_batch(self, batch, split_size):
- pass
-
- trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
- m = InvalidModel()
- with pytest.raises(MisconfigurationException):
- trainer.fit(m)
-
-
-def test_accumulate_grad_batches(tmpdir) -> None:
- """
- Verify that a `MisconfigurationException` is raised when
- `accumulate_grad_batches` is not set to 1.
- """
- trainer = Trainer(max_epochs=1, accumulate_grad_batches=2, default_root_dir=tmpdir)
- m = AsyncBoringModel()
- with pytest.raises(MisconfigurationException):
- trainer.fit(m)
-
-
-def test_is_last_not_set(tmpdir) -> None:
- """
- Verify that a `MisconfigurationException` is raised when `training_step`
- doesn't include "is_last" in the result dict.
- """
-
- class InvalidModel(AsyncBoringModel):
- def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT:
- output = super().training_step(dataloader_iter)
- del output["is_last"]
- return output
-
- trainer = Trainer(max_epochs=1, accumulate_grad_batches=2, default_root_dir=tmpdir)
- m = InvalidModel()
- with pytest.raises(MisconfigurationException):
- trainer.fit(m)
diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py
index 22d2be8c3a9b0..5fdd09d5fd4d4 100644
--- a/tests/loops/test_loop_state_dict.py
+++ b/tests/loops/test_loop_state_dict.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from unittest.mock import Mock
+from unittest.mock import ANY, Mock
import pytest
import torch
@@ -22,23 +22,31 @@
def test_loops_state_dict():
+ trainer = Trainer()
+ trainer.train_dataloader = Mock()
+
fit_loop = FitLoop()
with pytest.raises(MisconfigurationException, match="Loop FitLoop should be connected to a"):
fit_loop.trainer = object()
+ fit_loop.trainer = trainer
fit_loop.connect(Mock())
state_dict = fit_loop.state_dict()
+
new_fit_loop = FitLoop()
+ new_fit_loop.trainer = trainer
+
new_fit_loop.load_state_dict(state_dict)
assert fit_loop.state_dict() == new_fit_loop.state_dict()
def test_loops_state_dict_structure():
trainer = Trainer()
+ trainer.train_dataloader = Mock()
state_dict = trainer.checkpoint_connector._get_loops_state_dict()
expected = {
"fit_loop": {
- "state_dict": {},
+ "state_dict": {"dataloader_state_dict": ANY},
"epoch_loop.state_dict": {},
"epoch_loop.batch_progress": {
"total": {"ready": 0, "started": 0, "processed": 0, "completed": 0},
diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py
index 65cbebc8203e5..200b2daae93ed 100644
--- a/tests/loops/test_loops.py
+++ b/tests/loops/test_loops.py
@@ -504,7 +504,13 @@ def configure_optimizers_multiple(self):
assert checkpoint["loops"]["fit_loop"] == expected
trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"], restart_progress=False)
- assert trainer.fit_loop.state_dict() == checkpoint["loops"]["fit_loop"]
+ state_dict = trainer.fit_loop.state_dict()
+
+ # need to remove these elements for comparison; comparing with `fit_loop.state_dict()` would require the
+ # fit loop to have an iterator, which is only available during training
+ checkpoint["loops"]["fit_loop"]["state_dict"]["dataloader_state_dict"] = ANY
+
+ assert state_dict == checkpoint["loops"]["fit_loop"]
trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"])
state_dict = trainer.fit_loop.state_dict()
diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py
index 79c0cf7c12f15..5fdc180613e67 100644
--- a/tests/models/test_amp.py
+++ b/tests/models/test_amp.py
@@ -22,7 +22,7 @@
import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.plugins.environments import SLURMEnvironment
-from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10
+from pytorch_lightning.utilities import _TORCH_BFLOAT_AVAILABLE, _TORCH_CPU_AMP_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.runif import RunIf
@@ -30,13 +30,19 @@
class AMPTestModel(BoringModel):
def _step(self, batch, batch_idx):
- assert torch.is_autocast_enabled()
+ self._assert_autocast_enabled()
output = self(batch)
bfloat16 = self.trainer.precision_plugin.is_bfloat16
assert output.dtype == torch.float16 if not bfloat16 else torch.bfloat16
loss = self.loss(batch, output)
return loss
+ def loss(self, batch, prediction):
+ # todo (sean): convert bfloat16 to float32 as mse loss for cpu amp is currently not supported
+ if self.trainer.precision_plugin.use_cpu:
+ prediction = prediction.float()
+ return super().loss(batch, prediction)
+
def training_step(self, batch, batch_idx):
output = self._step(batch, batch_idx)
return {"loss": output}
@@ -50,17 +56,58 @@ def test_step(self, batch, batch_idx):
return {"y": output}
def predict(self, batch, batch_idx, dataloader_idx=None):
- assert torch.is_autocast_enabled()
+ self._assert_autocast_enabled()
output = self(batch)
bfloat16 = self.trainer.precision_plugin.is_bfloat16
assert output.dtype == torch.float16 if not bfloat16 else torch.bfloat16
return output
+ def _assert_autocast_enabled(self):
+ if self.trainer.precision_plugin.use_cpu:
+ assert torch.is_autocast_cpu_enabled()
+ else:
+ assert torch.is_autocast_enabled()
+
+
+@pytest.mark.skipif(not _TORCH_CPU_AMP_AVAILABLE, reason="CPU AMP not available")
+@pytest.mark.parametrize(
+ "accelerator",
+ [
+ None,
+ pytest.param("dp", marks=pytest.mark.skip("dp + amp not supported currently")), # TODO
+ "ddp_spawn",
+ ],
+)
+@pytest.mark.parametrize(
+ "precision",
+ [
+ pytest.param(16, marks=pytest.mark.skip("CPU precision 16 is not supported in PyTorch yet.")), # TODO
+ "bf16",
+ ],
+)
+@pytest.mark.parametrize("num_processes", [1, 2])
+def test_amp_cpus(tmpdir, accelerator, precision, num_processes):
+ """Make sure combinations of AMP and training types work if supported."""
+ tutils.reset_seed()
+
+ trainer = Trainer(
+ default_root_dir=tmpdir, num_processes=num_processes, max_epochs=1, accelerator=accelerator, precision=precision
+ )
+
+ model = AMPTestModel()
+ # tutils.run_model_test(trainer_options, model)
+ trainer.fit(model)
+ trainer.test(model)
+ trainer.predict(model, DataLoader(RandomDataset(32, 64)))
+
+ assert trainer.state.finished, f"Training failed with {trainer.state}"
+
@RunIf(min_gpus=2)
@pytest.mark.parametrize(
"accelerator",
[
+ None,
pytest.param("dp", marks=pytest.mark.skip("dp + amp not supported currently")), # TODO
"ddp_spawn",
],
@@ -71,7 +118,7 @@ def predict(self, batch, batch_idx, dataloader_idx=None):
16,
pytest.param(
"bf16",
- marks=pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_10, reason="torch.bfloat16 not available"),
+ marks=pytest.mark.skipif(not _TORCH_BFLOAT_AVAILABLE, reason="torch.bfloat16 not available"),
),
],
)
diff --git a/tests/plugins/test_amp_plugins.py b/tests/plugins/test_amp_plugins.py
index 15ec43973b0ed..930cbd8797248 100644
--- a/tests/plugins/test_amp_plugins.py
+++ b/tests/plugins/test_amp_plugins.py
@@ -21,6 +21,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
+from pytorch_lightning.utilities import _TORCH_CPU_AMP_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
@@ -190,3 +191,50 @@ def test_amp_precision_16_bfloat_throws_error(tmpdir):
precision="bf16",
gpus=1,
)
+
+
+@RunIf(amp_native=True, max_torch="1.9")
+def test_cpu_amp_precision_throws_error(tmpdir):
+ with pytest.raises(
+ MisconfigurationException,
+ match="To use native AMP on CPU, install PyTorch 1.10 or later.",
+ ):
+ NativeMixedPrecisionPlugin(use_cpu=True)
+
+
+@pytest.mark.skipif(not _TORCH_CPU_AMP_AVAILABLE, reason="Torch CPU AMP is not available.")
+@RunIf(
+ min_gpus=1,
+ amp_native=True,
+)
+def test_cpu_amp_precision_context_manager(tmpdir):
+ """
+ Test to ensure that the context manager correctly is set to CPU + bfloat16, and a scaler isn't set.
+ """
+
+ plugin = NativeMixedPrecisionPlugin(precision="bf16", use_cpu=True)
+ assert plugin.use_cpu
+ assert not hasattr(plugin, "scaler")
+ context_manager = plugin.autocast_context_manager()
+ assert isinstance(context_manager, torch.cpu.amp.autocast)
+ assert context_manager.fast_dtype == torch.bfloat16
+
+
+@pytest.mark.skipif(not _TORCH_CPU_AMP_AVAILABLE, reason="Torch CPU AMP is not available.")
+@RunIf(
+ min_gpus=1,
+ amp_native=True,
+)
+def test_cpu_amp_precision_16_throws_error(tmpdir):
+ """
+ Throw error when using 16 as Native CPU AMP only supports bfloat16.
+ """
+
+ with pytest.raises(
+ MisconfigurationException,
+ match="CPU native amp only supports bfloat16. Please pass precision='bf16' to the Trainer.",
+ ):
+ Trainer(
+ default_root_dir=tmpdir,
+ precision=16,
+ )
diff --git a/tests/plugins/test_ddp_plugin_with_comm_hook.py b/tests/plugins/test_ddp_plugin_with_comm_hook.py
index 49c4cd18ef316..1e5968f1a0a5a 100644
--- a/tests/plugins/test_ddp_plugin_with_comm_hook.py
+++ b/tests/plugins/test_ddp_plugin_with_comm_hook.py
@@ -15,13 +15,15 @@
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DDPPlugin, DDPSpawnPlugin
-from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8
+from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_10
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
if torch.distributed.is_available() and _TORCH_GREATER_EQUAL_1_8:
from torch.distributed.algorithms.ddp_comm_hooks import default_hooks as default
from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook as powerSGD
+if torch.distributed.is_available() and _TORCH_GREATER_EQUAL_1_10:
+ import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
@RunIf(skip_windows=True, min_torch="1.9.0", min_gpus=2, special=True)
@@ -108,3 +110,32 @@ def test_ddp_spawn_fp16_compress_comm_hook(tmpdir):
)
trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"
+
+
+@RunIf(skip_windows=True, min_torch="1.10.0", min_gpus=2, special=True)
+def test_ddp_post_local_sgd_comm_hook(tmpdir):
+ """Test for DDP post-localSGD hook."""
+ model = BoringModel()
+
+ training_type_plugin = DDPPlugin(
+ ddp_comm_state=post_localSGD.PostLocalSGDState(
+ process_group=None,
+ subgroup=None,
+ start_localSGD_iter=8,
+ ),
+ ddp_comm_hook=post_localSGD.post_localSGD_hook,
+ model_averaging_period=4,
+ sync_batchnorm=True,
+ )
+ trainer = Trainer(
+ fast_dev_run=True,
+ gpus=2,
+ plugins=[training_type_plugin],
+ default_root_dir=tmpdir,
+ sync_batchnorm=True,
+ )
+ trainer.fit(model)
+ trainer_comm_hook = trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook
+ expected_comm_hook = post_localSGD.post_localSGD_hook.__qualname__
+ assert trainer_comm_hook == expected_comm_hook
+ assert trainer.state.finished, f"Training failed with {trainer.state}"
diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py
index 43158865f9e75..455e08dc10ad5 100644
--- a/tests/trainer/connectors/test_callback_connector.py
+++ b/tests/trainer/connectors/test_callback_connector.py
@@ -64,28 +64,55 @@ def on_save_checkpoint(self, *args):
class StatefulCallback1(Callback):
+ def __init__(self, unique=None, other=None):
+ self._unique = unique
+ self._other = other
+
+ @property
+ def state_key(self):
+ return self._generate_state_key(unique=self._unique)
+
def on_save_checkpoint(self, *args):
- return {"content1": 1}
+ return {"content1": self._unique}
def test_all_callback_states_saved_before_checkpoint_callback(tmpdir):
- """Test that all callback states get saved even if the ModelCheckpoint is not given as last."""
+ """
+ Test that all callback states get saved even if the ModelCheckpoint is not given as last
+ and when there are multiple callbacks of the same type.
+ """
callback0 = StatefulCallback0()
- callback1 = StatefulCallback1()
+ callback1 = StatefulCallback1(unique="one")
+ callback2 = StatefulCallback1(unique="two", other=2)
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename="all_states")
model = BoringModel()
trainer = Trainer(
- default_root_dir=tmpdir, max_steps=1, limit_val_batches=1, callbacks=[callback0, checkpoint_callback, callback1]
+ default_root_dir=tmpdir,
+ max_steps=1,
+ limit_val_batches=1,
+ callbacks=[
+ callback0,
+ # checkpoint callback does not have to be at the end
+ checkpoint_callback,
+ # callback2 and callback3 have the same type
+ callback1,
+ callback2,
+ ],
)
trainer.fit(model)
ckpt = torch.load(str(tmpdir / "all_states.ckpt"))
state0 = ckpt["callbacks"]["StatefulCallback0"]
- state1 = ckpt["callbacks"]["StatefulCallback1"]
+ state1 = ckpt["callbacks"]["StatefulCallback1{'unique': 'one'}"]
+ state2 = ckpt["callbacks"]["StatefulCallback1{'unique': 'two'}"]
assert "content0" in state0 and state0["content0"] == 0
- assert "content1" in state1 and state1["content1"] == 1
- assert "ModelCheckpoint" in ckpt["callbacks"]
+ assert "content1" in state1 and state1["content1"] == "one"
+ assert "content1" in state2 and state2["content1"] == "two"
+ assert (
+ "ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
+ " 'train_time_interval': None, 'save_on_train_epoch_end': True}" in ckpt["callbacks"]
+ )
def test_attach_model_callbacks():
diff --git a/tests/trainer/loops/test_evaluation_loop_flow.py b/tests/trainer/loops/test_evaluation_loop_flow.py
index 6c0b7b7ed40d7..78dd4bce728d4 100644
--- a/tests/trainer/loops/test_evaluation_loop_flow.py
+++ b/tests/trainer/loops/test_evaluation_loop_flow.py
@@ -78,10 +78,11 @@ def backward(self, loss, optimizer, optimizer_idx):
assert train_step_out.minimize.item() == 171
# make sure the optimizer closure returns the correct things
- opt_closure_result = trainer.fit_loop.epoch_loop.batch_loop.training_step_and_backward(
+ opt_closure = trainer.fit_loop.epoch_loop.batch_loop._make_closure(
batch, batch_idx, 0, trainer.optimizers[0], hiddens=None
)
- assert opt_closure_result["loss"].item() == 171
+ opt_closure_result = opt_closure()
+ assert opt_closure_result.item() == 171
def test__eval_step__eval_step_end__flow(tmpdir):
@@ -144,10 +145,11 @@ def backward(self, loss, optimizer, optimizer_idx):
assert train_step_out.minimize.item() == 171
# make sure the optimizer closure returns the correct things
- opt_closure_result = trainer.fit_loop.epoch_loop.batch_loop.training_step_and_backward(
+ opt_closure = trainer.fit_loop.epoch_loop.batch_loop._make_closure(
batch, batch_idx, 0, trainer.optimizers[0], hiddens=None
)
- assert opt_closure_result["loss"].item() == 171
+ opt_closure_result = opt_closure()
+ assert opt_closure_result.item() == 171
def test__eval_step__epoch_end__flow(tmpdir):
diff --git a/tests/trainer/loops/test_training_loop_flow_scalar.py b/tests/trainer/loops/test_training_loop_flow_scalar.py
index 4ee9d858d44c9..692d2420bfd82 100644
--- a/tests/trainer/loops/test_training_loop_flow_scalar.py
+++ b/tests/trainer/loops/test_training_loop_flow_scalar.py
@@ -18,6 +18,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.core.lightning import LightningModule
+from pytorch_lightning.loops.closure import Closure
from pytorch_lightning.trainer.states import RunningStage
from tests.helpers.boring_model import BoringModel, RandomDataset
from tests.helpers.deterministic_model import DeterministicModel
@@ -156,10 +157,11 @@ def backward(self, loss, optimizer, optimizer_idx):
assert train_step_out.minimize.item() == 171
# make sure the optimizer closure returns the correct things
- opt_closure_result = trainer.fit_loop.epoch_loop.batch_loop.training_step_and_backward(
+ opt_closure = trainer.fit_loop.epoch_loop.batch_loop._make_closure(
batch, batch_idx, 0, trainer.optimizers[0], hiddens=None
)
- assert opt_closure_result["loss"].item() == 171
+ opt_closure_result = opt_closure()
+ assert opt_closure_result.item() == 171
def test__training_step__step_end__epoch_end__flow_scalar(tmpdir):
@@ -229,10 +231,11 @@ def backward(self, loss, optimizer, optimizer_idx):
assert train_step_out.minimize.item() == 171
# make sure the optimizer closure returns the correct things
- opt_closure_result = trainer.fit_loop.epoch_loop.batch_loop.training_step_and_backward(
+ opt_closure = trainer.fit_loop.epoch_loop.batch_loop._make_closure(
batch, batch_idx, 0, trainer.optimizers[0], hiddens=None
)
- assert opt_closure_result["loss"].item() == 171
+ opt_closure_result = opt_closure()
+ assert opt_closure_result.item() == 171
def test_train_step_no_return(tmpdir):
@@ -258,6 +261,8 @@ def validation_epoch_end(self, outputs):
trainer = Trainer(**trainer_args)
+ Closure.warning_cache.clear()
+
with pytest.warns(UserWarning, match=r"training_step returned None.*"):
trainer.fit(model)
@@ -268,6 +273,8 @@ def validation_epoch_end(self, outputs):
model.automatic_optimization = False
trainer = Trainer(**trainer_args)
+ Closure.warning_cache.clear()
+
with no_warning_call(UserWarning, match=r"training_step returned None.*"):
trainer.fit(model)
@@ -293,6 +300,8 @@ def training_step(self, batch, batch_idx):
checkpoint_callback=False,
)
+ Closure.warning_cache.clear()
+
with pytest.warns(UserWarning, match=r".*training_step returned None.*"):
trainer.fit(model)
diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py
index 670e8b4842a89..80896f6fa450c 100644
--- a/tests/trainer/optimization/test_manual_optimization.py
+++ b/tests/trainer/optimization/test_manual_optimization.py
@@ -419,7 +419,7 @@ class TestModel(ManualOptModel):
called = False
- def on_after_backward(self):
+ def on_before_optimizer_step(self, *args):
self.called = True
norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2)
if not (torch.isinf(norm) or torch.isnan(norm)):
diff --git a/tests/trainer/properties/log_dir.py b/tests/trainer/properties/test_log_dir.py
similarity index 100%
rename from tests/trainer/properties/log_dir.py
rename to tests/trainer/properties/test_log_dir.py
diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py
index e665fc79e4323..ca5b908459171 100644
--- a/tests/utilities/test_auto_restart.py
+++ b/tests/utilities/test_auto_restart.py
@@ -14,9 +14,13 @@
import math
import os
import random
+import random as python_random
from collections.abc import Iterable
-from typing import Optional
+from contextlib import suppress
+from copy import deepcopy
+from typing import List, Optional
from unittest import mock
+from unittest.mock import ANY
import numpy as np
import pytest
@@ -29,16 +33,19 @@
from torch.utils.data.dataset import Dataset, IterableDataset
import tests.helpers.utils as tutils
-from pytorch_lightning import Callback, seed_everything, Trainer
+from pytorch_lightning import Callback, LightningModule, seed_everything, Trainer
from pytorch_lightning.utilities.auto_restart import (
_add_capture_metadata_collate,
_dataloader_load_state_dict,
_dataloader_to_state_dict,
CaptureIterableDataset,
+ CaptureMapDataset,
FastForwardSampler,
+ MergedIteratorState,
)
from pytorch_lightning.utilities.enums import AutoRestartBatchKeys
from pytorch_lightning.utilities.exceptions import MisconfigurationException
+from pytorch_lightning.utilities.fetching import DataFetcher
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf
@@ -671,7 +678,11 @@ def create_dataloader():
_ = next(iter_dataloader)
state_dict = _dataloader_to_state_dict(dataloader, iter_dataloader)
- assert state_dict[0]["current_iteration"] == 16
+ assert state_dict == {
+ "num_workers": 0,
+ "previous_worker": None,
+ 0: {"current_iteration": 16},
+ }
dataloader = create_dataloader()
dataloader = _dataloader_load_state_dict(dataloader, state_dict)
@@ -679,14 +690,18 @@ def create_dataloader():
_ = next(iter_dataloader)
state_dict = _dataloader_to_state_dict(dataloader, iter_dataloader)
- assert state_dict[0]["current_iteration"] == 24
+ assert state_dict == {
+ "num_workers": 0,
+ "previous_worker": None,
+ 0: {"current_iteration": 24},
+ }
@RunIf(min_torch="1.7.0")
@pytest.mark.parametrize("use_fault_tolerant", ["0", "1"])
def test_data_loading_wraps_dataset_and_samplers(use_fault_tolerant, tmpdir):
"""
- this test ensures the dataset and sampler are properly wrapped when fault tolerant is enabled.
+ This test ensures the dataset and sampler are properly wrapped when fault tolerant is enabled.
"""
class CustomBatchSampler(BatchSampler):
@@ -713,7 +728,15 @@ def train_dataloader(self):
}
def training_step(self, batch, batch_idx):
- pass
+ assert batch == {
+ "a": [ANY, ANY, ANY],
+ "b": ANY,
+ }
+
+ def validation_step(self, batch, batch_idx):
+ assert isinstance(batch, torch.Tensor)
+
+ validation_epoch_end = None
class Check(Callback):
def on_train_batch_start(self, trainer, *_) -> None:
@@ -721,12 +744,16 @@ def on_train_batch_start(self, trainer, *_) -> None:
if use_fault_tolerant == "1":
assert isinstance(loaders["a"][0].loader.dataset, CaptureIterableDataset)
assert isinstance(loaders["a"][1].loader.sampler, FastForwardSampler)
+ assert isinstance(loaders["a"][1].loader.dataset, CaptureMapDataset)
assert isinstance(loaders["a"][2].loader.batch_sampler, FastForwardSampler)
+ assert isinstance(loaders["a"][2].loader.dataset, CaptureMapDataset)
assert isinstance(loaders["b"].loader.dataset, CaptureIterableDataset)
else:
assert isinstance(loaders["a"][0].loader.dataset, RangeIterableDataset)
assert isinstance(loaders["a"][1].loader.sampler, SequentialSampler)
+ assert not isinstance(loaders["a"][1].loader.dataset, CaptureMapDataset)
assert isinstance(loaders["a"][2].loader.batch_sampler, CustomBatchSampler)
+ assert not isinstance(loaders["a"][2].loader.dataset, CaptureMapDataset)
assert isinstance(loaders["b"].loader.dataset, RangeIterableDataset)
with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": use_fault_tolerant}):
@@ -734,3 +761,210 @@ def on_train_batch_start(self, trainer, *_) -> None:
model.training_epoch_end = None
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_train_batches=1, callbacks=Check())
trainer.fit(model)
+
+
+class SequentialGetItemDataset(Dataset):
+ def __init__(self, length, *_):
+ self.len = length
+
+ def __getitem__(self, index):
+ return torch.tensor([index]).float()
+
+ def __len__(self):
+ return self.len
+
+
+class RandomGetItemDataset(Dataset):
+ """A dataset with random elements generated using global rng from torch, numpy and python."""
+
+ def __init__(self, length, size):
+ self.size = size
+ self.len = length
+
+ def __getitem__(self, index):
+ t = torch.rand(self.size)
+ n = torch.from_numpy(np.random.rand(self.size))
+ p = torch.tensor([python_random.random() for _ in range(self.size)])
+ sample = (index + (t + n + p) / 10).float()
+ return sample
+
+ def __len__(self):
+ return self.len
+
+
+# TODO: test with `RandomGeneratorGetItemDataset`
+@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
+@RunIf(min_torch="1.7.0")
+@pytest.mark.parametrize(
+ "dataset_class",
+ [
+ SequentialGetItemDataset,
+ RandomGetItemDataset,
+ # RandomGeneratorGetItemDataset,
+ ],
+)
+@pytest.mark.parametrize("num_workers", [0])
+@pytest.mark.parametrize("batch_size", [1, 2, 3])
+def test_dataset_rng_states_restart(dataset_class, num_workers, batch_size):
+ """Test that the sequence of batches coming from a random number generator continues with the correct sequence
+ after reloading the state.
+ """
+
+ def create_dataset_sampler():
+ dset = CaptureMapDataset(dataset_class(16, 8))
+ random_sampler = RandomSampler(dset, generator=torch.Generator())
+ return dset, random_sampler
+
+ def create_dataloader_sampler(dset, sampler):
+ sampler = FastForwardSampler(sampler)
+ sampler.setup(batch_size)
+ dl = DataLoader(dset, num_workers=num_workers, sampler=sampler, batch_size=batch_size)
+ _add_capture_metadata_collate(dl)
+ return dl, sampler
+
+ def fetch(fetcher, prefetch_iter, num_batches_fetched):
+ batch, _ = next(prefetch_iter)
+
+ state: List[MergedIteratorState] = fetcher.state
+ assert len(state) == 1
+ assert isinstance(state[0], MergedIteratorState)
+
+ assert len(fetcher.dataloader_iter.cache_states) == 1
+ if num_workers == 0:
+ assert state[0].state[0].num_batches_fetched == num_batches_fetched
+ return state
+
+ dataset, random_sampler = create_dataset_sampler()
+ dataloader, ff_sampler = create_dataloader_sampler(dataset, random_sampler)
+
+ fetcher = DataFetcher()
+ fetcher.setup(dataloader)
+ prefetch_iter = iter(fetcher)
+
+ # fetch 4 batches
+ fetch(fetcher, prefetch_iter, 1)
+ fetch(fetcher, prefetch_iter, 2)
+ fetch(fetcher, prefetch_iter, 3)
+
+ # (A) capture the state after fetching 4 batches
+ state = fetch(fetcher, prefetch_iter, 4)
+ state = deepcopy(state[0])
+
+ # (B) simulate 2 additional batches
+ batch05, _ = next(prefetch_iter)
+ batch06, _ = next(prefetch_iter)
+
+ # start reloading
+ dataset, random_sampler = create_dataset_sampler()
+ dataloader, ff_sampler = create_dataloader_sampler(dataset, random_sampler)
+
+ # load the state dict saved at (A)
+ ff_sampler.load_state_dict(state.sampler_states)
+ dataset.load_state_dict(state.dataset_states, latest_worker_id=state.latest_worker_id, num_workers=num_workers)
+
+ prefetcher = DataFetcher()
+ prefetcher.setup(dataloader)
+ prefetch_iter = iter(prefetcher)
+
+ # fetch 2 random batches, these should match exactly the batches seen at (B)
+ batch05_restart, _ = next(prefetch_iter)
+ batch06_restart, _ = next(prefetch_iter)
+
+ assert torch.equal(batch05, batch05_restart)
+ assert torch.equal(batch06, batch06_restart)
+
+
+class CustomException(Exception):
+ pass
+
+
+class SequentialIterableDataset(IterableDataset):
+ def __init__(self, length, *_):
+ self.len = length
+ self.sampler = SequentialSampler(range(self.len))
+
+ def __iter__(self):
+ self.sampler_iter = iter(self.sampler)
+ return self
+
+ def __next__(self):
+ indice = next(self.sampler_iter)
+ return torch.tensor([indice]).float()
+
+
+class TestModel(LightningModule):
+ def __init__(self, fail_on_step: int = -1):
+ super().__init__()
+ self.layer = torch.nn.Linear(1, 2)
+ self.seen_batches = []
+ self.fail_on_step = fail_on_step
+
+ def training_step(self, batch, batch_idx):
+ if self.global_step == self.fail_on_step:
+ raise CustomException()
+ self.seen_batches.append(torch.stack(batch) if isinstance(batch, list) else batch)
+ loss = sum(self.layer(b).sum() for b in batch)
+ return loss
+
+ def configure_optimizers(self):
+ return torch.optim.SGD(self.layer.parameters(), lr=0.1)
+
+
+def _run_training(trainer_kwargs, dataset_classes, fail_on_step: int = -1):
+ seed_everything(1)
+ train_dataloader = [
+ DataLoader(dataset_class(3, 1), batch_size=1, num_workers=0) for dataset_class in dataset_classes
+ ]
+ train_dataloader = train_dataloader[0] if len(train_dataloader) == 1 else train_dataloader
+ model = TestModel(fail_on_step=fail_on_step)
+ trainer = Trainer(**trainer_kwargs)
+ with suppress(CustomException):
+ trainer.fit(model, train_dataloader=train_dataloader)
+ return model.seen_batches
+
+
+@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
+@RunIf(min_torch="1.7.0")
+@pytest.mark.parametrize(
+ "dataset_classes",
+ [
+ # single training dataset
+ [RandomGetItemDataset],
+ [SequentialIterableDataset],
+ # multiple training datasets (combinded dataloader)
+ [SequentialGetItemDataset, SequentialIterableDataset],
+ [SequentialIterableDataset, SequentialIterableDataset],
+ # [RandomGetItemDataset, RandomGetItemDataset], # TODO: support in the future
+ ],
+)
+@pytest.mark.parametrize("multiple_trainloader_mode", ["min_size", "max_size_cycle"])
+def test_dataset_rng_states_restart_with_lightning(tmpdir, dataset_classes, multiple_trainloader_mode):
+ """Test that the Trainer can resume from a failed run in the case of several types of datasets."""
+ trainer_kwargs = dict(
+ default_root_dir=tmpdir,
+ max_epochs=3,
+ weights_summary=None,
+ progress_bar_refresh_rate=0,
+ multiple_trainloader_mode=multiple_trainloader_mode,
+ )
+
+ all_batches = _run_training(trainer_kwargs, dataset_classes)
+ all_batches = torch.stack(all_batches)
+ assert len(all_batches) == 9
+
+ # Simulate 1st failure
+ complete_batches = _run_training(trainer_kwargs, dataset_classes, fail_on_step=4)
+ assert len(complete_batches) == 4
+
+ checkpoint_path = os.path.join(tmpdir, ".pl_auto_save.ckpt")
+ assert os.path.exists(checkpoint_path)
+
+ # Resume after failure
+ trainer_kwargs.update(resume_from_checkpoint=checkpoint_path)
+ resumed_batches = _run_training(trainer_kwargs, dataset_classes, fail_on_step=-1)
+ assert len(resumed_batches) == 5
+
+ # the resumed batches should match the batches of the successful training
+ all_batches_resumed = torch.stack(complete_batches + resumed_batches)
+ assert len(all_batches_resumed) == 9
+ assert torch.equal(all_batches, all_batches_resumed)
diff --git a/tests/utilities/test_enums.py b/tests/utilities/test_enums.py
index ec33fc74b5a3b..c92ce938c7607 100644
--- a/tests/utilities/test_enums.py
+++ b/tests/utilities/test_enums.py
@@ -1,4 +1,5 @@
from pytorch_lightning.utilities import DeviceType
+from pytorch_lightning.utilities.enums import PrecisionType
def test_consistency():
@@ -9,3 +10,11 @@ def test_consistency():
# hash cannot be case invariant
assert DeviceType.TPU not in {"TPU", "CPU"}
assert DeviceType.TPU in {"tpu", "CPU"}
+
+
+def test_precision_supported_types():
+ assert PrecisionType.supported_types() == ["16", "32", "64", "bf16"]
+ assert PrecisionType.supported_type(16)
+ assert PrecisionType.supported_type("16")
+ assert not PrecisionType.supported_type(1)
+ assert not PrecisionType.supported_type("invalid")
diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py
index b351165e03fd8..bf54bbae83568 100644
--- a/tests/utilities/test_fetching.py
+++ b/tests/utilities/test_fetching.py
@@ -13,7 +13,7 @@
# limitations under the License.
import os
from time import time
-from typing import Any
+from typing import Any, Iterator
from unittest import mock
import pytest
@@ -25,7 +25,8 @@
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import DataFetcher, DataLoaderIterDataFetcher, InterBatchParallelDataFetcher
-from tests.helpers.boring_model import BoringModel
+from pytorch_lightning.utilities.types import STEP_OUTPUT
+from tests.helpers import BoringModel, RandomDataset
from tests.helpers.runif import RunIf
@@ -125,7 +126,8 @@ def measure() -> float:
return sum(stats) / len(stats)
-BATCH_SIZE = 128
+BATCH_SIZE = 32
+DATASET_LEN = 64
EMB_SZ = 100
EMB_DIM = 64
@@ -176,6 +178,7 @@ def test_dataloader(self):
def test_trainer_num_prefetch_batches(tmpdir):
model = RecommenderModel()
+
trainer_kwargs = dict(
default_root_dir=tmpdir,
max_epochs=1,
@@ -190,8 +193,8 @@ def test_trainer_num_prefetch_batches(tmpdir):
trainer = Trainer(**trainer_kwargs)
trainer.fit(model)
t1 = time()
- global_step = trainer.global_step
assert isinstance(trainer.data_connector.train_data_fetcher, InterBatchParallelDataFetcher)
+ global_step = trainer.global_step
torch.cuda.synchronize()
@@ -199,9 +202,9 @@ def test_trainer_num_prefetch_batches(tmpdir):
trainer = Trainer(**trainer_kwargs)
trainer.fit(model)
t3 = time()
+ assert isinstance(trainer.data_connector.train_data_fetcher, DataFetcher)
assert global_step == trainer.global_step == 4
- assert isinstance(trainer.data_connector.train_data_fetcher, DataFetcher)
ratio = (t3 - t2) / (t1 - t0)
assert ratio > 1.1, ratio
@@ -218,7 +221,7 @@ def __init__(self, *args, automatic_optimization: bool = False, **kwargs):
def training_step(self, dataloader_iter, batch_idx):
assert self.count == batch_idx
- assert isinstance(self.trainer.data_connector.data_fetcher, DataLoaderIterDataFetcher)
+ assert isinstance(self.trainer.data_connector.train_data_fetcher, DataLoaderIterDataFetcher)
# fetch 2 batches
self.batches.append(next(dataloader_iter))
self.batches.append(next(dataloader_iter))
@@ -227,7 +230,10 @@ def training_step(self, dataloader_iter, batch_idx):
assert isinstance(batch, torch.Tensor) or batch is None
self.count += 2
if self.automatic_optimization:
- return super().training_step(batch, 0)
+ loss = super().training_step(batch, 0)
+ with pytest.raises(MisconfigurationException, match="dataloader_iter"):
+ self.log("train_loss", loss["loss"])
+ self.log("train_loss", loss["loss"], batch_size=1)
else:
opt = self.optimizers()
output = self(batch)
@@ -236,10 +242,152 @@ def training_step(self, dataloader_iter, batch_idx):
loss.backward()
opt.step()
- training_epoch_end = None
+ def training_epoch_end(self, *_):
+ assert self.trainer.fit_loop.epoch_loop.batch_progress.current.ready == 33
+ assert self.trainer.data_connector.train_data_fetcher.fetched == 64
+ assert self.count == 64
model = TestModel(automatic_optimization=automatic_optimization)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
- trainer.data_connector.data_fetcher = DataLoaderIterDataFetcher()
trainer.fit(model)
- assert model.count == 64
+
+
+class DummyWaitable:
+ def __init__(self, val: Any) -> None:
+ self.val = val
+
+ def wait(self) -> Any:
+ return self.val
+
+
+class AsyncBoringModel(BoringModel):
+ def __init__(self) -> None:
+ super().__init__()
+ self.automatic_optimization = False
+ self.batch_i_handle = None
+ self.num_batches_processed = 0
+
+ def _async_op(self, batch: Any) -> DummyWaitable:
+ return DummyWaitable(val=batch)
+
+ def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT:
+ if self.batch_i_handle is None:
+ batch_i_raw = next(dataloader_iter)
+ self.batch_i_handle = self._async_op(batch_i_raw)
+
+ # Invariant: _async_op for batch[i] has been initiated
+ batch_ip1_handle = None
+ is_last = False
+ try:
+ batch_ip1_raw = next(dataloader_iter)
+ batch_ip1_handle = self._async_op(batch_ip1_raw)
+ except StopIteration:
+ is_last = True
+
+ batch_i = self.batch_i_handle.wait()
+
+ pred = self.layer(batch_i)
+ loss = self.loss(batch_i, pred)
+ loss.backward()
+ self.optimizers().step()
+ self.optimizers().zero_grad()
+
+ self.batch_i_handle = batch_ip1_handle
+ self.num_batches_processed += 1
+
+ return {"loss": loss, "is_last": is_last}
+
+ def train_dataloader(self):
+ return DataLoader(RandomDataset(BATCH_SIZE, DATASET_LEN))
+
+
+def test_training_step_with_dataloader_access(tmpdir) -> None:
+ """
+ A baseline functional test for `training_step` with dataloader access.
+ """
+ trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
+ m = AsyncBoringModel()
+ trainer.fit(m)
+ assert m.num_batches_processed == DATASET_LEN, f"Expect all {DATASET_LEN} batches to be processed."
+
+
+@pytest.mark.parametrize("trigger_stop_iteration", [False, True])
+def test_stop_iteration(trigger_stop_iteration, tmpdir):
+ """
+ Verify that StopIteration properly terminates the training when this is trigged
+ from the current `dataloader_iter`
+ """
+ EXPECT_NUM_BATCHES_PROCESSED = 2
+
+ class TestModel(AsyncBoringModel):
+ def __init__(self, trigger_stop_iteration) -> None:
+ super().__init__()
+ self.trigger_stop_iteration = trigger_stop_iteration
+
+ def training_step(self, dataloader_iter: Iterator, *args) -> STEP_OUTPUT:
+ output = super().training_step(dataloader_iter)
+ if self.trigger_stop_iteration and args[0] == EXPECT_NUM_BATCHES_PROCESSED:
+ raise StopIteration
+ return output
+
+ def train_dataloader(self):
+ if self.trigger_stop_iteration:
+ return DataLoader(RandomDataset(BATCH_SIZE, 2 * EXPECT_NUM_BATCHES_PROCESSED))
+ return DataLoader(RandomDataset(BATCH_SIZE, EXPECT_NUM_BATCHES_PROCESSED))
+
+ trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
+ m = TestModel(trigger_stop_iteration)
+ trainer.fit(m)
+ expected = EXPECT_NUM_BATCHES_PROCESSED
+ if trigger_stop_iteration:
+ expected *= 2
+ assert m.num_batches_processed == expected
+
+
+def test_on_train_batch_start_overridden(tmpdir) -> None:
+ """
+ Verify that a `MisconfigurationException` is raised when
+ `on_train_batch_start` is overridden on the `LightningModule`.
+ """
+
+ class InvalidModel(AsyncBoringModel):
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
+ pass
+
+ trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
+ m = InvalidModel()
+ with pytest.raises(MisconfigurationException, match="The model hook `on_train_batch_start` is not compatible with"):
+ trainer.fit(m)
+
+
+def test_on_train_batch_end_overridden(tmpdir) -> None:
+ """
+ Verify that a `MisconfigurationException` is raised when
+ `on_train_batch_end` is overridden on the `LightningModule`.
+ """
+
+ class InvalidModel(AsyncBoringModel):
+ def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
+ pass
+
+ trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
+ m = InvalidModel()
+ with pytest.raises(MisconfigurationException, match="The model hook `on_train_batch_end` is not compatible with"):
+ trainer.fit(m)
+
+
+def test_tbptt_split_batch_overridden(tmpdir) -> None:
+ """
+ Verify that a `MisconfigurationException` is raised when
+ `tbptt_split_batch` is overridden on the `LightningModule`.
+ """
+
+ class InvalidModel(AsyncBoringModel):
+ def __init__(self) -> None:
+ super().__init__()
+ self.truncated_bptt_steps = 2
+
+ trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
+ m = InvalidModel()
+ with pytest.raises(MisconfigurationException, match="is incompatible with `truncated_bptt_steps > 0`."):
+ trainer.fit(m)