From ef584fc902947088a7b13c7cf8c40e76e90d11de Mon Sep 17 00:00:00 2001 From: Ami Falk Date: Thu, 30 Nov 2023 10:58:06 -0500 Subject: [PATCH 01/23] ensemble sampling draft --- numpyro/infer/__init__.py | 3 + numpyro/infer/ensemble.py | 719 +++++++++++++++++++++++++++++++++ numpyro/infer/ensemble_util.py | 108 +++++ 3 files changed, 830 insertions(+) create mode 100644 numpyro/infer/ensemble.py create mode 100644 numpyro/infer/ensemble_util.py diff --git a/numpyro/infer/__init__.py b/numpyro/infer/__init__.py index d9e0e337f..9abf96fa2 100644 --- a/numpyro/infer/__init__.py +++ b/numpyro/infer/__init__.py @@ -10,6 +10,7 @@ TraceGraph_ELBO, TraceMeanField_ELBO, ) +from numpyro.infer.ensemble import AIES, ESS from numpyro.infer.hmc import HMC, NUTS from numpyro.infer.hmc_gibbs import HMCECS, DiscreteHMCGibbs, HMCGibbs from numpyro.infer.initialization import ( @@ -29,6 +30,7 @@ from . import autoguide, reparam __all__ = [ + "AIES", "autoguide", "init_to_feasible", "init_to_mean", @@ -41,6 +43,7 @@ "BarkerMH", "DiscreteHMCGibbs", "ELBO", + "ESS", "HMC", "HMCECS", "HMCGibbs", diff --git a/numpyro/infer/ensemble.py b/numpyro/infer/ensemble.py new file mode 100644 index 000000000..99e1064e4 --- /dev/null +++ b/numpyro/infer/ensemble.py @@ -0,0 +1,719 @@ +from abc import ABC, abstractmethod +from collections import namedtuple + +import jax +from jax import random, vmap +import jax.numpy as jnp +from jax.scipy.stats import gaussian_kde + +import numpyro.distributions as dist +from numpyro.infer.ensemble_util import _get_nondiagonal_pairs, batch_ravel_pytree +from numpyro.infer.initialization import init_to_uniform +from numpyro.infer.mcmc import MCMCKernel +from numpyro.infer.util import initialize_model +from numpyro.util import identity, is_prng_key + +EnsembleSamplerState = namedtuple( + "EnsembleSamplerState", ["z", "inner_state", "rng_key"] +) +""" +A :func:`~collections.namedtuple` consisting of the following fields: + + - **z** - Python collection representing values (unconstrained samples from + the posterior) at latent sites. + - **inner_state** - A namedtuple containing information needed to update half the ensemble. + - **rng_key** - random number generator seed used for generating proposals, etc. +""" + +AIESState = namedtuple("AIESState", ["i", "accept_prob", "mean_accept_prob", "rng_key"]) +""" +A :func:`~collections.namedtuple` consisting of the following fields. + + - **i** - iteration. + - **accept_prob** - Acceptance probability of the proposal. Note that ``z`` + does not correspond to the proposal if it is rejected. + - **mean_accept_prob** - Mean acceptance probability until current iteration + during warmup adaptation or sampling (for diagnostics). + - **rng_key** - random number generator seed used for generating proposals, etc. +""" + +ESSState = namedtuple("ESSState", ["i", + "n_expansions", + "n_contractions", + "mu", + "rng_key" + ] + ) +""" +A :func:`~collections.namedtuple` used as an inner state for Ensemble Sampler. +This consists of the following fields: + + - **i** - iteration. + - **n_expansions** - number of expansions in the current batch. Used for tuning mu. + - **n_contractions** - number of contractions in the current batch. Used for tuning mu. + - **mu** - Scale factor. This is tuned if tune_mu=True. + - **rng_key** - random number generator seed used for generating proposals, etc. +""" + + +class EnsembleSampler(MCMCKernel, ABC): + """ + Abstract class for ensemble samplers. + + :param model: Python callable containing Pyro :mod:`~numpyro.primitives`. + If model is provided, `potential_fn` will be inferred using the model. + :param potential_fn: XXX currently unsupported. + :param bool randomize_split: whether or not to permute the chain order at each iteration. + :param callable init_strategy: a per-site initialization function. + See :ref:`init_strategy` section for available functions. + """ + + def __init__(self, model=None, potential_fn=None, randomize_split=False, init_strategy=init_to_uniform): + if not (model is None) ^ (potential_fn is None): + raise ValueError("Only one of `model` or `potential_fn` must be specified.") + + self._model = model + self._potential_fn = potential_fn + self._batch_log_density = None + # unravel an (n_chains, n_params) Array into a pytree and + # evaluate the log density at each chain + + # --- other hyperparams go here + self._num_chains = None # must be an even number >= 2 + self._randomize_split = randomize_split # whether or not to permute the chain order at each iteration + # --- + + self._init_strategy = init_strategy + self._postprocess_fn = None + + @property + def model(self): + return self._model + + @property + def sample_field(self): + return "z" + + @abstractmethod + def init_inner_state(self, rng_key): + """return inner_state""" + raise NotImplementedError + + @abstractmethod + def update_active_chains(self, active, inactive, inner_state): + """return (updated active set of chains, updated inner state)""" + raise NotImplementedError + + def _init_state(self, rng_key, model_args, model_kwargs, init_params): + if self._model is not None: + new_params_info, potential_fn_gen, self._postprocess_fn, _ = initialize_model( + rng_key, + self._model, + dynamic_args=True, + init_strategy=self._init_strategy, + model_args=model_args, + model_kwargs=model_kwargs, + validate_grad=False, + ) + new_init_params = new_params_info[0] + self._potential_fn = potential_fn_gen(*model_args, **model_kwargs) + + _, unravel_fn = batch_ravel_pytree(new_init_params) + self._batch_log_density = lambda z: -vmap(self._potential_fn)(unravel_fn(z)) + + if init_params is None: + init_params = new_init_params + + return init_params + + def init( + self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={} + ): + assert not is_prng_key( + rng_key + ), "EnsembleSampler only supports chain_method='vectorized' or chain_method='parallel'." + assert rng_key.shape[0] % 2 == 0, "Number of chains must be even." + + self._num_chains = rng_key.shape[0] + rng_key, rng_key_inner_state, rng_key_init_model = random.split(rng_key[0], 3) + rng_key_init_model = random.split(rng_key_init_model, self._num_chains) + + init_params = self._init_state( + rng_key_init_model, model_args, model_kwargs, init_params + ) + + if self._potential_fn and init_params is None: + raise ValueError( + "Valid value of `init_params` must be provided with" " `potential_fn`." + ) + + self._num_warmup = num_warmup + + return EnsembleSamplerState( + init_params, self.init_inner_state(rng_key_inner_state), rng_key + ) + + def postprocess_fn(self, args, kwargs): + if self._postprocess_fn is None: + return identity + return self._postprocess_fn(*args, **kwargs) + + def sample(self, state, model_args, model_kwargs): + z, inner_state, rng_key = state + rng_key, _ = random.split(rng_key) + z_flat, unravel_fn = batch_ravel_pytree(z) + + split_ind = self._num_chains // 2 + active_start_idx = [0, split_ind] + active_stop_idx = [split_ind, self._num_chains] + inactive_start_idx = [split_ind, 0] + inactive_stop_idx = [self._num_chains, split_ind] + + if self._randomize_split: + z_flat = random.permutation(rng_key, z_flat, axis=0) + + # TODO: is there a way to do this without having to compile twice? + # indexing depends on the iteration which makes scan/foriloop tricky + for split in range(2): + active = z_flat[active_start_idx[split] : active_stop_idx[split]] + inactive = z_flat[inactive_start_idx[split] : inactive_stop_idx[split]] + + z_updates, inner_state = self.update_active_chains(active, inactive, inner_state) + + z_flat = z_flat.at[active_start_idx[split] : active_stop_idx[split]].set(z_updates) + + return EnsembleSamplerState(unravel_fn(z_flat), inner_state, rng_key) + + +class AIES(EnsembleSampler): + """ + Affine-Invariant Ensemble Sampling: a gradient free method. Suitable for low to moderate dimensional models. + Generally, `num_chains` should be at least twice the dimensionality of the model. + + .. note:: This kernel must be used with `num_chains` > 1 and `chain_method="vectorized` + or `chain_method="parallel` in :class:`MCMC`. The number of chains must be divisible by 2. + + **References:** + + 1. *emcee: The MCMC Hammer* (https://iopscience.iop.org/article/10.1086/670067), + Daniel Foreman-Mackey, David W. Hogg, Dustin Lang, and Jonathan Goodman. + + :param model: Python callable containing Pyro :mod:`~numpyro.primitives`. + If model is provided, `potential_fn` will be inferred using the model. + :param potential_fn: XXX currently unsupported. + :param bool randomize_split: whether or not to permute the chain order at each iteration. + :param moves: a dictionary mapping moves to their respective probabilities of being selected. + If left empty, defaults to `AIES.DEMove()`. + :param callable init_strategy: a per-site initialization function. + See :ref:`init_strategy` section for available functions. + + **Example** + + .. code-block:: python + import jax + import jax.numpy as jnp + import numpyro + import numpyro.distributions as dist + from numpyro.infer import MCMC, AIES + + def model(): + x = numpyro.sample("x", dist.Normal().expand([10])) + numpyro.sample("obs", dist.Normal(x, 1.0), obs=jnp.ones(10)) + + kernel = AIES(model, moves={AIES.DEMove() : .5, + AIES.StretchMove() : .5}) + mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=20, chain_method='vectorized') + mcmc.run(jax.random.PRNGKey(0)) + mcmc.print_summary() + """ + + def __init__(self, model=None, potential_fn=None, randomize_split=False, moves=None, init_strategy=init_to_uniform): + if not moves: + self._moves = [AIES.DEMove()] + self._weights = jnp.array([1.0]) + else: + self._moves = list(moves.keys()) + self._weights = jnp.array([weight for weight in moves.values()]) / len(moves) + + super().__init__(model, potential_fn, randomize_split, init_strategy) + + # XXX: this doesn't show because state_method='vectorized' shuts off diagnostics_str + def get_diagnostics_str(self, state): + return "acc. prob={:.2f}".format(state.inner_state.mean_accept_prob) + + def init_inner_state(self, rng_key): + # XXX hack -- we don't know num_chains until we init the inner state + self._moves = [move(self._num_chains) if move.__name__ == 'make_de_move' + else move for move in self._moves] + + return AIESState(jnp.array(0.0), jnp.array(0.0), jnp.array(0.0), rng_key) + + def update_active_chains(self, active, inactive, inner_state): + i, _, mean_accept_prob, rng_key = inner_state + rng_key, move_key, proposal_key, accept_key = random.split(rng_key, 4) + + move_i = random.choice(move_key, len(self._moves), p=self._weights) + proposal, factors = jax.lax.switch( + move_i, self._moves, proposal_key, active, inactive + ) + + # --- evaluate the proposal --- + log_accept_prob = ( + factors + + self._batch_log_density(proposal) + - self._batch_log_density(active) + ) + + accepted = dist.Uniform().sample(accept_key, (active.shape[0],)) < jnp.exp( + log_accept_prob + ) + updated_active_chains = jnp.where(accepted[:, jnp.newaxis], proposal, active) + + accept_prob = jnp.count_nonzero(accepted) / accepted.shape[0] + itr = i + 0.5 + n = jnp.where(i < self._num_warmup, itr, itr - self._num_warmup) + mean_accept_prob = mean_accept_prob + (accept_prob - mean_accept_prob) / n + + return updated_active_chains, AIESState( + itr, accept_prob, mean_accept_prob, rng_key + ) + + @staticmethod + def DEMove(sigma=1.0e-5, g0=None): + """A proposal using differential evolution. + + This `Differential evolution proposal + `_ is + implemented following `Nelson et al. (2013) + `_. + Args: + sigma (float): The standard deviation of the Gaussian used to stretch + the proposal vector. + gamma0 (Optional[float]): The mean stretch factor for the proposal + vector. By default, it is `2.38 / sqrt(2*ndim)` + as recommended by the two references. + """ + def make_de_move(n_chains): + PAIRS = _get_nondiagonal_pairs(n_chains // 2) + + def de_move(rng_key, active, inactive): + pairs_key, gamma_key = random.split(rng_key) + n_active_chains, n_params = inactive.shape + + # TODO: if we pass in n_params to parent scope we don't need to recompute this each time + g = 2.38 / jnp.sqrt(2.0 * n_params) if not g0 else g0 + + selected_pairs = random.choice(pairs_key, PAIRS, shape=(n_active_chains,)) + + # Compute diff vectors + diffs = jnp.diff(inactive[selected_pairs], axis=1).squeeze(axis=1) + + # Sample a gamma value for each walker following Nelson et al. (2013) + gamma = dist.Normal(g, g * sigma).sample( + gamma_key, sample_shape=(n_active_chains, 1) + ) + + # In this way, sigma is the standard deviation of the distribution of gamma, + # instead of the standard deviation of the distribution of the proposal as proposed by Ter Braak (2006). + # Otherwise, sigma should be tuned for each dimension, which confronts the idea of affine-invariance. + proposal = active + gamma * diffs + + return proposal, jnp.zeros(n_active_chains) + + return de_move + + return make_de_move + + @staticmethod + def StretchMove(a=2.0): + """ + A `Goodman & Weare (2010) + `_ "stretch move" with + parallelization as described in `Foreman-Mackey et al. (2013) + `_. + + :param a: (optional) + The stretch scale parameter. (default: ``2.0``) + """ + def stretch_move(rng_key, active, inactive): + n_active_chains, n_params = active.shape + unif_key, idx_key = random.split(rng_key) + + zz = ( + (a - 1.0) * random.uniform(unif_key, shape=(n_active_chains,)) + 1 + ) ** 2.0 / a + factors = (n_params - 1.0) * jnp.log(zz) + r_idxs = random.randint( + idx_key, shape=(n_active_chains,), minval=0, maxval=n_active_chains + ) + + proposal = inactive[r_idxs] - (inactive[r_idxs] - active) * zz[:, jnp.newaxis] + + return proposal, factors + + return stretch_move + + +class ESS(EnsembleSampler): + """ + Ensemble Slice Sampling: a gradient free method. Suitable for low to moderate dimensional models. + Generally, `num_chains` should be at least twice the dimensionality of the model. + + .. note:: This kernel must be used with `num_chains` > 1 and `chain_method="vectorized` + or `chain_method="parallel` in :class:`MCMC`. The number of chains must be divisible by 2. + + **References:** + + 1. *zeus: a PYTHON implementation of ensemble slice sampling for efficient Bayesian parameter inference* (https://academic.oup.com/mnras/article/508/3/3589/6381726), + Minas Karamanis, Florian Beutler, and John A. Peacock. + 2. *Ensemble slice sampling* (https://link.springer.com/article/10.1007/s11222-021-10038-2), + Minas Karamanis, Florian Beutler. + + :param model: Python callable containing Pyro :mod:`~numpyro.primitives`. + If model is provided, `potential_fn` will be inferred using the model. + :param potential_fn: XXX currently unsupported. + :param bool randomize_split: whether or not to permute the chain order at each iteration. + Strongly recommended to set to True. + :param moves: a dictionary mapping moves to their respective probabilities of being selected. + If left empty, defaults to `ESS.DifferentialMove()`. + :param int max_steps: number of maximum stepping-out steps per sample. + :param int max_iter: number of maximum expansions/contractions per sample. + :param float init_mu: initial scale factor. + :param bool tune_mu: whether or not to tune the intial scale factor. + :param callable init_strategy: a per-site initialization function. + See :ref:`init_strategy` section for available functions. + + **Example** + + .. code-block:: python + import jax + import jax.numpy as jnp + import numpyro + import numpyro.distributions as dist + from numpyro.infer import MCMC, AIES + + def model(): + x = numpyro.sample("x", dist.Normal().expand([10])) + numpyro.sample("obs", dist.Normal(x, 1.0), obs=jnp.ones(10)) + + kernel = AIES(model, moves={ESS.DifferentialMove() : .8, + ESS.KDEMove() : .2}) + mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=20, chain_method='vectorized') + mcmc.run(jax.random.PRNGKey(0)) + mcmc.print_summary() + """ + def __init__( + self, + model=None, + potential_fn=None, + randomize_split=True, + moves=None, + max_steps=10_000, + max_iter=10_000, + init_mu=1.0, + tune_mu=True, + init_strategy=init_to_uniform, + ): + if not moves: + self._moves = [ESS.DifferentialMove()] + self._weights = jnp.array([1.0]) + else: + self._moves = list(moves.keys()) + self._weights = jnp.array([weight for weight in moves.values()]) / len(moves) + + self._max_steps = max_steps # max number of stepping out steps + self._max_iter = max_iter # max number of expansions/contractions + self._init_mu = init_mu + self._tune_mu = tune_mu + + super().__init__(model, potential_fn, randomize_split, init_strategy) + + def init_inner_state(self, rng_key): + self.batch_log_density = lambda x: self._batch_log_density(x)[:, jnp.newaxis] + + # XXX hack -- we don't know num_chains until we init the inner state + self._moves = [move(self._num_chains) if move.__name__ == 'make_differential_move' + else move for move in self._moves] + + return ESSState(jnp.array(0.0), jnp.array(0), jnp.array(0), self._init_mu, rng_key) + + def update_active_chains(self, active, inactive, inner_state): + i, n_expansions, n_contractions, mu, rng_key = inner_state + (rng_key, + move_key, + dir_key, + height_key, + step_out_key, + shrink_key) = random.split(rng_key, 6) + + n_active_chains, n_params = active.shape + + move_i = random.choice(move_key, len(self._moves), p=self._weights) + directions = jax.lax.switch(move_i, self._moves, dir_key, inactive, mu) + + log_slice_height = self.batch_log_density(active) - dist.Exponential().sample( + height_key, sample_shape=(n_active_chains, 1) + ) + + curr_n_expansions, L, R = self._step_out( + step_out_key, log_slice_height, active, directions + ) + proposal, curr_n_contractions = self._shrink( + shrink_key, log_slice_height, L, R, active, directions + ) + + n_expansions += curr_n_expansions + n_contractions += curr_n_contractions + itr = i + 0.5 + + if self._tune_mu: + safe_n_expansions = jnp.max(jnp.array([1, n_expansions])) + + # only update tuning scale if a full iteration has passed + mu, n_expansions, n_contractions = jax.lax.cond(jnp.all(itr % 1 == 0), + lambda n_exp, n_con: (2.0 * n_exp / (n_exp + n_con), + jnp.array(0), + jnp.array(0) + ), + lambda _, __: (mu, + n_expansions, + n_contractions + ), + safe_n_expansions, n_contractions) + + return proposal, ESSState(itr, n_expansions, n_contractions, mu, rng_key) + + + @staticmethod + def RandomMove(): + """ + The `Karamanis & Beutler (2020) `_ "Random Move" with parallelization. + When this move is used the walkers move along random directions. There is no communication between the + walkers and this Move corresponds to the vanilla Slice Sampling method. This Move should be used for + debugging purposes only. + """ + def random_move(rng_key, inactive, mu): + directions = dist.Normal(loc=0, scale=1).sample( + rng_key, sample_shape=inactive.shape + ) + directions /= jnp.linalg.norm(directions, axis=0) + + return 2.0 * mu * directions + return random_move + + @staticmethod + def KDEMove(bw_method=None): + """ + The `Karamanis & Beutler (2020) `_ "KDE Move" with parallelization. + When this Move is used the distribution of the walkers of the complementary ensemble is traced using + a Gaussian Kernel Density Estimation methods. The walkers then move along random direction vectos + sampled from this distribution. + """ + def kde_move(rng_key, inactive, mu): + n_active_chains, n_params = inactive.shape + + kde = gaussian_kde(inactive.T, bw_method=bw_method) + + vectors = kde.resample(rng_key, (2 * n_active_chains,)).T + directions = vectors[:n_active_chains] - vectors[n_active_chains:] + + return 2.0 * mu * directions + return kde_move + + @staticmethod + def GaussianMove(): + """ + The `Karamanis & Beutler (2020) `_ "Gaussian Move" with parallelization. + When this Move is used the walkers move along directions defined by random vectors sampled from the Gaussian + approximation of the walkers of the complementary ensemble. + """ + def gaussian_move(rng_key, inactive, mu): + n_active_chains, n_params = inactive.shape + cov = jnp.cov(inactive, rowvar=False) + + return ( + 2.0 + * mu + * dist.MultivariateNormal(0, cov).sample( + rng_key, sample_shape=(n_active_chains,) + ) + ) + return gaussian_move + + @staticmethod + def DifferentialMove(): + """ + The `Karamanis & Beutler (2020) `_ "Differential Move" with parallelization. + When this Move is used the walkers move along directions defined by random pairs of walkers sampled (with no + replacement) from the complementary ensemble. This is the default choice and performs well along a wide range + of target distributions. + """ + def make_differential_move(n_chains): + PAIRS = _get_nondiagonal_pairs(n_chains // 2) + + def differential_move(rng_key, inactive, mu): + n_active_chains, n_params = inactive.shape + + selected_pairs = random.choice(rng_key, PAIRS, shape=(n_active_chains,)) + diffs = jnp.diff(inactive[selected_pairs], axis=1).squeeze( + axis=1 + ) # get the pairwise difference of each vector + + return 2.0 * mu * diffs + return differential_move + + return make_differential_move + + + def _step_out(self, rng_key, log_slice_height, active, directions): + init_L_key, init_J_key = random.split(rng_key) + n_active_chains, n_params = active.shape + + iteration = 0 + n_expansions = 0 + # set initial interval boundaries + L = -dist.Uniform().sample(init_L_key, sample_shape=(n_active_chains, 1)) + R = L + 1.0 + + # stepping out + J = jnp.floor( + dist.Uniform(low=0, high=self._max_steps).sample( + init_J_key, sample_shape=(n_active_chains, 1) + ) + ) + K = (self._max_steps - 1) - J + + # left stepping-out initialisation + mask_J = jnp.full((n_active_chains, 1), True) + # right stepping-out initialisation + mask_K = jnp.full((n_active_chains, 1), True) + + init_values = (n_expansions, L, R, J, K, mask_J, mask_K, iteration) + + def cond_fn(args): + n_expansions, L, R, J, K, mask_J, mask_K, iteration = args + + return (jnp.count_nonzero(mask_J) + jnp.count_nonzero(mask_K) > 0) & ( + iteration < self._max_iter + ) + + def body_fn(args): + n_expansions, L, R, J, K, mask_J, mask_K, iteration = args + + log_prob_L = self.batch_log_density(directions * L + active) + log_prob_R = self.batch_log_density(directions * R + active) + + can_expand_L = log_prob_L > log_slice_height + L = jnp.where(can_expand_L, L - 1, L) + J = jnp.where(can_expand_L, J - 1, J) + mask_J = jnp.where(can_expand_L, mask_J, False) + + can_expand_R = log_prob_R > log_slice_height + R = jnp.where(can_expand_R, R + 1, R) + K = jnp.where(can_expand_R, K - 1, K) + mask_K = jnp.where(can_expand_R, mask_K, False) + + iteration += 1 + n_expansions += jnp.count_nonzero(can_expand_L) + jnp.count_nonzero( + can_expand_R + ) + + return (n_expansions, L, R, J, K, mask_J, mask_K, iteration) + + n_expansions, L, R, J, K, mask_J, mask_K, iteration = jax.lax.while_loop( + cond_fn, body_fn, init_values + ) + + return n_expansions, L, R + + def _shrink(self, rng_key, log_slice_height, L, R, active, directions): + n_active_chains, n_params = active.shape + + iteration = 0 + n_contractions = 0 + widths = jnp.zeros((n_active_chains, 1)) + proposed = jnp.zeros((n_active_chains, n_params)) + can_shrink = jnp.full((n_active_chains, 1), True) + + init_values = ( + rng_key, + proposed, + n_contractions, + L, + R, + widths, + can_shrink, + iteration, + ) + + def cond_fn(args): + ( + rng_key, + proposed, + n_contractions, + L, + R, + widths, + can_shrink, + iteration, + ) = args + + return (jnp.count_nonzero(can_shrink) > 0) & (iteration < self._max_iter) + + def body_fn(args): + ( + rng_key, + proposed, + n_contractions, + L, + R, + widths, + can_shrink, + iteration, + ) = args + + rng_key, _ = random.split(rng_key) + + widths = jnp.where( + can_shrink, dist.Uniform(low=L, high=R).sample(rng_key), widths + ) + + # compute new positions + proposed = jnp.where(can_shrink, directions * widths + active, proposed) + proposed_log_prob = self.batch_log_density(proposed) + + # shrink slices + can_shrink = proposed_log_prob < log_slice_height + + L_cond = can_shrink & (widths < 0.0) + L = jnp.where(L_cond, widths, L) + + R_cond = can_shrink & (widths > 0.0) + R = jnp.where(R_cond, widths, R) + + iteration += 1 + n_contractions += jnp.count_nonzero(L_cond) + jnp.count_nonzero(R_cond) + + return ( + rng_key, + proposed, + n_contractions, + L, + R, + widths, + can_shrink, + iteration, + ) + + ( + rng_key, + proposed, + n_contractions, + L, + R, + widths, + can_shrink, + iteration, + ) = jax.lax.while_loop(cond_fn, body_fn, init_values) + + return proposed, n_contractions diff --git a/numpyro/infer/ensemble_util.py b/numpyro/infer/ensemble_util.py new file mode 100644 index 000000000..f26e64d70 --- /dev/null +++ b/numpyro/infer/ensemble_util.py @@ -0,0 +1,108 @@ +import warnings + +import numpy as np + +from jax import lax, vmap +import jax.numpy as jnp + +from jax._src import dtypes +from jax._src.tree_util import tree_flatten, tree_unflatten +from jax._src.util import safe_zip, unzip2, HashablePartial + +zip = safe_zip + + +def _get_nondiagonal_pairs(n): + """ + From https://github.com/dfm/emcee/blob/main/src/emcee/moves/de.py: + + Get the indices of a square matrix with size n, excluding the diagonal. + """ + + rows, cols = np.tril_indices(n, -1) # -1 to exclude diagonal + + # Combine rows-cols and cols-rows pairs + pairs = np.column_stack([np.concatenate([rows, cols]), + np.concatenate([cols, rows])]) + + return jnp.asarray(pairs) + + +def batch_ravel_pytree(pytree): + """Ravel (flatten) a pytree of arrays with leading batch dimension down to a (batch_size, 1D) array. + Args: + pytree: a pytree of arrays and scalars to ravel. + Returns: + A pair where the first element is a (batch_size, 1D) array representing the flattened and + concatenated leaf values, with dtype determined by promoting the dtypes of + leaf values, and the second element is a callable for unflattening a (batch_size, 1D) + vector of the same length back to a pytree of of the same structure as the + input ``pytree``. If the input pytree is empty (i.e. has no leaves) then as + a convention a 1D empty array of dtype float32 is returned in the first + component of the output. + For details on dtype promotion, see + https://jax.readthedocs.io/en/latest/type_promotion.html. + """ + + leaves, treedef = tree_flatten(pytree) + flat, unravel_list = _ravel_list(leaves) + return flat, HashablePartial(unravel_pytree, treedef, unravel_list) + +def unravel_pytree(treedef, unravel_list, flat): + return tree_unflatten(treedef, unravel_list(flat)) + +@vmap +def vmapped_ravel(a): + return jnp.ravel(a) + +def _ravel_list(lst): + if not lst: return jnp.array([], jnp.float32), lambda _: [] + from_dtypes = tuple(dtypes.dtype(l) for l in lst) + to_dtype = dtypes.result_type(*from_dtypes) + + # here 1 is n_leading_batch_dimensions + sizes, shapes = unzip2((np.prod(jnp.shape(x)[1:]), jnp.shape(x)[1:]) for x in lst) + indices = tuple(np.cumsum(sizes)) + + if all(dt == to_dtype for dt in from_dtypes): + # Skip any dtype conversion, resulting in a dtype-polymorphic `unravel`. + # See https://github.com/google/jax/issues/7809. + del from_dtypes, to_dtype + + # axis = n_leading_batch_dimensions + # vmap n_leading_batch_dimensions times + raveled = jnp.concatenate([vmapped_ravel(e) for e in lst], axis=1) + return raveled, HashablePartial(_unravel_list_single_dtype, indices, shapes) + + # When there is more than one distinct input dtype, we perform type + # conversions and produce a dtype-specific unravel function. + ravel = lambda e: jnp.ravel(lax.convert_element_type(e, to_dtype)) + raveled = jnp.concatenate([vmapped_ravel(e) for e in lst]) + unrav = HashablePartial(_unravel_list, indices, shapes, from_dtypes, to_dtype) + return raveled, unrav + + +def _unravel_list_single_dtype(indices, shapes, arr): + # axis is n_leading_batch_dimensions + chunks = jnp.split(arr, indices[:-1], axis=1) + + # the number of -1s is the number of leading batch dimensions + return [chunk.reshape((-1, *shape)) for chunk, shape in zip(chunks, shapes)] + + +def _unravel_list(indices, shapes, from_dtypes, to_dtype, arr): + arr_dtype = dtypes.dtype(arr) + if arr_dtype != to_dtype: + raise TypeError(f"unravel function given array of dtype {arr_dtype}, " + f"but expected dtype {to_dtype}") + + # axis is n_leading_batch_dimensions + chunks = jnp.split(arr, indices[:-1], axis=1) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") # ignore complex-to-real cast warning + # the number of -1s is the number of leading batch dimensions + return [lax.convert_element_type(chunk.reshape((-1, *shape)), dtype) + for chunk, shape, dtype in zip(chunks, shapes, from_dtypes)] + + + From f3b62d2473be6cb45b9ff1b3fa9f82e8545ada75 Mon Sep 17 00:00:00 2001 From: Ami Falk Date: Thu, 30 Nov 2023 11:36:18 -0500 Subject: [PATCH 02/23] rewrite for loop as fori_loop --- numpyro/infer/ensemble.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/numpyro/infer/ensemble.py b/numpyro/infer/ensemble.py index 99e1064e4..bde8c5f09 100644 --- a/numpyro/infer/ensemble.py +++ b/numpyro/infer/ensemble.py @@ -163,24 +163,29 @@ def sample(self, state, model_args, model_kwargs): rng_key, _ = random.split(rng_key) z_flat, unravel_fn = batch_ravel_pytree(z) - split_ind = self._num_chains // 2 - active_start_idx = [0, split_ind] - active_stop_idx = [split_ind, self._num_chains] - inactive_start_idx = [split_ind, 0] - inactive_stop_idx = [self._num_chains, split_ind] - if self._randomize_split: z_flat = random.permutation(rng_key, z_flat, axis=0) - # TODO: is there a way to do this without having to compile twice? - # indexing depends on the iteration which makes scan/foriloop tricky - for split in range(2): - active = z_flat[active_start_idx[split] : active_stop_idx[split]] - inactive = z_flat[inactive_start_idx[split] : inactive_stop_idx[split]] + split_ind = self._num_chains // 2 + def body_fn(i, z_flat_inner_state): + z_flat, inner_state = z_flat_inner_state + + active, inactive = jax.lax.cond(i == 0, + lambda x: (x[:split_ind], x[split_ind:]), + lambda x: (x[split_ind:], x[split_ind:]), + z_flat) + z_updates, inner_state = self.update_active_chains(active, inactive, inner_state) + + z_flat = jax.lax.cond(i == 0, + lambda x: x.at[:split_ind].set(z_updates), + lambda x: x.at[split_ind:].set(z_updates), + z_flat) + return (z_flat, inner_state) + + z_flat, inner_state = jax.lax.fori_loop(0, 2, body_fn, (z_flat, inner_state)) - z_flat = z_flat.at[active_start_idx[split] : active_stop_idx[split]].set(z_updates) return EnsembleSamplerState(unravel_fn(z_flat), inner_state, rng_key) From 6fff9270c0c25c15b06ec5ffa1ba7470cb28aed6 Mon Sep 17 00:00:00 2001 From: Ami Falk Date: Sat, 2 Dec 2023 14:47:02 -0500 Subject: [PATCH 03/23] added efficiency comment for ESS GaussianMove --- numpyro/infer/ensemble.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/numpyro/infer/ensemble.py b/numpyro/infer/ensemble.py index bde8c5f09..3a4a4e3f2 100644 --- a/numpyro/infer/ensemble.py +++ b/numpyro/infer/ensemble.py @@ -532,6 +532,14 @@ def GaussianMove(): When this Move is used the walkers move along directions defined by random vectors sampled from the Gaussian approximation of the walkers of the complementary ensemble. """ + + # In high dimensional regimes with sufficiently small n_active_chains, + # it is more efficient to sample without computing the Cholesky + # decomposition of the covariance matrix: + + # eps = dist.Normal(0, 1).sample(rng_key, (n_active_chains, n_params)) + # return 2.0 * mu * (eps @ (inactive - jnp.mean(inactive, axis=0)) / jnp.sqrt(n_active_chains)) + def gaussian_move(rng_key, inactive, mu): n_active_chains, n_params = inactive.shape cov = jnp.cov(inactive, rowvar=False) From dc582e52d80fc4553d3cdbb745a1bdad46a4e947 Mon Sep 17 00:00:00 2001 From: Ami Falk Date: Sat, 2 Dec 2023 14:55:18 -0500 Subject: [PATCH 04/23] fix typo --- numpyro/infer/ensemble.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/infer/ensemble.py b/numpyro/infer/ensemble.py index 3a4a4e3f2..4983ce9a6 100644 --- a/numpyro/infer/ensemble.py +++ b/numpyro/infer/ensemble.py @@ -537,7 +537,7 @@ def GaussianMove(): # it is more efficient to sample without computing the Cholesky # decomposition of the covariance matrix: - # eps = dist.Normal(0, 1).sample(rng_key, (n_active_chains, n_params)) + # eps = dist.Normal(0, 1).sample(rng_key, (n_active_chains, n_active_chains)) # return 2.0 * mu * (eps @ (inactive - jnp.mean(inactive, axis=0)) / jnp.sqrt(n_active_chains)) def gaussian_move(rng_key, inactive, mu): From 49e4cdb229471b56003775914aa57f249b6a569d Mon Sep 17 00:00:00 2001 From: Ami Falk Date: Sat, 23 Dec 2023 20:08:31 -0500 Subject: [PATCH 05/23] fixed ravel for mixed dtype --- numpyro/infer/ensemble.py | 3 +++ numpyro/infer/ensemble_util.py | 10 ++++++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/numpyro/infer/ensemble.py b/numpyro/infer/ensemble.py index 4983ce9a6..ec618961f 100644 --- a/numpyro/infer/ensemble.py +++ b/numpyro/infer/ensemble.py @@ -1,3 +1,6 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + from abc import ABC, abstractmethod from collections import namedtuple diff --git a/numpyro/infer/ensemble_util.py b/numpyro/infer/ensemble_util.py index f26e64d70..f0732b083 100644 --- a/numpyro/infer/ensemble_util.py +++ b/numpyro/infer/ensemble_util.py @@ -1,3 +1,6 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + import warnings import numpy as np @@ -51,9 +54,8 @@ def batch_ravel_pytree(pytree): def unravel_pytree(treedef, unravel_list, flat): return tree_unflatten(treedef, unravel_list(flat)) -@vmap def vmapped_ravel(a): - return jnp.ravel(a) + return vmap(jnp.ravel)(a) def _ravel_list(lst): if not lst: return jnp.array([], jnp.float32), lambda _: [] @@ -76,8 +78,8 @@ def _ravel_list(lst): # When there is more than one distinct input dtype, we perform type # conversions and produce a dtype-specific unravel function. - ravel = lambda e: jnp.ravel(lax.convert_element_type(e, to_dtype)) - raveled = jnp.concatenate([vmapped_ravel(e) for e in lst]) + convert_vmapped_ravel = lambda e: vmapped_ravel(lax.convert_element_type(e, to_dtype)) + raveled = jnp.concatenate([convert_vmapped_ravel(e) for e in lst]) unrav = HashablePartial(_unravel_list, indices, shapes, from_dtypes, to_dtype) return raveled, unrav From bc7c79252aca72895cdaf5652dbdab1efd118ff4 Mon Sep 17 00:00:00 2001 From: Ami Falk Date: Sun, 21 Jan 2024 10:34:27 -0500 Subject: [PATCH 06/23] add defaults --- numpyro/infer/ensemble.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/numpyro/infer/ensemble.py b/numpyro/infer/ensemble.py index ec618961f..7caa0fb59 100644 --- a/numpyro/infer/ensemble.py +++ b/numpyro/infer/ensemble.py @@ -71,7 +71,7 @@ class EnsembleSampler(MCMCKernel, ABC): See :ref:`init_strategy` section for available functions. """ - def __init__(self, model=None, potential_fn=None, randomize_split=False, init_strategy=init_to_uniform): + def __init__(self, model=None, potential_fn=None, *, randomize_split, init_strategy): if not (model is None) ^ (potential_fn is None): raise ValueError("Only one of `model` or `potential_fn` must be specified.") @@ -210,6 +210,7 @@ class AIES(EnsembleSampler): If model is provided, `potential_fn` will be inferred using the model. :param potential_fn: XXX currently unsupported. :param bool randomize_split: whether or not to permute the chain order at each iteration. + Defaults to False. :param moves: a dictionary mapping moves to their respective probabilities of being selected. If left empty, defaults to `AIES.DEMove()`. :param callable init_strategy: a per-site initialization function. @@ -243,7 +244,10 @@ def __init__(self, model=None, potential_fn=None, randomize_split=False, moves=N self._moves = list(moves.keys()) self._weights = jnp.array([weight for weight in moves.values()]) / len(moves) - super().__init__(model, potential_fn, randomize_split, init_strategy) + super().__init__(model, + potential_fn, + randomize_split=randomize_split, + init_strategy=init_strategy) # XXX: this doesn't show because state_method='vectorized' shuts off diagnostics_str def get_diagnostics_str(self, state): From 1ebff6c036bfaaeaf254a07c4c30845e4b4997e3 Mon Sep 17 00:00:00 2001 From: Ami Falk Date: Sun, 21 Jan 2024 11:13:23 -0500 Subject: [PATCH 07/23] add support for potential_fn --- numpyro/infer/ensemble.py | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/numpyro/infer/ensemble.py b/numpyro/infer/ensemble.py index 7caa0fb59..5d25b60c6 100644 --- a/numpyro/infer/ensemble.py +++ b/numpyro/infer/ensemble.py @@ -65,7 +65,10 @@ class EnsembleSampler(MCMCKernel, ABC): :param model: Python callable containing Pyro :mod:`~numpyro.primitives`. If model is provided, `potential_fn` will be inferred using the model. - :param potential_fn: XXX currently unsupported. + :param potential_fn: Python callable that computes the potential energy + given input parameters. The input parameters to `potential_fn` can be + any python collection type, provided that `init_params` argument to + :meth:`init` has the same type. :param bool randomize_split: whether or not to permute the chain order at each iteration. :param callable init_strategy: a per-site initialization function. See :ref:`init_strategy` section for available functions. @@ -121,12 +124,12 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params): new_init_params = new_params_info[0] self._potential_fn = potential_fn_gen(*model_args, **model_kwargs) - _, unravel_fn = batch_ravel_pytree(new_init_params) - self._batch_log_density = lambda z: -vmap(self._potential_fn)(unravel_fn(z)) - if init_params is None: init_params = new_init_params + _, unravel_fn = batch_ravel_pytree(init_params) + self._batch_log_density = lambda z: -vmap(self._potential_fn)(unravel_fn(z)) + return init_params def init( @@ -137,6 +140,11 @@ def init( ), "EnsembleSampler only supports chain_method='vectorized' or chain_method='parallel'." assert rng_key.shape[0] % 2 == 0, "Number of chains must be even." + if self._potential_fn and init_params is None: + raise ValueError( + "Valid value of `init_params` must be provided with `potential_fn`." + ) + self._num_chains = rng_key.shape[0] rng_key, rng_key_inner_state, rng_key_init_model = random.split(rng_key[0], 3) rng_key_init_model = random.split(rng_key_init_model, self._num_chains) @@ -145,11 +153,6 @@ def init( rng_key_init_model, model_args, model_kwargs, init_params ) - if self._potential_fn and init_params is None: - raise ValueError( - "Valid value of `init_params` must be provided with" " `potential_fn`." - ) - self._num_warmup = num_warmup return EnsembleSamplerState( @@ -208,7 +211,10 @@ class AIES(EnsembleSampler): :param model: Python callable containing Pyro :mod:`~numpyro.primitives`. If model is provided, `potential_fn` will be inferred using the model. - :param potential_fn: XXX currently unsupported. + :param potential_fn: Python callable that computes the potential energy + given input parameters. The input parameters to `potential_fn` can be + any python collection type, provided that `init_params` argument to + :meth:`init` has the same type. :param bool randomize_split: whether or not to permute the chain order at each iteration. Defaults to False. :param moves: a dictionary mapping moves to their respective probabilities of being selected. @@ -241,6 +247,7 @@ def __init__(self, model=None, potential_fn=None, randomize_split=False, moves=N self._moves = [AIES.DEMove()] self._weights = jnp.array([1.0]) else: + # TODO: check that moves are valid self._moves = list(moves.keys()) self._weights = jnp.array([weight for weight in moves.values()]) / len(moves) @@ -383,7 +390,10 @@ class ESS(EnsembleSampler): :param model: Python callable containing Pyro :mod:`~numpyro.primitives`. If model is provided, `potential_fn` will be inferred using the model. - :param potential_fn: XXX currently unsupported. + :param potential_fn: Python callable that computes the potential energy + given input parameters. The input parameters to `potential_fn` can be + any python collection type, provided that `init_params` argument to + :meth:`init` has the same type. :param bool randomize_split: whether or not to permute the chain order at each iteration. Strongly recommended to set to True. :param moves: a dictionary mapping moves to their respective probabilities of being selected. From 3625d757dbd57412004622a80706ab7aa381aa15 Mon Sep 17 00:00:00 2001 From: Ami Falk Date: Sun, 21 Jan 2024 14:08:06 -0500 Subject: [PATCH 08/23] AIES tests, warnings for AIES --- numpyro/infer/ensemble.py | 13 +++++- test/infer/test_mcmc.py | 91 ++++++++++++++++++++++++++------------- 2 files changed, 73 insertions(+), 31 deletions(-) diff --git a/numpyro/infer/ensemble.py b/numpyro/infer/ensemble.py index 5d25b60c6..272a258b1 100644 --- a/numpyro/infer/ensemble.py +++ b/numpyro/infer/ensemble.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from collections import namedtuple +import warnings import jax from jax import random, vmap @@ -86,7 +87,7 @@ def __init__(self, model=None, potential_fn=None, *, randomize_split, init_strat # --- other hyperparams go here self._num_chains = None # must be an even number >= 2 - self._randomize_split = randomize_split # whether or not to permute the chain order at each iteration + self._randomize_split = randomize_split # --- self._init_strategy = init_strategy @@ -127,9 +128,13 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params): if init_params is None: init_params = new_init_params - _, unravel_fn = batch_ravel_pytree(init_params) + flat_params, unravel_fn = batch_ravel_pytree(init_params) self._batch_log_density = lambda z: -vmap(self._potential_fn)(unravel_fn(z)) + if self._num_chains < 2 * flat_params.shape[1]: + warnings.warn("Setting n_chains to at least 2*n_params is strongly recommended.\n" + f"n_chains: {self._num_chains}, n_params: {flat_params.shape[1]}") + return init_params def init( @@ -138,12 +143,16 @@ def init( assert not is_prng_key( rng_key ), "EnsembleSampler only supports chain_method='vectorized' or chain_method='parallel'." + assert rng_key.shape[0] % 2 == 0, "Number of chains must be even." if self._potential_fn and init_params is None: raise ValueError( "Valid value of `init_params` must be provided with `potential_fn`." ) + + # TODO: if init_params is specified, check that the batch dimension of each array + # matches n_chains self._num_chains = rng_key.shape[0] rng_key, rng_key_inner_state, rng_key_init_model = random.split(rng_key[0], 3) diff --git a/test/infer/test_mcmc.py b/test/infer/test_mcmc.py index 04a5bc64d..ab46ed9d0 100644 --- a/test/infer/test_mcmc.py +++ b/test/infer/test_mcmc.py @@ -17,7 +17,7 @@ import numpyro import numpyro.distributions as dist from numpyro.distributions.transforms import AffineTransform -from numpyro.infer import HMC, MCMC, NUTS, SA, BarkerMH +from numpyro.infer import HMC, MCMC, NUTS, SA, BarkerMH, AIES, ESS # TODO: get ESS working from numpyro.infer.hmc import hmc from numpyro.infer.reparam import TransformReparam from numpyro.infer.sa import _get_proposal_loc_and_scale, _numpy_delete @@ -25,7 +25,7 @@ from numpyro.util import fori_collect, is_prng_key -@pytest.mark.parametrize("kernel_cls", [HMC, NUTS, SA, BarkerMH]) +@pytest.mark.parametrize("kernel_cls", [HMC, NUTS, SA, BarkerMH, AIES]) @pytest.mark.parametrize("dense_mass", [False, True]) def test_unnormalized_normal_x64(kernel_cls, dense_mass): true_mean, true_std = 1.0, 0.5 @@ -34,16 +34,27 @@ def test_unnormalized_normal_x64(kernel_cls, dense_mass): def potential_fn(z): return 0.5 * jnp.sum(((z - true_mean) / true_std) ** 2) - init_params = jnp.array(0.0) - if kernel_cls in [SA, BarkerMH]: - kernel = kernel_cls(potential_fn=potential_fn, dense_mass=dense_mass) + if kernel_cls in [AIES, ESS]: + num_chains = 10 + kernel = kernel_cls(potential_fn=potential_fn) + + init_params = random.normal(random.PRNGKey(1), (num_chains,)) + mcmc = MCMC( + kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False, + num_chains=num_chains, chain_method='vectorized' + ) else: - kernel = kernel_cls( - potential_fn=potential_fn, trajectory_length=8, dense_mass=dense_mass + if kernel_cls in [SA, BarkerMH]: + kernel = kernel_cls(potential_fn=potential_fn, dense_mass=dense_mass) + else: + kernel = kernel_cls( + potential_fn=potential_fn, trajectory_length=8, dense_mass=dense_mass + ) + init_params = jnp.array(0.0) + mcmc = MCMC( + kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False ) - mcmc = MCMC( - kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False - ) + mcmc.run(random.PRNGKey(0), init_params=init_params) mcmc.print_summary() hmc_states = mcmc.get_samples() @@ -83,13 +94,17 @@ def potential_fn(z): assert np.sum(np.abs(np.cov(samples.T) - true_cov)) / D**2 < 0.02 -@pytest.mark.parametrize("kernel_cls", [HMC, NUTS, SA, BarkerMH]) +@pytest.mark.parametrize("kernel_cls", [HMC, NUTS, SA, BarkerMH, AIES]) def test_logistic_regression_x64(kernel_cls): N, dim = 3000, 3 if kernel_cls is SA: num_warmup, num_samples = (100000, 100000) elif kernel_cls is BarkerMH: num_warmup, num_samples = (2000, 12000) + elif kernel_cls in [AIES, ESS]: + num_chains = 10 + samples_each_chain = 8000 + num_warmup, num_samples = (10_000, samples_each_chain * num_chains) else: num_warmup, num_samples = (1000, 8000) data = random.normal(random.PRNGKey(0), (N, dim)) @@ -102,17 +117,25 @@ def model(labels): logits = numpyro.deterministic("logits", jnp.sum(coefs * data, axis=-1)) return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels) - if kernel_cls is SA: - kernel = SA(model=model, adapt_state_size=9) - elif kernel_cls is BarkerMH: - kernel = BarkerMH(model=model) + if kernel_cls in [AIES, ESS]: + kernel = kernel_cls(model) + + mcmc = MCMC( + kernel, num_warmup=num_warmup, num_samples=samples_each_chain, + progress_bar=False, num_chains=num_chains, chain_method='vectorized' + ) else: - kernel = kernel_cls( - model=model, trajectory_length=8, find_heuristic_step_size=True + if kernel_cls is SA: + kernel = SA(model=model, adapt_state_size=9) + elif kernel_cls is BarkerMH: + kernel = BarkerMH(model=model) + else: + kernel = kernel_cls( + model=model, trajectory_length=8, find_heuristic_step_size=True + ) + mcmc = MCMC( + kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False ) - mcmc = MCMC( - kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False - ) mcmc.run(random.PRNGKey(2), labels) mcmc.print_summary() samples = mcmc.get_samples() @@ -185,7 +208,7 @@ def model(data): assert_allclose(jnp.mean(samples["loc"], 0), true_coef, atol=0.007) -@pytest.mark.parametrize("kernel_cls", [HMC, NUTS, SA, BarkerMH]) +@pytest.mark.parametrize("kernel_cls", [HMC, NUTS, SA, BarkerMH, AIES]) def test_beta_bernoulli_x64(kernel_cls): if kernel_cls is SA and "CI" in os.environ and "JAX_ENABLE_X64" in os.environ: pytest.skip("The test is flaky on CI x64.") @@ -200,15 +223,25 @@ def model(data): true_probs = jnp.array([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000,)) - if kernel_cls is SA: - kernel = SA(model=model) - elif kernel_cls is BarkerMH: - kernel = BarkerMH(model=model) + + if kernel_cls in [AIES, ESS]: + num_chains = 10 + kernel = kernel_cls(model=model) + + mcmc = MCMC( + kernel, num_warmup=num_warmup, num_samples=num_samples, + progress_bar=False, num_chains=num_chains, chain_method='vectorized' + ) else: - kernel = kernel_cls(model=model, trajectory_length=0.1) - mcmc = MCMC( - kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False - ) + if kernel_cls is SA: + kernel = SA(model=model) + elif kernel_cls is BarkerMH: + kernel = BarkerMH(model=model) + else: + kernel = kernel_cls(model=model, trajectory_length=0.1) + mcmc = MCMC( + kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False + ) mcmc.run(random.PRNGKey(2), data) mcmc.print_summary() samples = mcmc.get_samples() From cc501a4e99e870481d31747749305b59fcca60d6 Mon Sep 17 00:00:00 2001 From: Ami Falk Date: Sun, 21 Jan 2024 16:39:10 -0500 Subject: [PATCH 09/23] AIES input validation --- numpyro/infer/ensemble.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/numpyro/infer/ensemble.py b/numpyro/infer/ensemble.py index 272a258b1..58b24de22 100644 --- a/numpyro/infer/ensemble.py +++ b/numpyro/infer/ensemble.py @@ -146,15 +146,17 @@ def init( assert rng_key.shape[0] % 2 == 0, "Number of chains must be even." + self._num_chains = rng_key.shape[0] + if self._potential_fn and init_params is None: raise ValueError( "Valid value of `init_params` must be provided with `potential_fn`." ) + if init_params: + assert all([param.shape[0] == self._num_chains + for param in jax.tree_leaves(init_params)]), ("The batch dimension of each " + "param must match n_chains") - # TODO: if init_params is specified, check that the batch dimension of each array - # matches n_chains - - self._num_chains = rng_key.shape[0] rng_key, rng_key_inner_state, rng_key_init_model = random.split(rng_key[0], 3) rng_key_init_model = random.split(rng_key_init_model, self._num_chains) @@ -256,10 +258,13 @@ def __init__(self, model=None, potential_fn=None, randomize_split=False, moves=N self._moves = [AIES.DEMove()] self._weights = jnp.array([1.0]) else: - # TODO: check that moves are valid self._moves = list(moves.keys()) self._weights = jnp.array([weight for weight in moves.values()]) / len(moves) + assert all([hasattr(move, '__call__') for move in self._moves]), ( + "Each move must be a callable (one of AIES.DEMove(), or AIES.StretchMove()).") + assert jnp.all(self._weights >= 0), "Each specified move must have probability >= 0" + super().__init__(model, potential_fn, randomize_split=randomize_split, @@ -328,7 +333,8 @@ def de_move(rng_key, active, inactive): pairs_key, gamma_key = random.split(rng_key) n_active_chains, n_params = inactive.shape - # TODO: if we pass in n_params to parent scope we don't need to recompute this each time + # XXX: if we pass in n_params to parent scope we don't need to + # recompute this each time g = 2.38 / jnp.sqrt(2.0 * n_params) if not g0 else g0 selected_pairs = random.choice(pairs_key, PAIRS, shape=(n_active_chains,)) From 73ca58d871676c595b21754ba8a09c53e3107f11 Mon Sep 17 00:00:00 2001 From: Ami Falk Date: Sun, 21 Jan 2024 17:45:13 -0500 Subject: [PATCH 10/23] better docs, more input validation --- numpyro/infer/ensemble.py | 41 +++++++++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/numpyro/infer/ensemble.py b/numpyro/infer/ensemble.py index 58b24de22..32e0d0425 100644 --- a/numpyro/infer/ensemble.py +++ b/numpyro/infer/ensemble.py @@ -229,7 +229,8 @@ class AIES(EnsembleSampler): :param bool randomize_split: whether or not to permute the chain order at each iteration. Defaults to False. :param moves: a dictionary mapping moves to their respective probabilities of being selected. - If left empty, defaults to `AIES.DEMove()`. + Valid keys are `AIES.DEMove()` and `AIES.StretchMove()`. Both tend to work well in practice. + If the sum of probabilites exceeds 1, the probabilities will be normalized. Defaults to `{AIES.DEMove(): 1.0}`. :param callable init_strategy: a per-site initialization function. See :ref:`init_strategy` section for available functions. @@ -246,8 +247,8 @@ def model(): x = numpyro.sample("x", dist.Normal().expand([10])) numpyro.sample("obs", dist.Normal(x, 1.0), obs=jnp.ones(10)) - kernel = AIES(model, moves={AIES.DEMove() : .5, - AIES.StretchMove() : .5}) + kernel = AIES(model, moves={AIES.DEMove() : 0.5, + AIES.StretchMove() : 0.5}) mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=20, chain_method='vectorized') mcmc.run(jax.random.PRNGKey(0)) mcmc.print_summary() @@ -391,7 +392,7 @@ def stretch_move(rng_key, active, inactive): class ESS(EnsembleSampler): """ Ensemble Slice Sampling: a gradient free method. Suitable for low to moderate dimensional models. - Generally, `num_chains` should be at least twice the dimensionality of the model. + Generally, `num_chains` should be at least twice the dimensionality of the model. Increasing .. note:: This kernel must be used with `num_chains` > 1 and `chain_method="vectorized` or `chain_method="parallel` in :class:`MCMC`. The number of chains must be divisible by 2. @@ -409,14 +410,23 @@ class ESS(EnsembleSampler): given input parameters. The input parameters to `potential_fn` can be any python collection type, provided that `init_params` argument to :meth:`init` has the same type. - :param bool randomize_split: whether or not to permute the chain order at each iteration. - Strongly recommended to set to True. + :param bool randomize_split: whether or not to permute the chain order at each iteration. + Defaults to True. :param moves: a dictionary mapping moves to their respective probabilities of being selected. - If left empty, defaults to `ESS.DifferentialMove()`. - :param int max_steps: number of maximum stepping-out steps per sample. - :param int max_iter: number of maximum expansions/contractions per sample. - :param float init_mu: initial scale factor. - :param bool tune_mu: whether or not to tune the intial scale factor. + If the sum of probabilites exceeds 1, the probabilities will be normalized. Valid keys include: + - `ESS.DifferentialMove()` -> default proposal, works well along a wide range + of target distributions + - `ESS.GaussianMove()` -> for approximately normally distributed targets + - `ESS.KDEMove()` -> for multimodal posteriors - requires large `num_chains`, and + they must be well initialized + - `ESS.RandomMove()` -> no chain interaction, useful for debugging + + Defaults to `{ESS.DifferentialMove(): 1.0}`. + + :param int max_steps: number of maximum stepping-out steps per sample. Defaults to 10,000. + :param int max_iter: number of maximum expansions/contractions per sample. Defaults to 10,000. + :param float init_mu: initial scale factor. Defaults to 1.0. + :param bool tune_mu: whether or not to tune the initial scale factor. Defaults to True. :param callable init_strategy: a per-site initialization function. See :ref:`init_strategy` section for available functions. @@ -434,7 +444,7 @@ def model(): numpyro.sample("obs", dist.Normal(x, 1.0), obs=jnp.ones(10)) kernel = AIES(model, moves={ESS.DifferentialMove() : .8, - ESS.KDEMove() : .2}) + ESS.RandomMove() : .2}) mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=20, chain_method='vectorized') mcmc.run(jax.random.PRNGKey(0)) mcmc.print_summary() @@ -457,6 +467,13 @@ def __init__( else: self._moves = list(moves.keys()) self._weights = jnp.array([weight for weight in moves.values()]) / len(moves) + + assert all([hasattr(move, '__call__') for move in self._moves]), ( + "Each move must be a callable (one of `ESS.DifferentialMove()`, " + "`ESS.GaussianMove()`, `ESS.KDEMove()`, `ESS.RandomMove()`)") + + assert jnp.all(self._weights >= 0), "Each specified move must have probability >= 0" + assert init_mu > 0, "Scale factor should be strictly positive" self._max_steps = max_steps # max number of stepping out steps self._max_iter = max_iter # max number of expansions/contractions From e3687812890f26bba8de4564c545674f7338cc40 Mon Sep 17 00:00:00 2001 From: Ami Falk Date: Sun, 21 Jan 2024 22:02:46 -0500 Subject: [PATCH 11/23] ESS passing test cases --- numpyro/infer/ensemble.py | 7 +++++-- test/infer/test_mcmc.py | 32 ++++++++++++++++---------------- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/numpyro/infer/ensemble.py b/numpyro/infer/ensemble.py index 32e0d0425..bbfa2bebe 100644 --- a/numpyro/infer/ensemble.py +++ b/numpyro/infer/ensemble.py @@ -152,7 +152,7 @@ def init( raise ValueError( "Valid value of `init_params` must be provided with `potential_fn`." ) - if init_params: + if init_params is not None: assert all([param.shape[0] == self._num_chains for param in jax.tree_leaves(init_params)]), ("The batch dimension of each " "param must match n_chains") @@ -480,7 +480,10 @@ def __init__( self._init_mu = init_mu self._tune_mu = tune_mu - super().__init__(model, potential_fn, randomize_split, init_strategy) + super().__init__(model, + potential_fn, + randomize_split=randomize_split, + init_strategy=init_strategy) def init_inner_state(self, rng_key): self.batch_log_density = lambda x: self._batch_log_density(x)[:, jnp.newaxis] diff --git a/test/infer/test_mcmc.py b/test/infer/test_mcmc.py index ab46ed9d0..7300bcf85 100644 --- a/test/infer/test_mcmc.py +++ b/test/infer/test_mcmc.py @@ -17,7 +17,7 @@ import numpyro import numpyro.distributions as dist from numpyro.distributions.transforms import AffineTransform -from numpyro.infer import HMC, MCMC, NUTS, SA, BarkerMH, AIES, ESS # TODO: get ESS working +from numpyro.infer import HMC, MCMC, NUTS, SA, BarkerMH, AIES, ESS from numpyro.infer.hmc import hmc from numpyro.infer.reparam import TransformReparam from numpyro.infer.sa import _get_proposal_loc_and_scale, _numpy_delete @@ -25,7 +25,7 @@ from numpyro.util import fori_collect, is_prng_key -@pytest.mark.parametrize("kernel_cls", [HMC, NUTS, SA, BarkerMH, AIES]) +@pytest.mark.parametrize("kernel_cls", [HMC, NUTS, SA, BarkerMH, AIES, ESS]) @pytest.mark.parametrize("dense_mass", [False, True]) def test_unnormalized_normal_x64(kernel_cls, dense_mass): true_mean, true_std = 1.0, 0.5 @@ -94,19 +94,10 @@ def potential_fn(z): assert np.sum(np.abs(np.cov(samples.T) - true_cov)) / D**2 < 0.02 -@pytest.mark.parametrize("kernel_cls", [HMC, NUTS, SA, BarkerMH, AIES]) +@pytest.mark.parametrize("kernel_cls", [HMC, NUTS, SA, BarkerMH, AIES, ESS]) def test_logistic_regression_x64(kernel_cls): N, dim = 3000, 3 - if kernel_cls is SA: - num_warmup, num_samples = (100000, 100000) - elif kernel_cls is BarkerMH: - num_warmup, num_samples = (2000, 12000) - elif kernel_cls in [AIES, ESS]: - num_chains = 10 - samples_each_chain = 8000 - num_warmup, num_samples = (10_000, samples_each_chain * num_chains) - else: - num_warmup, num_samples = (1000, 8000) + data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1.0, dim + 1.0) logits = jnp.sum(true_coefs * data, axis=-1) @@ -116,26 +107,35 @@ def model(labels): coefs = numpyro.sample("coefs", dist.Normal(jnp.zeros(dim), jnp.ones(dim))) logits = numpyro.deterministic("logits", jnp.sum(coefs * data, axis=-1)) return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels) - + if kernel_cls in [AIES, ESS]: + num_chains = 10 + samples_each_chain = 8000 + num_warmup, num_samples = (10_000, samples_each_chain * num_chains) kernel = kernel_cls(model) - + mcmc = MCMC( kernel, num_warmup=num_warmup, num_samples=samples_each_chain, progress_bar=False, num_chains=num_chains, chain_method='vectorized' ) else: if kernel_cls is SA: + num_warmup, num_samples = (100000, 100000) kernel = SA(model=model, adapt_state_size=9) + elif kernel_cls is BarkerMH: + num_warmup, num_samples = (2000, 12000) kernel = BarkerMH(model=model) else: + num_warmup, num_samples = (1000, 8000) kernel = kernel_cls( model=model, trajectory_length=8, find_heuristic_step_size=True ) + mcmc = MCMC( kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False ) + mcmc.run(random.PRNGKey(2), labels) mcmc.print_summary() samples = mcmc.get_samples() @@ -208,7 +208,7 @@ def model(data): assert_allclose(jnp.mean(samples["loc"], 0), true_coef, atol=0.007) -@pytest.mark.parametrize("kernel_cls", [HMC, NUTS, SA, BarkerMH, AIES]) +@pytest.mark.parametrize("kernel_cls", [HMC, NUTS, SA, BarkerMH, AIES, ESS]) def test_beta_bernoulli_x64(kernel_cls): if kernel_cls is SA and "CI" in os.environ and "JAX_ENABLE_X64" in os.environ: pytest.skip("The test is flaky on CI x64.") From ff38013dd272d92f50e72ecddb34d0382b74a08b Mon Sep 17 00:00:00 2001 From: Ami Falk Date: Sun, 21 Jan 2024 23:00:32 -0500 Subject: [PATCH 12/23] add tests for other files --- numpyro/infer/ensemble.py | 3 +- test/infer/test_ensemble_mcmc.py | 57 ++++++++++++++++++++++++++++++++ test/infer/test_ensemble_util.py | 30 +++++++++++++++++ 3 files changed, 89 insertions(+), 1 deletion(-) create mode 100644 test/infer/test_ensemble_mcmc.py create mode 100644 test/infer/test_ensemble_util.py diff --git a/numpyro/infer/ensemble.py b/numpyro/infer/ensemble.py index bbfa2bebe..f4c19336c 100644 --- a/numpyro/infer/ensemble.py +++ b/numpyro/infer/ensemble.py @@ -142,7 +142,8 @@ def init( ): assert not is_prng_key( rng_key - ), "EnsembleSampler only supports chain_method='vectorized' or chain_method='parallel'." + ), ("EnsembleSampler only supports chain_method='vectorized' or chain_method='parallel'." + " (num_chains must be greater than 1)") assert rng_key.shape[0] % 2 == 0, "Number of chains must be even." diff --git a/test/infer/test_ensemble_mcmc.py b/test/infer/test_ensemble_mcmc.py new file mode 100644 index 000000000..76ce2b0d7 --- /dev/null +++ b/test/infer/test_ensemble_mcmc.py @@ -0,0 +1,57 @@ +import pytest + +import jax +import jax.numpy as jnp +import jax.random as random + +import numpyro +import numpyro.distributions as dist +from numpyro.infer import MCMC, AIES, ESS + +# --- +# reused for all smoke-tests +N, dim = 3000, 3 + +data = random.normal(random.PRNGKey(0), (N, dim)) +true_coefs = jnp.arange(1.0, dim + 1.0) +logits = jnp.sum(true_coefs * data, axis=-1) +labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) + +def model(labels): + coefs = numpyro.sample("coefs", dist.Normal(jnp.zeros(dim), jnp.ones(dim))) + logits = numpyro.deterministic("logits", jnp.sum(coefs * data, axis=-1)) + return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels) +# --- + +@pytest.mark.parametrize("kernel_cls, n_chain, method", + [(AIES, 10, "sequential"), + (AIES, 1, "vectorized"), + (ESS, 10, "sequential"), + (ESS, 1, "vectorized")]) +def test_chain_smoke(kernel_cls, n_chain, method): + kernel = kernel_cls(model) + + mcmc = MCMC(kernel, num_warmup=10, num_samples=10, + progress_bar=False, num_chains=n_chain, chain_method=method) + + with pytest.raises(AssertionError, match="chain_method"): + mcmc.run(random.PRNGKey(2), labels) + +@pytest.mark.parametrize("kernel_cls", [AIES, ESS]) +def test_out_shape_smoke(kernel_cls): + n_chains = 10 + kernel = kernel_cls(model) + + mcmc = MCMC(kernel, num_warmup=10, num_samples=10, + progress_bar=False, num_chains=n_chains, chain_method='vectorized') + mcmc.run(random.PRNGKey(2), labels) + + assert (mcmc.get_samples(group_by_chain=True)['coefs'].shape[0] + == n_chains) + +@pytest.mark.parametrize("kernel_cls", [AIES, ESS]) +def test_invalid_moves(kernel_cls): + with pytest.raises(AssertionError, match="Each move"): + kernel_cls(model, moves={'invalid': 1.}) + + \ No newline at end of file diff --git a/test/infer/test_ensemble_util.py b/test/infer/test_ensemble_util.py new file mode 100644 index 000000000..3df794c71 --- /dev/null +++ b/test/infer/test_ensemble_util.py @@ -0,0 +1,30 @@ +import jax +import jax.numpy as jnp +from numpyro.infer.ensemble_util import _get_nondiagonal_pairs, batch_ravel_pytree + +def test_nondiagonal_pairs(): + truth = jnp.array( + [[1, 0], + [2, 0], + [2, 1], + [0, 1], + [0, 2], + [1, 2]], dtype=jnp.int32) + + assert jnp.all(_get_nondiagonal_pairs(3) == truth) + +def test_batch_ravel_pytree(): + arr1 = jnp.arange(10).reshape((5, 2)) + arr2 = jnp.arange(15).reshape((5, 3)) + arr3 = jnp.arange(20).reshape((5, 4)) + + tree = {'arr1': arr1, 'arr2': arr2, 'arr3': arr3} + + flattened, unravel_fn = batch_ravel_pytree(tree) + unflattened = unravel_fn(flattened) + + assert flattened.shape == (5, 2 + 3 + 4) + + for unflattened_leaf, original_leaf in zip(jax.tree_util.tree_leaves(unflattened), + jax.tree_util.tree_leaves(tree)): + assert jnp.all(unflattened_leaf == original_leaf) From 78945bf2ce114f6c06cdd1e126f83de69dc6c11d Mon Sep 17 00:00:00 2001 From: Ami Falk Date: Sun, 21 Jan 2024 23:04:04 -0500 Subject: [PATCH 13/23] linting --- numpyro/infer/ensemble.py | 134 +++++++++++++++---------------- numpyro/infer/ensemble_util.py | 45 +++++------ test/infer/test_ensemble_mcmc.py | 31 +++---- test/infer/test_ensemble_util.py | 13 ++- test/infer/test_mcmc.py | 22 ++--- 5 files changed, 125 insertions(+), 120 deletions(-) diff --git a/numpyro/infer/ensemble.py b/numpyro/infer/ensemble.py index f4c19336c..8aa3bd3c1 100644 --- a/numpyro/infer/ensemble.py +++ b/numpyro/infer/ensemble.py @@ -63,7 +63,7 @@ class EnsembleSampler(MCMCKernel, ABC): """ Abstract class for ensemble samplers. - + :param model: Python callable containing Pyro :mod:`~numpyro.primitives`. If model is provided, `potential_fn` will be inferred using the model. :param potential_fn: Python callable that computes the potential energy @@ -74,7 +74,7 @@ class EnsembleSampler(MCMCKernel, ABC): :param callable init_strategy: a per-site initialization function. See :ref:`init_strategy` section for available functions. """ - + def __init__(self, model=None, potential_fn=None, *, randomize_split, init_strategy): if not (model is None) ^ (potential_fn is None): raise ValueError("Only one of `model` or `potential_fn` must be specified.") @@ -128,12 +128,12 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params): if init_params is None: init_params = new_init_params - flat_params, unravel_fn = batch_ravel_pytree(init_params) + flat_params, unravel_fn = batch_ravel_pytree(init_params) self._batch_log_density = lambda z: -vmap(self._potential_fn)(unravel_fn(z)) if self._num_chains < 2 * flat_params.shape[1]: warnings.warn("Setting n_chains to at least 2*n_params is strongly recommended.\n" - f"n_chains: {self._num_chains}, n_params: {flat_params.shape[1]}") + f"n_chains: {self._num_chains}, n_params: {flat_params.shape[1]}") return init_params @@ -154,10 +154,10 @@ def init( "Valid value of `init_params` must be provided with `potential_fn`." ) if init_params is not None: - assert all([param.shape[0] == self._num_chains - for param in jax.tree_leaves(init_params)]), ("The batch dimension of each " - "param must match n_chains") - + assert all([param.shape[0] == self._num_chains + for param in jax.tree_leaves(init_params)]), ("The batch dimension of each " + "param must match n_chains") + rng_key, rng_key_inner_state, rng_key_init_model = random.split(rng_key[0], 3) rng_key_init_model = random.split(rng_key_init_model, self._num_chains) @@ -188,20 +188,20 @@ def sample(self, state, model_args, model_kwargs): def body_fn(i, z_flat_inner_state): z_flat, inner_state = z_flat_inner_state - - active, inactive = jax.lax.cond(i == 0, + + active, inactive = jax.lax.cond(i == 0, lambda x: (x[:split_ind], x[split_ind:]), lambda x: (x[split_ind:], x[split_ind:]), z_flat) - + z_updates, inner_state = self.update_active_chains(active, inactive, inner_state) - - z_flat = jax.lax.cond(i == 0, + + z_flat = jax.lax.cond(i == 0, lambda x: x.at[:split_ind].set(z_updates), lambda x: x.at[split_ind:].set(z_updates), z_flat) return (z_flat, inner_state) - + z_flat, inner_state = jax.lax.fori_loop(0, 2, body_fn, (z_flat, inner_state)) @@ -212,15 +212,15 @@ class AIES(EnsembleSampler): """ Affine-Invariant Ensemble Sampling: a gradient free method. Suitable for low to moderate dimensional models. Generally, `num_chains` should be at least twice the dimensionality of the model. - + .. note:: This kernel must be used with `num_chains` > 1 and `chain_method="vectorized` or `chain_method="parallel` in :class:`MCMC`. The number of chains must be divisible by 2. - + **References:** - + 1. *emcee: The MCMC Hammer* (https://iopscience.iop.org/article/10.1086/670067), - Daniel Foreman-Mackey, David W. Hogg, Dustin Lang, and Jonathan Goodman. - + Daniel Foreman-Mackey, David W. Hogg, Dustin Lang, and Jonathan Goodman. + :param model: Python callable containing Pyro :mod:`~numpyro.primitives`. If model is provided, `potential_fn` will be inferred using the model. :param potential_fn: Python callable that computes the potential energy @@ -231,10 +231,10 @@ class AIES(EnsembleSampler): Defaults to False. :param moves: a dictionary mapping moves to their respective probabilities of being selected. Valid keys are `AIES.DEMove()` and `AIES.StretchMove()`. Both tend to work well in practice. - If the sum of probabilites exceeds 1, the probabilities will be normalized. Defaults to `{AIES.DEMove(): 1.0}`. + If the sum of probabilites exceeds 1, the probabilities will be normalized. Defaults to `{AIES.DEMove(): 1.0}`. :param callable init_strategy: a per-site initialization function. See :ref:`init_strategy` section for available functions. - + **Example** .. code-block:: python @@ -247,14 +247,14 @@ class AIES(EnsembleSampler): def model(): x = numpyro.sample("x", dist.Normal().expand([10])) numpyro.sample("obs", dist.Normal(x, 1.0), obs=jnp.ones(10)) - - kernel = AIES(model, moves={AIES.DEMove() : 0.5, + + kernel = AIES(model, moves={AIES.DEMove() : 0.5, AIES.StretchMove() : 0.5}) mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=20, chain_method='vectorized') mcmc.run(jax.random.PRNGKey(0)) mcmc.print_summary() """ - + def __init__(self, model=None, potential_fn=None, randomize_split=False, moves=None, init_strategy=init_to_uniform): if not moves: self._moves = [AIES.DEMove()] @@ -267,7 +267,7 @@ def __init__(self, model=None, potential_fn=None, randomize_split=False, moves=N "Each move must be a callable (one of AIES.DEMove(), or AIES.StretchMove()).") assert jnp.all(self._weights >= 0), "Each specified move must have probability >= 0" - super().__init__(model, + super().__init__(model, potential_fn, randomize_split=randomize_split, init_strategy=init_strategy) @@ -316,7 +316,7 @@ def update_active_chains(self, active, inactive, inner_state): @staticmethod def DEMove(sigma=1.0e-5, g0=None): """A proposal using differential evolution. - + This `Differential evolution proposal `_ is implemented following `Nelson et al. (2013) @@ -331,11 +331,11 @@ def DEMove(sigma=1.0e-5, g0=None): def make_de_move(n_chains): PAIRS = _get_nondiagonal_pairs(n_chains // 2) - def de_move(rng_key, active, inactive): + def de_move(rng_key, active, inactive): pairs_key, gamma_key = random.split(rng_key) n_active_chains, n_params = inactive.shape - # XXX: if we pass in n_params to parent scope we don't need to + # XXX: if we pass in n_params to parent scope we don't need to # recompute this each time g = 2.38 / jnp.sqrt(2.0 * n_params) if not g0 else g0 @@ -357,7 +357,7 @@ def de_move(rng_key, active, inactive): return proposal, jnp.zeros(n_active_chains) return de_move - + return make_de_move @staticmethod @@ -386,25 +386,25 @@ def stretch_move(rng_key, active, inactive): proposal = inactive[r_idxs] - (inactive[r_idxs] - active) * zz[:, jnp.newaxis] return proposal, factors - + return stretch_move - + class ESS(EnsembleSampler): """ Ensemble Slice Sampling: a gradient free method. Suitable for low to moderate dimensional models. - Generally, `num_chains` should be at least twice the dimensionality of the model. Increasing - + Generally, `num_chains` should be at least twice the dimensionality of the model. Increasing + .. note:: This kernel must be used with `num_chains` > 1 and `chain_method="vectorized` or `chain_method="parallel` in :class:`MCMC`. The number of chains must be divisible by 2. - + **References:** - + 1. *zeus: a PYTHON implementation of ensemble slice sampling for efficient Bayesian parameter inference* (https://academic.oup.com/mnras/article/508/3/3589/6381726), Minas Karamanis, Florian Beutler, and John A. Peacock. 2. *Ensemble slice sampling* (https://link.springer.com/article/10.1007/s11222-021-10038-2), Minas Karamanis, Florian Beutler. - + :param model: Python callable containing Pyro :mod:`~numpyro.primitives`. If model is provided, `potential_fn` will be inferred using the model. :param potential_fn: Python callable that computes the potential energy @@ -412,27 +412,27 @@ class ESS(EnsembleSampler): any python collection type, provided that `init_params` argument to :meth:`init` has the same type. :param bool randomize_split: whether or not to permute the chain order at each iteration. - Defaults to True. + Defaults to True. :param moves: a dictionary mapping moves to their respective probabilities of being selected. If the sum of probabilites exceeds 1, the probabilities will be normalized. Valid keys include: - - `ESS.DifferentialMove()` -> default proposal, works well along a wide range + - `ESS.DifferentialMove()` -> default proposal, works well along a wide range of target distributions - `ESS.GaussianMove()` -> for approximately normally distributed targets - `ESS.KDEMove()` -> for multimodal posteriors - requires large `num_chains`, and - they must be well initialized - - `ESS.RandomMove()` -> no chain interaction, useful for debugging - - Defaults to `{ESS.DifferentialMove(): 1.0}`. - + they must be well initialized + - `ESS.RandomMove()` -> no chain interaction, useful for debugging + + Defaults to `{ESS.DifferentialMove(): 1.0}`. + :param int max_steps: number of maximum stepping-out steps per sample. Defaults to 10,000. :param int max_iter: number of maximum expansions/contractions per sample. Defaults to 10,000. :param float init_mu: initial scale factor. Defaults to 1.0. :param bool tune_mu: whether or not to tune the initial scale factor. Defaults to True. :param callable init_strategy: a per-site initialization function. See :ref:`init_strategy` section for available functions. - + **Example** - + .. code-block:: python import jax import jax.numpy as jnp @@ -443,8 +443,8 @@ class ESS(EnsembleSampler): def model(): x = numpyro.sample("x", dist.Normal().expand([10])) numpyro.sample("obs", dist.Normal(x, 1.0), obs=jnp.ones(10)) - - kernel = AIES(model, moves={ESS.DifferentialMove() : .8, + + kernel = AIES(model, moves={ESS.DifferentialMove() : .8, ESS.RandomMove() : .2}) mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=20, chain_method='vectorized') mcmc.run(jax.random.PRNGKey(0)) @@ -468,22 +468,22 @@ def __init__( else: self._moves = list(moves.keys()) self._weights = jnp.array([weight for weight in moves.values()]) / len(moves) - + assert all([hasattr(move, '__call__') for move in self._moves]), ( "Each move must be a callable (one of `ESS.DifferentialMove()`, " "`ESS.GaussianMove()`, `ESS.KDEMove()`, `ESS.RandomMove()`)") - + assert jnp.all(self._weights >= 0), "Each specified move must have probability >= 0" assert init_mu > 0, "Scale factor should be strictly positive" - + self._max_steps = max_steps # max number of stepping out steps self._max_iter = max_iter # max number of expansions/contractions self._init_mu = init_mu self._tune_mu = tune_mu - super().__init__(model, + super().__init__(model, potential_fn, - randomize_split=randomize_split, + randomize_split=randomize_split, init_strategy=init_strategy) def init_inner_state(self, rng_key): @@ -523,22 +523,22 @@ def update_active_chains(self, active, inactive, inner_state): n_expansions += curr_n_expansions n_contractions += curr_n_contractions itr = i + 0.5 - + if self._tune_mu: - safe_n_expansions = jnp.max(jnp.array([1, n_expansions])) + safe_n_expansions = jnp.max(jnp.array([1, n_expansions])) # only update tuning scale if a full iteration has passed - mu, n_expansions, n_contractions = jax.lax.cond(jnp.all(itr % 1 == 0), - lambda n_exp, n_con: (2.0 * n_exp / (n_exp + n_con), + mu, n_expansions, n_contractions = jax.lax.cond(jnp.all(itr % 1 == 0), + lambda n_exp, n_con: (2.0 * n_exp / (n_exp + n_con), jnp.array(0), jnp.array(0) - ), - lambda _, __: (mu, + ), + lambda _, __: (mu, n_expansions, n_contractions ), safe_n_expansions, n_contractions) - + return proposal, ESSState(itr, n_expansions, n_contractions, mu, rng_key) @@ -586,10 +586,10 @@ def GaussianMove(): approximation of the walkers of the complementary ensemble. """ - # In high dimensional regimes with sufficiently small n_active_chains, + # In high dimensional regimes with sufficiently small n_active_chains, # it is more efficient to sample without computing the Cholesky # decomposition of the covariance matrix: - + # eps = dist.Normal(0, 1).sample(rng_key, (n_active_chains, n_active_chains)) # return 2.0 * mu * (eps @ (inactive - jnp.mean(inactive, axis=0)) / jnp.sqrt(n_active_chains)) @@ -616,7 +616,7 @@ def DifferentialMove(): """ def make_differential_move(n_chains): PAIRS = _get_nondiagonal_pairs(n_chains // 2) - + def differential_move(rng_key, inactive, mu): n_active_chains, n_params = inactive.shape @@ -627,7 +627,7 @@ def differential_move(rng_key, inactive, mu): return 2.0 * mu * diffs return differential_move - + return make_differential_move @@ -648,11 +648,11 @@ def _step_out(self, rng_key, log_slice_height, active, directions): ) ) K = (self._max_steps - 1) - J - + # left stepping-out initialisation mask_J = jnp.full((n_active_chains, 1), True) - # right stepping-out initialisation - mask_K = jnp.full((n_active_chains, 1), True) + # right stepping-out initialisation + mask_K = jnp.full((n_active_chains, 1), True) init_values = (n_expansions, L, R, J, K, mask_J, mask_K, iteration) @@ -666,7 +666,7 @@ def cond_fn(args): def body_fn(args): n_expansions, L, R, J, K, mask_J, mask_K, iteration = args - log_prob_L = self.batch_log_density(directions * L + active) + log_prob_L = self.batch_log_density(directions * L + active) log_prob_R = self.batch_log_density(directions * R + active) can_expand_L = log_prob_L > log_slice_height diff --git a/numpyro/infer/ensemble_util.py b/numpyro/infer/ensemble_util.py index f0732b083..bef62121c 100644 --- a/numpyro/infer/ensemble_util.py +++ b/numpyro/infer/ensemble_util.py @@ -6,11 +6,10 @@ import numpy as np from jax import lax, vmap -import jax.numpy as jnp - from jax._src import dtypes from jax._src.tree_util import tree_flatten, tree_unflatten -from jax._src.util import safe_zip, unzip2, HashablePartial +from jax._src.util import HashablePartial, safe_zip, unzip2 +import jax.numpy as jnp zip = safe_zip @@ -18,23 +17,23 @@ def _get_nondiagonal_pairs(n): """ From https://github.com/dfm/emcee/blob/main/src/emcee/moves/de.py: - + Get the indices of a square matrix with size n, excluding the diagonal. """ - + rows, cols = np.tril_indices(n, -1) # -1 to exclude diagonal # Combine rows-cols and cols-rows pairs - pairs = np.column_stack([np.concatenate([rows, cols]), + pairs = np.column_stack([np.concatenate([rows, cols]), np.concatenate([cols, rows])]) return jnp.asarray(pairs) def batch_ravel_pytree(pytree): - """Ravel (flatten) a pytree of arrays with leading batch dimension down to a (batch_size, 1D) array. + """Ravel (flatten) a pytree of arrays with leading batch dimension down to a (batch_size, 1D) array. Args: - pytree: a pytree of arrays and scalars to ravel. + pytree: a pytree of arrays and scalars to ravel. Returns: A pair where the first element is a (batch_size, 1D) array representing the flattened and concatenated leaf values, with dtype determined by promoting the dtypes of @@ -42,11 +41,11 @@ def batch_ravel_pytree(pytree): vector of the same length back to a pytree of of the same structure as the input ``pytree``. If the input pytree is empty (i.e. has no leaves) then as a convention a 1D empty array of dtype float32 is returned in the first - component of the output. + component of the output. For details on dtype promotion, see - https://jax.readthedocs.io/en/latest/type_promotion.html. + https://jax.readthedocs.io/en/latest/type_promotion.html. """ - + leaves, treedef = tree_flatten(pytree) flat, unravel_list = _ravel_list(leaves) return flat, HashablePartial(unravel_pytree, treedef, unravel_list) @@ -61,29 +60,29 @@ def _ravel_list(lst): if not lst: return jnp.array([], jnp.float32), lambda _: [] from_dtypes = tuple(dtypes.dtype(l) for l in lst) to_dtype = dtypes.result_type(*from_dtypes) - - # here 1 is n_leading_batch_dimensions + + # here 1 is n_leading_batch_dimensions sizes, shapes = unzip2((np.prod(jnp.shape(x)[1:]), jnp.shape(x)[1:]) for x in lst) indices = tuple(np.cumsum(sizes)) - + if all(dt == to_dtype for dt in from_dtypes): # Skip any dtype conversion, resulting in a dtype-polymorphic `unravel`. # See https://github.com/google/jax/issues/7809. del from_dtypes, to_dtype - + # axis = n_leading_batch_dimensions # vmap n_leading_batch_dimensions times raveled = jnp.concatenate([vmapped_ravel(e) for e in lst], axis=1) return raveled, HashablePartial(_unravel_list_single_dtype, indices, shapes) - + # When there is more than one distinct input dtype, we perform type # conversions and produce a dtype-specific unravel function. - convert_vmapped_ravel = lambda e: vmapped_ravel(lax.convert_element_type(e, to_dtype)) + convert_vmapped_ravel = lambda e: vmapped_ravel(lax.convert_element_type(e, to_dtype)) raveled = jnp.concatenate([convert_vmapped_ravel(e) for e in lst]) unrav = HashablePartial(_unravel_list, indices, shapes, from_dtypes, to_dtype) return raveled, unrav - - + + def _unravel_list_single_dtype(indices, shapes, arr): # axis is n_leading_batch_dimensions chunks = jnp.split(arr, indices[:-1], axis=1) @@ -97,7 +96,7 @@ def _unravel_list(indices, shapes, from_dtypes, to_dtype, arr): if arr_dtype != to_dtype: raise TypeError(f"unravel function given array of dtype {arr_dtype}, " f"but expected dtype {to_dtype}") - + # axis is n_leading_batch_dimensions chunks = jnp.split(arr, indices[:-1], axis=1) with warnings.catch_warnings(): @@ -105,6 +104,6 @@ def _unravel_list(indices, shapes, from_dtypes, to_dtype, arr): # the number of -1s is the number of leading batch dimensions return [lax.convert_element_type(chunk.reshape((-1, *shape)), dtype) for chunk, shape, dtype in zip(chunks, shapes, from_dtypes)] - - - + + + diff --git a/test/infer/test_ensemble_mcmc.py b/test/infer/test_ensemble_mcmc.py index 76ce2b0d7..314200357 100644 --- a/test/infer/test_ensemble_mcmc.py +++ b/test/infer/test_ensemble_mcmc.py @@ -1,17 +1,19 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + import pytest -import jax import jax.numpy as jnp import jax.random as random import numpyro import numpyro.distributions as dist -from numpyro.infer import MCMC, AIES, ESS +from numpyro.infer import AIES, ESS, MCMC # --- # reused for all smoke-tests N, dim = 3000, 3 - + data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1.0, dim + 1.0) logits = jnp.sum(true_coefs * data, axis=-1) @@ -28,30 +30,29 @@ def model(labels): (AIES, 1, "vectorized"), (ESS, 10, "sequential"), (ESS, 1, "vectorized")]) -def test_chain_smoke(kernel_cls, n_chain, method): +def test_chain_smoke(kernel_cls, n_chain, method): kernel = kernel_cls(model) - - mcmc = MCMC(kernel, num_warmup=10, num_samples=10, + + mcmc = MCMC(kernel, num_warmup=10, num_samples=10, progress_bar=False, num_chains=n_chain, chain_method=method) - + with pytest.raises(AssertionError, match="chain_method"): mcmc.run(random.PRNGKey(2), labels) @pytest.mark.parametrize("kernel_cls", [AIES, ESS]) -def test_out_shape_smoke(kernel_cls): +def test_out_shape_smoke(kernel_cls): n_chains = 10 kernel = kernel_cls(model) - - mcmc = MCMC(kernel, num_warmup=10, num_samples=10, + + mcmc = MCMC(kernel, num_warmup=10, num_samples=10, progress_bar=False, num_chains=n_chains, chain_method='vectorized') mcmc.run(random.PRNGKey(2), labels) - - assert (mcmc.get_samples(group_by_chain=True)['coefs'].shape[0] - == n_chains) + + assert (mcmc.get_samples(group_by_chain=True)['coefs'].shape[0] + == n_chains) @pytest.mark.parametrize("kernel_cls", [AIES, ESS]) def test_invalid_moves(kernel_cls): with pytest.raises(AssertionError, match="Each move"): kernel_cls(model, moves={'invalid': 1.}) - - \ No newline at end of file + diff --git a/test/infer/test_ensemble_util.py b/test/infer/test_ensemble_util.py index 3df794c71..3b4cd1b56 100644 --- a/test/infer/test_ensemble_util.py +++ b/test/infer/test_ensemble_util.py @@ -1,23 +1,28 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + import jax import jax.numpy as jnp + from numpyro.infer.ensemble_util import _get_nondiagonal_pairs, batch_ravel_pytree + def test_nondiagonal_pairs(): truth = jnp.array( - [[1, 0], + [[1, 0], [2, 0], [2, 1], [0, 1], [0, 2], [1, 2]], dtype=jnp.int32) - + assert jnp.all(_get_nondiagonal_pairs(3) == truth) - + def test_batch_ravel_pytree(): arr1 = jnp.arange(10).reshape((5, 2)) arr2 = jnp.arange(15).reshape((5, 3)) arr3 = jnp.arange(20).reshape((5, 4)) - + tree = {'arr1': arr1, 'arr2': arr2, 'arr3': arr3} flattened, unravel_fn = batch_ravel_pytree(tree) diff --git a/test/infer/test_mcmc.py b/test/infer/test_mcmc.py index 7300bcf85..89392d02f 100644 --- a/test/infer/test_mcmc.py +++ b/test/infer/test_mcmc.py @@ -17,7 +17,7 @@ import numpyro import numpyro.distributions as dist from numpyro.distributions.transforms import AffineTransform -from numpyro.infer import HMC, MCMC, NUTS, SA, BarkerMH, AIES, ESS +from numpyro.infer import AIES, ESS, HMC, MCMC, NUTS, SA, BarkerMH from numpyro.infer.hmc import hmc from numpyro.infer.reparam import TransformReparam from numpyro.infer.sa import _get_proposal_loc_and_scale, _numpy_delete @@ -54,7 +54,7 @@ def potential_fn(z): mcmc = MCMC( kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False ) - + mcmc.run(random.PRNGKey(0), init_params=init_params) mcmc.print_summary() hmc_states = mcmc.get_samples() @@ -97,7 +97,7 @@ def potential_fn(z): @pytest.mark.parametrize("kernel_cls", [HMC, NUTS, SA, BarkerMH, AIES, ESS]) def test_logistic_regression_x64(kernel_cls): N, dim = 3000, 3 - + data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1.0, dim + 1.0) logits = jnp.sum(true_coefs * data, axis=-1) @@ -107,22 +107,22 @@ def model(labels): coefs = numpyro.sample("coefs", dist.Normal(jnp.zeros(dim), jnp.ones(dim))) logits = numpyro.deterministic("logits", jnp.sum(coefs * data, axis=-1)) return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels) - + if kernel_cls in [AIES, ESS]: num_chains = 10 samples_each_chain = 8000 num_warmup, num_samples = (10_000, samples_each_chain * num_chains) kernel = kernel_cls(model) - + mcmc = MCMC( - kernel, num_warmup=num_warmup, num_samples=samples_each_chain, + kernel, num_warmup=num_warmup, num_samples=samples_each_chain, progress_bar=False, num_chains=num_chains, chain_method='vectorized' ) else: if kernel_cls is SA: num_warmup, num_samples = (100000, 100000) kernel = SA(model=model, adapt_state_size=9) - + elif kernel_cls is BarkerMH: num_warmup, num_samples = (2000, 12000) kernel = BarkerMH(model=model) @@ -131,11 +131,11 @@ def model(labels): kernel = kernel_cls( model=model, trajectory_length=8, find_heuristic_step_size=True ) - + mcmc = MCMC( kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False ) - + mcmc.run(random.PRNGKey(2), labels) mcmc.print_summary() samples = mcmc.get_samples() @@ -223,13 +223,13 @@ def model(data): true_probs = jnp.array([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000,)) - + if kernel_cls in [AIES, ESS]: num_chains = 10 kernel = kernel_cls(model=model) mcmc = MCMC( - kernel, num_warmup=num_warmup, num_samples=num_samples, + kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False, num_chains=num_chains, chain_method='vectorized' ) else: From b4a969582b5cdae89779e81e9c4225add9bfbe93 Mon Sep 17 00:00:00 2001 From: Ami Falk Date: Sun, 21 Jan 2024 23:11:06 -0500 Subject: [PATCH 14/23] refactor ensemble_util --- numpyro/infer/ensemble_util.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/numpyro/infer/ensemble_util.py b/numpyro/infer/ensemble_util.py index bef62121c..8f5e3c463 100644 --- a/numpyro/infer/ensemble_util.py +++ b/numpyro/infer/ensemble_util.py @@ -56,9 +56,16 @@ def unravel_pytree(treedef, unravel_list, flat): def vmapped_ravel(a): return vmap(jnp.ravel)(a) +# When there is more than one distinct input dtype, we perform type +# conversions and produce a dtype-specific unravel function. +def convert_vmapped_ravel(e, to_dtype): + return vmapped_ravel(lax.convert_element_type(e, to_dtype)) + def _ravel_list(lst): - if not lst: return jnp.array([], jnp.float32), lambda _: [] - from_dtypes = tuple(dtypes.dtype(l) for l in lst) + if not lst: + return jnp.array([], jnp.float32) + lambda _: [] + from_dtypes = tuple(dtypes.dtype(item) for item in lst) to_dtype = dtypes.result_type(*from_dtypes) # here 1 is n_leading_batch_dimensions @@ -75,9 +82,6 @@ def _ravel_list(lst): raveled = jnp.concatenate([vmapped_ravel(e) for e in lst], axis=1) return raveled, HashablePartial(_unravel_list_single_dtype, indices, shapes) - # When there is more than one distinct input dtype, we perform type - # conversions and produce a dtype-specific unravel function. - convert_vmapped_ravel = lambda e: vmapped_ravel(lax.convert_element_type(e, to_dtype)) raveled = jnp.concatenate([convert_vmapped_ravel(e) for e in lst]) unrav = HashablePartial(_unravel_list, indices, shapes, from_dtypes, to_dtype) return raveled, unrav From d865eaccc048e8ba01bbce1edeba134d202deff0 Mon Sep 17 00:00:00 2001 From: Ami Falk Date: Mon, 22 Jan 2024 08:02:40 -0500 Subject: [PATCH 15/23] make test result less close to margin in CI, swap deprecated function --- numpyro/infer/ensemble.py | 4 ++-- test/infer/test_mcmc.py | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/numpyro/infer/ensemble.py b/numpyro/infer/ensemble.py index 8aa3bd3c1..7492bfef0 100644 --- a/numpyro/infer/ensemble.py +++ b/numpyro/infer/ensemble.py @@ -155,8 +155,8 @@ def init( ) if init_params is not None: assert all([param.shape[0] == self._num_chains - for param in jax.tree_leaves(init_params)]), ("The batch dimension of each " - "param must match n_chains") + for param in jax.tree_util.tree_leaves(init_params)]), ( + "The batch dimension of each param must match n_chains") rng_key, rng_key_inner_state, rng_key_init_model = random.split(rng_key[0], 3) rng_key_init_model = random.split(rng_key_init_model, self._num_chains) diff --git a/test/infer/test_mcmc.py b/test/infer/test_mcmc.py index 89392d02f..1e3771ed0 100644 --- a/test/infer/test_mcmc.py +++ b/test/infer/test_mcmc.py @@ -109,7 +109,10 @@ def model(labels): return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels) if kernel_cls in [AIES, ESS]: - num_chains = 10 + if kernel_cls is AIES: + num_chains = 16 + else: + num_chains = 10 samples_each_chain = 8000 num_warmup, num_samples = (10_000, samples_each_chain * num_chains) kernel = kernel_cls(model) From d2b5a09bc5ad615a884fbd9f4c7652d52e4de0cb Mon Sep 17 00:00:00 2001 From: Ami Falk Date: Tue, 23 Jan 2024 09:21:15 -0500 Subject: [PATCH 16/23] rename get_nondiagonal_indices, fix batch_ravel_pytree --- numpyro/infer/ensemble.py | 6 +-- numpyro/infer/ensemble_util.py | 88 ++++---------------------------- test/infer/test_ensemble_util.py | 6 +-- 3 files changed, 17 insertions(+), 83 deletions(-) diff --git a/numpyro/infer/ensemble.py b/numpyro/infer/ensemble.py index 7492bfef0..14bbbe3a6 100644 --- a/numpyro/infer/ensemble.py +++ b/numpyro/infer/ensemble.py @@ -11,7 +11,7 @@ from jax.scipy.stats import gaussian_kde import numpyro.distributions as dist -from numpyro.infer.ensemble_util import _get_nondiagonal_pairs, batch_ravel_pytree +from numpyro.infer.ensemble_util import batch_ravel_pytree, get_nondiagonal_indices from numpyro.infer.initialization import init_to_uniform from numpyro.infer.mcmc import MCMCKernel from numpyro.infer.util import initialize_model @@ -329,7 +329,7 @@ def DEMove(sigma=1.0e-5, g0=None): as recommended by the two references. """ def make_de_move(n_chains): - PAIRS = _get_nondiagonal_pairs(n_chains // 2) + PAIRS = get_nondiagonal_indices(n_chains // 2) def de_move(rng_key, active, inactive): pairs_key, gamma_key = random.split(rng_key) @@ -615,7 +615,7 @@ def DifferentialMove(): of target distributions. """ def make_differential_move(n_chains): - PAIRS = _get_nondiagonal_pairs(n_chains // 2) + PAIRS = get_nondiagonal_indices(n_chains // 2) def differential_move(rng_key, inactive, mu): n_active_chains, n_params = inactive.shape diff --git a/numpyro/infer/ensemble_util.py b/numpyro/infer/ensemble_util.py index 8f5e3c463..028d694a5 100644 --- a/numpyro/infer/ensemble_util.py +++ b/numpyro/infer/ensemble_util.py @@ -1,26 +1,20 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -import warnings - import numpy as np -from jax import lax, vmap -from jax._src import dtypes -from jax._src.tree_util import tree_flatten, tree_unflatten -from jax._src.util import HashablePartial, safe_zip, unzip2 +import jax +from jax.flatten_util import ravel_pytree import jax.numpy as jnp - -zip = safe_zip +from jax.tree_util import tree_map -def _get_nondiagonal_pairs(n): +def get_nondiagonal_indices(n): """ From https://github.com/dfm/emcee/blob/main/src/emcee/moves/de.py: Get the indices of a square matrix with size n, excluding the diagonal. """ - rows, cols = np.tril_indices(n, -1) # -1 to exclude diagonal # Combine rows-cols and cols-rows pairs @@ -31,83 +25,23 @@ def _get_nondiagonal_pairs(n): def batch_ravel_pytree(pytree): - """Ravel (flatten) a pytree of arrays with leading batch dimension down to a (batch_size, 1D) array. + """ + Ravel (flatten) a pytree of arrays with leading batch dimension down to a (batch_size, 1D) array. + Args: pytree: a pytree of arrays and scalars to ravel. Returns: A pair where the first element is a (batch_size, 1D) array representing the flattened and concatenated leaf values, with dtype determined by promoting the dtypes of leaf values, and the second element is a callable for unflattening a (batch_size, 1D) - vector of the same length back to a pytree of of the same structure as the + array of the same length back to a pytree of the same structure as the input ``pytree``. If the input pytree is empty (i.e. has no leaves) then as a convention a 1D empty array of dtype float32 is returned in the first component of the output. - For details on dtype promotion, see - https://jax.readthedocs.io/en/latest/type_promotion.html. """ + flat = jax.vmap(lambda x: ravel_pytree(x)[0])(pytree) + unravel_fn = jax.vmap(ravel_pytree(tree_map(lambda z: z[0], pytree))[1]) - leaves, treedef = tree_flatten(pytree) - flat, unravel_list = _ravel_list(leaves) - return flat, HashablePartial(unravel_pytree, treedef, unravel_list) - -def unravel_pytree(treedef, unravel_list, flat): - return tree_unflatten(treedef, unravel_list(flat)) - -def vmapped_ravel(a): - return vmap(jnp.ravel)(a) - -# When there is more than one distinct input dtype, we perform type -# conversions and produce a dtype-specific unravel function. -def convert_vmapped_ravel(e, to_dtype): - return vmapped_ravel(lax.convert_element_type(e, to_dtype)) - -def _ravel_list(lst): - if not lst: - return jnp.array([], jnp.float32) - lambda _: [] - from_dtypes = tuple(dtypes.dtype(item) for item in lst) - to_dtype = dtypes.result_type(*from_dtypes) - - # here 1 is n_leading_batch_dimensions - sizes, shapes = unzip2((np.prod(jnp.shape(x)[1:]), jnp.shape(x)[1:]) for x in lst) - indices = tuple(np.cumsum(sizes)) - - if all(dt == to_dtype for dt in from_dtypes): - # Skip any dtype conversion, resulting in a dtype-polymorphic `unravel`. - # See https://github.com/google/jax/issues/7809. - del from_dtypes, to_dtype - - # axis = n_leading_batch_dimensions - # vmap n_leading_batch_dimensions times - raveled = jnp.concatenate([vmapped_ravel(e) for e in lst], axis=1) - return raveled, HashablePartial(_unravel_list_single_dtype, indices, shapes) - - raveled = jnp.concatenate([convert_vmapped_ravel(e) for e in lst]) - unrav = HashablePartial(_unravel_list, indices, shapes, from_dtypes, to_dtype) - return raveled, unrav - - -def _unravel_list_single_dtype(indices, shapes, arr): - # axis is n_leading_batch_dimensions - chunks = jnp.split(arr, indices[:-1], axis=1) - - # the number of -1s is the number of leading batch dimensions - return [chunk.reshape((-1, *shape)) for chunk, shape in zip(chunks, shapes)] - - -def _unravel_list(indices, shapes, from_dtypes, to_dtype, arr): - arr_dtype = dtypes.dtype(arr) - if arr_dtype != to_dtype: - raise TypeError(f"unravel function given array of dtype {arr_dtype}, " - f"but expected dtype {to_dtype}") - - # axis is n_leading_batch_dimensions - chunks = jnp.split(arr, indices[:-1], axis=1) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") # ignore complex-to-real cast warning - # the number of -1s is the number of leading batch dimensions - return [lax.convert_element_type(chunk.reshape((-1, *shape)), dtype) - for chunk, shape, dtype in zip(chunks, shapes, from_dtypes)] - + return flat, unravel_fn diff --git a/test/infer/test_ensemble_util.py b/test/infer/test_ensemble_util.py index 3b4cd1b56..5a066c69d 100644 --- a/test/infer/test_ensemble_util.py +++ b/test/infer/test_ensemble_util.py @@ -4,10 +4,10 @@ import jax import jax.numpy as jnp -from numpyro.infer.ensemble_util import _get_nondiagonal_pairs, batch_ravel_pytree +from numpyro.infer.ensemble_util import batch_ravel_pytree, get_nondiagonal_indices -def test_nondiagonal_pairs(): +def test_nondiagonal_indices(): truth = jnp.array( [[1, 0], [2, 0], @@ -16,7 +16,7 @@ def test_nondiagonal_pairs(): [0, 2], [1, 2]], dtype=jnp.int32) - assert jnp.all(_get_nondiagonal_pairs(3) == truth) + assert jnp.all(get_nondiagonal_indices(3) == truth) def test_batch_ravel_pytree(): arr1 = jnp.arange(10).reshape((5, 2)) From b20829d4e6078209586441819331815eb2d662c5 Mon Sep 17 00:00:00 2001 From: Ami Falk Date: Tue, 23 Jan 2024 12:05:18 -0500 Subject: [PATCH 17/23] print ensemble kernel diagnostics, smoke test parallel arg --- numpyro/infer/ensemble.py | 9 ++++++--- numpyro/infer/mcmc.py | 11 ++++++++++- test/infer/test_ensemble_mcmc.py | 5 ++++- 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/numpyro/infer/ensemble.py b/numpyro/infer/ensemble.py index 14bbbe3a6..e99adde99 100644 --- a/numpyro/infer/ensemble.py +++ b/numpyro/infer/ensemble.py @@ -101,6 +101,10 @@ def model(self): def sample_field(self): return "z" + @property + def is_ensemble_kernel(self): + return True + @abstractmethod def init_inner_state(self, rng_key): """return inner_state""" @@ -142,8 +146,8 @@ def init( ): assert not is_prng_key( rng_key - ), ("EnsembleSampler only supports chain_method='vectorized' or chain_method='parallel'." - " (num_chains must be greater than 1)") + ), ("EnsembleSampler only supports chain_method='vectorized' with num_chains > 1.\n" + "If you want to run chains in parallel, please raise a github issue.") assert rng_key.shape[0] % 2 == 0, "Number of chains must be even." @@ -272,7 +276,6 @@ def __init__(self, model=None, potential_fn=None, randomize_split=False, moves=N randomize_split=randomize_split, init_strategy=init_strategy) - # XXX: this doesn't show because state_method='vectorized' shuts off diagnostics_str def get_diagnostics_str(self, state): return "acc. prob={:.2f}".format(state.inner_state.mean_accept_prob) diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index b4ac1ec4a..faae7a16e 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -139,6 +139,15 @@ def default_fields(self): """ return (self.sample_field,) + @property + def is_ensemble_kernel(self): + """ + Denotes whether the kernel is an ensemble kernel. If True, `get_diagnostics_str()` + will be displayed during the MCMC run (when :meth:`MCMC.run() + ` is called) if `chain_method`='vectorized' + """ + return False + def get_diagnostics_str(self, state): """ Given the current `state`, returns the diagnostics string to @@ -424,7 +433,7 @@ def _single_chain_mcmc(self, init, args, kwargs, collect_fields): sample_fn, postprocess_fn = self._get_cached_fns() diagnostics = ( # noqa: E731 lambda x: self.sampler.get_diagnostics_str(x[0]) - if is_prng_key(rng_key) + if is_prng_key(rng_key) or self.sampler.is_ensemble_kernel else "" ) init_val = (init_state, args, kwargs) if self._jit_model_args else (init_state,) diff --git a/test/infer/test_ensemble_mcmc.py b/test/infer/test_ensemble_mcmc.py index 314200357..7cc289d89 100644 --- a/test/infer/test_ensemble_mcmc.py +++ b/test/infer/test_ensemble_mcmc.py @@ -10,6 +10,7 @@ import numpyro.distributions as dist from numpyro.infer import AIES, ESS, MCMC +numpyro.set_host_device_count(2) # --- # reused for all smoke-tests N, dim = 3000, 3 @@ -28,8 +29,10 @@ def model(labels): @pytest.mark.parametrize("kernel_cls, n_chain, method", [(AIES, 10, "sequential"), (AIES, 1, "vectorized"), + (AIES, 2, "parallel"), (ESS, 10, "sequential"), - (ESS, 1, "vectorized")]) + (ESS, 1, "vectorized"), + (ESS, 2, "parallel")]) def test_chain_smoke(kernel_cls, n_chain, method): kernel = kernel_cls(model) From 84bf11c7b0e9be8dac2f22e083402e282473eca0 Mon Sep 17 00:00:00 2001 From: Ami Falk Date: Tue, 23 Jan 2024 12:15:51 -0500 Subject: [PATCH 18/23] fix docstring build --- numpyro/infer/mcmc.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index faae7a16e..8ffd38871 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -142,9 +142,10 @@ def default_fields(self): @property def is_ensemble_kernel(self): """ - Denotes whether the kernel is an ensemble kernel. If True, `get_diagnostics_str()` - will be displayed during the MCMC run (when :meth:`MCMC.run() - ` is called) if `chain_method`='vectorized' + Denotes whether the kernel is an ensemble kernel. If True, + diagnostics_str will be displayed during the MCMC run + (when :meth:`MCMC.run() ` is called) + if `chain_method` = "vectorized". """ return False From 33133e43c00c531cb34564c1a9a875a5a15af700 Mon Sep 17 00:00:00 2001 From: Ami Falk Date: Tue, 23 Jan 2024 18:13:29 -0500 Subject: [PATCH 19/23] documentation --- docs/source/mcmc.rst | 34 ++++++++++++- numpyro/infer/ensemble.py | 104 +++++++++++++++++++------------------- 2 files changed, 85 insertions(+), 53 deletions(-) diff --git a/docs/source/mcmc.rst b/docs/source/mcmc.rst index edfda92d2..39cf7be3a 100644 --- a/docs/source/mcmc.rst +++ b/docs/source/mcmc.rst @@ -9,7 +9,9 @@ We provide a high-level overview of the MCMC algorithms in NumPyro: * `BarkerMH `_ is a gradient-based MCMC method that may be competitive with HMC and NUTS for some models. It is applicable to models with continuous latent variables. * `HMCGibbs `_ combines HMC/NUTS steps with custom Gibbs updates. Gibbs updates must be specified by the user. * `DiscreteHMCGibbs `_ combines HMC/NUTS steps with Gibbs updates for discrete latent variables. The corresponding Gibbs updates are computed automatically. -* `SA `_ is the only MCMC method in NumPyro that does not leverage gradients. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate. It may be a good choice for models with non-differentiable log densities. Note that SA generally requires a *very* large number of samples, as mixing tends to be slow. On the plus side individual steps can be fast. +* `SA `_ does not leverage gradients. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate. It may be a good choice for models with non-differentiable log densities. Note that SA generally requires a *very* large number of samples, as mixing tends to be slow. On the plus side individual steps can be fast. +* `AIES `_ is a gradient-free ensemble method that informs Metropolis-Hastings proposals by sharing information between chains. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate. It may be a good choice for models with non-differentiable log densities, and can be robust to likelihood-free models. AIES generally requires the number of chains to be twice as large as the number of latent parameters, (and ideally larger). +* `ESS `_ is a gradient-free ensemble method that shares information between chains to find good slice sampling directions. It tends to be more sample efficient than AIES. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate and may be a good choice for models with non-differentiable log densities. ESS generally requires the number of chains to be twice as large as the number of latent parameters, (and ideally larger). Like HMC/NUTS, all remaining MCMC algorithms support enumeration over discrete latent variables if possible (see `restrictions `_). Enumerated sites need to be marked with `infer={'enumerate': 'parallel'}` like in the `annotation example `_. @@ -101,6 +103,30 @@ SA :show-inheritance: :member-order: bysource +EnsembleSampler +^^^^^^^^^^^^^^^ +.. autoclass:: numpyro.infer.ensemble.EnsembleSampler + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + +AIES +^^^^ +.. autoclass:: numpyro.infer.ensemble.AIES + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + +ESS +^^^ +.. autoclass:: numpyro.infer.ensemble.ESS + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + .. autofunction:: numpyro.infer.hmc.hmc .. autofunction:: numpyro.infer.hmc.hmc.init_kernel @@ -117,6 +143,12 @@ SA .. autodata:: numpyro.infer.sa.SAState +.. autodata:: numpyro.infer.ensemble.EnsembleSamplerState + +.. autodata:: numpyro.infer.ensemble.AIESState + +.. autodata:: numpyro.infer.ensemble.ESSState + TensorFlow Kernels ------------------ diff --git a/numpyro/infer/ensemble.py b/numpyro/infer/ensemble.py index e99adde99..fe6af936e 100644 --- a/numpyro/infer/ensemble.py +++ b/numpyro/infer/ensemble.py @@ -62,7 +62,8 @@ class EnsembleSampler(MCMCKernel, ABC): """ - Abstract class for ensemble samplers. + Abstract class for ensemble samplers. Each MCMC sample is divided into two sub-iterations + in which half of the ensemble is updated. :param model: Python callable containing Pyro :mod:`~numpyro.primitives`. If model is provided, `potential_fn` will be inferred using the model. @@ -214,11 +215,12 @@ def body_fn(i, z_flat_inner_state): class AIES(EnsembleSampler): """ - Affine-Invariant Ensemble Sampling: a gradient free method. Suitable for low to moderate dimensional models. + Affine-Invariant Ensemble Sampling: a gradient free method that informs Metropolis-Hastings + proposals by sharing information between chains. Suitable for low to moderate dimensional models. Generally, `num_chains` should be at least twice the dimensionality of the model. .. note:: This kernel must be used with `num_chains` > 1 and `chain_method="vectorized` - or `chain_method="parallel` in :class:`MCMC`. The number of chains must be divisible by 2. + in :class:`MCMC`. The number of chains must be divisible by 2. **References:** @@ -241,22 +243,21 @@ class AIES(EnsembleSampler): **Example** - .. code-block:: python - import jax - import jax.numpy as jnp - import numpyro - import numpyro.distributions as dist - from numpyro.infer import MCMC, AIES - - def model(): - x = numpyro.sample("x", dist.Normal().expand([10])) - numpyro.sample("obs", dist.Normal(x, 1.0), obs=jnp.ones(10)) - - kernel = AIES(model, moves={AIES.DEMove() : 0.5, - AIES.StretchMove() : 0.5}) - mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=20, chain_method='vectorized') - mcmc.run(jax.random.PRNGKey(0)) - mcmc.print_summary() + .. doctest:: + >>> import jax + >>> import jax.numpy as jnp + >>> import numpyro + >>> import numpyro.distributions as dist + >>> from numpyro.infer import MCMC, AIES + + >>> def model(): + ... x = numpyro.sample("x", dist.Normal().expand([10])) + ... numpyro.sample("obs", dist.Normal(x, 1.0), obs=jnp.ones(10)) + >>> + >>> kernel = AIES(model, moves={AIES.DEMove() : 0.5, + ... AIES.StretchMove() : 0.5}) + >>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=20, chain_method='vectorized') + >>> mcmc.run(jax.random.PRNGKey(0)) """ def __init__(self, model=None, potential_fn=None, randomize_split=False, moves=None, init_strategy=init_to_uniform): @@ -318,18 +319,20 @@ def update_active_chains(self, active, inactive, inner_state): @staticmethod def DEMove(sigma=1.0e-5, g0=None): - """A proposal using differential evolution. + """ + A proposal using differential evolution. This `Differential evolution proposal `_ is implemented following `Nelson et al. (2013) `_. - Args: - sigma (float): The standard deviation of the Gaussian used to stretch - the proposal vector. - gamma0 (Optional[float]): The mean stretch factor for the proposal - vector. By default, it is `2.38 / sqrt(2*ndim)` - as recommended by the two references. + + :param sigma: (optional) + The standard deviation of the Gaussian used to stretch the proposal vector. + Defaults to `1.0.e-5`. + :param g0 (optional): + The mean stretch factor for the proposal vector. By default, + it is `2.38 / sqrt(2*ndim)` as recommended by the two references. """ def make_de_move(n_chains): PAIRS = get_nondiagonal_indices(n_chains // 2) @@ -395,11 +398,12 @@ def stretch_move(rng_key, active, inactive): class ESS(EnsembleSampler): """ - Ensemble Slice Sampling: a gradient free method. Suitable for low to moderate dimensional models. - Generally, `num_chains` should be at least twice the dimensionality of the model. Increasing + Ensemble Slice Sampling: a gradient free method that finds better slice sampling directions + by sharing information between chains. Suitable for low to moderate dimensional models. + Generally, `num_chains` should be at least twice the dimensionality of the model. .. note:: This kernel must be used with `num_chains` > 1 and `chain_method="vectorized` - or `chain_method="parallel` in :class:`MCMC`. The number of chains must be divisible by 2. + in :class:`MCMC`. The number of chains must be divisible by 2. **References:** @@ -418,13 +422,10 @@ class ESS(EnsembleSampler): Defaults to True. :param moves: a dictionary mapping moves to their respective probabilities of being selected. If the sum of probabilites exceeds 1, the probabilities will be normalized. Valid keys include: - - `ESS.DifferentialMove()` -> default proposal, works well along a wide range - of target distributions - - `ESS.GaussianMove()` -> for approximately normally distributed targets - - `ESS.KDEMove()` -> for multimodal posteriors - requires large `num_chains`, and - they must be well initialized - - `ESS.RandomMove()` -> no chain interaction, useful for debugging - + `ESS.DifferentialMove()` -> default proposal, works well along a wide range of target distributions, + `ESS.GaussianMove()` -> for approximately normally distributed targets, + `ESS.KDEMove()` -> for multimodal posteriors - requires large `num_chains`, and they must be well initialized + `ESS.RandomMove()` -> no chain interaction, useful for debugging. Defaults to `{ESS.DifferentialMove(): 1.0}`. :param int max_steps: number of maximum stepping-out steps per sample. Defaults to 10,000. @@ -436,22 +437,21 @@ class ESS(EnsembleSampler): **Example** - .. code-block:: python - import jax - import jax.numpy as jnp - import numpyro - import numpyro.distributions as dist - from numpyro.infer import MCMC, AIES - - def model(): - x = numpyro.sample("x", dist.Normal().expand([10])) - numpyro.sample("obs", dist.Normal(x, 1.0), obs=jnp.ones(10)) - - kernel = AIES(model, moves={ESS.DifferentialMove() : .8, - ESS.RandomMove() : .2}) - mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=20, chain_method='vectorized') - mcmc.run(jax.random.PRNGKey(0)) - mcmc.print_summary() + .. doctest:: + >>> import jax + >>> import jax.numpy as jnp + >>> import numpyro + >>> import numpyro.distributions as dist + >>> from numpyro.infer import MCMC, ESS + + >>> def model(): + ... x = numpyro.sample("x", dist.Normal().expand([10])) + ... numpyro.sample("obs", dist.Normal(x, 1.0), obs=jnp.ones(10)) + >>> + >>> kernel = ESS(model, moves={ESS.DifferentialMove() : 0.8, + ... ESS.RandomMove() : 0.2}) + >>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=20, chain_method='vectorized') + >>> mcmc.run(jax.random.PRNGKey(0)) """ def __init__( self, From 2c325a1ff4e51e6110a98bb94fe69865e6ddacde Mon Sep 17 00:00:00 2001 From: Ami Falk Date: Tue, 23 Jan 2024 18:25:00 -0500 Subject: [PATCH 20/23] skip slow CI tests, unnest test if statements --- test/infer/test_mcmc.py | 67 +++++++++++++++++++++++++---------------- 1 file changed, 41 insertions(+), 26 deletions(-) diff --git a/test/infer/test_mcmc.py b/test/infer/test_mcmc.py index 1e3771ed0..3f52c6c1e 100644 --- a/test/infer/test_mcmc.py +++ b/test/infer/test_mcmc.py @@ -43,13 +43,16 @@ def potential_fn(z): kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False, num_chains=num_chains, chain_method='vectorized' ) + elif kernel_cls in [SA, BarkerMH]: + kernel = kernel_cls(potential_fn=potential_fn, dense_mass=dense_mass) + init_params = jnp.array(0.0) + mcmc = MCMC( + kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False + ) else: - if kernel_cls in [SA, BarkerMH]: - kernel = kernel_cls(potential_fn=potential_fn, dense_mass=dense_mass) - else: - kernel = kernel_cls( - potential_fn=potential_fn, trajectory_length=8, dense_mass=dense_mass - ) + kernel = kernel_cls( + potential_fn=potential_fn, trajectory_length=8, dense_mass=dense_mass + ) init_params = jnp.array(0.0) mcmc = MCMC( kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False @@ -96,6 +99,9 @@ def potential_fn(z): @pytest.mark.parametrize("kernel_cls", [HMC, NUTS, SA, BarkerMH, AIES, ESS]) def test_logistic_regression_x64(kernel_cls): + if kernel_cls in [AIES, ESS] and "CI" in os.environ: + pytest.skip("reduce time for CI.") + N, dim = 3000, 3 data = random.normal(random.PRNGKey(0), (N, dim)) @@ -121,20 +127,23 @@ def model(labels): kernel, num_warmup=num_warmup, num_samples=samples_each_chain, progress_bar=False, num_chains=num_chains, chain_method='vectorized' ) + elif kernel_cls is SA: + num_warmup, num_samples = (100000, 100000) + kernel = SA(model=model, adapt_state_size=9) + mcmc = MCMC( + kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False + ) + elif kernel_cls is BarkerMH: + num_warmup, num_samples = (2000, 12000) + kernel = BarkerMH(model=model) + mcmc = MCMC( + kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False + ) else: - if kernel_cls is SA: - num_warmup, num_samples = (100000, 100000) - kernel = SA(model=model, adapt_state_size=9) - - elif kernel_cls is BarkerMH: - num_warmup, num_samples = (2000, 12000) - kernel = BarkerMH(model=model) - else: - num_warmup, num_samples = (1000, 8000) - kernel = kernel_cls( - model=model, trajectory_length=8, find_heuristic_step_size=True - ) - + num_warmup, num_samples = (1000, 8000) + kernel = kernel_cls( + model=model, trajectory_length=8, find_heuristic_step_size=True + ) mcmc = MCMC( kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False ) @@ -215,6 +224,8 @@ def model(data): def test_beta_bernoulli_x64(kernel_cls): if kernel_cls is SA and "CI" in os.environ and "JAX_ENABLE_X64" in os.environ: pytest.skip("The test is flaky on CI x64.") + if kernel_cls is ESS and "CI" in os.environ: + pytest.skip("reduce time for CI.") num_warmup, num_samples = (100000, 100000) if kernel_cls is SA else (500, 20000) def model(data): @@ -230,18 +241,22 @@ def model(data): if kernel_cls in [AIES, ESS]: num_chains = 10 kernel = kernel_cls(model=model) - mcmc = MCMC( kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False, num_chains=num_chains, chain_method='vectorized' ) + elif kernel_cls is SA: + kernel = SA(model=model) + mcmc = MCMC( + kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False + ) + elif kernel_cls is BarkerMH: + kernel = BarkerMH(model=model) + mcmc = MCMC( + kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False + ) else: - if kernel_cls is SA: - kernel = SA(model=model) - elif kernel_cls is BarkerMH: - kernel = BarkerMH(model=model) - else: - kernel = kernel_cls(model=model, trajectory_length=0.1) + kernel = kernel_cls(model=model, trajectory_length=0.1) mcmc = MCMC( kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False ) From e35d34ada423d92c6f0c922e71788fac185c6540 Mon Sep 17 00:00:00 2001 From: Ami Falk Date: Tue, 23 Jan 2024 18:30:39 -0500 Subject: [PATCH 21/23] fix doctest --- numpyro/infer/ensemble.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/numpyro/infer/ensemble.py b/numpyro/infer/ensemble.py index fe6af936e..5d7a9c119 100644 --- a/numpyro/infer/ensemble.py +++ b/numpyro/infer/ensemble.py @@ -244,6 +244,7 @@ class AIES(EnsembleSampler): **Example** .. doctest:: + >>> import jax >>> import jax.numpy as jnp >>> import numpyro @@ -438,6 +439,7 @@ class ESS(EnsembleSampler): **Example** .. doctest:: + >>> import jax >>> import jax.numpy as jnp >>> import numpyro From 550b0231cb0418266696c00d4f95c59f05381747 Mon Sep 17 00:00:00 2001 From: Ami Falk Date: Tue, 23 Jan 2024 19:15:34 -0500 Subject: [PATCH 22/23] doc rewrite --- docs/source/mcmc.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/mcmc.rst b/docs/source/mcmc.rst index 39cf7be3a..f89d7f04c 100644 --- a/docs/source/mcmc.rst +++ b/docs/source/mcmc.rst @@ -9,9 +9,9 @@ We provide a high-level overview of the MCMC algorithms in NumPyro: * `BarkerMH `_ is a gradient-based MCMC method that may be competitive with HMC and NUTS for some models. It is applicable to models with continuous latent variables. * `HMCGibbs `_ combines HMC/NUTS steps with custom Gibbs updates. Gibbs updates must be specified by the user. * `DiscreteHMCGibbs `_ combines HMC/NUTS steps with Gibbs updates for discrete latent variables. The corresponding Gibbs updates are computed automatically. -* `SA `_ does not leverage gradients. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate. It may be a good choice for models with non-differentiable log densities. Note that SA generally requires a *very* large number of samples, as mixing tends to be slow. On the plus side individual steps can be fast. -* `AIES `_ is a gradient-free ensemble method that informs Metropolis-Hastings proposals by sharing information between chains. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate. It may be a good choice for models with non-differentiable log densities, and can be robust to likelihood-free models. AIES generally requires the number of chains to be twice as large as the number of latent parameters, (and ideally larger). -* `ESS `_ is a gradient-free ensemble method that shares information between chains to find good slice sampling directions. It tends to be more sample efficient than AIES. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate and may be a good choice for models with non-differentiable log densities. ESS generally requires the number of chains to be twice as large as the number of latent parameters, (and ideally larger). +* `SA `_ is a gradient-free MCMC method. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate. It may be a good choice for models with non-differentiable log densities. Note that SA generally requires a *very* large number of samples, as mixing tends to be slow. On the plus side individual steps can be fast. +* `AIES `_ is a gradient-free ensemble MCMC method that informs Metropolis-Hastings proposals by sharing information between chains. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate. It may be a good choice for models with non-differentiable log densities, and can be robust to likelihood-free models. AIES generally requires the number of chains to be twice as large as the number of latent parameters, (and ideally larger). +* `ESS `_ is a gradient-free ensemble MCMC method that shares information between chains to find good slice sampling directions. It tends to be more sample efficient than AIES. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate and may be a good choice for models with non-differentiable log densities. ESS generally requires the number of chains to be twice as large as the number of latent parameters, (and ideally larger). Like HMC/NUTS, all remaining MCMC algorithms support enumeration over discrete latent variables if possible (see `restrictions `_). Enumerated sites need to be marked with `infer={'enumerate': 'parallel'}` like in the `annotation example `_. From 96e9483f48bdc477c393c2adcca9e7b479335cce Mon Sep 17 00:00:00 2001 From: Ami Falk Date: Thu, 25 Jan 2024 14:41:03 -0500 Subject: [PATCH 23/23] fix distribution test --- test/test_distributions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_distributions.py b/test/test_distributions.py index 714fe2871..f80f6bcf6 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -1744,7 +1744,7 @@ def test_mean_var(jax_dist, sp_dist, params): sp_var = jnp.diag(d_sp.cov()) except TypeError: # mvn does not have .cov() method sp_var = jnp.diag(d_sp.cov) - except AttributeError: + except (AttributeError, ValueError): sp_var = d_sp.var() else: sp_var = d_sp.var()