Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

load_from_checkpoint ValueError.. is not a valid ResNet18_Weights #1639

Closed
robmarkcole opened this issue Oct 10, 2023 · 12 comments · Fixed by #1670
Closed

load_from_checkpoint ValueError.. is not a valid ResNet18_Weights #1639

robmarkcole opened this issue Oct 10, 2023 · 12 comments · Fixed by #1670
Labels
trainers PyTorch Lightning trainers
Milestone

Comments

@robmarkcole
Copy link
Contributor

robmarkcole commented Oct 10, 2023

Description

Having trained a ClassificationTask

model = ClassificationTask(
    model="resnet18",
    weights=ResNet18_Weights.SENTINEL2_ALL_MOCO, # or try sentinel 2 all bands
    num_classes=10,
    in_channels=13,
    loss="ce", 
    patience=10
)

in another notebook I want to load_from_checkpoint but get an error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[11], line 1
----> 1 model = ClassificationTask.load_from_checkpoint(ckpt_path, map_location=torch.device(device))

File [/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/core/module.py:1537](https://vscode-remote+lightning-002eai.vscode-resource.vscode-cdn.net/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/core/module.py:1537), in LightningModule.load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)
   1457 @classmethod
   1458 def load_from_checkpoint(
   1459     cls,
   (...)
   1464     **kwargs: Any,
   1465 ) -> Self:
   1466     r"""
   1467     Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint
   1468     it stores the arguments passed to ``__init__``  in the checkpoint under ``"hyper_parameters"``.
   (...)
   1535         y_hat = pretrained_model(x)
   1536     """
-> 1537     loaded = _load_from_checkpoint(
   1538         cls,
   1539         checkpoint_path,
   1540         map_location,
   1541         hparams_file,
   1542         strict,
   1543         **kwargs,
...
    (Resize_0): Resize(output_size=256, p=1.0, p_batch=1.0, same_on_batch=True, size=256, side=short, resample=bilinear, align_corners=True, antialias=False)
    (CenterCrop_1): CenterCrop(p=1.0, p_batch=1.0, same_on_batch=True, resample=bilinear, cropping_mode=slice, align_corners=True, size=(224, 224), padding_mode=zeros)
    (Normalize_2): Normalize(p=1.0, p_batch=1.0, same_on_batch=True, mean=0, std=10000)
  )
), meta={'dataset': 'SSL4EO-S12', 'in_chans': 13, 'model': 'resnet18', 'publication': 'https://arxiv.org/abs/2211.07044', 'repo': 'https://github.com/zhu-xlab/SSL4EO-S12', 'ssl_method': 'moco'}) is not a valid ResNet18_Weights

Steps to reproduce

device = "cuda" if torch.cuda.is_available() else "cpu"
ckpt_path = 'path_to_file.ckpt'
model = ClassificationTask.load_from_checkpoint(ckpt_path, map_location=torch.device(device))

Version

0.5.0

@robmarkcole
Copy link
Contributor Author

No issue if I load a model trained with imagenet weights - possibly related to #1234

@calebrob6
Copy link
Member

Yep, I can reproduce this:

Create a checkpoint with:

from torchgeo.trainers import ClassificationTask
from torchgeo.datamodules import EuroSAT100DataModule
from torchgeo.models import ResNet18_Weights
import lightning.pytorch as pl

datamodule = EuroSAT100DataModule(
    batch_size=32,
    num_workers=8,
    download=True,
    root="data/",
)

task = ClassificationTask(
    model="resnet18",
    weights=ResNet18_Weights.SENTINEL2_ALL_MOCO,
    in_channels=13,
    num_classes=10,
)

trainer = pl.Trainer(
    accelerator="gpu",
    max_epochs=2,
)

trainer.fit(task, datamodule=datamodule)

Try to load with: task = ClassificationTask.load_from_checkpoint("lightning_logs/version_3/checkpoints/epoch=4-step=5.ckpt")

Note we haven't figured out how to save the checkpoint as a particular filename :)

@robmarkcole
Copy link
Contributor Author

RE checkpoint name, should just be set with filename=..

@calebrob6
Copy link
Member

In the BaseTask we set configure_callbacks() which overrides the callbacks passed to pl.Trainer.

@calebrob6
Copy link
Member

I can do this

class ClassificationTaskWithCallback(ClassificationTask):
    def configure_callbacks(self):
        return [
            ModelCheckpoint(filename="test.ckpt")
        ]

task = ClassificationTaskWithCallback(
    model="resnet18",
    weights=ResNet18_Weights.SENTINEL2_ALL_MOCO,
    in_channels=13,
    num_classes=10,
)

but this saves files as {default_root_dir}/lightning_logs/version_{something}/checkpoints/test.ckpt-v1.ckpt

@adamjstewart, I would consider this as a different bug

@adamjstewart adamjstewart added this to the 0.5.1 milestone Oct 10, 2023
@adamjstewart
Copy link
Collaborator

Does it work if you pass the checkpoint to weights=...?

@robmarkcole
Copy link
Contributor Author

Same error with test_results = trainer.test(model=task, dataloaders=datamodule, ckpt_path=ckpt_path)

@adamjstewart is there somewhere else you would use weights=?

@isaaccorley
Copy link
Collaborator

Can you try passing in the weights inside the load checkpoint call? It might be because we aren't saving the weights hyperparameters correctly so when it tries to load the weights from the checkpoint it breaks. Something like

.load_from_checkpoint(path, weights=ResNet18Weights...)

@robmarkcole
Copy link
Contributor Author

Same error using:

task = ClassificationTask(
    model="resnet18",
    # weights=True, # standard Imagenet
    weights=ResNet18_Weights.SENTINEL2_ALL_MOCO, # or try sentinel 2 all bands
    # weights=ResNet18_Weights.SENTINEL2_RGB_MOCO, # or try sentinel 2 rgb bands
    num_classes=10,
    in_channels=13,
    loss="ce", 
    patience=10
)

task = task.load_from_checkpoint(ckpt_path, weights=ResNet18_Weights.SENTINEL2_ALL_MOCO, map_location=torch.device(device))

@isaaccorley
Copy link
Collaborator

isaaccorley commented Oct 11, 2023

This actually seems to be an issue with saving the weights as a hyperparameter in the checkpoint file. For example the following fails:

torch.load(path, map_location="cpu")
# ... is not a valid ResNet18_Weights

@isaaccorley
Copy link
Collaborator

I think the solution here might be to exclude saving the weights as a hyperparam (or maybe only save the string representation for experiments).

Using a custom task like this for example does the trick:

class CustomTask(ClassificationTask):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        del self.hparams["weights"]

...
...
...

task = task.load_from_checkpoint(path, map_location="cpu")

@adamjstewart
Copy link
Collaborator

Seems related to Lightning-AI/pytorch-lightning#11494

We can use save_hyperparameters(ignore=["weights"]) if that helps. But why are you trying to load a checkpoint and load a pre-trained model at the same time? Shouldn't it be one or the other?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants