Skip to content

Commit

Permalink
Add seco transforms and zhu normalization to pretrained weights (#1119)
Browse files Browse the repository at this point in the history
* add seco transforms and zhu normalization

* adapt links

* add additional comment zhu lab

* left from merge
  • Loading branch information
nilsleh authored Feb 20, 2023
1 parent 562778c commit 3d69703
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 5 deletions.
55 changes: 51 additions & 4 deletions torchgeo/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,63 @@

import kornia.augmentation as K
import timm
import torch.nn as nn
import torch
from timm.models import ResNet
from torchvision.models._api import Weights, WeightsEnum

from ..transforms import AugmentationSequential

__all__ = ["ResNet50_Weights", "ResNet18_Weights"]


# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167 # noqa: E501
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501
# Normalization either by 10K or channel-wise with band statistics
_zhu_xlab_transforms = AugmentationSequential(
K.Resize(256), K.CenterCrop(224), data_keys=["image"]
K.Resize(256),
K.CenterCrop(224),
K.Normalize(mean=0, std=10000),
data_keys=["image"],
)

# https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/bigearthnet_dataset.py#L13 # noqa: E501
_seco_transforms = AugmentationSequential(
K.Resize(128),
K.Normalize(
mean=torch.Tensor(
[
340.76769064,
429.9430203,
614.21682446,
590.23569706,
950.68368468,
1792.46290469,
2075.46795189,
2218.94553375,
2266.46036911,
2246.0605464,
1594.42694882,
1009.32729131,
]
),
std=torch.Tensor(
[
554.81258967,
572.41639287,
582.87945694,
675.88746967,
729.89827633,
1096.01480586,
1273.45393088,
1365.45589904,
1356.13789355,
1302.3292881,
1079.19066363,
818.86747235,
]
),
),
data_keys=["image"],
)

# https://github.com/pytorch/vision/pull/6883
Expand Down Expand Up @@ -62,7 +109,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]

SENTINEL2_RGB_SECO = Weights(
url="https://huggingface.co/torchgeo/resnet18_sentinel2_rgb_seco/resolve/main/resnet18_sentinel2_rgb_seco-9976a9cb.pth", # noqa: E501
transforms=nn.Identity(),
transforms=_seco_transforms,
meta={
"dataset": "SeCo Dataset",
"in_chans": 3,
Expand Down Expand Up @@ -137,7 +184,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]

SENTINEL2_RGB_SECO = Weights(
url="https://huggingface.co/torchgeo/resnet50_sentinel2_rgb_seco/resolve/main/resnet50_sentinel2_rgb_seco-584035db.pth", # noqa: E501
transforms=nn.Identity(),
transforms=_seco_transforms,
meta={
"dataset": "SeCo Dataset",
"in_chans": 3,
Expand Down
8 changes: 7 additions & 1 deletion torchgeo/models/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,14 @@

__all__ = ["ViTSmall16_Weights"]

# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167 # noqa: E501
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501
# Normalization either by 10K or channel-wise with band statistics
_zhu_xlab_transforms = AugmentationSequential(
K.Resize(256), K.CenterCrop(224), data_keys=["image"]
K.Resize(256),
K.CenterCrop(224),
K.Normalize(mean=0, std=10000),
data_keys=["image"],
)

# https://github.com/pytorch/vision/pull/6883
Expand Down

0 comments on commit 3d69703

Please sign in to comment.