Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tutorial for Boosting Black Box Variational Inference #2308

Merged
merged 5 commits into from
Feb 29, 2020

Conversation

lorenzkuhn
Copy link
Contributor

In this pull request, we (@gideonite, @sharrison5, @TNU-yaoy and I) would like to submit a tutorial on implementing Boosting Black Box Variational Inference using Pyro. In the tutorial, we summarise the paper, show how boosting black box Variational Inference can be implemented as an extension of Pyro's SVI and apply the method to the problem of approximating a bimodal posterior.

I've run make test-tutorial(which passes), make scrub and make license. The changes to neutra.py and test_neutra.py were caused by make license.

We look forward to your feedback!

@fritzo
Copy link
Member

fritzo commented Feb 18, 2020

@lorenzkuhn this looks great! LGTM on code and documentation.

@martinjankowiak can you review the relbo math?

@martinjankowiak
Copy link
Collaborator

thanks @lorenzkuhn looks great!

here are some nits/comments/suggestions:

  • the "Variational Inference/ELBO" section is inconsistent about s(z)/q(z); i.e. s => q
  • typo 'mulitmodal'
  • "in our case, we include loc in both the model and the guide." do you mean z?
  • note that the way you use Trace_ELBO leads to extra samples being drawn. either comment or fix?

@lorenzkuhn
Copy link
Contributor Author

Thanks a lot for the feedback @fritzo and @martinjankowiak, that sounds good.

@martinjankowiak : Thanks, I'll fix the issues from the first three bullet points. What would be the recommended approach to avoid drawing extra samples in Trace_ELBO?

@martinjankowiak
Copy link
Collaborator

@lorenzkuhn i think the easiest thing to do would be to replace

loss_fn = pyro.infer.Trace_ELBO(max_plate_nesting=1).differentiable_loss(model,
                                                                             guide,
                                                                             *args,
                                                                             **kwargs)

with a simple custom elbo along the lines of what's here, i.e. this way you can reuse your guide trace.
@fritzo wdyt?

@fritzo
Copy link
Member

fritzo commented Feb 28, 2020

As I understand, @martinjankowiak is concerned with the guide being run twice in relbo(). I believe this could be avoided by passing a traced guide to elbo and using that trace to compute the approximation. WDYT of the following refactoring?

def relbo(model, guide, *args, **kwargs):
    approximation = kwargs.pop('approximation')

    # We first compute the elbo, but record a guide trace for use below.
    traced_guide = trace(guide)
    elbo = pyro.infer.Trace_ELBO(max_plate_nesting=1)
    loss_fn = elbo.differentiable_loss(model, traced_guide, *args, **kwargs)

    # We do not want to update parameters of previously fitted components
    # and thus block all parameters in the approximation apart from z.
    guide_trace = traced_guide.trace
    replayed_approximation = trace(replay(block(approximation, expose=['z']), guide_trace))
    approximation_trace = replayed_approximation.get_trace(*args, **kwargs)

    relbo = -loss_fn - approximation_trace.log_prob_sum()
    
    # By convention, the negative (R)ELBO is returned.
    return -relbo

@fritzo
Copy link
Member

fritzo commented Feb 28, 2020

@lorenzkuhn FYI we're planning on a 1.3 release on Monday. There's no rush 😄but if you would like to get this in by Monday morning, we can publish it with the release.

@lorenzkuhn
Copy link
Contributor Author

Ah, I see, thanks @martinjankowiak and @fritzo. I've refactored the relbo as you suggested @fritzo and fixed the typos/inconsistencies that @martinjankowiak found 👍

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Any further suggestions, @martinjankowiak ?

Copy link
Collaborator

@martinjankowiak martinjankowiak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm thanks for the tutorial!

@martinjankowiak martinjankowiak merged commit a14b4eb into pyro-ppl:dev Feb 29, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants