diff --git a/src/beanmachine/ppl/experimental/global_inference/proposer/hmc_proposer.py b/src/beanmachine/ppl/experimental/global_inference/proposer/hmc_proposer.py index 7151e911e6..9acab36199 100644 --- a/src/beanmachine/ppl/experimental/global_inference/proposer/hmc_proposer.py +++ b/src/beanmachine/ppl/experimental/global_inference/proposer/hmc_proposer.py @@ -58,16 +58,15 @@ def _initialize_momentums(self, world: SimpleWorld) -> RVDict: """Randomly draw momentum from MultivariateNormal(0, I). This momentum variable is denoted as p in [1] and r in [2].""" return { - node: torch.randn(world.get_transformed(node).shape) + # sample (flatten) momentums + node: torch.randn((world.get_transformed(node).numel(),)) for node in world.latent_nodes } def _kinetic_energy(self, momentums: RVDict) -> torch.Tensor: """Returns the kinetic energy KE = 1/2 * p^T @ p (equation 2.6 in [1])""" - energy = torch.tensor(0.0) - for r in momentums.values(): - energy += torch.sum(r ** 2) - return energy / 2 + r_all = torch.cat(list(momentums.values())) + return torch.dot(r_all, r_all) / 2 def _kinetic_grads(self, momentums: RVDict) -> RVDict: """Returns a dictionary of gradients of kinetic energy function with respect to @@ -123,20 +122,21 @@ def _leapfrog_step( new_momentums = {} for node, r in momentums.items(): - new_momentums[node] = r - step_size * pe_grad[node] / 2 + new_momentums[node] = r - step_size * pe_grad[node].flatten() / 2 ke_grad = self._kinetic_grads(new_momentums) new_world = world.copy() for node in world.latent_nodes: # this should override the value of all the latent nodes in new_world # but does not change observations and transforms + z = world.get_transformed(node) new_world.set_transformed( - node, world.get_transformed(node) + step_size * ke_grad[node] + node, z + step_size * ke_grad[node].reshape(z.shape) ) pe, pe_grad = self._potential_grads(new_world) for node, r in new_momentums.items(): - new_momentums[node] = r - step_size * pe_grad[node] / 2 + new_momentums[node] = r - step_size * pe_grad[node].flatten() / 2 return new_world, new_momentums, pe, pe_grad diff --git a/src/beanmachine/ppl/experimental/global_inference/proposer/nuts_proposer.py b/src/beanmachine/ppl/experimental/global_inference/proposer/nuts_proposer.py index 0b5611b9e2..e4f27f72e0 100644 --- a/src/beanmachine/ppl/experimental/global_inference/proposer/nuts_proposer.py +++ b/src/beanmachine/ppl/experimental/global_inference/proposer/nuts_proposer.py @@ -23,6 +23,7 @@ class _Tree(NamedTuple): pe: torch.Tensor pe_grad: RVDict log_weight: torch.Tensor + sum_momentums: RVDict sum_accept_prob: torch.Tensor num_proposals: int turned_or_diverged: bool @@ -72,16 +73,18 @@ def __init__( self._max_delta_energy = max_delta_energy self._multinomial_sampling = multinomial_sampling - def _is_u_turning(self, left_state: _TreeNode, right_state: _TreeNode) -> bool: - left_angle = 0.0 - right_angle = 0.0 - for node in left_state.world.latent_nodes: - diff = right_state.world.get_transformed( - node - ) - left_state.world.get_transformed(node) - left_angle += torch.sum(diff * left_state.momentums[node]) - right_angle += torch.sum(diff * right_state.momentums[node]) - return bool((left_angle <= 0) or (right_angle <= 0)) + def _is_u_turning( + self, + left_momentums: RVDict, + right_momentums: RVDict, + sum_momentums: RVDict, + ) -> bool: + """The generalized U-turn condition, as described in [2] Appendix 4.2""" + left_r = torch.cat(list(left_momentums.values())) + right_r = torch.cat(list(right_momentums.values())) + rho = torch.cat(list(sum_momentums.values())) + + return bool((torch.dot(left_r, rho) <= 0) or (torch.dot(right_r, rho) <= 0)) def _build_tree_base_case(self, root: _TreeNode, args: _TreeArgs) -> _Tree: """Base case of the recursive tree building algorithm: take a single leapfrog @@ -106,6 +109,7 @@ def _build_tree_base_case(self, root: _TreeNode, args: _TreeArgs) -> _Tree: pe=pe, pe_grad=pe_grad, log_weight=log_weight, + sum_momentums=momentums, sum_accept_prob=torch.clamp(torch.exp(-delta_energy), max=1.0), num_proposals=1, turned_or_diverged=bool( @@ -144,6 +148,10 @@ def _build_tree(self, root: _TreeNode, tree_depth: int, args: _TreeArgs) -> _Tre left_state = other_sub_tree.left if args.direction == -1 else sub_tree.left right_state = sub_tree.right if args.direction == -1 else other_sub_tree.right + sum_momentums = { + node: sub_tree.sum_momentums[node] + other_sub_tree.sum_momentums[node] + for node in sub_tree.sum_momentums + } return _Tree( left=left_state, right=right_state, @@ -151,10 +159,15 @@ def _build_tree(self, root: _TreeNode, tree_depth: int, args: _TreeArgs) -> _Tre pe=selected_subtree.pe, pe_grad=selected_subtree.pe_grad, log_weight=log_weight, + sum_momentums=sum_momentums, sum_accept_prob=sub_tree.sum_accept_prob + other_sub_tree.sum_accept_prob, num_proposals=sub_tree.num_proposals + other_sub_tree.num_proposals, turned_or_diverged=other_sub_tree.turned_or_diverged - or self._is_u_turning(left_state, right_state), + or self._is_u_turning( + left_state.momentums, + right_state.momentums, + sum_momentums, + ), ) def propose(self, world: Optional[SimpleWorld] = None) -> SimpleWorld: @@ -177,6 +190,7 @@ def propose(self, world: Optional[SimpleWorld] = None) -> SimpleWorld: log_weight = torch.tensor(0.0) # log accept prob of staying at current state sum_accept_prob = 0.0 num_proposals = 0 + sum_momentums = momentums for j in range(self._max_tree_depth): direction = 1 if torch.rand(()) > 0.5 else -1 @@ -204,8 +218,15 @@ def propose(self, world: Optional[SimpleWorld] = None) -> SimpleWorld: tree.pe, tree.pe_grad, ) - - if self._is_u_turning(left_tree_node, right_tree_node): + sum_momentums = { + node: sum_momentums[node] + tree.sum_momentums[node] + for node in sum_momentums + } + if self._is_u_turning( + left_tree_node.momentums, + right_tree_node.momentums, + sum_momentums, + ): break log_weight = torch.logaddexp(log_weight, tree.log_weight) diff --git a/src/beanmachine/ppl/experimental/global_inference/proposer/tests/hmc_proposer_test.py b/src/beanmachine/ppl/experimental/global_inference/proposer/tests/hmc_proposer_test.py index ca4b9ea8dd..2b29fccf87 100644 --- a/src/beanmachine/ppl/experimental/global_inference/proposer/tests/hmc_proposer_test.py +++ b/src/beanmachine/ppl/experimental/global_inference/proposer/tests/hmc_proposer_test.py @@ -38,7 +38,7 @@ def test_potential_grads(world, hmc): for node in world.latent_nodes: assert node in pe_grad assert isinstance(pe_grad[node], torch.Tensor) - assert pe_grad[node].shape == world[node].shape + assert pe_grad[node].shape == world.get_transformed(node).shape def test_initialize_momentums(world, hmc): @@ -46,7 +46,7 @@ def test_initialize_momentums(world, hmc): for node in world.latent_nodes: assert node in momentums assert isinstance(momentums[node], torch.Tensor) - assert momentums[node].shape == world[node].shape + assert len(momentums[node]) == world.get_transformed(node).numel() def test_kinetic_grads(world, hmc): @@ -58,7 +58,7 @@ def test_kinetic_grads(world, hmc): for node in world.latent_nodes: assert node in ke_grad assert isinstance(ke_grad[node], torch.Tensor) - assert ke_grad[node].shape == world[node].shape + assert len(ke_grad[node]) == world.get_transformed(node).numel() def test_leapfrog_step(world, hmc):