Skip to content

Commit

Permalink
Create helper wrapper around Aeppl IntervalTransform
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Mar 17, 2022
1 parent 345d98d commit 521df93
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 34 deletions.
20 changes: 11 additions & 9 deletions docs/source/api/distributions/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,30 @@ Transform instances are the entities that should be used in the

simplex
logodds
interval
log_exp_m1
ordered
log
sum_to_1
circular

Transform Composition Classes
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autosummary::
:toctree: generated

Chain
CholeskyCovPacked

Specific Transform Classes
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autosummary::
:toctree: generated

CholeskyCovPacked
Interval
LogExpM1
Ordered
SumTo1


Transform Composition Classes
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autosummary::
:toctree: generated

Chain
4 changes: 2 additions & 2 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def transform_params(*args):

return lower, upper

return transforms.interval(transform_params)
return transforms.Interval(bounds_fn=transform_params)


def assert_negative_support(var, label, distname, value=-1e-6):
Expand Down Expand Up @@ -3796,7 +3796,7 @@ def transform_params(*params):
_, _, _, x_points, _, _ = params
return floatX(x_points[0]), floatX(x_points[-1])

kwargs["transform"] = transforms.interval(transform_params)
kwargs["transform"] = transforms.Interval(bounds_fn=transform_params)
return super().__new__(cls, *args, **kwargs)

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
rv_size_is_none,
to_tuple,
)
from pymc.distributions.transforms import interval
from pymc.distributions.transforms import Interval
from pymc.math import kron_diag, kron_dot
from pymc.util import UNSET, check_dist_not_registered

Expand Down Expand Up @@ -1571,7 +1571,7 @@ class LKJCorr(BoundedContinuous):
def __new__(cls, *args, **kwargs):
transform = kwargs.get("transform", UNSET)
if transform is UNSET:
kwargs["transform"] = interval(lambda *args: (floatX(-1.0), floatX(1.0)))
kwargs["transform"] = Interval(floatX(-1.0), floatX(1.0))
return super().__new__(cls, *args, **kwargs)

@classmethod
Expand Down
88 changes: 83 additions & 5 deletions pymc/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import aesara.tensor as at
import numpy as np

from aeppl.transforms import (
CircularTransform,
Expand All @@ -27,7 +28,7 @@
"RVTransform",
"simplex",
"logodds",
"interval",
"Interval",
"log_exp_m1",
"ordered",
"log",
Expand Down Expand Up @@ -174,10 +175,87 @@ def log_jac_det(self, value, *inputs):
Instantiation of :class:`aeppl.transforms.LogOddsTransform`
for use in the ``transform`` argument of a random variable."""

interval = IntervalTransform
interval.__doc__ = """
Instantiation of :class:`aeppl.transforms.IntervalTransform`
for use in the ``transform`` argument of a random variable."""

class Interval(IntervalTransform):
"""Wrapper around :class:`aeppl.transforms.IntervalTransform` for use in the
``transform`` argument of a random variable.
Parameters
----------
lower: int, float, or None
Lower bound of the interval transform. Must be a constant value. If ``None``, the
interval is not bounded below.
upper: int, float or None
Upper bound of the interval transfrom. Must be a finite value. If ``None``, the
interval is not bounded above.
bounds_fn: callable
Alternative to lower and upper. Must return a tuple of lower and upper bounds
as a symbolic function of the respective distribution inputs. If lower or
upper is ``None``, the interval is unbounded on that edge.
.. warning:: Expressions returned by `bounds_fn` should depend only on the distribution
inputs or other constants.
Expressions that depend on other symbolic variables, including nonlocal distribution
variables defined in the model context will likely break sampling. This can happen
silently, and is difficult to debug
Examples
--------
.. code-block:: python
# Create an interval transform between -1 and +1
with pm.Model():
x = pm.Normal("x", transform=pm.transforms.Interval(lower=-1, upper=1))
.. code-block:: python
# Create an interval transform between -1 and +1 using a callable
def get_bounds(rng, size, dtype, loc, scale):
return 0, None
with pm.Model():
x = pm.Normal("x", transform=pm.transforms.Interval(bouns_fn=get_bounds))
.. code-block:: python
# Create a lower bounded interval transform based on a distribution parameter
def get_bounds(rng, size, dtype, loc, scale):
return loc, None
interval = pm.transforms.Interval(bounds_fn=get_bounds)
with pm.Model():
loc = pm.Normal("loc")
x = pm.Normal("x", mu=loc, sigma=2, transform=interval)
"""

def __init__(self, lower=None, upper=None, *, bounds_fn=None):
if bounds_fn is None:
try:
bounds = tuple(
None if bound is None else at.constant(bound, ndim=0).data
for bound in (lower, upper)
)
except (ValueError, TypeError):
raise ValueError(
"Interval bounds must be constant values. If you need expressions that "
"depend on symbolic variables use `args_fn`"
)

lower, upper = (
None if (bound is None or np.isinf(bound)) else bound for bound in bounds
)

if lower is None and upper is None:
raise ValueError("Lower and upper interval bounds cannot both be None")

def bounds_fn(*rv_inputs):
return lower, upper

super().__init__(args_fn=bounds_fn)


log_exp_m1 = LogExpM1()
log_exp_m1.__doc__ = """
Expand Down
2 changes: 1 addition & 1 deletion pymc/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2710,7 +2710,7 @@ def test_arguments_checks(self):
with pm.Model() as m:
x = pm.Poisson.dist(0.5)
with pytest.raises(ValueError, match=msg):
pm.Bound("bound", x, transform=pm.transforms.interval)
pm.Bound("bound", x, transform=pm.transforms.log)

msg = "Given dims do not exist in model coordinates."
with pm.Model() as m:
Expand Down
31 changes: 16 additions & 15 deletions pymc/tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,7 @@ def test_logodds():


def test_lowerbound():
def transform_params(*inputs):
return 0.0, None

trans = tr.interval(transform_params)
trans = tr.Interval(0.0, None)
check_transform(trans, Rplusbig)

check_jacobian_det(trans, Rplusbig, elemwise=True)
Expand All @@ -191,10 +188,7 @@ def transform_params(*inputs):


def test_upperbound():
def transform_params(*inputs):
return None, 0.0

trans = tr.interval(transform_params)
trans = tr.Interval(None, 0.0)
check_transform(trans, Rminusbig)

check_jacobian_det(trans, Rminusbig, elemwise=True)
Expand All @@ -208,10 +202,7 @@ def test_interval():
for a, b in [(-4, 5.5), (0.1, 0.7), (-10, 4.3)]:
domain = Unit * np.float64(b - a) + np.float64(a)

def transform_params(z=a, y=b):
return z, y

trans = tr.interval(transform_params)
trans = tr.Interval(a, b)
check_transform(trans, domain)

check_jacobian_det(trans, domain, elemwise=True)
Expand Down Expand Up @@ -375,7 +366,7 @@ def transform_params(*inputs):
upper = at.as_tensor_variable(upper) if upper is not None else None
return lower, upper

interval = tr.interval(transform_params)
interval = tr.Interval(bounds_fn=transform_params)
model = self.build_model(
pm.Uniform, {"lower": lower, "upper": upper}, size=size, transform=interval
)
Expand All @@ -396,7 +387,7 @@ def transform_params(*inputs):
upper = at.as_tensor_variable(upper) if upper is not None else None
return lower, upper

interval = tr.interval(transform_params)
interval = tr.Interval(bounds_fn=transform_params)
model = self.build_model(
pm.Triangular, {"lower": lower, "c": c, "upper": upper}, size=size, transform=interval
)
Expand Down Expand Up @@ -491,7 +482,7 @@ def transform_params(*inputs):
upper = at.as_tensor_variable(upper) if upper is not None else None
return lower, upper

interval = tr.interval(transform_params)
interval = tr.Interval(bounds_fn=transform_params)

initval = np.sort(np.abs(np.random.rand(*size)))
model = self.build_model(
Expand Down Expand Up @@ -556,3 +547,13 @@ def test_triangular_transform():
transform = x.tag.value_var.tag.transform
assert np.isclose(transform.backward(-np.inf, *x.owner.inputs).eval(), 0)
assert np.isclose(transform.backward(np.inf, *x.owner.inputs).eval(), 2)


def test_interval_transform_raises():
with pytest.raises(ValueError, match="Lower and upper interval bounds cannot both be None"):
tr.Interval(None, None)

with pytest.raises(ValueError, match="Interval bounds must be constant values"):
tr.Interval(at.constant(5) + 1, None)

assert tr.Interval(at.constant(5), None)

0 comments on commit 521df93

Please sign in to comment.