Skip to content

Commit

Permalink
refactor unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
akashkw committed Feb 14, 2022
1 parent 3cc43a7 commit 3c4f258
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 11 deletions.
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,7 +1239,7 @@ def _log_hyperparams(self) -> None:
elif datamodule_log_hyperparams:
hparams_initial = self.datamodule.hparams_initial

for logger in trainer.loggers:
for logger in self.loggers:
if hparams_initial is not None:
self.logger.log_hyperparams(hparams_initial)
self.logger.log_graph(self.lightning_module)
Expand Down
7 changes: 4 additions & 3 deletions tests/loggers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,11 @@ def test_multiple_loggers_pickle(tmpdir):
trainer = Trainer(logger=[logger1, logger2])
pkl_bytes = pickle.dumps(trainer)
trainer2 = pickle.loads(pkl_bytes)
trainer2.logger.log_metrics({"acc": 1.0}, 0)
for logger in trainer2.loggers:
logger.log_metrics({"acc": 1.0}, 0)

assert trainer2.logger[0].metrics_logged == {"acc": 1.0}
assert trainer2.logger[1].metrics_logged == {"acc": 1.0}
for logger in trainer2.loggers:
assert logger.metrics_logged == {"acc": 1.0}


def test_adding_step_key(tmpdir):
Expand Down
5 changes: 2 additions & 3 deletions tests/profiler/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@

from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import EarlyStopping, StochasticWeightAveraging
from pytorch_lightning.loggers.base import DummyLogger, LoggerCollection
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.loggers import CSVLogger, LoggerCollection, TensorBoardLogger
from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler
from pytorch_lightning.profiler.pytorch import RegisterRecordFunction, warning_cache
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -494,7 +493,7 @@ def look_for_trace(trace_dir):

model = BoringModel()
# Wrap the logger in a list so it becomes a LoggerCollection
logger = [TensorBoardLogger(save_dir=tmpdir), DummyLogger()]
logger = [TensorBoardLogger(save_dir=tmpdir), CSVLogger(tmpdir)]
trainer = Trainer(default_root_dir=tmpdir, profiler="pytorch", logger=logger, limit_train_batches=5, max_epochs=1)
assert isinstance(trainer.logger, LoggerCollection)
trainer.fit(model)
Expand Down
5 changes: 2 additions & 3 deletions tests/trainer/properties/test_log_dir.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning.loggers import CSVLogger, LoggerCollection, TensorBoardLogger
from tests.helpers.boring_model import BoringModel


Expand Down Expand Up @@ -118,7 +117,7 @@ def test_logdir_logger_collection(tmpdir):
trainer = Trainer(
default_root_dir=default_root_dir,
max_steps=2,
logger=[TensorBoardLogger(save_dir=save_dir, name="custom_logs"), DummyLogger()],
logger=[TensorBoardLogger(save_dir=save_dir, name="custom_logs"), CSVLogger(tmpdir)],
)
assert isinstance(trainer.logger, LoggerCollection)
assert trainer.log_dir == default_root_dir
Expand Down
4 changes: 3 additions & 1 deletion tests/trainer/properties/test_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,20 @@ def test_trainer_loggers_property():
# trainer.loggers should create a list of size 1
trainer = Trainer(logger=logger1)

assert trainer.logger == logger1
assert trainer.loggers == [logger1]

# trainer.loggers should be an empty list
trainer = Trainer(logger=False)

assert trainer.logger is None
assert trainer.loggers == []

# trainer.loggers should be a list of size 1 holding the default logger
trainer = Trainer(logger=True)

assert trainer.loggers == [trainer.logger]
assert type(trainer.loggers[0]) == TensorBoardLogger
assert isinstance(trainer.logger, TensorBoardLogger)


def test_trainer_loggers_setters():
Expand Down

0 comments on commit 3c4f258

Please sign in to comment.