Skip to content

Commit

Permalink
Bugs with truncation cases
Browse files Browse the repository at this point in the history
  • Loading branch information
sparktseung committed Oct 19, 2020
1 parent cc20a33 commit f558210
Show file tree
Hide file tree
Showing 10 changed files with 137 additions and 27 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.vscode/settings.json
Manifest.toml
docs/build/
.vscode/launch.json
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterTools = "35a29f4d-8980-5a13-9543-d66fff28ecb8"
InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f"
JuliaInterpreter = "aa1ae85d-cabe-5617-a682-6adf51b2e16a"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
13 changes: 10 additions & 3 deletions src/LRMoE.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module LRMoE

import Base: size, length, convert, show, getindex, rand, vec, inv, expm1
import Base: size, length, convert, show, getindex, rand, vec, inv, expm1, abs
import Base: sum, maximum, minimum, ceil, floor, extrema, +, -, *, ==
import Base: convert, copy
import Base.Math: @horner
Expand All @@ -14,7 +14,8 @@ using Distributions
import Distributions: pdf, cdf, ccdf, logpdf, logcdf, logccdf
import Distributions: rand
import Distributions: UnivariateDistribution, DiscreteUnivariateDistribution, ContinuousUnivariateDistribution
import Distributions: LogNormal, Normal, Poisson, Bernoulli
import Distributions: Bernoulli, Multinomial
import Distributions: LogNormal, Normal, Poisson

using InvertedIndices
import InvertedIndices: Not
Expand Down Expand Up @@ -94,8 +95,12 @@ export
PoissonExpert, ZIPoissonExpert,

# fitting
fit_main
fit_main,

# simulation
sim_expert,
sim_logit_gating,
sim_dataset


### source files
Expand All @@ -109,6 +114,8 @@ include("penalty.jl")

include("fit.jl")

include("simulation.jl")

# include("experts/ll/expert_ll_pos.jl")

"""
Expand Down
4 changes: 4 additions & 0 deletions src/experts/continuous/lognormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ function EM_M_expert(d::LogNormalExpert,
μ_old = d.μ
σ_old = d.σ

# println("$(μ_old), $(σ_old)")

# Further E-Step
logY_e_obs = vec( int_obs_logY.(d, yl, yu, expert_ll_pos) )
logY_e_lat = vec( int_lat_logY.(d, tl, tu, expert_tn_bar_pos) )
Expand All @@ -113,5 +115,7 @@ function EM_M_expert(d::LogNormalExpert,
μ_new = sum(term_zkz_logY)[1] / sum(term_zkz)[1]
σ_new = sqrt( 1/sum(term_zkz)[1] * (sum(term_zkz_logY_sq)[1] - 2.0*μ_new*sum(term_zkz_logY)[1] + (μ_new)^2*sum(term_zkz)[1] ) )

# println("$(μ_new), $(σ_new)")

return LogNormalExpert(μ_new, σ_new)
end
3 changes: 2 additions & 1 deletion src/experts/continuous/zilognormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ function EM_M_expert(d::ZILogNormalExpert,
expert_ll_pos,
expert_tn_pos,
expert_tn_bar_pos,
z_e_obs, z_e_lat, k_e,
# z_e_obs, z_e_lat, k_e,
z_pos_e_obs, z_pos_e_lat, k_e,
penalty = penalty, pen_pararms_jk = pen_pararms_jk)

return ZILogNormalExpert(p_new, tmp_update.μ, tmp_update.σ)
Expand Down
9 changes: 7 additions & 2 deletions src/fit/em.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,13 @@ function EM_E_z_lat(gate_expert_tn_bar_comp, gate_expert_tn_bar)
return tmp
end

function EM_E_k(gate_expert_tn)
return expm1.( - gate_expert_tn )
# function EM_E_k(gate_expert_tn)
# return expm1.( - gate_expert_tn )
# end

function EM_E_k(gate_expert_tn_bar_k)
# return exp.( gate_expert_tn_bar_k )
return expm1.( - log1mexp.(gate_expert_tn_bar_k) )
end

function EM_E_z_zero_obs_update(lower, prob, ll_vec)
Expand Down
13 changes: 9 additions & 4 deletions src/fit/fit_main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ function fit_main(Y, X, α_init, model;
print_steps = true)

# Make variables accessible with in the scope of `let`
let α_em, gate_em, model_em, ll_em_list, ll_em, ll_em_np, ll_em_old, ll_em_np_old, iter
let α_em, gate_em, model_em, ll_em_list, ll_em, ll_em_np, ll_em_old, ll_em_np_old, iter, z_e_obs, z_e_lat, k_e
# Initial loglik
gate_init = LogitGating(α_init, X)
ll_np_list = loglik_np(Y, gate_init, model)
Expand Down Expand Up @@ -37,7 +37,8 @@ function fit_main(Y, X, α_init, model;
# E-Step
z_e_obs = EM_E_z_obs(ll_em_list.gate_expert_ll_comp, ll_em_list.gate_expert_ll)
z_e_lat = EM_E_z_lat(ll_em_list.gate_expert_tn_bar_comp, ll_em_list.gate_expert_tn_bar)
k_e = EM_E_k(ll_em_list.gate_expert_tn)
# k_e = EM_E_k(ll_em_list.gate_expert_tn)
k_e = EM_E_k(ll_em_list.gate_expert_tn_bar_k)


# M-Step: α
Expand All @@ -55,6 +56,7 @@ function fit_main(Y, X, α_init, model;

# M-Step: component distributions
for j in 1:size(model)[1] # by dimension
# for j in 1:1
for k in 1:size(model)[2] # by component

model_em[j,k] = EM_M_expert(model_em[j,k],
Expand All @@ -69,8 +71,11 @@ function fit_main(Y, X, α_init, model;
ll_em_np = ll_em_list.ll
ll_em_penalty = penalty ? (pen_α(α_em) + penalty_params(model_em, pen_params)) : 0.0
ll_em = ll_em_np + ll_em_penalty

s = ll_em - ll_em_temp > 0 ? "+" : "-"
pct = abs((ll_em - ll_em_temp) / ll_em_temp) * 100
if print_steps
println("Iteration $(iter), updating model[$j, $k]: $(ll_em_temp) -> $(ll_em)")
println("Iteration $(iter), updating model[$j, $k]: $(ll_em_temp) -> $(ll_em), ( $(s) $(pct) % )")
end
ll_em_temp = ll_em
end
Expand All @@ -86,7 +91,7 @@ function fit_main(Y, X, α_init, model;
ll_em = ll_em_np + ll_em_penalty
end

return (α_em, model_em)
return (α_em = α_em, model_em = model_em, z_e_obs = z_e_obs, z_e_lat = z_e_lat, k_e = k_e)
end


Expand Down
8 changes: 8 additions & 0 deletions src/loglik.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,23 @@ function loglik_np(Y, gate, model)
gate_expert_ll_pos_comp = loglik_aggre_gate_dim(gate, expert_ll_pos_comp)
gate_expert_tn_pos_comp = loglik_aggre_gate_dim(gate, expert_tn_pos_comp)
gate_expert_tn_bar_pos_comp = gate + log1mexp.(expert_tn_pos_comp)
gate_expert_tn_bar_pos_comp_k = gate + expert_tn_bar_pos_comp

gate_expert_ll_comp = loglik_aggre_gate_dim(gate, expert_ll_comp)
gate_expert_tn_comp = loglik_aggre_gate_dim(gate, expert_tn_comp)
gate_expert_tn_bar_comp = gate + log1mexp.(expert_tn_comp)
gate_expert_tn_bar_comp_k = gate + expert_tn_bar_comp

# Aggregate by component
gate_expert_ll_pos = loglik_aggre_gate_dim_comp(gate_expert_ll_pos_comp)
gate_expert_tn_pos = loglik_aggre_gate_dim_comp(gate_expert_tn_pos_comp)
gate_expert_tn_bar_pos = loglik_aggre_gate_dim_comp(gate_expert_tn_bar_pos_comp)
gate_expert_tn_bar_pos_k = loglik_aggre_gate_dim_comp(gate_expert_tn_bar_pos_comp_k)

gate_expert_ll = loglik_aggre_gate_dim_comp(gate_expert_ll_comp)
gate_expert_tn = loglik_aggre_gate_dim_comp(gate_expert_tn_comp)
gate_expert_tn_bar = loglik_aggre_gate_dim_comp(gate_expert_tn_bar_comp)
gate_expert_tn_bar_k = loglik_aggre_gate_dim_comp(gate_expert_tn_bar_comp_k)

# Normalize by tn & tn_bar
norm_gate_expert_ll_pos = gate_expert_ll_pos - gate_expert_tn_pos
Expand Down Expand Up @@ -110,18 +114,22 @@ function loglik_np(Y, gate, model)
gate_expert_ll_pos_comp = gate_expert_ll_pos_comp,
gate_expert_tn_pos_comp = gate_expert_tn_pos_comp,
gate_expert_tn_bar_pos_comp = gate_expert_tn_bar_pos_comp,
gate_expert_tn_bar_pos_comp_k = gate_expert_tn_bar_pos_comp_k,

gate_expert_ll_comp = gate_expert_ll_comp,
gate_expert_tn_comp = gate_expert_tn_comp,
gate_expert_tn_bar_comp = gate_expert_tn_bar_comp,
gate_expert_tn_bar_comp_k = gate_expert_tn_bar_comp_k,

gate_expert_ll_pos = gate_expert_ll_pos,
gate_expert_tn_pos = gate_expert_tn_pos,
gate_expert_tn_bar_pos = gate_expert_tn_bar_pos,
gate_expert_tn_bar_pos_k = gate_expert_tn_bar_pos_k,

gate_expert_ll = gate_expert_ll,
gate_expert_tn = gate_expert_tn,
gate_expert_tn_bar = gate_expert_tn_bar,
gate_expert_tn_bar_k = gate_expert_tn_bar_k,

norm_gate_expert_ll_pos = norm_gate_expert_ll_pos,
norm_gate_expert_ll = norm_gate_expert_ll,
Expand Down
16 changes: 16 additions & 0 deletions src/simulation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
function sim_components(model, sample_size)
return [[hcat([sim_expert(model[j, k], sample_size) for k in 1:size(model)[2]]...) for j in 1:size(model)[1]]...]
end

function sim_logit_gating(α, X)
probs = exp.(LogitGating(α, X))
return hcat([rand(Distributions.Multinomial(1, probs[i,:])) for i in 1:size(X)[1]]...)'
end

function sim_dataset(α, X, model)
dim_comp_sim = sim_components(model, size(X)[1])
gating_sim = sim_logit_gating(α, X)
return hcat([sum(gating_sim .* dim_comp_sim[j], dims = 2) for j in 1:size(model)[1]]...)
end


96 changes: 79 additions & 17 deletions test/fit_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,102 @@ using Test
using Distributions
using StatsFuns

μ = 1
σ = 2
using LRMoE

@testset "fitting" begin

# d = Distributions.Gamma(1.0, 2.0)

# μμ = rand(d, 1)
# σσ = rand(d, 1)

# for (μ, σ) in zip(μμ, σσ)
# l = Distributions.LogNormal(μ, σ)
# y = rand(l, 200)

# # LogNormal
# r = LRMoE.LogNormalExpert(μ, σ)
# Y = hcat(fill(0, length(y)), y, y, fill(Inf, length(y)), fill(0, length(y)), 0.80.*y, 1.25.*y, fill(Inf, length(y)))

# model = [LRMoE.LogNormalExpert(μ, σ) LRMoE.LogNormalExpert(0.5*μ, 0.6*σ) LRMoE.LogNormalExpert(1.5*μ, 2.0*σ);
# LRMoE.LogNormalExpert(1.2*μ, σ) LRMoE.LogNormalExpert(0.8*μ, 0.6*σ) LRMoE.LogNormalExpert(1.5*μ, 0.5*σ)]


# pen_params = [[[Inf 1.0 Inf], [Inf 1.0 Inf], [Inf 1.0 Inf]],
# [[Inf 1.0 Inf], [Inf 1.0 Inf], [Inf 1.0 Inf]]]

# X = rand(Uniform(-1, 1), 200, 5)
# α_true = rand(Uniform(-1, 1), 3, 5)
# α_true[3, :] .= 0.0


# result = fit_main(Y, X, α_true, model, penalty = false, pen_params = pen_params)

# model = [LRMoE.LogNormalExpert(μ, σ) LRMoE.LogNormalExpert(0.5*μ, 0.6*σ) LRMoE.LogNormalExpert(1.5*μ, 2.0*σ);
# LRMoE.ZILogNormalExpert(0.4, μ, σ) LRMoE.LogNormalExpert(0.5*μ, 0.6*σ) LRMoE.ZILogNormalExpert(0.80, 1.5*μ, 2.0*σ)]

# result = fit_main(Y, X, α_true, model, penalty = false, pen_params = pen_params)

# end

end

@testset "fitting simulated data" begin

d = Distributions.Gamma(1.0, 2.0)

μμ = rand(d, 1)
σσ = rand(d, 1)

for (μ, σ) in zip(μμ, σσ)
l = Distributions.LogNormal(μ, σ)
y = rand(l, 200)

# LogNormal
r = LRMoE.LogNormalExpert(μ, σ)
Y = hcat(fill(0, length(y)), y, y, fill(Inf, length(y)), fill(0, length(y)), 0.80.*y, 1.25.*y, fill(Inf, length(y)))

X = rand(Uniform(-1, 1), 20000, 5)
α_true = rand(Uniform(-1, 1), 3, 5)
α_true[3, :] .= 0.0
model = [LRMoE.LogNormalExpert(μ, σ) LRMoE.LogNormalExpert(0.5*μ, 0.6*σ) LRMoE.LogNormalExpert(1.5*μ, 2.0*σ);
LRMoE.LogNormalExpert(1.2*μ, σ) LRMoE.LogNormalExpert(0.8*μ, 0.6*σ) LRMoE.LogNormalExpert(1.5*μ, 0.5*σ)]

LRMoE.ZILogNormalExpert(0.4, μ, σ) LRMoE.LogNormalExpert(0.5*μ, 0.6*σ) LRMoE.ZILogNormalExpert(0.80, 1.5*μ, 2.0*σ)]

pen_params = [[[Inf 1.0 Inf], [Inf 1.0 Inf], [Inf 1.0 Inf]],
[[Inf 1.0 Inf], [Inf 1.0 Inf], [Inf 1.0 Inf]]]
[[Inf 1.0 Inf], [Inf 1.0 Inf], [Inf 1.0 Inf]]]

X = rand(Uniform(-1, 1), 200, 5)
α_true = rand(Uniform(-1, 1), 3, 5)
α_true[3, :] .= 0.0
Y_sim = sim_dataset(α_true, X, model)
α_guess = copy(α_true)
α_guess .= 0.0
model_guess = [LRMoE.LogNormalExpert(0.8*μ, 1.2*σ) LRMoE.LogNormalExpert(μ, 0.9*σ) LRMoE.LogNormalExpert(1.0*μ, 2.5*σ);
LRMoE.ZILogNormalExpert(0.50, 2.0*μ, 1.2*σ) LRMoE.LogNormalExpert(0.75*μ, 0.3*σ) LRMoE.ZILogNormalExpert(0.50, 1.75*μ, 1.0*σ)]

# # Exact observation
# Y = hcat(fill(0, length(Y_sim[:,1])), Y_sim[:,1], Y_sim[:,1], fill(Inf, length(Y_sim[:,1])), fill(0, length(Y_sim[:,2])), Y_sim[:,2], Y_sim[:,2], fill(Inf, length(Y_sim[:,2])))
# result = fit_main(Y, X, α_guess, model_guess, penalty = false, pen_params = pen_params)

# # With censoring
# Y = hcat(fill(0, length(Y_sim[:,1])), Y_sim[:,1], Y_sim[:,1], fill(Inf, length(Y_sim[:,1])), fill(0, length(Y_sim[:,2])), 0.80.*Y_sim[:,2], 1.20.*Y_sim[:,2], fill(Inf, length(Y_sim[:,2])))
# result = fit_main(Y, X, α_guess, model_guess, penalty = false, pen_params = pen_params)

result = fit_main(Y, X, α_true, model, penalty = false, pen_params = pen_params)
# Y = hcat(fill(0, length(Y_sim[:,1])), 0.75.*Y_sim[:,1], fill(Inf, length(Y_sim[:,1])), fill(Inf, length(Y_sim[:,1])), fill(0, length(Y_sim[:,2])), 0.80.*Y_sim[:,2], 1.20.*Y_sim[:,2], fill(Inf, length(Y_sim[:,2])))
# result = fit_main(Y, X, α_guess, model_guess, penalty = false, pen_params = pen_params)

model = [LRMoE.LogNormalExpert(μ, σ) LRMoE.LogNormalExpert(0.5*μ, 0.6*σ) LRMoE.LogNormalExpert(1.5*μ, 2.0*σ);
LRMoE.ZILogNormalExpert(0.4, μ, σ) LRMoE.LogNormalExpert(0.5*μ, 0.6*σ) LRMoE.ZILogNormalExpert(0.80, 1.5*μ, 2.0*σ)]
# With truncation
# Y = hcat(fill(0.0, length(Y_sim[:,1])), Y_sim[:,1], Y_sim[:,1], fill(Inf, length(Y_sim[:,1])), 0.30.*Y_sim[:,2], 0.80.*Y_sim[:,2], 1.20.*Y_sim[:,2], fill(Inf, length(Y_sim[:,2])))
# Y = hcat( 0.80.*Y_sim[:,1], Y_sim[:,1], Y_sim[:,1], fill(Inf, length(Y_sim[:,1])), 0.30.*Y_sim[:,2], 0.80.*Y_sim[:,2], 1.20.*Y_sim[:,2], fill(Inf, length(Y_sim[:,2])))
# Y = hcat(fill(0, length(Y_sim[:,1])), Y_sim[:,1], Y_sim[:,1], fill(Inf, length(Y_sim[:,1])), fill(0.0, length(Y_sim[:,2])), 0.80.*Y_sim[:,2], 1.20.*Y_sim[:,2], 2.50.*Y_sim[:,2])
# result = fit_main(Y, X, α_guess, model_guess, penalty = false, pen_params = pen_params)

result = fit_main(Y, X, α_true, model, penalty = false, pen_params = pen_params)

# Y = hcat(fill(0, length(Y_sim[:,1])), Y_sim[:,1], Y_sim[:,1], 2.0 .*Y_sim[:,1], fill(0.0, length(Y_sim[:,2])), Y_sim[:,2], Y_sim[:,2], fill(Inf, length(Y_sim[:,2])))
# result = fit_main(Y, X, α_guess, model_guess, penalty = false, pen_params = pen_params, ecm_iter_max = 20)



# model_guess = [LRMoE.ZILogNormalExpert(0.50, 0.8*μ, 1.2*σ) LRMoE.ZILogNormalExpert(0.50, μ, 0.9*σ) LRMoE.ZILogNormalExpert(0.50, 1.0*μ, 2.5*σ);
# LRMoE.ZILogNormalExpert(0.50, 2.0*μ, 1.2*σ) LRMoE.ZILogNormalExpert(0.50, 0.75*μ, 0.3*σ) LRMoE.ZILogNormalExpert(0.50, 1.75*μ, 1.0*σ)]

# Y = hcat(fill(0, length(Y_sim[:,1])), Y_sim[:,1], Y_sim[:,1], 2.0 .*Y_sim[:,1], fill(0.0, length(Y_sim[:,2])), Y_sim[:,2], Y_sim[:,2], fill(Inf, length(Y_sim[:,2])))
# result = fit_main(Y, X, α_guess, model_guess, penalty = false, pen_params = pen_params, ecm_iter_max = 20)



end

end

0 comments on commit f558210

Please sign in to comment.