From 29a2aa6c89ed3f8394803e1cc3af8c667f54c9f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 5 Oct 2020 02:28:35 +0200 Subject: [PATCH] refactor --- tests/loggers/test_all.py | 22 +++++----------------- tests/loggers/test_comet.py | 16 ++++++++-------- 2 files changed, 13 insertions(+), 25 deletions(-) diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index 44c2ea1de7a69..6a9f3aa1c92ad 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -19,6 +19,7 @@ ) from pytorch_lightning.loggers.base import DummyExperiment from tests.base import EvalModelTemplate +from tests.loggers.test_comet import _patch_comet_atexit def _get_logger_args(logger_class, save_dir): @@ -45,11 +46,7 @@ def _get_logger_args(logger_class, save_dir): def test_loggers_fit_test(wandb, neptune, tmpdir, monkeypatch, logger_class): """Verify that basic functionality of all loggers.""" os.environ['PL_DEV_DEBUG'] = '0' - - if logger_class == CometLogger: - # prevent comet logger from trying to print at exit, since - # pytest's stdout/stderr redirection breaks it - monkeypatch.setattr(atexit, 'register', lambda _: None) + _patch_comet_atexit(monkeypatch) model = EvalModelTemplate() @@ -110,10 +107,7 @@ def log_metrics(self, metrics, step): @mock.patch('pytorch_lightning.loggers.wandb.wandb') def test_loggers_save_dir_and_weights_save_path(wandb, tmpdir, monkeypatch, logger_class): """ Test the combinations of save_dir, weights_save_path and default_root_dir. """ - if logger_class == CometLogger: - # prevent comet logger from trying to print at exit, since - # pytest's stdout/stderr redirection breaks it - monkeypatch.setattr(atexit, 'register', lambda _: None) + _patch_comet_atexit(monkeypatch) class TestLogger(logger_class): # for this test it does not matter what these attributes are @@ -173,10 +167,7 @@ def name(self): @mock.patch('pytorch_lightning.loggers.neptune.neptune') def test_loggers_pickle(neptune, tmpdir, monkeypatch, logger_class): """Verify that pickling trainer with logger works.""" - if logger_class == CometLogger: - # prevent comet logger from trying to print at exit, since - # pytest's stdout/stderr redirection breaks it - monkeypatch.setattr(atexit, 'register', lambda _: None) + _patch_comet_atexit(monkeypatch) logger_args = _get_logger_args(logger_class, tmpdir) logger = logger_class(**logger_args) @@ -250,10 +241,7 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_ @mock.patch('pytorch_lightning.loggers.neptune.neptune') def test_logger_created_on_rank_zero_only(neptune, tmpdir, monkeypatch, logger_class): """ Test that loggers get replaced by dummy loggers on global rank > 0""" - if logger_class == CometLogger: - # prevent comet logger from trying to print at exit, since - # pytest's stdout/stderr redirection breaks it - monkeypatch.setattr(atexit, 'register', lambda _: None) + _patch_comet_atexit(monkeypatch) logger_args = _get_logger_args(logger_class, tmpdir) logger = logger_class(**logger_args) diff --git a/tests/loggers/test_comet.py b/tests/loggers/test_comet.py index 16e8d8551b6e5..0e1199e88d27a 100644 --- a/tests/loggers/test_comet.py +++ b/tests/loggers/test_comet.py @@ -9,6 +9,12 @@ from tests.base import EvalModelTemplate +def _patch_comet_atexit(monkeypatch): + """ Prevent comet logger from trying to print at exit, since pytest's stdout/stderr redirection breaks it. """ + import atexit + monkeypatch.setattr(atexit, "register", lambda _: None) + + def test_comet_logger_online(): """Test comet online with mocks.""" # Test api_key given @@ -76,11 +82,7 @@ def test_comet_logger_experiment_name(): def test_comet_logger_dirs_creation(tmpdir, monkeypatch): """ Test that the logger creates the folders and files in the right place. """ - # prevent comet logger from trying to print at exit, since - # pytest's stdout/stderr redirection breaks it - import atexit - - monkeypatch.setattr(atexit, 'register', lambda _: None) + _patch_comet_atexit(monkeypatch) logger = CometLogger(project_name='test', save_dir=tmpdir) assert not os.listdir(tmpdir) @@ -159,9 +161,7 @@ def test_comet_version_without_experiment(): def test_comet_epoch_logging(tmpdir, monkeypatch): """ Test that CometLogger removes the epoch key from the metrics dict and passes it as argument. """ - import atexit - - monkeypatch.setattr(atexit, "register", lambda _: None) + _patch_comet_atexit(monkeypatch) with patch("pytorch_lightning.loggers.comet.CometOfflineExperiment.log_metrics") as log_metrics: logger = CometLogger(project_name="test", save_dir=tmpdir) logger.log_metrics({"test": 1, "epoch": 1}, step=123)