forked from mxmaxi007/Variable_Length_Emotion_Recognition
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMain.py
128 lines (100 loc) · 5.31 KB
/
Main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# encoding=UTF-8
import sys
import os
import math
import re
import time
import shutil
import numpy as np
import Preprocess.Spectrogram as Spectrogram
import Preprocess.Load_Data as Load_Data
import Preprocess.Normalization as Normalization
import Model.CNN_Const as CNN_Const
import Model.CNN_LSTM_Const as CNN_LSTM_Const
import Model.CNN_LSTM_Attention_Const as CNN_LSTM_Attention_Const
import Model.CNN_RNN_Var as CNN_RNN_Var
import Model.CNN_RNN_Const as CNN_RNN_Const
import Metrics.Accuracy as Accuracy
def main():
if len(sys.argv) != 7:
print('Usage: python3 ' + sys.argv[0] + ' wav_dir_path spectrogram_dir_path output_dir test_session classifer_type spectrogram_type\n');
sys.exit(2);
start = time.time();
wav_dir_path = sys.argv[1];
spectrogram_dir_path = sys.argv[2];
output_dir = sys.argv[3];
test_session = int(sys.argv[4]);
classifer_type = sys.argv[5];
spectrogram_type = sys.argv[6];
session_list = ["Ses01", "Ses02", "Ses03", "Ses04", "Ses05"];
emo_dict = {0: "Neutral", 1: "Angry", 2: "Happy", 3: "Sad"};
emo_num = 4;
os.environ["CUDA_VISIBLE_DEVICES"] = "3";
shutil.rmtree(output_dir, ignore_errors=True);
os.mkdir(output_dir);
# Spectrogram.wav_preprocess(wav_dir_path, spectrogram_dir_path, spectrogram_type);
if spectrogram_type == "Const":
x_train, y_train, x_validation, y_validation, x_test, y_test, weight_dict, sample_weight_list = Load_Data.load_spectrogram_const(
spectrogram_dir_path, session_list, test_session, emo_dict);
x_train, y_train, sample_weight_list = Normalization.length_sort(x_train, y_train, sample_weight_list);
mean_vec, std_vec = Normalization.get_mean_variance(x_train);
normal_para = np.array([mean_vec, std_vec]);
normal_para_npy_path = os.path.join(output_dir, "normal_para_const.npy");
np.save(normal_para_npy_path, normal_para);
normal_para = np.load(normal_para_npy_path);
mean_vec = normal_para[0];
std_vec = normal_para[1];
x_train = Normalization.normalize(x_train, mean_vec, std_vec);
x_validation = Normalization.normalize_list(x_validation, mean_vec, std_vec);
x_test = Normalization.normalize_list(x_test, mean_vec, std_vec);
# model_file = os.path.join(output_dir, classifer_type + "_" + spectrogram_type + ".model");
#
# if classifer_type == "CNN":
# CNN_Const.model_train(x_train, y_train, emo_num, weight_dict, model_file);
# elif classifer_type == "CNN_LSTM":
# CNN_LSTM_Const.model_train(x_train, y_train, emo_num, weight_dict, model_file);
# elif classifer_type == "CNN_LSTM_Attention":
# CNN_LSTM_Attention_Const.model_train(x_train, y_train, emo_num, weight_dict, model_file);
re_train = False;
if classifer_type == "CNN_RNN":
if re_train:
CNN_RNN_Const.model_re_train(x_train, y_train, x_test, y_test, emo_num, sample_weight_list,
weight_dict, output_dir);
else:
CNN_RNN_Const.model_train(x_train, y_train, x_test, y_test, emo_num, sample_weight_list,
weight_dict, output_dir);
Accuracy.accuracy_const(output_dir, x_test, y_test, emo_num, emo_dict);
elif spectrogram_type == "Var":
x_train, y_train, x_validation, y_validation, x_test, y_test, weight_dict, sample_weight_list = Load_Data.load_spectrogram_var(
spectrogram_dir_path, session_list, test_session, emo_dict);
x_train, y_train, sample_weight_list = Normalization.length_sort(x_train, y_train, sample_weight_list);
mean_vec, std_vec = Normalization.get_mean_variance(x_train);
normal_para = np.array([mean_vec, std_vec]);
normal_para_npy_path = os.path.join(output_dir, "normal_para_var.npy");
np.save(normal_para_npy_path, normal_para);
normal_para = np.load(normal_para_npy_path);
mean_vec = normal_para[0];
std_vec = normal_para[1];
x_train = Normalization.normalize(x_train, mean_vec, std_vec);
x_validation = Normalization.normalize(x_validation, mean_vec, std_vec);
x_test = Normalization.normalize(x_test, mean_vec, std_vec);
re_train = False;
if classifer_type == "CNN_RNN":
if re_train:
CNN_RNN_Var.model_re_train(x_train, y_train, x_test, y_test, emo_num, sample_weight_list,
weight_dict, output_dir);
else:
CNN_RNN_Var.model_train(x_train, y_train, x_test, y_test, emo_num, sample_weight_list,
weight_dict, output_dir);
# if re_train:
# CNN_RNN_Var.model_re_train(x_train, y_train, x_train, y_train, emo_num, sample_weight_list,
# weight_dict, output_dir);
# else:
# CNN_RNN_Var.model_train(x_train, y_train, x_train, y_train, emo_num, sample_weight_list,
# weight_dict, output_dir);
Accuracy.accuracy_var(output_dir, x_test, y_test, emo_num, emo_dict);
# Accuracy.accuracy_var(model_dir, x_train, y_train, emo_num, emo_dict);
end = time.time();
print("Total Time: {}s".format(end - start));
if __name__ == "__main__":
main();