Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NeMo-UX] Add more NeMo Logger tests #9795

Merged
merged 12 commits into from
Jul 26, 2024
59 changes: 37 additions & 22 deletions nemo/lightning/nemo_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,30 +73,45 @@ def setup(self, trainer: Union[pl.Trainer, fl.Fabric], resume_if_exists: bool =
logging.rank = self.global_rank

if self.explicit_log_dir and isinstance(trainer, pl.Trainer): # If explicit log_dir was passed, short circuit
return check_explicit_log_dir(trainer, self.explicit_log_dir, self.dir, self.name, self.version)

# Default dir to ./nemo_experiments if None was passed
_dir = self.dir
if self.dir is None:
_dir = str(Path.cwd() / 'nemo_experiments')

if not self.name:
self.name = "default"

version = self.version or os.environ.get(NEMO_ENV_VARNAME_VERSION, None)
if is_global_rank_zero():
if self.use_datetime_version:
version = time.strftime('%Y-%m-%d_%H-%M-%S')
if resume_if_exists:
logging.warning(
"No version folders would be created under the log folder as 'resume_if_exists' is enabled."
)
version = None
if version:
if trainer.logger is not None and not self.update_logger_directory:
logging.warning(
f"nemo logger received explicit_log_dir: {self.explicit_log_dir} and the pytorch lightning trainer "
f"that was passed to nemo_logger container a logger, but update_logger_directory is False. This means "
f"that the trainer's logger directory may not match with the explicit_log_dir."
)
if self.dir or self.version:
logging.error(
f"nemo logger received explicit_log_dir: {self.explicit_log_dir} and at least one of dir: {self.dir}, "
f"or version: {self.version}. Please note that dir, name, and version will be ignored."
)
if is_global_rank_zero() and Path(self.explicit_log_dir).exists():
logging.warning(f"NeMoLogger is logging to {self.explicit_log_dir}, but it already exists.")
log_dir, _dir, self.name, version = Path(self.explicit_log_dir), str(self.explicit_log_dir), "", ""

else:
# Default dir to ./nemo_experiments if None was passed
_dir = self.dir
if self.dir is None:
_dir = str(Path.cwd() / 'nemo_experiments')

if not self.name:
self.name = "default"

version = self.version or os.environ.get(NEMO_ENV_VARNAME_VERSION, None)
if is_global_rank_zero():
os.environ[NEMO_ENV_VARNAME_VERSION] = version
if self.use_datetime_version:
version = time.strftime('%Y-%m-%d_%H-%M-%S')
if resume_if_exists:
logging.warning(
"No version folders would be created under the log folder as 'resume_if_exists' is enabled."
)
version = None
if version:
if is_global_rank_zero():
os.environ[NEMO_ENV_VARNAME_VERSION] = version

log_dir = Path(_dir) / Path(str(self.name)) / Path("" if version is None else str(version))

log_dir = Path(_dir) / Path(str(self.name)) / Path("" if version is None else str(version))
# update app_state with log_dir, exp_dir, etc
app_state = AppState()
app_state.log_dir = log_dir
Expand Down
110 changes: 107 additions & 3 deletions tests/lightning/test_nemo_logger.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import os
import time
from pathlib import Path
from unittest.mock import patch

import pytest
from pytorch_lightning.callbacks import ModelCheckpoint as PTLModelCheckpoint
from pytorch_lightning.loggers import WandbLogger

from nemo import lightning as nl
from nemo.constants import NEMO_ENV_VARNAME_VERSION
from nemo.utils.exp_manager import NotFoundError


class TestNeMoLogger:
Expand All @@ -31,9 +36,18 @@ def test_explicit_log_dir(self, trainer):
explicit_dir = "explicit_test_dir"
logger = nl.NeMoLogger(name="test", explicit_log_dir=explicit_dir)

with patch("nemo.utils.exp_manager.check_explicit_log_dir") as mock_check:
logger.setup(trainer)
mock_check.assert_called_once_with(trainer, explicit_dir, None, "test", None)
app_state = logger.setup(trainer)
assert str(app_state.log_dir) == "explicit_test_dir"
assert app_state.name == "" ## name should be ignored when explicit_log_dir is passed in
assert app_state.version == ""

def test_default_log_dir(self, trainer):

if os.environ.get(NEMO_ENV_VARNAME_VERSION, None) is not None:
del os.environ[NEMO_ENV_VARNAME_VERSION]
logger = nl.NeMoLogger(use_datetime_version=False)
app_state = logger.setup(trainer)
assert app_state.log_dir == Path(Path.cwd() / "nemo_experiments" / "default")

def test_custom_version(self, trainer):
custom_version = "v1.0"
Expand All @@ -58,3 +72,93 @@ def test_model_checkpoint_setup(self, trainer):
ptl_ckpt = next(cb for cb in trainer.callbacks if isinstance(cb, PTLModelCheckpoint))
assert str(ptl_ckpt.dirpath).endswith("test_ckpt")
assert ptl_ckpt.filename == "test-{epoch:02d}-{val_loss:.2f}"

def test_resume(self, trainer, tmp_path):
"""Tests the resume capabilities of NeMoLogger + AutoResume"""

if os.environ.get(NEMO_ENV_VARNAME_VERSION, None) is not None:
del os.environ[NEMO_ENV_VARNAME_VERSION]

# Error because explicit_log_dir does not exist
with pytest.raises(NotFoundError):
nl.AutoResume(
dirpath=str(tmp_path / "test_resume"),
resume_if_exists=True,
).setup(model=None, trainer=trainer)

# Error because checkpoints folder does not exist
with pytest.raises(NotFoundError):
nl.AutoResume(
dirpath=str(tmp_path / "test_resume" / "does_not_exist"),
path="does_not_exist",
resume_if_exists=True,
).setup(None, trainer)

# No error because we tell autoresume to ignore notfounderror
nl.AutoResume(
dirpath=str(tmp_path / "test_resume" / "does_not_exist"),
resume_if_exists=True,
resume_ignore_no_checkpoint=True,
).setup(None, trainer)

path = Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints").mkdir(parents=True)
# Error because checkpoints do not exist in folder
with pytest.raises(NotFoundError):
nl.AutoResume(
dirpath=path,
resume_if_exists=True,
).setup(None, trainer)

Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel--end").mkdir()
# Error because *end.ckpt is in folder indicating that training has already finished
with pytest.raises(ValueError):
nl.AutoResume(
dirpath=Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints"),
resume_if_exists=True,
).setup(None, trainer)

## if there are multiple "-last" checkpoints, choose the most recent one
Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel--end").rmdir()
Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel--last").mkdir()
time.sleep(1) ## sleep for a second so the checkpoints are created at different times
Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last").mkdir()
nl.AutoResume(
dirpath=Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints"),
resume_if_exists=True,
).setup(None, trainer)
assert str(trainer.ckpt_path) == str(
Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last")
)

# Finally succeed
logger = nl.NeMoLogger(
name="default",
dir=str(tmp_path) + "/test_resume",
version="version_0",
use_datetime_version=False,
)
logger.setup(trainer)
Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last").rmdir()
nl.AutoResume(
resume_if_exists=True,
).setup(None, trainer)
checkpoint = Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel--last")
assert Path(trainer.ckpt_path).resolve() == checkpoint.resolve()

trainer = nl.Trainer(accelerator="cpu", logger=False)
# Check that model loads from `dirpath` and not <log_dir>/checkpoints
dirpath_log_dir = Path(tmp_path / "test_resume" / "dirpath_test" / "logs")
dirpath_log_dir.mkdir(parents=True)
dirpath_checkpoint_dir = Path(tmp_path / "test_resume" / "dirpath_test" / "ckpts")
dirpath_checkpoint = Path(dirpath_checkpoint_dir / "mymodel--last")
dirpath_checkpoint.mkdir(parents=True)
logger = nl.NeMoLogger(
name="default",
explicit_log_dir=dirpath_log_dir,
)
logger.setup(trainer)
nl.AutoResume(
resume_if_exists=True,
dirpath=str(dirpath_checkpoint_dir),
).setup(None, trainer)
assert Path(trainer.ckpt_path).resolve() == dirpath_checkpoint.resolve()
Loading