-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_nn.m
129 lines (117 loc) · 2.66 KB
/
test_nn.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
% Neural network training code
% Author: Xiujiao Gao
% Parameters
t1 = cputime;
% get weight matrix
W1 = load('W1.txt');
W2 = load('W2.txt');
% Validation
% X, feature matrix
X = load(strcat('train0.txt'));
[rows,columns] = size(X);
% get train data , the left will be used as validation data
X(1:ceil(rows*0.95),:)=[];
[rows,columns] = size(X);
xones = ones(rows,1);
X = [xones X];
% T, lable for each x feature,is 1 of 10 vector
T = zeros(rows,10);
T(:,1) = 1;
for i=1:9
x = load(strcat('train',num2str(i),'.txt'));
[rows,columns] = size(x);
% get training data
x(1:ceil(rows*0.95),:)=[];
[rows,columns] = size(x);
xones = ones(rows,1);
X = [X; xones x];
t = zeros(rows,10);
t(:,i+1) = 1;
T = [T;t];
end
% Get Y
A = X*W1;
[Arows,Acolumns] = size(A);
Z = zeros(Arows,Acolumns);
for i = 1:Arows
for j = 1:Acolumns
Z(i,j) = tanh(A(i,j));
end
end
z = ones(Arows,1)*1.0;
Z = [z Z];
R = Z*W2;
% compute Y
Sig = exp(R);
[Sigrows,Sigcolumns] = size(Sig);
rowsum = sum(Sig,2);
temp = zeros(Sigrows,Sigcolumns);
for i = 1:Sigcolumns
temp(:,i) = rowsum;
end
Y = Sig./temp;
% get max value from Y for each row and the corresponding column number
[y,n] = max(Y');
Y_Lable = zeros(Sigrows,Sigcolumns);
for i= 1:Sigrows
Y_Lable(i,n(i)) = 1;
end
% get error rate
E = xor(Y_Lable,T);
validerr = (sum(sum(E))/2)/Sigrows
% Test part
% X, feature matrix
X = load(strcat('test0.txt'));
[Xrows,Xcolumns] = size(X);
xones = ones(Xrows,1);
X = [xones X];
% T, lable for each x feature,is 1 of 10 vector
T = zeros(Xrows,10);
T(:,1) = 1;
for i=1:9
x = load(strcat('test',num2str(i),'.txt'));
[xrows,xcolumns] = size(x);
xones = ones(xrows,1);
X = [X; xones x];
t = zeros(xrows,10);
t(:,i+1) = 1;
T = [T;t];
end
% Get Y
A = X*W1;
[Arows,Acolumns] = size(A);
Z = zeros(Arows,Acolumns);
for i = 1:Arows
for j = 1:Acolumns
Z(i,j) = tanh(A(i,j));
end
end
z = ones(Arows,1)*1.0;
Z = [z Z];
R = Z*W2;
% compute Y
Sig = exp(R);
[Sigrows,Sigcolumns] = size(Sig);
rowsum = sum(Sig,2);
temp = zeros(Sigrows,Sigcolumns);
for i = 1:Sigcolumns
temp(:,i) = rowsum;
end
Y = Sig./temp;
% get max value from Y for each row and the corresponding column number
[y,n] = max(Y');
Y_Lable = zeros(Sigrows,Sigcolumns);
for i= 1:Sigrows
Y_Lable(i,n(i)) = 1;
end
% get error rate
E = xor(Y_Lable,T);
testerr = (sum(sum(E))/2)/Sigrows
fid = fopen('class_nn.txt','W');
for i=1:Sigrows
for j =1:Sigcolumns
fprintf(fid,'%d \t',Y_Lable(i,j));
end
fprintf(fid,'\n');
end
fclose(fid);