-
-
Notifications
You must be signed in to change notification settings - Fork 983
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modify turning condition for nuts #1466
Changes from 5 commits
9e33250
a88e5ac
27f0535
59a6b84
0dcab88
ec9ea88
d098434
db009bf
5e71225
ff1ca4e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,10 +14,11 @@ | |
# sum_accept_probs and num_proposals are used to calculate | ||
# the statistic accept_prob for Dual Averaging scheme; | ||
# z_left_grads and z_right_grads are kept to avoid recalculating | ||
# grads at left and right leaves | ||
# grads at left and right leaves; | ||
# r_sum is used to check turning condition | ||
_TreeInfo = namedtuple("TreeInfo", ["z_left", "r_left", "z_left_grads", | ||
"z_right", "r_right", "z_right_grads", | ||
"z_proposal", "size", "turning", "diverging", | ||
"z_proposal", "r_sum", "size", "turning", "diverging", | ||
"sum_accept_probs", "num_proposals"]) | ||
|
||
|
||
|
@@ -115,19 +116,27 @@ def __init__(self, | |
# Here, as suggested in [1], we set dE_max = 1000. | ||
self._max_sliced_energy = 1000 | ||
|
||
def _is_turning(self, z_left, r_left, z_right, r_right): | ||
diff_left = 0 | ||
diff_right = 0 | ||
for name in self._r_shapes: | ||
dz = z_right[name] - z_left[name] | ||
diff_left += (dz * r_left[name]).sum() | ||
diff_right += (dz * r_right[name]).sum() | ||
return diff_left < 0 or diff_right < 0 | ||
# Set a flag to decide if we want to eliminate the initial point from the candidates to | ||
# choose uniformly along the trajectory. In [1], this flag is True, but in Stan, they set | ||
# it to False (implicitly). | ||
self._eliminate_starting_point = True | ||
|
||
def _is_turning(self, r_left, r_right, r_sum): | ||
# We follow the strategy in Section A.4.2 of [2] for this implementation. | ||
r_left_flat = torch.cat([r_left[site_name].reshape(-1) for site_name in sorted(r_left)]) | ||
r_right_flat = torch.cat([r_right[site_name].reshape(-1) for site_name in sorted(r_right)]) | ||
if self.full_mass: | ||
return (r_sum - r_left_flat).dot(self._inverse_mass_matrix.matmul(r_left_flat)) <= 0 \ | ||
or (r_sum - r_right_flat).dot(self._inverse_mass_matrix.matmul(r_right_flat)) <= 0 | ||
else: | ||
return self._inverse_mass_matrix.dot((r_sum - r_left_flat) * r_left_flat) <= 0 \ | ||
or self._inverse_mass_matrix.dot((r_sum - r_right_flat) * r_right_flat) <= 0 | ||
|
||
def _build_basetree(self, z, r, z_grads, log_slice, direction, energy_current): | ||
step_size = self.step_size if direction == 1 else -self.step_size | ||
z_new, r_new, z_grads, potential_energy = single_step_velocity_verlet( | ||
z, r, self._potential_energy, self._inverse_mass_matrix, step_size, z_grads=z_grads) | ||
r_new_flat = torch.cat([r_new[site_name].reshape(-1) for site_name in sorted(r_new)]) | ||
energy_new = potential_energy + self._kinetic_energy(r_new) | ||
sliced_energy = energy_new + log_slice | ||
|
||
|
@@ -148,7 +157,7 @@ def _build_basetree(self, z, r, z_grads, log_slice, direction, energy_current): | |
delta_energy = energy_new - energy_current | ||
accept_prob = (-delta_energy).exp().clamp(max=1.0) | ||
return _TreeInfo(z_new, r_new, z_grads, z_new, r_new, z_grads, | ||
z_new, tree_size, False, diverging, accept_prob, 1) | ||
z_new, r_new_flat, tree_size, False, diverging, accept_prob, 1) | ||
|
||
def _build_tree(self, z, r, z_grads, log_slice, direction, tree_depth, energy_current): | ||
if tree_depth == 0: | ||
|
@@ -180,6 +189,7 @@ def _build_tree(self, z, r, z_grads, log_slice, direction, tree_depth, energy_cu | |
tree_size = half_tree.size + other_half_tree.size | ||
sum_accept_probs = half_tree.sum_accept_probs + other_half_tree.sum_accept_probs | ||
num_proposals = half_tree.num_proposals + other_half_tree.num_proposals | ||
r_sum = half_tree.r_sum + other_half_tree.r_sum | ||
|
||
# Under the slice sampling process, a proposal for z is uniformly picked. | ||
# The probability of that proposal belongs to which half of tree | ||
|
@@ -212,20 +222,20 @@ def _build_tree(self, z, r, z_grads, log_slice, direction, tree_depth, energy_cu | |
|
||
# We already check if first half tree is turning. Now, we check | ||
# if the other half tree or full tree are turning. | ||
turning = other_half_tree.turning or self._is_turning(z_left, r_left, z_right, r_right) | ||
turning = other_half_tree.turning or self._is_turning(r_left, r_right, r_sum) | ||
|
||
# The divergence is checked by the second half tree (the first half is already checked). | ||
diverging = other_half_tree.diverging | ||
|
||
return _TreeInfo(z_left, r_left, z_left_grads, z_right, r_right, z_right_grads, z_proposal, | ||
tree_size, turning, diverging, sum_accept_probs, num_proposals) | ||
r_sum, tree_size, turning, diverging, sum_accept_probs, num_proposals) | ||
|
||
def sample(self, trace): | ||
z = {name: node["value"].detach() for name, node in self._iter_latent_nodes(trace)} | ||
# automatically transform `z` to unconstrained space, if needed. | ||
for name, transform in self.transforms.items(): | ||
z[name] = transform(z[name]) | ||
r = self._sample_r(name="r_t={}".format(self._t)) | ||
r, r_flat = self._sample_r(name="r_t={}".format(self._t)) | ||
energy_current = self._energy(z, r) | ||
|
||
# Ideally, following a symplectic integrator trajectory, the energy is constant. | ||
|
@@ -240,8 +250,8 @@ def sample(self, trace): | |
# (z, r) ~ Uniform({(z', r') in trajectory | p(z', r') >= u}). | ||
# | ||
# For more information about slice sampling method, see [3]. | ||
# For another version of NUTS which uses multinomial sampling instead of slice sampling, see | ||
# [2]. | ||
# For another version of NUTS which uses multinomial sampling instead of slice sampling, | ||
# see [2]. | ||
|
||
# Rather than sampling the slice variable from `Uniform(0, exp(-energy))`, we can | ||
# sample log_slice directly using `energy`, so as to avoid potential underflow or | ||
|
@@ -253,8 +263,9 @@ def sample(self, trace): | |
z_left = z_right = z | ||
r_left = r_right = r | ||
z_left_grads = z_right_grads = None | ||
tree_size = 1 | ||
tree_size = 0 if self._eliminate_starting_point else 1 | ||
accepted = False | ||
r_sum = r_flat | ||
|
||
# Temporarily disable distributions args checking as | ||
# NaNs are expected during step size adaptation. | ||
|
@@ -283,11 +294,13 @@ def sample(self, trace): | |
|
||
rand = pyro.sample("rand_t={}_treedepth={}".format(self._t, tree_depth), | ||
dist.Uniform(torch.zeros(1), torch.ones(1))) | ||
if rand < new_tree.size / tree_size: | ||
if ((tree_size > 0) and (rand < new_tree.size / tree_size)) \ | ||
or ((tree_size == 0) and (new_tree.size > 0)): | ||
accepted = True | ||
z = new_tree.z_proposal | ||
|
||
if self._is_turning(z_left, r_left, z_right, r_right): # stop doubling | ||
r_sum += new_tree.r_sum | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what a bug!!! it took me a lot of time to detect it. We should never do shortcut assignment for tensor. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, in place tensor ops can give non deterministic results, in many cases. Just curious - what was the cause of failure here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @neerajprad When tree_depth = 0, we create a new_tree with 1 element. Hence r_sum = r_left + r_right, but the shortcut assignment gives r_sum = r_left. Because of this, |
||
if self._is_turning(r_left, r_right, r_sum): # stop doubling | ||
break | ||
else: # update tree_size | ||
tree_size += new_tree.size | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would also suggest including this reference which derives the termination criterion explicitly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In [2], the author also derives the criterion too. I find it easier to understand than differential geometry style (though it is a nice language).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh..its the same derivation. I was looking at 4.2 rather than the appendix.