-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Derive probability for broadcasting operations #6808
base: main
Are you sure you want to change the base?
Conversation
0ec578c
to
3fbb4ae
Compare
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #6808 +/- ##
==========================================
+ Coverage 91.93% 92.03% +0.10%
==========================================
Files 95 96 +1
Lines 16226 16317 +91
==========================================
+ Hits 14917 15018 +101
+ Misses 1309 1299 -10
|
3fbb4ae
to
4ac4189
Compare
4ac4189
to
2d84ff1
Compare
A warning is issued as this graph is unlikely to be desired for most users.
2d84ff1
to
042c9f3
Compare
I will get to this tomorrow morning, sorry about the delay! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work @ricardoV94! :) A lot of nice abstractions, which, together, are why I have many questions
|
||
|
||
@_logprob.register(MeasurableBroadcast) | ||
def broadcast_logprob(op, values, rv, *shape, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thinking out loud: could this possibly result in inconsistencies elsewhere? For instance, having Mixture components that have been broadcasted which would render them dependent, if that would be an issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The index mixture only works for basic RVs still so that's fine.
The switch mixture could actually wrongly broadcast the logp. In fact we should also check for invalid switches that mix support dimensions. The current implementation is only correct for ndim_supp==0
!
This is another example of why it's so important to have the meta-info for all the MeasurableOps (#6360).
Once we have the meta-info, the Mixture will unambiguously know what kind of measurable variable it is dealing with. In the case of MeasurableBroadcasting, for example, the ndim_supp
will have to be at least as large as the number of broadcasted dims (which means we should collapse that logp dimension instead of leaving it as we were doing now!).
We will also know where those support dims are, so that Mixture can know whether we are sub-selecting across core dims.
Without the meta-info, the only way of knowing ndim_supp
is by checking the dimensionality of the value vs the logp. We use this logic in some places already:
pymc/pymc/logprob/transforms.py
Lines 432 to 437 in f67ff8b
if input_logprob.ndim < value.ndim: | |
# For multivariate variables, the Jacobian is diagonal. | |
# We can get the right result by summing the last dimensions | |
# of `transform_elemwise.log_jac_det` | |
ndim_supp = value.ndim - input_logprob.ndim | |
jacobian = jacobian.sum(axis=tuple(range(-ndim_supp, 0))) |
Lines 185 to 189 in f67ff8b
if len({logp.ndim for logp in logps}) != 1: | |
raise ValueError( | |
"Joined logps have different number of dimensions, this can happen when " | |
"joining univariate and multivariate distributions", | |
) |
Which makes me worry whether the probability of a transformed broadcasted variable may be invalid because the "Jacobian" term is going to be counted multiple times?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You raised a very good point, which makes me wonder to what extent #6797 is correct in general?
For instance, if you scale a 3-vector Dirichlet
you shouldn't count the Jacobian 3 times, because one of the entries is redundant.
Do we need to propagate information about over-determined elements in multi-dimensional RVs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The first part of this answer suggests you count it 3 times indeed: https://stats.stackexchange.com/a/487538
I'm surprised :D
Edit: As seen below, that answer is wrong
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This I think says something else and correct? https://upcommons.upc.edu/bitstream/handle/2117/366723/p20-CoDaWork2011.pdf?sequence=1&isAllowed=y
I think these should match:
import pymc as pm
import numpy as np
x = 0.75
print(
pm.logp(pm.Beta.dist(5, 9), x).eval(),
pm.logp(pm.Dirichlet.dist([5, 9]), [x, 1-x]).eval(),
) # -3.471576058736023 -3.471576058736023
print(
pm.logp(2 * pm.Beta.dist(5, 9), 2 * x).eval(),
pm.logp(2 * pm.Dirichlet.dist([5, 9]), 2*np.array([x, 1-x])).eval(),
) # -4.164723239295968 -4.857870419855914
print(
pm.logp(2 * pm.Beta.dist(5, 9), 2 * x).eval(),
(pm.logp(pm.Dirichlet.dist([5, 9]), ([x, 1-x])) - np.log(2)).eval(),
) # -4.164723239295968 -4.164723239295968
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once we have the meta-info, the Mixture will unambiguously know what kind of measurable variable it is dealing with. In the case of MeasurableBroadcasting, for example, the
ndim_supp
will have to be at least as large as the number of broadcasted dims (which means we should collapse that logp dimension instead of leaving it as we were doing now!).
This makes sense! Would you say that it's better to wait for #6360?
The first part of this answer suggests you count it 3 times indeed: https://stats.stackexchange.com/a/487538
I'm surprised :D
I'm not sure if I fully follow 😅 Nonetheless, I'm glad that this question raised some interesting concerns
n_new_dims = len(shape) - rv.ndim | ||
assert n_new_dims >= 0 | ||
|
||
# Enumerate broadcasted dims |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trying to follow along here, this comment is more for "mental scribbles".
rv = pt.random.normal(size=(3, 1))
x = pt.broadcast_to(rv, (5, 2, 3, 4)) # a bit more than your example above
# rv.broadcastable = (False, False, False, False)
n_new_dims = 2 # 4 - 2
expanded_dims = (0, 1)
value.broadcastable[n_new_dims:] = (False, False) # (3, 4)
rv.broadcastable = (False, True) # (3, 1)
# condition is True only: if (not v_bcast) and rv_bcast = if (not False) and True
# condition is True only if v_bast is False and rv_bcast is True
broadcast_dims = (3,) # (0 + 2, 1 + 2) but conditions are (False, True)?
Related to #6398
TODO:
Second
/Alloc
which are other froms of broadcasting📚 Documentation preview 📚: https://pymc--6808.org.readthedocs.build/en/6808/
CC @shreyas3156