Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/geodesic sinkhorn #457

Merged
merged 48 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
d3d840f
Create `geodesic` module
diegoabt Sep 5, 2023
4306382
Merge pull request #1 from diegoabt/feature/geodesic-sinkhorn
diegoabt Sep 5, 2023
03a57f9
Lint code
diegoabt Sep 5, 2023
fb4de65
Lint code
diegoabt Sep 5, 2023
1f3cd9f
Merge branch 'main' of https://github.com/diegoabt/ott into main
diegoabt Sep 7, 2023
9473a68
Merge branch 'ott-jax:main' into main
diegoabt Sep 7, 2023
dd1fd95
Add Geodesic kernel citation to `docs/references.bib`
diegoabt Sep 8, 2023
8c087f4
Remove unused functions; update docstrings; remove `n_steps`; remove …
diegoabt Sep 8, 2023
408b14c
Remove hardcoded random key at `compute_largest_eigenvalue`
diegoabt Sep 8, 2023
1e33f10
Fix docstrings of functions inside `apply_kernel`
diegoabt Sep 8, 2023
d901e5c
Add chebyshev coeff computation to `from_graph`
diegoabt Sep 25, 2023
01eb7e1
Change `jax.exp.sparse` import
diegoabt Sep 26, 2023
f223d7a
Change `tree_flatten` outputs
diegoabt Sep 26, 2023
7bfa71c
Change input of `lobpcg_std` to be a sparsified product
diegoabt Sep 26, 2023
0a3b8c8
Change definition of cost from kernel
diegoabt Sep 27, 2023
c2e5770
Add `Geodesic` to `docs/geometry`
diegoabt Sep 28, 2023
c1ec39d
Remove `np.max` from max eigenval computation
diegoabt Sep 29, 2023
5830ebc
Add `default_prng_key` to eigenval computation
diegoabt Oct 5, 2023
1d06049
Add `safe_log` to cost matrix computation
diegoabt Oct 30, 2023
8e14cdb
Change to `jesp.BCOO` at chebyshev approx
diegoabt Oct 30, 2023
0b7364b
Restructure `from_graph`; coeffs are computed earlier now
diegoabt Oct 30, 2023
0a22bcd
Merge branch 'main' of https://github.com/diegoabt/ott into feature/g…
guillaumehu Oct 31, 2023
8bc10c1
fn outside of the class
guillaumehu Nov 2, 2023
3650cda
mv fn & process L once
guillaumehu Nov 2, 2023
06d04ef
wip tests geo
guillaumehu Nov 3, 2023
3e62805
jax pure_callback & new fn cheb
guillaumehu Nov 7, 2023
3640e45
symmetric kernel & wip tests
guillaumehu Nov 7, 2023
4f4447c
fix formatting ruff
guillaumehu Nov 8, 2023
2d5e0f0
wrap eigenval fn
guillaumehu Nov 8, 2023
083b021
rm num_scheme & _scale
guillaumehu Nov 8, 2023
406d9c8
type dense or sparse
guillaumehu Nov 8, 2023
9239524
expm with scan & fix hardcode dty & lobpcg iter
guillaumehu Nov 8, 2023
0dd5a97
test compare with BE and CN
guillaumehu Nov 8, 2023
43df649
rm lap_mn_id & test sink spd & fix tree_flatten
guillaumehu Nov 13, 2023
acda12b
default and type of Cheb. co. & docstrings
guillaumehu Nov 13, 2023
8bbb19e
dtype in purecallback depending on Lap
guillaumehu Nov 13, 2023
d640821
rm laplacian & update test
guillaumehu Nov 16, 2023
bab7d94
simpler wrapper for `ive`
guillaumehu Nov 17, 2023
b296da4
lint-docs spell check
guillaumehu Nov 17, 2023
b7d1df3
fix mistake rm setup
guillaumehu Nov 20, 2023
999b226
typo & rm fn scale lap &
guillaumehu Nov 27, 2023
e19cf01
rm t order init & typing
guillaumehu Nov 27, 2023
298c544
use rng util
guillaumehu Nov 27, 2023
7d13198
t for cost & test with ground truth
guillaumehu Nov 27, 2023
620245d
rm order & change `phi` to `eigval` & type hints
guillaumehu Nov 28, 2023
3a93b22
differentiability test
guillaumehu Nov 28, 2023
3e5ef04
condition to symmetrize kernel
guillaumehu Nov 28, 2023
cf17ce4
fix indentation
guillaumehu Nov 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/geometry.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ Geometries
pointcloud.PointCloud
grid.Grid
graph.Graph
geodesic.Geodesic
low_rank.LRCGeometry
epsilon_scheduler.Epsilon

Expand Down
7 changes: 7 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -805,3 +805,10 @@ @misc{klein:23
title = {Learning Costs for Structured Monge Displacements},
year = {2023},
}

@article{huguet:2022,
guillaumehu marked this conversation as resolved.
Show resolved Hide resolved
title={Geodesic Sinkhorn: optimal transport for high-dimensional datasets},
author={Huguet, Guillaume and Tong, Alexander and Zapatero, Mar{\'\i}a Ramos and Wolf, Guy and Krishnaswamy, Smita},
journal={arXiv preprint arXiv:2211.00805},
year={2022}
}
3 changes: 3 additions & 0 deletions docs/spelling/technical.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
Barycenters
Brenier
Bures
Chebyshev
Cholesky
DTW
Danskin
Expand All @@ -24,6 +25,7 @@ SGD
Schur
Seidel
Sinkhorn
UNet
Unbalancedness
Wasserstein
adaptively
Expand Down Expand Up @@ -93,6 +95,7 @@ parameterization
parameterizing
piecewise
pluripotent
polynomials
positivity
postfix
potentials
Expand Down
1 change: 1 addition & 0 deletions src/ott/geometry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from . import (
costs,
epsilon_scheduler,
geodesic,
geometry,
graph,
grid,
Expand Down
282 changes: 282 additions & 0 deletions src/ott/geometry/geodesic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

guillaumehu marked this conversation as resolved.
Show resolved Hide resolved
from typing import Any, Dict, Optional, Sequence, Tuple

import jax
import jax.experimental.sparse as jesp
import jax.numpy as jnp
import numpy as np
from scipy.special import ive

from ott import utils
from ott.geometry import geometry
from ott.math import utils as mu
from ott.types import Array_g

__all__ = ["Geodesic"]


@jax.tree_util.register_pytree_node_class
class Geodesic(geometry.Geometry):
r"""Graph distance approximation using heat kernel :cite:`huguet:2022`.

guillaumehu marked this conversation as resolved.
Show resolved Hide resolved
Approximates the heat kernel using Chebyshev polynomials of the
guillaumehu marked this conversation as resolved.
Show resolved Hide resolved
first kind of max order ``order``, which for small ``t`` approximates the
geodesic exponential kernel :math:`e^{\frac{-d(x, y)^2}{t}}`.

Args:
scaled_laplacian: The Laplacian scaled by the largest eigenvalue.
eigval: Largest eigenvalue of the Laplacian.
guillaumehu marked this conversation as resolved.
Show resolved Hide resolved
chebyshev_coeffs: Coefficients of the Chebyshev polynomials.
t: Time parameter for heat kernel.
guillaumehu marked this conversation as resolved.
Show resolved Hide resolved
order: Max order of Chebyshev polynomial.
kwargs: Keyword arguments for :class:`~ott.geometry.geometry.Geometry`.
"""

def __init__(
self,
scaled_laplacian: Array_g,
eigval: jnp.ndarray,
chebyshev_coeffs: jnp.ndarray,
t: float = 1e-3,
order: int = 100,
**kwargs: Any
):
super().__init__(epsilon=1., **kwargs)
self.scaled_laplacian = scaled_laplacian
self.eigval = eigval
self.chebyshev_coeffs = chebyshev_coeffs
self.t = t
self.order = order
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should remove this, as it's not being used anywhere in the class.
The docs will need to be adapted a bit as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only removed self.order, because self.t is used for the attribute cost_matrix, and self.eigval, self.chebyshev_coeffs are used inexpm_multiply. I renamed phi to eigval in expm_multiply to make it clearer.


@classmethod
def from_graph(
cls,
G: Array_g,
t: Optional[float] = 1e-3,
eigval: Optional[jnp.ndarray] = None,
order: int = 100,
directed: bool = False,
normalize: bool = False,
**kwargs: Any
) -> "Geodesic":
r"""Construct a Geodesic geometry from an adjacency matrix.

Args:
G: Adjacency matrix.
t: Time parameter for approximating the geodesic exponential kernel.
If `None`, it defaults to :math:`\frac{1}{|E|} \sum_{(u, v) \in E}
\text{weight}(u, v)` :cite:`crane:13`. In this case, the ``graph``
must be specified and the edge weights are assumed to be positive.
eigval: Largest eigenvalue of the Laplacian. If `None`, it's computed
at initialization.
order: Max order of Chebyshev polynomials.
directed: Whether the ``graph`` is directed. If `True`, it's made
undirected as :math:`G + G^T`. This parameter is ignored when passing
the Laplacian directly, assumed to be symmetric.
normalize: Whether to normalize the Laplacian as
:math:`L^{sym} = \left(D^+\right)^{\frac{1}{2}} L
\left(D^+\right)^{\frac{1}{2}}`, where :math:`L` is the
non-normalized Laplacian and :math:`D` is the degree matrix.
kwargs: Keyword arguments for the Geodesic class.
guillaumehu marked this conversation as resolved.
Show resolved Hide resolved

Returns:
The Geodesic geometry.
"""
assert G.shape[0] == G.shape[1], G.shape

if directed:
G = G + G.T

degree = jnp.sum(G, axis=1)
laplacian = jnp.diag(degree) - G

if normalize:
inv_sqrt_deg = jnp.diag(
jnp.where(degree > 0.0, 1.0 / jnp.sqrt(degree), 0.0)
)
laplacian = inv_sqrt_deg @ laplacian @ inv_sqrt_deg

eigval = compute_largest_eigenvalue(
laplacian, k=1
) if eigval is None else eigval
scaled_laplacian = rescale_laplacian(laplacian, eigval)

if t is None:
t = (jnp.sum(G) / jnp.sum(G > 0.)) ** 2

# Compute the coeffs of the Chebyshev pols approx using Bessel functs.
chebyshev_coeffs = compute_chebychev_coeff_all(
eigval, t, order, laplacian.dtype
)

return cls(
scaled_laplacian=scaled_laplacian,
eigval=eigval,
chebyshev_coeffs=chebyshev_coeffs,
t=t,
order=order,
**kwargs
)

def apply_kernel(
self,
scaling: jnp.ndarray,
eps: Optional[float] = None,
axis: int = 0,
) -> jnp.ndarray:
r"""Apply :attr:`kernel_matrix` on positive scaling vector.

Args:
scaling: Scaling to apply the kernel to.
eps: passed for consistency, not used yet.
axis: passed for consistency, not used yet.

Returns:
Kernel applied to ``scaling``.
"""
return expm_multiply(
self.scaled_laplacian, scaling, self.chebyshev_coeffs, self.eigval
)

@property
def kernel_matrix(self) -> jnp.ndarray: # noqa: D102
n, _ = self.shape
kernel = self.apply_kernel(jnp.eye(n))
# check if the kernel is symmetric
if jnp.any(kernel != kernel.T):
guillaumehu marked this conversation as resolved.
Show resolved Hide resolved
kernel = (kernel + kernel.T) / 2.0
return kernel

@property
def cost_matrix(self) -> jnp.ndarray: # noqa: D102
# Calculate the cost matrix using the formula (5) from the main reference
return -4 * self.t * mu.safe_log(self.kernel_matrix)
guillaumehu marked this conversation as resolved.
Show resolved Hide resolved

@property
def shape(self) -> Tuple[int, int]: # noqa: D102
return self.scaled_laplacian.shape

@property
def is_symmetric(self) -> bool: # noqa: D102
return True

@property
def dtype(self) -> jnp.dtype: # noqa: D102
return self.scaled_laplacian.dtype

def transport_from_potentials(
self, f: jnp.ndarray, g: jnp.ndarray
) -> jnp.ndarray:
"""Not implemented."""
raise ValueError("Not implemented.")

def apply_transport_from_potentials(
self,
f: jnp.ndarray,
g: jnp.ndarray,
vec: jnp.ndarray,
axis: int = 0
) -> jnp.ndarray:
"""Since applying from potentials is not feasible in grids, use scalings."""
u, v = self.scaling_from_potential(f), self.scaling_from_potential(g)
return self.apply_transport_from_scalings(u, v, vec, axis=axis)

def marginal_from_potentials(
self,
f: jnp.ndarray,
g: jnp.ndarray,
axis: int = 0,
) -> jnp.ndarray:
"""Not implemented."""
raise ValueError("Not implemented.")

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102
return [
self.scaled_laplacian,
self.eigval,
self.chebyshev_coeffs,
self.t,
self.order,
guillaumehu marked this conversation as resolved.
Show resolved Hide resolved
], {}

@classmethod
def tree_unflatten( # noqa: D102
cls, aux_data: Dict[str, Any], children: Sequence[Any]
) -> "Geodesic":
return cls(*children, **aux_data)


def compute_largest_eigenvalue(laplacian_matrix, k, rng=None):
guillaumehu marked this conversation as resolved.
Show resolved Hide resolved
# Compute the largest eigenvalue of the Laplacian matrix.
if rng is None:
guillaumehu marked this conversation as resolved.
Show resolved Hide resolved
rng = utils.default_prng_key(rng)
n = laplacian_matrix.shape[0]
# Generate random initial directions for eigenvalue computation
initial_dirs = jax.random.normal(rng, (n, k))

# Create a sparse matrix-vector product function using sparsify
# This function multiplies the sparse laplacian_matrix with a vector
lapl_vector_product = jesp.sparsify(lambda v: laplacian_matrix @ v)

# Compute eigenvalues using the sparse matrix-vector product
eigvals, _, _ = jesp.linalg.lobpcg_standard(
lapl_vector_product,
initial_dirs,
m=100,
guillaumehu marked this conversation as resolved.
Show resolved Hide resolved
)
return jnp.max(eigvals)


def rescale_laplacian(
laplacian_matrix: jnp.ndarray, largest_eigenvalue: jnp.ndarray
guillaumehu marked this conversation as resolved.
Show resolved Hide resolved
) -> jnp.ndarray:
# Rescale the Laplacian matrix.
return jax.lax.cond((largest_eigenvalue > 2),
guillaumehu marked this conversation as resolved.
Show resolved Hide resolved
lambda l: 2 * l / largest_eigenvalue, lambda l: l,
laplacian_matrix)


def expm_multiply(L, X, coeff, phi):
guillaumehu marked this conversation as resolved.
Show resolved Hide resolved

def body(carry, c):
T0, T1, Y = carry
T2 = (2 / phi) * L @ T1 - 2 * T1 - T0
Y = Y + c * T2
return (T1, T2, Y), None

T0 = X
Y = 0.5 * coeff[0] * T0
T1 = (1 / phi) * L @ X - T0
Y = Y + coeff[1] * T1

initial_state = (T0, T1, Y)
carry, _ = jax.lax.scan(body, initial_state, coeff[2:])
guillaumehu marked this conversation as resolved.
Show resolved Hide resolved
_, _, Y = carry
return Y


def compute_chebychev_coeff_all(phi, tau, K, dtype=jnp.float32):
guillaumehu marked this conversation as resolved.
Show resolved Hide resolved
"""Jax wrapper to compute the K+1 Chebychev coefficients."""
result_shape_dtype = jax.ShapeDtypeStruct(
shape=(K + 1,),
dtype=dtype,
)

chebychev_coeff = lambda phi, tau, K: (
2 * ive(np.arange(0, K + 1), -tau * phi)
).astype(dtype)

return jax.pure_callback(chebychev_coeff, result_shape_dtype, phi, tau, K)
6 changes: 5 additions & 1 deletion src/ott/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Protocol
from typing import Protocol, Union

import jax.experimental.sparse as jesp
import jax.numpy as jnp

__all__ = ["Transport"]

# TODO(michalk8): introduce additional types here

# Either a dense or sparse array.
Array_g = Union[jnp.ndarray, jesp.BCOO]


class Transport(Protocol):
"""Interface for the solution of a transport problem.
Expand Down
Loading