Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return the default_root_dir as the log_dir when the logger is a LoggerCollection #8187

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
d966b35
Fix LoggerCollection/PyTorch profiler bug
gahdritz Jun 29, 2021
8f89f4c
Add test for LoggerCollection/Pytorch profiler bugfix
gahdritz Jun 29, 2021
3104da8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 29, 2021
c8d6c06
Improve bugfix formatting, documentation
gahdritz Jun 29, 2021
0c122b7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 29, 2021
af0e5c4
Tweak LogCollection log_path logic
gahdritz Jun 29, 2021
14defaa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 29, 2021
5527bca
Remove dependence on warning wording from test
gahdritz Jun 30, 2021
32e944e
Add test that directly verifies logdir
gahdritz Jun 30, 2021
a62e024
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 30, 2021
910e2a4
Remove extraneous import from profiler test
gahdritz Jun 30, 2021
5dd9965
Avoid redefining builtin
gahdritz Jun 30, 2021
cab4ffe
Replace getattr
gahdritz Jun 30, 2021
a2b2277
Remove merge vestige
gahdritz Jun 30, 2021
71fe0cb
Add LoggerCollection type checks
gahdritz Jul 1, 2021
f6cac3f
Add missing import
gahdritz Jul 1, 2021
1ac77af
Reformat comment
gahdritz Jul 7, 2021
753973d
Remove redundant parentheses
gahdritz Jul 7, 2021
4a2dea0
Remove confusing comment
gahdritz Jul 7, 2021
6d6fe5f
Remove redundant ModelCheckpoint callback
gahdritz Jul 7, 2021
99b0f74
Merge if blocks
gahdritz Jul 7, 2021
741a92f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 7, 2021
6941d93
Addres Adrian's comment
carmocca Jul 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loggers.base import LoggerCollection
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
from pytorch_lightning.loops.fit_loop import FitLoop
Expand Down Expand Up @@ -223,8 +224,12 @@ def model(self, model: torch.nn.Module) -> None:
def log_dir(self) -> Optional[str]:
if self.logger is None:
dirpath = self.default_root_dir
elif isinstance(self.logger, TensorBoardLogger):
dirpath = self.logger.log_dir
elif isinstance(self.logger, LoggerCollection):
dirpath = self.default_root_dir
else:
dirpath = getattr(self.logger, 'log_dir' if isinstance(self.logger, TensorBoardLogger) else 'save_dir')
dirpath = self.logger.save_dir

dirpath = self.accelerator.broadcast(dirpath)
return dirpath
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,10 @@ def __init__(
limit_predict_batches: How much of prediction dataset to check (float = fraction, int = num_batches)

logger: Logger (or iterable collection of loggers) for experiment tracking. A ``True`` value uses
the default ``TensorBoardLogger``. ``False`` will disable logging.
the default ``TensorBoardLogger``. ``False`` will disable logging. If multiple loggers are
provided and the `save_dir` property of that logger is not set, local files (checkpoints,
profiler traces, etc.) are saved in ``default_root_dir`` rather than in the ``log_dir`` of any
of the individual loggers.

log_gpu_memory: None, 'min_max', 'all'. Might slow performance

Expand Down
32 changes: 32 additions & 0 deletions tests/profiler/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from packaging.version import Version

from pytorch_lightning import Callback, Trainer
from pytorch_lightning.loggers.base import LoggerCollection
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler
from pytorch_lightning.profiler.pytorch import RegisterRecordFunction
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -431,6 +433,36 @@ def test_pytorch_profiler_nested(tmpdir):
assert events_name == expected, (events_name, torch.__version__, platform.system())


def test_pytorch_profiler_logger_collection(tmpdir):
"""
Tests whether the PyTorch profiler is able to write its trace locally when
the Trainer's logger is an instance of LoggerCollection. See issue #8157.
"""

def look_for_trace(trace_dir):
""" Determines if a directory contains a PyTorch trace """
return any("trace.json" in filename for filename in os.listdir(trace_dir))

# Sanity check
assert not look_for_trace(tmpdir)

model = BoringModel()

# Wrap the logger in a list so it becomes a LoggerCollection
logger = [TensorBoardLogger(save_dir=tmpdir)]
trainer = Trainer(
default_root_dir=tmpdir,
profiler="pytorch",
logger=logger,
limit_train_batches=5,
max_epochs=1,
)

assert isinstance(trainer.logger, LoggerCollection)
trainer.fit(model)
assert look_for_trace(tmpdir)


@RunIf(min_gpus=1, special=True)
def test_pytorch_profiler_nested_emit_nvtx(tmpdir):
"""
Expand Down
19 changes: 18 additions & 1 deletion tests/trainer/properties/log_dir.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
from tests.helpers.boring_model import BoringModel


Expand Down Expand Up @@ -140,3 +140,20 @@ def test_logdir_custom_logger(tmpdir):
assert trainer.log_dir == expected
trainer.fit(model)
assert trainer.log_dir == expected


def test_logdir_logger_collection(tmpdir):
"""Tests that the logdir equals the default_root_dir when the logger is a LoggerCollection"""
default_root_dir = tmpdir / "default_root_dir"
save_dir = tmpdir / "save_dir"
model = TestModel(default_root_dir)
trainer = Trainer(
default_root_dir=default_root_dir,
max_steps=2,
logger=[TensorBoardLogger(save_dir=save_dir, name='custom_logs')]
)
assert isinstance(trainer.logger, LoggerCollection)
assert trainer.log_dir == default_root_dir

trainer.fit(model)
assert trainer.log_dir == default_root_dir