-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simple stick breaking #4129
Simple stick breaking #4129
Conversation
…e transformation.
Codecov Report
@@ Coverage Diff @@
## master #4129 +/- ##
=======================================
Coverage 88.74% 88.74%
=======================================
Files 89 89
Lines 14037 14024 -13
=======================================
- Hits 12457 12446 -11
+ Misses 1580 1578 -2
|
I will remove the NumPy implementation of the backward transformation |
I investigated sampling divergencies in the examples above. I changed the parameter for the Dirichlet distribution to import matplotlib.pyplot as plt
def pairplot_divergence(trace, var1, var2, i1=0, i2=0):
v1 = trace.get_values(varname=var1, combine=True)[:, i1]
v2 = trace.get_values(varname=var2, combine=True)[:, i2]
_, ax = plt.subplots(1, 1, figsize=(10, 5))
ax.plot(v1, v2, 'o', color='b', alpha=.5)
divergent = trace['diverging']
ax.plot(v1[divergent], v2[divergent], 'o', color='r')
ax.set_xlabel('{}[{}]'.format(var1, i1))
ax.set_ylabel('{}[{}]'.format(var2, i2))
ax.set_title('scatter plot between {}[{}] and {}[{}]'.format(var1, i1, var2, i2));
return ax Current StickBreakingpairplot_divergence(trace1, 'decomp', 'decomp', i1=2, i2=3) New StickBreaking2pairplot_divergence(trace2, 'decomp', 'decomp', i1=2, i2=3) ConclusionThe parameterization from |
It seems |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, this alternative stickbreaking seems fine, but why wouldn't it entirely replace the old one?
Also, if possible, this PR should add a new test that confirms one of the advantages of this transform over the old one. Ideally, such a test wouldn't require anything as costly as sampling. Is there a value range that demonstrates the improved numerical stability?
@katosh I would then still take the |
Code coverage is reduced since:
|
You can add a |
why not write a test with with pytest.warns(DeprecationWarning("<warning text>")):
<test which sets `eps` parameter> which covers it? |
You can definitely do that, but we're not really testing much of our own code in this case, so it's not a particularly relevant unit test. |
already done the test :) |
I tested how close to the edge of the simplex we can go before the transformation starts to break and for the cases I tested it seems to work down to the smallest float64: >>> import numpy as np
>>> from pymc3.distributions.transforms import stick_breaking
>>> a = 5e-324
>>> vec = np.array([a, a, a, 1-(3*a)]) # a point very close to the edge of the 4-simplex
>>> stick_breaking.backward(stick_breaking.forward(vec).eval()).eval()
array([5.e-324, 5.e-324, 5.e-324, 1.e+000]) # very close to vec However, the same can possibly not be said about the jacobian! |
@katosh This looks great and quite thorough. Is there anything missing before merging from your end? |
I am done so far but of course, I can do further testing if someone has a request. |
I think this is great, thanks so much for the contribution! |
Awesome, thank you for having me be part of this project! |
It appears that StickBreaking.forward_val is being eliminated, with no equivalent in the new version. This would concern me, as I happen to use it in a public repository. I could work around it, but perhaps there are others using it also. Is there any depreciation warning in the meantime? I only found out about this because I wanted to add a backward_val. |
Do you mind opening a separate issue for that? This one is pretty long and the forward_val were removed for all distributions not just StickBreaking |
See discourse topic here |
This is another attempt to introduce a new transformation of the n-simplex. The
stickbreaking
transformation is prominently used by the Dirichlet distribution as it maps the range of the Distribution (the n-simplex) to R^(n-1) where we can sample freely and apply, e.g., ADVI. The issue with the current implementation is that the transformation of later values in the vector depends on previous values. This introduces a dependency that can be confounding for ADVI and seems to produce numerical inaccuracies in some cases.There was a previous attempt to merge the new transformation but it had a mistake in the determinant of the jacobian: #3638
The current strikebreaking in master is an implementation of the transformation from Stan: https://mc-stan.org/docs/2_19/reference-manual/simplex-transform-section.html Which is just a repeated application of the logit transformation with adjusting range.
Advantages
eps
for numeric stability (as in https://github.com/pymc-devs/pymc3/blob/ba77d8502704e8aeb112782ee104fb339393cb19/pymc3/distributions/transforms.py#L475)Current StickBreaking
New StickBreaking2
The PR includes tests and there are no breaking changes as it only introduces a new transformation
pymc3.distributions.transforms.StickBreaking2
and leaves the originalpymc3.distributions.transforms.StickBreaking
untouched.