diff --git a/environment.yml b/environment.yml index 1290e0dd2e2..de47c5cc647 100644 --- a/environment.yml +++ b/environment.yml @@ -24,7 +24,7 @@ dependencies: - hydra-core>=1 - ipywidgets>=7 - isort[colors]>=5.8 - - kornia>=0.6.5 + - kornia>=0.6.9 - laspy>=2 - lightly>=1.4.4 - lightning>=1.8 diff --git a/requirements/min-reqs.old b/requirements/min-reqs.old index d74ce8e414e..84375103261 100644 --- a/requirements/min-reqs.old +++ b/requirements/min-reqs.old @@ -4,7 +4,7 @@ setuptools==42.0.0 # install einops==0.3.0 fiona==1.8.19 -kornia==0.6.5 +kornia==0.6.9 lightly==1.4.4 lightning==1.8.0 matplotlib==3.3.3 diff --git a/setup.cfg b/setup.cfg index 4e93bcd906d..c6bd83251b0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -28,8 +28,8 @@ install_requires = # fiona 1.8.19+ required to fix erroneous warning # https://github.com/Toblerity/Fiona/issues/986 fiona>=1.8.19,<2 - # kornia 0.6.5+ required due to change in kornia.augmentation API - kornia>=0.6.5,<0.7 + # kornia 0.6.9+ required for kornia.augmentation.RandomBrightness + kornia>=0.6.9,<0.7 # lightly 1.4.4+ required for MoCo v3 support lightly>=1.4.4 # lightning 1.8+ is first release diff --git a/torchgeo/trainers/moco.py b/torchgeo/trainers/moco.py index 09431a0b11c..2e1c9b52c5a 100644 --- a/torchgeo/trainers/moco.py +++ b/torchgeo/trainers/moco.py @@ -62,6 +62,8 @@ def moco_augmentations( T.RandomGrayscale(weights=weights, p=0.2), # Not appropriate for multispectral imagery, seasonal contrast used instead # K.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.4, p=1) + K.RandomBrightness(brightness=(0.6, 1.4), p=1.0), + K.RandomContrast(contrast=(0.6, 1.4), p=1.0), K.RandomHorizontalFlip(), K.RandomVerticalFlip(), # added data_keys=["input"], @@ -74,6 +76,8 @@ def moco_augmentations( # K.ColorJitter( # brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.8 # ) + K.RandomBrightness(brightness=(0.6, 1.4), p=0.8), + K.RandomContrast(contrast=(0.6, 1.4), p=0.8), T.RandomGrayscale(weights=weights, p=0.2), K.RandomGaussianBlur(kernel_size=(ks, ks), sigma=(0.1, 2), p=0.5), K.RandomHorizontalFlip(), @@ -88,6 +92,8 @@ def moco_augmentations( # K.ColorJitter( # brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8 # ) + K.RandomBrightness(brightness=(0.6, 1.4), p=0.8), + K.RandomContrast(contrast=(0.6, 1.4), p=0.8), T.RandomGrayscale(weights=weights, p=0.2), K.RandomGaussianBlur(kernel_size=(ks, ks), sigma=(0.1, 2), p=1), K.RandomHorizontalFlip(), @@ -100,6 +106,8 @@ def moco_augmentations( # K.ColorJitter( # brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8 # ) + K.RandomBrightness(brightness=(0.6, 1.4), p=0.8), + K.RandomContrast(contrast=(0.6, 1.4), p=0.8), T.RandomGrayscale(weights=weights, p=0.2), K.RandomGaussianBlur(kernel_size=(ks, ks), sigma=(0.1, 2), p=0.1), K.RandomSolarize(p=0.2), diff --git a/torchgeo/trainers/simclr.py b/torchgeo/trainers/simclr.py index 2d29944fd46..b09b0e591cf 100644 --- a/torchgeo/trainers/simclr.py +++ b/torchgeo/trainers/simclr.py @@ -49,6 +49,8 @@ def simclr_augmentations(size: int, weights: Tensor) -> nn.Module: K.RandomVerticalFlip(), # added # Not appropriate for multispectral imagery, seasonal contrast used instead # K.ColorJitter(brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2, p=0.8) + K.RandomBrightness(brightness=(0.2, 1.8), p=0.8), + K.RandomContrast(contrast=(0.2, 1.8), p=0.8), T.RandomGrayscale(weights=weights, p=0.2), K.RandomGaussianBlur(kernel_size=(ks, ks), sigma=(0.1, 2)), data_keys=["input"], diff --git a/torchgeo/transforms/color.py b/torchgeo/transforms/color.py index 199b1279c72..5459fc2f854 100644 --- a/torchgeo/transforms/color.py +++ b/torchgeo/transforms/color.py @@ -70,7 +70,7 @@ def apply_transform( Returns: The augmented input. """ - weights = flags["weights"][..., :, None, None] + weights = flags["weights"][..., :, None, None].to(input.device) out = input * weights out = out.sum(dim=-3) out = out.unsqueeze(-3).expand(input.shape)