Skip to content

Commit

Permalink
Merge pull request #35 from ott-jax/526517853C347DEEFE613233BE082B1B
Browse files Browse the repository at this point in the history
Migrate away from using JaxTestCase in tests
  • Loading branch information
LaetitiaPapaxanthos authored Mar 18, 2022
2 parents f108eb8 + 79a7f2c commit 5246188
Show file tree
Hide file tree
Showing 39 changed files with 353 additions and 360 deletions.
4 changes: 1 addition & 3 deletions tests/core/discrete_barycenter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,12 @@
from absl.testing import parameterized
import jax
import jax.numpy as jnp
import jax.test_util
from ott.core import discrete_barycenter as db
from ott.geometry import grid
from ott.geometry import pointcloud


@jax.test_util.with_config(jax_numpy_rank_promotion='allow')
class DiscreteBarycenterTest(jax.test_util.JaxTestCase):
class DiscreteBarycenterTest(parameterized.TestCase):

def setUp(self):
super().setUp()
Expand Down
43 changes: 22 additions & 21 deletions tests/core/fused_gromov_wasserstein_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,13 @@
from absl.testing import parameterized
import jax
import jax.numpy as jnp
import jax.test_util
import numpy as np
from ott.core import gromov_wasserstein
from ott.geometry import geometry
from ott.geometry import pointcloud


@jax.test_util.with_config(jax_numpy_rank_promotion='allow')
class FusedGromovWassersteinTest(jax.test_util.JaxTestCase):
class FusedGromovWassersteinTest(parameterized.TestCase):

def setUp(self):
super().setUp()
Expand Down Expand Up @@ -120,12 +119,14 @@ def reg_gw(a, b, implicit):
grad_manual_b = aux[1] - jnp.log(self.b)
self.assertIsNot(jnp.any(jnp.isnan(grad_reg_gw[0])), True)
self.assertIsNot(jnp.any(jnp.isnan(grad_reg_gw[1])), True)
self.assertAllClose(grad_manual_a, grad_reg_gw[0], rtol=1e-2, atol=1e-2)
self.assertAllClose(grad_manual_b, grad_reg_gw[1], rtol=1e-2, atol=1e-2)
self.assertAllClose(grad_matrices[0][0], grad_matrices[1][0],
rtol=1e-02, atol=1e-02)
self.assertAllClose(grad_matrices[0][1], grad_matrices[1][1],
rtol=1e-02, atol=1e-02)
np.testing.assert_allclose(
grad_manual_a, grad_reg_gw[0], rtol=1e-2, atol=1e-2)
np.testing.assert_allclose(
grad_manual_b, grad_reg_gw[1], rtol=1e-2, atol=1e-2)
np.testing.assert_allclose(
grad_matrices[0][0], grad_matrices[1][0], rtol=1e-02, atol=1e-02)
np.testing.assert_allclose(
grad_matrices[0][1], grad_matrices[1][1], rtol=1e-02, atol=1e-02)

@parameterized.parameters([True], [False])
def test_fused_gromov_wasserstein_pointcloud(self, lse_mode):
Expand Down Expand Up @@ -184,10 +185,10 @@ def reg_gw(x, y, x_2, y_2, fused_penalty, a, b, implicit):
grad_matrices[i] = grad_reg_gw
self.assertIsNot(jnp.any(jnp.isnan(grad_reg_gw[0])), True)
self.assertIsNot(jnp.any(jnp.isnan(grad_reg_gw[1])), True)
self.assertAllClose(grad_matrices[0][0], grad_matrices[1][0],
rtol=1e-02, atol=1e-02)
self.assertAllClose(grad_matrices[0][1], grad_matrices[1][1],
rtol=1e-02, atol=1e-02)
np.testing.assert_allclose(
grad_matrices[0][0], grad_matrices[1][0], rtol=1e-02, atol=1e-02)
np.testing.assert_allclose(
grad_matrices[0][1], grad_matrices[1][1], rtol=1e-02, atol=1e-02)

@parameterized.parameters([True], [False])
def test_gradient_fused_gromov_wasserstein_geometry(self, lse_mode):
Expand Down Expand Up @@ -223,12 +224,12 @@ def reg_gw(cx, cy, cxy, fused_penalty, a, b, implicit):
grad_matrices[i] = grad_reg_gw
self.assertIsNot(jnp.any(jnp.isnan(grad_reg_gw[0])), True)
self.assertIsNot(jnp.any(jnp.isnan(grad_reg_gw[1])), True)
self.assertAllClose(grad_matrices[0][0], grad_matrices[1][0],
rtol=1e-02, atol=1e-02)
self.assertAllClose(grad_matrices[0][1], grad_matrices[1][1],
rtol=1e-02, atol=1e-02)
self.assertAllClose(grad_matrices[0][2], grad_matrices[1][2],
rtol=1e-02, atol=1e-02)
np.testing.assert_allclose(
grad_matrices[0][0], grad_matrices[1][0], rtol=1e-02, atol=1e-02)
np.testing.assert_allclose(
grad_matrices[0][1], grad_matrices[1][1], rtol=1e-02, atol=1e-02)
np.testing.assert_allclose(
grad_matrices[0][2], grad_matrices[1][2], rtol=1e-02, atol=1e-02)

def test_adaptive_threshold_fused(self):
"""Checking solution is improved with smaller threshold for convergence."""
Expand Down Expand Up @@ -282,8 +283,8 @@ def reg_gw(cx, cy, cxy, fused_penalty, a, b, implicit):
implicit)
grad_matrices[i] = grad_reg_gw
self.assertIsNot(jnp.any(jnp.isnan(grad_reg_gw[0])), True)
self.assertAllClose(grad_matrices[0][0], grad_matrices[1][0],
rtol=1e-02, atol=1e-02)
np.testing.assert_allclose(
grad_matrices[0][0], grad_matrices[1][0], rtol=1e-02, atol=1e-02)

def test_effect_fused_penalty(self):

Expand Down
39 changes: 20 additions & 19 deletions tests/core/gromov_wasserstein_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,14 @@
from absl.testing import parameterized
import jax
import jax.numpy as jnp
import jax.test_util
import numpy as np
from ott.core import gromov_wasserstein
from ott.core import quad_problems
from ott.geometry import geometry
from ott.geometry import pointcloud


@jax.test_util.with_config(jax_numpy_rank_promotion='allow')
class GromovWassersteinTest(jax.test_util.JaxTestCase):
class GromovWassersteinTest(parameterized.TestCase):

def setUp(self):
super().setUp()
Expand Down Expand Up @@ -98,12 +97,14 @@ def reg_gw(a, b, implicit):
grad_manual_b = aux[1] - jnp.log(self.b)
self.assertIsNot(jnp.any(jnp.isnan(grad_reg_gw[0])), True)
self.assertIsNot(jnp.any(jnp.isnan(grad_reg_gw[1])), True)
self.assertAllClose(grad_manual_a, grad_reg_gw[0], rtol=1e-2, atol=1e-2)
self.assertAllClose(grad_manual_b, grad_reg_gw[1], rtol=1e-2, atol=1e-2)
self.assertAllClose(grad_matrices[0][0], grad_matrices[1][0],
rtol=1e-02, atol=1e-02)
self.assertAllClose(grad_matrices[0][1], grad_matrices[1][1],
rtol=1e-02, atol=1e-02)
np.testing.assert_allclose(
grad_manual_a, grad_reg_gw[0], rtol=1e-2, atol=1e-2)
np.testing.assert_allclose(
grad_manual_b, grad_reg_gw[1], rtol=1e-2, atol=1e-2)
np.testing.assert_allclose(
grad_matrices[0][0], grad_matrices[1][0], rtol=1e-02, atol=1e-02)
np.testing.assert_allclose(
grad_matrices[0][1], grad_matrices[1][1], rtol=1e-02, atol=1e-02)

def test_gromov_wasserstein_pointcloud(self):
"""Test basic computations pointclouds."""
Expand Down Expand Up @@ -136,10 +137,10 @@ def reg_gw(x, y, a, b, implicit):
grad_matrices[i] = grad_reg_gw
self.assertIsNot(jnp.any(jnp.isnan(grad_reg_gw[0])), True)
self.assertIsNot(jnp.any(jnp.isnan(grad_reg_gw[1])), True)
self.assertAllClose(grad_matrices[0][0], grad_matrices[1][0],
rtol=1e-02, atol=1e-02)
self.assertAllClose(grad_matrices[0][1], grad_matrices[1][1],
rtol=1e-02, atol=1e-02)
np.testing.assert_allclose(
grad_matrices[0][0], grad_matrices[1][0], rtol=1e-02, atol=1e-02)
np.testing.assert_allclose(
grad_matrices[0][1], grad_matrices[1][1], rtol=1e-02, atol=1e-02)

@parameterized.parameters([True], [False])
def test_gradient_gromov_wasserstein_geometry(self, lse_mode):
Expand All @@ -161,10 +162,10 @@ def reg_gw(cx, cy, a, b, implicit):
grad_matrices[i] = grad_reg_gw
self.assertIsNot(jnp.any(jnp.isnan(grad_reg_gw[0])), True)
self.assertIsNot(jnp.any(jnp.isnan(grad_reg_gw[1])), True)
self.assertAllClose(grad_matrices[0][0], grad_matrices[1][0],
rtol=1e-02, atol=1e-02)
self.assertAllClose(grad_matrices[0][1], grad_matrices[1][1],
rtol=1e-02, atol=1e-02)
np.testing.assert_allclose(
grad_matrices[0][0], grad_matrices[1][0], rtol=1e-02, atol=1e-02)
np.testing.assert_allclose(
grad_matrices[0][1], grad_matrices[1][1], rtol=1e-02, atol=1e-02)

def test_adaptive_threshold(self):
"""Checking solution is improved with smaller threshold for convergence."""
Expand Down Expand Up @@ -196,7 +197,7 @@ def test_gw_lr(self):
ot_gwlr = solver(prob)
solver = gromov_wasserstein.GromovWasserstein(epsilon=0.2)
ot_gw = solver(prob)
self.assertAllClose(ot_gwlr.costs, ot_gw.costs, rtol=5e-2)
np.testing.assert_allclose(ot_gwlr.costs, ot_gw.costs, rtol=5e-2)

def test_gw_lr_fused(self):
"""Checking LR and Entropic have similar outputs on same fused problem."""
Expand All @@ -223,7 +224,7 @@ def test_gw_lr_fused(self):
solver = gromov_wasserstein.GromovWasserstein(epsilon=5e-2)
ot_gw = solver(prob)

# Test solutions look alike
# Test solutions look alike
self.assertGreater(0.1, jnp.linalg.norm(ot_gwlr.matrix - ot_gw.matrix))
self.assertGreater(0.1, jnp.linalg.norm(ot_gwlr.matrix - ot_gwlreps.matrix))
# Test at least some difference when adding bigger entropic regularization
Expand Down
21 changes: 11 additions & 10 deletions tests/core/gromov_wasserstein_unbalanced_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@
from absl.testing import parameterized
import jax
import jax.numpy as jnp
import numpy as np
import jax.test_util
from ott.core import gromov_wasserstein
from ott.geometry import geometry
from ott.geometry import pointcloud

@jax.test_util.with_config(jax_numpy_rank_promotion='allow')
class GromovWassersteinUnbalancedTest(jax.test_util.JaxTestCase):

class GromovWassersteinUnbalancedTest(parameterized.TestCase):

def setUp(self):
super().setUp()
Expand Down Expand Up @@ -81,10 +82,10 @@ def reg_gw(x, y, a, b, implicit):
grad_matrices[i] = grad_reg_gw
self.assertIsNot(jnp.any(jnp.isnan(grad_reg_gw[0])), True)
self.assertIsNot(jnp.any(jnp.isnan(grad_reg_gw[1])), True)
self.assertAllClose(grad_matrices[0][0], grad_matrices[1][0],
rtol=1e-02, atol=1e-02)
self.assertAllClose(grad_matrices[0][1], grad_matrices[1][1],
rtol=1e-02, atol=1e-02)
np.testing.assert_allclose(
grad_matrices[0][0], grad_matrices[1][0], rtol=1e-02, atol=1e-02)
np.testing.assert_allclose(
grad_matrices[0][1], grad_matrices[1][1], rtol=1e-02, atol=1e-02)

@parameterized.parameters([True], [False])
def test_gradient_gromov_wasserstein_geometry(self, gw_unbalanced_correction):
Expand All @@ -108,10 +109,10 @@ def reg_gw(cx, cy, a, b, implicit):
grad_matrices[i] = grad_reg_gw
self.assertIsNot(jnp.any(jnp.isnan(grad_reg_gw[0])), True)
self.assertIsNot(jnp.any(jnp.isnan(grad_reg_gw[1])), True)
self.assertAllClose(grad_matrices[0][0], grad_matrices[1][0],
rtol=1e-02, atol=1e-02)
self.assertAllClose(grad_matrices[0][1], grad_matrices[1][1],
rtol=1e-02, atol=1e-02)
np.testing.assert_allclose(
grad_matrices[0][0], grad_matrices[1][0], rtol=1e-02, atol=1e-02)
np.testing.assert_allclose(
grad_matrices[0][1], grad_matrices[1][1], rtol=1e-02, atol=1e-02)


if __name__ == '__main__':
Expand Down
3 changes: 1 addition & 2 deletions tests/core/icnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@
from absl.testing import parameterized
import jax
import jax.numpy as jnp
import jax.test_util
from ott.core.icnn import ICNN


class ICNNTest(jax.test_util.JaxTestCase):
class ICNNTest(parameterized.TestCase):

def setUp(self):
super().setUp()
Expand Down
5 changes: 2 additions & 3 deletions tests/core/sinkhorn_anderson_acceleration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@
from absl.testing import parameterized
import jax
import jax.numpy as jnp
import jax.test_util
from ott.core import sinkhorn
from ott.geometry import pointcloud

@jax.test_util.with_config(jax_numpy_rank_promotion='allow')
class SinkhornAndersonTest(jax.test_util.JaxTestCase):

class SinkhornAndersonTest(parameterized.TestCase):
"""Tests for Anderson acceleration."""

def setUp(self):
Expand Down
7 changes: 3 additions & 4 deletions tests/core/sinkhorn_bures_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,13 @@
from absl.testing import parameterized
import jax
import jax.numpy as jnp
import jax.test_util
import numpy as np
from ott.core import sinkhorn
from ott.geometry import costs
from ott.geometry import pointcloud


@jax.test_util.with_config(jax_numpy_rank_promotion='allow')
class SinkhornTest(jax.test_util.JaxTestCase):
class SinkhornTest(parameterized.TestCase):

def setUp(self):
super().setUp()
Expand Down Expand Up @@ -83,7 +82,7 @@ def test_regularized_unbalanced_bures(self):
rub = costs.UnbalancedBures(self.dim, 1, 0.8)
self.assertIsNot(jnp.any(jnp.isnan(rub(x, y))), True)
self.assertIsNot(jnp.any(jnp.isnan(rub(y, x))), True)
self.assertAllClose(rub(x, y), rub(y, x), rtol=1e-3, atol=1e-3)
np.testing.assert_allclose(rub(x, y), rub(y, x), rtol=1e-3, atol=1e-3)


if __name__ == '__main__':
Expand Down
12 changes: 6 additions & 6 deletions tests/core/sinkhorn_diff_grid_loc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@
from absl.testing import parameterized
import jax
import jax.numpy as jnp
import jax.test_util
import numpy as np
from ott.core import sinkhorn
from ott.geometry import grid


@jax.test_util.with_config(jax_numpy_rank_promotion='allow')
class SinkhornGradGridTest(jax.test_util.JaxTestCase):
class SinkhornGradGridTest(parameterized.TestCase):

def setUp(self):
super().setUp()
Expand Down Expand Up @@ -63,9 +62,10 @@ def reg_ot(x):
delta_dot_grad = jnp.sum(jnp.array(
[jnp.sum(delt * gr, axis=None) for delt, gr in zip(delta, grad_reg_ot)]
))
self.assertAllClose(delta_dot_grad,
(reg_ot_delta_plus - reg_ot_delta_minus) / (2 * eps),
rtol=1e-03, atol=1e-02)
np.testing.assert_allclose(
delta_dot_grad, (reg_ot_delta_plus - reg_ot_delta_minus) / (2 * eps),
rtol=1e-03,
atol=1e-02)


if __name__ == '__main__':
Expand Down
12 changes: 6 additions & 6 deletions tests/core/sinkhorn_diff_grid_weights_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@
from absl.testing import parameterized
import jax
import jax.numpy as jnp
import jax.test_util
import numpy as np
from ott.core import sinkhorn
from ott.geometry import grid


@jax.test_util.with_config(jax_numpy_rank_promotion='allow')
class SinkhornGradGridTest(jax.test_util.JaxTestCase):
class SinkhornGradGridTest(parameterized.TestCase):

def setUp(self):
super().setUp()
Expand Down Expand Up @@ -60,9 +59,10 @@ def reg_ot(a, b):
reg_ot_delta_plus = reg_ot(a + eps * delta, b)
reg_ot_delta_minus = reg_ot(a - eps * delta, b)
delta_dot_grad = jnp.sum(delta * grad_reg_ot)
self.assertAllClose(delta_dot_grad,
(reg_ot_delta_plus - reg_ot_delta_minus) / (2 * eps),
rtol=1e-03, atol=1e-02)
np.testing.assert_allclose(
delta_dot_grad, (reg_ot_delta_plus - reg_ot_delta_minus) / (2 * eps),
rtol=1e-03,
atol=1e-02)


if __name__ == '__main__':
Expand Down
13 changes: 6 additions & 7 deletions tests/core/sinkhorn_diff_precond_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@
from absl.testing import parameterized
import jax
import jax.numpy as jnp
import jax.test_util
import numpy as np
from ott.tools import transport


@jax.test_util.with_config(jax_numpy_rank_promotion='allow')
class SinkhornJacobianPreconditioningTest(jax.test_util.JaxTestCase):
class SinkhornJacobianPreconditioningTest(parameterized.TestCase):

def setUp(self):
super().setUp()
Expand Down Expand Up @@ -106,16 +105,16 @@ def loss_from_potential(a,
val_p, _ = loss_imp_no_precond(a_p, x_p)
val_m, _ = loss_imp_no_precond(a_m, x_m)
fin_dif = (val_p - val_m) / (2 * perturb_scale)
self.assertAllClose(fin_dif, imp_dif_lp, atol=1e-2, rtol=1e-2)
self.assertAllClose(fin_dif, imp_dif_np, atol=1e-2, rtol=1e-2)
self.assertAllClose(imp_dif_np, imp_dif_lp, atol=1e-2, rtol=1e-2)
np.testing.assert_allclose(fin_dif, imp_dif_lp, atol=1e-2, rtol=1e-2)
np.testing.assert_allclose(fin_dif, imp_dif_np, atol=1e-2, rtol=1e-2)
np.testing.assert_allclose(imp_dif_np, imp_dif_lp, atol=1e-2, rtol=1e-2)

# center both if balanced problem testing gradient w.r.t weights
if tau_a == 1.0 and tau_b == 1.0 and arg == 0:
g_imp_np = g_imp_np - jnp.mean(g_imp_np)
g_imp_lp = g_imp_lp - jnp.mean(g_imp_lp)

self.assertAllClose(g_imp_np, g_imp_lp, atol=1e-2, rtol=1e-2)
np.testing.assert_allclose(g_imp_np, g_imp_lp, atol=1e-2, rtol=1e-2)

if __name__ == '__main__':
absltest.main()
Loading

0 comments on commit 5246188

Please sign in to comment.