-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathMetropolisHastings.m
58 lines (52 loc) · 2.1 KB
/
MetropolisHastings.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
%% Ãåíåðèðóåò âûáîðêó èç ðàñïðåäåëåíèÿ Ãèááñà
% X - 4-õ ìåðíàÿ ìàòðèöà èñõîäíûõ äàííûõ
% Yinit - 3-õ ìåðíàÿ ìàòðèöà íà÷àëüíîé ñåãìåíòàöèè
% B - êîëè÷åñòâî ïðîïóñêàåìûõ ñîñòîÿíèé
% M - âîçâðàùàåìîå êîëè÷åñòâî ñîñòîÿíèé
% k - êîëè÷åñòâî êëàññîâ ñåãìåíòàöèè
% p - ðàçìåðíîñòü äàííûõ (êîëè÷åñòâî time-series)
% beta - ïàðàìåòð ìîäåëè Ïîòòñà
% mu - ìàòðèöà ïàðàìåòðîâ äëÿ vMF
% kappa - âåêòîð ïàðàìåòðîâ äëÿ vMF
% neighbours_count - êîëè÷åñòâî ñîñåäåé
% Âîçâðàùàåò ìàòðèöó ðàçìåðíîñòè: Mx(êîëè÷åñòâî âîêñåëåé)
function [Y, sample] = MetropolisHastings(Yinit, B, M, k, beta, logprobs, neighbours_count, labelCosts)
sz = size(Yinit);
Yflat = Yinit(:);
flatsz = size(Yflat,1);
all_neighbours_ind = GetNeighbours(sz, neighbours_count);
Y = zeros(M, flatsz);
if nargin >= 8
counts = zeros(1,k);
for i=1:k
counts(i) = sum(Yflat==k);
end
end
for j=1:(B+M)
permutations = randperm(flatsz);
non_zero_perms = permutations(Yinit(permutations)~=0);
for i=non_zero_perms
neighbours = all_neighbours_ind(all_neighbours_ind(:, i)~=i, i);
cur_state = Yflat(i);
new_state = randsample(k, 1);
if nargin < 8
cur_energy = -beta * sum(Yflat(neighbours)~=cur_state, 1) - logprobs(cur_state, i);
new_energy = -beta * sum(Yflat(neighbours)~=new_state, 1) - logprobs(new_state, i);
else
new_counts = counts + [zeros(1, cur_state-1), -1, zeros(1, k-cur_state)] + [zeros(1, new_state-1), 1, zeros(1, k-new_state)];
cur_energy = -beta * sum(Yflat(neighbours)~=cur_state, 1) - logprobs(cur_state, i) + sum(labelCosts(:) .* (counts(:)>0));
new_energy = -beta * sum(Yflat(neighbours)~=new_state, 1) - logprobs(new_state, i) + sum(labelCosts(:) .* (new_counts(:)>0));
end
P = min(exp(new_energy - cur_energy), 1);
if rand() < P
Yflat(i) = new_state;
if nargin >= 8
counts = new_counts;
end
end
end
if j > B
Y(j-B, :) = Yflat;
end
end
sample = reshape(Y(end, :), sz);