Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add helper wrapper aound Interval transform #5347

Merged
merged 3 commits into from
Mar 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ All of the above apply to:
- `pm.sample_prior_predictive`, `pm.sample_posterior_predictive` and `pm.sample_posterior_predictive_w` now return an `InferenceData` object by default, instead of a dictionary (see [#5073](https://github.com/pymc-devs/pymc/pull/5073)).
- `pm.sample_prior_predictive` no longer returns transformed variable values by default. Pass them by name in `var_names` if you want to obtain these draws (see [4769](https://github.com/pymc-devs/pymc/pull/4769)).
- `pm.sample(trace=...)` no longer accepts `MultiTrace` or `len(.) > 0` traces ([see 5019#](https://github.com/pymc-devs/pymc/pull/5019)).
- `transforms` module is no longer accessible ta the root level. It is accessible at `pymc.distributions.transforms` (see[#5347](https://github.com/pymc-devs/pymc/pull/5347)).
- `logpt`, `logpt_sum`, `logp_elemwiset` and `nojac` variations were removed. Use `Model.logpt(jacobian=True/False, sum=True/False)` instead.
- `dlogp_nojact` and `d2logp_nojact` were removed. Use `Model.dlogpt` and `d2logpt` with `jacobian=False` instead.
- `logp`, `dlogp`, and `d2logp` and `nojac` variations were removed. Use `Model.compile_logp`, `compile_dlgop` and `compile_d2logp` with `jacobian` keyword instead.
Expand Down
22 changes: 12 additions & 10 deletions docs/source/api/distributions/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Transformations
***************

.. currentmodule:: pymc.transforms
.. currentmodule:: pymc.distributions.transforms

Transform Instances
~~~~~~~~~~~~~~~~~~~
Expand All @@ -15,28 +15,30 @@ Transform instances are the entities that should be used in the

simplex
logodds
interval
log_exp_m1
ordered
log
sum_to_1
circular

Transform Composition Classes
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autosummary::
:toctree: generated

Chain
CholeskyCovPacked

Specific Transform Classes
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autosummary::
:toctree: generated

CholeskyCovPacked
Interval
LogExpM1
Ordered
SumTo1


Transform Composition Classes
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autosummary::
:toctree: generated

Chain
1 change: 0 additions & 1 deletion pymc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def __set_compiler_flags():
from pymc.blocking import *
from pymc.data import *
from pymc.distributions import *
from pymc.distributions import transforms
from pymc.exceptions import *
from pymc.func_utils import find_constrained_prior
from pymc.math import (
Expand Down
4 changes: 2 additions & 2 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def transform_params(*args):

return lower, upper

return transforms.interval(transform_params)
return transforms.Interval(bounds_fn=transform_params)


def assert_negative_support(var, label, distname, value=-1e-6):
Expand Down Expand Up @@ -3796,7 +3796,7 @@ def transform_params(*params):
_, _, _, x_points, _, _ = params
return floatX(x_points[0]), floatX(x_points[-1])

kwargs["transform"] = transforms.interval(transform_params)
kwargs["transform"] = transforms.Interval(bounds_fn=transform_params)
return super().__new__(cls, *args, **kwargs)

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
rv_size_is_none,
to_tuple,
)
from pymc.distributions.transforms import interval
from pymc.distributions.transforms import Interval
from pymc.math import kron_diag, kron_dot
from pymc.util import UNSET, check_dist_not_registered

Expand Down Expand Up @@ -1554,7 +1554,7 @@ class LKJCorr(BoundedContinuous):
def __new__(cls, *args, **kwargs):
transform = kwargs.get("transform", UNSET)
if transform is UNSET:
kwargs["transform"] = interval(lambda *args: (floatX(-1.0), floatX(1.0)))
kwargs["transform"] = Interval(floatX(-1.0), floatX(1.0))
return super().__new__(cls, *args, **kwargs)

@classmethod
Expand Down
97 changes: 89 additions & 8 deletions pymc/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import aesara.tensor as at
import numpy as np

from aeppl.transforms import (
CircularTransform,
Expand All @@ -27,7 +28,7 @@
"RVTransform",
"simplex",
"logodds",
"interval",
"Interval",
"log_exp_m1",
"ordered",
"log",
Expand Down Expand Up @@ -165,25 +166,105 @@ def log_jac_det(self, value, *inputs):


simplex = Simplex()
simplex.__doc__ = """
Instantiation of :class:`aeppl.transforms.Simplex`
for use in the ``transform`` argument of a random variable."""

logodds = LogOddsTransform()
logodds.__doc__ = """
Instantiation of :class:`aeppl.transforms.LogOddsTransform`
for use in the ``transform`` argument of a random variable."""

interval = IntervalTransform
interval.__doc__ = """
Instantiation of :class:`aeppl.transforms.IntervalTransform`
for use in the ``transform`` argument of a random variable."""

class Interval(IntervalTransform):
"""Wrapper around :class:`aeppl.transforms.IntervalTransform` for use in the
``transform`` argument of a random variable.

Parameters
----------
lower : int, float, or None
Lower bound of the interval transform. Must be a constant value. If ``None``, the
interval is not bounded below.
upper : int, float or None
Upper bound of the interval transfrom. Must be a finite value. If ``None``, the
interval is not bounded above.
bounds_fn : callable
Alternative to lower and upper. Must return a tuple of lower and upper bounds
as a symbolic function of the respective distribution inputs. If lower or
upper is ``None``, the interval is unbounded on that edge.

.. warning:: Expressions returned by `bounds_fn` should depend only on the
distribution inputs or other constants. Expressions that depend on other
symbolic variables, including nonlocal variables defined in the model
context will likely break sampling.


Examples
--------
.. code-block:: python

# Create an interval transform between -1 and +1
with pm.Model():
interval = pm.distributions.transforms.Interval(lower=-1, upper=1)
x = pm.Normal("x", transform=interval)

.. code-block:: python

# Create an interval transform between -1 and +1 using a callable
def get_bounds(rng, size, dtype, loc, scale):
return 0, None

with pm.Model():
interval = pm.distributions.transforms.Interval(bouns_fn=get_bounds)
x = pm.Normal("x", transform=interval)

.. code-block:: python

# Create a lower bounded interval transform based on a distribution parameter
def get_bounds(rng, size, dtype, loc, scale):
return loc, None

interval = pm.distributions.transforms.Interval(bounds_fn=get_bounds)

with pm.Model():
loc = pm.Normal("loc")
x = pm.Normal("x", mu=loc, sigma=2, transform=interval)
"""

def __init__(self, lower=None, upper=None, *, bounds_fn=None):
if bounds_fn is None:
try:
bounds = tuple(
None if bound is None else at.constant(bound, ndim=0).data
for bound in (lower, upper)
)
except (ValueError, TypeError):
raise ValueError(
"Interval bounds must be constant values. If you need expressions that "
"depend on symbolic variables use `args_fn`"
)

lower, upper = (
None if (bound is None or np.isinf(bound)) else bound for bound in bounds
)

if lower is None and upper is None:
raise ValueError("Lower and upper interval bounds cannot both be None")

def bounds_fn(*rv_inputs):
return lower, upper

super().__init__(args_fn=bounds_fn)


log_exp_m1 = LogExpM1()
log_exp_m1.__doc__ = """
Instantiation of :class:`pymc.transforms.LogExpM1`
Instantiation of :class:`pymc.distributions.transforms.LogExpM1`
for use in the ``transform`` argument of a random variable."""

ordered = Ordered()
ordered.__doc__ = """
Instantiation of :class:`pymc.transforms.Ordered`
Instantiation of :class:`pymc.distributions.transforms.Ordered`
for use in the ``transform`` argument of a random variable."""

log = LogTransform()
Expand All @@ -193,7 +274,7 @@ def log_jac_det(self, value, *inputs):

sum_to_1 = SumTo1()
sum_to_1.__doc__ = """
Instantiation of :class:`pymc.transforms.SumTo1`
Instantiation of :class:`pymc.distributions.transforms.SumTo1`
for use in the ``transform`` argument of a random variable."""

circular = CircularTransform()
Expand Down
2 changes: 1 addition & 1 deletion pymc/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2710,7 +2710,7 @@ def test_arguments_checks(self):
with pm.Model() as m:
x = pm.Poisson.dist(0.5)
with pytest.raises(ValueError, match=msg):
pm.Bound("bound", x, transform=pm.transforms.interval)
pm.Bound("bound", x, transform=pm.distributions.transforms.log)

msg = "Given dims do not exist in model coordinates."
with pm.Model() as m:
Expand Down
6 changes: 4 additions & 2 deletions pymc/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,11 +327,13 @@ def test_deterministic_of_unobserved(self):

np.testing.assert_allclose(idata.posterior["y"], idata.posterior["x"] + 100)

def test_transform_with_rv_depenency(self):
def test_transform_with_rv_dependency(self):
# Test that untransformed variables that depend on upstream variables are properly handled
with pm.Model() as m:
x = pm.HalfNormal("x", observed=1)
transform = pm.transforms.IntervalTransform(lambda *inputs: (inputs[-2], inputs[-1]))
transform = pm.distributions.transforms.Interval(
bounds_fn=lambda *inputs: (inputs[-2], inputs[-1])
)
y = pm.Uniform("y", lower=0, upper=x, transform=transform)
trace = pm.sample(tune=10, draws=50, return_inferencedata=False, random_seed=336)

Expand Down
31 changes: 16 additions & 15 deletions pymc/tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,7 @@ def test_logodds():


def test_lowerbound():
def transform_params(*inputs):
return 0.0, None

trans = tr.interval(transform_params)
trans = tr.Interval(0.0, None)
check_transform(trans, Rplusbig)

check_jacobian_det(trans, Rplusbig, elemwise=True)
Expand All @@ -191,10 +188,7 @@ def transform_params(*inputs):


def test_upperbound():
def transform_params(*inputs):
return None, 0.0

trans = tr.interval(transform_params)
trans = tr.Interval(None, 0.0)
check_transform(trans, Rminusbig)

check_jacobian_det(trans, Rminusbig, elemwise=True)
Expand All @@ -208,10 +202,7 @@ def test_interval():
for a, b in [(-4, 5.5), (0.1, 0.7), (-10, 4.3)]:
domain = Unit * np.float64(b - a) + np.float64(a)

def transform_params(z=a, y=b):
return z, y

trans = tr.interval(transform_params)
trans = tr.Interval(a, b)
check_transform(trans, domain)

check_jacobian_det(trans, domain, elemwise=True)
Expand Down Expand Up @@ -375,7 +366,7 @@ def transform_params(*inputs):
upper = at.as_tensor_variable(upper) if upper is not None else None
return lower, upper

interval = tr.interval(transform_params)
interval = tr.Interval(bounds_fn=transform_params)
model = self.build_model(
pm.Uniform, {"lower": lower, "upper": upper}, size=size, transform=interval
)
Expand All @@ -396,7 +387,7 @@ def transform_params(*inputs):
upper = at.as_tensor_variable(upper) if upper is not None else None
return lower, upper

interval = tr.interval(transform_params)
interval = tr.Interval(bounds_fn=transform_params)
model = self.build_model(
pm.Triangular, {"lower": lower, "c": c, "upper": upper}, size=size, transform=interval
)
Expand Down Expand Up @@ -491,7 +482,7 @@ def transform_params(*inputs):
upper = at.as_tensor_variable(upper) if upper is not None else None
return lower, upper

interval = tr.interval(transform_params)
interval = tr.Interval(bounds_fn=transform_params)

initval = np.sort(np.abs(np.random.rand(*size)))
model = self.build_model(
Expand Down Expand Up @@ -556,3 +547,13 @@ def test_triangular_transform():
transform = x.tag.value_var.tag.transform
assert np.isclose(transform.backward(-np.inf, *x.owner.inputs).eval(), 0)
assert np.isclose(transform.backward(np.inf, *x.owner.inputs).eval(), 2)


def test_interval_transform_raises():
with pytest.raises(ValueError, match="Lower and upper interval bounds cannot both be None"):
tr.Interval(None, None)

with pytest.raises(ValueError, match="Interval bounds must be constant values"):
tr.Interval(at.constant(5) + 1, None)

assert tr.Interval(at.constant(5), None)