Skip to content

Commit

Permalink
Add compile_forward_sampling_function
Browse files Browse the repository at this point in the history
  • Loading branch information
lucianopaz authored and ricardoV94 committed May 17, 2022
1 parent 4969460 commit 862bd05
Show file tree
Hide file tree
Showing 4 changed files with 392 additions and 37 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ This includes API changes we did not warn about since at least `3.11.0` (2021-01
- `pymc.sampling_jax` samplers support `log_likelihood`, `observed_data`, and `sample_stats` in returned InferenceData object (see [#5189](https://github.com/pymc-devs/pymc/pull/5189))
- Adding support for `pm.Deterministic` in `pymc.sampling_jax` (see [#5182](https://github.com/pymc-devs/pymc/pull/5182))
- Added an alternative parametrization, `logit_p` to `pm.Binomial` and `pm.Categorical` distributions (see [5637](https://github.com/pymc-devs/pymc/pull/5637)).
- Added the low level `compile_forward_sampling_function` method to compile the aesara function responsible for generating forward samples (see [#5759](https://github.com/pymc-devs/pymc/pull/5759)).
- ...


Expand Down
208 changes: 171 additions & 37 deletions pymc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,22 @@
import numpy as np
import xarray

from aesara.graph.basic import Constant, Variable
from aesara.tensor import TensorVariable
from aesara import tensor as at
from aesara.graph.basic import Apply, Constant, Variable, general_toposort, walk
from aesara.graph.fg import FunctionGraph
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.random.var import (
RandomGeneratorSharedVariable,
RandomStateSharedVariable,
)
from aesara.tensor.sharedvar import SharedVariable
from arviz import InferenceData
from fastprogress.fastprogress import progress_bar
from typing_extensions import TypeAlias

import pymc as pm

from pymc.aesaraf import change_rv_size, compile_pymc, inputvars, walk_model
from pymc.aesaraf import change_rv_size, compile_pymc
from pymc.backends.arviz import _DefaultTrace
from pymc.backends.base import BaseTrace, MultiTrace
from pymc.backends.ndarray import NDArray
Expand All @@ -75,6 +81,7 @@
get_default_varnames,
get_untransformed_name,
is_transformed_name,
point_wrapper,
)
from pymc.vartypes import discrete_types

Expand All @@ -83,6 +90,7 @@
__all__ = [
"sample",
"iter_sample",
"compile_forward_sampling_function",
"sample_posterior_predictive",
"sample_posterior_predictive_w",
"init_nuts",
Expand Down Expand Up @@ -1534,6 +1542,147 @@ def stop_tuning(step):
return step


def get_vars_in_point_list(trace, model):
"""Get the list of Variable instances in the model that have values stored in the trace."""
if not isinstance(trace, MultiTrace):
names_in_trace = list(trace[0])
else:
names_in_trace = trace.varnames
vars_in_trace = [model[v] for v in names_in_trace]
return vars_in_trace


def compile_forward_sampling_function(
outputs: List[Variable],
vars_in_trace: List[Variable],
basic_rvs: Optional[List[Variable]] = None,
givens_dict: Optional[Dict[Variable, Any]] = None,
**kwargs,
) -> Callable[..., Union[np.ndarray, List[np.ndarray]]]:
"""Compile a function to draw samples, conditioned on the values of some variables.
The goal of this function is to walk the aesara computational graph from the list
of output nodes down to the root nodes, and then compile a function that will produce
values for these output nodes. The compiled function will take as inputs the subset of
variables in the ``vars_in_trace`` that are deemed to not be **volatile**.
Volatile variables are variables whose values could change between runs of the
compiled function or after inference has been run. These variables are:
- Variables in the outputs list
- ``SharedVariable`` instances that are not ``RandomStateSharedVariable`` or ``RandomGeneratorSharedVariable``
- Basic RVs that are not in the ``vars_in_trace`` list
- Variables that are keys in the ``givens_dict``
- Variables that have volatile inputs
Where by basic RVs we mean ``Variable`` instances produced by a ``RandomVariable`` ``Op``
that are in the ``basic_rvs`` list.
Concretely, this function can be used to compile a function to sample from the
posterior predictive distribution of a model that has variables that are conditioned
on ``MutableData`` instances. The variables that depend on the mutable data will be
considered volatile, and as such, they wont be included as inputs into the compiled function.
This means that if they have values stored in the posterior, these values will be ignored
and new values will be computed (in the case of deterministics and potentials) or sampled
(in the case of random variables).
This function also enables a way to impute values for any variable in the computational
graph that produces the desired outputs: the ``givens_dict``. This dictionary can be used
to set the ``givens`` argument of the aesara function compilation. This will essentially
replace a node in the computational graph with any other expression that has the same
type as the desired node. Passing variables in the givens_dict is considered an intervention
that might lead to different variable values from those that could have been seen during
inference, as such, **any variable that is passed in the ``givens_dict`` will be considered
volatile**.
Parameters
----------
outputs : List[aesara.graph.basic.Variable]
The list of variables that will be returned by the compiled function
vars_in_trace : List[aesara.graph.basic.Variable]
The list of variables that are assumed to have values stored in the trace
basic_rvs : Optional[List[aesara.graph.basic.Variable]]
A list of random variables that are defined in the model. This list (which could be the
output of ``model.basic_RVs``) should have a reference to the variables that should
be considered as random variable instances. This includes variables that have
a ``RandomVariable`` owner op, but also unpure random variables like Mixtures, or
Censored distributions. If ``None``, only pure random variables will be considered
as potential random variables.
givens_dict : Optional[Dict[aesara.graph.basic.Variable, Any]]
A dictionary that maps tensor variables to the values that should be used to replace them
in the compiled function. The types of the key and value should match or an error will be
raised during compilation.
"""
if givens_dict is None:
givens_dict = {}

if basic_rvs is None:
basic_rvs = []

# We need a function graph to walk the clients and propagate the volatile property
fg = FunctionGraph(outputs=outputs, clone=False)

# Walk the graph from inputs to outputs and tag the volatile variables
nodes: List[Variable] = general_toposort(
fg.outputs, deps=lambda x: x.owner.inputs if x.owner else []
)
volatile_nodes: Set[Any] = set()
for node in nodes:
if (
node in fg.outputs
or node in givens_dict
or ( # SharedVariables, except RandomState/Generators
isinstance(node, SharedVariable)
and not isinstance(node, (RandomStateSharedVariable, RandomGeneratorSharedVariable))
)
or ( # Basic RVs that are not in the trace
node.owner
and isinstance(node.owner.op, RandomVariable)
and node in basic_rvs
and node not in vars_in_trace
)
or ( # Variables that have any volatile input
node.owner and any(inp in volatile_nodes for inp in node.owner.inputs)
)
):
volatile_nodes.add(node)

# Collect the function inputs by walking the graph from the outputs. Inputs will be:
# 1. Random variables that are not volatile
# 2. Variables that have no owner and are not constant or shared
inputs = []

def expand(node):
if (
(
node.owner is None and not isinstance(node, (Constant, SharedVariable))
) # Variables without owners that are not constant or shared
or node in vars_in_trace # Variables in the trace
) and node not in volatile_nodes:
# This test will include variables without owners, and that are not constant
# or shared, because these nodes will never be considered volatile
inputs.append(node)
if node.owner:
return node.owner.inputs

# walk produces a generator, so we have to actually exhaust the generator in a list to walk
# the entire graph
list(walk(fg.outputs, expand))

# Populate the givens list
givens = [
(
node,
value
if isinstance(value, (Variable, Apply))
else at.constant(value, dtype=getattr(node, "dtype", None), name=node.name),
)
for node, value in givens_dict.items()
]

return compile_pymc(inputs, fg.outputs, givens=givens, on_unused_input="ignore", **kwargs)


def sample_posterior_predictive(
trace,
samples: Optional[int] = None,
Expand Down Expand Up @@ -1718,38 +1867,23 @@ def sample_posterior_predictive(
return trace
return {}

inputs: Sequence[TensorVariable]
input_names: Sequence[str]
if not isinstance(_trace, MultiTrace):
names_in_trace = list(_trace[0])
else:
names_in_trace = _trace.varnames
inputs_and_names = [
(rv, rv.name)
for rv in walk_model(vars_to_sample, walk_past_rvs=True)
if rv not in vars_to_sample
and rv in model.named_vars.values()
and not isinstance(rv, (Constant, SharedVariable))
and rv.name in names_in_trace
]
if inputs_and_names:
inputs, input_names = zip(*inputs_and_names)
else:
inputs, input_names = [], []

if size is not None:
vars_to_sample = [change_rv_size(v, size, expand=True) for v in vars_to_sample]
vars_in_trace = get_vars_in_point_list(_trace, model)

if compile_kwargs is None:
compile_kwargs = {}

sampler_fn = compile_pymc(
inputs,
vars_to_sample,
allow_input_downcast=True,
accept_inplace=True,
on_unused_input="ignore",
**compile_kwargs,
compile_kwargs.setdefault("allow_input_downcast", True)
compile_kwargs.setdefault("accept_inplace", True)

sampler_fn = point_wrapper(
compile_forward_sampling_function(
outputs=vars_to_sample,
vars_in_trace=vars_in_trace,
basic_rvs=model.basic_RVs,
givens_dict=None,
**compile_kwargs,
)
)

ppc_trace_t = _DefaultTrace(samples)
Expand All @@ -1775,7 +1909,7 @@ def sample_posterior_predictive(
else:
param = _trace[idx % len_trace]

values = sampler_fn(*(param[n] for n in input_names))
values = sampler_fn(**param)

for k, v in zip(vars_, values):
ppc_trace_t.insert(k.name, v, idx)
Expand Down Expand Up @@ -2063,16 +2197,16 @@ def sample_prior_predictive(
names.append(rv_var.name)
vars_to_sample.append(rv_var)

inputs = [i for i in inputvars(vars_to_sample) if not isinstance(i, (Constant, SharedVariable))]

if compile_kwargs is None:
compile_kwargs = {}
compile_kwargs.setdefault("allow_input_downcast", True)
compile_kwargs.setdefault("accept_inplace", True)

sampler_fn = compile_pymc(
inputs,
sampler_fn = compile_forward_sampling_function(
vars_to_sample,
allow_input_downcast=True,
accept_inplace=True,
vars_in_trace=[],
basic_rvs=model.basic_RVs,
givens_dict=None,
**compile_kwargs,
)

Expand Down
Loading

0 comments on commit 862bd05

Please sign in to comment.