Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement fix for NUTS missing u-turns #115

Closed
sethaxen opened this issue Dec 2, 2019 · 2 comments · Fixed by #145
Closed

Implement fix for NUTS missing u-turns #115

sethaxen opened this issue Dec 2, 2019 · 2 comments · Fixed by #145

Comments

@sethaxen
Copy link

sethaxen commented Dec 2, 2019

There was recently some discussion on Stan Discourse that resulted in this PR in Stan and this PR in PyMC3. They make some changes to the NUTS criterion to handle a previously undiscovered case where the NUTS criterion failed to catch a U-turn. Has this been implemented here yet? (I haven't yet researched what the fix entailed)

Self-contained example that demonstrates the problem

This post gave an example model that exhibits the failed behavior. I've reproduced it here:

using DynamicHMC, LogDensityProblems, Distributions, Random
import LogDensityProblems: capabilities, dimension, logdensity, logdensity_and_gradient

struct StdNormalProblem{N} end

StdNormalProblem(N::Int) = StdNormalProblem{N}()
(p::StdNormalProblem)(θ) = sum(logpdf.(Normal(), θ))
capabilities(::Type{<:StdNormalProblem}) = LogDensityProblems.LogDensityOrder{1}()
dimension(p::StdNormalProblem{N}) where {N} = N
logdensity(p::StdNormalProblem, θ) = p(θ)
logdensity_and_gradient(p::StdNormalProblem, θ) = (logdensity(p, θ), -θ)

max_depth = 12
rng = MersenneTwister(13)
nsat = sum(1:20) do _
    results = mcmc_with_warmup(
        rng,
        StdNormalProblem(200),
        1000;
        algorithm = DynamicHMC.NUTS(max_depth = max_depth),
        reporter = NoProgressReport(),
    )
    sum(getfield.(results.tree_statistics, :depth) .≥ max_depth)
end

On my machine, nsat is 9, i.e. 9/20000 (0.045 %) trajectories didn't detect a u-turn, which is on the same order as the Stan example before the fix was merged (0.12%).

@tpapp
Copy link
Owner

tpapp commented Dec 2, 2019

Thanks for bringing this up, and the very thorough issue report. It is on my radar and I plan to fix this. Feel free to ping me if I don't get to it before January.

@tpapp
Copy link
Owner

tpapp commented Feb 11, 2021

@sethaxen: apologies that this took such a long time, the fix is simple but I wanted to understand the math first. Thanks again for suggesting this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants