Skip to content

Commit

Permalink
Merge pull request #40 from ott-jax/revert-39-branch_meyer
Browse files Browse the repository at this point in the history
Revert "new test for for entropic LRSinhorn"
  • Loading branch information
marcocuturi authored Mar 24, 2022
2 parents 1216ee2 + 1d42ad1 commit c2c5b89
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 64 deletions.
43 changes: 10 additions & 33 deletions ott/core/sinkhorn_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ def __init__(self,
rank: int = 10,
gamma: float = 1.0,
epsilon: float = 1e-4,
init_type: str = 'random',
lse_mode: bool = True,
threshold: float = 1e-3,
norm_error: int = 1,
Expand All @@ -217,7 +216,6 @@ def __init__(self,
self.rank = rank
self.gamma = gamma
self.epsilon = epsilon
self.init_type = init_type
self.lse_mode = lse_mode
assert lse_mode, "Kernel mode not yet implemented for LRSinkhorn."
self.threshold = threshold
Expand All @@ -241,36 +239,15 @@ def __call__(self,
# Random initialization for q, r, g using rng_key
rng = jax.random.split(jax.random.PRNGKey(self.rng_key), 3)
a, b = ot_prob.a, ot_prob.b
if self.init_type == 'random':
if init_g is None:
init_g = jnp.abs(jax.random.uniform(rng[0], (self.rank,))) + 1
init_g = init_g / jnp.sum(init_g)
if init_q is None:
init_q = jnp.abs(jax.random.normal(rng[1], (a.shape[0], self.rank)))
init_q = init_q * (a / jnp.sum(init_q, axis=1))[:, None]
if init_r is None:
init_r = jnp.abs(jax.random.normal(rng[2], (b.shape[0], self.rank)))
init_r = init_r * (b / jnp.sum(init_r, axis=1))[:, None]
if self.init_type == 'rank_2':
if init_g is None:
init_g = jnp.ones((self.rank,)) / self.rank
lambda_1 = min(jnp.min(a), jnp.min(init_g), jnp.min(b)) / 2
a1 = jnp.arange(1, a.shape[0] + 1)
a1 = a1 / jnp.sum(a1)
a2 = (a - lambda_1 * a1) / (1 - lambda_1)
b1 = jnp.arange(1, b.shape[0] + 1)
b1 = b1 / jnp.sum(b1)
b2 = (b - lambda_1 * b1) / (1 - lambda_1)
g1 = jnp.arange(1, self.rank + 1)
g1 = g1 / jnp.sum(g1)
g2 = (init_g - lambda_1 * g1) / (1 - lambda_1)
if init_q is None:
init_q = lambda_1 * jnp.dot(a1[:, None], g1.reshape(1, -1))
init_q += (1 - lambda_1) * jnp.dot(a2[:, None], g2.reshape(1, -1))
if init_r is None:
init_r = lambda_1 * jnp.dot(b1[:, None], g1.reshape(1, -1))
init_r += (1 - lambda_1) * jnp.dot(b2[:, None], g2.reshape(1, -1))

if init_g is None:
init_g = jnp.abs(jax.random.uniform(rng[0], (self.rank,))) + 1
init_g = init_g / jnp.sum(init_g)
if init_q is None:
init_q = jnp.abs(jax.random.normal(rng[1], (a.shape[0], self.rank)))
init_q = init_q * (a / jnp.sum(init_q, axis=1))[:, None]
if init_r is None:
init_r = jnp.abs(jax.random.normal(rng[2], (b.shape[0], self.rank)))
init_r = init_r * (b / jnp.sum(init_r, axis=1))[:, None]
run_fn = run if not self.jit else jax.jit(run)
return run_fn(ot_prob, self, (init_q, init_r, init_g))

Expand Down Expand Up @@ -461,7 +438,7 @@ def make(
jit: bool = True,
rng_key: int = 0,
kwargs_dys: Any = None) -> LRSinkhorn:

return LRSinkhorn(
rank=rank,
gamma=gamma,
Expand Down
55 changes: 24 additions & 31 deletions tests/core/sinkhorn_lr_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

# Lint as: python3
"""Tests for the Policy."""

from absl.testing import absltest
from absl.testing import parameterized
import jax
Expand All @@ -29,7 +30,7 @@ class SinkhornLRTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self.rng = jax.random.PRNGKey(0)
self.dim = 2
self.dim = 4
self.n = 19
self.m = 17
self.rng, *rngs = jax.random.split(self.rng, 5)
Expand All @@ -47,39 +48,31 @@ def setUp(self):
@parameterized.parameters([True], [False])
def test_euclidean_point_cloud(self, use_lrcgeom):
"""Two point clouds, tested with various parameters."""
init_type_arr = ['rank_2','random']
for init_type in init_type_arr:
threshold = 1e-9
gamma = 100
geom = pointcloud.PointCloud(self.x, self.y)
if use_lrcgeom:
geom = geom.to_LRCGeometry()
ot_prob = problems.LinearProblem(geom, self.a, self.b)
solver = sinkhorn_lr.LRSinkhorn(threshold=threshold, gamma=gamma, rank=2, epsilon=0.0, init_type=init_type)
costs = solver(ot_prob).costs
self.assertTrue(jnp.isclose(costs[-2], costs[-1], rtol=threshold))
cost_1 = costs[costs > -1][-1]

solver = sinkhorn_lr.LRSinkhorn(threshold=threshold, gamma=gamma, rank=10, epsilon=0.0, init_type=init_type)
out = solver(ot_prob)
costs = out.costs
cost_2 = costs[costs > -1][-1]
self.assertGreater(cost_1, cost_2)
threshold = 1e-3
geom = pointcloud.PointCloud(self.x, self.y)
if use_lrcgeom:
geom = geom.to_LRCGeometry()
ot_prob = problems.LinearProblem(geom, self.a, self.b)
solver = sinkhorn_lr.LRSinkhorn(threshold=threshold, rank=10)
costs = solver(ot_prob).costs
self.assertTrue(jnp.isclose(costs[-2], costs[-1], rtol=threshold))
cost_1 = costs[costs > -1][-1]

other_geom = pointcloud.PointCloud(self.x, self.y + 0.3)
cost_other = out.cost_at_geom(other_geom)
self.assertGreater(cost_other, 0.0)
solver = sinkhorn_lr.LRSinkhorn(threshold=threshold, rank=20, epsilon=0.0)
out = solver(ot_prob)
costs = out.costs
cost_2 = costs[costs > -1][-1]
self.assertGreater(cost_1, cost_2)

solver = sinkhorn_lr.LRSinkhorn(threshold=threshold, gamma=gamma, rank=14, epsilon=1e-1, init_type=init_type)
out = solver(ot_prob)
costs = out.costs
cost_3 = costs[costs > -1][-1]
other_geom = pointcloud.PointCloud(self.x, self.y + 0.3)
cost_other = out.cost_at_geom(other_geom)
self.assertGreater(cost_other, 0.0)

solver = sinkhorn_lr.LRSinkhorn(threshold=threshold, gamma=gamma, rank=14, epsilon=1e-3, init_type=init_type)
out = solver(ot_prob)
costs = out.costs
cost_4 = costs[costs > -1][-1]
self.assertGreater(cost_3, cost_4)
solver = sinkhorn_lr.LRSinkhorn(threshold=threshold, rank=20, epsilon=1e-2)
out = solver(ot_prob)
costs = out.costs
cost_3 = costs[costs > -1][-1]
self.assertGreater(cost_3, cost_2)


if __name__ == '__main__':
Expand Down

0 comments on commit c2c5b89

Please sign in to comment.