Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Derive probability for broadcasting operations #6808

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jun 30, 2023

Related to #6398

TODO:

  • Cover Second/Alloc which are other froms of broadcasting

📚 Documentation preview 📚: https://pymc--6808.org.readthedocs.build/en/6808/

CC @shreyas3156

@codecov
Copy link

codecov bot commented Jun 30, 2023

Codecov Report

Merging #6808 (042c9f3) into main (413af04) will increase coverage by 0.10%.
The diff coverage is 94.52%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6808      +/-   ##
==========================================
+ Coverage   91.93%   92.03%   +0.10%     
==========================================
  Files          95       96       +1     
  Lines       16226    16317      +91     
==========================================
+ Hits        14917    15018     +101     
+ Misses       1309     1299      -10     
Impacted Files Coverage Δ
pymc/logprob/transforms.py 94.60% <89.28%> (-0.08%) ⬇️
pymc/logprob/shape.py 97.72% <97.72%> (ø)
pymc/logprob/__init__.py 100.00% <100.00%> (ø)

... and 5 files with indirect coverage changes

@ricardoV94 ricardoV94 force-pushed the measurable_broadcast branch from 3fbb4ae to 4ac4189 Compare June 30, 2023 09:44
@ricardoV94 ricardoV94 force-pushed the measurable_broadcast branch from 4ac4189 to 2d84ff1 Compare June 30, 2023 10:09
A warning is issued as this graph is unlikely to be desired for most users.
@ricardoV94 ricardoV94 force-pushed the measurable_broadcast branch from 2d84ff1 to 042c9f3 Compare June 30, 2023 17:36
@larryshamalama
Copy link
Member

I will get to this tomorrow morning, sorry about the delay!

Copy link
Member

@larryshamalama larryshamalama left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work @ricardoV94! :) A lot of nice abstractions, which, together, are why I have many questions



@_logprob.register(MeasurableBroadcast)
def broadcast_logprob(op, values, rv, *shape, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking out loud: could this possibly result in inconsistencies elsewhere? For instance, having Mixture components that have been broadcasted which would render them dependent, if that would be an issue

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The index mixture only works for basic RVs still so that's fine.

The switch mixture could actually wrongly broadcast the logp. In fact we should also check for invalid switches that mix support dimensions. The current implementation is only correct for ndim_supp==0!

This is another example of why it's so important to have the meta-info for all the MeasurableOps (#6360).

Once we have the meta-info, the Mixture will unambiguously know what kind of measurable variable it is dealing with. In the case of MeasurableBroadcasting, for example, the ndim_supp will have to be at least as large as the number of broadcasted dims (which means we should collapse that logp dimension instead of leaving it as we were doing now!).

We will also know where those support dims are, so that Mixture can know whether we are sub-selecting across core dims.

Without the meta-info, the only way of knowing ndim_supp is by checking the dimensionality of the value vs the logp. We use this logic in some places already:

if input_logprob.ndim < value.ndim:
# For multivariate variables, the Jacobian is diagonal.
# We can get the right result by summing the last dimensions
# of `transform_elemwise.log_jac_det`
ndim_supp = value.ndim - input_logprob.ndim
jacobian = jacobian.sum(axis=tuple(range(-ndim_supp, 0)))

pymc/pymc/logprob/tensor.py

Lines 185 to 189 in f67ff8b

if len({logp.ndim for logp in logps}) != 1:
raise ValueError(
"Joined logps have different number of dimensions, this can happen when "
"joining univariate and multivariate distributions",
)

Which makes me worry whether the probability of a transformed broadcasted variable may be invalid because the "Jacobian" term is going to be counted multiple times?

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You raised a very good point, which makes me wonder to what extent #6797 is correct in general?

For instance, if you scale a 3-vector Dirichlet you shouldn't count the Jacobian 3 times, because one of the entries is redundant.

Do we need to propagate information about over-determined elements in multi-dimensional RVs?

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first part of this answer suggests you count it 3 times indeed: https://stats.stackexchange.com/a/487538

I'm surprised :D

Edit: As seen below, that answer is wrong

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This I think says something else and correct? https://upcommons.upc.edu/bitstream/handle/2117/366723/p20-CoDaWork2011.pdf?sequence=1&isAllowed=y

I think these should match:

import pymc as pm
import numpy as np

x = 0.75
print(
    pm.logp(pm.Beta.dist(5, 9), x).eval(),
    pm.logp(pm.Dirichlet.dist([5, 9]), [x, 1-x]).eval(),
)  # -3.471576058736023 -3.471576058736023

print(
    pm.logp(2 * pm.Beta.dist(5, 9), 2 * x).eval(),
    pm.logp(2 * pm.Dirichlet.dist([5, 9]), 2*np.array([x, 1-x])).eval(),
)  # -4.164723239295968 -4.857870419855914

print(
    pm.logp(2 * pm.Beta.dist(5, 9), 2 * x).eval(),
    (pm.logp(pm.Dirichlet.dist([5, 9]), ([x, 1-x])) - np.log(2)).eval(),
)  # -4.164723239295968 -4.164723239295968

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once we have the meta-info, the Mixture will unambiguously know what kind of measurable variable it is dealing with. In the case of MeasurableBroadcasting, for example, the ndim_supp will have to be at least as large as the number of broadcasted dims (which means we should collapse that logp dimension instead of leaving it as we were doing now!).

This makes sense! Would you say that it's better to wait for #6360?

The first part of this answer suggests you count it 3 times indeed: https://stats.stackexchange.com/a/487538

I'm surprised :D

I'm not sure if I fully follow 😅 Nonetheless, I'm glad that this question raised some interesting concerns

n_new_dims = len(shape) - rv.ndim
assert n_new_dims >= 0

# Enumerate broadcasted dims
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to follow along here, this comment is more for "mental scribbles".

rv = pt.random.normal(size=(3, 1))
x = pt.broadcast_to(rv, (5, 2, 3, 4)) # a bit more than your example above
# rv.broadcastable = (False, False, False, False)

n_new_dims = 2 # 4 - 2
expanded_dims = (0, 1)

value.broadcastable[n_new_dims:] = (False, False) # (3, 4)
rv.broadcastable = (False, True) # (3, 1)

# condition is True only: if (not v_bcast) and rv_bcast = if (not False) and True
# condition is True only if v_bast is False and rv_bcast is True
broadcast_dims = (3,) # (0 + 2, 1 + 2) but conditions are (False, True)?

pymc/logprob/shape.py Show resolved Hide resolved
pymc/logprob/shape.py Show resolved Hide resolved
pymc/logprob/shape.py Show resolved Hide resolved
tests/logprob/test_transforms.py Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants