Skip to content

Commit

Permalink
added eval_var based on residual correction
Browse files Browse the repository at this point in the history
  • Loading branch information
noblec04 committed Oct 29, 2024
1 parent b1aab84 commit 9582e50
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 119 deletions.
113 changes: 0 additions & 113 deletions MatlabGP/RRNN.asv

This file was deleted.

31 changes: 28 additions & 3 deletions MatlabGP/RRNN.m
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@

obj.Wout{i} = lsqminnorm([H{i}, ones(size(H{i},1),1)],R);

Ri(i) = sum(abs(R));
Ri(i,:) = sum(abs(R),1);

R = R - [H{i}, ones(size(H{i},1),1)]*obj.Wout{i};

Expand Down Expand Up @@ -103,8 +103,33 @@
y = obj.unscale_y(y);
end

function sig = eval_var(~,x)
sig = 0*x(:,1);
function sig = eval_var(obj,X)
%PREDICT Predicts the output of the trained model for new input
%data
% Inputs:
% obj - trained RRNN
% X - Input data

% Output:
% y - output output

H = obj.activation.forward([X, ones(size(X,1),1)] * obj.Win{1});
y1 = [H, ones(size(H,1),1)] * obj.Wout{1};

y = 0;
X = obj.scale(X);

for i = 1:numel(obj.Win)

H = obj.activation.forward([X, ones(size(X,1),1)] * obj.Win{i});
y = y + [H, ones(size(H,1),1)] * obj.Wout{i};

end

y = obj.unscale_y(y);
y1 = obj.unscale_y(y1);

sig = (y1 - y).^2;
end

function [mu,sig] = eval_all(obj,x)
Expand Down
12 changes: 9 additions & 3 deletions MatlabGP/examples/testRRNN.m
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,24 @@
close all
clc

xx = [0;lhsdesign(10,1);1];
xx = [0;lhsdesign(8,1);1];
yy = normrnd(forr(xx,0),0*forr(xx,0));

xmesh = linspace(0,1,100)';
ymesh = forr(xmesh,0);

nnet = RRNN(NN.SWISH(0.8),10,0,1);
nnet = RRNN(NN.SWISH(2),10,0,1);

%%

tic
nnet2 = nnet.train(xx,yy,5);%,xv,fv
[nnet2,Ri] = nnet.train(xx,yy,5);%,xv,fv
toc

%%

yp2 = nnet2.eval(xmesh);
sig2 = nnet2.eval_var(xmesh);


%%
Expand All @@ -35,6 +36,11 @@
plot(xmesh,yp2)
plot(xx,yy,'x')

figure
plot(xmesh,sig2)

1 - mean((ymesh - yp2).^2)./var(ymesh)

%%

function y = forr(x,dx)
Expand Down

0 comments on commit 9582e50

Please sign in to comment.