-
-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Beta-Binomial conjugacy optimization
- Loading branch information
1 parent
e051965
commit 83953eb
Showing
9 changed files
with
365 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
# Add rewrites to the optimization DBs | ||
import pymc_experimental.sampling.optimizations.conjugacy | ||
import pymc_experimental.sampling.optimizations.summary_stats |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
from typing import Sequence | ||
|
||
from pymc import STEP_METHODS | ||
from pytensor.tensor.random.type import RandomGeneratorType | ||
|
||
from pytensor.compile.builders import OpFromGraph | ||
|
||
from pymc_experimental.sampling.mcmc import posterior_optimization_db | ||
from pymc_experimental.sampling.optimizations.conjugate_sampler import ConjugateRV, ConjugateRVSampler | ||
|
||
STEP_METHODS.append(ConjugateRVSampler) | ||
|
||
from pytensor.graph.fg import Output | ||
from pytensor.tensor.elemwise import DimShuffle | ||
from pymc.model.fgraph import model_free_rv, ModelValuedVar | ||
|
||
|
||
from pytensor.graph.basic import Variable | ||
from pytensor.graph.fg import FunctionGraph | ||
from pytensor.graph.rewriting.basic import node_rewriter | ||
from pymc.model.fgraph import ModelFreeRV | ||
from pymc.distributions import Beta, Binomial | ||
from pymc.pytensorf import collect_default_updates | ||
|
||
|
||
def get_model_var_of_rv(fgraph: FunctionGraph, rv: Variable) -> Variable: | ||
"""Return the Model dummy var that wraps the RV""" | ||
for client, _ in fgraph.clients[rv]: | ||
if isinstance(client.op, ModelValuedVar): | ||
return client.outputs[0] | ||
|
||
|
||
def get_dist_params(rv: Variable) -> tuple[Variable]: | ||
return rv.owner.op.dist_params(rv.owner) | ||
|
||
|
||
def rv_used_by(fgraph: FunctionGraph, rv: Variable, used_by_type: type, used_as_arg_idx: int | Sequence[int], strict: bool = True) -> list[Variable]: | ||
"""Return the RVs that use `rv` as an argument in an operation of type `used_by_type`. | ||
RV may be used directly or broadcasted before being used. | ||
Parameters | ||
---------- | ||
fgraph : FunctionGraph | ||
The function graph containing the RVs | ||
rv : Variable | ||
The RV to check for uses. | ||
used_by_type : type | ||
The type of operation that may use the RV. | ||
used_as_arg_idx : int | Sequence[int] | ||
The index of the RV in the operation's inputs. | ||
strict : bool, default=True | ||
If True, return no results when the RV is used in an unrecognized way. | ||
""" | ||
if isinstance(used_as_arg_idx, int): | ||
used_as_arg_idx = (used_as_arg_idx,) | ||
|
||
clients = fgraph.clients | ||
used_by : list[Variable] = [] | ||
for client, inp_idx in clients[rv]: | ||
if isinstance(client.op, Output): | ||
continue | ||
|
||
if isinstance(client.op, used_by_type) and inp_idx in used_as_arg_idx: | ||
# RV is directly used by the RV type | ||
used_by.append(client.default_output()) | ||
|
||
elif isinstance(client.op, DimShuffle) and client.op.is_left_expand_dims: | ||
for sub_client, sub_inp_idx in clients[client.outputs[0]]: | ||
if isinstance(sub_client.op, used_by_type) and sub_inp_idx in used_as_arg_idx: | ||
# RV is broadcasted and then used by the RV type | ||
used_by.append(sub_client.default_output()) | ||
elif strict: | ||
# Some other unrecognized use, bail out | ||
return [] | ||
elif strict: | ||
# Some other unrecognized use, bail out | ||
return [] | ||
|
||
return used_by | ||
|
||
|
||
def wrap_rv_and_conjugate_rv(fgraph: FunctionGraph, rv: Variable, conjugate_rv: Variable, inputs: Sequence[Variable]) -> Variable: | ||
"""Wrap the RV and its conjugate posterior RV in a ConjugateRV node. | ||
Also takes care of handling the random number generators used in the conjugate posterior. | ||
""" | ||
rngs, next_rngs = zip(*collect_default_updates(conjugate_rv, inputs=[rv, *inputs]).items()) | ||
for rng in rngs: | ||
if rng not in fgraph.inputs: | ||
fgraph.add_input(rng) | ||
conjugate_op = ConjugateRV(inputs=[rv, *inputs, *rngs], outputs=[rv, conjugate_rv, *next_rngs]) | ||
return conjugate_op(rv, *inputs, *rngs)[0] | ||
|
||
|
||
def create_untransformed_free_rv(fgraph: FunctionGraph, rv: Variable, name: str, dims: Sequence[str | Variable]) -> Variable: | ||
"""Create a model FreeRV without transform.""" | ||
transform = None | ||
value = rv.type(name=name) | ||
fgraph.add_input(value) | ||
free_rv = model_free_rv(rv, value, transform, *dims) | ||
free_rv.name = name | ||
return free_rv | ||
|
||
|
||
@node_rewriter(tracks=[ModelFreeRV]) | ||
def beta_binomial_conjugacy(fgraph: FunctionGraph, node): | ||
"""This applies the equivalence (up to a normalizing constant) described in: | ||
https://mc-stan.org/docs/stan-users-guide/efficiency-tuning.html#exploiting-sufficient-statistics | ||
""" | ||
[beta_free_rv] = node.outputs | ||
beta_rv, beta_value, *beta_dims = node.inputs | ||
|
||
if not isinstance(beta_rv.owner.op, Beta): | ||
return None | ||
|
||
p_arg_idx = 3 # inputs to Binomial are (rng, size, n, p) | ||
binomial_rvs = rv_used_by(fgraph, beta_free_rv, Binomial, p_arg_idx) | ||
|
||
if len(binomial_rvs) != 1: | ||
# Question: Can we apply conjugacy when RV is used by more than one binomial? | ||
return None | ||
|
||
[binomial_rv] = binomial_rvs | ||
|
||
binomial_model_var = get_model_var_of_rv(fgraph, binomial_rv) | ||
if binomial_model_var is None: | ||
return None | ||
|
||
# We want to replace free_rv by ConjugateRV()->(free_rv, conjugate_posterior_rv) | ||
a, b = get_dist_params(beta_rv) | ||
n, _ = get_dist_params(binomial_rv) | ||
|
||
# Use value of y in new graph to avoid circularity | ||
y = binomial_model_var.owner.inputs[1] | ||
|
||
conjugate_a = a + y | ||
conjugate_b = b + (n - y) | ||
extra_dims = range(binomial_rv.type.ndim - beta_rv.type.ndim) | ||
if extra_dims: | ||
conjugate_a = conjugate_a.sum(extra_dims) | ||
conjugate_b = conjugate_b.sum(extra_dims) | ||
conjugate_beta_rv = Beta.dist(conjugate_a, conjugate_b) | ||
|
||
new_beta_rv = wrap_rv_and_conjugate_rv(fgraph, beta_rv, conjugate_beta_rv, [a, b, n, y]) | ||
new_beta_free_rv = create_untransformed_free_rv(fgraph, new_beta_rv, beta_free_rv.name, beta_dims) | ||
return [new_beta_free_rv] | ||
|
||
|
||
posterior_optimization_db.register( | ||
beta_binomial_conjugacy.__name__, | ||
beta_binomial_conjugacy, | ||
"conjugacy" | ||
) |
106 changes: 106 additions & 0 deletions
106
pymc_experimental/sampling/optimizations/conjugate_sampler.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import numpy as np | ||
|
||
from pymc_experimental.utils.ofg import inline_ofg_outputs | ||
from pytensor.compile.builders import OpFromGraph | ||
from pymc.logprob.abstract import MeasurableOp, _logprob | ||
from pymc.distributions.distribution import _support_point | ||
from pymc.step_methods.compound import BlockedStep, StepMethodState, Competence | ||
from pymc.model.core import modelcontext | ||
from pymc.util import get_value_vars_from_user_vars | ||
from pymc.pytensorf import compile_pymc | ||
from pytensor import shared | ||
from pytensor.tensor.random.type import RandomGeneratorType | ||
from pytensor.link.jax.linker import JAXLinker | ||
from pymc.initial_point import PointType | ||
|
||
class ConjugateRV(OpFromGraph, MeasurableOp): | ||
"""Wrapper for ConjugateRVs, that outputs the original RV and the conjugate posterior expression. | ||
For partial step samplers to work, the logp and initial point correspond to the original RV | ||
while the variable itself is sampled by default by the `ConjugateRVSampler` by evaluating directly the | ||
conjugate posterior expression (i.e., taking forward random draws). | ||
""" | ||
|
||
|
||
@_logprob.register(ConjugateRV) | ||
def conjugate_rv_logp(op, values, rv, *params, **kwargs): | ||
# Logp is the same as the original RV | ||
return _logprob(rv.owner.op, values, *rv.owner.inputs) | ||
|
||
|
||
@_support_point.register(ConjugateRV) | ||
def conjugate_rv_support_point(op, conjugate_rv, rv, *params): | ||
# Support point is the same as the original RV | ||
return _support_point(rv.owner.op, rv, *rv.owner.inputs) | ||
|
||
|
||
class ConjugateRVSampler(BlockedStep): | ||
name = "conjugate_rv_sampler" | ||
_state_class = StepMethodState | ||
|
||
def __init__(self, vars, model=None, rng=None, compile_kwargs: dict | None = None, **kwargs): | ||
if len(vars) != 1: | ||
raise ValueError("ConjugateRVSampler can only be assigned to one variable at a time") | ||
|
||
model = modelcontext(model) | ||
[value] = get_value_vars_from_user_vars(vars, model=model) | ||
rv = model.values_to_rvs[value] | ||
self.vars = (value,) | ||
self.rv_name = value.name | ||
|
||
if model.rvs_to_transforms[rv] is not None: | ||
raise ValueError("Variable assigned to ConjugateRVSampler cannot be transformed") | ||
|
||
rv_and_posterior_rv_node = rv.owner | ||
op = rv_and_posterior_rv_node.op | ||
if not isinstance(op, ConjugateRV): | ||
raise ValueError("Variable must be a ConjugateRV") | ||
|
||
# Replace RVs in inputs of rv_posterior_rv_node by the corresponding value variables | ||
value_inputs = model.replace_rvs_by_values( | ||
[rv_and_posterior_rv_node.outputs[1]], | ||
)[0].owner.inputs | ||
# Inline the ConjugateRV graph to only compile `posterior_rv` | ||
_, posterior_rv, *_ = inline_ofg_outputs(op, value_inputs) | ||
|
||
if compile_kwargs is None: | ||
compile_kwargs = {} | ||
self.posterior_fn = compile_pymc( | ||
model.value_vars, | ||
posterior_rv, | ||
random_seed=rng, | ||
on_unused_input="ignore", | ||
**compile_kwargs, | ||
) | ||
self.posterior_fn.trust_input = True | ||
if isinstance(self.posterior_fn.maker.linker, JAXLinker): | ||
# Reseeding RVs in JAX backend requires a different logic, becuase the SharedVariables | ||
# used internally are not the ones that `function.get_shared()` returns. | ||
raise ValueError("ConjugateRVSampler is not compatible with JAX backend") | ||
|
||
def set_rng(self, rng: np.random.Generator): | ||
# Copy the function and replace any shared RNGs | ||
# This is needed so that it can work correctly with multiple traces | ||
# This will be costly if set_rng is called too often! | ||
shared_rngs = [ | ||
var for var in self.posterior_fn.get_shared() if isinstance(var.type, RandomGeneratorType) | ||
] | ||
n_shared_rngs = len(shared_rngs) | ||
swap = { | ||
old_shared_rng: shared(rng, borrow=True) | ||
for old_shared_rng, rng in zip(shared_rngs, rng.spawn(n_shared_rngs), strict=True) | ||
} | ||
self.posterior_fn = self.posterior_fn.copy(swap=swap) | ||
|
||
def step(self, point: PointType) -> tuple[PointType, list]: | ||
new_point = point.copy() | ||
new_point[self.rv_name] = self.posterior_fn(**point) | ||
return new_point, [] | ||
|
||
@staticmethod | ||
def competence(var, has_grad): | ||
"""BinaryMetropolis is only suitable for Bernoulli and Categorical variables with k=2.""" | ||
if isinstance(var.owner.op, ConjugateRV): | ||
return Competence.IDEAL | ||
|
||
return Competence.INCOMPATIBLE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from pytensor.graph.basic import Variable | ||
from pytensor.graph.replace import clone_replace | ||
from pytensor.compile.builders import OpFromGraph | ||
from typing import Sequence | ||
|
||
|
||
def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]: | ||
"""Inline the inner graph (outputs) of an OpFromGraph Op. | ||
Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps" | ||
the inner graph. | ||
""" | ||
return clone_replace( | ||
op.inner_outputs, | ||
replace=tuple(zip(op.inner_inputs, inputs)), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.