diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index 795050e5aa6..6fdbd64e730 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -26,7 +26,8 @@ from aesara.link.jax.dispatch import jax_funcify from pymc import Model, modelcontext -from pymc.aesaraf import compile_rv_inplace +from pymc.aesaraf import compile_rv_inplace, inputvars +from pymc.util import get_default_varnames warnings.warn("This module is experimental.") @@ -101,6 +102,7 @@ def sample_numpyro_nuts( target_accept=0.8, random_seed=10, model=None, + var_names=None, progress_bar=True, keep_untransformed=False, ): @@ -108,6 +110,11 @@ def sample_numpyro_nuts( model = modelcontext(model) + if var_names is None: + var_names = model.unobserved_value_vars + + vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed)) + tic1 = pd.Timestamp.now() print("Compiling...", file=sys.stdout) @@ -143,45 +150,28 @@ def sample_numpyro_nuts( seed = jax.random.PRNGKey(random_seed) map_seed = jax.random.split(seed, chains) - pmap_numpyro.run(map_seed, init_params=init_state_batched, extra_fields=("num_steps",)) + if chains == 1: + pmap_numpyro.run(seed, init_params=init_state, extra_fields=("num_steps",)) + else: + pmap_numpyro.run(map_seed, init_params=init_state_batched, extra_fields=("num_steps",)) + raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True) tic3 = pd.Timestamp.now() print("Sampling time = ", tic3 - tic2, file=sys.stdout) print("Transforming variables...", file=sys.stdout) - mcmc_samples = [] - for i, (value_var, raw_samples) in enumerate(zip(model.value_vars, raw_mcmc_samples)): - raw_samples = at.constant(np.asarray(raw_samples)) - - rv = model.values_to_rvs[value_var] - transform = getattr(value_var.tag, "transform", None) - - if transform is not None: - # TODO: This will fail when the transformation depends on another variable - # such as in interval transform with RVs as edges - trans_samples = transform.backward(raw_samples, *rv.owner.inputs) - trans_samples.name = rv.name - mcmc_samples.append(trans_samples) - - if keep_untransformed: - raw_samples.name = value_var.name - mcmc_samples.append(raw_samples) - else: - raw_samples.name = rv.name - mcmc_samples.append(raw_samples) - - mcmc_varnames = [var.name for var in mcmc_samples] - mcmc_samples = compile_rv_inplace( - [], - mcmc_samples, - mode="JAX", - )() + mcmc_samples = {} + for v in vars_to_sample: + fgraph = FunctionGraph(model.value_vars, [v], clone=False) + jax_fn = jax_funcify(fgraph) + result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0] + mcmc_samples[v.name] = result tic4 = pd.Timestamp.now() print("Transformation time = ", tic4 - tic3, file=sys.stdout) - posterior = {k: v for k, v in zip(mcmc_varnames, mcmc_samples)} + posterior = mcmc_samples az_trace = az.from_dict(posterior=posterior) return az_trace diff --git a/pymc/tests/test_sampling_jax.py b/pymc/tests/test_sampling_jax.py index e5180e19f8d..3fd04059c09 100644 --- a/pymc/tests/test_sampling_jax.py +++ b/pymc/tests/test_sampling_jax.py @@ -44,6 +44,23 @@ def test_transform_samples(): assert 1.5 < trace.posterior["sigma"].mean() < 2.5 +def test_deterministic_samples(): + aesara.config.on_opt_error = "raise" + np.random.seed(13244) + + obs = np.random.normal(10, 2, size=100) + obs_at = aesara.shared(obs, borrow=True, name="obs") + with pm.Model() as model: + a = pm.Uniform("a", -20, 20) + b = pm.Deterministic("b", a / 2.0) + c = pm.Normal("c", a, sigma=1.0, observed=obs_at) + + trace = sample_numpyro_nuts(chains=2, random_seed=1322, keep_untransformed=True) + + assert 8 < trace.posterior["a"].mean() < 11 + assert np.allclose(trace.posterior["b"].values, trace.posterior["a"].values / 2) + + def test_replace_shared_variables(): x = aesara.shared(5, name="shared_x")