-
-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Normal summary stats optimization
- Loading branch information
1 parent
3c19a5e
commit 493855f
Showing
9 changed files
with
166 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# ruff: noqa: F401 | ||
# Add rewrites to the optimization DBs | ||
import pymc_experimental.sampling.optimizations.summary_stats |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import pytensor.tensor as pt | ||
|
||
from pymc.distributions import Gamma, Normal | ||
from pymc.model.fgraph import ModelObservedRV, model_observed_rv | ||
from pytensor.graph.fg import FunctionGraph | ||
from pytensor.graph.rewriting.basic import node_rewriter | ||
|
||
from pymc_experimental.sampling.mcmc import posterior_optimization_db | ||
|
||
|
||
@node_rewriter(tracks=[ModelObservedRV]) | ||
def summary_stats_normal(fgraph: FunctionGraph, node): | ||
"""Applies the equivalence (up to a normalizing constant) described in: | ||
https://mc-stan.org/docs/stan-users-guide/efficiency-tuning.html#exploiting-sufficient-statistics | ||
""" | ||
[observed_rv] = node.outputs | ||
[rv, data] = node.inputs | ||
|
||
if not isinstance(rv.owner.op, Normal): | ||
return None | ||
|
||
# Check the normal RV is not just a scalar | ||
if all(rv.type.broadcastable): | ||
return None | ||
|
||
# Check that the observed RV is not used anywhere else (like a Potential or Deterministic) | ||
# There should be only one use: as an "output" | ||
if len(fgraph.clients[observed_rv]) > 1: | ||
return None | ||
|
||
mu, sigma = rv.owner.op.dist_params(rv.owner) | ||
|
||
# Check if mu and sigma are scalar RVs | ||
if not all(mu.type.broadcastable) and not all(sigma.type.broadcastable): | ||
return None | ||
|
||
# Check that mu and sigma are not used anywhere else | ||
# Note: This is too restrictive, it's fine if they're used in Deterministics! | ||
# There should only be two uses: as an "output" and as the param of the `rv` | ||
if len(fgraph.clients[mu]) > 2 or len(fgraph.clients[sigma]) > 2: | ||
return None | ||
|
||
# Remove expand_dims | ||
mu = mu.squeeze() | ||
sigma = sigma.squeeze() | ||
|
||
# Apply the rewrite | ||
mean_data = pt.mean(data) | ||
mean_data.name = None | ||
var_data = pt.var(data, ddof=1) | ||
var_data.name = None | ||
N = data.size | ||
sqrt_N = pt.sqrt(N) | ||
nm1_over2 = (N - 1) / 2 | ||
|
||
observed_mean = model_observed_rv( | ||
Normal.dist(mu=mu, sigma=sigma / sqrt_N), | ||
mean_data, | ||
) | ||
observed_mean.name = f"{rv.name}_mean" | ||
|
||
observed_var = model_observed_rv( | ||
Gamma.dist(alpha=nm1_over2, beta=nm1_over2 / (sigma**2)), | ||
var_data, | ||
) | ||
observed_var.name = f"{rv.name}_var" | ||
|
||
fgraph.add_output(observed_mean, import_missing=True) | ||
fgraph.add_output(observed_var, import_missing=True) | ||
fgraph.remove_node(node) | ||
# Just so it shows in the profile for verbose=True, | ||
# It won't do anything because node is not in the fgraph anymore | ||
return [node.out.copy()] | ||
|
||
|
||
posterior_optimization_db.register( | ||
summary_stats_normal.__name__, summary_stats_normal, "summary_stats" | ||
) |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import numpy as np | ||
|
||
from pymc.distributions import HalfNormal, Normal | ||
from pymc.model.core import Model | ||
from pymc.sampling.mcmc import sample | ||
|
||
from pymc_experimental import opt_sample | ||
|
||
|
||
def test_sample_opt_summary_stats(capsys): | ||
rng = np.random.default_rng(3) | ||
y_data = rng.normal(loc=1, scale=0.5, size=(1000,)) | ||
|
||
with Model() as m: | ||
mu = Normal("mu") | ||
sigma = HalfNormal("sigma") | ||
y = Normal("y", mu=mu, sigma=sigma, observed=y_data) | ||
|
||
sample_kwargs = dict( | ||
chains=1, tune=500, draws=500, compute_convergence_checks=False, progressbar=False | ||
) | ||
idata = sample(**sample_kwargs) | ||
opt_idata = opt_sample(**sample_kwargs, verbose=True) | ||
|
||
captured_out = capsys.readouterr().out | ||
assert "Applied optimization: summary_stats_normal 1x" in captured_out | ||
|
||
assert opt_idata.posterior.sizes["chain"] == 1 | ||
assert opt_idata.posterior.sizes["draw"] == 500 | ||
np.testing.assert_allclose( | ||
idata.posterior["mu"].mean(), opt_idata.posterior["mu"].mean(), rtol=1e-3 | ||
) | ||
np.testing.assert_allclose( | ||
idata.posterior["sigma"].mean(), opt_idata.posterior["sigma"].mean(), rtol=1e-2 | ||
) | ||
assert idata.sample_stats.sampling_time > opt_idata.sample_stats.sampling_time |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import numpy as np | ||
|
||
from pymc.distributions import HalfNormal, Normal | ||
from pymc.model.core import Model | ||
from pymc.model.fgraph import fgraph_from_model, model_from_fgraph | ||
from pytensor.graph.rewriting.basic import out2in | ||
|
||
from pymc_experimental.sampling.optimizations.summary_stats import summary_stats_normal | ||
|
||
|
||
def test_summary_stats_normal(): | ||
rng = np.random.default_rng(3) | ||
y_data = rng.normal(loc=1, scale=0.5, size=(1000,)) | ||
|
||
with Model() as m: | ||
mu = Normal("mu") | ||
sigma = HalfNormal("sigma") | ||
y = Normal("y", mu=mu, sigma=sigma, observed=y_data) | ||
|
||
assert len(m.free_RVs) == 2 | ||
assert len(m.observed_RVs) == 1 | ||
|
||
fgraph, _ = fgraph_from_model(m) | ||
summary_stats_rewrite = out2in(summary_stats_normal) | ||
_ = summary_stats_rewrite.apply(fgraph) | ||
new_m = model_from_fgraph(fgraph) | ||
|
||
assert len(new_m.free_RVs) == 2 | ||
assert len(new_m.observed_RVs) == 2 | ||
|
||
# Confirm equivalent (up to an additive normalization constant) | ||
m_logp = m.compile_logp() | ||
new_m_logp = new_m.compile_logp() | ||
|
||
ip = m.initial_point() | ||
first_logp_diff = m_logp(ip) - new_m_logp(ip) | ||
|
||
ip["mu"] += 0.5 | ||
ip["sigma_log__"] += 1.5 | ||
second_logp_diff = m_logp(ip) - new_m_logp(ip) | ||
|
||
np.testing.assert_allclose(first_logp_diff, second_logp_diff) | ||
|
||
# dlogp should be the same | ||
m_dlogp = m.compile_dlogp() | ||
new_m_dlogp = new_m.compile_dlogp() | ||
np.testing.assert_allclose(m_dlogp(ip), new_m_dlogp(ip)) |