Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support of ReverseDiff.jl as AD backend #428

Merged
merged 60 commits into from
Apr 4, 2018
Merged

Support of ReverseDiff.jl as AD backend #428

merged 60 commits into from
Apr 4, 2018

Conversation

xukai92
Copy link
Member

@xukai92 xukai92 commented Mar 7, 2018

TODOs

  • Allow SGLD and SGHMC to use ReverseDiff.jl based gradient function

@@ -0,0 +1,14 @@
using Turing: VarEstimator, add_sample!, get_var
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yebai Unit test for var estimator is here.

update_state(wum::WarmUpManager) = begin
wum.iter_n += 1 # update iteration number
# Ref: https://github.com/stan-dev/stan/blob/develop/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp
update_pre_cond(wum::WarmUpManager, θ_new) = begin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change update_pre_cond to update_pre_cond! since it modifies wum.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we only support diagonal preconditioning matrix at the moment?

@@ -10,7 +11,7 @@ end

sample_momentum(vi::VarInfo, spl::Sampler) = begin
dprintln(2, "sampling momentum...")
randn(length(getranges(vi, spl))) .* spl.info[:wum][:stds]
randn(length(getranges(vi, spl))) ./ spl.info[:wum][:stds]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yebai how to use pre-cond (1)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is correct

@@ -30,13 +35,35 @@ leapfrog(_θ::Union{Vector,SubArray}, p::Vector{Float64}, τ::Int, ϵ::Float64,
p_old = p; θ_old = copy(θ); old_logp = getlogp(vi)

p -= ϵ .* grad / 2
θ += ϵ .* p # full step for state
θ += ϵ .* p .* (spl.info[:wum][:stds].^2) # full step for state
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yebai how to use pre-cond (2)

Copy link
Member

@yebai yebai Apr 3, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is an error here, the correct way should be

θ += ϵ .* p

the current code effectively cancelled pred-cond(1) operation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type WarmUpManager
iter_n :: Int
state :: Int
adapt_n :: Int
params :: Dict
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use a more concrete type for params :: Dict if possible to improve efficiency.

@@ -52,35 +79,58 @@ find_H(p::Vector, model::Function, vi::VarInfo, spl::Sampler) = begin
# This can be a result of link/invlink (where expand! is used)
if getlogp(vi) == 0 vi = runmodel(model, vi, spl) end

p_orig = p ./ spl.info[:wum][:stds]
p_orig = p .* spl.info[:wum][:stds]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yebai how to use pre-cond (3)

Copy link
Member

@yebai yebai Apr 3, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is correct; perhaps we can change p_orig to p_prime

  p_prime = p .* spl.info[:wum][:stds]
  H = dot(p_prime, p_prime) / 2 + realpart(-getlogp(vi))

wum[:stds] = ones(D)
wum[:vars] = ones(D)

wum[:stds] = 1.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps wum[:stds] = ones (D) ?

@xukai92
Copy link
Member Author

xukai92 commented Apr 2, 2018

@yebai I tagged you a few places in the code. For adaptation in general, codes are mostly in adatp.jl, with corresponding reference source codes from Stan commented. Also in hmc_core.jl there are codes on how to use the pre-conditioning matrix - I also leave a link at the top of that file for the referred Stan code. Please take a look at them and let me know if there is anything unclear.


spl.info[:wum] = wum
end

update_da_params(wum::WarmUpManager, ϵ::Float64) = begin
wum[:ϵ] = [ϵ]
reset_da(wum::WarmUpManager) = begin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps change reset_da to reset_da!

end

return false

end

adapt(wum::WarmUpManager, stats::Float64, θ_new) = begin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps change adapt to adapt!?

# Dual averaging
wum[:ϵ] = []
reset_da(wum)
wum[:n_warmup] = spl.alg.n_adapt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps merge spl.alg.n_adapt and spl.alg.delta into spl.info[:adapt_conf]?

@xukai92
Copy link
Member Author

xukai92 commented Apr 3, 2018

@yebai Anything left to do in this PR?

@yebai yebai merged commit 43017ad into master Apr 4, 2018
@yebai yebai deleted the reverse-diff branch August 18, 2018 21:23
yebai pushed a commit that referenced this pull request Sep 18, 2018
* Fix dep log in lad

* Dont send opt res

* Fix VarInfo.show bug

* Fix auto tune

* Change * to .* in leapfrog

* temp fix type

* Disable @suppress_err temporarily

* Fix a dep

* Workable ReverseDiff v0.1 done

* Add RevDiff to REQUIRE

* Fix bug in R-AD

* Fix some bugs

* Fix bugs

* Update test

* ReversedDiff.jl mutable bug fixed

* Any to Real

* update benchmark

* Resolve mem alloc for simplex dist

* Fix bug and improve mem alloc

* Improve implementaion of transformations

* Don't include compile time in benchk

* Resolve slowness caused by use of vi.logp

* Update benchmark files

* Add line to load pickle

* Bugfix with reject

* Using ReverseDiff.jl and unsafe model as default

* Fix bug in test file

* Rename vi.rs to vi.rvs

* Add Naive Bayes model in Turing

* Add NB to travis

* DA works

* Tune init

* Better init

* NB MNIST Stan added

* Improve ad assignment

* Improve ad assignment

* Add Stan SV model

* Improve transform typing

* Finish HMM model

* High dim gauss done

* Benchmakr v2.0 done

* Modulize var estimator and fix transform.jl

* Run with ForwardDiff

* Enable Stan for LDA bench

* Fix a bug in adapt

* Improve some code

* Fix bug in NUTS MH step (#324)

* Add interface for optionally enabling adaption.

* Do not adapt step size when numerical error is caught.

* Fix initial epsilon_bar.

* Fix missing t_valid.

* Drop incorrectly adapted step size when necessary  (#324)

* Edit warning message.

* Small tweaks.

* reset_da ==> restart_da

* address suggested naming

* Samler type for WarmUpManager.paras and notation tweaks.

* Bugfix and adapt_step_size == > adapt_step_size!
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants