Skip to content

Commit

Permalink
Flake fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
akashkw committed Feb 15, 2022
1 parent 3c4f258 commit 6cda5b5
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 26 deletions.
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ def _track_grad_norm(self, trainer: "pl.Trainer") -> None:

if not trainer.loggers:
kwargs = {}
elif len(trainer.loggers) == 1:
kwargs = {"group_separator": trainer.logger.group_separator}
elif len(trainer.loggers) > 1:
kwargs = {"group_separator": trainer.loggers[0].group_separator}
else:
kwargs = {"group_separator": "/"}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch

import pytorch_lightning as pl
from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger
from pytorch_lightning.loggers import LightningLoggerBase, TensorBoardLogger
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment
from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT
from pytorch_lightning.trainer.states import RunningStage
Expand Down
47 changes: 24 additions & 23 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,29 +1215,30 @@ def _log_hyperparams(self) -> None:
# save exp to get started (this is where the first experiment logs are written)
datamodule_log_hyperparams = self.datamodule._log_hyperparams if self.datamodule is not None else False

if self.lightning_module._log_hyperparams and datamodule_log_hyperparams:
datamodule_hparams = self.datamodule.hparams_initial
lightning_hparams = self.lightning_module.hparams_initial
inconsistent_keys = []
for key in lightning_hparams.keys() & datamodule_hparams.keys():
lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key]
if type(lm_val) != type(dm_val):
inconsistent_keys.append(key)
elif isinstance(lm_val, torch.Tensor) and id(lm_val) != id(dm_val):
inconsistent_keys.append(key)
elif lm_val != dm_val:
inconsistent_keys.append(key)
if inconsistent_keys:
raise MisconfigurationException(
f"Error while merging hparams: the keys {inconsistent_keys} are present "
"in both the LightningModule's and LightningDataModule's hparams "
"but have different values."
)
hparams_initial = {**lightning_hparams, **datamodule_hparams}
elif self.lightning_module._log_hyperparams:
hparams_initial = self.lightning_module.hparams_initial
elif datamodule_log_hyperparams:
hparams_initial = self.datamodule.hparams_initial
if self.loggers:
if self.lightning_module._log_hyperparams and datamodule_log_hyperparams:
datamodule_hparams = self.datamodule.hparams_initial
lightning_hparams = self.lightning_module.hparams_initial
inconsistent_keys = []
for key in lightning_hparams.keys() & datamodule_hparams.keys():
lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key]
if type(lm_val) != type(dm_val):
inconsistent_keys.append(key)
elif isinstance(lm_val, torch.Tensor) and id(lm_val) != id(dm_val):
inconsistent_keys.append(key)
elif lm_val != dm_val:
inconsistent_keys.append(key)
if inconsistent_keys:
raise MisconfigurationException(
f"Error while merging hparams: the keys {inconsistent_keys} are present "
"in both the LightningModule's and LightningDataModule's hparams "
"but have different values."
)
hparams_initial = {**lightning_hparams, **datamodule_hparams}
elif self.lightning_module._log_hyperparams:
hparams_initial = self.lightning_module.hparams_initial
elif datamodule_log_hyperparams:
hparams_initial = self.datamodule.hparams_initial

for logger in self.loggers:
if hparams_initial is not None:
Expand Down

0 comments on commit 6cda5b5

Please sign in to comment.