-
Notifications
You must be signed in to change notification settings - Fork 370
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix load_state_dict for all timm models #1084
Conversation
I looked into the seco weights again. Since they are originally saved as part of a pytorch-lightning module, the keys have different names then the default timm keys. Looking at the seco code there are a "q" and a "k" network and the "q" network is used as a pretrained backbone for downstream tasks. The
|
Your guess is as good as mine. Do the authors have any code for loading the pretrained model like your group does? If not, then I think your analysis makes sense. If you really want to be sure, you can try to train a model with those pretrained weights and make sure if converges quickly. |
You could use something like this -- https://gist.github.com/calebrob6/44f2e42017e2d192e837f0a1cd526c50 -- to make sure that linear-probing on a downstream dataset with the model achieves good performance. This is the notebook that I used to verify one of the SSL4EO weights I think. |
Yes, so the encoder_q I got from this code where they load a backbone from their pretrained model for a downstream task. |
I did the above described extraction of the weights and tried @calebrob6 script with the extracted seco resnet18 weights. As they are pretrained on RGB I only use the Eurosat RGB bands. Here are the scores I get: Edit: This is with preprocessing step of just dividing image values by 1000:
This is with preprocessing step of using the provided normalization stats for bands
|
I will try to investigate again tomorrow. This is what I am using as a script (takes about 3 minutes to run on cpu locally). And for some reason one needs |
I downloaded the bigearthnet dataset to try the linear probing script on that since that is the dataset they also report in their paper. For bigearthnet with 10,000 samples (not the full dataset), I get the following scores:
|
Interesting evidence that the weights may not be very transferable... |
But this should be sufficient proof that your approach to extracting the first layer of weights is correct and we can move forward with this PR. |
In the paper, they report quiet a significant improvement when using Seco in linear probing, so I must be doing something wrong. I can also contact the authors to sort it out. |
That's prob for MSI, not RGB |
I think everything is only RGB, at least that is how I interpret it when they state "Although the collected dataset contains up to 12 spectral bands, in this work we focus on the RGB channels since it is a more general modality." |
For the moment I updated the seco weights on huggingface, and the loading works for all weights now. |
I'm still seeing the same issue: $ pytest -m slow tests/trainers/
...
FAILED tests/trainers/test_byol.py::TestBYOLTask::test_weight_enum_download[ResNet18_Weights.SENTINEL2_RGB_SECO] - KeyError: 'conv1.weight'
FAILED tests/trainers/test_byol.py::TestBYOLTask::test_weight_enum_download[ResNet50_Weights.SENTINEL2_RGB_SECO] - KeyError: 'conv1.weight'
FAILED tests/trainers/test_byol.py::TestBYOLTask::test_weight_str_download[ResNet18_Weights.SENTINEL2_RGB_SECO] - KeyError: 'conv1.weight'
FAILED tests/trainers/test_byol.py::TestBYOLTask::test_weight_str_download[ResNet50_Weights.SENTINEL2_RGB_SECO] - KeyError: 'conv1.weight'
FAILED tests/trainers/test_classification.py::TestClassificationTask::test_weight_enum_download[ResNet18_Weights.SENTINEL2_RGB_SECO] - KeyError: 'conv1.weight'
FAILED tests/trainers/test_classification.py::TestClassificationTask::test_weight_enum_download[ResNet50_Weights.SENTINEL2_RGB_SECO] - KeyError: 'conv1.weight'
FAILED tests/trainers/test_classification.py::TestClassificationTask::test_weight_str_download[ResNet18_Weights.SENTINEL2_RGB_SECO] - KeyError: 'conv1.weight'
FAILED tests/trainers/test_classification.py::TestClassificationTask::test_weight_str_download[ResNet50_Weights.SENTINEL2_RGB_SECO] - KeyError: 'conv1.weight'
FAILED tests/trainers/test_regression.py::TestRegressionTask::test_weight_enum_download[ResNet18_Weights.SENTINEL2_RGB_SECO] - KeyError: 'conv1.weight'
FAILED tests/trainers/test_regression.py::TestRegressionTask::test_weight_enum_download[ResNet50_Weights.SENTINEL2_RGB_SECO] - KeyError: 'conv1.weight'
FAILED tests/trainers/test_regression.py::TestRegressionTask::test_weight_str_download[ResNet18_Weights.SENTINEL2_RGB_SECO] - KeyError: 'conv1.weight'
FAILED tests/trainers/test_regression.py::TestRegressionTask::test_weight_str_download[ResNet50_Weights.SENTINEL2_RGB_SECO] - KeyError: 'conv1.weight' |
Mhm I am not getting those errors. Maybe, the old weights are still cached in your torch/hub? I had to delete those so it would reload the new ones after I uploaded them to huggingface. |
Oh mine are prob cached, let me delete |
Yep, works now. Thanks! |
* implement isaacs solution * simple test for function * private function but failing tests * Fix in_channels * Fix model * Test real weights * Real weights have no final layer * Style fixes * expand test coverage of other trainers * revert byol image_size --------- Co-authored-by: Adam J. Stewart <[email protected]>
* implement isaacs solution * simple test for function * private function but failing tests * Fix in_channels * Fix model * Test real weights * Real weights have no final layer * Style fixes * expand test coverage of other trainers * revert byol image_size --------- Co-authored-by: Adam J. Stewart <[email protected]>
This PR closes #1049 , by implementing Isaac's solution.