Skip to content

Commit

Permalink
Refactor codebase to use trainer.loggers over trainer.logger when…
Browse files Browse the repository at this point in the history
… needed (#11920)
  • Loading branch information
akashkw authored Feb 26, 2022
1 parent 244f365 commit 7e2f9fb
Show file tree
Hide file tree
Showing 22 changed files with 185 additions and 115 deletions.
3 changes: 2 additions & 1 deletion pl_examples/domain_templates/generative_adversarial_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ def on_train_epoch_end(self):
# log sampled images
sample_imgs = self(z)
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image("generated_images", grid, self.current_epoch)
for logger in self.loggers:
logger.experiment.add_image("generated_images", grid, self.current_epoch)


def main(args: Namespace) -> None:
Expand Down
20 changes: 11 additions & 9 deletions pytorch_lightning/callbacks/device_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class DeviceStatsMonitor(Callback):
"""

def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
if not trainer.logger:
if not trainer.loggers:
raise MisconfigurationException("Cannot use DeviceStatsMonitor callback with Trainer that has no logger.")

def on_train_batch_start(
Expand All @@ -55,17 +55,18 @@ def on_train_batch_start(
batch_idx: int,
unused: Optional[int] = 0,
) -> None:
if not trainer.logger:
if not trainer.loggers:
raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.")

if not trainer.logger_connector.should_update_logs:
return

device = trainer.strategy.root_device
device_stats = trainer.accelerator.get_device_stats(device)
separator = trainer.logger.group_separator
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_start", separator)
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)
for logger in trainer.loggers:
separator = logger.group_separator
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_start", separator)
logger.log_metrics(prefixed_device_stats, step=trainer.global_step)

def on_train_batch_end(
self,
Expand All @@ -76,17 +77,18 @@ def on_train_batch_end(
batch_idx: int,
unused: Optional[int] = 0,
) -> None:
if not trainer.logger:
if not trainer.loggers:
raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.")

if not trainer.logger_connector.should_update_logs:
return

device = trainer.strategy.root_device
device_stats = trainer.accelerator.get_device_stats(device)
separator = trainer.logger.group_separator
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_end", separator)
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)
for logger in trainer.loggers:
separator = logger.group_separator
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_end", separator)
logger.log_metrics(prefixed_device_stats, step=trainer.global_step)


def _prefix_metric_keys(metrics_dict: Dict[str, float], prefix: str, separator: str) -> Dict[str, float]:
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/callbacks/gpu_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(
self._gpu_ids: List[str] = [] # will be assigned later in setup()

def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
if not trainer.logger:
if not trainer.loggers:
raise MisconfigurationException("Cannot use GPUStatsMonitor callback with Trainer that has no logger.")

if trainer.strategy.root_device.type != "cuda":
Expand Down Expand Up @@ -161,8 +161,8 @@ def on_train_batch_start(
# First log at beginning of second step
logs["batch_time/inter_step (ms)"] = (time.time() - self._snap_inter_step_time) * 1000

assert trainer.logger is not None
trainer.logger.log_metrics(logs, step=trainer.global_step)
for logger in trainer.loggers:
logger.log_metrics(logs, step=trainer.global_step)

@rank_zero_only
def on_train_batch_end(
Expand All @@ -186,8 +186,8 @@ def on_train_batch_end(
if self._log_stats.intra_step_time and self._snap_intra_step_time:
logs["batch_time/intra_step (ms)"] = (time.time() - self._snap_intra_step_time) * 1000

assert trainer.logger is not None
trainer.logger.log_metrics(logs, step=trainer.global_step)
for logger in trainer.loggers:
logger.log_metrics(logs, step=trainer.global_step)

@staticmethod
def _get_gpu_ids(device_ids: List[int]) -> List[str]:
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/callbacks/lr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def on_train_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> No
MisconfigurationException:
If ``Trainer`` has no ``logger``.
"""
if not trainer.logger:
if not trainer.loggers:
raise MisconfigurationException(
"Cannot use `LearningRateMonitor` callback with `Trainer` that has no logger."
)
Expand Down Expand Up @@ -149,7 +149,6 @@ def _check_no_key(key: str) -> bool:
self.last_momentum_values = {name + "-momentum": None for name in names_flatten}

def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
assert trainer.logger is not None
if not trainer.logger_connector.should_update_logs:
return

Expand All @@ -158,16 +157,17 @@ def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any)
latest_stat = self._extract_stats(trainer, interval)

if latest_stat:
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
for logger in trainer.loggers:
logger.log_metrics(latest_stat, step=trainer.global_step)

def on_train_epoch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
assert trainer.logger is not None
if self.logging_interval != "step":
interval = "epoch" if self.logging_interval is None else "any"
latest_stat = self._extract_stats(trainer, interval)

if latest_stat:
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
for logger in trainer.loggers:
logger.log_metrics(latest_stat, step=trainer.global_step)

def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, float]:
latest_stat = {}
Expand Down
24 changes: 13 additions & 11 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.logger import _name, _version
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache
Expand Down Expand Up @@ -379,8 +380,9 @@ def save_checkpoint(self, trainer: "pl.Trainer") -> None:
self._save_last_checkpoint(trainer, monitor_candidates)

# notify loggers
if trainer.is_global_zero and trainer.logger:
trainer.logger.after_save_checkpoint(proxy(self))
if trainer.is_global_zero:
for logger in trainer.loggers:
logger.after_save_checkpoint(proxy(self))

def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool:
from pytorch_lightning.trainer.states import TrainerFn
Expand Down Expand Up @@ -572,20 +574,20 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None:
"""
if self.dirpath is not None:
return # short circuit

if trainer.logger is not None:
if trainer.loggers:
if trainer.weights_save_path != trainer.default_root_dir:
# the user has changed weights_save_path, it overrides anything
save_dir = trainer.weights_save_path
else:
elif len(trainer.loggers) == 1:
save_dir = trainer.logger.save_dir or trainer.default_root_dir
else:
save_dir = trainer.default_root_dir

version = (
trainer.logger.version
if isinstance(trainer.logger.version, str)
else f"version_{trainer.logger.version}"
)
ckpt_path = os.path.join(save_dir, str(trainer.logger.name), version, "checkpoints")
name = _name(trainer.loggers)
version = _version(trainer.loggers)
version = version if isinstance(version, str) else f"version_{version}"

ckpt_path = os.path.join(save_dir, str(name), version, "checkpoints")
else:
ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints")

Expand Down
14 changes: 8 additions & 6 deletions pytorch_lightning/callbacks/progress/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.logger import _version
from pytorch_lightning.utilities.rank_zero import rank_zero_warn


Expand Down Expand Up @@ -213,11 +214,12 @@ def get_standard_metrics(trainer: "pl.Trainer", pl_module: "pl.LightningModule")
if pl_module.truncated_bptt_steps > 0:
items_dict["split_idx"] = trainer.fit_loop.split_idx

if trainer.logger is not None and trainer.logger.version is not None:
version = trainer.logger.version
if isinstance(version, str):
# show last 4 places of long version strings
version = version[-4:]
items_dict["v_num"] = version
if trainer.loggers:
version = _version(trainer.loggers)
if version is not None:
if isinstance(version, str):
# show last 4 places of long version strings
version = version[-4:]
items_dict["v_num"] = version

return items_dict
13 changes: 7 additions & 6 deletions pytorch_lightning/callbacks/xla_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self, verbose: bool = True) -> None:
self._verbose = verbose

def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not trainer.logger:
if not trainer.loggers:
raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.")

if isinstance(trainer.accelerator, TPUAccelerator):
Expand All @@ -88,7 +88,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
self._start_time = time.time()

def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not trainer.logger:
if not trainer.loggers:
raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.")

device = trainer.strategy.root_device
Expand All @@ -102,10 +102,11 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu
peak_memory = trainer.strategy.reduce(peak_memory) * 0.001
epoch_time = trainer.strategy.reduce(epoch_time)

trainer.logger.log_metrics(
{"avg. free memory (MB)": float(free_memory), "avg. peak memory (MB)": float(peak_memory)},
step=trainer.current_epoch,
)
for logger in trainer.loggers:
logger.log_metrics(
{"avg. free memory (MB)": float(free_memory), "avg. peak memory (MB)": float(peak_memory)},
step=trainer.current_epoch,
)

if self._verbose:
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def logger(self) -> Optional[LightningLoggerBase]:

@property
def loggers(self) -> List[LightningLoggerBase]:
"""Reference to the loggers object in the Trainer."""
"""Reference to the list of loggers in the Trainer."""
return self.trainer.loggers if self.trainer else []

def _apply_batch_transfer_handler(
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,8 +504,10 @@ def _save_loggers_on_train_batch_end(self) -> None:
"""Flushes loggers to disk."""
# when loggers should save to disk
should_flush_logs = self.trainer.logger_connector.should_flush_logs
if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None:
self.trainer.logger.save()
# TODO: is_global_zero check should be moved to logger.save() implementation
if should_flush_logs and self.trainer.is_global_zero:
for logger in self.trainer.loggers:
logger.save()

def _reload_dataloader_state_dict(self, data_fetcher: AbstractDataFetcher) -> None:
if self._dataloader_state_dict:
Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,11 @@ def optimizer_step(
def _track_grad_norm(self, trainer: "pl.Trainer") -> None:
if trainer.track_grad_norm == -1:
return
kwargs = {"group_separator": trainer.logger.group_separator} if trainer.logger is not None else {}

kwargs = {}
if len(trainer.loggers) == 1:
kwargs["group_separator"] = trainer.loggers[0].group_separator

grad_norm_dict = grad_norm(trainer.lightning_module, trainer.track_grad_norm, **kwargs)
if grad_norm_dict:
prev_fx = trainer.lightning_module._current_fx_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import pytorch_lightning as pl
from pytorch_lightning.accelerators import GPUAccelerator
from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger
from pytorch_lightning.loggers import LightningLoggerBase, TensorBoardLogger
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment
from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT
from pytorch_lightning.trainer.states import RunningStage
Expand Down Expand Up @@ -90,15 +90,15 @@ def should_update_logs(self) -> bool:
def configure_logger(self, logger: Union[bool, LightningLoggerBase, Iterable[LightningLoggerBase]]) -> None:
if isinstance(logger, bool):
# default logger
self.trainer.logger = (
TensorBoardLogger(save_dir=self.trainer.default_root_dir, version=SLURMEnvironment.job_id())
self.trainer.loggers = (
[TensorBoardLogger(save_dir=self.trainer.default_root_dir, version=SLURMEnvironment.job_id())]
if logger
else None
else []
)
elif isinstance(logger, Iterable):
self.trainer.logger = LoggerCollection(logger)
self.trainer.loggers = list(logger)
else:
self.trainer.logger = logger
self.trainer.loggers = [logger]

def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None:
"""Logs the metric dict passed in. If `step` parameter is None and `step` key is presented is metrics, uses
Expand All @@ -109,7 +109,7 @@ def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None:
step: Step for which metrics should be logged. Default value is `self.global_step` during training or
the total validation / test log step count during validation and testing.
"""
if self.trainer.logger is None or not metrics:
if not self.trainer.loggers or not metrics:
return

self._logged_metrics.update(metrics)
Expand All @@ -126,11 +126,12 @@ def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None:
step = self.trainer.global_step

# log actual metrics
if self._override_agg_and_log_metrics:
self.trainer.logger.agg_and_log_metrics(metrics=scalar_metrics, step=step)
else:
self.trainer.logger.log_metrics(metrics=scalar_metrics, step=step)
self.trainer.logger.save()
for logger in self.trainer.loggers:
if self._override_agg_and_log_metrics:
logger.agg_and_log_metrics(metrics=scalar_metrics, step=step)
else:
logger.log_metrics(metrics=scalar_metrics, step=step)
logger.save()

"""
Evaluation metric updates
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/connectors/signal_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def slurm_sigusr1_handler_fn(self, signum: _SIGNUM, frame: FrameType) -> None:
rank_zero_info("handling SIGUSR1")

# save logger to make sure we get all the metrics
if self.trainer.logger:
self.trainer.logger.finalize("finished")
for logger in self.trainer.loggers:
logger.finalize("finished")
hpc_save_path = self.trainer._checkpoint_connector.hpc_save_path(self.trainer.weights_save_path)
self.trainer.save_checkpoint(hpc_save_path)

Expand Down
Loading

0 comments on commit 7e2f9fb

Please sign in to comment.