From 549f7141030219b7acf0aa7fc3ac261488d91e79 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Tue, 16 Nov 2021 02:13:44 +0000 Subject: [PATCH 01/13] Adding observed_data and sample_stats to numpyro sampler --- pymc/sampling_jax.py | 66 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 61 insertions(+), 5 deletions(-) diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index 6fdbd64e730..312381d463f 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -4,7 +4,7 @@ import sys import warnings -from typing import Callable, List +from typing import Callable, List, Dict, Optional, Any from aesara.graph import optimize_graph from aesara.tensor import TensorVariable @@ -26,7 +26,7 @@ from aesara.link.jax.dispatch import jax_funcify from pymc import Model, modelcontext -from pymc.aesaraf import compile_rv_inplace, inputvars +from pymc.aesaraf import compile_rv_inplace, extract_obs_data from pymc.util import get_default_varnames warnings.warn("This module is experimental.") @@ -95,6 +95,45 @@ def logp_fn_wrap(x): return logp_fn_wrap +# Adopted from pm.to_inference_data +def find_observations(model: Model) -> Dict[str, Any]: + """If there are observations available, return them as a dictionary.""" + observations = {} + for obs in model.observed_RVs: + aux_obs = getattr(obs.tag, "observations", None) + if aux_obs is not None: + try: + obs_data = extract_obs_data(aux_obs) + observations[obs.name] = obs_data + except TypeError: + warnings.warn(f"Could not extract data from symbolic observation {obs}") + else: + warnings.warn(f"No data for observation {obs}") + + return observations + + +# Adopted from arviz numpyro extractor +def sample_stats_to_xarray(posterior): + """Extract sample_stats from NumPyro posterior.""" + rename_key = { + "potential_energy": "lp", + "adapt_state.step_size": "step_size", + "num_steps": "n_steps", + "accept_prob": "acceptance_rate", + } + data = {} + for stat, value in posterior.get_extra_fields(group_by_chain=True).items(): + if isinstance(value, (dict, tuple)): + continue + name = rename_key.get(stat, stat) + value = value.copy() + data[name] = value + if stat == "num_steps": + data["tree_depth"] = np.log2(value).astype(int) + 1 + return data + + def sample_numpyro_nuts( draws=1000, tune=1000, @@ -151,9 +190,22 @@ def sample_numpyro_nuts( map_seed = jax.random.split(seed, chains) if chains == 1: - pmap_numpyro.run(seed, init_params=init_state, extra_fields=("num_steps",)) + init_params=init_state else: - pmap_numpyro.run(map_seed, init_params=init_state_batched, extra_fields=("num_steps",)) + init_params=init_state_batched + + pmap_numpyro.run( + map_seed, + init_params=init_params, + extra_fields=( + "num_steps", + "potential_energy", + "energy", + "adapt_state.step_size", + "accept_prob", + "diverging", + ), + ) raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True) @@ -172,6 +224,10 @@ def sample_numpyro_nuts( print("Transformation time = ", tic4 - tic3, file=sys.stdout) posterior = mcmc_samples - az_trace = az.from_dict(posterior=posterior) + az_posterior = az.from_dict(posterior=posterior) + + az_obs = az.from_dict(observed_data=find_observations(model)) + az_stats = az.from_dict(sample_stats=sample_stats_to_xarray(pmap_numpyro)) + az_trace = az.concat(az_posterior, az_obs, az_stats) return az_trace From ff3a0cd0d33fcb32ed0f3ec3c6e81177692d8de9 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Tue, 16 Nov 2021 02:21:22 +0000 Subject: [PATCH 02/13] isort fix --- pymc/sampling_jax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index 312381d463f..5300a3b569c 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -4,7 +4,7 @@ import sys import warnings -from typing import Callable, List, Dict, Optional, Any +from typing import Any, Callable, Dict, List, Optional from aesara.graph import optimize_graph from aesara.tensor import TensorVariable From 6987baca81e84762a37de4f2088e50d45bfcb511 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Tue, 16 Nov 2021 13:29:16 +0000 Subject: [PATCH 03/13] Refactor find_observations --- pymc/backends/arviz.py | 42 ++++++++++++++++++++++-------------------- pymc/sampling_jax.py | 19 +------------------ 2 files changed, 23 insertions(+), 38 deletions(-) diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index c08eb068ac2..28919a83523 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -25,14 +25,13 @@ from pymc.aesaraf import extract_obs_data from pymc.distributions import logpt -from pymc.model import modelcontext +from pymc.model import modelcontext, Model from pymc.util import get_default_varnames if TYPE_CHECKING: from typing import Set # pylint: disable=ungrouped-imports from pymc.backends.base import MultiTrace # pylint: disable=invalid-name - from pymc.model import Model ___all__ = [""] @@ -42,6 +41,26 @@ Var = Any # pylint: disable=invalid-name +def find_observations(model: Model) -> Optional[Dict[str, Var]]: + """If there are observations available, return them as a dictionary.""" + if model is None: + return None + + observations = {} + for obs in model.observed_RVs: + aux_obs = getattr(obs.tag, "observations", None) + if aux_obs is not None: + try: + obs_data = extract_obs_data(aux_obs) + observations[obs.name] = obs_data + except TypeError: + warnings.warn(f"Could not extract data from symbolic observation {obs}") + else: + warnings.warn(f"No data for observation {obs}") + + return observations + + class _DefaultTrace: """ Utility for collecting samples into a dictionary. @@ -196,25 +215,8 @@ def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray: self.dims = {**model_dims, **self.dims} self.density_dist_obs = density_dist_obs - self.observations = self.find_observations() - - def find_observations(self) -> Optional[Dict[str, Var]]: - """If there are observations available, return them as a dictionary.""" - if self.model is None: - return None - observations = {} - for obs in self.model.observed_RVs: - aux_obs = getattr(obs.tag, "observations", None) - if aux_obs is not None: - try: - obs_data = extract_obs_data(aux_obs) - observations[obs.name] = obs_data - except TypeError: - warnings.warn(f"Could not extract data from symbolic observation {obs}") - else: - warnings.warn(f"No data for observation {obs}") + self.observations = find_observations(self.model) - return observations def split_trace(self) -> Tuple[Union[None, "MultiTrace"], Union[None, "MultiTrace"]]: """Split MultiTrace object into posterior and warmup. diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index 5300a3b569c..c23ee0f5509 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -27,6 +27,7 @@ from pymc import Model, modelcontext from pymc.aesaraf import compile_rv_inplace, extract_obs_data +from pymc.backends.arviz import find_observations from pymc.util import get_default_varnames warnings.warn("This module is experimental.") @@ -95,24 +96,6 @@ def logp_fn_wrap(x): return logp_fn_wrap -# Adopted from pm.to_inference_data -def find_observations(model: Model) -> Dict[str, Any]: - """If there are observations available, return them as a dictionary.""" - observations = {} - for obs in model.observed_RVs: - aux_obs = getattr(obs.tag, "observations", None) - if aux_obs is not None: - try: - obs_data = extract_obs_data(aux_obs) - observations[obs.name] = obs_data - except TypeError: - warnings.warn(f"Could not extract data from symbolic observation {obs}") - else: - warnings.warn(f"No data for observation {obs}") - - return observations - - # Adopted from arviz numpyro extractor def sample_stats_to_xarray(posterior): """Extract sample_stats from NumPyro posterior.""" From af108d43b7b4a5f185d4dfc4517a9b1d6df48da4 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Tue, 16 Nov 2021 13:36:22 +0000 Subject: [PATCH 04/13] Format fixes --- pymc/backends/arviz.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 28919a83523..34a511db7e5 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -25,7 +25,7 @@ from pymc.aesaraf import extract_obs_data from pymc.distributions import logpt -from pymc.model import modelcontext, Model +from pymc.model import Model, modelcontext from pymc.util import get_default_varnames if TYPE_CHECKING: @@ -217,7 +217,6 @@ def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray: self.density_dist_obs = density_dist_obs self.observations = find_observations(self.model) - def split_trace(self) -> Tuple[Union[None, "MultiTrace"], Union[None, "MultiTrace"]]: """Split MultiTrace object into posterior and warmup. From f295288a946ebd74b23c83d7179232085e88ee83 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Tue, 16 Nov 2021 13:37:40 +0000 Subject: [PATCH 05/13] Remove more unused bits --- pymc/sampling_jax.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index c23ee0f5509..88dfae07785 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -4,7 +4,7 @@ import sys import warnings -from typing import Any, Callable, Dict, List, Optional +from typing import Callable, List from aesara.graph import optimize_graph from aesara.tensor import TensorVariable @@ -26,7 +26,7 @@ from aesara.link.jax.dispatch import jax_funcify from pymc import Model, modelcontext -from pymc.aesaraf import compile_rv_inplace, extract_obs_data +from pymc.aesaraf import compile_rv_inplace from pymc.backends.arviz import find_observations from pymc.util import get_default_varnames From 8a28fad225324c072ae81bbcf2d3fc5076b48c7a Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Tue, 16 Nov 2021 14:18:31 +0000 Subject: [PATCH 06/13] Fix typo annotations --- pymc/backends/arviz.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 34a511db7e5..26b784b36f7 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -25,13 +25,14 @@ from pymc.aesaraf import extract_obs_data from pymc.distributions import logpt -from pymc.model import Model, modelcontext +from pymc.model import modelcontext from pymc.util import get_default_varnames if TYPE_CHECKING: from typing import Set # pylint: disable=ungrouped-imports from pymc.backends.base import MultiTrace # pylint: disable=invalid-name + from pymc.model import Model ___all__ = [""] @@ -41,7 +42,7 @@ Var = Any # pylint: disable=invalid-name -def find_observations(model: Model) -> Optional[Dict[str, Var]]: +def find_observations(model: Optional["Model"]) -> Optional[Dict[str, Var]]: """If there are observations available, return them as a dictionary.""" if model is None: return None From 59ebfdb22286c6402f99440899fbad032586e835 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Tue, 16 Nov 2021 15:36:28 +0000 Subject: [PATCH 07/13] Fix seed --- pymc/sampling_jax.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index 88dfae07785..9d865b6df24 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -174,6 +174,7 @@ def sample_numpyro_nuts( if chains == 1: init_params=init_state + map_seed = seed else: init_params=init_state_batched From f5aeaf68c683c96f822ac4de79e31289bca320d1 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Tue, 16 Nov 2021 15:41:25 +0000 Subject: [PATCH 08/13] Format fix --- pymc/sampling_jax.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index 9d865b6df24..324be22f775 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -173,10 +173,10 @@ def sample_numpyro_nuts( map_seed = jax.random.split(seed, chains) if chains == 1: - init_params=init_state + init_params = init_state map_seed = seed else: - init_params=init_state_batched + init_params = init_state_batched pmap_numpyro.run( map_seed, From bf2ad0d1921206e761969b36588b7f8615a64ddb Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Wed, 17 Nov 2021 16:36:30 +0000 Subject: [PATCH 09/13] Add log likehoods to trace object --- pymc/sampling_jax.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index 324be22f775..efb975c2f00 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -28,6 +28,7 @@ from pymc import Model, modelcontext from pymc.aesaraf import compile_rv_inplace from pymc.backends.arviz import find_observations +from pymc.distributions import logpt from pymc.util import get_default_varnames warnings.warn("This module is experimental.") @@ -97,7 +98,7 @@ def logp_fn_wrap(x): # Adopted from arviz numpyro extractor -def sample_stats_to_xarray(posterior): +def _sample_stats_to_xarray(posterior): """Extract sample_stats from NumPyro posterior.""" rename_key = { "potential_energy": "lp", @@ -117,6 +118,18 @@ def sample_stats_to_xarray(posterior): return data +def _get_log_likelihood(model, samples): + "Compute log-likelihood for all observations" + data = {} + for v in model.observed_RVs: + logp_v = replace_shared_variables([logpt(v)]) + fgraph = FunctionGraph(model.value_vars, logp_v, clone=False) + jax_fn = jax_funcify(fgraph) + result = jax.vmap(jax.vmap(jax_fn))(*samples)[0] + data[v.name] = result + return data + + def sample_numpyro_nuts( draws=1000, tune=1000, @@ -211,7 +224,8 @@ def sample_numpyro_nuts( az_posterior = az.from_dict(posterior=posterior) az_obs = az.from_dict(observed_data=find_observations(model)) - az_stats = az.from_dict(sample_stats=sample_stats_to_xarray(pmap_numpyro)) - az_trace = az.concat(az_posterior, az_obs, az_stats) + az_stats = az.from_dict(sample_stats=_sample_stats_to_xarray(pmap_numpyro)) + az_ll = az.from_dict(log_likelihood=_get_log_likelihood(model, raw_mcmc_samples)) + az_trace = az.concat(az_posterior, az_ll, az_obs, az_stats) return az_trace From a3c6cc2df5e68f303077f7a2a014ee2f02d94ef6 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Wed, 17 Nov 2021 20:19:53 +0000 Subject: [PATCH 10/13] Add test for _get_log_likelihood in sampling_jax --- pymc/tests/test_sampling_jax.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/pymc/tests/test_sampling_jax.py b/pymc/tests/test_sampling_jax.py index 3fd04059c09..21ff77e1806 100644 --- a/pymc/tests/test_sampling_jax.py +++ b/pymc/tests/test_sampling_jax.py @@ -12,6 +12,7 @@ get_jaxified_logp, replace_shared_variables, sample_numpyro_nuts, + _get_log_likelihood, ) @@ -61,6 +62,24 @@ def test_deterministic_samples(): assert np.allclose(trace.posterior["b"].values, trace.posterior["a"].values / 2) +def test_get_log_log_likelihood(): + obs = np.random.normal(10, 2, size=100) + obs_at = aesara.shared(obs, borrow=True, name="obs") + with pm.Model() as model: + a = pm.Normal("a", 0, 2) + sigma = pm.HalfNormal("sigma") + b = pm.Normal("b", a, sigma=sigma, observed=obs_at) + + trace = pm.sample(chains=1, random_seed=1322) + + b_true = trace.log_likelihood.b.values + a = np.array(trace.posterior.a) + sigma_log_ = np.log(np.array(trace.posterior.sigma)) + b_jax = _get_log_likelihood(model, [a, sigma_log_])["b"][0] + + assert np.allclose(b_jax.reshape(-1), b_true.reshape(-1)) + + def test_replace_shared_variables(): x = aesara.shared(5, name="shared_x") From 009887627584c5f52696c2019e5349f8fe126ec4 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Wed, 17 Nov 2021 20:27:46 +0000 Subject: [PATCH 11/13] Format fix --- pymc/tests/test_sampling_jax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/tests/test_sampling_jax.py b/pymc/tests/test_sampling_jax.py index 21ff77e1806..6fd710a2e51 100644 --- a/pymc/tests/test_sampling_jax.py +++ b/pymc/tests/test_sampling_jax.py @@ -9,10 +9,10 @@ import pymc as pm from pymc.sampling_jax import ( + _get_log_likelihood, get_jaxified_logp, replace_shared_variables, sample_numpyro_nuts, - _get_log_likelihood, ) From 5f5cd87945620e12bfd10f9fa8cf8487adcd3535 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Thu, 18 Nov 2021 11:02:18 +0000 Subject: [PATCH 12/13] Use multiple chains --- pymc/tests/test_sampling_jax.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/tests/test_sampling_jax.py b/pymc/tests/test_sampling_jax.py index 6fd710a2e51..ea4be51e588 100644 --- a/pymc/tests/test_sampling_jax.py +++ b/pymc/tests/test_sampling_jax.py @@ -70,12 +70,12 @@ def test_get_log_log_likelihood(): sigma = pm.HalfNormal("sigma") b = pm.Normal("b", a, sigma=sigma, observed=obs_at) - trace = pm.sample(chains=1, random_seed=1322) + trace = pm.sample(tune=10, draws=10, chains=2, random_seed=1322) b_true = trace.log_likelihood.b.values a = np.array(trace.posterior.a) sigma_log_ = np.log(np.array(trace.posterior.sigma)) - b_jax = _get_log_likelihood(model, [a, sigma_log_])["b"][0] + b_jax = _get_log_likelihood(model, [a, sigma_log_])["b"] assert np.allclose(b_jax.reshape(-1), b_true.reshape(-1)) From 6e4fcab41296b7987aef7970af8a7de5034ab166 Mon Sep 17 00:00:00 2001 From: zaxtax Date: Thu, 18 Nov 2021 11:48:31 +0000 Subject: [PATCH 13/13] Update pymc/tests/test_sampling_jax.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pymc/tests/test_sampling_jax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/tests/test_sampling_jax.py b/pymc/tests/test_sampling_jax.py index ea4be51e588..172eceb4d07 100644 --- a/pymc/tests/test_sampling_jax.py +++ b/pymc/tests/test_sampling_jax.py @@ -62,7 +62,7 @@ def test_deterministic_samples(): assert np.allclose(trace.posterior["b"].values, trace.posterior["a"].values / 2) -def test_get_log_log_likelihood(): +def test_get_log_likelihood(): obs = np.random.normal(10, 2, size=100) obs_at = aesara.shared(obs, borrow=True, name="obs") with pm.Model() as model: