Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

introduce new norm operator that can differentiate through lambda x: norm(x-x) + lighter tests across the board #411

Merged
merged 18 commits into from
Aug 15, 2023

Conversation

marcocuturi
Copy link
Contributor

@marcocuturi marcocuturi commented Aug 14, 2023

Computing the gradient of Sinkhorn divergences using the Euclidean cost results in NaN values, because of the reliance on jnp.linalg.norm, and in particular, differentiation of the distance of a point against itself, e.g. norm(x-x) in the diagonal of a symmetric cost matrix.

The fact that such a gradient is (and should be, in general) NaN is well documented, see e.g. jax-ml/jax#6484

However, In the context of OT, this poses problems, since it is safe to ignore these contributions, and therefore treat them as having 0 gradient.

This PR introduces a new norm function that does not blow up with a NaN at 0.

f = lambda x: ott.math.utils.norm(x - x)
g = lambda x: jnp.linalg.norm(x - x)
x = jnp.array([1.2, 3.2, 4.1])
print(jax.grad(f)(x))
print(jax.grad(g)(x))

results in

[0. 0. 0.]
[nan nan nan]

as a result it should be now possible to differentiate through a sinkhorn_divergence with a more elaborate cost without producing NaN's.

Also, in order to speed things up in CI, prune out some tests.

@marcocuturi marcocuturi changed the title introduce new norm operator introduce new norm operator that can differentiate through lambda x: norm(x-x) Aug 14, 2023
@marcocuturi marcocuturi requested a review from michalk8 August 14, 2023 14:51
@marcocuturi
Copy link
Contributor Author

also of interest to @theouscidda6 as potentially useful in Monge gap.

src/ott/math/utils.py Outdated Show resolved Hide resolved
src/ott/math/utils.py Show resolved Hide resolved
src/ott/math/utils.py Outdated Show resolved Hide resolved
tests/math/math_utils_test.py Outdated Show resolved Hide resolved
tests/math/math_utils_test.py Outdated Show resolved Hide resolved
tests/math/math_utils_test.py Show resolved Hide resolved
tests/tools/sinkhorn_divergence_test.py Show resolved Hide resolved
# Test div of x to itself close to 0.
# Check differentiability of Sinkhorn divergence works, without NaN's.
grad = jax.grad(lambda x: div(x).divergence)(x)
assert jnp.all(jnp.logical_not(jnp.isnan(grad)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use np.testing....

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure I can find the adequate test in there, https://numpy.org/doc/stable/reference/routines.testing.html

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use np.testing.assert_array_equal(jnp.isnan(grad), False).

tests/math/math_utils_test.py Outdated Show resolved Hide resolved
src/ott/math/utils.py Outdated Show resolved Hide resolved
@marcocuturi
Copy link
Contributor Author

thanks Michal!

@codecov
Copy link

codecov bot commented Aug 15, 2023

Codecov Report

Merging #411 (c9f5780) into main (f275dc4) will decrease coverage by 0.07%.
The diff coverage is 100.00%.

❗ Current head c9f5780 differs from pull request most recent head 704c98a. Consider uploading reports for the commit 704c98a to get more accurate results

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #411      +/-   ##
==========================================
- Coverage   91.51%   91.44%   -0.07%     
==========================================
  Files          54       54              
  Lines        5891     5902      +11     
  Branches      856      860       +4     
==========================================
+ Hits         5391     5397       +6     
- Misses        364      370       +6     
+ Partials      136      135       -1     
Files Changed Coverage Δ
src/ott/geometry/costs.py 92.83% <100.00%> (ø)
src/ott/math/utils.py 94.54% <100.00%> (+1.36%) ⬆️

... and 3 files with indirect coverage changes

@marcocuturi marcocuturi changed the title introduce new norm operator that can differentiate through lambda x: norm(x-x) introduce new norm operator that can differentiate through lambda x: norm(x-x) + lighter tests across the board Aug 15, 2023
@marcocuturi marcocuturi requested a review from michalk8 August 15, 2023 08:40
src/ott/math/utils.py Show resolved Hide resolved
src/ott/math/utils.py Show resolved Hide resolved
# Test div of x to itself close to 0.
# Check differentiability of Sinkhorn divergence works, without NaN's.
grad = jax.grad(lambda x: div(x).divergence)(x)
assert jnp.all(jnp.logical_not(jnp.isnan(grad)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use np.testing.assert_array_equal(jnp.isnan(grad), False).

tests/math/math_utils_test.py Outdated Show resolved Hide resolved
tests/math/math_utils_test.py Outdated Show resolved Hide resolved
@marcocuturi marcocuturi merged commit 12e78a7 into main Aug 15, 2023
@marcocuturi marcocuturi deleted the new_norm branch August 15, 2023 14:36
michalk8 pushed a commit that referenced this pull request Jun 27, 2024
…x: norm(x-x)` + lighter tests across the board (#411)

* introduce new norm operator

* add

* resolve name conflict in (math).utils_test.py

* docs

* vjp -> jvp

* fix

* add axis in test

* type error

* speed up some tests

* fix

* fix

* fix apply_jacobian test

* fix mem size

* batch size

* yet another mem fix

* minor fixes + add in docs

* minor docs fixes

* docs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants