Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implements sliced W #576

Merged
merged 13 commits into from
Sep 13, 2024
Merged

implements sliced W #576

merged 13 commits into from
Sep 13, 2024

Conversation

marcocuturi
Copy link
Contributor

@marcocuturi marcocuturi commented Sep 12, 2024

Fairly primitive implementation of sliced_w distance.

essentially a wrapper on top of the ott.solvers.linear.solve_univariate function

@michalk8 michalk8 added the enhancement New feature or request label Sep 12, 2024
src/ott/tools/sliced.py Show resolved Hide resolved
src/ott/tools/sliced.py Outdated Show resolved Hide resolved
src/ott/tools/sliced.py Outdated Show resolved Hide resolved
src/ott/tools/sliced.py Outdated Show resolved Hide resolved
src/ott/tools/sliced.py Outdated Show resolved Hide resolved
src/ott/tools/sliced.py Show resolved Hide resolved
src/ott/tools/sliced.py Outdated Show resolved Hide resolved
src/ott/tools/sliced.py Outdated Show resolved Hide resolved
src/ott/tools/sliced.py Outdated Show resolved Hide resolved
tests/tools/sliced_test.py Outdated Show resolved Hide resolved
Copy link

codecov bot commented Sep 12, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 87.84%. Comparing base (4aed3ec) to head (46bd3dd).
Report is 31 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #576      +/-   ##
==========================================
+ Coverage   87.78%   87.84%   +0.05%     
==========================================
  Files          72       73       +1     
  Lines        7801     7823      +22     
  Branches     1126     1127       +1     
==========================================
+ Hits         6848     6872      +24     
+ Misses        802      798       -4     
- Partials      151      153       +2     
Files with missing lines Coverage Δ
src/ott/tools/sliced.py 100.00% <100.00%> (ø)

... and 1 file with indirect coverage changes

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

src/ott/tools/sliced.py Outdated Show resolved Hide resolved
src/ott/tools/sliced.py Outdated Show resolved Hide resolved
src/ott/tools/sliced.py Outdated Show resolved Hide resolved
src/ott/tools/sliced.py Outdated Show resolved Hide resolved
src/ott/tools/sliced.py Outdated Show resolved Hide resolved
docs/tools.rst Outdated Show resolved Hide resolved
docs/tools.rst Outdated Show resolved Hide resolved
docs/tools.rst Outdated Show resolved Hide resolved
docs/tools.rst Outdated Show resolved Hide resolved
docs/tutorials/Monge_Gap_Simple.ipynb Outdated Show resolved Hide resolved
@@ -88,10 +88,15 @@ def sliced_wasserstein(
Returns:
The sliced Wasserstein distance with the corresponding output object.
"""
if proj_fn is None:
if projector is None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a bit complex, no? it just saves the user lambda arr, **_: arr @ projector.T when passing proj_fn as matrix product, at the cost of this ambiguity

My hunch is that most users will likely want to change the way directions are sampled in the proj, rather than pass a matrix (that was the idea of passing a random generator first), or, alternatively, define a proper proj_fn feature extractor (that's also why I had added the features word above, as, when the proj_fn becomes more complex, one should rather see it as a feature extractor)

@marcocuturi
Copy link
Contributor Author

Thanks a lot @michalk8 for all these great comments!

@marcocuturi marcocuturi merged commit 27b639e into main Sep 13, 2024
8 of 11 checks passed
@marcocuturi marcocuturi deleted the sliced branch September 13, 2024 09:05
Copy link
Collaborator

@Algue-Rythme Algue-Rythme left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

rng = utils.default_prng_key(rng)
dim = x.shape[-1]
proj_m = jax.random.normal(rng, (n_proj, dim))
proj_m /= jnp.linalg.norm(proj_m, axis=1, keepdims=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is nitpicking, but are we sure that this division is safe? In theory the probability of getting the null vector is zero, but in practice I'm not sure of what is happening in the worst case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants