From 807e06866e477dc713ff95cc00e3dafa6d110bbd Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Sun, 14 Nov 2021 18:57:31 +0000 Subject: [PATCH 1/6] Adding Deterministic for sampling_numpyro --- pymc/sampling_jax.py | 61 +++++++++++++++++---------------- pymc/tests/test_sampling_jax.py | 17 +++++++++ 2 files changed, 49 insertions(+), 29 deletions(-) diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index 795050e5aa6..48dde69ad7a 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -27,6 +27,7 @@ from pymc import Model, modelcontext from pymc.aesaraf import compile_rv_inplace +from pymc.util import get_default_varnames warnings.warn("This module is experimental.") @@ -101,6 +102,7 @@ def sample_numpyro_nuts( target_accept=0.8, random_seed=10, model=None, + var_names=None, progress_bar=True, keep_untransformed=False, ): @@ -108,6 +110,11 @@ def sample_numpyro_nuts( model = modelcontext(model) + if var_names is None: + var_names = model.unobserved_value_vars + + vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed)) + tic1 = pd.Timestamp.now() print("Compiling...", file=sys.stdout) @@ -116,6 +123,7 @@ def sample_numpyro_nuts( init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state) logp_fn = get_jaxified_logp(model) + fn = model.fastfn(vars_to_sample) nuts_kernel = NUTS( potential_fn=logp_fn, @@ -143,45 +151,40 @@ def sample_numpyro_nuts( seed = jax.random.PRNGKey(random_seed) map_seed = jax.random.split(seed, chains) - pmap_numpyro.run(map_seed, init_params=init_state_batched, extra_fields=("num_steps",)) + if chains == 1: + pmap_numpyro.run(seed, init_params=init_state, extra_fields=("num_steps",)) + else: + pmap_numpyro.run(map_seed, init_params=init_state_batched, extra_fields=("num_steps",)) + raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True) tic3 = pd.Timestamp.now() print("Sampling time = ", tic3 - tic2, file=sys.stdout) print("Transforming variables...", file=sys.stdout) - mcmc_samples = [] - for i, (value_var, raw_samples) in enumerate(zip(model.value_vars, raw_mcmc_samples)): - raw_samples = at.constant(np.asarray(raw_samples)) - - rv = model.values_to_rvs[value_var] - transform = getattr(value_var.tag, "transform", None) - - if transform is not None: - # TODO: This will fail when the transformation depends on another variable - # such as in interval transform with RVs as edges - trans_samples = transform.backward(raw_samples, *rv.owner.inputs) - trans_samples.name = rv.name - mcmc_samples.append(trans_samples) - - if keep_untransformed: - raw_samples.name = value_var.name - mcmc_samples.append(raw_samples) - else: - raw_samples.name = rv.name - mcmc_samples.append(raw_samples) - - mcmc_varnames = [var.name for var in mcmc_samples] - mcmc_samples = compile_rv_inplace( - [], - mcmc_samples, - mode="JAX", - )() + mcmc_samples = {} + for v in vars_to_sample: + mcmc_samples[v.name] = [] + + for i in range(draws): + for c in range(chains): + draw = dict( + (value_var.name, raw_samples[c, i]) + for value_var, raw_samples in zip(model.value_vars, raw_mcmc_samples) + ) + sample = fn(draw) + for vi, v in enumerate(vars_to_sample): + mcmc_samples[v.name].append(sample[vi]) + + for v in vars_to_sample: + mcmc_samples[v.name] = np.array(mcmc_samples[v.name]).reshape( + (chains, draws) + mcmc_samples[v.name][-1].shape + ) tic4 = pd.Timestamp.now() print("Transformation time = ", tic4 - tic3, file=sys.stdout) - posterior = {k: v for k, v in zip(mcmc_varnames, mcmc_samples)} + posterior = mcmc_samples az_trace = az.from_dict(posterior=posterior) return az_trace diff --git a/pymc/tests/test_sampling_jax.py b/pymc/tests/test_sampling_jax.py index e5180e19f8d..b23b1d801e4 100644 --- a/pymc/tests/test_sampling_jax.py +++ b/pymc/tests/test_sampling_jax.py @@ -44,6 +44,23 @@ def test_transform_samples(): assert 1.5 < trace.posterior["sigma"].mean() < 2.5 +def test_deterministic_samples(): + aesara.config.on_opt_error = "raise" + np.random.seed(13244) + + obs = np.random.normal(10, 2, size=100) + obs_at = aesara.shared(obs, borrow=True, name="obs") + with pm.Model() as model: + a = pm.Normal("a", 0, 1) + b = pm.Deterministic("b", a / 2.0) + c = pm.Normal("c", a, sigma=1.0, observed=obs_at) + + trace = sample_numpyro_nuts(chains=2, random_seed=1322, keep_untransformed=True) + + assert 8 < trace.posterior["a"].mean() < 11 + assert 4 < trace.posterior["b"].mean() < 6 + + def test_replace_shared_variables(): x = aesara.shared(5, name="shared_x") From 743d15a69c972180ecf581720d8d6e3985c27b21 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Sun, 14 Nov 2021 19:13:44 +0000 Subject: [PATCH 2/6] Obey precommit hook --- pymc/sampling_jax.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index 48dde69ad7a..9fce39dbde3 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -168,10 +168,10 @@ def sample_numpyro_nuts( for i in range(draws): for c in range(chains): - draw = dict( - (value_var.name, raw_samples[c, i]) + draw = { + value_var.name: raw_samples[c, i] for value_var, raw_samples in zip(model.value_vars, raw_mcmc_samples) - ) + } sample = fn(draw) for vi, v in enumerate(vars_to_sample): mcmc_samples[v.name].append(sample[vi]) From 31db9e62ec094d553c411bb506cb9a12115726ae Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Mon, 15 Nov 2021 17:02:14 +0000 Subject: [PATCH 3/6] Extract values efficiently --- pymc/sampling_jax.py | 25 +++++++------------------ pymc/tests/test_sampling_jax.py | 2 +- 2 files changed, 8 insertions(+), 19 deletions(-) diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index 9fce39dbde3..78a3176a7ba 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -26,7 +26,7 @@ from aesara.link.jax.dispatch import jax_funcify from pymc import Model, modelcontext -from pymc.aesaraf import compile_rv_inplace +from pymc.aesaraf import compile_rv_inplace, inputvars from pymc.util import get_default_varnames warnings.warn("This module is experimental.") @@ -114,6 +114,8 @@ def sample_numpyro_nuts( var_names = model.unobserved_value_vars vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed)) + inputs = [model.rvs_to_values[i] for i in model.free_RVs] + tic1 = pd.Timestamp.now() print("Compiling...", file=sys.stdout) @@ -123,7 +125,6 @@ def sample_numpyro_nuts( init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state) logp_fn = get_jaxified_logp(model) - fn = model.fastfn(vars_to_sample) nuts_kernel = NUTS( potential_fn=logp_fn, @@ -164,22 +165,10 @@ def sample_numpyro_nuts( print("Transforming variables...", file=sys.stdout) mcmc_samples = {} for v in vars_to_sample: - mcmc_samples[v.name] = [] - - for i in range(draws): - for c in range(chains): - draw = { - value_var.name: raw_samples[c, i] - for value_var, raw_samples in zip(model.value_vars, raw_mcmc_samples) - } - sample = fn(draw) - for vi, v in enumerate(vars_to_sample): - mcmc_samples[v.name].append(sample[vi]) - - for v in vars_to_sample: - mcmc_samples[v.name] = np.array(mcmc_samples[v.name]).reshape( - (chains, draws) + mcmc_samples[v.name][-1].shape - ) + fgraph = FunctionGraph(inputs, [v], clone=False) + jax_fn = jax_funcify(fgraph) + result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0] + mcmc_samples[v.name] = result tic4 = pd.Timestamp.now() print("Transformation time = ", tic4 - tic3, file=sys.stdout) diff --git a/pymc/tests/test_sampling_jax.py b/pymc/tests/test_sampling_jax.py index b23b1d801e4..d9eb8ad1092 100644 --- a/pymc/tests/test_sampling_jax.py +++ b/pymc/tests/test_sampling_jax.py @@ -58,7 +58,7 @@ def test_deterministic_samples(): trace = sample_numpyro_nuts(chains=2, random_seed=1322, keep_untransformed=True) assert 8 < trace.posterior["a"].mean() < 11 - assert 4 < trace.posterior["b"].mean() < 6 + assert np.allclose(trace.posterior["b"].values, trace.posterior["a"].values / 2) def test_replace_shared_variables(): From 5fab7a7d014c0b079cc7dcc96e0bd6c1fe0fc887 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Mon, 15 Nov 2021 17:11:26 +0000 Subject: [PATCH 4/6] Change Normal to Uniform to test combinations of transforms and deterministic --- pymc/tests/test_sampling_jax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/tests/test_sampling_jax.py b/pymc/tests/test_sampling_jax.py index d9eb8ad1092..3fd04059c09 100644 --- a/pymc/tests/test_sampling_jax.py +++ b/pymc/tests/test_sampling_jax.py @@ -51,7 +51,7 @@ def test_deterministic_samples(): obs = np.random.normal(10, 2, size=100) obs_at = aesara.shared(obs, borrow=True, name="obs") with pm.Model() as model: - a = pm.Normal("a", 0, 1) + a = pm.Uniform("a", -20, 20) b = pm.Deterministic("b", a / 2.0) c = pm.Normal("c", a, sigma=1.0, observed=obs_at) From 13211abc570b19cede70808833c6a905a2926e94 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Mon, 15 Nov 2021 17:13:58 +0000 Subject: [PATCH 5/6] lint fix --- pymc/sampling_jax.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index 78a3176a7ba..5d318a33d05 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -116,7 +116,6 @@ def sample_numpyro_nuts( vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed)) inputs = [model.rvs_to_values[i] for i in model.free_RVs] - tic1 = pd.Timestamp.now() print("Compiling...", file=sys.stdout) From 5ead708444c6a42012aa69b43649a145321b08ef Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Mon, 15 Nov 2021 17:28:01 +0000 Subject: [PATCH 6/6] Use value_vars directly --- pymc/sampling_jax.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index 5d318a33d05..6fdbd64e730 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -114,7 +114,6 @@ def sample_numpyro_nuts( var_names = model.unobserved_value_vars vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed)) - inputs = [model.rvs_to_values[i] for i in model.free_RVs] tic1 = pd.Timestamp.now() print("Compiling...", file=sys.stdout) @@ -164,7 +163,7 @@ def sample_numpyro_nuts( print("Transforming variables...", file=sys.stdout) mcmc_samples = {} for v in vars_to_sample: - fgraph = FunctionGraph(inputs, [v], clone=False) + fgraph = FunctionGraph(model.value_vars, [v], clone=False) jax_fn = jax_funcify(fgraph) result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0] mcmc_samples[v.name] = result