-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performance improvements for chirho.robust
#459
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice detective work!
I can see why 2 and 3 would cause slowdowns, but 1 is still a bit mysterious to me even after spending some time yesterday playing around with the code and tests - there are a couple places where a naive implementation of vmap
could introduce unnecessarily repeated computation, but I made some smaller changes that I thought would work around those and they didn't seem to make a difference.
At any rate, whatever is really going on, you've convinced me that it would be a good idea to replace all uses in chirho.robust
of torch.func.vmap
to batch over datapoints with the use of a manually vectorized version of NMCLogPredictiveLikelihood
generalizing your PointPredictiveLikelihood
. I'll create a separate issue for this and assign myself.
In the meantime, can we break this into smaller PRs that address the issues independently, starting with 2 and 3? I have some comments on API changes here but I think it would be easiest to address them in more isolated contexts.
log_prob_params = {"loc": loc} | ||
num_monte_carlo = 10000 | ||
start_time = time.time() | ||
data = Predictive(GaussianModel(cov_mat), num_samples=num_monte_carlo)(loc) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This data generation step is slow because it's done sequentially. Passing parallel=True
to Predictive
(which we do in linearize
) speeds it up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it thanks!
guide() | ||
|
||
start_time = time.time() | ||
data = Predictive( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto: this is slow because it's not being vectorized. Passing parallel=True
should speed it up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense thanks!
Yup, I'll do that now! |
This PR will be broken up into 3 separate issues so will close it |
There are three main performance improvements:
Optimization when guide is a point estimate
If the guide is only a point estimate, using
NMCLogPredictiveLikelihood
, which relies ontorch.func.vmap
to batch over multiple datapoints can be orders of magnitude slower than manually batching the likelihood. Intest_empirical_fisher_vp_performance_with_likelihood
intest_performance.py
, we see thatNMCLogPredictiveLikelihood
can be 200x slower than manually vectorizing the likelihood using the new classPointLogPredictiveLikelihood
.Manually batching in
_flat_conjugate_gradient_solve
One big performance hit was using
torch.func.vmap
to batch over multiple conjugate gradient solves. Sincetorch.func.vmap
does not allow conditional if/else statements, the conjugate gradient solver cannot terminate earlier if the error tolerance is met. In this PR, we batch the conjugate gradient solver without usingtorch.func.vmap
so that we can use conditional statements. In one experiment, I was able to run only 23 congugate gradient steps instead of a few hundred to reach error tolerance. This behavior is also exploited here: https://arxiv.org/abs/1903.08114.Batching over multiple points in
linearize
andinfluence_fn
As shown in
test_performance.py
, simulating synthetic data from the model can be the most lengthy step. In the previous implementation, data is simulated each time the influence function is evaluated at a single point. In this PR, we evaluate the influence function over multiple points so that we only need to simulate data once.Closes #451