diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index c4a68c6e028..cf9d87eb0b3 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -19,7 +19,7 @@ from abc import ABCMeta from functools import singledispatch -from typing import Callable, Optional, Sequence +from typing import Callable, Iterable, Optional, Sequence import aesara @@ -354,6 +354,265 @@ def dist( return rv_out +class SymbolicDistribution: + def __new__( + cls, + name: str, + *args, + rngs: Optional[Iterable] = None, + dims: Optional[Dims] = None, + initval=None, + observed=None, + total_size=None, + transform=UNSET, + **kwargs, + ) -> TensorVariable: + """Adds a TensorVariable corresponding to a PyMC symbolic distribution to the + current model. + + While traditional PyMC distributions are represented by a single RandomVariable + graph, Symbolic distributions correspond to a larger graph that contains one or + more RandomVariables and an arbitrary number of deterministic operations, which + represent their own kind of distribution. + + The graphs returned by symbolic distributions can be evaluated directly to + obtain valid draws and can further be parsed by Aeppl to derive the + corresponding logp at runtime. + + Check pymc.distributions.Censored for an example of a symbolic distribution. + + Symbolic distributions must implement the following classmethods: + cls.dist + Performs input validation and converts optional alternative parametrizations + to a canonical parametrization. It should call `super().dist()`, passing a + list with the default parameters as the first and only non keyword argument, + followed by other keyword arguments like size and rngs, and return the result + cls.rv_op + Returns a TensorVariable that represents the symbolic distribution + parametrized by a default set of parameters and a size and rngs arguments + cls.ndim_supp + Returns the support of the symbolic distribution, given the default + parameters. This may not always be constant, for instance if the symbolic + distribution can be defined based on an arbitrary base distribution. + cls.change_size + Returns an equivalent symbolic distribution with a different size. This is + analogous to `pymc.aesaraf.change_rv_size` for `RandomVariable`s. + cls.graph_rvs + Returns base RVs in a symbolic distribution. + + Parameters + ---------- + cls : type + A distribution class that inherits from SymbolicDistribution. + name : str + Name for the new model variable. + rngs : optional + Random number generator to use for the RandomVariable(s) in the graph. + dims : tuple, optional + A tuple of dimension names known to the model. + initval : optional + Numeric or symbolic untransformed initial value of matching shape, + or one of the following initial value strategies: "moment", "prior". + Depending on the sampler's settings, a random jitter may be added to numeric, + symbolic or moment-based initial values in the transformed space. + observed : optional + Observed data to be passed when registering the random variable in the model. + See ``Model.register_rv``. + total_size : float, optional + See ``Model.register_rv``. + transform : optional + See ``Model.register_rv``. + **kwargs + Keyword arguments that will be forwarded to ``.dist()``. + Most prominently: ``shape`` and ``size`` + + Returns + ------- + var : TensorVariable + The created variable, registered in the Model. + """ + + try: + from pymc.model import Model + + model = Model.get_context() + except TypeError: + raise TypeError( + "No model on context stack, which is needed to " + "instantiate distributions. Add variable inside " + "a 'with model:' block, or use the '.dist' syntax " + "for a standalone distribution." + ) + + if "testval" in kwargs: + initval = kwargs.pop("testval") + warnings.warn( + "The `testval` argument is deprecated; use `initval`.", + FutureWarning, + stacklevel=2, + ) + + if not isinstance(name, string_types): + raise TypeError(f"Name needs to be a string but got: {name}") + + if dims is not None and "shape" in kwargs: + raise ValueError( + f"Passing both `dims` ({dims}) and `shape` ({kwargs['shape']}) is not supported!" + ) + if dims is not None and "size" in kwargs: + raise ValueError( + f"Passing both `dims` ({dims}) and `size` ({kwargs['size']}) is not supported!" + ) + dims = convert_dims(dims) + + if rngs is None: + # Create a temporary rv to obtain number of rngs needed + temp_graph = cls.dist(*args, rngs=None, **kwargs) + rngs = [model.next_rng() for _ in cls.graph_rvs(temp_graph)] + elif not isinstance(rngs, (list, tuple)): + rngs = [rngs] + + # Create the RV without dims information, because that's not something tracked at the Aesara level. + # If necessary we'll later replicate to a different size implied by already known dims. + rv_out = cls.dist(*args, rngs=rngs, **kwargs) + ndim_actual = rv_out.ndim + resize_shape = None + + # # `dims` are only available with this API, because `.dist()` can be used + # # without a modelcontext and dims are not tracked at the Aesara level. + if dims is not None: + ndim_resize, resize_shape, dims = resize_from_dims(dims, ndim_actual, model) + elif observed is not None: + ndim_resize, resize_shape, observed = resize_from_observed(observed, ndim_actual) + + if resize_shape: + # A batch size was specified through `dims`, or implied by `observed`. + rv_out = cls.change_size( + rv=rv_out, + new_size=resize_shape, + ) + + rv_out = model.register_rv( + rv_out, + name, + observed, + total_size, + dims=dims, + transform=transform, + initval=initval, + ) + + # TODO: Refactor this + # add in pretty-printing support + rv_out.str_repr = lambda *args, **kwargs: name + rv_out._repr_latex_ = f"\\text{name}" + # rv_out.str_repr = types.MethodType(str_for_dist, rv_out) + # rv_out._repr_latex_ = types.MethodType( + # functools.partial(str_for_dist, formatting="latex"), rv_out + # ) + + return rv_out + + @classmethod + def dist( + cls, + dist_params, + *, + shape: Optional[Shape] = None, + size: Optional[Size] = None, + **kwargs, + ) -> TensorVariable: + """Creates a TensorVariable corresponding to the `cls` symbolic distribution. + + Parameters + ---------- + dist_params : array-like + The inputs to the `RandomVariable` `Op`. + shape : int, tuple, Variable, optional + A tuple of sizes for each dimension of the new RV. + + An Ellipsis (...) may be inserted in the last position to short-hand refer to + all the dimensions that the RV would get if no shape/size/dims were passed at all. + size : int, tuple, Variable, optional + For creating the RV like in Aesara/NumPy. + + Returns + ------- + var : TensorVariable + """ + + if "testval" in kwargs: + kwargs.pop("testval") + warnings.warn( + "The `.dist(testval=...)` argument is deprecated and has no effect. " + "Initial values for sampling/optimization can be specified with `initval` in a modelcontext. " + "For using Aesara's test value features, you must assign the `.tag.test_value` yourself.", + FutureWarning, + stacklevel=2, + ) + if "initval" in kwargs: + raise TypeError( + "Unexpected keyword argument `initval`. " + "This argument is not available for the `.dist()` API." + ) + + if "dims" in kwargs: + raise NotImplementedError("The use of a `.dist(dims=...)` API is not supported.") + if shape is not None and size is not None: + raise ValueError( + f"Passing both `shape` ({shape}) and `size` ({size}) is not supported!" + ) + + shape = convert_shape(shape) + size = convert_size(size) + + create_size, ndim_expected, ndim_batch, ndim_supp = find_size( + shape=shape, size=size, ndim_supp=cls.ndim_supp(*dist_params) + ) + # Create the RV with a `size` right away. + # This is not necessarily the final result. + graph = cls.rv_op(*dist_params, size=create_size, **kwargs) + graph = maybe_resize( + graph, + cls.rv_op, + dist_params, + ndim_expected, + ndim_batch, + ndim_supp, + shape, + size, + change_rv_size_fn=cls.change_size, + **kwargs, + ) + + rngs = kwargs.pop("rngs", None) + if rngs is not None: + graph_rvs = cls.graph_rvs(graph) + assert len(rngs) == len(graph_rvs) + for rng, rv_out in zip(rngs, graph_rvs): + if ( + rv_out.owner + and isinstance(rv_out.owner.op, RandomVariable) + and isinstance(rng, RandomStateSharedVariable) + and not getattr(rng, "default_update", None) + ): + # This tells `aesara.function` that the shared RNG variable + # is mutable, which--in turn--tells the `FunctionGraph` + # `Supervisor` feature to allow in-place updates on the variable. + # Without it, the `RandomVariable`s could not be optimized to allow + # in-place RNG updates, forcing all sample results from compiled + # functions to be the same on repeated evaluations. + new_rng = rv_out.owner.outputs[0] + rv_out.update = (rng, new_rng) + rng.default_update = new_rng + + # TODO: Create new attr error stating that these are not available for DerivedDistribution + # rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)") + # rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)") + # rv_out.random = _make_nice_attr_error("rv.random()", "rv.eval()") + return graph + + @singledispatch def _get_moment(op, rv, *rv_inputs) -> TensorVariable: raise NotImplementedError(f"Variable {rv} of type {op} has no get_moment implementation.") diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index f74b61e7a0d..28117c3353b 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -20,6 +20,7 @@ import warnings +from functools import partial from typing import Optional, Sequence, Tuple, Union import numpy as np @@ -612,6 +613,8 @@ def maybe_resize( ndim_supp, shape, size, + *, + change_rv_size_fn=partial(change_rv_size, expand=True), **kwargs, ): """Resize a distribution if necessary. @@ -635,6 +638,8 @@ def maybe_resize( A tuple specifying the final shape of a distribution size : tuple A tuple specifying the size of a distribution + change_rv_size_fn: callable + A function that returns an equivalent RV with a different size Returns ------- @@ -647,7 +652,7 @@ def maybe_resize( if shape is not None and ndims_unexpected: if Ellipsis in shape: # Resize and we're done! - rv_out = change_rv_size(rv_var=rv_out, new_size=shape[:-1], expand=True) + rv_out = change_rv_size_fn(rv_var=rv_out, new_size=shape[:-1]) else: # This is rare, but happens, for example, with MvNormal(np.ones((2, 3)), np.eye(3), shape=(2, 3)). # Recreate the RV without passing `size` to created it with just the implied dimensions. @@ -656,7 +661,7 @@ def maybe_resize( # Now resize by any remaining "extra" dimensions that were not implied from support and parameters if rv_out.ndim < ndim_expected: expand_shape = shape[: ndim_expected - rv_out.ndim] - rv_out = change_rv_size(rv_var=rv_out, new_size=expand_shape, expand=True) + rv_out = change_rv_size_fn(rv_var=rv_out, new_size=expand_shape) if not rv_out.ndim == ndim_expected: raise ShapeError( f"Failed to create the RV with the expected dimensionality. "