From 4d2e1cd188640162e357ed34587e0d7d05179610 Mon Sep 17 00:00:00 2001 From: Dylan Stewart <94502285+dylanrstewart@users.noreply.github.com> Date: Fri, 10 Nov 2023 08:33:39 -0600 Subject: [PATCH] Trainers: skip weights and augmentations when saving hparams (#1670) * Update base.py to fix for custom augmentations * Allow subclasses to ignore specific arguments * Fix typing * Save to self.weights * pyupgrade * Add test * Save weights --------- Co-authored-by: Adam J. Stewart --- tests/conf/ssl4eo_l_moco_1.yaml | 6 ++++++ torchgeo/trainers/base.py | 13 +++++++++---- torchgeo/trainers/byol.py | 5 +++-- torchgeo/trainers/classification.py | 5 +++-- torchgeo/trainers/moco.py | 5 +++-- torchgeo/trainers/regression.py | 7 ++++--- torchgeo/trainers/segmentation.py | 5 +++-- torchgeo/trainers/simclr.py | 5 +++-- 8 files changed, 34 insertions(+), 17 deletions(-) diff --git a/tests/conf/ssl4eo_l_moco_1.yaml b/tests/conf/ssl4eo_l_moco_1.yaml index c4152e6a290..1486d29bf01 100644 --- a/tests/conf/ssl4eo_l_moco_1.yaml +++ b/tests/conf/ssl4eo_l_moco_1.yaml @@ -8,6 +8,12 @@ model: temperature: 0.07 memory_bank_size: 10 moco_momentum: 0.999 + augmentation1: + class_path: kornia.augmentation.RandomResizedCrop + init_args: + size: + - 224 + - 224 data: class_path: SSL4EOLDataModule init_args: diff --git a/torchgeo/trainers/base.py b/torchgeo/trainers/base.py index a549cbe3f7b..3a44c047a31 100644 --- a/torchgeo/trainers/base.py +++ b/torchgeo/trainers/base.py @@ -4,7 +4,8 @@ """Base classes for all :mod:`torchgeo` trainers.""" from abc import ABC, abstractmethod -from typing import Any +from collections.abc import Sequence +from typing import Any, Optional, Union import lightning from lightning.pytorch import LightningModule @@ -27,10 +28,14 @@ class BaseTask(LightningModule, ABC): #: Whether the goal is to minimize or maximize the performance metric to monitor. mode = "min" - def __init__(self) -> None: - """Initialize a new BaseTask instance.""" + def __init__(self, ignore: Optional[Union[Sequence[str], str]] = None) -> None: + """Initialize a new BaseTask instance. + + Args: + ignore: Arguments to skip when saving hyperparameters. + """ super().__init__() - self.save_hyperparameters() + self.save_hyperparameters(ignore=ignore) self.configure_losses() self.configure_metrics() self.configure_models() diff --git a/torchgeo/trainers/byol.py b/torchgeo/trainers/byol.py index 56b0290ca34..68bdb6c9c43 100644 --- a/torchgeo/trainers/byol.py +++ b/torchgeo/trainers/byol.py @@ -322,11 +322,12 @@ def __init__( *backbone*, *learning_rate*, and *learning_rate_schedule_patience* were renamed to *model*, *lr*, and *patience*. """ - super().__init__() + self.weights = weights + super().__init__(ignore="weights") def configure_models(self) -> None: """Initialize the model.""" - weights: Optional[Union[WeightsEnum, str, bool]] = self.hparams["weights"] + weights = self.weights in_channels: int = self.hparams["in_channels"] # Create backbone diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index 21a94b4be3d..76bea82118f 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -71,7 +71,8 @@ class and used with 'ce' loss. *learning_rate* and *learning_rate_schedule_patience* were renamed to *lr* and *patience*. """ - super().__init__() + self.weights = weights + super().__init__(ignore="weights") def configure_losses(self) -> None: """Initialize the loss criterion. @@ -117,7 +118,7 @@ def configure_metrics(self) -> None: def configure_models(self) -> None: """Initialize the model.""" - weights: Optional[Union[WeightsEnum, str, bool]] = self.hparams["weights"] + weights = self.weights # Create model self.model = timm.create_model( diff --git a/torchgeo/trainers/moco.py b/torchgeo/trainers/moco.py index 7e7b5e3e159..d2621a8da74 100644 --- a/torchgeo/trainers/moco.py +++ b/torchgeo/trainers/moco.py @@ -218,7 +218,8 @@ def __init__( if memory_bank_size > 0: warnings.warn("MoCo v3 does not use a memory bank") - super().__init__() + self.weights = weights + super().__init__(ignore=["weights", "augmentation1", "augmentation2"]) grayscale_weights = grayscale_weights or torch.ones(in_channels) aug1, aug2 = moco_augmentations(version, size, grayscale_weights) @@ -236,7 +237,7 @@ def configure_losses(self) -> None: def configure_models(self) -> None: """Initialize the model.""" model: str = self.hparams["model"] - weights: Optional[Union[WeightsEnum, str, bool]] = self.hparams["weights"] + weights = self.weights in_channels: int = self.hparams["in_channels"] version: int = self.hparams["version"] layers: int = self.hparams["layers"] diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index b2847556620..b540ceecddc 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -75,7 +75,8 @@ def __init__( *learning_rate* and *learning_rate_schedule_patience* were renamed to *lr* and *patience*. """ - super().__init__() + self.weights = weights + super().__init__(ignore="weights") def configure_losses(self) -> None: """Initialize the loss criterion. @@ -110,7 +111,7 @@ def configure_metrics(self) -> None: def configure_models(self) -> None: """Initialize the model.""" # Create model - weights: Optional[Union[WeightsEnum, str, bool]] = self.hparams["weights"] + weights = self.weights self.model = timm.create_model( self.hparams["model"], num_classes=self.hparams["num_outputs"], @@ -256,7 +257,7 @@ class PixelwiseRegressionTask(RegressionTask): def configure_models(self) -> None: """Initialize the model.""" - weights: Optional[Union[WeightsEnum, str, bool]] = self.hparams["weights"] + weights = self.weights if self.hparams["model"] == "unet": self.model = smp.Unet( diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 9a67d051c9a..9ee51d8262c 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -93,7 +93,8 @@ class and used with 'ce' loss. UserWarning, ) - super().__init__() + self.weights = weights + super().__init__(ignore="weights") def configure_losses(self) -> None: """Initialize the loss criterion. @@ -151,7 +152,7 @@ def configure_models(self) -> None: """ model: str = self.hparams["model"] backbone: str = self.hparams["backbone"] - weights: Optional[Union[WeightsEnum, str, bool]] = self.hparams["weights"] + weights = self.weights in_channels: int = self.hparams["in_channels"] num_classes: int = self.hparams["num_classes"] num_filters: int = self.hparams["num_filters"] diff --git a/torchgeo/trainers/simclr.py b/torchgeo/trainers/simclr.py index ca40eabb6ae..a889be1c96f 100644 --- a/torchgeo/trainers/simclr.py +++ b/torchgeo/trainers/simclr.py @@ -134,7 +134,8 @@ def __init__( if memory_bank_size == 0: warnings.warn("SimCLR v2 uses a memory bank") - super().__init__() + self.weights = weights + super().__init__(ignore=["weights", "augmentations"]) grayscale_weights = grayscale_weights or torch.ones(in_channels) self.augmentations = augmentations or simclr_augmentations( @@ -151,7 +152,7 @@ def configure_losses(self) -> None: def configure_models(self) -> None: """Initialize the model.""" - weights: Optional[Union[WeightsEnum, str, bool]] = self.hparams["weights"] + weights = self.weights hidden_dim: int = self.hparams["hidden_dim"] output_dim: int = self.hparams["output_dim"]