Skip to content

Commit

Permalink
Implement robust U-turn check (#3605)
Browse files Browse the repository at this point in the history
* [WIP] Robust U-turn check

Following the recent discussion on the Stan side: stan-dev/stan#2800

For experiment, do not merge.

* typo fix

* bug fix

*  Additional U turn check only when depth > 1 (to avoid redundant work).

* further logic to reduce redundant U Turn check.

* bug fix

fix error in recording the end point of the reversed subtree

* [WIP] Robust U-turn check

Following the recent discussion on the Stan side: stan-dev/stan#2800

For experiment, do not merge.

* typo fix

* bug fix

*  Additional U turn check only when depth > 1 (to avoid redundant work).

* further logic to reduce redundant U Turn check.

* bug fix

fix error in recording the end point of the reversed subtree

* Add release note.
  • Loading branch information
junpenglao authored and twiecki committed Oct 17, 2019
1 parent 530bc41 commit a01d015
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## PyMC3 3.8 (on deck)

### New features
- Implemented robust u turn check in NUTS (similar to stan-dev/stan#2800). See PR [#3605]
- Add capabilities to do inference on parameters in a differential equation with `DifferentialEquation`. See [#3590](https://github.com/pymc-devs/pymc3/pull/3590).
- Distinguish between `Data` and `Deterministic` variables when graphing models with graphviz. PR [#3491](https://github.com/pymc-devs/pymc3/pull/3491).
- Sequential Monte Carlo - Approximate Bayesian Computation step method is now available. The implementation is in an experimental stage and will be further improved.
Expand Down
28 changes: 25 additions & 3 deletions pymc3/step_methods/hmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,18 @@ def extend(self, direction):
if direction > 0:
tree, diverging, turning = self._build_subtree(
self.right, self.depth, floatX(np.asarray(self.step_size)))
leftmost_begin, leftmost_end = self.left, self.right
rightmost_begin, rightmost_end = tree.left, tree.right
leftmost_p_sum = self.p_sum
rightmost_p_sum = tree.p_sum
self.right = tree.right
else:
tree, diverging, turning = self._build_subtree(
self.left, self.depth, floatX(np.asarray(-self.step_size)))
leftmost_begin, leftmost_end = tree.right, tree.left
rightmost_begin, rightmost_end = self.left, self.right
leftmost_p_sum = tree.p_sum
rightmost_p_sum = self.p_sum
self.left = tree.right

self.depth += 1
Expand All @@ -271,9 +279,16 @@ def extend(self, direction):
self.log_size = np.logaddexp(self.log_size, tree.log_size)
self.p_sum[:] += tree.p_sum

left, right = self.left, self.right
p_sum = self.p_sum
turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0)
# Additional turning check only when tree depth > 0 to avoid redundant work
if self.depth > 0:
left, right = self.left, self.right
p_sum = self.p_sum
turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0)
p_sum1 = leftmost_p_sum + rightmost_begin.p
turning1 = (p_sum1.dot(leftmost_begin.v) <= 0) or (p_sum1.dot(rightmost_begin.v) <= 0)
p_sum2 = leftmost_end.p + rightmost_p_sum
turning2 = (p_sum2.dot(leftmost_end.v) <= 0) or (p_sum2.dot(rightmost_end.v) <= 0)
turning = (turning | turning1 | turning2)

return diverging, turning

Expand Down Expand Up @@ -324,6 +339,13 @@ def _build_subtree(self, left, depth, epsilon):
if not (diverging or turning):
p_sum = tree1.p_sum + tree2.p_sum
turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0)
# Additional U turn check only when depth > 1 to avoid redundant work.
if depth - 1 > 0:
p_sum1 = tree1.p_sum + tree2.left.p
turning1 = (p_sum1.dot(tree1.left.v) <= 0) or (p_sum1.dot(tree2.left.v) <= 0)
p_sum2 = tree1.right.p + tree2.p_sum
turning2 = (p_sum2.dot(tree1.right.v) <= 0) or (p_sum2.dot(tree2.right.v) <= 0)
turning = (turning | turning1 | turning2)

log_size = np.logaddexp(tree1.log_size, tree2.log_size)
if logbern(tree2.log_size - log_size):
Expand Down

0 comments on commit a01d015

Please sign in to comment.