Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor sinkhorn_divergence #577

Merged
merged 2 commits into from
Sep 15, 2024
Merged

refactor sinkhorn_divergence #577

merged 2 commits into from
Sep 15, 2024

Conversation

marcocuturi
Copy link
Contributor

Ensure that sinkhorn_divergence function in tools returns a tuple. First element in Tuple (what most users will want) is divergence float value. Second element is detailed output.

This new convention will likely trigger

AttributeError: 'tuple' object has no attribute 'divergence'

errors in code that is using the former API, apologies for this!

Also, we may think about featuring the wrappers in progot.py L. 410 or appearing also in sinkhorn_divergence_test.py L.432 more prominently, as they follow the more intuitive API sdiv(x,y) which our implementation does not allow directly, as we need to pass geometry type.

As I believe 99% of users use Sinkhorn divergences with a point cloud geometry, we might want to remove that ambiguity/possibility?

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Copy link

codecov bot commented Sep 13, 2024

Codecov Report

Attention: Patch coverage is 66.66667% with 2 lines in your changes missing coverage. Please review.

Project coverage is 87.83%. Comparing base (27b639e) to head (e7b8bfd).
Report is 30 commits behind head on main.

Files with missing lines Patch % Lines
src/ott/tools/progot.py 0.00% 2 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #577      +/-   ##
==========================================
- Coverage   87.84%   87.83%   -0.01%     
==========================================
  Files          73       73              
  Lines        7823     7826       +3     
  Branches     1127     1127              
==========================================
+ Hits         6872     6874       +2     
- Misses        798      799       +1     
  Partials      153      153              
Files with missing lines Coverage Δ
src/ott/tools/sinkhorn_divergence.py 91.86% <100.00%> (+0.19%) ⬆️
src/ott/tools/progot.py 29.16% <0.00%> (-0.21%) ⬇️

src/ott/tools/progot.py Outdated Show resolved Hide resolved
src/ott/tools/progot.py Outdated Show resolved Hide resolved
@@ -115,7 +115,8 @@ def sinkhorn_divergence(
geometry.

Returns:
Sinkhorn divergence value, three pairs of potentials, three costs.
Sinkhorn divergence value, in addition to
:class:`~ott.tools.sinkhorn_divergence.SinkhornDivergenceOutput` object.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would simplify to (also in the one other place), but we can also keep:
The Sinkhorn divergence and the output object.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks.

i have also made a small change. _sinkhorn_divergence now only outputs the output object, it's sinkhorn_divergence that chooses to expose explicitly the value first in the Tuple, with the output second.

@michalk8 michalk8 added the enhancement New feature or request label Sep 13, 2024
@marcocuturi marcocuturi merged commit aa33bd9 into main Sep 15, 2024
12 checks passed
@marcocuturi marcocuturi deleted the sinkdiv branch September 15, 2024 18:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants