From 425fc0e4d5c74320af442ec31f4620759b441ea8 Mon Sep 17 00:00:00 2001 From: Michal Klein Date: Thu, 10 Nov 2022 18:07:51 -0800 Subject: [PATCH 01/34] Reorganize repo structure --- {ott/examples => examples}/fairness/config.py | 0 {ott/examples => examples}/fairness/data.py | 0 {ott/examples => examples}/fairness/losses.py | 0 {ott/examples => examples}/fairness/main.py | 0 {ott/examples => examples}/fairness/models.py | 0 {ott/examples => examples}/fairness/train.py | 0 .../soft_error/config.py | 0 {ott/examples => examples}/soft_error/data.py | 0 .../soft_error/losses.py | 0 {ott/examples => examples}/soft_error/main.py | 0 .../examples => examples}/soft_error/model.py | 0 .../examples => examples}/soft_error/train.py | 0 ott/__init__.py | 2 +- ott/core/__init__.py | 44 ---- ott/core/_math_utils.py | 33 --- ott/core/momentum.py | 71 ------ ott/core/problems.py | 93 ------- ott/geometry/__init__.py | 7 +- ott/geometry/costs.py | 18 +- ott/geometry/geometry.py | 9 +- ott/geometry/graph.py | 4 +- ott/geometry/grid.py | 16 +- ott/geometry/pointcloud.py | 15 +- ott/initializers/__init__.py | 1 + ott/initializers/linear/__init__.py | 0 .../linear}/initializers.py | 240 ++---------------- .../linear}/initializers_lr.py | 44 ++-- ott/initializers/nn/__init__.py | 0 ott/initializers/nn/initializers.py | 188 ++++++++++++++ ott/initializers/nn/layers.py | 39 +++ ott/initializers/quadratic/__init__.py | 0 .../quadratic/initializers.py} | 42 +-- ott/math/__init__.py | 0 ott/{core => math}/decomposition.py | 0 ott/{core => math}/fixed_point_loop.py | 2 +- .../implicit_differentiation.py | 4 +- ott/{geometry => math}/matrix_square_root.py | 2 +- ott/{core => math}/potentials.py | 0 ott/{core => math}/unbalanced_functions.py | 0 ott/{geometry/ops.py => math/utils.py} | 56 ++-- ott/problems/__init__.py | 1 + ott/problems/linear/__init__.py | 2 + ott/problems/linear/barycenter_problem.py | 182 +++++++++++++ .../linear/linear_problem.py} | 2 + ott/problems/quadratic/__init__.py | 3 + .../quadratic/barycenter_problem.py} | 208 ++------------- ott/problems/quadratic/quadratic_costs.py | 34 +++ .../quadratic/quadratic_problem.py} | 82 ++---- ott/solvers/__init__.py | 1 + ott/solvers/linear/__init__.py | 0 .../linear/acceleration.py} | 86 +++++-- .../linear}/continuous_barycenter.py | 23 +- .../linear}/discrete_barycenter.py | 20 +- ott/{core => solvers/linear}/sinkhorn.py | 80 +++--- ott/{core => solvers/linear}/sinkhorn_lr.py | 47 ++-- ott/solvers/nn/__init__.py | 1 + ott/{core => solvers/nn}/icnn.py | 8 +- ott/{core => solvers/nn}/layers.py | 23 +- ott/{core => solvers/nn}/neuraldual.py | 5 +- ott/solvers/quadratic/__init__.py | 0 .../quadratic}/gromov_wasserstein.py | 41 ++- .../quadratic}/gw_barycenter.py | 27 +- .../gaussian_mixture/gaussian_mixture_pair.py | 2 +- ott/tools/gaussian_mixture/scale_tril.py | 3 +- ott/tools/k_means.py | 2 +- ott/tools/segment_sinkhorn.py | 3 +- ott/tools/sinkhorn_divergence.py | 4 +- ott/tools/transport.py | 91 ++++++- ott/typing.py | 22 ++ ott/utils/__init__.py | 0 ott/{core => utils}/dataclasses.py | 0 ott/{core => utils}/segment.py | 0 ott/{core => utils}/was_solver.py | 2 +- 73 files changed, 964 insertions(+), 971 deletions(-) rename {ott/examples => examples}/fairness/config.py (100%) rename {ott/examples => examples}/fairness/data.py (100%) rename {ott/examples => examples}/fairness/losses.py (100%) rename {ott/examples => examples}/fairness/main.py (100%) rename {ott/examples => examples}/fairness/models.py (100%) rename {ott/examples => examples}/fairness/train.py (100%) rename {ott/examples => examples}/soft_error/config.py (100%) rename {ott/examples => examples}/soft_error/data.py (100%) rename {ott/examples => examples}/soft_error/losses.py (100%) rename {ott/examples => examples}/soft_error/main.py (100%) rename {ott/examples => examples}/soft_error/model.py (100%) rename {ott/examples => examples}/soft_error/train.py (100%) delete mode 100644 ott/core/__init__.py delete mode 100644 ott/core/_math_utils.py delete mode 100644 ott/core/momentum.py delete mode 100644 ott/core/problems.py create mode 100644 ott/initializers/__init__.py create mode 100644 ott/initializers/linear/__init__.py rename ott/{core => initializers/linear}/initializers.py (52%) rename ott/{core => initializers/linear}/initializers_lr.py (94%) create mode 100644 ott/initializers/nn/__init__.py create mode 100644 ott/initializers/nn/initializers.py create mode 100644 ott/initializers/nn/layers.py create mode 100644 ott/initializers/quadratic/__init__.py rename ott/{core/quad_initializers.py => initializers/quadratic/initializers.py} (82%) create mode 100644 ott/math/__init__.py rename ott/{core => math}/decomposition.py (100%) rename ott/{core => math}/fixed_point_loop.py (99%) rename ott/{core => math}/implicit_differentiation.py (98%) rename ott/{geometry => math}/matrix_square_root.py (99%) rename ott/{core => math}/potentials.py (100%) rename ott/{core => math}/unbalanced_functions.py (100%) rename ott/{geometry/ops.py => math/utils.py} (62%) create mode 100644 ott/problems/__init__.py create mode 100644 ott/problems/linear/__init__.py create mode 100644 ott/problems/linear/barycenter_problem.py rename ott/{core/linear_problems.py => problems/linear/linear_problem.py} (99%) create mode 100644 ott/problems/quadratic/__init__.py rename ott/{core/bar_problems.py => problems/quadratic/barycenter_problem.py} (57%) create mode 100644 ott/problems/quadratic/quadratic_costs.py rename ott/{core/quad_problems.py => problems/quadratic/quadratic_problem.py} (90%) create mode 100644 ott/solvers/__init__.py create mode 100644 ott/solvers/linear/__init__.py rename ott/{core/anderson.py => solvers/linear/acceleration.py} (62%) rename ott/{core => solvers/linear}/continuous_barycenter.py (89%) rename ott/{core => solvers/linear}/discrete_barycenter.py (94%) rename ott/{core => solvers/linear}/sinkhorn.py (94%) rename ott/{core => solvers/linear}/sinkhorn_lr.py (94%) create mode 100644 ott/solvers/nn/__init__.py rename ott/{core => solvers/nn}/icnn.py (97%) rename ott/{core => solvers/nn}/layers.py (86%) rename ott/{core => solvers/nn}/neuraldual.py (99%) create mode 100644 ott/solvers/quadratic/__init__.py rename ott/{core => solvers/quadratic}/gromov_wasserstein.py (95%) rename ott/{core => solvers/quadratic}/gw_barycenter.py (93%) create mode 100644 ott/typing.py create mode 100644 ott/utils/__init__.py rename ott/{core => utils}/dataclasses.py (100%) rename ott/{core => utils}/segment.py (100%) rename ott/{core => utils}/was_solver.py (98%) diff --git a/ott/examples/fairness/config.py b/examples/fairness/config.py similarity index 100% rename from ott/examples/fairness/config.py rename to examples/fairness/config.py diff --git a/ott/examples/fairness/data.py b/examples/fairness/data.py similarity index 100% rename from ott/examples/fairness/data.py rename to examples/fairness/data.py diff --git a/ott/examples/fairness/losses.py b/examples/fairness/losses.py similarity index 100% rename from ott/examples/fairness/losses.py rename to examples/fairness/losses.py diff --git a/ott/examples/fairness/main.py b/examples/fairness/main.py similarity index 100% rename from ott/examples/fairness/main.py rename to examples/fairness/main.py diff --git a/ott/examples/fairness/models.py b/examples/fairness/models.py similarity index 100% rename from ott/examples/fairness/models.py rename to examples/fairness/models.py diff --git a/ott/examples/fairness/train.py b/examples/fairness/train.py similarity index 100% rename from ott/examples/fairness/train.py rename to examples/fairness/train.py diff --git a/ott/examples/soft_error/config.py b/examples/soft_error/config.py similarity index 100% rename from ott/examples/soft_error/config.py rename to examples/soft_error/config.py diff --git a/ott/examples/soft_error/data.py b/examples/soft_error/data.py similarity index 100% rename from ott/examples/soft_error/data.py rename to examples/soft_error/data.py diff --git a/ott/examples/soft_error/losses.py b/examples/soft_error/losses.py similarity index 100% rename from ott/examples/soft_error/losses.py rename to examples/soft_error/losses.py diff --git a/ott/examples/soft_error/main.py b/examples/soft_error/main.py similarity index 100% rename from ott/examples/soft_error/main.py rename to examples/soft_error/main.py diff --git a/ott/examples/soft_error/model.py b/examples/soft_error/model.py similarity index 100% rename from ott/examples/soft_error/model.py rename to examples/soft_error/model.py diff --git a/ott/examples/soft_error/train.py b/examples/soft_error/train.py similarity index 100% rename from ott/examples/soft_error/train.py rename to examples/soft_error/train.py diff --git a/ott/__init__.py b/ott/__init__.py index df1438bef..8015c263b 100644 --- a/ott/__init__.py +++ b/ott/__init__.py @@ -12,5 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. """OTT library.""" -from . import core, geometry, tools +from . import geometry, initializers, math, problems, solvers, tools, utils from ._version import __version__ diff --git a/ott/core/__init__.py b/ott/core/__init__.py deleted file mode 100644 index 4a39d58fc..000000000 --- a/ott/core/__init__.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2022 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""OTT core libraries: the engine behind most computations happening in OTT.""" - -# pytype: disable=import-error # kwargs-checking -from . import ( - anderson, - bar_problems, - continuous_barycenter, - dataclasses, - decomposition, - discrete_barycenter, - gromov_wasserstein, - gw_barycenter, - implicit_differentiation, - initializers, - initializers_lr, - linear_problems, - momentum, - potentials, - quad_initializers, - quad_problems, - sinkhorn, - sinkhorn_lr, -) - -# from . import neuraldual -from .implicit_differentiation import ImplicitDiff -from .linear_problems import LinearProblem -from .sinkhorn import Sinkhorn -from .sinkhorn_lr import LRSinkhorn - -# pytype: enable=import-error # kwargs-checking diff --git a/ott/core/_math_utils.py b/ott/core/_math_utils.py deleted file mode 100644 index 7a269a77b..000000000 --- a/ott/core/_math_utils.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import Optional, Union - -import jax.experimental.sparse as jesp -import jax.numpy as jnp - -__all__ = ["safe_log", "kl", "js"] - -Sparse_t = Union[jesp.CSR, jesp.CSC, jesp.COO, jesp.BCOO] - - -def safe_log(x: jnp.ndarray, *, eps: Optional[float] = None) -> jnp.ndarray: - if eps is None: - eps = jnp.finfo(x.dtype).tiny - return jnp.where(x > 0., jnp.log(x), jnp.log(eps)) - - -def kl(p: jnp.ndarray, q: jnp.ndarray) -> float: - """Kullback-Leilbler divergence.""" - return jnp.vdot(p, (safe_log(p) - safe_log(q))) - - -def js(p: jnp.ndarray, q: jnp.ndarray, *, c: float = 0.5) -> float: - """Jensen-Shannon divergence.""" - return c * (kl(p, q) + kl(q, p)) - - -def sparse_scale(c: float, mat: Sparse_t) -> Sparse_t: - """Scale a sparse matrix by a constant.""" - if isinstance(mat, jesp.BCOO): - # most feature complete, defer to original impl. - return c * mat - (data, *children), aux_data = mat.tree_flatten() - return type(mat).tree_unflatten(aux_data, [c * data] + children) diff --git a/ott/core/momentum.py b/ott/core/momentum.py deleted file mode 100644 index 380df4fc1..000000000 --- a/ott/core/momentum.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright 2022 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Functions related to momemtum.""" - -from typing import TYPE_CHECKING - -import jax -import jax.numpy as jnp - -from ott.core import dataclasses - -if TYPE_CHECKING: - from ott.core import sinkhorn - - -@dataclasses.register_pytree_node -class Momentum: - """Momentum for Sinkhorn updates, either constant or adaptive.""" - - start: int = 0 - error_threshold: float = jnp.inf - value: float = 1.0 - inner_iterations: int = 1 - - def weight(self, state: "sinkhorn.SinkhornState", iteration: int) -> float: - """Compute momentum term if needed, using previously seen errors.""" - if self.start == 0: - return self.value - idx = self.start // self.inner_iterations - - weight = jax.lax.cond( - jnp.logical_and( - iteration >= self.start, - state.errors[idx - 1, -1] < self.error_threshold - ), lambda state: self.lehmann(state), lambda state: self.value, state - ) - return weight - - def lehmann(self, state: "sinkhorn.SinkhornState") -> float: - """Momentum formula :cite:`lehmann:21`, eq. 5.""" - idx = self.start // self.inner_iterations - error_ratio = jnp.minimum( - state.errors[idx - 1, -1] / state.errors[idx - 2, -1], 0.99 - ) - power = 1.0 / self.inner_iterations - return 2.0 / (1.0 + jnp.sqrt(1.0 - error_ratio ** power)) - - def __call__( - self, - weight: float, - value: jnp.ndarray, - new_value: jnp.ndarray, - lse_mode: bool = True - ) -> jnp.ndarray: - if lse_mode: - value = jnp.where(jnp.isfinite(value), value, 0.0) - return (1.0 - weight) * value + weight * new_value - else: - value = jnp.where(value > 0.0, value, 1.0) - return value ** (1.0 - weight) * new_value ** weight diff --git a/ott/core/problems.py b/ott/core/problems.py deleted file mode 100644 index e60b4deb1..000000000 --- a/ott/core/problems.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright 2022 The OTT Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Utility to make a problem class from arrays.""" -from typing import Any, Optional, Union - -import jax.numpy as jnp -import numpy as np - -from ott.core import linear_problems, quad_problems -from ott.geometry import geometry, pointcloud - - -def make( - *args: Union[jnp.ndarray, geometry.Geometry, linear_problems.LinearProblem, - quad_problems.QuadraticProblem], - a: Optional[jnp.ndarray] = None, - b: Optional[jnp.ndarray] = None, - tau_a: float = 1.0, - tau_b: float = 1.0, - objective: Optional[str] = None, - gw_unbalanced_correction: Optional[bool] = True, - fused_penalty: Optional[float] = None, - scale_cost: Optional[Union[bool, float, str]] = False, - **kwargs: Any, -): - """Make a problem from arrays, assuming PointCloud geometries.""" - if isinstance(args[0], (jnp.ndarray, np.ndarray)): - x = args[0] - y = args[1] if len(args) > 1 else args[0] - if ((objective == 'linear') or - (objective is None and x.shape[1] == y.shape[1])): # noqa: E129 - geom_xy = pointcloud.PointCloud(x, y, **kwargs) - return linear_problems.LinearProblem( - geom_xy, a=a, b=b, tau_a=tau_a, tau_b=tau_b - ) - elif ((objective == 'quadratic') or - (objective is None and x.shape[1] != y.shape[1])): - geom_xx = pointcloud.PointCloud(x, x, **kwargs) - geom_yy = pointcloud.PointCloud(y, y, **kwargs) - return quad_problems.QuadraticProblem( - geom_xx=geom_xx, - geom_yy=geom_yy, - geom_xy=None, - scale_cost=scale_cost, - a=a, - b=b, - tau_a=tau_a, - tau_b=tau_b, - gw_unbalanced_correction=gw_unbalanced_correction - ) - elif objective == 'fused': - geom_xx = pointcloud.PointCloud(x, x, **kwargs) - geom_yy = pointcloud.PointCloud(y, y, **kwargs) - geom_xy = pointcloud.PointCloud(x, y, **kwargs) - return quad_problems.QuadraticProblem( - geom_xx=geom_xx, - geom_yy=geom_yy, - geom_xy=geom_xy, - fused_penalty=fused_penalty, - scale_cost=scale_cost, - a=a, - b=b, - tau_a=tau_a, - tau_b=tau_b, - gw_unbalanced_correction=gw_unbalanced_correction - ) - else: - raise ValueError(f'Unknown transport problem `{objective}`') - elif isinstance(args[0], geometry.Geometry): - if len(args) == 1: - return linear_problems.LinearProblem( - *args, a=a, b=b, tau_a=tau_a, tau_b=tau_b - ) - return quad_problems.QuadraticProblem( - *args, a=a, b=b, tau_a=tau_a, tau_b=tau_b, scale_cost=scale_cost - ) - elif isinstance( - args[0], (linear_problems.LinearProblem, quad_problems.QuadraticProblem) - ): - return args[0] - else: - raise ValueError('Cannot instantiate a transport problem.') diff --git a/ott/geometry/__init__.py b/ott/geometry/__init__.py index 38b4ada53..7309ab155 100644 --- a/ott/geometry/__init__.py +++ b/ott/geometry/__init__.py @@ -12,9 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. """OTT ground geometries: Classes and cost functions to instantiate them.""" -from . import costs, low_rank, ops -from .epsilon_scheduler import Epsilon -from .geometry import Geometry -from .graph import Graph -from .grid import Grid -from .pointcloud import PointCloud +from . import costs, epsilon_scheduler, geometry, graph, grid, pointcloud diff --git a/ott/geometry/costs.py b/ott/geometry/costs.py index fae64a596..1af441d73 100644 --- a/ott/geometry/costs.py +++ b/ott/geometry/costs.py @@ -17,13 +17,12 @@ import abc import functools import math -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Tuple, Union import jax import jax.numpy as jnp -from ott.core import fixed_point_loop -from ott.geometry import matrix_square_root +from ott.math import fixed_point_loop, matrix_square_root @jax.tree_util.register_pytree_node_class @@ -32,9 +31,11 @@ class CostFn(abc.ABC): Cost functions evaluate a function on a pair of inputs. For convenience, that function is split into two norms -- evaluated on each input separately -- - followed by a pairwise cost that involves both inputs, as in + followed by a pairwise cost that involves both inputs, as in: - c(x,y) = norm(x) + norm(y) + pairwise(x,y) + .. math:: + + c(x,y) = norm(x) + norm(y) + pairwise(x,y) If the norm function is not implemented, that value is handled as a 0. """ @@ -46,6 +47,7 @@ class CostFn(abc.ABC): def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: pass + @abc.abstractmethod def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> float: pass @@ -340,6 +342,7 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: log_m_pi += -0.5 * ldet_c_ab # If all logdet signs are 1, output value, nan otherwise. + # TODO(michalk8): use lax.cond return jnp.where( sldet_c == 1 and sldet_c_ab == 1 and sldet_ab == 1 and sldet_t_ab == 1, 2 * sig2 * mass_x * mass_y - 2 * (sig2 + gam) * jnp.exp(log_m_pi), @@ -355,7 +358,8 @@ def tree_unflatten(cls, aux_data, children): return cls(aux_data[0], aux_data[1], aux_data[2], **aux_data[3]) -def x_to_means_and_covs(x: jnp.ndarray, dimension: jnp.ndarray) -> jnp.ndarray: +def x_to_means_and_covs(x: jnp.ndarray, + dimension: int) -> Tuple[jnp.ndarray, jnp.ndarray]: """Extract means and covariance matrices of Gaussians from raveled vector. Args: @@ -367,7 +371,7 @@ def x_to_means_and_covs(x: jnp.ndarray, dimension: jnp.ndarray) -> jnp.ndarray: covariances: [num_gaussians, dimension] array that holds the covariances. """ x = jnp.atleast_2d(x) - means = x[:, 0:dimension] + means = x[:, :dimension] covariances = jnp.reshape( x[:, dimension:dimension + dimension ** 2], (-1, dimension, dimension) ) diff --git a/ott/geometry/geometry.py b/ott/geometry/geometry.py index dc59740b7..ff3d5000c 100644 --- a/ott/geometry/geometry.py +++ b/ott/geometry/geometry.py @@ -25,7 +25,8 @@ import jax.scipy as jsp from typing_extensions import Literal -from ott.geometry import epsilon_scheduler, ops +from ott.geometry import epsilon_scheduler +from ott.math import utils @jax.tree_util.register_pytree_node_class @@ -410,12 +411,12 @@ def _softmax( if vec is not None: if axis == 0: vec = vec.reshape((-1, 1)) - lse_output = ops.logsumexp( + lse_output = utils.logsumexp( self._center(f, g) / eps, b=vec, axis=axis, return_sign=True ) return eps * lse_output[0], lse_output[1] else: - lse_output = ops.logsumexp( + lse_output = utils.logsumexp( self._center(f, g) / eps, axis=axis, return_sign=False ) return eps * lse_output, jnp.array([1.0]) @@ -639,7 +640,7 @@ def to_LRCGeometry( Useful when this geometry is used in the linear term of fused GW. Returns: - The low-rank geometry. + Low-rank geometry. """ from ott.geometry import low_rank diff --git a/ott/geometry/graph.py b/ott/geometry/graph.py index b7e6e9024..8cd6c05a7 100644 --- a/ott/geometry/graph.py +++ b/ott/geometry/graph.py @@ -5,9 +5,9 @@ import jax.numpy as jnp from typing_extensions import Literal -from ott.core import _math_utils as mu -from ott.core import decomposition, fixed_point_loop from ott.geometry import geometry +from ott.math import decomposition, fixed_point_loop +from ott.math import utils as mu __all__ = ["Graph"] diff --git a/ott/geometry/grid.py b/ott/geometry/grid.py index 25e9809eb..88bd67f0a 100644 --- a/ott/geometry/grid.py +++ b/ott/geometry/grid.py @@ -21,19 +21,20 @@ import jax.numpy as jnp import numpy as np -from ott.geometry import costs, geometry, ops, pointcloud +from ott.geometry import costs, geometry, pointcloud +from ott.math import utils @jax.tree_util.register_pytree_node_class class Grid(geometry.Geometry): - r"""Class describing the geometry of points taken in a cartestian product. + r"""Class describing the geometry of points taken in a Cartesian product. This class implements a geometry in which probability measures are supported on a :math:`d`-dimensional cartesian grid, a cartesian product of :math:`d` lists of values, each list being itself of size :math:`n_i`. The transportation cost between points in the grid is assumed to be separable, - namely a sum of coordinate-wise cost functions, as in + namely a sum of coordinate-wise cost functions, as in: .. math:: @@ -52,7 +53,7 @@ class Grid(geometry.Geometry): Args: x : list of arrays of varying sizes, describing the locations of the grid. - Locations are provided as a list of jnp.ndarrays, that is :math:`d` + Locations are provided as a list of arrays, that is :math:`d` vectors of (possibly varying) size :math:`n_i`. The resulting grid is the Cartesian product of these vectors. grid_size: tuple of integers describing grid sizes, namely @@ -201,14 +202,13 @@ def _apply_lse_kernel_one_dimension(self, dimension, f, g, eps, vec=None): if vec is not None: vec = jnp.transpose(vec, indices) - softmax_res, softmax_sgn = ops.logsumexp( + softmax_res, softmax_sgn = utils.logsumexp( centered_cost, b=vec, axis=1, return_sign=True ) return eps * jnp.transpose(softmax_res, indices), jnp.transpose(softmax_sgn, indices) - else: - softmax_res = ops.logsumexp(centered_cost, axis=1) - return eps * jnp.transpose(softmax_res, indices), None + softmax_res = utils.logsumexp(centered_cost, axis=1) + return eps * jnp.transpose(softmax_res, indices), None def _apply_cost_to_vec( self, vec: jnp.ndarray, axis: int = 0, fn=None diff --git a/ott/geometry/pointcloud.py b/ott/geometry/pointcloud.py index 9b518090e..9e153f267 100644 --- a/ott/geometry/pointcloud.py +++ b/ott/geometry/pointcloud.py @@ -21,19 +21,20 @@ import jax.numpy as jnp from typing_extensions import Literal -from ott.geometry import costs, geometry, low_rank, ops +from ott.geometry import costs, geometry, low_rank +from ott.math import utils @jax.tree_util.register_pytree_node_class class PointCloud(geometry.Geometry): - """Defines geometry for 2 point clouds (possibly 1 vs itself) using CostFn. + """Defines geometry for 2 point clouds (possibly 1 vs itself). Creates a geometry, specifying a cost function passed as CostFn type object. - When the number of points is large, setting the `online` flag to `True` - implies that cost and kernel matrices used to update potentials or scalings + When the number of points is large, setting the ``batch_size`` flag implies + that cost and kernel matrices used to update potentials or scalings will be recomputed on the fly, rather than stored in memory. More precisely, - when setting `online`, the cost function will be partially cached by storing - norm values for each point in both point clouds, but the pairwise cost + when setting ``batch_size``, the cost function will be partially cached by + storing norm values for each point in both point clouds, but the pairwise cost function evaluations won't be. The sum of norms + the pairwise cost term is raised to `power`. @@ -737,7 +738,7 @@ def _apply_lse_kernel_xy( x, y, norm_x, norm_y, f, g, eps, vec, cost_fn, cost_pow, scale_cost ): c = _cost(x, y, norm_x, norm_y, cost_fn, cost_pow, scale_cost) - return ops.logsumexp((f + g - c) / eps, b=vec, return_sign=True, axis=-1) + return utils.logsumexp((f + g - c) / eps, b=vec, return_sign=True, axis=-1) def _transport_from_potentials_xy( diff --git a/ott/initializers/__init__.py b/ott/initializers/__init__.py new file mode 100644 index 000000000..15cfac006 --- /dev/null +++ b/ott/initializers/__init__.py @@ -0,0 +1 @@ +from . import linear, nn, quadratic diff --git a/ott/initializers/linear/__init__.py b/ott/initializers/linear/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ott/core/initializers.py b/ott/initializers/linear/initializers.py similarity index 52% rename from ott/core/initializers.py rename to ott/initializers/linear/initializers.py index 700aa47a5..2b10ed6e2 100644 --- a/ott/core/initializers.py +++ b/ott/initializers/linear/initializers.py @@ -12,44 +12,40 @@ # See the License for the specific language governing permissions and # limitations under the License. """Sinkhorn initializers.""" -import functools -from abc import ABC, abstractmethod +import abc from typing import Any, Dict, Optional, Sequence, Tuple import jax import jax.numpy as jnp -import optax -from flax import linen as nn -from flax.training import train_state -from ott.core import linear_problems, sinkhorn -from ott.geometry import geometry, pointcloud +from ott.geometry import pointcloud +from ott.problems.linear import linear_problem __all__ = [ - "DefaultInitializer", "GaussianInitializer", "SortingInitializer", - "MetaInitializer" + "SinkhornInitializer", "DefaultInitializer", "GaussianInitializer", + "SortingInitializer" ] @jax.tree_util.register_pytree_node_class -class SinkhornInitializer(ABC): +class SinkhornInitializer(abc.ABC): """Base class for Sinkhorn initializers.""" - @abstractmethod + @abc.abstractmethod def init_dual_a( - self, ot_prob: linear_problems.LinearProblem, lse_mode: bool + self, ot_prob: linear_problem.LinearProblem, lse_mode: bool ) -> jnp.ndarray: """Initialization for Sinkhorn potential/scaling f_u.""" - @abstractmethod + @abc.abstractmethod def init_dual_b( - self, ot_prob: linear_problems.LinearProblem, lse_mode: bool + self, ot_prob: linear_problem.LinearProblem, lse_mode: bool ) -> jnp.ndarray: """Initialization for Sinkhorn potential/scaling g_v.""" def __call__( self, - ot_prob: linear_problems.LinearProblem, + ot_prob: linear_problem.LinearProblem, a: Optional[jnp.ndarray], b: Optional[jnp.ndarray], lse_mode: bool, @@ -101,7 +97,7 @@ class DefaultInitializer(SinkhornInitializer): """Default initialization of Sinkhorn dual potentials/primal scalings.""" def init_dual_a( - self, ot_prob: linear_problems.LinearProblem, lse_mode: bool + self, ot_prob: linear_problem.LinearProblem, lse_mode: bool ) -> jnp.ndarray: """Initialize Sinkhorn potential/scaling f_u. @@ -117,7 +113,7 @@ def init_dual_a( return init_dual_a def init_dual_b( - self, ot_prob: linear_problems.LinearProblem, lse_mode: bool + self, ot_prob: linear_problem.LinearProblem, lse_mode: bool ) -> jnp.ndarray: """Initialize Sinkhorn potential/scaling g_v. @@ -145,7 +141,7 @@ class GaussianInitializer(DefaultInitializer): def init_dual_a( self, - ot_prob: linear_problems.LinearProblem, + ot_prob: linear_problem.LinearProblem, lse_mode: bool, ) -> jnp.ndarray: """Gaussian initialization function. @@ -247,7 +243,7 @@ def cond_fn(state: Tuple[jnp.ndarray, float, int]) -> bool: def init_dual_a( self, - ot_prob: linear_problems.LinearProblem, + ot_prob: linear_problem.LinearProblem, lse_mode: bool, init_f: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: @@ -285,212 +281,6 @@ def init_dual_a( return f_u -@jax.tree_util.register_pytree_node_class -class MetaInitializer(DefaultInitializer): - """Meta OT Initializer with a fixed geometry :cite:`amos:22`. - - This initializer consists of a predictive model that outputs the - :math:`f` duals to solve the entropy-regularized OT problem given - input probability weights ``a`` and ``b``, and a given (assumed to be - fixed) geometry ``geom``. - The model's parameters are learned using a training set of OT - instances (multiple pairs of probability weights), that assume the - **same** geometry ``geom`` is used throughout, both for training and - evaluation. The meta model defaults to the MLP in - :class:`~ott.core.initializers.MetaMLP` and, with batched problem - instances passed into :meth:`update`. - - **Sample training usage.** The following code shows a simple - example of using ``update`` to train the model, where - ``a`` and ``b`` are the weights of the measures and - ``geom`` is the fixed geometry. - - .. code-block:: python - - meta_initializer = init_lib.MetaInitializer(geom=geom) - while training(): - a, b = sample_batch() - loss, init_f, meta_initializer.state = meta_initializer.update( - meta_initializer.state, a=a, b=b) - - Args: - geom: The fixed geometry of the problem instances. - meta_model: The model to predict the potential :math:`f` from the measures. - opt: The optimizer to update the parameters. - rng: The PRNG key to use for initializing the model. - state: The training state of the model to start from. - """ - - def __init__( - self, - geom: geometry.Geometry, - meta_model: Optional[nn.Module] = None, - opt: optax.GradientTransformation = optax.adam(learning_rate=1e-3), - rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0), - state: Optional[train_state.TrainState] = None - ): - self.geom = geom - self.dtype = geom.x.dtype - self.opt = opt - self.rng = rng - - na, nb = geom.shape - self.meta_model = MetaMLP( - potential_size=na - ) if meta_model is None else meta_model - - if state is None: - # Initialize the model's training state. - a_placeholder = jnp.zeros(na, dtype=self.dtype) - b_placeholder = jnp.zeros(nb, dtype=self.dtype) - params = self.meta_model.init(rng, a_placeholder, b_placeholder)['params'] - self.state = train_state.TrainState.create( - apply_fn=self.meta_model.apply, params=params, tx=opt - ) - else: - self.state = state - - self.update_impl = self._get_update_fn() - - def update( - self, state: train_state.TrainState, a: jnp.ndarray, b: jnp.ndarray - ) -> Tuple[jnp.ndarray, jnp.ndarray, train_state.TrainState]: - r"""Update the meta model with the dual objective. - - The goal is for the model to match the optimal duals, i.e., - :math:`\hat f_\theta \approx f^\star`. - This can be done by training the predictions of :math:`\hat f_\theta` - to optimize the dual objective, which :math:`f^\star` also optimizes for. - The overall learning setup can thus be written as: - - .. math:: - \min_\theta\; {\mathbb E}_{(\alpha,\beta)\sim{\mathcal{D}}}\; - J(\hat f_\theta(a, b); \alpha, \beta), - - where :math:`a,b` are the probabilities of the measures :math:`\alpha,\beta`, - :math:`\mathcal{D}` is a meta distribution of optimal transport problems, - - .. math:: - -J(f; \alpha, \beta, c) := \langle f, a\rangle + \langle g, b \rangle - - \varepsilon\left\langle \exp\{f/\varepsilon\}, K\exp\{g/\varepsilon\}\right\rangle - - is the entropic dual objective, - and :math:`K_{i,j} := -C_{i,j}/\varepsilon` is the *Gibbs kernel*. - - Args: - state: Optimizer state of the meta model. - a: Probabilites of the :math:`\alpha` measure's atoms. - b: Probabilites of the :math:`\beta` measure's atoms. - - Returns: - The training loss, :math:`f`, and updated state. - """ - return self.update_impl(state, a, b) - - def init_dual_a( - self, ot_prob: linear_problems.LinearProblem, lse_mode: bool - ) -> jnp.ndarray: - # Detect if the problem is batched. - assert ot_prob.a.ndim in (1, 2) and ot_prob.b.ndim in (1, 2) - vmap_a_val = 0 if ot_prob.a.ndim == 2 else None - vmap_b_val = 0 if ot_prob.b.ndim == 2 else None - - if vmap_a_val is not None or vmap_b_val is not None: - compute_f_maybe_batch = jax.vmap( - self._compute_f, in_axes=(vmap_a_val, vmap_b_val, None) - ) - else: - compute_f_maybe_batch = self._compute_f - - init_f = compute_f_maybe_batch(ot_prob.a, ot_prob.b, self.state.params) - f_u = init_f if lse_mode else ot_prob.geom.scaling_from_potential(init_f) - return f_u - - def _get_update_fn(self): - """Return the implementation (and jitted) update function.""" - - def dual_obj_loss_single(params, a, b): - f_pred = self._compute_f(a, b, params) - g_pred = self.geom.update_potential( - f_pred, jnp.zeros_like(b), jnp.log(b), 0, axis=0 - ) - g_pred = jnp.where(jnp.isfinite(g_pred), g_pred, 0.) - - ot_prob = linear_problems.LinearProblem(geom=self.geom, a=a, b=b) - dual_obj = sinkhorn.ent_reg_cost(f_pred, g_pred, ot_prob, lse_mode=True) - loss = -dual_obj - return loss, f_pred - - def loss_batch(params, a, b): - loss_fn = functools.partial(dual_obj_loss_single, params=params) - loss, f_pred = jax.vmap(loss_fn)(a=a, b=b) - return jnp.mean(loss), f_pred - - @jax.jit - def update(state, a, b): - a = jnp.atleast_2d(a) - b = jnp.atleast_2d(b) - grad_fn = jax.value_and_grad(loss_batch, has_aux=True) - (loss, init_f), grads = grad_fn(state.params, a, b) - return loss, init_f, state.apply_gradients(grads=grads) - - return update - - def _compute_f(self, a, b, params): - r"""Predict the optimal :math:`f` potential. - - Args: - a: Probabilites of the :math:`\alpha` measure's atoms. - b: Probabilites of the :math:`\beta` measure's atoms. - params: The parameters of the Meta model. - - Returns: - The :math:`f` potential. - """ - return self.meta_model.apply({'params': params}, a, b) - - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: - return [self.geom, self.meta_model, self.opt], { - 'rng': self.rng, - 'state': self.state - } - - -class MetaMLP(nn.Module): - r"""A Meta MLP potential for :class:`~ott.core.initializers.MetaInitializer`. - - This provides an MLP :math:`\hat f_\theta(a, b)` that maps from the probabilities - of the measures to the optimal dual potentials :math:`f`. - - Args: - potential_size: The dimensionality of :math:`f`. - num_hidden_units: The number of hidden units in each layer. - num_hidden_layers: The number of hidden layers. - """ - - potential_size: int - num_hidden_units: int = 512 - num_hidden_layers: int = 3 - - @nn.compact - def __call__(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: - r"""Make a prediction. - - Args: - a: Probabilites of the :math:`\alpha` measure's atoms. - b: Probabilites of the :math:`\beta` measure's atoms. - - Returns: - The :math:`f` potential. - """ - dtype = a.dtype - z = jnp.concatenate((a, b)) - for _ in range(self.num_hidden_layers): - z = nn.relu(nn.Dense(self.num_hidden_units, dtype=dtype)(z)) - f = nn.Dense(self.potential_size, dtype=dtype)(z) - return f - - def _vectorized_update( f: jnp.ndarray, modified_cost: jnp.ndarray ) -> jnp.ndarray: diff --git a/ott/core/initializers_lr.py b/ott/initializers/linear/initializers_lr.py similarity index 94% rename from ott/core/initializers_lr.py rename to ott/initializers/linear/initializers_lr.py index 641f1b284..2999bf850 100644 --- a/ott/core/initializers_lr.py +++ b/ott/initializers/linear/initializers_lr.py @@ -1,5 +1,5 @@ +import abc import functools -from abc import ABC, abstractmethod from typing import ( TYPE_CHECKING, Any, @@ -17,8 +17,9 @@ from jax import numpy as jnp from typing_extensions import Literal -from ott.core import _math_utils as mu from ott.geometry import geometry, low_rank, pointcloud +from ott.math import fixed_point_loop +from ott.math import utils as mu __all__ = [ "RandomInitializer", "Rank2Initializer", "KMeansInitializer", @@ -26,23 +27,21 @@ ] if TYPE_CHECKING: - from ott.core import ( - gromov_wasserstein, - linear_problems, - quad_problems, - sinkhorn, - sinkhorn_lr, - ) - Problem_t = Union[linear_problems.LinearProblem, - quad_problems.QuadraticProblem] + from ott.problems.linear import linear_problem + from ott.problems.quadratic import quadratic_problem + from ott.solvers.linear import sinkhorn, sinkhorn_lr + from ott.solvers.quadratic import gromov_wasserstein + + Problem_t = Union[linear_problem.LinearProblem, + quadratic_problem.QuadraticProblem] else: Problem_t = "Union[linear_problems.LinearProblem, " \ "quad_problems.QuadraticProblem]" @jax.tree_util.register_pytree_node_class -class LRInitializer(ABC): - """Low-rank initializer for linear/quadratic problems. +class LRInitializer(abc.ABC): + """Base class for low-rank initializers. Args: rank: Rank of the factorization. @@ -53,7 +52,7 @@ def __init__(self, rank: int, **kwargs: Any): self._rank = rank self._kwargs = kwargs - @abstractmethod + @abc.abstractmethod def init_q( self, ot_prob: Problem_t, @@ -67,13 +66,14 @@ def init_q( Args: ot_prob: OT problem. key: Random key for seeding. + init_g: Initial value for :math:`g` factor. kwargs: Additional keyword arguments. Returns: Array of shape ``[n, rank]``. """ - @abstractmethod + @abc.abstractmethod def init_r( self, ot_prob: Problem_t, @@ -87,13 +87,14 @@ def init_r( Args: ot_prob: Linear OT problem. key: Random key for seeding. + init_g: Initial value for :math:`g` factor. kwargs: Additional keyword arguments. Returns: Array of shape ``[m, rank]``. """ - @abstractmethod + @abc.abstractmethod def init_g( self, ot_prob: Problem_t, @@ -130,7 +131,7 @@ def from_solver( Returns: The low-rank initializer. """ - from ott.core import gromov_wasserstein + from ott.solvers.quadratic import gromov_wasserstein if isinstance(solver, gromov_wasserstein.GromovWasserstein): assert solver.is_low_rank, "GW solver is not low-rank." @@ -382,7 +383,9 @@ def _compute_factor( which: Literal["q", "r"], **kwargs: Any, ) -> jnp.ndarray: - from ott.core import linear_problems, quad_problems, sinkhorn + from ott.problems import linear as linear_problems + from ott.problems import quadratic as quad_problems + from ott.solvers.linear import sinkhorn from ott.tools import k_means del kwargs @@ -512,7 +515,9 @@ def _compute_factor( which: Literal["q", "r"], **kwargs: Any, ) -> jnp.ndarray: - from ott.core import fixed_point_loop, linear_problems, sinkhorn + from ott.problems import linear as linear_problems + from ott.problems import quadratic as quad_problems + from ott.solvers.linear import sinkhorn def init_fn() -> GeneralizedKMeansInitializer.State: n = geom.shape[0] @@ -611,7 +616,6 @@ def body_fn( ) del kwargs - from ott.core import quad_problems if isinstance(ot_prob, quad_problems.QuadraticProblem): geom = ot_prob.geom_xx if which == "q" else ot_prob.geom_yy diff --git a/ott/initializers/nn/__init__.py b/ott/initializers/nn/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ott/initializers/nn/initializers.py b/ott/initializers/nn/initializers.py new file mode 100644 index 000000000..05f173b42 --- /dev/null +++ b/ott/initializers/nn/initializers.py @@ -0,0 +1,188 @@ +import functools +from typing import Any, Dict, Optional, Sequence, Tuple + +import jax +import jax.numpy as jnp +import optax +from flax import linen as nn +from flax.training import train_state + +from ott.geometry import geometry +from ott.initializers.linear import DefaultInitializer +from ott.initializers.nn import layers +from ott.problems.linear import linear_problem +from ott.solvers.linear import sinkhorn + +# TODO(michalk8): add Charlotte's initializer? +__all__ = ["MetaInitializer"] + + +@jax.tree_util.register_pytree_node_class +class MetaInitializer(DefaultInitializer): + """Meta OT Initializer with a fixed geometry :cite:`amos:22`. + + This initializer consists of a predictive model that outputs the + :math:`f` duals to solve the entropy-regularized OT problem given + input probability weights ``a`` and ``b``, and a given (assumed to be + fixed) geometry ``geom``. + The model's parameters are learned using a training set of OT + instances (multiple pairs of probability weights), that assume the + **same** geometry ``geom`` is used throughout, both for training and + evaluation. The meta model defaults to the MLP in + :class:`~ott.core.initializers.MetaMLP` and, with batched problem + instances passed into :meth:`update`. + + **Sample training usage.** The following code shows a simple + example of using ``update`` to train the model, where + ``a`` and ``b`` are the weights of the measures and + ``geom`` is the fixed geometry. + + .. code-block:: python + + meta_initializer = init_lib.MetaInitializer(geom=geom) + while training(): + a, b = sample_batch() + loss, init_f, meta_initializer.state = meta_initializer.update( + meta_initializer.state, a=a, b=b) + + Args: + geom: The fixed geometry of the problem instances. + meta_model: The model to predict the potential :math:`f` from the measures. + opt: The optimizer to update the parameters. + rng: The PRNG key to use for initializing the model. + state: The training state of the model to start from. + """ + + def __init__( + self, + geom: geometry.Geometry, + meta_model: Optional[nn.Module] = None, + opt: optax.GradientTransformation = optax.adam(learning_rate=1e-3), + rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0), + state: Optional[train_state.TrainState] = None + ): + self.geom = geom + self.dtype = geom.x.dtype + self.opt = opt + self.rng = rng + + na, nb = geom.shape + self.meta_model = layers.MetaMLP( + potential_size=na + ) if meta_model is None else meta_model + + if state is None: + # Initialize the model's training state. + a_placeholder = jnp.zeros(na, dtype=self.dtype) + b_placeholder = jnp.zeros(nb, dtype=self.dtype) + params = self.meta_model.init(rng, a_placeholder, b_placeholder)['params'] + self.state = train_state.TrainState.create( + apply_fn=self.meta_model.apply, params=params, tx=opt + ) + else: + self.state = state + + self.update_impl = self._get_update_fn() + + def update( + self, state: train_state.TrainState, a: jnp.ndarray, b: jnp.ndarray + ) -> Tuple[jnp.ndarray, jnp.ndarray, train_state.TrainState]: + r"""Update the meta model with the dual objective. + + The goal is for the model to match the optimal duals, i.e., + :math:`\hat f_\theta \approx f^\star`. + This can be done by training the predictions of :math:`\hat f_\theta` + to optimize the dual objective, which :math:`f^\star` also optimizes for. + The overall learning setup can thus be written as: + + .. math:: + \min_\theta\; {\mathbb E}_{(\alpha,\beta)\sim{\mathcal{D}}}\; + J(\hat f_\theta(a, b); \alpha, \beta), + + where :math:`a,b` are the probabilities of the measures :math:`\alpha,\beta`, + :math:`\mathcal{D}` is a meta distribution of optimal transport problems, + + .. math:: + -J(f; \alpha, \beta, c) := \langle f, a\rangle + \langle g, b \rangle - + \varepsilon\left\langle \exp\{f/\varepsilon\}, K\exp\{g/\varepsilon\}\right\rangle + + is the entropic dual objective, + and :math:`K_{i,j} := -C_{i,j}/\varepsilon` is the *Gibbs kernel*. + + Args: + state: Optimizer state of the meta model. + a: Probabilites of the :math:`\alpha` measure's atoms. + b: Probabilites of the :math:`\beta` measure's atoms. + + Returns: + The training loss, :math:`f`, and updated state. + """ + return self.update_impl(state, a, b) + + def init_dual_a( + self, ot_prob: linear_problem.LinearProblem, lse_mode: bool + ) -> jnp.ndarray: + # Detect if the problem is batched. + assert ot_prob.a.ndim in (1, 2) and ot_prob.b.ndim in (1, 2) + vmap_a_val = 0 if ot_prob.a.ndim == 2 else None + vmap_b_val = 0 if ot_prob.b.ndim == 2 else None + + if vmap_a_val is not None or vmap_b_val is not None: + compute_f_maybe_batch = jax.vmap( + self._compute_f, in_axes=(vmap_a_val, vmap_b_val, None) + ) + else: + compute_f_maybe_batch = self._compute_f + + init_f = compute_f_maybe_batch(ot_prob.a, ot_prob.b, self.state.params) + f_u = init_f if lse_mode else ot_prob.geom.scaling_from_potential(init_f) + return f_u + + def _get_update_fn(self): + """Return the implementation (and jitted) update function.""" + + def dual_obj_loss_single(params, a, b): + f_pred = self._compute_f(a, b, params) + g_pred = self.geom.update_potential( + f_pred, jnp.zeros_like(b), jnp.log(b), 0, axis=0 + ) + g_pred = jnp.where(jnp.isfinite(g_pred), g_pred, 0.) + + ot_prob = linear_problem.LinearProblem(geom=self.geom, a=a, b=b) + dual_obj = sinkhorn.ent_reg_cost(f_pred, g_pred, ot_prob, lse_mode=True) + loss = -dual_obj + return loss, f_pred + + def loss_batch(params, a, b): + loss_fn = functools.partial(dual_obj_loss_single, params=params) + loss, f_pred = jax.vmap(loss_fn)(a=a, b=b) + return jnp.mean(loss), f_pred + + @jax.jit + def update(state, a, b): + a = jnp.atleast_2d(a) + b = jnp.atleast_2d(b) + grad_fn = jax.value_and_grad(loss_batch, has_aux=True) + (loss, init_f), grads = grad_fn(state.params, a, b) + return loss, init_f, state.apply_gradients(grads=grads) + + return update + + def _compute_f(self, a, b, params): + r"""Predict the optimal :math:`f` potential. + + Args: + a: Probabilites of the :math:`\alpha` measure's atoms. + b: Probabilites of the :math:`\beta` measure's atoms. + params: The parameters of the Meta model. + + Returns: + The :math:`f` potential. + """ + return self.meta_model.apply({'params': params}, a, b) + + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + return [self.geom, self.meta_model, self.opt], { + 'rng': self.rng, + 'state': self.state + } diff --git a/ott/initializers/nn/layers.py b/ott/initializers/nn/layers.py new file mode 100644 index 000000000..4f6b25382 --- /dev/null +++ b/ott/initializers/nn/layers.py @@ -0,0 +1,39 @@ +from flax import linen as nn +from jax import numpy as jnp + +__all__ = ["MetaMLP"] + + +class MetaMLP(nn.Module): + r"""A Meta MLP potential for :class:`~ott.core.initializers.MetaInitializer`. + + This provides an MLP :math:`\hat f_\theta(a, b)` that maps from the + probabilities of the measures to the optimal dual potentials :math:`f`. + + Args: + potential_size: The dimensionality of :math:`f`. + num_hidden_units: The number of hidden units in each layer. + num_hidden_layers: The number of hidden layers. + """ + + potential_size: int + num_hidden_units: int = 512 + num_hidden_layers: int = 3 + + @nn.compact + def __call__(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: + r"""Make a prediction. + + Args: + a: Probabilities of the :math:`\alpha` measure's atoms. + b: Probabilities of the :math:`\beta` measure's atoms. + + Returns: + The :math:`f` potential. + """ + dtype = a.dtype + z = jnp.concatenate((a, b)) + for _ in range(self.num_hidden_layers): + z = nn.relu(nn.Dense(self.num_hidden_units, dtype=dtype)(z)) + f = nn.Dense(self.potential_size, dtype=dtype)(z) + return f diff --git a/ott/initializers/quadratic/__init__.py b/ott/initializers/quadratic/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ott/core/quad_initializers.py b/ott/initializers/quadratic/initializers.py similarity index 82% rename from ott/core/quad_initializers.py rename to ott/initializers/quadratic/initializers.py index 6bd66698d..bab5a8f54 100644 --- a/ott/core/quad_initializers.py +++ b/ott/initializers/quadratic/initializers.py @@ -1,19 +1,21 @@ -from abc import ABC, abstractmethod +import abc from typing import TYPE_CHECKING, Any, Dict, Sequence, Tuple import jax -from ott.core import linear_problems, sinkhorn_lr from ott.geometry import geometry +from ott.problems.linear import linear_problem +from ott.solvers.linear import sinkhorn_lr if TYPE_CHECKING: - from ott.core import initializers_lr, quad_problems + from ott.initializers.linear import initializers_lr + from ott.problems.quadratic import quadratic_problem __all__ = ["QuadraticInitializer", "LRQuadraticInitializer"] @jax.tree_util.register_pytree_node_class -class BaseQuadraticInitializer(ABC): +class BaseQuadraticInitializer(abc.ABC): """Base class for quadratic initializers. Args: @@ -24,8 +26,8 @@ def __init__(self, **kwargs: Any): self._kwargs = kwargs def __call__( - self, quad_prob: 'quad_problems.QuadraticProblem', **kwargs: Any - ) -> linear_problems.LinearProblem: + self, quad_prob: 'quadratic_problem.QuadraticProblem', **kwargs: Any + ) -> linear_problem.LinearProblem: """Compute the initial linearization of a quadratic problem. Args: @@ -39,7 +41,7 @@ def __call__( geom = self._create_geometry(quad_prob, **kwargs) assert geom.shape == (n, m), f"Expected geometry of shape `{n, m}`, " \ f"found `{geom.shape}`." - return linear_problems.LinearProblem( + return linear_problem.LinearProblem( geom, a=quad_prob.a, b=quad_prob.b, @@ -47,9 +49,9 @@ def __call__( tau_b=quad_prob.tau_b ) - @abstractmethod + @abc.abstractmethod def _create_geometry( - self, quad_prob: 'quad_problems.QuadraticProblem', **kwargs: Any + self, quad_prob: 'quadratic_problem.QuadraticProblem', **kwargs: Any ) -> geometry.Geometry: """Compute initial geometry for linearization. @@ -58,7 +60,7 @@ def _create_geometry( kwargs: Additional keyword arguments. Returns: - The initial geometry used to initialize a linear problem. + Geometry used to initialize the linearized problem. """ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: @@ -104,9 +106,9 @@ class QuadraticInitializer(BaseQuadraticInitializer): """ def _create_geometry( - self, quad_prob: 'quad_problems.QuadraticProblem', *, epsilon: float, + self, quad_prob: 'quadratic_problem.QuadraticProblem', *, epsilon: float, **kwargs: Any - ) -> linear_problems.LinearProblem: + ) -> geometry.Geometry: """Compute initial geometry for linearization. Args: @@ -115,9 +117,9 @@ def _create_geometry( kwargs: Additional keyword arguments, unused. Returns: - The initial geometry used to initialize a linear problem. + The initial geometry used to initialize the linearized problem. """ - from ott.core.quad_problems import apply_cost, update_epsilon_unbalanced + from ott.problems.quadratic import quadratic_problem del kwargs unbalanced_correction = 0.0 @@ -131,14 +133,18 @@ def _create_geometry( if not quad_prob.is_balanced: transport_mass = marginal_1.sum() # Initialises epsilon for Unbalanced GW according to Sejourne et al (2021) - epsilon = update_epsilon_unbalanced(epsilon, transport_mass) + epsilon = quadratic_problem.update_epsilon_unbalanced( + epsilon, transport_mass + ) unbalanced_correction = quad_prob.cost_unbalanced_correction( tmp, marginal_1, marginal_2, epsilon ) h1, h2 = quad_prob.quad_loss - tmp = apply_cost(quad_prob.geom_xx, tmp, axis=1, fn=h1) - tmp = apply_cost(quad_prob.geom_yy, tmp.T, axis=1, fn=h2).T + tmp = quadratic_problem.apply_cost(quad_prob.geom_xx, tmp, axis=1, fn=h1) + tmp = quadratic_problem.apply_cost( + quad_prob.geom_yy, tmp.T, axis=1, fn=h2 + ).T cost_matrix = (marginal_cost.cost_matrix - tmp + unbalanced_correction) cost_matrix += quad_prob.fused_penalty * quad_prob._fused_cost_matrix @@ -158,7 +164,7 @@ def __init__(self, lr_linear_initializer: 'initializers_lr.LRInitializer'): self._linear_lr_initializer = lr_linear_initializer def _create_geometry( - self, quad_prob: 'quad_problems.QuadraticProblem', **kwargs: Any + self, quad_prob: 'quadratic_problem.QuadraticProblem', **kwargs: Any ) -> geometry.Geometry: """Compute initial geometry for linearization. diff --git a/ott/math/__init__.py b/ott/math/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ott/core/decomposition.py b/ott/math/decomposition.py similarity index 100% rename from ott/core/decomposition.py rename to ott/math/decomposition.py diff --git a/ott/core/fixed_point_loop.py b/ott/math/fixed_point_loop.py similarity index 99% rename from ott/core/fixed_point_loop.py rename to ott/math/fixed_point_loop.py index 3482bb031..fbfb04789 100644 --- a/ott/core/fixed_point_loop.py +++ b/ott/math/fixed_point_loop.py @@ -17,9 +17,9 @@ from typing import Any, Callable import jax +import jax.numpy as jnp import numpy as np from jax import dtypes -from jax import numpy as jnp def fixpoint_iter( diff --git a/ott/core/implicit_differentiation.py b/ott/math/implicit_differentiation.py similarity index 98% rename from ott/core/implicit_differentiation.py rename to ott/math/implicit_differentiation.py index 8b68f10ef..44fc946e5 100644 --- a/ott/core/implicit_differentiation.py +++ b/ott/math/implicit_differentiation.py @@ -18,7 +18,9 @@ import jax import jax.numpy as jnp -from ott.core import dataclasses, linear_problems, unbalanced_functions +from ott.math import unbalanced_functions +from ott.problems import linear as linear_problems +from ott.utils import dataclasses @dataclasses.register_pytree_node diff --git a/ott/geometry/matrix_square_root.py b/ott/math/matrix_square_root.py similarity index 99% rename from ott/geometry/matrix_square_root.py rename to ott/math/matrix_square_root.py index 22eac70b9..9cca55695 100644 --- a/ott/geometry/matrix_square_root.py +++ b/ott/math/matrix_square_root.py @@ -22,7 +22,7 @@ import jax.numpy as jnp import numpy as np -from ott.core import fixed_point_loop +from ott.math import fixed_point_loop @functools.partial(jax.custom_vjp, nondiff_argnums=(1, 2, 3, 4, 5)) diff --git a/ott/core/potentials.py b/ott/math/potentials.py similarity index 100% rename from ott/core/potentials.py rename to ott/math/potentials.py diff --git a/ott/core/unbalanced_functions.py b/ott/math/unbalanced_functions.py similarity index 100% rename from ott/core/unbalanced_functions.py rename to ott/math/unbalanced_functions.py diff --git a/ott/geometry/ops.py b/ott/math/utils.py similarity index 62% rename from ott/geometry/ops.py rename to ott/math/utils.py index 1bf3248a3..6b1398aad 100644 --- a/ott/geometry/ops.py +++ b/ott/math/utils.py @@ -1,23 +1,42 @@ -# Copyright 2022 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Low level functions used within the scope of Geometric processing.""" - import functools +from typing import Optional, Union import jax +import jax.experimental.sparse as jesp import jax.numpy as jnp +__all__ = [ + "safe_log", "kl", "js", "sparse_scale", "logsumexp", + "barycentric_projection" +] + +Sparse_t = Union[jesp.CSR, jesp.CSC, jesp.COO, jesp.BCOO] + + +def safe_log(x: jnp.ndarray, *, eps: Optional[float] = None) -> jnp.ndarray: + if eps is None: + eps = jnp.finfo(x.dtype).tiny + return jnp.where(x > 0., jnp.log(x), jnp.log(eps)) + + +def kl(p: jnp.ndarray, q: jnp.ndarray) -> float: + """Kullback-Leilbler divergence.""" + return jnp.vdot(p, (safe_log(p) - safe_log(q))) + + +def js(p: jnp.ndarray, q: jnp.ndarray, *, c: float = 0.5) -> float: + """Jensen-Shannon divergence.""" + return c * (kl(p, q) + kl(q, p)) + + +def sparse_scale(c: float, mat: Sparse_t) -> Sparse_t: + """Scale a sparse matrix by a constant.""" + if isinstance(mat, jesp.BCOO): + # most feature complete, defer to original impl. + return c * mat + (data, *children), aux_data = mat.tree_flatten() + return type(mat).tree_unflatten(aux_data, [c * data] + children) + @functools.partial(jax.custom_jvp, nondiff_argnums=(1, 2, 4)) def logsumexp(mat, axis=None, keepdims=False, b=None, return_sign=False): @@ -69,3 +88,10 @@ def logsumexp_jvp(axis, keepdims, return_sign, primals, tangents): return (lse, sign), (sign * res, jnp.zeros_like(sign)) else: return lse, res + + +@functools.partial(jax.vmap, in_axes=[0, 0, None]) +def barycentric_projection( + matrix: jnp.ndarray, y: jnp.ndarray, cost_fn +) -> jnp.ndarray: + return jax.vmap(cost_fn.barycenter, in_axes=[0, None])(matrix, y) diff --git a/ott/problems/__init__.py b/ott/problems/__init__.py new file mode 100644 index 000000000..87714fd6b --- /dev/null +++ b/ott/problems/__init__.py @@ -0,0 +1 @@ +from . import linear, quadratic diff --git a/ott/problems/linear/__init__.py b/ott/problems/linear/__init__.py new file mode 100644 index 000000000..1ab3d4ccb --- /dev/null +++ b/ott/problems/linear/__init__.py @@ -0,0 +1,2 @@ +from .barycenter_problem import BarycenterProblem +from .linear_problem import LinearProblem diff --git a/ott/problems/linear/barycenter_problem.py b/ott/problems/linear/barycenter_problem.py new file mode 100644 index 000000000..a34ee728f --- /dev/null +++ b/ott/problems/linear/barycenter_problem.py @@ -0,0 +1,182 @@ +# Copyright 2022 Apple Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Classes defining OT problem(s) (objective function + utilities).""" +from typing import Any, Dict, Optional, Sequence, Tuple + +import jax +import jax.numpy as jnp + +from ott.geometry import costs +from ott.utils import segment + +__all__ = ["BarycenterProblem"] + + +@jax.tree_util.register_pytree_node_class +class BarycenterProblem: + """Wasserstein barycenter problem :cite:`cuturi:14`. + + Args: + y: Array of shape ``[num_total_points, ndim]`` merging the points of all + measures. Alternatively, already segmented array of shape + ``[num_measures, max_measure_size, ndim]`` can be passed. + See also :func:`~ott.core.segment.segment_point_cloud`. + b: Array of shape ``[num_total_points,]`` containing the weights of all + the points within the measures that define the barycenter problem. + Similarly as ``y``, segmented array of weights of shape + ``[num_measures, max_measure_size]`` can be passed. + If ``y`` is already pre-segmented, this array must be always specified. + weights: Array of shape ``[num_measures,]`` containing the weights of the + measures. + cost_fn: Cost function used. If `None`, + use :class:`~ott.geometry.costs.SqEuclidean` cost. + epsilon: Epsilon regularization used to solve reg-OT problems. + debiased: **Currently not implemented.** + Whether the problem is debiased, in the sense that + the regularized transportation cost of barycenter to itself will + be considered when computing gradient. Note that if the debiased option + is used, the barycenter size in + :meth:`~ott.core.continuous_barycenter.WassersteinBarycenter.init_state` + needs to be smaller than the maximum measure size for parallelization to + operate efficiently. + kwargs: Keyword arguments :func:`~ott.core.segment.segment_point_cloud`. + Only used when ``y`` is not already segmented. When passing + ``segment_ids``, 2 arguments must be specified for jitting to work: + + - ``num_segments`` - the total number of measures. + - ``max_measure_size`` - maximum of support sizes of these measures. + """ + + def __init__( + self, + y: jnp.ndarray, + b: Optional[jnp.ndarray] = None, + weights: Optional[jnp.ndarray] = None, + cost_fn: Optional[costs.CostFn] = None, + epsilon: Optional[float] = None, + debiased: bool = False, + **kwargs: Any, + ): + self._y = y + if y.ndim == 3 and b is None: + raise ValueError("Specify weights if `y` is already segmented.") + self._b = b + self._weights = weights + self.cost_fn = costs.SqEuclidean() if cost_fn is None else cost_fn + self.epsilon = epsilon + self.debiased = debiased + self._kwargs = kwargs + + if self._is_segmented: + # (num_measures, max_measure_size, ndim) + # (num_measures, max_measure_size) + assert self._y.shape[:2] == self._b.shape + else: + # (num_total_points, ndim) + # (num_total_points,) + assert self._b is None or self._y.shape[0] == self._b.shape[0] + + @property + def segmented_y_b(self) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Tuple of arrays containing the segmented measures and weights. + + Additional segment may be added when the problem is debiased. + + - Segmented measures of shape ``[num_measures, max_measure_size, ndim]``. + - Segmented weights of shape ``[num_measures, max_measure_size]``. + """ + if self._is_segmented: + y, b = self._y, self._b + else: + y, b = segment.segment_point_cloud( + x=self._y, + a=self._b, + padding_vector=self.cost_fn.padder(self.ndim), + **self._kwargs + ) + + if self.debiased: + return self._add_slice_for_debiased(y, b) + return y, b + + def _add_slice_for_debiased( + self, y: jnp.ndarray, b: jnp.ndarray + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + y, b = self._y, self._b + _, n, ndim = y.shape # (num_measures, max_measure_size, ndim) + # yapf: disable + y = jnp.concatenate((y, jnp.zeros((1, n, ndim))), axis=0) + b = jnp.concatenate((b, jnp.zeros((1, n))), axis=0) + # yapf: enable + return y, b + + @property + def flattened_y(self) -> jnp.ndarray: + """Array of shape ``[num_measures * (N_1 + N_2 + ...), ndim]``.""" + if self._is_segmented: + return self._y.reshape((-1, self._y.shape[-1])) + return self._y + + @property + def flattened_b(self) -> Optional[jnp.ndarray]: + """Array of shape ``[num_measures * (N_1 + N_2 + ...),]``.""" + return None if self._b is None else self._b.ravel() + + @property + def num_measures(self) -> int: + """Number of measures.""" + return self.segmented_y_b[0].shape[0] + + @property + def max_measure_size(self) -> int: + """Maximum number of points across all measures.""" + return self.segmented_y_b[0].shape[1] + + @property + def ndim(self) -> int: + """Number of dimensions of ``y``.""" + return self._y.shape[-1] + + @property + def weights(self) -> jnp.ndarray: + """Barycenter weights of shape ``[num_measures,]`` that sum to 1.""" + if self._weights is None: + weights = jnp.ones((self.num_measures,)) / self.num_measures + else: + # Check that the number of measures coincides with the weights' size. + assert self._weights.shape[0] == self.num_measures + # By default, we assume that weights sum to 1, and enforce this if needed. + weights = self._weights / jnp.sum(self._weights) + if self.debiased: + weights = jnp.concatenate((weights, jnp.array([-0.5]))) + return weights + + @property + def _is_segmented(self) -> bool: + return self._y.ndim == 3 + + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + return ([self._y, self._b, self._weights], { + 'cost_fn': self.cost_fn, + 'epsilon': self.epsilon, + 'debiased': self.debiased, + **self._kwargs, + }) + + @classmethod + def tree_unflatten( + cls, aux_data: Dict[str, Any], children: Sequence[Any] + ) -> "BarycenterProblem": + y, b, weights = children + return cls(y=y, b=b, weights=weights, **aux_data) diff --git a/ott/core/linear_problems.py b/ott/problems/linear/linear_problem.py similarity index 99% rename from ott/core/linear_problems.py rename to ott/problems/linear/linear_problem.py index 4357267ff..97c54d4e8 100644 --- a/ott/core/linear_problems.py +++ b/ott/problems/linear/linear_problem.py @@ -20,6 +20,8 @@ from ott.geometry import geometry +__all__ = ["LinearProblem"] + MarginalFunc = Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray] TransportAppFunc = Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray, int], jnp.ndarray] diff --git a/ott/problems/quadratic/__init__.py b/ott/problems/quadratic/__init__.py new file mode 100644 index 000000000..48c54ef97 --- /dev/null +++ b/ott/problems/quadratic/__init__.py @@ -0,0 +1,3 @@ +from . import quadratic_costs +from .barycenter_problem import GWBarycenterProblem +from .quadratic_problem import QuadraticProblem diff --git a/ott/core/bar_problems.py b/ott/problems/quadratic/barycenter_problem.py similarity index 57% rename from ott/core/bar_problems.py rename to ott/problems/quadratic/barycenter_problem.py index 68c29f325..089e645de 100644 --- a/ott/core/bar_problems.py +++ b/ott/problems/quadratic/barycenter_problem.py @@ -1,191 +1,21 @@ -# Copyright 2022 Apple Inc -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Classes defining OT problem(s) (objective function + utilities).""" import functools from typing import Any, Dict, Optional, Sequence, Tuple, Union import jax -import jax.numpy as jnp +from jax import numpy as jnp from typing_extensions import Literal -from ott.core import quad_problems, segment from ott.geometry import costs, geometry, pointcloud +from ott.math import utils as mu +from ott.problems.linear import barycenter_problem +from ott.problems.quadratic import quadratic_costs, quadratic_problem +from ott.utils import segment -__all__ = ["BarycenterProblem", "GWBarycenterProblem", "barycentric_projection"] +__all__ = ["GWBarycenterProblem"] @jax.tree_util.register_pytree_node_class -class BarycenterProblem: - """Wasserstein barycenter problem :cite:`cuturi:14`. - - Args: - y: Array of shape ``[num_total_points, ndim]`` merging the points of all - measures. Alternatively, already segmented array of shape - ``[num_measures, max_measure_size, ndim]`` can be passed. - See also :func:`~ott.core.segment.segment_point_cloud`. - b: Array of shape ``[num_total_points,]`` containing the weights of all - the points within the measures that define the barycenter problem. - Similarly as ``y``, segmented array of weights of shape - ``[num_measures, max_measure_size]`` can be passed. - If ``y`` is already pre-segmented, this array must be always specified. - weights: Array of shape ``[num_measures,]`` containing the weights of the - measures. - cost_fn: Cost function used. If `None`, - use :class:`~ott.geometry.costs.SqEuclidean` cost. - epsilon: Epsilon regularization used to solve reg-OT problems. - debiased: **Currently not implemented.** - Whether the problem is debiased, in the sense that - the regularized transportation cost of barycenter to itself will - be considered when computing gradient. Note that if the debiased option - is used, the barycenter size in - :meth:`~ott.core.continuous_barycenter.WassersteinBarycenter.init_state` - needs to be smaller than the maximum measure size for parallelization to - operate efficiently. - kwargs: Keyword arguments :func:`~ott.core.segment.segment_point_cloud`. - Only used when ``y`` is not already segmented. When passing - ``segment_ids``, 2 arguments must be specified for jitting to work: - - - ``num_segments`` - the total number of measures. - - ``max_measure_size`` - maximum of support sizes of these measures. - """ - - def __init__( - self, - y: jnp.ndarray, - b: Optional[jnp.ndarray] = None, - weights: Optional[jnp.ndarray] = None, - cost_fn: Optional[costs.CostFn] = None, - epsilon: Optional[float] = None, - debiased: bool = False, - **kwargs: Any, - ): - self._y = y - if y.ndim == 3 and b is None: - raise ValueError("Specify weights if `y` is already segmented.") - self._b = b - self._weights = weights - self.cost_fn = costs.SqEuclidean() if cost_fn is None else cost_fn - self.epsilon = epsilon - self.debiased = debiased - self._kwargs = kwargs - - if self._is_segmented: - # (num_measures, max_measure_size, ndim) - # (num_measures, max_measure_size) - assert self._y.shape[:2] == self._b.shape - else: - # (num_total_points, ndim) - # (num_total_points,) - assert self._b is None or self._y.shape[0] == self._b.shape[0] - - @property - def segmented_y_b(self) -> Tuple[jnp.ndarray, jnp.ndarray]: - """Tuple of arrays containing the segmented measures and weights. - - Additional segment may be added when the problem is debiased. - - - Segmented measures of shape ``[num_measures, max_measure_size, ndim]``. - - Segmented weights of shape ``[num_measures, max_measure_size]``. - """ - if self._is_segmented: - y, b = self._y, self._b - else: - y, b = segment.segment_point_cloud( - x=self._y, - a=self._b, - padding_vector=self.cost_fn.padder(self.ndim), - **self._kwargs - ) - - if self.debiased: - return self._add_slice_for_debiased(y, b) - return y, b - - def _add_slice_for_debiased( - self, y: jnp.ndarray, b: jnp.ndarray - ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: - y, b = self._y, self._b - _, n, ndim = y.shape # (num_measures, max_measure_size, ndim) - # yapf: disable - y = jnp.concatenate((y, jnp.zeros((1, n, ndim))), axis=0) - b = jnp.concatenate((b, jnp.zeros((1, n))), axis=0) - # yapf: enable - return y, b - - @property - def flattened_y(self) -> jnp.ndarray: - """Array of shape ``[num_measures * (N_1 + N_2 + ...), ndim]``.""" - if self._is_segmented: - return self._y.reshape((-1, self._y.shape[-1])) - return self._y - - @property - def flattened_b(self) -> Optional[jnp.ndarray]: - """Array of shape ``[num_measures * (N_1 + N_2 + ...),]``.""" - return None if self._b is None else self._b.ravel() - - @property - def num_measures(self) -> int: - """Number of measures.""" - return self.segmented_y_b[0].shape[0] - - @property - def max_measure_size(self) -> int: - """Maximum number of points across all measures.""" - return self.segmented_y_b[0].shape[1] - - @property - def ndim(self) -> int: - """Number of dimensions of ``y``.""" - return self._y.shape[-1] - - @property - def weights(self) -> jnp.ndarray: - """Barycenter weights of shape ``[num_measures,]`` that sum to 1.""" - if self._weights is None: - weights = jnp.ones((self.num_measures,)) / self.num_measures - else: - # Check that the number of measures coincides with the weights' size. - assert self._weights.shape[0] == self.num_measures - # By default, we assume that weights sum to 1, and enforce this if needed. - weights = self._weights / jnp.sum(self._weights) - if self.debiased: - weights = jnp.concatenate((weights, jnp.array([-0.5]))) - return weights - - @property - def _is_segmented(self) -> bool: - return self._y.ndim == 3 - - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: - return ([self._y, self._b, self._weights], { - 'cost_fn': self.cost_fn, - 'epsilon': self.epsilon, - 'debiased': self.debiased, - **self._kwargs, - }) - - @classmethod - def tree_unflatten( - cls, aux_data: Dict[str, Any], children: Sequence[Any] - ) -> "BarycenterProblem": - y, b, weights = children - return cls(y=y, b=b, weights=weights, **aux_data) - - -@jax.tree_util.register_pytree_node_class -class GWBarycenterProblem(BarycenterProblem): +class GWBarycenterProblem(barycenter_problem.BarycenterProblem): """(Fused) Gromov-Wasserstein barycenter problem :cite:`peyre:16,vayer:19`. Args: @@ -195,7 +25,7 @@ class GWBarycenterProblem(BarycenterProblem): See also :func:`~ott.core.segment.segment_point_cloud`. b: Array of shape ``[num_total_points,]`` containing the weights of all the points within the measures that define the barycenter problem. - Similarly as ``y``, segmented array of weights of shape + Similarly, as ``y``, segmented array of weights of shape ``[num_measures, max_measure_size]`` can be passed. If ``y`` is already pre-segmented, this array must be passed. weights: Array of shape ``[num_measures,]`` containing the weights of the @@ -277,7 +107,7 @@ def project( y: jnp.ndarray, b: jnp.ndarray, transport: jnp.ndarray, - fn: Optional[quad_problems.Loss], + fn: Optional[quadratic_costs.Loss], ) -> jnp.ndarray: geom = self._create_y_geometry(y, mask=b > 0.) fn, lin = (None, True) if fn is None else (fn.func, fn.is_linear) @@ -332,7 +162,8 @@ def update_features(self, transports: jnp.ndarray, if self._loss_name == "sqeucl": cost = costs.SqEuclidean() return jnp.sum( - weights * barycentric_projection(transports, y_fused, cost), axis=0 + weights * mu.barycentric_projection(transports, y_fused, cost), + axis=0 ) raise NotImplementedError(self._loss_name) @@ -395,7 +226,7 @@ def _create_problem( y: jnp.ndarray, b: jnp.ndarray, f: Optional[jnp.ndarray] = None - ) -> quad_problems.QuadraticProblem: + ) -> quadratic_problem.QuadraticProblem: # TODO(michalk8): in the future, mask in the problem for convenience? bary_mask = state.a > 0. y_mask = b > 0. @@ -411,7 +242,7 @@ def _create_problem( else: geom_xy = None - return quad_problems.QuadraticProblem( + return quadratic_problem.QuadraticProblem( geom_xx=geom_xx, geom_yy=geom_yy, geom_xy=geom_xy, @@ -448,16 +279,16 @@ def ndim_fused(self) -> Optional[int]: return self._y_fused.shape[-1] if self.is_fused else None @property - def gw_loss(self) -> quad_problems.GWLoss: + def gw_loss(self) -> quadratic_costs.GWLoss: """Gromov-Wasserstein loss.""" # TODO(michalk8): custom losses would require inverting some fns; # `https://jax.readthedocs.io/en/latest/notebooks/ some fns; # Writing_custom_interpreters_in_Jax.html#your-first-interpreter-invert` # might be useful if self._loss_name == 'sqeucl': - return quad_problems.make_square_loss() + return quadratic_costs.make_square_loss() if self._loss_name == 'kl': - return quad_problems.make_kl_loss() + return quadratic_costs.make_kl_loss() raise NotImplementedError( f"Loss `{self._loss_name}` is not yet implemented." ) @@ -481,10 +312,3 @@ def tree_unflatten( return cls( y=y, b=b, weights=weights, costs=costs, y_fused=y_fused, **aux_data ) - - -@functools.partial(jax.vmap, in_axes=[0, 0, None]) -def barycentric_projection( - matrix: jnp.ndarray, y: jnp.ndarray, cost_fn -) -> jnp.ndarray: - return jax.vmap(cost_fn.barycenter, in_axes=[0, None])(matrix, y) diff --git a/ott/problems/quadratic/quadratic_costs.py b/ott/problems/quadratic/quadratic_costs.py new file mode 100644 index 000000000..8de1b398d --- /dev/null +++ b/ott/problems/quadratic/quadratic_costs.py @@ -0,0 +1,34 @@ +from typing import Callable, NamedTuple + +import jax +import jax.numpy as jnp + +__all__ = ["make_square_loss", "make_kl_loss"] + + +class Loss(NamedTuple): + func: Callable[[jnp.ndarray], jnp.ndarray] + is_linear: bool + + +class GWLoss(NamedTuple): + f1: Loss + f2: Loss + h1: Loss + h2: Loss + + +def make_square_loss() -> GWLoss: + f1 = Loss(lambda x: x ** 2, is_linear=False) + f2 = Loss(lambda y: y ** 2, is_linear=False) + h1 = Loss(lambda x: x, is_linear=True) + h2 = Loss(lambda y: 2.0 * y, is_linear=True) + return GWLoss(f1, f2, h1, h2) + + +def make_kl_loss(clipping_value: float = 1e-8) -> GWLoss: + f1 = Loss(lambda x: -jax.scipy.special.entr(x) - x, is_linear=False) + f2 = Loss(lambda y: y, is_linear=True) + h1 = Loss(lambda x: x, is_linear=True) + h2 = Loss(lambda y: jnp.log(jnp.clip(y, clipping_value)), is_linear=False) + return GWLoss(f1, f2, h1, h2) diff --git a/ott/core/quad_problems.py b/ott/problems/quadratic/quadratic_problem.py similarity index 90% rename from ott/core/quad_problems.py rename to ott/problems/quadratic/quadratic_problem.py index e4a649ecf..03d0bdbb7 100644 --- a/ott/core/quad_problems.py +++ b/ott/problems/quadratic/quadratic_problem.py @@ -13,62 +13,19 @@ # limitations under the License. """Classes defining OT problem(s) (objective function + utilities).""" -from typing import Callable, NamedTuple, Optional, Tuple, Union +from typing import Optional, Tuple, Union import jax import jax.numpy as jnp +from typing_extensions import Literal -# Because Protocol is not available in Python < 3.8 -from typing_extensions import Literal, Protocol - -from ott.core import linear_problems, sinkhorn_lr from ott.geometry import epsilon_scheduler, geometry, low_rank, pointcloud +from ott.problems.linear import linear_problem +from ott.problems.quadratic import quadratic_costs +from ott.solvers.linear import sinkhorn_lr +from ott.typing import Transport - -class Transport(Protocol): - """Interface for the solution of a transport problem. - - Classes implementing those function do not have to inherit from it, the - class can however be used in type hints to support duck typing. - """ - - @property - def matrix(self) -> jnp.ndarray: - ... - - def apply(self, inputs: jnp.ndarray, axis: int) -> jnp.ndarray: - ... - - def marginal(self, axis: int = 0) -> jnp.ndarray: - ... - - -class Loss(NamedTuple): - func: Callable[[jnp.ndarray], jnp.ndarray] - is_linear: bool - - -class GWLoss(NamedTuple): - f1: Loss - f2: Loss - h1: Loss - h2: Loss - - -def make_square_loss() -> GWLoss: - f1 = Loss(lambda x: x ** 2, is_linear=False) - f2 = Loss(lambda y: y ** 2, is_linear=False) - h1 = Loss(lambda x: x, is_linear=True) - h2 = Loss(lambda y: 2.0 * y, is_linear=True) - return GWLoss(f1, f2, h1, h2) - - -def make_kl_loss(clipping_value: float = 1e-8) -> GWLoss: - f1 = Loss(lambda x: -jax.scipy.special.entr(x) - x, is_linear=False) - f2 = Loss(lambda y: y, is_linear=True) - h1 = Loss(lambda x: x, is_linear=True) - h2 = Loss(lambda y: jnp.log(jnp.clip(y, clipping_value)), is_linear=False) - return GWLoss(f1, f2, h1, h2) +__all__ = ["QuadraticProblem"] @jax.tree_util.register_pytree_node_class @@ -142,7 +99,7 @@ def __init__( scale_cost: Optional[Union[bool, float, str]] = False, a: Optional[jnp.ndarray] = None, b: Optional[jnp.ndarray] = None, - loss: Union[Literal['sqeucl', 'kl'], GWLoss] = 'sqeucl', + loss: Union[Literal['sqeucl', 'kl'], quadratic_costs.GWLoss] = 'sqeucl', tau_a: Optional[float] = 1.0, tau_b: Optional[float] = 1.0, gw_unbalanced_correction: bool = True, @@ -167,9 +124,9 @@ def __init__( self._loss_name = loss if self._loss_name == 'sqeucl': - self.loss = make_square_loss() + self.loss = quadratic_costs.make_square_loss() elif loss == 'kl': - self.loss = make_kl_loss() + self.loss = quadratic_costs.make_kl_loss() else: self.loss = loss @@ -312,7 +269,7 @@ def update_linearization( transport: Transport, epsilon: Optional[Union[epsilon_scheduler.Epsilon, float]] = None, old_transport_mass: float = 1.0 - ) -> linear_problems.LinearProblem: + ) -> linear_problem.LinearProblem: """Update linearization of GW problem by updating cost matrix. If the problem is balanced (``tau_a = 1.0 and tau_b = 1.0``), the equation @@ -363,15 +320,15 @@ def update_linearization( cost_matrix += self.fused_penalty * self._fused_cost_matrix * rescale_factor geom = geometry.Geometry(cost_matrix=cost_matrix, epsilon=epsilon) - return linear_problems.LinearProblem( + return linear_problem.LinearProblem( geom, self.a, self.b, tau_a=self.tau_a, tau_b=self.tau_b ) def update_lr_linearization( self, lr_sink: sinkhorn_lr.LRSinkhornOutput - ) -> linear_problems.LinearProblem: + ) -> linear_problem.LinearProblem: """Update a Quad problem linearization using a LR Sinkhorn.""" - return linear_problems.LinearProblem( + return linear_problem.LinearProblem( self.update_lr_geom(lr_sink), self.a, self.b, @@ -491,12 +448,12 @@ def is_low_rank(self) -> bool: ) @property - def linear_loss(self) -> Tuple[Loss, Loss]: + def linear_loss(self) -> Tuple[quadratic_costs.Loss, quadratic_costs.Loss]: """Linear part of the Gromov-Wasserstein loss.""" return self.loss.f1, self.loss.f2 @property - def quad_loss(self) -> Tuple[Loss, Loss]: + def quad_loss(self) -> Tuple[quadratic_costs.Loss, quadratic_costs.Loss]: """Quadratic part of the Gromov-Wasserstein loss.""" return self.loss.h1, self.loss.h2 @@ -524,7 +481,9 @@ def tree_unflatten(cls, aux_data, children): return cls(*geoms, a=a, b=b, **aux_data) -def update_epsilon_unbalanced(epsilon, transport_mass): +def update_epsilon_unbalanced( + epsilon: Union[float, epsilon_scheduler.Epsilon], transport_mass: float +) -> epsilon_scheduler.Epsilon: updated_epsilon = epsilon_scheduler.Epsilon.make(epsilon) updated_epsilon._scale_epsilon = ( updated_epsilon._scale_epsilon * transport_mass @@ -533,6 +492,7 @@ def update_epsilon_unbalanced(epsilon, transport_mass): def apply_cost( - geom: geometry.Geometry, arr: jnp.ndarray, *, axis: int, fn: Loss + geom: geometry.Geometry, arr: jnp.ndarray, *, axis: int, + fn: quadratic_costs.Loss ) -> jnp.ndarray: return geom.apply_cost(arr, axis=axis, fn=fn.func, is_linear=fn.is_linear) diff --git a/ott/solvers/__init__.py b/ott/solvers/__init__.py new file mode 100644 index 000000000..15cfac006 --- /dev/null +++ b/ott/solvers/__init__.py @@ -0,0 +1 @@ +from . import linear, nn, quadratic diff --git a/ott/solvers/linear/__init__.py b/ott/solvers/linear/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ott/core/anderson.py b/ott/solvers/linear/acceleration.py similarity index 62% rename from ott/core/anderson.py rename to ott/solvers/linear/acceleration.py index e20b6d0ec..8b638abc9 100644 --- a/ott/core/anderson.py +++ b/ott/solvers/linear/acceleration.py @@ -1,25 +1,14 @@ -# Copyright 2022 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tools for Anderson acceleration.""" -from typing import Any +from typing import TYPE_CHECKING import jax -import jax.numpy as jnp +from jax import numpy as jnp -from ott.core import dataclasses +if TYPE_CHECKING: + from ott.solvers.linear import sinkhorn -SinkhornState = Any +from ott.utils import dataclasses + +__all__ = ["AndersonAcceleration", "Momentum"] @dataclasses.register_pytree_node @@ -30,7 +19,7 @@ class AndersonAcceleration: refresh_every: int = 1 # Recompute interpolation periodically. ridge_identity: float = 1e-2 # Ridge used in the linear system. - def extrapolation(self, xs, fxs): + def extrapolation(self, xs: jnp.ndarray, fxs: jnp.ndarray) -> jnp.ndarray: """Compute Anderson extrapolation from past observations.""" # Remove -inf values to instantiate quadratic problem. All others # remain since they might be caused by a valid issue. @@ -53,7 +42,9 @@ def extrapolation(self, xs, fxs): combination = jnp.sum(fxs * weights[None, :], axis=1) return jnp.where(jnp.isfinite(combination), combination, -jnp.inf) - def update(self, state: SinkhornState, iteration: int, pb, lse_mode: bool): + def update( + self, state: 'sinkhorn.SinkhornState', iteration: int, pb, lse_mode: bool + ) -> 'sinkhorn.SinkhornState': """Anderson acceleration update. When using Anderson acceleration, first update the dual variable f_u with @@ -101,12 +92,63 @@ def update(self, state: SinkhornState, iteration: int, pb, lse_mode: bool): ) return state.set(fu=fu, old_fus=old_fus) - def init_maps(self, pb, state): + def init_maps( + self, pb, state: 'sinkhorn.SinkhornState' + ) -> 'sinkhorn.SinkhornState': """Initialize log matrix used in Anderson acceleration with nan values.""" fus = jnp.ones((pb.geom.shape[0], self.memory)) * jnp.nan return state.set(old_fus=fus, old_mapped_fus=fus) - def update_history(self, state, pb, lse_mode: bool): + def update_history( + self, state: 'sinkhorn.SinkhornState', pb, lse_mode: bool + ) -> 'sinkhorn.SinkhornState': f = state.fu if lse_mode else pb.geom.potential_from_scaling(state.fu) mapped = jnp.concatenate((state.old_mapped_fus[:, 1:], f[:, None]), axis=1) return state.set(old_mapped_fus=mapped) + + +@dataclasses.register_pytree_node +class Momentum: + """Momentum for Sinkhorn updates, either constant or adaptive.""" + + start: int = 0 + error_threshold: float = jnp.inf + value: float = 1.0 + inner_iterations: int = 1 + + def weight(self, state: 'sinkhorn.SinkhornState', iteration: int) -> float: + """Compute momentum term if needed, using previously seen errors.""" + if self.start == 0: + return self.value + idx = self.start // self.inner_iterations + + weight = jax.lax.cond( + jnp.logical_and( + iteration >= self.start, + state.errors[idx - 1, -1] < self.error_threshold + ), lambda state: self.lehmann(state), lambda state: self.value, state + ) + return weight + + def lehmann(self, state: 'sinkhorn.SinkhornState') -> float: + """Momentum formula :cite:`lehmann:21`, eq. 5.""" + idx = self.start // self.inner_iterations + error_ratio = jnp.minimum( + state.errors[idx - 1, -1] / state.errors[idx - 2, -1], 0.99 + ) + power = 1.0 / self.inner_iterations + return 2.0 / (1.0 + jnp.sqrt(1.0 - error_ratio ** power)) + + def __call__( + self, + weight: float, + value: jnp.ndarray, + new_value: jnp.ndarray, + lse_mode: bool = True + ) -> jnp.ndarray: + if lse_mode: + value = jnp.where(jnp.isfinite(value), value, 0.0) + return (1.0 - weight) * value + weight * new_value + else: + value = jnp.where(value > 0.0, value, 1.0) + return value ** (1.0 - weight) * new_value ** weight diff --git a/ott/core/continuous_barycenter.py b/ott/solvers/linear/continuous_barycenter.py similarity index 89% rename from ott/core/continuous_barycenter.py rename to ott/solvers/linear/continuous_barycenter.py index 8a328264c..5bc105321 100644 --- a/ott/core/continuous_barycenter.py +++ b/ott/solvers/linear/continuous_barycenter.py @@ -20,8 +20,11 @@ import jax import jax.numpy as jnp -from ott.core import bar_problems, fixed_point_loop, linear_problems, was_solver from ott.geometry import pointcloud +from ott.math import fixed_point_loop +from ott.math import utils as mu +from ott.problems.linear import barycenter_problem, linear_problem +from ott.utils import was_solver __all__ = ["BarycenterState", "WassersteinBarycenter"] @@ -53,7 +56,7 @@ def set(self, **kwargs: Any) -> 'BarycenterState': return self._replace(**kwargs) def update( - self, iteration: int, bar_prob: bar_problems.BarycenterProblem, + self, iteration: int, bar_prob: barycenter_problem.BarycenterProblem, linear_ot_solver: Any, store_errors: bool ) -> 'BarycenterState': seg_y, seg_b = bar_prob.segmented_y_b @@ -63,7 +66,7 @@ def solve_linear_ot( a: Optional[jnp.ndarray], x: jnp.ndarray, b: jnp.ndarray, y: jnp.ndarray ): out = linear_ot_solver( - linear_problems.LinearProblem( + linear_problem.LinearProblem( pointcloud.PointCloud( x, y, @@ -101,7 +104,7 @@ def solve_linear_ot( # Approximation of barycenter as barycenter of barycenters per measure. - barycenters_per_measure = bar_problems.barycentric_projection( + barycenters_per_measure = mu.barycentric_projection( matrices, seg_y, bar_prob.cost_fn ) @@ -123,7 +126,7 @@ class WassersteinBarycenter(was_solver.WassersteinSolver): def __call__( self, - bar_prob: bar_problems.BarycenterProblem, + bar_prob: barycenter_problem.BarycenterProblem, bar_size: int = 100, x_init: Optional[jnp.ndarray] = None, rng: int = 0 @@ -134,7 +137,7 @@ def __call__( def init_state( self, - bar_prob: bar_problems.BarycenterProblem, + bar_prob: barycenter_problem.BarycenterProblem, bar_size: int, x_init: Optional[jnp.ndarray] = None, # TODO(michalk8): change the API to pass the PRNG key directly @@ -183,18 +186,20 @@ def init_state( ) def output_from_state(self, state: BarycenterState) -> BarycenterState: + # TODO(michalk8): create an output variable to match rest of the framework return state def iterations( solver: WassersteinBarycenter, bar_size: int, - bar_prob: bar_problems.BarycenterProblem, x_init: jnp.ndarray, rng: int + bar_prob: barycenter_problem.BarycenterProblem, x_init: jnp.ndarray, + rng: int ) -> BarycenterState: """Jittable Wasserstein barycenter outer loop.""" def cond_fn( iteration: int, constants: Tuple[WassersteinBarycenter, - bar_problems.BarycenterProblem], + barycenter_problem.BarycenterProblem], state: BarycenterState ) -> bool: solver, _ = constants @@ -202,7 +207,7 @@ def cond_fn( def body_fn( iteration, constants: Tuple[WassersteinBarycenter, - bar_problems.BarycenterProblem], + barycenter_problem.BarycenterProblem], state: BarycenterState, compute_error: bool ) -> BarycenterState: del compute_error # Always assumed True diff --git a/ott/core/discrete_barycenter.py b/ott/solvers/linear/discrete_barycenter.py similarity index 94% rename from ott/core/discrete_barycenter.py rename to ott/solvers/linear/discrete_barycenter.py index dd53806df..6373c9657 100644 --- a/ott/core/discrete_barycenter.py +++ b/ott/solvers/linear/discrete_barycenter.py @@ -13,23 +13,29 @@ # limitations under the License. # Lint as: python3 -"""Implementation of Janati+(2020) Wasserstein barycenter algorithm.""" +"""Implementation of :cite:`janati:20` Wasserstein barycenter algorithm.""" -import collections import functools -from typing import Optional, Sequence +from typing import NamedTuple, Optional, Sequence import jax import jax.numpy as jnp -from ott.core import fixed_point_loop, sinkhorn from ott.geometry import geometry +from ott.math import fixed_point_loop +from ott.solvers.linear import sinkhorn -SinkhornBarycenterOutput = collections.namedtuple( - 'Barycenter', ['f', 'g', 'histogram', 'errors'] -) +__all__ = ["SinkhornBarycenterOutput", "discrete_barycenter"] +class SinkhornBarycenterOutput(NamedTuple): + f: jnp.ndarray + g: jnp.ndarray + histogram: jnp.ndarray + errors: jnp.ndarray + + +# TODO(michalk8): refactor as a solver def discrete_barycenter( geom: geometry.Geometry, a: jnp.ndarray, diff --git a/ott/core/sinkhorn.py b/ott/solvers/linear/sinkhorn.py similarity index 94% rename from ott/core/sinkhorn.py rename to ott/solvers/linear/sinkhorn.py index aee00d797..28f41657d 100644 --- a/ott/core/sinkhorn.py +++ b/ott/solvers/linear/sinkhorn.py @@ -21,14 +21,16 @@ import numpy as np from typing_extensions import Literal -from ott.core import anderson as anderson_lib -from ott.core import fixed_point_loop -from ott.core import implicit_differentiation as implicit_lib -from ott.core import initializers as init_lib -from ott.core import linear_problems -from ott.core import momentum as momentum_lib -from ott.core import potentials, unbalanced_functions from ott.geometry import geometry +from ott.initializers.linear import initializers as init_lib +from ott.math import fixed_point_loop +from ott.math import implicit_differentiation as implicit_lib +from ott.math import potentials, unbalanced_functions +from ott.problems.linear import linear_problem + +__all__ = ["Sinkhorn", "SinkhornOutput"] + +from ott.solvers.linear import acceleration class SinkhornState(NamedTuple): @@ -45,19 +47,19 @@ def set(self, **kwargs: Any) -> 'SinkhornState': return self._replace(**kwargs) def solution_error( - self, ot_prob: linear_problems.LinearProblem, norm_error: Sequence[int], + self, ot_prob: linear_problem.LinearProblem, norm_error: Sequence[int], lse_mode: bool ) -> jnp.ndarray: return solution_error(self.fu, self.gv, ot_prob, norm_error, lse_mode) def ent_reg_cost( - self, ot_prob: linear_problems.LinearProblem, lse_mode: bool + self, ot_prob: linear_problem.LinearProblem, lse_mode: bool ) -> float: return ent_reg_cost(self.fu, self.gv, ot_prob, lse_mode) def solution_error( - f_u: jnp.ndarray, g_v: jnp.ndarray, ot_prob: linear_problems.LinearProblem, + f_u: jnp.ndarray, g_v: jnp.ndarray, ot_prob: linear_problem.LinearProblem, norm_error: Sequence[int], lse_mode: bool ) -> jnp.ndarray: """Given two potential/scaling solutions, computes deviation to optimality. @@ -142,7 +144,7 @@ def marginal_error( def ent_reg_cost( - f: jnp.ndarray, g: jnp.ndarray, ot_prob: linear_problems.LinearProblem, + f: jnp.ndarray, g: jnp.ndarray, ot_prob: linear_problem.LinearProblem, lse_mode: bool ) -> float: r"""Compute objective of Sinkhorn for OT problem given dual solutions. @@ -209,14 +211,14 @@ class SinkhornOutput(NamedTuple): g: Optional[jnp.ndarray] = None errors: Optional[jnp.ndarray] = None reg_ot_cost: Optional[float] = None - ot_prob: Optional[linear_problems.LinearProblem] = None + ot_prob: Optional[linear_problem.LinearProblem] = None def set(self, **kwargs: Any) -> 'SinkhornOutput': """Return a copy of self, with potential overwrites.""" return self._replace(**kwargs) def set_cost( - self, ot_prob: linear_problems.LinearProblem, lse_mode: bool, + self, ot_prob: linear_problem.LinearProblem, lse_mode: bool, use_danskin: bool ) -> 'SinkhornOutput': f = jax.lax.stop_gradient(self.f) if use_danskin else self.f @@ -225,7 +227,7 @@ def set_cost( @property def linear(self) -> bool: - return isinstance(self.ot_prob, linear_problems.LinearProblem) + return isinstance(self.ot_prob, linear_problem.LinearProblem) @property def geom(self) -> geometry.Geometry: @@ -352,8 +354,8 @@ def __init__( inner_iterations: int = 10, min_iterations: int = 0, max_iterations: int = 2000, - momentum: Optional[momentum_lib.Momentum] = None, - anderson: Optional[anderson_lib.AndersonAcceleration] = None, + momentum: Optional[acceleration.Momentum] = None, + anderson: Optional[acceleration.AndersonAcceleration] = None, parallel_dual_updates: bool = False, use_danskin: Optional[bool] = None, implicit_diff: Optional[implicit_lib.ImplicitDiff @@ -373,20 +375,20 @@ def __init__( self.implicit_diff = implicit_diff if momentum is not None: - self.momentum = momentum_lib.Momentum( + self.momentum = acceleration.Momentum( momentum.start, momentum.error_threshold, momentum.value, self.inner_iterations ) else: # Use no momentum if using Anderson or unrolling. if self.anderson is not None or self.implicit_diff is None: - self.momentum = momentum_lib.Momentum( + self.momentum = acceleration.Momentum( inner_iterations=self.inner_iterations ) # Use adaptive momentum from 300th iteration. Only do so # if error is already below threshold below. else: - self.momentum = momentum_lib.Momentum( + self.momentum = acceleration.Momentum( start=300, error_threshold=1e-2, inner_iterations=self.inner_iterations @@ -404,7 +406,7 @@ def __init__( implicit_lib.ImplicitDiff() if self.implicit_diff is None else self.implicit_diff ) - self.momentum = momentum_lib.Momentum( + self.momentum = acceleration.Momentum( inner_iterations=self.inner_iterations ) @@ -415,7 +417,7 @@ def __init__( def __call__( self, - ot_prob: linear_problems.LinearProblem, + ot_prob: linear_problem.LinearProblem, init: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray]] = (None, None), ) -> SinkhornOutput: """Run Sinkhorn algorithm. @@ -436,7 +438,7 @@ def __call__( return run_fn(ot_prob, self, (init_dual_a, init_dual_b)) def lse_step( - self, ot_prob: linear_problems.LinearProblem, state: SinkhornState, + self, ot_prob: linear_problem.LinearProblem, state: SinkhornState, iteration: int ) -> SinkhornState: """Sinkhorn LSE update.""" @@ -458,7 +460,7 @@ def lse_step( return state.set(fu=fu, gv=gv) def kernel_step( - self, ot_prob: linear_problems.LinearProblem, state: SinkhornState, + self, ot_prob: linear_problem.LinearProblem, state: SinkhornState, iteration: int ) -> SinkhornState: """Sinkhorn multiplicative update.""" @@ -478,7 +480,7 @@ def kernel_step( return state.set(fu=fu, gv=gv) def one_iteration( - self, ot_prob: linear_problems.LinearProblem, state: SinkhornState, + self, ot_prob: linear_problem.LinearProblem, state: SinkhornState, iteration: int, compute_error: bool ) -> SinkhornState: """Carries out sinkhorn iteration. @@ -544,8 +546,8 @@ def outer_iterations(self) -> int: return np.ceil(self.max_iterations / self.inner_iterations).astype(int) def init_state( - self, ot_prob: linear_problems.LinearProblem, init: Tuple[jnp.ndarray, - jnp.ndarray] + self, ot_prob: linear_problem.LinearProblem, init: Tuple[jnp.ndarray, + jnp.ndarray] ) -> SinkhornState: """Return the initial state of the loop.""" fu, gv = init @@ -555,7 +557,7 @@ def init_state( return self.anderson.init_maps(ot_prob, state) if self.anderson else state def output_from_state( - self, ot_prob: linear_problems.LinearProblem, state: SinkhornState + self, ot_prob: linear_problem.LinearProblem, state: SinkhornState ) -> SinkhornOutput: """Create an output from a loop state. @@ -625,7 +627,7 @@ def tree_unflatten(cls, aux_data, children): def run( - ot_prob: linear_problems.LinearProblem, solver: Sinkhorn, + ot_prob: linear_problem.LinearProblem, solver: Sinkhorn, init: Tuple[jnp.ndarray, ...] ) -> SinkhornOutput: """Run loop of the solver, outputting a state upgraded to an output.""" @@ -638,20 +640,20 @@ def run( def iterations( - ot_prob: linear_problems.LinearProblem, solver: Sinkhorn, + ot_prob: linear_problem.LinearProblem, solver: Sinkhorn, init: Tuple[jnp.ndarray, ...] ) -> SinkhornOutput: """Jittable Sinkhorn loop. args contain initialization variables.""" def cond_fn( - iteration: int, const: Tuple[linear_problems.LinearProblem, Sinkhorn], + iteration: int, const: Tuple[linear_problem.LinearProblem, Sinkhorn], state: SinkhornState ) -> bool: _, solver = const return solver._continue(state, iteration) def body_fn( - iteration: int, const: Tuple[linear_problems.LinearProblem, Sinkhorn], + iteration: int, const: Tuple[linear_problem.LinearProblem, Sinkhorn], state: SinkhornState, compute_error: bool ) -> SinkhornState: ot_prob, solver = const @@ -675,10 +677,10 @@ def body_fn( def _iterations_taped( - ot_prob: linear_problems.LinearProblem, solver: Sinkhorn, + ot_prob: linear_problem.LinearProblem, solver: Sinkhorn, init: Tuple[jnp.ndarray, ...] ) -> Tuple[SinkhornOutput, Tuple[jnp.ndarray, jnp.ndarray, - linear_problems.LinearProblem, Sinkhorn]]: + linear_problem.LinearProblem, Sinkhorn]]: """Run forward pass of the Sinkhorn algorithm storing side information.""" state = iterations(ot_prob, solver, init) return state, (state.f, state.g, ot_prob, solver) @@ -750,16 +752,16 @@ def make( ) # If no params are passed, align default with that provide in Sinkhorn solver. if momentum is None and chg_momentum_from is None: - mom = momentum_lib.Momentum(start=300, error_threshold=1e-2) + mom = acceleration.Momentum(start=300, error_threshold=1e-2) elif momentum is None: - mom = momentum_lib.Momentum(start=chg_momentum_from) + mom = acceleration.Momentum(start=chg_momentum_from) elif chg_momentum_from is None: - mom = momentum_lib.Momentum(value=momentum) + mom = acceleration.Momentum(value=momentum) else: - mom = momentum_lib.Momentum(start=chg_momentum_from, value=momentum) + mom = acceleration.Momentum(start=chg_momentum_from, value=momentum) if anderson_acceleration > 0: - anderson = anderson_lib.AndersonAcceleration( + anderson = acceleration.AndersonAcceleration( memory=anderson_acceleration, refresh_every=refresh_anderson_frequency ) else: @@ -1100,5 +1102,5 @@ def sinkhorn( by the user. """ sink = make(**kwargs) - ot_prob = linear_problems.LinearProblem(geom, a, b, tau_a, tau_b) + ot_prob = linear_problem.LinearProblem(geom, a, b, tau_a, tau_b) return sink(ot_prob, (init_dual_a, init_dual_b)) diff --git a/ott/core/sinkhorn_lr.py b/ott/solvers/linear/sinkhorn_lr.py similarity index 94% rename from ott/core/sinkhorn_lr.py rename to ott/solvers/linear/sinkhorn_lr.py index e27989f44..62364ef14 100644 --- a/ott/core/sinkhorn_lr.py +++ b/ott/solvers/linear/sinkhorn_lr.py @@ -21,11 +21,14 @@ import jax.scipy as jsp from typing_extensions import Literal -from ott.core import _math_utils as mu -from ott.core import fixed_point_loop -from ott.core import initializers_lr as init_lib -from ott.core import linear_problems, sinkhorn from ott.geometry import geometry, low_rank, pointcloud +from ott.initializers.linear import initializers_lr as init_lib +from ott.math import fixed_point_loop +from ott.math import utils as mu +from ott.problems.linear import linear_problem +from ott.solvers.linear import sinkhorn + +__all__ = ["LRSinkhorn", "LRSinkhornOutput"] class LRSinkhornState(NamedTuple): @@ -48,13 +51,13 @@ def compute_error(self, previous_state: "LRSinkhornState") -> float: def reg_ot_cost( self, - ot_prob: linear_problems.LinearProblem, + ot_prob: linear_problem.LinearProblem, use_danskin: bool = False ) -> float: return compute_reg_ot_cost(self.q, self.r, self.g, ot_prob, use_danskin) def solution_error( - self, ot_prob: linear_problems.LinearProblem, norm_error: Tuple[int, ...], + self, ot_prob: linear_problem.LinearProblem, norm_error: Tuple[int, ...], lse_mode: bool ) -> jnp.ndarray: return solution_error(self.q, self.r, ot_prob, norm_error, lse_mode) @@ -68,7 +71,7 @@ def compute_reg_ot_cost( q: jnp.ndarray, r: jnp.ndarray, g: jnp.ndarray, - ot_prob: linear_problems.LinearProblem, + ot_prob: linear_problem.LinearProblem, use_danskin: bool = False ) -> float: q = jax.lax.stop_gradient(q) if use_danskin else q @@ -78,7 +81,7 @@ def compute_reg_ot_cost( def solution_error( - q: jnp.ndarray, r: jnp.ndarray, ot_prob: linear_problems.LinearProblem, + q: jnp.ndarray, r: jnp.ndarray, ot_prob: linear_problem.LinearProblem, norm_error: Tuple[int, ...], lse_mode: bool ) -> jnp.ndarray: """Compute solution error. @@ -122,7 +125,7 @@ class LRSinkhornOutput(NamedTuple): # TODO(michalk8): must be called `errors`, because of `store_inner_errors` # in future, enforce via class hierarchy errors: jnp.ndarray - ot_prob: linear_problems.LinearProblem + ot_prob: linear_problem.LinearProblem # TODO(michalk8): Optional is an artifact of the current impl., refactor reg_ot_cost: Optional[float] = None @@ -132,7 +135,7 @@ def set(self, **kwargs: Any) -> 'LRSinkhornOutput': def set_cost( self, - ot_prob: linear_problems.LinearProblem, + ot_prob: linear_problem.LinearProblem, lse_mode: bool, use_danskin: bool = False ) -> 'LRSinkhornOutput': @@ -141,14 +144,14 @@ def set_cost( def compute_reg_ot_cost( self, - ot_prob: linear_problems.LinearProblem, + ot_prob: linear_problem.LinearProblem, use_danskin: bool = False, ) -> float: return compute_reg_ot_cost(self.q, self.r, self.g, ot_prob, use_danskin) @property def linear(self) -> bool: - return isinstance(self.ot_prob, linear_problems.LinearProblem) + return isinstance(self.ot_prob, linear_problem.LinearProblem) @property def geom(self) -> geometry.Geometry: @@ -285,7 +288,7 @@ def __init__( def __call__( self, - ot_prob: linear_problems.LinearProblem, + ot_prob: linear_problem.LinearProblem, init: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray], Optional[jnp.ndarray]] = (None, None, None), key: Optional[jnp.ndarray] = None, @@ -316,7 +319,7 @@ def __call__( def _lr_costs( self, - ot_prob: linear_problems.LinearProblem, + ot_prob: linear_problem.LinearProblem, state: LRSinkhornState, ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, float]: log_q, log_r, log_g = ( @@ -353,7 +356,7 @@ def dykstra_update( c_r: jnp.ndarray, h: jnp.ndarray, gamma: float, - ot_prob: linear_problems.LinearProblem, + ot_prob: linear_problem.LinearProblem, min_entry_value: float = 1e-6, tolerance: float = 1e-3, min_iter: int = 0, @@ -465,7 +468,7 @@ def recompute_couplings( return recompute_couplings(f1, g1_old, c_q, f2, g2_old, c_r, h_old, gamma) def lse_step( - self, ot_prob: linear_problems.LinearProblem, state: LRSinkhornState, + self, ot_prob: linear_problem.LinearProblem, state: LRSinkhornState, iteration: int ) -> LRSinkhornState: """LR Sinkhorn LSE update.""" @@ -476,7 +479,7 @@ def lse_step( return state.set(q=q, g=g, r=r, gamma=gamma) def kernel_step( - self, ot_prob: linear_problems.LinearProblem, state: LRSinkhornState, + self, ot_prob: linear_problem.LinearProblem, state: LRSinkhornState, iteration: int ) -> NoReturn: """Not implemented.""" @@ -484,7 +487,7 @@ def kernel_step( raise NotImplementedError("Not implemented.") def one_iteration( - self, ot_prob: linear_problems.LinearProblem, state: LRSinkhornState, + self, ot_prob: linear_problem.LinearProblem, state: LRSinkhornState, iteration: int, compute_error: bool ) -> LRSinkhornState: """Carries out one LR sinkhorn iteration. @@ -539,7 +542,7 @@ def is_entropic(self) -> bool: return self.epsilon > 0. def create_initializer( - self, prob: linear_problems.LinearProblem + self, prob: linear_problem.LinearProblem ) -> init_lib.LRInitializer: """Create a low-rank Sinkhorn initializer. @@ -569,7 +572,7 @@ def create_initializer( return initializer def init_state( - self, ot_prob: linear_problems.LinearProblem, + self, ot_prob: linear_problem.LinearProblem, init: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray] ) -> LRSinkhornState: """Return the initial state of the loop.""" @@ -585,7 +588,7 @@ def init_state( ) def output_from_state( - self, ot_prob: linear_problems.LinearProblem, state: LRSinkhornState + self, ot_prob: linear_problem.LinearProblem, state: LRSinkhornState ) -> LRSinkhornOutput: """Create an output from a loop state. @@ -641,7 +644,7 @@ def _diverged(self, state: LRSinkhornState, iteration: int) -> bool: def run( - ot_prob: linear_problems.LinearProblem, + ot_prob: linear_problem.LinearProblem, solver: LRSinkhorn, init: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray], Optional[jnp.ndarray]], diff --git a/ott/solvers/nn/__init__.py b/ott/solvers/nn/__init__.py new file mode 100644 index 000000000..e6c465331 --- /dev/null +++ b/ott/solvers/nn/__init__.py @@ -0,0 +1 @@ +# TODO(michalk8): imports diff --git a/ott/core/icnn.py b/ott/solvers/nn/icnn.py similarity index 97% rename from ott/core/icnn.py rename to ott/solvers/nn/icnn.py index 8be945f4e..4e93bc91c 100644 --- a/ott/core/icnn.py +++ b/ott/solvers/nn/icnn.py @@ -13,7 +13,7 @@ # limitations under the License. # Lint as: python3 -"""Implementation of Amos+(2017) input convex neural networks (ICNN).""" +"""Implementation of :cite:`amos:17` input convex neural networks (ICNN).""" from typing import Any, Callable, Sequence, Tuple, Union @@ -24,8 +24,10 @@ from flax.training import train_state from jax.nn import initializers -from ott.core.layers import PosDefPotentials, PositiveDense -from ott.geometry import matrix_square_root +from ott.math import matrix_square_root +from ott.solvers.nn.layers import PosDefPotentials, PositiveDense + +__all__ = ["ICNN"] PRNGKey = Any Shape = Tuple[int] diff --git a/ott/core/layers.py b/ott/solvers/nn/layers.py similarity index 86% rename from ott/core/layers.py rename to ott/solvers/nn/layers.py index 5a4b0a396..a0add5a9d 100644 --- a/ott/core/layers.py +++ b/ott/solvers/nn/layers.py @@ -11,7 +11,7 @@ # limitations under the License. # Lint as: python3 -"""Layers used in input convex neural networks (Amos+(2017), Bunne+(2022)).""" +"""Layers used in input convex neural networks :cite:`amos:17,bunne:22`.""" from typing import Any, Callable, Tuple @@ -19,6 +19,8 @@ import jax.numpy as jnp from flax import linen as nn +__all__ = ["PositiveDense", "PosDefPotentials"] + PRNGKey = Any Shape = Tuple[int] Dtype = Any @@ -30,8 +32,8 @@ class PositiveDense(nn.Module): Args: dim_hidden: the number of output dim_hidden. - rectifier_fn: choice of rectiver function (default: softplus function). - inv_rectifier_fn: choice of inverse rectiver function + rectifier_fn: choice of rectifier function (default: softplus function). + inv_rectifier_fn: choice of inverse rectifier function (default: inverse softplus function). dtype: the dtype of the computation (default: float32). precision: numerical precision of computation see `jax.lax.Precision` @@ -50,11 +52,11 @@ class PositiveDense(nn.Module): bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.zeros @nn.compact - def __call__(self, inputs): + def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: """Applies a linear transformation to inputs along the last dimension. Args: - inputs: The nd-array to be transformed. + inputs: Array to be transformed. Returns: The transformed input. """ @@ -79,8 +81,8 @@ class PosDefPotentials(nn.Module): """A layer to output (0.5 [A_i A_i^T] (x - b_i)_i potentials. Args: - use_bias: whether to add a bias to the output (default: True). - dtype: the dtype of the computation (default: float32). + use_bias: whether to add a bias to the output. + dtype: the dtype of the computation. precision: numerical precision of computation see `jax.lax.Precision` for details. kernel_init: initializer function for the weight matrix. @@ -96,11 +98,12 @@ class PosDefPotentials(nn.Module): bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.zeros @nn.compact - def __call__(self, inputs): - """Applies a few quadratic forms. + def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + """Apply a few quadratic forms. Args: - inputs: The nd-array to be transformed (possibly batched) + inputs: Array to be transformed (possibly batched). + Returns: The transformed input. """ diff --git a/ott/core/neuraldual.py b/ott/solvers/nn/neuraldual.py similarity index 99% rename from ott/core/neuraldual.py rename to ott/solvers/nn/neuraldual.py index b4716a06d..b7b773164 100644 --- a/ott/core/neuraldual.py +++ b/ott/solvers/nn/neuraldual.py @@ -23,7 +23,10 @@ from flax import core from typing_extensions import Literal -from ott.core import icnn, potentials +from ott.math import potentials +from ott.solvers.nn import icnn + +__all__ = ["NeuralDualSolver"] Train_t = Dict[Literal["training_logs", "validation_logs"], List[float]] Potentials_t = potentials.DualPotentials diff --git a/ott/solvers/quadratic/__init__.py b/ott/solvers/quadratic/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ott/core/gromov_wasserstein.py b/ott/solvers/quadratic/gromov_wasserstein.py similarity index 95% rename from ott/core/gromov_wasserstein.py rename to ott/solvers/quadratic/gromov_wasserstein.py index 0f9f61305..98c3a0af6 100644 --- a/ott/core/gromov_wasserstein.py +++ b/ott/solvers/quadratic/gromov_wasserstein.py @@ -20,17 +20,14 @@ import jax.numpy as jnp from typing_extensions import Literal -from ott.core import ( - fixed_point_loop, - initializers_lr, - linear_problems, - quad_initializers, - quad_problems, - sinkhorn, - sinkhorn_lr, - was_solver, -) from ott.geometry import epsilon_scheduler, geometry, low_rank, pointcloud +from ott.initializers.linear import initializers_lr +from ott.initializers.quadratic import initializers as quad_initializers +from ott.math import fixed_point_loop +from ott.problems.linear import linear_problem +from ott.problems.quadratic import quadratic_costs, quadratic_problem +from ott.solvers.linear import sinkhorn, sinkhorn_lr +from ott.utils import was_solver LinearOutput = Union[sinkhorn.SinkhornOutput, sinkhorn_lr.LRSinkhornOutput] @@ -107,7 +104,7 @@ class GWState(NamedTuple): costs: jnp.ndarray linear_convergence: jnp.ndarray linear_state: LinearOutput - linear_pb: linear_problems.LinearProblem + linear_pb: linear_problem.LinearProblem old_transport_mass: float keys: Optional[jnp.ndarray] = None errors: Optional[jnp.ndarray] = None @@ -118,7 +115,7 @@ def set(self, **kwargs: Any) -> 'GWState': def update( self, iteration: int, linear_sol: LinearOutput, - linear_pb: linear_problems.LinearProblem, store_errors: bool, + linear_pb: linear_problem.LinearProblem, store_errors: bool, old_transport_mass: float ) -> 'GWState': costs = self.costs.at[iteration].set(linear_sol.reg_ot_cost) @@ -191,8 +188,8 @@ def __init__( def __call__( self, - prob: quad_problems.QuadraticProblem, - init: Optional[linear_problems.LinearProblem] = None, + prob: quadratic_problem.QuadraticProblem, + init: Optional[linear_problem.LinearProblem] = None, key: Optional[jnp.ndarray] = None, **kwargs: Any, ) -> GWOutput: @@ -240,8 +237,8 @@ def __call__( def init_state( self, - prob: quad_problems.QuadraticProblem, - init: linear_problems.LinearProblem, + prob: quadratic_problem.QuadraticProblem, + init: linear_problem.LinearProblem, key: jnp.ndarray, ) -> GWState: """Initialize the state of the Gromov-Wasserstein iterations. @@ -292,7 +289,7 @@ def output_from_state(self, state: GWState) -> GWOutput: ) def create_initializer( - self, prob: quad_problems.QuadraticProblem + self, prob: quadratic_problem.QuadraticProblem ) -> quad_initializers.BaseQuadraticInitializer: """Create quadratic, possibly low-rank initializer. @@ -309,7 +306,7 @@ def create_initializer( assert isinstance( self.quad_initializer, quad_initializers.LRQuadraticInitializer ), f"Expected quadratic initializer to be low rank, " \ - f"found `{type(self.quad_initializer).__name___}`." + f"found `{type(self.quad_initializer).__name__}`." assert self.quad_initializer.rank == self.rank, \ f"Expected quadratic initializer of rank `{self.rank}`, " \ f"found `{self.quad_initializer.rank}`." @@ -345,8 +342,8 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: def iterations( solver: GromovWasserstein, - prob: quad_problems.QuadraticProblem, - init: linear_problems.LinearProblem, + prob: quadratic_problem.QuadraticProblem, + init: linear_problem.LinearProblem, key: jnp.ndarray, ) -> GWOutput: """Jittable Gromov-Wasserstein outer loop.""" @@ -469,7 +466,7 @@ def gromov_wasserstein( scale_cost: Optional[Union[bool, float, str]] = False, a: Optional[jnp.ndarray] = None, b: Optional[jnp.ndarray] = None, - loss: Union[Literal['sqeucl', 'kl'], quad_problems.GWLoss] = 'sqeucl', + loss: Union[Literal['sqeucl', 'kl'], quadratic_costs.GWLoss] = 'sqeucl', tau_a: Optional[float] = 1.0, tau_b: Optional[float] = 1.0, gw_unbalanced_correction: bool = True, @@ -535,7 +532,7 @@ def gromov_wasserstein( Returns: A GromovWassersteinState named tuple. """ - prob = quad_problems.QuadraticProblem( + prob = quadratic_problem.QuadraticProblem( geom_xx, geom_yy, geom_xy=geom_xy, diff --git a/ott/core/gw_barycenter.py b/ott/solvers/quadratic/gw_barycenter.py similarity index 93% rename from ott/core/gw_barycenter.py rename to ott/solvers/quadratic/gw_barycenter.py index 18891afd1..8020f99e3 100644 --- a/ott/core/gw_barycenter.py +++ b/ott/solvers/quadratic/gw_barycenter.py @@ -4,14 +4,12 @@ import jax import jax.numpy as jnp -from ott.core import ( - bar_problems, - fixed_point_loop, - gromov_wasserstein, - linear_problems, - was_solver, -) from ott.geometry import pointcloud +from ott.math import fixed_point_loop +from ott.problems.linear import linear_problem +from ott.problems.quadratic import barycenter_problem +from ott.solvers.quadratic import gromov_wasserstein +from ott.utils import was_solver __all__ = ["GWBarycenterState", "GromovWassersteinBarycenter"] @@ -93,7 +91,7 @@ def __init__( self._quad_solver = gromov_wasserstein.GromovWasserstein(**kwargs) def __call__( - self, problem: bar_problems.GWBarycenterProblem, bar_size: int, + self, problem: barycenter_problem.GWBarycenterProblem, bar_size: int, **kwargs: Any ) -> GWBarycenterState: """Solver the (fused) GW barycenter problem. @@ -113,7 +111,7 @@ def __call__( def init_state( self, - problem: bar_problems.GWBarycenterProblem, + problem: barycenter_problem.GWBarycenterProblem, bar_size: int, bar_init: Optional[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]] = None, @@ -188,7 +186,7 @@ def update_state( self, state: GWBarycenterState, iteration: int, - problem: bar_problems.GWBarycenterProblem, + problem: barycenter_problem.GWBarycenterProblem, store_errors: bool = True, ) -> Tuple[float, bool, jnp.ndarray, Optional[jnp.ndarray]]: @@ -236,7 +234,7 @@ def solve_gw( def output_from_state(self, state: GWBarycenterState) -> GWBarycenterState: """No-op.""" # TODO(michalk8): just for consistency with continuous barycenter - # will be refactored in the future + # will be refactored in the future to create an output return state def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: @@ -279,13 +277,14 @@ def init_transports( geom = pointcloud.PointCloud( x, y, epsilon=epsilon, src_mask=a > 0, tgt_mask=b > 0 ) - problem = linear_problems.LinearProblem(geom, a=a, b=b) + problem = linear_problem.LinearProblem(geom, a=a, b=b) return solver(problem).matrix def iterations( solver: GromovWassersteinBarycenter, - problem: bar_problems.GWBarycenterProblem, init_state: GWBarycenterState + problem: barycenter_problem.GWBarycenterProblem, + init_state: GWBarycenterState ) -> GWBarycenterState: def cond_fn( @@ -297,7 +296,7 @@ def cond_fn( def body_fn( iteration, constants: Tuple[GromovWassersteinBarycenter, - bar_problems.GWBarycenterProblem], + barycenter_problem.GWBarycenterProblem], state: GWBarycenterState, compute_error: bool ) -> GWBarycenterState: del compute_error # always assumed true diff --git a/ott/tools/gaussian_mixture/gaussian_mixture_pair.py b/ott/tools/gaussian_mixture/gaussian_mixture_pair.py index c5b8968c1..eccd3da68 100644 --- a/ott/tools/gaussian_mixture/gaussian_mixture_pair.py +++ b/ott/tools/gaussian_mixture/gaussian_mixture_pair.py @@ -16,8 +16,8 @@ import jax import jax.numpy as jnp -from ott.core import sinkhorn from ott.geometry import costs, geometry, pointcloud +from ott.solvers.linear import sinkhorn from ott.tools.gaussian_mixture import gaussian_mixture diff --git a/ott/tools/gaussian_mixture/scale_tril.py b/ott/tools/gaussian_mixture/scale_tril.py index c4f7ea077..53a8df1fc 100644 --- a/ott/tools/gaussian_mixture/scale_tril.py +++ b/ott/tools/gaussian_mixture/scale_tril.py @@ -18,7 +18,8 @@ import jax import jax.numpy as jnp -from ott.geometry import costs, matrix_square_root +from ott.geometry import costs +from ott.math import matrix_square_root from ott.tools.gaussian_mixture import linalg diff --git a/ott/tools/k_means.py b/ott/tools/k_means.py index 335d16dc8..4ad02723d 100644 --- a/ott/tools/k_means.py +++ b/ott/tools/k_means.py @@ -19,8 +19,8 @@ import jax.numpy as jnp from typing_extensions import Literal -from ott.core import fixed_point_loop from ott.geometry import costs, pointcloud +from ott.math import fixed_point_loop __all__ = ["k_means", "KMeansOutput"] diff --git a/ott/tools/segment_sinkhorn.py b/ott/tools/segment_sinkhorn.py index b770e06be..4a5427c6e 100644 --- a/ott/tools/segment_sinkhorn.py +++ b/ott/tools/segment_sinkhorn.py @@ -17,8 +17,9 @@ from jax import numpy as jnp -from ott.core import segment, sinkhorn from ott.geometry import costs, pointcloud +from ott.solvers.linear import sinkhorn +from ott.utils import segment def segment_sinkhorn( diff --git a/ott/tools/sinkhorn_divergence.py b/ott/tools/sinkhorn_divergence.py index faa3e1ba0..46b196339 100644 --- a/ott/tools/sinkhorn_divergence.py +++ b/ott/tools/sinkhorn_divergence.py @@ -17,8 +17,10 @@ import jax.numpy as jnp -from ott.core import potentials, segment, sinkhorn from ott.geometry import costs, geometry, pointcloud +from ott.math import potentials +from ott.solvers.linear import sinkhorn +from ott.utils import segment __all__ = [ "sinkhorn_divergence", "segment_sinkhorn_divergence", diff --git a/ott/tools/transport.py b/ott/tools/transport.py index 73674ed78..c4e98fa68 100644 --- a/ott/tools/transport.py +++ b/ott/tools/transport.py @@ -27,13 +27,19 @@ ott.core.gromov_wasserstein) for better control over the parameters. """ -from typing import Any, NamedTuple, Optional +from typing import Any, NamedTuple, Optional, Union import jax.numpy as jnp +import numpy as np from typing_extensions import Literal -from ott.core import gromov_wasserstein, linear_problems, problems, sinkhorn -from ott.geometry import geometry +from ott.geometry import geometry, pointcloud +from ott.problems.linear import linear_problem +from ott.problems.quadratic import quadratic_problem +from ott.solvers.linear import sinkhorn +from ott.solvers.quadratic import gromov_wasserstein + +__all__ = ["Transport"] class Transport(NamedTuple): @@ -44,7 +50,7 @@ class Transport(NamedTuple): @property def linear(self) -> bool: - return isinstance(self.problem, linear_problems.LinearProblem) + return isinstance(self.problem, linear_problem.LinearProblem) @property def geom(self) -> geometry.Geometry: @@ -109,7 +115,7 @@ def solve( fused_penalty = kwargs.pop('fused_penalty', None) eps_keys = ['epsilon', 'init', 'target', 'decay'] pb_kwargs = {k: v for k, v in kwargs.items() if k in eps_keys} - pb = problems.make( + pb = make( *args, objective=objective, a=a, @@ -120,7 +126,7 @@ def solve( fused_penalty=fused_penalty, **pb_kwargs ) - linear = isinstance(pb, linear_problems.LinearProblem) + linear = isinstance(pb, linear_problem.LinearProblem) solver_fn = sinkhorn.make if linear else gromov_wasserstein.make geom_keys = ['cost_fn', 'power', 'online'] @@ -130,3 +136,76 @@ def solve( solver = solver_fn(**kwargs) output = solver(pb, (init_dual_a, init_dual_b)) return Transport(pb, output) + + +def make( + *args: Union[jnp.ndarray, geometry.Geometry, linear_problem.LinearProblem, + quadratic_problem.QuadraticProblem], + a: Optional[jnp.ndarray] = None, + b: Optional[jnp.ndarray] = None, + tau_a: float = 1.0, + tau_b: float = 1.0, + objective: Optional[str] = None, + gw_unbalanced_correction: Optional[bool] = True, + fused_penalty: Optional[float] = None, + scale_cost: Optional[Union[bool, float, str]] = False, + **kwargs: Any, +): + """Make a problem from arrays, assuming PointCloud geometries.""" + if isinstance(args[0], (jnp.ndarray, np.ndarray)): + x = args[0] + y = args[1] if len(args) > 1 else args[0] + if ((objective == 'linear') or + (objective is None and x.shape[1] == y.shape[1])): # noqa: E129 + geom_xy = pointcloud.PointCloud(x, y, **kwargs) + return linear_problem.LinearProblem( + geom_xy, a=a, b=b, tau_a=tau_a, tau_b=tau_b + ) + elif ((objective == 'quadratic') or + (objective is None and x.shape[1] != y.shape[1])): + geom_xx = pointcloud.PointCloud(x, x, **kwargs) + geom_yy = pointcloud.PointCloud(y, y, **kwargs) + return quadratic_problem.QuadraticProblem( + geom_xx=geom_xx, + geom_yy=geom_yy, + geom_xy=None, + scale_cost=scale_cost, + a=a, + b=b, + tau_a=tau_a, + tau_b=tau_b, + gw_unbalanced_correction=gw_unbalanced_correction + ) + elif objective == 'fused': + geom_xx = pointcloud.PointCloud(x, x, **kwargs) + geom_yy = pointcloud.PointCloud(y, y, **kwargs) + geom_xy = pointcloud.PointCloud(x, y, **kwargs) + return quadratic_problem.QuadraticProblem( + geom_xx=geom_xx, + geom_yy=geom_yy, + geom_xy=geom_xy, + fused_penalty=fused_penalty, + scale_cost=scale_cost, + a=a, + b=b, + tau_a=tau_a, + tau_b=tau_b, + gw_unbalanced_correction=gw_unbalanced_correction + ) + else: + raise ValueError(f'Unknown transport problem `{objective}`') + elif isinstance(args[0], geometry.Geometry): + if len(args) == 1: + return linear_problem.LinearProblem( + *args, a=a, b=b, tau_a=tau_a, tau_b=tau_b + ) + return quadratic_problem.QuadraticProblem( + *args, a=a, b=b, tau_a=tau_a, tau_b=tau_b, scale_cost=scale_cost + ) + elif isinstance( + args[0], + (linear_problem.LinearProblem, quadratic_problem.QuadraticProblem) + ): + return args[0] + else: + raise ValueError('Cannot instantiate a transport problem.') diff --git a/ott/typing.py b/ott/typing.py new file mode 100644 index 000000000..1b27fa3aa --- /dev/null +++ b/ott/typing.py @@ -0,0 +1,22 @@ +from jax import numpy as jnp +from typing_extensions import Protocol + +# TODO(michalk8): introduce additional types here + + +class Transport(Protocol): + """Interface for the solution of a transport problem. + + Classes implementing those function do not have to inherit from it, the + class can however be used in type hints to support duck typing. + """ + + @property + def matrix(self) -> jnp.ndarray: + ... + + def apply(self, inputs: jnp.ndarray, axis: int) -> jnp.ndarray: + ... + + def marginal(self, axis: int = 0) -> jnp.ndarray: + ... diff --git a/ott/utils/__init__.py b/ott/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ott/core/dataclasses.py b/ott/utils/dataclasses.py similarity index 100% rename from ott/core/dataclasses.py rename to ott/utils/dataclasses.py diff --git a/ott/core/segment.py b/ott/utils/segment.py similarity index 100% rename from ott/core/segment.py rename to ott/utils/segment.py diff --git a/ott/core/was_solver.py b/ott/utils/was_solver.py similarity index 98% rename from ott/core/was_solver.py rename to ott/utils/was_solver.py index eab88c19a..7e81c4377 100644 --- a/ott/core/was_solver.py +++ b/ott/utils/was_solver.py @@ -19,7 +19,7 @@ import jax import jax.numpy as jnp -from ott.core import sinkhorn, sinkhorn_lr +from ott.solvers.linear import sinkhorn, sinkhorn_lr State = Union[sinkhorn.SinkhornState, sinkhorn_lr.LRSinkhornState, "continuous_barycenter.BarycenterState"] # noqa: F821 From 7f2fa029c88fae726e2b3e822f06a2b2acb0afdc Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Tue, 15 Nov 2022 14:19:41 +0100 Subject: [PATCH 02/34] Update PC docs --- ott/geometry/pointcloud.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ott/geometry/pointcloud.py b/ott/geometry/pointcloud.py index 9e153f267..dcd576a48 100644 --- a/ott/geometry/pointcloud.py +++ b/ott/geometry/pointcloud.py @@ -600,7 +600,7 @@ def to_LRCGeometry( Useful when this geometry is used in the linear term of fused GW. kwargs: Keyword arguments, such as ``rank``, to :meth:`~ott.geometry.geometry.Geometry.to_LRCGeometry` used when - the point cloud does not squared Euclidean cost. + the point cloud does not have squared Euclidean cost. Returns: Returns the unmodified point cloud if :math:`n m \ge (n + m) d`, where From a41ade869ed36645634144b2c0335f39a9b88bc3 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Wed, 16 Nov 2022 12:03:39 +0100 Subject: [PATCH 03/34] Update imports, fix some types --- ott/geometry/costs.py | 2 + ott/geometry/epsilon_scheduler.py | 2 + ott/geometry/geometry.py | 2 + ott/geometry/grid.py | 2 + ott/geometry/low_rank.py | 2 + ott/geometry/pointcloud.py | 2 + ott/initializers/linear/__init__.py | 1 + ott/initializers/linear/initializers.py | 25 +++++---- ott/initializers/linear/initializers_lr.py | 17 +++--- ott/initializers/nn/__init__.py | 1 + ott/initializers/nn/initializers.py | 60 +++++++++++++++++---- ott/initializers/nn/layers.py | 39 -------------- ott/initializers/quadratic/__init__.py | 1 + ott/initializers/quadratic/initializers.py | 13 +++-- ott/math/__init__.py | 9 ++++ ott/math/fixed_point_loop.py | 5 +- ott/math/implicit_differentiation.py | 20 ++++--- ott/math/matrix_square_root.py | 17 +++--- ott/math/potentials.py | 10 ++-- ott/math/utils.py | 1 + ott/problems/linear/__init__.py | 3 +- ott/problems/linear/linear_problem.py | 1 + ott/problems/quadratic/__init__.py | 4 +- ott/problems/quadratic/quadratic_problem.py | 10 ++-- ott/solvers/linear/__init__.py | 7 +++ ott/solvers/linear/acceleration.py | 2 +- ott/solvers/linear/continuous_barycenter.py | 2 - ott/solvers/linear/discrete_barycenter.py | 2 +- ott/solvers/linear/sinkhorn.py | 3 +- ott/solvers/nn/__init__.py | 2 +- ott/solvers/nn/neuraldual.py | 14 ++--- ott/solvers/quadratic/__init__.py | 1 + ott/solvers/quadratic/gromov_wasserstein.py | 2 + ott/utils/__init__.py | 1 + ott/utils/dataclasses.py | 2 + ott/utils/segment.py | 2 + ott/utils/was_solver.py | 16 +++--- 37 files changed, 176 insertions(+), 129 deletions(-) delete mode 100644 ott/initializers/nn/layers.py diff --git a/ott/geometry/costs.py b/ott/geometry/costs.py index 1af441d73..22225a16a 100644 --- a/ott/geometry/costs.py +++ b/ott/geometry/costs.py @@ -24,6 +24,8 @@ from ott.math import fixed_point_loop, matrix_square_root +__all__ = ["Euclidean", "SqEuclidean", "Cosine", "Bures", "UnbalancedBures"] + @jax.tree_util.register_pytree_node_class class CostFn(abc.ABC): diff --git a/ott/geometry/epsilon_scheduler.py b/ott/geometry/epsilon_scheduler.py index 1e0fc30e7..8f5666385 100644 --- a/ott/geometry/epsilon_scheduler.py +++ b/ott/geometry/epsilon_scheduler.py @@ -19,6 +19,8 @@ import jax import jax.numpy as jnp +__all__ = ["Epsilon"] + @jax.tree_util.register_pytree_node_class class Epsilon: diff --git a/ott/geometry/geometry.py b/ott/geometry/geometry.py index ff3d5000c..3550c881b 100644 --- a/ott/geometry/geometry.py +++ b/ott/geometry/geometry.py @@ -28,6 +28,8 @@ from ott.geometry import epsilon_scheduler from ott.math import utils +__all__ = ["Geometry"] + @jax.tree_util.register_pytree_node_class class Geometry: diff --git a/ott/geometry/grid.py b/ott/geometry/grid.py index 88bd67f0a..3cc94c77b 100644 --- a/ott/geometry/grid.py +++ b/ott/geometry/grid.py @@ -24,6 +24,8 @@ from ott.geometry import costs, geometry, pointcloud from ott.math import utils +__all__ = ["Grid"] + @jax.tree_util.register_pytree_node_class class Grid(geometry.Geometry): diff --git a/ott/geometry/low_rank.py b/ott/geometry/low_rank.py index bf8bfa2a3..1fc748059 100644 --- a/ott/geometry/low_rank.py +++ b/ott/geometry/low_rank.py @@ -22,6 +22,8 @@ from ott.geometry import geometry +__all__ = ["LRCGeometry"] + @jax.tree_util.register_pytree_node_class class LRCGeometry(geometry.Geometry): diff --git a/ott/geometry/pointcloud.py b/ott/geometry/pointcloud.py index dcd576a48..c435a8ded 100644 --- a/ott/geometry/pointcloud.py +++ b/ott/geometry/pointcloud.py @@ -24,6 +24,8 @@ from ott.geometry import costs, geometry, low_rank from ott.math import utils +__all__ = ["PointCloud"] + @jax.tree_util.register_pytree_node_class class PointCloud(geometry.Geometry): diff --git a/ott/initializers/linear/__init__.py b/ott/initializers/linear/__init__.py index e69de29bb..1ce1a00cd 100644 --- a/ott/initializers/linear/__init__.py +++ b/ott/initializers/linear/__init__.py @@ -0,0 +1 @@ +from . import initializers, initializers_lr diff --git a/ott/initializers/linear/initializers.py b/ott/initializers/linear/initializers.py index 2b10ed6e2..65b91cc52 100644 --- a/ott/initializers/linear/initializers.py +++ b/ott/initializers/linear/initializers.py @@ -13,18 +13,17 @@ # limitations under the License. """Sinkhorn initializers.""" import abc -from typing import Any, Dict, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple import jax import jax.numpy as jnp from ott.geometry import pointcloud -from ott.problems.linear import linear_problem -__all__ = [ - "SinkhornInitializer", "DefaultInitializer", "GaussianInitializer", - "SortingInitializer" -] +if TYPE_CHECKING: + from ott.problems.linear import linear_problem + +__all__ = ["DefaultInitializer", "GaussianInitializer", "SortingInitializer"] @jax.tree_util.register_pytree_node_class @@ -33,19 +32,19 @@ class SinkhornInitializer(abc.ABC): @abc.abstractmethod def init_dual_a( - self, ot_prob: linear_problem.LinearProblem, lse_mode: bool + self, ot_prob: 'linear_problem.LinearProblem', lse_mode: bool ) -> jnp.ndarray: """Initialization for Sinkhorn potential/scaling f_u.""" @abc.abstractmethod def init_dual_b( - self, ot_prob: linear_problem.LinearProblem, lse_mode: bool + self, ot_prob: 'linear_problem.LinearProblem', lse_mode: bool ) -> jnp.ndarray: """Initialization for Sinkhorn potential/scaling g_v.""" def __call__( self, - ot_prob: linear_problem.LinearProblem, + ot_prob: 'linear_problem.LinearProblem', a: Optional[jnp.ndarray], b: Optional[jnp.ndarray], lse_mode: bool, @@ -97,7 +96,7 @@ class DefaultInitializer(SinkhornInitializer): """Default initialization of Sinkhorn dual potentials/primal scalings.""" def init_dual_a( - self, ot_prob: linear_problem.LinearProblem, lse_mode: bool + self, ot_prob: 'linear_problem.LinearProblem', lse_mode: bool ) -> jnp.ndarray: """Initialize Sinkhorn potential/scaling f_u. @@ -113,7 +112,7 @@ def init_dual_a( return init_dual_a def init_dual_b( - self, ot_prob: linear_problem.LinearProblem, lse_mode: bool + self, ot_prob: 'linear_problem.LinearProblem', lse_mode: bool ) -> jnp.ndarray: """Initialize Sinkhorn potential/scaling g_v. @@ -141,7 +140,7 @@ class GaussianInitializer(DefaultInitializer): def init_dual_a( self, - ot_prob: linear_problem.LinearProblem, + ot_prob: 'linear_problem.LinearProblem', lse_mode: bool, ) -> jnp.ndarray: """Gaussian initialization function. @@ -243,7 +242,7 @@ def cond_fn(state: Tuple[jnp.ndarray, float, int]) -> bool: def init_dual_a( self, - ot_prob: linear_problem.LinearProblem, + ot_prob: 'linear_problem.LinearProblem', lse_mode: bool, init_f: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: diff --git a/ott/initializers/linear/initializers_lr.py b/ott/initializers/linear/initializers_lr.py index 2999bf850..a8132816f 100644 --- a/ott/initializers/linear/initializers_lr.py +++ b/ott/initializers/linear/initializers_lr.py @@ -21,22 +21,19 @@ from ott.math import fixed_point_loop from ott.math import utils as mu -__all__ = [ - "RandomInitializer", "Rank2Initializer", "KMeansInitializer", - "GeneralizedKMeansInitializer" -] - if TYPE_CHECKING: from ott.problems.linear import linear_problem from ott.problems.quadratic import quadratic_problem from ott.solvers.linear import sinkhorn, sinkhorn_lr from ott.solvers.quadratic import gromov_wasserstein - Problem_t = Union[linear_problem.LinearProblem, - quadratic_problem.QuadraticProblem] -else: - Problem_t = "Union[linear_problems.LinearProblem, " \ - "quad_problems.QuadraticProblem]" +Problem_t = Union["linear_problem.LinearProblem", + "quadratic_problem.QuadraticProblem"] + +__all__ = [ + "RandomInitializer", "Rank2Initializer", "KMeansInitializer", + "GeneralizedKMeansInitializer" +] @jax.tree_util.register_pytree_node_class diff --git a/ott/initializers/nn/__init__.py b/ott/initializers/nn/__init__.py index e69de29bb..7ccb321da 100644 --- a/ott/initializers/nn/__init__.py +++ b/ott/initializers/nn/__init__.py @@ -0,0 +1 @@ +from . import initializers diff --git a/ott/initializers/nn/initializers.py b/ott/initializers/nn/initializers.py index 05f173b42..09fe1b26b 100644 --- a/ott/initializers/nn/initializers.py +++ b/ott/initializers/nn/initializers.py @@ -1,30 +1,31 @@ import functools -from typing import Any, Dict, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple import jax -import jax.numpy as jnp import optax from flax import linen as nn from flax.training import train_state +from jax import numpy as jnp from ott.geometry import geometry -from ott.initializers.linear import DefaultInitializer -from ott.initializers.nn import layers -from ott.problems.linear import linear_problem -from ott.solvers.linear import sinkhorn +from ott.initializers.linear import initializers -# TODO(michalk8): add Charlotte's initializer? -__all__ = ["MetaInitializer"] +if TYPE_CHECKING: + from ott.problems.linear import linear_problem + +# TODO(michalk8): add initializer for NeuralDual? +__all__ = ["MetaInitializer", "MetaMLP"] @jax.tree_util.register_pytree_node_class -class MetaInitializer(DefaultInitializer): +class MetaInitializer(initializers.DefaultInitializer): """Meta OT Initializer with a fixed geometry :cite:`amos:22`. This initializer consists of a predictive model that outputs the :math:`f` duals to solve the entropy-regularized OT problem given input probability weights ``a`` and ``b``, and a given (assumed to be fixed) geometry ``geom``. + The model's parameters are learned using a training set of OT instances (multiple pairs of probability weights), that assume the **same** geometry ``geom`` is used throughout, both for training and @@ -67,7 +68,7 @@ def __init__( self.rng = rng na, nb = geom.shape - self.meta_model = layers.MetaMLP( + self.meta_model = MetaMLP( potential_size=na ) if meta_model is None else meta_model @@ -120,7 +121,7 @@ def update( return self.update_impl(state, a, b) def init_dual_a( - self, ot_prob: linear_problem.LinearProblem, lse_mode: bool + self, ot_prob: 'linear_problem.LinearProblem', lse_mode: bool ) -> jnp.ndarray: # Detect if the problem is batched. assert ot_prob.a.ndim in (1, 2) and ot_prob.b.ndim in (1, 2) @@ -140,6 +141,8 @@ def init_dual_a( def _get_update_fn(self): """Return the implementation (and jitted) update function.""" + from ott.problems.linear import linear_problem + from ott.solvers.linear import sinkhorn def dual_obj_loss_single(params, a, b): f_pred = self._compute_f(a, b, params) @@ -186,3 +189,38 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: 'rng': self.rng, 'state': self.state } + + +class MetaMLP(nn.Module): + r"""A Meta MLP potential for :class:`~ott.core.initializers.MetaInitializer`. + + This provides an MLP :math:`\hat f_\theta(a, b)` that maps from the + probabilities of the measures to the optimal dual potentials :math:`f`. + + Args: + potential_size: The dimensionality of :math:`f`. + num_hidden_units: The number of hidden units in each layer. + num_hidden_layers: The number of hidden layers. + """ + + potential_size: int + num_hidden_units: int = 512 + num_hidden_layers: int = 3 + + @nn.compact + def __call__(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: + r"""Make a prediction. + + Args: + a: Probabilities of the :math:`\alpha` measure's atoms. + b: Probabilities of the :math:`\beta` measure's atoms. + + Returns: + The :math:`f` potential. + """ + dtype = a.dtype + z = jnp.concatenate((a, b)) + for _ in range(self.num_hidden_layers): + z = nn.relu(nn.Dense(self.num_hidden_units, dtype=dtype)(z)) + f = nn.Dense(self.potential_size, dtype=dtype)(z) + return f diff --git a/ott/initializers/nn/layers.py b/ott/initializers/nn/layers.py deleted file mode 100644 index 4f6b25382..000000000 --- a/ott/initializers/nn/layers.py +++ /dev/null @@ -1,39 +0,0 @@ -from flax import linen as nn -from jax import numpy as jnp - -__all__ = ["MetaMLP"] - - -class MetaMLP(nn.Module): - r"""A Meta MLP potential for :class:`~ott.core.initializers.MetaInitializer`. - - This provides an MLP :math:`\hat f_\theta(a, b)` that maps from the - probabilities of the measures to the optimal dual potentials :math:`f`. - - Args: - potential_size: The dimensionality of :math:`f`. - num_hidden_units: The number of hidden units in each layer. - num_hidden_layers: The number of hidden layers. - """ - - potential_size: int - num_hidden_units: int = 512 - num_hidden_layers: int = 3 - - @nn.compact - def __call__(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: - r"""Make a prediction. - - Args: - a: Probabilities of the :math:`\alpha` measure's atoms. - b: Probabilities of the :math:`\beta` measure's atoms. - - Returns: - The :math:`f` potential. - """ - dtype = a.dtype - z = jnp.concatenate((a, b)) - for _ in range(self.num_hidden_layers): - z = nn.relu(nn.Dense(self.num_hidden_units, dtype=dtype)(z)) - f = nn.Dense(self.potential_size, dtype=dtype)(z) - return f diff --git a/ott/initializers/quadratic/__init__.py b/ott/initializers/quadratic/__init__.py index e69de29bb..7ccb321da 100644 --- a/ott/initializers/quadratic/__init__.py +++ b/ott/initializers/quadratic/__init__.py @@ -0,0 +1 @@ +from . import initializers diff --git a/ott/initializers/quadratic/initializers.py b/ott/initializers/quadratic/initializers.py index bab5a8f54..0c97f83a5 100644 --- a/ott/initializers/quadratic/initializers.py +++ b/ott/initializers/quadratic/initializers.py @@ -4,11 +4,10 @@ import jax from ott.geometry import geometry -from ott.problems.linear import linear_problem -from ott.solvers.linear import sinkhorn_lr if TYPE_CHECKING: from ott.initializers.linear import initializers_lr + from ott.problems.linear import linear_problem from ott.problems.quadratic import quadratic_problem __all__ = ["QuadraticInitializer", "LRQuadraticInitializer"] @@ -27,7 +26,7 @@ def __init__(self, **kwargs: Any): def __call__( self, quad_prob: 'quadratic_problem.QuadraticProblem', **kwargs: Any - ) -> linear_problem.LinearProblem: + ) -> 'linear_problem.LinearProblem': """Compute the initial linearization of a quadratic problem. Args: @@ -37,6 +36,8 @@ def __call__( Returns: Linear problem. """ + from ott.problems.linear import linear_problem + n, m = quad_prob.geom_xx.shape[0], quad_prob.geom_yy.shape[0] geom = self._create_geometry(quad_prob, **kwargs) assert geom.shape == (n, m), f"Expected geometry of shape `{n, m}`, " \ @@ -134,10 +135,10 @@ def _create_geometry( transport_mass = marginal_1.sum() # Initialises epsilon for Unbalanced GW according to Sejourne et al (2021) epsilon = quadratic_problem.update_epsilon_unbalanced( - epsilon, transport_mass + epsilon=epsilon, transport_mass=transport_mass ) unbalanced_correction = quad_prob.cost_unbalanced_correction( - tmp, marginal_1, marginal_2, epsilon + tmp, marginal_1, marginal_2, epsilon=epsilon ) h1, h2 = quad_prob.quad_loss @@ -176,6 +177,8 @@ def _create_geometry( Returns: The initial geometry used to initialize a linear problem. """ + from ott.solvers.linear import sinkhorn_lr + q, r, g = self._linear_lr_initializer(quad_prob, **kwargs) tmp_out = sinkhorn_lr.LRSinkhornOutput( q=q, r=r, g=g, costs=None, errors=None, ot_prob=None diff --git a/ott/math/__init__.py b/ott/math/__init__.py index e69de29bb..faac35ec8 100644 --- a/ott/math/__init__.py +++ b/ott/math/__init__.py @@ -0,0 +1,9 @@ +from . import ( + decomposition, + fixed_point_loop, + implicit_differentiation, + matrix_square_root, + potentials, + unbalanced_functions, + utils, +) diff --git a/ott/math/fixed_point_loop.py b/ott/math/fixed_point_loop.py index fbfb04789..3b22ad20a 100644 --- a/ott/math/fixed_point_loop.py +++ b/ott/math/fixed_point_loop.py @@ -19,7 +19,8 @@ import jax import jax.numpy as jnp import numpy as np -from jax import dtypes + +__all__ = ["fixpoint_iter", "fixpoint_iter_backprop"] def fixpoint_iter( @@ -123,7 +124,7 @@ def fixpoint_iter_fwd( states = jax.tree_util.tree_map( lambda x: jnp.zeros( (max_iterations // inner_iterations + 1,) + jnp.shape(x), - dtype=dtypes.result_type(x) + dtype=jax.dtypes.result_type(x) ), state ) diff --git a/ott/math/implicit_differentiation.py b/ott/math/implicit_differentiation.py index 44fc946e5..1f42888c0 100644 --- a/ott/math/implicit_differentiation.py +++ b/ott/math/implicit_differentiation.py @@ -13,22 +13,26 @@ # limitations under the License. """Functions entering the implicit differentiation of Sinkhorn.""" -from typing import Callable, Optional, Tuple +from typing import TYPE_CHECKING, Callable, Optional, Tuple import jax import jax.numpy as jnp from ott.math import unbalanced_functions -from ott.problems import linear as linear_problems from ott.utils import dataclasses +if TYPE_CHECKING: + from ott.problems.linear import linear_problem + +__all__ = ["ImplicitDiff"] + @dataclasses.register_pytree_node class ImplicitDiff: """Implicit differentiation of Sinkhorn algorithm. Attributes: - implicit_solver_fun: Callable, should return (solution, ...) + solver_fun: Callable, should return (solution, ...) ridge_kernel: promotes zero-sum solutions. only used if tau_a = tau_b = 1.0 ridge_identity: handles rank deficient transport matrices (this happens typically when rows/cols in cost/kernel matrices are colinear, or, @@ -48,9 +52,9 @@ class ImplicitDiff: def solve( self, gr: Tuple[jnp.ndarray, - jnp.ndarray], ot_prob: linear_problems.LinearProblem, + jnp.ndarray], ot_prob: "linear_problem.LinearProblem", f: jnp.ndarray, g: jnp.ndarray, lse_mode: bool - ): + ) -> jnp.ndarray: r"""Apply minus inverse of [hessian ``reg_ot_cost`` w.r.t. ``f``, ``g``]. This function is used to carry out implicit differentiation of ``sinkhorn`` @@ -271,9 +275,9 @@ def first_order_conditions( return jnp.concatenate((result_a, result_b)) def gradient( - self, prob: linear_problems.LinearProblem, f: jnp.ndarray, g: jnp.ndarray, - lse_mode: bool, gr: Tuple[jnp.ndarray, jnp.ndarray] - ) -> linear_problems.LinearProblem: + self, prob: "linear_problem.LinearProblem", f: jnp.ndarray, + g: jnp.ndarray, lse_mode: bool, gr: Tuple[jnp.ndarray, jnp.ndarray] + ) -> "linear_problem.LinearProblem": """Apply vjp to recover gradient in reverse mode differentiation.""" # Applies first part of vjp to gr: inverse part of implicit function theorem vjp_gr = self.solve(gr, prob, f, g, lse_mode) diff --git a/ott/math/matrix_square_root.py b/ott/math/matrix_square_root.py index 9cca55695..74a7651ad 100644 --- a/ott/math/matrix_square_root.py +++ b/ott/math/matrix_square_root.py @@ -24,6 +24,8 @@ from ott.math import fixed_point_loop +__all__ = ["sqrtm"] + @functools.partial(jax.custom_vjp, nondiff_argnums=(1, 2, 3, 4, 5)) def sqrtm( @@ -33,7 +35,7 @@ def sqrtm( inner_iterations: int = 10, max_iterations: int = 1000, regularization: float = 1e-3 -) -> jnp.ndarray: +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Higham algorithm to compute matrix square root of p.d. matrix. See :cite:`higham:97`, eq. 2.6b @@ -171,7 +173,7 @@ def sqrtm_fwd( max_iterations=max_iterations, regularization=regularization, ) - return ((sqrt_x, inv_sqrt_x, errors), (sqrt_x, inv_sqrt_x)) + return (sqrt_x, inv_sqrt_x, errors), (sqrt_x, inv_sqrt_x) def sqrtm_bwd( @@ -226,7 +228,7 @@ def sqrtm_bwd( axis1=-1, axis2=-2 ) - return (vjp_cot_sqrt + vjp_cot_inv_sqrt,) + return vjp_cot_sqrt + vjp_cot_inv_sqrt, sqrtm.defvjp(sqrtm_fwd, sqrtm_bwd) @@ -254,7 +256,7 @@ def sqrtm_only_bwd(sqrt_x: jnp.ndarray, axis1=-2, axis2=-1 ) - return (vjp,) + return vjp, sqrtm_only.defvjp(sqrtm_only_fwd, sqrtm_only_bwd) @@ -270,9 +272,8 @@ def inv_sqrtm_only_fwd(x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: return inv_sqrt_x, inv_sqrt_x -def inv_sqrtm_only_bwd( - residual: jnp.ndarray, cotangent: jnp.ndarray -) -> jnp.ndarray: +def inv_sqrtm_only_bwd(residual: jnp.ndarray, + cotangent: jnp.ndarray) -> Tuple[jnp.ndarray]: inv_sqrt_x = residual inv_x = jnp.matmul(inv_sqrt_x, inv_sqrt_x) vjp = jnp.swapaxes( @@ -287,7 +288,7 @@ def inv_sqrtm_only_bwd( axis1=-1, axis2=-2 ) - return (vjp,) + return vjp, inv_sqrtm_only.defvjp(inv_sqrtm_only_fwd, inv_sqrtm_only_bwd) diff --git a/ott/math/potentials.py b/ott/math/potentials.py index 0170825cc..6c84ae48e 100644 --- a/ott/math/potentials.py +++ b/ott/math/potentials.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Sequence, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, Sequence, Tuple import jax import jax.numpy as jnp @@ -6,7 +6,8 @@ import jax.tree_util as jtu from typing_extensions import Literal -from ott.geometry import pointcloud +if TYPE_CHECKING: + from ott.geometry import pointcloud __all__ = ["DualPotentials", "EntropicPotentials"] Potential_t = Callable[[jnp.ndarray], float] @@ -130,7 +131,7 @@ class EntropicPotentials(DualPotentials): """ def __init__( - self, f: jnp.ndarray, g: jnp.ndarray, geom: pointcloud.PointCloud, + self, f: jnp.ndarray, g: jnp.ndarray, geom: "pointcloud.PointCloud", a: jnp.ndarray, b: jnp.ndarray ): n, m = geom.shape @@ -157,6 +158,7 @@ def g(self) -> Potential_t: def _create_potential_function( self, *, kind: Literal["f", "g"] ) -> Potential_t: + from ott.geometry import pointcloud def callback(x: jnp.ndarray) -> float: cost = pointcloud.PointCloud( @@ -164,7 +166,7 @@ def callback(x: jnp.ndarray) -> float: y, cost_fn=self._geom.cost_fn, power=self._geom.power, - epsilon=1.0 # epsilon is not used + epsilon=1.0 # epsilon is not used ).cost_matrix return -eps * jsp.special.logsumexp((potential - cost) / eps, b=prob_weights) diff --git a/ott/math/utils.py b/ott/math/utils.py index 6b1398aad..3c83f8e82 100644 --- a/ott/math/utils.py +++ b/ott/math/utils.py @@ -10,6 +10,7 @@ "barycentric_projection" ] +# TODO(michalk8): move to typing.py when refactoring types Sparse_t = Union[jesp.CSR, jesp.CSC, jesp.COO, jesp.BCOO] diff --git a/ott/problems/linear/__init__.py b/ott/problems/linear/__init__.py index 1ab3d4ccb..1681cefbb 100644 --- a/ott/problems/linear/__init__.py +++ b/ott/problems/linear/__init__.py @@ -1,2 +1 @@ -from .barycenter_problem import BarycenterProblem -from .linear_problem import LinearProblem +from . import barycenter_problem, linear_problem diff --git a/ott/problems/linear/linear_problem.py b/ott/problems/linear/linear_problem.py index 97c54d4e8..6f7aad57d 100644 --- a/ott/problems/linear/linear_problem.py +++ b/ott/problems/linear/linear_problem.py @@ -22,6 +22,7 @@ __all__ = ["LinearProblem"] +# TODO(michalk8): move to typing.py when refactoring the types MarginalFunc = Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray] TransportAppFunc = Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray, int], jnp.ndarray] diff --git a/ott/problems/quadratic/__init__.py b/ott/problems/quadratic/__init__.py index 48c54ef97..44f37cc34 100644 --- a/ott/problems/quadratic/__init__.py +++ b/ott/problems/quadratic/__init__.py @@ -1,3 +1 @@ -from . import quadratic_costs -from .barycenter_problem import GWBarycenterProblem -from .quadratic_problem import QuadraticProblem +from . import barycenter_problem, quadratic_costs, quadratic_problem diff --git a/ott/problems/quadratic/quadratic_problem.py b/ott/problems/quadratic/quadratic_problem.py index 03d0bdbb7..e2060015d 100644 --- a/ott/problems/quadratic/quadratic_problem.py +++ b/ott/problems/quadratic/quadratic_problem.py @@ -13,7 +13,7 @@ # limitations under the License. """Classes defining OT problem(s) (objective function + utilities).""" -from typing import Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional, Tuple, Union import jax import jax.numpy as jnp @@ -22,9 +22,11 @@ from ott.geometry import epsilon_scheduler, geometry, low_rank, pointcloud from ott.problems.linear import linear_problem from ott.problems.quadratic import quadratic_costs -from ott.solvers.linear import sinkhorn_lr from ott.typing import Transport +if TYPE_CHECKING: + from ott.solvers.linear import sinkhorn_lr + __all__ = ["QuadraticProblem"] @@ -238,7 +240,7 @@ def init_transport_mass(self) -> float: return a.sum() * b.sum() def update_lr_geom( - self, lr_sink: sinkhorn_lr.LRSinkhornOutput + self, lr_sink: 'sinkhorn_lr.LRSinkhornOutput' ) -> geometry.Geometry: """Recompute (possibly LRC) linearization using LR Sinkhorn output.""" marginal_1 = lr_sink.marginal(1) @@ -325,7 +327,7 @@ def update_linearization( ) def update_lr_linearization( - self, lr_sink: sinkhorn_lr.LRSinkhornOutput + self, lr_sink: 'sinkhorn_lr.LRSinkhornOutput' ) -> linear_problem.LinearProblem: """Update a Quad problem linearization using a LR Sinkhorn.""" return linear_problem.LinearProblem( diff --git a/ott/solvers/linear/__init__.py b/ott/solvers/linear/__init__.py index e69de29bb..7f9d3a1bb 100644 --- a/ott/solvers/linear/__init__.py +++ b/ott/solvers/linear/__init__.py @@ -0,0 +1,7 @@ +from . import ( + acceleration, + continuous_barycenter, + discrete_barycenter, + sinkhorn, + sinkhorn_lr, +) diff --git a/ott/solvers/linear/acceleration.py b/ott/solvers/linear/acceleration.py index 8b638abc9..9684f2956 100644 --- a/ott/solvers/linear/acceleration.py +++ b/ott/solvers/linear/acceleration.py @@ -38,7 +38,7 @@ def extrapolation(self, xs: jnp.ndarray, fxs: jnp.ndarray) -> jnp.ndarray: # Recover linear combination and return it with NaN (caused # by 0 weights leading to -jnp.inf potentials, mixed with weights - # coefficiences of different signs), disambiguated to -inf. + # coefficients of different signs), disambiguated to -inf. combination = jnp.sum(fxs * weights[None, :], axis=1) return jnp.where(jnp.isfinite(combination), combination, -jnp.inf) diff --git a/ott/solvers/linear/continuous_barycenter.py b/ott/solvers/linear/continuous_barycenter.py index 5bc105321..e1d4a96e2 100644 --- a/ott/solvers/linear/continuous_barycenter.py +++ b/ott/solvers/linear/continuous_barycenter.py @@ -39,8 +39,6 @@ class BarycenterState(NamedTuple): inner Sinkhorn iterations. errors: Holds sequence of vectors of errors of the Sinkhorn algorithm at each iteration. - linear_states: State used to solve and store solutions to the OT problems - from the barycenter to the measures. x: barycenter points. a: barycenter weights. """ diff --git a/ott/solvers/linear/discrete_barycenter.py b/ott/solvers/linear/discrete_barycenter.py index 6373c9657..044295129 100644 --- a/ott/solvers/linear/discrete_barycenter.py +++ b/ott/solvers/linear/discrete_barycenter.py @@ -35,7 +35,7 @@ class SinkhornBarycenterOutput(NamedTuple): errors: jnp.ndarray -# TODO(michalk8): refactor as a solver +# TODO(michalk8): refactor as a solver? def discrete_barycenter( geom: geometry.Geometry, a: jnp.ndarray, diff --git a/ott/solvers/linear/sinkhorn.py b/ott/solvers/linear/sinkhorn.py index 28f41657d..b8201d6e5 100644 --- a/ott/solvers/linear/sinkhorn.py +++ b/ott/solvers/linear/sinkhorn.py @@ -27,11 +27,10 @@ from ott.math import implicit_differentiation as implicit_lib from ott.math import potentials, unbalanced_functions from ott.problems.linear import linear_problem +from ott.solvers.linear import acceleration __all__ = ["Sinkhorn", "SinkhornOutput"] -from ott.solvers.linear import acceleration - class SinkhornState(NamedTuple): """Holds the state variables used to solve OT with Sinkhorn.""" diff --git a/ott/solvers/nn/__init__.py b/ott/solvers/nn/__init__.py index e6c465331..a695e215c 100644 --- a/ott/solvers/nn/__init__.py +++ b/ott/solvers/nn/__init__.py @@ -1 +1 @@ -# TODO(michalk8): imports +from . import icnn, layers, neuraldual diff --git a/ott/solvers/nn/neuraldual.py b/ott/solvers/nn/neuraldual.py index b7b773164..dcde36866 100644 --- a/ott/solvers/nn/neuraldual.py +++ b/ott/solvers/nn/neuraldual.py @@ -28,8 +28,7 @@ __all__ = ["NeuralDualSolver"] -Train_t = Dict[Literal["training_logs", "validation_logs"], List[float]] -Potentials_t = potentials.DualPotentials +Train_t = Dict[Literal["train_logs", "valid_logs"], Dict[str, List[float]]] class NeuralDualSolver: @@ -139,7 +138,8 @@ def __call__( trainloader_target: Iterator[jnp.ndarray], validloader_source: Iterator[jnp.ndarray], validloader_target: Iterator[jnp.ndarray], - ) -> Union[Potentials_t, Tuple[Potentials_t, Train_t]]: + ) -> Union[potentials.DualPotentials, Tuple[potentials.DualPotentials, + Train_t]]: logs = self.train_neuraldual( trainloader_source, trainloader_target, @@ -152,10 +152,10 @@ def __call__( def train_neuraldual( self, - trainloader_source, - trainloader_target, - validloader_source, - validloader_target, + trainloader_source: Iterator[jnp.ndarray], + trainloader_target: Iterator[jnp.ndarray], + validloader_source: Iterator[jnp.ndarray], + validloader_target: Iterator[jnp.ndarray], ) -> Train_t: """Implementation of the training and validation script.""" # noqa: D401 try: diff --git a/ott/solvers/quadratic/__init__.py b/ott/solvers/quadratic/__init__.py index e69de29bb..af9e5d01e 100644 --- a/ott/solvers/quadratic/__init__.py +++ b/ott/solvers/quadratic/__init__.py @@ -0,0 +1 @@ +from . import gromov_wasserstein, gw_barycenter diff --git a/ott/solvers/quadratic/gromov_wasserstein.py b/ott/solvers/quadratic/gromov_wasserstein.py index 98c3a0af6..805354fa1 100644 --- a/ott/solvers/quadratic/gromov_wasserstein.py +++ b/ott/solvers/quadratic/gromov_wasserstein.py @@ -29,6 +29,8 @@ from ott.solvers.linear import sinkhorn, sinkhorn_lr from ott.utils import was_solver +__all__ = ["GWOutput", "GromovWasserstein", "gromov_wasserstein"] + LinearOutput = Union[sinkhorn.SinkhornOutput, sinkhorn_lr.LRSinkhornOutput] diff --git a/ott/utils/__init__.py b/ott/utils/__init__.py index e69de29bb..3f5072c03 100644 --- a/ott/utils/__init__.py +++ b/ott/utils/__init__.py @@ -0,0 +1 @@ +from . import dataclasses, segment, was_solver diff --git a/ott/utils/dataclasses.py b/ott/utils/dataclasses.py index 78aebac51..5f3de2ef6 100644 --- a/ott/utils/dataclasses.py +++ b/ott/utils/dataclasses.py @@ -17,6 +17,8 @@ import jax +__all__ = ["register_pytree_node"] + def register_pytree_node(cls: type) -> type: """Register dataclasses as pytree_nodes.""" diff --git a/ott/utils/segment.py b/ott/utils/segment.py index 82c4c98d3..2e15db821 100644 --- a/ott/utils/segment.py +++ b/ott/utils/segment.py @@ -16,6 +16,8 @@ import jax from jax import numpy as jnp +__all__ = ["segment_point_cloud"] + def segment_point_cloud( x: jnp.ndarray, diff --git a/ott/utils/was_solver.py b/ott/utils/was_solver.py index 7e81c4377..2ff7789d9 100644 --- a/ott/utils/was_solver.py +++ b/ott/utils/was_solver.py @@ -14,15 +14,18 @@ # Lint as: python3 """A Jax version of the regularised GW Solver (Peyre et al. 2016).""" -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Union import jax import jax.numpy as jnp -from ott.solvers.linear import sinkhorn, sinkhorn_lr +if TYPE_CHECKING: + from ott.solvers.linear import continuous_barycenter, sinkhorn, sinkhorn_lr -State = Union[sinkhorn.SinkhornState, sinkhorn_lr.LRSinkhornState, - "continuous_barycenter.BarycenterState"] # noqa: F821 +__all__ = ["WassersteinSolver"] + +State = Union["sinkhorn.SinkhornState", "sinkhorn_lr.LRSinkhornState", + "continuous_barycenter.BarycenterState"] @jax.tree_util.register_pytree_node_class @@ -33,8 +36,8 @@ def __init__( self, epsilon: Optional[float] = None, rank: int = -1, - linear_ot_solver: Optional[Union[sinkhorn.Sinkhorn, - sinkhorn_lr.LRSinkhorn]] = None, + linear_ot_solver: Optional[Union["sinkhorn.Sinkhorn", + "sinkhorn_lr.LRSinkhorn"]] = None, min_iterations: int = 5, max_iterations: int = 50, threshold: float = 1e-3, @@ -42,6 +45,7 @@ def __init__( store_inner_errors: bool = False, **kwargs: Any, ): + from ott.solvers.linear import sinkhorn, sinkhorn_lr default_epsilon = 1.0 # Set epsilon value to default if needed, but keep track of whether None was # passed to handle the case where a linear_ot_solver is passed or not. From 0d30656cc2001cf5a3569799c1cf6d0d32f857eb Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Wed, 16 Nov 2022 14:28:29 +0100 Subject: [PATCH 04/34] Fix more types, pet-peeves --- ott/problems/quadratic/__init__.py | 2 +- ...{barycenter_problem.py => gw_barycenter.py} | 2 +- ott/solvers/nn/neuraldual.py | 18 +++++++++--------- ott/solvers/quadratic/gw_barycenter.py | 13 ++++++------- ott/tools/gaussian_mixture/fit_gmm.py | 15 +++++++-------- ott/tools/segment_sinkhorn.py | 4 ++-- ott/tools/sinkhorn_divergence.py | 4 ++-- 7 files changed, 28 insertions(+), 30 deletions(-) rename ott/problems/quadratic/{barycenter_problem.py => gw_barycenter.py} (99%) diff --git a/ott/problems/quadratic/__init__.py b/ott/problems/quadratic/__init__.py index 44f37cc34..18ff1c517 100644 --- a/ott/problems/quadratic/__init__.py +++ b/ott/problems/quadratic/__init__.py @@ -1 +1 @@ -from . import barycenter_problem, quadratic_costs, quadratic_problem +from . import gw_barycenter, quadratic_costs, quadratic_problem diff --git a/ott/problems/quadratic/barycenter_problem.py b/ott/problems/quadratic/gw_barycenter.py similarity index 99% rename from ott/problems/quadratic/barycenter_problem.py rename to ott/problems/quadratic/gw_barycenter.py index 089e645de..1bfd6a28d 100644 --- a/ott/problems/quadratic/barycenter_problem.py +++ b/ott/problems/quadratic/gw_barycenter.py @@ -36,7 +36,7 @@ class GWBarycenterProblem(barycenter_problem.BarycenterProblem): Only one of ``y`` and ``cost`` can be specified. y_fused: Array of shape ``[num_total_points, ndim_fused]`` containing the data of the points of all measures used to define the linear term - in the fused case. Similarly as ``y``, can be specified as a pre-segmented + in the fused case. Same as ``y``, it can be specified as a pre-segmented array of shape ``[num_measures, max_measure_size, ndim_fused]``. gw_loss: Gromov-Wasserstein loss. fused_penalty: Multiplier of the linear term. Only used when diff --git a/ott/solvers/nn/neuraldual.py b/ott/solvers/nn/neuraldual.py index dcde36866..537a7e4e3 100644 --- a/ott/solvers/nn/neuraldual.py +++ b/ott/solvers/nn/neuraldual.py @@ -14,7 +14,7 @@ """A Jax implementation of the ICNN based Kantorovich dual.""" import warnings -from typing import Dict, Iterator, List, Optional, Tuple, Union +from typing import Dict, Iterable, List, Optional, Tuple, Union import flax.linen as nn import jax @@ -134,10 +134,10 @@ def setup( def __call__( self, - trainloader_source: Iterator[jnp.ndarray], - trainloader_target: Iterator[jnp.ndarray], - validloader_source: Iterator[jnp.ndarray], - validloader_target: Iterator[jnp.ndarray], + trainloader_source: Iterable[jnp.ndarray], + trainloader_target: Iterable[jnp.ndarray], + validloader_source: Iterable[jnp.ndarray], + validloader_target: Iterable[jnp.ndarray], ) -> Union[potentials.DualPotentials, Tuple[potentials.DualPotentials, Train_t]]: logs = self.train_neuraldual( @@ -152,10 +152,10 @@ def __call__( def train_neuraldual( self, - trainloader_source: Iterator[jnp.ndarray], - trainloader_target: Iterator[jnp.ndarray], - validloader_source: Iterator[jnp.ndarray], - validloader_target: Iterator[jnp.ndarray], + trainloader_source: Iterable[jnp.ndarray], + trainloader_target: Iterable[jnp.ndarray], + validloader_source: Iterable[jnp.ndarray], + validloader_target: Iterable[jnp.ndarray], ) -> Train_t: """Implementation of the training and validation script.""" # noqa: D401 try: diff --git a/ott/solvers/quadratic/gw_barycenter.py b/ott/solvers/quadratic/gw_barycenter.py index 8020f99e3..528f676c4 100644 --- a/ott/solvers/quadratic/gw_barycenter.py +++ b/ott/solvers/quadratic/gw_barycenter.py @@ -7,7 +7,7 @@ from ott.geometry import pointcloud from ott.math import fixed_point_loop from ott.problems.linear import linear_problem -from ott.problems.quadratic import barycenter_problem +from ott.problems.quadratic import gw_barycenter from ott.solvers.quadratic import gromov_wasserstein from ott.utils import was_solver @@ -91,7 +91,7 @@ def __init__( self._quad_solver = gromov_wasserstein.GromovWasserstein(**kwargs) def __call__( - self, problem: barycenter_problem.GWBarycenterProblem, bar_size: int, + self, problem: gw_barycenter.GWBarycenterProblem, bar_size: int, **kwargs: Any ) -> GWBarycenterState: """Solver the (fused) GW barycenter problem. @@ -111,7 +111,7 @@ def __call__( def init_state( self, - problem: barycenter_problem.GWBarycenterProblem, + problem: gw_barycenter.GWBarycenterProblem, bar_size: int, bar_init: Optional[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]] = None, @@ -186,7 +186,7 @@ def update_state( self, state: GWBarycenterState, iteration: int, - problem: barycenter_problem.GWBarycenterProblem, + problem: gw_barycenter.GWBarycenterProblem, store_errors: bool = True, ) -> Tuple[float, bool, jnp.ndarray, Optional[jnp.ndarray]]: @@ -283,8 +283,7 @@ def init_transports( def iterations( solver: GromovWassersteinBarycenter, - problem: barycenter_problem.GWBarycenterProblem, - init_state: GWBarycenterState + problem: gw_barycenter.GWBarycenterProblem, init_state: GWBarycenterState ) -> GWBarycenterState: def cond_fn( @@ -296,7 +295,7 @@ def cond_fn( def body_fn( iteration, constants: Tuple[GromovWassersteinBarycenter, - barycenter_problem.GWBarycenterProblem], + gw_barycenter.GWBarycenterProblem], state: GWBarycenterState, compute_error: bool ) -> GWBarycenterState: del compute_error # always assumed true diff --git a/ott/tools/gaussian_mixture/fit_gmm.py b/ott/tools/gaussian_mixture/fit_gmm.py index d09cf6288..112ac44d7 100644 --- a/ott/tools/gaussian_mixture/fit_gmm.py +++ b/ott/tools/gaussian_mixture/fit_gmm.py @@ -185,7 +185,7 @@ def fit_model_em( def _get_dist_sq(points: jnp.ndarray, loc: jnp.ndarray) -> jnp.ndarray: """Get the squared distance from each point to each loc.""" - def _dist_sq_one_loc(points: jnp.ndarray, loc: jnp.ndarray): + def _dist_sq_one_loc(points: jnp.ndarray, loc: jnp.ndarray) -> jnp.ndarray: return jnp.sum((points - loc[None]) ** 2., axis=-1) dist_sq_fn = jax.vmap(_dist_sq_one_loc, in_axes=(None, 0), out_axes=1) @@ -266,9 +266,9 @@ def initialize( key: jnp.ndarray, points: jnp.ndarray, point_weights: Optional[jnp.ndarray], - n_components: jnp.ndarray, - n_attempts=50, - verbose=False + n_components: int, + n_attempts: int = 50, + verbose: bool = False ) -> gaussian_mixture.GaussianMixture: """Initialize a GMM via K-means++ with retries on failure. @@ -289,7 +289,7 @@ def initialize( for attempt in range(n_attempts): key, subkey = jax.random.split(key) try: - gmm = from_kmeans_plusplus( + return from_kmeans_plusplus( key=subkey, points=points, point_weights=point_weights, @@ -297,6 +297,5 @@ def initialize( ) except ValueError: if verbose: - print(f'Failed to initialize, attempt {attempt}', flush=True) - return gmm - raise ValueError('Failed to initialize') + print(f'Failed to initialize, attempt {attempt}.', flush=True) + raise ValueError('Failed to initialize.') diff --git a/ott/tools/segment_sinkhorn.py b/ott/tools/segment_sinkhorn.py index 4a5427c6e..7f4afeb60 100644 --- a/ott/tools/segment_sinkhorn.py +++ b/ott/tools/segment_sinkhorn.py @@ -31,8 +31,8 @@ def segment_sinkhorn( segment_ids_x: Optional[jnp.ndarray] = None, segment_ids_y: Optional[jnp.ndarray] = None, indices_are_sorted: Optional[bool] = None, - num_per_segment_x: Tuple[int] = None, - num_per_segment_y: Tuple[int] = None, + num_per_segment_x: Optional[Tuple[int, ...]] = None, + num_per_segment_y: Optional[Tuple[int, ...]] = None, weights_x: Optional[jnp.ndarray] = None, weights_y: Optional[jnp.ndarray] = None, sinkhorn_kwargs: Mapping[str, Any] = MappingProxyType({}), diff --git a/ott/tools/sinkhorn_divergence.py b/ott/tools/sinkhorn_divergence.py index 46b196339..119f7be5d 100644 --- a/ott/tools/sinkhorn_divergence.py +++ b/ott/tools/sinkhorn_divergence.py @@ -192,8 +192,8 @@ def segment_sinkhorn_divergence( segment_ids_x: Optional[jnp.ndarray] = None, segment_ids_y: Optional[jnp.ndarray] = None, indices_are_sorted: Optional[bool] = None, - num_per_segment_x: Optional[jnp.ndarray] = None, - num_per_segment_y: Optional[jnp.ndarray] = None, + num_per_segment_x: Optional[Tuple[int, ...]] = None, + num_per_segment_y: Optional[Tuple[int, ...]] = None, weights_x: Optional[jnp.ndarray] = None, weights_y: Optional[jnp.ndarray] = None, sinkhorn_kwargs: Mapping[str, Any] = MappingProxyType({}), From da2c97abe592984a8d05271a727fe39bce7f32d0 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Wed, 16 Nov 2022 14:30:09 +0100 Subject: [PATCH 05/34] Fix tests --- tests/core/continuous_barycenter_test.py | 56 +++++++------ tests/core/discrete_barycenter_test.py | 2 +- tests/core/fused_gromov_wasserstein_test.py | 39 ++++----- tests/core/gromov_wasserstein_test.py | 13 +-- tests/core/icnn_test.py | 18 ++-- tests/core/initializers_test.py | 84 +++++++++---------- tests/core/neuraldual_test.py | 6 +- tests/core/potentials_test.py | 19 +++-- tests/core/sinkhorn_diff_test.py | 7 +- tests/core/sinkhorn_extra_test.py | 6 +- tests/core/sinkhorn_grid_test.py | 2 +- tests/core/sinkhorn_lr_test.py | 7 +- tests/core/sinkhorn_test.py | 5 +- tests/geometry/geometry_lse_test.py | 4 +- tests/geometry/graph_test.py | 11 +-- tests/geometry/matrix_square_root_test.py | 4 +- tests/geometry/scaling_cost_test.py | 5 +- .../tools/gaussian_mixture/scale_tril_test.py | 2 +- tests/tools/segment_sinkhorn_test.py | 2 +- tests/tools/sinkhorn_divergence_test.py | 2 +- tests/tools/transport_test.py | 4 +- 21 files changed, 154 insertions(+), 144 deletions(-) diff --git a/tests/core/continuous_barycenter_test.py b/tests/core/continuous_barycenter_test.py index 5c862a47f..6ff16784e 100644 --- a/tests/core/continuous_barycenter_test.py +++ b/tests/core/continuous_barycenter_test.py @@ -12,7 +12,7 @@ # limitations under the License. # Lint as: python3 -"""Tests for Continuous barycenters.""" +"""Tests for continuous barycenter.""" import functools from typing import Any, Optional, Sequence, Tuple @@ -21,9 +21,13 @@ import numpy as np import pytest -from ott.core import bar_problems, continuous_barycenter, gw_barycenter, segment from ott.geometry import costs, pointcloud +from ott.problems.linear import barycenter_problem +from ott.problems.quadratic import gw_barycenter as gwb +from ott.solvers.linear import continuous_barycenter as cb +from ott.solvers.quadratic import gw_barycenter as gwb_solver from ott.tools.gaussian_mixture import gaussian_mixture +from ott.utils import segment means_and_covs_to_x = jax.vmap(costs.mean_and_cov_to_x, in_axes=[0, 0, None]) @@ -70,7 +74,7 @@ def test_euclidean_barycenter( b.append(c / jnp.sum(c)) b = jnp.concatenate(b, axis=0) # Set a barycenter problem with 8 measures, of irregular sizes. - bar_prob = bar_problems.BarycenterProblem( + bar_prob = barycenter_problem.BarycenterProblem( y, b, epsilon=epsilon, @@ -84,9 +88,7 @@ def test_euclidean_barycenter( # Define solver threshold = 1e-3 - solver = continuous_barycenter.WassersteinBarycenter( - rank=rank, threshold=threshold, jit=jit - ) + solver = cb.WassersteinBarycenter(rank=rank, threshold=threshold, jit=jit) # Set barycenter size to 31. bar_size = 31 @@ -119,19 +121,21 @@ def test_barycenter_jit(self, rng: jnp.ndarray, segment_before: bool): @functools.partial(jax.jit, static_argnums=(2, 3)) def barycenter( - y: jnp.ndarray, b: jnp.ndarray, segment_before: bool, - num_per_segment: int - ) -> continuous_barycenter.BarycenterState: + y: jnp.ndarray, + b: jnp.ndarray, + segment_before: bool, + num_per_segment: Tuple[int, ...], + ) -> cb.BarycenterState: if segment_before: y, b = segment.segment_point_cloud( x=y, a=b, num_per_segment=num_per_segment ) - bar_prob = bar_problems.BarycenterProblem(y, b, epsilon=1e-1) + bar_prob = barycenter_problem.BarycenterProblem(y, b, epsilon=1e-1) else: - bar_prob = bar_problems.BarycenterProblem( + bar_prob = barycenter_problem.BarycenterProblem( y, b, epsilon=1e-1, num_per_segment=num_per_segment ) - solver = continuous_barycenter.WassersteinBarycenter(threshold=threshold) + solver = cb.WassersteinBarycenter(threshold=threshold) return solver(bar_prob) rngs = jax.random.split(rng, 20) @@ -220,7 +224,7 @@ def test_bures_barycenter( num_per_segment=(num_components, num_components), padding_vector=bures_cost.padder(y.shape[1]), ) - bar_p = bar_problems.BarycenterProblem( + bar_p = barycenter_problem.BarycenterProblem( seg_y, seg_b, weights=barycentric_weights, @@ -231,9 +235,7 @@ def test_bures_barycenter( assert bar_p.max_measure_size == seg_y.shape[1] assert bar_p.ndim == seg_y.shape[2] - solver = continuous_barycenter.WassersteinBarycenter( - lse_mode=lse_mode, jit=jit - ) + solver = cb.WassersteinBarycenter(lse_mode=lse_mode, jit=jit) out = solver(bar_p, bar_size=bar_size, x_init=x_init) barycenter = out.x @@ -318,17 +320,17 @@ def test_bures_barycenter_different_number_of_components( for i in range(num_measures)] # positions of mass of the measures - ys = jnp.vstack( + ys = jnp.vstack([ means_and_covs_to_x(means_covs[i][0], means_covs[i][1], dim) for i in range(num_measures) - ) + ]) # mass distribution of the measures weights = [ gmm_generators[i].component_weight_ob.probs() for i in range(num_measures) ] - bs = jnp.hstack(jnp.array(weights[i]) for i in range(num_measures)) + bs = jnp.hstack([jnp.array(weights[i]) for i in range(num_measures)]) # random initialization of the barycenter gmm_generator = gaussian_mixture.GaussianMixture.from_random( @@ -340,7 +342,7 @@ def test_bures_barycenter_different_number_of_components( # test second interface for segmentation seg_ids = jnp.repeat(jnp.arange(num_measures), n_components) - bar_p = bar_problems.BarycenterProblem( + bar_p = barycenter_problem.BarycenterProblem( y=ys, b=bs, weights=barycentric_weights, @@ -354,7 +356,7 @@ def test_bures_barycenter_different_number_of_components( assert bar_p.num_measures == num_measures assert bar_p.ndim == ys.shape[-1] - solver = continuous_barycenter.WassersteinBarycenter(lse_mode=True, jit=jit) + solver = cb.WassersteinBarycenter(lse_mode=True, jit=jit) # Compute the barycenter. out = solver(bar_p, bar_size=bar_size, x_init=x_init) @@ -439,8 +441,8 @@ def test_gw_barycenter( "epsilon": epsilon } - problem_pc = bar_problems.GWBarycenterProblem(y=ys, b=bs, **kwargs) - problem_cost = bar_problems.GWBarycenterProblem( + problem_pc = gwb.GWBarycenterProblem(y=ys, b=bs, **kwargs) + problem_cost = gwb.GWBarycenterProblem( costs=costs, b=cbs, **kwargs, @@ -454,7 +456,7 @@ def test_gw_barycenter( assert problem_pc.ndim == self.NDIM assert problem_cost.ndim is None - solver = gw_barycenter.GromovWassersteinBarycenter(jit=True) + solver = gwb_solver.GromovWassersteinBarycenter(jit=True) out_pc = solver(problem_pc, bar_size=bar_size) out_cost = solver(problem_cost, bar_size=bar_size) @@ -479,8 +481,8 @@ def test_fgw_barycenter( def barycenter( y: jnp.ndim, y_fused: jnp.ndarray, num_per_segment: Tuple[int, ...] - ) -> gw_barycenter.GWBarycenterState: - prob = bar_problems.GWBarycenterProblem( + ) -> gwb_solver.GWBarycenterState: + prob = gwb.GWBarycenterProblem( y=y, y_fused=y_fused, num_per_segment=num_per_segment, @@ -495,7 +497,7 @@ def barycenter( assert prob.ndim == self.NDIM assert prob.ndim_fused == self.NDIM_F - solver = gw_barycenter.GromovWassersteinBarycenter( + solver = gwb_solver.GromovWassersteinBarycenter( jit=False, store_inner_errors=True ) diff --git a/tests/core/discrete_barycenter_test.py b/tests/core/discrete_barycenter_test.py index c17d031d5..2b09ac991 100644 --- a/tests/core/discrete_barycenter_test.py +++ b/tests/core/discrete_barycenter_test.py @@ -18,8 +18,8 @@ import jax.numpy as jnp import pytest -from ott.core import discrete_barycenter as db from ott.geometry import grid, pointcloud +from ott.solvers.linear import discrete_barycenter as db class TestDiscreteBarycenter: diff --git a/tests/core/fused_gromov_wasserstein_test.py b/tests/core/fused_gromov_wasserstein_test.py index a2333c96e..ebd15e7de 100644 --- a/tests/core/fused_gromov_wasserstein_test.py +++ b/tests/core/fused_gromov_wasserstein_test.py @@ -21,8 +21,9 @@ import numpy as np import pytest -from ott.core import gromov_wasserstein, quad_problems from ott.geometry import geometry, low_rank, pointcloud +from ott.problems.quadratic import quadratic_problem +from ott.solvers.quadratic import gromov_wasserstein as gwb_solver class TestFusedGromovWasserstein: @@ -55,7 +56,7 @@ def test_flag_store_errors_fused(self): geom_x = pointcloud.PointCloud(self.x) geom_y = pointcloud.PointCloud(self.y) geom_xy = pointcloud.PointCloud(self.x_2, self.y_2) - out = gromov_wasserstein.gromov_wasserstein( + out = gwb_solver.gromov_wasserstein( geom_xx=geom_x, geom_yy=geom_y, geom_xy=geom_xy, @@ -66,7 +67,7 @@ def test_flag_store_errors_fused(self): ).errors assert out is None - out = gromov_wasserstein.gromov_wasserstein( + out = gwb_solver.gromov_wasserstein( geom_xx=geom_x, geom_yy=geom_y, geom_xy=geom_xy, @@ -86,7 +87,7 @@ def test_flag_store_errors_fused(self): assert out.ndim == 2 @pytest.mark.fast.with_args(jit=[False, True], only_fast=1) - def test_gradient_marginals_fused_gromov_wasserstein(self, jit: bool): + def test_gradient_marginals_fused_gwb_solver(self, jit: bool): """Test gradient w.r.t. probability weights.""" geom_x = pointcloud.PointCloud(self.x) geom_y = pointcloud.PointCloud(self.y) @@ -98,7 +99,7 @@ def reg_gw(a, b, implicit): 'implicit_differentiation': implicit, 'max_iterations': 1001 } - out = gromov_wasserstein.gromov_wasserstein( + out = gwb_solver.gromov_wasserstein( geom_x, geom_y, geom_xy=geom_xy, @@ -136,14 +137,14 @@ def reg_gw(a, b, implicit): ) @pytest.mark.fast.with_args(lse_mode=[False, True], only_fast=1) - def test_fused_gromov_wasserstein_pointcloud(self, lse_mode: bool): + def test_fused_gwb_solver_pointcloud(self, lse_mode: bool): """Test basic computations pointclouds.""" def reg_gw(x, y, x_2, y_2, fused_penalty, a, b): geom_x = pointcloud.PointCloud(x) geom_y = pointcloud.PointCloud(y) geom_xy = pointcloud.PointCloud(x_2, y_2) - return gromov_wasserstein.gromov_wasserstein( + return gwb_solver.gromov_wasserstein( geom_x, geom_y, geom_xy=geom_xy, @@ -163,7 +164,7 @@ def reg_gw(x, y, x_2, y_2, fused_penalty, a, b): assert cost is not None @pytest.mark.parametrize("lse_mode", [False, True]) - def test_gradient_fused_gromov_wasserstein_pointcloud(self, lse_mode: bool): + def test_gradient_fused_gwb_solver_pointcloud(self, lse_mode: bool): """Test gradient w.r.t. pointclouds.""" def reg_gw(x, y, x_2, y_2, fused_penalty, a, b, implicit): @@ -175,7 +176,7 @@ def reg_gw(x, y, x_2, y_2, fused_penalty, a, b, implicit): 'max_iterations': 1001, 'lse_mode': lse_mode } - return gromov_wasserstein.gromov_wasserstein( + return gwb_solver.gromov_wasserstein( geom_x, geom_y, geom_xy=geom_xy, @@ -206,7 +207,7 @@ def reg_gw(x, y, x_2, y_2, fused_penalty, a, b, implicit): ) @pytest.mark.parametrize("lse_mode", [False, True]) - def test_gradient_fused_gromov_wasserstein_geometry(self, lse_mode: bool): + def test_gradient_fused_gwb_solver_geometry(self, lse_mode: bool): """Test gradient w.r.t. cost matrices.""" def reg_gw(cx, cy, cxy, fused_penalty, a, b, implicit): @@ -218,7 +219,7 @@ def reg_gw(cx, cy, cxy, fused_penalty, a, b, implicit): 'max_iterations': 1001, 'lse_mode': lse_mode } - return gromov_wasserstein.gromov_wasserstein( + return gwb_solver.gromov_wasserstein( geom_x, geom_y, geom_xy=geom_xy, @@ -259,7 +260,7 @@ def test_adaptive_threshold_fused(self): # without warm start for calls to sinkhorn def loss_thre(threshold: float) -> float: - return gromov_wasserstein.gromov_wasserstein( + return gwb_solver.gromov_wasserstein( geom_xx=geom_x, geom_yy=geom_y, geom_xy=geom_xy, @@ -274,7 +275,7 @@ def loss_thre(threshold: float) -> float: assert loss_thre(1e-3) > loss_thre(1e-5) @pytest.mark.parametrize("lse_mode", [False, True]) - def test_gradient_fused_gromov_wasserstein_penalty(self, lse_mode: bool): + def test_gradient_fused_gwb_solver_penalty(self, lse_mode: bool): """Test gradient w.r.t. penalty.""" def reg_gw(cx, cy, cxy, fused_penalty, a, b, implicit): @@ -286,7 +287,7 @@ def reg_gw(cx, cy, cxy, fused_penalty, a, b, implicit): 'max_iterations': 1001, 'lse_mode': lse_mode } - return gromov_wasserstein.gromov_wasserstein( + return gwb_solver.gromov_wasserstein( geom_x, geom_y, geom_xy=geom_xy, @@ -318,7 +319,7 @@ def reg_fgw(x, y, x_2, y_2, fused_penalty, a, b): geom_y = pointcloud.PointCloud(y) geom_xy = pointcloud.PointCloud(x_2, y_2) sinkhorn_kwargs = {'max_iterations': 1001} - return gromov_wasserstein.gromov_wasserstein( + return gwb_solver.gromov_wasserstein( geom_x, geom_y, geom_xy=geom_xy, @@ -333,7 +334,7 @@ def reg_gw(x, y, a, b): geom_x = pointcloud.PointCloud(x) geom_y = pointcloud.PointCloud(y) sinkhorn_kwargs = {'max_iterations': 1001} - return gromov_wasserstein.gromov_wasserstein( + return gwb_solver.gromov_wasserstein( geom_x, geom_y, a=a, @@ -366,7 +367,7 @@ def test_fgw_lr_memory(self, rng: jnp.ndarray, jit: bool): geom_y = pointcloud.PointCloud(y) geom_xy = pointcloud.PointCloud(xx, yy) - ot_gwlr = gromov_wasserstein.gromov_wasserstein( + ot_gwlr = gwb_solver.gromov_wasserstein( geom_x, geom_y, geom_xy, rank=5, jit=jit ) res0 = ot_gwlr.apply(x.T, axis=0) @@ -391,14 +392,14 @@ def test_fgw_lr_generic_cost_matrix( geom_y = geometry.Geometry(cost_matrix=y @ y.T) geom_xy = geometry.Geometry(cost_matrix=xx @ yy.T) - problem = quad_problems.QuadraticProblem( + problem = quadratic_problem.QuadraticProblem( geom_x, geom_y, geom_xy, ranks=cost_rank, tolerances=5e-1 ) assert problem._is_low_rank_convertible lr_prob = problem.to_low_rank() assert lr_prob.is_low_rank - solver = gromov_wasserstein.GromovWasserstein(rank=5, epsilon=1) + solver = gwb_solver.GromovWasserstein(rank=5, epsilon=1) out = solver(problem) assert solver.rank == 5 diff --git a/tests/core/gromov_wasserstein_test.py b/tests/core/gromov_wasserstein_test.py index 0ded70d48..27e7f58b8 100644 --- a/tests/core/gromov_wasserstein_test.py +++ b/tests/core/gromov_wasserstein_test.py @@ -21,8 +21,9 @@ import numpy as np import pytest -from ott.core import gromov_wasserstein, quad_problems from ott.geometry import geometry, low_rank, pointcloud +from ott.problems.quadratic import quadratic_problem +from ott.solvers.quadratic import gromov_wasserstein @pytest.mark.fast @@ -48,7 +49,9 @@ def test_quad_to_low_rank( geom_yy = geometry.Geometry(geom_yy.cost_matrix) geom_xy = geometry.Geometry(geom_xy.cost_matrix) - prob = quad_problems.QuadraticProblem(geom_xx, geom_yy, geom_xy, ranks=rank) + prob = quadratic_problem.QuadraticProblem( + geom_xx, geom_yy, geom_xy, ranks=rank + ) assert not prob.is_low_rank # point clouds are always converted, if possible @@ -94,7 +97,7 @@ def test_implicit_conversion_mixed_input(self, rng: jnp.ndarray): geom_xx = pointcloud.PointCloud(x) geom_yy = pointcloud.PointCloud(y).to_LRCGeometry() - prob = quad_problems.QuadraticProblem(geom_xx, geom_yy, ranks=-1) + prob = quadratic_problem.QuadraticProblem(geom_xx, geom_yy, ranks=-1) lr_prob = prob.to_low_rank() assert prob._is_low_rank_convertible @@ -325,7 +328,7 @@ def test_gw_lr(self, rng: jnp.ndarray): geom_xx = pointcloud.PointCloud(x) geom_yy = pointcloud.PointCloud(y) - prob = quad_problems.QuadraticProblem(geom_xx, geom_yy, a=a, b=b) + prob = quadratic_problem.QuadraticProblem(geom_xx, geom_yy, a=a, b=b) solver = gromov_wasserstein.GromovWasserstein(rank=5, epsilon=0.2) ot_gwlr = solver(prob) solver = gromov_wasserstein.GromovWasserstein(epsilon=0.2) @@ -347,7 +350,7 @@ def test_gw_lr_matches_fused(self, rng: jnp.ndarray): geom_xx = pointcloud.PointCloud(x) geom_yy = pointcloud.PointCloud(y) geom_xy = pointcloud.PointCloud(x, z) # only used to compute n x m matrix - prob = quad_problems.QuadraticProblem( + prob = quadratic_problem.QuadraticProblem( geom_xx, geom_yy, geom_xy=geom_xy, fused_penalty=1.3, a=a, b=b ) solver = gromov_wasserstein.GromovWasserstein(rank=6) diff --git a/tests/core/icnn_test.py b/tests/core/icnn_test.py index 43b5be23c..b45d2c521 100644 --- a/tests/core/icnn_test.py +++ b/tests/core/icnn_test.py @@ -20,7 +20,7 @@ import numpy as np import pytest -from ott.core.icnn import ICNN +from ott.solvers.nn import icnn @pytest.mark.fast @@ -32,22 +32,22 @@ def test_icnn_convexity(self, rng: jnp.ndarray): dim_hidden = (64, 64) # define icnn model - icnn = ICNN(dim_hidden) + model = icnn.ICNN(dim_hidden) # initialize model key1, key2, key3 = jax.random.split(rng, 3) - params = icnn.init(key1, jnp.ones(n_features))['params'] + params = model.init(key1, jnp.ones(n_features))['params'] # check convexity x = jax.random.normal(key1, (n_samples, n_features)) * 0.1 y = jax.random.normal(key2, (n_samples, n_features)) - out_x = icnn.apply({'params': params}, x) - out_y = icnn.apply({'params': params}, y) + out_x = model.apply({'params': params}, x) + out_y = model.apply({'params': params}, y) out = list() for t in jnp.linspace(0, 1): - out_xy = icnn.apply({'params': params}, t * x + (1 - t) * y) + out_xy = model.apply({'params': params}, t * x + (1 - t) * y) out.append((t * out_x + (1 - t) * out_y) - out_xy) np.testing.assert_array_equal(jnp.asarray(out) >= 0, True) @@ -58,17 +58,17 @@ def test_icnn_hessian(self, rng: jnp.ndarray): # define icnn model n_samples = 2 dim_hidden = (64, 64) - icnn = ICNN(dim_hidden) + model = icnn.ICNN(dim_hidden) # initialize model key1, key2 = jax.random.split(rng) - params = icnn.init(key1, jnp.ones(n_samples))['params'] + params = model.init(key1, jnp.ones(n_samples))['params'] # check if Hessian is positive-semidefinite via eigenvalues data = jax.random.normal(key2, (n_samples,)) # compute Hessian - hessian = jax.jacfwd(jax.jacrev(icnn.apply, argnums=1), argnums=1) + hessian = jax.jacfwd(jax.jacrev(model.apply, argnums=1), argnums=1) icnn_hess = hessian({'params': params}, data) # compute eigenvalues diff --git a/tests/core/initializers_test.py b/tests/core/initializers_test.py index 01df6ad6d..47cdc10e4 100644 --- a/tests/core/initializers_test.py +++ b/tests/core/initializers_test.py @@ -19,17 +19,15 @@ import numpy as np import pytest -from ott.core import gromov_wasserstein -from ott.core import initializers as init_lib -from ott.core import ( - initializers_lr, - linear_problems, - quad_initializers, - quad_problems, - sinkhorn, - sinkhorn_lr, -) +import ott.initializers.nn.initializers from ott.geometry import geometry, low_rank, pointcloud +from ott.initializers.linear import initializers as lin_init +from ott.initializers.linear import initializers_lr +from ott.initializers.quadratic import initializers as quad_init +from ott.problems.linear import linear_problem +from ott.problems.quadratic import quadratic_problem +from ott.solvers.linear import sinkhorn, sinkhorn_lr +from ott.solvers.quadratic import gromov_wasserstein def create_sorting_problem(rng, n, epsilon=0.01, online=False): @@ -56,7 +54,7 @@ def create_sorting_problem(rng, n, epsilon=0.01, online=False): epsilon=epsilon, batch_size=batch_size ) - ot_problem = linear_problems.LinearProblem(geom=geom, a=a, b=b) + ot_problem = linear_problem.LinearProblem(geom=geom, a=a, b=b) return ot_problem @@ -77,7 +75,7 @@ def create_ot_problem(rng, n, m, d, epsilon=0.01, online=False): batch_size = 3 if online else None geom = pointcloud.PointCloud(x, y, epsilon=epsilon, batch_size=batch_size) - ot_problem = linear_problems.LinearProblem(geom=geom, a=a, b=b) + ot_problem = linear_problem.LinearProblem(geom=geom, a=a, b=b) return ot_problem @@ -87,7 +85,7 @@ def run_sinkhorn_sort_init( x, y, a=None, b=None, epsilon=0.01, vector_min=True, lse_mode=True ): geom = pointcloud.PointCloud(x, y, epsilon=epsilon) - sort_init = init_lib.SortingInitializer(vectorized_update=vector_min) + sort_init = lin_init.SortingInitializer(vectorized_update=vector_min) out = sinkhorn.sinkhorn( geom, a=a, b=b, jit=True, initializer=sort_init, lse_mode=lse_mode ) @@ -109,7 +107,7 @@ def run_sinkhorn_gaus_init(x, y, a=None, b=None, epsilon=0.01, lse_mode=True): a=a, b=b, jit=True, - initializer=init_lib.GaussianInitializer(), + initializer=lin_init.GaussianInitializer(), lse_mode=lse_mode ) return out @@ -122,12 +120,12 @@ def test_init_pytree(self): @jax.jit def init_sort(): - init = init_lib.SortingInitializer() + init = lin_init.SortingInitializer() return init @jax.jit def init_gaus(): - init = init_lib.GaussianInitializer() + init = lin_init.GaussianInitializer() return init _ = init_gaus() @@ -136,18 +134,18 @@ def init_gaus(): @pytest.mark.parametrize( "init", [ "default", "gaussian", "sorting", - init_lib.DefaultInitializer(), "non-existent" + lin_init.DefaultInitializer(), "non-existent" ] ) def test_create_initializer(self, init: str): solver = sinkhorn.Sinkhorn(initializer=init) expected_types = { - "default": init_lib.DefaultInitializer, - "gaussian": init_lib.GaussianInitializer, - "sorting": init_lib.SortingInitializer, + "default": lin_init.DefaultInitializer, + "gaussian": lin_init.GaussianInitializer, + "sorting": lin_init.SortingInitializer, } - if isinstance(init, init_lib.SinkhornInitializer): + if isinstance(init, lin_init.SinkhornInitializer): assert solver.create_initializer() is init elif init == "non-existent": with pytest.raises(NotImplementedError, match=r""): @@ -201,7 +199,7 @@ def test_sorting_init_online(self, rng: jnp.ndarray): ot_problem = create_sorting_problem( rng=rng, n=n, epsilon=epsilon, online=True ) - sort_init = init_lib.SortingInitializer(vectorized_update=True) + sort_init = lin_init.SortingInitializer(vectorized_update=True) with pytest.raises(AssertionError, match=r"online"): sort_init.init_dual_a(ot_problem, lse_mode=True) @@ -212,7 +210,7 @@ def test_sorting_init_square_cost(self, rng: jnp.ndarray): epsilon = 0.01 ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False) - sort_init = init_lib.SortingInitializer(vectorized_update=True) + sort_init = lin_init.SortingInitializer(vectorized_update=True) with pytest.raises(AssertionError, match=r"square"): sort_init.init_dual_a(ot_problem, lse_mode=True) @@ -225,10 +223,10 @@ def test_default_initializer(self, rng: jnp.ndarray): ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False) - default_potential_a = init_lib.DefaultInitializer().init_dual_a( + default_potential_a = lin_init.DefaultInitializer().init_dual_a( ot_problem, lse_mode=True ) - default_potential_b = init_lib.DefaultInitializer().init_dual_b( + default_potential_b = lin_init.DefaultInitializer().init_dual_b( ot_problem, lse_mode=True ) @@ -244,11 +242,11 @@ def test_gauss_pointcloud_geom(self, rng: jnp.ndarray): ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False) - gaus_init = init_lib.GaussianInitializer() + gaus_init = lin_init.GaussianInitializer() new_geom = geometry.Geometry( cost_matrix=ot_problem.geom.cost_matrix, epsilon=epsilon ) - ot_problem = linear_problems.LinearProblem( + ot_problem = linear_problem.LinearProblem( geom=new_geom, a=ot_problem.a, b=ot_problem.b ) @@ -316,7 +314,7 @@ def test_meta_initializer(self, lse_mode, rng: jnp.ndarray): base_num_iter = jnp.sum(sink_out.errors > -1) # Overfit the initializer to the problem. - meta_initializer = init_lib.MetaInitializer(geom) + meta_initializer = ott.initializers.nn.initializers.MetaInitializer(geom) for _ in range(100): _, _, meta_initializer.state = meta_initializer.update( meta_initializer.state, a=a, b=b @@ -354,7 +352,7 @@ def test_create_default_initializer(self, rng: jnp.ndarray, kind: str): geom = geometry.Geometry(geom.cost_matrix) else: raise NotImplementedError(geom) - prob = linear_problems.LinearProblem(geom) + prob = linear_problem.LinearProblem(geom) solver = sinkhorn_lr.LRSinkhorn(rank=rank, initializer=None) initializer = solver.create_initializer(prob) @@ -389,7 +387,7 @@ def test_partial_initialization( key1, key2, key3, key4 = jax.random.split(rng, 4) x = jax.random.normal(key1, (n, d)) pc = pointcloud.PointCloud(x, epsilon=5e-1) - prob = linear_problems.LinearProblem(pc) + prob = linear_problem.LinearProblem(pc) q_init = jax.random.normal(key2, (n, rank)) r_init = jax.random.normal(key2, (n, rank)) g_init = jax.random.normal(key2, (rank,)) @@ -416,7 +414,7 @@ def test_generalized_k_means_has_correct_rank( n, d = 100, 10 x = jax.random.normal(rng, (n, d)) pc = pointcloud.PointCloud(x, epsilon=5e-1) - prob = linear_problems.LinearProblem(pc) + prob = linear_problem.LinearProblem(pc) solver = sinkhorn_lr.LRSinkhorn( rank=rank, initializer="generalized-k-means" @@ -437,8 +435,8 @@ def test_generalized_k_means_matches_k_means(self, rng: jnp.ndarray): pc = pointcloud.PointCloud(x, y, epsilon=eps) geom = geometry.Geometry(cost_matrix=pc.cost_matrix, epsilon=eps) - pc_problem = linear_problems.LinearProblem(pc) - geom_problem = linear_problems.LinearProblem(geom) + pc_problem = linear_problem.LinearProblem(pc) + geom_problem = linear_problem.LinearProblem(geom) solver = sinkhorn_lr.LRSinkhorn( rank=rank, initializer="k-means", max_iterations=5000 @@ -464,7 +462,7 @@ def test_better_initialization_helps(self, rng: jnp.ndarray, epsilon: float): x = jax.random.normal(key1, (n, d)) y = jax.random.normal(key2, (n, d)) pc = pointcloud.PointCloud(x, y, epsilon=5e-1) - prob = linear_problems.LinearProblem(pc) + prob = linear_problem.LinearProblem(pc) solver_random = sinkhorn_lr.LRSinkhorn( rank=rank, epsilon=epsilon, initializer="random", max_iterations=10000 @@ -507,14 +505,14 @@ def test_create_default_lr_initializer(self, rng: jnp.ndarray, kind: str): geom_y = geometry.Geometry(geom_y.cost_matrix, epsilon=eps) else: raise NotImplementedError(kind) - prob = quad_problems.QuadraticProblem(geom_x, geom_y) + prob = quadratic_problem.QuadraticProblem(geom_x, geom_y) solver = gromov_wasserstein.GromovWasserstein( rank=rank, quad_initializer=None, kwargs_init=kwargs_init ) initializer = solver.create_initializer(prob) - assert isinstance(initializer, quad_initializers.LRQuadraticInitializer) + assert isinstance(initializer, quad_init.LRQuadraticInitializer) assert initializer.rank == rank linear_init = initializer._linear_lr_initializer if kind in ("pc", "lrc"): @@ -528,24 +526,24 @@ def test_non_lr_initializer(self): rank=-1, quad_initializer="not used" ) initializer = solver.create_initializer(prob="not used") - assert isinstance(initializer, quad_initializers.QuadraticInitializer) + assert isinstance(initializer, quad_init.QuadraticInitializer) @pytest.mark.parametrize("rank", [-1, 2]) def test_explicitly_passing_initializer(self, rank: int): if rank == -1: - linear_init = init_lib.SortingInitializer() - quad_init = quad_initializers.QuadraticInitializer() + linear_init = lin_init.SortingInitializer() + q_init = quad_init.QuadraticInitializer() else: linear_init = initializers_lr.Rank2Initializer(rank) - quad_init = quad_initializers.LRQuadraticInitializer(linear_init) + q_init = quad_init.LRQuadraticInitializer(linear_init) solver = gromov_wasserstein.GromovWasserstein( initializer=linear_init, - quad_initializer=quad_init, + quad_initializer=q_init, ) assert solver.linear_ot_solver.initializer is linear_init - assert solver.quad_initializer is quad_init + assert solver.quad_initializer is q_init if solver.is_low_rank: assert solver.quad_initializer.rank == rank @@ -564,7 +562,7 @@ def test_gw_better_initialization_helps(self, rng: jnp.ndarray, eps: float): jax.random.normal(key4, (m, d2)), epsilon=eps, ) - problem = quad_problems.QuadraticProblem(geom_x, geom_y) + problem = quadratic_problem.QuadraticProblem(geom_x, geom_y) solver_random = gromov_wasserstein.GromovWasserstein( rank=rank, initializer="random", diff --git a/tests/core/neuraldual_test.py b/tests/core/neuraldual_test.py index 289fc59b5..4fa9c009b 100644 --- a/tests/core/neuraldual_test.py +++ b/tests/core/neuraldual_test.py @@ -21,7 +21,7 @@ import pytest from typing_extensions import Literal -from ott.core.neuraldual import NeuralDualSolver +from ott.solvers.nn import neuraldual class ToyDataset: @@ -94,7 +94,7 @@ def decreasing(losses: Sequence[float]) -> bool: dataloader_source, dataloader_target = toy_dataset # initialize neural dual - neural_dual_solver = NeuralDualSolver( + neural_dual_solver = neuraldual.NeuralDualSolver( input_dim=2, num_train_iters=num_train_iters, logging=True, @@ -113,7 +113,7 @@ def test_neural_dual_jit(self, toy_dataset: Tuple[ToyDataset, ToyDataset]): num_train_iters = 10 dataloader_source, dataloader_target = toy_dataset # initialize neural dual - neural_dual_solver = NeuralDualSolver( + neural_dual_solver = neuraldual.NeuralDualSolver( input_dim=2, num_train_iters=num_train_iters ) neural_dual = neural_dual_solver( diff --git a/tests/core/potentials_test.py b/tests/core/potentials_test.py index a381bf5d8..00476a72a 100644 --- a/tests/core/potentials_test.py +++ b/tests/core/potentials_test.py @@ -3,8 +3,9 @@ import numpy as np import pytest -from ott.core import Sinkhorn, linear_problems from ott.geometry import pointcloud +from ott.problems.linear import linear_problem +from ott.solvers.linear import sinkhorn from ott.tools import sinkhorn_divergence from ott.tools.gaussian_mixture import gaussian @@ -24,8 +25,8 @@ def test_entropic_potentials_dist(self, rng: jnp.ndarray, eps: float): y = g2.sample(key2, n2) geom = pointcloud.PointCloud(x, y, epsilon=eps) - prob = linear_problems.LinearProblem(geom) - out = Sinkhorn()(prob) + prob = linear_problem.LinearProblem(geom) + out = sinkhorn.Sinkhorn()(prob) assert out.converged potentials = out.to_dual_potentials() @@ -50,8 +51,8 @@ def test_entropic_potentials_displacement( y = g2.sample(key2, n2) geom = pointcloud.PointCloud(x, y, epsilon=eps) - prob = linear_problems.LinearProblem(geom) - out = Sinkhorn()(prob) + prob = linear_problem.LinearProblem(geom) + out = sinkhorn.Sinkhorn()(prob) assert out.converged potentials = out.to_dual_potentials() @@ -75,11 +76,11 @@ def test_distance_differentiability(self, rng: jnp.ndarray, jit: bool): x = jax.random.normal(key1, (n, d)) y = jax.random.normal(key2, (m, d)) - prob = linear_problems.LinearProblem(pointcloud.PointCloud(x, y)) + prob = linear_problem.LinearProblem(pointcloud.PointCloud(x, y)) v_x = jax.random.normal(key3, shape=x.shape) v_x = (v_x / jnp.linalg.norm(v_x, axis=-1, keepdims=True)) * 1e-3 - pots = Sinkhorn()(prob).to_dual_potentials() + pots = sinkhorn.Sinkhorn()(prob).to_dual_potentials() grad_dist = jax.grad(pots.distance) if jit: @@ -103,9 +104,9 @@ def test_potentials_sinkhorn_divergence( y = jax.random.normal(key2, (m, d)) + mu1 x_test = jax.random.normal(key3, (n, d)) + mu0 geom = pointcloud.PointCloud(x, y, epsilon=eps) - prob = linear_problems.LinearProblem(geom) + prob = linear_problem.LinearProblem(geom) - sink_pots = Sinkhorn()(prob).to_dual_potentials() + sink_pots = sinkhorn.Sinkhorn()(prob).to_dual_potentials() div_pots = sinkhorn_divergence.sinkhorn_divergence( type(geom), x, y, epsilon=eps ).to_dual_potentials() diff --git a/tests/core/sinkhorn_diff_test.py b/tests/core/sinkhorn_diff_test.py index 7e60a8309..c55711dc7 100644 --- a/tests/core/sinkhorn_diff_test.py +++ b/tests/core/sinkhorn_diff_test.py @@ -22,9 +22,10 @@ import numpy as np import pytest -from ott.core import implicit_differentiation as implicit_lib -from ott.core import linear_problems, sinkhorn from ott.geometry import costs, geometry, grid, pointcloud +from ott.math import implicit_differentiation as implicit_lib +from ott.problems.linear import linear_problem +from ott.solvers.linear import sinkhorn from ott.tools import transport @@ -765,7 +766,7 @@ def test_hessian_sinkhorn( def loss(a: jnp.ndarray, x: jnp.ndarray, implicit: bool = True): geom = pointcloud.PointCloud(x, y, epsilon=epsilon) - prob = linear_problems.LinearProblem(geom, a, b, tau_a, tau_b) + prob = linear_problem.LinearProblem(geom, a, b, tau_a, tau_b) implicit_diff = ( None if not implicit else implicit_lib.ImplicitDiff(ridge_kernel=ridge, ridge_identity=ridge) diff --git a/tests/core/sinkhorn_extra_test.py b/tests/core/sinkhorn_extra_test.py index 6bdc0b258..2f64a4c4a 100644 --- a/tests/core/sinkhorn_extra_test.py +++ b/tests/core/sinkhorn_extra_test.py @@ -15,7 +15,7 @@ # Lint as: python3 """Tests Anderson acceleration for sinkhorn.""" import functools -from typing import Any, Callable, Tuple +from typing import Callable, Tuple import chex import jax @@ -23,8 +23,8 @@ import numpy as np import pytest -from ott.core import sinkhorn from ott.geometry import costs, geometry, pointcloud +from ott.solvers.linear import sinkhorn non_jitted_sinkhorn = functools.partial(sinkhorn.sinkhorn, jit=False) @@ -330,7 +330,7 @@ def f( def test_jit_vs_non_jit_bwd(self, implicit: bool): def loss( - a: jnp.ndarray, x: jnp.ndarray, fun: Callable[[Any], + a: jnp.ndarray, x: jnp.ndarray, fun: Callable[..., sinkhorn.SinkhornOutput] ): out = fun( diff --git a/tests/core/sinkhorn_grid_test.py b/tests/core/sinkhorn_grid_test.py index 0c4180561..7937ce717 100644 --- a/tests/core/sinkhorn_grid_test.py +++ b/tests/core/sinkhorn_grid_test.py @@ -20,8 +20,8 @@ import numpy as np import pytest -from ott.core import sinkhorn from ott.geometry import grid, pointcloud +from ott.solvers.linear import sinkhorn class TestSinkhornGrid: diff --git a/tests/core/sinkhorn_lr_test.py b/tests/core/sinkhorn_lr_test.py index 3782b6e0b..afb6961e4 100644 --- a/tests/core/sinkhorn_lr_test.py +++ b/tests/core/sinkhorn_lr_test.py @@ -19,8 +19,9 @@ import numpy as np import pytest -from ott.core import linear_problems, sinkhorn_lr from ott.geometry import low_rank, pointcloud +from ott.problems.linear import linear_problem +from ott.solvers.linear import sinkhorn_lr class TestLRSinkhorn: @@ -58,7 +59,7 @@ def test_euclidean_point_cloud_lr( if use_lrcgeom: geom = geom.to_LRCGeometry() assert isinstance(geom, low_rank.LRCGeometry) - ot_prob = linear_problems.LinearProblem(geom, self.a, self.b) + ot_prob = linear_problem.LinearProblem(geom, self.a, self.b) # Start with a low rank parameter solver = sinkhorn_lr.LRSinkhorn( @@ -131,7 +132,7 @@ def test_output_apply_batch_size(self, axis: int): data = self.a if axis == 0 else self.b geom = pointcloud.PointCloud(self.x, self.y) - ot_prob = linear_problems.LinearProblem(geom, self.a, self.b) + ot_prob = linear_problem.LinearProblem(geom, self.a, self.b) solver = sinkhorn_lr.LRSinkhorn( threshold=threshold, rank=10, diff --git a/tests/core/sinkhorn_test.py b/tests/core/sinkhorn_test.py index ae3d04255..fba7a243e 100644 --- a/tests/core/sinkhorn_test.py +++ b/tests/core/sinkhorn_test.py @@ -20,8 +20,9 @@ import numpy as np import pytest -from ott.core import linear_problems, sinkhorn from ott.geometry import costs, geometry, pointcloud +from ott.problems.linear import linear_problem +from ott.solvers.linear import sinkhorn class TestSinkhorn: @@ -455,7 +456,7 @@ def test_sinkhorn_online_memory(self, batch_size: int): x = jax.random.uniform(rngs[0], (n, 2)) y = jax.random.uniform(rngs[1], (m, 2)) geom = pointcloud.PointCloud(x, y, batch_size=batch_size, epsilon=1) - problem = linear_problems.LinearProblem(geom) + problem = linear_problem.LinearProblem(geom) solver = sinkhorn.Sinkhorn() out = solver(problem) diff --git a/tests/geometry/geometry_lse_test.py b/tests/geometry/geometry_lse_test.py index d7a030107..f0b076147 100644 --- a/tests/geometry/geometry_lse_test.py +++ b/tests/geometry/geometry_lse_test.py @@ -20,7 +20,7 @@ import numpy as np import pytest -from ott.geometry import ops +from ott.math import utils as mu @pytest.mark.fast @@ -36,7 +36,7 @@ def test_lse(self, rng: jnp.ndarray): b_1 = jax.random.normal(keys[2], (n, 1)) def lse_(x, axis, b, return_sign): - out = ops.logsumexp(x, axis, False, b, return_sign) + out = mu.logsumexp(x, axis, False, b, return_sign) return jnp.sum(out[0] if return_sign else out) lse = jax.value_and_grad(lse_, argnums=(0, 2)) diff --git a/tests/geometry/graph_test.py b/tests/geometry/graph_test.py index 3bf976e9f..97b245a59 100644 --- a/tests/geometry/graph_test.py +++ b/tests/geometry/graph_test.py @@ -11,10 +11,11 @@ from networkx.generators import balanced_tree, random_graphs from typing_extensions import Literal -from ott.core import decomposition -from ott.core import implicit_differentiation as implicit_lib -from ott.core import linear_problems, sinkhorn from ott.geometry import geometry, graph +from ott.math import decomposition +from ott.math import implicit_differentiation as implicit_lib +from ott.problems.linear import linear_problem +from ott.solvers.linear import sinkhorn # we mix both dense/sparse tests sksparse = pytest.importorskip("sksparse") @@ -373,7 +374,7 @@ def test_graph_sinkhorn( def callback(geom: geometry.Geometry) -> sinkhorn.SinkhornOutput: solver = sinkhorn.Sinkhorn(lse_mode=False) - problem = linear_problems.LinearProblem(geom) + problem = linear_problem.LinearProblem(geom) return solver(problem) n, eps, tol = 11, 1e-5, 1e-3 @@ -422,7 +423,7 @@ def callback( geom = graph.Graph(G, t=1.) solver = sinkhorn.Sinkhorn(lse_mode=False, **kwargs) - problem = linear_problems.LinearProblem(geom) + problem = linear_problem.LinearProblem(geom) return solver(problem).reg_ot_cost diff --git a/tests/geometry/matrix_square_root_test.py b/tests/geometry/matrix_square_root_test.py index 0b48fb706..b5d3c08d8 100644 --- a/tests/geometry/matrix_square_root_test.py +++ b/tests/geometry/matrix_square_root_test.py @@ -21,7 +21,7 @@ import numpy as np import pytest -from ott.geometry import matrix_square_root +from ott.math import matrix_square_root def _get_random_spd_matrix(dim: int, key: jnp.ndarray): @@ -56,7 +56,7 @@ def _get_test_fn( unit = jax.random.normal(key=subkey3, shape=(dim, dim)) unit /= jnp.sqrt(jnp.sum(unit ** 2.)) - def _test_fn(x: float) -> float: + def _test_fn(x: jnp.ndarray) -> jnp.ndarray: # m is the product of 2 symmetric, positive definite matrices # so it will be positive definite but not necessarily symmetric m = jnp.matmul(m0, m1 + x * dx) diff --git a/tests/geometry/scaling_cost_test.py b/tests/geometry/scaling_cost_test.py index 2767c1ea5..8a6752608 100644 --- a/tests/geometry/scaling_cost_test.py +++ b/tests/geometry/scaling_cost_test.py @@ -19,8 +19,9 @@ import numpy as np import pytest -from ott.core import linear_problems, sinkhorn, sinkhorn_lr from ott.geometry import geometry, low_rank, pointcloud +from ott.problems.linear import linear_problem +from ott.solvers.linear import sinkhorn, sinkhorn_lr class TestScaleCost: @@ -154,7 +155,7 @@ def test_scale_cost_low_rank(self, scale: Union[str, float]): def apply_sinkhorn(cost1, cost2, scale_cost): geom = low_rank.LRCGeometry(cost1, cost2, scale_cost=scale_cost) - ot_prob = linear_problems.LinearProblem(geom, self.a, self.b) + ot_prob = linear_problem.LinearProblem(geom, self.a, self.b) solver = sinkhorn_lr.LRSinkhorn(rank=5, threshold=1e-3) out = solver(ot_prob) return geom, out diff --git a/tests/tools/gaussian_mixture/scale_tril_test.py b/tests/tools/gaussian_mixture/scale_tril_test.py index ef37cff8f..32f2f93f3 100644 --- a/tests/tools/gaussian_mixture/scale_tril_test.py +++ b/tests/tools/gaussian_mixture/scale_tril_test.py @@ -18,7 +18,7 @@ import numpy as np import pytest -from ott.geometry import matrix_square_root +from ott.math import matrix_square_root from ott.tools.gaussian_mixture import scale_tril diff --git a/tests/tools/segment_sinkhorn_test.py b/tests/tools/segment_sinkhorn_test.py index e7269d2da..5a1e81c1e 100644 --- a/tests/tools/segment_sinkhorn_test.py +++ b/tests/tools/segment_sinkhorn_test.py @@ -20,8 +20,8 @@ import numpy as np import pytest -from ott.core import sinkhorn from ott.geometry import costs, pointcloud +from ott.solvers.linear import sinkhorn from ott.tools import segment_sinkhorn from ott.tools.gaussian_mixture import gaussian_mixture diff --git a/tests/tools/sinkhorn_divergence_test.py b/tests/tools/sinkhorn_divergence_test.py index 152002638..2d5d63d5d 100644 --- a/tests/tools/sinkhorn_divergence_test.py +++ b/tests/tools/sinkhorn_divergence_test.py @@ -21,8 +21,8 @@ import numpy as np import pytest -from ott.core import sinkhorn from ott.geometry import costs, geometry, pointcloud +from ott.solvers.linear import sinkhorn from ott.tools import sinkhorn_divergence from ott.tools.gaussian_mixture import gaussian_mixture diff --git a/tests/tools/transport_test.py b/tests/tools/transport_test.py index 44c157261..4954c85bd 100644 --- a/tests/tools/transport_test.py +++ b/tests/tools/transport_test.py @@ -18,8 +18,8 @@ import numpy as np import pytest -from ott.core import linear_problems from ott.geometry import pointcloud +from ott.problems.linear import linear_problem from ott.tools import transport @@ -60,7 +60,7 @@ def test_transport_from_problem(self, rng: jnp.ndarray): geom = pointcloud.PointCloud(x, y, batch_size=9) b = jax.random.uniform(rngs[2], (num_b,)) b /= jnp.sum(b) - pb = linear_problems.LinearProblem(geom, b=b) + pb = linear_problem.LinearProblem(geom, b=b) ot = transport.solve(pb) np.testing.assert_array_equal(ot.matrix.shape, (num_a, num_b)) From eb384eb7e831c8c40b0c8acccbadcf02374d017a Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Wed, 16 Nov 2022 15:34:43 +0100 Subject: [PATCH 06/34] Update cost funcs and potentials --- ott/geometry/costs.py | 18 ++++++++++-------- ott/math/potentials.py | 32 +++++++++++++++++++++----------- ott/solvers/linear/sinkhorn.py | 2 +- 3 files changed, 32 insertions(+), 20 deletions(-) diff --git a/ott/geometry/costs.py b/ott/geometry/costs.py index 065426661..084673300 100644 --- a/ott/geometry/costs.py +++ b/ott/geometry/costs.py @@ -52,7 +52,6 @@ class CostFn(abc.ABC): def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: pass - @abc.abstractmethod def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> float: raise NotImplementedError("Barycenter not yet implemented for this cost.") @@ -115,14 +114,14 @@ class TICost(CostFn): @abc.abstractmethod def h(self, z: jnp.ndarray) -> float: - """RBF function acting on difference of `x-y` to ouput cost.""" + """TI function acting on difference of :math:`x-y` to output cost.""" def h_legendre(self, z: jnp.ndarray) -> float: - """Legendre transform of RBF function `h` (when latter is convex).""" + """Legendre transform of TI function :func:`h` (when latter is convex).""" raise NotImplementedError("`h_legendre` not implemented.") def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: - """Compute cost as evaluation of :func:`h` on `x-y`.""" + """Compute cost as evaluation of :func:`h` on :math:`x-y`.""" return self.h(x - y) @@ -135,9 +134,10 @@ class SqPNorm(TICost): """ def __init__(self, p: float): + super().__init__() assert p >= 1.0, "p parameter in sq. p-norm should be >= 1.0" self.p = p - self.q = 1. / (1 - 1 / self.p) if p > 1.0 else 'inf' + self.q = 1. / (1. - 1. / self.p) if p > 1.0 else "inf" def h(self, z: jnp.ndarray) -> float: return 0.5 * jnp.linalg.norm(z, self.p) ** 2 @@ -159,9 +159,10 @@ class PNorm(TICost): """p-norm (to the power p) of the difference of two vectors.""" def __init__(self, p: float): + super().__init__() assert p >= 1.0, "p parameter in p-norm should be >= 1.0" self.p = p - self.q = 1. / (1 - 1 / self.p) + self.q = 1. / (1. - 1. / self.p) if p > 1. else "inf" def h(self, z: jnp.ndarray) -> float: return jnp.linalg.norm(z, self.p) ** self.p / self.p @@ -182,8 +183,9 @@ def tree_unflatten(cls, aux_data, children): class Euclidean(CostFn): """Euclidean distance. - Note that the Euclidean distance is not cast as a `TICost`, because this - would correspond to `h = jnp.linalg.norm`, whose gradient is not invertible, + Note that the Euclidean distance is not cast as a + :class:`~ott.geometry.costs.TICost`, since this would correspond to :math:`h` + being :func:`jax.numpy.linalg.norm`, whose gradient is not invertible, because the function is not strictly convex (it is linear on rays). """ diff --git a/ott/math/potentials.py b/ott/math/potentials.py index 50d62b3a6..cf2914c67 100644 --- a/ott/math/potentials.py +++ b/ott/math/potentials.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Callable, Dict, Sequence, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, Tuple import jax import jax.numpy as jnp @@ -6,10 +6,8 @@ import jax.tree_util as jtu from typing_extensions import Literal -from ott.geometry import costs - if TYPE_CHECKING: - from ott.geometry import pointcloud + from ott.geometry import costs, pointcloud __all__ = ["DualPotentials", "EntropicPotentials"] Potential_t = Callable[[jnp.ndarray], float] @@ -27,15 +25,16 @@ class DualPotentials: g: The second dual potential function. cost_fn: The cost function used to solve the OT problem. corr: Whether the duals solve the problem in distance form, or correlation - form (as used for instance for ICNNs, see e.g. top right of p.3 in :cite:`makkuva:20`) + form (as used for instance for ICNNs, see, e.g., top right of p.3 in + :cite:`makkuva:20`) """ def __init__( self, f: Potential_t, g: Potential_t, - cost_fn: costs.CostFn, *, + cost_fn: 'costs.CostFn', corr: bool = False ): self._f = f @@ -44,7 +43,7 @@ def __init__( self._corr = corr def transport(self, vec: jnp.ndarray, forward: bool = True) -> jnp.ndarray: - r"""Transport ``vec`` according to Brenier formula. + r"""Transport ``vec`` according to Brenier formula :cite:`brenier:91`. Uses Theorem 1.17 from :cite:`santambrogio:15` to compute an OT map when given the Legendre transform of the dual potentials. @@ -66,6 +65,8 @@ def transport(self, vec: jnp.ndarray, forward: bool = True) -> jnp.ndarray: Returns: The transported points. """ + from ott.geometry import costs + vec = jnp.atleast_2d(vec) if self._corr and isinstance(self.cost_fn, costs.SqEuclidean): return self._grad_g(vec) if forward else self._grad_f(vec) @@ -127,6 +128,8 @@ def _grad_g(self) -> Callable[[jnp.ndarray], jnp.ndarray]: @property def _grad_h_inv(self) -> Callable[[jnp.ndarray], jnp.ndarray]: + from ott.geometry import costs + assert isinstance(self.cost_fn, costs.TICost), ( "Cost must be a `TICost` and " "provide access to Legendre transform of `h`." @@ -152,15 +155,22 @@ class EntropicPotentials(DualPotentials): g: The second dual potential vector of shape ``[m,]``. geom: Geometry used to compute the dual potentials using :class:`~ott.core.sinkhorn.Sinkhorn`. - a: probability weights for the first measure. - b: probaility weights for the second measure. + a: Probability weights for the first measure. If `None`, use uniform. + b: Probability weights for the second measure. If `None`, use uniform. """ def __init__( - self, f: jnp.ndarray, g: jnp.ndarray, geom: "pointcloud.PointCloud", - a: jnp.ndarray, b: jnp.ndarray + self, + f: jnp.ndarray, + g: jnp.ndarray, + geom: "pointcloud.PointCloud", + a: Optional[jnp.ndarray] = None, + b: Optional[jnp.ndarray] = None, ): n, m = geom.shape + a = jnp.ones(n) / n if a is None else a + b = jnp.ones(m) / m if b is None else b + assert f.shape == (n,) and a.shape == (n,), \ f"Expected `f` and `a` to be of shape `{n,}`, found `{f.shape}`." assert g.shape == (m,) and b.shape == (m,), \ diff --git a/ott/solvers/linear/sinkhorn.py b/ott/solvers/linear/sinkhorn.py index b8201d6e5..84c6100d4 100644 --- a/ott/solvers/linear/sinkhorn.py +++ b/ott/solvers/linear/sinkhorn.py @@ -295,7 +295,7 @@ def transport_mass(self) -> float: def to_dual_potentials(self) -> potentials.EntropicPotentials: """Return the entropic map estimator.""" return potentials.EntropicPotentials( - self.f, self.g, self.geom, self.a, self.b + self.f, self.g, geom=self.geom, a=self.a, b=self.b ) From 5280021027795ba32faedd08f1f348a948150253 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Wed, 16 Nov 2022 16:20:02 +0100 Subject: [PATCH 07/34] Fix LR initializer --- ott/initializers/linear/initializers_lr.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ott/initializers/linear/initializers_lr.py b/ott/initializers/linear/initializers_lr.py index a8132816f..24b4fcaa8 100644 --- a/ott/initializers/linear/initializers_lr.py +++ b/ott/initializers/linear/initializers_lr.py @@ -380,8 +380,8 @@ def _compute_factor( which: Literal["q", "r"], **kwargs: Any, ) -> jnp.ndarray: - from ott.problems import linear as linear_problems - from ott.problems import quadratic as quad_problems + from ott.problems.linear import linear_problem + from ott.problems.quadratic import quadratic_problem from ott.solvers.linear import sinkhorn from ott.tools import k_means @@ -395,7 +395,7 @@ def _compute_factor( ) fn = jax.jit(fn, static_argnames="k") if jit else fn - if isinstance(ot_prob, quad_problems.QuadraticProblem): + if isinstance(ot_prob, quadratic_problem.QuadraticProblem): geom = ot_prob.geom_xx if which == "q" else ot_prob.geom_yy else: geom = ot_prob.geom @@ -407,7 +407,7 @@ def _compute_factor( arr, centroids, epsilon=0.1, scale_cost="max_cost" ) - prob = linear_problems.LinearProblem(geom, marginals, init_g) + prob = linear_problem.LinearProblem(geom, marginals, init_g) solver = sinkhorn.Sinkhorn(**self._sinkhorn_kwargs) return solver(prob).matrix From a5999b50eb048496ad778e821795fbc913cce998 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Wed, 16 Nov 2022 16:23:49 +0100 Subject: [PATCH 08/34] Fix k-means initializer --- ott/initializers/linear/initializers_lr.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ott/initializers/linear/initializers_lr.py b/ott/initializers/linear/initializers_lr.py index 24b4fcaa8..c3d83f735 100644 --- a/ott/initializers/linear/initializers_lr.py +++ b/ott/initializers/linear/initializers_lr.py @@ -512,8 +512,8 @@ def _compute_factor( which: Literal["q", "r"], **kwargs: Any, ) -> jnp.ndarray: - from ott.problems import linear as linear_problems - from ott.problems import quadratic as quad_problems + from ott.problems.linear import linear_problem + from ott.problems.quadratic import quadratic_problem from ott.solvers.linear import sinkhorn def init_fn() -> GeneralizedKMeansInitializer.State: @@ -588,7 +588,7 @@ def body_fn( cost_matrix=cost, epsilon=eps, ) - problem = linear_problems.LinearProblem( + problem = linear_problem.LinearProblem( cost, a=consts.marginal, b=consts.g ) @@ -614,7 +614,7 @@ def body_fn( del kwargs - if isinstance(ot_prob, quad_problems.QuadraticProblem): + if isinstance(ot_prob, quadratic_problem.QuadraticProblem): geom = ot_prob.geom_xx if which == "q" else ot_prob.geom_yy else: geom = ot_prob.geom From 35ae68a02e9e4a58b43b3a8dad702f6216329136 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Wed, 16 Nov 2022 16:42:29 +0100 Subject: [PATCH 09/34] Move `utils` --- ott/geometry/__init__.py | 2 +- ott/{utils => geometry}/segment.py | 2 +- ott/initializers/linear/initializers_lr.py | 2 +- ott/initializers/nn/initializers.py | 2 +- ott/math/__init__.py | 1 - ott/math/implicit_differentiation.py | 4 ++-- ott/problems/linear/__init__.py | 2 +- ott/problems/linear/barycenter_problem.py | 3 +-- ott/{math => problems/linear}/potentials.py | 0 ott/problems/quadratic/gw_barycenter.py | 5 ++--- ott/solvers/linear/acceleration.py | 10 +++++----- ott/solvers/linear/continuous_barycenter.py | 2 +- ott/solvers/linear/sinkhorn.py | 4 ++-- ott/solvers/nn/neuraldual.py | 2 +- ott/solvers/quadratic/gromov_wasserstein.py | 2 +- ott/solvers/quadratic/gw_barycenter.py | 2 +- ott/{utils => solvers}/was_solver.py | 1 + ott/tools/segment_sinkhorn.py | 5 ++--- ott/tools/sinkhorn_divergence.py | 5 ++--- ott/typing.py | 2 +- ott/{utils/dataclasses.py => utils.py} | 0 ott/utils/__init__.py | 1 - tests/core/continuous_barycenter_test.py | 3 +-- 23 files changed, 28 insertions(+), 34 deletions(-) rename ott/{utils => geometry}/segment.py (99%) rename ott/{math => problems/linear}/potentials.py (100%) rename ott/{utils => solvers}/was_solver.py (98%) rename ott/{utils/dataclasses.py => utils.py} (100%) delete mode 100644 ott/utils/__init__.py diff --git a/ott/geometry/__init__.py b/ott/geometry/__init__.py index 7309ab155..e57c7b3ec 100644 --- a/ott/geometry/__init__.py +++ b/ott/geometry/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. """OTT ground geometries: Classes and cost functions to instantiate them.""" -from . import costs, epsilon_scheduler, geometry, graph, grid, pointcloud +from . import costs, epsilon_scheduler, geometry, graph, grid, pointcloud, segment diff --git a/ott/utils/segment.py b/ott/geometry/segment.py similarity index 99% rename from ott/utils/segment.py rename to ott/geometry/segment.py index 2e15db821..1051f348a 100644 --- a/ott/utils/segment.py +++ b/ott/geometry/segment.py @@ -14,7 +14,7 @@ from typing import Callable, Optional, Tuple import jax -from jax import numpy as jnp +import jax.numpy as jnp __all__ = ["segment_point_cloud"] diff --git a/ott/initializers/linear/initializers_lr.py b/ott/initializers/linear/initializers_lr.py index c3d83f735..6059a17bb 100644 --- a/ott/initializers/linear/initializers_lr.py +++ b/ott/initializers/linear/initializers_lr.py @@ -13,8 +13,8 @@ ) import jax +import jax.numpy as jnp import numpy as np -from jax import numpy as jnp from typing_extensions import Literal from ott.geometry import geometry, low_rank, pointcloud diff --git a/ott/initializers/nn/initializers.py b/ott/initializers/nn/initializers.py index 09fe1b26b..b10a58dac 100644 --- a/ott/initializers/nn/initializers.py +++ b/ott/initializers/nn/initializers.py @@ -2,10 +2,10 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple import jax +import jax.numpy as jnp import optax from flax import linen as nn from flax.training import train_state -from jax import numpy as jnp from ott.geometry import geometry from ott.initializers.linear import initializers diff --git a/ott/math/__init__.py b/ott/math/__init__.py index faac35ec8..60c303c2e 100644 --- a/ott/math/__init__.py +++ b/ott/math/__init__.py @@ -3,7 +3,6 @@ fixed_point_loop, implicit_differentiation, matrix_square_root, - potentials, unbalanced_functions, utils, ) diff --git a/ott/math/implicit_differentiation.py b/ott/math/implicit_differentiation.py index 1f42888c0..58ed326bf 100644 --- a/ott/math/implicit_differentiation.py +++ b/ott/math/implicit_differentiation.py @@ -18,8 +18,8 @@ import jax import jax.numpy as jnp +from ott import utils from ott.math import unbalanced_functions -from ott.utils import dataclasses if TYPE_CHECKING: from ott.problems.linear import linear_problem @@ -27,7 +27,7 @@ __all__ = ["ImplicitDiff"] -@dataclasses.register_pytree_node +@utils.register_pytree_node class ImplicitDiff: """Implicit differentiation of Sinkhorn algorithm. diff --git a/ott/problems/linear/__init__.py b/ott/problems/linear/__init__.py index 1681cefbb..2088e5a3c 100644 --- a/ott/problems/linear/__init__.py +++ b/ott/problems/linear/__init__.py @@ -1 +1 @@ -from . import barycenter_problem, linear_problem +from . import barycenter_problem, linear_problem, potentials diff --git a/ott/problems/linear/barycenter_problem.py b/ott/problems/linear/barycenter_problem.py index a34ee728f..7a6e63129 100644 --- a/ott/problems/linear/barycenter_problem.py +++ b/ott/problems/linear/barycenter_problem.py @@ -17,8 +17,7 @@ import jax import jax.numpy as jnp -from ott.geometry import costs -from ott.utils import segment +from ott.geometry import costs, segment __all__ = ["BarycenterProblem"] diff --git a/ott/math/potentials.py b/ott/problems/linear/potentials.py similarity index 100% rename from ott/math/potentials.py rename to ott/problems/linear/potentials.py diff --git a/ott/problems/quadratic/gw_barycenter.py b/ott/problems/quadratic/gw_barycenter.py index 8d95a2773..a052e6224 100644 --- a/ott/problems/quadratic/gw_barycenter.py +++ b/ott/problems/quadratic/gw_barycenter.py @@ -2,14 +2,13 @@ from typing import Any, Dict, Optional, Sequence, Tuple, Union import jax -from jax import numpy as jnp +import jax.numpy as jnp from typing_extensions import Literal -from ott.geometry import costs, geometry, pointcloud +from ott.geometry import costs, geometry, pointcloud, segment from ott.math import utils as mu from ott.problems.linear import barycenter_problem from ott.problems.quadratic import quadratic_costs, quadratic_problem -from ott.utils import segment __all__ = ["GWBarycenterProblem"] diff --git a/ott/solvers/linear/acceleration.py b/ott/solvers/linear/acceleration.py index 9684f2956..a22ac9ab6 100644 --- a/ott/solvers/linear/acceleration.py +++ b/ott/solvers/linear/acceleration.py @@ -1,17 +1,17 @@ from typing import TYPE_CHECKING import jax -from jax import numpy as jnp +import jax.numpy as jnp + +from ott import utils if TYPE_CHECKING: from ott.solvers.linear import sinkhorn -from ott.utils import dataclasses - __all__ = ["AndersonAcceleration", "Momentum"] -@dataclasses.register_pytree_node +@utils.register_pytree_node class AndersonAcceleration: """Implements Anderson acceleration for Sinkhorn.""" @@ -107,7 +107,7 @@ def update_history( return state.set(old_mapped_fus=mapped) -@dataclasses.register_pytree_node +@utils.register_pytree_node class Momentum: """Momentum for Sinkhorn updates, either constant or adaptive.""" diff --git a/ott/solvers/linear/continuous_barycenter.py b/ott/solvers/linear/continuous_barycenter.py index e1d4a96e2..a1534647c 100644 --- a/ott/solvers/linear/continuous_barycenter.py +++ b/ott/solvers/linear/continuous_barycenter.py @@ -24,7 +24,7 @@ from ott.math import fixed_point_loop from ott.math import utils as mu from ott.problems.linear import barycenter_problem, linear_problem -from ott.utils import was_solver +from ott.solvers import was_solver __all__ = ["BarycenterState", "WassersteinBarycenter"] diff --git a/ott/solvers/linear/sinkhorn.py b/ott/solvers/linear/sinkhorn.py index 84c6100d4..9f824e142 100644 --- a/ott/solvers/linear/sinkhorn.py +++ b/ott/solvers/linear/sinkhorn.py @@ -25,8 +25,8 @@ from ott.initializers.linear import initializers as init_lib from ott.math import fixed_point_loop from ott.math import implicit_differentiation as implicit_lib -from ott.math import potentials, unbalanced_functions -from ott.problems.linear import linear_problem +from ott.math import unbalanced_functions +from ott.problems.linear import linear_problem, potentials from ott.solvers.linear import acceleration __all__ = ["Sinkhorn", "SinkhornOutput"] diff --git a/ott/solvers/nn/neuraldual.py b/ott/solvers/nn/neuraldual.py index 5d107e78c..306afa5c2 100644 --- a/ott/solvers/nn/neuraldual.py +++ b/ott/solvers/nn/neuraldual.py @@ -24,7 +24,7 @@ from typing_extensions import Literal from ott.geometry import costs -from ott.math import potentials +from ott.problems.linear import potentials from ott.solvers.nn import icnn __all__ = ["NeuralDualSolver"] diff --git a/ott/solvers/quadratic/gromov_wasserstein.py b/ott/solvers/quadratic/gromov_wasserstein.py index 805354fa1..05a427a74 100644 --- a/ott/solvers/quadratic/gromov_wasserstein.py +++ b/ott/solvers/quadratic/gromov_wasserstein.py @@ -26,8 +26,8 @@ from ott.math import fixed_point_loop from ott.problems.linear import linear_problem from ott.problems.quadratic import quadratic_costs, quadratic_problem +from ott.solvers import was_solver from ott.solvers.linear import sinkhorn, sinkhorn_lr -from ott.utils import was_solver __all__ = ["GWOutput", "GromovWasserstein", "gromov_wasserstein"] diff --git a/ott/solvers/quadratic/gw_barycenter.py b/ott/solvers/quadratic/gw_barycenter.py index 528f676c4..46ee6618f 100644 --- a/ott/solvers/quadratic/gw_barycenter.py +++ b/ott/solvers/quadratic/gw_barycenter.py @@ -8,8 +8,8 @@ from ott.math import fixed_point_loop from ott.problems.linear import linear_problem from ott.problems.quadratic import gw_barycenter +from ott.solvers import was_solver from ott.solvers.quadratic import gromov_wasserstein -from ott.utils import was_solver __all__ = ["GWBarycenterState", "GromovWassersteinBarycenter"] diff --git a/ott/utils/was_solver.py b/ott/solvers/was_solver.py similarity index 98% rename from ott/utils/was_solver.py rename to ott/solvers/was_solver.py index 2ff7789d9..c823ef1d4 100644 --- a/ott/utils/was_solver.py +++ b/ott/solvers/was_solver.py @@ -28,6 +28,7 @@ "continuous_barycenter.BarycenterState"] +# TODO(michalk8): refactor to have generic nested solver API @jax.tree_util.register_pytree_node_class class WassersteinSolver: """A generic solver for problems that use a linear reg-OT pb in inner loop.""" diff --git a/ott/tools/segment_sinkhorn.py b/ott/tools/segment_sinkhorn.py index 7f4afeb60..fca0a4226 100644 --- a/ott/tools/segment_sinkhorn.py +++ b/ott/tools/segment_sinkhorn.py @@ -15,11 +15,10 @@ from types import MappingProxyType from typing import Any, Mapping, Optional, Tuple -from jax import numpy as jnp +import jax.numpy as jnp -from ott.geometry import costs, pointcloud +from ott.geometry import costs, pointcloud, segment from ott.solvers.linear import sinkhorn -from ott.utils import segment def segment_sinkhorn( diff --git a/ott/tools/sinkhorn_divergence.py b/ott/tools/sinkhorn_divergence.py index 119f7be5d..2b8404439 100644 --- a/ott/tools/sinkhorn_divergence.py +++ b/ott/tools/sinkhorn_divergence.py @@ -17,10 +17,9 @@ import jax.numpy as jnp -from ott.geometry import costs, geometry, pointcloud -from ott.math import potentials +from ott.geometry import costs, geometry, pointcloud, segment +from ott.problems.linear import potentials from ott.solvers.linear import sinkhorn -from ott.utils import segment __all__ = [ "sinkhorn_divergence", "segment_sinkhorn_divergence", diff --git a/ott/typing.py b/ott/typing.py index 1b27fa3aa..28e746713 100644 --- a/ott/typing.py +++ b/ott/typing.py @@ -1,4 +1,4 @@ -from jax import numpy as jnp +import jax.numpy as jnp from typing_extensions import Protocol # TODO(michalk8): introduce additional types here diff --git a/ott/utils/dataclasses.py b/ott/utils.py similarity index 100% rename from ott/utils/dataclasses.py rename to ott/utils.py diff --git a/ott/utils/__init__.py b/ott/utils/__init__.py deleted file mode 100644 index 3f5072c03..000000000 --- a/ott/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from . import dataclasses, segment, was_solver diff --git a/tests/core/continuous_barycenter_test.py b/tests/core/continuous_barycenter_test.py index 2f3430d27..2a0efa9d1 100644 --- a/tests/core/continuous_barycenter_test.py +++ b/tests/core/continuous_barycenter_test.py @@ -21,13 +21,12 @@ import numpy as np import pytest -from ott.geometry import costs, pointcloud +from ott.geometry import costs, pointcloud, segment from ott.problems.linear import barycenter_problem from ott.problems.quadratic import gw_barycenter as gwb from ott.solvers.linear import continuous_barycenter as cb from ott.solvers.quadratic import gw_barycenter as gwb_solver from ott.tools.gaussian_mixture import gaussian_mixture -from ott.utils import segment means_and_covs_to_x = jax.vmap(costs.mean_and_cov_to_x, in_axes=[0, 0, None]) From c1dfdb469d3852955ad274522e65bd778236febb Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Wed, 16 Nov 2022 19:31:17 +0100 Subject: [PATCH 10/34] Update imports in notebooks --- docs/notebooks/GWLRSinkhorn.ipynb | 49 +++++++--------- docs/notebooks/Hessians.ipynb | 27 +++------ docs/notebooks/LRSinkhorn.ipynb | 56 ++++++++----------- docs/notebooks/MetaOT.ipynb | 32 +++++++---- docs/notebooks/OTT_&_POT.ipynb | 20 +++++-- docs/notebooks/One_Sinkhorn.ipynb | 6 +- docs/notebooks/Sinkhorn_Barycenters.ipynb | 31 ++++++---- docs/notebooks/application_biology.ipynb | 23 ++++++-- docs/notebooks/gromov_wasserstein.ipynb | 8 +-- .../gromov_wasserstein_multiomics.ipynb | 9 ++- docs/notebooks/icnn_inits.ipynb | 11 ++-- docs/notebooks/introduction_grid.ipynb | 20 +++++-- docs/notebooks/neural_dual.ipynb | 17 +++--- docs/notebooks/point_clouds.ipynb | 30 ++++++---- docs/notebooks/soft_sort.ipynb | 4 +- .../wasserstein_barycenters_gmms.ipynb | 13 +++-- 16 files changed, 195 insertions(+), 161 deletions(-) diff --git a/docs/notebooks/GWLRSinkhorn.ipynb b/docs/notebooks/GWLRSinkhorn.ipynb index 376212527..4b1cd42e1 100644 --- a/docs/notebooks/GWLRSinkhorn.ipynb +++ b/docs/notebooks/GWLRSinkhorn.ipynb @@ -42,9 +42,11 @@ }, "outputs": [], "source": [ - "import jax.numpy as jnp\n", "import jax\n", - "import ott\n", + "import jax.numpy as jnp\n", + "from ott.geometry import pointcloud\n", + "from ott.solvers.quadratic import gromov_wasserstein\n", + "from ott.problems.quadratic import quadratic_problem\n", "import matplotlib.pyplot as plt" ] }, @@ -65,15 +67,7 @@ }, "id": "PfiRNdhVW8hT" }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" - ] - } - ], + "outputs": [], "source": [ "def create_points(rng, n, m, d1, d2):\n", " rngs = jax.random.split(rng, 5)\n", @@ -120,11 +114,10 @@ }, "outputs": [], "source": [ - "geom_xx = ott.geometry.pointcloud.PointCloud(x)\n", - "geom_yy = ott.geometry.pointcloud.PointCloud(y)\n", - "# below `z` is there only to create n x m geometry\n", - "geom_xy = ott.geometry.pointcloud.PointCloud(x, z)\n", - "prob = ott.core.quad_problems.QuadraticProblem(\n", + "geom_xx = pointcloud.PointCloud(x)\n", + "geom_yy = pointcloud.PointCloud(y)\n", + "geom_xy = pointcloud.PointCloud(x, z)\n", + "prob = quadratic_problem.QuadraticProblem(\n", " geom_xx,\n", " geom_yy,\n", " geom_xy=geom_xy,\n", @@ -161,7 +154,7 @@ }, "outputs": [], "source": [ - "solver = ott.core.gromov_wasserstein.GromovWasserstein(rank=6)\n", + "solver = gromov_wasserstein.GromovWasserstein(rank=6)\n", "ot_gwlr = solver(prob)" ] }, @@ -193,7 +186,7 @@ }, "outputs": [], "source": [ - "solver = ott.core.gromov_wasserstein.GromovWasserstein(epsilon=0.05)\n", + "solver = gromov_wasserstein.GromovWasserstein(epsilon=0.05)\n", "ot_gw = solver(prob)" ] }, @@ -230,26 +223,22 @@ "outputs": [ { "data": { - "image/png": "\n", + "image/png": "\n", "text/plain": [ - "
" + "
" ] }, - "metadata": { - "needs_background": "light" - }, + "metadata": {}, "output_type": "display_data" }, { "data": { - "image/png": "\n", + "image/png": "\n", "text/plain": [ - "
" + "
" ] }, - "metadata": { - "needs_background": "light" - }, + "metadata": {}, "output_type": "display_data" } ], @@ -282,7 +271,7 @@ ] }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -296,7 +285,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.10.6" } }, "nbformat": 4, diff --git a/docs/notebooks/Hessians.ipynb b/docs/notebooks/Hessians.ipynb index 59a6e6a55..51383ca35 100644 --- a/docs/notebooks/Hessians.ipynb +++ b/docs/notebooks/Hessians.ipynb @@ -22,7 +22,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -34,7 +34,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 2, "metadata": { "id": "F0ESAZHMV_vL" }, @@ -43,10 +43,9 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "import ott\n", "from ott.tools import sinkhorn_divergence\n", "from ott.geometry import pointcloud\n", - "from ott.core import implicit_differentiation as implicit_lib\n", + "from ott.math import implicit_differentiation as implicit_lib\n", "import matplotlib.pyplot as plt" ] }, @@ -61,7 +60,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": { "id": "0jfa6mSiWAw6" }, @@ -80,19 +79,11 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": { "id": "79peUzQOVqcJ" }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" - ] - } - ], + "outputs": [], "source": [ "a, x, b, y = sample(15, 17, 3)" ] @@ -257,9 +248,9 @@ "provenance": [] }, "kernelspec": { - "display_name": "ott", + "display_name": "Python 3 (ipykernel)", "language": "python", - "name": "ott" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -271,7 +262,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.10.6" } }, "nbformat": 4, diff --git a/docs/notebooks/LRSinkhorn.ipynb b/docs/notebooks/LRSinkhorn.ipynb index dc304a367..5328cfb91 100644 --- a/docs/notebooks/LRSinkhorn.ipynb +++ b/docs/notebooks/LRSinkhorn.ipynb @@ -43,6 +43,9 @@ "import jax.numpy as jnp\n", "import jax\n", "import ott\n", + "from ott.geometry import pointcloud\n", + "from ott.problems.linear import linear_problem\n", + "from ott.solvers.linear import sinkhorn, sinkhorn_lr\n", "import matplotlib.pyplot as plt\n", "\n", "plt.rcParams.update({\"font.size\": 18})" @@ -56,9 +59,6 @@ }, "outputs": [], "source": [ - "import ott\n", - "\n", - "\n", "def create_points(rng, n, m, d):\n", " rngs = jax.random.split(rng, 4)\n", " x = jax.random.normal(rngs[0], (n, d)) + 1\n", @@ -91,8 +91,8 @@ "n, m, d = 19, 35, 2\n", "x, y, a, b = create_points(rng, n=n, m=m, d=d)\n", "\n", - "geom = ott.geometry.pointcloud.PointCloud(x, y, epsilon=0.1)\n", - "ot_prob = ott.core.linear_problems.LinearProblem(geom, a, b)" + "geom = pointcloud.PointCloud(x, y, epsilon=0.1)\n", + "ot_prob = linear_problem.LinearProblem(geom, a, b)" ] }, { @@ -128,31 +128,27 @@ "outputs": [ { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW0AAADtCAYAAAB0xiROAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAAleUlEQVR4nO3de7xcZX3v8c83e+dmyIVkQyQXCJKgDQoRItjKsQoHCK0aekBNyhFsqbSvyqmtl9PUUwEpr4o9bbFV6jkgKKIQkNYSFQ1XjwVqIFwUAka3CJJwDQm5kXt+54/1bBx29qw1e2b2npnl953XemVmPWs988zas3/7mWf91rMUEZiZWWcY0eoGmJlZ7Ry0zcw6iIO2mVkHcdA2M+sgDtpmZh2ku9UNMDMbKpM1O3bxcuF2W3hmeUQsGIYmNcxB28xKaxcvM18fKtzu+3FRzzA0pykctM2s3FTDNh10uYqDtpmVlgCNqCFq7x3ypjSNg7aZlZdAtfS0O4iDtpmVWk097Q7ioG1mJSYHbTOzjiFKNz7ioG1mpVaymO2gbWbllXW0yxW1HbTNrNzKFbMdtM2sxAQjusoVtR20zazcSjY84ln+zKzUpOKltnq0QNJqSb2SlgxQPlrS9al8haRZaf2xkh5Ky48k/V6tdQ7EQdvMyktZnnbRUlyNuoDLgFOBucBiSXP7bXYOsCEiZgOXAp9N6x8B5kfEPGAB8H8ldddY5z4ctM2s3JrT1T4W6I2IxyNiJ7AUWNhvm4XA1enxjcCJkhQRL0fE7rR+DL+anqqWOvfhoG1mpSVgxAgVLkCPpJUVy7n9qpoOPFXxfE1aN+A2KUhvBKYASDpO0irgYeBPUnktde7DJyLNrNxqG7NeFxHzh6oJEbECOELSbwBXS/puvXU5aJtZealpE0atBWZWPJ+R1g20zRpJ3cBE4MXKDSLiMUlbgDfWWOc+PDxiZuWmGpZi9wFzJB0qaRSwCFjWb5tlwNnp8RnAHRERaZ9uAEmHAG8Anqixzn24p21mpdaMy9gjYrek84DlQBdwVUSsknQRsDIilgFXAtdI6gXWkwVhgOOBJZJ2kd1u4U8jYl1q2z51Fr6fiA66z06HknQmcHZEnFzDth8E/igijh9MmZnta9KomfGO136kcLubnvrE/UM5pt1MHh5pEknHS7pH0kZJ6yXdLektABHx9VoCdqeRdLikb0hal973jyV9NOWf1lvnhZK+Nsh9zktn/HdI+sog9rtdUvR9de1X9tup7OJ+6/9C0rOSNkm6StLoirInJG2TtCUtt1SUnS3p/rTfGkl/V/m6kn5D0h3pOPb2uwDjrZJuTZ+rF9IxP2gQh+jXlgTqUuHSSRy0m0DSBODbwOeByWRpO58GdrSyXXkGClSD3P8wYAVZytKbImIi8F5gPjC+8RYOytPAxcBVte6Qvv2MrFI2EvgnsvdXuf4UYAlwInAI8Dqyn3Old0fEfmmp/EP9GuDPgR7guFTHx1O93cBNZJ+hycC5wNckHZ723R+4HJiVXncz8OVa3+uvu2ZdEdkuHLSb43CAiLguIvZExLaIuCUifgzZsIaku/o2Tj24P5H0M0kvSbpMVQbeJP1vSXdJmlix7u8lbZD0C0mnVqyfJmlZ6pH1SvpQRdmFkm6U9DVJm4APSvq+pL9J3wo2S7pFUk+N7/nTwD0R8dGIeCa9/9UR8fsR8VJ6zfdIWpXe4/dTulNfe/5S0tr0uqslnShpAfBJ4P2pp/qjWhoSEf8WEf9OvzP11aRjeQHwP6ts8jHgFuAn/dafDVwZEasiYgPwN8AHa2zjFyPiPyJiZ0SsBb4OvC0VvwGYBlyaPj93AHcDH0j7fjcivhERmyLiZeALFftakZJFbQft5vgpsEfS1ZJOlbR/Dfu8C3gLcCTwPuCUykJJIyRdkcpPjoiNqeg4YDVZj+3vgCsrAv5SsgT9aWRnr/9W0gkV1S4ku1JrElnQAPh94A+AA4FRpN5fDf5rqmtAqZd4HVnv8gDgZuBbkkZJej1wHvCWiBif3vsTEfE94G+B61NP9ahU1xJJ366xXbX4W+CLwLMDtPsQ4A+BiwbY7wig8g/Jj4CpkqZUrPt6GsK4RdJROW14O5B30klkaWH17Gt9aojXHRazHbSbISI2kZ0hDuAK4IXU452as9slEfFSRPwSuBOYV1E2kizgTSb7uv1yRdmTEXFFROwhu2T2ILLAMZOs9/WXEbE9Ih4CvgScVbHvf0bEv0fE3ojYltZ9OSJ+mp7f0K8deaYAz+SUvx/4TkTcGhG7gL8HxgK/BewBRgNzJY2MiCci4ufVKoqISyLiXTW2K5ek+WTH6fNVNvln4FMRsWWAsv3IrnLr0/e4bzjoTH41hHEnsFzSpAHa8Idkw0h/n1atBp4HPiFppKSTgd8mG1Lpv++RwPnAJ6q03/ppxtwj7cRBu0ki4rGI+GBEzCDrIU0DPpezS2Uv72WygNBnNlmv+NNpToIB96sI5vul11sfEZsrtn2SV18WW3nJbC3tyPMi2R+Maqal1+9r6970+tMjopesB34h8LykpZKm1fi6dZM0AvgX4CMVc0FUlr8bGB8R11epYgswoeJ53+PNABFxdxoaezkiPgO8BPyXfq9xGvAZ4NS+tK/0R+004HfJfh4fI/sDuqbfvrOB76b2/0dt79rK1tV20B4CEfET4CtU/3pb5DGyIYvvpqGEWjwNTJZUeRLwYF59hVUz8ztvA04vaM8hfU/SEM7MvvZExLUpdfGQ1K6+GdGGMgd1AlkP93pJz5Jd3ADZFWz/hezk4PyUHfIs2beFP5d0U9puFVA55HEU8FxEVBtLDyou3Uhj9leQfXt6+FUbRvw4In47IqZExClkJznvrdj3ELJj/jcRcU09b/7XkVQ878gI97R//Uh6g6SPSZqRns8EFgM/rLfOiLiO7KTcbSlTo2j7p4B7gM9IGpO+Rp8DDCp9rpKyFLYPVim+APitdKL0tWn72elE5ySynuLvphOMI8l6jzuAeyS9XtIJytLltgPbyC46AHgOmJV6xbW2s1vSGLILFLrS+x8oO2Yj2TeAeWn5nbT+GLJMkU+RnVTuK19GFmT/IG33VeAcSXPTe/xrsj/OSDpY0tvSmP0YSZ8gO+9wdyo/gew8wukR8UowrngPR6b9XiPp42TfYvrqng7cAXwhIv5PrcfFkhE1LB2kw5rbtjaTnSBcIWkrWbB+hCxQ1S0iriY7IXaH0oTqBRaTjak+DXwTuCAibqvntZVdVjuFKn940hj0b6bXWyVpI/CvwEpgc0SsBv472djxOuDdZD3MnWTj2Zek9c+SnQT9q1T1N9L/L0p6ILXlk8qfYOevyQL/kvSa29K6vmC6RdLBkXm2bwFeSPs/l7I6Nvcr3wZsjYj16T1/j+zk753AL8mGfy5IdYwnO7m5gezbxAKyIZC+XvinyOaiuFm/yuOufE8fIDtH8DxZj/+kiOhLGf0jsp73hRX7DjTmbgOQVLh0El8RaQOSdDzw4YhY3Oq2mNVr8tiD4+TZxedsr3/kzzrmikjPPWIDioi7gLsKNzRrZ4LaB9o6g4O2mZWWaNrUrG3DQdvMyqt582m3DQdtMyuxzsvDLjKsQbunpydmHTJrOF+y5aIg7Vg1zsDebvbs3Vu1rGtE/iDi6geezi1//dFDfp3NgMr6s9q1u/rPCmBkd3sO+t7/wP3rIuKARuspWcxuLGiniwX+iSw/9ksRcUne9rMOmcWKFfukqJba3pzgBjCiIMC1qy1bqk9guN9+o6uWAZw48sLc8ttX5JcPld279+SWd3fXPeNsS61btzW3vKdn3DC1ZHC6R3Y9WbxVsbINj9QdMZTNmXwZcCowF1gsaW6zGmZm1jBRusvYG+lpHwv0RsTjAJKWks2X8WgzGmZm1igBIzrsJgdFGvluPp1XT0C0hldPTgSApHOV3VVk5QvrXuhfbGY2tJpzY9+2MeQDqhFxeUTMj4j5B/Q0fE7BzKx2Kt9l7I0Mj6wlm7WtzwxePaOcmVmLdd582UUa6WnfB8yRdGiaXGgR2axoZmZto2TnIevvaUfEbknnAcvJUv6uiojcWyDt3rOXDetfrlq+/+R9btQxLPbsyU/L27UzPxVszNgB7w8LwNq1m3L3nTlzUm55uxo3blTd+96284LijVqgU1P6iuw/aWyrmzCgnTv2uQ/F0Oi0qFygoTztiLiZ7N5/ZmZtR3L2iJlZZ2nS+IikBZJWS+qVtGSA8tGSrk/lK/rmwJd0kqT7JT2c/j+hYp/vpzofSsuBRe3w3CNmVmrNyA6puJjwJLL05vskLYuIyutSzgE2RMRsSYvIbqH3ftJNQCLiaUlvJBtSrkyPPjMiVtbaFve0zay80nzaRUsNXrmYMN19qe9iwkoLgavT4xuBEyUpIh6MiL4Jd1YBY9Ot9urioG1mJVbD0EjWE+/puwgwLef2q6iWiwlf2SYidpPdk3RKv21OBx6ouJUcwJfT0MinVMPXAg+PmFl51X4ict1Q325M0hFkQyYnV6w+MyLWShpPdo/VD5DdQLqqYQ3aXV1iwoQxw/mSr/jpz9ZVLZt92OTcfUePyT9M27ftqlo2Y8bE/IYVWP9i9RnaJkxs7FjmzTBYdO/Qrq7q+zY6W17eaxd1RHbtyn/tz/9T9Tuo/Y+PHJ+778iR+e3eu7d6u4uO51335E9o97bfPLhq2ctbd+buO2Fifsrfs89WT0udOnV87r5FaXvdOcds+/bqvzdN1ZyUv1ouJuzbZo2kbrIbOb+YNUEzyG62fVa6KTYAEbE2/b9Z0rVkwzC5QdvDI2ZWWk2c5K+WiwmXAWenx2cAd0RESJoEfAdYEhF3v9I2qVtST3o8EngX8EhRQzw8Ymal1ozL2KtdTCjpImBlRCwDrgSukdQLrCcL7ADnAbOB8yWdn9adDGwFlqeA3QXcBlxR1BYHbTMrryZepz7QxYQRcX7F4+3AewfY72Lg4irVHjPYdjhom1mplewqdgdtMysxwYick+adyEHbzMrNPW0zs84gyndj32EN2k+sXscfnfClquVf/kH/i5Ca5/A5PVXLVtz7VNUygMNn97+o6dXyppR97tnNuftOfW1+HuzkKUN3p+y83OFG7hLfSB42NDZXRFEu9V987O11v25Ru0fkBof8ut9+/Kzc8ry2jS+49iEvfxzgta+dkFueZ/SY6tMSQ/4xe+Sx4bn9YKfdmaaIe9pmVl4SuKdtZtY5StbRdtA2sxITyNkjZmadwz1tM7MO4ewRM7NOU7KutoO2mZWX5JS/Rsw6vIfLbz9nOF+yJscdOzO3fNu2/PmK83zmo9/JLf/ctYtyy4dSIx/mPbv3Vi3r6s4/8bPxpW255ZP2r5733qi891w0F3dRDnirFOVhP/1M9fmyAWbOmNTE1rxa3vF+y9HThux1K5XtbuzuaZtZubmnbWbWIVTzjXs7hoO2mZVWduca97TNzDqHU/7MzDqEs0fMzDqLnD1Sv5c2budbyx6tWv7fTn/TkL123hSR37j+R7n7nn7GkbnleSmB//i19+U3bAht2bw9tzwvVaxous+itL48Q5nSV+S556pPlTt1av40uUWpn6NHV/91KprqtmDW19wEiPwpYYtT+vI+B0V1F8lLo3zwx882VHet3NOuIOkJYDOwB9gdEfOb0Sgzs6Zo3n1920YzetrvjIh1TajHzKz5fCLSzKwzlDHlr9G08wBukXS/pAHvFSbpXEkrJa3ctGlDgy9nZjYIEhpRvHSSRnvax0fEWkkHArdK+klE/KByg4i4HLgc4LDDjig43WJm1lydFpSLNNTTjoi16f/ngW8CxzajUWZmzdKsnrakBZJWS+qVtGSA8tGSrk/lKyTNSutPSqMRD6f/T6jY55i0vlfSP6uGsZy6g7akcZLG9z0GTgYeqbc+M7OmUzamXbQUViN1AZcBpwJzgcWS5vbb7BxgQ0TMBi4FPpvWrwPeHRFvAs4GrqnY54vAh4A5aVlQ1JZGhkemAt9Mb7gbuDYivpe3w/6TxubmYu/dW326z6VffzC3Mb//gWNyy/N+MO9bNC933yJju0c1tP9Q2W98fq51J3qpaFrXSWNzy4tysfOMHTt0P+dG8qEbPdHWyGvnXf8A+dPZHjt/Rt2vOyjNGR05FuiNiMcBJC0FFgKVF54sBC5Mj28EviBJEVEZvFYBYyWNBiYDEyLih6nOrwKnAd/Na0jdQTs1/qh69zczG2qDyB7pkbSy4vnl6Xxcn+nAUxXP1wDH9avjlW0iYrekjcAUsp52n9OBByJih6TpqZ7KOqcXNdQpf2ZWagUXo/ZZN9QXB0o6gmzI5ORG6inZTLNmZhVqGM+usSe+Fqi8xdWMtG7AbSR1AxOBF9PzGWTJGmdFxM8rtq8cIxqozn04aJtZqUnFSw3uA+ZIOlTSKGARsKzfNsvITjQCnAHcEREhaRLwHWBJRNzdt3FEPANskvTWlDVyFnBTUUMctM2stPrGtBvtaUfEbuA8YDnwGHBDRKySdJGk96TNrgSmSOoFPgr0pQWeB8wGzpf0UFoOTGV/CnwJ6AV+TsFJSPCYtpmVXLOuYo+Im4Gb+607v+LxduC9A+x3MXBxlTpXAm8cTDsctM2s1Mo290hbBe2tW6rPV3zEUdMaqnvby9XrHvua1uVZP/6L9bnlY8dU/xEddNCEZjenKbZu3ZFbPm7c6Lrrfva5LbnlN3+r+nztkJ/Pv2dP9esEALq66h9NXLdua2553s8ZYFTOXN1dBZP8F83lneelDS/nlm/YmD9n+6GzJtf92k0hGOGgbWbWGbIx7Va3orkctM2s1By0zcw6iMe0zcw6SMlitoO2mZWYOu8mB0UctM2stMp4u7G2CtrjJ1SfSvSoIw/K3Xfv3vwpIluZ1pfndYe2OCWqiq1bCtL29quettdISl+RN7z+gIbK8zSS0lekp2fckNU9lCbt/5qGyttByWJ2ewVtM7Nmc0/bzKxT1D4hVMdw0DazUitZzHbQNrPyEo3dTq0dOWibWal5TNvMrIOULGY7aJtZidV+O7GOMaxBO4jc6S8byZNtZNxqw/r86Sf3n5yfi7pp47aqZRMmjq2rTc2wfPlPc8u/c9XKqmWXXHV6s5vTFrZty5mid2x+Lv+3CqZ9PfXU11ct6+7uym9YC+3dW/13spFpXYvqHo5g6ln+zMw6jIO2mVkHcfaImVmnkLNHzMw6S7litoO2mZWXZ/kzM+swDtpmZp1C8onIRgjVnYsdkT9f9vZtu3LLR48ZWbWsKA+7yFDO1f3Iqueqlr3h9T25+5500uzc8lNOObxq2c8ffzF338NeN3RzZu/csbtq2chR+fnORb2q51/YWrXskIPzf47vfvfc3PKiz2gj+w5lb/Ej77+uatnnv3Fm7r5btxbMuz6Ec6vXoox52oURVNJVkp6X9EjFusmSbpX0s/T//kPbTDOz+ihdFZm31FjPAkmrJfVKWjJA+WhJ16fyFZJmpfVTJN0paYukL/Tb5/upzofScmBRO2rp9n4FWNBv3RLg9oiYA9yenpuZtZ1mBG1JXcBlwKnAXGCxpP5fvc4BNkTEbOBS4LNp/XbgU8DHq1R/ZkTMS8vzRW0pDNoR8QNgfb/VC4Gr0+OrgdOK6jEzG3bpJghFSw2OBXoj4vGI2AksJYuDlSrj4o3AiZIUEVsj4i6y4N2weicWmBoRz6THzwJTq20o6VxJKyWtfGHdC3W+nJlZfWrsaff0xam0nNuvmunAUxXP16R1A24TEbuBjcCUGpr45TQ08inV0O1v+ERkRISkqmdRIuJy4HKA+cfMr/9MjZnZIAkY0VVTV3pdRMwf4uYM5MyIWCtpPPCvwAeAr+btUG9P+zlJBwGk/wvHYczMhl0NvewaT0SuBWZWPJ+R1g24jaRuYCKQm4YVEWvT/5uBa8mGYXLVG7SXAWenx2cDN9VZj5nZkGrSmPZ9wBxJh0oaBSwii4OVKuPiGcAdkZPLKalbUk96PBJ4F/BIte37FA6PSLoOeAfZmM8a4ALgEuAGSecATwLvK6oHslzUXbv2VC0fObJ6Dm7RX8OiXOmXc/JJi/bduzd/VCev3XlzbUPxfNtHzK2eATSUubuHva6WobiBFeXMd4/M7yuMGl3/qF3e/M0AM6ZNrFr2y6deyt334JmTcsvzfh5Fedg/XPFUbvmb5x1UtezpZzbn7nvIwZNyy4tysfM0koe9bl31nPlmasbvSUTslnQesBzoAq6KiFWSLgJWRsQy4ErgGkm9ZMkbiyra8AQwARgl6TTgZLLYuTwF7C7gNuCKorYU/nZExOIqRScW7Wtm1krNnHskIm4Gbu637vyKx9uB91bZd1aVao8ZbDt8GbuZlVrZroh00Daz8hKowVumtRsHbTMrNfe0zcw6hpBn+TMz6wyDSOnrGMM7NauUmx63Z0/1dK16p3R9Zf/u+tMJu2q7ompARSl99z/YPz//1Y6eN63u126Vz/3jf+SW/+Un31l33UWpc/c/+HRu+VuOmVG17IknX8rdtyjlL8/WrTtzy6dNG59b/tSajVXLph80IXffRn53du6sPk0u5Ke7Qv7vVk/PuLraNFi+CYKZWQfxTRDMzDqIe9pmZh0iG9N20DYz6xgli9kO2mZWZrXfTqxTOGibWak5aJuZdQip5psgdIy2Ctp5+aS7CvJFtxVMB9rIFJJF8qZuLUo3OubN/e9Y9Gp/fOqXq5Zd9q2zcvctys/N64Fs3pR/O7vxE8ZULVvyv07I3bfIjh3Vf9ZF2VtFx3P9+perlr39+Fn5lTdg3Lj86X+Lyoeyt/iRxddVLfvctYuqlgHce9+a3PLjjp2ZWz4cStbRbq+gbWbWbKJcUdtB28zKrVwx20HbzMrNJyLNzDqFJ4wyM+scQp57xMysk3h4xMysg5QsZndO0B45Kr+pReWbNm6rWlY053WRvK9fRfM/F/UC/uVbZ1ct6+oeunvf5eVhF9m2LX/u6LFj83OSR4+u/2O5Y3t+vv7EidXf19atO3L3bSTXf++e/M/B6p+tyy0/7HWTq5YVXcOw3/j8n2VeLnbBx5f5x+TnxefJu76haTxhlJlZ5xDuaZuZdZQRJYva5bq3vJlZP333icxbaqtHCyStltQrackA5aMlXZ/KV0ialdZPkXSnpC2SvtBvn2MkPZz2+WfVMJbjoG1mpSapcKmhji7gMuBUYC6wWNLcfpudA2yIiNnApcBn0/rtwKeAjw9Q9ReBDwFz0rKgqC0O2mZWWrX0smvsaR8L9EbE4xGxE1gKLOy3zULg6vT4RuBESYqIrRFxF1nwrmibDgImRMQPI8tY+CpwWlFDHLTNrMSKe9mpp90jaWXFcm6/iqYDT1U8X5PWDbhNROwGNgJTcho3PdWTV+c+fm1ORG7cXD2dq9GUv6G0NyfnqmsY2zEY27blp6CNGTMyt7yRFK1du/bklnd3Vz9qWzbnpyo2lPK3d29u+YP3PZVbfvCMCVXLNhW0uyjlL09RympReUHtDexbuxo/TusiYv4QN6UpCnvakq6S9LykRyrWXShpraSH0vI7Q9tMM7P6aIQKlxqsBSonB5+R1g24jaRuYCLwYkGdMwrq3EctwyNfYeDB8UsjYl5abq6hHjOz4aXmnIgE7gPmSDpU0ihgEbCs3zbLgL6r4c4A7oicryIR8QywSdJbU9bIWcBNRQ0pHB6JiB/0pa6YmXWSZl1cExG7JZ0HLCcbmbwqIlZJughYGRHLgCuBayT1AuvJAnvWDukJYAIwStJpwMkR8Sjwp2Qd47HAd9OSq5Ex7fMknQWsBD4WERsG2igN6J8LcPDBBzfwcmZmg9esy9jTiMLN/dadX/F4O/DeKvvOqrJ+JfDGwbSj3uyRLwKHAfOAZ4B/qLZhRFweEfMjYv4BPQfU+XJmZvVRDUsnqaunHRHP9T2WdAXw7aa1yMysico2n3ZdPe2UFN7n94BHqm1rZtYqtZyE7LRZAAt72pKuA95Blny+BrgAeIekeWSJlk8Af9yMxjz+i/VVy1bc80TuvovPPDq3fNpB1fNch1LRB+LelWtyy98098CqZSNHtmem9oQJ+fnMjfyS3PPDX+aW9/S8Jrf88Nk9VcsOnLpfXW2qxe49+Xnai858c2553jFrJA+7qO6dO/Knui3Ki8+7BmLEiOG5tq/DYnKhWrJHFg+w+sohaIuZWdP92gVtM7NO1mnDH0UctM2s1EoWsx20zay85NuNmZl1FgdtM7MOUrKY7fm0zcw6SVv1tF936OS6ymrR1dWav09F8w0fPe+g3PK8+Z/b1c03/yS3/LCcXGmAI+ZOrVr2W2/Nn7+mkfmdFx/5+dzypQ//Wd11jx07Krd89+78fOcNG7ZVLSvKiy/67OeVj31Nfrvbdyb6XylbT7utgraZWbOp42YXyeegbWalJYFKNgjsoG1mJSb3tM3MOkq5YraDtpmVW8litoO2mZWbL65pwJ69e9myeXvV8kanmKzXj378TG75UUfmp+XlKfrADGVKX1H6256c6UIbadd73nNE3fs2qpFf0KKUvnUvbMkt7zmg/qldi453T8+4uutuV42kZw5GyWK2e9pmVl7ZjX3LFbVLlgxjZlZu7mmbWXnJwyNmZh3FwyNmZtYy7mmbWYmJEe5pm5l1ENWw1FKNtEDSakm9kpYMUD5a0vWpfIWkWRVlf5XWr5Z0SsX6JyQ9LOkhSStracew9rQfeujBdZMmj3uyYlUPsG4421Ajt2vw2rVtbtfgtFO7Dmm0gizlr/GGSOoCLgNOAtYA90laFhGPVmx2DrAhImZLWgR8Fni/pLnAIuAIYBpwm6TDI6JvPt53RkTNx3xYg3ZEHFD5XNLKiJg/nG2ohds1eO3aNrdrcNq1XY1o0uDIsUBvRDwOIGkpsBCoDNoLgQvT4xuBLyg7C7oQWBoRO4BfSOpN9f1nPQ3x8IiZlVdfV7togR5JKyuWc/vVNB14quL5mrRuwG0iYjewEZhSsG8At0i6f4DXHJBPRJpZqdXY017Xom8Yx0fEWkkHArdK+klE/CBvh1b3tC9v8etX43YNXru2ze0anHZtV900QoVLDdYCMyuez0jrBtxGUjcwEXgxb9+I6Pv/eeCbZMMmuVoatCOiLT8gbtfgtWvb3K7Badd2NaJJySP3AXMkHSppFNmJxWX9tlkGnJ0enwHcEdmsWMuARSm75FBgDnCvpHGSxgNIGgecDDxS1BAPj5hZaTUreyQidks6D1gOdAFXRcQqSRcBKyNiGXAlcE060bieLLCTtruB7KTlbuDDEbFH0lTgm+mKzW7g2oj4XuF7Gq7pEc3Mhtub5x0dd9x+V+F2k3vG3d8pWTMtGR4pSlJvpXqS3YeoHVdJel7SIxXrJku6VdLP0v/7t0m7LpS0Nh2zhyT9TgvaNVPSnZIelbRK0kfS+pYes5x2tcMxGyPpXkk/Sm37dFp/aLo4pDddLDJquNvWTLUlj3SOYQ/aFUnqpwJzgcUp+bydvDMi5rX4L+9XgAX91i0Bbo+IOcDt6flw+wr7tgvg0nTM5kXEzcPcJsi+dn4sIuYCbwU+nD5XrT5m1doFrT9mO4ATIuIoYB6wQNJbyS4KuTQiZgMbyC4a6Uw1BGwH7WKvJKlHxE6gL0ndKqS0n/X9Vi8Erk6PrwZOG842QdV2tVxEPBMRD6THm4HHyHJhW3rMctrVcpHpux3PyLQEcALZxSHQos9ZczXpVGSbaEXQriVJvZUGnew+jKZGRN+90Z4FprayMf2cJ+nHafhk2IdtKqU5H94MrKCNjlm/dkEbHDNJXZIeAp4HbgV+DryULg6B9vv9HDT3tMvv+Ig4mmz45sOS3t7qBg0kpRK1y1nkLwKHkX3Ffgb4h1Y1RNJ+wL8Cfx4RmyrLWnnMBmhXWxyziNgTEfPIcoePBd7QinZY7VoRtGtJUm+ZepLdh9Fzkg4CSP8/3+L2ABARz6Vf/r3AFbTomEkaSRYYvx4R/5ZWt/yYDdSudjlmfSLiJeBO4DeBSeniEGiz38+6lGt0pCVBu5Yk9ZaoN9l9GFUm758N3NTCtryiLygmv0cLjlmamOdK4LGI+MeKopYes2rtapNjdoCkSenxWLIZ7B4jC95npM3a5nNWD9X4r5MM+8U11ZLUh7sdVdSV7D4UJF0HvINsIps1wAXAJcANks4BngTe1ybteoekeWRDD08Afzzc7QLeBnwAeDiN0QJ8ktYfs2rtWtwGx+wg4OqU0TUCuCEivi3pUWCppIuBB8n+6HSsThuzLuKLa8ystI5+8zHxg/93d+F24yeO7ZiLa3wZu5mVV7OuY28jDtpmVmrlCtkO2mZWdiWL2g7aZlZqJYvZDtpmVnIe0zYz6xzlCtkO2mZWdiWL2g7aZlZa2VXq5YraDtpmVm7litkO2mZWYh049WoRB20zK7lyRW0HbTMrtXKFbAdtMyu7kkVtB20zK7WSxWwHbTMrs/KdiXTQNrNSK1nM9o19zcxqIWmBpNWSeiUtGaB8tKTrU/kKSbMqyv4qrV8t6ZRa6xyIg7aZlVZ2DwQVLoX1ZLdkuww4FZhLdru4uf02OwfYEBGzgUuBz6Z955LdC/cIYAHwL5K6aqxzHw7aZmbFjgV6I+LxiNgJLAUW9ttmIXB1enwjcGK6sfNCYGlE7IiIXwC9qb5a6tyHx7TNrLTuf+D+5d0ju3pq2HSMpJUVzy+PiMsrnk8Hnqp4vgY4rl8dr2yTbmC+EZiS1v+w377T0+OiOvfhoG1mpRURC1rdhmbz8IiZWbG1wMyK5zPSugG3kdQNTARezNm3ljr34aBtZlbsPmCOpEMljSI7sbis3zbLgLPT4zOAOyIi0vpFKbvkUGAOcG+Nde7DwyNmZgXSGPV5wHKgC7gqIlZJughYGRHLgCuBayT1AuvJgjBpuxuAR4HdwIcjYg/AQHUWtUXZHwIzM+sEHh4xM+sgDtpmZh3EQdvMrIM4aJuZdRAHbTOzDuKgbWbWQRy0zcw6yP8H5WVeSINXBckAAAAASUVORK5CYII=\n", + "image/png": "\n", "text/plain": [ - "
" + "
" ] }, - "metadata": { - "needs_background": "light" - }, + "metadata": {}, "output_type": "display_data" }, { "data": { - "image/png": "\n", + "image/png": "\n", "text/plain": [ - "
" + "
" ] }, - "metadata": { - "needs_background": "light" - }, + "metadata": {}, "output_type": "display_data" } ], "source": [ - "solver = ott.core.sinkhorn.Sinkhorn()\n", + "solver = sinkhorn.Sinkhorn()\n", "ot_sink = solver(ot_prob)\n", "\n", "transp_cost = jnp.sum(ot_sink.matrix * geom.cost_matrix)\n", @@ -198,31 +194,27 @@ "outputs": [ { "data": { - "image/png": "\n", + "image/png": "\n", "text/plain": [ - "
" + "
" ] }, - "metadata": { - "needs_background": "light" - }, + "metadata": {}, "output_type": "display_data" }, { "data": { - "image/png": "\n", + "image/png": "\n", "text/plain": [ - "
" + "
" ] }, - "metadata": { - "needs_background": "light" - }, + "metadata": {}, "output_type": "display_data" } ], "source": [ - "solver = ott.core.sinkhorn_lr.LRSinkhorn(rank=int(min(n, m) / 2))\n", + "solver = sinkhorn_lr.LRSinkhorn(rank=int(min(n, m) / 2))\n", "ot_lr = solver(ot_prob)\n", "\n", "transp_cost = ot_lr.compute_reg_ot_cost(ot_prob)\n", @@ -273,12 +265,12 @@ }, "outputs": [], "source": [ - "geom = ott.geometry.pointcloud.PointCloud(x, y, epsilon=0.1)\n", - "ot_prob = ott.core.linear_problems.LinearProblem(geom, a, b)\n", + "geom = pointcloud.PointCloud(x, y, epsilon=0.1)\n", + "ot_prob = linear_problem.LinearProblem(geom, a, b)\n", "costs = []\n", "ranks = [15, 20, 35, 50, 100]\n", "for rank in ranks:\n", - " solver = ott.core.sinkhorn_lr.LRSinkhorn(rank=rank, initializer=\"k-means\")\n", + " solver = sinkhorn_lr.LRSinkhorn(rank=rank, initializer=\"k-means\")\n", " ot_lr = solver(ot_prob)\n", " costs.append(ot_lr.reg_ot_cost)" ] @@ -359,7 +351,7 @@ ] }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -373,7 +365,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.10.6" } }, "nbformat": 4, diff --git a/docs/notebooks/MetaOT.ipynb b/docs/notebooks/MetaOT.ipynb index a574c9ba1..dc4c3acd2 100644 --- a/docs/notebooks/MetaOT.ipynb +++ b/docs/notebooks/MetaOT.ipynb @@ -42,10 +42,19 @@ "execution_count": 1, "id": "9fde1353", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Note: you may need to restart the kernel to use updated packages.\n", + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], "source": [ - "%pip install ott-jax\n", - "%pip install torchvision" + "%pip install -q ott-jax\n", + "%pip install -q torchvision" ] }, { @@ -56,8 +65,9 @@ "outputs": [], "source": [ "from ott.geometry import pointcloud\n", - "from ott.core import initializers as init_lib\n", - "from ott.core import linear_problems, sinkhorn\n", + "from ott.initializers.linear import initializers as init_lib\n", + "from ott.problems.linear import linear_problem\n", + "from ott.solvers.linear import sinkhorn\n", "\n", "import jax\n", "import jax.numpy as jnp\n", @@ -505,7 +515,7 @@ ], "source": [ "a, b = demo_batch.a[0], demo_batch.b[0]\n", - "ot_problem = linear_problems.LinearProblem(geom, a, b)\n", + "ot_problem = linear_problem.LinearProblem(geom, a, b)\n", "\n", "# Predict the optimal f duals.\n", "f = meta_initializer.init_dual_a(ot_problem, lse_mode=True)\n", @@ -606,7 +616,7 @@ "\n", "\n", "def get_meta_ot_potentials(a, b):\n", - " ot_problem = linear_problems.LinearProblem(geom, a, b)\n", + " ot_problem = linear_problem.LinearProblem(geom, a, b)\n", " f = meta_initializer.init_dual_a(ot_problem, lse_mode=True)\n", " g = geom.update_potential(f, jnp.zeros_like(b), jnp.log(b), 0, axis=0)\n", " return f, g\n", @@ -618,7 +628,7 @@ "\n", "\n", "def get_gaussian_potentials(a, b):\n", - " ot_problem = linear_problems.LinearProblem(geom, a=a, b=b)\n", + " ot_problem = linear_problem.LinearProblem(geom, a=a, b=b)\n", " f = init_lib.GaussianInitializer().init_dual_a(ot_problem, lse_mode=True)\n", " g = geom.update_potential(f, jnp.zeros_like(b), jnp.log(b), 0, axis=0)\n", " return f, g\n", @@ -677,7 +687,7 @@ " \"max_iterations\": 26,\n", " }\n", "\n", - " ot_problem = linear_problems.LinearProblem(geom, a=a, b=b)\n", + " ot_problem = linear_problem.LinearProblem(geom, a=a, b=b)\n", " base_sink_out = sinkhorn.sinkhorn(\n", " geom, a=a, b=b, init_dual_a=None, **sink_kwargs\n", " )\n", @@ -750,7 +760,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -764,7 +774,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.4" + "version": "3.10.6" } }, "nbformat": 4, diff --git a/docs/notebooks/OTT_&_POT.ipynb b/docs/notebooks/OTT_&_POT.ipynb index 15181472b..5d4ca6815 100644 --- a/docs/notebooks/OTT_&_POT.ipynb +++ b/docs/notebooks/OTT_&_POT.ipynb @@ -28,7 +28,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": { "id": "IO2KLVZ1KWvq" }, @@ -52,7 +52,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": { "id": "ysURew0UKhHE" }, @@ -63,7 +63,7 @@ "import jax.numpy as jnp\n", "import ott\n", "from ott.geometry import pointcloud\n", - "from ott.core import sinkhorn\n", + "from ott.solvers.linear import sinkhorn\n", "\n", "# import OT, from POT\n", "import numpy as np\n", @@ -510,9 +510,21 @@ "provenance": [] }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" } }, "nbformat": 4, diff --git a/docs/notebooks/One_Sinkhorn.ipynb b/docs/notebooks/One_Sinkhorn.ipynb index 3735ccf20..e7c1f580e 100644 --- a/docs/notebooks/One_Sinkhorn.ipynb +++ b/docs/notebooks/One_Sinkhorn.ipynb @@ -40,7 +40,7 @@ "source": [ "import ott\n", "from ott.tools.sinkhorn_divergence import sinkhorn_divergence\n", - "from ott.core.sinkhorn import sinkhorn\n", + "from ott.solvers.linear import sinkhorn\n", "from ott.geometry.pointcloud import PointCloud\n", "from ott.geometry.geometry import Geometry" ] @@ -1129,7 +1129,7 @@ "toc_visible": true }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -1143,7 +1143,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.10.6" } }, "nbformat": 4, diff --git a/docs/notebooks/Sinkhorn_Barycenters.ipynb b/docs/notebooks/Sinkhorn_Barycenters.ipynb index db3116b76..02e388fba 100644 --- a/docs/notebooks/Sinkhorn_Barycenters.ipynb +++ b/docs/notebooks/Sinkhorn_Barycenters.ipynb @@ -27,7 +27,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -57,7 +57,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": { "id": "LraeyMnPxVyC" }, @@ -65,14 +65,13 @@ "source": [ "import jax\n", "import jax.numpy as jnp\n", - "import ott\n", - "from ott.geometry import grid\n", - "from ott.core import discrete_barycenter" + "from ott.geometry import grid, costs, epsilon_scheduler\n", + "from ott.solvers.linear import discrete_barycenter" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": { "id": "Vm4s3i53yZXa" }, @@ -246,7 +245,7 @@ "outputs": [], "source": [ "@jax.tree_util.register_pytree_node_class\n", - "class Custom(ott.geometry.costs.CostFn):\n", + "class Custom(costs.CostFn):\n", " \"\"\"Custom function.\"\"\"\n", "\n", " def pairwise(self, x, y):\n", @@ -257,9 +256,7 @@ "g_grid = grid.Grid(\n", " x=[jnp.arange(0, n) / 100 for n in grid_size],\n", " cost_fns=[Custom()],\n", - " epsilon=ott.geometry.epsilon_scheduler.Epsilon(\n", - " target=1e-4, init=1e-1, decay=0.95\n", - " ),\n", + " epsilon=epsilon_scheduler.Epsilon(target=1e-4, init=1e-1, decay=0.95),\n", ")" ] }, @@ -478,9 +475,21 @@ "provenance": [] }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" } }, "nbformat": 4, diff --git a/docs/notebooks/application_biology.ipynb b/docs/notebooks/application_biology.ipynb index d8a4f869a..ffbcd0306 100644 --- a/docs/notebooks/application_biology.ipynb +++ b/docs/notebooks/application_biology.ipynb @@ -27,7 +27,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -39,16 +39,17 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": { "id": "n8JBHUyPHSJE" }, "outputs": [], "source": [ + "import numpy as np\n", "import matplotlib as mpl\n", "from matplotlib import pyplot as plt\n", - "import numpy as np\n", - "from ott.core import sinkhorn\n", + "\n", + "from ott.solvers.linear import sinkhorn\n", "from ott.geometry import pointcloud" ] }, @@ -388,9 +389,21 @@ "provenance": [] }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" } }, "nbformat": 4, diff --git a/docs/notebooks/gromov_wasserstein.ipynb b/docs/notebooks/gromov_wasserstein.ipynb index d6c09bf08..23c2db96b 100644 --- a/docs/notebooks/gromov_wasserstein.ipynb +++ b/docs/notebooks/gromov_wasserstein.ipynb @@ -41,7 +41,7 @@ "source": [ "from IPython import display\n", "import jax\n", - "from jax import numpy as jnp\n", + "import jax.numpy as jnp\n", "from jax import random\n", "import numpy as np\n", "from matplotlib import animation\n", @@ -50,7 +50,7 @@ "import mpl_toolkits.mplot3d.axes3d as p3\n", "\n", "import ott\n", - "from ott.core import gromov_wasserstein as gw\n", + "from ott.solvers.quadratic import gromov_wasserstein as gw\n", "from ott.geometry import pointcloud" ] }, @@ -8203,7 +8203,7 @@ "provenance": [] }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -8217,7 +8217,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.10.6" } }, "nbformat": 4, diff --git a/docs/notebooks/gromov_wasserstein_multiomics.ipynb b/docs/notebooks/gromov_wasserstein_multiomics.ipynb index 5b3db54b7..574d35292 100644 --- a/docs/notebooks/gromov_wasserstein_multiomics.ipynb +++ b/docs/notebooks/gromov_wasserstein_multiomics.ipynb @@ -58,9 +58,8 @@ "source": [ "import sys\n", "\n", - "!git clone -q https://github.com/rsinghlab/SCOT\n", - "\n", "if \"google.colab\" in sys.modules:\n", + " !git clone -q https://github.com/rsinghlab/SCOT\n", " !pip install -r SCOT/src/requirements.txt\n", " !pip install -q git+https://github.com/ott-jax/ott@main\n", " !pip install -q seaborn" @@ -68,7 +67,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -82,7 +81,7 @@ "from IPython import display\n", "\n", "import ott\n", - "from ott.core import gromov_wasserstein as gw\n", + "from ott.solvers.quadratic import gromov_wasserstein as gw\n", "\n", "from ot.gromov import init_matrix, gwloss\n", "from SCOT.src.scot import SCOT" @@ -16943,7 +16942,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.10.6" } }, "nbformat": 4, diff --git a/docs/notebooks/icnn_inits.ipynb b/docs/notebooks/icnn_inits.ipynb index e75704b73..6fb90cc9b 100644 --- a/docs/notebooks/icnn_inits.ipynb +++ b/docs/notebooks/icnn_inits.ipynb @@ -17,7 +17,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -29,7 +29,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -42,8 +42,7 @@ "from torch.utils.data import DataLoader\n", "from ott.tools.sinkhorn_divergence import sinkhorn_divergence\n", "from ott.geometry import pointcloud\n", - "from ott.core.neuraldual import NeuralDualSolver\n", - "from ott.core import icnn" + "from ott.solvers.nn import icnn, neuraldual" ] }, { @@ -551,7 +550,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -565,7 +564,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.5" + "version": "3.10.6" }, "vscode": { "interpreter": { diff --git a/docs/notebooks/introduction_grid.ipynb b/docs/notebooks/introduction_grid.ipynb index 402041b86..cadf21432 100644 --- a/docs/notebooks/introduction_grid.ipynb +++ b/docs/notebooks/introduction_grid.ipynb @@ -25,7 +25,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -37,7 +37,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": { "executionInfo": { "elapsed": 320, @@ -58,7 +58,7 @@ "import jax.numpy as jnp\n", "import numpy as np\n", "\n", - "from ott.core import sinkhorn\n", + "from ott.solvers.linear import sinkhorn\n", "from ott.geometry import costs\n", "from ott.geometry import grid\n", "from ott.geometry import pointcloud" @@ -460,9 +460,21 @@ "provenance": [] }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" } }, "nbformat": 4, diff --git a/docs/notebooks/neural_dual.ipynb b/docs/notebooks/neural_dual.ipynb index 30f5f621c..1ba8b126e 100644 --- a/docs/notebooks/neural_dual.ipynb +++ b/docs/notebooks/neural_dual.ipynb @@ -17,7 +17,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -29,7 +29,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -40,10 +40,9 @@ "import matplotlib.pyplot as plt\n", "from torch.utils.data import IterableDataset\n", "from torch.utils.data import DataLoader\n", - "from ott.tools.sinkhorn_divergence import sinkhorn_divergence\n", + "from ott.tools import sinkhorn_divergence\n", "from ott.geometry import pointcloud\n", - "from ott.core.neuraldual import NeuralDualSolver\n", - "from ott.core import icnn" + "from ott.solvers.nn import neuraldual, icnn" ] }, { @@ -163,7 +162,7 @@ " a = jnp.ones(len(x)) / len(x)\n", " b = jnp.ones(len(y)) / len(y)\n", "\n", - " sdiv = sinkhorn_divergence(\n", + " sdiv = sinkhorn_divergence.sinkhorn_divergence(\n", " pointcloud.PointCloud, x, y, epsilon=epsilon, a=a, b=b\n", " )\n", " return sdiv.divergence" @@ -331,7 +330,7 @@ } ], "source": [ - "neural_dual_solver = NeuralDualSolver(\n", + "neural_dual_solver = neuraldual.NeuralDualSolver(\n", " input_dim,\n", " neural_f,\n", " neural_g,\n", @@ -535,7 +534,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.9.15 64-bit", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -549,7 +548,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.15" + "version": "3.10.6" }, "vscode": { "interpreter": { diff --git a/docs/notebooks/point_clouds.ipynb b/docs/notebooks/point_clouds.ipynb index 7447d8cd6..763022977 100644 --- a/docs/notebooks/point_clouds.ipynb +++ b/docs/notebooks/point_clouds.ipynb @@ -21,7 +21,15 @@ "id": "O2Qs8m9SN1ag", "outputId": "ed53b82f-b649-4836-994a-453b16377772" }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], "source": [ "%pip install -q git+https://github.com/ott-jax/ott@main" ] @@ -41,8 +49,8 @@ "\n", "import ott\n", "from ott.geometry import costs, pointcloud\n", - "from ott.core import sinkhorn\n", - "from ott.core import linear_problems" + "from ott.problems.linear import linear_problem\n", + "from ott.solvers.linear import sinkhorn" ] }, { @@ -89,7 +97,7 @@ "\n", "This geometry object defines a `LinearProblem` object, which contains all the data needed to instantiate a linear OT problem (see Gromov-Wasserstein tutorials for *quadratic* OT problems).\n", "\n", - "We can then call a `Sinkhorn` solver to solve that problem, and compute the OT between these points clouds. Note that all weights are assumed to be uniform in this notebook, but non-uniform weights can be passed as `a= .. ,b= ..` arguments when defining the `LinearProblem` below." + "We can then call a `Sinkhorn` solver to solve that problem, and compute the OT between these points clouds. Note that all weights are assumed to be uniform in this notebook, but non-uniform weights can be passed as `a= ..., b= ...` arguments when defining the `LinearProblem` below." ] }, { @@ -117,7 +125,7 @@ ], "source": [ "# Define a linear problem with that cost structure.\n", - "ot_prob = linear_problems.LinearProblem(geom)\n", + "ot_prob = linear_problem.LinearProblem(geom)\n", "# Create a Sinkhorn solver\n", "solver = sinkhorn.Sinkhorn()\n", "# Solve OT problem\n", @@ -252,7 +260,7 @@ "):\n", " # Wrapper function that returns OT cost and OT output given a geometry.\n", " def reg_ot_cost(geom):\n", - " out = sinkhorn.Sinkhorn()(linear_problems.LinearProblem(geom))\n", + " out = sinkhorn.Sinkhorn()(linear_problem.LinearProblem(geom))\n", " return out.reg_ot_cost, out\n", "\n", " # The jax.value_and_grad operator. Note that we make explicit that\n", @@ -217025,7 +217033,7 @@ "outputs": [], "source": [ "geom = pointcloud.PointCloud(x, y)\n", - "out = sinkhorn.Sinkhorn()(linear_problems.LinearProblem(geom))\n", + "out = sinkhorn.Sinkhorn()(linear_problem.LinearProblem(geom))\n", "dual_potentials = out.to_dual_potentials()" ] }, @@ -217087,7 +217095,7 @@ "outputs": [], "source": [ "geom = pointcloud.PointCloud(x, y, cost_fn=costs.SqPNorm(p=1.1))\n", - "out = sinkhorn.Sinkhorn()(linear_problems.LinearProblem(geom))\n", + "out = sinkhorn.Sinkhorn()(linear_problem.LinearProblem(geom))\n", "dual_potentials = out.to_dual_potentials()" ] }, @@ -217144,7 +217152,7 @@ "provenance": [] }, "kernelspec": { - "display_name": "Python 3.9.15 64-bit", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -217158,7 +217166,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.15" + "version": "3.10.6" }, "vscode": { "interpreter": { @@ -217167,5 +217175,5 @@ } }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 1 } diff --git a/docs/notebooks/soft_sort.ipynb b/docs/notebooks/soft_sort.ipynb index 69a801f3d..244cf2976 100644 --- a/docs/notebooks/soft_sort.ipynb +++ b/docs/notebooks/soft_sort.ipynb @@ -760,7 +760,7 @@ "toc_visible": true }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -774,7 +774,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.10.6" } }, "nbformat": 4, diff --git a/docs/notebooks/wasserstein_barycenters_gmms.ipynb b/docs/notebooks/wasserstein_barycenters_gmms.ipynb index 547f7a720..1703efe78 100644 --- a/docs/notebooks/wasserstein_barycenters_gmms.ipynb +++ b/docs/notebooks/wasserstein_barycenters_gmms.ipynb @@ -52,9 +52,10 @@ "metadata": {}, "outputs": [], "source": [ + "from ott.geometry import costs\n", "from ott.tools.gaussian_mixture import gaussian_mixture\n", - "from ott.core import bar_problems, continuous_barycenter\n", - "from ott.geometry import costs" + "from ott.problems.linear import barycenter_problem\n", + "from ott.solvers.linear import continuous_barycenter" ] }, { @@ -406,7 +407,7 @@ "outputs": [], "source": [ "# create a barycenter problem.\n", - "bar_p = bar_problems.BarycenterProblem(\n", + "bar_p = barycenter_problem.BarycenterProblem(\n", " y=ys,\n", " b=bs,\n", " weights=barycentric_weights,\n", @@ -606,9 +607,9 @@ ], "metadata": { "kernelspec": { - "display_name": "ott", + "display_name": "Python 3 (ipykernel)", "language": "python", - "name": "ott" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -620,7 +621,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.10.6" } }, "nbformat": 4, From f213453032c356d851f705ac323d1f459b3ed336 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 18 Nov 2022 10:48:00 +0100 Subject: [PATCH 11/34] Update geometry docs --- ott/geometry/costs.py | 114 +++++++++++++++------- ott/geometry/epsilon_scheduler.py | 1 + ott/geometry/geometry.py | 6 +- ott/problems/linear/barycenter_problem.py | 2 +- ott/problems/quadratic/gw_barycenter.py | 2 +- ott/tools/segment_sinkhorn.py | 4 +- ott/tools/sinkhorn_divergence.py | 4 +- tests/core/continuous_barycenter_test.py | 2 +- 8 files changed, 89 insertions(+), 46 deletions(-) diff --git a/ott/geometry/costs.py b/ott/geometry/costs.py index 084673300..5b11005d3 100644 --- a/ott/geometry/costs.py +++ b/ott/geometry/costs.py @@ -52,11 +52,29 @@ class CostFn(abc.ABC): def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: pass + # TODO(michalk8): make weights optional? def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> float: - raise NotImplementedError("Barycenter not yet implemented for this cost.") + """Barycentric projection. + + Args: + weights: TODO. + xs: TODO. + + Returns: + TODO. + """ + raise NotImplementedError("Barycenter is not yet implemented.") @classmethod - def padder(cls, dim: int) -> jnp.ndarray: + def _padder(cls, dim: int) -> jnp.ndarray: + """TODO. + + Args: + dim: TODO. + + Returns: + TODO. + """ return jnp.zeros((1, dim)) def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> float: @@ -104,6 +122,7 @@ class TICost(CostFn): real-values, to be used as: .. math:: + c(x,y) = h(z), z := x-y. If that cost function is used to form an Entropic map using the @@ -117,7 +136,7 @@ def h(self, z: jnp.ndarray) -> float: """TI function acting on difference of :math:`x-y` to output cost.""" def h_legendre(self, z: jnp.ndarray) -> float: - """Legendre transform of TI function :func:`h` (when latter is convex).""" + """Legendre transform of :func:`h` when it is convex.""" raise NotImplementedError("`h_legendre` not implemented.") def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: @@ -131,6 +150,9 @@ class SqPNorm(TICost): For details on the derivation of the Legendre transform of the norm, see e.g. the reference :cite:`boyd:04`, p.93/94. + + Args: + p: TODO. """ def __init__(self, p: float): @@ -156,7 +178,11 @@ def tree_unflatten(cls, aux_data, children): @jax.tree_util.register_pytree_node_class class PNorm(TICost): - """p-norm (to the power p) of the difference of two vectors.""" + """p-norm (to the power p) of the difference of two vectors. + + Args: + p: TODO. + """ def __init__(self, p: float): super().__init__() @@ -219,7 +245,11 @@ def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> jnp.ndarray: @jax.tree_util.register_pytree_node_class class Cosine(CostFn): - """Cosine distance CostFn.""" + """Cosine distance cost function. + + Args: + ridge: TODO. + """ def __init__(self, ridge: float = 1e-8): super().__init__() @@ -236,27 +266,32 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: return jnp.clip(cosine_distance, 0., 2.) @classmethod - def padder(cls, dim: int) -> jnp.ndarray: + def _padder(cls, dim: int) -> jnp.ndarray: return jnp.ones((1, dim)) @jax.tree_util.register_pytree_node_class class Bures(CostFn): - """Bures distance between a pair of (mean, cov matrix) raveled as vectors.""" + """Bures distance between a pair of (mean, cov matrix) raveled as vectors. + + Args: + dimension: Dimensionality of the data. + kwargs: Keyword arguments for :func:`ott.math.matrix_square_root.sqrtm`. + """ def __init__(self, dimension: int, **kwargs: Any): super().__init__() self._dimension = dimension self._sqrtm_kw = kwargs - def norm(self, x: jnp.ndarray): + def norm(self, x: jnp.ndarray) -> jnp.ndarray: """Compute norm of Gaussian, sq. 2-norm of mean + trace of covariance.""" mean, cov = x_to_means_and_covs(x, self._dimension) norm = jnp.sum(mean ** 2, axis=-1) norm += jnp.trace(cov, axis1=-2, axis2=-1) return norm - def pairwise(self, x: jnp.ndarray, y: jnp.ndarray): + def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: """Compute - 2 x Bures dot-product.""" mean_x, cov_x = x_to_means_and_covs(x, self._dimension) mean_y, cov_y = x_to_means_and_covs(y, self._dimension) @@ -268,17 +303,6 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray): )[0] return -2 * (mean_dot_prod + jnp.trace(sq__sq_x_y_sq_x, axis1=-2, axis2=-1)) - @functools.partial(jax.vmap, in_axes=[None, None, 0, 0]) - def scale_covariances(self, cov_sqrt, cov_i, lambda_i): - """Iterate update needed to compute barycenter of covariances.""" - return lambda_i * matrix_square_root.sqrtm_only( - jnp.matmul(jnp.matmul(cov_sqrt, cov_i), cov_sqrt) - ) - - def relative_diff(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: - """Monitor change in two successive estimates of matrices.""" - return jnp.sum(jnp.square(x - y)) / jnp.prod(jnp.array(x.shape)) - def covariance_fixpoint_iter( self, covs: jnp.ndarray, @@ -287,27 +311,40 @@ def covariance_fixpoint_iter( ) -> jnp.ndarray: """Iterate fix-point updates to compute barycenter of Gaussians.""" - def cond_fn(iteration, constants, state): + @functools.partial(jax.vmap, in_axes=[None, 0, 0]) + def scale_covariances( + cov_sqrt: jnp.ndarray, cov_i: jnp.ndarray, lambda_i: jnp.ndarray + ) -> jnp.ndarray: + """Iterate update needed to compute barycenter of covariances.""" + return lambda_i * matrix_square_root.sqrtm_only( + (cov_sqrt @ cov_i) @ cov_sqrt + ) + + def cond_fn(iteration: int, constants: Tuple[...], state) -> bool: + del iteration, constants _, diff = state - return diff > jnp.array(rtol) + return diff > rtol - def body_fn(iteration, constants, state, compute_error): - del compute_error + def body_fn( + iteration: int, constants: Tuple[...], state: Tuple[jnp.ndarray, float], + compute_error: bool + ) -> Tuple[jnp.ndarray, float]: + del iteration, constants, compute_error cov, _ = state cov_sqrt, cov_inv_sqrt, _ = matrix_square_root.sqrtm(cov) scaled_cov = jnp.linalg.matrix_power( jnp.sum(self.scale_covariances(cov_sqrt, covs, lambdas), axis=0), 2 ) - next_cov = jnp.matmul(jnp.matmul(cov_inv_sqrt, scaled_cov), cov_inv_sqrt) - diff = self.relative_diff(next_cov, cov) + next_cov = (cov_inv_sqrt @ scaled_cov) @ cov_inv_sqrt + diff = jnp.sum((next_cov - cov) ** 2) / jnp.prod(jnp.array(cov.shape)) return next_cov, diff - def init_state(): + def init_state() -> Tuple[jnp.ndarray, float]: cov_init = jnp.eye(self._dimension) diff = jnp.inf return cov_init, diff - state = fixed_point_loop.fixpoint_iter( + cov, _ = fixed_point_loop.fixpoint_iter( cond_fn=cond_fn, body_fn=body_fn, min_iterations=10, @@ -316,8 +353,6 @@ def init_state(): constants=(), state=init_state() ) - - cov, _ = state return cov def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> jnp.ndarray: @@ -345,7 +380,7 @@ def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> jnp.ndarray: return barycenter @classmethod - def padder(cls, dim: int) -> jnp.ndarray: + def _padder(cls, dim: int) -> jnp.ndarray: """Pad with concatenated zero means and \ raveled identity covariance matrix.""" dimension = int((-1 + math.sqrt(1 + 4 * dim)) / 2) @@ -370,6 +405,12 @@ class UnbalancedBures(CostFn): This cost implements the value defined in :cite:`janati:20`, eq. 37, 39, 40. We follow their notations. It is assumed inputs are given as triplets (mass, mean, covariance) raveled as vectors, in that order. + + Args: + dimension: TODO. + gamma: TODO. + sigma: TODO. + kwargs: TODO. """ def __init__( @@ -382,10 +423,10 @@ def __init__( super().__init__() self._dimension = dimension self._gamma = gamma - self._sigma2 = sigma ** 2 + self._sigma = sigma self._sqrtm_kw = kwargs - def norm(self, x: jnp.ndarray) -> Union[float, jnp.ndarray]: + def norm(self, x: jnp.ndarray) -> jnp.ndarray: """Compute norm of Gaussian for unbalanced Bures.""" return self._gamma * x[0] @@ -393,7 +434,7 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: """Compute dot-product for unbalanced Bures.""" # Sets a few constants gam = self._gamma - sig2 = self._sigma2 + sig2 = self._sigma ** 2 lam = sig2 + gam / 2 tau = gam / (2 * lam) @@ -446,12 +487,13 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: ) def tree_flatten(self): - return (), (self._dimension, self._gamma, self._sigma2, self._sqrtm_kw) + return (), (self._dimension, self._gamma, self._sigma, self._sqrtm_kw) @classmethod def tree_unflatten(cls, aux_data, children): del children - return cls(aux_data[0], aux_data[1], aux_data[2], **aux_data[3]) + dim, gamma, sigma, kwargs = aux_data + return cls(dim, gamma=gamma, sigma=sigma, **kwargs) def x_to_means_and_covs(x: jnp.ndarray, diff --git a/ott/geometry/epsilon_scheduler.py b/ott/geometry/epsilon_scheduler.py index 8f5666385..284fa29f7 100644 --- a/ott/geometry/epsilon_scheduler.py +++ b/ott/geometry/epsilon_scheduler.py @@ -44,6 +44,7 @@ class Epsilon: decay: geometric decay factor, smaller than 1. """ + # TODO(michalk8): directly use the defaults instead of `None` def __init__( self, target: Optional[float] = None, diff --git a/ott/geometry/geometry.py b/ott/geometry/geometry.py index 3550c881b..f33a8225c 100644 --- a/ott/geometry/geometry.py +++ b/ott/geometry/geometry.py @@ -28,7 +28,7 @@ from ott.geometry import epsilon_scheduler from ott.math import utils -__all__ = ["Geometry"] +__all__ = ["Geometry", "is_linear", "is_affine"] @jax.tree_util.register_pytree_node_class @@ -37,7 +37,7 @@ class Geometry: Optimal transport problems are intrinsically geometric: they compute an optimal way to transport mass from one configuration onto another. To define - what is meant by optimality of a transport requires defining a cost, of moving + what is meant by optimality of transport requires defining a cost, of moving mass from one among several sources, towards one out of multiple targets. These sources and targets can be provided as points in vectors spaces, grids, or more generally exclusively described through a (dissimilarity) cost matrix, @@ -900,4 +900,4 @@ def is_affine(fn) -> bool: def is_linear(fn) -> bool: """Test heuristically if a function is linear.""" - return fn(0.0) == 0.0 and is_affine(fn) + return jnp.logical_and(fn(0.0) == 0.0, is_affine(fn)) diff --git a/ott/problems/linear/barycenter_problem.py b/ott/problems/linear/barycenter_problem.py index 7a6e63129..943292bc2 100644 --- a/ott/problems/linear/barycenter_problem.py +++ b/ott/problems/linear/barycenter_problem.py @@ -101,7 +101,7 @@ def segmented_y_b(self) -> Tuple[jnp.ndarray, jnp.ndarray]: y, b = segment.segment_point_cloud( x=self._y, a=self._b, - padding_vector=self.cost_fn.padder(self.ndim), + padding_vector=self.cost_fn._padder(self.ndim), **self._kwargs ) diff --git a/ott/problems/quadratic/gw_barycenter.py b/ott/problems/quadratic/gw_barycenter.py index a052e6224..be87ec098 100644 --- a/ott/problems/quadratic/gw_barycenter.py +++ b/ott/problems/quadratic/gw_barycenter.py @@ -263,7 +263,7 @@ def segmented_y_fused(self) -> Optional[jnp.ndarray]: return self._y_fused y_fused, _ = segment.segment_point_cloud( x=self._y_fused, - padding_vector=self.cost_fn.padder(self.ndim_fused), + padding_vector=self.cost_fn._padder(self.ndim_fused), **self._kwargs ) return y_fused diff --git a/ott/tools/segment_sinkhorn.py b/ott/tools/segment_sinkhorn.py index fca0a4226..3b149c0f5 100644 --- a/ott/tools/segment_sinkhorn.py +++ b/ott/tools/segment_sinkhorn.py @@ -98,9 +98,9 @@ def segment_sinkhorn( dim = x.shape[1] if cost_fn is None: # default padder - padding_vector = costs.CostFn.padder(dim=dim) + padding_vector = costs.CostFn._padder(dim=dim) else: - padding_vector = cost_fn.padder(dim=dim) + padding_vector = cost_fn._padder(dim=dim) def eval_fn( padded_x: jnp.ndarray, diff --git a/ott/tools/sinkhorn_divergence.py b/ott/tools/sinkhorn_divergence.py index 2b8404439..baf5688c3 100644 --- a/ott/tools/sinkhorn_divergence.py +++ b/ott/tools/sinkhorn_divergence.py @@ -270,9 +270,9 @@ def segment_sinkhorn_divergence( dim = x.shape[1] if cost_fn is None: # default padder - padding_vector = costs.CostFn.padder(dim=dim) + padding_vector = costs.CostFn._padder(dim=dim) else: - padding_vector = cost_fn.padder(dim=dim) + padding_vector = cost_fn._padder(dim=dim) def eval_fn( padded_x: jnp.ndarray, diff --git a/tests/core/continuous_barycenter_test.py b/tests/core/continuous_barycenter_test.py index 2a0efa9d1..2578a459c 100644 --- a/tests/core/continuous_barycenter_test.py +++ b/tests/core/continuous_barycenter_test.py @@ -221,7 +221,7 @@ def test_bures_barycenter( num_segments=num_measures, max_measure_size=num_components, num_per_segment=(num_components, num_components), - padding_vector=bures_cost.padder(y.shape[1]), + padding_vector=bures_cost._padder(y.shape[1]), ) bar_p = barycenter_problem.BarycenterProblem( seg_y, From 4b560d242f8f111a5c8fac7fb840f2ea9447abad Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 18 Nov 2022 10:55:29 +0100 Subject: [PATCH 12/34] Update initializers --- ott/initializers/linear/initializers_lr.py | 6 ++++-- ott/initializers/nn/initializers.py | 4 ++-- ott/initializers/quadratic/initializers.py | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/ott/initializers/linear/initializers_lr.py b/ott/initializers/linear/initializers_lr.py index 6059a17bb..3eb4670ee 100644 --- a/ott/initializers/linear/initializers_lr.py +++ b/ott/initializers/linear/initializers_lr.py @@ -342,7 +342,8 @@ class KMeansInitializer(LRInitializer): rank: Rank of the factorization. min_iterations: Minimum number of k-means iterations. max_iterations: Maximum number of k-means iterations. - sinkhorn_kwargs: Keyword arguments for :class:`~ott.core.sinkhorn.Sinkhorn`. + sinkhorn_kwargs: Keyword arguments for + :class:`~ott.solvers.linear.sinkhorn.Sinkhorn`. kwargs: Keyword arguments for :func:`~ott.tools.k_means.k_means`. """ @@ -466,7 +467,8 @@ class GeneralizedKMeansInitializer(KMeansInitializer): inner_iterations: Number of iterations used by the algorithm before re-evaluating progress. threshold: Convergence threshold. - sinkhorn_kwargs: Keyword arguments for :class:`~ott.core.sinkhorn.Sinkhorn`. + sinkhorn_kwargs: Keyword arguments for + :class:`~ott.solvers.linear.sinkhorn.Sinkhorn`. """ def __init__( diff --git a/ott/initializers/nn/initializers.py b/ott/initializers/nn/initializers.py index b10a58dac..0f6ce2e07 100644 --- a/ott/initializers/nn/initializers.py +++ b/ott/initializers/nn/initializers.py @@ -30,7 +30,7 @@ class MetaInitializer(initializers.DefaultInitializer): instances (multiple pairs of probability weights), that assume the **same** geometry ``geom`` is used throughout, both for training and evaluation. The meta model defaults to the MLP in - :class:`~ott.core.initializers.MetaMLP` and, with batched problem + :class:`~ott.initializers.nn.initializers.MetaMLP` and, with batched problem instances passed into :meth:`update`. **Sample training usage.** The following code shows a simple @@ -192,7 +192,7 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: class MetaMLP(nn.Module): - r"""A Meta MLP potential for :class:`~ott.core.initializers.MetaInitializer`. + r"""Potential for :class:`~ott.initializers.nn.initializers.MetaInitializer`. This provides an MLP :math:`\hat f_\theta(a, b)` that maps from the probabilities of the measures to the optimal dual potentials :math:`f`. diff --git a/ott/initializers/quadratic/initializers.py b/ott/initializers/quadratic/initializers.py index 0c97f83a5..2b09306d9 100644 --- a/ott/initializers/quadratic/initializers.py +++ b/ott/initializers/quadratic/initializers.py @@ -172,7 +172,7 @@ def _create_geometry( Args: quad_prob: Quadratic OT problem. kwargs: Keyword arguments for - :meth:`ott.core.initializers_lr.LRInitializer.__call__`. + :meth:`~ott.initializers.linear.initializers_lr.LRInitializer.__call__`. Returns: The initial geometry used to initialize a linear problem. From 75602d290f1283a17ae09cebbc8954341b13f892 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 18 Nov 2022 10:57:28 +0100 Subject: [PATCH 13/34] Update math docs --- ott/math/decomposition.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ott/math/decomposition.py b/ott/math/decomposition.py index cbdf61a8b..a05ee1b23 100644 --- a/ott/math/decomposition.py +++ b/ott/math/decomposition.py @@ -65,12 +65,12 @@ def _solve(self, L: Optional[T], b: jnp.ndarray) -> jnp.ndarray: def create(cls, A: Union[T, sp.spmatrix], **kwargs: Any) -> "CholeskySolver": """Instantiate sparse or dense Cholesky solver. - Optionally converts :class:`scipy.sparse.spmatrix` to its + And optionally convert :class:`scipy.sparse.spmatrix` to its :mod:`jax` equivalent. Args: A: Symmetric positive definite matrix of shape ``[n, n]``. - kwargs: Keyword arguments for the initialization. + kwargs: Keyword arguments for the solver initialization. Returns: Sparse or dense Cholesky solver. From 7dcbea69367abcfd3c3c9e5795abf3b22b296106 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 18 Nov 2022 11:11:50 +0100 Subject: [PATCH 14/34] Update problem docstrings --- ott/problems/linear/barycenter_problem.py | 8 ++++---- ott/problems/linear/potentials.py | 2 +- ott/problems/quadratic/gw_barycenter.py | 9 +++++---- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/ott/problems/linear/barycenter_problem.py b/ott/problems/linear/barycenter_problem.py index 943292bc2..ac30a2b01 100644 --- a/ott/problems/linear/barycenter_problem.py +++ b/ott/problems/linear/barycenter_problem.py @@ -30,16 +30,16 @@ class BarycenterProblem: y: Array of shape ``[num_total_points, ndim]`` merging the points of all measures. Alternatively, already segmented array of shape ``[num_measures, max_measure_size, ndim]`` can be passed. - See also :func:`~ott.core.segment.segment_point_cloud`. + See also :func:`~ott.geometry.segment.segment_point_cloud`. b: Array of shape ``[num_total_points,]`` containing the weights of all the points within the measures that define the barycenter problem. - Similarly as ``y``, segmented array of weights of shape + Same as ``y``, pre-segmented array of weights of shape ``[num_measures, max_measure_size]`` can be passed. If ``y`` is already pre-segmented, this array must be always specified. weights: Array of shape ``[num_measures,]`` containing the weights of the measures. cost_fn: Cost function used. If `None`, - use :class:`~ott.geometry.costs.SqEuclidean` cost. + use the :class:`~ott.geometry.costs.SqEuclidean` cost. epsilon: Epsilon regularization used to solve reg-OT problems. debiased: **Currently not implemented.** Whether the problem is debiased, in the sense that @@ -49,7 +49,7 @@ class BarycenterProblem: :meth:`~ott.core.continuous_barycenter.WassersteinBarycenter.init_state` needs to be smaller than the maximum measure size for parallelization to operate efficiently. - kwargs: Keyword arguments :func:`~ott.core.segment.segment_point_cloud`. + kwargs: Keyword arguments :func:`~ott.geometry.segment.segment_point_cloud`. Only used when ``y`` is not already segmented. When passing ``segment_ids``, 2 arguments must be specified for jitting to work: diff --git a/ott/problems/linear/potentials.py b/ott/problems/linear/potentials.py index cf2914c67..dde7366d6 100644 --- a/ott/problems/linear/potentials.py +++ b/ott/problems/linear/potentials.py @@ -154,7 +154,7 @@ class EntropicPotentials(DualPotentials): f: The first dual potential vector of shape ``[n,]``. g: The second dual potential vector of shape ``[m,]``. geom: Geometry used to compute the dual potentials using - :class:`~ott.core.sinkhorn.Sinkhorn`. + :class:`~ott.solvers.linear.sinkhorn.Sinkhorn`. a: Probability weights for the first measure. If `None`, use uniform. b: Probability weights for the second measure. If `None`, use uniform. """ diff --git a/ott/problems/quadratic/gw_barycenter.py b/ott/problems/quadratic/gw_barycenter.py index be87ec098..9d582ce1c 100644 --- a/ott/problems/quadratic/gw_barycenter.py +++ b/ott/problems/quadratic/gw_barycenter.py @@ -13,6 +13,7 @@ __all__ = ["GWBarycenterProblem"] +# TODO(michalk8): better abstraction (common superclass for Wasserstein bary) @jax.tree_util.register_pytree_node_class class GWBarycenterProblem(barycenter_problem.BarycenterProblem): """(Fused) Gromov-Wasserstein barycenter problem :cite:`peyre:16,vayer:19`. @@ -21,10 +22,10 @@ class GWBarycenterProblem(barycenter_problem.BarycenterProblem): y: Array of shape ``[num_total_points, ndim]`` merging the points of all measures. Alternatively, already segmented array of shape ``[num_measures, max_measure_size, ndim]`` can be passed. - See also :func:`~ott.core.segment.segment_point_cloud`. + See also :func:`~ott.geometry.segment.segment_point_cloud`. b: Array of shape ``[num_total_points,]`` containing the weights of all the points within the measures that define the barycenter problem. - Similarly, as ``y``, segmented array of weights of shape + Same as ``y``, pre-segmented array of weights of shape ``[num_measures, max_measure_size]`` can be passed. If ``y`` is already pre-segmented, this array must be passed. weights: Array of shape ``[num_measures,]`` containing the weights of the @@ -42,7 +43,7 @@ class GWBarycenterProblem(barycenter_problem.BarycenterProblem): ``y_fused != None``. scale_cost: Scaling of cost matrices passed to geometries. kwargs: Keyword arguments for - :class:`~ott.core.bar_problems.BarycenterProblem`. + :class:`~ott.problems.linear.barycenter_problem.BarycenterProblem`. """ def __init__( @@ -138,7 +139,7 @@ def update_features(self, transports: jnp.ndarray, """Update the barycenter features in the fused case :cite:`vayer:19`. Uses :cite:`cuturi:14` eq. 8, and is implemented only - for the squared :class:`~ott.geometry.costs.SqEuclidean` cost. + for the :class:`~ott.geometry.costs.SqEuclidean` cost. Args: transports: Transport maps of shape From 3365f247f135bd6d219efab2d3b7ac3ce89860bc Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 18 Nov 2022 12:27:19 +0100 Subject: [PATCH 15/34] Update `solvers` docstrings --- ott/geometry/geometry.py | 4 +-- ott/geometry/low_rank.py | 13 +++---- ott/solvers/linear/sinkhorn.py | 3 +- ott/solvers/linear/sinkhorn_lr.py | 38 ++++++++++---------- ott/solvers/quadratic/gromov_wasserstein.py | 40 ++++++++++----------- ott/solvers/quadratic/gw_barycenter.py | 13 +++---- 6 files changed, 56 insertions(+), 55 deletions(-) diff --git a/ott/geometry/geometry.py b/ott/geometry/geometry.py index f33a8225c..d76282e51 100644 --- a/ott/geometry/geometry.py +++ b/ott/geometry/geometry.py @@ -65,11 +65,11 @@ class Geometry: 'median', 'mean' and 'max_cost'. Alternatively, a float factor can be given to rescale the cost such that ``cost_matrix /= scale_cost``. If `True`, use 'mean'. - tgt_mask: Mask specifying valid rows when computing some statistics of + src_mask: Mask specifying valid rows when computing some statistics of :attr:`cost_matrix`, see :attr:`src_mask`. tgt_mask: Mask specifying valid columns when computing some statistics of :attr:`cost_matrix`, see :attr:`tgt_mask`. - kwargs: additional kwargs to epsilon scheduler. + kwargs: additional kwargs for epsilon scheduler. Note: When defining a ``Geometry`` through a ``cost_matrix``, it is important to diff --git a/ott/geometry/low_rank.py b/ott/geometry/low_rank.py index 1fc748059..eeffa0733 100644 --- a/ott/geometry/low_rank.py +++ b/ott/geometry/low_rank.py @@ -47,10 +47,11 @@ class LRCGeometry(geometry.Geometry): ``cost_matrix /= scale_cost``. If `True`, use 'mean'. batch_size: optional size of the batch to compute online (without instantiating the matrix) the scale factor ``scale_cost`` of the - ``cost_matrix`` when ``scale_cost='max_cost'``. If set to ``None``, the - batch size is set to 1024 or to the largest number of samples between - ``cost_1`` and ``cost_2`` if smaller than `1024`. - kwargs: Additional kwargs to :class:`~ott.geometry.geometry.Geometry`. + :attr:`cost_matrix` when ``scale_cost = 'max_cost'``. If `None`, the batch + size is set to `1024` or to the largest number of samples between + :attr:`cost_1` and :attr:`cost_2` if smaller than `1024`. + kwargs: Additional keyword arguments for + :class:`~ott.geometry.geometry.Geometry`. """ def __init__( @@ -189,10 +190,10 @@ def compute_max_cost(self) -> float: Three cases are taken into account: - If the number of samples of ``cost_1`` and ``cost_2`` are both smaller - than 1024 and if ``batch_size`` is ``None``, the ``cost_matrix`` is + than 1024 and if ``batch_size`` is `None`, the ``cost_matrix`` is computed to obtain its maximum entry. - If one of the number of samples of ``cost_1`` or ``cost_2`` is larger - than 1024 and if ``batch_size`` is ``None``, then the maximum of the + than 1024 and if ``batch_size`` is `None`, then the maximum of the cost matrix is calculated by batch. The batches are created on the longest axis of the cost matrix and their size is fixed to 1024. - If ``batch_size`` is provided as a float, then the maximum of the cost diff --git a/ott/solvers/linear/sinkhorn.py b/ott/solvers/linear/sinkhorn.py index 9f824e142..624a2a115 100644 --- a/ott/solvers/linear/sinkhorn.py +++ b/ott/solvers/linear/sinkhorn.py @@ -305,7 +305,8 @@ class Sinkhorn: A Sinkhorn solver takes a linear OT problem object as an input and returns a SinkhornOutput object that contains all the information required to compute - transports. See :func:`~ott.core.sinkhorn.sinkhorn` for a functional wrapper. + transports. See :func:`~ott.solvers.linear.sinkhorn.sinkhorn` + for a functional wrapper. Args: lse_mode: ``True`` for log-sum-exp computations, ``False`` for kernel diff --git a/ott/solvers/linear/sinkhorn_lr.py b/ott/solvers/linear/sinkhorn_lr.py index 62364ef14..5f4dbd7ec 100644 --- a/ott/solvers/linear/sinkhorn_lr.py +++ b/ott/solvers/linear/sinkhorn_lr.py @@ -217,38 +217,38 @@ class LRSinkhorn(sinkhorn.Sinkhorn): case. Args: - rank: the rank constraint on the coupling to minimize the linear OT problem - gamma: the (inverse of) gradient step size used by mirror descent. + rank: The rank constraint on the coupling to minimize the linear OT problem + gamma: The (inverse of) gradient step size used by mirror descent. gamma_rescale: Whether to rescale :math:`\gamma` every iteration as described in :cite:`scetbon:22b`. - epsilon: entropic regularization added on top of low-rank problem. + epsilon: Entropic regularization added on top of low-rank problem. initializer: How to initialize the :math:`Q`, :math:`R` and :math:`g` factors. Valid options are: - - `'random'` - :class:`~ott.core.initializers_lr.RandomInitializer`. - - `'rank2'` - :class:`~ott.core.initializers_lr.Rank2Initializer`. - - `'k-means'` - :class:`~ott.core.initializers_lr.KMeansInitializer`. - - `'generalized-k-means'` - - :class:`~ott.core.initializers_lr.GeneralizedKMeansInitializer`. + - `'random'` - :class:`~ott.initializers.linear.initializers_lr.RandomInitializer`. + - `'rank2'` - :class:`~ott.initializers.linear.initializers_lr.Rank2Initializer`. + - `'k-means'` - :class:`~ott.initializers.linear.initializers_lr.KMeansInitializer`. + - `'generalized-k-means'` - :class:`~ott.initializers.linear.initializers_lr.GeneralizedKMeansInitializer`. - If `None`, :class:`~ott.core.initializers_lr.KMeansInitializer` + If `None`, :class:`~ott.initializers.linear.initializers_lr.KMeansInitializer` is used when the linear problem's geometry is :class:`~ott.geometry.pointcloud.PointCloud` or :class:`~ott.geometry.low_rank.LRCGeometry`. - Otherwise, use :class:`~ott.core.initializers_lr.RandomInitializer`. + Otherwise, use :class:`~ott.initializers.linear.initializers_lr.RandomInitializer`. - lse_mode: whether to run computations in lse or kernel mode. At the moment, + lse_mode: Whether to run computations in lse or kernel mode. At the moment, only ``lse_mode = True`` is implemented. - inner_iterations: number of inner iterations used by the algorithm before + inner_iterations: Number of inner iterations used by the algorithm before re-evaluating progress. - use_danskin: use Danskin theorem to evaluate gradient of objective w.r.t. + use_danskin: Use Danskin theorem to evaluate gradient of objective w.r.t. input parameters. Only `True` handled at this moment. implicit_diff: Whether to use implicit differentiation. Currently, only ``implicit_diff = False`` is implemented. - kwargs_dys: keyword arguments passed to :meth:`dykstra_update`. - kwargs_init: keyword arguments for - :class:`~ott.core.initializers_lr.LRInitializer`. - kwargs: Keyword arguments for :class:`~ott.core.sinkhorn.Sinkhorn`. + kwargs_dys: Keyword arguments passed to :meth:`dykstra_update`. + kwargs_init: Keyword arguments for + :class:`~ott.initializers.linear.initializers_lr.LRInitializer`. + kwargs: Keyword arguments for + :class:`~ott.solvers.linear.sinkhorn.Sinkhorn`. """ def __init__( @@ -268,8 +268,8 @@ def __init__( kwargs_init: Optional[Mapping[str, Any]] = None, **kwargs: Any, ): - assert lse_mode, "Kernel mode not yet implemented for LRSinkhorn." - assert not implicit_diff, "Implicit diff. not yet implemented for LRSink." + assert lse_mode, "Kernel mode not yet implemented." + assert not implicit_diff, "Implicit diff. not yet implemented." super().__init__( lse_mode=lse_mode, inner_iterations=inner_iterations, diff --git a/ott/solvers/quadratic/gromov_wasserstein.py b/ott/solvers/quadratic/gromov_wasserstein.py index 05a427a74..3468147e3 100644 --- a/ott/solvers/quadratic/gromov_wasserstein.py +++ b/ott/solvers/quadratic/gromov_wasserstein.py @@ -144,33 +144,33 @@ class GromovWasserstein(was_solver.WassersteinSolver): Args: args: Positional_arguments for - :class:`~ott.core.was_solver.WassersteinSolver`. + :class:`~ott.solvers.was_solver.WassersteinSolver`. warm_start: Whether to initialize (low-rank) Sinkhorn calls using values from the previous iteration. If `None`, warm starts are not used for standard Sinkhorn, but used for low-rank Sinkhorn. quad_initializer: Quadratic initializer. If the solver is entropic, - :class:`~ott.core.quad_initializers.QuadraticInitializer` is always used. - Otherwise, the quadratic initializer wraps low-rank Sinkhorn initializers: + :class:`~ott.initializers.quadratic.initializers.QuadraticInitializer` + is always used. Otherwise, the quadratic initializer wraps the low-rank + Sinkhorn initializers: - - `'random'` - :class:`~ott.core.initializers_lr.RandomInitializer`. - - `'rank2'` - :class:`~ott.core.initializers_lr.Rank2Initializer`. - - `'k-means'` - :class:`~ott.core.initializers_lr.KMeansInitializer`. - - `'generalized-k-means'` - - :class:`~ott.core.initializers_lr.GeneralizedKMeansInitializer`. + - `'random'` - :class:`~ott.initializers.linear.initializers_lr.RandomInitializer`. + - `'rank2'` - :class:`~ott.initializers.linear.initializers_lr.Rank2Initializer`. + - `'k-means'` - :class:`~ott.initializers.linear.initializers_lr.KMeansInitializer`. + - `'generalized-k-means'` - :class:`~ott.initializers.linear.initializers_lr.GeneralizedKMeansInitializer`. If `None`, the low-rank initializer will be selected in a problem-specific manner: - - if both :attr:`~ott.core.quad_problems.QuadraticProblem.geom_xx` and - :attr:`~ott.core.quad_problems.QuadraticProblem.geom_yy` are - :class:`~ott.geometry.pointcloud.PointCloud` or - :class:`~ott.geometry.low_rank.LRCGeometry`, - :class:`~ott.core.initializers_lr.KMeansInitializer` is used. - - otherwise, use :class:`~ott.core.initializers_lr.RandomInitializer`. + - if both :attr:`~ott.problems.quadratic.quadratic_problem.QuadraticProblem.geom_xx` + and :attr:`~ott.problems.quadratic.quadratic_problem.QuadraticProblem.geom_yy` + are :class:`~ott.geometry.pointcloud.PointCloud` or :class:`~ott.geometry.low_rank.LRCGeometry`, + :class:`~ott.initializers.linear.initializers_lr.KMeansInitializer` + is used. + - otherwise, use :class:`~ott.initializers.linear.initializers_lr.RandomInitializer`. kwargs_init: Keyword arguments when creating the initializer. kwargs: Keyword arguments for - :class:`~ott.core.was_solver.WassersteinSolver`. + :class:`~ott.solvers.was_solver.WassersteinSolver`. """ def __init__( @@ -493,16 +493,14 @@ def gromov_wasserstein( - if `True`, use the default for each geometry. - if `False`, keep the original scaling in geometries. - if :class:`str`, use a specific method available in - :class:`ott.geometry.geometry.Geometry` or - :class:`ott.geometry.pointcloud.PointCloud`. + :class:`~ott.geometry.geometry.Geometry` or + :class:`~ott.geometry.pointcloud.PointCloud`. - if `None`, do not scale the cost matrices. a: jnp.ndarray[num_a,] or jnp.ndarray[batch,num_a] weights. b: jnp.ndarray[num_b,] or jnp.ndarray[batch,num_b] weights. loss: defaults to the square Euclidean distance. Can also pass 'kl' to define the GW loss as KL loss. - See :class:`~ott.core.gromov_wasserstein.GromovWasserstein` on how to pass - custom loss. tau_a: float between 0 and 1.0, parameter that controls the strength of the KL divergence constraint between the weights and marginals of the transport for the first view. If set to 1.0, then it is equivalent to a @@ -528,8 +526,8 @@ def gromov_wasserstein( geometries are *not* :class:`~ott.geometry.pointcloud.PointCloud` with `'sqeucl'` cost. If :class:`float`, that tolerance is shared across all 3 geometries. - kwargs: Keyword arguments to - :class:`~ott.core.gromov_wasserstein.GromovWasserstein`. + kwargs: Keyword arguments for + :class:`~ott.solvers.quadratic.gromov_wasserstein.GromovWasserstein`. Returns: A GromovWassersteinState named tuple. diff --git a/ott/solvers/quadratic/gw_barycenter.py b/ott/solvers/quadratic/gw_barycenter.py index 46ee6618f..f88e175cd 100644 --- a/ott/solvers/quadratic/gw_barycenter.py +++ b/ott/solvers/quadratic/gw_barycenter.py @@ -15,7 +15,8 @@ class GWBarycenterState(NamedTuple): - """Holds the state of the :class:`~ott.core.bar_problems.GWBarycenterProblem`. + """Holds the state of the \ + :class:`~ott.problems.quadratic.gw_barycenter.GWBarycenterProblem`. Args: c: Barycenter cost matrix of shape ``[bar_size, bar_size]``. @@ -44,7 +45,7 @@ def set(self, **kwargs: Any) -> 'GWBarycenterState': @jax.tree_util.register_pytree_node_class class GromovWassersteinBarycenter(was_solver.WassersteinSolver): """Gromov-Wasserstein barycenter solver of the \ - :class:`~ott.core.bar_problems.GWBarycenterProblem`. + :class:`~ott.problems.quadratic.gw_barycenter.GWBarycenterProblem`. Args: epsilon: Entropy regulariser. @@ -56,7 +57,7 @@ class GromovWassersteinBarycenter(was_solver.WassersteinSolver): as its linear solver, at each iteration for each measure. quad_solver: The GW solver. kwargs: Keyword argument for - :class:`~ott.core.gromov_wasserstein.GromovWasserstein`. + :class:`~ott.solvers.quadratic.gromov_wasserstein.GromovWasserstein`. Only used when ``quad_solver = None``. """ @@ -129,9 +130,9 @@ def init_state( - :class:`jax.numpy.ndarray` - barycenter cost matrix of shape ``[bar_size, bar_size]``. Only used in the non-fused case. - - 2- :class:`tuple` of :class:`jax.numpy.ndarray` - the 1st array - corresponds to ``[bar_size, bar_size]`` cost matrix, - the 2nd array is ``[bar_size, ndim_fused]`` a feature matrix used in + - :class:`tuple` of :class:`jax.numpy.ndarray` - the 1st array + corresponds to a cost matrix of shape ``[bar_size, bar_size]``, + the 2nd array is a ``[bar_size, ndim_fused]`` feature matrix used in the fused case. a: An array of shape ``[bar_size,]`` containing the barycenter weights. From 861bdfa5bf3b53935b7c8f616b46fd96e6255100 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 18 Nov 2022 12:36:29 +0100 Subject: [PATCH 16/34] Update `tools` docstrings --- ott/tools/segment_sinkhorn.py | 29 ++++++++++++++------------- ott/tools/sinkhorn_divergence.py | 34 +++++++++++++++++--------------- 2 files changed, 33 insertions(+), 30 deletions(-) diff --git a/ott/tools/segment_sinkhorn.py b/ott/tools/segment_sinkhorn.py index 3b149c0f5..886430d07 100644 --- a/ott/tools/segment_sinkhorn.py +++ b/ott/tools/segment_sinkhorn.py @@ -29,7 +29,7 @@ def segment_sinkhorn( cost_fn: Optional[costs.CostFn] = None, segment_ids_x: Optional[jnp.ndarray] = None, segment_ids_y: Optional[jnp.ndarray] = None, - indices_are_sorted: Optional[bool] = None, + indices_are_sorted: bool = False, num_per_segment_x: Optional[Tuple[int, ...]] = None, num_per_segment_y: Optional[Tuple[int, ...]] = None, weights_x: Optional[jnp.ndarray] = None, @@ -56,22 +56,23 @@ def segment_sinkhorn( parallel. Args: - x: Array of input points, of shape [num_x, feature]. Multiple segments are - held in this single array. - y: Array of target points, of shape [num_y, feature]. - num_segments: Number of segments contained in x and y. Providing this number - is required for JIT compilation to work, see also - :func:`~ott.core.segment.segment_point_cloud`. + x: Array of input points, of shape `[num_x, feature]`. + Multiple segments are held in this single array. + y: Array of target points, of shape `[num_y, feature]`. + num_segments: Number of segments contained in `x` and `y`. + Providing this is required for JIT compilation to work, + see also :func:`~ott.geometry.segment.segment_point_cloud`. max_measure_size: Total size of measures after padding. Should ideally be set to an upper bound on points clouds processed with the segment - interface. Providing this number is required for JIT compilation to work. - cost_fn: Cost function, defaults to :class:`~ott.core.costs.SqEuclidean`. - segment_ids_x: **1st interface** The segment ID for which each row of x + interface. Providing this is required for JIT compilation to work. + cost_fn: Cost function, defaults to + :class:`~ott.geometry.costs.SqEuclidean`. + segment_ids_x: **1st interface** The segment ID for which each row of `x` belongs. This is a similar interface to `jax.ops.segment_sum`. - segment_ids_y: **1st interface** The segment ID for which each row of y + segment_ids_y: **1st interface** The segment ID for which each row of `y` belongs. indices_are_sorted: **1st interface** Whether `segment_ids_x` and - `segment_ids_y` are sorted. Default false. + `segment_ids_y` are sorted. num_per_segment_x: **2nd interface** Number of points in each segment in `x`. For example, [100, 20, 30] would imply that `x` is segmented into three arrays of length `[100]`, `[20]`, and `[30]` respectively. @@ -87,9 +88,9 @@ def segment_sinkhorn( `y`/`y` (except when `static_b` is `True`, in which case `y`/`y` is not evaluated). kwargs: keywords arguments passed to form - :class:`ott.geometry.pointcloud.PointCloud` geometry objects from the + :class:`~ott.geometry.pointcloud.PointCloud` geometry objects from the subsets of points and masses selected in `x` and `y`, possibly a - :class:`ott.geometry.costs.CostFn` or an entropy regularizer. + :class:`~ott.geometry.costs.CostFn` or an entropy regularizer. Returns: An array of sinkhorn reg_ot_cost for each segment. diff --git a/ott/tools/sinkhorn_divergence.py b/ott/tools/sinkhorn_divergence.py index baf5688c3..28c32fa82 100644 --- a/ott/tools/sinkhorn_divergence.py +++ b/ott/tools/sinkhorn_divergence.py @@ -70,8 +70,9 @@ def sinkhorn_divergence( match that of `b` to converge. b: the weight of each target point. The sum of all elements of `b` must match that of `a` to converge. - sinkhorn_kwargs: keywords arguments for :func:`~ott.core.sinkhorn.sinkhorn` - that is called twice if ``static_b = True`` else 3 times. + sinkhorn_kwargs: keywords arguments for + :func:`~ott.solvers.linear.sinkhorn.sinkhorn` that is called twice + if ``static_b = True`` else 3 times. static_b: if True, divergence of measure `b` against itself is **not** computed. share_epsilon: if True, enforces that the same epsilon regularizer is shared @@ -138,7 +139,7 @@ def _sinkhorn_divergence( all elements of b must match that of a to converge. symmetric_sinkhorn: Use Sinkhorn updates in Eq. 25 of :cite:`feydy:19` for symmetric terms comparing x/x and y/y. - kwargs: Keyword arguments to :func:`ott.core.sinkhorn.sinkhorn`. + kwargs: Keyword arguments to :func:`~ott.solvers.linear.sinkhorn.sinkhorn`. Returns: SinkhornDivergenceOutput named tuple. @@ -190,7 +191,7 @@ def segment_sinkhorn_divergence( cost_fn: Optional[costs.CostFn] = None, segment_ids_x: Optional[jnp.ndarray] = None, segment_ids_y: Optional[jnp.ndarray] = None, - indices_are_sorted: Optional[bool] = None, + indices_are_sorted: bool = False, num_per_segment_x: Optional[Tuple[int, ...]] = None, num_per_segment_y: Optional[Tuple[int, ...]] = None, weights_x: Optional[jnp.ndarray] = None, @@ -219,23 +220,24 @@ def segment_sinkhorn_divergence( a tensor, and `vmap` used to evaluate sinkhorn divergences in parallel. Args: - x: Array of input points, of shape [num_x, feature]. Multiple segments are - held in this single array. - y: Array of target points, of shape [num_y, feature]. - num_segments: Number of segments contained in x and y. Providing this number - is required for JIT compilation to work, see also - :func:`~ott.core.segment.segment_point_cloud`. + x: Array of input points, of shape `[num_x, feature]`. + Multiple segments are held in this single array. + y: Array of target points, of shape `[num_y, feature]`. + num_segments: Number of segments contained in `x` and `y`. + Providing this is required for JIT compilation to work, + see also :func:`~ott.geometry.segment.segment_point_cloud`. max_measure_size: Total size of measures after padding. Should ideally be set to an upper bound on points clouds processed with the segment interface. Should also be smaller than total length of `x` or `y`. - Providing this number is required for JIT compilation to work. - cost_fn: Cost function, defaults to :class:`~ott.core.costs.SqEuclidean`. - segment_ids_x: **1st interface** The segment ID for which each row of x + Providing this is required for JIT compilation to work. + cost_fn: Cost function, + defaults to :class:`~ott.geometry.costs.SqEuclidean`. + segment_ids_x: **1st interface** The segment ID for which each row of `x` belongs. This is a similar interface to :func:`jax.ops.segment_sum`. - segment_ids_y: **1st interface** The segment ID for which each row of y + segment_ids_y: **1st interface** The segment ID for which each row of `y` belongs. indices_are_sorted: **1st interface** Whether `segment_ids_x` and - `segment_ids_y` are sorted. Default false. + `segment_ids_y` are sorted. num_per_segment_x: **2nd interface** Number of points in each segment in `x`. For example, [100, 20, 30] would imply that `x` is segmented into three arrays of length `[100]`, `[20]`, and `[30]` respectively. @@ -260,7 +262,7 @@ def segment_sinkhorn_divergence( symmetric_sinkhorn: Use Sinkhorn updates in Eq. 25 of :cite:`feydy:19` for symmetric terms comparing x/x and y/y. kwargs: keywords arguments passed to form - :class:`ott.geometry.pointcloud.PointCloud` geometry objects from the + :class:`~ott.geometry.pointcloud.PointCloud` geometry objects from the subsets of points and masses selected in `x` and `y`, this could be for instance entropy regularization float, scheduler or normalization. Returns: From 4faeca54f585d71193a056a8203ffa7e4cb2f582 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 18 Nov 2022 12:47:25 +0100 Subject: [PATCH 17/34] Remove remaining `core` mentions from docstrings --- ott/geometry/graph.py | 2 +- ott/initializers/linear/initializers_lr.py | 2 +- ott/problems/linear/barycenter_problem.py | 2 +- ott/solvers/linear/continuous_barycenter.py | 2 +- ott/solvers/linear/sinkhorn.py | 4 ++-- ott/solvers/linear/sinkhorn_lr.py | 6 +++--- ott/tools/transport.py | 9 ++++----- 7 files changed, 13 insertions(+), 14 deletions(-) diff --git a/ott/geometry/graph.py b/ott/geometry/graph.py index 8cd6c05a7..dd361f67a 100644 --- a/ott/geometry/graph.py +++ b/ott/geometry/graph.py @@ -192,7 +192,7 @@ def laplacian(self) -> Union[jnp.ndarray, Sparse_t]: # in the sparse case, we don't sum duplicates here because # we need to know `nnz` a priori for JIT (could be exposed in `__init__`) - # instead, `ott.core.decomposition._jax_sparse_to_scipy` handles it on host + # instead, `ott.math.decomposition._jax_sparse_to_scipy` handles it on host return D - self.graph @property diff --git a/ott/initializers/linear/initializers_lr.py b/ott/initializers/linear/initializers_lr.py index 3eb4670ee..8e60ba51a 100644 --- a/ott/initializers/linear/initializers_lr.py +++ b/ott/initializers/linear/initializers_lr.py @@ -531,7 +531,7 @@ def init_fn() -> GeneralizedKMeansInitializer.State: crossed_threshold=False ) - # see the explanation in `ott.core.sinkhorn_lr` + # see the explanation in `ott.solvers.linear.sinkhorn_lr` def converged( state: GeneralizedKMeansInitializer.State, consts: GeneralizedKMeansInitializer.Constants, iteration: int diff --git a/ott/problems/linear/barycenter_problem.py b/ott/problems/linear/barycenter_problem.py index ac30a2b01..c2f28860f 100644 --- a/ott/problems/linear/barycenter_problem.py +++ b/ott/problems/linear/barycenter_problem.py @@ -46,7 +46,7 @@ class BarycenterProblem: the regularized transportation cost of barycenter to itself will be considered when computing gradient. Note that if the debiased option is used, the barycenter size in - :meth:`~ott.core.continuous_barycenter.WassersteinBarycenter.init_state` + :meth:`~ott.solvers.linear.continuous_barycenter.WassersteinBarycenter.init_state` needs to be smaller than the maximum measure size for parallelization to operate efficiently. kwargs: Keyword arguments :func:`~ott.geometry.segment.segment_point_cloud`. diff --git a/ott/solvers/linear/continuous_barycenter.py b/ott/solvers/linear/continuous_barycenter.py index a1534647c..715a72384 100644 --- a/ott/solvers/linear/continuous_barycenter.py +++ b/ott/solvers/linear/continuous_barycenter.py @@ -149,7 +149,7 @@ def init_state( x_init: Initial barycenter estimate of shape ``[bar_size, ndim]``. If `None`, ``bar_size`` points will be sampled from the input measures according to their weights - :attr:`~ott.core.bar_problems.BarycenterProblem.flattened_y`. + :attr:`~ott.problems.linear.barycenter_problem.BarycenterProblem.flattened_y`. rng: Seed for :func:`jax.random.PRNGKey`. Returns: diff --git a/ott/solvers/linear/sinkhorn.py b/ott/solvers/linear/sinkhorn.py index 624a2a115..bfc86c4a5 100644 --- a/ott/solvers/linear/sinkhorn.py +++ b/ott/solvers/linear/sinkhorn.py @@ -327,8 +327,8 @@ class Sinkhorn: unroll-able :func:`jax.lax.while_loop` that monitors convergence. In that case the error is not monitored and the ``converged`` flag will return ``False`` as a consequence. - momentum: a Momentum instance. See ott.core.momentum - anderson: an AndersonAcceleration instance. See ott.core.anderson. + momentum: Momentum instance. + anderson: AndersonAcceleration instance. implicit_diff: instance used to solve implicit differentiation. Unrolls iterations if None. parallel_dual_updates: updates potentials or scalings in parallel if True, diff --git a/ott/solvers/linear/sinkhorn_lr.py b/ott/solvers/linear/sinkhorn_lr.py index 5f4dbd7ec..13f014818 100644 --- a/ott/solvers/linear/sinkhorn_lr.py +++ b/ott/solvers/linear/sinkhorn_lr.py @@ -300,9 +300,9 @@ def __call__( ot_prob: Linear OT problem. init: Initial values for the low-rank factors: - - :attr:`~ott.core.sinkhorn_lr.LRSinkhornOutput.q`. - - :attr:`~ott.core.sinkhorn_lr.LRSinkhornOutput.r`. - - :attr:`~ott.core.sinkhorn_lr.LRSinkhornOutput.g`. + - :attr:`~ott.solvers.linear.sinkhorn_lr.LRSinkhornOutput.q`. + - :attr:`~ott.solvers.linear.sinkhorn_lr.LRSinkhornOutput.r`. + - :attr:`~ott.solvers.linear.sinkhorn_lr.LRSinkhornOutput.g`. Any `None` values will be initialized using the initializer. key: Random key for seeding. diff --git a/ott/tools/transport.py b/ott/tools/transport.py index cc736d859..1f25a8d8f 100644 --- a/ott/tools/transport.py +++ b/ott/tools/transport.py @@ -21,10 +21,9 @@ >>> ot = ott.transport.solve(x, y) >>> Tz = ot.apply(z) -Even if the transport.solve sole function can support many complex use cases, we -suggest more advanced users to instantiate directly their problem (see -ott.core.problems) and their solvers (see ott.core.sinkhorn and -ott.core.gromov_wasserstein) for better control over the parameters. +Even if the `transport.solve` sole function can support many complex use cases, +we suggest more advanced users to instantiate directly their :mod:`ott.problems` +and their :mod:`ott.solvers` for better control over the parameters. """ from typing import Any, NamedTuple, Optional, Union @@ -43,7 +42,7 @@ class Transport(NamedTuple): - """Implement a core.problems.Transport interface to transport solutions.""" + """Transport interface to transport solutions.""" problem: Any = None solver_output: Any = None From ba29e10ee7060b6afd5f645455430ce0d6d67be1 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 18 Nov 2022 14:58:14 +0100 Subject: [PATCH 18/34] Start updating documentation --- .gitignore | 2 +- docs/index.rst | 36 ++---------------------------------- docs/problems/index.rst | 13 +++++++++++++ docs/problems/linear.rst | 20 ++++++++++++++++++++ docs/problems/quadratic.rst | 20 ++++++++++++++++++++ docs/solvers/index.rst | 14 ++++++++++++++ docs/solvers/linear.rst | 32 ++++++++++++++++++++++++++++++++ docs/solvers/nn.rst | 2 ++ docs/solvers/quadratic.rst | 2 ++ 9 files changed, 106 insertions(+), 35 deletions(-) create mode 100644 docs/problems/index.rst create mode 100644 docs/problems/linear.rst create mode 100644 docs/problems/quadratic.rst create mode 100644 docs/solvers/index.rst create mode 100644 docs/solvers/linear.rst create mode 100644 docs/solvers/nn.rst create mode 100644 docs/solvers/quadratic.rst diff --git a/.gitignore b/.gitignore index e289be07d..d8f1864bd 100644 --- a/.gitignore +++ b/.gitignore @@ -161,7 +161,7 @@ cython_debug/ # generated documentation docs/html -docs/_autosummary +**/_autosummary # macos **/.DS_Store diff --git a/docs/index.rst b/docs/index.rst index 52af7189f..35771b557 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -57,46 +57,14 @@ There are currently three packages, ``geometry``, ``core`` and ``tools``, playin between GMMs, or computing differentiable sort and quantile operations :cite:`cuturi:19`. -.. toctree:: - :maxdepth: 1 - :caption: Tutorials: - - notebooks/point_clouds.ipynb - notebooks/introduction_grid.ipynb - -.. toctree:: - :maxdepth: 1 - :caption: Benchmarks: - - notebooks/OTT_&_POT.ipynb - notebooks/One_Sinkhorn.ipynb - notebooks/LRSinkhorn.ipynb - -.. toctree:: - :maxdepth: 1 - :caption: Advanced Applications: - - notebooks/Sinkhorn_Barycenters.ipynb - notebooks/gromov_wasserstein.ipynb - notebooks/GWLRSinkhorn.ipynb - notebooks/Hessians.ipynb - notebooks/soft_sort.ipynb - notebooks/application_biology.ipynb - notebooks/gromov_wasserstein_multiomics.ipynb - notebooks/fairness.ipynb - notebooks/neural_dual.ipynb - notebooks/icnn_inits.ipynb - notebooks/wasserstein_barycenters_gmms.ipynb - notebooks/gmm_pair_demo.ipynb - notebooks/MetaOT.ipynb .. toctree:: :maxdepth: 1 :caption: Public API: ott packages geometry - core - tools + problems/index + solvers/index .. toctree:: :maxdepth: 1 diff --git a/docs/problems/index.rst b/docs/problems/index.rst new file mode 100644 index 000000000..89a0b44db --- /dev/null +++ b/docs/problems/index.rst @@ -0,0 +1,13 @@ +ott.problems package +==================== + +TODO(cuturi): add some nice text here please. + +.. currentmodule:: ott.problems +.. automodule:: ott.problems + +.. toctree:: + :maxdepth: 2 + + linear + quadratic diff --git a/docs/problems/linear.rst b/docs/problems/linear.rst new file mode 100644 index 000000000..4f4ce5cb2 --- /dev/null +++ b/docs/problems/linear.rst @@ -0,0 +1,20 @@ +ott.problems.linear package +=========================== +.. currentmodule:: ott.problems.linear +.. automodule:: ott.problems.linear + +OT Problems +----------- +.. autosummary:: + :toctree: _autosummary + + linear_problem.LinearProblem + barycenter_problem.BarycenterProblem + +Dual Potentials +--------------- +.. autosummary:: + :toctree: _autosummary + + potentials.DualPotentials + potentials.EntropicPotentials diff --git a/docs/problems/quadratic.rst b/docs/problems/quadratic.rst new file mode 100644 index 000000000..900081871 --- /dev/null +++ b/docs/problems/quadratic.rst @@ -0,0 +1,20 @@ +ott.problems.quadratic package +============================== +.. currentmodule:: ott.problems.quadratic +.. automodule:: ott.problems.quadratic + +OT Problems +----------- +.. autosummary:: + :toctree: _autosummary + + quadratic_problem.QuadraticProblem + gw_barycenter.GWBarycenterProblem + +Costs +----- +.. autosummary:: + :toctree: _autosummary + + quadratic_costs.make_square_loss + quadratic_costs.make_kl_loss diff --git a/docs/solvers/index.rst b/docs/solvers/index.rst new file mode 100644 index 000000000..cf054d858 --- /dev/null +++ b/docs/solvers/index.rst @@ -0,0 +1,14 @@ +ott.solvers package +=================== + +TODO(cuturi): add some nice text here please. + +.. currentmodule:: ott.solvers +.. automodule:: ott.solvers + +.. toctree:: + :maxdepth: 2 + + linear + quadratic + neural diff --git a/docs/solvers/linear.rst b/docs/solvers/linear.rst new file mode 100644 index 000000000..16ddd20fe --- /dev/null +++ b/docs/solvers/linear.rst @@ -0,0 +1,32 @@ +ott.solvers.linear package +========================== +.. currentmodule:: ott.solvers.linear +.. automodule:: ott.solvers.linear + +Sinkhorn Solvers +---------------- +.. autosummary:: + :toctree: _autosummary + + sinkhorn.Sinkhorn + sinkhorn.SinkhornOutput + sinkhorn_lr.LRSinkhorn + sinkhorn_lr.LRSinkhornOutput + +Barycenter Solvers +------------------ +.. autosummary:: + :toctree: _autosummary + + continuous_barycenter.WassersteinBarycenter + continuous_barycenter.BarycenterState + discrete_barycenter.discrete_barycenter + discrete_barycenter.SinkhornBarycenterOutput + +Sinkhorn Acceleration +--------------------- +.. autosummary:: + :toctree: _autosummary + + acceleration.Momentum + acceleration.AndersonAcceleration diff --git a/docs/solvers/nn.rst b/docs/solvers/nn.rst new file mode 100644 index 000000000..887929c9f --- /dev/null +++ b/docs/solvers/nn.rst @@ -0,0 +1,2 @@ +ott.solvers.nn package +====================== diff --git a/docs/solvers/quadratic.rst b/docs/solvers/quadratic.rst new file mode 100644 index 000000000..c1e85ca70 --- /dev/null +++ b/docs/solvers/quadratic.rst @@ -0,0 +1,2 @@ +ott.solvers.quadratic package +============================= From 2bd6d8b41aae15262e29aaecab53e436d4eb5ead Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 21 Nov 2022 10:28:19 +0100 Subject: [PATCH 19/34] Fix typing --- ott/geometry/costs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ott/geometry/costs.py b/ott/geometry/costs.py index 5b11005d3..4521b89e7 100644 --- a/ott/geometry/costs.py +++ b/ott/geometry/costs.py @@ -320,14 +320,14 @@ def scale_covariances( (cov_sqrt @ cov_i) @ cov_sqrt ) - def cond_fn(iteration: int, constants: Tuple[...], state) -> bool: + def cond_fn(iteration: int, constants: Tuple[Any, ...], state) -> bool: del iteration, constants _, diff = state return diff > rtol def body_fn( - iteration: int, constants: Tuple[...], state: Tuple[jnp.ndarray, float], - compute_error: bool + iteration: int, constants: Tuple[Any, ...], + state: Tuple[jnp.ndarray, float], compute_error: bool ) -> Tuple[jnp.ndarray, float]: del iteration, constants, compute_error cov, _ = state From d54074aa09c549da36f64970e123432bc6d307af Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 21 Nov 2022 10:31:31 +0100 Subject: [PATCH 20/34] Update solvers docs --- docs/solvers/index.rst | 2 +- docs/solvers/nn.rst | 17 +++++++++++++++++ docs/solvers/quadratic.rst | 19 +++++++++++++++++++ 3 files changed, 37 insertions(+), 1 deletion(-) diff --git a/docs/solvers/index.rst b/docs/solvers/index.rst index cf054d858..9954a585c 100644 --- a/docs/solvers/index.rst +++ b/docs/solvers/index.rst @@ -11,4 +11,4 @@ TODO(cuturi): add some nice text here please. linear quadratic - neural + nn diff --git a/docs/solvers/nn.rst b/docs/solvers/nn.rst index 887929c9f..cd28c655f 100644 --- a/docs/solvers/nn.rst +++ b/docs/solvers/nn.rst @@ -1,2 +1,19 @@ ott.solvers.nn package ====================== +.. currentmodule:: ott.solvers.nn +.. automodule:: ott.solvers.nn + +Neural Dual +----------- +.. autosummary:: + :toctree: _autosummary + + neuraldual.NeuralDualSolver + +ICNN +---- +.. autosummary:: + :toctree: _autosummary + + icnn.ICNN + layers.PositiveDense diff --git a/docs/solvers/quadratic.rst b/docs/solvers/quadratic.rst index c1e85ca70..5ac42cabd 100644 --- a/docs/solvers/quadratic.rst +++ b/docs/solvers/quadratic.rst @@ -1,2 +1,21 @@ ott.solvers.quadratic package ============================= +.. currentmodule:: ott.solvers.quadratic +.. automodule:: ott.solvers.quadratic + +Gromov Wasserstein Solvers +-------------------------- +.. autosummary:: + :toctree: _autosummary + + gromov_wasserstein.GromovWasserstein + gromov_wasserstein.GWOutput + gromov_wasserstein.gromov_wasserstein + +Barycenter Solvers +------------------ +.. autosummary:: + :toctree: _autosummary + + gw_barycenter.GWBarycenterState + gw_barycenter.GromovWassersteinBarycenter From 6a1521345823be1985b44939f5da20b335203da4 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 21 Nov 2022 11:00:31 +0100 Subject: [PATCH 21/34] Add initializers --- docs/Makefile | 1 + docs/index.rst | 1 + docs/initializers/index.rst | 14 ++++++++++++++ docs/initializers/linear.rst | 23 +++++++++++++++++++++++ docs/initializers/nn.rst | 12 ++++++++++++ docs/initializers/quadratic.rst | 12 ++++++++++++ docs/solvers/quadratic.rst | 2 +- ott/geometry/costs.py | 2 +- 8 files changed, 65 insertions(+), 2 deletions(-) create mode 100644 docs/initializers/index.rst create mode 100644 docs/initializers/linear.rst create mode 100644 docs/initializers/nn.rst create mode 100644 docs/initializers/quadratic.rst diff --git a/docs/Makefile b/docs/Makefile index 3db4deda9..2dab86e59 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -22,3 +22,4 @@ help: clean: @rm -rf $(BUILDDIR)/ @rm -rf $(SOURCEDIR)/_autosummary + @rm -rf $(SOURCEDIR)/**/_autosummary diff --git a/docs/index.rst b/docs/index.rst index 35771b557..7a2bebe6b 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -65,6 +65,7 @@ There are currently three packages, ``geometry``, ``core`` and ``tools``, playin geometry problems/index solvers/index + initializers/index .. toctree:: :maxdepth: 1 diff --git a/docs/initializers/index.rst b/docs/initializers/index.rst new file mode 100644 index 000000000..a2f4162f3 --- /dev/null +++ b/docs/initializers/index.rst @@ -0,0 +1,14 @@ +ott.initializers package +======================== + +TODO(cuturi): add some nice text here please. + +.. currentmodule:: ott.initializers +.. automodule:: ott.initializers + +.. toctree:: + :maxdepth: 2 + + linear + quadratic + nn diff --git a/docs/initializers/linear.rst b/docs/initializers/linear.rst new file mode 100644 index 000000000..8fb5d89ea --- /dev/null +++ b/docs/initializers/linear.rst @@ -0,0 +1,23 @@ +ott.initializers.linear package +=============================== +.. currentmodule:: ott.initializers.linear +.. automodule:: ott.initializers.linear + +Sinkhorn Initializers +--------------------- +.. autosummary:: + :toctree: _autosummary + + initializers.DefaultInitializer + initializers.GaussianInitializer + initializers.SinkhornInitializer + +Low-rank Sinkhorn Initializers +------------------------------ +.. autosummary:: + :toctree: _autosummary + + initializers_lr.RandomInitializer + initializers_lr.Rank2Initializer + initializers_lr.KMeansInitializer + initializers_lr.GeneralizedKMeansInitializer diff --git a/docs/initializers/nn.rst b/docs/initializers/nn.rst new file mode 100644 index 000000000..6f439f6a8 --- /dev/null +++ b/docs/initializers/nn.rst @@ -0,0 +1,12 @@ +ott.initializers.nn package +=========================== +.. currentmodule:: ott.initializers.nn +.. automodule:: ott.initializers.nn + +Neural Initializers +------------------- +.. autosummary:: + :toctree: _autosummary + + initializers.MetaInitializer + initializers.MetaMLP diff --git a/docs/initializers/quadratic.rst b/docs/initializers/quadratic.rst new file mode 100644 index 000000000..d3ea718a9 --- /dev/null +++ b/docs/initializers/quadratic.rst @@ -0,0 +1,12 @@ +ott.initializers.quadratic package +================================== +.. currentmodule:: ott.initializers.quadratic +.. automodule:: ott.initializers.quadratic + +Gromov-Wasserstein Initializers +------------------------------- +.. autosummary:: + :toctree: _autosummary + + initializers.QuadraticInitializer + initializers.LRQuadraticInitializer diff --git a/docs/solvers/quadratic.rst b/docs/solvers/quadratic.rst index 5ac42cabd..33b0a9014 100644 --- a/docs/solvers/quadratic.rst +++ b/docs/solvers/quadratic.rst @@ -3,7 +3,7 @@ ott.solvers.quadratic package .. currentmodule:: ott.solvers.quadratic .. automodule:: ott.solvers.quadratic -Gromov Wasserstein Solvers +Gromov-Wasserstein Solvers -------------------------- .. autosummary:: :toctree: _autosummary diff --git a/ott/geometry/costs.py b/ott/geometry/costs.py index 4521b89e7..fdf52e7de 100644 --- a/ott/geometry/costs.py +++ b/ott/geometry/costs.py @@ -333,7 +333,7 @@ def body_fn( cov, _ = state cov_sqrt, cov_inv_sqrt, _ = matrix_square_root.sqrtm(cov) scaled_cov = jnp.linalg.matrix_power( - jnp.sum(self.scale_covariances(cov_sqrt, covs, lambdas), axis=0), 2 + jnp.sum(scale_covariances(cov_sqrt, covs, lambdas), axis=0), 2 ) next_cov = (cov_inv_sqrt @ scaled_cov) @ cov_inv_sqrt diff = jnp.sum((next_cov - cov) ** 2) / jnp.prod(jnp.array(cov.shape)) From 987ad2595575d544d8d48d9009b06a448f311014 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 21 Nov 2022 12:31:02 +0100 Subject: [PATCH 22/34] Update docs --- docs/conf.py | 3 + docs/core.rst | 115 -------------------- docs/geometry.rst | 7 ++ docs/index.rst | 49 +++++++-- docs/initializers/index.rst | 2 + docs/math.rst | 35 ++++++ docs/problems/index.rst | 2 + docs/references.bib | 3 +- docs/solvers/index.rst | 9 ++ docs/solvers/linear.rst | 1 + docs/tools.rst | 1 - ott/_version.py | 6 +- ott/initializers/nn/initializers.py | 28 ++--- ott/math/decomposition.py | 2 +- ott/math/implicit_differentiation.py | 28 ++--- ott/math/matrix_square_root.py | 2 +- ott/math/utils.py | 7 +- ott/problems/quadratic/quadratic_problem.py | 2 +- ott/solvers/quadratic/gromov_wasserstein.py | 2 +- ott/{typing.py => types.py} | 0 20 files changed, 142 insertions(+), 162 deletions(-) delete mode 100644 docs/core.rst create mode 100644 docs/math.rst rename ott/{typing.py => types.py} (100%) diff --git a/docs/conf.py b/docs/conf.py index 6b8dbd0bc..c53993a05 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -73,6 +73,9 @@ source_suffix = ['.rst'] autosummary_generate = True +autosummary_filename_map = { + "ott.solvers.linear.sinkhorn.sinkhorn": "sinkhorn-function" +} autodoc_typehints = 'description' diff --git a/docs/core.rst b/docs/core.rst deleted file mode 100644 index 0c62c88e5..000000000 --- a/docs/core.rst +++ /dev/null @@ -1,115 +0,0 @@ -.. _core: - -ott.core package -================ -.. currentmodule:: ott.core -.. automodule:: ott.core - -The core package contains definitions of various OT problems, starting -from the most simple, the linear OT problem, to more advanced problems -such as quadratic, or involving multiple measures, the barycenter problem. -We follow with the classic :class:`~ott.core.sinkhorn.sinkhorn` routine (essentially a -wrapper for the :class:`~ott.core.sinkhorn.Sinkhorn` solver class) :cite:`cuturi:13,sejourne:19`. -We also provide an analogous low-rank Sinkhorn solver :cite:`scetbon:21` to handle very large instances. -Both are used within our Wasserstein barycenter solvers :cite:`benamou:15,janati:20a`, as well as our -Gromov-Wasserstein solver :cite:`memoli:11,scetbon:22`. We also provide an implementation of -input convex neural networks :cite:`amos:17`, a NN that can be used to estimate OT :cite:`makkuva:20`. - -OT Problems ------------ -.. autosummary:: - :toctree: _autosummary - - linear_problems.LinearProblem - quad_problems.QuadraticProblem - bar_problems.BarycenterProblem - bar_problems.GWBarycenterProblem - -Sinkhorn --------- -.. autosummary:: - :toctree: _autosummary - - sinkhorn.sinkhorn - sinkhorn.Sinkhorn - sinkhorn.SinkhornOutput - -Sinkhorn Dual Initializers --------------------------- -.. autosummary:: - :toctree: _autosummary - - initializers.DefaultInitializer - initializers.GaussianInitializer - initializers.SortingInitializer - initializers.MetaInitializer - initializers.MetaMLP - -Low-Rank Sinkhorn ------------------ -.. autosummary:: - :toctree: _autosummary - - sinkhorn_lr.LRSinkhorn - sinkhorn_lr.LRSinkhornOutput - -Low-Rank Sinkhorn Initializers ------------------------------- -.. autosummary:: - :toctree: _autosummary - - initializers_lr.RandomInitializer - initializers_lr.Rank2Initializer - initializers_lr.KMeansInitializer - initializers_lr.GeneralizedKMeansInitializer - -Quadratic Initializers ----------------------- -.. autosummary:: - :toctree: _autosummary - - quad_initializers.QuadraticInitializer - quad_initializers.LRQuadraticInitializer - -Barycenters (Entropic and LR) ------------------------------ -.. autosummary:: - :toctree: _autosummary - - discrete_barycenter.discrete_barycenter - continuous_barycenter.WassersteinBarycenter - continuous_barycenter.BarycenterState - gw_barycenter.GromovWassersteinBarycenter - gw_barycenter.GWBarycenterState - -Gromov-Wasserstein (Entropic and LR) ------------------------------------- -.. autosummary:: - :toctree: _autosummary - - gromov_wasserstein.gromov_wasserstein - gromov_wasserstein.GromovWasserstein - gromov_wasserstein.GWOutput - -Dual Potentials ---------------- -.. autosummary:: - :toctree: _autosummary - - potentials.DualPotentials - potentials.EntropicPotentials - -Neural Dual Potentials ----------------------- -.. autosummary:: - :toctree: _autosummary - - icnn.ICNN - neuraldual.NeuralDualSolver - -Padding Utilities ------------------ -.. autosummary:: - :toctree: _autosummary - - segment.segment_point_cloud diff --git a/docs/geometry.rst b/docs/geometry.rst index edabf6b3b..c857dd8a7 100644 --- a/docs/geometry.rst +++ b/docs/geometry.rst @@ -58,3 +58,10 @@ Cost Functions costs.Cosine costs.Bures costs.UnbalancedBures + +Utilities +--------- +.. autosummary:: + :toctree: _autosummary + + segment.segment_point_cloud diff --git a/docs/index.rst b/docs/index.rst index 7a2bebe6b..24a376771 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -17,11 +17,10 @@ The first family consists in *discrete* solvers computing transport between poin using the Sinkhorn :cite:`cuturi:13` and low-rank Sinkhorn :cite:`scetbon:21` algorithms, and moving up towards Gromov-Wasserstein :cite:`memoli:11`, :cite:`memoli:11`; the second family consists in *continuous* solvers, using suitable neural architectures :cite:`amos:17` coupled -with SGD type estimators :cite:`makkuva:20`, :cite:`korotin:21`. +with SGD type estimators :cite:`makkuva:20,korotin:21`. Design Choices -------------- - `OTT` is designed with the following choices: - Take advantage whenever possible of JAX features, such as `Just-in-time (JIT) compilation`_, @@ -44,19 +43,51 @@ Design Choices Packages -------- -There are currently three packages, ``geometry``, ``core`` and ``tools``, playing the following roles: - - :ref:`geometry` contains classes to instantiate objects that describe *two point clouds* paired with a *cost* function. Geometry objects are used to - describe OT problems, handled by solvers in ``core``. -- :ref:`core` classes describe OT problems (linear, quadratic, barycenters), and - solver classes, to instantiate algorithms that will output an OT. + describe OT problems, handled by solvers in the :ref:`solvers`. +- :ref:`problems` TODO(marcocuturi) +- :ref:`solvers` TODO(marcocuturi) +- :ref:`initializers` TODO(marcocuturi) - :ref:`tools` provides an interface to exploit OT solutions, as produced by - solvers in the ``core`` package. Such tasks include computing approximations + solvers in the :ref:`solvers`. Such tasks include computing approximations to Wasserstein distances :cite:`genevay:18,sejourne:19`, approximating OT between GMMs, or computing differentiable sort and quantile operations :cite:`cuturi:19`. +- :ref:`math` TODO(marcocuturi) + +.. toctree:: + :maxdepth: 1 + :caption: Tutorials: + + notebooks/point_clouds.ipynb + notebooks/introduction_grid.ipynb + +.. toctree:: + :maxdepth: 1 + :caption: Benchmarks: + + notebooks/OTT_&_POT.ipynb + notebooks/One_Sinkhorn.ipynb + notebooks/LRSinkhorn.ipynb + +.. toctree:: + :maxdepth: 1 + :caption: Advanced Applications: + notebooks/Sinkhorn_Barycenters.ipynb + notebooks/gromov_wasserstein.ipynb + notebooks/GWLRSinkhorn.ipynb + notebooks/Hessians.ipynb + notebooks/soft_sort.ipynb + notebooks/application_biology.ipynb + notebooks/gromov_wasserstein_multiomics.ipynb + notebooks/fairness.ipynb + notebooks/neural_dual.ipynb + notebooks/icnn_inits.ipynb + notebooks/wasserstein_barycenters_gmms.ipynb + notebooks/gmm_pair_demo.ipynb + notebooks/MetaOT.ipynb .. toctree:: :maxdepth: 1 @@ -66,6 +97,8 @@ There are currently three packages, ``geometry``, ``core`` and ``tools``, playin problems/index solvers/index initializers/index + tools + math .. toctree:: :maxdepth: 1 diff --git a/docs/initializers/index.rst b/docs/initializers/index.rst index a2f4162f3..4baba57f5 100644 --- a/docs/initializers/index.rst +++ b/docs/initializers/index.rst @@ -1,3 +1,5 @@ +.. _initializers: + ott.initializers package ======================== diff --git a/docs/math.rst b/docs/math.rst new file mode 100644 index 000000000..de58a29e2 --- /dev/null +++ b/docs/math.rst @@ -0,0 +1,35 @@ +.. _math: + +ott.math package +================ +.. currentmodule:: ott.math +.. automodule:: ott.math + +Implicit Differentiation +------------------------ +.. autosummary:: + :toctree: _autosummary + + implicit_differentiation.ImplicitDiff + +Fixed-point Iteration +--------------------- +.. autosummary:: + :toctree: _autosummary + + fixed_point_loop.fixpoint_iter + +Cholesky Decomposition +---------------------- +.. autosummary:: + :toctree: _autosummary + + decomposition.DenseCholeskySolver + decomposition.SparseCholeskySolver + +Matrix Square Root +------------------ +.. autosummary:: + :toctree: _autosummary + + matrix_square_root.sqrtm diff --git a/docs/problems/index.rst b/docs/problems/index.rst index 89a0b44db..462411f7d 100644 --- a/docs/problems/index.rst +++ b/docs/problems/index.rst @@ -1,3 +1,5 @@ +.. _problems: + ott.problems package ==================== diff --git a/docs/references.bib b/docs/references.bib index 74abe1adb..882822f55 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -29,7 +29,6 @@ @InProceedings{peyre:16 url = {https://proceedings.mlr.press/v48/peyre16.html}, } - @InProceedings{feydy:19, title = {Interpolating between Optimal Transport and MMD using Sinkhorn Divergences}, author = {Feydy, Jean and S\'{e}journ\'{e}, Thibault and Vialard, Fran\c{c}ois-Xavier and Amari, Shun-ichi and Trouve, Alain and Peyr\'{e}, Gabriel}, @@ -492,7 +491,7 @@ @inproceedings{chizat:20 year = {2020} } -@Article{higham:1997, +@Article{higham:97, author = "Higham, Nicholas J.", title = "Stable iterations for the matrix square root", journal = "Numerical Algorithms", diff --git a/docs/solvers/index.rst b/docs/solvers/index.rst index 9954a585c..b903983a7 100644 --- a/docs/solvers/index.rst +++ b/docs/solvers/index.rst @@ -1,3 +1,5 @@ +.. _solvers: + ott.solvers package =================== @@ -12,3 +14,10 @@ TODO(cuturi): add some nice text here please. linear quadratic nn + +Wasserstein Solver +------------------ +.. autosummary:: + :toctree: _autosummary + + was_solver.WassersteinSolver diff --git a/docs/solvers/linear.rst b/docs/solvers/linear.rst index 16ddd20fe..8e39c984e 100644 --- a/docs/solvers/linear.rst +++ b/docs/solvers/linear.rst @@ -8,6 +8,7 @@ Sinkhorn Solvers .. autosummary:: :toctree: _autosummary + sinkhorn.sinkhorn sinkhorn.Sinkhorn sinkhorn.SinkhornOutput sinkhorn_lr.LRSinkhorn diff --git a/docs/tools.rst b/docs/tools.rst index 2a890d52a..a0847d506 100644 --- a/docs/tools.rst +++ b/docs/tools.rst @@ -23,7 +23,6 @@ Segmented Sinkhorn segment_sinkhorn.segment_sinkhorn - Sinkhorn Divergence ------------------- .. autosummary:: diff --git a/ott/_version.py b/ott/_version.py index 23af74edc..689bed779 100644 --- a/ott/_version.py +++ b/ott/_version.py @@ -1,13 +1,11 @@ -from packaging.version import parse - try: from importlib_metadata import PackageNotFoundError, version # Python < 3.8 except ImportError: from importlib.metadata import PackageNotFoundError, version try: - __version__ = str(parse(version("ott-jax"))) + __version__ = version("ott-jax") except PackageNotFoundError: __version__ = "" -del parse, version, PackageNotFoundError +del version, PackageNotFoundError diff --git a/ott/initializers/nn/initializers.py b/ott/initializers/nn/initializers.py index 0f6ce2e07..f87dce7d0 100644 --- a/ott/initializers/nn/initializers.py +++ b/ott/initializers/nn/initializers.py @@ -33,25 +33,27 @@ class MetaInitializer(initializers.DefaultInitializer): :class:`~ott.initializers.nn.initializers.MetaMLP` and, with batched problem instances passed into :meth:`update`. - **Sample training usage.** The following code shows a simple - example of using ``update`` to train the model, where - ``a`` and ``b`` are the weights of the measures and - ``geom`` is the fixed geometry. - - .. code-block:: python - - meta_initializer = init_lib.MetaInitializer(geom=geom) - while training(): - a, b = sample_batch() - loss, init_f, meta_initializer.state = meta_initializer.update( - meta_initializer.state, a=a, b=b) - Args: geom: The fixed geometry of the problem instances. meta_model: The model to predict the potential :math:`f` from the measures. opt: The optimizer to update the parameters. rng: The PRNG key to use for initializing the model. state: The training state of the model to start from. + + Examples: + The following code shows a simple + example of using ``update`` to train the model, where + ``a`` and ``b`` are the weights of the measures and + ``geom`` is the fixed geometry. + + .. code-block:: python + + meta_initializer = init_lib.MetaInitializer(geom) + while training(): + a, b = sample_batch() + loss, init_f, meta_initializer.state = meta_initializer.update( + meta_initializer.state, a=a, b=b + ) """ def __init__( diff --git a/ott/math/decomposition.py b/ott/math/decomposition.py index a05ee1b23..e05888e81 100644 --- a/ott/math/decomposition.py +++ b/ott/math/decomposition.py @@ -25,7 +25,7 @@ except ImportError: cholmod = None -__all__ = ["CholeskySolver", "DenseCholeskySolver", "SparseCholeskySolver"] +__all__ = ["DenseCholeskySolver", "SparseCholeskySolver"] T = TypeVar("T") diff --git a/ott/math/implicit_differentiation.py b/ott/math/implicit_differentiation.py index 58ed326bf..26bbbea97 100644 --- a/ott/math/implicit_differentiation.py +++ b/ott/math/implicit_differentiation.py @@ -31,24 +31,26 @@ class ImplicitDiff: """Implicit differentiation of Sinkhorn algorithm. - Attributes: + Args: solver_fun: Callable, should return (solution, ...) ridge_kernel: promotes zero-sum solutions. only used if tau_a = tau_b = 1.0 ridge_identity: handles rank deficient transport matrices (this happens - typically when rows/cols in cost/kernel matrices are colinear, or, + typically when rows/cols in cost/kernel matrices are collinear, or, equivalently when two points from either measure are close). symmetric: flag used to figure out whether the linear system solved in the implicit function theorem is symmetric or not. This happens when either ``a == b`` or the precondition_fun is the identity. False by default, and, at the moment, needs to be set manually by the user in the more favorable case where the system is guaranteed to be symmetric. + precondition_fun: TODO(marcocuturi) """ - solver_fun: Callable = jax.scipy.sparse.linalg.cg # pylint: disable=g-bare-generic + solver_fun: Callable[[jnp.ndarray, jnp.ndarray], + Tuple[jnp.ndarray, ...]] = jax.scipy.sparse.linalg.cg ridge_kernel: float = 0.0 ridge_identity: float = 0.0 symmetric: bool = False - precondition_fun: Optional[Callable[[float], float]] = None + precondition_fun: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None def solve( self, gr: Tuple[jnp.ndarray, @@ -64,9 +66,11 @@ def solve( Given a ``precondition_fun``, written here for short as :math:`h`, the first order conditions for the dual energy - :math:`E(K, \epsilon, a, b, f, g) :=- + - \langle\exp^{f/\epsilon}, K - \exp^{g/\epsilon}>` + + .. math:: + + E(K, \epsilon, a, b, f, g) :=- + - \langle\exp^{f/\epsilon}, K \exp^{g/\epsilon}> form the basis of the Sinkhorn algorithm. To differentiate optimal solutions to that problem, we exploit the fact that :math:`h(\nabla E = 0)` and @@ -97,7 +101,7 @@ def solve( application elementwise of :math:`h'` to the row (respectively column) marginal sum of the transport. - Note that we take great care in not instantiatiating these transport + Note that we take great care in not instantiating these transport matrices, to rely instead on calls to the ``app_transport`` method from the ``Geometry`` object ``geom`` (which will either use potentials or scalings, depending on ``lse_mode``) @@ -112,7 +116,7 @@ def solve( that subspace to enforce solutions have zero sum. The Schur complement can also be rank deficient if two lines or columns of T - are colinear. This will typically happen it two rows or columns of the cost + are collinear. This will typically happen it two rows or columns of the cost or kernel matrix are numerically close. To avoid this, we add a more global ``ridge_identity * z`` regularizer to achieve better conditioning. @@ -120,10 +124,8 @@ def solve( ``implicit_solver_fun``, which is set by default to ``cg``. When the system is symmetric (as detected by the corresponding flag ``symmetric``), ``cg`` is applied directly. When - it - is not, normal equations are used (i.e. the Schur complement is multiplied - by - its transpose before solving the system). + it is not, normal equations are used (i.e. the Schur complement is + multiplied by its transpose before solving the system). Args: gr: 2-tuple, (vector of size ``n``, vector of size ``m``). diff --git a/ott/math/matrix_square_root.py b/ott/math/matrix_square_root.py index 74a7651ad..761e67af2 100644 --- a/ott/math/matrix_square_root.py +++ b/ott/math/matrix_square_root.py @@ -24,7 +24,7 @@ from ott.math import fixed_point_loop -__all__ = ["sqrtm"] +__all__ = ["sqrtm", "sqrtm_only", "inv_sqrtm_only"] @functools.partial(jax.custom_vjp, nondiff_argnums=(1, 2, 3, 4, 5)) diff --git a/ott/math/utils.py b/ott/math/utils.py index 3c83f8e82..fef5ae667 100644 --- a/ott/math/utils.py +++ b/ott/math/utils.py @@ -1,10 +1,13 @@ import functools -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union import jax import jax.experimental.sparse as jesp import jax.numpy as jnp +if TYPE_CHECKING: + from ott.geometry import costs + __all__ = [ "safe_log", "kl", "js", "sparse_scale", "logsumexp", "barycentric_projection" @@ -93,6 +96,6 @@ def logsumexp_jvp(axis, keepdims, return_sign, primals, tangents): @functools.partial(jax.vmap, in_axes=[0, 0, None]) def barycentric_projection( - matrix: jnp.ndarray, y: jnp.ndarray, cost_fn + matrix: jnp.ndarray, y: jnp.ndarray, cost_fn: "costs.CostFn" ) -> jnp.ndarray: return jax.vmap(cost_fn.barycenter, in_axes=[0, None])(matrix, y) diff --git a/ott/problems/quadratic/quadratic_problem.py b/ott/problems/quadratic/quadratic_problem.py index e2060015d..2bf975c5d 100644 --- a/ott/problems/quadratic/quadratic_problem.py +++ b/ott/problems/quadratic/quadratic_problem.py @@ -22,7 +22,7 @@ from ott.geometry import epsilon_scheduler, geometry, low_rank, pointcloud from ott.problems.linear import linear_problem from ott.problems.quadratic import quadratic_costs -from ott.typing import Transport +from ott.types import Transport if TYPE_CHECKING: from ott.solvers.linear import sinkhorn_lr diff --git a/ott/solvers/quadratic/gromov_wasserstein.py b/ott/solvers/quadratic/gromov_wasserstein.py index 3468147e3..0489b9680 100644 --- a/ott/solvers/quadratic/gromov_wasserstein.py +++ b/ott/solvers/quadratic/gromov_wasserstein.py @@ -249,7 +249,7 @@ def init_state( prob: Quadratic OT problem. init: Initial linearization of the quadratic problem. key: Random key for low-rank initializers. Only used when - :attr:`warm_start` is `False`. + :attr:`warm_start` is `False`. Returns: The initial Gromov-Wasserstein state. diff --git a/ott/typing.py b/ott/types.py similarity index 100% rename from ott/typing.py rename to ott/types.py From 8e0fcb7873163ce8a50301df524fcba591e7e619 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 21 Nov 2022 14:17:35 +0100 Subject: [PATCH 23/34] Fix MetaOT links --- docs/conf.py | 3 ++- docs/notebooks/MetaOT.ipynb | 16 ++++++++-------- docs/notebooks/gmm_pair_demo.ipynb | 19 ++++++++++++++++--- 3 files changed, 26 insertions(+), 12 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index c53993a05..fc7b55099 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -74,7 +74,8 @@ autosummary_generate = True autosummary_filename_map = { - "ott.solvers.linear.sinkhorn.sinkhorn": "sinkhorn-function" + "ott.solvers.linear.sinkhorn.sinkhorn": + "ott.solvers.linear.sinkhorn.sinkhorn-function" } autodoc_typehints = 'description' diff --git a/docs/notebooks/MetaOT.ipynb b/docs/notebooks/MetaOT.ipynb index dc4c3acd2..ebe8bc21d 100644 --- a/docs/notebooks/MetaOT.ipynb +++ b/docs/notebooks/MetaOT.ipynb @@ -24,9 +24,9 @@ "\n", "We will cover:\n", "\n", - "+ [ott.core.initializers.MetaInitializer](../_autosummary/ott.core.initializers.MetaInitializer.html): The main class for the Meta OT initializer\n", - "+ [ott.core.initializers.MetaMLP](../_autosummary/ott.core.initializers.MetaMLP.html): A Meta MLP to predict the dual potentials from the weights of the measures\n", - "+ [ott.core.initializers.GaussianInitializer](../_autosummary/ott.core.initializers.GaussianInitializer.html): The main initialization class for the Gasusian initializer" + "+ [MetaInitializer](../initializers/_autosummary/ott.initializers.nn.initializers.MetaInitializer.html): The main class for the Meta OT initializer\n", + "+ [MetaMLP](../initializers/_autosummary/ott.initializers.nn.initializers.MetaMLP.html): A Meta MLP to predict the dual potentials from the weights of the measures\n", + "+ [GaussianInitializer](../initializers/_autosummary/ott.initializers.linear.initializers.GaussianInitializer.html): The main initialization class for the Gaussian initializer" ] }, { @@ -220,7 +220,7 @@ "This tutorial shows how to train a meta OT model to predict\n", "the optimal Sinkhorn potentials from the image pairs.\n", "We will reproduce their results using \n", - "[OTT's Meta OT initializer](../_autosummary/ott.core.initializers.MetaInitializer.html),\n", + "[OTT's Meta OT initializer](../initializers/_autosummary/ott.initializers.nn.initializers.MetaInitializer.html),\n", "which provides an easy-to-use interface\n", "for training and using Meta OT models.\n", "\n", @@ -296,7 +296,7 @@ "We interpret the pair of MNIST digits as discrete measures\n", "$\\alpha = \\sum_{i=1}^{n_a} a_i \\delta_{x_i}$ and $\\beta = \\sum_{j=1}^{n_b} b_j \\delta_{y_j}$.\n", "The default Sinkhorn implementation in \n", - "[ott.core.sinkhorn.sinkhorn](../_autosummary/ott.core.sinkhorn.sinkhorn.html)\n", + "[ott.solvers.linear.sinkhorn.sinkhorn](../solvers/_autosummary/ott.solvers.linear.sinkhorn.sinkhorn-function.html)\n", "can easily compute their optimal coupling and associated\n", "dual potentials $f$ and $g$ from scratch.\n", "The optimal coupling between the measures can be used\n", @@ -380,16 +380,16 @@ "in the meta distribution $\\mathcal{D}$ during training.\n", "\n", "The following instantiates\n", - "[ott.core.initializers.MetaInitializer](../_autosummary/ott.core.initializers.MetaInitializer.html),\n", + "[MetaInitializer](../initializers/_autosummary/ott.initializers.nn.initializers.MetaInitializer.html),\n", "which provides an implementation for training and deploying Meta OT models.\n", "The default meta potential model for $f_\\theta$ is a standard multi-layer MLP\n", - "defined in [ott.core.initializers.MetaMLP](../_autosummary/ott.core.initializers.MetaMLP.html)\n", + "defined in [MetaMLP](../_autosummary/ott.initializers.nn.initializers.MetaMLP.html)\n", "and it is optimized with Adam by default.\n", "\n", "**Custom model and optimizers**.\n", "The model and training procedure use\n", "[flax](https://flax.readthedocs.io/en/latest/) and\n", - "[optax](https://optax.readthedocs.io/en/latest).\n", + "[optax](https://optax.readthedocs.io/en/latest/).\n", "The Meta OT initializer can take a custom-written Flax module\n", "in `init_model` or optimizer in `opt` that may be better-suited\n", "to your setting than an MLP." diff --git a/docs/notebooks/gmm_pair_demo.ipynb b/docs/notebooks/gmm_pair_demo.ipynb index 176410897..ceb701d4d 100644 --- a/docs/notebooks/gmm_pair_demo.ipynb +++ b/docs/notebooks/gmm_pair_demo.ipynb @@ -16,7 +16,7 @@ "\n", "[1] Y. Chen, T. T. Georgiou, and A. Tannenbaum, [Optimal transport for Gaussian mixture models](https://arxiv.org/abs/1710.07876), arXiv, (2017).\n", "\n", - "[2] Y. Chen, T. T. Georgiou, and A. Tannenbaum, [Optimal Transport for Gaussian Mixture Models](), *IEEE Access*, 7 (2019), pp. 6269–6278, https://doi.org/10.1109/ACCESS.2018.2889838.\n", + "[2] Y. Chen, T. T. Georgiou, and A. Tannenbaum, [Optimal Transport for Gaussian Mixture Models](https://doi.org/10.1109/ACCESS.2018.2889838), *IEEE Access*, 7 (2019), pp. 6269–6278.\n", "\n", "[3] Y. Chen, J. Ye, and J. Li, [A distance for HMMS based on aggregated Wasserstein metric and state registration](https://arxiv.org/abs/1608.01747), arXiv, (2016).\n", "\n", @@ -813,10 +813,23 @@ ] }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", + "language": "python", "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" } }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 1 } From f7b92957fa1cd8a523bcfc51c50507bc9954a8f7 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 21 Nov 2022 16:26:23 +0100 Subject: [PATCH 24/34] Fix bibliography links --- docs/notebooks/One_Sinkhorn.ipynb | 4 ++-- docs/notebooks/gromov_wasserstein_multiomics.ipynb | 2 +- docs/references.bib | 11 +---------- examples/fairness/data.py | 6 ++---- ott/geometry/costs.py | 1 + ott/math/implicit_differentiation.py | 5 ++--- 6 files changed, 9 insertions(+), 20 deletions(-) diff --git a/docs/notebooks/One_Sinkhorn.ipynb b/docs/notebooks/One_Sinkhorn.ipynb index e7c1f580e..01aefa7ec 100644 --- a/docs/notebooks/One_Sinkhorn.ipynb +++ b/docs/notebooks/One_Sinkhorn.ipynb @@ -54,7 +54,7 @@ "## From Texts to Word Histograms\n", "\n", "\n", - "We adapt a [keras NLP tutorial](https://keras.io/examples/nlp/pretrained_word_embeddings/) to preprocess raw text (here a subset of texts from the [newsgroup20](https://kdd.ics.uci.edu/databases/20newsgroups/20newsgroups.html) database) and turn them into word embeddings histograms. See [colab](https://colab.research.google.com/drive/1uCK_qBpOb8yY32ABU_GcykSKE-Q-yjfi#scrollTo=zsekRny9wZoI) for detailed pre-processing." + "We adapt a [keras NLP tutorial](https://keras.io/examples/nlp/pretrained_word_embeddings/) to preprocess raw text (here a subset of texts from the [newsgroup20](https://kdd.ics.uci.edu/databases/20newsgroups/20newsgroups.html) database) and turn them into word embeddings histograms. See [colab](https://colab.research.google.com/drive/1uCK_qBpOb8yY32ABU_GcykSKE-Q-yjfi) for detailed pre-processing." ] }, { @@ -497,7 +497,7 @@ "id": "o05rHAyPK3pN" }, "source": [ - "We now take a closer look at the actual convergence curves of the error of the `sinkhorn` algorithm (i.e. marginal error). We introduce a `plot_results` function to visualize this convergence (See [colab](https://colab.research.google.com/drive/1uCK_qBpOb8yY32ABU_GcykSKE-Q-yjfi#scrollTo=zsekRny9wZoI))." + "We now take a closer look at the actual convergence curves of the error of the `sinkhorn` algorithm (i.e. marginal error). We introduce a `plot_results` function to visualize this convergence (See [colab](https://colab.research.google.com/drive/1uCK_qBpOb8yY32ABU_GcykSKE-Q-yjfi))." ] }, { diff --git a/docs/notebooks/gromov_wasserstein_multiomics.ipynb b/docs/notebooks/gromov_wasserstein_multiomics.ipynb index 574d35292..545f9cd9f 100644 --- a/docs/notebooks/gromov_wasserstein_multiomics.ipynb +++ b/docs/notebooks/gromov_wasserstein_multiomics.ipynb @@ -15,7 +15,7 @@ "id": "BB8VjJrVsuuG" }, "source": [ - "A [variety of single-cell measurements](https://en.wikipedia.org/wiki/Single-cell_analysis) can help explore cell characteristics that are helpful to understand biological mechanisms. These measurements can for instance [describe epigenetic changes](https://en.wikipedia.org/wiki/Single_cell_epigenomics) (DNA methylation, chromatin accessibility, histone modifications, chromosome conformation), the genome itself, as well as the proteins present in the cell ([single cell sequencing](https://en.wikipedia.org/wiki/Single_cell_sequencing#Single-cell_genome_(DNA)_sequencing)). However, performing measures of different natures rises a major challenge: that of establishing an alignment across two (possibly several) measurement spaces that are unrelated, in the sense that no biological-based theory allows to construct such correspondences between them. In the absence of supervised information, the alignment can be constructed from first-hand principles, such as that of preserving geometry (i.e. an isomorphism) between the two measurement spaces. Indeed, since the population of cells measured is (statistically) the same across measurements, we expect that cells with similar genomes will be mapped to cells with similar transcriptomes, proteomes and epigenetic changes. \n", + "A [variety of single-cell measurements](https://en.wikipedia.org/wiki/Single-cell_analysis) can help explore cell characteristics that are helpful to understand biological mechanisms. These measurements can for instance [describe epigenetic changes](https://en.wikipedia.org/wiki/Single_cell_epigenomics) (DNA methylation, chromatin accessibility, histone modifications, chromosome conformation), the genome itself, as well as the proteins present in the cell ([single cell sequencing](https://en.wikipedia.org/wiki/Single_cell_sequencing). However, performing measures of different natures rises a major challenge: that of establishing an alignment across two (possibly several) measurement spaces that are unrelated, in the sense that no biological-based theory allows to construct such correspondences between them. In the absence of supervised information, the alignment can be constructed from first-hand principles, such as that of preserving geometry (i.e. an isomorphism) between the two measurement spaces. Indeed, since the population of cells measured is (statistically) the same across measurements, we expect that cells with similar genomes will be mapped to cells with similar transcriptomes, proteomes and epigenetic changes. \n", "\n", "The Gromov Wasserstein optimal transport framework, implemented in OTT, is a relevant candidate to perform such an unsupervised cell alignment. \n", "This approach was proposed by Demetci et al. in *Gromov-Wasserstein optimal transport to align single-cell multi-omics data, ICML 2020 Workshop on Computational Biology, 2020* ([ICML article](https://icml-compbio.github.io/icml-website-2020/2020/papers/WCBICML2020_paper_51.pdf), pre-print ) who called it SCOT ([GitHub repo](https://github.com/rsinghlab/SCOT)), from which this notebook is adapted.\n", diff --git a/docs/references.bib b/docs/references.bib index 882822f55..3bef70d15 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -160,10 +160,8 @@ @Article{demetci:20 title = {Gromov-Wasserstein optimal transport to align single-cell multi-omics data}, elocation-id = {2020.04.28.066787}, year = {2020}, - doi = {10.1101/2020.04.28.066787}, publisher = {Cold Spring Harbor Laboratory}, URL = {https://www.biorxiv.org/content/early/2020/11/11/2020.04.28.066787}, - eprint = {https://www.biorxiv.org/content/early/2020/11/11/2020.04.28.066787.full.pdf}, journal = {bioRxiv} } @@ -213,8 +211,6 @@ @Article{gelbrich:90 number = {1}, pages = {185-203}, doi = {https://doi.org/10.1002/mana.19901470121}, - url = {https://onlinelibrary.wiley.com/doi/abs/10.1002/mana.19901470121}, - eprint = {https://onlinelibrary.wiley.com/doi/pdf/10.1002/mana.19901470121}, year = {1990} } @@ -293,9 +289,8 @@ @Article{benamou:15 pages = {A1111-A1138}, year = {2015}, doi = {10.1137/141000439}, - URL = {https://doi.org/10.1137/141000439}, - eprint = {https://doi.org/10.1137/141000439} } + @article{brenier:91, title={Polar factorization and monotone rearrangement of vector-valued functions}, author={Brenier, Yann}, @@ -406,8 +401,6 @@ @Article{delon:20 pages = {936-970}, year = {2020}, doi = {10.1137/19M1301047}, - URL = {https://doi.org/10.1137/19M1301047}, - eprint = {https://doi.org/10.1137/19M1301047}, } @InProceedings{janati:20a, @@ -435,8 +428,6 @@ @Article{schmitz:18 pages = {643-678}, year = {2018}, doi = {10.1137/17M1140431}, - URL = {https://doi.org/10.1137/17M1140431}, - eprint = {https://doi.org/10.1137/17M1140431}, } @Article{alvarez-esteban:16, diff --git a/examples/fairness/data.py b/examples/fairness/data.py index 075e46367..91547d421 100644 --- a/examples/fairness/data.py +++ b/examples/fairness/data.py @@ -19,8 +19,6 @@ import numpy as np import pandas as pd -open_fn = open - def load_df( data_path: str, @@ -30,12 +28,12 @@ def load_df( **kwargs ): """Load a pandas dataframe from two filenames.""" - with open_fn(data_path, 'r') as fp: + with open(data_path) as fp: df = pd.read_csv(fp, skipinitialspace=True, header=None, **kwargs) headers = [] targets = [] - with open_fn(info_path, 'r') as fp: + with open(info_path) as fp: for line in fp: if line.startswith('|') or not line.strip(): continue diff --git a/ott/geometry/costs.py b/ott/geometry/costs.py index fdf52e7de..eefcf1613 100644 --- a/ott/geometry/costs.py +++ b/ott/geometry/costs.py @@ -188,6 +188,7 @@ def __init__(self, p: float): super().__init__() assert p >= 1.0, "p parameter in p-norm should be >= 1.0" self.p = p + # TODO(marcocuturi): fid case when `p=1` self.q = 1. / (1. - 1. / self.p) if p > 1. else "inf" def h(self, z: jnp.ndarray) -> float: diff --git a/ott/math/implicit_differentiation.py b/ott/math/implicit_differentiation.py index 26bbbea97..8e5cf30c8 100644 --- a/ott/math/implicit_differentiation.py +++ b/ott/math/implicit_differentiation.py @@ -120,8 +120,7 @@ def solve( or kernel matrix are numerically close. To avoid this, we add a more global ``ridge_identity * z`` regularizer to achieve better conditioning. - These linear systems are solved using the user defined - ``implicit_solver_fun``, + These linear systems are solved using the user defined ``solver_fun``, which is set by default to ``cg``. When the system is symmetric (as detected by the corresponding flag ``symmetric``), ``cg`` is applied directly. When it is not, normal equations are used (i.e. the Schur complement is @@ -129,7 +128,7 @@ def solve( Args: gr: 2-tuple, (vector of size ``n``, vector of size ``m``). - ot_prob: the instantiation of the regularizad transport problem. + ot_prob: the instantiation of the regularized transport problem. f: potential, w.r.t marginal a. g: potential, w.r.t marginal b. lse_mode: bool, log-sum-exp mode if True, kernel else. From 8875b2fd0bba89fdfadabc7ad85a4002d9ed411d Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 21 Nov 2022 16:55:36 +0100 Subject: [PATCH 25/34] Fix more links in the notebooks --- docs/notebooks/gromov_wasserstein_multiomics.ipynb | 4 ++-- docs/notebooks/soft_sort.ipynb | 4 ++-- ott/solvers/nn/layers.py | 5 +++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/docs/notebooks/gromov_wasserstein_multiomics.ipynb b/docs/notebooks/gromov_wasserstein_multiomics.ipynb index 545f9cd9f..778b97cc1 100644 --- a/docs/notebooks/gromov_wasserstein_multiomics.ipynb +++ b/docs/notebooks/gromov_wasserstein_multiomics.ipynb @@ -20,7 +20,7 @@ "The Gromov Wasserstein optimal transport framework, implemented in OTT, is a relevant candidate to perform such an unsupervised cell alignment. \n", "This approach was proposed by Demetci et al. in *Gromov-Wasserstein optimal transport to align single-cell multi-omics data, ICML 2020 Workshop on Computational Biology, 2020* ([ICML article](https://icml-compbio.github.io/icml-website-2020/2020/papers/WCBICML2020_paper_51.pdf), pre-print ) who called it SCOT ([GitHub repo](https://github.com/rsinghlab/SCOT)), from which this notebook is adapted.\n", "\n", - "The original SCOT code code uses [POT](https://github.com/PythonOT/POT) (Python Optimal Transport). Here, we propose a slight modification of the SCOT code to use the `gromov_wasserstein` [function](https://ott-jax.readthedocs.io/en/latest/notebooks/gromov_wasserstein.html) from `ott.core`, which we have found to be faster than the POT implementation `ot.gromov.entropic_gromov_wasserstein` on GPU (see *Alignment and evaluation*). We then use this OTT version of the SCOT algorithm to perform cell alignment for the SNARE-seq dataset , which contains measures of two natures:\n", + "The original SCOT code code uses [POT](https://github.com/PythonOT/POT) (Python Optimal Transport). Here, we propose a slight modification of the SCOT code to use the [gromov_wasserstein function](../solvers/_autosummary/ott.solvers.quadratic.gromov_wasserstein.gromov_wasserstein.html), which we have found to be faster than the POT implementation `ot.gromov.entropic_gromov_wasserstein` on GPU (see *Alignment and evaluation*). We then use this OTT version of the SCOT algorithm to perform cell alignment for the SNARE-seq dataset , which contains measures of two natures:\n", "\n", " - Chromatin accessibility ([scATAC-seq](https://en.wikipedia.org/wiki/ATAC-seq) data)\n", " - Gene expression ([scRNA-seq](https://en.wikipedia.org/wiki/Single_cell_sequencing#scRNA-Seq) data)" @@ -137,7 +137,7 @@ "id": "G8nC8PpHsuuS" }, "source": [ - "## Using `gromov_wasserstein` from `ott.core` to perform OT" + "## Using `gromov_wasserstein` from OTT" ] }, { diff --git a/docs/notebooks/soft_sort.ipynb b/docs/notebooks/soft_sort.ipynb index 244cf2976..0a2f3d661 100644 --- a/docs/notebooks/soft_sort.ipynb +++ b/docs/notebooks/soft_sort.ipynb @@ -183,7 +183,7 @@ "\n", "This colab shows how _soft_ counterparts to these operators are defined in OTT. By _soft_, we mean **differentiable**, **approximate** proxies to these original _\"hard\"_ operators. For instance `soft_sort.ranks` returned by OTT operators won't be integer valued, but instead floating point approximations; `soft_sort.sort` will not contain exactly the `n` values contained in the input array, reordered, but instead `n` combinaisons of thoses values that look very close to them.\n", "\n", - "**These soft operators trade off accuracy for a more informative Jacobian**. This trade-off is controlled by a non-negative parameter `epsilon`: The *smaller* `epsilon`, the closer to the original ranking and sorting operations; The *bigger*, the more bias yet the more informative gradients. That `epsilon` also correponds to that used in regularized OT (see doc on [sinkhorn](https://ott-jax.readthedocs.io/en/latest/_autosummary/ott.core.sinkhorn.sinkhorn.html#ott.core.sinkhorn.sinkhorn)).\n", + "**These soft operators trade off accuracy for a more informative Jacobian**. This trade-off is controlled by a non-negative parameter `epsilon`: The *smaller* `epsilon`, the closer to the original ranking and sorting operations; The *bigger*, the more bias yet the more informative gradients. That `epsilon` also correponds to that used in regularized OT (see doc on [sinkhorn](../solvers/_autosummary/ott.solvers.linear.sinkhorn.sinkhorn-function.html)).\n", "\n", "The behavior of these operators is illustrated below." ] @@ -599,7 +599,7 @@ "\n", "In this tutorial we show how OTT can be used to implement a loss based on soft ranks. That soft 0-1 loss is used here to train a neural network for image classification, as done by Cuturi et al.\n", "\n", - "This implementation relies on [FLAX](https://github.com/google/flax) a neural network library for JAX." + "This implementation relies on [Flax](https://github.com/google/flax) a neural network library for JAX." ] }, { diff --git a/ott/solvers/nn/layers.py b/ott/solvers/nn/layers.py index a0add5a9d..770d966dd 100644 --- a/ott/solvers/nn/layers.py +++ b/ott/solvers/nn/layers.py @@ -42,8 +42,9 @@ class PositiveDense(nn.Module): bias_init: initializer function for the bias. """ dim_hidden: int - rectifier_fn: Callable = nn.softplus - inv_rectifier_fn: Callable = lambda x: jnp.log(jnp.exp(x) - 1) + rectifier_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.softplus + inv_rectifier_fn: Callable[[jnp.ndarray], + jnp.ndarray] = lambda x: jnp.log(jnp.exp(x) - 1) use_bias: bool = True dtype: Any = jnp.float32 precision: Any = None From 791bfed58e2330321982d2d641de49dc0c3f544e Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 21 Nov 2022 17:43:56 +0100 Subject: [PATCH 26/34] Follow line length in README.md --- .editorconfig | 2 +- CONTRIBUTING.md | 2 +- README.md | 48 +++++++++++++++++++++++++++++++++--------------- 3 files changed, 35 insertions(+), 17 deletions(-) diff --git a/.editorconfig b/.editorconfig index a20daede6..e480103dd 100644 --- a/.editorconfig +++ b/.editorconfig @@ -3,9 +3,9 @@ root = true [*] end_of_line = lf insert_final_newline = true +charset = utf-8 [*py] -charset = utf-8 indent_size = 2 indent_style = space max_line_length = 80 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 6038bcf66..34eb5da00 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -9,7 +9,7 @@ to the project, participating in discussions or raising issues. 1. fork the repository using the **Fork** button on GitHub or the following [link](https://github.com/ott-jax/ott/fork) 2. ```bash - git clone https://github.com/YOUR_USERNAME/ott + git clone https://github.com//ott cd ott pip install -e .'[dev,test]' pre-commit install diff --git a/README.md b/README.md index 48e8e75f9..6dede0b38 100644 --- a/README.md +++ b/README.md @@ -1,30 +1,43 @@ logo -# Optimal Transport Tools (OTT). - +# Optimal Transport Tools (OTT) [![Tests](https://img.shields.io/github/workflow/status/ott-jax/ott/tests/main)](https://github.com/ott-jax/ott/actions/workflows/tests.yml) [![Docs](https://img.shields.io/readthedocs/ott-jax/latest)](https://ott-jax.readthedocs.io/en/latest/) [![Coverage](https://img.shields.io/codecov/c/github/ott-jax/ott/main)](https://app.codecov.io/gh/ott-jax/ott) -**See [full documentation](https://ott-jax.readthedocs.io/en/latest/).** +**See the [full documentation](https://ott-jax.readthedocs.io/en/latest/).** ## What is OTT-JAX? - -A JAX powered library to compute optimal transport at scale and on accelerators, OTT-JAX includes the fastest implementation of the Sinkhorn algorithm you will find around. We have implemented all tweaks (scheduling, acceleration, initializations) and extensions (low-rank), that can be used directly, or within more advanced problems (Gromov-Wasserstein, barycenters). Some of JAX features, including [JIT](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#Using-jit-to-speed-up-functions), [auto-vectorization](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#Auto-vectorization-with-vmap) and [implicit differentiation](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) work towards the goal of having end-to-end differentiable outputs. OTT-JAX is developed by a team of researchers from Apple, Google, Meta and many academic contributors, including TU München, Oxford, ENSAE/IP Paris and the Hebrew University. +A JAX powered library to compute optimal transport at scale and on accelerators, OTT-JAX includes the fastest +implementation of the Sinkhorn algorithm you will find around. We have implemented all tweaks (scheduling, +acceleration, initializations) and extensions (low-rank), that can be used directly, or within more advanced problems +(Gromov-Wasserstein, barycenters). Some of JAX features, including +[JIT](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#Using-jit-to-speed-up-functions), +[auto-vectorization](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#Auto-vectorization-with-vmap) and +[implicit differentiation](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) +work towards the goal of having end-to-end differentiable outputs. OTT-JAX is developed by a team of researchers +from Apple, Google, Meta and many academic contributors, including TU München, Oxford, ENSAE/IP Paris and the +Hebrew University. ## What is optimal transport? +Optimal transport can be loosely described as the branch of mathematics and optimization that studies +*matching problems*: given two families of points, and a cost function on pairs of points, find a `good' (low cost) way +to associate bijectively to every point in the first family another in the second. -Optimal transport can be loosely described as the branch of mathematics and optimization that studies *matching problems*: given two families of points, and a cost function on pairs of points, find a `good' (low cost) way to associate bijectively to every point in the first family another in the second. - -Such problems appear in all areas of science, are easy to describe, yet hard to solve. Indeed, while matching optimally two sets of *n* points using a pairwise cost can be solved with the [Hungarian algorithm](https://en.wikipedia.org/wiki/Hungarian_algorithm), solving it costs an order of $O(n^3)$ operations, and lacks flexibility, since one may want to couple families of different sizes. +Such problems appear in all areas of science, are easy to describe, yet hard to solve. Indeed, while matching optimally +two sets of *n* points using a pairwise cost can be solved with the +[Hungarian algorithm](https://en.wikipedia.org/wiki/Hungarian_algorithm), solving it costs an order of $O(n^3)$ +operations, and lacks flexibility, since one may want to couple families of different sizes. -Optimal transport extends all of this, through faster algorithms (in $n^2$ or even linear in $n$) along with numerous generalizations that can help it handle weighted sets of different size, partial matchings, and even more evolved so-called quadratic matching problems. +Optimal transport extends all of this, through faster algorithms (in $n^2$ or even linear in $n$) along with numerous +generalizations that can help it handle weighted sets of different size, partial matchings, and even more evolved +so-called quadratic matching problems. -In the simple toy example below, we compute the optimal coupling matrix between two point clouds sampled randomly (2D vectors, compared with the squared Euclidean distance): +In the simple toy example below, we compute the optimal coupling matrix between two point clouds sampled randomly +(2D vectors, compared with the squared Euclidean distance): ## Example - -```py +```python import jax import jax.numpy as jnp from ott.tools import transport @@ -41,17 +54,22 @@ ot = transport.solve(x, y, a=a, b=b) P = ot.matrix ``` -The call to `solve` above works out the optimal transport solution. The `ot` object contains a transport matrix (here of size $12\times 14$) that quantifies a `link strength` between each point of the first point cloud, to one or more points from the second, as illustrated in the plot below. In this toy example, most choices were arbitrary, and are reflected in the crude `solve` API. We provide far more flexibility to define custom cost functions, objectives, and solvers, as detailed in the [full documentation](https://ott-jax.readthedocs.io/en/latest/). +The call to `solve` above works out the optimal transport solution. The `ot` object contains a transport matrix +(here of size $12\times 14$) that quantifies a `link strength` between each point of the first point cloud, to one or +more points from the second, as illustrated in the plot below. In this toy example, most choices were arbitrary, and +are reflected in the crude `solve` API. We provide far more flexibility to define custom cost functions, objectives, +and solvers, as detailed in the [full documentation](https://ott-jax.readthedocs.io/en/latest/). ![obtained coupling](https://raw.githubusercontent.com/ott-jax/ott/main/images/couplings.png) -## Citation +## Citation If you have found this work useful, please consider citing this reference: ``` @article{cuturi2022optimal, title={Optimal Transport Tools (OTT): A JAX Toolbox for all things Wasserstein}, - author={Cuturi, Marco and Meng-Papaxanthos, Laetitia and Tian, Yingtao and Bunne, Charlotte and Davis, Geoff and Teboul, Olivier}, + author={Cuturi, Marco and Meng-Papaxanthos, Laetitia and Tian, Yingtao and Bunne, Charlotte and + Davis, Geoff and Teboul, Olivier}, journal={arXiv preprint arXiv:2201.12324}, year={2022} } From 41a7ca7e59269f1780d21516fef1de054aa9a025 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Tue, 22 Nov 2022 09:31:26 +0100 Subject: [PATCH 27/34] Update `tests` structure --- .../{geometry_costs_test.py => costs_test.py} | 0 .../{geometry_lr_test.py => low_rank_test.py} | 0 ...cloud_apply_test.py => pointcloud_test.py} | 0 ...etry_subset_test.py => subsetting_test.py} | 4 +- .../linear/sinkhorn_init_test.py} | 261 +----------------- .../linear/sinkhorn_lr_init_test.py | 171 ++++++++++++ tests/initializers/quadratic/gw_init_test.py | 132 +++++++++ .../geometry_lse_test.py => math/lse_test.py} | 0 .../matrix_square_root_test.py | 0 .../linear}/potentials_test.py | 0 .../linear}/continuous_barycenter_test.py | 158 +---------- .../linear}/discrete_barycenter_test.py | 2 - .../linear}/sinkhorn_diff_test.py | 0 .../linear}/sinkhorn_grid_test.py | 0 .../linear}/sinkhorn_lr_test.py | 0 .../linear/sinkhorn_misc_test.py} | 2 +- .../{core => solvers/linear}/sinkhorn_test.py | 0 tests/{core => solvers/nn}/icnn_test.py | 0 tests/{core => solvers/nn}/neuraldual_test.py | 0 .../solvers/quadratic/fgw_barycenter_test.py | 78 ++++++ .../quadratic/fgw_test.py} | 40 +-- tests/solvers/quadratic/gw_barycenter_test.py | 113 ++++++++ .../quadratic/gw_test.py} | 22 +- 23 files changed, 531 insertions(+), 452 deletions(-) rename tests/geometry/{geometry_costs_test.py => costs_test.py} (100%) rename tests/geometry/{geometry_lr_test.py => low_rank_test.py} (100%) rename tests/geometry/{geometry_pointcloud_apply_test.py => pointcloud_test.py} (100%) rename tests/geometry/{geometry_subset_test.py => subsetting_test.py} (98%) rename tests/{core/initializers_test.py => initializers/linear/sinkhorn_init_test.py} (50%) create mode 100644 tests/initializers/linear/sinkhorn_lr_init_test.py create mode 100644 tests/initializers/quadratic/gw_init_test.py rename tests/{geometry/geometry_lse_test.py => math/lse_test.py} (100%) rename tests/{geometry => math}/matrix_square_root_test.py (100%) rename tests/{core => problems/linear}/potentials_test.py (100%) rename tests/{core => solvers/linear}/continuous_barycenter_test.py (69%) rename tests/{core => solvers/linear}/discrete_barycenter_test.py (99%) rename tests/{core => solvers/linear}/sinkhorn_diff_test.py (100%) rename tests/{core => solvers/linear}/sinkhorn_grid_test.py (100%) rename tests/{core => solvers/linear}/sinkhorn_lr_test.py (100%) rename tests/{core/sinkhorn_extra_test.py => solvers/linear/sinkhorn_misc_test.py} (99%) rename tests/{core => solvers/linear}/sinkhorn_test.py (100%) rename tests/{core => solvers/nn}/icnn_test.py (100%) rename tests/{core => solvers/nn}/neuraldual_test.py (100%) create mode 100644 tests/solvers/quadratic/fgw_barycenter_test.py rename tests/{core/fused_gromov_wasserstein_test.py => solvers/quadratic/fgw_test.py} (92%) create mode 100644 tests/solvers/quadratic/gw_barycenter_test.py rename tests/{core/gromov_wasserstein_test.py => solvers/quadratic/gw_test.py} (96%) diff --git a/tests/geometry/geometry_costs_test.py b/tests/geometry/costs_test.py similarity index 100% rename from tests/geometry/geometry_costs_test.py rename to tests/geometry/costs_test.py diff --git a/tests/geometry/geometry_lr_test.py b/tests/geometry/low_rank_test.py similarity index 100% rename from tests/geometry/geometry_lr_test.py rename to tests/geometry/low_rank_test.py diff --git a/tests/geometry/geometry_pointcloud_apply_test.py b/tests/geometry/pointcloud_test.py similarity index 100% rename from tests/geometry/geometry_pointcloud_apply_test.py rename to tests/geometry/pointcloud_test.py diff --git a/tests/geometry/geometry_subset_test.py b/tests/geometry/subsetting_test.py similarity index 98% rename from tests/geometry/geometry_subset_test.py rename to tests/geometry/subsetting_test.py index 5d57e6eb9..2369e4b28 100644 --- a/tests/geometry/geometry_subset_test.py +++ b/tests/geometry/subsetting_test.py @@ -11,7 +11,9 @@ @pytest.fixture() -def pc_masked(rng: jnp.ndarray) -> Tuple[pointcloud.PointCloud, Tuple]: +def pc_masked( + rng: jnp.ndarray +) -> Tuple[pointcloud.PointCloud, pointcloud.PointCloud]: n, m = 20, 30 key1, key2 = jax.random.split(rng, 2) # x = jnp.full((n,), fill_value=1.) diff --git a/tests/core/initializers_test.py b/tests/initializers/linear/sinkhorn_init_test.py similarity index 50% rename from tests/core/initializers_test.py rename to tests/initializers/linear/sinkhorn_init_test.py index 47cdc10e4..27e3f347d 100644 --- a/tests/core/initializers_test.py +++ b/tests/initializers/linear/sinkhorn_init_test.py @@ -20,14 +20,10 @@ import pytest import ott.initializers.nn.initializers -from ott.geometry import geometry, low_rank, pointcloud +from ott.geometry import geometry, pointcloud from ott.initializers.linear import initializers as lin_init -from ott.initializers.linear import initializers_lr -from ott.initializers.quadratic import initializers as quad_init from ott.problems.linear import linear_problem -from ott.problems.quadratic import quadratic_problem -from ott.solvers.linear import sinkhorn, sinkhorn_lr -from ott.solvers.quadratic import gromov_wasserstein +from ott.solvers.linear import sinkhorn def create_sorting_problem(rng, n, epsilon=0.01, online=False): @@ -333,256 +329,3 @@ def test_meta_initializer(self, lse_mode, rng: jnp.ndarray): # check initializer is better if lse_mode: assert base_num_iter >= meta_num_iter - - -class TestLRInitializers: - - @pytest.mark.fast.with_args("kind", ["pc", "lrc", "geom"], only_fast=0) - def test_create_default_initializer(self, rng: jnp.ndarray, kind: str): - n, d, rank = 110, 2, 3 - x = jax.random.normal(rng, (n, d)) - geom = pointcloud.PointCloud(x) - - if kind == "pc": - pass - elif kind == "lrc": - geom = geom.to_LRCGeometry() - assert isinstance(geom, low_rank.LRCGeometry) - elif kind == "geom": - geom = geometry.Geometry(geom.cost_matrix) - else: - raise NotImplementedError(geom) - prob = linear_problem.LinearProblem(geom) - - solver = sinkhorn_lr.LRSinkhorn(rank=rank, initializer=None) - initializer = solver.create_initializer(prob) - - assert initializer.rank == rank - if kind in ("pc", "lrc"): - assert isinstance(initializer, initializers_lr.KMeansInitializer) - else: - assert isinstance(initializer, initializers_lr.RandomInitializer) - - q, r, g = initializer(prob) - - assert q.shape == (n, rank) - assert r.shape == (n, rank) - assert g.shape == (rank,) - - def test_explicitly_passing_initializer(self): - rank = 2 - initializer = initializers_lr.RandomInitializer(rank=rank) - solver = sinkhorn_lr.LRSinkhorn(rank=rank, initializer=initializer) - - assert solver.create_initializer(prob="not used") is initializer - - @pytest.mark.parametrize( - "initializer", ["random", "rank2", "k-means", "generalized-k-means"] - ) - @pytest.mark.parametrize("partial_init", ["q", "r", "g"]) - def test_partial_initialization( - self, rng: jnp.ndarray, initializer: str, partial_init: str - ): - n, d, rank = 100, 10, 6 - key1, key2, key3, key4 = jax.random.split(rng, 4) - x = jax.random.normal(key1, (n, d)) - pc = pointcloud.PointCloud(x, epsilon=5e-1) - prob = linear_problem.LinearProblem(pc) - q_init = jax.random.normal(key2, (n, rank)) - r_init = jax.random.normal(key2, (n, rank)) - g_init = jax.random.normal(key2, (rank,)) - - solver = sinkhorn_lr.LRSinkhorn(rank=rank, initializer=initializer) - initializer = solver.create_initializer(prob) - - if partial_init == "q": - q, _, _ = initializer(prob, q=q_init) - np.testing.assert_array_equal(q, q_init) - elif partial_init == "r": - _, r, _ = initializer(prob, r=r_init) - np.testing.assert_array_equal(r, r_init) - elif partial_init == "g": - _, _, g = initializer(prob, g=g_init) - np.testing.assert_array_equal(g, g_init) - else: - raise NotImplementedError(partial_init) - - @pytest.mark.fast.with_args("rank", [2, 4, 10, 13], only_fast=True) - def test_generalized_k_means_has_correct_rank( - self, rng: jnp.ndarray, rank: int - ): - n, d = 100, 10 - x = jax.random.normal(rng, (n, d)) - pc = pointcloud.PointCloud(x, epsilon=5e-1) - prob = linear_problem.LinearProblem(pc) - - solver = sinkhorn_lr.LRSinkhorn( - rank=rank, initializer="generalized-k-means" - ) - initializer = solver.create_initializer(prob) - - q, r, g = initializer(prob) - - assert jnp.linalg.matrix_rank(q) == rank - assert jnp.linalg.matrix_rank(r) == rank - - def test_generalized_k_means_matches_k_means(self, rng: jnp.ndarray): - n, d, rank = 120, 15, 5 - eps = 1e-1 - key1, key2 = jax.random.split(rng, 2) - x = jax.random.normal(key1, (n, d)) - y = jax.random.normal(key1, (n, d)) - - pc = pointcloud.PointCloud(x, y, epsilon=eps) - geom = geometry.Geometry(cost_matrix=pc.cost_matrix, epsilon=eps) - pc_problem = linear_problem.LinearProblem(pc) - geom_problem = linear_problem.LinearProblem(geom) - - solver = sinkhorn_lr.LRSinkhorn( - rank=rank, initializer="k-means", max_iterations=5000 - ) - pc_out = solver(pc_problem) - - solver = sinkhorn_lr.LRSinkhorn( - rank=rank, initializer="generalized-k-means", max_iterations=5000 - ) - geom_out = solver(geom_problem) - - with pytest.raises(AssertionError): - np.testing.assert_allclose(pc_out.costs, geom_out.costs) - - np.testing.assert_allclose( - pc_out.reg_ot_cost, geom_out.reg_ot_cost, atol=0.5, rtol=0.02 - ) - - @pytest.mark.parametrize("epsilon", [0., 1e-1]) - def test_better_initialization_helps(self, rng: jnp.ndarray, epsilon: float): - n, d, rank = 81, 13, 3 - key1, key2 = jax.random.split(rng, 2) - x = jax.random.normal(key1, (n, d)) - y = jax.random.normal(key2, (n, d)) - pc = pointcloud.PointCloud(x, y, epsilon=5e-1) - prob = linear_problem.LinearProblem(pc) - - solver_random = sinkhorn_lr.LRSinkhorn( - rank=rank, epsilon=epsilon, initializer="random", max_iterations=10000 - ) - solver_init = sinkhorn_lr.LRSinkhorn( - rank=rank, epsilon=epsilon, initializer="k-means", max_iterations=10000 - ) - - out_random = solver_random(prob) - out_init = solver_init(prob) - - assert out_random.converged - assert out_init.converged - # converged earlier - assert (out_init.errors > -1).sum() < (out_random.errors > -1).sum() - # converged to a better solution - assert out_init.reg_ot_cost < out_random.reg_ot_cost - - -class TestQuadraticInitializers: - - @pytest.mark.parametrize("kind", ["pc", "lrc", "geom"]) - def test_create_default_lr_initializer(self, rng: jnp.ndarray, kind: str): - n, d1, d2, rank = 150, 2, 3, 5 - eps = 1e-1 - key1, key2 = jax.random.split(rng, 2) - x = jax.random.normal(key1, (n, d1)) - y = jax.random.normal(key1, (n, d2)) - kwargs_init = {"foo": "bar"} - - geom_x = pointcloud.PointCloud(x, epsilon=eps) - geom_y = pointcloud.PointCloud(y, epsilon=eps) - if kind == "pc": - pass - elif kind == "lrc": - geom_x = geom_x.to_LRCGeometry() - geom_y = geom_y.to_LRCGeometry() - elif kind == "geom": - geom_x = geometry.Geometry(geom_x.cost_matrix, epsilon=eps) - geom_y = geometry.Geometry(geom_y.cost_matrix, epsilon=eps) - else: - raise NotImplementedError(kind) - prob = quadratic_problem.QuadraticProblem(geom_x, geom_y) - - solver = gromov_wasserstein.GromovWasserstein( - rank=rank, quad_initializer=None, kwargs_init=kwargs_init - ) - initializer = solver.create_initializer(prob) - - assert isinstance(initializer, quad_init.LRQuadraticInitializer) - assert initializer.rank == rank - linear_init = initializer._linear_lr_initializer - if kind in ("pc", "lrc"): - assert isinstance(linear_init, initializers_lr.KMeansInitializer) - else: - assert isinstance(linear_init, initializers_lr.RandomInitializer) - assert linear_init._kwargs == kwargs_init - - def test_non_lr_initializer(self): - solver = gromov_wasserstein.GromovWasserstein( - rank=-1, quad_initializer="not used" - ) - initializer = solver.create_initializer(prob="not used") - assert isinstance(initializer, quad_init.QuadraticInitializer) - - @pytest.mark.parametrize("rank", [-1, 2]) - def test_explicitly_passing_initializer(self, rank: int): - if rank == -1: - linear_init = lin_init.SortingInitializer() - q_init = quad_init.QuadraticInitializer() - else: - linear_init = initializers_lr.Rank2Initializer(rank) - q_init = quad_init.LRQuadraticInitializer(linear_init) - - solver = gromov_wasserstein.GromovWasserstein( - initializer=linear_init, - quad_initializer=q_init, - ) - - assert solver.linear_ot_solver.initializer is linear_init - assert solver.quad_initializer is q_init - if solver.is_low_rank: - assert solver.quad_initializer.rank == rank - - @pytest.mark.parametrize("eps", [0., 1e-2]) - def test_gw_better_initialization_helps(self, rng: jnp.ndarray, eps: float): - n, m, d1, d2, rank = 123, 124, 12, 10, 5 - key1, key2, key3, key4 = jax.random.split(rng, 4) - - geom_x = pointcloud.PointCloud( - jax.random.normal(key1, (n, d1)), - jax.random.normal(key2, (n, d1)), - epsilon=eps, - ) - geom_y = pointcloud.PointCloud( - jax.random.normal(key3, (m, d2)), - jax.random.normal(key4, (m, d2)), - epsilon=eps, - ) - problem = quadratic_problem.QuadraticProblem(geom_x, geom_y) - solver_random = gromov_wasserstein.GromovWasserstein( - rank=rank, - initializer="random", - quad_initializer="random", - epsilon=eps, - store_inner_errors=True, - ) - solver_kmeans = gromov_wasserstein.GromovWasserstein( - rank=rank, - initializer="k-means", - quad_initializer="k-means", - epsilon=eps, - store_inner_errors=True - ) - - out_random = solver_random(problem) - out_kmeans = solver_kmeans(problem) - - assert out_random.reg_gw_cost - out_kmeans.reg_gw_cost >= 1. - random_errors = out_random.errors[out_random.errors > -1] - kmeans_errors = out_kmeans.errors[out_kmeans.errors > -1] - np.testing.assert_array_equal(random_errors >= 0., True) - np.testing.assert_array_equal(kmeans_errors >= 0., True) diff --git a/tests/initializers/linear/sinkhorn_lr_init_test.py b/tests/initializers/linear/sinkhorn_lr_init_test.py new file mode 100644 index 000000000..bff23ec2e --- /dev/null +++ b/tests/initializers/linear/sinkhorn_lr_init_test.py @@ -0,0 +1,171 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Tests for Sinkhorn initializers.""" + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from ott.geometry import geometry, low_rank, pointcloud +from ott.initializers.linear import initializers_lr +from ott.problems.linear import linear_problem +from ott.solvers.linear import sinkhorn_lr + + +class TestLRInitializers: + + @pytest.mark.fast.with_args("kind", ["pc", "lrc", "geom"], only_fast=0) + def test_create_default_initializer(self, rng: jnp.ndarray, kind: str): + n, d, rank = 110, 2, 3 + x = jax.random.normal(rng, (n, d)) + geom = pointcloud.PointCloud(x) + + if kind == "pc": + pass + elif kind == "lrc": + geom = geom.to_LRCGeometry() + assert isinstance(geom, low_rank.LRCGeometry) + elif kind == "geom": + geom = geometry.Geometry(geom.cost_matrix) + else: + raise NotImplementedError(geom) + prob = linear_problem.LinearProblem(geom) + + solver = sinkhorn_lr.LRSinkhorn(rank=rank, initializer=None) + initializer = solver.create_initializer(prob) + + assert initializer.rank == rank + if kind in ("pc", "lrc"): + assert isinstance(initializer, initializers_lr.KMeansInitializer) + else: + assert isinstance(initializer, initializers_lr.RandomInitializer) + + q, r, g = initializer(prob) + + assert q.shape == (n, rank) + assert r.shape == (n, rank) + assert g.shape == (rank,) + + def test_explicitly_passing_initializer(self): + rank = 2 + initializer = initializers_lr.RandomInitializer(rank=rank) + solver = sinkhorn_lr.LRSinkhorn(rank=rank, initializer=initializer) + + assert solver.create_initializer(prob="not used") is initializer + + @pytest.mark.parametrize( + "initializer", ["random", "rank2", "k-means", "generalized-k-means"] + ) + @pytest.mark.parametrize("partial_init", ["q", "r", "g"]) + def test_partial_initialization( + self, rng: jnp.ndarray, initializer: str, partial_init: str + ): + n, d, rank = 100, 10, 6 + key1, key2, key3, key4 = jax.random.split(rng, 4) + x = jax.random.normal(key1, (n, d)) + pc = pointcloud.PointCloud(x, epsilon=5e-1) + prob = linear_problem.LinearProblem(pc) + q_init = jax.random.normal(key2, (n, rank)) + r_init = jax.random.normal(key2, (n, rank)) + g_init = jax.random.normal(key2, (rank,)) + + solver = sinkhorn_lr.LRSinkhorn(rank=rank, initializer=initializer) + initializer = solver.create_initializer(prob) + + if partial_init == "q": + q, _, _ = initializer(prob, q=q_init) + np.testing.assert_array_equal(q, q_init) + elif partial_init == "r": + _, r, _ = initializer(prob, r=r_init) + np.testing.assert_array_equal(r, r_init) + elif partial_init == "g": + _, _, g = initializer(prob, g=g_init) + np.testing.assert_array_equal(g, g_init) + else: + raise NotImplementedError(partial_init) + + @pytest.mark.fast.with_args("rank", [2, 4, 10, 13], only_fast=True) + def test_generalized_k_means_has_correct_rank( + self, rng: jnp.ndarray, rank: int + ): + n, d = 100, 10 + x = jax.random.normal(rng, (n, d)) + pc = pointcloud.PointCloud(x, epsilon=5e-1) + prob = linear_problem.LinearProblem(pc) + + solver = sinkhorn_lr.LRSinkhorn( + rank=rank, initializer="generalized-k-means" + ) + initializer = solver.create_initializer(prob) + + q, r, g = initializer(prob) + + assert jnp.linalg.matrix_rank(q) == rank + assert jnp.linalg.matrix_rank(r) == rank + + def test_generalized_k_means_matches_k_means(self, rng: jnp.ndarray): + n, d, rank = 120, 15, 5 + eps = 1e-1 + key1, key2 = jax.random.split(rng, 2) + x = jax.random.normal(key1, (n, d)) + y = jax.random.normal(key1, (n, d)) + + pc = pointcloud.PointCloud(x, y, epsilon=eps) + geom = geometry.Geometry(cost_matrix=pc.cost_matrix, epsilon=eps) + pc_problem = linear_problem.LinearProblem(pc) + geom_problem = linear_problem.LinearProblem(geom) + + solver = sinkhorn_lr.LRSinkhorn( + rank=rank, initializer="k-means", max_iterations=5000 + ) + pc_out = solver(pc_problem) + + solver = sinkhorn_lr.LRSinkhorn( + rank=rank, initializer="generalized-k-means", max_iterations=5000 + ) + geom_out = solver(geom_problem) + + with pytest.raises(AssertionError): + np.testing.assert_allclose(pc_out.costs, geom_out.costs) + + np.testing.assert_allclose( + pc_out.reg_ot_cost, geom_out.reg_ot_cost, atol=0.5, rtol=0.02 + ) + + @pytest.mark.parametrize("epsilon", [0., 1e-1]) + def test_better_initialization_helps(self, rng: jnp.ndarray, epsilon: float): + n, d, rank = 81, 13, 3 + key1, key2 = jax.random.split(rng, 2) + x = jax.random.normal(key1, (n, d)) + y = jax.random.normal(key2, (n, d)) + pc = pointcloud.PointCloud(x, y, epsilon=5e-1) + prob = linear_problem.LinearProblem(pc) + + solver_random = sinkhorn_lr.LRSinkhorn( + rank=rank, epsilon=epsilon, initializer="random", max_iterations=10000 + ) + solver_init = sinkhorn_lr.LRSinkhorn( + rank=rank, epsilon=epsilon, initializer="k-means", max_iterations=10000 + ) + + out_random = solver_random(prob) + out_init = solver_init(prob) + + assert out_random.converged + assert out_init.converged + # converged earlier + assert (out_init.errors > -1).sum() < (out_random.errors > -1).sum() + # converged to a better solution + assert out_init.reg_ot_cost < out_random.reg_ot_cost diff --git a/tests/initializers/quadratic/gw_init_test.py b/tests/initializers/quadratic/gw_init_test.py new file mode 100644 index 000000000..b900b60fc --- /dev/null +++ b/tests/initializers/quadratic/gw_init_test.py @@ -0,0 +1,132 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Tests for Gromov-Wasserstein initializers.""" + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from ott.geometry import geometry, pointcloud +from ott.initializers.linear import initializers as lin_init +from ott.initializers.linear import initializers_lr +from ott.initializers.quadratic import initializers as quad_init +from ott.problems.quadratic import quadratic_problem +from ott.solvers.quadratic import gromov_wasserstein + + +class TestQuadraticInitializers: + + @pytest.mark.parametrize("kind", ["pc", "lrc", "geom"]) + def test_create_default_lr_initializer(self, rng: jnp.ndarray, kind: str): + n, d1, d2, rank = 150, 2, 3, 5 + eps = 1e-1 + key1, key2 = jax.random.split(rng, 2) + x = jax.random.normal(key1, (n, d1)) + y = jax.random.normal(key1, (n, d2)) + kwargs_init = {"foo": "bar"} + + geom_x = pointcloud.PointCloud(x, epsilon=eps) + geom_y = pointcloud.PointCloud(y, epsilon=eps) + if kind == "pc": + pass + elif kind == "lrc": + geom_x = geom_x.to_LRCGeometry() + geom_y = geom_y.to_LRCGeometry() + elif kind == "geom": + geom_x = geometry.Geometry(geom_x.cost_matrix, epsilon=eps) + geom_y = geometry.Geometry(geom_y.cost_matrix, epsilon=eps) + else: + raise NotImplementedError(kind) + prob = quadratic_problem.QuadraticProblem(geom_x, geom_y) + + solver = gromov_wasserstein.GromovWasserstein( + rank=rank, quad_initializer=None, kwargs_init=kwargs_init + ) + initializer = solver.create_initializer(prob) + + assert isinstance(initializer, quad_init.LRQuadraticInitializer) + assert initializer.rank == rank + linear_init = initializer._linear_lr_initializer + if kind in ("pc", "lrc"): + assert isinstance(linear_init, initializers_lr.KMeansInitializer) + else: + assert isinstance(linear_init, initializers_lr.RandomInitializer) + assert linear_init._kwargs == kwargs_init + + def test_non_lr_initializer(self): + solver = gromov_wasserstein.GromovWasserstein( + rank=-1, quad_initializer="not used" + ) + initializer = solver.create_initializer(prob="not used") + assert isinstance(initializer, quad_init.QuadraticInitializer) + + @pytest.mark.parametrize("rank", [-1, 2]) + def test_explicitly_passing_initializer(self, rank: int): + if rank == -1: + linear_init = lin_init.SortingInitializer() + q_init = quad_init.QuadraticInitializer() + else: + linear_init = initializers_lr.Rank2Initializer(rank) + q_init = quad_init.LRQuadraticInitializer(linear_init) + + solver = gromov_wasserstein.GromovWasserstein( + initializer=linear_init, + quad_initializer=q_init, + ) + + assert solver.linear_ot_solver.initializer is linear_init + assert solver.quad_initializer is q_init + if solver.is_low_rank: + assert solver.quad_initializer.rank == rank + + @pytest.mark.parametrize("eps", [0., 1e-2]) + def test_gw_better_initialization_helps(self, rng: jnp.ndarray, eps: float): + n, m, d1, d2, rank = 123, 124, 12, 10, 5 + key1, key2, key3, key4 = jax.random.split(rng, 4) + + geom_x = pointcloud.PointCloud( + jax.random.normal(key1, (n, d1)), + jax.random.normal(key2, (n, d1)), + epsilon=eps, + ) + geom_y = pointcloud.PointCloud( + jax.random.normal(key3, (m, d2)), + jax.random.normal(key4, (m, d2)), + epsilon=eps, + ) + problem = quadratic_problem.QuadraticProblem(geom_x, geom_y) + solver_random = gromov_wasserstein.GromovWasserstein( + rank=rank, + initializer="random", + quad_initializer="random", + epsilon=eps, + store_inner_errors=True, + ) + solver_kmeans = gromov_wasserstein.GromovWasserstein( + rank=rank, + initializer="k-means", + quad_initializer="k-means", + epsilon=eps, + store_inner_errors=True + ) + + out_random = solver_random(problem) + out_kmeans = solver_kmeans(problem) + + assert out_random.reg_gw_cost - out_kmeans.reg_gw_cost >= 1. + random_errors = out_random.errors[out_random.errors > -1] + kmeans_errors = out_kmeans.errors[out_kmeans.errors > -1] + np.testing.assert_array_equal(random_errors >= 0., True) + np.testing.assert_array_equal(kmeans_errors >= 0., True) diff --git a/tests/geometry/geometry_lse_test.py b/tests/math/lse_test.py similarity index 100% rename from tests/geometry/geometry_lse_test.py rename to tests/math/lse_test.py diff --git a/tests/geometry/matrix_square_root_test.py b/tests/math/matrix_square_root_test.py similarity index 100% rename from tests/geometry/matrix_square_root_test.py rename to tests/math/matrix_square_root_test.py diff --git a/tests/core/potentials_test.py b/tests/problems/linear/potentials_test.py similarity index 100% rename from tests/core/potentials_test.py rename to tests/problems/linear/potentials_test.py diff --git a/tests/core/continuous_barycenter_test.py b/tests/solvers/linear/continuous_barycenter_test.py similarity index 69% rename from tests/core/continuous_barycenter_test.py rename to tests/solvers/linear/continuous_barycenter_test.py index 2578a459c..36fdf1d97 100644 --- a/tests/core/continuous_barycenter_test.py +++ b/tests/solvers/linear/continuous_barycenter_test.py @@ -14,18 +14,16 @@ # Lint as: python3 """Tests for continuous barycenter.""" import functools -from typing import Any, Optional, Sequence, Tuple +from typing import Tuple import jax import jax.numpy as jnp import numpy as np import pytest -from ott.geometry import costs, pointcloud, segment +from ott.geometry import costs, segment from ott.problems.linear import barycenter_problem -from ott.problems.quadratic import gw_barycenter as gwb from ott.solvers.linear import continuous_barycenter as cb -from ott.solvers.quadratic import gw_barycenter as gwb_solver from ott.tools.gaussian_mixture import gaussian_mixture means_and_covs_to_x = jax.vmap(costs.mean_and_cov_to_x, in_axes=[0, 0, None]) @@ -378,155 +376,3 @@ def test_bures_barycenter_different_number_of_components( jax.vmap(is_positive_semidefinite, in_axes=0, out_axes=0)(covs_bary), True ) - - -class TestGWBarycenter: - ndim = 3 - ndim_f = 4 - - @staticmethod - def random_pc( - n: int, - d: int, - rng: jnp.ndarray, - m: Optional[int] = None, - **kwargs: Any - ) -> pointcloud.PointCloud: - key1, key2 = jax.random.split(rng, 2) - x = jax.random.normal(key1, (n, d)) - y = x if m is None else jax.random.normal(key2, (m, d)) - return pointcloud.PointCloud(x, y, batch_size=None, **kwargs) - - @staticmethod - def pad_cost_matrices( - costs: Sequence[jnp.ndarray], - shape: Optional[Tuple[int, int]] = None - ) -> Tuple[jnp.ndarray, jnp.ndarray]: - if shape is None: - shape = jnp.asarray([arr.shape for arr in costs]).max() - shape = (shape, shape) - else: - assert shape[0] == shape[1], shape - - cs, weights = [], [] - for cost in costs: - r, c = cost.shape - cs.append(jnp.zeros(shape).at[:r, :c].set(cost)) - w = jnp.ones(r) / r - weights.append(jnp.concatenate([w, jnp.zeros(shape[0] - r)])) - return jnp.stack(cs), jnp.stack(weights) - - # TODO(cuturi) add back KL test when KL cost GW is fixed. - @pytest.mark.parametrize( - "gw_loss,bar_size,epsilon", - [("sqeucl", 17, None)] #, ("kl", 22, 1e-2)] - ) - def test_gw_barycenter( - self, rng: jnp.ndarray, gw_loss: str, bar_size: int, - epsilon: Optional[float] - ): - tol = 1e-3 if gw_loss == "sqeucl" else 1e-1 - num_per_segment = (13, 15, 21) - rngs = jax.random.split(rng, len(num_per_segment)) - pcs = [ - self.random_pc(n, d=self.ndim, rng=rng) - for n, rng in zip(num_per_segment, rngs) - ] - costs = [pc._compute_cost_matrix() for pc, n in zip(pcs, num_per_segment)] - costs, cbs = self.pad_cost_matrices(costs) - ys = jnp.concatenate([pc.x for pc in pcs]) - bs = jnp.concatenate([jnp.ones(n) / n for n in num_per_segment]) - kwargs = { - "gw_loss": gw_loss, - "num_per_segment": num_per_segment, - "epsilon": epsilon - } - - problem_pc = gwb.GWBarycenterProblem(y=ys, b=bs, **kwargs) - problem_cost = gwb.GWBarycenterProblem( - costs=costs, - b=cbs, - **kwargs, - ) - for prob in [problem_pc, problem_cost]: - assert not prob.is_fused - assert prob.ndim_fused is None - assert prob.num_measures == len(num_per_segment) - assert prob.max_measure_size == max(num_per_segment) - assert prob._loss_name == gw_loss - assert problem_pc.ndim == self.ndim - assert problem_cost.ndim is None - - solver = gwb_solver.GromovWassersteinBarycenter(jit=True) - out_pc = solver(problem_pc, bar_size=bar_size) - out_cost = solver(problem_cost, bar_size=bar_size) - - assert out_pc.x is None - assert out_cost.x is None - assert out_pc.cost.shape == (bar_size, bar_size) - np.testing.assert_allclose(out_pc.cost, out_cost.cost, rtol=tol, atol=tol) - np.testing.assert_allclose(out_pc.costs, out_cost.costs, rtol=tol, atol=tol) - - @pytest.mark.fast( - "jit,fused_penalty,scale_cost", [(False, 1.5, "mean"), - (True, 3.1, "max_cost")], - only_fast=0 - ) - def test_fgw_barycenter( - self, - rng: jnp.ndarray, - jit: bool, - fused_penalty: float, - scale_cost: str, - ): - - def barycenter( - y: jnp.ndim, y_fused: jnp.ndarray, num_per_segment: Tuple[int, ...] - ) -> gwb_solver.GWBarycenterState: - prob = gwb.GWBarycenterProblem( - y=y, - y_fused=y_fused, - num_per_segment=num_per_segment, - fused_penalty=fused_penalty, - scale_cost=scale_cost, - ) - assert prob.is_fused - assert prob.fused_penalty == fused_penalty - assert not prob._y_as_costs - assert prob.max_measure_size == max(num_per_segment) - assert prob.num_measures == len(num_per_segment) - assert prob.ndim == self.ndim - assert prob.ndim_fused == self.ndim_f - - solver = gwb_solver.GromovWassersteinBarycenter( - jit=False, store_inner_errors=True, epsilon=epsilon - ) - - x_init = jax.random.normal(rng, (bar_size, self.ndim_f)) - cost_init = pointcloud.PointCloud(x_init).cost_matrix - - return solver(prob, bar_size=bar_size, bar_init=(cost_init, x_init)) - - bar_size, epsilon, = 10, 1e-1 - num_per_segment = (7, 12) - - key1, *rngs = jax.random.split(rng, len(num_per_segment) + 1) - y = jnp.concatenate([ - self.random_pc(n, d=self.ndim, rng=rng).x - for n, rng in zip(num_per_segment, rngs) - ]) - rngs = jax.random.split(key1, len(num_per_segment)) - y_fused = jnp.concatenate([ - self.random_pc(n, d=self.ndim_f, rng=rng).x - for n, rng in zip(num_per_segment, rngs) - ]) - - fn = jax.jit(barycenter, static_argnums=2) if jit else barycenter - out = fn(y, y_fused, num_per_segment) - - assert out.cost.shape == (bar_size, bar_size) - assert out.x.shape == (bar_size, self.ndim_f) - np.testing.assert_array_equal(jnp.isfinite(out.cost), True) - np.testing.assert_array_equal(jnp.isfinite(out.x), True) - np.testing.assert_array_equal(jnp.isfinite(out.costs), True) - np.testing.assert_array_equal(jnp.isfinite(out.errors), True) diff --git a/tests/core/discrete_barycenter_test.py b/tests/solvers/linear/discrete_barycenter_test.py similarity index 99% rename from tests/core/discrete_barycenter_test.py rename to tests/solvers/linear/discrete_barycenter_test.py index 2b09ac991..8f2bcfa0b 100644 --- a/tests/core/discrete_barycenter_test.py +++ b/tests/solvers/linear/discrete_barycenter_test.py @@ -13,8 +13,6 @@ # limitations under the License. # Lint as: python3 -"""Tests for the Policy.""" - import jax.numpy as jnp import pytest diff --git a/tests/core/sinkhorn_diff_test.py b/tests/solvers/linear/sinkhorn_diff_test.py similarity index 100% rename from tests/core/sinkhorn_diff_test.py rename to tests/solvers/linear/sinkhorn_diff_test.py diff --git a/tests/core/sinkhorn_grid_test.py b/tests/solvers/linear/sinkhorn_grid_test.py similarity index 100% rename from tests/core/sinkhorn_grid_test.py rename to tests/solvers/linear/sinkhorn_grid_test.py diff --git a/tests/core/sinkhorn_lr_test.py b/tests/solvers/linear/sinkhorn_lr_test.py similarity index 100% rename from tests/core/sinkhorn_lr_test.py rename to tests/solvers/linear/sinkhorn_lr_test.py diff --git a/tests/core/sinkhorn_extra_test.py b/tests/solvers/linear/sinkhorn_misc_test.py similarity index 99% rename from tests/core/sinkhorn_extra_test.py rename to tests/solvers/linear/sinkhorn_misc_test.py index 2f64a4c4a..fd1664b50 100644 --- a/tests/core/sinkhorn_extra_test.py +++ b/tests/solvers/linear/sinkhorn_misc_test.py @@ -13,7 +13,7 @@ # limitations under the License. # Lint as: python3 -"""Tests Anderson acceleration for sinkhorn.""" +"""Tests Anderson acceleration for Sinkhorn.""" import functools from typing import Callable, Tuple diff --git a/tests/core/sinkhorn_test.py b/tests/solvers/linear/sinkhorn_test.py similarity index 100% rename from tests/core/sinkhorn_test.py rename to tests/solvers/linear/sinkhorn_test.py diff --git a/tests/core/icnn_test.py b/tests/solvers/nn/icnn_test.py similarity index 100% rename from tests/core/icnn_test.py rename to tests/solvers/nn/icnn_test.py diff --git a/tests/core/neuraldual_test.py b/tests/solvers/nn/neuraldual_test.py similarity index 100% rename from tests/core/neuraldual_test.py rename to tests/solvers/nn/neuraldual_test.py diff --git a/tests/solvers/quadratic/fgw_barycenter_test.py b/tests/solvers/quadratic/fgw_barycenter_test.py new file mode 100644 index 000000000..d3dca9ad6 --- /dev/null +++ b/tests/solvers/quadratic/fgw_barycenter_test.py @@ -0,0 +1,78 @@ +"""Tests for Fused Gromov-Wasserstein barycenter.""" +from typing import Tuple + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from ott.geometry import pointcloud +from ott.problems.quadratic import gw_barycenter as gwb +from ott.solvers.quadratic import gw_barycenter as gwb_solver + + +class FGWBarycenterTest: + + @pytest.mark.fast( + "jit,fused_penalty,scale_cost", [(False, 1.5, "mean"), + (True, 3.1, "max_cost")], + only_fast=0 + ) + def test_fgw_barycenter( + self, + rng: jnp.ndarray, + jit: bool, + fused_penalty: float, + scale_cost: str, + ): + + def barycenter( + y: jnp.ndim, y_fused: jnp.ndarray, num_per_segment: Tuple[int, ...] + ) -> gwb_solver.GWBarycenterState: + prob = gwb.GWBarycenterProblem( + y=y, + y_fused=y_fused, + num_per_segment=num_per_segment, + fused_penalty=fused_penalty, + scale_cost=scale_cost, + ) + assert prob.is_fused + assert prob.fused_penalty == fused_penalty + assert not prob._y_as_costs + assert prob.max_measure_size == max(num_per_segment) + assert prob.num_measures == len(num_per_segment) + assert prob.ndim == self.ndim + assert prob.ndim_fused == self.ndim_f + + solver = gwb_solver.GromovWassersteinBarycenter( + jit=False, store_inner_errors=True, epsilon=epsilon + ) + + x_init = jax.random.normal(rng, (bar_size, self.ndim_f)) + cost_init = pointcloud.PointCloud(x_init).cost_matrix + + return solver(prob, bar_size=bar_size, bar_init=(cost_init, x_init)) + + bar_size, epsilon, = 10, 1e-1 + num_per_segment = (7, 12) + + key1, *rngs = jax.random.split(rng, len(num_per_segment) + 1) + y = jnp.concatenate([ + self.random_pc(n, d=self.ndim, rng=rng).x + for n, rng in zip(num_per_segment, rngs) + ]) + rngs = jax.random.split(key1, len(num_per_segment)) + y_fused = jnp.concatenate([ + self.random_pc(n, d=self.ndim_f, rng=rng).x + for n, rng in zip(num_per_segment, rngs) + ]) + + fn = jax.jit(barycenter, static_argnums=2) if jit else barycenter + out = fn(y, y_fused, num_per_segment) + + assert out.cost.shape == (bar_size, bar_size) + assert out.x.shape == (bar_size, self.ndim_f) + np.testing.assert_array_equal(jnp.isfinite(out.cost), True) + np.testing.assert_array_equal(jnp.isfinite(out.x), True) + np.testing.assert_array_equal(jnp.isfinite(out.costs), True) + np.testing.assert_array_equal(jnp.isfinite(out.errors), True) diff --git a/tests/core/fused_gromov_wasserstein_test.py b/tests/solvers/quadratic/fgw_test.py similarity index 92% rename from tests/core/fused_gromov_wasserstein_test.py rename to tests/solvers/quadratic/fgw_test.py index ebd15e7de..338ffafb1 100644 --- a/tests/core/fused_gromov_wasserstein_test.py +++ b/tests/solvers/quadratic/fgw_test.py @@ -23,7 +23,7 @@ from ott.geometry import geometry, low_rank, pointcloud from ott.problems.quadratic import quadratic_problem -from ott.solvers.quadratic import gromov_wasserstein as gwb_solver +from ott.solvers.quadratic import gromov_wasserstein as gw_solver class TestFusedGromovWasserstein: @@ -50,13 +50,13 @@ def initialize(self, rng: jnp.ndarray): self.cy = jax.random.uniform(keys[5], (self.m, self.m)) self.cxy = jax.random.uniform(keys[6], (self.n, self.m)) - def test_flag_store_errors_fused(self): + def test_fgw_flag_store_errors_fused(self): """Tests whether errors are properly stored if requested.""" threshold_sinkhorn = 1e-2 geom_x = pointcloud.PointCloud(self.x) geom_y = pointcloud.PointCloud(self.y) geom_xy = pointcloud.PointCloud(self.x_2, self.y_2) - out = gwb_solver.gromov_wasserstein( + out = gw_solver.gromov_wasserstein( geom_xx=geom_x, geom_yy=geom_y, geom_xy=geom_xy, @@ -67,7 +67,7 @@ def test_flag_store_errors_fused(self): ).errors assert out is None - out = gwb_solver.gromov_wasserstein( + out = gw_solver.gromov_wasserstein( geom_xx=geom_x, geom_yy=geom_y, geom_xy=geom_xy, @@ -87,7 +87,7 @@ def test_flag_store_errors_fused(self): assert out.ndim == 2 @pytest.mark.fast.with_args(jit=[False, True], only_fast=1) - def test_gradient_marginals_fused_gwb_solver(self, jit: bool): + def test_gradient_marginals_fgw_solver(self, jit: bool): """Test gradient w.r.t. probability weights.""" geom_x = pointcloud.PointCloud(self.x) geom_y = pointcloud.PointCloud(self.y) @@ -99,7 +99,7 @@ def reg_gw(a, b, implicit): 'implicit_differentiation': implicit, 'max_iterations': 1001 } - out = gwb_solver.gromov_wasserstein( + out = gw_solver.gromov_wasserstein( geom_x, geom_y, geom_xy=geom_xy, @@ -137,14 +137,14 @@ def reg_gw(a, b, implicit): ) @pytest.mark.fast.with_args(lse_mode=[False, True], only_fast=1) - def test_fused_gwb_solver_pointcloud(self, lse_mode: bool): + def test_fgw_solver_pointcloud(self, lse_mode: bool): """Test basic computations pointclouds.""" def reg_gw(x, y, x_2, y_2, fused_penalty, a, b): geom_x = pointcloud.PointCloud(x) geom_y = pointcloud.PointCloud(y) geom_xy = pointcloud.PointCloud(x_2, y_2) - return gwb_solver.gromov_wasserstein( + return gw_solver.gromov_wasserstein( geom_x, geom_y, geom_xy=geom_xy, @@ -164,7 +164,7 @@ def reg_gw(x, y, x_2, y_2, fused_penalty, a, b): assert cost is not None @pytest.mark.parametrize("lse_mode", [False, True]) - def test_gradient_fused_gwb_solver_pointcloud(self, lse_mode: bool): + def test_gradient_fgw_solver_pointcloud(self, lse_mode: bool): """Test gradient w.r.t. pointclouds.""" def reg_gw(x, y, x_2, y_2, fused_penalty, a, b, implicit): @@ -176,7 +176,7 @@ def reg_gw(x, y, x_2, y_2, fused_penalty, a, b, implicit): 'max_iterations': 1001, 'lse_mode': lse_mode } - return gwb_solver.gromov_wasserstein( + return gw_solver.gromov_wasserstein( geom_x, geom_y, geom_xy=geom_xy, @@ -207,7 +207,7 @@ def reg_gw(x, y, x_2, y_2, fused_penalty, a, b, implicit): ) @pytest.mark.parametrize("lse_mode", [False, True]) - def test_gradient_fused_gwb_solver_geometry(self, lse_mode: bool): + def test_gradient_fgw_solver_geometry(self, lse_mode: bool): """Test gradient w.r.t. cost matrices.""" def reg_gw(cx, cy, cxy, fused_penalty, a, b, implicit): @@ -219,7 +219,7 @@ def reg_gw(cx, cy, cxy, fused_penalty, a, b, implicit): 'max_iterations': 1001, 'lse_mode': lse_mode } - return gwb_solver.gromov_wasserstein( + return gw_solver.gromov_wasserstein( geom_x, geom_y, geom_xy=geom_xy, @@ -252,7 +252,7 @@ def reg_gw(cx, cy, cxy, fused_penalty, a, b, implicit): grad_matrices[0][2], grad_matrices[1][2], rtol=1e-02, atol=1e-02 ) - def test_adaptive_threshold_fused(self): + def test_fgw_adaptive_threshold(self): """Checking solution is improved with smaller threshold for convergence.""" geom_x = pointcloud.PointCloud(self.x, self.x) geom_y = pointcloud.PointCloud(self.y, self.y) @@ -260,7 +260,7 @@ def test_adaptive_threshold_fused(self): # without warm start for calls to sinkhorn def loss_thre(threshold: float) -> float: - return gwb_solver.gromov_wasserstein( + return gw_solver.gromov_wasserstein( geom_xx=geom_x, geom_yy=geom_y, geom_xy=geom_xy, @@ -275,7 +275,7 @@ def loss_thre(threshold: float) -> float: assert loss_thre(1e-3) > loss_thre(1e-5) @pytest.mark.parametrize("lse_mode", [False, True]) - def test_gradient_fused_gwb_solver_penalty(self, lse_mode: bool): + def test_gradient_fgw_solver_penalty(self, lse_mode: bool): """Test gradient w.r.t. penalty.""" def reg_gw(cx, cy, cxy, fused_penalty, a, b, implicit): @@ -287,7 +287,7 @@ def reg_gw(cx, cy, cxy, fused_penalty, a, b, implicit): 'max_iterations': 1001, 'lse_mode': lse_mode } - return gwb_solver.gromov_wasserstein( + return gw_solver.gromov_wasserstein( geom_x, geom_y, geom_xy=geom_xy, @@ -319,7 +319,7 @@ def reg_fgw(x, y, x_2, y_2, fused_penalty, a, b): geom_y = pointcloud.PointCloud(y) geom_xy = pointcloud.PointCloud(x_2, y_2) sinkhorn_kwargs = {'max_iterations': 1001} - return gwb_solver.gromov_wasserstein( + return gw_solver.gromov_wasserstein( geom_x, geom_y, geom_xy=geom_xy, @@ -334,7 +334,7 @@ def reg_gw(x, y, a, b): geom_x = pointcloud.PointCloud(x) geom_y = pointcloud.PointCloud(y) sinkhorn_kwargs = {'max_iterations': 1001} - return gwb_solver.gromov_wasserstein( + return gw_solver.gromov_wasserstein( geom_x, geom_y, a=a, @@ -367,7 +367,7 @@ def test_fgw_lr_memory(self, rng: jnp.ndarray, jit: bool): geom_y = pointcloud.PointCloud(y) geom_xy = pointcloud.PointCloud(xx, yy) - ot_gwlr = gwb_solver.gromov_wasserstein( + ot_gwlr = gw_solver.gromov_wasserstein( geom_x, geom_y, geom_xy, rank=5, jit=jit ) res0 = ot_gwlr.apply(x.T, axis=0) @@ -399,7 +399,7 @@ def test_fgw_lr_generic_cost_matrix( lr_prob = problem.to_low_rank() assert lr_prob.is_low_rank - solver = gwb_solver.GromovWasserstein(rank=5, epsilon=1) + solver = gw_solver.GromovWasserstein(rank=5, epsilon=1) out = solver(problem) assert solver.rank == 5 diff --git a/tests/solvers/quadratic/gw_barycenter_test.py b/tests/solvers/quadratic/gw_barycenter_test.py new file mode 100644 index 000000000..94cd5759b --- /dev/null +++ b/tests/solvers/quadratic/gw_barycenter_test.py @@ -0,0 +1,113 @@ +# Copyright 2022 Apple +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Tests for Gromov-Wasserstein barycenter.""" +from typing import Any, Optional, Sequence, Tuple + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from ott.geometry import pointcloud +from ott.problems.quadratic import gw_barycenter as gwb +from ott.solvers.quadratic import gw_barycenter as gwb_solver + + +class TestGWBarycenter: + ndim = 3 + ndim_f = 4 + + @staticmethod + def random_pc( + n: int, + d: int, + rng: jnp.ndarray, + m: Optional[int] = None, + **kwargs: Any + ) -> pointcloud.PointCloud: + key1, key2 = jax.random.split(rng, 2) + x = jax.random.normal(key1, (n, d)) + y = x if m is None else jax.random.normal(key2, (m, d)) + return pointcloud.PointCloud(x, y, batch_size=None, **kwargs) + + @staticmethod + def pad_cost_matrices( + costs: Sequence[jnp.ndarray], + shape: Optional[Tuple[int, int]] = None + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + if shape is None: + shape = jnp.asarray([arr.shape for arr in costs]).max() + shape = (shape, shape) + else: + assert shape[0] == shape[1], shape + + cs, weights = [], [] + for cost in costs: + r, c = cost.shape + cs.append(jnp.zeros(shape).at[:r, :c].set(cost)) + w = jnp.ones(r) / r + weights.append(jnp.concatenate([w, jnp.zeros(shape[0] - r)])) + return jnp.stack(cs), jnp.stack(weights) + + # TODO(cuturi) add back KL test when KL cost GW is fixed. + @pytest.mark.parametrize( + "gw_loss,bar_size,epsilon", + [("sqeucl", 17, None)] # , ("kl", 22, 1e-2)] + ) + def test_gw_barycenter( + self, rng: jnp.ndarray, gw_loss: str, bar_size: int, + epsilon: Optional[float] + ): + tol = 1e-3 if gw_loss == "sqeucl" else 1e-1 + num_per_segment = (13, 15, 21) + rngs = jax.random.split(rng, len(num_per_segment)) + pcs = [ + self.random_pc(n, d=self.ndim, rng=rng) + for n, rng in zip(num_per_segment, rngs) + ] + costs = [pc._compute_cost_matrix() for pc, n in zip(pcs, num_per_segment)] + costs, cbs = self.pad_cost_matrices(costs) + ys = jnp.concatenate([pc.x for pc in pcs]) + bs = jnp.concatenate([jnp.ones(n) / n for n in num_per_segment]) + kwargs = { + "gw_loss": gw_loss, + "num_per_segment": num_per_segment, + "epsilon": epsilon + } + + problem_pc = gwb.GWBarycenterProblem(y=ys, b=bs, **kwargs) + problem_cost = gwb.GWBarycenterProblem( + costs=costs, + b=cbs, + **kwargs, + ) + for prob in [problem_pc, problem_cost]: + assert not prob.is_fused + assert prob.ndim_fused is None + assert prob.num_measures == len(num_per_segment) + assert prob.max_measure_size == max(num_per_segment) + assert prob._loss_name == gw_loss + assert problem_pc.ndim == self.ndim + assert problem_cost.ndim is None + + solver = gwb_solver.GromovWassersteinBarycenter(jit=True) + out_pc = solver(problem_pc, bar_size=bar_size) + out_cost = solver(problem_cost, bar_size=bar_size) + + assert out_pc.x is None + assert out_cost.x is None + assert out_pc.cost.shape == (bar_size, bar_size) + np.testing.assert_allclose(out_pc.cost, out_cost.cost, rtol=tol, atol=tol) + np.testing.assert_allclose(out_pc.costs, out_cost.costs, rtol=tol, atol=tol) diff --git a/tests/core/gromov_wasserstein_test.py b/tests/solvers/quadratic/gw_test.py similarity index 96% rename from tests/core/gromov_wasserstein_test.py rename to tests/solvers/quadratic/gw_test.py index 27e7f58b8..0e37a9c6d 100644 --- a/tests/core/gromov_wasserstein_test.py +++ b/tests/solvers/quadratic/gw_test.py @@ -88,7 +88,7 @@ def test_quad_to_low_rank( assert lr_prob._is_low_rank_convertible assert lr_prob.to_low_rank() is lr_prob - def test_implicit_conversion_mixed_input(self, rng: jnp.ndarray): + def test_gw_implicit_conversion_mixed_input(self, rng: jnp.ndarray): n, m, d1, d2 = 200, 300, 20, 25 k1, k2 = jax.random.split(rng, 2) x = jax.random.normal(k1, (n, d1)) @@ -152,7 +152,7 @@ def test_flag_store_errors(self): assert out.ndim == 2 @pytest.mark.parametrize("jit", [False, True]) - def test_gradient_marginals_gromov_wasserstein(self, jit: bool): + def test_gradient_marginals_gw(self, jit: bool): """Test gradient w.r.t. probability weights.""" geom_x = pointcloud.PointCloud(self.x) geom_y = pointcloud.PointCloud(self.y) @@ -199,7 +199,7 @@ def reg_gw(a, b, implicit): ) @pytest.mark.fast - def test_gromov_wasserstein_pointcloud(self): + def test_gw_pointcloud(self): """Test basic computations pointclouds.""" def reg_gw(x, y, a, b): @@ -212,7 +212,7 @@ def reg_gw(x, y, a, b): assert not jnp.isnan(reg_gw(self.x, self.y, self.a, self.b)) @pytest.mark.parametrize("lse_mode", [False, True]) - def test_gradient_gromov_wasserstein_pointcloud(self, lse_mode: bool): + def test_gradient_gw_pointcloud(self, lse_mode: bool): """Test gradient w.r.t. pointclouds.""" def reg_gw(x, y, a, b, implicit): @@ -254,7 +254,7 @@ def reg_gw(x, y, a, b, implicit): ) @pytest.mark.parametrize("lse_mode", [False, True]) - def test_gradient_gromov_wasserstein_geometry(self, lse_mode: bool): + def test_gradient_gw_geometry(self, lse_mode: bool): """Test gradient w.r.t. cost matrices.""" def reg_gw(cx, cy, a, b, implicit): @@ -296,7 +296,7 @@ def reg_gw(cx, cy, a, b, implicit): grad_matrices[0][1], grad_matrices[1][1], rtol=1e-02, atol=1e-02 ) - def test_adaptive_threshold(self): + def test_gw_adaptive_threshold(self): """Checking solution is improved with smaller threshold for convergence.""" geom_x = pointcloud.PointCloud(self.x, self.x) geom_y = pointcloud.PointCloud(self.y, self.y) @@ -466,7 +466,7 @@ def initialize(self, rng: jnp.ndarray): self.tau_b = 0.9 @pytest.mark.fast - def test_gromov_wasserstein_pointcloud(self): + def test_gw_pointcloud(self): """Test basic computations pointclouds.""" def reg_gw(x, y, a, b): @@ -487,9 +487,7 @@ def reg_gw(x, y, a, b): assert not jnp.isnan(cost) @pytest.mark.parametrize("gw_unbalanced_correction", [False, True]) - def test_gradient_gromov_wasserstein_pointcloud( - self, gw_unbalanced_correction: bool - ): + def test_gradient_gw_pointcloud(self, gw_unbalanced_correction: bool): """Test gradient w.r.t. pointclouds.""" def reg_gw(x, y, a, b, implicit): @@ -533,9 +531,7 @@ def reg_gw(x, y, a, b, implicit): ) @pytest.mark.parametrize("gw_unbalanced_correction", [False, True]) - def test_gradient_gromov_wasserstein_geometry( - self, gw_unbalanced_correction: bool - ): + def test_gradient_gw_geometry(self, gw_unbalanced_correction: bool): """Test gradient w.r.t. cost matrices.""" def reg_gw(cx, cy, a, b, implicit): From 0f1ad3b109909f7dd04063476ba4c511e4c05622 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Tue, 22 Nov 2022 09:43:52 +0100 Subject: [PATCH 28/34] Update badges --- README.md | 1 + docs/index.rst | 21 ++++++++++++++++++++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 6dede0b38..cb44680a9 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ logo # Optimal Transport Tools (OTT) +[![Downloads](https://pepy.tech/badge/ott-jax)](https://pypi.org/project/ott-jax/) [![Tests](https://img.shields.io/github/workflow/status/ott-jax/ott/tests/main)](https://github.com/ott-jax/ott/actions/workflows/tests.yml) [![Docs](https://img.shields.io/readthedocs/ott-jax/latest)](https://ott-jax.readthedocs.io/en/latest/) [![Coverage](https://img.shields.io/codecov/c/github/ott-jax/ott/main)](https://app.codecov.io/gh/ott-jax/ott) diff --git a/docs/index.rst b/docs/index.rst index 24a376771..8623d61d5 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,6 +1,8 @@ +|Downloads| |Tests| |Docs| |Coverage| + Optimal Transport Tools (OTT) documentation =========================================== -`Code `_ on github. +`Code `_ on GitHub. To install, simply run ``pip install ott-jax``. Intro @@ -106,6 +108,23 @@ Packages references + +.. |Downloads| image:: https://pepy.tech/badge/ott-jax + :target: https://pypi.org/project/ott-jax/ + :alt: Documentation + +.. |Tests| image:: https://img.shields.io/github/workflow/status/ott-jax/ott/tests/main + :target: https://github.com/ott-jax/ott/actions/workflows/tests.yml + :alt: Documentation + +.. |Docs| image:: https://img.shields.io/readthedocs/ott-jax/latest + :target: https://ott-jax.readthedocs.io/en/latest/ + :alt: Documentation + +.. |Coverage| image:: https://img.shields.io/codecov/c/github/ott-jax/ott/main + :target: https://app.codecov.io/gh/ott-jax/ott + :alt: Coverage + .. _Just-in-time (JIT) compilation: https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit .. _auto-vectorization (VMAP): https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap .. _automatic: https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation From 75e358e56fa4233d9087d2e84ce8e740ec2444e6 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Tue, 22 Nov 2022 10:38:49 +0100 Subject: [PATCH 29/34] Add TODOs, fix citation in `index.rst`, move `implicit_diff` --- docs/index.rst | 12 +++++++----- docs/initializers/index.rst | 2 +- docs/initializers/linear.rst | 2 ++ docs/initializers/nn.rst | 2 ++ docs/initializers/quadratic.rst | 2 ++ docs/math.rst | 7 +------ docs/problems/index.rst | 2 +- docs/problems/linear.rst | 2 ++ docs/problems/quadratic.rst | 2 ++ docs/solvers/index.rst | 2 +- docs/solvers/linear.rst | 9 +++++++++ docs/solvers/nn.rst | 2 ++ docs/solvers/quadratic.rst | 2 ++ ott/math/__init__.py | 1 - ott/solvers/linear/__init__.py | 1 + .../linear}/implicit_differentiation.py | 0 ott/solvers/linear/sinkhorn.py | 5 ++--- tests/geometry/graph_test.py | 2 +- 18 files changed, 38 insertions(+), 19 deletions(-) rename ott/{math => solvers/linear}/implicit_differentiation.py (100%) diff --git a/docs/index.rst b/docs/index.rst index 8623d61d5..037233463 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -17,7 +17,7 @@ such as differentiable approximations to ranking or even clustering. To achieve this, `OTT` rests on two families of tools: The first family consists in *discrete* solvers computing transport between point clouds, using the Sinkhorn :cite:`cuturi:13` and low-rank Sinkhorn :cite:`scetbon:21` algorithms, -and moving up towards Gromov-Wasserstein :cite:`memoli:11`, :cite:`memoli:11`; +and moving up towards Gromov-Wasserstein :cite:`memoli:11,peyre:16`; the second family consists in *continuous* solvers, using suitable neural architectures :cite:`amos:17` coupled with SGD type estimators :cite:`makkuva:20,korotin:21`. @@ -43,20 +43,22 @@ Design Choices automatically in higher level calls (e.g. updates in Gromov-Wasserstein), without requiring any attention from the user. +.. TODO(marcocuturi): add missing package descriptions below + Packages -------- - :ref:`geometry` contains classes to instantiate objects that describe *two point clouds* paired with a *cost* function. Geometry objects are used to describe OT problems, handled by solvers in the :ref:`solvers`. -- :ref:`problems` TODO(marcocuturi) -- :ref:`solvers` TODO(marcocuturi) -- :ref:`initializers` TODO(marcocuturi) +- :ref:`problems` +- :ref:`solvers` +- :ref:`initializers` - :ref:`tools` provides an interface to exploit OT solutions, as produced by solvers in the :ref:`solvers`. Such tasks include computing approximations to Wasserstein distances :cite:`genevay:18,sejourne:19`, approximating OT between GMMs, or computing differentiable sort and quantile operations :cite:`cuturi:19`. -- :ref:`math` TODO(marcocuturi) +- :ref:`math` .. toctree:: :maxdepth: 1 diff --git a/docs/initializers/index.rst b/docs/initializers/index.rst index 4baba57f5..b591be3b2 100644 --- a/docs/initializers/index.rst +++ b/docs/initializers/index.rst @@ -3,7 +3,7 @@ ott.initializers package ======================== -TODO(cuturi): add some nice text here please. +.. TODO(cuturi): add some nice text here please .. currentmodule:: ott.initializers .. automodule:: ott.initializers diff --git a/docs/initializers/linear.rst b/docs/initializers/linear.rst index 8fb5d89ea..ad3143fa1 100644 --- a/docs/initializers/linear.rst +++ b/docs/initializers/linear.rst @@ -3,6 +3,8 @@ ott.initializers.linear package .. currentmodule:: ott.initializers.linear .. automodule:: ott.initializers.linear +.. TODO(marcocuturi): maybe add some text here + Sinkhorn Initializers --------------------- .. autosummary:: diff --git a/docs/initializers/nn.rst b/docs/initializers/nn.rst index 6f439f6a8..2f88e0999 100644 --- a/docs/initializers/nn.rst +++ b/docs/initializers/nn.rst @@ -3,6 +3,8 @@ ott.initializers.nn package .. currentmodule:: ott.initializers.nn .. automodule:: ott.initializers.nn +.. TODO(marcocuturi): maybe add some text here + Neural Initializers ------------------- .. autosummary:: diff --git a/docs/initializers/quadratic.rst b/docs/initializers/quadratic.rst index d3ea718a9..1929bd380 100644 --- a/docs/initializers/quadratic.rst +++ b/docs/initializers/quadratic.rst @@ -3,6 +3,8 @@ ott.initializers.quadratic package .. currentmodule:: ott.initializers.quadratic .. automodule:: ott.initializers.quadratic +.. TODO(marcocuturi): maybe add some text here + Gromov-Wasserstein Initializers ------------------------------- .. autosummary:: diff --git a/docs/math.rst b/docs/math.rst index de58a29e2..20ea2fc3f 100644 --- a/docs/math.rst +++ b/docs/math.rst @@ -5,12 +5,7 @@ ott.math package .. currentmodule:: ott.math .. automodule:: ott.math -Implicit Differentiation ------------------------- -.. autosummary:: - :toctree: _autosummary - - implicit_differentiation.ImplicitDiff +.. TODO(marcocuturi): maybe add some text here Fixed-point Iteration --------------------- diff --git a/docs/problems/index.rst b/docs/problems/index.rst index 462411f7d..16e5ead90 100644 --- a/docs/problems/index.rst +++ b/docs/problems/index.rst @@ -3,7 +3,7 @@ ott.problems package ==================== -TODO(cuturi): add some nice text here please. +.. TODO(marcocuturi): add some nice text here please .. currentmodule:: ott.problems .. automodule:: ott.problems diff --git a/docs/problems/linear.rst b/docs/problems/linear.rst index 4f4ce5cb2..d8b442e15 100644 --- a/docs/problems/linear.rst +++ b/docs/problems/linear.rst @@ -3,6 +3,8 @@ ott.problems.linear package .. currentmodule:: ott.problems.linear .. automodule:: ott.problems.linear +.. TODO(marcocuturi): maybe add some text here + OT Problems ----------- .. autosummary:: diff --git a/docs/problems/quadratic.rst b/docs/problems/quadratic.rst index 900081871..e7e8c32d1 100644 --- a/docs/problems/quadratic.rst +++ b/docs/problems/quadratic.rst @@ -3,6 +3,8 @@ ott.problems.quadratic package .. currentmodule:: ott.problems.quadratic .. automodule:: ott.problems.quadratic +.. TODO(marcocuturi): maybe add some text here + OT Problems ----------- .. autosummary:: diff --git a/docs/solvers/index.rst b/docs/solvers/index.rst index b903983a7..8b7d62532 100644 --- a/docs/solvers/index.rst +++ b/docs/solvers/index.rst @@ -3,7 +3,7 @@ ott.solvers package =================== -TODO(cuturi): add some nice text here please. +.. TODO(marcocuturi): add some nice text here please .. currentmodule:: ott.solvers .. automodule:: ott.solvers diff --git a/docs/solvers/linear.rst b/docs/solvers/linear.rst index 8e39c984e..0605bd7e1 100644 --- a/docs/solvers/linear.rst +++ b/docs/solvers/linear.rst @@ -3,6 +3,8 @@ ott.solvers.linear package .. currentmodule:: ott.solvers.linear .. automodule:: ott.solvers.linear +.. TODO(marcocuturi): maybe add some text here + Sinkhorn Solvers ---------------- .. autosummary:: @@ -31,3 +33,10 @@ Sinkhorn Acceleration acceleration.Momentum acceleration.AndersonAcceleration + +Implicit Differentiation +------------------------ +.. autosummary:: + :toctree: _autosummary + + implicit_differentiation.ImplicitDiff diff --git a/docs/solvers/nn.rst b/docs/solvers/nn.rst index cd28c655f..08bd7fc1a 100644 --- a/docs/solvers/nn.rst +++ b/docs/solvers/nn.rst @@ -3,6 +3,8 @@ ott.solvers.nn package .. currentmodule:: ott.solvers.nn .. automodule:: ott.solvers.nn +.. TODO(marcocuturi): maybe add some text here + Neural Dual ----------- .. autosummary:: diff --git a/docs/solvers/quadratic.rst b/docs/solvers/quadratic.rst index 33b0a9014..9f6ea7a38 100644 --- a/docs/solvers/quadratic.rst +++ b/docs/solvers/quadratic.rst @@ -3,6 +3,8 @@ ott.solvers.quadratic package .. currentmodule:: ott.solvers.quadratic .. automodule:: ott.solvers.quadratic +.. TODO(marcocuturi): maybe add some text here + Gromov-Wasserstein Solvers -------------------------- .. autosummary:: diff --git a/ott/math/__init__.py b/ott/math/__init__.py index 60c303c2e..67aca8931 100644 --- a/ott/math/__init__.py +++ b/ott/math/__init__.py @@ -1,7 +1,6 @@ from . import ( decomposition, fixed_point_loop, - implicit_differentiation, matrix_square_root, unbalanced_functions, utils, diff --git a/ott/solvers/linear/__init__.py b/ott/solvers/linear/__init__.py index 7f9d3a1bb..40034b929 100644 --- a/ott/solvers/linear/__init__.py +++ b/ott/solvers/linear/__init__.py @@ -2,6 +2,7 @@ acceleration, continuous_barycenter, discrete_barycenter, + implicit_differentiation, sinkhorn, sinkhorn_lr, ) diff --git a/ott/math/implicit_differentiation.py b/ott/solvers/linear/implicit_differentiation.py similarity index 100% rename from ott/math/implicit_differentiation.py rename to ott/solvers/linear/implicit_differentiation.py diff --git a/ott/solvers/linear/sinkhorn.py b/ott/solvers/linear/sinkhorn.py index bfc86c4a5..aebee0038 100644 --- a/ott/solvers/linear/sinkhorn.py +++ b/ott/solvers/linear/sinkhorn.py @@ -23,11 +23,10 @@ from ott.geometry import geometry from ott.initializers.linear import initializers as init_lib -from ott.math import fixed_point_loop -from ott.math import implicit_differentiation as implicit_lib -from ott.math import unbalanced_functions +from ott.math import fixed_point_loop, unbalanced_functions from ott.problems.linear import linear_problem, potentials from ott.solvers.linear import acceleration +from ott.solvers.linear import implicit_differentiation as implicit_lib __all__ = ["Sinkhorn", "SinkhornOutput"] diff --git a/tests/geometry/graph_test.py b/tests/geometry/graph_test.py index 97b245a59..58f766890 100644 --- a/tests/geometry/graph_test.py +++ b/tests/geometry/graph_test.py @@ -13,8 +13,8 @@ from ott.geometry import geometry, graph from ott.math import decomposition -from ott.math import implicit_differentiation as implicit_lib from ott.problems.linear import linear_problem +from ott.solvers.linear import implicit_differentiation as implicit_lib from ott.solvers.linear import sinkhorn # we mix both dense/sparse tests From 8c1ed2ae1d662b6327aa2f6d37e5d16939738504 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Tue, 22 Nov 2022 11:03:15 +0100 Subject: [PATCH 30/34] Fix implicit_diff, TODOs in costs --- docs/notebooks/Hessians.ipynb | 2 +- ott/geometry/costs.py | 45 +++++++++++----------- tests/solvers/linear/sinkhorn_diff_test.py | 2 +- 3 files changed, 25 insertions(+), 24 deletions(-) diff --git a/docs/notebooks/Hessians.ipynb b/docs/notebooks/Hessians.ipynb index 51383ca35..e0ce67c72 100644 --- a/docs/notebooks/Hessians.ipynb +++ b/docs/notebooks/Hessians.ipynb @@ -45,7 +45,7 @@ "\n", "from ott.tools import sinkhorn_divergence\n", "from ott.geometry import pointcloud\n", - "from ott.math import implicit_differentiation as implicit_lib\n", + "from ott.solvers.linear import implicit_differentiation as implicit_lib\n", "import matplotlib.pyplot as plt" ] }, diff --git a/ott/geometry/costs.py b/ott/geometry/costs.py index eefcf1613..7f35046e5 100644 --- a/ott/geometry/costs.py +++ b/ott/geometry/costs.py @@ -49,35 +49,35 @@ class CostFn(abc.ABC): norm: Optional[Callable[[jnp.ndarray], Union[float, jnp.ndarray]]] = None @abc.abstractmethod - def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: + def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: pass # TODO(michalk8): make weights optional? - def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> float: + def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> jnp.ndarray: """Barycentric projection. Args: - weights: TODO. - xs: TODO. + weights: Weights of the points. + xs: Points to project. Returns: - TODO. + The barycentric projection. """ raise NotImplementedError("Barycenter is not yet implemented.") @classmethod def _padder(cls, dim: int) -> jnp.ndarray: - """TODO. + """Create a padding vector for easier jitting. Args: - dim: TODO. + dim: Dimensionality of the data. Returns: - TODO. + The padding vector. """ return jnp.zeros((1, dim)) - def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> float: + def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: cost = self.pairwise(x, y) if self.norm is None: return cost @@ -139,7 +139,7 @@ def h_legendre(self, z: jnp.ndarray) -> float: """Legendre transform of :func:`h` when it is convex.""" raise NotImplementedError("`h_legendre` not implemented.") - def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: + def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: """Compute cost as evaluation of :func:`h` on :math:`x-y`.""" return self.h(x - y) @@ -152,7 +152,7 @@ class SqPNorm(TICost): the reference :cite:`boyd:04`, p.93/94. Args: - p: TODO. + p: Power of the p-norm. """ def __init__(self, p: float): @@ -181,14 +181,14 @@ class PNorm(TICost): """p-norm (to the power p) of the difference of two vectors. Args: - p: TODO. + p: Power of the p-norm. """ def __init__(self, p: float): super().__init__() assert p >= 1.0, "p parameter in p-norm should be >= 1.0" self.p = p - # TODO(marcocuturi): fid case when `p=1` + # TODO(marcocuturi): fix case when `p=1` self.q = 1. / (1. - 1. / self.p) if p > 1. else "inf" def h(self, z: jnp.ndarray) -> float: @@ -216,7 +216,7 @@ class Euclidean(CostFn): because the function is not strictly convex (it is linear on rays). """ - def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: + def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: """Compute Euclidean norm.""" return jnp.linalg.norm(x - y) @@ -229,7 +229,7 @@ def norm(self, x: jnp.ndarray) -> Union[float, jnp.ndarray]: """Compute squared Euclidean norm for vector.""" return jnp.sum(x ** 2, axis=-1) - def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: + def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: """Compute minus twice the dot-product between vectors.""" return -2. * jnp.vdot(x, y) @@ -249,14 +249,14 @@ class Cosine(CostFn): """Cosine distance cost function. Args: - ridge: TODO. + ridge: Ridge regularization. """ def __init__(self, ridge: float = 1e-8): super().__init__() self._ridge = ridge - def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: + def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: """Cosine distance between vectors, denominator regularized with ridge.""" ridge = self._ridge x_norm = jnp.linalg.norm(x, axis=-1) @@ -408,10 +408,10 @@ class UnbalancedBures(CostFn): triplets (mass, mean, covariance) raveled as vectors, in that order. Args: - dimension: TODO. - gamma: TODO. - sigma: TODO. - kwargs: TODO. + dimension: Dimensionality of the data. + gamma: KL-divergence regularization for the marginals. + sigma: Entropic regularization. + kwargs: Keyword arguments for :func:`~ott.math.matrix_square_root.sqrtm`. """ def __init__( @@ -431,7 +431,7 @@ def norm(self, x: jnp.ndarray) -> jnp.ndarray: """Compute norm of Gaussian for unbalanced Bures.""" return self._gamma * x[0] - def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: + def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: """Compute dot-product for unbalanced Bures.""" # Sets a few constants gam = self._gamma @@ -481,6 +481,7 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: # If all logdet signs are 1, output value, nan otherwise. # TODO(michalk8): use lax.cond + return jnp.where( sldet_c == 1 and sldet_c_ab == 1 and sldet_ab == 1 and sldet_t_ab == 1, 2 * sig2 * mass_x * mass_y - 2 * (sig2 + gam) * jnp.exp(log_m_pi), diff --git a/tests/solvers/linear/sinkhorn_diff_test.py b/tests/solvers/linear/sinkhorn_diff_test.py index f307ec954..80d9e62c8 100644 --- a/tests/solvers/linear/sinkhorn_diff_test.py +++ b/tests/solvers/linear/sinkhorn_diff_test.py @@ -23,8 +23,8 @@ import pytest from ott.geometry import costs, geometry, grid, pointcloud -from ott.math import implicit_differentiation as implicit_lib from ott.problems.linear import linear_problem +from ott.solvers.linear import implicit_differentiation as implicit_lib from ott.solvers.linear import sinkhorn from ott.tools import transport From 9186c616b7949539daf475cdd2d6dff6c02fa98a Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Tue, 22 Nov 2022 11:11:26 +0100 Subject: [PATCH 31/34] Use `jax.lax.cond` in `UnbalancedBures` --- ott/geometry/costs.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/ott/geometry/costs.py b/ott/geometry/costs.py index 7f35046e5..ad0fdbf8c 100644 --- a/ott/geometry/costs.py +++ b/ott/geometry/costs.py @@ -52,7 +52,6 @@ class CostFn(abc.ABC): def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: pass - # TODO(michalk8): make weights optional? def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> jnp.ndarray: """Barycentric projection. @@ -436,8 +435,8 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: # Sets a few constants gam = self._gamma sig2 = self._sigma ** 2 - lam = sig2 + gam / 2 - tau = gam / (2 * lam) + lam = sig2 + gam / 2.0 + tau = gam / (2.0 * lam) # Extracts mass, mean vector, covariance matrices mass_x, mass_y = x[0], y[0] @@ -463,29 +462,24 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: sldet_c, ldet_c = jnp.linalg.slogdet(c_mat) sldet_t_ab, ldet_t_ab = jnp.linalg.slogdet(tilde_a_b) sldet_ab, ldet_ab = jnp.linalg.slogdet(jnp.matmul(cov_x, cov_y)) - sldet_c_ab, ldet_c_ab = jnp.linalg.slogdet(c_mat - 2 * tilde_a_b / gam) + sldet_c_ab, ldet_c_ab = jnp.linalg.slogdet(c_mat - 2.0 * tilde_a_b / gam) # Gathers all these results to compute log total mass of transport log_m_pi = (0.5 * self._dimension * sig2 / (gam + sig2)) * jnp.log(sig2) - - log_m_pi += (1 / (tau + 1)) * ( + log_m_pi += (1.0 / (tau + 1.0)) * ( jnp.log(mass_x) + jnp.log(mass_y) + ldet_c + 0.5 * (tau * ldet_t_ab - ldet_ab) ) - log_m_pi += -jnp.sum( diff_means * jnp.linalg.solve(cov_x + cov_y + lam * iden, diff_means) - ) / (2 * (tau + 1)) - + ) / (2.0 * (tau + 1.0)) log_m_pi += -0.5 * ldet_c_ab - # If all logdet signs are 1, output value, nan otherwise. - # TODO(michalk8): use lax.cond - - return jnp.where( - sldet_c == 1 and sldet_c_ab == 1 and sldet_ab == 1 and sldet_t_ab == 1, - 2 * sig2 * mass_x * mass_y - 2 * (sig2 + gam) * jnp.exp(log_m_pi), - jnp.nan + # if all logdet signs are 1, output value, nan otherwise + pos_signs = (sldet_c + sldet_c_ab + sldet_t_ab + sldet_t_ab) == 4 + return jax.lax.cond( + pos_signs, lambda: 2 * sig2 * mass_x * mass_y - 2 * + (sig2 + gam) * jnp.exp(log_m_pi), lambda: jnp.nan ) def tree_flatten(self): From f49351ca30d8f93e762bfa0e0625e8bb58afebc9 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Tue, 22 Nov 2022 15:09:05 +0100 Subject: [PATCH 32/34] Fix `UnbalancedBures` --- ott/geometry/costs.py | 46 +++++++++++++++------- tests/solvers/linear/sinkhorn_misc_test.py | 39 ++++++++++++------ 2 files changed, 59 insertions(+), 26 deletions(-) diff --git a/ott/geometry/costs.py b/ott/geometry/costs.py index ad0fdbf8c..42559f1ef 100644 --- a/ott/geometry/costs.py +++ b/ott/geometry/costs.py @@ -400,38 +400,55 @@ def tree_unflatten(cls, aux_data, children): @jax.tree_util.register_pytree_node_class class UnbalancedBures(CostFn): - """Regularized/unbalanced Bures dist between two triplets of (mass,mean,cov). + """Unbalanced Bures distance between two triplets of `(mass, mean, cov)`. - This cost implements the value defined in :cite:`janati:20`, eq. 37, 39, 40. - We follow their notations. It is assumed inputs are given as - triplets (mass, mean, covariance) raveled as vectors, in that order. + This cost uses the notation defined in :cite:`janati:20`, eq. 37, 39, 40. Args: dimension: Dimensionality of the data. - gamma: KL-divergence regularization for the marginals. sigma: Entropic regularization. + gamma: KL-divergence regularization for the marginals. kwargs: Keyword arguments for :func:`~ott.math.matrix_square_root.sqrtm`. """ def __init__( self, dimension: int, - gamma: float = 1.0, + *, sigma: float = 1.0, + gamma: float = 1.0, **kwargs: Any, ): super().__init__() self._dimension = dimension - self._gamma = gamma self._sigma = sigma + self._gamma = gamma self._sqrtm_kw = kwargs def norm(self, x: jnp.ndarray) -> jnp.ndarray: - """Compute norm of Gaussian for unbalanced Bures.""" - return self._gamma * x[0] + """Compute norm of Gaussian for unbalanced Bures. + + Args: + x: Array of shape ``[n_points + n_points + n_dim ** 2,]`` + corresponding to the raveled mass, means and the covariance matrix. + + Returns: + The norm, array of shape ``[n_points,]``. + """ + return self._gamma * x[:, 0] def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: - """Compute dot-product for unbalanced Bures.""" + """Compute dot-product for unbalanced Bures. + + Args: + x: Array of shape ``[n_points + n_points + n_dim ** 2,]`` + corresponding to the raveled mass, means and the covariance matrix. + y: Array of shape ``[n_points + n_points + n_dim ** 2,]`` + corresponding to the raveled mass, means and the covariance matrix. + + Returns: + The cost. + """ # Sets a few constants gam = self._gamma sig2 = self._sigma ** 2 @@ -477,19 +494,20 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: # if all logdet signs are 1, output value, nan otherwise pos_signs = (sldet_c + sldet_c_ab + sldet_t_ab + sldet_t_ab) == 4 + return jax.lax.cond( pos_signs, lambda: 2 * sig2 * mass_x * mass_y - 2 * (sig2 + gam) * jnp.exp(log_m_pi), lambda: jnp.nan ) def tree_flatten(self): - return (), (self._dimension, self._gamma, self._sigma, self._sqrtm_kw) + return (), (self._dimension, self._sigma, self._gamma, self._sqrtm_kw) @classmethod def tree_unflatten(cls, aux_data, children): del children - dim, gamma, sigma, kwargs = aux_data - return cls(dim, gamma=gamma, sigma=sigma, **kwargs) + dim, sigma, gamma, kwargs = aux_data + return cls(dim, sigma=sigma, gamma=gamma, **kwargs) def x_to_means_and_covs(x: jnp.ndarray, @@ -498,7 +516,7 @@ def x_to_means_and_covs(x: jnp.ndarray, Args: x: [num_gaussians, dimension, (1 + dimension)] array of concatenated means - and covariances (raveled) dimension: the dimension of the Gaussians. + and covariances (raveled) dimension: the dimension of the Gaussians. Returns: means: [num_gaussians, dimension] array that holds the means. diff --git a/tests/solvers/linear/sinkhorn_misc_test.py b/tests/solvers/linear/sinkhorn_misc_test.py index fd1664b50..96da5c6bc 100644 --- a/tests/solvers/linear/sinkhorn_misc_test.py +++ b/tests/solvers/linear/sinkhorn_misc_test.py @@ -131,25 +131,40 @@ def initialize(self): self.a = a / jnp.sum(a) self.b = b / jnp.sum(b) - def test_bures_point_cloud(self): + @pytest.mark.parametrize("lse_mode", [False, True]) + @pytest.mark.parametrize("unbalanced,thresh", [(False, 1e-3), (True, 1e-4)]) + def test_bures_point_cloud( + self, rng: jnp.ndarray, lse_mode: bool, unbalanced: bool, thresh: float + ): """Two point clouds of Gaussians, tested with various parameters.""" - threshold = 1e-3 - geom = pointcloud.PointCloud( - self.x, - self.y, - cost_fn=costs.Bures(dimension=self.dim, regularization=1e-4), - epsilon=self.eps + if unbalanced: + rng1, rng2 = jax.random.split(rng, 2) + ws_x = jnp.abs(jax.random.normal(rng1, (self.x.shape[0], 1))) + 1e-1 + ws_y = jnp.abs(jax.random.normal(rng2, (self.y.shape[0], 1))) + 1e-1 + ws_x = ws_x.at[0].set(0.) + x = jnp.concatenate([ws_x, self.x], axis=1) + y = jnp.concatenate([ws_y, self.y], axis=1) + cost_fn = costs.UnbalancedBures(dimension=self.dim, gamma=0.9, sigma=0.98) + else: + x, y = self.x, self.y + cost_fn = costs.Bures(dimension=self.dim, regularization=1e-4) + + geom = pointcloud.PointCloud(x, y, cost_fn=cost_fn, epsilon=self.eps) + + out = sinkhorn.sinkhorn( + geom, a=self.a, b=self.b, lse_mode=lse_mode, threshold=thresh ) - errors = sinkhorn.sinkhorn(geom, a=self.a, b=self.b, lse_mode=False).errors - err = errors[errors > -1][-1] - assert threshold > err + err = out.errors[out.errors > -1][-1] + + assert out.converged + assert thresh > err - def test_regularized_unbalanced_bures(self): + def test_regularized_unbalanced_bures_cost(self): """Tests Regularized Unbalanced Bures.""" x = jnp.concatenate((jnp.array([0.9]), self.x[0, :])) y = jnp.concatenate((jnp.array([1.1]), self.y[0, :])) - rub = costs.UnbalancedBures(self.dim, 1, 0.8) + rub = costs.UnbalancedBures(self.dim, gamma=1.0, sigma=0.8) assert not jnp.any(jnp.isnan(rub(x, y))) assert not jnp.any(jnp.isnan(rub(y, x))) np.testing.assert_allclose(rub(x, y), rub(y, x), rtol=5e-3, atol=5e-3) From 687e5944b5e02be51186334bbeebd4672c582683 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Tue, 22 Nov 2022 15:50:14 +0100 Subject: [PATCH 33/34] Update CI versions --- .github/workflows/lint.yml | 6 +++--- .github/workflows/notebook_tests.yml | 6 +++--- .github/workflows/publish_to_pypi.yml | 4 ++-- .github/workflows/tests.yml | 4 ++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index b280db886..9c091ef1b 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -15,13 +15,13 @@ jobs: os: [ubuntu-latest] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python }} - - uses: actions/cache@v2 + - uses: actions/cache@v3 with: path: ~/.cache/pre-commit key: precommit-${{ env.pythonLocation }}-${{ hashFiles('**/.pre-commit-config.yaml') }} diff --git a/.github/workflows/notebook_tests.yml b/.github/workflows/notebook_tests.yml index b89b301e8..a3f327e5d 100644 --- a/.github/workflows/notebook_tests.yml +++ b/.github/workflows/notebook_tests.yml @@ -12,12 +12,12 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - python-version: ['3.8'] + python-version: [3.8] os: [ubuntu-latest] steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/publish_to_pypi.yml b/.github/workflows/publish_to_pypi.yml index b9efc2088..6d26315fd 100644 --- a/.github/workflows/publish_to_pypi.yml +++ b/.github/workflows/publish_to_pypi.yml @@ -13,9 +13,9 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: 3.x - name: Install dependencies diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9570b6a7d..be1e3dbd5 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -18,8 +18,8 @@ jobs: test_mark: [fast, all] steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} From aea8912df5db6705af684a8b3edf182e85032fb8 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Tue, 22 Nov 2022 16:14:38 +0100 Subject: [PATCH 34/34] Fix UnbalancedBures's norm --- ott/geometry/costs.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ott/geometry/costs.py b/ott/geometry/costs.py index 42559f1ef..94c06112a 100644 --- a/ott/geometry/costs.py +++ b/ott/geometry/costs.py @@ -429,13 +429,14 @@ def norm(self, x: jnp.ndarray) -> jnp.ndarray: """Compute norm of Gaussian for unbalanced Bures. Args: - x: Array of shape ``[n_points + n_points + n_dim ** 2,]`` - corresponding to the raveled mass, means and the covariance matrix. + x: Array of shape ``[n_points + n_points + n_dim ** 2,]``, potentially + batched, corresponding to the raveled mass, means and the covariance + matrix. Returns: - The norm, array of shape ``[n_points,]``. + The norm, array of shape ``[]`` or ``[batch,]`` in the batched case. """ - return self._gamma * x[:, 0] + return self._gamma * x[..., 0] def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: """Compute dot-product for unbalanced Bures.