This repository has been archived by the owner on Dec 18, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Pull Request resolved: #849 This diff implements the multinomial sampling scheme as introduced in Appendix 2.1. of ["A Conceptual Introduction to Hamiltonian Monte Carlo" by Michael Betancourt](https://arxiv.org/abs/1701.02434), where instead of using the slice variable, we draw from a multinomial distribution over the states in the trajectory with probabilities {F618762963} This sampling mechanism is used instead of the slice sampling (as introduced in the original NUTS paper) in [Pyro](https://github.com/pyro-ppl/pyro/blob/c340831b3478ba008fdddba0b972b4275c9036a3/pyro/infer/mcmc/nuts.py#L229-L230), [Stan](https://github.com/stan-dev/stan/blob/d8c34d315f92892a9d19b96e06b196bd7640b7e5/src/stan/mcmc/hmc/nuts/base_nuts.hpp#L321-L324), and [Tensorflow Probability](https://github.com/tensorflow/probability/blob/v0.12.2/tensorflow_probability/python/mcmc/nuts.py#L873-L876). I also added a few sanity checks to verify that the tree-building functions work roughly as we expected :). Reviewed By: neerajprad Differential Revision: D28272315 fbshipit-source-id: 9d1f6a8e29765c1f284a1e1e570c0c1fd5911126
- Loading branch information
1 parent
e788754
commit 1e8fa94
Showing
4 changed files
with
132 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
76 changes: 76 additions & 0 deletions
76
src/beanmachine/ppl/experimental/global_inference/proposer/tests/nuts_proposer_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import beanmachine.ppl as bm | ||
import pytest | ||
import torch | ||
import torch.distributions as dist | ||
from beanmachine.ppl.experimental.global_inference.proposer.nuts_proposer import ( | ||
NUTSProposer, | ||
_Tree, | ||
_TreeArgs, | ||
_TreeNode, | ||
) | ||
from beanmachine.ppl.experimental.global_inference.simple_world import SimpleWorld | ||
|
||
|
||
@bm.random_variable | ||
def foo(): | ||
return dist.Beta(2.0, 2.0) | ||
|
||
|
||
@bm.random_variable | ||
def bar(): | ||
return dist.Bernoulli(foo()) | ||
|
||
|
||
@pytest.fixture | ||
def nuts(): | ||
world = SimpleWorld(observations={bar(): torch.tensor(0.8)}) | ||
world.call(bar()) | ||
nuts_proposer = NUTSProposer(world) | ||
return nuts_proposer | ||
|
||
|
||
@pytest.fixture | ||
def tree_node(nuts): | ||
momentums = nuts._initialize_momentums(nuts.world) | ||
return _TreeNode(world=nuts.world, momentums=momentums, pe_grad=nuts._pe_grad) | ||
|
||
|
||
@pytest.fixture | ||
def tree_args(tree_node, nuts): | ||
initial_energy = nuts._hamiltonian(nuts.world, tree_node.momentums, nuts._pe) | ||
return _TreeArgs( | ||
log_slice=-initial_energy, | ||
direction=1, | ||
step_size=nuts.step_size, | ||
initial_energy=initial_energy, | ||
) | ||
|
||
|
||
def test_base_tree(tree_node, tree_args, nuts): | ||
nuts._multinomial_sampling = False | ||
tree_args = tree_args._replace( | ||
log_slice=torch.log1p(-torch.rand(())) - tree_args.initial_energy | ||
) | ||
tree = nuts._build_tree_base_case(root=tree_node, args=tree_args) | ||
assert isinstance(tree, _Tree) | ||
assert torch.isclose(tree.log_weight, torch.tensor(float("-inf"))) or torch.isclose( | ||
tree.log_weight, torch.tensor(0.0) | ||
) | ||
assert tree.left == tree.right | ||
|
||
|
||
def test_base_tree_multinomial(tree_node, tree_args, nuts): | ||
tree = nuts._build_tree_base_case(root=tree_node, args=tree_args) | ||
assert isinstance(tree, _Tree) | ||
# in multinomial sampling, trees are weighted by their accept prob | ||
assert torch.isclose( | ||
torch.clamp(tree.log_weight.exp(), max=1.0), tree.sum_accept_prob | ||
) | ||
|
||
|
||
def test_build_tree(tree_node, tree_args, nuts): | ||
tree_depth = 3 | ||
tree = nuts._build_tree(root=tree_node, tree_depth=tree_depth, args=tree_args) | ||
assert isinstance(tree, _Tree) | ||
assert tree.turned_or_diverged or (tree.left is not tree.right) | ||
assert tree.turned_or_diverged or tree.num_proposals == 2 ** tree_depth |