From 7c56c50f76487ad33153627ebec0da93fa3ab735 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 3 May 2023 16:58:53 -0500 Subject: [PATCH] Add RandomGrayscale --- torchgeo/trainers/moco.py | 138 +++++++++++++++++++++++--------------- 1 file changed, 84 insertions(+), 54 deletions(-) diff --git a/torchgeo/trainers/moco.py b/torchgeo/trainers/moco.py index d76c24aa38b..644b9e5c4c4 100644 --- a/torchgeo/trainers/moco.py +++ b/torchgeo/trainers/moco.py @@ -27,6 +27,8 @@ ) from torchvision.models._api import WeightsEnum +import torchgeo.transforms as T + from ..models import get_weight from . import utils @@ -36,58 +38,75 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler -# https://github.com/facebookresearch/moco/blob/main/main_moco.py#L326 -# https://github.com/facebookresearch/moco-v3/blob/main/main_moco.py#L261 -SIZE = 224 -KS = SIZE // 10 // 2 * 2 + 1 - -# Same as InstDict: https://arxiv.org/abs/1805.01978 -AUG1_V1 = AUG2_V1 = K.AugmentationSequential( - K.RandomResizedCrop(size=(SIZE, SIZE), scale=(0.2, 1)), - # Not appropriate for multispectral imagery, seasonal contrast used instead - # K.RandomGrayscale(p=0.2), - # K.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.4, p=1) - K.RandomHorizontalFlip(), - K.RandomVerticalFlip(), # added - data_keys=["input"], -) +def moco_augmentations( + version: int, size: int, weights: Tensor +) -> tuple[nn.Module, nn.Module]: + """Data augmentations used by MoCo. -# Similar to SimCLR: https://arxiv.org/abs/2002.05709 -AUG1_V2 = AUG2_V2 = K.AugmentationSequential( - K.RandomResizedCrop(size=(SIZE, SIZE), scale=(0.2, 1)), - # Not appropriate for multispectral imagery, seasonal contrast used instead - # K.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.8) - # K.RandomGrayscale(p=0.2), - K.RandomGaussianBlur(kernel_size=(KS, KS), sigma=(0.1, 2), p=0.5), - K.RandomHorizontalFlip(), - K.RandomVerticalFlip(), # added - data_keys=["input"], -) + Args: + version: Version of MoCo. + size: Size of patch to crop. + weights: Weight vector for grayscale computation. -# Same as BYOL: https://arxiv.org/abs/2006.07733 -AUG1_V3 = K.AugmentationSequential( - K.RandomResizedCrop(size=(SIZE, SIZE), scale=(0.08, 1)), - # Not appropriate for multispectral imagery, - # seasonal contrast used instead - # K.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8) - # K.RandomGrayscale(p=0.2), - K.RandomGaussianBlur(kernel_size=(KS, KS), sigma=(0.1, 2), p=1), - K.RandomHorizontalFlip(), - K.RandomVerticalFlip(), # added - data_keys=["input"], -) -AUG2_V3 = K.AugmentationSequential( - K.RandomResizedCrop(size=(SIZE, SIZE), scale=(0.08, 1)), - # Not appropriate for multispectral imagery, - # seasonal contrast used instead - # K.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8) - # K.RandomGrayscale(p=0.2), - K.RandomGaussianBlur(kernel_size=(KS, KS), sigma=(0.1, 2), p=0.1), - K.RandomSolarize(p=0.2), - K.RandomHorizontalFlip(), - K.RandomVerticalFlip(), # added - data_keys=["input"], -) + Returns: + Data augmentation pipelines. + """ + # https://github.com/facebookresearch/moco/blob/main/main_moco.py#L326 + # https://github.com/facebookresearch/moco-v3/blob/main/main_moco.py#L261 + ks = size // 10 // 2 * 2 + 1 + if version == 1: + # Same as InstDict: https://arxiv.org/abs/1805.01978 + aug1, aug2 = K.AugmentationSequential( + K.RandomResizedCrop(size=(size, size), scale=(0.2, 1)), + T.RandomGrayscale(weights=weights, p=0.2), + # Not appropriate for multispectral imagery, seasonal contrast used instead + # K.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.4, p=1) + K.RandomHorizontalFlip(), + K.RandomVerticalFlip(), # added + data_keys=["input"], + ) + elif version == 2: + # Similar to SimCLR: https://arxiv.org/abs/2002.05709 + aug1, aug2 = K.AugmentationSequential( + K.RandomResizedCrop(size=(size, size), scale=(0.2, 1)), + # Not appropriate for multispectral imagery, seasonal contrast used instead + # K.ColorJitter( + # brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.8 + # ) + T.RandomGrayscale(weights=weights, p=0.2), + K.RandomGaussianBlur(kernel_size=(ks, ks), sigma=(0.1, 2), p=0.5), + K.RandomHorizontalFlip(), + K.RandomVerticalFlip(), # added + data_keys=["input"], + ) + else: + # Same as BYOL: https://arxiv.org/abs/2006.07733 + aug1 = K.AugmentationSequential( + K.RandomResizedCrop(size=(size, size), scale=(0.08, 1)), + # Not appropriate for multispectral imagery, seasonal contrast used instead + # K.ColorJitter( + # brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8 + # ) + T.RandomGrayscale(weights=weights, p=0.2), + K.RandomGaussianBlur(kernel_size=(ks, ks), sigma=(0.1, 2), p=1), + K.RandomHorizontalFlip(), + K.RandomVerticalFlip(), # added + data_keys=["input"], + ) + aug2 = K.AugmentationSequential( + K.RandomResizedCrop(size=(size, size), scale=(0.08, 1)), + # Not appropriate for multispectral imagery, seasonal contrast used instead + # K.ColorJitter( + # brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8 + # ) + T.RandomGrayscale(weights=weights, p=0.2), + K.RandomGaussianBlur(kernel_size=(ks, ks), sigma=(0.1, 2), p=0.1), + K.RandomSolarize(p=0.2), + K.RandomHorizontalFlip(), + K.RandomVerticalFlip(), # added + data_keys=["input"], + ) + return aug1, aug2 class MoCoTask(LightningModule): # type: ignore[misc] @@ -124,8 +143,10 @@ def __init__( memory_bank_size: int = 0, moco_momentum: float = 0.99, gather_distributed: bool = False, - augmentation1: nn.Module = AUG1_V3, - augmentation2: nn.Module = AUG2_V3, + size: int = 224, + grayscale_weights: Optional[Tensor] = None, + augmentation1: Optional[nn.Module] = None, + augmentation2: Optional[nn.Module] = None, ) -> None: """Initialize a new MoCoTask instance. @@ -152,8 +173,14 @@ def __init__( (0.999 for v1/2, 0.99 for v3) gather_distributed: Gather negatives from all GPUs during distributed training (ignored if memory_bank_size > 0). + size: Size of patch to crop. + grayscale_weights: Weight vector for grayscale computation, see + :class:`~torchgeo.transforms.RandomGrayscale`. Only used when + ``augmentations=None``. Defaults to average of all bands. augmentation1: Data augmentation for 1st branch. + Defaults to MoCo augmentation. augmentation2: Data augmentation for 2nd branch. + Defaults to MoCo augmentation. Raises: AssertionError: If an invalid version of MoCo is requested. @@ -180,8 +207,11 @@ def __init__( warnings.warn("MoCo v3 does not use a memory bank") self.save_hyperparameters(ignore=["augmentation1", "augmentation2"]) - self.augmentation1 = augmentation1 - self.augmentation2 = augmentation2 + + grayscale_weights = grayscale_weights or torch.ones(in_channels) + aug1, aug2 = moco_augmentations(version, size, grayscale_weights) + self.augmentation1 = augmentation1 or aug1 + self.augmentation2 = augmentation2 or aug2 # Create backbone self.backbone = timm.create_model(