Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

load_state_dict does not return the model #1503

Merged
merged 11 commits into from
Feb 6, 2024

Conversation

konstantinklemmer
Copy link
Contributor

@konstantinklemmer konstantinklemmer commented Jul 31, 2023

Fixed an error in the state dict loading of the tutorial

@github-actions github-actions bot added the documentation Improvements or additions to documentation label Jul 31, 2023
@adamjstewart adamjstewart added this to the 0.4.2 milestone Aug 1, 2023
@adamjstewart
Copy link
Collaborator

Not sure why tests are failing but it's clearly unrelated to this PR. Will try to investigate in a separate PR.

@adamjstewart
Copy link
Collaborator

Fixed an error in the state dict loading of the turorial

What's the error?

@calebrob6
Copy link
Member

@konstantinklemmer ping

@adamjstewart
Copy link
Collaborator

P.S. I fixed the failing test, tests should pass after updating your branch. I'm just curious why this PR is needed since it doesn't fail during testing.

@adamjstewart adamjstewart removed this from the 0.4.2 milestone Sep 28, 2023
calebrob6
calebrob6 previously approved these changes Dec 18, 2023
@adamjstewart
Copy link
Collaborator

It's still unclear to me the purpose of this PR. What is the bug it is trying to solve?

@calebrob6
Copy link
Member

There isn't a bug, this is just an example of how to load pre-trained models differently.

@adamjstewart
Copy link
Collaborator

That's not what the PR descriptions says...

@isaaccorley
Copy link
Collaborator

The error that's being fixed is that model.load_state_dict() returns either None or a list of incompatible keys. Our example incorrectly does model = model.load_state_dict which overrides the model itself to None or a list.

@isaaccorley
Copy link
Collaborator

This looks okay to me

konstantinklemmer and others added 2 commits January 3, 2024 12:11
Fixed an error in the state dict loading of the turorial and added a comment on the num_classes parameter when creating timm models.
@adamjstewart
Copy link
Collaborator

Thanks @isaaccorley, I understand the error now.

It looks like we actually make this same mistake in the README. And we define our own custom torchgeo.trainers.utils.load_state_dict that behaves differently from the builtin one. We should fix all of these at the same time.

I'm not sure about the num_classes comment. We aren't trying to teach people how to use timm, just how to use TorchGeo. I don't disagree that it's useful, just that it's in the wrong place.

@konstantinklemmer let me know if you want me to make these changes myself. If I don't hear back I'll assume this PR has been abandoned and take over.

@konstantinklemmer
Copy link
Contributor Author

Sorry, no idea why I am not getting notifications for this PR...

Yes, I am happy to remove that comment and the example with num_classes=0. Does that sound alright? Basically just change model = model.load_state_dict() to model.load_state_dict().

@adamjstewart
Copy link
Collaborator

Yes, and update the README and our builtin load_state_dict as well. Let me know if you want help with the latter.

@konstantinklemmer
Copy link
Contributor Author

Yes please, I am not sure how to tackle the builtin load_state_dict problem (or what it even is exactly).

@adamjstewart
Copy link
Collaborator

or what it even is exactly

It's not really a problem per se, just that we define a wrapper around load_state_dict that returns the model instead of returning hits/misses like the builtin PyTorch method. It's kind of confusing to have different behavior, so I think we should match the behavior of the builtin method. Luckily, this method isn't public facing, so this change won't be backwards incompatible.

@konstantinklemmer
Copy link
Contributor Author

Okay! So there are two functions that I found that, if I understand correctly, are relevant:

In both cases, if we want to keep them as standalone functions that return a model with weights loaded, we should probably rename them?

@adamjstewart
Copy link
Collaborator

torchgeo.trainers.utils.load_state_dict is the one you would change. The test_load_state_dict is just to test the method, the inputs don't matter.

We can either A) change the name, or B) change the return value to match torch.nn.Module.load_state_dict. If the latter works, I would actually prefer the latter.

@konstantinklemmer
Copy link
Contributor Author

konstantinklemmer commented Jan 21, 2024

Well, load_state_dict does not have a return value, no? So for B) we'd just need to get rid of the return call, i.e. this line here (unless I misunderstand): https://github.com/microsoft/torchgeo/blob/436baa9773977c789152854ac7b4eff90e0d9e95/torchgeo/trainers/utils.py#L119C5-L119C17

Then load_state_dict(model, ckpt['state_dict']) should be equivalent to model.load_state_dict(ckpt['state_dict']).

@adamjstewart
Copy link
Collaborator

adamjstewart commented Jan 21, 2024

The builtin load_state_dict returns missing and unexpected keys: https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict

So you'll just return the output of that call so that our wrapper matches.

@konstantinklemmer
Copy link
Contributor Author

Okay after some more digging, the nn.Module.load_state_dict returns an object _IncompatibleKeys (https://github.com/pytorch/pytorch/blob/c3780010a58a84920335296ee5f091a0db18259f/torch/nn/modules/module.py#L29). This object uses missing and incompatible keys. However, when you check the load_state_dict and _load_from_state_dict (https://github.com/pytorch/pytorch/blob/c3780010a58a84920335296ee5f091a0db18259f/torch/nn/modules/module.py#L1953) functions, it seems that both will be empty dicts:

missing_keys: List[str] = []
unexpected_keys: List[str] = []

unless the strict=True flag is on. The torchgeo.trainers.utils.load_state_dict does not use that flag. So should I just return _IncompatibleKeys with those empty dicts?

@adamjstewart
Copy link
Collaborator

Let's just return the output of nn.Module.load_state_dict. For type hints, try:

-> Tuple[List[str], List[str]]

I think that will work, and be correct. The builtin method has no return type annotations so it shouldn't complain that we aren't using a NamedTuple. If you have any trouble with mypy, let me know and I can hack on it.

@konstantinklemmer
Copy link
Contributor Author

Got it! Will try if that works and report back.

* Import Tuple from typing
* Change return of `load_state_dict` from `model` to `Tuple[List[str], List[str]]`, matching the return of the standard PyTorch builtin function.
@github-actions github-actions bot added the trainers PyTorch Lightning trainers label Jan 30, 2024
Remove example of loading pretrained model without prediction head (`num_classes=0`).
Adapt new `load_state_dict` function.
@konstantinklemmer
Copy link
Contributor Author

Ok I think I updated all files (the README, the notebook, the utils.py) but tests are failing.

@github-actions github-actions bot added the testing Continuous integration testing label Jan 31, 2024
@adamjstewart adamjstewart added this to the 0.5.2 milestone Jan 31, 2024
@adamjstewart adamjstewart changed the title Update pretrained_weights.ipynb load_state_dict does not return the model Jan 31, 2024
@konstantinklemmer
Copy link
Contributor Author

Yay! Thanks for helping with this - seems though that the original problem with the notebook test still persists?

@adamjstewart
Copy link
Collaborator

Completely unrelated problem, fixed in #1838

@adamjstewart adamjstewart merged commit 55b3c50 into microsoft:main Feb 6, 2024
20 of 21 checks passed
isaaccorley pushed a commit that referenced this pull request Mar 2, 2024
* Update pretrained_weights.ipynb

Fixed an error in the state dict loading of the turorial and added a comment on the num_classes parameter when creating timm models.

* Update docs/tutorials/pretrained_weights.ipynb

* Update utils.py

* Import Tuple from typing
* Change return of `load_state_dict` from `model` to `Tuple[List[str], List[str]]`, matching the return of the standard PyTorch builtin function.

* Update pretrained_weights.ipynb

Remove example of loading pretrained model without prediction head (`num_classes=0`).

* Update README.md

Adapt new `load_state_dict` function.

* Mimic return type of builtin load_state_dict

* Modern type hints

* Blacken

* Try being explicit

---------

Co-authored-by: Caleb Robinson <[email protected]>
Co-authored-by: Adam J. Stewart <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation testing Continuous integration testing trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants