-
Notifications
You must be signed in to change notification settings - Fork 0
/
process_text.py
107 lines (92 loc) · 3.53 KB
/
process_text.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from typing import Dict
from argparse import ArgumentParser
from pathlib import Path
import numpy as np
import logging
from tqdm import tqdm
import json
from math import floor
def build_vocab(tsv_folder: Path) -> Dict[str, int]:
vocab = {'PAD': 0, 'UNK': 1}
unique_symbols = set()
for tsv in tqdm(tsv_folder.glob('*.tsv')):
with open(tsv, 'r') as f:
for line in f:
line = line.strip()
if len(line) == 0:
continue
parts = line.split('\t')
if len(parts) == 3:
_, _, word = parts
elif len(parts) == 2:
word = ' '
else:
logging.warning(f'Bad line: {line} in file {tsv}')
continue
for s in word.lower():
unique_symbols.add(s)
for s in unique_symbols:
vocab[s] = len(vocab)
return vocab
def process_tsv(tsv: Path, vocab: Dict[str, int], fps: float = 30.):
starts, ends, words = [], [], []
with open(tsv, 'r') as f:
for line in f:
line = line.strip()
if len(line) == 0:
continue
parts = line.split('\t')
if len(parts) == 3:
start, end, word = parts
elif len(parts) == 2:
start, end = parts
word = ' '
else:
logging.warning(f'Bad line: {line} in file {tsv}')
continue
starts.append(float(start))
ends.append(float(end))
words.append(word.lower())
if len(ends) == 0:
return np.zeros(1)
indexes = np.zeros(floor(fps * ends[-1]))
for start, end, word in zip(starts, ends, words):
word_start = floor(fps * start)
word_end = floor(fps * end)
word_len = word_end - word_start
for i, s in enumerate(word):
symbol_start = word_start + floor(i * word_len / len(word))
symbol_end = min(word_start + floor((i + 1) * word_len / len(word)), word_end)
symbol_idx = vocab[s] if s in vocab else vocab['UNK']
indexes[symbol_start:symbol_end] = np.ones(symbol_end - symbol_start) * symbol_idx
return indexes
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
arg_parser = ArgumentParser()
arg_parser.add_argument('--src', help='Path to tsv folder')
arg_parser.add_argument('--dst', help='Path to store results')
arg_parser.add_argument('--vocab', help='Path to store or load vocab')
arg_parser.add_argument('--fps', type=float, help='Audio features framerate', default=30.)
args = arg_parser.parse_args()
vocab_path = Path(args.vocab)
src_path = Path(args.src)
if not vocab_path.exists():
logging.info('Building new vocab...')
vocab = build_vocab(src_path)
with open(vocab_path, 'w') as f:
json.dump(vocab, f, indent=4)
else:
logging.info('Loading vocab from file...')
with open(vocab_path, 'r') as f:
vocab = json.load(f)
dst_path = Path(args.dst)
if not dst_path.exists():
dst_path.mkdir()
weights = np.identity(len(vocab) - 1)
weights = np.concatenate([np.zeros((1, len(weights))), weights], axis=0)
for tsv in tqdm(src_path.glob('*.tsv')):
indexes = process_tsv(tsv, vocab, args.fps)
embeddings = weights[indexes.astype(int)]
dst_sample = dst_path / tsv.name.replace('.tsv', '.npy')
np.save(dst_sample, embeddings)
# print(indexes)