Skip to content

Commit

Permalink
Add Beta-Binomial conjugacy optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Dec 7, 2024
1 parent 1469915 commit a5c036a
Show file tree
Hide file tree
Showing 8 changed files with 485 additions and 13 deletions.
14 changes: 1 addition & 13 deletions pymc_experimental/model/marginal/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from pymc.logprob.abstract import MeasurableOp, _logprob
from pymc.logprob.basic import conditional_logp, logp
from pymc.pytensorf import constant_fold
from pytensor import Variable
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.mode import Mode
from pytensor.graph import Op, vectorize_graph
Expand All @@ -17,6 +16,7 @@
from pytensor.tensor import TensorVariable

from pymc_experimental.distributions import DiscreteMarkovChain
from pymc_experimental.utils.ofg import inline_ofg_outputs


class MarginalRV(OpFromGraph, MeasurableOp):
Expand Down Expand Up @@ -126,18 +126,6 @@ def align_logp_dims(dims: tuple[tuple[int, None]], logp: TensorVariable) -> Tens
return logp.transpose(*dims_alignment)


def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]:
"""Inline the inner graph (outputs) of an OpFromGraph Op.
Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps"
the inner graph.
"""
return clone_replace(
op.inner_outputs,
replace=tuple(zip(op.inner_inputs, inputs)),
)


DUMMY_ZERO = pt.constant(0, name="dummy_zero")


Expand Down
38 changes: 38 additions & 0 deletions pymc_experimental/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,44 @@ def opt_sample(
y = pm.Binomial("y", n=100, p=p, observed=[1, 50, 99, 50]*250)
idata = pmx.opt_sample(verbose=True)
# Applied optimization: beta_binomial_conjugacy 1x
# ConjugateRVSampler: [p]
You can control which optimizations are applied using the `include` and `exclude` arguments:
.. code:: python
import pymc as pm
import pymc_experimental as pmx
with pm.Model() as m:
p = pm.Beta("p", 1, 1, shape=(1000,))
y = pm.Binomial("y", n=100, p=p, observed=[1, 50, 99, 50]*250)
idata = pmx.opt_sample(exclude="conjugacy", verbose=True)
# No optimizations applied
# NUTS: [p]
.. code:: python
import pymc as pm
import pymc_experimental as pmx
with pm.Model() as m:
a = pm.InverseGamma("a", 1, 1)
b = pm.InverseGamma("b", 1, 1)
p = pm.Beta("p", a, b, shape=(1000,))
y = pm.Binomial("y", n=100, p=p, observed=[1, 50, 99, 50]*250)
# By default, the conjugacy of p will not be applied because it depends on other free variables
idata = pmx.opt_sample(include="conjugacy-eager", verbose=True)
# Applied optimization: beta_binomial_conjugacy_eager 1x
# CompoundStep
# >NUTS: [a, b]
# >ConjugateRVSampler: [p]
"""
if kwargs.get("step", None) is not None:
raise ValueError(
Expand Down
1 change: 1 addition & 0 deletions pymc_experimental/sampling/optimizations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# ruff: noqa: F401
# Add rewrites to the optimization DBs
import pymc_experimental.sampling.optimizations.conjugacy
import pymc_experimental.sampling.optimizations.summary_stats

from pymc_experimental.sampling.optimizations.optimize import (
Expand Down
199 changes: 199 additions & 0 deletions pymc_experimental/sampling/optimizations/conjugacy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
from collections.abc import Sequence
from functools import partial

from pymc.distributions import Beta, Binomial
from pymc.model.fgraph import ModelFreeRV, ModelValuedVar, model_free_rv
from pymc.pytensorf import collect_default_updates
from pytensor.graph.basic import Variable, ancestors
from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.subtensor import _sum_grad_over_bcasted_dims as sum_bcasted_dims

from pymc_experimental.sampling.optimizations.conjugate_sampler import (
ConjugateRV,
)
from pymc_experimental.sampling.optimizations.optimize import posterior_optimization_db


def register_conjugacy_rewrites_variants(rewrite_fn, tracks=(ModelFreeRV,)):
"""Register a rewrite function and its force variant in the posterior optimization DB."""
name = rewrite_fn.__name__

rewrite_fn_default = partial(rewrite_fn, eager=False)
rewrite_fn_default.__name__ = name
rewrite_default = node_rewriter(tracks=tracks)(rewrite_fn_default)

rewrite_fn_eager = partial(rewrite_fn, eager=True)
rewrite_fn_eager.__name__ = f"{name}_eager"
rewrite_eager = node_rewriter(tracks=tracks)(rewrite_fn_eager)

posterior_optimization_db.register(
rewrite_default.__name__,
rewrite_default,
"default",
"conjugacy",
)

posterior_optimization_db.register(
rewrite_eager.__name__,
rewrite_eager,
"non-default",
"conjugacy-eager",
)

return rewrite_default, rewrite_eager


def has_free_rv_ancestor(vars: Variable | Sequence[Variable]) -> bool:
"""Return True if any of the variables have a model variable as an ancestor."""
if not isinstance(vars, Sequence):
vars = (vars,)

# TODO: It should stop at observed RVs, it doesn't matter if they have a free RV above
# Did not implement due to laziness and it being a rare case
return any(
var.owner is not None and isinstance(var.owner.op, ModelFreeRV) for var in ancestors(vars)
)


def get_model_var_of_rv(fgraph: FunctionGraph, rv: Variable) -> Variable:
"""Return the Model dummy var that wraps the RV"""
for client, _ in fgraph.clients[rv]:
if isinstance(client.op, ModelValuedVar):
return client.outputs[0]


def get_dist_params(rv: Variable) -> tuple[Variable]:
return rv.owner.op.dist_params(rv.owner)


def rv_used_by(
fgraph: FunctionGraph,
rv: Variable,
used_by_type: type,
used_as_arg_idx: int | Sequence[int],
strict: bool = True,
) -> list[Variable]:
"""Return the RVs that use `rv` as an argument in an operation of type `used_by_type`.
RV may be used directly or broadcasted before being used.
Parameters
----------
fgraph : FunctionGraph
The function graph containing the RVs
rv : Variable
The RV to check for uses.
used_by_type : type
The type of operation that may use the RV.
used_as_arg_idx : int | Sequence[int]
The index of the RV in the operation's inputs.
strict : bool, default=True
If True, return no results when the RV is used in an unrecognized way.
"""
if isinstance(used_as_arg_idx, int):
used_as_arg_idx = (used_as_arg_idx,)

clients = fgraph.clients
used_by: list[Variable] = []
for client, inp_idx in clients[rv]:
if isinstance(client.op, Output):
continue

if isinstance(client.op, used_by_type) and inp_idx in used_as_arg_idx:
# RV is directly used by the RV type
used_by.append(client.default_output())

elif isinstance(client.op, DimShuffle) and client.op.is_left_expand_dims:
for sub_client, sub_inp_idx in clients[client.outputs[0]]:
if isinstance(sub_client.op, used_by_type) and sub_inp_idx in used_as_arg_idx:
# RV is broadcasted and then used by the RV type
used_by.append(sub_client.default_output())
elif strict:
# Some other unrecognized use, bail out
return []
elif strict:
# Some other unrecognized use, bail out
return []

return used_by


def wrap_rv_and_conjugate_rv(
fgraph: FunctionGraph, rv: Variable, conjugate_rv: Variable, inputs: Sequence[Variable]
) -> Variable:
"""Wrap the RV and its conjugate posterior RV in a ConjugateRV node.
Also takes care of handling the random number generators used in the conjugate posterior.
"""
rngs, next_rngs = zip(*collect_default_updates(conjugate_rv, inputs=[rv, *inputs]).items())
for rng in rngs:
if rng not in fgraph.inputs:
fgraph.add_input(rng)
conjugate_op = ConjugateRV(inputs=[rv, *inputs, *rngs], outputs=[rv, conjugate_rv, *next_rngs])
return conjugate_op(rv, *inputs, *rngs)[0]


def create_untransformed_free_rv(
fgraph: FunctionGraph, rv: Variable, name: str, dims: Sequence[str | Variable]
) -> Variable:
"""Create a model FreeRV without transform."""
transform = None
value = rv.type(name=name)
fgraph.add_input(value)
free_rv = model_free_rv(rv, value, transform, *dims)
free_rv.name = name
return free_rv


def beta_binomial_conjugacy(fgraph: FunctionGraph, node, eager: bool = False):
if not isinstance(node.op, ModelFreeRV):
return None

[beta_free_rv] = node.outputs
beta_rv, _, *beta_dims = node.inputs

if not isinstance(beta_rv.owner.op, Beta):
return None

_, beta_rv_size, a, b = beta_rv.owner.inputs
if not eager and has_free_rv_ancestor([a, b]):
# Don't apply rewrite if a, b depend on other model variables as that will force a Gibbs sampling scheme
return None

p_arg_idx = 3 # inputs to Binomial are (rng, size, n, p)
binomial_rvs = rv_used_by(fgraph, beta_free_rv, Binomial, p_arg_idx)

if len(binomial_rvs) != 1:
# Question: Can we apply conjugacy when RV is used by more than one binomial?
return None

[binomial_rv] = binomial_rvs

binomial_model_var = get_model_var_of_rv(fgraph, binomial_rv)
if binomial_model_var is None:
return None

# We want to replace free_rv by ConjugateRV()->(free_rv, conjugate_posterior_rv)
n, _ = get_dist_params(binomial_rv)

# Use value of y in new graph to avoid circularity
y = binomial_model_var.owner.inputs[1]

conjugate_a = sum_bcasted_dims(beta_rv, a + y)
conjugate_b = sum_bcasted_dims(beta_rv, b + (n - y))

conjugate_beta_rv = Beta.dist(conjugate_a, conjugate_b, shape=beta_rv_size)

new_beta_rv = wrap_rv_and_conjugate_rv(fgraph, beta_rv, conjugate_beta_rv, [a, b, n, y])
new_beta_free_rv = create_untransformed_free_rv(
fgraph, new_beta_rv, beta_free_rv.name, beta_dims
)
return [new_beta_free_rv]


beta_binomial_conjugacy_default, beta_binomial_conjugacy_force = (
register_conjugacy_rewrites_variants(beta_binomial_conjugacy)
)
Loading

0 comments on commit a5c036a

Please sign in to comment.