-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor sinkhorn_divergence #577
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #577 +/- ##
==========================================
- Coverage 87.84% 87.83% -0.01%
==========================================
Files 73 73
Lines 7823 7826 +3
Branches 1127 1127
==========================================
+ Hits 6872 6874 +2
- Misses 798 799 +1
Partials 153 153
|
src/ott/tools/sinkhorn_divergence.py
Outdated
@@ -115,7 +115,8 @@ def sinkhorn_divergence( | |||
geometry. | |||
|
|||
Returns: | |||
Sinkhorn divergence value, three pairs of potentials, three costs. | |||
Sinkhorn divergence value, in addition to | |||
:class:`~ott.tools.sinkhorn_divergence.SinkhornDivergenceOutput` object. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would simplify to (also in the one other place), but we can also keep:
The Sinkhorn divergence and the output object.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks.
i have also made a small change. _sinkhorn_divergence
now only outputs the output
object, it's sinkhorn_divergence
that chooses to expose explicitly the value first in the Tuple, with the output second.
Ensure that
sinkhorn_divergence
function in tools returns a tuple. First element in Tuple (what most users will want) is divergence float value. Second element is detailed output.This new convention will likely trigger
AttributeError: 'tuple' object has no attribute 'divergence'
errors in code that is using the former API, apologies for this!
Also, we may think about featuring the wrappers in
progot.py
L. 410 or appearing also insinkhorn_divergence_test.py
L.432 more prominently, as they follow the more intuitive APIsdiv(x,y)
which our implementation does not allow directly, as we need to pass geometry type.As I believe 99% of users use Sinkhorn divergences with a point cloud geometry, we might want to remove that ambiguity/possibility?