-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implements sliced W #576
implements sliced W #576
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #576 +/- ##
==========================================
+ Coverage 87.78% 87.84% +0.05%
==========================================
Files 72 73 +1
Lines 7801 7823 +22
Branches 1126 1127 +1
==========================================
+ Hits 6848 6872 +24
+ Misses 802 798 -4
- Partials 151 153 +2
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
src/ott/tools/sliced.py
Outdated
@@ -88,10 +88,15 @@ def sliced_wasserstein( | |||
Returns: | |||
The sliced Wasserstein distance with the corresponding output object. | |||
""" | |||
if proj_fn is None: | |||
if projector is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a bit complex, no? it just saves the user lambda arr, **_: arr @ projector.T
when passing proj_fn
as matrix product, at the cost of this ambiguity
My hunch is that most users will likely want to change the way directions are sampled in the proj, rather than pass a matrix (that was the idea of passing a random generator first), or, alternatively, define a proper proj_fn
feature extractor (that's also why I had added the features
word above, as, when the proj_fn
becomes more complex, one should rather see it as a feature extractor)
Thanks a lot @michalk8 for all these great comments! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
rng = utils.default_prng_key(rng) | ||
dim = x.shape[-1] | ||
proj_m = jax.random.normal(rng, (n_proj, dim)) | ||
proj_m /= jnp.linalg.norm(proj_m, axis=1, keepdims=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is nitpicking, but are we sure that this division is safe? In theory the probability of getting the null vector is zero, but in practice I'm not sure of what is happening in the worst case.
Fairly primitive implementation of sliced_w distance.
essentially a wrapper on top of the
ott.solvers.linear.solve_univariate
function