You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It seems like the SeCo weights still have the final layer from when they were trained, meaning they always issue a warning when loaded into a trainer. We should remove the last layer and make sure the only keys in the weight are the ones needed to load the model.
Steps to reproduce
Any of the following commands will reproduce the warning:
No other weights issue this warning. We could ignore it with with pytest.warns():, but this will cause all other weights to fail the tests since they don't issue a warning.
The current SeCo weights are almost entirely incorrect as the only key that matches between our uploaded weights and the author's weights is "conv1.weight".
Here's how you can extract just the encoder weights and rename them into a format that timm/torchvision understands:
from lightning.pytorch.utilities.migration import pl_legacy_patch
import timm
from src.moco import MocoV2 # from the SeCo repo
import torch
from copy import deepcopy
with pl_legacy_patch():
backbone = MocoV2.load_from_checkpoint("../../seasonal-contrast/seco_resnet50_1m.ckpt")
model = deepcopy(backbone.encoder_q).eval()
state_dict_original = model.state_dict()
resnet18 = timm.create_model("resnet50")
key_correct_name_list = list(resnet18.state_dict().keys())
state_dict_new = {}
for i, (k, v) in enumerate(state_dict_original.items()):
state_dict_new[key_correct_name_list[i]] = v
torch.save(state_dict_new, "seco_resnet50_1m.ckpt")
Description
It seems like the SeCo weights still have the final layer from when they were trained, meaning they always issue a warning when loaded into a trainer. We should remove the last layer and make sure the only keys in the weight are the ones needed to load the model.
Steps to reproduce
Any of the following commands will reproduce the warning:
No other weights issue this warning. We could ignore it with
with pytest.warns():
, but this will cause all other weights to fail the tests since they don't issue a warning.Version
0.5.0.dev0 (35525b2)
The text was updated successfully, but these errors were encountered: