Skip to content

Commit

Permalink
Remove truncated_bptt_steps from Trainer constructor (#8825)
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Wälchli <[email protected]>
  • Loading branch information
ananthsub and awaelchli authored Aug 11, 2021
1 parent cb2a8ed commit b47e3ab
Show file tree
Hide file tree
Showing 13 changed files with 13 additions and 109 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci_test-base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ jobs:

- name: Test Package [only]
run: |
# NOTE: run coverage on tests does not propagare faler status for Win, https://github.com/nedbat/coveragepy/issues/1003
# NOTE: run coverage on tests does not propagate failure status for Win, https://github.com/nedbat/coveragepy/issues/1003
coverage run --source pytorch_lightning -m pytest pytorch_lightning -v --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml
- name: Upload pytest test results
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci_test-conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
- name: Tests
run: |
# NOTE: run coverage on tests does not propagare faler status for Win, https://github.com/nedbat/coveragepy/issues/1003
# NOTE: run coverage on tests does not propagate failure status for Win, https://github.com/nedbat/coveragepy/issues/1003
coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-torch${{ matrix.pytorch-version }}.xml
shell: bash -l {0}

Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed the deprecated `sync_step` argument from `WandbLogger` ([#8763](https://github.com/PyTorchLightning/pytorch-lightning/pull/8763))


- Removed the deprecated `Trainer.truncated_bptt_steps` in favor of `LightningModule.truncated_bptt_steps` ([#8826](https://github.com/PyTorchLightning/pytorch-lightning/pull/8826))


### Fixed

- Fixed `trainer.fit_loop.split_idx` always returning `None` ([#8601](https://github.com/PyTorchLightning/pytorch-lightning/pull/8601))
Expand Down
60 changes: 0 additions & 60 deletions docs/source/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1482,66 +1482,6 @@ Example::
--env=XLA_USE_BF16=1
-- python your_trainer_file.py

truncated_bptt_steps
^^^^^^^^^^^^^^^^^^^^

.. raw:: html

<video width="50%" max-width="400px" controls
poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/truncated_bptt_steps.jpg"
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/truncated_bptt_steps.mp4"></video>

|
Truncated back prop breaks performs backprop every k steps of
a much longer sequence.

If this is enabled, your batches will automatically get truncated
and the trainer will apply Truncated Backprop to it.

(`Williams et al. "An efficient gradient-based algorithm for on-line training of
recurrent network trajectories."
<http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.56.7941&rep=rep1&type=pdf>`_)

.. testcode::

# default used by the Trainer (ie: disabled)
trainer = Trainer(truncated_bptt_steps=None)

# backprop every 5 steps in a batch
trainer = Trainer(truncated_bptt_steps=5)

.. note:: Make sure your batches have a sequence dimension.

Lightning takes care to split your batch along the time-dimension.

.. code-block:: python
# we use the second as the time dimension
# (batch, time, ...)
sub_batch = batch[0, 0:t, ...]
Using this feature requires updating your LightningModule's
:meth:`pytorch_lightning.core.LightningModule.training_step` to include a `hiddens` arg
with the hidden

.. code-block:: python
# Truncated back-propagation through time
def training_step(self, batch, batch_idx, hiddens):
# hiddens are the hiddens from the previous truncated backprop step
out, hiddens = self.lstm(data, hiddens)
return {"loss": ..., "hiddens": hiddens}
To modify how the batch is split,
override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`:

.. testcode::

class LitMNIST(LightningModule):
def tbptt_split_batch(self, batch, split_size):
# do your own splitting on the batch
return splits

val_check_interval
^^^^^^^^^^^^^^^^^^
Expand Down
2 changes: 1 addition & 1 deletion docs/source/starter/new-project.rst
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ Other cool features
Once you define and train your first Lightning model, you might want to try other cool features like

- :doc:`Automatic early stopping <../common/early_stopping>`
- :ref:`Automatic truncated-back-propagation-through-time <common/trainer:truncated_bptt_steps>`
- :ref:`Automatic truncated-back-propagation-through-time <common/lightning_module:truncated_bptt_steps>`
- :ref:`Automatically scale your batch size <advanced/training_tricks:Auto scaling of batch size>`
- :doc:`Automatically find a good learning rate <../advanced/lr_finder>`
- :ref:`Load checkpoints directly from S3 <common/weights_loading:Checkpoint Loading>`
Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1804,9 +1804,7 @@ def get_progress_bar_dict(self):
if avg_training_loss is not None:
tqdm_dict["loss"] = f"{avg_training_loss:.3g}"

module_tbptt_enabled = self.truncated_bptt_steps > 0
trainer_tbptt_enabled = self.trainer.truncated_bptt_steps is not None and self.trainer.truncated_bptt_steps > 0
if module_tbptt_enabled or trainer_tbptt_enabled:
if self.truncated_bptt_steps > 0:
tqdm_dict["split_idx"] = self.trainer.fit_loop.split_idx

if self.trainer.logger is not None and self.trainer.logger.version is not None:
Expand Down
16 changes: 2 additions & 14 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def _tbptt_split_batch(self, batch: Any) -> List[Any]:
Args:
batch: the current batch to split
"""
tbptt_steps = self._truncated_bptt_steps()
tbptt_steps = self.trainer.lightning_module.truncated_bptt_steps
if tbptt_steps == 0:
return [batch]

Expand Down Expand Up @@ -643,19 +643,7 @@ def _build_kwargs(self, batch: Any, batch_idx: int, opt_idx: int, hiddens: Optio
)

# pass hiddens if using tbptt
if self._truncated_bptt_enabled():
if self.trainer.lightning_module.truncated_bptt_steps > 0:
step_kwargs["hiddens"] = hiddens

return step_kwargs

def _truncated_bptt_enabled(self) -> bool:
"""Temporary tbptt utilities until this flag is fully migrated to the lightning module."""
return self._truncated_bptt_steps() > 0

def _truncated_bptt_steps(self) -> int:
"""Returns the number of tbptt steps"""
lightning_module = self.trainer.lightning_module
# Give precedence to the LightningModule as the Trainer flag will be removed in v1.5
if lightning_module.truncated_bptt_steps > 0:
return lightning_module.truncated_bptt_steps
return self.trainer.truncated_bptt_steps or 0
12 changes: 2 additions & 10 deletions pytorch_lightning/trainer/connectors/training_trick_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Union
from typing import Dict, List, Union

from pytorch_lightning.callbacks import GradientAccumulationScheduler
from pytorch_lightning.utilities import GradClipAlgorithmType, rank_zero_deprecation
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException


Expand All @@ -28,7 +28,6 @@ def on_trainer_init(
gradient_clip_algorithm: str,
track_grad_norm: Union[int, float, str],
accumulate_grad_batches: Union[int, Dict[int, int], List[list]],
truncated_bptt_steps: Optional[int],
terminate_on_nan: bool,
):

Expand All @@ -49,13 +48,6 @@ def on_trainer_init(
self.trainer.accumulate_grad_batches = accumulate_grad_batches
self.configure_accumulated_gradients(accumulate_grad_batches)

if truncated_bptt_steps is not None and truncated_bptt_steps > 0:
rank_zero_deprecation(
"Trainer.truncated_bptt_steps is deprecated in v1.3 and will be removed in v1.5."
" Set truncated_bptt_steps directly on the LightningModule instead."
)
self.trainer.truncated_bptt_steps = truncated_bptt_steps

def configure_accumulated_gradients(self, accumulate_grad_batches):
if isinstance(accumulate_grad_batches, dict):
self.trainer.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
Expand Down
5 changes: 0 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def __init__(
weights_summary: Optional[str] = "top",
weights_save_path: Optional[str] = None,
num_sanity_val_steps: int = 2,
truncated_bptt_steps: Optional[int] = None,
resume_from_checkpoint: Optional[Union[Path, str]] = None,
profiler: Optional[Union[BaseProfiler, str]] = None,
benchmark: bool = False,
Expand Down Expand Up @@ -309,9 +308,6 @@ def __init__(
track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm.
truncated_bptt_steps: Deprecated in v1.3 to be removed in 1.5.
Please use :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` instead.
val_check_interval: How often to check the validation set. Use float to check within a training epoch,
use int to check every n steps (batches).
Expand Down Expand Up @@ -438,7 +434,6 @@ def __init__(
gradient_clip_algorithm,
track_grad_norm,
accumulate_grad_batches,
truncated_bptt_steps,
terminate_on_nan,
)
self._setup_on_init(num_sanity_val_steps)
Expand Down
5 changes: 0 additions & 5 deletions tests/deprecated_api/test_remove_1-5.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,6 @@ def test_v1_5_0_datamodule_setter():
assert any("The `LightningModule.datamodule`" in w for w in warning_cache)


def test_v1_5_0_trainer_tbptt_steps(tmpdir):
with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"):
_ = Trainer(truncated_bptt_steps=1)


@RunIf(deepspeed=True)
@pytest.mark.parametrize(
"params", [dict(cpu_offload=True), dict(cpu_offload_params=True), dict(cpu_offload_use_pin_memory=True)]
Expand Down
9 changes: 2 additions & 7 deletions tests/models/test_truncated_bptt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@


@pytest.mark.parametrize("n_hidden_states", (1, 2))
@pytest.mark.parametrize("property_on_module", (False, True))
def test_tbptt_cpu_model(tmpdir, n_hidden_states, property_on_module):
def test_tbptt_cpu_model(tmpdir, n_hidden_states):
"""Test truncated back propagation through time works."""
truncated_bptt_steps = 2
sequence_size = 30
Expand All @@ -44,8 +43,7 @@ def __init__(self, batch_size, in_features, out_features, n_hidden_states, *args
self.batch_size = batch_size
self.layer = torch.nn.Linear(in_features, out_features)
self.n_hidden_states = n_hidden_states
if property_on_module:
self.truncated_bptt_steps = truncated_bptt_steps
self.truncated_bptt_steps = truncated_bptt_steps

def training_step(self, batch, batch_idx, hiddens):
assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps"
Expand Down Expand Up @@ -83,13 +81,10 @@ def train_dataloader(self):
)
model.example_input_array = torch.randn(5, truncated_bptt_steps)

trainer_tbptt_steps = None if property_on_module else truncated_bptt_steps

# fit model
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
truncated_bptt_steps=trainer_tbptt_steps,
limit_val_batches=0,
weights_summary=None,
)
Expand Down
1 change: 0 additions & 1 deletion tests/trainer/test_trainer_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ def _raise():
"log_gpu_memory": None,
"accelerator": None,
"weights_save_path": None,
"truncated_bptt_steps": None,
"resume_from_checkpoint": None,
"profiler": None,
},
Expand Down
1 change: 0 additions & 1 deletion tests/utilities/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def _raise():
log_gpu_memory=None,
distributed_backend=None,
weights_save_path=None,
truncated_bptt_steps=None,
resume_from_checkpoint=None,
profiler=None,
),
Expand Down

0 comments on commit b47e3ab

Please sign in to comment.