Skip to content

Commit

Permalink
Merge pull request #1 from diegoabt/feature/geodesic-sinkhorn
Browse files Browse the repository at this point in the history
Create `geodesic` module
  • Loading branch information
diegoabt authored Sep 5, 2023
2 parents 937e049 + d3d840f commit 4306382
Showing 1 changed file with 321 additions and 0 deletions.
321 changes: 321 additions & 0 deletions src/ott/geometry/geodesic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,321 @@
# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Literal, Optional, Sequence, Tuple, List

import jax
import jax.numpy as jnp
from jax.experimental.sparse.linalg import lobpcg_standard

import numpy as np
from scipy.special import ive

from ott.geometry import geometry
from ott.math import utils as mu


__all__ = ["Geodesic"]


@jax.tree_util.register_pytree_node_class
class Geodesic(geometry.Geometry): #TODO: Rewrite docstring
r"""Graph distance approximation using heat kernel :cite:`heitz:21,crane:13`.
Approximates the heat kernel for large ``n_steps``, which for small ``t``
approximates the geodesic exponential kernel :math:`e^{\frac{-d(x, y)^2}{t}}`.
Args:
laplacian: Symmetric graph Laplacian. The check for symmetry is **NOT**
performed. See also :meth:`from_graph`.
n_steps: Maximum number of steps used to approximate the heat kernel.
numerical_scheme: Numerical scheme used to solve the heat diffusion.
normalize: Whether to normalize the Laplacian as
:math:`L^{sym} = \left(D^+\right)^{\frac{1}{2}} L
\left(D^+\right)^{\frac{1}{2}}`, where :math:`L` is the
non-normalized Laplacian and :math:`D` is the degree matrix.
tol: Relative tolerance with respect to the Hilbert metric, see
:cite:`peyre:19`, Remark 4.12. Used when iteratively updating scalings.
If negative, this option is ignored and only ``n_steps`` is used.
kwargs: Keyword arguments for :class:`~ott.geometry.geometry.Geometry`.
"""

def __init__(
self,
laplacian: jnp.ndarray,
t: float = 1e-3,
n_steps: int = 100,
tol: float = -1.0,
**kwargs: Any
):
super().__init__(epsilon=1., **kwargs)
self.laplacian = laplacian
self.t = t
self.n_steps = n_steps
self.tol = tol

@classmethod
def from_graph(
cls,
G: jnp.ndarray,
t: Optional[float] = 1e-3,
directed: bool = False,
normalize: bool = False,
**kwargs: Any
) -> "Geodesic":
r"""Construct :class:`~ott.geometry.graph.Graph` from an adjacency matrix.
Args:
G: Adjacency matrix.
t: Constant used when approximating the geodesic exponential kernel.
If `None`, use :math:`\frac{1}{|E|} \sum_{(u, v) \in E} weight(u, v)`
:cite:`crane:13`. In this case, the ``graph`` must be specified
and the edge weights are all assumed to be positive.
directed: Whether the ``graph`` is directed. If not, it will be made
undirected as :math:`G + G^T`. This parameter is ignored when directly
passing the Laplacian, which is assumed to be symmetric.
normalize: Whether to normalize the Laplacian as
:math:`L^{sym} = \left(D^+\right)^{\frac{1}{2}} L
\left(D^+\right)^{\frac{1}{2}}`, where :math:`L` is the
non-normalized Laplacian and :math:`D` is the degree matrix.
kwargs: Keyword arguments for :class:`~ott.geometry.graph.Graph`.
Returns:
The graph geometry.
"""
assert G.shape[0] == G.shape[1], G.shape
print(G.shape)
if directed:
G = G + G.T

degree = jnp.sum(G, axis=1)
laplacian = jnp.diag(degree) - G

if normalize:
inv_sqrt_deg = jnp.diag(
jnp.where(degree > 0.0, 1.0 / jnp.sqrt(degree), 0.0)
)
laplacian = inv_sqrt_deg @ laplacian @ inv_sqrt_deg

if t is None:
t = (jnp.sum(G) / jnp.sum(G > 0.)) ** 2

return cls(laplacian, t=t, **kwargs)



def apply_kernel(
self,
scaling: jnp.ndarray,
eps: Optional[float] = None,
axis: int = 0,
) -> jnp.ndarray:
r"""Apply :attr:`kernel_matrix` on positive scaling vector.
Args:
scaling: Scaling to apply the kernel to.
eps: passed for consistency, not used yet.
axis: passed for consistency, not used yet.
Returns:
Kernel applied to ``scaling``.
"""

def compute_laplacian(adjacency_matrix: jnp.ndarray) -> jnp.ndarray:
"""
Compute the Laplacian matrix from the adjacency matrix.
Args:
adjacency_matrix: An (n, n) array representing the adjacency matrix of a graph.
Returns:
An (n, n) array representing the Laplacian matrix.
"""
degree_matrix = jnp.diag(jnp.sum(adjacency_matrix, axis=0))
laplacian_matrix = degree_matrix - adjacency_matrix
return laplacian_matrix

def compute_largest_eigenvalue(laplacian_matrix, k):
"""
Compute the largest eigenvalue of the Laplacian matrix.
Args:
laplacian_matrix: An (n, n) array representing the Laplacian matrix.
k: Number of eigenvalues/vectors to compute.
Returns:
The largest eigenvalue of the Laplacian matrix.
"""
n = laplacian_matrix.shape[0]
initial_directions = jax.random.normal(jax.random.PRNGKey(0), (n, k))
# Convert the Laplacian matrix to a dense array
#laplacian_array = laplacian_matrix.toarray()
eigvals, _, _ = lobpcg_standard(laplacian_matrix, initial_directions, m=k)
largest_eigenvalue = np.max(eigvals)
return largest_eigenvalue

def rescale_laplacian(laplacian_matrix: jnp.ndarray) -> jnp.ndarray:
"""
Rescale the Laplacian matrix.
Args:
laplacian_matrix: An (n, n) array representing the Laplacian matrix.
Returns:
The rescaled Laplacian matrix.
"""
largest_eigenvalue = compute_largest_eigenvalue(laplacian_matrix, k=1)
if largest_eigenvalue > 2:
rescaled_laplacian = laplacian_matrix.copy()
rescaled_laplacian /= largest_eigenvalue
laplacian_matrix = 2 * rescaled_laplacian
return laplacian_matrix

def define_scaled_laplacian(laplacian_matrix: jnp.ndarray) -> jnp.ndarray:
"""
Define the scaled Laplacian matrix.
Args:
laplacian_matrix: An (n, n) array representing the Laplacian matrix.
Returns:
The scaled Laplacian matrix.
"""
n = laplacian_matrix.shape[0]
identity = jnp.eye(n)
scaled_laplacian = laplacian_matrix - identity
return scaled_laplacian

def chebyshev_coefficients(t: float, max_order: int) -> List[float]:
"""
Compute the coefficients of the Chebyshev polynomial approximation using Bessel functions.
Args:
t: Time parameter.
max_order: Maximum order of the Chebyshev polynomial approximation.
Returns:
A list of coefficients.
"""
return (2 * ive(jnp.arange(0, max_order + 1), -t)).tolist()

def compute_chebyshev_approximation(
x: jnp.ndarray, coeffs: List[float]
) -> jnp.ndarray:
"""
Compute the Chebyshev polynomial approximation for the given input and coefficients.
Args:
x: Input to evaluate the polynomial at.
coeffs: List of Chebyshev polynomial coefficients.
Returns:
The Chebyshev polynomial approximation evaluated at x.
"""
return self.apply_kernel(x, coeffs)

#laplacian_matrix = compute_laplacian(self.adjacency_matrix)
rescaled_laplacian = rescale_laplacian(self.laplacian)
scaled_laplacian = define_scaled_laplacian(rescaled_laplacian)
chebyshev_coeffs = chebyshev_coefficients(self.t, self.n_steps)

laplacian_times_signal = scaled_laplacian.dot(scaling) # Apply the kernel

chebyshev_approx_on_signal = compute_chebyshev_approximation(
laplacian_times_signal, chebyshev_coeffs)

return chebyshev_approx_on_signal

@property
def kernel_matrix(self) -> jnp.ndarray: # noqa: D102
n, _ = self.shape
kernel = self.apply_kernel(jnp.eye(n))
# force symmetry because of numerical imprecision
# happens when `numerical_scheme='backward_euler'` and small `t`
return (kernel + kernel.T) * 0.5

@property
def cost_matrix(self) -> jnp.ndarray: # noqa: D102
return -self.t * mu.safe_log(self.kernel_matrix)

@property
def _scale(self) -> float:
"""Constant used to scale the Laplacian."""
if self.numerical_scheme == "backward_euler":
return self.t / (4. * self.n_steps)
if self.numerical_scheme == "crank_nicolson":
return self.t / (2. * self.n_steps)
raise NotImplementedError(
f"Numerical scheme `{self.numerical_scheme}` is not implemented."
)

@property
def _scaled_laplacian(self) -> jnp.ndarray:
"""Laplacian scaled by a constant, depending on the numerical scheme."""
return self._scale * self.laplacian

@property
def _M(self) -> jnp.ndarray:
n, _ = self.shape
return self._scaled_laplacian + jnp.eye(n)

@property
def shape(self) -> Tuple[int, int]: # noqa: D102
return self.laplacian.shape

@property
def is_symmetric(self) -> bool: # noqa: D102
return True

@property
def dtype(self) -> jnp.dtype: # noqa: D102
return self.laplacian.dtype

def transport_from_potentials(
self, f: jnp.ndarray, g: jnp.ndarray
) -> jnp.ndarray:
"""Not implemented."""
raise ValueError("Not implemented.")

def apply_transport_from_potentials(
self,
f: jnp.ndarray,
g: jnp.ndarray,
vec: jnp.ndarray,
axis: int = 0
) -> jnp.ndarray:
"""Since applying from potentials is not feasible in grids, use scalings."""
u, v = self.scaling_from_potential(f), self.scaling_from_potential(g)
return self.apply_transport_from_scalings(u, v, vec, axis=axis)

def marginal_from_potentials(
self,
f: jnp.ndarray,
g: jnp.ndarray,
axis: int = 0,
) -> jnp.ndarray:
"""Not implemented."""
raise ValueError("Not implemented.")

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102
return [self.laplacian, self.t], {
"n_steps": self.n_steps,
"numerical_scheme": self.numerical_scheme,
"tol": self.tol,
}

@classmethod
def tree_unflatten( # noqa: D102
cls, aux_data: Dict[str, Any], children: Sequence[Any]
) -> "Graph":
return cls(*children, **aux_data)

0 comments on commit 4306382

Please sign in to comment.