Skip to content

Commit

Permalink
Refactor the base SMC kernel
Browse files Browse the repository at this point in the history
Base SMC is neatly divided in 3 steps:
- particle update
- particle weighting
- resampling
  • Loading branch information
rlouf committed Oct 9, 2022
1 parent 0bea265 commit 590446b
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 81 deletions.
28 changes: 14 additions & 14 deletions blackjax/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,27 +51,25 @@ def __new__( # type: ignore[misc]
cls,
logprior_fn: Callable,
loglikelihood_fn: Callable,
mcmc_algorithm: SamplingAlgorithm,
mcmc_step_fn: Callable,
mcmc_init_fn: Callable,
mcmc_parameters: Dict,
resampling_fn: Callable,
target_ess: float,
root_solver: Callable = smc.solver.dichotomy,
use_log_ess: bool = True,
mcmc_iter: int = 10,
num_mcmc_steps: int = 10,
) -> SamplingAlgorithm:
def kernel_factory(logprob_fn):
return mcmc_algorithm(logprob_fn, **mcmc_parameters).step

step = cls.kernel(
logprior_fn,
loglikelihood_fn,
kernel_factory,
mcmc_algorithm.init,
mcmc_step_fn,
mcmc_init_fn,
resampling_fn,
target_ess,
root_solver,
use_log_ess,
mcmc_iter,
)

def init_fn(position: PyTree):
Expand All @@ -81,6 +79,8 @@ def step_fn(rng_key: PRNGKey, state):
return step(
rng_key,
state,
num_mcmc_steps,
mcmc_parameters,
)

return SamplingAlgorithm(init_fn, step_fn)
Expand All @@ -103,21 +103,19 @@ def __new__( # type: ignore[misc]
cls,
logprior_fn: Callable,
loglikelihood_fn: Callable,
mcmc_algorithm: SamplingAlgorithm,
mcmc_step_fn: Callable,
mcmc_init_fn: Callable,
mcmc_parameters: Dict,
resampling_fn: Callable,
mcmc_iter: int = 10,
num_mcmc_steps: int = 10,
) -> SamplingAlgorithm:
def kernel_factory(logprob_fn):
return mcmc_algorithm(logprob_fn, **mcmc_parameters).step

step = cls.kernel(
logprior_fn,
loglikelihood_fn,
kernel_factory,
mcmc_algorithm.init,
mcmc_step_fn,
mcmc_init_fn,
resampling_fn,
mcmc_iter,
)

def init_fn(position: PyTree):
Expand All @@ -127,7 +125,9 @@ def step_fn(rng_key: PRNGKey, state, lmbda):
return step(
rng_key,
state,
num_mcmc_steps,
lmbda,
mcmc_parameters,
)

return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type]
Expand Down
19 changes: 9 additions & 10 deletions blackjax/smc/adaptive_tempered.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
def kernel(
logprior_fn: Callable,
loglikelihood_fn: Callable,
mcmc_kernel_factory: Callable,
make_mcmc_state: Callable,
mcmc_step_fn: Callable,
mcmc_init_fn: Callable,
resampling_fn: Callable,
target_ess: float,
root_solver: Callable = solver.dichotomy,
use_log_ess: bool = True,
mcmc_iter: int = 10,
) -> Callable:
r"""Build a Tempered SMC step using an adaptive schedule.
Expand All @@ -47,8 +46,6 @@ def kernel(
use_log_ess: bool, optional
Use ESS in log space to solve for delta, default is `True`.
This is usually more stable when using gradient based solvers.
mcmc_iter: int
Number of iterations in the MCMC chain.
Returns
-------
Expand All @@ -75,17 +72,19 @@ def compute_delta(state: tempered.TemperedSMCState) -> float:
kernel = tempered.kernel(
logprior_fn,
loglikelihood_fn,
mcmc_kernel_factory,
make_mcmc_state,
mcmc_step_fn,
mcmc_init_fn,
resampling_fn,
mcmc_iter,
)

def one_step(
rng_key: PRNGKey, state: tempered.TemperedSMCState
rng_key: PRNGKey,
state: tempered.TemperedSMCState,
num_mcmc_steps: int,
mcmc_parameters: dict,
) -> Tuple[tempered.TemperedSMCState, base.SMCInfo]:
delta = compute_delta(state)
lmbda = delta + state.lmbda
return kernel(rng_key, state, lmbda)
return kernel(rng_key, state, num_mcmc_steps, lmbda, mcmc_parameters)

return one_step
66 changes: 32 additions & 34 deletions blackjax/smc/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import functools as ft
from typing import Callable, NamedTuple, Tuple

import jax
Expand Down Expand Up @@ -28,20 +27,19 @@ class SMCInfo(NamedTuple):


def kernel(
mcmc_kernel: Callable,
mcmc_init: Callable,
mcmc_step_fn: Callable,
mcmc_init_fn: Callable,
resampling_fn: Callable,
num_mcmc_steps: int,
):
"""Build a generic SMC kernel.
In Feynman-Kac equivalent terms, the algo goes roughly as follows:
```
M_t = mcmc_kernel(logprob_fn, **parameters)
for i in range(num_mcmc_steps):
x_t^i = M_t(..., x_t^i)
M_t = mcmc_kernel
G_t = log_weights_fn
for i in range(num_mcmc_steps):
x_t^i = M_t(..., x_t^i, logprob_fn, **parameters)
log_weights = G_t(x_t)
idx = resample(log_weights)
x_t = x_t[idx]
Expand All @@ -50,15 +48,13 @@ def kernel(
Parameters
----------
mcmc_kernel: Callable
A MCMC kernel that generates a new sample from a give state.
mcmc_init: Callable
mcmc_step_fn: Callable
A MCMC step function that generates a new sample from a give state.
mcmc_init_fn: Callable
Creates a new MCMC state from a position.
resampling_fn: Callable
A function that resamples the particles generated by the MCMC kernel,
based of previously computed weights.
num_mcmc_steps: int
Number of iterations of the MCMC kernel
Returns
-------
Expand All @@ -73,18 +69,11 @@ def one_step(
particles: PyTree,
logprob_fn: Callable,
log_weight_fn: Callable,
num_mcmc_steps: int,
mcmc_parameters: dict,
) -> Tuple[PyTree, SMCInfo]:
"""
We could write this in a much better way?
particles = vmap(f)(particles, *parameters)
weights = vmap(logweightfn)(weights, particles)
resampled_weights = sample(particles, weights)
Plus the problem is that you may want to parallelize in a different way later
Parameters
----------
rng_key: DeviceArray[int],
Expand All @@ -95,6 +84,8 @@ def one_step(
Log probability function we wish to sample from.
log_weight_fn: Callable
A function that represents the Feynman-Kac log potential at time t.
num_mcmc_steps: int
Number of iterations of the MCMC kernel
mcmc_parameters: dict
A dictionary that contains the parameters of the MCMC kernel.
Expand All @@ -106,25 +97,32 @@ def one_step(
Additional information on the SMC step
"""

num_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0]
scan_key, resampling_key = jax.random.split(rng_key, 2)
update_key, resampling_key = jax.random.split(rng_key, 2)

# TODO: Consider asking the caller to provide the particle_update_fn
# instead
def mcmc_update_particle(rng_key, position):
state = mcmc_init_fn(position, logprob_fn)

applied_mcmc_kernel = ft.partial(mcmc_kernel, **mcmc_parameters)
def body_fn(state, rng_key):
new_state, _ = mcmc_step_fn(
rng_key, state, logprob_fn, **mcmc_parameters
)
return new_state, new_state

def mcmc_body_fn(curr_particles, curr_key):
keys = jax.random.split(curr_key, num_particles)
new_particles, _ = jax.vmap(applied_mcmc_kernel, in_axes=(0, 0, None))(
keys, curr_particles, logprob_fn
)
return new_particles, None
keys = jax.random.split(rng_key, num_mcmc_steps)
last_state, _ = jax.lax.scan(body_fn, state, keys)
return last_state.position

mcmc_state = jax.vmap(mcmc_init, in_axes=(0, None))(particles, logprob_fn)
keys = jax.random.split(scan_key, num_mcmc_steps)
proposed_states, _ = jax.lax.scan(mcmc_body_fn, mcmc_state, keys)
proposed_particles = proposed_states.position
# Update the particles (parallel)
num_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0]
keys = jax.random.split(update_key, num_particles)
proposed_particles = jax.vmap(mcmc_update_particle)(keys, particles)

# Resample the particles depending on their respective weights
# Compute the particles' respective weight (parallel)
log_weights = jax.vmap(log_weight_fn, in_axes=(0,))(proposed_particles)

# Resample the particles (sync)
weights, log_likelihood_increment = _normalize(log_weights)
resampling_index = resampling_fn(weights, resampling_key)
particles = jax.tree_map(lambda x: x[resampling_index], proposed_particles)
Expand Down
33 changes: 20 additions & 13 deletions blackjax/smc/tempered.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@ def init(position: PyTree):
def kernel(
logprior_fn: Callable,
loglikelihood_fn: Callable,
mcmc_kernel_factory: Callable,
make_mcmc_state: Callable,
mcmc_step_fn: Callable,
mcmc_init_fn: Callable,
resampling_fn: Callable,
num_mcmc_iterations: int,
) -> Callable:
"""Build the base Tempered SMC kernel.
Expand All @@ -38,9 +37,10 @@ def kernel(
.. math::
p(x) \\propto p_0(x) \\exp(-V(x)) \\mathrm{d}x
where :math:`p_0` is the prior distribution, typically easy to sample from and for which the density
is easy to compute, and :math:`\\exp(-V(x))` is an unnormalized likelihood term for which :math:`V(x)` is easy
to compute pointwise.
where :math:`p_0` is the prior distribution, typically easy to sample from
and for which the density is easy to compute, and :math:`\\exp(-V(x))` is an
unnormalized likelihood term for which :math:`V(x)` is easy to compute
pointwise.
Parameters
----------
Expand All @@ -49,9 +49,9 @@ def kernel(
loglikelihood_fn
A function that returns the probability at a given
position.
mcmc_kernel_factory
mcmc_step_fn
A function that creates a mcmc kernel from a log-probability density function.
make_mcmc_state: Callable
mcmc_init_fn: Callable
A function that creates a new mcmc state from a position and a
log-probability density function.
resampling_fn
Expand All @@ -66,12 +66,14 @@ def kernel(
information about the transition.
"""
kernel = smc.base.kernel(
mcmc_kernel_factory, make_mcmc_state, resampling_fn, num_mcmc_iterations
)
kernel = smc.base.kernel(mcmc_step_fn, mcmc_init_fn, resampling_fn)

def one_step(
rng_key: PRNGKey, state: TemperedSMCState, lmbda: float
rng_key: PRNGKey,
state: TemperedSMCState,
num_mcmc_steps: int,
lmbda: float,
mcmc_parameters: dict,
) -> Tuple[TemperedSMCState, smc.base.SMCInfo]:
"""Move the particles one step using the Tempered SMC algorithm.
Expand Down Expand Up @@ -102,7 +104,12 @@ def tempered_logposterior_fn(position: PyTree) -> float:
return logprior + tempered_loglikelihood

smc_state, info = kernel(
rng_key, state.particles, tempered_logposterior_fn, log_weights_fn
rng_key,
state.particles,
tempered_logposterior_fn,
log_weights_fn,
num_mcmc_steps,
mcmc_parameters,
)
state = TemperedSMCState(smc_state, state.lmbda + delta)

Expand Down
8 changes: 4 additions & 4 deletions tests/test_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import jax.numpy as jnp
import jax.scipy.stats as stats
import numpy as np
from absl.testing import absltest, parameterized
from absl.testing import absltest

import blackjax
import blackjax.smc.base as base
Expand All @@ -27,9 +27,9 @@ def setUp(self):
self.key = jax.random.PRNGKey(42)

@chex.all_variants(with_pmap=False)
@parameterized.parameters([500, 1000, 5000])
def test_smc(self, N):
def test_smc(self):

N = 100
mcmc_parameters = {
"step_size": 1e-2,
"inverse_mass_matrix": jnp.eye(1),
Expand All @@ -42,7 +42,6 @@ def test_smc(self, N):
blackjax.mcmc.hmc.kernel(),
blackjax.mcmc.hmc.init,
resampling.systematic,
1000,
)

# Don't use exactly the invariant distribution for the MCMC kernel
Expand All @@ -53,6 +52,7 @@ def test_smc(self, N):
smc_kernel,
logprob_fn=kernel_logprob_fn,
log_weight_fn=specialized_log_weights_fn,
num_mcmc_steps=1500,
mcmc_parameters=mcmc_parameters,
)
)(self.key, init_particles)
Expand Down
Loading

0 comments on commit 590446b

Please sign in to comment.