-
-
Notifications
You must be signed in to change notification settings - Fork 983
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modify turning condition for nuts #1466
Conversation
self._eliminate_starting_point = True | ||
|
||
def _is_turning(self, r_left, r_right, r_sum): | ||
# We follow the strategy in Section A.4.2 of [2] for this implementation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would also suggest including this reference which derives the termination criterion explicitly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In [2], the author also derives the criterion too. I find it easier to understand than differential geometry style (though it is a nice language).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh..its the same derivation. I was looking at 4.2 rather than the appendix.
pyro/infer/mcmc/nuts.py
Outdated
accepted = True | ||
z = new_tree.z_proposal | ||
|
||
if self._is_turning(z_left, r_left, z_right, r_right): # stop doubling | ||
r_sum += new_tree.r_sum |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what a bug!!! it took me a lot of time to detect it. We should never do shortcut assignment for tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, in place tensor ops can give non deterministic results, in many cases. Just curious - what was the cause of failure here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@neerajprad When tree_depth = 0, we create a new_tree with 1 element. Hence r_sum = r_left + r_right, but the shortcut assignment gives r_sum = r_left. Because of this, is_turning
will return True (r_sum - r_left = 0), which makes NUTS not run at all or just run with 1 velocity verlet step.
(0.02, False, False, False), | ||
(0.02, False, True, False), | ||
(0.1, False, False, False), | ||
(0.1, False, True, False), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the old step_size works too but it is slow
pytest.mark.skipif('CI' in os.environ or 'CUDA_TEST' in os.environ, | ||
reason='Slow test - skip on CI/CUDA')] | ||
) | ||
TEST_CASES = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, this has been bugging me too for a while. 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, me too ^^
@neerajprad from the gaussian test, I can see that using mass matrix makes things easier to sample. Now we don't need many samples to pass the test. |
Makefile
Outdated
@@ -5,8 +5,8 @@ all: docs test | |||
install: FORCE | |||
pip install -e .[dev,profile] | |||
|
|||
uninstall: FORCE | |||
pip uninstall pyro-ppl | |||
reinstall: FORCE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When do you need this? If you do an editable install via -e, it should pick up any local changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
whoa, I have uninstalled and installed again and again to test examples, so I made this command. Thx @neerajprad!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some minor comments. Looks great otherwise!
Makefile
Outdated
@@ -5,9 +5,6 @@ all: docs test | |||
install: FORCE | |||
pip install -e .[dev,profile] | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let us put the uninstall option back in.
pyro/infer/mcmc/nuts.py
Outdated
# TODO: change to torch.dot for pytorch 1.0 | ||
if self.full_mass: | ||
if ((r_sum - r_left_flat) * (self._inverse_mass_matrix.matmul(r_left_flat))).sum() > 0: | ||
if ((r_sum - r_right_flat) * (self._inverse_mass_matrix.matmul(r_right_flat))) \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: why not just and
this condition?
@neerajprad Do you think that it is better to keep the old terminate condition (it is not a bug indeed) by using a hidden flag |
The old terminating condition with the identity mass matrix? Do you mean that it is a valid terminating condition in terms of preserving detailed balance, even though it might not generate the longest trajectories? Unless you are already observing cases where the old terminating condition is yielding better results, I would be more inclined to just remove that option. In any case, this should be equivalent to our old code with mass matrix adaptation disabled, so we can always compare against that. |
Never mind, ignoring it is totally fine to me. :) |
Summary: Pull Request resolved: #853 The original U-turn condition is well-defined only for Euclidean manifolds (as discussed in [Appendix 4.2 of this paper](https://arxiv.org/pdf/1701.02434.pdf) or equivalently, [this paper](https://arxiv.org/pdf/1304.1920.pdf)), before introducing more complicated kinetic energy functions, it'd be better to first replace the original U-turn condition with the generalized version, i.e. terminate when either of these conditions is true: {F619168889} where {F619168973} (M^{-1} would be an identity matrix in this diff as we haven't implemented mass matrix adaptation scheme yet) and {F619169021} This this diff is analogous to pyro-ppl/pyro#1466. I noticed that in [Pyro and Numpyro's implementation](https://github.com/pyro-ppl/pyro/blob/dev/pyro/infer/mcmc/nuts.py#L163-L172), `rho` is defined slightly differently by excluding momentums at the boundary -- this was briefly discussed in [Stan's forum](https://discourse.mc-stan.org/t/nuts-misses-u-turns-runs-in-circles-until-max-treedepth/9727/44). The post mentioned another issue with the U turn condition which I will address in D28735950 to make the reviewing process easier :). Reviewed By: jpchen, neerajprad Differential Revision: D28424431 fbshipit-source-id: 4aa477c263f2902891f4cc39a6f6d820b3692f9f
With the introduction of mass matrix, we have to modify the turning condition. It seems not straightforward, as discussed in section A.4.2. of A Conceptual Introduction to
Hamiltonian Monte Carlo. I will follow that section for the implementation (and compare the current implementation in Stan).
It would be fun btw. After this, I will add the multinomial sampling option for nuts. Hope that it would not be so complicated. ^^