Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/quadratic solve #433

Merged
merged 7 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/initializers/linear.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Low-rank Sinkhorn Initializers
.. autosummary::
:toctree: _autosummary

initializers_lr.LRInitializer
initializers_lr.RandomInitializer
initializers_lr.Rank2Initializer
initializers_lr.KMeansInitializer
Expand Down
1 change: 1 addition & 0 deletions docs/problems/quadratic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@ Costs
.. autosummary::
:toctree: _autosummary

quadratic_costs.GWLoss
quadratic_costs.make_square_loss
quadratic_costs.make_kl_loss
2 changes: 1 addition & 1 deletion docs/solvers/quadratic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Gromov-Wasserstein Solvers
.. autosummary::
:toctree: _autosummary

gromov_wasserstein.solve
solve
gromov_wasserstein.GromovWasserstein
gromov_wasserstein.GWOutput
gromov_wasserstein_lr.LRGromovWasserstein
Expand Down
19 changes: 2 additions & 17 deletions docs/tutorials/notebooks/gromov_wasserstein_multiomics.ipynb
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "BB8VjJrVsuuG"
Expand All @@ -21,7 +20,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "BYAvtCwhsuuJ"
Expand All @@ -31,7 +29,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "pJXAdeCCsuuJ"
Expand Down Expand Up @@ -134,7 +131,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "m9uzbXpZsuuT"
Expand Down Expand Up @@ -187,7 +183,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "brfhPBHWsuuV"
Expand All @@ -214,7 +209,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "KtvRj6RosuuY"
Expand All @@ -224,7 +218,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "gmS0MXbwsuuZ"
Expand All @@ -234,7 +227,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "ZbiuslH0suua"
Expand Down Expand Up @@ -298,7 +290,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "LNp6t08Ysuuc"
Expand Down Expand Up @@ -347,7 +338,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "Lo-CPjPWsuue"
Expand All @@ -359,11 +349,10 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We provide the Average FOSCTTM to align `X` (chromatinaccessibility domain) to `Y` (gene expression domain) for each implementation:"
"We provide the Average FOSCTTM to align `X` (chromatin accessibility domain) to `Y` (gene expression domain) for each implementation:"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the catch!

]
},
{
Expand Down Expand Up @@ -415,7 +404,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "UtcqjKXC26uG"
Expand All @@ -425,7 +413,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "svoMyh1dwPvg"
Expand Down Expand Up @@ -539,7 +526,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "ZruK0o4Gsuul"
Expand Down Expand Up @@ -4879,7 +4865,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "7y-DPRLCsuuo"
Expand Down Expand Up @@ -16909,7 +16894,7 @@
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand Down
41 changes: 36 additions & 5 deletions src/ott/problems/quadratic/quadratic_costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.
from typing import Callable, NamedTuple

import jax
import jax.numpy as jnp
import jax.scipy as jsp

__all__ = ["make_square_loss", "make_kl_loss"]

Expand All @@ -24,23 +24,54 @@ class Loss(NamedTuple): # noqa: D101
is_linear: bool


class GWLoss(NamedTuple): # noqa: D101
class GWLoss(NamedTuple):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks

r"""Efficient decomposition of the Gromov-Wasserstein loss function.

The loss function :math:`L` is assumed to match the form given in eq. 5. of
:cite:`peyre:16`:

.. math::
L(x, y) = f_1(x) + f_2(y) - h_1(x) h_2(y)

Args:
f1: First linear term.
f2: Second linear term.
h1: First quadratic term.
h2: Second quadratic term.
"""
f1: Loss
f2: Loss
h1: Loss
h2: Loss


def make_square_loss() -> GWLoss: # noqa: D103
def make_square_loss() -> GWLoss:
"""Squared Euclidean loss for Gromov-Wasserstein.

See Prop. 1 and Remark 1 of :cite:`peyre:16` for more information.

Returns:
The squared Euclidean loss.
"""
f1 = Loss(lambda x: x ** 2, is_linear=False)
f2 = Loss(lambda y: y ** 2, is_linear=False)
h1 = Loss(lambda x: x, is_linear=True)
h2 = Loss(lambda y: 2.0 * y, is_linear=True)
return GWLoss(f1, f2, h1, h2)


def make_kl_loss(clipping_value: float = 1e-8) -> GWLoss: # noqa: D103
f1 = Loss(lambda x: -jax.scipy.special.entr(x) - x, is_linear=False)
def make_kl_loss(clipping_value: float = 1e-8) -> GWLoss:
r"""Kullback-Leibler loss for Gromov-Wasserstein.

See Prop. 1 and Remark 1 of :cite:`peyre:16` for more information.

Args:
clipping_value: Value used to avoid :math:`\log(0)`.

Returns:
The KL loss.
"""
f1 = Loss(lambda x: -jsp.special.entr(x) - x, is_linear=False)
f2 = Loss(lambda y: y, is_linear=True)
h1 = Loss(lambda x: x, is_linear=True)
h2 = Loss(lambda y: jnp.log(jnp.clip(y, clipping_value)), is_linear=False)
Expand Down
35 changes: 14 additions & 21 deletions src/ott/problems/quadratic/quadratic_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,16 @@ class QuadraticProblem:
is assumed to match the form given in eq. 5., with our notations:

.. math::

L(x, y) = lin1(x) + lin2(y) - quad1(x) * quad2(y)
L(x, y) = f_1(x) + f_2(y) - h_1(x) h_2(y)

Args:
geom_xx: Ground geometry of the first space.
geom_yy: Ground geometry of the second space.
geom_xy: Geometry defining the linear penalty term for
Fused Gromov-Wasserstein. If `None`, the problem reduces to a plain
Gromov-Wasserstein problem.
fused_penalty: multiplier of the linear term in Fused Gromov-Wasserstein,
i.e. problem = purely quadratic + fused_penalty * linear problem.
Ignored if ``geom_xy`` is not specified.
fused Gromov-Wasserstein :cite:`vayer:19`. If :obj:`None`, the problem
reduces to a plain Gromov-Wasserstein problem :cite:`peyre:16`.
fused_penalty: Multiplier of the linear term in fused Gromov-Wasserstein,
i.e. ``problem = purely quadratic + fused_penalty * linear problem``.
scale_cost: option to rescale the cost matrices:

- if :obj:`True`, use the default for each geometry.
Expand All @@ -62,19 +60,14 @@ class QuadraticProblem:
:class:`~ott.geometry.pointcloud.PointCloud`.
- if :obj:`None`, do not scale the cost matrices.

a: array representing the probability weights of the samples
from ``geom_xx``. If `None`, it will be uniform.
b: array representing the probability weights of the samples
from ``geom_yy``. If `None`, it will be uniform.
loss: a 2-tuple of 2-tuples of Callable. The first tuple is the linear
part of the loss. The second one is the quadratic part (quad1, quad2).
By default, the loss is set as the 4 functions representing the squared
Euclidean loss, and this property is taken advantage of in subsequent
computations. Alternatively, KL loss can be specified in no less optimized
way.
tau_a: if `< 1.0`, defines how much unbalanced the problem is on
a: The first marginal. If :obj:`None`, it will be uniform.
b: The second marginal. If :obj:`None`, it will be uniform.
loss: Gromov-Wasserstein loss function, see
:class:`~ott.problems.quadratic.quadratic_costs.GWLoss` for more
information.
tau_a: If :math:`< 1.0`, defines how much unbalanced the problem is on
the first marginal.
tau_b: if `< 1.0`, defines how much unbalanced the problem is on
tau_b: If :math:`< 1.0`, defines how much unbalanced the problem is on
the second marginal.
gw_unbalanced_correction: Whether the unbalanced version of
:cite:`sejourne:21` is used. Otherwise, ``tau_a`` and ``tau_b``
Expand All @@ -101,8 +94,8 @@ def __init__(
a: Optional[jnp.ndarray] = None,
b: Optional[jnp.ndarray] = None,
loss: Union[Literal["sqeucl", "kl"], quadratic_costs.GWLoss] = "sqeucl",
tau_a: Optional[float] = 1.0,
tau_b: Optional[float] = 1.0,
tau_a: float = 1.0,
tau_b: float = 1.0,
gw_unbalanced_correction: bool = True,
ranks: Union[int, Tuple[int, ...]] = -1,
tolerances: Union[float, Tuple[float, ...]] = 1e-2,
Expand Down
19 changes: 15 additions & 4 deletions src/ott/solvers/linear/_solve.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Union

import jax.numpy as jnp

from ott.geometry import geometry
from ott.problems.linear import linear_problem

#if TYPE_CHECKING:
from ott.solvers.linear import sinkhorn, sinkhorn_lr

__all__ = ["solve"]
Expand All @@ -23,7 +34,7 @@ def solve(
"""Solve linear regularized OT problem using Sinkhorn iterations.

Args:
geom: The ground geometry cost of the linear problem.
geom: The ground geometry of the linear problem.
a: The first marginal. If :obj:`None`, it will be uniform.
b: The second marginal. If :obj:`None`, it will be uniform.
tau_a: If :math:`< 1`, defines how much unbalanced the problem is
Expand All @@ -36,7 +47,7 @@ def solve(
kwargs: Keyword arguments for
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` or
:class:`~ott.solvers.linear.sinkhorn_lr.LRSinkhorn`,
depending on ``rank``.
depending on the ``rank``.

Returns:
The Sinkhorn output.
Expand Down
14 changes: 4 additions & 10 deletions src/ott/solvers/linear/sinkhorn_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Jax implementation of the Low-Rank Sinkhorn algorithm."""
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -272,7 +271,7 @@ def _inv_g(self) -> jnp.ndarray:

@jax.tree_util.register_pytree_node_class
class LRSinkhorn(sinkhorn.Sinkhorn):
r"""A Low-Rank Sinkhorn solver for linear reg-OT problems.
r"""Low-Rank Sinkhorn solver for linear reg-OT problems.

The algorithm is described in :cite:`scetbon:21` and the implementation
contained here is adapted from `LOT <https://github.com/meyerscetbon/LOT>`_.
Expand All @@ -288,15 +287,12 @@ class LRSinkhorn(sinkhorn.Sinkhorn):
described in :cite:`scetbon:22b`.
epsilon: Entropic regularization added on top of low-rank problem.
initializer: How to initialize the :math:`Q`, :math:`R` and :math:`g`
factors. Valid options are `'random'`, `'rank2'`, `'k-means'`, and
`'generalized-k-means'`.
factors.
lse_mode: Whether to run computations in LSE or kernel mode.
inner_iterations: Number of inner iterations used by the algorithm before
re-evaluating progress.
use_danskin: Use Danskin theorem to evaluate gradient of objective w.r.t.
input parameters. Only `True` handled at this moment.
implicit_diff: Whether to use implicit differentiation. Currently, only
``implicit_diff = False`` is implemented.
progress_fn: callback function which gets called during the Sinkhorn
iterations, so the user can display the error at each iteration,
e.g., using a progress bar. See :func:`~ott.utils.default_progress_fn`
Expand All @@ -316,25 +312,23 @@ def __init__(
rank: int,
gamma: float = 10.,
gamma_rescale: bool = True,
epsilon: float = 0.,
epsilon: float = 0.0,
initializer: Union[Literal["random", "rank2", "k-means",
"generalized-k-means"],
initializers_lr.LRInitializer] = "random",
lse_mode: bool = True,
inner_iterations: int = 10,
use_danskin: bool = True,
implicit_diff: bool = False,
kwargs_dys: Optional[Mapping[str, Any]] = None,
kwargs_init: Optional[Mapping[str, Any]] = None,
progress_fn: Optional[ProgressCallbackFn_t] = None,
**kwargs: Any,
):
assert not implicit_diff, "Implicit diff. not yet implemented."
kwargs["implicit_diff"] = None # not yet implemented
super().__init__(
lse_mode=lse_mode,
inner_iterations=inner_iterations,
use_danskin=use_danskin,
implicit_diff=implicit_diff,
**kwargs
)
self.rank = rank
Expand Down
1 change: 1 addition & 0 deletions src/ott/solvers/quadratic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from . import gromov_wasserstein, gromov_wasserstein_lr, gw_barycenter
from ._solve import solve
Loading