Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Commit

Permalink
Robust U-turn condition (#864)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #864

An issue with the U-turn condition was discovered and discussed in [this post in Stan forum](https://discourse.mc-stan.org/t/nuts-misses-u-turns-runs-in-circles-until-max-treedepth/9727)

TL;DR: we can make the U-turn condition more robust by introducing two additional checks across subtrees. This can help us avoid missing U-turns for approximately iid normal models.

{F619223264}

Since the tree combining code are almost identical in `_build_tree` and `propose`, I also take the chance to refactor them into a common function called `_combine_tree`. If you look closely you will notice that most part of `_combine_tree` are moved from existing code as-is. The only addition is the two additional call to `_is_u_turning`

Related PR that implements this change:
- Stan: stan-dev/stan#2800
- PyMC3: pymc-devs/pymc#3605
- Turing.jl: TuringLang/AdvancedHMC.jl#207
- DynamicHMC.jl: tpapp/DynamicHMC.jl#145

Reviewed By: neerajprad

Differential Revision: D28735950

fbshipit-source-id: ada4ebcad26a87ef5e697f422b5c5b17007afe42
  • Loading branch information
horizon-blue authored and facebook-github-bot committed Jun 4, 2021
1 parent 1182236 commit 518cab7
Showing 1 changed file with 22 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,28 @@ def _combine_tree(
right_tree.right.momentums,
sum_momentums,
)
# More robust U-turn condition
# https://discourse.mc-stan.org/t/nuts-misses-u-turns-runs-in-circles-until-max-treedepth/9727
if not turned_or_diverged and right_tree.num_proposals > 1:
extended_sum_momentums = {
node: left_tree.sum_momentums[node] + right_tree.left.momentums[node]
for node in sum_momentums
}
turned_or_diverged = self._is_u_turning(
left_tree.left.momentums,
right_tree.left.momentums,
extended_sum_momentums,
)
if not turned_or_diverged and left_tree.num_proposals > 1:
extended_sum_momentums = {
node: right_tree.sum_momentums[node] + left_tree.right.momentums[node]
for node in sum_momentums
}
turned_or_diverged = self._is_u_turning(
left_tree.right.momentums,
right_tree.right.momentums,
extended_sum_momentums,
)

return _Tree(
left=left_tree.left,
Expand Down

0 comments on commit 518cab7

Please sign in to comment.