-
Notifications
You must be signed in to change notification settings - Fork 370
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
load_from_checkpoint ValueError.. is not a valid ResNet18_Weights #1639
Comments
No issue if I load a model trained with |
Yep, I can reproduce this: Create a checkpoint with:
Try to load with: Note we haven't figured out how to save the checkpoint as a particular filename :) |
RE checkpoint name, should just be set with |
In the BaseTask we set |
I can do this
but this saves files as @adamjstewart, I would consider this as a different bug |
Does it work if you pass the checkpoint to |
Same error with @adamjstewart is there somewhere else you would use |
Can you try passing in the weights inside the load checkpoint call? It might be because we aren't saving the weights hyperparameters correctly so when it tries to load the weights from the checkpoint it breaks. Something like
|
Same error using: task = ClassificationTask(
model="resnet18",
# weights=True, # standard Imagenet
weights=ResNet18_Weights.SENTINEL2_ALL_MOCO, # or try sentinel 2 all bands
# weights=ResNet18_Weights.SENTINEL2_RGB_MOCO, # or try sentinel 2 rgb bands
num_classes=10,
in_channels=13,
loss="ce",
patience=10
)
task = task.load_from_checkpoint(ckpt_path, weights=ResNet18_Weights.SENTINEL2_ALL_MOCO, map_location=torch.device(device)) |
This actually seems to be an issue with saving the weights as a hyperparameter in the checkpoint file. For example the following fails: torch.load(path, map_location="cpu")
# ... is not a valid ResNet18_Weights |
I think the solution here might be to exclude saving the weights as a hyperparam (or maybe only save the string representation for experiments). Using a custom task like this for example does the trick: class CustomTask(ClassificationTask):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
del self.hparams["weights"]
...
...
...
task = task.load_from_checkpoint(path, map_location="cpu") |
Seems related to Lightning-AI/pytorch-lightning#11494 We can use |
Description
Having trained a ClassificationTask
in another notebook I want to load_from_checkpoint but get an error:
Steps to reproduce
Version
0.5.0
The text was updated successfully, but these errors were encountered: