Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MoCo trainer #1285

Merged
merged 10 commits into from
May 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/required.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ setuptools==67.7.0
einops==0.6.1
fiona==1.9.3
kornia==0.6.12
lightning==2.0.2
lightly==1.4.5
lightning==2.0.2
matplotlib==3.7.1
numpy==1.24.3
pillow==9.5.0
Expand Down
20 changes: 20 additions & 0 deletions tests/conf/chesapeake_cvpr_prior_moco.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
module:
_target_: torchgeo.trainers.MoCoTask
model: "resnet18"
in_channels: 4

datamodule:
_target_: torchgeo.datamodules.ChesapeakeCVPRDataModule
root: "tests/data/chesapeake/cvpr"
download: false
train_splits:
- "de-test"
val_splits:
- "de-test"
test_splits:
- "de-test"
batch_size: 2
patch_size: 64
num_workers: 0
class_set: 5
use_prior_labels: True
16 changes: 16 additions & 0 deletions tests/conf/seco_moco_1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
module:
_target_: torchgeo.trainers.MoCoTask
model: "resnet18"
in_channels: 3
version: 1
weight_decay: 1e-4
temperature: 0.07
memory_bank_size: 10
moco_momentum: 0.999

datamodule:
_target_: torchgeo.datamodules.SeasonalContrastS2DataModule
root: "tests/data/seco"
seasons: 1
batch_size: 2
num_workers: 0
19 changes: 19 additions & 0 deletions tests/conf/seco_moco_2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
module:
_target_: torchgeo.trainers.MoCoTask
model: "resnet18"
in_channels: 3
version: 2
layers: 2
hidden_dim: 10
output_dim: 5
weight_decay: 1e-4
temperature: 0.07
memory_bank_size: 10
moco_momentum: 0.999

datamodule:
_target_: torchgeo.datamodules.SeasonalContrastS2DataModule
root: "tests/data/seco"
seasons: 2
batch_size: 2
num_workers: 0
16 changes: 16 additions & 0 deletions tests/conf/ssl4eo_s12_moco_1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
module:
_target_: torchgeo.trainers.MoCoTask
model: "resnet18"
in_channels: 13
version: 1
weight_decay: 1e-4
temperature: 0.07
memory_bank_size: 10
moco_momentum: 0.999

datamodule:
_target_: torchgeo.datamodules.SSL4EOS12DataModule
root: "tests/data/ssl4eo/s12"
seasons: 1
batch_size: 2
num_workers: 0
19 changes: 19 additions & 0 deletions tests/conf/ssl4eo_s12_moco_2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
module:
_target_: torchgeo.trainers.MoCoTask
model: "resnet18"
in_channels: 13
version: 2
layers: 2
hidden_dim: 10
output_dim: 5
weight_decay: 1e-4
temperature: 0.07
memory_bank_size: 10
moco_momentum: 0.999

datamodule:
_target_: torchgeo.datamodules.SSL4EOS12DataModule
root: "tests/data/ssl4eo/s12"
seasons: 2
batch_size: 2
num_workers: 0
156 changes: 156 additions & 0 deletions tests/trainers/test_moco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
from pathlib import Path
from typing import Any

import pytest
import timm
import torch
import torchvision
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from hydra.utils import instantiate
from lightning.pytorch import Trainer
from omegaconf import OmegaConf
from torch.nn import Module
from torchvision.models._api import WeightsEnum

from torchgeo.datasets import SSL4EOS12, SeasonalContrastS2
from torchgeo.models import get_model_weights, list_models
from torchgeo.trainers import MoCoTask

from .test_classification import ClassificationTestModel


def create_model(*args: Any, **kwargs: Any) -> Module:
return ClassificationTestModel(**kwargs)


def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
state_dict: dict[str, Any] = torch.load(url)
return state_dict


class TestMoCoTask:
@pytest.mark.parametrize(
"name",
[
"chesapeake_cvpr_prior_moco",
"seco_moco_1",
"seco_moco_2",
"ssl4eo_s12_moco_1",
"ssl4eo_s12_moco_2",
],
)
def test_trainer(
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
) -> None:
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))

if name.startswith("seco"):
monkeypatch.setattr(SeasonalContrastS2, "__len__", lambda self: 2)

if name.startswith("ssl4eo_s12"):
monkeypatch.setattr(SSL4EOS12, "__len__", lambda self: 2)

# Instantiate datamodule
datamodule = instantiate(conf.datamodule)

# Instantiate model
monkeypatch.setattr(timm, "create_model", create_model)
model = instantiate(conf.module)

# Instantiate trainer
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.fit(model=model, datamodule=datamodule)

def test_version_warnings(self) -> None:
with pytest.warns(UserWarning, match="MoCo v1 uses a memory bank"):
MoCoTask(version=1, layers=2, memory_bank_size=0)
with pytest.warns(UserWarning, match="MoCo v2 only uses 2 layers"):
MoCoTask(version=2, layers=3, memory_bank_size=10)
with pytest.warns(UserWarning, match="MoCo v2 uses a memory bank"):
MoCoTask(version=2, layers=2, memory_bank_size=0)
with pytest.warns(UserWarning, match="MoCo v3 uses 3 layers"):
MoCoTask(version=3, layers=2, memory_bank_size=0)
with pytest.warns(UserWarning, match="MoCo v3 does not use a memory bank"):
MoCoTask(version=3, layers=3, memory_bank_size=10)

@pytest.fixture(
params=[
weights for model in list_models() for weights in get_model_weights(model)
]
)
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
) -> WeightsEnum:
path = tmp_path / f"{weights}.pth"
model = timm.create_model(
weights.meta["model"], in_chans=weights.meta["in_chans"]
)
torch.save(model.state_dict(), path)
try:
monkeypatch.setattr(weights.value, "url", str(path))
except AttributeError:
monkeypatch.setattr(weights, "url", str(path))
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
return weights

def test_weight_file(self, checkpoint: str) -> None:
model_kwargs: dict[str, Any] = {"model": "resnet18", "weights": checkpoint}
match = "num classes .* != num classes in pretrained model"
with pytest.warns(UserWarning, match=match):
MoCoTask(**model_kwargs)

def test_weight_enum(self, mocked_weights: WeightsEnum) -> None:
model_kwargs: dict[str, Any] = {
"model": mocked_weights.meta["model"],
"weights": mocked_weights,
"in_channels": mocked_weights.meta["in_chans"],
}
match = "num classes .* != num classes in pretrained model"
with pytest.warns(UserWarning, match=match):
MoCoTask(**model_kwargs)

def test_weight_str(self, mocked_weights: WeightsEnum) -> None:
model_kwargs: dict[str, Any] = {
"model": mocked_weights.meta["model"],
"weights": str(mocked_weights),
"in_channels": mocked_weights.meta["in_chans"],
}
match = "num classes .* != num classes in pretrained model"
with pytest.warns(UserWarning, match=match):
MoCoTask(**model_kwargs)

@pytest.mark.slow
def test_weight_enum_download(self, weights: WeightsEnum) -> None:
model_kwargs: dict[str, Any] = {
"model": weights.meta["model"],
"weights": weights,
"in_channels": weights.meta["in_chans"],
}
match = "num classes .* != num classes in pretrained model"
with pytest.warns(UserWarning, match=match):
MoCoTask(**model_kwargs)

@pytest.mark.slow
def test_weight_str_download(self, weights: WeightsEnum) -> None:
model_kwargs: dict[str, Any] = {
"model": weights.meta["model"],
"weights": str(weights),
"in_channels": weights.meta["in_chans"],
}
match = "num classes .* != num classes in pretrained model"
with pytest.warns(UserWarning, match=match):
MoCoTask(**model_kwargs)
2 changes: 2 additions & 0 deletions torchgeo/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
from .byol import BYOLTask
from .classification import ClassificationTask, MultiLabelClassificationTask
from .detection import ObjectDetectionTask
from .moco import MoCoTask
from .regression import PixelwiseRegressionTask, RegressionTask
from .segmentation import SemanticSegmentationTask
from .simclr import SimCLRTask

__all__ = (
"BYOLTask",
"ClassificationTask",
"MoCoTask",
"MultiLabelClassificationTask",
"ObjectDetectionTask",
"PixelwiseRegressionTask",
Expand Down
Loading