-
Notifications
You must be signed in to change notification settings - Fork 220
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support of ReverseDiff.jl as AD backend (#428)
* Fix dep log in lad * Dont send opt res * Fix VarInfo.show bug * Fix auto tune * Change * to .* in leapfrog * temp fix type * Disable @suppress_err temporarily * Fix a dep * Workable ReverseDiff v0.1 done * Add RevDiff to REQUIRE * Fix bug in R-AD * Fix some bugs * Fix bugs * Update test * ReversedDiff.jl mutable bug fixed * Any to Real * update benchmark * Resolve mem alloc for simplex dist * Fix bug and improve mem alloc * Improve implementaion of transformations * Don't include compile time in benchk * Resolve slowness caused by use of vi.logp * Update benchmark files * Add line to load pickle * Bugfix with reject * Using ReverseDiff.jl and unsafe model as default * Fix bug in test file * Rename vi.rs to vi.rvs * Add Naive Bayes model in Turing * Add NB to travis * DA works * Tune init * Better init * NB MNIST Stan added * Improve ad assignment * Improve ad assignment * Add Stan SV model * Improve transform typing * Finish HMM model * High dim gauss done * Benchmakr v2.0 done * Modulize var estimator and fix transform.jl * Run with ForwardDiff * Enable Stan for LDA bench * Fix a bug in adapt * Improve some code * Fix bug in NUTS MH step (#324) * Add interface for optionally enabling adaption. * Do not adapt step size when numerical error is caught. * Fix initial epsilon_bar. * Fix missing t_valid. * Drop incorrectly adapted step size when necessary (#324) * Edit warning message. * Small tweaks. * reset_da ==> restart_da * address suggested naming * Samler type for WarmUpManager.paras and notation tweaks. * Bugfix and adapt_step_size == > adapt_step_size!
- Loading branch information
Showing
61 changed files
with
1,236 additions
and
373 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ julia 0.6 | |
Stan | ||
Distributions 0.11.0 | ||
ForwardDiff | ||
ReverseDiff | ||
Mamba | ||
|
||
ProgressMeter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
using HDF5, JLD, ProfileView | ||
|
||
using Turing | ||
setadbackend(:reverse_diff) | ||
turnprogress(false) | ||
|
||
include(Pkg.dir("Turing")*"/example-models/stan-models/lda-stan.data.jl") | ||
include(Pkg.dir("Turing")*"/example-models/stan-models/lda.model.jl") | ||
|
||
sample(ldamodel(data=ldastandata[1]), HMC(2, 0.025, 10)) | ||
Profile.clear() | ||
@profile sample(ldamodel(data=ldastandata[1]), HMC(2000, 0.025, 10)) | ||
|
||
ProfileView.svgwrite("ldamodel.svg") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
using HDF5, JLD | ||
|
||
const hdgdata = [Dict("D"=>100000)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
using Turing | ||
|
||
@model high_dim_gauss(D) = begin | ||
|
||
mu ~ MvNormal(zeros(D), ones(D)) | ||
|
||
# mu = Vector{Real}(D) | ||
# mu ~ [Normal(0, 1)] | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
using Turing | ||
|
||
include(Pkg.dir("Turing")*"/example-models/aistats2018/high-dim-gauss.data.jl") | ||
include(Pkg.dir("Turing")*"/example-models/aistats2018/high-dim-gauss.model.jl") | ||
|
||
turnprogress(false) | ||
|
||
mf = high_dim_gauss(data=hdgdata[1]) | ||
chn = sample(mf, HMC(1000, 0.05, 5)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
include(Pkg.dir("Turing")*"/benchmarks/benchmarkhelper.jl") | ||
using Stan, HDF5, JLD | ||
|
||
const hdgstanmodel = " | ||
data { | ||
int D; | ||
} | ||
parameters { | ||
real mu[D]; | ||
} | ||
model { | ||
for (d in 1:D) | ||
mu[d] ~ normal(0, 1); | ||
} | ||
" | ||
|
||
include(Pkg.dir("Turing")*"/example-models/aistats2018/high-dim-gauss.data.jl") | ||
|
||
hdgstan = Stanmodel(Sample(algorithm=Stan.Hmc(Stan.Static(0.25),Stan.diag_e(),0.05,0.0), save_warmup=true,adapt=Stan.Adapt(engaged=false)), num_samples=1000, num_warmup=0, thin=1, name="High_Dim_Gauss", model=hdgstanmodel, nchains=1); | ||
|
||
rc, hdg_stan_sim = stan(hdgstan, hdgdata, CmdStanDir=CMDSTAN_HOME, summary=false); | ||
|
||
hdg_time = get_stan_time("High_Dim_Gauss") | ||
println("Time used:", hdg_time) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
TPATH = Pkg.dir("Turing") | ||
include(TPATH*"/example-models/nips-2017/hmm.model.jl") | ||
|
||
|
||
using HDF5, JLD | ||
const hmm_semisup_data = load(TPATH*"/example-models/nips-2017/hmm_semisup_data.jld")["data"] | ||
|
||
|
||
# collapsed = false | ||
|
||
# S = 4 # number of samplers | ||
# spls = [Gibbs(N,PG(50,1,:y),HMC(1,0.25,6,:phi,:theta)), | ||
# Gibbs(N,PG(50,1,:y),HMCDA(1,200,0.65,0.75,:phi,:theta)), | ||
# Gibbs(N,PG(50,1,:y),NUTS(1,200,0.65,:phi,:theta)), | ||
# PG(50,N)][1:S] | ||
|
||
|
||
# spl_names = ["Gibbs($N,PG(50,1,:y),HMC(1,0.25,6,:phi,:theta))", | ||
# "Gibbs($N,PG(50,1,:y),HMCDA(1,200,0.65,0.75,:phi,:theta))", | ||
# "Gibbs($N,PG(50,1,:y),NUTS(1,200,0.65,:phi,:theta))", | ||
# "PG(50,$N)"][1:S] | ||
# for i in 1:S | ||
# println("$(spl_names[i]) running") | ||
# #if i != 1 && i != 2 # i=1 already done | ||
# chain = sample(hmm_semisup(data=hmm_semisup_data[1]), spls[i]) | ||
|
||
# save(TPATH*"/example-models/nips-2017/hmm-uncollapsed-$(spl_names[i])-chain.jld", "chain", chain) | ||
# #end | ||
# end | ||
|
||
# setadbackend(:forward_diff) | ||
# setchunksize(70) | ||
turnprogress(false) | ||
|
||
collapsed = true | ||
|
||
S = 1 # number of samplers | ||
N = 2000 | ||
spls = [HMC(N,0.005,5)][1:S] | ||
|
||
# S = 4 # number of samplers | ||
# spls = [HMC(N,0.25,6),HMCDA(N,200,0.65,0.75),NUTS(N,200,0.65),PG(50,N)][1:S] | ||
# spl_names = ["HMC($N,0.05,6)","HMCDA($N,200,0.65,0.35)","NUTS($N,200,0.65)","PG(50,$N)"][1:S] | ||
mf = hmm_semisup_collapsed(data=hmm_semisup_data[1]) | ||
for i in 1:S | ||
chain = sample(mf, spls[i]) | ||
|
||
# save(TPATH*"/example-models/nips-2017/hmm-collapsed-$(spl_names[i])-chain.jld", "chain", chain) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
include(Pkg.dir("Turing")*"/benchmarks/benchmarkhelper.jl") | ||
using Stan, HDF5, JLD | ||
|
||
const hmmstanmodel = " | ||
data { | ||
int<lower=1> K; // num categories | ||
int<lower=1> V; // num words | ||
int<lower=0> T; // num supervised items | ||
int<lower=1> T_unsup; // num unsupervised items | ||
int<lower=1,upper=V> w[T]; // words | ||
int<lower=1,upper=K> z[T]; // categories | ||
int<lower=1,upper=V> u[T_unsup]; // unsup words | ||
vector<lower=0>[K] alpha; // transit prior | ||
vector<lower=0>[V] beta; // emit prior | ||
} | ||
parameters { | ||
simplex[K] theta[K]; // transit probs | ||
simplex[V] phi[K]; // emit probs | ||
} | ||
model { | ||
for (k in 1:K) | ||
theta[k] ~ dirichlet(alpha); | ||
for (k in 1:K) | ||
phi[k] ~ dirichlet(beta); | ||
for (t in 1:T) | ||
w[t] ~ categorical(phi[z[t]]); | ||
for (t in 2:T) | ||
z[t] ~ categorical(theta[z[t-1]]); | ||
{ | ||
// forward algorithm computes log p(u|...) | ||
real acc[K]; | ||
real gamma[T_unsup,K]; | ||
for (k in 1:K) | ||
gamma[1,k] <- log(phi[k,u[1]]); | ||
for (t in 2:T_unsup) { | ||
for (k in 1:K) { | ||
for (j in 1:K) | ||
acc[j] <- gamma[t-1,j] + log(theta[j,k]) + log(phi[k,u[t]]); | ||
gamma[t,k] <- log_sum_exp(acc); | ||
} | ||
} | ||
increment_log_prob(log_sum_exp(gamma[T_unsup])); | ||
} | ||
} | ||
" | ||
|
||
const hmm_semisup_data = load(Pkg.dir("Turing")*"/example-models/nips-2017/hmm_semisup_data.jld")["data"] | ||
|
||
hmmstan = Stanmodel(Sample(algorithm=Stan.Hmc(Stan.Static(0.025),Stan.diag_e(),0.005,0.0), save_warmup=true,adapt=Stan.Adapt(engaged=false)), num_samples=1000, num_warmup=0, thin=1, name="Hidden_Markov", model=hmmstanmodel, nchains=1); | ||
|
||
rc, hmm_stan_sim = stan(hmmstan, hmm_semisup_data, CmdStanDir=CMDSTAN_HOME, summary=false); | ||
|
||
sv_time = get_stan_time("Hidden_Markov") | ||
println("Time used:", sv_time) |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
using HDF5, JLD | ||
|
||
const nbmnistdata = load(Pkg.dir("Turing")*"/example-models/aistats2018/mnist-10000-40.data")["data"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
using Turing | ||
|
||
@model nb(images, labels, C, D, N, label_mask) = begin | ||
|
||
m = Vector{Vector{Real}}(C) | ||
# @simd for c = 1:C | ||
# @inbounds m[c] ~ MvNormal(zeros(D), 10 * ones(D)) | ||
# end | ||
|
||
m ~ [MvNormal(zeros(D), 10 * ones(D))] | ||
|
||
# for c = 1:C | ||
# @simd for d = 1:D | ||
# @inbounds _lp += sum(logpdf.(Normal(m[c][d], 1), images[d, label_mask[c]])) | ||
# end | ||
# end | ||
|
||
@simd for n = 1:N | ||
@inbounds _lp += logpdf(MvNormal(zeros(D), 10 * ones(D)), images[:,n] - m[labels[n]]) | ||
end | ||
|
||
# @simd for c = 1:C | ||
# @inbounds _lp += mapreduce(d -> sum(logpdf.(Normal(m[c][d], λ), images[d, label_mask[c]])), +, 1:D) | ||
# end | ||
|
||
# @simd for n = 1:N | ||
# # @inbounds _lp += sum(logpdf.(Normal(m[l][d], λ), images[d, label_mask[l]])) | ||
# @inbounds images[:,n] ~ MvNormal(m[labels[n]], ones(D)) | ||
# end | ||
|
||
end |
Oops, something went wrong.