Skip to content

Commit

Permalink
Add more h_legendre tests
Browse files Browse the repository at this point in the history
  • Loading branch information
michalk8 committed Apr 29, 2024
1 parent b0ad84e commit 35c2113
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions tests/geometry/costs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ class TestTIRegCost:
@pytest.mark.parametrize(
"cost_fn", [
costs.SqPNorm(p=1.0),
costs.SqPNorm(2.3),
costs.PNormP(p=1.0),
costs.SqPNorm(2.4),
costs.PNormP(p=1.1),
costs.PNormP(1.3),
costs.SqEuclidean()
]
Expand All @@ -147,8 +147,25 @@ def test_h_legendre(self, rng: jax.Array, cost_fn: costs.TICost):

np.testing.assert_array_equal(jnp.isfinite(h_transform(x)), True)

def test_h_legendre_sqeucl(self):
pass
@pytest.mark.parametrize("ridge", [1e-12, 1e-6])
def test_h_legendre_sqeucl(self, rng: jax.Array, ridge: float):
n, d = 12, 4
rngs = jax.random.split(rng, 2)
u = jnp.abs(jax.random.uniform(rngs[0], (d,)))
x = jax.random.normal(rngs[1], (n, d))

sqeucl = costs.SqEuclidean()
el_l2 = costs.ElasticL2(scaling_reg=0.0)

h_concave = lambda z: 0.5 * (-sqeucl.h(z) + jnp.dot(z, u))
h_concave_half = lambda z: -sqeucl.h(z) + jnp.dot(z, u)

pred = jax.jit(
jax.vmap(jax.grad(sqeucl.h_transform(h_concave, ridge=ridge)))
)
gt = jax.jit(jax.vmap(jax.grad(el_l2.h_transform(h_concave_half))))

np.testing.assert_allclose(pred(x), gt(x), rtol=1e-5, atol=1e-5)


@pytest.mark.fast()
Expand Down

0 comments on commit 35c2113

Please sign in to comment.