Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Commit

Permalink
Generalized U-turn condition (#853)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #853

The original U-turn condition is well-defined only for Euclidean manifolds (as discussed in [Appendix 4.2 of this paper](https://arxiv.org/pdf/1701.02434.pdf) or equivalently, [this paper](https://arxiv.org/pdf/1304.1920.pdf)), before introducing more complicated kinetic energy functions, it'd be better to first replace the original U-turn condition with the generalized version, i.e. terminate when either of these conditions is true:

{F619168889}

where

{F619168973} (M^{-1} would be an identity matrix in this diff as we haven't implemented mass matrix adaptation scheme yet)

and

{F619169021}

This this diff is analogous to pyro-ppl/pyro#1466.

I noticed that in [Pyro and Numpyro's implementation](https://github.com/pyro-ppl/pyro/blob/dev/pyro/infer/mcmc/nuts.py#L163-L172), `rho` is defined slightly differently by excluding momentums at the boundary -- this was briefly discussed in [Stan's forum](https://discourse.mc-stan.org/t/nuts-misses-u-turns-runs-in-circles-until-max-treedepth/9727/44). The post mentioned another issue with the U turn condition which I will address in D28735950 to make the reviewing process easier :).

Reviewed By: jpchen, neerajprad

Differential Revision: D28424431

fbshipit-source-id: 4aa477c263f2902891f4cc39a6f6d820b3692f9f
  • Loading branch information
horizon-blue authored and facebook-github-bot committed Jun 3, 2021
1 parent 1e8fa94 commit 7117e12
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,15 @@ def _initialize_momentums(self, world: SimpleWorld) -> RVDict:
"""Randomly draw momentum from MultivariateNormal(0, I). This momentum variable
is denoted as p in [1] and r in [2]."""
return {
node: torch.randn(world.get_transformed(node).shape)
# sample (flatten) momentums
node: torch.randn((world.get_transformed(node).numel(),))
for node in world.latent_nodes
}

def _kinetic_energy(self, momentums: RVDict) -> torch.Tensor:
"""Returns the kinetic energy KE = 1/2 * p^T @ p (equation 2.6 in [1])"""
energy = torch.tensor(0.0)
for r in momentums.values():
energy += torch.sum(r ** 2)
return energy / 2
r_all = torch.cat(list(momentums.values()))
return torch.dot(r_all, r_all) / 2

def _kinetic_grads(self, momentums: RVDict) -> RVDict:
"""Returns a dictionary of gradients of kinetic energy function with respect to
Expand Down Expand Up @@ -123,20 +122,21 @@ def _leapfrog_step(

new_momentums = {}
for node, r in momentums.items():
new_momentums[node] = r - step_size * pe_grad[node] / 2
new_momentums[node] = r - step_size * pe_grad[node].flatten() / 2
ke_grad = self._kinetic_grads(new_momentums)

new_world = world.copy()
for node in world.latent_nodes:
# this should override the value of all the latent nodes in new_world
# but does not change observations and transforms
z = world.get_transformed(node)
new_world.set_transformed(
node, world.get_transformed(node) + step_size * ke_grad[node]
node, z + step_size * ke_grad[node].reshape(z.shape)
)

pe, pe_grad = self._potential_grads(new_world)
for node, r in new_momentums.items():
new_momentums[node] = r - step_size * pe_grad[node] / 2
new_momentums[node] = r - step_size * pe_grad[node].flatten() / 2

return new_world, new_momentums, pe, pe_grad

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class _Tree(NamedTuple):
pe: torch.Tensor
pe_grad: RVDict
log_weight: torch.Tensor
sum_momentums: RVDict
sum_accept_prob: torch.Tensor
num_proposals: int
turned_or_diverged: bool
Expand Down Expand Up @@ -72,16 +73,18 @@ def __init__(
self._max_delta_energy = max_delta_energy
self._multinomial_sampling = multinomial_sampling

def _is_u_turning(self, left_state: _TreeNode, right_state: _TreeNode) -> bool:
left_angle = 0.0
right_angle = 0.0
for node in left_state.world.latent_nodes:
diff = right_state.world.get_transformed(
node
) - left_state.world.get_transformed(node)
left_angle += torch.sum(diff * left_state.momentums[node])
right_angle += torch.sum(diff * right_state.momentums[node])
return bool((left_angle <= 0) or (right_angle <= 0))
def _is_u_turning(
self,
left_momentums: RVDict,
right_momentums: RVDict,
sum_momentums: RVDict,
) -> bool:
"""The generalized U-turn condition, as described in [2] Appendix 4.2"""
left_r = torch.cat(list(left_momentums.values()))
right_r = torch.cat(list(right_momentums.values()))
rho = torch.cat(list(sum_momentums.values()))

return bool((torch.dot(left_r, rho) <= 0) or (torch.dot(right_r, rho) <= 0))

def _build_tree_base_case(self, root: _TreeNode, args: _TreeArgs) -> _Tree:
"""Base case of the recursive tree building algorithm: take a single leapfrog
Expand All @@ -106,6 +109,7 @@ def _build_tree_base_case(self, root: _TreeNode, args: _TreeArgs) -> _Tree:
pe=pe,
pe_grad=pe_grad,
log_weight=log_weight,
sum_momentums=momentums,
sum_accept_prob=torch.clamp(torch.exp(-delta_energy), max=1.0),
num_proposals=1,
turned_or_diverged=bool(
Expand Down Expand Up @@ -144,17 +148,26 @@ def _build_tree(self, root: _TreeNode, tree_depth: int, args: _TreeArgs) -> _Tre

left_state = other_sub_tree.left if args.direction == -1 else sub_tree.left
right_state = sub_tree.right if args.direction == -1 else other_sub_tree.right
sum_momentums = {
node: sub_tree.sum_momentums[node] + other_sub_tree.sum_momentums[node]
for node in sub_tree.sum_momentums
}
return _Tree(
left=left_state,
right=right_state,
proposal=selected_subtree.proposal,
pe=selected_subtree.pe,
pe_grad=selected_subtree.pe_grad,
log_weight=log_weight,
sum_momentums=sum_momentums,
sum_accept_prob=sub_tree.sum_accept_prob + other_sub_tree.sum_accept_prob,
num_proposals=sub_tree.num_proposals + other_sub_tree.num_proposals,
turned_or_diverged=other_sub_tree.turned_or_diverged
or self._is_u_turning(left_state, right_state),
or self._is_u_turning(
left_state.momentums,
right_state.momentums,
sum_momentums,
),
)

def propose(self, world: Optional[SimpleWorld] = None) -> SimpleWorld:
Expand All @@ -177,6 +190,7 @@ def propose(self, world: Optional[SimpleWorld] = None) -> SimpleWorld:
log_weight = torch.tensor(0.0) # log accept prob of staying at current state
sum_accept_prob = 0.0
num_proposals = 0
sum_momentums = momentums

for j in range(self._max_tree_depth):
direction = 1 if torch.rand(()) > 0.5 else -1
Expand Down Expand Up @@ -204,8 +218,15 @@ def propose(self, world: Optional[SimpleWorld] = None) -> SimpleWorld:
tree.pe,
tree.pe_grad,
)

if self._is_u_turning(left_tree_node, right_tree_node):
sum_momentums = {
node: sum_momentums[node] + tree.sum_momentums[node]
for node in sum_momentums
}
if self._is_u_turning(
left_tree_node.momentums,
right_tree_node.momentums,
sum_momentums,
):
break

log_weight = torch.logaddexp(log_weight, tree.log_weight)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ def test_potential_grads(world, hmc):
for node in world.latent_nodes:
assert node in pe_grad
assert isinstance(pe_grad[node], torch.Tensor)
assert pe_grad[node].shape == world[node].shape
assert pe_grad[node].shape == world.get_transformed(node).shape


def test_initialize_momentums(world, hmc):
momentums = hmc._initialize_momentums(world)
for node in world.latent_nodes:
assert node in momentums
assert isinstance(momentums[node], torch.Tensor)
assert momentums[node].shape == world[node].shape
assert len(momentums[node]) == world.get_transformed(node).numel()


def test_kinetic_grads(world, hmc):
Expand All @@ -58,7 +58,7 @@ def test_kinetic_grads(world, hmc):
for node in world.latent_nodes:
assert node in ke_grad
assert isinstance(ke_grad[node], torch.Tensor)
assert ke_grad[node].shape == world[node].shape
assert len(ke_grad[node]) == world.get_transformed(node).numel()


def test_leapfrog_step(world, hmc):
Expand Down

0 comments on commit 7117e12

Please sign in to comment.