Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Commit

Permalink
NUTS multinomial sampling (#849)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #849

This diff implements the multinomial sampling scheme as introduced in Appendix 2.1. of ["A Conceptual Introduction to Hamiltonian Monte Carlo" by Michael Betancourt](https://arxiv.org/abs/1701.02434), where instead of using the slice variable, we draw from a multinomial distribution over the states in the trajectory with probabilities

{F618762963}

This sampling mechanism is used instead of the slice sampling (as introduced in the original NUTS paper) in [Pyro](https://github.com/pyro-ppl/pyro/blob/c340831b3478ba008fdddba0b972b4275c9036a3/pyro/infer/mcmc/nuts.py#L229-L230), [Stan](https://github.com/stan-dev/stan/blob/d8c34d315f92892a9d19b96e06b196bd7640b7e5/src/stan/mcmc/hmc/nuts/base_nuts.hpp#L321-L324), and [Tensorflow Probability](https://github.com/tensorflow/probability/blob/v0.12.2/tensorflow_probability/python/mcmc/nuts.py#L873-L876). I also added a few sanity checks to verify that the tree-building functions work roughly as we expected :).

Reviewed By: neerajprad

Differential Revision: D28272315

fbshipit-source-id: 9d1f6a8e29765c1f284a1e1e570c0c1fd5911126
  • Loading branch information
horizon-blue authored and facebook-github-bot committed Jun 3, 2021
1 parent e788754 commit 1e8fa94
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class GlobalNoUTurnSampler(BaseInference):
max_delta_energy: float = 1000.0
initial_step_size: float = 1.0
adapt_step_size: bool = True
multinomial_sampling: bool = True

def get_proposer(self, world: SimpleWorld) -> NUTSProposer:
return NUTSProposer(world, **dataclasses.asdict(self))
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class _Tree(NamedTuple):
proposal: SimpleWorld
pe: torch.Tensor
pe_grad: RVDict
weight: torch.Tensor
log_weight: torch.Tensor
sum_accept_prob: torch.Tensor
num_proposals: int
turned_or_diverged: bool
Expand All @@ -39,12 +39,17 @@ class NUTSProposer(HMCProposer):
"""
The No-U-Turn Sampler (NUTS) as described in [1]. Unlike vanilla HMC, it does not
require users to specify a trajectory length. The current implementation roughly
follows Algorithm 6 of [1].
follows Algorithm 6 of [1]. If multinomial_sampling is True, then the next state
will be drawn from a multinomial distribution (weighted by acceptance probability,
as introduced in Appendix 2 of [2]) instead of drawn uniformly.
Reference:
[1] Matthew Hoffman and Andrew Gelman. "The No-U-Turn Sampler: Adaptively
Setting Path Lengths in Hamiltonian Monte Carlo" (2014).
https://arxiv.org/abs/1111.4246
[2] Michael Betancourt. "A Conceptual Introduction to Hamiltonian Monte Carlo"
(2017). https://arxiv.org/abs/1701.02434
"""

def __init__(
Expand All @@ -54,6 +59,7 @@ def __init__(
max_delta_energy: float = 1000.0,
initial_step_size: float = 1.0,
adapt_step_size: bool = True,
multinomial_sampling: bool = True,
):
# note that trajectory_length is not used in NUTS
super().__init__(
Expand All @@ -64,6 +70,7 @@ def __init__(
)
self._max_tree_depth = max_tree_depth
self._max_delta_energy = max_delta_energy
self._multinomial_sampling = multinomial_sampling

def _is_u_turning(self, left_state: _TreeNode, right_state: _TreeNode) -> bool:
left_angle = 0.0
Expand All @@ -83,20 +90,23 @@ def _build_tree_base_case(self, root: _TreeNode, args: _TreeArgs) -> _Tree:
root.world, root.momentums, args.step_size * args.direction, root.pe_grad
)
new_energy = self._hamiltonian(world, momentums, pe)
new_energy = torch.nan_to_num(new_energy, float("inf"))
# initial_energy == -L(\theta^{m-1}) + 1/2 r_0^2 in Algorithm 6 of [1]
delta_energy = torch.nan_to_num(new_energy - args.initial_energy, float("inf"))
if self._multinomial_sampling:
log_weight = -delta_energy
else:
# slice sampling as introduced in the original NUTS paper [1]
log_weight = (args.log_slice <= -new_energy).log()

tree_node = _TreeNode(world=world, momentums=momentums, pe_grad=pe_grad)
return _Tree(
left=tree_node,
right=tree_node,
proposal=world,
pe=pe,
pe_grad=pe_grad,
weight=(args.log_slice <= -new_energy).float(),
sum_accept_prob=torch.clamp(
# initial_energy == -L(\theta^{m-1}) + 1/2 r_0^2 in Algorithm 6 of [1]
torch.exp(args.initial_energy - new_energy),
max=1.0,
),
log_weight=log_weight,
sum_accept_prob=torch.clamp(torch.exp(-delta_energy), max=1.0),
num_proposals=1,
turned_or_diverged=bool(
args.log_slice >= self._max_delta_energy - new_energy
Expand All @@ -121,10 +131,13 @@ def _build_tree(self, root: _TreeNode, tree_depth: int, args: _TreeArgs) -> _Tre
args=args,
)

# randomly choose between left/right subtree based on their weights
sum_weight = sub_tree.weight + other_sub_tree.weight
# clamp with a non-zero minimum value to avoid divide-by-zero
if torch.bernoulli(other_sub_tree.weight / torch.clamp(sum_weight, min=1e-3)):
# uniform progressive sampling (Appendix 3.1 of [2])
log_weight = torch.logaddexp(sub_tree.log_weight, other_sub_tree.log_weight)
log_tree_prob = other_sub_tree.log_weight - log_weight

# if log_tree_prob is NaN then this will evaluate to False; this can happen when
# the log weight of both trees are -inf
if torch.log1p(-torch.rand(())) <= log_tree_prob:
selected_subtree = other_sub_tree
else:
selected_subtree = sub_tree
Expand All @@ -137,7 +150,7 @@ def _build_tree(self, root: _TreeNode, tree_depth: int, args: _TreeArgs) -> _Tre
proposal=selected_subtree.proposal,
pe=selected_subtree.pe,
pe_grad=selected_subtree.pe_grad,
weight=sum_weight,
log_weight=log_weight,
sum_accept_prob=sub_tree.sum_accept_prob + other_sub_tree.sum_accept_prob,
num_proposals=sub_tree.num_proposals + other_sub_tree.num_proposals,
turned_or_diverged=other_sub_tree.turned_or_diverged
Expand All @@ -152,12 +165,16 @@ def propose(self, world: Optional[SimpleWorld] = None) -> SimpleWorld:

momentums = self._initialize_momentums(self.world)
current_energy = self._hamiltonian(self.world, momentums, self._pe)
# this is a more stable way to sample from log(Uniform(0, exp(-current_energy)))
log_slice = torch.log1p(-torch.rand(())) - current_energy
if self._multinomial_sampling:
# log slice is only used to check the divergence
log_slice = -current_energy
else:
# this is a more stable way to sample from log(Uniform(0, exp(-current_energy)))
log_slice = torch.log1p(-torch.rand(())) - current_energy
left_tree_node = right_tree_node = _TreeNode(
self.world, momentums, self._pe_grad
)
sum_weight = 1.0
log_weight = torch.tensor(0.0) # log accept prob of staying at current state
sum_accept_prob = 0.0
num_proposals = 0

Expand All @@ -177,8 +194,11 @@ def propose(self, world: Optional[SimpleWorld] = None) -> SimpleWorld:
if tree.turned_or_diverged:
break

# biased progressive sampling (Appendix 3.2 of [2])
log_tree_prob = tree.log_weight - log_weight

# choose new world by randomly sample from proposed worlds
if torch.bernoulli(torch.clamp(tree.weight / sum_weight, max=1.0)):
if torch.log1p(-torch.rand(())) <= log_tree_prob:
self.world, self._pe, self._pe_grad = (
tree.proposal,
tree.pe,
Expand All @@ -188,7 +208,7 @@ def propose(self, world: Optional[SimpleWorld] = None) -> SimpleWorld:
if self._is_u_turning(left_tree_node, right_tree_node):
break

sum_weight += tree.weight
log_weight = torch.logaddexp(log_weight, tree.log_weight)

self._alpha = sum_accept_prob / num_proposals
return self.world
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import beanmachine.ppl as bm
import pytest
import torch
import torch.distributions as dist
from beanmachine.ppl.experimental.global_inference.proposer.hmc_proposer import (
Expand All @@ -17,12 +18,20 @@ def bar():
return dist.Normal(foo(), 1.0)


world = SimpleWorld()
world.call(bar())
hmc = HMCProposer(world, trajectory_length=1.0)
@pytest.fixture
def world():
w = SimpleWorld()
w.call(bar())
return w


def test_potential_grads():
@pytest.fixture
def hmc(world):
hmc_proposer = HMCProposer(world, trajectory_length=1.0)
return hmc_proposer


def test_potential_grads(world, hmc):
pe, pe_grad = hmc._potential_grads(world)
assert isinstance(pe, torch.Tensor)
assert pe.numel() == 1
Expand All @@ -32,15 +41,15 @@ def test_potential_grads():
assert pe_grad[node].shape == world[node].shape


def test_initialize_momentums():
def test_initialize_momentums(world, hmc):
momentums = hmc._initialize_momentums(world)
for node in world.latent_nodes:
assert node in momentums
assert isinstance(momentums[node], torch.Tensor)
assert momentums[node].shape == world[node].shape


def test_kinetic_grads():
def test_kinetic_grads(world, hmc):
momentums = hmc._initialize_momentums(world)
ke = hmc._kinetic_energy(momentums)
assert isinstance(ke, torch.Tensor)
Expand All @@ -52,7 +61,7 @@ def test_kinetic_grads():
assert ke_grad[node].shape == world[node].shape


def test_leapfrog_step():
def test_leapfrog_step(world, hmc):
step_size = 0.0
momentums = hmc._initialize_momentums(world)
new_world, new_momentums, pe, pe_grad = hmc._leapfrog_step(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import beanmachine.ppl as bm
import pytest
import torch
import torch.distributions as dist
from beanmachine.ppl.experimental.global_inference.proposer.nuts_proposer import (
NUTSProposer,
_Tree,
_TreeArgs,
_TreeNode,
)
from beanmachine.ppl.experimental.global_inference.simple_world import SimpleWorld


@bm.random_variable
def foo():
return dist.Beta(2.0, 2.0)


@bm.random_variable
def bar():
return dist.Bernoulli(foo())


@pytest.fixture
def nuts():
world = SimpleWorld(observations={bar(): torch.tensor(0.8)})
world.call(bar())
nuts_proposer = NUTSProposer(world)
return nuts_proposer


@pytest.fixture
def tree_node(nuts):
momentums = nuts._initialize_momentums(nuts.world)
return _TreeNode(world=nuts.world, momentums=momentums, pe_grad=nuts._pe_grad)


@pytest.fixture
def tree_args(tree_node, nuts):
initial_energy = nuts._hamiltonian(nuts.world, tree_node.momentums, nuts._pe)
return _TreeArgs(
log_slice=-initial_energy,
direction=1,
step_size=nuts.step_size,
initial_energy=initial_energy,
)


def test_base_tree(tree_node, tree_args, nuts):
nuts._multinomial_sampling = False
tree_args = tree_args._replace(
log_slice=torch.log1p(-torch.rand(())) - tree_args.initial_energy
)
tree = nuts._build_tree_base_case(root=tree_node, args=tree_args)
assert isinstance(tree, _Tree)
assert torch.isclose(tree.log_weight, torch.tensor(float("-inf"))) or torch.isclose(
tree.log_weight, torch.tensor(0.0)
)
assert tree.left == tree.right


def test_base_tree_multinomial(tree_node, tree_args, nuts):
tree = nuts._build_tree_base_case(root=tree_node, args=tree_args)
assert isinstance(tree, _Tree)
# in multinomial sampling, trees are weighted by their accept prob
assert torch.isclose(
torch.clamp(tree.log_weight.exp(), max=1.0), tree.sum_accept_prob
)


def test_build_tree(tree_node, tree_args, nuts):
tree_depth = 3
tree = nuts._build_tree(root=tree_node, tree_depth=tree_depth, args=tree_args)
assert isinstance(tree, _Tree)
assert tree.turned_or_diverged or (tree.left is not tree.right)
assert tree.turned_or_diverged or tree.num_proposals == 2 ** tree_depth

0 comments on commit 1e8fa94

Please sign in to comment.