From b47e3ab7ce171ced7b9d807bbca6062cd6b5c97e Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 10 Aug 2021 20:26:01 -0700 Subject: [PATCH] Remove truncated_bptt_steps from Trainer constructor (#8825) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- .github/workflows/ci_test-base.yml | 2 +- .github/workflows/ci_test-conda.yml | 2 +- CHANGELOG.md | 3 + docs/source/common/trainer.rst | 60 ------------------- docs/source/starter/new-project.rst | 2 +- pytorch_lightning/core/lightning.py | 4 +- .../loops/batch/training_batch_loop.py | 16 +---- .../connectors/training_trick_connector.py | 12 +--- pytorch_lightning/trainer/trainer.py | 5 -- tests/deprecated_api/test_remove_1-5.py | 5 -- tests/models/test_truncated_bptt.py | 9 +-- tests/trainer/test_trainer_cli.py | 1 - tests/utilities/test_cli.py | 1 - 13 files changed, 13 insertions(+), 109 deletions(-) diff --git a/.github/workflows/ci_test-base.yml b/.github/workflows/ci_test-base.yml index 6a1b04276e60f..92a2f4206f8fc 100644 --- a/.github/workflows/ci_test-base.yml +++ b/.github/workflows/ci_test-base.yml @@ -71,7 +71,7 @@ jobs: - name: Test Package [only] run: | - # NOTE: run coverage on tests does not propagare faler status for Win, https://github.com/nedbat/coveragepy/issues/1003 + # NOTE: run coverage on tests does not propagate failure status for Win, https://github.com/nedbat/coveragepy/issues/1003 coverage run --source pytorch_lightning -m pytest pytorch_lightning -v --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml - name: Upload pytest test results diff --git a/.github/workflows/ci_test-conda.yml b/.github/workflows/ci_test-conda.yml index 6f291af749d33..86a5a4a933268 100644 --- a/.github/workflows/ci_test-conda.yml +++ b/.github/workflows/ci_test-conda.yml @@ -43,7 +43,7 @@ jobs: - name: Tests run: | - # NOTE: run coverage on tests does not propagare faler status for Win, https://github.com/nedbat/coveragepy/issues/1003 + # NOTE: run coverage on tests does not propagate failure status for Win, https://github.com/nedbat/coveragepy/issues/1003 coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-torch${{ matrix.pytorch-version }}.xml shell: bash -l {0} diff --git a/CHANGELOG.md b/CHANGELOG.md index 3055b15011a2f..77163b62457b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -118,6 +118,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed the deprecated `sync_step` argument from `WandbLogger` ([#8763](https://github.com/PyTorchLightning/pytorch-lightning/pull/8763)) +- Removed the deprecated `Trainer.truncated_bptt_steps` in favor of `LightningModule.truncated_bptt_steps` ([#8826](https://github.com/PyTorchLightning/pytorch-lightning/pull/8826)) + + ### Fixed - Fixed `trainer.fit_loop.split_idx` always returning `None` ([#8601](https://github.com/PyTorchLightning/pytorch-lightning/pull/8601)) diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst index af8bdcf5c2123..9fde434e78579 100644 --- a/docs/source/common/trainer.rst +++ b/docs/source/common/trainer.rst @@ -1482,66 +1482,6 @@ Example:: --env=XLA_USE_BF16=1 -- python your_trainer_file.py -truncated_bptt_steps -^^^^^^^^^^^^^^^^^^^^ - -.. raw:: html - - - -| - -Truncated back prop breaks performs backprop every k steps of -a much longer sequence. - -If this is enabled, your batches will automatically get truncated -and the trainer will apply Truncated Backprop to it. - -(`Williams et al. "An efficient gradient-based algorithm for on-line training of -recurrent network trajectories." -`_) - -.. testcode:: - - # default used by the Trainer (ie: disabled) - trainer = Trainer(truncated_bptt_steps=None) - - # backprop every 5 steps in a batch - trainer = Trainer(truncated_bptt_steps=5) - -.. note:: Make sure your batches have a sequence dimension. - -Lightning takes care to split your batch along the time-dimension. - -.. code-block:: python - - # we use the second as the time dimension - # (batch, time, ...) - sub_batch = batch[0, 0:t, ...] - -Using this feature requires updating your LightningModule's -:meth:`pytorch_lightning.core.LightningModule.training_step` to include a `hiddens` arg -with the hidden - -.. code-block:: python - - # Truncated back-propagation through time - def training_step(self, batch, batch_idx, hiddens): - # hiddens are the hiddens from the previous truncated backprop step - out, hiddens = self.lstm(data, hiddens) - return {"loss": ..., "hiddens": hiddens} - -To modify how the batch is split, -override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`: - -.. testcode:: - - class LitMNIST(LightningModule): - def tbptt_split_batch(self, batch, split_size): - # do your own splitting on the batch - return splits val_check_interval ^^^^^^^^^^^^^^^^^^ diff --git a/docs/source/starter/new-project.rst b/docs/source/starter/new-project.rst index 51e28457ba4c3..89281de4cd0f2 100644 --- a/docs/source/starter/new-project.rst +++ b/docs/source/starter/new-project.rst @@ -728,7 +728,7 @@ Other cool features Once you define and train your first Lightning model, you might want to try other cool features like - :doc:`Automatic early stopping <../common/early_stopping>` -- :ref:`Automatic truncated-back-propagation-through-time ` +- :ref:`Automatic truncated-back-propagation-through-time ` - :ref:`Automatically scale your batch size ` - :doc:`Automatically find a good learning rate <../advanced/lr_finder>` - :ref:`Load checkpoints directly from S3 ` diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 841dfc4076633..b4a50bf10a577 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1804,9 +1804,7 @@ def get_progress_bar_dict(self): if avg_training_loss is not None: tqdm_dict["loss"] = f"{avg_training_loss:.3g}" - module_tbptt_enabled = self.truncated_bptt_steps > 0 - trainer_tbptt_enabled = self.trainer.truncated_bptt_steps is not None and self.trainer.truncated_bptt_steps > 0 - if module_tbptt_enabled or trainer_tbptt_enabled: + if self.truncated_bptt_steps > 0: tqdm_dict["split_idx"] = self.trainer.fit_loop.split_idx if self.trainer.logger is not None and self.trainer.logger.version is not None: diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 8b5268539aab7..4850e715e1840 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -451,7 +451,7 @@ def _tbptt_split_batch(self, batch: Any) -> List[Any]: Args: batch: the current batch to split """ - tbptt_steps = self._truncated_bptt_steps() + tbptt_steps = self.trainer.lightning_module.truncated_bptt_steps if tbptt_steps == 0: return [batch] @@ -643,19 +643,7 @@ def _build_kwargs(self, batch: Any, batch_idx: int, opt_idx: int, hiddens: Optio ) # pass hiddens if using tbptt - if self._truncated_bptt_enabled(): + if self.trainer.lightning_module.truncated_bptt_steps > 0: step_kwargs["hiddens"] = hiddens return step_kwargs - - def _truncated_bptt_enabled(self) -> bool: - """Temporary tbptt utilities until this flag is fully migrated to the lightning module.""" - return self._truncated_bptt_steps() > 0 - - def _truncated_bptt_steps(self) -> int: - """Returns the number of tbptt steps""" - lightning_module = self.trainer.lightning_module - # Give precedence to the LightningModule as the Trainer flag will be removed in v1.5 - if lightning_module.truncated_bptt_steps > 0: - return lightning_module.truncated_bptt_steps - return self.trainer.truncated_bptt_steps or 0 diff --git a/pytorch_lightning/trainer/connectors/training_trick_connector.py b/pytorch_lightning/trainer/connectors/training_trick_connector.py index 683954ed5e634..733199c93267c 100644 --- a/pytorch_lightning/trainer/connectors/training_trick_connector.py +++ b/pytorch_lightning/trainer/connectors/training_trick_connector.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional, Union +from typing import Dict, List, Union from pytorch_lightning.callbacks import GradientAccumulationScheduler -from pytorch_lightning.utilities import GradClipAlgorithmType, rank_zero_deprecation +from pytorch_lightning.utilities import GradClipAlgorithmType from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -28,7 +28,6 @@ def on_trainer_init( gradient_clip_algorithm: str, track_grad_norm: Union[int, float, str], accumulate_grad_batches: Union[int, Dict[int, int], List[list]], - truncated_bptt_steps: Optional[int], terminate_on_nan: bool, ): @@ -49,13 +48,6 @@ def on_trainer_init( self.trainer.accumulate_grad_batches = accumulate_grad_batches self.configure_accumulated_gradients(accumulate_grad_batches) - if truncated_bptt_steps is not None and truncated_bptt_steps > 0: - rank_zero_deprecation( - "Trainer.truncated_bptt_steps is deprecated in v1.3 and will be removed in v1.5." - " Set truncated_bptt_steps directly on the LightningModule instead." - ) - self.trainer.truncated_bptt_steps = truncated_bptt_steps - def configure_accumulated_gradients(self, accumulate_grad_batches): if isinstance(accumulate_grad_batches, dict): self.trainer.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4201724e08e4c..c9eca4da38c4b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -140,7 +140,6 @@ def __init__( weights_summary: Optional[str] = "top", weights_save_path: Optional[str] = None, num_sanity_val_steps: int = 2, - truncated_bptt_steps: Optional[int] = None, resume_from_checkpoint: Optional[Union[Path, str]] = None, profiler: Optional[Union[BaseProfiler, str]] = None, benchmark: bool = False, @@ -309,9 +308,6 @@ def __init__( track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm. - truncated_bptt_steps: Deprecated in v1.3 to be removed in 1.5. - Please use :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` instead. - val_check_interval: How often to check the validation set. Use float to check within a training epoch, use int to check every n steps (batches). @@ -438,7 +434,6 @@ def __init__( gradient_clip_algorithm, track_grad_norm, accumulate_grad_batches, - truncated_bptt_steps, terminate_on_nan, ) self._setup_on_init(num_sanity_val_steps) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 3ea319a27a07a..27b2271a848f3 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -205,11 +205,6 @@ def test_v1_5_0_datamodule_setter(): assert any("The `LightningModule.datamodule`" in w for w in warning_cache) -def test_v1_5_0_trainer_tbptt_steps(tmpdir): - with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"): - _ = Trainer(truncated_bptt_steps=1) - - @RunIf(deepspeed=True) @pytest.mark.parametrize( "params", [dict(cpu_offload=True), dict(cpu_offload_params=True), dict(cpu_offload_use_pin_memory=True)] diff --git a/tests/models/test_truncated_bptt.py b/tests/models/test_truncated_bptt.py index c454753e81151..ab10a527e3cda 100644 --- a/tests/models/test_truncated_bptt.py +++ b/tests/models/test_truncated_bptt.py @@ -20,8 +20,7 @@ @pytest.mark.parametrize("n_hidden_states", (1, 2)) -@pytest.mark.parametrize("property_on_module", (False, True)) -def test_tbptt_cpu_model(tmpdir, n_hidden_states, property_on_module): +def test_tbptt_cpu_model(tmpdir, n_hidden_states): """Test truncated back propagation through time works.""" truncated_bptt_steps = 2 sequence_size = 30 @@ -44,8 +43,7 @@ def __init__(self, batch_size, in_features, out_features, n_hidden_states, *args self.batch_size = batch_size self.layer = torch.nn.Linear(in_features, out_features) self.n_hidden_states = n_hidden_states - if property_on_module: - self.truncated_bptt_steps = truncated_bptt_steps + self.truncated_bptt_steps = truncated_bptt_steps def training_step(self, batch, batch_idx, hiddens): assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps" @@ -83,13 +81,10 @@ def train_dataloader(self): ) model.example_input_array = torch.randn(5, truncated_bptt_steps) - trainer_tbptt_steps = None if property_on_module else truncated_bptt_steps - # fit model trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - truncated_bptt_steps=trainer_tbptt_steps, limit_val_batches=0, weights_summary=None, ) diff --git a/tests/trainer/test_trainer_cli.py b/tests/trainer/test_trainer_cli.py index 51f83967affb2..08cdafefaf07a 100644 --- a/tests/trainer/test_trainer_cli.py +++ b/tests/trainer/test_trainer_cli.py @@ -141,7 +141,6 @@ def _raise(): "log_gpu_memory": None, "accelerator": None, "weights_save_path": None, - "truncated_bptt_steps": None, "resume_from_checkpoint": None, "profiler": None, }, diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index e1d8bda010e88..6477348be04da 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -126,7 +126,6 @@ def _raise(): log_gpu_memory=None, distributed_backend=None, weights_save_path=None, - truncated_bptt_steps=None, resume_from_checkpoint=None, profiler=None, ),