Skip to content

Commit

Permalink
Moved WTS to RL
Browse files Browse the repository at this point in the history
modified Rosenbrock test problem with new MM scheme
  • Loading branch information
noblec04 committed Sep 1, 2024
1 parent 9475dfa commit 27f5a50
Show file tree
Hide file tree
Showing 9 changed files with 663 additions and 59 deletions.
44 changes: 31 additions & 13 deletions MatlabGP/+BO/WTS.asv → MatlabGP/+RL/TS.asv
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
classdef WTS
%Windowed Thompson Sampling
classdef TS
%Windowed Thompson Sampling - beta dist
% [1] Trovo, Paladino, Restelli, & Gatti 2020
properties
rewards
Expand All @@ -8,7 +8,7 @@ classdef WTS

methods

function obj = WTS(window,narms)
function obj = TS(window,narms)
obj.rewards = cell(1,narms);
obj.window = window;
end
Expand All @@ -25,18 +25,20 @@ classdef WTS
obj.rewards{i}(end+1)=0;
end

obj.rewards{arm}(end)=reward;
obj.rewards{arm}(end)=log(reward);
end

function arm = action(obj,prior)

function arm = action(obj)

for i = 1:numel(obj.rewards)

S = sum(obj.rewards{i});
T = sum(double(obj.rewards{i}>0));
T = sum(double(obj.rewards{i}~=0));

sig = 1./(1/10 + T);
mu = sig*S;

nu(i) = betarnd(S+prior(i),T-S+prior(i));
nu(i) = normrnd(mu,sqrt(sig));
end

[~,arm] = max(nu);
Expand All @@ -45,22 +47,38 @@ classdef WTS

function plotDists(obj)

X = 0:0.001:1;
X = -10:0.01:10;

figure
hold on
for i = 1:numel(obj.rewards)

S = sum(obj.rewards{i});
T = sum(double(obj.rewards{i}>0));
T = sum(double(obj.rewards{i}~=0));

Y = betapdf(X,S+1,T-S+1);
sig = 1./(1/10 + T);
mu = sig*S;

plot(X,Y);
Y = normpdf(X + mu,mu,sig);

plot(X+mu,Y);

end

end

function lik = likelihood(obj,R)

for i = 1:numel(obj.rewards)

S = sum(obj.rewards{i});
T = sum(double(obj.rewards{i}~=0));

sig = 1./(1/10 + T);
mu = sig*S;

lik( = normpdf(R(i),mu,sig);

end
end

end
Expand Down
85 changes: 85 additions & 0 deletions MatlabGP/+RL/TS.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
classdef TS
%Windowed Thompson Sampling - beta dist
% [1] Trovo, Paladino, Restelli, & Gatti 2020
properties
rewards
window
end

methods

function obj = TS(window,narms)
obj.rewards = cell(1,narms);
obj.window = window;
end

function obj = addReward(obj,arm,reward)

if length(obj.rewards{1})==obj.window
for i = 1:numel(obj.rewards)
obj.rewards{i}(1)=[];
end
end

for i = 1:numel(obj.rewards)
obj.rewards{i}(end+1)=0;
end

obj.rewards{arm}(end)=log(reward);
end

function [arm,nu] = action(obj)

for i = 1:numel(obj.rewards)

S = sum(obj.rewards{i});
T = sum(double(obj.rewards{i}~=0));

sig = 1./(1/10 + T);
mu = sig*S;

nu(i) = normrnd(mu,sqrt(sig));
end

[~,arm] = max(nu);

end

function plotDists(obj)

X = -10:0.01:10;

figure
hold on
for i = 1:numel(obj.rewards)

S = sum(obj.rewards{i});
T = sum(double(obj.rewards{i}~=0));

sig = 1./(1/10 + T);
mu = sig*S;

Y = normpdf(X + mu,mu,sig);

plot(X+mu,Y);

end
end

function lik = likelihood(obj,R)

for i = 1:numel(obj.rewards)

S = sum(obj.rewards{i});
T = sum(double(obj.rewards{i}~=0));

sig = 1./(1/10 + T);
mu = sig*S;

lik(i) = normpdf(R(i),mu,sig);

end
end

end
end
2 changes: 1 addition & 1 deletion MatlabGP/+BO/WTS.m → MatlabGP/+RL/WTS.m
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
classdef WTS
%Windowed Thompson Sampling
%Windowed Thompson Sampling - beta dist
% [1] Trovo, Paladino, Restelli, & Gatti 2020
properties
rewards
Expand Down
12 changes: 6 additions & 6 deletions MatlabGP/GP.m
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@
y = obj.mean.eval(x) + ksf*obj.alpha;

if nargout>1
kss = obj.kernel.build(xs(1,:),xs(1,:));

sig = abs(diag(kss) + obj.kernel.signn - dot(ksf',obj.Kinv*ksf')');
sig = abs(obj.kernel.scale + obj.kernel.signn - dot(ksf',obj.Kinv*ksf')');
end
end

Expand Down Expand Up @@ -139,13 +137,15 @@

obj.kernel.scale = 1;

[obj.K] = obj.kernel.build(xx,xx);
%[obj.K] = obj.kernel.build(xx,xx);

res = obj.Y - obj.mean.eval(obj.X);

kkp = pinv(obj.K);
%kkp = pinv(obj.K);

%sigp = sqrt(abs(res'*kkp*res./(size(obj.Y,1))));

sigp = sqrt(abs(res'*kkp*res./(size(obj.Y,1))));
sigp = std(obj.Y);

obj.kernel.scale = sigp^2;

Expand Down
Loading

0 comments on commit 27f5a50

Please sign in to comment.