Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/h legendre tests #529

Merged
merged 8 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/ott/math/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,11 @@ def logsumexp_jvp(axis, keepdims, return_sign, primals, tangents):
if return_sign:
lse, sign = lse
lse = jnp.where(jnp.isfinite(lse), lse, 0.0)
centered_exp = jnp.exp(mat - jnp.expand_dims(lse, axis=axis))

if axis is not None:
centered_exp = jnp.exp(mat - jnp.expand_dims(lse, axis=axis))
else:
centered_exp = jnp.exp(mat - lse)

if b is None:
res = jnp.sum(centered_exp * tan_mat, axis=axis, keepdims=keepdims)
Expand Down
63 changes: 63 additions & 0 deletions tests/geometry/costs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import numpy as np

from ott.geometry import costs, pointcloud
from ott.math import utils as mu
from ott.solvers import linear

try:
Expand Down Expand Up @@ -128,6 +129,45 @@ def test_bures(self, rng: jax.Array):
np.testing.assert_equal(diffs.shape[0], max_iterations // inner_iterations)


class TestTIRegCost:

@pytest.mark.parametrize(
"cost_fn", [
costs.SqPNorm(p=1.0),
costs.SqPNorm(2.4),
costs.PNormP(p=1.1),
costs.PNormP(1.3),
costs.SqEuclidean()
]
)
def test_h_legendre(self, rng: jax.Array, cost_fn: costs.TICost):
x = jax.random.normal(rng, (15, 3))
h_transform = cost_fn.h_transform(mu.logsumexp)
h_transform = jax.jit(jax.vmap(jax.grad(h_transform)))

np.testing.assert_array_equal(jnp.isfinite(h_transform(x)), True)

@pytest.mark.parametrize("ridge", [1e-12, 1e-6])
def test_h_legendre_sqeucl(self, rng: jax.Array, ridge: float):
n, d = 12, 4
rngs = jax.random.split(rng, 2)
u = jnp.abs(jax.random.uniform(rngs[0], (d,)))
x = jax.random.normal(rngs[1], (n, d))

sqeucl = costs.SqEuclidean()
el_l2 = costs.ElasticL2(scaling_reg=0.0)

h_concave = lambda z: 0.5 * (-sqeucl.h(z) + jnp.dot(z, u))
h_concave_half = lambda z: -sqeucl.h(z) + jnp.dot(z, u)

pred = jax.jit(
jax.vmap(jax.grad(sqeucl.h_transform(h_concave, ridge=ridge)))
)
gt = jax.jit(jax.vmap(jax.grad(el_l2.h_transform(h_concave_half))))

np.testing.assert_allclose(pred(x), gt(x), rtol=1e-5, atol=1e-5)


@pytest.mark.fast()
class TestRegTICost:

Expand Down Expand Up @@ -216,6 +256,29 @@ def test_stronger_regularization_increases_sparsity(
for fwd in [False, True]:
np.testing.assert_array_equal(np.diff(sparsity[fwd]) > 0.0, True)

@pytest.mark.parametrize("d", [5, 10])
def test_h_legendre_elastic_l2(self, rng: jax.Array, d: int):
n, d = 13, d
rngs = jax.random.split(rng, 2)
x = jax.random.normal(rngs[0], (n, d))
u = jax.random.normal(rngs[1], (d,))

elastic_l2 = costs.ElasticL2(scaling_reg=0.0)
p_norm_p = costs.PNormP(p=2)

concave_fn = lambda z: -elastic_l2.h(z) + jnp.dot(z, u)

p_grad_h = jax.jit(
jax.vmap(jax.grad(p_norm_p.h_transform(concave_fn, tol=1e-5)))
)
elastic_grad_h = jax.vmap(
jax.grad(elastic_l2.h_transform(concave_fn, tol=1e-5))
)

np.testing.assert_allclose(
elastic_grad_h(x), p_grad_h(x), rtol=1e-4, atol=1e-4
)


@pytest.mark.skipif(ts_metrics is None, reason="Not supported for Python 3.11")
@pytest.mark.fast()
Expand Down
Loading