Skip to content

Commit

Permalink
Generic LR cost decomposition (#99)
Browse files Browse the repository at this point in the history
* Initial implementation of generic LR cost decomp

* Add subset method

* Annotate array sizes, use multi_dot

* [ci skip] Make `to_LRCGeometry` in LR geom no-op

* Fix ``to_LRCGeometry`` when online, update docs

* Add factorization tests

* Add test for subsetting

* Polish documentation, add bibtex

* Fix unnecessary indents

* Disable `pytest-xdist` for all tests on CI

* Update GW to include generic LR cost decomp

* Fix LR cost conversion check in GW, add  test

* Fix `{GW,}LR` tutorial, use_danskin=False in LROut
  • Loading branch information
michalk8 authored Jul 12, 2022
1 parent 03f850d commit a463991
Show file tree
Hide file tree
Showing 18 changed files with 507 additions and 64 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
- name: Run all tests
if: ${{ matrix.test_mark == 'all' }}
run: |
pytest --cov=ott --cov-append --cov-report=xml --cov-report=term-missing --cov-config=setup.cfg --memray
pytest --cov=ott --cov-append --cov-report=xml --cov-report=term-missing --cov-config=setup.cfg --memray -n 0
- name: Upload coverage
uses: codecov/codecov-action@v3
Expand Down
6 changes: 6 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
'sphinx.ext.mathjax',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinxcontrib.bibtex',
'nbsphinx',
'IPython.sphinxext.ipython_console_highlighting',
'sphinx_autodoc_typehints',
Expand All @@ -75,6 +76,11 @@
pygments_lexer = 'ipython3'
nbsphinx_execute = 'never'

# bibliography
bibtex_bibfiles = ["references.bib"]
bibtex_reference_style = "author_year"
bibtex_default_style = "alpha"

# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']

Expand Down
1 change: 1 addition & 0 deletions docs/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Gromov-Wasserstein (Entropic and LR)
:toctree: _autosummary

gromov_wasserstein.gromov_wasserstein
gromov_wasserstein.GromovWasserstein
gromov_wasserstein.GWOutput

Neural Potentials
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ There are currently three packages, ``geometry``, ``core`` and ``tools``, playin
geometry
core
tools
references

Indices and tables
==================
Expand Down
29 changes: 29 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
@InProceedings{indyk:19,
title = {Sample-Optimal Low-Rank Approximation of Distance Matrices},
author = {Indyk, Pitor and Vakilian, Ali and Wagner, Tal and Woodruff, David P},
booktitle = {Proceedings of the Thirty-Second Conference on Learning Theory},
pages = {1723--1751},
year = {2019},
editor = {Beygelzimer, Alina and Hsu, Daniel},
volume = {99},
series = {Proceedings of Machine Learning Research},
month = {25--28 Jun},
publisher = {PMLR},
pdf = {http://proceedings.mlr.press/v99/indyk19a/indyk19a.pdf},
url = {https://proceedings.mlr.press/v99/indyk19a.html},
}

@InProceedings{scetbon:21,
title = {Low-Rank Sinkhorn Factorization},
author = {Scetbon, Meyer and Cuturi, Marco and Peyr{\'e}, Gabriel},
booktitle = {Proceedings of the 38th International Conference on Machine Learning},
pages = {9344--9354},
year = {2021},
editor = {Meila, Marina and Zhang, Tong},
volume = {139},
series = {Proceedings of Machine Learning Research},
month = {18--24 Jul},
publisher = {PMLR},
pdf = {http://proceedings.mlr.press/v139/scetbon21a/scetbon21a.pdf},
url = {https://proceedings.mlr.press/v139/scetbon21a.html},
}
5 changes: 5 additions & 0 deletions docs/references.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
References
==========

.. bibliography::
:cited:
75 changes: 55 additions & 20 deletions ott/core/gromov_wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
sinkhorn_lr,
was_solver,
)
from ott.geometry import epsilon_scheduler, geometry, low_rank, pointcloud
from ott.geometry import epsilon_scheduler, geometry, pointcloud

LinearOutput = Union[sinkhorn.SinkhornOutput, sinkhorn_lr.LRSinkhornOutput]

Expand Down Expand Up @@ -138,30 +138,52 @@ def update(

@jax.tree_util.register_pytree_node_class
class GromovWasserstein(was_solver.WassersteinSolver):
"""A Gromov Wasserstein solver, built on generic template."""
"""A Gromov Wasserstein solver, built on generic template.
Args:
args: Positional arguments for
:class:`~ott.core.was_solver.WassersteinSolver`.
cost_rank: Rank of the cost matrix, see
:meth:`~ott.geometry.geometry.Geometry.to_LRCGeometry`. Used when
geometries are *not* :class:`~ott.geometry.pointcloud.PointCloud` with
`'sqeucl'` cost function. If `-1`, these geometries will not be converted
to low-rank.
cost_tol: Tolerance used when converting geometries to low-rank. Used when
geometries are *not* :class:`~ott.geometry.pointcloud.PointCloud` with
`'sqeucl'` cost function.
kwargs: Keyword arguments for
:class:`~ott.core.was_solver.WassersteinSolver`.
"""

def __init__(
self,
*args: Any,
cost_rank: int = -1,
cost_tol: float = 1e-2,
**kwargs: Any
):
super().__init__(*args, **kwargs)
self.cost_rank = cost_rank
self.cost_tol = cost_tol

def __call__(self, prob: quad_problems.QuadraticProblem) -> GWOutput:
# Consider converting problem first if using low-rank solver
if self.is_low_rank:
convert = (
isinstance(prob.geom_xx, pointcloud.PointCloud) and
prob.geom_xx.is_squared_euclidean and
isinstance(prob.geom_yy, pointcloud.PointCloud) and
prob.geom_yy.is_squared_euclidean
if self.is_low_rank and self._convert_geoms_to_lr(prob):
prob.geom_xx = prob.geom_xx.to_LRCGeometry(
rank=self.cost_rank, tol=self.cost_tol
)
# Consider converting
if convert:
if not prob.is_fused or isinstance(prob.geom_xy, low_rank.LRCGeometry):
prob.geom_xx = prob.geom_xx.to_LRCGeometry()
prob.geom_yy = prob.geom_yy.to_LRCGeometry()
prob.geom_yy = prob.geom_yy.to_LRCGeometry(
rank=self.cost_rank, tol=self.cost_tol
)
if prob.geom_xy is not None:
if isinstance(
prob.geom_xy, pointcloud.PointCloud
) and prob.geom_xy.is_squared_euclidean:
prob.geom_xy = prob.geom_xy.to_LRCGeometry(prob.fused_penalty)
else:
if (
isinstance(prob.geom_xy, pointcloud.PointCloud) and
prob.geom_xy.is_squared_euclidean
):
prob.geom_xy = prob.geom_xy.to_LRCGeometry(prob.fused_penalty)
prob.geom_xx = prob.geom_xx.to_LRCGeometry()
prob.geom_yy = prob.geom_yy.to_LRCGeometry()
prob.geom_xy = prob.geom_xy.to_LRCGeometry(
rank=self.cost_rank, tol=self.cost_tol
)

# Possibly jit iteration functions and run. Closure on rank to
# avoid jitting issues, since rank value will be used to branch between
Expand Down Expand Up @@ -226,6 +248,19 @@ def output_from_state(self, state: GWState) -> GWOutput:
old_transport_mass=state.old_transport_mass
)

def _convert_geoms_to_lr(self, prob: quad_problems.QuadraticProblem) -> bool:

def is_sqeucl_pc(geom: geometry.Geometry) -> bool:
return isinstance(
geom, pointcloud.PointCloud
) and geom.is_squared_euclidean

geom_xx, geom_yy, geom_xy = prob.geom_xx, prob.geom_yy, prob.geom_xy
return self.cost_rank != -1 or (
is_sqeucl_pc(geom_xx) and is_sqeucl_pc(geom_yy) and
(geom_xy is None or is_sqeucl_pc(geom_xy))
)


def iterations(
solver: GromovWasserstein, prob: quad_problems.QuadraticProblem, rank: int
Expand Down
14 changes: 10 additions & 4 deletions ott/core/sinkhorn_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,18 @@ def set(self, **kwargs: Any) -> 'LRSinkhornOutput':
return self._replace(**kwargs)

def set_cost(
self, ot_prob: linear_problems.LinearProblem, lse_mode: bool,
use_danskin: bool
self,
ot_prob: linear_problems.LinearProblem,
lse_mode: bool,
use_danskin: bool = False
) -> 'LRSinkhornOutput':
del lse_mode
return self.set(reg_ot_cost=self.compute_reg_ot_cost(ot_prob, use_danskin))

def compute_reg_ot_cost(
self, ot_prob: linear_problems.LinearProblem, use_danskin: bool
self,
ot_prob: linear_problems.LinearProblem,
use_danskin: bool = False,
) -> float:
return compute_reg_ot_cost(self.q, self.r, self.g, ot_prob, use_danskin)

Expand Down Expand Up @@ -533,7 +537,9 @@ def run(
) -> LRSinkhornOutput:
"""Run loop of the solver, outputting a state upgraded to an output."""
out = sinkhorn.iterations(ot_prob, solver, init)
out = out.set_cost(ot_prob, solver.lse_mode, solver.use_danskin)
out = out.set_cost(
ot_prob, lse_mode=solver.lse_mode, use_danskin=solver.use_danskin
)
return out.set(ot_prob=ot_prob)


Expand Down
1 change: 1 addition & 0 deletions ott/core/was_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(

@property
def is_low_rank(self) -> bool:
"""Whether the solver is low-rank."""
return self.rank > 0

def tree_flatten(self):
Expand Down
130 changes: 128 additions & 2 deletions ott/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@
# Lint as: python3
"""A class describing operations used to instantiate and use a geometry."""
import functools
from typing import Any, Callable, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union

if TYPE_CHECKING:
from ott.geometry import low_rank

import jax
import jax.numpy as jnp
import jax.scipy as jsp
from typing_extensions import Literal

from ott.geometry import epsilon_scheduler, ops
Expand Down Expand Up @@ -212,7 +216,7 @@ def _set_scale_cost(
aux_data["scale_cost"] = scale_cost
return type(self).tree_unflatten(aux_data, children)

def copy_epsilon(self, other: epsilon_scheduler.Epsilon) -> "Geometry":
def copy_epsilon(self, other: 'Geometry') -> "Geometry":
"""Copy the epsilon parameters from another geometry."""
scheduler = other._epsilon
self._epsilon_init = scheduler._target_init
Expand Down Expand Up @@ -614,6 +618,128 @@ def prepare_divergences(
for arg1, arg2, _ in zip(cost_matrices, kernel_matrices, range(size))
)

def to_LRCGeometry(
self,
rank: int,
tol: float = 1e-2,
seed: int = 0
) -> 'low_rank.LRCGeometry':
r"""Factorize the cost matrix in sublinear time :cite:`indyk:19`.
Uses the implementation of :cite:`scetbon:21`, algorithm 4.
It holds that with probability *0.99*,
:math:`||A - UV||_F^2 \leq || A - A_k ||_F^2 + tol \cdot ||A||_F^2`,
where :math:`A` is ``n x m`` cost matrix, :math:`UV` the factorization
computed in sublinear time and :math:`A_k` the best rank-k approximation.
Args:
rank: Target rank of the :attr:`cost_matrix`.
tol: Tolerance of the error. The total number of sampled points is
:math:`min(n, m,\frac{rank}{tol})`.
seed: Random seed.
Returns:
Low-rank geometry.
"""
from ott.geometry import low_rank

assert rank > 0, f"Rank must be positive, got {rank}."
rng = jax.random.PRNGKey(seed)
key1, key2, key3, key4, key5 = jax.random.split(rng, 5)
n, m = self.shape
n_subset = min(int(rank / tol), n, m)

i_star = jax.random.randint(key1, shape=(), minval=0, maxval=n)
j_star = jax.random.randint(key2, shape=(), minval=0, maxval=m)

# force `batch_size=None` since `cost_matrix` would be `None`
ci_star = self.subset(
i_star, None, batch_size=None
).cost_matrix.ravel() ** 2 # (m,)
cj_star = self.subset(
None, j_star, batch_size=None
).cost_matrix.ravel() ** 2 # (n,)

p_row = cj_star + ci_star[j_star] + jnp.mean(ci_star) # (n,)
p_row /= jnp.sum(p_row)
row_ixs = jax.random.choice(key3, n, shape=(n_subset,), p=p_row)
# (n_subset, m)
S = self.subset(row_ixs, None, batch_size=None).cost_matrix
S /= jnp.sqrt(n_subset * p_row[row_ixs][:, None])

p_col = jnp.sum(S ** 2, axis=0) # (m,)
p_col /= jnp.sum(p_col)
# (n_subset,)
col_ixs = jax.random.choice(key4, m, shape=(n_subset,), p=p_col)
# (n_subset, n_subset)
W = S[:, col_ixs] / jnp.sqrt(n_subset * p_col[col_ixs][None, :])

U, _, V = jsp.linalg.svd(W)
U = U[:, :rank] # (n_subset, rank)
U = (S.T @ U) / jnp.linalg.norm(W.T @ U, axis=0) # (m, rank)

# lls
d, v = jnp.linalg.eigh(U.T @ U) # (k,), (k, k)
v /= jnp.sqrt(d)[None, :]

inv_scale = (1. / jnp.sqrt(n_subset))
col_ixs = jax.random.choice(key5, m, shape=(n_subset,)) # (n_subset,)

# (n, n_subset)
A_trans = self.subset(
None, col_ixs, batch_size=None
).cost_matrix * inv_scale
B = (U[col_ixs, :] @ v * inv_scale) # (n_subset, k)
M = jnp.linalg.inv(B.T @ B) # (k, k)
V = jnp.linalg.multi_dot([A_trans, B, M.T, v.T]) # (n, k)

return low_rank.LRCGeometry(
cost_1=V,
cost_2=U,
epsilon=self._epsilon_init,
relative_epsilon=self._relative_epsilon,
scale=self._scale_epsilon,
scale_cost=self._scale_cost,
**self._kwargs
)

def subset(
self,
src_ixs: Optional[jnp.ndarray],
tgt_ixs: Optional[jnp.ndarray],
**kwargs: Any,
) -> "Geometry":
"""Subset rows and/or columns of a geometry.
Args:
src_ixs: Source indices. If ``None``, use all rows.
tgt_ixs: Target indices. If ``None``, use all columns.
kwargs: Keyword arguments for :class:`ott.geometry.geometry.Geometry`.
Returns:
Subset of a geometry.
"""

def sub(
arr: jnp.ndarray, src_ixs: Optional[jnp.ndarray],
tgt_ixs: Optional[jnp.ndarray]
) -> jnp.ndarray:
if src_ixs is not None:
arr = arr[jnp.atleast_1d(src_ixs), :]
if tgt_ixs is not None:
arr = arr[:, jnp.atleast_1d(tgt_ixs)]
return arr

(cost, kernel, *children), aux_data = self.tree_flatten()
if cost is not None:
cost = sub(cost, src_ixs, tgt_ixs)
if kernel is not None:
kernel = sub(kernel, src_ixs, tgt_ixs)

aux_data = {**aux_data, **kwargs}
return type(self).tree_unflatten(aux_data, [cost, kernel] + children)

def tree_flatten(self):
return (
self._cost_matrix, self._kernel_matrix, self._epsilon_init,
Expand Down
5 changes: 5 additions & 0 deletions ott/geometry/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,11 @@ def transport_from_scalings(
' cloud geometry instead'
)

def subset(
self, src_ixs: Optional[jnp.ndarray], tgt_ixs: Optional[jnp.ndarray]
) -> NoReturn:
raise NotImplementedError("Subsetting grid is not implemented.")

@classmethod
def prepare_divergences(
cls,
Expand Down
Loading

0 comments on commit a463991

Please sign in to comment.