Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify turning condition for nuts #1466

Merged
merged 10 commits into from
Oct 19, 2018
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ all: docs test
install: FORCE
pip install -e .[dev,profile]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let us put the uninstall option back in.

uninstall: FORCE
pip uninstall pyro-ppl
reinstall: FORCE
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When do you need this? If you do an editable install via -e, it should pick up any local changes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whoa, I have uninstalled and installed again and again to test examples, so I made this command. Thx @neerajprad!

pip uninstall -y pyro-ppl && pip install . --no-deps

docs: FORCE
$(MAKE) -C docs html
Expand Down
6 changes: 3 additions & 3 deletions pyro/infer/mcmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def _find_reasonable_step_size(self, z):
# We are going to find a step_size which make accept_prob (Metropolis correction)
# near the target_accept_prob. If accept_prob:=exp(-delta_energy) is small,
# then we have to decrease step_size; otherwise, increase step_size.
r = self._sample_r(name="r_presample")
r, _ = self._sample_r(name="r_presample")
energy_current = self._energy(z, r)
z_new, r_new, z_grads, potential_energy = single_step_velocity_verlet(
z, r, self._potential_energy, self._inverse_mass_matrix, step_size)
Expand Down Expand Up @@ -322,7 +322,7 @@ def _sample_r(self, name):
r[name] = r_flat[pos:next_pos].reshape(self._r_shapes[name])
pos = next_pos
assert pos == r_flat.size(0)
return r
return r, r_flat

def _validate_trace(self, trace):
trace_eval = TraceEinsumEvaluator if self.use_einsum else TraceTreeEvaluator
Expand Down Expand Up @@ -378,7 +378,7 @@ def sample(self, trace):
for name, transform in self.transforms.items():
z[name] = transform(z[name])

r = self._sample_r(name="r_t={}".format(self._t))
r, _ = self._sample_r(name="r_t={}".format(self._t))

# Temporarily disable distributions args checking as
# NaNs are expected during step size adaptation
Expand Down
51 changes: 32 additions & 19 deletions pyro/infer/mcmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
# sum_accept_probs and num_proposals are used to calculate
# the statistic accept_prob for Dual Averaging scheme;
# z_left_grads and z_right_grads are kept to avoid recalculating
# grads at left and right leaves
# grads at left and right leaves;
# r_sum is used to check turning condition
_TreeInfo = namedtuple("TreeInfo", ["z_left", "r_left", "z_left_grads",
"z_right", "r_right", "z_right_grads",
"z_proposal", "size", "turning", "diverging",
"z_proposal", "r_sum", "size", "turning", "diverging",
"sum_accept_probs", "num_proposals"])


Expand Down Expand Up @@ -115,19 +116,27 @@ def __init__(self,
# Here, as suggested in [1], we set dE_max = 1000.
self._max_sliced_energy = 1000

def _is_turning(self, z_left, r_left, z_right, r_right):
diff_left = 0
diff_right = 0
for name in self._r_shapes:
dz = z_right[name] - z_left[name]
diff_left += (dz * r_left[name]).sum()
diff_right += (dz * r_right[name]).sum()
return diff_left < 0 or diff_right < 0
# Set a flag to decide if we want to eliminate the initial point from the candidates to
# choose uniformly along the trajectory. In [1], this flag is True, but in Stan, they set
# it to False (implicitly).
self._eliminate_starting_point = True

def _is_turning(self, r_left, r_right, r_sum):
# We follow the strategy in Section A.4.2 of [2] for this implementation.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would also suggest including this reference which derives the termination criterion explicitly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In [2], the author also derives the criterion too. I find it easier to understand than differential geometry style (though it is a nice language).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh..its the same derivation. I was looking at 4.2 rather than the appendix.

r_left_flat = torch.cat([r_left[site_name].reshape(-1) for site_name in sorted(r_left)])
r_right_flat = torch.cat([r_right[site_name].reshape(-1) for site_name in sorted(r_right)])
if self.full_mass:
return (r_sum - r_left_flat).dot(self._inverse_mass_matrix.matmul(r_left_flat)) <= 0 \
or (r_sum - r_right_flat).dot(self._inverse_mass_matrix.matmul(r_right_flat)) <= 0
else:
return self._inverse_mass_matrix.dot((r_sum - r_left_flat) * r_left_flat) <= 0 \
or self._inverse_mass_matrix.dot((r_sum - r_right_flat) * r_right_flat) <= 0

def _build_basetree(self, z, r, z_grads, log_slice, direction, energy_current):
step_size = self.step_size if direction == 1 else -self.step_size
z_new, r_new, z_grads, potential_energy = single_step_velocity_verlet(
z, r, self._potential_energy, self._inverse_mass_matrix, step_size, z_grads=z_grads)
r_new_flat = torch.cat([r_new[site_name].reshape(-1) for site_name in sorted(r_new)])
energy_new = potential_energy + self._kinetic_energy(r_new)
sliced_energy = energy_new + log_slice

Expand All @@ -148,7 +157,7 @@ def _build_basetree(self, z, r, z_grads, log_slice, direction, energy_current):
delta_energy = energy_new - energy_current
accept_prob = (-delta_energy).exp().clamp(max=1.0)
return _TreeInfo(z_new, r_new, z_grads, z_new, r_new, z_grads,
z_new, tree_size, False, diverging, accept_prob, 1)
z_new, r_new_flat, tree_size, False, diverging, accept_prob, 1)

def _build_tree(self, z, r, z_grads, log_slice, direction, tree_depth, energy_current):
if tree_depth == 0:
Expand Down Expand Up @@ -180,6 +189,7 @@ def _build_tree(self, z, r, z_grads, log_slice, direction, tree_depth, energy_cu
tree_size = half_tree.size + other_half_tree.size
sum_accept_probs = half_tree.sum_accept_probs + other_half_tree.sum_accept_probs
num_proposals = half_tree.num_proposals + other_half_tree.num_proposals
r_sum = half_tree.r_sum + other_half_tree.r_sum

# Under the slice sampling process, a proposal for z is uniformly picked.
# The probability of that proposal belongs to which half of tree
Expand Down Expand Up @@ -212,20 +222,20 @@ def _build_tree(self, z, r, z_grads, log_slice, direction, tree_depth, energy_cu

# We already check if first half tree is turning. Now, we check
# if the other half tree or full tree are turning.
turning = other_half_tree.turning or self._is_turning(z_left, r_left, z_right, r_right)
turning = other_half_tree.turning or self._is_turning(r_left, r_right, r_sum)

# The divergence is checked by the second half tree (the first half is already checked).
diverging = other_half_tree.diverging

return _TreeInfo(z_left, r_left, z_left_grads, z_right, r_right, z_right_grads, z_proposal,
tree_size, turning, diverging, sum_accept_probs, num_proposals)
r_sum, tree_size, turning, diverging, sum_accept_probs, num_proposals)

def sample(self, trace):
z = {name: node["value"].detach() for name, node in self._iter_latent_nodes(trace)}
# automatically transform `z` to unconstrained space, if needed.
for name, transform in self.transforms.items():
z[name] = transform(z[name])
r = self._sample_r(name="r_t={}".format(self._t))
r, r_flat = self._sample_r(name="r_t={}".format(self._t))
energy_current = self._energy(z, r)

# Ideally, following a symplectic integrator trajectory, the energy is constant.
Expand All @@ -240,8 +250,8 @@ def sample(self, trace):
# (z, r) ~ Uniform({(z', r') in trajectory | p(z', r') >= u}).
#
# For more information about slice sampling method, see [3].
# For another version of NUTS which uses multinomial sampling instead of slice sampling, see
# [2].
# For another version of NUTS which uses multinomial sampling instead of slice sampling,
# see [2].

# Rather than sampling the slice variable from `Uniform(0, exp(-energy))`, we can
# sample log_slice directly using `energy`, so as to avoid potential underflow or
Expand All @@ -253,8 +263,9 @@ def sample(self, trace):
z_left = z_right = z
r_left = r_right = r
z_left_grads = z_right_grads = None
tree_size = 1
tree_size = 0 if self._eliminate_starting_point else 1
accepted = False
r_sum = r_flat

# Temporarily disable distributions args checking as
# NaNs are expected during step size adaptation.
Expand Down Expand Up @@ -283,11 +294,13 @@ def sample(self, trace):

rand = pyro.sample("rand_t={}_treedepth={}".format(self._t, tree_depth),
dist.Uniform(torch.zeros(1), torch.ones(1)))
if rand < new_tree.size / tree_size:
if ((tree_size > 0) and (rand < new_tree.size / tree_size)) \
or ((tree_size == 0) and (new_tree.size > 0)):
accepted = True
z = new_tree.z_proposal

if self._is_turning(z_left, r_left, z_right, r_right): # stop doubling
r_sum = r_sum + new_tree.r_sum
if self._is_turning(r_left, r_right, r_sum): # stop doubling
break
else: # update tree_size
tree_size += new_tree.size
Expand Down
66 changes: 52 additions & 14 deletions tests/infer/mcmc/test_nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,57 @@
import pyro.poutine as poutine
from tests.common import assert_equal

from .test_hmc import TEST_CASES, TEST_IDS, T, rmse
from .test_hmc import GaussianChain, T, rmse

logger = logging.getLogger(__name__)

T2 = T(*TEST_CASES[2].values)._replace(num_samples=800, warmup_steps=200)
TEST_CASES[2] = pytest.param(*T2, marks=pytest.mark.skipif(
'CI' in os.environ or 'CUDA_TEST' in os.environ,
reason='Slow test - skip on CI/CUDA'))
T3 = T(*TEST_CASES[3].values)._replace(num_samples=1000, warmup_steps=200)
TEST_CASES[3] = pytest.param(*T3, marks=[
pytest.mark.skipif('CI' in os.environ or 'CUDA_TEST' in os.environ,
reason='Slow test - skip on CI/CUDA')]
)
TEST_CASES = [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, this has been bugging me too for a while. 😄

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, me too ^^

T(
GaussianChain(dim=10, chain_len=3, num_obs=1),
num_samples=800,
warmup_steps=200,
hmc_params=None,
expected_means=[0.25, 0.50, 0.75],
expected_precs=[1.33, 1, 1.33],
mean_tol=0.06,
std_tol=0.06,
),
T(
GaussianChain(dim=10, chain_len=4, num_obs=1),
num_samples=800,
warmup_steps=200,
hmc_params=None,
expected_means=[0.20, 0.40, 0.60, 0.80],
expected_precs=[1.25, 0.83, 0.83, 1.25],
mean_tol=0.06,
std_tol=0.06,
),
pytest.param(*T(
GaussianChain(dim=5, chain_len=2, num_obs=10000),
num_samples=800,
warmup_steps=200,
hmc_params=None,
expected_means=[0.5, 1.0],
expected_precs=[2.0, 10000],
mean_tol=0.04,
std_tol=0.04,
), marks=[pytest.mark.skipif('CI' in os.environ or 'CUDA_TEST' in os.environ,
reason='Slow test - skip on CI/CUDA')]),
pytest.param(*T(
GaussianChain(dim=5, chain_len=9, num_obs=1),
num_samples=1200,
warmup_steps=200,
hmc_params=None,
expected_means=[0.10, 0.20, 0.30, 0.40, 0.50, 0.60, 0.70, 0.80, 0.90],
expected_precs=[1.11, 0.63, 0.48, 0.42, 0.4, 0.42, 0.48, 0.63, 1.11],
mean_tol=0.07,
std_tol=0.07,
), marks=[pytest.mark.skipif('CI' in os.environ or 'CUDA_TEST' in os.environ,
reason='Slow test - skip on CI/CUDA')])
]

TEST_IDS = [t[0].id_fn() if type(t).__name__ == 'TestExample'
else t[0][0].id_fn() for t in TEST_CASES]


@pytest.mark.parametrize(
Expand All @@ -44,7 +82,7 @@ def test_nuts_conjugate_gaussian(fixture,
mean_tol,
std_tol):
pyro.get_param_store().clear()
nuts_kernel = NUTS(fixture.model, hmc_params['step_size'])
nuts_kernel = NUTS(fixture.model)
mcmc_run = MCMC(nuts_kernel, num_samples, warmup_steps).run(fixture.data)
for i in range(1, fixture.chain_len + 1):
param_name = 'loc_' + str(i)
Expand Down Expand Up @@ -90,8 +128,8 @@ def model(data):
@pytest.mark.parametrize(
"step_size, adapt_step_size, adapt_mass_matrix, full_mass",
[
(0.02, False, False, False),
(0.02, False, True, False),
(0.1, False, False, False),
(0.1, False, True, False),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the old step_size works too but it is slow

(None, True, False, False),
(None, True, True, False),
(None, True, True, True),
Expand Down Expand Up @@ -154,7 +192,7 @@ def model(data):
true_beta = torch.tensor(1.)
data = dist.Beta(concentration1=true_alpha, concentration0=true_beta).sample(torch.Size((5000,)))
nuts_kernel = NUTS(model)
mcmc_run = MCMC(nuts_kernel, num_samples=500, warmup_steps=200).run(data)
mcmc_run = MCMC(nuts_kernel, num_samples=600, warmup_steps=200).run(data)
posterior = EmpiricalMarginal(mcmc_run, sites=['alpha', 'beta'])
assert_equal(posterior.mean, torch.stack([true_alpha, true_beta]), prec=0.05)

Expand Down