From 3c6669a181c895d3dca693291a65248eca773ab4 Mon Sep 17 00:00:00 2001 From: DuYicong515 Date: Wed, 23 Feb 2022 13:20:52 -0800 Subject: [PATCH] fix tests --- .../connectors/logger_connector/logger_connector.py | 2 +- tests/callbacks/test_gpu_stats_monitor.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index c3352b2a7a0ca3..dca28335334622 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -223,7 +223,7 @@ def _log_gpus_metrics(self) -> None: self.trainer.lightning_module.log(key, mem, prog_bar=False, logger=True) else: gpu_id = int(key.split("/")[0].split(":")[1]) - if gpu_id in self.trainer.data_parallel_device_ids: + if self.trainer.data_parallel_device_ids and gpu_id in self.trainer.data_parallel_device_ids: self.trainer.lightning_module.log( key, mem, prog_bar=False, logger=True, on_step=True, on_epoch=False ) diff --git a/tests/callbacks/test_gpu_stats_monitor.py b/tests/callbacks/test_gpu_stats_monitor.py index 1e3a9953bf137e..5af24000e06ede 100644 --- a/tests/callbacks/test_gpu_stats_monitor.py +++ b/tests/callbacks/test_gpu_stats_monitor.py @@ -47,7 +47,8 @@ def test_gpu_stats_monitor(tmpdir): logger=logger, ) - trainer.fit(model) + with pytest.deprecated_call(match="`Trainer.data_parallel_device_ids` was deprecated in v1.6."): + trainer.fit(model) assert trainer.state.finished, f"Training failed with {trainer.state}" path_csv = os.path.join(logger.log_dir, ExperimentWriter.NAME_METRICS_FILE) @@ -84,7 +85,9 @@ def test_gpu_stats_monitor_no_queries(tmpdir): devices=1, callbacks=[gpu_stats], ) - with mock.patch("pytorch_lightning.loggers.tensorboard.TensorBoardLogger.log_metrics") as log_metrics_mock: + with mock.patch( + "pytorch_lightning.loggers.tensorboard.TensorBoardLogger.log_metrics" + ) as log_metrics_mock, pytest.deprecated_call(match="`Trainer.data_parallel_device_ids` was deprecated in v1.6."): trainer.fit(model) assert log_metrics_mock.mock_calls[1:] == [