-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add monge gap #361
Add monge gap #361
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Codecov Report
❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more. Additional details and impacted files@@ Coverage Diff @@
## main #361 +/- ##
==========================================
+ Coverage 87.95% 88.00% +0.04%
==========================================
Files 52 54 +2
Lines 5679 5769 +90
Branches 841 857 +16
==========================================
+ Hits 4995 5077 +82
- Misses 561 566 +5
- Partials 123 126 +3
|
hi Théo, can you fix the linter issues? you can check first the CONTRIBUTING.md file and install Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is fantastic Théo! thanks so much, I am sure this will be very useful.
I would suggest moving the estimation pipeline in nn/solvers
, I am sure other people will benefit from it.
src/ott/solvers/nn/losses.py
Outdated
from typing import Any, Mapping, Callable | ||
from types import MappingProxyType | ||
|
||
class MongeGap: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds like this is missing a simple test.
This could be running the Monge gap on the gradient of a convex function on a small batch of 10 points.
The test should be put in the same location as losses.py
but in the test
folder, and be losses_test.py
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree with Marco, this function needs some tests.
@@ -0,0 +1,734 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feels like the first "big" equation (in the begin{equation} env) is not rendering in this NB.
squared-Euclidean (capital letter)
Explain in what sense this provides an alternative to the solver = Sinkhorn(norm_error=1)
solver, and, ideally, we might want to run it in this solver too?
When defining the monge gap, mention to cost_fn
is not provided.
Reply via ReviewNB
src/ott/solvers/nn/losses.py
Outdated
self.sinkhorn_kwargs = sinkhorn_kwargs | ||
|
||
def __call__( | ||
self, samples: jnp.ndarray, T: Callable[[jnp.ndarray], jnp.ndarray] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this function should receive the function computing the transport, as it limits what can be passed (e.g., classes implementing __call__
). I'd rather you pass the transported samples.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it might be worth revisiting this choice.
In fact, I think the notion of Monge gap only has meaning when it is passed with callable (the Monge gap is defined as a function defined on maps in the paper), so I think the initial implementation was better conceptually (i.e. I would pass callable T
and source
points as arguments, just to instantiate target
as being T
run on samples.
We can have a function with source
and target
points passed directly, but this should be called something different (e.g. id_coupling_gap
).
src/ott/solvers/nn/losses.py
Outdated
from typing import Any, Mapping, Callable | ||
from types import MappingProxyType | ||
|
||
class MongeGap: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree with Marco, this function needs some tests.
@@ -0,0 +1,734 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line #163. if self.plot_fitted_map:
I feel like plotting shouldn't be part of the training loop - it creates a lot of arguments and pollutes the class (+ it's simple enough to be added after the map has been fitted).
Reply via ReviewNB
@@ -0,0 +1,734 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... instantiated for the squared Euclidean cost.
... that will be bot -> ... that will be both
Please use numbered lists instead of unnumbered lists + (i)/(ii).
The rendering of "... class. For its instantiation, we need ..." seems to have strike-through enabled.
Please add link to the MLP class as {class}...
.
TODO(michalk8): add a section header for MapEstimator
and use it when referring to it.
Reply via ReviewNB
@@ -0,0 +1,734 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line #35. dict_monge_gaps = {
I'd slightly prefer to defined 2 variables for the 2 different Monge gaps and later use:
for key, monge_gap in zip(["no_monge_gap", ...], [None, ]):
...
But if you prefer this, let's keep it!
Reply via ReviewNB
@@ -0,0 +1,734 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line #1. # neural network and optimizer
Please add above a cell that briefly explains what you're doing in this cell.
Reply via ReviewNB
@@ -0,0 +1,734 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... instantiated for the squared Euclidean cost.
... that will be bot -> ... that will be both
Please use numbered lists instead of unnumbered lists + (i)/(ii).
The rendering of "... class. For its instantiation, we need ..." seems to have strike-through enabled.
Please add link to the MLP class as {class}...
.
TODO(michalk8): add a section header for MapEstimator
and use it when referring to it.
Reply via ReviewNB
@@ -0,0 +1,734 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line #35. dict_monge_gaps = {
I'd slightly prefer to defined 2 variables for the 2 different Monge gaps and later use:
for key, monge_gap in zip(["no_monge_gap", ...], [None, ]):
...
But if you prefer this, let's keep it!
Reply via ReviewNB
Hi @theouscidda6 , any progress on this? |
Thank you very much for your help and comments! Here is a revised pull request. |
make branch up to date
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @theouscidda6 , I think the tests are shaping up nicely, just some minor comments! Further, I think we should test for the following:
- case when the source/target is the same
- differentiating through the MongeGap
- this comment by @marcocuturi : Add monge gap #361 (comment)
This LGTM Théo, thanks a lot for all this work!! |
* add monge gap and corresponding tutorial * add monge gap and corresponding tutorial * update bib * revising the pull-request * revising the pull request * revising the pull request * revising the pull request * fix linter issues * fix linter issues * fix linter issues * fix linter issues * fix linter issues * fix linter issues * fix linter issues * fix linter issues * fix linter issues * fix linter issues * fix linter issues * fix linter issues * fix lint cod issue * update losses * update output * update doc * update doc * update doc * Update docs v3 * Fix docs linter * Refer to shapes * Add a skeleton test * update rng type for icnn test for consistentency * add monge gap test * add map_estimator test * add_new_tests * add_new_tests * Clean tests * Reduce number of tests * Further reduce number of tests * Pin numpy * CHange k-means threshold * Fix missing fixture --------- Co-authored-by: Michal Klein <[email protected]> Co-authored-by: Marco Cuturi <[email protected]>
This pull request proposes to add the
MongeGap
class toott-jax
, by creating a newlosses.py
file in theott/src/ott/solvers/nn
folder. It includes the addition of a tutorial showing how to recode, from scratch, a simple solver to fit a neural OT map using the Monge gap.