-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathdemo_2D.m
217 lines (151 loc) · 6.9 KB
/
demo_2D.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
rng(1); % fix random seed
addpath GPz/ % path to GPz
addpath(genpath('minFunc_2012/')) % path to minfunc
%%%%%%%%%%%%%% Model options %%%%%%%%%%%%%%%%
m = 50; % number of basis functions to use [required]
method = 'VD'; % select a method, options = GL, VL, GD, VD, GC and VC [required]
heteroscedastic = true; % learn a heteroscedastic noise process, set to false if only interested in point estimates [default=true]
normalize = true; % pre-process the input by subtracting the means and dividing by the standard deviations [default=true]
maxIter = 500; % maximum number of iterations [default=200]
maxAttempts = 50; % maximum iterations to attempt if there is no progress on the validation set [default=infinity]
trainSplit = 0.7; % percentage of data to use for training
validSplit = 0.15; % percentage of data to use for validation
testSplit = 0.15; % percentage of data to use for testing
inputNoise = true; % false = use mag errors as additional inputs, true = use mag errors as additional input noise
percentage = 0.5; % percentage of data with a missing variable
%%%%%%%%%%%%%% Create dataset %%%%%%%%%%%%%%
mean1 = [10 0];
Sigma1 = [10 0;0 1];
mean2 = [10 10];
Sigma2 = [5 -3;-3 3];
mean3 = [5 5];
Sigma3 = [2 0;0 2];
X = [mvnrnd(mean1,Sigma1,1000);mvnrnd(mean2,Sigma2,1000);mvnrnd(mean3,Sigma3,1000)];
[n,d] = size(X);
PHI = [mvnpdf(X,mean1,Sigma1) mvnpdf(X,mean2,Sigma2) mvnpdf(X,mean3,Sigma3)];
w = [-9;6;3];
Y = PHI*w+randn(n,1)*0.01;
if(inputNoise)
E = 0.5; % desired mean of the input noise variance
V = 0.25; % desired variance of the input noise variance
% The parameters of a gamma distribution with the desired mean and variance
a = E^2/V; b = E/V;
% sample from the gamma distribution with mean=E and variance=V
Psi = gamrnd(a,1/b,size(X));
Xn = X+randn(size(X)).*sqrt(Psi); % create a noisy input
% concert to covariances when using GC or VC
if(method(2)=='C')
Psi = reshape([Psi(:,1) zeros(n,2) Psi(:,2)]',2,2,n);
end
else
Psi = [];
Xn = X;
end
[n,d] = size(Xn);
% removing random variables from training
if(percentage>0)
r = randperm(n)';
psize = ceil(percentage*n/2);
Xn(r(1:psize),1) = nan; % remove the first variable from half the selected sample
Xn(r(psize+1:2*psize),2) = nan; % remove the second variable from the other half the selected sample
end
% split data into training, validation and testing
[training,validation,testing] = sample(n,trainSplit,validSplit,testSplit);
%%%%%%%%%%%%%% Fit the model %%%%%%%%%%%%%%
% initialize the model
model = init(Xn,Y,method,m,'heteroscedastic',heteroscedastic,'normalize',normalize,'training',training,'Psi',Psi);
% train the model
model = train(model,Xn,Y,'maxIter',maxIter,'maxAttempt',maxAttempts,'training',training,'validation',validation,'Psi',Psi);
%%%%%%%%%%%%%% Display %%%%%%%%%%%%%%
% create 2D test data
[x,y] = meshgrid(linspace(min(X(:,1))-1,max(X(:,1))+1,100),linspace(min(X(:,2))-1,max(X(:,2))+1,100));
Xs = [x(:) y(:)];
mu = predict(Xs,model);
% Visualize prediction
figure;
subplot(2,3,1)
surf(x,y,reshape(mu,size(x)));colormap jet;axis tight;
hold on;
Xd = Xn;
Xd(isnan(Xd(:,1)),1) = min(Xs(:,1));
Xd(isnan(Xd(:,2)),2) = min(Xs(:,2));
plot3(Xd(training,1),Xd(training,2),Y(training),'.');
xlabel('$x$','interpreter','latex','FontSize',12);
ylabel('$y$','interpreter','latex','FontSize',12);
zlabel('$z$','interpreter','latex','FontSize',12);
title('Predicted Model','interpreter','latex','FontSize',12);
% Visualize ground truth
subplot(2,3,4)
PHI = [mvnpdf(Xs,mean1,Sigma1) mvnpdf(Xs,mean2,Sigma2) mvnpdf(Xs,mean3,Sigma3)];
truth = PHI*w;
surf(x,y,reshape(truth,size(x)));colormap jet;axis tight;
xlabel('$x$','interpreter','latex','FontSize',12);
ylabel('$y$','interpreter','latex','FontSize',12);
zlabel('$z$','interpreter','latex','FontSize',12);
title('Reference Model','interpreter','latex','FontSize',12);
%%%%%%%%%%%%%% Predict with missing variables %%%%%%%%%%%%%%
labels = ['x','y'];
rmses = zeros(2,2);
for o=1:2
% create a test set with only variable 'o' observed
range_o = max(X(:,o))-min(X(:,o));
Xo = linspace(min(X(:,o))-range_o/10,max(X(:,o))+range_o/10,1000)';
% set missing variables to NaNs and observed variables to Xo
Xs = nan(size(Xo,1),2);
Xs(:,o) = Xo;
[mu,sigma] = predict(Xs,model);
% plot the results
subplot(2,3,o+1);
hold on;
f = [mu+2*sqrt(sigma); flip(mu-2*sqrt(sigma))];
fill([Xo; flip(Xo)], f, [0.85 0.85 0.85]);
plot(Xn(training,o),Y(training),'b.');
plot(Xo,mu,'r-','LineWidth',2);
xlabel(['$',labels(o),'$'],'interpreter','latex','FontSize',12);
ylabel('$z$','interpreter','latex','FontSize',12);
title('Predicted Model','interpreter','latex','FontSize',12);
axis tight
ax1 = gca;
% compute the error on the test set
Xs = nan(sum(testing),2);
Xs(:,o) = X(testing,o);
mu = predict(Xs,model);
rmses(1,o) = sqrt(mean((Y(testing)-mu).^2));
% build a reference model trained only on the observed variable to compare
if(isempty(Psi))
Psi_oo = [];
elseif(method(2)=='C')
Psi_oo = squeeze(Psi(o,o,:));
else
Psi_oo = Psi(:,o);
end
removed = isnan(Xn(:,o));
% build and train the reference model only on the observed variable
ref_model = init(Xn(:,o),Y,method,m,'heteroscedastic',heteroscedastic,'normalize',normalize,'training',training&~removed,'Psi',Psi_oo);
ref_model = train(ref_model,Xn(:,o),Y,'maxIter',maxIter,'maxAttempts',maxAttempts,'training',training&~removed,'validation',validation&~removed,'Psi',Psi_oo);
% generate predictions using the reference model
[mu,sigma] = predict(Xo,ref_model);
% visualize the results
subplot(2,3,o+4);
hold on;
f = [mu+2*sqrt(sigma); flip(mu-2*sqrt(sigma))];
fill([Xo; flip(Xo)], f, [0.85 0.85 0.85]);
plot(Xn(training,o),Y(training),'b.');
plot(Xo,mu,'r-','LineWidth',2);
xlabel(['$',labels(o),'$'],'interpreter','latex','FontSize',12);
ylabel('$z$','interpreter','latex','FontSize',12);
title('Reference Model','interpreter','latex','FontSize',12);
axis tight
ax2 = gca;
% compute the error on the test set
mu = predict(X(testing,o),model);
rmses(2,o) = sqrt(mean((Y(testing)-mu).^2));
% equlize axes
YLim = [min(ax1.YLim(1),ax2.YLim(1)) max(ax1.YLim(2),ax2.YLim(2))];
ax1.YLim = YLim;
ax2.YLim = YLim;
ax1.XLim = [min(Xo) max(Xo)];
ax2.XLim = [min(Xo) max(Xo)];
end
%%%%%%%%%%%%%% Display Metrics %%%%%%%%%%%%%%
fprintf('\t\t RMSE on the test set\n\t\tMissing y\tMissing x\nPredicted\t%f\t%f\nReference\t%f\t%f\n',rmses(1,1),rmses(1,2),rmses(2,1),rmses(2,2))