Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance improvements for chirho.robust #459

Closed
wants to merge 11 commits into from

Conversation

agrawalraj
Copy link
Contributor

@agrawalraj agrawalraj commented Dec 19, 2023

There are three main performance improvements:

  1. Optimization when guide is a point estimate
    If the guide is only a point estimate, using NMCLogPredictiveLikelihood, which relies on torch.func.vmap to batch over multiple datapoints can be orders of magnitude slower than manually batching the likelihood. In test_empirical_fisher_vp_performance_with_likelihood in test_performance.py, we see that NMCLogPredictiveLikelihood can be 200x slower than manually vectorizing the likelihood using the new class PointLogPredictiveLikelihood.

  2. Manually batching in _flat_conjugate_gradient_solve
    One big performance hit was using torch.func.vmap to batch over multiple conjugate gradient solves. Since torch.func.vmap does not allow conditional if/else statements, the conjugate gradient solver cannot terminate earlier if the error tolerance is met. In this PR, we batch the conjugate gradient solver without using torch.func.vmap so that we can use conditional statements. In one experiment, I was able to run only 23 congugate gradient steps instead of a few hundred to reach error tolerance. This behavior is also exploited here: https://arxiv.org/abs/1903.08114.

  3. Batching over multiple points in linearize and influence_fn
    As shown in test_performance.py, simulating synthetic data from the model can be the most lengthy step. In the previous implementation, data is simulated each time the influence function is evaluated at a single point. In this PR, we evaluate the influence function over multiple points so that we only need to simulate data once.

Closes #451

@agrawalraj agrawalraj added enhancement New feature or request module:robust labels Dec 19, 2023
@agrawalraj agrawalraj self-assigned this Dec 19, 2023
@agrawalraj agrawalraj requested a review from eb8680 December 19, 2023 01:08
@agrawalraj agrawalraj added the status:awaiting review Awaiting response from reviewer label Dec 19, 2023
@agrawalraj agrawalraj linked an issue Dec 19, 2023 that may be closed by this pull request
Copy link
Contributor

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice detective work!

I can see why 2 and 3 would cause slowdowns, but 1 is still a bit mysterious to me even after spending some time yesterday playing around with the code and tests - there are a couple places where a naive implementation of vmap could introduce unnecessarily repeated computation, but I made some smaller changes that I thought would work around those and they didn't seem to make a difference.

At any rate, whatever is really going on, you've convinced me that it would be a good idea to replace all uses in chirho.robust of torch.func.vmap to batch over datapoints with the use of a manually vectorized version of NMCLogPredictiveLikelihood generalizing your PointPredictiveLikelihood. I'll create a separate issue for this and assign myself.

In the meantime, can we break this into smaller PRs that address the issues independently, starting with 2 and 3? I have some comments on API changes here but I think it would be easiest to address them in more isolated contexts.

log_prob_params = {"loc": loc}
num_monte_carlo = 10000
start_time = time.time()
data = Predictive(GaussianModel(cov_mat), num_samples=num_monte_carlo)(loc)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This data generation step is slow because it's done sequentially. Passing parallel=True to Predictive (which we do in linearize) speeds it up.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it thanks!

guide()

start_time = time.time()
data = Predictive(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: this is slow because it's not being vectorized. Passing parallel=True should speed it up.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense thanks!

@agrawalraj
Copy link
Contributor Author

Nice detective work!

I can see why 2 and 3 would cause slowdowns, but 1 is still a bit mysterious to me even after spending some time yesterday playing around with the code and tests - there are a couple places where a naive implementation of vmap could introduce unnecessarily repeated computation, but I made some smaller changes that I thought would work around those and they didn't seem to make a difference.

At any rate, whatever is really going on, you've convinced me that it would be a good idea to replace all uses in chirho.robust of torch.func.vmap to batch over datapoints with the use of a manually vectorized version of NMCLogPredictiveLikelihood generalizing your PointPredictiveLikelihood. I'll create a separate issue for this and assign myself.

In the meantime, can we break this into smaller PRs that address the issues independently, starting with 2 and 3? I have some comments on API changes here but I think it would be easiest to address them in more isolated contexts.

Yup, I'll do that now!

@agrawalraj
Copy link
Contributor Author

This PR will be broken up into 3 separate issues so will close it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request module:robust status:awaiting review Awaiting response from reviewer
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Potential slowdown of vmapping on the outside in influence_fn
2 participants