Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch apply_lse_kernel for online=True #23

Merged
merged 12 commits into from
Mar 2, 2022

Conversation

michalk8
Copy link
Collaborator

@michalk8 michalk8 commented Mar 2, 2022

As discussed in #20 , this PR fixes online by batching apply_lse_kernel and running the fully-vectorized computation on a batch of shape n * batch_size or m * batch_size (depends of the axis) instead of n * m. Minor inefficiency comes from that this approach computed the kernel application for extra {n,m} % batch_size points, given that result of each iteration of jax.lax.scan must have the same shape.

@LaetitiaPapaxanthos there are 2 points I am unsure how you want to handle:

  • shall the backward compatibility be retained and allow backward=True? In that case, we'd need to use some default value, that can either be fixed or depend on the number of points (from our benchmarks, value of 1024 seems to work well) UPDATE: online=True is the same as online=1024
  • currently, the batch size is the same when using axis=0 or axis=1, but this could be done in axis-specific manner UPDATE: kept the same batch size for both axes

TODOs:

  • add new tests (and depending on the 1st point above, might need to adjust old tests where online=True)
  • update docs

closes #20

@michalk8
Copy link
Collaborator Author

michalk8 commented Mar 2, 2022

There was a bug which caused the coupling/marginals not to match online=False (neither tests here or running the tutorial notebook caught it for the particular values of online), should be fixed now (added a test for this + 2 more for jitting); locally tests pass (but would be great to enable CI on PRs).
Performance-wise, should even be slightly faster, since apply_lse_kernel doesn't do any extra work (985.29s for a coupling of shape (131072, 65536), Nvidia A100). As far as corner-cases, there should be tests for them (batch size of 1, n, m and some prime number).

@marcocuturi
Copy link
Contributor

Thanks a lot Michal! this is fantastic.

@marcocuturi marcocuturi merged commit 0909dfe into ott-jax:master Mar 2, 2022
michalk8 pushed a commit that referenced this pull request Jun 27, 2024
Batch `apply_lse_kernel` for `online=True`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Vectorization memory issues with PointCloud.apply_lse_kernel with online=True
2 participants