diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 1b23e0af027..79dcc44914f 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -139,6 +139,7 @@ This includes API changes we did not warn about since at least `3.11.0` (2021-01 - `pymc.sampling_jax` samplers support `log_likelihood`, `observed_data`, and `sample_stats` in returned InferenceData object (see [#5189](https://github.com/pymc-devs/pymc/pull/5189)) - Adding support for `pm.Deterministic` in `pymc.sampling_jax` (see [#5182](https://github.com/pymc-devs/pymc/pull/5182)) - Added an alternative parametrization, `logit_p` to `pm.Binomial` and `pm.Categorical` distributions (see [5637](https://github.com/pymc-devs/pymc/pull/5637)). +- Added the low level `compile_forward_sampling_function` method to compile the aesara function responsible for generating forward samples (see [#5759](https://github.com/pymc-devs/pymc/pull/5759)). - ... diff --git a/pymc/sampling.py b/pymc/sampling.py index 452e2303229..1ec0a1ee8f8 100644 --- a/pymc/sampling.py +++ b/pymc/sampling.py @@ -43,8 +43,14 @@ import numpy as np import xarray -from aesara.graph.basic import Constant, Variable -from aesara.tensor import TensorVariable +from aesara import tensor as at +from aesara.graph.basic import Apply, Constant, Variable, general_toposort, walk +from aesara.graph.fg import FunctionGraph +from aesara.tensor.random.op import RandomVariable +from aesara.tensor.random.var import ( + RandomGeneratorSharedVariable, + RandomStateSharedVariable, +) from aesara.tensor.sharedvar import SharedVariable from arviz import InferenceData from fastprogress.fastprogress import progress_bar @@ -52,7 +58,7 @@ import pymc as pm -from pymc.aesaraf import change_rv_size, compile_pymc, inputvars, walk_model +from pymc.aesaraf import change_rv_size, compile_pymc from pymc.backends.arviz import _DefaultTrace from pymc.backends.base import BaseTrace, MultiTrace from pymc.backends.ndarray import NDArray @@ -75,6 +81,7 @@ get_default_varnames, get_untransformed_name, is_transformed_name, + point_wrapper, ) from pymc.vartypes import discrete_types @@ -83,6 +90,7 @@ __all__ = [ "sample", "iter_sample", + "compile_forward_sampling_function", "sample_posterior_predictive", "sample_posterior_predictive_w", "init_nuts", @@ -1534,6 +1542,147 @@ def stop_tuning(step): return step +def get_vars_in_point_list(trace, model): + """Get the list of Variable instances in the model that have values stored in the trace.""" + if not isinstance(trace, MultiTrace): + names_in_trace = list(trace[0]) + else: + names_in_trace = trace.varnames + vars_in_trace = [model[v] for v in names_in_trace] + return vars_in_trace + + +def compile_forward_sampling_function( + outputs: List[Variable], + vars_in_trace: List[Variable], + basic_rvs: Optional[List[Variable]] = None, + givens_dict: Optional[Dict[Variable, Any]] = None, + **kwargs, +) -> Callable[..., Union[np.ndarray, List[np.ndarray]]]: + """Compile a function to draw samples, conditioned on the values of some variables. + + The goal of this function is to walk the aesara computational graph from the list + of output nodes down to the root nodes, and then compile a function that will produce + values for these output nodes. The compiled function will take as inputs the subset of + variables in the ``vars_in_trace`` that are deemed to not be **volatile**. + + Volatile variables are variables whose values could change between runs of the + compiled function or after inference has been run. These variables are: + + - Variables in the outputs list + - ``SharedVariable`` instances that are not ``RandomStateSharedVariable`` or ``RandomGeneratorSharedVariable`` + - Basic RVs that are not in the ``vars_in_trace`` list + - Variables that are keys in the ``givens_dict`` + - Variables that have volatile inputs + + Where by basic RVs we mean ``Variable`` instances produced by a ``RandomVariable`` ``Op`` + that are in the ``basic_rvs`` list. + + Concretely, this function can be used to compile a function to sample from the + posterior predictive distribution of a model that has variables that are conditioned + on ``MutableData`` instances. The variables that depend on the mutable data will be + considered volatile, and as such, they wont be included as inputs into the compiled function. + This means that if they have values stored in the posterior, these values will be ignored + and new values will be computed (in the case of deterministics and potentials) or sampled + (in the case of random variables). + + This function also enables a way to impute values for any variable in the computational + graph that produces the desired outputs: the ``givens_dict``. This dictionary can be used + to set the ``givens`` argument of the aesara function compilation. This will essentially + replace a node in the computational graph with any other expression that has the same + type as the desired node. Passing variables in the givens_dict is considered an intervention + that might lead to different variable values from those that could have been seen during + inference, as such, **any variable that is passed in the ``givens_dict`` will be considered + volatile**. + + Parameters + ---------- + outputs : List[aesara.graph.basic.Variable] + The list of variables that will be returned by the compiled function + vars_in_trace : List[aesara.graph.basic.Variable] + The list of variables that are assumed to have values stored in the trace + basic_rvs : Optional[List[aesara.graph.basic.Variable]] + A list of random variables that are defined in the model. This list (which could be the + output of ``model.basic_RVs``) should have a reference to the variables that should + be considered as random variable instances. This includes variables that have + a ``RandomVariable`` owner op, but also unpure random variables like Mixtures, or + Censored distributions. If ``None``, only pure random variables will be considered + as potential random variables. + givens_dict : Optional[Dict[aesara.graph.basic.Variable, Any]] + A dictionary that maps tensor variables to the values that should be used to replace them + in the compiled function. The types of the key and value should match or an error will be + raised during compilation. + """ + if givens_dict is None: + givens_dict = {} + + if basic_rvs is None: + basic_rvs = [] + + # We need a function graph to walk the clients and propagate the volatile property + fg = FunctionGraph(outputs=outputs, clone=False) + + # Walk the graph from inputs to outputs and tag the volatile variables + nodes: List[Variable] = general_toposort( + fg.outputs, deps=lambda x: x.owner.inputs if x.owner else [] + ) + volatile_nodes: Set[Any] = set() + for node in nodes: + if ( + node in fg.outputs + or node in givens_dict + or ( # SharedVariables, except RandomState/Generators + isinstance(node, SharedVariable) + and not isinstance(node, (RandomStateSharedVariable, RandomGeneratorSharedVariable)) + ) + or ( # Basic RVs that are not in the trace + node.owner + and isinstance(node.owner.op, RandomVariable) + and node in basic_rvs + and node not in vars_in_trace + ) + or ( # Variables that have any volatile input + node.owner and any(inp in volatile_nodes for inp in node.owner.inputs) + ) + ): + volatile_nodes.add(node) + + # Collect the function inputs by walking the graph from the outputs. Inputs will be: + # 1. Random variables that are not volatile + # 2. Variables that have no owner and are not constant or shared + inputs = [] + + def expand(node): + if ( + ( + node.owner is None and not isinstance(node, (Constant, SharedVariable)) + ) # Variables without owners that are not constant or shared + or node in vars_in_trace # Variables in the trace + ) and node not in volatile_nodes: + # This test will include variables without owners, and that are not constant + # or shared, because these nodes will never be considered volatile + inputs.append(node) + if node.owner: + return node.owner.inputs + + # walk produces a generator, so we have to actually exhaust the generator in a list to walk + # the entire graph + list(walk(fg.outputs, expand)) + + # Populate the givens list + givens = [ + ( + node, + value + if isinstance(value, (Variable, Apply)) + else at.constant(value, dtype=getattr(node, "dtype", None), name=node.name), + ) + for node, value in givens_dict.items() + ] + + return compile_pymc(inputs, fg.outputs, givens=givens, on_unused_input="ignore", **kwargs) + + def sample_posterior_predictive( trace, samples: Optional[int] = None, @@ -1718,38 +1867,23 @@ def sample_posterior_predictive( return trace return {} - inputs: Sequence[TensorVariable] - input_names: Sequence[str] - if not isinstance(_trace, MultiTrace): - names_in_trace = list(_trace[0]) - else: - names_in_trace = _trace.varnames - inputs_and_names = [ - (rv, rv.name) - for rv in walk_model(vars_to_sample, walk_past_rvs=True) - if rv not in vars_to_sample - and rv in model.named_vars.values() - and not isinstance(rv, (Constant, SharedVariable)) - and rv.name in names_in_trace - ] - if inputs_and_names: - inputs, input_names = zip(*inputs_and_names) - else: - inputs, input_names = [], [] - if size is not None: vars_to_sample = [change_rv_size(v, size, expand=True) for v in vars_to_sample] + vars_in_trace = get_vars_in_point_list(_trace, model) if compile_kwargs is None: compile_kwargs = {} - - sampler_fn = compile_pymc( - inputs, - vars_to_sample, - allow_input_downcast=True, - accept_inplace=True, - on_unused_input="ignore", - **compile_kwargs, + compile_kwargs.setdefault("allow_input_downcast", True) + compile_kwargs.setdefault("accept_inplace", True) + + sampler_fn = point_wrapper( + compile_forward_sampling_function( + outputs=vars_to_sample, + vars_in_trace=vars_in_trace, + basic_rvs=model.basic_RVs, + givens_dict=None, + **compile_kwargs, + ) ) ppc_trace_t = _DefaultTrace(samples) @@ -1775,7 +1909,7 @@ def sample_posterior_predictive( else: param = _trace[idx % len_trace] - values = sampler_fn(*(param[n] for n in input_names)) + values = sampler_fn(**param) for k, v in zip(vars_, values): ppc_trace_t.insert(k.name, v, idx) @@ -2063,16 +2197,16 @@ def sample_prior_predictive( names.append(rv_var.name) vars_to_sample.append(rv_var) - inputs = [i for i in inputvars(vars_to_sample) if not isinstance(i, (Constant, SharedVariable))] - if compile_kwargs is None: compile_kwargs = {} + compile_kwargs.setdefault("allow_input_downcast", True) + compile_kwargs.setdefault("accept_inplace", True) - sampler_fn = compile_pymc( - inputs, + sampler_fn = compile_forward_sampling_function( vars_to_sample, - allow_input_downcast=True, - accept_inplace=True, + vars_in_trace=[], + basic_rvs=model.basic_RVs, + givens_dict=None, **compile_kwargs, ) diff --git a/pymc/tests/test_sampling.py b/pymc/tests/test_sampling.py index 8e3550021ec..e3a093d9195 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/test_sampling.py @@ -27,6 +27,7 @@ import xarray as xr from aesara import Mode, shared +from aesara.compile import SharedVariable from arviz import InferenceData from arviz import from_dict as az_from_dict from arviz.tests.helpers import check_multiple_attrs @@ -38,6 +39,7 @@ from pymc.backends.base import MultiTrace from pymc.backends.ndarray import NDArray from pymc.exceptions import IncorrectArgumentsError, SamplingError +from pymc.sampling import compile_forward_sampling_function from pymc.tests.helpers import SeededTest, fast_unstable_sampling_mode from pymc.tests.models import simple_init @@ -1446,3 +1448,207 @@ def test_no_init_nuts_compound(caplog): b = pm.Poisson("b", 1) pm.sample(10, tune=10) assert "Initializing NUTS" not in caplog.text + + +class TestCompileForwardSampler: + @staticmethod + def get_function_roots(function): + return [ + var + for var in aesara.graph.basic.graph_inputs(function.maker.fgraph.outputs) + if var.name + ] + + @staticmethod + def get_function_inputs(function): + return {i for i in function.maker.fgraph.inputs if not isinstance(i, SharedVariable)} + + def test_linear_model(self): + with pm.Model() as model: + x = pm.MutableData("x", np.linspace(0, 1, 10)) + y = pm.MutableData("y", np.ones(10)) + + alpha = pm.Normal("alpha", 0, 0.1) + beta = pm.Normal("beta", 0, 0.1) + mu = pm.Deterministic("mu", alpha + beta * x) + sigma = pm.HalfNormal("sigma", 0.1) + obs = pm.Normal("obs", mu, sigma, observed=y) + + f = compile_forward_sampling_function( + [obs], + vars_in_trace=[alpha, beta, sigma, mu], + basic_rvs=model.basic_RVs, + ) + assert {i.name for i in self.get_function_inputs(f)} == {"alpha", "beta", "sigma"} + assert {i.name for i in self.get_function_roots(f)} == {"x", "alpha", "beta", "sigma"} + + with pm.Model() as model: + x = pm.ConstantData("x", np.linspace(0, 1, 10)) + y = pm.MutableData("y", np.ones(10)) + + alpha = pm.Normal("alpha", 0, 0.1) + beta = pm.Normal("beta", 0, 0.1) + mu = pm.Deterministic("mu", alpha + beta * x) + sigma = pm.HalfNormal("sigma", 0.1) + obs = pm.Normal("obs", mu, sigma, observed=y) + + f = compile_forward_sampling_function( + [obs], + vars_in_trace=[alpha, beta, sigma, mu], + basic_rvs=model.basic_RVs, + ) + assert {i.name for i in self.get_function_inputs(f)} == {"alpha", "beta", "sigma", "mu"} + assert {i.name for i in self.get_function_roots(f)} == {"mu", "sigma"} + + def test_nested_observed_model(self): + with pm.Model() as model: + p = pm.ConstantData("p", np.array([0.25, 0.5, 0.25])) + x = pm.MutableData("x", np.zeros(10)) + y = pm.MutableData("y", np.ones(10)) + + category = pm.Categorical("category", p, observed=x) + beta = pm.Normal("beta", 0, 0.1, size=p.shape) + mu = pm.Deterministic("mu", beta[category]) + sigma = pm.HalfNormal("sigma", 0.1) + pm.Normal("obs", mu, sigma, observed=y) + + f = compile_forward_sampling_function( + outputs=model.observed_RVs, + vars_in_trace=[beta, mu, sigma], + basic_rvs=model.basic_RVs, + ) + assert {i.name for i in self.get_function_inputs(f)} == {"beta", "sigma"} + assert {i.name for i in self.get_function_roots(f)} == {"x", "p", "beta", "sigma"} + + f = compile_forward_sampling_function( + outputs=model.observed_RVs, + vars_in_trace=[beta, mu, sigma], + basic_rvs=model.basic_RVs, + givens_dict={category: np.zeros(10, dtype=category.dtype)}, + ) + assert {i.name for i in self.get_function_inputs(f)} == {"beta", "sigma"} + assert {i.name for i in self.get_function_roots(f)} == { + "x", + "p", + "category", + "beta", + "sigma", + } + + def test_volatile_parameters(self): + with pm.Model() as model: + y = pm.MutableData("y", np.ones(10)) + mu = pm.Normal("mu", 0, 1) + nested_mu = pm.Normal("nested_mu", mu, 1, size=10) + sigma = pm.HalfNormal("sigma", 1) + pm.Normal("obs", nested_mu, sigma, observed=y) + + f = compile_forward_sampling_function( + outputs=model.observed_RVs, + vars_in_trace=[nested_mu, sigma], # mu isn't in the trace and will be deemed volatile + basic_rvs=model.basic_RVs, + ) + assert {i.name for i in self.get_function_inputs(f)} == {"sigma"} + assert {i.name for i in self.get_function_roots(f)} == {"sigma"} + + f = compile_forward_sampling_function( + outputs=model.observed_RVs, + vars_in_trace=[mu, nested_mu, sigma], + basic_rvs=model.basic_RVs, + givens_dict={ + mu: np.array(1.0) + }, # mu will be considered volatile because it's in givens + ) + assert {i.name for i in self.get_function_inputs(f)} == {"sigma"} + assert {i.name for i in self.get_function_roots(f)} == {"mu", "sigma"} + + def test_distributions_op_from_graph(self): + with pm.Model() as model: + w = pm.Dirichlet("w", a=np.ones(3), size=(5, 3)) + + mu = pm.Normal("mu", mu=np.arange(3), sigma=1) + + components = pm.Normal.dist(mu=mu, sigma=1, size=w.shape) + mix_mu = pm.Mixture("mix_mu", w=w, comp_dists=components) + obs = pm.Normal("obs", mix_mu, 1, observed=np.ones((5, 3))) + + f = compile_forward_sampling_function( + outputs=[obs], + vars_in_trace=[mix_mu, mu, w], + basic_rvs=model.basic_RVs, + ) + assert {i.name for i in self.get_function_inputs(f)} == {"w", "mu", "mix_mu"} + assert {i.name for i in self.get_function_roots(f)} == {"mix_mu"} + + f = compile_forward_sampling_function( + outputs=[obs], + vars_in_trace=[mu, w], + basic_rvs=model.basic_RVs, + ) + assert {i.name for i in self.get_function_inputs(f)} == {"w", "mu"} + assert {i.name for i in self.get_function_roots(f)} == {"w", "mu"} + + f = compile_forward_sampling_function( + outputs=[obs], + vars_in_trace=[mix_mu, mu], + basic_rvs=model.basic_RVs, + ) + assert {i.name for i in self.get_function_inputs(f)} == {"mu"} + assert {i.name for i in self.get_function_roots(f)} == {"mu"} + + def test_distributions_no_op_from_graph(self): + with pm.Model() as model: + latent_mu = pm.Normal("latent_mu", mu=np.arange(3), sigma=1) + mu = pm.Censored("mu", pm.Normal.dist(mu=latent_mu, sigma=1), lower=-1, upper=1) + obs = pm.Normal("obs", mu, 1, observed=np.ones((10, 3))) + + f = compile_forward_sampling_function( + outputs=[obs], + vars_in_trace=[latent_mu, mu], + basic_rvs=model.basic_RVs, + ) + assert {i.name for i in self.get_function_inputs(f)} == {"latent_mu", "mu"} + assert {i.name for i in self.get_function_roots(f)} == {"mu"} + + f = compile_forward_sampling_function( + outputs=[obs], + vars_in_trace=[mu], + basic_rvs=model.basic_RVs, + ) + assert {i.name for i in self.get_function_inputs(f)} == set() + assert {i.name for i in self.get_function_roots(f)} == set() + + def test_lkj_cholesky_cov(self): + with pm.Model() as model: + mu = np.zeros(3) + sd_dist = pm.Exponential.dist(1.0, size=3) + chol, corr, stds = pm.LKJCholeskyCov( # pylint: disable=unpacking-non-sequence + "chol_packed", n=3, eta=2, sd_dist=sd_dist, compute_corr=True + ) + chol_packed = model["chol_packed"] + chol = pm.Deterministic("chol", chol) + obs = pm.MvNormal("obs", mu=mu, chol=chol, observed=np.zeros(3)) + + f = compile_forward_sampling_function( + outputs=[obs], + vars_in_trace=[chol_packed, chol], + basic_rvs=model.basic_RVs, + ) + assert {i.name for i in self.get_function_inputs(f)} == {"chol_packed", "chol"} + assert {i.name for i in self.get_function_roots(f)} == {"chol"} + + f = compile_forward_sampling_function( + outputs=[obs], + vars_in_trace=[chol_packed], + basic_rvs=model.basic_RVs, + ) + assert {i.name for i in self.get_function_inputs(f)} == {"chol_packed"} + assert {i.name for i in self.get_function_roots(f)} == {"chol_packed"} + + f = compile_forward_sampling_function( + outputs=[obs], + vars_in_trace=[chol], + basic_rvs=model.basic_RVs, + ) + assert {i.name for i in self.get_function_inputs(f)} == set() + assert {i.name for i in self.get_function_roots(f)} == set() diff --git a/pymc/util.py b/pymc/util.py index 8ef7d886d32..657b9a8f521 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -21,6 +21,7 @@ import numpy as np import xarray +from aesara.compile import SharedVariable from cachetools import LRUCache, cachedmethod @@ -349,3 +350,16 @@ def check_dist_not_registered(dist, model=None): f"You should use an unregistered (unnamed) distribution created via " f"the `.dist()` API instead, such as:\n`dist=pm.Normal.dist(0, 1)`" ) + + +def point_wrapper(core_function): + """Wrap an aesara compiled function to be able to ingest point dictionaries whilst + ignoring the keys that are not valid inputs to the core function. + """ + ins = [i.name for i in core_function.maker.fgraph.inputs if not isinstance(i, SharedVariable)] + + def wrapped(**kwargs): + input_point = {k: v for k, v in kwargs.items() if k in ins} + return core_function(**input_point) + + return wrapped