-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
118 lines (99 loc) · 4.8 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import argparse
import os
import warnings
from glob import glob
import numpy as np
import tensorflow as tf
from scipy.io import wavfile
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from tensorflow.keras.callbacks import CSVLogger, ModelCheckpoint
from tensorflow.keras.utils import to_categorical
from model import Conv1D, Conv2D, LSTM
class DataGenerator(tf.keras.utils.Sequence):
def __init__(self, wav_paths, labels, sr, dt, n_classes,
batch_size=32, shuffle=True):
self.wav_paths = wav_paths
self.labels = labels
self.sr = sr
self.dt = dt
self.n_classes = n_classes
self.batch_size = batch_size
self.shuffle = True
self.on_epoch_end()
def __len__(self):
return int(np.floor(len(self.wav_paths) / self.batch_size))
def __getitem__(self, index):
indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
wav_paths = [self.wav_paths[k] for k in indexes]
labels = [self.labels[k] for k in indexes]
# generate a batch of time data
X = np.empty((self.batch_size, int(self.sr * self.dt), 1), dtype=np.float32)
Y = np.empty((self.batch_size, self.n_classes), dtype=np.float32)
for i, (path, label) in enumerate(zip(wav_paths, labels)):
rate, wav = wavfile.read(path)
X[i,] = wav.reshape(-1, 1)
Y[i,] = to_categorical(label, num_classes=self.n_classes)
return X, Y
def on_epoch_end(self):
self.indexes = np.arange(len(self.wav_paths))
if self.shuffle:
np.random.shuffle(self.indexes)
def train(args):
src_root = args.src_root
sr = args.sample_rate
dt = args.delta_time
batch_size = args.batch_size
model_type = args.model_type
params = {'N_CLASSES': len(os.listdir(args.src_root)),
'SR': sr,
'DT': dt}
models = {'conv1d': Conv1D(**params),
'conv2d': Conv2D(**params),
'lstm': LSTM(**params)}
assert model_type in models.keys(), '{} not an available model'.format(model_type)
csv_path = os.path.join('logs', '{}_history.csv'.format(model_type))
wav_paths = glob('{}/**'.format(src_root), recursive=True)
wav_paths = [x.replace(os.sep, '/') for x in wav_paths if '.wav' in x]
classes = sorted(os.listdir(args.src_root))
le = LabelEncoder()
le.fit(classes)
labels = [os.path.split(x)[0].split('/')[-1] for x in wav_paths]
labels = le.transform(labels)
wav_train, wav_val, label_train, label_val = train_test_split(wav_paths,
labels,
test_size=0.2,
random_state=0)
assert len(label_train) >= args.batch_size, 'Number of train samples must be >= batch_size'
if len(set(label_train)) != params['N_CLASSES']:
warnings.warn('Found {}/{} classes in training data. Increase data size or change random_state.'.format(
len(set(label_train)), params['N_CLASSES']))
if len(set(label_val)) != params['N_CLASSES']:
warnings.warn('Found {}/{} classes in validation data. Increase data size or change random_state.'.format(
len(set(label_val)), params['N_CLASSES']))
tg = DataGenerator(wav_train, label_train, sr, dt,
params['N_CLASSES'], batch_size=batch_size)
vg = DataGenerator(wav_val, label_val, sr, dt,
params['N_CLASSES'], batch_size=batch_size)
model = models[model_type]
cp = ModelCheckpoint('models/{}.h5'.format(model_type), monitor='val_loss',
save_best_only=True, save_weights_only=False,
mode='auto', save_freq='epoch', verbose=1)
csv_logger = CSVLogger(csv_path, append=False)
model.fit(tg, validation_data=vg,
epochs=30, verbose=1,
callbacks=[csv_logger, cp])
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Audio Classification Training')
parser.add_argument('--model_type', type=str, default='lstm',
help='model to run. i.e. conv1d, conv2d, lstm')
parser.add_argument('--src_root', type=str, default='clean',
help='directory of audio files in total duration')
parser.add_argument('--batch_size', type=int, default=16,
help='batch size')
parser.add_argument('--delta_time', '-dt', type=float, default=1.0,
help='time in seconds to sample audio')
parser.add_argument('--sample_rate', '-sr', type=int, default=16000,
help='sample rate of clean audio')
args, _ = parser.parse_known_args()
train(args)