Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support of ReverseDiff.jl as AD backend #428

Merged
merged 60 commits into from
Apr 4, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
b80471d
Fix dep log in lad
xukai92 Jan 18, 2018
615d152
Dont send opt res
xukai92 Jan 18, 2018
9a4ded5
Fix VarInfo.show bug
xukai92 Jan 18, 2018
c9d2bc8
Fix auto tune
xukai92 Jan 18, 2018
f077a34
Change * to .* in leapfrog
xukai92 Jan 18, 2018
4ce4f0f
temp fix type
xukai92 Jan 25, 2018
c5efcd4
Merge branch 'master' into reverse-diff
xukai92 Feb 12, 2018
ef7dd28
Disable @suppress_err temporarily
xukai92 Feb 12, 2018
acc056f
Fix a dep
xukai92 Feb 12, 2018
4157483
Workable ReverseDiff v0.1 done
xukai92 Feb 13, 2018
78174ba
Add RevDiff to REQUIRE
xukai92 Feb 13, 2018
dc6447b
Fix bug in R-AD
xukai92 Feb 13, 2018
2c49882
Fix some bugs
xukai92 Feb 13, 2018
65f56b2
Fix bugs
xukai92 Feb 13, 2018
0cc917f
Update test
xukai92 Feb 14, 2018
2cc71f1
ReversedDiff.jl mutable bug fixed
xukai92 Feb 14, 2018
25854ba
Any to Real
xukai92 Feb 15, 2018
c433175
update benchmark
xukai92 Feb 15, 2018
0d7256f
Resolve mem alloc for simplex dist
xukai92 Feb 15, 2018
b311346
Fix bug and improve mem alloc
xukai92 Feb 16, 2018
c4ad9a1
Improve implementaion of transformations
xukai92 Feb 16, 2018
78b6ed6
Don't include compile time in benchk
xukai92 Feb 18, 2018
d641f42
Resolve slowness caused by use of vi.logp
xukai92 Feb 19, 2018
3a04c42
Update benchmark files
xukai92 Feb 19, 2018
24d9a00
Add line to load pickle
xukai92 Feb 20, 2018
7b68367
Bugfix with reject
xukai92 Feb 20, 2018
c1066e5
Merge branch 'reverse-diff' of https://github.com/yebai/Turing.jl int…
xukai92 Feb 20, 2018
a68e9af
Using ReverseDiff.jl and unsafe model as default
xukai92 Feb 24, 2018
117c5eb
Fix bug in test file
xukai92 Feb 24, 2018
a75a2ad
Rename vi.rs to vi.rvs
xukai92 Feb 25, 2018
dc80595
Add Naive Bayes model in Turing
xukai92 Feb 25, 2018
336b6aa
Add NB to travis
xukai92 Feb 25, 2018
858bc31
DA works
xukai92 Feb 28, 2018
429540a
Tune init
xukai92 Feb 28, 2018
cd5c666
Better init
xukai92 Feb 28, 2018
9db337e
NB MNIST Stan added
xukai92 Feb 28, 2018
ddc251e
Improve ad assignment
xukai92 Feb 28, 2018
9598755
Improve ad assignment
xukai92 Feb 28, 2018
c061627
Add Stan SV model
xukai92 Feb 28, 2018
3fa4475
Improve transform typing
xukai92 Feb 28, 2018
091f546
Finish HMM model
xukai92 Feb 28, 2018
6404924
High dim gauss done
xukai92 Mar 1, 2018
b2b8ea7
Benchmakr v2.0 done
xukai92 Mar 1, 2018
1b7ec97
Modulize var estimator and fix transform.jl
xukai92 Mar 5, 2018
6c10277
Run with ForwardDiff
xukai92 Mar 6, 2018
017e2cd
Enable Stan for LDA bench
xukai92 Mar 6, 2018
577996b
Fix a bug in adapt
xukai92 Mar 14, 2018
1df6045
Improve some code
xukai92 Mar 14, 2018
6eefadf
Fix bug in NUTS MH step (#324)
yebai Apr 3, 2018
4b5f9d3
Add interface for optionally enabling adaption.
yebai Apr 3, 2018
b65bf28
Do not adapt step size when numerical error is caught.
yebai Apr 3, 2018
ccac414
Fix initial epsilon_bar.
yebai Apr 3, 2018
b9d2b76
Fix missing t_valid.
yebai Apr 3, 2018
d47bfae
Drop incorrectly adapted step size when necessary (#324)
yebai Apr 3, 2018
0a1f8d8
Edit warning message.
yebai Apr 3, 2018
f3c88f0
Small tweaks.
yebai Apr 3, 2018
240c651
reset_da ==> restart_da
yebai Apr 3, 2018
3f7fc5d
address suggested naming
xukai92 Apr 3, 2018
bf4d53d
Samler type for WarmUpManager.paras and notation tweaks.
yebai Apr 4, 2018
1770ae9
Bugfix and adapt_step_size == > adapt_step_size!
yebai Apr 4, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ script:
elseif ENV["GROUP"] == "SV"
include(Pkg.dir("Turing")*"/benchmarks/install_deps.jl");
include(Pkg.dir("Turing")*"/benchmarks/sv.run.jl")
elseif ENV["GROUP"] == "NB"
include(Pkg.dir("Turing")*"/example-models/aistats2018/naive-bayes.run.jl")
elseif ENV["GROUP"] == "Opt"
include(Pkg.dir("Turing")*"/benchmarks/install_deps.jl");
include(Pkg.dir("Turing")*"/benchmarks/optimization.jl")
Expand Down
1 change: 1 addition & 0 deletions REQUIRE
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ julia 0.6
Stan
Distributions 0.11.0
ForwardDiff
ReverseDiff
Mamba

ProgressMeter
6 changes: 3 additions & 3 deletions benchmarks/MoC-stan.run.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,20 @@ include(Pkg.dir("Turing")*"/example-models/stan-models/MoC-stan.model.jl")
stan_model_name = "Naive_Bayes"
nbstan = Stanmodel(Sample(algorithm=Stan.Hmc(Stan.Static(0.05),Stan.diag_e(),0.01,0.0),
save_warmup=true,adapt=Stan.Adapt(engaged=false)),
num_samples=2000, num_warmup=0, thin=1,
num_samples=5000, num_warmup=0, thin=1,
name=stan_model_name, model=naivebayesstanmodel, nchains=1);

rc, nb_stan_sim = stan(nbstan, nbstandata, CmdStanDir=CMDSTAN_HOME, summary=false);
# nb_stan_sim.names

stan_d_raw = Dict()
for i = 1:4, j = 1:10
stan_d_raw["phi[$i][$j]"] = nb_stan_sim[1001:2000, ["phi.$i.$j"], :].value[:]
stan_d_raw["phi[$i][$j]"] = nb_stan_sim[1001:end, ["phi.$i.$j"], :].value[:]
end

stan_d = Dict()
for i = 1:4
stan_d["phi[$i]"] = mean([[stan_d_raw["phi[$i][$k]"][j] for k = 1:10] for j = 1:1000])
stan_d["phi[$i]"] = mean([[stan_d_raw["phi[$i][$k]"][j] for k = 1:10] for j = 1:(nbstan.num_samples-1000)])
end

nb_time = get_stan_time(stan_model_name)
4 changes: 3 additions & 1 deletion benchmarks/MoC.run.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ include(Pkg.dir("Turing")*"/benchmarks/benchmarkhelper.jl")
include(Pkg.dir("Turing")*"/example-models/stan-models/MoC-stan.data.jl")
include(Pkg.dir("Turing")*"/example-models/stan-models/MoC.model.jl")

setadbackend(:reverse_diff)

tbenchmark("HMC(20, 0.01, 5)", "nbmodel", "data=nbstandata[1]")

bench_res = tbenchmark("HMC(2000, 0.01, 5)", "nbmodel", "data=nbstandata[1]")
bench_res = tbenchmark("HMC(5000, 0.01, 5)", "nbmodel", "data=nbstandata[1]")
# bench_res = tbenchmark("HMCDA(1000, 0.65, 0.3)", "nbmodel", "data=nbstandata[1]")
# bench_res = tbenchmark("NUTS(2000, 0.65)", "nbmodel", "data=nbstandata[1]")
bench_res[4].names = ["phi[1]", "phi[2]", "phi[3]", "phi[4]"]
Expand Down
8 changes: 5 additions & 3 deletions benchmarks/benchmarkhelper.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# Get running time of Stan
get_stan_time(stan_model_name::String) = begin
s = readlines(pwd()*"/tmp/$(stan_model_name)_samples_1.csv")
m = match(r"(?<time>[0-9].[0-9]*)", s[end-1])
println(s[end-1])
m = match(r"(?<time>[0-9]+.[0-9]*)", s[end-1])
float(m[:time])
end

# Run benchmark
tbenchmark(alg::String, model::String, data::String) = begin
chain, time, mem, _, _ = eval(parse("@timed sample($model($data), $alg)"))
chain, time, mem, _, _ = eval(parse("model_f = $model($data); @timed sample(model_f, $alg)"))
alg, sum(chain[:elapsed]), mem, chain, deepcopy(chain)
end

Expand All @@ -18,7 +19,7 @@ build_logd(name::String, engine::String, time, mem, tchain, _) = begin
"engine" => engine,
"time" => time,
"mem" => mem,
"turing" => Dict(v => mean(tchain[Symbol(v)]) for v in keys(tchain))
"turing" => Dict(v => mean(tchain[Symbol(v)][1001:end]) for v in keys(tchain))
)
end

Expand Down Expand Up @@ -58,6 +59,7 @@ log2str(logd::Dict, monitor=[]) = begin
end
if haskey(logd, "stan") && haskey(logd["stan"], v)
str *= ("| -> Stan = $(round(logd["stan"][v], 3)), ")
println(m, logd["stan"][v])
diff = abs(m - logd["stan"][v])
diff_output = "diff = $(round(diff, 3))"
if sum(diff) > 0.2
Expand Down
16 changes: 9 additions & 7 deletions benchmarks/lda-stan.run.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,24 @@ include(Pkg.dir("Turing")*"/example-models/stan-models/lda-stan.model.jl")
stan_model_name = "LDA"
# ldastan = Stanmodel(Sample(save_warmup=true), name=stan_model_name, model=ldastanmodel, nchains=1);
# To understand parameters, use: ?Stan.Static, ?Stan,Hmc
ldastan = Stanmodel(Sample(algorithm=Stan.Hmc(Stan.Static(0.25),Stan.diag_e(),0.025,0.0),
ldastan = Stanmodel(Sample(algorithm=Stan.Hmc(Stan.Static(0.05),Stan.diag_e(),0.005,0.0),
save_warmup=true,adapt=Stan.Adapt(engaged=false)),
num_samples=2000, num_warmup=0, thin=1,
num_samples=3000, num_warmup=0, thin=1,
name=stan_model_name, model=ldastanmodel, nchains=1);

rc, lda_stan_sim = stan(ldastan, ldastandata, CmdStanDir=CMDSTAN_HOME, summary=false);
# lda_stan_sim.names

V = ldastandata[1]["V"]
K = ldastandata[1]["K"]
lda_stan_d_raw = Dict()
for i = 1:2, j = 1:5
lda_stan_d_raw["phi[$i][$j]"] = lda_stan_sim[1001:2000, ["phi.$i.$j"], :].value[:]
for i = 1:K, j = 1:V
lda_stan_d_raw["phi[$i][$j]"] = lda_stan_sim[1001:end, ["phi.$i.$j"], :].value[:]
end

lda_stan_d = Dict()
for i = 1:2
lda_stan_d["phi[$i]"] = mean([[lda_stan_d_raw["phi[$i][$k]"][j] for k = 1:5] for j = 1:1000])
for i = 1:K
lda_stan_d["phi[$i]"] = mean([[lda_stan_d_raw["phi[$i][$k]"][j] for k = 1:V] for j = 1:(ldastan.num_samples-1000)])
end

lda_time = get_stan_time(stan_model_name)
println("Stan time: ", lda_time)
24 changes: 18 additions & 6 deletions benchmarks/lda.run.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,26 @@ include(Pkg.dir("Turing")*"/example-models/stan-models/lda.model.jl")

include(Pkg.dir("Turing")*"/benchmarks/"*"lda-stan.run.jl")

setchunksize(60)
# setchunksize(100)
setadbackend(:reverse_diff)
# setadbackend(:forward_diff)

#for alg in ["HMC(2000, 0.25, 10)", "HMCDA(1000, 0.65, 1.5)", "NUTS(2000, 1000, 0.65)"]
tbenchmark("HMC(20, 0.025, 10)", "ldamodel_vec", "data=ldastandata[1]") # first run for compilation
# setadsafe(false)

for (modelc, modeln) in zip(["ldamodel_vec", "ldamodel"], ["LDA-vec", "LDA"])
bench_res = tbenchmark("HMC(2000, 0.025, 10)", modelc, "data=ldastandata[1]")
bench_res[4].names = ["phi[1]", "phi[2]"]
# tbenchmark("HMC(2, 0.025, 10)", "ldamodel", "data=ldastandata[1]")

turnprogress(false)

for (modelc, modeln) in zip([
# "ldamodel_vec",
"ldamodel"
], [
# "LDA-vec",
"LDA"
])
tbenchmark("HMC(2, 0.005, 10)", modelc, "data=ldastandata[1]")
bench_res = tbenchmark("HMC(3000, 0.005, 10)", modelc, "data=ldastandata[1]")
bench_res[4].names = ["phi[$k]" for k in 1:ldastandata[1]["K"]]
logd = build_logd(modeln, bench_res...)
logd["stan"] = lda_stan_d
logd["time_stan"] = lda_time
Expand Down
14 changes: 14 additions & 0 deletions benchmarks/profile.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using HDF5, JLD, ProfileView

using Turing
setadbackend(:reverse_diff)
turnprogress(false)

include(Pkg.dir("Turing")*"/example-models/stan-models/lda-stan.data.jl")
include(Pkg.dir("Turing")*"/example-models/stan-models/lda.model.jl")

sample(ldamodel(data=ldastandata[1]), HMC(2, 0.025, 10))
Profile.clear()
@profile sample(ldamodel(data=ldastandata[1]), HMC(2000, 0.025, 10))

ProfileView.svgwrite("ldamodel.svg")
19 changes: 11 additions & 8 deletions benchmarks/sv.run.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@ include(TPATH*"/example-models/nips-2017/sv.model.jl")
using HDF5, JLD
sv_data = load(TPATH*"/example-models/nips-2017/sv-data.jld.data")["data"]

model_f = sv_model(data=sv_data[1])
sample_n = 500
# setadbackend(:forward_diff)
# setchunksize(1000)
# chain_nuts = sample(model_f, NUTS(sample_n, 0.65))
# describe(chain_nuts)

setchunksize(550)
chain_nuts = sample(model_f, NUTS(sample_n, 0.65))
describe(chain_nuts)
# setchunksize(5)
# chain_gibbs = sample(model_f, Gibbs(sample_n, PG(50,1,:h), NUTS(1000,0.65,:ϕ,:σ,:μ)))
# describe(chain_gibbs)

setchunksize(5)
chain_gibbs = sample(model_f, Gibbs(sample_n, PG(50,1,:h), NUTS(1000,0.65,:ϕ,:σ,:μ)))
describe(chain_gibbs)
turnprogress(false)

mf = sv_model(data=sv_data[1])
chain_nuts = sample(mf, HMC(2000, 0.05, 10))
3 changes: 3 additions & 0 deletions example-models/aistats2018/high-dim-gauss.data.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
using HDF5, JLD

const hdgdata = [Dict("D"=>100000)]
10 changes: 10 additions & 0 deletions example-models/aistats2018/high-dim-gauss.model.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
using Turing

@model high_dim_gauss(D) = begin

mu ~ MvNormal(zeros(D), ones(D))

# mu = Vector{Real}(D)
# mu ~ [Normal(0, 1)]

end
9 changes: 9 additions & 0 deletions example-models/aistats2018/high-dim-gauss.run.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
using Turing

include(Pkg.dir("Turing")*"/example-models/aistats2018/high-dim-gauss.data.jl")
include(Pkg.dir("Turing")*"/example-models/aistats2018/high-dim-gauss.model.jl")

turnprogress(false)

mf = high_dim_gauss(data=hdgdata[1])
chn = sample(mf, HMC(1000, 0.05, 5))
24 changes: 24 additions & 0 deletions example-models/aistats2018/high-dim-gauss.stan.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
include(Pkg.dir("Turing")*"/benchmarks/benchmarkhelper.jl")
using Stan, HDF5, JLD

const hdgstanmodel = "
data {
int D;
}
parameters {
real mu[D];
}
model {
for (d in 1:D)
mu[d] ~ normal(0, 1);
}
"

include(Pkg.dir("Turing")*"/example-models/aistats2018/high-dim-gauss.data.jl")

hdgstan = Stanmodel(Sample(algorithm=Stan.Hmc(Stan.Static(0.25),Stan.diag_e(),0.05,0.0), save_warmup=true,adapt=Stan.Adapt(engaged=false)), num_samples=1000, num_warmup=0, thin=1, name="High_Dim_Gauss", model=hdgstanmodel, nchains=1);

rc, hdg_stan_sim = stan(hdgstan, hdgdata, CmdStanDir=CMDSTAN_HOME, summary=false);

hdg_time = get_stan_time("High_Dim_Gauss")
println("Time used:", hdg_time)
49 changes: 49 additions & 0 deletions example-models/aistats2018/hmm.run.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
TPATH = Pkg.dir("Turing")
include(TPATH*"/example-models/nips-2017/hmm.model.jl")


using HDF5, JLD
const hmm_semisup_data = load(TPATH*"/example-models/nips-2017/hmm_semisup_data.jld")["data"]


# collapsed = false

# S = 4 # number of samplers
# spls = [Gibbs(N,PG(50,1,:y),HMC(1,0.25,6,:phi,:theta)),
# Gibbs(N,PG(50,1,:y),HMCDA(1,200,0.65,0.75,:phi,:theta)),
# Gibbs(N,PG(50,1,:y),NUTS(1,200,0.65,:phi,:theta)),
# PG(50,N)][1:S]


# spl_names = ["Gibbs($N,PG(50,1,:y),HMC(1,0.25,6,:phi,:theta))",
# "Gibbs($N,PG(50,1,:y),HMCDA(1,200,0.65,0.75,:phi,:theta))",
# "Gibbs($N,PG(50,1,:y),NUTS(1,200,0.65,:phi,:theta))",
# "PG(50,$N)"][1:S]
# for i in 1:S
# println("$(spl_names[i]) running")
# #if i != 1 && i != 2 # i=1 already done
# chain = sample(hmm_semisup(data=hmm_semisup_data[1]), spls[i])

# save(TPATH*"/example-models/nips-2017/hmm-uncollapsed-$(spl_names[i])-chain.jld", "chain", chain)
# #end
# end

# setadbackend(:forward_diff)
# setchunksize(70)
turnprogress(false)

collapsed = true

S = 1 # number of samplers
N = 2000
spls = [HMC(N,0.005,5)][1:S]

# S = 4 # number of samplers
# spls = [HMC(N,0.25,6),HMCDA(N,200,0.65,0.75),NUTS(N,200,0.65),PG(50,N)][1:S]
# spl_names = ["HMC($N,0.05,6)","HMCDA($N,200,0.65,0.35)","NUTS($N,200,0.65)","PG(50,$N)"][1:S]
mf = hmm_semisup_collapsed(data=hmm_semisup_data[1])
for i in 1:S
chain = sample(mf, spls[i])

# save(TPATH*"/example-models/nips-2017/hmm-collapsed-$(spl_names[i])-chain.jld", "chain", chain)
end
55 changes: 55 additions & 0 deletions example-models/aistats2018/hmm.stan.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
include(Pkg.dir("Turing")*"/benchmarks/benchmarkhelper.jl")
using Stan, HDF5, JLD

const hmmstanmodel = "
data {
int<lower=1> K; // num categories
int<lower=1> V; // num words
int<lower=0> T; // num supervised items
int<lower=1> T_unsup; // num unsupervised items
int<lower=1,upper=V> w[T]; // words
int<lower=1,upper=K> z[T]; // categories
int<lower=1,upper=V> u[T_unsup]; // unsup words
vector<lower=0>[K] alpha; // transit prior
vector<lower=0>[V] beta; // emit prior
}
parameters {
simplex[K] theta[K]; // transit probs
simplex[V] phi[K]; // emit probs
}
model {
for (k in 1:K)
theta[k] ~ dirichlet(alpha);
for (k in 1:K)
phi[k] ~ dirichlet(beta);
for (t in 1:T)
w[t] ~ categorical(phi[z[t]]);
for (t in 2:T)
z[t] ~ categorical(theta[z[t-1]]);

{
// forward algorithm computes log p(u|...)
real acc[K];
real gamma[T_unsup,K];
for (k in 1:K)
gamma[1,k] <- log(phi[k,u[1]]);
for (t in 2:T_unsup) {
for (k in 1:K) {
for (j in 1:K)
acc[j] <- gamma[t-1,j] + log(theta[j,k]) + log(phi[k,u[t]]);
gamma[t,k] <- log_sum_exp(acc);
}
}
increment_log_prob(log_sum_exp(gamma[T_unsup]));
}
}
"

const hmm_semisup_data = load(Pkg.dir("Turing")*"/example-models/nips-2017/hmm_semisup_data.jld")["data"]

hmmstan = Stanmodel(Sample(algorithm=Stan.Hmc(Stan.Static(0.025),Stan.diag_e(),0.005,0.0), save_warmup=true,adapt=Stan.Adapt(engaged=false)), num_samples=1000, num_warmup=0, thin=1, name="Hidden_Markov", model=hmmstanmodel, nchains=1);

rc, hmm_stan_sim = stan(hmmstan, hmm_semisup_data, CmdStanDir=CMDSTAN_HOME, summary=false);

sv_time = get_stan_time("Hidden_Markov")
println("Time used:", sv_time)
Binary file added example-models/aistats2018/mnist-10000-40.data
Binary file not shown.
3 changes: 3 additions & 0 deletions example-models/aistats2018/naive-bayes.data.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
using HDF5, JLD

const nbmnistdata = load(Pkg.dir("Turing")*"/example-models/aistats2018/mnist-10000-40.data")["data"]
31 changes: 31 additions & 0 deletions example-models/aistats2018/naive-bayes.model.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using Turing

@model nb(images, labels, C, D, N, label_mask) = begin

m = Vector{Vector{Real}}(C)
# @simd for c = 1:C
# @inbounds m[c] ~ MvNormal(zeros(D), 10 * ones(D))
# end

m ~ [MvNormal(zeros(D), 10 * ones(D))]

# for c = 1:C
# @simd for d = 1:D
# @inbounds _lp += sum(logpdf.(Normal(m[c][d], 1), images[d, label_mask[c]]))
# end
# end

@simd for n = 1:N
@inbounds _lp += logpdf(MvNormal(zeros(D), 10 * ones(D)), images[:,n] - m[labels[n]])
end

# @simd for c = 1:C
# @inbounds _lp += mapreduce(d -> sum(logpdf.(Normal(m[c][d], λ), images[d, label_mask[c]])), +, 1:D)
# end

# @simd for n = 1:N
# # @inbounds _lp += sum(logpdf.(Normal(m[l][d], λ), images[d, label_mask[l]]))
# @inbounds images[:,n] ~ MvNormal(m[labels[n]], ones(D))
# end

end
Loading