From 493855fbbc5e74de0307b38d816deb444dd15e7c Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 5 Dec 2024 14:57:26 +0100 Subject: [PATCH] Add Normal summary stats optimization --- pymc_experimental/sampling/__init__.py | 3 + pymc_experimental/sampling/mcmc.py | 2 +- .../sampling/optimizations/__init__.py | 0 .../sampling/optimizations/summary_stats.py | 79 +++++++++++++++++++ tests/sampling/__init__.py | 0 tests/sampling/mcmc/__init__.py | 0 tests/sampling/mcmc/test_mcmc.py | 36 +++++++++ tests/sampling/optimizations/__init__.py | 0 .../optimizations/test_summary_stats.py | 47 +++++++++++ 9 files changed, 166 insertions(+), 1 deletion(-) create mode 100644 pymc_experimental/sampling/optimizations/__init__.py create mode 100644 pymc_experimental/sampling/optimizations/summary_stats.py create mode 100644 tests/sampling/__init__.py create mode 100644 tests/sampling/mcmc/__init__.py create mode 100644 tests/sampling/mcmc/test_mcmc.py create mode 100644 tests/sampling/optimizations/__init__.py create mode 100644 tests/sampling/optimizations/test_summary_stats.py diff --git a/pymc_experimental/sampling/__init__.py b/pymc_experimental/sampling/__init__.py index e69de29b..53c2fc07 100644 --- a/pymc_experimental/sampling/__init__.py +++ b/pymc_experimental/sampling/__init__.py @@ -0,0 +1,3 @@ +# ruff: noqa: F401 +# Add rewrites to the optimization DBs +import pymc_experimental.sampling.optimizations.summary_stats diff --git a/pymc_experimental/sampling/mcmc.py b/pymc_experimental/sampling/mcmc.py index 76b32f6f..ec79143e 100644 --- a/pymc_experimental/sampling/mcmc.py +++ b/pymc_experimental/sampling/mcmc.py @@ -53,7 +53,7 @@ def opt_sample( fgraph, _ = fgraph_from_model(model) if rewriter is None: - rewriter = posterior_optimization_db.query(RewriteDatabaseQuery(include=[])) + rewriter = posterior_optimization_db.query(RewriteDatabaseQuery(include=["summary_stats"])) _, _, rewrite_counters, *_ = rewriter.rewrite(fgraph) if verbose: diff --git a/pymc_experimental/sampling/optimizations/__init__.py b/pymc_experimental/sampling/optimizations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pymc_experimental/sampling/optimizations/summary_stats.py b/pymc_experimental/sampling/optimizations/summary_stats.py new file mode 100644 index 00000000..c2a717ad --- /dev/null +++ b/pymc_experimental/sampling/optimizations/summary_stats.py @@ -0,0 +1,79 @@ +import pytensor.tensor as pt + +from pymc.distributions import Gamma, Normal +from pymc.model.fgraph import ModelObservedRV, model_observed_rv +from pytensor.graph.fg import FunctionGraph +from pytensor.graph.rewriting.basic import node_rewriter + +from pymc_experimental.sampling.mcmc import posterior_optimization_db + + +@node_rewriter(tracks=[ModelObservedRV]) +def summary_stats_normal(fgraph: FunctionGraph, node): + """Applies the equivalence (up to a normalizing constant) described in: + + https://mc-stan.org/docs/stan-users-guide/efficiency-tuning.html#exploiting-sufficient-statistics + """ + [observed_rv] = node.outputs + [rv, data] = node.inputs + + if not isinstance(rv.owner.op, Normal): + return None + + # Check the normal RV is not just a scalar + if all(rv.type.broadcastable): + return None + + # Check that the observed RV is not used anywhere else (like a Potential or Deterministic) + # There should be only one use: as an "output" + if len(fgraph.clients[observed_rv]) > 1: + return None + + mu, sigma = rv.owner.op.dist_params(rv.owner) + + # Check if mu and sigma are scalar RVs + if not all(mu.type.broadcastable) and not all(sigma.type.broadcastable): + return None + + # Check that mu and sigma are not used anywhere else + # Note: This is too restrictive, it's fine if they're used in Deterministics! + # There should only be two uses: as an "output" and as the param of the `rv` + if len(fgraph.clients[mu]) > 2 or len(fgraph.clients[sigma]) > 2: + return None + + # Remove expand_dims + mu = mu.squeeze() + sigma = sigma.squeeze() + + # Apply the rewrite + mean_data = pt.mean(data) + mean_data.name = None + var_data = pt.var(data, ddof=1) + var_data.name = None + N = data.size + sqrt_N = pt.sqrt(N) + nm1_over2 = (N - 1) / 2 + + observed_mean = model_observed_rv( + Normal.dist(mu=mu, sigma=sigma / sqrt_N), + mean_data, + ) + observed_mean.name = f"{rv.name}_mean" + + observed_var = model_observed_rv( + Gamma.dist(alpha=nm1_over2, beta=nm1_over2 / (sigma**2)), + var_data, + ) + observed_var.name = f"{rv.name}_var" + + fgraph.add_output(observed_mean, import_missing=True) + fgraph.add_output(observed_var, import_missing=True) + fgraph.remove_node(node) + # Just so it shows in the profile for verbose=True, + # It won't do anything because node is not in the fgraph anymore + return [node.out.copy()] + + +posterior_optimization_db.register( + summary_stats_normal.__name__, summary_stats_normal, "summary_stats" +) diff --git a/tests/sampling/__init__.py b/tests/sampling/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/sampling/mcmc/__init__.py b/tests/sampling/mcmc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/sampling/mcmc/test_mcmc.py b/tests/sampling/mcmc/test_mcmc.py new file mode 100644 index 00000000..d9bc7936 --- /dev/null +++ b/tests/sampling/mcmc/test_mcmc.py @@ -0,0 +1,36 @@ +import numpy as np + +from pymc.distributions import HalfNormal, Normal +from pymc.model.core import Model +from pymc.sampling.mcmc import sample + +from pymc_experimental import opt_sample + + +def test_sample_opt_summary_stats(capsys): + rng = np.random.default_rng(3) + y_data = rng.normal(loc=1, scale=0.5, size=(1000,)) + + with Model() as m: + mu = Normal("mu") + sigma = HalfNormal("sigma") + y = Normal("y", mu=mu, sigma=sigma, observed=y_data) + + sample_kwargs = dict( + chains=1, tune=500, draws=500, compute_convergence_checks=False, progressbar=False + ) + idata = sample(**sample_kwargs) + opt_idata = opt_sample(**sample_kwargs, verbose=True) + + captured_out = capsys.readouterr().out + assert "Applied optimization: summary_stats_normal 1x" in captured_out + + assert opt_idata.posterior.sizes["chain"] == 1 + assert opt_idata.posterior.sizes["draw"] == 500 + np.testing.assert_allclose( + idata.posterior["mu"].mean(), opt_idata.posterior["mu"].mean(), rtol=1e-3 + ) + np.testing.assert_allclose( + idata.posterior["sigma"].mean(), opt_idata.posterior["sigma"].mean(), rtol=1e-2 + ) + assert idata.sample_stats.sampling_time > opt_idata.sample_stats.sampling_time diff --git a/tests/sampling/optimizations/__init__.py b/tests/sampling/optimizations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/sampling/optimizations/test_summary_stats.py b/tests/sampling/optimizations/test_summary_stats.py new file mode 100644 index 00000000..cc5e24d2 --- /dev/null +++ b/tests/sampling/optimizations/test_summary_stats.py @@ -0,0 +1,47 @@ +import numpy as np + +from pymc.distributions import HalfNormal, Normal +from pymc.model.core import Model +from pymc.model.fgraph import fgraph_from_model, model_from_fgraph +from pytensor.graph.rewriting.basic import out2in + +from pymc_experimental.sampling.optimizations.summary_stats import summary_stats_normal + + +def test_summary_stats_normal(): + rng = np.random.default_rng(3) + y_data = rng.normal(loc=1, scale=0.5, size=(1000,)) + + with Model() as m: + mu = Normal("mu") + sigma = HalfNormal("sigma") + y = Normal("y", mu=mu, sigma=sigma, observed=y_data) + + assert len(m.free_RVs) == 2 + assert len(m.observed_RVs) == 1 + + fgraph, _ = fgraph_from_model(m) + summary_stats_rewrite = out2in(summary_stats_normal) + _ = summary_stats_rewrite.apply(fgraph) + new_m = model_from_fgraph(fgraph) + + assert len(new_m.free_RVs) == 2 + assert len(new_m.observed_RVs) == 2 + + # Confirm equivalent (up to an additive normalization constant) + m_logp = m.compile_logp() + new_m_logp = new_m.compile_logp() + + ip = m.initial_point() + first_logp_diff = m_logp(ip) - new_m_logp(ip) + + ip["mu"] += 0.5 + ip["sigma_log__"] += 1.5 + second_logp_diff = m_logp(ip) - new_m_logp(ip) + + np.testing.assert_allclose(first_logp_diff, second_logp_diff) + + # dlogp should be the same + m_dlogp = m.compile_dlogp() + new_m_dlogp = new_m.compile_dlogp() + np.testing.assert_allclose(m_dlogp(ip), new_m_dlogp(ip))