Skip to content

Commit

Permalink
Additional Satlas pretrained models (#1884)
Browse files Browse the repository at this point in the history
* Documentation, satellite-specific transform and weights for additional Satlas single-image rgb&multispectral Swin-v2 models. Tests pass.

* Address 3 of comments

* Address comments, fix readmydocs and isort, mypy still unhappy

* update

* Add bands to meta dicts

* Add comment about Satlas S2 RGB using TCI product

* linting

---------

Co-authored-by: Piper Wolters <[email protected]>
Co-authored-by: Piper Wolters <[email protected]>
  • Loading branch information
3 people authored Feb 21, 2024
1 parent 8defbe4 commit ad2ef7a
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 13 deletions.
1 change: 1 addition & 0 deletions docs/api/landsat_pretrained_weights.csv
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ ResNet50_Weights.LANDSAT_OLI_SR_MOCO,8--9,7,`link <https://github.com/microsoft/
ResNet50_Weights.LANDSAT_OLI_SR_SIMCLR,8--9,7,`link <https://github.com/microsoft/torchgeo>`__,`link <https://arxiv.org/abs/2306.09424>`__,"CC0-1.0",63.65,46.68,60.01,43.17
ViTSmall16_Weights.LANDSAT_OLI_SR_MOCO,8--9,7,`link <https://github.com/microsoft/torchgeo>`__,`link <https://arxiv.org/abs/2306.09424>`__,"CC0-1.0",66.81,50.16,64.17,47.24
ViTSmall16_Weights.LANDSAT_OLI_SR_SIMCLR,8--9,7,`link <https://github.com/microsoft/torchgeo>`__,`link <https://arxiv.org/abs/2306.09424>`__,"CC0-1.0",65.04,48.20,62.61,45.46
Swin_V2_B_Weights.LANDSAT_MS_SI_SATLAS,11,'link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,"ODC-BY",,,,
2 changes: 1 addition & 1 deletion docs/api/naip_pretrained_weights.csv
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Weight,Channels,Source,Citation,License
Swin_V2_B_Weights.NAIP_RGB_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,"Apache-2.0"
Swin_V2_B_Weights.NAIP_RGB_SI_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,"ODC-BY"
1 change: 1 addition & 0 deletions docs/api/sentinel1_pretrained_weights.csv
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
Weight,Channels,Source,Citation,License
ResNet50_Weights.SENTINEL1_ALL_MOCO, 2,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,"CC-BY-4.0"
Swin_V2_B_Weights.SENTINEL1_SI_SATLAS,2,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,"ODC-BY"
3 changes: 2 additions & 1 deletion docs/api/sentinel2_pretrained_weights.csv
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ ResNet50_Weights.SENTINEL2_RGB_MOCO, 3,`link <https://github.com/zhu-xlab/SSL4EO
ResNet50_Weights.SENTINEL2_RGB_SECO, 3,`link <https://github.com/ServiceNow/seasonal-contrast>`__,`link <https://arxiv.org/abs/2103.16607>`__,"Apache-2.0",87.81,,,
ViTSmall16_Weights.SENTINEL2_ALL_DINO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,"CC-BY-4.0",90.5,99.0,62.2,
ViTSmall16_Weights.SENTINEL2_ALL_MOCO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,"CC-BY-4.0",89.9,98.6,61.6,
Swin_V2_B_Weights.SENTINEL2_RGB_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,"Apache-2.0",,,,
Swin_V2_B_Weights.SENTINEL2_RGB_SI_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,"ODC-BY",,,,
Swin_V2_B_Weights.SENTINEL2_MS_SI_SATLAS,9,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,"ODC-BY",,,,
87 changes: 79 additions & 8 deletions torchgeo/models/swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,39 @@
import kornia.augmentation as K
import torch
import torchvision
from kornia.contrib import Lambda
from torchvision.models import SwinTransformer
from torchvision.models._api import Weights, WeightsEnum

from ..transforms import AugmentationSequential

__all__ = ["Swin_V2_B_Weights"]


# https://github.com/allenai/satlas/blob/bcaa968da5395f675d067613e02613a344e81415/satlas/cmd/model/train.py#L42 # noqa: E501
# All Satlas imagery is uint8 and normalized to the range (0, 1) by dividing by 255
# Satlas uses the TCI product for Sentinel-2 RGB, which is in the range (0, 255).
# See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images. # noqa: E501
# Satlas Sentinel-1 and RGB Sentinel-2 and NAIP imagery is uint8 and is normalized to (0, 1) by dividing by 255. # noqa: E501
_satlas_transforms = AugmentationSequential(
K.CenterCrop(256),
K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)),
K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), data_keys=["image"]
)

# Satlas uses the TCI product for Sentinel-2 RGB, which is in the range (0, 255).
# See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images. # noqa: E501
# Satlas Sentinel-2 multispectral imagery has first 3 bands divided by 255 and the following 6 bands by 8160, both clipped to (0, 1). # noqa: E501
_std = torch.tensor(
[255.0, 255.0, 255.0, 8160.0, 8160.0, 8160.0, 8160.0, 8160.0, 8160.0]
) # noqa: E501
_mean = torch.zeros_like(_std)
_sentinel2_ms_satlas_transforms = AugmentationSequential(
K.Normalize(mean=_mean, std=_std),
Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0)),
data_keys=["image"],
)

# Satlas Landsat imagery is 16-bit, normalized by clipping some pixel N with (N-4000)/16320 to (0, 1). # noqa: E501
_landsat_satlas_transforms = AugmentationSequential(
K.Normalize(mean=torch.tensor(4000), std=torch.tensor(16320)),
Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0)),
data_keys=["image"],
)

Expand All @@ -39,8 +59,8 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
.. versionadded:: 0.6
"""

NAIP_RGB_SATLAS = Weights(
url="https://huggingface.co/torchgeo/swin_v2_b_naip_rgb_satlas/resolve/main/swin_v2_b_naip_rgb_satlas-685f45bd.pth", # noqa: E501
NAIP_RGB_SI_SATLAS = Weights(
url="https://huggingface.co/allenai/satlas-pretrain/resolve/main/aerial_swinb_si.pth", # noqa: E501
transforms=_satlas_transforms,
meta={
"dataset": "Satlas",
Expand All @@ -51,8 +71,8 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
},
)

SENTINEL2_RGB_SATLAS = Weights(
url="https://huggingface.co/torchgeo/swin_v2_b_sentinel2_rgb_satlas/resolve/main/swin_v2_b_sentinel2_rgb_satlas-51471041.pth", # noqa: E501
SENTINEL2_RGB_SI_SATLAS = Weights(
url="https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_swinb_si_rgb.pth", # noqa: E501
transforms=_satlas_transforms,
meta={
"dataset": "Satlas",
Expand All @@ -63,6 +83,57 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
},
)

SENTINEL2_MS_SI_SATLAS = Weights(
url="https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_swinb_si_ms.pth", # noqa: E501
transforms=_sentinel2_ms_satlas_transforms,
meta={
"dataset": "Satlas",
"in_chans": 9,
"model": "swin_v2_b",
"publication": "https://arxiv.org/abs/2211.15660",
"repo": "https://github.com/allenai/satlas",
"bands": ["B02", "B03", "B04", "B05", "B06", "B07", "B08", "B11", "B12"],
},
)

SENTINEL1_SI_SATLAS = Weights(
url="https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel1_swinb_si.pth", # noqa: E501
transforms=_satlas_transforms,
meta={
"dataset": "Satlas",
"in_chans": 2,
"model": "swin_v2_b",
"publication": "https://arxiv.org/abs/2211.15660",
"repo": "https://github.com/allenai/satlas",
"bands": ["VH", "VV"],
},
)

LANDSAT_SI_SATLAS = Weights(
url="https://huggingface.co/allenai/satlas-pretrain/resolve/main/landsat_swinb_si.pth", # noqa: E501
transforms=_landsat_satlas_transforms,
meta={
"dataset": "Satlas",
"in_chans": 11,
"model": "swin_v2_b",
"publication": "https://arxiv.org/abs/2211.15660",
"repo": "https://github.com/allenai/satlas",
"bands": [
"B01",
"B02",
"B03",
"B04",
"B05",
"B06",
"B07",
"B08",
"B09",
"B10",
"B11",
], # noqa: E501
},
)


def swin_v2_b(
weights: Optional[Swin_V2_B_Weights] = None, *args: Any, **kwargs: Any
Expand Down
6 changes: 3 additions & 3 deletions torchgeo/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import kornia.augmentation as K
import torch
from einops import rearrange
from kornia.contrib import extract_tensor_patches
from kornia.contrib import Lambda, extract_tensor_patches
from kornia.geometry import crop_by_indices
from kornia.geometry.boxes import Boxes
from torch import Tensor
Expand All @@ -25,7 +25,7 @@ class AugmentationSequential(Module):

def __init__(
self,
*args: Union[K.base._AugmentationBase, K.ImageSequential],
*args: Union[K.base._AugmentationBase, K.ImageSequential, Lambda],
data_keys: list[str],
**kwargs: Any,
) -> None:
Expand Down Expand Up @@ -53,7 +53,7 @@ def __init__(
else:
keys.append(key)

self.augs = K.AugmentationSequential(*args, data_keys=keys, **kwargs)
self.augs = K.AugmentationSequential(*args, data_keys=keys, **kwargs) # type: ignore[arg-type] # noqa: E501

def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Perform augmentations and update data dict.
Expand Down

0 comments on commit ad2ef7a

Please sign in to comment.