-
-
Notifications
You must be signed in to change notification settings - Fork 984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tutorial for Boosting Black Box Variational Inference #2308
Conversation
@lorenzkuhn this looks great! LGTM on code and documentation. @martinjankowiak can you review the relbo math? |
thanks @lorenzkuhn looks great! here are some nits/comments/suggestions:
|
Thanks a lot for the feedback @fritzo and @martinjankowiak, that sounds good. @martinjankowiak : Thanks, I'll fix the issues from the first three bullet points. What would be the recommended approach to avoid drawing extra samples in Trace_ELBO? |
@lorenzkuhn i think the easiest thing to do would be to replace
with a simple custom elbo along the lines of what's here, i.e. this way you can reuse your guide trace. |
As I understand, @martinjankowiak is concerned with the def relbo(model, guide, *args, **kwargs):
approximation = kwargs.pop('approximation')
# We first compute the elbo, but record a guide trace for use below.
traced_guide = trace(guide)
elbo = pyro.infer.Trace_ELBO(max_plate_nesting=1)
loss_fn = elbo.differentiable_loss(model, traced_guide, *args, **kwargs)
# We do not want to update parameters of previously fitted components
# and thus block all parameters in the approximation apart from z.
guide_trace = traced_guide.trace
replayed_approximation = trace(replay(block(approximation, expose=['z']), guide_trace))
approximation_trace = replayed_approximation.get_trace(*args, **kwargs)
relbo = -loss_fn - approximation_trace.log_prob_sum()
# By convention, the negative (R)ELBO is returned.
return -relbo |
@lorenzkuhn FYI we're planning on a 1.3 release on Monday. There's no rush 😄but if you would like to get this in by Monday morning, we can publish it with the release. |
Ah, I see, thanks @martinjankowiak and @fritzo. I've refactored the relbo as you suggested @fritzo and fixed the typos/inconsistencies that @martinjankowiak found 👍 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Any further suggestions, @martinjankowiak ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm thanks for the tutorial!
In this pull request, we (@gideonite, @sharrison5, @TNU-yaoy and I) would like to submit a tutorial on implementing Boosting Black Box Variational Inference using Pyro. In the tutorial, we summarise the paper, show how boosting black box Variational Inference can be implemented as an extension of Pyro's SVI and apply the method to the problem of approximating a bimodal posterior.
I've run
make test-tutorial
(which passes),make scrub
andmake license
. The changes toneutra.py
andtest_neutra.py
were caused bymake license
.We look forward to your feedback!