Skip to content

Commit

Permalink
add transforms and update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
isaaccorley committed May 12, 2023
1 parent a2d9175 commit acb8be3
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
1 change: 1 addition & 0 deletions docs/api/resnet_pretrained_weights.csv
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ Weight,Channels,Source,Citation,BigEarthNet,EuroSAT,So2Sat,OSCD
ResNet18_Weights.SENTINEL2_ALL_MOCO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,,,,
ResNet18_Weights.SENTINEL2_RGB_MOCO, 3,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,,,,
ResNet18_Weights.SENTINEL2_RGB_SECO, 3,`link <https://github.com/ServiceNow/seasonal-contrast>`__,`link <https://arxiv.org/abs/2103.16607>`__,87.27,93.14,,46.94
ResNet50_Weights.FMOW_RGB_GASSL, 3,`link <https://github.com/sustainlab-group/geography-aware-ssl>`__,`link <https://arxiv.org/abs/2011.09980>`__,,,,
ResNet50_Weights.SENTINEL1_ALL_MOCO, 2,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,,,,
ResNet50_Weights.SENTINEL2_ALL_DINO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,90.7,99.1,63.6,
ResNet50_Weights.SENTINEL2_ALL_MOCO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,91.8,99.1,60.9,
Expand Down
37 changes: 24 additions & 13 deletions torchgeo/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@
data_keys=["image"],
)

# Normalization only available for RGB dataset, defined here:
# https://github.com/sustainlab-group/geography-aware-ssl/blob/main/moco_fmow/main_moco_geo%2Btp.py#L287 # noqa: E501
_mean = torch.tensor([0.485, 0.456, 0.406])
_std = torch.tensor([0.229, 0.224, 0.225])
_gassl_transforms = AugmentationSequential(
K.Resize(224),
K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)),
K.Normalize(mean=_mean, std=_std),
data_keys=["image"],
)

# https://github.com/pytorch/vision/pull/6883
# https://github.com/pytorch/vision/pull/7107
# Can be removed once torchvision>=0.15 is required
Expand Down Expand Up @@ -105,6 +116,19 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
.. versionadded:: 0.4
"""

FMOW_RGB_GASSL = Weights(
url="https://huggingface.co/torchgeo/resnet50_fmow_rgb_gassl/resolve/main/resnet50_fmow_rgb_gassl-da43d987.pth", # noqa: E501
transforms=_gassl_transforms,
meta={
"dataset": "fMoW Dataset",
"in_chans": 3,
"model": "resnet50",
"publication": "https://arxiv.org/abs/2011.09980",
"repo": "https://github.com/sustainlab-group/geography-aware-ssl",
"ssl_method": "gassl",
},
)

SENTINEL1_ALL_MOCO = Weights(
url="https://huggingface.co/torchgeo/resnet50_sentinel1_all_moco/resolve/main/resnet50_sentinel1_all_moco-906e4356.pth", # noqa: E501
transforms=_zhu_xlab_transforms,
Expand Down Expand Up @@ -170,19 +194,6 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
},
)

FMOW_RGB_GASSL = Weights(
url="https://huggingface.co/torchgeo/resnet50_fmow_rgb_gassl/resolve/main/resnet50_fmow_rgb_gassl-44b4461b.pth", # noqa: E501
transforms=_seco_transforms,
meta={
"dataset": "fMoW Dataset",
"in_chans": 3,
"model": "resnet50",
"publication": "https://arxiv.org/abs/2011.09980",
"repo": "https://github.com/sustainlab-group/geography-aware-ssl",
"ssl_method": "gassl",
},
)


def resnet18(
weights: Optional[ResNet18_Weights] = None, *args: Any, **kwargs: Any
Expand Down

0 comments on commit acb8be3

Please sign in to comment.