Skip to content

Commit

Permalink
Feature/quadratic solve (#433)
Browse files Browse the repository at this point in the history
* Add `solve` function for quadratic problems

* Update docs

* Fix typo in notebook

* Fix implicit diff passing in LR Sinkhorn

* Fix typo

* Clean `GWLoss` docs
  • Loading branch information
michalk8 authored Sep 12, 2023
1 parent cea562b commit 06b7428
Show file tree
Hide file tree
Showing 11 changed files with 168 additions and 145 deletions.
1 change: 1 addition & 0 deletions docs/initializers/linear.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Low-rank Sinkhorn Initializers
.. autosummary::
:toctree: _autosummary

initializers_lr.LRInitializer
initializers_lr.RandomInitializer
initializers_lr.Rank2Initializer
initializers_lr.KMeansInitializer
Expand Down
1 change: 1 addition & 0 deletions docs/problems/quadratic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@ Costs
.. autosummary::
:toctree: _autosummary

quadratic_costs.GWLoss
quadratic_costs.make_square_loss
quadratic_costs.make_kl_loss
2 changes: 1 addition & 1 deletion docs/solvers/quadratic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Gromov-Wasserstein Solvers
.. autosummary::
:toctree: _autosummary

gromov_wasserstein.solve
solve
gromov_wasserstein.GromovWasserstein
gromov_wasserstein.GWOutput
gromov_wasserstein_lr.LRGromovWasserstein
Expand Down
41 changes: 36 additions & 5 deletions src/ott/problems/quadratic/quadratic_costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.
from typing import Callable, NamedTuple

import jax
import jax.numpy as jnp
import jax.scipy as jsp

__all__ = ["make_square_loss", "make_kl_loss"]

Expand All @@ -24,23 +24,54 @@ class Loss(NamedTuple): # noqa: D101
is_linear: bool


class GWLoss(NamedTuple): # noqa: D101
class GWLoss(NamedTuple):
r"""Efficient decomposition of the Gromov-Wasserstein loss function.
The loss function :math:`L` is assumed to match the form given in eq. 5. of
:cite:`peyre:16`:
.. math::
L(x, y) = f_1(x) + f_2(y) - h_1(x) h_2(y)
Args:
f1: First linear term.
f2: Second linear term.
h1: First quadratic term.
h2: Second quadratic term.
"""
f1: Loss
f2: Loss
h1: Loss
h2: Loss


def make_square_loss() -> GWLoss: # noqa: D103
def make_square_loss() -> GWLoss:
"""Squared Euclidean loss for Gromov-Wasserstein.
See Prop. 1 and Remark 1 of :cite:`peyre:16` for more information.
Returns:
The squared Euclidean loss.
"""
f1 = Loss(lambda x: x ** 2, is_linear=False)
f2 = Loss(lambda y: y ** 2, is_linear=False)
h1 = Loss(lambda x: x, is_linear=True)
h2 = Loss(lambda y: 2.0 * y, is_linear=True)
return GWLoss(f1, f2, h1, h2)


def make_kl_loss(clipping_value: float = 1e-8) -> GWLoss: # noqa: D103
f1 = Loss(lambda x: -jax.scipy.special.entr(x) - x, is_linear=False)
def make_kl_loss(clipping_value: float = 1e-8) -> GWLoss:
r"""Kullback-Leibler loss for Gromov-Wasserstein.
See Prop. 1 and Remark 1 of :cite:`peyre:16` for more information.
Args:
clipping_value: Value used to avoid :math:`\log(0)`.
Returns:
The KL loss.
"""
f1 = Loss(lambda x: -jsp.special.entr(x) - x, is_linear=False)
f2 = Loss(lambda y: y, is_linear=True)
h1 = Loss(lambda x: x, is_linear=True)
h2 = Loss(lambda y: jnp.log(jnp.clip(y, clipping_value)), is_linear=False)
Expand Down
35 changes: 14 additions & 21 deletions src/ott/problems/quadratic/quadratic_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,16 @@ class QuadraticProblem:
is assumed to match the form given in eq. 5., with our notations:
.. math::
L(x, y) = lin1(x) + lin2(y) - quad1(x) * quad2(y)
L(x, y) = f_1(x) + f_2(y) - h_1(x) h_2(y)
Args:
geom_xx: Ground geometry of the first space.
geom_yy: Ground geometry of the second space.
geom_xy: Geometry defining the linear penalty term for
Fused Gromov-Wasserstein. If `None`, the problem reduces to a plain
Gromov-Wasserstein problem.
fused_penalty: multiplier of the linear term in Fused Gromov-Wasserstein,
i.e. problem = purely quadratic + fused_penalty * linear problem.
Ignored if ``geom_xy`` is not specified.
fused Gromov-Wasserstein :cite:`vayer:19`. If :obj:`None`, the problem
reduces to a plain Gromov-Wasserstein problem :cite:`peyre:16`.
fused_penalty: Multiplier of the linear term in fused Gromov-Wasserstein,
i.e. ``problem = purely quadratic + fused_penalty * linear problem``.
scale_cost: option to rescale the cost matrices:
- if :obj:`True`, use the default for each geometry.
Expand All @@ -62,19 +60,14 @@ class QuadraticProblem:
:class:`~ott.geometry.pointcloud.PointCloud`.
- if :obj:`None`, do not scale the cost matrices.
a: array representing the probability weights of the samples
from ``geom_xx``. If `None`, it will be uniform.
b: array representing the probability weights of the samples
from ``geom_yy``. If `None`, it will be uniform.
loss: a 2-tuple of 2-tuples of Callable. The first tuple is the linear
part of the loss. The second one is the quadratic part (quad1, quad2).
By default, the loss is set as the 4 functions representing the squared
Euclidean loss, and this property is taken advantage of in subsequent
computations. Alternatively, KL loss can be specified in no less optimized
way.
tau_a: if `< 1.0`, defines how much unbalanced the problem is on
a: The first marginal. If :obj:`None`, it will be uniform.
b: The second marginal. If :obj:`None`, it will be uniform.
loss: Gromov-Wasserstein loss function, see
:class:`~ott.problems.quadratic.quadratic_costs.GWLoss` for more
information.
tau_a: If :math:`< 1.0`, defines how much unbalanced the problem is on
the first marginal.
tau_b: if `< 1.0`, defines how much unbalanced the problem is on
tau_b: If :math:`< 1.0`, defines how much unbalanced the problem is on
the second marginal.
gw_unbalanced_correction: Whether the unbalanced version of
:cite:`sejourne:21` is used. Otherwise, ``tau_a`` and ``tau_b``
Expand All @@ -101,8 +94,8 @@ def __init__(
a: Optional[jnp.ndarray] = None,
b: Optional[jnp.ndarray] = None,
loss: Union[Literal["sqeucl", "kl"], quadratic_costs.GWLoss] = "sqeucl",
tau_a: Optional[float] = 1.0,
tau_b: Optional[float] = 1.0,
tau_a: float = 1.0,
tau_b: float = 1.0,
gw_unbalanced_correction: bool = True,
ranks: Union[int, Tuple[int, ...]] = -1,
tolerances: Union[float, Tuple[float, ...]] = 1e-2,
Expand Down
19 changes: 15 additions & 4 deletions src/ott/solvers/linear/_solve.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Union

import jax.numpy as jnp

from ott.geometry import geometry
from ott.problems.linear import linear_problem

#if TYPE_CHECKING:
from ott.solvers.linear import sinkhorn, sinkhorn_lr

__all__ = ["solve"]
Expand All @@ -23,7 +34,7 @@ def solve(
"""Solve linear regularized OT problem using Sinkhorn iterations.
Args:
geom: The ground geometry cost of the linear problem.
geom: The ground geometry of the linear problem.
a: The first marginal. If :obj:`None`, it will be uniform.
b: The second marginal. If :obj:`None`, it will be uniform.
tau_a: If :math:`< 1`, defines how much unbalanced the problem is
Expand All @@ -36,7 +47,7 @@ def solve(
kwargs: Keyword arguments for
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` or
:class:`~ott.solvers.linear.sinkhorn_lr.LRSinkhorn`,
depending on ``rank``.
depending on the ``rank``.
Returns:
The Sinkhorn output.
Expand Down
14 changes: 4 additions & 10 deletions src/ott/solvers/linear/sinkhorn_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Jax implementation of the Low-Rank Sinkhorn algorithm."""
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -272,7 +271,7 @@ def _inv_g(self) -> jnp.ndarray:

@jax.tree_util.register_pytree_node_class
class LRSinkhorn(sinkhorn.Sinkhorn):
r"""A Low-Rank Sinkhorn solver for linear reg-OT problems.
r"""Low-Rank Sinkhorn solver for linear reg-OT problems.
The algorithm is described in :cite:`scetbon:21` and the implementation
contained here is adapted from `LOT <https://github.com/meyerscetbon/LOT>`_.
Expand All @@ -288,15 +287,12 @@ class LRSinkhorn(sinkhorn.Sinkhorn):
described in :cite:`scetbon:22b`.
epsilon: Entropic regularization added on top of low-rank problem.
initializer: How to initialize the :math:`Q`, :math:`R` and :math:`g`
factors. Valid options are `'random'`, `'rank2'`, `'k-means'`, and
`'generalized-k-means'`.
factors.
lse_mode: Whether to run computations in LSE or kernel mode.
inner_iterations: Number of inner iterations used by the algorithm before
re-evaluating progress.
use_danskin: Use Danskin theorem to evaluate gradient of objective w.r.t.
input parameters. Only `True` handled at this moment.
implicit_diff: Whether to use implicit differentiation. Currently, only
``implicit_diff = False`` is implemented.
progress_fn: callback function which gets called during the Sinkhorn
iterations, so the user can display the error at each iteration,
e.g., using a progress bar. See :func:`~ott.utils.default_progress_fn`
Expand All @@ -316,25 +312,23 @@ def __init__(
rank: int,
gamma: float = 10.,
gamma_rescale: bool = True,
epsilon: float = 0.,
epsilon: float = 0.0,
initializer: Union[Literal["random", "rank2", "k-means",
"generalized-k-means"],
initializers_lr.LRInitializer] = "random",
lse_mode: bool = True,
inner_iterations: int = 10,
use_danskin: bool = True,
implicit_diff: bool = False,
kwargs_dys: Optional[Mapping[str, Any]] = None,
kwargs_init: Optional[Mapping[str, Any]] = None,
progress_fn: Optional[ProgressCallbackFn_t] = None,
**kwargs: Any,
):
assert not implicit_diff, "Implicit diff. not yet implemented."
kwargs["implicit_diff"] = None # not yet implemented
super().__init__(
lse_mode=lse_mode,
inner_iterations=inner_iterations,
use_danskin=use_danskin,
implicit_diff=implicit_diff,
**kwargs
)
self.rank = rank
Expand Down
1 change: 1 addition & 0 deletions src/ott/solvers/quadratic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from . import gromov_wasserstein, gromov_wasserstein_lr, gw_barycenter
from ._solve import solve
91 changes: 91 additions & 0 deletions src/ott/solvers/quadratic/_solve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Literal, Optional, Union

import jax.numpy as jnp

from ott.geometry import geometry
from ott.problems.quadratic import quadratic_costs, quadratic_problem
from ott.solvers.quadratic import gromov_wasserstein as gw
from ott.solvers.quadratic import gromov_wasserstein_lr as lrgw

__all__ = ["solve"]


def solve(
geom_xx: geometry.Geometry,
geom_yy: geometry.Geometry,
geom_xy: Optional[geometry.Geometry] = None,
fused_penalty: float = 1.0,
a: Optional[jnp.ndarray] = None,
b: Optional[jnp.ndarray] = None,
tau_a: float = 1.0,
tau_b: float = 1.0,
loss: Union[Literal["sqeucl", "kl"], quadratic_costs.GWLoss] = "sqeucl",
gw_unbalanced_correction: bool = True,
rank: int = -1,
**kwargs: Any,
) -> Union[gw.GWOutput, lrgw.LRGWOutput]:
"""Solve quadratic regularized OT problem using a Gromov-Wasserstein solver.
Args:
geom_xx: Ground geometry of the first space.
geom_yy: Ground geometry of the second space.
geom_xy: Geometry defining the linear penalty term for
fused Gromov-Wasserstein :cite:`vayer:19`. If :obj:`None`, the problem
reduces to a plain Gromov-Wasserstein problem :cite:`peyre:16`.
fused_penalty: Multiplier of the linear term in fused Gromov-Wasserstein,
i.e. ``problem = purely quadratic + fused_penalty * linear problem``.
a: The first marginal. If :obj:`None`, it will be uniform.
b: The second marginal. If :obj:`None`, it will be uniform.
tau_a: If :math:`< 1`, defines how much unbalanced the problem is
on the first marginal.
tau_b: If :math:`< 1`, defines how much unbalanced the problem is
on the second marginal.
loss: Gromov-Wasserstein loss function, see
:class:`~ott.problems.quadratic.quadratic_costs.GWLoss` for more
information. If ``rank > 0``, ``'sqeucl'`` is always used.
gw_unbalanced_correction: Whether the unbalanced version of
:cite:`sejourne:21` is used. Otherwise, ``tau_a`` and ``tau_b``
only affect the resolution of the linearization of the GW problem
in the inner loop. Only used when ``rank = -1``.
rank: Rank constraint on the coupling to minimize the quadratic OT problem
:cite:`scetbon:22`. If :math:`-1`, no rank constraint is used.
kwargs: Keyword arguments for
:class:`~ott.solvers.quadratic.gromov_wasserstein.GromovWasserstein` or
:class:`~ott.solvers.quadratic.gromov_wasserstein_lr.LRGromovWasserstein`,
depending on the ``rank``
Returns:
The Gromov-Wasserstein output.
"""
prob = quadratic_problem.QuadraticProblem(
geom_xx=geom_xx,
geom_yy=geom_yy,
geom_xy=geom_xy,
fused_penalty=fused_penalty,
a=a,
b=b,
tau_a=tau_a,
tau_b=tau_b,
loss=loss,
gw_unbalanced_correction=gw_unbalanced_correction
)

if rank > 0:
solver = lrgw.LRGromovWasserstein(rank=rank, **kwargs)
else:
solver = gw.GromovWasserstein(rank=rank, **kwargs)

return solver(prob)
Loading

0 comments on commit 06b7428

Please sign in to comment.