Skip to content

Commit

Permalink
Add Normal summary stats optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Dec 5, 2024
1 parent 3c19a5e commit 493855f
Show file tree
Hide file tree
Showing 9 changed files with 166 additions and 1 deletion.
3 changes: 3 additions & 0 deletions pymc_experimental/sampling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# ruff: noqa: F401
# Add rewrites to the optimization DBs
import pymc_experimental.sampling.optimizations.summary_stats
2 changes: 1 addition & 1 deletion pymc_experimental/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def opt_sample(
fgraph, _ = fgraph_from_model(model)

if rewriter is None:
rewriter = posterior_optimization_db.query(RewriteDatabaseQuery(include=[]))
rewriter = posterior_optimization_db.query(RewriteDatabaseQuery(include=["summary_stats"]))
_, _, rewrite_counters, *_ = rewriter.rewrite(fgraph)

if verbose:
Expand Down
Empty file.
79 changes: 79 additions & 0 deletions pymc_experimental/sampling/optimizations/summary_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import pytensor.tensor as pt

from pymc.distributions import Gamma, Normal
from pymc.model.fgraph import ModelObservedRV, model_observed_rv
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter

from pymc_experimental.sampling.mcmc import posterior_optimization_db


@node_rewriter(tracks=[ModelObservedRV])
def summary_stats_normal(fgraph: FunctionGraph, node):
"""Applies the equivalence (up to a normalizing constant) described in:
https://mc-stan.org/docs/stan-users-guide/efficiency-tuning.html#exploiting-sufficient-statistics
"""
[observed_rv] = node.outputs
[rv, data] = node.inputs

if not isinstance(rv.owner.op, Normal):
return None

# Check the normal RV is not just a scalar
if all(rv.type.broadcastable):
return None

# Check that the observed RV is not used anywhere else (like a Potential or Deterministic)
# There should be only one use: as an "output"
if len(fgraph.clients[observed_rv]) > 1:
return None

mu, sigma = rv.owner.op.dist_params(rv.owner)

# Check if mu and sigma are scalar RVs
if not all(mu.type.broadcastable) and not all(sigma.type.broadcastable):
return None

# Check that mu and sigma are not used anywhere else
# Note: This is too restrictive, it's fine if they're used in Deterministics!
# There should only be two uses: as an "output" and as the param of the `rv`
if len(fgraph.clients[mu]) > 2 or len(fgraph.clients[sigma]) > 2:
return None

# Remove expand_dims
mu = mu.squeeze()
sigma = sigma.squeeze()

# Apply the rewrite
mean_data = pt.mean(data)
mean_data.name = None
var_data = pt.var(data, ddof=1)
var_data.name = None
N = data.size
sqrt_N = pt.sqrt(N)
nm1_over2 = (N - 1) / 2

observed_mean = model_observed_rv(
Normal.dist(mu=mu, sigma=sigma / sqrt_N),
mean_data,
)
observed_mean.name = f"{rv.name}_mean"

observed_var = model_observed_rv(
Gamma.dist(alpha=nm1_over2, beta=nm1_over2 / (sigma**2)),
var_data,
)
observed_var.name = f"{rv.name}_var"

fgraph.add_output(observed_mean, import_missing=True)
fgraph.add_output(observed_var, import_missing=True)
fgraph.remove_node(node)
# Just so it shows in the profile for verbose=True,
# It won't do anything because node is not in the fgraph anymore
return [node.out.copy()]


posterior_optimization_db.register(
summary_stats_normal.__name__, summary_stats_normal, "summary_stats"
)
Empty file added tests/sampling/__init__.py
Empty file.
Empty file added tests/sampling/mcmc/__init__.py
Empty file.
36 changes: 36 additions & 0 deletions tests/sampling/mcmc/test_mcmc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import numpy as np

from pymc.distributions import HalfNormal, Normal
from pymc.model.core import Model
from pymc.sampling.mcmc import sample

from pymc_experimental import opt_sample


def test_sample_opt_summary_stats(capsys):
rng = np.random.default_rng(3)
y_data = rng.normal(loc=1, scale=0.5, size=(1000,))

with Model() as m:
mu = Normal("mu")
sigma = HalfNormal("sigma")
y = Normal("y", mu=mu, sigma=sigma, observed=y_data)

sample_kwargs = dict(
chains=1, tune=500, draws=500, compute_convergence_checks=False, progressbar=False
)
idata = sample(**sample_kwargs)
opt_idata = opt_sample(**sample_kwargs, verbose=True)

captured_out = capsys.readouterr().out
assert "Applied optimization: summary_stats_normal 1x" in captured_out

assert opt_idata.posterior.sizes["chain"] == 1
assert opt_idata.posterior.sizes["draw"] == 500
np.testing.assert_allclose(
idata.posterior["mu"].mean(), opt_idata.posterior["mu"].mean(), rtol=1e-3
)
np.testing.assert_allclose(
idata.posterior["sigma"].mean(), opt_idata.posterior["sigma"].mean(), rtol=1e-2
)
assert idata.sample_stats.sampling_time > opt_idata.sample_stats.sampling_time
Empty file.
47 changes: 47 additions & 0 deletions tests/sampling/optimizations/test_summary_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import numpy as np

from pymc.distributions import HalfNormal, Normal
from pymc.model.core import Model
from pymc.model.fgraph import fgraph_from_model, model_from_fgraph
from pytensor.graph.rewriting.basic import out2in

from pymc_experimental.sampling.optimizations.summary_stats import summary_stats_normal


def test_summary_stats_normal():
rng = np.random.default_rng(3)
y_data = rng.normal(loc=1, scale=0.5, size=(1000,))

with Model() as m:
mu = Normal("mu")
sigma = HalfNormal("sigma")
y = Normal("y", mu=mu, sigma=sigma, observed=y_data)

assert len(m.free_RVs) == 2
assert len(m.observed_RVs) == 1

fgraph, _ = fgraph_from_model(m)
summary_stats_rewrite = out2in(summary_stats_normal)
_ = summary_stats_rewrite.apply(fgraph)
new_m = model_from_fgraph(fgraph)

assert len(new_m.free_RVs) == 2
assert len(new_m.observed_RVs) == 2

# Confirm equivalent (up to an additive normalization constant)
m_logp = m.compile_logp()
new_m_logp = new_m.compile_logp()

ip = m.initial_point()
first_logp_diff = m_logp(ip) - new_m_logp(ip)

ip["mu"] += 0.5
ip["sigma_log__"] += 1.5
second_logp_diff = m_logp(ip) - new_m_logp(ip)

np.testing.assert_allclose(first_logp_diff, second_logp_diff)

# dlogp should be the same
m_dlogp = m.compile_dlogp()
new_m_dlogp = new_m.compile_dlogp()
np.testing.assert_allclose(m_dlogp(ip), new_m_dlogp(ip))

0 comments on commit 493855f

Please sign in to comment.