From d3d840f3e80c1e0bcd141f17325deb98542f63ff Mon Sep 17 00:00:00 2001 From: Diego Baptista Theuerkauf <34717973+diegoabt@users.noreply.github.com> Date: Tue, 5 Sep 2023 13:43:48 +0200 Subject: [PATCH 01/44] Create `geodesic` module This module contains the implementation of the Geodesic Sinkhorn algorithm [1]. [1] Huguet, G., Tong, A., Zapatero, M. R., Wolf, G., & Krishnaswamy, S. (2022). Geodesic Sinkhorn: optimal transport for high-dimensional datasets. arXiv preprint arXiv:2211.00805. --- src/ott/geometry/geodesic.py | 321 +++++++++++++++++++++++++++++++++++ 1 file changed, 321 insertions(+) create mode 100644 src/ott/geometry/geodesic.py diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py new file mode 100644 index 000000000..812f582b7 --- /dev/null +++ b/src/ott/geometry/geodesic.py @@ -0,0 +1,321 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Literal, Optional, Sequence, Tuple, List + +import jax +import jax.numpy as jnp +from jax.experimental.sparse.linalg import lobpcg_standard + +import numpy as np +from scipy.special import ive + +from ott.geometry import geometry +from ott.math import utils as mu + + +__all__ = ["Geodesic"] + + +@jax.tree_util.register_pytree_node_class +class Geodesic(geometry.Geometry): #TODO: Rewrite docstring + r"""Graph distance approximation using heat kernel :cite:`heitz:21,crane:13`. + + Approximates the heat kernel for large ``n_steps``, which for small ``t`` + approximates the geodesic exponential kernel :math:`e^{\frac{-d(x, y)^2}{t}}`. + + Args: + laplacian: Symmetric graph Laplacian. The check for symmetry is **NOT** + performed. See also :meth:`from_graph`. + n_steps: Maximum number of steps used to approximate the heat kernel. + numerical_scheme: Numerical scheme used to solve the heat diffusion. + normalize: Whether to normalize the Laplacian as + :math:`L^{sym} = \left(D^+\right)^{\frac{1}{2}} L + \left(D^+\right)^{\frac{1}{2}}`, where :math:`L` is the + non-normalized Laplacian and :math:`D` is the degree matrix. + tol: Relative tolerance with respect to the Hilbert metric, see + :cite:`peyre:19`, Remark 4.12. Used when iteratively updating scalings. + If negative, this option is ignored and only ``n_steps`` is used. + kwargs: Keyword arguments for :class:`~ott.geometry.geometry.Geometry`. + """ + + def __init__( + self, + laplacian: jnp.ndarray, + t: float = 1e-3, + n_steps: int = 100, + tol: float = -1.0, + **kwargs: Any + ): + super().__init__(epsilon=1., **kwargs) + self.laplacian = laplacian + self.t = t + self.n_steps = n_steps + self.tol = tol + + @classmethod + def from_graph( + cls, + G: jnp.ndarray, + t: Optional[float] = 1e-3, + directed: bool = False, + normalize: bool = False, + **kwargs: Any + ) -> "Geodesic": + r"""Construct :class:`~ott.geometry.graph.Graph` from an adjacency matrix. + + Args: + G: Adjacency matrix. + t: Constant used when approximating the geodesic exponential kernel. + If `None`, use :math:`\frac{1}{|E|} \sum_{(u, v) \in E} weight(u, v)` + :cite:`crane:13`. In this case, the ``graph`` must be specified + and the edge weights are all assumed to be positive. + directed: Whether the ``graph`` is directed. If not, it will be made + undirected as :math:`G + G^T`. This parameter is ignored when directly + passing the Laplacian, which is assumed to be symmetric. + normalize: Whether to normalize the Laplacian as + :math:`L^{sym} = \left(D^+\right)^{\frac{1}{2}} L + \left(D^+\right)^{\frac{1}{2}}`, where :math:`L` is the + non-normalized Laplacian and :math:`D` is the degree matrix. + kwargs: Keyword arguments for :class:`~ott.geometry.graph.Graph`. + + Returns: + The graph geometry. + """ + assert G.shape[0] == G.shape[1], G.shape + print(G.shape) + if directed: + G = G + G.T + + degree = jnp.sum(G, axis=1) + laplacian = jnp.diag(degree) - G + + if normalize: + inv_sqrt_deg = jnp.diag( + jnp.where(degree > 0.0, 1.0 / jnp.sqrt(degree), 0.0) + ) + laplacian = inv_sqrt_deg @ laplacian @ inv_sqrt_deg + + if t is None: + t = (jnp.sum(G) / jnp.sum(G > 0.)) ** 2 + + return cls(laplacian, t=t, **kwargs) + + + + def apply_kernel( + self, + scaling: jnp.ndarray, + eps: Optional[float] = None, + axis: int = 0, + ) -> jnp.ndarray: + r"""Apply :attr:`kernel_matrix` on positive scaling vector. + + Args: + scaling: Scaling to apply the kernel to. + eps: passed for consistency, not used yet. + axis: passed for consistency, not used yet. + + Returns: + Kernel applied to ``scaling``. + """ + + def compute_laplacian(adjacency_matrix: jnp.ndarray) -> jnp.ndarray: + """ + Compute the Laplacian matrix from the adjacency matrix. + + Args: + adjacency_matrix: An (n, n) array representing the adjacency matrix of a graph. + + Returns: + An (n, n) array representing the Laplacian matrix. + """ + degree_matrix = jnp.diag(jnp.sum(adjacency_matrix, axis=0)) + laplacian_matrix = degree_matrix - adjacency_matrix + return laplacian_matrix + + def compute_largest_eigenvalue(laplacian_matrix, k): + """ + Compute the largest eigenvalue of the Laplacian matrix. + + Args: + laplacian_matrix: An (n, n) array representing the Laplacian matrix. + k: Number of eigenvalues/vectors to compute. + + Returns: + The largest eigenvalue of the Laplacian matrix. + """ + n = laplacian_matrix.shape[0] + initial_directions = jax.random.normal(jax.random.PRNGKey(0), (n, k)) + # Convert the Laplacian matrix to a dense array + #laplacian_array = laplacian_matrix.toarray() + eigvals, _, _ = lobpcg_standard(laplacian_matrix, initial_directions, m=k) + largest_eigenvalue = np.max(eigvals) + return largest_eigenvalue + + def rescale_laplacian(laplacian_matrix: jnp.ndarray) -> jnp.ndarray: + """ + Rescale the Laplacian matrix. + + Args: + laplacian_matrix: An (n, n) array representing the Laplacian matrix. + + Returns: + The rescaled Laplacian matrix. + """ + largest_eigenvalue = compute_largest_eigenvalue(laplacian_matrix, k=1) + if largest_eigenvalue > 2: + rescaled_laplacian = laplacian_matrix.copy() + rescaled_laplacian /= largest_eigenvalue + laplacian_matrix = 2 * rescaled_laplacian + return laplacian_matrix + + def define_scaled_laplacian(laplacian_matrix: jnp.ndarray) -> jnp.ndarray: + """ + Define the scaled Laplacian matrix. + + Args: + laplacian_matrix: An (n, n) array representing the Laplacian matrix. + + Returns: + The scaled Laplacian matrix. + """ + n = laplacian_matrix.shape[0] + identity = jnp.eye(n) + scaled_laplacian = laplacian_matrix - identity + return scaled_laplacian + + def chebyshev_coefficients(t: float, max_order: int) -> List[float]: + """ + Compute the coefficients of the Chebyshev polynomial approximation using Bessel functions. + + Args: + t: Time parameter. + max_order: Maximum order of the Chebyshev polynomial approximation. + + Returns: + A list of coefficients. + """ + return (2 * ive(jnp.arange(0, max_order + 1), -t)).tolist() + + def compute_chebyshev_approximation( + x: jnp.ndarray, coeffs: List[float] + ) -> jnp.ndarray: + """ + Compute the Chebyshev polynomial approximation for the given input and coefficients. + + Args: + x: Input to evaluate the polynomial at. + coeffs: List of Chebyshev polynomial coefficients. + + Returns: + The Chebyshev polynomial approximation evaluated at x. + """ + return self.apply_kernel(x, coeffs) + + #laplacian_matrix = compute_laplacian(self.adjacency_matrix) + rescaled_laplacian = rescale_laplacian(self.laplacian) + scaled_laplacian = define_scaled_laplacian(rescaled_laplacian) + chebyshev_coeffs = chebyshev_coefficients(self.t, self.n_steps) + + laplacian_times_signal = scaled_laplacian.dot(scaling) # Apply the kernel + + chebyshev_approx_on_signal = compute_chebyshev_approximation( + laplacian_times_signal, chebyshev_coeffs) + + return chebyshev_approx_on_signal + + @property + def kernel_matrix(self) -> jnp.ndarray: # noqa: D102 + n, _ = self.shape + kernel = self.apply_kernel(jnp.eye(n)) + # force symmetry because of numerical imprecision + # happens when `numerical_scheme='backward_euler'` and small `t` + return (kernel + kernel.T) * 0.5 + + @property + def cost_matrix(self) -> jnp.ndarray: # noqa: D102 + return -self.t * mu.safe_log(self.kernel_matrix) + + @property + def _scale(self) -> float: + """Constant used to scale the Laplacian.""" + if self.numerical_scheme == "backward_euler": + return self.t / (4. * self.n_steps) + if self.numerical_scheme == "crank_nicolson": + return self.t / (2. * self.n_steps) + raise NotImplementedError( + f"Numerical scheme `{self.numerical_scheme}` is not implemented." + ) + + @property + def _scaled_laplacian(self) -> jnp.ndarray: + """Laplacian scaled by a constant, depending on the numerical scheme.""" + return self._scale * self.laplacian + + @property + def _M(self) -> jnp.ndarray: + n, _ = self.shape + return self._scaled_laplacian + jnp.eye(n) + + @property + def shape(self) -> Tuple[int, int]: # noqa: D102 + return self.laplacian.shape + + @property + def is_symmetric(self) -> bool: # noqa: D102 + return True + + @property + def dtype(self) -> jnp.dtype: # noqa: D102 + return self.laplacian.dtype + + def transport_from_potentials( + self, f: jnp.ndarray, g: jnp.ndarray + ) -> jnp.ndarray: + """Not implemented.""" + raise ValueError("Not implemented.") + + def apply_transport_from_potentials( + self, + f: jnp.ndarray, + g: jnp.ndarray, + vec: jnp.ndarray, + axis: int = 0 + ) -> jnp.ndarray: + """Since applying from potentials is not feasible in grids, use scalings.""" + u, v = self.scaling_from_potential(f), self.scaling_from_potential(g) + return self.apply_transport_from_scalings(u, v, vec, axis=axis) + + def marginal_from_potentials( + self, + f: jnp.ndarray, + g: jnp.ndarray, + axis: int = 0, + ) -> jnp.ndarray: + """Not implemented.""" + raise ValueError("Not implemented.") + + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 + return [self.laplacian, self.t], { + "n_steps": self.n_steps, + "numerical_scheme": self.numerical_scheme, + "tol": self.tol, + } + + @classmethod + def tree_unflatten( # noqa: D102 + cls, aux_data: Dict[str, Any], children: Sequence[Any] + ) -> "Graph": + return cls(*children, **aux_data) \ No newline at end of file From 03a57f953f6b20371f4245de1b26952809910249 Mon Sep 17 00:00:00 2001 From: diegoabt Date: Tue, 5 Sep 2023 15:59:53 +0200 Subject: [PATCH 02/44] Lint code --- src/ott/geometry/geodesic.py | 83 +++++++++++++++--------------------- 1 file changed, 35 insertions(+), 48 deletions(-) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index 812f582b7..bee4dd105 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -12,24 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Literal, Optional, Sequence, Tuple, List +from typing import Any, Dict, List, Optional, Sequence, Tuple import jax import jax.numpy as jnp -from jax.experimental.sparse.linalg import lobpcg_standard - import numpy as np +from jax.experimental.sparse.linalg import lobpcg_standard from scipy.special import ive from ott.geometry import geometry from ott.math import utils as mu - __all__ = ["Geodesic"] @jax.tree_util.register_pytree_node_class -class Geodesic(geometry.Geometry): #TODO: Rewrite docstring +class Geodesic(geometry.Geometry): r"""Graph distance approximation using heat kernel :cite:`heitz:21,crane:13`. Approximates the heat kernel for large ``n_steps``, which for small ``t`` @@ -94,7 +92,7 @@ def from_graph( The graph geometry. """ assert G.shape[0] == G.shape[1], G.shape - print(G.shape) + if directed: G = G + G.T @@ -112,42 +110,38 @@ def from_graph( return cls(laplacian, t=t, **kwargs) - - def apply_kernel( - self, - scaling: jnp.ndarray, - eps: Optional[float] = None, - axis: int = 0, - ) -> jnp.ndarray: + self, + scaling: jnp.ndarray, + eps: Optional[float] = None, + axis: int = 0, + ) -> jnp.ndarray: r"""Apply :attr:`kernel_matrix` on positive scaling vector. - Args: - scaling: Scaling to apply the kernel to. - eps: passed for consistency, not used yet. - axis: passed for consistency, not used yet. + Args: + scaling: Scaling to apply the kernel to. + eps: passed for consistency, not used yet. + axis: passed for consistency, not used yet. - Returns: - Kernel applied to ``scaling``. - """ + Returns: + Kernel applied to ``scaling``. + """ def compute_laplacian(adjacency_matrix: jnp.ndarray) -> jnp.ndarray: - """ - Compute the Laplacian matrix from the adjacency matrix. + """Compute the Laplacian matrix from the adjacency matrix. Args: - adjacency_matrix: An (n, n) array representing the adjacency matrix of a graph. + adjacency_matrix: An (n, n) array representing the + adjacency matrix of a graph. Returns: An (n, n) array representing the Laplacian matrix. """ degree_matrix = jnp.diag(jnp.sum(adjacency_matrix, axis=0)) - laplacian_matrix = degree_matrix - adjacency_matrix - return laplacian_matrix + return degree_matrix - adjacency_matrix def compute_largest_eigenvalue(laplacian_matrix, k): - """ - Compute the largest eigenvalue of the Laplacian matrix. + """Compute the largest eigenvalue of the Laplacian matrix. Args: laplacian_matrix: An (n, n) array representing the Laplacian matrix. @@ -161,12 +155,11 @@ def compute_largest_eigenvalue(laplacian_matrix, k): # Convert the Laplacian matrix to a dense array #laplacian_array = laplacian_matrix.toarray() eigvals, _, _ = lobpcg_standard(laplacian_matrix, initial_directions, m=k) - largest_eigenvalue = np.max(eigvals) - return largest_eigenvalue + + return np.max(eigvals) def rescale_laplacian(laplacian_matrix: jnp.ndarray) -> jnp.ndarray: - """ - Rescale the Laplacian matrix. + """Rescale the Laplacian matrix. Args: laplacian_matrix: An (n, n) array representing the Laplacian matrix. @@ -178,12 +171,10 @@ def rescale_laplacian(laplacian_matrix: jnp.ndarray) -> jnp.ndarray: if largest_eigenvalue > 2: rescaled_laplacian = laplacian_matrix.copy() rescaled_laplacian /= largest_eigenvalue - laplacian_matrix = 2 * rescaled_laplacian - return laplacian_matrix + return 2 * rescaled_laplacian def define_scaled_laplacian(laplacian_matrix: jnp.ndarray) -> jnp.ndarray: - """ - Define the scaled Laplacian matrix. + """Define the scaled Laplacian matrix. Args: laplacian_matrix: An (n, n) array representing the Laplacian matrix. @@ -193,12 +184,10 @@ def define_scaled_laplacian(laplacian_matrix: jnp.ndarray) -> jnp.ndarray: """ n = laplacian_matrix.shape[0] identity = jnp.eye(n) - scaled_laplacian = laplacian_matrix - identity - return scaled_laplacian + return laplacian_matrix - identity def chebyshev_coefficients(t: float, max_order: int) -> List[float]: - """ - Compute the coefficients of the Chebyshev polynomial approximation using Bessel functions. + """Compute the coeffs of the Chebyshev pols approx using Bessel functs. Args: t: Time parameter. @@ -210,10 +199,9 @@ def chebyshev_coefficients(t: float, max_order: int) -> List[float]: return (2 * ive(jnp.arange(0, max_order + 1), -t)).tolist() def compute_chebyshev_approximation( - x: jnp.ndarray, coeffs: List[float] + x: jnp.ndarray, coeffs: List[float] ) -> jnp.ndarray: - """ - Compute the Chebyshev polynomial approximation for the given input and coefficients. + """Compute the Chebyshev polynomial approx for the given input and coeffs. Args: x: Input to evaluate the polynomial at. @@ -231,10 +219,9 @@ def compute_chebyshev_approximation( laplacian_times_signal = scaled_laplacian.dot(scaling) # Apply the kernel - chebyshev_approx_on_signal = compute_chebyshev_approximation( - laplacian_times_signal, chebyshev_coeffs) - - return chebyshev_approx_on_signal + return compute_chebyshev_approximation( + laplacian_times_signal, chebyshev_coeffs + ) @property def kernel_matrix(self) -> jnp.ndarray: # noqa: D102 @@ -317,5 +304,5 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 @classmethod def tree_unflatten( # noqa: D102 cls, aux_data: Dict[str, Any], children: Sequence[Any] - ) -> "Graph": - return cls(*children, **aux_data) \ No newline at end of file + ) -> "Geodesic": + return cls(*children, **aux_data) From fb4de65c1a21291ab8efc61f034cd5faa6bf893b Mon Sep 17 00:00:00 2001 From: Diego Baptista Theuerkauf <34717973+diegoabt@users.noreply.github.com> Date: Tue, 5 Sep 2023 16:40:36 +0200 Subject: [PATCH 03/44] Lint code --- src/ott/geometry/geodesic.py | 83 +++++++++++++++--------------------- 1 file changed, 35 insertions(+), 48 deletions(-) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index 812f582b7..bee4dd105 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -12,24 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Literal, Optional, Sequence, Tuple, List +from typing import Any, Dict, List, Optional, Sequence, Tuple import jax import jax.numpy as jnp -from jax.experimental.sparse.linalg import lobpcg_standard - import numpy as np +from jax.experimental.sparse.linalg import lobpcg_standard from scipy.special import ive from ott.geometry import geometry from ott.math import utils as mu - __all__ = ["Geodesic"] @jax.tree_util.register_pytree_node_class -class Geodesic(geometry.Geometry): #TODO: Rewrite docstring +class Geodesic(geometry.Geometry): r"""Graph distance approximation using heat kernel :cite:`heitz:21,crane:13`. Approximates the heat kernel for large ``n_steps``, which for small ``t`` @@ -94,7 +92,7 @@ def from_graph( The graph geometry. """ assert G.shape[0] == G.shape[1], G.shape - print(G.shape) + if directed: G = G + G.T @@ -112,42 +110,38 @@ def from_graph( return cls(laplacian, t=t, **kwargs) - - def apply_kernel( - self, - scaling: jnp.ndarray, - eps: Optional[float] = None, - axis: int = 0, - ) -> jnp.ndarray: + self, + scaling: jnp.ndarray, + eps: Optional[float] = None, + axis: int = 0, + ) -> jnp.ndarray: r"""Apply :attr:`kernel_matrix` on positive scaling vector. - Args: - scaling: Scaling to apply the kernel to. - eps: passed for consistency, not used yet. - axis: passed for consistency, not used yet. + Args: + scaling: Scaling to apply the kernel to. + eps: passed for consistency, not used yet. + axis: passed for consistency, not used yet. - Returns: - Kernel applied to ``scaling``. - """ + Returns: + Kernel applied to ``scaling``. + """ def compute_laplacian(adjacency_matrix: jnp.ndarray) -> jnp.ndarray: - """ - Compute the Laplacian matrix from the adjacency matrix. + """Compute the Laplacian matrix from the adjacency matrix. Args: - adjacency_matrix: An (n, n) array representing the adjacency matrix of a graph. + adjacency_matrix: An (n, n) array representing the + adjacency matrix of a graph. Returns: An (n, n) array representing the Laplacian matrix. """ degree_matrix = jnp.diag(jnp.sum(adjacency_matrix, axis=0)) - laplacian_matrix = degree_matrix - adjacency_matrix - return laplacian_matrix + return degree_matrix - adjacency_matrix def compute_largest_eigenvalue(laplacian_matrix, k): - """ - Compute the largest eigenvalue of the Laplacian matrix. + """Compute the largest eigenvalue of the Laplacian matrix. Args: laplacian_matrix: An (n, n) array representing the Laplacian matrix. @@ -161,12 +155,11 @@ def compute_largest_eigenvalue(laplacian_matrix, k): # Convert the Laplacian matrix to a dense array #laplacian_array = laplacian_matrix.toarray() eigvals, _, _ = lobpcg_standard(laplacian_matrix, initial_directions, m=k) - largest_eigenvalue = np.max(eigvals) - return largest_eigenvalue + + return np.max(eigvals) def rescale_laplacian(laplacian_matrix: jnp.ndarray) -> jnp.ndarray: - """ - Rescale the Laplacian matrix. + """Rescale the Laplacian matrix. Args: laplacian_matrix: An (n, n) array representing the Laplacian matrix. @@ -178,12 +171,10 @@ def rescale_laplacian(laplacian_matrix: jnp.ndarray) -> jnp.ndarray: if largest_eigenvalue > 2: rescaled_laplacian = laplacian_matrix.copy() rescaled_laplacian /= largest_eigenvalue - laplacian_matrix = 2 * rescaled_laplacian - return laplacian_matrix + return 2 * rescaled_laplacian def define_scaled_laplacian(laplacian_matrix: jnp.ndarray) -> jnp.ndarray: - """ - Define the scaled Laplacian matrix. + """Define the scaled Laplacian matrix. Args: laplacian_matrix: An (n, n) array representing the Laplacian matrix. @@ -193,12 +184,10 @@ def define_scaled_laplacian(laplacian_matrix: jnp.ndarray) -> jnp.ndarray: """ n = laplacian_matrix.shape[0] identity = jnp.eye(n) - scaled_laplacian = laplacian_matrix - identity - return scaled_laplacian + return laplacian_matrix - identity def chebyshev_coefficients(t: float, max_order: int) -> List[float]: - """ - Compute the coefficients of the Chebyshev polynomial approximation using Bessel functions. + """Compute the coeffs of the Chebyshev pols approx using Bessel functs. Args: t: Time parameter. @@ -210,10 +199,9 @@ def chebyshev_coefficients(t: float, max_order: int) -> List[float]: return (2 * ive(jnp.arange(0, max_order + 1), -t)).tolist() def compute_chebyshev_approximation( - x: jnp.ndarray, coeffs: List[float] + x: jnp.ndarray, coeffs: List[float] ) -> jnp.ndarray: - """ - Compute the Chebyshev polynomial approximation for the given input and coefficients. + """Compute the Chebyshev polynomial approx for the given input and coeffs. Args: x: Input to evaluate the polynomial at. @@ -231,10 +219,9 @@ def compute_chebyshev_approximation( laplacian_times_signal = scaled_laplacian.dot(scaling) # Apply the kernel - chebyshev_approx_on_signal = compute_chebyshev_approximation( - laplacian_times_signal, chebyshev_coeffs) - - return chebyshev_approx_on_signal + return compute_chebyshev_approximation( + laplacian_times_signal, chebyshev_coeffs + ) @property def kernel_matrix(self) -> jnp.ndarray: # noqa: D102 @@ -317,5 +304,5 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 @classmethod def tree_unflatten( # noqa: D102 cls, aux_data: Dict[str, Any], children: Sequence[Any] - ) -> "Graph": - return cls(*children, **aux_data) \ No newline at end of file + ) -> "Geodesic": + return cls(*children, **aux_data) From dd1fd959649aff2197e7c81896c1d70621607b85 Mon Sep 17 00:00:00 2001 From: diegoabt Date: Fri, 8 Sep 2023 17:02:32 +0200 Subject: [PATCH 04/44] Add Geodesic kernel citation to `docs/references.bib` --- docs/references.bib | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/references.bib b/docs/references.bib index f2f59d870..47623967f 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -794,3 +794,10 @@ @misc{klein:23 title = {Learning Costs for Structured Monge Displacements}, year = {2023}, } + +@article{huguet:2022, + title={Geodesic Sinkhorn: optimal transport for high-dimensional datasets}, + author={Huguet, Guillaume and Tong, Alexander and Zapatero, Mar{\'\i}a Ramos and Wolf, Guy and Krishnaswamy, Smita}, + journal={arXiv preprint arXiv:2211.00805}, + year={2022} +} From 8c087f49fb87fe48e7c8a8f9ee7b1c6a4db38be9 Mon Sep 17 00:00:00 2001 From: diegoabt Date: Fri, 8 Sep 2023 17:20:50 +0200 Subject: [PATCH 05/44] Remove unused functions; update docstrings; remove `n_steps`; remove forced symm --- src/ott/geometry/geodesic.py | 77 ++++++++++++------------------------ 1 file changed, 25 insertions(+), 52 deletions(-) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index bee4dd105..2a075b706 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -21,30 +21,22 @@ from scipy.special import ive from ott.geometry import geometry -from ott.math import utils as mu __all__ = ["Geodesic"] @jax.tree_util.register_pytree_node_class class Geodesic(geometry.Geometry): - r"""Graph distance approximation using heat kernel :cite:`heitz:21,crane:13`. + r"""Graph distance approximation using heat kernel :cite:`huguet:2022`. - Approximates the heat kernel for large ``n_steps``, which for small ``t`` - approximates the geodesic exponential kernel :math:`e^{\frac{-d(x, y)^2}{t}}`. + Approximates the heat-geodesic kernel using thee Chebyshev polynomials of the + first kind of max order ``order``, which for small ``t`` approximates the + geodesic exponential kernel :math:`e^{\frac{-d(x, y)^2}{t}}`. Args: - laplacian: Symmetric graph Laplacian. The check for symmetry is **NOT** - performed. See also :meth:`from_graph`. - n_steps: Maximum number of steps used to approximate the heat kernel. - numerical_scheme: Numerical scheme used to solve the heat diffusion. - normalize: Whether to normalize the Laplacian as - :math:`L^{sym} = \left(D^+\right)^{\frac{1}{2}} L - \left(D^+\right)^{\frac{1}{2}}`, where :math:`L` is the - non-normalized Laplacian and :math:`D` is the degree matrix. - tol: Relative tolerance with respect to the Hilbert metric, see - :cite:`peyre:19`, Remark 4.12. Used when iteratively updating scalings. - If negative, this option is ignored and only ``n_steps`` is used. + laplacian: Symmetric graph Laplacian. + t: Time parameter for heat kernel. + order: Max order of Chebyshev polynomial. kwargs: Keyword arguments for :class:`~ott.geometry.geometry.Geometry`. """ @@ -52,15 +44,13 @@ def __init__( self, laplacian: jnp.ndarray, t: float = 1e-3, - n_steps: int = 100, - tol: float = -1.0, + order: int = 100, **kwargs: Any ): super().__init__(epsilon=1., **kwargs) self.laplacian = laplacian self.t = t - self.n_steps = n_steps - self.tol = tol + self.order = order @classmethod def from_graph( @@ -71,25 +61,25 @@ def from_graph( normalize: bool = False, **kwargs: Any ) -> "Geodesic": - r"""Construct :class:`~ott.geometry.graph.Graph` from an adjacency matrix. + r"""Construct a Geodesic geometry from an adjacency matrix. Args: G: Adjacency matrix. - t: Constant used when approximating the geodesic exponential kernel. - If `None`, use :math:`\frac{1}{|E|} \sum_{(u, v) \in E} weight(u, v)` - :cite:`crane:13`. In this case, the ``graph`` must be specified - and the edge weights are all assumed to be positive. - directed: Whether the ``graph`` is directed. If not, it will be made - undirected as :math:`G + G^T`. This parameter is ignored when directly - passing the Laplacian, which is assumed to be symmetric. + t: Time parameter for approximating the geodesic exponential kernel. + If `None`, it defaults to :math:`\frac{1}{|E|} \sum_{(u, v) \in E} + \text{weight}(u, v)` :cite:`crane:13`. In this case, the ``graph`` + must be specified and the edge weights are assumed to be positive. + directed: Whether the ``graph`` is directed. If not, it's made + undirected as :math:`G + G^T`. This parameter is ignored when passing + the Laplacian directly, assumed to be symmetric. normalize: Whether to normalize the Laplacian as :math:`L^{sym} = \left(D^+\right)^{\frac{1}{2}} L \left(D^+\right)^{\frac{1}{2}}`, where :math:`L` is the non-normalized Laplacian and :math:`D` is the degree matrix. - kwargs: Keyword arguments for :class:`~ott.geometry.graph.Graph`. + kwargs: Keyword arguments for the Geodesic class. Returns: - The graph geometry. + The Geodesic geometry. """ assert G.shape[0] == G.shape[1], G.shape @@ -150,7 +140,7 @@ def compute_largest_eigenvalue(laplacian_matrix, k): Returns: The largest eigenvalue of the Laplacian matrix. """ - n = laplacian_matrix.shape[0] + n, _ = self.shape initial_directions = jax.random.normal(jax.random.PRNGKey(0), (n, k)) # Convert the Laplacian matrix to a dense array #laplacian_array = laplacian_matrix.toarray() @@ -215,7 +205,7 @@ def compute_chebyshev_approximation( #laplacian_matrix = compute_laplacian(self.adjacency_matrix) rescaled_laplacian = rescale_laplacian(self.laplacian) scaled_laplacian = define_scaled_laplacian(rescaled_laplacian) - chebyshev_coeffs = chebyshev_coefficients(self.t, self.n_steps) + chebyshev_coeffs = chebyshev_coefficients(self.t, self.order) laplacian_times_signal = scaled_laplacian.dot(scaling) # Apply the kernel @@ -226,36 +216,19 @@ def compute_chebyshev_approximation( @property def kernel_matrix(self) -> jnp.ndarray: # noqa: D102 n, _ = self.shape - kernel = self.apply_kernel(jnp.eye(n)) - # force symmetry because of numerical imprecision - # happens when `numerical_scheme='backward_euler'` and small `t` - return (kernel + kernel.T) * 0.5 - - @property - def cost_matrix(self) -> jnp.ndarray: # noqa: D102 - return -self.t * mu.safe_log(self.kernel_matrix) + return self.apply_kernel(jnp.eye(n)) @property def _scale(self) -> float: """Constant used to scale the Laplacian.""" if self.numerical_scheme == "backward_euler": - return self.t / (4. * self.n_steps) + return self.t / (4. * self.order) if self.numerical_scheme == "crank_nicolson": - return self.t / (2. * self.n_steps) + return self.t / (2. * self.order) raise NotImplementedError( f"Numerical scheme `{self.numerical_scheme}` is not implemented." ) - @property - def _scaled_laplacian(self) -> jnp.ndarray: - """Laplacian scaled by a constant, depending on the numerical scheme.""" - return self._scale * self.laplacian - - @property - def _M(self) -> jnp.ndarray: - n, _ = self.shape - return self._scaled_laplacian + jnp.eye(n) - @property def shape(self) -> Tuple[int, int]: # noqa: D102 return self.laplacian.shape @@ -296,7 +269,7 @@ def marginal_from_potentials( def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return [self.laplacian, self.t], { - "n_steps": self.n_steps, + "order": self.order, "numerical_scheme": self.numerical_scheme, "tol": self.tol, } From 408b14c18abab7d5eb47f4122e8b518e42bae6f3 Mon Sep 17 00:00:00 2001 From: diegoabt Date: Fri, 8 Sep 2023 17:42:45 +0200 Subject: [PATCH 06/44] Remove hardcoded random key at `compute_largest_eigenvalue` --- src/ott/geometry/geodesic.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index 2a075b706..b7096ad04 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -21,6 +21,7 @@ from scipy.special import ive from ott.geometry import geometry +from ott.utils import default_prng_key __all__ = ["Geodesic"] @@ -141,9 +142,8 @@ def compute_largest_eigenvalue(laplacian_matrix, k): The largest eigenvalue of the Laplacian matrix. """ n, _ = self.shape - initial_directions = jax.random.normal(jax.random.PRNGKey(0), (n, k)) - # Convert the Laplacian matrix to a dense array - #laplacian_array = laplacian_matrix.toarray() + prng_key = default_prng_key() + initial_directions = jax.random.normal(prng_key, (n, k)) eigvals, _, _ = lobpcg_standard(laplacian_matrix, initial_directions, m=k) return np.max(eigvals) From 1e33f10470cbfa63b6dd0e19960394f1aa6a0215 Mon Sep 17 00:00:00 2001 From: diegoabt Date: Fri, 8 Sep 2023 17:48:31 +0200 Subject: [PATCH 07/44] Fix docstrings of functions inside `apply_kernel` --- src/ott/geometry/geodesic.py | 62 +++--------------------------------- 1 file changed, 5 insertions(+), 57 deletions(-) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index b7096ad04..32ad8b952 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -118,29 +118,8 @@ def apply_kernel( Kernel applied to ``scaling``. """ - def compute_laplacian(adjacency_matrix: jnp.ndarray) -> jnp.ndarray: - """Compute the Laplacian matrix from the adjacency matrix. - - Args: - adjacency_matrix: An (n, n) array representing the - adjacency matrix of a graph. - - Returns: - An (n, n) array representing the Laplacian matrix. - """ - degree_matrix = jnp.diag(jnp.sum(adjacency_matrix, axis=0)) - return degree_matrix - adjacency_matrix - def compute_largest_eigenvalue(laplacian_matrix, k): - """Compute the largest eigenvalue of the Laplacian matrix. - - Args: - laplacian_matrix: An (n, n) array representing the Laplacian matrix. - k: Number of eigenvalues/vectors to compute. - - Returns: - The largest eigenvalue of the Laplacian matrix. - """ + # Compute the largest eigenvalue of the Laplacian matrix. n, _ = self.shape prng_key = default_prng_key() initial_directions = jax.random.normal(prng_key, (n, k)) @@ -149,14 +128,7 @@ def compute_largest_eigenvalue(laplacian_matrix, k): return np.max(eigvals) def rescale_laplacian(laplacian_matrix: jnp.ndarray) -> jnp.ndarray: - """Rescale the Laplacian matrix. - - Args: - laplacian_matrix: An (n, n) array representing the Laplacian matrix. - - Returns: - The rescaled Laplacian matrix. - """ + # Rescale the Laplacian matrix. largest_eigenvalue = compute_largest_eigenvalue(laplacian_matrix, k=1) if largest_eigenvalue > 2: rescaled_laplacian = laplacian_matrix.copy() @@ -164,45 +136,21 @@ def rescale_laplacian(laplacian_matrix: jnp.ndarray) -> jnp.ndarray: return 2 * rescaled_laplacian def define_scaled_laplacian(laplacian_matrix: jnp.ndarray) -> jnp.ndarray: - """Define the scaled Laplacian matrix. - - Args: - laplacian_matrix: An (n, n) array representing the Laplacian matrix. - - Returns: - The scaled Laplacian matrix. - """ + # Define the scaled Laplacian matrix. n = laplacian_matrix.shape[0] identity = jnp.eye(n) return laplacian_matrix - identity def chebyshev_coefficients(t: float, max_order: int) -> List[float]: - """Compute the coeffs of the Chebyshev pols approx using Bessel functs. - - Args: - t: Time parameter. - max_order: Maximum order of the Chebyshev polynomial approximation. - - Returns: - A list of coefficients. - """ + # Compute the coeffs of the Chebyshev pols approx using Bessel functs. return (2 * ive(jnp.arange(0, max_order + 1), -t)).tolist() def compute_chebyshev_approximation( x: jnp.ndarray, coeffs: List[float] ) -> jnp.ndarray: - """Compute the Chebyshev polynomial approx for the given input and coeffs. - - Args: - x: Input to evaluate the polynomial at. - coeffs: List of Chebyshev polynomial coefficients. - - Returns: - The Chebyshev polynomial approximation evaluated at x. - """ + # Compute the Chebyshev polynomial approx for the given input and coeffs. return self.apply_kernel(x, coeffs) - #laplacian_matrix = compute_laplacian(self.adjacency_matrix) rescaled_laplacian = rescale_laplacian(self.laplacian) scaled_laplacian = define_scaled_laplacian(rescaled_laplacian) chebyshev_coeffs = chebyshev_coefficients(self.t, self.order) From d901e5ce1a07b09987af6ffbc79ce8bce2f87f7d Mon Sep 17 00:00:00 2001 From: diegoabt Date: Mon, 25 Sep 2023 11:07:01 +0200 Subject: [PATCH 08/44] Add chebyshev coeff computation to `from_graph` --- src/ott/geometry/geodesic.py | 44 ++++++++++++++++++++++++------------ 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index 32ad8b952..f492983a6 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -17,11 +17,13 @@ import jax import jax.numpy as jnp import numpy as np +import scipy.sparse as sp from jax.experimental.sparse.linalg import lobpcg_standard from scipy.special import ive from ott.geometry import geometry -from ott.utils import default_prng_key + +#from ott.utils import default_prng_key __all__ = ["Geodesic"] @@ -52,12 +54,14 @@ def __init__( self.laplacian = laplacian self.t = t self.order = order + self.chebyshev_coeffs = None @classmethod def from_graph( cls, G: jnp.ndarray, t: Optional[float] = 1e-3, + order=100, directed: bool = False, normalize: bool = False, **kwargs: Any @@ -70,6 +74,7 @@ def from_graph( If `None`, it defaults to :math:`\frac{1}{|E|} \sum_{(u, v) \in E} \text{weight}(u, v)` :cite:`crane:13`. In this case, the ``graph`` must be specified and the edge weights are assumed to be positive. + order: Max order of Chebyshev polynomial. directed: Whether the ``graph`` is directed. If not, it's made undirected as :math:`G + G^T`. This parameter is ignored when passing the Laplacian directly, assumed to be symmetric. @@ -99,7 +104,19 @@ def from_graph( if t is None: t = (jnp.sum(G) / jnp.sum(G > 0.)) ** 2 - return cls(laplacian, t=t, **kwargs) + # Create an instance of the Geodesic class and set the attribute + geodesic_instance = cls(laplacian, t=t, order=order, **kwargs) + + # Compute the coeffs of the Chebyshev pols approx using Bessel functs. + chebyshev_coeffs = ( + 2 * + ive(jnp.arange(0, geodesic_instance.order + 1), -geodesic_instance.t) + ).tolist() + + # Set the attribute + geodesic_instance.chebyshev_coeffs = chebyshev_coeffs + + return geodesic_instance def apply_kernel( self, @@ -118,11 +135,12 @@ def apply_kernel( Kernel applied to ``scaling``. """ - def compute_largest_eigenvalue(laplacian_matrix, k): + def compute_largest_eigenvalue(laplacian_matrix, k, seed=None): # Compute the largest eigenvalue of the Laplacian matrix. + if seed is None: + seed = jax.random.PRNGKey(0) n, _ = self.shape - prng_key = default_prng_key() - initial_directions = jax.random.normal(prng_key, (n, k)) + initial_directions = jax.random.normal(seed, (n, k)) eigvals, _, _ = lobpcg_standard(laplacian_matrix, initial_directions, m=k) return np.max(eigvals) @@ -133,7 +151,8 @@ def rescale_laplacian(laplacian_matrix: jnp.ndarray) -> jnp.ndarray: if largest_eigenvalue > 2: rescaled_laplacian = laplacian_matrix.copy() rescaled_laplacian /= largest_eigenvalue - return 2 * rescaled_laplacian + return 2 * rescaled_laplacian + return laplacian_matrix def define_scaled_laplacian(laplacian_matrix: jnp.ndarray) -> jnp.ndarray: # Define the scaled Laplacian matrix. @@ -141,24 +160,21 @@ def define_scaled_laplacian(laplacian_matrix: jnp.ndarray) -> jnp.ndarray: identity = jnp.eye(n) return laplacian_matrix - identity - def chebyshev_coefficients(t: float, max_order: int) -> List[float]: - # Compute the coeffs of the Chebyshev pols approx using Bessel functs. - return (2 * ive(jnp.arange(0, max_order + 1), -t)).tolist() - def compute_chebyshev_approximation( x: jnp.ndarray, coeffs: List[float] ) -> jnp.ndarray: - # Compute the Chebyshev polynomial approx for the given input and coeffs. - return self.apply_kernel(x, coeffs) + # Compute the Chebyshev polynomial approximation for + # the given input and coefficients. + x_dense = x.toarray() if sp.issparse(x) else x + return np.polynomial.chebyshev.chebval(x_dense, coeffs) rescaled_laplacian = rescale_laplacian(self.laplacian) scaled_laplacian = define_scaled_laplacian(rescaled_laplacian) - chebyshev_coeffs = chebyshev_coefficients(self.t, self.order) laplacian_times_signal = scaled_laplacian.dot(scaling) # Apply the kernel return compute_chebyshev_approximation( - laplacian_times_signal, chebyshev_coeffs + laplacian_times_signal, self.chebyshev_coeffs ) @property From 01eb7e1341183a2ada5c81146ac62f4d13119291 Mon Sep 17 00:00:00 2001 From: diegoabt Date: Tue, 26 Sep 2023 17:21:11 +0200 Subject: [PATCH 09/44] Change `jax.exp.sparse` import --- src/ott/geometry/geodesic.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index f492983a6..a12fc60fd 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -15,16 +15,14 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple import jax +import jax.experimental.sparse as jesp import jax.numpy as jnp import numpy as np import scipy.sparse as sp -from jax.experimental.sparse.linalg import lobpcg_standard from scipy.special import ive from ott.geometry import geometry -#from ott.utils import default_prng_key - __all__ = ["Geodesic"] @@ -141,7 +139,9 @@ def compute_largest_eigenvalue(laplacian_matrix, k, seed=None): seed = jax.random.PRNGKey(0) n, _ = self.shape initial_directions = jax.random.normal(seed, (n, k)) - eigvals, _, _ = lobpcg_standard(laplacian_matrix, initial_directions, m=k) + eigvals, _, _ = jesp.linalg.lobpcg_standard( + laplacian_matrix, initial_directions, m=k + ) return np.max(eigvals) From f223d7a3957b7b336f0c5411aa1d836a523d9b6b Mon Sep 17 00:00:00 2001 From: diegoabt Date: Tue, 26 Sep 2023 17:32:04 +0200 Subject: [PATCH 10/44] Change `tree_flatten` outputs --- src/ott/geometry/geodesic.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index a12fc60fd..1e84d40c7 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Sequence, Tuple +from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple import jax import jax.experimental.sparse as jesp @@ -46,13 +46,17 @@ def __init__( laplacian: jnp.ndarray, t: float = 1e-3, order: int = 100, + chebyshev_coeffs: Optional[List[float]] = None, + numerical_scheme: Literal["backward_euler", + "crank_nicolson"] = "backward_euler", **kwargs: Any ): super().__init__(epsilon=1., **kwargs) self.laplacian = laplacian self.t = t self.order = order - self.chebyshev_coeffs = None + self.chebyshev_coeffs = chebyshev_coeffs + self.numerical_scheme = numerical_scheme @classmethod def from_graph( @@ -232,10 +236,9 @@ def marginal_from_potentials( raise ValueError("Not implemented.") def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 - return [self.laplacian, self.t], { - "order": self.order, + return [self.laplacian, self.t, self.order], { + "chebyshev_coeffs": self.chebyshev_coeffs, "numerical_scheme": self.numerical_scheme, - "tol": self.tol, } @classmethod From 7bfa71c11ea03c205f66f5649de74ed160de6a61 Mon Sep 17 00:00:00 2001 From: diegoabt Date: Tue, 26 Sep 2023 18:28:20 +0200 Subject: [PATCH 11/44] Change input of `lobpcg_std` to be a sparsified product --- src/ott/geometry/geodesic.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index 1e84d40c7..2ceb1c0c5 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -142,9 +142,16 @@ def compute_largest_eigenvalue(laplacian_matrix, k, seed=None): if seed is None: seed = jax.random.PRNGKey(0) n, _ = self.shape - initial_directions = jax.random.normal(seed, (n, k)) + # Generate random initial directions for eigenvalue computation + initial_dirs = jax.random.normal(seed, (n, k)) + + # Create a sparse matrix-vector product function using sparsify + # This function multiplies the sparse laplacian_matrix with a vector + lapl_vector_product = jesp.sparsify(lambda v: laplacian_matrix @ v) + + # Compute eigenvalues using the sparse matrix-vector product eigvals, _, _ = jesp.linalg.lobpcg_standard( - laplacian_matrix, initial_directions, m=k + lapl_vector_product, initial_dirs, m=k ) return np.max(eigvals) @@ -237,8 +244,8 @@ def marginal_from_potentials( def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return [self.laplacian, self.t, self.order], { - "chebyshev_coeffs": self.chebyshev_coeffs, "numerical_scheme": self.numerical_scheme, + "chebyshev_coeffs": self.chebyshev_coeffs, } @classmethod From 0a3b8c86faf729003782ce96a83c0eaf9fa16aaa Mon Sep 17 00:00:00 2001 From: diegoabt Date: Wed, 27 Sep 2023 15:09:41 +0200 Subject: [PATCH 12/44] Change definition of cost from kernel --- src/ott/geometry/geodesic.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index 2ceb1c0c5..b01e154a5 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -193,6 +193,11 @@ def kernel_matrix(self) -> jnp.ndarray: # noqa: D102 n, _ = self.shape return self.apply_kernel(jnp.eye(n)) + @property + def cost_matrix(self) -> jnp.ndarray: # noqa: D102 + # Calculate the cost matrix using the formula (5) from the main reference + return -4 * self.t * jnp.log(self.kernel_matrix) + @property def _scale(self) -> float: """Constant used to scale the Laplacian.""" From c2e5770e4b40e6bc88fa6cd93b915d2907893f32 Mon Sep 17 00:00:00 2001 From: diegoabt Date: Thu, 28 Sep 2023 14:21:34 +0200 Subject: [PATCH 13/44] Add `Geodesic` to `docs/geometry` --- docs/geometry.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/geometry.rst b/docs/geometry.rst index adc0a3880..5b0d31e88 100644 --- a/docs/geometry.rst +++ b/docs/geometry.rst @@ -45,6 +45,7 @@ Geometries pointcloud.PointCloud grid.Grid graph.Graph + geodesic.Geodesic low_rank.LRCGeometry epsilon_scheduler.Epsilon From c1ec39d8adb18bc56d1c23ab8ac1df550fd8c463 Mon Sep 17 00:00:00 2001 From: diegoabt Date: Fri, 29 Sep 2023 14:38:51 +0200 Subject: [PATCH 14/44] Remove `np.max` from max eigenval computation --- src/ott/geometry/geodesic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index b01e154a5..33f3015af 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -154,7 +154,7 @@ def compute_largest_eigenvalue(laplacian_matrix, k, seed=None): lapl_vector_product, initial_dirs, m=k ) - return np.max(eigvals) + return jnp.max(eigvals) def rescale_laplacian(laplacian_matrix: jnp.ndarray) -> jnp.ndarray: # Rescale the Laplacian matrix. From 5830ebce40e9303dcb78479c42891f3bbc552817 Mon Sep 17 00:00:00 2001 From: diegoabt Date: Thu, 5 Oct 2023 14:00:46 +0200 Subject: [PATCH 15/44] Add `default_prng_key` to eigenval computation --- src/ott/geometry/geodesic.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index 33f3015af..eaacb122c 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -21,6 +21,7 @@ import scipy.sparse as sp from scipy.special import ive +from ott import utils from ott.geometry import geometry __all__ = ["Geodesic"] @@ -137,13 +138,13 @@ def apply_kernel( Kernel applied to ``scaling``. """ - def compute_largest_eigenvalue(laplacian_matrix, k, seed=None): + def compute_largest_eigenvalue(laplacian_matrix, k, rng=None): # Compute the largest eigenvalue of the Laplacian matrix. - if seed is None: - seed = jax.random.PRNGKey(0) + if rng is None: + rng = utils.default_prng_key(rng) n, _ = self.shape # Generate random initial directions for eigenvalue computation - initial_dirs = jax.random.normal(seed, (n, k)) + initial_dirs = jax.random.normal(rng, (n, k)) # Create a sparse matrix-vector product function using sparsify # This function multiplies the sparse laplacian_matrix with a vector From 1d060496dea9832346e93071649aa014b738c733 Mon Sep 17 00:00:00 2001 From: diegoabt Date: Mon, 30 Oct 2023 11:55:47 +0100 Subject: [PATCH 16/44] Add `safe_log` to cost matrix computation --- src/ott/geometry/geodesic.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index eaacb122c..32962acc5 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -23,6 +23,7 @@ from ott import utils from ott.geometry import geometry +from ott.math import utils as mu __all__ = ["Geodesic"] @@ -197,7 +198,7 @@ def kernel_matrix(self) -> jnp.ndarray: # noqa: D102 @property def cost_matrix(self) -> jnp.ndarray: # noqa: D102 # Calculate the cost matrix using the formula (5) from the main reference - return -4 * self.t * jnp.log(self.kernel_matrix) + return -4 * self.t * mu.safe_log(self.kernel_matrix) @property def _scale(self) -> float: From 8e14cdb1547a83822b6b959649d8244d122a3804 Mon Sep 17 00:00:00 2001 From: diegoabt Date: Mon, 30 Oct 2023 12:07:35 +0100 Subject: [PATCH 17/44] Change to `jesp.BCOO` at chebyshev approx --- src/ott/geometry/geodesic.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index 32962acc5..ddeb1e968 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -18,7 +18,6 @@ import jax.experimental.sparse as jesp import jax.numpy as jnp import numpy as np -import scipy.sparse as sp from scipy.special import ive from ott import utils @@ -65,7 +64,7 @@ def from_graph( cls, G: jnp.ndarray, t: Optional[float] = 1e-3, - order=100, + order: int = 100, directed: bool = False, normalize: bool = False, **kwargs: Any @@ -178,7 +177,8 @@ def compute_chebyshev_approximation( ) -> jnp.ndarray: # Compute the Chebyshev polynomial approximation for # the given input and coefficients. - x_dense = x.toarray() if sp.issparse(x) else x + x_dense = x.todense( + ) if type(x) is jesp.BCOO else x # this should be true all the time return np.polynomial.chebyshev.chebval(x_dense, coeffs) rescaled_laplacian = rescale_laplacian(self.laplacian) From 0b7364b9f2fb3896e69b91b18d0e4f7f2b31c459 Mon Sep 17 00:00:00 2001 From: diegoabt Date: Mon, 30 Oct 2023 12:33:55 +0100 Subject: [PATCH 18/44] Restructure `from_graph`; coeffs are computed earlier now --- src/ott/geometry/geodesic.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index ddeb1e968..f9d6049d2 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -107,19 +107,16 @@ def from_graph( if t is None: t = (jnp.sum(G) / jnp.sum(G > 0.)) ** 2 - # Create an instance of the Geodesic class and set the attribute - geodesic_instance = cls(laplacian, t=t, order=order, **kwargs) - # Compute the coeffs of the Chebyshev pols approx using Bessel functs. - chebyshev_coeffs = ( - 2 * - ive(jnp.arange(0, geodesic_instance.order + 1), -geodesic_instance.t) - ).tolist() - - # Set the attribute - geodesic_instance.chebyshev_coeffs = chebyshev_coeffs - - return geodesic_instance + chebyshev_coeffs = (2 * ive(jnp.arange(0, order + 1), -t)).tolist() + + return cls( + laplacian, + t=t, + order=order, + chebyshev_coeffs=chebyshev_coeffs, + **kwargs + ) def apply_kernel( self, From 8bc10c1e6b8f617c663ede909c2d8aec93370756 Mon Sep 17 00:00:00 2001 From: guillaumehu Date: Thu, 2 Nov 2023 09:56:46 +0100 Subject: [PATCH 19/44] fn outside of the class --- src/ott/geometry/geodesic.py | 73 +++++++++++++++++++----------------- 1 file changed, 39 insertions(+), 34 deletions(-) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index f9d6049d2..dfd4a6dfa 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -135,40 +135,6 @@ def apply_kernel( Kernel applied to ``scaling``. """ - def compute_largest_eigenvalue(laplacian_matrix, k, rng=None): - # Compute the largest eigenvalue of the Laplacian matrix. - if rng is None: - rng = utils.default_prng_key(rng) - n, _ = self.shape - # Generate random initial directions for eigenvalue computation - initial_dirs = jax.random.normal(rng, (n, k)) - - # Create a sparse matrix-vector product function using sparsify - # This function multiplies the sparse laplacian_matrix with a vector - lapl_vector_product = jesp.sparsify(lambda v: laplacian_matrix @ v) - - # Compute eigenvalues using the sparse matrix-vector product - eigvals, _, _ = jesp.linalg.lobpcg_standard( - lapl_vector_product, initial_dirs, m=k - ) - - return jnp.max(eigvals) - - def rescale_laplacian(laplacian_matrix: jnp.ndarray) -> jnp.ndarray: - # Rescale the Laplacian matrix. - largest_eigenvalue = compute_largest_eigenvalue(laplacian_matrix, k=1) - if largest_eigenvalue > 2: - rescaled_laplacian = laplacian_matrix.copy() - rescaled_laplacian /= largest_eigenvalue - return 2 * rescaled_laplacian - return laplacian_matrix - - def define_scaled_laplacian(laplacian_matrix: jnp.ndarray) -> jnp.ndarray: - # Define the scaled Laplacian matrix. - n = laplacian_matrix.shape[0] - identity = jnp.eye(n) - return laplacian_matrix - identity - def compute_chebyshev_approximation( x: jnp.ndarray, coeffs: List[float] ) -> jnp.ndarray: @@ -257,3 +223,42 @@ def tree_unflatten( # noqa: D102 cls, aux_data: Dict[str, Any], children: Sequence[Any] ) -> "Geodesic": return cls(*children, **aux_data) + + +# TODO: +# just moving some function here for now, idk if we want them in the class +# or in a utils file. + +def compute_largest_eigenvalue(laplacian_matrix, k, rng=None): + # Compute the largest eigenvalue of the Laplacian matrix. + if rng is None: + rng = utils.default_prng_key(rng) + n = laplacian_matrix.shape[0] + # Generate random initial directions for eigenvalue computation + initial_dirs = jax.random.normal(rng, (n, k)) + + # Create a sparse matrix-vector product function using sparsify + # This function multiplies the sparse laplacian_matrix with a vector + lapl_vector_product = jesp.sparsify(lambda v: laplacian_matrix @ v) + + # Compute eigenvalues using the sparse matrix-vector product + eigvals, _, _ = jesp.linalg.lobpcg_standard( + lapl_vector_product, initial_dirs, m=k + ) + + return jnp.max(eigvals) + +def rescale_laplacian(laplacian_matrix: jnp.ndarray) -> jnp.ndarray: + # Rescale the Laplacian matrix. + largest_eigenvalue = compute_largest_eigenvalue(laplacian_matrix, k=1) + if largest_eigenvalue > 2: + rescaled_laplacian = laplacian_matrix.copy() + rescaled_laplacian /= largest_eigenvalue + return 2 * rescaled_laplacian + return laplacian_matrix + +def define_scaled_laplacian(laplacian_matrix: jnp.ndarray) -> jnp.ndarray: + # Define the scaled Laplacian matrix. + n = laplacian_matrix.shape[0] + identity = jnp.eye(n) + return laplacian_matrix - identity \ No newline at end of file From 3650cda7e943cfd4e44380744b2e245fcce4240d Mon Sep 17 00:00:00 2001 From: guillaumehu Date: Thu, 2 Nov 2023 11:06:44 +0100 Subject: [PATCH 20/44] mv fn & process L once --- src/ott/geometry/geodesic.py | 46 ++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index dfd4a6dfa..01da8e701 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -50,6 +50,8 @@ def __init__( chebyshev_coeffs: Optional[List[float]] = None, numerical_scheme: Literal["backward_euler", "crank_nicolson"] = "backward_euler", + lap_m_id: Optional[jnp.ndarray] = None, # Rescale Laplacian minus identity + eigval: Optional[jnp.ndarray] = None, # (Second)Largest eigenvalue of Laplacian **kwargs: Any ): super().__init__(epsilon=1., **kwargs) @@ -58,6 +60,8 @@ def __init__( self.order = order self.chebyshev_coeffs = chebyshev_coeffs self.numerical_scheme = numerical_scheme + self.lap_m_id = lap_m_id + self.eigval = eigval @classmethod def from_graph( @@ -104,6 +108,10 @@ def from_graph( ) laplacian = inv_sqrt_deg @ laplacian @ inv_sqrt_deg + eigval = compute_largest_eigenvalue(laplacian, k=1) + rescaled_laplacian = rescale_laplacian(laplacian, eigval) + lap_min_id = define_scaled_laplacian(rescaled_laplacian) + if t is None: t = (jnp.sum(G) / jnp.sum(G > 0.)) ** 2 @@ -115,6 +123,8 @@ def from_graph( t=t, order=order, chebyshev_coeffs=chebyshev_coeffs, + lap_m_id=lap_min_id, + eigval=eigval, **kwargs ) @@ -124,10 +134,13 @@ def apply_kernel( eps: Optional[float] = None, axis: int = 0, ) -> jnp.ndarray: + # TODO: fix indentation + # NOTE: GH: We could also input time, since we only need to recompute the coeffs, + # i.e. we can use the same laplacian, scales laplaciant for different times. r"""Apply :attr:`kernel_matrix` on positive scaling vector. Args: - scaling: Scaling to apply the kernel to. + scaling: Scaling to apply the kernel to. eps: passed for consistency, not used yet. axis: passed for consistency, not used yet. @@ -135,19 +148,7 @@ def apply_kernel( Kernel applied to ``scaling``. """ - def compute_chebyshev_approximation( - x: jnp.ndarray, coeffs: List[float] - ) -> jnp.ndarray: - # Compute the Chebyshev polynomial approximation for - # the given input and coefficients. - x_dense = x.todense( - ) if type(x) is jesp.BCOO else x # this should be true all the time - return np.polynomial.chebyshev.chebval(x_dense, coeffs) - - rescaled_laplacian = rescale_laplacian(self.laplacian) - scaled_laplacian = define_scaled_laplacian(rescaled_laplacian) - - laplacian_times_signal = scaled_laplacian.dot(scaling) # Apply the kernel + laplacian_times_signal = self.lap_m_id.dot(scaling) # Apply the kernel return compute_chebyshev_approximation( laplacian_times_signal, self.chebyshev_coeffs @@ -226,7 +227,7 @@ def tree_unflatten( # noqa: D102 # TODO: -# just moving some function here for now, idk if we want them in the class +# Moving some function here for now, idk if we want them in the class # or in a utils file. def compute_largest_eigenvalue(laplacian_matrix, k, rng=None): @@ -248,9 +249,8 @@ def compute_largest_eigenvalue(laplacian_matrix, k, rng=None): return jnp.max(eigvals) -def rescale_laplacian(laplacian_matrix: jnp.ndarray) -> jnp.ndarray: +def rescale_laplacian(laplacian_matrix: jnp.ndarray, largest_eigenvalue: jnp.ndarray) -> jnp.ndarray: # Rescale the Laplacian matrix. - largest_eigenvalue = compute_largest_eigenvalue(laplacian_matrix, k=1) if largest_eigenvalue > 2: rescaled_laplacian = laplacian_matrix.copy() rescaled_laplacian /= largest_eigenvalue @@ -261,4 +261,14 @@ def define_scaled_laplacian(laplacian_matrix: jnp.ndarray) -> jnp.ndarray: # Define the scaled Laplacian matrix. n = laplacian_matrix.shape[0] identity = jnp.eye(n) - return laplacian_matrix - identity \ No newline at end of file + return laplacian_matrix - identity + +@jax.pure_callback +def compute_chebyshev_approximation( + x: jnp.ndarray, coeffs: List[float] +) -> jnp.ndarray: + # Compute the Chebyshev polynomial approximation for + # the given input and coefficients. + x_dense = x.todense( + ) if type(x) is jesp.BCOO else x # this should be true all the time + return np.polynomial.chebyshev.chebval(x_dense, coeffs) \ No newline at end of file From 06d04ef6b925a8ba4605a58eaea5cbbf43b3d931 Mon Sep 17 00:00:00 2001 From: guillaumehu Date: Fri, 3 Nov 2023 13:28:52 +0100 Subject: [PATCH 21/44] wip tests geo --- setup.py | 17 -- tests/geometry/geo_test.py | 309 +++++++++++++++++++++++++++++++++++++ 2 files changed, 309 insertions(+), 17 deletions(-) delete mode 100644 setup.py create mode 100644 tests/geometry/geo_test.py diff --git a/setup.py b/setup.py deleted file mode 100644 index 9ae2b1026..000000000 --- a/setup.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright OTT-JAX -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from setuptools import setup - -# for packaging tools not supporting, e.g., PEP 517, PEP 660 -setup() diff --git a/tests/geometry/geo_test.py b/tests/geometry/geo_test.py new file mode 100644 index 000000000..8f6feba0b --- /dev/null +++ b/tests/geometry/geo_test.py @@ -0,0 +1,309 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import jax.numpy as jnp +import networkx as nx +import numpy as np +from networkx.algorithms import shortest_paths +from networkx.generators import random_graphs +from ott.geometry import geodesic, geometry + + +def random_graph( + n: int, + p: float = 0.3, + seed: Optional[int] = 0, + *, + return_laplacian: bool = False, + directed: bool = False, +) -> jnp.ndarray: + G = random_graphs.fast_gnp_random_graph(n, p, seed=seed, directed=directed) + if not directed: + assert nx.is_connected(G), "Generated graph is not connected." + + rng = np.random.RandomState(seed) + for _, _, w in G.edges(data=True): + w["weight"] = rng.uniform(0, 10) + + G = nx.linalg.laplacian_matrix( + G + ) if return_laplacian else nx.linalg.adjacency_matrix(G) + + return jnp.asarray(G.toarray()) + + +def gt_geometry(G: jnp.ndarray, *, epsilon: float = 1e-2) -> geometry.Geometry: + if not isinstance(G, nx.Graph): + G = nx.from_numpy_array(np.asarray(G)) + + n = len(G) + cost = np.zeros((n, n), dtype=float) + + path = dict( + shortest_paths.all_pairs_bellman_ford_path_length(G, weight="weight") + ) + for i, src in enumerate(G.nodes): + for j, tgt in enumerate(G.nodes): + cost[i, j] = path[src][tgt] ** 2 + + cost = jnp.asarray(cost) + kernel = jnp.asarray(np.exp(-cost / epsilon)) + return geometry.Geometry(cost_matrix=cost, kernel_matrix=kernel, epsilon=1.) + + +class TestGeodesic: + + def test_init(self): + n, order = 10, 100 + G = random_graph(n, p=0.5) + geom = geodesic.Geodesic.from_graph(G, order=order) + + np.testing.assert_equal(geom.order, order) + np.testing.assert_equal(geom.t, 1e-3) + np.testing.assert_equal(geom.chebyshev_coeffs, None) + np.testing.assert_equal(geom.laplacian, G) + np.testing.assert_equal(geom.numerical_scheme, "backward_euler") + + +# class TestGraph: + +# def test_kernel_is_symmetric_positive_definite( +# self, rng: jax.random.PRNGKeyArray +# ): +# n, tol = 65, 0.02 +# x = jax.random.normal(rng, (n,)) +# geom = graph.Graph.from_graph(random_graph(n), t=1e-3) + +# kernel = geom.kernel_matrix + +# vec0 = geom.apply_kernel(x, axis=0) +# vec1 = geom.apply_kernel(x, axis=1) +# vec_direct0 = geom.kernel_matrix.T @ x +# vec_direct1 = geom.kernel_matrix @ x + +# # we symmetrize the kernel explicitly when materializing it, because +# # numerical error arise for small `t` and `backward_euler` +# np.testing.assert_array_equal(kernel, kernel.T) +# np.testing.assert_array_equal(jnp.linalg.eigvals(kernel) > 0., True) +# # internally, the axis is ignored because the kernel is symmetric +# np.testing.assert_array_equal(vec0, vec1) +# np.testing.assert_array_equal(vec_direct0, vec_direct1) + +# np.testing.assert_allclose(vec0, vec_direct0, rtol=tol, atol=tol) +# np.testing.assert_allclose(vec1, vec_direct1, rtol=tol, atol=tol) + +# def test_automatic_t(self): +# G = random_graph(38, return_laplacian=False) +# geom = graph.Graph.from_graph(G, t=None) + +# expected = (jnp.sum(G) / jnp.sum(G > 0.)) ** 2 +# actual = geom.t +# np.testing.assert_equal(actual, expected) + +# @pytest.mark.fast.with_args( +# numerical_scheme=["backward_euler", "crank_nicolson"], +# only_fast=0, +# ) +# def test_approximates_ground_truth( +# self, +# rng: jax.random.PRNGKeyArray, +# numerical_scheme: Literal["backward_euler", "crank_nicolson"], +# ): +# eps, n_steps = 1e-5, 20 +# G = random_graph(37, p=0.5) +# x = jax.random.normal(rng, (G.shape[0],)) + +# gt_geom = gt_geometry(G, epsilon=eps) +# graph_geom = graph.Graph.from_graph( +# G, t=eps, n_steps=n_steps, numerical_scheme=numerical_scheme +# ) + +# np.testing.assert_allclose( +# gt_geom.kernel_matrix, graph_geom.kernel_matrix, rtol=1e-2, atol=1e-2 +# ) +# np.testing.assert_allclose( +# gt_geom.apply_kernel(x), +# graph_geom.apply_kernel(x), +# rtol=1e-2, +# atol=1e-2 +# ) + +# @pytest.mark.fast.with_args( +# n_steps=[50, 100, 200], +# t=[1e-4, 1e-5], +# only_fast=0, +# ) +# def test_crank_nicolson_more_stable(self, t: Optional[float], n_steps: int): +# tol = 5 * t +# G = nx.linalg.adjacency_matrix(balanced_tree(r=2, h=5)) +# G = jnp.asarray(G.toarray(), dtype=float) +# eye = jnp.eye(G.shape[0]) + +# be_geom = graph.Graph.from_graph( +# G, t=t, n_steps=n_steps, numerical_scheme="backward_euler" +# ) +# cn_geom = graph.Graph.from_graph( +# G, t=t, n_steps=n_steps, numerical_scheme="crank_nicolson" +# ) +# eps = jnp.finfo(eye.dtype).tiny + +# be_cost = -t * jnp.log(be_geom.apply_kernel(eye) + eps) +# cn_cost = -t * jnp.log(cn_geom.apply_kernel(eye) + eps) + +# np.testing.assert_allclose(cn_cost, cn_cost.T, rtol=tol, atol=tol) +# with pytest.raises(AssertionError): +# np.testing.assert_allclose(be_cost, be_cost.T, rtol=tol, atol=tol) + +# @pytest.mark.parametrize(("jit", "normalize"), [(False, True), (True, False)]) +# def test_directed_graph(self, jit: bool, normalize: bool): + +# def create_graph(G: jnp.ndarray) -> graph.Graph: +# return graph.Graph.from_graph(G, directed=True, normalize=normalize) + +# G = random_graph(16, p=0.25, directed=True) +# create_fn = jax.jit(create_graph) if jit else create_graph +# geom = create_fn(G) + +# with pytest.raises(AssertionError): +# np.testing.assert_allclose(G, G.T) + +# L = geom.laplacian + +# with pytest.raises(AssertionError): +# # make sure that original graph was directed +# np.testing.assert_allclose(G, G.T, rtol=1e-6, atol=1e-6) +# np.testing.assert_allclose(L, L.T, rtol=1e-6, atol=1e-6) + +# @pytest.mark.parametrize("directed", [False, True]) +# @pytest.mark.parametrize("normalize", [False, True]) +# def test_normalize_laplacian(self, directed: bool, normalize: bool): + +# def laplacian(G: jnp.ndarray) -> jnp.ndarray: +# if directed: +# G = G + G.T + +# data = jnp.sum(G, axis=1) +# lap = jnp.diag(data) - G +# if normalize: +# inv_sqrt_deg = jnp.diag( +# jnp.where(data > 0.0, 1.0 / jnp.sqrt(data), 0.0) +# ) +# return inv_sqrt_deg @ lap @ inv_sqrt_deg +# return lap + +# G = random_graph(51, p=0.35, directed=directed) +# geom = graph.Graph.from_graph(G, directed=directed, normalize=normalize) + +# expected = laplacian(G) +# actual = geom.laplacian + +# np.testing.assert_allclose(actual, expected, rtol=1e-6, atol=1e-6) + +# @pytest.mark.fast.with_args(jit=[False, True], only_fast=0) +# def test_graph_sinkhorn(self, rng: jax.random.PRNGKeyArray, jit: bool): + +# def callback(geom: geometry.Geometry) -> sinkhorn.SinkhornOutput: +# solver = sinkhorn.Sinkhorn(lse_mode=False) +# problem = linear_problem.LinearProblem(geom) +# return solver(problem) + +# n, eps, tol = 11, 1e-5, 1e-3 +# G = random_graph(n, p=0.35) +# x = jax.random.normal(rng, (n,)) + +# gt_geom = gt_geometry(G, epsilon=eps) +# graph_geom = graph.Graph.from_graph(G, t=eps) + +# fn = jax.jit(callback) if jit else callback + +# gt_out = fn(gt_geom) +# graph_out = fn(graph_geom) + +# assert gt_out.converged +# assert graph_out.converged +# np.testing.assert_allclose( +# graph_out.reg_ot_cost, gt_out.reg_ot_cost, rtol=tol, atol=tol +# ) +# np.testing.assert_allclose(graph_out.f, gt_out.f, rtol=tol, atol=tol) +# np.testing.assert_allclose(graph_out.g, gt_out.g, rtol=tol, atol=tol) + +# for axis in [0, 1]: +# y_gt = gt_out.apply(x, axis=axis) +# y_out = graph_out.apply(x, axis=axis) +# # note the high tolerance +# np.testing.assert_allclose(y_gt, y_out, rtol=5e-1, atol=5e-1) + +# np.testing.assert_allclose( +# gt_out.matrix, graph_out.matrix, rtol=1e-1, atol=1e-1 +# ) + +# @pytest.mark.parametrize( +# "implicit_diff", +# [False, True], +# ids=["not-implicit", "implicit"], +# ) +# def test_dense_graph_differentiability( +# self, rng: jax.random.PRNGKeyArray, implicit_diff: bool +# ): + +# def callback( +# data: jnp.ndarray, rows: jnp.ndarray, cols: jnp.ndarray, +# shape: Tuple[int, int] +# ) -> float: +# G = sparse.BCOO((data, jnp.c_[rows, cols]), shape=shape).todense() + +# geom = graph.Graph.from_graph(G, t=1.) +# solver = sinkhorn.Sinkhorn(lse_mode=False, **kwargs) +# problem = linear_problem.LinearProblem(geom) + +# return solver(problem).reg_ot_cost + +# if implicit_diff: +# kwargs = {"implicit_diff": implicit_lib.ImplicitDiff()} +# else: +# kwargs = {"implicit_diff": None} + +# eps = 1e-3 +# G = random_graph(20, p=0.5) +# G = sparse.BCOO.fromdense(G) + +# w, rows, cols = G.data, G.indices[:, 0], G.indices[:, 1] +# v_w = jax.random.normal(rng, shape=w.shape) +# v_w = (v_w / jnp.linalg.norm(v_w, axis=-1, keepdims=True)) * eps + +# grad_w = jax.grad(callback)(w, rows, cols, shape=G.shape) + +# expected = callback(w + v_w, rows, cols, +# G.shape) - callback(w - v_w, rows, cols, G.shape) +# actual = 2 * jnp.vdot(v_w, grad_w) +# np.testing.assert_allclose(actual, expected, rtol=1e-4, atol=1e-4) + +# def test_tolerance_hilbert_metric(self, rng: jax.random.PRNGKeyArray): +# n, n_steps, t, tol = 256, 1000, 1e-4, 3e-4 +# G = random_graph(n, p=0.15) +# x = jnp.abs(jax.random.normal(rng, (n,))) + +# graph_no_tol = graph.Graph.from_graph(G, t=t, n_steps=n_steps, tol=-1) +# graph_low_tol = graph.Graph.from_graph(G, t=t, n_steps=n_steps, tol=2.5e-4) +# graph_high_tol = graph.Graph.from_graph(G, t=t, n_steps=n_steps, tol=1e-1) + +# app_no_tol = graph_no_tol.apply_kernel(x) +# app_low_tol = graph_low_tol.apply_kernel(x) # does 1 iteration +# app_high_tol = graph_high_tol.apply_kernel(x) # does 961 iterations + +# np.testing.assert_allclose(app_no_tol, app_low_tol, rtol=tol, atol=tol) +# np.testing.assert_allclose(app_no_tol, app_high_tol, rtol=5e-2, atol=5e-2) +# with pytest.raises(AssertionError): +# np.testing.assert_allclose(app_no_tol, app_high_tol, rtol=tol, atol=tol) From 3e62805e2098c4c75fccf90826450e3598acc503 Mon Sep 17 00:00:00 2001 From: guillaumehu Date: Tue, 7 Nov 2023 17:52:35 +0100 Subject: [PATCH 22/44] jax pure_callback & new fn cheb --- src/ott/geometry/__init__.py | 1 + src/ott/geometry/geodesic.py | 94 +++++++++++++++++++++++++++--------- 2 files changed, 71 insertions(+), 24 deletions(-) diff --git a/src/ott/geometry/__init__.py b/src/ott/geometry/__init__.py index 5890e0935..c16ba687a 100644 --- a/src/ott/geometry/__init__.py +++ b/src/ott/geometry/__init__.py @@ -14,6 +14,7 @@ from . import ( costs, epsilon_scheduler, + geodesic, geometry, graph, grid, diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index 01da8e701..81cb56735 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -50,8 +50,10 @@ def __init__( chebyshev_coeffs: Optional[List[float]] = None, numerical_scheme: Literal["backward_euler", "crank_nicolson"] = "backward_euler", - lap_m_id: Optional[jnp.ndarray] = None, # Rescale Laplacian minus identity - eigval: Optional[jnp.ndarray] = None, # (Second)Largest eigenvalue of Laplacian + lap_min_id: Optional[jnp.ndarray + ] = None, # Rescale Laplacian minus identity + eigval: Optional[jnp.ndarray + ] = None, # (Second)Largest eigenvalue of Laplacian **kwargs: Any ): super().__init__(epsilon=1., **kwargs) @@ -60,8 +62,8 @@ def __init__( self.order = order self.chebyshev_coeffs = chebyshev_coeffs self.numerical_scheme = numerical_scheme - self.lap_m_id = lap_m_id self.eigval = eigval + self.lap_min_id = lap_min_id @classmethod def from_graph( @@ -110,7 +112,9 @@ def from_graph( eigval = compute_largest_eigenvalue(laplacian, k=1) rescaled_laplacian = rescale_laplacian(laplacian, eigval) - lap_min_id = define_scaled_laplacian(rescaled_laplacian) + lap_min_id = define_scaled_laplacian( + rescaled_laplacian + ) # TODO: remove if not needed. if t is None: t = (jnp.sum(G) / jnp.sum(G > 0.)) ** 2 @@ -119,11 +123,11 @@ def from_graph( chebyshev_coeffs = (2 * ive(jnp.arange(0, order + 1), -t)).tolist() return cls( - laplacian, + laplacian=laplacian, t=t, order=order, chebyshev_coeffs=chebyshev_coeffs, - lap_m_id=lap_min_id, + lap_min_id=lap_min_id, eigval=eigval, **kwargs ) @@ -135,25 +139,24 @@ def apply_kernel( axis: int = 0, ) -> jnp.ndarray: # TODO: fix indentation - # NOTE: GH: We could also input time, since we only need to recompute the coeffs, + # NOTE: GH: We could also input time, since we only need to recompute the coeffs, # i.e. we can use the same laplacian, scales laplaciant for different times. r"""Apply :attr:`kernel_matrix` on positive scaling vector. Args: - scaling: Scaling to apply the kernel to. + scaling: Scaling to apply the kernel to. eps: passed for consistency, not used yet. axis: passed for consistency, not used yet. Returns: Kernel applied to ``scaling``. """ - - laplacian_times_signal = self.lap_m_id.dot(scaling) # Apply the kernel - - return compute_chebyshev_approximation( - laplacian_times_signal, self.chebyshev_coeffs + diff_signal = expm_multiply( + self.laplacian, scaling, self.t, self.eigval, self.order ) + return diff_signal + @property def kernel_matrix(self) -> jnp.ndarray: # noqa: D102 n, _ = self.shape @@ -228,7 +231,8 @@ def tree_unflatten( # noqa: D102 # TODO: # Moving some function here for now, idk if we want them in the class -# or in a utils file. +# or in a utils file. + def compute_largest_eigenvalue(laplacian_matrix, k, rng=None): # Compute the largest eigenvalue of the Laplacian matrix. @@ -249,7 +253,10 @@ def compute_largest_eigenvalue(laplacian_matrix, k, rng=None): return jnp.max(eigvals) -def rescale_laplacian(laplacian_matrix: jnp.ndarray, largest_eigenvalue: jnp.ndarray) -> jnp.ndarray: + +def rescale_laplacian( + laplacian_matrix: jnp.ndarray, largest_eigenvalue: jnp.ndarray +) -> jnp.ndarray: # Rescale the Laplacian matrix. if largest_eigenvalue > 2: rescaled_laplacian = laplacian_matrix.copy() @@ -257,18 +264,57 @@ def rescale_laplacian(laplacian_matrix: jnp.ndarray, largest_eigenvalue: jnp.nda return 2 * rescaled_laplacian return laplacian_matrix + def define_scaled_laplacian(laplacian_matrix: jnp.ndarray) -> jnp.ndarray: # Define the scaled Laplacian matrix. n = laplacian_matrix.shape[0] identity = jnp.eye(n) return laplacian_matrix - identity -@jax.pure_callback -def compute_chebyshev_approximation( - x: jnp.ndarray, coeffs: List[float] -) -> jnp.ndarray: - # Compute the Chebyshev polynomial approximation for - # the given input and coefficients. - x_dense = x.todense( - ) if type(x) is jesp.BCOO else x # this should be true all the time - return np.polynomial.chebyshev.chebval(x_dense, coeffs) \ No newline at end of file + +def _scipy_compute_chebychev_coeff_all(phi, tau, K): + """Compute the K+1 Chebychev coefficients for our functions.""" + coeff = 2 * ive(np.arange(0, K + 1), -tau * phi) + if coeff.dtype == np.float64: + coeff = np.float32(coeff) + return coeff + + +def expm_multiply( + L, + X, + phi, + tau, + K=None, +): + # NOTE: Modified the signature, to reuse computation during the Sinkhorn iteration. + # Compute coefficients (they should all fit in memory, no problem) + coeff = compute_chebychev_coeff_all(phi, tau, K) + # Initialize the accumulator with only the first coeff*polynomial + T0 = X + Y = 0.5 * coeff[0] * T0 + # Add the second coeff*polynomial to the accumulator + T1 = (1 / phi) * L @ X - T0 + Y = Y + coeff[1] * T1 + # Recursively add the next coeff*polynomial + for j in range(2, K + 1): + T2 = (2 / phi) * L @ T1 - 2 * T1 - T0 + Y = Y + coeff[j] * T2 + T0 = T1 + T1 = T2 + return Y + + +def compute_chebychev_coeff_all(phi, tau, K): + """Jax wrapper to compute the K+1 Chebychev coefficients.""" + if not isinstance(phi, jnp.ndarray): + phi = jnp.asarray(phi) + + result_shape_dtype = jax.ShapeDtypeStruct( + shape=(K + 1,), + dtype=jax.numpy.float32, + ) # TODO: not sure about the best type here. Maybe the best if to have + # the same type as the laplacian. + return jax.pure_callback( + _scipy_compute_chebychev_coeff_all, result_shape_dtype, phi, tau, K + ) From 3640e451e89718ebb47a358449c5ba738eb7606d Mon Sep 17 00:00:00 2001 From: guillaumehu Date: Tue, 7 Nov 2023 18:38:39 +0100 Subject: [PATCH 23/44] symmetric kernel & wip tests --- src/ott/geometry/geodesic.py | 6 ++- tests/geometry/geo_test.py | 88 ++++++++++++++++++++---------------- 2 files changed, 55 insertions(+), 39 deletions(-) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index 81cb56735..0540e9aab 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -160,7 +160,11 @@ def apply_kernel( @property def kernel_matrix(self) -> jnp.ndarray: # noqa: D102 n, _ = self.shape - return self.apply_kernel(jnp.eye(n)) + kernel = self.apply_kernel(jnp.eye(n)) + # check if the kernel is symmetric + if jnp.any((kernel != kernel.T)): + kernel = (kernel + kernel.T) / 2.0 + return kernel @property def cost_matrix(self) -> jnp.ndarray: # noqa: D102 diff --git a/tests/geometry/geo_test.py b/tests/geometry/geo_test.py index 8f6feba0b..d32bbfa4e 100644 --- a/tests/geometry/geo_test.py +++ b/tests/geometry/geo_test.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Optional +import jax import jax.numpy as jnp import networkx as nx import numpy as np @@ -67,51 +68,62 @@ class TestGeodesic: def test_init(self): n, order = 10, 100 + t = 10 G = random_graph(n, p=0.5) - geom = geodesic.Geodesic.from_graph(G, order=order) + geom = geodesic.Geodesic.from_graph(G, t=t, order=order) np.testing.assert_equal(geom.order, order) - np.testing.assert_equal(geom.t, 1e-3) - np.testing.assert_equal(geom.chebyshev_coeffs, None) - np.testing.assert_equal(geom.laplacian, G) - np.testing.assert_equal(geom.numerical_scheme, "backward_euler") + np.testing.assert_equal(geom.t, t) + # np.testing.assert_equal(geom.laplacian, G) # TODO: check for the normalized laplacian + + def test_kernel_is_symmetric_positive_definite( + self, rng: jax.random.PRNGKeyArray + ): + n, tol = 65, 0.02 + t = 20 + order = 50 + x = jax.random.normal(rng, (n,)) + G = random_graph(n) + geom = geodesic.Geodesic.from_graph(G, t=t, order=order) + + kernel = geom.kernel_matrix + + vec0 = geom.apply_kernel(x, axis=0) + vec1 = geom.apply_kernel(x, axis=1) + vec_direct0 = geom.kernel_matrix.T @ x + vec_direct1 = geom.kernel_matrix @ x + + # we symmetrize the kernel explicitly when materializing it, because + # numerical error arise for small `t` and `backward_euler`, or Chebyshev approximation. + np.testing.assert_array_equal(kernel, kernel.T) + eigenvalues = jnp.linalg.eigvals(kernel) + neg_eigenvalues = eigenvalues[eigenvalues < 0] + # check that the negative eigenvalues are all very small + np.testing.assert_array_less(jnp.abs(neg_eigenvalues), 1e-3) + # internally, the axis is ignored because the kernel is symmetric + np.testing.assert_array_equal(vec0, vec1) + np.testing.assert_array_equal(vec_direct0, vec_direct1) + + np.testing.assert_allclose(vec0, vec_direct0, rtol=tol, atol=tol) + np.testing.assert_allclose(vec1, vec_direct1, rtol=tol, atol=tol) + + # compute the distance matrix and check that it is symmetric + cost_matrix = geom.cost_matrix + np.testing.assert_array_equal(cost_matrix, cost_matrix.T) + # and all dissimilarities are positive + np.testing.assert_array_less(0, cost_matrix) + + def test_automatic_t(self): + G = random_graph(38, return_laplacian=False) + geom = geodesic.Geodesic.from_graph(G, t=None) + + expected = (jnp.sum(G) / jnp.sum(G > 0.)) ** 2 + actual = geom.t + np.testing.assert_equal(actual, expected) # class TestGraph: -# def test_kernel_is_symmetric_positive_definite( -# self, rng: jax.random.PRNGKeyArray -# ): -# n, tol = 65, 0.02 -# x = jax.random.normal(rng, (n,)) -# geom = graph.Graph.from_graph(random_graph(n), t=1e-3) - -# kernel = geom.kernel_matrix - -# vec0 = geom.apply_kernel(x, axis=0) -# vec1 = geom.apply_kernel(x, axis=1) -# vec_direct0 = geom.kernel_matrix.T @ x -# vec_direct1 = geom.kernel_matrix @ x - -# # we symmetrize the kernel explicitly when materializing it, because -# # numerical error arise for small `t` and `backward_euler` -# np.testing.assert_array_equal(kernel, kernel.T) -# np.testing.assert_array_equal(jnp.linalg.eigvals(kernel) > 0., True) -# # internally, the axis is ignored because the kernel is symmetric -# np.testing.assert_array_equal(vec0, vec1) -# np.testing.assert_array_equal(vec_direct0, vec_direct1) - -# np.testing.assert_allclose(vec0, vec_direct0, rtol=tol, atol=tol) -# np.testing.assert_allclose(vec1, vec_direct1, rtol=tol, atol=tol) - -# def test_automatic_t(self): -# G = random_graph(38, return_laplacian=False) -# geom = graph.Graph.from_graph(G, t=None) - -# expected = (jnp.sum(G) / jnp.sum(G > 0.)) ** 2 -# actual = geom.t -# np.testing.assert_equal(actual, expected) - # @pytest.mark.fast.with_args( # numerical_scheme=["backward_euler", "crank_nicolson"], # only_fast=0, From 4f4447c554e0560f7fa4e544d28400f00bdb3d2e Mon Sep 17 00:00:00 2001 From: guillaumehu Date: Wed, 8 Nov 2023 09:21:18 +0100 Subject: [PATCH 24/44] fix formatting ruff --- src/ott/geometry/geodesic.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index 0540e9aab..56e2bbb36 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -139,7 +139,8 @@ def apply_kernel( axis: int = 0, ) -> jnp.ndarray: # TODO: fix indentation - # NOTE: GH: We could also input time, since we only need to recompute the coeffs, + # NOTE: GH: We could also input time, + # since we only need to recompute the coeffs, # i.e. we can use the same laplacian, scales laplaciant for different times. r"""Apply :attr:`kernel_matrix` on positive scaling vector. @@ -151,18 +152,16 @@ def apply_kernel( Returns: Kernel applied to ``scaling``. """ - diff_signal = expm_multiply( + return expm_multiply( self.laplacian, scaling, self.t, self.eigval, self.order ) - return diff_signal - @property def kernel_matrix(self) -> jnp.ndarray: # noqa: D102 n, _ = self.shape kernel = self.apply_kernel(jnp.eye(n)) # check if the kernel is symmetric - if jnp.any((kernel != kernel.T)): + if jnp.any(kernel != kernel.T): kernel = (kernel + kernel.T) / 2.0 return kernel @@ -291,7 +290,6 @@ def expm_multiply( tau, K=None, ): - # NOTE: Modified the signature, to reuse computation during the Sinkhorn iteration. # Compute coefficients (they should all fit in memory, no problem) coeff = compute_chebychev_coeff_all(phi, tau, K) # Initialize the accumulator with only the first coeff*polynomial From 2d5e0f02bb49fdc50d9acd98427bf0a4ebaf2b76 Mon Sep 17 00:00:00 2001 From: guillaumehu Date: Wed, 8 Nov 2023 10:38:20 +0100 Subject: [PATCH 25/44] wrap eigenval fn --- src/ott/geometry/geodesic.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index 56e2bbb36..4c4f7ee18 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -18,6 +18,7 @@ import jax.experimental.sparse as jesp import jax.numpy as jnp import numpy as np +from scipy.sparse.linalg import eigsh as _scipy_eigsh from scipy.special import ive from ott import utils @@ -73,6 +74,7 @@ def from_graph( order: int = 100, directed: bool = False, normalize: bool = False, + eigenval_scipy: bool = False, **kwargs: Any ) -> "Geodesic": r"""Construct a Geodesic geometry from an adjacency matrix. @@ -91,6 +93,8 @@ def from_graph( :math:`L^{sym} = \left(D^+\right)^{\frac{1}{2}} L \left(D^+\right)^{\frac{1}{2}}`, where :math:`L` is the non-normalized Laplacian and :math:`D` is the degree matrix. + eigenval_scipy: Whether to use the scipy implementation of the + eigenvalue computation. kwargs: Keyword arguments for the Geodesic class. Returns: @@ -109,8 +113,10 @@ def from_graph( jnp.where(degree > 0.0, 1.0 / jnp.sqrt(degree), 0.0) ) laplacian = inv_sqrt_deg @ laplacian @ inv_sqrt_deg - - eigval = compute_largest_eigenvalue(laplacian, k=1) + if eigenval_scipy: + eigval = compute_eigenvalue(laplacian) + else: + eigval = compute_largest_eigenvalue(laplacian, k=1) rescaled_laplacian = rescale_laplacian(laplacian, eigval) lap_min_id = define_scaled_laplacian( rescaled_laplacian @@ -120,7 +126,7 @@ def from_graph( t = (jnp.sum(G) / jnp.sum(G > 0.)) ** 2 # Compute the coeffs of the Chebyshev pols approx using Bessel functs. - chebyshev_coeffs = (2 * ive(jnp.arange(0, order + 1), -t)).tolist() + chebyshev_coeffs = compute_chebychev_coeff_all(eigval, t, order) return cls( laplacian=laplacian, @@ -153,7 +159,8 @@ def apply_kernel( Kernel applied to ``scaling``. """ return expm_multiply( - self.laplacian, scaling, self.t, self.eigval, self.order + self.laplacian, scaling, self.chebyshev_coeffs, self.t, self.eigval, + self.order ) @property @@ -286,12 +293,11 @@ def _scipy_compute_chebychev_coeff_all(phi, tau, K): def expm_multiply( L, X, + coeff, phi, tau, K=None, ): - # Compute coefficients (they should all fit in memory, no problem) - coeff = compute_chebychev_coeff_all(phi, tau, K) # Initialize the accumulator with only the first coeff*polynomial T0 = X Y = 0.5 * coeff[0] * T0 @@ -320,3 +326,13 @@ def compute_chebychev_coeff_all(phi, tau, K): return jax.pure_callback( _scipy_compute_chebychev_coeff_all, result_shape_dtype, phi, tau, K ) + + +def compute_eigenvalue(L): + """Jax wrapper to compute the largest eigenvalue of the Laplacian.""" + result_shape_dtype = jax.ShapeDtypeStruct( + shape=(1,), + dtype=jax.numpy.float32, + ) + eval_only = lambda x: _scipy_eigsh(x, k=1)[0] / 2.0 + return jax.pure_callback(eval_only, result_shape_dtype, L) From 083b02125acea001d8b1f6649d115c841228dc0a Mon Sep 17 00:00:00 2001 From: guillaumehu Date: Wed, 8 Nov 2023 13:54:51 +0100 Subject: [PATCH 26/44] rm num_scheme & _scale --- src/ott/geometry/geodesic.py | 25 +++++-------------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index 4c4f7ee18..51b638e26 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple +from typing import Any, Dict, List, Optional, Sequence, Tuple import jax import jax.experimental.sparse as jesp @@ -49,8 +49,6 @@ def __init__( t: float = 1e-3, order: int = 100, chebyshev_coeffs: Optional[List[float]] = None, - numerical_scheme: Literal["backward_euler", - "crank_nicolson"] = "backward_euler", lap_min_id: Optional[jnp.ndarray ] = None, # Rescale Laplacian minus identity eigval: Optional[jnp.ndarray @@ -62,7 +60,6 @@ def __init__( self.t = t self.order = order self.chebyshev_coeffs = chebyshev_coeffs - self.numerical_scheme = numerical_scheme self.eigval = eigval self.lap_min_id = lap_min_id @@ -151,12 +148,12 @@ def apply_kernel( r"""Apply :attr:`kernel_matrix` on positive scaling vector. Args: - scaling: Scaling to apply the kernel to. - eps: passed for consistency, not used yet. - axis: passed for consistency, not used yet. + scaling: Scaling to apply the kernel to. + eps: passed for consistency, not used yet. + axis: passed for consistency, not used yet. Returns: - Kernel applied to ``scaling``. + Kernel applied to ``scaling``. """ return expm_multiply( self.laplacian, scaling, self.chebyshev_coeffs, self.t, self.eigval, @@ -177,17 +174,6 @@ def cost_matrix(self) -> jnp.ndarray: # noqa: D102 # Calculate the cost matrix using the formula (5) from the main reference return -4 * self.t * mu.safe_log(self.kernel_matrix) - @property - def _scale(self) -> float: - """Constant used to scale the Laplacian.""" - if self.numerical_scheme == "backward_euler": - return self.t / (4. * self.order) - if self.numerical_scheme == "crank_nicolson": - return self.t / (2. * self.order) - raise NotImplementedError( - f"Numerical scheme `{self.numerical_scheme}` is not implemented." - ) - @property def shape(self) -> Tuple[int, int]: # noqa: D102 return self.laplacian.shape @@ -228,7 +214,6 @@ def marginal_from_potentials( def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return [self.laplacian, self.t, self.order], { - "numerical_scheme": self.numerical_scheme, "chebyshev_coeffs": self.chebyshev_coeffs, } From 406d9c899bba506d85dda20dd2d08dc12bbce327 Mon Sep 17 00:00:00 2001 From: guillaumehu Date: Wed, 8 Nov 2023 14:10:05 +0100 Subject: [PATCH 27/44] type dense or sparse --- src/ott/geometry/geodesic.py | 36 ++++++++++++++++++++++++++++-------- src/ott/types.py | 6 +++++- 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index 51b638e26..09688a564 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -24,9 +24,34 @@ from ott import utils from ott.geometry import geometry from ott.math import utils as mu +from ott.types import Array_g __all__ = ["Geodesic"] +# TODO: +# - Finalize the docstrings. +# - Add tests. +# - Verify sparse graph + cholesky. + +# Previous meetings todos: +# 1) wrap all scipy and numpy +# 2) move all the comp in the init, just call it once +# 3) make sure it works with sparse graph + cholesky +# 4) differentiablity , graph geo uses cholesky (triangle solve). + +# NOTE: Meeting questions: +# - i moved some fn outside of the class. Where do we want them? +# - review type, float64 vs float32 (see the TODOs in the code). +# - Currently two implementations of the eigenvalue computation, +# they give different results. I trust the one from scipy. +# - Changed the chebyshev computstion to the one inspired from +# https://github.com/sibyllema/Fast-Multiscale-Diffusion-on-Graphs. +# - I started working on test. +# - I see that there is also another method using backward Euler.. +# Do we just want to have a HeatFilter class that includes both? +# - Do we want docstrings for all methods? e.g. the wrapper? +# GH: I think only the class is enough. + @jax.tree_util.register_pytree_node_class class Geodesic(geometry.Geometry): @@ -45,12 +70,11 @@ class Geodesic(geometry.Geometry): def __init__( self, - laplacian: jnp.ndarray, + laplacian: Array_g, t: float = 1e-3, order: int = 100, chebyshev_coeffs: Optional[List[float]] = None, - lap_min_id: Optional[jnp.ndarray - ] = None, # Rescale Laplacian minus identity + lap_min_id: Optional[Array_g] = None, # Rescale Laplacian minus identity eigval: Optional[jnp.ndarray ] = None, # (Second)Largest eigenvalue of Laplacian **kwargs: Any @@ -66,7 +90,7 @@ def __init__( @classmethod def from_graph( cls, - G: jnp.ndarray, + G: Array_g, t: Optional[float] = 1e-3, order: int = 100, directed: bool = False, @@ -141,10 +165,6 @@ def apply_kernel( eps: Optional[float] = None, axis: int = 0, ) -> jnp.ndarray: - # TODO: fix indentation - # NOTE: GH: We could also input time, - # since we only need to recompute the coeffs, - # i.e. we can use the same laplacian, scales laplaciant for different times. r"""Apply :attr:`kernel_matrix` on positive scaling vector. Args: diff --git a/src/ott/types.py b/src/ott/types.py index 7a4c88716..dc7e552d9 100644 --- a/src/ott/types.py +++ b/src/ott/types.py @@ -11,14 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Protocol +from typing import Protocol, Union +import jax.experimental.sparse as jesp import jax.numpy as jnp __all__ = ["Transport"] # TODO(michalk8): introduce additional types here +# Either a dense or sparse array. +Array_g = Union[jnp.ndarray, jesp.BCOO] + class Transport(Protocol): """Interface for the solution of a transport problem. From 9239524d80ee30efd8088ef4538c1fab9b42288c Mon Sep 17 00:00:00 2001 From: guillaumehu Date: Wed, 8 Nov 2023 18:32:48 +0100 Subject: [PATCH 28/44] expm with scan & fix hardcode dty & lobpcg iter --- src/ott/geometry/geodesic.py | 100 +++++++++++------------------------ 1 file changed, 32 insertions(+), 68 deletions(-) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index 09688a564..ffcb04e5a 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -18,7 +18,6 @@ import jax.experimental.sparse as jesp import jax.numpy as jnp import numpy as np -from scipy.sparse.linalg import eigsh as _scipy_eigsh from scipy.special import ive from ott import utils @@ -34,24 +33,11 @@ # - Verify sparse graph + cholesky. # Previous meetings todos: -# 1) wrap all scipy and numpy -# 2) move all the comp in the init, just call it once +# 1) wrap all scipy and numpy (done) +# 2) move all the comp in the init, just call it once (done) # 3) make sure it works with sparse graph + cholesky # 4) differentiablity , graph geo uses cholesky (triangle solve). -# NOTE: Meeting questions: -# - i moved some fn outside of the class. Where do we want them? -# - review type, float64 vs float32 (see the TODOs in the code). -# - Currently two implementations of the eigenvalue computation, -# they give different results. I trust the one from scipy. -# - Changed the chebyshev computstion to the one inspired from -# https://github.com/sibyllema/Fast-Multiscale-Diffusion-on-Graphs. -# - I started working on test. -# - I see that there is also another method using backward Euler.. -# Do we just want to have a HeatFilter class that includes both? -# - Do we want docstrings for all methods? e.g. the wrapper? -# GH: I think only the class is enough. - @jax.tree_util.register_pytree_node_class class Geodesic(geometry.Geometry): @@ -95,7 +81,6 @@ def from_graph( order: int = 100, directed: bool = False, normalize: bool = False, - eigenval_scipy: bool = False, **kwargs: Any ) -> "Geodesic": r"""Construct a Geodesic geometry from an adjacency matrix. @@ -114,8 +99,6 @@ def from_graph( :math:`L^{sym} = \left(D^+\right)^{\frac{1}{2}} L \left(D^+\right)^{\frac{1}{2}}`, where :math:`L` is the non-normalized Laplacian and :math:`D` is the degree matrix. - eigenval_scipy: Whether to use the scipy implementation of the - eigenvalue computation. kwargs: Keyword arguments for the Geodesic class. Returns: @@ -134,10 +117,8 @@ def from_graph( jnp.where(degree > 0.0, 1.0 / jnp.sqrt(degree), 0.0) ) laplacian = inv_sqrt_deg @ laplacian @ inv_sqrt_deg - if eigenval_scipy: - eigval = compute_eigenvalue(laplacian) - else: - eigval = compute_largest_eigenvalue(laplacian, k=1) + + eigval = compute_largest_eigenvalue(laplacian, k=1) rescaled_laplacian = rescale_laplacian(laplacian, eigval) lap_min_id = define_scaled_laplacian( rescaled_laplacian @@ -176,8 +157,7 @@ def apply_kernel( Kernel applied to ``scaling``. """ return expm_multiply( - self.laplacian, scaling, self.chebyshev_coeffs, self.t, self.eigval, - self.order + self.laplacian, scaling, self.chebyshev_coeffs, self.eigval, self.order ) @property @@ -244,11 +224,6 @@ def tree_unflatten( # noqa: D102 return cls(*children, **aux_data) -# TODO: -# Moving some function here for now, idk if we want them in the class -# or in a utils file. - - def compute_largest_eigenvalue(laplacian_matrix, k, rng=None): # Compute the largest eigenvalue of the Laplacian matrix. if rng is None: @@ -263,9 +238,10 @@ def compute_largest_eigenvalue(laplacian_matrix, k, rng=None): # Compute eigenvalues using the sparse matrix-vector product eigvals, _, _ = jesp.linalg.lobpcg_standard( - lapl_vector_product, initial_dirs, m=k + lapl_vector_product, + initial_dirs, + m=100, ) - return jnp.max(eigvals) @@ -273,11 +249,9 @@ def rescale_laplacian( laplacian_matrix: jnp.ndarray, largest_eigenvalue: jnp.ndarray ) -> jnp.ndarray: # Rescale the Laplacian matrix. - if largest_eigenvalue > 2: - rescaled_laplacian = laplacian_matrix.copy() - rescaled_laplacian /= largest_eigenvalue - return 2 * rescaled_laplacian - return laplacian_matrix + return jax.lax.cond((largest_eigenvalue > 2), + lambda l: 2 * l / largest_eigenvalue, lambda l: l, + laplacian_matrix) def define_scaled_laplacian(laplacian_matrix: jnp.ndarray) -> jnp.ndarray: @@ -295,49 +269,39 @@ def _scipy_compute_chebychev_coeff_all(phi, tau, K): return coeff -def expm_multiply( - L, - X, - coeff, - phi, - tau, - K=None, -): - # Initialize the accumulator with only the first coeff*polynomial +def expm_multiply(L, X, coeff, phi, K): + + def body(carry, c): + T0, T1, Y = carry + T2 = (2 / phi) * L @ T1 - 2 * T1 - T0 + Y = Y + c * T2 + return (T1, T2, Y), None + T0 = X Y = 0.5 * coeff[0] * T0 - # Add the second coeff*polynomial to the accumulator T1 = (1 / phi) * L @ X - T0 Y = Y + coeff[1] * T1 - # Recursively add the next coeff*polynomial - for j in range(2, K + 1): - T2 = (2 / phi) * L @ T1 - 2 * T1 - T0 - Y = Y + coeff[j] * T2 - T0 = T1 - T1 = T2 + + initial_state = (T0, T1, Y) + carry, _ = jax.lax.scan(body, initial_state, coeff[2:]) + _, _, Y = carry return Y def compute_chebychev_coeff_all(phi, tau, K): """Jax wrapper to compute the K+1 Chebychev coefficients.""" - if not isinstance(phi, jnp.ndarray): - phi = jnp.asarray(phi) + if hasattr(phi, "dtype") and phi.dtype == jnp.float64: + _type = jnp.float64 + else: + _type = jnp.float32 result_shape_dtype = jax.ShapeDtypeStruct( shape=(K + 1,), - dtype=jax.numpy.float32, - ) # TODO: not sure about the best type here. Maybe the best if to have - # the same type as the laplacian. - return jax.pure_callback( - _scipy_compute_chebychev_coeff_all, result_shape_dtype, phi, tau, K + dtype=_type, ) + chebychev_coeff = lambda phi, tau, K: _scipy_compute_chebychev_coeff_all( + phi, tau, K + ).astype(_type) -def compute_eigenvalue(L): - """Jax wrapper to compute the largest eigenvalue of the Laplacian.""" - result_shape_dtype = jax.ShapeDtypeStruct( - shape=(1,), - dtype=jax.numpy.float32, - ) - eval_only = lambda x: _scipy_eigsh(x, k=1)[0] / 2.0 - return jax.pure_callback(eval_only, result_shape_dtype, L) + return jax.pure_callback(chebychev_coeff, result_shape_dtype, phi, tau, K) From 0dd5a97e70ab4177b1baf3758a82d85d9f3d46c0 Mon Sep 17 00:00:00 2001 From: guillaumehu Date: Wed, 8 Nov 2023 18:37:57 +0100 Subject: [PATCH 29/44] test compare with BE and CN --- tests/geometry/geo_test.py | 96 ++++++++++++++++++++------------------ 1 file changed, 51 insertions(+), 45 deletions(-) diff --git a/tests/geometry/geo_test.py b/tests/geometry/geo_test.py index d32bbfa4e..999ce9306 100644 --- a/tests/geometry/geo_test.py +++ b/tests/geometry/geo_test.py @@ -17,9 +17,10 @@ import jax.numpy as jnp import networkx as nx import numpy as np +import pytest from networkx.algorithms import shortest_paths -from networkx.generators import random_graphs -from ott.geometry import geodesic, geometry +from networkx.generators import balanced_tree, random_graphs +from ott.geometry import geodesic, geometry, graph def random_graph( @@ -74,13 +75,12 @@ def test_init(self): np.testing.assert_equal(geom.order, order) np.testing.assert_equal(geom.t, t) - # np.testing.assert_equal(geom.laplacian, G) # TODO: check for the normalized laplacian def test_kernel_is_symmetric_positive_definite( self, rng: jax.random.PRNGKeyArray ): - n, tol = 65, 0.02 - t = 20 + n, tol = 100, 0.02 + t = 1 order = 50 x = jax.random.normal(rng, (n,)) G = random_graph(n) @@ -94,7 +94,7 @@ def test_kernel_is_symmetric_positive_definite( vec_direct1 = geom.kernel_matrix @ x # we symmetrize the kernel explicitly when materializing it, because - # numerical error arise for small `t` and `backward_euler`, or Chebyshev approximation. + # numerical errors can make it non-symmetric. np.testing.assert_array_equal(kernel, kernel.T) eigenvalues = jnp.linalg.eigvals(kernel) neg_eigenvalues = eigenvalues[eigenvalues < 0] @@ -121,7 +121,6 @@ def test_automatic_t(self): actual = geom.t np.testing.assert_equal(actual, expected) - # class TestGraph: # @pytest.mark.fast.with_args( @@ -152,51 +151,57 @@ def test_automatic_t(self): # atol=1e-2 # ) -# @pytest.mark.fast.with_args( -# n_steps=[50, 100, 200], -# t=[1e-4, 1e-5], -# only_fast=0, -# ) -# def test_crank_nicolson_more_stable(self, t: Optional[float], n_steps: int): -# tol = 5 * t -# G = nx.linalg.adjacency_matrix(balanced_tree(r=2, h=5)) -# G = jnp.asarray(G.toarray(), dtype=float) -# eye = jnp.eye(G.shape[0]) - -# be_geom = graph.Graph.from_graph( -# G, t=t, n_steps=n_steps, numerical_scheme="backward_euler" -# ) -# cn_geom = graph.Graph.from_graph( -# G, t=t, n_steps=n_steps, numerical_scheme="crank_nicolson" -# ) -# eps = jnp.finfo(eye.dtype).tiny + @pytest.mark.fast.with_args( + n_steps=[50, 100, 200], + t=[1e-4, 1e-5], + only_fast=0, + ) + def cheb_be_cn(self, t: Optional[float], n_steps: int): + tol = 5 * t + G = nx.linalg.adjacency_matrix(balanced_tree(r=2, h=5)) + G = jnp.asarray(G.toarray(), dtype=float) + eye = jnp.eye(G.shape[0]) -# be_cost = -t * jnp.log(be_geom.apply_kernel(eye) + eps) -# cn_cost = -t * jnp.log(cn_geom.apply_kernel(eye) + eps) + be_geom = graph.Graph.from_graph( + G, t=t, n_steps=n_steps, numerical_scheme="backward_euler" + ) + cn_geom = graph.Graph.from_graph( + G, t=t, n_steps=n_steps, numerical_scheme="crank_nicolson" + ) + geo = geodesic.Geodesic.from_graph(G, t=t, order=n_steps) + eps = jnp.finfo(eye.dtype).tiny -# np.testing.assert_allclose(cn_cost, cn_cost.T, rtol=tol, atol=tol) -# with pytest.raises(AssertionError): -# np.testing.assert_allclose(be_cost, be_cost.T, rtol=tol, atol=tol) + be_cost = -t * jnp.log(be_geom.apply_kernel(eye) + eps) + cn_cost = -t * jnp.log(cn_geom.apply_kernel(eye) + eps) + cheb_cost = -t * jnp.log(geo.apply_kernel(eye) + eps) -# @pytest.mark.parametrize(("jit", "normalize"), [(False, True), (True, False)]) -# def test_directed_graph(self, jit: bool, normalize: bool): + np.testing.assert_allclose(cheb_cost, cheb_cost.T, rtol=tol, atol=tol) + # check that it is close to the BE CN + np.testing.assert_allclose(be_cost, cheb_cost, rtol=tol, atol=tol) + np.testing.assert_allclose(cn_cost, cheb_cost, rtol=tol, atol=tol) + with pytest.raises(AssertionError): + np.testing.assert_allclose(be_cost, be_cost.T, rtol=tol, atol=tol) -# def create_graph(G: jnp.ndarray) -> graph.Graph: -# return graph.Graph.from_graph(G, directed=True, normalize=normalize) + @pytest.mark.parametrize(("jit", "normalize"), [(False, True), (True, False)]) + def test_directed_graph(self, jit: bool, normalize: bool): -# G = random_graph(16, p=0.25, directed=True) -# create_fn = jax.jit(create_graph) if jit else create_graph -# geom = create_fn(G) + def create_graph(G: jnp.ndarray) -> graph.Graph: + return geodesic.Geodesic.from_graph(G, directed=True, normalize=normalize) -# with pytest.raises(AssertionError): -# np.testing.assert_allclose(G, G.T) + G = random_graph(16, p=0.25, directed=True) + create_fn = jax.jit(create_graph) if jit else create_graph + geom = create_fn(G) -# L = geom.laplacian + with pytest.raises(AssertionError): + np.testing.assert_allclose(G, G.T) + + L = geom.laplacian + + with pytest.raises(AssertionError): + # make sure that original graph was directed + np.testing.assert_allclose(G, G.T, rtol=1e-6, atol=1e-6) + np.testing.assert_allclose(L, L.T, rtol=1e-6, atol=1e-6) -# with pytest.raises(AssertionError): -# # make sure that original graph was directed -# np.testing.assert_allclose(G, G.T, rtol=1e-6, atol=1e-6) -# np.testing.assert_allclose(L, L.T, rtol=1e-6, atol=1e-6) # @pytest.mark.parametrize("directed", [False, True]) # @pytest.mark.parametrize("normalize", [False, True]) @@ -308,7 +313,8 @@ def test_automatic_t(self): # x = jnp.abs(jax.random.normal(rng, (n,))) # graph_no_tol = graph.Graph.from_graph(G, t=t, n_steps=n_steps, tol=-1) -# graph_low_tol = graph.Graph.from_graph(G, t=t, n_steps=n_steps, tol=2.5e-4) +# graph_low_tol = graph.Graph.from_graph(G, t=t, +# n_steps=n_steps, tol=2.5e-4) # graph_high_tol = graph.Graph.from_graph(G, t=t, n_steps=n_steps, tol=1e-1) # app_no_tol = graph_no_tol.apply_kernel(x) From 43df64901fe681adc239440ff171f7837faf1bae Mon Sep 17 00:00:00 2001 From: guillaumehu Date: Mon, 13 Nov 2023 11:51:20 +0100 Subject: [PATCH 30/44] rm lap_mn_id & test sink spd & fix tree_flatten --- src/ott/geometry/geodesic.py | 43 ++++--- tests/geometry/geo_test.py | 214 +++++++++++------------------------ 2 files changed, 82 insertions(+), 175 deletions(-) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index ffcb04e5a..ae8b82eac 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -57,27 +57,27 @@ class Geodesic(geometry.Geometry): def __init__( self, laplacian: Array_g, + scaled_laplacian: Array_g, + eigval: jnp.ndarray, t: float = 1e-3, order: int = 100, chebyshev_coeffs: Optional[List[float]] = None, - lap_min_id: Optional[Array_g] = None, # Rescale Laplacian minus identity - eigval: Optional[jnp.ndarray - ] = None, # (Second)Largest eigenvalue of Laplacian **kwargs: Any ): super().__init__(epsilon=1., **kwargs) self.laplacian = laplacian + self.scaled_laplacian = scaled_laplacian + self.eigval = eigval self.t = t self.order = order self.chebyshev_coeffs = chebyshev_coeffs - self.eigval = eigval - self.lap_min_id = lap_min_id @classmethod def from_graph( cls, G: Array_g, t: Optional[float] = 1e-3, + eigval: Optional[jnp.ndarray] = None, # Largest eigenvalue of Laplacian order: int = 100, directed: bool = False, normalize: bool = False, @@ -91,6 +91,8 @@ def from_graph( If `None`, it defaults to :math:`\frac{1}{|E|} \sum_{(u, v) \in E} \text{weight}(u, v)` :cite:`crane:13`. In this case, the ``graph`` must be specified and the edge weights are assumed to be positive. + eigval: Largest eigenvalue of the Laplacian. If `None`, it's computed + at initialization. order: Max order of Chebyshev polynomial. directed: Whether the ``graph`` is directed. If not, it's made undirected as :math:`G + G^T`. This parameter is ignored when passing @@ -118,11 +120,10 @@ def from_graph( ) laplacian = inv_sqrt_deg @ laplacian @ inv_sqrt_deg - eigval = compute_largest_eigenvalue(laplacian, k=1) - rescaled_laplacian = rescale_laplacian(laplacian, eigval) - lap_min_id = define_scaled_laplacian( - rescaled_laplacian - ) # TODO: remove if not needed. + eigval = compute_largest_eigenvalue( + laplacian, k=1 + ) if eigval is None else eigval + scaled_laplacian = rescale_laplacian(laplacian, eigval) if t is None: t = (jnp.sum(G) / jnp.sum(G > 0.)) ** 2 @@ -132,11 +133,11 @@ def from_graph( return cls( laplacian=laplacian, + scaled_laplacian=scaled_laplacian, + eigval=eigval, t=t, order=order, chebyshev_coeffs=chebyshev_coeffs, - lap_min_id=lap_min_id, - eigval=eigval, **kwargs ) @@ -157,7 +158,7 @@ def apply_kernel( Kernel applied to ``scaling``. """ return expm_multiply( - self.laplacian, scaling, self.chebyshev_coeffs, self.eigval, self.order + self.scaled_laplacian, scaling, self.chebyshev_coeffs, self.eigval ) @property @@ -213,9 +214,10 @@ def marginal_from_potentials( raise ValueError("Not implemented.") def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 - return [self.laplacian, self.t, self.order], { - "chebyshev_coeffs": self.chebyshev_coeffs, - } + return [ + self.laplacian, self.scaled_laplacian, self.eigval, self.t, self.order, + self.chebyshev_coeffs + ], {} @classmethod def tree_unflatten( # noqa: D102 @@ -254,13 +256,6 @@ def rescale_laplacian( laplacian_matrix) -def define_scaled_laplacian(laplacian_matrix: jnp.ndarray) -> jnp.ndarray: - # Define the scaled Laplacian matrix. - n = laplacian_matrix.shape[0] - identity = jnp.eye(n) - return laplacian_matrix - identity - - def _scipy_compute_chebychev_coeff_all(phi, tau, K): """Compute the K+1 Chebychev coefficients for our functions.""" coeff = 2 * ive(np.arange(0, K + 1), -tau * phi) @@ -269,7 +264,7 @@ def _scipy_compute_chebychev_coeff_all(phi, tau, K): return coeff -def expm_multiply(L, X, coeff, phi, K): +def expm_multiply(L, X, coeff, phi): def body(carry, c): T0, T1, Y = carry diff --git a/tests/geometry/geo_test.py b/tests/geometry/geo_test.py index 999ce9306..2f205e26d 100644 --- a/tests/geometry/geo_test.py +++ b/tests/geometry/geo_test.py @@ -21,6 +21,8 @@ from networkx.algorithms import shortest_paths from networkx.generators import balanced_tree, random_graphs from ott.geometry import geodesic, geometry, graph +from ott.problems.linear import linear_problem +from ott.solvers.linear import sinkhorn def random_graph( @@ -121,36 +123,6 @@ def test_automatic_t(self): actual = geom.t np.testing.assert_equal(actual, expected) -# class TestGraph: - -# @pytest.mark.fast.with_args( -# numerical_scheme=["backward_euler", "crank_nicolson"], -# only_fast=0, -# ) -# def test_approximates_ground_truth( -# self, -# rng: jax.random.PRNGKeyArray, -# numerical_scheme: Literal["backward_euler", "crank_nicolson"], -# ): -# eps, n_steps = 1e-5, 20 -# G = random_graph(37, p=0.5) -# x = jax.random.normal(rng, (G.shape[0],)) - -# gt_geom = gt_geometry(G, epsilon=eps) -# graph_geom = graph.Graph.from_graph( -# G, t=eps, n_steps=n_steps, numerical_scheme=numerical_scheme -# ) - -# np.testing.assert_allclose( -# gt_geom.kernel_matrix, graph_geom.kernel_matrix, rtol=1e-2, atol=1e-2 -# ) -# np.testing.assert_allclose( -# gt_geom.apply_kernel(x), -# graph_geom.apply_kernel(x), -# rtol=1e-2, -# atol=1e-2 -# ) - @pytest.mark.fast.with_args( n_steps=[50, 100, 200], t=[1e-4, 1e-5], @@ -202,126 +174,66 @@ def create_graph(G: jnp.ndarray) -> graph.Graph: np.testing.assert_allclose(G, G.T, rtol=1e-6, atol=1e-6) np.testing.assert_allclose(L, L.T, rtol=1e-6, atol=1e-6) + @pytest.mark.parametrize("directed", [False, True]) + @pytest.mark.parametrize("normalize", [False, True]) + def test_normalize_laplacian(self, directed: bool, normalize: bool): + + def laplacian(G: jnp.ndarray) -> jnp.ndarray: + if directed: + G = G + G.T + + data = jnp.sum(G, axis=1) + lap = jnp.diag(data) - G + if normalize: + inv_sqrt_deg = jnp.diag( + jnp.where(data > 0.0, 1.0 / jnp.sqrt(data), 0.0) + ) + return inv_sqrt_deg @ lap @ inv_sqrt_deg + return lap + + G = random_graph(51, p=0.35, directed=directed) + geom = geodesic.Geodesic.from_graph( + G, directed=directed, normalize=normalize + ) + + expected = laplacian(G) + actual = geom.laplacian + + np.testing.assert_allclose(actual, expected, rtol=1e-6, atol=1e-6) -# @pytest.mark.parametrize("directed", [False, True]) -# @pytest.mark.parametrize("normalize", [False, True]) -# def test_normalize_laplacian(self, directed: bool, normalize: bool): + @pytest.mark.fast.with_args(jit=[False, True], only_fast=0) + def test_graph_sinkhorn(self, rng: jax.random.PRNGKeyArray, jit: bool): + + def callback(geom: geometry.Geometry) -> sinkhorn.SinkhornOutput: + solver = sinkhorn.Sinkhorn(lse_mode=False) + problem = linear_problem.LinearProblem(geom) + return solver(problem) + + n, eps, tol = 11, 1e-5, 1e-3 + G = random_graph(n, p=0.35) + x = jax.random.normal(rng, (n,)) -# def laplacian(G: jnp.ndarray) -> jnp.ndarray: -# if directed: -# G = G + G.T - -# data = jnp.sum(G, axis=1) -# lap = jnp.diag(data) - G -# if normalize: -# inv_sqrt_deg = jnp.diag( -# jnp.where(data > 0.0, 1.0 / jnp.sqrt(data), 0.0) -# ) -# return inv_sqrt_deg @ lap @ inv_sqrt_deg -# return lap - -# G = random_graph(51, p=0.35, directed=directed) -# geom = graph.Graph.from_graph(G, directed=directed, normalize=normalize) - -# expected = laplacian(G) -# actual = geom.laplacian - -# np.testing.assert_allclose(actual, expected, rtol=1e-6, atol=1e-6) - -# @pytest.mark.fast.with_args(jit=[False, True], only_fast=0) -# def test_graph_sinkhorn(self, rng: jax.random.PRNGKeyArray, jit: bool): - -# def callback(geom: geometry.Geometry) -> sinkhorn.SinkhornOutput: -# solver = sinkhorn.Sinkhorn(lse_mode=False) -# problem = linear_problem.LinearProblem(geom) -# return solver(problem) - -# n, eps, tol = 11, 1e-5, 1e-3 -# G = random_graph(n, p=0.35) -# x = jax.random.normal(rng, (n,)) - -# gt_geom = gt_geometry(G, epsilon=eps) -# graph_geom = graph.Graph.from_graph(G, t=eps) - -# fn = jax.jit(callback) if jit else callback - -# gt_out = fn(gt_geom) -# graph_out = fn(graph_geom) - -# assert gt_out.converged -# assert graph_out.converged -# np.testing.assert_allclose( -# graph_out.reg_ot_cost, gt_out.reg_ot_cost, rtol=tol, atol=tol -# ) -# np.testing.assert_allclose(graph_out.f, gt_out.f, rtol=tol, atol=tol) -# np.testing.assert_allclose(graph_out.g, gt_out.g, rtol=tol, atol=tol) - -# for axis in [0, 1]: -# y_gt = gt_out.apply(x, axis=axis) -# y_out = graph_out.apply(x, axis=axis) -# # note the high tolerance -# np.testing.assert_allclose(y_gt, y_out, rtol=5e-1, atol=5e-1) - -# np.testing.assert_allclose( -# gt_out.matrix, graph_out.matrix, rtol=1e-1, atol=1e-1 -# ) - -# @pytest.mark.parametrize( -# "implicit_diff", -# [False, True], -# ids=["not-implicit", "implicit"], -# ) -# def test_dense_graph_differentiability( -# self, rng: jax.random.PRNGKeyArray, implicit_diff: bool -# ): - -# def callback( -# data: jnp.ndarray, rows: jnp.ndarray, cols: jnp.ndarray, -# shape: Tuple[int, int] -# ) -> float: -# G = sparse.BCOO((data, jnp.c_[rows, cols]), shape=shape).todense() - -# geom = graph.Graph.from_graph(G, t=1.) -# solver = sinkhorn.Sinkhorn(lse_mode=False, **kwargs) -# problem = linear_problem.LinearProblem(geom) - -# return solver(problem).reg_ot_cost - -# if implicit_diff: -# kwargs = {"implicit_diff": implicit_lib.ImplicitDiff()} -# else: -# kwargs = {"implicit_diff": None} - -# eps = 1e-3 -# G = random_graph(20, p=0.5) -# G = sparse.BCOO.fromdense(G) - -# w, rows, cols = G.data, G.indices[:, 0], G.indices[:, 1] -# v_w = jax.random.normal(rng, shape=w.shape) -# v_w = (v_w / jnp.linalg.norm(v_w, axis=-1, keepdims=True)) * eps - -# grad_w = jax.grad(callback)(w, rows, cols, shape=G.shape) - -# expected = callback(w + v_w, rows, cols, -# G.shape) - callback(w - v_w, rows, cols, G.shape) -# actual = 2 * jnp.vdot(v_w, grad_w) -# np.testing.assert_allclose(actual, expected, rtol=1e-4, atol=1e-4) - -# def test_tolerance_hilbert_metric(self, rng: jax.random.PRNGKeyArray): -# n, n_steps, t, tol = 256, 1000, 1e-4, 3e-4 -# G = random_graph(n, p=0.15) -# x = jnp.abs(jax.random.normal(rng, (n,))) - -# graph_no_tol = graph.Graph.from_graph(G, t=t, n_steps=n_steps, tol=-1) -# graph_low_tol = graph.Graph.from_graph(G, t=t, -# n_steps=n_steps, tol=2.5e-4) -# graph_high_tol = graph.Graph.from_graph(G, t=t, n_steps=n_steps, tol=1e-1) - -# app_no_tol = graph_no_tol.apply_kernel(x) -# app_low_tol = graph_low_tol.apply_kernel(x) # does 1 iteration -# app_high_tol = graph_high_tol.apply_kernel(x) # does 961 iterations - -# np.testing.assert_allclose(app_no_tol, app_low_tol, rtol=tol, atol=tol) -# np.testing.assert_allclose(app_no_tol, app_high_tol, rtol=5e-2, atol=5e-2) -# with pytest.raises(AssertionError): -# np.testing.assert_allclose(app_no_tol, app_high_tol, rtol=tol, atol=tol) + gt_geom = gt_geometry(G, epsilon=eps) + graph_geom = geodesic.Geodesic.from_graph(G, t=eps) + + fn = jax.jit(callback) if jit else callback + gt_out = fn(gt_geom) + graph_out = fn(graph_geom) + + assert gt_out.converged + assert graph_out.converged + np.testing.assert_allclose( + graph_out.reg_ot_cost, gt_out.reg_ot_cost, rtol=tol, atol=tol + ) + np.testing.assert_allclose(graph_out.f, gt_out.f, rtol=tol, atol=tol) + np.testing.assert_allclose(graph_out.g, gt_out.g, rtol=tol, atol=tol) + + for axis in [0, 1]: + y_gt = gt_out.apply(x, axis=axis) + y_out = graph_out.apply(x, axis=axis) + # note the high tolerance + np.testing.assert_allclose(y_gt, y_out, rtol=5e-1, atol=5e-1) + + np.testing.assert_allclose( + gt_out.matrix, graph_out.matrix, rtol=1e-1, atol=1e-1 + ) From acda12bc6e187809c91a4a89f4b4e3476824f39f Mon Sep 17 00:00:00 2001 From: guillaumehu Date: Mon, 13 Nov 2023 12:13:42 +0100 Subject: [PATCH 31/44] default and type of Cheb. co. & docstrings --- src/ott/geometry/geodesic.py | 36 ++++++++++++++++-------------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index ae8b82eac..3a4442ea3 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Sequence, Tuple +from typing import Any, Dict, Optional, Sequence, Tuple import jax import jax.experimental.sparse as jesp @@ -27,28 +27,20 @@ __all__ = ["Geodesic"] -# TODO: -# - Finalize the docstrings. -# - Add tests. -# - Verify sparse graph + cholesky. - -# Previous meetings todos: -# 1) wrap all scipy and numpy (done) -# 2) move all the comp in the init, just call it once (done) -# 3) make sure it works with sparse graph + cholesky -# 4) differentiablity , graph geo uses cholesky (triangle solve). - @jax.tree_util.register_pytree_node_class class Geodesic(geometry.Geometry): r"""Graph distance approximation using heat kernel :cite:`huguet:2022`. - Approximates the heat-geodesic kernel using thee Chebyshev polynomials of the + Approximates the heat kernel using Chebyshev polynomials of the first kind of max order ``order``, which for small ``t`` approximates the geodesic exponential kernel :math:`e^{\frac{-d(x, y)^2}{t}}`. Args: laplacian: Symmetric graph Laplacian. + scaled_laplacian: The Laplacian scaled by the largest eigenvalue. + eigval: Largest eigenvalue of the Laplacian. + chebyshev_coeffs: Coefficients of the Chebyshev polynomials. t: Time parameter for heat kernel. order: Max order of Chebyshev polynomial. kwargs: Keyword arguments for :class:`~ott.geometry.geometry.Geometry`. @@ -59,25 +51,25 @@ def __init__( laplacian: Array_g, scaled_laplacian: Array_g, eigval: jnp.ndarray, + chebyshev_coeffs: jnp.ndarray, t: float = 1e-3, order: int = 100, - chebyshev_coeffs: Optional[List[float]] = None, **kwargs: Any ): super().__init__(epsilon=1., **kwargs) self.laplacian = laplacian self.scaled_laplacian = scaled_laplacian self.eigval = eigval + self.chebyshev_coeffs = chebyshev_coeffs self.t = t self.order = order - self.chebyshev_coeffs = chebyshev_coeffs @classmethod def from_graph( cls, G: Array_g, t: Optional[float] = 1e-3, - eigval: Optional[jnp.ndarray] = None, # Largest eigenvalue of Laplacian + eigval: Optional[jnp.ndarray] = None, order: int = 100, directed: bool = False, normalize: bool = False, @@ -94,7 +86,7 @@ def from_graph( eigval: Largest eigenvalue of the Laplacian. If `None`, it's computed at initialization. order: Max order of Chebyshev polynomial. - directed: Whether the ``graph`` is directed. If not, it's made + directed: Whether the ``graph`` is directed. If `True`, it's made undirected as :math:`G + G^T`. This parameter is ignored when passing the Laplacian directly, assumed to be symmetric. normalize: Whether to normalize the Laplacian as @@ -135,9 +127,9 @@ def from_graph( laplacian=laplacian, scaled_laplacian=scaled_laplacian, eigval=eigval, + chebyshev_coeffs=chebyshev_coeffs, t=t, order=order, - chebyshev_coeffs=chebyshev_coeffs, **kwargs ) @@ -215,8 +207,12 @@ def marginal_from_potentials( def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return [ - self.laplacian, self.scaled_laplacian, self.eigval, self.t, self.order, - self.chebyshev_coeffs + self.laplacian, + self.scaled_laplacian, + self.eigval, + self.chebyshev_coeffs, + self.t, + self.order, ], {} @classmethod From 8bbb19eafc41c6cd9fc8d421414c625348d01bfa Mon Sep 17 00:00:00 2001 From: guillaumehu Date: Mon, 13 Nov 2023 12:28:53 +0100 Subject: [PATCH 32/44] dtype in purecallback depending on Lap --- src/ott/geometry/geodesic.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index 3a4442ea3..c48821818 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -121,7 +121,9 @@ def from_graph( t = (jnp.sum(G) / jnp.sum(G > 0.)) ** 2 # Compute the coeffs of the Chebyshev pols approx using Bessel functs. - chebyshev_coeffs = compute_chebychev_coeff_all(eigval, t, order) + chebyshev_coeffs = compute_chebychev_coeff_all( + eigval, t, order, laplacian.dtype + ) return cls( laplacian=laplacian, @@ -252,11 +254,15 @@ def rescale_laplacian( laplacian_matrix) -def _scipy_compute_chebychev_coeff_all(phi, tau, K): +def _scipy_compute_chebychev_coeff_all(phi, tau, K, dtype=jnp.float32): """Compute the K+1 Chebychev coefficients for our functions.""" coeff = 2 * ive(np.arange(0, K + 1), -tau * phi) - if coeff.dtype == np.float64: + if dtype == jnp.float32 and coeff.dtype != np.float32: coeff = np.float32(coeff) + elif dtype == jnp.float64 and coeff.dtype != np.float64: + coeff = np.float64(coeff) + else: + raise ValueError("Invalid dtype.") return coeff @@ -279,20 +285,15 @@ def body(carry, c): return Y -def compute_chebychev_coeff_all(phi, tau, K): +def compute_chebychev_coeff_all(phi, tau, K, dtype=jnp.float32): """Jax wrapper to compute the K+1 Chebychev coefficients.""" - if hasattr(phi, "dtype") and phi.dtype == jnp.float64: - _type = jnp.float64 - else: - _type = jnp.float32 - result_shape_dtype = jax.ShapeDtypeStruct( shape=(K + 1,), - dtype=_type, + dtype=dtype, ) chebychev_coeff = lambda phi, tau, K: _scipy_compute_chebychev_coeff_all( - phi, tau, K - ).astype(_type) + phi, tau, K, dtype=dtype + ).astype(dtype) return jax.pure_callback(chebychev_coeff, result_shape_dtype, phi, tau, K) From d6408213a34af4c51a6a23be9161d8a918bb62cd Mon Sep 17 00:00:00 2001 From: guillaumehu Date: Thu, 16 Nov 2023 18:07:40 +0100 Subject: [PATCH 33/44] rm laplacian & update test --- src/ott/geometry/geodesic.py | 9 ++------- tests/geometry/geo_test.py | 9 +++++++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index c48821818..8924eaf7c 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -37,7 +37,6 @@ class Geodesic(geometry.Geometry): geodesic exponential kernel :math:`e^{\frac{-d(x, y)^2}{t}}`. Args: - laplacian: Symmetric graph Laplacian. scaled_laplacian: The Laplacian scaled by the largest eigenvalue. eigval: Largest eigenvalue of the Laplacian. chebyshev_coeffs: Coefficients of the Chebyshev polynomials. @@ -48,7 +47,6 @@ class Geodesic(geometry.Geometry): def __init__( self, - laplacian: Array_g, scaled_laplacian: Array_g, eigval: jnp.ndarray, chebyshev_coeffs: jnp.ndarray, @@ -57,7 +55,6 @@ def __init__( **kwargs: Any ): super().__init__(epsilon=1., **kwargs) - self.laplacian = laplacian self.scaled_laplacian = scaled_laplacian self.eigval = eigval self.chebyshev_coeffs = chebyshev_coeffs @@ -126,7 +123,6 @@ def from_graph( ) return cls( - laplacian=laplacian, scaled_laplacian=scaled_laplacian, eigval=eigval, chebyshev_coeffs=chebyshev_coeffs, @@ -171,7 +167,7 @@ def cost_matrix(self) -> jnp.ndarray: # noqa: D102 @property def shape(self) -> Tuple[int, int]: # noqa: D102 - return self.laplacian.shape + return self.scaled_laplacian.shape @property def is_symmetric(self) -> bool: # noqa: D102 @@ -179,7 +175,7 @@ def is_symmetric(self) -> bool: # noqa: D102 @property def dtype(self) -> jnp.dtype: # noqa: D102 - return self.laplacian.dtype + return self.scaled_laplacian.dtype def transport_from_potentials( self, f: jnp.ndarray, g: jnp.ndarray @@ -209,7 +205,6 @@ def marginal_from_potentials( def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return [ - self.laplacian, self.scaled_laplacian, self.eigval, self.chebyshev_coeffs, diff --git a/tests/geometry/geo_test.py b/tests/geometry/geo_test.py index 2f205e26d..c8d505489 100644 --- a/tests/geometry/geo_test.py +++ b/tests/geometry/geo_test.py @@ -167,7 +167,7 @@ def create_graph(G: jnp.ndarray) -> graph.Graph: with pytest.raises(AssertionError): np.testing.assert_allclose(G, G.T) - L = geom.laplacian + L = geom.scaled_laplacian with pytest.raises(AssertionError): # make sure that original graph was directed @@ -197,7 +197,12 @@ def laplacian(G: jnp.ndarray) -> jnp.ndarray: ) expected = laplacian(G) - actual = geom.laplacian + eigenvalues = jnp.linalg.eigvals(expected) + eigval = jnp.max(eigenvalues) + #rescale the laplacian + expected = 2 * expected / eigval if eigval > 2 else expected + + actual = geom.scaled_laplacian np.testing.assert_allclose(actual, expected, rtol=1e-6, atol=1e-6) From bab7d9428c0fa6b0b56f15e8a1be6140f1c46952 Mon Sep 17 00:00:00 2001 From: guillaumehu Date: Fri, 17 Nov 2023 10:01:22 +0100 Subject: [PATCH 34/44] simpler wrapper for `ive` --- src/ott/geometry/geodesic.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index 8924eaf7c..4d0efe2b5 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -249,18 +249,6 @@ def rescale_laplacian( laplacian_matrix) -def _scipy_compute_chebychev_coeff_all(phi, tau, K, dtype=jnp.float32): - """Compute the K+1 Chebychev coefficients for our functions.""" - coeff = 2 * ive(np.arange(0, K + 1), -tau * phi) - if dtype == jnp.float32 and coeff.dtype != np.float32: - coeff = np.float32(coeff) - elif dtype == jnp.float64 and coeff.dtype != np.float64: - coeff = np.float64(coeff) - else: - raise ValueError("Invalid dtype.") - return coeff - - def expm_multiply(L, X, coeff, phi): def body(carry, c): @@ -287,8 +275,8 @@ def compute_chebychev_coeff_all(phi, tau, K, dtype=jnp.float32): dtype=dtype, ) - chebychev_coeff = lambda phi, tau, K: _scipy_compute_chebychev_coeff_all( - phi, tau, K, dtype=dtype + chebychev_coeff = lambda phi, tau, K: ( + 2 * ive(np.arange(0, K + 1), -tau * phi) ).astype(dtype) return jax.pure_callback(chebychev_coeff, result_shape_dtype, phi, tau, K) From b296da49d25402ae24c7eddfba119311d03c3872 Mon Sep 17 00:00:00 2001 From: guillaumehu Date: Fri, 17 Nov 2023 10:53:44 +0100 Subject: [PATCH 35/44] lint-docs spell check --- docs/spelling/technical.txt | 3 +++ src/ott/geometry/geodesic.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/spelling/technical.txt b/docs/spelling/technical.txt index 4030a8292..63a7851fd 100644 --- a/docs/spelling/technical.txt +++ b/docs/spelling/technical.txt @@ -1,6 +1,7 @@ Barycenters Brenier Bures +Chebyshev Cholesky DTW Danskin @@ -24,6 +25,7 @@ SGD Schur Seidel Sinkhorn +UNet Unbalancedness Wasserstein adaptively @@ -93,6 +95,7 @@ parameterization parameterizing piecewise pluripotent +polynomials positivity postfix potentials diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index 4d0efe2b5..679f6c868 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -82,7 +82,7 @@ def from_graph( must be specified and the edge weights are assumed to be positive. eigval: Largest eigenvalue of the Laplacian. If `None`, it's computed at initialization. - order: Max order of Chebyshev polynomial. + order: Max order of Chebyshev polynomials. directed: Whether the ``graph`` is directed. If `True`, it's made undirected as :math:`G + G^T`. This parameter is ignored when passing the Laplacian directly, assumed to be symmetric. From b7d1df3693bab182538ac6058d22d1c72d269356 Mon Sep 17 00:00:00 2001 From: guillaumehu Date: Mon, 20 Nov 2023 14:27:07 +0100 Subject: [PATCH 36/44] fix mistake rm setup --- setup.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 setup.py diff --git a/setup.py b/setup.py new file mode 100644 index 000000000..9ae2b1026 --- /dev/null +++ b/setup.py @@ -0,0 +1,17 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from setuptools import setup + +# for packaging tools not supporting, e.g., PEP 517, PEP 660 +setup() From 999b226f22642e27c2a50830c4f902153e1e7175 Mon Sep 17 00:00:00 2001 From: guillaumehu Date: Mon, 27 Nov 2023 18:10:18 +0100 Subject: [PATCH 37/44] typo & rm fn scale lap & --- src/ott/geometry/geodesic.py | 48 +++++++++++++++++------------------- 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index 679f6c868..aed0c20b3 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -20,7 +20,6 @@ import numpy as np from scipy.special import ive -from ott import utils from ott.geometry import geometry from ott.math import utils as mu from ott.types import Array_g @@ -32,15 +31,20 @@ class Geodesic(geometry.Geometry): r"""Graph distance approximation using heat kernel :cite:`huguet:2022`. - Approximates the heat kernel using Chebyshev polynomials of the - first kind of max order ``order``, which for small ``t`` approximates the + important:: + This constructor is not meant to be called by the user, + please use the :meth:`from_graph` method instead. + + Approximates the heat kernel using `Chebyshev polynomials + `_ of the first kind of + max order ``order``, which for small ``t`` approximates the geodesic exponential kernel :math:`e^{\frac{-d(x, y)^2}{t}}`. Args: scaled_laplacian: The Laplacian scaled by the largest eigenvalue. - eigval: Largest eigenvalue of the Laplacian. + eigval: The largest eigenvalue of the Laplacian. chebyshev_coeffs: Coefficients of the Chebyshev polynomials. - t: Time parameter for heat kernel. + t: Time parameter for the heat kernel. order: Max order of Chebyshev polynomial. kwargs: Keyword arguments for :class:`~ott.geometry.geometry.Geometry`. """ @@ -90,7 +94,7 @@ def from_graph( :math:`L^{sym} = \left(D^+\right)^{\frac{1}{2}} L \left(D^+\right)^{\frac{1}{2}}`, where :math:`L` is the non-normalized Laplacian and :math:`D` is the degree matrix. - kwargs: Keyword arguments for the Geodesic class. + kwargs: Keyword arguments for :class:`~ott.geometry.geodesic.Geodesic`. Returns: The Geodesic geometry. @@ -112,10 +116,12 @@ def from_graph( eigval = compute_largest_eigenvalue( laplacian, k=1 ) if eigval is None else eigval - scaled_laplacian = rescale_laplacian(laplacian, eigval) + + scaled_laplacian = jax.lax.cond((eigval > 2.0), lambda l: 2.0 * l / eigval, + lambda l: l, laplacian) if t is None: - t = (jnp.sum(G) / jnp.sum(G > 0.)) ** 2 + t = (jnp.sum(G) / jnp.sum(G > 0.)) ** 2.0 # Compute the coeffs of the Chebyshev pols approx using Bessel functs. chebyshev_coeffs = compute_chebychev_coeff_all( @@ -163,7 +169,7 @@ def kernel_matrix(self) -> jnp.ndarray: # noqa: D102 @property def cost_matrix(self) -> jnp.ndarray: # noqa: D102 # Calculate the cost matrix using the formula (5) from the main reference - return -4 * self.t * mu.safe_log(self.kernel_matrix) + return -4.0 * self.t * mu.safe_log(self.kernel_matrix) @property def shape(self) -> Tuple[int, int]: # noqa: D102 @@ -219,13 +225,13 @@ def tree_unflatten( # noqa: D102 return cls(*children, **aux_data) -def compute_largest_eigenvalue(laplacian_matrix, k, rng=None): +def compute_largest_eigenvalue( + laplacian_matrix: jnp.ndarray, rng: Optional[jax.Array] = None +) -> float: # Compute the largest eigenvalue of the Laplacian matrix. - if rng is None: - rng = utils.default_prng_key(rng) n = laplacian_matrix.shape[0] # Generate random initial directions for eigenvalue computation - initial_dirs = jax.random.normal(rng, (n, k)) + initial_dirs = jax.random.normal(rng, (n, 1)) # Create a sparse matrix-vector product function using sparsify # This function multiplies the sparse laplacian_matrix with a vector @@ -235,31 +241,21 @@ def compute_largest_eigenvalue(laplacian_matrix, k, rng=None): eigvals, _, _ = jesp.linalg.lobpcg_standard( lapl_vector_product, initial_dirs, - m=100, ) return jnp.max(eigvals) -def rescale_laplacian( - laplacian_matrix: jnp.ndarray, largest_eigenvalue: jnp.ndarray -) -> jnp.ndarray: - # Rescale the Laplacian matrix. - return jax.lax.cond((largest_eigenvalue > 2), - lambda l: 2 * l / largest_eigenvalue, lambda l: l, - laplacian_matrix) - - def expm_multiply(L, X, coeff, phi): def body(carry, c): T0, T1, Y = carry - T2 = (2 / phi) * L @ T1 - 2 * T1 - T0 + T2 = (2.0 / phi) * L @ T1 - 2.0 * T1 - T0 Y = Y + c * T2 return (T1, T2, Y), None T0 = X Y = 0.5 * coeff[0] * T0 - T1 = (1 / phi) * L @ X - T0 + T1 = (1.0 / phi) * L @ X - T0 Y = Y + coeff[1] * T1 initial_state = (T0, T1, Y) @@ -276,7 +272,7 @@ def compute_chebychev_coeff_all(phi, tau, K, dtype=jnp.float32): ) chebychev_coeff = lambda phi, tau, K: ( - 2 * ive(np.arange(0, K + 1), -tau * phi) + 2.0 * ive(np.arange(0, K + 1), -tau * phi) ).astype(dtype) return jax.pure_callback(chebychev_coeff, result_shape_dtype, phi, tau, K) From e19cf0183237979a07cb74acfba94fc0cd374fe6 Mon Sep 17 00:00:00 2001 From: guillaumehu Date: Mon, 27 Nov 2023 19:06:33 +0100 Subject: [PATCH 38/44] rm t order init & typing --- src/ott/geometry/geodesic.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index aed0c20b3..7b9f6832b 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -55,15 +55,12 @@ def __init__( eigval: jnp.ndarray, chebyshev_coeffs: jnp.ndarray, t: float = 1e-3, - order: int = 100, **kwargs: Any ): super().__init__(epsilon=1., **kwargs) self.scaled_laplacian = scaled_laplacian self.eigval = eigval self.chebyshev_coeffs = chebyshev_coeffs - self.t = t - self.order = order @classmethod def from_graph( @@ -132,8 +129,6 @@ def from_graph( scaled_laplacian=scaled_laplacian, eigval=eigval, chebyshev_coeffs=chebyshev_coeffs, - t=t, - order=order, **kwargs ) @@ -214,8 +209,6 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 self.scaled_laplacian, self.eigval, self.chebyshev_coeffs, - self.t, - self.order, ], {} @classmethod @@ -245,7 +238,9 @@ def compute_largest_eigenvalue( return jnp.max(eigvals) -def expm_multiply(L, X, coeff, phi): +def expm_multiply( + L: jnp.ndarray, X: jnp.ndarray, coeff: jnp.ndarray, phi: float +) -> jnp.ndarray: def body(carry, c): T0, T1, Y = carry @@ -259,12 +254,13 @@ def body(carry, c): Y = Y + coeff[1] * T1 initial_state = (T0, T1, Y) - carry, _ = jax.lax.scan(body, initial_state, coeff[2:]) - _, _, Y = carry + (_, _, Y), _ = jax.lax.scan(body, initial_state, coeff[2:]) return Y -def compute_chebychev_coeff_all(phi, tau, K, dtype=jnp.float32): +def compute_chebychev_coeff_all( + phi: float, tau: float, K: int, dtype: np.dtype +) -> jnp.ndarray: """Jax wrapper to compute the K+1 Chebychev coefficients.""" result_shape_dtype = jax.ShapeDtypeStruct( shape=(K + 1,), From 298c544f7410fd3cebed60bd6142719c0cf79827 Mon Sep 17 00:00:00 2001 From: guillaumehu Date: Mon, 27 Nov 2023 19:12:16 +0100 Subject: [PATCH 39/44] use rng util --- src/ott/geometry/geodesic.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index 7b9f6832b..2ce08cd0a 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -23,6 +23,7 @@ from ott.geometry import geometry from ott.math import utils as mu from ott.types import Array_g +from ott.utils import default_prng_key __all__ = ["Geodesic"] @@ -110,9 +111,7 @@ def from_graph( ) laplacian = inv_sqrt_deg @ laplacian @ inv_sqrt_deg - eigval = compute_largest_eigenvalue( - laplacian, k=1 - ) if eigval is None else eigval + eigval = compute_largest_eigenvalue(laplacian) if eigval is None else eigval scaled_laplacian = jax.lax.cond((eigval > 2.0), lambda l: 2.0 * l / eigval, lambda l: l, laplacian) @@ -224,7 +223,7 @@ def compute_largest_eigenvalue( # Compute the largest eigenvalue of the Laplacian matrix. n = laplacian_matrix.shape[0] # Generate random initial directions for eigenvalue computation - initial_dirs = jax.random.normal(rng, (n, 1)) + initial_dirs = jax.random.normal(default_prng_key(rng), (n, 1)) # Create a sparse matrix-vector product function using sparsify # This function multiplies the sparse laplacian_matrix with a vector From 7d13198ad2adcc7b85005fa7b83e3bd7a9166d6b Mon Sep 17 00:00:00 2001 From: guillaumehu Date: Mon, 27 Nov 2023 21:37:45 +0100 Subject: [PATCH 40/44] t for cost & test with ground truth --- docs/references.bib | 2 +- src/ott/geometry/geodesic.py | 3 +++ tests/geometry/geo_test.py | 47 +++++++----------------------------- 3 files changed, 13 insertions(+), 39 deletions(-) diff --git a/docs/references.bib b/docs/references.bib index a208c2db4..3060bf191 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -806,7 +806,7 @@ @misc{klein:23 year = {2023}, } -@article{huguet:2022, +@misc{huguet:2022, title={Geodesic Sinkhorn: optimal transport for high-dimensional datasets}, author={Huguet, Guillaume and Tong, Alexander and Zapatero, Mar{\'\i}a Ramos and Wolf, Guy and Krishnaswamy, Smita}, journal={arXiv preprint arXiv:2211.00805}, diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index 2ce08cd0a..a64790539 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -62,6 +62,7 @@ def __init__( self.scaled_laplacian = scaled_laplacian self.eigval = eigval self.chebyshev_coeffs = chebyshev_coeffs + self.t = t @classmethod def from_graph( @@ -128,6 +129,7 @@ def from_graph( scaled_laplacian=scaled_laplacian, eigval=eigval, chebyshev_coeffs=chebyshev_coeffs, + t=t, **kwargs ) @@ -208,6 +210,7 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 self.scaled_laplacian, self.eigval, self.chebyshev_coeffs, + self.t, ], {} @classmethod diff --git a/tests/geometry/geo_test.py b/tests/geometry/geo_test.py index c8d505489..49655be33 100644 --- a/tests/geometry/geo_test.py +++ b/tests/geometry/geo_test.py @@ -69,15 +69,6 @@ def gt_geometry(G: jnp.ndarray, *, epsilon: float = 1e-2) -> geometry.Geometry: class TestGeodesic: - def test_init(self): - n, order = 10, 100 - t = 10 - G = random_graph(n, p=0.5) - geom = geodesic.Geodesic.from_graph(G, t=t, order=order) - - np.testing.assert_equal(geom.order, order) - np.testing.assert_equal(geom.t, t) - def test_kernel_is_symmetric_positive_definite( self, rng: jax.random.PRNGKeyArray ): @@ -115,44 +106,24 @@ def test_kernel_is_symmetric_positive_definite( # and all dissimilarities are positive np.testing.assert_array_less(0, cost_matrix) - def test_automatic_t(self): - G = random_graph(38, return_laplacian=False) - geom = geodesic.Geodesic.from_graph(G, t=None) - - expected = (jnp.sum(G) / jnp.sum(G > 0.)) ** 2 - actual = geom.t - np.testing.assert_equal(actual, expected) - @pytest.mark.fast.with_args( - n_steps=[50, 100, 200], + order=[50, 100, 200], t=[1e-4, 1e-5], - only_fast=0, ) - def cheb_be_cn(self, t: Optional[float], n_steps: int): - tol = 5 * t + def test_approximates_ground_truth(self, t: Optional[float], order: int): + tol = 1e-2 G = nx.linalg.adjacency_matrix(balanced_tree(r=2, h=5)) G = jnp.asarray(G.toarray(), dtype=float) eye = jnp.eye(G.shape[0]) - - be_geom = graph.Graph.from_graph( - G, t=t, n_steps=n_steps, numerical_scheme="backward_euler" - ) - cn_geom = graph.Graph.from_graph( - G, t=t, n_steps=n_steps, numerical_scheme="crank_nicolson" - ) - geo = geodesic.Geodesic.from_graph(G, t=t, order=n_steps) eps = jnp.finfo(eye.dtype).tiny - be_cost = -t * jnp.log(be_geom.apply_kernel(eye) + eps) - cn_cost = -t * jnp.log(cn_geom.apply_kernel(eye) + eps) - cheb_cost = -t * jnp.log(geo.apply_kernel(eye) + eps) + gt_geom = gt_geometry(G, epsilon=eps) - np.testing.assert_allclose(cheb_cost, cheb_cost.T, rtol=tol, atol=tol) - # check that it is close to the BE CN - np.testing.assert_allclose(be_cost, cheb_cost, rtol=tol, atol=tol) - np.testing.assert_allclose(cn_cost, cheb_cost, rtol=tol, atol=tol) - with pytest.raises(AssertionError): - np.testing.assert_allclose(be_cost, be_cost.T, rtol=tol, atol=tol) + geo = geodesic.Geodesic.from_graph(G, t=t, order=order) + + np.testing.assert_allclose( + gt_geom.kernel_matrix, geo.kernel_matrix, rtol=tol, atol=tol + ) @pytest.mark.parametrize(("jit", "normalize"), [(False, True), (True, False)]) def test_directed_graph(self, jit: bool, normalize: bool): From 620245d1d4d88122d020fce08a46de1e5e67be1f Mon Sep 17 00:00:00 2001 From: guillaumehu Date: Tue, 28 Nov 2023 10:13:36 +0100 Subject: [PATCH 41/44] rm order & change `phi` to `eigval` & type hints --- src/ott/geometry/geodesic.py | 16 +++++++--------- tests/geometry/geo_test.py | 15 ++++++++++----- tests/geometry/graph_test.py | 10 +++++++--- 3 files changed, 24 insertions(+), 17 deletions(-) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index a64790539..20e5cb11d 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import Any, Dict, Optional, Sequence, Tuple import jax @@ -46,7 +45,6 @@ class Geodesic(geometry.Geometry): eigval: The largest eigenvalue of the Laplacian. chebyshev_coeffs: Coefficients of the Chebyshev polynomials. t: Time parameter for the heat kernel. - order: Max order of Chebyshev polynomial. kwargs: Keyword arguments for :class:`~ott.geometry.geometry.Geometry`. """ @@ -241,18 +239,18 @@ def compute_largest_eigenvalue( def expm_multiply( - L: jnp.ndarray, X: jnp.ndarray, coeff: jnp.ndarray, phi: float + L: jnp.ndarray, X: jnp.ndarray, coeff: jnp.ndarray, eigval: float ) -> jnp.ndarray: def body(carry, c): T0, T1, Y = carry - T2 = (2.0 / phi) * L @ T1 - 2.0 * T1 - T0 + T2 = (2.0 / eigval) * L @ T1 - 2.0 * T1 - T0 Y = Y + c * T2 return (T1, T2, Y), None T0 = X Y = 0.5 * coeff[0] * T0 - T1 = (1.0 / phi) * L @ X - T0 + T1 = (1.0 / eigval) * L @ X - T0 Y = Y + coeff[1] * T1 initial_state = (T0, T1, Y) @@ -261,7 +259,7 @@ def body(carry, c): def compute_chebychev_coeff_all( - phi: float, tau: float, K: int, dtype: np.dtype + eigval: float, tau: float, K: int, dtype: np.dtype ) -> jnp.ndarray: """Jax wrapper to compute the K+1 Chebychev coefficients.""" result_shape_dtype = jax.ShapeDtypeStruct( @@ -269,8 +267,8 @@ def compute_chebychev_coeff_all( dtype=dtype, ) - chebychev_coeff = lambda phi, tau, K: ( - 2.0 * ive(np.arange(0, K + 1), -tau * phi) + chebychev_coeff = lambda eigval, tau, K: ( + 2.0 * ive(np.arange(0, K + 1), -tau * eigval) ).astype(dtype) - return jax.pure_callback(chebychev_coeff, result_shape_dtype, phi, tau, K) + return jax.pure_callback(chebychev_coeff, result_shape_dtype, eigval, tau, K) diff --git a/tests/geometry/geo_test.py b/tests/geometry/geo_test.py index 49655be33..a2eb00105 100644 --- a/tests/geometry/geo_test.py +++ b/tests/geometry/geo_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Union import jax import jax.numpy as jnp @@ -48,12 +48,16 @@ def random_graph( return jnp.asarray(G.toarray()) -def gt_geometry(G: jnp.ndarray, *, epsilon: float = 1e-2) -> geometry.Geometry: +def gt_geometry( + G: Union[jnp.ndarray, nx.Graph], + *, + epsilon: float = 1e-2 +) -> geometry.Geometry: if not isinstance(G, nx.Graph): G = nx.from_numpy_array(np.asarray(G)) n = len(G) - cost = np.zeros((n, n), dtype=float) + cost = np.zeros((n, n)) path = dict( shortest_paths.all_pairs_bellman_ford_path_length(G, weight="weight") @@ -70,7 +74,8 @@ def gt_geometry(G: jnp.ndarray, *, epsilon: float = 1e-2) -> geometry.Geometry: class TestGeodesic: def test_kernel_is_symmetric_positive_definite( - self, rng: jax.random.PRNGKeyArray + self, + rng: jax.Array, ): n, tol = 100, 0.02 t = 1 @@ -178,7 +183,7 @@ def laplacian(G: jnp.ndarray) -> jnp.ndarray: np.testing.assert_allclose(actual, expected, rtol=1e-6, atol=1e-6) @pytest.mark.fast.with_args(jit=[False, True], only_fast=0) - def test_graph_sinkhorn(self, rng: jax.random.PRNGKeyArray, jit: bool): + def test_geo_sinkhorn(self, rng: jax.Array, jit: bool): def callback(geom: geometry.Geometry) -> sinkhorn.SinkhornOutput: solver = sinkhorn.Sinkhorn(lse_mode=False) diff --git a/tests/geometry/graph_test.py b/tests/geometry/graph_test.py index 18e683bb4..7ba5b2df7 100644 --- a/tests/geometry/graph_test.py +++ b/tests/geometry/graph_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Literal, Optional, Tuple +from typing import Literal, Optional, Tuple, Union import jax import jax.numpy as jnp @@ -50,12 +50,16 @@ def random_graph( return jnp.asarray(G.toarray()) -def gt_geometry(G: jnp.ndarray, *, epsilon: float = 1e-2) -> geometry.Geometry: +def gt_geometry( + G: Union[jnp.ndarray, nx.Graph], + *, + epsilon: float = 1e-2 +) -> geometry.Geometry: if not isinstance(G, nx.Graph): G = nx.from_numpy_array(np.asarray(G)) n = len(G) - cost = np.zeros((n, n), dtype=float) + cost = np.zeros((n, n)) path = dict( shortest_paths.all_pairs_bellman_ford_path_length(G, weight="weight") From 3a93b22b46a01b88d58087a60de948f1b8ce9a1b Mon Sep 17 00:00:00 2001 From: guillaumehu Date: Tue, 28 Nov 2023 13:12:58 +0100 Subject: [PATCH 42/44] differentiability test --- tests/geometry/geo_test.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/geometry/geo_test.py b/tests/geometry/geo_test.py index a2eb00105..21976f34c 100644 --- a/tests/geometry/geo_test.py +++ b/tests/geometry/geo_test.py @@ -218,3 +218,27 @@ def callback(geom: geometry.Geometry) -> sinkhorn.SinkhornOutput: np.testing.assert_allclose( gt_out.matrix, graph_out.matrix, rtol=1e-1, atol=1e-1 ) + + def test_geometry_differentiability(self, rng: jax.Array): + + def callback(geom) -> float: + + solver = sinkhorn.Sinkhorn(lse_mode=False) + problem = linear_problem.LinearProblem(geom) + + return solver(problem).reg_ot_cost + + eps = 1e-3 + G = random_graph(20, p=0.5) + geom = geodesic.Geodesic.from_graph(G, t=1.) + + v_w = jax.random.normal(rng, shape=G.shape) + v_w = (v_w / jnp.linalg.norm(v_w, axis=-1, keepdims=True)) * eps + + grad_sl = jax.grad(callback)(geom).scaled_laplacian + geom__finite_right = geodesic.Geodesic.from_graph(G + v_w, t=1.) + geom__finite_left = geodesic.Geodesic.from_graph(G - v_w, t=1.) + + expected = callback(geom__finite_right) - callback(geom__finite_left) + actual = 2 * jnp.vdot(v_w, grad_sl) + np.testing.assert_allclose(actual, expected, rtol=1e-4, atol=1e-4) From 3e5ef045f056eb6d5a96a25a5931ddd314fa61c9 Mon Sep 17 00:00:00 2001 From: guillaumehu Date: Tue, 28 Nov 2023 13:39:16 +0100 Subject: [PATCH 43/44] condition to symmetrize kernel --- src/ott/geometry/geodesic.py | 8 ++++---- src/ott/geometry/graph.py | 7 +++++-- tests/geometry/geo_test.py | 2 +- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index 20e5cb11d..9d61af63c 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -155,10 +155,10 @@ def apply_kernel( def kernel_matrix(self) -> jnp.ndarray: # noqa: D102 n, _ = self.shape kernel = self.apply_kernel(jnp.eye(n)) - # check if the kernel is symmetric - if jnp.any(kernel != kernel.T): - kernel = (kernel + kernel.T) / 2.0 - return kernel + return jax.lax.cond( + jnp.allclose(kernel, kernel.T, atol=1e-8, rtol=1e-8), lambda x: x, + lambda x: (x + x.T) / 2.0, kernel + ) @property def cost_matrix(self) -> jnp.ndarray: # noqa: D102 diff --git a/src/ott/geometry/graph.py b/src/ott/geometry/graph.py index c7dac0c99..5109929f0 100644 --- a/src/ott/geometry/graph.py +++ b/src/ott/geometry/graph.py @@ -189,9 +189,12 @@ def body_fn( def kernel_matrix(self) -> jnp.ndarray: # noqa: D102 n, _ = self.shape kernel = self.apply_kernel(jnp.eye(n)) - # force symmetry because of numerical imprecision + # Symmetrize the kernel if needed. Numerical imprecision # happens when `numerical_scheme='backward_euler'` and small `t` - return (kernel + kernel.T) * 0.5 + return jax.lax.cond( + jnp.allclose(kernel, kernel.T, atol=1e-8, rtol=1e-8), lambda x: x, + lambda x: (x + x.T) / 2.0, kernel + ) @property def cost_matrix(self) -> jnp.ndarray: # noqa: D102 diff --git a/tests/geometry/geo_test.py b/tests/geometry/geo_test.py index 21976f34c..c54261e3a 100644 --- a/tests/geometry/geo_test.py +++ b/tests/geometry/geo_test.py @@ -93,7 +93,7 @@ def test_kernel_is_symmetric_positive_definite( # we symmetrize the kernel explicitly when materializing it, because # numerical errors can make it non-symmetric. - np.testing.assert_array_equal(kernel, kernel.T) + np.testing.assert_allclose(kernel, kernel.T, rtol=tol, atol=tol) eigenvalues = jnp.linalg.eigvals(kernel) neg_eigenvalues = eigenvalues[eigenvalues < 0] # check that the negative eigenvalues are all very small From cf17ce4d66de44e687e8493f6bec1805bbb3543c Mon Sep 17 00:00:00 2001 From: guillaumehu Date: Tue, 28 Nov 2023 15:12:37 +0100 Subject: [PATCH 44/44] fix indentation --- src/ott/geometry/geodesic.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index 9d61af63c..f07cce777 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -31,14 +31,14 @@ class Geodesic(geometry.Geometry): r"""Graph distance approximation using heat kernel :cite:`huguet:2022`. - important:: - This constructor is not meant to be called by the user, - please use the :meth:`from_graph` method instead. - - Approximates the heat kernel using `Chebyshev polynomials - `_ of the first kind of - max order ``order``, which for small ``t`` approximates the - geodesic exponential kernel :math:`e^{\frac{-d(x, y)^2}{t}}`. + .. important:: + This constructor is not meant to be called by the user, + please use the :meth:`from_graph` method instead. + + Approximates the heat kernel using + `Chebyshev polynomials `_ + of the first kind of max order ``order``, which for small ``t`` + approximates the geodesic exponential kernel :math:`e^{\frac{-d(x, y)^2}{t}}`. Args: scaled_laplacian: The Laplacian scaled by the largest eigenvalue.