Skip to content

Commit

Permalink
Small fixes to pm.Bound
Browse files Browse the repository at this point in the history
* Fix invalid code example in docstrings
* Rename distribution parameter to dist
* Use `check_dist_not_registered` helper
  • Loading branch information
ricardoV94 committed Jan 4, 2022
1 parent 36e023c commit 766ac8d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 36 deletions.
60 changes: 25 additions & 35 deletions pymc/distributions/bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pymc.distributions.logprob import logp
from pymc.distributions.shape_utils import to_tuple
from pymc.model import modelcontext
from pymc.util import check_dist_not_registered

__all__ = ["Bound"]

Expand Down Expand Up @@ -144,8 +145,9 @@ class Bound:
Parameters
----------
distribution: pymc distribution
Distribution to be transformed into a bounded distribution.
dist: PyMC unnamed distribution
Distribution to be transformed into a bounded distribution created via the
`.dist()` API.
lower: float or array like, optional
Lower bound of the distribution.
upper: float or array like, optional
Expand All @@ -156,15 +158,15 @@ class Bound:
.. code-block:: python
with pm.Model():
normal_dist = Normal.dist(mu=0.0, sigma=1.0, initval=-0.5)
negative_normal = pm.Bound(normal_dist, upper=0.0)
normal_dist = pm.Normal.dist(mu=0.0, sigma=1.0)
negative_normal = pm.Bound("negative_normal", normal_dist, upper=0.0)
"""

def __new__(
cls,
name,
distribution,
dist,
lower=None,
upper=None,
size=None,
Expand All @@ -174,7 +176,7 @@ def __new__(
**kwargs,
):

cls._argument_checks(distribution, **kwargs)
cls._argument_checks(dist, **kwargs)

if dims is not None:
model = modelcontext(None)
Expand All @@ -185,12 +187,12 @@ def __new__(
raise ValueError("Given dims do not exist in model coordinates.")

lower, upper, initval = cls._set_values(lower, upper, size, shape, initval)
distribution.tag.ignore_logprob = True
dist.tag.ignore_logprob = True

if isinstance(distribution.owner.op, Continuous):
if isinstance(dist.owner.op, Continuous):
res = _ContinuousBounded(
name,
[distribution, lower, upper],
[dist, lower, upper],
initval=floatX(initval),
size=size,
shape=shape,
Expand All @@ -199,7 +201,7 @@ def __new__(
else:
res = _DiscreteBounded(
name,
[distribution, lower, upper],
[dist, lower, upper],
initval=intX(initval),
size=size,
shape=shape,
Expand All @@ -210,28 +212,28 @@ def __new__(
@classmethod
def dist(
cls,
distribution,
dist,
lower=None,
upper=None,
size=None,
shape=None,
**kwargs,
):

cls._argument_checks(distribution, **kwargs)
cls._argument_checks(dist, **kwargs)
lower, upper, initval = cls._set_values(lower, upper, size, shape, initval=None)
distribution.tag.ignore_logprob = True
if isinstance(distribution.owner.op, Continuous):
dist.tag.ignore_logprob = True
if isinstance(dist.owner.op, Continuous):
res = _ContinuousBounded.dist(
[distribution, lower, upper],
[dist, lower, upper],
size=size,
shape=shape,
**kwargs,
)
res.tag.test_value = floatX(initval)
else:
res = _DiscreteBounded.dist(
[distribution, lower, upper],
[dist, lower, upper],
size=size,
shape=shape,
**kwargs,
Expand All @@ -240,7 +242,7 @@ def dist(
return res

@classmethod
def _argument_checks(cls, distribution, **kwargs):
def _argument_checks(cls, dist, **kwargs):
if "observed" in kwargs:
raise ValueError(
"Observed Bound distributions are not supported. "
Expand All @@ -249,34 +251,22 @@ def _argument_checks(cls, distribution, **kwargs):
"with the cumulative probability function."
)

if not isinstance(distribution, TensorVariable):
if not isinstance(dist, TensorVariable):
raise ValueError(
"Passing a distribution class to `Bound` is no longer supported.\n"
"Please pass the output of a distribution instantiated via the "
"`.dist()` API such as:\n"
'`pm.Bound("bound", pm.Normal.dist(0, 1), lower=0)`'
)

try:
model = modelcontext(None)
except TypeError:
pass
else:
if distribution in model.basic_RVs:
raise ValueError(
f"The distribution passed into `Bound` was already registered "
f"in the current model.\nYou should pass an unregistered "
f"(unnamed) distribution created via the `.dist()` API, such as:\n"
f'`pm.Bound("bound", pm.Normal.dist(0, 1), lower=0)`'
)

if distribution.owner.op.ndim_supp != 0:
check_dist_not_registered(dist)

if dist.owner.op.ndim_supp != 0:
raise NotImplementedError("Bounding of MultiVariate RVs is not yet supported.")

if not isinstance(distribution.owner.op, (Discrete, Continuous)):
if not isinstance(dist.owner.op, (Discrete, Continuous)):
raise ValueError(
f"`distribution` {distribution} must be a Discrete or Continuous"
" distribution subclass"
f"`distribution` {dist} must be a Discrete or Continuous" " distribution subclass"
)

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion pymc/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2701,7 +2701,7 @@ def test_arguments_checks(self):
with pytest.raises(ValueError, match=msg):
pm.Bound("bound", x, dims="random_dims")

msg = "The distribution passed into `Bound` was already registered"
msg = "The dist x was already registered in the current model"
with pm.Model() as m:
x = pm.Normal("x", 0, 1)
with pytest.raises(ValueError, match=msg):
Expand Down

0 comments on commit 766ac8d

Please sign in to comment.