diff --git a/docs/source/api/distributions/transforms.rst b/docs/source/api/distributions/transforms.rst index 8a08baca4fd..ffccd979889 100644 --- a/docs/source/api/distributions/transforms.rst +++ b/docs/source/api/distributions/transforms.rst @@ -15,21 +15,12 @@ Transform instances are the entities that should be used in the simplex logodds - interval log_exp_m1 ordered log sum_to_1 circular -Transform Composition Classes -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autosummary:: - :toctree: generated - - Chain - CholeskyCovPacked Specific Transform Classes ~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -37,6 +28,17 @@ Specific Transform Classes .. autosummary:: :toctree: generated + CholeskyCovPacked + Interval LogExpM1 Ordered SumTo1 + + +Transform Composition Classes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + :toctree: generated + + Chain diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 03c4a5a2e36..567db3ef5e0 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -195,7 +195,7 @@ def transform_params(*args): return lower, upper - return transforms.interval(transform_params) + return transforms.Interval(bounds_fn=transform_params) def assert_negative_support(var, label, distname, value=-1e-6): @@ -3796,7 +3796,7 @@ def transform_params(*params): _, _, _, x_points, _, _ = params return floatX(x_points[0]), floatX(x_points[-1]) - kwargs["transform"] = transforms.interval(transform_params) + kwargs["transform"] = transforms.Interval(bounds_fn=transform_params) return super().__new__(cls, *args, **kwargs) @classmethod diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index de500a33533..2b997dc0064 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -62,7 +62,7 @@ rv_size_is_none, to_tuple, ) -from pymc.distributions.transforms import interval +from pymc.distributions.transforms import Interval from pymc.math import kron_diag, kron_dot from pymc.util import UNSET, check_dist_not_registered @@ -1571,7 +1571,7 @@ class LKJCorr(BoundedContinuous): def __new__(cls, *args, **kwargs): transform = kwargs.get("transform", UNSET) if transform is UNSET: - kwargs["transform"] = interval(lambda *args: (floatX(-1.0), floatX(1.0))) + kwargs["transform"] = Interval(floatX(-1.0), floatX(1.0)) return super().__new__(cls, *args, **kwargs) @classmethod diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index d02fa210feb..b21ba41f0a5 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -13,6 +13,7 @@ # limitations under the License. import aesara.tensor as at +import numpy as np from aeppl.transforms import ( CircularTransform, @@ -27,7 +28,7 @@ "RVTransform", "simplex", "logodds", - "interval", + "Interval", "log_exp_m1", "ordered", "log", @@ -174,10 +175,87 @@ def log_jac_det(self, value, *inputs): Instantiation of :class:`aeppl.transforms.LogOddsTransform` for use in the ``transform`` argument of a random variable.""" -interval = IntervalTransform -interval.__doc__ = """ -Instantiation of :class:`aeppl.transforms.IntervalTransform` -for use in the ``transform`` argument of a random variable.""" + +class Interval(IntervalTransform): + """Wrapper around :class:`aeppl.transforms.IntervalTransform` for use in the + ``transform`` argument of a random variable. + + Parameters + ---------- + lower: int, float, or None + Lower bound of the interval transform. Must be a constant value. If ``None``, the + interval is not bounded below. + upper: int, float or None + Upper bound of the interval transfrom. Must be a finite value. If ``None``, the + interval is not bounded above. + bounds_fn: callable + Alternative to lower and upper. Must return a tuple of lower and upper bounds + as a symbolic function of the respective distribution inputs. If lower or + upper is ``None``, the interval is unbounded on that edge. + + .. warning:: Expressions returned by `bounds_fn` should depend only on the distribution + inputs or other constants. + + Expressions that depend on other symbolic variables, including nonlocal distribution + variables defined in the model context will likely break sampling. This can happen + silently, and is difficult to debug + + + Examples + -------- + .. code-block:: python + + # Create an interval transform between -1 and +1 + with pm.Model(): + x = pm.Normal("x", transform=pm.transforms.Interval(lower=-1, upper=1)) + + .. code-block:: python + + # Create an interval transform between -1 and +1 using a callable + def get_bounds(rng, size, dtype, loc, scale): + return 0, None + + with pm.Model(): + x = pm.Normal("x", transform=pm.transforms.Interval(bouns_fn=get_bounds)) + + .. code-block:: python + + # Create a lower bounded interval transform based on a distribution parameter + def get_bounds(rng, size, dtype, loc, scale): + return loc, None + + interval = pm.transforms.Interval(bounds_fn=get_bounds) + + with pm.Model(): + loc = pm.Normal("loc") + x = pm.Normal("x", mu=loc, sigma=2, transform=interval) + """ + + def __init__(self, lower=None, upper=None, *, bounds_fn=None): + if bounds_fn is None: + try: + bounds = tuple( + None if bound is None else at.constant(bound, ndim=0).data + for bound in (lower, upper) + ) + except (ValueError, TypeError): + raise ValueError( + "Interval bounds must be constant values. If you need expressions that " + "depend on symbolic variables use `args_fn`" + ) + + lower, upper = ( + None if (bound is None or np.isinf(bound)) else bound for bound in bounds + ) + + if lower is None and upper is None: + raise ValueError("Lower and upper interval bounds cannot both be None") + + def bounds_fn(*rv_inputs): + return lower, upper + + super().__init__(args_fn=bounds_fn) + log_exp_m1 = LogExpM1() log_exp_m1.__doc__ = """ diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py index 0e4e65baf40..d47d276038b 100644 --- a/pymc/tests/test_distributions.py +++ b/pymc/tests/test_distributions.py @@ -2710,7 +2710,7 @@ def test_arguments_checks(self): with pm.Model() as m: x = pm.Poisson.dist(0.5) with pytest.raises(ValueError, match=msg): - pm.Bound("bound", x, transform=pm.transforms.interval) + pm.Bound("bound", x, transform=pm.transforms.log) msg = "Given dims do not exist in model coordinates." with pm.Model() as m: diff --git a/pymc/tests/test_transforms.py b/pymc/tests/test_transforms.py index 6931b3d668c..04e1a588b24 100644 --- a/pymc/tests/test_transforms.py +++ b/pymc/tests/test_transforms.py @@ -177,10 +177,7 @@ def test_logodds(): def test_lowerbound(): - def transform_params(*inputs): - return 0.0, None - - trans = tr.interval(transform_params) + trans = tr.Interval(0.0, None) check_transform(trans, Rplusbig) check_jacobian_det(trans, Rplusbig, elemwise=True) @@ -191,10 +188,7 @@ def transform_params(*inputs): def test_upperbound(): - def transform_params(*inputs): - return None, 0.0 - - trans = tr.interval(transform_params) + trans = tr.Interval(None, 0.0) check_transform(trans, Rminusbig) check_jacobian_det(trans, Rminusbig, elemwise=True) @@ -208,10 +202,7 @@ def test_interval(): for a, b in [(-4, 5.5), (0.1, 0.7), (-10, 4.3)]: domain = Unit * np.float64(b - a) + np.float64(a) - def transform_params(z=a, y=b): - return z, y - - trans = tr.interval(transform_params) + trans = tr.Interval(a, b) check_transform(trans, domain) check_jacobian_det(trans, domain, elemwise=True) @@ -375,7 +366,7 @@ def transform_params(*inputs): upper = at.as_tensor_variable(upper) if upper is not None else None return lower, upper - interval = tr.interval(transform_params) + interval = tr.Interval(bounds_fn=transform_params) model = self.build_model( pm.Uniform, {"lower": lower, "upper": upper}, size=size, transform=interval ) @@ -396,7 +387,7 @@ def transform_params(*inputs): upper = at.as_tensor_variable(upper) if upper is not None else None return lower, upper - interval = tr.interval(transform_params) + interval = tr.Interval(bounds_fn=transform_params) model = self.build_model( pm.Triangular, {"lower": lower, "c": c, "upper": upper}, size=size, transform=interval ) @@ -491,7 +482,7 @@ def transform_params(*inputs): upper = at.as_tensor_variable(upper) if upper is not None else None return lower, upper - interval = tr.interval(transform_params) + interval = tr.Interval(bounds_fn=transform_params) initval = np.sort(np.abs(np.random.rand(*size))) model = self.build_model( @@ -556,3 +547,13 @@ def test_triangular_transform(): transform = x.tag.value_var.tag.transform assert np.isclose(transform.backward(-np.inf, *x.owner.inputs).eval(), 0) assert np.isclose(transform.backward(np.inf, *x.owner.inputs).eval(), 2) + + +def test_interval_transform_raises(): + with pytest.raises(ValueError, match="Lower and upper interval bounds cannot both be None"): + tr.Interval(None, None) + + with pytest.raises(ValueError, match="Interval bounds must be constant values"): + tr.Interval(at.constant(5) + 1, None) + + assert tr.Interval(at.constant(5), None)