diff --git a/docs/references.bib b/docs/references.bib index 14ab4ee66..e77d7e136 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -195,15 +195,13 @@ @Misc{richter-powell:21 copyright = {arXiv.org perpetual, non-exclusive license} } -@Misc{bunne:22, - doi = {10.48550/ARXIV.2206.14262}, - url = {https://arxiv.org/abs/2206.14262}, - author = {Bunne, Charlotte and Krause, Andreas and Cuturi, Marco}, - keywords = {Machine Learning (cs.LG), FOS: Computer and information sciences, FOS: Computer and information sciences}, - title = {Supervised Training of Conditional Monge Maps}, - publisher = {arXiv}, - year = {2022}, - copyright = {Creative Commons Attribution Non Commercial Share Alike 4.0 International} +@inproceedings{bunne:22, + title={Supervised Training of Conditional Monge Maps}, + author={Charlotte Bunne and Andreas Krause and marco cuturi}, + booktitle={Advances in Neural Information Processing Systems}, + editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho}, + year={2022}, + url={https://openreview.net/forum?id=sPNtVVUq7wi} } @Article{gelbrich:90, diff --git a/src/ott/problems/quadratic/quadratic_problem.py b/src/ott/problems/quadratic/quadratic_problem.py index aa7192ca5..76c472ec3 100644 --- a/src/ott/problems/quadratic/quadratic_problem.py +++ b/src/ott/problems/quadratic/quadratic_problem.py @@ -138,23 +138,28 @@ def marginal_dependent_cost( Uses the first term in eq. 6, p. 1 of :cite:`peyre:16`. - Let :math:`p` [num_a,] be the marginal of the transport matrix for samples - from `geom_xx` and :math:`q` [num_b,] be the marginal of the transport - matrix for samples from `geom_yy`. `cost_xx` (resp. `cost_yy`) is the - cost matrix of `geom_xx` (resp. `geom_yy`). The cost term that - depends on these marginals can be written as: + Let :math:`p` be the `[n,]` marginal of the transport matrix for samples + from :attr:`geom_xx` and :math:`q` the `[m,]` marginal of the + transport matrix for samples from :attr:`geom_yy`. - `marginal_dep_term` = `lin1`(`cost_xx`) :math:`p \mathbb{1}_{num_b}^T` - + (`lin2`(`cost_yy`) :math:`q \mathbb{1}_{num_a}^T)^T` + When ``cost_xx`` (resp. ``cost_yy``) is the cost matrix of :attr:`geom_xx` + (resp. :attr:`geom_yy`), the cost term that depends on these marginals can + be written as: + + .. math:: + + \text{marginal_dep_term} = \text{lin1}(\text{cost_xx}) p \mathbb{1}_{m}^T + + \mathbb{1}_{n}(\text{lin2}(\text{cost_yy}) q)^T + + This helper function instantiates these two low-rank matrices and groups + them into a single low-rank cost geometry object. Args: - marginal_1: jnp.ndarray[num_a,], marginal of the transport matrix - for samples from geom_xx - marginal_2: jnp.ndarray[num_b,], marginal of the transport matrix - for samples from geom_yy + marginal_1: [n,], first marginal of transport matrix. + marginal_2: [m,], second marginal of transport matrix. Returns: - Low-rank geometry. + Low-rank geometry of rank 2, storing normalization constants. """ if self._loss_name == 'sqeucl': # quadratic apply, efficient for LR tmp1 = self.geom_xx.apply_square_cost(marginal_1, axis=1) diff --git a/src/ott/solvers/linear/sinkhorn.py b/src/ott/solvers/linear/sinkhorn.py index 232bc16e3..735084dcc 100644 --- a/src/ott/solvers/linear/sinkhorn.py +++ b/src/ott/solvers/linear/sinkhorn.py @@ -308,7 +308,7 @@ def dual_cost(self) -> jnp.ndarray: return dual_cost @property - def primal_cost(self) -> jnp.ndarray: + def primal_cost(self) -> float: """Return transport cost of current solution at geometry.""" return self.transport_cost_at_geom(other_geom=self.geom) diff --git a/src/ott/solvers/linear/sinkhorn_lr.py b/src/ott/solvers/linear/sinkhorn_lr.py index c0b824d20..d8296daf0 100644 --- a/src/ott/solvers/linear/sinkhorn_lr.py +++ b/src/ott/solvers/linear/sinkhorn_lr.py @@ -209,7 +209,7 @@ def transport_cost_at_geom(self, other_geom: geometry.Geometry) -> float: return self.cost_at_geom(other_geom) @property - def primal_cost(self) -> jnp.ndarray: + def primal_cost(self) -> float: """Return (by recomputing it) transport cost of current solution.""" return self.transport_cost_at_geom(other_geom=self.geom) diff --git a/src/ott/solvers/quadratic/gromov_wasserstein.py b/src/ott/solvers/quadratic/gromov_wasserstein.py index c6475a1b7..327b2243c 100644 --- a/src/ott/solvers/quadratic/gromov_wasserstein.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein.py @@ -89,6 +89,11 @@ def reg_gw_cost(self) -> float: def _rescale_factor(self) -> float: return jnp.sqrt(self.old_transport_mass / self.linear_state.transport_mass) + @property + def primal_cost(self) -> float: + """Return transport cost of current linear OT solution at geometry.""" + return self.linear_state.transport_cost_at_geom(other_geom=self.geom) + class GWState(NamedTuple): """Holds the state of the Gromov-Wasserstein solver. diff --git a/tests/solvers/quadratic/gw_test.py b/tests/solvers/quadratic/gw_test.py index e45feeee3..bde02d30b 100644 --- a/tests/solvers/quadratic/gw_test.py +++ b/tests/solvers/quadratic/gw_test.py @@ -196,25 +196,34 @@ def reg_gw(a: jnp.ndarray, b: jnp.ndarray, ) @pytest.mark.fast - @pytest.mark.parametrize("unbalanced", [False, True]) - def test_gw_pointcloud(self, unbalanced: bool): + @pytest.mark.parametrize( + "balanced,rank", [(True, -1), (False, -1), (True, 3)] + ) + def test_gw_pointcloud(self, balanced: bool, rank: int): """Test basic computations pointclouds.""" + geom_x = pointcloud.PointCloud(self.x) + geom_y = pointcloud.PointCloud(self.y) + tau_a, tau_b = (1.0, 1.0) if balanced else (self.tau_a, self.tau_b) + prob = quadratic_problem.QuadraticProblem( + geom_x, geom_y, a=self.a, b=self.b, tau_a=tau_a, tau_b=tau_b + ) + solver = gromov_wasserstein.GromovWasserstein( + rank=rank, epsilon=0.0 if rank > 0 else 1.0, max_iterations=10 + ) - def reg_gw( - x: jnp.ndarray, y: jnp.ndarray, a: jnp.ndarray, b: jnp.ndarray - ) -> float: - geom_x = pointcloud.PointCloud(x) - geom_y = pointcloud.PointCloud(y) - tau_a, tau_b = (self.tau_a, self.tau_b) if unbalanced else (1.0, 1.0) - prob = quadratic_problem.QuadraticProblem( - geom_x, geom_y, a=a, b=b, tau_a=tau_a, tau_b=tau_b - ) - solver = gromov_wasserstein.GromovWasserstein( - epsilon=1.0, max_iterations=10 + out = solver(prob) + # TODO(cuturi): test primal cost for un-balanced case as well. + if balanced: + u = geom_x.apply_square_cost(out.matrix.sum(axis=-1)).squeeze() + v = geom_y.apply_square_cost(out.matrix.sum(axis=0)).squeeze() + c = (geom_x.cost_matrix @ out.matrix) @ geom_y.cost_matrix + c = (u[:, None] + v[None, :] - 2 * c) + + np.testing.assert_allclose( + out.primal_cost, jnp.sum(c * out.matrix), rtol=1e-3 ) - return solver(prob).reg_gw_cost - assert not jnp.isnan(reg_gw(self.x, self.y, self.a, self.b)) + assert not jnp.isnan(out.reg_gw_cost) @pytest.mark.parametrize( "unbalanced,unbalanced_correction", [(False, False), (True, False),