-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathvisualize_alignment.py
95 lines (82 loc) · 3.94 KB
/
visualize_alignment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# coding: utf-8
"""
usage: visualize_alignment.py [options] <filename>
options:
--output-prefix=<prefix> output filename prefix
"""
from docopt import docopt
import numpy as np
import tensorflow as tf
from collections import namedtuple
import matplotlib
import os
matplotlib.use('Agg')
from matplotlib import pyplot as plt
class TrainingResult(
namedtuple("TrainingResult",
["global_step", "id", "text", "predicted_mel", "ground_truth_mel", "alignments"])):
pass
def read_training_result(filename):
record_iterator = tf.python_io.tf_record_iterator(filename)
for string_record in record_iterator:
example = tf.train.Example()
example.ParseFromString(string_record)
global_step = example.features.feature['global_step'].int64_list.value[0]
batch_size = example.features.feature['batch_size'].int64_list.value[0]
id = example.features.feature['id'].int64_list.value
text = example.features.feature['text'].bytes_list.value
predicted_mel = example.features.feature['predicted_mel'].bytes_list.value
ground_truth_mel = example.features.feature['ground_truth_mel'].bytes_list.value
mel_length = example.features.feature['mel_length'].int64_list.value
mel_width = example.features.feature['mel_width'].int64_list.value[0]
alignment = example.features.feature['alignment'].bytes_list.value
alignment_source_length = example.features.feature['alignment_source_length'].int64_list.value
alignment_target_length = example.features.feature['alignment_target_length'].int64_list.value
texts = (t.decode('utf-8') for t in text)
alignments = [np.frombuffer(align, dtype=np.float32).reshape([batch_size, src_len, tgt_len]) for align, src_len, tgt_len in
zip(alignment, alignment_source_length, alignment_target_length)]
alignments = [[a[i].T for a in alignments] for i in range(batch_size)]
predicted_mels = (np.frombuffer(mel, dtype=np.float32).reshape([-1, mel_width]) for mel, mel_len in
zip(predicted_mel, mel_length))
ground_truth_mels = (np.frombuffer(mel, dtype=np.float32).reshape([mel_len, mel_width]) for mel, mel_len in
zip(ground_truth_mel, mel_length))
for id, text, align, pred_mel, gt_mel in zip(id, texts, alignments, predicted_mels, ground_truth_mels):
yield TrainingResult(
global_step=global_step,
id=id,
text=text,
predicted_mel=pred_mel,
ground_truth_mel=gt_mel,
alignments=align,
)
def save_alignment(alignments, text, _id, path, info=None):
num_alignment = len(alignments)
fig = plt.figure(figsize=(12, 16))
for i, alignment in enumerate(alignments):
ax = fig.add_subplot(num_alignment, 1, i + 1)
im = ax.imshow(
alignment,
aspect='auto',
origin='lower',
interpolation='none')
fig.colorbar(im, ax=ax)
xlabel = 'Decoder timestep'
if info is not None:
xlabel += '\n\n' + info
ax.set_xlabel(xlabel)
ax.set_ylabel('Encoder timestep')
ax.set_title("layer {}".format(i+1))
ax.hlines(len(text), xmin=0, xmax=alignment.shape[1], colors=['red'])
fig.subplots_adjust(wspace=0.4, hspace=0.6)
fig.suptitle(f"record ID: {_id}, input text: {str(text)}")
fig.savefig(path, format='png')
plt.close()
if __name__ == "__main__":
args = docopt(__doc__)
filename = args["<filename>"]
prefix = args["--output-prefix"] or "alignment_"
output_base_filename, _ = os.path.splitext(os.path.basename(filename))
output_dir = os.path.dirname(filename)
output_filename = prefix + output_base_filename + "_{}.png"
for result in read_training_result(filename):
save_alignment(result.alignments, result.text, result.id, os.path.join(output_dir, output_filename).format(result.id))