From c48500ebd44a2179189e2404640294e1e6b2d083 Mon Sep 17 00:00:00 2001 From: DuYicong515 Date: Wed, 23 Feb 2022 13:02:59 -0800 Subject: [PATCH] use trainer.data_parallel_device_ids in logger_connector --- .../connectors/logger_connector/logger_connector.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 0e7caa93310966..c3352b2a7a0ca3 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -19,7 +19,6 @@ from pytorch_lightning.accelerators import GPUAccelerator from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment -from pytorch_lightning.strategies import ParallelStrategy, SingleDeviceStrategy from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import memory @@ -224,13 +223,7 @@ def _log_gpus_metrics(self) -> None: self.trainer.lightning_module.log(key, mem, prog_bar=False, logger=True) else: gpu_id = int(key.split("/")[0].split(":")[1]) - parallel_device_ids = [] - if isinstance(self.trainer.accelerator, GPUAccelerator): - if isinstance(self.trainer.strategy, ParallelStrategy): - parallel_device_ids = [i for i in range(len(self.trainer.strategy.parallel_devices))] - elif isinstance(self.strategy, SingleDeviceStrategy): - parallel_device_ids = [0] - if gpu_id in parallel_device_ids: + if gpu_id in self.trainer.data_parallel_device_ids: self.trainer.lightning_module.log( key, mem, prog_bar=False, logger=True, on_step=True, on_epoch=False )