-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simple stick breaking (Formerly #3620) #3638
Conversation
The test that seems to be failing is |
self.eps = eps | ||
def __init__(self, eps=None): | ||
if eps is not None: | ||
warnings.warn("The argument `eps` is depricated and will not be used.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
deprecated
This is great -- tests would be useful. |
The test fail is indicating the Jacobian is incorrect. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would probably be best to make this new stick-breaking only an alternative to the current one, and not the default. I observed, that samples would sometimes distribute more normally in the transformed space and hence allow better approximation by ADVI. Here is an example where I used the current stick-beaking to sample and the new alternative altStickbeaking
to transform the results:
ax = pd.DataFrame(trace['decomp_stickbreaking__']).plot.kde(figsize=(10,4))
ax.set_xlim(-10, 10)
even know the 10 components of trace['decomp']
are equivalent they each look different in the transformed space but are quite normally distributed and in the alternative...
md = altStickBreaking().forward(trace['decomp']).eval()
ax = pd.DataFrame(md).plot.kde(figsize=(10,4))
ax.set_xlim(-10, 10)
...there is a bump to the right. The model generating this trace can be found here: https://discourse.pymc.io/t/numerical-issues-with-stickbreaking-in-advi/3825
Also, forcing the samples towards the edge of the simplex, e.g. with a bad prior, seems more stable in the current version:
import pymc3 as pm
import pandas as pd
import numpy as np
with pm.Model() as model:
decomp = pm.Dirichlet('decomp', np.ones(10)*.1, shape=10)
trace = pm.sample()
pd.DataFrame(trace['decomp_stickbreaking__']).plot.kde(figsize=(10,4));
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [decomp]
Sampling 4 chains, 309 divergences: 100%|██████████| 4000/4000 [00:04<00:00, 817.95draws/s]
There were 66 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.8791365205500101, but should be close to 0.8. Try to increase the number of tuning steps.
There were 99 divergences after tuning. Increase `target_accept` or reparameterize.
There were 67 divergences after tuning. Increase `target_accept` or reparameterize.
There were 76 divergences after tuning. Increase `target_accept` or reparameterize.
with pm.Model() as model:
decomp = pm.Dirichlet('decomp', np.ones(10)*.1, shape=10,
transform=altStickBreaking())
trace2 = pm.sample()
pd.DataFrame(trace2['decomp_stickbreaking__']).plot.kde(figsize=(10,4));
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [decomp]
Sampling 4 chains, 2,000 divergences: 100%|██████████| 4000/4000 [00:05<00:00, 707.18draws/s]
The chain contains only diverging samples. The model is probably misspecified.
The chain contains only diverging samples. The model is probably misspecified.
The chain contains only diverging samples. The model is probably misspecified.
The chain contains only diverging samples. The model is probably misspecified.
The rhat statistic is larger than 1.4 for some parameters. The sampler did not converge.
The estimated number of effective samples is smaller than 200 for some parameters.
This is probably due to the jacobian beeing incorrect, especially close to the edge of the simplex.
return floatX(y.T) | ||
|
||
def forward_val(self, x_, point=None): | ||
def forward_val(self, x_): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We probably have to keep the argument point=None
since it is used in https://github.com/pymc-devs/pymc3/blob/e81df2d19ddc4066648b4b2dfc72431c6824f96f/pymc3/util.py#L141-L142
def __init__(self, eps=None): | ||
if eps is not None: | ||
warnings.warn("The argument `eps` is depricated and will not be used.", | ||
DeprecationWarning) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did this because I replaced the original stick-breaking that took the argument and didn't want to break it. Maybe we don't need it anymore.
Is anyone able to troubleshoot the bad Jacobian? I don't think I can wrap my head around it at the moment, and that's the important blocker here. Happy to work through the rest of the details of this PR if there's an obvious fix for that part. |
Maby this helps the troubleshooting: The Jacobian is a little tricky since the transformation comes with a change in dimensionality. Calculating the Jacobian with the dimensional change yields a non-square matrix that has no determinant. However, the distributions using it, e.g., the Dirichlet distribution with I used the wolfram language to calculate the determinant for arbitrary but fixed |
Finally getting back to this, although I'm not sure I'll be able to help much... It's not clear to me if the WolframCloud notebook is what you based your original Jacobian implementation off of, or something you hacked together in response to this PR, but either way I think the two are not equivalent...? Versus the implementation in this PR as it currently stands:
Isn't this missing the |
@bsmith89 The +1 denominator is realized by appending a 0 to |
What's the status of this? Should we close? |
We never found out why the tests are failing. The current stickbreaking transformation that comes from stan seems to only concern mapping to the correct range. The first components of the simplex are mapped directly to R but the mapping of later components depends on the values of previous components. Hence the variance of early components increases the variance of later components in the simplex posterior. I believe the results of ADVI may suffer from bias in later components of the simplex. The suggested transformation is somewhat equivalent to the isometric log-ration transformation: https://link.springer.com/article/10.1023/A:1023818214614 But without the need to calculate an appropriate Helmert-matrix. Only the last component of the simplex is completely implied by the former but all other components are independent and according to the cited paper a lot of metric properties are conserved through the transformation. However, it is unclear how much better this transformation is in practice, and if it is worth to fix whatever is going wrong in the test. |
|
||
def jacobian_det(self, y_): | ||
y = y_.T | ||
Km1 = y.shape[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to write Km1 = y.shape[0] + 1
here.
I found the mistake in the jacobian: |
I created the new PR #4129 to explain some details and run the tests. |
I found a slight update of #3620 by @katosh to be very useful for my work, so I'm opening a new version of that now closed PR.
I'm not entirely sure what katosh meant by
But I think that the dimensionality of the Jacobian was wrong in his patch, so I've added a few additional commits that squeeze the last dimension.
I can also spend some time developing a minimal example (slash tests) showing the problems with stick-breaking for the current master (e81df2d as of this PR), but I don't currently have those. Briefly, I've found through experience that the extreme edges of the simplex space are numerically unstable, and result in
inf
s terminating gradient descent. NUTS also usually ends up with many divergences (often 100%) for what I think are related reasons.This patch seems to have fixed both of those problems, and hasn't misbehaved for any 2-dimensional variables in my tests over the last few weeks.
It does NOT, however, work for random variables that are more complicated than a stacked array of Dirichlet RVs. I believe that this is also an issue for the implementation at master. I have no idea what the fix to that problem would be.