From 590446b957653defce079e1af362f03d83bc44d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Fri, 7 Oct 2022 17:54:09 +0200 Subject: [PATCH] Refactor the base SMC kernel Base SMC is neatly divided in 3 steps: - particle update - particle weighting - resampling --- blackjax/kernels.py | 28 ++++++------- blackjax/smc/adaptive_tempered.py | 19 +++++---- blackjax/smc/base.py | 66 +++++++++++++++---------------- blackjax/smc/tempered.py | 33 ++++++++++------ tests/test_smc.py | 8 ++-- tests/test_tempered_smc.py | 15 ++++--- 6 files changed, 88 insertions(+), 81 deletions(-) diff --git a/blackjax/kernels.py b/blackjax/kernels.py index aad55fc9d..cc1db4254 100644 --- a/blackjax/kernels.py +++ b/blackjax/kernels.py @@ -51,27 +51,25 @@ def __new__( # type: ignore[misc] cls, logprior_fn: Callable, loglikelihood_fn: Callable, - mcmc_algorithm: SamplingAlgorithm, + mcmc_step_fn: Callable, + mcmc_init_fn: Callable, mcmc_parameters: Dict, resampling_fn: Callable, target_ess: float, root_solver: Callable = smc.solver.dichotomy, use_log_ess: bool = True, - mcmc_iter: int = 10, + num_mcmc_steps: int = 10, ) -> SamplingAlgorithm: - def kernel_factory(logprob_fn): - return mcmc_algorithm(logprob_fn, **mcmc_parameters).step step = cls.kernel( logprior_fn, loglikelihood_fn, - kernel_factory, - mcmc_algorithm.init, + mcmc_step_fn, + mcmc_init_fn, resampling_fn, target_ess, root_solver, use_log_ess, - mcmc_iter, ) def init_fn(position: PyTree): @@ -81,6 +79,8 @@ def step_fn(rng_key: PRNGKey, state): return step( rng_key, state, + num_mcmc_steps, + mcmc_parameters, ) return SamplingAlgorithm(init_fn, step_fn) @@ -103,21 +103,19 @@ def __new__( # type: ignore[misc] cls, logprior_fn: Callable, loglikelihood_fn: Callable, - mcmc_algorithm: SamplingAlgorithm, + mcmc_step_fn: Callable, + mcmc_init_fn: Callable, mcmc_parameters: Dict, resampling_fn: Callable, - mcmc_iter: int = 10, + num_mcmc_steps: int = 10, ) -> SamplingAlgorithm: - def kernel_factory(logprob_fn): - return mcmc_algorithm(logprob_fn, **mcmc_parameters).step step = cls.kernel( logprior_fn, loglikelihood_fn, - kernel_factory, - mcmc_algorithm.init, + mcmc_step_fn, + mcmc_init_fn, resampling_fn, - mcmc_iter, ) def init_fn(position: PyTree): @@ -127,7 +125,9 @@ def step_fn(rng_key: PRNGKey, state, lmbda): return step( rng_key, state, + num_mcmc_steps, lmbda, + mcmc_parameters, ) return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] diff --git a/blackjax/smc/adaptive_tempered.py b/blackjax/smc/adaptive_tempered.py index a694c25e9..8c78e2db2 100644 --- a/blackjax/smc/adaptive_tempered.py +++ b/blackjax/smc/adaptive_tempered.py @@ -15,13 +15,12 @@ def kernel( logprior_fn: Callable, loglikelihood_fn: Callable, - mcmc_kernel_factory: Callable, - make_mcmc_state: Callable, + mcmc_step_fn: Callable, + mcmc_init_fn: Callable, resampling_fn: Callable, target_ess: float, root_solver: Callable = solver.dichotomy, use_log_ess: bool = True, - mcmc_iter: int = 10, ) -> Callable: r"""Build a Tempered SMC step using an adaptive schedule. @@ -47,8 +46,6 @@ def kernel( use_log_ess: bool, optional Use ESS in log space to solve for delta, default is `True`. This is usually more stable when using gradient based solvers. - mcmc_iter: int - Number of iterations in the MCMC chain. Returns ------- @@ -75,17 +72,19 @@ def compute_delta(state: tempered.TemperedSMCState) -> float: kernel = tempered.kernel( logprior_fn, loglikelihood_fn, - mcmc_kernel_factory, - make_mcmc_state, + mcmc_step_fn, + mcmc_init_fn, resampling_fn, - mcmc_iter, ) def one_step( - rng_key: PRNGKey, state: tempered.TemperedSMCState + rng_key: PRNGKey, + state: tempered.TemperedSMCState, + num_mcmc_steps: int, + mcmc_parameters: dict, ) -> Tuple[tempered.TemperedSMCState, base.SMCInfo]: delta = compute_delta(state) lmbda = delta + state.lmbda - return kernel(rng_key, state, lmbda) + return kernel(rng_key, state, num_mcmc_steps, lmbda, mcmc_parameters) return one_step diff --git a/blackjax/smc/base.py b/blackjax/smc/base.py index 4dbe70bbd..90145fd61 100644 --- a/blackjax/smc/base.py +++ b/blackjax/smc/base.py @@ -1,4 +1,3 @@ -import functools as ft from typing import Callable, NamedTuple, Tuple import jax @@ -28,20 +27,19 @@ class SMCInfo(NamedTuple): def kernel( - mcmc_kernel: Callable, - mcmc_init: Callable, + mcmc_step_fn: Callable, + mcmc_init_fn: Callable, resampling_fn: Callable, - num_mcmc_steps: int, ): """Build a generic SMC kernel. In Feynman-Kac equivalent terms, the algo goes roughly as follows: ``` - M_t = mcmc_kernel(logprob_fn, **parameters) - for i in range(num_mcmc_steps): - x_t^i = M_t(..., x_t^i) + M_t = mcmc_kernel G_t = log_weights_fn + for i in range(num_mcmc_steps): + x_t^i = M_t(..., x_t^i, logprob_fn, **parameters) log_weights = G_t(x_t) idx = resample(log_weights) x_t = x_t[idx] @@ -50,15 +48,13 @@ def kernel( Parameters ---------- - mcmc_kernel: Callable - A MCMC kernel that generates a new sample from a give state. - mcmc_init: Callable + mcmc_step_fn: Callable + A MCMC step function that generates a new sample from a give state. + mcmc_init_fn: Callable Creates a new MCMC state from a position. resampling_fn: Callable A function that resamples the particles generated by the MCMC kernel, based of previously computed weights. - num_mcmc_steps: int - Number of iterations of the MCMC kernel Returns ------- @@ -73,18 +69,11 @@ def one_step( particles: PyTree, logprob_fn: Callable, log_weight_fn: Callable, + num_mcmc_steps: int, mcmc_parameters: dict, ) -> Tuple[PyTree, SMCInfo]: """ - We could write this in a much better way? - - particles = vmap(f)(particles, *parameters) - weights = vmap(logweightfn)(weights, particles) - resampled_weights = sample(particles, weights) - - Plus the problem is that you may want to parallelize in a different way later - Parameters ---------- rng_key: DeviceArray[int], @@ -95,6 +84,8 @@ def one_step( Log probability function we wish to sample from. log_weight_fn: Callable A function that represents the Feynman-Kac log potential at time t. + num_mcmc_steps: int + Number of iterations of the MCMC kernel mcmc_parameters: dict A dictionary that contains the parameters of the MCMC kernel. @@ -106,25 +97,32 @@ def one_step( Additional information on the SMC step """ - num_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0] - scan_key, resampling_key = jax.random.split(rng_key, 2) + update_key, resampling_key = jax.random.split(rng_key, 2) + + # TODO: Consider asking the caller to provide the particle_update_fn + # instead + def mcmc_update_particle(rng_key, position): + state = mcmc_init_fn(position, logprob_fn) - applied_mcmc_kernel = ft.partial(mcmc_kernel, **mcmc_parameters) + def body_fn(state, rng_key): + new_state, _ = mcmc_step_fn( + rng_key, state, logprob_fn, **mcmc_parameters + ) + return new_state, new_state - def mcmc_body_fn(curr_particles, curr_key): - keys = jax.random.split(curr_key, num_particles) - new_particles, _ = jax.vmap(applied_mcmc_kernel, in_axes=(0, 0, None))( - keys, curr_particles, logprob_fn - ) - return new_particles, None + keys = jax.random.split(rng_key, num_mcmc_steps) + last_state, _ = jax.lax.scan(body_fn, state, keys) + return last_state.position - mcmc_state = jax.vmap(mcmc_init, in_axes=(0, None))(particles, logprob_fn) - keys = jax.random.split(scan_key, num_mcmc_steps) - proposed_states, _ = jax.lax.scan(mcmc_body_fn, mcmc_state, keys) - proposed_particles = proposed_states.position + # Update the particles (parallel) + num_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0] + keys = jax.random.split(update_key, num_particles) + proposed_particles = jax.vmap(mcmc_update_particle)(keys, particles) - # Resample the particles depending on their respective weights + # Compute the particles' respective weight (parallel) log_weights = jax.vmap(log_weight_fn, in_axes=(0,))(proposed_particles) + + # Resample the particles (sync) weights, log_likelihood_increment = _normalize(log_weights) resampling_index = resampling_fn(weights, resampling_key) particles = jax.tree_map(lambda x: x[resampling_index], proposed_particles) diff --git a/blackjax/smc/tempered.py b/blackjax/smc/tempered.py index e8641b7b1..42b259fa8 100644 --- a/blackjax/smc/tempered.py +++ b/blackjax/smc/tempered.py @@ -26,10 +26,9 @@ def init(position: PyTree): def kernel( logprior_fn: Callable, loglikelihood_fn: Callable, - mcmc_kernel_factory: Callable, - make_mcmc_state: Callable, + mcmc_step_fn: Callable, + mcmc_init_fn: Callable, resampling_fn: Callable, - num_mcmc_iterations: int, ) -> Callable: """Build the base Tempered SMC kernel. @@ -38,9 +37,10 @@ def kernel( .. math:: p(x) \\propto p_0(x) \\exp(-V(x)) \\mathrm{d}x - where :math:`p_0` is the prior distribution, typically easy to sample from and for which the density - is easy to compute, and :math:`\\exp(-V(x))` is an unnormalized likelihood term for which :math:`V(x)` is easy - to compute pointwise. + where :math:`p_0` is the prior distribution, typically easy to sample from + and for which the density is easy to compute, and :math:`\\exp(-V(x))` is an + unnormalized likelihood term for which :math:`V(x)` is easy to compute + pointwise. Parameters ---------- @@ -49,9 +49,9 @@ def kernel( loglikelihood_fn A function that returns the probability at a given position. - mcmc_kernel_factory + mcmc_step_fn A function that creates a mcmc kernel from a log-probability density function. - make_mcmc_state: Callable + mcmc_init_fn: Callable A function that creates a new mcmc state from a position and a log-probability density function. resampling_fn @@ -66,12 +66,14 @@ def kernel( information about the transition. """ - kernel = smc.base.kernel( - mcmc_kernel_factory, make_mcmc_state, resampling_fn, num_mcmc_iterations - ) + kernel = smc.base.kernel(mcmc_step_fn, mcmc_init_fn, resampling_fn) def one_step( - rng_key: PRNGKey, state: TemperedSMCState, lmbda: float + rng_key: PRNGKey, + state: TemperedSMCState, + num_mcmc_steps: int, + lmbda: float, + mcmc_parameters: dict, ) -> Tuple[TemperedSMCState, smc.base.SMCInfo]: """Move the particles one step using the Tempered SMC algorithm. @@ -102,7 +104,12 @@ def tempered_logposterior_fn(position: PyTree) -> float: return logprior + tempered_loglikelihood smc_state, info = kernel( - rng_key, state.particles, tempered_logposterior_fn, log_weights_fn + rng_key, + state.particles, + tempered_logposterior_fn, + log_weights_fn, + num_mcmc_steps, + mcmc_parameters, ) state = TemperedSMCState(smc_state, state.lmbda + delta) diff --git a/tests/test_smc.py b/tests/test_smc.py index 0f9a5b830..034a7d118 100644 --- a/tests/test_smc.py +++ b/tests/test_smc.py @@ -6,7 +6,7 @@ import jax.numpy as jnp import jax.scipy.stats as stats import numpy as np -from absl.testing import absltest, parameterized +from absl.testing import absltest import blackjax import blackjax.smc.base as base @@ -27,9 +27,9 @@ def setUp(self): self.key = jax.random.PRNGKey(42) @chex.all_variants(with_pmap=False) - @parameterized.parameters([500, 1000, 5000]) - def test_smc(self, N): + def test_smc(self): + N = 100 mcmc_parameters = { "step_size": 1e-2, "inverse_mass_matrix": jnp.eye(1), @@ -42,7 +42,6 @@ def test_smc(self, N): blackjax.mcmc.hmc.kernel(), blackjax.mcmc.hmc.init, resampling.systematic, - 1000, ) # Don't use exactly the invariant distribution for the MCMC kernel @@ -53,6 +52,7 @@ def test_smc(self, N): smc_kernel, logprob_fn=kernel_logprob_fn, log_weight_fn=specialized_log_weights_fn, + num_mcmc_steps=1500, mcmc_parameters=mcmc_parameters, ) )(self.key, init_particles) diff --git a/tests/test_tempered_smc.py b/tests/test_tempered_smc.py index 0a879469e..ae1a3079e 100644 --- a/tests/test_tempered_smc.py +++ b/tests/test_tempered_smc.py @@ -49,7 +49,7 @@ def logprob_fn(self, scale, coefs, preds, x): return jnp.sum(logpdf) @chex.all_variants(without_jit=False, with_pmap=False) - @parameterized.parameters(itertools.product([100, 5000], [True, False])) + @parameterized.parameters(itertools.product([100], [True, False])) def test_adaptive_tempered_smc(self, N, use_log): x_data = np.random.normal(0, 1, size=(1000, 1)) y_data = 3 * x_data + np.random.normal(size=x_data.shape) @@ -75,7 +75,8 @@ def test_adaptive_tempered_smc(self, N, use_log): tempering = adaptive_tempered_smc( prior, conditioned_logprob, - blackjax.hmc, + blackjax.hmc.kernel(), + blackjax.hmc.init, hmc_parameters, resampling.systematic, target_ess, @@ -97,7 +98,7 @@ def test_adaptive_tempered_smc(self, N, use_log): assert iterates[1] >= iterates[0] @chex.all_variants(without_jit=False, with_pmap=False) - @parameterized.parameters(itertools.product([100, 1000], [10, 100])) + @parameterized.parameters(itertools.product([100], [10])) def test_fixed_schedule_tempered_smc(self, N, n_schedule): x_data = np.random.normal(0, 1, size=(1000, 1)) y_data = 3 * x_data + np.random.normal(size=x_data.shape) @@ -120,7 +121,8 @@ def test_fixed_schedule_tempered_smc(self, N, n_schedule): tempering = tempered_smc( prior, conditionned_logprob, - blackjax.hmc, + blackjax.hmc.kernel(), + blackjax.hmc.init, hmc_parameters, resampling.systematic, 10, @@ -155,7 +157,7 @@ class NormalizingConstantTest(chex.TestCase): """Test normalizing constant estimate.""" @chex.all_variants(without_jit=False, with_pmap=False) - @parameterized.parameters(itertools.product([500, 1_000], [2, 10])) + @parameterized.parameters(itertools.product([500], [2])) def test_normalizing_constant(self, N, dim): rng_key = jax.random.PRNGKey(2356) rng_key, cov_key = jax.random.split(rng_key, 2) @@ -181,7 +183,8 @@ def test_normalizing_constant(self, N, dim): tempering = adaptive_tempered_smc( prior, conditionned_logprob, - blackjax.hmc, + blackjax.hmc.kernel(), + blackjax.hmc.init, hmc_parameters, resampling.systematic, 0.9,