Skip to content

Commit

Permalink
add primal_cost property to output of GW, to compute cost without a…
Browse files Browse the repository at this point in the history
…ny regularization. (#286)

* add primal_cost

* precommit

* avoid using unbalanced for LRSinkhorn

* take comments into account.

* bib

* incorporating comments.

* update
  • Loading branch information
marcocuturi authored Feb 15, 2023
1 parent f2be655 commit 963f3b1
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 38 deletions.
16 changes: 7 additions & 9 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -195,15 +195,13 @@ @Misc{richter-powell:21
copyright = {arXiv.org perpetual, non-exclusive license}
}

@Misc{bunne:22,
doi = {10.48550/ARXIV.2206.14262},
url = {https://arxiv.org/abs/2206.14262},
author = {Bunne, Charlotte and Krause, Andreas and Cuturi, Marco},
keywords = {Machine Learning (cs.LG), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {Supervised Training of Conditional Monge Maps},
publisher = {arXiv},
year = {2022},
copyright = {Creative Commons Attribution Non Commercial Share Alike 4.0 International}
@inproceedings{bunne:22,
title={Supervised Training of Conditional Monge Maps},
author={Charlotte Bunne and Andreas Krause and marco cuturi},
booktitle={Advances in Neural Information Processing Systems},
editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho},
year={2022},
url={https://openreview.net/forum?id=sPNtVVUq7wi}
}

@Article{gelbrich:90,
Expand Down
29 changes: 17 additions & 12 deletions src/ott/problems/quadratic/quadratic_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,23 +138,28 @@ def marginal_dependent_cost(
Uses the first term in eq. 6, p. 1 of :cite:`peyre:16`.
Let :math:`p` [num_a,] be the marginal of the transport matrix for samples
from `geom_xx` and :math:`q` [num_b,] be the marginal of the transport
matrix for samples from `geom_yy`. `cost_xx` (resp. `cost_yy`) is the
cost matrix of `geom_xx` (resp. `geom_yy`). The cost term that
depends on these marginals can be written as:
Let :math:`p` be the `[n,]` marginal of the transport matrix for samples
from :attr:`geom_xx` and :math:`q` the `[m,]` marginal of the
transport matrix for samples from :attr:`geom_yy`.
`marginal_dep_term` = `lin1`(`cost_xx`) :math:`p \mathbb{1}_{num_b}^T`
+ (`lin2`(`cost_yy`) :math:`q \mathbb{1}_{num_a}^T)^T`
When ``cost_xx`` (resp. ``cost_yy``) is the cost matrix of :attr:`geom_xx`
(resp. :attr:`geom_yy`), the cost term that depends on these marginals can
be written as:
.. math::
\text{marginal_dep_term} = \text{lin1}(\text{cost_xx}) p \mathbb{1}_{m}^T
+ \mathbb{1}_{n}(\text{lin2}(\text{cost_yy}) q)^T
This helper function instantiates these two low-rank matrices and groups
them into a single low-rank cost geometry object.
Args:
marginal_1: jnp.ndarray<float>[num_a,], marginal of the transport matrix
for samples from geom_xx
marginal_2: jnp.ndarray<float>[num_b,], marginal of the transport matrix
for samples from geom_yy
marginal_1: [n,], first marginal of transport matrix.
marginal_2: [m,], second marginal of transport matrix.
Returns:
Low-rank geometry.
Low-rank geometry of rank 2, storing normalization constants.
"""
if self._loss_name == 'sqeucl': # quadratic apply, efficient for LR
tmp1 = self.geom_xx.apply_square_cost(marginal_1, axis=1)
Expand Down
2 changes: 1 addition & 1 deletion src/ott/solvers/linear/sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def dual_cost(self) -> jnp.ndarray:
return dual_cost

@property
def primal_cost(self) -> jnp.ndarray:
def primal_cost(self) -> float:
"""Return transport cost of current solution at geometry."""
return self.transport_cost_at_geom(other_geom=self.geom)

Expand Down
2 changes: 1 addition & 1 deletion src/ott/solvers/linear/sinkhorn_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def transport_cost_at_geom(self, other_geom: geometry.Geometry) -> float:
return self.cost_at_geom(other_geom)

@property
def primal_cost(self) -> jnp.ndarray:
def primal_cost(self) -> float:
"""Return (by recomputing it) transport cost of current solution."""
return self.transport_cost_at_geom(other_geom=self.geom)

Expand Down
5 changes: 5 additions & 0 deletions src/ott/solvers/quadratic/gromov_wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ def reg_gw_cost(self) -> float:
def _rescale_factor(self) -> float:
return jnp.sqrt(self.old_transport_mass / self.linear_state.transport_mass)

@property
def primal_cost(self) -> float:
"""Return transport cost of current linear OT solution at geometry."""
return self.linear_state.transport_cost_at_geom(other_geom=self.geom)


class GWState(NamedTuple):
"""Holds the state of the Gromov-Wasserstein solver.
Expand Down
39 changes: 24 additions & 15 deletions tests/solvers/quadratic/gw_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,25 +196,34 @@ def reg_gw(a: jnp.ndarray, b: jnp.ndarray,
)

@pytest.mark.fast
@pytest.mark.parametrize("unbalanced", [False, True])
def test_gw_pointcloud(self, unbalanced: bool):
@pytest.mark.parametrize(
"balanced,rank", [(True, -1), (False, -1), (True, 3)]
)
def test_gw_pointcloud(self, balanced: bool, rank: int):
"""Test basic computations pointclouds."""
geom_x = pointcloud.PointCloud(self.x)
geom_y = pointcloud.PointCloud(self.y)
tau_a, tau_b = (1.0, 1.0) if balanced else (self.tau_a, self.tau_b)
prob = quadratic_problem.QuadraticProblem(
geom_x, geom_y, a=self.a, b=self.b, tau_a=tau_a, tau_b=tau_b
)
solver = gromov_wasserstein.GromovWasserstein(
rank=rank, epsilon=0.0 if rank > 0 else 1.0, max_iterations=10
)

def reg_gw(
x: jnp.ndarray, y: jnp.ndarray, a: jnp.ndarray, b: jnp.ndarray
) -> float:
geom_x = pointcloud.PointCloud(x)
geom_y = pointcloud.PointCloud(y)
tau_a, tau_b = (self.tau_a, self.tau_b) if unbalanced else (1.0, 1.0)
prob = quadratic_problem.QuadraticProblem(
geom_x, geom_y, a=a, b=b, tau_a=tau_a, tau_b=tau_b
)
solver = gromov_wasserstein.GromovWasserstein(
epsilon=1.0, max_iterations=10
out = solver(prob)
# TODO(cuturi): test primal cost for un-balanced case as well.
if balanced:
u = geom_x.apply_square_cost(out.matrix.sum(axis=-1)).squeeze()
v = geom_y.apply_square_cost(out.matrix.sum(axis=0)).squeeze()
c = (geom_x.cost_matrix @ out.matrix) @ geom_y.cost_matrix
c = (u[:, None] + v[None, :] - 2 * c)

np.testing.assert_allclose(
out.primal_cost, jnp.sum(c * out.matrix), rtol=1e-3
)
return solver(prob).reg_gw_cost

assert not jnp.isnan(reg_gw(self.x, self.y, self.a, self.b))
assert not jnp.isnan(out.reg_gw_cost)

@pytest.mark.parametrize(
"unbalanced,unbalanced_correction", [(False, False), (True, False),
Expand Down

0 comments on commit 963f3b1

Please sign in to comment.