Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Commit

Permalink
Refactor tree combining logic (#863)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #863

The tree-combining logics in `_build_tree` and the main `propose` methods are almost identical except that one weighted both subtree equally and the other is biased toward the new tree. The goal of this refactoring is to make subsequent changes to the tree combining logic easier (e.g., changes like D28735950 doesn't need to be repeat twice).

The refactored `_combine_tree` method is more or less analogous to the [`_combine_tree` method in Numpyro](https://github.com/pyro-ppl/numpyro/blob/master/numpyro/infer/hmc_util.py#L762-L843).

Other than refactoring, there is no change to the algorithm in this diff.

Reviewed By: neerajprad

Differential Revision: D28817142

fbshipit-source-id: 87814af96d28506ea4fd5ae55ec949f3dfc296fb
  • Loading branch information
horizon-blue authored and facebook-github-bot committed Jun 4, 2021
1 parent 7117e12 commit 1182236
Showing 1 changed file with 68 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -135,39 +135,67 @@ def _build_tree(self, root: _TreeNode, tree_depth: int, args: _TreeArgs) -> _Tre
args=args,
)

# uniform progressive sampling (Appendix 3.1 of [2])
log_weight = torch.logaddexp(sub_tree.log_weight, other_sub_tree.log_weight)
log_tree_prob = other_sub_tree.log_weight - log_weight

# if log_tree_prob is NaN then this will evaluate to False; this can happen when
# the log weight of both trees are -inf
if torch.log1p(-torch.rand(())) <= log_tree_prob:
selected_subtree = other_sub_tree
return self._combine_tree(
sub_tree, other_sub_tree, args.direction, biased=False
)

def _combine_tree(
self, old_tree: _Tree, new_tree: _Tree, direction: int, biased: bool
) -> _Tree:
"""Combine the old tree and the new tree into a single (large) tree. The new
tree will be add to the left of the old tree if direction is -1, otherwise it
will be add to the right. If biased is True, then we will prefer choosing from
new tree (which is away from the starting location) than old tree when sampling
the next state from the trajectory. This function assumes old_tree is not
turned or diverged."""
# if old tree hsa turned or diverged, then we shouldn't build the new tree in
# the first place
assert not old_tree.turned_or_diverged
# log of the sum of the weights from both trees
log_weight = torch.logaddexp(old_tree.log_weight, new_tree.log_weight)

if new_tree.turned_or_diverged:
selected_subtree = old_tree
else:
selected_subtree = sub_tree
# progressively sample from the trajectory
if biased:
# biased progressive sampling (Appendix 3.2 of [2])
log_tree_prob = new_tree.log_weight - old_tree.log_weight
else:
# uniform progressive sampling (Appendix 3.1 of [2])
log_tree_prob = new_tree.log_weight - log_weight

if torch.rand_like(log_tree_prob).log() < log_tree_prob:
selected_subtree = new_tree
else:
selected_subtree = old_tree

if direction == -1:
left_tree, right_tree = new_tree, old_tree
else:
left_tree, right_tree = old_tree, new_tree

left_state = other_sub_tree.left if args.direction == -1 else sub_tree.left
right_state = sub_tree.right if args.direction == -1 else other_sub_tree.right
sum_momentums = {
node: sub_tree.sum_momentums[node] + other_sub_tree.sum_momentums[node]
for node in sub_tree.sum_momentums
node: left_tree.sum_momentums[node] + right_tree.sum_momentums[node]
for node in left_tree.sum_momentums
}
turned_or_diverged = new_tree.turned_or_diverged or self._is_u_turning(
left_tree.left.momentums,
right_tree.right.momentums,
sum_momentums,
)

return _Tree(
left=left_state,
right=right_state,
left=left_tree.left,
right=right_tree.right,
proposal=selected_subtree.proposal,
pe=selected_subtree.pe,
pe_grad=selected_subtree.pe_grad,
log_weight=log_weight,
sum_momentums=sum_momentums,
sum_accept_prob=sub_tree.sum_accept_prob + other_sub_tree.sum_accept_prob,
num_proposals=sub_tree.num_proposals + other_sub_tree.num_proposals,
turned_or_diverged=other_sub_tree.turned_or_diverged
or self._is_u_turning(
left_state.momentums,
right_state.momentums,
sum_momentums,
),
sum_accept_prob=old_tree.sum_accept_prob + new_tree.sum_accept_prob,
num_proposals=old_tree.num_proposals + new_tree.num_proposals,
turned_or_diverged=turned_or_diverged,
)

def propose(self, world: Optional[SimpleWorld] = None) -> SimpleWorld:
Expand All @@ -184,52 +212,32 @@ def propose(self, world: Optional[SimpleWorld] = None) -> SimpleWorld:
else:
# this is a more stable way to sample from log(Uniform(0, exp(-current_energy)))
log_slice = torch.log1p(-torch.rand(())) - current_energy
left_tree_node = right_tree_node = _TreeNode(
self.world, momentums, self._pe_grad
tree_node = _TreeNode(self.world, momentums, self._pe_grad)
tree = _Tree(
left=tree_node,
right=tree_node,
proposal=self.world,
pe=self._pe,
pe_grad=self._pe_grad,
log_weight=torch.tensor(0.0), # log accept prob of staying at current state
sum_momentums=momentums,
sum_accept_prob=torch.tensor(0.0),
num_proposals=0,
turned_or_diverged=False,
)
log_weight = torch.tensor(0.0) # log accept prob of staying at current state
sum_accept_prob = 0.0
num_proposals = 0
sum_momentums = momentums

for j in range(self._max_tree_depth):
direction = 1 if torch.rand(()) > 0.5 else -1
tree_args = _TreeArgs(log_slice, direction, self.step_size, current_energy)
if direction == -1:
tree = self._build_tree(left_tree_node, j, tree_args)
left_tree_node = tree.left
new_tree = self._build_tree(tree.left, j, tree_args)
else:
tree = self._build_tree(right_tree_node, j, tree_args)
right_tree_node = tree.right

sum_accept_prob += tree.sum_accept_prob
num_proposals += tree.num_proposals
new_tree = self._build_tree(tree.right, j, tree_args)

tree = self._combine_tree(tree, new_tree, direction, biased=True)
if tree.turned_or_diverged:
break

# biased progressive sampling (Appendix 3.2 of [2])
log_tree_prob = tree.log_weight - log_weight

# choose new world by randomly sample from proposed worlds
if torch.log1p(-torch.rand(())) <= log_tree_prob:
self.world, self._pe, self._pe_grad = (
tree.proposal,
tree.pe,
tree.pe_grad,
)
sum_momentums = {
node: sum_momentums[node] + tree.sum_momentums[node]
for node in sum_momentums
}
if self._is_u_turning(
left_tree_node.momentums,
right_tree_node.momentums,
sum_momentums,
):
break

log_weight = torch.logaddexp(log_weight, tree.log_weight)

self._alpha = sum_accept_prob / num_proposals
self.world, self._pe, self._pe_grad = tree.proposal, tree.pe, tree.pe_grad
self._alpha = tree.sum_accept_prob / tree.num_proposals
return self.world

0 comments on commit 1182236

Please sign in to comment.