Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add monge gap #361

Merged
merged 44 commits into from
Jul 4, 2023
Merged

Add monge gap #361

merged 44 commits into from
Jul 4, 2023

Conversation

theouscidda6
Copy link
Contributor

This pull request proposes to add the MongeGap class to ott-jax, by creating a new losses.py file in the ott/src/ott/solvers/nn folder. It includes the addition of a tutorial showing how to recode, from scratch, a simple solver to fit a neural OT map using the Monge gap.

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@codecov-commenter
Copy link

codecov-commenter commented May 4, 2023

Codecov Report

Merging #361 (ff9832a) into main (b2b7ebb) will increase coverage by 0.04%.
The diff coverage is 88.88%.

❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #361      +/-   ##
==========================================
+ Coverage   87.95%   88.00%   +0.04%     
==========================================
  Files          52       54       +2     
  Lines        5679     5769      +90     
  Branches      841      857      +16     
==========================================
+ Hits         4995     5077      +82     
- Misses        561      566       +5     
- Partials      123      126       +3     
Impacted Files Coverage Δ
src/ott/geometry/geometry.py 92.69% <ø> (ø)
src/ott/tools/map_estimator.py 87.01% <87.01%> (ø)
src/ott/solvers/nn/losses.py 100.00% <100.00%> (ø)

... and 1 file with indirect coverage changes

@michalk8 michalk8 added the enhancement New feature or request label May 5, 2023
@michalk8 michalk8 self-requested a review May 11, 2023 13:32
@marcocuturi
Copy link
Contributor

hi Théo, can you fix the linter issues?

you can check first the CONTRIBUTING.md file and install pre-commit

Thanks!

Copy link
Contributor

@marcocuturi marcocuturi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fantastic Théo! thanks so much, I am sure this will be very useful.

I would suggest moving the estimation pipeline in nn/solvers, I am sure other people will benefit from it.

src/ott/solvers/nn/losses.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/losses.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/losses.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/losses.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/losses.py Outdated Show resolved Hide resolved
from typing import Any, Mapping, Callable
from types import MappingProxyType

class MongeGap:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like this is missing a simple test.

This could be running the Monge gap on the gradient of a convex function on a small batch of 10 points.

The test should be put in the same location as losses.py but in the test folder, and be losses_test.py.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with Marco, this function needs some tests.

@@ -0,0 +1,734 @@
{
Copy link
Contributor

@marcocuturi marcocuturi May 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels like the first "big" equation (in the begin{equation} env) is not rendering in this NB.

squared-Euclidean (capital letter)

Explain in what sense this provides an alternative to the solver = Sinkhorn(norm_error=1) solver, and, ideally, we might want to run it in this solver too?

When defining the monge gap, mention to cost_fn is not provided.


Reply via ReviewNB

docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
src/ott/solvers/nn/losses.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/losses.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/losses.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/losses.py Outdated Show resolved Hide resolved
self.sinkhorn_kwargs = sinkhorn_kwargs

def __call__(
self, samples: jnp.ndarray, T: Callable[[jnp.ndarray], jnp.ndarray]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this function should receive the function computing the transport, as it limits what can be passed (e.g., classes implementing __call__). I'd rather you pass the transported samples.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be worth revisiting this choice.

In fact, I think the notion of Monge gap only has meaning when it is passed with callable (the Monge gap is defined as a function defined on maps in the paper), so I think the initial implementation was better conceptually (i.e. I would pass callable T and source points as arguments, just to instantiate target as being T run on samples.

We can have a function with source and target points passed directly, but this should be called something different (e.g. id_coupling_gap).

from typing import Any, Mapping, Callable
from types import MappingProxyType

class MongeGap:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with Marco, this function needs some tests.

src/ott/solvers/nn/losses.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/losses.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/losses.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/losses.py Outdated Show resolved Hide resolved
docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
@@ -0,0 +1,734 @@
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line #163.            if self.plot_fitted_map:

I feel like plotting shouldn't be part of the training loop - it creates a lot of arguments and pollutes the class (+ it's simple enough to be added after the map has been fitted).


Reply via ReviewNB

@@ -0,0 +1,734 @@
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... instantiated for the squared Euclidean cost.

... that will be bot -> ... that will be both

Please use numbered lists instead of unnumbered lists + (i)/(ii).

The rendering of "... class. For its instantiation, we need ..." seems to have strike-through enabled.

Please add link to the MLP class as {class}....

TODO(michalk8): add a section header for MapEstimator and use it when referring to it.


Reply via ReviewNB

docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
@@ -0,0 +1,734 @@
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line #35.    dict_monge_gaps = {

I'd slightly prefer to defined 2 variables for the 2 different Monge gaps and later use:

for key, monge_gap in zip(["no_monge_gap", ...], [None, ]):

...

But if you prefer this, let's keep it!


Reply via ReviewNB

@@ -0,0 +1,734 @@
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line #1.    # neural network and optimizer

Please add above a cell that briefly explains what you're doing in this cell.


Reply via ReviewNB

docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
@@ -0,0 +1,734 @@
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... instantiated for the squared Euclidean cost.

... that will be bot -> ... that will be both

Please use numbered lists instead of unnumbered lists + (i)/(ii).

The rendering of "... class. For its instantiation, we need ..." seems to have strike-through enabled.

Please add link to the MLP class as {class}....

TODO(michalk8): add a section header for MapEstimator and use it when referring to it.


Reply via ReviewNB

docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
@@ -0,0 +1,734 @@
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line #35.    dict_monge_gaps = {

I'd slightly prefer to defined 2 variables for the 2 different Monge gaps and later use:

for key, monge_gap in zip(["no_monge_gap", ...], [None, ]):

...

But if you prefer this, let's keep it!


Reply via ReviewNB

docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
docs/tutorials/notebooks/Monge_Gap.ipynb Show resolved Hide resolved
@michalk8
Copy link
Collaborator

Hi @theouscidda6 , any progress on this?

@theouscidda6
Copy link
Contributor Author

Thank you very much for your help and comments! Here is a revised pull request.

docs/references.bib Show resolved Hide resolved
docs/references.bib Outdated Show resolved Hide resolved
src/ott/solvers/nn/losses.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/losses.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/losses.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/losses.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/losses.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/losses.py Outdated Show resolved Hide resolved
src/ott/solvers/nn/losses.py Outdated Show resolved Hide resolved
src/ott/tools/map_estimator.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@michalk8 michalk8 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @theouscidda6 , I think the tests are shaping up nicely, just some minor comments! Further, I think we should test for the following:

tests/solvers/nn/icnn_test.py Outdated Show resolved Hide resolved
tests/solvers/nn/losses_test.py Outdated Show resolved Hide resolved
tests/solvers/nn/losses_test.py Outdated Show resolved Hide resolved
tests/solvers/nn/losses_test.py Outdated Show resolved Hide resolved
tests/solvers/nn/losses_test.py Show resolved Hide resolved
tests/solvers/nn/losses_test.py Outdated Show resolved Hide resolved
tests/solvers/nn/losses_test.py Outdated Show resolved Hide resolved
tests/solvers/nn/losses_test.py Outdated Show resolved Hide resolved
tests/solvers/nn/losses_test.py Outdated Show resolved Hide resolved
tests/solvers/nn/losses_test.py Outdated Show resolved Hide resolved
@marcocuturi
Copy link
Contributor

This LGTM Théo, thanks a lot for all this work!!

@marcocuturi marcocuturi merged commit b47aaaf into ott-jax:main Jul 4, 2023
@michalk8 michalk8 mentioned this pull request Jul 21, 2023
3 tasks
michalk8 added a commit that referenced this pull request Jun 27, 2024
* add monge gap and corresponding tutorial

* add monge gap and corresponding tutorial

* update bib

* revising the pull-request

* revising the pull request

* revising the pull request

* revising the pull request

* fix linter issues

* fix linter issues

* fix linter issues

* fix linter issues

* fix linter issues

* fix linter issues

* fix linter issues

* fix linter issues

* fix linter issues

* fix linter issues

* fix linter issues

* fix linter issues

* fix lint cod issue

* update losses

* update output

* update doc

* update doc

* update doc

* Update docs v3

* Fix docs linter

* Refer to shapes

* Add a skeleton test

* update rng type for icnn test for consistentency

* add monge gap test

* add map_estimator test

* add_new_tests

* add_new_tests

* Clean tests

* Reduce number of tests

* Further reduce number of tests

* Pin numpy

* CHange k-means threshold

* Fix missing fixture

---------

Co-authored-by: Michal Klein <[email protected]>
Co-authored-by: Marco Cuturi <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants