Skip to content

Commit

Permalink
Create base SymbolicDistribution
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jan 4, 2022
1 parent 920597a commit e5f49d9
Show file tree
Hide file tree
Showing 2 changed files with 267 additions and 3 deletions.
261 changes: 260 additions & 1 deletion pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from abc import ABCMeta
from functools import singledispatch
from typing import Callable, Optional, Sequence
from typing import Callable, Iterable, Optional, Sequence

import aesara

Expand Down Expand Up @@ -354,6 +354,265 @@ def dist(
return rv_out


class SymbolicDistribution:
def __new__(
cls,
name: str,
*args,
rngs: Optional[Iterable] = None,
dims: Optional[Dims] = None,
initval=None,
observed=None,
total_size=None,
transform=UNSET,
**kwargs,
) -> TensorVariable:
"""Adds a TensorVariable corresponding to a PyMC symbolic distribution to the
current model.
While traditional PyMC distributions are represented by a single RandomVariable
graph, Symbolic distributions correspond to a larger graph that contains one or
more RandomVariables and an arbitrary number of deterministic operations, which
represent their own kind of distribution.
The graphs returned by symbolic distributions can be evaluated directly to
obtain valid draws and can further be parsed by Aeppl to derive the
corresponding logp at runtime.
Check pymc.distributions.Censored for an example of a symbolic distribution.
Symbolic distributions must implement the following classmethods:
cls.dist
Performs input validation and converts optional alternative parametrizations
to a canonical parametrization. It should call `super().dist()`, passing a
list with the default parameters as the first and only non keyword argument,
followed by other keyword arguments like size and rngs, and return the result
cls.rv_op
Returns a TensorVariable that represents the symbolic distribution
parametrized by a default set of parameters and a size and rngs arguments
cls.ndim_supp
Returns the support of the symbolic distribution, given the default
parameters. This may not always be constant, for instance if the symbolic
distribution can be defined based on an arbitrary base distribution.
cls.change_size
Returns an equivalent symbolic distribution with a different size. This is
analogous to `pymc.aesaraf.change_rv_size` for `RandomVariable`s.
cls.graph_rvs
Returns base RVs in a symbolic distribution.
Parameters
----------
cls : type
A distribution class that inherits from SymbolicDistribution.
name : str
Name for the new model variable.
rngs : optional
Random number generator to use for the RandomVariable(s) in the graph.
dims : tuple, optional
A tuple of dimension names known to the model.
initval : optional
Numeric or symbolic untransformed initial value of matching shape,
or one of the following initial value strategies: "moment", "prior".
Depending on the sampler's settings, a random jitter may be added to numeric,
symbolic or moment-based initial values in the transformed space.
observed : optional
Observed data to be passed when registering the random variable in the model.
See ``Model.register_rv``.
total_size : float, optional
See ``Model.register_rv``.
transform : optional
See ``Model.register_rv``.
**kwargs
Keyword arguments that will be forwarded to ``.dist()``.
Most prominently: ``shape`` and ``size``
Returns
-------
var : TensorVariable
The created variable, registered in the Model.
"""

try:
from pymc.model import Model

model = Model.get_context()
except TypeError:
raise TypeError(
"No model on context stack, which is needed to "
"instantiate distributions. Add variable inside "
"a 'with model:' block, or use the '.dist' syntax "
"for a standalone distribution."
)

if "testval" in kwargs:
initval = kwargs.pop("testval")
warnings.warn(
"The `testval` argument is deprecated; use `initval`.",
FutureWarning,
stacklevel=2,
)

if not isinstance(name, string_types):
raise TypeError(f"Name needs to be a string but got: {name}")

if dims is not None and "shape" in kwargs:
raise ValueError(
f"Passing both `dims` ({dims}) and `shape` ({kwargs['shape']}) is not supported!"
)
if dims is not None and "size" in kwargs:
raise ValueError(
f"Passing both `dims` ({dims}) and `size` ({kwargs['size']}) is not supported!"
)
dims = convert_dims(dims)

if rngs is None:
# Create a temporary rv to obtain number of rngs needed
temp_graph = cls.dist(*args, rngs=None, **kwargs)
rngs = [model.next_rng() for _ in cls.graph_rvs(temp_graph)]
elif not isinstance(rngs, (list, tuple)):
rngs = [rngs]

# Create the RV without dims information, because that's not something tracked at the Aesara level.
# If necessary we'll later replicate to a different size implied by already known dims.
rv_out = cls.dist(*args, rngs=rngs, **kwargs)
ndim_actual = rv_out.ndim
resize_shape = None

# # `dims` are only available with this API, because `.dist()` can be used
# # without a modelcontext and dims are not tracked at the Aesara level.
if dims is not None:
ndim_resize, resize_shape, dims = resize_from_dims(dims, ndim_actual, model)
elif observed is not None:
ndim_resize, resize_shape, observed = resize_from_observed(observed, ndim_actual)

if resize_shape:
# A batch size was specified through `dims`, or implied by `observed`.
rv_out = cls.change_size(
rv=rv_out,
new_size=resize_shape,
)

rv_out = model.register_rv(
rv_out,
name,
observed,
total_size,
dims=dims,
transform=transform,
initval=initval,
)

# TODO: Refactor this
# add in pretty-printing support
rv_out.str_repr = lambda *args, **kwargs: name
rv_out._repr_latex_ = f"\\text{name}"
# rv_out.str_repr = types.MethodType(str_for_dist, rv_out)
# rv_out._repr_latex_ = types.MethodType(
# functools.partial(str_for_dist, formatting="latex"), rv_out
# )

return rv_out

@classmethod
def dist(
cls,
dist_params,
*,
shape: Optional[Shape] = None,
size: Optional[Size] = None,
**kwargs,
) -> TensorVariable:
"""Creates a TensorVariable corresponding to the `cls` symbolic distribution.
Parameters
----------
dist_params : array-like
The inputs to the `RandomVariable` `Op`.
shape : int, tuple, Variable, optional
A tuple of sizes for each dimension of the new RV.
An Ellipsis (...) may be inserted in the last position to short-hand refer to
all the dimensions that the RV would get if no shape/size/dims were passed at all.
size : int, tuple, Variable, optional
For creating the RV like in Aesara/NumPy.
Returns
-------
var : TensorVariable
"""

if "testval" in kwargs:
kwargs.pop("testval")
warnings.warn(
"The `.dist(testval=...)` argument is deprecated and has no effect. "
"Initial values for sampling/optimization can be specified with `initval` in a modelcontext. "
"For using Aesara's test value features, you must assign the `.tag.test_value` yourself.",
FutureWarning,
stacklevel=2,
)
if "initval" in kwargs:
raise TypeError(
"Unexpected keyword argument `initval`. "
"This argument is not available for the `.dist()` API."
)

if "dims" in kwargs:
raise NotImplementedError("The use of a `.dist(dims=...)` API is not supported.")
if shape is not None and size is not None:
raise ValueError(
f"Passing both `shape` ({shape}) and `size` ({size}) is not supported!"
)

shape = convert_shape(shape)
size = convert_size(size)

create_size, ndim_expected, ndim_batch, ndim_supp = find_size(
shape=shape, size=size, ndim_supp=cls.ndim_supp(*dist_params)
)
# Create the RV with a `size` right away.
# This is not necessarily the final result.
graph = cls.rv_op(*dist_params, size=create_size, **kwargs)
graph = maybe_resize(
graph,
cls.rv_op,
dist_params,
ndim_expected,
ndim_batch,
ndim_supp,
shape,
size,
change_rv_size_fn=cls.change_size,
**kwargs,
)

rngs = kwargs.pop("rngs", None)
if rngs is not None:
graph_rvs = cls.graph_rvs(graph)
assert len(rngs) == len(graph_rvs)
for rng, rv_out in zip(rngs, graph_rvs):
if (
rv_out.owner
and isinstance(rv_out.owner.op, RandomVariable)
and isinstance(rng, RandomStateSharedVariable)
and not getattr(rng, "default_update", None)
):
# This tells `aesara.function` that the shared RNG variable
# is mutable, which--in turn--tells the `FunctionGraph`
# `Supervisor` feature to allow in-place updates on the variable.
# Without it, the `RandomVariable`s could not be optimized to allow
# in-place RNG updates, forcing all sample results from compiled
# functions to be the same on repeated evaluations.
new_rng = rv_out.owner.outputs[0]
rv_out.update = (rng, new_rng)
rng.default_update = new_rng

# TODO: Create new attr error stating that these are not available for DerivedDistribution
# rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
# rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
# rv_out.random = _make_nice_attr_error("rv.random()", "rv.eval()")
return graph


@singledispatch
def _get_moment(op, rv, *rv_inputs) -> TensorVariable:
raise NotImplementedError(f"Variable {rv} of type {op} has no get_moment implementation.")
Expand Down
9 changes: 7 additions & 2 deletions pymc/distributions/shape_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import warnings

from functools import partial
from typing import Optional, Sequence, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -612,6 +613,8 @@ def maybe_resize(
ndim_supp,
shape,
size,
*,
change_rv_size_fn=partial(change_rv_size, expand=True),
**kwargs,
):
"""Resize a distribution if necessary.
Expand All @@ -635,6 +638,8 @@ def maybe_resize(
A tuple specifying the final shape of a distribution
size : tuple
A tuple specifying the size of a distribution
change_rv_size_fn: callable
A function that returns an equivalent RV with a different size
Returns
-------
Expand All @@ -647,7 +652,7 @@ def maybe_resize(
if shape is not None and ndims_unexpected:
if Ellipsis in shape:
# Resize and we're done!
rv_out = change_rv_size(rv_var=rv_out, new_size=shape[:-1], expand=True)
rv_out = change_rv_size_fn(rv_var=rv_out, new_size=shape[:-1])
else:
# This is rare, but happens, for example, with MvNormal(np.ones((2, 3)), np.eye(3), shape=(2, 3)).
# Recreate the RV without passing `size` to created it with just the implied dimensions.
Expand All @@ -656,7 +661,7 @@ def maybe_resize(
# Now resize by any remaining "extra" dimensions that were not implied from support and parameters
if rv_out.ndim < ndim_expected:
expand_shape = shape[: ndim_expected - rv_out.ndim]
rv_out = change_rv_size(rv_var=rv_out, new_size=expand_shape, expand=True)
rv_out = change_rv_size_fn(rv_var=rv_out, new_size=expand_shape)
if not rv_out.ndim == ndim_expected:
raise ShapeError(
f"Failed to create the RV with the expected dimensionality. "
Expand Down

0 comments on commit e5f49d9

Please sign in to comment.