Skip to content

Commit

Permalink
Merge of PR #19
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 411057494
  • Loading branch information
marcocuturi committed Nov 19, 2021
1 parent c822977 commit 9a26193
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 9 deletions.
2 changes: 2 additions & 0 deletions ott/core/gromov_wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ def gromov_wasserstein(
if loss_fn is None:
raise ValueError('Unknown loss. Either pass an instance of GWLoss or '
f'a string among: [{",".join(GW_LOSSES.keys())}]')
sinkhorn_kwargs = {} if sinkhorn_kwargs is None else sinkhorn_kwargs

tau_a = sinkhorn_kwargs.get('tau_a', 1.0)
tau_b = sinkhorn_kwargs.get('tau_b', 1.0)
if tau_a != 1.0 or tau_b != 1.0:
Expand Down
6 changes: 1 addition & 5 deletions ott/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,8 @@ def _epsilon(self):

@property
def cost_matrix(self):
"""Returns cost matrix, computes it if only kernel was specified."""
if self._cost_matrix is None:
# If no epsilon was passed on to the geometry, then assume it is one by
# default.
cost = -jnp.log(self._kernel_matrix)
return cost if self._epsilon_init is None else self.epsilon * cost
return -self.epsilon * jnp.log(self._kernel_matrix)
return self._cost_matrix

@property
Expand Down
20 changes: 16 additions & 4 deletions tests/core/gromov_wasserstein_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def setUp(self):
@parameterized.parameters([True], [False])
def test_gradient_marginals_gromov_wasserstein(self, jit):
"""Test gradient w.r.t. probability weights."""
geom_x = pointcloud.PointCloud(self.x, self.x)
geom_y = pointcloud.PointCloud(self.y, self.y)
geom_x = pointcloud.PointCloud(self.x)
geom_y = pointcloud.PointCloud(self.y)

def reg_gw(a, b, implicit):
sinkhorn_kwargs = {'implicit_differentiation': implicit,
Expand Down Expand Up @@ -77,13 +77,25 @@ def reg_gw(a, b, implicit):
self.assertAllClose(grad_matrices[0][1], grad_matrices[1][1],
rtol=1e-02, atol=1e-02)

@parameterized.parameters([True], [False])
def test_gromov_wasserstein_pointcloud(self, lse_mode):
"""Test basic computations pointclouds."""

def reg_gw(x, y, a, b):
geom_x = pointcloud.PointCloud(x)
geom_y = pointcloud.PointCloud(y)
return gromov_wasserstein.gromov_wasserstein(
geom_x, geom_y, a=a, b=b, epsilon=1.0, max_iterations=10).reg_gw_cost

self.assertIsNot(jnp.isnan(reg_gw(self.x, self.y, self.a, self.b)), True)

@parameterized.parameters([True], [False])
def test_gradient_gromov_wasserstein_pointcloud(self, lse_mode):
"""Test gradient w.r.t. pointclouds."""

def reg_gw(x, y, a, b, implicit):
geom_x = pointcloud.PointCloud(x, x)
geom_y = pointcloud.PointCloud(y, y)
geom_x = pointcloud.PointCloud(x)
geom_y = pointcloud.PointCloud(y)
sinkhorn_kwargs = {'implicit_differentiation': implicit,
'max_iterations': 1001, 'lse_mode': lse_mode}
return gromov_wasserstein.gromov_wasserstein(
Expand Down

0 comments on commit 9a26193

Please sign in to comment.