Skip to content

Commit

Permalink
Add RandomGrayscale
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed May 3, 2023
1 parent 663e061 commit 7c56c50
Showing 1 changed file with 84 additions and 54 deletions.
138 changes: 84 additions & 54 deletions torchgeo/trainers/moco.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
)
from torchvision.models._api import WeightsEnum

import torchgeo.transforms as T

from ..models import get_weight
from . import utils

Expand All @@ -36,58 +38,75 @@
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler


# https://github.com/facebookresearch/moco/blob/main/main_moco.py#L326
# https://github.com/facebookresearch/moco-v3/blob/main/main_moco.py#L261
SIZE = 224
KS = SIZE // 10 // 2 * 2 + 1

# Same as InstDict: https://arxiv.org/abs/1805.01978
AUG1_V1 = AUG2_V1 = K.AugmentationSequential(
K.RandomResizedCrop(size=(SIZE, SIZE), scale=(0.2, 1)),
# Not appropriate for multispectral imagery, seasonal contrast used instead
# K.RandomGrayscale(p=0.2),
# K.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.4, p=1)
K.RandomHorizontalFlip(),
K.RandomVerticalFlip(), # added
data_keys=["input"],
)
def moco_augmentations(
version: int, size: int, weights: Tensor
) -> tuple[nn.Module, nn.Module]:
"""Data augmentations used by MoCo.
# Similar to SimCLR: https://arxiv.org/abs/2002.05709
AUG1_V2 = AUG2_V2 = K.AugmentationSequential(
K.RandomResizedCrop(size=(SIZE, SIZE), scale=(0.2, 1)),
# Not appropriate for multispectral imagery, seasonal contrast used instead
# K.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.8)
# K.RandomGrayscale(p=0.2),
K.RandomGaussianBlur(kernel_size=(KS, KS), sigma=(0.1, 2), p=0.5),
K.RandomHorizontalFlip(),
K.RandomVerticalFlip(), # added
data_keys=["input"],
)
Args:
version: Version of MoCo.
size: Size of patch to crop.
weights: Weight vector for grayscale computation.
# Same as BYOL: https://arxiv.org/abs/2006.07733
AUG1_V3 = K.AugmentationSequential(
K.RandomResizedCrop(size=(SIZE, SIZE), scale=(0.08, 1)),
# Not appropriate for multispectral imagery,
# seasonal contrast used instead
# K.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8)
# K.RandomGrayscale(p=0.2),
K.RandomGaussianBlur(kernel_size=(KS, KS), sigma=(0.1, 2), p=1),
K.RandomHorizontalFlip(),
K.RandomVerticalFlip(), # added
data_keys=["input"],
)
AUG2_V3 = K.AugmentationSequential(
K.RandomResizedCrop(size=(SIZE, SIZE), scale=(0.08, 1)),
# Not appropriate for multispectral imagery,
# seasonal contrast used instead
# K.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8)
# K.RandomGrayscale(p=0.2),
K.RandomGaussianBlur(kernel_size=(KS, KS), sigma=(0.1, 2), p=0.1),
K.RandomSolarize(p=0.2),
K.RandomHorizontalFlip(),
K.RandomVerticalFlip(), # added
data_keys=["input"],
)
Returns:
Data augmentation pipelines.
"""
# https://github.com/facebookresearch/moco/blob/main/main_moco.py#L326
# https://github.com/facebookresearch/moco-v3/blob/main/main_moco.py#L261
ks = size // 10 // 2 * 2 + 1
if version == 1:
# Same as InstDict: https://arxiv.org/abs/1805.01978
aug1, aug2 = K.AugmentationSequential(
K.RandomResizedCrop(size=(size, size), scale=(0.2, 1)),
T.RandomGrayscale(weights=weights, p=0.2),
# Not appropriate for multispectral imagery, seasonal contrast used instead
# K.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.4, p=1)
K.RandomHorizontalFlip(),
K.RandomVerticalFlip(), # added
data_keys=["input"],
)
elif version == 2:
# Similar to SimCLR: https://arxiv.org/abs/2002.05709
aug1, aug2 = K.AugmentationSequential(
K.RandomResizedCrop(size=(size, size), scale=(0.2, 1)),
# Not appropriate for multispectral imagery, seasonal contrast used instead
# K.ColorJitter(
# brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.8
# )
T.RandomGrayscale(weights=weights, p=0.2),
K.RandomGaussianBlur(kernel_size=(ks, ks), sigma=(0.1, 2), p=0.5),
K.RandomHorizontalFlip(),
K.RandomVerticalFlip(), # added
data_keys=["input"],
)
else:
# Same as BYOL: https://arxiv.org/abs/2006.07733
aug1 = K.AugmentationSequential(
K.RandomResizedCrop(size=(size, size), scale=(0.08, 1)),
# Not appropriate for multispectral imagery, seasonal contrast used instead
# K.ColorJitter(
# brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8
# )
T.RandomGrayscale(weights=weights, p=0.2),
K.RandomGaussianBlur(kernel_size=(ks, ks), sigma=(0.1, 2), p=1),
K.RandomHorizontalFlip(),
K.RandomVerticalFlip(), # added
data_keys=["input"],
)
aug2 = K.AugmentationSequential(
K.RandomResizedCrop(size=(size, size), scale=(0.08, 1)),
# Not appropriate for multispectral imagery, seasonal contrast used instead
# K.ColorJitter(
# brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8
# )
T.RandomGrayscale(weights=weights, p=0.2),
K.RandomGaussianBlur(kernel_size=(ks, ks), sigma=(0.1, 2), p=0.1),
K.RandomSolarize(p=0.2),
K.RandomHorizontalFlip(),
K.RandomVerticalFlip(), # added
data_keys=["input"],
)
return aug1, aug2


class MoCoTask(LightningModule): # type: ignore[misc]
Expand Down Expand Up @@ -124,8 +143,10 @@ def __init__(
memory_bank_size: int = 0,
moco_momentum: float = 0.99,
gather_distributed: bool = False,
augmentation1: nn.Module = AUG1_V3,
augmentation2: nn.Module = AUG2_V3,
size: int = 224,
grayscale_weights: Optional[Tensor] = None,
augmentation1: Optional[nn.Module] = None,
augmentation2: Optional[nn.Module] = None,
) -> None:
"""Initialize a new MoCoTask instance.
Expand All @@ -152,8 +173,14 @@ def __init__(
(0.999 for v1/2, 0.99 for v3)
gather_distributed: Gather negatives from all GPUs during distributed
training (ignored if memory_bank_size > 0).
size: Size of patch to crop.
grayscale_weights: Weight vector for grayscale computation, see
:class:`~torchgeo.transforms.RandomGrayscale`. Only used when
``augmentations=None``. Defaults to average of all bands.
augmentation1: Data augmentation for 1st branch.
Defaults to MoCo augmentation.
augmentation2: Data augmentation for 2nd branch.
Defaults to MoCo augmentation.
Raises:
AssertionError: If an invalid version of MoCo is requested.
Expand All @@ -180,8 +207,11 @@ def __init__(
warnings.warn("MoCo v3 does not use a memory bank")

self.save_hyperparameters(ignore=["augmentation1", "augmentation2"])
self.augmentation1 = augmentation1
self.augmentation2 = augmentation2

grayscale_weights = grayscale_weights or torch.ones(in_channels)
aug1, aug2 = moco_augmentations(version, size, grayscale_weights)
self.augmentation1 = augmentation1 or aug1
self.augmentation2 = augmentation2 or aug2

# Create backbone
self.backbone = timm.create_model(
Expand Down

0 comments on commit 7c56c50

Please sign in to comment.