Skip to content

Commit

Permalink
[NeMo-UX] Add more NeMo Logger tests (#9795)
Browse files Browse the repository at this point in the history
* add some more NeMoLogger tests and fix explicit_log_dir in 2.0

Signed-off-by: ashors1 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: ashors1 <[email protected]>

* update explicit_log_dir test

Signed-off-by: ashors1 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: ashors1 <[email protected]>

* remove reference to exp_manager

Signed-off-by: ashors1 <[email protected]>

* update checkpoint restore test

Signed-off-by: ashors1 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: ashors1 <[email protected]>

* update comment

Signed-off-by: ashors1 <[email protected]>

---------

Signed-off-by: ashors1 <[email protected]>
Signed-off-by: ashors1 <[email protected]>
Co-authored-by: Pablo Garay <[email protected]>
Co-authored-by: ashors1 <[email protected]>
  • Loading branch information
3 people authored and web-flow committed Jul 26, 2024
1 parent 7506f98 commit 75d1c67
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 25 deletions.
59 changes: 37 additions & 22 deletions nemo/lightning/nemo_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,30 +73,45 @@ def setup(self, trainer: Union[pl.Trainer, fl.Fabric], resume_if_exists: bool =
logging.rank = self.global_rank

if self.explicit_log_dir and isinstance(trainer, pl.Trainer): # If explicit log_dir was passed, short circuit
return check_explicit_log_dir(trainer, self.explicit_log_dir, self.dir, self.name, self.version)

# Default dir to ./nemo_experiments if None was passed
_dir = self.dir
if self.dir is None:
_dir = str(Path.cwd() / 'nemo_experiments')

if not self.name:
self.name = "default"

version = self.version or os.environ.get(NEMO_ENV_VARNAME_VERSION, None)
if is_global_rank_zero():
if self.use_datetime_version:
version = time.strftime('%Y-%m-%d_%H-%M-%S')
if resume_if_exists:
logging.warning(
"No version folders would be created under the log folder as 'resume_if_exists' is enabled."
)
version = None
if version:
if trainer.logger is not None and not self.update_logger_directory:
logging.warning(
f"nemo logger received explicit_log_dir: {self.explicit_log_dir} and the pytorch lightning trainer "
f"that was passed to nemo_logger container a logger, but update_logger_directory is False. This means "
f"that the trainer's logger directory may not match with the explicit_log_dir."
)
if self.dir or self.version:
logging.error(
f"nemo logger received explicit_log_dir: {self.explicit_log_dir} and at least one of dir: {self.dir}, "
f"or version: {self.version}. Please note that dir, name, and version will be ignored."
)
if is_global_rank_zero() and Path(self.explicit_log_dir).exists():
logging.warning(f"NeMoLogger is logging to {self.explicit_log_dir}, but it already exists.")
log_dir, _dir, self.name, version = Path(self.explicit_log_dir), str(self.explicit_log_dir), "", ""

else:
# Default dir to ./nemo_experiments if None was passed
_dir = self.dir
if self.dir is None:
_dir = str(Path.cwd() / 'nemo_experiments')

if not self.name:
self.name = "default"

version = self.version or os.environ.get(NEMO_ENV_VARNAME_VERSION, None)
if is_global_rank_zero():
os.environ[NEMO_ENV_VARNAME_VERSION] = version
if self.use_datetime_version:
version = time.strftime('%Y-%m-%d_%H-%M-%S')
if resume_if_exists:
logging.warning(
"No version folders would be created under the log folder as 'resume_if_exists' is enabled."
)
version = None
if version:
if is_global_rank_zero():
os.environ[NEMO_ENV_VARNAME_VERSION] = version

log_dir = Path(_dir) / Path(str(self.name)) / Path("" if version is None else str(version))

log_dir = Path(_dir) / Path(str(self.name)) / Path("" if version is None else str(version))
# update app_state with log_dir, exp_dir, etc
app_state = AppState()
app_state.log_dir = log_dir
Expand Down
110 changes: 107 additions & 3 deletions tests/lightning/test_nemo_logger.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import os
import time
from pathlib import Path
from unittest.mock import patch

import pytest
from pytorch_lightning.callbacks import ModelCheckpoint as PTLModelCheckpoint
from pytorch_lightning.loggers import WandbLogger

from nemo import lightning as nl
from nemo.constants import NEMO_ENV_VARNAME_VERSION
from nemo.utils.exp_manager import NotFoundError


class TestNeMoLogger:
Expand All @@ -31,9 +36,18 @@ def test_explicit_log_dir(self, trainer):
explicit_dir = "explicit_test_dir"
logger = nl.NeMoLogger(name="test", explicit_log_dir=explicit_dir)

with patch("nemo.utils.exp_manager.check_explicit_log_dir") as mock_check:
logger.setup(trainer)
mock_check.assert_called_once_with(trainer, explicit_dir, None, "test", None)
app_state = logger.setup(trainer)
assert str(app_state.log_dir) == "explicit_test_dir"
assert app_state.name == "" ## name should be ignored when explicit_log_dir is passed in
assert app_state.version == ""

def test_default_log_dir(self, trainer):

if os.environ.get(NEMO_ENV_VARNAME_VERSION, None) is not None:
del os.environ[NEMO_ENV_VARNAME_VERSION]
logger = nl.NeMoLogger(use_datetime_version=False)
app_state = logger.setup(trainer)
assert app_state.log_dir == Path(Path.cwd() / "nemo_experiments" / "default")

def test_custom_version(self, trainer):
custom_version = "v1.0"
Expand All @@ -58,3 +72,93 @@ def test_model_checkpoint_setup(self, trainer):
ptl_ckpt = next(cb for cb in trainer.callbacks if isinstance(cb, PTLModelCheckpoint))
assert str(ptl_ckpt.dirpath).endswith("test_ckpt")
assert ptl_ckpt.filename == "test-{epoch:02d}-{val_loss:.2f}"

def test_resume(self, trainer, tmp_path):
"""Tests the resume capabilities of NeMoLogger + AutoResume"""

if os.environ.get(NEMO_ENV_VARNAME_VERSION, None) is not None:
del os.environ[NEMO_ENV_VARNAME_VERSION]

# Error because explicit_log_dir does not exist
with pytest.raises(NotFoundError):
nl.AutoResume(
dirpath=str(tmp_path / "test_resume"),
resume_if_exists=True,
).setup(model=None, trainer=trainer)

# Error because checkpoints folder does not exist
with pytest.raises(NotFoundError):
nl.AutoResume(
dirpath=str(tmp_path / "test_resume" / "does_not_exist"),
path="does_not_exist",
resume_if_exists=True,
).setup(None, trainer)

# No error because we tell autoresume to ignore notfounderror
nl.AutoResume(
dirpath=str(tmp_path / "test_resume" / "does_not_exist"),
resume_if_exists=True,
resume_ignore_no_checkpoint=True,
).setup(None, trainer)

path = Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints").mkdir(parents=True)
# Error because checkpoints do not exist in folder
with pytest.raises(NotFoundError):
nl.AutoResume(
dirpath=path,
resume_if_exists=True,
).setup(None, trainer)

Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel--end").mkdir()
# Error because *end.ckpt is in folder indicating that training has already finished
with pytest.raises(ValueError):
nl.AutoResume(
dirpath=Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints"),
resume_if_exists=True,
).setup(None, trainer)

## if there are multiple "-last" checkpoints, choose the most recent one
Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel--end").rmdir()
Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel--last").mkdir()
time.sleep(1) ## sleep for a second so the checkpoints are created at different times
Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last").mkdir()
nl.AutoResume(
dirpath=Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints"),
resume_if_exists=True,
).setup(None, trainer)
assert str(trainer.ckpt_path) == str(
Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last")
)

# Finally succeed
logger = nl.NeMoLogger(
name="default",
dir=str(tmp_path) + "/test_resume",
version="version_0",
use_datetime_version=False,
)
logger.setup(trainer)
Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last").rmdir()
nl.AutoResume(
resume_if_exists=True,
).setup(None, trainer)
checkpoint = Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel--last")
assert Path(trainer.ckpt_path).resolve() == checkpoint.resolve()

trainer = nl.Trainer(accelerator="cpu", logger=False)
# Check that model loads from `dirpath` and not <log_dir>/checkpoints
dirpath_log_dir = Path(tmp_path / "test_resume" / "dirpath_test" / "logs")
dirpath_log_dir.mkdir(parents=True)
dirpath_checkpoint_dir = Path(tmp_path / "test_resume" / "dirpath_test" / "ckpts")
dirpath_checkpoint = Path(dirpath_checkpoint_dir / "mymodel--last")
dirpath_checkpoint.mkdir(parents=True)
logger = nl.NeMoLogger(
name="default",
explicit_log_dir=dirpath_log_dir,
)
logger.setup(trainer)
nl.AutoResume(
resume_if_exists=True,
dirpath=str(dirpath_checkpoint_dir),
).setup(None, trainer)
assert Path(trainer.ckpt_path).resolve() == dirpath_checkpoint.resolve()

0 comments on commit 75d1c67

Please sign in to comment.