From a01d015763f82b590a8355fec018ce2706982607 Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Thu, 17 Oct 2019 18:20:36 +0200 Subject: [PATCH] Implement robust U-turn check (#3605) * [WIP] Robust U-turn check Following the recent discussion on the Stan side: https://github.com/stan-dev/stan/pull/2800 For experiment, do not merge. * typo fix * bug fix * Additional U turn check only when depth > 1 (to avoid redundant work). * further logic to reduce redundant U Turn check. * bug fix fix error in recording the end point of the reversed subtree * [WIP] Robust U-turn check Following the recent discussion on the Stan side: https://github.com/stan-dev/stan/pull/2800 For experiment, do not merge. * typo fix * bug fix * Additional U turn check only when depth > 1 (to avoid redundant work). * further logic to reduce redundant U Turn check. * bug fix fix error in recording the end point of the reversed subtree * Add release note. --- RELEASE-NOTES.md | 1 + pymc3/step_methods/hmc/nuts.py | 28 +++++++++++++++++++++++++--- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index fb2b12fa18b..3980f2c74bf 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -3,6 +3,7 @@ ## PyMC3 3.8 (on deck) ### New features +- Implemented robust u turn check in NUTS (similar to stan-dev/stan#2800). See PR [#3605] - Add capabilities to do inference on parameters in a differential equation with `DifferentialEquation`. See [#3590](https://github.com/pymc-devs/pymc3/pull/3590). - Distinguish between `Data` and `Deterministic` variables when graphing models with graphviz. PR [#3491](https://github.com/pymc-devs/pymc3/pull/3491). - Sequential Monte Carlo - Approximate Bayesian Computation step method is now available. The implementation is in an experimental stage and will be further improved. diff --git a/pymc3/step_methods/hmc/nuts.py b/pymc3/step_methods/hmc/nuts.py index 01e454fb6eb..6a83394e363 100644 --- a/pymc3/step_methods/hmc/nuts.py +++ b/pymc3/step_methods/hmc/nuts.py @@ -251,10 +251,18 @@ def extend(self, direction): if direction > 0: tree, diverging, turning = self._build_subtree( self.right, self.depth, floatX(np.asarray(self.step_size))) + leftmost_begin, leftmost_end = self.left, self.right + rightmost_begin, rightmost_end = tree.left, tree.right + leftmost_p_sum = self.p_sum + rightmost_p_sum = tree.p_sum self.right = tree.right else: tree, diverging, turning = self._build_subtree( self.left, self.depth, floatX(np.asarray(-self.step_size))) + leftmost_begin, leftmost_end = tree.right, tree.left + rightmost_begin, rightmost_end = self.left, self.right + leftmost_p_sum = tree.p_sum + rightmost_p_sum = self.p_sum self.left = tree.right self.depth += 1 @@ -271,9 +279,16 @@ def extend(self, direction): self.log_size = np.logaddexp(self.log_size, tree.log_size) self.p_sum[:] += tree.p_sum - left, right = self.left, self.right - p_sum = self.p_sum - turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0) + # Additional turning check only when tree depth > 0 to avoid redundant work + if self.depth > 0: + left, right = self.left, self.right + p_sum = self.p_sum + turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0) + p_sum1 = leftmost_p_sum + rightmost_begin.p + turning1 = (p_sum1.dot(leftmost_begin.v) <= 0) or (p_sum1.dot(rightmost_begin.v) <= 0) + p_sum2 = leftmost_end.p + rightmost_p_sum + turning2 = (p_sum2.dot(leftmost_end.v) <= 0) or (p_sum2.dot(rightmost_end.v) <= 0) + turning = (turning | turning1 | turning2) return diverging, turning @@ -324,6 +339,13 @@ def _build_subtree(self, left, depth, epsilon): if not (diverging or turning): p_sum = tree1.p_sum + tree2.p_sum turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0) + # Additional U turn check only when depth > 1 to avoid redundant work. + if depth - 1 > 0: + p_sum1 = tree1.p_sum + tree2.left.p + turning1 = (p_sum1.dot(tree1.left.v) <= 0) or (p_sum1.dot(tree2.left.v) <= 0) + p_sum2 = tree1.right.p + tree2.p_sum + turning2 = (p_sum2.dot(tree1.right.v) <= 0) or (p_sum2.dot(tree2.right.v) <= 0) + turning = (turning | turning1 | turning2) log_size = np.logaddexp(tree1.log_size, tree2.log_size) if logbern(tree2.log_size - log_size):