-
Notifications
You must be signed in to change notification settings - Fork 2
/
base_predictor_LSTM.m
83 lines (70 loc) · 1.98 KB
/
base_predictor_LSTM.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
%% LSTM基预测器
%% 清空环境变量
clc;clear;close all;
%% 训练测试数据
load dataset.mat
trainXn=dataset.trainXn;
trainYn=dataset.trainYn;
testXn=dataset.validXn;
testY=dataset.validY;
outputps=dataset.outputps;
%% LSTM网络构建
inputSize = size(trainXn,1);
layers = [...
sequenceInputLayer(inputSize)
lstmLayer(100,'OutputMode','sequence')
lstmLayer(100,'OutputMode','last')
fullyConnectedLayer(100)
tanhLayer()
fullyConnectedLayer(1)
tanhLayer()
regressionLayer()
];
options = trainingOptions('adam',...
'ExecutionEnvironment', 'cpu',...
'GradientThreshold', 1,...
'MaxEpochs', 100,...
'MiniBatchSize', 4,...
'InitialLearnRate', 1e-3, ...
'LearnRateSchedule', 'piecewise', ...
'LearnRateDropPeriod', 10, ...
'LearnRateDropFactor', 0.1, ...
'SequenceLength', 'longest',...
'Shuffle', 'never',...
'Plots', 'training-progress',...
'Verbose', false,...
'L2Regularization', 1e-6...
);
%% LSTM网络训练
net = trainNetwork(con2seq(trainXn)',trainYn',layers,options);
%% LSTM网络预测
testYn_out = predict(net,con2seq(testXn)');
testOutput = mapminmax('reverse',testYn_out',outputps);
%% 误差分析
tValue=testY;
pValue=testOutput;
disp('LSTM');
myerror = myError(tValue,pValue,'show_less');
%% 输出预测结果
xlsName = '.\myResults\results_on_valid.xlsx';
xlswrite(xlsName,tValue',1,'A');
xlswrite(xlsName,pValue',1,'C');
%% 绘图
figure
plot(pValue)
hold on
plot(tValue)
title('测试样本预测结果','fontsize',14)
xlabel('样本','fontsize',14)
ylabel('输出','fontsize',14)
legend('预测值','实际值')
grid on
%% Save variables
% Environments
fullpath = mfilename('fullpath');
[~,name]=fileparts(fullpath);
save(['.\myStatus\',name]);
clear fullpath path name;
% Key parameters
lstm = net;
save('.\myModel\lstm','lstm');