-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
introduce new norm
operator that can differentiate through lambda x: norm(x-x)
+ lighter tests across the board
#411
Conversation
norm
operator that can differentiate through lambda x: norm(x-x)
also of interest to @theouscidda6 as potentially useful in Monge gap. |
# Test div of x to itself close to 0. | ||
# Check differentiability of Sinkhorn divergence works, without NaN's. | ||
grad = jax.grad(lambda x: div(x).divergence)(x) | ||
assert jnp.all(jnp.logical_not(jnp.isnan(grad))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use np.testing...
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure I can find the adequate test in there, https://numpy.org/doc/stable/reference/routines.testing.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use np.testing.assert_array_equal(jnp.isnan(grad), False)
.
thanks Michal! |
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #411 +/- ##
==========================================
- Coverage 91.51% 91.44% -0.07%
==========================================
Files 54 54
Lines 5891 5902 +11
Branches 856 860 +4
==========================================
+ Hits 5391 5397 +6
- Misses 364 370 +6
+ Partials 136 135 -1
|
norm
operator that can differentiate through lambda x: norm(x-x)
norm
operator that can differentiate through lambda x: norm(x-x)
+ lighter tests across the board
# Test div of x to itself close to 0. | ||
# Check differentiability of Sinkhorn divergence works, without NaN's. | ||
grad = jax.grad(lambda x: div(x).divergence)(x) | ||
assert jnp.all(jnp.logical_not(jnp.isnan(grad))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use np.testing.assert_array_equal(jnp.isnan(grad), False)
.
…x: norm(x-x)` + lighter tests across the board (#411) * introduce new norm operator * add * resolve name conflict in (math).utils_test.py * docs * vjp -> jvp * fix * add axis in test * type error * speed up some tests * fix * fix * fix apply_jacobian test * fix mem size * batch size * yet another mem fix * minor fixes + add in docs * minor docs fixes * docs
Computing the gradient of Sinkhorn divergences using the
Euclidean
cost results inNaN
values, because of the reliance onjnp.linalg.norm
, and in particular, differentiation of the distance of a point against itself, e.g.norm(x-x)
in the diagonal of a symmetric cost matrix.The fact that such a gradient is (and should be, in general)
NaN
is well documented, see e.g. jax-ml/jax#6484However, In the context of OT, this poses problems, since it is safe to ignore these contributions, and therefore treat them as having 0 gradient.
This PR introduces a new
norm
function that does not blow up with aNaN
at0
.results in
as a result it should be now possible to differentiate through a
sinkhorn_divergence
with a more elaborate cost without producingNaN
's.Also, in order to speed things up in CI, prune out some tests.