Skip to content

Commit

Permalink
LR Sinkhorn improvements (#111)
Browse files Browse the repository at this point in the history
* test

* update lr-sinkhorn

* restored_branch

* check

* review

* circular fixed

* update review

* Fix bugs in `LRSinkhorn`

* Use new `k-means` implementation

* Fix linter

* Refactor `LRSinkhorn` initializers

* Use `if` for `is_entropic`, remove dead variables

* Slightly improve types

* Do not use stateful `gamma`

* Fix typo in tests

* Fix using `state.gamma` instead of `self.gamma`

* Fix point cloud size in notebook

* Add assertion to k-means

* Use `jax.lax.cond` instead of `jax.numpy.where`

* Change convergence criterion

* Use safe log

* Fix more tests

* Fix tests

* Fix `tree_flatten` in `KMeansInitializer`

* Fix defaults, change `rank_2` -> `rank2`

* Simplify `apply`

* Update TODOs

* Update docs, make `lr_costs` private

* Increate tolerance in failing test

* Update LR notebook

* Address comments

* Remove LR Sinkhorn notebook from testing, to slow

Co-authored-by: Michal Klein
  • Loading branch information
meyerscetbon authored Sep 1, 2022
1 parent 3d1c7d2 commit 1419513
Show file tree
Hide file tree
Showing 15 changed files with 760 additions and 263 deletions.
9 changes: 9 additions & 0 deletions docs/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ Low-Rank Sinkhorn
sinkhorn_lr.LRSinkhorn
sinkhorn_lr.LRSinkhornOutput

Low-Rank Sinkhorn Initializers
------------------------------
.. autosummary::
:toctree: _autosummary

initializers_lr.RandomInitializer
initializers_lr.Rank2Initializer
initializers_lr.KMeansInitializer

Barycenters (Entropic and LR)
-----------------------------
.. autosummary::
Expand Down
48 changes: 20 additions & 28 deletions docs/notebooks/LRSinkhorn.ipynb

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -564,3 +564,15 @@ @article{crane:13
keywords = {heat kernel, discrete differential geometry, geodesic distance, Digital geometry processing,
distance transform}
}

@misc{scetbon:22b,
doi = {10.48550/ARXIV.2205.12365},
url = {https://arxiv.org/abs/2205.12365},
author = {Scetbon, Meyer and Cuturi, Marco},
keywords = {Machine Learning (stat.ML), Machine Learning (cs.LG), FOS: Computer and information sciences,
FOS: Computer and information sciences},
title = {Low-rank Optimal Transport: Approximation, Statistics and Debiasing},
publisher = {arXiv},
year = {2022},
copyright = {Creative Commons Attribution 4.0 International}
}
2 changes: 2 additions & 0 deletions ott/core/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from ott.core import linear_problems
from ott.geometry import pointcloud

__all__ = ["DefaultInitializer", "GaussianInitializer", "SortingInitializer"]


@jax.tree_util.register_pytree_node_class
class SinkhornInitializer(ABC):
Expand Down
345 changes: 345 additions & 0 deletions ott/core/initializers_lr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,345 @@
import functools
from abc import ABC, abstractmethod
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, Union

import jax
from jax import numpy as jnp
from typing_extensions import Literal

from ott.core import linear_problems
from ott.geometry import low_rank, pointcloud

__all__ = ["RandomInitializer", "Rank2Initializer", "KMeansInitializer"]


@jax.tree_util.register_pytree_node_class
class LRSinkhornInitializer(ABC):
"""Low-rank Sinkhorn initializer.
Args:
rank: Rank of the factorization.
"""

def __init__(self, rank: int):
self._rank = rank

@abstractmethod
def init_q(
self,
ot_prob: linear_problems.LinearProblem,
key: jnp.ndarray,
**kwargs: Any,
) -> jnp.ndarray:
"""Initialize the low-rank factor :math:`Q`.
Args:
ot_prob: Linear OT problem.
key: Random key for seeding.
kwargs: Additional keyword arguments.
Returns:
Array of shape ``[n, rank]``.
"""

@abstractmethod
def init_r(
self,
ot_prob: linear_problems.LinearProblem,
key: jnp.ndarray,
**kwargs: Any,
) -> jnp.ndarray:
"""Initialize the low-rank factor :math:`R`.
Args:
ot_prob: Linear OT problem.
key: Random key for seeding.
kwargs: Additional keyword arguments.
Returns:
Array of shape ``[m, rank]``.
"""

@abstractmethod
def init_g(
self,
ot_prob: linear_problems.LinearProblem,
key: jnp.ndarray,
**kwargs: Any,
) -> jnp.ndarray:
"""Initialize the low-rank factor :math:`g`.
Args:
ot_prob: Linear OT problem.
key: Random key for seeding.
kwargs: Additional keyword arguments.
Returns:
Array of shape ``[rank,]``.
"""

def __call__(
self,
ot_prob: Optional[linear_problems.LinearProblem],
q: Optional[jnp.ndarray] = None,
r: Optional[jnp.ndarray] = None,
g: Optional[jnp.ndarray] = None,
*,
key: Optional[jnp.ndarray] = None,
**kwargs: Any
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Initialize the factors :math:`Q`, :math:`R` and :math:`g`.
Args:
ot_prob: Linear OT problem.
q: Factor of shape ``[n, rank]``. If not `None`, :meth:`init_q` will be
used to initialize the factor.
r: Array of shape ``[m, rank]``. If not `None`, :meth:`init_r` will be
used to initialize the factor.
g: Array of shape ``[rank,]``. If not `None`, :meth:`init_g` will be
used to initialize the factor.
key: Random key for seeding.
kwargs: Additional keyword arguments for :meth:`init_q`, :meth:`init_r`
and :meth:`init_g`.
Returns:
The factors :math:`Q`, :math:`R` and :math:`g`, respectively.
"""
if key is None:
key = jax.random.PRNGKey(0)
key1, key2, key3 = jax.random.split(key, 3)

if g is None:
g = self.init_g(ot_prob, key1, **kwargs)
if q is None:
q = self.init_q(ot_prob, key2, init_g=g, **kwargs)
if r is None:
r = self.init_r(ot_prob, key3, init_g=g, **kwargs)

assert g.shape == (self.rank,)
assert q.shape == (ot_prob.a.shape[0], self.rank)
assert r.shape == (ot_prob.b.shape[0], self.rank)

return q, r, g

@property
def rank(self) -> int:
"""Rank of the transport matrix factorization."""
return self._rank

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]:
return [self.rank], {}

@classmethod
def tree_unflatten(
cls, aux_data: Dict[str, Any], children: Sequence[Any]
) -> "LRSinkhornInitializer":
return cls(*children, **aux_data)


@jax.tree_util.register_pytree_node_class
class RandomInitializer(LRSinkhornInitializer):
"""Low-rank Sinkhorn factorization using random factors.
Args:
rank: Rank of the factorization.
"""

def init_q(
self,
ot_prob: linear_problems.LinearProblem,
key: jnp.ndarray,
**kwargs: Any,
) -> jnp.ndarray:
del kwargs
a = ot_prob.a
init_q = jnp.abs(jax.random.normal(key, (a.shape[0], self.rank)))
return a[:, None] * (init_q / jnp.sum(init_q, axis=1, keepdims=True))

def init_r(
self,
ot_prob: linear_problems.LinearProblem,
key: jnp.ndarray,
**kwargs: Any,
) -> jnp.ndarray:
del kwargs
b = ot_prob.b
init_r = jnp.abs(jax.random.normal(key, (b.shape[0], self.rank)))
return b[:, None] * (init_r / jnp.sum(init_r, axis=1, keepdims=True))

def init_g(
self,
ot_prob: linear_problems.LinearProblem,
key: jnp.ndarray,
**kwargs: Any,
) -> jnp.ndarray:
del kwargs
init_g = jnp.abs(jax.random.uniform(key, (self.rank,))) + 1.
return init_g / jnp.sum(init_g)


@jax.tree_util.register_pytree_node_class
class Rank2Initializer(LRSinkhornInitializer):
"""Low-rank Sinkhorn factorization using rank-2 factors :cite:`scetbon:21`.
Args:
rank: Rank of the factorization.
"""

def _compute_factor(
self,
ot_prob: linear_problems.LinearProblem,
init_g: jnp.ndarray,
*,
which: Literal["q", "r"],
) -> jnp.ndarray:
a, b = ot_prob.a, ot_prob.b
marginal = a if which == "q" else b
n, r = marginal.shape[0], self.rank

lambda_1 = jnp.min(
jnp.array([jnp.min(a), jnp.min(init_g),
jnp.min(b)])
) * .5

# normalization to 1 can overflow in i32 (e.g., n=128k)
# using the formula: r * (r + 1) / 2 will raise:
# OverflowError: Python int 16384128000 too large to convert to int32
# normalizing by `jnp.sum()` overflows silently
g1 = 2. * jnp.arange(1, r + 1) / (r ** 2 + r)
g2 = (init_g - lambda_1 * g1) / (1. - lambda_1)
x = 2. * jnp.arange(1, n + 1) / (n ** 2 + n)
y = (marginal - lambda_1 * x) / (1. - lambda_1)

return ((lambda_1 * x[:, None] @ g1.reshape(1, -1)) +
((1 - lambda_1) * y[:, None] @ g2.reshape(1, -1)))

def init_q(
self,
ot_prob: linear_problems.LinearProblem,
key: jnp.ndarray,
*,
init_g: jnp.ndarray,
**kwargs: Any,
) -> jnp.ndarray:
del key, kwargs
return self._compute_factor(ot_prob, init_g, which="q")

def init_r(
self,
ot_prob: linear_problems.LinearProblem,
key: jnp.ndarray,
*,
init_g: jnp.ndarray,
**kwargs: Any,
) -> jnp.ndarray:
del key, kwargs
return self._compute_factor(ot_prob, init_g, which="r")

def init_g(
self,
ot_prob: linear_problems.LinearProblem,
key: jnp.ndarray,
**kwargs: Any,
) -> jnp.ndarray:
del key, kwargs
return jnp.ones((self.rank,)) / self.rank


@jax.tree_util.register_pytree_node_class
class KMeansInitializer(LRSinkhornInitializer):
"""K-means initializer for low-rank Sinkhorn :cite:`scetbon:22b`.
Args:
rank: Rank of the factorization.
sinkhorn_kwargs: Keyword arguments for :class:`~ott.core.sinkhorn.Sinkhorn`.
kwargs: Keyword arguments for :func:`~ott.tools.k_means.k_means`.
"""

def __init__(
self,
rank: int,
sinkhorn_kwargs: Optional[Mapping[str, Any]] = None,
**kwargs: Any
):
super().__init__(rank)
self._sinkhorn_kwargs = {} if sinkhorn_kwargs is None else sinkhorn_kwargs
self._k_means_kwargs = kwargs

@staticmethod
def _extract_array(
geom: Union[pointcloud.PointCloud, low_rank.LRCGeometry], *, first: bool
) -> jnp.ndarray:
if isinstance(geom, pointcloud.PointCloud):
return geom.x if first else geom.y
if isinstance(geom, low_rank.LRCGeometry):
return geom.cost_1 if first else geom.cost_2
raise TypeError(
f"k-means initializer not implemented for `{type(geom).__name__}`."
)

def _compute_factor(
self,
ot_prob: linear_problems.LinearProblem,
key: jnp.ndarray,
*,
init_g: jnp.ndarray,
which: Literal["q", "r"],
**kwargs: Any,
) -> jnp.ndarray:
from ott.core import sinkhorn
from ott.tools import k_means

del kwargs
jit = self._sinkhorn_kwargs.get("jit", True)
fn = functools.partial(k_means.k_means, **self._k_means_kwargs)
fn = jax.jit(fn, static_argnames="k") if jit else fn

arr = self._extract_array(ot_prob.geom, first=which == "q")
marginals = ot_prob.a if which == "q" else ot_prob.b

centroids = fn(arr, self.rank, key=key).centroids
geom = pointcloud.PointCloud(
arr, centroids, epsilon=0.1, scale_cost="max_cost"
)

prob = linear_problems.LinearProblem(geom, marginals, init_g)
solver = sinkhorn.Sinkhorn(**self._sinkhorn_kwargs)
return solver(prob).matrix

def init_q(
self,
ot_prob: linear_problems.LinearProblem,
key: jnp.ndarray,
*,
init_g: jnp.ndarray,
**kwargs: Any,
) -> jnp.ndarray:
return self._compute_factor(
ot_prob, key, init_g=init_g, which="q", **kwargs
)

def init_r(
self,
ot_prob: linear_problems.LinearProblem,
key: jnp.ndarray,
*,
init_g: jnp.ndarray,
**kwargs: Any,
) -> jnp.ndarray:
return self._compute_factor(
ot_prob, key, init_g=init_g, which="r", **kwargs
)

def init_g(
self,
ot_prob: linear_problems.LinearProblem,
key: jnp.ndarray,
**kwargs: Any,
) -> jnp.ndarray:
del key, kwargs
return jnp.ones((self.rank,)) / self.rank

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]:
children, aux_data = super().tree_flatten()
aux_data["sinkhorn_kwargs"] = self._sinkhorn_kwargs
return children, {**aux_data, **self._k_means_kwargs}
2 changes: 1 addition & 1 deletion ott/core/quad_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def init_linearization(
)

def init_lr_linearization(
self, rank: int = 10, **kwargs: Any
self, rank: int, **kwargs: Any
) -> linear_problems.LinearProblem:
"""Linearizes a Quad problem with a predefined initializer."""
x_ = self.geom_xx.apply_square_cost(self.a)
Expand Down
Loading

0 comments on commit 1419513

Please sign in to comment.