diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 96c4840f7d33a..e4dc82e05c8be 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -199,6 +199,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +- The `LightningDataModule.load_from_checkpoint` and `LightningModule.load_from_checkpoint` methods now raise an error if they are called on an instance instead of the class ([#18432](https://github.com/Lightning-AI/lightning/pull/18432)) + + ### Deprecated - Deprecated the `SingleTPUStrategy` (`strategy="single_tpu"`) in favor of `SingleDeviceXLAStrategy` (`strategy="single_xla"`) ([#17383](https://github.com/Lightning-AI/lightning/pull/17383)) diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index bda6dac711d06..f1a3efa90c9ec 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -24,6 +24,7 @@ from lightning.pytorch.core.hooks import DataHooks from lightning.pytorch.core.mixins import HyperparametersMixin from lightning.pytorch.core.saving import _load_from_checkpoint +from lightning.pytorch.utilities.model_helpers import _restricted_classmethod from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS @@ -157,7 +158,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """ pass - @classmethod + @_restricted_classmethod def load_from_checkpoint( cls, checkpoint_path: Union[_PATH, IO], @@ -200,8 +201,9 @@ def load_from_checkpoint( :class:`LightningDataModule` instance with loaded weights and hyperparameters (if available). Note: - ``load_from_checkpoint`` is a **class** method. You should use your :class:`LightningDataModule` - **class** to call it instead of the :class:`LightningDataModule` instance. + ``load_from_checkpoint`` is a **class** method. You must use your :class:`LightningDataModule` + **class** to call it instead of the :class:`LightningDataModule` instance, or a + ``TypeError`` will be raised. Example:: @@ -223,7 +225,7 @@ def load_from_checkpoint( """ loaded = _load_from_checkpoint( - cls, + cls, # type: ignore[arg-type] checkpoint_path, map_location=map_location, hparams_file=hparams_file, diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 5e70533267b24..2ef97cd58bbc3 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -64,6 +64,7 @@ from lightning.pytorch.utilities import GradClipAlgorithmType from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_9_1 +from lightning.pytorch.utilities.model_helpers import _restricted_classmethod from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_warn, WarningCache from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature from lightning.pytorch.utilities.types import ( @@ -1464,7 +1465,7 @@ def forward(self, x): return torchscript_module - @classmethod + @_restricted_classmethod def load_from_checkpoint( cls, checkpoint_path: Union[_PATH, IO], @@ -1512,7 +1513,8 @@ def load_from_checkpoint( Note: ``load_from_checkpoint`` is a **class** method. You should use your :class:`LightningModule` - **class** to call it instead of the :class:`LightningModule` instance. + **class** to call it instead of the :class:`LightningModule` instance, or a + ``TypeError`` will be raised. Example:: @@ -1545,7 +1547,7 @@ def load_from_checkpoint( y_hat = pretrained_model(x) """ loaded = _load_from_checkpoint( - cls, + cls, # type: ignore[arg-type] checkpoint_path, map_location, hparams_file, diff --git a/src/lightning/pytorch/utilities/model_helpers.py b/src/lightning/pytorch/utilities/model_helpers.py index d683e294b235b..980695e4d8caa 100644 --- a/src/lightning/pytorch/utilities/model_helpers.py +++ b/src/lightning/pytorch/utilities/model_helpers.py @@ -11,10 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Type +import inspect +import os +from types import MethodType +from typing import Any, Callable, Generic, Optional, Type, TYPE_CHECKING, TypeVar from lightning_utilities.core.imports import RequirementCache from torch import nn +from typing_extensions import Concatenate, ParamSpec import lightning.pytorch as pl @@ -68,3 +72,31 @@ def _check_mixed_imports(instance: object) -> None: f"You passed a `{old}` object ({type(instance).__qualname__}) to a `{new}`" " Trainer. Please switch to a single import style." ) + + +_T = TypeVar("_T") # type of the method owner +_P = ParamSpec("_P") # parameters of the decorated method +_R_co = TypeVar("_R_co", covariant=True) # return type of the decorated method + + +class _restricted_classmethod_impl(Generic[_T, _P, _R_co]): + """Drop-in replacement for @classmethod, but raises an exception when the decorated method is called on an instance + instead of a class type.""" + + def __init__(self, method: Callable[Concatenate[_T, _P], _R_co]) -> None: + self.method = method + + def __get__(self, instance: Optional[_T], cls: Type[_T]) -> Callable[_P, _R_co]: + # Workaround for https://github.com/pytorch/pytorch/issues/67146 + is_scripting = any(os.path.join("torch", "jit") in frameinfo.filename for frameinfo in inspect.stack()) + if instance is not None and not is_scripting: + raise TypeError( + f"The classmethod `{cls.__name__}.{self.method.__name__}` cannot be called on an instance." + " Please call it on the class type and make sure the return value is used." + ) + return MethodType(self.method, cls) + + +# trick static type checkers into thinking it's a @classmethod +# https://github.com/microsoft/pyright/issues/5865 +_restricted_classmethod = classmethod if TYPE_CHECKING else _restricted_classmethod_impl diff --git a/tests/tests_pytorch/callbacks/test_pruning.py b/tests/tests_pytorch/callbacks/test_pruning.py index 61676864afd66..dc795e5f80dcb 100644 --- a/tests/tests_pytorch/callbacks/test_pruning.py +++ b/tests/tests_pytorch/callbacks/test_pruning.py @@ -271,7 +271,7 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent: bool filepath = str(tmpdir / "foo.ckpt") trainer.save_checkpoint(filepath) - model.load_from_checkpoint(filepath, strict=False) + model.load_state_dict(torch.load(filepath), strict=False) has_pruning = hasattr(model.layer.mlp_1, "weight_orig") assert not has_pruning if make_pruning_permanent else has_pruning @@ -326,7 +326,7 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint): # removed on_train_end assert not hasattr(model.layer.mlp_3, "weight_orig") - model.load_from_checkpoint(trainer.checkpoint_callback.kth_best_model_path) + model = TestModel.load_from_checkpoint(trainer.checkpoint_callback.kth_best_model_path) assert not hasattr(model.layer.mlp_3, "weight_orig") - model.load_from_checkpoint(trainer.checkpoint_callback.last_model_path) + model = TestModel.load_from_checkpoint(trainer.checkpoint_callback.last_model_path) assert not hasattr(model.layer.mlp_3, "weight_orig") diff --git a/tests/tests_pytorch/utilities/test_model_helpers.py b/tests/tests_pytorch/utilities/test_model_helpers.py index 13b500d8e4359..d9fc267d9168e 100644 --- a/tests/tests_pytorch/utilities/test_model_helpers.py +++ b/tests/tests_pytorch/utilities/test_model_helpers.py @@ -16,7 +16,7 @@ from lightning.pytorch import LightningDataModule from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel -from lightning.pytorch.utilities.model_helpers import is_overridden +from lightning.pytorch.utilities.model_helpers import _restricted_classmethod, is_overridden def test_is_overridden(): @@ -49,3 +49,22 @@ def test_mixed_imports_unified(): with pytest.raises(TypeError, match=r"`pytorch_lightning` object \(EarlyStopping\) to a `lightning.pytorch`"): new_is_overridden("on_fit_start", OldEarlyStopping("foo")) + + +class RestrictedClass: + @_restricted_classmethod + def restricted_cmethod(cls): + # Can only be called on the class type + pass + + @classmethod + def cmethod(cls): + # Can be called on instance or class type + pass + + +def test_restricted_classmethod(): + with pytest.raises(TypeError, match="cannot be called on an instance"): + RestrictedClass().restricted_cmethod() + + RestrictedClass.restricted_cmethod() # no exception