diff --git a/pymc/distributions/__init__.py b/pymc/distributions/__init__.py index bb91ea23c2a..2aff0c18c03 100644 --- a/pymc/distributions/__init__.py +++ b/pymc/distributions/__init__.py @@ -21,6 +21,7 @@ ) from pymc.distributions.bound import Bound +from pymc.distributions.censored import Censored from pymc.distributions.continuous import ( AsymmetricLaplace, Beta, @@ -187,6 +188,7 @@ "Rice", "Moyal", "Simulator", + "Censored", "CAR", "PolyaGamma", "logpt", diff --git a/pymc/distributions/censored.py b/pymc/distributions/censored.py new file mode 100644 index 00000000000..14b92f30359 --- /dev/null +++ b/pymc/distributions/censored.py @@ -0,0 +1,146 @@ +# Copyright 2020 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import aesara.tensor as at +import numpy as np + +from aesara.scalar import Clip +from aesara.tensor import TensorVariable +from aesara.tensor.random.op import RandomVariable + +from pymc.distributions.distribution import SymbolicDistribution, _get_moment +from pymc.util import check_dist_not_registered + + +class Censored(SymbolicDistribution): + r""" + Censored distribution + + The pdf of a censored distribution is + + .. math:: + + \begin{cases} + 0 & \text{for } x < lower, \\ + \text{CDF}(lower, dist) & \text{for } x = lower, \\ + \text{PDF}(x, dist) & \text{for } lower < x < upper, \\ + 1-\text{CDF}(upper, dist) & \text {for} x = upper, \\ + 0 & \text{for } x > upper, + \end{cases} + + + Parameters + ---------- + dist: PyMC unnamed distribution + PyMC distribution created via the .dit() API, which will be censored. This + distribution must be univariate and have a logcdf method implemeted. + lower: float or None + Lower (left) censoring point. If `None` the distribution will not be left censored + upper: float or None + Upper (right) censoring point. If `None`, the distribution will not be right censored. + + + Examples + -------- + .. code-block:: python + + with pm.Model(): + normal_dist = pm.Normal.dist(mu=0.0, sigma=1.0) + censored_normal = pm.Censored("censored_normal", normal_dist, lower=-1, upper=1) + """ + + @classmethod + def dist(cls, dist, lower, upper, **kwargs): + if not isinstance(dist, TensorVariable) or not isinstance(dist.owner.op, RandomVariable): + raise ValueError( + f"Censoring dist must be a distribution created via the `.dist()` API, got {type(dist)}" + ) + if dist.owner.op.ndim_supp > 0: + raise NotImplementedError( + "Censoring of multivariate distributions has not been implemented yet" + ) + check_dist_not_registered(dist) + return super().dist([dist, lower, upper], **kwargs) + + @classmethod + def rv_op(cls, dist, lower=None, upper=None, size=None, rngs=None): + if lower is None: + lower = at.constant(-np.inf) + if upper is None: + upper = at.constant(np.inf) + + # Censoring is achieved by clipping the base distribution between lower and upper + rv_out = at.clip(dist, lower, upper) + + # Reference nodes to facilitate identification in other classmethods, without + # worring about possible dimshuffles + rv_out.tag.dist = dist + rv_out.tag.lower = lower + rv_out.tag.upper = upper + + if size is not None: + rv_out = cls.change_size(rv_out, size) + if rngs is not None: + rv_out = cls.change_rngs(rv_out, rngs) + + return rv_out + + @classmethod + def ndim_supp(cls, *dist_params): + return 0 + + @classmethod + def change_size(cls, rv, new_size): + dist_node = rv.tag.dist.owner + lower = rv.tag.lower + upper = rv.tag.upper + rng, old_size, dtype, *dist_params = dist_node.inputs + new_dist = dist_node.op.make_node(rng, new_size, dtype, *dist_params).default_output() + return cls.rv_op(new_dist, lower, upper) + + @classmethod + def change_rngs(cls, rv, new_rngs): + (new_rng,) = new_rngs + dist_node = rv.tag.dist.owner + lower = rv.tag.lower + upper = rv.tag.upper + olg_rng, size, dtype, *dist_params = dist_node.inputs + new_dist = dist_node.op.make_node(new_rng, size, dtype, *dist_params).default_output() + return cls.rv_op(new_dist, lower, upper) + + @classmethod + def graph_rvs(cls, rv): + return (rv.tag.dist,) + + +@_get_moment.register(Clip) +def get_moment_censored(op, rv, dist, lower, upper): + moment = at.switch( + at.eq(lower, -np.inf), + at.switch( + at.isinf(upper), + # lower = -inf, upper = inf + 0, + # lower = -inf, upper = x + upper - 1, + ), + at.switch( + at.eq(upper, np.inf), + # lower = x, upper = inf + lower + 1, + # lower = x, upper = x + (lower + upper) / 2, + ), + ) + moment = at.full_like(dist, moment) + return moment diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index cf9d87eb0b3..86a1c1e7fbd 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -26,6 +26,7 @@ from aeppl.logprob import _logcdf, _logprob from aesara import tensor as at from aesara.tensor.basic import as_tensor_variable +from aesara.tensor.elemwise import Elemwise from aesara.tensor.random.op import RandomVariable from aesara.tensor.random.var import RandomStateSharedVariable from aesara.tensor.var import TensorVariable @@ -628,6 +629,12 @@ def get_moment(rv: TensorVariable) -> TensorVariable: return _get_moment(rv.owner.op, rv, *rv.owner.inputs).astype(rv.dtype) +@_get_moment.register(Elemwise) +def _get_moment_elemwise(op, rv, *dist_params): + """For Elemwise Ops, dispatch on respective scalar_op""" + return _get_moment(op.scalar_op, rv, *dist_params) + + class Discrete(Distribution): """Base class for discrete distributions""" diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py index 0095638e576..efe42006f45 100644 --- a/pymc/tests/test_distributions.py +++ b/pymc/tests/test_distributions.py @@ -3275,3 +3275,73 @@ def logp(value, mu): ).shape == to_tuple(size) ) + + +class TestCensored: + @pytest.mark.parametrize("censored", (False, True)) + def test_censored_workflow(self, censored): + # Based on pymc-examples/censored_data + rng = np.random.default_rng(1234) + size = 500 + true_mu = 13.0 + true_sigma = 5.0 + + # Set censoring limits + low = 3.0 + high = 16.0 + + # Draw censored samples + data = rng.normal(true_mu, true_sigma, size) + data[data <= low] = low + data[data >= high] = high + + with pm.Model(rng_seeder=17092021) as m: + mu = pm.Normal( + "mu", + mu=((high - low) / 2) + low, + sigma=(high - low) / 2.0, + initval="moment", + ) + sigma = pm.HalfNormal("sigma", sigma=(high - low) / 2.0, initval="moment") + observed = pm.Censored( + "observed", + pm.Normal.dist(mu=mu, sigma=sigma), + lower=low if censored else None, + upper=high if censored else None, + observed=data, + ) + + prior_pred = pm.sample_prior_predictive() + posterior = pm.sample(tune=500, draws=500) + posterior_pred = pm.sample_posterior_predictive(posterior) + + expected = True if censored else False + assert (9 < prior_pred.prior_predictive.mean() < 10) == expected + assert (13 < posterior.posterior["mu"].mean() < 14) == expected + assert (4.5 < posterior.posterior["sigma"].mean() < 5.5) == expected + assert (12 < posterior_pred.posterior_predictive.mean() < 13) == expected + + def test_censored_invalid_dist(self): + with pm.Model(): + invalid_dist = pm.Normal + with pytest.raises( + ValueError, + match=r"Censoring dist must be a distribution created via the", + ): + x = pm.Censored("x", invalid_dist, lower=None, upper=None) + + with pm.Model(): + mv_dist = pm.Dirichlet.dist(a=[1, 1, 1]) + with pytest.raises( + NotImplementedError, + match="Censoring of multivariate distributions has not been implemented yet", + ): + x = pm.Censored("x", mv_dist, lower=None, upper=None) + + with pm.Model(): + registered_dist = pm.Normal("dist") + with pytest.raises( + ValueError, + match="The dist dist was already registered in the current model", + ): + x = pm.Censored("x", registered_dist, lower=None, upper=None) diff --git a/pymc/util.py b/pymc/util.py index 92f2462e1fb..f2183688540 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -332,3 +332,20 @@ def cf(self): return cf return cachedmethod(self_cache_fn(f.__name__), key=hash_key)(f) + + +def check_dist_not_registered(dist, model=None): + """Check that a dist is not registered in the model already""" + from pymc.model import modelcontext + + try: + model = modelcontext(None) + except TypeError: + pass + else: + if dist in model.basic_RVs: + raise ValueError( + f"The dist {dist} was already registered in the current model.\n" + f"You should use an unregistered (unnamed) distribution created via " + f"the `.dist()` API instead, such as:\n`dist=pm.Normal.dist(0, 1)`" + )