Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix loading encoder weights trained with BYOL #524

Merged
merged 2 commits into from
Jun 13, 2022

Conversation

gaetanbahl
Copy link
Contributor

When loading a model trained with BYOLTask using the "weights" parameter in ClassificationTask, the following error occurs:

Traceback (most recent call last):                                  
File "finetune_ssl_on_eurosat.py", line 14, in <module>                     
    task = ClassificationTask(classification_model="resnet50", loss="ce",
File "/torchgeo/torchgeo/trainers/classification.py", line 103, in __init__
    self.config_task()
File "/torchgeo/torchgeo/trainers/classification.py", line 77, in config_task
    self.config_model()
File "/torchgeo/torchgeo/trainers/classification.py", line 66, in config_model
    name, state_dict = utils.extract_encoder(self.hyperparams["weights"])
File "/torchgeo/torchgeo/trainers/utils.py", line 55, in extract_encoder                       
raise ValueError(
ValueError: Unknown checkpoint task. Only encoder or classification_model extraction is supported 

The extract_encoder function is looking for the encoder key in the loaded dict, when it seems like it should be looking for encoder_name. This PR fixes this and allows to load a model trained with BYOLTask in ClassificationTask for fine-tuning.

@github-actions github-actions bot added the trainers PyTorch Lightning trainers label May 3, 2022
@adamjstewart
Copy link
Collaborator

It looks like our tests also use "encoder" and need to be updated.

@adamjstewart adamjstewart added this to the 0.2.2 milestone May 3, 2022
@ghost
Copy link

ghost commented Jun 13, 2022

CLA assistant check
All CLA requirements met.

@github-actions github-actions bot added the testing Continuous integration testing label Jun 13, 2022
@calebrob6 calebrob6 merged commit bab8050 into microsoft:main Jun 13, 2022
@adamjstewart adamjstewart modified the milestones: 0.2.2, 0.3.0 Jul 2, 2022
@adamjstewart adamjstewart mentioned this pull request Jul 11, 2022
@gaetanbahl gaetanbahl deleted the fix-encoder-load branch September 25, 2022 07:30
yichiac pushed a commit to yichiac/torchgeo that referenced this pull request Apr 29, 2023
* Fix loading encoder weights trained with BYOL

* Update conftest.py

Co-authored-by: BAHL Gaetan <[email protected]>
Co-authored-by: Caleb Robinson <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
testing Continuous integration testing trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants