diff --git a/src/ott/initializers/quadratic/initializers.py b/src/ott/initializers/quadratic/initializers.py index de4b6118c..62950bc4e 100644 --- a/src/ott/initializers/quadratic/initializers.py +++ b/src/ott/initializers/quadratic/initializers.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc -from typing import TYPE_CHECKING, Any, Dict, Sequence, Tuple +from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple import jax import jax.numpy as jnp @@ -121,7 +121,11 @@ class QuadraticInitializer(BaseQuadraticInitializer): """ def _create_geometry( - self, quad_prob: "quadratic_problem.QuadraticProblem", *, epsilon: float, + self, + quad_prob: "quadratic_problem.QuadraticProblem", + *, + epsilon: float, + relative_epsilon: Optional[bool] = None, **kwargs: Any ) -> geometry.Geometry: """Compute initial geometry for linearization. @@ -129,7 +133,8 @@ def _create_geometry( Args: quad_prob: Quadratic OT problem. epsilon: Epsilon regularization. - kwargs: Additional keyword arguments, unused. + relative_epsilon: Flag, use `relative_epsilon` or not in geometry. + kwargs: Keyword arguments for :class:`~ott.geometry.geometry.Geometry`. Returns: The initial geometry used to initialize the linearized problem. @@ -160,8 +165,12 @@ def _create_geometry( ) cost_matrix = marginal_cost.cost_matrix - tmp + unbalanced_correction - cost_matrix += quad_prob.fused_penalty * quad_prob._fused_cost_matrix() - return geometry.Geometry(cost_matrix=cost_matrix, epsilon=epsilon) + cost_matrix += quad_prob.fused_penalty * quad_prob._fused_cost_matrix + return geometry.Geometry( + cost_matrix=cost_matrix, + epsilon=epsilon, + relative_epsilon=relative_epsilon + ) class LRQuadraticInitializer(BaseQuadraticInitializer): @@ -176,12 +185,16 @@ def __init__(self, lr_linear_initializer: "initializers_lr.LRInitializer"): self._linear_lr_initializer = lr_linear_initializer def _create_geometry( - self, quad_prob: "quadratic_problem.QuadraticProblem", **kwargs: Any + self, + quad_prob: "quadratic_problem.QuadraticProblem", + relative_epsilon: Optional[bool] = False, + **kwargs: Any ) -> geometry.Geometry: """Compute initial geometry for linearization. Args: quad_prob: Quadratic OT problem. + relative_epsilon: Whether to use relative epsilon in the geometry. kwargs: Keyword arguments for :meth:`~ott.initializers.linear.initializers_lr.LRInitializer.__call__`. @@ -201,7 +214,7 @@ def _create_geometry( epsilon=None, ) - return quad_prob.update_lr_geom(tmp_out) + return quad_prob.update_lr_geom(tmp_out, relative_epsilon=relative_epsilon) @property def rank(self) -> int: diff --git a/src/ott/problems/quadratic/quadratic_problem.py b/src/ott/problems/quadratic/quadratic_problem.py index 84f01ff6e..a14273a92 100644 --- a/src/ott/problems/quadratic/quadratic_problem.py +++ b/src/ott/problems/quadratic/quadratic_problem.py @@ -134,8 +134,6 @@ def marginal_dependent_cost( self, marginal_1: jnp.ndarray, marginal_2: jnp.ndarray, - *, - remove_scale: bool = False, ) -> low_rank.LRCGeometry: r"""Initialize cost term that depends on the marginals of the transport. @@ -160,17 +158,11 @@ def marginal_dependent_cost( Args: marginal_1: [n,], first marginal of transport matrix. marginal_2: [m,], second marginal of transport matrix. - remove_scale: Whether to remove any scaling from the cost matrices before - computing the linearization. Returns: Low-rank geometry of rank 2, storing normalization constants. """ geom_xx, geom_yy = self.geom_xx, self.geom_yy - if remove_scale: - geom_xx = geom_xx.set_scale_cost(1.0) - geom_yy = geom_yy.set_scale_cost(1.0) - if self._loss_name == "sqeucl": # quadratic apply, efficient for LR tmp1 = geom_xx.apply_square_cost(marginal_1, axis=1) tmp2 = geom_yy.apply_square_cost(marginal_2, axis=1) @@ -251,14 +243,12 @@ def init_transport_mass(self) -> float: def update_lr_geom( self, lr_sink: "sinkhorn_lr.LRSinkhornOutput", - remove_scale: bool = False, + relative_epsilon: Optional[bool] = None, ) -> geometry.Geometry: """Recompute (possibly LRC) linearization using LR Sinkhorn output.""" marginal_1 = lr_sink.marginal(1) marginal_2 = lr_sink.marginal(0) - marginal_cost = self.marginal_dependent_cost( - marginal_1, marginal_2, remove_scale=remove_scale - ) + marginal_cost = self.marginal_dependent_cost(marginal_1, marginal_2) # Extract factors from LR Sinkhorn output q, r, inv_sqg = lr_sink.q, lr_sink.r, 1.0 / jnp.sqrt(lr_sink.g) @@ -268,20 +258,20 @@ def update_lr_geom( # Handle LRC Geometry case. h1, h2 = self.quad_loss geom_xx, geom_yy, geom_xy = self.geom_xx, self.geom_yy, self.geom_xy - if remove_scale: - geom_xx = geom_xx.set_scale_cost(1.0) - geom_yy = geom_yy.set_scale_cost(1.0) - geom_xy = geom_xy.set_scale_cost(1.0) if self.is_fused else None tmp1 = apply_cost(geom_xx, q, axis=1, fn=h1) tmp2 = apply_cost(geom_yy, r, axis=1, fn=h2) if self.is_low_rank: - geom = low_rank.LRCGeometry(cost_1=tmp1, cost_2=-tmp2) + marginal_cost + geom = low_rank.LRCGeometry( + cost_1=tmp1, cost_2=-tmp2, relative_epsilon=relative_epsilon + ) + marginal_cost if self.is_fused: geom = geom + geom_xy else: cost_matrix = marginal_cost.cost_matrix - jnp.dot(tmp1, tmp2.T) - cost_matrix += self.fused_penalty * self._fused_cost_matrix(remove_scale) - geom = geometry.Geometry(cost_matrix=cost_matrix) + cost_matrix += self.fused_penalty * self._fused_cost_matrix + geom = geometry.Geometry( + cost_matrix=cost_matrix, relative_epsilon=relative_epsilon + ) return geom # noqa: RET504 def update_linearization( @@ -289,7 +279,7 @@ def update_linearization( transport: Transport, epsilon: Optional[Union[epsilon_scheduler.Epsilon, float]] = None, old_transport_mass: float = 1.0, - remove_scale: bool = False, + relative_epsilon: Optional[bool] = None, ) -> linear_problem.LinearProblem: """Update linearization of GW problem by updating cost matrix. @@ -307,11 +297,8 @@ def update_linearization( epsilon: An epsilon scheduler or a float passed on to the linearization. old_transport_mass: Sum of the elements of the transport matrix at the previous iteration. - remove_scale: Whether to remove any scaling from the cost matrices when - computing the linearization of the quadratic cost. At the moment, this - is only used when doing this update at the last outer iteration of the - :class:`~ott.solvers.quadratic.gromov_wasserstein.GromovWasserstein` - solver. + relative_epsilon: Whether to use relative epsilon in the linearized + geometry. Returns: Updated linear OT problem, a new local linearization of GW problem. @@ -326,9 +313,7 @@ def update_linearization( marginal_1 = transport.marginal(axis=1) * rescale_factor marginal_2 = transport.marginal(axis=0) * rescale_factor - marginal_cost = self.marginal_dependent_cost( - marginal_1, marginal_2, remove_scale=remove_scale - ) + marginal_cost = self.marginal_dependent_cost(marginal_1, marginal_2) transport_matrix = transport.matrix * rescale_factor @@ -342,18 +327,18 @@ def update_linearization( h1, h2 = self.quad_loss geom_xx, geom_yy = self.geom_xx, self.geom_yy - if remove_scale: - geom_xx = geom_xx.set_scale_cost(1.0) - geom_yy = geom_yy.set_scale_cost(1.0) tmp = apply_cost(geom_xx, transport_matrix, axis=1, fn=h1) tmp = apply_cost(geom_yy, tmp.T, axis=1, fn=h2).T cost_matrix = marginal_cost.cost_matrix - tmp + unbalanced_correction - cost_matrix += self.fused_penalty * rescale_factor * \ - self._fused_cost_matrix(remove_scale) + cost_matrix += self.fused_penalty * rescale_factor * self._fused_cost_matrix - geom = geometry.Geometry(cost_matrix=cost_matrix, epsilon=epsilon) + geom = geometry.Geometry( + cost_matrix=cost_matrix, + epsilon=epsilon, + relative_epsilon=relative_epsilon + ) return linear_problem.LinearProblem( geom, self.a, self.b, tau_a=self.tau_a, tau_b=self.tau_b @@ -363,24 +348,22 @@ def update_lr_linearization( self, lr_sink: "sinkhorn_lr.LRSinkhornOutput", *, - remove_scale: bool = False, + relative_epsilon: Optional[bool] = None, ) -> linear_problem.LinearProblem: """Update a Quad problem linearization using a LR Sinkhorn.""" return linear_problem.LinearProblem( - self.update_lr_geom(lr_sink, remove_scale=remove_scale), + self.update_lr_geom(lr_sink, relative_epsilon=relative_epsilon), self.a, self.b, tau_a=self.tau_a, tau_b=self.tau_b ) - def _fused_cost_matrix(self, - unscale: bool = False) -> Union[float, jnp.ndarray]: + @property + def _fused_cost_matrix(self) -> Union[float, jnp.ndarray]: if not self.is_fused: return 0.0 geom_xy = self.geom_xy - if unscale: - geom_xy = geom_xy.set_scale_cost(1.0) if isinstance(geom_xy, pointcloud.PointCloud) and geom_xy.is_online: return geom_xy._compute_cost_matrix() * geom_xy.inv_scale_cost diff --git a/src/ott/solvers/quadratic/gromov_wasserstein.py b/src/ott/solvers/quadratic/gromov_wasserstein.py index 037da1482..ca3c24090 100644 --- a/src/ott/solvers/quadratic/gromov_wasserstein.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein.py @@ -164,13 +164,8 @@ class GromovWasserstein(was_solver.WassersteinSolver): warm_start: Whether to initialize (low-rank) Sinkhorn calls using values from the previous iteration. If `None`, warm starts are not used for standard Sinkhorn, but used for low-rank Sinkhorn. - unscale_last_linearization: Whether to remove any scaling from the - cost matrices of the last linearization stored in - :attr:`~ott.solvers.quadratic.gromov_wasserstein.GWOutput.geom`. - This has the practical benefit that, while the OT coupling matrices - obtained with GW might have been computed by re-scaling cost matrices for - numerical stability, the last linearization stored in the geometry will be - unscaled and recomputed with the original cost values. + relative_epsilon: Whether to use relative epsilon in the linearized + geometry. quad_initializer: Quadratic initializer. If the solver is entropic, :class:`~ott.initializers.quadratic.initializers.QuadraticInitializer` is always used. Otherwise, the quadratic initializer wraps the low-rank @@ -194,7 +189,7 @@ def __init__( self, *args: Any, warm_start: Optional[bool] = None, - unscale_last_linearization: bool = False, + relative_epsilon: Optional[bool] = None, quad_initializer: Optional[ Union[Literal["random", "rank2", "k-means", "generalized-k-means"], quad_initializers.BaseQuadraticInitializer]] = None, @@ -204,7 +199,7 @@ def __init__( ): super().__init__(*args, **kwargs) self._warm_start = warm_start - self.unscale_last_linearization = unscale_last_linearization + self.relative_epsilon = relative_epsilon self.quad_initializer = quad_initializer self.progress_fn = progress_fn self.kwargs_init = {} if kwargs_init is None else kwargs_init @@ -236,21 +231,27 @@ def __call__( if init is None: initializer = self.create_initializer(prob) - init = initializer(prob, epsilon=self.epsilon, rng=rng1, **kwargs) + init = initializer( + prob, + epsilon=self.epsilon, + rng=rng1, + relative_epsilon=self.relative_epsilon, + **kwargs + ) out = iterations(self, prob, init, rng2) # TODO(lpapaxanthoos): remove stop_gradient when using backprop if self.is_low_rank: linearization = prob.update_lr_linearization( jax.lax.stop_gradient(out.linear_state), - remove_scale=self.unscale_last_linearization + relative_epsilon=self.relative_epsilon, ) else: linearization = prob.update_linearization( jax.lax.stop_gradient(out.linear_state), epsilon=self.epsilon, old_transport_mass=jax.lax.stop_gradient(out.old_transport_mass), - remove_scale=self.unscale_last_linearization, + relative_epsilon=self.relative_epsilon, ) linear_state = out.linear_state.set_cost(linearization, True, True) @@ -366,7 +367,7 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 children, aux_data = super().tree_flatten() aux_data["warm_start"] = self._warm_start aux_data["progress_fn"] = self.progress_fn - aux_data["unscale_last_linearization"] = self.unscale_last_linearization + aux_data["relative_epsilon"] = self.relative_epsilon aux_data["quad_initializer"] = self.quad_initializer aux_data["kwargs_init"] = self.kwargs_init return children, aux_data @@ -396,12 +397,17 @@ def body_fn( rng = state.rngs[iteration] init = (lin_state.q, lin_state.r, lin_state.g) if solver.warm_start else (None, None, None) - linear_pb = prob.update_lr_linearization(state.linear_state) + linear_pb = prob.update_lr_linearization( + state.linear_state, relative_epsilon=solver.relative_epsilon + ) out = solver.linear_ot_solver(linear_pb, init=init, rng=rng) else: init = (lin_state.f, lin_state.g) if solver.warm_start else (None, None) linear_pb = prob.update_linearization( - lin_state, solver.epsilon, state.old_transport_mass + lin_state, + solver.epsilon, + state.old_transport_mass, + relative_epsilon=solver.relative_epsilon, ) out = solver.linear_ot_solver(linear_pb, init=init) diff --git a/tests/solvers/quadratic/gw_test.py b/tests/solvers/quadratic/gw_test.py index 9b6a6cec5..c43cc92d1 100644 --- a/tests/solvers/quadratic/gw_test.py +++ b/tests/solvers/quadratic/gw_test.py @@ -394,36 +394,31 @@ def test_gw_lr_warm_start_helps(self, rng: jax.random.PRNGKeyArray): with pytest.raises(AssertionError): np.testing.assert_allclose(out_cold.matrix, out_warm.matrix) - @pytest.mark.parametrize("scale_cost", [1.15, 2.3]) - def test_unscale_last_linearization( - self, rng: jax.random.PRNGKeyArray, scale_cost: float + @pytest.mark.parametrize("scale_cost", [1.0, "mean"]) + def test_relative_epsilon( + self, + rng: jax.random.PRNGKeyArray, + scale_cost: Union[float, str], ): + eps = 1e-2 rng1, rng2 = jax.random.split(rng, 2) - n, m = 7, 16 - rtol = atol = 1e-3 - geom_x = pointcloud.PointCloud( - jax.random.normal(rng1, (n, 2)), scale_cost=scale_cost + jax.random.normal(rng1, (49, 5)), scale_cost=scale_cost ) geom_y = pointcloud.PointCloud( - jax.random.normal(rng2, (m, 6)), scale_cost=scale_cost + jax.random.normal(rng2, (78, 6)), scale_cost=scale_cost ) - # hold true only when `scale_cost` is the same for both geometries - expected = 1.0 / (geom_x.inv_scale_cost * geom_y.inv_scale_cost) - prob = quadratic_problem.QuadraticProblem(geom_x, geom_y) - solver_scaled = gromov_wasserstein.GromovWasserstein( - unscale_last_linearization=False - ) - solver_unscaled = gromov_wasserstein.GromovWasserstein( - unscale_last_linearization=True + + solver = gromov_wasserstein.GromovWasserstein( + epsilon=eps, relative_epsilon=True ) - out_scaled = solver_scaled(prob) - out_unscaled = solver_unscaled(prob) - actual = out_unscaled.primal_cost / out_scaled.primal_cost + out = solver(prob) - np.testing.assert_allclose( - out_scaled.matrix, out_unscaled.matrix, rtol=rtol, atol=atol - ) - np.testing.assert_allclose(expected, actual, rtol=rtol, atol=atol) + if scale_cost == 1.0: + assert 40 < out.reg_gw_cost < 41 + assert 38 < out.primal_cost < 39 + else: + assert 0.215 < out.reg_gw_cost < 0.22 + assert 0.19 < out.primal_cost < 0.20