-
-
Notifications
You must be signed in to change notification settings - Fork 983
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modify turning condition for nuts #1466
Changes from 7 commits
9e33250
a88e5ac
27f0535
59a6b84
0dcab88
ec9ea88
d098434
db009bf
5e71225
ff1ca4e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,8 +5,8 @@ all: docs test | |
install: FORCE | ||
pip install -e .[dev,profile] | ||
|
||
uninstall: FORCE | ||
pip uninstall pyro-ppl | ||
reinstall: FORCE | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When do you need this? If you do an editable install via -e, it should pick up any local changes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. whoa, I have uninstalled and installed again and again to test examples, so I made this command. Thx @neerajprad! |
||
pip uninstall -y pyro-ppl && pip install . --no-deps | ||
|
||
docs: FORCE | ||
$(MAKE) -C docs html | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,10 +14,11 @@ | |
# sum_accept_probs and num_proposals are used to calculate | ||
# the statistic accept_prob for Dual Averaging scheme; | ||
# z_left_grads and z_right_grads are kept to avoid recalculating | ||
# grads at left and right leaves | ||
# grads at left and right leaves; | ||
# r_sum is used to check turning condition | ||
_TreeInfo = namedtuple("TreeInfo", ["z_left", "r_left", "z_left_grads", | ||
"z_right", "r_right", "z_right_grads", | ||
"z_proposal", "size", "turning", "diverging", | ||
"z_proposal", "r_sum", "size", "turning", "diverging", | ||
"sum_accept_probs", "num_proposals"]) | ||
|
||
|
||
|
@@ -115,19 +116,27 @@ def __init__(self, | |
# Here, as suggested in [1], we set dE_max = 1000. | ||
self._max_sliced_energy = 1000 | ||
|
||
def _is_turning(self, z_left, r_left, z_right, r_right): | ||
diff_left = 0 | ||
diff_right = 0 | ||
for name in self._r_shapes: | ||
dz = z_right[name] - z_left[name] | ||
diff_left += (dz * r_left[name]).sum() | ||
diff_right += (dz * r_right[name]).sum() | ||
return diff_left < 0 or diff_right < 0 | ||
# Set a flag to decide if we want to eliminate the initial point from the candidates to | ||
# choose uniformly along the trajectory. In [1], this flag is True, but in Stan, they set | ||
# it to False (implicitly). | ||
self._eliminate_starting_point = True | ||
|
||
def _is_turning(self, r_left, r_right, r_sum): | ||
# We follow the strategy in Section A.4.2 of [2] for this implementation. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would also suggest including this reference which derives the termination criterion explicitly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In [2], the author also derives the criterion too. I find it easier to understand than differential geometry style (though it is a nice language). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh..its the same derivation. I was looking at 4.2 rather than the appendix. |
||
r_left_flat = torch.cat([r_left[site_name].reshape(-1) for site_name in sorted(r_left)]) | ||
r_right_flat = torch.cat([r_right[site_name].reshape(-1) for site_name in sorted(r_right)]) | ||
if self.full_mass: | ||
return (r_sum - r_left_flat).dot(self._inverse_mass_matrix.matmul(r_left_flat)) <= 0 \ | ||
or (r_sum - r_right_flat).dot(self._inverse_mass_matrix.matmul(r_right_flat)) <= 0 | ||
else: | ||
return self._inverse_mass_matrix.dot((r_sum - r_left_flat) * r_left_flat) <= 0 \ | ||
or self._inverse_mass_matrix.dot((r_sum - r_right_flat) * r_right_flat) <= 0 | ||
|
||
def _build_basetree(self, z, r, z_grads, log_slice, direction, energy_current): | ||
step_size = self.step_size if direction == 1 else -self.step_size | ||
z_new, r_new, z_grads, potential_energy = single_step_velocity_verlet( | ||
z, r, self._potential_energy, self._inverse_mass_matrix, step_size, z_grads=z_grads) | ||
r_new_flat = torch.cat([r_new[site_name].reshape(-1) for site_name in sorted(r_new)]) | ||
energy_new = potential_energy + self._kinetic_energy(r_new) | ||
sliced_energy = energy_new + log_slice | ||
|
||
|
@@ -148,7 +157,7 @@ def _build_basetree(self, z, r, z_grads, log_slice, direction, energy_current): | |
delta_energy = energy_new - energy_current | ||
accept_prob = (-delta_energy).exp().clamp(max=1.0) | ||
return _TreeInfo(z_new, r_new, z_grads, z_new, r_new, z_grads, | ||
z_new, tree_size, False, diverging, accept_prob, 1) | ||
z_new, r_new_flat, tree_size, False, diverging, accept_prob, 1) | ||
|
||
def _build_tree(self, z, r, z_grads, log_slice, direction, tree_depth, energy_current): | ||
if tree_depth == 0: | ||
|
@@ -180,6 +189,7 @@ def _build_tree(self, z, r, z_grads, log_slice, direction, tree_depth, energy_cu | |
tree_size = half_tree.size + other_half_tree.size | ||
sum_accept_probs = half_tree.sum_accept_probs + other_half_tree.sum_accept_probs | ||
num_proposals = half_tree.num_proposals + other_half_tree.num_proposals | ||
r_sum = half_tree.r_sum + other_half_tree.r_sum | ||
|
||
# Under the slice sampling process, a proposal for z is uniformly picked. | ||
# The probability of that proposal belongs to which half of tree | ||
|
@@ -212,20 +222,20 @@ def _build_tree(self, z, r, z_grads, log_slice, direction, tree_depth, energy_cu | |
|
||
# We already check if first half tree is turning. Now, we check | ||
# if the other half tree or full tree are turning. | ||
turning = other_half_tree.turning or self._is_turning(z_left, r_left, z_right, r_right) | ||
turning = other_half_tree.turning or self._is_turning(r_left, r_right, r_sum) | ||
|
||
# The divergence is checked by the second half tree (the first half is already checked). | ||
diverging = other_half_tree.diverging | ||
|
||
return _TreeInfo(z_left, r_left, z_left_grads, z_right, r_right, z_right_grads, z_proposal, | ||
tree_size, turning, diverging, sum_accept_probs, num_proposals) | ||
r_sum, tree_size, turning, diverging, sum_accept_probs, num_proposals) | ||
|
||
def sample(self, trace): | ||
z = {name: node["value"].detach() for name, node in self._iter_latent_nodes(trace)} | ||
# automatically transform `z` to unconstrained space, if needed. | ||
for name, transform in self.transforms.items(): | ||
z[name] = transform(z[name]) | ||
r = self._sample_r(name="r_t={}".format(self._t)) | ||
r, r_flat = self._sample_r(name="r_t={}".format(self._t)) | ||
energy_current = self._energy(z, r) | ||
|
||
# Ideally, following a symplectic integrator trajectory, the energy is constant. | ||
|
@@ -240,8 +250,8 @@ def sample(self, trace): | |
# (z, r) ~ Uniform({(z', r') in trajectory | p(z', r') >= u}). | ||
# | ||
# For more information about slice sampling method, see [3]. | ||
# For another version of NUTS which uses multinomial sampling instead of slice sampling, see | ||
# [2]. | ||
# For another version of NUTS which uses multinomial sampling instead of slice sampling, | ||
# see [2]. | ||
|
||
# Rather than sampling the slice variable from `Uniform(0, exp(-energy))`, we can | ||
# sample log_slice directly using `energy`, so as to avoid potential underflow or | ||
|
@@ -253,8 +263,9 @@ def sample(self, trace): | |
z_left = z_right = z | ||
r_left = r_right = r | ||
z_left_grads = z_right_grads = None | ||
tree_size = 1 | ||
tree_size = 0 if self._eliminate_starting_point else 1 | ||
accepted = False | ||
r_sum = r_flat | ||
|
||
# Temporarily disable distributions args checking as | ||
# NaNs are expected during step size adaptation. | ||
|
@@ -283,11 +294,13 @@ def sample(self, trace): | |
|
||
rand = pyro.sample("rand_t={}_treedepth={}".format(self._t, tree_depth), | ||
dist.Uniform(torch.zeros(1), torch.ones(1))) | ||
if rand < new_tree.size / tree_size: | ||
if ((tree_size > 0) and (rand < new_tree.size / tree_size)) \ | ||
or ((tree_size == 0) and (new_tree.size > 0)): | ||
accepted = True | ||
z = new_tree.z_proposal | ||
|
||
if self._is_turning(z_left, r_left, z_right, r_right): # stop doubling | ||
r_sum = r_sum + new_tree.r_sum | ||
if self._is_turning(r_left, r_right, r_sum): # stop doubling | ||
break | ||
else: # update tree_size | ||
tree_size += new_tree.size | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,19 +14,57 @@ | |
import pyro.poutine as poutine | ||
from tests.common import assert_equal | ||
|
||
from .test_hmc import TEST_CASES, TEST_IDS, T, rmse | ||
from .test_hmc import GaussianChain, T, rmse | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
T2 = T(*TEST_CASES[2].values)._replace(num_samples=800, warmup_steps=200) | ||
TEST_CASES[2] = pytest.param(*T2, marks=pytest.mark.skipif( | ||
'CI' in os.environ or 'CUDA_TEST' in os.environ, | ||
reason='Slow test - skip on CI/CUDA')) | ||
T3 = T(*TEST_CASES[3].values)._replace(num_samples=1000, warmup_steps=200) | ||
TEST_CASES[3] = pytest.param(*T3, marks=[ | ||
pytest.mark.skipif('CI' in os.environ or 'CUDA_TEST' in os.environ, | ||
reason='Slow test - skip on CI/CUDA')] | ||
) | ||
TEST_CASES = [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1, this has been bugging me too for a while. 😄 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, me too ^^ |
||
T( | ||
GaussianChain(dim=10, chain_len=3, num_obs=1), | ||
num_samples=800, | ||
warmup_steps=200, | ||
hmc_params=None, | ||
expected_means=[0.25, 0.50, 0.75], | ||
expected_precs=[1.33, 1, 1.33], | ||
mean_tol=0.06, | ||
std_tol=0.06, | ||
), | ||
T( | ||
GaussianChain(dim=10, chain_len=4, num_obs=1), | ||
num_samples=800, | ||
warmup_steps=200, | ||
hmc_params=None, | ||
expected_means=[0.20, 0.40, 0.60, 0.80], | ||
expected_precs=[1.25, 0.83, 0.83, 1.25], | ||
mean_tol=0.06, | ||
std_tol=0.06, | ||
), | ||
pytest.param(*T( | ||
GaussianChain(dim=5, chain_len=2, num_obs=10000), | ||
num_samples=800, | ||
warmup_steps=200, | ||
hmc_params=None, | ||
expected_means=[0.5, 1.0], | ||
expected_precs=[2.0, 10000], | ||
mean_tol=0.04, | ||
std_tol=0.04, | ||
), marks=[pytest.mark.skipif('CI' in os.environ or 'CUDA_TEST' in os.environ, | ||
reason='Slow test - skip on CI/CUDA')]), | ||
pytest.param(*T( | ||
GaussianChain(dim=5, chain_len=9, num_obs=1), | ||
num_samples=1200, | ||
warmup_steps=200, | ||
hmc_params=None, | ||
expected_means=[0.10, 0.20, 0.30, 0.40, 0.50, 0.60, 0.70, 0.80, 0.90], | ||
expected_precs=[1.11, 0.63, 0.48, 0.42, 0.4, 0.42, 0.48, 0.63, 1.11], | ||
mean_tol=0.07, | ||
std_tol=0.07, | ||
), marks=[pytest.mark.skipif('CI' in os.environ or 'CUDA_TEST' in os.environ, | ||
reason='Slow test - skip on CI/CUDA')]) | ||
] | ||
|
||
TEST_IDS = [t[0].id_fn() if type(t).__name__ == 'TestExample' | ||
else t[0][0].id_fn() for t in TEST_CASES] | ||
|
||
|
||
@pytest.mark.parametrize( | ||
|
@@ -44,7 +82,7 @@ def test_nuts_conjugate_gaussian(fixture, | |
mean_tol, | ||
std_tol): | ||
pyro.get_param_store().clear() | ||
nuts_kernel = NUTS(fixture.model, hmc_params['step_size']) | ||
nuts_kernel = NUTS(fixture.model) | ||
mcmc_run = MCMC(nuts_kernel, num_samples, warmup_steps).run(fixture.data) | ||
for i in range(1, fixture.chain_len + 1): | ||
param_name = 'loc_' + str(i) | ||
|
@@ -90,8 +128,8 @@ def model(data): | |
@pytest.mark.parametrize( | ||
"step_size, adapt_step_size, adapt_mass_matrix, full_mass", | ||
[ | ||
(0.02, False, False, False), | ||
(0.02, False, True, False), | ||
(0.1, False, False, False), | ||
(0.1, False, True, False), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the old step_size works too but it is slow |
||
(None, True, False, False), | ||
(None, True, True, False), | ||
(None, True, True, True), | ||
|
@@ -154,7 +192,7 @@ def model(data): | |
true_beta = torch.tensor(1.) | ||
data = dist.Beta(concentration1=true_alpha, concentration0=true_beta).sample(torch.Size((5000,))) | ||
nuts_kernel = NUTS(model) | ||
mcmc_run = MCMC(nuts_kernel, num_samples=500, warmup_steps=200).run(data) | ||
mcmc_run = MCMC(nuts_kernel, num_samples=600, warmup_steps=200).run(data) | ||
posterior = EmpiricalMarginal(mcmc_run, sites=['alpha', 'beta']) | ||
assert_equal(posterior.mean, torch.stack([true_alpha, true_beta]), prec=0.05) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let us put the uninstall option back in.