Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding log_likelihood, observed_data, and sample_stats to numpyro sampler #5189

Merged
40 changes: 21 additions & 19 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,26 @@
Var = Any # pylint: disable=invalid-name


def find_observations(model: Optional["Model"]) -> Optional[Dict[str, Var]]:
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
"""If there are observations available, return them as a dictionary."""
if model is None:
return None

observations = {}
for obs in model.observed_RVs:
aux_obs = getattr(obs.tag, "observations", None)
if aux_obs is not None:
try:
obs_data = extract_obs_data(aux_obs)
observations[obs.name] = obs_data
except TypeError:
warnings.warn(f"Could not extract data from symbolic observation {obs}")
else:
warnings.warn(f"No data for observation {obs}")

return observations


class _DefaultTrace:
"""
Utility for collecting samples into a dictionary.
Expand Down Expand Up @@ -196,25 +216,7 @@ def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray:
self.dims = {**model_dims, **self.dims}

self.density_dist_obs = density_dist_obs
self.observations = self.find_observations()

def find_observations(self) -> Optional[Dict[str, Var]]:
"""If there are observations available, return them as a dictionary."""
if self.model is None:
return None
observations = {}
for obs in self.model.observed_RVs:
aux_obs = getattr(obs.tag, "observations", None)
if aux_obs is not None:
try:
obs_data = extract_obs_data(aux_obs)
observations[obs.name] = obs_data
except TypeError:
warnings.warn(f"Could not extract data from symbolic observation {obs}")
else:
warnings.warn(f"No data for observation {obs}")

return observations
self.observations = find_observations(self.model)

def split_trace(self) -> Tuple[Union[None, "MultiTrace"], Union[None, "MultiTrace"]]:
"""Split MultiTrace object into posterior and warmup.
Expand Down
62 changes: 58 additions & 4 deletions pymc/sampling_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
from aesara.link.jax.dispatch import jax_funcify

from pymc import Model, modelcontext
from pymc.aesaraf import compile_rv_inplace, inputvars
from pymc.aesaraf import compile_rv_inplace
from pymc.backends.arviz import find_observations
from pymc.distributions import logpt
from pymc.util import get_default_varnames

warnings.warn("This module is experimental.")
Expand Down Expand Up @@ -95,6 +97,39 @@ def logp_fn_wrap(x):
return logp_fn_wrap


# Adopted from arviz numpyro extractor
def _sample_stats_to_xarray(posterior):
"""Extract sample_stats from NumPyro posterior."""
rename_key = {
"potential_energy": "lp",
"adapt_state.step_size": "step_size",
"num_steps": "n_steps",
"accept_prob": "acceptance_rate",
}
data = {}
for stat, value in posterior.get_extra_fields(group_by_chain=True).items():
if isinstance(value, (dict, tuple)):
continue
name = rename_key.get(stat, stat)
value = value.copy()
data[name] = value
if stat == "num_steps":
data["tree_depth"] = np.log2(value).astype(int) + 1
return data


def _get_log_likelihood(model, samples):
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
"Compute log-likelihood for all observations"
data = {}
for v in model.observed_RVs:
logp_v = replace_shared_variables([logpt(v)])
fgraph = FunctionGraph(model.value_vars, logp_v, clone=False)
jax_fn = jax_funcify(fgraph)
result = jax.vmap(jax.vmap(jax_fn))(*samples)[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, would we expect any benefits to jit_compiling this outer vmap?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to use a similar approach with Aesara directly?

Here we only loop over observed variables in order to get the pointwise log likelihood. We had some discussion about this in #4489 but ended up keeping the 3 nested loops over variables, chains and draws.

Copy link
Member

@ricardoV94 ricardoV94 Nov 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be possible, but requires a Aesara Scan, and at least for small models this was not faster than python looping when I checked it. Here is a Notebook that documents some things I tried: https://gist.github.com/ricardoV94/6089a8c46a0e19665f01c79ea04e1cb2

It might be faster if using shared variables...

Copy link
Contributor Author

@zaxtax zaxtax Nov 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No idea. I think the easiest thing to do is just benchmark it. I don't even call optimize_graph on either the graph in this function or the main sample routine.

When I run the model in the unit test with the change

result = jax.vmap(jax.vmap(jax_fn))(*samples)[0] to
result = jax.jit(jax.vmap(jax.vmap(jax_fn)))(*samples)[0]

I don't really get a speed-up until there are millions of samples.

Copy link
Member

@ricardoV94 ricardoV94 Nov 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't even call optimize_graph on either the graph in this function or the main sample routine

We should definitely call optimize_graph, otherwise the computed logps may not correspond to the ones used during sampling. For instance we have many optimizations that improve numerically stability, so you might get underflows to -inf for some of the posterior samples (which would never have been accepted by NUTS) which could screw up things downstream.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be possible, but requires a Aesara Scan, and at least for small models this was not faster than python looping when I checked it.

Then it's probably not worth it. I was under the impression it would be possible to vectorize/broadcast the operation from the conversations in #4489 and in slack.

Copy link
Member

@ricardoV94 ricardoV94 Nov 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It must be possible, since the vmap above works just fine. I just have no idea how they do it xD, or how/if you could do it in Aesara. I also wonder whether the vmap works for more complicated models with multivariate distributions and the like

Copy link
Contributor Author

@zaxtax zaxtax Nov 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright. I'm going to make a separate PR for some of this other stuff.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, feel free to tag me if you want me to review, I am not watching PRs. I can already say I won't be able to help with the vectorized log_likelihood thing, I tried and I lost much more time with that than what would have been healthy. I should be able to help with coords and dims though

data[v.name] = result
return data


def sample_numpyro_nuts(
draws=1000,
tune=1000,
Expand Down Expand Up @@ -151,9 +186,23 @@ def sample_numpyro_nuts(
map_seed = jax.random.split(seed, chains)

if chains == 1:
pmap_numpyro.run(seed, init_params=init_state, extra_fields=("num_steps",))
init_params = init_state
map_seed = seed
else:
pmap_numpyro.run(map_seed, init_params=init_state_batched, extra_fields=("num_steps",))
init_params = init_state_batched

pmap_numpyro.run(
map_seed,
init_params=init_params,
extra_fields=(
"num_steps",
"potential_energy",
"energy",
"adapt_state.step_size",
"accept_prob",
"diverging",
),
)

raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)

Expand All @@ -172,6 +221,11 @@ def sample_numpyro_nuts(
print("Transformation time = ", tic4 - tic3, file=sys.stdout)

posterior = mcmc_samples
az_trace = az.from_dict(posterior=posterior)
az_posterior = az.from_dict(posterior=posterior)

az_obs = az.from_dict(observed_data=find_observations(model))
az_stats = az.from_dict(sample_stats=_sample_stats_to_xarray(pmap_numpyro))
az_ll = az.from_dict(log_likelihood=_get_log_likelihood(model, raw_mcmc_samples))
az_trace = az.concat(az_posterior, az_ll, az_obs, az_stats)

return az_trace
19 changes: 19 additions & 0 deletions pymc/tests/test_sampling_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pymc as pm

from pymc.sampling_jax import (
_get_log_likelihood,
get_jaxified_logp,
replace_shared_variables,
sample_numpyro_nuts,
Expand Down Expand Up @@ -61,6 +62,24 @@ def test_deterministic_samples():
assert np.allclose(trace.posterior["b"].values, trace.posterior["a"].values / 2)


def test_get_log_likelihood():
obs = np.random.normal(10, 2, size=100)
obs_at = aesara.shared(obs, borrow=True, name="obs")
with pm.Model() as model:
a = pm.Normal("a", 0, 2)
sigma = pm.HalfNormal("sigma")
b = pm.Normal("b", a, sigma=sigma, observed=obs_at)

trace = pm.sample(tune=10, draws=10, chains=2, random_seed=1322)

b_true = trace.log_likelihood.b.values
a = np.array(trace.posterior.a)
sigma_log_ = np.log(np.array(trace.posterior.sigma))
b_jax = _get_log_likelihood(model, [a, sigma_log_])["b"]

assert np.allclose(b_jax.reshape(-1), b_true.reshape(-1))


def test_replace_shared_variables():
x = aesara.shared(5, name="shared_x")

Expand Down