-
Notifications
You must be signed in to change notification settings - Fork 220
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support of ReverseDiff.jl as AD backend #428
Conversation
@@ -0,0 +1,14 @@ | |||
using Turing: VarEstimator, add_sample!, get_var |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yebai Unit test for var estimator is here.
src/samplers/support/adapt.jl
Outdated
update_state(wum::WarmUpManager) = begin | ||
wum.iter_n += 1 # update iteration number | ||
# Ref: https://github.com/stan-dev/stan/blob/develop/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp | ||
update_pre_cond(wum::WarmUpManager, θ_new) = begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change update_pre_cond
to update_pre_cond!
since it modifies wum
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we only support diagonal preconditioning matrix at the moment?
@@ -10,7 +11,7 @@ end | |||
|
|||
sample_momentum(vi::VarInfo, spl::Sampler) = begin | |||
dprintln(2, "sampling momentum...") | |||
randn(length(getranges(vi, spl))) .* spl.info[:wum][:stds] | |||
randn(length(getranges(vi, spl))) ./ spl.info[:wum][:stds] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yebai how to use pre-cond (1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is correct
src/samplers/support/hmc_core.jl
Outdated
@@ -30,13 +35,35 @@ leapfrog(_θ::Union{Vector,SubArray}, p::Vector{Float64}, τ::Int, ϵ::Float64, | |||
p_old = p; θ_old = copy(θ); old_logp = getlogp(vi) | |||
|
|||
p -= ϵ .* grad / 2 | |||
θ += ϵ .* p # full step for state | |||
θ += ϵ .* p .* (spl.info[:wum][:stds].^2) # full step for state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yebai how to use pre-cond (2)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is an error here, the correct way should be
θ += ϵ .* p
the current code effectively cancelled pred-cond(1) operation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why does Stan do this? I made these changes mainly by referring https://github.com/yebai/Turing.jl/blob/1df604591e4e4297261dfcbb478ca51059ecea9d/src/samplers/support/hmc_core.jl
src/samplers/support/adapt.jl
Outdated
type WarmUpManager | ||
iter_n :: Int | ||
state :: Int | ||
adapt_n :: Int | ||
params :: Dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use a more concrete type for params :: Dict
if possible to improve efficiency.
@@ -52,35 +79,58 @@ find_H(p::Vector, model::Function, vi::VarInfo, spl::Sampler) = begin | |||
# This can be a result of link/invlink (where expand! is used) | |||
if getlogp(vi) == 0 vi = runmodel(model, vi, spl) end | |||
|
|||
p_orig = p ./ spl.info[:wum][:stds] | |||
p_orig = p .* spl.info[:wum][:stds] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yebai how to use pre-cond (3)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is correct; perhaps we can change p_orig
to p_prime
p_prime = p .* spl.info[:wum][:stds]
H = dot(p_prime, p_prime) / 2 + realpart(-getlogp(vi))
src/samplers/support/adapt.jl
Outdated
wum[:stds] = ones(D) | ||
wum[:vars] = ones(D) | ||
|
||
wum[:stds] = 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perhaps wum[:stds] = ones (D)
?
@yebai I tagged you a few places in the code. For adaptation in general, codes are mostly in |
src/samplers/support/adapt.jl
Outdated
|
||
spl.info[:wum] = wum | ||
end | ||
|
||
update_da_params(wum::WarmUpManager, ϵ::Float64) = begin | ||
wum[:ϵ] = [ϵ] | ||
reset_da(wum::WarmUpManager) = begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perhaps change reset_da
to reset_da!
src/samplers/support/adapt.jl
Outdated
end | ||
|
||
return false | ||
|
||
end | ||
|
||
adapt(wum::WarmUpManager, stats::Float64, θ_new) = begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perhaps change adapt
to adapt!
?
src/samplers/support/adapt.jl
Outdated
# Dual averaging | ||
wum[:ϵ] = [] | ||
reset_da(wum) | ||
wum[:n_warmup] = spl.alg.n_adapt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perhaps merge spl.alg.n_adapt
and spl.alg.delta
into spl.info[:adapt_conf]
?
@yebai Anything left to do in this PR? |
* Fix dep log in lad * Dont send opt res * Fix VarInfo.show bug * Fix auto tune * Change * to .* in leapfrog * temp fix type * Disable @suppress_err temporarily * Fix a dep * Workable ReverseDiff v0.1 done * Add RevDiff to REQUIRE * Fix bug in R-AD * Fix some bugs * Fix bugs * Update test * ReversedDiff.jl mutable bug fixed * Any to Real * update benchmark * Resolve mem alloc for simplex dist * Fix bug and improve mem alloc * Improve implementaion of transformations * Don't include compile time in benchk * Resolve slowness caused by use of vi.logp * Update benchmark files * Add line to load pickle * Bugfix with reject * Using ReverseDiff.jl and unsafe model as default * Fix bug in test file * Rename vi.rs to vi.rvs * Add Naive Bayes model in Turing * Add NB to travis * DA works * Tune init * Better init * NB MNIST Stan added * Improve ad assignment * Improve ad assignment * Add Stan SV model * Improve transform typing * Finish HMM model * High dim gauss done * Benchmakr v2.0 done * Modulize var estimator and fix transform.jl * Run with ForwardDiff * Enable Stan for LDA bench * Fix a bug in adapt * Improve some code * Fix bug in NUTS MH step (#324) * Add interface for optionally enabling adaption. * Do not adapt step size when numerical error is caught. * Fix initial epsilon_bar. * Fix missing t_valid. * Drop incorrectly adapted step size when necessary (#324) * Edit warning message. * Small tweaks. * reset_da ==> restart_da * address suggested naming * Samler type for WarmUpManager.paras and notation tweaks. * Bugfix and adapt_step_size == > adapt_step_size!
TODOs