Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/batched vmap #588

Merged
merged 79 commits into from
Oct 16, 2024
Merged

Feature/batched vmap #588

merged 79 commits into from
Oct 16, 2024

Conversation

michalk8
Copy link
Collaborator

@michalk8 michalk8 commented Oct 10, 2024

TODOs:

  • negative in/out axes
  • clean-up the pointcloud.py (re-order methods/properties, etc.)
  • update LRCGeometry
  • test the batched_vmap
  • docs

closes #504

Copy link

codecov bot commented Oct 10, 2024

Codecov Report

Attention: Patch coverage is 97.55245% with 7 lines in your changes missing coverage. Please review.

Project coverage is 88.13%. Comparing base (706cef7) to head (e390e64).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
src/ott/geometry/pointcloud.py 95.79% 3 Missing and 2 partials ⚠️
src/ott/utils.py 98.05% 1 Missing and 1 partial ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #588      +/-   ##
==========================================
+ Coverage   88.00%   88.13%   +0.12%     
==========================================
  Files          73       73              
  Lines        7820     7768      -52     
  Branches      567      556      -11     
==========================================
- Hits         6882     6846      -36     
+ Misses        789      779      -10     
+ Partials      149      143       -6     
Files with missing lines Coverage Δ
src/ott/geometry/costs.py 97.18% <100.00%> (-0.04%) ⬇️
src/ott/geometry/distrib_costs.py 100.00% <100.00%> (ø)
src/ott/geometry/geodesic.py 94.39% <100.00%> (+0.05%) ⬆️
src/ott/geometry/geometry.py 94.46% <100.00%> (+1.21%) ⬆️
src/ott/geometry/graph.py 95.83% <100.00%> (+0.04%) ⬆️
src/ott/geometry/grid.py 96.24% <100.00%> (+0.02%) ⬆️
src/ott/geometry/low_rank.py 96.73% <100.00%> (-0.31%) ⬇️
src/ott/neural/methods/monge_gap.py 91.20% <ø> (ø)
src/ott/problems/linear/potentials.py 91.50% <ø> (ø)
src/ott/problems/quadratic/gw_barycenter.py 90.00% <ø> (ø)
... and 7 more

@michalk8 michalk8 requested a review from marcocuturi October 15, 2024 17:53
@michalk8 michalk8 added the enhancement New feature or request label Oct 15, 2024
@michalk8 michalk8 marked this pull request as ready for review October 15, 2024 17:53
Copy link
Contributor

@marcocuturi marcocuturi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

src/ott/geometry/costs.py Outdated Show resolved Hide resolved
src/ott/geometry/geodesic.py Outdated Show resolved Hide resolved
src/ott/geometry/geometry.py Outdated Show resolved Hide resolved
src/ott/geometry/graph.py Outdated Show resolved Hide resolved
src/ott/geometry/graph.py Outdated Show resolved Hide resolved
@@ -105,6 +105,7 @@ def __call__(
dual_initialization, weights=weights, axis=0
)[jnp.newaxis, :]

# TODO(michalk8): geom.is_symmetric is not static
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true!

@michalk8 michalk8 merged commit c9d3a49 into main Oct 16, 2024
13 checks passed
@michalk8 michalk8 deleted the feature/batched-vmap branch October 16, 2024 16:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Increased GPU memory usage when using a cost_fn different from costs.SqEuclidean()
2 participants