Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simple stick breaking (Formerly #3620) #3638

Closed
wants to merge 9 commits into from

Conversation

bsmith89
Copy link
Contributor

@bsmith89 bsmith89 commented Oct 1, 2019

I found a slight update of #3620 by @katosh to be very useful for my work, so I'm opening a new version of that now closed PR.

I'm not entirely sure what katosh meant by

I don't get the tests to work. My jacobian must be wrong or not stable enough, And the problem at hand could not be solved.

But I think that the dimensionality of the Jacobian was wrong in his patch, so I've added a few additional commits that squeeze the last dimension.

I can also spend some time developing a minimal example (slash tests) showing the problems with stick-breaking for the current master (e81df2d as of this PR), but I don't currently have those. Briefly, I've found through experience that the extreme edges of the simplex space are numerically unstable, and result in infs terminating gradient descent. NUTS also usually ends up with many divergences (often 100%) for what I think are related reasons.

This patch seems to have fixed both of those problems, and hasn't misbehaved for any 2-dimensional variables in my tests over the last few weeks.

It does NOT, however, work for random variables that are more complicated than a stacked array of Dirichlet RVs. I believe that this is also an issue for the implementation at master. I have no idea what the fix to that problem would be.

@bsmith89
Copy link
Contributor Author

bsmith89 commented Oct 1, 2019

The test that seems to be failing is test_simplex_bounds. Can anybody advise on what might be happening? I can't quite grok the error message.

self.eps = eps
def __init__(self, eps=None):
if eps is not None:
warnings.warn("The argument `eps` is depricated and will not be used.",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deprecated

@twiecki
Copy link
Member

twiecki commented Oct 2, 2019

This is great -- tests would be useful.

@junpenglao
Copy link
Member

The test fail is indicating the Jacobian is incorrect.

Copy link
Contributor

@katosh katosh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would probably be best to make this new stick-breaking only an alternative to the current one, and not the default. I observed, that samples would sometimes distribute more normally in the transformed space and hence allow better approximation by ADVI. Here is an example where I used the current stick-beaking to sample and the new alternative altStickbeaking to transform the results:

ax = pd.DataFrame(trace['decomp_stickbreaking__']).plot.kde(figsize=(10,4))
ax.set_xlim(-10, 10)

image
even know the 10 components of trace['decomp'] are equivalent they each look different in the transformed space but are quite normally distributed and in the alternative...

md = altStickBreaking().forward(trace['decomp']).eval()
ax = pd.DataFrame(md).plot.kde(figsize=(10,4))
ax.set_xlim(-10, 10)

image
...there is a bump to the right. The model generating this trace can be found here: https://discourse.pymc.io/t/numerical-issues-with-stickbreaking-in-advi/3825

Also, forcing the samples towards the edge of the simplex, e.g. with a bad prior, seems more stable in the current version:

import pymc3 as pm
import pandas as pd
import numpy as np
with pm.Model() as model:
    decomp = pm.Dirichlet('decomp', np.ones(10)*.1, shape=10)
    trace = pm.sample()
pd.DataFrame(trace['decomp_stickbreaking__']).plot.kde(figsize=(10,4));
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [decomp]
Sampling 4 chains, 309 divergences: 100%|██████████| 4000/4000 [00:04<00:00, 817.95draws/s]
There were 66 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.8791365205500101, but should be close to 0.8. Try to increase the number of tuning steps.
There were 99 divergences after tuning. Increase `target_accept` or reparameterize.
There were 67 divergences after tuning. Increase `target_accept` or reparameterize.
There were 76 divergences after tuning. Increase `target_accept` or reparameterize.

image

with pm.Model() as model:
    decomp = pm.Dirichlet('decomp', np.ones(10)*.1, shape=10,
                          transform=altStickBreaking())
    trace2 = pm.sample()
pd.DataFrame(trace2['decomp_stickbreaking__']).plot.kde(figsize=(10,4));
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [decomp]
Sampling 4 chains, 2,000 divergences: 100%|██████████| 4000/4000 [00:05<00:00, 707.18draws/s]
The chain contains only diverging samples. The model is probably misspecified.
The chain contains only diverging samples. The model is probably misspecified.
The chain contains only diverging samples. The model is probably misspecified.
The chain contains only diverging samples. The model is probably misspecified.
The rhat statistic is larger than 1.4 for some parameters. The sampler did not converge.
The estimated number of effective samples is smaller than 200 for some parameters.

image

This is probably due to the jacobian beeing incorrect, especially close to the edge of the simplex.

return floatX(y.T)

def forward_val(self, x_, point=None):
def forward_val(self, x_):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +523 to +526
def __init__(self, eps=None):
if eps is not None:
warnings.warn("The argument `eps` is depricated and will not be used.",
DeprecationWarning)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did this because I replaced the original stick-breaking that took the argument and didn't want to break it. Maybe we don't need it anymore.

@bsmith89
Copy link
Contributor Author

bsmith89 commented Oct 3, 2019

Is anyone able to troubleshoot the bad Jacobian? I don't think I can wrap my head around it at the moment, and that's the important blocker here.

Happy to work through the rest of the details of this PR if there's an obvious fix for that part.

@katosh
Copy link
Contributor

katosh commented Oct 4, 2019

Maby this helps the troubleshooting:

The Jacobian is a little tricky since the transformation comes with a change in dimensionality. Calculating the Jacobian with the dimensional change yields a non-square matrix that has no determinant. However, the distributions using it, e.g., the Dirichlet distribution with n categories, must implicitly have only n-1 degrees of freedom. So what needs to be implemented is the logarithm of the absolute value of the determinant of the jacobian of the backward transformation restricted to n-1 coordinates.

I used the wolfram language to calculate the determinant for arbitrary but fixed n here (https://www.wolframcloud.com/obj/dominik.otto/Published/stick-breaking-jacobian.nb) and implemented the logarithm. In the notebook, the function f is the backward transformation restricted to the first n-1 coordinates. I used a helper function znp that represents the missing nth coordinate since it is still needed in the softmax denominator.

@bsmith89
Copy link
Contributor Author

@katosh

Finally getting back to this, although I'm not sure I'll be able to help much...

It's not clear to me if the WolframCloud notebook is what you based your original Jacobian implementation off of, or something you hacked together in response to this PR, but either way I think the two are not equivalent...?

Your derivation at k=4 :
image

Versus the implementation in this PR as it currently stands:

    def jacobian_det(self, y_):
        y = y_.T
        k = y.shape[0]
        sy = tt.sum(y, 0, keepdims=True)
        r = tt.concatenate([y+sy, tt.zeros(sy.shape)])
        sr = tt.log(tt.sum(tt.exp(r), 0, keepdims=True))
        d = tt.log(k) + (k*sy) - (k*sr)
        return tt.sum(d, 0).T

Isn't this missing the +1 in the denominator, or am I confused?

@katosh
Copy link
Contributor

katosh commented Dec 17, 2019

@bsmith89 The +1 denominator is realized by appending a 0 to r and then taking the exponential. I did it that way because I was not sure if theano would optimize the log(sum(exp(...))) correctly if I wrote log(sum(exp(...))+1) .

@twiecki
Copy link
Member

twiecki commented Jul 27, 2020

What's the status of this? Should we close?

@katosh
Copy link
Contributor

katosh commented Jul 27, 2020

We never found out why the tests are failing.

The current stickbreaking transformation that comes from stan seems to only concern mapping to the correct range. The first components of the simplex are mapped directly to R but the mapping of later components depends on the values of previous components. Hence the variance of early components increases the variance of later components in the simplex posterior. I believe the results of ADVI may suffer from bias in later components of the simplex.
https://mc-stan.org/docs/2_18/reference-manual/simplex-transform-section.html

The suggested transformation is somewhat equivalent to the isometric log-ration transformation: https://link.springer.com/article/10.1023/A:1023818214614 But without the need to calculate an appropriate Helmert-matrix. Only the last component of the simplex is completely implied by the former but all other components are independent and according to the cited paper a lot of metric properties are conserved through the transformation.

However, it is unclear how much better this transformation is in practice, and if it is worth to fix whatever is going wrong in the test.

@twiecki
Copy link
Member

twiecki commented Jul 27, 2020

Thanks @katosh. I'm closing this then. If @bsmith89 is motivated to find out what's wrong with the test we can revisit.

@twiecki twiecki closed this Jul 27, 2020

def jacobian_det(self, y_):
y = y_.T
Km1 = y.shape[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to write Km1 = y.shape[0] + 1 here.

@katosh
Copy link
Contributor

katosh commented Sep 23, 2020

I found the mistake in the jacobian: y.shape[0] equals n-1 and not equal the dimension of the simplex n. To fix the implementation we need to add 1 to KLm1here: https://github.com/bsmith89/pymc3/blob/c0cfbd516b407dbe0619c2d1679c7c936427aeaa/pymc3/distributions/transforms.py#L560

@katosh katosh mentioned this pull request Sep 23, 2020
@katosh
Copy link
Contributor

katosh commented Sep 23, 2020

I created the new PR #4129 to explain some details and run the tests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants