Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
DuYicong515 committed Feb 23, 2022
1 parent c48500e commit 3c6669a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def _log_gpus_metrics(self) -> None:
self.trainer.lightning_module.log(key, mem, prog_bar=False, logger=True)
else:
gpu_id = int(key.split("/")[0].split(":")[1])
if gpu_id in self.trainer.data_parallel_device_ids:
if self.trainer.data_parallel_device_ids and gpu_id in self.trainer.data_parallel_device_ids:
self.trainer.lightning_module.log(
key, mem, prog_bar=False, logger=True, on_step=True, on_epoch=False
)
Expand Down
7 changes: 5 additions & 2 deletions tests/callbacks/test_gpu_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def test_gpu_stats_monitor(tmpdir):
logger=logger,
)

trainer.fit(model)
with pytest.deprecated_call(match="`Trainer.data_parallel_device_ids` was deprecated in v1.6."):
trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"

path_csv = os.path.join(logger.log_dir, ExperimentWriter.NAME_METRICS_FILE)
Expand Down Expand Up @@ -84,7 +85,9 @@ def test_gpu_stats_monitor_no_queries(tmpdir):
devices=1,
callbacks=[gpu_stats],
)
with mock.patch("pytorch_lightning.loggers.tensorboard.TensorBoardLogger.log_metrics") as log_metrics_mock:
with mock.patch(
"pytorch_lightning.loggers.tensorboard.TensorBoardLogger.log_metrics"
) as log_metrics_mock, pytest.deprecated_call(match="`Trainer.data_parallel_device_ids` was deprecated in v1.6."):
trainer.fit(model)

assert log_metrics_mock.mock_calls[1:] == [
Expand Down

0 comments on commit 3c6669a

Please sign in to comment.