Skip to content

Commit

Permalink
full matrix prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
jaak-s committed Mar 23, 2015
1 parent 40ed2e9 commit 189639f
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 3 deletions.
9 changes: 8 additions & 1 deletion src/macau.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ function macau(data::RelationData;
modes_other = map(entity -> Vector{Int64}[ find(en -> en != entity, r.entities) for r in entity.relations ],
data.entities)
if full_prediction
yhat_full = zeros(size(r.relations[1]))
yhat_full = zeros(size(data.relations[1]))
end

verbose && println("Sampling")
Expand Down Expand Up @@ -107,6 +107,10 @@ function macau(data::RelationData;
rel = data.relations[1]
probe_rat = pred(rel, rel.test_vec, rel.test_F)

if full_prediction && i >= burnin ## last burnin sample is included
yhat_full += pred_all( data.relations[1] )
end

if i > burnin
if verbose && i == burnin + 1
println("--------- Burn-in complete, averaging posterior samples ----------")
Expand Down Expand Up @@ -155,6 +159,9 @@ function macau(data::RelationData;
result["RMSE"] = rmse_avg
result["accuracy"] = err_avg
result["ROC"] = roc_avg
if full_prediction
result["predictions_full"] = yhat_full / (burnin + 1)
end
if numTest(data.relations[1]) > 0
rel = data.relations[1]
result["predictions"] = copy(rel.test_vec)
Expand Down
6 changes: 4 additions & 2 deletions src/sampling.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
export pred
using Iterators

export pred, pred_all

function pred(r::Relation, probe_vec::DataFrame, F)
if ! hasFeatures(r)
Expand Down Expand Up @@ -39,7 +41,7 @@ function pred_all(r::Relation)
if hasFeatures(r)
error("Prediction of all elements is not possible when Relation has features.")
end
udot(r) + r.model.mean_value
udot_all(r) + r.model.mean_value
end

function makeClamped(x, clamp::Vector{Float64})
Expand Down
12 changes: 12 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,18 @@ assignToTest!(rd.relations[1], 2)
result = macau(rd, burnin = 10, psamples = 10, verbose = false)
@test size(result["predictions"],1) == 2

# testing pred_all
Yhat = pred_all(rd.relations[1])
@test size(Yhat) == (15, 10)
@test_approx_eq Yhat[2,3] (rd.entities[1].model.sample[2,:] * rd.entities[2].model.sample[3,:]')[1] + rd.relations[1].model.mean_value

# predict all
result1 = macau(rd, burnin = 10, psamples = 10, verbose = false, full_prediction = true)
@test size(result1["predictions_full"]) == (15, 10)
x1 = result1["predictions"][1, 1:2]
y1 = result1["predictions"][:pred][1]
@test_approx_eq result1["predictions_full"][x1[1], x1[2]] y1

# custom function on latent variables
f1(a) = length(a)
result2 = macau(rd, burnin = 5, psamples = 6, verbose = false, f = f1)
Expand Down
8 changes: 8 additions & 0 deletions test/tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,11 @@ end
rd = RelationData(df)
assignToTest!(rd.relations[1], 10)
result = macau(rd, burnin=50, psamples=10, num_latent=2, verbose=false)

## checking prediction for whole matrix for tensors
Yhat = pred_all(rd.relations[1])
@test size(Yhat) == size(X)
yprod = vec(rd.entities[1].model.sample[4,:])
yprod .*= vec(rd.entities[2].model.sample[2,:])
yprod .*= vec(rd.entities[3].model.sample[1,:])
@test_approx_eq Yhat[4,2,1] sum(yprod)+rd.relations[1].model.mean_value

0 comments on commit 189639f

Please sign in to comment.