diff --git a/examples/speaker_recognition/speaker_reco.py b/examples/speaker_recognition/speaker_reco.py index 2cf7fd0553d4..3259a6512b56 100644 --- a/examples/speaker_recognition/speaker_reco.py +++ b/examples/speaker_recognition/speaker_reco.py @@ -18,7 +18,6 @@ from pytorch_lightning import seed_everything from nemo.collections.asr.models import EncDecSpeakerLabelModel -from nemo.collections.common.callbacks import CallbackManager from nemo.core.config import hydra_runner from nemo.utils import logging from nemo.utils.exp_manager import exp_manager @@ -51,12 +50,6 @@ def main(cfg): logging.info(f'Hydra config: {cfg.pretty()}') trainer = pl.Trainer(**cfg.trainer) - - callbacks = ['LogEpochTimeCallback()', 'LogTrainValidLossCallback()'] - callback_mgr = CallbackManager() - callbacks = callback_mgr.add_callback(callbacks) - trainer.callbacks.extend(callbacks) - log_dir = exp_manager(trainer, cfg.get("exp_manager", None)) speaker_model = EncDecSpeakerLabelModel(cfg=cfg.model, trainer=trainer) trainer.fit(speaker_model) diff --git a/nemo/collections/asr/data/audio_to_label.py b/nemo/collections/asr/data/audio_to_label.py index c04649752b30..28929375c741 100644 --- a/nemo/collections/asr/data/audio_to_label.py +++ b/nemo/collections/asr/data/audio_to_label.py @@ -154,7 +154,7 @@ def fixed_seq_collate_fn(self, batch): tokens_lengths = torch.stack(tokens_lengths) return audio_signal, audio_lengths, tokens, tokens_lengths - + def sliced_seq_collate_fn(self, batch): """collate batch of audio sig, audio len, tokens, tokens len Args: diff --git a/nemo/collections/asr/models/label_models.py b/nemo/collections/asr/models/label_models.py index 56ab2200959a..4dd3ea9ad1b0 100644 --- a/nemo/collections/asr/models/label_models.py +++ b/nemo/collections/asr/models/label_models.py @@ -219,7 +219,7 @@ def test_epoch_end(self, outputs): emb_shape = embs.shape[-1] embs = embs.view(-1, emb_shape).cpu().numpy() out_embeddings = {} - start_idx=0 + start_idx = 0 with open(self.test_manifest, 'r') as manifest: for idx, line in enumerate(manifest.readlines()): line = line.strip() @@ -229,7 +229,7 @@ def test_epoch_end(self, outputs): if uniq_name in out_embeddings: raise KeyError("Embeddings for label {} already present in emb dictionary".format(uniq_name)) num_slices = slices[idx] - end_idx = start_idx+num_slices + end_idx = start_idx + num_slices out_embeddings[uniq_name] = embs[start_idx:end_idx].mean(axis=0) start_idx = end_idx diff --git a/nemo/collections/common/callbacks/__init__.py b/nemo/collections/common/callbacks/__init__.py index 92393d04ec53..96e2bff8d4d6 100644 --- a/nemo/collections/common/callbacks/__init__.py +++ b/nemo/collections/common/callbacks/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. from nemo.collections.common.callbacks.callbacks import ( - CallbackManager, + AVAILABLE_CALLBACKS, LogEpochTimeCallback, LogTrainValidLossCallback, ) diff --git a/nemo/collections/common/callbacks/callbacks.py b/nemo/collections/common/callbacks/callbacks.py index e890727bcc17..96a4cc21063d 100644 --- a/nemo/collections/common/callbacks/callbacks.py +++ b/nemo/collections/common/callbacks/callbacks.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import time -from typing import List, Union from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_only @@ -45,7 +44,7 @@ def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_id print_freq = trainer.row_log_interval total_batches = trainer.num_training_batches if 0 < print_freq < 1: - print_freq = int(total_batches*print_freq) + print_freq = int(total_batches * print_freq) if batch_idx % print_freq == 0: logging.info( "Epoch: {}/{} batch: {}/{} train_loss: {:.3f} train_acc: {:.2f}".format( @@ -64,3 +63,9 @@ def on_validation_epoch_end(self, trainer, pl_module): trainer.current_epoch + 1, trainer.max_epochs, pl_module.val_loss_mean, pl_module.accuracy ) ) + + +AVAILABLE_CALLBACKS = { + 'LogEpochTimeCallback': LogEpochTimeCallback(), + 'LogTrainValidLossCallback': LogTrainValidLossCallback(), +} diff --git a/nemo/core/classes/modelPT.py b/nemo/core/classes/modelPT.py index 489fdee0020b..d8811263cfd6 100644 --- a/nemo/core/classes/modelPT.py +++ b/nemo/core/classes/modelPT.py @@ -27,6 +27,7 @@ from omegaconf import DictConfig, OmegaConf from pytorch_lightning import LightningModule, Trainer +from nemo.collections.common import callbacks from nemo.core import optim from nemo.core.classes.common import Model from nemo.core.optim import prepare_lr_scheduler diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index 4433dd664b83..baa14ef10438 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -29,6 +29,7 @@ from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger from pytorch_lightning.utilities import rank_zero_only +from nemo.collections.common import callbacks from nemo.constants import NEMO_ENV_VARNAME_VERSION from nemo.utils import logging from nemo.utils.exceptions import NeMoBaseException @@ -36,28 +37,6 @@ from nemo.utils.lightning_logger_patch import add_filehandlers_to_pl_logger -class CallbackManager: - def __init__(self) -> None: - self.callbacks = set(['LogEpochTimeCallback()', 'LogTrainValidLossCallback()']) - - def get_callback(self, callback_name: str): - if callback_name in self.callbacks: - return eval(callback_name) - else: - raise NameError("Provided Callback name is not part of nemo Callback system") - - def add_callback(self, callback_names: Union[str, List]): - if type(callback_names) is str: - callback_names = callback_names.split(',') - - callbacks = [] - for name in callback_names: - callbacks.append(self.get_callback(name)) - - return callbacks - - - class NotFoundError(NeMoBaseException): """ Raised when a file or folder is not found""" @@ -97,6 +76,7 @@ class ExpManagerConfig: create_checkpoint_callback: Optional[bool] = True # Additional exp_manager arguments files_to_copy: Optional[List[str]] = None + callbacks: Optional[str] = None def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictConfig, Dict]] = None) -> Path: @@ -216,6 +196,10 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo if cfg.create_checkpoint_callback: configure_checkpointing(trainer, log_dir, checkpoint_name) + # Add nemo callbacks + if cfg.callbacks: + add_callbacks(trainer, cfg.callbacks) + # Move files_to_copy to folder and add git information if present if cfg.files_to_copy: for _file in cfg.files_to_copy: @@ -577,3 +561,12 @@ def on_train_end(self, trainer, pl_module): trainer.configure_checkpoint_callback(checkpoint_callback) trainer.callbacks.append(checkpoint_callback) trainer.checkpoint_callback = checkpoint_callback + + +def add_callbacks(trainer: 'pytorch_lightning.Trainer', nemo_callbacks: Optional[List[str]]): + + for callback in nemo_callbacks: + if callback in callbacks.AVAILABLE_CALLBACKS: + trainer.callbacks.append(callbacks.AVAILABLE_CALLBACKS[callback]) + else: + raise NameError(" Request callback is not part of nemo callbacks please check callback name")