Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/geodesic sinkhorn #457

Merged
merged 48 commits into from
Nov 28, 2023
Merged

Conversation

guillaumehu
Copy link
Collaborator

@guillaumehu guillaumehu commented Nov 8, 2023

Implementation in: https://arxiv.org/abs/2211.00805

previous PR by @diegoabt can be found in #425

diegoabt and others added 30 commits September 5, 2023 13:43
This module contains the implementation of the Geodesic Sinkhorn algorithm [1].

[1] Huguet, G., Tong, A., Zapatero, M. R., Wolf, G., & Krishnaswamy, S. (2022). Geodesic Sinkhorn: optimal transport for high-dimensional datasets. arXiv preprint arXiv:2211.00805.
@michalk8 michalk8 added the enhancement New feature or request label Nov 8, 2023
@michalk8 michalk8 mentioned this pull request Nov 8, 2023
Copy link

codecov bot commented Nov 8, 2023

Codecov Report

Merging #457 (b7d1df3) into main (f5ade02) will increase coverage by 8.63%.
Report is 10 commits behind head on main.
The diff coverage is 94.84%.

❗ Current head b7d1df3 differs from pull request most recent head cf17ce4. Consider uploading reports for the commit cf17ce4 to get more accurate results

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #457      +/-   ##
==========================================
+ Coverage   81.97%   90.60%   +8.63%     
==========================================
  Files          59       60       +1     
  Lines        6345     6464     +119     
  Branches      615      913     +298     
==========================================
+ Hits         5201     5857     +656     
+ Misses       1024      465     -559     
- Partials      120      142      +22     
Files Coverage Δ
src/ott/geometry/geodesic.py 94.84% <94.84%> (ø)

... and 22 files with indirect coverage changes

@michalk8 michalk8 marked this pull request as ready for review November 8, 2023 21:52
@michalk8 michalk8 marked this pull request as draft November 8, 2023 21:52
@guillaumehu guillaumehu marked this pull request as ready for review November 13, 2023 11:36
@guillaumehu
Copy link
Collaborator Author

Hey @michalk8 let me know if you have any suggestions on the PR. I am not sure what is the best way to deal with the dtype in _scipy_compute_chebychev_coeff_all, and I think we could remove the laplacian as an instance attribute since we only use the scaled_laplacian.

docs/references.bib Outdated Show resolved Hide resolved
src/ott/geometry/geodesic.py Show resolved Hide resolved
src/ott/geometry/geodesic.py Outdated Show resolved Hide resolved
src/ott/geometry/geodesic.py Outdated Show resolved Hide resolved
src/ott/geometry/geodesic.py Outdated Show resolved Hide resolved
tests/geometry/geo_test.py Outdated Show resolved Hide resolved
src/ott/geometry/geodesic.py Outdated Show resolved Hide resolved
self.eigval = eigval
self.chebyshev_coeffs = chebyshev_coeffs
self.t = t
self.order = order
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should remove this, as it's not being used anywhere in the class.
The docs will need to be adapted a bit as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only removed self.order, because self.t is used for the attribute cost_matrix, and self.eigval, self.chebyshev_coeffs are used inexpm_multiply. I renamed phi to eigval in expm_multiply to make it clearer.

src/ott/geometry/geodesic.py Outdated Show resolved Hide resolved
tests/geometry/geo_test.py Show resolved Hide resolved
@guillaumehu guillaumehu requested a review from michalk8 November 28, 2023 14:24
Copy link
Collaborator

@michalk8 michalk8 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @guillaumehu , lgtm!

@michalk8 michalk8 merged commit 75c6b3d into ott-jax:main Nov 28, 2023
9 of 10 checks passed
michalk8 pushed a commit that referenced this pull request Jun 27, 2024
* Create `geodesic` module 

This module contains the implementation of the Geodesic Sinkhorn algorithm [1].

[1] Huguet, G., Tong, A., Zapatero, M. R., Wolf, G., & Krishnaswamy, S. (2022). Geodesic Sinkhorn: optimal transport for high-dimensional datasets. arXiv preprint arXiv:2211.00805.

* Lint code

* Lint code

* Add Geodesic kernel citation to `docs/references.bib`

* Remove unused functions; update docstrings; remove `n_steps`; remove forced symm

* Remove hardcoded random key at `compute_largest_eigenvalue`

* Fix docstrings of functions inside `apply_kernel`

* Add chebyshev coeff computation to `from_graph`

* Change `jax.exp.sparse` import

* Change `tree_flatten` outputs

* Change input of `lobpcg_std` to be a sparsified product

* Change definition of cost from kernel

* Add `Geodesic` to `docs/geometry`

* Remove `np.max` from max eigenval computation

* Add `default_prng_key` to eigenval computation

* Add `safe_log` to cost matrix computation

* Change to `jesp.BCOO` at chebyshev approx

* Restructure `from_graph`; coeffs are computed earlier now

* fn outside of the class

* mv fn & process L once

* wip tests geo

* jax pure_callback & new fn cheb

* symmetric kernel & wip tests

* fix formatting ruff

* wrap eigenval fn

* rm num_scheme & _scale

* type dense or sparse

* expm with scan & fix hardcode dty & lobpcg iter

* test compare with BE and CN

* rm lap_mn_id & test sink spd & fix tree_flatten

* default and type of Cheb. co. & docstrings

* dtype in purecallback depending on Lap

* rm laplacian & update test

* simpler wrapper for `ive`

* lint-docs spell check

* fix mistake rm setup

* typo & rm fn scale lap &

* rm t order init & typing

* use rng util

* t for cost & test with ground truth

* rm order & change `phi` to `eigval` & type hints

* differentiability test

* condition to symmetrize kernel

* fix indentation

---------

Co-authored-by: Diego Baptista Theuerkauf <[email protected]>
Co-authored-by: diegoabt <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants