From 6d3e4cab7cc6a8acb9ef5fb34dab6f6c2c3776be Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Thu, 14 May 2020 14:32:13 -0700 Subject: [PATCH] Add WandB support to contextnet.py Signed-off-by: smajumdar --- examples/asr/contextnet.py | 36 +++++++++++++++++++++++++++++++----- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/examples/asr/contextnet.py b/examples/asr/contextnet.py index 69d55b743f60..0aa2ae03ef30 100644 --- a/examples/asr/contextnet.py +++ b/examples/asr/contextnet.py @@ -1,9 +1,23 @@ -# Copyright (c) 2019 NVIDIA Corporation +# Copyright (C) NVIDIA CORPORATION. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.**** + import argparse import copy import os from functools import partial +import wandb from ruamel.yaml import YAML import nemo @@ -44,13 +58,17 @@ def parse_args(): # Create new args parser.add_argument("--exp_name", default="ContextNet", type=str) + parser.add_argument("--project", default=None, type=str) parser.add_argument("--beta1", default=0.95, type=float) parser.add_argument("--beta2", default=0.5, type=float) parser.add_argument("--warmup_steps", default=1000, type=int) - parser.add_argument('--min_lr', default=1e-3, type=float) + parser.add_argument("--warmup_ratio", default=None, type=float) + parser.add_argument('--min_lr', default=1e-5, type=float) parser.add_argument("--load_dir", default=None, type=str) parser.add_argument("--synced_bn", action='store_true', help="Use synchronized batch norm") parser.add_argument("--synced_bn_groupsize", default=0, type=int) + parser.add_argument("--update_freq", default=50, type=int, help="Metrics update freq") + parser.add_argument("--eval_freq", default=1000, type=int, help="Evaluation frequency") parser.add_argument('--kernel_size_factor', default=1.0, type=float) args = parser.parse_args() @@ -96,8 +114,7 @@ def create_all_dags(args, neural_factory): labels=vocab, batch_size=args.batch_size, num_workers=cpu_per_traindl, - **train_dl_params, - # normalize_transcripts=False + **train_dl_params ) N = len(data_layer_train) @@ -191,6 +208,7 @@ def create_all_dags(args, neural_factory): print_func=partial(monitor_asr_train_progress, labels=vocab), get_tb_values=lambda x: [["loss", x[0]]], tb_writer=neural_factory.tb_writer, + step_freq=args.update_freq, ) callbacks = [train_callback] @@ -202,6 +220,14 @@ def create_all_dags(args, neural_factory): callbacks.append(chpt_callback) + # Log training metrics to wandb + if args.project is not None: + wand_callback = nemo.core.WandbCallback(train_tensors=[loss_t], + wandb_name=args.exp_name, wandb_project=args.project, + update_freq=args.update_freq, + args=args) + callbacks.append(wand_callback) + # assemble eval DAGs for i, eval_dl in enumerate(data_layers_eval): (audio_signal_e, a_sig_length_e, transcript_e, transcript_len_e,) = eval_dl() @@ -272,7 +298,7 @@ def main(): tensors_to_optimize=[train_loss], callbacks=callbacks, lr_policy=CosineAnnealing( - args.num_epochs * steps_per_epoch, warmup_steps=args.warmup_steps, min_lr=args.min_lr + args.num_epochs * steps_per_epoch, warmup_steps=args.warmup_steps, warmup_ratio=args.warmup_ratio, min_lr=args.min_lr ), optimizer=args.optimizer, optimization_params={