Skip to content

Commit

Permalink
Rebase off of master; add new working prototype of loss callback
Browse files Browse the repository at this point in the history
Signed-off-by: Jason <[email protected]>
  • Loading branch information
blisc committed May 6, 2020
1 parent d219483 commit b16d356
Show file tree
Hide file tree
Showing 8 changed files with 525 additions and 30 deletions.
298 changes: 298 additions & 0 deletions examples/asr/jasper_an4_debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
# Copyright (c) 2019 NVIDIA Corporation
import argparse
import math
import os
from functools import partial

from ruamel.yaml import YAML

import nemo
import nemo.collections.asr as nemo_asr
import nemo.utils.argparse as nm_argparse
from nemo.collections.asr.helpers import (
monitor_asr_train_progress,
post_process_predictions,
post_process_transcripts,
process_evaluation_batch,
process_evaluation_epoch,
word_error_rate,
)
from nemo.utils.lr_policies import CosineAnnealing

logging = nemo.logging


def create_dags(model_config_file, vocab, args, nf):

# Create a data_layer for training.
data_layer = nemo_asr.AudioToTextDataLayer.import_from_config(
model_config_file,
"AudioToTextDataLayer_train",
overwrite_params={"manifest_filepath": args.train_dataset, "batch_size": args.batch_size},
)

num_samples = len(data_layer)
steps_per_epoch = math.ceil(num_samples / (data_layer.batch_size * args.iter_per_step * nf.world_size))
total_steps = steps_per_epoch * args.num_epochs
logging.info("Train samples=", num_samples, "num_steps=", total_steps)

# # Create a data_layer for evaluation.
# data_layer_eval = nemo_asr.AudioToTextDataLayer.import_from_config(
# model_config_file, "AudioToTextDataLayer_eval", overwrite_params={"manifest_filepath": args.eval_datasets},
# )

# num_samples = len(data_layer_eval)
# logging.info(f"Eval samples={num_samples}")

# Instantiate data processor.
data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor.import_from_config(
model_config_file, "AudioToMelSpectrogramPreprocessor"
)

# Instantiate JASPER encoder-decoder modules.
jasper_encoder = nemo_asr.JasperEncoder.import_from_config(model_config_file, "JasperEncoder")
jasper_decoder = nemo_asr.JasperDecoderForCTC.import_from_config(
model_config_file, "JasperDecoderForCTC", overwrite_params={"num_classes": len(vocab)}
)

# Instantiate losses.
ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab))
greedy_decoder = nemo_asr.GreedyCTCDecoder()

# Create a training graph.
audio, audio_len, transcript, transcript_len = data_layer()
processed, processed_len = data_preprocessor(input_signal=audio, length=audio_len)
encoded, encoded_len = jasper_encoder(audio_signal=processed, length=processed_len)
log_probs = jasper_decoder(encoder_output=encoded)
predictions = greedy_decoder(log_probs=log_probs)
loss = ctc_loss(log_probs=log_probs, targets=transcript, input_length=encoded_len, target_length=transcript_len,)

# # Create an evaluation graph.
# audio_e, audio_len_e, transcript_e, transcript_len_e = data_layer_eval()
# processed_e, processed_len_e = data_preprocessor(input_signal=audio_e, length=audio_len_e)
# encoded_e, encoded_len_e = jasper_encoder(audio_signal=processed_e, length=processed_len_e)
# log_probs_e = jasper_decoder(encoder_output=encoded_e)
# predictions_e = greedy_decoder(log_probs=log_probs_e)
# loss_e = ctc_loss(
# log_probs=log_probs_e, targets=transcript_e, input_length=encoded_len_e, target_length=transcript_len_e,
# )
logging.info("Num of params in encoder: {0}".format(jasper_encoder.num_weights))

# Callbacks to print info to console and Tensorboard.
# train_callback = nemo.core.SimpleLossLoggerCallback(
# tensors=[loss, predictions, transcript, transcript_len],
# print_func=partial(monitor_asr_train_progress, labels=vocab),
# get_tb_values=lambda x: [["loss", x[0]]],
# tb_writer=nf.tb_writer,
# )

# loss.rename("test")
# train_callback = nemo.core.SimpleLossLogger(tensors_to_log=["test"])

train_callback = nemo.core.SimpleLossLogger()

# checkpointer_callback = nemo.core.CheckpointCallback(folder=nf.checkpoint_dir, step_freq=args.checkpoint_save_freq)

# eval_tensors = [loss_e, predictions_e, transcript_e, transcript_len_e]
# eval_callback = nemo.core.EvaluatorCallback(
# eval_tensors=eval_tensors,
# user_iter_callback=partial(process_evaluation_batch, labels=vocab),
# user_epochs_done_callback=process_evaluation_epoch,
# eval_step=args.eval_freq,
# tb_writer=nf.tb_writer,
# eval_at_start=not args.do_not_eval_at_start,
# )
# callbacks = [train_callback, checkpointer_callback, eval_callback]
callbacks = [train_callback]

# Return entities required by the actual training.
return (
loss,
# eval_tensors,
callbacks,
total_steps,
# log_probs_e,
# encoded_len_e,
)


def main():
parser = argparse.ArgumentParser(
parents=[nm_argparse.NemoArgParser()], description='AN4 ASR', conflict_handler='resolve',
)

# Overwrite default args
parser.add_argument("--train_dataset", type=str, help="training dataset path")
parser.add_argument("--eval_datasets", type=str, help="validation dataset path")

# Create new args
# parser.add_argument("--lm", default="./an4-lm.3gram.binary", type=str)
parser.add_argument("--batch_size", default=48, type=int, help="size of the training batch")
parser.add_argument("--lm", default=None, type=str)
parser.add_argument("--test_after_training", action='store_true')
parser.add_argument("--momentum", type=float)
parser.add_argument("--beta1", default=0.95, type=float)
parser.add_argument("--beta2", default=0.25, type=float)
parser.add_argument("--do_not_eval_at_start", action='store_true')
parser.set_defaults(
model_config="./configs/jasper_an4.yaml",
train_dataset="~/TestData/an4_dataset/an4_train.json",
eval_datasets="~/TestData/an4_dataset/an4_val.json",
work_dir="./tmp",
optimizer="novograd",
num_epochs=50,
lr=0.02,
weight_decay=0.005,
checkpoint_save_freq=1000,
eval_freq=100,
amp_opt_level="O1",
)

args = parser.parse_args()
betas = (args.beta1, args.beta2)

wer_thr = 0.20
beam_wer_thr = 0.15

nf = nemo.core.NeuralModuleFactory(
local_rank=args.local_rank,
files_to_copy=[__file__],
optimization_level=args.amp_opt_level,
random_seed=0,
log_dir=args.work_dir,
create_tb_writer=True,
cudnn_benchmark=args.cudnn_benchmark,
)
tb_writer = nf.tb_writer
checkpoint_dir = nf.checkpoint_dir

# Load model definition
yaml = YAML(typ="safe")
with open(args.model_config) as f:
jasper_params = yaml.load(f)
# Get vocabulary.
vocab = jasper_params['labels']

# (loss, eval_tensors, callbacks, total_steps, log_probs_e, encoded_len_e,) = create_dags(
# args.model_config, vocab, args, nf
# )

loss, callbacks, total_steps = create_dags(args.model_config, vocab, args, nf)

nf.train(
tensors_to_optimize=[loss],
callbacks=callbacks,
optimizer=args.optimizer,
lr_policy=CosineAnnealing(total_steps=total_steps, min_lr=args.lr / 100),
optimization_params={
"num_epochs": args.num_epochs,
"max_steps": args.max_steps,
"lr": args.lr,
"momentum": args.momentum,
"betas": betas,
"weight_decay": args.weight_decay,
"grad_norm_clip": None,
},
batches_per_step=args.iter_per_step,
amp_max_loss_scale=256.0,
# synced_batchnorm=(nf.global_rank is not None),
)

# if args.test_after_training:
# logging.info("Testing greedy and beam search with LM WER.")
# # Create BeamSearch NM
# if nf.world_size > 1 or args.lm is None:
# logging.warning("Skipping beam search WER as it does not work if doing distributed training.")
# else:
# beam_search_with_lm = nemo_asr.BeamSearchDecoderWithLM(
# vocab=vocab, beam_width=64, alpha=2.0, beta=1.5, lm_path=args.lm, num_cpus=max(os.cpu_count(), 1),
# )
# beam_predictions = beam_search_with_lm(log_probs=log_probs_e, log_probs_length=encoded_len_e)
# eval_tensors.append(beam_predictions)

# evaluated_tensors = nf.infer(eval_tensors)
# if nf.global_rank in [0, None]:
# greedy_hypotheses = post_process_predictions(evaluated_tensors[1], vocab)
# references = post_process_transcripts(evaluated_tensors[2], evaluated_tensors[3], vocab)
# wer = word_error_rate(hypotheses=greedy_hypotheses, references=references)
# logging.info("Greedy WER: {:.2f}%".format(wer * 100))
# if wer > wer_thr:
# nf.sync_all_processes(False)
# raise ValueError(f"Final eval greedy WER {wer * 100:.2f}% > :" f"than {wer_thr * 100:.2f}%")
# nf.sync_all_processes()

# if nf.world_size == 1 and args.lm is not None:
# beam_hypotheses = []
# # Over mini-batch
# for i in evaluated_tensors[-1]:
# # Over samples
# for j in i:
# beam_hypotheses.append(j[0][1])

# beam_wer = word_error_rate(hypotheses=beam_hypotheses, references=references)
# logging.info("Beam WER {:.2f}%".format(beam_wer * 100))
# assert beam_wer <= beam_wer_thr, "Final eval beam WER {:.2f}% > than {:.2f}%".format(
# beam_wer * 100, beam_wer_thr * 100
# )
# assert beam_wer <= wer, "Final eval beam WER > than the greedy WER."

# # Reload model weights and train for extra 10 epochs
# checkpointer_callback = nemo.core.CheckpointCallback(
# folder=checkpoint_dir, step_freq=args.checkpoint_save_freq, force_load=True,
# )

# # Distributed Data Parallel changes the underlying class so we need
# # to reinstantiate Encoder and Decoder
# args.num_epochs += 10
# previous_step_count = total_steps
# loss, eval_tensors, callbacks, total_steps, _, _ = create_dags(args.model_config, vocab, args, nf)

# nf.reset_trainer()
# nf.train(
# tensors_to_optimize=[loss],
# callbacks=callbacks,
# optimizer=args.optimizer,
# lr_policy=CosineAnnealing(warmup_steps=previous_step_count, total_steps=total_steps),
# optimization_params={
# "num_epochs": args.num_epochs,
# "lr": args.lr / 100,
# "momentum": args.momentum,
# "betas": betas,
# "weight_decay": args.weight_decay,
# "grad_norm_clip": None,
# },
# reset=True,
# amp_max_loss_scale=256.0,
# # synced_batchnorm=(nf.global_rank is not None),
# )

# evaluated_tensors = nf.infer(eval_tensors)
# if nf.global_rank in [0, None]:
# greedy_hypotheses = post_process_predictions(evaluated_tensors[1], vocab)
# references = post_process_transcripts(evaluated_tensors[2], evaluated_tensors[3], vocab)
# wer_new = word_error_rate(hypotheses=greedy_hypotheses, references=references)
# logging.info("New greedy WER: {:.2f}%".format(wer_new * 100))
# if wer_new > wer * 1.1:
# nf.sync_all_processes(False)
# raise ValueError(
# f"Fine tuning: new WER {wer_new * 100:.2f}% > than the " f"previous WER {wer * 100:.2f}%"
# )
# nf.sync_all_processes()

# # Open the log file and ensure that epochs is strictly increasing
# if nf._exp_manager.log_file:
# epochs = []
# with open(nf._exp_manager.log_file, "r") as log_file:
# line = log_file.readline()
# while line:
# index = line.find("Starting epoch")
# if index != -1:
# epochs.append(int(line[index + len("Starting epoch") :]))
# line = log_file.readline()
# for i, e in enumerate(epochs):
# if i != e:
# raise ValueError("Epochs from logfile was not understood")


if __name__ == "__main__":
main()
28 changes: 15 additions & 13 deletions nemo/backends/pytorch/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from nemo.backends.pytorch.nm import DataLayerNM, TrainableNM
from nemo.backends.pytorch.optimizers import AdamW, Novograd, master_params
from nemo.core import DeploymentFormat, DeviceType, NeuralModule, NmTensor
from nemo.core.callbacks import ActionCallback, EvaluatorCallback, SimpleLossLoggerCallback
from nemo.core.neural_factory import Actions, OperationMode, Optimization
from nemo.core.callbacks import ActionCallback, EvaluatorCallback, NeMoCallback, SimpleLossLoggerCallback
from nemo.core.neural_factory import Actions, OperationMode, Optimization, TrainingState
from nemo.core.neural_types import *
from nemo.utils.helpers import get_checkpoint_from_dir

Expand Down Expand Up @@ -450,10 +450,10 @@ def __nm_graph_forward_pass(
if nm_tensor is None:
continue
t_name = nm_tensor.unique_name
if t_name not in registered_tensors:
if t_name not in registered_tensors or registered_tensors[t_name] is None:
registered_tensors[t_name] = t_tensor
else:
raise ValueError("A NMTensor was produced twice in " f"the same DAG. {t_name}")
raise ValueError(f"A NMTensor was produced twice in the same DAG. {t_name}")

@staticmethod
def pad_tensor(t: torch.Tensor, target_size: torch.Size):
Expand Down Expand Up @@ -1110,6 +1110,7 @@ def train(
gradient_predivide=False,
amp_max_loss_scale=2.0 ** 24,
):
self._training_state = TrainingState()
# Analyse the arguments passed to train.
if tensors_to_optimize is not None and training_graph is not None:
raise ValueError("Cannot pass both `tensors_to_optimize` and `training_graph` to the train() function")
Expand Down Expand Up @@ -1204,7 +1205,7 @@ def train(
# callbacks setup
if callbacks is not None:
for callback in callbacks:
if not isinstance(callback, ActionCallback):
if not isinstance(callback, ActionCallback) and not isinstance(callback, NeMoCallback):
raise ValueError("A callback was received that was not a child of ActionCallback")
elif isinstance(callback, SimpleLossLoggerCallback):
if logging_callchain:
Expand Down Expand Up @@ -1407,20 +1408,20 @@ def train(
else:
tensors.append(d)

registered_tensors = {
t.unique_name: d for t, d in zip(curr_call_chain[0][2].values(), tensors) if t is not None
}
for t, d in zip(curr_call_chain[0][2].values(), tensors):
if t is not None:
self.training_state.set_tensor(t, d)
disable_allreduce = batch_counter < (batches_per_step - 1)
self.__nm_graph_forward_pass(
call_chain=curr_call_chain, registered_tensors=registered_tensors,
call_chain=curr_call_chain, registered_tensors=self.training_state.tensor_dict,
)

curr_tensors_to_optimize = training_loop[self.step % len(training_loop)][1]
final_loss = 0
for tensor in curr_tensors_to_optimize:
if (
torch.isnan(registered_tensors[tensor.unique_name]).any()
or torch.isinf(registered_tensors[tensor.unique_name]).any()
torch.isnan(self.training_state.tensor_dict[tensor.unique_name]).any()
or torch.isinf(self.training_state.tensor_dict[tensor.unique_name]).any()
):
if (
(stop_on_nan_loss)
Expand All @@ -1436,7 +1437,7 @@ def train(
)
else:
logging.warning('Loss is NaN or inf, continuing training')
final_loss += registered_tensors[tensor.unique_name]
final_loss += self.training_state.tensor_dict[tensor.unique_name]

if self._optim_level in AmpOptimizations and self._optim_level != Optimization.mxprO0:
with amp.scale_loss(final_loss, curr_optimizer, delay_unscale=disable_allreduce) as scaled_loss:
Expand Down Expand Up @@ -1479,10 +1480,11 @@ def train(
batch_counter = 0
# Register iteration end with callbacks
self._update_callbacks(
callbacks=callbacks, registered_tensors=registered_tensors,
callbacks=callbacks, registered_tensors=self.training_state.tensor_dict, final_loss=final_loss
)
self._perform_on_iteration_end(callbacks=callbacks)
self.step += 1
self.training_state.clear_dict()
# End of epoch for loop
# Register epochs end with callbacks
self._perform_on_epoch_end(callbacks=callbacks)
Expand Down
Loading

0 comments on commit b16d356

Please sign in to comment.