forked from lyx199504/mc-lstm-time-series
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
57 lines (45 loc) · 2.33 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2022/2/11 19:02
# @Author : LYX-夜光
import numpy as np
from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score
from dataPreprocessing import getDataset, standard
from optUtils import yaml_config
from optUtils.dataUtil import stratified_shuffle_index
from dl_models.cnn import CNN
from dl_models.ms_cnn import MS_CNN
from dl_models.c_lstm import C_LSTM
from dl_models.c_lstm_ae import C_LSTM_AE
from dl_models.imc_lstm import IMC_LSTM
from dl_models.smc_lstm import SMC_LSTM
from dl_models.cmc_lstm import CMC_LSTM
if __name__ == "__main__":
seq_len = 60
model_list = [CNN, MS_CNN, C_LSTM, IMC_LSTM, CMC_LSTM, SMC_LSTM, C_LSTM_AE]
dataset_list = ['realAdExchange', 'realTraffic', 'realKnownCause', 'realAWSCloudwatch', 'A1Benchmark', 'realTweets']
for model_clf in model_list:
for dataset_name in dataset_list:
X, y, r = getDataset(dataset_name, seq_len=seq_len, pre_list=[standard])
seed, fold = yaml_config['cus_param']['seed'], yaml_config['cv_param']['fold']
# 根据r的取值数量分层抽样
shuffle_index = stratified_shuffle_index(r, n_splits=fold, random_state=seed)
X, y = X[shuffle_index], y[shuffle_index]
if model_clf == C_LSTM_AE:
length = 10
X = np.array([[x[i: i + length] for i in range(len(x) - length + 1)] for x in X])
P, total = sum(y > 0), len(y)
print("+: %d (%.2f%%)" % (P, P / total * 100), "-: %d (%.2f%%)" % (total - P, (1 - P / total) * 100))
train_point, val_point = int(len(X) * 0.6), int(len(X) * 0.8)
model = model_clf(learning_rate=0.001, batch_size=512, epochs=500, random_state=1, seq_len=seq_len)
# model.create_model()
# print(sum([param.nelement() for param in model.parameters()]))
# exit()
model.model_name += "_%s" % dataset_name
model.param_search = False
model.save_model = True
model.device = 'cuda'
model.metrics = f1_score
model.metrics_list = [recall_score, precision_score, accuracy_score]
model.fit(X[:train_point], y[:train_point], X[train_point:val_point], y[train_point:val_point])
model.test_score(X[val_point:], y[val_point:])