From be9d25c338be0244c605e25315f42beb1237f368 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 25 Feb 2023 15:39:48 -0700 Subject: [PATCH] Trainers: fix support for non-TensorBoardLogger (#1145) * Trainers: fix support for CSVLogger * No need to check truthiness of attr * Revert "No need to check truthiness of attr" This reverts commit 450df6c53c53c6e64cf4178a0b1ed91c33aa5c45. --- torchgeo/trainers/classification.py | 2 ++ torchgeo/trainers/detection.py | 1 + torchgeo/trainers/regression.py | 1 + torchgeo/trainers/segmentation.py | 1 + 4 files changed, 5 insertions(+) diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index c94549d7a15..57b1a7fbf9d 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -192,6 +192,7 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: and hasattr(self.trainer, "datamodule") and self.logger and hasattr(self.logger, "experiment") + and hasattr(self.logger.experiment, "add_figure") ): try: datamodule = self.trainer.datamodule @@ -376,6 +377,7 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: and hasattr(self.trainer, "datamodule") and self.logger and hasattr(self.logger, "experiment") + and hasattr(self.logger.experiment, "add_figure") ): try: datamodule = self.trainer.datamodule diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index 939ae761b2d..1eafdf1b47b 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -230,6 +230,7 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: and hasattr(self.trainer, "datamodule") and self.logger and hasattr(self.logger, "experiment") + and hasattr(self.logger.experiment, "add_figure") ): try: datamodule = self.trainer.datamodule diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 7e6137f226a..0d32e3436c7 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -149,6 +149,7 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: and hasattr(self.trainer, "datamodule") and self.logger and hasattr(self.logger, "experiment") + and hasattr(self.logger.experiment, "add_figure") ): try: datamodule = self.trainer.datamodule diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 7459fa8a953..8913589b913 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -202,6 +202,7 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: and hasattr(self.trainer, "datamodule") and self.logger and hasattr(self.logger, "experiment") + and hasattr(self.logger.experiment, "add_figure") ): try: datamodule = self.trainer.datamodule