diff --git a/docs/geometry.rst b/docs/geometry.rst index adc0a3880..5b0d31e88 100644 --- a/docs/geometry.rst +++ b/docs/geometry.rst @@ -45,6 +45,7 @@ Geometries pointcloud.PointCloud grid.Grid graph.Graph + geodesic.Geodesic low_rank.LRCGeometry epsilon_scheduler.Epsilon diff --git a/docs/references.bib b/docs/references.bib index 35ba274ba..3060bf191 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -805,3 +805,10 @@ @misc{klein:23 title = {Learning Costs for Structured Monge Displacements}, year = {2023}, } + +@misc{huguet:2022, + title={Geodesic Sinkhorn: optimal transport for high-dimensional datasets}, + author={Huguet, Guillaume and Tong, Alexander and Zapatero, Mar{\'\i}a Ramos and Wolf, Guy and Krishnaswamy, Smita}, + journal={arXiv preprint arXiv:2211.00805}, + year={2022} +} diff --git a/docs/spelling/technical.txt b/docs/spelling/technical.txt index 4030a8292..63a7851fd 100644 --- a/docs/spelling/technical.txt +++ b/docs/spelling/technical.txt @@ -1,6 +1,7 @@ Barycenters Brenier Bures +Chebyshev Cholesky DTW Danskin @@ -24,6 +25,7 @@ SGD Schur Seidel Sinkhorn +UNet Unbalancedness Wasserstein adaptively @@ -93,6 +95,7 @@ parameterization parameterizing piecewise pluripotent +polynomials positivity postfix potentials diff --git a/src/ott/geometry/__init__.py b/src/ott/geometry/__init__.py index 5890e0935..c16ba687a 100644 --- a/src/ott/geometry/__init__.py +++ b/src/ott/geometry/__init__.py @@ -14,6 +14,7 @@ from . import ( costs, epsilon_scheduler, + geodesic, geometry, graph, grid, diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py new file mode 100644 index 000000000..f07cce777 --- /dev/null +++ b/src/ott/geometry/geodesic.py @@ -0,0 +1,274 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional, Sequence, Tuple + +import jax +import jax.experimental.sparse as jesp +import jax.numpy as jnp +import numpy as np +from scipy.special import ive + +from ott.geometry import geometry +from ott.math import utils as mu +from ott.types import Array_g +from ott.utils import default_prng_key + +__all__ = ["Geodesic"] + + +@jax.tree_util.register_pytree_node_class +class Geodesic(geometry.Geometry): + r"""Graph distance approximation using heat kernel :cite:`huguet:2022`. + + .. important:: + This constructor is not meant to be called by the user, + please use the :meth:`from_graph` method instead. + + Approximates the heat kernel using + `Chebyshev polynomials `_ + of the first kind of max order ``order``, which for small ``t`` + approximates the geodesic exponential kernel :math:`e^{\frac{-d(x, y)^2}{t}}`. + + Args: + scaled_laplacian: The Laplacian scaled by the largest eigenvalue. + eigval: The largest eigenvalue of the Laplacian. + chebyshev_coeffs: Coefficients of the Chebyshev polynomials. + t: Time parameter for the heat kernel. + kwargs: Keyword arguments for :class:`~ott.geometry.geometry.Geometry`. + """ + + def __init__( + self, + scaled_laplacian: Array_g, + eigval: jnp.ndarray, + chebyshev_coeffs: jnp.ndarray, + t: float = 1e-3, + **kwargs: Any + ): + super().__init__(epsilon=1., **kwargs) + self.scaled_laplacian = scaled_laplacian + self.eigval = eigval + self.chebyshev_coeffs = chebyshev_coeffs + self.t = t + + @classmethod + def from_graph( + cls, + G: Array_g, + t: Optional[float] = 1e-3, + eigval: Optional[jnp.ndarray] = None, + order: int = 100, + directed: bool = False, + normalize: bool = False, + **kwargs: Any + ) -> "Geodesic": + r"""Construct a Geodesic geometry from an adjacency matrix. + + Args: + G: Adjacency matrix. + t: Time parameter for approximating the geodesic exponential kernel. + If `None`, it defaults to :math:`\frac{1}{|E|} \sum_{(u, v) \in E} + \text{weight}(u, v)` :cite:`crane:13`. In this case, the ``graph`` + must be specified and the edge weights are assumed to be positive. + eigval: Largest eigenvalue of the Laplacian. If `None`, it's computed + at initialization. + order: Max order of Chebyshev polynomials. + directed: Whether the ``graph`` is directed. If `True`, it's made + undirected as :math:`G + G^T`. This parameter is ignored when passing + the Laplacian directly, assumed to be symmetric. + normalize: Whether to normalize the Laplacian as + :math:`L^{sym} = \left(D^+\right)^{\frac{1}{2}} L + \left(D^+\right)^{\frac{1}{2}}`, where :math:`L` is the + non-normalized Laplacian and :math:`D` is the degree matrix. + kwargs: Keyword arguments for :class:`~ott.geometry.geodesic.Geodesic`. + + Returns: + The Geodesic geometry. + """ + assert G.shape[0] == G.shape[1], G.shape + + if directed: + G = G + G.T + + degree = jnp.sum(G, axis=1) + laplacian = jnp.diag(degree) - G + + if normalize: + inv_sqrt_deg = jnp.diag( + jnp.where(degree > 0.0, 1.0 / jnp.sqrt(degree), 0.0) + ) + laplacian = inv_sqrt_deg @ laplacian @ inv_sqrt_deg + + eigval = compute_largest_eigenvalue(laplacian) if eigval is None else eigval + + scaled_laplacian = jax.lax.cond((eigval > 2.0), lambda l: 2.0 * l / eigval, + lambda l: l, laplacian) + + if t is None: + t = (jnp.sum(G) / jnp.sum(G > 0.)) ** 2.0 + + # Compute the coeffs of the Chebyshev pols approx using Bessel functs. + chebyshev_coeffs = compute_chebychev_coeff_all( + eigval, t, order, laplacian.dtype + ) + + return cls( + scaled_laplacian=scaled_laplacian, + eigval=eigval, + chebyshev_coeffs=chebyshev_coeffs, + t=t, + **kwargs + ) + + def apply_kernel( + self, + scaling: jnp.ndarray, + eps: Optional[float] = None, + axis: int = 0, + ) -> jnp.ndarray: + r"""Apply :attr:`kernel_matrix` on positive scaling vector. + + Args: + scaling: Scaling to apply the kernel to. + eps: passed for consistency, not used yet. + axis: passed for consistency, not used yet. + + Returns: + Kernel applied to ``scaling``. + """ + return expm_multiply( + self.scaled_laplacian, scaling, self.chebyshev_coeffs, self.eigval + ) + + @property + def kernel_matrix(self) -> jnp.ndarray: # noqa: D102 + n, _ = self.shape + kernel = self.apply_kernel(jnp.eye(n)) + return jax.lax.cond( + jnp.allclose(kernel, kernel.T, atol=1e-8, rtol=1e-8), lambda x: x, + lambda x: (x + x.T) / 2.0, kernel + ) + + @property + def cost_matrix(self) -> jnp.ndarray: # noqa: D102 + # Calculate the cost matrix using the formula (5) from the main reference + return -4.0 * self.t * mu.safe_log(self.kernel_matrix) + + @property + def shape(self) -> Tuple[int, int]: # noqa: D102 + return self.scaled_laplacian.shape + + @property + def is_symmetric(self) -> bool: # noqa: D102 + return True + + @property + def dtype(self) -> jnp.dtype: # noqa: D102 + return self.scaled_laplacian.dtype + + def transport_from_potentials( + self, f: jnp.ndarray, g: jnp.ndarray + ) -> jnp.ndarray: + """Not implemented.""" + raise ValueError("Not implemented.") + + def apply_transport_from_potentials( + self, + f: jnp.ndarray, + g: jnp.ndarray, + vec: jnp.ndarray, + axis: int = 0 + ) -> jnp.ndarray: + """Since applying from potentials is not feasible in grids, use scalings.""" + u, v = self.scaling_from_potential(f), self.scaling_from_potential(g) + return self.apply_transport_from_scalings(u, v, vec, axis=axis) + + def marginal_from_potentials( + self, + f: jnp.ndarray, + g: jnp.ndarray, + axis: int = 0, + ) -> jnp.ndarray: + """Not implemented.""" + raise ValueError("Not implemented.") + + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 + return [ + self.scaled_laplacian, + self.eigval, + self.chebyshev_coeffs, + self.t, + ], {} + + @classmethod + def tree_unflatten( # noqa: D102 + cls, aux_data: Dict[str, Any], children: Sequence[Any] + ) -> "Geodesic": + return cls(*children, **aux_data) + + +def compute_largest_eigenvalue( + laplacian_matrix: jnp.ndarray, rng: Optional[jax.Array] = None +) -> float: + # Compute the largest eigenvalue of the Laplacian matrix. + n = laplacian_matrix.shape[0] + # Generate random initial directions for eigenvalue computation + initial_dirs = jax.random.normal(default_prng_key(rng), (n, 1)) + + # Create a sparse matrix-vector product function using sparsify + # This function multiplies the sparse laplacian_matrix with a vector + lapl_vector_product = jesp.sparsify(lambda v: laplacian_matrix @ v) + + # Compute eigenvalues using the sparse matrix-vector product + eigvals, _, _ = jesp.linalg.lobpcg_standard( + lapl_vector_product, + initial_dirs, + ) + return jnp.max(eigvals) + + +def expm_multiply( + L: jnp.ndarray, X: jnp.ndarray, coeff: jnp.ndarray, eigval: float +) -> jnp.ndarray: + + def body(carry, c): + T0, T1, Y = carry + T2 = (2.0 / eigval) * L @ T1 - 2.0 * T1 - T0 + Y = Y + c * T2 + return (T1, T2, Y), None + + T0 = X + Y = 0.5 * coeff[0] * T0 + T1 = (1.0 / eigval) * L @ X - T0 + Y = Y + coeff[1] * T1 + + initial_state = (T0, T1, Y) + (_, _, Y), _ = jax.lax.scan(body, initial_state, coeff[2:]) + return Y + + +def compute_chebychev_coeff_all( + eigval: float, tau: float, K: int, dtype: np.dtype +) -> jnp.ndarray: + """Jax wrapper to compute the K+1 Chebychev coefficients.""" + result_shape_dtype = jax.ShapeDtypeStruct( + shape=(K + 1,), + dtype=dtype, + ) + + chebychev_coeff = lambda eigval, tau, K: ( + 2.0 * ive(np.arange(0, K + 1), -tau * eigval) + ).astype(dtype) + + return jax.pure_callback(chebychev_coeff, result_shape_dtype, eigval, tau, K) diff --git a/src/ott/geometry/graph.py b/src/ott/geometry/graph.py index c7dac0c99..5109929f0 100644 --- a/src/ott/geometry/graph.py +++ b/src/ott/geometry/graph.py @@ -189,9 +189,12 @@ def body_fn( def kernel_matrix(self) -> jnp.ndarray: # noqa: D102 n, _ = self.shape kernel = self.apply_kernel(jnp.eye(n)) - # force symmetry because of numerical imprecision + # Symmetrize the kernel if needed. Numerical imprecision # happens when `numerical_scheme='backward_euler'` and small `t` - return (kernel + kernel.T) * 0.5 + return jax.lax.cond( + jnp.allclose(kernel, kernel.T, atol=1e-8, rtol=1e-8), lambda x: x, + lambda x: (x + x.T) / 2.0, kernel + ) @property def cost_matrix(self) -> jnp.ndarray: # noqa: D102 diff --git a/src/ott/types.py b/src/ott/types.py index 7a4c88716..dc7e552d9 100644 --- a/src/ott/types.py +++ b/src/ott/types.py @@ -11,14 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Protocol +from typing import Protocol, Union +import jax.experimental.sparse as jesp import jax.numpy as jnp __all__ = ["Transport"] # TODO(michalk8): introduce additional types here +# Either a dense or sparse array. +Array_g = Union[jnp.ndarray, jesp.BCOO] + class Transport(Protocol): """Interface for the solution of a transport problem. diff --git a/tests/geometry/geo_test.py b/tests/geometry/geo_test.py new file mode 100644 index 000000000..c54261e3a --- /dev/null +++ b/tests/geometry/geo_test.py @@ -0,0 +1,244 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Union + +import jax +import jax.numpy as jnp +import networkx as nx +import numpy as np +import pytest +from networkx.algorithms import shortest_paths +from networkx.generators import balanced_tree, random_graphs +from ott.geometry import geodesic, geometry, graph +from ott.problems.linear import linear_problem +from ott.solvers.linear import sinkhorn + + +def random_graph( + n: int, + p: float = 0.3, + seed: Optional[int] = 0, + *, + return_laplacian: bool = False, + directed: bool = False, +) -> jnp.ndarray: + G = random_graphs.fast_gnp_random_graph(n, p, seed=seed, directed=directed) + if not directed: + assert nx.is_connected(G), "Generated graph is not connected." + + rng = np.random.RandomState(seed) + for _, _, w in G.edges(data=True): + w["weight"] = rng.uniform(0, 10) + + G = nx.linalg.laplacian_matrix( + G + ) if return_laplacian else nx.linalg.adjacency_matrix(G) + + return jnp.asarray(G.toarray()) + + +def gt_geometry( + G: Union[jnp.ndarray, nx.Graph], + *, + epsilon: float = 1e-2 +) -> geometry.Geometry: + if not isinstance(G, nx.Graph): + G = nx.from_numpy_array(np.asarray(G)) + + n = len(G) + cost = np.zeros((n, n)) + + path = dict( + shortest_paths.all_pairs_bellman_ford_path_length(G, weight="weight") + ) + for i, src in enumerate(G.nodes): + for j, tgt in enumerate(G.nodes): + cost[i, j] = path[src][tgt] ** 2 + + cost = jnp.asarray(cost) + kernel = jnp.asarray(np.exp(-cost / epsilon)) + return geometry.Geometry(cost_matrix=cost, kernel_matrix=kernel, epsilon=1.) + + +class TestGeodesic: + + def test_kernel_is_symmetric_positive_definite( + self, + rng: jax.Array, + ): + n, tol = 100, 0.02 + t = 1 + order = 50 + x = jax.random.normal(rng, (n,)) + G = random_graph(n) + geom = geodesic.Geodesic.from_graph(G, t=t, order=order) + + kernel = geom.kernel_matrix + + vec0 = geom.apply_kernel(x, axis=0) + vec1 = geom.apply_kernel(x, axis=1) + vec_direct0 = geom.kernel_matrix.T @ x + vec_direct1 = geom.kernel_matrix @ x + + # we symmetrize the kernel explicitly when materializing it, because + # numerical errors can make it non-symmetric. + np.testing.assert_allclose(kernel, kernel.T, rtol=tol, atol=tol) + eigenvalues = jnp.linalg.eigvals(kernel) + neg_eigenvalues = eigenvalues[eigenvalues < 0] + # check that the negative eigenvalues are all very small + np.testing.assert_array_less(jnp.abs(neg_eigenvalues), 1e-3) + # internally, the axis is ignored because the kernel is symmetric + np.testing.assert_array_equal(vec0, vec1) + np.testing.assert_array_equal(vec_direct0, vec_direct1) + + np.testing.assert_allclose(vec0, vec_direct0, rtol=tol, atol=tol) + np.testing.assert_allclose(vec1, vec_direct1, rtol=tol, atol=tol) + + # compute the distance matrix and check that it is symmetric + cost_matrix = geom.cost_matrix + np.testing.assert_array_equal(cost_matrix, cost_matrix.T) + # and all dissimilarities are positive + np.testing.assert_array_less(0, cost_matrix) + + @pytest.mark.fast.with_args( + order=[50, 100, 200], + t=[1e-4, 1e-5], + ) + def test_approximates_ground_truth(self, t: Optional[float], order: int): + tol = 1e-2 + G = nx.linalg.adjacency_matrix(balanced_tree(r=2, h=5)) + G = jnp.asarray(G.toarray(), dtype=float) + eye = jnp.eye(G.shape[0]) + eps = jnp.finfo(eye.dtype).tiny + + gt_geom = gt_geometry(G, epsilon=eps) + + geo = geodesic.Geodesic.from_graph(G, t=t, order=order) + + np.testing.assert_allclose( + gt_geom.kernel_matrix, geo.kernel_matrix, rtol=tol, atol=tol + ) + + @pytest.mark.parametrize(("jit", "normalize"), [(False, True), (True, False)]) + def test_directed_graph(self, jit: bool, normalize: bool): + + def create_graph(G: jnp.ndarray) -> graph.Graph: + return geodesic.Geodesic.from_graph(G, directed=True, normalize=normalize) + + G = random_graph(16, p=0.25, directed=True) + create_fn = jax.jit(create_graph) if jit else create_graph + geom = create_fn(G) + + with pytest.raises(AssertionError): + np.testing.assert_allclose(G, G.T) + + L = geom.scaled_laplacian + + with pytest.raises(AssertionError): + # make sure that original graph was directed + np.testing.assert_allclose(G, G.T, rtol=1e-6, atol=1e-6) + np.testing.assert_allclose(L, L.T, rtol=1e-6, atol=1e-6) + + @pytest.mark.parametrize("directed", [False, True]) + @pytest.mark.parametrize("normalize", [False, True]) + def test_normalize_laplacian(self, directed: bool, normalize: bool): + + def laplacian(G: jnp.ndarray) -> jnp.ndarray: + if directed: + G = G + G.T + + data = jnp.sum(G, axis=1) + lap = jnp.diag(data) - G + if normalize: + inv_sqrt_deg = jnp.diag( + jnp.where(data > 0.0, 1.0 / jnp.sqrt(data), 0.0) + ) + return inv_sqrt_deg @ lap @ inv_sqrt_deg + return lap + + G = random_graph(51, p=0.35, directed=directed) + geom = geodesic.Geodesic.from_graph( + G, directed=directed, normalize=normalize + ) + + expected = laplacian(G) + eigenvalues = jnp.linalg.eigvals(expected) + eigval = jnp.max(eigenvalues) + #rescale the laplacian + expected = 2 * expected / eigval if eigval > 2 else expected + + actual = geom.scaled_laplacian + + np.testing.assert_allclose(actual, expected, rtol=1e-6, atol=1e-6) + + @pytest.mark.fast.with_args(jit=[False, True], only_fast=0) + def test_geo_sinkhorn(self, rng: jax.Array, jit: bool): + + def callback(geom: geometry.Geometry) -> sinkhorn.SinkhornOutput: + solver = sinkhorn.Sinkhorn(lse_mode=False) + problem = linear_problem.LinearProblem(geom) + return solver(problem) + + n, eps, tol = 11, 1e-5, 1e-3 + G = random_graph(n, p=0.35) + x = jax.random.normal(rng, (n,)) + + gt_geom = gt_geometry(G, epsilon=eps) + graph_geom = geodesic.Geodesic.from_graph(G, t=eps) + + fn = jax.jit(callback) if jit else callback + gt_out = fn(gt_geom) + graph_out = fn(graph_geom) + + assert gt_out.converged + assert graph_out.converged + np.testing.assert_allclose( + graph_out.reg_ot_cost, gt_out.reg_ot_cost, rtol=tol, atol=tol + ) + np.testing.assert_allclose(graph_out.f, gt_out.f, rtol=tol, atol=tol) + np.testing.assert_allclose(graph_out.g, gt_out.g, rtol=tol, atol=tol) + + for axis in [0, 1]: + y_gt = gt_out.apply(x, axis=axis) + y_out = graph_out.apply(x, axis=axis) + # note the high tolerance + np.testing.assert_allclose(y_gt, y_out, rtol=5e-1, atol=5e-1) + + np.testing.assert_allclose( + gt_out.matrix, graph_out.matrix, rtol=1e-1, atol=1e-1 + ) + + def test_geometry_differentiability(self, rng: jax.Array): + + def callback(geom) -> float: + + solver = sinkhorn.Sinkhorn(lse_mode=False) + problem = linear_problem.LinearProblem(geom) + + return solver(problem).reg_ot_cost + + eps = 1e-3 + G = random_graph(20, p=0.5) + geom = geodesic.Geodesic.from_graph(G, t=1.) + + v_w = jax.random.normal(rng, shape=G.shape) + v_w = (v_w / jnp.linalg.norm(v_w, axis=-1, keepdims=True)) * eps + + grad_sl = jax.grad(callback)(geom).scaled_laplacian + geom__finite_right = geodesic.Geodesic.from_graph(G + v_w, t=1.) + geom__finite_left = geodesic.Geodesic.from_graph(G - v_w, t=1.) + + expected = callback(geom__finite_right) - callback(geom__finite_left) + actual = 2 * jnp.vdot(v_w, grad_sl) + np.testing.assert_allclose(actual, expected, rtol=1e-4, atol=1e-4) diff --git a/tests/geometry/graph_test.py b/tests/geometry/graph_test.py index 18e683bb4..7ba5b2df7 100644 --- a/tests/geometry/graph_test.py +++ b/tests/geometry/graph_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Literal, Optional, Tuple +from typing import Literal, Optional, Tuple, Union import jax import jax.numpy as jnp @@ -50,12 +50,16 @@ def random_graph( return jnp.asarray(G.toarray()) -def gt_geometry(G: jnp.ndarray, *, epsilon: float = 1e-2) -> geometry.Geometry: +def gt_geometry( + G: Union[jnp.ndarray, nx.Graph], + *, + epsilon: float = 1e-2 +) -> geometry.Geometry: if not isinstance(G, nx.Graph): G = nx.from_numpy_array(np.asarray(G)) n = len(G) - cost = np.zeros((n, n), dtype=float) + cost = np.zeros((n, n)) path = dict( shortest_paths.all_pairs_bellman_ford_path_length(G, weight="weight")