Skip to content

Commit

Permalink
Add Censored distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jan 4, 2022
1 parent e5f49d9 commit c54207b
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pymc/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)

from pymc.distributions.bound import Bound
from pymc.distributions.censored import Censored
from pymc.distributions.continuous import (
AsymmetricLaplace,
Beta,
Expand Down Expand Up @@ -187,6 +188,7 @@
"Rice",
"Moyal",
"Simulator",
"Censored",
"CAR",
"PolyaGamma",
"logpt",
Expand Down
146 changes: 146 additions & 0 deletions pymc/distributions/censored.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright 2020 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import aesara.tensor as at
import numpy as np

from aesara.scalar import Clip
from aesara.tensor import TensorVariable
from aesara.tensor.random.op import RandomVariable

from pymc.distributions.distribution import SymbolicDistribution, _get_moment
from pymc.util import check_dist_not_registered


class Censored(SymbolicDistribution):
r"""
Censored distribution
The pdf of a censored distribution is
.. math::
\begin{cases}
0 & \text{for } x < lower, \\
\text{CDF}(lower, dist) & \text{for } x = lower, \\
\text{PDF}(x, dist) & \text{for } lower < x < upper, \\
1-\text{CDF}(upper, dist) & \text {for} x = upper, \\
0 & \text{for } x > upper,
\end{cases}
Parameters
----------
dist: PyMC unnamed distribution
PyMC distribution created via the .dit() API, which will be censored. This
distribution must be univariate and have a logcdf method implemeted.
lower: float or None
Lower (left) censoring point. If `None` the distribution will not be left censored
upper: float or None
Upper (right) censoring point. If `None`, the distribution will not be right censored.
Examples
--------
.. code-block:: python
with pm.Model():
normal_dist = pm.Normal.dist(mu=0.0, sigma=1.0)
censored_normal = pm.Censored("censored_normal", normal_dist, lower=-1, upper=1)
"""

@classmethod
def dist(cls, dist, lower, upper, **kwargs):
if not isinstance(dist, TensorVariable) or not isinstance(dist.owner.op, RandomVariable):
raise ValueError(
f"Censoring dist must be a distribution created via the `.dist()` API, got {type(dist)}"
)
if dist.owner.op.ndim_supp > 0:
raise NotImplementedError(
"Censoring of multivariate distributions has not been implemented yet"
)
check_dist_not_registered(dist)
return super().dist([dist, lower, upper], **kwargs)

@classmethod
def rv_op(cls, dist, lower=None, upper=None, size=None, rngs=None):
if lower is None:
lower = at.constant(-np.inf)
if upper is None:
upper = at.constant(np.inf)

# Censoring is achieved by clipping the base distribution between lower and upper
rv_out = at.clip(dist, lower, upper)

# Reference nodes to facilitate identification in other classmethods, without
# worring about possible dimshuffles
rv_out.tag.dist = dist
rv_out.tag.lower = lower
rv_out.tag.upper = upper

if size is not None:
rv_out = cls.change_size(rv_out, size)
if rngs is not None:
rv_out = cls.change_rngs(rv_out, rngs)

return rv_out

@classmethod
def ndim_supp(cls, *dist_params):
return 0

@classmethod
def change_size(cls, rv, new_size):
dist_node = rv.tag.dist.owner
lower = rv.tag.lower
upper = rv.tag.upper
rng, old_size, dtype, *dist_params = dist_node.inputs
new_dist = dist_node.op.make_node(rng, new_size, dtype, *dist_params).default_output()
return cls.rv_op(new_dist, lower, upper)

@classmethod
def change_rngs(cls, rv, new_rngs):
(new_rng,) = new_rngs
dist_node = rv.tag.dist.owner
lower = rv.tag.lower
upper = rv.tag.upper
olg_rng, size, dtype, *dist_params = dist_node.inputs
new_dist = dist_node.op.make_node(new_rng, size, dtype, *dist_params).default_output()
return cls.rv_op(new_dist, lower, upper)

@classmethod
def graph_rvs(cls, rv):
return (rv.tag.dist,)


@_get_moment.register(Clip)
def get_moment_censored(op, rv, dist, lower, upper):
moment = at.switch(
at.eq(lower, -np.inf),
at.switch(
at.isinf(upper),
# lower = -inf, upper = inf
0,
# lower = -inf, upper = x
upper - 1,
),
at.switch(
at.eq(upper, np.inf),
# lower = x, upper = inf
lower + 1,
# lower = x, upper = x
(lower + upper) / 2,
),
)
moment = at.full_like(dist, moment)
return moment
7 changes: 7 additions & 0 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from aeppl.logprob import _logcdf, _logprob
from aesara import tensor as at
from aesara.tensor.basic import as_tensor_variable
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.random.var import RandomStateSharedVariable
from aesara.tensor.var import TensorVariable
Expand Down Expand Up @@ -628,6 +629,12 @@ def get_moment(rv: TensorVariable) -> TensorVariable:
return _get_moment(rv.owner.op, rv, *rv.owner.inputs).astype(rv.dtype)


@_get_moment.register(Elemwise)
def _get_moment_elemwise(op, rv, *dist_params):
"""For Elemwise Ops, dispatch on respective scalar_op"""
return _get_moment(op.scalar_op, rv, *dist_params)


class Discrete(Distribution):
"""Base class for discrete distributions"""

Expand Down
70 changes: 70 additions & 0 deletions pymc/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3275,3 +3275,73 @@ def logp(value, mu):
).shape
== to_tuple(size)
)


class TestCensored:
@pytest.mark.parametrize("censored", (False, True))
def test_censored_workflow(self, censored):
# Based on pymc-examples/censored_data
rng = np.random.default_rng(1234)
size = 500
true_mu = 13.0
true_sigma = 5.0

# Set censoring limits
low = 3.0
high = 16.0

# Draw censored samples
data = rng.normal(true_mu, true_sigma, size)
data[data <= low] = low
data[data >= high] = high

with pm.Model(rng_seeder=17092021) as m:
mu = pm.Normal(
"mu",
mu=((high - low) / 2) + low,
sigma=(high - low) / 2.0,
initval="moment",
)
sigma = pm.HalfNormal("sigma", sigma=(high - low) / 2.0, initval="moment")
observed = pm.Censored(
"observed",
pm.Normal.dist(mu=mu, sigma=sigma),
lower=low if censored else None,
upper=high if censored else None,
observed=data,
)

prior_pred = pm.sample_prior_predictive()
posterior = pm.sample(tune=500, draws=500)
posterior_pred = pm.sample_posterior_predictive(posterior)

expected = True if censored else False
assert (9 < prior_pred.prior_predictive.mean() < 10) == expected
assert (13 < posterior.posterior["mu"].mean() < 14) == expected
assert (4.5 < posterior.posterior["sigma"].mean() < 5.5) == expected
assert (12 < posterior_pred.posterior_predictive.mean() < 13) == expected

def test_censored_invalid_dist(self):
with pm.Model():
invalid_dist = pm.Normal
with pytest.raises(
ValueError,
match=r"Censoring dist must be a distribution created via the",
):
x = pm.Censored("x", invalid_dist, lower=None, upper=None)

with pm.Model():
mv_dist = pm.Dirichlet.dist(a=[1, 1, 1])
with pytest.raises(
NotImplementedError,
match="Censoring of multivariate distributions has not been implemented yet",
):
x = pm.Censored("x", mv_dist, lower=None, upper=None)

with pm.Model():
registered_dist = pm.Normal("dist")
with pytest.raises(
ValueError,
match="The dist dist was already registered in the current model",
):
x = pm.Censored("x", registered_dist, lower=None, upper=None)
17 changes: 17 additions & 0 deletions pymc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,3 +332,20 @@ def cf(self):
return cf

return cachedmethod(self_cache_fn(f.__name__), key=hash_key)(f)


def check_dist_not_registered(dist, model=None):
"""Check that a dist is not registered in the model already"""
from pymc.model import modelcontext

try:
model = modelcontext(None)
except TypeError:
pass
else:
if dist in model.basic_RVs:
raise ValueError(
f"The dist {dist} was already registered in the current model.\n"
f"You should use an unregistered (unnamed) distribution created via "
f"the `.dist()` API instead, such as:\n`dist=pm.Normal.dist(0, 1)`"
)

0 comments on commit c54207b

Please sign in to comment.