-
Notifications
You must be signed in to change notification settings - Fork 502
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Implementation of FUGW and UCOOT #677
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #677 +/- ##
==========================================
+ Coverage 96.88% 96.99% +0.10%
==========================================
Files 93 96 +3
Lines 18166 19117 +951
==========================================
+ Hits 17600 18542 +942
- Misses 566 575 +9 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @6Ulm for this impressive work.
Follows some remarks and comments to conclude the PR.
You also forgot to update ot.gromov.__init__.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your last updates, you are almost there ! follows small comments on the documentation to correct then we'll be able to merge with the main branch :)
Types of changes
This PR is dedicated to the implementation of
1.Fused Unbalanced GW (or more correctly, its lower bound)
2. (Fused) Unbalanced COOT.
Since their structures, it is enough to write a common template, then write a wrapper for each divergence.
More precisely, we create a method called
fused_unbalanced_cross_spaces_divergence
, in whichreg_type="independent"
corresponds to (Fused) UCOOT. This yieldsunbalanced_co_optimal_transport
method.reg_type="joint"
corresponds to FUGW. This yieldsfused_unbalanced_gromov_wasserstein
method.We also allow for unregularized approximation of FUGW and UCOOT, i.e.$\varepsilon = 0$ , thanks to the Majorization-Minization
ot.unbalanced.mm_unbalanced
andot.unbalanced.lbfgsb_unbalanced
L-BFGS-B methods.This implementation also allows for$2$ types of marginal penalization: Kullback-Leibler divergence and squared L2 norm. We also allow the cost to be sub-differentiable w.r.t the input matrices and reference distributions. This is implemented in
unbalanced_co_optimal_transport2
andfused_unbalanced_gromov_wasserstein
methods.Motivation and context / Related issue
How has this been tested (if it applies)
PR checklist