Skip to content

Commit

Permalink
Raise exception when load_from_checkpoint called from instance (#18432
Browse files Browse the repository at this point in the history
)

Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: Adrian Wälchli <[email protected]>
  • Loading branch information
3 people authored Sep 20, 2023
1 parent b4b21e0 commit 3a59462
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 12 deletions.
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).



- The `LightningDataModule.load_from_checkpoint` and `LightningModule.load_from_checkpoint` methods now raise an error if they are called on an instance instead of the class ([#18432](https://github.com/Lightning-AI/lightning/pull/18432))


### Deprecated

- Deprecated the `SingleTPUStrategy` (`strategy="single_tpu"`) in favor of `SingleDeviceXLAStrategy` (`strategy="single_xla"`) ([#17383](https://github.com/Lightning-AI/lightning/pull/17383))
Expand Down
10 changes: 6 additions & 4 deletions src/lightning/pytorch/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from lightning.pytorch.core.hooks import DataHooks
from lightning.pytorch.core.mixins import HyperparametersMixin
from lightning.pytorch.core.saving import _load_from_checkpoint
from lightning.pytorch.utilities.model_helpers import _restricted_classmethod
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS


Expand Down Expand Up @@ -157,7 +158,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""
pass

@classmethod
@_restricted_classmethod
def load_from_checkpoint(
cls,
checkpoint_path: Union[_PATH, IO],
Expand Down Expand Up @@ -200,8 +201,9 @@ def load_from_checkpoint(
:class:`LightningDataModule` instance with loaded weights and hyperparameters (if available).
Note:
``load_from_checkpoint`` is a **class** method. You should use your :class:`LightningDataModule`
**class** to call it instead of the :class:`LightningDataModule` instance.
``load_from_checkpoint`` is a **class** method. You must use your :class:`LightningDataModule`
**class** to call it instead of the :class:`LightningDataModule` instance, or a
``TypeError`` will be raised.
Example::
Expand All @@ -223,7 +225,7 @@ def load_from_checkpoint(
"""
loaded = _load_from_checkpoint(
cls,
cls, # type: ignore[arg-type]
checkpoint_path,
map_location=map_location,
hparams_file=hparams_file,
Expand Down
8 changes: 5 additions & 3 deletions src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from lightning.pytorch.utilities import GradClipAlgorithmType
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_9_1
from lightning.pytorch.utilities.model_helpers import _restricted_classmethod
from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_warn, WarningCache
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
from lightning.pytorch.utilities.types import (
Expand Down Expand Up @@ -1464,7 +1465,7 @@ def forward(self, x):

return torchscript_module

@classmethod
@_restricted_classmethod
def load_from_checkpoint(
cls,
checkpoint_path: Union[_PATH, IO],
Expand Down Expand Up @@ -1512,7 +1513,8 @@ def load_from_checkpoint(
Note:
``load_from_checkpoint`` is a **class** method. You should use your :class:`LightningModule`
**class** to call it instead of the :class:`LightningModule` instance.
**class** to call it instead of the :class:`LightningModule` instance, or a
``TypeError`` will be raised.
Example::
Expand Down Expand Up @@ -1545,7 +1547,7 @@ def load_from_checkpoint(
y_hat = pretrained_model(x)
"""
loaded = _load_from_checkpoint(
cls,
cls, # type: ignore[arg-type]
checkpoint_path,
map_location,
hparams_file,
Expand Down
34 changes: 33 additions & 1 deletion src/lightning/pytorch/utilities/model_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Type
import inspect
import os
from types import MethodType
from typing import Any, Callable, Generic, Optional, Type, TYPE_CHECKING, TypeVar

from lightning_utilities.core.imports import RequirementCache
from torch import nn
from typing_extensions import Concatenate, ParamSpec

import lightning.pytorch as pl

Expand Down Expand Up @@ -68,3 +72,31 @@ def _check_mixed_imports(instance: object) -> None:
f"You passed a `{old}` object ({type(instance).__qualname__}) to a `{new}`"
" Trainer. Please switch to a single import style."
)


_T = TypeVar("_T") # type of the method owner
_P = ParamSpec("_P") # parameters of the decorated method
_R_co = TypeVar("_R_co", covariant=True) # return type of the decorated method


class _restricted_classmethod_impl(Generic[_T, _P, _R_co]):
"""Drop-in replacement for @classmethod, but raises an exception when the decorated method is called on an instance
instead of a class type."""

def __init__(self, method: Callable[Concatenate[_T, _P], _R_co]) -> None:
self.method = method

def __get__(self, instance: Optional[_T], cls: Type[_T]) -> Callable[_P, _R_co]:
# Workaround for https://github.com/pytorch/pytorch/issues/67146
is_scripting = any(os.path.join("torch", "jit") in frameinfo.filename for frameinfo in inspect.stack())
if instance is not None and not is_scripting:
raise TypeError(
f"The classmethod `{cls.__name__}.{self.method.__name__}` cannot be called on an instance."
" Please call it on the class type and make sure the return value is used."
)
return MethodType(self.method, cls)


# trick static type checkers into thinking it's a @classmethod
# https://github.com/microsoft/pyright/issues/5865
_restricted_classmethod = classmethod if TYPE_CHECKING else _restricted_classmethod_impl
6 changes: 3 additions & 3 deletions tests/tests_pytorch/callbacks/test_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent: bool
filepath = str(tmpdir / "foo.ckpt")
trainer.save_checkpoint(filepath)

model.load_from_checkpoint(filepath, strict=False)
model.load_state_dict(torch.load(filepath), strict=False)
has_pruning = hasattr(model.layer.mlp_1, "weight_orig")
assert not has_pruning if make_pruning_permanent else has_pruning

Expand Down Expand Up @@ -326,7 +326,7 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
# removed on_train_end
assert not hasattr(model.layer.mlp_3, "weight_orig")

model.load_from_checkpoint(trainer.checkpoint_callback.kth_best_model_path)
model = TestModel.load_from_checkpoint(trainer.checkpoint_callback.kth_best_model_path)
assert not hasattr(model.layer.mlp_3, "weight_orig")
model.load_from_checkpoint(trainer.checkpoint_callback.last_model_path)
model = TestModel.load_from_checkpoint(trainer.checkpoint_callback.last_model_path)
assert not hasattr(model.layer.mlp_3, "weight_orig")
21 changes: 20 additions & 1 deletion tests/tests_pytorch/utilities/test_model_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from lightning.pytorch import LightningDataModule
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.model_helpers import _restricted_classmethod, is_overridden


def test_is_overridden():
Expand Down Expand Up @@ -49,3 +49,22 @@ def test_mixed_imports_unified():

with pytest.raises(TypeError, match=r"`pytorch_lightning` object \(EarlyStopping\) to a `lightning.pytorch`"):
new_is_overridden("on_fit_start", OldEarlyStopping("foo"))


class RestrictedClass:
@_restricted_classmethod
def restricted_cmethod(cls):
# Can only be called on the class type
pass

@classmethod
def cmethod(cls):
# Can be called on instance or class type
pass


def test_restricted_classmethod():
with pytest.raises(TypeError, match="cannot be called on an instance"):
RestrictedClass().restricted_cmethod()

RestrictedClass.restricted_cmethod() # no exception

0 comments on commit 3a59462

Please sign in to comment.