Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

introduce multivariate cdf / quantiles #447

Merged
merged 10 commits into from
Oct 25, 2023
Merged

introduce multivariate cdf / quantiles #447

merged 10 commits into from
Oct 25, 2023

Conversation

marcocuturi
Copy link
Contributor

The idea comes from the great work of Marc Hallin and colleagues

https://arxiv.org/abs/1412.8434

The implementation essentially relies on entropic maps (Pooladian/Niles Weed) to approximate both CDF and Quantiles Monge maps (forward and backward between input measure and reference/uniform measure).

We can expect performance to degrade with dimension + sensitivity w.r.t. epsilon, and this can be observed in tests (hence fairly loose atol values).

@codecov
Copy link

codecov bot commented Oct 25, 2023

Codecov Report

Merging #447 (82f700a) into main (369db8c) will increase coverage by 0.02%.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #447      +/-   ##
==========================================
+ Coverage   90.56%   90.59%   +0.02%     
==========================================
  Files          57       57              
  Lines        6256     6274      +18     
  Branches      884      888       +4     
==========================================
+ Hits         5666     5684      +18     
  Misses        448      448              
  Partials      142      142              
Files Coverage Δ
src/ott/tools/soft_sort.py 96.07% <100.00%> (+0.52%) ⬆️

@michalk8 michalk8 self-requested a review October 25, 2023 11:10
@michalk8 michalk8 added the enhancement New feature or request label Oct 25, 2023
src/ott/tools/soft_sort.py Outdated Show resolved Hide resolved
src/ott/tools/soft_sort.py Outdated Show resolved Hide resolved
src/ott/tools/soft_sort.py Outdated Show resolved Hide resolved
src/ott/tools/soft_sort.py Outdated Show resolved Hide resolved
src/ott/tools/soft_sort.py Outdated Show resolved Hide resolved
src/ott/tools/soft_sort.py Outdated Show resolved Hide resolved
tests/tools/soft_sort_test.py Outdated Show resolved Hide resolved
src/ott/tools/soft_sort.py Outdated Show resolved Hide resolved
tests/tools/soft_sort_test.py Outdated Show resolved Hide resolved
tests/tools/soft_sort_test.py Outdated Show resolved Hide resolved
src/ott/tools/soft_sort.py Outdated Show resolved Hide resolved
src/ott/tools/soft_sort.py Outdated Show resolved Hide resolved
src/ott/tools/soft_sort.py Outdated Show resolved Hide resolved
src/ott/tools/soft_sort.py Outdated Show resolved Hide resolved
src/ott/tools/soft_sort.py Outdated Show resolved Hide resolved
src/ott/tools/soft_sort.py Outdated Show resolved Hide resolved
src/ott/tools/soft_sort.py Outdated Show resolved Hide resolved
@@ -499,31 +500,33 @@ def multivariate_cdf_quantile_maps(
to be uniform by default.
kwargs: keyword arguments passed on to the :func:`~ott.solvers.linear.solve`
function, which solves the OT problem between ``inputs`` and ``targets``
using the Sinkhorn algorithm.
using the :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` algorithm.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here i wasn't sure about the reference, because we use solve... that being said, probably not a good idea to use LR on this, because would crash :) so maybe an instance where we should force kwargs to only refer to sinkhorn...

@@ -479,12 +480,12 @@ def multivariate_cdf_quantile_maps(

Args:
inputs: 2D array of ``[n, d]`` vectors.
target_sampler: Callable that takes a ``key`` and ``[m,d]`` shape.
target_sampler: Callable that takes a ``rng`` and ``[m, d]`` shape.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry about this one!!!

@marcocuturi
Copy link
Contributor Author

thanks a lot for the last fixes Michal!

@marcocuturi marcocuturi merged commit 3706511 into main Oct 25, 2023
marcocuturi added a commit that referenced this pull request Oct 25, 2023
* introduce multiv cdf / quantiles

* fix online memory

* incorporate feedback from review

* various fixes

* remove doi

* adding jittability

* adding docs

* Fix docstrings

* Update test

* Fix spellchecker

---------

Co-authored-by: Michal Klein <[email protected]>
michalk8 added a commit that referenced this pull request Oct 25, 2023
* Bump `jax>=0.4`

* Update Docker image

* Change GPU test name

* Fix typo

* Don't pre-allocate memory on GPU

* Update step name

* Fix GPU device number and jax installation

* introduce multivariate cdf / quantiles (#447) (#449)

* introduce multiv cdf / quantiles

* fix online memory

* incorporate feedback from review

* various fixes

* remove doi

* adding jittability

* adding docs

* Fix docstrings

* Update test

* Fix spellchecker

---------

Co-authored-by: Michal Klein <[email protected]>
@marcocuturi marcocuturi deleted the multivq branch October 31, 2023 11:03
michalk8 added a commit that referenced this pull request Jun 27, 2024
* introduce multiv cdf / quantiles

* fix online memory

* incorporate feedback from review

* various fixes

* remove doi

* adding jittability

* adding docs

* Fix docstrings

* Update test

* Fix spellchecker

---------

Co-authored-by: Michal Klein <[email protected]>
michalk8 added a commit that referenced this pull request Jun 27, 2024
* Bump `jax>=0.4`

* Update Docker image

* Change GPU test name

* Fix typo

* Don't pre-allocate memory on GPU

* Update step name

* Fix GPU device number and jax installation

* introduce multivariate cdf / quantiles (#447) (#449)

* introduce multiv cdf / quantiles

* fix online memory

* incorporate feedback from review

* various fixes

* remove doi

* adding jittability

* adding docs

* Fix docstrings

* Update test

* Fix spellchecker

---------

Co-authored-by: Michal Klein <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants