From 9f8ea526526e88ae28c3837c0157a35d8e87a9ea Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 30 Jun 2023 11:20:40 +0200 Subject: [PATCH 1/2] Derive probability for broadcast operation --- .github/workflows/tests.yml | 1 + mypy.ini | 1 + pymc/logprob/__init__.py | 1 + pymc/logprob/shape.py | 121 ++++++++++++++++++++++++++++++++++++ tests/logprob/test_shape.py | 50 +++++++++++++++ 5 files changed, 174 insertions(+) create mode 100644 pymc/logprob/shape.py create mode 100644 tests/logprob/test_shape.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d72431194e9..e5e2b2fac93 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -112,6 +112,7 @@ jobs: tests/logprob/test_mixture.py tests/logprob/test_rewriting.py tests/logprob/test_scan.py + tests/logprob/test_shape.py tests/logprob/test_tensor.py tests/logprob/test_transforms.py tests/logprob/test_utils.py diff --git a/mypy.ini b/mypy.ini index 360e64c2e51..e420cc7697f 100644 --- a/mypy.ini +++ b/mypy.ini @@ -10,3 +10,4 @@ disallow_untyped_defs = False disallow_untyped_decorators = False ignore_missing_imports = True warn_unused_ignores = False +disable_error_code = annotation-unchecked diff --git a/pymc/logprob/__init__.py b/pymc/logprob/__init__.py index 0ddea90b6fb..ad001546c5c 100644 --- a/pymc/logprob/__init__.py +++ b/pymc/logprob/__init__.py @@ -50,6 +50,7 @@ import pymc.logprob.checks import pymc.logprob.mixture import pymc.logprob.scan +import pymc.logprob.shape import pymc.logprob.tensor import pymc.logprob.transforms diff --git a/pymc/logprob/shape.py b/pymc/logprob/shape.py new file mode 100644 index 00000000000..ed20cab9ae0 --- /dev/null +++ b/pymc/logprob/shape.py @@ -0,0 +1,121 @@ +# Copyright 2023 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import numpy as np +import pytensor.tensor as pt + +from pytensor.graph import node_rewriter +from pytensor.tensor.extra_ops import BroadcastTo + +from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper +from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db + + +class MeasurableBroadcast(BroadcastTo): + pass + + +MeasurableVariable.register(MeasurableBroadcast) + + +measurable_broadcast = MeasurableBroadcast() + + +@_logprob.register(MeasurableBroadcast) +def broadcast_logprob(op, values, rv, *shape, **kwargs): + """Log-probability expression for (statically-)broadcasted RV + + The probability is the same as the base RV, if no broadcasting had happened: + + ``logp(broadcast_to(normal(size=(3, 1)), (2, 3, 4)), zeros((2, 3, 4))) == logp(normal(size=(3, 1)), zeros((3, 1)))`` + + And zero if the value couldn't have possibly originated via broadcasting: + + ``logp(broadcast_to(normal(size=(1,)), (3,)), [1, 2, 3]) == [-np.inf]`` + + """ + [value] = values + + n_new_dims = len(shape) - rv.ndim + assert n_new_dims >= 0 + + # Enumerate broadcasted dims + expanded_dims = tuple(range(n_new_dims)) + broadcast_dims = tuple( + i + n_new_dims + for i, (v_bcast, rv_bcast) in enumerate( + zip(value.broadcastable[n_new_dims:], rv.broadcastable) + ) + if (not v_bcast) and rv_bcast + ) + + # "Unbroadcast" value via indexing. + # All entries in the broadcasted dimensions should be the same, so we simply select the first of each. + indices = [] + for i in range(value.ndim): + # Remove expanded dims + if i in expanded_dims: + indices.append(0) + # Keep first entry of broadcasted (but not expanded) dims + elif i in broadcast_dims: + indices.append(slice(0, 1)) + else: + indices.append(slice(None)) + + unbroadcast_value = value[tuple(indices)] + logp = _logprob_helper(rv, unbroadcast_value) + + # Check that dependent values were indeed identical, by comparing with a re-broadcasted value + valid_value = pt.broadcast_to(unbroadcast_value, shape) + # Note: This could fail due to float-precision issues. + # If that proves to be a problem we should switch to `pt.allclose` + check = pt.all(pt.eq(value, valid_value)) + logp = pt.switch(check, logp, -np.inf) + + # Reintroduce expanded_dims in the returned logp + if n_new_dims > 0: + logp = pt.shape_padleft(logp, n_new_dims) + + return logp + + +@node_rewriter([BroadcastTo]) +def find_measurable_broadcast(fgraph, node): + r"""Finds `BroadcastTo`\s for which a `logprob` can be computed.""" + + if isinstance(node.op, MeasurableBroadcast): + return None # pragma: no cover + + rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) + + if rv_map_feature is None: + return None # pragma: no cover + + base_rv, *shape = node.inputs + + if not rv_map_feature.request_measurable([base_rv]): + return None + + new_rv = measurable_broadcast.make_node(base_rv, *shape).default_output() + + return [new_rv] + + +measurable_ir_rewrites_db.register( + "find_measurable_broadcast", + find_measurable_broadcast, + "basic", + "shape", +) diff --git a/tests/logprob/test_shape.py b/tests/logprob/test_shape.py new file mode 100644 index 00000000000..6b9dbed7666 --- /dev/null +++ b/tests/logprob/test_shape.py @@ -0,0 +1,50 @@ +# Copyright 2023 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytensor +import pytensor.tensor as pt +import scipy.stats as st + +from pymc import logp + + +def test_measurable_broadcast(): + b_shape = pt.vector("b_shape", shape=(3,), dtype=int) + + x = pt.random.normal(size=(3, 1)) + bcast_x = pt.broadcast_to(x, shape=b_shape) + bcast_x.name = "bcast_x" + + bcast_x_value = bcast_x.clone() + logp_bcast_x = logp(bcast_x, bcast_x_value) + logp_fn = pytensor.function([b_shape, bcast_x_value], logp_bcast_x, on_unused_input="ignore") + + # assert_allclose also asserts shapes match (if neither is scalar) + np.testing.assert_allclose( + logp_fn([1, 3, 1], np.zeros((1, 3, 1))), + st.norm.logpdf(np.zeros((1, 3, 1))), + ) + np.testing.assert_allclose( + logp_fn([1, 3, 5], np.zeros((1, 3, 5))), + st.norm.logpdf(np.zeros((1, 3, 1))), + ) + np.testing.assert_allclose( + logp_fn([2, 3, 5], np.broadcast_to(np.arange(3).reshape(1, 3, 1), (2, 3, 5))), + st.norm.logpdf(np.arange(3).reshape(1, 3, 1)), + ) + # Invalid broadcast value + np.testing.assert_array_equal( + logp_fn([1, 3, 5], np.arange(3 * 5).reshape(1, 3, 5)), + np.full(shape=(1, 3, 1), fill_value=-np.inf), + ) From 042c9f3c0c1a107f4b5e86ca3f5244c65d9d3a34 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 30 Jun 2023 11:27:16 +0200 Subject: [PATCH 2/2] Derive probability for transforms with implicit broadcasting A warning is issued as this graph is unlikely to be desired for most users. --- pymc/logprob/transforms.py | 69 ++++++++++++++++++++++---------- tests/logprob/test_transforms.py | 32 ++++++++++++--- 2 files changed, 73 insertions(+), 28 deletions(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 5f751b5bf1f..2621ad8d9b5 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -35,6 +35,7 @@ # SOFTWARE. import abc +import warnings from copy import copy from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union @@ -111,6 +112,7 @@ cleanup_ir_rewrites_db, measurable_ir_rewrites_db, ) +from pymc.logprob.shape import measurable_broadcast from pymc.logprob.utils import check_potential_measurability @@ -564,10 +566,13 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li scalar_op = node.op.scalar_op measurable_input_idx = 0 + measurable_input_broadcast = ( + measurable_input.type.broadcastable != node.default_output().type.broadcastable + ) transform_inputs: Tuple[TensorVariable, ...] = (measurable_input,) transform: RVTransform - transform_dict = { + unary_transforms_dict = { Exp: ExpTransform(), Log: LogTransform(), Abs: AbsTransform(), @@ -581,29 +586,49 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li Erfc: ErfcTransform(), Erfcx: ErfcxTransform(), } - transform = transform_dict.get(type(scalar_op), None) - if isinstance(scalar_op, Pow): - # We only allow for the base to be measurable - if measurable_input_idx != 0: - return None - try: - (power,) = other_inputs - power = pt.get_underlying_scalar_constant_value(power).item() - # Power needs to be a constant - except NotScalarConstantError: + transform = unary_transforms_dict.get(type(scalar_op), None) + if transform is None: + if isinstance(scalar_op, Pow): + # We only allow for the base to be measurable + if measurable_input_idx != 0: + return None + try: + (power,) = other_inputs + base_power = pt.get_underlying_scalar_constant_value(power).item() + # Power needs to be a constant + except NotScalarConstantError: + return None + transform_inputs = (measurable_input, power) + transform = PowerTransform(power=base_power) + elif isinstance(scalar_op, Add): + transform_inputs = (measurable_input, pt.add(*other_inputs)) + transform = LocTransform( + transform_args_fn=lambda *inputs: inputs[-1], + ) + elif isinstance(scalar_op, Mul): + transform_inputs = (measurable_input, pt.mul(*other_inputs)) + transform = ScaleTransform( + transform_args_fn=lambda *inputs: inputs[-1], + ) + else: + raise TypeError( + f"Scalar Op not supported: {scalar_op}. Rewrite should not have been triggered" + ) # pragma: no cover + + if measurable_input_broadcast: + # This rewrite logic only supports broadcasting for transforms with two inputs, where the first is measurable. + # This covers all current cases, update if other cases are supported in the future. + if len(transform_inputs) != 2 or measurable_input_idx != 0: return None - transform_inputs = (measurable_input, power) - transform = PowerTransform(power=power) - elif isinstance(scalar_op, Add): - transform_inputs = (measurable_input, pt.add(*other_inputs)) - transform = LocTransform( - transform_args_fn=lambda *inputs: inputs[-1], - ) - elif transform is None: - transform_inputs = (measurable_input, pt.mul(*other_inputs)) - transform = ScaleTransform( - transform_args_fn=lambda *inputs: inputs[-1], + warnings.warn( + "MeasurableTransform with implicit broadcasting detected. This corresponds to a potentially degenerate probability graph.\n" + "If you did not intend this, make sure the base measurable variable is created with all the dimensions from the start." + "Otherwise, an explicit `broadcast_to` operation can be used to silence this warning.\n", + UserWarning, ) + measurable_input, other_input = transform_inputs + measurable_input = measurable_broadcast(measurable_input, other_input.shape) + transform_inputs = (measurable_input, other_input) transform_op = MeasurableTransform( scalar_op=scalar_op, diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index 9960dff9487..91644fc0e89 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -807,16 +807,36 @@ def test_discrete_rv_multinary_transform_fails(): conditional_logp({y_rv: y_rv.clone()}) -@pytest.mark.xfail(reason="Check not implemented yet") -def test_invalid_broadcasted_transform_rv_fails(): +@pytest.mark.filterwarnings("error") # Fail if unexpected warning is issued +@pytest.mark.parametrize("implicit_broadcast", (True, False)) +def test_broadcasted_transform_rv(implicit_broadcast): loc = pt.vector("loc") - y_rv = loc + pt.random.normal(0, 1, size=1, name="base_rv") + base_rv = pt.random.normal(0, 1, size=1, name="base_rv") + if implicit_broadcast: + y_rv = loc + base_rv + else: + y_rv = loc + pt.broadcast_to(base_rv, shape=loc.shape) y_rv.name = "y" y_vv = y_rv.clone() - # This logp derivation should fail or count only once the values that are broadcasted - logprob = logp(y_rv, y_vv) - assert logprob.eval({y_vv: [0, 0, 0, 0], loc: [0, 0, 0, 0]}).shape == () + if implicit_broadcast: + with pytest.warns(UserWarning, match="implicit broadcasting detected"): + logprob = logp(y_rv, y_vv) + else: + logprob = logp(y_rv, y_vv) + logprob_fn = pytensor.function([loc, y_vv], logprob) + + # All values must have the same offset from `loc` + np.testing.assert_allclose( + logprob_fn([1, 1, 1, 1], [0, 0, 0, 0]), sp.stats.norm.logpdf([0], loc=1) + ) + np.testing.assert_allclose( + logprob_fn([1, 2, 3, 4], [0, 1, 2, 3]), sp.stats.norm.logpdf([0], loc=1) + ) + + # Otherwise probability is 0 + np.testing.assert_array_equal(logprob_fn([1, 1, 1, 1], [0, 0, 0, 1]), [-np.inf]) + np.testing.assert_array_equal(logprob_fn([1, 2, 3, 4], [0, 0, 0, 0]), [-np.inf]) @pytest.mark.parametrize("numerator", (1.0, 2.0))