forked from breizhn/DTLN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_evaluation.py
125 lines (105 loc) · 4.46 KB
/
run_evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# -*- coding: utf-8 -*-
"""
Script to process a folder of .wav files with a trained DTLN model.
This script supports subfolders and names the processed files the same as the
original. The model expects 16kHz audio .wav files. Files with other
sampling rates will be resampled. Stereo files will be downmixed to mono.
The idea of this script is to use it for baseline or comparison purpose.
Example call:
$python run_evaluation.py -i /name/of/input/folder \
-o /name/of/output/folder \
-m /name/of/the/model.h5
Author: Nils L. Westhausen ([email protected])
Version: 13.05.2020
This code is licensed under the terms of the MIT-license.
"""
import soundfile as sf
import librosa
import numpy as np
import os
import argparse
from DTLN_model import DTLN_model
def process_file(model, audio_file_name, out_file_name):
'''
Funtion to read an audio file, rocess it by the network and write the
enhanced audio to .wav file.
Parameters
----------
model : Keras model
Keras model, which accepts audio in the size (1,timesteps).
audio_file_name : STRING
Name and path of the input audio file.
out_file_name : STRING
Name and path of the target file.
'''
# read audio file with librosa to handle resampling and enforce mono
in_data,fs = librosa.core.load(audio_file_name, sr=16000, mono=True)
# predict audio with the model
predicted = model.predict_on_batch(
np.expand_dims(in_data,axis=0).astype(np.float32))
# squeeze the batch dimension away
predicted_speech = np.squeeze(predicted)
# write the file to target destination
sf.write(out_file_name, predicted_speech,fs)
def process_folder(model, folder_name, new_folder_name):
'''
Function to find .wav files in the folder and subfolders of "folder_name",
process each .wav file with an algorithm and write it back to disk in the
folder "new_folder_name". The structure of the original directory is
preserved. The processed files will be saved with the same name as the
original file.
Parameters
----------
model : Keras model
Keras model, which accepts audio in the size (1,timesteps).
folder_name : STRING
Input folder with .wav files.
new_folder_name : STRING
Traget folder for the processed files.
'''
# empty list for file and folder names
file_names = [];
directories = [];
new_directories = [];
# walk through the directory
for root, dirs, files in os.walk(folder_name):
for file in files:
# look for .wav files
if file.endswith(".wav"):
# write paths and filenames to lists
file_names.append(file)
directories.append(root)
# create new directory names
new_directories.append(root.replace(folder_name, new_folder_name))
# check if the new directory already exists, if not create it
if not os.path.exists(root.replace(folder_name, new_folder_name)):
os.makedirs(root.replace(folder_name, new_folder_name))
# iterate over all .wav files
for idx in range(len(file_names)):
# process each file with the model
process_file(model, os.path.join(directories[idx],file_names[idx]),
os.path.join(new_directories[idx],file_names[idx]))
print(file_names[idx] + ' processed successfully!')
if __name__ == '__main__':
# arguement parser for running directly from the command line
parser = argparse.ArgumentParser(description='data evaluation')
parser.add_argument('--in_folder', '-i',
help='folder with input files')
parser.add_argument('--out_folder', '-o',
help='target folder for processed files')
parser.add_argument('--model', '-m',
help='weights of the enhancement model in .h5 format')
args = parser.parse_args()
# determine type of model
if args.model.find('_norm_') != -1:
norm_stft = True
else:
norm_stft = False
# create class instance
modelClass = DTLN_model();
# build the model in default configuration
modelClass.build_DTLN_model(norm_stft=norm_stft)
# load weights of the .h5 file
modelClass.model.load_weights(args.model)
# process the folder
process_folder(modelClass.model, args.in_folder, args.out_folder)