From 61a1746ea29bd7e308cce1ab24aa9e01c3d299da Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 3 Aug 2021 15:06:02 +0200 Subject: [PATCH 1/4] Remove support for the deprecateed `on_{save,load}_checkpoint` signature --- pytorch_lightning/trainer/callback_hook.py | 35 +------ tests/deprecated_api/test_remove_1-5.py | 101 +-------------------- 2 files changed, 5 insertions(+), 131 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 842a10aa69ef1..802dd88e52bce 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -14,14 +14,13 @@ from abc import ABC from copy import deepcopy -from inspect import signature -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Type, Union import torch import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback -from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -237,29 +236,11 @@ def on_keyboard_interrupt(self): for callback in self.callbacks: callback.on_keyboard_interrupt(self, self.lightning_module) - @staticmethod - def __is_old_signature_on_save_checkpoint(fn: Callable) -> bool: - parameters = list(signature(fn).parameters) - return len(parameters) == 2 and parameters[0] != "args" - - @staticmethod - def __is_old_signature_on_load_checkpoint(fn: Callable) -> bool: - parameters = list(signature(fn).parameters) - return len(parameters) == 1 and parameters[0] != "args" - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: """Called when saving a model checkpoint.""" callback_states = {} for callback in self.callbacks: - if self.__is_old_signature_on_save_checkpoint(callback.on_save_checkpoint): - rank_zero_deprecation( - "`Callback.on_save_checkpoint` signature has changed in v1.3." - " A `checkpoint` parameter has been added." - " Support for the old signature will be removed in v1.5" - ) - state = callback.on_save_checkpoint(self, self.lightning_module) - else: - state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint) + state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint) if state: callback_states[callback.state_id] = state return callback_states @@ -289,15 +270,7 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: state = callback_states.get(callback.state_id, callback_states.get(callback._legacy_state_id)) if state: state = deepcopy(state) - if self.__is_old_signature_on_load_checkpoint(callback.on_load_checkpoint): - rank_zero_deprecation( - "`Callback.on_load_checkpoint` signature has changed in v1.3." - " `trainer` and `pl_module` parameters have been added." - " Support for the old signature will be removed in v1.5" - ) - callback.on_load_checkpoint(state) - else: - callback.on_load_checkpoint(self, self.lightning_module, state) + callback.on_load_checkpoint(self, self.lightning_module, state) def on_before_backward(self, loss: torch.Tensor) -> None: """Called before ``loss.backward()``.""" diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 2555a6418fd16..40d9568c33243 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -14,13 +14,12 @@ """Test deprecated functionality which will be removed in v1.5.0""" import operator import os -from typing import Any, Dict from unittest import mock import pytest import torch -from pytorch_lightning import Callback, Trainer +from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.core.decorators import auto_move_data from pytorch_lightning.loggers import WandbLogger @@ -48,104 +47,6 @@ def test_v1_5_0_wandb_unused_sync_step(_): WandbLogger(sync_step=True) -def test_v1_5_0_old_callback_on_save_checkpoint(tmpdir): - class OldSignature(Callback): - def on_save_checkpoint(self, trainer, pl_module): - ... - - model = BoringModel() - trainer_kwargs = {"default_root_dir": tmpdir, "checkpoint_callback": False, "max_epochs": 1} - filepath = tmpdir / "test.ckpt" - - trainer = Trainer(**trainer_kwargs, callbacks=[OldSignature()]) - trainer.fit(model) - - with pytest.deprecated_call(match="old signature will be removed in v1.5"): - trainer.save_checkpoint(filepath) - - class NewSignature(Callback): - def on_save_checkpoint(self, trainer, pl_module, checkpoint): - ... - - class ValidSignature1(Callback): - def on_save_checkpoint(self, trainer, *args): - ... - - class ValidSignature2(Callback): - def on_save_checkpoint(self, *args): - ... - - trainer.callbacks = [NewSignature(), ValidSignature1(), ValidSignature2()] - with no_warning_call(DeprecationWarning): - trainer.save_checkpoint(filepath) - - -class BaseSignatureOnLoadCheckpoint(Callback): - def __init__(self): - self.on_load_checkpoint_called = False - - -class OldSignatureOnLoadCheckpoint(BaseSignatureOnLoadCheckpoint): - def on_save_checkpoint(self, *args) -> Dict[str, Any]: - return {"a": 0} - - def on_load_checkpoint(self, callback_state) -> None: - assert callback_state == {"a": 0} - self.on_load_checkpoint_called = True - - -class NewSignatureOnLoadCheckpoint(BaseSignatureOnLoadCheckpoint): - def on_save_checkpoint(self, trainer, pl_module, checkpoint) -> dict: - return {"something": "something"} - - def on_load_checkpoint(self, trainer, pl_module, checkpoint): - assert checkpoint == {"something": "something"} - self.on_load_checkpoint_called = True - - -class ValidSignature2OnLoadCheckpoint(BaseSignatureOnLoadCheckpoint): - def on_save_checkpoint(self, trainer, pl_module, checkpoint) -> dict: - return {"something": "something"} - - def on_load_checkpoint(self, *args): - assert len(args) == 3 - self.on_load_checkpoint_called = True - - -def test_v1_5_0_old_callback_on_load_checkpoint(tmpdir): - - model = BoringModel() - trainer_kwargs = {"default_root_dir": tmpdir, "max_steps": 1} - chk = ModelCheckpoint(save_last=True) - trainer = Trainer(**trainer_kwargs, callbacks=[OldSignatureOnLoadCheckpoint(), chk]) - trainer.fit(model) - - with pytest.deprecated_call(match="old signature will be removed in v1.5"): - trainer_kwargs["max_steps"] = 2 - cb = OldSignatureOnLoadCheckpoint() - trainer = Trainer(**trainer_kwargs, callbacks=cb, resume_from_checkpoint=chk.last_model_path) - trainer.fit(model) - assert cb.on_load_checkpoint_called - - class ValidSignature1(BaseSignatureOnLoadCheckpoint): - def on_load_checkpoint(self, trainer, *args): - assert len(args) == 2 - self.on_load_checkpoint_called = True - - model = BoringModel() - chk = ModelCheckpoint(save_last=True) - trainer = Trainer( - **trainer_kwargs, - callbacks=[NewSignatureOnLoadCheckpoint(), ValidSignature1(), ValidSignature2OnLoadCheckpoint(), chk] - ) - with no_deprecated_call(match="old signature will be removed in v1.5"): - trainer.fit(model) - - trainer = Trainer(**trainer_kwargs, resume_from_checkpoint=chk.last_model_path) - with no_deprecated_call(match="old signature will be removed in v1.5"): - trainer.fit(model) - - def test_v1_5_0_legacy_profiler_argument(): with pytest.deprecated_call(match="renamed to `record_functions` in v1.3"): PyTorchProfiler(profiled_functions=[]) From e8e534e16acc88df2e44817704cd9071ec215e6e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 3 Aug 2021 15:09:33 +0200 Subject: [PATCH 2/4] Update CHANGELOG --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ff4a53ed8ef62..dc56cf495557f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -77,6 +77,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed the deprecated `optimizer_idx` from `training_step` as an accepted argument in manual optimization ([#8576](https://github.com/PyTorchLightning/pytorch-lightning/pull/8576)) +- Removed support for the deprecated `on_save_checkpoint` signature. The hook now takes a `checkpoint` positional parameter ([#8697](https://github.com/PyTorchLightning/pytorch-lightning/pull/8697)) + + +- Removed support for the deprecated `on_load_checkpoint` signature. The hook now takes a `model` positional parameter ([#8697](https://github.com/PyTorchLightning/pytorch-lightning/pull/8697)) + + - Removed the deprecated `save_function` property in `ModelCheckpoint` ([#8680](https://github.com/PyTorchLightning/pytorch-lightning/pull/8680)) ### Fixed From 14ce2c4922721f4862e7f7eef227467303b992c0 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 3 Aug 2021 16:42:15 +0200 Subject: [PATCH 3/4] Missed code --- pytorch_lightning/callbacks/early_stopping.py | 4 +++- pytorch_lightning/callbacks/timer.py | 4 +++- pytorch_lightning/core/saving.py | 19 ------------------- 3 files changed, 6 insertions(+), 21 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 7def99d3cf8d4..96ba26a324b11 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -153,7 +153,9 @@ def on_save_checkpoint( "patience": self.patience, } - def on_load_checkpoint(self, callback_state: Dict[str, Any]) -> None: + def on_load_checkpoint( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any] + ) -> None: self.wait_count = callback_state["wait_count"] self.stopped_epoch = callback_state["stopped_epoch"] self.best_score = callback_state["best_score"] diff --git a/pytorch_lightning/callbacks/timer.py b/pytorch_lightning/callbacks/timer.py index f68ddb8611264..23894a1179c1f 100644 --- a/pytorch_lightning/callbacks/timer.py +++ b/pytorch_lightning/callbacks/timer.py @@ -158,7 +158,9 @@ def on_save_checkpoint( ) -> Dict[str, Any]: return {"time_elapsed": {stage.value: self.time_elapsed(stage) for stage in list(RunningStage)}} - def on_load_checkpoint(self, callback_state: Dict[str, Any]) -> None: + def on_load_checkpoint( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any] + ) -> None: time_elapsed = callback_state.get("time_elapsed", {}) self._offset = time_elapsed.get(RunningStage.TRAINING.value, 0) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 7501a027fde10..79608bfc1c5c1 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -212,25 +212,6 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], strict: bool = True, **cl return model - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - """ - Do something with the checkpoint. - Gives model a chance to load something before ``state_dict`` is restored. - - Args: - checkpoint: A dictionary with variables from the checkpoint. - """ - - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - """ - Give the model a chance to add something to the checkpoint. - ``state_dict`` is already there. - - Args: - checkpoint: A dictionary in which you can save variables to save in a checkpoint. - Contents need to be pickleable. - """ - # ------------------------- # OPTIONAL HOOKS # ------------------------- From 6b5b76260ddf82fc38f2c253ce761738898ab2cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 4 Aug 2021 13:45:49 +0200 Subject: [PATCH 4/4] Update CHANGELOG.md Co-authored-by: Yifu Wang --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dc56cf495557f..d82265da312de 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -80,7 +80,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed support for the deprecated `on_save_checkpoint` signature. The hook now takes a `checkpoint` positional parameter ([#8697](https://github.com/PyTorchLightning/pytorch-lightning/pull/8697)) -- Removed support for the deprecated `on_load_checkpoint` signature. The hook now takes a `model` positional parameter ([#8697](https://github.com/PyTorchLightning/pytorch-lightning/pull/8697)) +- Removed support for the deprecated `on_load_checkpoint` signature. The hook now takes a `pl_module` positional parameter ([#8697](https://github.com/PyTorchLightning/pytorch-lightning/pull/8697)) - Removed the deprecated `save_function` property in `ModelCheckpoint` ([#8680](https://github.com/PyTorchLightning/pytorch-lightning/pull/8680))