Skip to content

Commit

Permalink
Add Beta-Binomial conjugacy optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Dec 5, 2024
1 parent e051965 commit dbe62e8
Show file tree
Hide file tree
Showing 9 changed files with 368 additions and 16 deletions.
14 changes: 1 addition & 13 deletions pymc_experimental/model/marginal/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from pymc.logprob.abstract import MeasurableOp, _logprob
from pymc.logprob.basic import conditional_logp, logp
from pymc.pytensorf import constant_fold
from pytensor import Variable
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.mode import Mode
from pytensor.graph import Op, vectorize_graph
Expand All @@ -17,6 +16,7 @@
from pytensor.tensor import TensorVariable

from pymc_experimental.distributions import DiscreteMarkovChain
from pymc_experimental.utils.ofg import inline_ofg_outputs


class MarginalRV(OpFromGraph, MeasurableOp):
Expand Down Expand Up @@ -126,18 +126,6 @@ def align_logp_dims(dims: tuple[tuple[int, None]], logp: TensorVariable) -> Tens
return logp.transpose(*dims_alignment)


def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]:
"""Inline the inner graph (outputs) of an OpFromGraph Op.
Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps"
the inner graph.
"""
return clone_replace(
op.inner_outputs,
replace=tuple(zip(op.inner_inputs, inputs)),
)


DUMMY_ZERO = pt.constant(0, name="dummy_zero")


Expand Down
1 change: 1 addition & 0 deletions pymc_experimental/sampling/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# Add rewrites to the optimization DBs
import pymc_experimental.sampling.optimizations.conjugacy
import pymc_experimental.sampling.optimizations.summary_stats
8 changes: 7 additions & 1 deletion pymc_experimental/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ def opt_sample(
):
"""Sample from a model after applying optimizations.
.. warning:: There is no guarantee that the optimizations will improve the sampling performance. For instance, conjugacy optimizations can lead to less efficient sampling for the remaining variables (if any), due to imposing a Gibbs sampling scheme.
Parameters
----------
model : Model (optional)
Expand Down Expand Up @@ -47,13 +50,16 @@ def opt_sample(
y = pm.Binomial("y", n=10, p=p, observed=5)
idata = pmx.opt_sample(verbose=True)
# Applied optimization: beta_binomial_conjugacy 1x
# ConjugateRVSampler: [p]
"""

model = modelcontext(model)
fgraph, _ = fgraph_from_model(model)

if rewriter is None:
rewriter = posterior_optimization_db.query(RewriteDatabaseQuery(include=["summary_stats"]))
rewriter = posterior_optimization_db.query(RewriteDatabaseQuery(include=["summary_stats", "conjugacy"]))
_, _, rewrite_counters, *_ = rewriter.rewrite(fgraph)

if verbose:
Expand Down
156 changes: 156 additions & 0 deletions pymc_experimental/sampling/optimizations/conjugacy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from typing import Sequence

from pymc import STEP_METHODS
from pytensor.tensor.random.type import RandomGeneratorType

from pytensor.compile.builders import OpFromGraph

from pymc_experimental.sampling.mcmc import posterior_optimization_db
from pymc_experimental.sampling.optimizations.conjugate_sampler import ConjugateRV, ConjugateRVSampler

STEP_METHODS.append(ConjugateRVSampler)

from pytensor.graph.fg import Output
from pytensor.tensor.elemwise import DimShuffle
from pymc.model.fgraph import model_free_rv, ModelValuedVar


from pytensor.graph.basic import Variable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter
from pymc.model.fgraph import ModelFreeRV
from pymc.distributions import Beta, Binomial
from pymc.pytensorf import collect_default_updates


def get_model_var_of_rv(fgraph: FunctionGraph, rv: Variable) -> Variable:
"""Return the Model dummy var that wraps the RV"""
for client, _ in fgraph.clients[rv]:
if isinstance(client.op, ModelValuedVar):
return client.outputs[0]


def get_dist_params(rv: Variable) -> tuple[Variable]:
return rv.owner.op.dist_params(rv.owner)


def rv_used_by(fgraph: FunctionGraph, rv: Variable, used_by_type: type, used_as_arg_idx: int | Sequence[int], strict: bool = True) -> list[Variable]:
"""Return the RVs that use `rv` as an argument in an operation of type `used_by_type`.
RV may be used directly or broadcasted before being used.
Parameters
----------
fgraph : FunctionGraph
The function graph containing the RVs
rv : Variable
The RV to check for uses.
used_by_type : type
The type of operation that may use the RV.
used_as_arg_idx : int | Sequence[int]
The index of the RV in the operation's inputs.
strict : bool, default=True
If True, return no results when the RV is used in an unrecognized way.
"""
if isinstance(used_as_arg_idx, int):
used_as_arg_idx = (used_as_arg_idx,)

clients = fgraph.clients
used_by : list[Variable] = []
for client, inp_idx in clients[rv]:
if isinstance(client.op, Output):
continue

if isinstance(client.op, used_by_type) and inp_idx in used_as_arg_idx:
# RV is directly used by the RV type
used_by.append(client.default_output())

elif isinstance(client.op, DimShuffle) and client.op.is_left_expand_dims:
for sub_client, sub_inp_idx in clients[client.outputs[0]]:
if isinstance(sub_client.op, used_by_type) and sub_inp_idx in used_as_arg_idx:
# RV is broadcasted and then used by the RV type
used_by.append(sub_client.default_output())
elif strict:
# Some other unrecognized use, bail out
return []
elif strict:
# Some other unrecognized use, bail out
return []

return used_by


def wrap_rv_and_conjugate_rv(fgraph: FunctionGraph, rv: Variable, conjugate_rv: Variable, inputs: Sequence[Variable]) -> Variable:
"""Wrap the RV and its conjugate posterior RV in a ConjugateRV node.
Also takes care of handling the random number generators used in the conjugate posterior.
"""
rngs, next_rngs = zip(*collect_default_updates(conjugate_rv, inputs=[rv, *inputs]).items())
for rng in rngs:
if rng not in fgraph.inputs:
fgraph.add_input(rng)
conjugate_op = ConjugateRV(inputs=[rv, *inputs, *rngs], outputs=[rv, conjugate_rv, *next_rngs])
return conjugate_op(rv, *inputs, *rngs)[0]


def create_untransformed_free_rv(fgraph: FunctionGraph, rv: Variable, name: str, dims: Sequence[str | Variable]) -> Variable:
"""Create a model FreeRV without transform."""
transform = None
value = rv.type(name=name)
fgraph.add_input(value)
free_rv = model_free_rv(rv, value, transform, *dims)
free_rv.name = name
return free_rv


@node_rewriter(tracks=[ModelFreeRV])
def beta_binomial_conjugacy(fgraph: FunctionGraph, node):
"""This applies the equivalence (up to a normalizing constant) described in:
https://mc-stan.org/docs/stan-users-guide/efficiency-tuning.html#exploiting-sufficient-statistics
"""
[beta_free_rv] = node.outputs
beta_rv, beta_value, *beta_dims = node.inputs

if not isinstance(beta_rv.owner.op, Beta):
return None

p_arg_idx = 3 # inputs to Binomial are (rng, size, n, p)
binomial_rvs = rv_used_by(fgraph, beta_free_rv, Binomial, p_arg_idx)

if len(binomial_rvs) != 1:
# Question: Can we apply conjugacy when RV is used by more than one binomial?
return None

[binomial_rv] = binomial_rvs

binomial_model_var = get_model_var_of_rv(fgraph, binomial_rv)
if binomial_model_var is None:
return None

# We want to replace free_rv by ConjugateRV()->(free_rv, conjugate_posterior_rv)
a, b = get_dist_params(beta_rv)
n, _ = get_dist_params(binomial_rv)

# Use value of y in new graph to avoid circularity
y = binomial_model_var.owner.inputs[1]

conjugate_a = a + y
conjugate_b = b + (n - y)
extra_dims = range(binomial_rv.type.ndim - beta_rv.type.ndim)
if extra_dims:
conjugate_a = conjugate_a.sum(extra_dims)
conjugate_b = conjugate_b.sum(extra_dims)
conjugate_beta_rv = Beta.dist(conjugate_a, conjugate_b)

new_beta_rv = wrap_rv_and_conjugate_rv(fgraph, beta_rv, conjugate_beta_rv, [a, b, n, y])
new_beta_free_rv = create_untransformed_free_rv(fgraph, new_beta_rv, beta_free_rv.name, beta_dims)
return [new_beta_free_rv]


posterior_optimization_db.register(
beta_binomial_conjugacy.__name__,
beta_binomial_conjugacy,
"conjugacy"
)
106 changes: 106 additions & 0 deletions pymc_experimental/sampling/optimizations/conjugate_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import numpy as np

from pymc_experimental.utils.ofg import inline_ofg_outputs
from pytensor.compile.builders import OpFromGraph
from pymc.logprob.abstract import MeasurableOp, _logprob
from pymc.distributions.distribution import _support_point
from pymc.step_methods.compound import BlockedStep, StepMethodState, Competence
from pymc.model.core import modelcontext
from pymc.util import get_value_vars_from_user_vars
from pymc.pytensorf import compile_pymc
from pytensor import shared
from pytensor.tensor.random.type import RandomGeneratorType
from pytensor.link.jax.linker import JAXLinker
from pymc.initial_point import PointType

class ConjugateRV(OpFromGraph, MeasurableOp):
"""Wrapper for ConjugateRVs, that outputs the original RV and the conjugate posterior expression.
For partial step samplers to work, the logp and initial point correspond to the original RV
while the variable itself is sampled by default by the `ConjugateRVSampler` by evaluating directly the
conjugate posterior expression (i.e., taking forward random draws).
"""


@_logprob.register(ConjugateRV)
def conjugate_rv_logp(op, values, rv, *params, **kwargs):
# Logp is the same as the original RV
return _logprob(rv.owner.op, values, *rv.owner.inputs)


@_support_point.register(ConjugateRV)
def conjugate_rv_support_point(op, conjugate_rv, rv, *params):
# Support point is the same as the original RV
return _support_point(rv.owner.op, rv, *rv.owner.inputs)


class ConjugateRVSampler(BlockedStep):
name = "conjugate_rv_sampler"
_state_class = StepMethodState

def __init__(self, vars, model=None, rng=None, compile_kwargs: dict | None = None, **kwargs):
if len(vars) != 1:
raise ValueError("ConjugateRVSampler can only be assigned to one variable at a time")

model = modelcontext(model)
[value] = get_value_vars_from_user_vars(vars, model=model)
rv = model.values_to_rvs[value]
self.vars = (value,)
self.rv_name = value.name

if model.rvs_to_transforms[rv] is not None:
raise ValueError("Variable assigned to ConjugateRVSampler cannot be transformed")

rv_and_posterior_rv_node = rv.owner
op = rv_and_posterior_rv_node.op
if not isinstance(op, ConjugateRV):
raise ValueError("Variable must be a ConjugateRV")

# Replace RVs in inputs of rv_posterior_rv_node by the corresponding value variables
value_inputs = model.replace_rvs_by_values(
[rv_and_posterior_rv_node.outputs[1]],
)[0].owner.inputs
# Inline the ConjugateRV graph to only compile `posterior_rv`
_, posterior_rv, *_ = inline_ofg_outputs(op, value_inputs)

if compile_kwargs is None:
compile_kwargs = {}
self.posterior_fn = compile_pymc(
model.value_vars,
posterior_rv,
random_seed=rng,
on_unused_input="ignore",
**compile_kwargs,
)
self.posterior_fn.trust_input = True
if isinstance(self.posterior_fn.maker.linker, JAXLinker):
# Reseeding RVs in JAX backend requires a different logic, becuase the SharedVariables
# used internally are not the ones that `function.get_shared()` returns.
raise ValueError("ConjugateRVSampler is not compatible with JAX backend")

def set_rng(self, rng: np.random.Generator):
# Copy the function and replace any shared RNGs
# This is needed so that it can work correctly with multiple traces
# This will be costly if set_rng is called too often!
shared_rngs = [
var for var in self.posterior_fn.get_shared() if isinstance(var.type, RandomGeneratorType)
]
n_shared_rngs = len(shared_rngs)
swap = {
old_shared_rng: shared(rng, borrow=True)
for old_shared_rng, rng in zip(shared_rngs, rng.spawn(n_shared_rngs), strict=True)
}
self.posterior_fn = self.posterior_fn.copy(swap=swap)

def step(self, point: PointType) -> tuple[PointType, list]:
new_point = point.copy()
new_point[self.rv_name] = self.posterior_fn(**point)
return new_point, []

@staticmethod
def competence(var, has_grad):
"""BinaryMetropolis is only suitable for Bernoulli and Categorical variables with k=2."""
if isinstance(var.owner.op, ConjugateRV):
return Competence.IDEAL

return Competence.INCOMPATIBLE
16 changes: 16 additions & 0 deletions pymc_experimental/utils/ofg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from pytensor.graph.basic import Variable
from pytensor.graph.replace import clone_replace
from pytensor.compile.builders import OpFromGraph
from typing import Sequence


def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]:
"""Inline the inner graph (outputs) of an OpFromGraph Op.
Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps"
the inner graph.
"""
return clone_replace(
op.inner_outputs,
replace=tuple(zip(op.inner_inputs, inputs)),
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ addopts = [
]

filterwarnings =[
"error",
# "error",
# Raised by arviz when the model_builder class adds non-standard group names to InferenceData
"ignore::UserWarning:arviz.data.inference_data",

Expand Down
Loading

0 comments on commit dbe62e8

Please sign in to comment.