Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Derive probability for broadcasting operations #6808

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ jobs:
tests/logprob/test_mixture.py
tests/logprob/test_rewriting.py
tests/logprob/test_scan.py
tests/logprob/test_shape.py
tests/logprob/test_tensor.py
tests/logprob/test_transforms.py
tests/logprob/test_utils.py
Expand Down
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ disallow_untyped_defs = False
disallow_untyped_decorators = False
ignore_missing_imports = True
warn_unused_ignores = False
disable_error_code = annotation-unchecked
1 change: 1 addition & 0 deletions pymc/logprob/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import pymc.logprob.checks
import pymc.logprob.mixture
import pymc.logprob.scan
import pymc.logprob.shape
import pymc.logprob.tensor
import pymc.logprob.transforms

Expand Down
121 changes: 121 additions & 0 deletions pymc/logprob/shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright 2023 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

import numpy as np
import pytensor.tensor as pt

from pytensor.graph import node_rewriter
from pytensor.tensor.extra_ops import BroadcastTo

from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db


class MeasurableBroadcast(BroadcastTo):
pass


MeasurableVariable.register(MeasurableBroadcast)


measurable_broadcast = MeasurableBroadcast()


@_logprob.register(MeasurableBroadcast)
def broadcast_logprob(op, values, rv, *shape, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking out loud: could this possibly result in inconsistencies elsewhere? For instance, having Mixture components that have been broadcasted which would render them dependent, if that would be an issue

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The index mixture only works for basic RVs still so that's fine.

The switch mixture could actually wrongly broadcast the logp. In fact we should also check for invalid switches that mix support dimensions. The current implementation is only correct for ndim_supp==0!

This is another example of why it's so important to have the meta-info for all the MeasurableOps (#6360).

Once we have the meta-info, the Mixture will unambiguously know what kind of measurable variable it is dealing with. In the case of MeasurableBroadcasting, for example, the ndim_supp will have to be at least as large as the number of broadcasted dims (which means we should collapse that logp dimension instead of leaving it as we were doing now!).

We will also know where those support dims are, so that Mixture can know whether we are sub-selecting across core dims.

Without the meta-info, the only way of knowing ndim_supp is by checking the dimensionality of the value vs the logp. We use this logic in some places already:

if input_logprob.ndim < value.ndim:
# For multivariate variables, the Jacobian is diagonal.
# We can get the right result by summing the last dimensions
# of `transform_elemwise.log_jac_det`
ndim_supp = value.ndim - input_logprob.ndim
jacobian = jacobian.sum(axis=tuple(range(-ndim_supp, 0)))

pymc/pymc/logprob/tensor.py

Lines 185 to 189 in f67ff8b

if len({logp.ndim for logp in logps}) != 1:
raise ValueError(
"Joined logps have different number of dimensions, this can happen when "
"joining univariate and multivariate distributions",
)

Which makes me worry whether the probability of a transformed broadcasted variable may be invalid because the "Jacobian" term is going to be counted multiple times?

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You raised a very good point, which makes me wonder to what extent #6797 is correct in general?

For instance, if you scale a 3-vector Dirichlet you shouldn't count the Jacobian 3 times, because one of the entries is redundant.

Do we need to propagate information about over-determined elements in multi-dimensional RVs?

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first part of this answer suggests you count it 3 times indeed: https://stats.stackexchange.com/a/487538

I'm surprised :D

Edit: As seen below, that answer is wrong

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This I think says something else and correct? https://upcommons.upc.edu/bitstream/handle/2117/366723/p20-CoDaWork2011.pdf?sequence=1&isAllowed=y

I think these should match:

import pymc as pm
import numpy as np

x = 0.75
print(
    pm.logp(pm.Beta.dist(5, 9), x).eval(),
    pm.logp(pm.Dirichlet.dist([5, 9]), [x, 1-x]).eval(),
)  # -3.471576058736023 -3.471576058736023

print(
    pm.logp(2 * pm.Beta.dist(5, 9), 2 * x).eval(),
    pm.logp(2 * pm.Dirichlet.dist([5, 9]), 2*np.array([x, 1-x])).eval(),
)  # -4.164723239295968 -4.857870419855914

print(
    pm.logp(2 * pm.Beta.dist(5, 9), 2 * x).eval(),
    (pm.logp(pm.Dirichlet.dist([5, 9]), ([x, 1-x])) - np.log(2)).eval(),
)  # -4.164723239295968 -4.164723239295968

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once we have the meta-info, the Mixture will unambiguously know what kind of measurable variable it is dealing with. In the case of MeasurableBroadcasting, for example, the ndim_supp will have to be at least as large as the number of broadcasted dims (which means we should collapse that logp dimension instead of leaving it as we were doing now!).

This makes sense! Would you say that it's better to wait for #6360?

The first part of this answer suggests you count it 3 times indeed: https://stats.stackexchange.com/a/487538

I'm surprised :D

I'm not sure if I fully follow 😅 Nonetheless, I'm glad that this question raised some interesting concerns

"""Log-probability expression for (statically-)broadcasted RV

The probability is the same as the base RV, if no broadcasting had happened:

``logp(broadcast_to(normal(size=(3, 1)), (2, 3, 4)), zeros((2, 3, 4))) == logp(normal(size=(3, 1)), zeros((3, 1)))``

And zero if the value couldn't have possibly originated via broadcasting:

``logp(broadcast_to(normal(size=(1,)), (3,)), [1, 2, 3]) == [-np.inf]``

"""
[value] = values

n_new_dims = len(shape) - rv.ndim
assert n_new_dims >= 0

# Enumerate broadcasted dims
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to follow along here, this comment is more for "mental scribbles".

rv = pt.random.normal(size=(3, 1))
x = pt.broadcast_to(rv, (5, 2, 3, 4)) # a bit more than your example above
# rv.broadcastable = (False, False, False, False)

n_new_dims = 2 # 4 - 2
expanded_dims = (0, 1)

value.broadcastable[n_new_dims:] = (False, False) # (3, 4)
rv.broadcastable = (False, True) # (3, 1)

# condition is True only: if (not v_bcast) and rv_bcast = if (not False) and True
# condition is True only if v_bast is False and rv_bcast is True
broadcast_dims = (3,) # (0 + 2, 1 + 2) but conditions are (False, True)?

expanded_dims = tuple(range(n_new_dims))
broadcast_dims = tuple(
i + n_new_dims
for i, (v_bcast, rv_bcast) in enumerate(
zip(value.broadcastable[n_new_dims:], rv.broadcastable)
)
if (not v_bcast) and rv_bcast
)
larryshamalama marked this conversation as resolved.
Show resolved Hide resolved

# "Unbroadcast" value via indexing.
# All entries in the broadcasted dimensions should be the same, so we simply select the first of each.
indices = []
for i in range(value.ndim):
# Remove expanded dims
if i in expanded_dims:
indices.append(0)
# Keep first entry of broadcasted (but not expanded) dims
elif i in broadcast_dims:
indices.append(slice(0, 1))
else:
indices.append(slice(None))

unbroadcast_value = value[tuple(indices)]
logp = _logprob_helper(rv, unbroadcast_value)

# Check that dependent values were indeed identical, by comparing with a re-broadcasted value
valid_value = pt.broadcast_to(unbroadcast_value, shape)
# Note: This could fail due to float-precision issues.
# If that proves to be a problem we should switch to `pt.allclose`
check = pt.all(pt.eq(value, valid_value))
larryshamalama marked this conversation as resolved.
Show resolved Hide resolved
logp = pt.switch(check, logp, -np.inf)

# Reintroduce expanded_dims in the returned logp
if n_new_dims > 0:
logp = pt.shape_padleft(logp, n_new_dims)

return logp


@node_rewriter([BroadcastTo])
def find_measurable_broadcast(fgraph, node):
r"""Finds `BroadcastTo`\s for which a `logprob` can be computed."""

if isinstance(node.op, MeasurableBroadcast):
return None # pragma: no cover

rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)

if rv_map_feature is None:
return None # pragma: no cover

base_rv, *shape = node.inputs

if not rv_map_feature.request_measurable([base_rv]):
larryshamalama marked this conversation as resolved.
Show resolved Hide resolved
return None

new_rv = measurable_broadcast.make_node(base_rv, *shape).default_output()

return [new_rv]


measurable_ir_rewrites_db.register(
"find_measurable_broadcast",
find_measurable_broadcast,
"basic",
"shape",
)
69 changes: 47 additions & 22 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
# SOFTWARE.

import abc
import warnings

from copy import copy
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -111,6 +112,7 @@
cleanup_ir_rewrites_db,
measurable_ir_rewrites_db,
)
from pymc.logprob.shape import measurable_broadcast
from pymc.logprob.utils import check_potential_measurability


Expand Down Expand Up @@ -564,10 +566,13 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li

scalar_op = node.op.scalar_op
measurable_input_idx = 0
measurable_input_broadcast = (
measurable_input.type.broadcastable != node.default_output().type.broadcastable
)
transform_inputs: Tuple[TensorVariable, ...] = (measurable_input,)
transform: RVTransform

transform_dict = {
unary_transforms_dict = {
Exp: ExpTransform(),
Log: LogTransform(),
Abs: AbsTransform(),
Expand All @@ -581,29 +586,49 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
Erfc: ErfcTransform(),
Erfcx: ErfcxTransform(),
}
transform = transform_dict.get(type(scalar_op), None)
if isinstance(scalar_op, Pow):
# We only allow for the base to be measurable
if measurable_input_idx != 0:
return None
try:
(power,) = other_inputs
power = pt.get_underlying_scalar_constant_value(power).item()
# Power needs to be a constant
except NotScalarConstantError:
transform = unary_transforms_dict.get(type(scalar_op), None)
if transform is None:
if isinstance(scalar_op, Pow):
# We only allow for the base to be measurable
if measurable_input_idx != 0:
return None
try:
(power,) = other_inputs
base_power = pt.get_underlying_scalar_constant_value(power).item()
# Power needs to be a constant
except NotScalarConstantError:
return None
transform_inputs = (measurable_input, power)
transform = PowerTransform(power=base_power)
elif isinstance(scalar_op, Add):
transform_inputs = (measurable_input, pt.add(*other_inputs))
transform = LocTransform(
transform_args_fn=lambda *inputs: inputs[-1],
)
elif isinstance(scalar_op, Mul):
transform_inputs = (measurable_input, pt.mul(*other_inputs))
transform = ScaleTransform(
transform_args_fn=lambda *inputs: inputs[-1],
)
else:
raise TypeError(
f"Scalar Op not supported: {scalar_op}. Rewrite should not have been triggered"
) # pragma: no cover

if measurable_input_broadcast:
# This rewrite logic only supports broadcasting for transforms with two inputs, where the first is measurable.
# This covers all current cases, update if other cases are supported in the future.
if len(transform_inputs) != 2 or measurable_input_idx != 0:
return None
transform_inputs = (measurable_input, power)
transform = PowerTransform(power=power)
elif isinstance(scalar_op, Add):
transform_inputs = (measurable_input, pt.add(*other_inputs))
transform = LocTransform(
transform_args_fn=lambda *inputs: inputs[-1],
)
elif transform is None:
transform_inputs = (measurable_input, pt.mul(*other_inputs))
transform = ScaleTransform(
transform_args_fn=lambda *inputs: inputs[-1],
warnings.warn(
"MeasurableTransform with implicit broadcasting detected. This corresponds to a potentially degenerate probability graph.\n"
"If you did not intend this, make sure the base measurable variable is created with all the dimensions from the start."
"Otherwise, an explicit `broadcast_to` operation can be used to silence this warning.\n",
UserWarning,
)
measurable_input, other_input = transform_inputs
measurable_input = measurable_broadcast(measurable_input, other_input.shape)
transform_inputs = (measurable_input, other_input)

transform_op = MeasurableTransform(
scalar_op=scalar_op,
Expand Down
50 changes: 50 additions & 0 deletions tests/logprob/test_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2023 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytensor
import pytensor.tensor as pt
import scipy.stats as st

from pymc import logp


def test_measurable_broadcast():
b_shape = pt.vector("b_shape", shape=(3,), dtype=int)

x = pt.random.normal(size=(3, 1))
bcast_x = pt.broadcast_to(x, shape=b_shape)
bcast_x.name = "bcast_x"

bcast_x_value = bcast_x.clone()
logp_bcast_x = logp(bcast_x, bcast_x_value)
logp_fn = pytensor.function([b_shape, bcast_x_value], logp_bcast_x, on_unused_input="ignore")

# assert_allclose also asserts shapes match (if neither is scalar)
np.testing.assert_allclose(
logp_fn([1, 3, 1], np.zeros((1, 3, 1))),
st.norm.logpdf(np.zeros((1, 3, 1))),
)
np.testing.assert_allclose(
logp_fn([1, 3, 5], np.zeros((1, 3, 5))),
st.norm.logpdf(np.zeros((1, 3, 1))),
)
np.testing.assert_allclose(
logp_fn([2, 3, 5], np.broadcast_to(np.arange(3).reshape(1, 3, 1), (2, 3, 5))),
st.norm.logpdf(np.arange(3).reshape(1, 3, 1)),
)
# Invalid broadcast value
np.testing.assert_array_equal(
logp_fn([1, 3, 5], np.arange(3 * 5).reshape(1, 3, 5)),
np.full(shape=(1, 3, 1), fill_value=-np.inf),
)
32 changes: 26 additions & 6 deletions tests/logprob/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,16 +807,36 @@ def test_discrete_rv_multinary_transform_fails():
conditional_logp({y_rv: y_rv.clone()})


@pytest.mark.xfail(reason="Check not implemented yet")
def test_invalid_broadcasted_transform_rv_fails():
@pytest.mark.filterwarnings("error") # Fail if unexpected warning is issued
@pytest.mark.parametrize("implicit_broadcast", (True, False))
def test_broadcasted_transform_rv(implicit_broadcast):
larryshamalama marked this conversation as resolved.
Show resolved Hide resolved
loc = pt.vector("loc")
y_rv = loc + pt.random.normal(0, 1, size=1, name="base_rv")
base_rv = pt.random.normal(0, 1, size=1, name="base_rv")
if implicit_broadcast:
y_rv = loc + base_rv
else:
y_rv = loc + pt.broadcast_to(base_rv, shape=loc.shape)
y_rv.name = "y"
y_vv = y_rv.clone()

# This logp derivation should fail or count only once the values that are broadcasted
logprob = logp(y_rv, y_vv)
assert logprob.eval({y_vv: [0, 0, 0, 0], loc: [0, 0, 0, 0]}).shape == ()
if implicit_broadcast:
with pytest.warns(UserWarning, match="implicit broadcasting detected"):
logprob = logp(y_rv, y_vv)
else:
logprob = logp(y_rv, y_vv)
logprob_fn = pytensor.function([loc, y_vv], logprob)

# All values must have the same offset from `loc`
np.testing.assert_allclose(
logprob_fn([1, 1, 1, 1], [0, 0, 0, 0]), sp.stats.norm.logpdf([0], loc=1)
)
np.testing.assert_allclose(
logprob_fn([1, 2, 3, 4], [0, 1, 2, 3]), sp.stats.norm.logpdf([0], loc=1)
)

# Otherwise probability is 0
np.testing.assert_array_equal(logprob_fn([1, 1, 1, 1], [0, 0, 0, 1]), [-np.inf])
np.testing.assert_array_equal(logprob_fn([1, 2, 3, 4], [0, 0, 0, 0]), [-np.inf])


@pytest.mark.parametrize("numerator", (1.0, 2.0))
Expand Down