Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap DensityDist's random with generate_samples #3554

Merged
merged 11 commits into from
Aug 16, 2019
3 changes: 3 additions & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
- SMC is no longer a step method of `pm.sample` now it should be called using `pm.sample_smc` [3579](https://github.com/pymc-devs/pymc3/pull/3579)
- Now uses `multiprocessong` rather than `psutil` to count CPUs, which results in reliable core counts on Chromebooks.
- `sample_posterior_predictive` now preallocates the memory required for its output to improve memory usage. Addresses problems raised in this [discourse thread](https://discourse.pymc.io/t/memory-error-with-posterior-predictive-sample/2891/4).
- Fixed a bug in `Categorical.logp`. In the case of multidimensional `p`'s, the indexing was done wrong leading to incorrectly shaped tensors that consumed `O(n**2)` memory instead of `O(n)`. This fixes issue [#3535](https://github.com/pymc-devs/pymc3/issues/3535)
- Fixed a defect in `OrderedLogistic.__init__` that unnecessarily increased the dimensionality of the underlying `p`. Related to issue issue [#3535](https://github.com/pymc-devs/pymc3/issues/3535) but was not the true cause of it.
- Wrapped `DensityDist.rand` with `generate_samples` to make it aware of the distribution's shape. Added control flow attributes to still be able to behave as in earlier versions, and to control how to interpret the `size` parameter in the `random` callable signature. Fixes [3553](https://github.com/pymc-devs/pymc3/issues/3553)


## PyMC3 3.7 (May 29 2019)
Expand Down
1 change: 1 addition & 0 deletions docs/source/api/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ Distributions
distributions/multivariate
distributions/mixture
distributions/timeseries
distributions/utilities
29 changes: 29 additions & 0 deletions docs/source/api/distributions/utilities.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
*******************************************
Distribution utility classes and functions
*******************************************

.. currentmodule:: pymc3.distributions
.. autosummary::

Distribution
Discrete
Continuous
NoDistribution
DensityDist
TensorType

draw_values
generate_samples


.. autoclass:: Distribution
.. autoclass:: Discrete
.. autoclass:: Continuous
.. autoclass:: NoDistribution
.. autoclass:: DensityDist
:members:
.. autofunction:: TensorType

.. autofunction:: draw_values
.. autofunction:: generate_samples

261 changes: 241 additions & 20 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,29 +198,250 @@ class DensityDist(Distribution):

A distribution with the passed log density function is created.
Requires a custom random function passed as kwarg `random` to
enable sampling.
enable prior or posterior predictive sampling.

Example:
"""

def __init__(
self,
logp,
shape=(),
dtype=None,
testval=0,
random=None,
wrap_random_with_dist_shape=True,
check_shape_in_random=True,
*args,
**kwargs
):
"""
Parameters
----------
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NICE!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The one question I have -- and this is not a question about your change -- is whether we should have a default shape that is not None (i.e., that records the fact that the programmer did not specify shape).
I'm sure that the shape is most likely (), but worried about the case where it isn't, and the programmer/modeler has not realized that they need to specify the shape.
But this is a different issue, and should not delay merging this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say no, we shouldn't add a special empty shape value. Ideally, the default random variable's shape should be deduced from its parameters. We are far from this goal. At the moment the shape information must be explicitly supplied to every distribution in pymc3, not just DensityDist. The default behavior package wise is to assume a scalar distribution.
When we start to move towards a more symbolic handling of shapes, we should address whether to reserve a special empty value or not.


logp: callable
A callable that has the following signature ``logp(value)`` and
returns a theano tensor that represents the distribution's log
probability density.
shape: tuple (Optional): defaults to `()`
The shape of the distribution. The default value indicates a scalar.
If the distribution is *not* scalar-valued, the programmer should pass
a value here.
dtype: None, str (Optional)
The dtype of the distribution.
testval: number or array (Optional)
The ``testval`` of the RV's tensor that follow the ``DensityDist``
distribution.
random: None or callable (Optional)
If ``None``, no random method is attached to the ``DensityDist``
instance.
If a callable, it is used as the distribution's ``random`` method.
The behavior of this callable can be altered with the
``wrap_random_with_dist_shape`` parameter.
The supplied callable must have the following signature:
``random(size=None, **kwargs)``, where ``size`` is the number of
IID draws to take from the distribution. Any extra keyword
argument can be added as required.
wrap_random_with_dist_shape: bool (Optional)
If ``True``, the provided ``random`` callable is passed through
``generate_samples`` to make the random number generator aware of
the ``DensityDist`` instance's ``shape``.
If ``False``, it is used exactly as it was provided.
check_shape_in_random: bool (Optional)
If ``True``, the shape of the random samples generate in the
``random`` method is checked with the expected return shape. This
test is only performed if ``wrap_random_with_dist_shape is False``.
args, kwargs: (Optional)
These are passed to the parent class' ``__init__``.

Note
----
If the ``random`` method is wrapped with dist shape, what this
means is that the ``random`` callable will be wrapped with the
:func:`~genereate_samples` function. The distribution's shape will
be passed to :func:`~generate_samples` as the ``dist_shape``
parameter. Any extra ``kwargs`` provided to ``random`` will be
passed as ``not_broadcast_kwargs`` of :func:`~generate_samples`.

Examples
--------
.. code-block:: python
with pm.Model():
mu = pm.Normal('mu',0,1)
normal_dist = pm.Normal.dist(mu, 1)
pm.DensityDist('density_dist', normal_dist.logp, observed=np.random.randn(100), random=normal_dist.random)
trace = pm.sample(100)
.. code-block:: python

with pm.Model():
mu = pm.Normal('mu',0,1)
normal_dist = pm.Normal.dist(mu, 1)
pm.DensityDist(
'density_dist',
normal_dist.logp,
observed=np.random.randn(100),
random=normal_dist.random
)
trace = pm.sample(100)

If the ``DensityDist`` is multidimensional, some care must be taken
with the supplied ``random`` method. By default, the supplied random
is wrapped by :func:`~generate_samples` to make it aware of the
multidimensional distribution's shape.
This can be prevented setting ``wrap_random_with_dist_shape=False``.
Furthermore, the ``size`` parameter is interpreted as the number of
IID draws to take from this multidimensional distribution.


.. code-block:: python

with pm.Model():
mu = pm.Normal('mu', 0 , 1)
normal_dist = pm.Normal.dist(mu, 1, shape=3)
dens = pm.DensityDist(
'density_dist',
normal_dist.logp,
observed=np.random.randn(100, 3),
shape=3,
random=normal_dist.random,
)
prior = pm.sample_prior_predictive(10)['density_dist']
assert prior.shape == (10, 100, 3)

If ``wrap_random_with_dist_shape=False``, we start to get samples of
an incorrect shape. By default, we can try to catch these situations.

"""

def __init__(self, logp, shape=(), dtype=None, testval=0, random=None, *args, **kwargs):
.. code-block:: python

with pm.Model():
mu = pm.Normal('mu', 0 , 1)
normal_dist = pm.Normal.dist(mu, 1, shape=3)
dens = pm.DensityDist(
'density_dist',
normal_dist.logp,
observed=np.random.randn(100, 3),
shape=3,
random=normal_dist.random,
wrap_random_with_dist_shape=False, # Is True by default
)
err = None
try:
prior = pm.sample_prior_predictive(10)['density_dist']
except RuntimeError as e:
err = e
assert isinstance(err, RuntimeError)

The default catching can be disabled with the
``check_shape_in_random`` parameter.


.. code-block:: python

with pm.Model():
mu = pm.Normal('mu', 0 , 1)
normal_dist = pm.Normal.dist(mu, 1, shape=3)
dens = pm.DensityDist(
'density_dist',
normal_dist.logp,
observed=np.random.randn(100, 3),
shape=3,
random=normal_dist.random,
wrap_random_with_dist_shape=False, # Is True by default
check_shape_in_random=False, # Is True by default
)
prior = pm.sample_prior_predictive(10)['density_dist']
# We get samples with an incorrect shape
assert prior.shape != (10, 100, 3)

If you use callables that work with ``scipy.stats`` rvs, you must
be aware that their ``size`` parameter is not the number of IID
samples to draw from a distribution, but the desired ``shape`` of
the returned array of samples. It is the user's responsibility to
wrap the callable to make it comply with PyMC3's interpretation
of ``size``.


.. code-block:: python

with pm.Model():
mu = pm.Normal('mu', 0 , 1)
normal_dist = pm.Normal.dist(mu, 1, shape=3)
dens = pm.DensityDist(
'density_dist',
normal_dist.logp,
observed=np.random.randn(100, 3),
shape=3,
random=stats.norm.rvs,
pymc3_size_interpretation=False, # Is True by default
)
prior = pm.sample_prior_predictive(10)['density_dist']
assert prior.shape == (10, 100, 3)

"""
if dtype is None:
dtype = theano.config.floatX
super().__init__(shape, dtype, testval, *args, **kwargs)
self.logp = logp
self.rand = random
self.wrap_random_with_dist_shape = wrap_random_with_dist_shape
self.check_shape_in_random = check_shape_in_random

def random(self, *args, **kwargs):
def random(self, point=None, size=None, **kwargs):
if self.rand is not None:
return self.rand(*args, **kwargs)
not_broadcast_kwargs = dict(point=point)
not_broadcast_kwargs.update(**kwargs)
if self.wrap_random_with_dist_shape:
size = to_tuple(size)
with _DrawValuesContextBlocker():
test_draw = generate_samples(
self.rand,
size=None,
not_broadcast_kwargs=not_broadcast_kwargs,
)
test_shape = test_draw.shape
if self.shape[:len(size)] == size:
dist_shape = size + self.shape
else:
dist_shape = self.shape
broadcast_shape = broadcast_dist_samples_shape(
[dist_shape, test_shape],
size=size
)
broadcast_shape = broadcast_shape[
:len(broadcast_shape) - len(test_shape)
]
samples = generate_samples(
self.rand,
broadcast_shape=broadcast_shape,
size=size,
not_broadcast_kwargs=not_broadcast_kwargs,
)
else:
samples = self.rand(point=point, size=size, **kwargs)
if self.check_shape_in_random:
expected_shape = (
self.shape
if size is None else
to_tuple(size) + self.shape
)
if not expected_shape == samples.shape:
raise RuntimeError(
"DensityDist encountered a shape inconsistency "
"while drawing samples using the supplied random "
"function. Was expecting to get samples of shape "
"{expected} but got {got} instead.\n"
"Whenever possible wrap_random_with_dist_shape = True "
"is recommended.\n"
"Be aware that the random callable provided as the "
"DensityDist random method cannot "
"adapt to shape changes in the distribution's "
"shape, which sometimes are necessary for sampling "
"when the model uses pymc3.Data or theano shared "
"tensors, or when the DensityDist has observed "
"values.\n"
"This check can be disabled by passing "
"check_shape_in_random=False when the DensityDist "
"is initialized.".
format(
expected=expected_shape,
got=samples.shape,
)
)
return samples
else:
raise ValueError("Distribution was not passed any random method "
"Define a custom random method and pass it as kwarg random")
Expand Down Expand Up @@ -290,17 +511,17 @@ def draw_values(params, point=None, size=None):
Draw (fix) parameter values. Handles a number of cases:

1) The parameter is a scalar
2) The parameter is an *RV
2) The parameter is an RV

a) parameter can be fixed to the value in the point
b) parameter can be fixed by sampling from the *RV
b) parameter can be fixed by sampling from the RV
c) parameter can be fixed using tag.test_value (last resort)

3) The parameter is a tensor variable/constant. Can be evaluated using
theano.function, but a variable may contain nodes which

a) are named parameters in the point
b) are *RVs with a random method
b) are RVs with a random method
"""
# Get fast drawable values (i.e. things in point or numbers, arrays,
# constants or shares, or things that were already drawn in related
Expand Down Expand Up @@ -646,13 +867,13 @@ def generate_samples(generator, *args, **kwargs):
generator : function
Function to generate the random samples. The function is
expected take parameters for generating samples and
a keyword argument `size` which determines the shape
a keyword argument ``size`` which determines the shape
of the samples.
The *args and **kwargs (stripped of the keywords below) will be
The args and kwargs (stripped of the keywords below) will be
passed to the generator function.

keyword arguments
~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~

dist_shape : int or tuple of int
The shape of the random variable (i.e., the shape attribute).
Expand All @@ -666,9 +887,9 @@ def generate_samples(generator, *args, **kwargs):
the shape of the probabilities in the Categorical distribution.
not_broadcast_kwargs: dict or None
Key word argument dictionary to provide to the random generator, which
must not be broadcasted with the rest of the *args and **kwargs.
must not be broadcasted with the rest of the args and kwargs.

Any remaining *args and **kwargs are passed on to the generator function.
Any remaining args and kwargs are passed on to the generator function.
"""
dist_shape = kwargs.pop('dist_shape', ())
one_d = _is_one_d(dist_shape)
Expand Down
Loading