Skip to content

Commit

Permalink
correct clamping clamp
Browse files Browse the repository at this point in the history
  • Loading branch information
jaak-s committed Jan 27, 2015
1 parent 6b69614 commit ea944e2
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 13 deletions.
23 changes: 10 additions & 13 deletions src/BMRF.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ function BMRF(data::RelationData;
roc_avg = 0.0
rmse_avg = 0.0

local probe_rat_all
local probe_rat_all, clamped_rat_all

## Gibbs sampling loop
for i in 1 : burnin + psamples
Expand Down Expand Up @@ -57,14 +57,7 @@ function BMRF(data::RelationData;
end

# clamping maybe needed for MovieLens data
local probe_rat
if isempty(clamp)
probe_rat = pred(rel.test_vec, data.entities[2].model.sample, data.entities[1].model.sample, rel.mean_rating)
else
probe_rat = pred(rel.test_vec, data.entities[2].model.sample, data.entities[1].model.sample, rel.mean_rating)
probe_rat[ probe_rat .< clamp[1] ] = clamp[1]
probe_rat[ probe_rat .> clamp[2] ] = clamp[2]
end
probe_rat = pred(rel.test_vec, data.entities[2].model.sample, data.entities[1].model.sample, rel.mean_rating)

if i > burnin
probe_rat_all = (counter_prob*probe_rat_all + probe_rat)/(counter_prob+1)
Expand All @@ -80,9 +73,13 @@ function BMRF(data::RelationData;
correct = (rel.test_label .== (probe_rat_all .< class_cut) )
err_avg = mean(correct)
err = mean(rel.test_label .== (probe_rat .< class_cut))
rmse_avg = haveTest ? sqrt(mean( (rel.test_vec[:,3] - probe_rat_all) .^ 2 )) : NaN
rmse = haveTest ? sqrt(mean( (rel.test_vec[:,3] - probe_rat) .^ 2 )) : NaN
roc_avg = haveTest ? AUC_ROC(rel.test_label, -vec(probe_rat_all)) : NaN

clamped_rat = isempty(clamp) ?probe_rat :makeClamped(probe_rat, clamp)
clamped_rat_all = isempty(clamp) ?probe_rat_all :makeClamped(probe_rat_all, clamp)

rmse_avg = haveTest ? sqrt(mean( (rel.test_vec[:,3] - clamped_rat_all) .^ 2 )) : NaN
rmse = haveTest ? sqrt(mean( (rel.test_vec[:,3] - clamped_rat) .^ 2 )) : NaN
roc_avg = haveTest ? AUC_ROC(rel.test_label, -vec(probe_rat_all)) : NaN
verbose && @printf("Iteration %d:\t avgAcc %6.4f Acc %6.4f | avgRMSE %6.4f | avgROC %6.4f | FU(%6.2f) FM(%6.2f) Fb(%6.2f) [%2.0fs]\n", i, err_avg, err, rmse_avg, roc_avg, vecnorm(data.entities[1].model.sample), vecnorm(data.entities[2].model.sample), vecnorm(data.entities[1].model.beta), time1 - time0)
end

Expand All @@ -91,6 +88,6 @@ function BMRF(data::RelationData;
result["RMSE"] = rmse_avg
result["accuracy"] = err_avg
result["ROC"] = roc_avg
result["predictions"] = probe_rat_all
result["predictions"] = clamped_rat_all
return result
end
7 changes: 7 additions & 0 deletions src/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@ function pred(probe_vec, sample_m, sample_u, mean_rating)
sum(sample_m[probe_vec[:,2],:].*sample_u[probe_vec[:,1],:],2) + mean_rating
end

function makeClamped(x, clamp::Vector{Float64})
x2 = copy(x)
x2[x2 .< clamp[1]] = clamp[1]
x2[x2 .> clamp[2]] = clamp[2]
return x2
end

function ConditionalNormalWishart(U::Matrix{Float64}, mu::Vector{Float64}, kappa::Real, T::Matrix{Float64}, nu::Real)
N = size(U, 1)
= mean(U,1)
Expand Down

0 comments on commit ea944e2

Please sign in to comment.