Multi-label Guided Soft Contrastive Learning for Efficient Earth Observation Pretraining
SoftCon explores the potential of two free resources beyond pure satellite imagery for multispectral and SAR pretraining: land cover land use products (e.g. Dynamic World) to guide contrastive learning, and strong vision foundation models (e.g. DINO and DINOv2) for continual pretraining. Without prohibitive compute, SoftCon is able to produce small models with high-performance representations that beat many SOTA large models on a wide range of downstream tasks.
Model | Modality | BigEarthNet-10% linear | EuroSAT linear | download |
---|---|---|---|---|
RN50 | MS | 84.8 | 98.6 | backbone |
ViT-S/14 | MS | 85.0 | 97.1 | backbone |
ViT-B/14 | MS | 86.8* | 98.0* | backbone |
RN50 | SAR | 78.9 | 87.1 | backbone |
ViT-S/14 | SAR | 80.3 | 87.1 | backbone |
ViT-B/14 | SAR | 82.5* | 89.1* | backbone |
*: Linear head with 4 output layers, referring to DINOv2 appendix B.3.
Clone this repository:
git clone https://github.com/zhu-xlab/softcon
cd softcon
Download the pretrained weights and put them in the ./pretrained
directory.
wget https://huggingface.co/wangyi111/softcon/resolve/main/B13_rn50_softcon.pth -P ./pretrained
...
Install PyTorch and torchvision following the official instructions, e.g.,
pip3 install torch torchvision # CUDA 12.1
Open a Python interpreter and run:
import torch
from torchvision.models import resnet50
# create RN50 model for multispectral
model_r50 = resnet50(pretrained=False)
model_r50.conv1 = torch.nn.Conv2d(13, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
model_r50.fc = torch.nn.Identity()
# load pretrained weights
ckpt_r50 = torch.load('./pretrained/B13_rn50_softcon.pth')
model_r50.load_state_dict(ckpt_r50)
# encode one image
model_r50.eval()
img = torch.randn(1, 13, 224, 224)
with torch.no_grad():
out = model_r50(img)
print(out.shape) # torch.Size([1, 2048])
Similarly, for ViT backcbones run:
import torch
from models.dinov2 import vision_transformer as dinov2_vits
# create ViT-S/14 model for SAR
model_vits14 = dinov2_vits.__dict__['vit_small'](
img_size=224,
patch_size=14,
in_chans=2,
block_chunks=0,
init_values=1e-5,
num_register_tokens=0,
)
# load pretrained weights
ckpt_vits14 = torch.load('./pretrained/B2_vits14_softcon.pth')
model_vits14.load_state_dict(ckpt_vits14)
# encode one image
model_vits14.eval()
img = torch.randn(1, 2, 224, 224)
with torch.no_grad():
out = model_vits14(img)
print(out.shape) # torch.Size([1, 384])
It may depend on the downstream task for the most suitable data preprocessing. As a general case, we recommend using the per-channel mean/std of the SSL4EO-S12 dataset (our pretraining dataset) or the target dataset for input normalization. Our normalization function is as follows:
def normalize(img, mean, std):
min_value = mean - 2 * std
max_value = mean + 2 * std
img = (img - min_value) / (max_value - min_value) * 255.0
img = np.clip(img, 0, 255).astype(np.uint8)
return img
SSL4EO-S12-ML dataset is a large-scale multi-label land cover land use classification dataset derived from SSL4EO-S12 images and Dynamic World segmentation maps. It consists of 780,371 multispectral Sentinel-2 images with size 264×264, divided into 247,377 non-overlapping scenes each with 1-4 multi-seasonal patches. Each image has a multi-label annotation from one or more categories in 9 land cover land use classes.
We provide labels corresponding to SSl4EO-S12 image IDs as a json file in HuggingFace. The structure is shown as the following example:
...
{"0000002": # SSL4EO-S12 location ID
{"20200718T102559_20200718T103605_T31TFJ": [], # season ID & multi-label (empty means no label for this scene)
"20201011T103031_20201011T103339_T31TFJ": [],
"20210117T104259_20210117T104300_T31TFJ": ["0", "1", "2", "4", "5", "6", "8"],
"20210402T104021_20210402T104258_T31TFJ": ["1", "2", "4", "5", "6", "8"]
},
"0000003":
{"20200403T100549_20200403T101937_T31PDQ": ["4", "5", "7"],
"20200702T100559_20200702T101831_T31PDQ": ["4", "5", "7"],
"20200930T100729_20200930T102207_T31PEQ": ["0", "1", "4", "5", "6", "7"],
"20210103T101411_20210103T102025_T31PDQ": ["1", "4", "5", "7"]
},
...
}
- Add instructions for SSL4EO-S12-ML data loading
- Add instructions for pretraining and transfer learning
- Add models to torchgeo
@misc{wang2024multilabel,
title={Multi-Label Guided Soft Contrastive Learning for Efficient Earth Observation Pretraining},
author={Wang, Yi and Albrecht, Conrad M and Zhu, Xiao Xiang},
journal={arXiv preprint arXiv:2405.20462},
year={2024}
}