diff --git a/nemo/lightning/nemo_logger.py b/nemo/lightning/nemo_logger.py index 5ed783fdbefe..8c0508676f14 100644 --- a/nemo/lightning/nemo_logger.py +++ b/nemo/lightning/nemo_logger.py @@ -73,30 +73,45 @@ def setup(self, trainer: Union[pl.Trainer, fl.Fabric], resume_if_exists: bool = logging.rank = self.global_rank if self.explicit_log_dir and isinstance(trainer, pl.Trainer): # If explicit log_dir was passed, short circuit - return check_explicit_log_dir(trainer, self.explicit_log_dir, self.dir, self.name, self.version) - - # Default dir to ./nemo_experiments if None was passed - _dir = self.dir - if self.dir is None: - _dir = str(Path.cwd() / 'nemo_experiments') - - if not self.name: - self.name = "default" - - version = self.version or os.environ.get(NEMO_ENV_VARNAME_VERSION, None) - if is_global_rank_zero(): - if self.use_datetime_version: - version = time.strftime('%Y-%m-%d_%H-%M-%S') - if resume_if_exists: - logging.warning( - "No version folders would be created under the log folder as 'resume_if_exists' is enabled." - ) - version = None - if version: + if trainer.logger is not None and not self.update_logger_directory: + logging.warning( + f"nemo logger received explicit_log_dir: {self.explicit_log_dir} and the pytorch lightning trainer " + f"that was passed to nemo_logger container a logger, but update_logger_directory is False. This means " + f"that the trainer's logger directory may not match with the explicit_log_dir." + ) + if self.dir or self.version: + logging.error( + f"nemo logger received explicit_log_dir: {self.explicit_log_dir} and at least one of dir: {self.dir}, " + f"or version: {self.version}. Please note that dir, name, and version will be ignored." + ) + if is_global_rank_zero() and Path(self.explicit_log_dir).exists(): + logging.warning(f"NeMoLogger is logging to {self.explicit_log_dir}, but it already exists.") + log_dir, _dir, self.name, version = Path(self.explicit_log_dir), str(self.explicit_log_dir), "", "" + + else: + # Default dir to ./nemo_experiments if None was passed + _dir = self.dir + if self.dir is None: + _dir = str(Path.cwd() / 'nemo_experiments') + + if not self.name: + self.name = "default" + + version = self.version or os.environ.get(NEMO_ENV_VARNAME_VERSION, None) if is_global_rank_zero(): - os.environ[NEMO_ENV_VARNAME_VERSION] = version + if self.use_datetime_version: + version = time.strftime('%Y-%m-%d_%H-%M-%S') + if resume_if_exists: + logging.warning( + "No version folders would be created under the log folder as 'resume_if_exists' is enabled." + ) + version = None + if version: + if is_global_rank_zero(): + os.environ[NEMO_ENV_VARNAME_VERSION] = version + + log_dir = Path(_dir) / Path(str(self.name)) / Path("" if version is None else str(version)) - log_dir = Path(_dir) / Path(str(self.name)) / Path("" if version is None else str(version)) # update app_state with log_dir, exp_dir, etc app_state = AppState() app_state.log_dir = log_dir diff --git a/tests/lightning/test_nemo_logger.py b/tests/lightning/test_nemo_logger.py index 0dd49838d9e4..0b9e37c46752 100644 --- a/tests/lightning/test_nemo_logger.py +++ b/tests/lightning/test_nemo_logger.py @@ -1,3 +1,6 @@ +import os +import time +from pathlib import Path from unittest.mock import patch import pytest @@ -5,6 +8,8 @@ from pytorch_lightning.loggers import WandbLogger from nemo import lightning as nl +from nemo.constants import NEMO_ENV_VARNAME_VERSION +from nemo.utils.exp_manager import NotFoundError class TestNeMoLogger: @@ -31,9 +36,18 @@ def test_explicit_log_dir(self, trainer): explicit_dir = "explicit_test_dir" logger = nl.NeMoLogger(name="test", explicit_log_dir=explicit_dir) - with patch("nemo.utils.exp_manager.check_explicit_log_dir") as mock_check: - logger.setup(trainer) - mock_check.assert_called_once_with(trainer, explicit_dir, None, "test", None) + app_state = logger.setup(trainer) + assert str(app_state.log_dir) == "explicit_test_dir" + assert app_state.name == "" ## name should be ignored when explicit_log_dir is passed in + assert app_state.version == "" + + def test_default_log_dir(self, trainer): + + if os.environ.get(NEMO_ENV_VARNAME_VERSION, None) is not None: + del os.environ[NEMO_ENV_VARNAME_VERSION] + logger = nl.NeMoLogger(use_datetime_version=False) + app_state = logger.setup(trainer) + assert app_state.log_dir == Path(Path.cwd() / "nemo_experiments" / "default") def test_custom_version(self, trainer): custom_version = "v1.0" @@ -58,3 +72,93 @@ def test_model_checkpoint_setup(self, trainer): ptl_ckpt = next(cb for cb in trainer.callbacks if isinstance(cb, PTLModelCheckpoint)) assert str(ptl_ckpt.dirpath).endswith("test_ckpt") assert ptl_ckpt.filename == "test-{epoch:02d}-{val_loss:.2f}" + + def test_resume(self, trainer, tmp_path): + """Tests the resume capabilities of NeMoLogger + AutoResume""" + + if os.environ.get(NEMO_ENV_VARNAME_VERSION, None) is not None: + del os.environ[NEMO_ENV_VARNAME_VERSION] + + # Error because explicit_log_dir does not exist + with pytest.raises(NotFoundError): + nl.AutoResume( + dirpath=str(tmp_path / "test_resume"), + resume_if_exists=True, + ).setup(model=None, trainer=trainer) + + # Error because checkpoints folder does not exist + with pytest.raises(NotFoundError): + nl.AutoResume( + dirpath=str(tmp_path / "test_resume" / "does_not_exist"), + path="does_not_exist", + resume_if_exists=True, + ).setup(None, trainer) + + # No error because we tell autoresume to ignore notfounderror + nl.AutoResume( + dirpath=str(tmp_path / "test_resume" / "does_not_exist"), + resume_if_exists=True, + resume_ignore_no_checkpoint=True, + ).setup(None, trainer) + + path = Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints").mkdir(parents=True) + # Error because checkpoints do not exist in folder + with pytest.raises(NotFoundError): + nl.AutoResume( + dirpath=path, + resume_if_exists=True, + ).setup(None, trainer) + + Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel--end").mkdir() + # Error because *end.ckpt is in folder indicating that training has already finished + with pytest.raises(ValueError): + nl.AutoResume( + dirpath=Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints"), + resume_if_exists=True, + ).setup(None, trainer) + + ## if there are multiple "-last" checkpoints, choose the most recent one + Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel--end").rmdir() + Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel--last").mkdir() + time.sleep(1) ## sleep for a second so the checkpoints are created at different times + Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last").mkdir() + nl.AutoResume( + dirpath=Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints"), + resume_if_exists=True, + ).setup(None, trainer) + assert str(trainer.ckpt_path) == str( + Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last") + ) + + # Finally succeed + logger = nl.NeMoLogger( + name="default", + dir=str(tmp_path) + "/test_resume", + version="version_0", + use_datetime_version=False, + ) + logger.setup(trainer) + Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last").rmdir() + nl.AutoResume( + resume_if_exists=True, + ).setup(None, trainer) + checkpoint = Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel--last") + assert Path(trainer.ckpt_path).resolve() == checkpoint.resolve() + + trainer = nl.Trainer(accelerator="cpu", logger=False) + # Check that model loads from `dirpath` and not /checkpoints + dirpath_log_dir = Path(tmp_path / "test_resume" / "dirpath_test" / "logs") + dirpath_log_dir.mkdir(parents=True) + dirpath_checkpoint_dir = Path(tmp_path / "test_resume" / "dirpath_test" / "ckpts") + dirpath_checkpoint = Path(dirpath_checkpoint_dir / "mymodel--last") + dirpath_checkpoint.mkdir(parents=True) + logger = nl.NeMoLogger( + name="default", + explicit_log_dir=dirpath_log_dir, + ) + logger.setup(trainer) + nl.AutoResume( + resume_if_exists=True, + dirpath=str(dirpath_checkpoint_dir), + ).setup(None, trainer) + assert Path(trainer.ckpt_path).resolve() == dirpath_checkpoint.resolve()