-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference2.py
70 lines (58 loc) · 2.65 KB
/
inference2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import sys
import glob
import argparse
import torch
import audiosegment
import matplotlib.pyplot as plt
import numpy as np
from utils.audio import MelGen
from utils.plotting import plot_spectrogram_to_numpy
from utils.reconstruct import Reconstruct
from utils.constant import t_div
from utils.hparams import HParam
from model.model import MelNet
def parse_args(args):
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, required=True,
help="yaml file for configuration")
parser.add_argument('-p', '--infer_config', type=str, required=True,
help="yaml file for inference configuration")
parser.add_argument('-t', '--timestep', type=int, default=240,
help="timestep of mel-spectrogram to generate")
parser.add_argument('-n', '--name', type=str, default="result", required=False,
help="Name for sample")
parser.add_argument('-i', '--input', type=str, default=None, required=False,
help="Input for conditional generation, leave empty for unconditional")
return parser.parse_args(args)
if __name__ == '__main__':
args = parse_args(sys.argv[1:])
hp = HParam(args.config)
infer_hp = HParam(args.infer_config)
assert args.timestep % t_div[hp.model.tier] == 0, \
"timestep should be divisible by %d, got %d" % (t_div[hp.model.tier], args.timestep)
melgen = MelGen(hp)
model = MelNet(hp, args, infer_hp).cuda()
model.load_tiers()
model.eval()
with torch.no_grad():
generated = model.sample(args.input)
os.makedirs('temp', exist_ok=True)
torch.save(generated, os.path.join('temp', args.name + '.pt'))
spectrogram = plot_spectrogram_to_numpy(generated[0].cpu().detach().numpy())
plt.imsave(os.path.join('temp', args.name + '.png'), spectrogram.transpose((1, 2, 0)))
# waveform, wavespec = Reconstruct(hp).inverse(generated[0])
# wavespec = plot_spectrogram_to_numpy(wavespec.cpu().detach().numpy())
# plt.imsave(os.path.join('temp', 'Final ' + args.name + '.png'), wavespec.transpose((1, 2, 0)))
# waveform = waveform.unsqueeze(-1)
# waveform = waveform.cpu().detach().numpy()
# waveform *= 32768 / waveform.max()
# waveform = waveform.astype(np.int16)
# audio = audiosegment.from_numpy_array(
# waveform,
# framerate=hp.audio.sr
# )
# audio.export(os.path.join('temp', args.name + '.wav'), format='wav')
constructed_mel_tensor = generated[0].cpu().detach().numpy()
audio = melgen.reconstruct_audio(constructed_mel_tensor)
melgen.save_audio('temp/constructed_'+args.name+'.wav', audio)