Skip to content

Commit

Permalink
added decider
Browse files Browse the repository at this point in the history
Added windowed Thompson Sampling
Added hypercube vertex code
  • Loading branch information
noblec04 committed Aug 31, 2024
1 parent d21083f commit 9475dfa
Show file tree
Hide file tree
Showing 6 changed files with 269 additions and 56 deletions.
67 changes: 67 additions & 0 deletions MatlabGP/+BO/WTS.asv
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
classdef WTS
%Windowed Thompson Sampling
% [1] Trovo, Paladino, Restelli, & Gatti 2020
properties
rewards
window
end

methods

function obj = WTS(window,narms)
obj.rewards = cell(1,narms);
obj.window = window;
end

function obj = addReward(obj,arm,reward)

if length(obj.rewards{1})==obj.window
for i = 1:numel(obj.rewards)
obj.rewards{i}(1)=[];
end
end

for i = 1:numel(obj.rewards)
obj.rewards{i}(end+1)=0;
end

obj.rewards{arm}(end)=reward;
end

function arm = action(obj,prior)


for i = 1:numel(obj.rewards)

S = sum(obj.rewards{i});
T = sum(double(obj.rewards{i}>0));

nu(i) = betarnd(S+prior(i),T-S+prior(i));
end

[~,arm] = max(nu);

end

function plotDists(obj)

X = 0:0.001:1;

figure
hold on
for i = 1:numel(obj.rewards)

S = sum(obj.rewards{i});
T = sum(double(obj.rewards{i}>0));

Y = betapdf(X,S+1,T-S+1);

plot(X,Y);

end


end

end
end
71 changes: 71 additions & 0 deletions MatlabGP/+BO/WTS.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
classdef WTS
%Windowed Thompson Sampling
% [1] Trovo, Paladino, Restelli, & Gatti 2020
properties
rewards
window
end

methods

function obj = WTS(window,narms)
obj.rewards = cell(1,narms);
obj.window = window;
end

function obj = addReward(obj,arm,reward)

if length(obj.rewards{1})==obj.window
for i = 1:numel(obj.rewards)
obj.rewards{i}(1)=[];
end
end

for i = 1:numel(obj.rewards)
obj.rewards{i}(end+1)=0;
end

obj.rewards{arm}(end)=reward;
end

function arm = action(obj,sig)

if nargin<2
sig = 0*[1:numel(obj.rewards)] + 1;
end


for i = 1:numel(obj.rewards)

S = sum(obj.rewards{i});
T = sum(double(obj.rewards{i}>0));

nu(i) = betarnd(S+1,T-S+1)*sig(i);
end

[~,arm] = max(nu);

end

function plotDists(obj)

X = 0:0.001:1;

figure
hold on
for i = 1:numel(obj.rewards)

S = sum(obj.rewards{i});
T = sum(double(obj.rewards{i}>0));

Y = betapdf(X,S+1,T-S+1);

plot(X,Y);

end


end

end
end
30 changes: 30 additions & 0 deletions MatlabGP/+BO/argmaxGrid.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
function [x,R] = argmaxGrid(FF,Z)

lb = Z.lb_x;
ub = Z.ub_x;

D = length(lb);

XT = lb + (ub - lb).*[lhsdesign(1000*D,D);utils.HypercubeVerts(D)];

[~,im] = max(FF(Z,XT));

x0 = XT(im,:);

try
opts = optimoptions('fmincon','SpecifyObjectiveGradient',true,'Display','off');

[x,R] = fmincon(@(x) FF(Z,x),x0,[],[],[],[],lb,ub,[],opts);

catch

opts = optimoptions('fmincon','SpecifyObjectiveGradient',false,'Display','off');

[x,R] = fmincon(@(x) FF(Z,x),x0,[],[],[],[],lb,ub,[],opts);

end
%[x,R] = VSGD(@(x) FF(Z,x),x0,'lr',0.03,'lb',lb,'ub',ub,'gamma',0.01,'iters',100,'tol',1*10^(-3));

R = -1*R;

end
14 changes: 14 additions & 0 deletions MatlabGP/+utils/HypercubeVerts.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
function [vertices] = HypercubeVerts(d)

%{
Construct vector of hypercube vertices
%}

y = ones(1, d) * 2;
x = fliplr([1 cumprod(y)]);
n = x(1);
x = x(2:end);
vertices = ceil(repmat((1:n).', 1, d) ./ repmat(x, n, 1));
vertices = mod(vertices - 1, 2);

end
73 changes: 45 additions & 28 deletions MatlabGP/examples/TestRosenbrockProblem.asv
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ mb = means.linear(ones(1,D));

a = kernels.RQ(2,1,ones(1,D+nF-1));
b = kernels.RQ(2,1,ones(1,D));
a.signn = 1*10^(-7);
b.signn = 1*10^(-7);
a.signn = eps;
b.signn = eps;

%%
tic
Expand All @@ -54,27 +54,31 @@ MF = MF.condition();
MF = MF.train();
toc

%%

mc = means.linear(ones(1,D));%*means.sine(1,10,0,1);
c = kernels.RQ(2,1,ones(1,D));
c.signn = eps;
c.signn = 1;

LOO = GP(mc,c);

LOOZ{1} = LOO.condition(x{1},log(abs(Z{1}.LOO)),lb,ub);
LOOZ{1} = LOOZ{1}.train();
LOOZ{2} = LOO.condition(x{2},log(abs(Z{2}.LOO)),lb,ub);
LOOZ{2} = LOOZ{2}.train();

LOOMF = LOO.condition(x{1},log(abs(MF.LOO)),lb,ub);
LOOMF = LOOMF.train();
% LOOZ{1} = LOO.condition(x{1},log(abs(Z{1}.LOO)),lb,ub);
% LOOZ{1} = LOOZ{1}.train();
% LOOZ{2} = LOO.condition(x{2},log(abs(Z{2}.LOO)),lb,ub);
% LOOZ{2} = LOOZ{2}.train();
% LOOZ{3} = LOO.condition(x{3},log(abs(Z{3}.LOO)),lb,ub);
% LOOZ{3} = LOOZ{3}.train();
%
% LOOMF = LOO.condition(x{1},log(abs(MF.LOO)),lb,ub);
% LOOMF = LOOMF.train();


%%
figure
hold on
utils.plotSurf(Z{1},1,2,'color','r')
utils.plotSurf(Z{2},1,2,'color','b')
%utils.plotSurf(Z{3},1,2,'color','g')
utils.plotSurf(Z{3},1,2,'color','g')

%%

Expand All @@ -92,22 +96,26 @@ max(abs(yy - MF.eval_mu(xx)))./std(yy)

%%

C = [50 20 1];%20
C = [50 30 1];%20

for jj = 1:100

for jj = 1:60
dec = BO.WTS(20,3);
in = dec.action();

%[xn,Rn] = BO.argmax(@BO.UCB,LOOMF);
%[xn,Rn] = BO.argmax(@BO.MFSFDelta,MF);
[xn,Rn] = BO.argmax(@BO.maxVAR,MF);
%[xn,Rn] = BO.argmaxGrid(@BO.UCB,LOOMF);
[xn,Rn] = BO.argmaxGrid(@BO.MFSFDelta,MF);
%[xn,Rn] = BO.argmaxGrid(@BO.maxVAR,MF);

siggn(1) = exp((LOOZ{1}.eval(xn)))/(C(1));
siggn(2) = exp((LOOZ{2}.eval(xn)))/(C(2));
% siggn(1) = exp((LOOZ{1}.eval(xn)))/(C(1));
% siggn(2) = exp((LOOZ{2}.eval(xn)))/(C(2));
% siggn(3) = exp((LOOZ{3}.eval(xn)))/(C(3));

% siggn(1) = abs(Z{1}.eval_var(xn))/(C(1));
% siggn(2) = abs(Z{2}.eval_var(xn))/(C(2));
%siggn(3) = abs(Z{3}.eval_var(xn))/(C(3));

[~,in] = max(siggn);
% siggn(3) = abs(Z{3}.eval_var(xn))/(C(3));
%
% [~,in] = max(siggn);

if in==1
[x{1},flag] = utils.catunique(x{1},xn);
Expand All @@ -123,22 +131,31 @@ for jj = 1:60
end
end

% [x{3},flag] = utils.catunique(x{3},xn);
% if flag
% y{3} = [y{3}; testFuncs.Rosenbrock(xn,3)];
% end
[x{3},flag] = utils.catunique(x{3},xn);
if flag
y{3} = [y{3}; testFuncs.Rosenbrock(xn,3)];
end

for ii = 1:nF
Z{ii} = Z{ii}.condition(x{ii},y{ii},lb,ub);
end

yh1 = MF.eval(xn);

MF.GPs = Z;
MF = MF.condition();

LOOZ{1} = LOOZ{1}.condition(x{1},log(abs(Z{1}.LOO)),lb,ub);
LOOZ{2} = LOOZ{2}.condition(x{2},log(abs(Z{2}.LOO)),lb,ub);
yh2 = MF.eval(xn);

Ri = abs(yh2 - yh1)/C(in);

dec = dec.addReward(in,Ri);

%LOOMF = LOOMF.condition(x{1},log(abs(MF.LOO)),lb,ub);
% LOOZ{1} = LOOZ{1}.condition(x{1},log(abs(Z{1}.LOO)),lb,ub);
% LOOZ{2} = LOOZ{2}.condition(x{2},log(abs(Z{2}.LOO)),lb,ub);
% LOOZ{3} = LOOZ{3}.condition(x{3},log(abs(Z{3}.LOO)),lb,ub);

% LOOMF = LOOMF.condition(x{1},log(abs(MF.LOO)),lb,ub);

R2z(jj) = 1 - mean((yy - Z{1}.eval_mu(xx)).^2)./var(yy);
RMAEz(jj) = max(abs(yy - Z{1}.eval_mu(xx)))./std(yy);
Expand Down
Loading

0 comments on commit 9475dfa

Please sign in to comment.