Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use jaxified logp for initial point evaluation when sampling via Jax #7610

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
14 changes: 12 additions & 2 deletions pymc/initial_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from pymc.logprob.transforms import Transform
from pymc.pytensorf import (
SeedSequenceSeed,
compile,
find_rng_nodes,
replace_rng_nodes,
Expand Down Expand Up @@ -67,7 +68,7 @@ def make_initial_point_fns_per_chain(
overrides: StartDict | Sequence[StartDict | None] | None,
jitter_rvs: set[TensorVariable] | None = None,
chains: int,
) -> list[Callable]:
) -> list[Callable[[SeedSequenceSeed], PointType]]:
"""Create an initial point function for each chain, as defined by initvals.

If a single initval dictionary is passed, the function is replicated for each
Expand All @@ -82,6 +83,11 @@ def make_initial_point_fns_per_chain(
Random variable tensors for which U(-1, 1) jitter shall be applied.
(To the transformed space if applicable.)

Returns
-------
ipfns : list[Callable[[SeedSequenceSeed], dict[str, np.ndarray]]]
list of functions that return initial points for each chain.

Raises
------
ValueError
Expand Down Expand Up @@ -124,7 +130,7 @@ def make_initial_point_fn(
jitter_rvs: set[TensorVariable] | None = None,
default_strategy: str = "support_point",
return_transformed: bool = True,
) -> Callable:
) -> Callable[[SeedSequenceSeed], PointType]:
"""Create seeded function that computes initial values for all free model variables.

Parameters
Expand All @@ -138,6 +144,10 @@ def make_initial_point_fn(
Initial value (strategies) to use instead of what's specified in `Model.initial_values`.
return_transformed : bool
If `True` the returned variables will correspond to transformed initial values.

Returns
-------
initial_point_fn : Callable[[SeedSequenceSeed], dict[str, np.ndarray]]
"""
sdict_overrides = convert_str_to_rv_dict(model, overrides or {})
initval_strats = {
Expand Down
189 changes: 120 additions & 69 deletions pymc/sampling/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from collections.abc import Callable, Sequence
from datetime import datetime
from functools import partial
from types import ModuleType
from typing import Any, Literal

import arviz as az
Expand All @@ -28,6 +29,7 @@

from arviz.data.base import make_attrs
from jax.lax import scan
from numpy.typing import ArrayLike
from pytensor.compile import SharedVariable, Supervisor, mode
from pytensor.graph.basic import graph_inputs
from pytensor.graph.fg import FunctionGraph
Expand Down Expand Up @@ -121,7 +123,7 @@ def _replace_shared_variables(graph: list[TensorVariable]) -> list[TensorVariabl
def get_jaxified_graph(
inputs: list[TensorVariable] | None = None,
outputs: list[TensorVariable] | None = None,
) -> list[TensorVariable]:
) -> Callable[[list[TensorVariable]], list[TensorVariable]]:
"""Compile a PyTensor graph into an optimized JAX function."""
graph = _replace_shared_variables(outputs) if outputs is not None else None

Expand All @@ -144,13 +146,13 @@ def get_jaxified_graph(
return jax_funcify(fgraph)


def get_jaxified_logp(model: Model, negative_logp=True) -> Callable:
def get_jaxified_logp(model: Model, negative_logp: bool = True) -> Callable[[ArrayLike], jax.Array]:
model_logp = model.logp()
if not negative_logp:
model_logp = -model_logp
logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])

def logp_fn_wrap(x):
def logp_fn_wrap(x: ArrayLike) -> jax.Array:
return logp_fn(*x)[0]

return logp_fn_wrap
Expand Down Expand Up @@ -211,23 +213,43 @@ def _get_batched_jittered_initial_points(
chains: int,
initvals: StartDict | Sequence[StartDict | None] | None,
random_seed: RandomSeed,
logp_fn: Callable[[ArrayLike], jax.Array] | None = None,
jitter: bool = True,
jitter_max_retries: int = 10,
) -> np.ndarray | list[np.ndarray]:
"""Get jittered initial point in format expected by NumPyro MCMC kernel.
"""Get jittered initial point in format expected by Jax MCMC kernel.

Parameters
----------
logp_fn : Callable[Sequence[np.ndarray]], np.ndarray]
Jaxified logp function

Returns
-------
out: list of ndarrays
list with one item per variable and number of chains as batch dimension.
Each item has shape `(chains, *var.shape)`
"""
if logp_fn is None:
eval_logp_initial_point = None

else:

def eval_logp_initial_point(point: dict[str, np.ndarray]) -> jax.Array:
"""Wrap logp_fn to conform to _init_jitter logic.

Wraps jaxified logp function to accept a dict of
{model_variable: np.array} key:value pairs.
"""
return logp_fn(point.values())

initial_points = _init_jitter(
model,
initvals,
seeds=_get_seeds_per_chain(random_seed, chains),
jitter=jitter,
jitter_max_retries=jitter_max_retries,
logp_fn=eval_logp_initial_point,
)
initial_points_values = [list(initial_point.values()) for initial_point in initial_points]
if chains == 1:
Expand All @@ -236,7 +258,7 @@ def _get_batched_jittered_initial_points(


def _blackjax_inference_loop(
seed, init_position, logprob_fn, draws, tune, target_accept, **adaptation_kwargs
seed, init_position, logp_fn, draws, tune, target_accept, **adaptation_kwargs
):
import blackjax

Expand All @@ -252,13 +274,13 @@ def _blackjax_inference_loop(

adapt = blackjax.window_adaptation(
algorithm=algorithm,
logdensity_fn=logprob_fn,
logdensity_fn=logp_fn,
target_acceptance_rate=target_accept,
adaptation_info_fn=get_filter_adapt_info_fn(),
**adaptation_kwargs,
)
(last_state, tuned_params), _ = adapt.run(seed, init_position, num_steps=tune)
kernel = algorithm(logprob_fn, **tuned_params).step
kernel = algorithm(logp_fn, **tuned_params).step

def _one_step(state, xs):
_, rng_key = xs
Expand Down Expand Up @@ -289,67 +311,51 @@ def _sample_blackjax_nuts(
tune: int,
draws: int,
chains: int,
chain_method: str | None,
chain_method: Literal["parallel", "vectorized"],
progressbar: bool,
random_seed: int,
initial_points,
nuts_kwargs,
) -> az.InferenceData:
initial_points: np.ndarray | list[np.ndarray],
nuts_kwargs: dict[str, Any],
logp_fn: Callable[[ArrayLike], jax.Array] | None = None,
) -> tuple[Any, dict[str, Any], ModuleType]:
"""
Draw samples from the posterior using the NUTS method from the ``blackjax`` library.

Parameters
----------
draws : int, default 1000
The number of samples to draw. The number of tuned samples are discarded by
default.
tune : int, default 1000
model : Model
Model to sample from. The model needs to have free random variables.
target_accept : float in [0, 1].
The step size is tuned such that we approximate this acceptance rate. Higher
values like 0.9 or 0.95 often work better for problematic posteriors.
tune : int
Number of iterations to tune. Samplers adjust the step sizes, scalings or
similar during tuning. Tuning samples will be drawn in addition to the number
specified in the ``draws`` argument.
chains : int, default 4
draws : int
The number of samples to draw. The number of tuned samples are discarded by default.
chains : int
The number of chains to sample.
target_accept : float in [0, 1].
The step size is tuned such that we approximate this acceptance rate. Higher
values like 0.9 or 0.95 often work better for problematic posteriors.
random_seed : int, RandomState or Generator, optional
chain_method : "parallel" or "vectorized"
Specify how samples should be drawn.
progressbar : bool
Whether to show progressbar or not during sampling.
random_seed : int, RandomState or Generator
Random seed used by the sampling steps.
initvals: StartDict or Sequence[Optional[StartDict]], optional
Initial values for random variables provided as a dictionary (or sequence of
dictionaries) mapping the random variable (by name or reference) to desired
starting values.
jitter: bool, default True
If True, add jitter to initial points.
model : Model, optional
Model to sample from. The model needs to have free random variables. When inside
a ``with`` model context, it defaults to that model, otherwise the model must be
passed explicitly.
var_names : sequence of str, optional
Names of variables for which to compute the posterior samples. Defaults to all
variables in the posterior.
keep_untransformed : bool, default False
Include untransformed variables in the posterior samples. Defaults to False.
chain_method : str, default "parallel"
Specify how samples should be drawn. The choices include "parallel", and
"vectorized".
postprocessing_backend: Optional[Literal["cpu", "gpu"]], default None,
Specify how postprocessing should be computed. gpu or cpu
postprocessing_vectorize: Literal["vmap", "scan"], default "scan"
How to vectorize the postprocessing: vmap or sequential scan
idata_kwargs : dict, optional
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as
value for the ``log_likelihood`` key to indicate that the pointwise log
likelihood should not be included in the returned object. Values for
``observed_data``, ``constant_data``, ``coords``, and ``dims`` are inferred from
the ``model`` argument if not provided in ``idata_kwargs``. If ``coords`` and
``dims`` are provided, they are used to update the inferred dictionaries.
initial_points : np.ndarray or list[np.ndarray]
Initial point(s) for sampler to begin sampling from.
nuts_kwargs : dict
Keyword arguments for the blackjax nuts sampler
logp_fn : Callable[[ArrayLike], jax.Array], optional, default None
jaxified logp function. If not passed in it will be created anew.

Returns
-------
InferenceData
ArviZ ``InferenceData`` object that contains the posterior samples, together
with their respective sample stats and pointwise log likeihood values (unless
skipped with ``idata_kwargs``).
raw_mcmc_samples
Datastructure containing raw mcmc samples
sample_stats : dict[str, Any]
Dictionary containing sample stats
blackjax : ModuleType["blackjax"]
"""
import blackjax

Expand All @@ -366,15 +372,16 @@ def _sample_blackjax_nuts(
if chains == 1:
initial_points = [np.stack(init_state) for init_state in zip(initial_points)]

logprob_fn = get_jaxified_logp(model)
if logp_fn is None:
logp_fn = get_jaxified_logp(model)

seed = jax.random.PRNGKey(random_seed)
keys = jax.random.split(seed, chains)

nuts_kwargs["progress_bar"] = progressbar
get_posterior_samples = partial(
_blackjax_inference_loop,
logprob_fn=logprob_fn,
logp_fn=logp_fn,
tune=tune,
draws=draws,
target_accept=target_accept,
Expand All @@ -386,7 +393,7 @@ def _sample_blackjax_nuts(


# Adopted from arviz numpyro extractor
def _numpyro_stats_to_dict(posterior):
def _numpyro_stats_to_dict(posterior) -> dict[str, Any]:
"""Extract sample_stats from NumPyro posterior."""
rename_key = {
"potential_energy": "lp",
Expand All @@ -412,17 +419,58 @@ def _sample_numpyro_nuts(
tune: int,
draws: int,
chains: int,
chain_method: str | None,
chain_method: Literal["parallel", "vectorized"],
progressbar: bool,
random_seed: int,
initial_points,
initial_points: np.ndarray | list[np.ndarray],
nuts_kwargs: dict[str, Any],
):
logp_fn: Callable[[ArrayLike], jax.Array] | None = None,
) -> tuple[Any, dict[str, Any], ModuleType]:
"""
Draw samples from the posterior using the NUTS method from the ``numpyro`` library.

Parameters
----------
model : Model
Model to sample from. The model needs to have free random variables.
target_accept : float in [0, 1].
The step size is tuned such that we approximate this acceptance rate. Higher
values like 0.9 or 0.95 often work better for problematic posteriors.
tune : int
Number of iterations to tune. Samplers adjust the step sizes, scalings or
similar during tuning. Tuning samples will be drawn in addition to the number
specified in the ``draws`` argument.
draws : int
The number of samples to draw. The number of tuned samples are discarded by default.
chains : int
The number of chains to sample.
chain_method : "parallel" or "vectorized"
Specify how samples should be drawn.
progressbar : bool
Whether to show progressbar or not during sampling.
random_seed : int, RandomState or Generator
Random seed used by the sampling steps.
initial_points : np.ndarray or list[np.ndarray]
Initial point(s) for sampler to begin sampling from.
nuts_kwargs : dict
Keyword arguments for the underlying numpyro nuts sampler
logp_fn : Callable[[ArrayLike], jax.Array], optional, default None
jaxified logp function. If not passed in it will be created anew.

Returns
-------
raw_mcmc_samples
Datastructure containing raw mcmc samples
sample_stats : dict[str, Any]
Dictionary containing sample stats
numpyro : ModuleType["numpyro"]
"""
import numpyro

from numpyro.infer import MCMC, NUTS

logp_fn = get_jaxified_logp(model, negative_logp=False)
if logp_fn is None:
logp_fn = get_jaxified_logp(model, negative_logp=False)

nuts_kwargs.setdefault("adapt_step_size", True)
nuts_kwargs.setdefault("adapt_mass_matrix", True)
Expand Down Expand Up @@ -480,7 +528,7 @@ def sample_jax_nuts(
nuts_kwargs: dict | None = None,
progressbar: bool = True,
keep_untransformed: bool = False,
chain_method: str = "parallel",
chain_method: Literal["parallel", "vectorized"] = "parallel",
postprocessing_backend: Literal["cpu", "gpu"] | None = None,
postprocessing_vectorize: Literal["vmap", "scan"] | None = None,
postprocessing_chunks=None,
Expand Down Expand Up @@ -526,7 +574,7 @@ def sample_jax_nuts(
If True, display a progressbar while sampling
keep_untransformed : bool, default False
Include untransformed variables in the posterior samples.
chain_method : str, default "parallel"
chain_method : Literal["parallel", "vectorized"], default "parallel"
Specify how samples should be drawn. The choices include "parallel", and
"vectorized".
postprocessing_backend : Optional[Literal["cpu", "gpu"]], default None,
Expand Down Expand Up @@ -590,6 +638,15 @@ def sample_jax_nuts(
get_default_varnames(filtered_var_names, include_transformed=keep_untransformed)
)

if nuts_sampler == "numpyro":
sampler_fn = _sample_numpyro_nuts
logp_fn = get_jaxified_logp(model, negative_logp=False)
elif nuts_sampler == "blackjax":
sampler_fn = _sample_blackjax_nuts
logp_fn = get_jaxified_logp(model)
else:
raise ValueError(f"{nuts_sampler=} not recognized")

(random_seed,) = _get_seeds_per_chain(random_seed, 1)

initial_points = _get_batched_jittered_initial_points(
Expand All @@ -598,15 +655,9 @@ def sample_jax_nuts(
initvals=initvals,
random_seed=random_seed,
jitter=jitter,
logp_fn=logp_fn,
)

if nuts_sampler == "numpyro":
sampler_fn = _sample_numpyro_nuts
elif nuts_sampler == "blackjax":
sampler_fn = _sample_blackjax_nuts
else:
raise ValueError(f"{nuts_sampler=} not recognized")

tic1 = datetime.now()
raw_mcmc_samples, sample_stats, library = sampler_fn(
model=model,
Expand Down
Loading
Loading