Skip to content

Commit

Permalink
Refactor DensityDist into v4
Browse files Browse the repository at this point in the history
  • Loading branch information
lucianopaz committed Sep 25, 2021
1 parent 37ba9a3 commit 8a6b9b3
Show file tree
Hide file tree
Showing 12 changed files with 364 additions and 196 deletions.
5 changes: 5 additions & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
- `pm.sample` now returns results as `InferenceData` instead of `MultiTrace` by default (see [#4744](https://github.com/pymc-devs/pymc/pull/4744)).
- `pm.sample_prior_predictive` no longer returns transformed variable values by default. Pass them by name in `var_names` if you want to obtain these draws (see [4769](https://github.com/pymc-devs/pymc/pull/4769)).
-`pm.Bound` interface no longer accepts a callable class as argument, instead it requires an instantiated distribution (created via the `.dist()` API) to be passed as an argument. In addition, Bound no longer returns a class instance but works as a normal PyMC distribution. Finally, it is no longer possible to do predictive random sampling from Bounded variables. Please, consult the new documentation for details on how to use Bounded variables (see [4815](https://github.com/pymc-devs/pymc/pull/4815)).
- `pm.DensityDist` no longer accepts the `logp` as its first position argument. It is now an optional keyword argument. If you pass a callable as the first positional argument, a `TypeError` will be raised (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)).
- `pm.DensityDist` now accepts distribution parameters as positional arguments. Passing them as a dictionary in the `observed` keyword argument is no longer supported and will raise an error (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)).
- The signature of the `logp` and `random` functions that can be passed into a `pm.DensityDist` has been changed (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)).
- ...

### New Features
Expand All @@ -25,6 +28,8 @@
- The `Polya-Gamma` distribution has been added (see [#4531](https://github.com/pymc-devs/pymc/pull/4531)). To make use of this distribution, the [`polyagamma>=1.3.1`](https://pypi.org/project/polyagamma/) library must be installed and available in the user's environment.
- A small change to the mass matrix tuning methods jitter+adapt_diag (the default) and adapt_diag improves performance early on during tuning for some models. [#5004](https://github.com/pymc-devs/pymc/pull/5004)
- New experimental mass matrix tuning method jitter+adapt_diag_grad. [#5004](https://github.com/pymc-devs/pymc/pull/5004)
- `pm.DensityDist` can now accept an optional `logcdf` keyword argument to pass in a function to compute the cummulative density function of the distribution (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)).
- `pm.DensityDist` can now accept an optional `get_moment` keyword argument to pass in a function to compute the moment of the distribution (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)).
- ...

### Maintenance
Expand Down
8 changes: 4 additions & 4 deletions docs/source/Probability_Distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,16 @@ An exponential survival function, where :math:`c=0` denotes failure (or non-surv
f(c, t) = \left\{ \begin{array}{l} \exp(-\lambda t), \text{if c=1} \\
\lambda \exp(-\lambda t), \text{if c=0} \end{array} \right.
Such a function can be implemented as a PyMC distribution by writing a function that specifies the log-probability, then passing that function as an argument to the ``DensityDist`` function, which creates an instance of a PyMC distribution with the custom function as its log-probability.
Such a function can be implemented as a PyMC distribution by writing a function that specifies the log-probability, then passing that function as a keyword argument to the ``DensityDist`` function, which creates an instance of a PyMC3 distribution with the custom function as its log-probability.

For the exponential survival function, this is:

::

def logp(failure, value):
return (failure * log(λ) - λ * value).sum()
def logp(value, t, λ):
return (value * log(λ) - λ * t).sum()

exp_surv = pm.DensityDist('exp_surv', logp, observed={'failure':failure, 'value':t})
exp_surv = pm.DensityDist('exp_surv', t, λ, logp=logp, observed=failure)

Similarly, if a random number generator is required, a function returning random numbers corresponding to the probability distribution can be passed as the ``random`` argument.

Expand Down
266 changes: 207 additions & 59 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,18 @@
# limitations under the License.
import contextvars
import functools
import multiprocessing
import sys
import types
import warnings

from abc import ABCMeta
from functools import singledispatch
from typing import Optional
from typing import Callable, Optional, Sequence

import aesara
import numpy as np

from aesara.tensor.basic import as_tensor_variable
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.random.var import RandomStateSharedVariable
from aesara.tensor.var import TensorVariable
Expand All @@ -41,12 +42,14 @@
maybe_resize,
resize_from_dims,
resize_from_observed,
to_tuple,
)
from pymc.printing import str_for_dist
from pymc.util import UNSET
from pymc.vartypes import string_types

__all__ = [
"DensityDistRV",
"DensityDist",
"Distribution",
"Continuous",
Expand Down Expand Up @@ -387,96 +390,241 @@ class NoDistribution(Distribution):
"""


class DensityDist(Distribution):
"""Distribution based on a given log density function.
class DensityDistRV(RandomVariable):
"""
Base class for DensityDistRV
This should be subclassed when defining custom DensityDist objects.
"""

name = "DensityDistRV"
_print_name = ("DensityDist", "\\operatorname{DensityDist}")

@classmethod
def rng_fn(cls, rng, *args):
args = list(args)
size = args.pop(-1)
return cls._random_fn(*args, rng=rng, size=size)

A distribution with the passed log density function is created.
Requires a custom random function passed as kwarg `random` to
enable prior or posterior predictive sampling.

class DensityDist(NoDistribution):
"""A distribution that can be used to wrap black-box log density functions.
Creates a Distribution and registers the supplied log density function to be used
for inference. It is also possible to supply a `random` method in order to be able
to sample from the prior or posterior predictive distributions.
"""

def __init__(
self,
logp,
shape=(),
dtype=None,
initval=0,
random=None,
wrap_random_with_dist_shape=True,
check_shape_in_random=True,
*args,
def __new__(
cls,
name: str,
*dist_params,
logp: Optional[Callable] = None,
logcdf: Optional[Callable] = None,
random: Optional[Callable] = None,
get_moment: Optional[Callable] = None,
ndim_supp: int = 0,
ndims_params: Optional[Sequence[int]] = None,
dtype: str = "floatX",
**kwargs,
):
"""
Parameters
----------
logp: callable
A callable that has the following signature ``logp(value)`` and
returns an Aesara tensor that represents the distribution's log
probability density.
shape: tuple (Optional): defaults to `()`
The shape of the distribution. The default value indicates a scalar.
If the distribution is *not* scalar-valued, the programmer should pass
a value here.
dtype: None, str (Optional)
The dtype of the distribution.
initval: number or array (Optional)
The ``initval`` of the RV's tensor that follow the ``DensityDist``
distribution.
args, kwargs: (Optional)
These are passed to the parent class' ``__init__``.
name : str
dist_params : Tuple
A sequence of the distribution's parameter. These will be converted into
aesara tensors internally. These parameters could be other ``RandomVariable``
instances.
logp : Optional[Callable]
A callable that calculates the log density of some given observed ``value``
conditioned on certain distribution parameter values. It must have the
following signature: ``logp(value, *dist_params)``, where ``value`` is
an Aesara tensor that represents the observed value, and ``dist_params``
are the tensors that hold the values of the distribution parameters.
This function must return an Aesara tensor. If ``None``, a ``NotImplemented``
error will be raised when trying to compute the distribution's logp.
logcdf : Optional[Callable]
A callable that calculates the log cummulative probability of some given observed
``value`` conditioned on certain distribution parameter values. It must have the
following signature: ``logcdf(value, *dist_params)``, where ``value`` is
an Aesara tensor that represents the observed value, and ``dist_params``
are the tensors that hold the values of the distribution parameters.
This function must return an Aesara tensor. If ``None``, a ``NotImplemented``
error will be raised when trying to compute the distribution's logcdf.
random : Optional[Callable]
A callable that can be used to generate random draws from the distribution.
It must have the following signature: ``random(*dist_params, rng=None, size=None)``.
The distribution parameters are passed as positional arguments in the
same order as they are supplied when the ``DensityDist`` is constructed.
The keyword arguments are ``rnd``, which will provide the random variable's
associated :py:class:`~numpy.random.Generator`, and ``size``, that will represent
the desired size of the random draw. If ``None``, a ``NotImplemented``
error will be raised when trying to draw random samples from the distribution's
prior or posterior predictive.
get_moment : Optional[Callable]
A callable that can be used to compute the moments of the distribution.
It must have the following signature: ``get_moment(rv, size, *rv_inputs)``.
The distribution's :class:`~aesara.tensor.random.op.RandomVariable` is passed
as the first argument ``rv``. ``size`` is the random variable's size implied
by the ``dims``, ``size`` and parameters supplied to the distribution. Finally,
``rv_inputs`` is the sequence of the distribution parameters, in the same order
as they were supplied when the DensityDist was created. If ``None``, a
``NotImplemented`` error will be raised when trying to draw random samples from
the distribution's prior or posterior predictive.
ndim_supp : int
The number of dimensions in the support of the distribution. Defaults to assuming
a scalar distribution, i.e. ``ndim_supp = 0``.
ndims_params : Optional[Sequence[int]]
The list of number of dimensions in the support of each of the distribution's
parameters. If ``None``, it is assumed that all parameters are scalars, hence
the number of dimensions of their support will be 0.
dtype : str
The dtype of the distribution. All draws and observations passed into the distribution
will be casted onto this dtype.
kwargs :
Extra keyword arguments are passed to the parent's class ``__new__`` method.
Examples
--------
.. code-block:: python
def logp(value, mu):
return -(value - mu)**2
with pm.Model():
mu = pm.Normal('mu',0,1)
normal_dist = pm.Normal.dist(mu, 1)
pm.DensityDist(
'density_dist',
normal_dist.logp,
mu,
logp=logp,
observed=np.random.randn(100),
)
idata = pm.sample(100)
.. code-block:: python
def logp(value, mu):
return -(value - mu)**2
def random(mu, rng=None, size=None):
return rng.normal(loc=mu, scale=1, size=size)
with pm.Model():
mu = pm.Normal('mu', 0 , 1)
normal_dist = pm.Normal.dist(mu, 1, shape=3)
dens = pm.DensityDist(
'density_dist',
normal_dist.logp,
mu,
logp=logp,
random=random,
observed=np.random.randn(100, 3),
shape=3,
size=(100, 3),
)
prior = pm.sample_prior_predictive(10)['density_dist']
assert prior.shape == (10, 100, 3)
"""
if dtype is None:

if dist_params is None:
dist_params = []
elif len(dist_params) > 0 and callable(dist_params[0]):
raise TypeError(
"The DensityDist API has changed, you are using the old API "
"where logp was the first positional argument. In the current API, "
"the logp is a keyword argument, amongst other changes. Please refer "
"to the API documentation for more information on how to use the "
"new DensityDist API."
)
dist_params = [as_tensor_variable(param) for param in dist_params]

# Assume scalar ndims_params
if ndims_params is None:
ndims_params = [0] * len(dist_params)

if logp is None:
logp = default_not_implemented(name, "logp")

if logcdf is None:
logcdf = default_not_implemented(name, "logcdf")

if random is None:
random = default_not_implemented(name, "random")

if get_moment is None:
get_moment = default_not_implemented(name, "get_moment")

rv_op = type(
f"DensityDist_{name}",
(DensityDistRV,),
dict(
name=f"DensityDist_{name}",
inplace=False,
ndim_supp=ndim_supp,
ndims_params=ndims_params,
dtype=dtype,
# Specifc to DensityDist
_random_fn=random,
),
)()

# Register custom logp
rv_type = type(rv_op)

@_logp.register(rv_type)
def density_dist_logp(op, rv, rvs_to_values, *dist_params, **kwargs):
value_var = rvs_to_values.get(rv, rv)
return logp(
value_var,
*dist_params,
)

@_logcdf.register(rv_type)
def density_dist_logcdf(op, var, rvs_to_values, *dist_params, **kwargs):
value_var = rvs_to_values.get(var, var)
return logcdf(value_var, *dist_params, **kwargs)

@_get_moment.register(rv_type)
def density_dist_get_moment(op, rv, size, *rv_inputs):
return get_moment(rv, size, *rv_inputs)

cls.rv_op = rv_op
return super().__new__(cls, name, *dist_params, **kwargs)

@classmethod
def dist(cls, *args, **kwargs):
output = super().dist(args, **kwargs)
if cls.rv_op.dtype == "floatX":
dtype = aesara.config.floatX
super().__init__(shape, dtype, initval, *args, **kwargs)
self.logp = logp
if type(self.logp) == types.MethodType:
if PLATFORM != "linux":
warnings.warn(
"You are passing a bound method as logp for DensityDist, this can lead to "
"errors when sampling on platforms other than Linux. Consider using a "
"plain function instead, or subclass Distribution."
)
elif type(multiprocessing.get_context()) != multiprocessing.context.ForkContext:
warnings.warn(
"You are passing a bound method as logp for DensityDist, this can lead to "
"errors when sampling when multiprocessing cannot rely on forking. Consider using a "
"plain function instead, or subclass Distribution."
)
self.rand = random
self.wrap_random_with_dist_shape = wrap_random_with_dist_shape
self.check_shape_in_random = check_shape_in_random
else:
dtype = cls.rv_op.dtype
ndim_supp = cls.rv_op.ndim_supp
if not hasattr(output.tag, "test_value"):
size = to_tuple(kwargs.get("size", None)) + (1,) * ndim_supp
output.tag.test_value = np.zeros(size, dtype)
return output


def default_not_implemented(rv_name, method_name):
if method_name == "random":
# This is a hack to catch the NotImplementedError when creating the RV without random
# If the message starts with "Cannot sample from", then it uses the test_value as
# the initial_val.
message = (
f"Cannot sample from the DensityDist '{rv_name}' because the {method_name} "
"keyword argument was not provided when the distribution was "
f"but this method had not been provided when the distribution was "
f"constructed. Please re-build your model and provide a callable "
f"to '{rv_name}'s {method_name} keyword argument.\n"
)
else:
message = (
f"Attempted to run {method_name} on the DensityDist '{rv_name}', "
f"but this method had not been provided when the distribution was "
f"constructed. Please re-build your model and provide a callable "
f"to '{rv_name}'s {method_name} keyword argument.\n"
)

def func(*args, **kwargs):
raise NotImplementedError(message)

def _distr_parameters_for_repr(self):
return []
return func
Loading

0 comments on commit 8a6b9b3

Please sign in to comment.