-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
introduce multivariate cdf / quantiles #447
Conversation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #447 +/- ##
==========================================
+ Coverage 90.56% 90.59% +0.02%
==========================================
Files 57 57
Lines 6256 6274 +18
Branches 884 888 +4
==========================================
+ Hits 5666 5684 +18
Misses 448 448
Partials 142 142
|
@@ -499,31 +500,33 @@ def multivariate_cdf_quantile_maps( | |||
to be uniform by default. | |||
kwargs: keyword arguments passed on to the :func:`~ott.solvers.linear.solve` | |||
function, which solves the OT problem between ``inputs`` and ``targets`` | |||
using the Sinkhorn algorithm. | |||
using the :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` algorithm. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here i wasn't sure about the reference, because we use solve
... that being said, probably not a good idea to use LR
on this, because would crash :) so maybe an instance where we should force kwargs to only refer to sinkhorn...
@@ -479,12 +480,12 @@ def multivariate_cdf_quantile_maps( | |||
|
|||
Args: | |||
inputs: 2D array of ``[n, d]`` vectors. | |||
target_sampler: Callable that takes a ``key`` and ``[m,d]`` shape. | |||
target_sampler: Callable that takes a ``rng`` and ``[m, d]`` shape. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry about this one!!!
thanks a lot for the last fixes Michal! |
* introduce multiv cdf / quantiles * fix online memory * incorporate feedback from review * various fixes * remove doi * adding jittability * adding docs * Fix docstrings * Update test * Fix spellchecker --------- Co-authored-by: Michal Klein <[email protected]>
* Bump `jax>=0.4` * Update Docker image * Change GPU test name * Fix typo * Don't pre-allocate memory on GPU * Update step name * Fix GPU device number and jax installation * introduce multivariate cdf / quantiles (#447) (#449) * introduce multiv cdf / quantiles * fix online memory * incorporate feedback from review * various fixes * remove doi * adding jittability * adding docs * Fix docstrings * Update test * Fix spellchecker --------- Co-authored-by: Michal Klein <[email protected]>
* introduce multiv cdf / quantiles * fix online memory * incorporate feedback from review * various fixes * remove doi * adding jittability * adding docs * Fix docstrings * Update test * Fix spellchecker --------- Co-authored-by: Michal Klein <[email protected]>
* Bump `jax>=0.4` * Update Docker image * Change GPU test name * Fix typo * Don't pre-allocate memory on GPU * Update step name * Fix GPU device number and jax installation * introduce multivariate cdf / quantiles (#447) (#449) * introduce multiv cdf / quantiles * fix online memory * incorporate feedback from review * various fixes * remove doi * adding jittability * adding docs * Fix docstrings * Update test * Fix spellchecker --------- Co-authored-by: Michal Klein <[email protected]>
The idea comes from the great work of Marc Hallin and colleagues
https://arxiv.org/abs/1412.8434
The implementation essentially relies on entropic maps (Pooladian/Niles Weed) to approximate both CDF and Quantiles Monge maps (forward and backward between input measure and reference/uniform measure).
We can expect performance to degrade with dimension + sensitivity w.r.t. epsilon, and this can be observed in tests (hence fairly loose
atol
values).