Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic LR cost decomposition #99

Merged
merged 15 commits into from
Jul 12, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
- name: Run all tests
if: ${{ matrix.test_mark == 'all' }}
run: |
pytest --cov=ott --cov-append --cov-report=xml --cov-report=term-missing --cov-config=setup.cfg --memray
pytest --cov=ott --cov-append --cov-report=xml --cov-report=term-missing --cov-config=setup.cfg --memray -n 0
- name: Upload coverage
uses: codecov/codecov-action@v3
Expand Down
6 changes: 6 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
'sphinx.ext.mathjax',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinxcontrib.bibtex',
'nbsphinx',
'IPython.sphinxext.ipython_console_highlighting',
'sphinx_autodoc_typehints',
Expand All @@ -75,6 +76,11 @@
pygments_lexer = 'ipython3'
nbsphinx_execute = 'never'

# bibliography
bibtex_bibfiles = ["references.bib"]
bibtex_reference_style = "author_year"
bibtex_default_style = "alpha"

# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']

Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ There are currently three packages, ``geometry``, ``core`` and ``tools``, playin
geometry
core
tools
references

Indices and tables
==================
Expand Down
29 changes: 29 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
@InProceedings{indyk:19,
title = {Sample-Optimal Low-Rank Approximation of Distance Matrices},
author = {Indyk, Pitor and Vakilian, Ali and Wagner, Tal and Woodruff, David P},
booktitle = {Proceedings of the Thirty-Second Conference on Learning Theory},
pages = {1723--1751},
year = {2019},
editor = {Beygelzimer, Alina and Hsu, Daniel},
volume = {99},
series = {Proceedings of Machine Learning Research},
month = {25--28 Jun},
publisher = {PMLR},
pdf = {http://proceedings.mlr.press/v99/indyk19a/indyk19a.pdf},
url = {https://proceedings.mlr.press/v99/indyk19a.html},
}

@InProceedings{scetbon:21,
title = {Low-Rank Sinkhorn Factorization},
author = {Scetbon, Meyer and Cuturi, Marco and Peyr{\'e}, Gabriel},
booktitle = {Proceedings of the 38th International Conference on Machine Learning},
pages = {9344--9354},
year = {2021},
editor = {Meila, Marina and Zhang, Tong},
volume = {139},
series = {Proceedings of Machine Learning Research},
month = {18--24 Jul},
publisher = {PMLR},
pdf = {http://proceedings.mlr.press/v139/scetbon21a/scetbon21a.pdf},
url = {https://proceedings.mlr.press/v139/scetbon21a.html},
}
5 changes: 5 additions & 0 deletions docs/references.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
References
==========

.. bibliography::
:cited:
130 changes: 128 additions & 2 deletions ott/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@
# Lint as: python3
"""A class describing operations used to instantiate and use a geometry."""
import functools
from typing import Any, Callable, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union

if TYPE_CHECKING:
from ott.geometry import low_rank

import jax
import jax.numpy as jnp
import jax.scipy as jsp
from typing_extensions import Literal

from ott.geometry import epsilon_scheduler, ops
Expand Down Expand Up @@ -212,7 +216,7 @@ def _set_scale_cost(
aux_data["scale_cost"] = scale_cost
return type(self).tree_unflatten(aux_data, children)

def copy_epsilon(self, other: epsilon_scheduler.Epsilon) -> "Geometry":
def copy_epsilon(self, other: 'Geometry') -> "Geometry":
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
"""Copy the epsilon parameters from another geometry."""
scheduler = other._epsilon
self._epsilon_init = scheduler._target_init
Expand Down Expand Up @@ -614,6 +618,128 @@ def prepare_divergences(
for arg1, arg2, _ in zip(cost_matrices, kernel_matrices, range(size))
)

def to_LRCGeometry(
self,
rank: int,
tol: float = 1e-2,
seed: int = 0
) -> 'low_rank.LRCGeometry':
r"""Factorize the cost matrix in sublinear time :cite:`indyk:19`.
Uses the implementation of :cite:`scetbon:21`, algorithm 4.
It holds that with probability *0.99*,
:math:`||A - UV||_F^2 \leq || A - A_k ||_F^2 + tol \cdot ||A||_F^2`,
where :math:`A` is ``n x m`` cost matrix, :math:`UV` the factorization
computed in sublinear time and :math:`A_k` the best rank-k approximation.
Args:
rank: Target rank of the :attr:`cost_matrix`.
tol: Tolerance of the error. The total number of sampled points is
:math:`min(n, m,\frac{rank}{tol})`.
seed: Random seed.
Returns:
Low-rank geometry.
"""
from ott.geometry import low_rank

assert rank > 0, f"Rank must be positive, got {rank}."
rng = jax.random.PRNGKey(seed)
key1, key2, key3, key4, key5 = jax.random.split(rng, 5)
n, m = self.shape
n_subset = min(int(rank / tol), n, m)

i_star = jax.random.randint(key1, shape=(), minval=0, maxval=n)
j_star = jax.random.randint(key2, shape=(), minval=0, maxval=m)

# force `batch_size=None` since `cost_matrix` would be `None`
ci_star = self.subset(
i_star, None, batch_size=None
).cost_matrix.ravel() ** 2 # (m,)
cj_star = self.subset(
None, j_star, batch_size=None
).cost_matrix.ravel() ** 2 # (n,)

p_row = cj_star + ci_star[j_star] + jnp.mean(ci_star) # (n,)
p_row /= jnp.sum(p_row)
row_ixs = jax.random.choice(key3, n, shape=(n_subset,), p=p_row)
# (n_subset, m)
S = self.subset(row_ixs, None, batch_size=None).cost_matrix
S /= jnp.sqrt(n_subset * p_row[row_ixs][:, None])

p_col = jnp.sum(S ** 2, axis=0) # (m,)
p_col /= jnp.sum(p_col)
# (n_subset,)
col_ixs = jax.random.choice(key4, m, shape=(n_subset,), p=p_col)
# (n_subset, n_subset)
W = S[:, col_ixs] / jnp.sqrt(n_subset * p_col[col_ixs][None, :])

U, _, V = jsp.linalg.svd(W)
U = U[:, :rank] # (n_subset, rank)
U = (S.T @ U) / jnp.linalg.norm(W.T @ U, axis=0) # (m, rank)

# lls
d, v = jnp.linalg.eigh(U.T @ U) # (k,), (k, k)
v /= jnp.sqrt(d)[None, :]

inv_scale = (1. / jnp.sqrt(n_subset))
col_ixs = jax.random.choice(key5, m, shape=(n_subset,)) # (n_subset,)

# (n, n_subset)
A_trans = self.subset(
None, col_ixs, batch_size=None
).cost_matrix * inv_scale
B = (U[col_ixs, :] @ v * inv_scale) # (n_subset, k)
M = jnp.linalg.inv(B.T @ B) # (k, k)
V = jnp.linalg.multi_dot([A_trans, B, M.T, v.T]) # (n, k)

return low_rank.LRCGeometry(
cost_1=V,
cost_2=U,
epsilon=self._epsilon_init,
relative_epsilon=self._relative_epsilon,
scale=self._scale_epsilon,
scale_cost=self._scale_cost,
**self._kwargs
)

def subset(
self,
src_ixs: Optional[jnp.ndarray],
tgt_ixs: Optional[jnp.ndarray],
**kwargs: Any,
) -> "Geometry":
"""Subset rows and/or columns of a geometry.
Args:
src_ixs: Source indices. If ``None``, use all rows.
tgt_ixs: Target indices. If ``None``, use all columns.
kwargs: Keyword arguments for :class:`ott.geometry.geometry.Geometry`.
Returns:
Subset of a geometry.
"""

def sub(
arr: jnp.ndarray, src_ixs: Optional[jnp.ndarray],
tgt_ixs: Optional[jnp.ndarray]
) -> jnp.ndarray:
if src_ixs is not None:
arr = arr[jnp.atleast_1d(src_ixs), :]
if tgt_ixs is not None:
arr = arr[:, jnp.atleast_1d(tgt_ixs)]
return arr

(cost, kernel, *children), aux_data = self.tree_flatten()
if cost is not None:
cost = sub(cost, src_ixs, tgt_ixs)
if kernel is not None:
kernel = sub(kernel, src_ixs, tgt_ixs)

aux_data = {**aux_data, **kwargs}
return type(self).tree_unflatten(aux_data, [cost, kernel] + children)

def tree_flatten(self):
return (
self._cost_matrix, self._kernel_matrix, self._epsilon_init,
Expand Down
5 changes: 5 additions & 0 deletions ott/geometry/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,11 @@ def transport_from_scalings(
' cloud geometry instead'
)

def subset(
self, src_ixs: Optional[jnp.ndarray], tgt_ixs: Optional[jnp.ndarray]
) -> NoReturn:
raise NotImplementedError("Subsetting grid is not implemented.")

@classmethod
def prepare_divergences(
cls,
Expand Down
29 changes: 29 additions & 0 deletions ott/geometry/low_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,35 @@ def finalize(carry):
max_value = jnp.max(jnp.concatenate((out, last_slice.reshape(-1))))
return max_value + self._bias

def to_LRCGeometry(
self, rank: int, tol: float = 1e-2, seed: int = 0
) -> 'LRCGeometry':
"""Return self."""
return self

def subset(
self, src_ixs: Optional[jnp.ndarray], tgt_ixs: Optional[jnp.ndarray],
**kwargs: Any
) -> "LRCGeometry":
"""Subset rows and/or columns of a geometry.

Args:
src_ixs: Source indices. If ``None``, use all rows.
tgt_ixs: Target indices. If ``None``, use all columns.
kwargs: Keyword arguments for :class:`ott.geometry.low_rank.LRCGeometry`.

Returns:
The subsetted geometry.
"""
(c1, c2, *children), aux_data = self.tree_flatten()
if src_ixs is not None:
c1 = c1[jnp.atleast_1d(src_ixs), :]
if tgt_ixs is not None:
c2 = c2[jnp.atleast_1d(tgt_ixs), :]

aux_data = {**aux_data, **kwargs}
return type(self).tree_unflatten(aux_data, [c1, c2] + children)

def tree_flatten(self):
return (self._cost_1, self._cost_2, self._kwargs), {
'bias': self._bias,
Expand Down
92 changes: 61 additions & 31 deletions ott/geometry/pointcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,41 +558,71 @@ def tree_unflatten(cls, aux_data, children):
x, y, eps, cost_fn = children
return cls(x, y, epsilon=eps, cost_fn=cost_fn, **aux_data)

def to_LRCGeometry(self, scale: float = 1.0) -> low_rank.LRCGeometry:
def to_LRCGeometry(
self,
scale: float = 1.0,
**kwargs: Any,
) -> Union[low_rank.LRCGeometry, 'PointCloud']:
"""Convert sqEuc. PointCloud to LRCGeometry if useful, and rescale."""
if self.is_squared_euclidean:
(n, m), d = self.shape, self.x.shape[1]
if n * m > (n + m) * d: # here apply_cost using LRCGeometry preferable.
cost_1 = jnp.concatenate((
jnp.sum(self.x ** 2, axis=1, keepdims=True),
jnp.ones((self.shape[0], 1)), -jnp.sqrt(2) * self.x
),
axis=1)
cost_2 = jnp.concatenate((
jnp.ones(
(self.shape[1], 1)
), jnp.sum(self.y ** 2, axis=1, keepdims=True), jnp.sqrt(2) * self.y
),
axis=1)
cost_1 *= jnp.sqrt(scale)
cost_2 *= jnp.sqrt(scale)

return low_rank.LRCGeometry(
cost_1=cost_1,
cost_2=cost_2,
epsilon=self._epsilon_init,
relative_epsilon=self._relative_epsilon,
scale=self._scale_epsilon,
scale_cost=self._scale_cost,
**self._kwargs
)
else:
(x, y, *children), aux_data = self.tree_flatten()
x = x * jnp.sqrt(scale)
y = y * jnp.sqrt(scale)
return PointCloud.tree_unflatten(aux_data, [x, y] + children)
else:
raise ValueError('Cannot turn non-sq-Euclidean geometry into low-rank')
return self._sqeucl_to_lr(scale)
(x, y, *children), aux_data = self.tree_flatten()
x = x * jnp.sqrt(scale)
y = y * jnp.sqrt(scale)
return PointCloud.tree_unflatten(aux_data, [x, y] + children)
return super().to_LRCGeometry(**kwargs)

def _sqeucl_to_lr(self, scale: float = 1.0) -> low_rank.LRCGeometry:
assert self.is_squared_euclidean, "Geometry must be squared Euclidean."
n, m = self.shape
cost_1 = jnp.concatenate((
jnp.sum(self.x ** 2, axis=1, keepdims=True), jnp.ones(
(n, 1)
), -jnp.sqrt(2) * self.x
),
axis=1)
cost_2 = jnp.concatenate((
jnp.ones((m, 1)), jnp.sum(self.y ** 2, axis=1,
keepdims=True), jnp.sqrt(2) * self.y
),
axis=1)
cost_1 *= jnp.sqrt(scale)
cost_2 *= jnp.sqrt(scale)

return low_rank.LRCGeometry(
cost_1=cost_1,
cost_2=cost_2,
epsilon=self._epsilon_init,
relative_epsilon=self._relative_epsilon,
scale=self._scale_epsilon,
scale_cost=self._scale_cost,
**self._kwargs
)

def subset(
self, src_ixs: Optional[jnp.ndarray], tgt_ixs: Optional[jnp.ndarray],
**kwargs: Any
) -> "PointCloud":
"""Subset rows and/or columns of a geometry.

Args:
src_ixs: Source indices. If ``None``, use all rows.
tgt_ixs: Target indices. If ``None``, use all columns.
kwargs: Keyword arguments for :class:`ott.geometry.pointcloud.PointCloud`.

Returns:
The subsetted geometry.
"""
(x, y, *children), aux_data = self.tree_flatten()
if src_ixs is not None:
x = x[jnp.atleast_1d(src_ixs), :]
if tgt_ixs is not None:
y = y[jnp.atleast_1d(tgt_ixs), :]

aux_data = {**aux_data, **kwargs}
return type(self).tree_unflatten(aux_data, [x, y] + children)

@property
def batch_size(self) -> Optional[int]:
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ docs =
ipython>=7.20.0
sphinx_autodoc_typehints>=1.12.0
sphinx-book-theme
sphinxcontrib-bibtex
dev =
pre-commit

Expand Down
2 changes: 0 additions & 2 deletions tests/core/continuous_barycenter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,6 @@ def test_euclidean_barycenter(
lse_mode=[False, True],
epsilon=[1e-1, 5e-1],
jit=[False, True],
# TODO(michalk8): finalize the API
# might be beneficial to all for more than 1 test to be selected
only_fast={
"lse_mode": True,
"epsilon": 1e-1,
Expand Down
Loading