diff --git a/src/beanmachine/ppl/experimental/global_inference/proposer/nuts_proposer.py b/src/beanmachine/ppl/experimental/global_inference/proposer/nuts_proposer.py index 2a1f02c7ec..95ff76044c 100644 --- a/src/beanmachine/ppl/experimental/global_inference/proposer/nuts_proposer.py +++ b/src/beanmachine/ppl/experimental/global_inference/proposer/nuts_proposer.py @@ -184,6 +184,28 @@ def _combine_tree( right_tree.right.momentums, sum_momentums, ) + # More robust U-turn condition + # https://discourse.mc-stan.org/t/nuts-misses-u-turns-runs-in-circles-until-max-treedepth/9727 + if not turned_or_diverged and right_tree.num_proposals > 1: + extended_sum_momentums = { + node: left_tree.sum_momentums[node] + right_tree.left.momentums[node] + for node in sum_momentums + } + turned_or_diverged = self._is_u_turning( + left_tree.left.momentums, + right_tree.left.momentums, + extended_sum_momentums, + ) + if not turned_or_diverged and left_tree.num_proposals > 1: + extended_sum_momentums = { + node: right_tree.sum_momentums[node] + left_tree.right.momentums[node] + for node in sum_momentums + } + turned_or_diverged = self._is_u_turning( + left_tree.right.momentums, + right_tree.right.momentums, + extended_sum_momentums, + ) return _Tree( left=left_tree.left,