Skip to content

Commit

Permalink
update mode
Browse files Browse the repository at this point in the history
  • Loading branch information
Ji Chen committed Jul 17, 2020
1 parent 345a2b8 commit 79b653b
Showing 1 changed file with 22 additions and 19 deletions.
41 changes: 22 additions & 19 deletions examples/pipeline_wavernn/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,10 @@ def parse_args():
"--n-freq", default=80, type=int, help="the number of spectrogram bins to use",
)
parser.add_argument(
"--n-hidden-resblock", default=128, type=int, help="the number of hidden dimensions of resblock",
"--n-hidden-resblock",
default=128,
type=int,
help="the number of hidden dimensions of resblock",
)
parser.add_argument(
"--n-output-melresnet",
Expand All @@ -132,17 +135,17 @@ def parse_args():
"--n-fft", default=2048, type=int, help="the number of Fourier bins",
)
parser.add_argument(
"--loss",
default="waveform",
choices=["waveform", "mol"],
"--loss-fn",
default="crossentropy",
choices=["crossentropy", "mol"],
type=str,
help="the type of loss",
help="the type of loss function",
)
parser.add_argument(
"--seq-len-factor",
default=5,
type=int,
help="the length factor of input waveform, the length of input waveform = hop_length * seq_len_factor",
help="the length of each waveform to process per batch = hop_length * seq_len_factor",
)
parser.add_argument(
"--val-ratio",
Expand All @@ -151,14 +154,14 @@ def parse_args():
help="the ratio of waveforms for validation",
)
parser.add_argument(
"--file-path", default="/private/home/jimchen90/datasets", type=str, help="the path of audio files",
"--file-path", default="", type=str, help="the path of audio files",
)

args = parser.parse_args()
return args


def train_one_epoch(model, loss, criterion, optimizer, data_loader, device, epoch):
def train_one_epoch(model, loss_fn, criterion, optimizer, data_loader, device, epoch):

model.train()

Expand All @@ -179,7 +182,7 @@ def train_one_epoch(model, loss, criterion, optimizer, data_loader, device, epoc
output = model(waveform, specgram)
output, target = output.squeeze(1), target.squeeze(1)

if loss == "waveform":
if loss_fn == "crossentropy":
output = output.transpose(1, 2)
target = target.long()

Expand Down Expand Up @@ -219,7 +222,7 @@ def train_one_epoch(model, loss, criterion, optimizer, data_loader, device, epoc
metric()


def validate(model, loss, criterion, data_loader, device, epoch):
def validate(model, loss_fn, criterion, data_loader, device, epoch):

with torch.no_grad():

Expand All @@ -236,7 +239,7 @@ def validate(model, loss, criterion, data_loader, device, epoch):
output = model(waveform, specgram)
output, target = output.squeeze(1), target.squeeze(1)

if loss == "waveform":
if loss_fn == "crossentropy":
output = output.transpose(1, 2)
target = target.long()

Expand Down Expand Up @@ -308,19 +311,19 @@ def main(args):
**loader_validation_params,
)

n_classes = 2 ** args.n_bits if args.loss_fn == "crossentropy" else 30

model = _WaveRNN(
upsample_scales=args.upsample_scales,
n_bits=args.n_bits,
sample_rate=args.sample_rate,
n_classes=n_classes,
hop_length=args.hop_length,
n_res_block=args.n_res_block,
n_rnn=args.n_rnn,
n_fc=args.n_fc,
kernel_size=args.kernel_size,
n_freq=args.n_freq,
n_hidden_resblock=args.n_hidden_resblock,
n_output_melresnet=args.n_output_melresnet,
loss=args.loss,
n_hidden=args.n_hidden,
n_output=args.n_output,
)

model = torch.nn.DataParallel(model)
Expand All @@ -336,7 +339,7 @@ def main(args):

optimizer = Adam(model.parameters(), **optimizer_params)

criterion = nn.CrossEntropyLoss() if args.loss == "waveform" else MoLLoss()
criterion = nn.CrossEntropyLoss() if args.loss_fn == "crossentropy" else MoLLoss()

best_loss = 10.0

Expand Down Expand Up @@ -370,13 +373,13 @@ def main(args):
for epoch in range(args.start_epoch, args.epochs):

train_one_epoch(
model, args.loss, criterion, optimizer, train_loader, devices[0], epoch,
model, args.loss_fn, criterion, optimizer, train_loader, devices[0], epoch,
)

if not (epoch + 1) % args.print_freq or epoch == args.epochs - 1:

sum_loss = validate(
model, args.loss, criterion, val_loader, devices[0], epoch,
model, args.loss_fn, criterion, val_loader, devices[0], epoch,
)

is_best = sum_loss < best_loss
Expand Down

0 comments on commit 79b653b

Please sign in to comment.