Skip to content

Commit

Permalink
callback in expmanager
Browse files Browse the repository at this point in the history
Signed-off-by: nithinraok <[email protected]>
  • Loading branch information
nithinraok committed Sep 1, 2020
1 parent c4f2d97 commit d86880f
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 35 deletions.
7 changes: 0 additions & 7 deletions examples/speaker_recognition/speaker_reco.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from pytorch_lightning import seed_everything

from nemo.collections.asr.models import EncDecSpeakerLabelModel
from nemo.collections.common.callbacks import CallbackManager
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager
Expand Down Expand Up @@ -51,12 +50,6 @@ def main(cfg):

logging.info(f'Hydra config: {cfg.pretty()}')
trainer = pl.Trainer(**cfg.trainer)

callbacks = ['LogEpochTimeCallback()', 'LogTrainValidLossCallback()']
callback_mgr = CallbackManager()
callbacks = callback_mgr.add_callback(callbacks)
trainer.callbacks.extend(callbacks)

log_dir = exp_manager(trainer, cfg.get("exp_manager", None))
speaker_model = EncDecSpeakerLabelModel(cfg=cfg.model, trainer=trainer)
trainer.fit(speaker_model)
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/data/audio_to_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def fixed_seq_collate_fn(self, batch):
tokens_lengths = torch.stack(tokens_lengths)

return audio_signal, audio_lengths, tokens, tokens_lengths

def sliced_seq_collate_fn(self, batch):
"""collate batch of audio sig, audio len, tokens, tokens len
Args:
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/models/label_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def test_epoch_end(self, outputs):
emb_shape = embs.shape[-1]
embs = embs.view(-1, emb_shape).cpu().numpy()
out_embeddings = {}
start_idx=0
start_idx = 0
with open(self.test_manifest, 'r') as manifest:
for idx, line in enumerate(manifest.readlines()):
line = line.strip()
Expand All @@ -229,7 +229,7 @@ def test_epoch_end(self, outputs):
if uniq_name in out_embeddings:
raise KeyError("Embeddings for label {} already present in emb dictionary".format(uniq_name))
num_slices = slices[idx]
end_idx = start_idx+num_slices
end_idx = start_idx + num_slices
out_embeddings[uniq_name] = embs[start_idx:end_idx].mean(axis=0)
start_idx = end_idx

Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/common/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from nemo.collections.common.callbacks.callbacks import (
CallbackManager,
AVAILABLE_CALLBACKS,
LogEpochTimeCallback,
LogTrainValidLossCallback,
)
9 changes: 7 additions & 2 deletions nemo/collections/common/callbacks/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from typing import List, Union

from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_only
Expand Down Expand Up @@ -45,7 +44,7 @@ def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_id
print_freq = trainer.row_log_interval
total_batches = trainer.num_training_batches
if 0 < print_freq < 1:
print_freq = int(total_batches*print_freq)
print_freq = int(total_batches * print_freq)
if batch_idx % print_freq == 0:
logging.info(
"Epoch: {}/{} batch: {}/{} train_loss: {:.3f} train_acc: {:.2f}".format(
Expand All @@ -64,3 +63,9 @@ def on_validation_epoch_end(self, trainer, pl_module):
trainer.current_epoch + 1, trainer.max_epochs, pl_module.val_loss_mean, pl_module.accuracy
)
)


AVAILABLE_CALLBACKS = {
'LogEpochTimeCallback': LogEpochTimeCallback(),
'LogTrainValidLossCallback': LogTrainValidLossCallback(),
}
1 change: 1 addition & 0 deletions nemo/core/classes/modelPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import LightningModule, Trainer

from nemo.collections.common import callbacks
from nemo.core import optim
from nemo.core.classes.common import Model
from nemo.core.optim import prepare_lr_scheduler
Expand Down
37 changes: 15 additions & 22 deletions nemo/utils/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,35 +29,14 @@
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning.utilities import rank_zero_only

from nemo.collections.common import callbacks
from nemo.constants import NEMO_ENV_VARNAME_VERSION
from nemo.utils import logging
from nemo.utils.exceptions import NeMoBaseException
from nemo.utils.get_rank import is_global_rank_zero
from nemo.utils.lightning_logger_patch import add_filehandlers_to_pl_logger


class CallbackManager:
def __init__(self) -> None:
self.callbacks = set(['LogEpochTimeCallback()', 'LogTrainValidLossCallback()'])

def get_callback(self, callback_name: str):
if callback_name in self.callbacks:
return eval(callback_name)
else:
raise NameError("Provided Callback name is not part of nemo Callback system")

def add_callback(self, callback_names: Union[str, List]):
if type(callback_names) is str:
callback_names = callback_names.split(',')

callbacks = []
for name in callback_names:
callbacks.append(self.get_callback(name))

return callbacks



class NotFoundError(NeMoBaseException):
""" Raised when a file or folder is not found"""

Expand Down Expand Up @@ -97,6 +76,7 @@ class ExpManagerConfig:
create_checkpoint_callback: Optional[bool] = True
# Additional exp_manager arguments
files_to_copy: Optional[List[str]] = None
callbacks: Optional[str] = None


def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictConfig, Dict]] = None) -> Path:
Expand Down Expand Up @@ -216,6 +196,10 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo
if cfg.create_checkpoint_callback:
configure_checkpointing(trainer, log_dir, checkpoint_name)

# Add nemo callbacks
if cfg.callbacks:
add_callbacks(trainer, cfg.callbacks)

# Move files_to_copy to folder and add git information if present
if cfg.files_to_copy:
for _file in cfg.files_to_copy:
Expand Down Expand Up @@ -577,3 +561,12 @@ def on_train_end(self, trainer, pl_module):
trainer.configure_checkpoint_callback(checkpoint_callback)
trainer.callbacks.append(checkpoint_callback)
trainer.checkpoint_callback = checkpoint_callback


def add_callbacks(trainer: 'pytorch_lightning.Trainer', nemo_callbacks: Optional[List[str]]):

for callback in nemo_callbacks:
if callback in callbacks.AVAILABLE_CALLBACKS:
trainer.callbacks.append(callbacks.AVAILABLE_CALLBACKS[callback])
else:
raise NameError(" Request callback is not part of nemo callbacks please check callback name")

0 comments on commit d86880f

Please sign in to comment.