Skip to content

Commit

Permalink
Add WandB support to contextnet.py
Browse files Browse the repository at this point in the history
Signed-off-by: smajumdar <[email protected]>
  • Loading branch information
titu1994 committed May 14, 2020
1 parent a8d7f4c commit 6d3e4ca
Showing 1 changed file with 31 additions and 5 deletions.
36 changes: 31 additions & 5 deletions examples/asr/contextnet.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,23 @@
# Copyright (c) 2019 NVIDIA Corporation
# Copyright (C) NVIDIA CORPORATION. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.****

import argparse
import copy
import os
from functools import partial

import wandb
from ruamel.yaml import YAML

import nemo
Expand Down Expand Up @@ -44,13 +58,17 @@ def parse_args():

# Create new args
parser.add_argument("--exp_name", default="ContextNet", type=str)
parser.add_argument("--project", default=None, type=str)
parser.add_argument("--beta1", default=0.95, type=float)
parser.add_argument("--beta2", default=0.5, type=float)
parser.add_argument("--warmup_steps", default=1000, type=int)
parser.add_argument('--min_lr', default=1e-3, type=float)
parser.add_argument("--warmup_ratio", default=None, type=float)
parser.add_argument('--min_lr', default=1e-5, type=float)
parser.add_argument("--load_dir", default=None, type=str)
parser.add_argument("--synced_bn", action='store_true', help="Use synchronized batch norm")
parser.add_argument("--synced_bn_groupsize", default=0, type=int)
parser.add_argument("--update_freq", default=50, type=int, help="Metrics update freq")
parser.add_argument("--eval_freq", default=1000, type=int, help="Evaluation frequency")
parser.add_argument('--kernel_size_factor', default=1.0, type=float)

args = parser.parse_args()
Expand Down Expand Up @@ -96,8 +114,7 @@ def create_all_dags(args, neural_factory):
labels=vocab,
batch_size=args.batch_size,
num_workers=cpu_per_traindl,
**train_dl_params,
# normalize_transcripts=False
**train_dl_params
)

N = len(data_layer_train)
Expand Down Expand Up @@ -191,6 +208,7 @@ def create_all_dags(args, neural_factory):
print_func=partial(monitor_asr_train_progress, labels=vocab),
get_tb_values=lambda x: [["loss", x[0]]],
tb_writer=neural_factory.tb_writer,
step_freq=args.update_freq,
)

callbacks = [train_callback]
Expand All @@ -202,6 +220,14 @@ def create_all_dags(args, neural_factory):

callbacks.append(chpt_callback)

# Log training metrics to wandb
if args.project is not None:
wand_callback = nemo.core.WandbCallback(train_tensors=[loss_t],
wandb_name=args.exp_name, wandb_project=args.project,
update_freq=args.update_freq,
args=args)
callbacks.append(wand_callback)

# assemble eval DAGs
for i, eval_dl in enumerate(data_layers_eval):
(audio_signal_e, a_sig_length_e, transcript_e, transcript_len_e,) = eval_dl()
Expand Down Expand Up @@ -272,7 +298,7 @@ def main():
tensors_to_optimize=[train_loss],
callbacks=callbacks,
lr_policy=CosineAnnealing(
args.num_epochs * steps_per_epoch, warmup_steps=args.warmup_steps, min_lr=args.min_lr
args.num_epochs * steps_per_epoch, warmup_steps=args.warmup_steps, warmup_ratio=args.warmup_ratio, min_lr=args.min_lr
),
optimizer=args.optimizer,
optimization_params={
Expand Down

0 comments on commit 6d3e4ca

Please sign in to comment.