From 0f91857f135ba5dd3b140eee2db2a1f4122ca8c8 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 22 Nov 2023 09:54:14 +0100 Subject: [PATCH 001/186] draft of BaseSolver and UnbalancedMixin --- docs/conf.py | 3 +- src/ott/neural/solvers/base_solver.py | 248 ++++++++++++++++++ tests/geometry/costs_test.py | 1 + tests/geometry/graph_test.py | 1 + tests/geometry/low_rank_test.py | 1 + tests/geometry/pointcloud_test.py | 1 + tests/geometry/scaling_cost_test.py | 1 + tests/geometry/subsetting_test.py | 1 + .../initializers/linear/sinkhorn_init_test.py | 1 + .../linear/sinkhorn_lr_init_test.py | 1 + tests/initializers/quadratic/gw_init_test.py | 1 + tests/math/lse_test.py | 1 + tests/math/math_utils_test.py | 1 + tests/math/matrix_square_root_test.py | 1 + tests/neural/icnn_test.py | 1 + tests/neural/losses_test.py | 1 + tests/neural/map_estimator_test.py | 1 + tests/neural/meta_initializer_test.py | 1 + tests/neural/neuraldual_test.py | 1 + tests/problems/linear/potentials_test.py | 1 + .../linear/continuous_barycenter_test.py | 1 + .../linear/discrete_barycenter_test.py | 1 + tests/solvers/linear/sinkhorn_diff_test.py | 1 + tests/solvers/linear/sinkhorn_grid_test.py | 1 + tests/solvers/linear/sinkhorn_lr_test.py | 1 + tests/solvers/linear/sinkhorn_misc_test.py | 1 + tests/solvers/linear/sinkhorn_test.py | 1 + tests/solvers/linear/univariate_test.py | 1 + tests/solvers/quadratic/fgw_test.py | 1 + tests/solvers/quadratic/gw_barycenter_test.py | 1 + tests/solvers/quadratic/gw_test.py | 1 + tests/solvers/quadratic/lower_bound_test.py | 1 + .../gaussian_mixture/fit_gmm_pair_test.py | 1 + tests/tools/gaussian_mixture/fit_gmm_test.py | 1 + .../gaussian_mixture_pair_test.py | 1 + .../gaussian_mixture/gaussian_mixture_test.py | 1 + tests/tools/gaussian_mixture/gaussian_test.py | 1 + tests/tools/gaussian_mixture/linalg_test.py | 1 + .../gaussian_mixture/probabilities_test.py | 1 + .../tools/gaussian_mixture/scale_tril_test.py | 1 + tests/tools/k_means_test.py | 5 +- tests/tools/plot_test.py | 1 + tests/tools/segment_sinkhorn_test.py | 1 + tests/tools/sinkhorn_divergence_test.py | 1 + tests/tools/soft_sort_test.py | 1 + tests/utils_test.py | 1 + 46 files changed, 296 insertions(+), 3 deletions(-) create mode 100644 src/ott/neural/solvers/base_solver.py diff --git a/docs/conf.py b/docs/conf.py index 6158c668d..c2a2c7102 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -26,9 +26,10 @@ import logging from datetime import datetime -import ott from sphinx.util import logging as sphinx_logging +import ott + # -- Project information ----------------------------------------------------- needs_sphinx = "4.0" diff --git a/src/ott/neural/solvers/base_solver.py b/src/ott/neural/solvers/base_solver.py new file mode 100644 index 000000000..cd12684e0 --- /dev/null +++ b/src/ott/neural/solvers/base_solver.py @@ -0,0 +1,248 @@ +from abc import ABC, abstractmethod +from pathlib import Path +from types import Mapping, MappingProxyType +from typing import ( + Any, + Callable, + Dict, + Literal, + Optional, + Tuple, + Union, +) + +import jax +import jax.numpy as jnp +import optax +from flax import train_state + +from ott.geometry.pointcloud import PointCloud +from ott.neural.solvers import models +from ott.problems.linear import linear_problem +from ott.solvers.linear import sinkhorn + + +class BaseNeuralSolver(ABC): + """Base class for neural solvers. + + Args: + iterations: Number of iterations to train for. + valid_freq: Frequency at which to run validation. + """ + + def __init__(self, iterations: int, valid_freq: int, **_: Any) -> Any: + self.iterations = iterations + self.valid_freq = valid_freq + + @abstractmethod + def setup(self, *args: Any, **kwargs: Any) -> None: + pass + + @abstractmethod + def __call__(self, *args: Any, **kwargs: Any) -> None: + """Train the model.""" + pass + + @abstractmethod + def save(self, path: Path): + """Save the model.""" + pass + + @abstractmethod + @property + def is_balanced(self) -> Dict[str, Any]: + """Return the training logs.""" + pass + + @abstractmethod + @property + def training_logs(self) -> Dict[str, Any]: + """Return the training logs.""" + pass + + +class UnbalancednessMixin: + + def __init__( + self, + source_dim: int, + target_dim: int, + cond_dim: Optional[int], + tau_a: float = 1.0, + tau_b: float = 1.0, + mlp_eta: Optional[models.ModelBase] = None, + mlp_xi: Optional[models.ModelBase] = None, + seed: Optional[int] = None, + opt_eta: Optional[optax.GradientTransformation] = None, + opt_xi: Optional[optax.GradientTransformation] = None, + resample_epsilon: float = 1e-2, + scale_cost: Union[bool, int, float, Literal["mean", "max_cost", + "median"]] = "mean", + sinkhorn_kwargs: Mapping[str, Any] = MappingProxyType({}), + **_: Any, + ) -> None: + self.source_dim = source_dim + self.target_dim = target_dim + self.cond_dim = cond_dim + self.tau_a = tau_a + self.tau_b = tau_b + self.mlp_eta = mlp_eta + self.mlp_xi = mlp_xi + self.seed = seed + self.opt_eta = opt_eta + self.opt_xi = opt_xi + self.resample_epsilon = resample_epsilon + self.scale_cost = scale_cost + + self._compute_unbalanced_marginals = self._get_compute_unbalanced_marginals( + tau_a=tau_a, + tau_b=tau_b, + resample_epsilon=resample_epsilon, + scale_cost=scale_cost, + sinkhorn_kwargs=sinkhorn_kwargs + ) + self._setup(source_dim=source_dim, target_dim=target_dim, cond_dim=cond_dim) + + def _get_compute_unbalanced_marginals( + self, + tau_a: float, + tau_b: float, + resample_epsilon: float, + scale_cost: Union[bool, int, float, Literal["mean", "max_cost", + "median"]] = "mean", + sinkhorn_kwargs: Dict[str, Any] = MappingProxyType({}), + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Compute the unbalanced source and target marginals for a batch.""" + + @jax.jit + def compute_unbalanced_marginals( + batch_source: jnp.ndarray, batch_target: jnp.ndarray + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + geom = PointCloud( + batch_source, + batch_target, + epsilon=resample_epsilon, + scale_cost=scale_cost + ) + out = sinkhorn.Sinkhorn(**sinkhorn_kwargs)( + linear_problem.LinearProblem(geom, tau_a=tau_a, tau_b=tau_b) + ) + return out.matrix.sum(axis=1), out.matrix.sum(axis=0) + + return compute_unbalanced_marginals + + @jax.jit + def _resample( + self, + key: jax.random.KeyArray, + batch: Tuple[jnp.ndarray, ...], + marginals: jnp.ndarray, + ) -> Tuple[jnp.ndarray, ...]: + """Resample a batch based upon marginals.""" + indices = jax.random.choice( + key, a=len(marginals), p=jnp.squeeze(marginals), shape=[len(marginals)] + ) + return tuple(b[indices] if b is not None else None for b in batch) + + def _setup(self, source_dim: int, target_dim: int, cond_dim: int): + self.unbalancedness_step_fn = self._get_step_fn() + if self.mlp_eta is not None: + self.opt_eta = ( + self.opt_eta if self.opt_eta is not None else + optax.adamw(learning_rate=1e-4, weight_decay=1e-10) + ) + self.state_eta = self.mlp_eta.create_train_state( + self._key, self.opt_eta, source_dim + cond_dim + ) + if self.mlp_xi is not None: + self.opt_xi = ( + self.opt_xi if self.opt_xi is not None else + optax.adamw(learning_rate=1e-4, weight_decay=1e-10) + ) + self.state_xi = self.mlp_xi.create_train_state( + self._key, self.opt_xi, target_dim + cond_dim + ) + + def _get_step_fn(self) -> Callable: # type:ignore[type-arg] + + def loss_a_fn( + params_eta: Optional[jnp.ndarray], + apply_fn_eta: Callable[[Dict[str, jnp.ndarray], jnp.ndarray], + jnp.ndarray], + x: jnp.ndarray, + a: jnp.ndarray, + expectation_reweighting: float, + ) -> Tuple[float, jnp.ndarray]: + eta_predictions = apply_fn_eta({"params": params_eta}, x) + return ( + optax.l2_loss(eta_predictions[:, 0], a).mean() + + optax.l2_loss(jnp.mean(eta_predictions) - expectation_reweighting), + eta_predictions, + ) + + def loss_b_fn( + params_xi: Optional[jnp.ndarray], + apply_fn_xi: Callable[[Dict[str, jnp.ndarray], jnp.ndarray], + jnp.ndarray], + x: jnp.ndarray, + b: jnp.ndarray, + expectation_reweighting: float, + ) -> Tuple[float, jnp.ndarray]: + xi_predictions = apply_fn_xi({"params": params_xi}, x) + return ( + optax.l2_loss(xi_predictions[:, 0], b).mean() + + optax.l2_loss(jnp.mean(xi_predictions) - expectation_reweighting), + xi_predictions, + ) + + @jax.jit + def step_fn( + source: jnp.ndarray, + target: jnp.ndarray, + condition: Optional[jnp.ndarray], + a: jnp.ndarray, + b: jnp.ndarray, + state_eta: Optional[train_state.TrainState] = None, + state_xi: Optional[train_state.TrainState] = None, + *, + is_training: bool = True, + ): + if condition is None: + input_source = source + input_target = target + else: + input_source = jnp.concatenate([source, condition], axis=-1) + input_target = jnp.concatenate([target, condition], axis=-1) + if state_eta is not None: + grad_a_fn = jax.value_and_grad(loss_a_fn, argnums=0, has_aux=True) + (loss_a, eta_predictions), grads_eta = grad_a_fn( + state_eta.params, + state_eta.apply_fn, + input_source, + a * len(a), + jnp.sum(b), + ) + new_state_eta = state_eta.apply_gradients( + grads=grads_eta + ) if is_training else None + + else: + new_state_eta = eta_predictions = loss_a = None + if state_xi is not None: + grad_b_fn = jax.value_and_grad(loss_b_fn, argnums=0, has_aux=True) + (loss_b, xi_predictions), grads_xi = grad_b_fn( + state_xi.params, + state_xi.apply_fn, + input_target, + b * len(b), + jnp.sum(a), + ) + new_state_xi = state_xi.apply_gradients( + grads=grads_xi + ) if is_training else None + else: + new_state_xi = xi_predictions = loss_b = None + + return new_state_eta, new_state_xi, eta_predictions, xi_predictions, loss_a, loss_b + + return step_fn diff --git a/tests/geometry/costs_test.py b/tests/geometry/costs_test.py index 20158042a..57a4d8874 100644 --- a/tests/geometry/costs_test.py +++ b/tests/geometry/costs_test.py @@ -17,6 +17,7 @@ import jax.numpy as jnp import numpy as np import pytest + from ott.geometry import costs, pointcloud from ott.solvers import linear diff --git a/tests/geometry/graph_test.py b/tests/geometry/graph_test.py index eb80735f2..c242b192f 100644 --- a/tests/geometry/graph_test.py +++ b/tests/geometry/graph_test.py @@ -21,6 +21,7 @@ from jax.experimental import sparse from networkx.algorithms import shortest_paths from networkx.generators import balanced_tree, random_graphs + from ott.geometry import geometry, graph from ott.problems.linear import linear_problem from ott.solvers.linear import implicit_differentiation as implicit_lib diff --git a/tests/geometry/low_rank_test.py b/tests/geometry/low_rank_test.py index e0d937b35..87dd98db2 100644 --- a/tests/geometry/low_rank_test.py +++ b/tests/geometry/low_rank_test.py @@ -17,6 +17,7 @@ import jax.numpy as jnp import numpy as np import pytest + from ott.geometry import costs, geometry, grid, low_rank, pointcloud diff --git a/tests/geometry/pointcloud_test.py b/tests/geometry/pointcloud_test.py index 17c0ac7aa..5f75ddb8e 100644 --- a/tests/geometry/pointcloud_test.py +++ b/tests/geometry/pointcloud_test.py @@ -17,6 +17,7 @@ import jax.numpy as jnp import numpy as np import pytest + from ott.geometry import costs, geometry, pointcloud diff --git a/tests/geometry/scaling_cost_test.py b/tests/geometry/scaling_cost_test.py index 9cb905cfa..94ce97cf4 100644 --- a/tests/geometry/scaling_cost_test.py +++ b/tests/geometry/scaling_cost_test.py @@ -17,6 +17,7 @@ import jax.numpy as jnp import numpy as np import pytest + from ott.geometry import geometry, low_rank, pointcloud from ott.problems.linear import linear_problem from ott.solvers.linear import sinkhorn, sinkhorn_lr diff --git a/tests/geometry/subsetting_test.py b/tests/geometry/subsetting_test.py index 579180c8c..c07929436 100644 --- a/tests/geometry/subsetting_test.py +++ b/tests/geometry/subsetting_test.py @@ -17,6 +17,7 @@ import jax.numpy as jnp import numpy as np import pytest + from ott.geometry import geometry, low_rank, pointcloud Geom_t = Union[pointcloud.PointCloud, geometry.Geometry, low_rank.LRCGeometry] diff --git a/tests/initializers/linear/sinkhorn_init_test.py b/tests/initializers/linear/sinkhorn_init_test.py index 0b5979c05..6acf77f11 100644 --- a/tests/initializers/linear/sinkhorn_init_test.py +++ b/tests/initializers/linear/sinkhorn_init_test.py @@ -17,6 +17,7 @@ import jax.numpy as jnp import numpy as np import pytest + from ott.geometry import geometry, pointcloud from ott.initializers.linear import initializers as linear_init from ott.problems.linear import linear_problem diff --git a/tests/initializers/linear/sinkhorn_lr_init_test.py b/tests/initializers/linear/sinkhorn_lr_init_test.py index b71ff0aac..f3fe7acd1 100644 --- a/tests/initializers/linear/sinkhorn_lr_init_test.py +++ b/tests/initializers/linear/sinkhorn_lr_init_test.py @@ -15,6 +15,7 @@ import jax.numpy as jnp import numpy as np import pytest + from ott.geometry import geometry, pointcloud from ott.initializers.linear import initializers_lr from ott.problems.linear import linear_problem diff --git a/tests/initializers/quadratic/gw_init_test.py b/tests/initializers/quadratic/gw_init_test.py index 7298bfafe..4c39bafb4 100644 --- a/tests/initializers/quadratic/gw_init_test.py +++ b/tests/initializers/quadratic/gw_init_test.py @@ -14,6 +14,7 @@ import jax import numpy as np import pytest + from ott.geometry import pointcloud from ott.initializers.linear import initializers as lin_init from ott.initializers.linear import initializers_lr diff --git a/tests/math/lse_test.py b/tests/math/lse_test.py index 342726ebe..b842afe21 100644 --- a/tests/math/lse_test.py +++ b/tests/math/lse_test.py @@ -15,6 +15,7 @@ import jax.numpy as jnp import numpy as np import pytest + from ott.math import utils as mu diff --git a/tests/math/math_utils_test.py b/tests/math/math_utils_test.py index 5a5e3a69a..b8451355b 100644 --- a/tests/math/math_utils_test.py +++ b/tests/math/math_utils_test.py @@ -17,6 +17,7 @@ import jax.numpy as jnp import numpy as np import pytest + from ott.math import utils as mu diff --git a/tests/math/matrix_square_root_test.py b/tests/math/matrix_square_root_test.py index 2bee2ea70..fcd557957 100644 --- a/tests/math/matrix_square_root_test.py +++ b/tests/math/matrix_square_root_test.py @@ -17,6 +17,7 @@ import jax.numpy as jnp import numpy as np import pytest + from ott.math import matrix_square_root diff --git a/tests/neural/icnn_test.py b/tests/neural/icnn_test.py index fabc4f422..fd6c07f2b 100644 --- a/tests/neural/icnn_test.py +++ b/tests/neural/icnn_test.py @@ -15,6 +15,7 @@ import jax.numpy as jnp import numpy as np import pytest + from ott.neural.models import models diff --git a/tests/neural/losses_test.py b/tests/neural/losses_test.py index f675dbc76..8cff7bd64 100644 --- a/tests/neural/losses_test.py +++ b/tests/neural/losses_test.py @@ -15,6 +15,7 @@ import jax import numpy as np import pytest + from ott.geometry import costs from ott.neural.models import models from ott.neural.solvers import losses diff --git a/tests/neural/map_estimator_test.py b/tests/neural/map_estimator_test.py index 680cfcd01..7c506aa38 100644 --- a/tests/neural/map_estimator_test.py +++ b/tests/neural/map_estimator_test.py @@ -15,6 +15,7 @@ import jax.numpy as jnp import pytest + from ott import datasets from ott.geometry import pointcloud from ott.neural.models import models diff --git a/tests/neural/meta_initializer_test.py b/tests/neural/meta_initializer_test.py index f9b1e4cd0..f711366ec 100644 --- a/tests/neural/meta_initializer_test.py +++ b/tests/neural/meta_initializer_test.py @@ -17,6 +17,7 @@ import jax.numpy as jnp import pytest from flax import linen as nn + from ott.geometry import pointcloud from ott.initializers.linear import initializers as linear_init from ott.neural.models import models as nn_init diff --git a/tests/neural/neuraldual_test.py b/tests/neural/neuraldual_test.py index c1aed055d..252bf817a 100644 --- a/tests/neural/neuraldual_test.py +++ b/tests/neural/neuraldual_test.py @@ -16,6 +16,7 @@ import jax import numpy as np import pytest + from ott import datasets from ott.neural.models import conjugate_solvers, models from ott.neural.solvers import neuraldual diff --git a/tests/problems/linear/potentials_test.py b/tests/problems/linear/potentials_test.py index 619537297..c9fa9cf17 100644 --- a/tests/problems/linear/potentials_test.py +++ b/tests/problems/linear/potentials_test.py @@ -17,6 +17,7 @@ import matplotlib.pyplot as plt import numpy as np import pytest + from ott.geometry import costs, pointcloud from ott.problems.linear import linear_problem, potentials from ott.solvers.linear import sinkhorn diff --git a/tests/solvers/linear/continuous_barycenter_test.py b/tests/solvers/linear/continuous_barycenter_test.py index de4724200..5512263c7 100644 --- a/tests/solvers/linear/continuous_barycenter_test.py +++ b/tests/solvers/linear/continuous_barycenter_test.py @@ -18,6 +18,7 @@ import jax.numpy as jnp import numpy as np import pytest + from ott.geometry import costs, segment from ott.problems.linear import barycenter_problem from ott.solvers.linear import continuous_barycenter as cb diff --git a/tests/solvers/linear/discrete_barycenter_test.py b/tests/solvers/linear/discrete_barycenter_test.py index 9e31d85d0..dc90e15c0 100644 --- a/tests/solvers/linear/discrete_barycenter_test.py +++ b/tests/solvers/linear/discrete_barycenter_test.py @@ -13,6 +13,7 @@ # limitations under the License. import jax.numpy as jnp import pytest + from ott.geometry import grid, pointcloud from ott.problems.linear import barycenter_problem as bp from ott.solvers.linear import discrete_barycenter as db diff --git a/tests/solvers/linear/sinkhorn_diff_test.py b/tests/solvers/linear/sinkhorn_diff_test.py index 892233127..d80f94251 100644 --- a/tests/solvers/linear/sinkhorn_diff_test.py +++ b/tests/solvers/linear/sinkhorn_diff_test.py @@ -18,6 +18,7 @@ import jax.numpy as jnp import numpy as np import pytest + from ott.geometry import costs, geometry, grid, pointcloud from ott.problems.linear import linear_problem from ott.solvers.linear import implicit_differentiation as implicit_lib diff --git a/tests/solvers/linear/sinkhorn_grid_test.py b/tests/solvers/linear/sinkhorn_grid_test.py index 925af7278..b2aa4da3e 100644 --- a/tests/solvers/linear/sinkhorn_grid_test.py +++ b/tests/solvers/linear/sinkhorn_grid_test.py @@ -15,6 +15,7 @@ import jax.numpy as jnp import numpy as np import pytest + from ott.geometry import grid, pointcloud from ott.problems.linear import linear_problem from ott.solvers import linear diff --git a/tests/solvers/linear/sinkhorn_lr_test.py b/tests/solvers/linear/sinkhorn_lr_test.py index 975570ac9..9b360bdf0 100644 --- a/tests/solvers/linear/sinkhorn_lr_test.py +++ b/tests/solvers/linear/sinkhorn_lr_test.py @@ -17,6 +17,7 @@ import jax.numpy as jnp import numpy as np import pytest + from ott.geometry import low_rank, pointcloud from ott.problems.linear import linear_problem from ott.solvers.linear import sinkhorn_lr diff --git a/tests/solvers/linear/sinkhorn_misc_test.py b/tests/solvers/linear/sinkhorn_misc_test.py index 8fd2623e5..aeb37918b 100644 --- a/tests/solvers/linear/sinkhorn_misc_test.py +++ b/tests/solvers/linear/sinkhorn_misc_test.py @@ -18,6 +18,7 @@ import jax.numpy as jnp import numpy as np import pytest + from ott.geometry import costs, geometry, pointcloud from ott.problems.linear import linear_problem from ott.solvers import linear diff --git a/tests/solvers/linear/sinkhorn_test.py b/tests/solvers/linear/sinkhorn_test.py index 2ff49b57e..ce7f9919a 100644 --- a/tests/solvers/linear/sinkhorn_test.py +++ b/tests/solvers/linear/sinkhorn_test.py @@ -19,6 +19,7 @@ import jax.numpy as jnp import numpy as np import pytest + from ott import utils from ott.geometry import costs, epsilon_scheduler, geometry, grid, pointcloud from ott.problems.linear import linear_problem diff --git a/tests/solvers/linear/univariate_test.py b/tests/solvers/linear/univariate_test.py index 221f295cd..1a5529167 100644 --- a/tests/solvers/linear/univariate_test.py +++ b/tests/solvers/linear/univariate_test.py @@ -16,6 +16,7 @@ import numpy as np import pytest import scipy as sp + from ott.geometry import costs, pointcloud from ott.problems.linear import linear_problem from ott.solvers.linear import sinkhorn, univariate diff --git a/tests/solvers/quadratic/fgw_test.py b/tests/solvers/quadratic/fgw_test.py index acc40ba36..0a2a2fff4 100644 --- a/tests/solvers/quadratic/fgw_test.py +++ b/tests/solvers/quadratic/fgw_test.py @@ -17,6 +17,7 @@ import jax.numpy as jnp import numpy as np import pytest + from ott.geometry import geometry, low_rank, pointcloud from ott.problems.quadratic import quadratic_problem from ott.solvers.linear import implicit_differentiation as implicit_lib diff --git a/tests/solvers/quadratic/gw_barycenter_test.py b/tests/solvers/quadratic/gw_barycenter_test.py index d07247fef..6bc843477 100644 --- a/tests/solvers/quadratic/gw_barycenter_test.py +++ b/tests/solvers/quadratic/gw_barycenter_test.py @@ -17,6 +17,7 @@ import jax.numpy as jnp import numpy as np import pytest + from ott.geometry import pointcloud from ott.problems.quadratic import gw_barycenter as gwb from ott.solvers.quadratic import gw_barycenter as gwb_solver diff --git a/tests/solvers/quadratic/gw_test.py b/tests/solvers/quadratic/gw_test.py index 192de8ed6..e7b77cd58 100644 --- a/tests/solvers/quadratic/gw_test.py +++ b/tests/solvers/quadratic/gw_test.py @@ -17,6 +17,7 @@ import jax.numpy as jnp import numpy as np import pytest + from ott.geometry import geometry, low_rank, pointcloud from ott.problems.quadratic import quadratic_problem from ott.solvers.linear import implicit_differentiation as implicit_lib diff --git a/tests/solvers/quadratic/lower_bound_test.py b/tests/solvers/quadratic/lower_bound_test.py index 6a15bd20a..ba90d6362 100644 --- a/tests/solvers/quadratic/lower_bound_test.py +++ b/tests/solvers/quadratic/lower_bound_test.py @@ -19,6 +19,7 @@ import jax.numpy as jnp import numpy as np import pytest + from ott.geometry import costs, pointcloud from ott.initializers.linear import initializers from ott.problems.quadratic import quadratic_problem diff --git a/tests/tools/gaussian_mixture/fit_gmm_pair_test.py b/tests/tools/gaussian_mixture/fit_gmm_pair_test.py index 98b7619ea..8f43eaa4e 100644 --- a/tests/tools/gaussian_mixture/fit_gmm_pair_test.py +++ b/tests/tools/gaussian_mixture/fit_gmm_pair_test.py @@ -14,6 +14,7 @@ import jax import jax.numpy as jnp import pytest + from ott.tools.gaussian_mixture import ( fit_gmm, fit_gmm_pair, diff --git a/tests/tools/gaussian_mixture/fit_gmm_test.py b/tests/tools/gaussian_mixture/fit_gmm_test.py index 18c930740..e39633b19 100644 --- a/tests/tools/gaussian_mixture/fit_gmm_test.py +++ b/tests/tools/gaussian_mixture/fit_gmm_test.py @@ -15,6 +15,7 @@ import jax.numpy as jnp import jax.test_util import pytest + from ott.tools.gaussian_mixture import fit_gmm, gaussian_mixture diff --git a/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py b/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py index ce57fa533..ccf1e50cd 100644 --- a/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py +++ b/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py @@ -15,6 +15,7 @@ import jax.numpy as jnp import numpy as np import pytest + from ott.tools.gaussian_mixture import gaussian_mixture, gaussian_mixture_pair diff --git a/tests/tools/gaussian_mixture/gaussian_mixture_test.py b/tests/tools/gaussian_mixture/gaussian_mixture_test.py index 3e6fcde83..af52860be 100644 --- a/tests/tools/gaussian_mixture/gaussian_mixture_test.py +++ b/tests/tools/gaussian_mixture/gaussian_mixture_test.py @@ -15,6 +15,7 @@ import jax.numpy as jnp import numpy as np import pytest + from ott.tools.gaussian_mixture import gaussian_mixture, linalg diff --git a/tests/tools/gaussian_mixture/gaussian_test.py b/tests/tools/gaussian_mixture/gaussian_test.py index 1d05d5056..8b720861c 100644 --- a/tests/tools/gaussian_mixture/gaussian_test.py +++ b/tests/tools/gaussian_mixture/gaussian_test.py @@ -15,6 +15,7 @@ import jax.numpy as jnp import numpy as np import pytest + from ott.tools.gaussian_mixture import gaussian, scale_tril diff --git a/tests/tools/gaussian_mixture/linalg_test.py b/tests/tools/gaussian_mixture/linalg_test.py index b92552a23..4db928264 100644 --- a/tests/tools/gaussian_mixture/linalg_test.py +++ b/tests/tools/gaussian_mixture/linalg_test.py @@ -15,6 +15,7 @@ import jax.numpy as jnp import numpy as np import pytest + from ott.tools.gaussian_mixture import linalg diff --git a/tests/tools/gaussian_mixture/probabilities_test.py b/tests/tools/gaussian_mixture/probabilities_test.py index 9d51be1a4..4924924df 100644 --- a/tests/tools/gaussian_mixture/probabilities_test.py +++ b/tests/tools/gaussian_mixture/probabilities_test.py @@ -15,6 +15,7 @@ import jax.numpy as jnp import numpy as np import pytest + from ott.tools.gaussian_mixture import probabilities diff --git a/tests/tools/gaussian_mixture/scale_tril_test.py b/tests/tools/gaussian_mixture/scale_tril_test.py index 049f9a043..3e53fd543 100644 --- a/tests/tools/gaussian_mixture/scale_tril_test.py +++ b/tests/tools/gaussian_mixture/scale_tril_test.py @@ -15,6 +15,7 @@ import jax.numpy as jnp import numpy as np import pytest + from ott.math import matrix_square_root from ott.tools.gaussian_mixture import scale_tril diff --git a/tests/tools/k_means_test.py b/tests/tools/k_means_test.py index b77f4ce5f..a36c4b5c1 100644 --- a/tests/tools/k_means_test.py +++ b/tests/tools/k_means_test.py @@ -19,12 +19,13 @@ import jax.numpy as jnp import numpy as np import pytest -from ott.geometry import costs, pointcloud -from ott.tools import k_means from sklearn import datasets from sklearn.cluster import KMeans, kmeans_plusplus from sklearn.cluster._k_means_common import _is_same_clustering +from ott.geometry import costs, pointcloud +from ott.tools import k_means + def make_blobs( *args: Any, diff --git a/tests/tools/plot_test.py b/tests/tools/plot_test.py index 80e374bb6..8c8b81a1c 100644 --- a/tests/tools/plot_test.py +++ b/tests/tools/plot_test.py @@ -14,6 +14,7 @@ import jax import matplotlib.pyplot as plt + import ott from ott.geometry import pointcloud from ott.problems.linear import linear_problem diff --git a/tests/tools/segment_sinkhorn_test.py b/tests/tools/segment_sinkhorn_test.py index 2e56af4c3..119dbf93a 100644 --- a/tests/tools/segment_sinkhorn_test.py +++ b/tests/tools/segment_sinkhorn_test.py @@ -15,6 +15,7 @@ import jax.numpy as jnp import numpy as np import pytest + from ott.geometry import costs, pointcloud from ott.problems.linear import linear_problem from ott.solvers.linear import sinkhorn diff --git a/tests/tools/sinkhorn_divergence_test.py b/tests/tools/sinkhorn_divergence_test.py index 1c180fe31..d46c220d0 100644 --- a/tests/tools/sinkhorn_divergence_test.py +++ b/tests/tools/sinkhorn_divergence_test.py @@ -17,6 +17,7 @@ import jax.numpy as jnp import numpy as np import pytest + from ott.geometry import costs, geometry, pointcloud from ott.solvers import linear from ott.solvers.linear import acceleration diff --git a/tests/tools/soft_sort_test.py b/tests/tools/soft_sort_test.py index 735764f8a..372420a9e 100644 --- a/tests/tools/soft_sort_test.py +++ b/tests/tools/soft_sort_test.py @@ -18,6 +18,7 @@ import jax.numpy as jnp import numpy as np import pytest + from ott.solvers.linear import acceleration from ott.solvers.linear import implicit_differentiation as implicit_lib from ott.tools import soft_sort diff --git a/tests/utils_test.py b/tests/utils_test.py index 768a498b5..192ed59f4 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -14,6 +14,7 @@ from typing import Optional import pytest + from ott import utils From 3706970e21a6732eb1bd7bb659be752218f460ee Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 22 Nov 2023 10:46:49 +0100 Subject: [PATCH 002/186] draft of BaseSolver and UnbalancedMixin --- src/ott/neural/solvers/base_solver.py | 14 ++-- src/ott/neural/solvers/flow_matching.py | 100 ++++++++++++++++++++++++ src/ott/neural/solvers/neuraldual.py | 4 + 3 files changed, 111 insertions(+), 7 deletions(-) create mode 100644 src/ott/neural/solvers/flow_matching.py diff --git a/src/ott/neural/solvers/base_solver.py b/src/ott/neural/solvers/base_solver.py index cd12684e0..caca5e732 100644 --- a/src/ott/neural/solvers/base_solver.py +++ b/src/ott/neural/solvers/base_solver.py @@ -2,13 +2,13 @@ from pathlib import Path from types import Mapping, MappingProxyType from typing import ( - Any, - Callable, - Dict, - Literal, - Optional, - Tuple, - Union, + Any, + Callable, + Dict, + Literal, + Optional, + Tuple, + Union, ) import jax diff --git a/src/ott/neural/solvers/flow_matching.py b/src/ott/neural/solvers/flow_matching.py new file mode 100644 index 000000000..98c05cb8f --- /dev/null +++ b/src/ott/neural/solvers/flow_matching.py @@ -0,0 +1,100 @@ +from typing import Any, Callable, Dict, Optional, Type + +import jax.numpy as jnp +import orbax as obx + +from ott.geometry import costs +from ott.neural.models.models import BaseNeuralVectorField +from ott.neural.solver.base_solver import BaseNeuralSolver, UnbalancednessMixin +from ott.solvers import was_solver + + +class FlowMatching(BaseNeuralSolver, UnbalancednessMixin): + + def __init__( + self, + neural_vector_field: Type[BaseNeuralVectorField], + input_dim: int, + iterations: int, + valid_freq: int, + ot_solver: Type[was_solver.WassersteinSolver], + optimizer: Optional[Any] = None, + checkpoint_manager: Type[obx.CheckpointManager] = None, + epsilon: float = 1e-2, + cost_fn: Type[costs.CostFn] = costs.SqEuclidean(), + tau_a: float = 1.0, + tau_b: float = 1.0, + mlp_eta: Callable[[jnp.ndarray], float] = None, + mlp_xi: Callable[[jnp.ndarray], float] = None, + unbalanced_kwargs: Dict[str, Any] = {}, + callback_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], + Any]] = None, + seed: int = 0, + **kwargs: Any, + ) -> None: + + super().__init__(iterations=iterations, valid_freq=valid_freq) + super(UnbalancednessMixin, self).__init__( + mlp_eta=mlp_eta, + mlp_xi=mlp_xi, + tau_a=tau_a, + tau_b=tau_b, + **unbalanced_kwargs + ) + self.neural_vector_field = neural_vector_field + self.input_dim = input_dim + self.ot_solver = ot_solver + self.optimizer = optimizer + self.epsilon = epsilon + self.cost_fn = cost_fn + self.callback_fn = callback_fn + self.checkpoint_manager = checkpoint_manager + self.seed = seed + + def setup(self, **kwargs: Any) -> None: + self.state_neural_vector_field = self.neural_vector_field.create_train_state( + self.rng, self.optimizer, self.output_dim + ) + + self.step_fn = self._get_step_fn() + + self.match_fn = self._get_match_fn( + self.ot_solver, + epsilon=self.epsilon, + cost_fn=self.cost_fn, + tau_a=self.tau_a, + tau_b=self.tau_b, + scale_cost=self.scale_cost, + ) + + def _get_match_fn(self): + pass + + def __call__(self, train_loader, valid_loader) -> None: + for iter in range(self.iterations): + batch = next(train_loader) + batch, a, b = self.match_fn(batch) + self.state_neural_vector_field, logs = self.step_fn( + self.state_neural_vector_field, batch + ) + if not self.is_balanced: + self.unbalancedness_step_fn(batch, a, b) + if iter % self.valid_freq == 0: + self._valid_step(valid_loader, iter) + if self.checkpoint_manager is not None: + states_to_save = { + "state_neural_vector_field": self.state_neural_vector_field + } + if self.state_mlp is not None: + states_to_save["state_eta"] = self.state_mlp + if self.state_xi is not None: + states_to_save["state_xi"] = self.state_xi + self.checkpoint_manager.save(iter, states_to_save) + + def _valid_step(self, valid_loader, iter) -> None: + batch = next(valid_loader) + batch, a, b = self.match_fn(batch) + if not self.is_balanced: + self.unbalancedness_step_fn(batch, a, b) + if self.callback_fn is not None: + self.callback_fn(batch, a, b) diff --git a/src/ott/neural/solvers/neuraldual.py b/src/ott/neural/solvers/neuraldual.py index fef8f873c..6ac3d1c79 100644 --- a/src/ott/neural/solvers/neuraldual.py +++ b/src/ott/neural/solvers/neuraldual.py @@ -68,6 +68,10 @@ class W2NeuralTrainState(train_state.TrainState): ) +class BaseNeuralVectorField(nn.Module): + pass + + class BaseW2NeuralDual(abc.ABC, nn.Module): """Base class for the neural solver models.""" From 42dd2b8263e05cc2ba01c169d2c82f73b06d590e Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 22 Nov 2023 14:12:42 +0100 Subject: [PATCH 003/186] [ci skip] continue flow matching implementation --- src/ott/neural/data/dataloaders.py | 22 +++++ src/ott/neural/solvers/base_solver.py | 33 ++++++- src/ott/neural/solvers/flow_matching.py | 117 ++++++++++++++++++++---- src/ott/neural/solvers/flows.py | 56 ++++++++++++ 4 files changed, 209 insertions(+), 19 deletions(-) create mode 100644 src/ott/neural/data/dataloaders.py create mode 100644 src/ott/neural/solvers/flows.py diff --git a/src/ott/neural/data/dataloaders.py b/src/ott/neural/data/dataloaders.py new file mode 100644 index 000000000..8ea1f5571 --- /dev/null +++ b/src/ott/neural/data/dataloaders.py @@ -0,0 +1,22 @@ +from typing import Dict + +import jax +import jax.numpy as jnp +import tensorflow as tf + + +class ConditionalDataLoader: + + def __init__( + self, rng: jax.random.KeyArray, dataloaders: Dict[str, tf.Dataloader], + p: jax.Array + ) -> None: + super().__init__() + self.rng = rng + self.conditions = dataloaders.keys() + self.p = p + + def __next__(self) -> jnp.ndarray: + self.rng, rng = jax.random.split(self.rng, 2) + condition = jax.random.choice(rng, self.conditions, p=self.p) + return next(self.dataloaders[condition]) diff --git a/src/ott/neural/solvers/base_solver.py b/src/ott/neural/solvers/base_solver.py index caca5e732..045e8bc86 100644 --- a/src/ott/neural/solvers/base_solver.py +++ b/src/ott/neural/solvers/base_solver.py @@ -15,6 +15,7 @@ import jax.numpy as jnp import optax from flax import train_state +from jax import random from ott.geometry.pointcloud import PointCloud from ott.neural.solvers import models @@ -61,6 +62,36 @@ def training_logs(self) -> Dict[str, Any]: pass +class ResampleMixin: + + def _resample_data( + self, + key: jax.random.KeyArray, + tmat: jnp.ndarray, + source_arrays: Tuple[jnp.ndarray, ...], + target_arrays: Tuple[jnp.ndarray, ...], + ) -> Tuple[jnp.ndarray, ...]: + """Resample a batch according to coupling `tmat`.""" + transition_matrix = tmat.flatten() + indices = random.choice( + key, transition_matrix.flatten(), shape=[len(transition_matrix) ** 2] + ) + indices_source = indices // self.batch_size + indices_target = indices % self.batch_size + return tuple( + b[indices_source] if b is not None else None for b in source_arrays + ), tuple( + b[indices_target] if b is not None else None for b in target_arrays + ) + + def _resample_data_conditionally( + self, + *args: Any, + **kwargs: Any, + ): + raise NotImplementedError + + class UnbalancednessMixin: def __init__( @@ -132,7 +163,7 @@ def compute_unbalanced_marginals( return compute_unbalanced_marginals @jax.jit - def _resample( + def _resample_unbalanced( self, key: jax.random.KeyArray, batch: Tuple[jnp.ndarray, ...], diff --git a/src/ott/neural/solvers/flow_matching.py b/src/ott/neural/solvers/flow_matching.py index 98c05cb8f..966a1f3b0 100644 --- a/src/ott/neural/solvers/flow_matching.py +++ b/src/ott/neural/solvers/flow_matching.py @@ -1,15 +1,27 @@ -from typing import Any, Callable, Dict, Optional, Type +import functools +from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Type +import jax import jax.numpy as jnp import orbax as obx +from jax import random -from ott.geometry import costs +from ott.geometry import costs, pointcloud from ott.neural.models.models import BaseNeuralVectorField -from ott.neural.solver.base_solver import BaseNeuralSolver, UnbalancednessMixin +from ott.neural.solver.base_solver import ( + BaseNeuralSolver, + MatchMixin, + UnbalancednessMixin, +) +from ott.neural.solvers.flows import ( + BaseFlow, + ConstantNoiseFlow, +) +from ott.problems.linear import linear_problem from ott.solvers import was_solver -class FlowMatching(BaseNeuralSolver, UnbalancednessMixin): +class FlowMatching(BaseNeuralSolver, MatchMixin, UnbalancednessMixin): def __init__( self, @@ -18,6 +30,7 @@ def __init__( iterations: int, valid_freq: int, ot_solver: Type[was_solver.WassersteinSolver], + flow: Type[BaseFlow] = ConstantNoiseFlow(0), optimizer: Optional[Any] = None, checkpoint_manager: Type[obx.CheckpointManager] = None, epsilon: float = 1e-2, @@ -32,18 +45,21 @@ def __init__( seed: int = 0, **kwargs: Any, ) -> None: - - super().__init__(iterations=iterations, valid_freq=valid_freq) - super(UnbalancednessMixin, self).__init__( - mlp_eta=mlp_eta, - mlp_xi=mlp_xi, + super().__init__( + iterations=iterations, + valid_freq=valid_freq, tau_a=tau_a, tau_b=tau_b, - **unbalanced_kwargs + mlp_eta=mlp_eta, + mlp_xi=mlp_xi, + unbalanced_kwargs=unbalanced_kwargs, + **kwargs ) + self.neural_vector_field = neural_vector_field self.input_dim = input_dim self.ot_solver = ot_solver + self.flow = flow self.optimizer = optimizer self.epsilon = epsilon self.cost_fn = cost_fn @@ -57,7 +73,6 @@ def setup(self, **kwargs: Any) -> None: ) self.step_fn = self._get_step_fn() - self.match_fn = self._get_match_fn( self.ot_solver, epsilon=self.epsilon, @@ -67,18 +82,80 @@ def setup(self, **kwargs: Any) -> None: scale_cost=self.scale_cost, ) - def _get_match_fn(self): - pass + def _get_step_fn(self) -> Callable: + + def step_fn( + key: random.PRNGKeyArray, + state_neural_vector_field: Any, + batch: Dict[str, jnp.ndarray], + ) -> Tuple[Any, Any]: + + def loss_fn( + params: jax.Array, t: jax.Array, noise: jax.Array, + batch: Dict[str, jnp.ndarray], keys_model: random.PRNGKeyArray + ) -> jnp.ndarray: + + x_t = self.flow.compute_xt(noise, t, batch["source"], batch["target"]) + apply_fn = functools.partial( + state_neural_vector_field.apply, {"params": params} + ) + v_t = jax.vmap(apply_fn)( + t=t, x_t=x_t, condition=batch["condition"], keys_model=keys_model + ) + u_t = self.flow.compute_ut(t, batch["source"], batch["target"]) + return jnp.mean((v_t - u_t) ** 2) + + batch_size = len(batch["source"]) + key_noise, key_t, key_model = random.split(key, 3) + keys_model = random.split(key_model, batch_size) + t = self.sample_t(key_t, batch_size) + noise = self.sample_noise(key_noise, batch_size) + loss_grad = jax.value_and_grad(loss_fn) + loss, grads = loss_grad( + state_neural_vector_field.params, t, noise, batch, keys_model + ) + return state_neural_vector_field.apply_gradients(grads), loss + + return step_fn + + def _get_match_fn( + self, + ot_solver: Any, + epsilon: float, + cost_fn: str, + tau_a: float, + tau_b: float, + scale_cost: Any, + ) -> Callable: + + def match_pairs( + x: jax.Array, y: jax.Array + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: + geom = pointcloud.PointCloud( + x, y, epsilon=epsilon, scale_cost=scale_cost, cost_fn=cost_fn + ) + return ot_solver( + linear_problem.LinearProblem(geom, tau_a=tau_a, tau_b=tau_b) + ).matrix + + return match_pairs def __call__(self, train_loader, valid_loader) -> None: + batch: Mapping[str, jnp.ndarray] = {} for iter in range(self.iterations): - batch = next(train_loader) - batch, a, b = self.match_fn(batch) - self.state_neural_vector_field, logs = self.step_fn( + batch["source"], batch["target"], batch["condition"] = next(train_loader) + tmat = self.match_fn(batch) + batch = self.resample( + batch, tmat, (batch["source"], batch["condition"]), + (batch["target"], batch["condition"]) + ) + self.state_neural_vector_field, loss = self.step_fn( self.state_neural_vector_field, batch ) - if not self.is_balanced: - self.unbalancedness_step_fn(batch, a, b) + if self.learn_rescaling: + self.state_eta, self.state_xi, eta_predictions, xi_predictions, loss_a, loss_b = self.unbalancedness_step_fn( + batch, tmat.sum(axis=1), tmat.sum(axis=0) + ) if iter % self.valid_freq == 0: self._valid_step(valid_loader, iter) if self.checkpoint_manager is not None: @@ -98,3 +175,7 @@ def _valid_step(self, valid_loader, iter) -> None: self.unbalancedness_step_fn(batch, a, b) if self.callback_fn is not None: self.callback_fn(batch, a, b) + + @property + def learn_rescaling(self) -> bool: + return self.mlp_eta is not None or self.mlp_xi is not None diff --git a/src/ott/neural/solvers/flows.py b/src/ott/neural/solvers/flows.py new file mode 100644 index 000000000..b8c635f61 --- /dev/null +++ b/src/ott/neural/solvers/flows.py @@ -0,0 +1,56 @@ +import abc + +import jax +import jax.numpy as jnp + + +class BaseFlow(abc.ABC): + + def __init__(self, sigma: float) -> None: + self.sigma = sigma + + @abc.abstractmethod + def compute_mu_t(self, t: jax.Array, x_0: jax.Array, x_1: jax.Array): + pass + + @abc.abstractmethod + def compute_sigma_t(self, t: jax.Array): + pass + + @abc.abstractmethod + def compute_ut( + self, t: jax.Array, x_0: jax.Array, x_1: jax.Array + ) -> jax.Array: + pass + + def compute_xt( + self, noise: jax.Array, t: jax.Array, x_0: jax.Array, x_1: jax.Array + ) -> jax.Array: + mu_t = self.compute_mu_t(t, x_0, x_1) + sigma_t = self.compute_sigma_t(t, x_0, x_1) + return mu_t + sigma_t * noise + + +class StraightFlow(BaseFlow): + + def compute_mu_t( + self, t: jax.Array, x_0: jax.Array, x_1: jax.Array + ) -> jax.Array: + return t * x_0 + (1 - t) * x_1 + + def compute_ut( + self, t: jax.Array, x_0: jax.Array, x_1: jax.Array + ) -> jax.Array: + return x_1 - x_0 + + +class ConstantNoiseFlow(StraightFlow): + + def compute_sigma_t(self, t: jax.Array): + return self.sigma + + +class BrownianNoiseFlow(StraightFlow): + + def compute_sigma_t(self, t: jax.Array): + return jnp.sqrt(self.sigma * t * (1 - t)) From 65832815ae3b1706fa61dd43b8194ec80a635bed Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 22 Nov 2023 14:35:36 +0100 Subject: [PATCH 004/186] [ci skip] continue flow matching implementation --- src/ott/neural/models/models.py | 12 ++++++++ src/ott/neural/solvers/base_solver.py | 5 +++ src/ott/neural/solvers/flow_matching.py | 41 ++++++++++++++++++++++--- tests/neural/test_flow_matching.py | 4 +++ 4 files changed, 57 insertions(+), 5 deletions(-) create mode 100644 tests/neural/test_flow_matching.py diff --git a/src/ott/neural/models/models.py b/src/ott/neural/models/models.py index c5485ef59..1983075b0 100644 --- a/src/ott/neural/models/models.py +++ b/src/ott/neural/models/models.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc import functools from typing import Any, Callable, Dict, Optional, Sequence, Tuple @@ -403,3 +404,14 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 "rng": self.rng, "state": self.state } + + +class BaseNeuralVectorField(abc.ABC): + + def __call__( + self, + t: jax.Array, + condition: Optional[jax.Array] = None, + keys_model: Optional[jax.Array] = None + ) -> jnp.ndarray: # noqa: D102): + pass diff --git a/src/ott/neural/solvers/base_solver.py b/src/ott/neural/solvers/base_solver.py index 045e8bc86..bb0a30f22 100644 --- a/src/ott/neural/solvers/base_solver.py +++ b/src/ott/neural/solvers/base_solver.py @@ -44,6 +44,11 @@ def __call__(self, *args: Any, **kwargs: Any) -> None: """Train the model.""" pass + @abstractmethod + def transport(self, *args: Any, forward: bool, **kwargs: Any) -> Any: + """Transport.""" + pass + @abstractmethod def save(self, path: Path): """Save the model.""" diff --git a/src/ott/neural/solvers/flow_matching.py b/src/ott/neural/solvers/flow_matching.py index 966a1f3b0..e573d8444 100644 --- a/src/ott/neural/solvers/flow_matching.py +++ b/src/ott/neural/solvers/flow_matching.py @@ -1,6 +1,8 @@ import functools +import types from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Type +import diffrax import jax import jax.numpy as jnp import orbax as obx @@ -168,13 +170,42 @@ def __call__(self, train_loader, valid_loader) -> None: states_to_save["state_xi"] = self.state_xi self.checkpoint_manager.save(iter, states_to_save) + def transport( + self, + data: jnp.array, + condition: Optional[jax.Array], + rng: random.PRNGKey, + forward: bool = True, + diffeqsolve_kwargs: Dict[str, Any] = types.MappingProxyType({}) + ) -> diffrax.Solution: + diffeqsolve_kwargs = dict(diffeqsolve_kwargs) + t0, t1 = (0, 1) if forward else (1, 0) + return diffrax.diffeqsolve( + diffrax.ODETerm( + lambda t, y: self.state_neural_vector_field. + apply({"params": self.state_neural_vector_field.params}, + t=t, + x=y, + condition=condition) + ), + diffeqsolve_kwargs.pop("solver", diffrax.Tsit5()), + t0=t0, + t1=t1, + dt0=diffeqsolve_kwargs.pop("dt0", None), + y0=data, + stepsize_controller=diffeqsolve_kwargs.pop( + "stepsize_controller", diffrax.PIDController(rtol=1e-5, atol=1e-5) + ), + **diffeqsolve_kwargs, + ) + def _valid_step(self, valid_loader, iter) -> None: batch = next(valid_loader) - batch, a, b = self.match_fn(batch) - if not self.is_balanced: - self.unbalancedness_step_fn(batch, a, b) - if self.callback_fn is not None: - self.callback_fn(batch, a, b) + tmat = self.match_fn(batch) + batch = self.resample( + batch, tmat, (batch["source"], batch["condition"]), + (batch["target"], batch["condition"]) + ) @property def learn_rescaling(self) -> bool: diff --git a/tests/neural/test_flow_matching.py b/tests/neural/test_flow_matching.py new file mode 100644 index 000000000..ca55f901d --- /dev/null +++ b/tests/neural/test_flow_matching.py @@ -0,0 +1,4 @@ +class TestFlowMatching: + + def test_flow_matching(self): + pass From f5a043cd7b88b16702b164e8faf66793e28993f7 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 22 Nov 2023 16:15:32 +0100 Subject: [PATCH 005/186] [ci skip] add neural networks --- src/ott/neural/models/models.py | 211 +++++++++++++++++++++++- src/ott/neural/solvers/flow_matching.py | 10 +- 2 files changed, 215 insertions(+), 6 deletions(-) diff --git a/src/ott/neural/models/models.py b/src/ott/neural/models/models.py index 1983075b0..2277f6145 100644 --- a/src/ott/neural/models/models.py +++ b/src/ott/neural/models/models.py @@ -15,7 +15,9 @@ import functools from typing import Any, Callable, Dict, Optional, Sequence, Tuple +import flax.linen as nn import jax +import jax.numpy as jnp import optax from flax import linen as nn from flax.core import frozen_dict @@ -30,6 +32,7 @@ from ott.neural.models import layers from ott.neural.solvers import neuraldual from ott.problems.linear import linear_problem +from ott.solvers.nn.models import NeuralTrainState __all__ = ["ICNN", "MLP", "MetaInitializer"] @@ -406,12 +409,218 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 } -class BaseNeuralVectorField(abc.ABC): +class Block(nn.Module): + dim: int = 128 + num_layers: int = 3 + activation_fn: Any = nn.silu + out_dim: int = 32 + @nn.compact + def __call__(self, x): + for i in range(self.num_layers): + x = nn.Dense(self.dim, name="fc{0}".format(i))(x) + x = self.activation_fn(x) + x = nn.Dense(self.out_dim, name="fc_final")(x) + return x + + +class BaseNeuralVectorField(nn.Module, abc.ABC): + + @abc.abstractmethod def __call__( self, t: jax.Array, + x: jax.Array, condition: Optional[jax.Array] = None, keys_model: Optional[jax.Array] = None ) -> jnp.ndarray: # noqa: D102): pass + + +class Block(nn.Module): + dim: int = 128 + out_dim: int = 32 + num_layers: int = 3 + act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.silu + + @nn.compact + def __call__(self, x): + for i in range(self.num_layers): + x = nn.Dense(self.dim, name="fc{0}".format(i))(x) + x = self.act_fn(x) + x = nn.Dense(self.out_dim, name="fc_final")(x) + return x + + +class NeuralVectorField(BaseNeuralVectorField): + condition_dim: int + latent_embed_dim: int + condition_embed_dim: Optional[int] = None + t_embed_dim: Optional[int] = None + joint_hidden_dim: Optional[int] = None + num_layers_per_block: int = 3 + act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.silu + n_frequencies: int = 128 + + def time_encoder(self, t: jax.Array) -> jnp.array: + freq = 2 * jnp.arange(self.n_frequencies) * jnp.pi + t = freq * t + return jnp.concatenate((jnp.cos(t), jnp.sin(t)), axis=-1) + + def __post_init__(self): + + # set embedded dim from latent embedded dim + if self.condition_embed_dim is None: + self.condition_embed_dim = self.latent_embed_dim + if self.t_embed_dim is None: + self.t_embed_dim = self.latent_embed_dim + + # set joint hidden dim from all embedded dim + concat_embed_dim = ( + self.latent_embed_dim + self.condition_embed_dim + self.t_embed_dim + ) + if self.joint_hidden_dim is not None: + assert (self.joint_hidden_dim >= concat_embed_dim), ( + "joint_hidden_dim must be greater than or equal to the sum of " + "all embedded dimensions. " + ) + self.joint_hidden_dim = self.latent_embed_dim + else: + self.joint_hidden_dim = concat_embed_dim + super().__post_init__() + + @nn.compact + def __call__( + self, + t: jax.Array, + condition: Optional[jax.Array], + latent: jax.Array, + keys_model: Optional[jax.Array] = None, + ) -> jax.Array: + + t = self.time_encoder(t) + t = Block( + dim=self.t_embed_dim, + out_dim=self.t_embed_dim, + num_layers=self.num_layers_per_block, + act_fn=self.act_fn, + )( + t + ) + + data = Block( + dim=self.latent_embed_dim, + out_dim=self.latent_embed_dim, + num_layers=self.num_layers_per_block, + act_fn=self.act_fn + )( + data + ) + + if self.condition_dim > 0: + condition = Block( + dim=self.condition_embed_dim, + out_dim=self.condition_embed_dim, + num_layers=self.num_layers_per_block, + act_fn=self.act_fn + )( + condition + ) + concatenated = jnp.concatenate((t, data, condition), axis=-1) + else: + concatenated = jnp.concatenate((t, data), axis=-1) + + out = Block( + dim=self.joint_hidden_dim, + out_dim=self.joint_hidden_dim, + num_layers=self.num_layers_per_block, + act_fn=self.act_fn, + )( + concatenated + ) + + return nn.Dense( + self.output_dim, + use_bias=True, + )( + out + ) + + def create_train_state( + self, + rng: jax.random.PRNGKeyArray, + optimizer: optax.OptState, + input_dim: int, + ) -> NeuralTrainState: + params = self.init( + rng, jnp.ones((1, 1)), jnp.ones((1, input_dim)), + jnp.ones((1, self.condition_dim)) + )["params"] + return train_state.TrainState.create( + apply_fn=self.apply, params=params, tx=optimizer + ) + + +class BaseRescalingNet(nn.Module, abc.ABC): + + @abc.abstractmethod + def __call___( + self, x: jax.Array, condition: Optional[jax.Array] = None + ) -> jax.Array: + pass + + +class Rescaling_MLP(nn.Module): + hidden_dim: int + cond_dim: int + is_potential: bool = False + act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.selu + + @nn.compact + def __call__( + self, x: jnp.ndarray, condition: Optional[jax.Array] + ) -> jnp.ndarray: # noqa: D102 + x = Block( + dim=self.latent_embed_dim, + out_dim=self.latent_embed_dim, + num_layers=self.num_layers_per_block, + act_fn=self.act_fn + )( + x + ) + if self.condition_dim > 0: + condition = Block( + dim=self.condition_embed_dim, + out_dim=self.condition_embed_dim, + num_layers=self.num_layers_per_block, + act_fn=self.act_fn + )( + condition + ) + concatenated = jnp.concatenate((x, condition), axis=-1) + else: + concatenated = x + + out = Block( + dim=self.joint_hidden_dim, + out_dim=self.joint_hidden_dim, + num_layers=self.num_layers_per_block, + act_fn=self.act_fn, + )( + concatenated + ) + + return jnp.exp(out) + + def create_train_state( + self, + rng: jax.random.PRNGKeyArray, + optimizer: optax.OptState, + input_dim: int, + ) -> train_state.TrainState: + params = self.init( + rng, jnp.ones((1, input_dim)), jnp.ones((1, self.cond_dim)) + )["params"] + return train_state.TrainState.create( + apply_fn=self.apply, params=params, tx=optimizer + ) diff --git a/src/ott/neural/solvers/flow_matching.py b/src/ott/neural/solvers/flow_matching.py index e573d8444..2db7833ee 100644 --- a/src/ott/neural/solvers/flow_matching.py +++ b/src/ott/neural/solvers/flow_matching.py @@ -11,13 +11,13 @@ from ott.geometry import costs, pointcloud from ott.neural.models.models import BaseNeuralVectorField from ott.neural.solver.base_solver import ( - BaseNeuralSolver, - MatchMixin, - UnbalancednessMixin, + BaseNeuralSolver, + MatchMixin, + UnbalancednessMixin, ) from ott.neural.solvers.flows import ( - BaseFlow, - ConstantNoiseFlow, + BaseFlow, + ConstantNoiseFlow, ) from ott.problems.linear import linear_problem from ott.solvers import was_solver From 34bb10fe41185d52dbe514abb8f8d09df71877e2 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 22 Nov 2023 17:01:05 +0100 Subject: [PATCH 006/186] [ci skip] add test --- src/ott/neural/models/models.py | 30 ++++--------------- src/ott/neural/solvers/flow_matching.py | 6 ++-- tests/neural/conftest.py | 38 +++++++++++++++++++++++++ tests/neural/flow_matching_test.py | 27 ++++++++++++++++++ tests/neural/test_flow_matching.py | 4 --- 5 files changed, 74 insertions(+), 31 deletions(-) create mode 100644 tests/neural/conftest.py create mode 100644 tests/neural/flow_matching_test.py delete mode 100644 tests/neural/test_flow_matching.py diff --git a/src/ott/neural/models/models.py b/src/ott/neural/models/models.py index 2277f6145..0d28ab9fe 100644 --- a/src/ott/neural/models/models.py +++ b/src/ott/neural/models/models.py @@ -19,10 +19,8 @@ import jax import jax.numpy as jnp import optax -from flax import linen as nn from flax.core import frozen_dict from flax.training import train_state -from jax import numpy as jnp from jax.nn import initializers from ott import utils @@ -420,8 +418,7 @@ def __call__(self, x): for i in range(self.num_layers): x = nn.Dense(self.dim, name="fc{0}".format(i))(x) x = self.activation_fn(x) - x = nn.Dense(self.out_dim, name="fc_final")(x) - return x + return nn.Dense(self.out_dim)(x) class BaseNeuralVectorField(nn.Module, abc.ABC): @@ -437,21 +434,6 @@ def __call__( pass -class Block(nn.Module): - dim: int = 128 - out_dim: int = 32 - num_layers: int = 3 - act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.silu - - @nn.compact - def __call__(self, x): - for i in range(self.num_layers): - x = nn.Dense(self.dim, name="fc{0}".format(i))(x) - x = self.act_fn(x) - x = nn.Dense(self.out_dim, name="fc_final")(x) - return x - - class NeuralVectorField(BaseNeuralVectorField): condition_dim: int latent_embed_dim: int @@ -493,8 +475,8 @@ def __post_init__(self): def __call__( self, t: jax.Array, + x: jax.Array, condition: Optional[jax.Array], - latent: jax.Array, keys_model: Optional[jax.Array] = None, ) -> jax.Array: @@ -508,13 +490,13 @@ def __call__( t ) - data = Block( + x = Block( dim=self.latent_embed_dim, out_dim=self.latent_embed_dim, num_layers=self.num_layers_per_block, act_fn=self.act_fn )( - data + x ) if self.condition_dim > 0: @@ -526,9 +508,9 @@ def __call__( )( condition ) - concatenated = jnp.concatenate((t, data, condition), axis=-1) + concatenated = jnp.concatenate((t, x, condition), axis=-1) else: - concatenated = jnp.concatenate((t, data), axis=-1) + concatenated = jnp.concatenate((t, x), axis=-1) out = Block( dim=self.joint_hidden_dim, diff --git a/src/ott/neural/solvers/flow_matching.py b/src/ott/neural/solvers/flow_matching.py index 2db7833ee..0ff68604e 100644 --- a/src/ott/neural/solvers/flow_matching.py +++ b/src/ott/neural/solvers/flow_matching.py @@ -5,6 +5,7 @@ import diffrax import jax import jax.numpy as jnp +import optax import orbax as obx from jax import random @@ -17,7 +18,6 @@ ) from ott.neural.solvers.flows import ( BaseFlow, - ConstantNoiseFlow, ) from ott.problems.linear import linear_problem from ott.solvers import was_solver @@ -32,8 +32,8 @@ def __init__( iterations: int, valid_freq: int, ot_solver: Type[was_solver.WassersteinSolver], - flow: Type[BaseFlow] = ConstantNoiseFlow(0), - optimizer: Optional[Any] = None, + flow: Type[BaseFlow], + optimizer: Type[optax.GradientTransformation], checkpoint_manager: Type[obx.CheckpointManager] = None, epsilon: float = 1e-2, cost_fn: Type[costs.CostFn] = costs.SqEuclidean(), diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py new file mode 100644 index 000000000..682aa5df8 --- /dev/null +++ b/tests/neural/conftest.py @@ -0,0 +1,38 @@ +from typing import Iterator + +import pytest + +from ott import datasets + + +class UnconditionalDataLoader: + + def __init__(self, iter: Iterator): + self.iter = iter + + def __next__(self): + return next(self.iter), None + + +@pytest.fixture(scope="module") +def data_loader_gaussian_1(): + """Returns a data loader for a simple Gaussian mixture.""" + loader = datasets.create_gaussian_mixture_samplers( + name_source="simple", + name_target="circle", + train_batch_size=30, + valid_batch_size=30, + ) + return UnconditionalDataLoader(loader[0]) + + +@pytest.fixture(scope="module") +def data_loader_gaussian_2(): + """Returns a data loader for a simple Gaussian mixture.""" + loader = datasets.create_gaussian_mixture_samplers( + name_source="simple", + name_target="circle", + train_batch_size=30, + valid_batch_size=30, + ) + return UnconditionalDataLoader(loader[0] + 1) diff --git a/tests/neural/flow_matching_test.py b/tests/neural/flow_matching_test.py new file mode 100644 index 000000000..2cab8608f --- /dev/null +++ b/tests/neural/flow_matching_test.py @@ -0,0 +1,27 @@ +import optax + +from ott.neural.flow_matching import FlowMatching +from ott.neural.flows import ConstantNoiseFlow +from ott.neural.models import NeuralVectorField +from ott.solvers.linear import sinkhorn + + +class TestFlowMatching: + + def test_flow_matching(self, data_loader_gaussian_1, data_loader_gaussian_2): + neural_vf = NeuralVectorField( + input_dim=2, hidden_dims=[32, 32], output_dim=2, activation="relu" + ) + ot_solver = sinkhorn.SinkhornSolver() + flow = ConstantNoiseFlow(sigma=0) + optimizer = optax.adam(learning_rate=1e-3) + fm = FlowMatching( + neural_vf, + input_dim=2, + iterations=3, + valid_freq=2, + ot_solver=ot_solver, + flow=flow, + optimizer=optimizer + ) + fm(data_loader_gaussian_1, data_loader_gaussian_2) diff --git a/tests/neural/test_flow_matching.py b/tests/neural/test_flow_matching.py deleted file mode 100644 index ca55f901d..000000000 --- a/tests/neural/test_flow_matching.py +++ /dev/null @@ -1,4 +0,0 @@ -class TestFlowMatching: - - def test_flow_matching(self): - pass From 374d05194018a869358aff01a6e4aaff9e2d35ed Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 22 Nov 2023 17:52:05 +0100 Subject: [PATCH 007/186] [ci skip] resolve import errors --- src/ott/neural/__init__.py | 2 +- src/ott/neural/data/__init__.py | 14 ++++++++++++ src/ott/neural/data/dataloaders.py | 30 ++++++++++++------------- src/ott/neural/models/models.py | 3 +-- src/ott/neural/solvers/base_solver.py | 26 ++++++++------------- src/ott/neural/solvers/flow_matching.py | 16 ++++++------- tests/neural/flow_matching_test.py | 6 ++--- 7 files changed, 50 insertions(+), 47 deletions(-) create mode 100644 src/ott/neural/data/__init__.py diff --git a/src/ott/neural/__init__.py b/src/ott/neural/__init__.py index f448c5dbe..326fae432 100644 --- a/src/ott/neural/__init__.py +++ b/src/ott/neural/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import models, solvers +from . import data, models, solvers diff --git a/src/ott/neural/data/__init__.py b/src/ott/neural/data/__init__.py new file mode 100644 index 000000000..51f8dd2af --- /dev/null +++ b/src/ott/neural/data/__init__.py @@ -0,0 +1,14 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from . import dataloaders diff --git a/src/ott/neural/data/dataloaders.py b/src/ott/neural/data/dataloaders.py index 8ea1f5571..c8976d348 100644 --- a/src/ott/neural/data/dataloaders.py +++ b/src/ott/neural/data/dataloaders.py @@ -1,22 +1,20 @@ -from typing import Dict -import jax -import jax.numpy as jnp -import tensorflow as tf +#import tensorflow as tf class ConditionalDataLoader: + pass - def __init__( - self, rng: jax.random.KeyArray, dataloaders: Dict[str, tf.Dataloader], - p: jax.Array - ) -> None: - super().__init__() - self.rng = rng - self.conditions = dataloaders.keys() - self.p = p + #def __init__( + # self, rng: jax.random.KeyArray, dataloaders: Dict[str, tf.Dataloader], + # p: jax.Array + #) -> None: + # super().__init__() + # self.rng = rng + # self.conditions = dataloaders.keys() + # self.p = p - def __next__(self) -> jnp.ndarray: - self.rng, rng = jax.random.split(self.rng, 2) - condition = jax.random.choice(rng, self.conditions, p=self.p) - return next(self.dataloaders[condition]) + #def __next__(self) -> jnp.ndarray: + # self.rng, rng = jax.random.split(self.rng, 2) + # condition = jax.random.choice(rng, self.conditions, p=self.p) + # return next(self.dataloaders[condition]) diff --git a/src/ott/neural/models/models.py b/src/ott/neural/models/models.py index 0d28ab9fe..18d86144c 100644 --- a/src/ott/neural/models/models.py +++ b/src/ott/neural/models/models.py @@ -30,7 +30,6 @@ from ott.neural.models import layers from ott.neural.solvers import neuraldual from ott.problems.linear import linear_problem -from ott.solvers.nn.models import NeuralTrainState __all__ = ["ICNN", "MLP", "MetaInitializer"] @@ -533,7 +532,7 @@ def create_train_state( rng: jax.random.PRNGKeyArray, optimizer: optax.OptState, input_dim: int, - ) -> NeuralTrainState: + ) -> train_state.TrainState: params = self.init( rng, jnp.ones((1, 1)), jnp.ones((1, input_dim)), jnp.ones((1, self.condition_dim)) diff --git a/src/ott/neural/solvers/base_solver.py b/src/ott/neural/solvers/base_solver.py index bb0a30f22..ed716b50c 100644 --- a/src/ott/neural/solvers/base_solver.py +++ b/src/ott/neural/solvers/base_solver.py @@ -1,24 +1,16 @@ from abc import ABC, abstractmethod from pathlib import Path -from types import Mapping, MappingProxyType -from typing import ( - Any, - Callable, - Dict, - Literal, - Optional, - Tuple, - Union, -) +from types import MappingProxyType +from typing import Any, Callable, Dict, Literal, Mapping, Optional, Tuple, Union import jax import jax.numpy as jnp import optax -from flax import train_state +from flax.training import train_state from jax import random from ott.geometry.pointcloud import PointCloud -from ott.neural.solvers import models +from ott.neural.models import models from ott.problems.linear import linear_problem from ott.solvers.linear import sinkhorn @@ -54,14 +46,14 @@ def save(self, path: Path): """Save the model.""" pass - @abstractmethod @property - def is_balanced(self) -> Dict[str, Any]: + @abstractmethod + def is_balanced(self) -> bool: """Return the training logs.""" pass - @abstractmethod @property + @abstractmethod def training_logs(self) -> Dict[str, Any]: """Return the training logs.""" pass @@ -106,8 +98,8 @@ def __init__( cond_dim: Optional[int], tau_a: float = 1.0, tau_b: float = 1.0, - mlp_eta: Optional[models.ModelBase] = None, - mlp_xi: Optional[models.ModelBase] = None, + mlp_eta: Optional[models.BaseRescalingNet] = None, + mlp_xi: Optional[models.BaseRescalingNet] = None, seed: Optional[int] = None, opt_eta: Optional[optax.GradientTransformation] = None, opt_xi: Optional[optax.GradientTransformation] = None, diff --git a/src/ott/neural/solvers/flow_matching.py b/src/ott/neural/solvers/flow_matching.py index 0ff68604e..2a99c9163 100644 --- a/src/ott/neural/solvers/flow_matching.py +++ b/src/ott/neural/solvers/flow_matching.py @@ -6,24 +6,24 @@ import jax import jax.numpy as jnp import optax -import orbax as obx from jax import random +from orbax import checkpoint from ott.geometry import costs, pointcloud from ott.neural.models.models import BaseNeuralVectorField -from ott.neural.solver.base_solver import ( - BaseNeuralSolver, - MatchMixin, - UnbalancednessMixin, +from ott.neural.solvers.base_solver import ( + BaseNeuralSolver, + ResampleMixin, + UnbalancednessMixin, ) from ott.neural.solvers.flows import ( - BaseFlow, + BaseFlow, ) from ott.problems.linear import linear_problem from ott.solvers import was_solver -class FlowMatching(BaseNeuralSolver, MatchMixin, UnbalancednessMixin): +class FlowMatching(BaseNeuralSolver, ResampleMixin, UnbalancednessMixin): def __init__( self, @@ -34,7 +34,7 @@ def __init__( ot_solver: Type[was_solver.WassersteinSolver], flow: Type[BaseFlow], optimizer: Type[optax.GradientTransformation], - checkpoint_manager: Type[obx.CheckpointManager] = None, + checkpoint_manager: Type[checkpoint.CheckpointManager] = None, epsilon: float = 1e-2, cost_fn: Type[costs.CostFn] = costs.SqEuclidean(), tau_a: float = 1.0, diff --git a/tests/neural/flow_matching_test.py b/tests/neural/flow_matching_test.py index 2cab8608f..8c31bcc26 100644 --- a/tests/neural/flow_matching_test.py +++ b/tests/neural/flow_matching_test.py @@ -1,8 +1,8 @@ import optax -from ott.neural.flow_matching import FlowMatching -from ott.neural.flows import ConstantNoiseFlow -from ott.neural.models import NeuralVectorField +from ott.neural.models.models import NeuralVectorField +from ott.neural.solvers.flow_matching import FlowMatching +from ott.neural.solvers.flows import ConstantNoiseFlow from ott.solvers.linear import sinkhorn From a9e9a8c521008df5f65cd1812bdabb35d9e1ff25 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 22 Nov 2023 18:50:35 +0100 Subject: [PATCH 008/186] [ci skip] MRO not working --- src/ott/neural/data/dataloaders.py | 1 - src/ott/neural/models/models.py | 5 +- src/ott/neural/solvers/base_solver.py | 9 ++-- src/ott/neural/solvers/flow_matching.py | 40 ++++++++++++---- tests/neural/conftest.py | 63 +++++++++++++------------ tests/neural/flow_matching_test.py | 11 +++-- 6 files changed, 77 insertions(+), 52 deletions(-) diff --git a/src/ott/neural/data/dataloaders.py b/src/ott/neural/data/dataloaders.py index c8976d348..2bddebaa1 100644 --- a/src/ott/neural/data/dataloaders.py +++ b/src/ott/neural/data/dataloaders.py @@ -1,4 +1,3 @@ - #import tensorflow as tf diff --git a/src/ott/neural/models/models.py b/src/ott/neural/models/models.py index 18d86144c..0e4d65203 100644 --- a/src/ott/neural/models/models.py +++ b/src/ott/neural/models/models.py @@ -409,14 +409,14 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 class Block(nn.Module): dim: int = 128 num_layers: int = 3 - activation_fn: Any = nn.silu + act_fn: Any = nn.silu out_dim: int = 32 @nn.compact def __call__(self, x): for i in range(self.num_layers): x = nn.Dense(self.dim, name="fc{0}".format(i))(x) - x = self.activation_fn(x) + x = self.act_fn(x) return nn.Dense(self.out_dim)(x) @@ -434,6 +434,7 @@ def __call__( class NeuralVectorField(BaseNeuralVectorField): + output_dim: int condition_dim: int latent_embed_dim: int condition_embed_dim: Optional[int] = None diff --git a/src/ott/neural/solvers/base_solver.py b/src/ott/neural/solvers/base_solver.py index ed716b50c..b1a5c108d 100644 --- a/src/ott/neural/solvers/base_solver.py +++ b/src/ott/neural/solvers/base_solver.py @@ -46,12 +46,6 @@ def save(self, path: Path): """Save the model.""" pass - @property - @abstractmethod - def is_balanced(self) -> bool: - """Return the training logs.""" - pass - @property @abstractmethod def training_logs(self) -> Dict[str, Any]: @@ -61,6 +55,9 @@ def training_logs(self) -> Dict[str, Any]: class ResampleMixin: + def __init__(*args, **kwargs): + pass + def _resample_data( self, key: jax.random.KeyArray, diff --git a/src/ott/neural/solvers/flow_matching.py b/src/ott/neural/solvers/flow_matching.py index 2a99c9163..e53b9e45d 100644 --- a/src/ott/neural/solvers/flow_matching.py +++ b/src/ott/neural/solvers/flow_matching.py @@ -12,23 +12,24 @@ from ott.geometry import costs, pointcloud from ott.neural.models.models import BaseNeuralVectorField from ott.neural.solvers.base_solver import ( - BaseNeuralSolver, - ResampleMixin, - UnbalancednessMixin, + BaseNeuralSolver, + ResampleMixin, + UnbalancednessMixin, ) from ott.neural.solvers.flows import ( - BaseFlow, + BaseFlow, ) from ott.problems.linear import linear_problem from ott.solvers import was_solver -class FlowMatching(BaseNeuralSolver, ResampleMixin, UnbalancednessMixin): +class FlowMatching(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): def __init__( self, neural_vector_field: Type[BaseNeuralVectorField], input_dim: int, + cond_dim: int, iterations: int, valid_freq: int, ot_solver: Type[was_solver.WassersteinSolver], @@ -44,12 +45,15 @@ def __init__( unbalanced_kwargs: Dict[str, Any] = {}, callback_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], Any]] = None, - seed: int = 0, + rng: random.PRNGKeyArray = random.PRNGKey(0), **kwargs: Any, ) -> None: super().__init__( iterations=iterations, valid_freq=valid_freq, + source_dim=input_dim, + target_dim=input_dim, + cond_dim=cond_dim, tau_a=tau_a, tau_b=tau_b, mlp_eta=mlp_eta, @@ -67,11 +71,13 @@ def __init__( self.cost_fn = cost_fn self.callback_fn = callback_fn self.checkpoint_manager = checkpoint_manager - self.seed = seed + self.rng = rng - def setup(self, **kwargs: Any) -> None: + self.setup() + + def setup(self) -> None: self.state_neural_vector_field = self.neural_vector_field.create_train_state( - self.rng, self.optimizer, self.output_dim + self.rng, self.optimizer, self.input_dim ) self.step_fn = self._get_step_fn() @@ -210,3 +216,19 @@ def _valid_step(self, valid_loader, iter) -> None: @property def learn_rescaling(self) -> bool: return self.mlp_eta is not None or self.mlp_xi is not None + + def save(self, path: str) -> None: + raise NotImplementedError + + def training_logs(self) -> Dict[str, Any]: + raise NotImplementedError + + def sample_t( + self, key: random.PRNGKey, batch_size: int + ) -> jnp.ndarray: #TODO: make more general + return random.uniform(key, batch_size) + + def sample_noise( + self, key: random.PRNGKey, batch_size: int + ) -> jnp.ndarray: #TODO: make more general + return random.normal(key, shape=(batch_size, self.input_dim)) diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py index 682aa5df8..57dba63bf 100644 --- a/tests/neural/conftest.py +++ b/tests/neural/conftest.py @@ -1,38 +1,41 @@ -from typing import Iterator +from typing import Optional +import jax import pytest -from ott import datasets - -class UnconditionalDataLoader: - - def __init__(self, iter: Iterator): - self.iter = iter - - def __next__(self): - return next(self.iter), None - - -@pytest.fixture(scope="module") -def data_loader_gaussian_1(): - """Returns a data loader for a simple Gaussian mixture.""" - loader = datasets.create_gaussian_mixture_samplers( - name_source="simple", - name_target="circle", - train_batch_size=30, - valid_batch_size=30, - ) - return UnconditionalDataLoader(loader[0]) +class DataLoader: + + def __init__( + self, + source_data: jax.Array, + target_data: jax.Array, + conditions: Optional[jax.Array], + batch_size: int = 64 + ) -> None: + super().__init__() + self.source_data = source_data + self.target_data = target_data + self.conditions = conditions + self.batch_size = batch_size + self.key = jax.random.PRNGKey(0) + + def __next__(self) -> jax.Array: + key, self.key = jax.random.split(self.key) + inds_source = jax.random.choice( + key, len(self.source_data), shape=[self.batch_size] + ) + inds_target = jax.random.choice( + key, len(self.target_data), shape=[self.batch_size] + ) + return self.source_data[inds_source, :], self.target_data[ + inds_target, :], self.conditions[ + inds_source, :] if self.conditions is not None else None @pytest.fixture(scope="module") -def data_loader_gaussian_2(): +def data_loader_gaussian(): """Returns a data loader for a simple Gaussian mixture.""" - loader = datasets.create_gaussian_mixture_samplers( - name_source="simple", - name_target="circle", - train_batch_size=30, - valid_batch_size=30, - ) - return UnconditionalDataLoader(loader[0] + 1) + source = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) + target = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) + return DataLoader(source, target, None, 16) diff --git a/tests/neural/flow_matching_test.py b/tests/neural/flow_matching_test.py index 8c31bcc26..23ce1e178 100644 --- a/tests/neural/flow_matching_test.py +++ b/tests/neural/flow_matching_test.py @@ -8,20 +8,23 @@ class TestFlowMatching: - def test_flow_matching(self, data_loader_gaussian_1, data_loader_gaussian_2): + def test_flow_matching(self, data_loader_gaussian): neural_vf = NeuralVectorField( - input_dim=2, hidden_dims=[32, 32], output_dim=2, activation="relu" + output_dim=2, + condition_dim=0, + latent_embed_dim=5, ) - ot_solver = sinkhorn.SinkhornSolver() + ot_solver = sinkhorn.Sinkhorn() flow = ConstantNoiseFlow(sigma=0) optimizer = optax.adam(learning_rate=1e-3) fm = FlowMatching( neural_vf, input_dim=2, + cond_dim=0, iterations=3, valid_freq=2, ot_solver=ot_solver, flow=flow, optimizer=optimizer ) - fm(data_loader_gaussian_1, data_loader_gaussian_2) + fm(data_loader_gaussian, data_loader_gaussian) From e4f89918a0568904d51c209940c23262b53ccb37 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Thu, 23 Nov 2023 10:48:37 +0100 Subject: [PATCH 009/186] [ci skip] basic test for flow matching passes --- pyproject.toml | 1 + src/ott/neural/models/models.py | 2 -- src/ott/neural/solvers/base_solver.py | 14 ++++---- src/ott/neural/solvers/flow_matching.py | 44 ++++++++++++------------- src/ott/neural/solvers/flows.py | 2 +- 5 files changed, 31 insertions(+), 32 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e168ec984..56aad6a91 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ Changelog = "https://github.com/ott-jax/ott/releases" neural = [ "flax>=0.6.6", "optax>=0.1.1", + "diffrax>=0.4.1", ] dev = [ "pre-commit>=2.16.0", diff --git a/src/ott/neural/models/models.py b/src/ott/neural/models/models.py index 0e4d65203..853a1d69e 100644 --- a/src/ott/neural/models/models.py +++ b/src/ott/neural/models/models.py @@ -150,8 +150,6 @@ def _compute_gaussian_map_params( ) -> Tuple[jnp.ndarray, jnp.ndarray]: from ott.tools.gaussian_mixture import gaussian source, target = samples - # print(source) - # print(type(source)) g_s = gaussian.Gaussian.from_samples(source) g_t = gaussian.Gaussian.from_samples(target) lin_op = g_s.scale.gaussian_map(g_t.scale) diff --git a/src/ott/neural/solvers/base_solver.py b/src/ott/neural/solvers/base_solver.py index b1a5c108d..da7577565 100644 --- a/src/ott/neural/solvers/base_solver.py +++ b/src/ott/neural/solvers/base_solver.py @@ -23,7 +23,7 @@ class BaseNeuralSolver(ABC): valid_freq: Frequency at which to run validation. """ - def __init__(self, iterations: int, valid_freq: int, **_: Any) -> Any: + def __init__(self, iterations: int, valid_freq: int, **_: Any) -> None: self.iterations = iterations self.valid_freq = valid_freq @@ -66,16 +66,16 @@ def _resample_data( target_arrays: Tuple[jnp.ndarray, ...], ) -> Tuple[jnp.ndarray, ...]: """Resample a batch according to coupling `tmat`.""" - transition_matrix = tmat.flatten() + tmat_flattened = tmat.flatten() indices = random.choice( - key, transition_matrix.flatten(), shape=[len(transition_matrix) ** 2] + key, len(tmat_flattened), shape=[len(tmat_flattened)] ) - indices_source = indices // self.batch_size - indices_target = indices % self.batch_size + indices_source = indices // tmat.shape[1] + indices_target = indices % tmat.shape[1] return tuple( - b[indices_source] if b is not None else None for b in source_arrays + b[indices_source, :] if b is not None else None for b in source_arrays ), tuple( - b[indices_target] if b is not None else None for b in target_arrays + b[indices_target, :] if b is not None else None for b in target_arrays ) def _resample_data_conditionally( diff --git a/src/ott/neural/solvers/flow_matching.py b/src/ott/neural/solvers/flow_matching.py index e53b9e45d..a0e48d414 100644 --- a/src/ott/neural/solvers/flow_matching.py +++ b/src/ott/neural/solvers/flow_matching.py @@ -48,9 +48,12 @@ def __init__( rng: random.PRNGKeyArray = random.PRNGKey(0), **kwargs: Any, ) -> None: - super().__init__( - iterations=iterations, - valid_freq=valid_freq, + BaseNeuralSolver.__init__( + self, iterations=iterations, valid_freq=valid_freq + ) + ResampleMixin.__init__(self) + UnbalancednessMixin.__init__( + self, source_dim=input_dim, target_dim=input_dim, cond_dim=cond_dim, @@ -59,7 +62,6 @@ def __init__( mlp_eta=mlp_eta, mlp_xi=mlp_xi, unbalanced_kwargs=unbalanced_kwargs, - **kwargs ) self.neural_vector_field = neural_vector_field @@ -105,10 +107,10 @@ def loss_fn( x_t = self.flow.compute_xt(noise, t, batch["source"], batch["target"]) apply_fn = functools.partial( - state_neural_vector_field.apply, {"params": params} + state_neural_vector_field.apply_fn, {"params": params} ) v_t = jax.vmap(apply_fn)( - t=t, x_t=x_t, condition=batch["condition"], keys_model=keys_model + t=t, x=x_t, condition=batch["condition"], keys_model=keys_model ) u_t = self.flow.compute_ut(t, batch["source"], batch["target"]) return jnp.mean((v_t - u_t) ** 2) @@ -122,7 +124,7 @@ def loss_fn( loss, grads = loss_grad( state_neural_vector_field.params, t, noise, batch, keys_model ) - return state_neural_vector_field.apply_gradients(grads), loss + return state_neural_vector_field.apply_gradients(grads=grads), loss return step_fn @@ -151,14 +153,16 @@ def match_pairs( def __call__(self, train_loader, valid_loader) -> None: batch: Mapping[str, jnp.ndarray] = {} for iter in range(self.iterations): + rng_resample, rng_step_fn, self.rng = random.split(self.rng, 3) batch["source"], batch["target"], batch["condition"] = next(train_loader) - tmat = self.match_fn(batch) - batch = self.resample( - batch, tmat, (batch["source"], batch["condition"]), - (batch["target"], batch["condition"]) - ) + tmat = self.match_fn(batch["source"], batch["target"]) + (batch["source"], + batch["condition"]), (batch["target"],) = self._resample_data( + rng_resample, tmat, (batch["source"], batch["condition"]), + (batch["target"],) + ) self.state_neural_vector_field, loss = self.step_fn( - self.state_neural_vector_field, batch + rng_step_fn, self.state_neural_vector_field, batch ) if self.learn_rescaling: self.state_eta, self.state_xi, eta_predictions, xi_predictions, loss_a, loss_b = self.unbalancedness_step_fn( @@ -206,12 +210,8 @@ def transport( ) def _valid_step(self, valid_loader, iter) -> None: - batch = next(valid_loader) - tmat = self.match_fn(batch) - batch = self.resample( - batch, tmat, (batch["source"], batch["condition"]), - (batch["target"], batch["condition"]) - ) + next(valid_loader) + # TODO: add callback and logging @property def learn_rescaling(self) -> bool: @@ -223,12 +223,12 @@ def save(self, path: str) -> None: def training_logs(self) -> Dict[str, Any]: raise NotImplementedError - def sample_t( + def sample_t( #TODO: make more general self, key: random.PRNGKey, batch_size: int ) -> jnp.ndarray: #TODO: make more general - return random.uniform(key, batch_size) + return random.uniform(key, [batch_size, 1]) - def sample_noise( + def sample_noise( #TODO: make more general self, key: random.PRNGKey, batch_size: int ) -> jnp.ndarray: #TODO: make more general return random.normal(key, shape=(batch_size, self.input_dim)) diff --git a/src/ott/neural/solvers/flows.py b/src/ott/neural/solvers/flows.py index b8c635f61..68cc84f5f 100644 --- a/src/ott/neural/solvers/flows.py +++ b/src/ott/neural/solvers/flows.py @@ -27,7 +27,7 @@ def compute_xt( self, noise: jax.Array, t: jax.Array, x_0: jax.Array, x_1: jax.Array ) -> jax.Array: mu_t = self.compute_mu_t(t, x_0, x_1) - sigma_t = self.compute_sigma_t(t, x_0, x_1) + sigma_t = self.compute_sigma_t(t) return mu_t + sigma_t * noise From 7869e3774277c213d26d89925618de913e051cf1 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Thu, 23 Nov 2023 11:59:39 +0100 Subject: [PATCH 010/186] [ci skip] add tests for FM with conditions and conditional OT with FM --- src/ott/neural/data/dataloaders.py | 2 +- src/ott/neural/solvers/base_solver.py | 5 ++ src/ott/neural/solvers/flow_matching.py | 54 +++++++----- tests/neural/conftest.py | 51 +++++++++++- tests/neural/flow_matching_test.py | 106 +++++++++++++++++++++++- 5 files changed, 190 insertions(+), 28 deletions(-) diff --git a/src/ott/neural/data/dataloaders.py b/src/ott/neural/data/dataloaders.py index 2bddebaa1..fe0c367b7 100644 --- a/src/ott/neural/data/dataloaders.py +++ b/src/ott/neural/data/dataloaders.py @@ -1,7 +1,7 @@ #import tensorflow as tf -class ConditionalDataLoader: +class ConditionalDataLoader: #TODO(@MUCDK) uncomment, resolve installation issues with TF pass #def __init__( diff --git a/src/ott/neural/solvers/base_solver.py b/src/ott/neural/solvers/base_solver.py index da7577565..02db3cae3 100644 --- a/src/ott/neural/solvers/base_solver.py +++ b/src/ott/neural/solvers/base_solver.py @@ -46,6 +46,11 @@ def save(self, path: Path): """Save the model.""" pass + @abstractmethod + def load(self, path: Path): + """Load the model.""" + pass + @property @abstractmethod def training_logs(self) -> Dict[str, Any]: diff --git a/src/ott/neural/solvers/flow_matching.py b/src/ott/neural/solvers/flow_matching.py index a0e48d414..4d8ea15ce 100644 --- a/src/ott/neural/solvers/flow_matching.py +++ b/src/ott/neural/solvers/flow_matching.py @@ -32,7 +32,7 @@ def __init__( cond_dim: int, iterations: int, valid_freq: int, - ot_solver: Type[was_solver.WassersteinSolver], + ot_solver: Optional[Type[was_solver.WassersteinSolver]], flow: Type[BaseFlow], optimizer: Type[optax.GradientTransformation], checkpoint_manager: Type[checkpoint.CheckpointManager] = None, @@ -83,14 +83,17 @@ def setup(self) -> None: ) self.step_fn = self._get_step_fn() - self.match_fn = self._get_match_fn( - self.ot_solver, - epsilon=self.epsilon, - cost_fn=self.cost_fn, - tau_a=self.tau_a, - tau_b=self.tau_b, - scale_cost=self.scale_cost, - ) + if self.ot_solver is not None: + self.match_fn = self._get_match_fn( + self.ot_solver, + epsilon=self.epsilon, + cost_fn=self.cost_fn, + tau_a=self.tau_a, + tau_b=self.tau_b, + scale_cost=self.scale_cost, + ) + else: + self.match_fn = None def _get_step_fn(self) -> Callable: @@ -155,12 +158,13 @@ def __call__(self, train_loader, valid_loader) -> None: for iter in range(self.iterations): rng_resample, rng_step_fn, self.rng = random.split(self.rng, 3) batch["source"], batch["target"], batch["condition"] = next(train_loader) - tmat = self.match_fn(batch["source"], batch["target"]) - (batch["source"], - batch["condition"]), (batch["target"],) = self._resample_data( - rng_resample, tmat, (batch["source"], batch["condition"]), - (batch["target"],) - ) + if self.ot_solver is not None: + tmat = self.match_fn(batch["source"], batch["target"]) + (batch["source"], + batch["condition"]), (batch["target"],) = self._resample_data( + rng_resample, tmat, (batch["source"], batch["condition"]), + (batch["target"],) + ) self.state_neural_vector_field, loss = self.step_fn( rng_step_fn, self.state_neural_vector_field, batch ) @@ -184,19 +188,22 @@ def transport( self, data: jnp.array, condition: Optional[jax.Array], - rng: random.PRNGKey, forward: bool = True, diffeqsolve_kwargs: Dict[str, Any] = types.MappingProxyType({}) ) -> diffrax.Solution: diffeqsolve_kwargs = dict(diffeqsolve_kwargs) - t0, t1 = (0, 1) if forward else (1, 0) + arr = jnp.ones((len(data), 1)) + t0, t1 = (arr * 0.0, arr * 1.0) if forward else (arr * 1.0, arr * 0.0) + apply_fn_partial = functools.partial( + self.state_neural_vector_field.apply_fn, condition=condition + ) return diffrax.diffeqsolve( diffrax.ODETerm( - lambda t, y: self.state_neural_vector_field. - apply({"params": self.state_neural_vector_field.params}, - t=t, - x=y, - condition=condition) + lambda t, y, *args: apply_fn_partial( + {"params": self.state_neural_vector_field.params}, + t=t, + x=y, + ) ), diffeqsolve_kwargs.pop("solver", diffrax.Tsit5()), t0=t0, @@ -220,6 +227,9 @@ def learn_rescaling(self) -> bool: def save(self, path: str) -> None: raise NotImplementedError + def load(self, path: str) -> "FlowMatching": + raise NotImplementedError + def training_logs(self) -> Dict[str, Any]: raise NotImplementedError diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py index 57dba63bf..f1b6e16f8 100644 --- a/tests/neural/conftest.py +++ b/tests/neural/conftest.py @@ -1,6 +1,7 @@ -from typing import Optional +from typing import Dict, Iterator, Optional import jax +import jax.numpy as jnp import pytest @@ -33,9 +34,55 @@ def __next__(self) -> jax.Array: inds_source, :] if self.conditions is not None else None +class ConditionalDataLoader: + + def __init__( + self, rng: jax.random.KeyArray, dataloaders: Dict[str, Iterator], + p: jax.Array + ) -> None: + super().__init__() + self.rng = rng + self.dataloaders = dataloaders + self.conditions = list(dataloaders.keys()) + self.p = p + + def __next__(self) -> jnp.ndarray: + self.rng, rng = jax.random.split(self.rng, 2) + idx = jax.random.choice(rng, len(self.conditions), p=self.p) + return next(self.dataloaders[self.conditions[idx]]) + + @pytest.fixture(scope="module") def data_loader_gaussian(): """Returns a data loader for a simple Gaussian mixture.""" source = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) - target = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) + target = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) + 1.0 return DataLoader(source, target, None, 16) + + +@pytest.fixture(scope="module") +def data_loader_gaussian_conditional(): + """Returns a data loader for Gaussian mixtures with conditions.""" + source_0 = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) + target_0 = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) + 2.0 + + source_1 = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) + target_1 = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) - 2.0 + dl0 = DataLoader(source_0, target_0, jnp.zeros_like(source_0) * 0.0, 16) + dl1 = DataLoader(source_1, target_1, jnp.ones_like(source_1) * 1.0, 16) + + return ConditionalDataLoader( + jax.random.PRNGKey(0), { + "0": dl0, + "1": dl1 + }, jnp.array([0.5, 0.5]) + ) + + +@pytest.fixture(scope="module") +def data_loader_gaussian_with_conditions(): + """Returns a data loader for a simple Gaussian mixture with conditions.""" + source = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) + conditions = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 1)) + target = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) + 1.0 + return DataLoader(source, target, conditions, 16) diff --git a/tests/neural/flow_matching_test.py b/tests/neural/flow_matching_test.py index 23ce1e178..a1c6e8f3d 100644 --- a/tests/neural/flow_matching_test.py +++ b/tests/neural/flow_matching_test.py @@ -1,21 +1,35 @@ +from typing import Type + +import diffrax +import jax.numpy as jnp import optax +import pytest from ott.neural.models.models import NeuralVectorField from ott.neural.solvers.flow_matching import FlowMatching -from ott.neural.solvers.flows import ConstantNoiseFlow +from ott.neural.solvers.flows import ( + BaseFlow, + BrownianNoiseFlow, + ConstantNoiseFlow, +) from ott.solvers.linear import sinkhorn class TestFlowMatching: - def test_flow_matching(self, data_loader_gaussian): + @pytest.mark.parametrize( + "flow", + [ConstantNoiseFlow(0.0), + ConstantNoiseFlow(1.0), + BrownianNoiseFlow(0.2)] + ) + def test_flow_matching(self, data_loader_gaussian, flow: Type[BaseFlow]): neural_vf = NeuralVectorField( output_dim=2, condition_dim=0, latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() - flow = ConstantNoiseFlow(sigma=0) optimizer = optax.adam(learning_rate=1e-3) fm = FlowMatching( neural_vf, @@ -28,3 +42,89 @@ def test_flow_matching(self, data_loader_gaussian): optimizer=optimizer ) fm(data_loader_gaussian, data_loader_gaussian) + + source, target, condition = next(data_loader_gaussian) + result_forward = fm.transport(source, condition=condition, forward=True) + assert isinstance(result_forward, diffrax.Solution) + assert jnp.sum(jnp.isnan(result_forward.y)) == 0 + + result_backward = fm.transport(target, condition=condition, forward=False) + assert isinstance(result_backward, diffrax.Solution) + assert jnp.sum(jnp.isnan(result_backward.y)) == 0 + + @pytest.mark.parametrize( + "flow", + [ConstantNoiseFlow(0.0), + ConstantNoiseFlow(1.0), + BrownianNoiseFlow(0.2)] + ) + def test_flow_matching_with_conditions( + self, data_loader_gaussian_with_conditions, flow: Type[BaseFlow] + ): + neural_vf = NeuralVectorField( + output_dim=2, + condition_dim=1, + latent_embed_dim=5, + ) + ot_solver = sinkhorn.Sinkhorn() + optimizer = optax.adam(learning_rate=1e-3) + fm = FlowMatching( + neural_vf, + input_dim=2, + cond_dim=1, + iterations=3, + valid_freq=2, + ot_solver=ot_solver, + flow=flow, + optimizer=optimizer + ) + fm( + data_loader_gaussian_with_conditions, + data_loader_gaussian_with_conditions + ) + + source, target, condition = next(data_loader_gaussian_with_conditions) + result_forward = fm.transport(source, condition=condition, forward=True) + assert isinstance(result_forward, diffrax.Solution) + assert jnp.sum(jnp.isnan(result_forward.y)) == 0 + + result_backward = fm.transport(target, condition=condition, forward=False) + assert isinstance(result_backward, diffrax.Solution) + assert jnp.sum(jnp.isnan(result_backward.y)) == 0 + + @pytest.mark.parametrize( + "flow", + [ConstantNoiseFlow(0.0), + ConstantNoiseFlow(1.0), + BrownianNoiseFlow(0.2)] + ) + def test_flow_matching_conditional( + self, data_loader_gaussian_conditional, flow: Type[BaseFlow] + ): + neural_vf = NeuralVectorField( + output_dim=2, + condition_dim=0, + latent_embed_dim=5, + ) + ot_solver = sinkhorn.Sinkhorn() + optimizer = optax.adam(learning_rate=1e-3) + fm = FlowMatching( + neural_vf, + input_dim=2, + cond_dim=0, + iterations=3, + valid_freq=2, + ot_solver=ot_solver, + flow=flow, + optimizer=optimizer + ) + fm(data_loader_gaussian_conditional, data_loader_gaussian_conditional) + + source, target, condition = next(data_loader_gaussian_conditional) + result_forward = fm.transport(source, condition=condition, forward=True) + assert isinstance(result_forward, diffrax.Solution) + assert jnp.sum(jnp.isnan(result_forward.y)) == 0 + + result_backward = fm.transport(target, condition=condition, forward=False) + assert isinstance(result_backward, diffrax.Solution) + assert jnp.sum(jnp.isnan(result_backward.y)) == 0 From 5a90dc1b2c4b20f10309af00256e6d0bbf136130 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Thu, 23 Nov 2023 18:03:45 +0100 Subject: [PATCH 011/186] [ci skip] add genot outline --- src/ott/neural/solvers/base_solver.py | 131 +++++++- src/ott/neural/solvers/flow_matching.py | 90 +++--- src/ott/neural/solvers/genot.py | 391 ++++++++++++++++++++++++ tests/neural/flow_matching_test.py | 9 +- 4 files changed, 574 insertions(+), 47 deletions(-) create mode 100644 src/ott/neural/solvers/genot.py diff --git a/src/ott/neural/solvers/base_solver.py b/src/ott/neural/solvers/base_solver.py index 02db3cae3..662fa73a3 100644 --- a/src/ott/neural/solvers/base_solver.py +++ b/src/ott/neural/solvers/base_solver.py @@ -9,9 +9,11 @@ from flax.training import train_state from jax import random +from ott.geometry import pointcloud from ott.geometry.pointcloud import PointCloud from ott.neural.models import models from ott.problems.linear import linear_problem +from ott.problems.quadratic import quadratic_problem from ott.solvers.linear import sinkhorn @@ -83,12 +85,131 @@ def _resample_data( b[indices_target, :] if b is not None else None for b in target_arrays ) - def _resample_data_conditionally( + def sample_conditional_indices_from_tmap( + key: jax.random.PRNGKeyArray, + tmat: jnp.ndarray, + k_samples_per_x: Union[int, jnp.ndarray], + source_arrays: Tuple[jnp.ndarray, ...], + target_arrays: Tuple[jnp.ndarray, ...], + *, + is_balanced: bool, + ) -> Tuple[jnp.array, jnp.array]: + left_marginals = tmat.sum(axis=1) + if not is_balanced: + key, key2 = jax.random.split(key, 2) + indices = jax.random.choice( + key=key2, + a=jnp.arange(len(left_marginals)), + p=left_marginals, + shape=(len(left_marginals),) + ) + else: + indices = jnp.arange(tmat.shape[0]) + tmat_adapted = tmat[indices] + indices_per_row = jax.vmap( + lambda tmat_adapted: jax.random.choice( + key=key, + a=jnp.arange(tmat.shape[1]), + p=tmat_adapted, + shape=(k_samples_per_x,) + ), + in_axes=0, + out_axes=0, + )( + tmat_adapted + ) + + indices_source = jnp.repeat(indices, k_samples_per_x) + indices_target = indices_per_row % tmat.shape[1] + return tuple( + b[indices_source, :] if b is not None else None for b in source_arrays + ), tuple( + b[indices_target, :] if b is not None else None for b in target_arrays + ) + + def _get_sinkhorn_match_fn( + self, + ot_solver: Any, + epsilon: float, + cost_fn: str, + scale_cost: Any, + tau_a: float, + tau_b: float, + ) -> Callable: + + def match_pairs( + x: jax.Array, y: jax.Array + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: + geom = pointcloud.PointCloud( + x, y, epsilon=epsilon, scale_cost=scale_cost, cost_fn=cost_fn + ) + return ot_solver( + linear_problem.LinearProblem(geom, tau_a=tau_a, tau_b=tau_b) + ).matrix + + return match_pairs + + def _get_gromov_match_fn( self, - *args: Any, - **kwargs: Any, - ): - raise NotImplementedError + ot_solver: Any, + cost_fn: Union[Any, Mapping[str, Any]], + scale_cost: Union[Any, Mapping[str, Any]], + tau_a: float, + tau_b: float, + fused_penalty: float, + ) -> Callable: + if isinstance(cost_fn, Mapping): + assert "x_cost_fn" in cost_fn + assert "y_cost_fn" in cost_fn + x_cost_fn = cost_fn["x_cost_fn"] + y_cost_fn = cost_fn["y_cost_fn"] + if fused_penalty > 0: + assert "xy_cost_fn" in x_cost_fn + xy_cost_fn = cost_fn["xy_cost_fn"] + else: + x_cost_fn = y_cost_fn = xy_cost_fn = cost_fn + + if isinstance(scale_cost, Mapping): + assert "x_scale_cost" in scale_cost + assert "y_scale_cost" in scale_cost + x_scale_cost = scale_cost["x_scale_cost"] + y_scale_cost = scale_cost["y_scale_cost"] + if fused_penalty > 0: + assert "xy_scale_cost" in scale_cost + xy_scale_cost = cost_fn["xy_scale_cost"] + else: + x_scale_cost = y_scale_cost = xy_scale_cost = scale_cost + + def match_pairs( + x_quad: Tuple[jnp.ndarray, jnp.ndarray], + y_quad: Tuple[jnp.ndarray, jnp.ndarray], + x_lin: Optional[jax.Array], + y_lin: Optional[jax.Array], + ) -> Tuple[jnp.array, jnp.array]: + geom_xx = pointcloud.PointCloud( + x=x_quad, y=x_quad, cost_fn=x_cost_fn, scale_cost=x_scale_cost + ) + geom_yy = pointcloud.PointCloud( + x=y_quad, y=y_quad, cost_fn=y_cost_fn, scale_cost=y_scale_cost + ) + if fused_penalty > 0: + geom_xy = pointcloud.PointCloud( + x=x_lin, y=y_lin, cost_fn=xy_cost_fn, scale_cost=xy_scale_cost + ) + else: + geom_xy = None + prob = quadratic_problem.QuadraticProblem( + geom_xx, + geom_yy, + geom_xy, + fused_penalty=fused_penalty, + tau_a=tau_a, + tau_b=tau_b + ) + out = ot_solver(prob) + return out.matrix + + return match_pairs class UnbalancednessMixin: diff --git a/src/ott/neural/solvers/flow_matching.py b/src/ott/neural/solvers/flow_matching.py index 4d8ea15ce..3532d09c1 100644 --- a/src/ott/neural/solvers/flow_matching.py +++ b/src/ott/neural/solvers/flow_matching.py @@ -9,7 +9,7 @@ from jax import random from orbax import checkpoint -from ott.geometry import costs, pointcloud +from ott.geometry import costs from ott.neural.models.models import BaseNeuralVectorField from ott.neural.solvers.base_solver import ( BaseNeuralSolver, @@ -19,7 +19,6 @@ from ott.neural.solvers.flows import ( BaseFlow, ) -from ott.problems.linear import linear_problem from ott.solvers import was_solver @@ -84,13 +83,13 @@ def setup(self) -> None: self.step_fn = self._get_step_fn() if self.ot_solver is not None: - self.match_fn = self._get_match_fn( + self.match_fn = self._get_sinkhorn_match_fn( self.ot_solver, epsilon=self.epsilon, cost_fn=self.cost_fn, + scale_cost=self.scale_cost, tau_a=self.tau_a, tau_b=self.tau_b, - scale_cost=self.scale_cost, ) else: self.match_fn = None @@ -131,28 +130,6 @@ def loss_fn( return step_fn - def _get_match_fn( - self, - ot_solver: Any, - epsilon: float, - cost_fn: str, - tau_a: float, - tau_b: float, - scale_cost: Any, - ) -> Callable: - - def match_pairs( - x: jax.Array, y: jax.Array - ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: - geom = pointcloud.PointCloud( - x, y, epsilon=epsilon, scale_cost=scale_cost, cost_fn=cost_fn - ) - return ot_solver( - linear_problem.LinearProblem(geom, tau_a=tau_a, tau_b=tau_b) - ).matrix - - return match_pairs - def __call__(self, train_loader, valid_loader) -> None: batch: Mapping[str, jnp.ndarray] = {} for iter in range(self.iterations): @@ -192,27 +169,64 @@ def transport( diffeqsolve_kwargs: Dict[str, Any] = types.MappingProxyType({}) ) -> diffrax.Solution: diffeqsolve_kwargs = dict(diffeqsolve_kwargs) + + def solve_ode( + t0: jax.Array, t1: jax.Array, input: jax.Array, cond: jax.Array + ): + return diffrax.diffeqsolve( + diffrax.ODETerm( + lambda t, x, args: self.state_neural_vector_field. + apply_fn({"params": self.state_neural_vector_field.params}, + t=t, + x=x, + condition=cond) + ), + diffeqsolve_kwargs.pop("solver", diffrax.Tsit5()), + t0=t0, + t1=t1, + dt0=diffeqsolve_kwargs.pop("dt0", None), + y0=input, + stepsize_controller=diffeqsolve_kwargs.pop( + "stepsize_controller", + diffrax.PIDController(rtol=1e-5, atol=1e-5) + ), + **diffeqsolve_kwargs, + ).solution.y + + arr = jnp.ones((len(data), 1)) + t0, t1 = (arr * 0.0, arr * 1.0) if forward else (arr * 1.0, arr * 0.0) + + out = jax.vmap(solve_ode)(t0, t1, data, condition) + return out + + def _transport( + self, + data: jnp.array, + condition: Optional[jax.Array], + forward: bool = True, + diffeqsolve_kwargs: Dict[str, Any] = types.MappingProxyType({}) + ) -> diffrax.Solution: + diffeqsolve_kwargs = dict(diffeqsolve_kwargs) arr = jnp.ones((len(data), 1)) t0, t1 = (arr * 0.0, arr * 1.0) if forward else (arr * 1.0, arr * 0.0) apply_fn_partial = functools.partial( - self.state_neural_vector_field.apply_fn, condition=condition + self.state_neural_vector_field.apply_fn, + params={"params": self.state_neural_vector_field.params}, + condition=condition + ) + term = diffrax.ODETerm(lambda t, y, *args: apply_fn_partial(t, y, *args)) + solver = diffeqsolve_kwargs.pop("solver", diffrax.Tsit5()) + stepsize_controller = diffeqsolve_kwargs.pop( + "stepsize_controller", diffrax.PIDController(rtol=1e-5, atol=1e-5) ) return diffrax.diffeqsolve( - diffrax.ODETerm( - lambda t, y, *args: apply_fn_partial( - {"params": self.state_neural_vector_field.params}, - t=t, - x=y, - ) - ), - diffeqsolve_kwargs.pop("solver", diffrax.Tsit5()), + term, + solver, t0=t0, t1=t1, dt0=diffeqsolve_kwargs.pop("dt0", None), y0=data, - stepsize_controller=diffeqsolve_kwargs.pop( - "stepsize_controller", diffrax.PIDController(rtol=1e-5, atol=1e-5) - ), + stepsize_controller=stepsize_controller, **diffeqsolve_kwargs, ) diff --git a/src/ott/neural/solvers/genot.py b/src/ott/neural/solvers/genot.py new file mode 100644 index 000000000..26ae93ede --- /dev/null +++ b/src/ott/neural/solvers/genot.py @@ -0,0 +1,391 @@ +import types +from functools import partial +from typing import ( + Any, + Callable, + Dict, + Literal, + Mapping, + Optional, + Tuple, + Type, + Union, +) + +import diffrax +import jax +import jax.numpy as jnp +import optax +from flax.training.train_state import TrainState +from jax import random +from tqdm import tqdm + +from ott.geometry import costs +from ott.neural.models.models import BaseNeuralVectorField +from ott.neural.solvers.base_solver import ( + BaseNeuralSolver, + ResampleMixin, + UnbalancednessMixin, +) +from ott.neural.solvers.flows import BaseFlow, ConstantNoiseFlow +from ott.solvers import was_solver +from ott.solvers.linear import sinkhorn +from ott.solvers.quadratic import gromov_wasserstein + +Match_fn_T = Callable[[jax.random.PRNGKeyArray, jnp.array, jnp.array], + Tuple[jnp.array, jnp.array, jnp.array, jnp.array]] +Match_latent_fn_T = Callable[[jax.random.PRNGKeyArray, jnp.array, jnp.array], + Tuple[jnp.array, jnp.array]] + + +class GENOT(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): + + def __init__( + self, + neural_vector_field: Type[BaseNeuralVectorField], + input_dim: int, + output_dim: int, + cond_dim: int, + iterations: int, + valid_freq: int, + ot_solver: Type[was_solver.WassersteinSolver], + optimizer: Type[optax.GradientTransformation], + flow: Type[BaseFlow] = ConstantNoiseFlow(0.0), + k_noise_per_x: int = 1, + t_offset: float = 1e-5, + epsilon: float = 1e-2, + cost_fn: Union[costs.CostFn, Literal["graph"]] = costs.SqEuclidean(), + solver_latent_to_data: Optional[Type[was_solver.WassersteinSolver] + ] = None, + latent_to_data_epsilon: float = 1e-2, + latent_to_data_scale_cost: Any = 1.0, + scale_cost: Union[Any, Mapping[str, Any]] = 1.0, + graph_kwargs: Dict[str, Any] = types.MappingProxyType({}), + fused_penalty: float = 0.0, + tau_a: float = 1.0, + tau_b: float = 1.0, + mlp_eta: Callable[[jnp.ndarray], float] = None, + mlp_xi: Callable[[jnp.ndarray], float] = None, + unbalanced_kwargs: Dict[str, Any] = {}, + callback: Optional[Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], + Any]] = None, + callback_kwargs: Dict[str, Any] = {}, + callback_iters: int = 10, + rng: random.PRNGKeyArray = random.PRNGKey(0), + **kwargs: Any, + ) -> None: + """The GENOT training class. + + Parameters + ---------- + neural_vector_field + Neural vector field + input_dim + Dimension of the source distribution + output_dim + Dimension of the target distribution + cond_dim + Dimension of the condition + iterations + Number of iterations to train + valid_freq + Number of iterations after which to perform a validation step + ot_solver + Solver to match samples from the source to the target distribution + optimizer + Optimizer for the neural vector field + flow + Flow to use in the target space from noise to data. Should be of type + `ConstantNoiseFlow` to recover the setup in the paper TODO. + k_noise_per_x + Number of samples to draw from the conditional distribution + t_offset + Offset for sampling from the time t + epsilon + Entropy regularization parameter for the discrete solver + cost_fn + Cost function to use for the discrete OT solver + solver_latent_to_data + Linear OT solver to match samples from the noise to the conditional distribution + latent_to_data_epsilon + Entropy regularization term for `solver_latent_to_data` + latent_to_data_scale_cost + How to scale the cost matrix for the `solver_latent_to_data` solver + scale_cost + How to scale the cost matrix in each discrete OT problem + graph_kwargs + Keyword arguments for the graph cost computation in case `cost="graph"` + fused_penalty + Penalisation term for the linear term in a Fused GW setting + split_dim + Dimension to split the data into fused term and purely quadratic term in the FGW setting + mlp_eta + Neural network to learn the left rescaling function + mlp_xi + Neural network to learn the right rescaling function + tau_a + Left unbalancedness parameter + tau_b + Right unbalancedness parameter + callback + Callback function + callback_kwargs + Keyword arguments to the callback function + callback_iters + Number of iterations after which to evaluate callback function + seed + Random seed + kwargs + Keyword arguments passed to `setup`, e.g. custom choice of optimizers for learning rescaling functions + """ + BaseNeuralSolver.__init__( + self, iterations=iterations, valid_freq=valid_freq + ) + ResampleMixin.__init__(self) + UnbalancednessMixin.__init__( + self, + source_dim=input_dim, + target_dim=input_dim, + cond_dim=cond_dim, + tau_a=tau_a, + tau_b=tau_b, + mlp_eta=mlp_eta, + mlp_xi=mlp_xi, + unbalanced_kwargs=unbalanced_kwargs, + ) + + if isinstance( + ot_solver, gromov_wasserstein.GromovWasserstein + ) and epsilon is not None: + raise ValueError( + "If `ot_solver` is `GromovWasserstein`, `epsilon` must be `None`. This check is performed " + "to ensure that in the (fused) Gromov case the `epsilon` parameter is passed via the `ot_solver`." + ) + + # setup parameters + self.rng = rng + self.metrics = {"loss": [], "loss_eta": [], "loss_xi": []} + + # neural parameters + self.neural_vector_field = neural_vector_field + self.state_neural_vector_field: Optional[TrainState] = None + self.optimizer = optimizer + self.noise_fn = jax.tree_util.Partial( + jax.random.multivariate_normal, + mean=jnp.zeros((output_dim,)), + cov=jnp.diag(jnp.ones((output_dim,))) + ) + self.input_dim = input_dim + self.output_dim = output_dim + self.cond_dim = cond_dim + self.k_noise_per_x = k_noise_per_x + + # OT data-data matching parameters + self.ot_solver = ot_solver + self.epsilon = epsilon + self.cost_fn = cost_fn + self.scale_cost = scale_cost + self.graph_kwargs = graph_kwargs # "k_neighbors", kwargs for graph.Graph.from_graph() + self.fused_penalty = fused_penalty + + # OT latent-data matching parameters + self.solver_latent_to_data = solver_latent_to_data + self.latent_to_data_epsilon = latent_to_data_epsilon + self.latent_to_data_scale_cost = latent_to_data_scale_cost + + # callback parameteres + self.callback = callback + self.callback_kwargs = callback_kwargs + self.callback_iters = callback_iters + + #TODO: check how to handle this + self.t_offset = t_offset + + self.setup(**kwargs) + + def setup(self) -> None: + """Set up the model. + + Parameters + ---------- + kwargs + Keyword arguments for the setup function + """ + self.state_neural_vector_field = self.neural_vector_field.create_train_state( + self.rng, self.optimizer, self.input_dim + ) + self.step_fn = self._get_step_fn() + if self.solver_latent_to_data is not None: + self.match_latent_to_data_fn = self._get_match_latent_fn( + self.solver_latent_to_data, self.latent_to_data_epsilon, + self.latent_to_data_scale_cost + ) + else: + self.match_latent_to_data_fn = lambda key, x, y, **_: (x, y) + + if isinstance(self.ot_solver, sinkhorn.Sinkhorn): + self.match_fn = self._get_sinkhorn_match_fn( + self.ot_solver, self.epsilon, self.cost_fn, self.tau_a, self.tau_b, + self.scale_cost + ) + else: + self._get_gromov_match_fn( + self.ot_solver, self.cost_fn, self.tau_a, self.tau_b, self.scale_cost, + self.fused_penalty + ) + + def __call__(self, train_loader, valid_loader) -> None: + """Train GENOT.""" + batch: Dict[str, jnp.array] = {} + for step in tqdm(range(self.iterations)): + batch["source"], batch["source_q"], batch["target"], batch[ + "target_q"], batch["condition"] = next(train_loader) + + self.rng, rng_time, rng_match, rng_resample, rng_noise, rng_step_fn = jax.random.split( + self.rng, 6 + ) + n_samples = len(batch["source"]) * self.k_noise_per_k + t = ( + jax.random.uniform(rng_time, (1,)) + jnp.arange(n_samples) / n_samples + ) % (1 - self.t_offset) + batch["time"] = t[:, None] + batch["noise"] = self.noise_fn( + rng_noise, shape=(batch["source"], self.k_noise_per_x) + ) + + tmat = self.match_fn(rng_match, batch["source"], batch["target"]) + (batch["source"], batch["source_q"], batch["condition"] + ), (batch["target"], batch["target_q"]) = self._resample_data( + rng_resample, tmat, + (batch["source"], batch["source_q"], batch["condition"]), + (batch["target"], batch["target_q"]) + ) + rng_noise = jax.random.split(rng_noise, (len(batch["target"]))) + + noise_matched, conditional_target = jax.vmap( + self.match_latent_to_data_fn, 0, 0 + )(key=rng_noise, x=batch["noise"], y=batch["target"]) + + batch["source"] = jnp.reshape(batch["source"], (len(batch["source"]), -1)) + batch["target"] = jnp.reshape( + conditional_target, (len(batch["source"]), -1) + ) + batch["noise"] = jnp.reshape(noise_matched, (len(batch["soruce"]), -1)) + + self.state_neural_vector_field, loss = self.step_fn( + rng_step_fn, self.state_neural_vector_field, batch + ) + if self.learn_rescaling: + self.state_eta, self.state_xi, eta_predictions, xi_predictions, loss_a, loss_b = self.unbalancedness_step_fn( + batch, tmat.sum(axis=1), tmat.sum(axis=0) + ) + if iter % self.valid_freq == 0: + self._valid_step(valid_loader, iter) + if self.checkpoint_manager is not None: + states_to_save = { + "state_neural_vector_field": self.state_neural_vector_field + } + if self.state_mlp is not None: + states_to_save["state_eta"] = self.state_mlp + if self.state_xi is not None: + states_to_save["state_xi"] = self.state_xi + self.checkpoint_manager.save(iter, states_to_save) + + def _get_step_fn(self) -> Callable: + + def loss_fn( + params_mlp: jnp.array, + apply_fn_mlp: Callable, + batch: Dict[str, jnp.array], + ): + + def phi_t( + x_0: jnp.ndarray, x_1: jnp.ndarray, t: jnp.ndarray + ) -> jnp.ndarray: + return (1 - t) * x_0 + t * x_1 + + def u_t(x_0: jnp.ndarray, x_1: jnp.ndarray) -> jnp.ndarray: + return x_1 - x_0 + + phi_t_eval = phi_t(batch["noise"], batch["target"], batch["time"]) + mlp_pred = apply_fn_mlp({"params": params_mlp}, + t=batch["time"], + latent=phi_t_eval, + condition=batch["source"]) + d_psi = u_t(batch["noise"], batch["target"]) + + return jnp.mean(optax.l2_loss(mlp_pred, d_psi)) + + @jax.jit + def step_fn( + key: jax.random.PRNGKeyArray, + state_neural_net: TrainState, + batch: Dict[str, jnp.array], + ): + + grad_fn = jax.value_and_grad(loss_fn, has_aux=False) + loss, grads_mlp = grad_fn( + state_neural_net.params, + state_neural_net.apply_fn, + batch, + ) + metrics = {} + metrics["loss"] = loss + + return (state_neural_net.apply_gradients(grads=grads_mlp), loss) + + return step_fn + + def transport( + self, + source: jnp.array, + seed: int = 0, + diffeqsolve_kwargs: Dict[str, Any] = types.MappingProxyType({}) + ) -> Union[jnp.array, diffrax.Solution, Optional[jnp.ndarray]]: + """Transport the distribution. + + Parameters + ---------- + source + Source distribution to transport + seed + Random seed for sampling from the latent distribution + diffeqsolve_kwargs + Keyword arguments for the ODE solver. + + Returns: + ------- + The transported samples, the solution of the neural ODE, and the rescaling factor. + """ + diffeqsolve_kwargs = dict(diffeqsolve_kwargs) + rng = jax.random.PRNGKey(seed) + latent_shape = (len(source),) + latent_batch = self.noise_fn(rng, shape=latent_shape) + apply_fn_partial = partial( + self.state_neural_vector_field.apply_fn, condition=source + ) + solution = diffrax.diffeqsolve( + diffrax.ODETerm( + lambda t, y, *args: + apply_fn_partial({"params": self.state_neural_vector_field.params}, + t=t, + latent=y) + ), + diffeqsolve_kwargs.pop("solver", diffrax.Tsit5()), + t0=0, + t1=1, + dt0=diffeqsolve_kwargs.pop("dt0", None), + y0=latent_batch, + stepsize_controller=diffeqsolve_kwargs.pop( + "stepsize_controller", diffrax.PIDController(rtol=1e-3, atol=1e-6) + ), + **diffeqsolve_kwargs, + ) + if self.state_eta is not None: + weight_factors = self.state_eta.apply_fn({ + "params": self.state_eta.params + }, + x=source) + else: + weight_factors = jnp.ones(source.shape) + return solution.ys, solution, weight_factors diff --git a/tests/neural/flow_matching_test.py b/tests/neural/flow_matching_test.py index a1c6e8f3d..39199106b 100644 --- a/tests/neural/flow_matching_test.py +++ b/tests/neural/flow_matching_test.py @@ -1,6 +1,7 @@ from typing import Type import diffrax +import jax import jax.numpy as jnp import optax import pytest @@ -85,12 +86,12 @@ def test_flow_matching_with_conditions( source, target, condition = next(data_loader_gaussian_with_conditions) result_forward = fm.transport(source, condition=condition, forward=True) - assert isinstance(result_forward, diffrax.Solution) - assert jnp.sum(jnp.isnan(result_forward.y)) == 0 + assert isinstance(result_forward, jax.Array) + assert jnp.sum(jnp.isnan(result_forward)) == 0 result_backward = fm.transport(target, condition=condition, forward=False) - assert isinstance(result_backward, diffrax.Solution) - assert jnp.sum(jnp.isnan(result_backward.y)) == 0 + assert isinstance(result_backward, jax.Array) + assert jnp.sum(jnp.isnan(result_backward)) == 0 @pytest.mark.parametrize( "flow", From c843758c8510db019b2892f08959043537508d7b Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Thu, 23 Nov 2023 19:15:27 +0100 Subject: [PATCH 012/186] [ci skip] restructure genot --- src/ott/neural/solvers/base_solver.py | 2 +- src/ott/neural/solvers/flow_matching.py | 7 +- src/ott/neural/solvers/genot.py | 150 +++++++++++++----------- 3 files changed, 89 insertions(+), 70 deletions(-) diff --git a/src/ott/neural/solvers/base_solver.py b/src/ott/neural/solvers/base_solver.py index 662fa73a3..eb14fadee 100644 --- a/src/ott/neural/solvers/base_solver.py +++ b/src/ott/neural/solvers/base_solver.py @@ -85,7 +85,7 @@ def _resample_data( b[indices_target, :] if b is not None else None for b in target_arrays ) - def sample_conditional_indices_from_tmap( + def _sample_conditional_indices_from_tmap( key: jax.random.PRNGKeyArray, tmat: jnp.ndarray, k_samples_per_x: Union[int, jnp.ndarray], diff --git a/src/ott/neural/solvers/flow_matching.py b/src/ott/neural/solvers/flow_matching.py index 3532d09c1..bfa5b5110 100644 --- a/src/ott/neural/solvers/flow_matching.py +++ b/src/ott/neural/solvers/flow_matching.py @@ -6,6 +6,7 @@ import jax import jax.numpy as jnp import optax +from flax.training import train_state from jax import random from orbax import checkpoint @@ -98,7 +99,7 @@ def _get_step_fn(self) -> Callable: def step_fn( key: random.PRNGKeyArray, - state_neural_vector_field: Any, + state_neural_vector_field: train_state.TrainState, batch: Dict[str, jnp.ndarray], ) -> Tuple[Any, Any]: @@ -122,8 +123,8 @@ def loss_fn( keys_model = random.split(key_model, batch_size) t = self.sample_t(key_t, batch_size) noise = self.sample_noise(key_noise, batch_size) - loss_grad = jax.value_and_grad(loss_fn) - loss, grads = loss_grad( + grad_fn = jax.value_and_grad(loss_fn) + loss, grads = grad_fn( state_neural_vector_field.params, t, noise, batch, keys_model ) return state_neural_vector_field.apply_gradients(grads=grads), loss diff --git a/src/ott/neural/solvers/genot.py b/src/ott/neural/solvers/genot.py index 26ae93ede..110d65738 100644 --- a/src/ott/neural/solvers/genot.py +++ b/src/ott/neural/solvers/genot.py @@ -1,21 +1,23 @@ +import functools import types from functools import partial from typing import ( - Any, - Callable, - Dict, - Literal, - Mapping, - Optional, - Tuple, - Type, - Union, + Any, + Callable, + Dict, + Literal, + Mapping, + Optional, + Tuple, + Type, + Union, ) import diffrax import jax import jax.numpy as jnp import optax +from flax.training import train_state from flax.training.train_state import TrainState from jax import random from tqdm import tqdm @@ -23,9 +25,9 @@ from ott.geometry import costs from ott.neural.models.models import BaseNeuralVectorField from ott.neural.solvers.base_solver import ( - BaseNeuralSolver, - ResampleMixin, - UnbalancednessMixin, + BaseNeuralSolver, + ResampleMixin, + UnbalancednessMixin, ) from ott.neural.solvers.flows import BaseFlow, ConstantNoiseFlow from ott.solvers import was_solver @@ -57,8 +59,7 @@ def __init__( cost_fn: Union[costs.CostFn, Literal["graph"]] = costs.SqEuclidean(), solver_latent_to_data: Optional[Type[was_solver.WassersteinSolver] ] = None, - latent_to_data_epsilon: float = 1e-2, - latent_to_data_scale_cost: Any = 1.0, + kwargs_solver_latent_to_data: Dict[str, Any] = types.MappingProxyType({}), scale_cost: Union[Any, Mapping[str, Any]] = 1.0, graph_kwargs: Dict[str, Any] = types.MappingProxyType({}), fused_penalty: float = 0.0, @@ -190,8 +191,7 @@ def __init__( # OT latent-data matching parameters self.solver_latent_to_data = solver_latent_to_data - self.latent_to_data_epsilon = latent_to_data_epsilon - self.latent_to_data_scale_cost = latent_to_data_scale_cost + self.kwargs_solver_latent_to_data = kwargs_solver_latent_to_data # callback parameteres self.callback = callback @@ -216,13 +216,13 @@ def setup(self) -> None: ) self.step_fn = self._get_step_fn() if self.solver_latent_to_data is not None: - self.match_latent_to_data_fn = self._get_match_latent_fn( - self.solver_latent_to_data, self.latent_to_data_epsilon, - self.latent_to_data_scale_cost + self.match_latent_to_data_fn = self._get_sinkhorn_match_fn( + self.solver_latent_to_data, **self.kwargs_solver_latent_to_data ) else: self.match_latent_to_data_fn = lambda key, x, y, **_: (x, y) + # TODO: add graph construction function if isinstance(self.ot_solver, sinkhorn.Sinkhorn): self.match_fn = self._get_sinkhorn_match_fn( self.ot_solver, self.epsilon, self.cost_fn, self.tau_a, self.tau_b, @@ -241,36 +241,39 @@ def __call__(self, train_loader, valid_loader) -> None: batch["source"], batch["source_q"], batch["target"], batch[ "target_q"], batch["condition"] = next(train_loader) - self.rng, rng_time, rng_match, rng_resample, rng_noise, rng_step_fn = jax.random.split( - self.rng, 6 + self.rng, rng_time, rng_match, rng_resample, rng_noise, rng_latent_data_match, rng_step_fn = jax.random.split( + self.rng, 7 ) n_samples = len(batch["source"]) * self.k_noise_per_k - t = ( - jax.random.uniform(rng_time, (1,)) + jnp.arange(n_samples) / n_samples - ) % (1 - self.t_offset) - batch["time"] = t[:, None] + batch["time"] = self.sample_t(key, n_samples) batch["noise"] = self.noise_fn( rng_noise, shape=(batch["source"], self.k_noise_per_x) ) tmat = self.match_fn(rng_match, batch["source"], batch["target"]) (batch["source"], batch["source_q"], batch["condition"] + ), (batch["target"], + batch["target_q"]) = self._sample_conditional_indices_from_tmap( + rng_resample, tmat, self.k_noise_per_x, + (batch["source"], batch["source_q"], batch["condition"]), + (batch["target"], batch["target_q"]) + ) + rng_noise = jax.random.split(rng_noise, (len(batch["target"]))) + + tmat_latent_data = jax.vmap(self.match_latent_to_data_fn, 0, 0)( + key=rng_noise, x=batch["noise"], y=batch["target"] + ) + (batch["source"], batch["source_q"], batch["condition"] ), (batch["target"], batch["target_q"]) = self._resample_data( - rng_resample, tmat, + rng_latent_data_match, tmat_latent_data, (batch["source"], batch["source_q"], batch["condition"]), (batch["target"], batch["target_q"]) ) - rng_noise = jax.random.split(rng_noise, (len(batch["target"]))) - - noise_matched, conditional_target = jax.vmap( - self.match_latent_to_data_fn, 0, 0 - )(key=rng_noise, x=batch["noise"], y=batch["target"]) - batch["source"] = jnp.reshape(batch["source"], (len(batch["source"]), -1)) - batch["target"] = jnp.reshape( - conditional_target, (len(batch["source"]), -1) - ) - batch["noise"] = jnp.reshape(noise_matched, (len(batch["soruce"]), -1)) + batch = { + key: jnp.reshape(arr, (len(batch["source"]), -1)) + for key, arr in batch.items() + } self.state_neural_vector_field, loss = self.step_fn( rng_step_fn, self.state_neural_vector_field, batch @@ -293,46 +296,38 @@ def __call__(self, train_loader, valid_loader) -> None: def _get_step_fn(self) -> Callable: - def loss_fn( - params_mlp: jnp.array, - apply_fn_mlp: Callable, - batch: Dict[str, jnp.array], - ): - - def phi_t( - x_0: jnp.ndarray, x_1: jnp.ndarray, t: jnp.ndarray - ) -> jnp.ndarray: - return (1 - t) * x_0 + t * x_1 - - def u_t(x_0: jnp.ndarray, x_1: jnp.ndarray) -> jnp.ndarray: - return x_1 - x_0 - - phi_t_eval = phi_t(batch["noise"], batch["target"], batch["time"]) - mlp_pred = apply_fn_mlp({"params": params_mlp}, - t=batch["time"], - latent=phi_t_eval, - condition=batch["source"]) - d_psi = u_t(batch["noise"], batch["target"]) - - return jnp.mean(optax.l2_loss(mlp_pred, d_psi)) - @jax.jit def step_fn( key: jax.random.PRNGKeyArray, - state_neural_net: TrainState, + state_neural_vector_field: train_state.TrainState, batch: Dict[str, jnp.array], ): + def loss_fn( + params: jax.Array, t: jax.Array, noise: jax.Array, + batch: Dict[str, jnp.array], keys_model: random.PRNGKeyArray + ): + + x_t = self.flow.compute_xt(noise, t, batch["latent"], batch["target"]) + apply_fn = functools.partial( + state_neural_vector_field.apply_fn, {"params": params} + ) + cond_input = jnp.concatenate([batch["source"], batch["condition"]], + axis=-1) + v_t = jax.vmap(apply_fn)( + t=t, x=x_t, condition=cond_input, keys_model=keys_model + ) + u_t = self.flow.compute_ut(t, batch["latent"], batch["target"]) + return jnp.mean((v_t - u_t) ** 2) + grad_fn = jax.value_and_grad(loss_fn, has_aux=False) - loss, grads_mlp = grad_fn( - state_neural_net.params, - state_neural_net.apply_fn, + loss, grads = grad_fn( + state_neural_vector_field.params, + state_neural_vector_field.apply_fn, batch, ) - metrics = {} - metrics["loss"] = loss - return (state_neural_net.apply_gradients(grads=grads_mlp), loss) + return state_neural_vector_field.apply_gradients(grads=grads), loss return step_fn @@ -389,3 +384,26 @@ def transport( else: weight_factors = jnp.ones(source.shape) return solution.ys, solution, weight_factors + + def _valid_step(self, valid_loader, iter) -> None: + next(valid_loader) + + # TODO: add callback and logging + + @property + def learn_rescaling(self) -> bool: + return self.mlp_eta is not None or self.mlp_xi is not None + + def save(self, path: str) -> None: + raise NotImplementedError + + def load(self, path: str) -> "GENOT": + raise NotImplementedError + + def training_logs(self) -> Dict[str, Any]: + raise NotImplementedError + + def sample_t( #TODO: make more general + self, key: random.PRNGKey, batch_size: int + ) -> jnp.ndarray: #TODO: make more general + return random.uniform(key, [batch_size, 1]) From ef86c540651addcc4d5d3d774621f51f57b6c255 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 24 Nov 2023 11:02:48 +0100 Subject: [PATCH 013/186] [ci skip] restructure genot --- src/ott/neural/solvers/base_solver.py | 23 +++- src/ott/neural/solvers/genot.py | 164 ++++++++++++++++---------- tests/neural/conftest.py | 66 +++++++++++ tests/neural/genot_test.py | 39 ++++++ 4 files changed, 226 insertions(+), 66 deletions(-) create mode 100644 tests/neural/genot_test.py diff --git a/src/ott/neural/solvers/base_solver.py b/src/ott/neural/solvers/base_solver.py index eb14fadee..9d323b13c 100644 --- a/src/ott/neural/solvers/base_solver.py +++ b/src/ott/neural/solvers/base_solver.py @@ -86,16 +86,17 @@ def _resample_data( ) def _sample_conditional_indices_from_tmap( + self, key: jax.random.PRNGKeyArray, tmat: jnp.ndarray, k_samples_per_x: Union[int, jnp.ndarray], source_arrays: Tuple[jnp.ndarray, ...], target_arrays: Tuple[jnp.ndarray, ...], *, - is_balanced: bool, + source_is_balanced: bool, ) -> Tuple[jnp.array, jnp.array]: left_marginals = tmat.sum(axis=1) - if not is_balanced: + if not source_is_balanced: key, key2 = jax.random.split(key, 2) indices = jax.random.choice( key=key2, @@ -135,6 +136,8 @@ def _get_sinkhorn_match_fn( scale_cost: Any, tau_a: float, tau_b: float, + *, + filter_input: bool = False, ) -> Callable: def match_pairs( @@ -147,7 +150,17 @@ def match_pairs( linear_problem.LinearProblem(geom, tau_a=tau_a, tau_b=tau_b) ).matrix - return match_pairs + def match_pairs_filtered( + x_lin: jax.Array, x_quad: jax.Array, y_lin: jax.Array, y_quad: jax.Array + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: + geom = pointcloud.PointCloud( + x_lin, y_lin, epsilon=epsilon, scale_cost=scale_cost, cost_fn=cost_fn + ) + return ot_solver( + linear_problem.LinearProblem(geom, tau_a=tau_a, tau_b=tau_b) + ).matrix + + return match_pairs_filtered if filter_input else match_pairs def _get_gromov_match_fn( self, @@ -181,10 +194,10 @@ def _get_gromov_match_fn( x_scale_cost = y_scale_cost = xy_scale_cost = scale_cost def match_pairs( - x_quad: Tuple[jnp.ndarray, jnp.ndarray], - y_quad: Tuple[jnp.ndarray, jnp.ndarray], x_lin: Optional[jax.Array], + x_quad: Tuple[jnp.ndarray, jnp.ndarray], y_lin: Optional[jax.Array], + y_quad: Tuple[jnp.ndarray, jnp.ndarray], ) -> Tuple[jnp.array, jnp.array]: geom_xx = pointcloud.PointCloud( x=x_quad, y=x_quad, cost_fn=x_cost_fn, scale_cost=x_scale_cost diff --git a/src/ott/neural/solvers/genot.py b/src/ott/neural/solvers/genot.py index 110d65738..19cf7536c 100644 --- a/src/ott/neural/solvers/genot.py +++ b/src/ott/neural/solvers/genot.py @@ -20,7 +20,7 @@ from flax.training import train_state from flax.training.train_state import TrainState from jax import random -from tqdm import tqdm +from orbax import checkpoint from ott.geometry import costs from ott.neural.models.models import BaseNeuralVectorField @@ -52,6 +52,7 @@ def __init__( valid_freq: int, ot_solver: Type[was_solver.WassersteinSolver], optimizer: Type[optax.GradientTransformation], + checkpoint_manager: Type[checkpoint.CheckpointManager] = None, flow: Type[BaseFlow] = ConstantNoiseFlow(0.0), k_noise_per_x: int = 1, t_offset: float = 1e-5, @@ -61,7 +62,6 @@ def __init__( ] = None, kwargs_solver_latent_to_data: Dict[str, Any] = types.MappingProxyType({}), scale_cost: Union[Any, Mapping[str, Any]] = 1.0, - graph_kwargs: Dict[str, Any] = types.MappingProxyType({}), fused_penalty: float = 0.0, tau_a: float = 1.0, tau_b: float = 1.0, @@ -163,15 +163,13 @@ def __init__( "to ensure that in the (fused) Gromov case the `epsilon` parameter is passed via the `ot_solver`." ) - # setup parameters self.rng = rng - self.metrics = {"loss": [], "loss_eta": [], "loss_xi": []} - - # neural parameters self.neural_vector_field = neural_vector_field self.state_neural_vector_field: Optional[TrainState] = None + self.flow = flow self.optimizer = optimizer - self.noise_fn = jax.tree_util.Partial( + self.checkpoint_manager = checkpoint_manager + self.latent_noise_fn = jax.tree_util.Partial( jax.random.multivariate_normal, mean=jnp.zeros((output_dim,)), cov=jnp.diag(jnp.ones((output_dim,))) @@ -186,7 +184,6 @@ def __init__( self.epsilon = epsilon self.cost_fn = cost_fn self.scale_cost = scale_cost - self.graph_kwargs = graph_kwargs # "k_neighbors", kwargs for graph.Graph.from_graph() self.fused_penalty = fused_penalty # OT latent-data matching parameters @@ -225,8 +222,13 @@ def setup(self) -> None: # TODO: add graph construction function if isinstance(self.ot_solver, sinkhorn.Sinkhorn): self.match_fn = self._get_sinkhorn_match_fn( - self.ot_solver, self.epsilon, self.cost_fn, self.tau_a, self.tau_b, - self.scale_cost + self.ot_solver, + self.epsilon, + self.cost_fn, + self.tau_a, + self.tau_b, + self.scale_cost, + filter_input=True ) else: self._get_gromov_match_fn( @@ -237,41 +239,65 @@ def setup(self) -> None: def __call__(self, train_loader, valid_loader) -> None: """Train GENOT.""" batch: Dict[str, jnp.array] = {} - for step in tqdm(range(self.iterations)): + for iteration in range(self.iterations): batch["source"], batch["source_q"], batch["target"], batch[ "target_q"], batch["condition"] = next(train_loader) self.rng, rng_time, rng_match, rng_resample, rng_noise, rng_latent_data_match, rng_step_fn = jax.random.split( self.rng, 7 ) - n_samples = len(batch["source"]) * self.k_noise_per_k - batch["time"] = self.sample_t(key, n_samples) - batch["noise"] = self.noise_fn( - rng_noise, shape=(batch["source"], self.k_noise_per_x) + batch_size = len(batch["source"] + ) if "source" in batch else len(batch["source_q"]) + n_samples = batch_size * self.k_noise_per_x + batch["time"] = self.sample_t(rng_time, n_samples) + batch["noise"] = self.sample_noise(rng_noise, n_samples) + batch["latent"] = self.latent_noise_fn( + rng_noise, shape=(batch_size, self.k_noise_per_x) ) - tmat = self.match_fn(rng_match, batch["source"], batch["target"]) + tmat = self.match_fn( + batch["source"], batch["source_q"], batch["target"], batch["target_q"] + ) (batch["source"], batch["source_q"], batch["condition"] ), (batch["target"], batch["target_q"]) = self._sample_conditional_indices_from_tmap( - rng_resample, tmat, self.k_noise_per_x, + rng_resample, + tmat, + self.k_noise_per_x, (batch["source"], batch["source_q"], batch["condition"]), - (batch["target"], batch["target_q"]) + (batch["target"], batch["target_q"]), + source_is_balanced=(self.tau_a == 1.0) ) rng_noise = jax.random.split(rng_noise, (len(batch["target"]))) - tmat_latent_data = jax.vmap(self.match_latent_to_data_fn, 0, 0)( - key=rng_noise, x=batch["noise"], y=batch["target"] - ) - (batch["source"], batch["source_q"], batch["condition"] - ), (batch["target"], batch["target_q"]) = self._resample_data( - rng_latent_data_match, tmat_latent_data, - (batch["source"], batch["source_q"], batch["condition"]), - (batch["target"], batch["target_q"]) - ) + if self.solver_latent_to_data is not None: + tmats_latent_data = jnp.array( + jax.vmap(self.match_latent_to_data_fn, 0, + 0)(key=rng_noise, x=batch["noise"], y=batch["target"]) + ) + + if self.k_noise_per_x > 1: + rng_latent_data_match = jax.random.split( + rng_latent_data_match, batch_size + ) + (batch["source"], batch["source_q"], batch["condition"] + ), (batch["target"], + batch["target_q"]) = jax.vmap(self._resample_data, 0, 0)( + rng_latent_data_match, tmats_latent_data, + (batch["source"], batch["source_q"], batch["condition"]), + (batch["target"], batch["target_q"]) + ) + #(batch["source"], batch["source_q"], batch["condition"] + #), (batch["target"], batch["target_q"]) = self._resample_data( + # rng_latent_data_match, tmat_latent_data, + # (batch["source"], batch["source_q"], batch["condition"]), + # (batch["target"], batch["target_q"]) + #) batch = { - key: jnp.reshape(arr, (len(batch["source"]), -1)) + key: + jnp.reshape(arr, (len(batch["source"]), + -1)) if arr is not None else None for key, arr in batch.items() } @@ -282,8 +308,8 @@ def __call__(self, train_loader, valid_loader) -> None: self.state_eta, self.state_xi, eta_predictions, xi_predictions, loss_a, loss_b = self.unbalancedness_step_fn( batch, tmat.sum(axis=1), tmat.sum(axis=0) ) - if iter % self.valid_freq == 0: - self._valid_step(valid_loader, iter) + if iteration % self.valid_freq == 0: + self._valid_step(valid_loader, iteration) if self.checkpoint_manager is not None: states_to_save = { "state_neural_vector_field": self.state_neural_vector_field @@ -292,7 +318,7 @@ def __call__(self, train_loader, valid_loader) -> None: states_to_save["state_eta"] = self.state_mlp if self.state_xi is not None: states_to_save["state_xi"] = self.state_xi - self.checkpoint_manager.save(iter, states_to_save) + self.checkpoint_manager.save(iteration, states_to_save) def _get_step_fn(self) -> Callable: @@ -304,28 +330,34 @@ def step_fn( ): def loss_fn( - params: jax.Array, t: jax.Array, noise: jax.Array, - batch: Dict[str, jnp.array], keys_model: random.PRNGKeyArray + params: jax.Array, batch: Dict[str, jnp.array], + keys_model: random.PRNGKeyArray ): - x_t = self.flow.compute_xt(noise, t, batch["latent"], batch["target"]) + x_t = self.flow.compute_xt( + batch["noise"], batch["time"], batch["latent"], batch["target"] + ) apply_fn = functools.partial( state_neural_vector_field.apply_fn, {"params": params} ) - cond_input = jnp.concatenate([batch["source"], batch["condition"]], - axis=-1) + + if batch["condition"] is None: + cond_input = batch["source"] + else: + cond_input = jnp.concatenate([batch["source"], batch["condition"]], + axis=-1) v_t = jax.vmap(apply_fn)( - t=t, x=x_t, condition=cond_input, keys_model=keys_model + t=batch["time"], x=x_t, condition=cond_input, keys_model=keys_model + ) + u_t = self.flow.compute_ut( + batch["time"], batch["latent"], batch["target"] ) - u_t = self.flow.compute_ut(t, batch["latent"], batch["target"]) return jnp.mean((v_t - u_t) ** 2) + keys_model = random.split(key, len(batch["noise"])) + grad_fn = jax.value_and_grad(loss_fn, has_aux=False) - loss, grads = grad_fn( - state_neural_vector_field.params, - state_neural_vector_field.apply_fn, - batch, - ) + loss, grads = grad_fn(state_neural_vector_field.params, batch, keys_model) return state_neural_vector_field.apply_gradients(grads=grads), loss @@ -333,9 +365,11 @@ def loss_fn( def transport( self, - source: jnp.array, - seed: int = 0, - diffeqsolve_kwargs: Dict[str, Any] = types.MappingProxyType({}) + source: jax.Array, + condition: jax.Array, + rng: random.PRNGKeyArray = random.PRNGKey(0), + diffeqsolve_kwargs: Dict[str, Any] = types.MappingProxyType({}), + forward: bool = True, ) -> Union[jnp.array, diffrax.Solution, Optional[jnp.ndarray]]: """Transport the distribution. @@ -352,23 +386,33 @@ def transport( ------- The transported samples, the solution of the neural ODE, and the rescaling factor. """ + if not forward: + raise NotImplementedError diffeqsolve_kwargs = dict(diffeqsolve_kwargs) - rng = jax.random.PRNGKey(seed) - latent_shape = (len(source),) - latent_batch = self.noise_fn(rng, shape=latent_shape) + assert len(source) == len(condition) if condition is not None else True + + latent_batch = self.latent_noise_fn( + rng, shape=(len(source), self.output_dim) + ) + cond_input = source if condition is None else jnp.concatenate([ + source, condition + ], + axis=-1) apply_fn_partial = partial( - self.state_neural_vector_field.apply_fn, condition=source + self.state_neural_vector_field.apply_fn, condition=cond_input ) + t0 = jnp.zeros((len(source),1)) + t1 = jnp.ones((len(source),1)) solution = diffrax.diffeqsolve( diffrax.ODETerm( lambda t, y, *args: apply_fn_partial({"params": self.state_neural_vector_field.params}, t=t, - latent=y) + x=y) ), diffeqsolve_kwargs.pop("solver", diffrax.Tsit5()), - t0=0, - t1=1, + t0=t0, + t1=t1, dt0=diffeqsolve_kwargs.pop("dt0", None), y0=latent_batch, stepsize_controller=diffeqsolve_kwargs.pop( @@ -376,14 +420,7 @@ def transport( ), **diffeqsolve_kwargs, ) - if self.state_eta is not None: - weight_factors = self.state_eta.apply_fn({ - "params": self.state_eta.params - }, - x=source) - else: - weight_factors = jnp.ones(source.shape) - return solution.ys, solution, weight_factors + return solution.ys def _valid_step(self, valid_loader, iter) -> None: next(valid_loader) @@ -407,3 +444,8 @@ def sample_t( #TODO: make more general self, key: random.PRNGKey, batch_size: int ) -> jnp.ndarray: #TODO: make more general return random.uniform(key, [batch_size, 1]) + + def sample_noise( #TODO: make more general + self, key: random.PRNGKey, batch_size: int + ) -> jnp.ndarray: #TODO: make more general + return random.normal(key, shape=(batch_size, self.input_dim)) diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py index f1b6e16f8..161a5a1ab 100644 --- a/tests/neural/conftest.py +++ b/tests/neural/conftest.py @@ -86,3 +86,69 @@ def data_loader_gaussian_with_conditions(): conditions = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 1)) target = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) + 1.0 return DataLoader(source, target, conditions, 16) + + +class GENOTDataLoader: + + def __init__( + self, + source_lin: Optional[jax.Array], + source_quad: Optional[jax.Array], + target_lin: Optional[jax.Array], + target_quad: Optional[jax.Array], + conditions: Optional[jax.Array], + batch_size: int = 64 + ) -> None: + super().__init__() + self.source_lin = source_lin + self.target_lin = target_lin + self.source_quad = source_quad + self.target_quad = target_quad + self.conditions = conditions + self.batch_size = batch_size + self.key = jax.random.PRNGKey(0) + + def __next__(self) -> jax.Array: + key, self.key = jax.random.split(self.key) + inds_source = jax.random.choice( + key, len(self.source_lin), shape=[self.batch_size] + ) + inds_target = jax.random.choice( + key, len(self.target_lin), shape=[self.batch_size] + ) + return self.source_lin[ + inds_source, : + ] if self.source_lin is not None else None, self.source_quad[ + inds_source, : + ] if self.source_quad is not None else None, self.target_lin[ + inds_target, : + ] if self.target_lin is not None else None, self.target_quad[ + inds_target, : + ] if self.target_quad is not None else None, self.conditions[ + inds_source, :] if self.conditions is not None else None + + +@pytest.fixture(scope="module") +def genot_data_loader_linear(): + """Returns a data loader for a simple Gaussian mixture.""" + source = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) + target = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) + 1.0 + return GENOTDataLoader(source, None, target, None, None, 16) + + +@pytest.fixture(scope="module") +def genot_data_loader_quad(): + """Returns a data loader for a simple Gaussian mixture.""" + source = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) + target = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 1)) + 1.0 + return GENOTDataLoader(None, source, None, target, None, 16) + + +@pytest.fixture(scope="module") +def genot_data_loader_fused(): + """Returns a data loader for a simple Gaussian mixture.""" + source_q = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) + target_q = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 1)) + 1.0 + source_lin = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) + target_lin = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) + 1.0 + return GENOTDataLoader(source_lin, source_q, target_lin, target_q, None, 16) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py new file mode 100644 index 000000000..55e2a351d --- /dev/null +++ b/tests/neural/genot_test.py @@ -0,0 +1,39 @@ +import diffrax +import jax.numpy as jnp +import optax + +from ott.neural.models.models import NeuralVectorField +from ott.neural.solvers.genot import GENOT +from ott.solvers.linear import sinkhorn + + +class TestGENOT: + + def test_genot_linear(self, genot_data_loader_linear): + neural_vf = NeuralVectorField( + output_dim=2, + condition_dim=0, + latent_embed_dim=5, + ) + ot_solver = sinkhorn.Sinkhorn() + optimizer = optax.adam(learning_rate=1e-3) + genot = GENOT( + neural_vf, + input_dim=2, + output_dim=2, + cond_dim=0, + iterations=3, + valid_freq=2, + ot_solver=ot_solver, + optimizer=optimizer + ) + genot(genot_data_loader_linear, genot_data_loader_linear) + + source_lin, source_quad, target_lin, target_quad, condition = next( + genot_data_loader_linear + ) + result_forward = genot.transport( + source_lin, condition=condition, forward=True + ) + assert isinstance(result_forward, diffrax.Solution) + assert jnp.sum(jnp.isnan(result_forward.y)) == 0 From 70a6173715070d956db486fd3e23ae0e553b5f80 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 24 Nov 2023 12:02:20 +0100 Subject: [PATCH 014/186] [ci skip] fix transport --- src/ott/neural/solvers/flow_matching.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/ott/neural/solvers/flow_matching.py b/src/ott/neural/solvers/flow_matching.py index bfa5b5110..8ad52310f 100644 --- a/src/ott/neural/solvers/flow_matching.py +++ b/src/ott/neural/solvers/flow_matching.py @@ -171,8 +171,9 @@ def transport( ) -> diffrax.Solution: diffeqsolve_kwargs = dict(diffeqsolve_kwargs) + t0, t1 = (0.0, 1.0) if forward else (1.0, 0.0) def solve_ode( - t0: jax.Array, t1: jax.Array, input: jax.Array, cond: jax.Array + input: jax.Array, cond: jax.Array ): return diffrax.diffeqsolve( diffrax.ODETerm( @@ -192,12 +193,9 @@ def solve_ode( diffrax.PIDController(rtol=1e-5, atol=1e-5) ), **diffeqsolve_kwargs, - ).solution.y + ).ys[0] - arr = jnp.ones((len(data), 1)) - t0, t1 = (arr * 0.0, arr * 1.0) if forward else (arr * 1.0, arr * 0.0) - - out = jax.vmap(solve_ode)(t0, t1, data, condition) + out = jax.vmap(solve_ode)(data, condition) return out def _transport( From 40570e683591325d91941452cd46478775116488 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 24 Nov 2023 12:51:57 +0100 Subject: [PATCH 015/186] [ci skip] flow matching tests passing --- tests/neural/flow_matching_test.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/neural/flow_matching_test.py b/tests/neural/flow_matching_test.py index 39199106b..ab8b7713f 100644 --- a/tests/neural/flow_matching_test.py +++ b/tests/neural/flow_matching_test.py @@ -46,12 +46,12 @@ def test_flow_matching(self, data_loader_gaussian, flow: Type[BaseFlow]): source, target, condition = next(data_loader_gaussian) result_forward = fm.transport(source, condition=condition, forward=True) - assert isinstance(result_forward, diffrax.Solution) - assert jnp.sum(jnp.isnan(result_forward.y)) == 0 + assert isinstance(result_forward, jax.Array) + assert jnp.sum(jnp.isnan(result_forward)) == 0 result_backward = fm.transport(target, condition=condition, forward=False) - assert isinstance(result_backward, diffrax.Solution) - assert jnp.sum(jnp.isnan(result_backward.y)) == 0 + assert isinstance(result_backward, jax.Array) + assert jnp.sum(jnp.isnan(result_backward)) == 0 @pytest.mark.parametrize( "flow", @@ -123,9 +123,9 @@ def test_flow_matching_conditional( source, target, condition = next(data_loader_gaussian_conditional) result_forward = fm.transport(source, condition=condition, forward=True) - assert isinstance(result_forward, diffrax.Solution) - assert jnp.sum(jnp.isnan(result_forward.y)) == 0 + assert isinstance(result_forward, jax.Array) + assert jnp.sum(jnp.isnan(result_forward)) == 0 result_backward = fm.transport(target, condition=condition, forward=False) - assert isinstance(result_backward, diffrax.Solution) - assert jnp.sum(jnp.isnan(result_backward.y)) == 0 + assert isinstance(result_backward, jax.Array) + assert jnp.sum(jnp.isnan(result_backward)) == 0 From b0910ea1355f94e97fc57cbf2b1dec38c0df6a62 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 24 Nov 2023 14:40:26 +0100 Subject: [PATCH 016/186] [ci skip] add more tests genot --- src/ott/neural/solvers/flow_matching.py | 5 +- src/ott/neural/solvers/genot.py | 87 ++++++++++++------------- tests/neural/conftest.py | 34 ++++++++-- tests/neural/flow_matching_test.py | 1 - tests/neural/genot_test.py | 71 ++++++++++++++++++-- 5 files changed, 139 insertions(+), 59 deletions(-) diff --git a/src/ott/neural/solvers/flow_matching.py b/src/ott/neural/solvers/flow_matching.py index 8ad52310f..909c54207 100644 --- a/src/ott/neural/solvers/flow_matching.py +++ b/src/ott/neural/solvers/flow_matching.py @@ -172,9 +172,8 @@ def transport( diffeqsolve_kwargs = dict(diffeqsolve_kwargs) t0, t1 = (0.0, 1.0) if forward else (1.0, 0.0) - def solve_ode( - input: jax.Array, cond: jax.Array - ): + + def solve_ode(input: jax.Array, cond: jax.Array): return diffrax.diffeqsolve( diffrax.ODETerm( lambda t, x, args: self.state_neural_vector_field. diff --git a/src/ott/neural/solvers/genot.py b/src/ott/neural/solvers/genot.py index 19cf7536c..b72d84c48 100644 --- a/src/ott/neural/solvers/genot.py +++ b/src/ott/neural/solvers/genot.py @@ -1,6 +1,5 @@ import functools import types -from functools import partial from typing import ( Any, Callable, @@ -209,7 +208,7 @@ def setup(self) -> None: Keyword arguments for the setup function """ self.state_neural_vector_field = self.neural_vector_field.create_train_state( - self.rng, self.optimizer, self.input_dim + self.rng, self.optimizer, self.input_dim + self.cond_dim ) self.step_fn = self._get_step_fn() if self.solver_latent_to_data is not None: @@ -231,7 +230,7 @@ def setup(self) -> None: filter_input=True ) else: - self._get_gromov_match_fn( + self.match_fn = self._get_gromov_match_fn( self.ot_solver, self.cost_fn, self.tau_a, self.tau_b, self.scale_cost, self.fused_penalty ) @@ -243,11 +242,11 @@ def __call__(self, train_loader, valid_loader) -> None: batch["source"], batch["source_q"], batch["target"], batch[ "target_q"], batch["condition"] = next(train_loader) - self.rng, rng_time, rng_match, rng_resample, rng_noise, rng_latent_data_match, rng_step_fn = jax.random.split( - self.rng, 7 + self.rng, rng_time, rng_resample, rng_noise, rng_latent_data_match, rng_step_fn = jax.random.split( + self.rng, 6 ) batch_size = len(batch["source"] - ) if "source" in batch else len(batch["source_q"]) + ) if batch["source"] is not None else len(batch["source_q"]) n_samples = batch_size * self.k_noise_per_x batch["time"] = self.sample_t(rng_time, n_samples) batch["noise"] = self.sample_noise(rng_noise, n_samples) @@ -268,12 +267,13 @@ def __call__(self, train_loader, valid_loader) -> None: (batch["target"], batch["target_q"]), source_is_balanced=(self.tau_a == 1.0) ) - rng_noise = jax.random.split(rng_noise, (len(batch["target"]))) - + rng_latent = jax.random.split(rng_noise, batch_size * self.k_noise_per_x) + if self.solver_latent_to_data is not None: + target = jnp.concatenate([batch[el] for el in ["target", "target_q"] if batch[el] is not None], axis=1) tmats_latent_data = jnp.array( jax.vmap(self.match_latent_to_data_fn, 0, - 0)(key=rng_noise, x=batch["noise"], y=batch["target"]) + 0)(key=rng_latent, x=batch["latent"], y=target) ) if self.k_noise_per_x > 1: @@ -296,7 +296,7 @@ def __call__(self, train_loader, valid_loader) -> None: batch = { key: - jnp.reshape(arr, (len(batch["source"]), + jnp.reshape(arr, (batch_size*self.k_noise_per_x, -1)) if arr is not None else None for key, arr in batch.items() } @@ -333,24 +333,21 @@ def loss_fn( params: jax.Array, batch: Dict[str, jnp.array], keys_model: random.PRNGKeyArray ): - + target = jnp.concatenate([batch[el] for el in ["target", "target_q"] if batch[el] is not None], axis=1) x_t = self.flow.compute_xt( - batch["noise"], batch["time"], batch["latent"], batch["target"] + batch["noise"], batch["time"], batch["latent"], target ) apply_fn = functools.partial( state_neural_vector_field.apply_fn, {"params": params} ) - if batch["condition"] is None: - cond_input = batch["source"] - else: - cond_input = jnp.concatenate([batch["source"], batch["condition"]], - axis=-1) + cond_input = jnp.concatenate([batch[el] for el in ["source", "source_q", "condition"] if batch[el] is not None], axis=1) + v_t = jax.vmap(apply_fn)( t=batch["time"], x=x_t, condition=cond_input, keys_model=keys_model ) u_t = self.flow.compute_ut( - batch["time"], batch["latent"], batch["target"] + batch["time"], batch["latent"], target ) return jnp.mean((v_t - u_t) ** 2) @@ -366,7 +363,7 @@ def loss_fn( def transport( self, source: jax.Array, - condition: jax.Array, + condition: Optional[jax.Array], rng: random.PRNGKeyArray = random.PRNGKey(0), diffeqsolve_kwargs: Dict[str, Any] = types.MappingProxyType({}), forward: bool = True, @@ -391,36 +388,36 @@ def transport( diffeqsolve_kwargs = dict(diffeqsolve_kwargs) assert len(source) == len(condition) if condition is not None else True - latent_batch = self.latent_noise_fn( - rng, shape=(len(source), self.output_dim) - ) + latent_batch = self.latent_noise_fn(rng, shape=(len(source),)) cond_input = source if condition is None else jnp.concatenate([ source, condition ], axis=-1) - apply_fn_partial = partial( - self.state_neural_vector_field.apply_fn, condition=cond_input - ) - t0 = jnp.zeros((len(source),1)) - t1 = jnp.ones((len(source),1)) - solution = diffrax.diffeqsolve( - diffrax.ODETerm( - lambda t, y, *args: - apply_fn_partial({"params": self.state_neural_vector_field.params}, - t=t, - x=y) - ), - diffeqsolve_kwargs.pop("solver", diffrax.Tsit5()), - t0=t0, - t1=t1, - dt0=diffeqsolve_kwargs.pop("dt0", None), - y0=latent_batch, - stepsize_controller=diffeqsolve_kwargs.pop( - "stepsize_controller", diffrax.PIDController(rtol=1e-3, atol=1e-6) - ), - **diffeqsolve_kwargs, - ) - return solution.ys + t0, t1 = (0.0, 1.0) + + def solve_ode(input: jax.Array, cond: jax.Array): + return diffrax.diffeqsolve( + diffrax.ODETerm( + lambda t, x, args: self.state_neural_vector_field. + apply_fn({"params": self.state_neural_vector_field.params}, + t=t, + x=x, + condition=cond) + ), + diffeqsolve_kwargs.pop("solver", diffrax.Tsit5()), + t0=t0, + t1=t1, + dt0=diffeqsolve_kwargs.pop("dt0", None), + y0=input, + stepsize_controller=diffeqsolve_kwargs.pop( + "stepsize_controller", + diffrax.PIDController(rtol=1e-5, atol=1e-5) + ), + **diffeqsolve_kwargs, + ).ys[0] + + out = jax.vmap(solve_ode)(latent_batch, cond_input) + return out def _valid_step(self, valid_loader, iter) -> None: next(valid_loader) diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py index 161a5a1ab..c9b226ce6 100644 --- a/tests/neural/conftest.py +++ b/tests/neural/conftest.py @@ -100,6 +100,23 @@ def __init__( batch_size: int = 64 ) -> None: super().__init__() + if source_lin is not None: + if source_quad is not None: + assert len(source_lin) == len(source_quad) + else: + self.n_source = len(source_lin) + else: + self.n_source = len(source_quad) + if conditions is not None: + assert len(conditions) == self.n_source + if target_lin is not None: + if target_quad is not None: + assert len(target_lin) == len(target_quad) + else: + self.n_target = len(target_lin) + else: + self.n_target = len(target_quad) + self.source_lin = source_lin self.target_lin = target_lin self.source_quad = source_quad @@ -110,12 +127,8 @@ def __init__( def __next__(self) -> jax.Array: key, self.key = jax.random.split(self.key) - inds_source = jax.random.choice( - key, len(self.source_lin), shape=[self.batch_size] - ) - inds_target = jax.random.choice( - key, len(self.target_lin), shape=[self.batch_size] - ) + inds_source = jax.random.choice(key, self.n_source, shape=[self.batch_size]) + inds_target = jax.random.choice(key, self.n_target, shape=[self.batch_size]) return self.source_lin[ inds_source, : ] if self.source_lin is not None else None, self.source_quad[ @@ -136,6 +149,15 @@ def genot_data_loader_linear(): return GENOTDataLoader(source, None, target, None, None, 16) +@pytest.fixture(scope="module") +def genot_data_loader_linear_conditional(): + """Returns a data loader for a simple Gaussian mixture.""" + source = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) + target = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) + 1.0 + conditions = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 4)) + return GENOTDataLoader(source, None, target, None, conditions, 16) + + @pytest.fixture(scope="module") def genot_data_loader_quad(): """Returns a data loader for a simple Gaussian mixture.""" diff --git a/tests/neural/flow_matching_test.py b/tests/neural/flow_matching_test.py index ab8b7713f..858c90084 100644 --- a/tests/neural/flow_matching_test.py +++ b/tests/neural/flow_matching_test.py @@ -1,6 +1,5 @@ from typing import Type -import diffrax import jax import jax.numpy as jnp import optax diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index 55e2a351d..c1bf870bd 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -1,15 +1,16 @@ -import diffrax +import jax import jax.numpy as jnp import optax from ott.neural.models.models import NeuralVectorField from ott.neural.solvers.genot import GENOT from ott.solvers.linear import sinkhorn +from ott.solvers.quadratic import gromov_wasserstein class TestGENOT: - def test_genot_linear(self, genot_data_loader_linear): + def test_genot_linear_unconditional(self, genot_data_loader_linear): neural_vf = NeuralVectorField( output_dim=2, condition_dim=0, @@ -35,5 +36,67 @@ def test_genot_linear(self, genot_data_loader_linear): result_forward = genot.transport( source_lin, condition=condition, forward=True ) - assert isinstance(result_forward, diffrax.Solution) - assert jnp.sum(jnp.isnan(result_forward.y)) == 0 + assert isinstance(result_forward, jax.Array) + assert jnp.sum(jnp.isnan(result_forward)) == 0 + + def test_genot_quad_unconditional(self, genot_data_loader_quad): + neural_vf = NeuralVectorField( + output_dim=2, + condition_dim=0, + latent_embed_dim=5, + ) + ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) + optimizer = optax.adam(learning_rate=1e-3) + genot = GENOT( + neural_vf, + input_dim=1, + output_dim=2, + cond_dim=0, + epsilon=None, + iterations=3, + valid_freq=2, + ot_solver=ot_solver, + optimizer=optimizer + ) + genot(genot_data_loader_quad, genot_data_loader_quad) + + source_lin, source_quad, target_lin, target_quad, condition = next( + genot_data_loader_quad + ) + result_forward = genot.transport( + source_quad, condition=condition, forward=True + ) + assert isinstance(result_forward, jax.Array) + assert jnp.sum(jnp.isnan(result_forward)) == 0 + + def test_genot_linear_conditional(self, genot_data_loader_linear_conditional): + neural_vf = NeuralVectorField( + output_dim=2, + condition_dim=4, + latent_embed_dim=5, + ) + ot_solver = sinkhorn.Sinkhorn() + optimizer = optax.adam(learning_rate=1e-3) + genot = GENOT( + neural_vf, + input_dim=2, + output_dim=2, + cond_dim=4, + iterations=3, + valid_freq=2, + ot_solver=ot_solver, + optimizer=optimizer + ) + genot( + genot_data_loader_linear_conditional, + genot_data_loader_linear_conditional + ) + + source_lin, source_quad, target_lin, target_quad, condition = next( + genot_data_loader_linear_conditional + ) + result_forward = genot.transport( + source_lin, condition=condition, forward=True + ) + assert isinstance(result_forward, jax.Array) + assert jnp.sum(jnp.isnan(result_forward)) == 0 From 542f5122ebdc3cf5d7ff0eefa0ed03d86246691e Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 24 Nov 2023 15:37:06 +0100 Subject: [PATCH 017/186] [ci skip] add more tests genot --- src/ott/neural/solvers/flow_matching.py | 34 +---- src/ott/neural/solvers/genot.py | 43 +++--- tests/neural/conftest.py | 24 +++ tests/neural/genot_test.py | 190 +++++++++++++++++++++--- 4 files changed, 219 insertions(+), 72 deletions(-) diff --git a/src/ott/neural/solvers/flow_matching.py b/src/ott/neural/solvers/flow_matching.py index 909c54207..a80f4cc49 100644 --- a/src/ott/neural/solvers/flow_matching.py +++ b/src/ott/neural/solvers/flow_matching.py @@ -194,39 +194,7 @@ def solve_ode(input: jax.Array, cond: jax.Array): **diffeqsolve_kwargs, ).ys[0] - out = jax.vmap(solve_ode)(data, condition) - return out - - def _transport( - self, - data: jnp.array, - condition: Optional[jax.Array], - forward: bool = True, - diffeqsolve_kwargs: Dict[str, Any] = types.MappingProxyType({}) - ) -> diffrax.Solution: - diffeqsolve_kwargs = dict(diffeqsolve_kwargs) - arr = jnp.ones((len(data), 1)) - t0, t1 = (arr * 0.0, arr * 1.0) if forward else (arr * 1.0, arr * 0.0) - apply_fn_partial = functools.partial( - self.state_neural_vector_field.apply_fn, - params={"params": self.state_neural_vector_field.params}, - condition=condition - ) - term = diffrax.ODETerm(lambda t, y, *args: apply_fn_partial(t, y, *args)) - solver = diffeqsolve_kwargs.pop("solver", diffrax.Tsit5()) - stepsize_controller = diffeqsolve_kwargs.pop( - "stepsize_controller", diffrax.PIDController(rtol=1e-5, atol=1e-5) - ) - return diffrax.diffeqsolve( - term, - solver, - t0=t0, - t1=t1, - dt0=diffeqsolve_kwargs.pop("dt0", None), - y0=data, - stepsize_controller=stepsize_controller, - **diffeqsolve_kwargs, - ) + return jax.vmap(solve_ode)(data, condition) def _valid_step(self, valid_loader, iter) -> None: next(valid_loader) diff --git a/src/ott/neural/solvers/genot.py b/src/ott/neural/solvers/genot.py index b72d84c48..1d9a9fc72 100644 --- a/src/ott/neural/solvers/genot.py +++ b/src/ott/neural/solvers/genot.py @@ -208,7 +208,7 @@ def setup(self) -> None: Keyword arguments for the setup function """ self.state_neural_vector_field = self.neural_vector_field.create_train_state( - self.rng, self.optimizer, self.input_dim + self.cond_dim + self.rng, self.optimizer, self.output_dim ) self.step_fn = self._get_step_fn() if self.solver_latent_to_data is not None: @@ -245,13 +245,16 @@ def __call__(self, train_loader, valid_loader) -> None: self.rng, rng_time, rng_resample, rng_noise, rng_latent_data_match, rng_step_fn = jax.random.split( self.rng, 6 ) - batch_size = len(batch["source"] - ) if batch["source"] is not None else len(batch["source_q"]) + batch_size = len(batch["source"]) if batch["source"] is not None else len( + batch["source_q"] + ) n_samples = batch_size * self.k_noise_per_x batch["time"] = self.sample_t(rng_time, n_samples) batch["noise"] = self.sample_noise(rng_noise, n_samples) batch["latent"] = self.latent_noise_fn( - rng_noise, shape=(batch_size, self.k_noise_per_x) + rng_noise, + shape=(batch_size, self.k_noise_per_x) if self.k_noise_per_x > 1 else + (batch_size,) ) tmat = self.match_fn( @@ -268,9 +271,12 @@ def __call__(self, train_loader, valid_loader) -> None: source_is_balanced=(self.tau_a == 1.0) ) rng_latent = jax.random.split(rng_noise, batch_size * self.k_noise_per_x) - + if self.solver_latent_to_data is not None: - target = jnp.concatenate([batch[el] for el in ["target", "target_q"] if batch[el] is not None], axis=1) + target = jnp.concatenate([ + batch[el] for el in ["target", "target_q"] if batch[el] is not None + ], + axis=1) tmats_latent_data = jnp.array( jax.vmap(self.match_latent_to_data_fn, 0, 0)(key=rng_latent, x=batch["latent"], y=target) @@ -293,10 +299,9 @@ def __call__(self, train_loader, valid_loader) -> None: # (batch["source"], batch["source_q"], batch["condition"]), # (batch["target"], batch["target_q"]) #) - batch = { key: - jnp.reshape(arr, (batch_size*self.k_noise_per_x, + jnp.reshape(arr, (batch_size * self.k_noise_per_x, -1)) if arr is not None else None for key, arr in batch.items() } @@ -333,7 +338,10 @@ def loss_fn( params: jax.Array, batch: Dict[str, jnp.array], keys_model: random.PRNGKeyArray ): - target = jnp.concatenate([batch[el] for el in ["target", "target_q"] if batch[el] is not None], axis=1) + target = jnp.concatenate([ + batch[el] for el in ["target", "target_q"] if batch[el] is not None + ], + axis=1) x_t = self.flow.compute_xt( batch["noise"], batch["time"], batch["latent"], target ) @@ -341,14 +349,16 @@ def loss_fn( state_neural_vector_field.apply_fn, {"params": params} ) - cond_input = jnp.concatenate([batch[el] for el in ["source", "source_q", "condition"] if batch[el] is not None], axis=1) - + cond_input = jnp.concatenate([ + batch[el] + for el in ["source", "source_q", "condition"] + if batch[el] is not None + ], + axis=1) v_t = jax.vmap(apply_fn)( t=batch["time"], x=x_t, condition=cond_input, keys_model=keys_model ) - u_t = self.flow.compute_ut( - batch["time"], batch["latent"], target - ) + u_t = self.flow.compute_ut(batch["time"], batch["latent"], target) return jnp.mean((v_t - u_t) ** 2) keys_model = random.split(key, len(batch["noise"])) @@ -416,8 +426,7 @@ def solve_ode(input: jax.Array, cond: jax.Array): **diffeqsolve_kwargs, ).ys[0] - out = jax.vmap(solve_ode)(latent_batch, cond_input) - return out + return jax.vmap(solve_ode)(latent_batch, cond_input) def _valid_step(self, valid_loader, iter) -> None: next(valid_loader) @@ -445,4 +454,4 @@ def sample_t( #TODO: make more general def sample_noise( #TODO: make more general self, key: random.PRNGKey, batch_size: int ) -> jnp.ndarray: #TODO: make more general - return random.normal(key, shape=(batch_size, self.input_dim)) + return random.normal(key, shape=(batch_size, self.output_dim)) diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py index c9b226ce6..2dc9f1e43 100644 --- a/tests/neural/conftest.py +++ b/tests/neural/conftest.py @@ -103,6 +103,7 @@ def __init__( if source_lin is not None: if source_quad is not None: assert len(source_lin) == len(source_quad) + self.n_source = len(source_lin) else: self.n_source = len(source_lin) else: @@ -112,6 +113,7 @@ def __init__( if target_lin is not None: if target_quad is not None: assert len(target_lin) == len(target_quad) + self.n_target = len(target_lin) else: self.n_target = len(target_lin) else: @@ -166,6 +168,15 @@ def genot_data_loader_quad(): return GENOTDataLoader(None, source, None, target, None, 16) +@pytest.fixture(scope="module") +def genot_data_loader_quad_conditional(): + """Returns a data loader for a simple Gaussian mixture.""" + source = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) + target = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 1)) + 1.0 + conditions = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 7)) + return GENOTDataLoader(None, source, None, target, conditions, 16) + + @pytest.fixture(scope="module") def genot_data_loader_fused(): """Returns a data loader for a simple Gaussian mixture.""" @@ -174,3 +185,16 @@ def genot_data_loader_fused(): source_lin = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) target_lin = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) + 1.0 return GENOTDataLoader(source_lin, source_q, target_lin, target_q, None, 16) + + +@pytest.fixture(scope="module") +def genot_data_loader_fused_conditional(): + """Returns a data loader for a simple Gaussian mixture.""" + source_q = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) + target_q = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 1)) + 1.0 + source_lin = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) + target_lin = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) + 1.0 + conditions = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 1)) + return GENOTDataLoader( + source_lin, source_q, target_lin, target_q, conditions, 16 + ) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index c1bf870bd..5f59db542 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -1,6 +1,9 @@ +from typing import Iterator + import jax import jax.numpy as jnp import optax +import pytest from ott.neural.models.models import NeuralVectorField from ott.neural.solvers.genot import GENOT @@ -9,24 +12,36 @@ class TestGENOT: + #TODO: add tests for unbalancedness + + @pytest.mark.parametrize("k_noise_per_x", [1, 2]) + def test_genot_linear_unconditional( + self, genot_data_loader_linear: Iterator, k_noise_per_x: int + ): + source_lin, source_quad, target_lin, target_quad, condition = next( + genot_data_loader_linear + ) + source_dim = source_lin.shape[1] + target_dim = target_lin.shape[1] + condition_dim = 0 - def test_genot_linear_unconditional(self, genot_data_loader_linear): neural_vf = NeuralVectorField( - output_dim=2, - condition_dim=0, + output_dim=target_dim, + condition_dim=condition_dim, latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() optimizer = optax.adam(learning_rate=1e-3) genot = GENOT( neural_vf, - input_dim=2, - output_dim=2, - cond_dim=0, + input_dim=source_dim, + output_dim=target_dim, + cond_dim=condition_dim, iterations=3, valid_freq=2, ot_solver=ot_solver, - optimizer=optimizer + optimizer=optimizer, + k_noise_per_x=k_noise_per_x, ) genot(genot_data_loader_linear, genot_data_loader_linear) @@ -39,53 +54,109 @@ def test_genot_linear_unconditional(self, genot_data_loader_linear): assert isinstance(result_forward, jax.Array) assert jnp.sum(jnp.isnan(result_forward)) == 0 - def test_genot_quad_unconditional(self, genot_data_loader_quad): + @pytest.mark.parametrize("k_noise_per_x", [1, 2]) + def test_genot_quad_unconditional( + self, genot_data_loader_quad: Iterator, k_noise_per_x: int + ): + source_lin, source_quad, target_lin, target_quad, condition = next( + genot_data_loader_quad + ) + source_dim = source_quad.shape[1] + target_dim = target_quad.shape[1] + condition_dim = 0 neural_vf = NeuralVectorField( - output_dim=2, - condition_dim=0, + output_dim=target_dim, + condition_dim=condition_dim, latent_embed_dim=5, ) ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) optimizer = optax.adam(learning_rate=1e-3) genot = GENOT( neural_vf, - input_dim=1, - output_dim=2, - cond_dim=0, + input_dim=source_dim, + output_dim=target_dim, + cond_dim=condition_dim, epsilon=None, iterations=3, valid_freq=2, ot_solver=ot_solver, - optimizer=optimizer + optimizer=optimizer, + k_noise_per_x=k_noise_per_x, ) genot(genot_data_loader_quad, genot_data_loader_quad) + result_forward = genot.transport( + source_quad, condition=condition, forward=True + ) + assert isinstance(result_forward, jax.Array) + assert jnp.sum(jnp.isnan(result_forward)) == 0 + + @pytest.mark.parametrize("k_noise_per_x", [1, 2]) + def test_genot_fused_unconditional( + self, genot_data_loader_fused: Iterator, k_noise_per_x: int + ): source_lin, source_quad, target_lin, target_quad, condition = next( - genot_data_loader_quad + genot_data_loader_fused + ) + source_dim = source_lin.shape[1] + source_quad.shape[1] + target_dim = target_lin.shape[1] + target_quad.shape[1] + condition_dim = 0 + neural_vf = NeuralVectorField( + output_dim=target_dim, + condition_dim=condition_dim, + latent_embed_dim=5, + ) + ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) + optimizer = optax.adam(learning_rate=1e-3) + genot = GENOT( + neural_vf, + input_dim=source_dim, + output_dim=target_dim, + cond_dim=condition_dim, + epsilon=None, + iterations=3, + valid_freq=2, + ot_solver=ot_solver, + optimizer=optimizer, + fused_penalty=0.5, + k_noise_per_x=k_noise_per_x, ) + genot(genot_data_loader_fused, genot_data_loader_fused) + result_forward = genot.transport( source_quad, condition=condition, forward=True ) assert isinstance(result_forward, jax.Array) assert jnp.sum(jnp.isnan(result_forward)) == 0 - def test_genot_linear_conditional(self, genot_data_loader_linear_conditional): + @pytest.mark.parametrize("k_noise_per_x", [1, 2]) + def test_genot_linear_conditional( + self, genot_data_loader_linear_conditional: Iterator, k_noise_per_x: int + ): + source_lin, source_quad, target_lin, target_quad, condition = next( + genot_data_loader_linear_conditional + ) + source_dim = source_lin.shape[1] + target_dim = target_lin.shape[1] + condition_dim = condition.shape[1] + neural_vf = NeuralVectorField( - output_dim=2, - condition_dim=4, + output_dim=target_dim, + condition_dim=source_dim + condition_dim, latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() optimizer = optax.adam(learning_rate=1e-3) genot = GENOT( neural_vf, - input_dim=2, - output_dim=2, - cond_dim=4, + input_dim=source_dim, + output_dim=target_dim, + cond_dim=condition_dim, iterations=3, valid_freq=2, ot_solver=ot_solver, - optimizer=optimizer + optimizer=optimizer, + k_noise_per_x=k_noise_per_x, ) genot( genot_data_loader_linear_conditional, @@ -100,3 +171,78 @@ def test_genot_linear_conditional(self, genot_data_loader_linear_conditional): ) assert isinstance(result_forward, jax.Array) assert jnp.sum(jnp.isnan(result_forward)) == 0 + + @pytest.mark.parametrize("k_noise_per_x", [1, 2]) + def test_genot_quad_conditional( + self, genot_data_loader_quad: Iterator, k_noise_per_x: int + ): + source_lin, source_quad, target_lin, target_quad, condition = next( + genot_data_loader_quad + ) + source_dim = source_quad.shape[1] + target_dim = target_quad.shape[1] + condition_dim = condition.shape[1] + neural_vf = NeuralVectorField( + output_dim=target_dim, + condition_dim=condition_dim, + latent_embed_dim=5, + ) + ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) + optimizer = optax.adam(learning_rate=1e-3) + genot = GENOT( + neural_vf, + input_dim=source_dim, + output_dim=target_dim, + cond_dim=condition_dim, + epsilon=None, + iterations=3, + valid_freq=2, + ot_solver=ot_solver, + optimizer=optimizer, + k_noise_per_x=k_noise_per_x, + ) + genot(genot_data_loader_quad, genot_data_loader_quad) + + result_forward = genot.transport( + source_quad, condition=condition, forward=True + ) + assert isinstance(result_forward, jax.Array) + assert jnp.sum(jnp.isnan(result_forward)) == 0 + + @pytest.mark.parametrize("k_noise_per_x", [1, 2]) + def test_genot_fused_conditional( + self, genot_data_loader_fused: Iterator, k_noise_per_x: int + ): + source_lin, source_quad, target_lin, target_quad, condition = next( + genot_data_loader_fused + ) + source_dim = source_lin.shape[1] + source_quad.shape[1] + target_dim = target_lin.shape[1] + target_quad.shape[1] + condition_dim = condition.shape[1] + neural_vf = NeuralVectorField( + output_dim=target_dim, + condition_dim=condition_dim, + latent_embed_dim=5, + ) + ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) + optimizer = optax.adam(learning_rate=1e-3) + genot = GENOT( + neural_vf, + input_dim=source_dim, + output_dim=target_dim, + cond_dim=condition_dim, + epsilon=None, + iterations=3, + valid_freq=2, + ot_solver=ot_solver, + optimizer=optimizer, + fused_penalty=0.5, + k_noise_per_x=k_noise_per_x, + ) + genot(genot_data_loader_fused, genot_data_loader_fused) + + result_forward = genot.transport( + source_quad, condition=condition, forward=True + ) + assert isinstance(result_forward, jax.Array) + assert jnp.sum(jnp.isnan(result_forward)) == 0 From c067f45b04c18dc7f265aa6a42cd645a6eca761f Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Sun, 26 Nov 2023 13:26:12 +0100 Subject: [PATCH 018/186] [ci skip] add TimeSampler --- src/ott/neural/data/dataloaders.py | 13 +++++ src/ott/neural/solvers/base_solver.py | 13 +++++ src/ott/neural/solvers/flows.py | 48 +++++++++++++++++ src/ott/neural/solvers/genot.py | 53 ++++++++++++------- .../solvers/{flow_matching.py => otfm.py} | 28 ++++++---- tests/neural/flow_matching_test.py | 37 ++++++++++--- tests/neural/genot_test.py | 26 +++++++++ 7 files changed, 182 insertions(+), 36 deletions(-) rename src/ott/neural/solvers/{flow_matching.py => otfm.py} (89%) diff --git a/src/ott/neural/data/dataloaders.py b/src/ott/neural/data/dataloaders.py index fe0c367b7..acceb36c1 100644 --- a/src/ott/neural/data/dataloaders.py +++ b/src/ott/neural/data/dataloaders.py @@ -1,3 +1,16 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. #import tensorflow as tf diff --git a/src/ott/neural/solvers/base_solver.py b/src/ott/neural/solvers/base_solver.py index 9d323b13c..66d3ecbef 100644 --- a/src/ott/neural/solvers/base_solver.py +++ b/src/ott/neural/solvers/base_solver.py @@ -1,3 +1,16 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from abc import ABC, abstractmethod from pathlib import Path from types import MappingProxyType diff --git a/src/ott/neural/solvers/flows.py b/src/ott/neural/solvers/flows.py index 68cc84f5f..1eba46982 100644 --- a/src/ott/neural/solvers/flows.py +++ b/src/ott/neural/solvers/flows.py @@ -1,3 +1,16 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import abc import jax @@ -54,3 +67,38 @@ class BrownianNoiseFlow(StraightFlow): def compute_sigma_t(self, t: jax.Array): return jnp.sqrt(self.sigma * t * (1 - t)) + + +class BaseTimeSampler(abc.ABC): + + @abc.abstractmethod + def __call__(self, rng: jnp.ndarray, num_samples: int) -> jnp.ndarray: + pass + + +class UniformSampler(BaseTimeSampler): + + def __init__(self, low: float = 0.0, high: float = 1.0) -> None: + self.low = low + self.high = high + + def __call__(self, rng: jnp.ndarray, num_samples: int) -> jnp.ndarray: + return jax.random.uniform( + rng, (num_samples, 1), minval=self.low, maxval=self.high + ) + + +class OffsetUniformSampler(BaseTimeSampler): + + def __init__( + self, offset: float, low: float = 0.0, high: float = 1.0 + ) -> None: + self.offset = offset + self.low = low + self.high = high + + def __call__(self, rng: jnp.ndarray, num_samples: int) -> jnp.ndarray: + return ( + jax.random.uniform(rng, (1, 1), minval=self.low, maxval=self.high) + + jnp.arange(num_samples)[:, None] / num_samples + ) % ((self.high - self.low) - self.offset) diff --git a/src/ott/neural/solvers/genot.py b/src/ott/neural/solvers/genot.py index 1d9a9fc72..f120ea6d8 100644 --- a/src/ott/neural/solvers/genot.py +++ b/src/ott/neural/solvers/genot.py @@ -1,15 +1,28 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import functools import types from typing import ( - Any, - Callable, - Dict, - Literal, - Mapping, - Optional, - Tuple, - Type, - Union, + Any, + Callable, + Dict, + Literal, + Mapping, + Optional, + Tuple, + Type, + Union, ) import diffrax @@ -24,11 +37,16 @@ from ott.geometry import costs from ott.neural.models.models import BaseNeuralVectorField from ott.neural.solvers.base_solver import ( - BaseNeuralSolver, - ResampleMixin, - UnbalancednessMixin, + BaseNeuralSolver, + ResampleMixin, + UnbalancednessMixin, +) +from ott.neural.solvers.flows import ( + BaseFlow, + BaseTimeSampler, + ConstantNoiseFlow, + UniformSampler, ) -from ott.neural.solvers.flows import BaseFlow, ConstantNoiseFlow from ott.solvers import was_solver from ott.solvers.linear import sinkhorn from ott.solvers.quadratic import gromov_wasserstein @@ -53,6 +71,7 @@ def __init__( optimizer: Type[optax.GradientTransformation], checkpoint_manager: Type[checkpoint.CheckpointManager] = None, flow: Type[BaseFlow] = ConstantNoiseFlow(0.0), + time_sampler: Type[BaseTimeSampler] = UniformSampler(), k_noise_per_x: int = 1, t_offset: float = 1e-5, epsilon: float = 1e-2, @@ -166,6 +185,7 @@ def __init__( self.neural_vector_field = neural_vector_field self.state_neural_vector_field: Optional[TrainState] = None self.flow = flow + self.time_sampler = time_sampler self.optimizer = optimizer self.checkpoint_manager = checkpoint_manager self.latent_noise_fn = jax.tree_util.Partial( @@ -249,7 +269,7 @@ def __call__(self, train_loader, valid_loader) -> None: batch["source_q"] ) n_samples = batch_size * self.k_noise_per_x - batch["time"] = self.sample_t(rng_time, n_samples) + batch["time"] = self.time_sampler(rng_time, n_samples) batch["noise"] = self.sample_noise(rng_noise, n_samples) batch["latent"] = self.latent_noise_fn( rng_noise, @@ -446,11 +466,6 @@ def load(self, path: str) -> "GENOT": def training_logs(self) -> Dict[str, Any]: raise NotImplementedError - def sample_t( #TODO: make more general - self, key: random.PRNGKey, batch_size: int - ) -> jnp.ndarray: #TODO: make more general - return random.uniform(key, [batch_size, 1]) - def sample_noise( #TODO: make more general self, key: random.PRNGKey, batch_size: int ) -> jnp.ndarray: #TODO: make more general diff --git a/src/ott/neural/solvers/flow_matching.py b/src/ott/neural/solvers/otfm.py similarity index 89% rename from src/ott/neural/solvers/flow_matching.py rename to src/ott/neural/solvers/otfm.py index a80f4cc49..ec0be23da 100644 --- a/src/ott/neural/solvers/flow_matching.py +++ b/src/ott/neural/solvers/otfm.py @@ -1,3 +1,16 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import functools import types from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Type @@ -19,11 +32,12 @@ ) from ott.neural.solvers.flows import ( BaseFlow, + BaseTimeSampler, ) from ott.solvers import was_solver -class FlowMatching(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): +class OTFlowMatching(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): def __init__( self, @@ -34,6 +48,7 @@ def __init__( valid_freq: int, ot_solver: Optional[Type[was_solver.WassersteinSolver]], flow: Type[BaseFlow], + time_sampler: Type[BaseTimeSampler], optimizer: Type[optax.GradientTransformation], checkpoint_manager: Type[checkpoint.CheckpointManager] = None, epsilon: float = 1e-2, @@ -46,7 +61,6 @@ def __init__( callback_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], Any]] = None, rng: random.PRNGKeyArray = random.PRNGKey(0), - **kwargs: Any, ) -> None: BaseNeuralSolver.__init__( self, iterations=iterations, valid_freq=valid_freq @@ -68,6 +82,7 @@ def __init__( self.input_dim = input_dim self.ot_solver = ot_solver self.flow = flow + self.time_sampler = time_sampler self.optimizer = optimizer self.epsilon = epsilon self.cost_fn = cost_fn @@ -121,7 +136,7 @@ def loss_fn( batch_size = len(batch["source"]) key_noise, key_t, key_model = random.split(key, 3) keys_model = random.split(key_model, batch_size) - t = self.sample_t(key_t, batch_size) + t = self.time_sampler(key_t, batch_size) noise = self.sample_noise(key_noise, batch_size) grad_fn = jax.value_and_grad(loss_fn) loss, grads = grad_fn( @@ -207,17 +222,12 @@ def learn_rescaling(self) -> bool: def save(self, path: str) -> None: raise NotImplementedError - def load(self, path: str) -> "FlowMatching": + def load(self, path: str) -> "OTFlowMatching": raise NotImplementedError def training_logs(self) -> Dict[str, Any]: raise NotImplementedError - def sample_t( #TODO: make more general - self, key: random.PRNGKey, batch_size: int - ) -> jnp.ndarray: #TODO: make more general - return random.uniform(key, [batch_size, 1]) - def sample_noise( #TODO: make more general self, key: random.PRNGKey, batch_size: int ) -> jnp.ndarray: #TODO: make more general diff --git a/tests/neural/flow_matching_test.py b/tests/neural/flow_matching_test.py index 858c90084..9529e8a62 100644 --- a/tests/neural/flow_matching_test.py +++ b/tests/neural/flow_matching_test.py @@ -1,3 +1,16 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Type import jax @@ -6,16 +19,18 @@ import pytest from ott.neural.models.models import NeuralVectorField -from ott.neural.solvers.flow_matching import FlowMatching from ott.neural.solvers.flows import ( - BaseFlow, - BrownianNoiseFlow, - ConstantNoiseFlow, + BaseFlow, + BrownianNoiseFlow, + ConstantNoiseFlow, + OffsetUniformSampler, + UniformSampler, ) +from ott.neural.solvers.otfm import OTFlowMatching from ott.solvers.linear import sinkhorn -class TestFlowMatching: +class TestOTFlowMatching: @pytest.mark.parametrize( "flow", @@ -30,8 +45,9 @@ def test_flow_matching(self, data_loader_gaussian, flow: Type[BaseFlow]): latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() + time_sampler = UniformSampler() optimizer = optax.adam(learning_rate=1e-3) - fm = FlowMatching( + fm = OTFlowMatching( neural_vf, input_dim=2, cond_dim=0, @@ -39,6 +55,7 @@ def test_flow_matching(self, data_loader_gaussian, flow: Type[BaseFlow]): valid_freq=2, ot_solver=ot_solver, flow=flow, + time_sampler=time_sampler, optimizer=optimizer ) fm(data_loader_gaussian, data_loader_gaussian) @@ -67,8 +84,9 @@ def test_flow_matching_with_conditions( latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() + time_sampler = OffsetUniformSampler(1e-6) optimizer = optax.adam(learning_rate=1e-3) - fm = FlowMatching( + fm = OTFlowMatching( neural_vf, input_dim=2, cond_dim=1, @@ -76,6 +94,7 @@ def test_flow_matching_with_conditions( valid_freq=2, ot_solver=ot_solver, flow=flow, + time_sampler=time_sampler, optimizer=optimizer ) fm( @@ -107,8 +126,9 @@ def test_flow_matching_conditional( latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() + time_sampler = UniformSampler() optimizer = optax.adam(learning_rate=1e-3) - fm = FlowMatching( + fm = OTFlowMatching( neural_vf, input_dim=2, cond_dim=0, @@ -116,6 +136,7 @@ def test_flow_matching_conditional( valid_freq=2, ot_solver=ot_solver, flow=flow, + time_sampler=time_sampler, optimizer=optimizer ) fm(data_loader_gaussian_conditional, data_loader_gaussian_conditional) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index 5f59db542..183af8419 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -1,3 +1,16 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Iterator import jax @@ -6,6 +19,7 @@ import pytest from ott.neural.models.models import NeuralVectorField +from ott.neural.solvers.flows import OffsetUniformSampler, UniformSampler from ott.neural.solvers.genot import GENOT from ott.solvers.linear import sinkhorn from ott.solvers.quadratic import gromov_wasserstein @@ -31,6 +45,7 @@ def test_genot_linear_unconditional( latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() + time_sampler = UniformSampler() optimizer = optax.adam(learning_rate=1e-3) genot = GENOT( neural_vf, @@ -40,6 +55,7 @@ def test_genot_linear_unconditional( iterations=3, valid_freq=2, ot_solver=ot_solver, + time_sampler=time_sampler, optimizer=optimizer, k_noise_per_x=k_noise_per_x, ) @@ -70,6 +86,7 @@ def test_genot_quad_unconditional( latent_embed_dim=5, ) ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) + time_sampler = OffsetUniformSampler(1e-3) optimizer = optax.adam(learning_rate=1e-3) genot = GENOT( neural_vf, @@ -80,6 +97,7 @@ def test_genot_quad_unconditional( iterations=3, valid_freq=2, ot_solver=ot_solver, + time_sampler=time_sampler, optimizer=optimizer, k_noise_per_x=k_noise_per_x, ) @@ -107,6 +125,7 @@ def test_genot_fused_unconditional( latent_embed_dim=5, ) ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) + time_sampler = UniformSampler() optimizer = optax.adam(learning_rate=1e-3) genot = GENOT( neural_vf, @@ -117,6 +136,7 @@ def test_genot_fused_unconditional( iterations=3, valid_freq=2, ot_solver=ot_solver, + time_sampler=time_sampler, optimizer=optimizer, fused_penalty=0.5, k_noise_per_x=k_noise_per_x, @@ -146,6 +166,7 @@ def test_genot_linear_conditional( latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() + time_sampler = UniformSampler() optimizer = optax.adam(learning_rate=1e-3) genot = GENOT( neural_vf, @@ -155,6 +176,7 @@ def test_genot_linear_conditional( iterations=3, valid_freq=2, ot_solver=ot_solver, + time_sampler=time_sampler, optimizer=optimizer, k_noise_per_x=k_noise_per_x, ) @@ -188,6 +210,7 @@ def test_genot_quad_conditional( latent_embed_dim=5, ) ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) + time_sampler = UniformSampler() optimizer = optax.adam(learning_rate=1e-3) genot = GENOT( neural_vf, @@ -198,6 +221,7 @@ def test_genot_quad_conditional( iterations=3, valid_freq=2, ot_solver=ot_solver, + time_sampler=time_sampler, optimizer=optimizer, k_noise_per_x=k_noise_per_x, ) @@ -225,6 +249,7 @@ def test_genot_fused_conditional( latent_embed_dim=5, ) ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) + time_sampler = UniformSampler() optimizer = optax.adam(learning_rate=1e-3) genot = GENOT( neural_vf, @@ -235,6 +260,7 @@ def test_genot_fused_conditional( iterations=3, valid_freq=2, ot_solver=ot_solver, + time_sampler=time_sampler, optimizer=optimizer, fused_penalty=0.5, k_noise_per_x=k_noise_per_x, From 2546afc4e95b5ab73c719f5a1b6b3fde52ea23ca Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Sun, 26 Nov 2023 14:18:21 +0100 Subject: [PATCH 019/186] [ci skip] add docs for TimeSampler and Flow --- src/ott/neural/solvers/flows.py | 112 ++++++++++++++++++++++++++++- src/ott/neural/solvers/genot.py | 32 ++++----- tests/neural/flow_matching_test.py | 10 +-- 3 files changed, 132 insertions(+), 22 deletions(-) diff --git a/src/ott/neural/solvers/flows.py b/src/ott/neural/solvers/flows.py index 1eba46982..19c3d2f67 100644 --- a/src/ott/neural/solvers/flows.py +++ b/src/ott/neural/solvers/flows.py @@ -18,77 +18,178 @@ class BaseFlow(abc.ABC): + """Base class for all flows. + + Args: + sigma: Constant noise used for computing time-dependent noise schedule. + """ def __init__(self, sigma: float) -> None: self.sigma = sigma @abc.abstractmethod def compute_mu_t(self, t: jax.Array, x_0: jax.Array, x_1: jax.Array): + """Compute the mean of the probablitiy path between :math:`x` and :math:`y` at time :math:`t`. + + Args: + t: Time :math:`t`. + x_0: Sample from the source distribution. + x_1: Sample from the target distribution. + """ pass @abc.abstractmethod def compute_sigma_t(self, t: jax.Array): + """Compute the standard deviation of the probablity path at time :math:`t`. + + Args: + t: Time :math:`t`. + """ pass @abc.abstractmethod def compute_ut( self, t: jax.Array, x_0: jax.Array, x_1: jax.Array ) -> jax.Array: + """Evaluate the conditional vector field defined between :math:`x_0` and :math:`x_1` at time :math:`t`. + + Args: + t: Time :math:`t`. + x_0: Sample from the source distribution. + x_1: Sample from the target distribution. + """ pass def compute_xt( self, noise: jax.Array, t: jax.Array, x_0: jax.Array, x_1: jax.Array ) -> jax.Array: + """Sample from the probability path between :math:`x_0` and :math:`x_1` at time :math:`t`. + + Args: + noise: Noise sampled from a standard normal distribution. + t: Time :math:`t`. + x_0: Sample from the source distribution. + x_1: Sample from the target distribution. + + Returns: + Samples from the probability path between :math:`x_0` and :math:`x_1` at time :math:`t`. + """ mu_t = self.compute_mu_t(t, x_0, x_1) sigma_t = self.compute_sigma_t(t) return mu_t + sigma_t * noise -class StraightFlow(BaseFlow): +class StraightFlow(BaseFlow, abc.ABC): + """Base class for flows with straight paths.""" def compute_mu_t( self, t: jax.Array, x_0: jax.Array, x_1: jax.Array ) -> jax.Array: + """Compute the mean of the probablitiy path between :math:`x` and :math:`y` at time :math:`t`. + + Args: + t: Time :math:`t`. + x_0: Sample from the source distribution. + x_1: Sample from the target distribution. + """ return t * x_0 + (1 - t) * x_1 def compute_ut( self, t: jax.Array, x_0: jax.Array, x_1: jax.Array ) -> jax.Array: + """Evaluate the conditional vector field defined between :math:`x_0` and :math:`x_1` at time :math:`t`. + + Args: + t: Time :math:`t`. + x_0: Sample from the source distribution. + x_1: Sample from the target distribution. + + Returns: + Conditional vector field evaluated at time :math:`t`. + """ return x_1 - x_0 class ConstantNoiseFlow(StraightFlow): + r"""Flow with straight paths and constant flow noise :math:`\sigma`.""" def compute_sigma_t(self, t: jax.Array): + r"""Compute noise of the flow at time :math:`t`. + + Args: + t: Time :math:`t`. + + Returns: + Constant, time-independent standard deviation :math:`\sigma`. + """ return self.sigma class BrownianNoiseFlow(StraightFlow): + r"""Sampler for sampling noise implicitly defined by a Schroedinger Bridge problem with parameter `\sigma` such that :math:`\sigma_t = \sigma * \sqrt(t * (1-t))`.""" def compute_sigma_t(self, t: jax.Array): + """Compute the standard deviation of the probablity path at time :math:`t`. + + Args: + t: Time :math:`t`. + + Returns: + Standard deviation of the probablity path at time :math:`t`. + """ return jnp.sqrt(self.sigma * t * (1 - t)) class BaseTimeSampler(abc.ABC): + """Base class for time samplers.""" @abc.abstractmethod def __call__(self, rng: jnp.ndarray, num_samples: int) -> jnp.ndarray: + """Generate `num_samples` samples of the time `math`:t:. + + Args: + rng: Random number generator. + num_samples: Number of samples to generate. + + """ pass class UniformSampler(BaseTimeSampler): + """Sample :math:`t` from a uniform distribution :math:`[low, high]`. + + Args: + low: Lower bound of the uniform distribution. + high: Upper bound of the uniform distribution. + """ def __init__(self, low: float = 0.0, high: float = 1.0) -> None: self.low = low self.high = high def __call__(self, rng: jnp.ndarray, num_samples: int) -> jnp.ndarray: + """Generate `num_samples` samples of the time `math`:t:. + + Args: + rng: Random number generator. + num_samples: Number of samples to generate. + + Returns: + `num_samples` samples of the time :math:`t``. + """ return jax.random.uniform( rng, (num_samples, 1), minval=self.low, maxval=self.high ) class OffsetUniformSampler(BaseTimeSampler): + """Sample :math:`t` from a uniform distribution :math:`[low, high]` with offset `offset`. + + Args: + offset: Offset of the uniform distribution. + low: Lower bound of the uniform distribution. + high: Upper bound of the uniform distribution. + """ def __init__( self, offset: float, low: float = 0.0, high: float = 1.0 @@ -98,6 +199,15 @@ def __init__( self.high = high def __call__(self, rng: jnp.ndarray, num_samples: int) -> jnp.ndarray: + """Generate `num_samples` samples of the time `math`:t:. + + Args: + rng: Random number generator. + num_samples: Number of samples to generate. + + Returns: + An array with `num_samples` samples of the time `math`:t:. + """ return ( jax.random.uniform(rng, (1, 1), minval=self.low, maxval=self.high) + jnp.arange(num_samples)[:, None] / num_samples diff --git a/src/ott/neural/solvers/genot.py b/src/ott/neural/solvers/genot.py index f120ea6d8..3d6b3fafb 100644 --- a/src/ott/neural/solvers/genot.py +++ b/src/ott/neural/solvers/genot.py @@ -14,15 +14,15 @@ import functools import types from typing import ( - Any, - Callable, - Dict, - Literal, - Mapping, - Optional, - Tuple, - Type, - Union, + Any, + Callable, + Dict, + Literal, + Mapping, + Optional, + Tuple, + Type, + Union, ) import diffrax @@ -37,15 +37,15 @@ from ott.geometry import costs from ott.neural.models.models import BaseNeuralVectorField from ott.neural.solvers.base_solver import ( - BaseNeuralSolver, - ResampleMixin, - UnbalancednessMixin, + BaseNeuralSolver, + ResampleMixin, + UnbalancednessMixin, ) from ott.neural.solvers.flows import ( - BaseFlow, - BaseTimeSampler, - ConstantNoiseFlow, - UniformSampler, + BaseFlow, + BaseTimeSampler, + ConstantNoiseFlow, + UniformSampler, ) from ott.solvers import was_solver from ott.solvers.linear import sinkhorn diff --git a/tests/neural/flow_matching_test.py b/tests/neural/flow_matching_test.py index 9529e8a62..a1135cf2d 100644 --- a/tests/neural/flow_matching_test.py +++ b/tests/neural/flow_matching_test.py @@ -20,11 +20,11 @@ from ott.neural.models.models import NeuralVectorField from ott.neural.solvers.flows import ( - BaseFlow, - BrownianNoiseFlow, - ConstantNoiseFlow, - OffsetUniformSampler, - UniformSampler, + BaseFlow, + BrownianNoiseFlow, + ConstantNoiseFlow, + OffsetUniformSampler, + UniformSampler, ) from ott.neural.solvers.otfm import OTFlowMatching from ott.solvers.linear import sinkhorn From 579852f93e6a5c5be9c6a632a8c005e62fb6dc8e Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Sun, 26 Nov 2023 15:56:46 +0100 Subject: [PATCH 020/186] [ci skip] add docs for OTFlowMatching and replace jnp.ndarray by jax.Array --- docs/tutorials/Hessians.ipynb | 2 +- docs/tutorials/Monge_Gap.ipynb | 6 +- docs/tutorials/One_Sinkhorn.ipynb | 4 +- .../tutorials/basic_ot_between_datasets.ipynb | 2 +- docs/tutorials/point_clouds.ipynb | 4 +- .../sinkhorn_divergence_gradient_flow.ipynb | 4 +- .../sparse_monge_displacements.ipynb | 2 +- src/ott/datasets.py | 4 +- src/ott/geometry/costs.py | 166 ++++++++-------- src/ott/geometry/geometry.py | 186 +++++++++--------- src/ott/geometry/graph.py | 44 ++--- src/ott/geometry/grid.py | 40 ++-- src/ott/geometry/low_rank.py | 56 +++--- src/ott/geometry/pointcloud.py | 108 +++++----- src/ott/geometry/segment.py | 33 ++-- src/ott/initializers/linear/initializers.py | 42 ++-- .../initializers/linear/initializers_lr.py | 70 +++---- .../initializers/quadratic/initializers.py | 4 +- src/ott/math/fixed_point_loop.py | 2 +- src/ott/math/matrix_square_root.py | 51 +++-- src/ott/math/unbalanced_functions.py | 23 +-- src/ott/math/utils.py | 22 +-- src/ott/neural/data/dataloaders.py | 2 +- src/ott/neural/models/conjugate_solvers.py | 13 +- src/ott/neural/models/layers.py | 10 +- src/ott/neural/models/models.py | 38 ++-- src/ott/neural/solvers/base_solver.py | 68 ++++--- src/ott/neural/solvers/flows.py | 25 +-- src/ott/neural/solvers/genot.py | 10 +- src/ott/neural/solvers/losses.py | 8 +- src/ott/neural/solvers/map_estimator.py | 28 +-- src/ott/neural/solvers/neuraldual.py | 50 ++--- src/ott/neural/solvers/otfm.py | 100 ++++++++-- src/ott/problems/linear/barycenter_problem.py | 20 +- src/ott/problems/linear/linear_problem.py | 13 +- src/ott/problems/linear/potentials.py | 34 ++-- src/ott/problems/quadratic/gw_barycenter.py | 50 +++-- src/ott/problems/quadratic/quadratic_costs.py | 3 +- .../problems/quadratic/quadratic_problem.py | 30 +-- src/ott/solvers/linear/_solve.py | 6 +- src/ott/solvers/linear/acceleration.py | 8 +- .../solvers/linear/continuous_barycenter.py | 18 +- src/ott/solvers/linear/discrete_barycenter.py | 18 +- .../linear/implicit_differentiation.py | 31 ++- src/ott/solvers/linear/lineax_implicit.py | 4 +- src/ott/solvers/linear/lr_utils.py | 42 ++-- src/ott/solvers/linear/sinkhorn.py | 86 ++++---- src/ott/solvers/linear/sinkhorn_lr.py | 129 ++++++------ src/ott/solvers/linear/univariate.py | 14 +- src/ott/solvers/quadratic/_solve.py | 6 +- .../solvers/quadratic/gromov_wasserstein.py | 16 +- .../quadratic/gromov_wasserstein_lr.py | 133 +++++++------ src/ott/solvers/quadratic/gw_barycenter.py | 32 ++- src/ott/tools/gaussian_mixture/fit_gmm.py | 34 ++-- .../tools/gaussian_mixture/fit_gmm_pair.py | 26 +-- src/ott/tools/gaussian_mixture/gaussian.py | 30 +-- .../gaussian_mixture/gaussian_mixture.py | 39 ++-- .../gaussian_mixture/gaussian_mixture_pair.py | 6 +- src/ott/tools/gaussian_mixture/linalg.py | 30 ++- .../tools/gaussian_mixture/probabilities.py | 12 +- src/ott/tools/gaussian_mixture/scale_tril.py | 36 ++-- src/ott/tools/k_means.py | 58 +++--- src/ott/tools/plot.py | 6 +- src/ott/tools/segment_sinkhorn.py | 24 +-- src/ott/tools/sinkhorn_divergence.py | 44 ++--- src/ott/tools/soft_sort.py | 78 ++++---- src/ott/types.py | 8 +- tests/conftest.py | 3 +- tests/geometry/costs_test.py | 2 +- tests/geometry/graph_test.py | 12 +- tests/geometry/low_rank_test.py | 2 +- tests/geometry/scaling_cost_test.py | 6 +- .../initializers/linear/sinkhorn_init_test.py | 8 +- tests/math/matrix_square_root_test.py | 8 +- tests/neural/conftest.py | 2 +- tests/neural/map_estimator_test.py | 6 +- tests/neural/meta_initializer_test.py | 10 +- .../linear/continuous_barycenter_test.py | 6 +- tests/solvers/linear/sinkhorn_diff_test.py | 34 ++-- tests/solvers/linear/sinkhorn_misc_test.py | 8 +- tests/solvers/quadratic/fgw_test.py | 12 +- tests/solvers/quadratic/gw_barycenter_test.py | 6 +- tests/solvers/quadratic/gw_test.py | 7 +- tests/solvers/quadratic/lower_bound_test.py | 4 +- tests/tools/k_means_test.py | 16 +- tests/tools/sinkhorn_divergence_test.py | 2 +- tests/tools/soft_sort_test.py | 4 +- 87 files changed, 1271 insertions(+), 1238 deletions(-) diff --git a/docs/tutorials/Hessians.ipynb b/docs/tutorials/Hessians.ipynb index 0e50ec959..f7c8b56d1 100644 --- a/docs/tutorials/Hessians.ipynb +++ b/docs/tutorials/Hessians.ipynb @@ -103,7 +103,7 @@ }, "outputs": [], "source": [ - "def loss(a: jnp.ndarray, x: jnp.ndarray, implicit: bool = True) -> float:\n", + "def loss(a: jax.Array, x: jax.Array, implicit: bool = True) -> float:\n", " return sinkhorn_divergence.sinkhorn_divergence(\n", " pointcloud.PointCloud,\n", " x,\n", diff --git a/docs/tutorials/Monge_Gap.ipynb b/docs/tutorials/Monge_Gap.ipynb index a1622a8c5..ac38d89b4 100644 --- a/docs/tutorials/Monge_Gap.ipynb +++ b/docs/tutorials/Monge_Gap.ipynb @@ -94,13 +94,13 @@ "\n", " name: Literal[\"moon\", \"s_curve\"]\n", " theta_rotation: float = 0.0\n", - " mean: Optional[jnp.ndarray] = None\n", + " mean: Optional[jax.Array] = None\n", " noise: float = 0.01\n", " scale: float = 1.0\n", " batch_size: int = 1024\n", " rng: Optional[jax.Array] = (None,)\n", "\n", - " def __iter__(self) -> Iterator[jnp.ndarray]:\n", + " def __iter__(self) -> Iterator[jax.Array]:\n", " \"\"\"Random sample generator from Gaussian mixture.\n", "\n", " Returns:\n", @@ -108,7 +108,7 @@ " \"\"\"\n", " return self._create_sample_generators()\n", "\n", - " def _create_sample_generators(self) -> Iterator[jnp.ndarray]:\n", + " def _create_sample_generators(self) -> Iterator[jax.Array]:\n", " rng = jax.random.PRNGKey(0) if self.rng is None else self.rng\n", "\n", " # define rotation matrix tp rotate samples\n", diff --git a/docs/tutorials/One_Sinkhorn.ipynb b/docs/tutorials/One_Sinkhorn.ipynb index 8c3d98e2e..9465441d8 100644 --- a/docs/tutorials/One_Sinkhorn.ipynb +++ b/docs/tutorials/One_Sinkhorn.ipynb @@ -555,9 +555,7 @@ }, "outputs": [], "source": [ - "def my_sinkhorn(\n", - " geom: geometry.Geometry, a: jnp.ndarray, b: jnp.ndarray, **kwargs\n", - "):\n", + "def my_sinkhorn(geom: geometry.Geometry, a: jax.Array, b: jax.Array, **kwargs):\n", " return linear.solve(\n", " geom, a, b, inner_iterations=1, max_iterations=10_000, **kwargs\n", " )" diff --git a/docs/tutorials/basic_ot_between_datasets.ipynb b/docs/tutorials/basic_ot_between_datasets.ipynb index 3cc61d403..b3c452d36 100644 --- a/docs/tutorials/basic_ot_between_datasets.ipynb +++ b/docs/tutorials/basic_ot_between_datasets.ipynb @@ -260,7 +260,7 @@ "metadata": {}, "outputs": [], "source": [ - "def reg_ot_cost(x: jnp.ndarray, y: jnp.ndarray) -> float:\n", + "def reg_ot_cost(x: jax.Array, y: jax.Array) -> float:\n", " geom = pointcloud.PointCloud(x, y)\n", " ot = linear.solve(geom)\n", " return ot.reg_ot_cost" diff --git a/docs/tutorials/point_clouds.ipynb b/docs/tutorials/point_clouds.ipynb index fd20ffc9a..e1b77edca 100644 --- a/docs/tutorials/point_clouds.ipynb +++ b/docs/tutorials/point_clouds.ipynb @@ -241,8 +241,8 @@ "outputs": [], "source": [ "def optimize(\n", - " x: jnp.ndarray,\n", - " y: jnp.ndarray,\n", + " x: jax.Array,\n", + " y: jax.Array,\n", " num_iter: int = 300,\n", " dump_every: int = 5,\n", " learning_rate: float = 0.2,\n", diff --git a/docs/tutorials/sinkhorn_divergence_gradient_flow.ipynb b/docs/tutorials/sinkhorn_divergence_gradient_flow.ipynb index c3b73039c..ff84f53b4 100644 --- a/docs/tutorials/sinkhorn_divergence_gradient_flow.ipynb +++ b/docs/tutorials/sinkhorn_divergence_gradient_flow.ipynb @@ -145,8 +145,8 @@ "outputs": [], "source": [ "def gradient_flow(\n", - " x: jnp.ndarray,\n", - " y: jnp.ndarray,\n", + " x: jax.Array,\n", + " y: jax.Array,\n", " cost_fn: callable,\n", " num_iter: int = 500,\n", " lr: float = 0.2,\n", diff --git a/docs/tutorials/sparse_monge_displacements.ipynb b/docs/tutorials/sparse_monge_displacements.ipynb index a21213703..8fcb49096 100644 --- a/docs/tutorials/sparse_monge_displacements.ipynb +++ b/docs/tutorials/sparse_monge_displacements.ipynb @@ -241,7 +241,7 @@ "solver = jax.jit(sinkhorn.Sinkhorn())\n", "\n", "\n", - "def entropic_map(x, y, cost_fn: costs.TICost) -> jnp.ndarray:\n", + "def entropic_map(x, y, cost_fn: costs.TICost) -> jax.Array:\n", " geom = pointcloud.PointCloud(x, y, cost_fn=cost_fn)\n", " output = solver(linear_problem.LinearProblem(geom))\n", " dual_potentials = output.to_dual_potentials()\n", diff --git a/src/ott/datasets.py b/src/ott/datasets.py index 4d67c5976..12dda06bb 100644 --- a/src/ott/datasets.py +++ b/src/ott/datasets.py @@ -32,8 +32,8 @@ class Dataset(NamedTuple): source_iter: loader for the source measure target_iter: loader for the target measure """ - source_iter: Iterator[jnp.ndarray] - target_iter: Iterator[jnp.ndarray] + source_iter: Iterator[jax.Array] + target_iter: Iterator[jax.Array] @dataclasses.dataclass diff --git a/src/ott/geometry/costs.py b/src/ott/geometry/costs.py index 9f1a6c3a0..aeaf89b72 100644 --- a/src/ott/geometry/costs.py +++ b/src/ott/geometry/costs.py @@ -56,10 +56,10 @@ class CostFn(abc.ABC): """ # no norm function created by default. - norm: Optional[Callable[[jnp.ndarray], Union[float, jnp.ndarray]]] = None + norm: Optional[Callable[[jax.Array], Union[float, jax.Array]]] = None @abc.abstractmethod - def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: + def pairwise(self, x: jax.Array, y: jax.Array) -> float: """Compute cost between :math:`x` and :math:`y`. Args: @@ -70,8 +70,8 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: The cost. """ - def barycenter(self, weights: jnp.ndarray, - xs: jnp.ndarray) -> Tuple[jnp.ndarray, Any]: + def barycenter(self, weights: jax.Array, + xs: jax.Array) -> Tuple[jax.Array, Any]: """Barycentric operator. Args: @@ -86,7 +86,7 @@ def barycenter(self, weights: jnp.ndarray, raise NotImplementedError("Barycenter is not implemented.") @classmethod - def _padder(cls, dim: int) -> jnp.ndarray: + def _padder(cls, dim: int) -> jax.Array: """Create a padding vector of adequate dimension, well-suited to a cost. Args: @@ -97,7 +97,7 @@ def _padder(cls, dim: int) -> jnp.ndarray: """ return jnp.zeros((1, dim)) - def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> float: + def __call__(self, x: jax.Array, y: jax.Array) -> float: """Compute cost between :math:`x` and :math:`y`. Args: @@ -113,7 +113,7 @@ def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> float: return cost return cost + self.norm(x) + self.norm(y) - def all_pairs(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + def all_pairs(self, x: jax.Array, y: jax.Array) -> jax.Array: """Compute matrix of all pairwise costs, including the :attr:`norms `. Args: @@ -125,7 +125,7 @@ def all_pairs(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: """ return jax.vmap(lambda x_: jax.vmap(lambda y_: self(x_, y_))(y))(x) - def all_pairs_pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + def all_pairs_pairwise(self, x: jax.Array, y: jax.Array) -> jax.Array: """Compute matrix of all pairwise costs, excluding the :attr:`norms `. Args: @@ -163,7 +163,7 @@ class TICost(CostFn): """ @abc.abstractmethod - def h(self, z: jnp.ndarray) -> float: + def h(self, z: jax.Array) -> float: """TI function acting on difference of :math:`x-y` to output cost. Args: @@ -173,11 +173,11 @@ def h(self, z: jnp.ndarray) -> float: The cost. """ - def h_legendre(self, z: jnp.ndarray) -> float: + def h_legendre(self, z: jax.Array) -> float: """Legendre transform of :func:`h` when it is convex.""" raise NotImplementedError("Legendre transform of `h` is not implemented.") - def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: + def pairwise(self, x: jax.Array, y: jax.Array) -> float: """Compute cost as evaluation of :func:`h` on :math:`x-y`.""" return self.h(x - y) @@ -198,10 +198,10 @@ def __init__(self, p: float): self.p = p self.q = 1.0 / (1.0 - (1.0 / p)) if p > 1.0 else jnp.inf - def h(self, z: jnp.ndarray) -> float: # noqa: D102 + def h(self, z: jax.Array) -> float: # noqa: D102 return 0.5 * mu.norm(z, self.p) ** 2 - def h_legendre(self, z: jnp.ndarray) -> float: + def h_legendre(self, z: jax.Array) -> float: """Legendre transform of :func:`h`. For details on the derivation, see e.g., :cite:`boyd:04`, p. 93/94. @@ -234,10 +234,10 @@ def __init__(self, p: float): self.p = p self.q = 1.0 / (1.0 - (1.0 / p)) if p > 1.0 else jnp.inf - def h(self, z: jnp.ndarray) -> float: # noqa: D102 + def h(self, z: jax.Array) -> float: # noqa: D102 return mu.norm(z, self.p) ** self.p / self.p - def h_legendre(self, z: jnp.ndarray) -> float: # noqa: D102 + def h_legendre(self, z: jax.Array) -> float: # noqa: D102 # not defined for `p=1` return mu.norm(z, self.q) ** self.q / self.q @@ -260,7 +260,7 @@ class Euclidean(CostFn): because the function is not strictly convex (it is linear on rays). """ - def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: + def pairwise(self, x: jax.Array, y: jax.Array) -> float: """Compute Euclidean norm using custom jvp implementation. Here we use a custom jvp implementation for the norm that does not yield @@ -277,22 +277,22 @@ class SqEuclidean(TICost): Implemented as a translation invariant cost, :math:`h(z) = \|z\|^2`. """ - def norm(self, x: jnp.ndarray) -> Union[float, jnp.ndarray]: + def norm(self, x: jax.Array) -> Union[float, jax.Array]: """Compute squared Euclidean norm for vector.""" return jnp.sum(x ** 2, axis=-1) - def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: + def pairwise(self, x: jax.Array, y: jax.Array) -> float: """Compute minus twice the dot-product between vectors.""" return -2. * jnp.vdot(x, y) - def h(self, z: jnp.ndarray) -> float: # noqa: D102 + def h(self, z: jax.Array) -> float: # noqa: D102 return jnp.sum(z ** 2) - def h_legendre(self, z: jnp.ndarray) -> float: # noqa: D102 + def h_legendre(self, z: jax.Array) -> float: # noqa: D102 return 0.25 * jnp.sum(z ** 2) - def barycenter(self, weights: jnp.ndarray, - xs: jnp.ndarray) -> Tuple[jnp.ndarray, Any]: + def barycenter(self, weights: jax.Array, + xs: jax.Array) -> Tuple[jax.Array, Any]: """Output barycenter of vectors when using squared-Euclidean distance.""" return jnp.average(xs, weights=weights, axis=0), None @@ -309,7 +309,7 @@ def __init__(self, ridge: float = 1e-8): super().__init__() self._ridge = ridge - def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: + def pairwise(self, x: jax.Array, y: jax.Array) -> float: """Cosine distance between vectors, denominator regularized with ridge.""" ridge = self._ridge x_norm = jnp.linalg.norm(x, axis=-1) @@ -318,7 +318,7 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: return 1.0 - cosine_similarity @classmethod - def _padder(cls, dim: int) -> jnp.ndarray: + def _padder(cls, dim: int) -> jax.Array: return jnp.ones((1, dim)) @@ -341,7 +341,7 @@ class RegTICost(TICost, abc.ABC): def __init__( self, scaling_reg: float = 1.0, - matrix: Optional[jnp.ndarray] = None, + matrix: Optional[jax.Array] = None, orthogonal: bool = False, ): super().__init__() @@ -350,16 +350,16 @@ def __init__( self.orthogonal = orthogonal @abc.abstractmethod - def _reg(self, z: jnp.ndarray) -> float: + def _reg(self, z: jax.Array) -> float: """Regularization function.""" - def _reg_stiefel_orth(self, z: jnp.ndarray) -> float: + def _reg_stiefel_orth(self, z: jax.Array) -> float: raise NotImplementedError( "Regularization in the orthogonal " "subspace is not implemented." ) - def reg(self, z: jnp.ndarray) -> float: + def reg(self, z: jax.Array) -> float: """Regularization function. Args: @@ -374,7 +374,7 @@ def reg(self, z: jnp.ndarray) -> float: return self._reg_stiefel_orth(z) return self._reg(self.matrix @ z) - def prox_reg(self, z: jnp.ndarray, tau: float = 1.0) -> jnp.ndarray: + def prox_reg(self, z: jax.Array, tau: float = 1.0) -> jax.Array: """Proximal operator of :meth:`reg`. Args: @@ -391,26 +391,24 @@ def prox_reg(self, z: jnp.ndarray, tau: float = 1.0) -> jnp.ndarray: return self._prox_reg_stiefel_orth(z, tau) return self._prox_reg_stiefel(z, tau) - def _prox_reg(self, z: jnp.ndarray, tau: float = 1.0) -> jnp.ndarray: + def _prox_reg(self, z: jax.Array, tau: float = 1.0) -> jax.Array: raise NotImplementedError("Proximal operator is not implemented.") - def _prox_reg_stiefel_orth( - self, z: jnp.ndarray, tau: float = 1.0 - ) -> jnp.ndarray: + def _prox_reg_stiefel_orth(self, z: jax.Array, tau: float = 1.0) -> jax.Array: - def orth(x: jnp.ndarray) -> jnp.ndarray: + def orth(x: jax.Array) -> jax.Array: return x - self.matrix.T @ (self.matrix @ x) # assumes `matrix` has orthogonal rows tmp = orth(z) return z - orth(tmp - self._prox_reg(tmp, tau)) - def _prox_reg_stiefel(self, z: jnp.ndarray, tau: float) -> jnp.ndarray: + def _prox_reg_stiefel(self, z: jax.Array, tau: float) -> jax.Array: # assumes `matrix` has orthogonal rows tmp = self.matrix @ z return z - self.matrix.T @ (tmp - self._prox_reg(tmp, tau)) - def prox_legendre_reg(self, z: jnp.ndarray, tau: float = 1.0) -> jnp.ndarray: + def prox_legendre_reg(self, z: jax.Array, tau: float = 1.0) -> jax.Array: r"""Proximal operator of the Legendre transform of :meth:`reg`. Uses Moreau's decomposition: @@ -428,16 +426,16 @@ def prox_legendre_reg(self, z: jnp.ndarray, tau: float = 1.0) -> jnp.ndarray: """ return z - tau * self.prox_reg(z / tau, 1.0 / tau) - def h(self, z: jnp.ndarray) -> float: # noqa: D102 + def h(self, z: jax.Array) -> float: # noqa: D102 out = 0.5 * jnp.sum(z ** 2) return out + self.scaling_reg * self.reg(z) - def h_legendre(self, z: jnp.ndarray) -> float: # noqa: D102 + def h_legendre(self, z: jax.Array) -> float: # noqa: D102 q = jax.lax.stop_gradient(self.prox_reg(z)) return jnp.sum(q * z) - self.h(q) - def h_transform(self, f: Callable[[jnp.ndarray], float], - **kwargs: Any) -> Callable[[jnp.ndarray], float]: + def h_transform(self, f: Callable[[jax.Array], float], + **kwargs: Any) -> Callable[[jax.Array], float]: r"""Compute the h-transform of a concave function. Return a callable :math:`f_h` defined as: @@ -467,18 +465,16 @@ def h_transform(self, f: Callable[[jnp.ndarray], float], The h-transform of ``f``. """ - def minus_f(z: jnp.ndarray, x: jnp.ndarray) -> float: + def minus_f(z: jax.Array, x: jax.Array) -> float: return -f(x - z) - def prox( - x: jnp.ndarray, scaling_reg: float, scaling_h: float - ) -> jnp.ndarray: + def prox(x: jax.Array, scaling_reg: float, scaling_h: float) -> jax.Array: # https://web.stanford.edu/~boyd/papers/pdf/prox_algs.pdf 2.2. tmp = 1.0 / (1.0 + scaling_h) tau = scaling_reg * scaling_h * tmp return self.prox_reg(x * tmp, tau) - def f_h(x: jnp.ndarray) -> float: + def f_h(x: jax.Array) -> float: pg = jaxopt.ProximalGradient(fun=minus_f, prox=prox, **kwargs) pg_run = pg.run(x, self.scaling_reg, x=x) pg_sol = jax.lax.stop_gradient(pg_run.params) @@ -508,10 +504,10 @@ class ElasticL1(RegTICost): to promote displacements in the span of ``matrix``. """ - def _reg(self, z: jnp.ndarray) -> float: # noqa: D102 + def _reg(self, z: jax.Array) -> float: # noqa: D102 return jnp.linalg.norm(z, ord=1) - def _prox_reg(self, z: jnp.ndarray, tau: float = 1.0) -> jnp.ndarray: + def _prox_reg(self, z: jax.Array, tau: float = 1.0) -> jax.Array: return jnp.sign(z) * jax.nn.relu(jnp.abs(z) - tau * self.scaling_reg) @@ -529,19 +525,17 @@ class ElasticL2(RegTICost): to promote displacements in the span of ``matrix``. """ - def _reg(self, z: jnp.ndarray) -> float: # noqa: D102 + def _reg(self, z: jax.Array) -> float: # noqa: D102 return 0.5 * jnp.sum(z ** 2) - def _reg_stiefel_orth(self, z: jnp.ndarray) -> float: + def _reg_stiefel_orth(self, z: jax.Array) -> float: # Pythagorean identity return self._reg(z) - self._reg(self.matrix @ z) - def _prox_reg(self, z: jnp.ndarray, tau: float = 1.0) -> jnp.ndarray: + def _prox_reg(self, z: jax.Array, tau: float = 1.0) -> jax.Array: return z / (1.0 + tau * self.scaling_reg) - def _prox_reg_stiefel_orth( - self, z: jnp.ndarray, tau: float = 1.0 - ) -> jnp.ndarray: + def _prox_reg_stiefel_orth(self, z: jax.Array, tau: float = 1.0) -> jax.Array: out = z + tau * self.scaling_reg * self.matrix.T @ (self.matrix @ z) return self._prox_reg(out, tau) @@ -565,7 +559,7 @@ class ElasticSTVS(RegTICost): to promote displacements in the span of ``matrix``. """ # noqa: D205,E501 - def _reg(self, z: jnp.ndarray) -> float: # noqa: D102 + def _reg(self, z: jax.Array) -> float: # noqa: D102 u = jnp.arcsinh(jnp.abs(z) / (2 * self.scaling_reg)) out = u - 0.5 * jnp.exp(-2.0 * u) # Lemma 2.1 of `schreck:15`; @@ -573,8 +567,8 @@ def _reg(self, z: jnp.ndarray) -> float: # noqa: D102 return self.scaling_reg * jnp.sum(out + 0.5) # make positive def _prox_reg( # noqa: D102 - self, z: jnp.ndarray, tau: float = 1.0 - ) -> jnp.ndarray: + self, z: jax.Array, tau: float = 1.0 + ) -> jax.Array: tmp = 1.0 - (self.scaling_reg * tau / (jnp.abs(z) + 1e-12)) ** 2 return jax.nn.relu(tmp) * z @@ -600,7 +594,7 @@ def __init__(self, k: int, *args, **kwargs: Any): super().__init__(*args, **kwargs) self.k = k - def _reg(self, z: jnp.ndarray) -> float: # noqa: D102 + def _reg(self, z: jax.Array) -> float: # noqa: D102 # Prop 2.1 in :cite:`argyriou:12` k = self.k top_w = jax.lax.top_k(jnp.abs(z), k)[0] # Fetch largest k values @@ -621,15 +615,14 @@ def _reg(self, z: jnp.ndarray) -> float: # noqa: D102 return 0.5 * (s + (r + 1) * cesaro[r] ** 2) - def prox_reg(self, z: jnp.ndarray, tau: float = 1.0) -> float: # noqa: D102 + def prox_reg(self, z: jax.Array, tau: float = 1.0) -> float: # noqa: D102 @functools.partial(jax.vmap, in_axes=[0, None, None]) - def find_indices(r: int, l: jnp.ndarray, - z: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: + def find_indices(r: int, l: jax.Array, + z: jax.Array) -> Tuple[jax.Array, jax.Array]: @functools.partial(jax.vmap, in_axes=[None, 0, None]) - def inner(r: int, l: int, - z: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: + def inner(r: int, l: int, z: jax.Array) -> Tuple[jax.Array, jax.Array]: i = k - r - 1 res = jnp.sum(z * ((i <= ixs) & (ixs < l))) res /= l - k + (beta + 1) * r + beta + 1 @@ -692,14 +685,14 @@ def __init__(self, dimension: int, sqrtm_kw: Optional[Dict[str, Any]] = None): self._dimension = dimension self._sqrtm_kw = {} if sqrtm_kw is None else sqrtm_kw - def norm(self, x: jnp.ndarray) -> jnp.ndarray: + def norm(self, x: jax.Array) -> jax.Array: """Compute norm of Gaussian, sq. 2-norm of mean + trace of covariance.""" mean, cov = x_to_means_and_covs(x, self._dimension) norm = jnp.sum(mean ** 2, axis=-1) norm += jnp.trace(cov, axis1=-2, axis2=-1) return norm - def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: + def pairwise(self, x: jax.Array, y: jax.Array) -> float: """Compute - 2 x Bures dot-product.""" mean_x, cov_x = x_to_means_and_covs(x, self._dimension) mean_y, cov_y = x_to_means_and_covs(y, self._dimension) @@ -713,12 +706,12 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: def covariance_fixpoint_iter( self, - covs: jnp.ndarray, - weights: jnp.ndarray, + covs: jax.Array, + weights: jax.Array, tolerance: float = 1e-4, sqrtm_kw: Optional[Dict[str, Any]] = None, **kwargs: Any - ) -> jnp.ndarray: + ) -> jax.Array: """Iterate fix-point updates to compute barycenter of Gaussians. Args: @@ -744,8 +737,8 @@ def covariance_fixpoint_iter( @functools.partial(jax.vmap, in_axes=[None, 0, 0]) def scale_covariances( - cov_sqrt: jnp.ndarray, cov: jnp.ndarray, weight: jnp.ndarray - ) -> jnp.ndarray: + cov_sqrt: jax.Array, cov: jax.Array, weight: jax.Array + ) -> jax.Array: """Rescale covariance in barycenter step.""" return weight * matrix_square_root.sqrtm_only((cov_sqrt @ cov) @ cov_sqrt, **sqrtm_kw) @@ -757,8 +750,8 @@ def cond_fn(iteration: int, constants: Tuple[Any, ...], state) -> bool: def body_fn( iteration: int, constants: Tuple[Any, ...], - state: Tuple[jnp.ndarray, float], compute_error: bool - ) -> Tuple[jnp.ndarray, float]: + state: Tuple[jax.Array, float], compute_error: bool + ) -> Tuple[jax.Array, float]: del constants, compute_error cov, diffs = state cov_sqrt, cov_inv_sqrt, _ = matrix_square_root.sqrtm(cov, **sqrtm_kw) @@ -770,7 +763,7 @@ def body_fn( diffs = diffs.at[iteration // inner_iterations].set(diff) return next_cov, diffs - def init_state() -> Tuple[jnp.ndarray, float]: + def init_state() -> Tuple[jax.Array, float]: cov_init = jnp.eye(self._dimension) diffs = -jnp.ones( (np.ceil(max_iterations / inner_iterations).astype(int),), @@ -791,12 +784,12 @@ def init_state() -> Tuple[jnp.ndarray, float]: def barycenter( self, - weights: jnp.ndarray, - xs: jnp.ndarray, + weights: jax.Array, + xs: jax.Array, tolerance: float = 1e-4, sqrtm_kw: Optional[Dict[str, Any]] = None, **kwargs: Any - ) -> Tuple[jnp.ndarray, jnp.ndarray]: + ) -> Tuple[jax.Array, jax.Array]: """Compute the Bures barycenter of weighted Gaussian distributions. Implements the fixed point approach proposed in :cite:`alvarez-esteban:16` @@ -842,7 +835,7 @@ def barycenter( return mean_and_cov_to_x(mu_bary, cov_bary, self._dimension), diffs @classmethod - def _padder(cls, dim: int) -> jnp.ndarray: + def _padder(cls, dim: int) -> jax.Array: dimension = int((-1 + math.sqrt(1 + 4 * dim)) / 2) padding = mean_and_cov_to_x( jnp.zeros((dimension,)), jnp.eye(dimension), dimension @@ -885,7 +878,7 @@ def __init__( self._gamma = gamma self._sqrtm_kw = kwargs - def norm(self, x: jnp.ndarray) -> jnp.ndarray: + def norm(self, x: jax.Array) -> jax.Array: """Compute norm of Gaussian for unbalanced Bures. Args: @@ -898,7 +891,7 @@ def norm(self, x: jnp.ndarray) -> jnp.ndarray: """ return self._gamma * x[..., 0] - def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: + def pairwise(self, x: jax.Array, y: jax.Array) -> float: """Compute dot-product for unbalanced Bures. Args: @@ -992,18 +985,17 @@ def __init__( self.ground_cost = SqEuclidean() if ground_cost is None else ground_cost self.debiased = debiased - def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: # noqa: D102 + def pairwise(self, x: jax.Array, y: jax.Array) -> float: # noqa: D102 c_xy = self._soft_dtw(x, y) if self.debiased: return c_xy - 0.5 * (self._soft_dtw(x, x) + self._soft_dtw(y, y)) return c_xy - def _soft_dtw(self, t1: jnp.ndarray, t2: jnp.ndarray) -> float: + def _soft_dtw(self, t1: jax.Array, t2: jax.Array) -> float: def body( - carry: Tuple[jnp.ndarray, jnp.ndarray], - current_antidiagonal: jnp.ndarray - ) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]: + carry: Tuple[jax.Array, jax.Array], current_antidiagonal: jax.Array + ) -> Tuple[Tuple[jax.Array, jax.Array], jax.Array]: # modified from: https://github.com/khdlr/softdtw_jax two_ago, one_ago = carry @@ -1050,8 +1042,8 @@ def tree_unflatten(cls, aux_data, children): # noqa: D102 return cls(*children, **aux_data) -def x_to_means_and_covs(x: jnp.ndarray, - dimension: int) -> Tuple[jnp.ndarray, jnp.ndarray]: +def x_to_means_and_covs(x: jax.Array, + dimension: int) -> Tuple[jax.Array, jax.Array]: """Extract means and covariance matrices of Gaussians from raveled vector. Args: @@ -1071,8 +1063,8 @@ def x_to_means_and_covs(x: jnp.ndarray, def mean_and_cov_to_x( - mean: jnp.ndarray, covariance: jnp.ndarray, dimension: int -) -> jnp.ndarray: + mean: jax.Array, covariance: jax.Array, dimension: int +) -> jax.Array: """Ravel a Gaussian's mean and covariance matrix to d(1 + d) vector.""" return jnp.concatenate( (mean, jnp.reshape(covariance, (dimension * dimension))) diff --git a/src/ott/geometry/geometry.py b/src/ott/geometry/geometry.py index f953bf38c..5d3db3ee6 100644 --- a/src/ott/geometry/geometry.py +++ b/src/ott/geometry/geometry.py @@ -79,14 +79,14 @@ class Geometry: def __init__( self, - cost_matrix: Optional[jnp.ndarray] = None, - kernel_matrix: Optional[jnp.ndarray] = None, + cost_matrix: Optional[jax.Array] = None, + kernel_matrix: Optional[jax.Array] = None, epsilon: Optional[Union[float, epsilon_scheduler.Epsilon]] = None, relative_epsilon: Optional[bool] = None, scale_cost: Union[bool, int, float, Literal["mean", "max_cost", "median"]] = 1.0, - src_mask: Optional[jnp.ndarray] = None, - tgt_mask: Optional[jnp.ndarray] = None, + src_mask: Optional[jax.Array] = None, + tgt_mask: Optional[jax.Array] = None, ): self._cost_matrix = cost_matrix self._kernel_matrix = kernel_matrix @@ -107,7 +107,7 @@ def cost_rank(self) -> Optional[int]: """Output rank of cost matrix, if any was provided.""" @property - def cost_matrix(self) -> jnp.ndarray: + def cost_matrix(self) -> jax.Array: """Cost matrix, recomputed from kernel if only kernel was specified.""" if self._cost_matrix is None: # If no epsilon was passed on to the geometry, then assume it is one by @@ -131,7 +131,7 @@ def mean_cost_matrix(self) -> float: return jnp.sum(tmp * self._m_normed_ones) @property - def kernel_matrix(self) -> jnp.ndarray: + def kernel_matrix(self) -> jax.Array: """Kernel matrix. Either provided by user or recomputed from :attr:`cost_matrix`. @@ -245,12 +245,12 @@ def copy_epsilon(self, other: "Geometry") -> "Geometry": def apply_lse_kernel( self, - f: jnp.ndarray, - g: jnp.ndarray, + f: jax.Array, + g: jax.Array, eps: float, - vec: jnp.ndarray = None, + vec: jax.Array = None, axis: int = 0 - ) -> jnp.ndarray: + ) -> jax.Array: r"""Apply :attr:`kernel_matrix` in log domain. This function applies the ground geometry's kernel in log domain, using @@ -267,10 +267,10 @@ def apply_lse_kernel( f and g in iterations 1 & 2 respectively. Args: - f: jnp.ndarray [num_a,] , potential of size num_rows of cost_matrix - g: jnp.ndarray [num_b,] , potential of size num_cols of cost_matrix + f: jax.Array [num_a,] , potential of size num_rows of cost_matrix + g: jax.Array [num_b,] , potential of size num_cols of cost_matrix eps: float, regularization strength - vec: jnp.ndarray [num_a or num_b,] , when not None, this has the effect of + vec: jax.Array [num_a or num_b,] , when not None, this has the effect of doing log-Kernel computations with an addition elementwise multiplication of exp(g / eps) by a vector. This is carried out by adding weights to the log-sum-exp function, and needs to handle signs @@ -278,7 +278,7 @@ def apply_lse_kernel( axis: summing over axis 0 when doing (2), or over axis 1 when doing (1) Returns: - A jnp.ndarray corresponding to output above, depending on axis. + A jax.Array corresponding to output above, depending on axis. """ w_res, w_sgn = self._softmax(f, g, eps, vec, axis) remove = f if axis == 1 else g @@ -286,20 +286,20 @@ def apply_lse_kernel( def apply_kernel( self, - scaling: jnp.ndarray, + scaling: jax.Array, eps: Optional[float] = None, axis: int = 0, - ) -> jnp.ndarray: + ) -> jax.Array: """Apply :attr:`kernel_matrix` on positive scaling vector. Args: - scaling: jnp.ndarray [num_a or num_b] , scaling of size num_rows or + scaling: jax.Array [num_a or num_b] , scaling of size num_rows or num_cols of kernel_matrix eps: passed for consistency, not used yet. axis: standard kernel product if axis is 1, transpose if 0. Returns: - a jnp.ndarray corresponding to output above, depending on axis. + a jax.Array corresponding to output above, depending on axis. """ if eps is None: kernel = self.kernel_matrix @@ -311,10 +311,10 @@ def apply_kernel( def marginal_from_potentials( self, - f: jnp.ndarray, - g: jnp.ndarray, + f: jax.Array, + g: jax.Array, axis: int = 0, - ) -> jnp.ndarray: + ) -> jax.Array: """Output marginal of transportation matrix from potentials. This applies first lse kernel in the standard way, removes the @@ -323,8 +323,8 @@ def marginal_from_potentials( by potentials. Args: - f: jnp.ndarray [num_a,] , potential of size num_rows of cost_matrix - g: jnp.ndarray [num_b,] , potential of size num_cols of cost_matrix + f: jax.Array [num_a,] , potential of size num_rows of cost_matrix + g: jax.Array [num_b,] , potential of size num_cols of cost_matrix axis: axis along which to integrate, returns marginal on other axis. Returns: @@ -336,23 +336,19 @@ def marginal_from_potentials( def marginal_from_scalings( self, - u: jnp.ndarray, - v: jnp.ndarray, + u: jax.Array, + v: jax.Array, axis: int = 0, - ) -> jnp.ndarray: + ) -> jax.Array: """Output marginal of transportation matrix from scalings.""" u, v = (v, u) if axis == 0 else (u, v) return u * self.apply_kernel(v, eps=self.epsilon, axis=axis) - def transport_from_potentials( - self, f: jnp.ndarray, g: jnp.ndarray - ) -> jnp.ndarray: + def transport_from_potentials(self, f: jax.Array, g: jax.Array) -> jax.Array: """Output transport matrix from potentials.""" return jnp.exp(self._center(f, g) / self.epsilon) - def transport_from_scalings( - self, u: jnp.ndarray, v: jnp.ndarray - ) -> jnp.ndarray: + def transport_from_scalings(self, u: jax.Array, v: jax.Array) -> jax.Array: """Output transport matrix from pair of scalings.""" return self.kernel_matrix * u[:, jnp.newaxis] * v[jnp.newaxis, :] @@ -361,17 +357,17 @@ def transport_from_scalings( def update_potential( self, - f: jnp.ndarray, - g: jnp.ndarray, - log_marginal: jnp.ndarray, + f: jax.Array, + g: jax.Array, + log_marginal: jax.Array, iteration: Optional[int] = None, axis: int = 0, - ) -> jnp.ndarray: + ) -> jax.Array: """Carry out one Sinkhorn update for potentials, i.e. in log space. Args: - f: jnp.ndarray [num_a,] , potential of size num_rows of cost_matrix - g: jnp.ndarray [num_b,] , potential of size num_cols of cost_matrix + f: jax.Array [num_a,] , potential of size num_rows of cost_matrix + g: jax.Array [num_b,] , potential of size num_cols of cost_matrix log_marginal: targeted marginal iteration: used to compute epsilon from schedule, if provided. axis: axis along which the update should be carried out. @@ -385,15 +381,15 @@ def update_potential( def update_scaling( self, - scaling: jnp.ndarray, - marginal: jnp.ndarray, + scaling: jax.Array, + marginal: jax.Array, iteration: Optional[int] = None, axis: int = 0, - ) -> jnp.ndarray: + ) -> jax.Array: """Carry out one Sinkhorn update for scalings, using kernel directly. Args: - scaling: jnp.ndarray of num_a or num_b positive values. + scaling: jax.Array of num_a or num_b positive values. marginal: targeted marginal iteration: used to compute epsilon from schedule, if provided. axis: axis along which the update should be carried out. @@ -406,13 +402,13 @@ def update_scaling( return marginal / jnp.where(app_kernel > 0, app_kernel, 1.0) # Helper functions - def _center(self, f: jnp.ndarray, g: jnp.ndarray) -> jnp.ndarray: + def _center(self, f: jax.Array, g: jax.Array) -> jax.Array: return f[:, jnp.newaxis] + g[jnp.newaxis, :] - self.cost_matrix def _softmax( - self, f: jnp.ndarray, g: jnp.ndarray, eps: float, - vec: Optional[jnp.ndarray], axis: int - ) -> Tuple[jnp.ndarray, jnp.ndarray]: + self, f: jax.Array, g: jax.Array, eps: float, vec: Optional[jax.Array], + axis: int + ) -> Tuple[jax.Array, jax.Array]: """Apply softmax row or column wise, weighted by vec.""" if vec is not None: if axis == 0: @@ -429,8 +425,8 @@ def _softmax( @functools.partial(jax.vmap, in_axes=[None, None, None, 0, None]) def _apply_transport_from_potentials( - self, f: jnp.ndarray, g: jnp.ndarray, vec: jnp.ndarray, axis: int - ) -> jnp.ndarray: + self, f: jax.Array, g: jax.Array, vec: jax.Array, axis: int + ) -> jax.Array: """Apply lse_kernel to arbitrary vector while keeping track of signs.""" lse_res, lse_sgn = self.apply_lse_kernel( f, g, self.epsilon, vec=vec, axis=axis @@ -441,11 +437,11 @@ def _apply_transport_from_potentials( # wrapper to allow default option for axis. def apply_transport_from_potentials( self, - f: jnp.ndarray, - g: jnp.ndarray, - vec: jnp.ndarray, + f: jax.Array, + g: jax.Array, + vec: jax.Array, axis: int = 0 - ) -> jnp.ndarray: + ) -> jax.Array: """Apply transport matrix computed from potentials to a (batched) vec. This approach does not instantiate the transport matrix itself, but uses @@ -456,9 +452,9 @@ def apply_transport_from_potentials( (b=..., return_sign=True) optional parameters of logsumexp. Args: - f: jnp.ndarray [num_a,] , potential of size num_rows of cost_matrix - g: jnp.ndarray [num_b,] , potential of size num_cols of cost_matrix - vec: jnp.ndarray [batch, num_a or num_b], vector that will be multiplied + f: jax.Array [num_a,] , potential of size num_rows of cost_matrix + g: jax.Array [num_b,] , potential of size num_cols of cost_matrix + vec: jax.Array [batch, num_a or num_b], vector that will be multiplied by transport matrix corresponding to potentials f, g, and geom. axis: axis to differentiate left (0) or right (1) multiply. @@ -473,7 +469,7 @@ def apply_transport_from_potentials( @functools.partial(jax.vmap, in_axes=[None, None, None, 0, None]) def _apply_transport_from_scalings( - self, u: jnp.ndarray, v: jnp.ndarray, vec: jnp.ndarray, axis: int + self, u: jax.Array, v: jax.Array, vec: jax.Array, axis: int ): u, v = (u, v * vec) if axis == 1 else (v, u * vec) return u * self.apply_kernel(v, eps=self.epsilon, axis=axis) @@ -481,20 +477,20 @@ def _apply_transport_from_scalings( # wrapper to allow default option for axis def apply_transport_from_scalings( self, - u: jnp.ndarray, - v: jnp.ndarray, - vec: jnp.ndarray, + u: jax.Array, + v: jax.Array, + vec: jax.Array, axis: int = 0 - ) -> jnp.ndarray: + ) -> jax.Array: """Apply transport matrix computed from scalings to a (batched) vec. This approach does not instantiate the transport matrix itself, but relies instead on the apply_kernel function. Args: - u: jnp.ndarray [num_a,] , scaling of size num_rows of cost_matrix - v: jnp.ndarray [num_b,] , scaling of size num_cols of cost_matrix - vec: jnp.ndarray [batch, num_a or num_b], vector that will be multiplied + u: jax.Array [num_a,] , scaling of size num_rows of cost_matrix + v: jax.Array [num_b,] , scaling of size num_cols of cost_matrix + vec: jax.Array [batch, num_a or num_b], vector that will be multiplied by transport matrix corresponding to scalings u, v, and geom. axis: axis to differentiate left (0) or right (1) multiply. @@ -507,7 +503,7 @@ def apply_transport_from_scalings( )[0, :] return self._apply_transport_from_scalings(u, v, vec, axis) - def potential_from_scaling(self, scaling: jnp.ndarray) -> jnp.ndarray: + def potential_from_scaling(self, scaling: jax.Array) -> jax.Array: """Compute dual potential vector from scaling vector. Args: @@ -518,7 +514,7 @@ def potential_from_scaling(self, scaling: jnp.ndarray) -> jnp.ndarray: """ return self.epsilon * jnp.log(scaling) - def scaling_from_potential(self, potential: jnp.ndarray) -> jnp.ndarray: + def scaling_from_potential(self, potential: jax.Array) -> jax.Array: """Compute scaling vector from dual potential. Args: @@ -532,7 +528,7 @@ def scaling_from_potential(self, potential: jnp.ndarray) -> jnp.ndarray: finite, jnp.exp(jnp.where(finite, potential / self.epsilon, 0.0)), 0.0 ) - def apply_square_cost(self, arr: jnp.ndarray, axis: int = 0) -> jnp.ndarray: + def apply_square_cost(self, arr: jax.Array, axis: int = 0) -> jax.Array: """Apply elementwise-square of cost matrix to array (vector or matrix). This function applies the ground geometry's cost matrix, to perform either @@ -553,11 +549,11 @@ def apply_square_cost(self, arr: jnp.ndarray, axis: int = 0) -> jnp.ndarray: def apply_cost( self, - arr: jnp.ndarray, + arr: jax.Array, axis: int = 0, - fn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, + fn: Optional[Callable[[jax.Array], jax.Array]] = None, **kwargs: Any - ) -> jnp.ndarray: + ) -> jax.Array: """Apply :attr:`cost_matrix` to array (vector or matrix). This function applies the ground geometry's cost matrix, to perform either @@ -566,7 +562,7 @@ def apply_cost( where C is [num_a, num_b] Args: - arr: jnp.ndarray [num_a or num_b, p], vector that will be multiplied by + arr: jax.Array [num_a or num_b, p], vector that will be multiplied by the cost matrix. axis: standard cost matrix if axis=1, transpose if 0 fn: function to apply to cost matrix element-wise before the dot product @@ -583,21 +579,21 @@ def apply_cost( def _apply_cost_to_vec( self, - vec: jnp.ndarray, + vec: jax.Array, axis: int = 0, fn=None, **_: Any, - ) -> jnp.ndarray: + ) -> jax.Array: """Apply ``[num_a, num_b]`` fn(cost) (or transpose) to vector. Args: - vec: jnp.ndarray [num_a,] ([num_b,] if axis=1) vector + vec: jax.Array [num_a,] ([num_b,] if axis=1) vector axis: axis on which the reduction is done. fn: function optionally applied to cost matrix element-wise, before the doc product Returns: - A jnp.ndarray corresponding to cost x vector + A jax.Array corresponding to cost x vector """ matrix = self.cost_matrix.T if axis == 0 else self.cost_matrix matrix = fn(matrix) if fn is not None else matrix @@ -718,7 +714,7 @@ def to_LRCGeometry( ) def subset( - self, src_ixs: Optional[jnp.ndarray], tgt_ixs: Optional[jnp.ndarray], + self, src_ixs: Optional[jax.Array], tgt_ixs: Optional[jax.Array], **kwargs: Any ) -> "Geometry": """Subset rows or columns of a geometry. @@ -733,10 +729,10 @@ def subset( """ def subset_fn( - arr: Optional[jnp.ndarray], - src_ixs: Optional[jnp.ndarray], - tgt_ixs: Optional[jnp.ndarray], - ) -> Optional[jnp.ndarray]: + arr: Optional[jax.Array], + src_ixs: Optional[jax.Array], + tgt_ixs: Optional[jax.Array], + ) -> Optional[jax.Array]: if arr is None: return None if src_ixs is not None: @@ -755,8 +751,8 @@ def subset_fn( def mask( self, - src_mask: Optional[jnp.ndarray], - tgt_mask: Optional[jnp.ndarray], + src_mask: Optional[jax.Array], + tgt_mask: Optional[jax.Array], mask_value: float = 0., ) -> "Geometry": """Mask rows or columns of a geometry. @@ -780,10 +776,10 @@ def mask( """ def mask_fn( - arr: Optional[jnp.ndarray], - src_mask: Optional[jnp.ndarray], - tgt_mask: Optional[jnp.ndarray], - ) -> Optional[jnp.ndarray]: + arr: Optional[jax.Array], + src_mask: Optional[jax.Array], + tgt_mask: Optional[jax.Array], + ) -> Optional[jax.Array]: if arr is None: return arr assert arr.ndim == 2, arr.ndim @@ -801,12 +797,12 @@ def mask_fn( def _mask_subset_helper( self, - src_ixs: Optional[jnp.ndarray], - tgt_ixs: Optional[jnp.ndarray], + src_ixs: Optional[jax.Array], + tgt_ixs: Optional[jax.Array], *, fn: Callable[ - [Optional[jnp.ndarray], Optional[jnp.ndarray], Optional[jnp.ndarray]], - Optional[jnp.ndarray]], + [Optional[jax.Array], Optional[jax.Array], Optional[jax.Array]], + Optional[jax.Array]], propagate_mask: bool, **kwargs: Any, ) -> "Geometry": @@ -825,7 +821,7 @@ def _mask_subset_helper( ) @property - def src_mask(self) -> Optional[jnp.ndarray]: + def src_mask(self) -> Optional[jax.Array]: """Mask of shape ``[num_a,]`` to compute :attr:`cost_matrix` statistics. Specifically, it is used when computing: @@ -837,7 +833,7 @@ def src_mask(self) -> Optional[jnp.ndarray]: return self._normalize_mask(self._src_mask, self.shape[0]) @property - def tgt_mask(self) -> Optional[jnp.ndarray]: + def tgt_mask(self) -> Optional[jax.Array]: """Mask of shape ``[num_b,]`` to compute :attr:`cost_matrix` statistics. Specifically, it is used when computing: @@ -863,22 +859,22 @@ def _masked_geom(self, mask_value: float = 0.) -> "Geometry": return self.mask(src_mask, tgt_mask, mask_value=mask_value) @property - def _n_normed_ones(self) -> jnp.ndarray: + def _n_normed_ones(self) -> jax.Array: """Normalized array of shape ``[num_a,]``.""" mask = self.src_mask arr = jnp.ones(self.shape[0]) if mask is None else mask return arr / jnp.sum(arr) @property - def _m_normed_ones(self) -> jnp.ndarray: + def _m_normed_ones(self) -> jax.Array: """Normalized array of shape ``[num_b,]``.""" mask = self.tgt_mask arr = jnp.ones(self.shape[1]) if mask is None else mask return arr / jnp.sum(arr) @staticmethod - def _normalize_mask(mask: Optional[Union[int, jnp.ndarray]], - size: int) -> Optional[jnp.ndarray]: + def _normalize_mask(mask: Optional[Union[int, jax.Array]], + size: int) -> Optional[jax.Array]: """Convert array of indices to a boolean mask.""" if mask is None: return None diff --git a/src/ott/geometry/graph.py b/src/ott/geometry/graph.py index c7dac0c99..ab0fe8768 100644 --- a/src/ott/geometry/graph.py +++ b/src/ott/geometry/graph.py @@ -48,7 +48,7 @@ class Graph(geometry.Geometry): def __init__( self, - laplacian: jnp.ndarray, + laplacian: jax.Array, t: float = 1e-3, n_steps: int = 100, numerical_scheme: Literal["backward_euler", @@ -66,7 +66,7 @@ def __init__( @classmethod def from_graph( cls, - G: jnp.ndarray, + G: jax.Array, t: Optional[float] = 1e-3, directed: bool = False, normalize: bool = False, @@ -113,10 +113,10 @@ def from_graph( def apply_kernel( self, - scaling: jnp.ndarray, + scaling: jax.Array, eps: Optional[float] = None, axis: int = 0, - ) -> jnp.ndarray: + ) -> jax.Array: r"""Apply :attr:`kernel_matrix` on positive scaling vector. Args: @@ -129,8 +129,8 @@ def apply_kernel( """ def conf_fn( - iteration: int, consts: Tuple[jnp.ndarray, Optional[jnp.ndarray]], - old_new: Tuple[jnp.ndarray, jnp.ndarray] + iteration: int, consts: Tuple[jax.Array, Optional[jax.Array]], + old_new: Tuple[jax.Array, jax.Array] ) -> bool: del iteration, consts @@ -143,9 +143,9 @@ def conf_fn( return (jnp.nanmax(f) - jnp.nanmin(f)) > self.tol def body_fn( - iteration: int, consts: Tuple[jnp.ndarray, Optional[jnp.ndarray]], - old_new: Tuple[jnp.ndarray, jnp.ndarray], compute_errors: bool - ) -> Tuple[jnp.ndarray, jnp.ndarray]: + iteration: int, consts: Tuple[jax.Array, Optional[jax.Array]], + old_new: Tuple[jax.Array, jax.Array], compute_errors: bool + ) -> Tuple[jax.Array, jax.Array]: del iteration, compute_errors L, scaled_lap = consts @@ -186,7 +186,7 @@ def body_fn( )[1] @property - def kernel_matrix(self) -> jnp.ndarray: # noqa: D102 + def kernel_matrix(self) -> jax.Array: # noqa: D102 n, _ = self.shape kernel = self.apply_kernel(jnp.eye(n)) # force symmetry because of numerical imprecision @@ -194,7 +194,7 @@ def kernel_matrix(self) -> jnp.ndarray: # noqa: D102 return (kernel + kernel.T) * 0.5 @property - def cost_matrix(self) -> jnp.ndarray: # noqa: D102 + def cost_matrix(self) -> jax.Array: # noqa: D102 return -self.t * mu.safe_log(self.kernel_matrix) @property @@ -209,12 +209,12 @@ def _scale(self) -> float: ) @property - def _scaled_laplacian(self) -> jnp.ndarray: + def _scaled_laplacian(self) -> jax.Array: """Laplacian scaled by a constant, depending on the numerical scheme.""" return self._scale * self.laplacian @property - def _M(self) -> jnp.ndarray: + def _M(self) -> jax.Array: n, _ = self.shape return self._scaled_laplacian + jnp.eye(n) @@ -230,29 +230,27 @@ def is_symmetric(self) -> bool: # noqa: D102 def dtype(self) -> jnp.dtype: # noqa: D102 return self.laplacian.dtype - def transport_from_potentials( - self, f: jnp.ndarray, g: jnp.ndarray - ) -> jnp.ndarray: + def transport_from_potentials(self, f: jax.Array, g: jax.Array) -> jax.Array: """Not implemented.""" raise ValueError("Not implemented.") def apply_transport_from_potentials( self, - f: jnp.ndarray, - g: jnp.ndarray, - vec: jnp.ndarray, + f: jax.Array, + g: jax.Array, + vec: jax.Array, axis: int = 0 - ) -> jnp.ndarray: + ) -> jax.Array: """Since applying from potentials is not feasible in grids, use scalings.""" u, v = self.scaling_from_potential(f), self.scaling_from_potential(g) return self.apply_transport_from_scalings(u, v, vec, axis=axis) def marginal_from_potentials( self, - f: jnp.ndarray, - g: jnp.ndarray, + f: jax.Array, + g: jax.Array, axis: int = 0, - ) -> jnp.ndarray: + ) -> jax.Array: """Not implemented.""" raise ValueError("Not implemented.") diff --git a/src/ott/geometry/grid.py b/src/ott/geometry/grid.py index fd64500c9..3401f52c7 100644 --- a/src/ott/geometry/grid.py +++ b/src/ott/geometry/grid.py @@ -71,7 +71,7 @@ class Grid(geometry.Geometry): def __init__( self, - x: Optional[Sequence[jnp.ndarray]] = None, + x: Optional[Sequence[jax.Array]] = None, grid_size: Optional[Sequence[int]] = None, cost_fns: Optional[Sequence[costs.CostFn]] = None, num_a: Optional[int] = None, @@ -146,12 +146,12 @@ def is_symmetric(self) -> bool: # noqa: D102 # Reimplemented functions to be used in regularized OT def apply_lse_kernel( self, - f: jnp.ndarray, - g: jnp.ndarray, + f: jax.Array, + g: jax.Array, eps: float, - vec: Optional[jnp.ndarray] = None, + vec: Optional[jax.Array] = None, axis: int = 0 - ) -> jnp.ndarray: + ) -> jax.Array: """Apply grid kernel in log space. See notes in parent class for use case. Reshapes vector inputs below as grids, applies kernels onto each slice, and @@ -160,10 +160,10 @@ def apply_lse_kernel( More implementation details in :cite:`schmitz:18`. Args: - f: jnp.ndarray, a vector of potentials - g: jnp.ndarray, a vector of potentials + f: jax.Array, a vector of potentials + g: jax.Array, a vector of potentials eps: float, regularization strength - vec: jnp.ndarray, if needed, a vector onto which apply the kernel weighted + vec: jax.Array, if needed, a vector onto which apply the kernel weighted by f and g. axis: axis (0 or 1) along which summation should be carried out. @@ -209,8 +209,8 @@ def _apply_lse_kernel_one_dimension(self, dimension, f, g, eps, vec=None): return jnp.transpose(softmax_res, indices), None def _apply_cost_to_vec( - self, vec: jnp.ndarray, axis: int = 0, fn=None - ) -> jnp.ndarray: + self, vec: jax.Array, axis: int = 0, fn=None + ) -> jax.Array: r"""Apply grid's cost matrix (without instantiating it) to a vector. The `apply_cost` operation on grids rests on the following identity. @@ -229,13 +229,13 @@ def _apply_cost_to_vec( summation while keeping dimensions. Args: - vec: jnp.ndarray, flat vector of total size prod(grid_size). + vec: jax.Array, flat vector of total size prod(grid_size). axis: axis 0 if applying transpose costs, 1 if using the original cost. fn: function optionally applied to cost matrix element-wise, before the dot product. Returns: - A jnp.ndarray corresponding to cost x matrix + A jax.Array corresponding to cost x matrix """ vec = jnp.reshape(vec, self.grid_size) accum_vec = jnp.zeros_like(vec) @@ -255,10 +255,10 @@ def _apply_cost_to_vec( def apply_kernel( self, - scaling: jnp.ndarray, + scaling: jax.Array, eps: Optional[float] = None, axis: Optional[int] = None - ) -> jnp.ndarray: + ) -> jax.Array: """Apply grid kernel on scaling vector. See notes in parent class for use. @@ -269,7 +269,7 @@ def apply_kernel( More implementation details in :cite:`schmitz:18`, Args: - scaling: jnp.ndarray, a vector of scaling (>0) values. + scaling: jax.Array, a vector of scaling (>0) values. eps: float, regularization strength axis: axis (0 or 1) along which summation should be carried out. @@ -289,7 +289,7 @@ def apply_kernel( return scaling.ravel() def transport_from_potentials( - self, f: jnp.ndarray, g: jnp.ndarray, axis: int = 0 + self, f: jax.Array, g: jax.Array, axis: int = 0 ) -> NoReturn: """Not implemented, use :meth:`apply_transport_from_potentials` instead.""" raise ValueError( @@ -300,7 +300,7 @@ def transport_from_potentials( ) def transport_from_scalings( - self, f: jnp.ndarray, g: jnp.ndarray, axis: int = 0 + self, f: jax.Array, g: jax.Array, axis: int = 0 ) -> NoReturn: """Not implemented, use :meth:`apply_transport_from_scalings` instead.""" raise ValueError( @@ -311,15 +311,15 @@ def transport_from_scalings( ) def subset( - self, src_ixs: Optional[jnp.ndarray], tgt_ixs: Optional[jnp.ndarray] + self, src_ixs: Optional[jax.Array], tgt_ixs: Optional[jax.Array] ) -> NoReturn: """Not implemented.""" raise NotImplementedError("Subsetting is not implemented for grids.") def mask( self, - src_mask: Optional[jnp.ndarray], - tgt_mask: Optional[jnp.ndarray], + src_mask: Optional[jax.Array], + tgt_mask: Optional[jax.Array], mask_value: float = 0., ) -> NoReturn: """Not implemented.""" diff --git a/src/ott/geometry/low_rank.py b/src/ott/geometry/low_rank.py index e759b4cb9..750d8db62 100644 --- a/src/ott/geometry/low_rank.py +++ b/src/ott/geometry/low_rank.py @@ -33,8 +33,8 @@ class LRCGeometry(geometry.Geometry): if :math:`C = AB^T` and :math:`D = EF^T` then :math:`C + D = [A,E][B,F]^T` Args: - cost_1: jnp.ndarray[num_a, r] - cost_2: jnp.ndarray[num_b, r] + cost_1: jax.Array[num_a, r] + cost_2: jax.Array[num_b, r] bias: constant added to entire cost matrix. scale: Value used to rescale the factors of the low-rank geometry. scale_cost: option to rescale the cost matrix. Implemented scalings are @@ -51,8 +51,8 @@ class LRCGeometry(geometry.Geometry): def __init__( self, - cost_1: jnp.ndarray, - cost_2: jnp.ndarray, + cost_1: jax.Array, + cost_2: jax.Array, bias: float = 0.0, scale_factor: float = 1.0, scale_cost: Union[bool, int, float, Literal["mean", "max_bound", @@ -69,13 +69,13 @@ def __init__( self.batch_size = batch_size @property - def cost_1(self) -> jnp.ndarray: + def cost_1(self) -> jax.Array: """First factor of the :attr:`cost_matrix`.""" scale_factor = jnp.sqrt(self._scale_factor * self.inv_scale_cost) return scale_factor * self._cost_1 @property - def cost_2(self) -> jnp.ndarray: + def cost_2(self) -> jax.Array: """Second factor of the :attr:`cost_matrix`.""" scale_factor = jnp.sqrt(self._scale_factor * self.inv_scale_cost) return scale_factor * self._cost_2 @@ -90,7 +90,7 @@ def cost_rank(self) -> int: # noqa: D102 return self._cost_1.shape[1] @property - def cost_matrix(self) -> jnp.ndarray: + def cost_matrix(self) -> jax.Array: """Materialize the cost matrix.""" return jnp.matmul(self.cost_1, self.cost_2.T) + self.bias @@ -124,7 +124,7 @@ def inv_scale_cost(self) -> float: # noqa: D102 return 1.0 / self.compute_max_cost() raise ValueError(f"Scaling {self._scale_cost} not implemented.") - def apply_square_cost(self, arr: jnp.ndarray, axis: int = 0) -> jnp.ndarray: + def apply_square_cost(self, arr: jax.Array, axis: int = 0) -> jax.Array: """Apply elementwise-square of cost matrix to array (vector or matrix).""" (n, m), r = self.shape, self.cost_rank # When applying square of a LRCGeometry, one can either elementwise square @@ -142,15 +142,15 @@ def apply_square_cost(self, arr: jnp.ndarray, axis: int = 0) -> jnp.ndarray: def _apply_cost_to_vec( self, - vec: jnp.ndarray, + vec: jax.Array, axis: int = 0, - fn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, + fn: Optional[Callable[[jax.Array], jax.Array]] = None, is_linear: bool = False, - ) -> jnp.ndarray: + ) -> jax.Array: """Apply [num_a, num_b] fn(cost) (or transpose) to vector. Args: - vec: jnp.ndarray [num_a,] ([num_b,] if axis=1) vector + vec: jax.Array [num_a,] ([num_b,] if axis=1) vector axis: axis on which the reduction is done. fn: function optionally applied to cost matrix element-wise, before the doc product @@ -159,12 +159,12 @@ def _apply_cost_to_vec( for a heuristic to help determine if a function is linear. Returns: - A jnp.ndarray corresponding to cost x vector + A jax.Array corresponding to cost x vector """ def linear_apply( - vec: jnp.ndarray, axis: int, fn: Callable[[jnp.ndarray], jnp.ndarray] - ) -> jnp.ndarray: + vec: jax.Array, axis: int, fn: Callable[[jax.Array], jax.Array] + ) -> jax.Array: c1 = self.cost_1 if axis == 1 else self.cost_2 c2 = self.cost_2 if axis == 1 else self.cost_1 c2 = fn(c2) if fn is not None else c2 @@ -241,14 +241,14 @@ def can_LRC(self): # noqa: D102 return True def subset( # noqa: D102 - self, src_ixs: Optional[jnp.ndarray], tgt_ixs: Optional[jnp.ndarray], + self, src_ixs: Optional[jax.Array], tgt_ixs: Optional[jax.Array], **kwargs: Any ) -> "LRCGeometry": def subset_fn( - arr: Optional[jnp.ndarray], - ixs: Optional[jnp.ndarray], - ) -> jnp.ndarray: + arr: Optional[jax.Array], + ixs: Optional[jax.Array], + ) -> jax.Array: return arr if arr is None or ixs is None else arr[jnp.atleast_1d(ixs)] return self._mask_subset_helper( @@ -257,15 +257,15 @@ def subset_fn( def mask( # noqa: D102 self, - src_mask: Optional[jnp.ndarray], - tgt_mask: Optional[jnp.ndarray], + src_mask: Optional[jax.Array], + tgt_mask: Optional[jax.Array], mask_value: float = 0., ) -> "LRCGeometry": def mask_fn( - arr: Optional[jnp.ndarray], - mask: Optional[jnp.ndarray], - ) -> Optional[jnp.ndarray]: + arr: Optional[jax.Array], + mask: Optional[jax.Array], + ) -> Optional[jax.Array]: if arr is None or mask is None: return arr return jnp.where(mask[:, None], arr, mask_value) @@ -278,11 +278,11 @@ def mask_fn( def _mask_subset_helper( self, - src_ixs: Optional[jnp.ndarray], - tgt_ixs: Optional[jnp.ndarray], + src_ixs: Optional[jax.Array], + tgt_ixs: Optional[jax.Array], *, - fn: Callable[[Optional[jnp.ndarray], Optional[jnp.ndarray]], - Optional[jnp.ndarray]], + fn: Callable[[Optional[jax.Array], Optional[jax.Array]], + Optional[jax.Array]], propagate_mask: bool, **kwargs: Any, ) -> "LRCGeometry": diff --git a/src/ott/geometry/pointcloud.py b/src/ott/geometry/pointcloud.py index e7f46a020..c5d48a096 100644 --- a/src/ott/geometry/pointcloud.py +++ b/src/ott/geometry/pointcloud.py @@ -56,8 +56,8 @@ class PointCloud(geometry.Geometry): def __init__( self, - x: jnp.ndarray, - y: Optional[jnp.ndarray] = None, + x: jax.Array, + y: Optional[jax.Array] = None, cost_fn: Optional[costs.CostFn] = None, batch_size: Optional[int] = None, scale_cost: Union[bool, int, float, @@ -77,13 +77,13 @@ def __init__( self._scale_cost = "mean" if scale_cost is True else scale_cost @property - def _norm_x(self) -> Union[float, jnp.ndarray]: + def _norm_x(self) -> Union[float, jax.Array]: if self._axis_norm == 0: return self.cost_fn.norm(self.x) return 0. @property - def _norm_y(self) -> Union[float, jnp.ndarray]: + def _norm_y(self) -> Union[float, jax.Array]: if self._axis_norm == 0: return self.cost_fn.norm(self.y) return 0. @@ -98,14 +98,14 @@ def _check_LRC_dim(self): return n * m > (n + m) * d @property - def cost_matrix(self) -> Optional[jnp.ndarray]: # noqa: D102 + def cost_matrix(self) -> Optional[jax.Array]: # noqa: D102 if self.is_online: return None cost_matrix = self._compute_cost_matrix() return cost_matrix * self.inv_scale_cost @property - def kernel_matrix(self) -> Optional[jnp.ndarray]: # noqa: D102 + def kernel_matrix(self) -> Optional[jax.Array]: # noqa: D102 if self.is_online: return None return jnp.exp(-self.cost_matrix / self.epsilon) @@ -183,7 +183,7 @@ def inv_scale_cost(self) -> float: # noqa: D102 ) raise ValueError(f"Scaling {self._scale_cost} not implemented.") - def _compute_cost_matrix(self) -> jnp.ndarray: + def _compute_cost_matrix(self) -> jax.Array: cost_matrix = self.cost_fn.all_pairs_pairwise(self.x, self.y) if self._axis_norm is not None: cost_matrix += self._norm_x[:, jnp.newaxis] + self._norm_y[jnp.newaxis, :] @@ -191,12 +191,12 @@ def _compute_cost_matrix(self) -> jnp.ndarray: def apply_lse_kernel( # noqa: D102 self, - f: jnp.ndarray, - g: jnp.ndarray, + f: jax.Array, + g: jax.Array, eps: float, - vec: Optional[jnp.ndarray] = None, + vec: Optional[jax.Array] = None, axis: int = 0 - ) -> jnp.ndarray: + ) -> jax.Array: def body0(carry, i: int): f, g, eps, vec = carry @@ -278,10 +278,10 @@ def finalize(i: int): def apply_kernel( # noqa: D102 self, - scaling: jnp.ndarray, + scaling: jax.Array, eps: Optional[float] = None, axis: int = 0 - ) -> jnp.ndarray: + ) -> jax.Array: if eps is None: eps = self.epsilon @@ -303,8 +303,8 @@ def apply_kernel( # noqa: D102 ) def transport_from_potentials( # noqa: D102 - self, f: jnp.ndarray, g: jnp.ndarray - ) -> jnp.ndarray: + self, f: jax.Array, g: jax.Array + ) -> jax.Array: if not self.is_online: return super().transport_from_potentials(f, g) transport = jax.vmap( @@ -317,8 +317,8 @@ def transport_from_potentials( # noqa: D102 ) def transport_from_scalings( # noqa: D102 - self, u: jnp.ndarray, v: jnp.ndarray - ) -> jnp.ndarray: + self, u: jax.Array, v: jax.Array + ) -> jax.Array: if not self.is_online: return super().transport_from_scalings(u, v) transport = jax.vmap( @@ -342,11 +342,11 @@ def transport_from_scalings( # noqa: D102 def apply_cost( self, - arr: jnp.ndarray, + arr: jax.Array, axis: int = 0, - fn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, + fn: Optional[Callable[[jax.Array], jax.Array]] = None, is_linear: bool = False, - ) -> jnp.ndarray: + ) -> jax.Array: """Apply cost matrix to array (vector or matrix). This function applies the geometry's cost matrix, to perform either @@ -356,7 +356,7 @@ def apply_cost( application of fn to each entry of the :attr:`cost_matrix`. Args: - arr: jnp.ndarray [num_a or num_b, batch], vector that will be multiplied + arr: jax.Array [num_a or num_b, batch], vector that will be multiplied by the cost matrix. axis: standard cost matrix if axis=1, transpose if 0. fn: function optionally applied to cost matrix element-wise, before the @@ -367,7 +367,7 @@ def apply_cost( for a heuristic to help determine if a function is linear. Returns: - A jnp.ndarray, [num_b, batch] if axis=0 or [num_a, batch] if axis=1 + A jax.Array, [num_b, batch] if axis=0 or [num_a, batch] if axis=1 """ # switch to efficient computation for the squared euclidean case. if self.is_squared_euclidean and (fn is None or is_linear): @@ -375,9 +375,7 @@ def apply_cost( return self._apply_cost(arr, axis, fn=fn) - def _apply_cost( - self, arr: jnp.ndarray, axis: int = 0, fn=None - ) -> jnp.ndarray: + def _apply_cost(self, arr: jax.Array, axis: int = 0, fn=None) -> jax.Array: """See :meth:`apply_cost`.""" if not self.is_online: return super().apply_cost(arr, axis, fn) @@ -401,24 +399,24 @@ def _apply_cost( def vec_apply_cost( self, - arr: jnp.ndarray, + arr: jax.Array, axis: int = 0, - fn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None - ) -> jnp.ndarray: + fn: Optional[Callable[[jax.Array], jax.Array]] = None + ) -> jax.Array: """Apply the geometry's cost matrix in a vectorized way. This function can be used when the cost matrix is squared euclidean and ``fn`` is a linear function. Args: - arr: jnp.ndarray [num_a or num_b, p], vector that will be multiplied + arr: jax.Array [num_a or num_b, p], vector that will be multiplied by the cost matrix. axis: standard cost matrix if axis=1, transport if 0. fn: function optionally applied to cost matrix element-wise, before the application. Returns: - A jnp.ndarray, [num_b, p] if axis=0 or [num_a, p] if axis=1 + A jax.Array, [num_b, p] if axis=0 or [num_a, p] if axis=1 """ assert self.is_squared_euclidean, "Cost matrix is not a squared Euclidean." rank = arr.ndim @@ -434,7 +432,7 @@ def vec_apply_cost( applied_cost = fn(applied_cost) return self.inv_scale_cost * applied_cost - def _leading_slice(self, t: jnp.ndarray, i: int) -> jnp.ndarray: + def _leading_slice(self, t: jax.Array, i: int) -> jax.Array: start_indices = [i * self.batch_size] + (t.ndim - 1) * [0] slice_sizes = [self.batch_size] + list(t.shape[1:]) return jax.lax.dynamic_slice(t, start_indices, slice_sizes) @@ -525,18 +523,18 @@ def finalize(i: int): f"Scaling method {summary} does not exist for online mode." ) - def barycenter(self, weights: jnp.ndarray) -> jnp.ndarray: + def barycenter(self, weights: jax.Array) -> jax.Array: """Compute barycenter of points in self.x using weights.""" return self.cost_fn.barycenter(self.x, weights)[0] @classmethod def prepare_divergences( cls, - x: jnp.ndarray, - y: jnp.ndarray, + x: jax.Array, + y: jax.Array, static_b: bool = False, - src_mask: Optional[jnp.ndarray] = None, - tgt_mask: Optional[jnp.ndarray] = None, + src_mask: Optional[jax.Array] = None, + tgt_mask: Optional[jax.Array] = None, **kwargs: Any ) -> Tuple["PointCloud", ...]: """Instantiate the geometries used for a divergence computation.""" @@ -640,14 +638,14 @@ def _sqeucl_to_lr(self, scale: float = 1.0) -> low_rank.LRCGeometry: ) def subset( # noqa: D102 - self, src_ixs: Optional[jnp.ndarray], tgt_ixs: Optional[jnp.ndarray], + self, src_ixs: Optional[jax.Array], tgt_ixs: Optional[jax.Array], **kwargs: Any ) -> "PointCloud": def subset_fn( - arr: Optional[jnp.ndarray], - ixs: Optional[jnp.ndarray], - ) -> jnp.ndarray: + arr: Optional[jax.Array], + ixs: Optional[jax.Array], + ) -> jax.Array: return arr if arr is None or ixs is None else arr[jnp.atleast_1d(ixs)] return self._mask_subset_helper( @@ -656,15 +654,15 @@ def subset_fn( def mask( # noqa: D102 self, - src_mask: Optional[jnp.ndarray], - tgt_mask: Optional[jnp.ndarray], + src_mask: Optional[jax.Array], + tgt_mask: Optional[jax.Array], mask_value: float = 0., ) -> "PointCloud": def mask_fn( - arr: Optional[jnp.ndarray], - mask: Optional[jnp.ndarray], - ) -> Optional[jnp.ndarray]: + arr: Optional[jax.Array], + mask: Optional[jax.Array], + ) -> Optional[jax.Array]: if arr is None or mask is None: return arr return jnp.where(mask[:, None], arr, mask_value) @@ -677,11 +675,11 @@ def mask_fn( def _mask_subset_helper( self, - src_ixs: Optional[jnp.ndarray], - tgt_ixs: Optional[jnp.ndarray], + src_ixs: Optional[jax.Array], + tgt_ixs: Optional[jax.Array], *, - fn: Callable[[Optional[jnp.ndarray], Optional[jnp.ndarray]], - Optional[jnp.ndarray]], + fn: Callable[[Optional[jax.Array], Optional[jax.Array]], + Optional[jax.Array]], propagate_mask: bool, **kwargs: Any, ) -> "PointCloud": @@ -767,18 +765,18 @@ def _apply_cost_xy(x, y, norm_x, norm_y, vec, cost_fn, scale_cost, fn=None): fn(cost) matrix (or transpose) to vector. Args: - x: jnp.ndarray [num_a, d], first pointcloud - y: jnp.ndarray [num_b, d], second pointcloud - norm_x: jnp.ndarray [num_a,], (squared) norm as defined in by cost_fn - norm_y: jnp.ndarray [num_b,], (squared) norm as defined in by cost_fn - vec: jnp.ndarray [num_a,] ([num_b,] if axis=1 from `apply_cost`) vector + x: jax.Array [num_a, d], first pointcloud + y: jax.Array [num_b, d], second pointcloud + norm_x: jax.Array [num_a,], (squared) norm as defined in by cost_fn + norm_y: jax.Array [num_b,], (squared) norm as defined in by cost_fn + vec: jax.Array [num_a,] ([num_b,] if axis=1 from `apply_cost`) vector cost_fn: a CostFn function between two points in dimension d. scale_cost: scaling factor of the cost matrix. fn: function optionally applied to cost matrix element-wise, before the apply. Returns: - A jnp.ndarray corresponding to cost x vector + A jax.Array corresponding to cost x vector """ c = _cost(x, y, norm_x, norm_y, cost_fn, scale_cost) return jnp.dot(c, vec) if fn is None else jnp.dot(fn(c), vec) diff --git a/src/ott/geometry/segment.py b/src/ott/geometry/segment.py index 20a1ee92b..5e2c764c8 100644 --- a/src/ott/geometry/segment.py +++ b/src/ott/geometry/segment.py @@ -21,15 +21,15 @@ def segment_point_cloud( - x: jnp.ndarray, - a: Optional[jnp.ndarray] = None, + x: jax.Array, + a: Optional[jax.Array] = None, num_segments: Optional[int] = None, max_measure_size: Optional[int] = None, - segment_ids: Optional[jnp.ndarray] = None, + segment_ids: Optional[jax.Array] = None, indices_are_sorted: bool = False, num_per_segment: Optional[Tuple[int, ...]] = None, - padding_vector: Optional[jnp.ndarray] = None -) -> Tuple[jnp.ndarray, jnp.ndarray]: + padding_vector: Optional[jax.Array] = None +) -> Tuple[jax.Array, jax.Array]: """Segment and pad as needed the entries of a point cloud. There are two interfaces: @@ -129,21 +129,20 @@ def segment_point_cloud( def _segment_interface( - x: jnp.ndarray, - y: jnp.ndarray, - eval_fn: Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray], - jnp.ndarray], + x: jax.Array, + y: jax.Array, + eval_fn: Callable[[jax.Array, jax.Array, jax.Array, jax.Array], jax.Array], num_segments: Optional[int] = None, max_measure_size: Optional[int] = None, - segment_ids_x: Optional[jnp.ndarray] = None, - segment_ids_y: Optional[jnp.ndarray] = None, + segment_ids_x: Optional[jax.Array] = None, + segment_ids_y: Optional[jax.Array] = None, indices_are_sorted: bool = False, - num_per_segment_x: Optional[jnp.ndarray] = None, - num_per_segment_y: Optional[jnp.ndarray] = None, - weights_x: Optional[jnp.ndarray] = None, - weights_y: Optional[jnp.ndarray] = None, - padding_vector: Optional[jnp.ndarray] = None, -) -> jnp.ndarray: + num_per_segment_x: Optional[jax.Array] = None, + num_per_segment_y: Optional[jax.Array] = None, + weights_x: Optional[jax.Array] = None, + weights_y: Optional[jax.Array] = None, + padding_vector: Optional[jax.Array] = None, +) -> jax.Array: """Wrapper to segment two point clouds and return parallel evaluations. Utility function that segments two point clouds using the approach outlined diff --git a/src/ott/initializers/linear/initializers.py b/src/ott/initializers/linear/initializers.py index f3ba93321..58744cfb0 100644 --- a/src/ott/initializers/linear/initializers.py +++ b/src/ott/initializers/linear/initializers.py @@ -37,7 +37,7 @@ def init_dual_a( ot_prob: linear_problem.LinearProblem, lse_mode: bool, rng: Optional[jax.Array] = None, - ) -> jnp.ndarray: + ) -> jax.Array: """Initialize Sinkhorn potential/scaling f_u. Args: @@ -55,7 +55,7 @@ def init_dual_b( ot_prob: linear_problem.LinearProblem, lse_mode: bool, rng: Optional[jax.Array] = None, - ) -> jnp.ndarray: + ) -> jax.Array: """Initialize Sinkhorn potential/scaling g_v. Args: @@ -70,11 +70,11 @@ def init_dual_b( def __call__( self, ot_prob: linear_problem.LinearProblem, - a: Optional[jnp.ndarray], - b: Optional[jnp.ndarray], + a: Optional[jax.Array], + b: Optional[jax.Array], lse_mode: bool, rng: Optional[jax.Array] = None, - ) -> Tuple[jnp.ndarray, jnp.ndarray]: + ) -> Tuple[jax.Array, jax.Array]: """Initialize Sinkhorn potentials/scalings f_u and g_v. Args: @@ -129,7 +129,7 @@ def init_dual_a( # noqa: D102 ot_prob: linear_problem.LinearProblem, lse_mode: bool, rng: Optional[jax.Array] = None, - ) -> jnp.ndarray: + ) -> jax.Array: del rng return jnp.zeros_like(ot_prob.a) if lse_mode else jnp.ones_like(ot_prob.a) @@ -138,7 +138,7 @@ def init_dual_b( # noqa: D102 ot_prob: linear_problem.LinearProblem, lse_mode: bool, rng: Optional[jax.Array] = None, - ) -> jnp.ndarray: + ) -> jax.Array: del rng return jnp.zeros_like(ot_prob.b) if lse_mode else jnp.ones_like(ot_prob.b) @@ -159,7 +159,7 @@ def init_dual_a( # noqa: D102 ot_prob: linear_problem.LinearProblem, lse_mode: bool, rng: Optional[jax.Array] = None, - ) -> jnp.ndarray: + ) -> jax.Array: # import Gaussian here due to circular imports from ott.tools.gaussian_mixture import gaussian @@ -207,8 +207,8 @@ def __init__( self.vectorized_update = vectorized_update def _init_sorting_dual( - self, modified_cost: jnp.ndarray, init_f: jnp.ndarray - ) -> jnp.ndarray: + self, modified_cost: jax.Array, init_f: jax.Array + ) -> jax.Array: """Run DualSort algorithm. Args: @@ -221,15 +221,15 @@ def _init_sorting_dual( """ def body_fn( - state: Tuple[jnp.ndarray, float, int] - ) -> Tuple[jnp.ndarray, float, int]: + state: Tuple[jax.Array, float, int] + ) -> Tuple[jax.Array, float, int]: prev_f, _, it = state new_f = fn(prev_f, modified_cost) diff = jnp.sum((new_f - prev_f) ** 2) it += 1 return new_f, diff, it - def cond_fn(state: Tuple[jnp.ndarray, float, int]) -> bool: + def cond_fn(state: Tuple[jax.Array, float, int]) -> bool: _, diff, it = state return jnp.logical_and(diff > self.tolerance, it < self.max_iter) @@ -246,8 +246,8 @@ def init_dual_a( ot_prob: linear_problem.LinearProblem, lse_mode: bool, rng: Optional[jax.Array] = None, - init_f: Optional[jnp.ndarray] = None, - ) -> jnp.ndarray: + init_f: Optional[jax.Array] = None, + ) -> jax.Array: """Apply DualSort algorithm. Args: @@ -325,7 +325,7 @@ def init_dual_a( # noqa: D102 ot_prob: linear_problem.LinearProblem, lse_mode: bool, rng: Optional[jax.Array] = None, - ) -> jnp.ndarray: + ) -> jax.Array: from ott.solvers import linear assert isinstance( @@ -373,9 +373,7 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 }) -def _vectorized_update( - f: jnp.ndarray, modified_cost: jnp.ndarray -) -> jnp.ndarray: +def _vectorized_update(f: jax.Array, modified_cost: jax.Array) -> jax.Array: """Inner loop DualSort Update. Args: @@ -388,9 +386,7 @@ def _vectorized_update( return jnp.min(modified_cost + f[None, :], axis=1) -def _coordinate_update( - f: jnp.ndarray, modified_cost: jnp.ndarray -) -> jnp.ndarray: +def _coordinate_update(f: jax.Array, modified_cost: jax.Array) -> jax.Array: """Coordinate-wise updates within inner loop. Args: @@ -401,7 +397,7 @@ def _coordinate_update( updated potential vector, f. """ - def body_fn(i: int, f: jnp.ndarray) -> jnp.ndarray: + def body_fn(i: int, f: jax.Array) -> jax.Array: new_f = jnp.min(modified_cost[i, :] + f) return f.at[i].set(new_f) diff --git a/src/ott/initializers/linear/initializers_lr.py b/src/ott/initializers/linear/initializers_lr.py index 5c2302156..9eb8e1231 100644 --- a/src/ott/initializers/linear/initializers_lr.py +++ b/src/ott/initializers/linear/initializers_lr.py @@ -69,9 +69,9 @@ def init_q( ot_prob: Problem_t, rng: jax.Array, *, - init_g: jnp.ndarray, + init_g: jax.Array, **kwargs: Any, - ) -> jnp.ndarray: + ) -> jax.Array: """Initialize the low-rank factor :math:`Q`. Args: @@ -90,9 +90,9 @@ def init_r( ot_prob: Problem_t, rng: jax.Array, *, - init_g: jnp.ndarray, + init_g: jax.Array, **kwargs: Any, - ) -> jnp.ndarray: + ) -> jax.Array: """Initialize the low-rank factor :math:`R`. Args: @@ -111,7 +111,7 @@ def init_g( ot_prob: Problem_t, rng: jax.Array, **kwargs: Any, - ) -> jnp.ndarray: + ) -> jax.Array: """Initialize the low-rank factor :math:`g`. Args: @@ -165,13 +165,13 @@ def from_solver( def __call__( self, ot_prob: Problem_t, - q: Optional[jnp.ndarray] = None, - r: Optional[jnp.ndarray] = None, - g: Optional[jnp.ndarray] = None, + q: Optional[jax.Array] = None, + r: Optional[jax.Array] = None, + g: Optional[jax.Array] = None, *, rng: Optional[jax.Array] = None, **kwargs: Any - ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + ) -> Tuple[jax.Array, jax.Array, jax.Array]: """Initialize the factors :math:`Q`, :math:`R` and :math:`g`. Args: @@ -234,9 +234,9 @@ def init_q( # noqa: D102 ot_prob: Problem_t, rng: jax.Array, *, - init_g: jnp.ndarray, + init_g: jax.Array, **kwargs: Any, - ) -> jnp.ndarray: + ) -> jax.Array: del kwargs, init_g a = ot_prob.a init_q = jnp.abs(jax.random.normal(rng, (a.shape[0], self.rank))) @@ -247,9 +247,9 @@ def init_r( # noqa: D102 ot_prob: Problem_t, rng: jax.Array, *, - init_g: jnp.ndarray, + init_g: jax.Array, **kwargs: Any, - ) -> jnp.ndarray: + ) -> jax.Array: del kwargs, init_g b = ot_prob.b init_r = jnp.abs(jax.random.normal(rng, (b.shape[0], self.rank))) @@ -260,7 +260,7 @@ def init_g( # noqa: D102 ot_prob: Problem_t, rng: jax.Array, **kwargs: Any, - ) -> jnp.ndarray: + ) -> jax.Array: del kwargs init_g = jnp.abs(jax.random.uniform(rng, (self.rank,))) + 1. return init_g / jnp.sum(init_g) @@ -278,10 +278,10 @@ class Rank2Initializer(LRInitializer): def _compute_factor( self, ot_prob: Problem_t, - init_g: jnp.ndarray, + init_g: jax.Array, *, which: Literal["q", "r"], - ) -> jnp.ndarray: + ) -> jax.Array: a, b = ot_prob.a, ot_prob.b marginal = a if which == "q" else b n, r = marginal.shape[0], self.rank @@ -307,9 +307,9 @@ def init_q( # noqa: D102 ot_prob: Problem_t, rng: jax.Array, *, - init_g: jnp.ndarray, + init_g: jax.Array, **kwargs: Any, - ) -> jnp.ndarray: + ) -> jax.Array: del rng, kwargs return self._compute_factor(ot_prob, init_g, which="q") @@ -318,9 +318,9 @@ def init_r( # noqa: D102 ot_prob: Problem_t, rng: jax.Array, *, - init_g: jnp.ndarray, + init_g: jax.Array, **kwargs: Any, - ) -> jnp.ndarray: + ) -> jax.Array: del rng, kwargs return self._compute_factor(ot_prob, init_g, which="r") @@ -329,7 +329,7 @@ def init_g( # noqa: D102 ot_prob: Problem_t, rng: jax.Array, **kwargs: Any, - ) -> jnp.ndarray: + ) -> jax.Array: del rng, kwargs return jnp.ones((self.rank,)) / self.rank @@ -364,7 +364,7 @@ def __init__( self._sinkhorn_kwargs = {} if sinkhorn_kwargs is None else sinkhorn_kwargs @staticmethod - def _extract_array(geom: geometry.Geometry, *, first: bool) -> jnp.ndarray: + def _extract_array(geom: geometry.Geometry, *, first: bool) -> jax.Array: if isinstance(geom, pointcloud.PointCloud): return geom.x if first else geom.y if isinstance(geom, low_rank.LRCGeometry): @@ -378,10 +378,10 @@ def _compute_factor( ot_prob: Problem_t, rng: jax.Array, *, - init_g: jnp.ndarray, + init_g: jax.Array, which: Literal["q", "r"], **kwargs: Any, - ) -> jnp.ndarray: + ) -> jax.Array: from ott.problems.linear import linear_problem from ott.problems.quadratic import quadratic_problem from ott.solvers.linear import sinkhorn @@ -420,9 +420,9 @@ def init_q( # noqa: D102 ot_prob: Problem_t, rng: jax.Array, *, - init_g: jnp.ndarray, + init_g: jax.Array, **kwargs: Any, - ) -> jnp.ndarray: + ) -> jax.Array: return self._compute_factor( ot_prob, rng, init_g=init_g, which="q", **kwargs ) @@ -432,9 +432,9 @@ def init_r( # noqa: D102 ot_prob: Problem_t, rng: jax.Array, *, - init_g: jnp.ndarray, + init_g: jax.Array, **kwargs: Any, - ) -> jnp.ndarray: + ) -> jax.Array: return self._compute_factor( ot_prob, rng, init_g=init_g, which="r", **kwargs ) @@ -444,7 +444,7 @@ def init_g( # noqa: D102 ot_prob: Problem_t, rng: jax.Array, **kwargs: Any, - ) -> jnp.ndarray: + ) -> jax.Array: del rng, kwargs return jnp.ones((self.rank,)) / self.rank @@ -498,14 +498,14 @@ def __init__( class Constants(NamedTuple): # noqa: D106 solver: "sinkhorn.Sinkhorn" geom: geometry.Geometry # (n, n) - marginal: jnp.ndarray # (n,) - g: jnp.ndarray # (r,) + marginal: jax.Array # (n,) + g: jax.Array # (r,) gamma: float threshold: float class State(NamedTuple): # noqa: D106 - factor: jnp.ndarray - criterions: jnp.ndarray + factor: jax.Array + criterions: jax.Array crossed_threshold: bool def _compute_factor( @@ -513,10 +513,10 @@ def _compute_factor( ot_prob: Problem_t, rng: jax.Array, *, - init_g: jnp.ndarray, + init_g: jax.Array, which: Literal["q", "r"], **kwargs: Any, - ) -> jnp.ndarray: + ) -> jax.Array: from ott.problems.linear import linear_problem from ott.problems.quadratic import quadratic_problem from ott.solvers.linear import sinkhorn diff --git a/src/ott/initializers/quadratic/initializers.py b/src/ott/initializers/quadratic/initializers.py index 795e81ccc..323570770 100644 --- a/src/ott/initializers/quadratic/initializers.py +++ b/src/ott/initializers/quadratic/initializers.py @@ -125,9 +125,7 @@ class QuadraticInitializer(BaseQuadraticInitializer): defaults to the product coupling :math:`ab^T`. """ - def __init__( - self, init_coupling: Optional[jnp.ndarray] = None, **kwargs: Any - ): + def __init__(self, init_coupling: Optional[jax.Array] = None, **kwargs: Any): super().__init__(**kwargs) self.init_coupling = init_coupling diff --git a/src/ott/math/fixed_point_loop.py b/src/ott/math/fixed_point_loop.py index 9034eba62..5c8b7b94d 100644 --- a/src/ott/math/fixed_point_loop.py +++ b/src/ott/math/fixed_point_loop.py @@ -179,7 +179,7 @@ def fixpoint_iter_bwd( # The tree may contain some python floats g_constants = jax.tree_util.tree_map( lambda x: jnp.zeros_like(x, dtype=x.dtype) - if isinstance(x, (np.ndarray, jnp.ndarray)) else 0, constants + if isinstance(x, (np.ndarray, jax.Array)) else 0, constants ) def bwd_cond_fn(iteration_g_gconst): diff --git a/src/ott/math/matrix_square_root.py b/src/ott/math/matrix_square_root.py index 4a0177780..5089f14a0 100644 --- a/src/ott/math/matrix_square_root.py +++ b/src/ott/math/matrix_square_root.py @@ -25,13 +25,13 @@ @functools.partial(jax.custom_vjp, nondiff_argnums=(1, 2, 3, 4, 5)) def sqrtm( - x: jnp.ndarray, + x: jax.Array, threshold: float = 1e-6, min_iterations: int = 0, inner_iterations: int = 10, max_iterations: int = 1000, regularization: float = 1e-6 -) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: +) -> Tuple[jax.Array, jax.Array, jax.Array]: """Higham algorithm to compute matrix square root of p.d. matrix. See :cite:`higham:97`, eq. 2.6b @@ -118,10 +118,10 @@ def new_err(x, norm_x, y): def solve_sylvester_bartels_stewart( - a: jnp.ndarray, - b: jnp.ndarray, - c: jnp.ndarray, -) -> jnp.ndarray: + a: jax.Array, + b: jax.Array, + c: jax.Array, +) -> jax.Array: """Solve the real Sylvester equation AX - XB = C using Bartels-Stewart.""" # See https://nhigham.com/2020/09/01/what-is-the-sylvester-equation/ for # discussion of the algorithm (but note that in the derivation, the sign on @@ -153,14 +153,13 @@ def solve_sylvester_bartels_stewart( def sqrtm_fwd( - x: jnp.ndarray, + x: jax.Array, threshold: float, min_iterations: int, inner_iterations: int, max_iterations: int, regularization: float, -) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, - jnp.ndarray]]: +) -> Tuple[Tuple[jax.Array, jax.Array, jax.Array], Tuple[jax.Array, jax.Array]]: """Forward pass of custom VJP.""" sqrt_x, inv_sqrt_x, errors = sqrtm( x=x, @@ -179,9 +178,9 @@ def sqrtm_bwd( inner_iterations: int, max_iterations: int, regularization: float, - residual: Tuple[jnp.ndarray, jnp.ndarray], - cotangent: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], -) -> Tuple[jnp.ndarray]: + residual: Tuple[jax.Array, jax.Array], + cotangent: Tuple[jax.Array, jax.Array, jax.Array], +) -> Tuple[jax.Array]: """Compute the derivative by solving a Sylvester equation.""" del threshold, min_iterations, inner_iterations, \ max_iterations, regularization @@ -237,13 +236,13 @@ def sqrtm_bwd( @functools.partial(jax.custom_vjp, nondiff_argnums=(1, 2, 3, 4, 5)) def sqrtm_only( # noqa: D103 - x: jnp.ndarray, + x: jax.Array, threshold: float = 1e-6, min_iterations: int = 0, inner_iterations: int = 10, max_iterations: int = 1000, regularization: float = 1e-6 -) -> jnp.ndarray: +) -> jax.Array: return sqrtm( x, threshold, min_iterations, inner_iterations, max_iterations, regularization @@ -251,9 +250,9 @@ def sqrtm_only( # noqa: D103 def sqrtm_only_fwd( # noqa: D103 - x: jnp.ndarray, threshold: float, min_iterations: int, + x: jax.Array, threshold: float, min_iterations: int, inner_iterations: int, max_iterations: int, regularization: float -) -> Tuple[jnp.ndarray, jnp.ndarray]: +) -> Tuple[jax.Array, jax.Array]: sqrt_x = sqrtm( x, threshold, min_iterations, inner_iterations, max_iterations, regularization @@ -263,9 +262,9 @@ def sqrtm_only_fwd( # noqa: D103 def sqrtm_only_bwd( # noqa: D103 threshold: float, min_iterations: int, inner_iterations: int, - max_iterations: int, regularization: float, sqrt_x: jnp.ndarray, - cotangent: jnp.ndarray -) -> Tuple[jnp.ndarray]: + max_iterations: int, regularization: float, sqrt_x: jax.Array, + cotangent: jax.Array +) -> Tuple[jax.Array]: del threshold, min_iterations, inner_iterations, \ max_iterations, regularization vjp = jnp.swapaxes( @@ -283,13 +282,13 @@ def sqrtm_only_bwd( # noqa: D103 @functools.partial(jax.custom_vjp, nondiff_argnums=(1, 2, 3, 4, 5)) def inv_sqrtm_only( # noqa: D103 - x: jnp.ndarray, + x: jax.Array, threshold: float = 1e-6, min_iterations: int = 0, inner_iterations: int = 10, max_iterations: int = 1000, regularization: float = 1e-6 -) -> jnp.ndarray: +) -> jax.Array: return sqrtm( x, threshold, min_iterations, inner_iterations, max_iterations, regularization @@ -297,13 +296,13 @@ def inv_sqrtm_only( # noqa: D103 def inv_sqrtm_only_fwd( # noqa: D103 - x: jnp.ndarray, + x: jax.Array, threshold: float, min_iterations: int, inner_iterations: int, max_iterations: int, regularization: float, -) -> Tuple[jnp.ndarray, jnp.ndarray]: +) -> Tuple[jax.Array, jax.Array]: inv_sqrt_x = sqrtm( x, threshold, min_iterations, inner_iterations, max_iterations, regularization @@ -313,9 +312,9 @@ def inv_sqrtm_only_fwd( # noqa: D103 def inv_sqrtm_only_bwd( # noqa: D103 threshold: float, min_iterations: int, inner_iterations: int, - max_iterations: int, regularization: float, residual: jnp.ndarray, - cotangent: jnp.ndarray -) -> Tuple[jnp.ndarray]: + max_iterations: int, regularization: float, residual: jax.Array, + cotangent: jax.Array +) -> Tuple[jax.Array]: del threshold, min_iterations, inner_iterations, \ max_iterations, regularization diff --git a/src/ott/math/unbalanced_functions.py b/src/ott/math/unbalanced_functions.py index fc1aca9f3..2d7baebb7 100644 --- a/src/ott/math/unbalanced_functions.py +++ b/src/ott/math/unbalanced_functions.py @@ -13,31 +13,32 @@ # limitations under the License. from typing import Callable +import jax import jax.numpy as jnp -def phi_star(h: jnp.ndarray, rho: float) -> jnp.ndarray: +def phi_star(h: jax.Array, rho: float) -> jax.Array: """Legendre transform of KL, :cite:`sejourne:19`, p. 9.""" return rho * (jnp.exp(h / rho) - 1) -def derivative_phi_star(f: jnp.ndarray, rho: float) -> jnp.ndarray: +def derivative_phi_star(f: jax.Array, rho: float) -> jax.Array: """Derivative of Legendre transform of phi_starKL, see phi_star.""" # TODO(cuturi): use jax.grad directly. return jnp.exp(f / rho) def grad_of_marginal_fit( - c: jnp.ndarray, h: jnp.ndarray, tau: float, epsilon: float -) -> jnp.ndarray: + c: jax.Array, h: jax.Array, tau: float, epsilon: float +) -> jax.Array: """Compute grad of terms linked to marginals in objective. Computes gradient w.r.t. f ( or g) of terms in :cite:`sejourne:19`, left-hand-side of eq. 15 terms involving phi_star). Args: - c: jnp.ndarray, first target marginal (either a or b in practice) - h: jnp.ndarray, potential (either f or g in practice) + c: jax.Array, first target marginal (either a or b in practice) + h: jax.Array, potential (either f or g in practice) tau: float, strength (in ]0,1]) of regularizer w.r.t. marginal epsilon: regularization @@ -50,14 +51,14 @@ def grad_of_marginal_fit( return jnp.where(c > 0, c * derivative_phi_star(-h, r), 0.0) -def second_derivative_phi_star(f: jnp.ndarray, rho: float) -> jnp.ndarray: +def second_derivative_phi_star(f: jax.Array, rho: float) -> jax.Array: """Second Derivative of Legendre transform of KL, see phi_star.""" return jnp.exp(f / rho) / rho def diag_jacobian_of_marginal_fit( - c: jnp.ndarray, h: jnp.ndarray, tau: float, epsilon: float, - derivative: Callable[[jnp.ndarray, float], jnp.ndarray] + c: jax.Array, h: jax.Array, tau: float, epsilon: float, + derivative: Callable[[jax.Array, float], jax.Array] ): """Compute grad of terms linked to marginals in objective. @@ -65,8 +66,8 @@ def diag_jacobian_of_marginal_fit( left-hand-side of eq. 32 (terms involving phi_star) Args: - c: jnp.ndarray, first target marginal (either a or b in practice) - h: jnp.ndarray, potential (either f or g in practice) + c: jax.Array, first target marginal (either a or b in practice) + h: jax.Array, potential (either f or g in practice) tau: float, strength (in ]0,1]) of regularizer w.r.t. marginal epsilon: regularization derivative: Callable diff --git a/src/ott/math/utils.py b/src/ott/math/utils.py index 8e7ea90ee..188707c10 100644 --- a/src/ott/math/utils.py +++ b/src/ott/math/utils.py @@ -34,10 +34,10 @@ def safe_log( # noqa: D103 - x: jnp.ndarray, + x: jax.Array, *, eps: Optional[float] = None -) -> jnp.ndarray: +) -> jax.Array: if eps is None: eps = jnp.finfo(x.dtype).tiny return jnp.where(x > 0., jnp.log(x), jnp.log(eps)) @@ -46,11 +46,11 @@ def safe_log( # noqa: D103 @functools.partial(jax.custom_jvp, nondiff_argnums=[1, 2, 3]) @functools.partial(jax.jit, static_argnames=("ord", "axis", "keepdims")) def norm( - x: jnp.ndarray, + x: jax.Array, ord: Union[int, str, None] = None, axis: Union[None, Sequence[int], int] = None, keepdims: bool = False -) -> jnp.ndarray: +) -> jax.Array: """Computes order ord norm of vector, using `jnp.linalg` in forward pass. Evaluations of distances between a vector and itself using translation @@ -105,18 +105,18 @@ def norm_jvp(ord, axis, keepdims, primals, tangents): # TODO(michalk8): add axis argument -def kl(p: jnp.ndarray, q: jnp.ndarray) -> float: +def kl(p: jax.Array, q: jax.Array) -> float: """Kullback-Leibler divergence.""" return jnp.vdot(p, (safe_log(p) - safe_log(q))) -def gen_kl(p: jnp.ndarray, q: jnp.ndarray) -> float: +def gen_kl(p: jax.Array, q: jax.Array) -> float: """Generalized Kullback-Leibler divergence.""" return jnp.vdot(p, (safe_log(p) - safe_log(q))) + jnp.sum(q) - jnp.sum(p) # TODO(michalk8): add axis argument -def gen_js(p: jnp.ndarray, q: jnp.ndarray, c: float = 0.5) -> float: +def gen_js(p: jax.Array, q: jax.Array, c: float = 0.5) -> float: """Jensen-Shannon divergence.""" return c * (gen_kl(p, q) + gen_kl(q, p)) @@ -176,8 +176,8 @@ def logsumexp_jvp(axis, keepdims, return_sign, primals, tangents): @functools.partial(jax.custom_vjp, nondiff_argnums=(2,)) def softmin( - x: jnp.ndarray, gamma: float, axis: Optional[int] = None -) -> jnp.ndarray: + x: jax.Array, gamma: float, axis: Optional[int] = None +) -> jax.Array: r"""Soft-min operator. Args: @@ -205,8 +205,8 @@ def softmin( @functools.partial(jax.vmap, in_axes=[0, 0, None]) def barycentric_projection( - matrix: jnp.ndarray, y: jnp.ndarray, cost_fn: "costs.CostFn" -) -> jnp.ndarray: + matrix: jax.Array, y: jax.Array, cost_fn: "costs.CostFn" +) -> jax.Array: """Compute the barycentric projection of a matrix. Args: diff --git a/src/ott/neural/data/dataloaders.py b/src/ott/neural/data/dataloaders.py index acceb36c1..0ebfc77a0 100644 --- a/src/ott/neural/data/dataloaders.py +++ b/src/ott/neural/data/dataloaders.py @@ -26,7 +26,7 @@ class ConditionalDataLoader: #TODO(@MUCDK) uncomment, resolve installation issu # self.conditions = dataloaders.keys() # self.p = p - #def __next__(self) -> jnp.ndarray: + #def __next__(self) -> jax.Array: # self.rng, rng = jax.random.split(self.rng, 2) # condition = jax.random.choice(rng, self.conditions, p=self.p) # return next(self.dataloaders[condition]) diff --git a/src/ott/neural/models/conjugate_solvers.py b/src/ott/neural/models/conjugate_solvers.py index 0758cf1ad..4d3d8eea0 100644 --- a/src/ott/neural/models/conjugate_solvers.py +++ b/src/ott/neural/models/conjugate_solvers.py @@ -14,6 +14,7 @@ import abc from typing import Callable, Literal, NamedTuple, Optional +import jax import jax.numpy as jnp from jaxopt import LBFGS @@ -36,7 +37,7 @@ class ConjugateResults(NamedTuple): num_iter: the number of iterations taken by the solver """ val: float - grad: jnp.ndarray + grad: jax.Array num_iter: int @@ -50,9 +51,9 @@ class FenchelConjugateSolver(abc.ABC): @abc.abstractmethod def solve( self, - f: Callable[[jnp.ndarray], jnp.ndarray], - y: jnp.ndarray, - x_init: Optional[jnp.ndarray] = None + f: Callable[[jax.Array], jax.Array], + y: jax.Array, + x_init: Optional[jax.Array] = None ) -> ConjugateResults: """Solve for the conjugate. @@ -90,8 +91,8 @@ class FenchelConjugateLBFGS(FenchelConjugateSolver): def solve( # noqa: D102 self, - f: Callable[[jnp.ndarray], jnp.ndarray], - y: jnp.ndarray, + f: Callable[[jax.Array], jax.Array], + y: jax.Array, x_init: Optional[jnp.array] = None ) -> ConjugateResults: assert y.ndim == 1, y.ndim diff --git a/src/ott/neural/models/layers.py b/src/ott/neural/models/layers.py index dfd222c60..0eac7e626 100644 --- a/src/ott/neural/models/layers.py +++ b/src/ott/neural/models/layers.py @@ -40,9 +40,9 @@ class PositiveDense(nn.Module): bias_init: initializer function for the bias. """ dim_hidden: int - rectifier_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.softplus - inv_rectifier_fn: Callable[[jnp.ndarray], - jnp.ndarray] = lambda x: jnp.log(jnp.exp(x) - 1) + rectifier_fn: Callable[[jax.Array], jax.Array] = nn.softplus + inv_rectifier_fn: Callable[[jax.Array], + jax.Array] = lambda x: jnp.log(jnp.exp(x) - 1) use_bias: bool = True dtype: Any = jnp.float32 precision: Any = None @@ -51,7 +51,7 @@ class PositiveDense(nn.Module): bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.zeros @nn.compact - def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + def __call__(self, inputs: jax.Array) -> jax.Array: """Applies a linear transformation to inputs along the last dimension. Args: @@ -99,7 +99,7 @@ class PosDefPotentials(nn.Module): bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.zeros @nn.compact - def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + def __call__(self, inputs: jax.Array) -> jax.Array: """Apply a few quadratic forms. Args: diff --git a/src/ott/neural/models/models.py b/src/ott/neural/models/models.py index 853a1d69e..5ec8fb292 100644 --- a/src/ott/neural/models/models.py +++ b/src/ott/neural/models/models.py @@ -62,9 +62,9 @@ class ICNN(neuraldual.BaseW2NeuralDual): dim_hidden: Sequence[int] init_std: float = 1e-2 init_fn: Callable = jax.nn.initializers.normal - act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + act_fn: Callable[[jax.Array], jax.Array] = nn.relu pos_weights: bool = True - gaussian_map_samples: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None + gaussian_map_samples: Optional[Tuple[jax.Array, jax.Array]] = None @property def is_potential(self) -> bool: # noqa: D102 @@ -146,8 +146,8 @@ def setup(self) -> None: # noqa: D102 @staticmethod def _compute_gaussian_map_params( - samples: Tuple[jnp.ndarray, jnp.ndarray] - ) -> Tuple[jnp.ndarray, jnp.ndarray]: + samples: Tuple[jax.Array, jax.Array] + ) -> Tuple[jax.Array, jax.Array]: from ott.tools.gaussian_mixture import gaussian source, target = samples g_s = gaussian.Gaussian.from_samples(source) @@ -160,13 +160,13 @@ def _compute_gaussian_map_params( @staticmethod def _compute_identity_map_params( input_dim: int - ) -> Tuple[jnp.ndarray, jnp.ndarray]: + ) -> Tuple[jax.Array, jax.Array]: A = jnp.eye(input_dim).reshape((1, input_dim, input_dim)) b = jnp.zeros((1, input_dim)) return A, b @nn.compact - def __call__(self, x: jnp.ndarray) -> float: # noqa: D102 + def __call__(self, x: jax.Array) -> float: # noqa: D102 z = self.act_fn(self.w_xs[0](x)) for i in range(self.num_hidden): z = jnp.add(self.w_zs[i](z), self.w_xs[i + 1](x)) @@ -189,10 +189,10 @@ class MLP(neuraldual.BaseW2NeuralDual): dim_hidden: Sequence[int] is_potential: bool = True - act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.leaky_relu + act_fn: Callable[[jax.Array], jax.Array] = nn.leaky_relu @nn.compact - def __call__(self, x: jnp.ndarray) -> jnp.ndarray: # noqa: D102 + def __call__(self, x: jax.Array) -> jax.Array: # noqa: D102 squeeze = x.ndim == 1 if squeeze: x = jnp.expand_dims(x, 0) @@ -289,8 +289,8 @@ def __init__( self.update_impl = self._get_update_fn() def update( - self, state: train_state.TrainState, a: jnp.ndarray, b: jnp.ndarray - ) -> Tuple[jnp.ndarray, jnp.ndarray, train_state.TrainState]: + self, state: train_state.TrainState, a: jax.Array, b: jax.Array + ) -> Tuple[jax.Array, jax.Array, train_state.TrainState]: r"""Update the meta model with the dual objective. The goal is for the model to match the optimal duals, i.e., @@ -329,7 +329,7 @@ def init_dual_a( # noqa: D102 ot_prob: "linear_problem.LinearProblem", lse_mode: bool, rng: Optional[jax.Array] = None, - ) -> jnp.ndarray: + ) -> jax.Array: del rng # Detect if the problem is batched. assert ot_prob.a.ndim in (1, 2) @@ -382,9 +382,9 @@ def update(state, a, b): return update def _compute_f( - self, a: jnp.ndarray, b: jnp.ndarray, - params: frozen_dict.FrozenDict[str, jnp.ndarray] - ) -> jnp.ndarray: + self, a: jax.Array, b: jax.Array, + params: frozen_dict.FrozenDict[str, jax.Array] + ) -> jax.Array: r"""Predict the optimal :math:`f` potential. Args: @@ -427,7 +427,7 @@ def __call__( x: jax.Array, condition: Optional[jax.Array] = None, keys_model: Optional[jax.Array] = None - ) -> jnp.ndarray: # noqa: D102): + ) -> jax.Array: # noqa: D102): pass @@ -439,7 +439,7 @@ class NeuralVectorField(BaseNeuralVectorField): t_embed_dim: Optional[int] = None joint_hidden_dim: Optional[int] = None num_layers_per_block: int = 3 - act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.silu + act_fn: Callable[[jax.Array], jax.Array] = nn.silu n_frequencies: int = 128 def time_encoder(self, t: jax.Array) -> jnp.array: @@ -554,12 +554,12 @@ class Rescaling_MLP(nn.Module): hidden_dim: int cond_dim: int is_potential: bool = False - act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.selu + act_fn: Callable[[jax.Array], jax.Array] = nn.selu @nn.compact def __call__( - self, x: jnp.ndarray, condition: Optional[jax.Array] - ) -> jnp.ndarray: # noqa: D102 + self, x: jax.Array, condition: Optional[jax.Array] + ) -> jax.Array: # noqa: D102 x = Block( dim=self.latent_embed_dim, out_dim=self.latent_embed_dim, diff --git a/src/ott/neural/solvers/base_solver.py b/src/ott/neural/solvers/base_solver.py index 66d3ecbef..69f510d81 100644 --- a/src/ott/neural/solvers/base_solver.py +++ b/src/ott/neural/solvers/base_solver.py @@ -81,10 +81,10 @@ def __init__(*args, **kwargs): def _resample_data( self, key: jax.random.KeyArray, - tmat: jnp.ndarray, - source_arrays: Tuple[jnp.ndarray, ...], - target_arrays: Tuple[jnp.ndarray, ...], - ) -> Tuple[jnp.ndarray, ...]: + tmat: jax.Array, + source_arrays: Tuple[jax.Array, ...], + target_arrays: Tuple[jax.Array, ...], + ) -> Tuple[jax.Array, ...]: """Resample a batch according to coupling `tmat`.""" tmat_flattened = tmat.flatten() indices = random.choice( @@ -101,10 +101,10 @@ def _resample_data( def _sample_conditional_indices_from_tmap( self, key: jax.random.PRNGKeyArray, - tmat: jnp.ndarray, - k_samples_per_x: Union[int, jnp.ndarray], - source_arrays: Tuple[jnp.ndarray, ...], - target_arrays: Tuple[jnp.ndarray, ...], + tmat: jax.Array, + k_samples_per_x: Union[int, jax.Array], + source_arrays: Tuple[jax.Array, ...], + target_arrays: Tuple[jax.Array, ...], *, source_is_balanced: bool, ) -> Tuple[jnp.array, jnp.array]: @@ -155,7 +155,7 @@ def _get_sinkhorn_match_fn( def match_pairs( x: jax.Array, y: jax.Array - ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: + ) -> Tuple[jax.Array, jax.Array, jax.Array, jax.Array]: geom = pointcloud.PointCloud( x, y, epsilon=epsilon, scale_cost=scale_cost, cost_fn=cost_fn ) @@ -165,7 +165,7 @@ def match_pairs( def match_pairs_filtered( x_lin: jax.Array, x_quad: jax.Array, y_lin: jax.Array, y_quad: jax.Array - ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: + ) -> Tuple[jax.Array, jax.Array, jax.Array, jax.Array]: geom = pointcloud.PointCloud( x_lin, y_lin, epsilon=epsilon, scale_cost=scale_cost, cost_fn=cost_fn ) @@ -208,9 +208,9 @@ def _get_gromov_match_fn( def match_pairs( x_lin: Optional[jax.Array], - x_quad: Tuple[jnp.ndarray, jnp.ndarray], + x_quad: Tuple[jax.Array, jax.Array], y_lin: Optional[jax.Array], - y_quad: Tuple[jnp.ndarray, jnp.ndarray], + y_quad: Tuple[jax.Array, jax.Array], ) -> Tuple[jnp.array, jnp.array]: geom_xx = pointcloud.PointCloud( x=x_quad, y=x_quad, cost_fn=x_cost_fn, scale_cost=x_scale_cost @@ -288,13 +288,13 @@ def _get_compute_unbalanced_marginals( scale_cost: Union[bool, int, float, Literal["mean", "max_cost", "median"]] = "mean", sinkhorn_kwargs: Dict[str, Any] = MappingProxyType({}), - ) -> Tuple[jnp.ndarray, jnp.ndarray]: + ) -> Tuple[jax.Array, jax.Array]: """Compute the unbalanced source and target marginals for a batch.""" @jax.jit def compute_unbalanced_marginals( - batch_source: jnp.ndarray, batch_target: jnp.ndarray - ) -> Tuple[jnp.ndarray, jnp.ndarray]: + batch_source: jax.Array, batch_target: jax.Array + ) -> Tuple[jax.Array, jax.Array]: geom = PointCloud( batch_source, batch_target, @@ -312,9 +312,9 @@ def compute_unbalanced_marginals( def _resample_unbalanced( self, key: jax.random.KeyArray, - batch: Tuple[jnp.ndarray, ...], - marginals: jnp.ndarray, - ) -> Tuple[jnp.ndarray, ...]: + batch: Tuple[jax.Array, ...], + marginals: jax.Array, + ) -> Tuple[jax.Array, ...]: """Resample a batch based upon marginals.""" indices = jax.random.choice( key, a=len(marginals), p=jnp.squeeze(marginals), shape=[len(marginals)] @@ -343,13 +343,12 @@ def _setup(self, source_dim: int, target_dim: int, cond_dim: int): def _get_step_fn(self) -> Callable: # type:ignore[type-arg] def loss_a_fn( - params_eta: Optional[jnp.ndarray], - apply_fn_eta: Callable[[Dict[str, jnp.ndarray], jnp.ndarray], - jnp.ndarray], - x: jnp.ndarray, - a: jnp.ndarray, + params_eta: Optional[jax.Array], + apply_fn_eta: Callable[[Dict[str, jax.Array], jax.Array], jax.Array], + x: jax.Array, + a: jax.Array, expectation_reweighting: float, - ) -> Tuple[float, jnp.ndarray]: + ) -> Tuple[float, jax.Array]: eta_predictions = apply_fn_eta({"params": params_eta}, x) return ( optax.l2_loss(eta_predictions[:, 0], a).mean() + @@ -358,13 +357,12 @@ def loss_a_fn( ) def loss_b_fn( - params_xi: Optional[jnp.ndarray], - apply_fn_xi: Callable[[Dict[str, jnp.ndarray], jnp.ndarray], - jnp.ndarray], - x: jnp.ndarray, - b: jnp.ndarray, + params_xi: Optional[jax.Array], + apply_fn_xi: Callable[[Dict[str, jax.Array], jax.Array], jax.Array], + x: jax.Array, + b: jax.Array, expectation_reweighting: float, - ) -> Tuple[float, jnp.ndarray]: + ) -> Tuple[float, jax.Array]: xi_predictions = apply_fn_xi({"params": params_xi}, x) return ( optax.l2_loss(xi_predictions[:, 0], b).mean() + @@ -374,11 +372,11 @@ def loss_b_fn( @jax.jit def step_fn( - source: jnp.ndarray, - target: jnp.ndarray, - condition: Optional[jnp.ndarray], - a: jnp.ndarray, - b: jnp.ndarray, + source: jax.Array, + target: jax.Array, + condition: Optional[jax.Array], + a: jax.Array, + b: jax.Array, state_eta: Optional[train_state.TrainState] = None, state_xi: Optional[train_state.TrainState] = None, *, diff --git a/src/ott/neural/solvers/flows.py b/src/ott/neural/solvers/flows.py index 19c3d2f67..6552048fb 100644 --- a/src/ott/neural/solvers/flows.py +++ b/src/ott/neural/solvers/flows.py @@ -141,16 +141,24 @@ def compute_sigma_t(self, t: jax.Array): class BaseTimeSampler(abc.ABC): - """Base class for time samplers.""" + """Base class for time samplers. + + Args: + low: Lower bound of the distribution to sample from. + high: Upper bound of the distribution to sample from . + """ + + def __init__(self, low: float, high: float) -> None: + self.low = low + self.high = high @abc.abstractmethod - def __call__(self, rng: jnp.ndarray, num_samples: int) -> jnp.ndarray: + def __call__(self, rng: jax.Array, num_samples: int) -> jax.Array: """Generate `num_samples` samples of the time `math`:t:. Args: rng: Random number generator. num_samples: Number of samples to generate. - """ pass @@ -163,11 +171,7 @@ class UniformSampler(BaseTimeSampler): high: Upper bound of the uniform distribution. """ - def __init__(self, low: float = 0.0, high: float = 1.0) -> None: - self.low = low - self.high = high - - def __call__(self, rng: jnp.ndarray, num_samples: int) -> jnp.ndarray: + def __call__(self, rng: jax.Array, num_samples: int) -> jax.Array: """Generate `num_samples` samples of the time `math`:t:. Args: @@ -194,11 +198,10 @@ class OffsetUniformSampler(BaseTimeSampler): def __init__( self, offset: float, low: float = 0.0, high: float = 1.0 ) -> None: + super().__init__(low=low, high=high) self.offset = offset - self.low = low - self.high = high - def __call__(self, rng: jnp.ndarray, num_samples: int) -> jnp.ndarray: + def __call__(self, rng: jax.Array, num_samples: int) -> jax.Array: """Generate `num_samples` samples of the time `math`:t:. Args: diff --git a/src/ott/neural/solvers/genot.py b/src/ott/neural/solvers/genot.py index 3d6b3fafb..efdf5af29 100644 --- a/src/ott/neural/solvers/genot.py +++ b/src/ott/neural/solvers/genot.py @@ -83,10 +83,10 @@ def __init__( fused_penalty: float = 0.0, tau_a: float = 1.0, tau_b: float = 1.0, - mlp_eta: Callable[[jnp.ndarray], float] = None, - mlp_xi: Callable[[jnp.ndarray], float] = None, + mlp_eta: Callable[[jax.Array], float] = None, + mlp_xi: Callable[[jax.Array], float] = None, unbalanced_kwargs: Dict[str, Any] = {}, - callback: Optional[Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], + callback: Optional[Callable[[jax.Array, jax.Array, jax.Array], Any]] = None, callback_kwargs: Dict[str, Any] = {}, callback_iters: int = 10, @@ -397,7 +397,7 @@ def transport( rng: random.PRNGKeyArray = random.PRNGKey(0), diffeqsolve_kwargs: Dict[str, Any] = types.MappingProxyType({}), forward: bool = True, - ) -> Union[jnp.array, diffrax.Solution, Optional[jnp.ndarray]]: + ) -> Union[jnp.array, diffrax.Solution, Optional[jax.Array]]: """Transport the distribution. Parameters @@ -468,5 +468,5 @@ def training_logs(self) -> Dict[str, Any]: def sample_noise( #TODO: make more general self, key: random.PRNGKey, batch_size: int - ) -> jnp.ndarray: #TODO: make more general + ) -> jax.Array: #TODO: make more general return random.normal(key, shape=(batch_size, self.output_dim)) diff --git a/src/ott/neural/solvers/losses.py b/src/ott/neural/solvers/losses.py index bec0f3916..fbf091b22 100644 --- a/src/ott/neural/solvers/losses.py +++ b/src/ott/neural/solvers/losses.py @@ -25,8 +25,8 @@ def monge_gap( - map_fn: Callable[[jnp.ndarray], jnp.ndarray], - reference_points: jnp.ndarray, + map_fn: Callable[[jax.Array], jax.Array], + reference_points: jax.Array, cost_fn: Optional[costs.CostFn] = None, epsilon: Optional[float] = None, relative_epsilon: Optional[bool] = None, @@ -91,8 +91,8 @@ def monge_gap( def monge_gap_from_samples( - source: jnp.ndarray, - target: jnp.ndarray, + source: jax.Array, + target: jax.Array, cost_fn: Optional[costs.CostFn] = None, epsilon: Optional[float] = None, relative_epsilon: Optional[bool] = None, diff --git a/src/ott/neural/solvers/map_estimator.py b/src/ott/neural/solvers/map_estimator.py index 27745f9ca..53bcdc7dd 100644 --- a/src/ott/neural/solvers/map_estimator.py +++ b/src/ott/neural/solvers/map_estimator.py @@ -79,9 +79,9 @@ def __init__( dim_data: int, model: neuraldual.BaseW2NeuralDual, optimizer: Optional[optax.OptState] = None, - fitting_loss: Optional[Callable[[jnp.ndarray, jnp.ndarray], + fitting_loss: Optional[Callable[[jax.Array, jax.Array], Tuple[float, Optional[Any]]]] = None, - regularizer: Optional[Callable[[jnp.ndarray, jnp.ndarray], + regularizer: Optional[Callable[[jax.Array, jax.Array], Tuple[float, Optional[Any]]]] = None, regularizer_strength: Union[float, Sequence[float]] = 1., num_train_iters: int = 10_000, @@ -126,7 +126,7 @@ def setup( self.step_fn = self._get_step_fn() @property - def regularizer(self) -> Callable[[jnp.ndarray, jnp.ndarray], float]: + def regularizer(self) -> Callable[[jax.Array, jax.Array], float]: """Regularizer added to the fitting loss. Can be e.g. the :func:`~ott.solvers.nn.losses.monge_gap_from_samples`. @@ -139,7 +139,7 @@ def regularizer(self) -> Callable[[jnp.ndarray, jnp.ndarray], float]: return lambda *args, **kwargs: (0., None) @property - def fitting_loss(self) -> Callable[[jnp.ndarray, jnp.ndarray], float]: + def fitting_loss(self) -> Callable[[jax.Array, jax.Array], float]: """Fitting loss to fit the marginal constraint. Can be for instance the @@ -153,9 +153,9 @@ def fitting_loss(self) -> Callable[[jnp.ndarray, jnp.ndarray], float]: @staticmethod def _generate_batch( - loader_source: Iterator[jnp.ndarray], - loader_target: Iterator[jnp.ndarray], - ) -> Dict[str, jnp.ndarray]: + loader_source: Iterator[jax.Array], + loader_target: Iterator[jax.Array], + ) -> Dict[str, jax.Array]: """Generate batches a batch of samples. ``loader_source`` and ``loader_target`` can be training or @@ -168,10 +168,10 @@ def _generate_batch( def train_map_estimator( self, - trainloader_source: Iterator[jnp.ndarray], - trainloader_target: Iterator[jnp.ndarray], - validloader_source: Iterator[jnp.ndarray], - validloader_target: Iterator[jnp.ndarray], + trainloader_source: Iterator[jax.Array], + trainloader_target: Iterator[jax.Array], + validloader_source: Iterator[jax.Array], + validloader_target: Iterator[jax.Array], ) -> Tuple[train_state.TrainState, Dict[str, Any]]: """Training loop.""" # define logs @@ -230,7 +230,7 @@ def _get_step_fn(self) -> Callable: def loss_fn( params: frozen_dict.FrozenDict, apply_fn: Callable, - batch: Dict[str, jnp.ndarray], step: int + batch: Dict[str, jax.Array], step: int ) -> Tuple[float, Dict[str, float]]: """Loss function.""" # map samples with the fitted map @@ -261,8 +261,8 @@ def loss_fn( @functools.partial(jax.jit, static_argnums=3) def step_fn( state_neural_net: train_state.TrainState, - train_batch: Dict[str, jnp.ndarray], - valid_batch: Optional[Dict[str, jnp.ndarray]] = None, + train_batch: Dict[str, jax.Array], + valid_batch: Optional[Dict[str, jax.Array]] = None, is_logging_step: bool = False, step: int = 0 ) -> Tuple[train_state.TrainState, Dict[str, float]]: diff --git a/src/ott/neural/solvers/neuraldual.py b/src/ott/neural/solvers/neuraldual.py index 6ac3d1c79..7d4d5800f 100644 --- a/src/ott/neural/solvers/neuraldual.py +++ b/src/ott/neural/solvers/neuraldual.py @@ -44,8 +44,8 @@ Callback_t = Callable[[int, potentials.DualPotentials], None] Conj_t = Optional[conjugate_solvers.FenchelConjugateSolver] -PotentialValueFn_t = Callable[[jnp.ndarray], jnp.ndarray] -PotentialGradientFn_t = Callable[[jnp.ndarray], jnp.ndarray] +PotentialValueFn_t = Callable[[jax.Array], jax.Array] +PotentialGradientFn_t = Callable[[jax.Array], jax.Array] class W2NeuralTrainState(train_state.TrainState): @@ -60,9 +60,9 @@ class W2NeuralTrainState(train_state.TrainState): potential_gradient_fn: the potential's gradient function """ potential_value_fn: Callable[ - [frozen_dict.FrozenDict[str, jnp.ndarray], Optional[PotentialValueFn_t]], + [frozen_dict.FrozenDict[str, jax.Array], Optional[PotentialValueFn_t]], PotentialValueFn_t] = struct.field(pytree_node=False) - potential_gradient_fn: Callable[[frozen_dict.FrozenDict[str, jnp.ndarray]], + potential_gradient_fn: Callable[[frozen_dict.FrozenDict[str, jax.Array]], PotentialGradientFn_t] = struct.field( pytree_node=False ) @@ -87,7 +87,7 @@ def is_potential(self) -> bool: def potential_value_fn( self, - params: frozen_dict.FrozenDict[str, jnp.ndarray], + params: frozen_dict.FrozenDict[str, jax.Array], other_potential_value_fn: Optional[PotentialValueFn_t] = None, ) -> PotentialValueFn_t: r"""Return a function giving the value of the potential. @@ -119,7 +119,7 @@ def potential_value_fn( "The value of the gradient-based potential depends " \ "on the value of the other potential." - def value_fn(x: jnp.ndarray) -> jnp.ndarray: + def value_fn(x: jax.Array) -> jax.Array: squeeze = x.ndim == 1 if squeeze: x = jnp.expand_dims(x, 0) @@ -132,7 +132,7 @@ def value_fn(x: jnp.ndarray) -> jnp.ndarray: def potential_gradient_fn( self, - params: frozen_dict.FrozenDict[str, jnp.ndarray], + params: frozen_dict.FrozenDict[str, jax.Array], ) -> PotentialGradientFn_t: """Return a function returning a vector or the gradient of the potential. @@ -358,10 +358,10 @@ def setup( def __call__( # noqa: D102 self, - trainloader_source: Iterator[jnp.ndarray], - trainloader_target: Iterator[jnp.ndarray], - validloader_source: Iterator[jnp.ndarray], - validloader_target: Iterator[jnp.ndarray], + trainloader_source: Iterator[jax.Array], + trainloader_target: Iterator[jax.Array], + validloader_source: Iterator[jax.Array], + validloader_target: Iterator[jax.Array], callback: Optional[Callback_t] = None, ) -> Union[potentials.DualPotentials, Tuple[potentials.DualPotentials, Train_t]]: @@ -378,10 +378,10 @@ def __call__( # noqa: D102 def train_neuraldual_parallel( self, - trainloader_source: Iterator[jnp.ndarray], - trainloader_target: Iterator[jnp.ndarray], - validloader_source: Iterator[jnp.ndarray], - validloader_target: Iterator[jnp.ndarray], + trainloader_source: Iterator[jax.Array], + trainloader_target: Iterator[jax.Array], + validloader_source: Iterator[jax.Array], + validloader_target: Iterator[jax.Array], callback: Optional[Callback_t] = None, ) -> Train_t: """Training and validation with parallel updates.""" @@ -453,10 +453,10 @@ def train_neuraldual_parallel( def train_neuraldual_alternating( self, - trainloader_source: Iterator[jnp.ndarray], - trainloader_target: Iterator[jnp.ndarray], - validloader_source: Iterator[jnp.ndarray], - validloader_target: Iterator[jnp.ndarray], + trainloader_source: Iterator[jax.Array], + trainloader_target: Iterator[jax.Array], + validloader_source: Iterator[jax.Array], + validloader_target: Iterator[jax.Array], callback: Optional[Callback_t] = None, ) -> Train_t: """Training and validation with alternating updates.""" @@ -533,7 +533,7 @@ def loss_fn(params_f, params_g, f_value, g_value, g_gradient, batch): init_source_hat = g_gradient(params_g)(target) - def g_value_partial(y: jnp.ndarray) -> jnp.ndarray: + def g_value_partial(y: jax.Array) -> jax.Array: """Lazy way of evaluating g if f's computation needs it.""" return g_value(params_g)(y) @@ -661,7 +661,7 @@ def to_dual_potentials( self.state_g.params, f_value ) - def g_value_finetuned(y: jnp.ndarray) -> jnp.ndarray: + def g_value_finetuned(y: jax.Array) -> jax.Array: x_hat = jax.grad(g_value_prediction)(y) grad_g_y = jax.lax.stop_gradient( self.conjugate_solver.solve(f_value, y, x_init=x_hat).grad @@ -686,7 +686,7 @@ def _clip_weights_icnn(params): return core.freeze(params) @staticmethod - def _penalize_weights_icnn(params: Dict[str, jnp.ndarray]) -> float: + def _penalize_weights_icnn(params: Dict[str, jax.Array]) -> float: penalty = 0.0 for k, param in params.items(): if k.startswith("w_z"): @@ -696,9 +696,9 @@ def _penalize_weights_icnn(params: Dict[str, jnp.ndarray]) -> float: @staticmethod def _update_logs( logs: Dict[str, List[Union[float, str]]], - loss_f: jnp.ndarray, - loss_g: jnp.ndarray, - w_dist: jnp.ndarray, + loss_f: jax.Array, + loss_g: jax.Array, + w_dist: jax.Array, ) -> None: logs["loss_f"].append(float(loss_f)) logs["loss_g"].append(float(loss_g)) diff --git a/src/ott/neural/solvers/otfm.py b/src/ott/neural/solvers/otfm.py index ec0be23da..3b5aa3319 100644 --- a/src/ott/neural/solvers/otfm.py +++ b/src/ott/neural/solvers/otfm.py @@ -38,6 +38,35 @@ class OTFlowMatching(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): + """Flow matching as introduced in :cite:`TODO, with extension to OT-FM (). + + Args: + neural_vector_field: Neural vector field parameterized by a neural network. + input_dim: Dimension of the input data. + cond_dim: Dimension of the conditioning variable. + iterations: Number of iterations. + valid_freq: Frequency of validation. + ot_solver: OT solver to match samples from the source and the target distribution as proposed in :cite:`TODO`. If `None`, no matching will be performed as proposed in :cite:`TODO`. + flow: Flow between source and target distribution. + time_sampler: Sampler for the time. + optimizer: Optimizer for `neural_vector_field`. + checkpoint_manager: Checkpoint manager. + epsilon: Entropy regularization term for the `ot_solver`. + cost_fn: Cost function for the OT problem solved by the `ot_solver`. + tau_a: If :math:`<1`, defines how much unbalanced the problem is + on the first marginal. + tau_b: If :math:`< 1`, defines how much unbalanced the problem is + on the second marginal. + mlp_eta: Neural network to learn the left rescaling function as suggested in :cite:`TODO`. If `None`, the left rescaling factor is not learnt. + mlp_xi: Neural network to learn the right rescaling function as suggested in :cite:`TODO`. If `None`, the right rescaling factor is not learnt. + unbalanced_kwargs: Keyword arguments for the unbalancedness solver. + callback_fn: Callback function. + rng: Random number generator. + + Returns: + None + + """ def __init__( self, @@ -55,10 +84,10 @@ def __init__( cost_fn: Type[costs.CostFn] = costs.SqEuclidean(), tau_a: float = 1.0, tau_b: float = 1.0, - mlp_eta: Callable[[jnp.ndarray], float] = None, - mlp_xi: Callable[[jnp.ndarray], float] = None, + mlp_eta: Callable[[jax.Array], float] = None, + mlp_xi: Callable[[jax.Array], float] = None, unbalanced_kwargs: Dict[str, Any] = {}, - callback_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], + callback_fn: Optional[Callable[[jax.Array, jax.Array, jax.Array], Any]] = None, rng: random.PRNGKeyArray = random.PRNGKey(0), ) -> None: @@ -93,6 +122,7 @@ def __init__( self.setup() def setup(self) -> None: + """Setup :class:`OTFlowMatching`.""" self.state_neural_vector_field = self.neural_vector_field.create_train_state( self.rng, self.optimizer, self.input_dim ) @@ -115,13 +145,13 @@ def _get_step_fn(self) -> Callable: def step_fn( key: random.PRNGKeyArray, state_neural_vector_field: train_state.TrainState, - batch: Dict[str, jnp.ndarray], + batch: Dict[str, jax.Array], ) -> Tuple[Any, Any]: def loss_fn( params: jax.Array, t: jax.Array, noise: jax.Array, - batch: Dict[str, jnp.ndarray], keys_model: random.PRNGKeyArray - ) -> jnp.ndarray: + batch: Dict[str, jax.Array], keys_model: random.PRNGKeyArray + ) -> jax.Array: x_t = self.flow.compute_xt(noise, t, batch["source"], batch["target"]) apply_fn = functools.partial( @@ -147,7 +177,16 @@ def loss_fn( return step_fn def __call__(self, train_loader, valid_loader) -> None: - batch: Mapping[str, jnp.ndarray] = {} + """Train :class:`OTFlowMatching`. + + Args; + train_loader: Dataloader for the training data. + valid_loader: Dataloader for the validation data. + + Returns: + None + """ + batch: Mapping[str, jax.Array] = {} for iter in range(self.iterations): rng_resample, rng_step_fn, self.rng = random.split(self.rng, 3) batch["source"], batch["target"], batch["condition"] = next(train_loader) @@ -184,9 +223,26 @@ def transport( forward: bool = True, diffeqsolve_kwargs: Dict[str, Any] = types.MappingProxyType({}) ) -> diffrax.Solution: + """Transport data with the learnt map. + + This method solves the neural ODE parameterized by the :attr:`~ott.neural.solvers.OTFlowMatching.neural_vector_field` from + :attr:`~ott.neural.flows.BaseTimeSampler.low` to :attr:`~ott.neural.flows.BaseTimeSampler.high` if `forward` is `True`, + else the other way round. + + Args: + data: Initial condition of the ODE. + condition: Condition of the input data. + forward: If `True` integrates forward, otherwise backwards. + diffeqsovle_kwargs: Keyword arguments for the ODE solver. + + Returns: + The push-forward or pull-back distribution defined by the learnt transport plan. + + """ diffeqsolve_kwargs = dict(diffeqsolve_kwargs) - t0, t1 = (0.0, 1.0) if forward else (1.0, 0.0) + t0, t1 = (self.time_sampler.low, self.time_sampler.high + ) if forward else (self.time_sampler.high, self.time_sampler.low) def solve_ode(input: jax.Array, cond: jax.Array): return diffrax.diffeqsolve( @@ -217,18 +273,40 @@ def _valid_step(self, valid_loader, iter) -> None: @property def learn_rescaling(self) -> bool: + """Whether to learn at least one rescaling factor of the marginal distributions.""" return self.mlp_eta is not None or self.mlp_xi is not None def save(self, path: str) -> None: + """Save the model. + + Args: + path: Where to save the model to. + """ raise NotImplementedError def load(self, path: str) -> "OTFlowMatching": + """Load a model. + + Args: + path: Where to load the model from. + + Returns: + An instance of :class:`ott.neural.solvers.OTFlowMatching`. + """ raise NotImplementedError def training_logs(self) -> Dict[str, Any]: + """Logs of the training.""" raise NotImplementedError - def sample_noise( #TODO: make more general - self, key: random.PRNGKey, batch_size: int - ) -> jnp.ndarray: #TODO: make more general + def sample_noise(self, key: random.PRNGKey, batch_size: int) -> jax.Array: + """Sample noise from a standard-normal distribution. + + Args: + key: Random key for seeding. + batch_size: Number of samples to draw. + + Returns: + Samples from the standard normal distribution. + """ return random.normal(key, shape=(batch_size, self.input_dim)) diff --git a/src/ott/problems/linear/barycenter_problem.py b/src/ott/problems/linear/barycenter_problem.py index ca5333a8e..c94cc578d 100644 --- a/src/ott/problems/linear/barycenter_problem.py +++ b/src/ott/problems/linear/barycenter_problem.py @@ -50,9 +50,9 @@ class FreeBarycenterProblem: def __init__( self, - y: jnp.ndarray, - b: Optional[jnp.ndarray] = None, - weights: Optional[jnp.ndarray] = None, + y: jax.Array, + b: Optional[jax.Array] = None, + weights: Optional[jax.Array] = None, cost_fn: Optional[costs.CostFn] = None, epsilon: Optional[float] = None, **kwargs: Any, @@ -76,7 +76,7 @@ def __init__( assert self._b is None or self._y.shape[0] == self._b.shape[0] @property - def segmented_y_b(self) -> Tuple[jnp.ndarray, jnp.ndarray]: + def segmented_y_b(self) -> Tuple[jax.Array, jax.Array]: """Tuple of arrays containing the segmented measures and weights. - Segmented measures of shape ``[num_measures, max_measure_size, ndim]``. @@ -94,14 +94,14 @@ def segmented_y_b(self) -> Tuple[jnp.ndarray, jnp.ndarray]: return y, b @property - def flattened_y(self) -> jnp.ndarray: + def flattened_y(self) -> jax.Array: """Array of shape ``[num_measures * (N_1 + N_2 + ...), ndim]``.""" if self._is_segmented: return self._y.reshape((-1, self._y.shape[-1])) return self._y @property - def flattened_b(self) -> Optional[jnp.ndarray]: + def flattened_b(self) -> Optional[jax.Array]: """Array of shape ``[num_measures * (N_1 + N_2 + ...),]``.""" return None if self._b is None else self._b.ravel() @@ -121,7 +121,7 @@ def ndim(self) -> int: return self._y.shape[-1] @property - def weights(self) -> jnp.ndarray: + def weights(self) -> jax.Array: """Barycenter weights of shape ``[num_measures,]`` that sum to 1.""" if self._weights is None: return jnp.ones((self.num_measures,)) / self.num_measures @@ -165,8 +165,8 @@ class FixedBarycenterProblem: def __init__( self, geom: geometry.Geometry, - a: jnp.ndarray, - weights: Optional[jnp.ndarray] = None, + a: jax.Array, + weights: Optional[jax.Array] = None, ): self.geom = geom self.a = a @@ -178,7 +178,7 @@ def num_measures(self) -> int: return self.a.shape[0] @property - def weights(self) -> jnp.ndarray: + def weights(self) -> jax.Array: """Barycenter weights of shape ``[num_measures,]`` that sum to :math`1`.""" if self._weights is None: return jnp.ones((self.num_measures,)) / self.num_measures diff --git a/src/ott/problems/linear/linear_problem.py b/src/ott/problems/linear/linear_problem.py index 7c206aa63..3e09c0e59 100644 --- a/src/ott/problems/linear/linear_problem.py +++ b/src/ott/problems/linear/linear_problem.py @@ -21,9 +21,8 @@ __all__ = ["LinearProblem"] # TODO(michalk8): move to typing.py when refactoring the types -MarginalFunc = Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray] -TransportAppFunc = Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray, int], - jnp.ndarray] +MarginalFunc = Callable[[jax.Array, jax.Array], jax.Array] +TransportAppFunc = Callable[[jax.Array, jax.Array, jax.Array, int], jax.Array] @jax.tree_util.register_pytree_node_class @@ -50,8 +49,8 @@ class LinearProblem: def __init__( self, geom: geometry.Geometry, - a: Optional[jnp.ndarray] = None, - b: Optional[jnp.ndarray] = None, + a: Optional[jax.Array] = None, + b: Optional[jax.Array] = None, tau_a: float = 1.0, tau_b: float = 1.0 ): @@ -62,13 +61,13 @@ def __init__( self.tau_b = tau_b @property - def a(self) -> jnp.ndarray: + def a(self) -> jax.Array: """First marginal.""" num_a = self.geom.shape[0] return jnp.ones((num_a,)) / num_a if self._a is None else self._a @property - def b(self) -> jnp.ndarray: + def b(self) -> jax.Array: """Second marginal.""" num_b = self.geom.shape[1] return jnp.ones((num_b,)) / num_b if self._b is None else self._b diff --git a/src/ott/problems/linear/potentials.py b/src/ott/problems/linear/potentials.py index 7ab226072..718aa22a1 100644 --- a/src/ott/problems/linear/potentials.py +++ b/src/ott/problems/linear/potentials.py @@ -37,7 +37,7 @@ mpl = plt = None __all__ = ["DualPotentials", "EntropicPotentials"] -Potential_t = Callable[[jnp.ndarray], float] +Potential_t = Callable[[jax.Array], float] @jtu.register_pytree_node_class @@ -72,7 +72,7 @@ def __init__( self.cost_fn = cost_fn self._corr = corr - def transport(self, vec: jnp.ndarray, forward: bool = True) -> jnp.ndarray: + def transport(self, vec: jax.Array, forward: bool = True) -> jax.Array: r"""Transport ``vec`` according to Brenier formula :cite:`brenier:91`. Uses Theorem 1.17 from :cite:`santambrogio:15` to compute an OT map when @@ -105,7 +105,7 @@ def transport(self, vec: jnp.ndarray, forward: bool = True) -> jnp.ndarray: return vec - self._grad_h_inv(self._grad_f(vec)) return vec - self._grad_h_inv(self._grad_g(vec)) - def distance(self, src: jnp.ndarray, tgt: jnp.ndarray) -> float: + def distance(self, src: jax.Array, tgt: jax.Array) -> float: r"""Evaluate Wasserstein distance between samples using dual potentials. This uses direct estimation of potentials against measures when dual @@ -146,17 +146,17 @@ def g(self) -> Potential_t: return self._g @property - def _grad_f(self) -> Callable[[jnp.ndarray], jnp.ndarray]: + def _grad_f(self) -> Callable[[jax.Array], jax.Array]: """Vectorized gradient of the potential function :attr:`f`.""" return jax.vmap(jax.grad(self.f, argnums=0)) @property - def _grad_g(self) -> Callable[[jnp.ndarray], jnp.ndarray]: + def _grad_g(self) -> Callable[[jax.Array], jax.Array]: """Vectorized gradient of the potential function :attr:`g`.""" return jax.vmap(jax.grad(self.g, argnums=0)) @property - def _grad_h_inv(self) -> Callable[[jnp.ndarray], jnp.ndarray]: + def _grad_h_inv(self) -> Callable[[jax.Array], jax.Array]: from ott.geometry import costs assert isinstance(self.cost_fn, costs.TICost), ( @@ -181,9 +181,9 @@ def tree_unflatten( # noqa: D102 def plot_ot_map( self, - source: jnp.ndarray, - target: jnp.ndarray, - samples: Optional[jnp.ndarray] = None, + source: jax.Array, + target: jax.Array, + samples: Optional[jax.Array] = None, forward: bool = True, ax: Optional["plt.Axes"] = None, legend_kwargs: Optional[Dict[str, Any]] = None, @@ -348,11 +348,11 @@ class EntropicPotentials(DualPotentials): def __init__( self, - f_xy: jnp.ndarray, - g_xy: jnp.ndarray, + f_xy: jax.Array, + g_xy: jax.Array, prob: linear_problem.LinearProblem, - f_xx: Optional[jnp.ndarray] = None, - g_yy: Optional[jnp.ndarray] = None, + f_xx: Optional[jax.Array] = None, + g_yy: Optional[jax.Array] = None, ): # we pass directly the arrays and override the properties # since only the properties need to be callable @@ -373,11 +373,11 @@ def _potential_fn(self, *, kind: Literal["f", "g"]) -> Potential_t: from ott.geometry import pointcloud def callback( - x: jnp.ndarray, + x: jax.Array, *, - potential: jnp.ndarray, - y: jnp.ndarray, - weights: jnp.ndarray, + potential: jax.Array, + y: jax.Array, + weights: jax.Array, epsilon: float, ) -> float: x = jnp.atleast_2d(x) diff --git a/src/ott/problems/quadratic/gw_barycenter.py b/src/ott/problems/quadratic/gw_barycenter.py index dfe562d98..7170f1064 100644 --- a/src/ott/problems/quadratic/gw_barycenter.py +++ b/src/ott/problems/quadratic/gw_barycenter.py @@ -60,11 +60,11 @@ class GWBarycenterProblem(barycenter_problem.FreeBarycenterProblem): def __init__( self, - y: Optional[jnp.ndarray] = None, - b: Optional[jnp.ndarray] = None, - weights: Optional[jnp.ndarray] = None, - costs: Optional[jnp.ndarray] = None, - y_fused: Optional[jnp.ndarray] = None, + y: Optional[jax.Array] = None, + b: Optional[jax.Array] = None, + weights: Optional[jax.Array] = None, + costs: Optional[jax.Array] = None, + y_fused: Optional[jax.Array] = None, fused_penalty: float = 1.0, gw_loss: Literal["sqeucl", "kl"] = "sqeucl", scale_cost: Union[int, float, Literal["mean", "max_cost"]] = 1.0, @@ -98,9 +98,7 @@ def __init__( # TODO(michalk8): in the future, consider checking the other 2 cases # using `segmented_y` and `segmented_y_fused`? - def update_barycenter( - self, transports: jnp.ndarray, a: jnp.ndarray - ) -> jnp.ndarray: + def update_barycenter(self, transports: jax.Array, a: jax.Array) -> jax.Array: """Update the barycenter cost matrix. Uses the eq. 14 and 15 of :cite:`peyre:16`. @@ -116,11 +114,11 @@ def update_barycenter( @functools.partial(jax.vmap, in_axes=[0, 0, 0, None]) def project( - y: jnp.ndarray, - b: jnp.ndarray, - transport: jnp.ndarray, + y: jax.Array, + b: jax.Array, + transport: jax.Array, fn: Optional[quadratic_costs.Loss], - ) -> jnp.ndarray: + ) -> jax.Array: geom = self._create_y_geometry(y, mask=b > 0.) fn, lin = (None, True) if fn is None else (fn.func, fn.is_linear) @@ -146,8 +144,8 @@ def project( return jnp.exp(barycenter) return barycenter - def update_features(self, transports: jnp.ndarray, - a: jnp.ndarray) -> Optional[jnp.ndarray]: + def update_features(self, transports: jax.Array, + a: jax.Array) -> Optional[jax.Array]: """Update the barycenter features in the fused case :cite:`vayer:19`. Uses :cite:`cuturi:14` eq. 8, and is implemented only @@ -181,8 +179,8 @@ def update_features(self, transports: jnp.ndarray, def _create_bary_geometry( self, - cost_matrix: jnp.ndarray, - mask: Optional[jnp.ndarray] = None + cost_matrix: jax.Array, + mask: Optional[jax.Array] = None ) -> geometry.Geometry: return geometry.Geometry( cost_matrix=cost_matrix, @@ -194,8 +192,8 @@ def _create_bary_geometry( def _create_y_geometry( self, - y: jnp.ndarray, - mask: Optional[jnp.ndarray] = None + y: jax.Array, + mask: Optional[jax.Array] = None ) -> geometry.Geometry: if self._y_as_costs: assert y.shape[0] == y.shape[1], y.shape @@ -217,10 +215,10 @@ def _create_y_geometry( def _create_fused_geometry( self, - x: jnp.ndarray, - y: jnp.ndarray, - src_mask: Optional[jnp.ndarray] = None, - tgt_mask: Optional[jnp.ndarray] = None + x: jax.Array, + y: jax.Array, + src_mask: Optional[jax.Array] = None, + tgt_mask: Optional[jax.Array] = None ) -> pointcloud.PointCloud: return pointcloud.PointCloud( x, @@ -235,9 +233,9 @@ def _create_fused_geometry( def _create_problem( self, state: "GWBarycenterState", # noqa: F821 - y: jnp.ndarray, - b: jnp.ndarray, - f: Optional[jnp.ndarray] = None + y: jax.Array, + b: jax.Array, + f: Optional[jax.Array] = None ) -> quadratic_problem.QuadraticProblem: # TODO(michalk8): in future, mask in the problem for convenience? bary_mask = state.a > 0. @@ -269,7 +267,7 @@ def is_fused(self) -> bool: return self._y_fused is not None @property - def segmented_y_fused(self) -> Optional[jnp.ndarray]: + def segmented_y_fused(self) -> Optional[jax.Array]: """Feature array of shape used in the fused case.""" if not self.is_fused or self._y_fused.ndim == 3: return self._y_fused diff --git a/src/ott/problems/quadratic/quadratic_costs.py b/src/ott/problems/quadratic/quadratic_costs.py index 70f2bf5ad..060c3c537 100644 --- a/src/ott/problems/quadratic/quadratic_costs.py +++ b/src/ott/problems/quadratic/quadratic_costs.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Callable, NamedTuple +import jax import jax.numpy as jnp import jax.scipy as jsp @@ -20,7 +21,7 @@ class Loss(NamedTuple): # noqa: D101 - func: Callable[[jnp.ndarray], jnp.ndarray] + func: Callable[[jax.Array], jax.Array] is_linear: bool diff --git a/src/ott/problems/quadratic/quadratic_problem.py b/src/ott/problems/quadratic/quadratic_problem.py index 5deb4558c..cf5b804a2 100644 --- a/src/ott/problems/quadratic/quadratic_problem.py +++ b/src/ott/problems/quadratic/quadratic_problem.py @@ -91,8 +91,8 @@ def __init__( geom_xy: Optional[geometry.Geometry] = None, fused_penalty: float = 1.0, scale_cost: Optional[Union[bool, float, str]] = False, - a: Optional[jnp.ndarray] = None, - b: Optional[jnp.ndarray] = None, + a: Optional[jax.Array] = None, + b: Optional[jax.Array] = None, loss: Union[Literal["sqeucl", "kl"], quadratic_costs.GWLoss] = "sqeucl", tau_a: float = 1.0, tau_b: float = 1.0, @@ -125,8 +125,8 @@ def __init__( def marginal_dependent_cost( self, - marginal_1: jnp.ndarray, - marginal_2: jnp.ndarray, + marginal_1: jax.Array, + marginal_2: jax.Array, ) -> low_rank.LRCGeometry: r"""Initialize cost term that depends on the marginals of the transport. @@ -169,9 +169,9 @@ def marginal_dependent_cost( def cost_unbalanced_correction( self, - transport_matrix: jnp.ndarray, - marginal_1: jnp.ndarray, - marginal_2: jnp.ndarray, + transport_matrix: jax.Array, + marginal_1: jax.Array, + marginal_2: jax.Array, epsilon: epsilon_scheduler.Epsilon, ) -> float: r"""Calculate cost term from the quadratic divergence when unbalanced. @@ -193,10 +193,10 @@ def cost_unbalanced_correction( :math:`+ epsilon * \sum(KL(P|ab'))` Args: - transport_matrix: jnp.ndarray[num_a, num_b], transport matrix. - marginal_1: jnp.ndarray[num_a,], marginal of the transport matrix + transport_matrix: jax.Array[num_a, num_b], transport matrix. + marginal_1: jax.Array[num_a,], marginal of the transport matrix for samples from :attr:`geom_xx`. - marginal_2: jnp.ndarray[num_b,], marginal of the transport matrix + marginal_2: jax.Array[num_b,], marginal of the transport matrix for samples from :attr:`geom_yy`. epsilon: entropy regularizer. @@ -353,7 +353,7 @@ def update_lr_linearization( ) @property - def _fused_cost_matrix(self) -> Union[float, jnp.ndarray]: + def _fused_cost_matrix(self) -> Union[float, jax.Array]: if not self.is_fused: return 0.0 geom_xy = self.geom_xy @@ -442,13 +442,13 @@ def geom_xy(self) -> Optional[geometry.Geometry]: return self._geom_xy @property - def a(self) -> jnp.ndarray: + def a(self) -> jax.Array: """First marginal.""" num_a = self.geom_xx.shape[0] return jnp.ones((num_a,)) / num_a if self._a is None else self._a @property - def b(self) -> jnp.ndarray: + def b(self) -> jax.Array: """Second marginal.""" num_b = self.geom_yy.shape[0] return jnp.ones((num_b,)) / num_b if self._b is None else self._b @@ -510,7 +510,7 @@ def update_epsilon_unbalanced( # noqa: D103 def apply_cost( # noqa: D103 - geom: geometry.Geometry, arr: jnp.ndarray, *, axis: int, + geom: geometry.Geometry, arr: jax.Array, *, axis: int, fn: quadratic_costs.Loss -) -> jnp.ndarray: +) -> jax.Array: return geom.apply_cost(arr, axis=axis, fn=fn.func, is_linear=fn.is_linear) diff --git a/src/ott/solvers/linear/_solve.py b/src/ott/solvers/linear/_solve.py index 2bca6a825..fad5a4e7d 100644 --- a/src/ott/solvers/linear/_solve.py +++ b/src/ott/solvers/linear/_solve.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Any, Optional, Union -import jax.numpy as jnp +import jax from ott.geometry import geometry from ott.problems.linear import linear_problem @@ -24,8 +24,8 @@ def solve( geom: geometry.Geometry, - a: Optional[jnp.ndarray] = None, - b: Optional[jnp.ndarray] = None, + a: Optional[jax.Array] = None, + b: Optional[jax.Array] = None, tau_a: float = 1.0, tau_b: float = 1.0, rank: int = -1, diff --git a/src/ott/solvers/linear/acceleration.py b/src/ott/solvers/linear/acceleration.py index 4529e7f78..7ce602194 100644 --- a/src/ott/solvers/linear/acceleration.py +++ b/src/ott/solvers/linear/acceleration.py @@ -34,7 +34,7 @@ class AndersonAcceleration: refresh_every: int = 1 # Recompute interpolation periodically. ridge_identity: float = 1e-2 # Ridge used in the linear system. - def extrapolation(self, xs: jnp.ndarray, fxs: jnp.ndarray) -> jnp.ndarray: + def extrapolation(self, xs: jax.Array, fxs: jax.Array) -> jax.Array: """Compute Anderson extrapolation from past observations.""" # Remove -inf values to instantiate quadratic problem. All others # remain since they might be caused by a valid issue. @@ -161,10 +161,10 @@ def lehmann(self, state: "sinkhorn.SinkhornState") -> float: def __call__( # noqa: D102 self, weight: float, - value: jnp.ndarray, - new_value: jnp.ndarray, + value: jax.Array, + new_value: jax.Array, lse_mode: bool = True - ) -> jnp.ndarray: + ) -> jax.Array: if lse_mode: value = jnp.where(jnp.isfinite(value), value, 0.0) return (1.0 - weight) * value + weight * new_value diff --git a/src/ott/solvers/linear/continuous_barycenter.py b/src/ott/solvers/linear/continuous_barycenter.py index 2d89a74ea..0094c3a3c 100644 --- a/src/ott/solvers/linear/continuous_barycenter.py +++ b/src/ott/solvers/linear/continuous_barycenter.py @@ -41,11 +41,11 @@ class FreeBarycenterState(NamedTuple): a: barycenter weights. """ - costs: Optional[jnp.ndarray] = None - linear_convergence: Optional[jnp.ndarray] = None - errors: Optional[jnp.ndarray] = None - x: Optional[jnp.ndarray] = None - a: Optional[jnp.ndarray] = None + costs: Optional[jax.Array] = None + linear_convergence: Optional[jax.Array] = None + errors: Optional[jax.Array] = None + x: Optional[jax.Array] = None + a: Optional[jax.Array] = None def set(self, **kwargs: Any) -> "FreeBarycenterState": """Return a copy of self, possibly with overwrites.""" @@ -70,7 +70,7 @@ def update( @functools.partial(jax.vmap, in_axes=[None, None, 0, 0]) def solve_linear_ot( - a: Optional[jnp.ndarray], x: jnp.ndarray, b: jnp.ndarray, y: jnp.ndarray + a: Optional[jax.Array], x: jax.Array, b: jax.Array, y: jax.Array ): out = linear_ot_solver( linear_problem.LinearProblem( @@ -129,7 +129,7 @@ def __call__( # noqa: D102 self, bar_prob: barycenter_problem.FreeBarycenterProblem, bar_size: int = 100, - x_init: Optional[jnp.ndarray] = None, + x_init: Optional[jax.Array] = None, rng: Optional[jax.Array] = None, ) -> FreeBarycenterState: # TODO(michalk8): no reason for iterations to be outside this class @@ -140,7 +140,7 @@ def init_state( self, bar_prob: barycenter_problem.FreeBarycenterProblem, bar_size: int, - x_init: Optional[jnp.ndarray] = None, + x_init: Optional[jax.Array] = None, rng: Optional[jax.Array] = None, ) -> FreeBarycenterState: """Initialize the state of the Wasserstein barycenter iterations. @@ -195,7 +195,7 @@ def output_from_state( # noqa: D102 def iterations( solver: FreeWassersteinBarycenter, bar_size: int, - bar_prob: barycenter_problem.FreeBarycenterProblem, x_init: jnp.ndarray, + bar_prob: barycenter_problem.FreeBarycenterProblem, x_init: jax.Array, rng: jax.Array ) -> FreeBarycenterState: """Jittable Wasserstein barycenter outer loop.""" diff --git a/src/ott/solvers/linear/discrete_barycenter.py b/src/ott/solvers/linear/discrete_barycenter.py index dcfdc1470..85adaa795 100644 --- a/src/ott/solvers/linear/discrete_barycenter.py +++ b/src/ott/solvers/linear/discrete_barycenter.py @@ -26,10 +26,10 @@ class SinkhornBarycenterOutput(NamedTuple): # noqa: D101 - f: jnp.ndarray - g: jnp.ndarray - histogram: jnp.ndarray - errors: jnp.ndarray + f: jax.Array + g: jax.Array + histogram: jax.Array + errors: jax.Array @jax.tree_util.register_pytree_node_class @@ -79,7 +79,7 @@ def __init__( def __call__( self, fixed_bp: barycenter_problem.FixedBarycenterProblem, - dual_initialization: Optional[jnp.ndarray] = None, + dual_initialization: Optional[jax.Array] = None, ) -> SinkhornBarycenterOutput: """Solve barycenter problem, possibly using clever initialization. @@ -128,10 +128,10 @@ def tree_unflatten(cls, aux_data, children): # noqa: D102 @functools.partial(jax.jit, static_argnums=(5, 6, 7, 8, 9, 10, 11, 12)) def _discrete_barycenter( - geom: geometry.Geometry, a: jnp.ndarray, weights: jnp.ndarray, - dual_initialization: jnp.ndarray, threshold: float, - norm_error: Sequence[int], inner_iterations: int, min_iterations: int, - max_iterations: int, lse_mode: bool, debiased: bool, num_a: int, num_b: int + geom: geometry.Geometry, a: jax.Array, weights: jax.Array, + dual_initialization: jax.Array, threshold: float, norm_error: Sequence[int], + inner_iterations: int, min_iterations: int, max_iterations: int, + lse_mode: bool, debiased: bool, num_a: int, num_b: int ) -> SinkhornBarycenterOutput: """Jit'able function to compute discrete barycenters.""" if lse_mode: diff --git a/src/ott/solvers/linear/implicit_differentiation.py b/src/ott/solvers/linear/implicit_differentiation.py index fbf98ce81..c5e7cb0f3 100644 --- a/src/ott/solvers/linear/implicit_differentiation.py +++ b/src/ott/solvers/linear/implicit_differentiation.py @@ -23,9 +23,8 @@ if TYPE_CHECKING: from ott.problems.linear import linear_problem -LinOp_t = Callable[[jnp.ndarray], jnp.ndarray] -Solver_t = Callable[[LinOp_t, jnp.ndarray, Optional[LinOp_t], bool], - jnp.ndarray] +LinOp_t = Callable[[jax.Array], jax.Array] +Solver_t = Callable[[LinOp_t, jax.Array, Optional[LinOp_t], bool], jax.Array] __all__ = ["ImplicitDiff", "solve_jax_cg"] @@ -70,16 +69,16 @@ class ImplicitDiff: solver: Optional[Solver_t] = None solver_kwargs: Optional[Dict[str, Any]] = None symmetric: bool = False - precondition_fun: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None + precondition_fun: Optional[Callable[[jax.Array], jax.Array]] = None def solve( self, - gr: Tuple[jnp.ndarray, jnp.ndarray], + gr: Tuple[jax.Array, jax.Array], ot_prob: "linear_problem.LinearProblem", - f: jnp.ndarray, - g: jnp.ndarray, + f: jax.Array, + g: jax.Array, lse_mode: bool, - ) -> jnp.ndarray: + ) -> jax.Array: r"""Apply minus inverse of [hessian ``reg_ot_cost`` w.r.t. ``f``, ``g``]. This function is used to carry out implicit differentiation of ``sinkhorn`` @@ -224,7 +223,7 @@ def solve( return jnp.concatenate((-vjp_gr_f, -vjp_gr_g)) def first_order_conditions( - self, prob, f: jnp.ndarray, g: jnp.ndarray, lse_mode: bool + self, prob, f: jax.Array, g: jax.Array, lse_mode: bool ): r"""Compute vector of first order conditions for the reg-OT problem. @@ -238,12 +237,12 @@ def first_order_conditions( Args: prob: definition of the linear optimal transport problem. - f: jnp.ndarray, first potential - g: jnp.ndarray, second potential + f: jax.Array, first potential + g: jax.Array, second potential lse_mode: bool Returns: - a jnp.ndarray of size (size of ``n + m``) quantifying deviation to + a jax.Array of size (size of ``n + m``) quantifying deviation to optimality for variables ``f`` and ``g``. """ geom = prob.geom @@ -266,8 +265,8 @@ def first_order_conditions( return jnp.concatenate((result_a, result_b)) def gradient( - self, prob: "linear_problem.LinearProblem", f: jnp.ndarray, - g: jnp.ndarray, lse_mode: bool, gr: Tuple[jnp.ndarray, jnp.ndarray] + self, prob: "linear_problem.LinearProblem", f: jax.Array, g: jax.Array, + lse_mode: bool, gr: Tuple[jax.Array, jax.Array] ) -> "linear_problem.LinearProblem": """Apply VJP to recover gradient in reverse mode differentiation.""" # Applies first part of vjp to gr: inverse part of implicit function theorem @@ -287,13 +286,13 @@ def replace(self, **kwargs: Any) -> "ImplicitDiff": # noqa: D102 def solve_jax_cg( lin: LinOp_t, - b: jnp.ndarray, + b: jax.Array, lin_t: Optional[LinOp_t] = None, symmetric: bool = False, ridge_identity: float = 0.0, ridge_kernel: float = 0.0, **kwargs: Any -) -> jnp.ndarray: +) -> jax.Array: """Wrapper around JAX native linear solvers. Args: diff --git a/src/ott/solvers/linear/lineax_implicit.py b/src/ott/solvers/linear/lineax_implicit.py index 79b9e7c95..ac3978462 100644 --- a/src/ott/solvers/linear/lineax_implicit.py +++ b/src/ott/solvers/linear/lineax_implicit.py @@ -46,14 +46,14 @@ def transpose(self): def solve_lineax( lin: Callable, - b: jnp.ndarray, + b: jax.Array, lin_t: Optional[Callable] = None, symmetric: bool = False, nonsym_solver: Optional[lx.AbstractLinearSolver] = None, ridge_identity: float = 0.0, ridge_kernel: float = 0.0, **kwargs: Any -) -> jnp.ndarray: +) -> jax.Array: """Wrapper around lineax solvers. Args: diff --git a/src/ott/solvers/linear/lr_utils.py b/src/ott/solvers/linear/lr_utils.py index 8ade265c9..2eb4c32ed 100644 --- a/src/ott/solvers/linear/lr_utils.py +++ b/src/ott/solvers/linear/lr_utils.py @@ -24,27 +24,27 @@ class State(NamedTuple): # noqa: D101 - v1: jnp.ndarray - v2: jnp.ndarray - u1: jnp.ndarray - u2: jnp.ndarray - g: jnp.ndarray + v1: jax.Array + v2: jax.Array + u1: jax.Array + u2: jax.Array + g: jax.Array err: float class Constants(NamedTuple): # noqa: D101 - a: jnp.ndarray - b: jnp.ndarray + a: jax.Array + b: jax.Array rho_a: float rho_b: float - supp_a: Optional[jnp.ndarray] = None - supp_b: Optional[jnp.ndarray] = None + supp_a: Optional[jax.Array] = None + supp_b: Optional[jax.Array] = None def unbalanced_dykstra_lse( - c_q: jnp.ndarray, - c_r: jnp.ndarray, - c_g: jnp.ndarray, + c_q: jax.Array, + c_r: jax.Array, + c_g: jax.Array, gamma: float, ot_prob: linear_problem.LinearProblem, translation_invariant: bool = True, @@ -52,7 +52,7 @@ def unbalanced_dykstra_lse( min_iter: int = 0, inner_iter: int = 10, max_iter: int = 10000 -) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: +) -> Tuple[jax.Array, jax.Array, jax.Array]: """Dykstra's algorithm for the unbalanced :class:`~ott.solvers.linear.sinkhorn_lr.LRSinkhorn` in LSE mode. @@ -74,10 +74,10 @@ def unbalanced_dykstra_lse( """ # noqa: D205 def _softm( - v: jnp.ndarray, - c: jnp.ndarray, + v: jax.Array, + c: jax.Array, axis: int, - ) -> jnp.ndarray: + ) -> jax.Array: v = jnp.expand_dims(v, axis=1 - axis) return jsp.special.logsumexp(v + c, axis=axis) @@ -181,9 +181,9 @@ def body_fn( def unbalanced_dykstra_kernel( - k_q: jnp.ndarray, - k_r: jnp.ndarray, - k_g: jnp.ndarray, + k_q: jax.Array, + k_r: jax.Array, + k_g: jax.Array, gamma: float, ot_prob: linear_problem.LinearProblem, translation_invariant: bool = True, @@ -191,7 +191,7 @@ def unbalanced_dykstra_kernel( min_iter: int = 0, inner_iter: int = 10, max_iter: int = 10000 -) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: +) -> Tuple[jax.Array, jax.Array, jax.Array]: """Dykstra's algorithm for the unbalanced :class:`~ott.solvers.linear.sinkhorn_lr.LRSinkhorn` in kernel mode. @@ -317,7 +317,7 @@ def body_fn( def compute_lambdas( - const: Constants, state: State, gamma: float, g: jnp.ndarray, *, + const: Constants, state: State, gamma: float, g: jax.Array, *, lse_mode: bool ) -> Tuple[float, float]: """TODO.""" diff --git a/src/ott/solvers/linear/sinkhorn.py b/src/ott/solvers/linear/sinkhorn.py index 7ab17e870..44afe1833 100644 --- a/src/ott/solvers/linear/sinkhorn.py +++ b/src/ott/solvers/linear/sinkhorn.py @@ -52,11 +52,11 @@ class SinkhornState(NamedTuple): """Holds the state variables used to solve OT with Sinkhorn.""" - errors: Optional[jnp.ndarray] = None - fu: Optional[jnp.ndarray] = None - gv: Optional[jnp.ndarray] = None - old_fus: Optional[jnp.ndarray] = None - old_mapped_fus: Optional[jnp.ndarray] = None + errors: Optional[jax.Array] = None + fu: Optional[jax.Array] = None + gv: Optional[jax.Array] = None + old_fus: Optional[jax.Array] = None + old_mapped_fus: Optional[jax.Array] = None def set(self, **kwargs: Any) -> "SinkhornState": """Return a copy of self, with potential overwrites.""" @@ -70,7 +70,7 @@ def solution_error( lse_mode: bool, parallel_dual_updates: bool, recenter: bool, - ) -> jnp.ndarray: + ) -> jax.Array: """State dependent function to return error.""" fu, gv = self.fu, self.gv if recenter and lse_mode: @@ -92,10 +92,10 @@ def compute_kl_reg_cost( # noqa: D102 def recenter( self, - f: jnp.ndarray, - g: jnp.ndarray, + f: jax.Array, + g: jax.Array, ot_prob: linear_problem.LinearProblem, - ) -> Tuple[jnp.ndarray, jnp.ndarray]: + ) -> Tuple[jax.Array, jax.Array]: """Re-center dual potentials. If the ``ot_prob`` is balanced, the ``f`` potential is zero-centered. @@ -132,14 +132,14 @@ def recenter( def solution_error( - f_u: jnp.ndarray, - g_v: jnp.ndarray, + f_u: jax.Array, + g_v: jax.Array, ot_prob: linear_problem.LinearProblem, *, norm_error: Sequence[int], lse_mode: bool, parallel_dual_updates: bool, -) -> jnp.ndarray: +) -> jax.Array: """Given two potential/scaling solutions, computes deviation to optimality. When the ``ot_prob`` problem is balanced and the usual Sinkhorn updates are @@ -153,8 +153,8 @@ def solution_error( additional quantities to qualify optimality must be taken into account. Args: - f_u: jnp.ndarray, potential or scaling - g_v: jnp.ndarray, potential or scaling + f_u: jax.Array, potential or scaling + g_v: jax.Array, potential or scaling ot_prob: linear OT problem norm_error: int, p-norm used to compute error. lse_mode: True if log-sum-exp operations, False if kernel vector products. @@ -196,9 +196,9 @@ def solution_error( def marginal_error( - f_u: jnp.ndarray, - g_v: jnp.ndarray, - target: jnp.ndarray, + f_u: jax.Array, + g_v: jax.Array, + target: jax.Array, geom: geometry.Geometry, axis: int = 0, norm_error: Sequence[int] = (1,), @@ -229,7 +229,7 @@ def marginal_error( def compute_kl_reg_cost( - f: jnp.ndarray, g: jnp.ndarray, ot_prob: linear_problem.LinearProblem, + f: jax.Array, g: jax.Array, ot_prob: linear_problem.LinearProblem, lse_mode: bool ) -> float: r"""Compute objective of Sinkhorn for OT problem given dual solutions. @@ -243,8 +243,8 @@ def compute_kl_reg_cost( values, ``jnp.where`` is used to cancel these contributions. Args: - f: jnp.ndarray, potential - g: jnp.ndarray, potential + f: jax.Array, potential + g: jax.Array, potential ot_prob: linear optimal transport problem. lse_mode: bool, whether to compute total mass in lse or kernel mode. @@ -320,12 +320,12 @@ class SinkhornOutput(NamedTuple): computations of errors. """ - f: Optional[jnp.ndarray] = None - g: Optional[jnp.ndarray] = None - errors: Optional[jnp.ndarray] = None + f: Optional[jax.Array] = None + g: Optional[jax.Array] = None + errors: Optional[jax.Array] = None reg_ot_cost: Optional[float] = None ot_prob: Optional[linear_problem.LinearProblem] = None - threshold: Optional[jnp.ndarray] = None + threshold: Optional[jax.Array] = None converged: Optional[bool] = None inner_iterations: Optional[int] = None @@ -342,7 +342,7 @@ def set_cost( # noqa: D102 return self.set(reg_ot_cost=compute_kl_reg_cost(f, g, ot_prob, lse_mode)) @property - def dual_cost(self) -> jnp.ndarray: + def dual_cost(self) -> jax.Array: """Return dual transport cost, without considering regularizer.""" a, b = self.ot_prob.a, self.ot_prob.b dual_cost = jnp.sum(jnp.where(a > 0.0, a * self.f, 0)) @@ -399,9 +399,7 @@ def kl_reg_cost(self) -> float: """ return self.reg_ot_cost - def transport_cost_at_geom( - self, other_geom: geometry.Geometry - ) -> jnp.ndarray: + def transport_cost_at_geom(self, other_geom: geometry.Geometry) -> jax.Array: r"""Return bare transport cost of current solution at any geometry. In order to compute cost, we check first if the geometry can be converted @@ -428,11 +426,11 @@ def geom(self) -> geometry.Geometry: # noqa: D102 return self.ot_prob.geom @property - def a(self) -> jnp.ndarray: # noqa: D102 + def a(self) -> jax.Array: # noqa: D102 return self.ot_prob.a @property - def b(self) -> jnp.ndarray: # noqa: D102 + def b(self) -> jax.Array: # noqa: D102 return self.ot_prob.b @property @@ -441,13 +439,13 @@ def n_iters(self) -> int: # noqa: D102 return jnp.sum(self.errors != -1) * self.inner_iterations @property - def scalings(self) -> Tuple[jnp.ndarray, jnp.ndarray]: # noqa: D102 + def scalings(self) -> Tuple[jax.Array, jax.Array]: # noqa: D102 u = self.ot_prob.geom.scaling_from_potential(self.f) v = self.ot_prob.geom.scaling_from_potential(self.g) return u, v @property - def matrix(self) -> jnp.ndarray: + def matrix(self) -> jax.Array: """Transport matrix if it can be instantiated.""" try: return self.ot_prob.geom.transport_from_potentials(self.f, self.g) @@ -459,13 +457,13 @@ def transport_mass(self) -> float: """Sum of transport matrix.""" return self.marginal(0).sum() - def apply(self, inputs: jnp.ndarray, axis: int = 0) -> jnp.ndarray: + def apply(self, inputs: jax.Array, axis: int = 0) -> jax.Array: """Apply the transport to a ndarray; axis=1 for its transpose.""" return self.ot_prob.geom.apply_transport_from_potentials( self.f, self.g, inputs, axis=axis ) - def marginal(self, axis: int) -> jnp.ndarray: # noqa: D102 + def marginal(self, axis: int) -> jax.Array: # noqa: D102 return self.ot_prob.geom.marginal_from_potentials(self.f, self.g, axis=axis) def cost_at_geom(self, other_geom: geometry.Geometry) -> float: @@ -832,7 +830,7 @@ def __init__( def __call__( self, ot_prob: linear_problem.LinearProblem, - init: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray]] = (None, None), + init: Tuple[Optional[jax.Array], Optional[jax.Array]] = (None, None), rng: Optional[jax.Array] = None, ) -> SinkhornOutput: """Run Sinkhorn algorithm. @@ -868,9 +866,7 @@ def xi(tau_i: float, tau_j: float) -> float: k_ij = k(tau_i, tau_j) return k_ij / (1. - k_ij) - def smin( - potential: jnp.ndarray, marginal: jnp.ndarray, tau: float - ) -> float: + def smin(potential: jax.Array, marginal: jax.Array, tau: float) -> float: rho = uf.rho(ot_prob.epsilon, tau) return -rho * mu.logsumexp(-potential / rho, b=marginal) @@ -1015,8 +1011,8 @@ def outer_iterations(self) -> int: return np.ceil(self.max_iterations / self.inner_iterations).astype(int) def init_state( - self, ot_prob: linear_problem.LinearProblem, init: Tuple[jnp.ndarray, - jnp.ndarray] + self, ot_prob: linear_problem.LinearProblem, init: Tuple[jax.Array, + jax.Array] ) -> SinkhornState: """Return the initial state of the loop.""" fu, gv = init @@ -1124,7 +1120,7 @@ def tree_unflatten(cls, aux_data, children): # noqa: D102 def run( ot_prob: linear_problem.LinearProblem, solver: Sinkhorn, - init: Tuple[jnp.ndarray, ...] + init: Tuple[jax.Array, ...] ) -> SinkhornOutput: """Run loop of the solver, outputting a state upgraded to an output.""" iter_fun = _iterations_implicit if solver.implicit_diff else iterations @@ -1137,7 +1133,7 @@ def run( def iterations( ot_prob: linear_problem.LinearProblem, solver: Sinkhorn, - init: Tuple[jnp.ndarray, ...] + init: Tuple[jax.Array, ...] ) -> SinkhornOutput: """Jittable Sinkhorn loop. args contain initialization variables.""" @@ -1174,8 +1170,8 @@ def body_fn( def _iterations_taped( ot_prob: linear_problem.LinearProblem, solver: Sinkhorn, - init: Tuple[jnp.ndarray, ...] -) -> Tuple[SinkhornOutput, Tuple[jnp.ndarray, jnp.ndarray, + init: Tuple[jax.Array, ...] +) -> Tuple[SinkhornOutput, Tuple[jax.Array, jax.Array, linear_problem.LinearProblem, Sinkhorn]]: """Run forward pass of the Sinkhorn algorithm storing side information.""" state = iterations(ot_prob, solver, init) @@ -1194,7 +1190,7 @@ def _iterations_implicit_bwd(res, gr): considered. Returns: - a tuple of gradients: PyTree for geom, one jnp.ndarray for each of a and b. + a tuple of gradients: PyTree for geom, one jax.Array for each of a and b. """ f, g, ot_prob, solver = res gr = gr[:2] diff --git a/src/ott/solvers/linear/sinkhorn_lr.py b/src/ott/solvers/linear/sinkhorn_lr.py index db948cf8b..b6732f76f 100644 --- a/src/ott/solvers/linear/sinkhorn_lr.py +++ b/src/ott/solvers/linear/sinkhorn_lr.py @@ -43,12 +43,12 @@ class LRSinkhornState(NamedTuple): """State of the Low Rank Sinkhorn algorithm.""" - q: jnp.ndarray - r: jnp.ndarray - g: jnp.ndarray + q: jax.Array + r: jax.Array + g: jax.Array gamma: float - costs: jnp.ndarray - errors: jnp.ndarray + costs: jax.Array + errors: jax.Array crossed_threshold: bool def compute_error( # noqa: D102 @@ -79,7 +79,7 @@ def reg_ot_cost( # noqa: D102 def solution_error( # noqa: D102 self, ot_prob: linear_problem.LinearProblem, norm_error: Tuple[int, ...] - ) -> jnp.ndarray: + ) -> jax.Array: return solution_error(self.q, self.r, ot_prob, norm_error) def set(self, **kwargs: Any) -> "LRSinkhornState": @@ -88,9 +88,9 @@ def set(self, **kwargs: Any) -> "LRSinkhornState": def compute_reg_ot_cost( - q: jnp.ndarray, - r: jnp.ndarray, - g: jnp.ndarray, + q: jax.Array, + r: jax.Array, + g: jax.Array, ot_prob: linear_problem.LinearProblem, epsilon: float, use_danskin: bool = False @@ -110,7 +110,7 @@ def compute_reg_ot_cost( regularized OT cost, the (primal) transport cost of the low-rank solution. """ - def ent(x: jnp.ndarray) -> float: + def ent(x: jax.Array) -> float: # generalized entropy return jnp.sum(jsp.special.entr(x) + x) @@ -131,9 +131,9 @@ def ent(x: jnp.ndarray) -> float: def solution_error( - q: jnp.ndarray, r: jnp.ndarray, ot_prob: linear_problem.LinearProblem, + q: jax.Array, r: jax.Array, ot_prob: linear_problem.LinearProblem, norm_error: Tuple[int, ...] -) -> jnp.ndarray: +) -> jax.Array: """Compute solution error. Since only balanced case is available for LR, this is marginal deviation. @@ -166,13 +166,13 @@ def solution_error( class LRSinkhornOutput(NamedTuple): """Transport interface for a low-rank Sinkhorn solution.""" - q: jnp.ndarray - r: jnp.ndarray - g: jnp.ndarray - costs: jnp.ndarray + q: jax.Array + r: jax.Array + g: jax.Array + costs: jax.Array # TODO(michalk8): must be called `errors`, because of `store_inner_errors` # in future, enforce via class hierarchy - errors: jnp.ndarray + errors: jax.Array ot_prob: linear_problem.LinearProblem epsilon: float inner_iterations: int @@ -211,11 +211,11 @@ def geom(self) -> geometry.Geometry: # noqa: D102 return self.ot_prob.geom @property - def a(self) -> jnp.ndarray: # noqa: D102 + def a(self) -> jax.Array: # noqa: D102 return self.ot_prob.a @property - def b(self) -> jnp.ndarray: # noqa: D102 + def b(self) -> jax.Array: # noqa: D102 return self.ot_prob.b @property @@ -229,17 +229,17 @@ def converged(self) -> bool: # noqa: D102 ) @property - def matrix(self) -> jnp.ndarray: + def matrix(self) -> jax.Array: """Transport matrix if it can be instantiated.""" return (self.q * self._inv_g) @ self.r.T - def apply(self, inputs: jnp.ndarray, axis: int = 0) -> jnp.ndarray: + def apply(self, inputs: jax.Array, axis: int = 0) -> jax.Array: """Apply the transport to a array; axis=1 for its transpose.""" q, r = (self.q, self.r) if axis == 1 else (self.r, self.q) # for `axis=0`: (batch, m), (m, r), (r,), (r, n) return ((inputs @ r) * self._inv_g) @ q.T - def marginal(self, axis: int) -> jnp.ndarray: # noqa: D102 + def marginal(self, axis: int) -> jax.Array: # noqa: D102 length = self.q.shape[0] if axis == 0 else self.r.shape[0] return self.apply(jnp.ones(length,), axis=axis) @@ -262,7 +262,7 @@ def transport_mass(self) -> float: return self.marginal(0).sum() @property - def _inv_g(self) -> jnp.ndarray: + def _inv_g(self) -> jax.Array: return 1. / self.g @@ -341,8 +341,8 @@ def __init__( def __call__( self, ot_prob: linear_problem.LinearProblem, - init: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray], - Optional[jnp.ndarray]] = (None, None, None), + init: Tuple[Optional[jax.Array], Optional[jax.Array], + Optional[jax.Array]] = (None, None, None), rng: Optional[jax.Array] = None, **kwargs: Any, ) -> LRSinkhornOutput: @@ -371,7 +371,7 @@ def _get_costs( self, ot_prob: linear_problem.LinearProblem, state: LRSinkhornState, - ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, float]: + ) -> Tuple[jax.Array, jax.Array, jax.Array, float]: log_q, log_r, log_g = ( mu.safe_log(state.q), mu.safe_log(state.r), mu.safe_log(state.g) ) @@ -407,9 +407,9 @@ def _get_costs( # TODO(michalk8): move to `lr_utils` when refactoring this def dykstra_update_lse( self, - c_q: jnp.ndarray, - c_r: jnp.ndarray, - h: jnp.ndarray, + c_q: jax.Array, + c_r: jax.Array, + h: jax.Array, gamma: float, ot_prob: linear_problem.LinearProblem, min_entry_value: float = 1e-6, @@ -417,7 +417,7 @@ def dykstra_update_lse( min_iter: int = 0, inner_iter: int = 10, max_iter: int = 10000 - ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + ) -> Tuple[jax.Array, jax.Array, jax.Array]: """Run Dykstra's algorithm.""" # shortcuts for problem's definition. r = self.rank @@ -435,24 +435,24 @@ def dykstra_update_lse( constants = c_q, c_r, loga, logb def cond_fn( - iteration: int, constants: Tuple[jnp.ndarray, ...], - state_inner: Tuple[jnp.ndarray, ...] + iteration: int, constants: Tuple[jax.Array, ...], + state_inner: Tuple[jax.Array, ...] ) -> bool: del iteration, constants *_, err = state_inner return err > tolerance def _softm( - f: jnp.ndarray, g: jnp.ndarray, c: jnp.ndarray, axis: int - ) -> jnp.ndarray: + f: jax.Array, g: jax.Array, c: jax.Array, axis: int + ) -> jax.Array: return jsp.special.logsumexp( gamma * (f[:, None] + g[None, :] - c), axis=axis ) def body_fn( - iteration: int, constants: Tuple[jnp.ndarray, ...], - state_inner: Tuple[jnp.ndarray, ...], compute_error: bool - ) -> Tuple[jnp.ndarray, ...]: + iteration: int, constants: Tuple[jax.Array, ...], + state_inner: Tuple[jax.Array, ...], compute_error: bool + ) -> Tuple[jax.Array, ...]: # TODO(michalk8): in the future, use `NamedTuple` f1, f2, g1_old, g2_old, h_old, w_gi, w_gp, w_q, w_r, err = state_inner c_q, c_r, loga, logb = constants @@ -501,15 +501,15 @@ def body_fn( return f1, f2, g1_old, g2_old, h_old, w_gi, w_gp, w_q, w_r, err def recompute_couplings( - f1: jnp.ndarray, - g1: jnp.ndarray, - c_q: jnp.ndarray, - f2: jnp.ndarray, - g2: jnp.ndarray, - c_r: jnp.ndarray, - h: jnp.ndarray, + f1: jax.Array, + g1: jax.Array, + c_q: jax.Array, + f2: jax.Array, + g2: jax.Array, + c_r: jax.Array, + h: jax.Array, gamma: float, - ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + ) -> Tuple[jax.Array, jax.Array, jax.Array]: q = jnp.exp(gamma * (f1[:, None] + g1[None, :] - c_q)) r = jnp.exp(gamma * (f2[:, None] + g2[None, :] - c_r)) g = jnp.exp(gamma * h) @@ -524,9 +524,9 @@ def recompute_couplings( def dykstra_update_kernel( self, - k_q: jnp.ndarray, - k_r: jnp.ndarray, - k_g: jnp.ndarray, + k_q: jax.Array, + k_r: jax.Array, + k_g: jax.Array, gamma: float, ot_prob: linear_problem.LinearProblem, min_entry_value: float = 1e-6, @@ -534,7 +534,7 @@ def dykstra_update_kernel( min_iter: int = 0, inner_iter: int = 10, max_iter: int = 10000 - ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + ) -> Tuple[jax.Array, jax.Array, jax.Array]: """Run Dykstra's algorithm.""" # shortcuts for problem's definition. rank = self.rank @@ -553,17 +553,17 @@ def dykstra_update_kernel( constants = k_q, k_r, k_g, a, b def cond_fn( - iteration: int, constants: Tuple[jnp.ndarray, ...], - state_inner: Tuple[jnp.ndarray, ...] + iteration: int, constants: Tuple[jax.Array, ...], + state_inner: Tuple[jax.Array, ...] ) -> bool: del iteration, constants *_, err = state_inner return err > tolerance def body_fn( - iteration: int, constants: Tuple[jnp.ndarray, ...], - state_inner: Tuple[jnp.ndarray, ...], compute_error: bool - ) -> Tuple[jnp.ndarray, ...]: + iteration: int, constants: Tuple[jax.Array, ...], + state_inner: Tuple[jax.Array, ...], compute_error: bool + ) -> Tuple[jax.Array, ...]: # TODO(michalk8): in the future, use `NamedTuple` u1, u2, v1_old, v2_old, g_old, q_gi, q_gp, q_q, q_r, err = state_inner k_q, k_r, k_g, a, b = constants @@ -600,14 +600,14 @@ def body_fn( return u1, u2, v1_old, v2_old, g_old, q_gi, q_gp, q_q, q_r, err def recompute_couplings( - u1: jnp.ndarray, - v1: jnp.ndarray, - k_q: jnp.ndarray, - u2: jnp.ndarray, - v2: jnp.ndarray, - k_r: jnp.ndarray, - g: jnp.ndarray, - ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + u1: jax.Array, + v1: jax.Array, + k_q: jax.Array, + u2: jax.Array, + v2: jax.Array, + k_r: jax.Array, + g: jax.Array, + ) -> Tuple[jax.Array, jax.Array, jax.Array]: q = u1.reshape((-1, 1)) * k_q * v1.reshape((1, -1)) r = u2.reshape((-1, 1)) * k_r * v2.reshape((1, -1)) return q, r, g @@ -736,7 +736,7 @@ def create_initializer( def init_state( self, ot_prob: linear_problem.LinearProblem, - init: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray] + init: Tuple[jax.Array, jax.Array, jax.Array] ) -> LRSinkhornState: """Return the initial state of the loop.""" q, r, g = init @@ -811,8 +811,7 @@ def _diverged(self, state: LRSinkhornState, iteration: int) -> bool: def run( ot_prob: linear_problem.LinearProblem, solver: LRSinkhorn, - init: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray], - Optional[jnp.ndarray]], + init: Tuple[Optional[jax.Array], Optional[jax.Array], Optional[jax.Array]], ) -> LRSinkhornOutput: """Run loop of the solver, outputting a state upgraded to an output.""" out = sinkhorn.iterations(ot_prob, solver, init) diff --git a/src/ott/solvers/linear/univariate.py b/src/ott/solvers/linear/univariate.py index 2b6392227..1f2a47b6f 100644 --- a/src/ott/solvers/linear/univariate.py +++ b/src/ott/solvers/linear/univariate.py @@ -53,7 +53,7 @@ class UnivariateSolver: def __init__( self, - sort_fn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, + sort_fn: Optional[Callable[[jax.Array], jax.Array]] = None, cost_fn: Optional[costs.CostFn] = None, method: Literal["subsample", "quantile", "wasserstein", "equal"] = "subsample", @@ -66,10 +66,10 @@ def __init__( def __call__( self, - x: jnp.ndarray, - y: jnp.ndarray, - a: Optional[jnp.ndarray] = None, - b: Optional[jnp.ndarray] = None + x: jax.Array, + y: jax.Array, + a: Optional[jax.Array] = None, + b: Optional[jax.Array] = None ) -> float: """Computes the Univariate OT Distance between `x` and `y`. @@ -113,8 +113,8 @@ def __call__( return self.cost_fn.pairwise(xx, yy) * (n / xx.shape[0]) def _cdf_distance( - self, x: jnp.ndarray, y: jnp.ndarray, a: Optional[jnp.ndarray], - b: Optional[jnp.ndarray] + self, x: jax.Array, y: jax.Array, a: Optional[jax.Array], + b: Optional[jax.Array] ): # Implementation based on `scipy` implementation for # :func: diff --git a/src/ott/solvers/quadratic/_solve.py b/src/ott/solvers/quadratic/_solve.py index 9cdefec93..986680637 100644 --- a/src/ott/solvers/quadratic/_solve.py +++ b/src/ott/solvers/quadratic/_solve.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Any, Literal, Optional, Union -import jax.numpy as jnp +import jax from ott.geometry import geometry from ott.problems.quadratic import quadratic_costs, quadratic_problem @@ -28,8 +28,8 @@ def solve( geom_yy: geometry.Geometry, geom_xy: Optional[geometry.Geometry] = None, fused_penalty: float = 1.0, - a: Optional[jnp.ndarray] = None, - b: Optional[jnp.ndarray] = None, + a: Optional[jax.Array] = None, + b: Optional[jax.Array] = None, tau_a: float = 1.0, tau_b: float = 1.0, loss: Union[Literal["sqeucl", "kl"], quadratic_costs.GWLoss] = "sqeucl", diff --git a/src/ott/solvers/quadratic/gromov_wasserstein.py b/src/ott/solvers/quadratic/gromov_wasserstein.py index a7890e1c9..554cdaaed 100644 --- a/src/ott/solvers/quadratic/gromov_wasserstein.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein.py @@ -63,10 +63,10 @@ class GWOutput(NamedTuple): old_transport_mass: Holds total mass of transport at previous iteration. """ - costs: Optional[jnp.ndarray] = None - linear_convergence: Optional[jnp.ndarray] = None + costs: Optional[jax.Array] = None + linear_convergence: Optional[jax.Array] = None converged: bool = False - errors: Optional[jnp.ndarray] = None + errors: Optional[jax.Array] = None linear_state: Optional[LinearOutput] = None geom: Optional[geometry.Geometry] = None # Intermediate values. @@ -77,11 +77,11 @@ def set(self, **kwargs: Any) -> "GWOutput": return self._replace(**kwargs) @property - def matrix(self) -> jnp.ndarray: + def matrix(self) -> jax.Array: """Transport matrix.""" return self._rescale_factor * self.linear_state.matrix - def apply(self, inputs: jnp.ndarray, axis: int = 0) -> jnp.ndarray: + def apply(self, inputs: jax.Array, axis: int = 0) -> jax.Array: """Apply the transport to an array; axis=1 for its transpose.""" return self._rescale_factor * self.linear_state.apply(inputs, axis=axis) @@ -124,13 +124,13 @@ class GWState(NamedTuple): at each iteration. """ - costs: jnp.ndarray - linear_convergence: jnp.ndarray + costs: jax.Array + linear_convergence: jax.Array linear_state: LinearOutput linear_pb: linear_problem.LinearProblem old_transport_mass: float rngs: Optional[jax.Array] = None - errors: Optional[jnp.ndarray] = None + errors: Optional[jax.Array] = None def set(self, **kwargs: Any) -> "GWState": """Return a copy of self, possibly with overwrites.""" diff --git a/src/ott/solvers/quadratic/gromov_wasserstein_lr.py b/src/ott/solvers/quadratic/gromov_wasserstein_lr.py index 214853f4c..62a5592bc 100644 --- a/src/ott/solvers/quadratic/gromov_wasserstein_lr.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein_lr.py @@ -46,12 +46,12 @@ class LRGWState(NamedTuple): """State of the low-rank GW algorithm.""" - q: jnp.ndarray - r: jnp.ndarray - g: jnp.ndarray + q: jax.Array + r: jax.Array + g: jax.Array gamma: float - costs: jnp.ndarray - errors: jnp.ndarray + costs: jax.Array + errors: jax.Array crossed_threshold: bool def compute_error( # noqa: D102 @@ -85,9 +85,9 @@ def set(self, **kwargs: Any) -> "LRGWState": def compute_reg_gw_cost( - q: jnp.ndarray, - r: jnp.ndarray, - g: jnp.ndarray, + q: jax.Array, + r: jax.Array, + g: jax.Array, ot_prob: quadratic_problem.QuadraticProblem, epsilon: float, use_danskin: bool = False @@ -107,7 +107,7 @@ def compute_reg_gw_cost( regularized OT cost, the (primal) transport cost of the low-rank solution. """ - def ent(x: jnp.ndarray) -> float: + def ent(x: jax.Array) -> float: # generalized entropy return jnp.sum(jsp.special.entr(x) + x) @@ -139,13 +139,13 @@ def ent(x: jnp.ndarray) -> float: class LRGWOutput(NamedTuple): """Transport interface for a low-rank GW solution.""" - q: jnp.ndarray - r: jnp.ndarray - g: jnp.ndarray - costs: jnp.ndarray + q: jax.Array + r: jax.Array + g: jax.Array + costs: jax.Array # TODO(michalk8): must be called `errors`, because of `store_inner_errors` # in future, enforce via class hierarchy - errors: jnp.ndarray + errors: jax.Array ot_prob: quadratic_problem.QuadraticProblem epsilon: float inner_iterations: int @@ -184,11 +184,11 @@ def geom(self) -> geometry.Geometry: # noqa: D102 return _linearized_geometry(self.ot_prob, q=self.q, r=self.r, g=self.g) @property - def a(self) -> jnp.ndarray: # noqa: D102 + def a(self) -> jax.Array: # noqa: D102 return self.ot_prob.a @property - def b(self) -> jnp.ndarray: # noqa: D102 + def b(self) -> jax.Array: # noqa: D102 return self.ot_prob.b @property @@ -202,17 +202,17 @@ def converged(self) -> bool: # noqa: D102 ) @property - def matrix(self) -> jnp.ndarray: + def matrix(self) -> jax.Array: """Transport matrix if it can be instantiated.""" return (self.q * self._inv_g) @ self.r.T - def apply(self, inputs: jnp.ndarray, axis: int = 0) -> jnp.ndarray: + def apply(self, inputs: jax.Array, axis: int = 0) -> jax.Array: """Apply the transport to a array; axis=1 for its transpose.""" q, r = (self.q, self.r) if axis == 1 else (self.r, self.q) # for `axis=0`: (batch, m), (m, r), (r,), (r, n) return ((inputs @ r) * self._inv_g) @ q.T - def marginal(self, axis: int) -> jnp.ndarray: # noqa: D102 + def marginal(self, axis: int) -> jax.Array: # noqa: D102 length = self.q.shape[0] if axis == 0 else self.r.shape[0] return self.apply(jnp.ones(length,), axis=axis) @@ -250,7 +250,7 @@ def transport_mass(self) -> float: return self.marginal(0).sum() @property - def _inv_g(self) -> jnp.ndarray: + def _inv_g(self) -> jax.Array: return 1.0 / self.g @@ -334,8 +334,8 @@ def __init__( def __call__( self, ot_prob: quadratic_problem.QuadraticProblem, - init: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray], - Optional[jnp.ndarray]] = (None, None, None), + init: Tuple[Optional[jax.Array], Optional[jax.Array], + Optional[jax.Array]] = (None, None, None), rng: Optional[jax.Array] = None, **kwargs: Any, ) -> LRGWOutput: @@ -370,7 +370,7 @@ def _get_costs( self, ot_prob: quadratic_problem.QuadraticProblem, state: LRGWState, - ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, float]: + ) -> Tuple[jax.Array, jax.Array, jax.Array, float]: q, r, g = state.q, state.r, state.g log_q, log_r, log_g = mu.safe_log(q), mu.safe_log(r), mu.safe_log(g) inv_g = 1.0 / g[None, :] @@ -427,9 +427,9 @@ def _get_costs( # TODO(michalk8): move to `lr_utils` when refactoring this the future def dykstra_update_lse( self, - c_q: jnp.ndarray, - c_r: jnp.ndarray, - h: jnp.ndarray, + c_q: jax.Array, + c_r: jax.Array, + h: jax.Array, gamma: float, ot_prob: quadratic_problem.QuadraticProblem, min_entry_value: float = 1e-6, @@ -437,7 +437,7 @@ def dykstra_update_lse( min_iter: int = 0, inner_iter: int = 10, max_iter: int = 10000 - ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + ) -> Tuple[jax.Array, jax.Array, jax.Array]: """Run Dykstra's algorithm.""" # shortcuts for problem's definition. r = self.rank @@ -455,24 +455,24 @@ def dykstra_update_lse( constants = c_q, c_r, loga, logb def cond_fn( - iteration: int, constants: Tuple[jnp.ndarray, ...], - state_inner: Tuple[jnp.ndarray, ...] + iteration: int, constants: Tuple[jax.Array, ...], + state_inner: Tuple[jax.Array, ...] ) -> bool: del iteration, constants *_, err = state_inner return err > tolerance def _softm( - f: jnp.ndarray, g: jnp.ndarray, c: jnp.ndarray, axis: int - ) -> jnp.ndarray: + f: jax.Array, g: jax.Array, c: jax.Array, axis: int + ) -> jax.Array: return jsp.special.logsumexp( gamma * (f[:, None] + g[None, :] - c), axis=axis ) def body_fn( - iteration: int, constants: Tuple[jnp.ndarray, ...], - state_inner: Tuple[jnp.ndarray, ...], compute_error: bool - ) -> Tuple[jnp.ndarray, ...]: + iteration: int, constants: Tuple[jax.Array, ...], + state_inner: Tuple[jax.Array, ...], compute_error: bool + ) -> Tuple[jax.Array, ...]: # TODO(michalk8): in the future, use `NamedTuple` f1, f2, g1_old, g2_old, h_old, w_gi, w_gp, w_q, w_r, err = state_inner c_q, c_r, loga, logb = constants @@ -522,15 +522,15 @@ def body_fn( return f1, f2, g1_old, g2_old, h_old, w_gi, w_gp, w_q, w_r, err def recompute_couplings( - f1: jnp.ndarray, - g1: jnp.ndarray, - c_q: jnp.ndarray, - f2: jnp.ndarray, - g2: jnp.ndarray, - c_r: jnp.ndarray, - h: jnp.ndarray, + f1: jax.Array, + g1: jax.Array, + c_q: jax.Array, + f2: jax.Array, + g2: jax.Array, + c_r: jax.Array, + h: jax.Array, gamma: float, - ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + ) -> Tuple[jax.Array, jax.Array, jax.Array]: q = jnp.exp(gamma * (f1[:, None] + g1[None, :] - c_q)) r = jnp.exp(gamma * (f2[:, None] + g2[None, :] - c_r)) g = jnp.exp(gamma * h) @@ -545,9 +545,9 @@ def recompute_couplings( def dykstra_update_kernel( self, - k_q: jnp.ndarray, - k_r: jnp.ndarray, - k_g: jnp.ndarray, + k_q: jax.Array, + k_r: jax.Array, + k_g: jax.Array, gamma: float, ot_prob: quadratic_problem.QuadraticProblem, min_entry_value: float = 1e-6, @@ -555,7 +555,7 @@ def dykstra_update_kernel( min_iter: int = 0, inner_iter: int = 10, max_iter: int = 10000 - ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + ) -> Tuple[jax.Array, jax.Array, jax.Array]: """Run Dykstra's algorithm.""" # shortcuts for problem's definition. del gamma @@ -575,17 +575,17 @@ def dykstra_update_kernel( constants = k_q, k_r, k_g, a, b def cond_fn( - iteration: int, constants: Tuple[jnp.ndarray, ...], - state_inner: Tuple[jnp.ndarray, ...] + iteration: int, constants: Tuple[jax.Array, ...], + state_inner: Tuple[jax.Array, ...] ) -> bool: del iteration, constants *_, err = state_inner return err > tolerance def body_fn( - iteration: int, constants: Tuple[jnp.ndarray, ...], - state_inner: Tuple[jnp.ndarray, ...], compute_error: bool - ) -> Tuple[jnp.ndarray, ...]: + iteration: int, constants: Tuple[jax.Array, ...], + state_inner: Tuple[jax.Array, ...], compute_error: bool + ) -> Tuple[jax.Array, ...]: # TODO(michalk8): in the future, use `NamedTuple` u1, u2, v1_old, v2_old, g_old, q_gi, q_gp, q_q, q_r, err = state_inner k_q, k_r, k_g, a, b = constants @@ -623,14 +623,14 @@ def body_fn( return u1, u2, v1_old, v2_old, g_old, q_gi, q_gp, q_q, q_r, err def recompute_couplings( - u1: jnp.ndarray, - v1: jnp.ndarray, - k_q: jnp.ndarray, - u2: jnp.ndarray, - v2: jnp.ndarray, - k_r: jnp.ndarray, - g: jnp.ndarray, - ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + u1: jax.Array, + v1: jax.Array, + k_q: jax.Array, + u2: jax.Array, + v2: jax.Array, + k_r: jax.Array, + g: jax.Array, + ) -> Tuple[jax.Array, jax.Array, jax.Array]: q = u1.reshape((-1, 1)) * k_q * v1.reshape((1, -1)) r = u2.reshape((-1, 1)) * k_r * v2.reshape((1, -1)) return q, r, g @@ -762,7 +762,7 @@ def create_initializer( def init_state( self, ot_prob: quadratic_problem.QuadraticProblem, - init: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray] + init: Tuple[jax.Array, jax.Array, jax.Array] ) -> LRGWState: """Return the initial state of the loop.""" q, r, g = init @@ -837,8 +837,7 @@ def _diverged(self, state: LRGWState, iteration: int) -> bool: def run( ot_prob: quadratic_problem.QuadraticProblem, solver: LRGromovWasserstein, - init: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray], - Optional[jnp.ndarray]], + init: Tuple[Optional[jax.Array], Optional[jax.Array], Optional[jax.Array]], ) -> LRGWOutput: """Run loop of the solver, outputting a state upgraded to an output.""" out = sinkhorn.iterations(ot_prob, solver, init) @@ -849,9 +848,9 @@ def run( def dykstra_solution_error( - q: jnp.ndarray, r: jnp.ndarray, ot_prob: quadratic_problem.QuadraticProblem, + q: jax.Array, r: jax.Array, ot_prob: quadratic_problem.QuadraticProblem, norm_error: Tuple[int, ...] -) -> jnp.ndarray: +) -> jax.Array: """Compute solution error. Since only balanced case is available for LR, this is marginal deviation. @@ -884,9 +883,9 @@ def dykstra_solution_error( def _linearized_geometry( prob: quadratic_problem.QuadraticProblem, *, - q: jnp.ndarray, - r: jnp.ndarray, - g: jnp.ndarray, + q: jax.Array, + r: jax.Array, + g: jax.Array, ) -> low_rank.LRCGeometry: inv_sqrt_g = 1.0 / jnp.sqrt(g[None, :]) diff --git a/src/ott/solvers/quadratic/gw_barycenter.py b/src/ott/solvers/quadratic/gw_barycenter.py index f0d350b08..0f753793e 100644 --- a/src/ott/solvers/quadratic/gw_barycenter.py +++ b/src/ott/solvers/quadratic/gw_barycenter.py @@ -45,13 +45,13 @@ class GWBarycenterState(NamedTuple): gw_convergence: Array of shape ``[max_iter,]`` containing the convergence of all GW problems at each iteration. """ - cost: Optional[jnp.ndarray] = None - x: Optional[jnp.ndarray] = None - a: Optional[jnp.ndarray] = None - errors: Optional[jnp.ndarray] = None - costs: Optional[jnp.ndarray] = None - costs_bary: Optional[jnp.ndarray] = None - gw_convergence: Optional[jnp.ndarray] = None + cost: Optional[jax.Array] = None + x: Optional[jax.Array] = None + a: Optional[jax.Array] = None + errors: Optional[jax.Array] = None + costs: Optional[jax.Array] = None + costs_bary: Optional[jax.Array] = None + gw_convergence: Optional[jax.Array] = None def set(self, **kwargs: Any) -> "GWBarycenterState": """Return a copy of self, possibly with overwrites.""" @@ -133,9 +133,8 @@ def init_state( self, problem: gw_barycenter.GWBarycenterProblem, bar_size: int, - bar_init: Optional[Union[jnp.ndarray, Tuple[jnp.ndarray, - jnp.ndarray]]] = None, - a: Optional[jnp.ndarray] = None, + bar_init: Optional[Union[jax.Array, Tuple[jax.Array, jax.Array]]] = None, + a: Optional[jax.Array] = None, rng: Optional[jax.Array] = None, ) -> GWBarycenterState: """Initialize the (fused) Gromov-Wasserstein barycenter state. @@ -210,13 +209,13 @@ def update_state( iteration: int, problem: gw_barycenter.GWBarycenterProblem, store_errors: bool = True, - ) -> Tuple[float, bool, jnp.ndarray, Optional[jnp.ndarray]]: + ) -> Tuple[float, bool, jax.Array, Optional[jax.Array]]: """Solve the (fused) Gromov-Wasserstein barycenter problem.""" def solve_gw( - state: GWBarycenterState, b: jnp.ndarray, y: jnp.ndarray, - f: Optional[jnp.ndarray] - ) -> Tuple[float, bool, jnp.ndarray, Optional[jnp.ndarray]]: + state: GWBarycenterState, b: jax.Array, y: jax.Array, + f: Optional[jax.Array] + ) -> Tuple[float, bool, jax.Array, Optional[jax.Array]]: quad_problem = problem._create_problem(state, y=y, b=b, f=f) out = self._quad_solver(quad_problem) return ( @@ -282,9 +281,8 @@ def tree_unflatten( # noqa: D102 @partial(jax.vmap, in_axes=[None, 0, None, 0, None]) def init_transports( - solver, rng: jax.Array, a: jnp.ndarray, b: jnp.ndarray, - epsilon: Optional[float] -) -> jnp.ndarray: + solver, rng: jax.Array, a: jax.Array, b: jax.Array, epsilon: Optional[float] +) -> jax.Array: """Initialize random 2D point cloud and solve the linear OT problem. Args: diff --git a/src/ott/tools/gaussian_mixture/fit_gmm.py b/src/ott/tools/gaussian_mixture/fit_gmm.py index 4c62bded7..45d8e0935 100644 --- a/src/ott/tools/gaussian_mixture/fit_gmm.py +++ b/src/ott/tools/gaussian_mixture/fit_gmm.py @@ -62,8 +62,8 @@ def get_assignment_probs( - gmm: gaussian_mixture.GaussianMixture, points: jnp.ndarray -) -> jnp.ndarray: + gmm: gaussian_mixture.GaussianMixture, points: jax.Array +) -> jax.Array: r"""Get component assignment probabilities used in the E step of EM. Here we compute the component assignment probabilities p(Z|X, \Theta^{(t)}) @@ -81,9 +81,9 @@ def get_assignment_probs( def get_q( gmm: gaussian_mixture.GaussianMixture, - assignment_probs: jnp.ndarray, - points: jnp.ndarray, - point_weights: Optional[jnp.ndarray] = None, + assignment_probs: jax.Array, + points: jax.Array, + point_weights: Optional[jax.Array] = None, ) -> float: r"""Get Q(\Theta|\Theta^{(t)}). @@ -109,8 +109,8 @@ def get_q( def log_prob_loss( gmm: gaussian_mixture.GaussianMixture, - points: jnp.ndarray, - point_weights: Optional[jnp.ndarray] = None, + points: jax.Array, + point_weights: Optional[jax.Array] = None, ) -> float: """Loss function: weighted mean of (-log prob of observations). @@ -130,8 +130,8 @@ def log_prob_loss( def fit_model_em( gmm: gaussian_mixture.GaussianMixture, - points: jnp.ndarray, - point_weights: Optional[jnp.ndarray], + points: jax.Array, + point_weights: Optional[jax.Array], steps: int, jit: bool = True, verbose: bool = False, @@ -184,10 +184,10 @@ def fit_model_em( # See https://en.wikipedia.org/wiki/K-means%2B%2B for details -def _get_dist_sq(points: jnp.ndarray, loc: jnp.ndarray) -> jnp.ndarray: +def _get_dist_sq(points: jax.Array, loc: jax.Array) -> jax.Array: """Get the squared distance from each point to each loc.""" - def _dist_sq_one_loc(points: jnp.ndarray, loc: jnp.ndarray) -> jnp.ndarray: + def _dist_sq_one_loc(points: jax.Array, loc: jax.Array) -> jax.Array: return jnp.sum((points - loc[None]) ** 2., axis=-1) dist_sq_fn = jax.vmap(_dist_sq_one_loc, in_axes=(None, 0), out_axes=1) @@ -195,8 +195,8 @@ def _dist_sq_one_loc(points: jnp.ndarray, loc: jnp.ndarray) -> jnp.ndarray: def _get_locs( - rng: jax.Array, points: jnp.ndarray, n_components: int -) -> jnp.ndarray: + rng: jax.Array, points: jax.Array, n_components: int +) -> jax.Array: """Get the initial component means. Args: @@ -230,8 +230,8 @@ def _get_locs( def from_kmeans_plusplus( rng: jax.Array, - points: jnp.ndarray, - point_weights: Optional[jnp.ndarray], + points: jax.Array, + point_weights: Optional[jax.Array], n_components: int, ) -> gaussian_mixture.GaussianMixture: """Initialize a GMM via a single pass of K-means++. @@ -266,8 +266,8 @@ def from_kmeans_plusplus( def initialize( rng: jax.Array, - points: jnp.ndarray, - point_weights: Optional[jnp.ndarray], + points: jax.Array, + point_weights: Optional[jax.Array], n_components: int, n_attempts: int = 50, verbose: bool = False diff --git a/src/ott/tools/gaussian_mixture/fit_gmm_pair.py b/src/ott/tools/gaussian_mixture/fit_gmm_pair.py index 7ecde263c..35222caf9 100644 --- a/src/ott/tools/gaussian_mixture/fit_gmm_pair.py +++ b/src/ott/tools/gaussian_mixture/fit_gmm_pair.py @@ -98,9 +98,9 @@ class Observations(NamedTuple): """Weighted observations and their E-step assignment probabilities.""" - points: jnp.ndarray - point_weights: jnp.ndarray - assignment_probs: jnp.ndarray + points: jax.Array + point_weights: jax.Array + assignment_probs: jax.Array # Model fit @@ -108,7 +108,7 @@ class Observations(NamedTuple): def get_q( gmm: gaussian_mixture.GaussianMixture, obs: Observations -) -> jnp.ndarray: +) -> jax.Array: r"""Get Q(\Theta|\Theta^{(t)}). Here Q is the log likelihood for our observations based on the current @@ -159,7 +159,7 @@ def _objective_fn( pair: gaussian_mixture_pair.GaussianMixturePair, obs0: Observations, obs1: Observations, - ) -> jnp.ndarray: + ) -> jax.Array: """Compute the objective function for a pair of GMMs. Args: @@ -204,11 +204,11 @@ def print_losses( def do_e_step( # noqa: D103 - e_step_fn: Callable[[gaussian_mixture.GaussianMixture, jnp.ndarray], - jnp.ndarray], + e_step_fn: Callable[[gaussian_mixture.GaussianMixture, jax.Array], + jax.Array], gmm: gaussian_mixture.GaussianMixture, - points: jnp.ndarray, - point_weights: jnp.ndarray, + points: jax.Array, + point_weights: jax.Array, ) -> Observations: assignment_probs = e_step_fn(gmm, points) return Observations( @@ -307,10 +307,10 @@ def get_fit_model_em_fn( def _fit_model_em( pair: gaussian_mixture_pair.GaussianMixturePair, - points0: jnp.ndarray, - points1: jnp.ndarray, - point_weights0: Optional[jnp.ndarray], - point_weights1: Optional[jnp.ndarray], + points0: jax.Array, + points1: jax.Array, + point_weights0: Optional[jax.Array], + point_weights1: Optional[jax.Array], em_steps: int, m_steps: int = 50, verbose: bool = False, diff --git a/src/ott/tools/gaussian_mixture/gaussian.py b/src/ott/tools/gaussian_mixture/gaussian.py index 6e0a8ccb7..b8c8e227b 100644 --- a/src/ott/tools/gaussian_mixture/gaussian.py +++ b/src/ott/tools/gaussian_mixture/gaussian.py @@ -28,15 +28,15 @@ class Gaussian: """Normal distribution.""" - def __init__(self, loc: jnp.ndarray, scale: scale_tril.ScaleTriL): + def __init__(self, loc: jax.Array, scale: scale_tril.ScaleTriL): self._loc = loc self._scale = scale @classmethod def from_samples( cls, - points: jnp.ndarray, - weights: Optional[jnp.ndarray] = None + points: jax.Array, + weights: Optional[jax.Array] = None ) -> "Gaussian": """Construct a Gaussian from weighted samples. @@ -67,7 +67,7 @@ def from_random( n_dimensions: int, stdev_mean: float = 0.1, stdev_cov: float = 0.1, - ridge: Union[float, jnp.ndarray] = 0, + ridge: Union[float, jax.Array] = 0, dtype: Optional[jnp.dtype] = None ) -> "Gaussian": """Construct a random Gaussian. @@ -94,13 +94,13 @@ def from_random( return cls(loc=loc, scale=scale) @classmethod - def from_mean_and_cov(cls, mean: jnp.ndarray, cov: jnp.ndarray) -> "Gaussian": + def from_mean_and_cov(cls, mean: jax.Array, cov: jax.Array) -> "Gaussian": """Construct a Gaussian from a mean and covariance.""" scale = scale_tril.ScaleTriL.from_covariance(cov) return cls(loc=mean, scale=scale) @property - def loc(self) -> jnp.ndarray: + def loc(self) -> jax.Array: """Mean of the Gaussian.""" return self._loc @@ -114,22 +114,22 @@ def n_dimensions(self) -> int: """Dimensionality of the Gaussian.""" return self.loc.shape[-1] - def covariance(self) -> jnp.ndarray: + def covariance(self) -> jax.Array: """Covariance of the Gaussian.""" return self.scale.covariance() - def to_z(self, x: jnp.ndarray) -> jnp.ndarray: + def to_z(self, x: jax.Array) -> jax.Array: r"""Transform :math:`x` to :math:`z = \frac{x - loc}{scale}`.""" return self.scale.centered_to_z(x_centered=x - self.loc) - def from_z(self, z: jnp.ndarray) -> jnp.ndarray: + def from_z(self, z: jax.Array) -> jax.Array: r"""Transform :math:`z` to :math:`x = loc + scale \cdot z`.""" return self.scale.z_to_centered(z=z) + self.loc def log_prob( self, - x: jnp.ndarray, # (?, d) - ) -> jnp.ndarray: # (?, d) + x: jax.Array, # (?, d) + ) -> jax.Array: # (?, d) """Log probability for a Gaussian with a diagonal covariance.""" d = x.shape[-1] z = self.to_z(x) @@ -138,7 +138,7 @@ def log_prob( -0.5 * (d * LOG2PI + log_det[None] + jnp.sum(z ** 2., axis=-1)) ) # (?, k) - def sample(self, rng: jax.Array, size: int) -> jnp.ndarray: + def sample(self, rng: jax.Array, size: int) -> jax.Array: """Generate samples from the distribution.""" std_samples_t = jax.random.normal(key=rng, shape=(self.n_dimensions, size)) return self.loc[None] + ( @@ -149,7 +149,7 @@ def sample(self, rng: jax.Array, size: int) -> jnp.ndarray: ) ) - def w2_dist(self, other: "Gaussian") -> jnp.ndarray: + def w2_dist(self, other: "Gaussian") -> jax.Array: r"""Wasserstein distance :math:`W_2^2` to another Gaussian. .. math:: @@ -167,7 +167,7 @@ def w2_dist(self, other: "Gaussian") -> jnp.ndarray: delta_sigma = self.scale.w2_dist(other.scale) return delta_mean + delta_sigma - def f_potential(self, dest: "Gaussian", points: jnp.ndarray) -> jnp.ndarray: + def f_potential(self, dest: "Gaussian", points: jax.Array) -> jax.Array: """Optimal potential for W2 distance between Gaussians. Evaluated on points. Args: @@ -191,7 +191,7 @@ def batch_inner_product(x, y): points.dot(dest.loc) ) - def transport(self, dest: "Gaussian", points: jnp.ndarray) -> jnp.ndarray: + def transport(self, dest: "Gaussian", points: jax.Array) -> jax.Array: """Transport points according to map between two Gaussian measures. Args: diff --git a/src/ott/tools/gaussian_mixture/gaussian_mixture.py b/src/ott/tools/gaussian_mixture/gaussian_mixture.py index 313689939..a9cb2b326 100644 --- a/src/ott/tools/gaussian_mixture/gaussian_mixture.py +++ b/src/ott/tools/gaussian_mixture/gaussian_mixture.py @@ -27,9 +27,8 @@ def get_summary_stats_from_points_and_assignment_probs( - points: jnp.ndarray, point_weights: jnp.ndarray, - assignment_probs: jnp.ndarray -) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + points: jax.Array, point_weights: jax.Array, assignment_probs: jax.Array +) -> Tuple[jax.Array, jax.Array, jax.Array]: """Get component summary stats from points and component probabilities. Args: @@ -68,7 +67,7 @@ class GaussianMixture: """Gaussian Mixture model.""" def __init__( - self, loc: jnp.ndarray, scale_params: jnp.ndarray, + self, loc: jax.Array, scale_params: jax.Array, component_weight_ob: probabilities.Probabilities ): self._loc = loc @@ -113,7 +112,7 @@ def from_random( @classmethod def from_mean_cov_component_weights( - cls, mean: jnp.ndarray, cov: jnp.ndarray, component_weights: jnp.ndarray + cls, mean: jax.Array, cov: jax.Array, component_weights: jax.Array ): """Construct a GMM from means, covariances, and component weights.""" scale_params = [] @@ -128,9 +127,9 @@ def from_mean_cov_component_weights( @classmethod def from_points_and_assignment_probs( cls, - points: jnp.ndarray, - point_weights: jnp.ndarray, - assignment_probs: jnp.ndarray, + points: jax.Array, + point_weights: jax.Array, + assignment_probs: jax.Array, ) -> "GaussianMixture": """Estimate a GMM from points and a set of component probabilities.""" mean, cov, wts = get_summary_stats_from_points_and_assignment_probs( @@ -158,17 +157,17 @@ def n_components(self): return self._loc.shape[-2] @property - def loc(self) -> jnp.ndarray: + def loc(self) -> jax.Array: """Location parameters of the GMM.""" return self._loc @property - def scale_params(self) -> jnp.ndarray: + def scale_params(self) -> jax.Array: """Scale parameters of the GMM.""" return self._scale_params @property - def cholesky(self) -> jnp.ndarray: + def cholesky(self) -> jax.Array: """Cholesky decomposition of the GMM covariance matrices.""" size = self.n_dimensions @@ -178,7 +177,7 @@ def _get_cholesky(scale_params): return jax.vmap(_get_cholesky, in_axes=0, out_axes=0)(self.scale_params) @property - def covariance(self) -> jnp.ndarray: + def covariance(self) -> jax.Array: """Covariance matrices of the GMM.""" size = self.n_dimensions @@ -193,16 +192,16 @@ def component_weight_ob(self) -> probabilities.Probabilities: return self._component_weight_ob @property - def component_weights(self) -> jnp.ndarray: + def component_weights(self) -> jax.Array: """Component weights probabilities.""" return self._component_weight_ob.probs() - def log_component_weights(self) -> jnp.ndarray: + def log_component_weights(self) -> jax.Array: """Log component weights probabilities.""" return self._component_weight_ob.log_probs() def _get_normal( - self, loc: jnp.ndarray, scale_params: jnp.ndarray + self, loc: jax.Array, scale_params: jax.Array ) -> gaussian.Gaussian: size = loc.shape[-1] return gaussian.Gaussian( @@ -219,7 +218,7 @@ def components(self) -> List[gaussian.Gaussian]: """List of all GMM components.""" return [self.get_component(i) for i in range(self.n_components)] - def sample(self, rng: jax.Array, size: int) -> jnp.ndarray: + def sample(self, rng: jax.Array, size: int) -> jax.Array: """Generate samples from the distribution.""" subrng0, subrng1 = jax.random.split(rng) component = self.component_weight_ob.sample(rng=subrng0, size=size) @@ -244,7 +243,7 @@ def _transform_single_value(single_component, single_x): axis=0 ) - def conditional_log_prob(self, x: jnp.ndarray) -> jnp.ndarray: + def conditional_log_prob(self, x: jax.Array) -> jax.Array: """Compute the component-conditional log probability of x. Args: @@ -256,7 +255,7 @@ def conditional_log_prob(self, x: jnp.ndarray) -> jnp.ndarray: """ def _log_prob_single_component( - loc: jnp.ndarray, scale_params: jnp.ndarray, x: jnp.ndarray + loc: jax.Array, scale_params: jax.Array, x: jax.Array ): norm = self._get_normal(loc=loc, scale_params=scale_params) return norm.log_prob(x) @@ -266,7 +265,7 @@ def _log_prob_single_component( ) return conditional_log_prob_fn(self._loc, self._scale_params, x) - def log_prob(self, x: jnp.ndarray) -> jnp.ndarray: + def log_prob(self, x: jax.Array) -> jax.Array: """Compute the log probability of the observations x. Args: @@ -282,7 +281,7 @@ def log_prob(self, x: jnp.ndarray) -> jnp.ndarray: log_prob_conditional + log_component_weight[None, :], axis=-1 ) - def get_log_component_posterior(self, x: jnp.ndarray) -> jnp.ndarray: + def get_log_component_posterior(self, x: jax.Array) -> jax.Array: """Compute the posterior probability that x came from each component. Args: diff --git a/src/ott/tools/gaussian_mixture/gaussian_mixture_pair.py b/src/ott/tools/gaussian_mixture/gaussian_mixture_pair.py index b24506fcc..21d4dbaf1 100644 --- a/src/ott/tools/gaussian_mixture/gaussian_mixture_pair.py +++ b/src/ott/tools/gaussian_mixture/gaussian_mixture_pair.py @@ -128,12 +128,12 @@ def get_bures_geometry(self) -> pointcloud.PointCloud: epsilon=self.epsilon ) - def get_cost_matrix(self) -> jnp.ndarray: + def get_cost_matrix(self) -> jax.Array: """Get matrix of :math:`W_2^2` costs between all pairs of components.""" return self.get_bures_geometry().cost_matrix def get_sinkhorn( - self, cost_matrix: jnp.ndarray, **kwargs: Any + self, cost_matrix: jax.Array, **kwargs: Any ) -> sinkhorn.SinkhornOutput: """Get the output of Sinkhorn's method for a given cost matrix.""" # We use a Geometry here rather than the PointCloud created in @@ -152,7 +152,7 @@ def get_sinkhorn( def get_normalized_sinkhorn_coupling( self, sinkhorn_output: sinkhorn.SinkhornOutput, - ) -> jnp.ndarray: + ) -> jax.Array: """Get the normalized coupling matrix for the specified Sinkhorn output. Args: diff --git a/src/ott/tools/gaussian_mixture/linalg.py b/src/ott/tools/gaussian_mixture/linalg.py index 8e71369f3..2a5114d69 100644 --- a/src/ott/tools/gaussian_mixture/linalg.py +++ b/src/ott/tools/gaussian_mixture/linalg.py @@ -18,9 +18,9 @@ def get_mean_and_var( - points: jnp.ndarray, # (n, d) - weights: jnp.ndarray, # (n,) -) -> Tuple[jnp.ndarray, jnp.ndarray]: + points: jax.Array, # (n, d) + weights: jax.Array, # (n,) +) -> Tuple[jax.Array, jax.Array]: """Get the mean and variance of a weighted set of points.""" weights_sum = jnp.sum(weights, axis=-1) # (1,) mean = ( @@ -37,9 +37,9 @@ def get_mean_and_var( def get_mean_and_cov( - points: jnp.ndarray, # (n, d) - weights: jnp.ndarray, # (n,) -) -> Tuple[jnp.ndarray, jnp.ndarray]: + points: jax.Array, # (n, d) + weights: jax.Array, # (n,) +) -> Tuple[jax.Array, jax.Array]: """Get the mean and covariance of a weighted set of points.""" weights_sum = jnp.sum(weights, axis=-1, keepdims=True) # (1,) mean = ( @@ -59,7 +59,7 @@ def get_mean_and_cov( return mean, cov -def flat_to_tril(x: jnp.ndarray, size: int) -> jnp.ndarray: +def flat_to_tril(x: jax.Array, size: int) -> jax.Array: """Map flat values to lower triangular matrices. Args: @@ -76,7 +76,7 @@ def flat_to_tril(x: jnp.ndarray, size: int) -> jnp.ndarray: return m.at[..., tril[0], tril[1]].set(x) -def tril_to_flat(m: jnp.ndarray) -> jnp.ndarray: +def tril_to_flat(m: jax.Array) -> jax.Array: """Flatten lower triangular matrices. Args: @@ -91,8 +91,8 @@ def tril_to_flat(m: jnp.ndarray) -> jnp.ndarray: def apply_to_diag( - m: jnp.ndarray, fn: Callable[[jnp.ndarray], jnp.ndarray] -) -> jnp.ndarray: + m: jax.Array, fn: Callable[[jax.Array], jax.Array] +) -> jax.Array: """Apply a function to the diagonal of a matrix.""" size = m.shape[-1] diag = jnp.diagonal(m, axis1=-2, axis2=-1) @@ -101,9 +101,9 @@ def apply_to_diag( def matrix_powers( - m: jnp.ndarray, + m: jax.Array, powers: Iterable[float], -) -> List[jnp.ndarray]: +) -> List[jax.Array]: """Raise a real, symmetric matrix to multiple powers.""" eigs, q = jnp.linalg.eigh(m) qt = jnp.swapaxes(q, axis1=-2, axis2=-1) @@ -113,9 +113,7 @@ def matrix_powers( return ret -def invmatvectril( - m: jnp.ndarray, x: jnp.ndarray, lower: bool = True -) -> jnp.ndarray: +def invmatvectril(m: jax.Array, x: jax.Array, lower: bool = True) -> jax.Array: """Multiply x by the inverse of a triangular matrix. Args: @@ -133,7 +131,7 @@ def invmatvectril( def get_random_orthogonal( rng: jax.Array, dim: int, dtype: Optional[jnp.dtype] = None -) -> jnp.ndarray: +) -> jax.Array: """Get a random orthogonal matrix with the specified dimension.""" m = jax.random.normal(key=rng, shape=[dim, dim], dtype=dtype) q, _ = jnp.linalg.qr(m) diff --git a/src/ott/tools/gaussian_mixture/probabilities.py b/src/ott/tools/gaussian_mixture/probabilities.py index 6df3bb023..c3bb253a5 100644 --- a/src/ott/tools/gaussian_mixture/probabilities.py +++ b/src/ott/tools/gaussian_mixture/probabilities.py @@ -27,7 +27,7 @@ class Probabilities: to a length n simplex by appending a 0 and taking a softmax. """ - _params: jnp.ndarray + _params: jax.Array def __init__(self, params): self._params = params @@ -47,7 +47,7 @@ def from_random( ) @classmethod - def from_probs(cls, probs: jnp.ndarray) -> "Probabilities": + def from_probs(cls, probs: jax.Array) -> "Probabilities": """Construct Probabilities from a vector of probabilities.""" log_probs = jnp.log(probs) log_probs_normalized, norm = log_probs[:-1], log_probs[-1] @@ -62,21 +62,21 @@ def params(self): # noqa: D102 def dtype(self): # noqa: D102 return self._params.dtype - def unnormalized_log_probs(self) -> jnp.ndarray: + def unnormalized_log_probs(self) -> jax.Array: """Get the unnormalized log probabilities.""" return jnp.concatenate([self._params, jnp.zeros((1,), dtype=self.dtype)], axis=-1) - def log_probs(self) -> jnp.ndarray: + def log_probs(self) -> jax.Array: """Get the log probabilities.""" return jax.nn.log_softmax(self.unnormalized_log_probs()) - def probs(self) -> jnp.ndarray: + def probs(self) -> jax.Array: """Get the probabilities.""" return jax.nn.softmax(self.unnormalized_log_probs()) - def sample(self, rng: jax.Array, size: int) -> jnp.ndarray: + def sample(self, rng: jax.Array, size: int) -> jax.Array: """Sample from the distribution.""" return jax.random.categorical( key=rng, logits=self.unnormalized_log_probs(), shape=(size,) diff --git a/src/ott/tools/gaussian_mixture/scale_tril.py b/src/ott/tools/gaussian_mixture/scale_tril.py index b286cc74e..ee708d5ac 100644 --- a/src/ott/tools/gaussian_mixture/scale_tril.py +++ b/src/ott/tools/gaussian_mixture/scale_tril.py @@ -27,16 +27,16 @@ class ScaleTriL: """Pytree for a lower triangular Cholesky-factored covariance matrix.""" - def __init__(self, params: jnp.ndarray, size: int): + def __init__(self, params: jax.Array, size: int): self._params = params self._size = size @classmethod def from_points_and_weights( cls, - points: jnp.ndarray, - weights: jnp.ndarray, - ) -> Tuple[jnp.ndarray, "ScaleTriL"]: + points: jax.Array, + weights: jax.Array, + ) -> Tuple[jax.Array, "ScaleTriL"]: """Get a mean and a ScaleTriL from a set of points and weights.""" mean, cov = linalg.get_mean_and_cov(points=points, weights=weights) return mean, cls.from_covariance(cov) @@ -80,7 +80,7 @@ def from_random( return cls(params=flat, size=n_dimensions) @classmethod - def from_cholesky(cls, cholesky: jnp.ndarray) -> "ScaleTriL": + def from_cholesky(cls, cholesky: jax.Array) -> "ScaleTriL": """Construct ScaleTriL from a Cholesky factor of a covariance matrix.""" m = linalg.apply_to_diag(cholesky, jnp.log) flat = linalg.tril_to_flat(m) @@ -89,14 +89,14 @@ def from_cholesky(cls, cholesky: jnp.ndarray) -> "ScaleTriL": @classmethod def from_covariance( cls, - covariance: jnp.ndarray, + covariance: jax.Array, ) -> "ScaleTriL": """Construct ScaleTriL from a covariance matrix.""" cholesky = jnp.linalg.cholesky(covariance) return cls.from_cholesky(cholesky) @property - def params(self) -> jnp.ndarray: + def params(self) -> jax.Array: """Internal representation.""" return self._params @@ -110,34 +110,34 @@ def dtype(self): """Data type of the covariance matrix.""" return self._params.dtype - def cholesky(self) -> jnp.ndarray: + def cholesky(self) -> jax.Array: """Get a lower triangular Cholesky factor for the covariance matrix.""" m = linalg.flat_to_tril(self._params, size=self._size) return linalg.apply_to_diag(m, jnp.exp) - def covariance(self) -> jnp.ndarray: + def covariance(self) -> jax.Array: """Get the covariance matrix.""" cholesky = self.cholesky() return cholesky @ cholesky.T - def covariance_sqrt(self) -> jnp.ndarray: + def covariance_sqrt(self) -> jax.Array: """Get the square root of the covariance matrix.""" return linalg.matrix_powers(self.covariance(), (0.5,))[0] - def log_det_covariance(self) -> jnp.ndarray: + def log_det_covariance(self) -> jax.Array: """Get the log of the determinant of the covariance matrix.""" diag = jnp.diagonal(self.cholesky(), axis1=-2, axis2=-1) return 2. * jnp.sum(jnp.log(diag), axis=-1) - def centered_to_z(self, x_centered: jnp.ndarray) -> jnp.ndarray: + def centered_to_z(self, x_centered: jax.Array) -> jax.Array: """Map centered points to standardized centered points (i.e. cov(z) = I).""" return linalg.invmatvectril(m=self.cholesky(), x=x_centered, lower=True) - def z_to_centered(self, z: jnp.ndarray) -> jnp.ndarray: + def z_to_centered(self, z: jax.Array) -> jax.Array: """Scale standardized points to points with the specified covariance.""" return (self.cholesky() @ z.T).T - def w2_dist(self, other: "ScaleTriL") -> jnp.ndarray: + def w2_dist(self, other: "ScaleTriL") -> jax.Array: r"""Wasserstein distance W_2^2 to another Gaussian with same mean. Args: @@ -148,7 +148,7 @@ def w2_dist(self, other: "ScaleTriL") -> jnp.ndarray: """ dimension = self.size - def _flatten_cov(cov: jnp.ndarray) -> jnp.ndarray: + def _flatten_cov(cov: jax.Array) -> jax.Array: cov = cov.reshape(cov.shape[:-2] + (dimension * dimension,)) return jnp.concatenate([jnp.zeros(dimension), cov], axis=-1) @@ -159,7 +159,7 @@ def _flatten_cov(cov: jnp.ndarray) -> jnp.ndarray: ..., ] - def gaussian_map(self, dest_scale: "ScaleTriL") -> jnp.ndarray: + def gaussian_map(self, dest_scale: "ScaleTriL") -> jax.Array: """Scaling matrix used in transport between 0-mean Gaussians. Sigma_mu^{-1/2} @ @@ -179,9 +179,7 @@ def gaussian_map(self, dest_scale: "ScaleTriL") -> jnp.ndarray: ) return jnp.matmul(sqrt0_inv, jnp.matmul(m, sqrt0_inv)) - def transport( - self, dest_scale: "ScaleTriL", points: jnp.ndarray - ) -> jnp.ndarray: + def transport(self, dest_scale: "ScaleTriL", points: jax.Array) -> jax.Array: """Apply Monge map, computed between two 0-mean Gaussians, to points. Args: diff --git a/src/ott/tools/k_means.py b/src/ott/tools/k_means.py index 986b919d0..c8fc8189d 100644 --- a/src/ott/tools/k_means.py +++ b/src/ott/tools/k_means.py @@ -25,29 +25,29 @@ __all__ = ["k_means", "KMeansOutput"] Init_t = Union[Literal["k-means++", "random"], - Callable[[pointcloud.PointCloud, int, jnp.ndarray], jnp.ndarray]] + Callable[[pointcloud.PointCloud, int, jax.Array], jax.Array]] class KPPState(NamedTuple): # noqa: D101 rng: jax.Array - centroids: jnp.ndarray - centroid_dists: jnp.ndarray + centroids: jax.Array + centroid_dists: jax.Array class KMeansState(NamedTuple): # noqa: D101 - centroids: jnp.ndarray - prev_assignment: jnp.ndarray - assignment: jnp.ndarray - errors: jnp.ndarray + centroids: jax.Array + prev_assignment: jax.Array + assignment: jax.Array + errors: jax.Array center_shift: float class KMeansConst(NamedTuple): # noqa: D101 geom: pointcloud.PointCloud - x_weights: jnp.ndarray + x_weights: jax.Array @property - def x(self) -> jnp.ndarray: + def x(self) -> jax.Array: """Array of shape ``[n, ndim]`` containing the unweighted point cloud.""" return self.geom.x @@ -57,7 +57,7 @@ def weighted_x(self): return self.x_weights[:, :-1] @property - def weights(self) -> jnp.ndarray: + def weights(self) -> jax.Array: """Array of shape ``[n, 1]`` containing weights for each point.""" return self.x_weights[:, -1:] @@ -75,12 +75,12 @@ class KMeansOutput(NamedTuple): inner_errors: Array of shape ``[max_iterations,]`` containing the ``error`` at every iteration. """ - centroids: jnp.ndarray - assignment: jnp.ndarray + centroids: jax.Array + assignment: jax.Array converged: bool iteration: int error: float - inner_errors: Optional[jnp.ndarray] + inner_errors: Optional[jax.Array] @classmethod def _from_state( @@ -110,7 +110,7 @@ def _from_state( def _random_init( geom: pointcloud.PointCloud, k: int, rng: jax.Array -) -> jnp.ndarray: +) -> jax.Array: ixs = jnp.arange(geom.shape[0]) ixs = jax.random.choice(rng, ixs, shape=(k,), replace=False) return geom.subset(ixs, None).x @@ -121,7 +121,7 @@ def _k_means_plus_plus( k: int, rng: jax.Array, n_local_trials: Optional[int] = None, -) -> jnp.ndarray: +) -> jax.Array: def init_fn(geom: pointcloud.PointCloud, rng: jax.Array) -> KPPState: rng, next_rng = jax.random.split(rng, 2) @@ -131,7 +131,7 @@ def init_fn(geom: pointcloud.PointCloud, rng: jax.Array) -> KPPState: return KPPState(rng=next_rng, centroids=centroids, centroid_dists=dists) def body_fn( - iteration: int, const: Tuple[pointcloud.PointCloud, jnp.ndarray], + iteration: int, const: Tuple[pointcloud.PointCloud, jax.Array], state: KPPState, compute_error: bool ) -> KPPState: del compute_error @@ -177,10 +177,10 @@ def body_fn( @functools.partial(jax.vmap, in_axes=[None, 0, 0, 0], out_axes=0) def _reallocate_centroids( const: KMeansConst, - ix: jnp.ndarray, - centroid: jnp.ndarray, - weight: jnp.ndarray, -) -> Tuple[jnp.ndarray, jnp.ndarray]: + ix: jax.Array, + centroid: jax.Array, + weight: jax.Array, +) -> Tuple[jax.Array, jax.Array]: is_empty = weight <= 0. new_centroid = (1 - is_empty) * centroid + is_empty * const.x[ix] # (ndim,) centroid_to_remove = is_empty * const.weighted_x[ix] # (ndim,) @@ -190,8 +190,8 @@ def _reallocate_centroids( def _update_assignment( const: KMeansConst, - centroids: jnp.ndarray, -) -> Tuple[jnp.ndarray, jnp.ndarray]: + centroids: jax.Array, +) -> Tuple[jax.Array, jax.Array]: (x, _, *args), aux_data = const.geom.tree_flatten() cost_matrix = type( const.geom @@ -203,9 +203,9 @@ def _update_assignment( def _update_centroids( - const: KMeansConst, k: int, assignment: jnp.ndarray, - dist_to_centers: jnp.ndarray -) -> jnp.ndarray: + const: KMeansConst, k: int, assignment: jax.Array, + dist_to_centers: jax.Array +) -> jax.Array: # TODO(michalk8): # cannot put `k` into `const`, see https://github.com/ott-jax/ott/issues/129 x_weights = jax.ops.segment_sum(const.x_weights, assignment, num_segments=k) @@ -227,7 +227,7 @@ def _k_means( rng: jax.Array, geom: pointcloud.PointCloud, k: int, - weights: Optional[jnp.ndarray] = None, + weights: Optional[jax.Array] = None, init: Init_t = "k-means++", n_local_trials: Optional[int] = None, tol: float = 1e-4, @@ -342,9 +342,9 @@ def finalize_fn(const: KMeansConst, state: KMeansState) -> KMeansState: def k_means( - geom: Union[jnp.ndarray, pointcloud.PointCloud], + geom: Union[jax.Array, pointcloud.PointCloud], k: int, - weights: Optional[jnp.ndarray] = None, + weights: Optional[jax.Array] = None, init: Init_t = "k-means++", n_init: int = 10, n_local_trials: Optional[int] = None, @@ -386,7 +386,7 @@ def k_means( """ assert geom.shape[ 0] >= k, f"Cannot cluster `{geom.shape[0]}` points into `{k}` clusters." - if isinstance(geom, jnp.ndarray): + if isinstance(geom, jax.Array): geom = pointcloud.PointCloud(geom) if isinstance(geom.cost_fn, costs.Cosine): geom = geom._cosine_to_sqeucl() diff --git a/src/ott/tools/plot.py b/src/ott/tools/plot.py index bd1f42e91..d83868fd5 100644 --- a/src/ott/tools/plot.py +++ b/src/ott/tools/plot.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import List, Optional, Sequence, Tuple, Union +import jax import jax.numpy as jnp import numpy as np import scipy @@ -32,8 +33,7 @@ gromov_wasserstein.GWOutput] -def bidimensional(x: jnp.ndarray, - y: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: +def bidimensional(x: jax.Array, y: jax.Array) -> Tuple[jax.Array, jax.Array]: """Apply PCA to reduce to bi-dimensional data.""" if x.shape[1] < 3: return x, y @@ -121,7 +121,7 @@ def _scatter(self, ot: Transport): scales_y = b * self._scale * b.shape[0] return x, y, scales_x, scales_y - def _mapping(self, x: jnp.ndarray, y: jnp.ndarray, matrix: jnp.ndarray): + def _mapping(self, x: jax.Array, y: jax.Array, matrix: jax.Array): """Compute the lines representing the mapping between the 2 point clouds.""" # Only plot the lines with a cost above the threshold. u, v = jnp.where(matrix > self._threshold) diff --git a/src/ott/tools/segment_sinkhorn.py b/src/ott/tools/segment_sinkhorn.py index 223f2a30f..ca5e5c228 100644 --- a/src/ott/tools/segment_sinkhorn.py +++ b/src/ott/tools/segment_sinkhorn.py @@ -14,7 +14,7 @@ from types import MappingProxyType from typing import Any, Mapping, Optional, Tuple -import jax.numpy as jnp +import jax from ott.geometry import costs, pointcloud, segment from ott.problems.linear import linear_problem @@ -22,21 +22,21 @@ def segment_sinkhorn( - x: jnp.ndarray, - y: jnp.ndarray, + x: jax.Array, + y: jax.Array, num_segments: Optional[int] = None, max_measure_size: Optional[int] = None, cost_fn: Optional[costs.CostFn] = None, - segment_ids_x: Optional[jnp.ndarray] = None, - segment_ids_y: Optional[jnp.ndarray] = None, + segment_ids_x: Optional[jax.Array] = None, + segment_ids_y: Optional[jax.Array] = None, indices_are_sorted: bool = False, num_per_segment_x: Optional[Tuple[int, ...]] = None, num_per_segment_y: Optional[Tuple[int, ...]] = None, - weights_x: Optional[jnp.ndarray] = None, - weights_y: Optional[jnp.ndarray] = None, + weights_x: Optional[jax.Array] = None, + weights_y: Optional[jax.Array] = None, sinkhorn_kwargs: Mapping[str, Any] = MappingProxyType({}), **kwargs: Any -) -> jnp.ndarray: +) -> jax.Array: """Compute regularized OT cost between subsets of vectors in `x` and `y`. Helper function designed to compute Sinkhorn regularized OT cost between @@ -104,10 +104,10 @@ def segment_sinkhorn( padding_vector = cost_fn._padder(dim=dim) def eval_fn( - padded_x: jnp.ndarray, - padded_y: jnp.ndarray, - padded_weight_x: jnp.ndarray, - padded_weight_y: jnp.ndarray, + padded_x: jax.Array, + padded_y: jax.Array, + padded_weight_x: jax.Array, + padded_weight_y: jax.Array, ) -> float: mask_x = padded_weight_x > 0. mask_y = padded_weight_y > 0. diff --git a/src/ott/tools/sinkhorn_divergence.py b/src/ott/tools/sinkhorn_divergence.py index 51de97613..2ff1cbc4e 100644 --- a/src/ott/tools/sinkhorn_divergence.py +++ b/src/ott/tools/sinkhorn_divergence.py @@ -14,6 +14,7 @@ from types import MappingProxyType from typing import Any, Mapping, Optional, Tuple, Type +import jax import jax.numpy as jnp from ott import utils @@ -27,7 +28,7 @@ "SinkhornDivergenceOutput" ] -Potentials_t = Tuple[jnp.ndarray, jnp.ndarray] +Potentials_t = Tuple[jax.Array, jax.Array] @utils.register_pytree_node @@ -35,11 +36,10 @@ class SinkhornDivergenceOutput: # noqa: D101 divergence: float potentials: Tuple[Potentials_t, Potentials_t, Potentials_t] geoms: Tuple[geometry.Geometry, geometry.Geometry, geometry.Geometry] - errors: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray], - Optional[jnp.ndarray]] + errors: Tuple[Optional[jax.Array], Optional[jax.Array], Optional[jax.Array]] converged: Tuple[bool, bool, bool] - a: jnp.ndarray - b: jnp.ndarray + a: jax.Array + b: jax.Array n_iters: Tuple[int, int, int] def to_dual_potentials(self) -> "potentials.EntropicPotentials": @@ -73,8 +73,8 @@ def tree_unflatten_foo(cls, aux_data, children): # noqa: D102 def sinkhorn_divergence( geom: Type[geometry.Geometry], *args: Any, - a: Optional[jnp.ndarray] = None, - b: Optional[jnp.ndarray] = None, + a: Optional[jax.Array] = None, + b: Optional[jax.Array] = None, sinkhorn_kwargs: Mapping[str, Any] = MappingProxyType({}), static_b: bool = False, share_epsilon: bool = True, @@ -138,8 +138,8 @@ def _sinkhorn_divergence( geometry_xy: geometry.Geometry, geometry_xx: geometry.Geometry, geometry_yy: Optional[geometry.Geometry], - a: jnp.ndarray, - b: jnp.ndarray, + a: jax.Array, + b: jax.Array, symmetric_sinkhorn: bool, **kwargs: Any, ) -> SinkhornDivergenceOutput: @@ -155,9 +155,9 @@ def _sinkhorn_divergence( between elements of the view X. geometry_yy: a Cost object able to apply kernels with a certain epsilon, between elements of the view Y. - a: jnp.ndarray[n]: the weight of each input point. The sum of + a: jax.Array[n]: the weight of each input point. The sum of all elements of ``b`` must match that of ``a`` to converge. - b: jnp.ndarray[m]: the weight of each target point. The sum of + b: jax.Array[m]: the weight of each target point. The sum of all elements of ``b`` must match that of ``a`` to converge. symmetric_sinkhorn: Use Sinkhorn updates in Eq. 25 of :cite:`feydy:19` for symmetric terms comparing x/x and y/y. @@ -219,24 +219,24 @@ def _sinkhorn_divergence( def segment_sinkhorn_divergence( - x: jnp.ndarray, - y: jnp.ndarray, + x: jax.Array, + y: jax.Array, num_segments: Optional[int] = None, max_measure_size: Optional[int] = None, cost_fn: Optional[costs.CostFn] = None, - segment_ids_x: Optional[jnp.ndarray] = None, - segment_ids_y: Optional[jnp.ndarray] = None, + segment_ids_x: Optional[jax.Array] = None, + segment_ids_y: Optional[jax.Array] = None, indices_are_sorted: bool = False, num_per_segment_x: Optional[Tuple[int, ...]] = None, num_per_segment_y: Optional[Tuple[int, ...]] = None, - weights_x: Optional[jnp.ndarray] = None, - weights_y: Optional[jnp.ndarray] = None, + weights_x: Optional[jax.Array] = None, + weights_y: Optional[jax.Array] = None, sinkhorn_kwargs: Mapping[str, Any] = MappingProxyType({}), static_b: bool = False, share_epsilon: bool = True, symmetric_sinkhorn: bool = False, **kwargs: Any -) -> jnp.ndarray: +) -> jax.Array: """Compute Sinkhorn divergence between subsets of vectors in `x` and `y`. Helper function designed to compute Sinkhorn divergences between several point @@ -313,10 +313,10 @@ def segment_sinkhorn_divergence( padding_vector = cost_fn._padder(dim=dim) def eval_fn( - padded_x: jnp.ndarray, - padded_y: jnp.ndarray, - padded_weight_x: jnp.ndarray, - padded_weight_y: jnp.ndarray, + padded_x: jax.Array, + padded_y: jax.Array, + padded_weight_x: jax.Array, + padded_weight_y: jax.Array, ) -> float: mask_x = padded_weight_x > 0. mask_y = padded_weight_y > 0. diff --git a/src/ott/tools/soft_sort.py b/src/ott/tools/soft_sort.py index 646b3eb0c..b5b33e183 100644 --- a/src/ott/tools/soft_sort.py +++ b/src/ott/tools/soft_sort.py @@ -30,14 +30,14 @@ "quantize", "topk_mask", "multivariate_cdf_quantile_maps" ] -Func_t = Callable[[jnp.ndarray], jnp.ndarray] +Func_t = Callable[[jax.Array], jax.Array] def transport_for_sort( - inputs: jnp.ndarray, - weights: Optional[jnp.ndarray] = None, - target_weights: Optional[jnp.ndarray] = None, - squashing_fun: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, + inputs: jax.Array, + weights: Optional[jax.Array] = None, + target_weights: Optional[jax.Array] = None, + squashing_fun: Optional[Callable[[jax.Array], jax.Array]] = None, epsilon: float = 1e-2, **kwargs: Any, ) -> sinkhorn.SinkhornOutput: @@ -83,7 +83,7 @@ def transport_for_sort( return solver(prob) -def apply_on_axis(op, inputs, axis, *args, **kwargs: Any) -> jnp.ndarray: +def apply_on_axis(op, inputs, axis, *args, **kwargs: Any) -> jax.Array: """Apply a differentiable operator on a given axis of the input. Args: @@ -120,8 +120,8 @@ def apply_on_axis(op, inputs, axis, *args, **kwargs: Any) -> jnp.ndarray: def _sort( - inputs: jnp.ndarray, topk: int, num_targets: Optional[int], **kwargs: Any -) -> jnp.ndarray: + inputs: jax.Array, topk: int, num_targets: Optional[int], **kwargs: Any +) -> jax.Array: """Apply the soft sort operator on a one dimensional array.""" num_points = inputs.shape[0] a = jnp.ones((num_points,)) / num_points @@ -145,12 +145,12 @@ def _sort( def sort( - inputs: jnp.ndarray, + inputs: jax.Array, axis: int = -1, topk: int = -1, num_targets: Optional[int] = None, **kwargs: Any, -) -> jnp.ndarray: +) -> jax.Array: r"""Apply the soft sort operator on a given axis of the input. For instance: @@ -203,8 +203,8 @@ def sort( def _ranks( - inputs: jnp.ndarray, num_targets, target_weights, **kwargs: Any -) -> jnp.ndarray: + inputs: jax.Array, num_targets, target_weights, **kwargs: Any +) -> jax.Array: """Apply the soft ranks operator on a one dimensional array.""" num_points = inputs.shape[0] if target_weights is None: @@ -220,12 +220,12 @@ def _ranks( def ranks( - inputs: jnp.ndarray, + inputs: jax.Array, axis: int = -1, num_targets: Optional[int] = None, - target_weights: Optional[jnp.ndarray] = None, + target_weights: Optional[jax.Array] = None, **kwargs: Any, -) -> jnp.ndarray: +) -> jax.Array: r"""Apply the soft rank operator on input tensor. For instance: @@ -278,11 +278,11 @@ def ranks( def topk_mask( - inputs: jnp.ndarray, + inputs: jax.Array, axis: int = -1, k: int = 1, **kwargs: Any, -) -> jnp.ndarray: +) -> jax.Array: r"""Soft :math:`\text{top-}k` selection mask. For instance: @@ -337,12 +337,12 @@ def topk_mask( def quantile( - inputs: jnp.ndarray, - q: Optional[Union[float, jnp.ndarray]], + inputs: jax.Array, + q: Optional[Union[float, jax.Array]], axis: Union[int, Tuple[int, ...]] = -1, - weight: Optional[Union[float, jnp.ndarray]] = None, + weight: Optional[Union[float, jax.Array]] = None, **kwargs: Any, -) -> jnp.ndarray: +) -> jax.Array: r"""Apply the soft quantiles operator on the input tensor. For instance: @@ -395,8 +395,8 @@ def quantile( """ def _quantile( - inputs: jnp.ndarray, q: float, weight: float, **kwargs - ) -> jnp.ndarray: + inputs: jax.Array, q: float, weight: float, **kwargs + ) -> jax.Array: num_points = inputs.shape[0] q = jnp.array([0.2, 0.5, 0.8]) if q is None else jnp.atleast_1d(q) num_quantiles = q.shape[0] @@ -456,15 +456,15 @@ def _quantile( def multivariate_cdf_quantile_maps( - inputs: jnp.ndarray, + inputs: jax.Array, target_sampler: Optional[Callable[[jax.Array, Tuple[int, int]], - jnp.ndarray]] = None, + jax.Array]] = None, rng: Optional[jax.Array] = None, num_target_samples: Optional[int] = None, cost_fn: Optional[costs.CostFn] = None, epsilon: Optional[float] = None, - input_weights: Optional[jnp.ndarray] = None, - target_weights: Optional[jnp.ndarray] = None, + input_weights: Optional[jax.Array] = None, + target_weights: Optional[jax.Array] = None, **kwargs: Any ) -> Tuple[Func_t, Func_t]: r"""Returns multivariate CDF and quantile maps, given input samples. @@ -534,8 +534,8 @@ def multivariate_cdf_quantile_maps( def _quantile_normalization( - inputs: jnp.ndarray, targets: jnp.ndarray, weights: float, **kwargs: Any -) -> jnp.ndarray: + inputs: jax.Array, targets: jax.Array, weights: float, **kwargs: Any +) -> jax.Array: """Apply soft quantile normalization on a one dimensional array.""" num_points = inputs.shape[0] a = jnp.ones((num_points,)) / num_points @@ -544,12 +544,12 @@ def _quantile_normalization( def quantile_normalization( - inputs: jnp.ndarray, - targets: jnp.ndarray, - weights: Optional[jnp.ndarray] = None, + inputs: jax.Array, + targets: jax.Array, + weights: Optional[jax.Array] = None, axis: int = -1, **kwargs: Any, -) -> jnp.ndarray: +) -> jax.Array: r"""Re-normalize inputs so that its quantiles match those of targets/weights. Quantile normalization rearranges the values in inputs to values that match @@ -600,11 +600,11 @@ def quantile_normalization( def sort_with( - inputs: jnp.ndarray, - criterion: jnp.ndarray, + inputs: jax.Array, + criterion: jax.Array, topk: int = -1, **kwargs: Any, -) -> jnp.ndarray: +) -> jax.Array: r"""Sort a multidimensional array according to a real valued criterion. Given ``batch`` vectors of dimension `dim`, to which, for each, a real value @@ -655,7 +655,7 @@ def sort_with( return sort_fn(inputs) -def _quantize(inputs: jnp.ndarray, num_q: int, **kwargs: Any) -> jnp.ndarray: +def _quantize(inputs: jax.Array, num_q: int, **kwargs: Any) -> jax.Array: """Apply the soft quantization operator on a one dimensional array.""" num_points = inputs.shape[0] a = jnp.ones((num_points,)) / num_points @@ -665,11 +665,11 @@ def _quantize(inputs: jnp.ndarray, num_q: int, **kwargs: Any) -> jnp.ndarray: def quantize( - inputs: jnp.ndarray, + inputs: jax.Array, num_levels: int = 10, axis: int = -1, **kwargs: Any, -) -> jnp.ndarray: +) -> jax.Array: r"""Soft quantizes an input according using ``num_levels`` values along axis. The quantization operator consists in concentrating several values around diff --git a/src/ott/types.py b/src/ott/types.py index 7a4c88716..5c4609ec2 100644 --- a/src/ott/types.py +++ b/src/ott/types.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Protocol -import jax.numpy as jnp +import jax __all__ = ["Transport"] @@ -28,11 +28,11 @@ class can however be used in type hints to support duck typing. """ @property - def matrix(self) -> jnp.ndarray: + def matrix(self) -> jax.Array: ... - def apply(self, inputs: jnp.ndarray, axis: int) -> jnp.ndarray: + def apply(self, inputs: jax.Array, axis: int) -> jax.Array: ... - def marginal(self, axis: int = 0) -> jnp.ndarray: + def marginal(self, axis: int = 0) -> jax.Array: ... diff --git a/tests/conftest.py b/tests/conftest.py index bc4570343..a8118845c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,7 +17,6 @@ import jax import jax.experimental -import jax.numpy as jnp import pytest from _pytest.python import Metafunc @@ -69,7 +68,7 @@ def pytest_generate_tests(metafunc: Metafunc) -> None: @pytest.fixture(scope="session") -def rng() -> jnp.ndarray: +def rng() -> jax.Array: return jax.random.PRNGKey(0) diff --git a/tests/geometry/costs_test.py b/tests/geometry/costs_test.py index 57a4d8874..47446a4fd 100644 --- a/tests/geometry/costs_test.py +++ b/tests/geometry/costs_test.py @@ -27,7 +27,7 @@ ts_metrics = None -def _proj(matrix: jnp.ndarray) -> jnp.ndarray: +def _proj(matrix: jax.Array) -> jax.Array: u, _, v_h = jnp.linalg.svd(matrix, full_matrices=False) return u.dot(v_h) diff --git a/tests/geometry/graph_test.py b/tests/geometry/graph_test.py index c242b192f..cda2900a8 100644 --- a/tests/geometry/graph_test.py +++ b/tests/geometry/graph_test.py @@ -35,7 +35,7 @@ def random_graph( *, return_laplacian: bool = False, directed: bool = False, -) -> jnp.ndarray: +) -> jax.Array: G = random_graphs.fast_gnp_random_graph(n, p, seed=seed, directed=directed) if not directed: assert nx.is_connected(G), "Generated graph is not connected." @@ -51,7 +51,7 @@ def random_graph( return jnp.asarray(G.toarray()) -def gt_geometry(G: jnp.ndarray, *, epsilon: float = 1e-2) -> geometry.Geometry: +def gt_geometry(G: jax.Array, *, epsilon: float = 1e-2) -> geometry.Geometry: if not isinstance(G, nx.Graph): G = nx.from_numpy_array(np.asarray(G)) @@ -160,7 +160,7 @@ def test_crank_nicolson_more_stable(self, t: Optional[float], n_steps: int): @pytest.mark.parametrize(("jit", "normalize"), [(False, True), (True, False)]) def test_directed_graph(self, jit: bool, normalize: bool): - def create_graph(G: jnp.ndarray) -> graph.Graph: + def create_graph(G: jax.Array) -> graph.Graph: return graph.Graph.from_graph(G, directed=True, normalize=normalize) G = random_graph(16, p=0.25, directed=True) @@ -181,7 +181,7 @@ def create_graph(G: jnp.ndarray) -> graph.Graph: @pytest.mark.parametrize("normalize", [False, True]) def test_normalize_laplacian(self, directed: bool, normalize: bool): - def laplacian(G: jnp.ndarray) -> jnp.ndarray: + def laplacian(G: jax.Array) -> jax.Array: if directed: G = G + G.T @@ -250,8 +250,8 @@ def test_dense_graph_differentiability( ): def callback( - data: jnp.ndarray, rows: jnp.ndarray, cols: jnp.ndarray, - shape: Tuple[int, int] + data: jax.Array, rows: jax.Array, cols: jax.Array, shape: Tuple[int, + int] ) -> float: G = sparse.BCOO((data, jnp.c_[rows, cols]), shape=shape).todense() diff --git a/tests/geometry/low_rank_test.py b/tests/geometry/low_rank_test.py index 87dd98db2..6b3c36edd 100644 --- a/tests/geometry/low_rank_test.py +++ b/tests/geometry/low_rank_test.py @@ -160,7 +160,7 @@ def test_add_lr_geoms_scale_factor( @pytest.mark.parametrize("axis", [0, 1]) @pytest.mark.parametrize("fn", [lambda x: x + 10, lambda x: x * 2]) def test_apply_affine_function_efficient( - self, rng: jax.Array, fn: Callable[[jnp.ndarray], jnp.ndarray], axis: int + self, rng: jax.Array, fn: Callable[[jax.Array], jax.Array], axis: int ): n, m, d = 21, 13, 3 rngs = jax.random.split(rng, 3) diff --git a/tests/geometry/scaling_cost_test.py b/tests/geometry/scaling_cost_test.py index 94ce97cf4..b60805d34 100644 --- a/tests/geometry/scaling_cost_test.py +++ b/tests/geometry/scaling_cost_test.py @@ -53,7 +53,7 @@ def test_scale_cost_pointcloud( """Test various scale cost options for pointcloud.""" def apply_sinkhorn( - x: jnp.ndarray, y: jnp.ndarray, a: jnp.ndarray, b: jnp.ndarray, + x: jax.Array, y: jax.Array, a: jax.Array, b: jax.Array, scale_cost: Union[str, float] ): geom = pointcloud.PointCloud( @@ -120,8 +120,8 @@ def test_scale_cost_geometry(self, scale: Union[str, float]): """Test various scale cost options for geometry.""" def apply_sinkhorn( - cost: jnp.ndarray, a: jnp.ndarray, b: jnp.ndarray, - scale_cost: Union[str, float] + cost: jax.Array, a: jax.Array, b: jax.Array, scale_cost: Union[str, + float] ): geom = geometry.Geometry(cost, epsilon=self.eps, scale_cost=scale_cost) prob = linear_problem.LinearProblem(geom, a, b) diff --git a/tests/initializers/linear/sinkhorn_init_test.py b/tests/initializers/linear/sinkhorn_init_test.py index 6acf77f11..73c0ddaaa 100644 --- a/tests/initializers/linear/sinkhorn_init_test.py +++ b/tests/initializers/linear/sinkhorn_init_test.py @@ -80,12 +80,12 @@ def create_ot_problem( def run_sinkhorn( - x: jnp.ndarray, - y: jnp.ndarray, + x: jax.Array, + y: jax.Array, *, initializer: linear_init.SinkhornInitializer, - a: Optional[jnp.ndarray] = None, - b: Optional[jnp.ndarray] = None, + a: Optional[jax.Array] = None, + b: Optional[jax.Array] = None, epsilon: float = 1e-2, lse_mode: bool = True, ) -> sinkhorn.SinkhornOutput: diff --git a/tests/math/matrix_square_root_test.py b/tests/math/matrix_square_root_test.py index fcd557957..2263ea8b9 100644 --- a/tests/math/matrix_square_root_test.py +++ b/tests/math/matrix_square_root_test.py @@ -37,9 +37,9 @@ def _get_random_spd_matrix(dim: int, rng: jax.Array): def _get_test_fn( - fn: Callable[[jnp.ndarray], jnp.ndarray], dim: int, rng: jax.Array, + fn: Callable[[jax.Array], jax.Array], dim: int, rng: jax.Array, **kwargs: Any -) -> Callable[[jnp.ndarray], jnp.ndarray]: +) -> Callable[[jax.Array], jax.Array]: # We want to test gradients of a function fn that maps positive definite # matrices to positive definite matrices by comparing them to finite # difference approximations. We'll do so via a test function that @@ -54,7 +54,7 @@ def _get_test_fn( unit = jax.random.normal(key=subrng3, shape=(dim, dim)) unit /= jnp.sqrt(jnp.sum(unit ** 2.)) - def _test_fn(x: jnp.ndarray, **kwargs: Any) -> jnp.ndarray: + def _test_fn(x: jax.Array, **kwargs: Any) -> jax.Array: # m is the product of 2 symmetric, positive definite matrices # so it will be positive definite but not necessarily symmetric m = jnp.matmul(m0, m1 + x * dx) @@ -63,7 +63,7 @@ def _test_fn(x: jnp.ndarray, **kwargs: Any) -> jnp.ndarray: return _test_fn -def _sqrt_plus_inv_sqrt(x: jnp.ndarray) -> jnp.ndarray: +def _sqrt_plus_inv_sqrt(x: jax.Array) -> jax.Array: sqrtm = matrix_square_root.sqrtm(x) return sqrtm[0] + sqrtm[1] diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py index 2dc9f1e43..a5fdc1c2b 100644 --- a/tests/neural/conftest.py +++ b/tests/neural/conftest.py @@ -46,7 +46,7 @@ def __init__( self.conditions = list(dataloaders.keys()) self.p = p - def __next__(self) -> jnp.ndarray: + def __next__(self) -> jax.Array: self.rng, rng = jax.random.split(self.rng, 2) idx = jax.random.choice(rng, len(self.conditions), p=self.p) return next(self.dataloaders[self.conditions[idx]]) diff --git a/tests/neural/map_estimator_test.py b/tests/neural/map_estimator_test.py index 7c506aa38..0454db751 100644 --- a/tests/neural/map_estimator_test.py +++ b/tests/neural/map_estimator_test.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Optional -import jax.numpy as jnp +import jax import pytest from ott import datasets @@ -34,8 +34,8 @@ def test_map_estimator_convergence(self): # define the fitting loss and the regularizer def fitting_loss( - samples: jnp.ndarray, - mapped_samples: jnp.ndarray, + samples: jax.Array, + mapped_samples: jax.Array, ) -> Optional[float]: r"""Sinkhorn divergence fitting loss.""" div = sinkhorn_divergence.sinkhorn_divergence( diff --git a/tests/neural/meta_initializer_test.py b/tests/neural/meta_initializer_test.py index f711366ec..25a88907e 100644 --- a/tests/neural/meta_initializer_test.py +++ b/tests/neural/meta_initializer_test.py @@ -31,7 +31,7 @@ class MetaMLP(nn.Module): num_hidden_layers: int = 3 @nn.compact - def __call__(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: + def __call__(self, a: jax.Array, b: jax.Array) -> jax.Array: dtype = a.dtype z = jnp.concatenate((a, b)) for _ in range(self.num_hidden_layers): @@ -65,12 +65,12 @@ def create_ot_problem( def run_sinkhorn( - x: jnp.ndarray, - y: jnp.ndarray, + x: jax.Array, + y: jax.Array, *, initializer: linear_init.SinkhornInitializer, - a: Optional[jnp.ndarray] = None, - b: Optional[jnp.ndarray] = None, + a: Optional[jax.Array] = None, + b: Optional[jax.Array] = None, epsilon: float = 1e-2, lse_mode: bool = True, ) -> sinkhorn.SinkhornOutput: diff --git a/tests/solvers/linear/continuous_barycenter_test.py b/tests/solvers/linear/continuous_barycenter_test.py index 5512263c7..5c7fabd67 100644 --- a/tests/solvers/linear/continuous_barycenter_test.py +++ b/tests/solvers/linear/continuous_barycenter_test.py @@ -27,7 +27,7 @@ means_and_covs_to_x = jax.vmap(costs.mean_and_cov_to_x, in_axes=[0, 0, None]) -def is_positive_semidefinite(c: jnp.ndarray) -> bool: +def is_positive_semidefinite(c: jax.Array) -> bool: # GPU friendly, eigvals not implemented for non-symmetric matrices w = jnp.linalg.eigvalsh((c + c.T) / 2.0) return jnp.all(w >= 0) @@ -119,8 +119,8 @@ def test_barycenter_jit(self, rng: jax.Array, segment_before: bool): @functools.partial(jax.jit, static_argnums=(2, 3)) def barycenter( - y: jnp.ndarray, - b: jnp.ndarray, + y: jax.Array, + b: jax.Array, segment_before: bool, num_per_segment: Tuple[int, ...], ) -> cb.FreeBarycenterState: diff --git a/tests/solvers/linear/sinkhorn_diff_test.py b/tests/solvers/linear/sinkhorn_diff_test.py index d80f94251..a608c0d71 100644 --- a/tests/solvers/linear/sinkhorn_diff_test.py +++ b/tests/solvers/linear/sinkhorn_diff_test.py @@ -49,7 +49,7 @@ def test_implicit_differentiation_versus_autodiff( ): epsilon = 0.05 - def loss_g(a: jnp.ndarray, x: jnp.ndarray, implicit: bool = True) -> float: + def loss_g(a: jax.Array, x: jax.Array, implicit: bool = True) -> float: implicit_diff = implicit_lib.ImplicitDiff() if implicit else None geom = geometry.Geometry( cost_matrix=jnp.sum(x ** 2, axis=1)[:, jnp.newaxis] + @@ -65,9 +65,7 @@ def loss_g(a: jnp.ndarray, x: jnp.ndarray, implicit: bool = True) -> float: ) return solver(prob).reg_ot_cost - def loss_pcg( - a: jnp.ndarray, x: jnp.ndarray, implicit: bool = True - ) -> float: + def loss_pcg(a: jax.Array, x: jax.Array, implicit: bool = True) -> float: implicit_diff = implicit_lib.ImplicitDiff() if implicit else None geom = pointcloud.PointCloud(x, self.y, epsilon=epsilon) prob = linear_problem.LinearProblem( @@ -154,7 +152,7 @@ def test_autograd_sinkhorn( a = a / jnp.sum(a) b = b / jnp.sum(b) - def reg_ot(a: jnp.ndarray, b: jnp.ndarray) -> float: + def reg_ot(a: jax.Array, b: jax.Array) -> float: geom = pointcloud.PointCloud(x, y, epsilon=1e-1) prob = linear_problem.LinearProblem(geom, a=a, b=b) solver = sinkhorn.Sinkhorn(lse_mode=lse_mode) @@ -190,7 +188,7 @@ def test_gradient_sinkhorn_geometry( delta = delta / jnp.sqrt(jnp.vdot(delta, delta)) eps = 1e-3 # perturbation magnitude - def loss_fn(cm: jnp.ndarray): + def loss_fn(cm: jax.Array): a = jnp.ones(cm.shape[0]) / cm.shape[0] b = jnp.ones(cm.shape[1]) / cm.shape[1] geom = geometry.Geometry(cm, epsilon=0.5) @@ -264,8 +262,8 @@ def test_gradient_sinkhorn_euclidean( # Adding some near-zero distances to test proper handling with p_norm=1. y = y.at[0].set(x[0, :] + 1e-3) - def loss_fn(x: jnp.ndarray, - y: jnp.ndarray) -> Tuple[float, sinkhorn.SinkhornOutput]: + def loss_fn(x: jax.Array, + y: jax.Array) -> Tuple[float, sinkhorn.SinkhornOutput]: implicit_diff = implicit_lib.ImplicitDiff() if implicit else None geom = pointcloud.PointCloud(x, y, epsilon=epsilon, cost_fn=cost_fn) prob = linear_problem.LinearProblem(geom, a, b) @@ -320,7 +318,7 @@ def loss_fn(x: jnp.ndarray, def test_autoepsilon_differentiability(self, rng: jax.Array): cost = jax.random.uniform(rng, (15, 17)) - def reg_ot_cost(c: jnp.ndarray) -> float: + def reg_ot_cost(c: jax.Array) -> float: geom = geometry.Geometry(c, epsilon=None) # auto epsilon prob = linear_problem.LinearProblem(geom) return sinkhorn.Sinkhorn()(prob).reg_ot_cost @@ -331,7 +329,7 @@ def reg_ot_cost(c: jnp.ndarray) -> float: @pytest.mark.fast() def test_differentiability_with_jit(self, rng: jax.Array): - def reg_ot_cost(c: jnp.ndarray) -> float: + def reg_ot_cost(c: jax.Array) -> float: geom = geometry.Geometry(c, epsilon=1e-2) prob = linear_problem.LinearProblem(geom) return sinkhorn.Sinkhorn()(prob).reg_ot_cost @@ -385,7 +383,7 @@ def test_apply_transport_jacobian( # general rule, even more so when using backprop. epsilon = 0.01 if lse_mode else 0.1 - def apply_ot(a: jnp.ndarray, x: jnp.ndarray, implicit: bool) -> jnp.ndarray: + def apply_ot(a: jax.Array, x: jax.Array, implicit: bool) -> jax.Array: geom = pointcloud.PointCloud(x, y, epsilon=epsilon) prob = linear_problem.LinearProblem(geom, a, b, tau_a=tau_a, tau_b=tau_b) @@ -488,7 +486,7 @@ def test_potential_jacobian_sinkhorn( # with small epsilon when differentiating. epsilon = 0.01 if lse_mode else 0.1 - def loss_from_potential(a: jnp.ndarray, x: jnp.ndarray, implicit: bool): + def loss_from_potential(a: jax.Array, x: jax.Array, implicit: bool): geom = pointcloud.PointCloud(x, y, epsilon=epsilon) prob = linear_problem.LinearProblem(geom, a, b, tau_a=tau_a, tau_b=tau_b) @@ -556,7 +554,7 @@ def test_diff_sinkhorn_x_grid_x_perturbation( a = a.ravel() / jnp.sum(a) b = b.ravel() / jnp.sum(b) - def reg_ot(x: List[jnp.ndarray]) -> float: + def reg_ot(x: List[jax.Array]) -> float: geom = grid.Grid(x=x, epsilon=1.0) prob = linear_problem.LinearProblem(geom, a=a, b=b) solver = sinkhorn.Sinkhorn(threshold=1e-1, lse_mode=lse_mode) @@ -605,7 +603,7 @@ def test_diff_sinkhorn_x_grid_weights_perturbation( b = b.ravel() / jnp.sum(b) geom = grid.Grid(x=x, epsilon=1) - def reg_ot(a: jnp.ndarray, b: jnp.ndarray) -> float: + def reg_ot(a: jax.Array, b: jax.Array) -> float: prob = linear_problem.LinearProblem(geom, a, b) solver = sinkhorn.Sinkhorn(threshold=1e-3, lse_mode=lse_mode) return solver(prob).reg_ot_cost @@ -667,9 +665,9 @@ def test_potential_jacobian_sinkhorn_precond( epsilon = 0.05 if lse_mode else 0.1 def loss_from_potential( - a: jnp.ndarray, - x: jnp.ndarray, - precondition_fun: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, + a: jax.Array, + x: jax.Array, + precondition_fun: Optional[Callable[[jax.Array], jax.Array]] = None, symmetric: bool = False ) -> float: geom = pointcloud.PointCloud(x, y, epsilon=epsilon) @@ -771,7 +769,7 @@ def test_hessian_sinkhorn( imp_dif = implicit_lib.ImplicitDiff(solver_kwargs=solver_kwargs) - def loss(a: jnp.ndarray, x: jnp.ndarray, implicit: bool = True): + def loss(a: jax.Array, x: jax.Array, implicit: bool = True): geom = pointcloud.PointCloud(x, y, epsilon=epsilon) prob = linear_problem.LinearProblem(geom, a, b, tau_a, tau_b) implicit_diff = imp_dif if implicit else None diff --git a/tests/solvers/linear/sinkhorn_misc_test.py b/tests/solvers/linear/sinkhorn_misc_test.py index aeb37918b..e88ff4d0c 100644 --- a/tests/solvers/linear/sinkhorn_misc_test.py +++ b/tests/solvers/linear/sinkhorn_misc_test.py @@ -346,12 +346,10 @@ def assert_output_close( ) -> None: """Assert SinkhornOutputs are close.""" x = tuple( - a for a in x - if (a is not None and (isinstance(a, (jnp.ndarray, int)))) + a for a in x if (a is not None and (isinstance(a, (jax.Array, int)))) ) y = tuple( - a for a in y - if (a is not None and (isinstance(a, (jnp.ndarray, int)))) + a for a in y if (a is not None and (isinstance(a, (jax.Array, int)))) ) return chex.assert_trees_all_close(x, y, atol=1e-6, rtol=0) @@ -364,7 +362,7 @@ def assert_output_close( def test_jit_vs_non_jit_bwd(self, implicit: bool): @jax.value_and_grad - def val_grad(a: jnp.ndarray, x: jnp.ndarray) -> float: + def val_grad(a: jax.Array, x: jax.Array) -> float: implicit_diff = implicit_lib.ImplicitDiff() if implicit else None geom = geometry.Geometry( cost_matrix=( diff --git a/tests/solvers/quadratic/fgw_test.py b/tests/solvers/quadratic/fgw_test.py index 0a2a2fff4..508fedcb2 100644 --- a/tests/solvers/quadratic/fgw_test.py +++ b/tests/solvers/quadratic/fgw_test.py @@ -56,7 +56,7 @@ def test_gradient_marginals_fgw_solver(self, jit: bool): geom_y = pointcloud.PointCloud(self.y) geom_xy = pointcloud.PointCloud(self.x_2, self.y_2) - def reg_gw(a: jnp.ndarray, b: jnp.ndarray, implicit: bool): + def reg_gw(a: jax.Array, b: jax.Array, implicit: bool): prob = quadratic_problem.QuadraticProblem( geom_x, geom_y, geom_xy, fused_penalty=self.fused_penalty, a=a, b=b ) @@ -101,9 +101,9 @@ def test_gradient_fgw_solver_geometry(self, lse_mode: bool, is_cost: bool): """Test gradient w.r.t. the geometries.""" def reg_gw( - x: jnp.ndarray, y: jnp.ndarray, - xy: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]], - fused_penalty: float, a: jnp.ndarray, b: jnp.ndarray, implicit: bool + x: jax.Array, y: jax.Array, xy: Union[jax.Array, Tuple[jax.Array, + jax.Array]], + fused_penalty: float, a: jax.Array, b: jax.Array, implicit: bool ): if is_cost: geom_x = geometry.Geometry(cost_matrix=x) @@ -182,8 +182,8 @@ def test_gradient_fgw_solver_penalty(self): lse_mode = True def reg_gw( - cx: jnp.ndarray, cy: jnp.ndarray, cxy: jnp.ndarray, - fused_penalty: float, a: jnp.ndarray, b: jnp.ndarray, implicit: bool + cx: jax.Array, cy: jax.Array, cxy: jax.Array, fused_penalty: float, + a: jax.Array, b: jax.Array, implicit: bool ) -> float: geom_x = geometry.Geometry(cost_matrix=cx) geom_y = geometry.Geometry(cost_matrix=cy) diff --git a/tests/solvers/quadratic/gw_barycenter_test.py b/tests/solvers/quadratic/gw_barycenter_test.py index 6bc843477..d5dadd691 100644 --- a/tests/solvers/quadratic/gw_barycenter_test.py +++ b/tests/solvers/quadratic/gw_barycenter_test.py @@ -42,9 +42,9 @@ def random_pc( @staticmethod def pad_cost_matrices( - costs: Sequence[jnp.ndarray], + costs: Sequence[jax.Array], shape: Optional[Tuple[int, int]] = None - ) -> Tuple[jnp.ndarray, jnp.ndarray]: + ) -> Tuple[jax.Array, jax.Array]: if shape is None: shape = jnp.asarray([arr.shape for arr in costs]).max() shape = (shape, shape) @@ -133,7 +133,7 @@ def test_fgw_barycenter( ): def barycenter( - y: jnp.ndim, y_fused: jnp.ndarray, num_per_segment: Tuple[int, ...] + y: jnp.ndim, y_fused: jax.Array, num_per_segment: Tuple[int, ...] ) -> gwb_solver.GWBarycenterState: prob = gwb.GWBarycenterProblem( y=y, diff --git a/tests/solvers/quadratic/gw_test.py b/tests/solvers/quadratic/gw_test.py index e7b77cd58..e7d0ff106 100644 --- a/tests/solvers/quadratic/gw_test.py +++ b/tests/solvers/quadratic/gw_test.py @@ -156,8 +156,8 @@ def test_flag_store_errors(self): def test_gradient_marginals_gw(self, jit: bool): """Test gradient w.r.t. probability weights.""" - def reg_gw(a: jnp.ndarray, b: jnp.ndarray, - implicit: bool) -> Tuple[float, Tuple[jnp.ndarray, jnp.ndarray]]: + def reg_gw(a: jax.Array, b: jax.Array, + implicit: bool) -> Tuple[float, Tuple[jax.Array, jax.Array]]: prob = quadratic_problem.QuadraticProblem(geom_x, geom_y, a=a, b=b) implicit_diff = implicit_lib.ImplicitDiff() if implicit else None linear_solver = sinkhorn.Sinkhorn( @@ -245,8 +245,7 @@ def test_gradient_gw_geometry( """Test gradient w.r.t. the geometries.""" def reg_gw( - x: jnp.ndarray, y: jnp.ndarray, a: jnp.ndarray, b: jnp.ndarray, - implicit: bool + x: jax.Array, y: jax.Array, a: jax.Array, b: jax.Array, implicit: bool ) -> float: if is_cost: geom_x = geometry.Geometry(cost_matrix=x) diff --git a/tests/solvers/quadratic/lower_bound_test.py b/tests/solvers/quadratic/lower_bound_test.py index ba90d6362..bf32aad87 100644 --- a/tests/solvers/quadratic/lower_bound_test.py +++ b/tests/solvers/quadratic/lower_bound_test.py @@ -118,11 +118,11 @@ def test_lb_pointcloud( ] ) def test_lb_grad( - self, rng: jax.Array, sort_fn: Callable[[jnp.ndarray], jnp.ndarray], + self, rng: jax.Array, sort_fn: Callable[[jax.Array], jax.Array], method: str ): - def fn(x: jnp.ndarray, y: jnp.ndarray) -> float: + def fn(x: jax.Array, y: jax.Array) -> float: geom_x = pointcloud.PointCloud(x) geom_y = pointcloud.PointCloud(y) prob = quadratic_problem.QuadraticProblem(geom_x, geom_y) diff --git a/tests/tools/k_means_test.py b/tests/tools/k_means_test.py index a36c4b5c1..55cacde02 100644 --- a/tests/tools/k_means_test.py +++ b/tests/tools/k_means_test.py @@ -31,7 +31,7 @@ def make_blobs( *args: Any, cost_fn: Optional[Literal["sqeucl", "cosine"]] = None, **kwargs: Any -) -> Tuple[Union[jnp.ndarray, pointcloud.PointCloud], jnp.ndarray, jnp.ndarray]: +) -> Tuple[Union[jax.Array, pointcloud.PointCloud], jax.Array, jax.Array]: X, y, c = datasets.make_blobs(*args, return_centers=True, **kwargs) X, y, c = jnp.asarray(X), jnp.asarray(y), jnp.asarray(c) if cost_fn is None: @@ -47,10 +47,10 @@ def make_blobs( def compute_assignment( - x: jnp.ndarray, - centers: jnp.ndarray, - weights: Optional[jnp.ndarray] = None -) -> Tuple[jnp.ndarray, float]: + x: jax.Array, + centers: jax.Array, + weights: Optional[jax.Array] = None +) -> Tuple[jax.Array, float]: if weights is None: weights = jnp.ones(x.shape[0]) cost_matrix = pointcloud.PointCloud(x, centers).cost_matrix @@ -104,7 +104,7 @@ def test_matches_sklearn(self, rng: jax.Array, k: int): def test_initialization_differentiable(self, rng: jax.Array): - def callback(x: jnp.ndarray) -> float: + def callback(x: jax.Array) -> float: geom = pointcloud.PointCloud(x) centers = k_means._k_means_plus_plus(geom, k=3, rng=rng) _, inertia = compute_assignment(x, centers) @@ -336,7 +336,7 @@ def test_k_means_jitting( self, rng: jax.Array, init: Literal["k-means++", "random"] ): - def callback(x: jnp.ndarray) -> k_means.KMeansOutput: + def callback(x: jax.Array) -> k_means.KMeansOutput: return k_means.k_means( x, k=k, init=init, store_inner_errors=True, rng=rng ) @@ -368,7 +368,7 @@ def test_k_means_differentiability( self, rng: jax.Array, jit: bool, force_scan: bool ): - def inertia(x: jnp.ndarray, w: jnp.ndarray) -> float: + def inertia(x: jax.Array, w: jax.Array) -> float: return k_means.k_means( x, k=k, diff --git a/tests/tools/sinkhorn_divergence_test.py b/tests/tools/sinkhorn_divergence_test.py index d46c220d0..07bcf535e 100644 --- a/tests/tools/sinkhorn_divergence_test.py +++ b/tests/tools/sinkhorn_divergence_test.py @@ -403,7 +403,7 @@ def test_gradient_generic_point_cloud_wrapper(self): x = jax.random.uniform(rngs[0], (self._num_points[0], self._dim)) y = jax.random.uniform(rngs[1], (self._num_points[1], self._dim)) - def loss_fn(cloud_a: jnp.ndarray, cloud_b: jnp.ndarray) -> float: + def loss_fn(cloud_a: jax.Array, cloud_b: jax.Array) -> float: div = sinkhorn_divergence.sinkhorn_divergence( pointcloud.PointCloud, cloud_a, diff --git a/tests/tools/soft_sort_test.py b/tests/tools/soft_sort_test.py index 372420a9e..4f3a12c10 100644 --- a/tests/tools/soft_sort_test.py +++ b/tests/tools/soft_sort_test.py @@ -108,7 +108,7 @@ def test_multivariate_cdf_quantiles(self, rng: jax.Array): # Check passing custom sampler, must be still symmetric / centered on {.5}^d # Check passing custom epsilon also works. - def ball_sampler(k: jax.Array, s: Tuple[int, int]) -> jnp.ndarray: + def ball_sampler(k: jax.Array, s: Tuple[int, int]) -> jax.Array: return 0.5 * (jax.random.ball(k, d=s[1], p=4, shape=(s[0],)) + 1.) num_target_samples = 473 @@ -283,7 +283,7 @@ def test_soft_sort_jacobian(self, rng: jax.Array, implicit: bool): z = jax.random.uniform(rngs[0], ((b, n))) random_dir = jax.random.normal(rngs[1], (b,)) / b - def loss_fn(logits: jnp.ndarray) -> float: + def loss_fn(logits: jax.Array) -> float: im_d = None if implicit: # Ridge parameters are only used when using JAX's CG. From b075758cd367f4f24b6cd5e3e0dd0ca9b74302ad Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Sun, 26 Nov 2023 16:27:49 +0100 Subject: [PATCH 021/186] [ci skip] change init arguments of GENOT and add docstrings to GENOT --- src/ott/neural/solvers/genot.py | 184 ++++++++++++++------------------ src/ott/neural/solvers/otfm.py | 45 ++++---- tests/neural/genot_test.py | 12 +-- 3 files changed, 110 insertions(+), 131 deletions(-) diff --git a/src/ott/neural/solvers/genot.py b/src/ott/neural/solvers/genot.py index efdf5af29..808293f22 100644 --- a/src/ott/neural/solvers/genot.py +++ b/src/ott/neural/solvers/genot.py @@ -17,8 +17,6 @@ Any, Callable, Dict, - Literal, - Mapping, Optional, Tuple, Type, @@ -58,6 +56,37 @@ class GENOT(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): + """The GENOT training class as introduced in :cite:`TODO`. + + Args: + neural_vector_field: Neural vector field parameterized by a neural network. + input_dim: Dimension of the data in the source distribution. + output_dim: Dimension of the data in the target distribution. + cond_dim: Dimension of the conditioning variable. + iterations: Number of iterations. + valid_freq: Frequency of validation. + ot_solver: OT solver to match samples from the source and the target distribution. + epsilon: Entropy regularization term of the OT problem solved by `ot_solver`. + cost_fn: Cost function for the OT problem solved by the `ot_solver`. In the linear case, this is always expected to be of type `str`. If the problem is of quadratic type and `cost_fn` is a string, the `cost_fn` is used for all terms, i.e. both quadratic terms and, if applicable, the linear temr. If of type :class:`dict`, the keys are expected to be `x_cost_fn`, `y_cost_fn`, and if applicable, `xy_cost_fn`. + scale_cost: How to scale the cost matrix for the OT problem solved by the `ot_solver`. In the linear case, this is always expected to be not a :class:`dict`. If the problem is of quadratic type and `scale_cost` is a string, the `scale_cost` argument is used for all terms, i.e. both quadratic terms and, if applicable, the linear temr. If of type :class:`dict`, the keys are expected to be `x_scale_cost`, `y_scale_cost`, and if applicable, `xy_scale_cost`. + optimizer: Optimizer for `neural_vector_field`. + flow: Flow between latent distribution and target distribution. + time_sampler: Sampler for the time. + checkpoint_manager: Checkpoint manager. + k_samples_per_x: Number of samples drawn from the conditional distribution of an input sample, see algorithm TODO. + solver_latent_to_data: Linear OT solver to match the latent distribution with the conditional distribution. Only applicable if `k_samples_per_x` is larger than :math:`1`. #TODO: adapt + kwargs_solver_latent_to_data: Keyword arguments for `solver_latent_to_data`. #TODO: adapt + fused_penalty: Fused penalty of the linear/fused term in the Fused Gromov-Wasserstein problem. + tau_a: If :math:`<1`, defines how much unbalanced the problem is + on the first marginal. + tau_b: If :math:`< 1`, defines how much unbalanced the problem is + on the second marginal. + mlp_eta: Neural network to learn the left rescaling function as suggested in :cite:`TODO`. If `None`, the left rescaling factor is not learnt. + mlp_xi: Neural network to learn the right rescaling function as suggested in :cite:`TODO`. If `None`, the right rescaling factor is not learnt. + unbalanced_kwargs: Keyword arguments for the unbalancedness solver. + callback_fn: Callback function. + rng: Random number generator. + """ def __init__( self, @@ -68,95 +97,27 @@ def __init__( iterations: int, valid_freq: int, ot_solver: Type[was_solver.WassersteinSolver], + epsilon: float, + cost_fn: Union[costs.CostFn, Dict[str, costs.CostFn]], + scale_cost: Union[Any, Dict[str, Any]], #TODO: replace `Any` optimizer: Type[optax.GradientTransformation], - checkpoint_manager: Type[checkpoint.CheckpointManager] = None, flow: Type[BaseFlow] = ConstantNoiseFlow(0.0), time_sampler: Type[BaseTimeSampler] = UniformSampler(), - k_noise_per_x: int = 1, - t_offset: float = 1e-5, - epsilon: float = 1e-2, - cost_fn: Union[costs.CostFn, Literal["graph"]] = costs.SqEuclidean(), + checkpoint_manager: Type[checkpoint.CheckpointManager] = None, + k_samples_per_x: int = 1, solver_latent_to_data: Optional[Type[was_solver.WassersteinSolver] ] = None, kwargs_solver_latent_to_data: Dict[str, Any] = types.MappingProxyType({}), - scale_cost: Union[Any, Mapping[str, Any]] = 1.0, fused_penalty: float = 0.0, tau_a: float = 1.0, tau_b: float = 1.0, mlp_eta: Callable[[jax.Array], float] = None, mlp_xi: Callable[[jax.Array], float] = None, unbalanced_kwargs: Dict[str, Any] = {}, - callback: Optional[Callable[[jax.Array, jax.Array, jax.Array], - Any]] = None, - callback_kwargs: Dict[str, Any] = {}, - callback_iters: int = 10, + callback_fn: Optional[Callable[[jax.Array, jax.Array, jax.Array], + Any]] = None, rng: random.PRNGKeyArray = random.PRNGKey(0), - **kwargs: Any, ) -> None: - """The GENOT training class. - - Parameters - ---------- - neural_vector_field - Neural vector field - input_dim - Dimension of the source distribution - output_dim - Dimension of the target distribution - cond_dim - Dimension of the condition - iterations - Number of iterations to train - valid_freq - Number of iterations after which to perform a validation step - ot_solver - Solver to match samples from the source to the target distribution - optimizer - Optimizer for the neural vector field - flow - Flow to use in the target space from noise to data. Should be of type - `ConstantNoiseFlow` to recover the setup in the paper TODO. - k_noise_per_x - Number of samples to draw from the conditional distribution - t_offset - Offset for sampling from the time t - epsilon - Entropy regularization parameter for the discrete solver - cost_fn - Cost function to use for the discrete OT solver - solver_latent_to_data - Linear OT solver to match samples from the noise to the conditional distribution - latent_to_data_epsilon - Entropy regularization term for `solver_latent_to_data` - latent_to_data_scale_cost - How to scale the cost matrix for the `solver_latent_to_data` solver - scale_cost - How to scale the cost matrix in each discrete OT problem - graph_kwargs - Keyword arguments for the graph cost computation in case `cost="graph"` - fused_penalty - Penalisation term for the linear term in a Fused GW setting - split_dim - Dimension to split the data into fused term and purely quadratic term in the FGW setting - mlp_eta - Neural network to learn the left rescaling function - mlp_xi - Neural network to learn the right rescaling function - tau_a - Left unbalancedness parameter - tau_b - Right unbalancedness parameter - callback - Callback function - callback_kwargs - Keyword arguments to the callback function - callback_iters - Number of iterations after which to evaluate callback function - seed - Random seed - kwargs - Keyword arguments passed to `setup`, e.g. custom choice of optimizers for learning rescaling functions - """ BaseNeuralSolver.__init__( self, iterations=iterations, valid_freq=valid_freq ) @@ -196,7 +157,7 @@ def __init__( self.input_dim = input_dim self.output_dim = output_dim self.cond_dim = cond_dim - self.k_noise_per_x = k_noise_per_x + self.k_noise_per_x = k_samples_per_x # OT data-data matching parameters self.ot_solver = ot_solver @@ -210,14 +171,8 @@ def __init__( self.kwargs_solver_latent_to_data = kwargs_solver_latent_to_data # callback parameteres - self.callback = callback - self.callback_kwargs = callback_kwargs - self.callback_iters = callback_iters - - #TODO: check how to handle this - self.t_offset = t_offset - - self.setup(**kwargs) + self.callbac_fn = callback_fn + self.setup() def setup(self) -> None: """Set up the model. @@ -395,23 +350,24 @@ def transport( source: jax.Array, condition: Optional[jax.Array], rng: random.PRNGKeyArray = random.PRNGKey(0), - diffeqsolve_kwargs: Dict[str, Any] = types.MappingProxyType({}), forward: bool = True, + diffeqsolve_kwargs: Dict[str, Any] = types.MappingProxyType({}), ) -> Union[jnp.array, diffrax.Solution, Optional[jax.Array]]: - """Transport the distribution. + """Transport data with the learnt plan. - Parameters - ---------- - source - Source distribution to transport - seed - Random seed for sampling from the latent distribution - diffeqsolve_kwargs - Keyword arguments for the ODE solver. + This method pushes-forward the `source` to its conditional distribution by solving the neural ODE parameterized by the :attr:`~ott.neural.solvers.GENOTg.neural_vector_field` from + :attr:`~ott.neural.flows.BaseTimeSampler.low` to :attr:`~ott.neural.flows.BaseTimeSampler.high`. + + Args: + data: Initial condition of the ODE. + condition: Condition of the input data. + rng: random seed for sampling from the latent distribution. + forward: If `True` integrates forward, otherwise backwards. + diffeqsovle_kwargs: Keyword arguments for the ODE solver. Returns: - ------- - The transported samples, the solution of the neural ODE, and the rescaling factor. + The push-forward or pull-back distribution defined by the learnt transport plan. + """ if not forward: raise NotImplementedError @@ -449,24 +405,46 @@ def solve_ode(input: jax.Array, cond: jax.Array): return jax.vmap(solve_ode)(latent_batch, cond_input) def _valid_step(self, valid_loader, iter) -> None: + """TODO.""" next(valid_loader) - # TODO: add callback and logging - @property def learn_rescaling(self) -> bool: + """Whether to learn at least one rescaling factor of the marginal distributions.""" return self.mlp_eta is not None or self.mlp_xi is not None def save(self, path: str) -> None: + """Save the model. + + Args: + path: Where to save the model to. + """ raise NotImplementedError def load(self, path: str) -> "GENOT": + """Load a model. + + Args: + path: Where to load the model from. + + Returns: + An instance of :class:`ott.neural.solvers.OTFlowMatching`. + """ raise NotImplementedError + @property def training_logs(self) -> Dict[str, Any]: + """Logs of the training.""" raise NotImplementedError - def sample_noise( #TODO: make more general - self, key: random.PRNGKey, batch_size: int - ) -> jax.Array: #TODO: make more general + def sample_noise(self, key: random.PRNGKey, batch_size: int) -> jax.Array: + """Sample noise from a standard-normal distribution. + + Args: + key: Random key for seeding. + batch_size: Number of samples to draw. + + Returns: + Samples from the standard normal distribution. + """ return random.normal(key, shape=(batch_size, self.output_dim)) diff --git a/src/ott/neural/solvers/otfm.py b/src/ott/neural/solvers/otfm.py index 3b5aa3319..8001a00a1 100644 --- a/src/ott/neural/solvers/otfm.py +++ b/src/ott/neural/solvers/otfm.py @@ -41,30 +41,30 @@ class OTFlowMatching(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): """Flow matching as introduced in :cite:`TODO, with extension to OT-FM (). Args: - neural_vector_field: Neural vector field parameterized by a neural network. - input_dim: Dimension of the input data. - cond_dim: Dimension of the conditioning variable. - iterations: Number of iterations. - valid_freq: Frequency of validation. - ot_solver: OT solver to match samples from the source and the target distribution as proposed in :cite:`TODO`. If `None`, no matching will be performed as proposed in :cite:`TODO`. - flow: Flow between source and target distribution. - time_sampler: Sampler for the time. - optimizer: Optimizer for `neural_vector_field`. - checkpoint_manager: Checkpoint manager. - epsilon: Entropy regularization term for the `ot_solver`. - cost_fn: Cost function for the OT problem solved by the `ot_solver`. - tau_a: If :math:`<1`, defines how much unbalanced the problem is - on the first marginal. - tau_b: If :math:`< 1`, defines how much unbalanced the problem is - on the second marginal. - mlp_eta: Neural network to learn the left rescaling function as suggested in :cite:`TODO`. If `None`, the left rescaling factor is not learnt. - mlp_xi: Neural network to learn the right rescaling function as suggested in :cite:`TODO`. If `None`, the right rescaling factor is not learnt. - unbalanced_kwargs: Keyword arguments for the unbalancedness solver. - callback_fn: Callback function. - rng: Random number generator. + neural_vector_field: Neural vector field parameterized by a neural network. + input_dim: Dimension of the input data. + cond_dim: Dimension of the conditioning variable. + iterations: Number of iterations. + valid_freq: Frequency of validation. + ot_solver: OT solver to match samples from the source and the target distribution as proposed in :cite:`TODO`. If `None`, no matching will be performed as proposed in :cite:`TODO`. + flow: Flow between source and target distribution. + time_sampler: Sampler for the time. + optimizer: Optimizer for `neural_vector_field`. + checkpoint_manager: Checkpoint manager. + epsilon: Entropy regularization term of the OT OT problem solved by the `ot_solver`. + cost_fn: Cost function for the OT problem solved by the `ot_solver`. + tau_a: If :math:`<1`, defines how much unbalanced the problem is + on the first marginal. + tau_b: If :math:`< 1`, defines how much unbalanced the problem is + on the second marginal. + mlp_eta: Neural network to learn the left rescaling function as suggested in :cite:`TODO`. If `None`, the left rescaling factor is not learnt. + mlp_xi: Neural network to learn the right rescaling function as suggested in :cite:`TODO`. If `None`, the right rescaling factor is not learnt. + unbalanced_kwargs: Keyword arguments for the unbalancedness solver. + callback_fn: Callback function. + rng: Random number generator. Returns: - None + None """ @@ -295,6 +295,7 @@ def load(self, path: str) -> "OTFlowMatching": """ raise NotImplementedError + @property def training_logs(self) -> Dict[str, Any]: """Logs of the training.""" raise NotImplementedError diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index 183af8419..cf09cad33 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -57,7 +57,7 @@ def test_genot_linear_unconditional( ot_solver=ot_solver, time_sampler=time_sampler, optimizer=optimizer, - k_noise_per_x=k_noise_per_x, + k_samples_per_x=k_noise_per_x, ) genot(genot_data_loader_linear, genot_data_loader_linear) @@ -99,7 +99,7 @@ def test_genot_quad_unconditional( ot_solver=ot_solver, time_sampler=time_sampler, optimizer=optimizer, - k_noise_per_x=k_noise_per_x, + k_samples_per_x=k_noise_per_x, ) genot(genot_data_loader_quad, genot_data_loader_quad) @@ -139,7 +139,7 @@ def test_genot_fused_unconditional( time_sampler=time_sampler, optimizer=optimizer, fused_penalty=0.5, - k_noise_per_x=k_noise_per_x, + k_samples_per_x=k_noise_per_x, ) genot(genot_data_loader_fused, genot_data_loader_fused) @@ -178,7 +178,7 @@ def test_genot_linear_conditional( ot_solver=ot_solver, time_sampler=time_sampler, optimizer=optimizer, - k_noise_per_x=k_noise_per_x, + k_samples_per_x=k_noise_per_x, ) genot( genot_data_loader_linear_conditional, @@ -223,7 +223,7 @@ def test_genot_quad_conditional( ot_solver=ot_solver, time_sampler=time_sampler, optimizer=optimizer, - k_noise_per_x=k_noise_per_x, + k_samples_per_x=k_noise_per_x, ) genot(genot_data_loader_quad, genot_data_loader_quad) @@ -263,7 +263,7 @@ def test_genot_fused_conditional( time_sampler=time_sampler, optimizer=optimizer, fused_penalty=0.5, - k_noise_per_x=k_noise_per_x, + k_samples_per_x=k_noise_per_x, ) genot(genot_data_loader_fused, genot_data_loader_fused) From 95e8707bbef94a02639f164f3c6f4f0c39acd4ef Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Sun, 26 Nov 2023 16:38:47 +0100 Subject: [PATCH 022/186] [ci skip] split nets into base_models and models --- src/ott/neural/models/__init__.py | 2 +- src/ott/neural/models/base_models.py | 42 ++++++++++++++++++++++++++++ src/ott/neural/models/models.py | 33 ++++++---------------- src/ott/neural/solvers/flows.py | 5 ++++ src/ott/neural/solvers/genot.py | 2 ++ src/ott/neural/solvers/otfm.py | 2 ++ 6 files changed, 60 insertions(+), 26 deletions(-) create mode 100644 src/ott/neural/models/base_models.py diff --git a/src/ott/neural/models/__init__.py b/src/ott/neural/models/__init__.py index ec2ab6f3f..d2a583f34 100644 --- a/src/ott/neural/models/__init__.py +++ b/src/ott/neural/models/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import conjugate_solvers, layers, models +from . import base_models, conjugate_solvers, layers, models diff --git a/src/ott/neural/models/base_models.py b/src/ott/neural/models/base_models.py new file mode 100644 index 000000000..daf161abf --- /dev/null +++ b/src/ott/neural/models/base_models.py @@ -0,0 +1,42 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc +from typing import Optional + +import flax.linen as nn +import jax + +__all__ = ["BaseNeuralVectorField", "BaseRescalingNet"] + + +class BaseNeuralVectorField(nn.Module, abc.ABC): + + @abc.abstractmethod + def __call__( + self, + t: jax.Array, + x: jax.Array, + condition: Optional[jax.Array] = None, + keys_model: Optional[jax.Array] = None + ) -> jax.Array: # noqa: D102): + pass + + +class BaseRescalingNet(nn.Module, abc.ABC): + + @abc.abstractmethod + def __call___( + self, x: jax.Array, condition: Optional[jax.Array] = None + ) -> jax.Array: + pass diff --git a/src/ott/neural/models/models.py b/src/ott/neural/models/models.py index 5ec8fb292..ea191d99d 100644 --- a/src/ott/neural/models/models.py +++ b/src/ott/neural/models/models.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import abc import functools from typing import Any, Callable, Dict, Optional, Sequence, Tuple @@ -28,10 +27,16 @@ from ott.initializers.linear import initializers as lin_init from ott.math import matrix_square_root from ott.neural.models import layers +from ott.neural.models.base_models import ( + BaseNeuralVectorField, + BaseRescalingNet, +) from ott.neural.solvers import neuraldual from ott.problems.linear import linear_problem -__all__ = ["ICNN", "MLP", "MetaInitializer"] +__all__ = [ + "ICNN", "MLP", "MetaInitializer", "NeuralVectorField", "Rescaling_MLP" +] class ICNN(neuraldual.BaseW2NeuralDual): @@ -418,19 +423,6 @@ def __call__(self, x): return nn.Dense(self.out_dim)(x) -class BaseNeuralVectorField(nn.Module, abc.ABC): - - @abc.abstractmethod - def __call__( - self, - t: jax.Array, - x: jax.Array, - condition: Optional[jax.Array] = None, - keys_model: Optional[jax.Array] = None - ) -> jax.Array: # noqa: D102): - pass - - class NeuralVectorField(BaseNeuralVectorField): output_dim: int condition_dim: int @@ -541,16 +533,7 @@ def create_train_state( ) -class BaseRescalingNet(nn.Module, abc.ABC): - - @abc.abstractmethod - def __call___( - self, x: jax.Array, condition: Optional[jax.Array] = None - ) -> jax.Array: - pass - - -class Rescaling_MLP(nn.Module): +class Rescaling_MLP(BaseRescalingNet): hidden_dim: int cond_dim: int is_potential: bool = False diff --git a/src/ott/neural/solvers/flows.py b/src/ott/neural/solvers/flows.py index 6552048fb..148c7b188 100644 --- a/src/ott/neural/solvers/flows.py +++ b/src/ott/neural/solvers/flows.py @@ -16,6 +16,11 @@ import jax import jax.numpy as jnp +__all__ = [ + "BaseFlow", "StraightFlow", "ConstantNoiseFlow", "BrownianNoiseFlow", + "BaseTimeSampler", "UniformSampler", "OffsetUniformSampler" +] + class BaseFlow(abc.ABC): """Base class for all flows. diff --git a/src/ott/neural/solvers/genot.py b/src/ott/neural/solvers/genot.py index 808293f22..3c96a269d 100644 --- a/src/ott/neural/solvers/genot.py +++ b/src/ott/neural/solvers/genot.py @@ -54,6 +54,8 @@ Match_latent_fn_T = Callable[[jax.random.PRNGKeyArray, jnp.array, jnp.array], Tuple[jnp.array, jnp.array]] +__all__ = ["GENOT"] + class GENOT(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): """The GENOT training class as introduced in :cite:`TODO`. diff --git a/src/ott/neural/solvers/otfm.py b/src/ott/neural/solvers/otfm.py index 8001a00a1..27c4fab64 100644 --- a/src/ott/neural/solvers/otfm.py +++ b/src/ott/neural/solvers/otfm.py @@ -36,6 +36,8 @@ ) from ott.solvers import was_solver +__all__ = ["OTFlowMatching"] + class OTFlowMatching(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): """Flow matching as introduced in :cite:`TODO, with extension to OT-FM (). From 3b1791d586a302a80739d76b907c581196ad2e1a Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Sun, 26 Nov 2023 16:52:52 +0100 Subject: [PATCH 023/186] [ci skip] add references --- docs/references.bib | 41 +++++++++++++++++++++++++++++++++ src/ott/neural/solvers/genot.py | 6 ++--- src/ott/neural/solvers/otfm.py | 4 ++-- 3 files changed, 46 insertions(+), 5 deletions(-) diff --git a/docs/references.bib b/docs/references.bib index 35ba274ba..0df2ad9ce 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -805,3 +805,44 @@ @misc{klein:23 title = {Learning Costs for Structured Monge Displacements}, year = {2023}, } + +@misc{klein_uscidda:23, + author = {Klein, Dominik and Uscidda, Th{\'e}o and Theis, Fabian and Cuturi, Marco}, + doi = {10.48550/arXiv.2310.09254}, + eprint = {2310.09254}, + eprintclass = {stat.ML}, + eprinttype = {arXiv}, + title = {Generative Entropic Neural Optimal Transport To Map Within and Across Spaces}, + year = {2023}, +} + +@misc{lipman:22, + author = {Lipman, Yaron and Chen, Ricky TQ and Ben-Hamu, Heli and Nickel, Maximilian and Le, Matt}, + doi = {10.48550/arXiv.2210.02747, + eprint = {2210.02747}, + eprintclass = {stat.ML}, + eprinttype = {arXiv}, + title = {Flow matching for generative modeling}, + year = {2022}, +} + + +@misc{tong:23, + author={Tong, Alexander and Malkin, Nikolay and Huguet, Guillaume and Zhang, Yanlei and Rector-Brooks, Jarrid and Fatras, Kilian and Wolf, Guy and Bengio, Yoshua}, + doi = {TODO}, + eprint = {TODO}, + eprintclass = {TODO}, + eprinttype = {TODO}, + title={Improving and generalizing flow-based generative models with minibatch optimal transport}, + year={2023} +} + +@misc{pooladian:23, + author={Pooladian, Aram-Alexandre and Ben-Hamu, Heli and Domingo-Enrich, Carles and Amos, Brandon and Lipman, Yaron and Chen, Ricky}, + doi = {10.48550/arXiv.2304.14772, + eprint = {2304.14772}, + eprintclass = {stat.ML}, + eprinttype = {arXiv}, + title = {Multisample flow matching: Straightening flows with minibatch couplings}, + year = {2022}, +} diff --git a/src/ott/neural/solvers/genot.py b/src/ott/neural/solvers/genot.py index 3c96a269d..e3f17c7f3 100644 --- a/src/ott/neural/solvers/genot.py +++ b/src/ott/neural/solvers/genot.py @@ -58,7 +58,7 @@ class GENOT(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): - """The GENOT training class as introduced in :cite:`TODO`. + """The GENOT training class as introduced in :cite:`klein_uscidda:23`. Args: neural_vector_field: Neural vector field parameterized by a neural network. @@ -83,8 +83,8 @@ class GENOT(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): on the first marginal. tau_b: If :math:`< 1`, defines how much unbalanced the problem is on the second marginal. - mlp_eta: Neural network to learn the left rescaling function as suggested in :cite:`TODO`. If `None`, the left rescaling factor is not learnt. - mlp_xi: Neural network to learn the right rescaling function as suggested in :cite:`TODO`. If `None`, the right rescaling factor is not learnt. + mlp_eta: Neural network to learn the left rescaling function. If `None`, the left rescaling factor is not learnt. + mlp_xi: Neural network to learn the right rescaling function. If `None`, the right rescaling factor is not learnt. unbalanced_kwargs: Keyword arguments for the unbalancedness solver. callback_fn: Callback function. rng: Random number generator. diff --git a/src/ott/neural/solvers/otfm.py b/src/ott/neural/solvers/otfm.py index 27c4fab64..7faf68c71 100644 --- a/src/ott/neural/solvers/otfm.py +++ b/src/ott/neural/solvers/otfm.py @@ -40,7 +40,7 @@ class OTFlowMatching(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): - """Flow matching as introduced in :cite:`TODO, with extension to OT-FM (). + """Flow matching as introduced in :cite:`lipman:22`, with extension to OT-FM (:cite`tong:23`, :cite:`pooladian:23`). Args: neural_vector_field: Neural vector field parameterized by a neural network. @@ -48,7 +48,7 @@ class OTFlowMatching(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): cond_dim: Dimension of the conditioning variable. iterations: Number of iterations. valid_freq: Frequency of validation. - ot_solver: OT solver to match samples from the source and the target distribution as proposed in :cite:`TODO`. If `None`, no matching will be performed as proposed in :cite:`TODO`. + ot_solver: OT solver to match samples from the source and the target distribution as proposed in :cite:`tong:23`, :cite:`pooladian:23`. If `None`, no matching will be performed as proposed in :cite:`lipman:22`. flow: Flow between source and target distribution. time_sampler: Sampler for the time. optimizer: Optimizer for `neural_vector_field`. From eca77c055024c0e068553aa7bede2c127137da75 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Sun, 26 Nov 2023 18:51:33 +0100 Subject: [PATCH 024/186] add tests for learning the rescaling factors --- src/ott/neural/solvers/base_solver.py | 39 +++++++++++++++ src/ott/neural/solvers/genot.py | 2 +- src/ott/neural/solvers/otfm.py | 2 +- tests/neural/genot_test.py | 47 +++++++++++++++++ .../{flow_matching_test.py => otfm_test.py} | 50 ++++++++++++++++++- 5 files changed, 136 insertions(+), 4 deletions(-) rename tests/neural/{flow_matching_test.py => otfm_test.py} (74%) diff --git a/src/ott/neural/solvers/base_solver.py b/src/ott/neural/solvers/base_solver.py index 69f510d81..e18f0db96 100644 --- a/src/ott/neural/solvers/base_solver.py +++ b/src/ott/neural/solvers/base_solver.py @@ -44,6 +44,7 @@ def __init__(self, iterations: int, valid_freq: int, **_: Any) -> None: @abstractmethod def setup(self, *args: Any, **kwargs: Any) -> None: + """Setup the model.""" pass @abstractmethod @@ -74,6 +75,7 @@ def training_logs(self) -> Dict[str, Any]: class ResampleMixin: + """Mixin class for mini-batch OT in neural optimal transport solvers.""" def __init__(*args, **kwargs): pass @@ -239,6 +241,7 @@ def match_pairs( class UnbalancednessMixin: + """Mixin class to incorporate unbalancedness into neural OT models.""" def __init__( self, @@ -421,3 +424,39 @@ def step_fn( return new_state_eta, new_state_xi, eta_predictions, xi_predictions, loss_a, loss_b return step_fn + + def evaluate_eta( + self, source: jax.Array, condition: Optional[jax.Array] + ) -> jax.Array: + """Evaluate the left learnt rescaling factor. + + Args: + source: Samples from the source distribution to evaluate rescaling function on. + condition: Condition belonging to the samples in the source distribution. + + Returns: + Learnt left rescaling factors. + """ + if self.state_eta is None: + raise ValueError("The left rescaling factor was not parameterized.") + return self.state_xi.apply_fn({"params": self.state_eta.params}, + x=source, + condition=condition) + + def evaluate_xi( + self, target: jax.Array, condition: Optional[jax.Array] + ) -> jax.Array: + """Evaluate the right learnt rescaling factor. + + Args: + target: Samples from the target distribution to evaluate the rescaling function on. + condition: Condition belonging to the samples in the target distribution. + + Returns: + Learnt right rescaling factors. + """ + if self.state_xi is None: + raise ValueError("The right rescaling factor was not parameterized.") + return self.state_xi.apply_fn({"params": self.state_xi.params}, + x=target, + condition=condition) diff --git a/src/ott/neural/solvers/genot.py b/src/ott/neural/solvers/genot.py index e3f17c7f3..00b01cb6f 100644 --- a/src/ott/neural/solvers/genot.py +++ b/src/ott/neural/solvers/genot.py @@ -365,7 +365,7 @@ def transport( condition: Condition of the input data. rng: random seed for sampling from the latent distribution. forward: If `True` integrates forward, otherwise backwards. - diffeqsovle_kwargs: Keyword arguments for the ODE solver. + diffeqsolve_kwargs: Keyword arguments for the ODE solver. Returns: The push-forward or pull-back distribution defined by the learnt transport plan. diff --git a/src/ott/neural/solvers/otfm.py b/src/ott/neural/solvers/otfm.py index 7faf68c71..1ca5b0b0e 100644 --- a/src/ott/neural/solvers/otfm.py +++ b/src/ott/neural/solvers/otfm.py @@ -235,7 +235,7 @@ def transport( data: Initial condition of the ODE. condition: Condition of the input data. forward: If `True` integrates forward, otherwise backwards. - diffeqsovle_kwargs: Keyword arguments for the ODE solver. + diffeqsolve_kwargs: Keyword arguments for the ODE solver. Returns: The push-forward or pull-back distribution defined by the learnt transport plan. diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index cf09cad33..e13eb0aaf 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -272,3 +272,50 @@ def test_genot_fused_conditional( ) assert isinstance(result_forward, jax.Array) assert jnp.sum(jnp.isnan(result_forward)) == 0 + + @pytest.mark.parametrize("conditional", [False, True]) + def test_genot_linear_learn_rescaling( + self, conditional: bool, genot_data_loader_linear: Iterator, + genot_data_loader_linear_conditional: Iterator + ): + data_loader = genot_data_loader_linear_conditional if conditional else genot_data_loader_linear + + source_lin, source_quad, target_lin, target_quad, condition = next( + genot_data_loader_linear + ) + source_dim = source_lin.shape[1] + target_dim = target_lin.shape[1] + condition_dim = condition.shape[1] if conditional else 0 + + neural_vf = NeuralVectorField( + output_dim=target_dim, + condition_dim=condition_dim, + latent_embed_dim=5, + ) + ot_solver = sinkhorn.Sinkhorn() + time_sampler = UniformSampler() + optimizer = optax.adam(learning_rate=1e-3) + genot = GENOT( + neural_vf, + input_dim=source_dim, + output_dim=target_dim, + cond_dim=condition_dim, + iterations=3, + valid_freq=2, + ot_solver=ot_solver, + time_sampler=time_sampler, + optimizer=optimizer, + ) + genot(data_loader, data_loader) + + source_lin, source_quad, target_lin, target_quad, condition = next( + genot_data_loader_linear + ) + + result_eta = genot.evaluate_eta(source_lin, condition=condition) + assert isinstance(result_eta, jax.Array) + assert jnp.sum(jnp.isnan(result_eta)) == 0 + + result_xi = genot.evaluate_xi(target_lin, condition=condition) + assert isinstance(result_xi, jax.Array) + assert jnp.sum(jnp.isnan(result_xi)) == 0 diff --git a/tests/neural/flow_matching_test.py b/tests/neural/otfm_test.py similarity index 74% rename from tests/neural/flow_matching_test.py rename to tests/neural/otfm_test.py index a1135cf2d..e26920253 100644 --- a/tests/neural/flow_matching_test.py +++ b/tests/neural/otfm_test.py @@ -11,14 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Type +from typing import Iterator, Type import jax import jax.numpy as jnp import optax import pytest -from ott.neural.models.models import NeuralVectorField +from ott.neural.models.models import NeuralVectorField, Rescaling_MLP from ott.neural.solvers.flows import ( BaseFlow, BrownianNoiseFlow, @@ -149,3 +149,49 @@ def test_flow_matching_conditional( result_backward = fm.transport(target, condition=condition, forward=False) assert isinstance(result_backward, jax.Array) assert jnp.sum(jnp.isnan(result_backward)) == 0 + + @pytest.mark.parametrize("conditional", [True, False]) + def test_flow_matching_learn_rescaling( + self, conditional: bool, data_loader_gaussian: Iterator, + data_loader_gaussian_conditional: Iterator + ): + data_loader = data_loader_gaussian_conditional if conditional else data_loader_gaussian + neural_vf = NeuralVectorField( + output_dim=2, + condition_dim=0, + latent_embed_dim=5, + ) + ot_solver = sinkhorn.Sinkhorn() + time_sampler = UniformSampler() + flow = ConstantNoiseFlow(1.0) + optimizer = optax.adam(learning_rate=1e-3) + + tau_a = 0.9 + tau_b = 0.2 + mlp_eta = Rescaling_MLP(hidden_dim=4, cond_dim=0) + mlp_xi = Rescaling_MLP(hidden_dim=4, cond_dim=0) + fm = OTFlowMatching( + neural_vf, + input_dim=2, + cond_dim=0, + iterations=3, + valid_freq=2, + ot_solver=ot_solver, + flow=flow, + time_sampler=time_sampler, + optimizer=optimizer, + tau_a=tau_a, + tau_b=tau_b, + mlp_eta=mlp_eta, + mlp_xi=mlp_xi, + ) + fm(data_loader, data_loader) + + source, target, condition = next(data_loader_gaussian) + result_eta = fm.evaluate_eta(source, condition=condition) + assert isinstance(result_eta, jax.Array) + assert jnp.sum(jnp.isnan(result_eta)) == 0 + + result_xi = fm.evaluate_xi(target, condition=condition) + assert isinstance(result_xi, jax.Array) + assert jnp.sum(jnp.isnan(result_xi)) == 0 From 62b266655482701b5cb28b8e1520dc3956873041 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Sun, 26 Nov 2023 19:38:58 +0100 Subject: [PATCH 025/186] [ci skip] partially fix rescaling factor learning --- src/ott/neural/models/base_models.py | 2 +- src/ott/neural/models/models.py | 18 +++++++++--------- src/ott/neural/solvers/base_solver.py | 21 +++++++++++++++------ src/ott/neural/solvers/flows.py | 3 +++ src/ott/neural/solvers/genot.py | 2 ++ src/ott/neural/solvers/otfm.py | 4 +++- tests/neural/genot_test.py | 8 +++++++- tests/neural/otfm_test.py | 11 +++++++---- 8 files changed, 47 insertions(+), 22 deletions(-) diff --git a/src/ott/neural/models/base_models.py b/src/ott/neural/models/base_models.py index daf161abf..74a87df93 100644 --- a/src/ott/neural/models/base_models.py +++ b/src/ott/neural/models/base_models.py @@ -36,7 +36,7 @@ def __call__( class BaseRescalingNet(nn.Module, abc.ABC): @abc.abstractmethod - def __call___( + def __call__( self, x: jax.Array, condition: Optional[jax.Array] = None ) -> jax.Array: pass diff --git a/src/ott/neural/models/models.py b/src/ott/neural/models/models.py index ea191d99d..62edee7ef 100644 --- a/src/ott/neural/models/models.py +++ b/src/ott/neural/models/models.py @@ -535,8 +535,8 @@ def create_train_state( class Rescaling_MLP(BaseRescalingNet): hidden_dim: int - cond_dim: int - is_potential: bool = False + condition_dim: int + num_layers_per_block: int = 3 act_fn: Callable[[jax.Array], jax.Array] = nn.selu @nn.compact @@ -544,8 +544,8 @@ def __call__( self, x: jax.Array, condition: Optional[jax.Array] ) -> jax.Array: # noqa: D102 x = Block( - dim=self.latent_embed_dim, - out_dim=self.latent_embed_dim, + dim=self.hidden_dim, + out_dim=self.hidden_dim, num_layers=self.num_layers_per_block, act_fn=self.act_fn )( @@ -553,8 +553,8 @@ def __call__( ) if self.condition_dim > 0: condition = Block( - dim=self.condition_embed_dim, - out_dim=self.condition_embed_dim, + dim=self.hidden_dim, + out_dim=self.hidden_dim, num_layers=self.num_layers_per_block, act_fn=self.act_fn )( @@ -565,8 +565,8 @@ def __call__( concatenated = x out = Block( - dim=self.joint_hidden_dim, - out_dim=self.joint_hidden_dim, + dim=self.hidden_dim, + out_dim=self.hidden_dim, num_layers=self.num_layers_per_block, act_fn=self.act_fn, )( @@ -582,7 +582,7 @@ def create_train_state( input_dim: int, ) -> train_state.TrainState: params = self.init( - rng, jnp.ones((1, input_dim)), jnp.ones((1, self.cond_dim)) + rng, jnp.ones((1, input_dim)), jnp.ones((1, self.condition_dim)) )["params"] return train_state.TrainState.create( apply_fn=self.apply, params=params, tx=optimizer diff --git a/src/ott/neural/solvers/base_solver.py b/src/ott/neural/solvers/base_solver.py index e18f0db96..996f2fe0e 100644 --- a/src/ott/neural/solvers/base_solver.py +++ b/src/ott/neural/solvers/base_solver.py @@ -245,6 +245,7 @@ class UnbalancednessMixin: def __init__( self, + rng: jax.Array, source_dim: int, target_dim: int, cond_dim: Optional[int], @@ -261,6 +262,7 @@ def __init__( sinkhorn_kwargs: Mapping[str, Any] = MappingProxyType({}), **_: Any, ) -> None: + self.rng_unbalanced = rng self.source_dim = source_dim self.target_dim = target_dim self.cond_dim = cond_dim @@ -325,14 +327,17 @@ def _resample_unbalanced( return tuple(b[indices] if b is not None else None for b in batch) def _setup(self, source_dim: int, target_dim: int, cond_dim: int): - self.unbalancedness_step_fn = self._get_step_fn() + self.rng_unbalanced, rng_eta, rng_xi = jax.random.split( + self.rng_unbalanced, 3 + ) + self.unbalancedness_step_fn = self._get_rescaling_step_fn() if self.mlp_eta is not None: self.opt_eta = ( self.opt_eta if self.opt_eta is not None else optax.adamw(learning_rate=1e-4, weight_decay=1e-10) ) self.state_eta = self.mlp_eta.create_train_state( - self._key, self.opt_eta, source_dim + cond_dim + rng_eta, self.opt_eta, source_dim + cond_dim ) if self.mlp_xi is not None: self.opt_xi = ( @@ -340,19 +345,20 @@ def _setup(self, source_dim: int, target_dim: int, cond_dim: int): optax.adamw(learning_rate=1e-4, weight_decay=1e-10) ) self.state_xi = self.mlp_xi.create_train_state( - self._key, self.opt_xi, target_dim + cond_dim + rng_xi, self.opt_xi, target_dim + cond_dim ) - def _get_step_fn(self) -> Callable: # type:ignore[type-arg] + def _get_rescaling_step_fn(self) -> Callable: # type:ignore[type-arg] def loss_a_fn( params_eta: Optional[jax.Array], apply_fn_eta: Callable[[Dict[str, jax.Array], jax.Array], jax.Array], x: jax.Array, + condition: Optional[jax.Array], a: jax.Array, expectation_reweighting: float, ) -> Tuple[float, jax.Array]: - eta_predictions = apply_fn_eta({"params": params_eta}, x) + eta_predictions = apply_fn_eta({"params": params_eta}, x, condition) return ( optax.l2_loss(eta_predictions[:, 0], a).mean() + optax.l2_loss(jnp.mean(eta_predictions) - expectation_reweighting), @@ -363,10 +369,11 @@ def loss_b_fn( params_xi: Optional[jax.Array], apply_fn_xi: Callable[[Dict[str, jax.Array], jax.Array], jax.Array], x: jax.Array, + condition: Optional[jax.Array], b: jax.Array, expectation_reweighting: float, ) -> Tuple[float, jax.Array]: - xi_predictions = apply_fn_xi({"params": params_xi}, x) + xi_predictions = apply_fn_xi({"params": params_xi}, x, condition) return ( optax.l2_loss(xi_predictions[:, 0], b).mean() + optax.l2_loss(jnp.mean(xi_predictions) - expectation_reweighting), @@ -397,6 +404,7 @@ def step_fn( state_eta.params, state_eta.apply_fn, input_source, + condition, a * len(a), jnp.sum(b), ) @@ -412,6 +420,7 @@ def step_fn( state_xi.params, state_xi.apply_fn, input_target, + condition, b * len(b), jnp.sum(a), ) diff --git a/src/ott/neural/solvers/flows.py b/src/ott/neural/solvers/flows.py index 148c7b188..6450e2c1b 100644 --- a/src/ott/neural/solvers/flows.py +++ b/src/ott/neural/solvers/flows.py @@ -176,6 +176,9 @@ class UniformSampler(BaseTimeSampler): high: Upper bound of the uniform distribution. """ + def __init__(self, low: float = 0.0, high: float = 1.0) -> None: + super().__init__(low=low, high=high) + def __call__(self, rng: jax.Array, num_samples: int) -> jax.Array: """Generate `num_samples` samples of the time `math`:t:. diff --git a/src/ott/neural/solvers/genot.py b/src/ott/neural/solvers/genot.py index 00b01cb6f..55888d878 100644 --- a/src/ott/neural/solvers/genot.py +++ b/src/ott/neural/solvers/genot.py @@ -120,12 +120,14 @@ def __init__( Any]] = None, rng: random.PRNGKeyArray = random.PRNGKey(0), ) -> None: + rng, rng_unbalanced = random.split(rng) BaseNeuralSolver.__init__( self, iterations=iterations, valid_freq=valid_freq ) ResampleMixin.__init__(self) UnbalancednessMixin.__init__( self, + rng=rng_unbalanced, source_dim=input_dim, target_dim=input_dim, cond_dim=cond_dim, diff --git a/src/ott/neural/solvers/otfm.py b/src/ott/neural/solvers/otfm.py index 1ca5b0b0e..74392b2c5 100644 --- a/src/ott/neural/solvers/otfm.py +++ b/src/ott/neural/solvers/otfm.py @@ -93,12 +93,14 @@ def __init__( Any]] = None, rng: random.PRNGKeyArray = random.PRNGKey(0), ) -> None: + rng, rng_unbalanced = random.split(rng) BaseNeuralSolver.__init__( self, iterations=iterations, valid_freq=valid_freq ) ResampleMixin.__init__(self) UnbalancednessMixin.__init__( self, + rng=rng_unbalanced, source_dim=input_dim, target_dim=input_dim, cond_dim=cond_dim, @@ -204,7 +206,7 @@ def __call__(self, train_loader, valid_loader) -> None: ) if self.learn_rescaling: self.state_eta, self.state_xi, eta_predictions, xi_predictions, loss_a, loss_b = self.unbalancedness_step_fn( - batch, tmat.sum(axis=1), tmat.sum(axis=0) + source=batch["source"], target=batch["target"], condition=batch["condition"], a=tmat.sum(axis=1), b=tmat.sum(axis=0), state_eta=self.state_eta, state_xi=self.state_xi, ) if iter % self.valid_freq == 0: self._valid_step(valid_loader, iter) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index e13eb0aaf..cba3f1ef7 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -18,7 +18,7 @@ import optax import pytest -from ott.neural.models.models import NeuralVectorField +from ott.neural.models.models import NeuralVectorField, Rescaling_MLP from ott.neural.solvers.flows import OffsetUniformSampler, UniformSampler from ott.neural.solvers.genot import GENOT from ott.solvers.linear import sinkhorn @@ -295,6 +295,8 @@ def test_genot_linear_learn_rescaling( ot_solver = sinkhorn.Sinkhorn() time_sampler = UniformSampler() optimizer = optax.adam(learning_rate=1e-3) + mlp_eta = Rescaling_MLP(hidden_dim=4, condition_dim=condition_dim) + mlp_xi = Rescaling_MLP(hidden_dim=4, condition_dim=condition_dim) genot = GENOT( neural_vf, input_dim=source_dim, @@ -305,6 +307,10 @@ def test_genot_linear_learn_rescaling( ot_solver=ot_solver, time_sampler=time_sampler, optimizer=optimizer, + tau_a=tau_a, + tau_b=tau_b, + mlp_eta=mlp_eta, + mlp_xi=mlp_xi, ) genot(data_loader, data_loader) diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index e26920253..9e83251e7 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -156,6 +156,9 @@ def test_flow_matching_learn_rescaling( data_loader_gaussian_conditional: Iterator ): data_loader = data_loader_gaussian_conditional if conditional else data_loader_gaussian + source, target, condition = next(data_loader) + source_dim = source.shape[1] + condition_dim = condition.shape[1] if conditional else 0 neural_vf = NeuralVectorField( output_dim=2, condition_dim=0, @@ -168,12 +171,12 @@ def test_flow_matching_learn_rescaling( tau_a = 0.9 tau_b = 0.2 - mlp_eta = Rescaling_MLP(hidden_dim=4, cond_dim=0) - mlp_xi = Rescaling_MLP(hidden_dim=4, cond_dim=0) + mlp_eta = Rescaling_MLP(hidden_dim=4, condition_dim=condition_dim) + mlp_xi = Rescaling_MLP(hidden_dim=4, condition_dim=condition_dim) fm = OTFlowMatching( neural_vf, - input_dim=2, - cond_dim=0, + input_dim=source_dim, + cond_dim=condition_dim, iterations=3, valid_freq=2, ot_solver=ot_solver, From 2ceceea191464d779a95f4ba6b65cc4ed3792595 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Sun, 26 Nov 2023 19:58:54 +0100 Subject: [PATCH 026/186] [ci skip] fix rescaling factor learning --- src/ott/neural/solvers/base_solver.py | 19 ++++++------------- src/ott/neural/solvers/genot.py | 22 +++++++++++++++++----- src/ott/neural/solvers/otfm.py | 8 +++++++- tests/neural/genot_test.py | 2 ++ tests/neural/otfm_test.py | 3 +-- 5 files changed, 33 insertions(+), 21 deletions(-) diff --git a/src/ott/neural/solvers/base_solver.py b/src/ott/neural/solvers/base_solver.py index 996f2fe0e..b58710dfd 100644 --- a/src/ott/neural/solvers/base_solver.py +++ b/src/ott/neural/solvers/base_solver.py @@ -89,9 +89,7 @@ def _resample_data( ) -> Tuple[jax.Array, ...]: """Resample a batch according to coupling `tmat`.""" tmat_flattened = tmat.flatten() - indices = random.choice( - key, len(tmat_flattened), shape=[len(tmat_flattened)] - ) + indices = random.choice(key, len(tmat_flattened), shape=[tmat.shape[0]]) indices_source = indices // tmat.shape[1] indices_target = indices % tmat.shape[1] return tuple( @@ -337,7 +335,7 @@ def _setup(self, source_dim: int, target_dim: int, cond_dim: int): optax.adamw(learning_rate=1e-4, weight_decay=1e-10) ) self.state_eta = self.mlp_eta.create_train_state( - rng_eta, self.opt_eta, source_dim + cond_dim + rng_eta, self.opt_eta, source_dim ) if self.mlp_xi is not None: self.opt_xi = ( @@ -345,7 +343,7 @@ def _setup(self, source_dim: int, target_dim: int, cond_dim: int): optax.adamw(learning_rate=1e-4, weight_decay=1e-10) ) self.state_xi = self.mlp_xi.create_train_state( - rng_xi, self.opt_xi, target_dim + cond_dim + rng_xi, self.opt_xi, target_dim ) def _get_rescaling_step_fn(self) -> Callable: # type:ignore[type-arg] @@ -392,18 +390,13 @@ def step_fn( *, is_training: bool = True, ): - if condition is None: - input_source = source - input_target = target - else: - input_source = jnp.concatenate([source, condition], axis=-1) - input_target = jnp.concatenate([target, condition], axis=-1) if state_eta is not None: grad_a_fn = jax.value_and_grad(loss_a_fn, argnums=0, has_aux=True) + print(source.shape, (a * len(a)).shape) (loss_a, eta_predictions), grads_eta = grad_a_fn( state_eta.params, state_eta.apply_fn, - input_source, + source, condition, a * len(a), jnp.sum(b), @@ -419,7 +412,7 @@ def step_fn( (loss_b, xi_predictions), grads_xi = grad_b_fn( state_xi.params, state_xi.apply_fn, - input_target, + target, condition, b * len(b), jnp.sum(a), diff --git a/src/ott/neural/solvers/genot.py b/src/ott/neural/solvers/genot.py index 55888d878..9cdab7a08 100644 --- a/src/ott/neural/solvers/genot.py +++ b/src/ott/neural/solvers/genot.py @@ -249,19 +249,25 @@ def __call__(self, train_loader, valid_loader) -> None: (batch["target"], batch["target_q"]), source_is_balanced=(self.tau_a == 1.0) ) + source = jnp.concatenate([ + batch[el] for el in ["source", "source_q"] if batch[el] is not None + ], + axis=1) + target = jnp.concatenate([ + batch[el] for el in ["target", "target_q"] if batch[el] is not None + ], + axis=1) + rng_latent = jax.random.split(rng_noise, batch_size * self.k_noise_per_x) if self.solver_latent_to_data is not None: - target = jnp.concatenate([ - batch[el] for el in ["target", "target_q"] if batch[el] is not None - ], - axis=1) tmats_latent_data = jnp.array( jax.vmap(self.match_latent_to_data_fn, 0, 0)(key=rng_latent, x=batch["latent"], y=target) ) if self.k_noise_per_x > 1: + raise NotImplementedError rng_latent_data_match = jax.random.split( rng_latent_data_match, batch_size ) @@ -290,7 +296,13 @@ def __call__(self, train_loader, valid_loader) -> None: ) if self.learn_rescaling: self.state_eta, self.state_xi, eta_predictions, xi_predictions, loss_a, loss_b = self.unbalancedness_step_fn( - batch, tmat.sum(axis=1), tmat.sum(axis=0) + source=source, + target=target, + condition=batch["condition"], + a=tmat.sum(axis=1), + b=tmat.sum(axis=0), + state_eta=self.state_eta, + state_xi=self.state_xi, ) if iteration % self.valid_freq == 0: self._valid_step(valid_loader, iteration) diff --git a/src/ott/neural/solvers/otfm.py b/src/ott/neural/solvers/otfm.py index 74392b2c5..25684016f 100644 --- a/src/ott/neural/solvers/otfm.py +++ b/src/ott/neural/solvers/otfm.py @@ -206,7 +206,13 @@ def __call__(self, train_loader, valid_loader) -> None: ) if self.learn_rescaling: self.state_eta, self.state_xi, eta_predictions, xi_predictions, loss_a, loss_b = self.unbalancedness_step_fn( - source=batch["source"], target=batch["target"], condition=batch["condition"], a=tmat.sum(axis=1), b=tmat.sum(axis=0), state_eta=self.state_eta, state_xi=self.state_xi, + source=batch["source"], + target=batch["target"], + condition=batch["condition"], + a=tmat.sum(axis=1), + b=tmat.sum(axis=0), + state_eta=self.state_eta, + state_xi=self.state_xi, ) if iter % self.valid_freq == 0: self._valid_step(valid_loader, iter) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index cba3f1ef7..ea4280f8f 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -295,6 +295,8 @@ def test_genot_linear_learn_rescaling( ot_solver = sinkhorn.Sinkhorn() time_sampler = UniformSampler() optimizer = optax.adam(learning_rate=1e-3) + tau_a = 0.9 + tau_b = 0.2 mlp_eta = Rescaling_MLP(hidden_dim=4, condition_dim=condition_dim) mlp_xi = Rescaling_MLP(hidden_dim=4, condition_dim=condition_dim) genot = GENOT( diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index 9e83251e7..1993432f2 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -150,7 +150,7 @@ def test_flow_matching_conditional( assert isinstance(result_backward, jax.Array) assert jnp.sum(jnp.isnan(result_backward)) == 0 - @pytest.mark.parametrize("conditional", [True, False]) + @pytest.mark.parametrize("conditional", [False, True]) def test_flow_matching_learn_rescaling( self, conditional: bool, data_loader_gaussian: Iterator, data_loader_gaussian_conditional: Iterator @@ -190,7 +190,6 @@ def test_flow_matching_learn_rescaling( ) fm(data_loader, data_loader) - source, target, condition = next(data_loader_gaussian) result_eta = fm.evaluate_eta(source, condition=condition) assert isinstance(result_eta, jax.Array) assert jnp.sum(jnp.isnan(result_eta)) == 0 From e8f8171aee55a7b167c4e86c9ed59ea517833c01 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Mon, 27 Nov 2023 12:04:36 +0100 Subject: [PATCH 027/186] [ci skip] all tests passing but k_samples_per_x in genot --- src/ott/neural/solvers/base_solver.py | 1 - src/ott/neural/solvers/genot.py | 67 +++++++++++---------- tests/neural/genot_test.py | 83 +++++++++++++++++---------- 3 files changed, 85 insertions(+), 66 deletions(-) diff --git a/src/ott/neural/solvers/base_solver.py b/src/ott/neural/solvers/base_solver.py index b58710dfd..80dc8e616 100644 --- a/src/ott/neural/solvers/base_solver.py +++ b/src/ott/neural/solvers/base_solver.py @@ -392,7 +392,6 @@ def step_fn( ): if state_eta is not None: grad_a_fn = jax.value_and_grad(loss_a_fn, argnums=0, has_aux=True) - print(source.shape, (a * len(a)).shape) (loss_a, eta_predictions), grads_eta = grad_a_fn( state_eta.params, state_eta.apply_fn, diff --git a/src/ott/neural/solvers/genot.py b/src/ott/neural/solvers/genot.py index 9cdab7a08..cfb1b6158 100644 --- a/src/ott/neural/solvers/genot.py +++ b/src/ott/neural/solvers/genot.py @@ -218,15 +218,15 @@ def __call__(self, train_loader, valid_loader) -> None: """Train GENOT.""" batch: Dict[str, jnp.array] = {} for iteration in range(self.iterations): - batch["source"], batch["source_q"], batch["target"], batch[ + batch["source_lin"], batch["source_q"], batch["target_lin"], batch[ "target_q"], batch["condition"] = next(train_loader) self.rng, rng_time, rng_resample, rng_noise, rng_latent_data_match, rng_step_fn = jax.random.split( self.rng, 6 ) - batch_size = len(batch["source"]) if batch["source"] is not None else len( - batch["source_q"] - ) + batch_size = len( + batch["source_lin"] + ) if batch["source_lin"] is not None else len(batch["source_q"]) n_samples = batch_size * self.k_noise_per_x batch["time"] = self.time_sampler(rng_time, n_samples) batch["noise"] = self.sample_noise(rng_noise, n_samples) @@ -237,33 +237,36 @@ def __call__(self, train_loader, valid_loader) -> None: ) tmat = self.match_fn( - batch["source"], batch["source_q"], batch["target"], batch["target_q"] + batch["source_lin"], batch["source_q"], batch["target_lin"], + batch["target_q"] ) - (batch["source"], batch["source_q"], batch["condition"] - ), (batch["target"], - batch["target_q"]) = self._sample_conditional_indices_from_tmap( - rng_resample, - tmat, - self.k_noise_per_x, - (batch["source"], batch["source_q"], batch["condition"]), - (batch["target"], batch["target_q"]), - source_is_balanced=(self.tau_a == 1.0) - ) - source = jnp.concatenate([ - batch[el] for el in ["source", "source_q"] if batch[el] is not None + + batch["source"] = jnp.concatenate([ + batch[el] + for el in ["source_lin", "source_q"] + if batch[el] is not None ], - axis=1) - target = jnp.concatenate([ - batch[el] for el in ["target", "target_q"] if batch[el] is not None + axis=1) + batch["target"] = jnp.concatenate([ + batch[el] + for el in ["target_lin", "target_q"] + if batch[el] is not None ], - axis=1) - + axis=1) + (batch["source"], batch["condition"] + ), (batch["target"],) = self._sample_conditional_indices_from_tmap( + rng_resample, + tmat, + self.k_noise_per_x, (batch["source"], batch["condition"]), + (batch["target"],), + source_is_balanced=(self.tau_a == 1.0) + ) rng_latent = jax.random.split(rng_noise, batch_size * self.k_noise_per_x) if self.solver_latent_to_data is not None: tmats_latent_data = jnp.array( jax.vmap(self.match_latent_to_data_fn, 0, - 0)(key=rng_latent, x=batch["latent"], y=target) + 0)(key=rng_latent, x=batch["latent"], y=batch["target"]) ) if self.k_noise_per_x > 1: @@ -296,8 +299,8 @@ def __call__(self, train_loader, valid_loader) -> None: ) if self.learn_rescaling: self.state_eta, self.state_xi, eta_predictions, xi_predictions, loss_a, loss_b = self.unbalancedness_step_fn( - source=source, - target=target, + source=batch["source"], + target=batch["target"], condition=batch["condition"], a=tmat.sum(axis=1), b=tmat.sum(axis=0), @@ -329,27 +332,23 @@ def loss_fn( params: jax.Array, batch: Dict[str, jnp.array], keys_model: random.PRNGKeyArray ): - target = jnp.concatenate([ - batch[el] for el in ["target", "target_q"] if batch[el] is not None - ], - axis=1) x_t = self.flow.compute_xt( - batch["noise"], batch["time"], batch["latent"], target + batch["noise"], batch["time"], batch["latent"], batch["target"] ) apply_fn = functools.partial( state_neural_vector_field.apply_fn, {"params": params} ) cond_input = jnp.concatenate([ - batch[el] - for el in ["source", "source_q", "condition"] - if batch[el] is not None + batch[el] for el in ["source", "condition"] if batch[el] is not None ], axis=1) v_t = jax.vmap(apply_fn)( t=batch["time"], x=x_t, condition=cond_input, keys_model=keys_model ) - u_t = self.flow.compute_ut(batch["time"], batch["latent"], target) + u_t = self.flow.compute_ut( + batch["time"], batch["latent"], batch["target"] + ) return jnp.mean((v_t - u_t) ** 2) keys_model = random.split(key, len(batch["noise"])) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index ea4280f8f..be2aa7a3c 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -18,6 +18,7 @@ import optax import pytest +from ott.geometry import costs from ott.neural.models.models import NeuralVectorField, Rescaling_MLP from ott.neural.solvers.flows import OffsetUniformSampler, UniformSampler from ott.neural.solvers.genot import GENOT @@ -41,7 +42,7 @@ def test_genot_linear_unconditional( neural_vf = NeuralVectorField( output_dim=target_dim, - condition_dim=condition_dim, + condition_dim=source_dim + condition_dim, latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() @@ -55,8 +56,11 @@ def test_genot_linear_unconditional( iterations=3, valid_freq=2, ot_solver=ot_solver, - time_sampler=time_sampler, + epsilon=0.1, + cost_fn=costs.SqEuclidean(), + scale_cost=1.0, optimizer=optimizer, + time_sampler=time_sampler, k_samples_per_x=k_noise_per_x, ) genot(genot_data_loader_linear, genot_data_loader_linear) @@ -82,7 +86,7 @@ def test_genot_quad_unconditional( condition_dim = 0 neural_vf = NeuralVectorField( output_dim=target_dim, - condition_dim=condition_dim, + condition_dim=source_dim + condition_dim, latent_embed_dim=5, ) ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) @@ -93,12 +97,14 @@ def test_genot_quad_unconditional( input_dim=source_dim, output_dim=target_dim, cond_dim=condition_dim, - epsilon=None, iterations=3, valid_freq=2, ot_solver=ot_solver, - time_sampler=time_sampler, + epsilon=None, + cost_fn=costs.SqEuclidean(), + scale_cost=1.0, optimizer=optimizer, + time_sampler=time_sampler, k_samples_per_x=k_noise_per_x, ) genot(genot_data_loader_quad, genot_data_loader_quad) @@ -121,11 +127,11 @@ def test_genot_fused_unconditional( condition_dim = 0 neural_vf = NeuralVectorField( output_dim=target_dim, - condition_dim=condition_dim, + condition_dim=source_dim + condition_dim, latent_embed_dim=5, ) ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) - time_sampler = UniformSampler() + UniformSampler() optimizer = optax.adam(learning_rate=1e-3) genot = GENOT( neural_vf, @@ -136,7 +142,8 @@ def test_genot_fused_unconditional( iterations=3, valid_freq=2, ot_solver=ot_solver, - time_sampler=time_sampler, + cost_fn=costs.SqEuclidean(), + scale_cost=1.0, optimizer=optimizer, fused_penalty=0.5, k_samples_per_x=k_noise_per_x, @@ -144,7 +151,9 @@ def test_genot_fused_unconditional( genot(genot_data_loader_fused, genot_data_loader_fused) result_forward = genot.transport( - source_quad, condition=condition, forward=True + jnp.concatenate((source_lin, source_quad), axis=1), + condition=condition, + forward=True ) assert isinstance(result_forward, jax.Array) assert jnp.sum(jnp.isnan(result_forward)) == 0 @@ -176,8 +185,11 @@ def test_genot_linear_conditional( iterations=3, valid_freq=2, ot_solver=ot_solver, - time_sampler=time_sampler, + epsilon=0.1, + cost_fn=costs.SqEuclidean(), + scale_cost=1.0, optimizer=optimizer, + time_sampler=time_sampler, k_samples_per_x=k_noise_per_x, ) genot( @@ -196,17 +208,17 @@ def test_genot_linear_conditional( @pytest.mark.parametrize("k_noise_per_x", [1, 2]) def test_genot_quad_conditional( - self, genot_data_loader_quad: Iterator, k_noise_per_x: int + self, genot_data_loader_quad_conditional: Iterator, k_noise_per_x: int ): source_lin, source_quad, target_lin, target_quad, condition = next( - genot_data_loader_quad + genot_data_loader_quad_conditional ) source_dim = source_quad.shape[1] target_dim = target_quad.shape[1] condition_dim = condition.shape[1] neural_vf = NeuralVectorField( output_dim=target_dim, - condition_dim=condition_dim, + condition_dim=source_dim + condition_dim, latent_embed_dim=5, ) ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) @@ -217,15 +229,19 @@ def test_genot_quad_conditional( input_dim=source_dim, output_dim=target_dim, cond_dim=condition_dim, - epsilon=None, iterations=3, valid_freq=2, ot_solver=ot_solver, - time_sampler=time_sampler, + epsilon=None, + cost_fn=costs.SqEuclidean(), + scale_cost=1.0, optimizer=optimizer, + time_sampler=time_sampler, k_samples_per_x=k_noise_per_x, ) - genot(genot_data_loader_quad, genot_data_loader_quad) + genot( + genot_data_loader_quad_conditional, genot_data_loader_quad_conditional + ) result_forward = genot.transport( source_quad, condition=condition, forward=True @@ -235,17 +251,17 @@ def test_genot_quad_conditional( @pytest.mark.parametrize("k_noise_per_x", [1, 2]) def test_genot_fused_conditional( - self, genot_data_loader_fused: Iterator, k_noise_per_x: int + self, genot_data_loader_fused_conditional: Iterator, k_noise_per_x: int ): source_lin, source_quad, target_lin, target_quad, condition = next( - genot_data_loader_fused + genot_data_loader_fused_conditional ) source_dim = source_lin.shape[1] + source_quad.shape[1] target_dim = target_lin.shape[1] + target_quad.shape[1] condition_dim = condition.shape[1] neural_vf = NeuralVectorField( output_dim=target_dim, - condition_dim=condition_dim, + condition_dim=source_dim + condition_dim, latent_embed_dim=5, ) ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) @@ -256,19 +272,24 @@ def test_genot_fused_conditional( input_dim=source_dim, output_dim=target_dim, cond_dim=condition_dim, - epsilon=None, iterations=3, valid_freq=2, ot_solver=ot_solver, - time_sampler=time_sampler, + epsilon=None, + cost_fn=costs.SqEuclidean(), + scale_cost=1.0, optimizer=optimizer, - fused_penalty=0.5, + time_sampler=time_sampler, k_samples_per_x=k_noise_per_x, ) - genot(genot_data_loader_fused, genot_data_loader_fused) + genot( + genot_data_loader_fused_conditional, genot_data_loader_fused_conditional + ) result_forward = genot.transport( - source_quad, condition=condition, forward=True + jnp.concatenate((source_lin, source_quad), axis=1), + condition=condition, + forward=True ) assert isinstance(result_forward, jax.Array) assert jnp.sum(jnp.isnan(result_forward)) == 0 @@ -281,7 +302,7 @@ def test_genot_linear_learn_rescaling( data_loader = genot_data_loader_linear_conditional if conditional else genot_data_loader_linear source_lin, source_quad, target_lin, target_quad, condition = next( - genot_data_loader_linear + data_loader ) source_dim = source_lin.shape[1] target_dim = target_lin.shape[1] @@ -289,7 +310,7 @@ def test_genot_linear_learn_rescaling( neural_vf = NeuralVectorField( output_dim=target_dim, - condition_dim=condition_dim, + condition_dim=source_dim + condition_dim, latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() @@ -307,18 +328,18 @@ def test_genot_linear_learn_rescaling( iterations=3, valid_freq=2, ot_solver=ot_solver, - time_sampler=time_sampler, + epsilon=0.1, + cost_fn=costs.SqEuclidean(), + scale_cost=1.0, optimizer=optimizer, + time_sampler=time_sampler, tau_a=tau_a, tau_b=tau_b, mlp_eta=mlp_eta, mlp_xi=mlp_xi, ) - genot(data_loader, data_loader) - source_lin, source_quad, target_lin, target_quad, condition = next( - genot_data_loader_linear - ) + genot(data_loader, data_loader) result_eta = genot.evaluate_eta(source_lin, condition=condition) assert isinstance(result_eta, jax.Array) From add1348cec204d6ffb7162fda8301b74ed145761 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Mon, 27 Nov 2023 15:37:23 +0100 Subject: [PATCH 028/186] k_samples_per_x working in GENOT --- src/ott/neural/solvers/base_solver.py | 29 ++++++++----- src/ott/neural/solvers/genot.py | 50 ++++++++++------------ tests/neural/genot_test.py | 61 ++++++++++++++++++--------- 3 files changed, 82 insertions(+), 58 deletions(-) diff --git a/src/ott/neural/solvers/base_solver.py b/src/ott/neural/solvers/base_solver.py index 80dc8e616..2c8f7541e 100644 --- a/src/ott/neural/solvers/base_solver.py +++ b/src/ott/neural/solvers/base_solver.py @@ -22,7 +22,7 @@ from flax.training import train_state from jax import random -from ott.geometry import pointcloud +from ott.geometry import costs, pointcloud from ott.geometry.pointcloud import PointCloud from ott.neural.models import models from ott.problems.linear import linear_problem @@ -108,6 +108,7 @@ def _sample_conditional_indices_from_tmap( *, source_is_balanced: bool, ) -> Tuple[jnp.array, jnp.array]: + batch_size = tmat.shape[0] left_marginals = tmat.sum(axis=1) if not source_is_balanced: key, key2 = jax.random.split(key, 2) @@ -118,12 +119,12 @@ def _sample_conditional_indices_from_tmap( shape=(len(left_marginals),) ) else: - indices = jnp.arange(tmat.shape[0]) + indices = jnp.arange(batch_size) tmat_adapted = tmat[indices] indices_per_row = jax.vmap( lambda tmat_adapted: jax.random.choice( key=key, - a=jnp.arange(tmat.shape[1]), + a=jnp.arange(batch_size), p=tmat_adapted, shape=(k_samples_per_x,) ), @@ -134,21 +135,27 @@ def _sample_conditional_indices_from_tmap( ) indices_source = jnp.repeat(indices, k_samples_per_x) - indices_target = indices_per_row % tmat.shape[1] + indices_target = jnp.reshape( + indices_per_row % tmat.shape[1], (batch_size * k_samples_per_x,) + ) return tuple( - b[indices_source, :] if b is not None else None for b in source_arrays + jnp.reshape(b[indices_source], (k_samples_per_x, batch_size, + -1)) if b is not None else None + for b in source_arrays ), tuple( - b[indices_target, :] if b is not None else None for b in target_arrays + jnp.reshape(b[indices_target, :], (k_samples_per_x, batch_size, + -1)) if b is not None else None + for b in target_arrays ) def _get_sinkhorn_match_fn( self, ot_solver: Any, - epsilon: float, - cost_fn: str, - scale_cost: Any, - tau_a: float, - tau_b: float, + epsilon: float = 1e-2, + cost_fn: Any = costs.SqEuclidean(), + scale_cost: Any = "mean", + tau_a: float = 1.0, + tau_b: float = 1.0, *, filter_input: bool = False, ) -> Callable: diff --git a/src/ott/neural/solvers/genot.py b/src/ott/neural/solvers/genot.py index cfb1b6158..377ef033d 100644 --- a/src/ott/neural/solvers/genot.py +++ b/src/ott/neural/solvers/genot.py @@ -137,7 +137,6 @@ def __init__( mlp_xi=mlp_xi, unbalanced_kwargs=unbalanced_kwargs, ) - if isinstance( ot_solver, gromov_wasserstein.GromovWasserstein ) and epsilon is not None: @@ -161,7 +160,7 @@ def __init__( self.input_dim = input_dim self.output_dim = output_dim self.cond_dim = cond_dim - self.k_noise_per_x = k_samples_per_x + self.k_samples_per_x = k_samples_per_x # OT data-data matching parameters self.ot_solver = ot_solver @@ -175,7 +174,7 @@ def __init__( self.kwargs_solver_latent_to_data = kwargs_solver_latent_to_data # callback parameteres - self.callbac_fn = callback_fn + self.callback_fn = callback_fn self.setup() def setup(self) -> None: @@ -227,13 +226,11 @@ def __call__(self, train_loader, valid_loader) -> None: batch_size = len( batch["source_lin"] ) if batch["source_lin"] is not None else len(batch["source_q"]) - n_samples = batch_size * self.k_noise_per_x + n_samples = batch_size * self.k_samples_per_x batch["time"] = self.time_sampler(rng_time, n_samples) batch["noise"] = self.sample_noise(rng_noise, n_samples) batch["latent"] = self.latent_noise_fn( - rng_noise, - shape=(batch_size, self.k_noise_per_x) if self.k_noise_per_x > 1 else - (batch_size,) + rng_noise, shape=(self.k_samples_per_x, batch_size) ) tmat = self.match_fn( @@ -253,43 +250,40 @@ def __call__(self, train_loader, valid_loader) -> None: if batch[el] is not None ], axis=1) + + batch = { + k: v + for k, v in batch.items() + if k in ["source", "target", "condition", "time", "noise", "latent"] + } + (batch["source"], batch["condition"] ), (batch["target"],) = self._sample_conditional_indices_from_tmap( rng_resample, tmat, - self.k_noise_per_x, (batch["source"], batch["condition"]), + self.k_samples_per_x, (batch["source"], batch["condition"]), (batch["target"],), source_is_balanced=(self.tau_a == 1.0) ) - rng_latent = jax.random.split(rng_noise, batch_size * self.k_noise_per_x) + jax.random.split(rng_noise, batch_size * self.k_samples_per_x) if self.solver_latent_to_data is not None: tmats_latent_data = jnp.array( jax.vmap(self.match_latent_to_data_fn, 0, - 0)(key=rng_latent, x=batch["latent"], y=batch["target"]) + 0)(x=batch["latent"], y=batch["target"]) ) - if self.k_noise_per_x > 1: - raise NotImplementedError rng_latent_data_match = jax.random.split( - rng_latent_data_match, batch_size + rng_latent_data_match, self.k_samples_per_x + ) + (batch["source"], batch["condition"] + ), (batch["target"],) = jax.vmap(self._resample_data, 0, 0)( + rng_latent_data_match, tmats_latent_data, + (batch["source"], batch["condition"]), (batch["target"],) ) - (batch["source"], batch["source_q"], batch["condition"] - ), (batch["target"], - batch["target_q"]) = jax.vmap(self._resample_data, 0, 0)( - rng_latent_data_match, tmats_latent_data, - (batch["source"], batch["source_q"], batch["condition"]), - (batch["target"], batch["target_q"]) - ) - #(batch["source"], batch["source_q"], batch["condition"] - #), (batch["target"], batch["target_q"]) = self._resample_data( - # rng_latent_data_match, tmat_latent_data, - # (batch["source"], batch["source_q"], batch["condition"]), - # (batch["target"], batch["target_q"]) - #) batch = { key: - jnp.reshape(arr, (batch_size * self.k_noise_per_x, + jnp.reshape(arr, (batch_size * self.k_samples_per_x, -1)) if arr is not None else None for key, arr in batch.items() } @@ -374,7 +368,7 @@ def transport( :attr:`~ott.neural.flows.BaseTimeSampler.low` to :attr:`~ott.neural.flows.BaseTimeSampler.high`. Args: - data: Initial condition of the ODE. + source: Data to transport. condition: Condition of the input data. rng: random seed for sampling from the latent distribution. forward: If `True` integrates forward, otherwise backwards. diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index be2aa7a3c..0c4abb55e 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterator +from typing import Iterator, Optional import jax import jax.numpy as jnp @@ -29,10 +29,14 @@ class TestGENOT: #TODO: add tests for unbalancedness - @pytest.mark.parametrize("k_noise_per_x", [1, 2]) + @pytest.mark.parametrize("k_samples_per_x", [1, 2]) + @pytest.mark.parametrize("solver_latent_to_data", [None, "sinkhorn"]) def test_genot_linear_unconditional( - self, genot_data_loader_linear: Iterator, k_noise_per_x: int + self, genot_data_loader_linear: Iterator, k_samples_per_x: int, + solver_latent_to_data: Optional[str] ): + solver_latent_to_data = None if solver_latent_to_data is None else sinkhorn.Sinkhorn( + ) source_lin, source_quad, target_lin, target_quad, condition = next( genot_data_loader_linear ) @@ -61,7 +65,8 @@ def test_genot_linear_unconditional( scale_cost=1.0, optimizer=optimizer, time_sampler=time_sampler, - k_samples_per_x=k_noise_per_x, + k_samples_per_x=k_samples_per_x, + solver_latent_to_data=solver_latent_to_data, ) genot(genot_data_loader_linear, genot_data_loader_linear) @@ -74,10 +79,13 @@ def test_genot_linear_unconditional( assert isinstance(result_forward, jax.Array) assert jnp.sum(jnp.isnan(result_forward)) == 0 - @pytest.mark.parametrize("k_noise_per_x", [1, 2]) + @pytest.mark.parametrize("k_samples_per_x", [1, 2]) + @pytest.mark.parametrize("solver_latent_to_data", [None, "sinkhorn"]) def test_genot_quad_unconditional( - self, genot_data_loader_quad: Iterator, k_noise_per_x: int + self, genot_data_loader_quad: Iterator, k_samples_per_x: int, + solver_latent_to_data: Optional[str] ): + None if solver_latent_to_data is None else sinkhorn.Sinkhorn() source_lin, source_quad, target_lin, target_quad, condition = next( genot_data_loader_quad ) @@ -105,7 +113,7 @@ def test_genot_quad_unconditional( scale_cost=1.0, optimizer=optimizer, time_sampler=time_sampler, - k_samples_per_x=k_noise_per_x, + k_samples_per_x=k_samples_per_x, ) genot(genot_data_loader_quad, genot_data_loader_quad) @@ -115,10 +123,13 @@ def test_genot_quad_unconditional( assert isinstance(result_forward, jax.Array) assert jnp.sum(jnp.isnan(result_forward)) == 0 - @pytest.mark.parametrize("k_noise_per_x", [1, 2]) + @pytest.mark.parametrize("k_samples_per_x", [1, 2]) + @pytest.mark.parametrize("solver_latent_to_data", [None, "sinkhorn"]) def test_genot_fused_unconditional( - self, genot_data_loader_fused: Iterator, k_noise_per_x: int + self, genot_data_loader_fused: Iterator, k_samples_per_x: int, + solver_latent_to_data: Optional[str] ): + None if solver_latent_to_data is None else sinkhorn.Sinkhorn() source_lin, source_quad, target_lin, target_quad, condition = next( genot_data_loader_fused ) @@ -146,7 +157,7 @@ def test_genot_fused_unconditional( scale_cost=1.0, optimizer=optimizer, fused_penalty=0.5, - k_samples_per_x=k_noise_per_x, + k_samples_per_x=k_samples_per_x, ) genot(genot_data_loader_fused, genot_data_loader_fused) @@ -158,10 +169,13 @@ def test_genot_fused_unconditional( assert isinstance(result_forward, jax.Array) assert jnp.sum(jnp.isnan(result_forward)) == 0 - @pytest.mark.parametrize("k_noise_per_x", [1, 2]) + @pytest.mark.parametrize("k_samples_per_x", [1, 2]) + @pytest.mark.parametrize("solver_latent_to_data", [None, "sinkhorn"]) def test_genot_linear_conditional( - self, genot_data_loader_linear_conditional: Iterator, k_noise_per_x: int + self, genot_data_loader_linear_conditional: Iterator, + k_samples_per_x: int, solver_latent_to_data: Optional[str] ): + None if solver_latent_to_data is None else sinkhorn.Sinkhorn() source_lin, source_quad, target_lin, target_quad, condition = next( genot_data_loader_linear_conditional ) @@ -190,7 +204,7 @@ def test_genot_linear_conditional( scale_cost=1.0, optimizer=optimizer, time_sampler=time_sampler, - k_samples_per_x=k_noise_per_x, + k_samples_per_x=k_samples_per_x, ) genot( genot_data_loader_linear_conditional, @@ -206,10 +220,13 @@ def test_genot_linear_conditional( assert isinstance(result_forward, jax.Array) assert jnp.sum(jnp.isnan(result_forward)) == 0 - @pytest.mark.parametrize("k_noise_per_x", [1, 2]) + @pytest.mark.parametrize("k_samples_per_x", [1, 2]) + @pytest.mark.parametrize("solver_latent_to_data", [None, "sinkhorn"]) def test_genot_quad_conditional( - self, genot_data_loader_quad_conditional: Iterator, k_noise_per_x: int + self, genot_data_loader_quad_conditional: Iterator, k_samples_per_x: int, + solver_latent_to_data: Optional[str] ): + None if solver_latent_to_data is None else sinkhorn.Sinkhorn() source_lin, source_quad, target_lin, target_quad, condition = next( genot_data_loader_quad_conditional ) @@ -237,7 +254,7 @@ def test_genot_quad_conditional( scale_cost=1.0, optimizer=optimizer, time_sampler=time_sampler, - k_samples_per_x=k_noise_per_x, + k_samples_per_x=k_samples_per_x, ) genot( genot_data_loader_quad_conditional, genot_data_loader_quad_conditional @@ -249,10 +266,13 @@ def test_genot_quad_conditional( assert isinstance(result_forward, jax.Array) assert jnp.sum(jnp.isnan(result_forward)) == 0 - @pytest.mark.parametrize("k_noise_per_x", [1, 2]) + @pytest.mark.parametrize("k_samples_per_x", [1, 2]) + @pytest.mark.parametrize("solver_latent_to_data", [None, "sinkhorn"]) def test_genot_fused_conditional( - self, genot_data_loader_fused_conditional: Iterator, k_noise_per_x: int + self, genot_data_loader_fused_conditional: Iterator, k_samples_per_x: int, + solver_latent_to_data: Optional[str] ): + None if solver_latent_to_data is None else sinkhorn.Sinkhorn() source_lin, source_quad, target_lin, target_quad, condition = next( genot_data_loader_fused_conditional ) @@ -280,7 +300,7 @@ def test_genot_fused_conditional( scale_cost=1.0, optimizer=optimizer, time_sampler=time_sampler, - k_samples_per_x=k_noise_per_x, + k_samples_per_x=k_samples_per_x, ) genot( genot_data_loader_fused_conditional, genot_data_loader_fused_conditional @@ -295,10 +315,13 @@ def test_genot_fused_conditional( assert jnp.sum(jnp.isnan(result_forward)) == 0 @pytest.mark.parametrize("conditional", [False, True]) + @pytest.mark.parametrize("solver_latent_to_data", [None, "sinkhorn"]) def test_genot_linear_learn_rescaling( self, conditional: bool, genot_data_loader_linear: Iterator, + solver_latent_to_data: Optional[str], genot_data_loader_linear_conditional: Iterator ): + None if solver_latent_to_data is None else sinkhorn.Sinkhorn() data_loader = genot_data_loader_linear_conditional if conditional else genot_data_loader_linear source_lin, source_quad, target_lin, target_quad, condition = next( From 993d1de2f9f982e1571af8f6c79bcc6806f5214b Mon Sep 17 00:00:00 2001 From: lucaeyring Date: Mon, 27 Nov 2023 20:16:22 +0100 Subject: [PATCH 029/186] [ci skip] changed dataloaders to numpy and dict return --- src/ott/neural/models/models.py | 1 + src/ott/neural/solvers/base_solver.py | 2 +- src/ott/neural/solvers/otfm.py | 69 +++++-- tests/neural/conftest.py | 248 ++++++++++++++++---------- tests/neural/otfm_test.py | 28 +-- 5 files changed, 219 insertions(+), 129 deletions(-) diff --git a/src/ott/neural/models/models.py b/src/ott/neural/models/models.py index 62edee7ef..6bb075ff3 100644 --- a/src/ott/neural/models/models.py +++ b/src/ott/neural/models/models.py @@ -552,6 +552,7 @@ def __call__( x ) if self.condition_dim > 0: + condition = jnp.atleast_1d(condition) condition = Block( dim=self.hidden_dim, out_dim=self.hidden_dim, diff --git a/src/ott/neural/solvers/base_solver.py b/src/ott/neural/solvers/base_solver.py index 2c8f7541e..e7216da46 100644 --- a/src/ott/neural/solvers/base_solver.py +++ b/src/ott/neural/solvers/base_solver.py @@ -447,7 +447,7 @@ def evaluate_eta( """ if self.state_eta is None: raise ValueError("The left rescaling factor was not parameterized.") - return self.state_xi.apply_fn({"params": self.state_eta.params}, + return self.state_eta.apply_fn({"params": self.state_eta.params}, x=source, condition=condition) diff --git a/src/ott/neural/solvers/otfm.py b/src/ott/neural/solvers/otfm.py index 25684016f..6e6f419ca 100644 --- a/src/ott/neural/solvers/otfm.py +++ b/src/ott/neural/solvers/otfm.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections import defaultdict import functools import types from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Type @@ -35,6 +36,7 @@ BaseTimeSampler, ) from ott.solvers import was_solver +from ott.tools.sinkhorn_divergence import sinkhorn_divergence __all__ = ["OTFlowMatching"] @@ -63,6 +65,7 @@ class OTFlowMatching(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): mlp_xi: Neural network to learn the right rescaling function as suggested in :cite:`TODO`. If `None`, the right rescaling factor is not learnt. unbalanced_kwargs: Keyword arguments for the unbalancedness solver. callback_fn: Callback function. + num_eval_samples: Number of samples to evaluate on during evaluation. rng: Random number generator. Returns: @@ -76,7 +79,6 @@ def __init__( input_dim: int, cond_dim: int, iterations: int, - valid_freq: int, ot_solver: Optional[Type[was_solver.WassersteinSolver]], flow: Type[BaseFlow], time_sampler: Type[BaseTimeSampler], @@ -91,6 +93,9 @@ def __init__( unbalanced_kwargs: Dict[str, Any] = {}, callback_fn: Optional[Callable[[jax.Array, jax.Array, jax.Array], Any]] = None, + logging_freq: int = 100, + valid_freq: int = 5000, + num_eval_samples: int = 1000, rng: random.PRNGKeyArray = random.PRNGKey(0), ) -> None: rng, rng_unbalanced = random.split(rng) @@ -122,6 +127,9 @@ def __init__( self.callback_fn = callback_fn self.checkpoint_manager = checkpoint_manager self.rng = rng + self.logging_freq = logging_freq + self.num_eval_samples = num_eval_samples + self._training_logs: Mapping[str, Any] = defaultdict(list) self.setup() @@ -146,6 +154,7 @@ def setup(self) -> None: def _get_step_fn(self) -> Callable: + @jax.jit def step_fn( key: random.PRNGKeyArray, state_neural_vector_field: train_state.TrainState, @@ -157,17 +166,17 @@ def loss_fn( batch: Dict[str, jax.Array], keys_model: random.PRNGKeyArray ) -> jax.Array: - x_t = self.flow.compute_xt(noise, t, batch["source"], batch["target"]) + x_t = self.flow.compute_xt(noise, t, batch["source_lin"], batch["target_lin"]) apply_fn = functools.partial( state_neural_vector_field.apply_fn, {"params": params} ) v_t = jax.vmap(apply_fn)( - t=t, x=x_t, condition=batch["condition"], keys_model=keys_model + t=t, x=x_t, condition=batch["source_conditions"], keys_model=keys_model ) - u_t = self.flow.compute_ut(t, batch["source"], batch["target"]) + u_t = self.flow.compute_ut(t, batch["source_lin"], batch["target_lin"]) return jnp.mean((v_t - u_t) ** 2) - batch_size = len(batch["source"]) + batch_size = len(batch["source_lin"]) key_noise, key_t, key_model = random.split(key, 3) keys_model = random.split(key_model, batch_size) t = self.time_sampler(key_t, batch_size) @@ -191,24 +200,50 @@ def __call__(self, train_loader, valid_loader) -> None: None """ batch: Mapping[str, jax.Array] = {} + curr_loss = 0.0 + """ + if self.num_eval_samples > 0: + eval_batch_source, eval_batch_target = [], [] + for iter in range(self.num_eval_samples): + batch = next( + valid_loader + ) + eval_batch_source.append(batch["source_lin"]) + eval_batch_target.append(batch["target_lin"]) + eval_batch_source = jnp.stack(eval_batch_source) + eval_batch_target = jnp.stack(eval_batch_target) + self._training_logs["data_sink_div"].append( + sinkhorn_divergence( + eval_batch_source, + eval_batch_target, + self.epsilon, + self.cost_fn, + self.scale_cost, + ) + )""" + for iter in range(self.iterations): rng_resample, rng_step_fn, self.rng = random.split(self.rng, 3) - batch["source"], batch["target"], batch["condition"] = next(train_loader) + batch = next(train_loader) if self.ot_solver is not None: - tmat = self.match_fn(batch["source"], batch["target"]) - (batch["source"], - batch["condition"]), (batch["target"],) = self._resample_data( - rng_resample, tmat, (batch["source"], batch["condition"]), - (batch["target"],) + tmat = self.match_fn(batch["source_lin"], batch["target_lin"]) + (batch["source_lin"], + batch["source_conditions"]), (batch["target_lin"], batch["target_conditions"]) = self._resample_data( + rng_resample, tmat, (batch["source_lin"], batch["source_conditions"]), + (batch["target_lin"], batch["target_conditions"]) ) self.state_neural_vector_field, loss = self.step_fn( rng_step_fn, self.state_neural_vector_field, batch ) + curr_loss += loss + if iter % self.logging_freq == 0: + self._training_logs["loss"].append(curr_loss / self.logging_freq) + curr_loss = 0.0 if self.learn_rescaling: self.state_eta, self.state_xi, eta_predictions, xi_predictions, loss_a, loss_b = self.unbalancedness_step_fn( - source=batch["source"], - target=batch["target"], - condition=batch["condition"], + source=batch["source_lin"], + target=batch["target_lin"], + condition=batch["source_conditions"], a=tmat.sum(axis=1), b=tmat.sum(axis=0), state_eta=self.state_eta, @@ -220,8 +255,8 @@ def __call__(self, train_loader, valid_loader) -> None: states_to_save = { "state_neural_vector_field": self.state_neural_vector_field } - if self.state_mlp is not None: - states_to_save["state_eta"] = self.state_mlp + if self.state_eta is not None: + states_to_save["state_eta"] = self.state_eta if self.state_xi is not None: states_to_save["state_xi"] = self.state_xi self.checkpoint_manager.save(iter, states_to_save) @@ -254,6 +289,7 @@ def transport( t0, t1 = (self.time_sampler.low, self.time_sampler.high ) if forward else (self.time_sampler.high, self.time_sampler.low) + @jax.jit def solve_ode(input: jax.Array, cond: jax.Array): return diffrax.diffeqsolve( diffrax.ODETerm( @@ -280,6 +316,7 @@ def solve_ode(input: jax.Array, cond: jax.Array): def _valid_step(self, valid_loader, iter) -> None: next(valid_loader) # TODO: add callback and logging + @property def learn_rescaling(self) -> bool: diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py index a5fdc1c2b..008036790 100644 --- a/tests/neural/conftest.py +++ b/tests/neural/conftest.py @@ -1,7 +1,6 @@ -from typing import Dict, Iterator, Optional +from typing import Dict, Iterator, Mapping, Optional -import jax -import jax.numpy as jnp +import numpy as np import pytest @@ -9,95 +8,105 @@ class DataLoader: def __init__( self, - source_data: jax.Array, - target_data: jax.Array, - conditions: Optional[jax.Array], - batch_size: int = 64 + source_data: np.ndarray, + target_data: np.ndarray, + batch_size: int = 64, + source_conditions: Optional[np.ndarray] = None, + target_conditions: Optional[np.ndarray] = None, ) -> None: super().__init__() self.source_data = source_data self.target_data = target_data - self.conditions = conditions + self.source_conditions = source_conditions + self.target_conditions = target_conditions self.batch_size = batch_size - self.key = jax.random.PRNGKey(0) + self.rng = np.random.default_rng(seed=0) - def __next__(self) -> jax.Array: - key, self.key = jax.random.split(self.key) - inds_source = jax.random.choice( - key, len(self.source_data), shape=[self.batch_size] + def __next__(self) -> Mapping[str, np.ndarray]: + inds_source = self.rng.choice( + len(self.source_data), size=[self.batch_size] ) - inds_target = jax.random.choice( - key, len(self.target_data), shape=[self.batch_size] + inds_target = self.rng.choice( + len(self.target_data), size=[self.batch_size] ) - return self.source_data[inds_source, :], self.target_data[ - inds_target, :], self.conditions[ - inds_source, :] if self.conditions is not None else None + return { + "source_lin": + self.source_data[inds_source, :], + "target_lin": + self.target_data[inds_target, :], + "source_conditions": + self.source_conditions[inds_source, :] + if self.source_conditions is not None else None, + "target_conditions": + self.target_conditions[inds_target, :] + if self.target_conditions is not None else None, + } class ConditionalDataLoader: - def __init__( - self, rng: jax.random.KeyArray, dataloaders: Dict[str, Iterator], - p: jax.Array - ) -> None: + def __init__(self, dataloaders: Dict[str, Iterator], p: np.ndarray) -> None: super().__init__() - self.rng = rng self.dataloaders = dataloaders self.conditions = list(dataloaders.keys()) self.p = p - - def __next__(self) -> jax.Array: - self.rng, rng = jax.random.split(self.rng, 2) - idx = jax.random.choice(rng, len(self.conditions), p=self.p) + self.rng = np.random.default_rng(seed=0) + + def __next__(self, cond: str = None) -> Mapping[str, np.ndarray]: + if cond is not None: + if cond not in self.conditions: + raise ValueError(f"Condition {cond} not in {self.conditions}") + return next(self.dataloaders[cond]) + idx = self.rng.choice(len(self.conditions), p=self.p) return next(self.dataloaders[self.conditions[idx]]) @pytest.fixture(scope="module") def data_loader_gaussian(): """Returns a data loader for a simple Gaussian mixture.""" - source = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) - target = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) + 1.0 - return DataLoader(source, target, None, 16) + rng = np.random.default_rng(seed=0) + source = rng.normal(size=(100, 2)) + target = rng.normal(size=(100, 2)) + 1.0 + return DataLoader(source, target, 16) @pytest.fixture(scope="module") def data_loader_gaussian_conditional(): """Returns a data loader for Gaussian mixtures with conditions.""" - source_0 = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) - target_0 = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) + 2.0 - - source_1 = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) - target_1 = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) - 2.0 - dl0 = DataLoader(source_0, target_0, jnp.zeros_like(source_0) * 0.0, 16) - dl1 = DataLoader(source_1, target_1, jnp.ones_like(source_1) * 1.0, 16) - - return ConditionalDataLoader( - jax.random.PRNGKey(0), { - "0": dl0, - "1": dl1 - }, jnp.array([0.5, 0.5]) - ) + rng = np.random.default_rng(seed=0) + source_0 = rng.normal(size=(100, 2)) + target_0 = rng.normal(size=(100, 2)) + 2.0 + + source_1 = rng.normal(size=(100, 2)) + target_1 = rng.normal(size=(100, 2)) - 2.0 + dl0 = DataLoader(source_0, target_0, 16, source_conditions=np.zeros_like(source_0) * 0.0) + dl1 = DataLoader(source_1, target_1, 16, source_conditions=np.ones_like(source_1) * 1.0) + + return ConditionalDataLoader({"0": dl0, "1": dl1}, np.array([0.5, 0.5])) @pytest.fixture(scope="module") def data_loader_gaussian_with_conditions(): """Returns a data loader for a simple Gaussian mixture with conditions.""" - source = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) - conditions = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 1)) - target = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) + 1.0 - return DataLoader(source, target, conditions, 16) + rng = np.random.default_rng(seed=0) + source = rng.normal(size=(100, 2)) + target = rng.normal(size=(100, 2)) + 1.0 + source_conditions = rng.normal(size=(100, 1)) + target_conditions = rng.normal(size=(100, 1)) - 1.0 + return DataLoader(source, target, 16, source_conditions, target_conditions) class GENOTDataLoader: def __init__( self, - source_lin: Optional[jax.Array], - source_quad: Optional[jax.Array], - target_lin: Optional[jax.Array], - target_quad: Optional[jax.Array], - conditions: Optional[jax.Array], - batch_size: int = 64 + batch_size: int = 64, + source_lin: Optional[np.ndarray] = None, + source_quad: Optional[np.ndarray] = None, + target_lin: Optional[np.ndarray] = None, + target_quad: Optional[np.ndarray] = None, + source_conditions: Optional[np.ndarray] = None, + target_conditions: Optional[np.ndarray] = None, ) -> None: super().__init__() if source_lin is not None: @@ -108,8 +117,8 @@ def __init__( self.n_source = len(source_lin) else: self.n_source = len(source_quad) - if conditions is not None: - assert len(conditions) == self.n_source + if source_conditions is not None: + assert len(source_conditions) == self.n_source if target_lin is not None: if target_quad is not None: assert len(target_lin) == len(target_quad) @@ -118,83 +127,126 @@ def __init__( self.n_target = len(target_lin) else: self.n_target = len(target_quad) + if target_conditions is not None: + assert len(target_conditions) == self.n_target self.source_lin = source_lin self.target_lin = target_lin self.source_quad = source_quad self.target_quad = target_quad - self.conditions = conditions + self.source_conditions = source_conditions + self.target_conditions = target_conditions self.batch_size = batch_size - self.key = jax.random.PRNGKey(0) - - def __next__(self) -> jax.Array: - key, self.key = jax.random.split(self.key) - inds_source = jax.random.choice(key, self.n_source, shape=[self.batch_size]) - inds_target = jax.random.choice(key, self.n_target, shape=[self.batch_size]) - return self.source_lin[ - inds_source, : - ] if self.source_lin is not None else None, self.source_quad[ - inds_source, : - ] if self.source_quad is not None else None, self.target_lin[ - inds_target, : - ] if self.target_lin is not None else None, self.target_quad[ - inds_target, : - ] if self.target_quad is not None else None, self.conditions[ - inds_source, :] if self.conditions is not None else None + self.rng = np.random.default_rng(seed=0) + + def __next__(self) -> Mapping[str, np.ndarray]: + inds_source = self.rng.choice(self.n_source, size=[self.batch_size]) + inds_target = self.rng.choice(self.n_target, size=[self.batch_size]) + return { + "source_lin": + self.source_lin[inds_source, :] + if self.source_lin is not None else None, + "source_quad": + self.source_quad[inds_source, :] + if self.source_quad is not None else None, + "target_lin": + self.target_lin[inds_target, :] + if self.target_lin is not None else None, + "target_quad": + self.target_quad[inds_target, :] + if self.target_quad is not None else None, + "source_conditions": + self.source_conditions[inds_source, :] + if self.source_conditions is not None else None, + "target_conditions": + self.target_conditions[inds_target, :] + if self.target_conditions is not None else None, + } @pytest.fixture(scope="module") def genot_data_loader_linear(): """Returns a data loader for a simple Gaussian mixture.""" - source = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) - target = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) + 1.0 - return GENOTDataLoader(source, None, target, None, None, 16) + rng = np.random.default_rng(seed=0) + source = rng.normal(size=(100, 2)) + target = rng.normal(size=(100, 2)) + 1.0 + return GENOTDataLoader(16, source_lin=source, target_lin=target) @pytest.fixture(scope="module") def genot_data_loader_linear_conditional(): """Returns a data loader for a simple Gaussian mixture.""" - source = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) - target = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) + 1.0 - conditions = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 4)) - return GENOTDataLoader(source, None, target, None, conditions, 16) + rng = np.random.default_rng(seed=0) + source = rng.normal(size=(100, 2)) + target = rng.normal(size=(100, 2)) + 1.0 + conditions_source = rng.normal(size=(100, 4)) + conditions_target = rng.normal(size=(100, 4)) - 1.0 + return GENOTDataLoader( + 16, + source_lin=source, + target_lin=target, + conditions_source=conditions_source, + conditions_target=conditions_target + ) @pytest.fixture(scope="module") def genot_data_loader_quad(): """Returns a data loader for a simple Gaussian mixture.""" - source = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) - target = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 1)) + 1.0 - return GENOTDataLoader(None, source, None, target, None, 16) + rng = np.random.default_rng(seed=0) + source = rng.normal(size=(100, 2)) + target = rng.normal(size=(100, 1)) + 1.0 + return GENOTDataLoader(16, source_quad=source, target_quad=target) @pytest.fixture(scope="module") def genot_data_loader_quad_conditional(): """Returns a data loader for a simple Gaussian mixture.""" - source = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) - target = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 1)) + 1.0 - conditions = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 7)) - return GENOTDataLoader(None, source, None, target, conditions, 16) + rng = np.random.default_rng(seed=0) + source = rng.normal(size=(100, 2)) + target = rng.normal(size=(100, 1)) + 1.0 + conditions = rng.normal(size=(100, 7)) + return GENOTDataLoader( + 16, + source_quad=source, + target_quad=target, + source_conditions=conditions, + target_conditions=conditions + ) @pytest.fixture(scope="module") def genot_data_loader_fused(): """Returns a data loader for a simple Gaussian mixture.""" - source_q = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) - target_q = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 1)) + 1.0 - source_lin = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) - target_lin = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) + 1.0 - return GENOTDataLoader(source_lin, source_q, target_lin, target_q, None, 16) + rng = np.random.default_rng(seed=0) + source_q = rng.normal(size=(100, 2)) + target_q = rng.normal(size=(100, 1)) + 1.0 + source_lin = rng.normal(size=(100, 2)) + target_lin = rng.normal(size=(100, 2)) + 1.0 + return GENOTDataLoader( + 16, + source_lin=source_lin, + source_quad=source_q, + target_lin=target_lin, + target_quad=target_q + ) @pytest.fixture(scope="module") def genot_data_loader_fused_conditional(): """Returns a data loader for a simple Gaussian mixture.""" - source_q = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) - target_q = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 1)) + 1.0 - source_lin = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) - target_lin = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 2)) + 1.0 - conditions = jax.random.normal(jax.random.PRNGKey(0), shape=(100, 1)) + rng = np.random.default_rng(seed=0) + source_q = rng.normal(size=(100, 2)) + target_q = rng.normal(size=(100, 1)) + 1.0 + source_lin = rng.normal(size=(100, 2)) + target_lin = rng.normal(size=(100, 2)) + 1.0 + conditions = rng.normal(size=(100, 7)) return GENOTDataLoader( - source_lin, source_q, target_lin, target_q, conditions, 16 + 16, + source_lin=source_lin, + source_quad=source_q, + target_lin=target_lin, + target_quad=target_q, + source_conditions=conditions, + target_conditions=conditions ) diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index 1993432f2..9a75cb3fb 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -60,12 +60,12 @@ def test_flow_matching(self, data_loader_gaussian, flow: Type[BaseFlow]): ) fm(data_loader_gaussian, data_loader_gaussian) - source, target, condition = next(data_loader_gaussian) - result_forward = fm.transport(source, condition=condition, forward=True) + batch = next(data_loader_gaussian) + result_forward = fm.transport(batch["source_lin"], condition=batch["source_conditions"], forward=True) assert isinstance(result_forward, jax.Array) assert jnp.sum(jnp.isnan(result_forward)) == 0 - result_backward = fm.transport(target, condition=condition, forward=False) + result_backward = fm.transport(batch["target_lin"], condition=batch["target_conditions"], forward=False) assert isinstance(result_backward, jax.Array) assert jnp.sum(jnp.isnan(result_backward)) == 0 @@ -102,12 +102,12 @@ def test_flow_matching_with_conditions( data_loader_gaussian_with_conditions ) - source, target, condition = next(data_loader_gaussian_with_conditions) - result_forward = fm.transport(source, condition=condition, forward=True) + batch = next(data_loader_gaussian_with_conditions) + result_forward = fm.transport(batch["source_lin"], condition=batch["source_conditions"], forward=True) assert isinstance(result_forward, jax.Array) assert jnp.sum(jnp.isnan(result_forward)) == 0 - result_backward = fm.transport(target, condition=condition, forward=False) + result_backward = fm.transport(batch["target_lin"], condition=batch["target_conditions"], forward=False) assert isinstance(result_backward, jax.Array) assert jnp.sum(jnp.isnan(result_backward)) == 0 @@ -141,12 +141,12 @@ def test_flow_matching_conditional( ) fm(data_loader_gaussian_conditional, data_loader_gaussian_conditional) - source, target, condition = next(data_loader_gaussian_conditional) - result_forward = fm.transport(source, condition=condition, forward=True) + batch = next(data_loader_gaussian_conditional) + result_forward = fm.transport(batch["source_lin"], condition=batch["source_conditions"], forward=True) assert isinstance(result_forward, jax.Array) assert jnp.sum(jnp.isnan(result_forward)) == 0 - result_backward = fm.transport(target, condition=condition, forward=False) + result_backward = fm.transport(batch["target_lin"], condition=batch["target_conditions"], forward=False) assert isinstance(result_backward, jax.Array) assert jnp.sum(jnp.isnan(result_backward)) == 0 @@ -156,9 +156,9 @@ def test_flow_matching_learn_rescaling( data_loader_gaussian_conditional: Iterator ): data_loader = data_loader_gaussian_conditional if conditional else data_loader_gaussian - source, target, condition = next(data_loader) - source_dim = source.shape[1] - condition_dim = condition.shape[1] if conditional else 0 + batch = next(data_loader) + source_dim = batch["source_lin"].shape[1] + condition_dim = batch["source_conditions"].shape[1] if conditional else 0 neural_vf = NeuralVectorField( output_dim=2, condition_dim=0, @@ -190,10 +190,10 @@ def test_flow_matching_learn_rescaling( ) fm(data_loader, data_loader) - result_eta = fm.evaluate_eta(source, condition=condition) + result_eta = fm.evaluate_eta(batch["source_lin"], condition=batch["source_conditions"]) assert isinstance(result_eta, jax.Array) assert jnp.sum(jnp.isnan(result_eta)) == 0 - result_xi = fm.evaluate_xi(target, condition=condition) + result_xi = fm.evaluate_xi(batch["target_lin"], condition=batch["target_conditions"]) assert isinstance(result_xi, jax.Array) assert jnp.sum(jnp.isnan(result_xi)) == 0 From beee22dff6c5741efc7b3269e0eb9de8d609ec65 Mon Sep 17 00:00:00 2001 From: lucaeyring Date: Mon, 27 Nov 2023 20:17:47 +0100 Subject: [PATCH 030/186] [ci skip] changed dataloaders to numpy and dict return --- src/ott/neural/solvers/otfm.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/src/ott/neural/solvers/otfm.py b/src/ott/neural/solvers/otfm.py index 6e6f419ca..82a6b67aa 100644 --- a/src/ott/neural/solvers/otfm.py +++ b/src/ott/neural/solvers/otfm.py @@ -201,26 +201,6 @@ def __call__(self, train_loader, valid_loader) -> None: """ batch: Mapping[str, jax.Array] = {} curr_loss = 0.0 - """ - if self.num_eval_samples > 0: - eval_batch_source, eval_batch_target = [], [] - for iter in range(self.num_eval_samples): - batch = next( - valid_loader - ) - eval_batch_source.append(batch["source_lin"]) - eval_batch_target.append(batch["target_lin"]) - eval_batch_source = jnp.stack(eval_batch_source) - eval_batch_target = jnp.stack(eval_batch_target) - self._training_logs["data_sink_div"].append( - sinkhorn_divergence( - eval_batch_source, - eval_batch_target, - self.epsilon, - self.cost_fn, - self.scale_cost, - ) - )""" for iter in range(self.iterations): rng_resample, rng_step_fn, self.rng = random.split(self.rng, 3) From f26e07292a1a9cbfb31a79326edce114c1b636dd Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 28 Nov 2023 10:07:06 +0100 Subject: [PATCH 031/186] revert jax.Array to jnp.ndarray --- docs/tutorials/GWLRSinkhorn.ipynb | 2 +- docs/tutorials/Hessians.ipynb | 2 +- docs/tutorials/Monge_Gap.ipynb | 12 +- docs/tutorials/One_Sinkhorn.ipynb | 4 +- .../tutorials/basic_ot_between_datasets.ipynb | 2 +- docs/tutorials/point_clouds.ipynb | 4 +- .../sinkhorn_divergence_gradient_flow.ipynb | 4 +- .../sparse_monge_displacements.ipynb | 2 +- src/ott/datasets.py | 8 +- src/ott/geometry/costs.py | 166 +++++++-------- src/ott/geometry/geometry.py | 190 +++++++++--------- src/ott/geometry/graph.py | 44 ++-- src/ott/geometry/grid.py | 40 ++-- src/ott/geometry/low_rank.py | 60 +++--- src/ott/geometry/pointcloud.py | 110 +++++----- src/ott/geometry/segment.py | 33 +-- src/ott/initializers/linear/initializers.py | 58 +++--- .../initializers/linear/initializers_lr.py | 100 ++++----- .../initializers/quadratic/initializers.py | 4 +- src/ott/math/fixed_point_loop.py | 2 +- src/ott/math/matrix_square_root.py | 51 ++--- src/ott/math/unbalanced_functions.py | 23 +-- src/ott/math/utils.py | 22 +- src/ott/neural/data/dataloaders.py | 4 +- src/ott/neural/models/base_models.py | 18 +- src/ott/neural/models/conjugate_solvers.py | 13 +- src/ott/neural/models/layers.py | 12 +- src/ott/neural/models/models.py | 52 ++--- src/ott/neural/solvers/base_solver.py | 95 ++++----- src/ott/neural/solvers/flows.py | 31 +-- src/ott/neural/solvers/genot.py | 18 +- src/ott/neural/solvers/losses.py | 8 +- src/ott/neural/solvers/map_estimator.py | 30 +-- src/ott/neural/solvers/neuraldual.py | 56 +++--- src/ott/neural/solvers/otfm.py | 49 +++-- src/ott/problems/linear/barycenter_problem.py | 20 +- src/ott/problems/linear/linear_problem.py | 13 +- src/ott/problems/linear/potentials.py | 34 ++-- src/ott/problems/quadratic/gw_barycenter.py | 50 ++--- src/ott/problems/quadratic/quadratic_costs.py | 3 +- .../problems/quadratic/quadratic_problem.py | 32 +-- src/ott/solvers/linear/_solve.py | 6 +- src/ott/solvers/linear/acceleration.py | 8 +- .../solvers/linear/continuous_barycenter.py | 24 +-- src/ott/solvers/linear/discrete_barycenter.py | 18 +- .../linear/implicit_differentiation.py | 31 +-- src/ott/solvers/linear/lineax_implicit.py | 4 +- src/ott/solvers/linear/lr_utils.py | 42 ++-- src/ott/solvers/linear/sinkhorn.py | 88 ++++---- src/ott/solvers/linear/sinkhorn_lr.py | 131 ++++++------ src/ott/solvers/linear/univariate.py | 14 +- src/ott/solvers/quadratic/_solve.py | 6 +- .../solvers/quadratic/gromov_wasserstein.py | 24 +-- .../quadratic/gromov_wasserstein_lr.py | 135 +++++++------ src/ott/solvers/quadratic/gw_barycenter.py | 34 ++-- src/ott/tools/gaussian_mixture/fit_gmm.py | 38 ++-- .../tools/gaussian_mixture/fit_gmm_pair.py | 26 +-- src/ott/tools/gaussian_mixture/gaussian.py | 32 +-- .../gaussian_mixture/gaussian_mixture.py | 41 ++-- .../gaussian_mixture/gaussian_mixture_pair.py | 6 +- src/ott/tools/gaussian_mixture/linalg.py | 34 ++-- .../tools/gaussian_mixture/probabilities.py | 14 +- src/ott/tools/gaussian_mixture/scale_tril.py | 38 ++-- src/ott/tools/k_means.py | 70 +++---- src/ott/tools/plot.py | 6 +- src/ott/tools/segment_sinkhorn.py | 24 +-- src/ott/tools/sinkhorn_divergence.py | 44 ++-- src/ott/tools/soft_sort.py | 82 ++++---- src/ott/types.py | 8 +- src/ott/utils.py | 3 +- tests/conftest.py | 3 +- tests/geometry/costs_test.py | 22 +- tests/geometry/graph_test.py | 22 +- tests/geometry/low_rank_test.py | 26 +-- tests/geometry/pointcloud_test.py | 10 +- tests/geometry/scaling_cost_test.py | 8 +- tests/geometry/subsetting_test.py | 8 +- .../initializers/linear/sinkhorn_init_test.py | 26 +-- .../linear/sinkhorn_lr_init_test.py | 8 +- tests/initializers/quadratic/gw_init_test.py | 3 +- tests/math/lse_test.py | 2 +- tests/math/math_utils_test.py | 2 +- tests/math/matrix_square_root_test.py | 12 +- tests/neural/conftest.py | 16 +- tests/neural/genot_test.py | 17 +- tests/neural/icnn_test.py | 4 +- tests/neural/losses_test.py | 7 +- tests/neural/map_estimator_test.py | 6 +- tests/neural/meta_initializer_test.py | 14 +- tests/neural/otfm_test.py | 55 +++-- tests/problems/linear/potentials_test.py | 14 +- .../linear/continuous_barycenter_test.py | 14 +- tests/solvers/linear/sinkhorn_diff_test.py | 58 +++--- tests/solvers/linear/sinkhorn_grid_test.py | 8 +- tests/solvers/linear/sinkhorn_lr_test.py | 2 +- tests/solvers/linear/sinkhorn_misc_test.py | 18 +- tests/solvers/linear/sinkhorn_test.py | 2 +- tests/solvers/linear/univariate_test.py | 4 +- tests/solvers/quadratic/fgw_test.py | 18 +- tests/solvers/quadratic/gw_barycenter_test.py | 12 +- tests/solvers/quadratic/gw_test.py | 19 +- tests/solvers/quadratic/lower_bound_test.py | 6 +- .../gaussian_mixture/fit_gmm_pair_test.py | 2 +- tests/tools/gaussian_mixture/fit_gmm_test.py | 2 +- .../gaussian_mixture_pair_test.py | 2 +- .../gaussian_mixture/gaussian_mixture_test.py | 16 +- tests/tools/gaussian_mixture/gaussian_test.py | 18 +- tests/tools/gaussian_mixture/linalg_test.py | 20 +- .../gaussian_mixture/probabilities_test.py | 4 +- .../tools/gaussian_mixture/scale_tril_test.py | 12 +- tests/tools/k_means_test.py | 48 ++--- tests/tools/segment_sinkhorn_test.py | 2 +- tests/tools/sinkhorn_divergence_test.py | 6 +- tests/tools/soft_sort_test.py | 28 +-- 114 files changed, 1603 insertions(+), 1515 deletions(-) diff --git a/docs/tutorials/GWLRSinkhorn.ipynb b/docs/tutorials/GWLRSinkhorn.ipynb index ace06be8f..590671428 100644 --- a/docs/tutorials/GWLRSinkhorn.ipynb +++ b/docs/tutorials/GWLRSinkhorn.ipynb @@ -66,7 +66,7 @@ }, "outputs": [], "source": [ - "def create_points(rng: jax.Array, n: int, m: int, d1: int, d2: int):\n", + "def create_points(rng: jnp.ndarray, n: int, m: int, d1: int, d2: int):\n", " rngs = jax.random.split(rng, 5)\n", " x = jax.random.uniform(rngs[0], (n, d1))\n", " y = jax.random.uniform(rngs[1], (m, d2))\n", diff --git a/docs/tutorials/Hessians.ipynb b/docs/tutorials/Hessians.ipynb index f7c8b56d1..0e50ec959 100644 --- a/docs/tutorials/Hessians.ipynb +++ b/docs/tutorials/Hessians.ipynb @@ -103,7 +103,7 @@ }, "outputs": [], "source": [ - "def loss(a: jax.Array, x: jax.Array, implicit: bool = True) -> float:\n", + "def loss(a: jnp.ndarray, x: jnp.ndarray, implicit: bool = True) -> float:\n", " return sinkhorn_divergence.sinkhorn_divergence(\n", " pointcloud.PointCloud,\n", " x,\n", diff --git a/docs/tutorials/Monge_Gap.ipynb b/docs/tutorials/Monge_Gap.ipynb index ac38d89b4..53bc670dc 100644 --- a/docs/tutorials/Monge_Gap.ipynb +++ b/docs/tutorials/Monge_Gap.ipynb @@ -94,13 +94,13 @@ "\n", " name: Literal[\"moon\", \"s_curve\"]\n", " theta_rotation: float = 0.0\n", - " mean: Optional[jax.Array] = None\n", + " mean: Optional[jnp.ndarray] = None\n", " noise: float = 0.01\n", " scale: float = 1.0\n", " batch_size: int = 1024\n", - " rng: Optional[jax.Array] = (None,)\n", + " rng: Optional[jnp.ndarray] = (None,)\n", "\n", - " def __iter__(self) -> Iterator[jax.Array]:\n", + " def __iter__(self) -> Iterator[jnp.ndarray]:\n", " \"\"\"Random sample generator from Gaussian mixture.\n", "\n", " Returns:\n", @@ -108,7 +108,7 @@ " \"\"\"\n", " return self._create_sample_generators()\n", "\n", - " def _create_sample_generators(self) -> Iterator[jax.Array]:\n", + " def _create_sample_generators(self) -> Iterator[jnp.ndarray]:\n", " rng = jax.random.PRNGKey(0) if self.rng is None else self.rng\n", "\n", " # define rotation matrix tp rotate samples\n", @@ -151,7 +151,7 @@ " target_kwargs: Mapping[str, Any] = MappingProxyType({}),\n", " train_batch_size: int = 256,\n", " valid_batch_size: int = 256,\n", - " rng: Optional[jax.Array] = None,\n", + " rng: Optional[jnp.ndarray] = None,\n", ") -> Tuple[dataset.Dataset, dataset.Dataset, int]:\n", " \"\"\"Samplers from ``SklearnDistribution``.\"\"\"\n", " rng = jax.random.PRNGKey(0) if rng is None else rng\n", @@ -202,7 +202,7 @@ " num_points: Optional[int] = None,\n", " title: Optional[str] = None,\n", " figsize: Tuple[int, int] = (8, 6),\n", - " rng: Optional[jax.Array] = None,\n", + " rng: Optional[jnp.ndarray] = None,\n", "):\n", " \"\"\"Plot samples from the source and target measures.\n", "\n", diff --git a/docs/tutorials/One_Sinkhorn.ipynb b/docs/tutorials/One_Sinkhorn.ipynb index 9465441d8..8c3d98e2e 100644 --- a/docs/tutorials/One_Sinkhorn.ipynb +++ b/docs/tutorials/One_Sinkhorn.ipynb @@ -555,7 +555,9 @@ }, "outputs": [], "source": [ - "def my_sinkhorn(geom: geometry.Geometry, a: jax.Array, b: jax.Array, **kwargs):\n", + "def my_sinkhorn(\n", + " geom: geometry.Geometry, a: jnp.ndarray, b: jnp.ndarray, **kwargs\n", + "):\n", " return linear.solve(\n", " geom, a, b, inner_iterations=1, max_iterations=10_000, **kwargs\n", " )" diff --git a/docs/tutorials/basic_ot_between_datasets.ipynb b/docs/tutorials/basic_ot_between_datasets.ipynb index b3c452d36..3cc61d403 100644 --- a/docs/tutorials/basic_ot_between_datasets.ipynb +++ b/docs/tutorials/basic_ot_between_datasets.ipynb @@ -260,7 +260,7 @@ "metadata": {}, "outputs": [], "source": [ - "def reg_ot_cost(x: jax.Array, y: jax.Array) -> float:\n", + "def reg_ot_cost(x: jnp.ndarray, y: jnp.ndarray) -> float:\n", " geom = pointcloud.PointCloud(x, y)\n", " ot = linear.solve(geom)\n", " return ot.reg_ot_cost" diff --git a/docs/tutorials/point_clouds.ipynb b/docs/tutorials/point_clouds.ipynb index e1b77edca..fd20ffc9a 100644 --- a/docs/tutorials/point_clouds.ipynb +++ b/docs/tutorials/point_clouds.ipynb @@ -241,8 +241,8 @@ "outputs": [], "source": [ "def optimize(\n", - " x: jax.Array,\n", - " y: jax.Array,\n", + " x: jnp.ndarray,\n", + " y: jnp.ndarray,\n", " num_iter: int = 300,\n", " dump_every: int = 5,\n", " learning_rate: float = 0.2,\n", diff --git a/docs/tutorials/sinkhorn_divergence_gradient_flow.ipynb b/docs/tutorials/sinkhorn_divergence_gradient_flow.ipynb index ff84f53b4..c3b73039c 100644 --- a/docs/tutorials/sinkhorn_divergence_gradient_flow.ipynb +++ b/docs/tutorials/sinkhorn_divergence_gradient_flow.ipynb @@ -145,8 +145,8 @@ "outputs": [], "source": [ "def gradient_flow(\n", - " x: jax.Array,\n", - " y: jax.Array,\n", + " x: jnp.ndarray,\n", + " y: jnp.ndarray,\n", " cost_fn: callable,\n", " num_iter: int = 500,\n", " lr: float = 0.2,\n", diff --git a/docs/tutorials/sparse_monge_displacements.ipynb b/docs/tutorials/sparse_monge_displacements.ipynb index 8fcb49096..a21213703 100644 --- a/docs/tutorials/sparse_monge_displacements.ipynb +++ b/docs/tutorials/sparse_monge_displacements.ipynb @@ -241,7 +241,7 @@ "solver = jax.jit(sinkhorn.Sinkhorn())\n", "\n", "\n", - "def entropic_map(x, y, cost_fn: costs.TICost) -> jax.Array:\n", + "def entropic_map(x, y, cost_fn: costs.TICost) -> jnp.ndarray:\n", " geom = pointcloud.PointCloud(x, y, cost_fn=cost_fn)\n", " output = solver(linear_problem.LinearProblem(geom))\n", " dual_potentials = output.to_dual_potentials()\n", diff --git a/src/ott/datasets.py b/src/ott/datasets.py index 12dda06bb..07bd87fb9 100644 --- a/src/ott/datasets.py +++ b/src/ott/datasets.py @@ -32,8 +32,8 @@ class Dataset(NamedTuple): source_iter: loader for the source measure target_iter: loader for the target measure """ - source_iter: Iterator[jax.Array] - target_iter: Iterator[jax.Array] + source_iter: Iterator[jnp.ndarray] + target_iter: Iterator[jnp.ndarray] @dataclasses.dataclass @@ -57,7 +57,7 @@ class GaussianMixture: """ name: Name_t batch_size: int - init_rng: jax.Array + init_rng: jnp.ndarray scale: float = 5.0 std: float = 0.5 @@ -110,7 +110,7 @@ def create_gaussian_mixture_samplers( name_target: Name_t, train_batch_size: int = 2048, valid_batch_size: int = 2048, - rng: Optional[jax.Array] = None, + rng: Optional[jnp.ndarray] = None, ) -> Tuple[Dataset, Dataset, int]: """Gaussian samplers for :class:`~ott.solvers.nn.neuraldual.W2NeuralDual`. diff --git a/src/ott/geometry/costs.py b/src/ott/geometry/costs.py index aeaf89b72..9f1a6c3a0 100644 --- a/src/ott/geometry/costs.py +++ b/src/ott/geometry/costs.py @@ -56,10 +56,10 @@ class CostFn(abc.ABC): """ # no norm function created by default. - norm: Optional[Callable[[jax.Array], Union[float, jax.Array]]] = None + norm: Optional[Callable[[jnp.ndarray], Union[float, jnp.ndarray]]] = None @abc.abstractmethod - def pairwise(self, x: jax.Array, y: jax.Array) -> float: + def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: """Compute cost between :math:`x` and :math:`y`. Args: @@ -70,8 +70,8 @@ def pairwise(self, x: jax.Array, y: jax.Array) -> float: The cost. """ - def barycenter(self, weights: jax.Array, - xs: jax.Array) -> Tuple[jax.Array, Any]: + def barycenter(self, weights: jnp.ndarray, + xs: jnp.ndarray) -> Tuple[jnp.ndarray, Any]: """Barycentric operator. Args: @@ -86,7 +86,7 @@ def barycenter(self, weights: jax.Array, raise NotImplementedError("Barycenter is not implemented.") @classmethod - def _padder(cls, dim: int) -> jax.Array: + def _padder(cls, dim: int) -> jnp.ndarray: """Create a padding vector of adequate dimension, well-suited to a cost. Args: @@ -97,7 +97,7 @@ def _padder(cls, dim: int) -> jax.Array: """ return jnp.zeros((1, dim)) - def __call__(self, x: jax.Array, y: jax.Array) -> float: + def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> float: """Compute cost between :math:`x` and :math:`y`. Args: @@ -113,7 +113,7 @@ def __call__(self, x: jax.Array, y: jax.Array) -> float: return cost return cost + self.norm(x) + self.norm(y) - def all_pairs(self, x: jax.Array, y: jax.Array) -> jax.Array: + def all_pairs(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: """Compute matrix of all pairwise costs, including the :attr:`norms `. Args: @@ -125,7 +125,7 @@ def all_pairs(self, x: jax.Array, y: jax.Array) -> jax.Array: """ return jax.vmap(lambda x_: jax.vmap(lambda y_: self(x_, y_))(y))(x) - def all_pairs_pairwise(self, x: jax.Array, y: jax.Array) -> jax.Array: + def all_pairs_pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: """Compute matrix of all pairwise costs, excluding the :attr:`norms `. Args: @@ -163,7 +163,7 @@ class TICost(CostFn): """ @abc.abstractmethod - def h(self, z: jax.Array) -> float: + def h(self, z: jnp.ndarray) -> float: """TI function acting on difference of :math:`x-y` to output cost. Args: @@ -173,11 +173,11 @@ def h(self, z: jax.Array) -> float: The cost. """ - def h_legendre(self, z: jax.Array) -> float: + def h_legendre(self, z: jnp.ndarray) -> float: """Legendre transform of :func:`h` when it is convex.""" raise NotImplementedError("Legendre transform of `h` is not implemented.") - def pairwise(self, x: jax.Array, y: jax.Array) -> float: + def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: """Compute cost as evaluation of :func:`h` on :math:`x-y`.""" return self.h(x - y) @@ -198,10 +198,10 @@ def __init__(self, p: float): self.p = p self.q = 1.0 / (1.0 - (1.0 / p)) if p > 1.0 else jnp.inf - def h(self, z: jax.Array) -> float: # noqa: D102 + def h(self, z: jnp.ndarray) -> float: # noqa: D102 return 0.5 * mu.norm(z, self.p) ** 2 - def h_legendre(self, z: jax.Array) -> float: + def h_legendre(self, z: jnp.ndarray) -> float: """Legendre transform of :func:`h`. For details on the derivation, see e.g., :cite:`boyd:04`, p. 93/94. @@ -234,10 +234,10 @@ def __init__(self, p: float): self.p = p self.q = 1.0 / (1.0 - (1.0 / p)) if p > 1.0 else jnp.inf - def h(self, z: jax.Array) -> float: # noqa: D102 + def h(self, z: jnp.ndarray) -> float: # noqa: D102 return mu.norm(z, self.p) ** self.p / self.p - def h_legendre(self, z: jax.Array) -> float: # noqa: D102 + def h_legendre(self, z: jnp.ndarray) -> float: # noqa: D102 # not defined for `p=1` return mu.norm(z, self.q) ** self.q / self.q @@ -260,7 +260,7 @@ class Euclidean(CostFn): because the function is not strictly convex (it is linear on rays). """ - def pairwise(self, x: jax.Array, y: jax.Array) -> float: + def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: """Compute Euclidean norm using custom jvp implementation. Here we use a custom jvp implementation for the norm that does not yield @@ -277,22 +277,22 @@ class SqEuclidean(TICost): Implemented as a translation invariant cost, :math:`h(z) = \|z\|^2`. """ - def norm(self, x: jax.Array) -> Union[float, jax.Array]: + def norm(self, x: jnp.ndarray) -> Union[float, jnp.ndarray]: """Compute squared Euclidean norm for vector.""" return jnp.sum(x ** 2, axis=-1) - def pairwise(self, x: jax.Array, y: jax.Array) -> float: + def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: """Compute minus twice the dot-product between vectors.""" return -2. * jnp.vdot(x, y) - def h(self, z: jax.Array) -> float: # noqa: D102 + def h(self, z: jnp.ndarray) -> float: # noqa: D102 return jnp.sum(z ** 2) - def h_legendre(self, z: jax.Array) -> float: # noqa: D102 + def h_legendre(self, z: jnp.ndarray) -> float: # noqa: D102 return 0.25 * jnp.sum(z ** 2) - def barycenter(self, weights: jax.Array, - xs: jax.Array) -> Tuple[jax.Array, Any]: + def barycenter(self, weights: jnp.ndarray, + xs: jnp.ndarray) -> Tuple[jnp.ndarray, Any]: """Output barycenter of vectors when using squared-Euclidean distance.""" return jnp.average(xs, weights=weights, axis=0), None @@ -309,7 +309,7 @@ def __init__(self, ridge: float = 1e-8): super().__init__() self._ridge = ridge - def pairwise(self, x: jax.Array, y: jax.Array) -> float: + def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: """Cosine distance between vectors, denominator regularized with ridge.""" ridge = self._ridge x_norm = jnp.linalg.norm(x, axis=-1) @@ -318,7 +318,7 @@ def pairwise(self, x: jax.Array, y: jax.Array) -> float: return 1.0 - cosine_similarity @classmethod - def _padder(cls, dim: int) -> jax.Array: + def _padder(cls, dim: int) -> jnp.ndarray: return jnp.ones((1, dim)) @@ -341,7 +341,7 @@ class RegTICost(TICost, abc.ABC): def __init__( self, scaling_reg: float = 1.0, - matrix: Optional[jax.Array] = None, + matrix: Optional[jnp.ndarray] = None, orthogonal: bool = False, ): super().__init__() @@ -350,16 +350,16 @@ def __init__( self.orthogonal = orthogonal @abc.abstractmethod - def _reg(self, z: jax.Array) -> float: + def _reg(self, z: jnp.ndarray) -> float: """Regularization function.""" - def _reg_stiefel_orth(self, z: jax.Array) -> float: + def _reg_stiefel_orth(self, z: jnp.ndarray) -> float: raise NotImplementedError( "Regularization in the orthogonal " "subspace is not implemented." ) - def reg(self, z: jax.Array) -> float: + def reg(self, z: jnp.ndarray) -> float: """Regularization function. Args: @@ -374,7 +374,7 @@ def reg(self, z: jax.Array) -> float: return self._reg_stiefel_orth(z) return self._reg(self.matrix @ z) - def prox_reg(self, z: jax.Array, tau: float = 1.0) -> jax.Array: + def prox_reg(self, z: jnp.ndarray, tau: float = 1.0) -> jnp.ndarray: """Proximal operator of :meth:`reg`. Args: @@ -391,24 +391,26 @@ def prox_reg(self, z: jax.Array, tau: float = 1.0) -> jax.Array: return self._prox_reg_stiefel_orth(z, tau) return self._prox_reg_stiefel(z, tau) - def _prox_reg(self, z: jax.Array, tau: float = 1.0) -> jax.Array: + def _prox_reg(self, z: jnp.ndarray, tau: float = 1.0) -> jnp.ndarray: raise NotImplementedError("Proximal operator is not implemented.") - def _prox_reg_stiefel_orth(self, z: jax.Array, tau: float = 1.0) -> jax.Array: + def _prox_reg_stiefel_orth( + self, z: jnp.ndarray, tau: float = 1.0 + ) -> jnp.ndarray: - def orth(x: jax.Array) -> jax.Array: + def orth(x: jnp.ndarray) -> jnp.ndarray: return x - self.matrix.T @ (self.matrix @ x) # assumes `matrix` has orthogonal rows tmp = orth(z) return z - orth(tmp - self._prox_reg(tmp, tau)) - def _prox_reg_stiefel(self, z: jax.Array, tau: float) -> jax.Array: + def _prox_reg_stiefel(self, z: jnp.ndarray, tau: float) -> jnp.ndarray: # assumes `matrix` has orthogonal rows tmp = self.matrix @ z return z - self.matrix.T @ (tmp - self._prox_reg(tmp, tau)) - def prox_legendre_reg(self, z: jax.Array, tau: float = 1.0) -> jax.Array: + def prox_legendre_reg(self, z: jnp.ndarray, tau: float = 1.0) -> jnp.ndarray: r"""Proximal operator of the Legendre transform of :meth:`reg`. Uses Moreau's decomposition: @@ -426,16 +428,16 @@ def prox_legendre_reg(self, z: jax.Array, tau: float = 1.0) -> jax.Array: """ return z - tau * self.prox_reg(z / tau, 1.0 / tau) - def h(self, z: jax.Array) -> float: # noqa: D102 + def h(self, z: jnp.ndarray) -> float: # noqa: D102 out = 0.5 * jnp.sum(z ** 2) return out + self.scaling_reg * self.reg(z) - def h_legendre(self, z: jax.Array) -> float: # noqa: D102 + def h_legendre(self, z: jnp.ndarray) -> float: # noqa: D102 q = jax.lax.stop_gradient(self.prox_reg(z)) return jnp.sum(q * z) - self.h(q) - def h_transform(self, f: Callable[[jax.Array], float], - **kwargs: Any) -> Callable[[jax.Array], float]: + def h_transform(self, f: Callable[[jnp.ndarray], float], + **kwargs: Any) -> Callable[[jnp.ndarray], float]: r"""Compute the h-transform of a concave function. Return a callable :math:`f_h` defined as: @@ -465,16 +467,18 @@ def h_transform(self, f: Callable[[jax.Array], float], The h-transform of ``f``. """ - def minus_f(z: jax.Array, x: jax.Array) -> float: + def minus_f(z: jnp.ndarray, x: jnp.ndarray) -> float: return -f(x - z) - def prox(x: jax.Array, scaling_reg: float, scaling_h: float) -> jax.Array: + def prox( + x: jnp.ndarray, scaling_reg: float, scaling_h: float + ) -> jnp.ndarray: # https://web.stanford.edu/~boyd/papers/pdf/prox_algs.pdf 2.2. tmp = 1.0 / (1.0 + scaling_h) tau = scaling_reg * scaling_h * tmp return self.prox_reg(x * tmp, tau) - def f_h(x: jax.Array) -> float: + def f_h(x: jnp.ndarray) -> float: pg = jaxopt.ProximalGradient(fun=minus_f, prox=prox, **kwargs) pg_run = pg.run(x, self.scaling_reg, x=x) pg_sol = jax.lax.stop_gradient(pg_run.params) @@ -504,10 +508,10 @@ class ElasticL1(RegTICost): to promote displacements in the span of ``matrix``. """ - def _reg(self, z: jax.Array) -> float: # noqa: D102 + def _reg(self, z: jnp.ndarray) -> float: # noqa: D102 return jnp.linalg.norm(z, ord=1) - def _prox_reg(self, z: jax.Array, tau: float = 1.0) -> jax.Array: + def _prox_reg(self, z: jnp.ndarray, tau: float = 1.0) -> jnp.ndarray: return jnp.sign(z) * jax.nn.relu(jnp.abs(z) - tau * self.scaling_reg) @@ -525,17 +529,19 @@ class ElasticL2(RegTICost): to promote displacements in the span of ``matrix``. """ - def _reg(self, z: jax.Array) -> float: # noqa: D102 + def _reg(self, z: jnp.ndarray) -> float: # noqa: D102 return 0.5 * jnp.sum(z ** 2) - def _reg_stiefel_orth(self, z: jax.Array) -> float: + def _reg_stiefel_orth(self, z: jnp.ndarray) -> float: # Pythagorean identity return self._reg(z) - self._reg(self.matrix @ z) - def _prox_reg(self, z: jax.Array, tau: float = 1.0) -> jax.Array: + def _prox_reg(self, z: jnp.ndarray, tau: float = 1.0) -> jnp.ndarray: return z / (1.0 + tau * self.scaling_reg) - def _prox_reg_stiefel_orth(self, z: jax.Array, tau: float = 1.0) -> jax.Array: + def _prox_reg_stiefel_orth( + self, z: jnp.ndarray, tau: float = 1.0 + ) -> jnp.ndarray: out = z + tau * self.scaling_reg * self.matrix.T @ (self.matrix @ z) return self._prox_reg(out, tau) @@ -559,7 +565,7 @@ class ElasticSTVS(RegTICost): to promote displacements in the span of ``matrix``. """ # noqa: D205,E501 - def _reg(self, z: jax.Array) -> float: # noqa: D102 + def _reg(self, z: jnp.ndarray) -> float: # noqa: D102 u = jnp.arcsinh(jnp.abs(z) / (2 * self.scaling_reg)) out = u - 0.5 * jnp.exp(-2.0 * u) # Lemma 2.1 of `schreck:15`; @@ -567,8 +573,8 @@ def _reg(self, z: jax.Array) -> float: # noqa: D102 return self.scaling_reg * jnp.sum(out + 0.5) # make positive def _prox_reg( # noqa: D102 - self, z: jax.Array, tau: float = 1.0 - ) -> jax.Array: + self, z: jnp.ndarray, tau: float = 1.0 + ) -> jnp.ndarray: tmp = 1.0 - (self.scaling_reg * tau / (jnp.abs(z) + 1e-12)) ** 2 return jax.nn.relu(tmp) * z @@ -594,7 +600,7 @@ def __init__(self, k: int, *args, **kwargs: Any): super().__init__(*args, **kwargs) self.k = k - def _reg(self, z: jax.Array) -> float: # noqa: D102 + def _reg(self, z: jnp.ndarray) -> float: # noqa: D102 # Prop 2.1 in :cite:`argyriou:12` k = self.k top_w = jax.lax.top_k(jnp.abs(z), k)[0] # Fetch largest k values @@ -615,14 +621,15 @@ def _reg(self, z: jax.Array) -> float: # noqa: D102 return 0.5 * (s + (r + 1) * cesaro[r] ** 2) - def prox_reg(self, z: jax.Array, tau: float = 1.0) -> float: # noqa: D102 + def prox_reg(self, z: jnp.ndarray, tau: float = 1.0) -> float: # noqa: D102 @functools.partial(jax.vmap, in_axes=[0, None, None]) - def find_indices(r: int, l: jax.Array, - z: jax.Array) -> Tuple[jax.Array, jax.Array]: + def find_indices(r: int, l: jnp.ndarray, + z: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: @functools.partial(jax.vmap, in_axes=[None, 0, None]) - def inner(r: int, l: int, z: jax.Array) -> Tuple[jax.Array, jax.Array]: + def inner(r: int, l: int, + z: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: i = k - r - 1 res = jnp.sum(z * ((i <= ixs) & (ixs < l))) res /= l - k + (beta + 1) * r + beta + 1 @@ -685,14 +692,14 @@ def __init__(self, dimension: int, sqrtm_kw: Optional[Dict[str, Any]] = None): self._dimension = dimension self._sqrtm_kw = {} if sqrtm_kw is None else sqrtm_kw - def norm(self, x: jax.Array) -> jax.Array: + def norm(self, x: jnp.ndarray) -> jnp.ndarray: """Compute norm of Gaussian, sq. 2-norm of mean + trace of covariance.""" mean, cov = x_to_means_and_covs(x, self._dimension) norm = jnp.sum(mean ** 2, axis=-1) norm += jnp.trace(cov, axis1=-2, axis2=-1) return norm - def pairwise(self, x: jax.Array, y: jax.Array) -> float: + def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: """Compute - 2 x Bures dot-product.""" mean_x, cov_x = x_to_means_and_covs(x, self._dimension) mean_y, cov_y = x_to_means_and_covs(y, self._dimension) @@ -706,12 +713,12 @@ def pairwise(self, x: jax.Array, y: jax.Array) -> float: def covariance_fixpoint_iter( self, - covs: jax.Array, - weights: jax.Array, + covs: jnp.ndarray, + weights: jnp.ndarray, tolerance: float = 1e-4, sqrtm_kw: Optional[Dict[str, Any]] = None, **kwargs: Any - ) -> jax.Array: + ) -> jnp.ndarray: """Iterate fix-point updates to compute barycenter of Gaussians. Args: @@ -737,8 +744,8 @@ def covariance_fixpoint_iter( @functools.partial(jax.vmap, in_axes=[None, 0, 0]) def scale_covariances( - cov_sqrt: jax.Array, cov: jax.Array, weight: jax.Array - ) -> jax.Array: + cov_sqrt: jnp.ndarray, cov: jnp.ndarray, weight: jnp.ndarray + ) -> jnp.ndarray: """Rescale covariance in barycenter step.""" return weight * matrix_square_root.sqrtm_only((cov_sqrt @ cov) @ cov_sqrt, **sqrtm_kw) @@ -750,8 +757,8 @@ def cond_fn(iteration: int, constants: Tuple[Any, ...], state) -> bool: def body_fn( iteration: int, constants: Tuple[Any, ...], - state: Tuple[jax.Array, float], compute_error: bool - ) -> Tuple[jax.Array, float]: + state: Tuple[jnp.ndarray, float], compute_error: bool + ) -> Tuple[jnp.ndarray, float]: del constants, compute_error cov, diffs = state cov_sqrt, cov_inv_sqrt, _ = matrix_square_root.sqrtm(cov, **sqrtm_kw) @@ -763,7 +770,7 @@ def body_fn( diffs = diffs.at[iteration // inner_iterations].set(diff) return next_cov, diffs - def init_state() -> Tuple[jax.Array, float]: + def init_state() -> Tuple[jnp.ndarray, float]: cov_init = jnp.eye(self._dimension) diffs = -jnp.ones( (np.ceil(max_iterations / inner_iterations).astype(int),), @@ -784,12 +791,12 @@ def init_state() -> Tuple[jax.Array, float]: def barycenter( self, - weights: jax.Array, - xs: jax.Array, + weights: jnp.ndarray, + xs: jnp.ndarray, tolerance: float = 1e-4, sqrtm_kw: Optional[Dict[str, Any]] = None, **kwargs: Any - ) -> Tuple[jax.Array, jax.Array]: + ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Compute the Bures barycenter of weighted Gaussian distributions. Implements the fixed point approach proposed in :cite:`alvarez-esteban:16` @@ -835,7 +842,7 @@ def barycenter( return mean_and_cov_to_x(mu_bary, cov_bary, self._dimension), diffs @classmethod - def _padder(cls, dim: int) -> jax.Array: + def _padder(cls, dim: int) -> jnp.ndarray: dimension = int((-1 + math.sqrt(1 + 4 * dim)) / 2) padding = mean_and_cov_to_x( jnp.zeros((dimension,)), jnp.eye(dimension), dimension @@ -878,7 +885,7 @@ def __init__( self._gamma = gamma self._sqrtm_kw = kwargs - def norm(self, x: jax.Array) -> jax.Array: + def norm(self, x: jnp.ndarray) -> jnp.ndarray: """Compute norm of Gaussian for unbalanced Bures. Args: @@ -891,7 +898,7 @@ def norm(self, x: jax.Array) -> jax.Array: """ return self._gamma * x[..., 0] - def pairwise(self, x: jax.Array, y: jax.Array) -> float: + def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: """Compute dot-product for unbalanced Bures. Args: @@ -985,17 +992,18 @@ def __init__( self.ground_cost = SqEuclidean() if ground_cost is None else ground_cost self.debiased = debiased - def pairwise(self, x: jax.Array, y: jax.Array) -> float: # noqa: D102 + def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: # noqa: D102 c_xy = self._soft_dtw(x, y) if self.debiased: return c_xy - 0.5 * (self._soft_dtw(x, x) + self._soft_dtw(y, y)) return c_xy - def _soft_dtw(self, t1: jax.Array, t2: jax.Array) -> float: + def _soft_dtw(self, t1: jnp.ndarray, t2: jnp.ndarray) -> float: def body( - carry: Tuple[jax.Array, jax.Array], current_antidiagonal: jax.Array - ) -> Tuple[Tuple[jax.Array, jax.Array], jax.Array]: + carry: Tuple[jnp.ndarray, jnp.ndarray], + current_antidiagonal: jnp.ndarray + ) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]: # modified from: https://github.com/khdlr/softdtw_jax two_ago, one_ago = carry @@ -1042,8 +1050,8 @@ def tree_unflatten(cls, aux_data, children): # noqa: D102 return cls(*children, **aux_data) -def x_to_means_and_covs(x: jax.Array, - dimension: int) -> Tuple[jax.Array, jax.Array]: +def x_to_means_and_covs(x: jnp.ndarray, + dimension: int) -> Tuple[jnp.ndarray, jnp.ndarray]: """Extract means and covariance matrices of Gaussians from raveled vector. Args: @@ -1063,8 +1071,8 @@ def x_to_means_and_covs(x: jax.Array, def mean_and_cov_to_x( - mean: jax.Array, covariance: jax.Array, dimension: int -) -> jax.Array: + mean: jnp.ndarray, covariance: jnp.ndarray, dimension: int +) -> jnp.ndarray: """Ravel a Gaussian's mean and covariance matrix to d(1 + d) vector.""" return jnp.concatenate( (mean, jnp.reshape(covariance, (dimension * dimension))) diff --git a/src/ott/geometry/geometry.py b/src/ott/geometry/geometry.py index 5d3db3ee6..6894176a6 100644 --- a/src/ott/geometry/geometry.py +++ b/src/ott/geometry/geometry.py @@ -79,14 +79,14 @@ class Geometry: def __init__( self, - cost_matrix: Optional[jax.Array] = None, - kernel_matrix: Optional[jax.Array] = None, + cost_matrix: Optional[jnp.ndarray] = None, + kernel_matrix: Optional[jnp.ndarray] = None, epsilon: Optional[Union[float, epsilon_scheduler.Epsilon]] = None, relative_epsilon: Optional[bool] = None, scale_cost: Union[bool, int, float, Literal["mean", "max_cost", "median"]] = 1.0, - src_mask: Optional[jax.Array] = None, - tgt_mask: Optional[jax.Array] = None, + src_mask: Optional[jnp.ndarray] = None, + tgt_mask: Optional[jnp.ndarray] = None, ): self._cost_matrix = cost_matrix self._kernel_matrix = kernel_matrix @@ -107,7 +107,7 @@ def cost_rank(self) -> Optional[int]: """Output rank of cost matrix, if any was provided.""" @property - def cost_matrix(self) -> jax.Array: + def cost_matrix(self) -> jnp.ndarray: """Cost matrix, recomputed from kernel if only kernel was specified.""" if self._cost_matrix is None: # If no epsilon was passed on to the geometry, then assume it is one by @@ -131,7 +131,7 @@ def mean_cost_matrix(self) -> float: return jnp.sum(tmp * self._m_normed_ones) @property - def kernel_matrix(self) -> jax.Array: + def kernel_matrix(self) -> jnp.ndarray: """Kernel matrix. Either provided by user or recomputed from :attr:`cost_matrix`. @@ -201,7 +201,7 @@ def is_symmetric(self) -> bool: @property def inv_scale_cost(self) -> float: """Compute and return inverse of scaling factor for cost matrix.""" - if isinstance(self._scale_cost, (int, float, np.number, jax.Array)): + if isinstance(self._scale_cost, (int, float, np.number, jnp.ndarray)): return 1.0 / self._scale_cost self = self._masked_geom(mask_value=jnp.nan) if self._scale_cost == "max_cost": @@ -245,12 +245,12 @@ def copy_epsilon(self, other: "Geometry") -> "Geometry": def apply_lse_kernel( self, - f: jax.Array, - g: jax.Array, + f: jnp.ndarray, + g: jnp.ndarray, eps: float, - vec: jax.Array = None, + vec: jnp.ndarray = None, axis: int = 0 - ) -> jax.Array: + ) -> jnp.ndarray: r"""Apply :attr:`kernel_matrix` in log domain. This function applies the ground geometry's kernel in log domain, using @@ -267,10 +267,10 @@ def apply_lse_kernel( f and g in iterations 1 & 2 respectively. Args: - f: jax.Array [num_a,] , potential of size num_rows of cost_matrix - g: jax.Array [num_b,] , potential of size num_cols of cost_matrix + f: jnp.ndarray [num_a,] , potential of size num_rows of cost_matrix + g: jnp.ndarray [num_b,] , potential of size num_cols of cost_matrix eps: float, regularization strength - vec: jax.Array [num_a or num_b,] , when not None, this has the effect of + vec: jnp.ndarray [num_a or num_b,] , when not None, this has the effect of doing log-Kernel computations with an addition elementwise multiplication of exp(g / eps) by a vector. This is carried out by adding weights to the log-sum-exp function, and needs to handle signs @@ -278,7 +278,7 @@ def apply_lse_kernel( axis: summing over axis 0 when doing (2), or over axis 1 when doing (1) Returns: - A jax.Array corresponding to output above, depending on axis. + A jnp.ndarray corresponding to output above, depending on axis. """ w_res, w_sgn = self._softmax(f, g, eps, vec, axis) remove = f if axis == 1 else g @@ -286,20 +286,20 @@ def apply_lse_kernel( def apply_kernel( self, - scaling: jax.Array, + scaling: jnp.ndarray, eps: Optional[float] = None, axis: int = 0, - ) -> jax.Array: + ) -> jnp.ndarray: """Apply :attr:`kernel_matrix` on positive scaling vector. Args: - scaling: jax.Array [num_a or num_b] , scaling of size num_rows or + scaling: jnp.ndarray [num_a or num_b] , scaling of size num_rows or num_cols of kernel_matrix eps: passed for consistency, not used yet. axis: standard kernel product if axis is 1, transpose if 0. Returns: - a jax.Array corresponding to output above, depending on axis. + a jnp.ndarray corresponding to output above, depending on axis. """ if eps is None: kernel = self.kernel_matrix @@ -311,10 +311,10 @@ def apply_kernel( def marginal_from_potentials( self, - f: jax.Array, - g: jax.Array, + f: jnp.ndarray, + g: jnp.ndarray, axis: int = 0, - ) -> jax.Array: + ) -> jnp.ndarray: """Output marginal of transportation matrix from potentials. This applies first lse kernel in the standard way, removes the @@ -323,8 +323,8 @@ def marginal_from_potentials( by potentials. Args: - f: jax.Array [num_a,] , potential of size num_rows of cost_matrix - g: jax.Array [num_b,] , potential of size num_cols of cost_matrix + f: jnp.ndarray [num_a,] , potential of size num_rows of cost_matrix + g: jnp.ndarray [num_b,] , potential of size num_cols of cost_matrix axis: axis along which to integrate, returns marginal on other axis. Returns: @@ -336,19 +336,23 @@ def marginal_from_potentials( def marginal_from_scalings( self, - u: jax.Array, - v: jax.Array, + u: jnp.ndarray, + v: jnp.ndarray, axis: int = 0, - ) -> jax.Array: + ) -> jnp.ndarray: """Output marginal of transportation matrix from scalings.""" u, v = (v, u) if axis == 0 else (u, v) return u * self.apply_kernel(v, eps=self.epsilon, axis=axis) - def transport_from_potentials(self, f: jax.Array, g: jax.Array) -> jax.Array: + def transport_from_potentials( + self, f: jnp.ndarray, g: jnp.ndarray + ) -> jnp.ndarray: """Output transport matrix from potentials.""" return jnp.exp(self._center(f, g) / self.epsilon) - def transport_from_scalings(self, u: jax.Array, v: jax.Array) -> jax.Array: + def transport_from_scalings( + self, u: jnp.ndarray, v: jnp.ndarray + ) -> jnp.ndarray: """Output transport matrix from pair of scalings.""" return self.kernel_matrix * u[:, jnp.newaxis] * v[jnp.newaxis, :] @@ -357,17 +361,17 @@ def transport_from_scalings(self, u: jax.Array, v: jax.Array) -> jax.Array: def update_potential( self, - f: jax.Array, - g: jax.Array, - log_marginal: jax.Array, + f: jnp.ndarray, + g: jnp.ndarray, + log_marginal: jnp.ndarray, iteration: Optional[int] = None, axis: int = 0, - ) -> jax.Array: + ) -> jnp.ndarray: """Carry out one Sinkhorn update for potentials, i.e. in log space. Args: - f: jax.Array [num_a,] , potential of size num_rows of cost_matrix - g: jax.Array [num_b,] , potential of size num_cols of cost_matrix + f: jnp.ndarray [num_a,] , potential of size num_rows of cost_matrix + g: jnp.ndarray [num_b,] , potential of size num_cols of cost_matrix log_marginal: targeted marginal iteration: used to compute epsilon from schedule, if provided. axis: axis along which the update should be carried out. @@ -381,15 +385,15 @@ def update_potential( def update_scaling( self, - scaling: jax.Array, - marginal: jax.Array, + scaling: jnp.ndarray, + marginal: jnp.ndarray, iteration: Optional[int] = None, axis: int = 0, - ) -> jax.Array: + ) -> jnp.ndarray: """Carry out one Sinkhorn update for scalings, using kernel directly. Args: - scaling: jax.Array of num_a or num_b positive values. + scaling: jnp.ndarray of num_a or num_b positive values. marginal: targeted marginal iteration: used to compute epsilon from schedule, if provided. axis: axis along which the update should be carried out. @@ -402,13 +406,13 @@ def update_scaling( return marginal / jnp.where(app_kernel > 0, app_kernel, 1.0) # Helper functions - def _center(self, f: jax.Array, g: jax.Array) -> jax.Array: + def _center(self, f: jnp.ndarray, g: jnp.ndarray) -> jnp.ndarray: return f[:, jnp.newaxis] + g[jnp.newaxis, :] - self.cost_matrix def _softmax( - self, f: jax.Array, g: jax.Array, eps: float, vec: Optional[jax.Array], - axis: int - ) -> Tuple[jax.Array, jax.Array]: + self, f: jnp.ndarray, g: jnp.ndarray, eps: float, + vec: Optional[jnp.ndarray], axis: int + ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Apply softmax row or column wise, weighted by vec.""" if vec is not None: if axis == 0: @@ -425,8 +429,8 @@ def _softmax( @functools.partial(jax.vmap, in_axes=[None, None, None, 0, None]) def _apply_transport_from_potentials( - self, f: jax.Array, g: jax.Array, vec: jax.Array, axis: int - ) -> jax.Array: + self, f: jnp.ndarray, g: jnp.ndarray, vec: jnp.ndarray, axis: int + ) -> jnp.ndarray: """Apply lse_kernel to arbitrary vector while keeping track of signs.""" lse_res, lse_sgn = self.apply_lse_kernel( f, g, self.epsilon, vec=vec, axis=axis @@ -437,11 +441,11 @@ def _apply_transport_from_potentials( # wrapper to allow default option for axis. def apply_transport_from_potentials( self, - f: jax.Array, - g: jax.Array, - vec: jax.Array, + f: jnp.ndarray, + g: jnp.ndarray, + vec: jnp.ndarray, axis: int = 0 - ) -> jax.Array: + ) -> jnp.ndarray: """Apply transport matrix computed from potentials to a (batched) vec. This approach does not instantiate the transport matrix itself, but uses @@ -452,9 +456,9 @@ def apply_transport_from_potentials( (b=..., return_sign=True) optional parameters of logsumexp. Args: - f: jax.Array [num_a,] , potential of size num_rows of cost_matrix - g: jax.Array [num_b,] , potential of size num_cols of cost_matrix - vec: jax.Array [batch, num_a or num_b], vector that will be multiplied + f: jnp.ndarray [num_a,] , potential of size num_rows of cost_matrix + g: jnp.ndarray [num_b,] , potential of size num_cols of cost_matrix + vec: jnp.ndarray [batch, num_a or num_b], vector that will be multiplied by transport matrix corresponding to potentials f, g, and geom. axis: axis to differentiate left (0) or right (1) multiply. @@ -469,7 +473,7 @@ def apply_transport_from_potentials( @functools.partial(jax.vmap, in_axes=[None, None, None, 0, None]) def _apply_transport_from_scalings( - self, u: jax.Array, v: jax.Array, vec: jax.Array, axis: int + self, u: jnp.ndarray, v: jnp.ndarray, vec: jnp.ndarray, axis: int ): u, v = (u, v * vec) if axis == 1 else (v, u * vec) return u * self.apply_kernel(v, eps=self.epsilon, axis=axis) @@ -477,20 +481,20 @@ def _apply_transport_from_scalings( # wrapper to allow default option for axis def apply_transport_from_scalings( self, - u: jax.Array, - v: jax.Array, - vec: jax.Array, + u: jnp.ndarray, + v: jnp.ndarray, + vec: jnp.ndarray, axis: int = 0 - ) -> jax.Array: + ) -> jnp.ndarray: """Apply transport matrix computed from scalings to a (batched) vec. This approach does not instantiate the transport matrix itself, but relies instead on the apply_kernel function. Args: - u: jax.Array [num_a,] , scaling of size num_rows of cost_matrix - v: jax.Array [num_b,] , scaling of size num_cols of cost_matrix - vec: jax.Array [batch, num_a or num_b], vector that will be multiplied + u: jnp.ndarray [num_a,] , scaling of size num_rows of cost_matrix + v: jnp.ndarray [num_b,] , scaling of size num_cols of cost_matrix + vec: jnp.ndarray [batch, num_a or num_b], vector that will be multiplied by transport matrix corresponding to scalings u, v, and geom. axis: axis to differentiate left (0) or right (1) multiply. @@ -503,7 +507,7 @@ def apply_transport_from_scalings( )[0, :] return self._apply_transport_from_scalings(u, v, vec, axis) - def potential_from_scaling(self, scaling: jax.Array) -> jax.Array: + def potential_from_scaling(self, scaling: jnp.ndarray) -> jnp.ndarray: """Compute dual potential vector from scaling vector. Args: @@ -514,7 +518,7 @@ def potential_from_scaling(self, scaling: jax.Array) -> jax.Array: """ return self.epsilon * jnp.log(scaling) - def scaling_from_potential(self, potential: jax.Array) -> jax.Array: + def scaling_from_potential(self, potential: jnp.ndarray) -> jnp.ndarray: """Compute scaling vector from dual potential. Args: @@ -528,7 +532,7 @@ def scaling_from_potential(self, potential: jax.Array) -> jax.Array: finite, jnp.exp(jnp.where(finite, potential / self.epsilon, 0.0)), 0.0 ) - def apply_square_cost(self, arr: jax.Array, axis: int = 0) -> jax.Array: + def apply_square_cost(self, arr: jnp.ndarray, axis: int = 0) -> jnp.ndarray: """Apply elementwise-square of cost matrix to array (vector or matrix). This function applies the ground geometry's cost matrix, to perform either @@ -549,11 +553,11 @@ def apply_square_cost(self, arr: jax.Array, axis: int = 0) -> jax.Array: def apply_cost( self, - arr: jax.Array, + arr: jnp.ndarray, axis: int = 0, - fn: Optional[Callable[[jax.Array], jax.Array]] = None, + fn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, **kwargs: Any - ) -> jax.Array: + ) -> jnp.ndarray: """Apply :attr:`cost_matrix` to array (vector or matrix). This function applies the ground geometry's cost matrix, to perform either @@ -562,7 +566,7 @@ def apply_cost( where C is [num_a, num_b] Args: - arr: jax.Array [num_a or num_b, p], vector that will be multiplied by + arr: jnp.ndarray [num_a or num_b, p], vector that will be multiplied by the cost matrix. axis: standard cost matrix if axis=1, transpose if 0 fn: function to apply to cost matrix element-wise before the dot product @@ -579,21 +583,21 @@ def apply_cost( def _apply_cost_to_vec( self, - vec: jax.Array, + vec: jnp.ndarray, axis: int = 0, fn=None, **_: Any, - ) -> jax.Array: + ) -> jnp.ndarray: """Apply ``[num_a, num_b]`` fn(cost) (or transpose) to vector. Args: - vec: jax.Array [num_a,] ([num_b,] if axis=1) vector + vec: jnp.ndarray [num_a,] ([num_b,] if axis=1) vector axis: axis on which the reduction is done. fn: function optionally applied to cost matrix element-wise, before the doc product Returns: - A jax.Array corresponding to cost x vector + A jnp.ndarray corresponding to cost x vector """ matrix = self.cost_matrix.T if axis == 0 else self.cost_matrix matrix = fn(matrix) if fn is not None else matrix @@ -621,7 +625,7 @@ def to_LRCGeometry( self, rank: int = 0, tol: float = 1e-2, - rng: Optional[jax.Array] = None, + rng: Optional[jnp.ndarray] = None, scale: float = 1. ) -> "low_rank.LRCGeometry": r"""Factorize the cost matrix using either SVD (full) or :cite:`indyk:19`. @@ -714,7 +718,7 @@ def to_LRCGeometry( ) def subset( - self, src_ixs: Optional[jax.Array], tgt_ixs: Optional[jax.Array], + self, src_ixs: Optional[jnp.ndarray], tgt_ixs: Optional[jnp.ndarray], **kwargs: Any ) -> "Geometry": """Subset rows or columns of a geometry. @@ -729,10 +733,10 @@ def subset( """ def subset_fn( - arr: Optional[jax.Array], - src_ixs: Optional[jax.Array], - tgt_ixs: Optional[jax.Array], - ) -> Optional[jax.Array]: + arr: Optional[jnp.ndarray], + src_ixs: Optional[jnp.ndarray], + tgt_ixs: Optional[jnp.ndarray], + ) -> Optional[jnp.ndarray]: if arr is None: return None if src_ixs is not None: @@ -751,8 +755,8 @@ def subset_fn( def mask( self, - src_mask: Optional[jax.Array], - tgt_mask: Optional[jax.Array], + src_mask: Optional[jnp.ndarray], + tgt_mask: Optional[jnp.ndarray], mask_value: float = 0., ) -> "Geometry": """Mask rows or columns of a geometry. @@ -776,10 +780,10 @@ def mask( """ def mask_fn( - arr: Optional[jax.Array], - src_mask: Optional[jax.Array], - tgt_mask: Optional[jax.Array], - ) -> Optional[jax.Array]: + arr: Optional[jnp.ndarray], + src_mask: Optional[jnp.ndarray], + tgt_mask: Optional[jnp.ndarray], + ) -> Optional[jnp.ndarray]: if arr is None: return arr assert arr.ndim == 2, arr.ndim @@ -797,12 +801,12 @@ def mask_fn( def _mask_subset_helper( self, - src_ixs: Optional[jax.Array], - tgt_ixs: Optional[jax.Array], + src_ixs: Optional[jnp.ndarray], + tgt_ixs: Optional[jnp.ndarray], *, fn: Callable[ - [Optional[jax.Array], Optional[jax.Array], Optional[jax.Array]], - Optional[jax.Array]], + [Optional[jnp.ndarray], Optional[jnp.ndarray], Optional[jnp.ndarray]], + Optional[jnp.ndarray]], propagate_mask: bool, **kwargs: Any, ) -> "Geometry": @@ -821,7 +825,7 @@ def _mask_subset_helper( ) @property - def src_mask(self) -> Optional[jax.Array]: + def src_mask(self) -> Optional[jnp.ndarray]: """Mask of shape ``[num_a,]`` to compute :attr:`cost_matrix` statistics. Specifically, it is used when computing: @@ -833,7 +837,7 @@ def src_mask(self) -> Optional[jax.Array]: return self._normalize_mask(self._src_mask, self.shape[0]) @property - def tgt_mask(self) -> Optional[jax.Array]: + def tgt_mask(self) -> Optional[jnp.ndarray]: """Mask of shape ``[num_b,]`` to compute :attr:`cost_matrix` statistics. Specifically, it is used when computing: @@ -859,22 +863,22 @@ def _masked_geom(self, mask_value: float = 0.) -> "Geometry": return self.mask(src_mask, tgt_mask, mask_value=mask_value) @property - def _n_normed_ones(self) -> jax.Array: + def _n_normed_ones(self) -> jnp.ndarray: """Normalized array of shape ``[num_a,]``.""" mask = self.src_mask arr = jnp.ones(self.shape[0]) if mask is None else mask return arr / jnp.sum(arr) @property - def _m_normed_ones(self) -> jax.Array: + def _m_normed_ones(self) -> jnp.ndarray: """Normalized array of shape ``[num_b,]``.""" mask = self.tgt_mask arr = jnp.ones(self.shape[1]) if mask is None else mask return arr / jnp.sum(arr) @staticmethod - def _normalize_mask(mask: Optional[Union[int, jax.Array]], - size: int) -> Optional[jax.Array]: + def _normalize_mask(mask: Optional[Union[int, jnp.ndarray]], + size: int) -> Optional[jnp.ndarray]: """Convert array of indices to a boolean mask.""" if mask is None: return None diff --git a/src/ott/geometry/graph.py b/src/ott/geometry/graph.py index ab0fe8768..c7dac0c99 100644 --- a/src/ott/geometry/graph.py +++ b/src/ott/geometry/graph.py @@ -48,7 +48,7 @@ class Graph(geometry.Geometry): def __init__( self, - laplacian: jax.Array, + laplacian: jnp.ndarray, t: float = 1e-3, n_steps: int = 100, numerical_scheme: Literal["backward_euler", @@ -66,7 +66,7 @@ def __init__( @classmethod def from_graph( cls, - G: jax.Array, + G: jnp.ndarray, t: Optional[float] = 1e-3, directed: bool = False, normalize: bool = False, @@ -113,10 +113,10 @@ def from_graph( def apply_kernel( self, - scaling: jax.Array, + scaling: jnp.ndarray, eps: Optional[float] = None, axis: int = 0, - ) -> jax.Array: + ) -> jnp.ndarray: r"""Apply :attr:`kernel_matrix` on positive scaling vector. Args: @@ -129,8 +129,8 @@ def apply_kernel( """ def conf_fn( - iteration: int, consts: Tuple[jax.Array, Optional[jax.Array]], - old_new: Tuple[jax.Array, jax.Array] + iteration: int, consts: Tuple[jnp.ndarray, Optional[jnp.ndarray]], + old_new: Tuple[jnp.ndarray, jnp.ndarray] ) -> bool: del iteration, consts @@ -143,9 +143,9 @@ def conf_fn( return (jnp.nanmax(f) - jnp.nanmin(f)) > self.tol def body_fn( - iteration: int, consts: Tuple[jax.Array, Optional[jax.Array]], - old_new: Tuple[jax.Array, jax.Array], compute_errors: bool - ) -> Tuple[jax.Array, jax.Array]: + iteration: int, consts: Tuple[jnp.ndarray, Optional[jnp.ndarray]], + old_new: Tuple[jnp.ndarray, jnp.ndarray], compute_errors: bool + ) -> Tuple[jnp.ndarray, jnp.ndarray]: del iteration, compute_errors L, scaled_lap = consts @@ -186,7 +186,7 @@ def body_fn( )[1] @property - def kernel_matrix(self) -> jax.Array: # noqa: D102 + def kernel_matrix(self) -> jnp.ndarray: # noqa: D102 n, _ = self.shape kernel = self.apply_kernel(jnp.eye(n)) # force symmetry because of numerical imprecision @@ -194,7 +194,7 @@ def kernel_matrix(self) -> jax.Array: # noqa: D102 return (kernel + kernel.T) * 0.5 @property - def cost_matrix(self) -> jax.Array: # noqa: D102 + def cost_matrix(self) -> jnp.ndarray: # noqa: D102 return -self.t * mu.safe_log(self.kernel_matrix) @property @@ -209,12 +209,12 @@ def _scale(self) -> float: ) @property - def _scaled_laplacian(self) -> jax.Array: + def _scaled_laplacian(self) -> jnp.ndarray: """Laplacian scaled by a constant, depending on the numerical scheme.""" return self._scale * self.laplacian @property - def _M(self) -> jax.Array: + def _M(self) -> jnp.ndarray: n, _ = self.shape return self._scaled_laplacian + jnp.eye(n) @@ -230,27 +230,29 @@ def is_symmetric(self) -> bool: # noqa: D102 def dtype(self) -> jnp.dtype: # noqa: D102 return self.laplacian.dtype - def transport_from_potentials(self, f: jax.Array, g: jax.Array) -> jax.Array: + def transport_from_potentials( + self, f: jnp.ndarray, g: jnp.ndarray + ) -> jnp.ndarray: """Not implemented.""" raise ValueError("Not implemented.") def apply_transport_from_potentials( self, - f: jax.Array, - g: jax.Array, - vec: jax.Array, + f: jnp.ndarray, + g: jnp.ndarray, + vec: jnp.ndarray, axis: int = 0 - ) -> jax.Array: + ) -> jnp.ndarray: """Since applying from potentials is not feasible in grids, use scalings.""" u, v = self.scaling_from_potential(f), self.scaling_from_potential(g) return self.apply_transport_from_scalings(u, v, vec, axis=axis) def marginal_from_potentials( self, - f: jax.Array, - g: jax.Array, + f: jnp.ndarray, + g: jnp.ndarray, axis: int = 0, - ) -> jax.Array: + ) -> jnp.ndarray: """Not implemented.""" raise ValueError("Not implemented.") diff --git a/src/ott/geometry/grid.py b/src/ott/geometry/grid.py index 3401f52c7..fd64500c9 100644 --- a/src/ott/geometry/grid.py +++ b/src/ott/geometry/grid.py @@ -71,7 +71,7 @@ class Grid(geometry.Geometry): def __init__( self, - x: Optional[Sequence[jax.Array]] = None, + x: Optional[Sequence[jnp.ndarray]] = None, grid_size: Optional[Sequence[int]] = None, cost_fns: Optional[Sequence[costs.CostFn]] = None, num_a: Optional[int] = None, @@ -146,12 +146,12 @@ def is_symmetric(self) -> bool: # noqa: D102 # Reimplemented functions to be used in regularized OT def apply_lse_kernel( self, - f: jax.Array, - g: jax.Array, + f: jnp.ndarray, + g: jnp.ndarray, eps: float, - vec: Optional[jax.Array] = None, + vec: Optional[jnp.ndarray] = None, axis: int = 0 - ) -> jax.Array: + ) -> jnp.ndarray: """Apply grid kernel in log space. See notes in parent class for use case. Reshapes vector inputs below as grids, applies kernels onto each slice, and @@ -160,10 +160,10 @@ def apply_lse_kernel( More implementation details in :cite:`schmitz:18`. Args: - f: jax.Array, a vector of potentials - g: jax.Array, a vector of potentials + f: jnp.ndarray, a vector of potentials + g: jnp.ndarray, a vector of potentials eps: float, regularization strength - vec: jax.Array, if needed, a vector onto which apply the kernel weighted + vec: jnp.ndarray, if needed, a vector onto which apply the kernel weighted by f and g. axis: axis (0 or 1) along which summation should be carried out. @@ -209,8 +209,8 @@ def _apply_lse_kernel_one_dimension(self, dimension, f, g, eps, vec=None): return jnp.transpose(softmax_res, indices), None def _apply_cost_to_vec( - self, vec: jax.Array, axis: int = 0, fn=None - ) -> jax.Array: + self, vec: jnp.ndarray, axis: int = 0, fn=None + ) -> jnp.ndarray: r"""Apply grid's cost matrix (without instantiating it) to a vector. The `apply_cost` operation on grids rests on the following identity. @@ -229,13 +229,13 @@ def _apply_cost_to_vec( summation while keeping dimensions. Args: - vec: jax.Array, flat vector of total size prod(grid_size). + vec: jnp.ndarray, flat vector of total size prod(grid_size). axis: axis 0 if applying transpose costs, 1 if using the original cost. fn: function optionally applied to cost matrix element-wise, before the dot product. Returns: - A jax.Array corresponding to cost x matrix + A jnp.ndarray corresponding to cost x matrix """ vec = jnp.reshape(vec, self.grid_size) accum_vec = jnp.zeros_like(vec) @@ -255,10 +255,10 @@ def _apply_cost_to_vec( def apply_kernel( self, - scaling: jax.Array, + scaling: jnp.ndarray, eps: Optional[float] = None, axis: Optional[int] = None - ) -> jax.Array: + ) -> jnp.ndarray: """Apply grid kernel on scaling vector. See notes in parent class for use. @@ -269,7 +269,7 @@ def apply_kernel( More implementation details in :cite:`schmitz:18`, Args: - scaling: jax.Array, a vector of scaling (>0) values. + scaling: jnp.ndarray, a vector of scaling (>0) values. eps: float, regularization strength axis: axis (0 or 1) along which summation should be carried out. @@ -289,7 +289,7 @@ def apply_kernel( return scaling.ravel() def transport_from_potentials( - self, f: jax.Array, g: jax.Array, axis: int = 0 + self, f: jnp.ndarray, g: jnp.ndarray, axis: int = 0 ) -> NoReturn: """Not implemented, use :meth:`apply_transport_from_potentials` instead.""" raise ValueError( @@ -300,7 +300,7 @@ def transport_from_potentials( ) def transport_from_scalings( - self, f: jax.Array, g: jax.Array, axis: int = 0 + self, f: jnp.ndarray, g: jnp.ndarray, axis: int = 0 ) -> NoReturn: """Not implemented, use :meth:`apply_transport_from_scalings` instead.""" raise ValueError( @@ -311,15 +311,15 @@ def transport_from_scalings( ) def subset( - self, src_ixs: Optional[jax.Array], tgt_ixs: Optional[jax.Array] + self, src_ixs: Optional[jnp.ndarray], tgt_ixs: Optional[jnp.ndarray] ) -> NoReturn: """Not implemented.""" raise NotImplementedError("Subsetting is not implemented for grids.") def mask( self, - src_mask: Optional[jax.Array], - tgt_mask: Optional[jax.Array], + src_mask: Optional[jnp.ndarray], + tgt_mask: Optional[jnp.ndarray], mask_value: float = 0., ) -> NoReturn: """Not implemented.""" diff --git a/src/ott/geometry/low_rank.py b/src/ott/geometry/low_rank.py index 750d8db62..1bfaeae0a 100644 --- a/src/ott/geometry/low_rank.py +++ b/src/ott/geometry/low_rank.py @@ -33,8 +33,8 @@ class LRCGeometry(geometry.Geometry): if :math:`C = AB^T` and :math:`D = EF^T` then :math:`C + D = [A,E][B,F]^T` Args: - cost_1: jax.Array[num_a, r] - cost_2: jax.Array[num_b, r] + cost_1: jnp.ndarray[num_a, r] + cost_2: jnp.ndarray[num_b, r] bias: constant added to entire cost matrix. scale: Value used to rescale the factors of the low-rank geometry. scale_cost: option to rescale the cost matrix. Implemented scalings are @@ -51,8 +51,8 @@ class LRCGeometry(geometry.Geometry): def __init__( self, - cost_1: jax.Array, - cost_2: jax.Array, + cost_1: jnp.ndarray, + cost_2: jnp.ndarray, bias: float = 0.0, scale_factor: float = 1.0, scale_cost: Union[bool, int, float, Literal["mean", "max_bound", @@ -69,13 +69,13 @@ def __init__( self.batch_size = batch_size @property - def cost_1(self) -> jax.Array: + def cost_1(self) -> jnp.ndarray: """First factor of the :attr:`cost_matrix`.""" scale_factor = jnp.sqrt(self._scale_factor * self.inv_scale_cost) return scale_factor * self._cost_1 @property - def cost_2(self) -> jax.Array: + def cost_2(self) -> jnp.ndarray: """Second factor of the :attr:`cost_matrix`.""" scale_factor = jnp.sqrt(self._scale_factor * self.inv_scale_cost) return scale_factor * self._cost_2 @@ -90,7 +90,7 @@ def cost_rank(self) -> int: # noqa: D102 return self._cost_1.shape[1] @property - def cost_matrix(self) -> jax.Array: + def cost_matrix(self) -> jnp.ndarray: """Materialize the cost matrix.""" return jnp.matmul(self.cost_1, self.cost_2.T) + self.bias @@ -107,7 +107,7 @@ def is_symmetric(self) -> bool: # noqa: D102 @property def inv_scale_cost(self) -> float: # noqa: D102 - if isinstance(self._scale_cost, (int, float, jax.Array)): + if isinstance(self._scale_cost, (int, float, jnp.ndarray)): return 1.0 / self._scale_cost self = self._masked_geom() if self._scale_cost == "max_bound": @@ -124,7 +124,7 @@ def inv_scale_cost(self) -> float: # noqa: D102 return 1.0 / self.compute_max_cost() raise ValueError(f"Scaling {self._scale_cost} not implemented.") - def apply_square_cost(self, arr: jax.Array, axis: int = 0) -> jax.Array: + def apply_square_cost(self, arr: jnp.ndarray, axis: int = 0) -> jnp.ndarray: """Apply elementwise-square of cost matrix to array (vector or matrix).""" (n, m), r = self.shape, self.cost_rank # When applying square of a LRCGeometry, one can either elementwise square @@ -142,15 +142,15 @@ def apply_square_cost(self, arr: jax.Array, axis: int = 0) -> jax.Array: def _apply_cost_to_vec( self, - vec: jax.Array, + vec: jnp.ndarray, axis: int = 0, - fn: Optional[Callable[[jax.Array], jax.Array]] = None, + fn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, is_linear: bool = False, - ) -> jax.Array: + ) -> jnp.ndarray: """Apply [num_a, num_b] fn(cost) (or transpose) to vector. Args: - vec: jax.Array [num_a,] ([num_b,] if axis=1) vector + vec: jnp.ndarray [num_a,] ([num_b,] if axis=1) vector axis: axis on which the reduction is done. fn: function optionally applied to cost matrix element-wise, before the doc product @@ -159,12 +159,12 @@ def _apply_cost_to_vec( for a heuristic to help determine if a function is linear. Returns: - A jax.Array corresponding to cost x vector + A jnp.ndarray corresponding to cost x vector """ def linear_apply( - vec: jax.Array, axis: int, fn: Callable[[jax.Array], jax.Array] - ) -> jax.Array: + vec: jnp.ndarray, axis: int, fn: Callable[[jnp.ndarray], jnp.ndarray] + ) -> jnp.ndarray: c1 = self.cost_1 if axis == 1 else self.cost_2 c2 = self.cost_2 if axis == 1 else self.cost_1 c2 = fn(c2) if fn is not None else c2 @@ -229,7 +229,7 @@ def to_LRCGeometry( self, rank: int = 0, tol: float = 1e-2, - rng: Optional[jax.Array] = None, + rng: Optional[jnp.ndarray] = None, scale: float = 1.0, ) -> "LRCGeometry": """Return self.""" @@ -241,14 +241,14 @@ def can_LRC(self): # noqa: D102 return True def subset( # noqa: D102 - self, src_ixs: Optional[jax.Array], tgt_ixs: Optional[jax.Array], + self, src_ixs: Optional[jnp.ndarray], tgt_ixs: Optional[jnp.ndarray], **kwargs: Any ) -> "LRCGeometry": def subset_fn( - arr: Optional[jax.Array], - ixs: Optional[jax.Array], - ) -> jax.Array: + arr: Optional[jnp.ndarray], + ixs: Optional[jnp.ndarray], + ) -> jnp.ndarray: return arr if arr is None or ixs is None else arr[jnp.atleast_1d(ixs)] return self._mask_subset_helper( @@ -257,15 +257,15 @@ def subset_fn( def mask( # noqa: D102 self, - src_mask: Optional[jax.Array], - tgt_mask: Optional[jax.Array], + src_mask: Optional[jnp.ndarray], + tgt_mask: Optional[jnp.ndarray], mask_value: float = 0., ) -> "LRCGeometry": def mask_fn( - arr: Optional[jax.Array], - mask: Optional[jax.Array], - ) -> Optional[jax.Array]: + arr: Optional[jnp.ndarray], + mask: Optional[jnp.ndarray], + ) -> Optional[jnp.ndarray]: if arr is None or mask is None: return arr return jnp.where(mask[:, None], arr, mask_value) @@ -278,11 +278,11 @@ def mask_fn( def _mask_subset_helper( self, - src_ixs: Optional[jax.Array], - tgt_ixs: Optional[jax.Array], + src_ixs: Optional[jnp.ndarray], + tgt_ixs: Optional[jnp.ndarray], *, - fn: Callable[[Optional[jax.Array], Optional[jax.Array]], - Optional[jax.Array]], + fn: Callable[[Optional[jnp.ndarray], Optional[jnp.ndarray]], + Optional[jnp.ndarray]], propagate_mask: bool, **kwargs: Any, ) -> "LRCGeometry": diff --git a/src/ott/geometry/pointcloud.py b/src/ott/geometry/pointcloud.py index c5d48a096..2050e1562 100644 --- a/src/ott/geometry/pointcloud.py +++ b/src/ott/geometry/pointcloud.py @@ -56,8 +56,8 @@ class PointCloud(geometry.Geometry): def __init__( self, - x: jax.Array, - y: Optional[jax.Array] = None, + x: jnp.ndarray, + y: Optional[jnp.ndarray] = None, cost_fn: Optional[costs.CostFn] = None, batch_size: Optional[int] = None, scale_cost: Union[bool, int, float, @@ -77,13 +77,13 @@ def __init__( self._scale_cost = "mean" if scale_cost is True else scale_cost @property - def _norm_x(self) -> Union[float, jax.Array]: + def _norm_x(self) -> Union[float, jnp.ndarray]: if self._axis_norm == 0: return self.cost_fn.norm(self.x) return 0. @property - def _norm_y(self) -> Union[float, jax.Array]: + def _norm_y(self) -> Union[float, jnp.ndarray]: if self._axis_norm == 0: return self.cost_fn.norm(self.y) return 0. @@ -98,14 +98,14 @@ def _check_LRC_dim(self): return n * m > (n + m) * d @property - def cost_matrix(self) -> Optional[jax.Array]: # noqa: D102 + def cost_matrix(self) -> Optional[jnp.ndarray]: # noqa: D102 if self.is_online: return None cost_matrix = self._compute_cost_matrix() return cost_matrix * self.inv_scale_cost @property - def kernel_matrix(self) -> Optional[jax.Array]: # noqa: D102 + def kernel_matrix(self) -> Optional[jnp.ndarray]: # noqa: D102 if self.is_online: return None return jnp.exp(-self.cost_matrix / self.epsilon) @@ -141,7 +141,7 @@ def cost_rank(self) -> int: # noqa: D102 @property def inv_scale_cost(self) -> float: # noqa: D102 - if isinstance(self._scale_cost, (int, float, jax.Array)): + if isinstance(self._scale_cost, (int, float, jnp.ndarray)): return 1.0 / self._scale_cost self = self._masked_geom() if self._scale_cost == "max_cost": @@ -183,7 +183,7 @@ def inv_scale_cost(self) -> float: # noqa: D102 ) raise ValueError(f"Scaling {self._scale_cost} not implemented.") - def _compute_cost_matrix(self) -> jax.Array: + def _compute_cost_matrix(self) -> jnp.ndarray: cost_matrix = self.cost_fn.all_pairs_pairwise(self.x, self.y) if self._axis_norm is not None: cost_matrix += self._norm_x[:, jnp.newaxis] + self._norm_y[jnp.newaxis, :] @@ -191,12 +191,12 @@ def _compute_cost_matrix(self) -> jax.Array: def apply_lse_kernel( # noqa: D102 self, - f: jax.Array, - g: jax.Array, + f: jnp.ndarray, + g: jnp.ndarray, eps: float, - vec: Optional[jax.Array] = None, + vec: Optional[jnp.ndarray] = None, axis: int = 0 - ) -> jax.Array: + ) -> jnp.ndarray: def body0(carry, i: int): f, g, eps, vec = carry @@ -278,10 +278,10 @@ def finalize(i: int): def apply_kernel( # noqa: D102 self, - scaling: jax.Array, + scaling: jnp.ndarray, eps: Optional[float] = None, axis: int = 0 - ) -> jax.Array: + ) -> jnp.ndarray: if eps is None: eps = self.epsilon @@ -303,8 +303,8 @@ def apply_kernel( # noqa: D102 ) def transport_from_potentials( # noqa: D102 - self, f: jax.Array, g: jax.Array - ) -> jax.Array: + self, f: jnp.ndarray, g: jnp.ndarray + ) -> jnp.ndarray: if not self.is_online: return super().transport_from_potentials(f, g) transport = jax.vmap( @@ -317,8 +317,8 @@ def transport_from_potentials( # noqa: D102 ) def transport_from_scalings( # noqa: D102 - self, u: jax.Array, v: jax.Array - ) -> jax.Array: + self, u: jnp.ndarray, v: jnp.ndarray + ) -> jnp.ndarray: if not self.is_online: return super().transport_from_scalings(u, v) transport = jax.vmap( @@ -342,11 +342,11 @@ def transport_from_scalings( # noqa: D102 def apply_cost( self, - arr: jax.Array, + arr: jnp.ndarray, axis: int = 0, - fn: Optional[Callable[[jax.Array], jax.Array]] = None, + fn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, is_linear: bool = False, - ) -> jax.Array: + ) -> jnp.ndarray: """Apply cost matrix to array (vector or matrix). This function applies the geometry's cost matrix, to perform either @@ -356,7 +356,7 @@ def apply_cost( application of fn to each entry of the :attr:`cost_matrix`. Args: - arr: jax.Array [num_a or num_b, batch], vector that will be multiplied + arr: jnp.ndarray [num_a or num_b, batch], vector that will be multiplied by the cost matrix. axis: standard cost matrix if axis=1, transpose if 0. fn: function optionally applied to cost matrix element-wise, before the @@ -367,7 +367,7 @@ def apply_cost( for a heuristic to help determine if a function is linear. Returns: - A jax.Array, [num_b, batch] if axis=0 or [num_a, batch] if axis=1 + A jnp.ndarray, [num_b, batch] if axis=0 or [num_a, batch] if axis=1 """ # switch to efficient computation for the squared euclidean case. if self.is_squared_euclidean and (fn is None or is_linear): @@ -375,7 +375,9 @@ def apply_cost( return self._apply_cost(arr, axis, fn=fn) - def _apply_cost(self, arr: jax.Array, axis: int = 0, fn=None) -> jax.Array: + def _apply_cost( + self, arr: jnp.ndarray, axis: int = 0, fn=None + ) -> jnp.ndarray: """See :meth:`apply_cost`.""" if not self.is_online: return super().apply_cost(arr, axis, fn) @@ -399,24 +401,24 @@ def _apply_cost(self, arr: jax.Array, axis: int = 0, fn=None) -> jax.Array: def vec_apply_cost( self, - arr: jax.Array, + arr: jnp.ndarray, axis: int = 0, - fn: Optional[Callable[[jax.Array], jax.Array]] = None - ) -> jax.Array: + fn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None + ) -> jnp.ndarray: """Apply the geometry's cost matrix in a vectorized way. This function can be used when the cost matrix is squared euclidean and ``fn`` is a linear function. Args: - arr: jax.Array [num_a or num_b, p], vector that will be multiplied + arr: jnp.ndarray [num_a or num_b, p], vector that will be multiplied by the cost matrix. axis: standard cost matrix if axis=1, transport if 0. fn: function optionally applied to cost matrix element-wise, before the application. Returns: - A jax.Array, [num_b, p] if axis=0 or [num_a, p] if axis=1 + A jnp.ndarray, [num_b, p] if axis=0 or [num_a, p] if axis=1 """ assert self.is_squared_euclidean, "Cost matrix is not a squared Euclidean." rank = arr.ndim @@ -432,7 +434,7 @@ def vec_apply_cost( applied_cost = fn(applied_cost) return self.inv_scale_cost * applied_cost - def _leading_slice(self, t: jax.Array, i: int) -> jax.Array: + def _leading_slice(self, t: jnp.ndarray, i: int) -> jnp.ndarray: start_indices = [i * self.batch_size] + (t.ndim - 1) * [0] slice_sizes = [self.batch_size] + list(t.shape[1:]) return jax.lax.dynamic_slice(t, start_indices, slice_sizes) @@ -523,18 +525,18 @@ def finalize(i: int): f"Scaling method {summary} does not exist for online mode." ) - def barycenter(self, weights: jax.Array) -> jax.Array: + def barycenter(self, weights: jnp.ndarray) -> jnp.ndarray: """Compute barycenter of points in self.x using weights.""" return self.cost_fn.barycenter(self.x, weights)[0] @classmethod def prepare_divergences( cls, - x: jax.Array, - y: jax.Array, + x: jnp.ndarray, + y: jnp.ndarray, static_b: bool = False, - src_mask: Optional[jax.Array] = None, - tgt_mask: Optional[jax.Array] = None, + src_mask: Optional[jnp.ndarray] = None, + tgt_mask: Optional[jnp.ndarray] = None, **kwargs: Any ) -> Tuple["PointCloud", ...]: """Instantiate the geometries used for a divergence computation.""" @@ -638,14 +640,14 @@ def _sqeucl_to_lr(self, scale: float = 1.0) -> low_rank.LRCGeometry: ) def subset( # noqa: D102 - self, src_ixs: Optional[jax.Array], tgt_ixs: Optional[jax.Array], + self, src_ixs: Optional[jnp.ndarray], tgt_ixs: Optional[jnp.ndarray], **kwargs: Any ) -> "PointCloud": def subset_fn( - arr: Optional[jax.Array], - ixs: Optional[jax.Array], - ) -> jax.Array: + arr: Optional[jnp.ndarray], + ixs: Optional[jnp.ndarray], + ) -> jnp.ndarray: return arr if arr is None or ixs is None else arr[jnp.atleast_1d(ixs)] return self._mask_subset_helper( @@ -654,15 +656,15 @@ def subset_fn( def mask( # noqa: D102 self, - src_mask: Optional[jax.Array], - tgt_mask: Optional[jax.Array], + src_mask: Optional[jnp.ndarray], + tgt_mask: Optional[jnp.ndarray], mask_value: float = 0., ) -> "PointCloud": def mask_fn( - arr: Optional[jax.Array], - mask: Optional[jax.Array], - ) -> Optional[jax.Array]: + arr: Optional[jnp.ndarray], + mask: Optional[jnp.ndarray], + ) -> Optional[jnp.ndarray]: if arr is None or mask is None: return arr return jnp.where(mask[:, None], arr, mask_value) @@ -675,11 +677,11 @@ def mask_fn( def _mask_subset_helper( self, - src_ixs: Optional[jax.Array], - tgt_ixs: Optional[jax.Array], + src_ixs: Optional[jnp.ndarray], + tgt_ixs: Optional[jnp.ndarray], *, - fn: Callable[[Optional[jax.Array], Optional[jax.Array]], - Optional[jax.Array]], + fn: Callable[[Optional[jnp.ndarray], Optional[jnp.ndarray]], + Optional[jnp.ndarray]], propagate_mask: bool, **kwargs: Any, ) -> "PointCloud": @@ -765,18 +767,18 @@ def _apply_cost_xy(x, y, norm_x, norm_y, vec, cost_fn, scale_cost, fn=None): fn(cost) matrix (or transpose) to vector. Args: - x: jax.Array [num_a, d], first pointcloud - y: jax.Array [num_b, d], second pointcloud - norm_x: jax.Array [num_a,], (squared) norm as defined in by cost_fn - norm_y: jax.Array [num_b,], (squared) norm as defined in by cost_fn - vec: jax.Array [num_a,] ([num_b,] if axis=1 from `apply_cost`) vector + x: jnp.ndarray [num_a, d], first pointcloud + y: jnp.ndarray [num_b, d], second pointcloud + norm_x: jnp.ndarray [num_a,], (squared) norm as defined in by cost_fn + norm_y: jnp.ndarray [num_b,], (squared) norm as defined in by cost_fn + vec: jnp.ndarray [num_a,] ([num_b,] if axis=1 from `apply_cost`) vector cost_fn: a CostFn function between two points in dimension d. scale_cost: scaling factor of the cost matrix. fn: function optionally applied to cost matrix element-wise, before the apply. Returns: - A jax.Array corresponding to cost x vector + A jnp.ndarray corresponding to cost x vector """ c = _cost(x, y, norm_x, norm_y, cost_fn, scale_cost) return jnp.dot(c, vec) if fn is None else jnp.dot(fn(c), vec) diff --git a/src/ott/geometry/segment.py b/src/ott/geometry/segment.py index 5e2c764c8..20a1ee92b 100644 --- a/src/ott/geometry/segment.py +++ b/src/ott/geometry/segment.py @@ -21,15 +21,15 @@ def segment_point_cloud( - x: jax.Array, - a: Optional[jax.Array] = None, + x: jnp.ndarray, + a: Optional[jnp.ndarray] = None, num_segments: Optional[int] = None, max_measure_size: Optional[int] = None, - segment_ids: Optional[jax.Array] = None, + segment_ids: Optional[jnp.ndarray] = None, indices_are_sorted: bool = False, num_per_segment: Optional[Tuple[int, ...]] = None, - padding_vector: Optional[jax.Array] = None -) -> Tuple[jax.Array, jax.Array]: + padding_vector: Optional[jnp.ndarray] = None +) -> Tuple[jnp.ndarray, jnp.ndarray]: """Segment and pad as needed the entries of a point cloud. There are two interfaces: @@ -129,20 +129,21 @@ def segment_point_cloud( def _segment_interface( - x: jax.Array, - y: jax.Array, - eval_fn: Callable[[jax.Array, jax.Array, jax.Array, jax.Array], jax.Array], + x: jnp.ndarray, + y: jnp.ndarray, + eval_fn: Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray], + jnp.ndarray], num_segments: Optional[int] = None, max_measure_size: Optional[int] = None, - segment_ids_x: Optional[jax.Array] = None, - segment_ids_y: Optional[jax.Array] = None, + segment_ids_x: Optional[jnp.ndarray] = None, + segment_ids_y: Optional[jnp.ndarray] = None, indices_are_sorted: bool = False, - num_per_segment_x: Optional[jax.Array] = None, - num_per_segment_y: Optional[jax.Array] = None, - weights_x: Optional[jax.Array] = None, - weights_y: Optional[jax.Array] = None, - padding_vector: Optional[jax.Array] = None, -) -> jax.Array: + num_per_segment_x: Optional[jnp.ndarray] = None, + num_per_segment_y: Optional[jnp.ndarray] = None, + weights_x: Optional[jnp.ndarray] = None, + weights_y: Optional[jnp.ndarray] = None, + padding_vector: Optional[jnp.ndarray] = None, +) -> jnp.ndarray: """Wrapper to segment two point clouds and return parallel evaluations. Utility function that segments two point clouds using the approach outlined diff --git a/src/ott/initializers/linear/initializers.py b/src/ott/initializers/linear/initializers.py index 58744cfb0..bc4871841 100644 --- a/src/ott/initializers/linear/initializers.py +++ b/src/ott/initializers/linear/initializers.py @@ -36,8 +36,8 @@ def init_dual_a( self, ot_prob: linear_problem.LinearProblem, lse_mode: bool, - rng: Optional[jax.Array] = None, - ) -> jax.Array: + rng: Optional[jnp.ndarray] = None, + ) -> jnp.ndarray: """Initialize Sinkhorn potential/scaling f_u. Args: @@ -54,8 +54,8 @@ def init_dual_b( self, ot_prob: linear_problem.LinearProblem, lse_mode: bool, - rng: Optional[jax.Array] = None, - ) -> jax.Array: + rng: Optional[jnp.ndarray] = None, + ) -> jnp.ndarray: """Initialize Sinkhorn potential/scaling g_v. Args: @@ -70,11 +70,11 @@ def init_dual_b( def __call__( self, ot_prob: linear_problem.LinearProblem, - a: Optional[jax.Array], - b: Optional[jax.Array], + a: Optional[jnp.ndarray], + b: Optional[jnp.ndarray], lse_mode: bool, - rng: Optional[jax.Array] = None, - ) -> Tuple[jax.Array, jax.Array]: + rng: Optional[jnp.ndarray] = None, + ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Initialize Sinkhorn potentials/scalings f_u and g_v. Args: @@ -128,8 +128,8 @@ def init_dual_a( # noqa: D102 self, ot_prob: linear_problem.LinearProblem, lse_mode: bool, - rng: Optional[jax.Array] = None, - ) -> jax.Array: + rng: Optional[jnp.ndarray] = None, + ) -> jnp.ndarray: del rng return jnp.zeros_like(ot_prob.a) if lse_mode else jnp.ones_like(ot_prob.a) @@ -137,8 +137,8 @@ def init_dual_b( # noqa: D102 self, ot_prob: linear_problem.LinearProblem, lse_mode: bool, - rng: Optional[jax.Array] = None, - ) -> jax.Array: + rng: Optional[jnp.ndarray] = None, + ) -> jnp.ndarray: del rng return jnp.zeros_like(ot_prob.b) if lse_mode else jnp.ones_like(ot_prob.b) @@ -158,8 +158,8 @@ def init_dual_a( # noqa: D102 self, ot_prob: linear_problem.LinearProblem, lse_mode: bool, - rng: Optional[jax.Array] = None, - ) -> jax.Array: + rng: Optional[jnp.ndarray] = None, + ) -> jnp.ndarray: # import Gaussian here due to circular imports from ott.tools.gaussian_mixture import gaussian @@ -207,8 +207,8 @@ def __init__( self.vectorized_update = vectorized_update def _init_sorting_dual( - self, modified_cost: jax.Array, init_f: jax.Array - ) -> jax.Array: + self, modified_cost: jnp.ndarray, init_f: jnp.ndarray + ) -> jnp.ndarray: """Run DualSort algorithm. Args: @@ -221,15 +221,15 @@ def _init_sorting_dual( """ def body_fn( - state: Tuple[jax.Array, float, int] - ) -> Tuple[jax.Array, float, int]: + state: Tuple[jnp.ndarray, float, int] + ) -> Tuple[jnp.ndarray, float, int]: prev_f, _, it = state new_f = fn(prev_f, modified_cost) diff = jnp.sum((new_f - prev_f) ** 2) it += 1 return new_f, diff, it - def cond_fn(state: Tuple[jax.Array, float, int]) -> bool: + def cond_fn(state: Tuple[jnp.ndarray, float, int]) -> bool: _, diff, it = state return jnp.logical_and(diff > self.tolerance, it < self.max_iter) @@ -245,9 +245,9 @@ def init_dual_a( self, ot_prob: linear_problem.LinearProblem, lse_mode: bool, - rng: Optional[jax.Array] = None, - init_f: Optional[jax.Array] = None, - ) -> jax.Array: + rng: Optional[jnp.ndarray] = None, + init_f: Optional[jnp.ndarray] = None, + ) -> jnp.ndarray: """Apply DualSort algorithm. Args: @@ -324,8 +324,8 @@ def init_dual_a( # noqa: D102 self, ot_prob: linear_problem.LinearProblem, lse_mode: bool, - rng: Optional[jax.Array] = None, - ) -> jax.Array: + rng: Optional[jnp.ndarray] = None, + ) -> jnp.ndarray: from ott.solvers import linear assert isinstance( @@ -373,7 +373,9 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 }) -def _vectorized_update(f: jax.Array, modified_cost: jax.Array) -> jax.Array: +def _vectorized_update( + f: jnp.ndarray, modified_cost: jnp.ndarray +) -> jnp.ndarray: """Inner loop DualSort Update. Args: @@ -386,7 +388,9 @@ def _vectorized_update(f: jax.Array, modified_cost: jax.Array) -> jax.Array: return jnp.min(modified_cost + f[None, :], axis=1) -def _coordinate_update(f: jax.Array, modified_cost: jax.Array) -> jax.Array: +def _coordinate_update( + f: jnp.ndarray, modified_cost: jnp.ndarray +) -> jnp.ndarray: """Coordinate-wise updates within inner loop. Args: @@ -397,7 +401,7 @@ def _coordinate_update(f: jax.Array, modified_cost: jax.Array) -> jax.Array: updated potential vector, f. """ - def body_fn(i: int, f: jax.Array) -> jax.Array: + def body_fn(i: int, f: jnp.ndarray) -> jnp.ndarray: new_f = jnp.min(modified_cost[i, :] + f) return f.at[i].set(new_f) diff --git a/src/ott/initializers/linear/initializers_lr.py b/src/ott/initializers/linear/initializers_lr.py index 9eb8e1231..b1f70d912 100644 --- a/src/ott/initializers/linear/initializers_lr.py +++ b/src/ott/initializers/linear/initializers_lr.py @@ -67,11 +67,11 @@ def __init__(self, rank: int, **kwargs: Any): def init_q( self, ot_prob: Problem_t, - rng: jax.Array, + rng: jnp.ndarray, *, - init_g: jax.Array, + init_g: jnp.ndarray, **kwargs: Any, - ) -> jax.Array: + ) -> jnp.ndarray: """Initialize the low-rank factor :math:`Q`. Args: @@ -88,11 +88,11 @@ def init_q( def init_r( self, ot_prob: Problem_t, - rng: jax.Array, + rng: jnp.ndarray, *, - init_g: jax.Array, + init_g: jnp.ndarray, **kwargs: Any, - ) -> jax.Array: + ) -> jnp.ndarray: """Initialize the low-rank factor :math:`R`. Args: @@ -109,9 +109,9 @@ def init_r( def init_g( self, ot_prob: Problem_t, - rng: jax.Array, + rng: jnp.ndarray, **kwargs: Any, - ) -> jax.Array: + ) -> jnp.ndarray: """Initialize the low-rank factor :math:`g`. Args: @@ -165,13 +165,13 @@ def from_solver( def __call__( self, ot_prob: Problem_t, - q: Optional[jax.Array] = None, - r: Optional[jax.Array] = None, - g: Optional[jax.Array] = None, + q: Optional[jnp.ndarray] = None, + r: Optional[jnp.ndarray] = None, + g: Optional[jnp.ndarray] = None, *, - rng: Optional[jax.Array] = None, + rng: Optional[jnp.ndarray] = None, **kwargs: Any - ) -> Tuple[jax.Array, jax.Array, jax.Array]: + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Initialize the factors :math:`Q`, :math:`R` and :math:`g`. Args: @@ -232,11 +232,11 @@ class RandomInitializer(LRInitializer): def init_q( # noqa: D102 self, ot_prob: Problem_t, - rng: jax.Array, + rng: jnp.ndarray, *, - init_g: jax.Array, + init_g: jnp.ndarray, **kwargs: Any, - ) -> jax.Array: + ) -> jnp.ndarray: del kwargs, init_g a = ot_prob.a init_q = jnp.abs(jax.random.normal(rng, (a.shape[0], self.rank))) @@ -245,11 +245,11 @@ def init_q( # noqa: D102 def init_r( # noqa: D102 self, ot_prob: Problem_t, - rng: jax.Array, + rng: jnp.ndarray, *, - init_g: jax.Array, + init_g: jnp.ndarray, **kwargs: Any, - ) -> jax.Array: + ) -> jnp.ndarray: del kwargs, init_g b = ot_prob.b init_r = jnp.abs(jax.random.normal(rng, (b.shape[0], self.rank))) @@ -258,9 +258,9 @@ def init_r( # noqa: D102 def init_g( # noqa: D102 self, ot_prob: Problem_t, - rng: jax.Array, + rng: jnp.ndarray, **kwargs: Any, - ) -> jax.Array: + ) -> jnp.ndarray: del kwargs init_g = jnp.abs(jax.random.uniform(rng, (self.rank,))) + 1. return init_g / jnp.sum(init_g) @@ -278,10 +278,10 @@ class Rank2Initializer(LRInitializer): def _compute_factor( self, ot_prob: Problem_t, - init_g: jax.Array, + init_g: jnp.ndarray, *, which: Literal["q", "r"], - ) -> jax.Array: + ) -> jnp.ndarray: a, b = ot_prob.a, ot_prob.b marginal = a if which == "q" else b n, r = marginal.shape[0], self.rank @@ -305,31 +305,31 @@ def _compute_factor( def init_q( # noqa: D102 self, ot_prob: Problem_t, - rng: jax.Array, + rng: jnp.ndarray, *, - init_g: jax.Array, + init_g: jnp.ndarray, **kwargs: Any, - ) -> jax.Array: + ) -> jnp.ndarray: del rng, kwargs return self._compute_factor(ot_prob, init_g, which="q") def init_r( # noqa: D102 self, ot_prob: Problem_t, - rng: jax.Array, + rng: jnp.ndarray, *, - init_g: jax.Array, + init_g: jnp.ndarray, **kwargs: Any, - ) -> jax.Array: + ) -> jnp.ndarray: del rng, kwargs return self._compute_factor(ot_prob, init_g, which="r") def init_g( # noqa: D102 self, ot_prob: Problem_t, - rng: jax.Array, + rng: jnp.ndarray, **kwargs: Any, - ) -> jax.Array: + ) -> jnp.ndarray: del rng, kwargs return jnp.ones((self.rank,)) / self.rank @@ -364,7 +364,7 @@ def __init__( self._sinkhorn_kwargs = {} if sinkhorn_kwargs is None else sinkhorn_kwargs @staticmethod - def _extract_array(geom: geometry.Geometry, *, first: bool) -> jax.Array: + def _extract_array(geom: geometry.Geometry, *, first: bool) -> jnp.ndarray: if isinstance(geom, pointcloud.PointCloud): return geom.x if first else geom.y if isinstance(geom, low_rank.LRCGeometry): @@ -376,12 +376,12 @@ def _extract_array(geom: geometry.Geometry, *, first: bool) -> jax.Array: def _compute_factor( self, ot_prob: Problem_t, - rng: jax.Array, + rng: jnp.ndarray, *, - init_g: jax.Array, + init_g: jnp.ndarray, which: Literal["q", "r"], **kwargs: Any, - ) -> jax.Array: + ) -> jnp.ndarray: from ott.problems.linear import linear_problem from ott.problems.quadratic import quadratic_problem from ott.solvers.linear import sinkhorn @@ -418,11 +418,11 @@ def _compute_factor( def init_q( # noqa: D102 self, ot_prob: Problem_t, - rng: jax.Array, + rng: jnp.ndarray, *, - init_g: jax.Array, + init_g: jnp.ndarray, **kwargs: Any, - ) -> jax.Array: + ) -> jnp.ndarray: return self._compute_factor( ot_prob, rng, init_g=init_g, which="q", **kwargs ) @@ -430,11 +430,11 @@ def init_q( # noqa: D102 def init_r( # noqa: D102 self, ot_prob: Problem_t, - rng: jax.Array, + rng: jnp.ndarray, *, - init_g: jax.Array, + init_g: jnp.ndarray, **kwargs: Any, - ) -> jax.Array: + ) -> jnp.ndarray: return self._compute_factor( ot_prob, rng, init_g=init_g, which="r", **kwargs ) @@ -442,9 +442,9 @@ def init_r( # noqa: D102 def init_g( # noqa: D102 self, ot_prob: Problem_t, - rng: jax.Array, + rng: jnp.ndarray, **kwargs: Any, - ) -> jax.Array: + ) -> jnp.ndarray: del rng, kwargs return jnp.ones((self.rank,)) / self.rank @@ -498,25 +498,25 @@ def __init__( class Constants(NamedTuple): # noqa: D106 solver: "sinkhorn.Sinkhorn" geom: geometry.Geometry # (n, n) - marginal: jax.Array # (n,) - g: jax.Array # (r,) + marginal: jnp.ndarray # (n,) + g: jnp.ndarray # (r,) gamma: float threshold: float class State(NamedTuple): # noqa: D106 - factor: jax.Array - criterions: jax.Array + factor: jnp.ndarray + criterions: jnp.ndarray crossed_threshold: bool def _compute_factor( self, ot_prob: Problem_t, - rng: jax.Array, + rng: jnp.ndarray, *, - init_g: jax.Array, + init_g: jnp.ndarray, which: Literal["q", "r"], **kwargs: Any, - ) -> jax.Array: + ) -> jnp.ndarray: from ott.problems.linear import linear_problem from ott.problems.quadratic import quadratic_problem from ott.solvers.linear import sinkhorn diff --git a/src/ott/initializers/quadratic/initializers.py b/src/ott/initializers/quadratic/initializers.py index 323570770..795e81ccc 100644 --- a/src/ott/initializers/quadratic/initializers.py +++ b/src/ott/initializers/quadratic/initializers.py @@ -125,7 +125,9 @@ class QuadraticInitializer(BaseQuadraticInitializer): defaults to the product coupling :math:`ab^T`. """ - def __init__(self, init_coupling: Optional[jax.Array] = None, **kwargs: Any): + def __init__( + self, init_coupling: Optional[jnp.ndarray] = None, **kwargs: Any + ): super().__init__(**kwargs) self.init_coupling = init_coupling diff --git a/src/ott/math/fixed_point_loop.py b/src/ott/math/fixed_point_loop.py index 5c8b7b94d..9034eba62 100644 --- a/src/ott/math/fixed_point_loop.py +++ b/src/ott/math/fixed_point_loop.py @@ -179,7 +179,7 @@ def fixpoint_iter_bwd( # The tree may contain some python floats g_constants = jax.tree_util.tree_map( lambda x: jnp.zeros_like(x, dtype=x.dtype) - if isinstance(x, (np.ndarray, jax.Array)) else 0, constants + if isinstance(x, (np.ndarray, jnp.ndarray)) else 0, constants ) def bwd_cond_fn(iteration_g_gconst): diff --git a/src/ott/math/matrix_square_root.py b/src/ott/math/matrix_square_root.py index 5089f14a0..4a0177780 100644 --- a/src/ott/math/matrix_square_root.py +++ b/src/ott/math/matrix_square_root.py @@ -25,13 +25,13 @@ @functools.partial(jax.custom_vjp, nondiff_argnums=(1, 2, 3, 4, 5)) def sqrtm( - x: jax.Array, + x: jnp.ndarray, threshold: float = 1e-6, min_iterations: int = 0, inner_iterations: int = 10, max_iterations: int = 1000, regularization: float = 1e-6 -) -> Tuple[jax.Array, jax.Array, jax.Array]: +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Higham algorithm to compute matrix square root of p.d. matrix. See :cite:`higham:97`, eq. 2.6b @@ -118,10 +118,10 @@ def new_err(x, norm_x, y): def solve_sylvester_bartels_stewart( - a: jax.Array, - b: jax.Array, - c: jax.Array, -) -> jax.Array: + a: jnp.ndarray, + b: jnp.ndarray, + c: jnp.ndarray, +) -> jnp.ndarray: """Solve the real Sylvester equation AX - XB = C using Bartels-Stewart.""" # See https://nhigham.com/2020/09/01/what-is-the-sylvester-equation/ for # discussion of the algorithm (but note that in the derivation, the sign on @@ -153,13 +153,14 @@ def solve_sylvester_bartels_stewart( def sqrtm_fwd( - x: jax.Array, + x: jnp.ndarray, threshold: float, min_iterations: int, inner_iterations: int, max_iterations: int, regularization: float, -) -> Tuple[Tuple[jax.Array, jax.Array, jax.Array], Tuple[jax.Array, jax.Array]]: +) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, + jnp.ndarray]]: """Forward pass of custom VJP.""" sqrt_x, inv_sqrt_x, errors = sqrtm( x=x, @@ -178,9 +179,9 @@ def sqrtm_bwd( inner_iterations: int, max_iterations: int, regularization: float, - residual: Tuple[jax.Array, jax.Array], - cotangent: Tuple[jax.Array, jax.Array, jax.Array], -) -> Tuple[jax.Array]: + residual: Tuple[jnp.ndarray, jnp.ndarray], + cotangent: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], +) -> Tuple[jnp.ndarray]: """Compute the derivative by solving a Sylvester equation.""" del threshold, min_iterations, inner_iterations, \ max_iterations, regularization @@ -236,13 +237,13 @@ def sqrtm_bwd( @functools.partial(jax.custom_vjp, nondiff_argnums=(1, 2, 3, 4, 5)) def sqrtm_only( # noqa: D103 - x: jax.Array, + x: jnp.ndarray, threshold: float = 1e-6, min_iterations: int = 0, inner_iterations: int = 10, max_iterations: int = 1000, regularization: float = 1e-6 -) -> jax.Array: +) -> jnp.ndarray: return sqrtm( x, threshold, min_iterations, inner_iterations, max_iterations, regularization @@ -250,9 +251,9 @@ def sqrtm_only( # noqa: D103 def sqrtm_only_fwd( # noqa: D103 - x: jax.Array, threshold: float, min_iterations: int, + x: jnp.ndarray, threshold: float, min_iterations: int, inner_iterations: int, max_iterations: int, regularization: float -) -> Tuple[jax.Array, jax.Array]: +) -> Tuple[jnp.ndarray, jnp.ndarray]: sqrt_x = sqrtm( x, threshold, min_iterations, inner_iterations, max_iterations, regularization @@ -262,9 +263,9 @@ def sqrtm_only_fwd( # noqa: D103 def sqrtm_only_bwd( # noqa: D103 threshold: float, min_iterations: int, inner_iterations: int, - max_iterations: int, regularization: float, sqrt_x: jax.Array, - cotangent: jax.Array -) -> Tuple[jax.Array]: + max_iterations: int, regularization: float, sqrt_x: jnp.ndarray, + cotangent: jnp.ndarray +) -> Tuple[jnp.ndarray]: del threshold, min_iterations, inner_iterations, \ max_iterations, regularization vjp = jnp.swapaxes( @@ -282,13 +283,13 @@ def sqrtm_only_bwd( # noqa: D103 @functools.partial(jax.custom_vjp, nondiff_argnums=(1, 2, 3, 4, 5)) def inv_sqrtm_only( # noqa: D103 - x: jax.Array, + x: jnp.ndarray, threshold: float = 1e-6, min_iterations: int = 0, inner_iterations: int = 10, max_iterations: int = 1000, regularization: float = 1e-6 -) -> jax.Array: +) -> jnp.ndarray: return sqrtm( x, threshold, min_iterations, inner_iterations, max_iterations, regularization @@ -296,13 +297,13 @@ def inv_sqrtm_only( # noqa: D103 def inv_sqrtm_only_fwd( # noqa: D103 - x: jax.Array, + x: jnp.ndarray, threshold: float, min_iterations: int, inner_iterations: int, max_iterations: int, regularization: float, -) -> Tuple[jax.Array, jax.Array]: +) -> Tuple[jnp.ndarray, jnp.ndarray]: inv_sqrt_x = sqrtm( x, threshold, min_iterations, inner_iterations, max_iterations, regularization @@ -312,9 +313,9 @@ def inv_sqrtm_only_fwd( # noqa: D103 def inv_sqrtm_only_bwd( # noqa: D103 threshold: float, min_iterations: int, inner_iterations: int, - max_iterations: int, regularization: float, residual: jax.Array, - cotangent: jax.Array -) -> Tuple[jax.Array]: + max_iterations: int, regularization: float, residual: jnp.ndarray, + cotangent: jnp.ndarray +) -> Tuple[jnp.ndarray]: del threshold, min_iterations, inner_iterations, \ max_iterations, regularization diff --git a/src/ott/math/unbalanced_functions.py b/src/ott/math/unbalanced_functions.py index 2d7baebb7..fc1aca9f3 100644 --- a/src/ott/math/unbalanced_functions.py +++ b/src/ott/math/unbalanced_functions.py @@ -13,32 +13,31 @@ # limitations under the License. from typing import Callable -import jax import jax.numpy as jnp -def phi_star(h: jax.Array, rho: float) -> jax.Array: +def phi_star(h: jnp.ndarray, rho: float) -> jnp.ndarray: """Legendre transform of KL, :cite:`sejourne:19`, p. 9.""" return rho * (jnp.exp(h / rho) - 1) -def derivative_phi_star(f: jax.Array, rho: float) -> jax.Array: +def derivative_phi_star(f: jnp.ndarray, rho: float) -> jnp.ndarray: """Derivative of Legendre transform of phi_starKL, see phi_star.""" # TODO(cuturi): use jax.grad directly. return jnp.exp(f / rho) def grad_of_marginal_fit( - c: jax.Array, h: jax.Array, tau: float, epsilon: float -) -> jax.Array: + c: jnp.ndarray, h: jnp.ndarray, tau: float, epsilon: float +) -> jnp.ndarray: """Compute grad of terms linked to marginals in objective. Computes gradient w.r.t. f ( or g) of terms in :cite:`sejourne:19`, left-hand-side of eq. 15 terms involving phi_star). Args: - c: jax.Array, first target marginal (either a or b in practice) - h: jax.Array, potential (either f or g in practice) + c: jnp.ndarray, first target marginal (either a or b in practice) + h: jnp.ndarray, potential (either f or g in practice) tau: float, strength (in ]0,1]) of regularizer w.r.t. marginal epsilon: regularization @@ -51,14 +50,14 @@ def grad_of_marginal_fit( return jnp.where(c > 0, c * derivative_phi_star(-h, r), 0.0) -def second_derivative_phi_star(f: jax.Array, rho: float) -> jax.Array: +def second_derivative_phi_star(f: jnp.ndarray, rho: float) -> jnp.ndarray: """Second Derivative of Legendre transform of KL, see phi_star.""" return jnp.exp(f / rho) / rho def diag_jacobian_of_marginal_fit( - c: jax.Array, h: jax.Array, tau: float, epsilon: float, - derivative: Callable[[jax.Array, float], jax.Array] + c: jnp.ndarray, h: jnp.ndarray, tau: float, epsilon: float, + derivative: Callable[[jnp.ndarray, float], jnp.ndarray] ): """Compute grad of terms linked to marginals in objective. @@ -66,8 +65,8 @@ def diag_jacobian_of_marginal_fit( left-hand-side of eq. 32 (terms involving phi_star) Args: - c: jax.Array, first target marginal (either a or b in practice) - h: jax.Array, potential (either f or g in practice) + c: jnp.ndarray, first target marginal (either a or b in practice) + h: jnp.ndarray, potential (either f or g in practice) tau: float, strength (in ]0,1]) of regularizer w.r.t. marginal epsilon: regularization derivative: Callable diff --git a/src/ott/math/utils.py b/src/ott/math/utils.py index 188707c10..8e7ea90ee 100644 --- a/src/ott/math/utils.py +++ b/src/ott/math/utils.py @@ -34,10 +34,10 @@ def safe_log( # noqa: D103 - x: jax.Array, + x: jnp.ndarray, *, eps: Optional[float] = None -) -> jax.Array: +) -> jnp.ndarray: if eps is None: eps = jnp.finfo(x.dtype).tiny return jnp.where(x > 0., jnp.log(x), jnp.log(eps)) @@ -46,11 +46,11 @@ def safe_log( # noqa: D103 @functools.partial(jax.custom_jvp, nondiff_argnums=[1, 2, 3]) @functools.partial(jax.jit, static_argnames=("ord", "axis", "keepdims")) def norm( - x: jax.Array, + x: jnp.ndarray, ord: Union[int, str, None] = None, axis: Union[None, Sequence[int], int] = None, keepdims: bool = False -) -> jax.Array: +) -> jnp.ndarray: """Computes order ord norm of vector, using `jnp.linalg` in forward pass. Evaluations of distances between a vector and itself using translation @@ -105,18 +105,18 @@ def norm_jvp(ord, axis, keepdims, primals, tangents): # TODO(michalk8): add axis argument -def kl(p: jax.Array, q: jax.Array) -> float: +def kl(p: jnp.ndarray, q: jnp.ndarray) -> float: """Kullback-Leibler divergence.""" return jnp.vdot(p, (safe_log(p) - safe_log(q))) -def gen_kl(p: jax.Array, q: jax.Array) -> float: +def gen_kl(p: jnp.ndarray, q: jnp.ndarray) -> float: """Generalized Kullback-Leibler divergence.""" return jnp.vdot(p, (safe_log(p) - safe_log(q))) + jnp.sum(q) - jnp.sum(p) # TODO(michalk8): add axis argument -def gen_js(p: jax.Array, q: jax.Array, c: float = 0.5) -> float: +def gen_js(p: jnp.ndarray, q: jnp.ndarray, c: float = 0.5) -> float: """Jensen-Shannon divergence.""" return c * (gen_kl(p, q) + gen_kl(q, p)) @@ -176,8 +176,8 @@ def logsumexp_jvp(axis, keepdims, return_sign, primals, tangents): @functools.partial(jax.custom_vjp, nondiff_argnums=(2,)) def softmin( - x: jax.Array, gamma: float, axis: Optional[int] = None -) -> jax.Array: + x: jnp.ndarray, gamma: float, axis: Optional[int] = None +) -> jnp.ndarray: r"""Soft-min operator. Args: @@ -205,8 +205,8 @@ def softmin( @functools.partial(jax.vmap, in_axes=[0, 0, None]) def barycentric_projection( - matrix: jax.Array, y: jax.Array, cost_fn: "costs.CostFn" -) -> jax.Array: + matrix: jnp.ndarray, y: jnp.ndarray, cost_fn: "costs.CostFn" +) -> jnp.ndarray: """Compute the barycentric projection of a matrix. Args: diff --git a/src/ott/neural/data/dataloaders.py b/src/ott/neural/data/dataloaders.py index 0ebfc77a0..466460384 100644 --- a/src/ott/neural/data/dataloaders.py +++ b/src/ott/neural/data/dataloaders.py @@ -19,14 +19,14 @@ class ConditionalDataLoader: #TODO(@MUCDK) uncomment, resolve installation issu #def __init__( # self, rng: jax.random.KeyArray, dataloaders: Dict[str, tf.Dataloader], - # p: jax.Array + # p: jnp.ndarray #) -> None: # super().__init__() # self.rng = rng # self.conditions = dataloaders.keys() # self.p = p - #def __next__(self) -> jax.Array: + #def __next__(self) -> jnp.ndarray: # self.rng, rng = jax.random.split(self.rng, 2) # condition = jax.random.choice(rng, self.conditions, p=self.p) # return next(self.dataloaders[condition]) diff --git a/src/ott/neural/models/base_models.py b/src/ott/neural/models/base_models.py index 74a87df93..c96ad5b29 100644 --- a/src/ott/neural/models/base_models.py +++ b/src/ott/neural/models/base_models.py @@ -15,7 +15,7 @@ from typing import Optional import flax.linen as nn -import jax +import jax.numpy as jnp __all__ = ["BaseNeuralVectorField", "BaseRescalingNet"] @@ -25,11 +25,11 @@ class BaseNeuralVectorField(nn.Module, abc.ABC): @abc.abstractmethod def __call__( self, - t: jax.Array, - x: jax.Array, - condition: Optional[jax.Array] = None, - keys_model: Optional[jax.Array] = None - ) -> jax.Array: # noqa: D102): + t: jnp.ndarray, + x: jnp.ndarray, + condition: Optional[jnp.ndarray] = None, + keys_model: Optional[jnp.ndarray] = None + ) -> jnp.ndarray: # noqa: D102): pass @@ -37,6 +37,8 @@ class BaseRescalingNet(nn.Module, abc.ABC): @abc.abstractmethod def __call__( - self, x: jax.Array, condition: Optional[jax.Array] = None - ) -> jax.Array: + self, + x: jnp.ndarray, + condition: Optional[jnp.ndarray] = None + ) -> jnp.ndarray: pass diff --git a/src/ott/neural/models/conjugate_solvers.py b/src/ott/neural/models/conjugate_solvers.py index 4d3d8eea0..0758cf1ad 100644 --- a/src/ott/neural/models/conjugate_solvers.py +++ b/src/ott/neural/models/conjugate_solvers.py @@ -14,7 +14,6 @@ import abc from typing import Callable, Literal, NamedTuple, Optional -import jax import jax.numpy as jnp from jaxopt import LBFGS @@ -37,7 +36,7 @@ class ConjugateResults(NamedTuple): num_iter: the number of iterations taken by the solver """ val: float - grad: jax.Array + grad: jnp.ndarray num_iter: int @@ -51,9 +50,9 @@ class FenchelConjugateSolver(abc.ABC): @abc.abstractmethod def solve( self, - f: Callable[[jax.Array], jax.Array], - y: jax.Array, - x_init: Optional[jax.Array] = None + f: Callable[[jnp.ndarray], jnp.ndarray], + y: jnp.ndarray, + x_init: Optional[jnp.ndarray] = None ) -> ConjugateResults: """Solve for the conjugate. @@ -91,8 +90,8 @@ class FenchelConjugateLBFGS(FenchelConjugateSolver): def solve( # noqa: D102 self, - f: Callable[[jax.Array], jax.Array], - y: jax.Array, + f: Callable[[jnp.ndarray], jnp.ndarray], + y: jnp.ndarray, x_init: Optional[jnp.array] = None ) -> ConjugateResults: assert y.ndim == 1, y.ndim diff --git a/src/ott/neural/models/layers.py b/src/ott/neural/models/layers.py index 0eac7e626..dffd48276 100644 --- a/src/ott/neural/models/layers.py +++ b/src/ott/neural/models/layers.py @@ -19,7 +19,7 @@ __all__ = ["PositiveDense", "PosDefPotentials"] -PRNGKey = jax.Array +PRNGKey = jnp.ndarray Shape = Tuple[int, ...] Dtype = Any Array = Any @@ -40,9 +40,9 @@ class PositiveDense(nn.Module): bias_init: initializer function for the bias. """ dim_hidden: int - rectifier_fn: Callable[[jax.Array], jax.Array] = nn.softplus - inv_rectifier_fn: Callable[[jax.Array], - jax.Array] = lambda x: jnp.log(jnp.exp(x) - 1) + rectifier_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.softplus + inv_rectifier_fn: Callable[[jnp.ndarray], + jnp.ndarray] = lambda x: jnp.log(jnp.exp(x) - 1) use_bias: bool = True dtype: Any = jnp.float32 precision: Any = None @@ -51,7 +51,7 @@ class PositiveDense(nn.Module): bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.zeros @nn.compact - def __call__(self, inputs: jax.Array) -> jax.Array: + def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: """Applies a linear transformation to inputs along the last dimension. Args: @@ -99,7 +99,7 @@ class PosDefPotentials(nn.Module): bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.zeros @nn.compact - def __call__(self, inputs: jax.Array) -> jax.Array: + def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: """Apply a few quadratic forms. Args: diff --git a/src/ott/neural/models/models.py b/src/ott/neural/models/models.py index 6bb075ff3..80326d3ca 100644 --- a/src/ott/neural/models/models.py +++ b/src/ott/neural/models/models.py @@ -67,9 +67,9 @@ class ICNN(neuraldual.BaseW2NeuralDual): dim_hidden: Sequence[int] init_std: float = 1e-2 init_fn: Callable = jax.nn.initializers.normal - act_fn: Callable[[jax.Array], jax.Array] = nn.relu + act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu pos_weights: bool = True - gaussian_map_samples: Optional[Tuple[jax.Array, jax.Array]] = None + gaussian_map_samples: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None @property def is_potential(self) -> bool: # noqa: D102 @@ -151,8 +151,8 @@ def setup(self) -> None: # noqa: D102 @staticmethod def _compute_gaussian_map_params( - samples: Tuple[jax.Array, jax.Array] - ) -> Tuple[jax.Array, jax.Array]: + samples: Tuple[jnp.ndarray, jnp.ndarray] + ) -> Tuple[jnp.ndarray, jnp.ndarray]: from ott.tools.gaussian_mixture import gaussian source, target = samples g_s = gaussian.Gaussian.from_samples(source) @@ -165,13 +165,13 @@ def _compute_gaussian_map_params( @staticmethod def _compute_identity_map_params( input_dim: int - ) -> Tuple[jax.Array, jax.Array]: + ) -> Tuple[jnp.ndarray, jnp.ndarray]: A = jnp.eye(input_dim).reshape((1, input_dim, input_dim)) b = jnp.zeros((1, input_dim)) return A, b @nn.compact - def __call__(self, x: jax.Array) -> float: # noqa: D102 + def __call__(self, x: jnp.ndarray) -> float: # noqa: D102 z = self.act_fn(self.w_xs[0](x)) for i in range(self.num_hidden): z = jnp.add(self.w_zs[i](z), self.w_xs[i + 1](x)) @@ -194,10 +194,10 @@ class MLP(neuraldual.BaseW2NeuralDual): dim_hidden: Sequence[int] is_potential: bool = True - act_fn: Callable[[jax.Array], jax.Array] = nn.leaky_relu + act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.leaky_relu @nn.compact - def __call__(self, x: jax.Array) -> jax.Array: # noqa: D102 + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: # noqa: D102 squeeze = x.ndim == 1 if squeeze: x = jnp.expand_dims(x, 0) @@ -267,7 +267,7 @@ def __init__( meta_model: nn.Module, opt: Optional[optax.GradientTransformation ] = optax.adam(learning_rate=1e-3), # noqa: B008 - rng: Optional[jax.Array] = None, + rng: Optional[jnp.ndarray] = None, state: Optional[train_state.TrainState] = None ): self.geom = geom @@ -294,8 +294,8 @@ def __init__( self.update_impl = self._get_update_fn() def update( - self, state: train_state.TrainState, a: jax.Array, b: jax.Array - ) -> Tuple[jax.Array, jax.Array, train_state.TrainState]: + self, state: train_state.TrainState, a: jnp.ndarray, b: jnp.ndarray + ) -> Tuple[jnp.ndarray, jnp.ndarray, train_state.TrainState]: r"""Update the meta model with the dual objective. The goal is for the model to match the optimal duals, i.e., @@ -333,8 +333,8 @@ def init_dual_a( # noqa: D102 self, ot_prob: "linear_problem.LinearProblem", lse_mode: bool, - rng: Optional[jax.Array] = None, - ) -> jax.Array: + rng: Optional[jnp.ndarray] = None, + ) -> jnp.ndarray: del rng # Detect if the problem is batched. assert ot_prob.a.ndim in (1, 2) @@ -387,9 +387,9 @@ def update(state, a, b): return update def _compute_f( - self, a: jax.Array, b: jax.Array, - params: frozen_dict.FrozenDict[str, jax.Array] - ) -> jax.Array: + self, a: jnp.ndarray, b: jnp.ndarray, + params: frozen_dict.FrozenDict[str, jnp.ndarray] + ) -> jnp.ndarray: r"""Predict the optimal :math:`f` potential. Args: @@ -431,10 +431,10 @@ class NeuralVectorField(BaseNeuralVectorField): t_embed_dim: Optional[int] = None joint_hidden_dim: Optional[int] = None num_layers_per_block: int = 3 - act_fn: Callable[[jax.Array], jax.Array] = nn.silu + act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.silu n_frequencies: int = 128 - def time_encoder(self, t: jax.Array) -> jnp.array: + def time_encoder(self, t: jnp.ndarray) -> jnp.array: freq = 2 * jnp.arange(self.n_frequencies) * jnp.pi t = freq * t return jnp.concatenate((jnp.cos(t), jnp.sin(t)), axis=-1) @@ -464,11 +464,11 @@ def __post_init__(self): @nn.compact def __call__( self, - t: jax.Array, - x: jax.Array, - condition: Optional[jax.Array], - keys_model: Optional[jax.Array] = None, - ) -> jax.Array: + t: jnp.ndarray, + x: jnp.ndarray, + condition: Optional[jnp.ndarray], + keys_model: Optional[jnp.ndarray] = None, + ) -> jnp.ndarray: t = self.time_encoder(t) t = Block( @@ -537,12 +537,12 @@ class Rescaling_MLP(BaseRescalingNet): hidden_dim: int condition_dim: int num_layers_per_block: int = 3 - act_fn: Callable[[jax.Array], jax.Array] = nn.selu + act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.selu @nn.compact def __call__( - self, x: jax.Array, condition: Optional[jax.Array] - ) -> jax.Array: # noqa: D102 + self, x: jnp.ndarray, condition: Optional[jnp.ndarray] + ) -> jnp.ndarray: # noqa: D102 x = Block( dim=self.hidden_dim, out_dim=self.hidden_dim, diff --git a/src/ott/neural/solvers/base_solver.py b/src/ott/neural/solvers/base_solver.py index e7216da46..0ad159a8f 100644 --- a/src/ott/neural/solvers/base_solver.py +++ b/src/ott/neural/solvers/base_solver.py @@ -83,10 +83,10 @@ def __init__(*args, **kwargs): def _resample_data( self, key: jax.random.KeyArray, - tmat: jax.Array, - source_arrays: Tuple[jax.Array, ...], - target_arrays: Tuple[jax.Array, ...], - ) -> Tuple[jax.Array, ...]: + tmat: jnp.ndarray, + source_arrays: Tuple[jnp.ndarray, ...], + target_arrays: Tuple[jnp.ndarray, ...], + ) -> Tuple[jnp.ndarray, ...]: """Resample a batch according to coupling `tmat`.""" tmat_flattened = tmat.flatten() indices = random.choice(key, len(tmat_flattened), shape=[tmat.shape[0]]) @@ -101,10 +101,10 @@ def _resample_data( def _sample_conditional_indices_from_tmap( self, key: jax.random.PRNGKeyArray, - tmat: jax.Array, - k_samples_per_x: Union[int, jax.Array], - source_arrays: Tuple[jax.Array, ...], - target_arrays: Tuple[jax.Array, ...], + tmat: jnp.ndarray, + k_samples_per_x: Union[int, jnp.ndarray], + source_arrays: Tuple[jnp.ndarray, ...], + target_arrays: Tuple[jnp.ndarray, ...], *, source_is_balanced: bool, ) -> Tuple[jnp.array, jnp.array]: @@ -161,8 +161,8 @@ def _get_sinkhorn_match_fn( ) -> Callable: def match_pairs( - x: jax.Array, y: jax.Array - ) -> Tuple[jax.Array, jax.Array, jax.Array, jax.Array]: + x: jnp.ndarray, y: jnp.ndarray + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: geom = pointcloud.PointCloud( x, y, epsilon=epsilon, scale_cost=scale_cost, cost_fn=cost_fn ) @@ -171,8 +171,9 @@ def match_pairs( ).matrix def match_pairs_filtered( - x_lin: jax.Array, x_quad: jax.Array, y_lin: jax.Array, y_quad: jax.Array - ) -> Tuple[jax.Array, jax.Array, jax.Array, jax.Array]: + x_lin: jnp.ndarray, x_quad: jnp.ndarray, y_lin: jnp.ndarray, + y_quad: jnp.ndarray + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: geom = pointcloud.PointCloud( x_lin, y_lin, epsilon=epsilon, scale_cost=scale_cost, cost_fn=cost_fn ) @@ -214,10 +215,10 @@ def _get_gromov_match_fn( x_scale_cost = y_scale_cost = xy_scale_cost = scale_cost def match_pairs( - x_lin: Optional[jax.Array], - x_quad: Tuple[jax.Array, jax.Array], - y_lin: Optional[jax.Array], - y_quad: Tuple[jax.Array, jax.Array], + x_lin: Optional[jnp.ndarray], + x_quad: Tuple[jnp.ndarray, jnp.ndarray], + y_lin: Optional[jnp.ndarray], + y_quad: Tuple[jnp.ndarray, jnp.ndarray], ) -> Tuple[jnp.array, jnp.array]: geom_xx = pointcloud.PointCloud( x=x_quad, y=x_quad, cost_fn=x_cost_fn, scale_cost=x_scale_cost @@ -250,7 +251,7 @@ class UnbalancednessMixin: def __init__( self, - rng: jax.Array, + rng: jnp.ndarray, source_dim: int, target_dim: int, cond_dim: Optional[int], @@ -298,13 +299,13 @@ def _get_compute_unbalanced_marginals( scale_cost: Union[bool, int, float, Literal["mean", "max_cost", "median"]] = "mean", sinkhorn_kwargs: Dict[str, Any] = MappingProxyType({}), - ) -> Tuple[jax.Array, jax.Array]: + ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Compute the unbalanced source and target marginals for a batch.""" @jax.jit def compute_unbalanced_marginals( - batch_source: jax.Array, batch_target: jax.Array - ) -> Tuple[jax.Array, jax.Array]: + batch_source: jnp.ndarray, batch_target: jnp.ndarray + ) -> Tuple[jnp.ndarray, jnp.ndarray]: geom = PointCloud( batch_source, batch_target, @@ -322,9 +323,9 @@ def compute_unbalanced_marginals( def _resample_unbalanced( self, key: jax.random.KeyArray, - batch: Tuple[jax.Array, ...], - marginals: jax.Array, - ) -> Tuple[jax.Array, ...]: + batch: Tuple[jnp.ndarray, ...], + marginals: jnp.ndarray, + ) -> Tuple[jnp.ndarray, ...]: """Resample a batch based upon marginals.""" indices = jax.random.choice( key, a=len(marginals), p=jnp.squeeze(marginals), shape=[len(marginals)] @@ -356,13 +357,14 @@ def _setup(self, source_dim: int, target_dim: int, cond_dim: int): def _get_rescaling_step_fn(self) -> Callable: # type:ignore[type-arg] def loss_a_fn( - params_eta: Optional[jax.Array], - apply_fn_eta: Callable[[Dict[str, jax.Array], jax.Array], jax.Array], - x: jax.Array, - condition: Optional[jax.Array], - a: jax.Array, + params_eta: Optional[jnp.ndarray], + apply_fn_eta: Callable[[Dict[str, jnp.ndarray], jnp.ndarray], + jnp.ndarray], + x: jnp.ndarray, + condition: Optional[jnp.ndarray], + a: jnp.ndarray, expectation_reweighting: float, - ) -> Tuple[float, jax.Array]: + ) -> Tuple[float, jnp.ndarray]: eta_predictions = apply_fn_eta({"params": params_eta}, x, condition) return ( optax.l2_loss(eta_predictions[:, 0], a).mean() + @@ -371,13 +373,14 @@ def loss_a_fn( ) def loss_b_fn( - params_xi: Optional[jax.Array], - apply_fn_xi: Callable[[Dict[str, jax.Array], jax.Array], jax.Array], - x: jax.Array, - condition: Optional[jax.Array], - b: jax.Array, + params_xi: Optional[jnp.ndarray], + apply_fn_xi: Callable[[Dict[str, jnp.ndarray], jnp.ndarray], + jnp.ndarray], + x: jnp.ndarray, + condition: Optional[jnp.ndarray], + b: jnp.ndarray, expectation_reweighting: float, - ) -> Tuple[float, jax.Array]: + ) -> Tuple[float, jnp.ndarray]: xi_predictions = apply_fn_xi({"params": params_xi}, x, condition) return ( optax.l2_loss(xi_predictions[:, 0], b).mean() + @@ -387,11 +390,11 @@ def loss_b_fn( @jax.jit def step_fn( - source: jax.Array, - target: jax.Array, - condition: Optional[jax.Array], - a: jax.Array, - b: jax.Array, + source: jnp.ndarray, + target: jnp.ndarray, + condition: Optional[jnp.ndarray], + a: jnp.ndarray, + b: jnp.ndarray, state_eta: Optional[train_state.TrainState] = None, state_xi: Optional[train_state.TrainState] = None, *, @@ -434,8 +437,8 @@ def step_fn( return step_fn def evaluate_eta( - self, source: jax.Array, condition: Optional[jax.Array] - ) -> jax.Array: + self, source: jnp.ndarray, condition: Optional[jnp.ndarray] + ) -> jnp.ndarray: """Evaluate the left learnt rescaling factor. Args: @@ -448,12 +451,12 @@ def evaluate_eta( if self.state_eta is None: raise ValueError("The left rescaling factor was not parameterized.") return self.state_eta.apply_fn({"params": self.state_eta.params}, - x=source, - condition=condition) + x=source, + condition=condition) def evaluate_xi( - self, target: jax.Array, condition: Optional[jax.Array] - ) -> jax.Array: + self, target: jnp.ndarray, condition: Optional[jnp.ndarray] + ) -> jnp.ndarray: """Evaluate the right learnt rescaling factor. Args: diff --git a/src/ott/neural/solvers/flows.py b/src/ott/neural/solvers/flows.py index 6450e2c1b..b02981fc9 100644 --- a/src/ott/neural/solvers/flows.py +++ b/src/ott/neural/solvers/flows.py @@ -33,7 +33,7 @@ def __init__(self, sigma: float) -> None: self.sigma = sigma @abc.abstractmethod - def compute_mu_t(self, t: jax.Array, x_0: jax.Array, x_1: jax.Array): + def compute_mu_t(self, t: jnp.ndarray, x_0: jnp.ndarray, x_1: jnp.ndarray): """Compute the mean of the probablitiy path between :math:`x` and :math:`y` at time :math:`t`. Args: @@ -44,7 +44,7 @@ def compute_mu_t(self, t: jax.Array, x_0: jax.Array, x_1: jax.Array): pass @abc.abstractmethod - def compute_sigma_t(self, t: jax.Array): + def compute_sigma_t(self, t: jnp.ndarray): """Compute the standard deviation of the probablity path at time :math:`t`. Args: @@ -54,8 +54,8 @@ def compute_sigma_t(self, t: jax.Array): @abc.abstractmethod def compute_ut( - self, t: jax.Array, x_0: jax.Array, x_1: jax.Array - ) -> jax.Array: + self, t: jnp.ndarray, x_0: jnp.ndarray, x_1: jnp.ndarray + ) -> jnp.ndarray: """Evaluate the conditional vector field defined between :math:`x_0` and :math:`x_1` at time :math:`t`. Args: @@ -66,8 +66,9 @@ def compute_ut( pass def compute_xt( - self, noise: jax.Array, t: jax.Array, x_0: jax.Array, x_1: jax.Array - ) -> jax.Array: + self, noise: jnp.ndarray, t: jnp.ndarray, x_0: jnp.ndarray, + x_1: jnp.ndarray + ) -> jnp.ndarray: """Sample from the probability path between :math:`x_0` and :math:`x_1` at time :math:`t`. Args: @@ -88,8 +89,8 @@ class StraightFlow(BaseFlow, abc.ABC): """Base class for flows with straight paths.""" def compute_mu_t( - self, t: jax.Array, x_0: jax.Array, x_1: jax.Array - ) -> jax.Array: + self, t: jnp.ndarray, x_0: jnp.ndarray, x_1: jnp.ndarray + ) -> jnp.ndarray: """Compute the mean of the probablitiy path between :math:`x` and :math:`y` at time :math:`t`. Args: @@ -100,8 +101,8 @@ def compute_mu_t( return t * x_0 + (1 - t) * x_1 def compute_ut( - self, t: jax.Array, x_0: jax.Array, x_1: jax.Array - ) -> jax.Array: + self, t: jnp.ndarray, x_0: jnp.ndarray, x_1: jnp.ndarray + ) -> jnp.ndarray: """Evaluate the conditional vector field defined between :math:`x_0` and :math:`x_1` at time :math:`t`. Args: @@ -118,7 +119,7 @@ def compute_ut( class ConstantNoiseFlow(StraightFlow): r"""Flow with straight paths and constant flow noise :math:`\sigma`.""" - def compute_sigma_t(self, t: jax.Array): + def compute_sigma_t(self, t: jnp.ndarray): r"""Compute noise of the flow at time :math:`t`. Args: @@ -133,7 +134,7 @@ def compute_sigma_t(self, t: jax.Array): class BrownianNoiseFlow(StraightFlow): r"""Sampler for sampling noise implicitly defined by a Schroedinger Bridge problem with parameter `\sigma` such that :math:`\sigma_t = \sigma * \sqrt(t * (1-t))`.""" - def compute_sigma_t(self, t: jax.Array): + def compute_sigma_t(self, t: jnp.ndarray): """Compute the standard deviation of the probablity path at time :math:`t`. Args: @@ -158,7 +159,7 @@ def __init__(self, low: float, high: float) -> None: self.high = high @abc.abstractmethod - def __call__(self, rng: jax.Array, num_samples: int) -> jax.Array: + def __call__(self, rng: jnp.ndarray, num_samples: int) -> jnp.ndarray: """Generate `num_samples` samples of the time `math`:t:. Args: @@ -179,7 +180,7 @@ class UniformSampler(BaseTimeSampler): def __init__(self, low: float = 0.0, high: float = 1.0) -> None: super().__init__(low=low, high=high) - def __call__(self, rng: jax.Array, num_samples: int) -> jax.Array: + def __call__(self, rng: jnp.ndarray, num_samples: int) -> jnp.ndarray: """Generate `num_samples` samples of the time `math`:t:. Args: @@ -209,7 +210,7 @@ def __init__( super().__init__(low=low, high=high) self.offset = offset - def __call__(self, rng: jax.Array, num_samples: int) -> jax.Array: + def __call__(self, rng: jnp.ndarray, num_samples: int) -> jnp.ndarray: """Generate `num_samples` samples of the time `math`:t:. Args: diff --git a/src/ott/neural/solvers/genot.py b/src/ott/neural/solvers/genot.py index 377ef033d..7b0867e44 100644 --- a/src/ott/neural/solvers/genot.py +++ b/src/ott/neural/solvers/genot.py @@ -113,10 +113,10 @@ def __init__( fused_penalty: float = 0.0, tau_a: float = 1.0, tau_b: float = 1.0, - mlp_eta: Callable[[jax.Array], float] = None, - mlp_xi: Callable[[jax.Array], float] = None, + mlp_eta: Callable[[jnp.ndarray], float] = None, + mlp_xi: Callable[[jnp.ndarray], float] = None, unbalanced_kwargs: Dict[str, Any] = {}, - callback_fn: Optional[Callable[[jax.Array, jax.Array, jax.Array], + callback_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], Any]] = None, rng: random.PRNGKeyArray = random.PRNGKey(0), ) -> None: @@ -323,7 +323,7 @@ def step_fn( ): def loss_fn( - params: jax.Array, batch: Dict[str, jnp.array], + params: jnp.ndarray, batch: Dict[str, jnp.array], keys_model: random.PRNGKeyArray ): x_t = self.flow.compute_xt( @@ -356,12 +356,12 @@ def loss_fn( def transport( self, - source: jax.Array, - condition: Optional[jax.Array], + source: jnp.ndarray, + condition: Optional[jnp.ndarray], rng: random.PRNGKeyArray = random.PRNGKey(0), forward: bool = True, diffeqsolve_kwargs: Dict[str, Any] = types.MappingProxyType({}), - ) -> Union[jnp.array, diffrax.Solution, Optional[jax.Array]]: + ) -> Union[jnp.array, diffrax.Solution, Optional[jnp.ndarray]]: """Transport data with the learnt plan. This method pushes-forward the `source` to its conditional distribution by solving the neural ODE parameterized by the :attr:`~ott.neural.solvers.GENOTg.neural_vector_field` from @@ -390,7 +390,7 @@ def transport( axis=-1) t0, t1 = (0.0, 1.0) - def solve_ode(input: jax.Array, cond: jax.Array): + def solve_ode(input: jnp.ndarray, cond: jnp.ndarray): return diffrax.diffeqsolve( diffrax.ODETerm( lambda t, x, args: self.state_neural_vector_field. @@ -446,7 +446,7 @@ def training_logs(self) -> Dict[str, Any]: """Logs of the training.""" raise NotImplementedError - def sample_noise(self, key: random.PRNGKey, batch_size: int) -> jax.Array: + def sample_noise(self, key: random.PRNGKey, batch_size: int) -> jnp.ndarray: """Sample noise from a standard-normal distribution. Args: diff --git a/src/ott/neural/solvers/losses.py b/src/ott/neural/solvers/losses.py index fbf091b22..bec0f3916 100644 --- a/src/ott/neural/solvers/losses.py +++ b/src/ott/neural/solvers/losses.py @@ -25,8 +25,8 @@ def monge_gap( - map_fn: Callable[[jax.Array], jax.Array], - reference_points: jax.Array, + map_fn: Callable[[jnp.ndarray], jnp.ndarray], + reference_points: jnp.ndarray, cost_fn: Optional[costs.CostFn] = None, epsilon: Optional[float] = None, relative_epsilon: Optional[bool] = None, @@ -91,8 +91,8 @@ def monge_gap( def monge_gap_from_samples( - source: jax.Array, - target: jax.Array, + source: jnp.ndarray, + target: jnp.ndarray, cost_fn: Optional[costs.CostFn] = None, epsilon: Optional[float] = None, relative_epsilon: Optional[bool] = None, diff --git a/src/ott/neural/solvers/map_estimator.py b/src/ott/neural/solvers/map_estimator.py index 53bcdc7dd..65edb1d60 100644 --- a/src/ott/neural/solvers/map_estimator.py +++ b/src/ott/neural/solvers/map_estimator.py @@ -79,15 +79,15 @@ def __init__( dim_data: int, model: neuraldual.BaseW2NeuralDual, optimizer: Optional[optax.OptState] = None, - fitting_loss: Optional[Callable[[jax.Array, jax.Array], + fitting_loss: Optional[Callable[[jnp.ndarray, jnp.ndarray], Tuple[float, Optional[Any]]]] = None, - regularizer: Optional[Callable[[jax.Array, jax.Array], + regularizer: Optional[Callable[[jnp.ndarray, jnp.ndarray], Tuple[float, Optional[Any]]]] = None, regularizer_strength: Union[float, Sequence[float]] = 1., num_train_iters: int = 10_000, logging: bool = False, valid_freq: int = 500, - rng: Optional[jax.Array] = None, + rng: Optional[jnp.ndarray] = None, ): self._fitting_loss = fitting_loss self._regularizer = regularizer @@ -126,7 +126,7 @@ def setup( self.step_fn = self._get_step_fn() @property - def regularizer(self) -> Callable[[jax.Array, jax.Array], float]: + def regularizer(self) -> Callable[[jnp.ndarray, jnp.ndarray], float]: """Regularizer added to the fitting loss. Can be e.g. the :func:`~ott.solvers.nn.losses.monge_gap_from_samples`. @@ -139,7 +139,7 @@ def regularizer(self) -> Callable[[jax.Array, jax.Array], float]: return lambda *args, **kwargs: (0., None) @property - def fitting_loss(self) -> Callable[[jax.Array, jax.Array], float]: + def fitting_loss(self) -> Callable[[jnp.ndarray, jnp.ndarray], float]: """Fitting loss to fit the marginal constraint. Can be for instance the @@ -153,9 +153,9 @@ def fitting_loss(self) -> Callable[[jax.Array, jax.Array], float]: @staticmethod def _generate_batch( - loader_source: Iterator[jax.Array], - loader_target: Iterator[jax.Array], - ) -> Dict[str, jax.Array]: + loader_source: Iterator[jnp.ndarray], + loader_target: Iterator[jnp.ndarray], + ) -> Dict[str, jnp.ndarray]: """Generate batches a batch of samples. ``loader_source`` and ``loader_target`` can be training or @@ -168,10 +168,10 @@ def _generate_batch( def train_map_estimator( self, - trainloader_source: Iterator[jax.Array], - trainloader_target: Iterator[jax.Array], - validloader_source: Iterator[jax.Array], - validloader_target: Iterator[jax.Array], + trainloader_source: Iterator[jnp.ndarray], + trainloader_target: Iterator[jnp.ndarray], + validloader_source: Iterator[jnp.ndarray], + validloader_target: Iterator[jnp.ndarray], ) -> Tuple[train_state.TrainState, Dict[str, Any]]: """Training loop.""" # define logs @@ -230,7 +230,7 @@ def _get_step_fn(self) -> Callable: def loss_fn( params: frozen_dict.FrozenDict, apply_fn: Callable, - batch: Dict[str, jax.Array], step: int + batch: Dict[str, jnp.ndarray], step: int ) -> Tuple[float, Dict[str, float]]: """Loss function.""" # map samples with the fitted map @@ -261,8 +261,8 @@ def loss_fn( @functools.partial(jax.jit, static_argnums=3) def step_fn( state_neural_net: train_state.TrainState, - train_batch: Dict[str, jax.Array], - valid_batch: Optional[Dict[str, jax.Array]] = None, + train_batch: Dict[str, jnp.ndarray], + valid_batch: Optional[Dict[str, jnp.ndarray]] = None, is_logging_step: bool = False, step: int = 0 ) -> Tuple[train_state.TrainState, Dict[str, float]]: diff --git a/src/ott/neural/solvers/neuraldual.py b/src/ott/neural/solvers/neuraldual.py index 7d4d5800f..23c63fa3f 100644 --- a/src/ott/neural/solvers/neuraldual.py +++ b/src/ott/neural/solvers/neuraldual.py @@ -44,8 +44,8 @@ Callback_t = Callable[[int, potentials.DualPotentials], None] Conj_t = Optional[conjugate_solvers.FenchelConjugateSolver] -PotentialValueFn_t = Callable[[jax.Array], jax.Array] -PotentialGradientFn_t = Callable[[jax.Array], jax.Array] +PotentialValueFn_t = Callable[[jnp.ndarray], jnp.ndarray] +PotentialGradientFn_t = Callable[[jnp.ndarray], jnp.ndarray] class W2NeuralTrainState(train_state.TrainState): @@ -60,9 +60,9 @@ class W2NeuralTrainState(train_state.TrainState): potential_gradient_fn: the potential's gradient function """ potential_value_fn: Callable[ - [frozen_dict.FrozenDict[str, jax.Array], Optional[PotentialValueFn_t]], + [frozen_dict.FrozenDict[str, jnp.ndarray], Optional[PotentialValueFn_t]], PotentialValueFn_t] = struct.field(pytree_node=False) - potential_gradient_fn: Callable[[frozen_dict.FrozenDict[str, jax.Array]], + potential_gradient_fn: Callable[[frozen_dict.FrozenDict[str, jnp.ndarray]], PotentialGradientFn_t] = struct.field( pytree_node=False ) @@ -87,7 +87,7 @@ def is_potential(self) -> bool: def potential_value_fn( self, - params: frozen_dict.FrozenDict[str, jax.Array], + params: frozen_dict.FrozenDict[str, jnp.ndarray], other_potential_value_fn: Optional[PotentialValueFn_t] = None, ) -> PotentialValueFn_t: r"""Return a function giving the value of the potential. @@ -119,7 +119,7 @@ def potential_value_fn( "The value of the gradient-based potential depends " \ "on the value of the other potential." - def value_fn(x: jax.Array) -> jax.Array: + def value_fn(x: jnp.ndarray) -> jnp.ndarray: squeeze = x.ndim == 1 if squeeze: x = jnp.expand_dims(x, 0) @@ -132,7 +132,7 @@ def value_fn(x: jax.Array) -> jax.Array: def potential_gradient_fn( self, - params: frozen_dict.FrozenDict[str, jax.Array], + params: frozen_dict.FrozenDict[str, jnp.ndarray], ) -> PotentialGradientFn_t: """Return a function returning a vector or the gradient of the potential. @@ -148,7 +148,7 @@ def potential_gradient_fn( def create_train_state( self, - rng: jax.Array, + rng: jnp.ndarray, optimizer: optax.OptState, input: Union[int, Tuple[int, ...]], **kwargs: Any, @@ -243,7 +243,7 @@ def __init__( valid_freq: int = 1000, log_freq: int = 1000, logging: bool = False, - rng: Optional[jax.Array] = None, + rng: Optional[jnp.ndarray] = None, pos_weights: bool = True, beta: float = 1.0, conjugate_solver: Conj_t = conjugate_solvers.DEFAULT_CONJUGATE_SOLVER, @@ -288,7 +288,7 @@ def __init__( def setup( self, - rng: jax.Array, + rng: jnp.ndarray, neural_f: BaseW2NeuralDual, neural_g: BaseW2NeuralDual, dim_data: int, @@ -358,10 +358,10 @@ def setup( def __call__( # noqa: D102 self, - trainloader_source: Iterator[jax.Array], - trainloader_target: Iterator[jax.Array], - validloader_source: Iterator[jax.Array], - validloader_target: Iterator[jax.Array], + trainloader_source: Iterator[jnp.ndarray], + trainloader_target: Iterator[jnp.ndarray], + validloader_source: Iterator[jnp.ndarray], + validloader_target: Iterator[jnp.ndarray], callback: Optional[Callback_t] = None, ) -> Union[potentials.DualPotentials, Tuple[potentials.DualPotentials, Train_t]]: @@ -378,10 +378,10 @@ def __call__( # noqa: D102 def train_neuraldual_parallel( self, - trainloader_source: Iterator[jax.Array], - trainloader_target: Iterator[jax.Array], - validloader_source: Iterator[jax.Array], - validloader_target: Iterator[jax.Array], + trainloader_source: Iterator[jnp.ndarray], + trainloader_target: Iterator[jnp.ndarray], + validloader_source: Iterator[jnp.ndarray], + validloader_target: Iterator[jnp.ndarray], callback: Optional[Callback_t] = None, ) -> Train_t: """Training and validation with parallel updates.""" @@ -453,10 +453,10 @@ def train_neuraldual_parallel( def train_neuraldual_alternating( self, - trainloader_source: Iterator[jax.Array], - trainloader_target: Iterator[jax.Array], - validloader_source: Iterator[jax.Array], - validloader_target: Iterator[jax.Array], + trainloader_source: Iterator[jnp.ndarray], + trainloader_target: Iterator[jnp.ndarray], + validloader_source: Iterator[jnp.ndarray], + validloader_target: Iterator[jnp.ndarray], callback: Optional[Callback_t] = None, ) -> Train_t: """Training and validation with alternating updates.""" @@ -533,7 +533,7 @@ def loss_fn(params_f, params_g, f_value, g_value, g_gradient, batch): init_source_hat = g_gradient(params_g)(target) - def g_value_partial(y: jax.Array) -> jax.Array: + def g_value_partial(y: jnp.ndarray) -> jnp.ndarray: """Lazy way of evaluating g if f's computation needs it.""" return g_value(params_g)(y) @@ -661,7 +661,7 @@ def to_dual_potentials( self.state_g.params, f_value ) - def g_value_finetuned(y: jax.Array) -> jax.Array: + def g_value_finetuned(y: jnp.ndarray) -> jnp.ndarray: x_hat = jax.grad(g_value_prediction)(y) grad_g_y = jax.lax.stop_gradient( self.conjugate_solver.solve(f_value, y, x_init=x_hat).grad @@ -686,7 +686,7 @@ def _clip_weights_icnn(params): return core.freeze(params) @staticmethod - def _penalize_weights_icnn(params: Dict[str, jax.Array]) -> float: + def _penalize_weights_icnn(params: Dict[str, jnp.ndarray]) -> float: penalty = 0.0 for k, param in params.items(): if k.startswith("w_z"): @@ -696,9 +696,9 @@ def _penalize_weights_icnn(params: Dict[str, jax.Array]) -> float: @staticmethod def _update_logs( logs: Dict[str, List[Union[float, str]]], - loss_f: jax.Array, - loss_g: jax.Array, - w_dist: jax.Array, + loss_f: jnp.ndarray, + loss_g: jnp.ndarray, + w_dist: jnp.ndarray, ) -> None: logs["loss_f"].append(float(loss_f)) logs["loss_g"].append(float(loss_g)) diff --git a/src/ott/neural/solvers/otfm.py b/src/ott/neural/solvers/otfm.py index 82a6b67aa..2afc94e6d 100644 --- a/src/ott/neural/solvers/otfm.py +++ b/src/ott/neural/solvers/otfm.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import defaultdict import functools import types +from collections import defaultdict from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Type import diffrax @@ -36,7 +36,6 @@ BaseTimeSampler, ) from ott.solvers import was_solver -from ott.tools.sinkhorn_divergence import sinkhorn_divergence __all__ = ["OTFlowMatching"] @@ -88,10 +87,10 @@ def __init__( cost_fn: Type[costs.CostFn] = costs.SqEuclidean(), tau_a: float = 1.0, tau_b: float = 1.0, - mlp_eta: Callable[[jax.Array], float] = None, - mlp_xi: Callable[[jax.Array], float] = None, + mlp_eta: Callable[[jnp.ndarray], float] = None, + mlp_xi: Callable[[jnp.ndarray], float] = None, unbalanced_kwargs: Dict[str, Any] = {}, - callback_fn: Optional[Callable[[jax.Array, jax.Array, jax.Array], + callback_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], Any]] = None, logging_freq: int = 100, valid_freq: int = 5000, @@ -158,20 +157,25 @@ def _get_step_fn(self) -> Callable: def step_fn( key: random.PRNGKeyArray, state_neural_vector_field: train_state.TrainState, - batch: Dict[str, jax.Array], + batch: Dict[str, jnp.ndarray], ) -> Tuple[Any, Any]: def loss_fn( - params: jax.Array, t: jax.Array, noise: jax.Array, - batch: Dict[str, jax.Array], keys_model: random.PRNGKeyArray - ) -> jax.Array: + params: jnp.ndarray, t: jnp.ndarray, noise: jnp.ndarray, + batch: Dict[str, jnp.ndarray], keys_model: random.PRNGKeyArray + ) -> jnp.ndarray: - x_t = self.flow.compute_xt(noise, t, batch["source_lin"], batch["target_lin"]) + x_t = self.flow.compute_xt( + noise, t, batch["source_lin"], batch["target_lin"] + ) apply_fn = functools.partial( state_neural_vector_field.apply_fn, {"params": params} ) v_t = jax.vmap(apply_fn)( - t=t, x=x_t, condition=batch["source_conditions"], keys_model=keys_model + t=t, + x=x_t, + condition=batch["source_conditions"], + keys_model=keys_model ) u_t = self.flow.compute_ut(t, batch["source_lin"], batch["target_lin"]) return jnp.mean((v_t - u_t) ** 2) @@ -199,19 +203,21 @@ def __call__(self, train_loader, valid_loader) -> None: Returns: None """ - batch: Mapping[str, jax.Array] = {} + batch: Mapping[str, jnp.ndarray] = {} curr_loss = 0.0 - + for iter in range(self.iterations): rng_resample, rng_step_fn, self.rng = random.split(self.rng, 3) batch = next(train_loader) if self.ot_solver is not None: tmat = self.match_fn(batch["source_lin"], batch["target_lin"]) - (batch["source_lin"], - batch["source_conditions"]), (batch["target_lin"], batch["target_conditions"]) = self._resample_data( - rng_resample, tmat, (batch["source_lin"], batch["source_conditions"]), - (batch["target_lin"], batch["target_conditions"]) - ) + (batch["source_lin"], batch["source_conditions"] + ), (batch["target_lin"], + batch["target_conditions"]) = self._resample_data( + rng_resample, tmat, + (batch["source_lin"], batch["source_conditions"]), + (batch["target_lin"], batch["target_conditions"]) + ) self.state_neural_vector_field, loss = self.step_fn( rng_step_fn, self.state_neural_vector_field, batch ) @@ -244,7 +250,7 @@ def __call__(self, train_loader, valid_loader) -> None: def transport( self, data: jnp.array, - condition: Optional[jax.Array], + condition: Optional[jnp.ndarray], forward: bool = True, diffeqsolve_kwargs: Dict[str, Any] = types.MappingProxyType({}) ) -> diffrax.Solution: @@ -270,7 +276,7 @@ def transport( ) if forward else (self.time_sampler.high, self.time_sampler.low) @jax.jit - def solve_ode(input: jax.Array, cond: jax.Array): + def solve_ode(input: jnp.ndarray, cond: jnp.ndarray): return diffrax.diffeqsolve( diffrax.ODETerm( lambda t, x, args: self.state_neural_vector_field. @@ -296,7 +302,6 @@ def solve_ode(input: jax.Array, cond: jax.Array): def _valid_step(self, valid_loader, iter) -> None: next(valid_loader) # TODO: add callback and logging - @property def learn_rescaling(self) -> bool: @@ -327,7 +332,7 @@ def training_logs(self) -> Dict[str, Any]: """Logs of the training.""" raise NotImplementedError - def sample_noise(self, key: random.PRNGKey, batch_size: int) -> jax.Array: + def sample_noise(self, key: random.PRNGKey, batch_size: int) -> jnp.ndarray: """Sample noise from a standard-normal distribution. Args: diff --git a/src/ott/problems/linear/barycenter_problem.py b/src/ott/problems/linear/barycenter_problem.py index c94cc578d..ca5333a8e 100644 --- a/src/ott/problems/linear/barycenter_problem.py +++ b/src/ott/problems/linear/barycenter_problem.py @@ -50,9 +50,9 @@ class FreeBarycenterProblem: def __init__( self, - y: jax.Array, - b: Optional[jax.Array] = None, - weights: Optional[jax.Array] = None, + y: jnp.ndarray, + b: Optional[jnp.ndarray] = None, + weights: Optional[jnp.ndarray] = None, cost_fn: Optional[costs.CostFn] = None, epsilon: Optional[float] = None, **kwargs: Any, @@ -76,7 +76,7 @@ def __init__( assert self._b is None or self._y.shape[0] == self._b.shape[0] @property - def segmented_y_b(self) -> Tuple[jax.Array, jax.Array]: + def segmented_y_b(self) -> Tuple[jnp.ndarray, jnp.ndarray]: """Tuple of arrays containing the segmented measures and weights. - Segmented measures of shape ``[num_measures, max_measure_size, ndim]``. @@ -94,14 +94,14 @@ def segmented_y_b(self) -> Tuple[jax.Array, jax.Array]: return y, b @property - def flattened_y(self) -> jax.Array: + def flattened_y(self) -> jnp.ndarray: """Array of shape ``[num_measures * (N_1 + N_2 + ...), ndim]``.""" if self._is_segmented: return self._y.reshape((-1, self._y.shape[-1])) return self._y @property - def flattened_b(self) -> Optional[jax.Array]: + def flattened_b(self) -> Optional[jnp.ndarray]: """Array of shape ``[num_measures * (N_1 + N_2 + ...),]``.""" return None if self._b is None else self._b.ravel() @@ -121,7 +121,7 @@ def ndim(self) -> int: return self._y.shape[-1] @property - def weights(self) -> jax.Array: + def weights(self) -> jnp.ndarray: """Barycenter weights of shape ``[num_measures,]`` that sum to 1.""" if self._weights is None: return jnp.ones((self.num_measures,)) / self.num_measures @@ -165,8 +165,8 @@ class FixedBarycenterProblem: def __init__( self, geom: geometry.Geometry, - a: jax.Array, - weights: Optional[jax.Array] = None, + a: jnp.ndarray, + weights: Optional[jnp.ndarray] = None, ): self.geom = geom self.a = a @@ -178,7 +178,7 @@ def num_measures(self) -> int: return self.a.shape[0] @property - def weights(self) -> jax.Array: + def weights(self) -> jnp.ndarray: """Barycenter weights of shape ``[num_measures,]`` that sum to :math`1`.""" if self._weights is None: return jnp.ones((self.num_measures,)) / self.num_measures diff --git a/src/ott/problems/linear/linear_problem.py b/src/ott/problems/linear/linear_problem.py index 3e09c0e59..7c206aa63 100644 --- a/src/ott/problems/linear/linear_problem.py +++ b/src/ott/problems/linear/linear_problem.py @@ -21,8 +21,9 @@ __all__ = ["LinearProblem"] # TODO(michalk8): move to typing.py when refactoring the types -MarginalFunc = Callable[[jax.Array, jax.Array], jax.Array] -TransportAppFunc = Callable[[jax.Array, jax.Array, jax.Array, int], jax.Array] +MarginalFunc = Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray] +TransportAppFunc = Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray, int], + jnp.ndarray] @jax.tree_util.register_pytree_node_class @@ -49,8 +50,8 @@ class LinearProblem: def __init__( self, geom: geometry.Geometry, - a: Optional[jax.Array] = None, - b: Optional[jax.Array] = None, + a: Optional[jnp.ndarray] = None, + b: Optional[jnp.ndarray] = None, tau_a: float = 1.0, tau_b: float = 1.0 ): @@ -61,13 +62,13 @@ def __init__( self.tau_b = tau_b @property - def a(self) -> jax.Array: + def a(self) -> jnp.ndarray: """First marginal.""" num_a = self.geom.shape[0] return jnp.ones((num_a,)) / num_a if self._a is None else self._a @property - def b(self) -> jax.Array: + def b(self) -> jnp.ndarray: """Second marginal.""" num_b = self.geom.shape[1] return jnp.ones((num_b,)) / num_b if self._b is None else self._b diff --git a/src/ott/problems/linear/potentials.py b/src/ott/problems/linear/potentials.py index 718aa22a1..7ab226072 100644 --- a/src/ott/problems/linear/potentials.py +++ b/src/ott/problems/linear/potentials.py @@ -37,7 +37,7 @@ mpl = plt = None __all__ = ["DualPotentials", "EntropicPotentials"] -Potential_t = Callable[[jax.Array], float] +Potential_t = Callable[[jnp.ndarray], float] @jtu.register_pytree_node_class @@ -72,7 +72,7 @@ def __init__( self.cost_fn = cost_fn self._corr = corr - def transport(self, vec: jax.Array, forward: bool = True) -> jax.Array: + def transport(self, vec: jnp.ndarray, forward: bool = True) -> jnp.ndarray: r"""Transport ``vec`` according to Brenier formula :cite:`brenier:91`. Uses Theorem 1.17 from :cite:`santambrogio:15` to compute an OT map when @@ -105,7 +105,7 @@ def transport(self, vec: jax.Array, forward: bool = True) -> jax.Array: return vec - self._grad_h_inv(self._grad_f(vec)) return vec - self._grad_h_inv(self._grad_g(vec)) - def distance(self, src: jax.Array, tgt: jax.Array) -> float: + def distance(self, src: jnp.ndarray, tgt: jnp.ndarray) -> float: r"""Evaluate Wasserstein distance between samples using dual potentials. This uses direct estimation of potentials against measures when dual @@ -146,17 +146,17 @@ def g(self) -> Potential_t: return self._g @property - def _grad_f(self) -> Callable[[jax.Array], jax.Array]: + def _grad_f(self) -> Callable[[jnp.ndarray], jnp.ndarray]: """Vectorized gradient of the potential function :attr:`f`.""" return jax.vmap(jax.grad(self.f, argnums=0)) @property - def _grad_g(self) -> Callable[[jax.Array], jax.Array]: + def _grad_g(self) -> Callable[[jnp.ndarray], jnp.ndarray]: """Vectorized gradient of the potential function :attr:`g`.""" return jax.vmap(jax.grad(self.g, argnums=0)) @property - def _grad_h_inv(self) -> Callable[[jax.Array], jax.Array]: + def _grad_h_inv(self) -> Callable[[jnp.ndarray], jnp.ndarray]: from ott.geometry import costs assert isinstance(self.cost_fn, costs.TICost), ( @@ -181,9 +181,9 @@ def tree_unflatten( # noqa: D102 def plot_ot_map( self, - source: jax.Array, - target: jax.Array, - samples: Optional[jax.Array] = None, + source: jnp.ndarray, + target: jnp.ndarray, + samples: Optional[jnp.ndarray] = None, forward: bool = True, ax: Optional["plt.Axes"] = None, legend_kwargs: Optional[Dict[str, Any]] = None, @@ -348,11 +348,11 @@ class EntropicPotentials(DualPotentials): def __init__( self, - f_xy: jax.Array, - g_xy: jax.Array, + f_xy: jnp.ndarray, + g_xy: jnp.ndarray, prob: linear_problem.LinearProblem, - f_xx: Optional[jax.Array] = None, - g_yy: Optional[jax.Array] = None, + f_xx: Optional[jnp.ndarray] = None, + g_yy: Optional[jnp.ndarray] = None, ): # we pass directly the arrays and override the properties # since only the properties need to be callable @@ -373,11 +373,11 @@ def _potential_fn(self, *, kind: Literal["f", "g"]) -> Potential_t: from ott.geometry import pointcloud def callback( - x: jax.Array, + x: jnp.ndarray, *, - potential: jax.Array, - y: jax.Array, - weights: jax.Array, + potential: jnp.ndarray, + y: jnp.ndarray, + weights: jnp.ndarray, epsilon: float, ) -> float: x = jnp.atleast_2d(x) diff --git a/src/ott/problems/quadratic/gw_barycenter.py b/src/ott/problems/quadratic/gw_barycenter.py index 7170f1064..dfe562d98 100644 --- a/src/ott/problems/quadratic/gw_barycenter.py +++ b/src/ott/problems/quadratic/gw_barycenter.py @@ -60,11 +60,11 @@ class GWBarycenterProblem(barycenter_problem.FreeBarycenterProblem): def __init__( self, - y: Optional[jax.Array] = None, - b: Optional[jax.Array] = None, - weights: Optional[jax.Array] = None, - costs: Optional[jax.Array] = None, - y_fused: Optional[jax.Array] = None, + y: Optional[jnp.ndarray] = None, + b: Optional[jnp.ndarray] = None, + weights: Optional[jnp.ndarray] = None, + costs: Optional[jnp.ndarray] = None, + y_fused: Optional[jnp.ndarray] = None, fused_penalty: float = 1.0, gw_loss: Literal["sqeucl", "kl"] = "sqeucl", scale_cost: Union[int, float, Literal["mean", "max_cost"]] = 1.0, @@ -98,7 +98,9 @@ def __init__( # TODO(michalk8): in the future, consider checking the other 2 cases # using `segmented_y` and `segmented_y_fused`? - def update_barycenter(self, transports: jax.Array, a: jax.Array) -> jax.Array: + def update_barycenter( + self, transports: jnp.ndarray, a: jnp.ndarray + ) -> jnp.ndarray: """Update the barycenter cost matrix. Uses the eq. 14 and 15 of :cite:`peyre:16`. @@ -114,11 +116,11 @@ def update_barycenter(self, transports: jax.Array, a: jax.Array) -> jax.Array: @functools.partial(jax.vmap, in_axes=[0, 0, 0, None]) def project( - y: jax.Array, - b: jax.Array, - transport: jax.Array, + y: jnp.ndarray, + b: jnp.ndarray, + transport: jnp.ndarray, fn: Optional[quadratic_costs.Loss], - ) -> jax.Array: + ) -> jnp.ndarray: geom = self._create_y_geometry(y, mask=b > 0.) fn, lin = (None, True) if fn is None else (fn.func, fn.is_linear) @@ -144,8 +146,8 @@ def project( return jnp.exp(barycenter) return barycenter - def update_features(self, transports: jax.Array, - a: jax.Array) -> Optional[jax.Array]: + def update_features(self, transports: jnp.ndarray, + a: jnp.ndarray) -> Optional[jnp.ndarray]: """Update the barycenter features in the fused case :cite:`vayer:19`. Uses :cite:`cuturi:14` eq. 8, and is implemented only @@ -179,8 +181,8 @@ def update_features(self, transports: jax.Array, def _create_bary_geometry( self, - cost_matrix: jax.Array, - mask: Optional[jax.Array] = None + cost_matrix: jnp.ndarray, + mask: Optional[jnp.ndarray] = None ) -> geometry.Geometry: return geometry.Geometry( cost_matrix=cost_matrix, @@ -192,8 +194,8 @@ def _create_bary_geometry( def _create_y_geometry( self, - y: jax.Array, - mask: Optional[jax.Array] = None + y: jnp.ndarray, + mask: Optional[jnp.ndarray] = None ) -> geometry.Geometry: if self._y_as_costs: assert y.shape[0] == y.shape[1], y.shape @@ -215,10 +217,10 @@ def _create_y_geometry( def _create_fused_geometry( self, - x: jax.Array, - y: jax.Array, - src_mask: Optional[jax.Array] = None, - tgt_mask: Optional[jax.Array] = None + x: jnp.ndarray, + y: jnp.ndarray, + src_mask: Optional[jnp.ndarray] = None, + tgt_mask: Optional[jnp.ndarray] = None ) -> pointcloud.PointCloud: return pointcloud.PointCloud( x, @@ -233,9 +235,9 @@ def _create_fused_geometry( def _create_problem( self, state: "GWBarycenterState", # noqa: F821 - y: jax.Array, - b: jax.Array, - f: Optional[jax.Array] = None + y: jnp.ndarray, + b: jnp.ndarray, + f: Optional[jnp.ndarray] = None ) -> quadratic_problem.QuadraticProblem: # TODO(michalk8): in future, mask in the problem for convenience? bary_mask = state.a > 0. @@ -267,7 +269,7 @@ def is_fused(self) -> bool: return self._y_fused is not None @property - def segmented_y_fused(self) -> Optional[jax.Array]: + def segmented_y_fused(self) -> Optional[jnp.ndarray]: """Feature array of shape used in the fused case.""" if not self.is_fused or self._y_fused.ndim == 3: return self._y_fused diff --git a/src/ott/problems/quadratic/quadratic_costs.py b/src/ott/problems/quadratic/quadratic_costs.py index 060c3c537..70f2bf5ad 100644 --- a/src/ott/problems/quadratic/quadratic_costs.py +++ b/src/ott/problems/quadratic/quadratic_costs.py @@ -13,7 +13,6 @@ # limitations under the License. from typing import Callable, NamedTuple -import jax import jax.numpy as jnp import jax.scipy as jsp @@ -21,7 +20,7 @@ class Loss(NamedTuple): # noqa: D101 - func: Callable[[jax.Array], jax.Array] + func: Callable[[jnp.ndarray], jnp.ndarray] is_linear: bool diff --git a/src/ott/problems/quadratic/quadratic_problem.py b/src/ott/problems/quadratic/quadratic_problem.py index cf5b804a2..a17aaf9fb 100644 --- a/src/ott/problems/quadratic/quadratic_problem.py +++ b/src/ott/problems/quadratic/quadratic_problem.py @@ -91,8 +91,8 @@ def __init__( geom_xy: Optional[geometry.Geometry] = None, fused_penalty: float = 1.0, scale_cost: Optional[Union[bool, float, str]] = False, - a: Optional[jax.Array] = None, - b: Optional[jax.Array] = None, + a: Optional[jnp.ndarray] = None, + b: Optional[jnp.ndarray] = None, loss: Union[Literal["sqeucl", "kl"], quadratic_costs.GWLoss] = "sqeucl", tau_a: float = 1.0, tau_b: float = 1.0, @@ -125,8 +125,8 @@ def __init__( def marginal_dependent_cost( self, - marginal_1: jax.Array, - marginal_2: jax.Array, + marginal_1: jnp.ndarray, + marginal_2: jnp.ndarray, ) -> low_rank.LRCGeometry: r"""Initialize cost term that depends on the marginals of the transport. @@ -169,9 +169,9 @@ def marginal_dependent_cost( def cost_unbalanced_correction( self, - transport_matrix: jax.Array, - marginal_1: jax.Array, - marginal_2: jax.Array, + transport_matrix: jnp.ndarray, + marginal_1: jnp.ndarray, + marginal_2: jnp.ndarray, epsilon: epsilon_scheduler.Epsilon, ) -> float: r"""Calculate cost term from the quadratic divergence when unbalanced. @@ -193,10 +193,10 @@ def cost_unbalanced_correction( :math:`+ epsilon * \sum(KL(P|ab'))` Args: - transport_matrix: jax.Array[num_a, num_b], transport matrix. - marginal_1: jax.Array[num_a,], marginal of the transport matrix + transport_matrix: jnp.ndarray[num_a, num_b], transport matrix. + marginal_1: jnp.ndarray[num_a,], marginal of the transport matrix for samples from :attr:`geom_xx`. - marginal_2: jax.Array[num_b,], marginal of the transport matrix + marginal_2: jnp.ndarray[num_b,], marginal of the transport matrix for samples from :attr:`geom_yy`. epsilon: entropy regularizer. @@ -353,7 +353,7 @@ def update_lr_linearization( ) @property - def _fused_cost_matrix(self) -> Union[float, jax.Array]: + def _fused_cost_matrix(self) -> Union[float, jnp.ndarray]: if not self.is_fused: return 0.0 geom_xy = self.geom_xy @@ -382,7 +382,7 @@ def convertible(geom: geometry.Geometry) -> bool: def to_low_rank( self, - rng: Optional[jax.Array] = None, + rng: Optional[jnp.ndarray] = None, ) -> "QuadraticProblem": """Convert geometries to low-rank. @@ -442,13 +442,13 @@ def geom_xy(self) -> Optional[geometry.Geometry]: return self._geom_xy @property - def a(self) -> jax.Array: + def a(self) -> jnp.ndarray: """First marginal.""" num_a = self.geom_xx.shape[0] return jnp.ones((num_a,)) / num_a if self._a is None else self._a @property - def b(self) -> jax.Array: + def b(self) -> jnp.ndarray: """Second marginal.""" num_b = self.geom_yy.shape[0] return jnp.ones((num_b,)) / num_b if self._b is None else self._b @@ -510,7 +510,7 @@ def update_epsilon_unbalanced( # noqa: D103 def apply_cost( # noqa: D103 - geom: geometry.Geometry, arr: jax.Array, *, axis: int, + geom: geometry.Geometry, arr: jnp.ndarray, *, axis: int, fn: quadratic_costs.Loss -) -> jax.Array: +) -> jnp.ndarray: return geom.apply_cost(arr, axis=axis, fn=fn.func, is_linear=fn.is_linear) diff --git a/src/ott/solvers/linear/_solve.py b/src/ott/solvers/linear/_solve.py index fad5a4e7d..2bca6a825 100644 --- a/src/ott/solvers/linear/_solve.py +++ b/src/ott/solvers/linear/_solve.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Any, Optional, Union -import jax +import jax.numpy as jnp from ott.geometry import geometry from ott.problems.linear import linear_problem @@ -24,8 +24,8 @@ def solve( geom: geometry.Geometry, - a: Optional[jax.Array] = None, - b: Optional[jax.Array] = None, + a: Optional[jnp.ndarray] = None, + b: Optional[jnp.ndarray] = None, tau_a: float = 1.0, tau_b: float = 1.0, rank: int = -1, diff --git a/src/ott/solvers/linear/acceleration.py b/src/ott/solvers/linear/acceleration.py index 7ce602194..4529e7f78 100644 --- a/src/ott/solvers/linear/acceleration.py +++ b/src/ott/solvers/linear/acceleration.py @@ -34,7 +34,7 @@ class AndersonAcceleration: refresh_every: int = 1 # Recompute interpolation periodically. ridge_identity: float = 1e-2 # Ridge used in the linear system. - def extrapolation(self, xs: jax.Array, fxs: jax.Array) -> jax.Array: + def extrapolation(self, xs: jnp.ndarray, fxs: jnp.ndarray) -> jnp.ndarray: """Compute Anderson extrapolation from past observations.""" # Remove -inf values to instantiate quadratic problem. All others # remain since they might be caused by a valid issue. @@ -161,10 +161,10 @@ def lehmann(self, state: "sinkhorn.SinkhornState") -> float: def __call__( # noqa: D102 self, weight: float, - value: jax.Array, - new_value: jax.Array, + value: jnp.ndarray, + new_value: jnp.ndarray, lse_mode: bool = True - ) -> jax.Array: + ) -> jnp.ndarray: if lse_mode: value = jnp.where(jnp.isfinite(value), value, 0.0) return (1.0 - weight) * value + weight * new_value diff --git a/src/ott/solvers/linear/continuous_barycenter.py b/src/ott/solvers/linear/continuous_barycenter.py index 0094c3a3c..b93c14032 100644 --- a/src/ott/solvers/linear/continuous_barycenter.py +++ b/src/ott/solvers/linear/continuous_barycenter.py @@ -41,11 +41,11 @@ class FreeBarycenterState(NamedTuple): a: barycenter weights. """ - costs: Optional[jax.Array] = None - linear_convergence: Optional[jax.Array] = None - errors: Optional[jax.Array] = None - x: Optional[jax.Array] = None - a: Optional[jax.Array] = None + costs: Optional[jnp.ndarray] = None + linear_convergence: Optional[jnp.ndarray] = None + errors: Optional[jnp.ndarray] = None + x: Optional[jnp.ndarray] = None + a: Optional[jnp.ndarray] = None def set(self, **kwargs: Any) -> "FreeBarycenterState": """Return a copy of self, possibly with overwrites.""" @@ -70,7 +70,7 @@ def update( @functools.partial(jax.vmap, in_axes=[None, None, 0, 0]) def solve_linear_ot( - a: Optional[jax.Array], x: jax.Array, b: jax.Array, y: jax.Array + a: Optional[jnp.ndarray], x: jnp.ndarray, b: jnp.ndarray, y: jnp.ndarray ): out = linear_ot_solver( linear_problem.LinearProblem( @@ -129,8 +129,8 @@ def __call__( # noqa: D102 self, bar_prob: barycenter_problem.FreeBarycenterProblem, bar_size: int = 100, - x_init: Optional[jax.Array] = None, - rng: Optional[jax.Array] = None, + x_init: Optional[jnp.ndarray] = None, + rng: Optional[jnp.ndarray] = None, ) -> FreeBarycenterState: # TODO(michalk8): no reason for iterations to be outside this class rng = utils.default_prng_key(rng) @@ -140,8 +140,8 @@ def init_state( self, bar_prob: barycenter_problem.FreeBarycenterProblem, bar_size: int, - x_init: Optional[jax.Array] = None, - rng: Optional[jax.Array] = None, + x_init: Optional[jnp.ndarray] = None, + rng: Optional[jnp.ndarray] = None, ) -> FreeBarycenterState: """Initialize the state of the Wasserstein barycenter iterations. @@ -195,8 +195,8 @@ def output_from_state( # noqa: D102 def iterations( solver: FreeWassersteinBarycenter, bar_size: int, - bar_prob: barycenter_problem.FreeBarycenterProblem, x_init: jax.Array, - rng: jax.Array + bar_prob: barycenter_problem.FreeBarycenterProblem, x_init: jnp.ndarray, + rng: jnp.ndarray ) -> FreeBarycenterState: """Jittable Wasserstein barycenter outer loop.""" diff --git a/src/ott/solvers/linear/discrete_barycenter.py b/src/ott/solvers/linear/discrete_barycenter.py index 85adaa795..dcfdc1470 100644 --- a/src/ott/solvers/linear/discrete_barycenter.py +++ b/src/ott/solvers/linear/discrete_barycenter.py @@ -26,10 +26,10 @@ class SinkhornBarycenterOutput(NamedTuple): # noqa: D101 - f: jax.Array - g: jax.Array - histogram: jax.Array - errors: jax.Array + f: jnp.ndarray + g: jnp.ndarray + histogram: jnp.ndarray + errors: jnp.ndarray @jax.tree_util.register_pytree_node_class @@ -79,7 +79,7 @@ def __init__( def __call__( self, fixed_bp: barycenter_problem.FixedBarycenterProblem, - dual_initialization: Optional[jax.Array] = None, + dual_initialization: Optional[jnp.ndarray] = None, ) -> SinkhornBarycenterOutput: """Solve barycenter problem, possibly using clever initialization. @@ -128,10 +128,10 @@ def tree_unflatten(cls, aux_data, children): # noqa: D102 @functools.partial(jax.jit, static_argnums=(5, 6, 7, 8, 9, 10, 11, 12)) def _discrete_barycenter( - geom: geometry.Geometry, a: jax.Array, weights: jax.Array, - dual_initialization: jax.Array, threshold: float, norm_error: Sequence[int], - inner_iterations: int, min_iterations: int, max_iterations: int, - lse_mode: bool, debiased: bool, num_a: int, num_b: int + geom: geometry.Geometry, a: jnp.ndarray, weights: jnp.ndarray, + dual_initialization: jnp.ndarray, threshold: float, + norm_error: Sequence[int], inner_iterations: int, min_iterations: int, + max_iterations: int, lse_mode: bool, debiased: bool, num_a: int, num_b: int ) -> SinkhornBarycenterOutput: """Jit'able function to compute discrete barycenters.""" if lse_mode: diff --git a/src/ott/solvers/linear/implicit_differentiation.py b/src/ott/solvers/linear/implicit_differentiation.py index c5e7cb0f3..fbf98ce81 100644 --- a/src/ott/solvers/linear/implicit_differentiation.py +++ b/src/ott/solvers/linear/implicit_differentiation.py @@ -23,8 +23,9 @@ if TYPE_CHECKING: from ott.problems.linear import linear_problem -LinOp_t = Callable[[jax.Array], jax.Array] -Solver_t = Callable[[LinOp_t, jax.Array, Optional[LinOp_t], bool], jax.Array] +LinOp_t = Callable[[jnp.ndarray], jnp.ndarray] +Solver_t = Callable[[LinOp_t, jnp.ndarray, Optional[LinOp_t], bool], + jnp.ndarray] __all__ = ["ImplicitDiff", "solve_jax_cg"] @@ -69,16 +70,16 @@ class ImplicitDiff: solver: Optional[Solver_t] = None solver_kwargs: Optional[Dict[str, Any]] = None symmetric: bool = False - precondition_fun: Optional[Callable[[jax.Array], jax.Array]] = None + precondition_fun: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None def solve( self, - gr: Tuple[jax.Array, jax.Array], + gr: Tuple[jnp.ndarray, jnp.ndarray], ot_prob: "linear_problem.LinearProblem", - f: jax.Array, - g: jax.Array, + f: jnp.ndarray, + g: jnp.ndarray, lse_mode: bool, - ) -> jax.Array: + ) -> jnp.ndarray: r"""Apply minus inverse of [hessian ``reg_ot_cost`` w.r.t. ``f``, ``g``]. This function is used to carry out implicit differentiation of ``sinkhorn`` @@ -223,7 +224,7 @@ def solve( return jnp.concatenate((-vjp_gr_f, -vjp_gr_g)) def first_order_conditions( - self, prob, f: jax.Array, g: jax.Array, lse_mode: bool + self, prob, f: jnp.ndarray, g: jnp.ndarray, lse_mode: bool ): r"""Compute vector of first order conditions for the reg-OT problem. @@ -237,12 +238,12 @@ def first_order_conditions( Args: prob: definition of the linear optimal transport problem. - f: jax.Array, first potential - g: jax.Array, second potential + f: jnp.ndarray, first potential + g: jnp.ndarray, second potential lse_mode: bool Returns: - a jax.Array of size (size of ``n + m``) quantifying deviation to + a jnp.ndarray of size (size of ``n + m``) quantifying deviation to optimality for variables ``f`` and ``g``. """ geom = prob.geom @@ -265,8 +266,8 @@ def first_order_conditions( return jnp.concatenate((result_a, result_b)) def gradient( - self, prob: "linear_problem.LinearProblem", f: jax.Array, g: jax.Array, - lse_mode: bool, gr: Tuple[jax.Array, jax.Array] + self, prob: "linear_problem.LinearProblem", f: jnp.ndarray, + g: jnp.ndarray, lse_mode: bool, gr: Tuple[jnp.ndarray, jnp.ndarray] ) -> "linear_problem.LinearProblem": """Apply VJP to recover gradient in reverse mode differentiation.""" # Applies first part of vjp to gr: inverse part of implicit function theorem @@ -286,13 +287,13 @@ def replace(self, **kwargs: Any) -> "ImplicitDiff": # noqa: D102 def solve_jax_cg( lin: LinOp_t, - b: jax.Array, + b: jnp.ndarray, lin_t: Optional[LinOp_t] = None, symmetric: bool = False, ridge_identity: float = 0.0, ridge_kernel: float = 0.0, **kwargs: Any -) -> jax.Array: +) -> jnp.ndarray: """Wrapper around JAX native linear solvers. Args: diff --git a/src/ott/solvers/linear/lineax_implicit.py b/src/ott/solvers/linear/lineax_implicit.py index ac3978462..79b9e7c95 100644 --- a/src/ott/solvers/linear/lineax_implicit.py +++ b/src/ott/solvers/linear/lineax_implicit.py @@ -46,14 +46,14 @@ def transpose(self): def solve_lineax( lin: Callable, - b: jax.Array, + b: jnp.ndarray, lin_t: Optional[Callable] = None, symmetric: bool = False, nonsym_solver: Optional[lx.AbstractLinearSolver] = None, ridge_identity: float = 0.0, ridge_kernel: float = 0.0, **kwargs: Any -) -> jax.Array: +) -> jnp.ndarray: """Wrapper around lineax solvers. Args: diff --git a/src/ott/solvers/linear/lr_utils.py b/src/ott/solvers/linear/lr_utils.py index 2eb4c32ed..8ade265c9 100644 --- a/src/ott/solvers/linear/lr_utils.py +++ b/src/ott/solvers/linear/lr_utils.py @@ -24,27 +24,27 @@ class State(NamedTuple): # noqa: D101 - v1: jax.Array - v2: jax.Array - u1: jax.Array - u2: jax.Array - g: jax.Array + v1: jnp.ndarray + v2: jnp.ndarray + u1: jnp.ndarray + u2: jnp.ndarray + g: jnp.ndarray err: float class Constants(NamedTuple): # noqa: D101 - a: jax.Array - b: jax.Array + a: jnp.ndarray + b: jnp.ndarray rho_a: float rho_b: float - supp_a: Optional[jax.Array] = None - supp_b: Optional[jax.Array] = None + supp_a: Optional[jnp.ndarray] = None + supp_b: Optional[jnp.ndarray] = None def unbalanced_dykstra_lse( - c_q: jax.Array, - c_r: jax.Array, - c_g: jax.Array, + c_q: jnp.ndarray, + c_r: jnp.ndarray, + c_g: jnp.ndarray, gamma: float, ot_prob: linear_problem.LinearProblem, translation_invariant: bool = True, @@ -52,7 +52,7 @@ def unbalanced_dykstra_lse( min_iter: int = 0, inner_iter: int = 10, max_iter: int = 10000 -) -> Tuple[jax.Array, jax.Array, jax.Array]: +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Dykstra's algorithm for the unbalanced :class:`~ott.solvers.linear.sinkhorn_lr.LRSinkhorn` in LSE mode. @@ -74,10 +74,10 @@ def unbalanced_dykstra_lse( """ # noqa: D205 def _softm( - v: jax.Array, - c: jax.Array, + v: jnp.ndarray, + c: jnp.ndarray, axis: int, - ) -> jax.Array: + ) -> jnp.ndarray: v = jnp.expand_dims(v, axis=1 - axis) return jsp.special.logsumexp(v + c, axis=axis) @@ -181,9 +181,9 @@ def body_fn( def unbalanced_dykstra_kernel( - k_q: jax.Array, - k_r: jax.Array, - k_g: jax.Array, + k_q: jnp.ndarray, + k_r: jnp.ndarray, + k_g: jnp.ndarray, gamma: float, ot_prob: linear_problem.LinearProblem, translation_invariant: bool = True, @@ -191,7 +191,7 @@ def unbalanced_dykstra_kernel( min_iter: int = 0, inner_iter: int = 10, max_iter: int = 10000 -) -> Tuple[jax.Array, jax.Array, jax.Array]: +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Dykstra's algorithm for the unbalanced :class:`~ott.solvers.linear.sinkhorn_lr.LRSinkhorn` in kernel mode. @@ -317,7 +317,7 @@ def body_fn( def compute_lambdas( - const: Constants, state: State, gamma: float, g: jax.Array, *, + const: Constants, state: State, gamma: float, g: jnp.ndarray, *, lse_mode: bool ) -> Tuple[float, float]: """TODO.""" diff --git a/src/ott/solvers/linear/sinkhorn.py b/src/ott/solvers/linear/sinkhorn.py index 44afe1833..058c2905b 100644 --- a/src/ott/solvers/linear/sinkhorn.py +++ b/src/ott/solvers/linear/sinkhorn.py @@ -52,11 +52,11 @@ class SinkhornState(NamedTuple): """Holds the state variables used to solve OT with Sinkhorn.""" - errors: Optional[jax.Array] = None - fu: Optional[jax.Array] = None - gv: Optional[jax.Array] = None - old_fus: Optional[jax.Array] = None - old_mapped_fus: Optional[jax.Array] = None + errors: Optional[jnp.ndarray] = None + fu: Optional[jnp.ndarray] = None + gv: Optional[jnp.ndarray] = None + old_fus: Optional[jnp.ndarray] = None + old_mapped_fus: Optional[jnp.ndarray] = None def set(self, **kwargs: Any) -> "SinkhornState": """Return a copy of self, with potential overwrites.""" @@ -70,7 +70,7 @@ def solution_error( lse_mode: bool, parallel_dual_updates: bool, recenter: bool, - ) -> jax.Array: + ) -> jnp.ndarray: """State dependent function to return error.""" fu, gv = self.fu, self.gv if recenter and lse_mode: @@ -92,10 +92,10 @@ def compute_kl_reg_cost( # noqa: D102 def recenter( self, - f: jax.Array, - g: jax.Array, + f: jnp.ndarray, + g: jnp.ndarray, ot_prob: linear_problem.LinearProblem, - ) -> Tuple[jax.Array, jax.Array]: + ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Re-center dual potentials. If the ``ot_prob`` is balanced, the ``f`` potential is zero-centered. @@ -132,14 +132,14 @@ def recenter( def solution_error( - f_u: jax.Array, - g_v: jax.Array, + f_u: jnp.ndarray, + g_v: jnp.ndarray, ot_prob: linear_problem.LinearProblem, *, norm_error: Sequence[int], lse_mode: bool, parallel_dual_updates: bool, -) -> jax.Array: +) -> jnp.ndarray: """Given two potential/scaling solutions, computes deviation to optimality. When the ``ot_prob`` problem is balanced and the usual Sinkhorn updates are @@ -153,8 +153,8 @@ def solution_error( additional quantities to qualify optimality must be taken into account. Args: - f_u: jax.Array, potential or scaling - g_v: jax.Array, potential or scaling + f_u: jnp.ndarray, potential or scaling + g_v: jnp.ndarray, potential or scaling ot_prob: linear OT problem norm_error: int, p-norm used to compute error. lse_mode: True if log-sum-exp operations, False if kernel vector products. @@ -196,9 +196,9 @@ def solution_error( def marginal_error( - f_u: jax.Array, - g_v: jax.Array, - target: jax.Array, + f_u: jnp.ndarray, + g_v: jnp.ndarray, + target: jnp.ndarray, geom: geometry.Geometry, axis: int = 0, norm_error: Sequence[int] = (1,), @@ -229,7 +229,7 @@ def marginal_error( def compute_kl_reg_cost( - f: jax.Array, g: jax.Array, ot_prob: linear_problem.LinearProblem, + f: jnp.ndarray, g: jnp.ndarray, ot_prob: linear_problem.LinearProblem, lse_mode: bool ) -> float: r"""Compute objective of Sinkhorn for OT problem given dual solutions. @@ -243,8 +243,8 @@ def compute_kl_reg_cost( values, ``jnp.where`` is used to cancel these contributions. Args: - f: jax.Array, potential - g: jax.Array, potential + f: jnp.ndarray, potential + g: jnp.ndarray, potential ot_prob: linear optimal transport problem. lse_mode: bool, whether to compute total mass in lse or kernel mode. @@ -320,12 +320,12 @@ class SinkhornOutput(NamedTuple): computations of errors. """ - f: Optional[jax.Array] = None - g: Optional[jax.Array] = None - errors: Optional[jax.Array] = None + f: Optional[jnp.ndarray] = None + g: Optional[jnp.ndarray] = None + errors: Optional[jnp.ndarray] = None reg_ot_cost: Optional[float] = None ot_prob: Optional[linear_problem.LinearProblem] = None - threshold: Optional[jax.Array] = None + threshold: Optional[jnp.ndarray] = None converged: Optional[bool] = None inner_iterations: Optional[int] = None @@ -342,7 +342,7 @@ def set_cost( # noqa: D102 return self.set(reg_ot_cost=compute_kl_reg_cost(f, g, ot_prob, lse_mode)) @property - def dual_cost(self) -> jax.Array: + def dual_cost(self) -> jnp.ndarray: """Return dual transport cost, without considering regularizer.""" a, b = self.ot_prob.a, self.ot_prob.b dual_cost = jnp.sum(jnp.where(a > 0.0, a * self.f, 0)) @@ -399,7 +399,9 @@ def kl_reg_cost(self) -> float: """ return self.reg_ot_cost - def transport_cost_at_geom(self, other_geom: geometry.Geometry) -> jax.Array: + def transport_cost_at_geom( + self, other_geom: geometry.Geometry + ) -> jnp.ndarray: r"""Return bare transport cost of current solution at any geometry. In order to compute cost, we check first if the geometry can be converted @@ -426,11 +428,11 @@ def geom(self) -> geometry.Geometry: # noqa: D102 return self.ot_prob.geom @property - def a(self) -> jax.Array: # noqa: D102 + def a(self) -> jnp.ndarray: # noqa: D102 return self.ot_prob.a @property - def b(self) -> jax.Array: # noqa: D102 + def b(self) -> jnp.ndarray: # noqa: D102 return self.ot_prob.b @property @@ -439,13 +441,13 @@ def n_iters(self) -> int: # noqa: D102 return jnp.sum(self.errors != -1) * self.inner_iterations @property - def scalings(self) -> Tuple[jax.Array, jax.Array]: # noqa: D102 + def scalings(self) -> Tuple[jnp.ndarray, jnp.ndarray]: # noqa: D102 u = self.ot_prob.geom.scaling_from_potential(self.f) v = self.ot_prob.geom.scaling_from_potential(self.g) return u, v @property - def matrix(self) -> jax.Array: + def matrix(self) -> jnp.ndarray: """Transport matrix if it can be instantiated.""" try: return self.ot_prob.geom.transport_from_potentials(self.f, self.g) @@ -457,13 +459,13 @@ def transport_mass(self) -> float: """Sum of transport matrix.""" return self.marginal(0).sum() - def apply(self, inputs: jax.Array, axis: int = 0) -> jax.Array: + def apply(self, inputs: jnp.ndarray, axis: int = 0) -> jnp.ndarray: """Apply the transport to a ndarray; axis=1 for its transpose.""" return self.ot_prob.geom.apply_transport_from_potentials( self.f, self.g, inputs, axis=axis ) - def marginal(self, axis: int) -> jax.Array: # noqa: D102 + def marginal(self, axis: int) -> jnp.ndarray: # noqa: D102 return self.ot_prob.geom.marginal_from_potentials(self.f, self.g, axis=axis) def cost_at_geom(self, other_geom: geometry.Geometry) -> float: @@ -830,8 +832,8 @@ def __init__( def __call__( self, ot_prob: linear_problem.LinearProblem, - init: Tuple[Optional[jax.Array], Optional[jax.Array]] = (None, None), - rng: Optional[jax.Array] = None, + init: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray]] = (None, None), + rng: Optional[jnp.ndarray] = None, ) -> SinkhornOutput: """Run Sinkhorn algorithm. @@ -866,7 +868,9 @@ def xi(tau_i: float, tau_j: float) -> float: k_ij = k(tau_i, tau_j) return k_ij / (1. - k_ij) - def smin(potential: jax.Array, marginal: jax.Array, tau: float) -> float: + def smin( + potential: jnp.ndarray, marginal: jnp.ndarray, tau: float + ) -> float: rho = uf.rho(ot_prob.epsilon, tau) return -rho * mu.logsumexp(-potential / rho, b=marginal) @@ -1011,8 +1015,8 @@ def outer_iterations(self) -> int: return np.ceil(self.max_iterations / self.inner_iterations).astype(int) def init_state( - self, ot_prob: linear_problem.LinearProblem, init: Tuple[jax.Array, - jax.Array] + self, ot_prob: linear_problem.LinearProblem, init: Tuple[jnp.ndarray, + jnp.ndarray] ) -> SinkhornState: """Return the initial state of the loop.""" fu, gv = init @@ -1120,7 +1124,7 @@ def tree_unflatten(cls, aux_data, children): # noqa: D102 def run( ot_prob: linear_problem.LinearProblem, solver: Sinkhorn, - init: Tuple[jax.Array, ...] + init: Tuple[jnp.ndarray, ...] ) -> SinkhornOutput: """Run loop of the solver, outputting a state upgraded to an output.""" iter_fun = _iterations_implicit if solver.implicit_diff else iterations @@ -1133,7 +1137,7 @@ def run( def iterations( ot_prob: linear_problem.LinearProblem, solver: Sinkhorn, - init: Tuple[jax.Array, ...] + init: Tuple[jnp.ndarray, ...] ) -> SinkhornOutput: """Jittable Sinkhorn loop. args contain initialization variables.""" @@ -1170,8 +1174,8 @@ def body_fn( def _iterations_taped( ot_prob: linear_problem.LinearProblem, solver: Sinkhorn, - init: Tuple[jax.Array, ...] -) -> Tuple[SinkhornOutput, Tuple[jax.Array, jax.Array, + init: Tuple[jnp.ndarray, ...] +) -> Tuple[SinkhornOutput, Tuple[jnp.ndarray, jnp.ndarray, linear_problem.LinearProblem, Sinkhorn]]: """Run forward pass of the Sinkhorn algorithm storing side information.""" state = iterations(ot_prob, solver, init) @@ -1190,7 +1194,7 @@ def _iterations_implicit_bwd(res, gr): considered. Returns: - a tuple of gradients: PyTree for geom, one jax.Array for each of a and b. + a tuple of gradients: PyTree for geom, one jnp.ndarray for each of a and b. """ f, g, ot_prob, solver = res gr = gr[:2] diff --git a/src/ott/solvers/linear/sinkhorn_lr.py b/src/ott/solvers/linear/sinkhorn_lr.py index b6732f76f..ba83aeb99 100644 --- a/src/ott/solvers/linear/sinkhorn_lr.py +++ b/src/ott/solvers/linear/sinkhorn_lr.py @@ -43,12 +43,12 @@ class LRSinkhornState(NamedTuple): """State of the Low Rank Sinkhorn algorithm.""" - q: jax.Array - r: jax.Array - g: jax.Array + q: jnp.ndarray + r: jnp.ndarray + g: jnp.ndarray gamma: float - costs: jax.Array - errors: jax.Array + costs: jnp.ndarray + errors: jnp.ndarray crossed_threshold: bool def compute_error( # noqa: D102 @@ -79,7 +79,7 @@ def reg_ot_cost( # noqa: D102 def solution_error( # noqa: D102 self, ot_prob: linear_problem.LinearProblem, norm_error: Tuple[int, ...] - ) -> jax.Array: + ) -> jnp.ndarray: return solution_error(self.q, self.r, ot_prob, norm_error) def set(self, **kwargs: Any) -> "LRSinkhornState": @@ -88,9 +88,9 @@ def set(self, **kwargs: Any) -> "LRSinkhornState": def compute_reg_ot_cost( - q: jax.Array, - r: jax.Array, - g: jax.Array, + q: jnp.ndarray, + r: jnp.ndarray, + g: jnp.ndarray, ot_prob: linear_problem.LinearProblem, epsilon: float, use_danskin: bool = False @@ -110,7 +110,7 @@ def compute_reg_ot_cost( regularized OT cost, the (primal) transport cost of the low-rank solution. """ - def ent(x: jax.Array) -> float: + def ent(x: jnp.ndarray) -> float: # generalized entropy return jnp.sum(jsp.special.entr(x) + x) @@ -131,9 +131,9 @@ def ent(x: jax.Array) -> float: def solution_error( - q: jax.Array, r: jax.Array, ot_prob: linear_problem.LinearProblem, + q: jnp.ndarray, r: jnp.ndarray, ot_prob: linear_problem.LinearProblem, norm_error: Tuple[int, ...] -) -> jax.Array: +) -> jnp.ndarray: """Compute solution error. Since only balanced case is available for LR, this is marginal deviation. @@ -166,13 +166,13 @@ def solution_error( class LRSinkhornOutput(NamedTuple): """Transport interface for a low-rank Sinkhorn solution.""" - q: jax.Array - r: jax.Array - g: jax.Array - costs: jax.Array + q: jnp.ndarray + r: jnp.ndarray + g: jnp.ndarray + costs: jnp.ndarray # TODO(michalk8): must be called `errors`, because of `store_inner_errors` # in future, enforce via class hierarchy - errors: jax.Array + errors: jnp.ndarray ot_prob: linear_problem.LinearProblem epsilon: float inner_iterations: int @@ -211,11 +211,11 @@ def geom(self) -> geometry.Geometry: # noqa: D102 return self.ot_prob.geom @property - def a(self) -> jax.Array: # noqa: D102 + def a(self) -> jnp.ndarray: # noqa: D102 return self.ot_prob.a @property - def b(self) -> jax.Array: # noqa: D102 + def b(self) -> jnp.ndarray: # noqa: D102 return self.ot_prob.b @property @@ -229,17 +229,17 @@ def converged(self) -> bool: # noqa: D102 ) @property - def matrix(self) -> jax.Array: + def matrix(self) -> jnp.ndarray: """Transport matrix if it can be instantiated.""" return (self.q * self._inv_g) @ self.r.T - def apply(self, inputs: jax.Array, axis: int = 0) -> jax.Array: + def apply(self, inputs: jnp.ndarray, axis: int = 0) -> jnp.ndarray: """Apply the transport to a array; axis=1 for its transpose.""" q, r = (self.q, self.r) if axis == 1 else (self.r, self.q) # for `axis=0`: (batch, m), (m, r), (r,), (r, n) return ((inputs @ r) * self._inv_g) @ q.T - def marginal(self, axis: int) -> jax.Array: # noqa: D102 + def marginal(self, axis: int) -> jnp.ndarray: # noqa: D102 length = self.q.shape[0] if axis == 0 else self.r.shape[0] return self.apply(jnp.ones(length,), axis=axis) @@ -262,7 +262,7 @@ def transport_mass(self) -> float: return self.marginal(0).sum() @property - def _inv_g(self) -> jax.Array: + def _inv_g(self) -> jnp.ndarray: return 1. / self.g @@ -341,9 +341,9 @@ def __init__( def __call__( self, ot_prob: linear_problem.LinearProblem, - init: Tuple[Optional[jax.Array], Optional[jax.Array], - Optional[jax.Array]] = (None, None, None), - rng: Optional[jax.Array] = None, + init: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray], + Optional[jnp.ndarray]] = (None, None, None), + rng: Optional[jnp.ndarray] = None, **kwargs: Any, ) -> LRSinkhornOutput: """Run low-rank Sinkhorn. @@ -371,7 +371,7 @@ def _get_costs( self, ot_prob: linear_problem.LinearProblem, state: LRSinkhornState, - ) -> Tuple[jax.Array, jax.Array, jax.Array, float]: + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, float]: log_q, log_r, log_g = ( mu.safe_log(state.q), mu.safe_log(state.r), mu.safe_log(state.g) ) @@ -407,9 +407,9 @@ def _get_costs( # TODO(michalk8): move to `lr_utils` when refactoring this def dykstra_update_lse( self, - c_q: jax.Array, - c_r: jax.Array, - h: jax.Array, + c_q: jnp.ndarray, + c_r: jnp.ndarray, + h: jnp.ndarray, gamma: float, ot_prob: linear_problem.LinearProblem, min_entry_value: float = 1e-6, @@ -417,7 +417,7 @@ def dykstra_update_lse( min_iter: int = 0, inner_iter: int = 10, max_iter: int = 10000 - ) -> Tuple[jax.Array, jax.Array, jax.Array]: + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Run Dykstra's algorithm.""" # shortcuts for problem's definition. r = self.rank @@ -435,24 +435,24 @@ def dykstra_update_lse( constants = c_q, c_r, loga, logb def cond_fn( - iteration: int, constants: Tuple[jax.Array, ...], - state_inner: Tuple[jax.Array, ...] + iteration: int, constants: Tuple[jnp.ndarray, ...], + state_inner: Tuple[jnp.ndarray, ...] ) -> bool: del iteration, constants *_, err = state_inner return err > tolerance def _softm( - f: jax.Array, g: jax.Array, c: jax.Array, axis: int - ) -> jax.Array: + f: jnp.ndarray, g: jnp.ndarray, c: jnp.ndarray, axis: int + ) -> jnp.ndarray: return jsp.special.logsumexp( gamma * (f[:, None] + g[None, :] - c), axis=axis ) def body_fn( - iteration: int, constants: Tuple[jax.Array, ...], - state_inner: Tuple[jax.Array, ...], compute_error: bool - ) -> Tuple[jax.Array, ...]: + iteration: int, constants: Tuple[jnp.ndarray, ...], + state_inner: Tuple[jnp.ndarray, ...], compute_error: bool + ) -> Tuple[jnp.ndarray, ...]: # TODO(michalk8): in the future, use `NamedTuple` f1, f2, g1_old, g2_old, h_old, w_gi, w_gp, w_q, w_r, err = state_inner c_q, c_r, loga, logb = constants @@ -501,15 +501,15 @@ def body_fn( return f1, f2, g1_old, g2_old, h_old, w_gi, w_gp, w_q, w_r, err def recompute_couplings( - f1: jax.Array, - g1: jax.Array, - c_q: jax.Array, - f2: jax.Array, - g2: jax.Array, - c_r: jax.Array, - h: jax.Array, + f1: jnp.ndarray, + g1: jnp.ndarray, + c_q: jnp.ndarray, + f2: jnp.ndarray, + g2: jnp.ndarray, + c_r: jnp.ndarray, + h: jnp.ndarray, gamma: float, - ) -> Tuple[jax.Array, jax.Array, jax.Array]: + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: q = jnp.exp(gamma * (f1[:, None] + g1[None, :] - c_q)) r = jnp.exp(gamma * (f2[:, None] + g2[None, :] - c_r)) g = jnp.exp(gamma * h) @@ -524,9 +524,9 @@ def recompute_couplings( def dykstra_update_kernel( self, - k_q: jax.Array, - k_r: jax.Array, - k_g: jax.Array, + k_q: jnp.ndarray, + k_r: jnp.ndarray, + k_g: jnp.ndarray, gamma: float, ot_prob: linear_problem.LinearProblem, min_entry_value: float = 1e-6, @@ -534,7 +534,7 @@ def dykstra_update_kernel( min_iter: int = 0, inner_iter: int = 10, max_iter: int = 10000 - ) -> Tuple[jax.Array, jax.Array, jax.Array]: + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Run Dykstra's algorithm.""" # shortcuts for problem's definition. rank = self.rank @@ -553,17 +553,17 @@ def dykstra_update_kernel( constants = k_q, k_r, k_g, a, b def cond_fn( - iteration: int, constants: Tuple[jax.Array, ...], - state_inner: Tuple[jax.Array, ...] + iteration: int, constants: Tuple[jnp.ndarray, ...], + state_inner: Tuple[jnp.ndarray, ...] ) -> bool: del iteration, constants *_, err = state_inner return err > tolerance def body_fn( - iteration: int, constants: Tuple[jax.Array, ...], - state_inner: Tuple[jax.Array, ...], compute_error: bool - ) -> Tuple[jax.Array, ...]: + iteration: int, constants: Tuple[jnp.ndarray, ...], + state_inner: Tuple[jnp.ndarray, ...], compute_error: bool + ) -> Tuple[jnp.ndarray, ...]: # TODO(michalk8): in the future, use `NamedTuple` u1, u2, v1_old, v2_old, g_old, q_gi, q_gp, q_q, q_r, err = state_inner k_q, k_r, k_g, a, b = constants @@ -600,14 +600,14 @@ def body_fn( return u1, u2, v1_old, v2_old, g_old, q_gi, q_gp, q_q, q_r, err def recompute_couplings( - u1: jax.Array, - v1: jax.Array, - k_q: jax.Array, - u2: jax.Array, - v2: jax.Array, - k_r: jax.Array, - g: jax.Array, - ) -> Tuple[jax.Array, jax.Array, jax.Array]: + u1: jnp.ndarray, + v1: jnp.ndarray, + k_q: jnp.ndarray, + u2: jnp.ndarray, + v2: jnp.ndarray, + k_r: jnp.ndarray, + g: jnp.ndarray, + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: q = u1.reshape((-1, 1)) * k_q * v1.reshape((1, -1)) r = u2.reshape((-1, 1)) * k_r * v2.reshape((1, -1)) return q, r, g @@ -736,7 +736,7 @@ def create_initializer( def init_state( self, ot_prob: linear_problem.LinearProblem, - init: Tuple[jax.Array, jax.Array, jax.Array] + init: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray] ) -> LRSinkhornState: """Return the initial state of the loop.""" q, r, g = init @@ -811,7 +811,8 @@ def _diverged(self, state: LRSinkhornState, iteration: int) -> bool: def run( ot_prob: linear_problem.LinearProblem, solver: LRSinkhorn, - init: Tuple[Optional[jax.Array], Optional[jax.Array], Optional[jax.Array]], + init: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray], + Optional[jnp.ndarray]], ) -> LRSinkhornOutput: """Run loop of the solver, outputting a state upgraded to an output.""" out = sinkhorn.iterations(ot_prob, solver, init) diff --git a/src/ott/solvers/linear/univariate.py b/src/ott/solvers/linear/univariate.py index 1f2a47b6f..2b6392227 100644 --- a/src/ott/solvers/linear/univariate.py +++ b/src/ott/solvers/linear/univariate.py @@ -53,7 +53,7 @@ class UnivariateSolver: def __init__( self, - sort_fn: Optional[Callable[[jax.Array], jax.Array]] = None, + sort_fn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, cost_fn: Optional[costs.CostFn] = None, method: Literal["subsample", "quantile", "wasserstein", "equal"] = "subsample", @@ -66,10 +66,10 @@ def __init__( def __call__( self, - x: jax.Array, - y: jax.Array, - a: Optional[jax.Array] = None, - b: Optional[jax.Array] = None + x: jnp.ndarray, + y: jnp.ndarray, + a: Optional[jnp.ndarray] = None, + b: Optional[jnp.ndarray] = None ) -> float: """Computes the Univariate OT Distance between `x` and `y`. @@ -113,8 +113,8 @@ def __call__( return self.cost_fn.pairwise(xx, yy) * (n / xx.shape[0]) def _cdf_distance( - self, x: jax.Array, y: jax.Array, a: Optional[jax.Array], - b: Optional[jax.Array] + self, x: jnp.ndarray, y: jnp.ndarray, a: Optional[jnp.ndarray], + b: Optional[jnp.ndarray] ): # Implementation based on `scipy` implementation for # :func: diff --git a/src/ott/solvers/quadratic/_solve.py b/src/ott/solvers/quadratic/_solve.py index 986680637..9cdefec93 100644 --- a/src/ott/solvers/quadratic/_solve.py +++ b/src/ott/solvers/quadratic/_solve.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Any, Literal, Optional, Union -import jax +import jax.numpy as jnp from ott.geometry import geometry from ott.problems.quadratic import quadratic_costs, quadratic_problem @@ -28,8 +28,8 @@ def solve( geom_yy: geometry.Geometry, geom_xy: Optional[geometry.Geometry] = None, fused_penalty: float = 1.0, - a: Optional[jax.Array] = None, - b: Optional[jax.Array] = None, + a: Optional[jnp.ndarray] = None, + b: Optional[jnp.ndarray] = None, tau_a: float = 1.0, tau_b: float = 1.0, loss: Union[Literal["sqeucl", "kl"], quadratic_costs.GWLoss] = "sqeucl", diff --git a/src/ott/solvers/quadratic/gromov_wasserstein.py b/src/ott/solvers/quadratic/gromov_wasserstein.py index 554cdaaed..862b91999 100644 --- a/src/ott/solvers/quadratic/gromov_wasserstein.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein.py @@ -63,10 +63,10 @@ class GWOutput(NamedTuple): old_transport_mass: Holds total mass of transport at previous iteration. """ - costs: Optional[jax.Array] = None - linear_convergence: Optional[jax.Array] = None + costs: Optional[jnp.ndarray] = None + linear_convergence: Optional[jnp.ndarray] = None converged: bool = False - errors: Optional[jax.Array] = None + errors: Optional[jnp.ndarray] = None linear_state: Optional[LinearOutput] = None geom: Optional[geometry.Geometry] = None # Intermediate values. @@ -77,11 +77,11 @@ def set(self, **kwargs: Any) -> "GWOutput": return self._replace(**kwargs) @property - def matrix(self) -> jax.Array: + def matrix(self) -> jnp.ndarray: """Transport matrix.""" return self._rescale_factor * self.linear_state.matrix - def apply(self, inputs: jax.Array, axis: int = 0) -> jax.Array: + def apply(self, inputs: jnp.ndarray, axis: int = 0) -> jnp.ndarray: """Apply the transport to an array; axis=1 for its transpose.""" return self._rescale_factor * self.linear_state.apply(inputs, axis=axis) @@ -124,13 +124,13 @@ class GWState(NamedTuple): at each iteration. """ - costs: jax.Array - linear_convergence: jax.Array + costs: jnp.ndarray + linear_convergence: jnp.ndarray linear_state: LinearOutput linear_pb: linear_problem.LinearProblem old_transport_mass: float - rngs: Optional[jax.Array] = None - errors: Optional[jax.Array] = None + rngs: Optional[jnp.ndarray] = None + errors: Optional[jnp.ndarray] = None def set(self, **kwargs: Any) -> "GWState": """Return a copy of self, possibly with overwrites.""" @@ -213,7 +213,7 @@ def __call__( self, prob: quadratic_problem.QuadraticProblem, init: Optional[linear_problem.LinearProblem] = None, - rng: Optional[jax.Array] = None, + rng: Optional[jnp.ndarray] = None, **kwargs: Any, ) -> GWOutput: """Run the Gromov-Wasserstein solver. @@ -272,7 +272,7 @@ def init_state( self, prob: quadratic_problem.QuadraticProblem, init: linear_problem.LinearProblem, - rng: jax.Array, + rng: jnp.ndarray, ) -> GWState: """Initialize the state of the Gromov-Wasserstein iterations. @@ -361,7 +361,7 @@ def iterations( solver: GromovWasserstein, prob: quadratic_problem.QuadraticProblem, init: linear_problem.LinearProblem, - rng: jax.Array, + rng: jnp.ndarray, ) -> GWOutput: """Jittable Gromov-Wasserstein outer loop.""" diff --git a/src/ott/solvers/quadratic/gromov_wasserstein_lr.py b/src/ott/solvers/quadratic/gromov_wasserstein_lr.py index 62a5592bc..710d8f617 100644 --- a/src/ott/solvers/quadratic/gromov_wasserstein_lr.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein_lr.py @@ -46,12 +46,12 @@ class LRGWState(NamedTuple): """State of the low-rank GW algorithm.""" - q: jax.Array - r: jax.Array - g: jax.Array + q: jnp.ndarray + r: jnp.ndarray + g: jnp.ndarray gamma: float - costs: jax.Array - errors: jax.Array + costs: jnp.ndarray + errors: jnp.ndarray crossed_threshold: bool def compute_error( # noqa: D102 @@ -85,9 +85,9 @@ def set(self, **kwargs: Any) -> "LRGWState": def compute_reg_gw_cost( - q: jax.Array, - r: jax.Array, - g: jax.Array, + q: jnp.ndarray, + r: jnp.ndarray, + g: jnp.ndarray, ot_prob: quadratic_problem.QuadraticProblem, epsilon: float, use_danskin: bool = False @@ -107,7 +107,7 @@ def compute_reg_gw_cost( regularized OT cost, the (primal) transport cost of the low-rank solution. """ - def ent(x: jax.Array) -> float: + def ent(x: jnp.ndarray) -> float: # generalized entropy return jnp.sum(jsp.special.entr(x) + x) @@ -139,13 +139,13 @@ def ent(x: jax.Array) -> float: class LRGWOutput(NamedTuple): """Transport interface for a low-rank GW solution.""" - q: jax.Array - r: jax.Array - g: jax.Array - costs: jax.Array + q: jnp.ndarray + r: jnp.ndarray + g: jnp.ndarray + costs: jnp.ndarray # TODO(michalk8): must be called `errors`, because of `store_inner_errors` # in future, enforce via class hierarchy - errors: jax.Array + errors: jnp.ndarray ot_prob: quadratic_problem.QuadraticProblem epsilon: float inner_iterations: int @@ -184,11 +184,11 @@ def geom(self) -> geometry.Geometry: # noqa: D102 return _linearized_geometry(self.ot_prob, q=self.q, r=self.r, g=self.g) @property - def a(self) -> jax.Array: # noqa: D102 + def a(self) -> jnp.ndarray: # noqa: D102 return self.ot_prob.a @property - def b(self) -> jax.Array: # noqa: D102 + def b(self) -> jnp.ndarray: # noqa: D102 return self.ot_prob.b @property @@ -202,17 +202,17 @@ def converged(self) -> bool: # noqa: D102 ) @property - def matrix(self) -> jax.Array: + def matrix(self) -> jnp.ndarray: """Transport matrix if it can be instantiated.""" return (self.q * self._inv_g) @ self.r.T - def apply(self, inputs: jax.Array, axis: int = 0) -> jax.Array: + def apply(self, inputs: jnp.ndarray, axis: int = 0) -> jnp.ndarray: """Apply the transport to a array; axis=1 for its transpose.""" q, r = (self.q, self.r) if axis == 1 else (self.r, self.q) # for `axis=0`: (batch, m), (m, r), (r,), (r, n) return ((inputs @ r) * self._inv_g) @ q.T - def marginal(self, axis: int) -> jax.Array: # noqa: D102 + def marginal(self, axis: int) -> jnp.ndarray: # noqa: D102 length = self.q.shape[0] if axis == 0 else self.r.shape[0] return self.apply(jnp.ones(length,), axis=axis) @@ -250,7 +250,7 @@ def transport_mass(self) -> float: return self.marginal(0).sum() @property - def _inv_g(self) -> jax.Array: + def _inv_g(self) -> jnp.ndarray: return 1.0 / self.g @@ -334,9 +334,9 @@ def __init__( def __call__( self, ot_prob: quadratic_problem.QuadraticProblem, - init: Tuple[Optional[jax.Array], Optional[jax.Array], - Optional[jax.Array]] = (None, None, None), - rng: Optional[jax.Array] = None, + init: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray], + Optional[jnp.ndarray]] = (None, None, None), + rng: Optional[jnp.ndarray] = None, **kwargs: Any, ) -> LRGWOutput: """Run low-rank Gromov-Wasserstein solver. @@ -370,7 +370,7 @@ def _get_costs( self, ot_prob: quadratic_problem.QuadraticProblem, state: LRGWState, - ) -> Tuple[jax.Array, jax.Array, jax.Array, float]: + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, float]: q, r, g = state.q, state.r, state.g log_q, log_r, log_g = mu.safe_log(q), mu.safe_log(r), mu.safe_log(g) inv_g = 1.0 / g[None, :] @@ -427,9 +427,9 @@ def _get_costs( # TODO(michalk8): move to `lr_utils` when refactoring this the future def dykstra_update_lse( self, - c_q: jax.Array, - c_r: jax.Array, - h: jax.Array, + c_q: jnp.ndarray, + c_r: jnp.ndarray, + h: jnp.ndarray, gamma: float, ot_prob: quadratic_problem.QuadraticProblem, min_entry_value: float = 1e-6, @@ -437,7 +437,7 @@ def dykstra_update_lse( min_iter: int = 0, inner_iter: int = 10, max_iter: int = 10000 - ) -> Tuple[jax.Array, jax.Array, jax.Array]: + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Run Dykstra's algorithm.""" # shortcuts for problem's definition. r = self.rank @@ -455,24 +455,24 @@ def dykstra_update_lse( constants = c_q, c_r, loga, logb def cond_fn( - iteration: int, constants: Tuple[jax.Array, ...], - state_inner: Tuple[jax.Array, ...] + iteration: int, constants: Tuple[jnp.ndarray, ...], + state_inner: Tuple[jnp.ndarray, ...] ) -> bool: del iteration, constants *_, err = state_inner return err > tolerance def _softm( - f: jax.Array, g: jax.Array, c: jax.Array, axis: int - ) -> jax.Array: + f: jnp.ndarray, g: jnp.ndarray, c: jnp.ndarray, axis: int + ) -> jnp.ndarray: return jsp.special.logsumexp( gamma * (f[:, None] + g[None, :] - c), axis=axis ) def body_fn( - iteration: int, constants: Tuple[jax.Array, ...], - state_inner: Tuple[jax.Array, ...], compute_error: bool - ) -> Tuple[jax.Array, ...]: + iteration: int, constants: Tuple[jnp.ndarray, ...], + state_inner: Tuple[jnp.ndarray, ...], compute_error: bool + ) -> Tuple[jnp.ndarray, ...]: # TODO(michalk8): in the future, use `NamedTuple` f1, f2, g1_old, g2_old, h_old, w_gi, w_gp, w_q, w_r, err = state_inner c_q, c_r, loga, logb = constants @@ -522,15 +522,15 @@ def body_fn( return f1, f2, g1_old, g2_old, h_old, w_gi, w_gp, w_q, w_r, err def recompute_couplings( - f1: jax.Array, - g1: jax.Array, - c_q: jax.Array, - f2: jax.Array, - g2: jax.Array, - c_r: jax.Array, - h: jax.Array, + f1: jnp.ndarray, + g1: jnp.ndarray, + c_q: jnp.ndarray, + f2: jnp.ndarray, + g2: jnp.ndarray, + c_r: jnp.ndarray, + h: jnp.ndarray, gamma: float, - ) -> Tuple[jax.Array, jax.Array, jax.Array]: + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: q = jnp.exp(gamma * (f1[:, None] + g1[None, :] - c_q)) r = jnp.exp(gamma * (f2[:, None] + g2[None, :] - c_r)) g = jnp.exp(gamma * h) @@ -545,9 +545,9 @@ def recompute_couplings( def dykstra_update_kernel( self, - k_q: jax.Array, - k_r: jax.Array, - k_g: jax.Array, + k_q: jnp.ndarray, + k_r: jnp.ndarray, + k_g: jnp.ndarray, gamma: float, ot_prob: quadratic_problem.QuadraticProblem, min_entry_value: float = 1e-6, @@ -555,7 +555,7 @@ def dykstra_update_kernel( min_iter: int = 0, inner_iter: int = 10, max_iter: int = 10000 - ) -> Tuple[jax.Array, jax.Array, jax.Array]: + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Run Dykstra's algorithm.""" # shortcuts for problem's definition. del gamma @@ -575,17 +575,17 @@ def dykstra_update_kernel( constants = k_q, k_r, k_g, a, b def cond_fn( - iteration: int, constants: Tuple[jax.Array, ...], - state_inner: Tuple[jax.Array, ...] + iteration: int, constants: Tuple[jnp.ndarray, ...], + state_inner: Tuple[jnp.ndarray, ...] ) -> bool: del iteration, constants *_, err = state_inner return err > tolerance def body_fn( - iteration: int, constants: Tuple[jax.Array, ...], - state_inner: Tuple[jax.Array, ...], compute_error: bool - ) -> Tuple[jax.Array, ...]: + iteration: int, constants: Tuple[jnp.ndarray, ...], + state_inner: Tuple[jnp.ndarray, ...], compute_error: bool + ) -> Tuple[jnp.ndarray, ...]: # TODO(michalk8): in the future, use `NamedTuple` u1, u2, v1_old, v2_old, g_old, q_gi, q_gp, q_q, q_r, err = state_inner k_q, k_r, k_g, a, b = constants @@ -623,14 +623,14 @@ def body_fn( return u1, u2, v1_old, v2_old, g_old, q_gi, q_gp, q_q, q_r, err def recompute_couplings( - u1: jax.Array, - v1: jax.Array, - k_q: jax.Array, - u2: jax.Array, - v2: jax.Array, - k_r: jax.Array, - g: jax.Array, - ) -> Tuple[jax.Array, jax.Array, jax.Array]: + u1: jnp.ndarray, + v1: jnp.ndarray, + k_q: jnp.ndarray, + u2: jnp.ndarray, + v2: jnp.ndarray, + k_r: jnp.ndarray, + g: jnp.ndarray, + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: q = u1.reshape((-1, 1)) * k_q * v1.reshape((1, -1)) r = u2.reshape((-1, 1)) * k_r * v2.reshape((1, -1)) return q, r, g @@ -762,7 +762,7 @@ def create_initializer( def init_state( self, ot_prob: quadratic_problem.QuadraticProblem, - init: Tuple[jax.Array, jax.Array, jax.Array] + init: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray] ) -> LRGWState: """Return the initial state of the loop.""" q, r, g = init @@ -837,7 +837,8 @@ def _diverged(self, state: LRGWState, iteration: int) -> bool: def run( ot_prob: quadratic_problem.QuadraticProblem, solver: LRGromovWasserstein, - init: Tuple[Optional[jax.Array], Optional[jax.Array], Optional[jax.Array]], + init: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray], + Optional[jnp.ndarray]], ) -> LRGWOutput: """Run loop of the solver, outputting a state upgraded to an output.""" out = sinkhorn.iterations(ot_prob, solver, init) @@ -848,9 +849,9 @@ def run( def dykstra_solution_error( - q: jax.Array, r: jax.Array, ot_prob: quadratic_problem.QuadraticProblem, + q: jnp.ndarray, r: jnp.ndarray, ot_prob: quadratic_problem.QuadraticProblem, norm_error: Tuple[int, ...] -) -> jax.Array: +) -> jnp.ndarray: """Compute solution error. Since only balanced case is available for LR, this is marginal deviation. @@ -883,9 +884,9 @@ def dykstra_solution_error( def _linearized_geometry( prob: quadratic_problem.QuadraticProblem, *, - q: jax.Array, - r: jax.Array, - g: jax.Array, + q: jnp.ndarray, + r: jnp.ndarray, + g: jnp.ndarray, ) -> low_rank.LRCGeometry: inv_sqrt_g = 1.0 / jnp.sqrt(g[None, :]) diff --git a/src/ott/solvers/quadratic/gw_barycenter.py b/src/ott/solvers/quadratic/gw_barycenter.py index 0f753793e..8816c5ada 100644 --- a/src/ott/solvers/quadratic/gw_barycenter.py +++ b/src/ott/solvers/quadratic/gw_barycenter.py @@ -45,13 +45,13 @@ class GWBarycenterState(NamedTuple): gw_convergence: Array of shape ``[max_iter,]`` containing the convergence of all GW problems at each iteration. """ - cost: Optional[jax.Array] = None - x: Optional[jax.Array] = None - a: Optional[jax.Array] = None - errors: Optional[jax.Array] = None - costs: Optional[jax.Array] = None - costs_bary: Optional[jax.Array] = None - gw_convergence: Optional[jax.Array] = None + cost: Optional[jnp.ndarray] = None + x: Optional[jnp.ndarray] = None + a: Optional[jnp.ndarray] = None + errors: Optional[jnp.ndarray] = None + costs: Optional[jnp.ndarray] = None + costs_bary: Optional[jnp.ndarray] = None + gw_convergence: Optional[jnp.ndarray] = None def set(self, **kwargs: Any) -> "GWBarycenterState": """Return a copy of self, possibly with overwrites.""" @@ -133,9 +133,10 @@ def init_state( self, problem: gw_barycenter.GWBarycenterProblem, bar_size: int, - bar_init: Optional[Union[jax.Array, Tuple[jax.Array, jax.Array]]] = None, - a: Optional[jax.Array] = None, - rng: Optional[jax.Array] = None, + bar_init: Optional[Union[jnp.ndarray, Tuple[jnp.ndarray, + jnp.ndarray]]] = None, + a: Optional[jnp.ndarray] = None, + rng: Optional[jnp.ndarray] = None, ) -> GWBarycenterState: """Initialize the (fused) Gromov-Wasserstein barycenter state. @@ -209,13 +210,13 @@ def update_state( iteration: int, problem: gw_barycenter.GWBarycenterProblem, store_errors: bool = True, - ) -> Tuple[float, bool, jax.Array, Optional[jax.Array]]: + ) -> Tuple[float, bool, jnp.ndarray, Optional[jnp.ndarray]]: """Solve the (fused) Gromov-Wasserstein barycenter problem.""" def solve_gw( - state: GWBarycenterState, b: jax.Array, y: jax.Array, - f: Optional[jax.Array] - ) -> Tuple[float, bool, jax.Array, Optional[jax.Array]]: + state: GWBarycenterState, b: jnp.ndarray, y: jnp.ndarray, + f: Optional[jnp.ndarray] + ) -> Tuple[float, bool, jnp.ndarray, Optional[jnp.ndarray]]: quad_problem = problem._create_problem(state, y=y, b=b, f=f) out = self._quad_solver(quad_problem) return ( @@ -281,8 +282,9 @@ def tree_unflatten( # noqa: D102 @partial(jax.vmap, in_axes=[None, 0, None, 0, None]) def init_transports( - solver, rng: jax.Array, a: jax.Array, b: jax.Array, epsilon: Optional[float] -) -> jax.Array: + solver, rng: jnp.ndarray, a: jnp.ndarray, b: jnp.ndarray, + epsilon: Optional[float] +) -> jnp.ndarray: """Initialize random 2D point cloud and solve the linear OT problem. Args: diff --git a/src/ott/tools/gaussian_mixture/fit_gmm.py b/src/ott/tools/gaussian_mixture/fit_gmm.py index 45d8e0935..0e3fbc4e8 100644 --- a/src/ott/tools/gaussian_mixture/fit_gmm.py +++ b/src/ott/tools/gaussian_mixture/fit_gmm.py @@ -62,8 +62,8 @@ def get_assignment_probs( - gmm: gaussian_mixture.GaussianMixture, points: jax.Array -) -> jax.Array: + gmm: gaussian_mixture.GaussianMixture, points: jnp.ndarray +) -> jnp.ndarray: r"""Get component assignment probabilities used in the E step of EM. Here we compute the component assignment probabilities p(Z|X, \Theta^{(t)}) @@ -81,9 +81,9 @@ def get_assignment_probs( def get_q( gmm: gaussian_mixture.GaussianMixture, - assignment_probs: jax.Array, - points: jax.Array, - point_weights: Optional[jax.Array] = None, + assignment_probs: jnp.ndarray, + points: jnp.ndarray, + point_weights: Optional[jnp.ndarray] = None, ) -> float: r"""Get Q(\Theta|\Theta^{(t)}). @@ -109,8 +109,8 @@ def get_q( def log_prob_loss( gmm: gaussian_mixture.GaussianMixture, - points: jax.Array, - point_weights: Optional[jax.Array] = None, + points: jnp.ndarray, + point_weights: Optional[jnp.ndarray] = None, ) -> float: """Loss function: weighted mean of (-log prob of observations). @@ -130,8 +130,8 @@ def log_prob_loss( def fit_model_em( gmm: gaussian_mixture.GaussianMixture, - points: jax.Array, - point_weights: Optional[jax.Array], + points: jnp.ndarray, + point_weights: Optional[jnp.ndarray], steps: int, jit: bool = True, verbose: bool = False, @@ -184,10 +184,10 @@ def fit_model_em( # See https://en.wikipedia.org/wiki/K-means%2B%2B for details -def _get_dist_sq(points: jax.Array, loc: jax.Array) -> jax.Array: +def _get_dist_sq(points: jnp.ndarray, loc: jnp.ndarray) -> jnp.ndarray: """Get the squared distance from each point to each loc.""" - def _dist_sq_one_loc(points: jax.Array, loc: jax.Array) -> jax.Array: + def _dist_sq_one_loc(points: jnp.ndarray, loc: jnp.ndarray) -> jnp.ndarray: return jnp.sum((points - loc[None]) ** 2., axis=-1) dist_sq_fn = jax.vmap(_dist_sq_one_loc, in_axes=(None, 0), out_axes=1) @@ -195,8 +195,8 @@ def _dist_sq_one_loc(points: jax.Array, loc: jax.Array) -> jax.Array: def _get_locs( - rng: jax.Array, points: jax.Array, n_components: int -) -> jax.Array: + rng: jnp.ndarray, points: jnp.ndarray, n_components: int +) -> jnp.ndarray: """Get the initial component means. Args: @@ -229,9 +229,9 @@ def _get_locs( def from_kmeans_plusplus( - rng: jax.Array, - points: jax.Array, - point_weights: Optional[jax.Array], + rng: jnp.ndarray, + points: jnp.ndarray, + point_weights: Optional[jnp.ndarray], n_components: int, ) -> gaussian_mixture.GaussianMixture: """Initialize a GMM via a single pass of K-means++. @@ -265,9 +265,9 @@ def from_kmeans_plusplus( def initialize( - rng: jax.Array, - points: jax.Array, - point_weights: Optional[jax.Array], + rng: jnp.ndarray, + points: jnp.ndarray, + point_weights: Optional[jnp.ndarray], n_components: int, n_attempts: int = 50, verbose: bool = False diff --git a/src/ott/tools/gaussian_mixture/fit_gmm_pair.py b/src/ott/tools/gaussian_mixture/fit_gmm_pair.py index 35222caf9..7ecde263c 100644 --- a/src/ott/tools/gaussian_mixture/fit_gmm_pair.py +++ b/src/ott/tools/gaussian_mixture/fit_gmm_pair.py @@ -98,9 +98,9 @@ class Observations(NamedTuple): """Weighted observations and their E-step assignment probabilities.""" - points: jax.Array - point_weights: jax.Array - assignment_probs: jax.Array + points: jnp.ndarray + point_weights: jnp.ndarray + assignment_probs: jnp.ndarray # Model fit @@ -108,7 +108,7 @@ class Observations(NamedTuple): def get_q( gmm: gaussian_mixture.GaussianMixture, obs: Observations -) -> jax.Array: +) -> jnp.ndarray: r"""Get Q(\Theta|\Theta^{(t)}). Here Q is the log likelihood for our observations based on the current @@ -159,7 +159,7 @@ def _objective_fn( pair: gaussian_mixture_pair.GaussianMixturePair, obs0: Observations, obs1: Observations, - ) -> jax.Array: + ) -> jnp.ndarray: """Compute the objective function for a pair of GMMs. Args: @@ -204,11 +204,11 @@ def print_losses( def do_e_step( # noqa: D103 - e_step_fn: Callable[[gaussian_mixture.GaussianMixture, jax.Array], - jax.Array], + e_step_fn: Callable[[gaussian_mixture.GaussianMixture, jnp.ndarray], + jnp.ndarray], gmm: gaussian_mixture.GaussianMixture, - points: jax.Array, - point_weights: jax.Array, + points: jnp.ndarray, + point_weights: jnp.ndarray, ) -> Observations: assignment_probs = e_step_fn(gmm, points) return Observations( @@ -307,10 +307,10 @@ def get_fit_model_em_fn( def _fit_model_em( pair: gaussian_mixture_pair.GaussianMixturePair, - points0: jax.Array, - points1: jax.Array, - point_weights0: Optional[jax.Array], - point_weights1: Optional[jax.Array], + points0: jnp.ndarray, + points1: jnp.ndarray, + point_weights0: Optional[jnp.ndarray], + point_weights1: Optional[jnp.ndarray], em_steps: int, m_steps: int = 50, verbose: bool = False, diff --git a/src/ott/tools/gaussian_mixture/gaussian.py b/src/ott/tools/gaussian_mixture/gaussian.py index b8c8e227b..70ac505f2 100644 --- a/src/ott/tools/gaussian_mixture/gaussian.py +++ b/src/ott/tools/gaussian_mixture/gaussian.py @@ -28,15 +28,15 @@ class Gaussian: """Normal distribution.""" - def __init__(self, loc: jax.Array, scale: scale_tril.ScaleTriL): + def __init__(self, loc: jnp.ndarray, scale: scale_tril.ScaleTriL): self._loc = loc self._scale = scale @classmethod def from_samples( cls, - points: jax.Array, - weights: Optional[jax.Array] = None + points: jnp.ndarray, + weights: Optional[jnp.ndarray] = None ) -> "Gaussian": """Construct a Gaussian from weighted samples. @@ -63,11 +63,11 @@ def from_samples( @classmethod def from_random( cls, - rng: jax.Array, + rng: jnp.ndarray, n_dimensions: int, stdev_mean: float = 0.1, stdev_cov: float = 0.1, - ridge: Union[float, jax.Array] = 0, + ridge: Union[float, jnp.ndarray] = 0, dtype: Optional[jnp.dtype] = None ) -> "Gaussian": """Construct a random Gaussian. @@ -94,13 +94,13 @@ def from_random( return cls(loc=loc, scale=scale) @classmethod - def from_mean_and_cov(cls, mean: jax.Array, cov: jax.Array) -> "Gaussian": + def from_mean_and_cov(cls, mean: jnp.ndarray, cov: jnp.ndarray) -> "Gaussian": """Construct a Gaussian from a mean and covariance.""" scale = scale_tril.ScaleTriL.from_covariance(cov) return cls(loc=mean, scale=scale) @property - def loc(self) -> jax.Array: + def loc(self) -> jnp.ndarray: """Mean of the Gaussian.""" return self._loc @@ -114,22 +114,22 @@ def n_dimensions(self) -> int: """Dimensionality of the Gaussian.""" return self.loc.shape[-1] - def covariance(self) -> jax.Array: + def covariance(self) -> jnp.ndarray: """Covariance of the Gaussian.""" return self.scale.covariance() - def to_z(self, x: jax.Array) -> jax.Array: + def to_z(self, x: jnp.ndarray) -> jnp.ndarray: r"""Transform :math:`x` to :math:`z = \frac{x - loc}{scale}`.""" return self.scale.centered_to_z(x_centered=x - self.loc) - def from_z(self, z: jax.Array) -> jax.Array: + def from_z(self, z: jnp.ndarray) -> jnp.ndarray: r"""Transform :math:`z` to :math:`x = loc + scale \cdot z`.""" return self.scale.z_to_centered(z=z) + self.loc def log_prob( self, - x: jax.Array, # (?, d) - ) -> jax.Array: # (?, d) + x: jnp.ndarray, # (?, d) + ) -> jnp.ndarray: # (?, d) """Log probability for a Gaussian with a diagonal covariance.""" d = x.shape[-1] z = self.to_z(x) @@ -138,7 +138,7 @@ def log_prob( -0.5 * (d * LOG2PI + log_det[None] + jnp.sum(z ** 2., axis=-1)) ) # (?, k) - def sample(self, rng: jax.Array, size: int) -> jax.Array: + def sample(self, rng: jnp.ndarray, size: int) -> jnp.ndarray: """Generate samples from the distribution.""" std_samples_t = jax.random.normal(key=rng, shape=(self.n_dimensions, size)) return self.loc[None] + ( @@ -149,7 +149,7 @@ def sample(self, rng: jax.Array, size: int) -> jax.Array: ) ) - def w2_dist(self, other: "Gaussian") -> jax.Array: + def w2_dist(self, other: "Gaussian") -> jnp.ndarray: r"""Wasserstein distance :math:`W_2^2` to another Gaussian. .. math:: @@ -167,7 +167,7 @@ def w2_dist(self, other: "Gaussian") -> jax.Array: delta_sigma = self.scale.w2_dist(other.scale) return delta_mean + delta_sigma - def f_potential(self, dest: "Gaussian", points: jax.Array) -> jax.Array: + def f_potential(self, dest: "Gaussian", points: jnp.ndarray) -> jnp.ndarray: """Optimal potential for W2 distance between Gaussians. Evaluated on points. Args: @@ -191,7 +191,7 @@ def batch_inner_product(x, y): points.dot(dest.loc) ) - def transport(self, dest: "Gaussian", points: jax.Array) -> jax.Array: + def transport(self, dest: "Gaussian", points: jnp.ndarray) -> jnp.ndarray: """Transport points according to map between two Gaussian measures. Args: diff --git a/src/ott/tools/gaussian_mixture/gaussian_mixture.py b/src/ott/tools/gaussian_mixture/gaussian_mixture.py index a9cb2b326..5d40a870d 100644 --- a/src/ott/tools/gaussian_mixture/gaussian_mixture.py +++ b/src/ott/tools/gaussian_mixture/gaussian_mixture.py @@ -27,8 +27,9 @@ def get_summary_stats_from_points_and_assignment_probs( - points: jax.Array, point_weights: jax.Array, assignment_probs: jax.Array -) -> Tuple[jax.Array, jax.Array, jax.Array]: + points: jnp.ndarray, point_weights: jnp.ndarray, + assignment_probs: jnp.ndarray +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Get component summary stats from points and component probabilities. Args: @@ -67,7 +68,7 @@ class GaussianMixture: """Gaussian Mixture model.""" def __init__( - self, loc: jax.Array, scale_params: jax.Array, + self, loc: jnp.ndarray, scale_params: jnp.ndarray, component_weight_ob: probabilities.Probabilities ): self._loc = loc @@ -77,7 +78,7 @@ def __init__( @classmethod def from_random( cls, - rng: jax.Array, + rng: jnp.ndarray, n_components: int, n_dimensions: int, stdev_mean: float = 0.1, @@ -112,7 +113,7 @@ def from_random( @classmethod def from_mean_cov_component_weights( - cls, mean: jax.Array, cov: jax.Array, component_weights: jax.Array + cls, mean: jnp.ndarray, cov: jnp.ndarray, component_weights: jnp.ndarray ): """Construct a GMM from means, covariances, and component weights.""" scale_params = [] @@ -127,9 +128,9 @@ def from_mean_cov_component_weights( @classmethod def from_points_and_assignment_probs( cls, - points: jax.Array, - point_weights: jax.Array, - assignment_probs: jax.Array, + points: jnp.ndarray, + point_weights: jnp.ndarray, + assignment_probs: jnp.ndarray, ) -> "GaussianMixture": """Estimate a GMM from points and a set of component probabilities.""" mean, cov, wts = get_summary_stats_from_points_and_assignment_probs( @@ -157,17 +158,17 @@ def n_components(self): return self._loc.shape[-2] @property - def loc(self) -> jax.Array: + def loc(self) -> jnp.ndarray: """Location parameters of the GMM.""" return self._loc @property - def scale_params(self) -> jax.Array: + def scale_params(self) -> jnp.ndarray: """Scale parameters of the GMM.""" return self._scale_params @property - def cholesky(self) -> jax.Array: + def cholesky(self) -> jnp.ndarray: """Cholesky decomposition of the GMM covariance matrices.""" size = self.n_dimensions @@ -177,7 +178,7 @@ def _get_cholesky(scale_params): return jax.vmap(_get_cholesky, in_axes=0, out_axes=0)(self.scale_params) @property - def covariance(self) -> jax.Array: + def covariance(self) -> jnp.ndarray: """Covariance matrices of the GMM.""" size = self.n_dimensions @@ -192,16 +193,16 @@ def component_weight_ob(self) -> probabilities.Probabilities: return self._component_weight_ob @property - def component_weights(self) -> jax.Array: + def component_weights(self) -> jnp.ndarray: """Component weights probabilities.""" return self._component_weight_ob.probs() - def log_component_weights(self) -> jax.Array: + def log_component_weights(self) -> jnp.ndarray: """Log component weights probabilities.""" return self._component_weight_ob.log_probs() def _get_normal( - self, loc: jax.Array, scale_params: jax.Array + self, loc: jnp.ndarray, scale_params: jnp.ndarray ) -> gaussian.Gaussian: size = loc.shape[-1] return gaussian.Gaussian( @@ -218,7 +219,7 @@ def components(self) -> List[gaussian.Gaussian]: """List of all GMM components.""" return [self.get_component(i) for i in range(self.n_components)] - def sample(self, rng: jax.Array, size: int) -> jax.Array: + def sample(self, rng: jnp.ndarray, size: int) -> jnp.ndarray: """Generate samples from the distribution.""" subrng0, subrng1 = jax.random.split(rng) component = self.component_weight_ob.sample(rng=subrng0, size=size) @@ -243,7 +244,7 @@ def _transform_single_value(single_component, single_x): axis=0 ) - def conditional_log_prob(self, x: jax.Array) -> jax.Array: + def conditional_log_prob(self, x: jnp.ndarray) -> jnp.ndarray: """Compute the component-conditional log probability of x. Args: @@ -255,7 +256,7 @@ def conditional_log_prob(self, x: jax.Array) -> jax.Array: """ def _log_prob_single_component( - loc: jax.Array, scale_params: jax.Array, x: jax.Array + loc: jnp.ndarray, scale_params: jnp.ndarray, x: jnp.ndarray ): norm = self._get_normal(loc=loc, scale_params=scale_params) return norm.log_prob(x) @@ -265,7 +266,7 @@ def _log_prob_single_component( ) return conditional_log_prob_fn(self._loc, self._scale_params, x) - def log_prob(self, x: jax.Array) -> jax.Array: + def log_prob(self, x: jnp.ndarray) -> jnp.ndarray: """Compute the log probability of the observations x. Args: @@ -281,7 +282,7 @@ def log_prob(self, x: jax.Array) -> jax.Array: log_prob_conditional + log_component_weight[None, :], axis=-1 ) - def get_log_component_posterior(self, x: jax.Array) -> jax.Array: + def get_log_component_posterior(self, x: jnp.ndarray) -> jnp.ndarray: """Compute the posterior probability that x came from each component. Args: diff --git a/src/ott/tools/gaussian_mixture/gaussian_mixture_pair.py b/src/ott/tools/gaussian_mixture/gaussian_mixture_pair.py index 21d4dbaf1..b24506fcc 100644 --- a/src/ott/tools/gaussian_mixture/gaussian_mixture_pair.py +++ b/src/ott/tools/gaussian_mixture/gaussian_mixture_pair.py @@ -128,12 +128,12 @@ def get_bures_geometry(self) -> pointcloud.PointCloud: epsilon=self.epsilon ) - def get_cost_matrix(self) -> jax.Array: + def get_cost_matrix(self) -> jnp.ndarray: """Get matrix of :math:`W_2^2` costs between all pairs of components.""" return self.get_bures_geometry().cost_matrix def get_sinkhorn( - self, cost_matrix: jax.Array, **kwargs: Any + self, cost_matrix: jnp.ndarray, **kwargs: Any ) -> sinkhorn.SinkhornOutput: """Get the output of Sinkhorn's method for a given cost matrix.""" # We use a Geometry here rather than the PointCloud created in @@ -152,7 +152,7 @@ def get_sinkhorn( def get_normalized_sinkhorn_coupling( self, sinkhorn_output: sinkhorn.SinkhornOutput, - ) -> jax.Array: + ) -> jnp.ndarray: """Get the normalized coupling matrix for the specified Sinkhorn output. Args: diff --git a/src/ott/tools/gaussian_mixture/linalg.py b/src/ott/tools/gaussian_mixture/linalg.py index 2a5114d69..9c88df0cc 100644 --- a/src/ott/tools/gaussian_mixture/linalg.py +++ b/src/ott/tools/gaussian_mixture/linalg.py @@ -18,9 +18,9 @@ def get_mean_and_var( - points: jax.Array, # (n, d) - weights: jax.Array, # (n,) -) -> Tuple[jax.Array, jax.Array]: + points: jnp.ndarray, # (n, d) + weights: jnp.ndarray, # (n,) +) -> Tuple[jnp.ndarray, jnp.ndarray]: """Get the mean and variance of a weighted set of points.""" weights_sum = jnp.sum(weights, axis=-1) # (1,) mean = ( @@ -37,9 +37,9 @@ def get_mean_and_var( def get_mean_and_cov( - points: jax.Array, # (n, d) - weights: jax.Array, # (n,) -) -> Tuple[jax.Array, jax.Array]: + points: jnp.ndarray, # (n, d) + weights: jnp.ndarray, # (n,) +) -> Tuple[jnp.ndarray, jnp.ndarray]: """Get the mean and covariance of a weighted set of points.""" weights_sum = jnp.sum(weights, axis=-1, keepdims=True) # (1,) mean = ( @@ -59,7 +59,7 @@ def get_mean_and_cov( return mean, cov -def flat_to_tril(x: jax.Array, size: int) -> jax.Array: +def flat_to_tril(x: jnp.ndarray, size: int) -> jnp.ndarray: """Map flat values to lower triangular matrices. Args: @@ -76,7 +76,7 @@ def flat_to_tril(x: jax.Array, size: int) -> jax.Array: return m.at[..., tril[0], tril[1]].set(x) -def tril_to_flat(m: jax.Array) -> jax.Array: +def tril_to_flat(m: jnp.ndarray) -> jnp.ndarray: """Flatten lower triangular matrices. Args: @@ -91,8 +91,8 @@ def tril_to_flat(m: jax.Array) -> jax.Array: def apply_to_diag( - m: jax.Array, fn: Callable[[jax.Array], jax.Array] -) -> jax.Array: + m: jnp.ndarray, fn: Callable[[jnp.ndarray], jnp.ndarray] +) -> jnp.ndarray: """Apply a function to the diagonal of a matrix.""" size = m.shape[-1] diag = jnp.diagonal(m, axis1=-2, axis2=-1) @@ -101,9 +101,9 @@ def apply_to_diag( def matrix_powers( - m: jax.Array, + m: jnp.ndarray, powers: Iterable[float], -) -> List[jax.Array]: +) -> List[jnp.ndarray]: """Raise a real, symmetric matrix to multiple powers.""" eigs, q = jnp.linalg.eigh(m) qt = jnp.swapaxes(q, axis1=-2, axis2=-1) @@ -113,7 +113,9 @@ def matrix_powers( return ret -def invmatvectril(m: jax.Array, x: jax.Array, lower: bool = True) -> jax.Array: +def invmatvectril( + m: jnp.ndarray, x: jnp.ndarray, lower: bool = True +) -> jnp.ndarray: """Multiply x by the inverse of a triangular matrix. Args: @@ -130,8 +132,10 @@ def invmatvectril(m: jax.Array, x: jax.Array, lower: bool = True) -> jax.Array: def get_random_orthogonal( - rng: jax.Array, dim: int, dtype: Optional[jnp.dtype] = None -) -> jax.Array: + rng: jnp.ndarray, + dim: int, + dtype: Optional[jnp.dtype] = None +) -> jnp.ndarray: """Get a random orthogonal matrix with the specified dimension.""" m = jax.random.normal(key=rng, shape=[dim, dim], dtype=dtype) q, _ = jnp.linalg.qr(m) diff --git a/src/ott/tools/gaussian_mixture/probabilities.py b/src/ott/tools/gaussian_mixture/probabilities.py index c3bb253a5..66a90c1a7 100644 --- a/src/ott/tools/gaussian_mixture/probabilities.py +++ b/src/ott/tools/gaussian_mixture/probabilities.py @@ -27,7 +27,7 @@ class Probabilities: to a length n simplex by appending a 0 and taking a softmax. """ - _params: jax.Array + _params: jnp.ndarray def __init__(self, params): self._params = params @@ -35,7 +35,7 @@ def __init__(self, params): @classmethod def from_random( cls, - rng: jax.Array, + rng: jnp.ndarray, n_dimensions: int, stdev: Optional[float] = 0.1, dtype: Optional[jnp.dtype] = None @@ -47,7 +47,7 @@ def from_random( ) @classmethod - def from_probs(cls, probs: jax.Array) -> "Probabilities": + def from_probs(cls, probs: jnp.ndarray) -> "Probabilities": """Construct Probabilities from a vector of probabilities.""" log_probs = jnp.log(probs) log_probs_normalized, norm = log_probs[:-1], log_probs[-1] @@ -62,21 +62,21 @@ def params(self): # noqa: D102 def dtype(self): # noqa: D102 return self._params.dtype - def unnormalized_log_probs(self) -> jax.Array: + def unnormalized_log_probs(self) -> jnp.ndarray: """Get the unnormalized log probabilities.""" return jnp.concatenate([self._params, jnp.zeros((1,), dtype=self.dtype)], axis=-1) - def log_probs(self) -> jax.Array: + def log_probs(self) -> jnp.ndarray: """Get the log probabilities.""" return jax.nn.log_softmax(self.unnormalized_log_probs()) - def probs(self) -> jax.Array: + def probs(self) -> jnp.ndarray: """Get the probabilities.""" return jax.nn.softmax(self.unnormalized_log_probs()) - def sample(self, rng: jax.Array, size: int) -> jax.Array: + def sample(self, rng: jnp.ndarray, size: int) -> jnp.ndarray: """Sample from the distribution.""" return jax.random.categorical( key=rng, logits=self.unnormalized_log_probs(), shape=(size,) diff --git a/src/ott/tools/gaussian_mixture/scale_tril.py b/src/ott/tools/gaussian_mixture/scale_tril.py index ee708d5ac..95b812d99 100644 --- a/src/ott/tools/gaussian_mixture/scale_tril.py +++ b/src/ott/tools/gaussian_mixture/scale_tril.py @@ -27,16 +27,16 @@ class ScaleTriL: """Pytree for a lower triangular Cholesky-factored covariance matrix.""" - def __init__(self, params: jax.Array, size: int): + def __init__(self, params: jnp.ndarray, size: int): self._params = params self._size = size @classmethod def from_points_and_weights( cls, - points: jax.Array, - weights: jax.Array, - ) -> Tuple[jax.Array, "ScaleTriL"]: + points: jnp.ndarray, + weights: jnp.ndarray, + ) -> Tuple[jnp.ndarray, "ScaleTriL"]: """Get a mean and a ScaleTriL from a set of points and weights.""" mean, cov = linalg.get_mean_and_cov(points=points, weights=weights) return mean, cls.from_covariance(cov) @@ -44,7 +44,7 @@ def from_points_and_weights( @classmethod def from_random( cls, - rng: jax.Array, + rng: jnp.ndarray, n_dimensions: int, stdev: Optional[float] = 0.1, dtype: jnp.dtype = jnp.float32, @@ -80,7 +80,7 @@ def from_random( return cls(params=flat, size=n_dimensions) @classmethod - def from_cholesky(cls, cholesky: jax.Array) -> "ScaleTriL": + def from_cholesky(cls, cholesky: jnp.ndarray) -> "ScaleTriL": """Construct ScaleTriL from a Cholesky factor of a covariance matrix.""" m = linalg.apply_to_diag(cholesky, jnp.log) flat = linalg.tril_to_flat(m) @@ -89,14 +89,14 @@ def from_cholesky(cls, cholesky: jax.Array) -> "ScaleTriL": @classmethod def from_covariance( cls, - covariance: jax.Array, + covariance: jnp.ndarray, ) -> "ScaleTriL": """Construct ScaleTriL from a covariance matrix.""" cholesky = jnp.linalg.cholesky(covariance) return cls.from_cholesky(cholesky) @property - def params(self) -> jax.Array: + def params(self) -> jnp.ndarray: """Internal representation.""" return self._params @@ -110,34 +110,34 @@ def dtype(self): """Data type of the covariance matrix.""" return self._params.dtype - def cholesky(self) -> jax.Array: + def cholesky(self) -> jnp.ndarray: """Get a lower triangular Cholesky factor for the covariance matrix.""" m = linalg.flat_to_tril(self._params, size=self._size) return linalg.apply_to_diag(m, jnp.exp) - def covariance(self) -> jax.Array: + def covariance(self) -> jnp.ndarray: """Get the covariance matrix.""" cholesky = self.cholesky() return cholesky @ cholesky.T - def covariance_sqrt(self) -> jax.Array: + def covariance_sqrt(self) -> jnp.ndarray: """Get the square root of the covariance matrix.""" return linalg.matrix_powers(self.covariance(), (0.5,))[0] - def log_det_covariance(self) -> jax.Array: + def log_det_covariance(self) -> jnp.ndarray: """Get the log of the determinant of the covariance matrix.""" diag = jnp.diagonal(self.cholesky(), axis1=-2, axis2=-1) return 2. * jnp.sum(jnp.log(diag), axis=-1) - def centered_to_z(self, x_centered: jax.Array) -> jax.Array: + def centered_to_z(self, x_centered: jnp.ndarray) -> jnp.ndarray: """Map centered points to standardized centered points (i.e. cov(z) = I).""" return linalg.invmatvectril(m=self.cholesky(), x=x_centered, lower=True) - def z_to_centered(self, z: jax.Array) -> jax.Array: + def z_to_centered(self, z: jnp.ndarray) -> jnp.ndarray: """Scale standardized points to points with the specified covariance.""" return (self.cholesky() @ z.T).T - def w2_dist(self, other: "ScaleTriL") -> jax.Array: + def w2_dist(self, other: "ScaleTriL") -> jnp.ndarray: r"""Wasserstein distance W_2^2 to another Gaussian with same mean. Args: @@ -148,7 +148,7 @@ def w2_dist(self, other: "ScaleTriL") -> jax.Array: """ dimension = self.size - def _flatten_cov(cov: jax.Array) -> jax.Array: + def _flatten_cov(cov: jnp.ndarray) -> jnp.ndarray: cov = cov.reshape(cov.shape[:-2] + (dimension * dimension,)) return jnp.concatenate([jnp.zeros(dimension), cov], axis=-1) @@ -159,7 +159,7 @@ def _flatten_cov(cov: jax.Array) -> jax.Array: ..., ] - def gaussian_map(self, dest_scale: "ScaleTriL") -> jax.Array: + def gaussian_map(self, dest_scale: "ScaleTriL") -> jnp.ndarray: """Scaling matrix used in transport between 0-mean Gaussians. Sigma_mu^{-1/2} @ @@ -179,7 +179,9 @@ def gaussian_map(self, dest_scale: "ScaleTriL") -> jax.Array: ) return jnp.matmul(sqrt0_inv, jnp.matmul(m, sqrt0_inv)) - def transport(self, dest_scale: "ScaleTriL", points: jax.Array) -> jax.Array: + def transport( + self, dest_scale: "ScaleTriL", points: jnp.ndarray + ) -> jnp.ndarray: """Apply Monge map, computed between two 0-mean Gaussians, to points. Args: diff --git a/src/ott/tools/k_means.py b/src/ott/tools/k_means.py index c8fc8189d..9175abe2c 100644 --- a/src/ott/tools/k_means.py +++ b/src/ott/tools/k_means.py @@ -25,29 +25,29 @@ __all__ = ["k_means", "KMeansOutput"] Init_t = Union[Literal["k-means++", "random"], - Callable[[pointcloud.PointCloud, int, jax.Array], jax.Array]] + Callable[[pointcloud.PointCloud, int, jnp.ndarray], jnp.ndarray]] class KPPState(NamedTuple): # noqa: D101 - rng: jax.Array - centroids: jax.Array - centroid_dists: jax.Array + rng: jnp.ndarray + centroids: jnp.ndarray + centroid_dists: jnp.ndarray class KMeansState(NamedTuple): # noqa: D101 - centroids: jax.Array - prev_assignment: jax.Array - assignment: jax.Array - errors: jax.Array + centroids: jnp.ndarray + prev_assignment: jnp.ndarray + assignment: jnp.ndarray + errors: jnp.ndarray center_shift: float class KMeansConst(NamedTuple): # noqa: D101 geom: pointcloud.PointCloud - x_weights: jax.Array + x_weights: jnp.ndarray @property - def x(self) -> jax.Array: + def x(self) -> jnp.ndarray: """Array of shape ``[n, ndim]`` containing the unweighted point cloud.""" return self.geom.x @@ -57,7 +57,7 @@ def weighted_x(self): return self.x_weights[:, :-1] @property - def weights(self) -> jax.Array: + def weights(self) -> jnp.ndarray: """Array of shape ``[n, 1]`` containing weights for each point.""" return self.x_weights[:, -1:] @@ -75,12 +75,12 @@ class KMeansOutput(NamedTuple): inner_errors: Array of shape ``[max_iterations,]`` containing the ``error`` at every iteration. """ - centroids: jax.Array - assignment: jax.Array + centroids: jnp.ndarray + assignment: jnp.ndarray converged: bool iteration: int error: float - inner_errors: Optional[jax.Array] + inner_errors: Optional[jnp.ndarray] @classmethod def _from_state( @@ -109,8 +109,8 @@ def _from_state( def _random_init( - geom: pointcloud.PointCloud, k: int, rng: jax.Array -) -> jax.Array: + geom: pointcloud.PointCloud, k: int, rng: jnp.ndarray +) -> jnp.ndarray: ixs = jnp.arange(geom.shape[0]) ixs = jax.random.choice(rng, ixs, shape=(k,), replace=False) return geom.subset(ixs, None).x @@ -119,11 +119,11 @@ def _random_init( def _k_means_plus_plus( geom: pointcloud.PointCloud, k: int, - rng: jax.Array, + rng: jnp.ndarray, n_local_trials: Optional[int] = None, -) -> jax.Array: +) -> jnp.ndarray: - def init_fn(geom: pointcloud.PointCloud, rng: jax.Array) -> KPPState: + def init_fn(geom: pointcloud.PointCloud, rng: jnp.ndarray) -> KPPState: rng, next_rng = jax.random.split(rng, 2) ix = jax.random.choice(rng, jnp.arange(geom.shape[0]), shape=()) centroids = jnp.full((k, geom.cost_rank), jnp.inf).at[0].set(geom.x[ix]) @@ -131,7 +131,7 @@ def init_fn(geom: pointcloud.PointCloud, rng: jax.Array) -> KPPState: return KPPState(rng=next_rng, centroids=centroids, centroid_dists=dists) def body_fn( - iteration: int, const: Tuple[pointcloud.PointCloud, jax.Array], + iteration: int, const: Tuple[pointcloud.PointCloud, jnp.ndarray], state: KPPState, compute_error: bool ) -> KPPState: del compute_error @@ -177,10 +177,10 @@ def body_fn( @functools.partial(jax.vmap, in_axes=[None, 0, 0, 0], out_axes=0) def _reallocate_centroids( const: KMeansConst, - ix: jax.Array, - centroid: jax.Array, - weight: jax.Array, -) -> Tuple[jax.Array, jax.Array]: + ix: jnp.ndarray, + centroid: jnp.ndarray, + weight: jnp.ndarray, +) -> Tuple[jnp.ndarray, jnp.ndarray]: is_empty = weight <= 0. new_centroid = (1 - is_empty) * centroid + is_empty * const.x[ix] # (ndim,) centroid_to_remove = is_empty * const.weighted_x[ix] # (ndim,) @@ -190,8 +190,8 @@ def _reallocate_centroids( def _update_assignment( const: KMeansConst, - centroids: jax.Array, -) -> Tuple[jax.Array, jax.Array]: + centroids: jnp.ndarray, +) -> Tuple[jnp.ndarray, jnp.ndarray]: (x, _, *args), aux_data = const.geom.tree_flatten() cost_matrix = type( const.geom @@ -203,9 +203,9 @@ def _update_assignment( def _update_centroids( - const: KMeansConst, k: int, assignment: jax.Array, - dist_to_centers: jax.Array -) -> jax.Array: + const: KMeansConst, k: int, assignment: jnp.ndarray, + dist_to_centers: jnp.ndarray +) -> jnp.ndarray: # TODO(michalk8): # cannot put `k` into `const`, see https://github.com/ott-jax/ott/issues/129 x_weights = jax.ops.segment_sum(const.x_weights, assignment, num_segments=k) @@ -224,10 +224,10 @@ def _update_centroids( @functools.partial(jax.vmap, in_axes=[0] + [None] * 9) def _k_means( - rng: jax.Array, + rng: jnp.ndarray, geom: pointcloud.PointCloud, k: int, - weights: Optional[jax.Array] = None, + weights: Optional[jnp.ndarray] = None, init: Init_t = "k-means++", n_local_trials: Optional[int] = None, tol: float = 1e-4, @@ -342,9 +342,9 @@ def finalize_fn(const: KMeansConst, state: KMeansState) -> KMeansState: def k_means( - geom: Union[jax.Array, pointcloud.PointCloud], + geom: Union[jnp.ndarray, pointcloud.PointCloud], k: int, - weights: Optional[jax.Array] = None, + weights: Optional[jnp.ndarray] = None, init: Init_t = "k-means++", n_init: int = 10, n_local_trials: Optional[int] = None, @@ -352,7 +352,7 @@ def k_means( min_iterations: int = 0, max_iterations: int = 300, store_inner_errors: bool = False, - rng: Optional[jax.Array] = None, + rng: Optional[jnp.ndarray] = None, ) -> KMeansOutput: r"""K-means clustering using Lloyd's algorithm :cite:`lloyd:82`. @@ -386,7 +386,7 @@ def k_means( """ assert geom.shape[ 0] >= k, f"Cannot cluster `{geom.shape[0]}` points into `{k}` clusters." - if isinstance(geom, jax.Array): + if isinstance(geom, jnp.ndarray): geom = pointcloud.PointCloud(geom) if isinstance(geom.cost_fn, costs.Cosine): geom = geom._cosine_to_sqeucl() diff --git a/src/ott/tools/plot.py b/src/ott/tools/plot.py index d83868fd5..bd1f42e91 100644 --- a/src/ott/tools/plot.py +++ b/src/ott/tools/plot.py @@ -13,7 +13,6 @@ # limitations under the License. from typing import List, Optional, Sequence, Tuple, Union -import jax import jax.numpy as jnp import numpy as np import scipy @@ -33,7 +32,8 @@ gromov_wasserstein.GWOutput] -def bidimensional(x: jax.Array, y: jax.Array) -> Tuple[jax.Array, jax.Array]: +def bidimensional(x: jnp.ndarray, + y: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: """Apply PCA to reduce to bi-dimensional data.""" if x.shape[1] < 3: return x, y @@ -121,7 +121,7 @@ def _scatter(self, ot: Transport): scales_y = b * self._scale * b.shape[0] return x, y, scales_x, scales_y - def _mapping(self, x: jax.Array, y: jax.Array, matrix: jax.Array): + def _mapping(self, x: jnp.ndarray, y: jnp.ndarray, matrix: jnp.ndarray): """Compute the lines representing the mapping between the 2 point clouds.""" # Only plot the lines with a cost above the threshold. u, v = jnp.where(matrix > self._threshold) diff --git a/src/ott/tools/segment_sinkhorn.py b/src/ott/tools/segment_sinkhorn.py index ca5e5c228..223f2a30f 100644 --- a/src/ott/tools/segment_sinkhorn.py +++ b/src/ott/tools/segment_sinkhorn.py @@ -14,7 +14,7 @@ from types import MappingProxyType from typing import Any, Mapping, Optional, Tuple -import jax +import jax.numpy as jnp from ott.geometry import costs, pointcloud, segment from ott.problems.linear import linear_problem @@ -22,21 +22,21 @@ def segment_sinkhorn( - x: jax.Array, - y: jax.Array, + x: jnp.ndarray, + y: jnp.ndarray, num_segments: Optional[int] = None, max_measure_size: Optional[int] = None, cost_fn: Optional[costs.CostFn] = None, - segment_ids_x: Optional[jax.Array] = None, - segment_ids_y: Optional[jax.Array] = None, + segment_ids_x: Optional[jnp.ndarray] = None, + segment_ids_y: Optional[jnp.ndarray] = None, indices_are_sorted: bool = False, num_per_segment_x: Optional[Tuple[int, ...]] = None, num_per_segment_y: Optional[Tuple[int, ...]] = None, - weights_x: Optional[jax.Array] = None, - weights_y: Optional[jax.Array] = None, + weights_x: Optional[jnp.ndarray] = None, + weights_y: Optional[jnp.ndarray] = None, sinkhorn_kwargs: Mapping[str, Any] = MappingProxyType({}), **kwargs: Any -) -> jax.Array: +) -> jnp.ndarray: """Compute regularized OT cost between subsets of vectors in `x` and `y`. Helper function designed to compute Sinkhorn regularized OT cost between @@ -104,10 +104,10 @@ def segment_sinkhorn( padding_vector = cost_fn._padder(dim=dim) def eval_fn( - padded_x: jax.Array, - padded_y: jax.Array, - padded_weight_x: jax.Array, - padded_weight_y: jax.Array, + padded_x: jnp.ndarray, + padded_y: jnp.ndarray, + padded_weight_x: jnp.ndarray, + padded_weight_y: jnp.ndarray, ) -> float: mask_x = padded_weight_x > 0. mask_y = padded_weight_y > 0. diff --git a/src/ott/tools/sinkhorn_divergence.py b/src/ott/tools/sinkhorn_divergence.py index 2ff1cbc4e..51de97613 100644 --- a/src/ott/tools/sinkhorn_divergence.py +++ b/src/ott/tools/sinkhorn_divergence.py @@ -14,7 +14,6 @@ from types import MappingProxyType from typing import Any, Mapping, Optional, Tuple, Type -import jax import jax.numpy as jnp from ott import utils @@ -28,7 +27,7 @@ "SinkhornDivergenceOutput" ] -Potentials_t = Tuple[jax.Array, jax.Array] +Potentials_t = Tuple[jnp.ndarray, jnp.ndarray] @utils.register_pytree_node @@ -36,10 +35,11 @@ class SinkhornDivergenceOutput: # noqa: D101 divergence: float potentials: Tuple[Potentials_t, Potentials_t, Potentials_t] geoms: Tuple[geometry.Geometry, geometry.Geometry, geometry.Geometry] - errors: Tuple[Optional[jax.Array], Optional[jax.Array], Optional[jax.Array]] + errors: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray], + Optional[jnp.ndarray]] converged: Tuple[bool, bool, bool] - a: jax.Array - b: jax.Array + a: jnp.ndarray + b: jnp.ndarray n_iters: Tuple[int, int, int] def to_dual_potentials(self) -> "potentials.EntropicPotentials": @@ -73,8 +73,8 @@ def tree_unflatten_foo(cls, aux_data, children): # noqa: D102 def sinkhorn_divergence( geom: Type[geometry.Geometry], *args: Any, - a: Optional[jax.Array] = None, - b: Optional[jax.Array] = None, + a: Optional[jnp.ndarray] = None, + b: Optional[jnp.ndarray] = None, sinkhorn_kwargs: Mapping[str, Any] = MappingProxyType({}), static_b: bool = False, share_epsilon: bool = True, @@ -138,8 +138,8 @@ def _sinkhorn_divergence( geometry_xy: geometry.Geometry, geometry_xx: geometry.Geometry, geometry_yy: Optional[geometry.Geometry], - a: jax.Array, - b: jax.Array, + a: jnp.ndarray, + b: jnp.ndarray, symmetric_sinkhorn: bool, **kwargs: Any, ) -> SinkhornDivergenceOutput: @@ -155,9 +155,9 @@ def _sinkhorn_divergence( between elements of the view X. geometry_yy: a Cost object able to apply kernels with a certain epsilon, between elements of the view Y. - a: jax.Array[n]: the weight of each input point. The sum of + a: jnp.ndarray[n]: the weight of each input point. The sum of all elements of ``b`` must match that of ``a`` to converge. - b: jax.Array[m]: the weight of each target point. The sum of + b: jnp.ndarray[m]: the weight of each target point. The sum of all elements of ``b`` must match that of ``a`` to converge. symmetric_sinkhorn: Use Sinkhorn updates in Eq. 25 of :cite:`feydy:19` for symmetric terms comparing x/x and y/y. @@ -219,24 +219,24 @@ def _sinkhorn_divergence( def segment_sinkhorn_divergence( - x: jax.Array, - y: jax.Array, + x: jnp.ndarray, + y: jnp.ndarray, num_segments: Optional[int] = None, max_measure_size: Optional[int] = None, cost_fn: Optional[costs.CostFn] = None, - segment_ids_x: Optional[jax.Array] = None, - segment_ids_y: Optional[jax.Array] = None, + segment_ids_x: Optional[jnp.ndarray] = None, + segment_ids_y: Optional[jnp.ndarray] = None, indices_are_sorted: bool = False, num_per_segment_x: Optional[Tuple[int, ...]] = None, num_per_segment_y: Optional[Tuple[int, ...]] = None, - weights_x: Optional[jax.Array] = None, - weights_y: Optional[jax.Array] = None, + weights_x: Optional[jnp.ndarray] = None, + weights_y: Optional[jnp.ndarray] = None, sinkhorn_kwargs: Mapping[str, Any] = MappingProxyType({}), static_b: bool = False, share_epsilon: bool = True, symmetric_sinkhorn: bool = False, **kwargs: Any -) -> jax.Array: +) -> jnp.ndarray: """Compute Sinkhorn divergence between subsets of vectors in `x` and `y`. Helper function designed to compute Sinkhorn divergences between several point @@ -313,10 +313,10 @@ def segment_sinkhorn_divergence( padding_vector = cost_fn._padder(dim=dim) def eval_fn( - padded_x: jax.Array, - padded_y: jax.Array, - padded_weight_x: jax.Array, - padded_weight_y: jax.Array, + padded_x: jnp.ndarray, + padded_y: jnp.ndarray, + padded_weight_x: jnp.ndarray, + padded_weight_y: jnp.ndarray, ) -> float: mask_x = padded_weight_x > 0. mask_y = padded_weight_y > 0. diff --git a/src/ott/tools/soft_sort.py b/src/ott/tools/soft_sort.py index b5b33e183..beb88365f 100644 --- a/src/ott/tools/soft_sort.py +++ b/src/ott/tools/soft_sort.py @@ -30,14 +30,14 @@ "quantize", "topk_mask", "multivariate_cdf_quantile_maps" ] -Func_t = Callable[[jax.Array], jax.Array] +Func_t = Callable[[jnp.ndarray], jnp.ndarray] def transport_for_sort( - inputs: jax.Array, - weights: Optional[jax.Array] = None, - target_weights: Optional[jax.Array] = None, - squashing_fun: Optional[Callable[[jax.Array], jax.Array]] = None, + inputs: jnp.ndarray, + weights: Optional[jnp.ndarray] = None, + target_weights: Optional[jnp.ndarray] = None, + squashing_fun: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, epsilon: float = 1e-2, **kwargs: Any, ) -> sinkhorn.SinkhornOutput: @@ -83,7 +83,7 @@ def transport_for_sort( return solver(prob) -def apply_on_axis(op, inputs, axis, *args, **kwargs: Any) -> jax.Array: +def apply_on_axis(op, inputs, axis, *args, **kwargs: Any) -> jnp.ndarray: """Apply a differentiable operator on a given axis of the input. Args: @@ -120,8 +120,8 @@ def apply_on_axis(op, inputs, axis, *args, **kwargs: Any) -> jax.Array: def _sort( - inputs: jax.Array, topk: int, num_targets: Optional[int], **kwargs: Any -) -> jax.Array: + inputs: jnp.ndarray, topk: int, num_targets: Optional[int], **kwargs: Any +) -> jnp.ndarray: """Apply the soft sort operator on a one dimensional array.""" num_points = inputs.shape[0] a = jnp.ones((num_points,)) / num_points @@ -145,12 +145,12 @@ def _sort( def sort( - inputs: jax.Array, + inputs: jnp.ndarray, axis: int = -1, topk: int = -1, num_targets: Optional[int] = None, **kwargs: Any, -) -> jax.Array: +) -> jnp.ndarray: r"""Apply the soft sort operator on a given axis of the input. For instance: @@ -203,8 +203,8 @@ def sort( def _ranks( - inputs: jax.Array, num_targets, target_weights, **kwargs: Any -) -> jax.Array: + inputs: jnp.ndarray, num_targets, target_weights, **kwargs: Any +) -> jnp.ndarray: """Apply the soft ranks operator on a one dimensional array.""" num_points = inputs.shape[0] if target_weights is None: @@ -220,12 +220,12 @@ def _ranks( def ranks( - inputs: jax.Array, + inputs: jnp.ndarray, axis: int = -1, num_targets: Optional[int] = None, - target_weights: Optional[jax.Array] = None, + target_weights: Optional[jnp.ndarray] = None, **kwargs: Any, -) -> jax.Array: +) -> jnp.ndarray: r"""Apply the soft rank operator on input tensor. For instance: @@ -278,11 +278,11 @@ def ranks( def topk_mask( - inputs: jax.Array, + inputs: jnp.ndarray, axis: int = -1, k: int = 1, **kwargs: Any, -) -> jax.Array: +) -> jnp.ndarray: r"""Soft :math:`\text{top-}k` selection mask. For instance: @@ -337,12 +337,12 @@ def topk_mask( def quantile( - inputs: jax.Array, - q: Optional[Union[float, jax.Array]], + inputs: jnp.ndarray, + q: Optional[Union[float, jnp.ndarray]], axis: Union[int, Tuple[int, ...]] = -1, - weight: Optional[Union[float, jax.Array]] = None, + weight: Optional[Union[float, jnp.ndarray]] = None, **kwargs: Any, -) -> jax.Array: +) -> jnp.ndarray: r"""Apply the soft quantiles operator on the input tensor. For instance: @@ -395,8 +395,8 @@ def quantile( """ def _quantile( - inputs: jax.Array, q: float, weight: float, **kwargs - ) -> jax.Array: + inputs: jnp.ndarray, q: float, weight: float, **kwargs + ) -> jnp.ndarray: num_points = inputs.shape[0] q = jnp.array([0.2, 0.5, 0.8]) if q is None else jnp.atleast_1d(q) num_quantiles = q.shape[0] @@ -456,15 +456,15 @@ def _quantile( def multivariate_cdf_quantile_maps( - inputs: jax.Array, - target_sampler: Optional[Callable[[jax.Array, Tuple[int, int]], - jax.Array]] = None, - rng: Optional[jax.Array] = None, + inputs: jnp.ndarray, + target_sampler: Optional[Callable[[jnp.ndarray, Tuple[int, int]], + jnp.ndarray]] = None, + rng: Optional[jnp.ndarray] = None, num_target_samples: Optional[int] = None, cost_fn: Optional[costs.CostFn] = None, epsilon: Optional[float] = None, - input_weights: Optional[jax.Array] = None, - target_weights: Optional[jax.Array] = None, + input_weights: Optional[jnp.ndarray] = None, + target_weights: Optional[jnp.ndarray] = None, **kwargs: Any ) -> Tuple[Func_t, Func_t]: r"""Returns multivariate CDF and quantile maps, given input samples. @@ -534,8 +534,8 @@ def multivariate_cdf_quantile_maps( def _quantile_normalization( - inputs: jax.Array, targets: jax.Array, weights: float, **kwargs: Any -) -> jax.Array: + inputs: jnp.ndarray, targets: jnp.ndarray, weights: float, **kwargs: Any +) -> jnp.ndarray: """Apply soft quantile normalization on a one dimensional array.""" num_points = inputs.shape[0] a = jnp.ones((num_points,)) / num_points @@ -544,12 +544,12 @@ def _quantile_normalization( def quantile_normalization( - inputs: jax.Array, - targets: jax.Array, - weights: Optional[jax.Array] = None, + inputs: jnp.ndarray, + targets: jnp.ndarray, + weights: Optional[jnp.ndarray] = None, axis: int = -1, **kwargs: Any, -) -> jax.Array: +) -> jnp.ndarray: r"""Re-normalize inputs so that its quantiles match those of targets/weights. Quantile normalization rearranges the values in inputs to values that match @@ -600,11 +600,11 @@ def quantile_normalization( def sort_with( - inputs: jax.Array, - criterion: jax.Array, + inputs: jnp.ndarray, + criterion: jnp.ndarray, topk: int = -1, **kwargs: Any, -) -> jax.Array: +) -> jnp.ndarray: r"""Sort a multidimensional array according to a real valued criterion. Given ``batch`` vectors of dimension `dim`, to which, for each, a real value @@ -655,7 +655,7 @@ def sort_with( return sort_fn(inputs) -def _quantize(inputs: jax.Array, num_q: int, **kwargs: Any) -> jax.Array: +def _quantize(inputs: jnp.ndarray, num_q: int, **kwargs: Any) -> jnp.ndarray: """Apply the soft quantization operator on a one dimensional array.""" num_points = inputs.shape[0] a = jnp.ones((num_points,)) / num_points @@ -665,11 +665,11 @@ def _quantize(inputs: jax.Array, num_q: int, **kwargs: Any) -> jax.Array: def quantize( - inputs: jax.Array, + inputs: jnp.ndarray, num_levels: int = 10, axis: int = -1, **kwargs: Any, -) -> jax.Array: +) -> jnp.ndarray: r"""Soft quantizes an input according using ``num_levels`` values along axis. The quantization operator consists in concentrating several values around diff --git a/src/ott/types.py b/src/ott/types.py index 5c4609ec2..7a4c88716 100644 --- a/src/ott/types.py +++ b/src/ott/types.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Protocol -import jax +import jax.numpy as jnp __all__ = ["Transport"] @@ -28,11 +28,11 @@ class can however be used in type hints to support duck typing. """ @property - def matrix(self) -> jax.Array: + def matrix(self) -> jnp.ndarray: ... - def apply(self, inputs: jax.Array, axis: int) -> jax.Array: + def apply(self, inputs: jnp.ndarray, axis: int) -> jnp.ndarray: ... - def marginal(self, axis: int = 0) -> jax.Array: + def marginal(self, axis: int = 0) -> jnp.ndarray: ... diff --git a/src/ott/utils.py b/src/ott/utils.py index 558f4ba1c..2acfd8420 100644 --- a/src/ott/utils.py +++ b/src/ott/utils.py @@ -18,6 +18,7 @@ from typing import Any, Callable, NamedTuple, Optional, Tuple import jax +import jax.numpy as jnp import numpy as np try: @@ -68,7 +69,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return functools.wraps(func)(wrapper) -def default_prng_key(rng: Optional[jax.Array] = None) -> jax.Array: +def default_prng_key(rng: Optional[jnp.ndarray] = None) -> jnp.ndarray: """Get the default PRNG key. Args: diff --git a/tests/conftest.py b/tests/conftest.py index a8118845c..bc4570343 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,6 +17,7 @@ import jax import jax.experimental +import jax.numpy as jnp import pytest from _pytest.python import Metafunc @@ -68,7 +69,7 @@ def pytest_generate_tests(metafunc: Metafunc) -> None: @pytest.fixture(scope="session") -def rng() -> jax.Array: +def rng() -> jnp.ndarray: return jax.random.PRNGKey(0) diff --git a/tests/geometry/costs_test.py b/tests/geometry/costs_test.py index 47446a4fd..b23e79071 100644 --- a/tests/geometry/costs_test.py +++ b/tests/geometry/costs_test.py @@ -27,7 +27,7 @@ ts_metrics = None -def _proj(matrix: jax.Array) -> jax.Array: +def _proj(matrix: jnp.ndarray) -> jnp.ndarray: u, _, v_h = jnp.linalg.svd(matrix, full_matrices=False) return u.dot(v_h) @@ -35,7 +35,7 @@ def _proj(matrix: jax.Array) -> jax.Array: @pytest.mark.fast() class TestCostFn: - def test_cosine(self, rng: jax.Array): + def test_cosine(self, rng: jnp.ndarray): """Test the cosine cost function.""" x = jnp.array([0, 0]) y = jnp.array([0, 0]) @@ -84,7 +84,7 @@ def test_cosine(self, rng: jax.Array): @pytest.mark.fast() class TestBuresBarycenter: - def test_bures(self, rng: jax.Array): + def test_bures(self, rng: jnp.ndarray): d = 3 r = jnp.array([1.2036, 0.2825, 0.013]) Sigma1 = r * jnp.eye(d) @@ -141,7 +141,7 @@ class TestRegTICost: ) def test_reg_cost_legendre( self, - rng: jax.Array, + rng: jnp.ndarray, scaling_reg: float, cost_fn_t: Type[costs.RegTICost], use_mat: bool, @@ -163,7 +163,7 @@ def test_reg_cost_legendre( @pytest.mark.parametrize("k", [1, 3, 10]) @pytest.mark.parametrize("d", [10, 50]) - def test_elastic_sq_k_overlap(self, rng: jax.Array, k: int, d: int): + def test_elastic_sq_k_overlap(self, rng: jnp.ndarray, k: int, d: int): expected = jax.random.normal(rng, (d,)) cost_fn = costs.ElasticSqKOverlap(k=k, scaling_reg=1e-2) @@ -178,7 +178,9 @@ def test_elastic_sq_k_overlap(self, rng: jax.Array, k: int, d: int): costs.ElasticSqKOverlap(k=3, scaling_reg=17) ] ) - def test_sparse_displacement(self, rng: jax.Array, cost_fn: costs.RegTICost): + def test_sparse_displacement( + self, rng: jnp.ndarray, cost_fn: costs.RegTICost + ): frac_sparse = 0.7 rng1, rng2 = jax.random.split(rng, 2) d = 17 @@ -194,7 +196,7 @@ def test_sparse_displacement(self, rng: jax.Array, cost_fn: costs.RegTICost): @pytest.mark.parametrize("cost_type_t", [costs.ElasticL1, costs.ElasticSTVS]) def test_stronger_regularization_increases_sparsity( - self, rng: jax.Array, cost_type_t: Type[costs.RegTICost] + self, rng: jnp.ndarray, cost_type_t: Type[costs.RegTICost] ): d, rngs = 17, jax.random.split(rng, 4) x = jax.random.normal(rngs[0], (50, d)) @@ -223,7 +225,7 @@ class TestSoftDTW: @pytest.mark.parametrize("n", [7, 10]) @pytest.mark.parametrize("m", [9, 10]) @pytest.mark.parametrize("gamma", [1e-3, 5]) - def test_soft_dtw(self, rng: jax.Array, n: int, m: int, gamma: float): + def test_soft_dtw(self, rng: jnp.ndarray, n: int, m: int, gamma: float): rng1, rng2 = jax.random.split(rng, 2) t1 = jax.random.normal(rng1, (n,)) t2 = jax.random.normal(rng2, (m,)) @@ -236,7 +238,7 @@ def test_soft_dtw(self, rng: jax.Array, n: int, m: int, gamma: float): @pytest.mark.parametrize(("debiased", "jit"), [(False, True), (True, False)]) def test_soft_dtw_debiased( self, - rng: jax.Array, + rng: jnp.ndarray, debiased: bool, jit: bool, ): @@ -263,7 +265,7 @@ def test_soft_dtw_debiased( @pytest.mark.parametrize(("debiased", "jit"), [(False, False), (True, True)]) @pytest.mark.parametrize("gamma", [1e-2, 1]) def test_soft_dtw_grad( - self, rng: jax.Array, debiased: bool, jit: bool, gamma: float + self, rng: jnp.ndarray, debiased: bool, jit: bool, gamma: float ): rngs = jax.random.split(rng, 4) eps, tol = 1e-3, 1e-5 diff --git a/tests/geometry/graph_test.py b/tests/geometry/graph_test.py index cda2900a8..b0c194c23 100644 --- a/tests/geometry/graph_test.py +++ b/tests/geometry/graph_test.py @@ -35,7 +35,7 @@ def random_graph( *, return_laplacian: bool = False, directed: bool = False, -) -> jax.Array: +) -> jnp.ndarray: G = random_graphs.fast_gnp_random_graph(n, p, seed=seed, directed=directed) if not directed: assert nx.is_connected(G), "Generated graph is not connected." @@ -51,7 +51,7 @@ def random_graph( return jnp.asarray(G.toarray()) -def gt_geometry(G: jax.Array, *, epsilon: float = 1e-2) -> geometry.Geometry: +def gt_geometry(G: jnp.ndarray, *, epsilon: float = 1e-2) -> geometry.Geometry: if not isinstance(G, nx.Graph): G = nx.from_numpy_array(np.asarray(G)) @@ -72,7 +72,7 @@ def gt_geometry(G: jax.Array, *, epsilon: float = 1e-2) -> geometry.Geometry: class TestGraph: - def test_kernel_is_symmetric_positive_definite(self, rng: jax.Array): + def test_kernel_is_symmetric_positive_definite(self, rng: jnp.ndarray): n, tol = 65, 0.02 x = jax.random.normal(rng, (n,)) geom = graph.Graph.from_graph(random_graph(n), t=1e-3) @@ -109,7 +109,7 @@ def test_automatic_t(self): ) def test_approximates_ground_truth( self, - rng: jax.Array, + rng: jnp.ndarray, numerical_scheme: Literal["backward_euler", "crank_nicolson"], ): eps, n_steps = 1e-5, 20 @@ -160,7 +160,7 @@ def test_crank_nicolson_more_stable(self, t: Optional[float], n_steps: int): @pytest.mark.parametrize(("jit", "normalize"), [(False, True), (True, False)]) def test_directed_graph(self, jit: bool, normalize: bool): - def create_graph(G: jax.Array) -> graph.Graph: + def create_graph(G: jnp.ndarray) -> graph.Graph: return graph.Graph.from_graph(G, directed=True, normalize=normalize) G = random_graph(16, p=0.25, directed=True) @@ -181,7 +181,7 @@ def create_graph(G: jax.Array) -> graph.Graph: @pytest.mark.parametrize("normalize", [False, True]) def test_normalize_laplacian(self, directed: bool, normalize: bool): - def laplacian(G: jax.Array) -> jax.Array: + def laplacian(G: jnp.ndarray) -> jnp.ndarray: if directed: G = G + G.T @@ -203,7 +203,7 @@ def laplacian(G: jax.Array) -> jax.Array: np.testing.assert_allclose(actual, expected, rtol=1e-6, atol=1e-6) @pytest.mark.fast.with_args(jit=[False, True], only_fast=0) - def test_graph_sinkhorn(self, rng: jax.Array, jit: bool): + def test_graph_sinkhorn(self, rng: jnp.ndarray, jit: bool): def callback(geom: geometry.Geometry) -> sinkhorn.SinkhornOutput: solver = sinkhorn.Sinkhorn(lse_mode=False) @@ -246,12 +246,12 @@ def callback(geom: geometry.Geometry) -> sinkhorn.SinkhornOutput: ids=["not-implicit", "implicit"], ) def test_dense_graph_differentiability( - self, rng: jax.Array, implicit_diff: bool + self, rng: jnp.ndarray, implicit_diff: bool ): def callback( - data: jax.Array, rows: jax.Array, cols: jax.Array, shape: Tuple[int, - int] + data: jnp.ndarray, rows: jnp.ndarray, cols: jnp.ndarray, + shape: Tuple[int, int] ) -> float: G = sparse.BCOO((data, jnp.c_[rows, cols]), shape=shape).todense() @@ -281,7 +281,7 @@ def callback( actual = 2 * jnp.vdot(v_w, grad_w) np.testing.assert_allclose(actual, expected, rtol=1e-4, atol=1e-4) - def test_tolerance_hilbert_metric(self, rng: jax.Array): + def test_tolerance_hilbert_metric(self, rng: jnp.ndarray): n, n_steps, t, tol = 256, 1000, 1e-4, 3e-4 G = random_graph(n, p=0.15) x = jnp.abs(jax.random.normal(rng, (n,))) diff --git a/tests/geometry/low_rank_test.py b/tests/geometry/low_rank_test.py index 6b3c36edd..b3cda89cf 100644 --- a/tests/geometry/low_rank_test.py +++ b/tests/geometry/low_rank_test.py @@ -24,7 +24,7 @@ @pytest.mark.fast() class TestLRGeometry: - def test_apply(self, rng: jax.Array): + def test_apply(self, rng: jnp.ndarray): """Test application of cost to vec or matrix.""" n, m, r = 17, 11, 7 rngs = jax.random.split(rng, 5) @@ -45,7 +45,7 @@ def test_apply(self, rng: jax.Array): @pytest.mark.parametrize("scale_cost", ["mean", "max_cost", "max_bound", 42.]) def test_conversion_pointcloud( - self, rng: jax.Array, scale_cost: Union[str, float] + self, rng: jnp.ndarray, scale_cost: Union[str, float] ): """Test conversion from PointCloud to LRCGeometry.""" n, m, d = 17, 11, 3 @@ -69,7 +69,7 @@ def test_conversion_pointcloud( rtol=1e-4 ) - def test_apply_squared(self, rng: jax.Array): + def test_apply_squared(self, rng: jnp.ndarray): """Test application of squared cost to vec or matrix.""" n, m = 27, 25 rngs = jax.random.split(rng, 5) @@ -94,7 +94,7 @@ def test_apply_squared(self, rng: jax.Array): @pytest.mark.parametrize("bias", [(0, 0), (4, 5)]) @pytest.mark.parametrize("scale_factor", [(1, 1), (2, 3)]) def test_add_lr_geoms( - self, rng: jax.Array, bias: Tuple[float, float], + self, rng: jnp.ndarray, bias: Tuple[float, float], scale_factor: Tuple[float, float] ): """Test application of cost to vec or matrix.""" @@ -133,7 +133,7 @@ def test_add_lr_geoms( @pytest.mark.parametrize(("scale", "scale_cost", "epsilon"), [(0.1, "mean", None), (0.9, "max_cost", 1e-2)]) def test_add_lr_geoms_scale_factor( - self, rng: jax.Array, scale: float, scale_cost: str, + self, rng: jnp.ndarray, scale: float, scale_cost: str, epsilon: Optional[float] ): n, d = 71, 2 @@ -160,7 +160,8 @@ def test_add_lr_geoms_scale_factor( @pytest.mark.parametrize("axis", [0, 1]) @pytest.mark.parametrize("fn", [lambda x: x + 10, lambda x: x * 2]) def test_apply_affine_function_efficient( - self, rng: jax.Array, fn: Callable[[jax.Array], jax.Array], axis: int + self, rng: jnp.ndarray, fn: Callable[[jnp.ndarray], jnp.ndarray], + axis: int ): n, m, d = 21, 13, 3 rngs = jax.random.split(rng, 3) @@ -180,7 +181,7 @@ def test_apply_affine_function_efficient( np.testing.assert_allclose(res_ineff, res_eff, rtol=1e-4, atol=1e-4) @pytest.mark.parametrize("rank", [5, 1000]) - def test_point_cloud_to_lr(self, rng: jax.Array, rank: int): + def test_point_cloud_to_lr(self, rng: jnp.ndarray, rank: int): n, m = 1500, 1000 scale = 2.0 rngs = jax.random.split(rng, 2) @@ -220,7 +221,7 @@ def assert_upper_bound( assert lhs <= rhs @pytest.mark.fast.with_args(rank=[2, 3], tol=[5e-1, 1e-2], only_fast=0) - def test_geometry_to_lr(self, rng: jax.Array, rank: int, tol: float): + def test_geometry_to_lr(self, rng: jnp.ndarray, rank: int, tol: float): rng1, rng2 = jax.random.split(rng, 2) x = jax.random.normal(rng1, shape=(370, 3)) y = jax.random.normal(rng2, shape=(460, 3)) @@ -241,7 +242,8 @@ def test_geometry_to_lr(self, rng: jax.Array, rank: int, tol: float): only_fast=1 ) def test_point_cloud_to_lr( - self, rng: jax.Array, batch_size: Optional[int], scale_cost: Optional[str] + self, rng: jnp.ndarray, batch_size: Optional[int], + scale_cost: Optional[str] ): rank, tol = 7, 1e-1 rng1, rng2 = jax.random.split(rng, 2) @@ -265,7 +267,7 @@ def test_point_cloud_to_lr( assert geom_lr.cost_rank == rank self.assert_upper_bound(geom, geom_lr, rank=rank, tol=tol) - def test_to_lrc_geometry_noop(self, rng: jax.Array): + def test_to_lrc_geometry_noop(self, rng: jnp.ndarray): rng1, rng2 = jax.random.split(rng, 2) cost1 = jax.random.normal(rng1, shape=(32, 2)) cost2 = jax.random.normal(rng2, shape=(23, 2)) @@ -287,7 +289,7 @@ def test_apply_transport_from_potentials(self): np.testing.assert_allclose(res, 1.1253539e-07, rtol=1e-6, atol=1e-6) @pytest.mark.limit_memory("190 MB") - def test_large_scale_factorization(self, rng: jax.Array): + def test_large_scale_factorization(self, rng: jnp.ndarray): rank, tol = 4, 1e-2 rng1, rng2 = jax.random.split(rng, 2) x = jax.random.normal(rng1, shape=(10_000, 7)) @@ -318,7 +320,7 @@ def test_conversion_grid(self): cost_matrix, cost_matrix_lrc, rtol=1e-5, atol=1e-5 ) - def test_full_to_lrc_geometry(self, rng: jax.Array): + def test_full_to_lrc_geometry(self, rng: jnp.ndarray): rng1, rng2 = jax.random.split(rng, 2) x = jax.random.normal(rng1, shape=(13, 7)) y = jax.random.normal(rng2, shape=(29, 7)) diff --git a/tests/geometry/pointcloud_test.py b/tests/geometry/pointcloud_test.py index 5f75ddb8e..ff32789fe 100644 --- a/tests/geometry/pointcloud_test.py +++ b/tests/geometry/pointcloud_test.py @@ -24,7 +24,7 @@ @pytest.mark.fast() class TestPointCloudApply: - def test_apply_cost_and_kernel(self, rng: jax.Array): + def test_apply_cost_and_kernel(self, rng: jnp.ndarray): """Test consistency of cost/kernel apply to vec.""" n, m, p, b = 5, 8, 10, 7 rngs = jax.random.split(rng, 5) @@ -68,7 +68,7 @@ def test_apply_cost_and_kernel(self, rng: jax.Array): np.testing.assert_allclose(prod0_online, prod0, rtol=1e-03, atol=1e-02) np.testing.assert_allclose(prod1_online, prod1, rtol=1e-03, atol=1e-02) - def test_general_cost_fn(self, rng: jax.Array): + def test_general_cost_fn(self, rng: jnp.ndarray): """Test non-vec cost apply to vec.""" n, m, p, b = 5, 8, 10, 7 rngs = jax.random.split(rng, 5) @@ -97,7 +97,7 @@ def test_correct_shape(self): np.testing.assert_array_equal(pc.shape, (n, m)) @pytest.mark.parametrize("axis", [0, 1]) - def test_apply_cost_without_norm(self, rng: jax.Array, axis: 1): + def test_apply_cost_without_norm(self, rng: jnp.ndarray, axis: 1): rng1, rng2 = jax.random.split(rng, 2) x = jax.random.normal(rng1, shape=(17, 3)) y = jax.random.normal(rng2, shape=(12, 3)) @@ -122,7 +122,7 @@ class TestPointCloudCosineConversion: "scale_cost", ["mean", "median", "max_cost", "max_norm", 41] ) def test_cosine_to_sqeucl_conversion( - self, rng: jax.Array, scale_cost: Union[str, float] + self, rng: jnp.ndarray, scale_cost: Union[str, float] ): rng1, rng2 = jax.random.split(rng, 2) x = jax.random.normal(rng1, shape=(101, 4)) @@ -155,7 +155,7 @@ def test_cosine_to_sqeucl_conversion( ) @pytest.mark.parametrize("axis", [0, 1]) def test_apply_cost_cosine_to_sqeucl( - self, rng: jax.Array, axis: int, scale_cost: Union[str, float] + self, rng: jnp.ndarray, axis: int, scale_cost: Union[str, float] ): rng1, rng2 = jax.random.split(rng, 2) x = jax.random.normal(rng1, shape=(17, 5)) diff --git a/tests/geometry/scaling_cost_test.py b/tests/geometry/scaling_cost_test.py index b60805d34..ce3f616ce 100644 --- a/tests/geometry/scaling_cost_test.py +++ b/tests/geometry/scaling_cost_test.py @@ -26,7 +26,7 @@ class TestScaleCost: @pytest.fixture(autouse=True) - def initialize(self, rng: jax.Array): + def initialize(self, rng: jnp.ndarray): self.dim = 4 self.n = 7 self.m = 9 @@ -53,7 +53,7 @@ def test_scale_cost_pointcloud( """Test various scale cost options for pointcloud.""" def apply_sinkhorn( - x: jax.Array, y: jax.Array, a: jax.Array, b: jax.Array, + x: jnp.ndarray, y: jnp.ndarray, a: jnp.ndarray, b: jnp.ndarray, scale_cost: Union[str, float] ): geom = pointcloud.PointCloud( @@ -120,8 +120,8 @@ def test_scale_cost_geometry(self, scale: Union[str, float]): """Test various scale cost options for geometry.""" def apply_sinkhorn( - cost: jax.Array, a: jax.Array, b: jax.Array, scale_cost: Union[str, - float] + cost: jnp.ndarray, a: jnp.ndarray, b: jnp.ndarray, + scale_cost: Union[str, float] ): geom = geometry.Geometry(cost, epsilon=self.eps, scale_cost=scale_cost) prob = linear_problem.LinearProblem(geom, a, b) diff --git a/tests/geometry/subsetting_test.py b/tests/geometry/subsetting_test.py index c07929436..5d7306682 100644 --- a/tests/geometry/subsetting_test.py +++ b/tests/geometry/subsetting_test.py @@ -25,7 +25,7 @@ @pytest.fixture() def pc_masked( - rng: jax.Array + rng: jnp.ndarray ) -> Tuple[pointcloud.PointCloud, pointcloud.PointCloud]: n, m = 20, 30 rng1, rng2 = jax.random.split(rng, 2) @@ -66,7 +66,7 @@ class TestMaskPointCloud: "clazz", [geometry.Geometry, pointcloud.PointCloud, low_rank.LRCGeometry] ) def test_mask( - self, rng: jax.Array, clazz: Type[geometry.Geometry], + self, rng: jnp.ndarray, clazz: Type[geometry.Geometry], src_ixs: Optional[Union[int, Sequence[int]]], tgt_ixs: Optional[Union[int, Sequence[int]]] ): @@ -140,7 +140,7 @@ def test_masked_summary( ) def test_mask_permutation( - self, geom_masked: Tuple[Geom_t, pointcloud.PointCloud], rng: jax.Array + self, geom_masked: Tuple[Geom_t, pointcloud.PointCloud], rng: jnp.ndarray ): rng1, rng2 = jax.random.split(rng) geom, _ = geom_masked @@ -162,7 +162,7 @@ def test_mask_permutation( ) def test_boolean_mask( - self, geom_masked: Tuple[Geom_t, pointcloud.PointCloud], rng: jax.Array + self, geom_masked: Tuple[Geom_t, pointcloud.PointCloud], rng: jnp.ndarray ): rng1, rng2 = jax.random.split(rng) p = jnp.array([0.5, 0.5]) diff --git a/tests/initializers/linear/sinkhorn_init_test.py b/tests/initializers/linear/sinkhorn_init_test.py index 73c0ddaaa..7686ddfa9 100644 --- a/tests/initializers/linear/sinkhorn_init_test.py +++ b/tests/initializers/linear/sinkhorn_init_test.py @@ -25,7 +25,7 @@ def create_sorting_problem( - rng: jax.Array, + rng: jnp.ndarray, n: int, epsilon: float = 1e-2, batch_size: Optional[int] = None @@ -55,7 +55,7 @@ def create_sorting_problem( def create_ot_problem( - rng: jax.Array, + rng: jnp.ndarray, n: int, m: int, d: int, @@ -80,12 +80,12 @@ def create_ot_problem( def run_sinkhorn( - x: jax.Array, - y: jax.Array, + x: jnp.ndarray, + y: jnp.ndarray, *, initializer: linear_init.SinkhornInitializer, - a: Optional[jax.Array] = None, - b: Optional[jax.Array] = None, + a: Optional[jnp.ndarray] = None, + b: Optional[jnp.ndarray] = None, epsilon: float = 1e-2, lse_mode: bool = True, ) -> sinkhorn.SinkhornOutput: @@ -132,7 +132,9 @@ def test_create_initializer(self, init: str): @pytest.mark.parametrize(("vector_min", "lse_mode"), [(True, True), (True, False), (False, True)]) - def test_sorting_init(self, vector_min: bool, lse_mode: bool, rng: jax.Array): + def test_sorting_init( + self, vector_min: bool, lse_mode: bool, rng: jnp.ndarray + ): """Tests sorting dual initializer.""" n = 50 epsilon = 1e-2 @@ -166,7 +168,7 @@ def test_sorting_init(self, vector_min: bool, lse_mode: bool, rng: jax.Array): assert sink_out_init.converged assert sink_out_base.n_iters > sink_out_init.n_iters - def test_sorting_init_online(self, rng: jax.Array): + def test_sorting_init_online(self, rng: jnp.ndarray): n = 10 epsilon = 1e-2 @@ -177,7 +179,7 @@ def test_sorting_init_online(self, rng: jax.Array): with pytest.raises(AssertionError, match=r"online"): sort_init.init_dual_a(ot_problem, lse_mode=True) - def test_sorting_init_square_cost(self, rng: jax.Array): + def test_sorting_init_square_cost(self, rng: jnp.ndarray): n, m, d = 10, 15, 1 epsilon = 1e-2 @@ -186,7 +188,7 @@ def test_sorting_init_square_cost(self, rng: jax.Array): with pytest.raises(AssertionError, match=r"square"): sort_init.init_dual_a(ot_problem, lse_mode=True) - def test_default_initializer(self, rng: jax.Array): + def test_default_initializer(self, rng: jnp.ndarray): """Tests default initializer""" n, m, d = 20, 20, 2 epsilon = 1e-2 @@ -204,7 +206,7 @@ def test_default_initializer(self, rng: jax.Array): np.testing.assert_array_equal(0., default_potential_a) np.testing.assert_array_equal(0., default_potential_b) - def test_gauss_pointcloud_geom(self, rng: jax.Array): + def test_gauss_pointcloud_geom(self, rng: jnp.ndarray): n, m, d = 20, 20, 2 epsilon = 1e-2 @@ -225,7 +227,7 @@ def test_gauss_pointcloud_geom(self, rng: jax.Array): @pytest.mark.parametrize("jit", [False, True]) @pytest.mark.parametrize("initializer", ["sorting", "gaussian", "subsample"]) def test_initializer_n_iter( - self, rng: jax.Array, lse_mode: bool, jit: bool, + self, rng: jnp.ndarray, lse_mode: bool, jit: bool, initializer: Literal["sorting", "gaussian", "subsample"] ): """Tests Gaussian initializer""" diff --git a/tests/initializers/linear/sinkhorn_lr_init_test.py b/tests/initializers/linear/sinkhorn_lr_init_test.py index f3fe7acd1..e954fec76 100644 --- a/tests/initializers/linear/sinkhorn_lr_init_test.py +++ b/tests/initializers/linear/sinkhorn_lr_init_test.py @@ -36,7 +36,7 @@ def test_explicit_initializer(self): ) @pytest.mark.parametrize("partial_init", ["q", "r", "g"]) def test_partial_initialization( - self, rng: jax.Array, initializer: str, partial_init: str + self, rng: jnp.ndarray, initializer: str, partial_init: str ): n, d, rank = 27, 5, 6 rng1, rng2, rng3, rng4 = jax.random.split(rng, 4) @@ -64,7 +64,7 @@ def test_partial_initialization( @pytest.mark.fast.with_args("rank", [2, 4, 10, 13], only_fast=True) def test_generalized_k_means_has_correct_rank( - self, rng: jax.Array, rank: int + self, rng: jnp.ndarray, rank: int ): n, d = 27, 5 x = jax.random.normal(rng, (n, d)) @@ -81,7 +81,7 @@ def test_generalized_k_means_has_correct_rank( assert jnp.linalg.matrix_rank(q) == rank assert jnp.linalg.matrix_rank(r) == rank - def test_generalized_k_means_matches_k_means(self, rng: jax.Array): + def test_generalized_k_means_matches_k_means(self, rng: jnp.ndarray): n, d, rank = 27, 7, 5 eps = 1e-1 rng1, rng2 = jax.random.split(rng, 2) @@ -111,7 +111,7 @@ def test_generalized_k_means_matches_k_means(self, rng: jax.Array): ) @pytest.mark.parametrize("epsilon", [0., 1e-1]) - def test_better_initialization_helps(self, rng: jax.Array, epsilon: float): + def test_better_initialization_helps(self, rng: jnp.ndarray, epsilon: float): n, d, rank = 81, 13, 3 rng1, rng2 = jax.random.split(rng, 2) x = jax.random.normal(rng1, (n, d)) diff --git a/tests/initializers/quadratic/gw_init_test.py b/tests/initializers/quadratic/gw_init_test.py index 4c39bafb4..e680e9c01 100644 --- a/tests/initializers/quadratic/gw_init_test.py +++ b/tests/initializers/quadratic/gw_init_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import jax +import jax.numpy as jnp import numpy as np import pytest @@ -49,7 +50,7 @@ def test_explicit_initializer_lr(self): assert solver.initializer.rank == rank @pytest.mark.parametrize("eps", [0., 1e-2]) - def test_gw_better_initialization_helps(self, rng: jax.Array, eps: float): + def test_gw_better_initialization_helps(self, rng: jnp.ndarray, eps: float): n, m, d1, d2, rank = 83, 84, 8, 6, 4 rng1, rng2, rng3, rng4 = jax.random.split(rng, 4) diff --git a/tests/math/lse_test.py b/tests/math/lse_test.py index b842afe21..36e7eba7f 100644 --- a/tests/math/lse_test.py +++ b/tests/math/lse_test.py @@ -22,7 +22,7 @@ @pytest.mark.fast() class TestGeometryLse: - def test_lse(self, rng: jax.Array): + def test_lse(self, rng: jnp.ndarray): """Test consistency of custom lse's jvp.""" n, m = 12, 8 rngs = jax.random.split(rng, 5) diff --git a/tests/math/math_utils_test.py b/tests/math/math_utils_test.py index b8451355b..7848bc2a9 100644 --- a/tests/math/math_utils_test.py +++ b/tests/math/math_utils_test.py @@ -26,7 +26,7 @@ class TestNorm: @pytest.mark.parametrize("ord", [1.1, 2.0, jnp.inf]) def test_norm( self, - rng: jax.Array, + rng: jnp.ndarray, ord, ): d = 5 diff --git a/tests/math/matrix_square_root_test.py b/tests/math/matrix_square_root_test.py index 2263ea8b9..7c8a1e7d5 100644 --- a/tests/math/matrix_square_root_test.py +++ b/tests/math/matrix_square_root_test.py @@ -21,7 +21,7 @@ from ott.math import matrix_square_root -def _get_random_spd_matrix(dim: int, rng: jax.Array): +def _get_random_spd_matrix(dim: int, rng: jnp.ndarray): # Get a random symmetric, positive definite matrix of a specified size. rng, subrng0, subrng1 = jax.random.split(rng, num=3) @@ -37,9 +37,9 @@ def _get_random_spd_matrix(dim: int, rng: jax.Array): def _get_test_fn( - fn: Callable[[jax.Array], jax.Array], dim: int, rng: jax.Array, + fn: Callable[[jnp.ndarray], jnp.ndarray], dim: int, rng: jnp.ndarray, **kwargs: Any -) -> Callable[[jax.Array], jax.Array]: +) -> Callable[[jnp.ndarray], jnp.ndarray]: # We want to test gradients of a function fn that maps positive definite # matrices to positive definite matrices by comparing them to finite # difference approximations. We'll do so via a test function that @@ -54,7 +54,7 @@ def _get_test_fn( unit = jax.random.normal(key=subrng3, shape=(dim, dim)) unit /= jnp.sqrt(jnp.sum(unit ** 2.)) - def _test_fn(x: jax.Array, **kwargs: Any) -> jax.Array: + def _test_fn(x: jnp.ndarray, **kwargs: Any) -> jnp.ndarray: # m is the product of 2 symmetric, positive definite matrices # so it will be positive definite but not necessarily symmetric m = jnp.matmul(m0, m1 + x * dx) @@ -63,7 +63,7 @@ def _test_fn(x: jax.Array, **kwargs: Any) -> jax.Array: return _test_fn -def _sqrt_plus_inv_sqrt(x: jax.Array) -> jax.Array: +def _sqrt_plus_inv_sqrt(x: jnp.ndarray) -> jnp.ndarray: sqrtm = matrix_square_root.sqrtm(x) return sqrtm[0] + sqrtm[1] @@ -71,7 +71,7 @@ def _sqrt_plus_inv_sqrt(x: jax.Array) -> jax.Array: class TestMatrixSquareRoot: @pytest.fixture(autouse=True) - def initialize(self, rng: jax.Array): + def initialize(self, rng: jnp.ndarray): self.dim = 13 self.batch = 3 # Values for testing the Sylvester solver diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py index 008036790..0dd65ba57 100644 --- a/tests/neural/conftest.py +++ b/tests/neural/conftest.py @@ -23,12 +23,8 @@ def __init__( self.rng = np.random.default_rng(seed=0) def __next__(self) -> Mapping[str, np.ndarray]: - inds_source = self.rng.choice( - len(self.source_data), size=[self.batch_size] - ) - inds_target = self.rng.choice( - len(self.target_data), size=[self.batch_size] - ) + inds_source = self.rng.choice(len(self.source_data), size=[self.batch_size]) + inds_target = self.rng.choice(len(self.target_data), size=[self.batch_size]) return { "source_lin": self.source_data[inds_source, :], @@ -79,8 +75,12 @@ def data_loader_gaussian_conditional(): source_1 = rng.normal(size=(100, 2)) target_1 = rng.normal(size=(100, 2)) - 2.0 - dl0 = DataLoader(source_0, target_0, 16, source_conditions=np.zeros_like(source_0) * 0.0) - dl1 = DataLoader(source_1, target_1, 16, source_conditions=np.ones_like(source_1) * 1.0) + dl0 = DataLoader( + source_0, target_0, 16, source_conditions=np.zeros_like(source_0) * 0.0 + ) + dl1 = DataLoader( + source_1, target_1, 16, source_conditions=np.ones_like(source_1) * 1.0 + ) return ConditionalDataLoader({"0": dl0, "1": dl1}, np.array([0.5, 0.5])) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index 0c4abb55e..ed65fc657 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -13,7 +13,6 @@ # limitations under the License. from typing import Iterator, Optional -import jax import jax.numpy as jnp import optax import pytest @@ -76,7 +75,7 @@ def test_genot_linear_unconditional( result_forward = genot.transport( source_lin, condition=condition, forward=True ) - assert isinstance(result_forward, jax.Array) + assert isinstance(result_forward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_forward)) == 0 @pytest.mark.parametrize("k_samples_per_x", [1, 2]) @@ -120,7 +119,7 @@ def test_genot_quad_unconditional( result_forward = genot.transport( source_quad, condition=condition, forward=True ) - assert isinstance(result_forward, jax.Array) + assert isinstance(result_forward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_forward)) == 0 @pytest.mark.parametrize("k_samples_per_x", [1, 2]) @@ -166,7 +165,7 @@ def test_genot_fused_unconditional( condition=condition, forward=True ) - assert isinstance(result_forward, jax.Array) + assert isinstance(result_forward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_forward)) == 0 @pytest.mark.parametrize("k_samples_per_x", [1, 2]) @@ -217,7 +216,7 @@ def test_genot_linear_conditional( result_forward = genot.transport( source_lin, condition=condition, forward=True ) - assert isinstance(result_forward, jax.Array) + assert isinstance(result_forward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_forward)) == 0 @pytest.mark.parametrize("k_samples_per_x", [1, 2]) @@ -263,7 +262,7 @@ def test_genot_quad_conditional( result_forward = genot.transport( source_quad, condition=condition, forward=True ) - assert isinstance(result_forward, jax.Array) + assert isinstance(result_forward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_forward)) == 0 @pytest.mark.parametrize("k_samples_per_x", [1, 2]) @@ -311,7 +310,7 @@ def test_genot_fused_conditional( condition=condition, forward=True ) - assert isinstance(result_forward, jax.Array) + assert isinstance(result_forward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_forward)) == 0 @pytest.mark.parametrize("conditional", [False, True]) @@ -365,9 +364,9 @@ def test_genot_linear_learn_rescaling( genot(data_loader, data_loader) result_eta = genot.evaluate_eta(source_lin, condition=condition) - assert isinstance(result_eta, jax.Array) + assert isinstance(result_eta, jnp.ndarray) assert jnp.sum(jnp.isnan(result_eta)) == 0 result_xi = genot.evaluate_xi(target_lin, condition=condition) - assert isinstance(result_xi, jax.Array) + assert isinstance(result_xi, jnp.ndarray) assert jnp.sum(jnp.isnan(result_xi)) == 0 diff --git a/tests/neural/icnn_test.py b/tests/neural/icnn_test.py index fd6c07f2b..4d760557f 100644 --- a/tests/neural/icnn_test.py +++ b/tests/neural/icnn_test.py @@ -22,7 +22,7 @@ @pytest.mark.fast() class TestICNN: - def test_icnn_convexity(self, rng: jax.Array): + def test_icnn_convexity(self, rng: jnp.ndarray): """Tests convexity of ICNN.""" n_samples, n_features = 10, 2 dim_hidden = (64, 64) @@ -48,7 +48,7 @@ def test_icnn_convexity(self, rng: jax.Array): np.testing.assert_array_equal(jnp.asarray(out) >= 0, True) - def test_icnn_hessian(self, rng: jax.Array): + def test_icnn_hessian(self, rng: jnp.ndarray): """Tests if Hessian of ICNN is positive-semidefinite.""" # define icnn model diff --git a/tests/neural/losses_test.py b/tests/neural/losses_test.py index 8cff7bd64..6ad2c0b3e 100644 --- a/tests/neural/losses_test.py +++ b/tests/neural/losses_test.py @@ -13,6 +13,7 @@ # limitations under the License. import jax +import jax.numpy as jnp import numpy as np import pytest @@ -27,7 +28,7 @@ class TestMongeGap: @pytest.mark.parametrize("n_samples", [5, 25]) @pytest.mark.parametrize("n_features", [10, 50, 100]) def test_monge_gap_non_negativity( - self, rng: jax.Array, n_samples: int, n_features: int + self, rng: jnp.ndarray, n_samples: int, n_features: int ): # generate data @@ -53,7 +54,7 @@ def test_monge_gap_non_negativity( np.testing.assert_array_equal(monge_gap_value, monge_gap_from_samples_value) - def test_monge_gap_jit(self, rng: jax.Array): + def test_monge_gap_jit(self, rng: jnp.ndarray): n_samples, n_features = 31, 17 # generate data rng1, rng2 = jax.random.split(rng, 2) @@ -85,7 +86,7 @@ def test_monge_gap_jit(self, rng: jax.Array): ], ) def test_monge_gap_from_samples_different_cost( - self, rng: jax.Array, cost_fn: costs.CostFn, n_samples: int, + self, rng: jnp.ndarray, cost_fn: costs.CostFn, n_samples: int, n_features: int ): """Test that the Monge gap for different costs. diff --git a/tests/neural/map_estimator_test.py b/tests/neural/map_estimator_test.py index 0454db751..7c506aa38 100644 --- a/tests/neural/map_estimator_test.py +++ b/tests/neural/map_estimator_test.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Optional -import jax +import jax.numpy as jnp import pytest from ott import datasets @@ -34,8 +34,8 @@ def test_map_estimator_convergence(self): # define the fitting loss and the regularizer def fitting_loss( - samples: jax.Array, - mapped_samples: jax.Array, + samples: jnp.ndarray, + mapped_samples: jnp.ndarray, ) -> Optional[float]: r"""Sinkhorn divergence fitting loss.""" div = sinkhorn_divergence.sinkhorn_divergence( diff --git a/tests/neural/meta_initializer_test.py b/tests/neural/meta_initializer_test.py index 25a88907e..f978e8206 100644 --- a/tests/neural/meta_initializer_test.py +++ b/tests/neural/meta_initializer_test.py @@ -31,7 +31,7 @@ class MetaMLP(nn.Module): num_hidden_layers: int = 3 @nn.compact - def __call__(self, a: jax.Array, b: jax.Array) -> jax.Array: + def __call__(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: dtype = a.dtype z = jnp.concatenate((a, b)) for _ in range(self.num_hidden_layers): @@ -40,7 +40,7 @@ def __call__(self, a: jax.Array, b: jax.Array) -> jax.Array: def create_ot_problem( - rng: jax.Array, + rng: jnp.ndarray, n: int, m: int, d: int, @@ -65,12 +65,12 @@ def create_ot_problem( def run_sinkhorn( - x: jax.Array, - y: jax.Array, + x: jnp.ndarray, + y: jnp.ndarray, *, initializer: linear_init.SinkhornInitializer, - a: Optional[jax.Array] = None, - b: Optional[jax.Array] = None, + a: Optional[jnp.ndarray] = None, + b: Optional[jnp.ndarray] = None, epsilon: float = 1e-2, lse_mode: bool = True, ) -> sinkhorn.SinkhornOutput: @@ -86,7 +86,7 @@ def run_sinkhorn( class TestMetaInitializer: @pytest.mark.parametrize("lse_mode", [True, False]) - def test_meta_initializer(self, rng: jax.Array, lse_mode: bool): + def test_meta_initializer(self, rng: jnp.ndarray, lse_mode: bool): """Tests Meta initializer""" n, m, d = 20, 20, 2 epsilon = 1e-2 diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index 9a75cb3fb..4346b6be8 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -13,7 +13,6 @@ # limitations under the License. from typing import Iterator, Type -import jax import jax.numpy as jnp import optax import pytest @@ -61,12 +60,18 @@ def test_flow_matching(self, data_loader_gaussian, flow: Type[BaseFlow]): fm(data_loader_gaussian, data_loader_gaussian) batch = next(data_loader_gaussian) - result_forward = fm.transport(batch["source_lin"], condition=batch["source_conditions"], forward=True) - assert isinstance(result_forward, jax.Array) + result_forward = fm.transport( + batch["source_lin"], condition=batch["source_conditions"], forward=True + ) + assert isinstance(result_forward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_forward)) == 0 - result_backward = fm.transport(batch["target_lin"], condition=batch["target_conditions"], forward=False) - assert isinstance(result_backward, jax.Array) + result_backward = fm.transport( + batch["target_lin"], + condition=batch["target_conditions"], + forward=False + ) + assert isinstance(result_backward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_backward)) == 0 @pytest.mark.parametrize( @@ -103,12 +108,18 @@ def test_flow_matching_with_conditions( ) batch = next(data_loader_gaussian_with_conditions) - result_forward = fm.transport(batch["source_lin"], condition=batch["source_conditions"], forward=True) - assert isinstance(result_forward, jax.Array) + result_forward = fm.transport( + batch["source_lin"], condition=batch["source_conditions"], forward=True + ) + assert isinstance(result_forward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_forward)) == 0 - result_backward = fm.transport(batch["target_lin"], condition=batch["target_conditions"], forward=False) - assert isinstance(result_backward, jax.Array) + result_backward = fm.transport( + batch["target_lin"], + condition=batch["target_conditions"], + forward=False + ) + assert isinstance(result_backward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_backward)) == 0 @pytest.mark.parametrize( @@ -142,12 +153,18 @@ def test_flow_matching_conditional( fm(data_loader_gaussian_conditional, data_loader_gaussian_conditional) batch = next(data_loader_gaussian_conditional) - result_forward = fm.transport(batch["source_lin"], condition=batch["source_conditions"], forward=True) - assert isinstance(result_forward, jax.Array) + result_forward = fm.transport( + batch["source_lin"], condition=batch["source_conditions"], forward=True + ) + assert isinstance(result_forward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_forward)) == 0 - result_backward = fm.transport(batch["target_lin"], condition=batch["target_conditions"], forward=False) - assert isinstance(result_backward, jax.Array) + result_backward = fm.transport( + batch["target_lin"], + condition=batch["target_conditions"], + forward=False + ) + assert isinstance(result_backward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_backward)) == 0 @pytest.mark.parametrize("conditional", [False, True]) @@ -190,10 +207,14 @@ def test_flow_matching_learn_rescaling( ) fm(data_loader, data_loader) - result_eta = fm.evaluate_eta(batch["source_lin"], condition=batch["source_conditions"]) - assert isinstance(result_eta, jax.Array) + result_eta = fm.evaluate_eta( + batch["source_lin"], condition=batch["source_conditions"] + ) + assert isinstance(result_eta, jnp.ndarray) assert jnp.sum(jnp.isnan(result_eta)) == 0 - result_xi = fm.evaluate_xi(batch["target_lin"], condition=batch["target_conditions"]) - assert isinstance(result_xi, jax.Array) + result_xi = fm.evaluate_xi( + batch["target_lin"], condition=batch["target_conditions"] + ) + assert isinstance(result_xi, jnp.ndarray) assert jnp.sum(jnp.isnan(result_xi)) == 0 diff --git a/tests/problems/linear/potentials_test.py b/tests/problems/linear/potentials_test.py index c9fa9cf17..dd5d4bbd6 100644 --- a/tests/problems/linear/potentials_test.py +++ b/tests/problems/linear/potentials_test.py @@ -36,7 +36,7 @@ def test_device_put(self): class TestEntropicPotentials: - def test_device_put(self, rng: jax.Array): + def test_device_put(self, rng: jnp.ndarray): n = 10 device = jax.devices()[0] rngs = jax.random.split(rng, 5) @@ -53,7 +53,7 @@ def test_device_put(self, rng: jax.Array): _ = jax.device_put(pot, device) @pytest.mark.fast.with_args(eps=[5e-2, 1e-1], only_fast=0) - def test_entropic_potentials_dist(self, rng: jax.Array, eps: float): + def test_entropic_potentials_dist(self, rng: jnp.ndarray, eps: float): n1, n2, d = 64, 96, 2 rng1, rng2, rng3, rng4 = jax.random.split(rng, 4) @@ -91,7 +91,7 @@ def test_entropic_potentials_dist(self, rng: jax.Array, eps: float): @pytest.mark.fast.with_args(forward=[False, True], only_fast=0) def test_entropic_potentials_displacement( - self, rng: jax.Array, forward: bool, monkeypatch + self, rng: jnp.ndarray, forward: bool, monkeypatch ): """Tests entropic displacements, as well as their plots.""" n1, n2, d = 96, 128, 2 @@ -134,7 +134,7 @@ def test_entropic_potentials_displacement( p=[1.3, 2.2, 1.0], forward=[False, True], only_fast=0 ) def test_entropic_potentials_sqpnorm( - self, rng: jax.Array, p: float, forward: bool + self, rng: jnp.ndarray, p: float, forward: bool ): epsilon = None cost_fn = costs.SqPNorm(p=p) @@ -174,7 +174,7 @@ def test_entropic_potentials_sqpnorm( p=[1.45, 2.2, 1.0], forward=[False, True], only_fast=0 ) def test_entropic_potentials_pnorm( - self, rng: jax.Array, p: float, forward: bool + self, rng: jnp.ndarray, p: float, forward: bool ): epsilon = None cost_fn = costs.PNormP(p=p) @@ -216,7 +216,7 @@ def test_entropic_potentials_pnorm( assert div < .1 * div_0 @pytest.mark.parametrize("jit", [False, True]) - def test_distance_differentiability(self, rng: jax.Array, jit: bool): + def test_distance_differentiability(self, rng: jnp.ndarray, jit: bool): rng1, rng2, rng3 = jax.random.split(rng, 3) n, m, d = 18, 36, 5 @@ -238,7 +238,7 @@ def test_distance_differentiability(self, rng: jax.Array, jit: bool): np.testing.assert_allclose(actual, expected, rtol=1e-4, atol=1e-4) @pytest.mark.parametrize("eps", [None, 1e-1, 1e1, 1e2, 1e3]) - def test_potentials_sinkhorn_divergence(self, rng: jax.Array, eps: float): + def test_potentials_sinkhorn_divergence(self, rng: jnp.ndarray, eps: float): rng1, rng2, rng3 = jax.random.split(rng, 3) n, m, d = 32, 36, 4 fwd = True diff --git a/tests/solvers/linear/continuous_barycenter_test.py b/tests/solvers/linear/continuous_barycenter_test.py index 5c7fabd67..4989cc1db 100644 --- a/tests/solvers/linear/continuous_barycenter_test.py +++ b/tests/solvers/linear/continuous_barycenter_test.py @@ -27,7 +27,7 @@ means_and_covs_to_x = jax.vmap(costs.mean_and_cov_to_x, in_axes=[0, 0, None]) -def is_positive_semidefinite(c: jax.Array) -> bool: +def is_positive_semidefinite(c: jnp.ndarray) -> bool: # GPU friendly, eigvals not implemented for non-symmetric matrices w = jnp.linalg.eigvalsh((c + c.T) / 2.0) return jnp.all(w >= 0) @@ -50,7 +50,7 @@ class TestBarycenter: }, ) def test_euclidean_barycenter( - self, rng: jax.Array, rank: int, epsilon: float, init_random: bool, + self, rng: jnp.ndarray, rank: int, epsilon: float, init_random: bool, jit: bool ): rngs = jax.random.split(rng, 20) @@ -115,12 +115,12 @@ def test_euclidean_barycenter( assert jnp.all(out.x.ravel() > .7) @pytest.mark.parametrize("segment_before", [False, True]) - def test_barycenter_jit(self, rng: jax.Array, segment_before: bool): + def test_barycenter_jit(self, rng: jnp.ndarray, segment_before: bool): @functools.partial(jax.jit, static_argnums=(2, 3)) def barycenter( - y: jax.Array, - b: jax.Array, + y: jnp.ndarray, + b: jnp.ndarray, segment_before: bool, num_per_segment: Tuple[int, ...], ) -> cb.FreeBarycenterState: @@ -170,7 +170,7 @@ def barycenter( @pytest.mark.fast() def test_bures_barycenter( self, - rng: jax.Array, + rng: jnp.ndarray, ): lse_mode = True, epsilon = 1e-1 @@ -256,7 +256,7 @@ def test_bures_barycenter( @pytest.mark.fast() def test_bures_barycenter_different_number_of_components( self, - rng: jax.Array, + rng: jnp.ndarray, ): alpha = 5. epsilon = 0.01 diff --git a/tests/solvers/linear/sinkhorn_diff_test.py b/tests/solvers/linear/sinkhorn_diff_test.py index a608c0d71..944534e14 100644 --- a/tests/solvers/linear/sinkhorn_diff_test.py +++ b/tests/solvers/linear/sinkhorn_diff_test.py @@ -29,7 +29,7 @@ class TestSinkhornImplicit: """Check implicit and autodiff match for Sinkhorn.""" @pytest.fixture(autouse=True) - def initialize(self, rng: jax.Array): + def initialize(self, rng: jnp.ndarray): self.dim = 3 self.n = 38 self.m = 73 @@ -49,7 +49,7 @@ def test_implicit_differentiation_versus_autodiff( ): epsilon = 0.05 - def loss_g(a: jax.Array, x: jax.Array, implicit: bool = True) -> float: + def loss_g(a: jnp.ndarray, x: jnp.ndarray, implicit: bool = True) -> float: implicit_diff = implicit_lib.ImplicitDiff() if implicit else None geom = geometry.Geometry( cost_matrix=jnp.sum(x ** 2, axis=1)[:, jnp.newaxis] + @@ -65,7 +65,9 @@ def loss_g(a: jax.Array, x: jax.Array, implicit: bool = True) -> float: ) return solver(prob).reg_ot_cost - def loss_pcg(a: jax.Array, x: jax.Array, implicit: bool = True) -> float: + def loss_pcg( + a: jnp.ndarray, x: jnp.ndarray, implicit: bool = True + ) -> float: implicit_diff = implicit_lib.ImplicitDiff() if implicit else None geom = pointcloud.PointCloud(x, self.y, epsilon=epsilon) prob = linear_problem.LinearProblem( @@ -135,7 +137,7 @@ class TestSinkhornJacobian: only_fast=0, ) def test_autograd_sinkhorn( - self, rng: jax.Array, lse_mode: bool, shape_data: Tuple[int, int] + self, rng: jnp.ndarray, lse_mode: bool, shape_data: Tuple[int, int] ): """Test gradient w.r.t. probability weights.""" n, m = shape_data @@ -152,7 +154,7 @@ def test_autograd_sinkhorn( a = a / jnp.sum(a) b = b / jnp.sum(b) - def reg_ot(a: jax.Array, b: jax.Array) -> float: + def reg_ot(a: jnp.ndarray, b: jnp.ndarray) -> float: geom = pointcloud.PointCloud(x, y, epsilon=1e-1) prob = linear_problem.LinearProblem(geom, a=a, b=b) solver = sinkhorn.Sinkhorn(lse_mode=lse_mode) @@ -178,7 +180,7 @@ def reg_ot(a: jax.Array, b: jax.Array) -> float: @pytest.mark.parametrize(("lse_mode", "shape_data"), [(True, (7, 9)), (False, (11, 5))]) def test_gradient_sinkhorn_geometry( - self, rng: jax.Array, lse_mode: bool, shape_data: Tuple[int, int] + self, rng: jnp.ndarray, lse_mode: bool, shape_data: Tuple[int, int] ): """Test gradient w.r.t. cost matrix.""" n, m = shape_data @@ -188,7 +190,7 @@ def test_gradient_sinkhorn_geometry( delta = delta / jnp.sqrt(jnp.vdot(delta, delta)) eps = 1e-3 # perturbation magnitude - def loss_fn(cm: jax.Array): + def loss_fn(cm: jnp.ndarray): a = jnp.ones(cm.shape[0]) / cm.shape[0] b = jnp.ones(cm.shape[1]) / cm.shape[1] geom = geometry.Geometry(cm, epsilon=0.5) @@ -241,7 +243,7 @@ def loss_fn(cm: jax.Array): only_fast=[0, 1], ) def test_gradient_sinkhorn_euclidean( - self, rng: jax.Array, lse_mode: bool, implicit: bool, min_iter: int, + self, rng: jnp.ndarray, lse_mode: bool, implicit: bool, min_iter: int, max_iter: int, epsilon: float, cost_fn: costs.CostFn ): """Test gradient w.r.t. locations x of reg-ot-cost.""" @@ -262,8 +264,8 @@ def test_gradient_sinkhorn_euclidean( # Adding some near-zero distances to test proper handling with p_norm=1. y = y.at[0].set(x[0, :] + 1e-3) - def loss_fn(x: jax.Array, - y: jax.Array) -> Tuple[float, sinkhorn.SinkhornOutput]: + def loss_fn(x: jnp.ndarray, + y: jnp.ndarray) -> Tuple[float, sinkhorn.SinkhornOutput]: implicit_diff = implicit_lib.ImplicitDiff() if implicit else None geom = pointcloud.PointCloud(x, y, epsilon=epsilon, cost_fn=cost_fn) prob = linear_problem.LinearProblem(geom, a, b) @@ -315,10 +317,10 @@ def loss_fn(x: jax.Array, ) np.testing.assert_array_equal(jnp.isnan(custom_grad), False) - def test_autoepsilon_differentiability(self, rng: jax.Array): + def test_autoepsilon_differentiability(self, rng: jnp.ndarray): cost = jax.random.uniform(rng, (15, 17)) - def reg_ot_cost(c: jax.Array) -> float: + def reg_ot_cost(c: jnp.ndarray) -> float: geom = geometry.Geometry(c, epsilon=None) # auto epsilon prob = linear_problem.LinearProblem(geom) return sinkhorn.Sinkhorn()(prob).reg_ot_cost @@ -327,9 +329,9 @@ def reg_ot_cost(c: jax.Array) -> float: np.testing.assert_array_equal(jnp.isnan(gradient), False) @pytest.mark.fast() - def test_differentiability_with_jit(self, rng: jax.Array): + def test_differentiability_with_jit(self, rng: jnp.ndarray): - def reg_ot_cost(c: jax.Array) -> float: + def reg_ot_cost(c: jnp.ndarray) -> float: geom = geometry.Geometry(c, epsilon=1e-2) prob = linear_problem.LinearProblem(geom) return sinkhorn.Sinkhorn()(prob).reg_ot_cost @@ -345,7 +347,7 @@ def reg_ot_cost(c: jax.Array) -> float: only_fast=0 ) def test_apply_transport_jacobian( - self, rng: jax.Array, lse_mode: bool, tau_a: float, tau_b: float, + self, rng: jnp.ndarray, lse_mode: bool, tau_a: float, tau_b: float, arg: int, axis: int ): """Tests Jacobian of application of OT to vector, w.r.t. @@ -383,7 +385,7 @@ def test_apply_transport_jacobian( # general rule, even more so when using backprop. epsilon = 0.01 if lse_mode else 0.1 - def apply_ot(a: jax.Array, x: jax.Array, implicit: bool) -> jax.Array: + def apply_ot(a: jnp.ndarray, x: jnp.ndarray, implicit: bool) -> jnp.ndarray: geom = pointcloud.PointCloud(x, y, epsilon=epsilon) prob = linear_problem.LinearProblem(geom, a, b, tau_a=tau_a, tau_b=tau_b) @@ -457,7 +459,7 @@ def apply_ot(a: jax.Array, x: jax.Array, implicit: bool) -> jax.Array: only_fast=0, ) def test_potential_jacobian_sinkhorn( - self, rng: jax.Array, lse_mode: bool, tau_a: float, tau_b: float, + self, rng: jnp.ndarray, lse_mode: bool, tau_a: float, tau_b: float, shape: Tuple[int, int], arg: int ): """Test Jacobian of optimal potential w.r.t. weights and locations.""" @@ -486,7 +488,7 @@ def test_potential_jacobian_sinkhorn( # with small epsilon when differentiating. epsilon = 0.01 if lse_mode else 0.1 - def loss_from_potential(a: jax.Array, x: jax.Array, implicit: bool): + def loss_from_potential(a: jnp.ndarray, x: jnp.ndarray, implicit: bool): geom = pointcloud.PointCloud(x, y, epsilon=epsilon) prob = linear_problem.LinearProblem(geom, a, b, tau_a=tau_a, tau_b=tau_b) @@ -539,7 +541,7 @@ class TestSinkhornGradGrid: @pytest.mark.parametrize("lse_mode", [False, True]) def test_diff_sinkhorn_x_grid_x_perturbation( - self, rng: jax.Array, lse_mode: bool + self, rng: jnp.ndarray, lse_mode: bool ): """Test gradient w.r.t. probability weights.""" eps = 1e-3 # perturbation magnitude @@ -554,7 +556,7 @@ def test_diff_sinkhorn_x_grid_x_perturbation( a = a.ravel() / jnp.sum(a) b = b.ravel() / jnp.sum(b) - def reg_ot(x: List[jax.Array]) -> float: + def reg_ot(x: List[jnp.ndarray]) -> float: geom = grid.Grid(x=x, epsilon=1.0) prob = linear_problem.LinearProblem(geom, a=a, b=b) solver = sinkhorn.Sinkhorn(threshold=1e-1, lse_mode=lse_mode) @@ -584,7 +586,7 @@ def reg_ot(x: List[jax.Array]) -> float: @pytest.mark.parametrize("lse_mode", [False, True]) def test_diff_sinkhorn_x_grid_weights_perturbation( - self, rng: jax.Array, lse_mode: bool + self, rng: jnp.ndarray, lse_mode: bool ): """Test gradient w.r.t. probability weights.""" eps = 1e-4 # perturbation magnitude @@ -603,7 +605,7 @@ def test_diff_sinkhorn_x_grid_weights_perturbation( b = b.ravel() / jnp.sum(b) geom = grid.Grid(x=x, epsilon=1) - def reg_ot(a: jax.Array, b: jax.Array) -> float: + def reg_ot(a: jnp.ndarray, b: jnp.ndarray) -> float: prob = linear_problem.LinearProblem(geom, a, b) solver = sinkhorn.Sinkhorn(threshold=1e-3, lse_mode=lse_mode) return solver(prob).reg_ot_cost @@ -635,7 +637,7 @@ class TestSinkhornJacobianPreconditioning: only_fast=[0, -1], ) def test_potential_jacobian_sinkhorn_precond( - self, rng: jax.Array, lse_mode: bool, tau_a: float, tau_b: float, + self, rng: jnp.ndarray, lse_mode: bool, tau_a: float, tau_b: float, shape: Tuple[int, int], arg: int ): """Test Jacobian of optimal potential works across 2 precond_fun.""" @@ -665,9 +667,9 @@ def test_potential_jacobian_sinkhorn_precond( epsilon = 0.05 if lse_mode else 0.1 def loss_from_potential( - a: jax.Array, - x: jax.Array, - precondition_fun: Optional[Callable[[jax.Array], jax.Array]] = None, + a: jnp.ndarray, + x: jnp.ndarray, + precondition_fun: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, symmetric: bool = False ) -> float: geom = pointcloud.PointCloud(x, y, epsilon=epsilon) @@ -738,7 +740,7 @@ class TestSinkhornHessian: only_fast=-1 ) def test_hessian_sinkhorn( - self, rng: jax.Array, lse_mode: bool, tau_a: float, tau_b: float, + self, rng: jnp.ndarray, lse_mode: bool, tau_a: float, tau_b: float, arg: int, lineax_ridge: float ): """Test hessian w.r.t. weights and locations.""" @@ -769,7 +771,7 @@ def test_hessian_sinkhorn( imp_dif = implicit_lib.ImplicitDiff(solver_kwargs=solver_kwargs) - def loss(a: jax.Array, x: jax.Array, implicit: bool = True): + def loss(a: jnp.ndarray, x: jnp.ndarray, implicit: bool = True): geom = pointcloud.PointCloud(x, y, epsilon=epsilon) prob = linear_problem.LinearProblem(geom, a, b, tau_a, tau_b) implicit_diff = imp_dif if implicit else None diff --git a/tests/solvers/linear/sinkhorn_grid_test.py b/tests/solvers/linear/sinkhorn_grid_test.py index b2aa4da3e..dd22f63b7 100644 --- a/tests/solvers/linear/sinkhorn_grid_test.py +++ b/tests/solvers/linear/sinkhorn_grid_test.py @@ -25,7 +25,7 @@ class TestSinkhornGrid: @pytest.mark.parametrize("lse_mode", [False, True]) - def test_separable_grid(self, rng: jax.Array, lse_mode: bool): + def test_separable_grid(self, rng: jnp.ndarray, lse_mode: bool): """Two histograms in a grid of size 5 x 6 x 7 in the hypercube^3.""" grid_size = (5, 6, 7) rngs = jax.random.split(rng, 2) @@ -46,7 +46,7 @@ def test_separable_grid(self, rng: jax.Array, lse_mode: bool): assert threshold > err @pytest.mark.fast.with_args("lse_mode", [False, True], only_fast=0) - def test_grid_vs_euclidean(self, rng: jax.Array, lse_mode: bool): + def test_grid_vs_euclidean(self, rng: jnp.ndarray, lse_mode: bool): grid_size = (5, 6, 7) rngs = jax.random.split(rng, 2) a = jax.random.uniform(rngs[0], grid_size) @@ -69,7 +69,7 @@ def test_grid_vs_euclidean(self, rng: jax.Array, lse_mode: bool): ) @pytest.mark.fast.with_args("lse_mode", [False, True], only_fast=1) - def test_apply_transport_grid(self, rng: jax.Array, lse_mode: bool): + def test_apply_transport_grid(self, rng: jnp.ndarray, lse_mode: bool): grid_size = (5, 6, 7) rngs = jax.random.split(rng, 4) a = jax.random.uniform(rngs[0], grid_size) @@ -118,7 +118,7 @@ def test_apply_transport_grid(self, rng: jax.Array, lse_mode: bool): np.testing.assert_array_equal(jnp.isnan(mat_transport_t_vec_a), False) @pytest.mark.fast() - def test_apply_cost(self, rng: jax.Array): + def test_apply_cost(self, rng: jnp.ndarray): grid_size = (5, 6, 7) geom_grid = grid.Grid(grid_size=grid_size, epsilon=0.1) diff --git a/tests/solvers/linear/sinkhorn_lr_test.py b/tests/solvers/linear/sinkhorn_lr_test.py index 9b360bdf0..90b149ea8 100644 --- a/tests/solvers/linear/sinkhorn_lr_test.py +++ b/tests/solvers/linear/sinkhorn_lr_test.py @@ -26,7 +26,7 @@ class TestLRSinkhorn: @pytest.fixture(autouse=True) - def initialize(self, rng: jax.Array): + def initialize(self, rng: jnp.ndarray): self.dim = 4 self.n = 23 self.m = 27 diff --git a/tests/solvers/linear/sinkhorn_misc_test.py b/tests/solvers/linear/sinkhorn_misc_test.py index e88ff4d0c..e97a34228 100644 --- a/tests/solvers/linear/sinkhorn_misc_test.py +++ b/tests/solvers/linear/sinkhorn_misc_test.py @@ -36,7 +36,7 @@ class TestSinkhornAnderson: only_fast=0, ) def test_anderson( - self, rng: jax.Array, lse_mode: bool, tau_a: float, tau_b: float + self, rng: jnp.ndarray, lse_mode: bool, tau_a: float, tau_b: float ): """Test efficiency of Anderson acceleration. @@ -128,7 +128,7 @@ def initialize(self): @pytest.mark.parametrize(("unbalanced", "thresh"), [(False, 1e-3), (True, 1e-4)]) def test_bures_point_cloud( - self, rng: jax.Array, lse_mode: bool, unbalanced: bool, thresh: float + self, rng: jnp.ndarray, lse_mode: bool, unbalanced: bool, thresh: float ): """Two point clouds of Gaussians, tested with various parameters.""" if unbalanced: @@ -169,7 +169,7 @@ def test_regularized_unbalanced_bures_cost(self): class TestSinkhornOnline: @pytest.fixture(autouse=True) - def initialize(self, rng: jax.Array): + def initialize(self, rng: jnp.ndarray): self.dim = 3 self.n = 100 self.m = 42 @@ -234,7 +234,7 @@ def callback(epsilon: float, batch_size: int) -> sinkhorn.SinkhornOutput: class TestSinkhornUnbalanced: @pytest.fixture(autouse=True) - def initialize(self, rng: jax.Array): + def initialize(self, rng: jnp.ndarray): self.dim = 4 self.n = 17 self.m = 23 @@ -315,7 +315,7 @@ class TestSinkhornJIT: """Check jitted and non jit match for Sinkhorn, and that everything jits.""" @pytest.fixture(autouse=True) - def initialize(self, rng: jax.Array): + def initialize(self, rng: jnp.ndarray): self.dim = 3 self.n = 10 self.m = 11 @@ -346,10 +346,12 @@ def assert_output_close( ) -> None: """Assert SinkhornOutputs are close.""" x = tuple( - a for a in x if (a is not None and (isinstance(a, (jax.Array, int)))) + a for a in x + if (a is not None and (isinstance(a, (jnp.ndarray, int)))) ) y = tuple( - a for a in y if (a is not None and (isinstance(a, (jax.Array, int)))) + a for a in y + if (a is not None and (isinstance(a, (jnp.ndarray, int)))) ) return chex.assert_trees_all_close(x, y, atol=1e-6, rtol=0) @@ -362,7 +364,7 @@ def assert_output_close( def test_jit_vs_non_jit_bwd(self, implicit: bool): @jax.value_and_grad - def val_grad(a: jax.Array, x: jax.Array) -> float: + def val_grad(a: jnp.ndarray, x: jnp.ndarray) -> float: implicit_diff = implicit_lib.ImplicitDiff() if implicit else None geom = geometry.Geometry( cost_matrix=( diff --git a/tests/solvers/linear/sinkhorn_test.py b/tests/solvers/linear/sinkhorn_test.py index ce7f9919a..c7475c4f3 100644 --- a/tests/solvers/linear/sinkhorn_test.py +++ b/tests/solvers/linear/sinkhorn_test.py @@ -30,7 +30,7 @@ class TestSinkhorn: @pytest.fixture(autouse=True) - def initialize(self, rng: jax.Array): + def initialize(self, rng: jnp.ndarray): self.rng = rng self.dim = 4 self.n = 17 diff --git a/tests/solvers/linear/univariate_test.py b/tests/solvers/linear/univariate_test.py index 1a5529167..47a34f7ce 100644 --- a/tests/solvers/linear/univariate_test.py +++ b/tests/solvers/linear/univariate_test.py @@ -25,7 +25,7 @@ class TestUnivariate: @pytest.fixture(autouse=True) - def initialize(self, rng: jax.Array): + def initialize(self, rng: jnp.ndarray): self.rng = rng self.n = 17 self.m = 29 @@ -86,7 +86,7 @@ def test_cdf_distance_and_scipy(self): @pytest.mark.fast() def test_cdf_grad( self, - rng: jax.Array, + rng: jnp.ndarray, ): # TODO: Once a `check_grad` function is implemented, replace the code # blocks before with `check_grad`'s. diff --git a/tests/solvers/quadratic/fgw_test.py b/tests/solvers/quadratic/fgw_test.py index 508fedcb2..10361d088 100644 --- a/tests/solvers/quadratic/fgw_test.py +++ b/tests/solvers/quadratic/fgw_test.py @@ -29,7 +29,7 @@ class TestFusedGromovWasserstein: # TODO(michalk8): refactor me in the future @pytest.fixture(autouse=True) - def initialize(self, rng: jax.Array): + def initialize(self, rng: jnp.ndarray): d_x = 2 d_y = 3 d_xy = 4 @@ -56,7 +56,7 @@ def test_gradient_marginals_fgw_solver(self, jit: bool): geom_y = pointcloud.PointCloud(self.y) geom_xy = pointcloud.PointCloud(self.x_2, self.y_2) - def reg_gw(a: jax.Array, b: jax.Array, implicit: bool): + def reg_gw(a: jnp.ndarray, b: jnp.ndarray, implicit: bool): prob = quadratic_problem.QuadraticProblem( geom_x, geom_y, geom_xy, fused_penalty=self.fused_penalty, a=a, b=b ) @@ -101,9 +101,9 @@ def test_gradient_fgw_solver_geometry(self, lse_mode: bool, is_cost: bool): """Test gradient w.r.t. the geometries.""" def reg_gw( - x: jax.Array, y: jax.Array, xy: Union[jax.Array, Tuple[jax.Array, - jax.Array]], - fused_penalty: float, a: jax.Array, b: jax.Array, implicit: bool + x: jnp.ndarray, y: jnp.ndarray, + xy: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]], + fused_penalty: float, a: jnp.ndarray, b: jnp.ndarray, implicit: bool ): if is_cost: geom_x = geometry.Geometry(cost_matrix=x) @@ -182,8 +182,8 @@ def test_gradient_fgw_solver_penalty(self): lse_mode = True def reg_gw( - cx: jax.Array, cy: jax.Array, cxy: jax.Array, fused_penalty: float, - a: jax.Array, b: jax.Array, implicit: bool + cx: jnp.ndarray, cy: jnp.ndarray, cxy: jnp.ndarray, + fused_penalty: float, a: jnp.ndarray, b: jnp.ndarray, implicit: bool ) -> float: geom_x = geometry.Geometry(cost_matrix=cx) geom_y = geometry.Geometry(cost_matrix=cy) @@ -216,7 +216,7 @@ def reg_gw( @pytest.mark.limit_memory("200 MB") @pytest.mark.parametrize("jit", [False, True]) - def test_fgw_lr_memory(self, rng: jax.Array, jit: bool): + def test_fgw_lr_memory(self, rng: jnp.ndarray, jit: bool): rngs = jax.random.split(rng, 4) n, m, d1, d2 = 5_000, 2_500, 1, 2 x = jax.random.uniform(rngs[0], (n, d1)) @@ -243,7 +243,7 @@ def test_fgw_lr_memory(self, rng: jax.Array, jit: bool): @pytest.mark.parametrize("cost_rank", [4, (2, 3, 4)]) def test_fgw_lr_generic_cost_matrix( - self, rng: jax.Array, cost_rank: Union[int, Tuple[int, int, int]] + self, rng: jnp.ndarray, cost_rank: Union[int, Tuple[int, int, int]] ): n, m = 20, 30 rng1, rng2, rng3, rng4 = jax.random.split(rng, 4) diff --git a/tests/solvers/quadratic/gw_barycenter_test.py b/tests/solvers/quadratic/gw_barycenter_test.py index d5dadd691..a157f27e5 100644 --- a/tests/solvers/quadratic/gw_barycenter_test.py +++ b/tests/solvers/quadratic/gw_barycenter_test.py @@ -31,7 +31,7 @@ class TestGWBarycenter: def random_pc( n: int, d: int, - rng: jax.Array, + rng: jnp.ndarray, m: Optional[int] = None, **kwargs: Any ) -> pointcloud.PointCloud: @@ -42,9 +42,9 @@ def random_pc( @staticmethod def pad_cost_matrices( - costs: Sequence[jax.Array], + costs: Sequence[jnp.ndarray], shape: Optional[Tuple[int, int]] = None - ) -> Tuple[jax.Array, jax.Array]: + ) -> Tuple[jnp.ndarray, jnp.ndarray]: if shape is None: shape = jnp.asarray([arr.shape for arr in costs]).max() shape = (shape, shape) @@ -65,7 +65,7 @@ def pad_cost_matrices( [("sqeucl", 17, None)] # , ("kl", 22, 1e-2)] ) def test_gw_barycenter( - self, rng: jax.Array, gw_loss: str, bar_size: int, + self, rng: jnp.ndarray, gw_loss: str, bar_size: int, epsilon: Optional[float] ): tol = 1e-3 if gw_loss == "sqeucl" else 1e-1 @@ -126,14 +126,14 @@ def test_gw_barycenter( ) def test_fgw_barycenter( self, - rng: jax.Array, + rng: jnp.ndarray, jit: bool, fused_penalty: float, scale_cost: str, ): def barycenter( - y: jnp.ndim, y_fused: jax.Array, num_per_segment: Tuple[int, ...] + y: jnp.ndim, y_fused: jnp.ndarray, num_per_segment: Tuple[int, ...] ) -> gwb_solver.GWBarycenterState: prob = gwb.GWBarycenterProblem( y=y, diff --git a/tests/solvers/quadratic/gw_test.py b/tests/solvers/quadratic/gw_test.py index e7d0ff106..e7ef7b558 100644 --- a/tests/solvers/quadratic/gw_test.py +++ b/tests/solvers/quadratic/gw_test.py @@ -31,7 +31,7 @@ class TestQuadraticProblem: @pytest.mark.parametrize("as_pc", [False, True]) @pytest.mark.parametrize("rank", [-1, 5, (1, 2, 3), (2, 3, 5)]) def test_quad_to_low_rank( - self, rng: jax.Array, as_pc: bool, rank: Union[int, Tuple[int, ...]] + self, rng: jnp.ndarray, as_pc: bool, rank: Union[int, Tuple[int, ...]] ): n, m, d1, d2, d = 100, 120, 4, 6, 10 rng1, rng2, rng3, rng4 = jax.random.split(rng, 4) @@ -87,7 +87,7 @@ def test_quad_to_low_rank( assert lr_prob._is_low_rank_convertible assert lr_prob.to_low_rank() is lr_prob - def test_gw_implicit_conversion_mixed_input(self, rng: jax.Array): + def test_gw_implicit_conversion_mixed_input(self, rng: jnp.ndarray): n, m, d1, d2 = 13, 77, 3, 4 rng1, rng2 = jax.random.split(rng, 2) x = jax.random.normal(rng1, (n, d1)) @@ -107,7 +107,7 @@ def test_gw_implicit_conversion_mixed_input(self, rng: jax.Array): class TestGromovWasserstein: @pytest.fixture(autouse=True) - def initialize(self, rng: jax.Array): + def initialize(self, rng: jnp.ndarray): d_x = 2 d_y = 3 self.n, self.m = 6, 7 @@ -156,8 +156,8 @@ def test_flag_store_errors(self): def test_gradient_marginals_gw(self, jit: bool): """Test gradient w.r.t. probability weights.""" - def reg_gw(a: jax.Array, b: jax.Array, - implicit: bool) -> Tuple[float, Tuple[jax.Array, jax.Array]]: + def reg_gw(a: jnp.ndarray, b: jnp.ndarray, + implicit: bool) -> Tuple[float, Tuple[jnp.ndarray, jnp.ndarray]]: prob = quadratic_problem.QuadraticProblem(geom_x, geom_y, a=a, b=b) implicit_diff = implicit_lib.ImplicitDiff() if implicit else None linear_solver = sinkhorn.Sinkhorn( @@ -245,7 +245,8 @@ def test_gradient_gw_geometry( """Test gradient w.r.t. the geometries.""" def reg_gw( - x: jax.Array, y: jax.Array, a: jax.Array, b: jax.Array, implicit: bool + x: jnp.ndarray, y: jnp.ndarray, a: jnp.ndarray, b: jnp.ndarray, + implicit: bool ) -> float: if is_cost: geom_x = geometry.Geometry(cost_matrix=x) @@ -309,7 +310,7 @@ def loss_thre(threshold: float) -> float: assert loss_thre(1e-3) >= loss_thre(1e-5) @pytest.mark.fast() - def test_gw_lr(self, rng: jax.Array): + def test_gw_lr(self, rng: jnp.ndarray): """Checking LR and Entropic have similar outputs on same problem.""" rngs = jax.random.split(rng, 4) n, m, d1, d2 = 24, 17, 2, 3 @@ -333,7 +334,7 @@ def test_gw_lr(self, rng: jax.Array): ot_gwlr.primal_cost, ot_gw.primal_cost, rtol=5e-2 ) - def test_gw_lr_matches_fused(self, rng: jax.Array): + def test_gw_lr_matches_fused(self, rng: jnp.ndarray): """Checking LR and Entropic have similar outputs on same fused problem.""" rngs = jax.random.split(rng, 5) n, m, d1, d2 = 24, 17, 2, 3 @@ -384,7 +385,7 @@ def test_gw_lr_apply(self, axis: int): @pytest.mark.parametrize("scale_cost", [1.0, "mean"]) def test_relative_epsilon( self, - rng: jax.Array, + rng: jnp.ndarray, scale_cost: Union[float, str], ): eps = 1e-2 diff --git a/tests/solvers/quadratic/lower_bound_test.py b/tests/solvers/quadratic/lower_bound_test.py index bf32aad87..68f7a804a 100644 --- a/tests/solvers/quadratic/lower_bound_test.py +++ b/tests/solvers/quadratic/lower_bound_test.py @@ -31,7 +31,7 @@ class TestLowerBoundSolver: @pytest.fixture(autouse=True) - def initialize(self, rng: jax.Array): + def initialize(self, rng: jnp.ndarray): d_x = 2 d_y = 3 self.n, self.m = 13, 15 @@ -118,11 +118,11 @@ def test_lb_pointcloud( ] ) def test_lb_grad( - self, rng: jax.Array, sort_fn: Callable[[jax.Array], jax.Array], + self, rng: jnp.ndarray, sort_fn: Callable[[jnp.ndarray], jnp.ndarray], method: str ): - def fn(x: jax.Array, y: jax.Array) -> float: + def fn(x: jnp.ndarray, y: jnp.ndarray) -> float: geom_x = pointcloud.PointCloud(x) geom_y = pointcloud.PointCloud(y) prob = quadratic_problem.QuadraticProblem(geom_x, geom_y) diff --git a/tests/tools/gaussian_mixture/fit_gmm_pair_test.py b/tests/tools/gaussian_mixture/fit_gmm_pair_test.py index 8f43eaa4e..75fc3bef5 100644 --- a/tests/tools/gaussian_mixture/fit_gmm_pair_test.py +++ b/tests/tools/gaussian_mixture/fit_gmm_pair_test.py @@ -29,7 +29,7 @@ class TestFitGmmPair: @pytest.fixture(autouse=True) - def initialize(self, rng: jax.Array): + def initialize(self, rng: jnp.ndarray): mean_generator0 = jnp.array([[2., -1.], [-2., 0.], [4., 3.]]) cov_generator0 = jnp.array([[[0.2, 0.], [0., 0.1]], [[0.6, 0.], [0., 0.3]], [[0.5, 0.4], [0.4, 0.5]]]) diff --git a/tests/tools/gaussian_mixture/fit_gmm_test.py b/tests/tools/gaussian_mixture/fit_gmm_test.py index e39633b19..1cfb4f95e 100644 --- a/tests/tools/gaussian_mixture/fit_gmm_test.py +++ b/tests/tools/gaussian_mixture/fit_gmm_test.py @@ -23,7 +23,7 @@ class TestFitGmm: @pytest.fixture(autouse=True) - def initialize(self, rng: jax.Array): + def initialize(self, rng: jnp.ndarray): mean_generator = jnp.array([[2., -1.], [-2., 0.], [4., 3.]]) cov_generator = jnp.array([[[0.2, 0.], [0., 0.1]], [[0.6, 0.], [0., 0.3]], [[0.5, 0.4], [0.4, 0.5]]]) diff --git a/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py b/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py index ccf1e50cd..bf2b01699 100644 --- a/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py +++ b/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py @@ -22,7 +22,7 @@ class TestGaussianMixturePair: @pytest.fixture(autouse=True) - def initialize(self, rng: jax.Array): + def initialize(self, rng: jnp.ndarray): self.n_components = 3 self.n_dimensions = 2 self.epsilon = 1.e-3 diff --git a/tests/tools/gaussian_mixture/gaussian_mixture_test.py b/tests/tools/gaussian_mixture/gaussian_mixture_test.py index af52860be..fd7675d51 100644 --- a/tests/tools/gaussian_mixture/gaussian_mixture_test.py +++ b/tests/tools/gaussian_mixture/gaussian_mixture_test.py @@ -23,7 +23,7 @@ class TestGaussianMixture: def test_get_summary_stats_from_points_and_assignment_probs( - self, rng: jax.Array + self, rng: jnp.ndarray ): n = 50 rng, subrng0, subrng1 = jax.random.split(rng, num=3) @@ -56,7 +56,7 @@ def test_get_summary_stats_from_points_and_assignment_probs( np.testing.assert_allclose(expected_cov, cov, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(expected_wt, comp_wt, atol=1e-4, rtol=1e-4) - def test_from_random(self, rng: jax.Array): + def test_from_random(self, rng: jnp.ndarray): gmm = gaussian_mixture.GaussianMixture.from_random( rng=rng, n_components=3, n_dimensions=2 ) @@ -79,7 +79,7 @@ def test_from_mean_cov_component_weights(self,): comp_wts, gmm.component_weights, atol=1e-4, rtol=1e-4 ) - def test_covariance(self, rng: jax.Array): + def test_covariance(self, rng: jnp.ndarray): gmm = gaussian_mixture.GaussianMixture.from_random( rng=rng, n_components=3, n_dimensions=2 ) @@ -89,7 +89,7 @@ def test_covariance(self, rng: jax.Array): cov[i], component.covariance(), atol=1e-4, rtol=1e-4 ) - def test_sample(self, rng: jax.Array): + def test_sample(self, rng: jnp.ndarray): gmm = gaussian_mixture.GaussianMixture.from_mean_cov_component_weights( mean=jnp.array([[-1., 0.], [1., 0.]]), cov=jnp.array([[[0.01, 0.], [0., 0.01]], [[0.01, 0.], [0., 0.01]]]), @@ -111,7 +111,7 @@ def test_sample(self, rng: jax.Array): atol=1.e-1 ) - def test_log_prob(self, rng: jax.Array): + def test_log_prob(self, rng: jnp.ndarray): n_components = 3 size = 100 subrng0, subrng1 = jax.random.split(rng, num=2) @@ -135,7 +135,7 @@ def test_log_prob(self, rng: jax.Array): np.testing.assert_allclose(expected, actual, atol=1e-4, rtol=1e-4) - def test_log_component_posterior(self, rng: jax.Array): + def test_log_component_posterior(self, rng: jnp.ndarray): gmm = gaussian_mixture.GaussianMixture.from_random( rng=rng, n_components=3, n_dimensions=2 ) @@ -149,7 +149,7 @@ def test_log_component_posterior(self, rng: jax.Array): expected, gmm.get_log_component_posterior(x), atol=1e-4, rtol=1e-4 ) - def test_flatten_unflatten(self, rng: jax.Array): + def test_flatten_unflatten(self, rng: jnp.ndarray): gmm = gaussian_mixture.GaussianMixture.from_random( rng=rng, n_components=3, n_dimensions=2 ) @@ -158,7 +158,7 @@ def test_flatten_unflatten(self, rng: jax.Array): assert gmm == gmm_new - def test_pytree_mapping(self, rng: jax.Array): + def test_pytree_mapping(self, rng: jnp.ndarray): gmm = gaussian_mixture.GaussianMixture.from_random( rng=rng, n_components=3, n_dimensions=2 ) diff --git a/tests/tools/gaussian_mixture/gaussian_test.py b/tests/tools/gaussian_mixture/gaussian_test.py index 8b720861c..0eac630e3 100644 --- a/tests/tools/gaussian_mixture/gaussian_test.py +++ b/tests/tools/gaussian_mixture/gaussian_test.py @@ -22,7 +22,7 @@ @pytest.mark.fast() class TestGaussian: - def test_from_random(self, rng: jax.Array): + def test_from_random(self, rng: jnp.ndarray): g = gaussian.Gaussian.from_random(rng=rng, n_dimensions=3) np.testing.assert_array_equal(g.loc.shape, (3,)) @@ -36,7 +36,7 @@ def test_from_mean_and_cov(self): np.testing.assert_array_equal(mean, g.loc) np.testing.assert_allclose(cov, g.covariance(), atol=1e-4, rtol=1e-4) - def test_to_z(self, rng: jax.Array): + def test_to_z(self, rng: jnp.ndarray): g = gaussian.Gaussian( loc=jnp.array([1., 2.]), scale=scale_tril.ScaleTriL( @@ -52,7 +52,7 @@ def test_to_z(self, rng: jax.Array): np.testing.assert_allclose(sample_mean, jnp.zeros(2), atol=0.1) np.testing.assert_allclose(sample_cov, jnp.eye(2), atol=0.1) - def test_from_z(self, rng: jax.Array): + def test_from_z(self, rng: jnp.ndarray): g = gaussian.Gaussian( loc=jnp.array([0., 0.]), scale=scale_tril.ScaleTriL( @@ -64,7 +64,7 @@ def test_from_z(self, rng: jax.Array): xnew = g.from_z(z) np.testing.assert_allclose(x, xnew, atol=1e-4, rtol=1e-4) - def test_log_prob(self, rng: jax.Array): + def test_log_prob(self, rng: jnp.ndarray): g = gaussian.Gaussian( loc=jnp.array([0., 0.]), scale=scale_tril.ScaleTriL( @@ -78,7 +78,7 @@ def test_log_prob(self, rng: jax.Array): ) np.testing.assert_allclose(expected, actual, atol=1e-5, rtol=1e-5) - def test_sample(self, rng: jax.Array): + def test_sample(self, rng: jnp.ndarray): mean = jnp.array([1., 2.]) cov = jnp.diag(jnp.array([1., 4.])) g = gaussian.Gaussian.from_mean_and_cov(mean, cov) @@ -89,7 +89,7 @@ def test_sample(self, rng: jax.Array): np.testing.assert_allclose(sample_mean, mean, atol=3. * 2. / 100.) np.testing.assert_allclose(sample_cov, cov, atol=2e-1) - def test_w2_dist(self, rng: jax.Array): + def test_w2_dist(self, rng: jnp.ndarray): # make sure distance between a random normal and itself is 0 rng, subrng = jax.random.split(rng) n = gaussian.Gaussian.from_random(rng=subrng, n_dimensions=3) @@ -118,7 +118,7 @@ def test_w2_dist(self, rng: jax.Array): expected = delta_mean + delta_sigma np.testing.assert_allclose(expected, w2, rtol=1e-6, atol=1e-6) - def test_transport(self, rng: jax.Array): + def test_transport(self, rng: jnp.ndarray): diag0 = jnp.array([1.]) diag1 = jnp.array([4.]) g0 = gaussian.Gaussian( @@ -134,14 +134,14 @@ def test_transport(self, rng: jax.Array): expected = 2. * points + 1. np.testing.assert_allclose(expected, actual, atol=1e-5, rtol=1e-5) - def test_flatten_unflatten(self, rng: jax.Array): + def test_flatten_unflatten(self, rng: jnp.ndarray): g = gaussian.Gaussian.from_random(rng, n_dimensions=3) children, aux_data = jax.tree_util.tree_flatten(g) g_new = jax.tree_util.tree_unflatten(aux_data, children) assert g == g_new - def test_pytree_mapping(self, rng: jax.Array): + def test_pytree_mapping(self, rng: jnp.ndarray): g = gaussian.Gaussian.from_random(rng, n_dimensions=3) g_x_2 = jax.tree_map(lambda x: 2 * x, g) diff --git a/tests/tools/gaussian_mixture/linalg_test.py b/tests/tools/gaussian_mixture/linalg_test.py index 4db928264..6fedb13ae 100644 --- a/tests/tools/gaussian_mixture/linalg_test.py +++ b/tests/tools/gaussian_mixture/linalg_test.py @@ -22,7 +22,7 @@ @pytest.mark.fast() class TestLinalg: - def test_get_mean_and_var(self, rng: jax.Array): + def test_get_mean_and_var(self, rng: jnp.ndarray): points = jax.random.normal(key=rng, shape=(10, 2)) weights = jnp.ones(10) expected_mean = jnp.mean(points, axis=0) @@ -33,7 +33,7 @@ def test_get_mean_and_var(self, rng: jax.Array): np.testing.assert_allclose(expected_mean, actual_mean, atol=1E-5, rtol=1E-5) np.testing.assert_allclose(expected_var, actual_var, atol=1E-5, rtol=1E-5) - def test_get_mean_and_var_nonuniform_weights(self, rng: jax.Array): + def test_get_mean_and_var_nonuniform_weights(self, rng: jnp.ndarray): points = jax.random.normal(key=rng, shape=(10, 2)) weights = jnp.concatenate([jnp.ones(5), jnp.zeros(5)], axis=-1) expected_mean = jnp.mean(points[:5], axis=0) @@ -44,7 +44,7 @@ def test_get_mean_and_var_nonuniform_weights(self, rng: jax.Array): np.testing.assert_allclose(expected_mean, actual_mean, rtol=1e-6, atol=1e-6) np.testing.assert_allclose(expected_var, actual_var, rtol=1e-6, atol=1e-6) - def test_get_mean_and_cov(self, rng: jax.Array): + def test_get_mean_and_cov(self, rng: jnp.ndarray): points = jax.random.normal(key=rng, shape=(10, 2)) weights = jnp.ones(10) expected_mean = jnp.mean(points, axis=0) @@ -55,7 +55,7 @@ def test_get_mean_and_cov(self, rng: jax.Array): np.testing.assert_allclose(expected_mean, actual_mean, atol=1e-5, rtol=1e-5) np.testing.assert_allclose(expected_cov, actual_cov, atol=1e-5, rtol=1e-5) - def test_get_mean_and_cov_nonuniform_weights(self, rng: jax.Array): + def test_get_mean_and_cov_nonuniform_weights(self, rng: jnp.ndarray): points = jax.random.normal(key=rng, shape=(10, 2)) weights = jnp.concatenate([jnp.ones(5), jnp.zeros(5)], axis=-1) expected_mean = jnp.mean(points[:5], axis=0) @@ -66,7 +66,7 @@ def test_get_mean_and_cov_nonuniform_weights(self, rng: jax.Array): np.testing.assert_allclose(expected_mean, actual_mean, rtol=1e-6, atol=1e-6) np.testing.assert_allclose(expected_cov, actual_cov, rtol=1e-6, atol=1e-6) - def test_flat_to_tril(self, rng: jax.Array): + def test_flat_to_tril(self, rng: jnp.ndarray): size = 3 x = jax.random.normal(key=rng, shape=(5, 4, size * (size + 1) // 2)) m = linalg.flat_to_tril(x, size) @@ -86,7 +86,7 @@ def test_flat_to_tril(self, rng: jax.Array): actual = linalg.tril_to_flat(m) np.testing.assert_allclose(x, actual) - def test_tril_to_flat(self, rng: jax.Array): + def test_tril_to_flat(self, rng: jnp.ndarray): size = 3 m = jax.random.normal(key=rng, shape=(5, 4, size, size)) for i in range(size): @@ -103,7 +103,7 @@ def test_tril_to_flat(self, rng: jax.Array): inverted = linalg.flat_to_tril(flat, size) np.testing.assert_allclose(m, inverted) - def test_apply_to_diag(self, rng: jax.Array): + def test_apply_to_diag(self, rng: jnp.ndarray): size = 3 m = jax.random.normal(key=rng, shape=(5, 4, size, size)) mnew = linalg.apply_to_diag(m, jnp.exp) @@ -114,7 +114,7 @@ def test_apply_to_diag(self, rng: jax.Array): else: np.testing.assert_allclose(jnp.exp(m[..., i, j]), mnew[..., i, j]) - def test_matrix_powers(self, rng: jax.Array): + def test_matrix_powers(self, rng: jnp.ndarray): rng, subrng = jax.random.split(rng) m = jax.random.normal(key=subrng, shape=(4, 4)) m += jnp.swapaxes(m, axis1=-2, axis2=-1) # symmetric @@ -125,7 +125,7 @@ def test_matrix_powers(self, rng: jax.Array): np.testing.assert_allclose(m, actual[0], rtol=1.e-5) np.testing.assert_allclose(inv_m, actual[1], rtol=1.e-4) - def test_invmatvectril(self, rng: jax.Array): + def test_invmatvectril(self, rng: jnp.ndarray): rng, subrng = jax.random.split(rng) m = jax.random.normal(key=subrng, shape=(2, 2)) m += jnp.swapaxes(m, axis1=-2, axis2=-1) # symmetric @@ -138,7 +138,7 @@ def test_invmatvectril(self, rng: jax.Array): actual = linalg.invmatvectril(m=cholesky, x=x, lower=True) np.testing.assert_allclose(expected, actual, atol=1e-4, rtol=1.e-4) - def test_get_random_orthogonal(self, rng: jax.Array): + def test_get_random_orthogonal(self, rng: jnp.ndarray): rng, subrng = jax.random.split(rng) q = linalg.get_random_orthogonal(rng=subrng, dim=3) qt = jnp.transpose(q) diff --git a/tests/tools/gaussian_mixture/probabilities_test.py b/tests/tools/gaussian_mixture/probabilities_test.py index 4924924df..5d28a52aa 100644 --- a/tests/tools/gaussian_mixture/probabilities_test.py +++ b/tests/tools/gaussian_mixture/probabilities_test.py @@ -39,7 +39,7 @@ def test_log_probs(self): np.testing.assert_allclose(jnp.sum(probs), 1.0, rtol=1e-6, atol=1e-6) np.testing.assert_array_equal(probs > 0., True) - def test_from_random(self, rng: jax.Array): + def test_from_random(self, rng: jnp.ndarray): n_dimensions = 4 pp = probabilities.Probabilities.from_random( rng=rng, n_dimensions=n_dimensions, stdev=0.1 @@ -51,7 +51,7 @@ def test_from_probs(self): pp = probabilities.Probabilities.from_probs(probs) np.testing.assert_allclose(probs, pp.probs(), rtol=1e-6, atol=1e-6) - def test_sample(self, rng: jax.Array): + def test_sample(self, rng: jnp.ndarray): p = 0.4 probs = jnp.array([p, 1. - p]) pp = probabilities.Probabilities.from_probs(probs) diff --git a/tests/tools/gaussian_mixture/scale_tril_test.py b/tests/tools/gaussian_mixture/scale_tril_test.py index 3e53fd543..36643b6d7 100644 --- a/tests/tools/gaussian_mixture/scale_tril_test.py +++ b/tests/tools/gaussian_mixture/scale_tril_test.py @@ -47,7 +47,7 @@ def test_log_det_covariance(self, chol: scale_tril.ScaleTriL): actual = chol.log_det_covariance() np.testing.assert_almost_equal(actual, expected) - def test_from_random(self, rng: jax.Array): + def test_from_random(self, rng: jnp.ndarray): n_dimensions = 4 cov = scale_tril.ScaleTriL.from_random( rng=rng, n_dimensions=n_dimensions, stdev=0.1 @@ -56,7 +56,7 @@ def test_from_random(self, rng: jax.Array): cov.cholesky().shape, (n_dimensions, n_dimensions) ) - def test_from_cholesky(self, rng: jax.Array): + def test_from_cholesky(self, rng: jnp.ndarray): n_dimensions = 4 cholesky = scale_tril.ScaleTriL.from_random( rng=rng, n_dimensions=n_dimensions, stdev=1. @@ -64,7 +64,7 @@ def test_from_cholesky(self, rng: jax.Array): scale = scale_tril.ScaleTriL.from_cholesky(cholesky) np.testing.assert_allclose(cholesky, scale.cholesky(), atol=1e-4, rtol=1e-4) - def test_w2_dist(self, rng: jax.Array): + def test_w2_dist(self, rng: jnp.ndarray): # make sure distance between a random normal and itself is 0 rng, subrng = jax.random.split(rng) s = scale_tril.ScaleTriL.from_random(rng=subrng, n_dimensions=3) @@ -85,7 +85,7 @@ def test_w2_dist(self, rng: jax.Array): delta_sigma = jnp.sum((jnp.sqrt(diag0) - jnp.sqrt(diag1)) ** 2.) np.testing.assert_allclose(delta_sigma, w2, atol=1e-4, rtol=1e-4) - def test_transport(self, rng: jax.Array): + def test_transport(self, rng: jnp.ndarray): size = 4 rng, subrng0, subrng1 = jax.random.split(rng, num=3) diag0 = jnp.exp(jax.random.normal(key=subrng0, shape=(size,))) @@ -99,14 +99,14 @@ def test_transport(self, rng: jax.Array): expected = x * jnp.sqrt(diag1)[None] / jnp.sqrt(diag0)[None] np.testing.assert_allclose(expected, transported, atol=1e-4, rtol=1e-4) - def test_flatten_unflatten(self, rng: jax.Array): + def test_flatten_unflatten(self, rng: jnp.ndarray): scale = scale_tril.ScaleTriL.from_random(rng=rng, n_dimensions=3) children, aux_data = jax.tree_util.tree_flatten(scale) scale_new = jax.tree_util.tree_unflatten(aux_data, children) np.testing.assert_array_equal(scale.params, scale_new.params) assert scale == scale_new - def test_pytree_mapping(self, rng: jax.Array): + def test_pytree_mapping(self, rng: jnp.ndarray): scale = scale_tril.ScaleTriL.from_random(rng=rng, n_dimensions=3) scale_x_2 = jax.tree_map(lambda x: 2 * x, scale) np.testing.assert_allclose(2. * scale.params, scale_x_2.params) diff --git a/tests/tools/k_means_test.py b/tests/tools/k_means_test.py index 55cacde02..9b504a82d 100644 --- a/tests/tools/k_means_test.py +++ b/tests/tools/k_means_test.py @@ -31,7 +31,7 @@ def make_blobs( *args: Any, cost_fn: Optional[Literal["sqeucl", "cosine"]] = None, **kwargs: Any -) -> Tuple[Union[jax.Array, pointcloud.PointCloud], jax.Array, jax.Array]: +) -> Tuple[Union[jnp.ndarray, pointcloud.PointCloud], jnp.ndarray, jnp.ndarray]: X, y, c = datasets.make_blobs(*args, return_centers=True, **kwargs) X, y, c = jnp.asarray(X), jnp.asarray(y), jnp.asarray(c) if cost_fn is None: @@ -47,10 +47,10 @@ def make_blobs( def compute_assignment( - x: jax.Array, - centers: jax.Array, - weights: Optional[jax.Array] = None -) -> Tuple[jax.Array, float]: + x: jnp.ndarray, + centers: jnp.ndarray, + weights: Optional[jnp.ndarray] = None +) -> Tuple[jnp.ndarray, float]: if weights is None: weights = jnp.ones(x.shape[0]) cost_matrix = pointcloud.PointCloud(x, centers).cost_matrix @@ -63,7 +63,7 @@ def compute_assignment( class TestKmeansPlusPlus: @pytest.mark.fast.with_args("n_local_trials", [None, 3], only_fast=-1) - def test_n_local_trials(self, rng: jax.Array, n_local_trials): + def test_n_local_trials(self, rng: jnp.ndarray, n_local_trials): n, k = 100, 4 rng1, rng2 = jax.random.split(rng) geom, _, c = make_blobs( @@ -78,7 +78,7 @@ def test_n_local_trials(self, rng: jax.Array, n_local_trials): assert shift1 > shift2 @pytest.mark.fast.with_args("k", [3, 5], only_fast=0) - def test_matches_sklearn(self, rng: jax.Array, k: int): + def test_matches_sklearn(self, rng: jnp.ndarray, k: int): ndim = 2 geom, _, _ = make_blobs( n_samples=100, @@ -102,9 +102,9 @@ def test_matches_sklearn(self, rng: jax.Array, k: int): ) assert jnp.abs(pred_inertia - gt_inertia) <= 200 - def test_initialization_differentiable(self, rng: jax.Array): + def test_initialization_differentiable(self, rng: jnp.ndarray): - def callback(x: jax.Array) -> float: + def callback(x: jnp.ndarray) -> float: geom = pointcloud.PointCloud(x) centers = k_means._k_means_plus_plus(geom, k=3, rng=rng) _, inertia = compute_assignment(x, centers) @@ -122,7 +122,7 @@ class TestKmeans: @pytest.mark.fast() @pytest.mark.parametrize("k", [1, 6]) - def test_k_means_output(self, rng: jax.Array, k: int): + def test_k_means_output(self, rng: jnp.ndarray, k: int): max_iter, ndim = 10, 4 geom, gt_assignment, _ = make_blobs( n_samples=50, n_features=ndim, centers=k, random_state=42 @@ -160,7 +160,7 @@ def test_k_means_simple_example(self): ["k-means++", "random", "callable", "wrong-callable"], only_fast=1, ) - def test_init_method(self, rng: jax.Array, init: str): + def test_init_method(self, rng: jnp.ndarray, init: str): if init == "callable": init_fn = lambda geom, k, _: geom.x[:k] elif init == "wrong-callable": @@ -176,7 +176,7 @@ def test_init_method(self, rng: jax.Array, init: str): else: _ = k_means.k_means(geom, k, init=init_fn) - def test_k_means_plus_plus_better_than_random(self, rng: jax.Array): + def test_k_means_plus_plus_better_than_random(self, rng: jnp.ndarray): k = 5 rng1, rng2 = jax.random.split(rng, 2) geom, _, _ = make_blobs(n_samples=50, centers=k, random_state=10) @@ -189,7 +189,7 @@ def test_k_means_plus_plus_better_than_random(self, rng: jax.Array): assert res_kpp.iteration < res_random.iteration assert res_kpp.error <= res_random.error - def test_larger_n_init_helps(self, rng: jax.Array): + def test_larger_n_init_helps(self, rng: jnp.ndarray): k = 10 geom, _, _ = make_blobs(n_samples=150, centers=k, random_state=0) @@ -199,7 +199,7 @@ def test_larger_n_init_helps(self, rng: jax.Array): assert res_larger_n_init.error < res.error @pytest.mark.parametrize("max_iter", [8, 16]) - def test_store_inner_errors(self, rng: jax.Array, max_iter: int): + def test_store_inner_errors(self, rng: jnp.ndarray, max_iter: int): ndim, k = 10, 4 geom, _, _ = make_blobs( n_samples=40, n_features=ndim, centers=k, random_state=43 @@ -215,7 +215,7 @@ def test_store_inner_errors(self, rng: jax.Array, max_iter: int): # check if error is decreasing np.testing.assert_array_equal(jnp.diff(errors[::-1]) >= 0., True) - def test_strict_tolerance(self, rng: jax.Array): + def test_strict_tolerance(self, rng: jnp.ndarray): k = 11 geom, _, _ = make_blobs(n_samples=200, centers=k, random_state=39) @@ -229,7 +229,7 @@ def test_strict_tolerance(self, rng: jax.Array): @pytest.mark.parametrize( "tol", [1e-3, 0.], ids=["weak-convergence", "strict-convergence"] ) - def test_convergence_force_scan(self, rng: jax.Array, tol: float): + def test_convergence_force_scan(self, rng: jnp.ndarray, tol: float): k, n_iter = 9, 20 geom, _, _ = make_blobs(n_samples=100, centers=k, random_state=37) @@ -247,7 +247,7 @@ def test_convergence_force_scan(self, rng: jax.Array, tol: float): assert res.iteration == n_iter np.testing.assert_array_equal(res.inner_errors == -1, False) - def test_k_means_min_iterations(self, rng: jax.Array): + def test_k_means_min_iterations(self, rng: jnp.ndarray): k, min_iter = 8, 12 geom, _, _ = make_blobs(n_samples=160, centers=k, random_state=38) @@ -264,7 +264,7 @@ def test_k_means_min_iterations(self, rng: jax.Array): assert res.converged assert jnp.sum(res.inner_errors != -1) >= min_iter - def test_weight_scaling_effects_only_inertia(self, rng: jax.Array): + def test_weight_scaling_effects_only_inertia(self, rng: jnp.ndarray): k = 10 rng1, rng2 = jax.random.split(rng) geom, _, _ = make_blobs(n_samples=130, centers=k, random_state=3) @@ -285,7 +285,7 @@ def test_weight_scaling_effects_only_inertia(self, rng: jax.Array): ) @pytest.mark.fast() - def test_empty_weights(self, rng: jax.Array): + def test_empty_weights(self, rng: jnp.ndarray): n, ndim, k, d = 20, 2, 3, 5. gen = np.random.RandomState(0) x = gen.normal(size=(n, ndim)) @@ -333,10 +333,10 @@ def test_cosine_cost_fn(self): @pytest.mark.fast.with_args("init", ["k-means++", "random"], only_fast=0) def test_k_means_jitting( - self, rng: jax.Array, init: Literal["k-means++", "random"] + self, rng: jnp.ndarray, init: Literal["k-means++", "random"] ): - def callback(x: jax.Array) -> k_means.KMeansOutput: + def callback(x: jnp.ndarray) -> k_means.KMeansOutput: return k_means.k_means( x, k=k, init=init, store_inner_errors=True, rng=rng ) @@ -365,10 +365,10 @@ def callback(x: jax.Array) -> k_means.KMeansOutput: (False, True)], ids=["jit-while-loop", "nojit-for-loop"]) def test_k_means_differentiability( - self, rng: jax.Array, jit: bool, force_scan: bool + self, rng: jnp.ndarray, jit: bool, force_scan: bool ): - def inertia(x: jax.Array, w: jax.Array) -> float: + def inertia(x: jnp.ndarray, w: jnp.ndarray) -> float: return k_means.k_means( x, k=k, @@ -404,7 +404,7 @@ def inertia(x: jax.Array, w: jax.Array) -> float: @pytest.mark.parametrize("tol", [1e-3, 0.]) @pytest.mark.parametrize(("n", "k"), [(37, 4), (128, 6)]) def test_clustering_matches_sklearn( - self, rng: jax.Array, n: int, k: int, tol: float + self, rng: jnp.ndarray, n: int, k: int, tol: float ): x, _, _ = make_blobs(n_samples=n, centers=k, random_state=41) diff --git a/tests/tools/segment_sinkhorn_test.py b/tests/tools/segment_sinkhorn_test.py index 119dbf93a..6e8a8fb8c 100644 --- a/tests/tools/segment_sinkhorn_test.py +++ b/tests/tools/segment_sinkhorn_test.py @@ -26,7 +26,7 @@ class TestSegmentSinkhorn: @pytest.fixture(autouse=True) - def setUp(self, rng: jax.Array): + def setUp(self, rng: jnp.ndarray): self._dim = 4 self._num_points = 13, 17 self._max_measure_size = 20 diff --git a/tests/tools/sinkhorn_divergence_test.py b/tests/tools/sinkhorn_divergence_test.py index 07bcf535e..0f3e56bfc 100644 --- a/tests/tools/sinkhorn_divergence_test.py +++ b/tests/tools/sinkhorn_divergence_test.py @@ -28,7 +28,7 @@ class TestSinkhornDivergence: @pytest.fixture(autouse=True) - def setUp(self, rng: jax.Array): + def setUp(self, rng: jnp.ndarray): self._dim = 4 self._num_points = 13, 17 self.rng, *rngs = jax.random.split(rng, 3) @@ -389,7 +389,7 @@ def test_euclidean_momentum_params( class TestSinkhornDivergenceGrad: @pytest.fixture(autouse=True) - def initialize(self, rng: jax.Array): + def initialize(self, rng: jnp.ndarray): self._dim = 3 self._num_points = 13, 12 self.rng, *rngs = jax.random.split(rng, 3) @@ -403,7 +403,7 @@ def test_gradient_generic_point_cloud_wrapper(self): x = jax.random.uniform(rngs[0], (self._num_points[0], self._dim)) y = jax.random.uniform(rngs[1], (self._num_points[1], self._dim)) - def loss_fn(cloud_a: jax.Array, cloud_b: jax.Array) -> float: + def loss_fn(cloud_a: jnp.ndarray, cloud_b: jnp.ndarray) -> float: div = sinkhorn_divergence.sinkhorn_divergence( pointcloud.PointCloud, cloud_a, diff --git a/tests/tools/soft_sort_test.py b/tests/tools/soft_sort_test.py index 4f3a12c10..2432a2dee 100644 --- a/tests/tools/soft_sort_test.py +++ b/tests/tools/soft_sort_test.py @@ -27,14 +27,14 @@ class TestSoftSort: @pytest.mark.parametrize("shape", [(20,), (20, 1)]) - def test_sort_one_array(self, rng: jax.Array, shape: Tuple[int, ...]): + def test_sort_one_array(self, rng: jnp.ndarray, shape: Tuple[int, ...]): x = jax.random.uniform(rng, shape) xs = soft_sort.sort(x, axis=0) np.testing.assert_array_equal(x.shape, xs.shape) np.testing.assert_array_equal(jnp.diff(xs, axis=0) >= 0.0, True) - def test_sort_array_squashing_momentum(self, rng: jax.Array): + def test_sort_array_squashing_momentum(self, rng: jnp.ndarray): shape = (33, 1) x = jax.random.uniform(rng, shape) xs_lin = soft_sort.sort( @@ -61,7 +61,7 @@ def test_sort_array_squashing_momentum(self, rng: jax.Array): @pytest.mark.fast() @pytest.mark.parametrize("k", [-1, 4, 100]) - def test_topk_one_array(self, rng: jax.Array, k: int): + def test_topk_one_array(self, rng: jnp.ndarray, k: int): n = 20 x = jax.random.uniform(rng, (n,)) axis = 0 @@ -75,7 +75,7 @@ def test_topk_one_array(self, rng: jax.Array, k: int): np.testing.assert_allclose(xs, jnp.sort(x, axis=axis)[-outsize:], atol=0.01) @pytest.mark.fast.with_args("topk", [-1, 2, 11], only_fast=-1) - def test_sort_batch(self, rng: jax.Array, topk: int): + def test_sort_batch(self, rng: jnp.ndarray, topk: int): x = jax.random.uniform(rng, (32, 10, 6, 4)) axis = 1 xs = soft_sort.sort(x, axis=axis, topk=topk) @@ -85,7 +85,7 @@ def test_sort_batch(self, rng: jax.Array, topk: int): np.testing.assert_array_equal(xs.shape, expected_shape) np.testing.assert_array_equal(jnp.diff(xs, axis=axis) >= 0.0, True) - def test_multivariate_cdf_quantiles(self, rng: jax.Array): + def test_multivariate_cdf_quantiles(self, rng: jnp.ndarray): n, d = 512, 3 key1, key2, key3 = jax.random.split(rng, 3) @@ -108,7 +108,7 @@ def test_multivariate_cdf_quantiles(self, rng: jax.Array): # Check passing custom sampler, must be still symmetric / centered on {.5}^d # Check passing custom epsilon also works. - def ball_sampler(k: jax.Array, s: Tuple[int, int]) -> jax.Array: + def ball_sampler(k: jnp.ndarray, s: Tuple[int, int]) -> jnp.ndarray: return 0.5 * (jax.random.ball(k, d=s[1], p=4, shape=(s[0],)) + 1.) num_target_samples = 473 @@ -128,7 +128,7 @@ def mv_c_q(inputs, num_target_samples, rng, epsilon): np.testing.assert_allclose(z, qua(q), atol=atol) @pytest.mark.fast.with_args("axis,jit", [(0, False), (1, True)], only_fast=0) - def test_ranks(self, axis, rng: jax.Array, jit: bool): + def test_ranks(self, axis, rng: jnp.ndarray, jit: bool): rng1, rng2 = jax.random.split(rng, 2) num_targets = 13 x = jax.random.uniform(rng1, (8, 5, 2)) @@ -163,7 +163,7 @@ def test_ranks(self, axis, rng: jax.Array, jit: bool): np.testing.assert_allclose(ranks, expected_ranks, atol=0.3, rtol=0.1) @pytest.mark.fast.with_args("axis,jit", [(0, False), (1, True)], only_fast=0) - def test_topk_mask(self, axis, rng: jax.Array, jit: bool): + def test_topk_mask(self, axis, rng: jnp.ndarray, jit: bool): def boolean_topk_mask(u, k): return u >= jnp.flip(jax.numpy.sort(u))[k - 1] @@ -194,7 +194,7 @@ def test_quantile(self, q: float): np.testing.assert_allclose(x_q, q, atol=1e-3, rtol=1e-2) - def test_quantile_on_several_axes(self, rng: jax.Array): + def test_quantile_on_several_axes(self, rng: jnp.ndarray): batch, height, width, channels = 4, 47, 45, 3 x = jax.random.uniform(rng, shape=(batch, height, width, channels)) q = soft_sort.quantile( @@ -208,7 +208,7 @@ def test_quantile_on_several_axes(self, rng: jax.Array): @pytest.mark.fast() @pytest.mark.parametrize("jit", [False, True]) - def test_quantiles(self, rng: jax.Array, jit: bool): + def test_quantiles(self, rng: jnp.ndarray, jit: bool): inputs = jax.random.uniform(rng, (100, 2, 3)) q = jnp.array([.1, .8, .4]) quantile_fn = soft_sort.quantile @@ -220,7 +220,7 @@ def test_quantiles(self, rng: jax.Array, jit: bool): np.testing.assert_allclose(m1.mean(axis=[1, 2]), q, atol=5e-2) @pytest.mark.parametrize("jit", [False, True]) - def test_soft_quantile_normalization(self, rng: jax.Array, jit: bool): + def test_soft_quantile_normalization(self, rng: jnp.ndarray, jit: bool): rngs = jax.random.split(rng, 2) x = jax.random.uniform(rngs[0], shape=(100,)) mu, sigma = 2.0, 1.2 @@ -237,7 +237,7 @@ def test_soft_quantile_normalization(self, rng: jax.Array, jit: bool): [mu_target, sigma_target], rtol=0.05) - def test_sort_with(self, rng: jax.Array): + def test_sort_with(self, rng: jnp.ndarray): n, d = 20, 4 inputs = jax.random.uniform(rng, shape=(n, d)) criterion = jnp.linspace(0.1, 1.2, n) @@ -269,7 +269,7 @@ def test_quantize(self, jit: bool): np.testing.assert_allclose(min_distances, min_distances, atol=0.05) @pytest.mark.parametrize("implicit", [False, True]) - def test_soft_sort_jacobian(self, rng: jax.Array, implicit: bool): + def test_soft_sort_jacobian(self, rng: jnp.ndarray, implicit: bool): # Add a ridge when using JAX solvers. try: from ott.solvers.linear import lineax_implicit # noqa: F401 @@ -283,7 +283,7 @@ def test_soft_sort_jacobian(self, rng: jax.Array, implicit: bool): z = jax.random.uniform(rngs[0], ((b, n))) random_dir = jax.random.normal(rngs[1], (b,)) / b - def loss_fn(logits: jax.Array) -> float: + def loss_fn(logits: jnp.ndarray) -> float: im_d = None if implicit: # Ridge parameters are only used when using JAX's CG. From 8fa3683a5b91d5848231236cb97d5b1632d1538d Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 28 Nov 2023 14:53:21 +0100 Subject: [PATCH 032/186] move dataloader from tests to module --- src/ott/neural/data/dataloaders.py | 131 +++++++++++++++++++++++--- tests/neural/conftest.py | 143 +++-------------------------- 2 files changed, 127 insertions(+), 147 deletions(-) diff --git a/src/ott/neural/data/dataloaders.py b/src/ott/neural/data/dataloaders.py index 466460384..121af2c94 100644 --- a/src/ott/neural/data/dataloaders.py +++ b/src/ott/neural/data/dataloaders.py @@ -11,22 +11,123 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -#import tensorflow as tf +from typing import Dict, Iterator, Mapping, Optional -class ConditionalDataLoader: #TODO(@MUCDK) uncomment, resolve installation issues with TF - pass +import numpy as np - #def __init__( - # self, rng: jax.random.KeyArray, dataloaders: Dict[str, tf.Dataloader], - # p: jnp.ndarray - #) -> None: - # super().__init__() - # self.rng = rng - # self.conditions = dataloaders.keys() - # self.p = p - #def __next__(self) -> jnp.ndarray: - # self.rng, rng = jax.random.split(self.rng, 2) - # condition = jax.random.choice(rng, self.conditions, p=self.p) - # return next(self.dataloaders[condition]) +__all__ =[ "OTDataLoader", "ConditionalDataLoader"] + +class OTDataLoader: + """Data loader for OT problems. + + Args: + batch_size: Number of samples per batch. + source_lin: Linear part of the source measure. + source_quad: Quadratic part of the source measure. + target_lin: Linear part of the target measure. + target_quad: Quadratic part of the target measure. + source_conditions: Conditions of the source measure. + target_conditions: Conditions of the target measure. + seed: Random seed. + """ + + def __init__( + self, + batch_size: int = 64, + source_lin: Optional[np.ndarray] = None, + source_quad: Optional[np.ndarray] = None, + target_lin: Optional[np.ndarray] = None, + target_quad: Optional[np.ndarray] = None, + source_conditions: Optional[np.ndarray] = None, + target_conditions: Optional[np.ndarray] = None, + seed: int = 0, + ) -> None: + super().__init__() + if source_lin is not None: + if source_quad is not None: + assert len(source_lin) == len(source_quad) + self.n_source = len(source_lin) + else: + self.n_source = len(source_lin) + else: + self.n_source = len(source_quad) + if source_conditions is not None: + assert len(source_conditions) == self.n_source + if target_lin is not None: + if target_quad is not None: + assert len(target_lin) == len(target_quad) + self.n_target = len(target_lin) + else: + self.n_target = len(target_lin) + else: + self.n_target = len(target_quad) + if target_conditions is not None: + assert len(target_conditions) == self.n_target + + self.source_lin = source_lin + self.target_lin = target_lin + self.source_quad = source_quad + self.target_quad = target_quad + self.source_conditions = source_conditions + self.target_conditions = target_conditions + self.batch_size = batch_size + self.rng = np.random.default_rng(seed=seed) + + def __next__(self) -> Mapping[str, np.ndarray]: + inds_source = self.rng.choice(self.n_source, size=[self.batch_size]) + inds_target = self.rng.choice(self.n_target, size=[self.batch_size]) + return { + "source_lin": + self.source_lin[inds_source, :] + if self.source_lin is not None else None, + "source_quad": + self.source_quad[inds_source, :] + if self.source_quad is not None else None, + "target_lin": + self.target_lin[inds_target, :] + if self.target_lin is not None else None, + "target_quad": + self.target_quad[inds_target, :] + if self.target_quad is not None else None, + "source_conditions": + self.source_conditions[inds_source, :] + if self.source_conditions is not None else None, + "target_conditions": + self.target_conditions[inds_target, :] + if self.target_conditions is not None else None, + } + + +class ConditionalDataLoader: + """Data loader for OT problems with conditions. + + This data loader wraps several data loaders and samples from them according to their conditions. + + Args: + dataloaders: Dictionary of data loaders with keys corresponding to conditions. + p: Probability of sampling from each data loader. + seed: Random seed. + + """ + + def __init__( + self, + dataloaders: Dict[str, Iterator], + p: np.ndarray, + seed: int = 0 + ) -> None: + super().__init__() + self.dataloaders = dataloaders + self.conditions = list(dataloaders.keys()) + self.p = p + self.rng = np.random.default_rng(seed=seed) + + def __next__(self, cond: str = None) -> Mapping[str, np.ndarray]: + if cond is not None: + if cond not in self.conditions: + raise ValueError(f"Condition {cond} not in {self.conditions}") + return next(self.dataloaders[cond]) + idx = self.rng.choice(len(self.conditions), p=self.p) + return next(self.dataloaders[self.conditions[idx]]) diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py index 0dd65ba57..edb635e90 100644 --- a/tests/neural/conftest.py +++ b/tests/neural/conftest.py @@ -1,60 +1,7 @@ -from typing import Dict, Iterator, Mapping, Optional - import numpy as np import pytest - -class DataLoader: - - def __init__( - self, - source_data: np.ndarray, - target_data: np.ndarray, - batch_size: int = 64, - source_conditions: Optional[np.ndarray] = None, - target_conditions: Optional[np.ndarray] = None, - ) -> None: - super().__init__() - self.source_data = source_data - self.target_data = target_data - self.source_conditions = source_conditions - self.target_conditions = target_conditions - self.batch_size = batch_size - self.rng = np.random.default_rng(seed=0) - - def __next__(self) -> Mapping[str, np.ndarray]: - inds_source = self.rng.choice(len(self.source_data), size=[self.batch_size]) - inds_target = self.rng.choice(len(self.target_data), size=[self.batch_size]) - return { - "source_lin": - self.source_data[inds_source, :], - "target_lin": - self.target_data[inds_target, :], - "source_conditions": - self.source_conditions[inds_source, :] - if self.source_conditions is not None else None, - "target_conditions": - self.target_conditions[inds_target, :] - if self.target_conditions is not None else None, - } - - -class ConditionalDataLoader: - - def __init__(self, dataloaders: Dict[str, Iterator], p: np.ndarray) -> None: - super().__init__() - self.dataloaders = dataloaders - self.conditions = list(dataloaders.keys()) - self.p = p - self.rng = np.random.default_rng(seed=0) - - def __next__(self, cond: str = None) -> Mapping[str, np.ndarray]: - if cond is not None: - if cond not in self.conditions: - raise ValueError(f"Condition {cond} not in {self.conditions}") - return next(self.dataloaders[cond]) - idx = self.rng.choice(len(self.conditions), p=self.p) - return next(self.dataloaders[self.conditions[idx]]) +from ott.neural.data.dataloaders import ConditionalDataLoader, OTDataLoader @pytest.fixture(scope="module") @@ -63,7 +10,7 @@ def data_loader_gaussian(): rng = np.random.default_rng(seed=0) source = rng.normal(size=(100, 2)) target = rng.normal(size=(100, 2)) + 1.0 - return DataLoader(source, target, 16) + return OTDataLoader(source, target, 16) @pytest.fixture(scope="module") @@ -75,10 +22,10 @@ def data_loader_gaussian_conditional(): source_1 = rng.normal(size=(100, 2)) target_1 = rng.normal(size=(100, 2)) - 2.0 - dl0 = DataLoader( + dl0 = OTDataLoader( source_0, target_0, 16, source_conditions=np.zeros_like(source_0) * 0.0 ) - dl1 = DataLoader( + dl1 = OTDataLoader( source_1, target_1, 16, source_conditions=np.ones_like(source_1) * 1.0 ) @@ -93,75 +40,7 @@ def data_loader_gaussian_with_conditions(): target = rng.normal(size=(100, 2)) + 1.0 source_conditions = rng.normal(size=(100, 1)) target_conditions = rng.normal(size=(100, 1)) - 1.0 - return DataLoader(source, target, 16, source_conditions, target_conditions) - - -class GENOTDataLoader: - - def __init__( - self, - batch_size: int = 64, - source_lin: Optional[np.ndarray] = None, - source_quad: Optional[np.ndarray] = None, - target_lin: Optional[np.ndarray] = None, - target_quad: Optional[np.ndarray] = None, - source_conditions: Optional[np.ndarray] = None, - target_conditions: Optional[np.ndarray] = None, - ) -> None: - super().__init__() - if source_lin is not None: - if source_quad is not None: - assert len(source_lin) == len(source_quad) - self.n_source = len(source_lin) - else: - self.n_source = len(source_lin) - else: - self.n_source = len(source_quad) - if source_conditions is not None: - assert len(source_conditions) == self.n_source - if target_lin is not None: - if target_quad is not None: - assert len(target_lin) == len(target_quad) - self.n_target = len(target_lin) - else: - self.n_target = len(target_lin) - else: - self.n_target = len(target_quad) - if target_conditions is not None: - assert len(target_conditions) == self.n_target - - self.source_lin = source_lin - self.target_lin = target_lin - self.source_quad = source_quad - self.target_quad = target_quad - self.source_conditions = source_conditions - self.target_conditions = target_conditions - self.batch_size = batch_size - self.rng = np.random.default_rng(seed=0) - - def __next__(self) -> Mapping[str, np.ndarray]: - inds_source = self.rng.choice(self.n_source, size=[self.batch_size]) - inds_target = self.rng.choice(self.n_target, size=[self.batch_size]) - return { - "source_lin": - self.source_lin[inds_source, :] - if self.source_lin is not None else None, - "source_quad": - self.source_quad[inds_source, :] - if self.source_quad is not None else None, - "target_lin": - self.target_lin[inds_target, :] - if self.target_lin is not None else None, - "target_quad": - self.target_quad[inds_target, :] - if self.target_quad is not None else None, - "source_conditions": - self.source_conditions[inds_source, :] - if self.source_conditions is not None else None, - "target_conditions": - self.target_conditions[inds_target, :] - if self.target_conditions is not None else None, - } + return OTDataLoader(source, target, 16, source_conditions, target_conditions) @pytest.fixture(scope="module") @@ -170,7 +49,7 @@ def genot_data_loader_linear(): rng = np.random.default_rng(seed=0) source = rng.normal(size=(100, 2)) target = rng.normal(size=(100, 2)) + 1.0 - return GENOTDataLoader(16, source_lin=source, target_lin=target) + return OTDataLoader(16, source_lin=source, target_lin=target) @pytest.fixture(scope="module") @@ -181,7 +60,7 @@ def genot_data_loader_linear_conditional(): target = rng.normal(size=(100, 2)) + 1.0 conditions_source = rng.normal(size=(100, 4)) conditions_target = rng.normal(size=(100, 4)) - 1.0 - return GENOTDataLoader( + return OTDataLoader( 16, source_lin=source, target_lin=target, @@ -196,7 +75,7 @@ def genot_data_loader_quad(): rng = np.random.default_rng(seed=0) source = rng.normal(size=(100, 2)) target = rng.normal(size=(100, 1)) + 1.0 - return GENOTDataLoader(16, source_quad=source, target_quad=target) + return OTDataLoader(16, source_quad=source, target_quad=target) @pytest.fixture(scope="module") @@ -206,7 +85,7 @@ def genot_data_loader_quad_conditional(): source = rng.normal(size=(100, 2)) target = rng.normal(size=(100, 1)) + 1.0 conditions = rng.normal(size=(100, 7)) - return GENOTDataLoader( + return OTDataLoader( 16, source_quad=source, target_quad=target, @@ -223,7 +102,7 @@ def genot_data_loader_fused(): target_q = rng.normal(size=(100, 1)) + 1.0 source_lin = rng.normal(size=(100, 2)) target_lin = rng.normal(size=(100, 2)) + 1.0 - return GENOTDataLoader( + return OTDataLoader( 16, source_lin=source_lin, source_quad=source_q, @@ -241,7 +120,7 @@ def genot_data_loader_fused_conditional(): source_lin = rng.normal(size=(100, 2)) target_lin = rng.normal(size=(100, 2)) + 1.0 conditions = rng.normal(size=(100, 7)) - return GENOTDataLoader( + return OTDataLoader( 16, source_lin=source_lin, source_quad=source_q, From 2e2f9f344822c98fcc1b54f606b44208a366f129 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 28 Nov 2023 15:33:39 +0100 Subject: [PATCH 033/186] add docstrings to neurcal networks --- src/ott/neural/data/dataloaders.py | 2 +- src/ott/neural/models/base_models.py | 18 +++++- src/ott/neural/models/models.py | 83 +++++++++++++++++++++++++++- 3 files changed, 99 insertions(+), 4 deletions(-) diff --git a/src/ott/neural/data/dataloaders.py b/src/ott/neural/data/dataloaders.py index 121af2c94..938dabb96 100644 --- a/src/ott/neural/data/dataloaders.py +++ b/src/ott/neural/data/dataloaders.py @@ -16,8 +16,8 @@ import numpy as np +__all__ = ["OTDataLoader", "ConditionalDataLoader"] -__all__ =[ "OTDataLoader", "ConditionalDataLoader"] class OTDataLoader: """Data loader for OT problems. diff --git a/src/ott/neural/models/base_models.py b/src/ott/neural/models/base_models.py index c96ad5b29..8b5dc126a 100644 --- a/src/ott/neural/models/base_models.py +++ b/src/ott/neural/models/base_models.py @@ -21,6 +21,7 @@ class BaseNeuralVectorField(nn.Module, abc.ABC): + """Base class for neural vector field models.""" @abc.abstractmethod def __call__( @@ -29,11 +30,20 @@ def __call__( x: jnp.ndarray, condition: Optional[jnp.ndarray] = None, keys_model: Optional[jnp.ndarray] = None - ) -> jnp.ndarray: # noqa: D102): + ) -> jnp.ndarray: + """"Evaluate the vector field. + + Args: + t: Time. + x: Input data. + condition: Condition. + keys_model: Random keys for the model. + """ pass class BaseRescalingNet(nn.Module, abc.ABC): + """Base class for models to learn distributional rescaling factors.""" @abc.abstractmethod def __call__( @@ -41,4 +51,10 @@ def __call__( x: jnp.ndarray, condition: Optional[jnp.ndarray] = None ) -> jnp.ndarray: + """Evaluate the model. + + Args: + x: Input data. + condition: Condition. + """ pass diff --git a/src/ott/neural/models/models.py b/src/ott/neural/models/models.py index 80326d3ca..0c424c588 100644 --- a/src/ott/neural/models/models.py +++ b/src/ott/neural/models/models.py @@ -180,7 +180,9 @@ def __call__(self, x: jnp.ndarray) -> float: # noqa: D102 return z.squeeze() -class MLP(neuraldual.BaseW2NeuralDual): +class MLP( + neuraldual.BaseW2NeuralDual +): #TODO don't let this inherit from BaseW2NeuralDual """A generic, typically not-convex (w.r.t input) MLP. Args: @@ -418,12 +420,34 @@ class Block(nn.Module): @nn.compact def __call__(self, x): for i in range(self.num_layers): - x = nn.Dense(self.dim, name="fc{0}".format(i))(x) + x = nn.Dense(self.dim)(x) x = self.act_fn(x) return nn.Dense(self.out_dim)(x) class NeuralVectorField(BaseNeuralVectorField): + """Parameterized neural vector field. + + Each of the input, condition, and time embeddings are passed through a block + consisting of ``num_layers_per_block`` layers of dimension ``latent_embed_dim``, ``condition_embed_dim``, and ``time_embed_dim``, respectively. + The output of each block is concatenated and passed through a final block of dimension ``joint_hidden_dim``. + + Args: + output_dim: Dimensionality of the neural vector field. + condition_dim: Dimensionality of the conditioning vector. + latent_embed_dim: Dimensionality of the embedding of the data. + condition_embed_dim: Dimensionality of the embedding of the condition. + If ``None``, set to ``latent_embed_dim``. + t_embed_dim: Dimensionality of the time embedding. + If ``None``, set to ``latent_embed_dim``. + joint_hidden_dim: Dimensionality of the hidden layers of the joint network. + If ``None``, set to ``latent_embed_dim + condition_embed_dim + + t_embed_dim``. + num_layers_per_block: Number of layers per block. + act_fn: Activation function. + n_frequencies: Number of frequencies to use for the time embedding. + + """ output_dim: int condition_dim: int latent_embed_dim: int @@ -435,6 +459,14 @@ class NeuralVectorField(BaseNeuralVectorField): n_frequencies: int = 128 def time_encoder(self, t: jnp.ndarray) -> jnp.array: + """Encode the time. + + Args: + t: Time. + + Returns: + Encoded time. + """ freq = 2 * jnp.arange(self.n_frequencies) * jnp.pi t = freq * t return jnp.concatenate((jnp.cos(t), jnp.sin(t)), axis=-1) @@ -469,7 +501,17 @@ def __call__( condition: Optional[jnp.ndarray], keys_model: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: + """Forward pass through the neural vector field. + + Args: + t: Time. + x: Data. + condition: Conditioning vector. + keys_model: Random number generator. + Returns: + Output of the neural vector field. + """ t = self.time_encoder(t) t = Block( dim=self.t_embed_dim, @@ -524,6 +566,16 @@ def create_train_state( optimizer: optax.OptState, input_dim: int, ) -> train_state.TrainState: + """Create the training state. + + Args: + rng: Random number generator. + optimizer: Optimizer. + input_dim: Dimensionality of the input. + + Returns: + Training state. + """ params = self.init( rng, jnp.ones((1, 1)), jnp.ones((1, input_dim)), jnp.ones((1, self.condition_dim)) @@ -534,6 +586,23 @@ def create_train_state( class Rescaling_MLP(BaseRescalingNet): + """Network to learn distributional rescaling factors based on a MLP. + + The input is passed through a block consisting of ``num_layers_per_block`` with size ``hidden_dim``. + If ``condition_dim`` is greater than 0, the conditioning vector is passed through a block of the same size. + Both outputs are concatenated and passed through another block of the same size. + + To ensure non-negativity of the output, the output is exponentiated. + + Args: + hidden_dim: Dimensionality of the hidden layers. + condition_dim: Dimensionality of the conditioning vector. + num_layers_per_block: Number of layers per block. + act_fn: Activation function. + + Returns: + Rescaling factors. + """ hidden_dim: int condition_dim: int num_layers_per_block: int = 3 @@ -582,6 +651,16 @@ def create_train_state( optimizer: optax.OptState, input_dim: int, ) -> train_state.TrainState: + """Create the training state. + + Args: + rng: Random number generator. + optimizer: Optimizer. + input_dim: Dimensionality of the input. + + Returns: + Training state. + """ params = self.init( rng, jnp.ones((1, input_dim)), jnp.ones((1, self.condition_dim)) )["params"] From 8c71deb5229d8929cde52293c83dee6e98146b0b Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 28 Nov 2023 17:20:18 +0100 Subject: [PATCH 034/186] [ci skip] adapt type of scale_cost and cost_fn --- src/ott/neural/solvers/base_solver.py | 13 ++++++++--- src/ott/neural/solvers/genot.py | 31 ++++++++++++--------------- 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/src/ott/neural/solvers/base_solver.py b/src/ott/neural/solvers/base_solver.py index 0ad159a8f..cc9a4c310 100644 --- a/src/ott/neural/solvers/base_solver.py +++ b/src/ott/neural/solvers/base_solver.py @@ -152,8 +152,10 @@ def _get_sinkhorn_match_fn( self, ot_solver: Any, epsilon: float = 1e-2, - cost_fn: Any = costs.SqEuclidean(), - scale_cost: Any = "mean", + cost_fn: costs.CostFn = costs.SqEuclidean(), + scale_cost: Union[bool, int, float, + Literal["mean", "max_norm", "max_bound", "max_cost", + "median"]] = "mean", tau_a: float = 1.0, tau_b: float = 1.0, *, @@ -187,7 +189,12 @@ def _get_gromov_match_fn( self, ot_solver: Any, cost_fn: Union[Any, Mapping[str, Any]], - scale_cost: Union[Any, Mapping[str, Any]], + scale_cost: Union[Union[bool, int, float, + Literal["mean", "max_norm", "max_bound", + "max_cost", "median"]], + Dict[str, Union[bool, int, float, + Literal["mean", "max_norm", "max_bound", + "max_cost", "median"]]]], tau_a: float, tau_b: float, fused_penalty: float, diff --git a/src/ott/neural/solvers/genot.py b/src/ott/neural/solvers/genot.py index 7b0867e44..a71f34760 100644 --- a/src/ott/neural/solvers/genot.py +++ b/src/ott/neural/solvers/genot.py @@ -13,15 +13,7 @@ # limitations under the License. import functools import types -from typing import ( - Any, - Callable, - Dict, - Optional, - Tuple, - Type, - Union, -) +from typing import Any, Callable, Dict, Literal, Optional, Tuple, Type, Union import diffrax import jax @@ -35,15 +27,15 @@ from ott.geometry import costs from ott.neural.models.models import BaseNeuralVectorField from ott.neural.solvers.base_solver import ( - BaseNeuralSolver, - ResampleMixin, - UnbalancednessMixin, + BaseNeuralSolver, + ResampleMixin, + UnbalancednessMixin, ) from ott.neural.solvers.flows import ( - BaseFlow, - BaseTimeSampler, - ConstantNoiseFlow, - UniformSampler, + BaseFlow, + BaseTimeSampler, + ConstantNoiseFlow, + UniformSampler, ) from ott.solvers import was_solver from ott.solvers.linear import sinkhorn @@ -101,7 +93,12 @@ def __init__( ot_solver: Type[was_solver.WassersteinSolver], epsilon: float, cost_fn: Union[costs.CostFn, Dict[str, costs.CostFn]], - scale_cost: Union[Any, Dict[str, Any]], #TODO: replace `Any` + scale_cost: Union[Union[bool, int, float, + Literal["mean", "max_norm", "max_bound", + "max_cost", "median"]], + Dict[str, Union[bool, int, float, + Literal["mean", "max_norm", "max_bound", + "max_cost", "median"]]]], optimizer: Type[optax.GradientTransformation], flow: Type[BaseFlow] = ConstantNoiseFlow(0.0), time_sampler: Type[BaseTimeSampler] = UniformSampler(), From a25b6c22d6f0ec2feee8facb262287f1c8dd11b2 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 28 Nov 2023 17:53:27 +0100 Subject: [PATCH 035/186] [ci skip] clean code --- src/ott/neural/models/models.py | 2 +- src/ott/neural/solvers/genot.py | 21 ++++++++------------- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/src/ott/neural/models/models.py b/src/ott/neural/models/models.py index 0c424c588..c8bcb00ce 100644 --- a/src/ott/neural/models/models.py +++ b/src/ott/neural/models/models.py @@ -419,7 +419,7 @@ class Block(nn.Module): @nn.compact def __call__(self, x): - for i in range(self.num_layers): + for _ in range(self.num_layers): x = nn.Dense(self.dim)(x) x = self.act_fn(x) return nn.Dense(self.out_dim)(x) diff --git a/src/ott/neural/solvers/genot.py b/src/ott/neural/solvers/genot.py index a71f34760..0e8de2705 100644 --- a/src/ott/neural/solvers/genot.py +++ b/src/ott/neural/solvers/genot.py @@ -13,7 +13,7 @@ # limitations under the License. import functools import types -from typing import Any, Callable, Dict, Literal, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, Literal, Optional, Type, Union import diffrax import jax @@ -27,25 +27,20 @@ from ott.geometry import costs from ott.neural.models.models import BaseNeuralVectorField from ott.neural.solvers.base_solver import ( - BaseNeuralSolver, - ResampleMixin, - UnbalancednessMixin, + BaseNeuralSolver, + ResampleMixin, + UnbalancednessMixin, ) from ott.neural.solvers.flows import ( - BaseFlow, - BaseTimeSampler, - ConstantNoiseFlow, - UniformSampler, + BaseFlow, + BaseTimeSampler, + ConstantNoiseFlow, + UniformSampler, ) from ott.solvers import was_solver from ott.solvers.linear import sinkhorn from ott.solvers.quadratic import gromov_wasserstein -Match_fn_T = Callable[[jax.random.PRNGKeyArray, jnp.array, jnp.array], - Tuple[jnp.array, jnp.array, jnp.array, jnp.array]] -Match_latent_fn_T = Callable[[jax.random.PRNGKeyArray, jnp.array, jnp.array], - Tuple[jnp.array, jnp.array]] - __all__ = ["GENOT"] From 75437db12328dce43a9de1cf94caeeceed10944a Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 28 Nov 2023 18:25:04 +0100 Subject: [PATCH 036/186] [ci skip] fix genot tests --- src/ott/neural/solvers/genot.py | 36 ++++++------- tests/neural/conftest.py | 16 +++--- tests/neural/genot_test.py | 89 ++++++++++++++++++--------------- 3 files changed, 73 insertions(+), 68 deletions(-) diff --git a/src/ott/neural/solvers/genot.py b/src/ott/neural/solvers/genot.py index 0e8de2705..f6e7cd3e3 100644 --- a/src/ott/neural/solvers/genot.py +++ b/src/ott/neural/solvers/genot.py @@ -209,15 +209,14 @@ def __call__(self, train_loader, valid_loader) -> None: """Train GENOT.""" batch: Dict[str, jnp.array] = {} for iteration in range(self.iterations): - batch["source_lin"], batch["source_q"], batch["target_lin"], batch[ - "target_q"], batch["condition"] = next(train_loader) + batch = next(train_loader) self.rng, rng_time, rng_resample, rng_noise, rng_latent_data_match, rng_step_fn = jax.random.split( self.rng, 6 ) batch_size = len( batch["source_lin"] - ) if batch["source_lin"] is not None else len(batch["source_q"]) + ) if batch["source_lin"] is not None else len(batch["source_quad"]) n_samples = batch_size * self.k_samples_per_x batch["time"] = self.time_sampler(rng_time, n_samples) batch["noise"] = self.sample_noise(rng_noise, n_samples) @@ -226,34 +225,33 @@ def __call__(self, train_loader, valid_loader) -> None: ) tmat = self.match_fn( - batch["source_lin"], batch["source_q"], batch["target_lin"], - batch["target_q"] + batch["source_lin"], batch["source_quad"], batch["target_lin"], + batch["target_quad"] ) batch["source"] = jnp.concatenate([ batch[el] - for el in ["source_lin", "source_q"] + for el in ["source_lin", "source_quad"] if batch[el] is not None ], axis=1) batch["target"] = jnp.concatenate([ batch[el] - for el in ["target_lin", "target_q"] + for el in ["target_lin", "target_quad"] if batch[el] is not None ], axis=1) batch = { - k: v - for k, v in batch.items() - if k in ["source", "target", "condition", "time", "noise", "latent"] + k: v for k, v in batch.items() if k in + ["source", "target", "source_conditions", "time", "noise", "latent"] } - (batch["source"], batch["condition"] + (batch["source"], batch["source_conditions"] ), (batch["target"],) = self._sample_conditional_indices_from_tmap( rng_resample, tmat, - self.k_samples_per_x, (batch["source"], batch["condition"]), + self.k_samples_per_x, (batch["source"], batch["source_conditions"]), (batch["target"],), source_is_balanced=(self.tau_a == 1.0) ) @@ -268,10 +266,10 @@ def __call__(self, train_loader, valid_loader) -> None: rng_latent_data_match = jax.random.split( rng_latent_data_match, self.k_samples_per_x ) - (batch["source"], batch["condition"] + (batch["source"], batch["source_conditions"] ), (batch["target"],) = jax.vmap(self._resample_data, 0, 0)( rng_latent_data_match, tmats_latent_data, - (batch["source"], batch["condition"]), (batch["target"],) + (batch["source"], batch["source_conditions"]), (batch["target"],) ) batch = { key: @@ -287,7 +285,7 @@ def __call__(self, train_loader, valid_loader) -> None: self.state_eta, self.state_xi, eta_predictions, xi_predictions, loss_a, loss_b = self.unbalancedness_step_fn( source=batch["source"], target=batch["target"], - condition=batch["condition"], + condition=batch["source_conditions"], a=tmat.sum(axis=1), b=tmat.sum(axis=0), state_eta=self.state_eta, @@ -299,8 +297,8 @@ def __call__(self, train_loader, valid_loader) -> None: states_to_save = { "state_neural_vector_field": self.state_neural_vector_field } - if self.state_mlp is not None: - states_to_save["state_eta"] = self.state_mlp + if self.state_eta is not None: + states_to_save["state_eta"] = self.state_eta if self.state_xi is not None: states_to_save["state_xi"] = self.state_xi self.checkpoint_manager.save(iteration, states_to_save) @@ -326,7 +324,9 @@ def loss_fn( ) cond_input = jnp.concatenate([ - batch[el] for el in ["source", "condition"] if batch[el] is not None + batch[el] + for el in ["source", "source_conditions"] + if batch[el] is not None ], axis=1) v_t = jax.vmap(apply_fn)( diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py index edb635e90..fb70cbb4d 100644 --- a/tests/neural/conftest.py +++ b/tests/neural/conftest.py @@ -58,14 +58,12 @@ def genot_data_loader_linear_conditional(): rng = np.random.default_rng(seed=0) source = rng.normal(size=(100, 2)) target = rng.normal(size=(100, 2)) + 1.0 - conditions_source = rng.normal(size=(100, 4)) - conditions_target = rng.normal(size=(100, 4)) - 1.0 + source_conditions = rng.normal(size=(100, 4)) return OTDataLoader( 16, source_lin=source, target_lin=target, - conditions_source=conditions_source, - conditions_target=conditions_target + source_conditions=source_conditions, ) @@ -84,13 +82,12 @@ def genot_data_loader_quad_conditional(): rng = np.random.default_rng(seed=0) source = rng.normal(size=(100, 2)) target = rng.normal(size=(100, 1)) + 1.0 - conditions = rng.normal(size=(100, 7)) + source_conditions = rng.normal(size=(100, 7)) return OTDataLoader( 16, source_quad=source, target_quad=target, - source_conditions=conditions, - target_conditions=conditions + source_conditions=source_conditions, ) @@ -119,13 +116,12 @@ def genot_data_loader_fused_conditional(): target_q = rng.normal(size=(100, 1)) + 1.0 source_lin = rng.normal(size=(100, 2)) target_lin = rng.normal(size=(100, 2)) + 1.0 - conditions = rng.normal(size=(100, 7)) + source_conditions = rng.normal(size=(100, 7)) return OTDataLoader( 16, source_lin=source_lin, source_quad=source_q, target_lin=target_lin, target_quad=target_q, - source_conditions=conditions, - target_conditions=conditions + source_conditions=source_conditions, ) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index ed65fc657..620c92138 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -36,9 +36,11 @@ def test_genot_linear_unconditional( ): solver_latent_to_data = None if solver_latent_to_data is None else sinkhorn.Sinkhorn( ) - source_lin, source_quad, target_lin, target_quad, condition = next( - genot_data_loader_linear - ) + batch = next(genot_data_loader_linear) + source_lin, source_quad, target_lin, target_quad, source_condition = batch[ + "source_lin"], batch["source_quad"], batch["target_lin"], batch[ + "target_quad"], batch["source_conditions"] + source_dim = source_lin.shape[1] target_dim = target_lin.shape[1] condition_dim = 0 @@ -69,11 +71,13 @@ def test_genot_linear_unconditional( ) genot(genot_data_loader_linear, genot_data_loader_linear) - source_lin, source_quad, target_lin, target_quad, condition = next( - genot_data_loader_linear - ) + batch = next(genot_data_loader_linear) + source_lin, source_quad, target_lin, target_quad, source_condition = batch[ + "source_lin"], batch["source_quad"], batch["target_lin"], batch[ + "target_quad"], batch["source_conditions"] + result_forward = genot.transport( - source_lin, condition=condition, forward=True + source_lin, condition=source_condition, forward=True ) assert isinstance(result_forward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_forward)) == 0 @@ -85,9 +89,11 @@ def test_genot_quad_unconditional( solver_latent_to_data: Optional[str] ): None if solver_latent_to_data is None else sinkhorn.Sinkhorn() - source_lin, source_quad, target_lin, target_quad, condition = next( - genot_data_loader_quad - ) + batch = next(genot_data_loader_quad) + source_lin, source_quad, target_lin, target_quad, source_condition = batch[ + "source_lin"], batch["source_quad"], batch["target_lin"], batch[ + "target_quad"], batch["source_conditions"] + source_dim = source_quad.shape[1] target_dim = target_quad.shape[1] condition_dim = 0 @@ -117,7 +123,7 @@ def test_genot_quad_unconditional( genot(genot_data_loader_quad, genot_data_loader_quad) result_forward = genot.transport( - source_quad, condition=condition, forward=True + source_quad, condition=source_condition, forward=True ) assert isinstance(result_forward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_forward)) == 0 @@ -129,9 +135,11 @@ def test_genot_fused_unconditional( solver_latent_to_data: Optional[str] ): None if solver_latent_to_data is None else sinkhorn.Sinkhorn() - source_lin, source_quad, target_lin, target_quad, condition = next( - genot_data_loader_fused - ) + batch = next(genot_data_loader_fused) + batch = source_lin, source_quad, target_lin, target_quad, source_condition = batch[ + "source_lin"], batch["source_quad"], batch["target_lin"], batch[ + "target_quad"], batch["source_conditions"] + source_dim = source_lin.shape[1] + source_quad.shape[1] target_dim = target_lin.shape[1] + target_quad.shape[1] condition_dim = 0 @@ -162,7 +170,7 @@ def test_genot_fused_unconditional( result_forward = genot.transport( jnp.concatenate((source_lin, source_quad), axis=1), - condition=condition, + condition=source_condition, forward=True ) assert isinstance(result_forward, jnp.ndarray) @@ -175,12 +183,12 @@ def test_genot_linear_conditional( k_samples_per_x: int, solver_latent_to_data: Optional[str] ): None if solver_latent_to_data is None else sinkhorn.Sinkhorn() - source_lin, source_quad, target_lin, target_quad, condition = next( - genot_data_loader_linear_conditional - ) + batch = next(genot_data_loader_linear_conditional) + source_lin, target_lin, source_condition = batch["source_lin"], batch[ + "target_lin"], batch["source_conditions"] source_dim = source_lin.shape[1] target_dim = target_lin.shape[1] - condition_dim = condition.shape[1] + condition_dim = source_condition.shape[1] neural_vf = NeuralVectorField( output_dim=target_dim, @@ -209,12 +217,8 @@ def test_genot_linear_conditional( genot_data_loader_linear_conditional, genot_data_loader_linear_conditional ) - - source_lin, source_quad, target_lin, target_quad, condition = next( - genot_data_loader_linear_conditional - ) result_forward = genot.transport( - source_lin, condition=condition, forward=True + source_lin, condition=source_condition, forward=True ) assert isinstance(result_forward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_forward)) == 0 @@ -226,12 +230,14 @@ def test_genot_quad_conditional( solver_latent_to_data: Optional[str] ): None if solver_latent_to_data is None else sinkhorn.Sinkhorn() - source_lin, source_quad, target_lin, target_quad, condition = next( - genot_data_loader_quad_conditional - ) + batch = next(genot_data_loader_quad_conditional) + source_lin, source_quad, target_lin, target_quad, source_condition = batch[ + "source_lin"], batch["source_quad"], batch["target_lin"], batch[ + "target_quad"], batch["source_conditions"] + source_dim = source_quad.shape[1] target_dim = target_quad.shape[1] - condition_dim = condition.shape[1] + condition_dim = source_condition.shape[1] neural_vf = NeuralVectorField( output_dim=target_dim, condition_dim=source_dim + condition_dim, @@ -260,7 +266,7 @@ def test_genot_quad_conditional( ) result_forward = genot.transport( - source_quad, condition=condition, forward=True + source_quad, condition=source_condition, forward=True ) assert isinstance(result_forward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_forward)) == 0 @@ -272,12 +278,13 @@ def test_genot_fused_conditional( solver_latent_to_data: Optional[str] ): None if solver_latent_to_data is None else sinkhorn.Sinkhorn() - source_lin, source_quad, target_lin, target_quad, condition = next( - genot_data_loader_fused_conditional - ) + batch = next(genot_data_loader_fused_conditional) + source_lin, source_quad, target_lin, target_quad, source_condition = batch[ + "source_lin"], batch["source_quad"], batch["target_lin"], batch[ + "target_quad"], batch["source_conditions"] source_dim = source_lin.shape[1] + source_quad.shape[1] target_dim = target_lin.shape[1] + target_quad.shape[1] - condition_dim = condition.shape[1] + condition_dim = source_condition.shape[1] neural_vf = NeuralVectorField( output_dim=target_dim, condition_dim=source_dim + condition_dim, @@ -307,7 +314,7 @@ def test_genot_fused_conditional( result_forward = genot.transport( jnp.concatenate((source_lin, source_quad), axis=1), - condition=condition, + condition=source_condition, forward=True ) assert isinstance(result_forward, jnp.ndarray) @@ -323,12 +330,14 @@ def test_genot_linear_learn_rescaling( None if solver_latent_to_data is None else sinkhorn.Sinkhorn() data_loader = genot_data_loader_linear_conditional if conditional else genot_data_loader_linear - source_lin, source_quad, target_lin, target_quad, condition = next( - data_loader - ) + batch = next(data_loader) + source_lin, source_quad, target_lin, target_quad, source_condition = batch[ + "source_lin"], batch["source_quad"], batch["target_lin"], batch[ + "target_quad"], batch["source_conditions"] + source_dim = source_lin.shape[1] target_dim = target_lin.shape[1] - condition_dim = condition.shape[1] if conditional else 0 + condition_dim = source_condition.shape[1] if conditional else 0 neural_vf = NeuralVectorField( output_dim=target_dim, @@ -363,10 +372,10 @@ def test_genot_linear_learn_rescaling( genot(data_loader, data_loader) - result_eta = genot.evaluate_eta(source_lin, condition=condition) + result_eta = genot.evaluate_eta(source_lin, condition=source_condition) assert isinstance(result_eta, jnp.ndarray) assert jnp.sum(jnp.isnan(result_eta)) == 0 - result_xi = genot.evaluate_xi(target_lin, condition=condition) + result_xi = genot.evaluate_xi(target_lin, condition=source_condition) assert isinstance(result_xi, jnp.ndarray) assert jnp.sum(jnp.isnan(result_xi)) == 0 From bfcfcbdbf667d714e899af409db2e9dbd7203e54 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 28 Nov 2023 18:29:27 +0100 Subject: [PATCH 037/186] [ci skip] fix otfm tests --- tests/neural/conftest.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py index fb70cbb4d..1aa567d8f 100644 --- a/tests/neural/conftest.py +++ b/tests/neural/conftest.py @@ -10,7 +10,7 @@ def data_loader_gaussian(): rng = np.random.default_rng(seed=0) source = rng.normal(size=(100, 2)) target = rng.normal(size=(100, 2)) + 1.0 - return OTDataLoader(source, target, 16) + return OTDataLoader(16, source_lin=source, target_lin=target) @pytest.fixture(scope="module") @@ -23,10 +23,10 @@ def data_loader_gaussian_conditional(): source_1 = rng.normal(size=(100, 2)) target_1 = rng.normal(size=(100, 2)) - 2.0 dl0 = OTDataLoader( - source_0, target_0, 16, source_conditions=np.zeros_like(source_0) * 0.0 + 16, source_lin=source_0, target_lin=target_0, source_conditions=np.zeros_like(source_0) * 0.0 ) dl1 = OTDataLoader( - source_1, target_1, 16, source_conditions=np.ones_like(source_1) * 1.0 + 16, source_lin=source_1, target_lin=target_1, source_conditions=np.ones_like(source_1) * 1.0 ) return ConditionalDataLoader({"0": dl0, "1": dl1}, np.array([0.5, 0.5])) @@ -40,7 +40,7 @@ def data_loader_gaussian_with_conditions(): target = rng.normal(size=(100, 2)) + 1.0 source_conditions = rng.normal(size=(100, 1)) target_conditions = rng.normal(size=(100, 1)) - 1.0 - return OTDataLoader(source, target, 16, source_conditions, target_conditions) + return OTDataLoader(16, source_lin=source, target_lin=target, source_conditions=source_conditions, target_conditions=target_conditions) @pytest.fixture(scope="module") From f27bc22ba84c462d00e5f25548c3b691fde250f7 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 28 Nov 2023 18:34:50 +0100 Subject: [PATCH 038/186] [ci skip] fix otfm tests --- tests/neural/conftest.py | 18 +++++++++++++++--- tests/neural/otfm_test.py | 8 ++++---- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py index 1aa567d8f..c6d25b128 100644 --- a/tests/neural/conftest.py +++ b/tests/neural/conftest.py @@ -23,10 +23,16 @@ def data_loader_gaussian_conditional(): source_1 = rng.normal(size=(100, 2)) target_1 = rng.normal(size=(100, 2)) - 2.0 dl0 = OTDataLoader( - 16, source_lin=source_0, target_lin=target_0, source_conditions=np.zeros_like(source_0) * 0.0 + 16, + source_lin=source_0, + target_lin=target_0, + source_conditions=np.zeros_like(source_0) * 0.0 ) dl1 = OTDataLoader( - 16, source_lin=source_1, target_lin=target_1, source_conditions=np.ones_like(source_1) * 1.0 + 16, + source_lin=source_1, + target_lin=target_1, + source_conditions=np.ones_like(source_1) * 1.0 ) return ConditionalDataLoader({"0": dl0, "1": dl1}, np.array([0.5, 0.5])) @@ -40,7 +46,13 @@ def data_loader_gaussian_with_conditions(): target = rng.normal(size=(100, 2)) + 1.0 source_conditions = rng.normal(size=(100, 1)) target_conditions = rng.normal(size=(100, 1)) - 1.0 - return OTDataLoader(16, source_lin=source, target_lin=target, source_conditions=source_conditions, target_conditions=target_conditions) + return OTDataLoader( + 16, + source_lin=source, + target_lin=target, + source_conditions=source_conditions, + target_conditions=target_conditions + ) @pytest.fixture(scope="module") diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index 4346b6be8..04413914d 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -68,7 +68,7 @@ def test_flow_matching(self, data_loader_gaussian, flow: Type[BaseFlow]): result_backward = fm.transport( batch["target_lin"], - condition=batch["target_conditions"], + condition=batch["source_conditions"], forward=False ) assert isinstance(result_backward, jnp.ndarray) @@ -116,7 +116,7 @@ def test_flow_matching_with_conditions( result_backward = fm.transport( batch["target_lin"], - condition=batch["target_conditions"], + condition=batch["source_conditions"], forward=False ) assert isinstance(result_backward, jnp.ndarray) @@ -161,7 +161,7 @@ def test_flow_matching_conditional( result_backward = fm.transport( batch["target_lin"], - condition=batch["target_conditions"], + condition=batch["source_conditions"], forward=False ) assert isinstance(result_backward, jnp.ndarray) @@ -214,7 +214,7 @@ def test_flow_matching_learn_rescaling( assert jnp.sum(jnp.isnan(result_eta)) == 0 result_xi = fm.evaluate_xi( - batch["target_lin"], condition=batch["target_conditions"] + batch["target_lin"], condition=batch["source_conditions"] ) assert isinstance(result_xi, jnp.ndarray) assert jnp.sum(jnp.isnan(result_xi)) == 0 From 384e8fcc15fadd1feba75301c74a457c1c6f0bf7 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 28 Nov 2023 18:59:24 +0100 Subject: [PATCH 039/186] add scale cost to otfm --- src/ott/neural/solvers/otfm.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/src/ott/neural/solvers/otfm.py b/src/ott/neural/solvers/otfm.py index 2afc94e6d..83b2ceaa6 100644 --- a/src/ott/neural/solvers/otfm.py +++ b/src/ott/neural/solvers/otfm.py @@ -14,7 +14,17 @@ import functools import types from collections import defaultdict -from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Type +from typing import ( + Any, + Callable, + Dict, + Literal, + Mapping, + Optional, + Tuple, + Type, + Union, +) import diffrax import jax @@ -27,13 +37,13 @@ from ott.geometry import costs from ott.neural.models.models import BaseNeuralVectorField from ott.neural.solvers.base_solver import ( - BaseNeuralSolver, - ResampleMixin, - UnbalancednessMixin, + BaseNeuralSolver, + ResampleMixin, + UnbalancednessMixin, ) from ott.neural.solvers.flows import ( - BaseFlow, - BaseTimeSampler, + BaseFlow, + BaseTimeSampler, ) from ott.solvers import was_solver @@ -56,6 +66,7 @@ class OTFlowMatching(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): checkpoint_manager: Checkpoint manager. epsilon: Entropy regularization term of the OT OT problem solved by the `ot_solver`. cost_fn: Cost function for the OT problem solved by the `ot_solver`. + scale_cost: How to scale the cost matrix for the OT problem solved by the `ot_solver`. tau_a: If :math:`<1`, defines how much unbalanced the problem is on the first marginal. tau_b: If :math:`< 1`, defines how much unbalanced the problem is @@ -85,6 +96,9 @@ def __init__( checkpoint_manager: Type[checkpoint.CheckpointManager] = None, epsilon: float = 1e-2, cost_fn: Type[costs.CostFn] = costs.SqEuclidean(), + scale_cost: Union[bool, int, float, + Literal["mean", "max_norm", "max_bound", "max_cost", + "median"]] = "mean", tau_a: float = 1.0, tau_b: float = 1.0, mlp_eta: Callable[[jnp.ndarray], float] = None, @@ -123,6 +137,7 @@ def __init__( self.optimizer = optimizer self.epsilon = epsilon self.cost_fn = cost_fn + self.scale_cost = scale_cost self.callback_fn = callback_fn self.checkpoint_manager = checkpoint_manager self.rng = rng From ef204e69a11725b1e2ce060e1ebf3a500a13d504 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 29 Nov 2023 10:37:33 +0100 Subject: [PATCH 040/186] incorporate feedback partially --- src/ott/datasets.py | 22 ++++++----------- src/ott/neural/models/models.py | 4 ++-- src/ott/neural/solvers/base_solver.py | 34 +++++++++++++-------------- src/ott/neural/solvers/genot.py | 4 ++-- src/ott/neural/solvers/otfm.py | 28 +++++++++++----------- tests/geometry/scaling_cost_test.py | 4 ++-- tests/neural/genot_test.py | 21 +++++++---------- tests/neural/otfm_test.py | 6 ++--- 8 files changed, 56 insertions(+), 67 deletions(-) diff --git a/src/ott/datasets.py b/src/ott/datasets.py index 07bd87fb9..9ddc0435a 100644 --- a/src/ott/datasets.py +++ b/src/ott/datasets.py @@ -51,13 +51,13 @@ class GaussianMixture: rectangle batch_size: batch size of the samples - init_rng: initial PRNG key + rng: initial PRNG key scale: scale of the Gaussian means std: the standard deviation of the individual Gaussian samples """ name: Name_t batch_size: int - init_rng: jnp.ndarray + rng: jnp.ndarray scale: float = 5.0 std: float = 0.5 @@ -96,7 +96,7 @@ def __iter__(self) -> Iterator[jnp.array]: return self._create_sample_generators() def _create_sample_generators(self) -> Iterator[jnp.array]: - rng = self.init_rng + rng = self.rng while True: rng1, rng2, rng = jax.random.split(rng, 3) means = jax.random.choice(rng1, self.centers, (self.batch_size,)) @@ -128,26 +128,18 @@ def create_gaussian_mixture_samplers( rng1, rng2, rng3, rng4 = jax.random.split(rng, 4) train_dataset = Dataset( source_iter=iter( - GaussianMixture( - name_source, batch_size=train_batch_size, init_rng=rng1 - ) + GaussianMixture(name_source, batch_size=train_batch_size, rng=rng1) ), target_iter=iter( - GaussianMixture( - name_target, batch_size=train_batch_size, init_rng=rng2 - ) + GaussianMixture(name_target, batch_size=train_batch_size, rng=rng2) ) ) valid_dataset = Dataset( source_iter=iter( - GaussianMixture( - name_source, batch_size=valid_batch_size, init_rng=rng3 - ) + GaussianMixture(name_source, batch_size=valid_batch_size, rng=rng3) ), target_iter=iter( - GaussianMixture( - name_target, batch_size=valid_batch_size, init_rng=rng4 - ) + GaussianMixture(name_target, batch_size=valid_batch_size, rng=rng4) ) ) dim_data = 2 diff --git a/src/ott/neural/models/models.py b/src/ott/neural/models/models.py index c8bcb00ce..f2ef76162 100644 --- a/src/ott/neural/models/models.py +++ b/src/ott/neural/models/models.py @@ -35,7 +35,7 @@ from ott.problems.linear import linear_problem __all__ = [ - "ICNN", "MLP", "MetaInitializer", "NeuralVectorField", "Rescaling_MLP" + "ICNN", "MLP", "MetaInitializer", "NeuralVectorField", "RescalingMLP" ] @@ -585,7 +585,7 @@ def create_train_state( ) -class Rescaling_MLP(BaseRescalingNet): +class RescalingMLP(BaseRescalingNet): """Network to learn distributional rescaling factors based on a MLP. The input is passed through a block consisting of ``num_layers_per_block`` with size ``hidden_dim``. diff --git a/src/ott/neural/solvers/base_solver.py b/src/ott/neural/solvers/base_solver.py index cc9a4c310..ce20d09c5 100644 --- a/src/ott/neural/solvers/base_solver.py +++ b/src/ott/neural/solvers/base_solver.py @@ -200,26 +200,26 @@ def _get_gromov_match_fn( fused_penalty: float, ) -> Callable: if isinstance(cost_fn, Mapping): - assert "x_cost_fn" in cost_fn - assert "y_cost_fn" in cost_fn - x_cost_fn = cost_fn["x_cost_fn"] - y_cost_fn = cost_fn["y_cost_fn"] + assert "cost_fn_xx" in cost_fn + assert "cost_fn_yy" in cost_fn + cost_fn_xx = cost_fn["cost_fn_xx"] + cost_fn_yy = cost_fn["cost_fn_yy"] if fused_penalty > 0: - assert "xy_cost_fn" in x_cost_fn - xy_cost_fn = cost_fn["xy_cost_fn"] + assert "cost_fn_xy" in cost_fn_xx + cost_fn_xy = cost_fn["cost_fn_xy"] else: - x_cost_fn = y_cost_fn = xy_cost_fn = cost_fn + cost_fn_xx = cost_fn_yy = cost_fn_xy = cost_fn if isinstance(scale_cost, Mapping): - assert "x_scale_cost" in scale_cost - assert "y_scale_cost" in scale_cost - x_scale_cost = scale_cost["x_scale_cost"] - y_scale_cost = scale_cost["y_scale_cost"] + assert "scale_cost_xx" in scale_cost + assert "scale_cost_yy" in scale_cost + scale_cost_xx = scale_cost["scale_cost_xx"] + scale_cost_yy = scale_cost["scale_cost_yy"] if fused_penalty > 0: - assert "xy_scale_cost" in scale_cost - xy_scale_cost = cost_fn["xy_scale_cost"] + assert "scale_cost_xy" in scale_cost + scale_cost_xy = cost_fn["scale_cost_xy"] else: - x_scale_cost = y_scale_cost = xy_scale_cost = scale_cost + scale_cost_xx = scale_cost_yy = scale_cost_xy = scale_cost def match_pairs( x_lin: Optional[jnp.ndarray], @@ -228,14 +228,14 @@ def match_pairs( y_quad: Tuple[jnp.ndarray, jnp.ndarray], ) -> Tuple[jnp.array, jnp.array]: geom_xx = pointcloud.PointCloud( - x=x_quad, y=x_quad, cost_fn=x_cost_fn, scale_cost=x_scale_cost + x=x_quad, y=x_quad, cost_fn=cost_fn_xx, scale_cost=scale_cost_xx ) geom_yy = pointcloud.PointCloud( - x=y_quad, y=y_quad, cost_fn=y_cost_fn, scale_cost=y_scale_cost + x=y_quad, y=y_quad, cost_fn=cost_fn_yy, scale_cost=scale_cost_yy ) if fused_penalty > 0: geom_xy = pointcloud.PointCloud( - x=x_lin, y=y_lin, cost_fn=xy_cost_fn, scale_cost=xy_scale_cost + x=x_lin, y=y_lin, cost_fn=cost_fn_xy, scale_cost=scale_cost_xy ) else: geom_xy = None diff --git a/src/ott/neural/solvers/genot.py b/src/ott/neural/solvers/genot.py index f6e7cd3e3..72adce17e 100644 --- a/src/ott/neural/solvers/genot.py +++ b/src/ott/neural/solvers/genot.py @@ -56,8 +56,8 @@ class GENOT(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): valid_freq: Frequency of validation. ot_solver: OT solver to match samples from the source and the target distribution. epsilon: Entropy regularization term of the OT problem solved by `ot_solver`. - cost_fn: Cost function for the OT problem solved by the `ot_solver`. In the linear case, this is always expected to be of type `str`. If the problem is of quadratic type and `cost_fn` is a string, the `cost_fn` is used for all terms, i.e. both quadratic terms and, if applicable, the linear temr. If of type :class:`dict`, the keys are expected to be `x_cost_fn`, `y_cost_fn`, and if applicable, `xy_cost_fn`. - scale_cost: How to scale the cost matrix for the OT problem solved by the `ot_solver`. In the linear case, this is always expected to be not a :class:`dict`. If the problem is of quadratic type and `scale_cost` is a string, the `scale_cost` argument is used for all terms, i.e. both quadratic terms and, if applicable, the linear temr. If of type :class:`dict`, the keys are expected to be `x_scale_cost`, `y_scale_cost`, and if applicable, `xy_scale_cost`. + cost_fn: Cost function for the OT problem solved by the `ot_solver`. In the linear case, this is always expected to be of type `str`. If the problem is of quadratic type and `cost_fn` is a string, the `cost_fn` is used for all terms, i.e. both quadratic terms and, if applicable, the linear temr. If of type :class:`dict`, the keys are expected to be `cost_fn_xx`, `cost_fn_yy`, and if applicable, `cost_fn_xy`. + scale_cost: How to scale the cost matrix for the OT problem solved by the `ot_solver`. In the linear case, this is always expected to be not a :class:`dict`. If the problem is of quadratic type and `scale_cost` is a string, the `scale_cost` argument is used for all terms, i.e. both quadratic terms and, if applicable, the linear temr. If of type :class:`dict`, the keys are expected to be `scale_cost_xx`, `scale_cost_yy`, and if applicable, `scale_cost_xy`. optimizer: Optimizer for `neural_vector_field`. flow: Flow between latent distribution and target distribution. time_sampler: Sampler for the time. diff --git a/src/ott/neural/solvers/otfm.py b/src/ott/neural/solvers/otfm.py index 83b2ceaa6..b69d7978e 100644 --- a/src/ott/neural/solvers/otfm.py +++ b/src/ott/neural/solvers/otfm.py @@ -15,15 +15,15 @@ import types from collections import defaultdict from typing import ( - Any, - Callable, - Dict, - Literal, - Mapping, - Optional, - Tuple, - Type, - Union, + Any, + Callable, + Dict, + Literal, + Mapping, + Optional, + Tuple, + Type, + Union, ) import diffrax @@ -37,13 +37,13 @@ from ott.geometry import costs from ott.neural.models.models import BaseNeuralVectorField from ott.neural.solvers.base_solver import ( - BaseNeuralSolver, - ResampleMixin, - UnbalancednessMixin, + BaseNeuralSolver, + ResampleMixin, + UnbalancednessMixin, ) from ott.neural.solvers.flows import ( - BaseFlow, - BaseTimeSampler, + BaseFlow, + BaseTimeSampler, ) from ott.solvers import was_solver diff --git a/tests/geometry/scaling_cost_test.py b/tests/geometry/scaling_cost_test.py index ce3f616ce..9f4ad1d57 100644 --- a/tests/geometry/scaling_cost_test.py +++ b/tests/geometry/scaling_cost_test.py @@ -188,7 +188,7 @@ def apply_sinkhorn(cost1, cost2, scale_cost): np.testing.assert_allclose(1.0, geom.cost_matrix.max(), rtol=1e-4) @pytest.mark.parametrize("batch_size", [5, 12]) - def test_max_scale_cost_low_rank_with_batch(self, batch_size: int): + def test_mascale_cost_xx_low_rank_with_batch(self, batch_size: int): """Test max_cost options for low rank with batch_size fixed.""" geom0 = low_rank.LRCGeometry( @@ -199,7 +199,7 @@ def test_max_scale_cost_low_rank_with_batch(self, batch_size: int): geom0.inv_scale_cost, 1.0 / jnp.max(self.cost_lr), rtol=1e-4 ) - def test_max_scale_cost_low_rank_large_array(self): + def test_mascale_cost_xx_low_rank_large_array(self): """Test max_cost options for large matrices.""" _, *rngs = jax.random.split(self.rng, 3) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index 620c92138..5c7c1d431 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -18,7 +18,7 @@ import pytest from ott.geometry import costs -from ott.neural.models.models import NeuralVectorField, Rescaling_MLP +from ott.neural.models.models import NeuralVectorField, RescalingMLP from ott.neural.solvers.flows import OffsetUniformSampler, UniformSampler from ott.neural.solvers.genot import GENOT from ott.solvers.linear import sinkhorn @@ -90,9 +90,8 @@ def test_genot_quad_unconditional( ): None if solver_latent_to_data is None else sinkhorn.Sinkhorn() batch = next(genot_data_loader_quad) - source_lin, source_quad, target_lin, target_quad, source_condition = batch[ - "source_lin"], batch["source_quad"], batch["target_lin"], batch[ - "target_quad"], batch["source_conditions"] + source_quad, target_quad, source_condition = batch["source_quad"], batch[ + "target_quad"], batch["source_conditions"] source_dim = source_quad.shape[1] target_dim = target_quad.shape[1] @@ -231,9 +230,8 @@ def test_genot_quad_conditional( ): None if solver_latent_to_data is None else sinkhorn.Sinkhorn() batch = next(genot_data_loader_quad_conditional) - source_lin, source_quad, target_lin, target_quad, source_condition = batch[ - "source_lin"], batch["source_quad"], batch["target_lin"], batch[ - "target_quad"], batch["source_conditions"] + source_quad, target_quad, source_condition = batch["source_quad"], batch[ + "target_quad"], batch["source_conditions"] source_dim = source_quad.shape[1] target_dim = target_quad.shape[1] @@ -331,9 +329,8 @@ def test_genot_linear_learn_rescaling( data_loader = genot_data_loader_linear_conditional if conditional else genot_data_loader_linear batch = next(data_loader) - source_lin, source_quad, target_lin, target_quad, source_condition = batch[ - "source_lin"], batch["source_quad"], batch["target_lin"], batch[ - "target_quad"], batch["source_conditions"] + source_lin, target_lin, source_condition = batch["source_lin"], batch[ + "target_lin"], batch["source_conditions"] source_dim = source_lin.shape[1] target_dim = target_lin.shape[1] @@ -349,8 +346,8 @@ def test_genot_linear_learn_rescaling( optimizer = optax.adam(learning_rate=1e-3) tau_a = 0.9 tau_b = 0.2 - mlp_eta = Rescaling_MLP(hidden_dim=4, condition_dim=condition_dim) - mlp_xi = Rescaling_MLP(hidden_dim=4, condition_dim=condition_dim) + mlp_eta = RescalingMLP(hidden_dim=4, condition_dim=condition_dim) + mlp_xi = RescalingMLP(hidden_dim=4, condition_dim=condition_dim) genot = GENOT( neural_vf, input_dim=source_dim, diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index 04413914d..d8deb1102 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -17,7 +17,7 @@ import optax import pytest -from ott.neural.models.models import NeuralVectorField, Rescaling_MLP +from ott.neural.models.models import NeuralVectorField, RescalingMLP from ott.neural.solvers.flows import ( BaseFlow, BrownianNoiseFlow, @@ -188,8 +188,8 @@ def test_flow_matching_learn_rescaling( tau_a = 0.9 tau_b = 0.2 - mlp_eta = Rescaling_MLP(hidden_dim=4, condition_dim=condition_dim) - mlp_xi = Rescaling_MLP(hidden_dim=4, condition_dim=condition_dim) + mlp_eta = RescalingMLP(hidden_dim=4, condition_dim=condition_dim) + mlp_xi = RescalingMLP(hidden_dim=4, condition_dim=condition_dim) fm = OTFlowMatching( neural_vf, input_dim=source_dim, From 2b1ab921258ba63bb473ab397cef1f481314e7a1 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 29 Nov 2023 11:18:50 +0100 Subject: [PATCH 041/186] resolve circular import errors --- src/ott/neural/__init__.py | 2 +- src/ott/neural/models/__init__.py | 2 +- src/ott/neural/{ => models}/base_models.py | 0 src/ott/neural/{ => models}/layers.py | 0 src/ott/neural/{ => models}/losses.py | 0 src/ott/neural/{ => models}/models.py | 3 ++- tests/neural/losses_test.py | 2 +- tests/neural/map_estimator_test.py | 2 +- 8 files changed, 6 insertions(+), 5 deletions(-) rename src/ott/neural/{ => models}/base_models.py (100%) rename src/ott/neural/{ => models}/layers.py (100%) rename src/ott/neural/{ => models}/losses.py (100%) rename src/ott/neural/{ => models}/models.py (99%) diff --git a/src/ott/neural/__init__.py b/src/ott/neural/__init__.py index aa1ca23fa..16f90c799 100644 --- a/src/ott/neural/__init__.py +++ b/src/ott/neural/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import layers, losses, models, solvers +from . import models, solvers, data diff --git a/src/ott/neural/models/__init__.py b/src/ott/neural/models/__init__.py index d2a583f34..5e6590cd1 100644 --- a/src/ott/neural/models/__init__.py +++ b/src/ott/neural/models/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import base_models, conjugate_solvers, layers, models +from . import base_models, models, losses, layers diff --git a/src/ott/neural/base_models.py b/src/ott/neural/models/base_models.py similarity index 100% rename from src/ott/neural/base_models.py rename to src/ott/neural/models/base_models.py diff --git a/src/ott/neural/layers.py b/src/ott/neural/models/layers.py similarity index 100% rename from src/ott/neural/layers.py rename to src/ott/neural/models/layers.py diff --git a/src/ott/neural/losses.py b/src/ott/neural/models/losses.py similarity index 100% rename from src/ott/neural/losses.py rename to src/ott/neural/models/losses.py diff --git a/src/ott/neural/models.py b/src/ott/neural/models/models.py similarity index 99% rename from src/ott/neural/models.py rename to src/ott/neural/models/models.py index 4087853dd..15fdcfcbc 100644 --- a/src/ott/neural/models.py +++ b/src/ott/neural/models/models.py @@ -26,9 +26,10 @@ from ott.geometry import geometry from ott.initializers.linear import initializers as lin_init from ott.math import matrix_square_root -from ott.neural import layers +from ott.neural.models import layers from ott.neural.solvers import neuraldual from ott.problems.linear import linear_problem +from ott.neural.models.base_models import BaseNeuralVectorField, BaseRescalingNet __all__ = [ "ICNN", "MLP", "MetaInitializer", "NeuralVectorField", "RescalingMLP" diff --git a/tests/neural/losses_test.py b/tests/neural/losses_test.py index a432f2dc5..4569b04d1 100644 --- a/tests/neural/losses_test.py +++ b/tests/neural/losses_test.py @@ -18,7 +18,7 @@ import pytest from ott.geometry import costs -from ott.neural import losses, models +from ott.neural.models import losses, models @pytest.mark.fast() diff --git a/tests/neural/map_estimator_test.py b/tests/neural/map_estimator_test.py index e0ec0b56b..96f9a9797 100644 --- a/tests/neural/map_estimator_test.py +++ b/tests/neural/map_estimator_test.py @@ -18,7 +18,7 @@ from ott import datasets from ott.geometry import pointcloud -from ott.neural import losses, models +from ott.neural.models import losses, models from ott.neural.solvers import map_estimator from ott.tools import sinkhorn_divergence From e1be6ca6c0c6d07490f976f92b75e8ebefd7b11f Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 29 Nov 2023 11:23:48 +0100 Subject: [PATCH 042/186] resolve a few pre-commit errors --- src/ott/neural/__init__.py | 2 +- src/ott/neural/models/__init__.py | 2 +- src/ott/neural/models/models.py | 5 ++++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/ott/neural/__init__.py b/src/ott/neural/__init__.py index 16f90c799..326fae432 100644 --- a/src/ott/neural/__init__.py +++ b/src/ott/neural/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import models, solvers, data +from . import data, models, solvers diff --git a/src/ott/neural/models/__init__.py b/src/ott/neural/models/__init__.py index 5e6590cd1..1e374d236 100644 --- a/src/ott/neural/models/__init__.py +++ b/src/ott/neural/models/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import base_models, models, losses, layers +from . import base_models, layers, losses, models diff --git a/src/ott/neural/models/models.py b/src/ott/neural/models/models.py index 15fdcfcbc..5c7b3f30e 100644 --- a/src/ott/neural/models/models.py +++ b/src/ott/neural/models/models.py @@ -27,9 +27,12 @@ from ott.initializers.linear import initializers as lin_init from ott.math import matrix_square_root from ott.neural.models import layers +from ott.neural.models.base_models import ( + BaseNeuralVectorField, + BaseRescalingNet, +) from ott.neural.solvers import neuraldual from ott.problems.linear import linear_problem -from ott.neural.models.base_models import BaseNeuralVectorField, BaseRescalingNet __all__ = [ "ICNN", "MLP", "MetaInitializer", "NeuralVectorField", "RescalingMLP" From a307bf8791b479d3843cbe80dda71a99d8a4f1dd Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 29 Nov 2023 11:42:40 +0100 Subject: [PATCH 043/186] resolve pre-commit errors --- src/ott/neural/data/dataloaders.py | 6 ++- src/ott/neural/models/models.py | 25 +++++++++--- src/ott/neural/solvers/base_solver.py | 11 +++-- src/ott/neural/solvers/flows.py | 40 ++++++++++++++---- src/ott/neural/solvers/genot.py | 59 +++++++++++++++++++-------- src/ott/neural/solvers/otfm.py | 32 ++++++++++----- tests/neural/genot_test.py | 10 ++--- 7 files changed, 132 insertions(+), 51 deletions(-) diff --git a/src/ott/neural/data/dataloaders.py b/src/ott/neural/data/dataloaders.py index 938dabb96..4fe8a9a8c 100644 --- a/src/ott/neural/data/dataloaders.py +++ b/src/ott/neural/data/dataloaders.py @@ -103,10 +103,12 @@ def __next__(self) -> Mapping[str, np.ndarray]: class ConditionalDataLoader: """Data loader for OT problems with conditions. - This data loader wraps several data loaders and samples from them according to their conditions. + This data loader wraps several data loaders and samples from them according + to their conditions. Args: - dataloaders: Dictionary of data loaders with keys corresponding to conditions. + dataloaders: Dictionary of data loaders with keys corresponding to + conditions. p: Probability of sampling from each data loader. seed: Random seed. diff --git a/src/ott/neural/models/models.py b/src/ott/neural/models/models.py index 5c7b3f30e..9b15cb803 100644 --- a/src/ott/neural/models/models.py +++ b/src/ott/neural/models/models.py @@ -427,8 +427,11 @@ class NeuralVectorField(BaseNeuralVectorField): """Parameterized neural vector field. Each of the input, condition, and time embeddings are passed through a block - consisting of ``num_layers_per_block`` layers of dimension ``latent_embed_dim``, ``condition_embed_dim``, and ``time_embed_dim``, respectively. - The output of each block is concatenated and passed through a final block of dimension ``joint_hidden_dim``. + consisting of ``num_layers_per_block`` layers of dimension + ``latent_embed_dim``, ``condition_embed_dim``, and ``time_embed_dim``, + respectively. + The output of each block is concatenated and passed through a final block of + dimension ``joint_hidden_dim``. Args: output_dim: Dimensionality of the neural vector field. @@ -586,9 +589,12 @@ def create_train_state( class RescalingMLP(BaseRescalingNet): """Network to learn distributional rescaling factors based on a MLP. - The input is passed through a block consisting of ``num_layers_per_block`` with size ``hidden_dim``. - If ``condition_dim`` is greater than 0, the conditioning vector is passed through a block of the same size. - Both outputs are concatenated and passed through another block of the same size. + The input is passed through a block consisting of ``num_layers_per_block`` + with size ``hidden_dim``. + If ``condition_dim`` is greater than 0, the conditioning vector is passed + through a block of the same size. + Both outputs are concatenated and passed through another block of the same + size. To ensure non-negativity of the output, the output is exponentiated. @@ -610,6 +616,15 @@ class RescalingMLP(BaseRescalingNet): def __call__( self, x: jnp.ndarray, condition: Optional[jnp.ndarray] ) -> jnp.ndarray: # noqa: D102 + """Forward pass through the rescaling network. + + Args: + x: Data. + condition: Condition. + + Returns: + Estimated rescaling factors. + """ x = Block( dim=self.hidden_dim, out_dim=self.hidden_dim, diff --git a/src/ott/neural/solvers/base_solver.py b/src/ott/neural/solvers/base_solver.py index ce20d09c5..5c291afaa 100644 --- a/src/ott/neural/solvers/base_solver.py +++ b/src/ott/neural/solvers/base_solver.py @@ -439,7 +439,10 @@ def step_fn( else: new_state_xi = xi_predictions = loss_b = None - return new_state_eta, new_state_xi, eta_predictions, xi_predictions, loss_a, loss_b + return ( + new_state_eta, new_state_xi, eta_predictions, xi_predictions, loss_a, + loss_b + ) return step_fn @@ -449,7 +452,8 @@ def evaluate_eta( """Evaluate the left learnt rescaling factor. Args: - source: Samples from the source distribution to evaluate rescaling function on. + source: Samples from the source distribution to evaluate rescaling + function on. condition: Condition belonging to the samples in the source distribution. Returns: @@ -467,7 +471,8 @@ def evaluate_xi( """Evaluate the right learnt rescaling factor. Args: - target: Samples from the target distribution to evaluate the rescaling function on. + target: Samples from the target distribution to evaluate the rescaling + function on. condition: Condition belonging to the samples in the target distribution. Returns: diff --git a/src/ott/neural/solvers/flows.py b/src/ott/neural/solvers/flows.py index b02981fc9..b61ff08d1 100644 --- a/src/ott/neural/solvers/flows.py +++ b/src/ott/neural/solvers/flows.py @@ -34,7 +34,10 @@ def __init__(self, sigma: float) -> None: @abc.abstractmethod def compute_mu_t(self, t: jnp.ndarray, x_0: jnp.ndarray, x_1: jnp.ndarray): - """Compute the mean of the probablitiy path between :math:`x` and :math:`y` at time :math:`t`. + """Compute the mean of the probablitiy path. + + Compute the mean of the probablitiy path between :math:`x` and :math:`y` + at time :math:`t`. Args: t: Time :math:`t`. @@ -56,7 +59,10 @@ def compute_sigma_t(self, t: jnp.ndarray): def compute_ut( self, t: jnp.ndarray, x_0: jnp.ndarray, x_1: jnp.ndarray ) -> jnp.ndarray: - """Evaluate the conditional vector field defined between :math:`x_0` and :math:`x_1` at time :math:`t`. + """Evaluate the conditional vector field. + + Evaluate the conditional vector field defined between :math:`x_0` and + :math:`x_1` at time :math:`t`. Args: t: Time :math:`t`. @@ -69,7 +75,10 @@ def compute_xt( self, noise: jnp.ndarray, t: jnp.ndarray, x_0: jnp.ndarray, x_1: jnp.ndarray ) -> jnp.ndarray: - """Sample from the probability path between :math:`x_0` and :math:`x_1` at time :math:`t`. + """Sample from the probability path. + + Sample from the probability path between :math:`x_0` and :math:`x_1` at + time :math:`t`. Args: noise: Noise sampled from a standard normal distribution. @@ -78,7 +87,8 @@ def compute_xt( x_1: Sample from the target distribution. Returns: - Samples from the probability path between :math:`x_0` and :math:`x_1` at time :math:`t`. + Samples from the probability path between :math:`x_0` and :math:`x_1` + at time :math:`t`. """ mu_t = self.compute_mu_t(t, x_0, x_1) sigma_t = self.compute_sigma_t(t) @@ -91,7 +101,10 @@ class StraightFlow(BaseFlow, abc.ABC): def compute_mu_t( self, t: jnp.ndarray, x_0: jnp.ndarray, x_1: jnp.ndarray ) -> jnp.ndarray: - """Compute the mean of the probablitiy path between :math:`x` and :math:`y` at time :math:`t`. + """Compute the mean of the probablitiy path. + + Compute the mean of the probablitiy path between :math:`x` and :math:`y` + at time :math:`t`. Args: t: Time :math:`t`. @@ -103,7 +116,10 @@ def compute_mu_t( def compute_ut( self, t: jnp.ndarray, x_0: jnp.ndarray, x_1: jnp.ndarray ) -> jnp.ndarray: - """Evaluate the conditional vector field defined between :math:`x_0` and :math:`x_1` at time :math:`t`. + """Evaluate the conditional vector field. + + Evaluate the conditional vector field defined between :math:`x_0` and + :math:`x_1` at time :math:`t`. Args: t: Time :math:`t`. @@ -132,7 +148,12 @@ def compute_sigma_t(self, t: jnp.ndarray): class BrownianNoiseFlow(StraightFlow): - r"""Sampler for sampling noise implicitly defined by a Schroedinger Bridge problem with parameter `\sigma` such that :math:`\sigma_t = \sigma * \sqrt(t * (1-t))`.""" + r"""Brownian Bridge Flow. + + Sampler for sampling noise implicitly defined by a Schroedinger Bridge + problem with parameter `\sigma` such that + :math:`\sigma_t = \sigma * \sqrt(t * (1-t))`. + """ def compute_sigma_t(self, t: jnp.ndarray): """Compute the standard deviation of the probablity path at time :math:`t`. @@ -196,7 +217,10 @@ def __call__(self, rng: jnp.ndarray, num_samples: int) -> jnp.ndarray: class OffsetUniformSampler(BaseTimeSampler): - """Sample :math:`t` from a uniform distribution :math:`[low, high]` with offset `offset`. + """Sample the time :math:`t`. + + Sample :math:`t` from a uniform distribution :math:`[low, high]` with + offset `offset`. Args: offset: Offset of the uniform distribution. diff --git a/src/ott/neural/solvers/genot.py b/src/ott/neural/solvers/genot.py index 72adce17e..2eb81be98 100644 --- a/src/ott/neural/solvers/genot.py +++ b/src/ott/neural/solvers/genot.py @@ -54,24 +54,44 @@ class GENOT(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): cond_dim: Dimension of the conditioning variable. iterations: Number of iterations. valid_freq: Frequency of validation. - ot_solver: OT solver to match samples from the source and the target distribution. - epsilon: Entropy regularization term of the OT problem solved by `ot_solver`. - cost_fn: Cost function for the OT problem solved by the `ot_solver`. In the linear case, this is always expected to be of type `str`. If the problem is of quadratic type and `cost_fn` is a string, the `cost_fn` is used for all terms, i.e. both quadratic terms and, if applicable, the linear temr. If of type :class:`dict`, the keys are expected to be `cost_fn_xx`, `cost_fn_yy`, and if applicable, `cost_fn_xy`. - scale_cost: How to scale the cost matrix for the OT problem solved by the `ot_solver`. In the linear case, this is always expected to be not a :class:`dict`. If the problem is of quadratic type and `scale_cost` is a string, the `scale_cost` argument is used for all terms, i.e. both quadratic terms and, if applicable, the linear temr. If of type :class:`dict`, the keys are expected to be `scale_cost_xx`, `scale_cost_yy`, and if applicable, `scale_cost_xy`. + ot_solver: OT solver to match samples from the source and the target + distribution. + epsilon: Entropy regularization term of the OT problem solved by + `ot_solver`. + cost_fn: Cost function for the OT problem solved by the `ot_solver`. + In the linear case, this is always expected to be of type `str`. + If the problem is of quadratic type and `cost_fn` is a string, + the `cost_fn` is used for all terms, i.e. both quadratic terms and, + if applicable, the linear temr. If of type :class:`dict`, the keys + are expected to be `cost_fn_xx`, `cost_fn_yy`, and if applicable, + `cost_fn_xy`. + scale_cost: How to scale the cost matrix for the OT problem solved by + the `ot_solver`. In the linear case, this is always expected to be + not a :class:`dict`. If the problem is of quadratic type and + `scale_cost` is a string, the `scale_cost` argument is used for all + terms, i.e. both quadratic terms and, if applicable, the linear temr. + If of type :class:`dict`, the keys are expected to be `scale_cost_xx`, + `scale_cost_yy`, and if applicable, `scale_cost_xy`. optimizer: Optimizer for `neural_vector_field`. flow: Flow between latent distribution and target distribution. time_sampler: Sampler for the time. checkpoint_manager: Checkpoint manager. - k_samples_per_x: Number of samples drawn from the conditional distribution of an input sample, see algorithm TODO. - solver_latent_to_data: Linear OT solver to match the latent distribution with the conditional distribution. Only applicable if `k_samples_per_x` is larger than :math:`1`. #TODO: adapt - kwargs_solver_latent_to_data: Keyword arguments for `solver_latent_to_data`. #TODO: adapt - fused_penalty: Fused penalty of the linear/fused term in the Fused Gromov-Wasserstein problem. + k_samples_per_x: Number of samples drawn from the conditional distribution + of an input sample, see algorithm TODO. + solver_latent_to_data: Linear OT solver to match the latent distribution + with the conditional distribution. + kwargs_solver_latent_to_data: Keyword arguments for `solver_latent_to_data`. + #TODO: adapt + fused_penalty: Fused penalty of the linear/fused term in the Fused + Gromov-Wasserstein problem. tau_a: If :math:`<1`, defines how much unbalanced the problem is on the first marginal. tau_b: If :math:`< 1`, defines how much unbalanced the problem is on the second marginal. - mlp_eta: Neural network to learn the left rescaling function. If `None`, the left rescaling factor is not learnt. - mlp_xi: Neural network to learn the right rescaling function. If `None`, the right rescaling factor is not learnt. + mlp_eta: Neural network to learn the left rescaling function. If `None`, + the left rescaling factor is not learnt. + mlp_xi: Neural network to learn the right rescaling function. If `None`, + the right rescaling factor is not learnt. unbalanced_kwargs: Keyword arguments for the unbalancedness solver. callback_fn: Callback function. rng: Random number generator. @@ -107,7 +127,7 @@ def __init__( tau_b: float = 1.0, mlp_eta: Callable[[jnp.ndarray], float] = None, mlp_xi: Callable[[jnp.ndarray], float] = None, - unbalanced_kwargs: Dict[str, Any] = {}, + unbalanced_kwargs: Dict[str, Any] = types.MappingProxyType({}), callback_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], Any]] = None, rng: random.PRNGKeyArray = random.PRNGKey(0), @@ -133,8 +153,9 @@ def __init__( ot_solver, gromov_wasserstein.GromovWasserstein ) and epsilon is not None: raise ValueError( - "If `ot_solver` is `GromovWasserstein`, `epsilon` must be `None`. This check is performed " - "to ensure that in the (fused) Gromov case the `epsilon` parameter is passed via the `ot_solver`." + "If `ot_solver` is `GromovWasserstein`, `epsilon` must be `None`. " + + "This check is performed to ensure that in the (fused) Gromov case " + + "the `epsilon` parameter is passed via the `ot_solver`." ) self.rng = rng @@ -356,8 +377,11 @@ def transport( ) -> Union[jnp.array, diffrax.Solution, Optional[jnp.ndarray]]: """Transport data with the learnt plan. - This method pushes-forward the `source` to its conditional distribution by solving the neural ODE parameterized by the :attr:`~ott.neural.solvers.GENOTg.neural_vector_field` from - :attr:`~ott.neural.flows.BaseTimeSampler.low` to :attr:`~ott.neural.flows.BaseTimeSampler.high`. + This method pushes-forward the `source` to its conditional distribution by + solving the neural ODE parameterized by the + :attr:`~ott.neural.solvers.GENOTg.neural_vector_field` from + :attr:`~ott.neural.flows.BaseTimeSampler.low` to + :attr:`~ott.neural.flows.BaseTimeSampler.high`. Args: source: Data to transport. @@ -367,7 +391,8 @@ def transport( diffeqsolve_kwargs: Keyword arguments for the ODE solver. Returns: - The push-forward or pull-back distribution defined by the learnt transport plan. + The push-forward or pull-back distribution defined by the learnt + transport plan. """ if not forward: @@ -411,7 +436,7 @@ def _valid_step(self, valid_loader, iter) -> None: @property def learn_rescaling(self) -> bool: - """Whether to learn at least one rescaling factor of the marginal distributions.""" + """Whether to learn at least one rescaling factor.""" return self.mlp_eta is not None or self.mlp_xi is not None def save(self, path: str) -> None: diff --git a/src/ott/neural/solvers/otfm.py b/src/ott/neural/solvers/otfm.py index b69d7978e..57720f139 100644 --- a/src/ott/neural/solvers/otfm.py +++ b/src/ott/neural/solvers/otfm.py @@ -51,7 +51,10 @@ class OTFlowMatching(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): - """Flow matching as introduced in :cite:`lipman:22`, with extension to OT-FM (:cite`tong:23`, :cite:`pooladian:23`). + """(Optimal transport) flow matching class. + + Flow matching as introduced in :cite:`lipman:22`, with extension to OT-FM + (:cite`tong:23`, :cite:`pooladian:23`). Args: neural_vector_field: Neural vector field parameterized by a neural network. @@ -59,20 +62,26 @@ class OTFlowMatching(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): cond_dim: Dimension of the conditioning variable. iterations: Number of iterations. valid_freq: Frequency of validation. - ot_solver: OT solver to match samples from the source and the target distribution as proposed in :cite:`tong:23`, :cite:`pooladian:23`. If `None`, no matching will be performed as proposed in :cite:`lipman:22`. + ot_solver: OT solver to match samples from the source and the target + distribution as proposed in :cite:`tong:23`, :cite:`pooladian:23`. + If `None`, no matching will be performed as proposed in :cite:`lipman:22`. flow: Flow between source and target distribution. time_sampler: Sampler for the time. optimizer: Optimizer for `neural_vector_field`. checkpoint_manager: Checkpoint manager. - epsilon: Entropy regularization term of the OT OT problem solved by the `ot_solver`. + epsilon: Entropy regularization term of the OT OT problem solved by the + `ot_solver`. cost_fn: Cost function for the OT problem solved by the `ot_solver`. - scale_cost: How to scale the cost matrix for the OT problem solved by the `ot_solver`. + scale_cost: How to scale the cost matrix for the OT problem solved by the + `ot_solver`. tau_a: If :math:`<1`, defines how much unbalanced the problem is on the first marginal. tau_b: If :math:`< 1`, defines how much unbalanced the problem is on the second marginal. - mlp_eta: Neural network to learn the left rescaling function as suggested in :cite:`TODO`. If `None`, the left rescaling factor is not learnt. - mlp_xi: Neural network to learn the right rescaling function as suggested in :cite:`TODO`. If `None`, the right rescaling factor is not learnt. + mlp_eta: Neural network to learn the left rescaling function as suggested + in :cite:`TODO`. If `None`, the left rescaling factor is not learnt. + mlp_xi: Neural network to learn the right rescaling function as suggested + in :cite:`TODO`. If `None`, the right rescaling factor is not learnt. unbalanced_kwargs: Keyword arguments for the unbalancedness solver. callback_fn: Callback function. num_eval_samples: Number of samples to evaluate on during evaluation. @@ -271,8 +280,10 @@ def transport( ) -> diffrax.Solution: """Transport data with the learnt map. - This method solves the neural ODE parameterized by the :attr:`~ott.neural.solvers.OTFlowMatching.neural_vector_field` from - :attr:`~ott.neural.flows.BaseTimeSampler.low` to :attr:`~ott.neural.flows.BaseTimeSampler.high` if `forward` is `True`, + This method solves the neural ODE parameterized by the + :attr:`~ott.neural.solvers.OTFlowMatching.neural_vector_field` from + :attr:`~ott.neural.flows.BaseTimeSampler.low` to + :attr:`~ott.neural.flows.BaseTimeSampler.high` if `forward` is `True`, else the other way round. Args: @@ -282,7 +293,8 @@ def transport( diffeqsolve_kwargs: Keyword arguments for the ODE solver. Returns: - The push-forward or pull-back distribution defined by the learnt transport plan. + The push-forward or pull-back distribution defined by the learnt + transport plan. """ diffeqsolve_kwargs = dict(diffeqsolve_kwargs) @@ -320,7 +332,7 @@ def _valid_step(self, valid_loader, iter) -> None: @property def learn_rescaling(self) -> bool: - """Whether to learn at least one rescaling factor of the marginal distributions.""" + """Whether to learn at least one rescaling factor.""" return self.mlp_eta is not None or self.mlp_xi is not None def save(self, path: str) -> None: diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index 5c7c1d431..d7db29817 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -37,9 +37,8 @@ def test_genot_linear_unconditional( solver_latent_to_data = None if solver_latent_to_data is None else sinkhorn.Sinkhorn( ) batch = next(genot_data_loader_linear) - source_lin, source_quad, target_lin, target_quad, source_condition = batch[ - "source_lin"], batch["source_quad"], batch["target_lin"], batch[ - "target_quad"], batch["source_conditions"] + source_lin, target_lin, source_condition = batch[ + "source_lin"], batch["target_lin"], batch["source_conditions"] source_dim = source_lin.shape[1] target_dim = target_lin.shape[1] @@ -72,9 +71,8 @@ def test_genot_linear_unconditional( genot(genot_data_loader_linear, genot_data_loader_linear) batch = next(genot_data_loader_linear) - source_lin, source_quad, target_lin, target_quad, source_condition = batch[ - "source_lin"], batch["source_quad"], batch["target_lin"], batch[ - "target_quad"], batch["source_conditions"] + source_lin, target_lin, source_condition = batch["source_lin"], batch[ + "target_lin"], batch["source_conditions"] result_forward = genot.transport( source_lin, condition=source_condition, forward=True From ffec70c58765e7e75377f8a220a07b5e22b2aa13 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 29 Nov 2023 13:35:53 +0100 Subject: [PATCH 044/186] resolve pre-commit errors --- src/ott/neural/solvers/base_solver.py | 2 +- src/ott/neural/solvers/genot.py | 26 +++++++++++++++++--------- src/ott/neural/solvers/otfm.py | 21 ++++++++++++++------- tests/neural/genot_test.py | 14 +++++++++----- tests/neural/otfm_test.py | 5 ++++- 5 files changed, 45 insertions(+), 23 deletions(-) diff --git a/src/ott/neural/solvers/base_solver.py b/src/ott/neural/solvers/base_solver.py index 5c291afaa..6a9ee84ea 100644 --- a/src/ott/neural/solvers/base_solver.py +++ b/src/ott/neural/solvers/base_solver.py @@ -152,7 +152,7 @@ def _get_sinkhorn_match_fn( self, ot_solver: Any, epsilon: float = 1e-2, - cost_fn: costs.CostFn = costs.SqEuclidean(), + cost_fn: Optional[costs.CostFn] = None, scale_cost: Union[bool, int, float, Literal["mean", "max_norm", "max_bound", "max_cost", "median"]] = "mean", diff --git a/src/ott/neural/solvers/genot.py b/src/ott/neural/solvers/genot.py index 2eb81be98..ae5953de4 100644 --- a/src/ott/neural/solvers/genot.py +++ b/src/ott/neural/solvers/genot.py @@ -24,6 +24,7 @@ from jax import random from orbax import checkpoint +from ott import utils from ott.geometry import costs from ott.neural.models.models import BaseNeuralVectorField from ott.neural.solvers.base_solver import ( @@ -130,7 +131,7 @@ def __init__( unbalanced_kwargs: Dict[str, Any] = types.MappingProxyType({}), callback_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], Any]] = None, - rng: random.PRNGKeyArray = random.PRNGKey(0), + rng: Optional[jnp.ndarray] = None, ) -> None: rng, rng_unbalanced = random.split(rng) BaseNeuralSolver.__init__( @@ -158,7 +159,7 @@ def __init__( "the `epsilon` parameter is passed via the `ot_solver`." ) - self.rng = rng + self.rng = utils.default_prng_key(rng) self.neural_vector_field = neural_vector_field self.state_neural_vector_field: Optional[TrainState] = None self.flow = flow @@ -198,8 +199,10 @@ def setup(self) -> None: kwargs Keyword arguments for the setup function """ - self.state_neural_vector_field = self.neural_vector_field.create_train_state( - self.rng, self.optimizer, self.output_dim + self.state_neural_vector_field = ( + self.neural_vector_field.create_train_state( + self.rng, self.optimizer, self.output_dim + ) ) self.step_fn = self._get_step_fn() if self.solver_latent_to_data is not None: @@ -232,9 +235,10 @@ def __call__(self, train_loader, valid_loader) -> None: for iteration in range(self.iterations): batch = next(train_loader) - self.rng, rng_time, rng_resample, rng_noise, rng_latent_data_match, rng_step_fn = jax.random.split( - self.rng, 6 - ) + ( + self.rng, rng_time, rng_resample, rng_noise, rng_latent_data_match, + rng_step_fn + ) = jax.random.split(self.rng, 6) batch_size = len( batch["source_lin"] ) if batch["source_lin"] is not None else len(batch["source_quad"]) @@ -303,7 +307,10 @@ def __call__(self, train_loader, valid_loader) -> None: rng_step_fn, self.state_neural_vector_field, batch ) if self.learn_rescaling: - self.state_eta, self.state_xi, eta_predictions, xi_predictions, loss_a, loss_b = self.unbalancedness_step_fn( + ( + self.state_eta, self.state_xi, eta_predictions, xi_predictions, + loss_a, loss_b + ) = self.unbalancedness_step_fn( source=batch["source"], target=batch["target"], condition=batch["source_conditions"], @@ -371,7 +378,7 @@ def transport( self, source: jnp.ndarray, condition: Optional[jnp.ndarray], - rng: random.PRNGKeyArray = random.PRNGKey(0), + rng: Optional[jnp.ndarray] = None, forward: bool = True, diffeqsolve_kwargs: Dict[str, Any] = types.MappingProxyType({}), ) -> Union[jnp.array, diffrax.Solution, Optional[jnp.ndarray]]: @@ -395,6 +402,7 @@ def transport( transport plan. """ + rng = utils.default_prng_key(rng) if not forward: raise NotImplementedError diffeqsolve_kwargs = dict(diffeqsolve_kwargs) diff --git a/src/ott/neural/solvers/otfm.py b/src/ott/neural/solvers/otfm.py index 57720f139..378d4f108 100644 --- a/src/ott/neural/solvers/otfm.py +++ b/src/ott/neural/solvers/otfm.py @@ -34,6 +34,7 @@ from jax import random from orbax import checkpoint +from ott import utils from ott.geometry import costs from ott.neural.models.models import BaseNeuralVectorField from ott.neural.solvers.base_solver import ( @@ -104,7 +105,7 @@ def __init__( optimizer: Type[optax.GradientTransformation], checkpoint_manager: Type[checkpoint.CheckpointManager] = None, epsilon: float = 1e-2, - cost_fn: Type[costs.CostFn] = costs.SqEuclidean(), + cost_fn: Optional[Type[costs.CostFn]] = None, scale_cost: Union[bool, int, float, Literal["mean", "max_norm", "max_bound", "max_cost", "median"]] = "mean", @@ -112,14 +113,15 @@ def __init__( tau_b: float = 1.0, mlp_eta: Callable[[jnp.ndarray], float] = None, mlp_xi: Callable[[jnp.ndarray], float] = None, - unbalanced_kwargs: Dict[str, Any] = {}, + unbalanced_kwargs: Dict[str, Any] = types.MappingProxyType({}), callback_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], Any]] = None, logging_freq: int = 100, valid_freq: int = 5000, num_eval_samples: int = 1000, - rng: random.PRNGKeyArray = random.PRNGKey(0), + rng: Optional[jnp.ndarray] = None, ) -> None: + rng = utils.default_prng_key(rng) rng, rng_unbalanced = random.split(rng) BaseNeuralSolver.__init__( self, iterations=iterations, valid_freq=valid_freq @@ -158,8 +160,10 @@ def __init__( def setup(self) -> None: """Setup :class:`OTFlowMatching`.""" - self.state_neural_vector_field = self.neural_vector_field.create_train_state( - self.rng, self.optimizer, self.input_dim + self.state_neural_vector_field = ( + self.neural_vector_field.create_train_state( + self.rng, self.optimizer, self.input_dim + ) ) self.step_fn = self._get_step_fn() @@ -250,7 +254,10 @@ def __call__(self, train_loader, valid_loader) -> None: self._training_logs["loss"].append(curr_loss / self.logging_freq) curr_loss = 0.0 if self.learn_rescaling: - self.state_eta, self.state_xi, eta_predictions, xi_predictions, loss_a, loss_b = self.unbalancedness_step_fn( + ( + self.state_eta, self.state_xi, eta_predictions, xi_predictions, + loss_a, loss_b + ) = self.unbalancedness_step_fn( source=batch["source_lin"], target=batch["target_lin"], condition=batch["source_conditions"], @@ -293,7 +300,7 @@ def transport( diffeqsolve_kwargs: Keyword arguments for the ODE solver. Returns: - The push-forward or pull-back distribution defined by the learnt + The push-forward or pull-back distribution defined by the learnt transport plan. """ diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index d7db29817..794b5d44e 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -34,11 +34,12 @@ def test_genot_linear_unconditional( self, genot_data_loader_linear: Iterator, k_samples_per_x: int, solver_latent_to_data: Optional[str] ): - solver_latent_to_data = None if solver_latent_to_data is None else sinkhorn.Sinkhorn( + solver_latent_to_data = ( + None if solver_latent_to_data is None else sinkhorn.Sinkhorn() ) batch = next(genot_data_loader_linear) - source_lin, target_lin, source_condition = batch[ - "source_lin"], batch["target_lin"], batch["source_conditions"] + source_lin, target_lin, source_condition = batch["source_lin"], batch[ + "target_lin"], batch["source_conditions"] source_dim = source_lin.shape[1] target_dim = target_lin.shape[1] @@ -133,7 +134,7 @@ def test_genot_fused_unconditional( ): None if solver_latent_to_data is None else sinkhorn.Sinkhorn() batch = next(genot_data_loader_fused) - batch = source_lin, source_quad, target_lin, target_quad, source_condition = batch[ + source_lin, source_quad, target_lin, target_quad, source_condition = batch[ "source_lin"], batch["source_quad"], batch["target_lin"], batch[ "target_quad"], batch["source_conditions"] @@ -324,7 +325,10 @@ def test_genot_linear_learn_rescaling( genot_data_loader_linear_conditional: Iterator ): None if solver_latent_to_data is None else sinkhorn.Sinkhorn() - data_loader = genot_data_loader_linear_conditional if conditional else genot_data_loader_linear + data_loader = ( + genot_data_loader_linear_conditional + if conditional else genot_data_loader_linear + ) batch = next(data_loader) source_lin, target_lin, source_condition = batch["source_lin"], batch[ diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index d8deb1102..e77789938 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -172,7 +172,10 @@ def test_flow_matching_learn_rescaling( self, conditional: bool, data_loader_gaussian: Iterator, data_loader_gaussian_conditional: Iterator ): - data_loader = data_loader_gaussian_conditional if conditional else data_loader_gaussian + data_loader = ( + data_loader_gaussian_conditional + if conditional else data_loader_gaussian + ) batch = next(data_loader) source_dim = batch["source_lin"].shape[1] condition_dim = batch["source_conditions"].shape[1] if conditional else 0 From 10d70f24f60c0639c597c06010520b27e7a19c3b Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 29 Nov 2023 13:37:24 +0100 Subject: [PATCH 045/186] fix rng bug --- src/ott/neural/solvers/genot.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/ott/neural/solvers/genot.py b/src/ott/neural/solvers/genot.py index ae5953de4..fbfca23ce 100644 --- a/src/ott/neural/solvers/genot.py +++ b/src/ott/neural/solvers/genot.py @@ -133,6 +133,7 @@ def __init__( Any]] = None, rng: Optional[jnp.ndarray] = None, ) -> None: + rng = utils.default_prng_key(rng) rng, rng_unbalanced = random.split(rng) BaseNeuralSolver.__init__( self, iterations=iterations, valid_freq=valid_freq From 9fb308bb29abfeacde673c1debfdf299ade35d58 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Wed, 29 Nov 2023 13:45:06 +0100 Subject: [PATCH 046/186] Update pre-commit --- .pre-commit-config.yaml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 396cca399..d54c42330 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,24 +7,24 @@ default_stages: minimum_pre_commit_version: 3.0.0 repos: - repo: https://github.com/google/yapf - rev: v0.40.0 + rev: v0.40.2 hooks: - id: yapf additional_dependencies: [toml] - repo: https://github.com/nbQA-dev/nbQA - rev: 1.7.0 + rev: 1.7.1 hooks: - id: nbqa-pyupgrade args: [--py38-plus] - id: nbqa-black - id: nbqa-isort - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks - rev: v2.10.0 + rev: v2.11.0 hooks: - id: pretty-format-yaml args: [--autofix, --indent, '2'] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: detect-private-key - id: check-ast @@ -38,12 +38,12 @@ repos: - id: check-case-conflict - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. - rev: v0.0.285 + rev: v0.1.6 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] - repo: https://github.com/rstcheck/rstcheck - rev: v6.1.2 + rev: v6.2.0 hooks: - id: rstcheck additional_dependencies: [tomli] From aa0bdc58a1e2cffbafe412ada21d6ba3285bcb79 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 29 Nov 2023 14:00:14 +0100 Subject: [PATCH 047/186] fix import error --- tests/neural/icnn_test.py | 2 +- tests/neural/meta_initializer_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/neural/icnn_test.py b/tests/neural/icnn_test.py index c52eac675..4d760557f 100644 --- a/tests/neural/icnn_test.py +++ b/tests/neural/icnn_test.py @@ -16,7 +16,7 @@ import numpy as np import pytest -from ott.neural import models +from ott.neural.models import models @pytest.mark.fast() diff --git a/tests/neural/meta_initializer_test.py b/tests/neural/meta_initializer_test.py index 98aa4f4d0..f978e8206 100644 --- a/tests/neural/meta_initializer_test.py +++ b/tests/neural/meta_initializer_test.py @@ -20,7 +20,7 @@ from ott.geometry import pointcloud from ott.initializers.linear import initializers as linear_init -from ott.neural import models as nn_init +from ott.neural.models import models as nn_init from ott.problems.linear import linear_problem from ott.solvers.linear import sinkhorn From b48dfdc80386100dada0d8953c5a115afeececa8 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Wed, 29 Nov 2023 14:25:43 +0100 Subject: [PATCH 048/186] Run linter --- .pre-commit-config.yaml | 40 ++++++++++--------- docs/tutorials/MetaOT.ipynb | 2 + docs/tutorials/Monge_Gap.ipynb | 3 +- docs/tutorials/icnn_inits.ipynb | 1 + docs/tutorials/neural_dual.ipynb | 3 +- docs/tutorials/point_clouds.ipynb | 4 ++ docs/tutorials/soft_sort.ipynb | 7 ++-- .../sparse_monge_displacements.ipynb | 2 + docs/tutorials/tracking_progress.ipynb | 2 + pyproject.toml | 12 +++--- src/ott/math/__init__.py | 7 +--- src/ott/neural/models/base_models.py | 3 +- src/ott/neural/models/layers.py | 3 +- src/ott/neural/models/models.py | 5 ++- src/ott/neural/solvers/base_solver.py | 4 +- src/ott/neural/solvers/genot.py | 16 ++++---- src/ott/neural/solvers/map_estimator.py | 1 + src/ott/neural/solvers/neuraldual.py | 3 +- src/ott/neural/solvers/otfm.py | 27 ++++++------- src/ott/problems/linear/potentials.py | 10 +---- src/ott/solvers/linear/lineax_implicit.py | 5 ++- tests/conftest.py | 6 ++- tests/geometry/costs_test.py | 3 +- tests/geometry/geodesic_test.py | 10 +++-- tests/geometry/graph_test.py | 10 +++-- tests/geometry/low_rank_test.py | 3 +- tests/geometry/pointcloud_test.py | 3 +- tests/geometry/scaling_cost_test.py | 3 +- tests/geometry/subsetting_test.py | 3 +- .../initializers/linear/sinkhorn_init_test.py | 3 +- .../linear/sinkhorn_lr_init_test.py | 3 +- tests/initializers/quadratic/gw_init_test.py | 3 +- tests/math/lse_test.py | 3 +- tests/math/math_utils_test.py | 3 +- tests/math/matrix_square_root_test.py | 3 +- tests/neural/conftest.py | 3 +- tests/neural/genot_test.py | 4 +- tests/neural/icnn_test.py | 3 +- tests/neural/losses_test.py | 3 +- tests/neural/map_estimator_test.py | 3 +- tests/neural/meta_initializer_test.py | 4 +- tests/neural/neuraldual_test.py | 3 +- tests/neural/otfm_test.py | 4 +- tests/problems/linear/potentials_test.py | 6 ++- .../linear/continuous_barycenter_test.py | 3 +- .../linear/discrete_barycenter_test.py | 3 +- tests/solvers/linear/sinkhorn_diff_test.py | 3 +- tests/solvers/linear/sinkhorn_grid_test.py | 3 +- tests/solvers/linear/sinkhorn_lr_test.py | 3 +- tests/solvers/linear/sinkhorn_misc_test.py | 7 +++- tests/solvers/linear/sinkhorn_test.py | 3 +- tests/solvers/linear/univariate_test.py | 3 +- tests/solvers/quadratic/fgw_test.py | 3 +- tests/solvers/quadratic/gw_barycenter_test.py | 3 +- tests/solvers/quadratic/gw_test.py | 3 +- tests/solvers/quadratic/lower_bound_test.py | 3 +- .../gaussian_mixture/fit_gmm_pair_test.py | 3 +- tests/tools/gaussian_mixture/fit_gmm_test.py | 3 +- .../gaussian_mixture_pair_test.py | 3 +- .../gaussian_mixture/gaussian_mixture_test.py | 3 +- tests/tools/gaussian_mixture/gaussian_test.py | 3 +- tests/tools/gaussian_mixture/linalg_test.py | 3 +- .../gaussian_mixture/probabilities_test.py | 3 +- .../tools/gaussian_mixture/scale_tril_test.py | 3 +- tests/tools/k_means_test.py | 3 +- tests/tools/plot_test.py | 1 + tests/tools/segment_sinkhorn_test.py | 3 +- tests/tools/sinkhorn_divergence_test.py | 3 +- tests/tools/soft_sort_test.py | 3 +- 69 files changed, 200 insertions(+), 129 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d54c42330..1f84672bb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,23 +6,6 @@ default_stages: - push minimum_pre_commit_version: 3.0.0 repos: -- repo: https://github.com/google/yapf - rev: v0.40.2 - hooks: - - id: yapf - additional_dependencies: [toml] -- repo: https://github.com/nbQA-dev/nbQA - rev: 1.7.1 - hooks: - - id: nbqa-pyupgrade - args: [--py38-plus] - - id: nbqa-black - - id: nbqa-isort -- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks - rev: v2.11.0 - hooks: - - id: pretty-format-yaml - args: [--autofix, --indent, '2'] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.5.0 hooks: @@ -37,11 +20,32 @@ repos: - id: trailing-whitespace - id: check-case-conflict - repo: https://github.com/charliermarsh/ruff-pre-commit - # Ruff version. rev: v0.1.6 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] +- repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + name: isort +- repo: https://github.com/google/yapf + rev: v0.40.2 + hooks: + - id: yapf + additional_dependencies: [toml] +- repo: https://github.com/nbQA-dev/nbQA + rev: 1.7.1 + hooks: + - id: nbqa-pyupgrade + args: [--py38-plus] + - id: nbqa-black + - id: nbqa-isort +- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks + rev: v2.11.0 + hooks: + - id: pretty-format-yaml + args: [--autofix, --indent, '2'] - repo: https://github.com/rstcheck/rstcheck rev: v6.2.0 hooks: diff --git a/docs/tutorials/MetaOT.ipynb b/docs/tutorials/MetaOT.ipynb index 79503617d..9349786a1 100644 --- a/docs/tutorials/MetaOT.ipynb +++ b/docs/tutorials/MetaOT.ipynb @@ -81,6 +81,8 @@ "outputs": [], "source": [ "# Obtain the MNIST dataset and flatten the images into discrete measures.\n", + "\n", + "\n", "def get_mnist_flat(train):\n", " dataset = torchvision.datasets.MNIST(\n", " \"/tmp/mnist/\",\n", diff --git a/docs/tutorials/Monge_Gap.ipynb b/docs/tutorials/Monge_Gap.ipynb index 53bc670dc..2fde4f923 100644 --- a/docs/tutorials/Monge_Gap.ipynb +++ b/docs/tutorials/Monge_Gap.ipynb @@ -31,8 +31,9 @@ "\n", "import jax\n", "import jax.numpy as jnp\n", - "import optax\n", "import sklearn.datasets\n", + "\n", + "import optax\n", "from flax import linen as nn\n", "\n", "from matplotlib import pyplot as plt\n", diff --git a/docs/tutorials/icnn_inits.ipynb b/docs/tutorials/icnn_inits.ipynb index 8d8444507..24ca43f6f 100644 --- a/docs/tutorials/icnn_inits.ipynb +++ b/docs/tutorials/icnn_inits.ipynb @@ -33,6 +33,7 @@ "import jax\n", "import jax.numpy as jnp\n", "import numpy as np\n", + "\n", "import optax\n", "\n", "import matplotlib.pyplot as plt\n", diff --git a/docs/tutorials/neural_dual.ipynb b/docs/tutorials/neural_dual.ipynb index c1d9461d2..3fadb58ca 100644 --- a/docs/tutorials/neural_dual.ipynb +++ b/docs/tutorials/neural_dual.ipynb @@ -49,9 +49,10 @@ "import jax\n", "import jax.numpy as jnp\n", "import numpy as np\n", - "import optax\n", "from torch.utils.data import DataLoader, IterableDataset\n", "\n", + "import optax\n", + "\n", "import matplotlib.pyplot as plt\n", "from IPython.display import clear_output, display\n", "\n", diff --git a/docs/tutorials/point_clouds.ipynb b/docs/tutorials/point_clouds.ipynb index fd20ffc9a..c01b51cfd 100644 --- a/docs/tutorials/point_clouds.ipynb +++ b/docs/tutorials/point_clouds.ipynb @@ -279,6 +279,8 @@ "outputs": [], "source": [ "# Helper function to plot successively the optimal transports\n", + "\n", + "\n", "def plot_ots(ots):\n", " fig = plt.figure(figsize=(8, 5))\n", " plott = ott.tools.plot.Plot(fig=fig)\n", @@ -366973,6 +366975,8 @@ "outputs": [], "source": [ "# Plotting utility\n", + "\n", + "\n", "def plot_map(x, y, z, forward: bool = True):\n", " plt.figure(figsize=(10, 8))\n", " marker_t = \"o\" if forward else \"X\"\n", diff --git a/docs/tutorials/soft_sort.ipynb b/docs/tutorials/soft_sort.ipynb index cf0f751ac..880506731 100644 --- a/docs/tutorials/soft_sort.ipynb +++ b/docs/tutorials/soft_sort.ipynb @@ -37,16 +37,17 @@ "\n", "from tqdm.notebook import tqdm\n", "\n", - "import flax.linen as nn\n", "import jax\n", "import jax.numpy as jnp\n", "import numpy as np\n", - "import optax\n", "import torchvision\n", - "from flax import struct\n", "from scipy import ndimage\n", "from torch.utils import data\n", "\n", + "import flax.linen as nn\n", + "import optax\n", + "from flax import struct\n", + "\n", "import matplotlib.pyplot as plt\n", "\n", "from ott.tools import soft_sort" diff --git a/docs/tutorials/sparse_monge_displacements.ipynb b/docs/tutorials/sparse_monge_displacements.ipynb index a21213703..8b735d9f7 100644 --- a/docs/tutorials/sparse_monge_displacements.ipynb +++ b/docs/tutorials/sparse_monge_displacements.ipynb @@ -114,6 +114,8 @@ "outputs": [], "source": [ "# Plotting utility\n", + "\n", + "\n", "def plot_map(x, y, x_new=None, z=None, ax=None, title=None):\n", " if ax is None:\n", " f, ax = plt.subplots(figsize=(10, 8))\n", diff --git a/docs/tutorials/tracking_progress.ipynb b/docs/tutorials/tracking_progress.ipynb index b8a230da6..cd358252b 100644 --- a/docs/tutorials/tracking_progress.ipynb +++ b/docs/tutorials/tracking_progress.ipynb @@ -373,6 +373,8 @@ "outputs": [], "source": [ "# Samples spiral\n", + "\n", + "\n", "def sample_spiral(\n", " n, min_radius, max_radius, key, min_angle=0, max_angle=10, noise=1.0\n", "):\n", diff --git a/pyproject.toml b/pyproject.toml index 530c55113..1961a5971 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,11 +103,14 @@ include = '\.ipynb$' [tool.isort] profile = "black" +line_length = 80 include_trailing_comma = true multi_line_output = 3 -sections = ["FUTURE", "STDLIB", "THIRDPARTY", "NUMERIC", "PLOTTING", "FIRSTPARTY", "LOCALFOLDER"] -# also contains what we import in notebooks -known_numeric = ["numpy", "scipy", "jax", "flax", "optax", "jaxopt", "torch", "ot", "torchvision", "pandas", "sklearn"] +sections = ["FUTURE", "STDLIB", "THIRDPARTY", "TEST", "NUMERIC", "NEURAL", "PLOTTING", "FIRSTPARTY", "LOCALFOLDER"] +# also contains what we import in notebooks/tests +known_neural = ["flax", "optax", "diffrax", "orbax"] +known_numeric = ["numpy", "scipy", "jax", "flax", "optax", "jaxopt", "torch", "ot", "torchvision", "pandas", "sklearn", "tslearn"] +known_test = ["pytest"] known_plotting = ["IPython", "matplotlib", "mpl_toolkits", "seaborn"] [tool.pytest.ini_options] @@ -286,7 +289,6 @@ ignore = [ line-length = 80 select = [ "D", # flake8-docstrings - "I", # isort "E", # pycodestyle "F", # pyflakes "W", # pycodestyle @@ -302,7 +304,7 @@ select = [ "T20", # flake8-print "RET", # flake8-raise ] -unfixable = ["B", "UP", "C4", "BLE", "T20", "RET"] +unfixable = ["I", "B", "UP", "C4", "BLE", "T20", "RET"] target-version = "py38" [tool.ruff.per-file-ignores] # TODO(michalk8): PO004 - remove `self.initialize` diff --git a/src/ott/math/__init__.py b/src/ott/math/__init__.py index 64bc1c07b..ce2a09a73 100644 --- a/src/ott/math/__init__.py +++ b/src/ott/math/__init__.py @@ -11,9 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import ( - fixed_point_loop, - matrix_square_root, - unbalanced_functions, - utils, -) +from . import fixed_point_loop, matrix_square_root, unbalanced_functions, utils diff --git a/src/ott/neural/models/base_models.py b/src/ott/neural/models/base_models.py index 8b5dc126a..d3ac7526a 100644 --- a/src/ott/neural/models/base_models.py +++ b/src/ott/neural/models/base_models.py @@ -14,9 +14,10 @@ import abc from typing import Optional -import flax.linen as nn import jax.numpy as jnp +import flax.linen as nn + __all__ = ["BaseNeuralVectorField", "BaseRescalingNet"] diff --git a/src/ott/neural/models/layers.py b/src/ott/neural/models/layers.py index 153087141..79e6394bc 100644 --- a/src/ott/neural/models/layers.py +++ b/src/ott/neural/models/layers.py @@ -15,7 +15,8 @@ import jax import jax.numpy as jnp -from flax import linen as nn + +import flax.linen as nn __all__ = ["PositiveDense", "PosDefPotentials"] diff --git a/src/ott/neural/models/models.py b/src/ott/neural/models/models.py index 9b15cb803..c65cbbaf3 100644 --- a/src/ott/neural/models/models.py +++ b/src/ott/neural/models/models.py @@ -14,13 +14,14 @@ import functools from typing import Any, Callable, Dict, Optional, Sequence, Tuple -import flax.linen as nn import jax import jax.numpy as jnp +from jax.nn import initializers + +import flax.linen as nn import optax from flax.core import frozen_dict from flax.training import train_state -from jax.nn import initializers from ott import utils from ott.geometry import geometry diff --git a/src/ott/neural/solvers/base_solver.py b/src/ott/neural/solvers/base_solver.py index 6a9ee84ea..bde81e9da 100644 --- a/src/ott/neural/solvers/base_solver.py +++ b/src/ott/neural/solvers/base_solver.py @@ -18,9 +18,9 @@ import jax import jax.numpy as jnp + import optax from flax.training import train_state -from jax import random from ott.geometry import costs, pointcloud from ott.geometry.pointcloud import PointCloud @@ -89,7 +89,7 @@ def _resample_data( ) -> Tuple[jnp.ndarray, ...]: """Resample a batch according to coupling `tmat`.""" tmat_flattened = tmat.flatten() - indices = random.choice(key, len(tmat_flattened), shape=[tmat.shape[0]]) + indices = jax.random.choice(key, len(tmat_flattened), shape=[tmat.shape[0]]) indices_source = indices // tmat.shape[1] indices_target = indices % tmat.shape[1] return tuple( diff --git a/src/ott/neural/solvers/genot.py b/src/ott/neural/solvers/genot.py index fbfca23ce..0613ae53c 100644 --- a/src/ott/neural/solvers/genot.py +++ b/src/ott/neural/solvers/genot.py @@ -15,13 +15,13 @@ import types from typing import Any, Callable, Dict, Literal, Optional, Type, Union -import diffrax import jax import jax.numpy as jnp + +import diffrax import optax from flax.training import train_state from flax.training.train_state import TrainState -from jax import random from orbax import checkpoint from ott import utils @@ -134,7 +134,7 @@ def __init__( rng: Optional[jnp.ndarray] = None, ) -> None: rng = utils.default_prng_key(rng) - rng, rng_unbalanced = random.split(rng) + rng, rng_unbalanced = jax.random.split(rng) BaseNeuralSolver.__init__( self, iterations=iterations, valid_freq=valid_freq ) @@ -343,7 +343,7 @@ def step_fn( def loss_fn( params: jnp.ndarray, batch: Dict[str, jnp.array], - keys_model: random.PRNGKeyArray + keys_model: jax.random.PRNGKeyArray ): x_t = self.flow.compute_xt( batch["noise"], batch["time"], batch["latent"], batch["target"] @@ -366,7 +366,7 @@ def loss_fn( ) return jnp.mean((v_t - u_t) ** 2) - keys_model = random.split(key, len(batch["noise"])) + keys_model = jax.random.split(key, len(batch["noise"])) grad_fn = jax.value_and_grad(loss_fn, has_aux=False) loss, grads = grad_fn(state_neural_vector_field.params, batch, keys_model) @@ -472,7 +472,9 @@ def training_logs(self) -> Dict[str, Any]: """Logs of the training.""" raise NotImplementedError - def sample_noise(self, key: random.PRNGKey, batch_size: int) -> jnp.ndarray: + def sample_noise( + self, key: jax.random.PRNGKey, batch_size: int + ) -> jnp.ndarray: """Sample noise from a standard-normal distribution. Args: @@ -482,4 +484,4 @@ def sample_noise(self, key: random.PRNGKey, batch_size: int) -> jnp.ndarray: Returns: Samples from the standard normal distribution. """ - return random.normal(key, shape=(batch_size, self.output_dim)) + return jax.random.normal(key, shape=(batch_size, self.output_dim)) diff --git a/src/ott/neural/solvers/map_estimator.py b/src/ott/neural/solvers/map_estimator.py index b97f673b0..7eaffdfc8 100644 --- a/src/ott/neural/solvers/map_estimator.py +++ b/src/ott/neural/solvers/map_estimator.py @@ -26,6 +26,7 @@ import jax import jax.numpy as jnp + import optax from flax.core import frozen_dict from flax.training import train_state diff --git a/src/ott/neural/solvers/neuraldual.py b/src/ott/neural/solvers/neuraldual.py index e78666ec6..a7da8c3e7 100644 --- a/src/ott/neural/solvers/neuraldual.py +++ b/src/ott/neural/solvers/neuraldual.py @@ -27,9 +27,10 @@ import jax import jax.numpy as jnp + +import flax.linen as nn import optax from flax import core, struct -from flax import linen as nn from flax.core import frozen_dict from flax.training import train_state diff --git a/src/ott/neural/solvers/otfm.py b/src/ott/neural/solvers/otfm.py index 378d4f108..fb054e30a 100644 --- a/src/ott/neural/solvers/otfm.py +++ b/src/ott/neural/solvers/otfm.py @@ -26,12 +26,12 @@ Union, ) -import diffrax import jax import jax.numpy as jnp + +import diffrax import optax from flax.training import train_state -from jax import random from orbax import checkpoint from ott import utils @@ -42,10 +42,7 @@ ResampleMixin, UnbalancednessMixin, ) -from ott.neural.solvers.flows import ( - BaseFlow, - BaseTimeSampler, -) +from ott.neural.solvers.flows import BaseFlow, BaseTimeSampler from ott.solvers import was_solver __all__ = ["OTFlowMatching"] @@ -122,7 +119,7 @@ def __init__( rng: Optional[jnp.ndarray] = None, ) -> None: rng = utils.default_prng_key(rng) - rng, rng_unbalanced = random.split(rng) + rng, rng_unbalanced = jax.random.split(rng) BaseNeuralSolver.__init__( self, iterations=iterations, valid_freq=valid_freq ) @@ -183,14 +180,14 @@ def _get_step_fn(self) -> Callable: @jax.jit def step_fn( - key: random.PRNGKeyArray, + key: jax.random.PRNGKeyArray, state_neural_vector_field: train_state.TrainState, batch: Dict[str, jnp.ndarray], ) -> Tuple[Any, Any]: def loss_fn( params: jnp.ndarray, t: jnp.ndarray, noise: jnp.ndarray, - batch: Dict[str, jnp.ndarray], keys_model: random.PRNGKeyArray + batch: Dict[str, jnp.ndarray], keys_model: jax.random.PRNGKeyArray ) -> jnp.ndarray: x_t = self.flow.compute_xt( @@ -209,8 +206,8 @@ def loss_fn( return jnp.mean((v_t - u_t) ** 2) batch_size = len(batch["source_lin"]) - key_noise, key_t, key_model = random.split(key, 3) - keys_model = random.split(key_model, batch_size) + key_noise, key_t, key_model = jax.random.split(key, 3) + keys_model = jax.random.split(key_model, batch_size) t = self.time_sampler(key_t, batch_size) noise = self.sample_noise(key_noise, batch_size) grad_fn = jax.value_and_grad(loss_fn) @@ -235,7 +232,7 @@ def __call__(self, train_loader, valid_loader) -> None: curr_loss = 0.0 for iter in range(self.iterations): - rng_resample, rng_step_fn, self.rng = random.split(self.rng, 3) + rng_resample, rng_step_fn, self.rng = jax.random.split(self.rng, 3) batch = next(train_loader) if self.ot_solver is not None: tmat = self.match_fn(batch["source_lin"], batch["target_lin"]) @@ -366,7 +363,9 @@ def training_logs(self) -> Dict[str, Any]: """Logs of the training.""" raise NotImplementedError - def sample_noise(self, key: random.PRNGKey, batch_size: int) -> jnp.ndarray: + def sample_noise( + self, key: jax.random.PRNGKey, batch_size: int + ) -> jnp.ndarray: """Sample noise from a standard-normal distribution. Args: @@ -376,4 +375,4 @@ def sample_noise(self, key: random.PRNGKey, batch_size: int) -> jnp.ndarray: Returns: Samples from the standard normal distribution. """ - return random.normal(key, shape=(batch_size, self.input_dim)) + return jax.random.normal(key, shape=(batch_size, self.input_dim)) diff --git a/src/ott/problems/linear/potentials.py b/src/ott/problems/linear/potentials.py index 7ab226072..a91cf5038 100644 --- a/src/ott/problems/linear/potentials.py +++ b/src/ott/problems/linear/potentials.py @@ -11,15 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import ( - Any, - Callable, - Dict, - Literal, - Optional, - Sequence, - Tuple, -) +from typing import Any, Callable, Dict, Literal, Optional, Sequence, Tuple import jax import jax.numpy as jnp diff --git a/src/ott/solvers/linear/lineax_implicit.py b/src/ott/solvers/linear/lineax_implicit.py index 79b9e7c95..30200b073 100644 --- a/src/ott/solvers/linear/lineax_implicit.py +++ b/src/ott/solvers/linear/lineax_implicit.py @@ -14,11 +14,12 @@ from typing import Any, Callable, Optional, TypeVar import equinox as eqx +import lineax as lx +from jaxtyping import Array, Float, PyTree + import jax import jax.numpy as jnp import jax.tree_util as jtu -import lineax as lx -from jaxtyping import Array, Float, PyTree _T = TypeVar("_T") _FlatPyTree = tuple[list[_T], jtu.PyTreeDef] diff --git a/tests/conftest.py b/tests/conftest.py index bc4570343..da7e6a3dc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,11 +15,13 @@ import itertools from typing import Any, Mapping, Optional, Sequence +from _pytest.python import Metafunc + +import pytest + import jax import jax.experimental import jax.numpy as jnp -import pytest -from _pytest.python import Metafunc def pytest_generate_tests(metafunc: Metafunc) -> None: diff --git a/tests/geometry/costs_test.py b/tests/geometry/costs_test.py index b23e79071..02d9976da 100644 --- a/tests/geometry/costs_test.py +++ b/tests/geometry/costs_test.py @@ -13,10 +13,11 @@ # limitations under the License. from typing import Type +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import costs, pointcloud from ott.solvers import linear diff --git a/tests/geometry/geodesic_test.py b/tests/geometry/geodesic_test.py index 3891ac144..986246dfd 100644 --- a/tests/geometry/geodesic_test.py +++ b/tests/geometry/geodesic_test.py @@ -13,14 +13,16 @@ # limitations under the License. from typing import Optional, Union -import jax -import jax.numpy as jnp import networkx as nx -import numpy as np -import pytest from networkx.algorithms import shortest_paths from networkx.generators import balanced_tree, random_graphs +import pytest + +import jax +import jax.numpy as jnp +import numpy as np + from ott.geometry import geodesic, geometry, graph from ott.problems.linear import linear_problem from ott.solvers.linear import sinkhorn diff --git a/tests/geometry/graph_test.py b/tests/geometry/graph_test.py index 35dde4c4b..9c79d2b42 100644 --- a/tests/geometry/graph_test.py +++ b/tests/geometry/graph_test.py @@ -13,14 +13,16 @@ # limitations under the License. from typing import Literal, Optional, Tuple, Union +import networkx as nx +from networkx.algorithms import shortest_paths +from networkx.generators import balanced_tree, random_graphs + +import pytest + import jax import jax.numpy as jnp -import networkx as nx import numpy as np -import pytest from jax.experimental import sparse -from networkx.algorithms import shortest_paths -from networkx.generators import balanced_tree, random_graphs from ott.geometry import geometry, graph from ott.problems.linear import linear_problem diff --git a/tests/geometry/low_rank_test.py b/tests/geometry/low_rank_test.py index b3cda89cf..3e068f8e2 100644 --- a/tests/geometry/low_rank_test.py +++ b/tests/geometry/low_rank_test.py @@ -13,10 +13,11 @@ # limitations under the License. from typing import Callable, Optional, Tuple, Union +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import costs, geometry, grid, low_rank, pointcloud diff --git a/tests/geometry/pointcloud_test.py b/tests/geometry/pointcloud_test.py index ff32789fe..1a952132f 100644 --- a/tests/geometry/pointcloud_test.py +++ b/tests/geometry/pointcloud_test.py @@ -13,10 +13,11 @@ # limitations under the License. from typing import Union +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import costs, geometry, pointcloud diff --git a/tests/geometry/scaling_cost_test.py b/tests/geometry/scaling_cost_test.py index 9f4ad1d57..e321b8524 100644 --- a/tests/geometry/scaling_cost_test.py +++ b/tests/geometry/scaling_cost_test.py @@ -13,10 +13,11 @@ # limitations under the License. from typing import Optional, Union +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import geometry, low_rank, pointcloud from ott.problems.linear import linear_problem diff --git a/tests/geometry/subsetting_test.py b/tests/geometry/subsetting_test.py index 5d7306682..7298001cc 100644 --- a/tests/geometry/subsetting_test.py +++ b/tests/geometry/subsetting_test.py @@ -13,10 +13,11 @@ # limitations under the License. from typing import Optional, Sequence, Tuple, Type, Union +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import geometry, low_rank, pointcloud diff --git a/tests/initializers/linear/sinkhorn_init_test.py b/tests/initializers/linear/sinkhorn_init_test.py index 7686ddfa9..8cc20f4c0 100644 --- a/tests/initializers/linear/sinkhorn_init_test.py +++ b/tests/initializers/linear/sinkhorn_init_test.py @@ -13,10 +13,11 @@ # limitations under the License. from typing import Literal, Optional +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import geometry, pointcloud from ott.initializers.linear import initializers as linear_init diff --git a/tests/initializers/linear/sinkhorn_lr_init_test.py b/tests/initializers/linear/sinkhorn_lr_init_test.py index e954fec76..0b67d2286 100644 --- a/tests/initializers/linear/sinkhorn_lr_init_test.py +++ b/tests/initializers/linear/sinkhorn_lr_init_test.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import geometry, pointcloud from ott.initializers.linear import initializers_lr diff --git a/tests/initializers/quadratic/gw_init_test.py b/tests/initializers/quadratic/gw_init_test.py index e680e9c01..8ab6cc4e5 100644 --- a/tests/initializers/quadratic/gw_init_test.py +++ b/tests/initializers/quadratic/gw_init_test.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import pointcloud from ott.initializers.linear import initializers as lin_init diff --git a/tests/math/lse_test.py b/tests/math/lse_test.py index 36e7eba7f..3ff28eada 100644 --- a/tests/math/lse_test.py +++ b/tests/math/lse_test.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.math import utils as mu diff --git a/tests/math/math_utils_test.py b/tests/math/math_utils_test.py index 7848bc2a9..3bd4c8114 100644 --- a/tests/math/math_utils_test.py +++ b/tests/math/math_utils_test.py @@ -13,10 +13,11 @@ # limitations under the License. import functools +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.math import utils as mu diff --git a/tests/math/matrix_square_root_test.py b/tests/math/matrix_square_root_test.py index 7c8a1e7d5..3f4aee25b 100644 --- a/tests/math/matrix_square_root_test.py +++ b/tests/math/matrix_square_root_test.py @@ -13,10 +13,11 @@ # limitations under the License. from typing import Any, Callable +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.math import matrix_square_root diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py index c6d25b128..74d66dea3 100644 --- a/tests/neural/conftest.py +++ b/tests/neural/conftest.py @@ -1,6 +1,7 @@ -import numpy as np import pytest +import numpy as np + from ott.neural.data.dataloaders import ConditionalDataLoader, OTDataLoader diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index 794b5d44e..fddc4fc3c 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -13,9 +13,11 @@ # limitations under the License. from typing import Iterator, Optional +import pytest + import jax.numpy as jnp + import optax -import pytest from ott.geometry import costs from ott.neural.models.models import NeuralVectorField, RescalingMLP diff --git a/tests/neural/icnn_test.py b/tests/neural/icnn_test.py index c52eac675..f710bdcbc 100644 --- a/tests/neural/icnn_test.py +++ b/tests/neural/icnn_test.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.neural import models diff --git a/tests/neural/losses_test.py b/tests/neural/losses_test.py index 4569b04d1..8e4a2f96c 100644 --- a/tests/neural/losses_test.py +++ b/tests/neural/losses_test.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import costs from ott.neural.models import losses, models diff --git a/tests/neural/map_estimator_test.py b/tests/neural/map_estimator_test.py index 96f9a9797..b5df51170 100644 --- a/tests/neural/map_estimator_test.py +++ b/tests/neural/map_estimator_test.py @@ -13,9 +13,10 @@ # limitations under the License. from typing import Optional -import jax.numpy as jnp import pytest +import jax.numpy as jnp + from ott import datasets from ott.geometry import pointcloud from ott.neural.models import losses, models diff --git a/tests/neural/meta_initializer_test.py b/tests/neural/meta_initializer_test.py index 98aa4f4d0..92f0c0b40 100644 --- a/tests/neural/meta_initializer_test.py +++ b/tests/neural/meta_initializer_test.py @@ -13,9 +13,11 @@ # limitations under the License. from typing import Optional +import pytest + import jax import jax.numpy as jnp -import pytest + from flax import linen as nn from ott.geometry import pointcloud diff --git a/tests/neural/neuraldual_test.py b/tests/neural/neuraldual_test.py index 1b7818163..b31ba9b6a 100644 --- a/tests/neural/neuraldual_test.py +++ b/tests/neural/neuraldual_test.py @@ -13,9 +13,10 @@ # limitations under the License. from typing import Optional, Sequence, Tuple +import pytest + import jax import numpy as np -import pytest from ott import datasets from ott.neural import models diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index e77789938..a57588a43 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -13,9 +13,11 @@ # limitations under the License. from typing import Iterator, Type +import pytest + import jax.numpy as jnp + import optax -import pytest from ott.neural.models.models import NeuralVectorField, RescalingMLP from ott.neural.solvers.flows import ( diff --git a/tests/problems/linear/potentials_test.py b/tests/problems/linear/potentials_test.py index dd5d4bbd6..a13211119 100644 --- a/tests/problems/linear/potentials_test.py +++ b/tests/problems/linear/potentials_test.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + import jax import jax.numpy as jnp -import matplotlib.pyplot as plt import numpy as np -import pytest + +import matplotlib.pyplot as plt from ott.geometry import costs, pointcloud from ott.problems.linear import linear_problem, potentials diff --git a/tests/solvers/linear/continuous_barycenter_test.py b/tests/solvers/linear/continuous_barycenter_test.py index 4989cc1db..730b529d3 100644 --- a/tests/solvers/linear/continuous_barycenter_test.py +++ b/tests/solvers/linear/continuous_barycenter_test.py @@ -14,10 +14,11 @@ import functools from typing import Tuple +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import costs, segment from ott.problems.linear import barycenter_problem diff --git a/tests/solvers/linear/discrete_barycenter_test.py b/tests/solvers/linear/discrete_barycenter_test.py index dc90e15c0..56784fb07 100644 --- a/tests/solvers/linear/discrete_barycenter_test.py +++ b/tests/solvers/linear/discrete_barycenter_test.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import jax.numpy as jnp import pytest +import jax.numpy as jnp + from ott.geometry import grid, pointcloud from ott.problems.linear import barycenter_problem as bp from ott.solvers.linear import discrete_barycenter as db diff --git a/tests/solvers/linear/sinkhorn_diff_test.py b/tests/solvers/linear/sinkhorn_diff_test.py index 944534e14..04de4dca9 100644 --- a/tests/solvers/linear/sinkhorn_diff_test.py +++ b/tests/solvers/linear/sinkhorn_diff_test.py @@ -14,10 +14,11 @@ import functools from typing import Callable, List, Optional, Tuple +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import costs, geometry, grid, pointcloud from ott.problems.linear import linear_problem diff --git a/tests/solvers/linear/sinkhorn_grid_test.py b/tests/solvers/linear/sinkhorn_grid_test.py index dd22f63b7..e7c116c8d 100644 --- a/tests/solvers/linear/sinkhorn_grid_test.py +++ b/tests/solvers/linear/sinkhorn_grid_test.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import grid, pointcloud from ott.problems.linear import linear_problem diff --git a/tests/solvers/linear/sinkhorn_lr_test.py b/tests/solvers/linear/sinkhorn_lr_test.py index 90b149ea8..0ce5a2307 100644 --- a/tests/solvers/linear/sinkhorn_lr_test.py +++ b/tests/solvers/linear/sinkhorn_lr_test.py @@ -13,10 +13,11 @@ # limitations under the License. from typing import Any, Tuple +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import low_rank, pointcloud from ott.problems.linear import linear_problem diff --git a/tests/solvers/linear/sinkhorn_misc_test.py b/tests/solvers/linear/sinkhorn_misc_test.py index e97a34228..d9d6d616c 100644 --- a/tests/solvers/linear/sinkhorn_misc_test.py +++ b/tests/solvers/linear/sinkhorn_misc_test.py @@ -14,16 +14,19 @@ from typing import Optional import chex + +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import costs, geometry, pointcloud from ott.problems.linear import linear_problem from ott.solvers import linear -from ott.solvers.linear import acceleration, sinkhorn +from ott.solvers.linear import acceleration from ott.solvers.linear import implicit_differentiation as implicit_lib +from ott.solvers.linear import sinkhorn class TestSinkhornAnderson: diff --git a/tests/solvers/linear/sinkhorn_test.py b/tests/solvers/linear/sinkhorn_test.py index c7475c4f3..0437a4efa 100644 --- a/tests/solvers/linear/sinkhorn_test.py +++ b/tests/solvers/linear/sinkhorn_test.py @@ -15,10 +15,11 @@ import sys from typing import Optional, Tuple +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott import utils from ott.geometry import costs, epsilon_scheduler, geometry, grid, pointcloud diff --git a/tests/solvers/linear/univariate_test.py b/tests/solvers/linear/univariate_test.py index 166da36bc..a002882fb 100644 --- a/tests/solvers/linear/univariate_test.py +++ b/tests/solvers/linear/univariate_test.py @@ -13,10 +13,11 @@ # limitations under the License. import functools +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest import scipy as sp from ott.geometry import costs, pointcloud diff --git a/tests/solvers/quadratic/fgw_test.py b/tests/solvers/quadratic/fgw_test.py index 10361d088..f998e802e 100644 --- a/tests/solvers/quadratic/fgw_test.py +++ b/tests/solvers/quadratic/fgw_test.py @@ -13,10 +13,11 @@ # limitations under the License. from typing import Literal, Tuple, Union +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import geometry, low_rank, pointcloud from ott.problems.quadratic import quadratic_problem diff --git a/tests/solvers/quadratic/gw_barycenter_test.py b/tests/solvers/quadratic/gw_barycenter_test.py index a157f27e5..eba4e3054 100644 --- a/tests/solvers/quadratic/gw_barycenter_test.py +++ b/tests/solvers/quadratic/gw_barycenter_test.py @@ -13,10 +13,11 @@ # limitations under the License. from typing import Any, Optional, Sequence, Tuple +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import pointcloud from ott.problems.quadratic import gw_barycenter as gwb diff --git a/tests/solvers/quadratic/gw_test.py b/tests/solvers/quadratic/gw_test.py index e7ef7b558..816f7fcd6 100644 --- a/tests/solvers/quadratic/gw_test.py +++ b/tests/solvers/quadratic/gw_test.py @@ -13,10 +13,11 @@ # limitations under the License. from typing import Tuple, Union +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import geometry, low_rank, pointcloud from ott.problems.quadratic import quadratic_problem diff --git a/tests/solvers/quadratic/lower_bound_test.py b/tests/solvers/quadratic/lower_bound_test.py index ad65be477..2e30a1bbe 100644 --- a/tests/solvers/quadratic/lower_bound_test.py +++ b/tests/solvers/quadratic/lower_bound_test.py @@ -15,10 +15,11 @@ import functools from typing import Callable +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import costs, distrib_costs, pointcloud from ott.initializers.linear import initializers diff --git a/tests/tools/gaussian_mixture/fit_gmm_pair_test.py b/tests/tools/gaussian_mixture/fit_gmm_pair_test.py index 75fc3bef5..20fe4ef4a 100644 --- a/tests/tools/gaussian_mixture/fit_gmm_pair_test.py +++ b/tests/tools/gaussian_mixture/fit_gmm_pair_test.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest + import jax import jax.numpy as jnp -import pytest from ott.tools.gaussian_mixture import ( fit_gmm, diff --git a/tests/tools/gaussian_mixture/fit_gmm_test.py b/tests/tools/gaussian_mixture/fit_gmm_test.py index 1cfb4f95e..648e9a287 100644 --- a/tests/tools/gaussian_mixture/fit_gmm_test.py +++ b/tests/tools/gaussian_mixture/fit_gmm_test.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest + import jax import jax.numpy as jnp import jax.test_util -import pytest from ott.tools.gaussian_mixture import fit_gmm, gaussian_mixture diff --git a/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py b/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py index bf2b01699..b11431d8c 100644 --- a/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py +++ b/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.tools.gaussian_mixture import gaussian_mixture, gaussian_mixture_pair diff --git a/tests/tools/gaussian_mixture/gaussian_mixture_test.py b/tests/tools/gaussian_mixture/gaussian_mixture_test.py index fd7675d51..540ebe980 100644 --- a/tests/tools/gaussian_mixture/gaussian_mixture_test.py +++ b/tests/tools/gaussian_mixture/gaussian_mixture_test.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.tools.gaussian_mixture import gaussian_mixture, linalg diff --git a/tests/tools/gaussian_mixture/gaussian_test.py b/tests/tools/gaussian_mixture/gaussian_test.py index 0eac630e3..23deff00d 100644 --- a/tests/tools/gaussian_mixture/gaussian_test.py +++ b/tests/tools/gaussian_mixture/gaussian_test.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.tools.gaussian_mixture import gaussian, scale_tril diff --git a/tests/tools/gaussian_mixture/linalg_test.py b/tests/tools/gaussian_mixture/linalg_test.py index 6fedb13ae..4529364dc 100644 --- a/tests/tools/gaussian_mixture/linalg_test.py +++ b/tests/tools/gaussian_mixture/linalg_test.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.tools.gaussian_mixture import linalg diff --git a/tests/tools/gaussian_mixture/probabilities_test.py b/tests/tools/gaussian_mixture/probabilities_test.py index 5d28a52aa..4fce8186f 100644 --- a/tests/tools/gaussian_mixture/probabilities_test.py +++ b/tests/tools/gaussian_mixture/probabilities_test.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.tools.gaussian_mixture import probabilities diff --git a/tests/tools/gaussian_mixture/scale_tril_test.py b/tests/tools/gaussian_mixture/scale_tril_test.py index 36643b6d7..f7bbe9293 100644 --- a/tests/tools/gaussian_mixture/scale_tril_test.py +++ b/tests/tools/gaussian_mixture/scale_tril_test.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.math import matrix_square_root from ott.tools.gaussian_mixture import scale_tril diff --git a/tests/tools/k_means_test.py b/tests/tools/k_means_test.py index 9b504a82d..c00288cec 100644 --- a/tests/tools/k_means_test.py +++ b/tests/tools/k_means_test.py @@ -15,10 +15,11 @@ import sys from typing import Any, Literal, Optional, Tuple, Union +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from sklearn import datasets from sklearn.cluster import KMeans, kmeans_plusplus from sklearn.cluster._k_means_common import _is_same_clustering diff --git a/tests/tools/plot_test.py b/tests/tools/plot_test.py index 8c8b81a1c..1f9f9ba01 100644 --- a/tests/tools/plot_test.py +++ b/tests/tools/plot_test.py @@ -13,6 +13,7 @@ # limitations under the License. import jax + import matplotlib.pyplot as plt import ott diff --git a/tests/tools/segment_sinkhorn_test.py b/tests/tools/segment_sinkhorn_test.py index 6e8a8fb8c..f98c164bf 100644 --- a/tests/tools/segment_sinkhorn_test.py +++ b/tests/tools/segment_sinkhorn_test.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import costs, pointcloud from ott.problems.linear import linear_problem diff --git a/tests/tools/sinkhorn_divergence_test.py b/tests/tools/sinkhorn_divergence_test.py index 0f3e56bfc..e3eab9912 100644 --- a/tests/tools/sinkhorn_divergence_test.py +++ b/tests/tools/sinkhorn_divergence_test.py @@ -13,10 +13,11 @@ # limitations under the License. from typing import Any, Dict, Optional +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import costs, geometry, pointcloud from ott.solvers import linear diff --git a/tests/tools/soft_sort_test.py b/tests/tools/soft_sort_test.py index 2432a2dee..c84680e9e 100644 --- a/tests/tools/soft_sort_test.py +++ b/tests/tools/soft_sort_test.py @@ -14,10 +14,11 @@ import functools from typing import Tuple +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.solvers.linear import acceleration from ott.solvers.linear import implicit_differentiation as implicit_lib From 4371e74949e3d252de1839afe10639bf01fcdaf0 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 29 Nov 2023 14:40:42 +0100 Subject: [PATCH 049/186] replace rng jnp.ndarray type by jax.array --- docs/tutorials/GWLRSinkhorn.ipynb | 2 +- src/ott/datasets.py | 2 +- .../initializers/linear/initializers_lr.py | 28 ++++++++-------- src/ott/neural/solvers/base_solver.py | 2 +- src/ott/neural/solvers/flows.py | 6 ++-- src/ott/neural/solvers/neuraldual.py | 4 +-- .../solvers/linear/continuous_barycenter.py | 2 +- .../solvers/quadratic/gromov_wasserstein.py | 4 +-- src/ott/solvers/quadratic/gw_barycenter.py | 2 +- src/ott/tools/gaussian_mixture/fit_gmm.py | 6 ++-- src/ott/tools/gaussian_mixture/gaussian.py | 4 +-- .../gaussian_mixture/gaussian_mixture.py | 4 +-- src/ott/tools/gaussian_mixture/linalg.py | 4 +-- .../tools/gaussian_mixture/probabilities.py | 4 +-- src/ott/tools/gaussian_mixture/scale_tril.py | 2 +- src/ott/tools/k_means.py | 10 +++--- tests/geometry/costs_test.py | 20 ++++++------ tests/geometry/graph_test.py | 10 +++--- tests/geometry/low_rank_test.py | 26 +++++++-------- tests/geometry/pointcloud_test.py | 10 +++--- tests/geometry/scaling_cost_test.py | 2 +- tests/geometry/subsetting_test.py | 8 ++--- .../initializers/linear/sinkhorn_init_test.py | 18 +++++------ .../linear/sinkhorn_lr_init_test.py | 8 ++--- tests/initializers/quadratic/gw_init_test.py | 3 +- tests/math/lse_test.py | 2 +- tests/math/math_utils_test.py | 2 +- tests/math/matrix_square_root_test.py | 6 ++-- tests/neural/icnn_test.py | 4 +-- tests/neural/losses_test.py | 7 ++-- tests/neural/meta_initializer_test.py | 4 +-- tests/problems/linear/potentials_test.py | 14 ++++---- .../linear/continuous_barycenter_test.py | 8 ++--- tests/solvers/linear/sinkhorn_diff_test.py | 24 +++++++------- tests/solvers/linear/sinkhorn_grid_test.py | 8 ++--- tests/solvers/linear/sinkhorn_lr_test.py | 2 +- tests/solvers/linear/sinkhorn_misc_test.py | 10 +++--- tests/solvers/linear/sinkhorn_test.py | 2 +- tests/solvers/linear/univariate_test.py | 4 +-- tests/solvers/quadratic/fgw_test.py | 6 ++-- tests/solvers/quadratic/gw_barycenter_test.py | 6 ++-- tests/solvers/quadratic/gw_test.py | 12 +++---- tests/solvers/quadratic/lower_bound_test.py | 4 +-- .../gaussian_mixture/fit_gmm_pair_test.py | 2 +- tests/tools/gaussian_mixture/fit_gmm_test.py | 2 +- .../gaussian_mixture_pair_test.py | 2 +- .../gaussian_mixture/gaussian_mixture_test.py | 16 +++++----- tests/tools/gaussian_mixture/gaussian_test.py | 18 +++++------ tests/tools/gaussian_mixture/linalg_test.py | 20 ++++++------ .../gaussian_mixture/probabilities_test.py | 4 +-- .../tools/gaussian_mixture/scale_tril_test.py | 12 +++---- tests/tools/k_means_test.py | 32 +++++++++---------- tests/tools/segment_sinkhorn_test.py | 2 +- tests/tools/sinkhorn_divergence_test.py | 4 +-- tests/tools/soft_sort_test.py | 24 +++++++------- 55 files changed, 222 insertions(+), 232 deletions(-) diff --git a/docs/tutorials/GWLRSinkhorn.ipynb b/docs/tutorials/GWLRSinkhorn.ipynb index 590671428..ace06be8f 100644 --- a/docs/tutorials/GWLRSinkhorn.ipynb +++ b/docs/tutorials/GWLRSinkhorn.ipynb @@ -66,7 +66,7 @@ }, "outputs": [], "source": [ - "def create_points(rng: jnp.ndarray, n: int, m: int, d1: int, d2: int):\n", + "def create_points(rng: jax.Array, n: int, m: int, d1: int, d2: int):\n", " rngs = jax.random.split(rng, 5)\n", " x = jax.random.uniform(rngs[0], (n, d1))\n", " y = jax.random.uniform(rngs[1], (m, d2))\n", diff --git a/src/ott/datasets.py b/src/ott/datasets.py index 1946bdcdd..3507c3418 100644 --- a/src/ott/datasets.py +++ b/src/ott/datasets.py @@ -57,7 +57,7 @@ class GaussianMixture: """ name: Name_t batch_size: int - rng: jnp.ndarray + rng: jax.Array scale: float = 5.0 std: float = 0.5 diff --git a/src/ott/initializers/linear/initializers_lr.py b/src/ott/initializers/linear/initializers_lr.py index b1f70d912..a3f615846 100644 --- a/src/ott/initializers/linear/initializers_lr.py +++ b/src/ott/initializers/linear/initializers_lr.py @@ -67,7 +67,7 @@ def __init__(self, rank: int, **kwargs: Any): def init_q( self, ot_prob: Problem_t, - rng: jnp.ndarray, + rng: jax.Array, *, init_g: jnp.ndarray, **kwargs: Any, @@ -88,7 +88,7 @@ def init_q( def init_r( self, ot_prob: Problem_t, - rng: jnp.ndarray, + rng: jax.Array, *, init_g: jnp.ndarray, **kwargs: Any, @@ -109,7 +109,7 @@ def init_r( def init_g( self, ot_prob: Problem_t, - rng: jnp.ndarray, + rng: jax.Array, **kwargs: Any, ) -> jnp.ndarray: """Initialize the low-rank factor :math:`g`. @@ -232,7 +232,7 @@ class RandomInitializer(LRInitializer): def init_q( # noqa: D102 self, ot_prob: Problem_t, - rng: jnp.ndarray, + rng: jax.Array, *, init_g: jnp.ndarray, **kwargs: Any, @@ -245,7 +245,7 @@ def init_q( # noqa: D102 def init_r( # noqa: D102 self, ot_prob: Problem_t, - rng: jnp.ndarray, + rng: jax.Array, *, init_g: jnp.ndarray, **kwargs: Any, @@ -258,7 +258,7 @@ def init_r( # noqa: D102 def init_g( # noqa: D102 self, ot_prob: Problem_t, - rng: jnp.ndarray, + rng: jax.Array, **kwargs: Any, ) -> jnp.ndarray: del kwargs @@ -305,7 +305,7 @@ def _compute_factor( def init_q( # noqa: D102 self, ot_prob: Problem_t, - rng: jnp.ndarray, + rng: jax.Array, *, init_g: jnp.ndarray, **kwargs: Any, @@ -316,7 +316,7 @@ def init_q( # noqa: D102 def init_r( # noqa: D102 self, ot_prob: Problem_t, - rng: jnp.ndarray, + rng: jax.Array, *, init_g: jnp.ndarray, **kwargs: Any, @@ -327,7 +327,7 @@ def init_r( # noqa: D102 def init_g( # noqa: D102 self, ot_prob: Problem_t, - rng: jnp.ndarray, + rng: jax.Array, **kwargs: Any, ) -> jnp.ndarray: del rng, kwargs @@ -376,7 +376,7 @@ def _extract_array(geom: geometry.Geometry, *, first: bool) -> jnp.ndarray: def _compute_factor( self, ot_prob: Problem_t, - rng: jnp.ndarray, + rng: jax.Array, *, init_g: jnp.ndarray, which: Literal["q", "r"], @@ -418,7 +418,7 @@ def _compute_factor( def init_q( # noqa: D102 self, ot_prob: Problem_t, - rng: jnp.ndarray, + rng: jax.Array, *, init_g: jnp.ndarray, **kwargs: Any, @@ -430,7 +430,7 @@ def init_q( # noqa: D102 def init_r( # noqa: D102 self, ot_prob: Problem_t, - rng: jnp.ndarray, + rng: jax.Array, *, init_g: jnp.ndarray, **kwargs: Any, @@ -442,7 +442,7 @@ def init_r( # noqa: D102 def init_g( # noqa: D102 self, ot_prob: Problem_t, - rng: jnp.ndarray, + rng: jax.Array, **kwargs: Any, ) -> jnp.ndarray: del rng, kwargs @@ -511,7 +511,7 @@ class State(NamedTuple): # noqa: D106 def _compute_factor( self, ot_prob: Problem_t, - rng: jnp.ndarray, + rng: jax.Array, *, init_g: jnp.ndarray, which: Literal["q", "r"], diff --git a/src/ott/neural/solvers/base_solver.py b/src/ott/neural/solvers/base_solver.py index bde81e9da..fe0ea6f3d 100644 --- a/src/ott/neural/solvers/base_solver.py +++ b/src/ott/neural/solvers/base_solver.py @@ -258,7 +258,7 @@ class UnbalancednessMixin: def __init__( self, - rng: jnp.ndarray, + rng: jax.Array, source_dim: int, target_dim: int, cond_dim: Optional[int], diff --git a/src/ott/neural/solvers/flows.py b/src/ott/neural/solvers/flows.py index b61ff08d1..47be01fc5 100644 --- a/src/ott/neural/solvers/flows.py +++ b/src/ott/neural/solvers/flows.py @@ -180,7 +180,7 @@ def __init__(self, low: float, high: float) -> None: self.high = high @abc.abstractmethod - def __call__(self, rng: jnp.ndarray, num_samples: int) -> jnp.ndarray: + def __call__(self, rng: jax.Array, num_samples: int) -> jnp.ndarray: """Generate `num_samples` samples of the time `math`:t:. Args: @@ -201,7 +201,7 @@ class UniformSampler(BaseTimeSampler): def __init__(self, low: float = 0.0, high: float = 1.0) -> None: super().__init__(low=low, high=high) - def __call__(self, rng: jnp.ndarray, num_samples: int) -> jnp.ndarray: + def __call__(self, rng: jax.Array, num_samples: int) -> jnp.ndarray: """Generate `num_samples` samples of the time `math`:t:. Args: @@ -234,7 +234,7 @@ def __init__( super().__init__(low=low, high=high) self.offset = offset - def __call__(self, rng: jnp.ndarray, num_samples: int) -> jnp.ndarray: + def __call__(self, rng: jax.Array, num_samples: int) -> jnp.ndarray: """Generate `num_samples` samples of the time `math`:t:. Args: diff --git a/src/ott/neural/solvers/neuraldual.py b/src/ott/neural/solvers/neuraldual.py index a7da8c3e7..0d9e215bb 100644 --- a/src/ott/neural/solvers/neuraldual.py +++ b/src/ott/neural/solvers/neuraldual.py @@ -149,7 +149,7 @@ def potential_gradient_fn( def create_train_state( self, - rng: jnp.ndarray, + rng: jax.Array, optimizer: optax.OptState, input: Union[int, Tuple[int, ...]], **kwargs: Any, @@ -289,7 +289,7 @@ def __init__( def setup( self, - rng: jnp.ndarray, + rng: jax.Array, neural_f: BaseW2NeuralDual, neural_g: BaseW2NeuralDual, dim_data: int, diff --git a/src/ott/solvers/linear/continuous_barycenter.py b/src/ott/solvers/linear/continuous_barycenter.py index b93c14032..e1477e60f 100644 --- a/src/ott/solvers/linear/continuous_barycenter.py +++ b/src/ott/solvers/linear/continuous_barycenter.py @@ -196,7 +196,7 @@ def output_from_state( # noqa: D102 def iterations( solver: FreeWassersteinBarycenter, bar_size: int, bar_prob: barycenter_problem.FreeBarycenterProblem, x_init: jnp.ndarray, - rng: jnp.ndarray + rng: jax.Array ) -> FreeBarycenterState: """Jittable Wasserstein barycenter outer loop.""" diff --git a/src/ott/solvers/quadratic/gromov_wasserstein.py b/src/ott/solvers/quadratic/gromov_wasserstein.py index 862b91999..6180db73f 100644 --- a/src/ott/solvers/quadratic/gromov_wasserstein.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein.py @@ -272,7 +272,7 @@ def init_state( self, prob: quadratic_problem.QuadraticProblem, init: linear_problem.LinearProblem, - rng: jnp.ndarray, + rng: jax.Array, ) -> GWState: """Initialize the state of the Gromov-Wasserstein iterations. @@ -361,7 +361,7 @@ def iterations( solver: GromovWasserstein, prob: quadratic_problem.QuadraticProblem, init: linear_problem.LinearProblem, - rng: jnp.ndarray, + rng: jax.Array, ) -> GWOutput: """Jittable Gromov-Wasserstein outer loop.""" diff --git a/src/ott/solvers/quadratic/gw_barycenter.py b/src/ott/solvers/quadratic/gw_barycenter.py index 8816c5ada..ea14880fe 100644 --- a/src/ott/solvers/quadratic/gw_barycenter.py +++ b/src/ott/solvers/quadratic/gw_barycenter.py @@ -282,7 +282,7 @@ def tree_unflatten( # noqa: D102 @partial(jax.vmap, in_axes=[None, 0, None, 0, None]) def init_transports( - solver, rng: jnp.ndarray, a: jnp.ndarray, b: jnp.ndarray, + solver, rng: jax.Array, a: jnp.ndarray, b: jnp.ndarray, epsilon: Optional[float] ) -> jnp.ndarray: """Initialize random 2D point cloud and solve the linear OT problem. diff --git a/src/ott/tools/gaussian_mixture/fit_gmm.py b/src/ott/tools/gaussian_mixture/fit_gmm.py index 0e3fbc4e8..4c62bded7 100644 --- a/src/ott/tools/gaussian_mixture/fit_gmm.py +++ b/src/ott/tools/gaussian_mixture/fit_gmm.py @@ -195,7 +195,7 @@ def _dist_sq_one_loc(points: jnp.ndarray, loc: jnp.ndarray) -> jnp.ndarray: def _get_locs( - rng: jnp.ndarray, points: jnp.ndarray, n_components: int + rng: jax.Array, points: jnp.ndarray, n_components: int ) -> jnp.ndarray: """Get the initial component means. @@ -229,7 +229,7 @@ def _get_locs( def from_kmeans_plusplus( - rng: jnp.ndarray, + rng: jax.Array, points: jnp.ndarray, point_weights: Optional[jnp.ndarray], n_components: int, @@ -265,7 +265,7 @@ def from_kmeans_plusplus( def initialize( - rng: jnp.ndarray, + rng: jax.Array, points: jnp.ndarray, point_weights: Optional[jnp.ndarray], n_components: int, diff --git a/src/ott/tools/gaussian_mixture/gaussian.py b/src/ott/tools/gaussian_mixture/gaussian.py index 70ac505f2..6e0a8ccb7 100644 --- a/src/ott/tools/gaussian_mixture/gaussian.py +++ b/src/ott/tools/gaussian_mixture/gaussian.py @@ -63,7 +63,7 @@ def from_samples( @classmethod def from_random( cls, - rng: jnp.ndarray, + rng: jax.Array, n_dimensions: int, stdev_mean: float = 0.1, stdev_cov: float = 0.1, @@ -138,7 +138,7 @@ def log_prob( -0.5 * (d * LOG2PI + log_det[None] + jnp.sum(z ** 2., axis=-1)) ) # (?, k) - def sample(self, rng: jnp.ndarray, size: int) -> jnp.ndarray: + def sample(self, rng: jax.Array, size: int) -> jnp.ndarray: """Generate samples from the distribution.""" std_samples_t = jax.random.normal(key=rng, shape=(self.n_dimensions, size)) return self.loc[None] + ( diff --git a/src/ott/tools/gaussian_mixture/gaussian_mixture.py b/src/ott/tools/gaussian_mixture/gaussian_mixture.py index 5d40a870d..313689939 100644 --- a/src/ott/tools/gaussian_mixture/gaussian_mixture.py +++ b/src/ott/tools/gaussian_mixture/gaussian_mixture.py @@ -78,7 +78,7 @@ def __init__( @classmethod def from_random( cls, - rng: jnp.ndarray, + rng: jax.Array, n_components: int, n_dimensions: int, stdev_mean: float = 0.1, @@ -219,7 +219,7 @@ def components(self) -> List[gaussian.Gaussian]: """List of all GMM components.""" return [self.get_component(i) for i in range(self.n_components)] - def sample(self, rng: jnp.ndarray, size: int) -> jnp.ndarray: + def sample(self, rng: jax.Array, size: int) -> jnp.ndarray: """Generate samples from the distribution.""" subrng0, subrng1 = jax.random.split(rng) component = self.component_weight_ob.sample(rng=subrng0, size=size) diff --git a/src/ott/tools/gaussian_mixture/linalg.py b/src/ott/tools/gaussian_mixture/linalg.py index 9c88df0cc..8e71369f3 100644 --- a/src/ott/tools/gaussian_mixture/linalg.py +++ b/src/ott/tools/gaussian_mixture/linalg.py @@ -132,9 +132,7 @@ def invmatvectril( def get_random_orthogonal( - rng: jnp.ndarray, - dim: int, - dtype: Optional[jnp.dtype] = None + rng: jax.Array, dim: int, dtype: Optional[jnp.dtype] = None ) -> jnp.ndarray: """Get a random orthogonal matrix with the specified dimension.""" m = jax.random.normal(key=rng, shape=[dim, dim], dtype=dtype) diff --git a/src/ott/tools/gaussian_mixture/probabilities.py b/src/ott/tools/gaussian_mixture/probabilities.py index 66a90c1a7..6df3bb023 100644 --- a/src/ott/tools/gaussian_mixture/probabilities.py +++ b/src/ott/tools/gaussian_mixture/probabilities.py @@ -35,7 +35,7 @@ def __init__(self, params): @classmethod def from_random( cls, - rng: jnp.ndarray, + rng: jax.Array, n_dimensions: int, stdev: Optional[float] = 0.1, dtype: Optional[jnp.dtype] = None @@ -76,7 +76,7 @@ def probs(self) -> jnp.ndarray: """Get the probabilities.""" return jax.nn.softmax(self.unnormalized_log_probs()) - def sample(self, rng: jnp.ndarray, size: int) -> jnp.ndarray: + def sample(self, rng: jax.Array, size: int) -> jnp.ndarray: """Sample from the distribution.""" return jax.random.categorical( key=rng, logits=self.unnormalized_log_probs(), shape=(size,) diff --git a/src/ott/tools/gaussian_mixture/scale_tril.py b/src/ott/tools/gaussian_mixture/scale_tril.py index 95b812d99..b286cc74e 100644 --- a/src/ott/tools/gaussian_mixture/scale_tril.py +++ b/src/ott/tools/gaussian_mixture/scale_tril.py @@ -44,7 +44,7 @@ def from_points_and_weights( @classmethod def from_random( cls, - rng: jnp.ndarray, + rng: jax.Array, n_dimensions: int, stdev: Optional[float] = 0.1, dtype: jnp.dtype = jnp.float32, diff --git a/src/ott/tools/k_means.py b/src/ott/tools/k_means.py index 9175abe2c..abbe99f34 100644 --- a/src/ott/tools/k_means.py +++ b/src/ott/tools/k_means.py @@ -29,7 +29,7 @@ class KPPState(NamedTuple): # noqa: D101 - rng: jnp.ndarray + rng: jax.Array centroids: jnp.ndarray centroid_dists: jnp.ndarray @@ -109,7 +109,7 @@ def _from_state( def _random_init( - geom: pointcloud.PointCloud, k: int, rng: jnp.ndarray + geom: pointcloud.PointCloud, k: int, rng: jax.Array ) -> jnp.ndarray: ixs = jnp.arange(geom.shape[0]) ixs = jax.random.choice(rng, ixs, shape=(k,), replace=False) @@ -119,11 +119,11 @@ def _random_init( def _k_means_plus_plus( geom: pointcloud.PointCloud, k: int, - rng: jnp.ndarray, + rng: jax.Array, n_local_trials: Optional[int] = None, ) -> jnp.ndarray: - def init_fn(geom: pointcloud.PointCloud, rng: jnp.ndarray) -> KPPState: + def init_fn(geom: pointcloud.PointCloud, rng: jax.Array) -> KPPState: rng, next_rng = jax.random.split(rng, 2) ix = jax.random.choice(rng, jnp.arange(geom.shape[0]), shape=()) centroids = jnp.full((k, geom.cost_rank), jnp.inf).at[0].set(geom.x[ix]) @@ -224,7 +224,7 @@ def _update_centroids( @functools.partial(jax.vmap, in_axes=[0] + [None] * 9) def _k_means( - rng: jnp.ndarray, + rng: jax.Array, geom: pointcloud.PointCloud, k: int, weights: Optional[jnp.ndarray] = None, diff --git a/tests/geometry/costs_test.py b/tests/geometry/costs_test.py index 02d9976da..0a7bead17 100644 --- a/tests/geometry/costs_test.py +++ b/tests/geometry/costs_test.py @@ -36,7 +36,7 @@ def _proj(matrix: jnp.ndarray) -> jnp.ndarray: @pytest.mark.fast() class TestCostFn: - def test_cosine(self, rng: jnp.ndarray): + def test_cosine(self, rng: jax.Array): """Test the cosine cost function.""" x = jnp.array([0, 0]) y = jnp.array([0, 0]) @@ -85,7 +85,7 @@ def test_cosine(self, rng: jnp.ndarray): @pytest.mark.fast() class TestBuresBarycenter: - def test_bures(self, rng: jnp.ndarray): + def test_bures(self, rng: jax.Array): d = 3 r = jnp.array([1.2036, 0.2825, 0.013]) Sigma1 = r * jnp.eye(d) @@ -142,7 +142,7 @@ class TestRegTICost: ) def test_reg_cost_legendre( self, - rng: jnp.ndarray, + rng: jax.Array, scaling_reg: float, cost_fn_t: Type[costs.RegTICost], use_mat: bool, @@ -164,7 +164,7 @@ def test_reg_cost_legendre( @pytest.mark.parametrize("k", [1, 3, 10]) @pytest.mark.parametrize("d", [10, 50]) - def test_elastic_sq_k_overlap(self, rng: jnp.ndarray, k: int, d: int): + def test_elastic_sq_k_overlap(self, rng: jax.Array, k: int, d: int): expected = jax.random.normal(rng, (d,)) cost_fn = costs.ElasticSqKOverlap(k=k, scaling_reg=1e-2) @@ -179,9 +179,7 @@ def test_elastic_sq_k_overlap(self, rng: jnp.ndarray, k: int, d: int): costs.ElasticSqKOverlap(k=3, scaling_reg=17) ] ) - def test_sparse_displacement( - self, rng: jnp.ndarray, cost_fn: costs.RegTICost - ): + def test_sparse_displacement(self, rng: jax.Array, cost_fn: costs.RegTICost): frac_sparse = 0.7 rng1, rng2 = jax.random.split(rng, 2) d = 17 @@ -197,7 +195,7 @@ def test_sparse_displacement( @pytest.mark.parametrize("cost_type_t", [costs.ElasticL1, costs.ElasticSTVS]) def test_stronger_regularization_increases_sparsity( - self, rng: jnp.ndarray, cost_type_t: Type[costs.RegTICost] + self, rng: jax.Array, cost_type_t: Type[costs.RegTICost] ): d, rngs = 17, jax.random.split(rng, 4) x = jax.random.normal(rngs[0], (50, d)) @@ -226,7 +224,7 @@ class TestSoftDTW: @pytest.mark.parametrize("n", [7, 10]) @pytest.mark.parametrize("m", [9, 10]) @pytest.mark.parametrize("gamma", [1e-3, 5]) - def test_soft_dtw(self, rng: jnp.ndarray, n: int, m: int, gamma: float): + def test_soft_dtw(self, rng: jax.Array, n: int, m: int, gamma: float): rng1, rng2 = jax.random.split(rng, 2) t1 = jax.random.normal(rng1, (n,)) t2 = jax.random.normal(rng2, (m,)) @@ -239,7 +237,7 @@ def test_soft_dtw(self, rng: jnp.ndarray, n: int, m: int, gamma: float): @pytest.mark.parametrize(("debiased", "jit"), [(False, True), (True, False)]) def test_soft_dtw_debiased( self, - rng: jnp.ndarray, + rng: jax.Array, debiased: bool, jit: bool, ): @@ -266,7 +264,7 @@ def test_soft_dtw_debiased( @pytest.mark.parametrize(("debiased", "jit"), [(False, False), (True, True)]) @pytest.mark.parametrize("gamma", [1e-2, 1]) def test_soft_dtw_grad( - self, rng: jnp.ndarray, debiased: bool, jit: bool, gamma: float + self, rng: jax.Array, debiased: bool, jit: bool, gamma: float ): rngs = jax.random.split(rng, 4) eps, tol = 1e-3, 1e-5 diff --git a/tests/geometry/graph_test.py b/tests/geometry/graph_test.py index 9c79d2b42..a68179db3 100644 --- a/tests/geometry/graph_test.py +++ b/tests/geometry/graph_test.py @@ -78,7 +78,7 @@ def gt_geometry( class TestGraph: - def test_kernel_is_symmetric_positive_definite(self, rng: jnp.ndarray): + def test_kernel_is_symmetric_positive_definite(self, rng: jax.Array): n, tol = 65, 0.02 x = jax.random.normal(rng, (n,)) geom = graph.Graph.from_graph(random_graph(n), t=1e-3) @@ -115,7 +115,7 @@ def test_automatic_t(self): ) def test_approximates_ground_truth( self, - rng: jnp.ndarray, + rng: jax.Array, numerical_scheme: Literal["backward_euler", "crank_nicolson"], ): eps, n_steps = 1e-5, 20 @@ -209,7 +209,7 @@ def laplacian(G: jnp.ndarray) -> jnp.ndarray: np.testing.assert_allclose(actual, expected, rtol=1e-6, atol=1e-6) @pytest.mark.fast.with_args(jit=[False, True], only_fast=0) - def test_graph_sinkhorn(self, rng: jnp.ndarray, jit: bool): + def test_graph_sinkhorn(self, rng: jax.Array, jit: bool): def callback(geom: geometry.Geometry) -> sinkhorn.SinkhornOutput: solver = sinkhorn.Sinkhorn(lse_mode=False) @@ -252,7 +252,7 @@ def callback(geom: geometry.Geometry) -> sinkhorn.SinkhornOutput: ids=["not-implicit", "implicit"], ) def test_dense_graph_differentiability( - self, rng: jnp.ndarray, implicit_diff: bool + self, rng: jax.Array, implicit_diff: bool ): def callback( @@ -287,7 +287,7 @@ def callback( actual = 2 * jnp.vdot(v_w, grad_w) np.testing.assert_allclose(actual, expected, rtol=1e-4, atol=1e-4) - def test_tolerance_hilbert_metric(self, rng: jnp.ndarray): + def test_tolerance_hilbert_metric(self, rng: jax.Array): n, n_steps, t, tol = 256, 1000, 1e-4, 3e-4 G = random_graph(n, p=0.15) x = jnp.abs(jax.random.normal(rng, (n,))) diff --git a/tests/geometry/low_rank_test.py b/tests/geometry/low_rank_test.py index 3e068f8e2..507042f68 100644 --- a/tests/geometry/low_rank_test.py +++ b/tests/geometry/low_rank_test.py @@ -25,7 +25,7 @@ @pytest.mark.fast() class TestLRGeometry: - def test_apply(self, rng: jnp.ndarray): + def test_apply(self, rng: jax.Array): """Test application of cost to vec or matrix.""" n, m, r = 17, 11, 7 rngs = jax.random.split(rng, 5) @@ -46,7 +46,7 @@ def test_apply(self, rng: jnp.ndarray): @pytest.mark.parametrize("scale_cost", ["mean", "max_cost", "max_bound", 42.]) def test_conversion_pointcloud( - self, rng: jnp.ndarray, scale_cost: Union[str, float] + self, rng: jax.Array, scale_cost: Union[str, float] ): """Test conversion from PointCloud to LRCGeometry.""" n, m, d = 17, 11, 3 @@ -70,7 +70,7 @@ def test_conversion_pointcloud( rtol=1e-4 ) - def test_apply_squared(self, rng: jnp.ndarray): + def test_apply_squared(self, rng: jax.Array): """Test application of squared cost to vec or matrix.""" n, m = 27, 25 rngs = jax.random.split(rng, 5) @@ -95,7 +95,7 @@ def test_apply_squared(self, rng: jnp.ndarray): @pytest.mark.parametrize("bias", [(0, 0), (4, 5)]) @pytest.mark.parametrize("scale_factor", [(1, 1), (2, 3)]) def test_add_lr_geoms( - self, rng: jnp.ndarray, bias: Tuple[float, float], + self, rng: jax.Array, bias: Tuple[float, float], scale_factor: Tuple[float, float] ): """Test application of cost to vec or matrix.""" @@ -134,7 +134,7 @@ def test_add_lr_geoms( @pytest.mark.parametrize(("scale", "scale_cost", "epsilon"), [(0.1, "mean", None), (0.9, "max_cost", 1e-2)]) def test_add_lr_geoms_scale_factor( - self, rng: jnp.ndarray, scale: float, scale_cost: str, + self, rng: jax.Array, scale: float, scale_cost: str, epsilon: Optional[float] ): n, d = 71, 2 @@ -161,8 +161,7 @@ def test_add_lr_geoms_scale_factor( @pytest.mark.parametrize("axis", [0, 1]) @pytest.mark.parametrize("fn", [lambda x: x + 10, lambda x: x * 2]) def test_apply_affine_function_efficient( - self, rng: jnp.ndarray, fn: Callable[[jnp.ndarray], jnp.ndarray], - axis: int + self, rng: jax.Array, fn: Callable[[jnp.ndarray], jnp.ndarray], axis: int ): n, m, d = 21, 13, 3 rngs = jax.random.split(rng, 3) @@ -182,7 +181,7 @@ def test_apply_affine_function_efficient( np.testing.assert_allclose(res_ineff, res_eff, rtol=1e-4, atol=1e-4) @pytest.mark.parametrize("rank", [5, 1000]) - def test_point_cloud_to_lr(self, rng: jnp.ndarray, rank: int): + def test_point_cloud_to_lr(self, rng: jax.Array, rank: int): n, m = 1500, 1000 scale = 2.0 rngs = jax.random.split(rng, 2) @@ -222,7 +221,7 @@ def assert_upper_bound( assert lhs <= rhs @pytest.mark.fast.with_args(rank=[2, 3], tol=[5e-1, 1e-2], only_fast=0) - def test_geometry_to_lr(self, rng: jnp.ndarray, rank: int, tol: float): + def test_geometry_to_lr(self, rng: jax.Array, rank: int, tol: float): rng1, rng2 = jax.random.split(rng, 2) x = jax.random.normal(rng1, shape=(370, 3)) y = jax.random.normal(rng2, shape=(460, 3)) @@ -243,8 +242,7 @@ def test_geometry_to_lr(self, rng: jnp.ndarray, rank: int, tol: float): only_fast=1 ) def test_point_cloud_to_lr( - self, rng: jnp.ndarray, batch_size: Optional[int], - scale_cost: Optional[str] + self, rng: jax.Array, batch_size: Optional[int], scale_cost: Optional[str] ): rank, tol = 7, 1e-1 rng1, rng2 = jax.random.split(rng, 2) @@ -268,7 +266,7 @@ def test_point_cloud_to_lr( assert geom_lr.cost_rank == rank self.assert_upper_bound(geom, geom_lr, rank=rank, tol=tol) - def test_to_lrc_geometry_noop(self, rng: jnp.ndarray): + def test_to_lrc_geometry_noop(self, rng: jax.Array): rng1, rng2 = jax.random.split(rng, 2) cost1 = jax.random.normal(rng1, shape=(32, 2)) cost2 = jax.random.normal(rng2, shape=(23, 2)) @@ -290,7 +288,7 @@ def test_apply_transport_from_potentials(self): np.testing.assert_allclose(res, 1.1253539e-07, rtol=1e-6, atol=1e-6) @pytest.mark.limit_memory("190 MB") - def test_large_scale_factorization(self, rng: jnp.ndarray): + def test_large_scale_factorization(self, rng: jax.Array): rank, tol = 4, 1e-2 rng1, rng2 = jax.random.split(rng, 2) x = jax.random.normal(rng1, shape=(10_000, 7)) @@ -321,7 +319,7 @@ def test_conversion_grid(self): cost_matrix, cost_matrix_lrc, rtol=1e-5, atol=1e-5 ) - def test_full_to_lrc_geometry(self, rng: jnp.ndarray): + def test_full_to_lrc_geometry(self, rng: jax.Array): rng1, rng2 = jax.random.split(rng, 2) x = jax.random.normal(rng1, shape=(13, 7)) y = jax.random.normal(rng2, shape=(29, 7)) diff --git a/tests/geometry/pointcloud_test.py b/tests/geometry/pointcloud_test.py index 1a952132f..cd7cb671c 100644 --- a/tests/geometry/pointcloud_test.py +++ b/tests/geometry/pointcloud_test.py @@ -25,7 +25,7 @@ @pytest.mark.fast() class TestPointCloudApply: - def test_apply_cost_and_kernel(self, rng: jnp.ndarray): + def test_apply_cost_and_kernel(self, rng: jax.Array): """Test consistency of cost/kernel apply to vec.""" n, m, p, b = 5, 8, 10, 7 rngs = jax.random.split(rng, 5) @@ -69,7 +69,7 @@ def test_apply_cost_and_kernel(self, rng: jnp.ndarray): np.testing.assert_allclose(prod0_online, prod0, rtol=1e-03, atol=1e-02) np.testing.assert_allclose(prod1_online, prod1, rtol=1e-03, atol=1e-02) - def test_general_cost_fn(self, rng: jnp.ndarray): + def test_general_cost_fn(self, rng: jax.Array): """Test non-vec cost apply to vec.""" n, m, p, b = 5, 8, 10, 7 rngs = jax.random.split(rng, 5) @@ -98,7 +98,7 @@ def test_correct_shape(self): np.testing.assert_array_equal(pc.shape, (n, m)) @pytest.mark.parametrize("axis", [0, 1]) - def test_apply_cost_without_norm(self, rng: jnp.ndarray, axis: 1): + def test_apply_cost_without_norm(self, rng: jax.Array, axis: 1): rng1, rng2 = jax.random.split(rng, 2) x = jax.random.normal(rng1, shape=(17, 3)) y = jax.random.normal(rng2, shape=(12, 3)) @@ -123,7 +123,7 @@ class TestPointCloudCosineConversion: "scale_cost", ["mean", "median", "max_cost", "max_norm", 41] ) def test_cosine_to_sqeucl_conversion( - self, rng: jnp.ndarray, scale_cost: Union[str, float] + self, rng: jax.Array, scale_cost: Union[str, float] ): rng1, rng2 = jax.random.split(rng, 2) x = jax.random.normal(rng1, shape=(101, 4)) @@ -156,7 +156,7 @@ def test_cosine_to_sqeucl_conversion( ) @pytest.mark.parametrize("axis", [0, 1]) def test_apply_cost_cosine_to_sqeucl( - self, rng: jnp.ndarray, axis: int, scale_cost: Union[str, float] + self, rng: jax.Array, axis: int, scale_cost: Union[str, float] ): rng1, rng2 = jax.random.split(rng, 2) x = jax.random.normal(rng1, shape=(17, 5)) diff --git a/tests/geometry/scaling_cost_test.py b/tests/geometry/scaling_cost_test.py index e321b8524..6cd5dcaa9 100644 --- a/tests/geometry/scaling_cost_test.py +++ b/tests/geometry/scaling_cost_test.py @@ -27,7 +27,7 @@ class TestScaleCost: @pytest.fixture(autouse=True) - def initialize(self, rng: jnp.ndarray): + def initialize(self, rng: jax.Array): self.dim = 4 self.n = 7 self.m = 9 diff --git a/tests/geometry/subsetting_test.py b/tests/geometry/subsetting_test.py index 7298001cc..ebaa6d4ac 100644 --- a/tests/geometry/subsetting_test.py +++ b/tests/geometry/subsetting_test.py @@ -26,7 +26,7 @@ @pytest.fixture() def pc_masked( - rng: jnp.ndarray + rng: jax.Array ) -> Tuple[pointcloud.PointCloud, pointcloud.PointCloud]: n, m = 20, 30 rng1, rng2 = jax.random.split(rng, 2) @@ -67,7 +67,7 @@ class TestMaskPointCloud: "clazz", [geometry.Geometry, pointcloud.PointCloud, low_rank.LRCGeometry] ) def test_mask( - self, rng: jnp.ndarray, clazz: Type[geometry.Geometry], + self, rng: jax.Array, clazz: Type[geometry.Geometry], src_ixs: Optional[Union[int, Sequence[int]]], tgt_ixs: Optional[Union[int, Sequence[int]]] ): @@ -141,7 +141,7 @@ def test_masked_summary( ) def test_mask_permutation( - self, geom_masked: Tuple[Geom_t, pointcloud.PointCloud], rng: jnp.ndarray + self, geom_masked: Tuple[Geom_t, pointcloud.PointCloud], rng: jax.Array ): rng1, rng2 = jax.random.split(rng) geom, _ = geom_masked @@ -163,7 +163,7 @@ def test_mask_permutation( ) def test_boolean_mask( - self, geom_masked: Tuple[Geom_t, pointcloud.PointCloud], rng: jnp.ndarray + self, geom_masked: Tuple[Geom_t, pointcloud.PointCloud], rng: jax.Array ): rng1, rng2 = jax.random.split(rng) p = jnp.array([0.5, 0.5]) diff --git a/tests/initializers/linear/sinkhorn_init_test.py b/tests/initializers/linear/sinkhorn_init_test.py index 8cc20f4c0..5af512a4a 100644 --- a/tests/initializers/linear/sinkhorn_init_test.py +++ b/tests/initializers/linear/sinkhorn_init_test.py @@ -26,7 +26,7 @@ def create_sorting_problem( - rng: jnp.ndarray, + rng: jax.Array, n: int, epsilon: float = 1e-2, batch_size: Optional[int] = None @@ -56,7 +56,7 @@ def create_sorting_problem( def create_ot_problem( - rng: jnp.ndarray, + rng: jax.Array, n: int, m: int, d: int, @@ -133,9 +133,7 @@ def test_create_initializer(self, init: str): @pytest.mark.parametrize(("vector_min", "lse_mode"), [(True, True), (True, False), (False, True)]) - def test_sorting_init( - self, vector_min: bool, lse_mode: bool, rng: jnp.ndarray - ): + def test_sorting_init(self, vector_min: bool, lse_mode: bool, rng: jax.Array): """Tests sorting dual initializer.""" n = 50 epsilon = 1e-2 @@ -169,7 +167,7 @@ def test_sorting_init( assert sink_out_init.converged assert sink_out_base.n_iters > sink_out_init.n_iters - def test_sorting_init_online(self, rng: jnp.ndarray): + def test_sorting_init_online(self, rng: jax.Array): n = 10 epsilon = 1e-2 @@ -180,7 +178,7 @@ def test_sorting_init_online(self, rng: jnp.ndarray): with pytest.raises(AssertionError, match=r"online"): sort_init.init_dual_a(ot_problem, lse_mode=True) - def test_sorting_init_square_cost(self, rng: jnp.ndarray): + def test_sorting_init_square_cost(self, rng: jax.Array): n, m, d = 10, 15, 1 epsilon = 1e-2 @@ -189,7 +187,7 @@ def test_sorting_init_square_cost(self, rng: jnp.ndarray): with pytest.raises(AssertionError, match=r"square"): sort_init.init_dual_a(ot_problem, lse_mode=True) - def test_default_initializer(self, rng: jnp.ndarray): + def test_default_initializer(self, rng: jax.Array): """Tests default initializer""" n, m, d = 20, 20, 2 epsilon = 1e-2 @@ -207,7 +205,7 @@ def test_default_initializer(self, rng: jnp.ndarray): np.testing.assert_array_equal(0., default_potential_a) np.testing.assert_array_equal(0., default_potential_b) - def test_gauss_pointcloud_geom(self, rng: jnp.ndarray): + def test_gauss_pointcloud_geom(self, rng: jax.Array): n, m, d = 20, 20, 2 epsilon = 1e-2 @@ -228,7 +226,7 @@ def test_gauss_pointcloud_geom(self, rng: jnp.ndarray): @pytest.mark.parametrize("jit", [False, True]) @pytest.mark.parametrize("initializer", ["sorting", "gaussian", "subsample"]) def test_initializer_n_iter( - self, rng: jnp.ndarray, lse_mode: bool, jit: bool, + self, rng: jax.Array, lse_mode: bool, jit: bool, initializer: Literal["sorting", "gaussian", "subsample"] ): """Tests Gaussian initializer""" diff --git a/tests/initializers/linear/sinkhorn_lr_init_test.py b/tests/initializers/linear/sinkhorn_lr_init_test.py index 0b67d2286..1d2a0e01b 100644 --- a/tests/initializers/linear/sinkhorn_lr_init_test.py +++ b/tests/initializers/linear/sinkhorn_lr_init_test.py @@ -37,7 +37,7 @@ def test_explicit_initializer(self): ) @pytest.mark.parametrize("partial_init", ["q", "r", "g"]) def test_partial_initialization( - self, rng: jnp.ndarray, initializer: str, partial_init: str + self, rng: jax.Array, initializer: str, partial_init: str ): n, d, rank = 27, 5, 6 rng1, rng2, rng3, rng4 = jax.random.split(rng, 4) @@ -65,7 +65,7 @@ def test_partial_initialization( @pytest.mark.fast.with_args("rank", [2, 4, 10, 13], only_fast=True) def test_generalized_k_means_has_correct_rank( - self, rng: jnp.ndarray, rank: int + self, rng: jax.Array, rank: int ): n, d = 27, 5 x = jax.random.normal(rng, (n, d)) @@ -82,7 +82,7 @@ def test_generalized_k_means_has_correct_rank( assert jnp.linalg.matrix_rank(q) == rank assert jnp.linalg.matrix_rank(r) == rank - def test_generalized_k_means_matches_k_means(self, rng: jnp.ndarray): + def test_generalized_k_means_matches_k_means(self, rng: jax.Array): n, d, rank = 27, 7, 5 eps = 1e-1 rng1, rng2 = jax.random.split(rng, 2) @@ -112,7 +112,7 @@ def test_generalized_k_means_matches_k_means(self, rng: jnp.ndarray): ) @pytest.mark.parametrize("epsilon", [0., 1e-1]) - def test_better_initialization_helps(self, rng: jnp.ndarray, epsilon: float): + def test_better_initialization_helps(self, rng: jax.Array, epsilon: float): n, d, rank = 81, 13, 3 rng1, rng2 = jax.random.split(rng, 2) x = jax.random.normal(rng1, (n, d)) diff --git a/tests/initializers/quadratic/gw_init_test.py b/tests/initializers/quadratic/gw_init_test.py index 8ab6cc4e5..ea630f4a2 100644 --- a/tests/initializers/quadratic/gw_init_test.py +++ b/tests/initializers/quadratic/gw_init_test.py @@ -14,7 +14,6 @@ import pytest import jax -import jax.numpy as jnp import numpy as np from ott.geometry import pointcloud @@ -51,7 +50,7 @@ def test_explicit_initializer_lr(self): assert solver.initializer.rank == rank @pytest.mark.parametrize("eps", [0., 1e-2]) - def test_gw_better_initialization_helps(self, rng: jnp.ndarray, eps: float): + def test_gw_better_initialization_helps(self, rng: jax.Array, eps: float): n, m, d1, d2, rank = 83, 84, 8, 6, 4 rng1, rng2, rng3, rng4 = jax.random.split(rng, 4) diff --git a/tests/math/lse_test.py b/tests/math/lse_test.py index 3ff28eada..7a1c469be 100644 --- a/tests/math/lse_test.py +++ b/tests/math/lse_test.py @@ -23,7 +23,7 @@ @pytest.mark.fast() class TestGeometryLse: - def test_lse(self, rng: jnp.ndarray): + def test_lse(self, rng: jax.Array): """Test consistency of custom lse's jvp.""" n, m = 12, 8 rngs = jax.random.split(rng, 5) diff --git a/tests/math/math_utils_test.py b/tests/math/math_utils_test.py index 3bd4c8114..a3afb0dca 100644 --- a/tests/math/math_utils_test.py +++ b/tests/math/math_utils_test.py @@ -27,7 +27,7 @@ class TestNorm: @pytest.mark.parametrize("ord", [1.1, 2.0, jnp.inf]) def test_norm( self, - rng: jnp.ndarray, + rng: jax.Array, ord, ): d = 5 diff --git a/tests/math/matrix_square_root_test.py b/tests/math/matrix_square_root_test.py index 3f4aee25b..ddb25458b 100644 --- a/tests/math/matrix_square_root_test.py +++ b/tests/math/matrix_square_root_test.py @@ -22,7 +22,7 @@ from ott.math import matrix_square_root -def _get_random_spd_matrix(dim: int, rng: jnp.ndarray): +def _get_random_spd_matrix(dim: int, rng: jax.Array): # Get a random symmetric, positive definite matrix of a specified size. rng, subrng0, subrng1 = jax.random.split(rng, num=3) @@ -38,7 +38,7 @@ def _get_random_spd_matrix(dim: int, rng: jnp.ndarray): def _get_test_fn( - fn: Callable[[jnp.ndarray], jnp.ndarray], dim: int, rng: jnp.ndarray, + fn: Callable[[jnp.ndarray], jnp.ndarray], dim: int, rng: jax.Array, **kwargs: Any ) -> Callable[[jnp.ndarray], jnp.ndarray]: # We want to test gradients of a function fn that maps positive definite @@ -72,7 +72,7 @@ def _sqrt_plus_inv_sqrt(x: jnp.ndarray) -> jnp.ndarray: class TestMatrixSquareRoot: @pytest.fixture(autouse=True) - def initialize(self, rng: jnp.ndarray): + def initialize(self, rng: jax.Array): self.dim = 13 self.batch = 3 # Values for testing the Sylvester solver diff --git a/tests/neural/icnn_test.py b/tests/neural/icnn_test.py index d0d4e92b8..dba2f7b7c 100644 --- a/tests/neural/icnn_test.py +++ b/tests/neural/icnn_test.py @@ -23,7 +23,7 @@ @pytest.mark.fast() class TestICNN: - def test_icnn_convexity(self, rng: jnp.ndarray): + def test_icnn_convexity(self, rng: jax.Array): """Tests convexity of ICNN.""" n_samples, n_features = 10, 2 dim_hidden = (64, 64) @@ -49,7 +49,7 @@ def test_icnn_convexity(self, rng: jnp.ndarray): np.testing.assert_array_equal(jnp.asarray(out) >= 0, True) - def test_icnn_hessian(self, rng: jnp.ndarray): + def test_icnn_hessian(self, rng: jax.Array): """Tests if Hessian of ICNN is positive-semidefinite.""" # define icnn model diff --git a/tests/neural/losses_test.py b/tests/neural/losses_test.py index 8e4a2f96c..f18681c7a 100644 --- a/tests/neural/losses_test.py +++ b/tests/neural/losses_test.py @@ -15,7 +15,6 @@ import pytest import jax -import jax.numpy as jnp import numpy as np from ott.geometry import costs @@ -28,7 +27,7 @@ class TestMongeGap: @pytest.mark.parametrize("n_samples", [5, 25]) @pytest.mark.parametrize("n_features", [10, 50, 100]) def test_monge_gap_non_negativity( - self, rng: jnp.ndarray, n_samples: int, n_features: int + self, rng: jax.Array, n_samples: int, n_features: int ): # generate data @@ -54,7 +53,7 @@ def test_monge_gap_non_negativity( np.testing.assert_array_equal(monge_gap_value, monge_gap_from_samples_value) - def test_monge_gap_jit(self, rng: jnp.ndarray): + def test_monge_gap_jit(self, rng: jax.Array): n_samples, n_features = 31, 17 # generate data rng1, rng2 = jax.random.split(rng, 2) @@ -86,7 +85,7 @@ def test_monge_gap_jit(self, rng: jnp.ndarray): ], ) def test_monge_gap_from_samples_different_cost( - self, rng: jnp.ndarray, cost_fn: costs.CostFn, n_samples: int, + self, rng: jax.Array, cost_fn: costs.CostFn, n_samples: int, n_features: int ): """Test that the Monge gap for different costs. diff --git a/tests/neural/meta_initializer_test.py b/tests/neural/meta_initializer_test.py index 442fe9272..e84554940 100644 --- a/tests/neural/meta_initializer_test.py +++ b/tests/neural/meta_initializer_test.py @@ -42,7 +42,7 @@ def __call__(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: def create_ot_problem( - rng: jnp.ndarray, + rng: jax.Array, n: int, m: int, d: int, @@ -88,7 +88,7 @@ def run_sinkhorn( class TestMetaInitializer: @pytest.mark.parametrize("lse_mode", [True, False]) - def test_meta_initializer(self, rng: jnp.ndarray, lse_mode: bool): + def test_meta_initializer(self, rng: jax.Array, lse_mode: bool): """Tests Meta initializer""" n, m, d = 20, 20, 2 epsilon = 1e-2 diff --git a/tests/problems/linear/potentials_test.py b/tests/problems/linear/potentials_test.py index a13211119..aa492c628 100644 --- a/tests/problems/linear/potentials_test.py +++ b/tests/problems/linear/potentials_test.py @@ -38,7 +38,7 @@ def test_device_put(self): class TestEntropicPotentials: - def test_device_put(self, rng: jnp.ndarray): + def test_device_put(self, rng: jax.Array): n = 10 device = jax.devices()[0] rngs = jax.random.split(rng, 5) @@ -55,7 +55,7 @@ def test_device_put(self, rng: jnp.ndarray): _ = jax.device_put(pot, device) @pytest.mark.fast.with_args(eps=[5e-2, 1e-1], only_fast=0) - def test_entropic_potentials_dist(self, rng: jnp.ndarray, eps: float): + def test_entropic_potentials_dist(self, rng: jax.Array, eps: float): n1, n2, d = 64, 96, 2 rng1, rng2, rng3, rng4 = jax.random.split(rng, 4) @@ -93,7 +93,7 @@ def test_entropic_potentials_dist(self, rng: jnp.ndarray, eps: float): @pytest.mark.fast.with_args(forward=[False, True], only_fast=0) def test_entropic_potentials_displacement( - self, rng: jnp.ndarray, forward: bool, monkeypatch + self, rng: jax.Array, forward: bool, monkeypatch ): """Tests entropic displacements, as well as their plots.""" n1, n2, d = 96, 128, 2 @@ -136,7 +136,7 @@ def test_entropic_potentials_displacement( p=[1.3, 2.2, 1.0], forward=[False, True], only_fast=0 ) def test_entropic_potentials_sqpnorm( - self, rng: jnp.ndarray, p: float, forward: bool + self, rng: jax.Array, p: float, forward: bool ): epsilon = None cost_fn = costs.SqPNorm(p=p) @@ -176,7 +176,7 @@ def test_entropic_potentials_sqpnorm( p=[1.45, 2.2, 1.0], forward=[False, True], only_fast=0 ) def test_entropic_potentials_pnorm( - self, rng: jnp.ndarray, p: float, forward: bool + self, rng: jax.Array, p: float, forward: bool ): epsilon = None cost_fn = costs.PNormP(p=p) @@ -218,7 +218,7 @@ def test_entropic_potentials_pnorm( assert div < .1 * div_0 @pytest.mark.parametrize("jit", [False, True]) - def test_distance_differentiability(self, rng: jnp.ndarray, jit: bool): + def test_distance_differentiability(self, rng: jax.Array, jit: bool): rng1, rng2, rng3 = jax.random.split(rng, 3) n, m, d = 18, 36, 5 @@ -240,7 +240,7 @@ def test_distance_differentiability(self, rng: jnp.ndarray, jit: bool): np.testing.assert_allclose(actual, expected, rtol=1e-4, atol=1e-4) @pytest.mark.parametrize("eps", [None, 1e-1, 1e1, 1e2, 1e3]) - def test_potentials_sinkhorn_divergence(self, rng: jnp.ndarray, eps: float): + def test_potentials_sinkhorn_divergence(self, rng: jax.Array, eps: float): rng1, rng2, rng3 = jax.random.split(rng, 3) n, m, d = 32, 36, 4 fwd = True diff --git a/tests/solvers/linear/continuous_barycenter_test.py b/tests/solvers/linear/continuous_barycenter_test.py index 730b529d3..48b9e7e0d 100644 --- a/tests/solvers/linear/continuous_barycenter_test.py +++ b/tests/solvers/linear/continuous_barycenter_test.py @@ -51,7 +51,7 @@ class TestBarycenter: }, ) def test_euclidean_barycenter( - self, rng: jnp.ndarray, rank: int, epsilon: float, init_random: bool, + self, rng: jax.Array, rank: int, epsilon: float, init_random: bool, jit: bool ): rngs = jax.random.split(rng, 20) @@ -116,7 +116,7 @@ def test_euclidean_barycenter( assert jnp.all(out.x.ravel() > .7) @pytest.mark.parametrize("segment_before", [False, True]) - def test_barycenter_jit(self, rng: jnp.ndarray, segment_before: bool): + def test_barycenter_jit(self, rng: jax.Array, segment_before: bool): @functools.partial(jax.jit, static_argnums=(2, 3)) def barycenter( @@ -171,7 +171,7 @@ def barycenter( @pytest.mark.fast() def test_bures_barycenter( self, - rng: jnp.ndarray, + rng: jax.Array, ): lse_mode = True, epsilon = 1e-1 @@ -257,7 +257,7 @@ def test_bures_barycenter( @pytest.mark.fast() def test_bures_barycenter_different_number_of_components( self, - rng: jnp.ndarray, + rng: jax.Array, ): alpha = 5. epsilon = 0.01 diff --git a/tests/solvers/linear/sinkhorn_diff_test.py b/tests/solvers/linear/sinkhorn_diff_test.py index 04de4dca9..69c01f9ad 100644 --- a/tests/solvers/linear/sinkhorn_diff_test.py +++ b/tests/solvers/linear/sinkhorn_diff_test.py @@ -30,7 +30,7 @@ class TestSinkhornImplicit: """Check implicit and autodiff match for Sinkhorn.""" @pytest.fixture(autouse=True) - def initialize(self, rng: jnp.ndarray): + def initialize(self, rng: jax.Array): self.dim = 3 self.n = 38 self.m = 73 @@ -138,7 +138,7 @@ class TestSinkhornJacobian: only_fast=0, ) def test_autograd_sinkhorn( - self, rng: jnp.ndarray, lse_mode: bool, shape_data: Tuple[int, int] + self, rng: jax.Array, lse_mode: bool, shape_data: Tuple[int, int] ): """Test gradient w.r.t. probability weights.""" n, m = shape_data @@ -181,7 +181,7 @@ def reg_ot(a: jnp.ndarray, b: jnp.ndarray) -> float: @pytest.mark.parametrize(("lse_mode", "shape_data"), [(True, (7, 9)), (False, (11, 5))]) def test_gradient_sinkhorn_geometry( - self, rng: jnp.ndarray, lse_mode: bool, shape_data: Tuple[int, int] + self, rng: jax.Array, lse_mode: bool, shape_data: Tuple[int, int] ): """Test gradient w.r.t. cost matrix.""" n, m = shape_data @@ -244,7 +244,7 @@ def loss_fn(cm: jnp.ndarray): only_fast=[0, 1], ) def test_gradient_sinkhorn_euclidean( - self, rng: jnp.ndarray, lse_mode: bool, implicit: bool, min_iter: int, + self, rng: jax.Array, lse_mode: bool, implicit: bool, min_iter: int, max_iter: int, epsilon: float, cost_fn: costs.CostFn ): """Test gradient w.r.t. locations x of reg-ot-cost.""" @@ -318,7 +318,7 @@ def loss_fn(x: jnp.ndarray, ) np.testing.assert_array_equal(jnp.isnan(custom_grad), False) - def test_autoepsilon_differentiability(self, rng: jnp.ndarray): + def test_autoepsilon_differentiability(self, rng: jax.Array): cost = jax.random.uniform(rng, (15, 17)) def reg_ot_cost(c: jnp.ndarray) -> float: @@ -330,7 +330,7 @@ def reg_ot_cost(c: jnp.ndarray) -> float: np.testing.assert_array_equal(jnp.isnan(gradient), False) @pytest.mark.fast() - def test_differentiability_with_jit(self, rng: jnp.ndarray): + def test_differentiability_with_jit(self, rng: jax.Array): def reg_ot_cost(c: jnp.ndarray) -> float: geom = geometry.Geometry(c, epsilon=1e-2) @@ -348,7 +348,7 @@ def reg_ot_cost(c: jnp.ndarray) -> float: only_fast=0 ) def test_apply_transport_jacobian( - self, rng: jnp.ndarray, lse_mode: bool, tau_a: float, tau_b: float, + self, rng: jax.Array, lse_mode: bool, tau_a: float, tau_b: float, arg: int, axis: int ): """Tests Jacobian of application of OT to vector, w.r.t. @@ -460,7 +460,7 @@ def apply_ot(a: jnp.ndarray, x: jnp.ndarray, implicit: bool) -> jnp.ndarray: only_fast=0, ) def test_potential_jacobian_sinkhorn( - self, rng: jnp.ndarray, lse_mode: bool, tau_a: float, tau_b: float, + self, rng: jax.Array, lse_mode: bool, tau_a: float, tau_b: float, shape: Tuple[int, int], arg: int ): """Test Jacobian of optimal potential w.r.t. weights and locations.""" @@ -542,7 +542,7 @@ class TestSinkhornGradGrid: @pytest.mark.parametrize("lse_mode", [False, True]) def test_diff_sinkhorn_x_grid_x_perturbation( - self, rng: jnp.ndarray, lse_mode: bool + self, rng: jax.Array, lse_mode: bool ): """Test gradient w.r.t. probability weights.""" eps = 1e-3 # perturbation magnitude @@ -587,7 +587,7 @@ def reg_ot(x: List[jnp.ndarray]) -> float: @pytest.mark.parametrize("lse_mode", [False, True]) def test_diff_sinkhorn_x_grid_weights_perturbation( - self, rng: jnp.ndarray, lse_mode: bool + self, rng: jax.Array, lse_mode: bool ): """Test gradient w.r.t. probability weights.""" eps = 1e-4 # perturbation magnitude @@ -638,7 +638,7 @@ class TestSinkhornJacobianPreconditioning: only_fast=[0, -1], ) def test_potential_jacobian_sinkhorn_precond( - self, rng: jnp.ndarray, lse_mode: bool, tau_a: float, tau_b: float, + self, rng: jax.Array, lse_mode: bool, tau_a: float, tau_b: float, shape: Tuple[int, int], arg: int ): """Test Jacobian of optimal potential works across 2 precond_fun.""" @@ -741,7 +741,7 @@ class TestSinkhornHessian: only_fast=-1 ) def test_hessian_sinkhorn( - self, rng: jnp.ndarray, lse_mode: bool, tau_a: float, tau_b: float, + self, rng: jax.Array, lse_mode: bool, tau_a: float, tau_b: float, arg: int, lineax_ridge: float ): """Test hessian w.r.t. weights and locations.""" diff --git a/tests/solvers/linear/sinkhorn_grid_test.py b/tests/solvers/linear/sinkhorn_grid_test.py index e7c116c8d..d73bc124b 100644 --- a/tests/solvers/linear/sinkhorn_grid_test.py +++ b/tests/solvers/linear/sinkhorn_grid_test.py @@ -26,7 +26,7 @@ class TestSinkhornGrid: @pytest.mark.parametrize("lse_mode", [False, True]) - def test_separable_grid(self, rng: jnp.ndarray, lse_mode: bool): + def test_separable_grid(self, rng: jax.Array, lse_mode: bool): """Two histograms in a grid of size 5 x 6 x 7 in the hypercube^3.""" grid_size = (5, 6, 7) rngs = jax.random.split(rng, 2) @@ -47,7 +47,7 @@ def test_separable_grid(self, rng: jnp.ndarray, lse_mode: bool): assert threshold > err @pytest.mark.fast.with_args("lse_mode", [False, True], only_fast=0) - def test_grid_vs_euclidean(self, rng: jnp.ndarray, lse_mode: bool): + def test_grid_vs_euclidean(self, rng: jax.Array, lse_mode: bool): grid_size = (5, 6, 7) rngs = jax.random.split(rng, 2) a = jax.random.uniform(rngs[0], grid_size) @@ -70,7 +70,7 @@ def test_grid_vs_euclidean(self, rng: jnp.ndarray, lse_mode: bool): ) @pytest.mark.fast.with_args("lse_mode", [False, True], only_fast=1) - def test_apply_transport_grid(self, rng: jnp.ndarray, lse_mode: bool): + def test_apply_transport_grid(self, rng: jax.Array, lse_mode: bool): grid_size = (5, 6, 7) rngs = jax.random.split(rng, 4) a = jax.random.uniform(rngs[0], grid_size) @@ -119,7 +119,7 @@ def test_apply_transport_grid(self, rng: jnp.ndarray, lse_mode: bool): np.testing.assert_array_equal(jnp.isnan(mat_transport_t_vec_a), False) @pytest.mark.fast() - def test_apply_cost(self, rng: jnp.ndarray): + def test_apply_cost(self, rng: jax.Array): grid_size = (5, 6, 7) geom_grid = grid.Grid(grid_size=grid_size, epsilon=0.1) diff --git a/tests/solvers/linear/sinkhorn_lr_test.py b/tests/solvers/linear/sinkhorn_lr_test.py index 0ce5a2307..1bfbd4843 100644 --- a/tests/solvers/linear/sinkhorn_lr_test.py +++ b/tests/solvers/linear/sinkhorn_lr_test.py @@ -27,7 +27,7 @@ class TestLRSinkhorn: @pytest.fixture(autouse=True) - def initialize(self, rng: jnp.ndarray): + def initialize(self, rng: jax.Array): self.dim = 4 self.n = 23 self.m = 27 diff --git a/tests/solvers/linear/sinkhorn_misc_test.py b/tests/solvers/linear/sinkhorn_misc_test.py index d9d6d616c..9d45c518c 100644 --- a/tests/solvers/linear/sinkhorn_misc_test.py +++ b/tests/solvers/linear/sinkhorn_misc_test.py @@ -39,7 +39,7 @@ class TestSinkhornAnderson: only_fast=0, ) def test_anderson( - self, rng: jnp.ndarray, lse_mode: bool, tau_a: float, tau_b: float + self, rng: jax.Array, lse_mode: bool, tau_a: float, tau_b: float ): """Test efficiency of Anderson acceleration. @@ -131,7 +131,7 @@ def initialize(self): @pytest.mark.parametrize(("unbalanced", "thresh"), [(False, 1e-3), (True, 1e-4)]) def test_bures_point_cloud( - self, rng: jnp.ndarray, lse_mode: bool, unbalanced: bool, thresh: float + self, rng: jax.Array, lse_mode: bool, unbalanced: bool, thresh: float ): """Two point clouds of Gaussians, tested with various parameters.""" if unbalanced: @@ -172,7 +172,7 @@ def test_regularized_unbalanced_bures_cost(self): class TestSinkhornOnline: @pytest.fixture(autouse=True) - def initialize(self, rng: jnp.ndarray): + def initialize(self, rng: jax.Array): self.dim = 3 self.n = 100 self.m = 42 @@ -237,7 +237,7 @@ def callback(epsilon: float, batch_size: int) -> sinkhorn.SinkhornOutput: class TestSinkhornUnbalanced: @pytest.fixture(autouse=True) - def initialize(self, rng: jnp.ndarray): + def initialize(self, rng: jax.Array): self.dim = 4 self.n = 17 self.m = 23 @@ -318,7 +318,7 @@ class TestSinkhornJIT: """Check jitted and non jit match for Sinkhorn, and that everything jits.""" @pytest.fixture(autouse=True) - def initialize(self, rng: jnp.ndarray): + def initialize(self, rng: jax.Array): self.dim = 3 self.n = 10 self.m = 11 diff --git a/tests/solvers/linear/sinkhorn_test.py b/tests/solvers/linear/sinkhorn_test.py index 0437a4efa..2676e74af 100644 --- a/tests/solvers/linear/sinkhorn_test.py +++ b/tests/solvers/linear/sinkhorn_test.py @@ -31,7 +31,7 @@ class TestSinkhorn: @pytest.fixture(autouse=True) - def initialize(self, rng: jnp.ndarray): + def initialize(self, rng: jax.Array): self.rng = rng self.dim = 4 self.n = 17 diff --git a/tests/solvers/linear/univariate_test.py b/tests/solvers/linear/univariate_test.py index a002882fb..6e0263611 100644 --- a/tests/solvers/linear/univariate_test.py +++ b/tests/solvers/linear/univariate_test.py @@ -29,7 +29,7 @@ class TestUnivariate: @pytest.fixture(autouse=True) - def initialize(self, rng: jnp.ndarray): + def initialize(self, rng: jax.Array): self.rng = rng self.n = 7 self.m = 5 @@ -120,7 +120,7 @@ def test_cdf_distance_and_scipy(self): @pytest.mark.fast() def test_univariate_grad( self, - rng: jnp.ndarray, + rng: jax.Array, ): # TODO: Once a `check_grad` function is implemented, replace the code # blocks before with `check_grad`'s. diff --git a/tests/solvers/quadratic/fgw_test.py b/tests/solvers/quadratic/fgw_test.py index f998e802e..47810b16a 100644 --- a/tests/solvers/quadratic/fgw_test.py +++ b/tests/solvers/quadratic/fgw_test.py @@ -30,7 +30,7 @@ class TestFusedGromovWasserstein: # TODO(michalk8): refactor me in the future @pytest.fixture(autouse=True) - def initialize(self, rng: jnp.ndarray): + def initialize(self, rng: jax.Array): d_x = 2 d_y = 3 d_xy = 4 @@ -217,7 +217,7 @@ def reg_gw( @pytest.mark.limit_memory("200 MB") @pytest.mark.parametrize("jit", [False, True]) - def test_fgw_lr_memory(self, rng: jnp.ndarray, jit: bool): + def test_fgw_lr_memory(self, rng: jax.Array, jit: bool): rngs = jax.random.split(rng, 4) n, m, d1, d2 = 5_000, 2_500, 1, 2 x = jax.random.uniform(rngs[0], (n, d1)) @@ -244,7 +244,7 @@ def test_fgw_lr_memory(self, rng: jnp.ndarray, jit: bool): @pytest.mark.parametrize("cost_rank", [4, (2, 3, 4)]) def test_fgw_lr_generic_cost_matrix( - self, rng: jnp.ndarray, cost_rank: Union[int, Tuple[int, int, int]] + self, rng: jax.Array, cost_rank: Union[int, Tuple[int, int, int]] ): n, m = 20, 30 rng1, rng2, rng3, rng4 = jax.random.split(rng, 4) diff --git a/tests/solvers/quadratic/gw_barycenter_test.py b/tests/solvers/quadratic/gw_barycenter_test.py index eba4e3054..02ecc953b 100644 --- a/tests/solvers/quadratic/gw_barycenter_test.py +++ b/tests/solvers/quadratic/gw_barycenter_test.py @@ -32,7 +32,7 @@ class TestGWBarycenter: def random_pc( n: int, d: int, - rng: jnp.ndarray, + rng: jax.Array, m: Optional[int] = None, **kwargs: Any ) -> pointcloud.PointCloud: @@ -66,7 +66,7 @@ def pad_cost_matrices( [("sqeucl", 17, None)] # , ("kl", 22, 1e-2)] ) def test_gw_barycenter( - self, rng: jnp.ndarray, gw_loss: str, bar_size: int, + self, rng: jax.Array, gw_loss: str, bar_size: int, epsilon: Optional[float] ): tol = 1e-3 if gw_loss == "sqeucl" else 1e-1 @@ -127,7 +127,7 @@ def test_gw_barycenter( ) def test_fgw_barycenter( self, - rng: jnp.ndarray, + rng: jax.Array, jit: bool, fused_penalty: float, scale_cost: str, diff --git a/tests/solvers/quadratic/gw_test.py b/tests/solvers/quadratic/gw_test.py index 816f7fcd6..2e5573fbe 100644 --- a/tests/solvers/quadratic/gw_test.py +++ b/tests/solvers/quadratic/gw_test.py @@ -32,7 +32,7 @@ class TestQuadraticProblem: @pytest.mark.parametrize("as_pc", [False, True]) @pytest.mark.parametrize("rank", [-1, 5, (1, 2, 3), (2, 3, 5)]) def test_quad_to_low_rank( - self, rng: jnp.ndarray, as_pc: bool, rank: Union[int, Tuple[int, ...]] + self, rng: jax.Array, as_pc: bool, rank: Union[int, Tuple[int, ...]] ): n, m, d1, d2, d = 100, 120, 4, 6, 10 rng1, rng2, rng3, rng4 = jax.random.split(rng, 4) @@ -88,7 +88,7 @@ def test_quad_to_low_rank( assert lr_prob._is_low_rank_convertible assert lr_prob.to_low_rank() is lr_prob - def test_gw_implicit_conversion_mixed_input(self, rng: jnp.ndarray): + def test_gw_implicit_conversion_mixed_input(self, rng: jax.Array): n, m, d1, d2 = 13, 77, 3, 4 rng1, rng2 = jax.random.split(rng, 2) x = jax.random.normal(rng1, (n, d1)) @@ -108,7 +108,7 @@ def test_gw_implicit_conversion_mixed_input(self, rng: jnp.ndarray): class TestGromovWasserstein: @pytest.fixture(autouse=True) - def initialize(self, rng: jnp.ndarray): + def initialize(self, rng: jax.Array): d_x = 2 d_y = 3 self.n, self.m = 6, 7 @@ -311,7 +311,7 @@ def loss_thre(threshold: float) -> float: assert loss_thre(1e-3) >= loss_thre(1e-5) @pytest.mark.fast() - def test_gw_lr(self, rng: jnp.ndarray): + def test_gw_lr(self, rng: jax.Array): """Checking LR and Entropic have similar outputs on same problem.""" rngs = jax.random.split(rng, 4) n, m, d1, d2 = 24, 17, 2, 3 @@ -335,7 +335,7 @@ def test_gw_lr(self, rng: jnp.ndarray): ot_gwlr.primal_cost, ot_gw.primal_cost, rtol=5e-2 ) - def test_gw_lr_matches_fused(self, rng: jnp.ndarray): + def test_gw_lr_matches_fused(self, rng: jax.Array): """Checking LR and Entropic have similar outputs on same fused problem.""" rngs = jax.random.split(rng, 5) n, m, d1, d2 = 24, 17, 2, 3 @@ -386,7 +386,7 @@ def test_gw_lr_apply(self, axis: int): @pytest.mark.parametrize("scale_cost", [1.0, "mean"]) def test_relative_epsilon( self, - rng: jnp.ndarray, + rng: jax.Array, scale_cost: Union[float, str], ): eps = 1e-2 diff --git a/tests/solvers/quadratic/lower_bound_test.py b/tests/solvers/quadratic/lower_bound_test.py index 2e30a1bbe..37bf2a8b3 100644 --- a/tests/solvers/quadratic/lower_bound_test.py +++ b/tests/solvers/quadratic/lower_bound_test.py @@ -32,7 +32,7 @@ class TestLowerBoundSolver: @pytest.fixture(autouse=True) - def initialize(self, rng: jnp.ndarray): + def initialize(self, rng: jax.Array): d_x = 2 d_y = 3 self.n, self.m = 13, 15 @@ -95,7 +95,7 @@ def test_lb_pointcloud(self, ground_cost: costs.TICost): ] ) def test_lb_grad( - self, rng: jnp.ndarray, sort_fn: Callable[[jnp.ndarray], jnp.ndarray], + self, rng: jax.Array, sort_fn: Callable[[jnp.ndarray], jnp.ndarray], method: str ): diff --git a/tests/tools/gaussian_mixture/fit_gmm_pair_test.py b/tests/tools/gaussian_mixture/fit_gmm_pair_test.py index 20fe4ef4a..06103cf7f 100644 --- a/tests/tools/gaussian_mixture/fit_gmm_pair_test.py +++ b/tests/tools/gaussian_mixture/fit_gmm_pair_test.py @@ -30,7 +30,7 @@ class TestFitGmmPair: @pytest.fixture(autouse=True) - def initialize(self, rng: jnp.ndarray): + def initialize(self, rng: jax.Array): mean_generator0 = jnp.array([[2., -1.], [-2., 0.], [4., 3.]]) cov_generator0 = jnp.array([[[0.2, 0.], [0., 0.1]], [[0.6, 0.], [0., 0.3]], [[0.5, 0.4], [0.4, 0.5]]]) diff --git a/tests/tools/gaussian_mixture/fit_gmm_test.py b/tests/tools/gaussian_mixture/fit_gmm_test.py index 648e9a287..82bbe3ec6 100644 --- a/tests/tools/gaussian_mixture/fit_gmm_test.py +++ b/tests/tools/gaussian_mixture/fit_gmm_test.py @@ -24,7 +24,7 @@ class TestFitGmm: @pytest.fixture(autouse=True) - def initialize(self, rng: jnp.ndarray): + def initialize(self, rng: jax.Array): mean_generator = jnp.array([[2., -1.], [-2., 0.], [4., 3.]]) cov_generator = jnp.array([[[0.2, 0.], [0., 0.1]], [[0.6, 0.], [0., 0.3]], [[0.5, 0.4], [0.4, 0.5]]]) diff --git a/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py b/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py index b11431d8c..690f07e33 100644 --- a/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py +++ b/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py @@ -23,7 +23,7 @@ class TestGaussianMixturePair: @pytest.fixture(autouse=True) - def initialize(self, rng: jnp.ndarray): + def initialize(self, rng: jax.Array): self.n_components = 3 self.n_dimensions = 2 self.epsilon = 1.e-3 diff --git a/tests/tools/gaussian_mixture/gaussian_mixture_test.py b/tests/tools/gaussian_mixture/gaussian_mixture_test.py index 540ebe980..864e11efc 100644 --- a/tests/tools/gaussian_mixture/gaussian_mixture_test.py +++ b/tests/tools/gaussian_mixture/gaussian_mixture_test.py @@ -24,7 +24,7 @@ class TestGaussianMixture: def test_get_summary_stats_from_points_and_assignment_probs( - self, rng: jnp.ndarray + self, rng: jax.Array ): n = 50 rng, subrng0, subrng1 = jax.random.split(rng, num=3) @@ -57,7 +57,7 @@ def test_get_summary_stats_from_points_and_assignment_probs( np.testing.assert_allclose(expected_cov, cov, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(expected_wt, comp_wt, atol=1e-4, rtol=1e-4) - def test_from_random(self, rng: jnp.ndarray): + def test_from_random(self, rng: jax.Array): gmm = gaussian_mixture.GaussianMixture.from_random( rng=rng, n_components=3, n_dimensions=2 ) @@ -80,7 +80,7 @@ def test_from_mean_cov_component_weights(self,): comp_wts, gmm.component_weights, atol=1e-4, rtol=1e-4 ) - def test_covariance(self, rng: jnp.ndarray): + def test_covariance(self, rng: jax.Array): gmm = gaussian_mixture.GaussianMixture.from_random( rng=rng, n_components=3, n_dimensions=2 ) @@ -90,7 +90,7 @@ def test_covariance(self, rng: jnp.ndarray): cov[i], component.covariance(), atol=1e-4, rtol=1e-4 ) - def test_sample(self, rng: jnp.ndarray): + def test_sample(self, rng: jax.Array): gmm = gaussian_mixture.GaussianMixture.from_mean_cov_component_weights( mean=jnp.array([[-1., 0.], [1., 0.]]), cov=jnp.array([[[0.01, 0.], [0., 0.01]], [[0.01, 0.], [0., 0.01]]]), @@ -112,7 +112,7 @@ def test_sample(self, rng: jnp.ndarray): atol=1.e-1 ) - def test_log_prob(self, rng: jnp.ndarray): + def test_log_prob(self, rng: jax.Array): n_components = 3 size = 100 subrng0, subrng1 = jax.random.split(rng, num=2) @@ -136,7 +136,7 @@ def test_log_prob(self, rng: jnp.ndarray): np.testing.assert_allclose(expected, actual, atol=1e-4, rtol=1e-4) - def test_log_component_posterior(self, rng: jnp.ndarray): + def test_log_component_posterior(self, rng: jax.Array): gmm = gaussian_mixture.GaussianMixture.from_random( rng=rng, n_components=3, n_dimensions=2 ) @@ -150,7 +150,7 @@ def test_log_component_posterior(self, rng: jnp.ndarray): expected, gmm.get_log_component_posterior(x), atol=1e-4, rtol=1e-4 ) - def test_flatten_unflatten(self, rng: jnp.ndarray): + def test_flatten_unflatten(self, rng: jax.Array): gmm = gaussian_mixture.GaussianMixture.from_random( rng=rng, n_components=3, n_dimensions=2 ) @@ -159,7 +159,7 @@ def test_flatten_unflatten(self, rng: jnp.ndarray): assert gmm == gmm_new - def test_pytree_mapping(self, rng: jnp.ndarray): + def test_pytree_mapping(self, rng: jax.Array): gmm = gaussian_mixture.GaussianMixture.from_random( rng=rng, n_components=3, n_dimensions=2 ) diff --git a/tests/tools/gaussian_mixture/gaussian_test.py b/tests/tools/gaussian_mixture/gaussian_test.py index 23deff00d..b731c2a8f 100644 --- a/tests/tools/gaussian_mixture/gaussian_test.py +++ b/tests/tools/gaussian_mixture/gaussian_test.py @@ -23,7 +23,7 @@ @pytest.mark.fast() class TestGaussian: - def test_from_random(self, rng: jnp.ndarray): + def test_from_random(self, rng: jax.Array): g = gaussian.Gaussian.from_random(rng=rng, n_dimensions=3) np.testing.assert_array_equal(g.loc.shape, (3,)) @@ -37,7 +37,7 @@ def test_from_mean_and_cov(self): np.testing.assert_array_equal(mean, g.loc) np.testing.assert_allclose(cov, g.covariance(), atol=1e-4, rtol=1e-4) - def test_to_z(self, rng: jnp.ndarray): + def test_to_z(self, rng: jax.Array): g = gaussian.Gaussian( loc=jnp.array([1., 2.]), scale=scale_tril.ScaleTriL( @@ -53,7 +53,7 @@ def test_to_z(self, rng: jnp.ndarray): np.testing.assert_allclose(sample_mean, jnp.zeros(2), atol=0.1) np.testing.assert_allclose(sample_cov, jnp.eye(2), atol=0.1) - def test_from_z(self, rng: jnp.ndarray): + def test_from_z(self, rng: jax.Array): g = gaussian.Gaussian( loc=jnp.array([0., 0.]), scale=scale_tril.ScaleTriL( @@ -65,7 +65,7 @@ def test_from_z(self, rng: jnp.ndarray): xnew = g.from_z(z) np.testing.assert_allclose(x, xnew, atol=1e-4, rtol=1e-4) - def test_log_prob(self, rng: jnp.ndarray): + def test_log_prob(self, rng: jax.Array): g = gaussian.Gaussian( loc=jnp.array([0., 0.]), scale=scale_tril.ScaleTriL( @@ -79,7 +79,7 @@ def test_log_prob(self, rng: jnp.ndarray): ) np.testing.assert_allclose(expected, actual, atol=1e-5, rtol=1e-5) - def test_sample(self, rng: jnp.ndarray): + def test_sample(self, rng: jax.Array): mean = jnp.array([1., 2.]) cov = jnp.diag(jnp.array([1., 4.])) g = gaussian.Gaussian.from_mean_and_cov(mean, cov) @@ -90,7 +90,7 @@ def test_sample(self, rng: jnp.ndarray): np.testing.assert_allclose(sample_mean, mean, atol=3. * 2. / 100.) np.testing.assert_allclose(sample_cov, cov, atol=2e-1) - def test_w2_dist(self, rng: jnp.ndarray): + def test_w2_dist(self, rng: jax.Array): # make sure distance between a random normal and itself is 0 rng, subrng = jax.random.split(rng) n = gaussian.Gaussian.from_random(rng=subrng, n_dimensions=3) @@ -119,7 +119,7 @@ def test_w2_dist(self, rng: jnp.ndarray): expected = delta_mean + delta_sigma np.testing.assert_allclose(expected, w2, rtol=1e-6, atol=1e-6) - def test_transport(self, rng: jnp.ndarray): + def test_transport(self, rng: jax.Array): diag0 = jnp.array([1.]) diag1 = jnp.array([4.]) g0 = gaussian.Gaussian( @@ -135,14 +135,14 @@ def test_transport(self, rng: jnp.ndarray): expected = 2. * points + 1. np.testing.assert_allclose(expected, actual, atol=1e-5, rtol=1e-5) - def test_flatten_unflatten(self, rng: jnp.ndarray): + def test_flatten_unflatten(self, rng: jax.Array): g = gaussian.Gaussian.from_random(rng, n_dimensions=3) children, aux_data = jax.tree_util.tree_flatten(g) g_new = jax.tree_util.tree_unflatten(aux_data, children) assert g == g_new - def test_pytree_mapping(self, rng: jnp.ndarray): + def test_pytree_mapping(self, rng: jax.Array): g = gaussian.Gaussian.from_random(rng, n_dimensions=3) g_x_2 = jax.tree_map(lambda x: 2 * x, g) diff --git a/tests/tools/gaussian_mixture/linalg_test.py b/tests/tools/gaussian_mixture/linalg_test.py index 4529364dc..345f6bfa8 100644 --- a/tests/tools/gaussian_mixture/linalg_test.py +++ b/tests/tools/gaussian_mixture/linalg_test.py @@ -23,7 +23,7 @@ @pytest.mark.fast() class TestLinalg: - def test_get_mean_and_var(self, rng: jnp.ndarray): + def test_get_mean_and_var(self, rng: jax.Array): points = jax.random.normal(key=rng, shape=(10, 2)) weights = jnp.ones(10) expected_mean = jnp.mean(points, axis=0) @@ -34,7 +34,7 @@ def test_get_mean_and_var(self, rng: jnp.ndarray): np.testing.assert_allclose(expected_mean, actual_mean, atol=1E-5, rtol=1E-5) np.testing.assert_allclose(expected_var, actual_var, atol=1E-5, rtol=1E-5) - def test_get_mean_and_var_nonuniform_weights(self, rng: jnp.ndarray): + def test_get_mean_and_var_nonuniform_weights(self, rng: jax.Array): points = jax.random.normal(key=rng, shape=(10, 2)) weights = jnp.concatenate([jnp.ones(5), jnp.zeros(5)], axis=-1) expected_mean = jnp.mean(points[:5], axis=0) @@ -45,7 +45,7 @@ def test_get_mean_and_var_nonuniform_weights(self, rng: jnp.ndarray): np.testing.assert_allclose(expected_mean, actual_mean, rtol=1e-6, atol=1e-6) np.testing.assert_allclose(expected_var, actual_var, rtol=1e-6, atol=1e-6) - def test_get_mean_and_cov(self, rng: jnp.ndarray): + def test_get_mean_and_cov(self, rng: jax.Array): points = jax.random.normal(key=rng, shape=(10, 2)) weights = jnp.ones(10) expected_mean = jnp.mean(points, axis=0) @@ -56,7 +56,7 @@ def test_get_mean_and_cov(self, rng: jnp.ndarray): np.testing.assert_allclose(expected_mean, actual_mean, atol=1e-5, rtol=1e-5) np.testing.assert_allclose(expected_cov, actual_cov, atol=1e-5, rtol=1e-5) - def test_get_mean_and_cov_nonuniform_weights(self, rng: jnp.ndarray): + def test_get_mean_and_cov_nonuniform_weights(self, rng: jax.Array): points = jax.random.normal(key=rng, shape=(10, 2)) weights = jnp.concatenate([jnp.ones(5), jnp.zeros(5)], axis=-1) expected_mean = jnp.mean(points[:5], axis=0) @@ -67,7 +67,7 @@ def test_get_mean_and_cov_nonuniform_weights(self, rng: jnp.ndarray): np.testing.assert_allclose(expected_mean, actual_mean, rtol=1e-6, atol=1e-6) np.testing.assert_allclose(expected_cov, actual_cov, rtol=1e-6, atol=1e-6) - def test_flat_to_tril(self, rng: jnp.ndarray): + def test_flat_to_tril(self, rng: jax.Array): size = 3 x = jax.random.normal(key=rng, shape=(5, 4, size * (size + 1) // 2)) m = linalg.flat_to_tril(x, size) @@ -87,7 +87,7 @@ def test_flat_to_tril(self, rng: jnp.ndarray): actual = linalg.tril_to_flat(m) np.testing.assert_allclose(x, actual) - def test_tril_to_flat(self, rng: jnp.ndarray): + def test_tril_to_flat(self, rng: jax.Array): size = 3 m = jax.random.normal(key=rng, shape=(5, 4, size, size)) for i in range(size): @@ -104,7 +104,7 @@ def test_tril_to_flat(self, rng: jnp.ndarray): inverted = linalg.flat_to_tril(flat, size) np.testing.assert_allclose(m, inverted) - def test_apply_to_diag(self, rng: jnp.ndarray): + def test_apply_to_diag(self, rng: jax.Array): size = 3 m = jax.random.normal(key=rng, shape=(5, 4, size, size)) mnew = linalg.apply_to_diag(m, jnp.exp) @@ -115,7 +115,7 @@ def test_apply_to_diag(self, rng: jnp.ndarray): else: np.testing.assert_allclose(jnp.exp(m[..., i, j]), mnew[..., i, j]) - def test_matrix_powers(self, rng: jnp.ndarray): + def test_matrix_powers(self, rng: jax.Array): rng, subrng = jax.random.split(rng) m = jax.random.normal(key=subrng, shape=(4, 4)) m += jnp.swapaxes(m, axis1=-2, axis2=-1) # symmetric @@ -126,7 +126,7 @@ def test_matrix_powers(self, rng: jnp.ndarray): np.testing.assert_allclose(m, actual[0], rtol=1.e-5) np.testing.assert_allclose(inv_m, actual[1], rtol=1.e-4) - def test_invmatvectril(self, rng: jnp.ndarray): + def test_invmatvectril(self, rng: jax.Array): rng, subrng = jax.random.split(rng) m = jax.random.normal(key=subrng, shape=(2, 2)) m += jnp.swapaxes(m, axis1=-2, axis2=-1) # symmetric @@ -139,7 +139,7 @@ def test_invmatvectril(self, rng: jnp.ndarray): actual = linalg.invmatvectril(m=cholesky, x=x, lower=True) np.testing.assert_allclose(expected, actual, atol=1e-4, rtol=1.e-4) - def test_get_random_orthogonal(self, rng: jnp.ndarray): + def test_get_random_orthogonal(self, rng: jax.Array): rng, subrng = jax.random.split(rng) q = linalg.get_random_orthogonal(rng=subrng, dim=3) qt = jnp.transpose(q) diff --git a/tests/tools/gaussian_mixture/probabilities_test.py b/tests/tools/gaussian_mixture/probabilities_test.py index 4fce8186f..fa0753c9f 100644 --- a/tests/tools/gaussian_mixture/probabilities_test.py +++ b/tests/tools/gaussian_mixture/probabilities_test.py @@ -40,7 +40,7 @@ def test_log_probs(self): np.testing.assert_allclose(jnp.sum(probs), 1.0, rtol=1e-6, atol=1e-6) np.testing.assert_array_equal(probs > 0., True) - def test_from_random(self, rng: jnp.ndarray): + def test_from_random(self, rng: jax.Array): n_dimensions = 4 pp = probabilities.Probabilities.from_random( rng=rng, n_dimensions=n_dimensions, stdev=0.1 @@ -52,7 +52,7 @@ def test_from_probs(self): pp = probabilities.Probabilities.from_probs(probs) np.testing.assert_allclose(probs, pp.probs(), rtol=1e-6, atol=1e-6) - def test_sample(self, rng: jnp.ndarray): + def test_sample(self, rng: jax.Array): p = 0.4 probs = jnp.array([p, 1. - p]) pp = probabilities.Probabilities.from_probs(probs) diff --git a/tests/tools/gaussian_mixture/scale_tril_test.py b/tests/tools/gaussian_mixture/scale_tril_test.py index f7bbe9293..e8244590b 100644 --- a/tests/tools/gaussian_mixture/scale_tril_test.py +++ b/tests/tools/gaussian_mixture/scale_tril_test.py @@ -48,7 +48,7 @@ def test_log_det_covariance(self, chol: scale_tril.ScaleTriL): actual = chol.log_det_covariance() np.testing.assert_almost_equal(actual, expected) - def test_from_random(self, rng: jnp.ndarray): + def test_from_random(self, rng: jax.Array): n_dimensions = 4 cov = scale_tril.ScaleTriL.from_random( rng=rng, n_dimensions=n_dimensions, stdev=0.1 @@ -57,7 +57,7 @@ def test_from_random(self, rng: jnp.ndarray): cov.cholesky().shape, (n_dimensions, n_dimensions) ) - def test_from_cholesky(self, rng: jnp.ndarray): + def test_from_cholesky(self, rng: jax.Array): n_dimensions = 4 cholesky = scale_tril.ScaleTriL.from_random( rng=rng, n_dimensions=n_dimensions, stdev=1. @@ -65,7 +65,7 @@ def test_from_cholesky(self, rng: jnp.ndarray): scale = scale_tril.ScaleTriL.from_cholesky(cholesky) np.testing.assert_allclose(cholesky, scale.cholesky(), atol=1e-4, rtol=1e-4) - def test_w2_dist(self, rng: jnp.ndarray): + def test_w2_dist(self, rng: jax.Array): # make sure distance between a random normal and itself is 0 rng, subrng = jax.random.split(rng) s = scale_tril.ScaleTriL.from_random(rng=subrng, n_dimensions=3) @@ -86,7 +86,7 @@ def test_w2_dist(self, rng: jnp.ndarray): delta_sigma = jnp.sum((jnp.sqrt(diag0) - jnp.sqrt(diag1)) ** 2.) np.testing.assert_allclose(delta_sigma, w2, atol=1e-4, rtol=1e-4) - def test_transport(self, rng: jnp.ndarray): + def test_transport(self, rng: jax.Array): size = 4 rng, subrng0, subrng1 = jax.random.split(rng, num=3) diag0 = jnp.exp(jax.random.normal(key=subrng0, shape=(size,))) @@ -100,14 +100,14 @@ def test_transport(self, rng: jnp.ndarray): expected = x * jnp.sqrt(diag1)[None] / jnp.sqrt(diag0)[None] np.testing.assert_allclose(expected, transported, atol=1e-4, rtol=1e-4) - def test_flatten_unflatten(self, rng: jnp.ndarray): + def test_flatten_unflatten(self, rng: jax.Array): scale = scale_tril.ScaleTriL.from_random(rng=rng, n_dimensions=3) children, aux_data = jax.tree_util.tree_flatten(scale) scale_new = jax.tree_util.tree_unflatten(aux_data, children) np.testing.assert_array_equal(scale.params, scale_new.params) assert scale == scale_new - def test_pytree_mapping(self, rng: jnp.ndarray): + def test_pytree_mapping(self, rng: jax.Array): scale = scale_tril.ScaleTriL.from_random(rng=rng, n_dimensions=3) scale_x_2 = jax.tree_map(lambda x: 2 * x, scale) np.testing.assert_allclose(2. * scale.params, scale_x_2.params) diff --git a/tests/tools/k_means_test.py b/tests/tools/k_means_test.py index c00288cec..6fc0fd403 100644 --- a/tests/tools/k_means_test.py +++ b/tests/tools/k_means_test.py @@ -64,7 +64,7 @@ def compute_assignment( class TestKmeansPlusPlus: @pytest.mark.fast.with_args("n_local_trials", [None, 3], only_fast=-1) - def test_n_local_trials(self, rng: jnp.ndarray, n_local_trials): + def test_n_local_trials(self, rng: jax.Array, n_local_trials): n, k = 100, 4 rng1, rng2 = jax.random.split(rng) geom, _, c = make_blobs( @@ -79,7 +79,7 @@ def test_n_local_trials(self, rng: jnp.ndarray, n_local_trials): assert shift1 > shift2 @pytest.mark.fast.with_args("k", [3, 5], only_fast=0) - def test_matches_sklearn(self, rng: jnp.ndarray, k: int): + def test_matches_sklearn(self, rng: jax.Array, k: int): ndim = 2 geom, _, _ = make_blobs( n_samples=100, @@ -103,7 +103,7 @@ def test_matches_sklearn(self, rng: jnp.ndarray, k: int): ) assert jnp.abs(pred_inertia - gt_inertia) <= 200 - def test_initialization_differentiable(self, rng: jnp.ndarray): + def test_initialization_differentiable(self, rng: jax.Array): def callback(x: jnp.ndarray) -> float: geom = pointcloud.PointCloud(x) @@ -123,7 +123,7 @@ class TestKmeans: @pytest.mark.fast() @pytest.mark.parametrize("k", [1, 6]) - def test_k_means_output(self, rng: jnp.ndarray, k: int): + def test_k_means_output(self, rng: jax.Array, k: int): max_iter, ndim = 10, 4 geom, gt_assignment, _ = make_blobs( n_samples=50, n_features=ndim, centers=k, random_state=42 @@ -161,7 +161,7 @@ def test_k_means_simple_example(self): ["k-means++", "random", "callable", "wrong-callable"], only_fast=1, ) - def test_init_method(self, rng: jnp.ndarray, init: str): + def test_init_method(self, rng: jax.Array, init: str): if init == "callable": init_fn = lambda geom, k, _: geom.x[:k] elif init == "wrong-callable": @@ -177,7 +177,7 @@ def test_init_method(self, rng: jnp.ndarray, init: str): else: _ = k_means.k_means(geom, k, init=init_fn) - def test_k_means_plus_plus_better_than_random(self, rng: jnp.ndarray): + def test_k_means_plus_plus_better_than_random(self, rng: jax.Array): k = 5 rng1, rng2 = jax.random.split(rng, 2) geom, _, _ = make_blobs(n_samples=50, centers=k, random_state=10) @@ -190,7 +190,7 @@ def test_k_means_plus_plus_better_than_random(self, rng: jnp.ndarray): assert res_kpp.iteration < res_random.iteration assert res_kpp.error <= res_random.error - def test_larger_n_init_helps(self, rng: jnp.ndarray): + def test_larger_n_init_helps(self, rng: jax.Array): k = 10 geom, _, _ = make_blobs(n_samples=150, centers=k, random_state=0) @@ -200,7 +200,7 @@ def test_larger_n_init_helps(self, rng: jnp.ndarray): assert res_larger_n_init.error < res.error @pytest.mark.parametrize("max_iter", [8, 16]) - def test_store_inner_errors(self, rng: jnp.ndarray, max_iter: int): + def test_store_inner_errors(self, rng: jax.Array, max_iter: int): ndim, k = 10, 4 geom, _, _ = make_blobs( n_samples=40, n_features=ndim, centers=k, random_state=43 @@ -216,7 +216,7 @@ def test_store_inner_errors(self, rng: jnp.ndarray, max_iter: int): # check if error is decreasing np.testing.assert_array_equal(jnp.diff(errors[::-1]) >= 0., True) - def test_strict_tolerance(self, rng: jnp.ndarray): + def test_strict_tolerance(self, rng: jax.Array): k = 11 geom, _, _ = make_blobs(n_samples=200, centers=k, random_state=39) @@ -230,7 +230,7 @@ def test_strict_tolerance(self, rng: jnp.ndarray): @pytest.mark.parametrize( "tol", [1e-3, 0.], ids=["weak-convergence", "strict-convergence"] ) - def test_convergence_force_scan(self, rng: jnp.ndarray, tol: float): + def test_convergence_force_scan(self, rng: jax.Array, tol: float): k, n_iter = 9, 20 geom, _, _ = make_blobs(n_samples=100, centers=k, random_state=37) @@ -248,7 +248,7 @@ def test_convergence_force_scan(self, rng: jnp.ndarray, tol: float): assert res.iteration == n_iter np.testing.assert_array_equal(res.inner_errors == -1, False) - def test_k_means_min_iterations(self, rng: jnp.ndarray): + def test_k_means_min_iterations(self, rng: jax.Array): k, min_iter = 8, 12 geom, _, _ = make_blobs(n_samples=160, centers=k, random_state=38) @@ -265,7 +265,7 @@ def test_k_means_min_iterations(self, rng: jnp.ndarray): assert res.converged assert jnp.sum(res.inner_errors != -1) >= min_iter - def test_weight_scaling_effects_only_inertia(self, rng: jnp.ndarray): + def test_weight_scaling_effects_only_inertia(self, rng: jax.Array): k = 10 rng1, rng2 = jax.random.split(rng) geom, _, _ = make_blobs(n_samples=130, centers=k, random_state=3) @@ -286,7 +286,7 @@ def test_weight_scaling_effects_only_inertia(self, rng: jnp.ndarray): ) @pytest.mark.fast() - def test_empty_weights(self, rng: jnp.ndarray): + def test_empty_weights(self, rng: jax.Array): n, ndim, k, d = 20, 2, 3, 5. gen = np.random.RandomState(0) x = gen.normal(size=(n, ndim)) @@ -334,7 +334,7 @@ def test_cosine_cost_fn(self): @pytest.mark.fast.with_args("init", ["k-means++", "random"], only_fast=0) def test_k_means_jitting( - self, rng: jnp.ndarray, init: Literal["k-means++", "random"] + self, rng: jax.Array, init: Literal["k-means++", "random"] ): def callback(x: jnp.ndarray) -> k_means.KMeansOutput: @@ -366,7 +366,7 @@ def callback(x: jnp.ndarray) -> k_means.KMeansOutput: (False, True)], ids=["jit-while-loop", "nojit-for-loop"]) def test_k_means_differentiability( - self, rng: jnp.ndarray, jit: bool, force_scan: bool + self, rng: jax.Array, jit: bool, force_scan: bool ): def inertia(x: jnp.ndarray, w: jnp.ndarray) -> float: @@ -405,7 +405,7 @@ def inertia(x: jnp.ndarray, w: jnp.ndarray) -> float: @pytest.mark.parametrize("tol", [1e-3, 0.]) @pytest.mark.parametrize(("n", "k"), [(37, 4), (128, 6)]) def test_clustering_matches_sklearn( - self, rng: jnp.ndarray, n: int, k: int, tol: float + self, rng: jax.Array, n: int, k: int, tol: float ): x, _, _ = make_blobs(n_samples=n, centers=k, random_state=41) diff --git a/tests/tools/segment_sinkhorn_test.py b/tests/tools/segment_sinkhorn_test.py index f98c164bf..53fb4ae85 100644 --- a/tests/tools/segment_sinkhorn_test.py +++ b/tests/tools/segment_sinkhorn_test.py @@ -27,7 +27,7 @@ class TestSegmentSinkhorn: @pytest.fixture(autouse=True) - def setUp(self, rng: jnp.ndarray): + def setUp(self, rng: jax.Array): self._dim = 4 self._num_points = 13, 17 self._max_measure_size = 20 diff --git a/tests/tools/sinkhorn_divergence_test.py b/tests/tools/sinkhorn_divergence_test.py index e3eab9912..040a04e00 100644 --- a/tests/tools/sinkhorn_divergence_test.py +++ b/tests/tools/sinkhorn_divergence_test.py @@ -29,7 +29,7 @@ class TestSinkhornDivergence: @pytest.fixture(autouse=True) - def setUp(self, rng: jnp.ndarray): + def setUp(self, rng: jax.Array): self._dim = 4 self._num_points = 13, 17 self.rng, *rngs = jax.random.split(rng, 3) @@ -390,7 +390,7 @@ def test_euclidean_momentum_params( class TestSinkhornDivergenceGrad: @pytest.fixture(autouse=True) - def initialize(self, rng: jnp.ndarray): + def initialize(self, rng: jax.Array): self._dim = 3 self._num_points = 13, 12 self.rng, *rngs = jax.random.split(rng, 3) diff --git a/tests/tools/soft_sort_test.py b/tests/tools/soft_sort_test.py index c84680e9e..b4fa68ddf 100644 --- a/tests/tools/soft_sort_test.py +++ b/tests/tools/soft_sort_test.py @@ -28,14 +28,14 @@ class TestSoftSort: @pytest.mark.parametrize("shape", [(20,), (20, 1)]) - def test_sort_one_array(self, rng: jnp.ndarray, shape: Tuple[int, ...]): + def test_sort_one_array(self, rng: jax.Array, shape: Tuple[int, ...]): x = jax.random.uniform(rng, shape) xs = soft_sort.sort(x, axis=0) np.testing.assert_array_equal(x.shape, xs.shape) np.testing.assert_array_equal(jnp.diff(xs, axis=0) >= 0.0, True) - def test_sort_array_squashing_momentum(self, rng: jnp.ndarray): + def test_sort_array_squashing_momentum(self, rng: jax.Array): shape = (33, 1) x = jax.random.uniform(rng, shape) xs_lin = soft_sort.sort( @@ -62,7 +62,7 @@ def test_sort_array_squashing_momentum(self, rng: jnp.ndarray): @pytest.mark.fast() @pytest.mark.parametrize("k", [-1, 4, 100]) - def test_topk_one_array(self, rng: jnp.ndarray, k: int): + def test_topk_one_array(self, rng: jax.Array, k: int): n = 20 x = jax.random.uniform(rng, (n,)) axis = 0 @@ -76,7 +76,7 @@ def test_topk_one_array(self, rng: jnp.ndarray, k: int): np.testing.assert_allclose(xs, jnp.sort(x, axis=axis)[-outsize:], atol=0.01) @pytest.mark.fast.with_args("topk", [-1, 2, 11], only_fast=-1) - def test_sort_batch(self, rng: jnp.ndarray, topk: int): + def test_sort_batch(self, rng: jax.Array, topk: int): x = jax.random.uniform(rng, (32, 10, 6, 4)) axis = 1 xs = soft_sort.sort(x, axis=axis, topk=topk) @@ -86,7 +86,7 @@ def test_sort_batch(self, rng: jnp.ndarray, topk: int): np.testing.assert_array_equal(xs.shape, expected_shape) np.testing.assert_array_equal(jnp.diff(xs, axis=axis) >= 0.0, True) - def test_multivariate_cdf_quantiles(self, rng: jnp.ndarray): + def test_multivariate_cdf_quantiles(self, rng: jax.Array): n, d = 512, 3 key1, key2, key3 = jax.random.split(rng, 3) @@ -129,7 +129,7 @@ def mv_c_q(inputs, num_target_samples, rng, epsilon): np.testing.assert_allclose(z, qua(q), atol=atol) @pytest.mark.fast.with_args("axis,jit", [(0, False), (1, True)], only_fast=0) - def test_ranks(self, axis, rng: jnp.ndarray, jit: bool): + def test_ranks(self, axis, rng: jax.Array, jit: bool): rng1, rng2 = jax.random.split(rng, 2) num_targets = 13 x = jax.random.uniform(rng1, (8, 5, 2)) @@ -164,7 +164,7 @@ def test_ranks(self, axis, rng: jnp.ndarray, jit: bool): np.testing.assert_allclose(ranks, expected_ranks, atol=0.3, rtol=0.1) @pytest.mark.fast.with_args("axis,jit", [(0, False), (1, True)], only_fast=0) - def test_topk_mask(self, axis, rng: jnp.ndarray, jit: bool): + def test_topk_mask(self, axis, rng: jax.Array, jit: bool): def boolean_topk_mask(u, k): return u >= jnp.flip(jax.numpy.sort(u))[k - 1] @@ -195,7 +195,7 @@ def test_quantile(self, q: float): np.testing.assert_allclose(x_q, q, atol=1e-3, rtol=1e-2) - def test_quantile_on_several_axes(self, rng: jnp.ndarray): + def test_quantile_on_several_axes(self, rng: jax.Array): batch, height, width, channels = 4, 47, 45, 3 x = jax.random.uniform(rng, shape=(batch, height, width, channels)) q = soft_sort.quantile( @@ -209,7 +209,7 @@ def test_quantile_on_several_axes(self, rng: jnp.ndarray): @pytest.mark.fast() @pytest.mark.parametrize("jit", [False, True]) - def test_quantiles(self, rng: jnp.ndarray, jit: bool): + def test_quantiles(self, rng: jax.Array, jit: bool): inputs = jax.random.uniform(rng, (100, 2, 3)) q = jnp.array([.1, .8, .4]) quantile_fn = soft_sort.quantile @@ -221,7 +221,7 @@ def test_quantiles(self, rng: jnp.ndarray, jit: bool): np.testing.assert_allclose(m1.mean(axis=[1, 2]), q, atol=5e-2) @pytest.mark.parametrize("jit", [False, True]) - def test_soft_quantile_normalization(self, rng: jnp.ndarray, jit: bool): + def test_soft_quantile_normalization(self, rng: jax.Array, jit: bool): rngs = jax.random.split(rng, 2) x = jax.random.uniform(rngs[0], shape=(100,)) mu, sigma = 2.0, 1.2 @@ -238,7 +238,7 @@ def test_soft_quantile_normalization(self, rng: jnp.ndarray, jit: bool): [mu_target, sigma_target], rtol=0.05) - def test_sort_with(self, rng: jnp.ndarray): + def test_sort_with(self, rng: jax.Array): n, d = 20, 4 inputs = jax.random.uniform(rng, shape=(n, d)) criterion = jnp.linspace(0.1, 1.2, n) @@ -270,7 +270,7 @@ def test_quantize(self, jit: bool): np.testing.assert_allclose(min_distances, min_distances, atol=0.05) @pytest.mark.parametrize("implicit", [False, True]) - def test_soft_sort_jacobian(self, rng: jnp.ndarray, implicit: bool): + def test_soft_sort_jacobian(self, rng: jax.Array, implicit: bool): # Add a ridge when using JAX solvers. try: from ott.solvers.linear import lineax_implicit # noqa: F401 From 2bc683a2bf39f80708603848ec4be3bd20ff2290 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 29 Nov 2023 16:00:43 +0100 Subject: [PATCH 050/186] replace rng jnp.ndarray type by jax.array --- docs/tutorials/Monge_Gap.ipynb | 6 +++--- src/ott/datasets.py | 2 +- src/ott/geometry/geometry.py | 2 +- src/ott/geometry/low_rank.py | 2 +- src/ott/initializers/linear/initializers.py | 16 ++++++++-------- src/ott/initializers/linear/initializers_lr.py | 2 +- src/ott/neural/models/models.py | 4 ++-- src/ott/neural/solvers/genot.py | 4 ++-- src/ott/neural/solvers/map_estimator.py | 2 +- src/ott/neural/solvers/neuraldual.py | 2 +- src/ott/neural/solvers/otfm.py | 2 +- src/ott/problems/quadratic/quadratic_problem.py | 2 +- src/ott/solvers/linear/continuous_barycenter.py | 4 ++-- src/ott/solvers/linear/sinkhorn.py | 2 +- src/ott/solvers/linear/sinkhorn_lr.py | 2 +- src/ott/solvers/quadratic/gromov_wasserstein.py | 2 +- .../solvers/quadratic/gromov_wasserstein_lr.py | 2 +- src/ott/solvers/quadratic/gw_barycenter.py | 2 +- src/ott/tools/k_means.py | 2 +- src/ott/tools/soft_sort.py | 2 +- src/ott/utils.py | 2 +- tests/neural/neuraldual_test.py | 2 +- 22 files changed, 34 insertions(+), 34 deletions(-) diff --git a/docs/tutorials/Monge_Gap.ipynb b/docs/tutorials/Monge_Gap.ipynb index 2fde4f923..78b4ce602 100644 --- a/docs/tutorials/Monge_Gap.ipynb +++ b/docs/tutorials/Monge_Gap.ipynb @@ -99,7 +99,7 @@ " noise: float = 0.01\n", " scale: float = 1.0\n", " batch_size: int = 1024\n", - " rng: Optional[jnp.ndarray] = (None,)\n", + " rng: Optional[jax.Array] = (None,)\n", "\n", " def __iter__(self) -> Iterator[jnp.ndarray]:\n", " \"\"\"Random sample generator from Gaussian mixture.\n", @@ -152,7 +152,7 @@ " target_kwargs: Mapping[str, Any] = MappingProxyType({}),\n", " train_batch_size: int = 256,\n", " valid_batch_size: int = 256,\n", - " rng: Optional[jnp.ndarray] = None,\n", + " rng: Optional[jax.Array] = None,\n", ") -> Tuple[dataset.Dataset, dataset.Dataset, int]:\n", " \"\"\"Samplers from ``SklearnDistribution``.\"\"\"\n", " rng = jax.random.PRNGKey(0) if rng is None else rng\n", @@ -203,7 +203,7 @@ " num_points: Optional[int] = None,\n", " title: Optional[str] = None,\n", " figsize: Tuple[int, int] = (8, 6),\n", - " rng: Optional[jnp.ndarray] = None,\n", + " rng: Optional[jax.Array] = None,\n", "):\n", " \"\"\"Plot samples from the source and target measures.\n", "\n", diff --git a/src/ott/datasets.py b/src/ott/datasets.py index 3507c3418..e5077c87c 100644 --- a/src/ott/datasets.py +++ b/src/ott/datasets.py @@ -110,7 +110,7 @@ def create_gaussian_mixture_samplers( name_target: Name_t, train_batch_size: int = 2048, valid_batch_size: int = 2048, - rng: Optional[jnp.ndarray] = None, + rng: Optional[jax.Array] = None, ) -> Tuple[Dataset, Dataset, int]: """Gaussian samplers. diff --git a/src/ott/geometry/geometry.py b/src/ott/geometry/geometry.py index 6894176a6..766c5e618 100644 --- a/src/ott/geometry/geometry.py +++ b/src/ott/geometry/geometry.py @@ -625,7 +625,7 @@ def to_LRCGeometry( self, rank: int = 0, tol: float = 1e-2, - rng: Optional[jnp.ndarray] = None, + rng: Optional[jax.Array] = None, scale: float = 1. ) -> "low_rank.LRCGeometry": r"""Factorize the cost matrix using either SVD (full) or :cite:`indyk:19`. diff --git a/src/ott/geometry/low_rank.py b/src/ott/geometry/low_rank.py index 1bfaeae0a..966db28d4 100644 --- a/src/ott/geometry/low_rank.py +++ b/src/ott/geometry/low_rank.py @@ -229,7 +229,7 @@ def to_LRCGeometry( self, rank: int = 0, tol: float = 1e-2, - rng: Optional[jnp.ndarray] = None, + rng: Optional[jax.Array] = None, scale: float = 1.0, ) -> "LRCGeometry": """Return self.""" diff --git a/src/ott/initializers/linear/initializers.py b/src/ott/initializers/linear/initializers.py index bc4871841..f3ba93321 100644 --- a/src/ott/initializers/linear/initializers.py +++ b/src/ott/initializers/linear/initializers.py @@ -36,7 +36,7 @@ def init_dual_a( self, ot_prob: linear_problem.LinearProblem, lse_mode: bool, - rng: Optional[jnp.ndarray] = None, + rng: Optional[jax.Array] = None, ) -> jnp.ndarray: """Initialize Sinkhorn potential/scaling f_u. @@ -54,7 +54,7 @@ def init_dual_b( self, ot_prob: linear_problem.LinearProblem, lse_mode: bool, - rng: Optional[jnp.ndarray] = None, + rng: Optional[jax.Array] = None, ) -> jnp.ndarray: """Initialize Sinkhorn potential/scaling g_v. @@ -73,7 +73,7 @@ def __call__( a: Optional[jnp.ndarray], b: Optional[jnp.ndarray], lse_mode: bool, - rng: Optional[jnp.ndarray] = None, + rng: Optional[jax.Array] = None, ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Initialize Sinkhorn potentials/scalings f_u and g_v. @@ -128,7 +128,7 @@ def init_dual_a( # noqa: D102 self, ot_prob: linear_problem.LinearProblem, lse_mode: bool, - rng: Optional[jnp.ndarray] = None, + rng: Optional[jax.Array] = None, ) -> jnp.ndarray: del rng return jnp.zeros_like(ot_prob.a) if lse_mode else jnp.ones_like(ot_prob.a) @@ -137,7 +137,7 @@ def init_dual_b( # noqa: D102 self, ot_prob: linear_problem.LinearProblem, lse_mode: bool, - rng: Optional[jnp.ndarray] = None, + rng: Optional[jax.Array] = None, ) -> jnp.ndarray: del rng return jnp.zeros_like(ot_prob.b) if lse_mode else jnp.ones_like(ot_prob.b) @@ -158,7 +158,7 @@ def init_dual_a( # noqa: D102 self, ot_prob: linear_problem.LinearProblem, lse_mode: bool, - rng: Optional[jnp.ndarray] = None, + rng: Optional[jax.Array] = None, ) -> jnp.ndarray: # import Gaussian here due to circular imports from ott.tools.gaussian_mixture import gaussian @@ -245,7 +245,7 @@ def init_dual_a( self, ot_prob: linear_problem.LinearProblem, lse_mode: bool, - rng: Optional[jnp.ndarray] = None, + rng: Optional[jax.Array] = None, init_f: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: """Apply DualSort algorithm. @@ -324,7 +324,7 @@ def init_dual_a( # noqa: D102 self, ot_prob: linear_problem.LinearProblem, lse_mode: bool, - rng: Optional[jnp.ndarray] = None, + rng: Optional[jax.Array] = None, ) -> jnp.ndarray: from ott.solvers import linear diff --git a/src/ott/initializers/linear/initializers_lr.py b/src/ott/initializers/linear/initializers_lr.py index a3f615846..5c2302156 100644 --- a/src/ott/initializers/linear/initializers_lr.py +++ b/src/ott/initializers/linear/initializers_lr.py @@ -169,7 +169,7 @@ def __call__( r: Optional[jnp.ndarray] = None, g: Optional[jnp.ndarray] = None, *, - rng: Optional[jnp.ndarray] = None, + rng: Optional[jax.Array] = None, **kwargs: Any ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Initialize the factors :math:`Q`, :math:`R` and :math:`g`. diff --git a/src/ott/neural/models/models.py b/src/ott/neural/models/models.py index c65cbbaf3..8e2562e97 100644 --- a/src/ott/neural/models/models.py +++ b/src/ott/neural/models/models.py @@ -268,7 +268,7 @@ def __init__( meta_model: nn.Module, opt: Optional[optax.GradientTransformation ] = optax.adam(learning_rate=1e-3), # noqa: B008 - rng: Optional[jnp.ndarray] = None, + rng: Optional[jax.Array] = None, state: Optional[train_state.TrainState] = None ): self.geom = geom @@ -334,7 +334,7 @@ def init_dual_a( # noqa: D102 self, ot_prob: "linear_problem.LinearProblem", lse_mode: bool, - rng: Optional[jnp.ndarray] = None, + rng: Optional[jax.Array] = None, ) -> jnp.ndarray: del rng # Detect if the problem is batched. diff --git a/src/ott/neural/solvers/genot.py b/src/ott/neural/solvers/genot.py index 0613ae53c..61b368a67 100644 --- a/src/ott/neural/solvers/genot.py +++ b/src/ott/neural/solvers/genot.py @@ -131,7 +131,7 @@ def __init__( unbalanced_kwargs: Dict[str, Any] = types.MappingProxyType({}), callback_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], Any]] = None, - rng: Optional[jnp.ndarray] = None, + rng: Optional[jax.Array] = None, ) -> None: rng = utils.default_prng_key(rng) rng, rng_unbalanced = jax.random.split(rng) @@ -379,7 +379,7 @@ def transport( self, source: jnp.ndarray, condition: Optional[jnp.ndarray], - rng: Optional[jnp.ndarray] = None, + rng: Optional[jax.Array] = None, forward: bool = True, diffeqsolve_kwargs: Dict[str, Any] = types.MappingProxyType({}), ) -> Union[jnp.array, diffrax.Solution, Optional[jnp.ndarray]]: diff --git a/src/ott/neural/solvers/map_estimator.py b/src/ott/neural/solvers/map_estimator.py index 7eaffdfc8..fb65917c7 100644 --- a/src/ott/neural/solvers/map_estimator.py +++ b/src/ott/neural/solvers/map_estimator.py @@ -88,7 +88,7 @@ def __init__( num_train_iters: int = 10_000, logging: bool = False, valid_freq: int = 500, - rng: Optional[jnp.ndarray] = None, + rng: Optional[jax.Array] = None, ): self._fitting_loss = fitting_loss self._regularizer = regularizer diff --git a/src/ott/neural/solvers/neuraldual.py b/src/ott/neural/solvers/neuraldual.py index 0d9e215bb..019fb836f 100644 --- a/src/ott/neural/solvers/neuraldual.py +++ b/src/ott/neural/solvers/neuraldual.py @@ -243,7 +243,7 @@ def __init__( valid_freq: int = 1000, log_freq: int = 1000, logging: bool = False, - rng: Optional[jnp.ndarray] = None, + rng: Optional[jax.Array] = None, pos_weights: bool = True, beta: float = 1.0, conjugate_solver: Optional[conjugate.FenchelConjugateSolver diff --git a/src/ott/neural/solvers/otfm.py b/src/ott/neural/solvers/otfm.py index fb054e30a..d145c4128 100644 --- a/src/ott/neural/solvers/otfm.py +++ b/src/ott/neural/solvers/otfm.py @@ -116,7 +116,7 @@ def __init__( logging_freq: int = 100, valid_freq: int = 5000, num_eval_samples: int = 1000, - rng: Optional[jnp.ndarray] = None, + rng: Optional[jax.Array] = None, ) -> None: rng = utils.default_prng_key(rng) rng, rng_unbalanced = jax.random.split(rng) diff --git a/src/ott/problems/quadratic/quadratic_problem.py b/src/ott/problems/quadratic/quadratic_problem.py index a17aaf9fb..5deb4558c 100644 --- a/src/ott/problems/quadratic/quadratic_problem.py +++ b/src/ott/problems/quadratic/quadratic_problem.py @@ -382,7 +382,7 @@ def convertible(geom: geometry.Geometry) -> bool: def to_low_rank( self, - rng: Optional[jnp.ndarray] = None, + rng: Optional[jax.Array] = None, ) -> "QuadraticProblem": """Convert geometries to low-rank. diff --git a/src/ott/solvers/linear/continuous_barycenter.py b/src/ott/solvers/linear/continuous_barycenter.py index e1477e60f..2d89a74ea 100644 --- a/src/ott/solvers/linear/continuous_barycenter.py +++ b/src/ott/solvers/linear/continuous_barycenter.py @@ -130,7 +130,7 @@ def __call__( # noqa: D102 bar_prob: barycenter_problem.FreeBarycenterProblem, bar_size: int = 100, x_init: Optional[jnp.ndarray] = None, - rng: Optional[jnp.ndarray] = None, + rng: Optional[jax.Array] = None, ) -> FreeBarycenterState: # TODO(michalk8): no reason for iterations to be outside this class rng = utils.default_prng_key(rng) @@ -141,7 +141,7 @@ def init_state( bar_prob: barycenter_problem.FreeBarycenterProblem, bar_size: int, x_init: Optional[jnp.ndarray] = None, - rng: Optional[jnp.ndarray] = None, + rng: Optional[jax.Array] = None, ) -> FreeBarycenterState: """Initialize the state of the Wasserstein barycenter iterations. diff --git a/src/ott/solvers/linear/sinkhorn.py b/src/ott/solvers/linear/sinkhorn.py index 56c718c1f..d9ab53f3a 100644 --- a/src/ott/solvers/linear/sinkhorn.py +++ b/src/ott/solvers/linear/sinkhorn.py @@ -843,7 +843,7 @@ def __call__( self, ot_prob: linear_problem.LinearProblem, init: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray]] = (None, None), - rng: Optional[jnp.ndarray] = None, + rng: Optional[jax.Array] = None, ) -> SinkhornOutput: """Run Sinkhorn algorithm. diff --git a/src/ott/solvers/linear/sinkhorn_lr.py b/src/ott/solvers/linear/sinkhorn_lr.py index ba83aeb99..db948cf8b 100644 --- a/src/ott/solvers/linear/sinkhorn_lr.py +++ b/src/ott/solvers/linear/sinkhorn_lr.py @@ -343,7 +343,7 @@ def __call__( ot_prob: linear_problem.LinearProblem, init: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray], Optional[jnp.ndarray]] = (None, None, None), - rng: Optional[jnp.ndarray] = None, + rng: Optional[jax.Array] = None, **kwargs: Any, ) -> LRSinkhornOutput: """Run low-rank Sinkhorn. diff --git a/src/ott/solvers/quadratic/gromov_wasserstein.py b/src/ott/solvers/quadratic/gromov_wasserstein.py index 6180db73f..5e23d88e6 100644 --- a/src/ott/solvers/quadratic/gromov_wasserstein.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein.py @@ -213,7 +213,7 @@ def __call__( self, prob: quadratic_problem.QuadraticProblem, init: Optional[linear_problem.LinearProblem] = None, - rng: Optional[jnp.ndarray] = None, + rng: Optional[jax.Array] = None, **kwargs: Any, ) -> GWOutput: """Run the Gromov-Wasserstein solver. diff --git a/src/ott/solvers/quadratic/gromov_wasserstein_lr.py b/src/ott/solvers/quadratic/gromov_wasserstein_lr.py index 710d8f617..214853f4c 100644 --- a/src/ott/solvers/quadratic/gromov_wasserstein_lr.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein_lr.py @@ -336,7 +336,7 @@ def __call__( ot_prob: quadratic_problem.QuadraticProblem, init: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray], Optional[jnp.ndarray]] = (None, None, None), - rng: Optional[jnp.ndarray] = None, + rng: Optional[jax.Array] = None, **kwargs: Any, ) -> LRGWOutput: """Run low-rank Gromov-Wasserstein solver. diff --git a/src/ott/solvers/quadratic/gw_barycenter.py b/src/ott/solvers/quadratic/gw_barycenter.py index ea14880fe..f0d350b08 100644 --- a/src/ott/solvers/quadratic/gw_barycenter.py +++ b/src/ott/solvers/quadratic/gw_barycenter.py @@ -136,7 +136,7 @@ def init_state( bar_init: Optional[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]] = None, a: Optional[jnp.ndarray] = None, - rng: Optional[jnp.ndarray] = None, + rng: Optional[jax.Array] = None, ) -> GWBarycenterState: """Initialize the (fused) Gromov-Wasserstein barycenter state. diff --git a/src/ott/tools/k_means.py b/src/ott/tools/k_means.py index abbe99f34..986b919d0 100644 --- a/src/ott/tools/k_means.py +++ b/src/ott/tools/k_means.py @@ -352,7 +352,7 @@ def k_means( min_iterations: int = 0, max_iterations: int = 300, store_inner_errors: bool = False, - rng: Optional[jnp.ndarray] = None, + rng: Optional[jax.Array] = None, ) -> KMeansOutput: r"""K-means clustering using Lloyd's algorithm :cite:`lloyd:82`. diff --git a/src/ott/tools/soft_sort.py b/src/ott/tools/soft_sort.py index beb88365f..ccde3bd2c 100644 --- a/src/ott/tools/soft_sort.py +++ b/src/ott/tools/soft_sort.py @@ -459,7 +459,7 @@ def multivariate_cdf_quantile_maps( inputs: jnp.ndarray, target_sampler: Optional[Callable[[jnp.ndarray, Tuple[int, int]], jnp.ndarray]] = None, - rng: Optional[jnp.ndarray] = None, + rng: Optional[jax.Array] = None, num_target_samples: Optional[int] = None, cost_fn: Optional[costs.CostFn] = None, epsilon: Optional[float] = None, diff --git a/src/ott/utils.py b/src/ott/utils.py index 2acfd8420..63a36f2b4 100644 --- a/src/ott/utils.py +++ b/src/ott/utils.py @@ -69,7 +69,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return functools.wraps(func)(wrapper) -def default_prng_key(rng: Optional[jnp.ndarray] = None) -> jnp.ndarray: +def default_prng_key(rng: Optional[jax.Array] = None) -> jnp.ndarray: """Get the default PRNG key. Args: diff --git a/tests/neural/neuraldual_test.py b/tests/neural/neuraldual_test.py index b31ba9b6a..8a362affa 100644 --- a/tests/neural/neuraldual_test.py +++ b/tests/neural/neuraldual_test.py @@ -19,7 +19,7 @@ import numpy as np from ott import datasets -from ott.neural import models +from ott.neural.models import models from ott.neural.solvers import conjugate, neuraldual ModelPair_t = Tuple[neuraldual.BaseW2NeuralDual, neuraldual.BaseW2NeuralDual] From 542dd0a7b6d509622afdcf2b6bb455198f1669d8 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 29 Nov 2023 17:28:40 +0100 Subject: [PATCH 051/186] fix import error --- src/ott/neural/solvers/neuraldual.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ott/neural/solvers/neuraldual.py b/src/ott/neural/solvers/neuraldual.py index 019fb836f..455d6b50e 100644 --- a/src/ott/neural/solvers/neuraldual.py +++ b/src/ott/neural/solvers/neuraldual.py @@ -36,7 +36,7 @@ from ott import utils from ott.geometry import costs -from ott.neural import models +from ott.neural.models import models from ott.neural.solvers import conjugate from ott.problems.linear import potentials From f585c247d1fe5a07d12cd9d52247ebb7157d1e5f Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 1 Dec 2023 18:01:55 +0100 Subject: [PATCH 052/186] [ci skip] start to incorporate feedback --- docs/tutorials/point_clouds.ipynb | 2 +- src/ott/geometry/geometry.py | 2 +- src/ott/geometry/low_rank.py | 2 +- src/ott/geometry/pointcloud.py | 2 +- src/ott/neural/data/dataloaders.py | 9 ++-- src/ott/neural/models/base_models.py | 61 --------------------------- src/ott/neural/models/layers.py | 18 +++++++- src/ott/neural/models/models.py | 46 ++++++-------------- src/ott/neural/solvers/base_solver.py | 27 ++++-------- src/ott/neural/solvers/flows.py | 40 ++++++------------ src/ott/neural/solvers/genot.py | 15 ++++--- src/ott/neural/solvers/neuraldual.py | 8 +--- src/ott/neural/solvers/otfm.py | 15 ++++--- tests/neural/genot_test.py | 16 +++---- tests/neural/otfm_test.py | 10 ++--- 15 files changed, 89 insertions(+), 184 deletions(-) delete mode 100644 src/ott/neural/models/base_models.py diff --git a/docs/tutorials/point_clouds.ipynb b/docs/tutorials/point_clouds.ipynb index c01b51cfd..156bafaa9 100644 --- a/docs/tutorials/point_clouds.ipynb +++ b/docs/tutorials/point_clouds.ipynb @@ -64,7 +64,7 @@ }, "outputs": [], "source": [ - "def create_points(rng: jax.random.PRNGKeyArray, n: int, m: int, d: int):\n", + "def create_points(rng: jax.Array, n: int, m: int, d: int):\n", " rngs = jax.random.split(rng, 3)\n", " x = jax.random.normal(rngs[0], (n, d)) + 1\n", " y = jax.random.uniform(rngs[1], (m, d))\n", diff --git a/src/ott/geometry/geometry.py b/src/ott/geometry/geometry.py index 766c5e618..f953bf38c 100644 --- a/src/ott/geometry/geometry.py +++ b/src/ott/geometry/geometry.py @@ -201,7 +201,7 @@ def is_symmetric(self) -> bool: @property def inv_scale_cost(self) -> float: """Compute and return inverse of scaling factor for cost matrix.""" - if isinstance(self._scale_cost, (int, float, np.number, jnp.ndarray)): + if isinstance(self._scale_cost, (int, float, np.number, jax.Array)): return 1.0 / self._scale_cost self = self._masked_geom(mask_value=jnp.nan) if self._scale_cost == "max_cost": diff --git a/src/ott/geometry/low_rank.py b/src/ott/geometry/low_rank.py index 966db28d4..e759b4cb9 100644 --- a/src/ott/geometry/low_rank.py +++ b/src/ott/geometry/low_rank.py @@ -107,7 +107,7 @@ def is_symmetric(self) -> bool: # noqa: D102 @property def inv_scale_cost(self) -> float: # noqa: D102 - if isinstance(self._scale_cost, (int, float, jnp.ndarray)): + if isinstance(self._scale_cost, (int, float, jax.Array)): return 1.0 / self._scale_cost self = self._masked_geom() if self._scale_cost == "max_bound": diff --git a/src/ott/geometry/pointcloud.py b/src/ott/geometry/pointcloud.py index 2050e1562..e7f46a020 100644 --- a/src/ott/geometry/pointcloud.py +++ b/src/ott/geometry/pointcloud.py @@ -141,7 +141,7 @@ def cost_rank(self) -> int: # noqa: D102 @property def inv_scale_cost(self) -> float: # noqa: D102 - if isinstance(self._scale_cost, (int, float, jnp.ndarray)): + if isinstance(self._scale_cost, (int, float, jax.Array)): return 1.0 / self._scale_cost self = self._masked_geom() if self._scale_cost == "max_cost": diff --git a/src/ott/neural/data/dataloaders.py b/src/ott/neural/data/dataloaders.py index 4fe8a9a8c..832117013 100644 --- a/src/ott/neural/data/dataloaders.py +++ b/src/ott/neural/data/dataloaders.py @@ -43,7 +43,7 @@ def __init__( source_conditions: Optional[np.ndarray] = None, target_conditions: Optional[np.ndarray] = None, seed: int = 0, - ) -> None: + ): super().__init__() if source_lin is not None: if source_quad is not None: @@ -115,11 +115,8 @@ class ConditionalDataLoader: """ def __init__( - self, - dataloaders: Dict[str, Iterator], - p: np.ndarray, - seed: int = 0 - ) -> None: + self, dataloaders: Dict[str, Iterator], p: np.ndarray, seed: int = 0 + ): super().__init__() self.dataloaders = dataloaders self.conditions = list(dataloaders.keys()) diff --git a/src/ott/neural/models/base_models.py b/src/ott/neural/models/base_models.py deleted file mode 100644 index d3ac7526a..000000000 --- a/src/ott/neural/models/base_models.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright OTT-JAX -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import abc -from typing import Optional - -import jax.numpy as jnp - -import flax.linen as nn - -__all__ = ["BaseNeuralVectorField", "BaseRescalingNet"] - - -class BaseNeuralVectorField(nn.Module, abc.ABC): - """Base class for neural vector field models.""" - - @abc.abstractmethod - def __call__( - self, - t: jnp.ndarray, - x: jnp.ndarray, - condition: Optional[jnp.ndarray] = None, - keys_model: Optional[jnp.ndarray] = None - ) -> jnp.ndarray: - """"Evaluate the vector field. - - Args: - t: Time. - x: Input data. - condition: Condition. - keys_model: Random keys for the model. - """ - pass - - -class BaseRescalingNet(nn.Module, abc.ABC): - """Base class for models to learn distributional rescaling factors.""" - - @abc.abstractmethod - def __call__( - self, - x: jnp.ndarray, - condition: Optional[jnp.ndarray] = None - ) -> jnp.ndarray: - """Evaluate the model. - - Args: - x: Input data. - condition: Condition. - """ - pass diff --git a/src/ott/neural/models/layers.py b/src/ott/neural/models/layers.py index 79e6394bc..50c2c6301 100644 --- a/src/ott/neural/models/layers.py +++ b/src/ott/neural/models/layers.py @@ -18,14 +18,28 @@ import flax.linen as nn -__all__ = ["PositiveDense", "PosDefPotentials"] +__all__ = ["PositiveDense", "PosDefPotentials", "MLPBlock"] -PRNGKey = jnp.ndarray +PRNGKey = jax.Array Shape = Tuple[int, ...] Dtype = Any Array = Any +class MLPBlock(nn.Module): + dim: int = 128 + num_layers: int = 3 + act_fn: Any = nn.silu + out_dim: int = 32 + + @nn.compact + def __call__(self, x): + for _ in range(self.num_layers): + x = nn.Dense(self.dim)(x) + x = self.act_fn(x) + return nn.Dense(self.out_dim)(x) + + class PositiveDense(nn.Module): """A linear transformation using a weight matrix with all entries positive. diff --git a/src/ott/neural/models/models.py b/src/ott/neural/models/models.py index 8e2562e97..0fc7d4f30 100644 --- a/src/ott/neural/models/models.py +++ b/src/ott/neural/models/models.py @@ -28,16 +28,10 @@ from ott.initializers.linear import initializers as lin_init from ott.math import matrix_square_root from ott.neural.models import layers -from ott.neural.models.base_models import ( - BaseNeuralVectorField, - BaseRescalingNet, -) from ott.neural.solvers import neuraldual from ott.problems.linear import linear_problem -__all__ = [ - "ICNN", "MLP", "MetaInitializer", "NeuralVectorField", "RescalingMLP" -] +__all__ = ["ICNN", "MLP", "MetaInitializer", "VelocityField", "RescalingMLP"] class ICNN(neuraldual.BaseW2NeuralDual): @@ -76,7 +70,7 @@ class ICNN(neuraldual.BaseW2NeuralDual): def is_potential(self) -> bool: # noqa: D102 return True - def setup(self) -> None: # noqa: D102 + def setup(self): # noqa: D102 self.num_hidden = len(self.dim_hidden) if self.pos_weights: @@ -410,21 +404,7 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 } -class Block(nn.Module): - dim: int = 128 - num_layers: int = 3 - act_fn: Any = nn.silu - out_dim: int = 32 - - @nn.compact - def __call__(self, x): - for _ in range(self.num_layers): - x = nn.Dense(self.dim)(x) - x = self.act_fn(x) - return nn.Dense(self.out_dim)(x) - - -class NeuralVectorField(BaseNeuralVectorField): +class VelocityField(nn.Module): """Parameterized neural vector field. Each of the input, condition, and time embeddings are passed through a block @@ -515,7 +495,7 @@ def __call__( Output of the neural vector field. """ t = self.time_encoder(t) - t = Block( + t = layers.MLPBlock( dim=self.t_embed_dim, out_dim=self.t_embed_dim, num_layers=self.num_layers_per_block, @@ -524,7 +504,7 @@ def __call__( t ) - x = Block( + x = layers.MLPBlock( dim=self.latent_embed_dim, out_dim=self.latent_embed_dim, num_layers=self.num_layers_per_block, @@ -534,7 +514,7 @@ def __call__( ) if self.condition_dim > 0: - condition = Block( + condition = layers.MLPBlock( dim=self.condition_embed_dim, out_dim=self.condition_embed_dim, num_layers=self.num_layers_per_block, @@ -546,7 +526,7 @@ def __call__( else: concatenated = jnp.concatenate((t, x), axis=-1) - out = Block( + out = layers.MLPBlock( dim=self.joint_hidden_dim, out_dim=self.joint_hidden_dim, num_layers=self.num_layers_per_block, @@ -564,7 +544,7 @@ def __call__( def create_train_state( self, - rng: jax.random.PRNGKeyArray, + rng: jax.Array, optimizer: optax.OptState, input_dim: int, ) -> train_state.TrainState: @@ -587,7 +567,7 @@ def create_train_state( ) -class RescalingMLP(BaseRescalingNet): +class RescalingMLP(nn.Module): """Network to learn distributional rescaling factors based on a MLP. The input is passed through a block consisting of ``num_layers_per_block`` @@ -626,7 +606,7 @@ def __call__( Returns: Estimated rescaling factors. """ - x = Block( + x = layers.MLPBlock( dim=self.hidden_dim, out_dim=self.hidden_dim, num_layers=self.num_layers_per_block, @@ -636,7 +616,7 @@ def __call__( ) if self.condition_dim > 0: condition = jnp.atleast_1d(condition) - condition = Block( + condition = layers.MLPBlock( dim=self.hidden_dim, out_dim=self.hidden_dim, num_layers=self.num_layers_per_block, @@ -648,7 +628,7 @@ def __call__( else: concatenated = x - out = Block( + out = layers.MLPBlock( dim=self.hidden_dim, out_dim=self.hidden_dim, num_layers=self.num_layers_per_block, @@ -661,7 +641,7 @@ def __call__( def create_train_state( self, - rng: jax.random.PRNGKeyArray, + rng: jax.Array, optimizer: optax.OptState, input_dim: int, ) -> train_state.TrainState: diff --git a/src/ott/neural/solvers/base_solver.py b/src/ott/neural/solvers/base_solver.py index fe0ea6f3d..780bf61ad 100644 --- a/src/ott/neural/solvers/base_solver.py +++ b/src/ott/neural/solvers/base_solver.py @@ -23,8 +23,6 @@ from flax.training import train_state from ott.geometry import costs, pointcloud -from ott.geometry.pointcloud import PointCloud -from ott.neural.models import models from ott.problems.linear import linear_problem from ott.problems.quadratic import quadratic_problem from ott.solvers.linear import sinkhorn @@ -38,48 +36,39 @@ class BaseNeuralSolver(ABC): valid_freq: Frequency at which to run validation. """ - def __init__(self, iterations: int, valid_freq: int, **_: Any) -> None: + def __init__(self, iterations: int, valid_freq: int, **_: Any): self.iterations = iterations self.valid_freq = valid_freq @abstractmethod - def setup(self, *args: Any, **kwargs: Any) -> None: + def setup(self, *args: Any, **kwargs: Any): """Setup the model.""" - pass @abstractmethod - def __call__(self, *args: Any, **kwargs: Any) -> None: + def __call__(self, *args: Any, **kwargs: Any): """Train the model.""" - pass @abstractmethod def transport(self, *args: Any, forward: bool, **kwargs: Any) -> Any: """Transport.""" - pass @abstractmethod def save(self, path: Path): """Save the model.""" - pass @abstractmethod def load(self, path: Path): """Load the model.""" - pass @property @abstractmethod def training_logs(self) -> Dict[str, Any]: """Return the training logs.""" - pass class ResampleMixin: """Mixin class for mini-batch OT in neural optimal transport solvers.""" - def __init__(*args, **kwargs): - pass - def _resample_data( self, key: jax.random.KeyArray, @@ -264,8 +253,10 @@ def __init__( cond_dim: Optional[int], tau_a: float = 1.0, tau_b: float = 1.0, - mlp_eta: Optional[models.BaseRescalingNet] = None, - mlp_xi: Optional[models.BaseRescalingNet] = None, + mlp_eta: Optional[Callable[[jnp.ndarray, Optional[jnp.ndarray]], + jnp.ndarray]] = None, + mlp_xi: Optional[Callable[[jnp.ndarray, Optional[jnp.ndarray]], + jnp.ndarray]] = None, seed: Optional[int] = None, opt_eta: Optional[optax.GradientTransformation] = None, opt_xi: Optional[optax.GradientTransformation] = None, @@ -274,7 +265,7 @@ def __init__( "median"]] = "mean", sinkhorn_kwargs: Mapping[str, Any] = MappingProxyType({}), **_: Any, - ) -> None: + ): self.rng_unbalanced = rng self.source_dim = source_dim self.target_dim = target_dim @@ -313,7 +304,7 @@ def _get_compute_unbalanced_marginals( def compute_unbalanced_marginals( batch_source: jnp.ndarray, batch_target: jnp.ndarray ) -> Tuple[jnp.ndarray, jnp.ndarray]: - geom = PointCloud( + geom = pointcloud.PointCloud( batch_source, batch_target, epsilon=resample_epsilon, diff --git a/src/ott/neural/solvers/flows.py b/src/ott/neural/solvers/flows.py index 47be01fc5..0ff81a560 100644 --- a/src/ott/neural/solvers/flows.py +++ b/src/ott/neural/solvers/flows.py @@ -29,11 +29,13 @@ class BaseFlow(abc.ABC): sigma: Constant noise used for computing time-dependent noise schedule. """ - def __init__(self, sigma: float) -> None: + def __init__(self, sigma: float): self.sigma = sigma @abc.abstractmethod - def compute_mu_t(self, t: jnp.ndarray, x_0: jnp.ndarray, x_1: jnp.ndarray): + def compute_mu_t( + self, t: jnp.ndarray, x_0: jnp.ndarray, x_1: jnp.ndarray + ) -> jnp.ndarray: """Compute the mean of the probablitiy path. Compute the mean of the probablitiy path between :math:`x` and :math:`y` @@ -44,7 +46,6 @@ def compute_mu_t(self, t: jnp.ndarray, x_0: jnp.ndarray, x_1: jnp.ndarray): x_0: Sample from the source distribution. x_1: Sample from the target distribution. """ - pass @abc.abstractmethod def compute_sigma_t(self, t: jnp.ndarray): @@ -53,7 +54,6 @@ def compute_sigma_t(self, t: jnp.ndarray): Args: t: Time :math:`t`. """ - pass @abc.abstractmethod def compute_ut( @@ -61,15 +61,14 @@ def compute_ut( ) -> jnp.ndarray: """Evaluate the conditional vector field. - Evaluate the conditional vector field defined between :math:`x_0` and - :math:`x_1` at time :math:`t`. + Evaluate the conditional vector field defined between :math:`x_0` and + :math:`x_1` at time :math:`t`. Args: t: Time :math:`t`. x_0: Sample from the source distribution. x_1: Sample from the target distribution. """ - pass def compute_xt( self, noise: jnp.ndarray, t: jnp.ndarray, x_0: jnp.ndarray, @@ -77,8 +76,8 @@ def compute_xt( ) -> jnp.ndarray: """Sample from the probability path. - Sample from the probability path between :math:`x_0` and :math:`x_1` at - time :math:`t`. + Sample from the probability path between :math:`x_0` and :math:`x_1` at + time :math:`t`. Args: noise: Noise sampled from a standard normal distribution. @@ -88,7 +87,7 @@ def compute_xt( Returns: Samples from the probability path between :math:`x_0` and :math:`x_1` - at time :math:`t`. + at time :math:`t`. """ mu_t = self.compute_mu_t(t, x_0, x_1) sigma_t = self.compute_sigma_t(t) @@ -101,16 +100,6 @@ class StraightFlow(BaseFlow, abc.ABC): def compute_mu_t( self, t: jnp.ndarray, x_0: jnp.ndarray, x_1: jnp.ndarray ) -> jnp.ndarray: - """Compute the mean of the probablitiy path. - - Compute the mean of the probablitiy path between :math:`x` and :math:`y` - at time :math:`t`. - - Args: - t: Time :math:`t`. - x_0: Sample from the source distribution. - x_1: Sample from the target distribution. - """ return t * x_0 + (1 - t) * x_1 def compute_ut( @@ -119,7 +108,7 @@ def compute_ut( """Evaluate the conditional vector field. Evaluate the conditional vector field defined between :math:`x_0` and - :math:`x_1` at time :math:`t`. + :math:`x_1` at time :math:`t`. Args: t: Time :math:`t`. @@ -175,7 +164,7 @@ class BaseTimeSampler(abc.ABC): high: Upper bound of the distribution to sample from . """ - def __init__(self, low: float, high: float) -> None: + def __init__(self, low: float, high: float): self.low = low self.high = high @@ -187,7 +176,6 @@ def __call__(self, rng: jax.Array, num_samples: int) -> jnp.ndarray: rng: Random number generator. num_samples: Number of samples to generate. """ - pass class UniformSampler(BaseTimeSampler): @@ -198,7 +186,7 @@ class UniformSampler(BaseTimeSampler): high: Upper bound of the uniform distribution. """ - def __init__(self, low: float = 0.0, high: float = 1.0) -> None: + def __init__(self, low: float = 0.0, high: float = 1.0): super().__init__(low=low, high=high) def __call__(self, rng: jax.Array, num_samples: int) -> jnp.ndarray: @@ -228,9 +216,7 @@ class OffsetUniformSampler(BaseTimeSampler): high: Upper bound of the uniform distribution. """ - def __init__( - self, offset: float, low: float = 0.0, high: float = 1.0 - ) -> None: + def __init__(self, offset: float, low: float = 0.0, high: float = 1.0): super().__init__(low=low, high=high) self.offset = offset diff --git a/src/ott/neural/solvers/genot.py b/src/ott/neural/solvers/genot.py index 61b368a67..fb76ded77 100644 --- a/src/ott/neural/solvers/genot.py +++ b/src/ott/neural/solvers/genot.py @@ -26,7 +26,6 @@ from ott import utils from ott.geometry import costs -from ott.neural.models.models import BaseNeuralVectorField from ott.neural.solvers.base_solver import ( BaseNeuralSolver, ResampleMixin, @@ -100,7 +99,9 @@ class GENOT(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): def __init__( self, - neural_vector_field: Type[BaseNeuralVectorField], + neural_vector_field: Callable[[ + jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray], Optional[jnp.ndarray] + ], jnp.ndarray], input_dim: int, output_dim: int, cond_dim: int, @@ -132,7 +133,7 @@ def __init__( callback_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], Any]] = None, rng: Optional[jax.Array] = None, - ) -> None: + ): rng = utils.default_prng_key(rng) rng, rng_unbalanced = jax.random.split(rng) BaseNeuralSolver.__init__( @@ -192,7 +193,7 @@ def __init__( self.callback_fn = callback_fn self.setup() - def setup(self) -> None: + def setup(self): """Set up the model. Parameters @@ -230,7 +231,7 @@ def setup(self) -> None: self.fused_penalty ) - def __call__(self, train_loader, valid_loader) -> None: + def __call__(self, train_loader, valid_loader): """Train GENOT.""" batch: Dict[str, jnp.array] = {} for iteration in range(self.iterations): @@ -439,7 +440,7 @@ def solve_ode(input: jnp.ndarray, cond: jnp.ndarray): return jax.vmap(solve_ode)(latent_batch, cond_input) - def _valid_step(self, valid_loader, iter) -> None: + def _valid_step(self, valid_loader, iter): """TODO.""" next(valid_loader) @@ -448,7 +449,7 @@ def learn_rescaling(self) -> bool: """Whether to learn at least one rescaling factor.""" return self.mlp_eta is not None or self.mlp_xi is not None - def save(self, path: str) -> None: + def save(self, path: str): """Save the model. Args: diff --git a/src/ott/neural/solvers/neuraldual.py b/src/ott/neural/solvers/neuraldual.py index 455d6b50e..ade11a085 100644 --- a/src/ott/neural/solvers/neuraldual.py +++ b/src/ott/neural/solvers/neuraldual.py @@ -69,10 +69,6 @@ class W2NeuralTrainState(train_state.TrainState): ) -class BaseNeuralVectorField(nn.Module): - pass - - class BaseW2NeuralDual(abc.ABC, nn.Module): """Base class for the neural solver models.""" @@ -295,7 +291,7 @@ def setup( dim_data: int, optimizer_f: optax.OptState, optimizer_g: optax.OptState, - ) -> None: + ): """Setup all components required to train the network.""" # split random number generator rng, rng_f, rng_g = jax.random.split(rng, 3) @@ -700,7 +696,7 @@ def _update_logs( loss_f: jnp.ndarray, loss_g: jnp.ndarray, w_dist: jnp.ndarray, - ) -> None: + ): logs["loss_f"].append(float(loss_f)) logs["loss_g"].append(float(loss_g)) logs["w_dist"].append(float(w_dist)) diff --git a/src/ott/neural/solvers/otfm.py b/src/ott/neural/solvers/otfm.py index d145c4128..b7885c1d5 100644 --- a/src/ott/neural/solvers/otfm.py +++ b/src/ott/neural/solvers/otfm.py @@ -36,7 +36,6 @@ from ott import utils from ott.geometry import costs -from ott.neural.models.models import BaseNeuralVectorField from ott.neural.solvers.base_solver import ( BaseNeuralSolver, ResampleMixin, @@ -92,7 +91,9 @@ class OTFlowMatching(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): def __init__( self, - neural_vector_field: Type[BaseNeuralVectorField], + neural_vector_field: Callable[[ + jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray], Optional[jnp.ndarray] + ], jnp.ndarray], input_dim: int, cond_dim: int, iterations: int, @@ -117,7 +118,7 @@ def __init__( valid_freq: int = 5000, num_eval_samples: int = 1000, rng: Optional[jax.Array] = None, - ) -> None: + ): rng = utils.default_prng_key(rng) rng, rng_unbalanced = jax.random.split(rng) BaseNeuralSolver.__init__( @@ -155,7 +156,7 @@ def __init__( self.setup() - def setup(self) -> None: + def setup(self): """Setup :class:`OTFlowMatching`.""" self.state_neural_vector_field = ( self.neural_vector_field.create_train_state( @@ -218,7 +219,7 @@ def loss_fn( return step_fn - def __call__(self, train_loader, valid_loader) -> None: + def __call__(self, train_loader, valid_loader): """Train :class:`OTFlowMatching`. Args; @@ -330,7 +331,7 @@ def solve_ode(input: jnp.ndarray, cond: jnp.ndarray): return jax.vmap(solve_ode)(data, condition) - def _valid_step(self, valid_loader, iter) -> None: + def _valid_step(self, valid_loader, iter): next(valid_loader) # TODO: add callback and logging @@ -339,7 +340,7 @@ def learn_rescaling(self) -> bool: """Whether to learn at least one rescaling factor.""" return self.mlp_eta is not None or self.mlp_xi is not None - def save(self, path: str) -> None: + def save(self, path: str): """Save the model. Args: diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index fddc4fc3c..92a929154 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -20,7 +20,7 @@ import optax from ott.geometry import costs -from ott.neural.models.models import NeuralVectorField, RescalingMLP +from ott.neural.models.models import RescalingMLP, VelocityField from ott.neural.solvers.flows import OffsetUniformSampler, UniformSampler from ott.neural.solvers.genot import GENOT from ott.solvers.linear import sinkhorn @@ -47,7 +47,7 @@ def test_genot_linear_unconditional( target_dim = target_lin.shape[1] condition_dim = 0 - neural_vf = NeuralVectorField( + neural_vf = VelocityField( output_dim=target_dim, condition_dim=source_dim + condition_dim, latent_embed_dim=5, @@ -97,7 +97,7 @@ def test_genot_quad_unconditional( source_dim = source_quad.shape[1] target_dim = target_quad.shape[1] condition_dim = 0 - neural_vf = NeuralVectorField( + neural_vf = VelocityField( output_dim=target_dim, condition_dim=source_dim + condition_dim, latent_embed_dim=5, @@ -143,7 +143,7 @@ def test_genot_fused_unconditional( source_dim = source_lin.shape[1] + source_quad.shape[1] target_dim = target_lin.shape[1] + target_quad.shape[1] condition_dim = 0 - neural_vf = NeuralVectorField( + neural_vf = VelocityField( output_dim=target_dim, condition_dim=source_dim + condition_dim, latent_embed_dim=5, @@ -190,7 +190,7 @@ def test_genot_linear_conditional( target_dim = target_lin.shape[1] condition_dim = source_condition.shape[1] - neural_vf = NeuralVectorField( + neural_vf = VelocityField( output_dim=target_dim, condition_dim=source_dim + condition_dim, latent_embed_dim=5, @@ -237,7 +237,7 @@ def test_genot_quad_conditional( source_dim = source_quad.shape[1] target_dim = target_quad.shape[1] condition_dim = source_condition.shape[1] - neural_vf = NeuralVectorField( + neural_vf = VelocityField( output_dim=target_dim, condition_dim=source_dim + condition_dim, latent_embed_dim=5, @@ -284,7 +284,7 @@ def test_genot_fused_conditional( source_dim = source_lin.shape[1] + source_quad.shape[1] target_dim = target_lin.shape[1] + target_quad.shape[1] condition_dim = source_condition.shape[1] - neural_vf = NeuralVectorField( + neural_vf = VelocityField( output_dim=target_dim, condition_dim=source_dim + condition_dim, latent_embed_dim=5, @@ -340,7 +340,7 @@ def test_genot_linear_learn_rescaling( target_dim = target_lin.shape[1] condition_dim = source_condition.shape[1] if conditional else 0 - neural_vf = NeuralVectorField( + neural_vf = VelocityField( output_dim=target_dim, condition_dim=source_dim + condition_dim, latent_embed_dim=5, diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index a57588a43..b38fceb74 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -19,7 +19,7 @@ import optax -from ott.neural.models.models import NeuralVectorField, RescalingMLP +from ott.neural.models.models import RescalingMLP, VelocityField from ott.neural.solvers.flows import ( BaseFlow, BrownianNoiseFlow, @@ -40,7 +40,7 @@ class TestOTFlowMatching: BrownianNoiseFlow(0.2)] ) def test_flow_matching(self, data_loader_gaussian, flow: Type[BaseFlow]): - neural_vf = NeuralVectorField( + neural_vf = VelocityField( output_dim=2, condition_dim=0, latent_embed_dim=5, @@ -85,7 +85,7 @@ def test_flow_matching(self, data_loader_gaussian, flow: Type[BaseFlow]): def test_flow_matching_with_conditions( self, data_loader_gaussian_with_conditions, flow: Type[BaseFlow] ): - neural_vf = NeuralVectorField( + neural_vf = VelocityField( output_dim=2, condition_dim=1, latent_embed_dim=5, @@ -133,7 +133,7 @@ def test_flow_matching_with_conditions( def test_flow_matching_conditional( self, data_loader_gaussian_conditional, flow: Type[BaseFlow] ): - neural_vf = NeuralVectorField( + neural_vf = VelocityField( output_dim=2, condition_dim=0, latent_embed_dim=5, @@ -181,7 +181,7 @@ def test_flow_matching_learn_rescaling( batch = next(data_loader) source_dim = batch["source_lin"].shape[1] condition_dim = batch["source_conditions"].shape[1] if conditional else 0 - neural_vf = NeuralVectorField( + neural_vf = VelocityField( output_dim=2, condition_dim=0, latent_embed_dim=5, From 3c0700973881763f25aa8d8f712c3954f8389d40 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Mon, 4 Dec 2023 11:02:17 +0100 Subject: [PATCH 053/186] restructure neural module --- src/ott/neural/__init__.py | 2 +- src/ott/neural/data/dataloaders.py | 1 - src/ott/neural/duality/__init__.py | 14 + .../neural/{solvers => duality}/conjugate.py | 0 src/ott/neural/duality/layers.py | 140 +++++ src/ott/neural/duality/models.py | 362 +++++++++++++ .../neural/{solvers => duality}/neuraldual.py | 2 +- src/ott/neural/{solvers => flows}/__init__.py | 2 +- src/ott/neural/{solvers => flows}/flows.py | 2 +- src/ott/neural/{solvers => flows}/genot.py | 12 +- src/ott/neural/flows/models.py | 188 +++++++ src/ott/neural/{solvers => flows}/otfm.py | 4 +- src/ott/neural/gaps/__init__.py | 14 + .../neural/{solvers => gaps}/map_estimator.py | 2 +- .../{models/losses.py => gaps/monge_gap.py} | 0 src/ott/neural/models/__init__.py | 2 +- .../neural/{solvers => models}/base_solver.py | 2 + src/ott/neural/models/layers.py | 128 +---- src/ott/neural/models/models.py | 506 +----------------- tests/neural/genot_test.py | 7 +- tests/neural/losses_test.py | 15 +- tests/neural/map_estimator_test.py | 7 +- tests/neural/neuraldual_test.py | 2 +- tests/neural/otfm_test.py | 7 +- 24 files changed, 768 insertions(+), 653 deletions(-) create mode 100644 src/ott/neural/duality/__init__.py rename src/ott/neural/{solvers => duality}/conjugate.py (100%) create mode 100644 src/ott/neural/duality/layers.py create mode 100644 src/ott/neural/duality/models.py rename src/ott/neural/{solvers => duality}/neuraldual.py (99%) rename src/ott/neural/{solvers => flows}/__init__.py (91%) rename src/ott/neural/{solvers => flows}/flows.py (99%) rename src/ott/neural/{solvers => flows}/genot.py (99%) create mode 100644 src/ott/neural/flows/models.py rename src/ott/neural/{solvers => flows}/otfm.py (99%) create mode 100644 src/ott/neural/gaps/__init__.py rename src/ott/neural/{solvers => gaps}/map_estimator.py (99%) rename src/ott/neural/{models/losses.py => gaps/monge_gap.py} (100%) rename src/ott/neural/{solvers => models}/base_solver.py (99%) diff --git a/src/ott/neural/__init__.py b/src/ott/neural/__init__.py index 326fae432..2a61ca021 100644 --- a/src/ott/neural/__init__.py +++ b/src/ott/neural/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import data, models, solvers +from . import data, duality, flows, gaps, models diff --git a/src/ott/neural/data/dataloaders.py b/src/ott/neural/data/dataloaders.py index 832117013..9c09ce08c 100644 --- a/src/ott/neural/data/dataloaders.py +++ b/src/ott/neural/data/dataloaders.py @@ -111,7 +111,6 @@ class ConditionalDataLoader: conditions. p: Probability of sampling from each data loader. seed: Random seed. - """ def __init__( diff --git a/src/ott/neural/duality/__init__.py b/src/ott/neural/duality/__init__.py new file mode 100644 index 000000000..ef76b42fa --- /dev/null +++ b/src/ott/neural/duality/__init__.py @@ -0,0 +1,14 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from . import conjugate, layers, models, neuraldual diff --git a/src/ott/neural/solvers/conjugate.py b/src/ott/neural/duality/conjugate.py similarity index 100% rename from src/ott/neural/solvers/conjugate.py rename to src/ott/neural/duality/conjugate.py diff --git a/src/ott/neural/duality/layers.py b/src/ott/neural/duality/layers.py new file mode 100644 index 000000000..4b85972f3 --- /dev/null +++ b/src/ott/neural/duality/layers.py @@ -0,0 +1,140 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional, Tuple + +import jax +import jax.numpy as jnp + +import flax.linen as nn + +__all__ = ["PositiveDense", "PosDefPotentials"] + +PRNGKey = jax.Array +Shape = Tuple[int, ...] +Dtype = Any +Array = Any + + +class PositiveDense(nn.Module): + """A linear transformation using a weight matrix with all entries positive. + + Args: + dim_hidden: the number of output dim_hidden. + rectifier_fn: choice of rectifier function (default: softplus function). + inv_rectifier_fn: choice of inverse rectifier function + (default: inverse softplus function). + dtype: the dtype of the computation (default: float32). + precision: numerical precision of computation see `jax.lax.Precision` + for details. + kernel_init: initializer function for the weight matrix. + bias_init: initializer function for the bias. + """ + dim_hidden: int + rectifier_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.softplus + inv_rectifier_fn: Callable[[jnp.ndarray], + jnp.ndarray] = lambda x: jnp.log(jnp.exp(x) - 1) + use_bias: bool = True + dtype: Any = jnp.float32 + precision: Any = None + kernel_init: Optional[Callable[[PRNGKey, Shape, Dtype], Array]] = None, + bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.zeros + + @nn.compact + def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + """Applies a linear transformation to inputs along the last dimension. + + Args: + inputs: Array to be transformed. + + Returns: + The transformed input. + """ + kernel_init = nn.initializers.lecun_normal( + ) if self.kernel_init is None else self.kernel_init + + inputs = jnp.asarray(inputs, self.dtype) + kernel = self.param( + "kernel", kernel_init, (inputs.shape[-1], self.dim_hidden) + ) + kernel = self.rectifier_fn(kernel) + kernel = jnp.asarray(kernel, self.dtype) + y = jax.lax.dot_general( + inputs, + kernel, (((inputs.ndim - 1,), (0,)), ((), ())), + precision=self.precision + ) + if self.use_bias: + bias = self.param("bias", self.bias_init, (self.dim_hidden,)) + bias = jnp.asarray(bias, self.dtype) + return y + bias + return y + + +class PosDefPotentials(nn.Module): + r"""A layer to output :math:`\frac{1}{2} ||A_i^T (x - b_i)||^2_i` potentials. + + Args: + use_bias: whether to add a bias to the output. + dtype: the dtype of the computation. + precision: numerical precision of computation see `jax.lax.Precision` + for details. + kernel_init: initializer function for the weight matrix. + bias_init: initializer function for the bias. + """ + dim_data: int + num_potentials: int + use_bias: bool = True + dtype: Any = jnp.float32 + precision: Any = None + kernel_init: Optional[Callable[[PRNGKey, Shape, Dtype], Array]] = None + bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.zeros + + @nn.compact + def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + """Apply a few quadratic forms. + + Args: + inputs: Array to be transformed (possibly batched). + + Returns: + The transformed input. + """ + kernel_init = nn.initializers.lecun_normal( + ) if self.kernel_init is None else self.kernel_init + inputs = jnp.asarray(inputs, self.dtype) + kernel = self.param( + "kernel", kernel_init, + (self.num_potentials, inputs.shape[-1], inputs.shape[-1]) + ) + + if self.use_bias: + bias = self.param( + "bias", self.bias_init, (self.num_potentials, self.dim_data) + ) + bias = jnp.asarray(bias, self.dtype) + + y = inputs.reshape((-1, inputs.shape[-1])) if inputs.ndim == 1 else inputs + y = y[..., None] - bias.T[None, ...] + y = jax.lax.dot_general( + y, kernel, (((1,), (1,)), ((2,), (0,))), precision=self.precision + ) + else: + y = jax.lax.dot_general( + inputs, + kernel, (((inputs.ndim - 1,), (0,)), ((), ())), + precision=self.precision + ) + + y = 0.5 * y * y + return jnp.sum(y.reshape((-1, self.num_potentials, self.dim_data)), axis=2) diff --git a/src/ott/neural/duality/models.py b/src/ott/neural/duality/models.py new file mode 100644 index 000000000..2b51c60cf --- /dev/null +++ b/src/ott/neural/duality/models.py @@ -0,0 +1,362 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools +from typing import Any, Callable, Dict, Optional, Sequence, Tuple + +import jax +import jax.numpy as jnp +from jax.nn import initializers + +import flax.linen as nn +import optax +from flax.core import frozen_dict +from flax.training import train_state + +from ott import utils +from ott.geometry import geometry +from ott.initializers.linear import initializers as lin_init +from ott.math import matrix_square_root +from ott.neural.duality import neuraldual +from ott.neural.models import layers +from ott.problems.linear import linear_problem + +__all__ = ["ICNN", "MetaInitializer"] + + +class ICNN(neuraldual.BaseW2NeuralDual): + """Input convex neural network (ICNN) architecture with initialization. + + Implementation of input convex neural networks as introduced in + :cite:`amos:17` with initialization schemes proposed by :cite:`bunne:22`. + + Args: + dim_data: data dimensionality. + dim_hidden: sequence specifying size of hidden dimensions. The + output dimension of the last layer is 1 by default. + init_std: value of standard deviation of weight initialization method. + init_fn: choice of initialization method for weight matrices (default: + :func:`jax.nn.initializers.normal`). + act_fn: choice of activation function used in network architecture + (needs to be convex, default: :obj:`jax.nn.relu`). + pos_weights: Enforce positive weights with a projection. + If ``False``, the positive weights should be enforced with clipping + or regularization in the loss. + gaussian_map_samples: Tuple of source and target points, used to initialize + the ICNN to mimic the linear Bures map that morphs the (Gaussian + approximation) of the input measure to that of the target measure. If + ``None``, the identity initialization is used, and ICNN mimics half the + squared Euclidean norm. + """ + dim_data: int + dim_hidden: Sequence[int] + init_std: float = 1e-2 + init_fn: Callable = jax.nn.initializers.normal + act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + pos_weights: bool = True + gaussian_map_samples: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None + + @property + def is_potential(self) -> bool: # noqa: D102 + return True + + def setup(self): # noqa: D102 + self.num_hidden = len(self.dim_hidden) + + if self.pos_weights: + hid_dense = layers.PositiveDense + # this function needs to be the inverse map of function + # used in PositiveDense layers + rescale = hid_dense.inv_rectifier_fn + else: + hid_dense = nn.Dense + rescale = lambda x: x + self.use_init = False + # check if Gaussian map was provided + if self.gaussian_map_samples is not None: + factor, mean = self._compute_gaussian_map_params( + self.gaussian_map_samples + ) + else: + factor, mean = self._compute_identity_map_params(self.dim_data) + + w_zs = [] + # keep track of previous size to normalize accordingly + normalization = 1 + + for i in range(1, self.num_hidden): + w_zs.append( + hid_dense( + self.dim_hidden[i], + kernel_init=initializers.constant(rescale(1.0 / normalization)), + use_bias=False, + ) + ) + normalization = self.dim_hidden[i] + # final layer computes average, still with normalized rescaling + w_zs.append( + hid_dense( + 1, + kernel_init=initializers.constant(rescale(1.0 / normalization)), + use_bias=False, + ) + ) + self.w_zs = w_zs + + # positive definite potential (the identity mapping or linear OT) + self.pos_def_potential = layers.PosDefPotentials( + self.dim_data, + num_potentials=1, + kernel_init=lambda *_: factor, + bias_init=lambda *_: mean, + use_bias=True, + ) + + # subsequent layers re-injected into convex functions + w_xs = [] + for i in range(self.num_hidden): + w_xs.append( + nn.Dense( + self.dim_hidden[i], + kernel_init=self.init_fn(self.init_std), + bias_init=initializers.constant(0.), + use_bias=True, + ) + ) + # final layer, to output number + w_xs.append( + nn.Dense( + 1, + kernel_init=self.init_fn(self.init_std), + bias_init=initializers.constant(0.), + use_bias=True, + ) + ) + self.w_xs = w_xs + + @staticmethod + def _compute_gaussian_map_params( + samples: Tuple[jnp.ndarray, jnp.ndarray] + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + from ott.tools.gaussian_mixture import gaussian + source, target = samples + g_s = gaussian.Gaussian.from_samples(source) + g_t = gaussian.Gaussian.from_samples(target) + lin_op = g_s.scale.gaussian_map(g_t.scale) + b = jnp.squeeze(g_t.loc) - jnp.linalg.solve(lin_op, jnp.squeeze(g_t.loc)) + lin_op = matrix_square_root.sqrtm_only(lin_op) + return jnp.expand_dims(lin_op, 0), jnp.expand_dims(b, 0) + + @staticmethod + def _compute_identity_map_params( + input_dim: int + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + A = jnp.eye(input_dim).reshape((1, input_dim, input_dim)) + b = jnp.zeros((1, input_dim)) + return A, b + + @nn.compact + def __call__(self, x: jnp.ndarray) -> float: # noqa: D102 + z = self.act_fn(self.w_xs[0](x)) + for i in range(self.num_hidden): + z = jnp.add(self.w_zs[i](z), self.w_xs[i + 1](x)) + z = self.act_fn(z) + z += self.pos_def_potential(x) + return z.squeeze() + + +@jax.tree_util.register_pytree_node_class +class MetaInitializer(lin_init.DefaultInitializer): + """Meta OT Initializer with a fixed geometry :cite:`amos:22`. + + This initializer consists of a predictive model that outputs the + :math:`f` duals to solve the entropy-regularized OT problem given + input probability weights ``a`` and ``b``, and a given (assumed to be + fixed) geometry ``geom``. + + The model's parameters are learned using a training set of OT + instances (multiple pairs of probability weights), that assume the + **same** geometry ``geom`` is used throughout, both for training and + evaluation. + + Args: + geom: The fixed geometry of the problem instances. + meta_model: The model to predict the potential :math:`f` from the measures. + TODO(marcocuturi): add explanation here what arguments to expect. + opt: The optimizer to update the parameters. If ``None``, use + :func:`optax.adam` with :math:`0.001` learning rate. + rng: The PRNG key to use for initializing the model. + state: The training state of the model to start from. + + Examples: + The following code shows a simple + example of using ``update`` to train the model, where + ``a`` and ``b`` are the weights of the measures and + ``geom`` is the fixed geometry. + + .. code-block:: python + + meta_initializer = init_lib.MetaInitializer(geom) + while training(): + a, b = sample_batch() + loss, init_f, meta_initializer.state = meta_initializer.update( + meta_initializer.state, a=a, b=b + ) + """ + + def __init__( + self, + geom: geometry.Geometry, + meta_model: nn.Module, + opt: Optional[optax.GradientTransformation + ] = optax.adam(learning_rate=1e-3), # noqa: B008 + rng: Optional[jax.Array] = None, + state: Optional[train_state.TrainState] = None + ): + self.geom = geom + self.dtype = geom.x.dtype + self.opt = opt + self.rng = utils.default_prng_key(rng) + + na, nb = geom.shape + # TODO(michalk8): add again some default MLP + self.meta_model = meta_model + + if state is None: + # Initialize the model's training state. + a_placeholder = jnp.zeros(na, dtype=self.dtype) + b_placeholder = jnp.zeros(nb, dtype=self.dtype) + params = self.meta_model.init(self.rng, a_placeholder, + b_placeholder)["params"] + self.state = train_state.TrainState.create( + apply_fn=self.meta_model.apply, params=params, tx=opt + ) + else: + self.state = state + + self.update_impl = self._get_update_fn() + + def update( + self, state: train_state.TrainState, a: jnp.ndarray, b: jnp.ndarray + ) -> Tuple[jnp.ndarray, jnp.ndarray, train_state.TrainState]: + r"""Update the meta model with the dual objective. + + The goal is for the model to match the optimal duals, i.e., + :math:`\hat f_\theta \approx f^\star`. + This can be done by training the predictions of :math:`\hat f_\theta` + to optimize the dual objective, which :math:`f^\star` also optimizes for. + The overall learning setup can thus be written as: + + .. math:: + \min_\theta\; {\mathbb E}_{(\alpha,\beta)\sim{\mathcal{D}}}\; + J(\hat f_\theta(a, b); \alpha, \beta), + + where :math:`a,b` are the probabilities of the measures :math:`\alpha,\beta` + ,:math:`\mathcal{D}` is a meta distribution of optimal transport problems, + + .. math:: + -J(f; \alpha, \beta, c) := \langle f, a\rangle + \langle g, b \rangle - + \varepsilon\left\langle \exp\{f/\varepsilon\}, K\exp\{g/\varepsilon\} + \right\rangle + + is the entropic dual objective, + and :math:`K_{i,j} := -C_{i,j}/\varepsilon` is the *Gibbs kernel*. + + Args: + state: Optimizer state of the meta model. + a: Probabilities of the :math:`\alpha` measure's atoms. + b: Probabilities of the :math:`\beta` measure's atoms. + + Returns: + The training loss, :math:`f`, and updated state. + """ + return self.update_impl(state, a, b) + + def init_dual_a( # noqa: D102 + self, + ot_prob: "linear_problem.LinearProblem", + lse_mode: bool, + rng: Optional[jax.Array] = None, + ) -> jnp.ndarray: + del rng + # Detect if the problem is batched. + assert ot_prob.a.ndim in (1, 2) + assert ot_prob.b.ndim in (1, 2) + vmap_a_val = 0 if ot_prob.a.ndim == 2 else None + vmap_b_val = 0 if ot_prob.b.ndim == 2 else None + + if vmap_a_val is not None or vmap_b_val is not None: + compute_f_maybe_batch = jax.vmap( + self._compute_f, in_axes=(vmap_a_val, vmap_b_val, None) + ) + else: + compute_f_maybe_batch = self._compute_f + + init_f = compute_f_maybe_batch(ot_prob.a, ot_prob.b, self.state.params) + return init_f if lse_mode else ot_prob.geom.scaling_from_potential(init_f) + + def _get_update_fn(self): + """Return the implementation (and jitted) update function.""" + from ott.problems.linear import linear_problem + from ott.solvers.linear import sinkhorn + + def dual_obj_loss_single(params, a, b): + f_pred = self._compute_f(a, b, params) + g_pred = self.geom.update_potential( + f_pred, jnp.zeros_like(b), jnp.log(b), 0, axis=0 + ) + g_pred = jnp.where(jnp.isfinite(g_pred), g_pred, 0.) + + ot_prob = linear_problem.LinearProblem(geom=self.geom, a=a, b=b) + dual_obj = sinkhorn.compute_kl_reg_cost( + f_pred, g_pred, ot_prob, lse_mode=True + ) + loss = -dual_obj + return loss, f_pred + + def loss_batch(params, a, b): + loss_fn = functools.partial(dual_obj_loss_single, params=params) + loss, f_pred = jax.vmap(loss_fn)(a=a, b=b) + return jnp.mean(loss), f_pred + + @jax.jit + def update(state, a, b): + a = jnp.atleast_2d(a) + b = jnp.atleast_2d(b) + grad_fn = jax.value_and_grad(loss_batch, has_aux=True) + (loss, init_f), grads = grad_fn(state.params, a, b) + return loss, init_f, state.apply_gradients(grads=grads) + + return update + + def _compute_f( + self, a: jnp.ndarray, b: jnp.ndarray, + params: frozen_dict.FrozenDict[str, jnp.ndarray] + ) -> jnp.ndarray: + r"""Predict the optimal :math:`f` potential. + + Args: + a: Probabilities of the :math:`\alpha` measure's atoms. + b: Probabilities of the :math:`\beta` measure's atoms. + params: The parameters of the Meta model. + + Returns: + The :math:`f` potential. + """ + return self.meta_model.apply({"params": params}, a, b) + + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 + return [self.geom, self.meta_model, self.opt], { + "rng": self.rng, + "state": self.state + } diff --git a/src/ott/neural/solvers/neuraldual.py b/src/ott/neural/duality/neuraldual.py similarity index 99% rename from src/ott/neural/solvers/neuraldual.py rename to src/ott/neural/duality/neuraldual.py index ade11a085..1d1aaa85b 100644 --- a/src/ott/neural/solvers/neuraldual.py +++ b/src/ott/neural/duality/neuraldual.py @@ -36,8 +36,8 @@ from ott import utils from ott.geometry import costs +from ott.neural.duality import conjugate from ott.neural.models import models -from ott.neural.solvers import conjugate from ott.problems.linear import potentials __all__ = ["W2NeuralTrainState", "BaseW2NeuralDual", "W2NeuralDual"] diff --git a/src/ott/neural/solvers/__init__.py b/src/ott/neural/flows/__init__.py similarity index 91% rename from src/ott/neural/solvers/__init__.py rename to src/ott/neural/flows/__init__.py index b09d8c60b..695cbbe3c 100644 --- a/src/ott/neural/solvers/__init__.py +++ b/src/ott/neural/flows/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import conjugate, map_estimator, neuraldual +from . import flows, genot, models, otfm diff --git a/src/ott/neural/solvers/flows.py b/src/ott/neural/flows/flows.py similarity index 99% rename from src/ott/neural/solvers/flows.py rename to src/ott/neural/flows/flows.py index 0ff81a560..93f471b9d 100644 --- a/src/ott/neural/solvers/flows.py +++ b/src/ott/neural/flows/flows.py @@ -97,7 +97,7 @@ def compute_xt( class StraightFlow(BaseFlow, abc.ABC): """Base class for flows with straight paths.""" - def compute_mu_t( + def compute_mu_t( # noqa: D102 self, t: jnp.ndarray, x_0: jnp.ndarray, x_1: jnp.ndarray ) -> jnp.ndarray: return t * x_0 + (1 - t) * x_1 diff --git a/src/ott/neural/solvers/genot.py b/src/ott/neural/flows/genot.py similarity index 99% rename from src/ott/neural/solvers/genot.py rename to src/ott/neural/flows/genot.py index fb76ded77..fa5ada781 100644 --- a/src/ott/neural/solvers/genot.py +++ b/src/ott/neural/flows/genot.py @@ -26,17 +26,17 @@ from ott import utils from ott.geometry import costs -from ott.neural.solvers.base_solver import ( - BaseNeuralSolver, - ResampleMixin, - UnbalancednessMixin, -) -from ott.neural.solvers.flows import ( +from ott.neural.flows.flows import ( BaseFlow, BaseTimeSampler, ConstantNoiseFlow, UniformSampler, ) +from ott.neural.models.base_solver import ( + BaseNeuralSolver, + ResampleMixin, + UnbalancednessMixin, +) from ott.solvers import was_solver from ott.solvers.linear import sinkhorn from ott.solvers.quadratic import gromov_wasserstein diff --git a/src/ott/neural/flows/models.py b/src/ott/neural/flows/models.py new file mode 100644 index 000000000..4cf671a19 --- /dev/null +++ b/src/ott/neural/flows/models.py @@ -0,0 +1,188 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Optional + +import jax +import jax.numpy as jnp + +import flax.linen as nn +import optax +from flax.training import train_state + +from ott.neural.models import layers + +__all__ = ["VelocityField"] + + +class VelocityField(nn.Module): + """Parameterized neural vector field. + + Each of the input, condition, and time embeddings are passed through a block + consisting of ``num_layers_per_block`` layers of dimension + ``latent_embed_dim``, ``condition_embed_dim``, and ``time_embed_dim``, + respectively. + The output of each block is concatenated and passed through a final block of + dimension ``joint_hidden_dim``. + + Args: + output_dim: Dimensionality of the neural vector field. + condition_dim: Dimensionality of the conditioning vector. + latent_embed_dim: Dimensionality of the embedding of the data. + condition_embed_dim: Dimensionality of the embedding of the condition. + If ``None``, set to ``latent_embed_dim``. + t_embed_dim: Dimensionality of the time embedding. + If ``None``, set to ``latent_embed_dim``. + joint_hidden_dim: Dimensionality of the hidden layers of the joint network. + If ``None``, set to ``latent_embed_dim + condition_embed_dim + + t_embed_dim``. + num_layers_per_block: Number of layers per block. + act_fn: Activation function. + n_frequencies: Number of frequencies to use for the time embedding. + + """ + output_dim: int + condition_dim: int + latent_embed_dim: int + condition_embed_dim: Optional[int] = None + t_embed_dim: Optional[int] = None + joint_hidden_dim: Optional[int] = None + num_layers_per_block: int = 3 + act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.silu + n_frequencies: int = 128 + + def time_encoder(self, t: jnp.ndarray) -> jnp.array: + """Encode the time. + + Args: + t: Time. + + Returns: + Encoded time. + """ + freq = 2 * jnp.arange(self.n_frequencies) * jnp.pi + t = freq * t + return jnp.concatenate((jnp.cos(t), jnp.sin(t)), axis=-1) + + def __post_init__(self): + + # set embedded dim from latent embedded dim + if self.condition_embed_dim is None: + self.condition_embed_dim = self.latent_embed_dim + if self.t_embed_dim is None: + self.t_embed_dim = self.latent_embed_dim + + # set joint hidden dim from all embedded dim + concat_embed_dim = ( + self.latent_embed_dim + self.condition_embed_dim + self.t_embed_dim + ) + if self.joint_hidden_dim is not None: + assert (self.joint_hidden_dim >= concat_embed_dim), ( + "joint_hidden_dim must be greater than or equal to the sum of " + "all embedded dimensions. " + ) + self.joint_hidden_dim = self.latent_embed_dim + else: + self.joint_hidden_dim = concat_embed_dim + super().__post_init__() + + @nn.compact + def __call__( + self, + t: jnp.ndarray, + x: jnp.ndarray, + condition: Optional[jnp.ndarray], + keys_model: Optional[jnp.ndarray] = None, + ) -> jnp.ndarray: + """Forward pass through the neural vector field. + + Args: + t: Time. + x: Data. + condition: Conditioning vector. + keys_model: Random number generator. + + Returns: + Output of the neural vector field. + """ + t = self.time_encoder(t) + t = layers.MLPBlock( + dim=self.t_embed_dim, + out_dim=self.t_embed_dim, + num_layers=self.num_layers_per_block, + act_fn=self.act_fn, + )( + t + ) + + x = layers.MLPBlock( + dim=self.latent_embed_dim, + out_dim=self.latent_embed_dim, + num_layers=self.num_layers_per_block, + act_fn=self.act_fn + )( + x + ) + + if self.condition_dim > 0: + condition = layers.MLPBlock( + dim=self.condition_embed_dim, + out_dim=self.condition_embed_dim, + num_layers=self.num_layers_per_block, + act_fn=self.act_fn + )( + condition + ) + concatenated = jnp.concatenate((t, x, condition), axis=-1) + else: + concatenated = jnp.concatenate((t, x), axis=-1) + + out = layers.MLPBlock( + dim=self.joint_hidden_dim, + out_dim=self.joint_hidden_dim, + num_layers=self.num_layers_per_block, + act_fn=self.act_fn, + )( + concatenated + ) + + return nn.Dense( + self.output_dim, + use_bias=True, + )( + out + ) + + def create_train_state( + self, + rng: jax.Array, + optimizer: optax.OptState, + input_dim: int, + ) -> train_state.TrainState: + """Create the training state. + + Args: + rng: Random number generator. + optimizer: Optimizer. + input_dim: Dimensionality of the input. + + Returns: + Training state. + """ + params = self.init( + rng, jnp.ones((1, 1)), jnp.ones((1, input_dim)), + jnp.ones((1, self.condition_dim)) + )["params"] + return train_state.TrainState.create( + apply_fn=self.apply, params=params, tx=optimizer + ) diff --git a/src/ott/neural/solvers/otfm.py b/src/ott/neural/flows/otfm.py similarity index 99% rename from src/ott/neural/solvers/otfm.py rename to src/ott/neural/flows/otfm.py index b7885c1d5..ed0114f6d 100644 --- a/src/ott/neural/solvers/otfm.py +++ b/src/ott/neural/flows/otfm.py @@ -36,12 +36,12 @@ from ott import utils from ott.geometry import costs -from ott.neural.solvers.base_solver import ( +from ott.neural.flows.flows import BaseFlow, BaseTimeSampler +from ott.neural.models.base_solver import ( BaseNeuralSolver, ResampleMixin, UnbalancednessMixin, ) -from ott.neural.solvers.flows import BaseFlow, BaseTimeSampler from ott.solvers import was_solver __all__ = ["OTFlowMatching"] diff --git a/src/ott/neural/gaps/__init__.py b/src/ott/neural/gaps/__init__.py new file mode 100644 index 000000000..0ba36da05 --- /dev/null +++ b/src/ott/neural/gaps/__init__.py @@ -0,0 +1,14 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from . import map_estimator, monge_gap diff --git a/src/ott/neural/solvers/map_estimator.py b/src/ott/neural/gaps/map_estimator.py similarity index 99% rename from src/ott/neural/solvers/map_estimator.py rename to src/ott/neural/gaps/map_estimator.py index fb65917c7..cfcc8cb86 100644 --- a/src/ott/neural/solvers/map_estimator.py +++ b/src/ott/neural/gaps/map_estimator.py @@ -32,7 +32,7 @@ from flax.training import train_state from ott import utils -from ott.neural.solvers import neuraldual +from ott.neural.duality import neuraldual __all__ = ["MapEstimator"] diff --git a/src/ott/neural/models/losses.py b/src/ott/neural/gaps/monge_gap.py similarity index 100% rename from src/ott/neural/models/losses.py rename to src/ott/neural/gaps/monge_gap.py diff --git a/src/ott/neural/models/__init__.py b/src/ott/neural/models/__init__.py index 1e374d236..5c2ac3b2b 100644 --- a/src/ott/neural/models/__init__.py +++ b/src/ott/neural/models/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import base_models, layers, losses, models +from . import base_solver, layers, models diff --git a/src/ott/neural/solvers/base_solver.py b/src/ott/neural/models/base_solver.py similarity index 99% rename from src/ott/neural/solvers/base_solver.py rename to src/ott/neural/models/base_solver.py index 780bf61ad..e60d25766 100644 --- a/src/ott/neural/solvers/base_solver.py +++ b/src/ott/neural/models/base_solver.py @@ -27,6 +27,8 @@ from ott.problems.quadratic import quadratic_problem from ott.solvers.linear import sinkhorn +__all__ = ["BaseNeuralSolver", "ResampleMixin", "UnbalancednessMixin"] + class BaseNeuralSolver(ABC): """Base class for neural solvers. diff --git a/src/ott/neural/models/layers.py b/src/ott/neural/models/layers.py index 50c2c6301..db8b24ae9 100644 --- a/src/ott/neural/models/layers.py +++ b/src/ott/neural/models/layers.py @@ -11,14 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Tuple +from typing import Any, Tuple import jax -import jax.numpy as jnp import flax.linen as nn -__all__ = ["PositiveDense", "PosDefPotentials", "MLPBlock"] +__all__ = ["MLPBlock"] PRNGKey = jax.Array Shape = Tuple[int, ...] @@ -27,128 +26,23 @@ class MLPBlock(nn.Module): + """A simple MLP block.""" dim: int = 128 num_layers: int = 3 act_fn: Any = nn.silu - out_dim: int = 32 + out_dim: int = 128 @nn.compact def __call__(self, x): - for _ in range(self.num_layers): - x = nn.Dense(self.dim)(x) - x = self.act_fn(x) - return nn.Dense(self.out_dim)(x) - - -class PositiveDense(nn.Module): - """A linear transformation using a weight matrix with all entries positive. - - Args: - dim_hidden: the number of output dim_hidden. - rectifier_fn: choice of rectifier function (default: softplus function). - inv_rectifier_fn: choice of inverse rectifier function - (default: inverse softplus function). - dtype: the dtype of the computation (default: float32). - precision: numerical precision of computation see `jax.lax.Precision` - for details. - kernel_init: initializer function for the weight matrix. - bias_init: initializer function for the bias. - """ - dim_hidden: int - rectifier_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.softplus - inv_rectifier_fn: Callable[[jnp.ndarray], - jnp.ndarray] = lambda x: jnp.log(jnp.exp(x) - 1) - use_bias: bool = True - dtype: Any = jnp.float32 - precision: Any = None - kernel_init: Optional[Callable[[PRNGKey, Shape, Dtype], Array]] = None, - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.zeros - - @nn.compact - def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: - """Applies a linear transformation to inputs along the last dimension. + """Apply the MLP block. Args: - inputs: Array to be transformed. + x: Input data of shape (batch_size, dim) Returns: - The transformed input. + Output data of shape (batch_size, out_dim). """ - kernel_init = nn.initializers.lecun_normal( - ) if self.kernel_init is None else self.kernel_init - - inputs = jnp.asarray(inputs, self.dtype) - kernel = self.param( - "kernel", kernel_init, (inputs.shape[-1], self.dim_hidden) - ) - kernel = self.rectifier_fn(kernel) - kernel = jnp.asarray(kernel, self.dtype) - y = jax.lax.dot_general( - inputs, - kernel, (((inputs.ndim - 1,), (0,)), ((), ())), - precision=self.precision - ) - if self.use_bias: - bias = self.param("bias", self.bias_init, (self.dim_hidden,)) - bias = jnp.asarray(bias, self.dtype) - return y + bias - return y - - -class PosDefPotentials(nn.Module): - r"""A layer to output :math:`\frac{1}{2} ||A_i^T (x - b_i)||^2_i` potentials. - - Args: - use_bias: whether to add a bias to the output. - dtype: the dtype of the computation. - precision: numerical precision of computation see `jax.lax.Precision` - for details. - kernel_init: initializer function for the weight matrix. - bias_init: initializer function for the bias. - """ - dim_data: int - num_potentials: int - use_bias: bool = True - dtype: Any = jnp.float32 - precision: Any = None - kernel_init: Optional[Callable[[PRNGKey, Shape, Dtype], Array]] = None - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.zeros - - @nn.compact - def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: - """Apply a few quadratic forms. - - Args: - inputs: Array to be transformed (possibly batched). - - Returns: - The transformed input. - """ - kernel_init = nn.initializers.lecun_normal( - ) if self.kernel_init is None else self.kernel_init - inputs = jnp.asarray(inputs, self.dtype) - kernel = self.param( - "kernel", kernel_init, - (self.num_potentials, inputs.shape[-1], inputs.shape[-1]) - ) - - if self.use_bias: - bias = self.param( - "bias", self.bias_init, (self.num_potentials, self.dim_data) - ) - bias = jnp.asarray(bias, self.dtype) - - y = inputs.reshape((-1, inputs.shape[-1])) if inputs.ndim == 1 else inputs - y = y[..., None] - bias.T[None, ...] - y = jax.lax.dot_general( - y, kernel, (((1,), (1,)), ((2,), (0,))), precision=self.precision - ) - else: - y = jax.lax.dot_general( - inputs, - kernel, (((inputs.ndim - 1,), (0,)), ((), ())), - precision=self.precision - ) - - y = 0.5 * y * y - return jnp.sum(y.reshape((-1, self.num_potentials, self.dim_data)), axis=2) + for _ in range(self.num_layers): + x = nn.Dense(self.dim)(x) + x = self.act_fn(x) + return nn.Dense(self.out_dim)(x) diff --git a/src/ott/neural/models/models.py b/src/ott/neural/models/models.py index 0fc7d4f30..78fd3d173 100644 --- a/src/ott/neural/models/models.py +++ b/src/ott/neural/models/models.py @@ -11,171 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import functools -from typing import Any, Callable, Dict, Optional, Sequence, Tuple +from typing import Callable, Optional, Sequence import jax import jax.numpy as jnp -from jax.nn import initializers import flax.linen as nn import optax -from flax.core import frozen_dict from flax.training import train_state -from ott import utils -from ott.geometry import geometry -from ott.initializers.linear import initializers as lin_init -from ott.math import matrix_square_root from ott.neural.models import layers -from ott.neural.solvers import neuraldual -from ott.problems.linear import linear_problem -__all__ = ["ICNN", "MLP", "MetaInitializer", "VelocityField", "RescalingMLP"] +__all__ = ["MLP", "RescalingMLP"] -class ICNN(neuraldual.BaseW2NeuralDual): - """Input convex neural network (ICNN) architecture with initialization. - - Implementation of input convex neural networks as introduced in - :cite:`amos:17` with initialization schemes proposed by :cite:`bunne:22`. - - Args: - dim_data: data dimensionality. - dim_hidden: sequence specifying size of hidden dimensions. The - output dimension of the last layer is 1 by default. - init_std: value of standard deviation of weight initialization method. - init_fn: choice of initialization method for weight matrices (default: - :func:`jax.nn.initializers.normal`). - act_fn: choice of activation function used in network architecture - (needs to be convex, default: :obj:`jax.nn.relu`). - pos_weights: Enforce positive weights with a projection. - If ``False``, the positive weights should be enforced with clipping - or regularization in the loss. - gaussian_map_samples: Tuple of source and target points, used to initialize - the ICNN to mimic the linear Bures map that morphs the (Gaussian - approximation) of the input measure to that of the target measure. If - ``None``, the identity initialization is used, and ICNN mimics half the - squared Euclidean norm. - """ - dim_data: int - dim_hidden: Sequence[int] - init_std: float = 1e-2 - init_fn: Callable = jax.nn.initializers.normal - act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu - pos_weights: bool = True - gaussian_map_samples: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None - - @property - def is_potential(self) -> bool: # noqa: D102 - return True - - def setup(self): # noqa: D102 - self.num_hidden = len(self.dim_hidden) - - if self.pos_weights: - hid_dense = layers.PositiveDense - # this function needs to be the inverse map of function - # used in PositiveDense layers - rescale = hid_dense.inv_rectifier_fn - else: - hid_dense = nn.Dense - rescale = lambda x: x - self.use_init = False - # check if Gaussian map was provided - if self.gaussian_map_samples is not None: - factor, mean = self._compute_gaussian_map_params( - self.gaussian_map_samples - ) - else: - factor, mean = self._compute_identity_map_params(self.dim_data) - - w_zs = [] - # keep track of previous size to normalize accordingly - normalization = 1 - - for i in range(1, self.num_hidden): - w_zs.append( - hid_dense( - self.dim_hidden[i], - kernel_init=initializers.constant(rescale(1.0 / normalization)), - use_bias=False, - ) - ) - normalization = self.dim_hidden[i] - # final layer computes average, still with normalized rescaling - w_zs.append( - hid_dense( - 1, - kernel_init=initializers.constant(rescale(1.0 / normalization)), - use_bias=False, - ) - ) - self.w_zs = w_zs - - # positive definite potential (the identity mapping or linear OT) - self.pos_def_potential = layers.PosDefPotentials( - self.dim_data, - num_potentials=1, - kernel_init=lambda *_: factor, - bias_init=lambda *_: mean, - use_bias=True, - ) - - # subsequent layers re-injected into convex functions - w_xs = [] - for i in range(self.num_hidden): - w_xs.append( - nn.Dense( - self.dim_hidden[i], - kernel_init=self.init_fn(self.init_std), - bias_init=initializers.constant(0.), - use_bias=True, - ) - ) - # final layer, to output number - w_xs.append( - nn.Dense( - 1, - kernel_init=self.init_fn(self.init_std), - bias_init=initializers.constant(0.), - use_bias=True, - ) - ) - self.w_xs = w_xs - - @staticmethod - def _compute_gaussian_map_params( - samples: Tuple[jnp.ndarray, jnp.ndarray] - ) -> Tuple[jnp.ndarray, jnp.ndarray]: - from ott.tools.gaussian_mixture import gaussian - source, target = samples - g_s = gaussian.Gaussian.from_samples(source) - g_t = gaussian.Gaussian.from_samples(target) - lin_op = g_s.scale.gaussian_map(g_t.scale) - b = jnp.squeeze(g_t.loc) - jnp.linalg.solve(lin_op, jnp.squeeze(g_t.loc)) - lin_op = matrix_square_root.sqrtm_only(lin_op) - return jnp.expand_dims(lin_op, 0), jnp.expand_dims(b, 0) - - @staticmethod - def _compute_identity_map_params( - input_dim: int - ) -> Tuple[jnp.ndarray, jnp.ndarray]: - A = jnp.eye(input_dim).reshape((1, input_dim, input_dim)) - b = jnp.zeros((1, input_dim)) - return A, b - - @nn.compact - def __call__(self, x: jnp.ndarray) -> float: # noqa: D102 - z = self.act_fn(self.w_xs[0](x)) - for i in range(self.num_hidden): - z = jnp.add(self.w_zs[i](z), self.w_xs[i + 1](x)) - z = self.act_fn(z) - z += self.pos_def_potential(x) - return z.squeeze() - - -class MLP(neuraldual.BaseW2NeuralDual): +class MLP(nn.Module): """A generic, not-convex MLP. Args: @@ -217,356 +67,6 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: # noqa: D102 return z.squeeze(0) if squeeze else z -@jax.tree_util.register_pytree_node_class -class MetaInitializer(lin_init.DefaultInitializer): - """Meta OT Initializer with a fixed geometry :cite:`amos:22`. - - This initializer consists of a predictive model that outputs the - :math:`f` duals to solve the entropy-regularized OT problem given - input probability weights ``a`` and ``b``, and a given (assumed to be - fixed) geometry ``geom``. - - The model's parameters are learned using a training set of OT - instances (multiple pairs of probability weights), that assume the - **same** geometry ``geom`` is used throughout, both for training and - evaluation. - - Args: - geom: The fixed geometry of the problem instances. - meta_model: The model to predict the potential :math:`f` from the measures. - TODO(marcocuturi): add explanation here what arguments to expect. - opt: The optimizer to update the parameters. If ``None``, use - :func:`optax.adam` with :math:`0.001` learning rate. - rng: The PRNG key to use for initializing the model. - state: The training state of the model to start from. - - Examples: - The following code shows a simple - example of using ``update`` to train the model, where - ``a`` and ``b`` are the weights of the measures and - ``geom`` is the fixed geometry. - - .. code-block:: python - - meta_initializer = init_lib.MetaInitializer(geom) - while training(): - a, b = sample_batch() - loss, init_f, meta_initializer.state = meta_initializer.update( - meta_initializer.state, a=a, b=b - ) - """ - - def __init__( - self, - geom: geometry.Geometry, - meta_model: nn.Module, - opt: Optional[optax.GradientTransformation - ] = optax.adam(learning_rate=1e-3), # noqa: B008 - rng: Optional[jax.Array] = None, - state: Optional[train_state.TrainState] = None - ): - self.geom = geom - self.dtype = geom.x.dtype - self.opt = opt - self.rng = utils.default_prng_key(rng) - - na, nb = geom.shape - # TODO(michalk8): add again some default MLP - self.meta_model = meta_model - - if state is None: - # Initialize the model's training state. - a_placeholder = jnp.zeros(na, dtype=self.dtype) - b_placeholder = jnp.zeros(nb, dtype=self.dtype) - params = self.meta_model.init(self.rng, a_placeholder, - b_placeholder)["params"] - self.state = train_state.TrainState.create( - apply_fn=self.meta_model.apply, params=params, tx=opt - ) - else: - self.state = state - - self.update_impl = self._get_update_fn() - - def update( - self, state: train_state.TrainState, a: jnp.ndarray, b: jnp.ndarray - ) -> Tuple[jnp.ndarray, jnp.ndarray, train_state.TrainState]: - r"""Update the meta model with the dual objective. - - The goal is for the model to match the optimal duals, i.e., - :math:`\hat f_\theta \approx f^\star`. - This can be done by training the predictions of :math:`\hat f_\theta` - to optimize the dual objective, which :math:`f^\star` also optimizes for. - The overall learning setup can thus be written as: - - .. math:: - \min_\theta\; {\mathbb E}_{(\alpha,\beta)\sim{\mathcal{D}}}\; - J(\hat f_\theta(a, b); \alpha, \beta), - - where :math:`a,b` are the probabilities of the measures :math:`\alpha,\beta` - ,:math:`\mathcal{D}` is a meta distribution of optimal transport problems, - - .. math:: - -J(f; \alpha, \beta, c) := \langle f, a\rangle + \langle g, b \rangle - - \varepsilon\left\langle \exp\{f/\varepsilon\}, K\exp\{g/\varepsilon\} - \right\rangle - - is the entropic dual objective, - and :math:`K_{i,j} := -C_{i,j}/\varepsilon` is the *Gibbs kernel*. - - Args: - state: Optimizer state of the meta model. - a: Probabilities of the :math:`\alpha` measure's atoms. - b: Probabilities of the :math:`\beta` measure's atoms. - - Returns: - The training loss, :math:`f`, and updated state. - """ - return self.update_impl(state, a, b) - - def init_dual_a( # noqa: D102 - self, - ot_prob: "linear_problem.LinearProblem", - lse_mode: bool, - rng: Optional[jax.Array] = None, - ) -> jnp.ndarray: - del rng - # Detect if the problem is batched. - assert ot_prob.a.ndim in (1, 2) - assert ot_prob.b.ndim in (1, 2) - vmap_a_val = 0 if ot_prob.a.ndim == 2 else None - vmap_b_val = 0 if ot_prob.b.ndim == 2 else None - - if vmap_a_val is not None or vmap_b_val is not None: - compute_f_maybe_batch = jax.vmap( - self._compute_f, in_axes=(vmap_a_val, vmap_b_val, None) - ) - else: - compute_f_maybe_batch = self._compute_f - - init_f = compute_f_maybe_batch(ot_prob.a, ot_prob.b, self.state.params) - return init_f if lse_mode else ot_prob.geom.scaling_from_potential(init_f) - - def _get_update_fn(self): - """Return the implementation (and jitted) update function.""" - from ott.problems.linear import linear_problem - from ott.solvers.linear import sinkhorn - - def dual_obj_loss_single(params, a, b): - f_pred = self._compute_f(a, b, params) - g_pred = self.geom.update_potential( - f_pred, jnp.zeros_like(b), jnp.log(b), 0, axis=0 - ) - g_pred = jnp.where(jnp.isfinite(g_pred), g_pred, 0.) - - ot_prob = linear_problem.LinearProblem(geom=self.geom, a=a, b=b) - dual_obj = sinkhorn.compute_kl_reg_cost( - f_pred, g_pred, ot_prob, lse_mode=True - ) - loss = -dual_obj - return loss, f_pred - - def loss_batch(params, a, b): - loss_fn = functools.partial(dual_obj_loss_single, params=params) - loss, f_pred = jax.vmap(loss_fn)(a=a, b=b) - return jnp.mean(loss), f_pred - - @jax.jit - def update(state, a, b): - a = jnp.atleast_2d(a) - b = jnp.atleast_2d(b) - grad_fn = jax.value_and_grad(loss_batch, has_aux=True) - (loss, init_f), grads = grad_fn(state.params, a, b) - return loss, init_f, state.apply_gradients(grads=grads) - - return update - - def _compute_f( - self, a: jnp.ndarray, b: jnp.ndarray, - params: frozen_dict.FrozenDict[str, jnp.ndarray] - ) -> jnp.ndarray: - r"""Predict the optimal :math:`f` potential. - - Args: - a: Probabilities of the :math:`\alpha` measure's atoms. - b: Probabilities of the :math:`\beta` measure's atoms. - params: The parameters of the Meta model. - - Returns: - The :math:`f` potential. - """ - return self.meta_model.apply({"params": params}, a, b) - - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 - return [self.geom, self.meta_model, self.opt], { - "rng": self.rng, - "state": self.state - } - - -class VelocityField(nn.Module): - """Parameterized neural vector field. - - Each of the input, condition, and time embeddings are passed through a block - consisting of ``num_layers_per_block`` layers of dimension - ``latent_embed_dim``, ``condition_embed_dim``, and ``time_embed_dim``, - respectively. - The output of each block is concatenated and passed through a final block of - dimension ``joint_hidden_dim``. - - Args: - output_dim: Dimensionality of the neural vector field. - condition_dim: Dimensionality of the conditioning vector. - latent_embed_dim: Dimensionality of the embedding of the data. - condition_embed_dim: Dimensionality of the embedding of the condition. - If ``None``, set to ``latent_embed_dim``. - t_embed_dim: Dimensionality of the time embedding. - If ``None``, set to ``latent_embed_dim``. - joint_hidden_dim: Dimensionality of the hidden layers of the joint network. - If ``None``, set to ``latent_embed_dim + condition_embed_dim + - t_embed_dim``. - num_layers_per_block: Number of layers per block. - act_fn: Activation function. - n_frequencies: Number of frequencies to use for the time embedding. - - """ - output_dim: int - condition_dim: int - latent_embed_dim: int - condition_embed_dim: Optional[int] = None - t_embed_dim: Optional[int] = None - joint_hidden_dim: Optional[int] = None - num_layers_per_block: int = 3 - act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.silu - n_frequencies: int = 128 - - def time_encoder(self, t: jnp.ndarray) -> jnp.array: - """Encode the time. - - Args: - t: Time. - - Returns: - Encoded time. - """ - freq = 2 * jnp.arange(self.n_frequencies) * jnp.pi - t = freq * t - return jnp.concatenate((jnp.cos(t), jnp.sin(t)), axis=-1) - - def __post_init__(self): - - # set embedded dim from latent embedded dim - if self.condition_embed_dim is None: - self.condition_embed_dim = self.latent_embed_dim - if self.t_embed_dim is None: - self.t_embed_dim = self.latent_embed_dim - - # set joint hidden dim from all embedded dim - concat_embed_dim = ( - self.latent_embed_dim + self.condition_embed_dim + self.t_embed_dim - ) - if self.joint_hidden_dim is not None: - assert (self.joint_hidden_dim >= concat_embed_dim), ( - "joint_hidden_dim must be greater than or equal to the sum of " - "all embedded dimensions. " - ) - self.joint_hidden_dim = self.latent_embed_dim - else: - self.joint_hidden_dim = concat_embed_dim - super().__post_init__() - - @nn.compact - def __call__( - self, - t: jnp.ndarray, - x: jnp.ndarray, - condition: Optional[jnp.ndarray], - keys_model: Optional[jnp.ndarray] = None, - ) -> jnp.ndarray: - """Forward pass through the neural vector field. - - Args: - t: Time. - x: Data. - condition: Conditioning vector. - keys_model: Random number generator. - - Returns: - Output of the neural vector field. - """ - t = self.time_encoder(t) - t = layers.MLPBlock( - dim=self.t_embed_dim, - out_dim=self.t_embed_dim, - num_layers=self.num_layers_per_block, - act_fn=self.act_fn, - )( - t - ) - - x = layers.MLPBlock( - dim=self.latent_embed_dim, - out_dim=self.latent_embed_dim, - num_layers=self.num_layers_per_block, - act_fn=self.act_fn - )( - x - ) - - if self.condition_dim > 0: - condition = layers.MLPBlock( - dim=self.condition_embed_dim, - out_dim=self.condition_embed_dim, - num_layers=self.num_layers_per_block, - act_fn=self.act_fn - )( - condition - ) - concatenated = jnp.concatenate((t, x, condition), axis=-1) - else: - concatenated = jnp.concatenate((t, x), axis=-1) - - out = layers.MLPBlock( - dim=self.joint_hidden_dim, - out_dim=self.joint_hidden_dim, - num_layers=self.num_layers_per_block, - act_fn=self.act_fn, - )( - concatenated - ) - - return nn.Dense( - self.output_dim, - use_bias=True, - )( - out - ) - - def create_train_state( - self, - rng: jax.Array, - optimizer: optax.OptState, - input_dim: int, - ) -> train_state.TrainState: - """Create the training state. - - Args: - rng: Random number generator. - optimizer: Optimizer. - input_dim: Dimensionality of the input. - - Returns: - Training state. - """ - params = self.init( - rng, jnp.ones((1, 1)), jnp.ones((1, input_dim)), - jnp.ones((1, self.condition_dim)) - )["params"] - return train_state.TrainState.create( - apply_fn=self.apply, params=params, tx=optimizer - ) - - class RescalingMLP(nn.Module): """Network to learn distributional rescaling factors based on a MLP. diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index 92a929154..a962afca3 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -20,9 +20,10 @@ import optax from ott.geometry import costs -from ott.neural.models.models import RescalingMLP, VelocityField -from ott.neural.solvers.flows import OffsetUniformSampler, UniformSampler -from ott.neural.solvers.genot import GENOT +from ott.neural.flows.flows import OffsetUniformSampler, UniformSampler +from ott.neural.flows.genot import GENOT +from ott.neural.flows.models import VelocityField +from ott.neural.models.models import RescalingMLP from ott.solvers.linear import sinkhorn from ott.solvers.quadratic import gromov_wasserstein diff --git a/tests/neural/losses_test.py b/tests/neural/losses_test.py index f18681c7a..d6c9334cd 100644 --- a/tests/neural/losses_test.py +++ b/tests/neural/losses_test.py @@ -18,7 +18,8 @@ import numpy as np from ott.geometry import costs -from ott.neural.models import losses, models +from ott.neural import models +from ott.neural.gaps import monge_gap @pytest.mark.fast() @@ -39,13 +40,13 @@ def test_monge_gap_non_negativity( target = model.apply(params, reference_points) # compute the Monge gap based on samples - monge_gap_from_samples_value = losses.monge_gap_from_samples( + monge_gap_from_samples_value = monge_gap.monge_gap_from_samples( source=reference_points, target=target ) np.testing.assert_array_equal(monge_gap_from_samples_value >= 0, True) # Compute the Monge gap using model directly - monge_gap_value = losses.monge_gap( + monge_gap_value = monge_gap.monge_gap( map_fn=lambda x: model.apply(params, x), reference_points=reference_points ) @@ -60,10 +61,10 @@ def test_monge_gap_jit(self, rng: jax.Array): source = jax.random.normal(rng1, (n_samples, n_features)) target = jax.random.normal(rng2, (n_samples, n_features)) # define jitted monge gap - jit_monge_gap = jax.jit(losses.monge_gap_from_samples) + jit_monge_gap = jax.jit(monge_gap.monge_gap_from_samples) # compute the Monge gaps for different costs - monge_gap_value = losses.monge_gap_from_samples( + monge_gap_value = monge_gap.monge_gap_from_samples( source=source, target=target ) jit_monge_gap_value = jit_monge_gap(source, target) @@ -101,10 +102,10 @@ def test_monge_gap_from_samples_different_cost( target = jax.random.normal(rng2, (n_samples, n_features)) * .1 + 3. # compute the Monge gaps for the euclidean cost - monge_gap_from_samples_value_eucl = losses.monge_gap_from_samples( + monge_gap_from_samples_value_eucl = monge_gap.monge_gap_from_samples( source=source, target=target, cost_fn=costs.Euclidean() ) - monge_gap_from_samples_value_cost_fn = losses.monge_gap_from_samples( + monge_gap_from_samples_value_cost_fn = monge_gap.monge_gap_from_samples( source=source, target=target, cost_fn=cost_fn ) diff --git a/tests/neural/map_estimator_test.py b/tests/neural/map_estimator_test.py index b5df51170..f19c63b32 100644 --- a/tests/neural/map_estimator_test.py +++ b/tests/neural/map_estimator_test.py @@ -19,8 +19,7 @@ from ott import datasets from ott.geometry import pointcloud -from ott.neural.models import losses, models -from ott.neural.solvers import map_estimator +from ott.neural.gaps import map_estimator, monge_gap from ott.tools import sinkhorn_divergence @@ -47,11 +46,11 @@ def fitting_loss( return (div, None) def regularizer(x, y): - gap, out = losses.monge_gap_from_samples(x, y, return_output=True) + gap, out = monge_gap.monge_gap_from_samples(x, y, return_output=True) return gap, out.n_iters # define the model - model = models.MLP(dim_hidden=[16, 8], is_potential=False) + model = monge_gap.MLP(dim_hidden=[16, 8], is_potential=False) # generate data train_dataset, valid_dataset, dim_data = ( diff --git a/tests/neural/neuraldual_test.py b/tests/neural/neuraldual_test.py index 8a362affa..fc107ec75 100644 --- a/tests/neural/neuraldual_test.py +++ b/tests/neural/neuraldual_test.py @@ -19,8 +19,8 @@ import numpy as np from ott import datasets +from ott.neural.duality import conjugate, neuraldual from ott.neural.models import models -from ott.neural.solvers import conjugate, neuraldual ModelPair_t = Tuple[neuraldual.BaseW2NeuralDual, neuraldual.BaseW2NeuralDual] DatasetPair_t = Tuple[datasets.Dataset, datasets.Dataset] diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index b38fceb74..0af948705 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -19,15 +19,16 @@ import optax -from ott.neural.models.models import RescalingMLP, VelocityField -from ott.neural.solvers.flows import ( +from ott.neural.flows.flows import ( BaseFlow, BrownianNoiseFlow, ConstantNoiseFlow, OffsetUniformSampler, UniformSampler, ) -from ott.neural.solvers.otfm import OTFlowMatching +from ott.neural.flows.models import VelocityField +from ott.neural.flows.otfm import OTFlowMatching +from ott.neural.models.models import RescalingMLP from ott.solvers.linear import sinkhorn From 8f404f8284a40323a169ceee96248325e8db7361 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Mon, 4 Dec 2023 11:31:38 +0100 Subject: [PATCH 054/186] fix import errors --- src/ott/neural/duality/models.py | 47 +++++++++++++++++++++++++-- src/ott/neural/duality/neuraldual.py | 5 ++- src/ott/neural/models/models.py | 21 ++++++++++++ tests/neural/icnn_test.py | 2 +- tests/neural/losses_test.py | 2 +- tests/neural/map_estimator_test.py | 3 +- tests/neural/meta_initializer_test.py | 2 +- tests/neural/neuraldual_test.py | 10 +++--- 8 files changed, 77 insertions(+), 15 deletions(-) diff --git a/src/ott/neural/duality/models.py b/src/ott/neural/duality/models.py index 2b51c60cf..d10e09f55 100644 --- a/src/ott/neural/duality/models.py +++ b/src/ott/neural/duality/models.py @@ -27,11 +27,10 @@ from ott.geometry import geometry from ott.initializers.linear import initializers as lin_init from ott.math import matrix_square_root -from ott.neural.duality import neuraldual -from ott.neural.models import layers +from ott.neural.duality import layers, neuraldual from ott.problems.linear import linear_problem -__all__ = ["ICNN", "MetaInitializer"] +__all__ = ["ICNN", "PotentialMLP", "MetaInitializer"] class ICNN(neuraldual.BaseW2NeuralDual): @@ -175,6 +174,48 @@ def __call__(self, x: jnp.ndarray) -> float: # noqa: D102 return z.squeeze() +class PotentialMLP(neuraldual.BaseW2NeuralDual): + """A generic, not-convex MLP. + + Args: + dim_hidden: sequence specifying size of hidden dimensions. The output + dimension of the last layer is automatically set to 1 if + :attr:`is_potential` is ``True``, or the dimension of the input otherwise + is_potential: Model the potential if ``True``, otherwise + model the gradient of the potential + act_fn: Activation function + """ + + dim_hidden: Sequence[int] + is_potential: bool = True + act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.leaky_relu + + @nn.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: # noqa: D102 + squeeze = x.ndim == 1 + if squeeze: + x = jnp.expand_dims(x, 0) + assert x.ndim == 2, x.ndim + n_input = x.shape[-1] + + z = x + for n_hidden in self.dim_hidden: + Wx = nn.Dense(n_hidden, use_bias=True) + z = self.act_fn(Wx(z)) + + if self.is_potential: + Wx = nn.Dense(1, use_bias=True) + z = Wx(z).squeeze(-1) + + quad_term = 0.5 * jax.vmap(jnp.dot)(x, x) + z += quad_term + else: + Wx = nn.Dense(n_input, use_bias=True) + z = x + Wx(z) + + return z.squeeze(0) if squeeze else z + + @jax.tree_util.register_pytree_node_class class MetaInitializer(lin_init.DefaultInitializer): """Meta OT Initializer with a fixed geometry :cite:`amos:22`. diff --git a/src/ott/neural/duality/neuraldual.py b/src/ott/neural/duality/neuraldual.py index 1d1aaa85b..a8f5fd273 100644 --- a/src/ott/neural/duality/neuraldual.py +++ b/src/ott/neural/duality/neuraldual.py @@ -36,8 +36,7 @@ from ott import utils from ott.geometry import costs -from ott.neural.duality import conjugate -from ott.neural.models import models +from ott.neural.duality import conjugate, models from ott.problems.linear import potentials __all__ = ["W2NeuralTrainState", "BaseW2NeuralDual", "W2NeuralDual"] @@ -326,7 +325,7 @@ def setup( # default to using back_and_forth with the non-convex models if self.back_and_forth is None: - self.back_and_forth = isinstance(neural_f, models.MLP) + self.back_and_forth = isinstance(neural_f, models.PotentialMLP) if self.num_inner_iters == 1 and self.parallel_updates: self.train_step_parallel = self.get_step_fn( diff --git a/src/ott/neural/models/models.py b/src/ott/neural/models/models.py index 78fd3d173..df4a0e14e 100644 --- a/src/ott/neural/models/models.py +++ b/src/ott/neural/models/models.py @@ -66,6 +66,27 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: # noqa: D102 return z.squeeze(0) if squeeze else z + def create_train_state( + self, + rng: jax.Array, + optimizer: optax.OptState, + input_dim: int, + ) -> train_state.TrainState: + """Create the training state. + + Args: + rng: Random number generator. + optimizer: Optimizer. + input_dim: Dimensionality of the input. + + Returns: + Training state. + """ + params = self.init(rng, jnp.ones(input_dim))["params"] + return train_state.TrainState.create( + apply_fn=self.apply, params=params, tx=optimizer + ) + class RescalingMLP(nn.Module): """Network to learn distributional rescaling factors based on a MLP. diff --git a/tests/neural/icnn_test.py b/tests/neural/icnn_test.py index dba2f7b7c..541ecff38 100644 --- a/tests/neural/icnn_test.py +++ b/tests/neural/icnn_test.py @@ -17,7 +17,7 @@ import jax.numpy as jnp import numpy as np -from ott.neural.models import models +from ott.neural.duality import models @pytest.mark.fast() diff --git a/tests/neural/losses_test.py b/tests/neural/losses_test.py index d6c9334cd..6379b9dfa 100644 --- a/tests/neural/losses_test.py +++ b/tests/neural/losses_test.py @@ -18,8 +18,8 @@ import numpy as np from ott.geometry import costs -from ott.neural import models from ott.neural.gaps import monge_gap +from ott.neural.models import models @pytest.mark.fast() diff --git a/tests/neural/map_estimator_test.py b/tests/neural/map_estimator_test.py index f19c63b32..508143465 100644 --- a/tests/neural/map_estimator_test.py +++ b/tests/neural/map_estimator_test.py @@ -20,6 +20,7 @@ from ott import datasets from ott.geometry import pointcloud from ott.neural.gaps import map_estimator, monge_gap +from ott.neural.models import models from ott.tools import sinkhorn_divergence @@ -50,7 +51,7 @@ def regularizer(x, y): return gap, out.n_iters # define the model - model = monge_gap.MLP(dim_hidden=[16, 8], is_potential=False) + model = models.MLP(dim_hidden=[16, 8], is_potential=False) # generate data train_dataset, valid_dataset, dim_data = ( diff --git a/tests/neural/meta_initializer_test.py b/tests/neural/meta_initializer_test.py index e84554940..a083d6560 100644 --- a/tests/neural/meta_initializer_test.py +++ b/tests/neural/meta_initializer_test.py @@ -22,7 +22,7 @@ from ott.geometry import pointcloud from ott.initializers.linear import initializers as linear_init -from ott.neural.models import models as nn_init +from ott.neural.duality import models as nn_init from ott.problems.linear import linear_problem from ott.solvers.linear import sinkhorn diff --git a/tests/neural/neuraldual_test.py b/tests/neural/neuraldual_test.py index fc107ec75..5aef77aba 100644 --- a/tests/neural/neuraldual_test.py +++ b/tests/neural/neuraldual_test.py @@ -19,8 +19,7 @@ import numpy as np from ott import datasets -from ott.neural.duality import conjugate, neuraldual -from ott.neural.models import models +from ott.neural.duality import conjugate, models, neuraldual ModelPair_t = Tuple[neuraldual.BaseW2NeuralDual, neuraldual.BaseW2NeuralDual] DatasetPair_t = Tuple[datasets.Dataset, datasets.Dataset] @@ -42,11 +41,12 @@ def neural_models(request: str) -> ModelPair_t: dim_hidden=[32]), models.ICNN(dim_data=2, dim_hidden=[32]) ) if request.param == "mlps": - return models.MLP(dim_hidden=[32]), models.MLP(dim_hidden=[32]), + return models.PotentialMLP(dim_hidden=[32] + ), models.PotentialMLP(dim_hidden=[32]), if request.param == "mlps-grad": return ( - models.MLP(dim_hidden=[32]), - models.MLP(is_potential=False, dim_hidden=[128]) + models.PotentialMLP(dim_hidden=[32]), + models.PotentialMLP(is_potential=False, dim_hidden=[128]) ) raise ValueError(f"Invalid request: {request.param}") From 0b81135c0f9d25b2d432b4d11ebd10dc523d4429 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 5 Dec 2023 14:03:54 +0100 Subject: [PATCH 055/186] incorporate feedback partially --- src/ott/neural/data/dataloaders.py | 1 - src/ott/neural/models/base_solver.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ott/neural/data/dataloaders.py b/src/ott/neural/data/dataloaders.py index 9c09ce08c..68da7de6e 100644 --- a/src/ott/neural/data/dataloaders.py +++ b/src/ott/neural/data/dataloaders.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import Dict, Iterator, Mapping, Optional import numpy as np diff --git a/src/ott/neural/models/base_solver.py b/src/ott/neural/models/base_solver.py index e60d25766..8ee71a9c6 100644 --- a/src/ott/neural/models/base_solver.py +++ b/src/ott/neural/models/base_solver.py @@ -153,6 +153,7 @@ def _get_sinkhorn_match_fn( filter_input: bool = False, ) -> Callable: + @jax.jit def match_pairs( x: jnp.ndarray, y: jnp.ndarray ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: From fccdeef659660081811973f48c3b9f4f56f03eaa Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 5 Dec 2023 14:21:19 +0100 Subject: [PATCH 056/186] make time encoder a layer --- src/ott/neural/flows/__init__.py | 2 +- src/ott/neural/flows/layers.py | 47 ++++++++++++++++++++++++++++++++ src/ott/neural/flows/models.py | 16 ++--------- src/ott/neural/models/layers.py | 5 ++-- 4 files changed, 53 insertions(+), 17 deletions(-) create mode 100644 src/ott/neural/flows/layers.py diff --git a/src/ott/neural/flows/__init__.py b/src/ott/neural/flows/__init__.py index 695cbbe3c..af3ceb125 100644 --- a/src/ott/neural/flows/__init__.py +++ b/src/ott/neural/flows/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import flows, genot, models, otfm +from . import flows, genot, layers, models, otfm diff --git a/src/ott/neural/flows/layers.py b/src/ott/neural/flows/layers.py new file mode 100644 index 000000000..84a526b1f --- /dev/null +++ b/src/ott/neural/flows/layers.py @@ -0,0 +1,47 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc + +import jax.numpy as jnp + +import flax.linen as nn + +__all__ = ["TimeEncoder", "CyclicalTimeEncoder"] + + +class TimeEncoder(nn.Module, abc.ABC): + """A time encoder.""" + + @abc.abstractmethod + def __call__(self, t: jnp.ndarray) -> jnp.ndarray: + """Encode the time. + + Args: + t: Input time of shape (batch_size, 1). + + Returns: + The encoded time. + """ + pass + + +class CyclicalTimeEncoder(nn.Module): + """A cyclical time encoder.""" + n_frequencies: int = 128 + + @nn.compact + def __call__(self, t: jnp.ndarray) -> jnp.ndarray: # noqa: D102 + freq = 2 * jnp.arange(self.n_frequencies) * jnp.pi + t = freq * t + return jnp.concatenate((jnp.cos(t), jnp.sin(t)), axis=-1) diff --git a/src/ott/neural/flows/models.py b/src/ott/neural/flows/models.py index 4cf671a19..be73ac09d 100644 --- a/src/ott/neural/flows/models.py +++ b/src/ott/neural/flows/models.py @@ -20,6 +20,7 @@ import optax from flax.training import train_state +import ott.neural.flows.layers as flow_layers from ott.neural.models import layers __all__ = ["VelocityField"] @@ -61,19 +62,6 @@ class VelocityField(nn.Module): act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.silu n_frequencies: int = 128 - def time_encoder(self, t: jnp.ndarray) -> jnp.array: - """Encode the time. - - Args: - t: Time. - - Returns: - Encoded time. - """ - freq = 2 * jnp.arange(self.n_frequencies) * jnp.pi - t = freq * t - return jnp.concatenate((jnp.cos(t), jnp.sin(t)), axis=-1) - def __post_init__(self): # set embedded dim from latent embedded dim @@ -115,7 +103,7 @@ def __call__( Returns: Output of the neural vector field. """ - t = self.time_encoder(t) + t = flow_layers.CyclicalTimeEncoder(n_frequencies=self.n_frequencies)(t) t = layers.MLPBlock( dim=self.t_embed_dim, out_dim=self.t_embed_dim, diff --git a/src/ott/neural/models/layers.py b/src/ott/neural/models/layers.py index db8b24ae9..952cc9d24 100644 --- a/src/ott/neural/models/layers.py +++ b/src/ott/neural/models/layers.py @@ -14,6 +14,7 @@ from typing import Any, Tuple import jax +import jax.numpy as jnp import flax.linen as nn @@ -26,14 +27,14 @@ class MLPBlock(nn.Module): - """A simple MLP block.""" + """An MLP block.""" dim: int = 128 num_layers: int = 3 act_fn: Any = nn.silu out_dim: int = 128 @nn.compact - def __call__(self, x): + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: """Apply the MLP block. Args: From 2a279c1d76594ee3937a70fae932e4096e738287 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 5 Dec 2023 14:50:30 +0100 Subject: [PATCH 057/186] make conditions Optional and minor feedback --- src/ott/neural/duality/models.py | 4 ++-- src/ott/neural/flows/flows.py | 36 ++++++++++++++-------------- src/ott/neural/flows/genot.py | 10 ++++---- src/ott/neural/flows/models.py | 32 ++++++++++++------------- src/ott/neural/flows/otfm.py | 9 +++---- src/ott/neural/models/base_solver.py | 8 +++++-- src/ott/neural/models/models.py | 11 +++++---- tests/neural/conftest.py | 13 ++++++++++ 8 files changed, 71 insertions(+), 52 deletions(-) diff --git a/src/ott/neural/duality/models.py b/src/ott/neural/duality/models.py index d10e09f55..baa0386c8 100644 --- a/src/ott/neural/duality/models.py +++ b/src/ott/neural/duality/models.py @@ -54,7 +54,7 @@ class ICNN(neuraldual.BaseW2NeuralDual): gaussian_map_samples: Tuple of source and target points, used to initialize the ICNN to mimic the linear Bures map that morphs the (Gaussian approximation) of the input measure to that of the target measure. If - ``None``, the identity initialization is used, and ICNN mimics half the + :obj:`None`, the identity initialization is used, and ICNN mimics half the squared Euclidean norm. """ dim_data: int @@ -234,7 +234,7 @@ class MetaInitializer(lin_init.DefaultInitializer): geom: The fixed geometry of the problem instances. meta_model: The model to predict the potential :math:`f` from the measures. TODO(marcocuturi): add explanation here what arguments to expect. - opt: The optimizer to update the parameters. If ``None``, use + opt: The optimizer to update the parameters. If :obj:`None`, use :func:`optax.adam` with :math:`0.001` learning rate. rng: The PRNG key to use for initializing the model. state: The training state of the model to start from. diff --git a/src/ott/neural/flows/flows.py b/src/ott/neural/flows/flows.py index 93f471b9d..83b23eb42 100644 --- a/src/ott/neural/flows/flows.py +++ b/src/ott/neural/flows/flows.py @@ -34,7 +34,7 @@ def __init__(self, sigma: float): @abc.abstractmethod def compute_mu_t( - self, t: jnp.ndarray, x_0: jnp.ndarray, x_1: jnp.ndarray + self, t: jnp.ndarray, src: jnp.ndarray, tgt: jnp.ndarray ) -> jnp.ndarray: """Compute the mean of the probablitiy path. @@ -43,12 +43,12 @@ def compute_mu_t( Args: t: Time :math:`t`. - x_0: Sample from the source distribution. - x_1: Sample from the target distribution. + src: Sample from the source distribution. + tgt: Sample from the target distribution. """ @abc.abstractmethod - def compute_sigma_t(self, t: jnp.ndarray): + def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray: """Compute the standard deviation of the probablity path at time :math:`t`. Args: @@ -57,7 +57,7 @@ def compute_sigma_t(self, t: jnp.ndarray): @abc.abstractmethod def compute_ut( - self, t: jnp.ndarray, x_0: jnp.ndarray, x_1: jnp.ndarray + self, t: jnp.ndarray, src: jnp.ndarray, tgt: jnp.ndarray ) -> jnp.ndarray: """Evaluate the conditional vector field. @@ -66,13 +66,13 @@ def compute_ut( Args: t: Time :math:`t`. - x_0: Sample from the source distribution. - x_1: Sample from the target distribution. + src: Sample from the source distribution. + tgt: Sample from the target distribution. """ def compute_xt( - self, noise: jnp.ndarray, t: jnp.ndarray, x_0: jnp.ndarray, - x_1: jnp.ndarray + self, noise: jnp.ndarray, t: jnp.ndarray, src: jnp.ndarray, + tgt: jnp.ndarray ) -> jnp.ndarray: """Sample from the probability path. @@ -82,14 +82,14 @@ def compute_xt( Args: noise: Noise sampled from a standard normal distribution. t: Time :math:`t`. - x_0: Sample from the source distribution. - x_1: Sample from the target distribution. + src: Sample from the source distribution. + tgt: Sample from the target distribution. Returns: Samples from the probability path between :math:`x_0` and :math:`x_1` at time :math:`t`. """ - mu_t = self.compute_mu_t(t, x_0, x_1) + mu_t = self.compute_mu_t(t, src, tgt) sigma_t = self.compute_sigma_t(t) return mu_t + sigma_t * noise @@ -103,7 +103,7 @@ def compute_mu_t( # noqa: D102 return t * x_0 + (1 - t) * x_1 def compute_ut( - self, t: jnp.ndarray, x_0: jnp.ndarray, x_1: jnp.ndarray + self, t: jnp.ndarray, src: jnp.ndarray, tgt: jnp.ndarray ) -> jnp.ndarray: """Evaluate the conditional vector field. @@ -112,19 +112,19 @@ def compute_ut( Args: t: Time :math:`t`. - x_0: Sample from the source distribution. - x_1: Sample from the target distribution. + src: Sample from the source distribution. + tgt: Sample from the target distribution. Returns: Conditional vector field evaluated at time :math:`t`. """ - return x_1 - x_0 + return tgt - src class ConstantNoiseFlow(StraightFlow): r"""Flow with straight paths and constant flow noise :math:`\sigma`.""" - def compute_sigma_t(self, t: jnp.ndarray): + def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray: r"""Compute noise of the flow at time :math:`t`. Args: @@ -144,7 +144,7 @@ class BrownianNoiseFlow(StraightFlow): :math:`\sigma_t = \sigma * \sqrt(t * (1-t))`. """ - def compute_sigma_t(self, t: jnp.ndarray): + def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray: """Compute the standard deviation of the probablity path at time :math:`t`. Args: diff --git a/src/ott/neural/flows/genot.py b/src/ott/neural/flows/genot.py index fa5ada781..4384f8e7f 100644 --- a/src/ott/neural/flows/genot.py +++ b/src/ott/neural/flows/genot.py @@ -88,10 +88,10 @@ class GENOT(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): on the first marginal. tau_b: If :math:`< 1`, defines how much unbalanced the problem is on the second marginal. - mlp_eta: Neural network to learn the left rescaling function. If `None`, - the left rescaling factor is not learnt. - mlp_xi: Neural network to learn the right rescaling function. If `None`, - the right rescaling factor is not learnt. + mlp_eta: Neural network to learn the left rescaling function. If + :obj:`None`, the left rescaling factor is not learnt. + mlp_xi: Neural network to learn the right rescaling function. If + :obj:`None`, the right rescaling factor is not learnt. unbalanced_kwargs: Keyword arguments for the unbalancedness solver. callback_fn: Callback function. rng: Random number generator. @@ -379,7 +379,7 @@ def loss_fn( def transport( self, source: jnp.ndarray, - condition: Optional[jnp.ndarray], + condition: Optional[jnp.ndarray] = None, rng: Optional[jax.Array] = None, forward: bool = True, diffeqsolve_kwargs: Dict[str, Any] = types.MappingProxyType({}), diff --git a/src/ott/neural/flows/models.py b/src/ott/neural/flows/models.py index be73ac09d..9a5ce13af 100644 --- a/src/ott/neural/flows/models.py +++ b/src/ott/neural/flows/models.py @@ -38,14 +38,14 @@ class VelocityField(nn.Module): Args: output_dim: Dimensionality of the neural vector field. - condition_dim: Dimensionality of the conditioning vector. latent_embed_dim: Dimensionality of the embedding of the data. + condition_dim: Dimensionality of the conditioning vector. condition_embed_dim: Dimensionality of the embedding of the condition. - If ``None``, set to ``latent_embed_dim``. + If :obj:`None`, set to ``latent_embed_dim``. t_embed_dim: Dimensionality of the time embedding. - If ``None``, set to ``latent_embed_dim``. + If :obj:`None`, set to ``latent_embed_dim``. joint_hidden_dim: Dimensionality of the hidden layers of the joint network. - If ``None``, set to ``latent_embed_dim + condition_embed_dim + + If :obj:`None`, set to ``latent_embed_dim + condition_embed_dim + t_embed_dim``. num_layers_per_block: Number of layers per block. act_fn: Activation function. @@ -53,8 +53,8 @@ class VelocityField(nn.Module): """ output_dim: int - condition_dim: int latent_embed_dim: int + condition_dim: Optional[int] = None condition_embed_dim: Optional[int] = None t_embed_dim: Optional[int] = None joint_hidden_dim: Optional[int] = None @@ -89,26 +89,29 @@ def __call__( self, t: jnp.ndarray, x: jnp.ndarray, - condition: Optional[jnp.ndarray], + condition: Optional[jnp.ndarray] = None, keys_model: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: """Forward pass through the neural vector field. Args: - t: Time. - x: Data. + t: Time of shape (batch_size, 1). + x: Data of shape (batch_size, output_dim). condition: Conditioning vector. keys_model: Random number generator. Returns: Output of the neural vector field. """ + if self.condition_dim is None: + assert condition is None + t = flow_layers.CyclicalTimeEncoder(n_frequencies=self.n_frequencies)(t) t = layers.MLPBlock( dim=self.t_embed_dim, out_dim=self.t_embed_dim, num_layers=self.num_layers_per_block, - act_fn=self.act_fn, + act_fn=self.act_fn )( t ) @@ -122,7 +125,7 @@ def __call__( x ) - if self.condition_dim > 0: + if self.condition_dim is not None: condition = layers.MLPBlock( dim=self.condition_embed_dim, out_dim=self.condition_embed_dim, @@ -139,17 +142,12 @@ def __call__( dim=self.joint_hidden_dim, out_dim=self.joint_hidden_dim, num_layers=self.num_layers_per_block, - act_fn=self.act_fn, + act_fn=self.act_fn )( concatenated ) - return nn.Dense( - self.output_dim, - use_bias=True, - )( - out - ) + return nn.Dense(self.output_dim, use_bias=True)(out) def create_train_state( self, diff --git a/src/ott/neural/flows/otfm.py b/src/ott/neural/flows/otfm.py index ed0114f6d..657c5fe82 100644 --- a/src/ott/neural/flows/otfm.py +++ b/src/ott/neural/flows/otfm.py @@ -61,7 +61,8 @@ class OTFlowMatching(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): valid_freq: Frequency of validation. ot_solver: OT solver to match samples from the source and the target distribution as proposed in :cite:`tong:23`, :cite:`pooladian:23`. - If `None`, no matching will be performed as proposed in :cite:`lipman:22`. + If :obj:`None`, no matching will be performed as proposed in + :cite:`lipman:22`. flow: Flow between source and target distribution. time_sampler: Sampler for the time. optimizer: Optimizer for `neural_vector_field`. @@ -76,9 +77,9 @@ class OTFlowMatching(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): tau_b: If :math:`< 1`, defines how much unbalanced the problem is on the second marginal. mlp_eta: Neural network to learn the left rescaling function as suggested - in :cite:`TODO`. If `None`, the left rescaling factor is not learnt. + in :cite:`TODO`. If :obj:`None`, the left rescaling factor is not learnt. mlp_xi: Neural network to learn the right rescaling function as suggested - in :cite:`TODO`. If `None`, the right rescaling factor is not learnt. + in :cite:`TODO`. If :obj:`None`, the right rescaling factor is not learnt. unbalanced_kwargs: Keyword arguments for the unbalancedness solver. callback_fn: Callback function. num_eval_samples: Number of samples to evaluate on during evaluation. @@ -279,7 +280,7 @@ def __call__(self, train_loader, valid_loader): def transport( self, data: jnp.array, - condition: Optional[jnp.ndarray], + condition: Optional[jnp.ndarray] = None, forward: bool = True, diffeqsolve_kwargs: Dict[str, Any] = types.MappingProxyType({}) ) -> diffrax.Solution: diff --git a/src/ott/neural/models/base_solver.py b/src/ott/neural/models/base_solver.py index 8ee71a9c6..e2236a294 100644 --- a/src/ott/neural/models/base_solver.py +++ b/src/ott/neural/models/base_solver.py @@ -441,7 +441,9 @@ def step_fn( return step_fn def evaluate_eta( - self, source: jnp.ndarray, condition: Optional[jnp.ndarray] + self, + source: jnp.ndarray, + condition: Optional[jnp.ndarray] = None ) -> jnp.ndarray: """Evaluate the left learnt rescaling factor. @@ -460,7 +462,9 @@ def evaluate_eta( condition=condition) def evaluate_xi( - self, target: jnp.ndarray, condition: Optional[jnp.ndarray] + self, + target: jnp.ndarray, + condition: Optional[jnp.ndarray] = None ) -> jnp.ndarray: """Evaluate the right learnt rescaling factor. diff --git a/src/ott/neural/models/models.py b/src/ott/neural/models/models.py index df4a0e14e..4250ff9f8 100644 --- a/src/ott/neural/models/models.py +++ b/src/ott/neural/models/models.py @@ -110,13 +110,15 @@ class RescalingMLP(nn.Module): Rescaling factors. """ hidden_dim: int - condition_dim: int + condition_dim: Optional[int] = None num_layers_per_block: int = 3 act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.selu @nn.compact def __call__( - self, x: jnp.ndarray, condition: Optional[jnp.ndarray] + self, + x: jnp.ndarray, + condition: Optional[jnp.ndarray] = None ) -> jnp.ndarray: # noqa: D102 """Forward pass through the rescaling network. @@ -127,6 +129,8 @@ def __call__( Returns: Estimated rescaling factors. """ + if self.condition_dim is None: + assert condition is None x = layers.MLPBlock( dim=self.hidden_dim, out_dim=self.hidden_dim, @@ -135,8 +139,7 @@ def __call__( )( x ) - if self.condition_dim > 0: - condition = jnp.atleast_1d(condition) + if self.condition_dim is not None: condition = layers.MLPBlock( dim=self.hidden_dim, out_dim=self.hidden_dim, diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py index 74d66dea3..723d25393 100644 --- a/tests/neural/conftest.py +++ b/tests/neural/conftest.py @@ -1,3 +1,16 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pytest import numpy as np From e6f0049bd26d59f393b8ad97cfc61d4cd5570771 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 5 Dec 2023 15:44:14 +0100 Subject: [PATCH 058/186] revert faulty jax.array / jnp.ndarray conversions --- src/ott/neural/flows/flows.py | 4 +- src/ott/neural/flows/otfm.py | 7 - src/ott/neural/models/base_solver.py | 18 +-- src/ott/neural/models/layers.py | 2 +- .../solvers/quadratic/gromov_wasserstein.py | 2 +- src/ott/tools/soft_sort.py | 2 +- src/ott/utils.py | 3 +- tests/geometry/scaling_cost_test.py | 4 +- tests/neural/genot_test.py | 1 - tests/solvers/quadratic/lower_bound_test.py | 137 ------------------ tests/tools/soft_sort_test.py | 2 +- 11 files changed, 18 insertions(+), 164 deletions(-) delete mode 100644 tests/solvers/quadratic/lower_bound_test.py diff --git a/src/ott/neural/flows/flows.py b/src/ott/neural/flows/flows.py index 83b23eb42..eb1723883 100644 --- a/src/ott/neural/flows/flows.py +++ b/src/ott/neural/flows/flows.py @@ -98,9 +98,9 @@ class StraightFlow(BaseFlow, abc.ABC): """Base class for flows with straight paths.""" def compute_mu_t( # noqa: D102 - self, t: jnp.ndarray, x_0: jnp.ndarray, x_1: jnp.ndarray + self, t: jnp.ndarray, src: jnp.ndarray, tgt: jnp.ndarray ) -> jnp.ndarray: - return t * x_0 + (1 - t) * x_1 + return t * src + (1 - t) * tgt def compute_ut( self, t: jnp.ndarray, src: jnp.ndarray, tgt: jnp.ndarray diff --git a/src/ott/neural/flows/otfm.py b/src/ott/neural/flows/otfm.py index 657c5fe82..ec5eb9821 100644 --- a/src/ott/neural/flows/otfm.py +++ b/src/ott/neural/flows/otfm.py @@ -84,10 +84,6 @@ class OTFlowMatching(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): callback_fn: Callback function. num_eval_samples: Number of samples to evaluate on during evaluation. rng: Random number generator. - - Returns: - None - """ def __init__( @@ -226,9 +222,6 @@ def __call__(self, train_loader, valid_loader): Args; train_loader: Dataloader for the training data. valid_loader: Dataloader for the validation data. - - Returns: - None """ batch: Mapping[str, jnp.ndarray] = {} curr_loss = 0.0 diff --git a/src/ott/neural/models/base_solver.py b/src/ott/neural/models/base_solver.py index e2236a294..3f807c4cc 100644 --- a/src/ott/neural/models/base_solver.py +++ b/src/ott/neural/models/base_solver.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from abc import ABC, abstractmethod +import abc from pathlib import Path from types import MappingProxyType from typing import Any, Callable, Dict, Literal, Mapping, Optional, Tuple, Union @@ -30,7 +30,7 @@ __all__ = ["BaseNeuralSolver", "ResampleMixin", "UnbalancednessMixin"] -class BaseNeuralSolver(ABC): +class BaseNeuralSolver(abc.ABC): """Base class for neural solvers. Args: @@ -42,28 +42,28 @@ def __init__(self, iterations: int, valid_freq: int, **_: Any): self.iterations = iterations self.valid_freq = valid_freq - @abstractmethod + @abc.abstractmethod def setup(self, *args: Any, **kwargs: Any): """Setup the model.""" - @abstractmethod + @abc.abstractmethod def __call__(self, *args: Any, **kwargs: Any): """Train the model.""" - @abstractmethod + @abc.abstractmethod def transport(self, *args: Any, forward: bool, **kwargs: Any) -> Any: """Transport.""" - @abstractmethod + @abc.abstractmethod def save(self, path: Path): """Save the model.""" - @abstractmethod + @abc.abstractmethod def load(self, path: Path): """Load the model.""" @property - @abstractmethod + @abc.abstractmethod def training_logs(self) -> Dict[str, Any]: """Return the training logs.""" @@ -327,7 +327,7 @@ def _resample_unbalanced( batch: Tuple[jnp.ndarray, ...], marginals: jnp.ndarray, ) -> Tuple[jnp.ndarray, ...]: - """Resample a batch based upon marginals.""" + """Resample a batch based on marginals.""" indices = jax.random.choice( key, a=len(marginals), p=jnp.squeeze(marginals), shape=[len(marginals)] ) diff --git a/src/ott/neural/models/layers.py b/src/ott/neural/models/layers.py index 952cc9d24..46313b0e2 100644 --- a/src/ott/neural/models/layers.py +++ b/src/ott/neural/models/layers.py @@ -38,7 +38,7 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: """Apply the MLP block. Args: - x: Input data of shape (batch_size, dim) + x: Input data of shape (batch_size, dim). Returns: Output data of shape (batch_size, out_dim). diff --git a/src/ott/solvers/quadratic/gromov_wasserstein.py b/src/ott/solvers/quadratic/gromov_wasserstein.py index 5e23d88e6..a7890e1c9 100644 --- a/src/ott/solvers/quadratic/gromov_wasserstein.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein.py @@ -129,7 +129,7 @@ class GWState(NamedTuple): linear_state: LinearOutput linear_pb: linear_problem.LinearProblem old_transport_mass: float - rngs: Optional[jnp.ndarray] = None + rngs: Optional[jax.Array] = None errors: Optional[jnp.ndarray] = None def set(self, **kwargs: Any) -> "GWState": diff --git a/src/ott/tools/soft_sort.py b/src/ott/tools/soft_sort.py index ccde3bd2c..1a30359ee 100644 --- a/src/ott/tools/soft_sort.py +++ b/src/ott/tools/soft_sort.py @@ -458,7 +458,7 @@ def _quantile( def multivariate_cdf_quantile_maps( inputs: jnp.ndarray, target_sampler: Optional[Callable[[jnp.ndarray, Tuple[int, int]], - jnp.ndarray]] = None, + jax.Array]] = None, rng: Optional[jax.Array] = None, num_target_samples: Optional[int] = None, cost_fn: Optional[costs.CostFn] = None, diff --git a/src/ott/utils.py b/src/ott/utils.py index 63a36f2b4..558f4ba1c 100644 --- a/src/ott/utils.py +++ b/src/ott/utils.py @@ -18,7 +18,6 @@ from typing import Any, Callable, NamedTuple, Optional, Tuple import jax -import jax.numpy as jnp import numpy as np try: @@ -69,7 +68,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return functools.wraps(func)(wrapper) -def default_prng_key(rng: Optional[jax.Array] = None) -> jnp.ndarray: +def default_prng_key(rng: Optional[jax.Array] = None) -> jax.Array: """Get the default PRNG key. Args: diff --git a/tests/geometry/scaling_cost_test.py b/tests/geometry/scaling_cost_test.py index 6cd5dcaa9..3dbe4bf31 100644 --- a/tests/geometry/scaling_cost_test.py +++ b/tests/geometry/scaling_cost_test.py @@ -189,7 +189,7 @@ def apply_sinkhorn(cost1, cost2, scale_cost): np.testing.assert_allclose(1.0, geom.cost_matrix.max(), rtol=1e-4) @pytest.mark.parametrize("batch_size", [5, 12]) - def test_mascale_cost_xx_low_rank_with_batch(self, batch_size: int): + def test_max_scale_cost_low_rank_with_batch(self, batch_size: int): """Test max_cost options for low rank with batch_size fixed.""" geom0 = low_rank.LRCGeometry( @@ -200,7 +200,7 @@ def test_mascale_cost_xx_low_rank_with_batch(self, batch_size: int): geom0.inv_scale_cost, 1.0 / jnp.max(self.cost_lr), rtol=1e-4 ) - def test_mascale_cost_xx_low_rank_large_array(self): + def test_max_scale_cost_low_rank_large_array(self): """Test max_cost options for large matrices.""" _, *rngs = jax.random.split(self.rng, 3) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index a962afca3..44ad21428 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -29,7 +29,6 @@ class TestGENOT: - #TODO: add tests for unbalancedness @pytest.mark.parametrize("k_samples_per_x", [1, 2]) @pytest.mark.parametrize("solver_latent_to_data", [None, "sinkhorn"]) diff --git a/tests/solvers/quadratic/lower_bound_test.py b/tests/solvers/quadratic/lower_bound_test.py deleted file mode 100644 index 37bf2a8b3..000000000 --- a/tests/solvers/quadratic/lower_bound_test.py +++ /dev/null @@ -1,137 +0,0 @@ -# Copyright OTT-JAX -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import functools -from typing import Callable - -import pytest - -import jax -import jax.numpy as jnp -import numpy as np - -from ott.geometry import costs, distrib_costs, pointcloud -from ott.initializers.linear import initializers -from ott.problems.quadratic import quadratic_problem -from ott.solvers.linear import implicit_differentiation as implicit_lib -from ott.solvers.quadratic import lower_bound -from ott.tools import soft_sort - - -class TestLowerBoundSolver: - - @pytest.fixture(autouse=True) - def initialize(self, rng: jax.Array): - d_x = 2 - d_y = 3 - self.n, self.m = 13, 15 - rngs = jax.random.split(rng, 4) - self.x = jax.random.uniform(rngs[0], (self.n, d_x)) - self.y = jax.random.uniform(rngs[1], (self.m, d_y)) - # Currently the Lower Bound only supports uniform distributions: - a = jnp.ones(self.n) - b = jnp.ones(self.m) - self.a = a / jnp.sum(a) - self.b = b / jnp.sum(b) - self.cx = jax.random.uniform(rngs[2], (self.n, self.n)) - self.cy = jax.random.uniform(rngs[3], (self.m, self.m)) - - @pytest.mark.fast.with_args( - "ground_cost", - [costs.SqEuclidean(), costs.PNormP(1.5)], - only_fast=0, - ) - def test_lb_pointcloud(self, ground_cost: costs.TICost): - x, y = self.x, self.y - - geom_x = pointcloud.PointCloud(x) - geom_y = pointcloud.PointCloud(y) - prob = quadratic_problem.QuadraticProblem( - geom_x, geom_y, a=self.a, b=self.b - ) - distrib_cost = distrib_costs.UnivariateWasserstein(ground_cost=ground_cost) - solver = lower_bound.LowerBoundSolver( - epsilon=1e-1, distrib_cost=distrib_cost - ) - - out = jax.jit(solver)(prob) - - assert not jnp.isnan(out.reg_ot_cost) - - @pytest.mark.parametrize("method", ["subsample", "quantile", "equal"]) - @pytest.mark.parametrize( - "sort_fn", - [ - None, - functools.partial( - soft_sort.sort, - epsilon=1e-3, - implicit_diff=False, - # soft sort uses `sorting` initializer, which uses while loop - # which is not reverse-mode diff. - initializer=initializers.DefaultInitializer(), - min_iterations=10, - max_iterations=10, - ), - functools.partial( - soft_sort.sort, - epsilon=1e-1, - implicit_diff=implicit_lib.ImplicitDiff(), - initializer=initializers.DefaultInitializer(), - min_iterations=0, - max_iterations=100, - ) - ] - ) - def test_lb_grad( - self, rng: jax.Array, sort_fn: Callable[[jnp.ndarray], jnp.ndarray], - method: str - ): - - def fn(x: jnp.ndarray, y: jnp.ndarray) -> float: - geom_x = pointcloud.PointCloud(x) - geom_y = pointcloud.PointCloud(y) - prob = quadratic_problem.QuadraticProblem(geom_x, geom_y) - - solver = lower_bound.LowerBoundSolver( - epsilon=5e-2, - sort_fn=sort_fn, - cost_fn=costs.SqEuclidean(), - method=method, - n_subsamples=n_sub, - ) - return solver(prob).reg_ot_cost - - rng1, rng2 = jax.random.split(rng) - eps, tol = 1e-4, 1e-3 - - n_sub = min(self.x.shape[0], self.y.shape[0]) - if method == "equal": - x, y = self.x[:n_sub], self.y[:n_sub] - else: - x, y = self.x, self.y - - grad_x, grad_y = jax.jit(jax.grad(fn, (0, 1)))(x, y) - - v_x = jax.random.normal(rng1, shape=x.shape) - v_x = (v_x / jnp.linalg.norm(v_x, axis=-1, keepdims=True)) * eps - expected = fn(x + v_x, y) - fn(x - v_x, y) - actual = 2.0 * jnp.vdot(v_x, grad_x) - np.testing.assert_allclose(actual, expected, rtol=tol, atol=tol) - - v_y = jax.random.normal(rng2, shape=y.shape) - v_y = (v_y / jnp.linalg.norm(v_y, axis=-1, keepdims=True)) * eps - expected = (fn(x, y + v_y) - fn(x, y - v_y)) - actual = 2.0 * jnp.vdot(v_y, grad_y) - np.testing.assert_allclose(actual, expected, rtol=tol, atol=tol) diff --git a/tests/tools/soft_sort_test.py b/tests/tools/soft_sort_test.py index b4fa68ddf..3d66d43c7 100644 --- a/tests/tools/soft_sort_test.py +++ b/tests/tools/soft_sort_test.py @@ -109,7 +109,7 @@ def test_multivariate_cdf_quantiles(self, rng: jax.Array): # Check passing custom sampler, must be still symmetric / centered on {.5}^d # Check passing custom epsilon also works. - def ball_sampler(k: jnp.ndarray, s: Tuple[int, int]) -> jnp.ndarray: + def ball_sampler(k: jax.Array, s: Tuple[int, int]) -> jnp.ndarray: return 0.5 * (jax.random.ball(k, d=s[1], p=4, shape=(s[0],)) + 1.) num_target_samples = 473 From f23497f6ef0e402cbac5817b72d07f56eec30c55 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 5 Dec 2023 16:04:04 +0100 Subject: [PATCH 059/186] make formatting in neural nets nicer --- src/ott/neural/flows/models.py | 21 ++++++++------------- src/ott/neural/models/models.py | 19 +++++++++---------- 2 files changed, 17 insertions(+), 23 deletions(-) diff --git a/src/ott/neural/flows/models.py b/src/ott/neural/flows/models.py index 9a5ce13af..c9c2df44a 100644 --- a/src/ott/neural/flows/models.py +++ b/src/ott/neural/flows/models.py @@ -107,46 +107,41 @@ def __call__( assert condition is None t = flow_layers.CyclicalTimeEncoder(n_frequencies=self.n_frequencies)(t) - t = layers.MLPBlock( + t_layer = layers.MLPBlock( dim=self.t_embed_dim, out_dim=self.t_embed_dim, num_layers=self.num_layers_per_block, act_fn=self.act_fn - )( - t ) + t = t_layer(t) - x = layers.MLPBlock( + x_layer = layers.MLPBlock( dim=self.latent_embed_dim, out_dim=self.latent_embed_dim, num_layers=self.num_layers_per_block, act_fn=self.act_fn - )( - x ) + x = x_layer(x) if self.condition_dim is not None: - condition = layers.MLPBlock( + condition_layer = layers.MLPBlock( dim=self.condition_embed_dim, out_dim=self.condition_embed_dim, num_layers=self.num_layers_per_block, act_fn=self.act_fn - )( - condition ) + condition = condition_layer(condition) concatenated = jnp.concatenate((t, x, condition), axis=-1) else: concatenated = jnp.concatenate((t, x), axis=-1) - out = layers.MLPBlock( + out_layer = layers.MLPBlock( dim=self.joint_hidden_dim, out_dim=self.joint_hidden_dim, num_layers=self.num_layers_per_block, act_fn=self.act_fn - )( - concatenated ) - + out = out_layer(concatenated) return nn.Dense(self.output_dim, use_bias=True)(out) def create_train_state( diff --git a/src/ott/neural/models/models.py b/src/ott/neural/models/models.py index 4250ff9f8..5afc809d7 100644 --- a/src/ott/neural/models/models.py +++ b/src/ott/neural/models/models.py @@ -131,36 +131,35 @@ def __call__( """ if self.condition_dim is None: assert condition is None - x = layers.MLPBlock( + x_layer = layers.MLPBlock( dim=self.hidden_dim, out_dim=self.hidden_dim, num_layers=self.num_layers_per_block, act_fn=self.act_fn - )( - x ) + x = x_layer(x) + if self.condition_dim is not None: - condition = layers.MLPBlock( + condition_layer = layers.MLPBlock( dim=self.hidden_dim, out_dim=self.hidden_dim, num_layers=self.num_layers_per_block, act_fn=self.act_fn - )( - condition ) + + condition = condition_layer(condition) concatenated = jnp.concatenate((x, condition), axis=-1) else: concatenated = x - out = layers.MLPBlock( + out_layer = layers.MLPBlock( dim=self.hidden_dim, out_dim=self.hidden_dim, num_layers=self.num_layers_per_block, - act_fn=self.act_fn, - )( - concatenated + act_fn=self.act_fn ) + out = out_layer(concatenated) return jnp.exp(out) def create_train_state( From 9f96583ddd967fc061ac1bae3b32ab47562f1f6f Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 5 Dec 2023 16:20:42 +0100 Subject: [PATCH 060/186] add description to Velocity Field --- src/ott/neural/flows/models.py | 10 +++++++++- src/ott/neural/models/models.py | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/ott/neural/flows/models.py b/src/ott/neural/flows/models.py index c9c2df44a..0177383c9 100644 --- a/src/ott/neural/flows/models.py +++ b/src/ott/neural/flows/models.py @@ -27,7 +27,15 @@ class VelocityField(nn.Module): - """Parameterized neural vector field. + r"""Parameterized neural vector field. + + The `VelocityField` learns a map + :math:`v: \\mathbb{R}\times \\mathbb{R}^d\rightarrow \\mathbb{R}^d` solving + the ODE :math:`\frac{dx}{dt} = v(t, x)`. Given a source distribution at time + :math:`t=0`, the `VelocityField` can be used to transport the source + distribution given at :math:`t_0` to a target distribution given at + :math:`t_1` by integrating :math:`v(t, x)` from :math:`t=t_0` to + :math:`t=t_1`. Each of the input, condition, and time embeddings are passed through a block consisting of ``num_layers_per_block`` layers of dimension diff --git a/src/ott/neural/models/models.py b/src/ott/neural/models/models.py index 5afc809d7..e84e3560a 100644 --- a/src/ott/neural/models/models.py +++ b/src/ott/neural/models/models.py @@ -107,7 +107,7 @@ class RescalingMLP(nn.Module): act_fn: Activation function. Returns: - Rescaling factors. + Non-negative rescaling factors. """ hidden_dim: int condition_dim: Optional[int] = None From 86fe8864098340931f5898086b58e5f9d13b01c9 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 5 Dec 2023 16:38:03 +0100 Subject: [PATCH 061/186] replace time sampler class by function --- src/ott/neural/flows/__init__.py | 2 +- src/ott/neural/flows/flows.py | 87 ++------------------------------ src/ott/neural/flows/genot.py | 10 ++-- src/ott/neural/flows/otfm.py | 4 +- src/ott/neural/flows/samplers.py | 50 ++++++++++++++++++ 5 files changed, 60 insertions(+), 93 deletions(-) create mode 100644 src/ott/neural/flows/samplers.py diff --git a/src/ott/neural/flows/__init__.py b/src/ott/neural/flows/__init__.py index af3ceb125..cc2c4bfdb 100644 --- a/src/ott/neural/flows/__init__.py +++ b/src/ott/neural/flows/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import flows, genot, layers, models, otfm +from . import flows, genot, layers, models, otfm, samplers diff --git a/src/ott/neural/flows/flows.py b/src/ott/neural/flows/flows.py index eb1723883..0dce912aa 100644 --- a/src/ott/neural/flows/flows.py +++ b/src/ott/neural/flows/flows.py @@ -13,12 +13,13 @@ # limitations under the License. import abc -import jax import jax.numpy as jnp __all__ = [ - "BaseFlow", "StraightFlow", "ConstantNoiseFlow", "BrownianNoiseFlow", - "BaseTimeSampler", "UniformSampler", "OffsetUniformSampler" + "BaseFlow", + "StraightFlow", + "ConstantNoiseFlow", + "BrownianNoiseFlow", ] @@ -154,83 +155,3 @@ def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray: Standard deviation of the probablity path at time :math:`t`. """ return jnp.sqrt(self.sigma * t * (1 - t)) - - -class BaseTimeSampler(abc.ABC): - """Base class for time samplers. - - Args: - low: Lower bound of the distribution to sample from. - high: Upper bound of the distribution to sample from . - """ - - def __init__(self, low: float, high: float): - self.low = low - self.high = high - - @abc.abstractmethod - def __call__(self, rng: jax.Array, num_samples: int) -> jnp.ndarray: - """Generate `num_samples` samples of the time `math`:t:. - - Args: - rng: Random number generator. - num_samples: Number of samples to generate. - """ - - -class UniformSampler(BaseTimeSampler): - """Sample :math:`t` from a uniform distribution :math:`[low, high]`. - - Args: - low: Lower bound of the uniform distribution. - high: Upper bound of the uniform distribution. - """ - - def __init__(self, low: float = 0.0, high: float = 1.0): - super().__init__(low=low, high=high) - - def __call__(self, rng: jax.Array, num_samples: int) -> jnp.ndarray: - """Generate `num_samples` samples of the time `math`:t:. - - Args: - rng: Random number generator. - num_samples: Number of samples to generate. - - Returns: - `num_samples` samples of the time :math:`t``. - """ - return jax.random.uniform( - rng, (num_samples, 1), minval=self.low, maxval=self.high - ) - - -class OffsetUniformSampler(BaseTimeSampler): - """Sample the time :math:`t`. - - Sample :math:`t` from a uniform distribution :math:`[low, high]` with - offset `offset`. - - Args: - offset: Offset of the uniform distribution. - low: Lower bound of the uniform distribution. - high: Upper bound of the uniform distribution. - """ - - def __init__(self, offset: float, low: float = 0.0, high: float = 1.0): - super().__init__(low=low, high=high) - self.offset = offset - - def __call__(self, rng: jax.Array, num_samples: int) -> jnp.ndarray: - """Generate `num_samples` samples of the time `math`:t:. - - Args: - rng: Random number generator. - num_samples: Number of samples to generate. - - Returns: - An array with `num_samples` samples of the time `math`:t:. - """ - return ( - jax.random.uniform(rng, (1, 1), minval=self.low, maxval=self.high) + - jnp.arange(num_samples)[:, None] / num_samples - ) % ((self.high - self.low) - self.offset) diff --git a/src/ott/neural/flows/genot.py b/src/ott/neural/flows/genot.py index 4384f8e7f..813be6649 100644 --- a/src/ott/neural/flows/genot.py +++ b/src/ott/neural/flows/genot.py @@ -26,12 +26,8 @@ from ott import utils from ott.geometry import costs -from ott.neural.flows.flows import ( - BaseFlow, - BaseTimeSampler, - ConstantNoiseFlow, - UniformSampler, -) +from ott.neural.flows.flows import BaseFlow, ConstantNoiseFlow +from ott.neural.flows.samplers import sample_uniformly from ott.neural.models.base_solver import ( BaseNeuralSolver, ResampleMixin, @@ -118,7 +114,7 @@ def __init__( "max_cost", "median"]]]], optimizer: Type[optax.GradientTransformation], flow: Type[BaseFlow] = ConstantNoiseFlow(0.0), - time_sampler: Type[BaseTimeSampler] = UniformSampler(), + time_sampler: Callable[[jax.Array, int], jnp.ndarray] = sample_uniformly, checkpoint_manager: Type[checkpoint.CheckpointManager] = None, k_samples_per_x: int = 1, solver_latent_to_data: Optional[Type[was_solver.WassersteinSolver] diff --git a/src/ott/neural/flows/otfm.py b/src/ott/neural/flows/otfm.py index ec5eb9821..f7a973eb4 100644 --- a/src/ott/neural/flows/otfm.py +++ b/src/ott/neural/flows/otfm.py @@ -36,7 +36,7 @@ from ott import utils from ott.geometry import costs -from ott.neural.flows.flows import BaseFlow, BaseTimeSampler +from ott.neural.flows.flows import BaseFlow from ott.neural.models.base_solver import ( BaseNeuralSolver, ResampleMixin, @@ -96,7 +96,7 @@ def __init__( iterations: int, ot_solver: Optional[Type[was_solver.WassersteinSolver]], flow: Type[BaseFlow], - time_sampler: Type[BaseTimeSampler], + time_sampler: Callable[[jax.Array, int], jnp.ndarray], optimizer: Type[optax.GradientTransformation], checkpoint_manager: Type[checkpoint.CheckpointManager] = None, epsilon: float = 1e-2, diff --git a/src/ott/neural/flows/samplers.py b/src/ott/neural/flows/samplers.py new file mode 100644 index 000000000..f5d0e0d17 --- /dev/null +++ b/src/ott/neural/flows/samplers.py @@ -0,0 +1,50 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import jax +import jax.numpy as jnp + +__all__ = ["sample_uniformly"] + + +def sample_uniformly( + rng: jax.Array, + num_samples: int, + low: float = 0.0, + high: float = 1.0, + offset: Optional[float] = None +): + """Sample from a uniform distribution. + + Sample :math:`t` from a uniform distribution :math:`[low, high]` with + offset `offset`. + + Args: + rng: Random number generator. + num_samples: Number of samples to generate. + low: Lower bound of the uniform distribution. + high: Upper bound of the uniform distribution. + offset: Offset of the uniform distribution. If :obj:`None`, no offset is + used. + + Returns: + An array with `num_samples` samples of the time `math`:t:. + """ + if offset is None: + return jax.random.uniform(rng, (num_samples, 1), minval=low, maxval=high) + return ( + jax.random.uniform(rng, (1, 1), minval=low, maxval=high) + + jnp.arange(num_samples)[:, None] / num_samples + ) % ((high - low) - offset) From 58e3d29c12429d7bbabc979545de05b10b6acb99 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 5 Dec 2023 16:52:02 +0100 Subject: [PATCH 062/186] add citations --- docs/references.bib | 20 ++++++++++++++++++++ src/ott/neural/flows/genot.py | 14 +++++++------- src/ott/neural/flows/otfm.py | 24 +++++++++++++----------- src/ott/neural/models/base_solver.py | 20 ++++++++++---------- tests/neural/genot_test.py | 8 ++++---- tests/neural/otfm_test.py | 8 ++++---- 6 files changed, 58 insertions(+), 36 deletions(-) diff --git a/docs/references.bib b/docs/references.bib index c5d4c4678..e0c83a6ee 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -814,3 +814,23 @@ @misc{huguet:2023 title = {Geodesic Sinkhorn for Fast and Accurate Optimal Transport on Manifolds}, year = {2023}, } + +@misc{eyring:23, + author={Eyring, Luca and Klein, Dominik and Uscidda, Th{\'e}o and Palla, Giovanni and Kilbertus, Niki and Akata, Zeynep and Theis, Fabian}, + doi = {10.48550/arXiv.2311.15100}, + eprint = {2311.15100}, + eprintclass = {stat.ML}, + eprinttype = {arXiv}, + title={Unbalancedness in Neural Monge Maps Improves Unpaired Domain Translation}, + year={2023} +} + +@misc{klein_uscidda:23, + author={Dominik Klein and Théo Uscidda and Fabian Theis and Marco Cuturi}, + doi = {10.48550/arXiv.2310.09254}, + eprint={2310.09254}, + eprintclass = {stat.ML}, + eprinttype = {arXiv}, + title={Generative Entropic Neural Optimal Transport To Map Within and Across Spaces}, + year={2023}, +} diff --git a/src/ott/neural/flows/genot.py b/src/ott/neural/flows/genot.py index 813be6649..dedbf20ec 100644 --- a/src/ott/neural/flows/genot.py +++ b/src/ott/neural/flows/genot.py @@ -84,9 +84,9 @@ class GENOT(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): on the first marginal. tau_b: If :math:`< 1`, defines how much unbalanced the problem is on the second marginal. - mlp_eta: Neural network to learn the left rescaling function. If + rescaling_a: Neural network to learn the left rescaling function. If :obj:`None`, the left rescaling factor is not learnt. - mlp_xi: Neural network to learn the right rescaling function. If + rescaling_b: Neural network to learn the right rescaling function. If :obj:`None`, the right rescaling factor is not learnt. unbalanced_kwargs: Keyword arguments for the unbalancedness solver. callback_fn: Callback function. @@ -123,8 +123,8 @@ def __init__( fused_penalty: float = 0.0, tau_a: float = 1.0, tau_b: float = 1.0, - mlp_eta: Callable[[jnp.ndarray], float] = None, - mlp_xi: Callable[[jnp.ndarray], float] = None, + rescaling_a: Callable[[jnp.ndarray], float] = None, + rescaling_b: Callable[[jnp.ndarray], float] = None, unbalanced_kwargs: Dict[str, Any] = types.MappingProxyType({}), callback_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], Any]] = None, @@ -144,8 +144,8 @@ def __init__( cond_dim=cond_dim, tau_a=tau_a, tau_b=tau_b, - mlp_eta=mlp_eta, - mlp_xi=mlp_xi, + rescaling_a=rescaling_a, + rescaling_b=rescaling_b, unbalanced_kwargs=unbalanced_kwargs, ) if isinstance( @@ -443,7 +443,7 @@ def _valid_step(self, valid_loader, iter): @property def learn_rescaling(self) -> bool: """Whether to learn at least one rescaling factor.""" - return self.mlp_eta is not None or self.mlp_xi is not None + return self.rescaling_a is not None or self.rescaling_b is not None def save(self, path: str): """Save the model. diff --git a/src/ott/neural/flows/otfm.py b/src/ott/neural/flows/otfm.py index f7a973eb4..d08032bad 100644 --- a/src/ott/neural/flows/otfm.py +++ b/src/ott/neural/flows/otfm.py @@ -73,13 +73,15 @@ class OTFlowMatching(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): scale_cost: How to scale the cost matrix for the OT problem solved by the `ot_solver`. tau_a: If :math:`<1`, defines how much unbalanced the problem is - on the first marginal. + on the first marginal. tau_b: If :math:`< 1`, defines how much unbalanced the problem is - on the second marginal. - mlp_eta: Neural network to learn the left rescaling function as suggested - in :cite:`TODO`. If :obj:`None`, the left rescaling factor is not learnt. - mlp_xi: Neural network to learn the right rescaling function as suggested - in :cite:`TODO`. If :obj:`None`, the right rescaling factor is not learnt. + on the second marginal. + rescaling_a: Neural network to learn the left rescaling function as + suggested in :cite:`eyring:23`. If :obj:`None`, the left rescaling factor + is not learnt. + rescaling_b: Neural network to learn the right rescaling function as + suggested in :cite:`eyring:23`. If :obj:`None`, the right rescaling factor + is not learnt. unbalanced_kwargs: Keyword arguments for the unbalancedness solver. callback_fn: Callback function. num_eval_samples: Number of samples to evaluate on during evaluation. @@ -106,8 +108,8 @@ def __init__( "median"]] = "mean", tau_a: float = 1.0, tau_b: float = 1.0, - mlp_eta: Callable[[jnp.ndarray], float] = None, - mlp_xi: Callable[[jnp.ndarray], float] = None, + rescaling_a: Callable[[jnp.ndarray], float] = None, + rescaling_b: Callable[[jnp.ndarray], float] = None, unbalanced_kwargs: Dict[str, Any] = types.MappingProxyType({}), callback_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], Any]] = None, @@ -130,8 +132,8 @@ def __init__( cond_dim=cond_dim, tau_a=tau_a, tau_b=tau_b, - mlp_eta=mlp_eta, - mlp_xi=mlp_xi, + rescaling_a=rescaling_a, + rescaling_b=rescaling_b, unbalanced_kwargs=unbalanced_kwargs, ) @@ -332,7 +334,7 @@ def _valid_step(self, valid_loader, iter): @property def learn_rescaling(self) -> bool: """Whether to learn at least one rescaling factor.""" - return self.mlp_eta is not None or self.mlp_xi is not None + return self.rescaling_a is not None or self.rescaling_b is not None def save(self, path: str): """Save the model. diff --git a/src/ott/neural/models/base_solver.py b/src/ott/neural/models/base_solver.py index 3f807c4cc..2a82ab610 100644 --- a/src/ott/neural/models/base_solver.py +++ b/src/ott/neural/models/base_solver.py @@ -256,10 +256,10 @@ def __init__( cond_dim: Optional[int], tau_a: float = 1.0, tau_b: float = 1.0, - mlp_eta: Optional[Callable[[jnp.ndarray, Optional[jnp.ndarray]], - jnp.ndarray]] = None, - mlp_xi: Optional[Callable[[jnp.ndarray, Optional[jnp.ndarray]], - jnp.ndarray]] = None, + rescaling_a: Optional[Callable[[jnp.ndarray, Optional[jnp.ndarray]], + jnp.ndarray]] = None, + rescaling_b: Optional[Callable[[jnp.ndarray, Optional[jnp.ndarray]], + jnp.ndarray]] = None, seed: Optional[int] = None, opt_eta: Optional[optax.GradientTransformation] = None, opt_xi: Optional[optax.GradientTransformation] = None, @@ -275,8 +275,8 @@ def __init__( self.cond_dim = cond_dim self.tau_a = tau_a self.tau_b = tau_b - self.mlp_eta = mlp_eta - self.mlp_xi = mlp_xi + self.rescaling_a = rescaling_a + self.rescaling_b = rescaling_b self.seed = seed self.opt_eta = opt_eta self.opt_xi = opt_xi @@ -338,20 +338,20 @@ def _setup(self, source_dim: int, target_dim: int, cond_dim: int): self.rng_unbalanced, 3 ) self.unbalancedness_step_fn = self._get_rescaling_step_fn() - if self.mlp_eta is not None: + if self.rescaling_a is not None: self.opt_eta = ( self.opt_eta if self.opt_eta is not None else optax.adamw(learning_rate=1e-4, weight_decay=1e-10) ) - self.state_eta = self.mlp_eta.create_train_state( + self.state_eta = self.rescaling_a.create_train_state( rng_eta, self.opt_eta, source_dim ) - if self.mlp_xi is not None: + if self.rescaling_b is not None: self.opt_xi = ( self.opt_xi if self.opt_xi is not None else optax.adamw(learning_rate=1e-4, weight_decay=1e-10) ) - self.state_xi = self.mlp_xi.create_train_state( + self.state_xi = self.rescaling_b.create_train_state( rng_xi, self.opt_xi, target_dim ) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index 44ad21428..a4c221d40 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -350,8 +350,8 @@ def test_genot_linear_learn_rescaling( optimizer = optax.adam(learning_rate=1e-3) tau_a = 0.9 tau_b = 0.2 - mlp_eta = RescalingMLP(hidden_dim=4, condition_dim=condition_dim) - mlp_xi = RescalingMLP(hidden_dim=4, condition_dim=condition_dim) + rescaling_a = RescalingMLP(hidden_dim=4, condition_dim=condition_dim) + rescaling_b = RescalingMLP(hidden_dim=4, condition_dim=condition_dim) genot = GENOT( neural_vf, input_dim=source_dim, @@ -367,8 +367,8 @@ def test_genot_linear_learn_rescaling( time_sampler=time_sampler, tau_a=tau_a, tau_b=tau_b, - mlp_eta=mlp_eta, - mlp_xi=mlp_xi, + rescaling_a=rescaling_a, + rescaling_b=rescaling_b, ) genot(data_loader, data_loader) diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index 0af948705..74f483654 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -194,8 +194,8 @@ def test_flow_matching_learn_rescaling( tau_a = 0.9 tau_b = 0.2 - mlp_eta = RescalingMLP(hidden_dim=4, condition_dim=condition_dim) - mlp_xi = RescalingMLP(hidden_dim=4, condition_dim=condition_dim) + rescaling_a = RescalingMLP(hidden_dim=4, condition_dim=condition_dim) + rescaling_b = RescalingMLP(hidden_dim=4, condition_dim=condition_dim) fm = OTFlowMatching( neural_vf, input_dim=source_dim, @@ -208,8 +208,8 @@ def test_flow_matching_learn_rescaling( optimizer=optimizer, tau_a=tau_a, tau_b=tau_b, - mlp_eta=mlp_eta, - mlp_xi=mlp_xi, + rescaling_a=rescaling_a, + rescaling_b=rescaling_b, ) fm(data_loader, data_loader) From 2f5fa52517a8a97f0cff6735cffbf0c44c5812a9 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 5 Dec 2023 17:11:15 +0100 Subject: [PATCH 063/186] add more references --- docs/references.bib | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/docs/references.bib b/docs/references.bib index e0c83a6ee..f161c1570 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -834,3 +834,33 @@ @misc{klein_uscidda:23 title={Generative Entropic Neural Optimal Transport To Map Within and Across Spaces}, year={2023}, } + +@misc{lipman:22, + author={Lipman, Yaron and Chen, Ricky TQ and Ben-Hamu, Heli and Nickel, Maximilian and Le, Matt}, + doi = {10.48550/arXiv.2210.02747}, + eprint={2210.02747}, + eprintclass = {stat.ML}, + eprinttype = {arXiv}, + title={Flow matching for generative modeling}, + year={2022}, +} + +@misc{tong:23, + author={Tong, Alexander and Malkin, Nikolay and Huguet, Guillaume and Zhang, Yanlei and {Rector-Brooks}, Jarrid and Fatras, Kilian and Wolf, Guy and Bengio, Yoshua}, + doi={10.48550/arXiv.2302.00482}, + eprint={2302.00482}, + eprintclass = {stat.ML}, + eprinttype = {arXiv}, + title={Improving and Generalizing Flow-Based Generative Models with Minibatch Optimal Transport}, + year={2023}, +} + +@misc{pooladian:23, + author={Pooladian, Aram-Alexandre and Ben-Hamu, Heli and Domingo-Enrich, Carles and Amos, Brandon and Lipman, Yaron and Chen, Ricky}, + doi={10.48550/arXiv.2304.14772}, + eprint={2304.14772}, + eprintclass = {stat.ML}, + eprinttype = {arXiv}, + title={Multisample flow matching: Straightening flows with minibatch couplings}, + year={2023} +} From 9ad992431c6ec560471e098426791f594b54f72a Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 5 Dec 2023 17:22:23 +0100 Subject: [PATCH 064/186] rename keys_model to rng --- src/ott/neural/flows/genot.py | 7 +++---- src/ott/neural/flows/models.py | 4 ++-- src/ott/neural/flows/otfm.py | 7 ++----- src/ott/neural/models/base_solver.py | 12 ++++++------ 4 files changed, 13 insertions(+), 17 deletions(-) diff --git a/src/ott/neural/flows/genot.py b/src/ott/neural/flows/genot.py index dedbf20ec..e23d9ba06 100644 --- a/src/ott/neural/flows/genot.py +++ b/src/ott/neural/flows/genot.py @@ -340,7 +340,7 @@ def step_fn( def loss_fn( params: jnp.ndarray, batch: Dict[str, jnp.array], - keys_model: jax.random.PRNGKeyArray + rng: jax.random.PRNGKeyArray ): x_t = self.flow.compute_xt( batch["noise"], batch["time"], batch["latent"], batch["target"] @@ -355,9 +355,8 @@ def loss_fn( if batch[el] is not None ], axis=1) - v_t = jax.vmap(apply_fn)( - t=batch["time"], x=x_t, condition=cond_input, keys_model=keys_model - ) + v_t = jax.vmap(apply_fn + )(t=batch["time"], x=x_t, condition=cond_input, rng=rng) u_t = self.flow.compute_ut( batch["time"], batch["latent"], batch["target"] ) diff --git a/src/ott/neural/flows/models.py b/src/ott/neural/flows/models.py index 0177383c9..badc91232 100644 --- a/src/ott/neural/flows/models.py +++ b/src/ott/neural/flows/models.py @@ -98,7 +98,7 @@ def __call__( t: jnp.ndarray, x: jnp.ndarray, condition: Optional[jnp.ndarray] = None, - keys_model: Optional[jnp.ndarray] = None, + rng: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: """Forward pass through the neural vector field. @@ -106,7 +106,7 @@ def __call__( t: Time of shape (batch_size, 1). x: Data of shape (batch_size, output_dim). condition: Conditioning vector. - keys_model: Random number generator. + rng: Random number generator. Returns: Output of the neural vector field. diff --git a/src/ott/neural/flows/otfm.py b/src/ott/neural/flows/otfm.py index d08032bad..e27ff2582 100644 --- a/src/ott/neural/flows/otfm.py +++ b/src/ott/neural/flows/otfm.py @@ -187,7 +187,7 @@ def step_fn( def loss_fn( params: jnp.ndarray, t: jnp.ndarray, noise: jnp.ndarray, - batch: Dict[str, jnp.ndarray], keys_model: jax.random.PRNGKeyArray + batch: Dict[str, jnp.ndarray], rng: jax.random.PRNGKeyArray ) -> jnp.ndarray: x_t = self.flow.compute_xt( @@ -197,10 +197,7 @@ def loss_fn( state_neural_vector_field.apply_fn, {"params": params} ) v_t = jax.vmap(apply_fn)( - t=t, - x=x_t, - condition=batch["source_conditions"], - keys_model=keys_model + t=t, x=x_t, condition=batch["source_conditions"], rng=rng ) u_t = self.flow.compute_ut(t, batch["source_lin"], batch["target_lin"]) return jnp.mean((v_t - u_t) ** 2) diff --git a/src/ott/neural/models/base_solver.py b/src/ott/neural/models/base_solver.py index 2a82ab610..a5521e1e6 100644 --- a/src/ott/neural/models/base_solver.py +++ b/src/ott/neural/models/base_solver.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc -from pathlib import Path -from types import MappingProxyType +import pathlib +import types from typing import Any, Callable, Dict, Literal, Mapping, Optional, Tuple, Union import jax @@ -55,11 +55,11 @@ def transport(self, *args: Any, forward: bool, **kwargs: Any) -> Any: """Transport.""" @abc.abstractmethod - def save(self, path: Path): + def save(self, path: pathlib.Path): """Save the model.""" @abc.abstractmethod - def load(self, path: Path): + def load(self, path: pathlib.Path): """Load the model.""" @property @@ -266,7 +266,7 @@ def __init__( resample_epsilon: float = 1e-2, scale_cost: Union[bool, int, float, Literal["mean", "max_cost", "median"]] = "mean", - sinkhorn_kwargs: Mapping[str, Any] = MappingProxyType({}), + sinkhorn_kwargs: Mapping[str, Any] = types.MappingProxyType({}), **_: Any, ): self.rng_unbalanced = rng @@ -299,7 +299,7 @@ def _get_compute_unbalanced_marginals( resample_epsilon: float, scale_cost: Union[bool, int, float, Literal["mean", "max_cost", "median"]] = "mean", - sinkhorn_kwargs: Dict[str, Any] = MappingProxyType({}), + sinkhorn_kwargs: Dict[str, Any] = types.MappingProxyType({}), ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Compute the unbalanced source and target marginals for a batch.""" From 0addc7a6044af468f2e015c8346215a90d84650d Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 5 Dec 2023 17:27:42 +0100 Subject: [PATCH 065/186] fix tests regarding time sampling --- tests/neural/genot_test.py | 16 ++++++++-------- tests/neural/otfm_test.py | 12 ++++++------ 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index a4c221d40..148cc935d 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools from typing import Iterator, Optional import pytest @@ -20,9 +21,9 @@ import optax from ott.geometry import costs -from ott.neural.flows.flows import OffsetUniformSampler, UniformSampler from ott.neural.flows.genot import GENOT from ott.neural.flows.models import VelocityField +from ott.neural.flows.samplers import sample_uniformly from ott.neural.models.models import RescalingMLP from ott.solvers.linear import sinkhorn from ott.solvers.quadratic import gromov_wasserstein @@ -53,7 +54,7 @@ def test_genot_linear_unconditional( latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() - time_sampler = UniformSampler() + time_sampler = sample_uniformly optimizer = optax.adam(learning_rate=1e-3) genot = GENOT( neural_vf, @@ -103,7 +104,7 @@ def test_genot_quad_unconditional( latent_embed_dim=5, ) ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) - time_sampler = OffsetUniformSampler(1e-3) + time_sampler = functools.patial(sample_uniformly, offset=1e-2) optimizer = optax.adam(learning_rate=1e-3) genot = GENOT( neural_vf, @@ -149,7 +150,6 @@ def test_genot_fused_unconditional( latent_embed_dim=5, ) ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) - UniformSampler() optimizer = optax.adam(learning_rate=1e-3) genot = GENOT( neural_vf, @@ -196,7 +196,7 @@ def test_genot_linear_conditional( latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() - time_sampler = UniformSampler() + time_sampler = sample_uniformly optimizer = optax.adam(learning_rate=1e-3) genot = GENOT( neural_vf, @@ -243,7 +243,7 @@ def test_genot_quad_conditional( latent_embed_dim=5, ) ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) - time_sampler = UniformSampler() + time_sampler = sample_uniformly optimizer = optax.adam(learning_rate=1e-3) genot = GENOT( neural_vf, @@ -290,7 +290,7 @@ def test_genot_fused_conditional( latent_embed_dim=5, ) ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) - time_sampler = UniformSampler() + time_sampler = sample_uniformly optimizer = optax.adam(learning_rate=1e-3) genot = GENOT( neural_vf, @@ -346,7 +346,7 @@ def test_genot_linear_learn_rescaling( latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() - time_sampler = UniformSampler() + time_sampler = sample_uniformly optimizer = optax.adam(learning_rate=1e-3) tau_a = 0.9 tau_b = 0.2 diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index 74f483654..1230a638b 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools from typing import Iterator, Type import pytest @@ -23,11 +24,10 @@ BaseFlow, BrownianNoiseFlow, ConstantNoiseFlow, - OffsetUniformSampler, - UniformSampler, ) from ott.neural.flows.models import VelocityField from ott.neural.flows.otfm import OTFlowMatching +from ott.neural.flows.samplers import sample_uniformly from ott.neural.models.models import RescalingMLP from ott.solvers.linear import sinkhorn @@ -47,7 +47,7 @@ def test_flow_matching(self, data_loader_gaussian, flow: Type[BaseFlow]): latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() - time_sampler = UniformSampler() + time_sampler = sample_uniformly optimizer = optax.adam(learning_rate=1e-3) fm = OTFlowMatching( neural_vf, @@ -92,7 +92,7 @@ def test_flow_matching_with_conditions( latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() - time_sampler = OffsetUniformSampler(1e-6) + time_sampler = functools.partial(sample_uniformly, offset=1e-5) optimizer = optax.adam(learning_rate=1e-3) fm = OTFlowMatching( neural_vf, @@ -140,7 +140,7 @@ def test_flow_matching_conditional( latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() - time_sampler = UniformSampler() + time_sampler = sample_uniformly optimizer = optax.adam(learning_rate=1e-3) fm = OTFlowMatching( neural_vf, @@ -188,7 +188,7 @@ def test_flow_matching_learn_rescaling( latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() - time_sampler = UniformSampler() + time_sampler = sample_uniformly flow = ConstantNoiseFlow(1.0) optimizer = optax.adam(learning_rate=1e-3) From be68393644d35d7918dbea111fbb8417fd0e2608 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 5 Dec 2023 17:31:34 +0100 Subject: [PATCH 066/186] fix typo in tests --- tests/neural/genot_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index 148cc935d..7098e7419 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -104,7 +104,7 @@ def test_genot_quad_unconditional( latent_embed_dim=5, ) ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) - time_sampler = functools.patial(sample_uniformly, offset=1e-2) + time_sampler = functools.partial(sample_uniformly, offset=1e-2) optimizer = optax.adam(learning_rate=1e-3) genot = GENOT( neural_vf, From b5bdc4a3d91530d2e6430034625476f32a3a35a2 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 5 Dec 2023 17:43:35 +0100 Subject: [PATCH 067/186] rename neural_vector_field to velocity_field everywhere --- src/ott/neural/flows/genot.py | 38 +++++++++++++--------------- src/ott/neural/flows/otfm.py | 34 ++++++++++++------------- src/ott/neural/models/base_solver.py | 2 ++ 3 files changed, 35 insertions(+), 39 deletions(-) diff --git a/src/ott/neural/flows/genot.py b/src/ott/neural/flows/genot.py index e23d9ba06..622456a26 100644 --- a/src/ott/neural/flows/genot.py +++ b/src/ott/neural/flows/genot.py @@ -44,7 +44,7 @@ class GENOT(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): """The GENOT training class as introduced in :cite:`klein_uscidda:23`. Args: - neural_vector_field: Neural vector field parameterized by a neural network. + velocity_field: Neural vector field parameterized by a neural network. input_dim: Dimension of the data in the source distribution. output_dim: Dimension of the data in the target distribution. cond_dim: Dimension of the conditioning variable. @@ -68,7 +68,7 @@ class GENOT(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): terms, i.e. both quadratic terms and, if applicable, the linear temr. If of type :class:`dict`, the keys are expected to be `scale_cost_xx`, `scale_cost_yy`, and if applicable, `scale_cost_xy`. - optimizer: Optimizer for `neural_vector_field`. + optimizer: Optimizer for `velocity_field`. flow: Flow between latent distribution and target distribution. time_sampler: Sampler for the time. checkpoint_manager: Checkpoint manager. @@ -95,7 +95,7 @@ class GENOT(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): def __init__( self, - neural_vector_field: Callable[[ + velocity_field: Callable[[ jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray], Optional[jnp.ndarray] ], jnp.ndarray], input_dim: int, @@ -158,8 +158,8 @@ def __init__( ) self.rng = utils.default_prng_key(rng) - self.neural_vector_field = neural_vector_field - self.state_neural_vector_field: Optional[TrainState] = None + self.velocity_field = velocity_field + self.state_velocity_field: Optional[TrainState] = None self.flow = flow self.time_sampler = time_sampler self.optimizer = optimizer @@ -197,8 +197,8 @@ def setup(self): kwargs Keyword arguments for the setup function """ - self.state_neural_vector_field = ( - self.neural_vector_field.create_train_state( + self.state_velocity_field = ( + self.velocity_field.create_train_state( self.rng, self.optimizer, self.output_dim ) ) @@ -301,8 +301,8 @@ def __call__(self, train_loader, valid_loader): for key, arr in batch.items() } - self.state_neural_vector_field, loss = self.step_fn( - rng_step_fn, self.state_neural_vector_field, batch + self.state_velocity_field, loss = self.step_fn( + rng_step_fn, self.state_velocity_field, batch ) if self.learn_rescaling: ( @@ -320,9 +320,7 @@ def __call__(self, train_loader, valid_loader): if iteration % self.valid_freq == 0: self._valid_step(valid_loader, iteration) if self.checkpoint_manager is not None: - states_to_save = { - "state_neural_vector_field": self.state_neural_vector_field - } + states_to_save = {"state_velocity_field": self.state_velocity_field} if self.state_eta is not None: states_to_save["state_eta"] = self.state_eta if self.state_xi is not None: @@ -334,7 +332,7 @@ def _get_step_fn(self) -> Callable: @jax.jit def step_fn( key: jax.random.PRNGKeyArray, - state_neural_vector_field: train_state.TrainState, + state_velocity_field: train_state.TrainState, batch: Dict[str, jnp.array], ): @@ -346,7 +344,7 @@ def loss_fn( batch["noise"], batch["time"], batch["latent"], batch["target"] ) apply_fn = functools.partial( - state_neural_vector_field.apply_fn, {"params": params} + state_velocity_field.apply_fn, {"params": params} ) cond_input = jnp.concatenate([ @@ -365,9 +363,9 @@ def loss_fn( keys_model = jax.random.split(key, len(batch["noise"])) grad_fn = jax.value_and_grad(loss_fn, has_aux=False) - loss, grads = grad_fn(state_neural_vector_field.params, batch, keys_model) + loss, grads = grad_fn(state_velocity_field.params, batch, keys_model) - return state_neural_vector_field.apply_gradients(grads=grads), loss + return state_velocity_field.apply_gradients(grads=grads), loss return step_fn @@ -383,9 +381,7 @@ def transport( This method pushes-forward the `source` to its conditional distribution by solving the neural ODE parameterized by the - :attr:`~ott.neural.solvers.GENOTg.neural_vector_field` from - :attr:`~ott.neural.flows.BaseTimeSampler.low` to - :attr:`~ott.neural.flows.BaseTimeSampler.high`. + :attr:`~ott.neural.flows.genot.velocity_field` Args: source: Data to transport. @@ -415,8 +411,8 @@ def transport( def solve_ode(input: jnp.ndarray, cond: jnp.ndarray): return diffrax.diffeqsolve( diffrax.ODETerm( - lambda t, x, args: self.state_neural_vector_field. - apply_fn({"params": self.state_neural_vector_field.params}, + lambda t, x, args: self.state_velocity_field. + apply_fn({"params": self.state_velocity_field.params}, t=t, x=x, condition=cond) diff --git a/src/ott/neural/flows/otfm.py b/src/ott/neural/flows/otfm.py index e27ff2582..ad43736c6 100644 --- a/src/ott/neural/flows/otfm.py +++ b/src/ott/neural/flows/otfm.py @@ -54,7 +54,7 @@ class OTFlowMatching(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): (:cite`tong:23`, :cite:`pooladian:23`). Args: - neural_vector_field: Neural vector field parameterized by a neural network. + velocity_field: Neural vector field parameterized by a neural network. input_dim: Dimension of the input data. cond_dim: Dimension of the conditioning variable. iterations: Number of iterations. @@ -65,7 +65,7 @@ class OTFlowMatching(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): :cite:`lipman:22`. flow: Flow between source and target distribution. time_sampler: Sampler for the time. - optimizer: Optimizer for `neural_vector_field`. + optimizer: Optimizer for `velocity_field`. checkpoint_manager: Checkpoint manager. epsilon: Entropy regularization term of the OT OT problem solved by the `ot_solver`. @@ -90,7 +90,7 @@ class OTFlowMatching(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): def __init__( self, - neural_vector_field: Callable[[ + velocity_field: Callable[[ jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray], Optional[jnp.ndarray] ], jnp.ndarray], input_dim: int, @@ -137,7 +137,7 @@ def __init__( unbalanced_kwargs=unbalanced_kwargs, ) - self.neural_vector_field = neural_vector_field + self.velocity_field = velocity_field self.input_dim = input_dim self.ot_solver = ot_solver self.flow = flow @@ -157,8 +157,8 @@ def __init__( def setup(self): """Setup :class:`OTFlowMatching`.""" - self.state_neural_vector_field = ( - self.neural_vector_field.create_train_state( + self.state_velocity_field = ( + self.velocity_field.create_train_state( self.rng, self.optimizer, self.input_dim ) ) @@ -181,7 +181,7 @@ def _get_step_fn(self) -> Callable: @jax.jit def step_fn( key: jax.random.PRNGKeyArray, - state_neural_vector_field: train_state.TrainState, + state_velocity_field: train_state.TrainState, batch: Dict[str, jnp.ndarray], ) -> Tuple[Any, Any]: @@ -194,7 +194,7 @@ def loss_fn( noise, t, batch["source_lin"], batch["target_lin"] ) apply_fn = functools.partial( - state_neural_vector_field.apply_fn, {"params": params} + state_velocity_field.apply_fn, {"params": params} ) v_t = jax.vmap(apply_fn)( t=t, x=x_t, condition=batch["source_conditions"], rng=rng @@ -209,9 +209,9 @@ def loss_fn( noise = self.sample_noise(key_noise, batch_size) grad_fn = jax.value_and_grad(loss_fn) loss, grads = grad_fn( - state_neural_vector_field.params, t, noise, batch, keys_model + state_velocity_field.params, t, noise, batch, keys_model ) - return state_neural_vector_field.apply_gradients(grads=grads), loss + return state_velocity_field.apply_gradients(grads=grads), loss return step_fn @@ -237,8 +237,8 @@ def __call__(self, train_loader, valid_loader): (batch["source_lin"], batch["source_conditions"]), (batch["target_lin"], batch["target_conditions"]) ) - self.state_neural_vector_field, loss = self.step_fn( - rng_step_fn, self.state_neural_vector_field, batch + self.state_velocity_field, loss = self.step_fn( + rng_step_fn, self.state_velocity_field, batch ) curr_loss += loss if iter % self.logging_freq == 0: @@ -260,9 +260,7 @@ def __call__(self, train_loader, valid_loader): if iter % self.valid_freq == 0: self._valid_step(valid_loader, iter) if self.checkpoint_manager is not None: - states_to_save = { - "state_neural_vector_field": self.state_neural_vector_field - } + states_to_save = {"state_velocity_field": self.state_velocity_field} if self.state_eta is not None: states_to_save["state_eta"] = self.state_eta if self.state_xi is not None: @@ -279,7 +277,7 @@ def transport( """Transport data with the learnt map. This method solves the neural ODE parameterized by the - :attr:`~ott.neural.solvers.OTFlowMatching.neural_vector_field` from + :attr:`~ott.neural.solvers.OTFlowMatching.velocity_field` from :attr:`~ott.neural.flows.BaseTimeSampler.low` to :attr:`~ott.neural.flows.BaseTimeSampler.high` if `forward` is `True`, else the other way round. @@ -304,8 +302,8 @@ def transport( def solve_ode(input: jnp.ndarray, cond: jnp.ndarray): return diffrax.diffeqsolve( diffrax.ODETerm( - lambda t, x, args: self.state_neural_vector_field. - apply_fn({"params": self.state_neural_vector_field.params}, + lambda t, x, args: self.state_velocity_field. + apply_fn({"params": self.state_velocity_field.params}, t=t, x=x, condition=cond) diff --git a/src/ott/neural/models/base_solver.py b/src/ott/neural/models/base_solver.py index a5521e1e6..98272c028 100644 --- a/src/ott/neural/models/base_solver.py +++ b/src/ott/neural/models/base_solver.py @@ -164,6 +164,7 @@ def match_pairs( linear_problem.LinearProblem(geom, tau_a=tau_a, tau_b=tau_b) ).matrix + @jax.jit def match_pairs_filtered( x_lin: jnp.ndarray, x_quad: jnp.ndarray, y_lin: jnp.ndarray, y_quad: jnp.ndarray @@ -213,6 +214,7 @@ def _get_gromov_match_fn( else: scale_cost_xx = scale_cost_yy = scale_cost_xy = scale_cost + @jax.jit def match_pairs( x_lin: Optional[jnp.ndarray], x_quad: Tuple[jnp.ndarray, jnp.ndarray], From bebbbd0db409ce2001160da8e9ad6b2afbe931f6 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 5 Dec 2023 18:02:09 +0100 Subject: [PATCH 068/186] fix OTFlowMatching.transport --- src/ott/neural/flows/models.py | 7 ++----- src/ott/neural/flows/otfm.py | 15 ++++++++------- src/ott/neural/models/models.py | 4 +--- 3 files changed, 11 insertions(+), 15 deletions(-) diff --git a/src/ott/neural/flows/models.py b/src/ott/neural/flows/models.py index badc91232..bdc1ab4aa 100644 --- a/src/ott/neural/flows/models.py +++ b/src/ott/neural/flows/models.py @@ -62,7 +62,7 @@ class VelocityField(nn.Module): """ output_dim: int latent_embed_dim: int - condition_dim: Optional[int] = None + condition_dim: int = 0 condition_embed_dim: Optional[int] = None t_embed_dim: Optional[int] = None joint_hidden_dim: Optional[int] = None @@ -111,9 +111,6 @@ def __call__( Returns: Output of the neural vector field. """ - if self.condition_dim is None: - assert condition is None - t = flow_layers.CyclicalTimeEncoder(n_frequencies=self.n_frequencies)(t) t_layer = layers.MLPBlock( dim=self.t_embed_dim, @@ -131,7 +128,7 @@ def __call__( ) x = x_layer(x) - if self.condition_dim is not None: + if self.condition_dim > 0: condition_layer = layers.MLPBlock( dim=self.condition_embed_dim, out_dim=self.condition_embed_dim, diff --git a/src/ott/neural/flows/otfm.py b/src/ott/neural/flows/otfm.py index ad43736c6..1be3bdc16 100644 --- a/src/ott/neural/flows/otfm.py +++ b/src/ott/neural/flows/otfm.py @@ -272,20 +272,22 @@ def transport( data: jnp.array, condition: Optional[jnp.ndarray] = None, forward: bool = True, + t_0: float = 0.0, + t_1: float = 1.0, diffeqsolve_kwargs: Dict[str, Any] = types.MappingProxyType({}) ) -> diffrax.Solution: """Transport data with the learnt map. - This method solves the neural ODE parameterized by the - :attr:`~ott.neural.solvers.OTFlowMatching.velocity_field` from - :attr:`~ott.neural.flows.BaseTimeSampler.low` to - :attr:`~ott.neural.flows.BaseTimeSampler.high` if `forward` is `True`, - else the other way round. + This method pushes-forward the `source` by + solving the neural ODE parameterized by the + :attr:`~ott.neural.flows.OTFlowMatching.velocity_field`. Args: data: Initial condition of the ODE. condition: Condition of the input data. forward: If `True` integrates forward, otherwise backwards. + t_0: Starting point of integration. + t_1: End point of integration. diffeqsolve_kwargs: Keyword arguments for the ODE solver. Returns: @@ -295,8 +297,7 @@ def transport( """ diffeqsolve_kwargs = dict(diffeqsolve_kwargs) - t0, t1 = (self.time_sampler.low, self.time_sampler.high - ) if forward else (self.time_sampler.high, self.time_sampler.low) + t0, t1 = (t_0, t_1) if forward else (t_1, t_0) @jax.jit def solve_ode(input: jnp.ndarray, cond: jnp.ndarray): diff --git a/src/ott/neural/models/models.py b/src/ott/neural/models/models.py index e84e3560a..56982270f 100644 --- a/src/ott/neural/models/models.py +++ b/src/ott/neural/models/models.py @@ -107,7 +107,7 @@ class RescalingMLP(nn.Module): act_fn: Activation function. Returns: - Non-negative rescaling factors. + Non-negative escaling factors. """ hidden_dim: int condition_dim: Optional[int] = None @@ -129,8 +129,6 @@ def __call__( Returns: Estimated rescaling factors. """ - if self.condition_dim is None: - assert condition is None x_layer = layers.MLPBlock( dim=self.hidden_dim, out_dim=self.hidden_dim, From f4c05c488978d842de988f967ab297e925b0d1ac Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 5 Dec 2023 18:07:26 +0100 Subject: [PATCH 069/186] fix rescaling networks --- src/ott/neural/models/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ott/neural/models/models.py b/src/ott/neural/models/models.py index 56982270f..93af5b58d 100644 --- a/src/ott/neural/models/models.py +++ b/src/ott/neural/models/models.py @@ -110,7 +110,7 @@ class RescalingMLP(nn.Module): Non-negative escaling factors. """ hidden_dim: int - condition_dim: Optional[int] = None + condition_dim: int = 0 num_layers_per_block: int = 3 act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.selu @@ -137,7 +137,7 @@ def __call__( ) x = x_layer(x) - if self.condition_dim is not None: + if self.condition_dim > 0: condition_layer = layers.MLPBlock( dim=self.hidden_dim, out_dim=self.hidden_dim, From 4d9992e61cdadb9d5d2633fd31456e918b0ab29b Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 5 Jan 2024 15:12:05 +0100 Subject: [PATCH 070/186] Update src/ott/neural/flows/flows.py Co-authored-by: nvesseron <96598529+nvesseron@users.noreply.github.com> --- src/ott/neural/flows/flows.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ott/neural/flows/flows.py b/src/ott/neural/flows/flows.py index 0dce912aa..5ddaa56b8 100644 --- a/src/ott/neural/flows/flows.py +++ b/src/ott/neural/flows/flows.py @@ -101,7 +101,7 @@ class StraightFlow(BaseFlow, abc.ABC): def compute_mu_t( # noqa: D102 self, t: jnp.ndarray, src: jnp.ndarray, tgt: jnp.ndarray ) -> jnp.ndarray: - return t * src + (1 - t) * tgt + return (1 - t) * src + t * tgt def compute_ut( self, t: jnp.ndarray, src: jnp.ndarray, tgt: jnp.ndarray From 51221ddaa53061c1f8cc08f100e9616a26ba06e3 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 5 Jan 2024 15:12:24 +0100 Subject: [PATCH 071/186] Update src/ott/neural/flows/flows.py Co-authored-by: nvesseron <96598529+nvesseron@users.noreply.github.com> --- src/ott/neural/flows/flows.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ott/neural/flows/flows.py b/src/ott/neural/flows/flows.py index 5ddaa56b8..572abba91 100644 --- a/src/ott/neural/flows/flows.py +++ b/src/ott/neural/flows/flows.py @@ -154,4 +154,4 @@ def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray: Returns: Standard deviation of the probablity path at time :math:`t`. """ - return jnp.sqrt(self.sigma * t * (1 - t)) + return self.sigma * jnp.sqrt(t * (1 - t)) From 6c56dfe4c9809caf3f2f87f8d70254e3b14cd8f7 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Mon, 8 Jan 2024 19:13:36 +0100 Subject: [PATCH 072/186] test for scale_cost --- tests/neural/genot_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index 7098e7419..e0a27ea72 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -from typing import Iterator, Optional +from typing import Iterator, Optional, Union, Literal import pytest @@ -31,10 +31,11 @@ class TestGENOT: + @pytest.mark.parameterize("scale_cost", ["mean", 2.0]) @pytest.mark.parametrize("k_samples_per_x", [1, 2]) @pytest.mark.parametrize("solver_latent_to_data", [None, "sinkhorn"]) def test_genot_linear_unconditional( - self, genot_data_loader_linear: Iterator, k_samples_per_x: int, + self, genot_data_loader_linear: Iterator, scale_cost: Union[float, Literal["mean"]], k_samples_per_x: int, solver_latent_to_data: Optional[str] ): solver_latent_to_data = ( From cc045fa1ab741d8ce7bcfb3ab59d445a1b97b350 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 9 Jan 2024 11:20:13 +0100 Subject: [PATCH 073/186] update test for scale_cost --- tests/neural/genot_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index e0a27ea72..0ba932587 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -31,7 +31,7 @@ class TestGENOT: - @pytest.mark.parameterize("scale_cost", ["mean", 2.0]) + @pytest.mark.parametrize("scale_cost", ["mean", 2.0]) @pytest.mark.parametrize("k_samples_per_x", [1, 2]) @pytest.mark.parametrize("solver_latent_to_data", [None, "sinkhorn"]) def test_genot_linear_unconditional( @@ -67,7 +67,7 @@ def test_genot_linear_unconditional( ot_solver=ot_solver, epsilon=0.1, cost_fn=costs.SqEuclidean(), - scale_cost=1.0, + scale_cost=scale_cost, optimizer=optimizer, time_sampler=time_sampler, k_samples_per_x=k_samples_per_x, From f4de3394c7be6f67e91df605aaa156184385edf9 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 9 Jan 2024 15:01:24 +0100 Subject: [PATCH 074/186] fix bug for scale_cost --- src/ott/neural/flows/genot.py | 27 ++++++++++++++++----------- tests/neural/genot_test.py | 5 +++-- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/src/ott/neural/flows/genot.py b/src/ott/neural/flows/genot.py index 622456a26..94f609c99 100644 --- a/src/ott/neural/flows/genot.py +++ b/src/ott/neural/flows/genot.py @@ -195,7 +195,7 @@ def setup(self): Parameters ---------- kwargs - Keyword arguments for the setup function + Keyword arguments for the setup function. """ self.state_velocity_field = ( self.velocity_field.create_train_state( @@ -205,7 +205,7 @@ def setup(self): self.step_fn = self._get_step_fn() if self.solver_latent_to_data is not None: self.match_latent_to_data_fn = self._get_sinkhorn_match_fn( - self.solver_latent_to_data, **self.kwargs_solver_latent_to_data + ot_solver=self.solver_latent_to_data, **self.kwargs_solver_latent_to_data ) else: self.match_latent_to_data_fn = lambda key, x, y, **_: (x, y) @@ -213,22 +213,27 @@ def setup(self): # TODO: add graph construction function if isinstance(self.ot_solver, sinkhorn.Sinkhorn): self.match_fn = self._get_sinkhorn_match_fn( - self.ot_solver, - self.epsilon, - self.cost_fn, - self.tau_a, - self.tau_b, - self.scale_cost, + ot_solver=self.ot_solver, + epsilon=self.epsilon, + cost_fn=self.cost_fn, + scale_cost=self.scale_cost, + tau_a=self.tau_a, + tau_b=self.tau_b, filter_input=True ) else: self.match_fn = self._get_gromov_match_fn( - self.ot_solver, self.cost_fn, self.tau_a, self.tau_b, self.scale_cost, - self.fused_penalty + ot_solver=self.ot_solver, cost_fn=self.cost_fn, scale_cost=self.scale_cost, tau_a=self.tau_a, tau_b=self.tau_b, + fused_penalty=self.fused_penalty ) def __call__(self, train_loader, valid_loader): - """Train GENOT.""" + """Train GENOT. + + Args: + train_loader: Data loader for the training data. + valid_loader: Data loader for the validation data. + """ batch: Dict[str, jnp.array] = {} for iteration in range(self.iterations): batch = next(train_loader) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index 0ba932587..b3de698ad 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -from typing import Iterator, Optional, Union, Literal +from typing import Iterator, Literal, Optional, Union import pytest @@ -35,7 +35,8 @@ class TestGENOT: @pytest.mark.parametrize("k_samples_per_x", [1, 2]) @pytest.mark.parametrize("solver_latent_to_data", [None, "sinkhorn"]) def test_genot_linear_unconditional( - self, genot_data_loader_linear: Iterator, scale_cost: Union[float, Literal["mean"]], k_samples_per_x: int, + self, genot_data_loader_linear: Iterator, + scale_cost: Union[float, Literal["mean"]], k_samples_per_x: int, solver_latent_to_data: Optional[str] ): solver_latent_to_data = ( From 5db4c730a385d9900c99518ef59b193b68f606d1 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 9 Jan 2024 15:01:59 +0100 Subject: [PATCH 075/186] fix bug for scale_cost --- src/ott/neural/flows/genot.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/ott/neural/flows/genot.py b/src/ott/neural/flows/genot.py index 94f609c99..4efc4a9ad 100644 --- a/src/ott/neural/flows/genot.py +++ b/src/ott/neural/flows/genot.py @@ -205,7 +205,8 @@ def setup(self): self.step_fn = self._get_step_fn() if self.solver_latent_to_data is not None: self.match_latent_to_data_fn = self._get_sinkhorn_match_fn( - ot_solver=self.solver_latent_to_data, **self.kwargs_solver_latent_to_data + ot_solver=self.solver_latent_to_data, + **self.kwargs_solver_latent_to_data ) else: self.match_latent_to_data_fn = lambda key, x, y, **_: (x, y) @@ -223,13 +224,17 @@ def setup(self): ) else: self.match_fn = self._get_gromov_match_fn( - ot_solver=self.ot_solver, cost_fn=self.cost_fn, scale_cost=self.scale_cost, tau_a=self.tau_a, tau_b=self.tau_b, + ot_solver=self.ot_solver, + cost_fn=self.cost_fn, + scale_cost=self.scale_cost, + tau_a=self.tau_a, + tau_b=self.tau_b, fused_penalty=self.fused_penalty ) def __call__(self, train_loader, valid_loader): """Train GENOT. - + Args: train_loader: Data loader for the training data. valid_loader: Data loader for the validation data. From 72885ac75491d98e55b4f039c583134d3c2ba7b7 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 10 Jan 2024 17:52:27 +0100 Subject: [PATCH 076/186] jit solve_ode in genot --- src/ott/neural/flows/genot.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/ott/neural/flows/genot.py b/src/ott/neural/flows/genot.py index 4efc4a9ad..c12026251 100644 --- a/src/ott/neural/flows/genot.py +++ b/src/ott/neural/flows/genot.py @@ -418,6 +418,7 @@ def transport( axis=-1) t0, t1 = (0.0, 1.0) + @jax.jit def solve_ode(input: jnp.ndarray, cond: jnp.ndarray): return diffrax.diffeqsolve( diffrax.ODETerm( From 937fffcce22a3b200df826bf266879b0a3afb53a Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 7 Feb 2024 15:10:51 +0100 Subject: [PATCH 077/186] incorporate changes partially --- docs/references.bib | 52 ++++++------- pyproject.toml | 4 +- src/ott/__init__.py | 11 +-- src/ott/datasets.py | 2 +- src/ott/neural/duality/models.py | 2 +- src/ott/neural/duality/neuraldual.py | 16 +--- src/ott/neural/flows/flows.py | 18 +++-- src/ott/neural/flows/genot.py | 54 ++++++------- src/ott/neural/flows/models.py | 5 +- src/ott/neural/flows/otfm.py | 47 +++++------- src/ott/neural/flows/samplers.py | 6 +- src/ott/neural/gaps/map_estimator.py | 11 +-- src/ott/neural/models/base_solver.py | 51 ++++++------- src/ott/neural/models/layers.py | 17 +++-- src/ott/neural/models/models.py | 73 ++---------------- src/ott/solvers/linear/sinkhorn_lr.py | 11 +-- src/ott/solvers/quadratic/__init__.py | 7 +- .../quadratic/gromov_wasserstein_lr.py | 11 +-- .../tools/gaussian_mixture/fit_gmm_pair.py | 6 +- .../gaussian_mixture/gaussian_mixture.py | 7 +- tests/neural/conftest.py | 27 ++++--- tests/neural/genot_test.py | 14 ++-- tests/neural/otfm_test.py | 75 +++++++++---------- 23 files changed, 191 insertions(+), 336 deletions(-) diff --git a/docs/references.bib b/docs/references.bib index f161c1570..799e827d1 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -816,51 +816,51 @@ @misc{huguet:2023 } @misc{eyring:23, - author={Eyring, Luca and Klein, Dominik and Uscidda, Th{\'e}o and Palla, Giovanni and Kilbertus, Niki and Akata, Zeynep and Theis, Fabian}, - doi = {10.48550/arXiv.2311.15100}, + author = {Eyring, Luca and Klein, Dominik and Uscidda, Théo and Palla, Giovanni and Kilbertus, Niki and Akata, Zeynep and Theis, Fabian}, + doi = {10.48550/arXiv.2311.15100}, eprint = {2311.15100}, eprintclass = {stat.ML}, eprinttype = {arXiv}, - title={Unbalancedness in Neural Monge Maps Improves Unpaired Domain Translation}, - year={2023} + title = {Unbalancedness in Neural Monge Maps Improves Unpaired Domain Translation}, + year = {2023}, } @misc{klein_uscidda:23, - author={Dominik Klein and Théo Uscidda and Fabian Theis and Marco Cuturi}, - doi = {10.48550/arXiv.2310.09254}, - eprint={2310.09254}, - eprintclass = {stat.ML}, - eprinttype = {arXiv}, - title={Generative Entropic Neural Optimal Transport To Map Within and Across Spaces}, - year={2023}, + author = {Klein, Dominik and Uscidda, Théo and Theis, Fabian and Cuturi, Marco}, + doi = {10.48550/arXiv.2310.09254}, + eprint = {2310.09254}, + eprintclass = {stat.ML}, + eprinttype = {arXiv}, + title = {Generative Entropic Neural Optimal Transport To Map Within and Across Spaces}, + year = {2023}, } @misc{lipman:22, - author={Lipman, Yaron and Chen, Ricky TQ and Ben-Hamu, Heli and Nickel, Maximilian and Le, Matt}, - doi = {10.48550/arXiv.2210.02747}, - eprint={2210.02747}, + author = {Lipman, Yaron and Chen, Ricky TQ and Ben-Hamu, Heli and Nickel, Maximilian and Le, Matt}, + doi = {10.48550/arXiv.2210.02747}, + eprint = {2210.02747}, eprintclass = {stat.ML}, eprinttype = {arXiv}, - title={Flow matching for generative modeling}, - year={2022}, + title = {Flow matching for generative modeling}, + year = {2022}, } @misc{tong:23, - author={Tong, Alexander and Malkin, Nikolay and Huguet, Guillaume and Zhang, Yanlei and {Rector-Brooks}, Jarrid and Fatras, Kilian and Wolf, Guy and Bengio, Yoshua}, - doi={10.48550/arXiv.2302.00482}, - eprint={2302.00482}, + author = {Tong, Alexander and Malkin, Nikolay and Huguet, Guillaume and Zhang, Yanlei and {Rector-Brooks}, Jarrid and Fatras, Kilian and Wolf, Guy and Bengio, Yoshua}, + doi = {10.48550/arXiv.2302.00482}, + eprint = {2302.00482}, eprintclass = {stat.ML}, eprinttype = {arXiv}, - title={Improving and Generalizing Flow-Based Generative Models with Minibatch Optimal Transport}, - year={2023}, + title = {Improving and Generalizing Flow-Based Generative Models with Minibatch Optimal Transport}, + year = {2023}, } @misc{pooladian:23, - author={Pooladian, Aram-Alexandre and Ben-Hamu, Heli and Domingo-Enrich, Carles and Amos, Brandon and Lipman, Yaron and Chen, Ricky}, - doi={10.48550/arXiv.2304.14772}, - eprint={2304.14772}, + author = {Pooladian, Aram-Alexandre and Ben-Hamu, Heli and Domingo-Enrich, Carles and Amos, Brandon and Lipman, Yaron and Chen, Ricky}, + doi = {10.48550/arXiv.2304.14772}, + eprint = {2304.14772}, eprintclass = {stat.ML}, eprinttype = {arXiv}, - title={Multisample flow matching: Straightening flows with minibatch couplings}, - year={2023} + title = {Multisample flow matching: Straightening flows with minibatch couplings}, + year = {2023}, } diff --git a/pyproject.toml b/pyproject.toml index 1961a5971..5c128e7c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,7 +103,6 @@ include = '\.ipynb$' [tool.isort] profile = "black" -line_length = 80 include_trailing_comma = true multi_line_output = 3 sections = ["FUTURE", "STDLIB", "THIRDPARTY", "TEST", "NUMERIC", "NEURAL", "PLOTTING", "FIRSTPARTY", "LOCALFOLDER"] @@ -289,6 +288,7 @@ ignore = [ line-length = 80 select = [ "D", # flake8-docstrings + "I", # isort "E", # pycodestyle "F", # pyflakes "W", # pycodestyle @@ -304,7 +304,7 @@ select = [ "T20", # flake8-print "RET", # flake8-raise ] -unfixable = ["I", "B", "UP", "C4", "BLE", "T20", "RET"] +unfixable = ["B", "UP", "C4", "BLE", "T20", "RET"] target-version = "py38" [tool.ruff.per-file-ignores] # TODO(michalk8): PO004 - remove `self.initialize` diff --git a/src/ott/__init__.py b/src/ott/__init__.py index dac0eb854..8d2f007c5 100644 --- a/src/ott/__init__.py +++ b/src/ott/__init__.py @@ -13,16 +13,7 @@ # limitations under the License. import contextlib -from . import ( - datasets, - geometry, - initializers, - math, - problems, - solvers, - tools, - utils, -) +from . import datasets, geometry, initializers, math, problems, solvers, tools, utils with contextlib.suppress(ImportError): # TODO(michalk8): add warning that neural module is not imported diff --git a/src/ott/datasets.py b/src/ott/datasets.py index e5077c87c..36ac6b561 100644 --- a/src/ott/datasets.py +++ b/src/ott/datasets.py @@ -61,7 +61,7 @@ class GaussianMixture: scale: float = 5.0 std: float = 0.5 - def __post_init__(self): + def __post_init__(self) -> None: gaussian_centers = { "simple": np.array([[0, 0]]), diff --git a/src/ott/neural/duality/models.py b/src/ott/neural/duality/models.py index baa0386c8..5c18f5eb0 100644 --- a/src/ott/neural/duality/models.py +++ b/src/ott/neural/duality/models.py @@ -69,7 +69,7 @@ class ICNN(neuraldual.BaseW2NeuralDual): def is_potential(self) -> bool: # noqa: D102 return True - def setup(self): # noqa: D102 + def setup(self) -> None: # noqa: D102 self.num_hidden = len(self.dim_hidden) if self.pos_weights: diff --git a/src/ott/neural/duality/neuraldual.py b/src/ott/neural/duality/neuraldual.py index a8f5fd273..573e7b2bb 100644 --- a/src/ott/neural/duality/neuraldual.py +++ b/src/ott/neural/duality/neuraldual.py @@ -13,17 +13,7 @@ # limitations under the License. import abc import warnings -from typing import ( - Any, - Callable, - Dict, - Iterator, - List, - Literal, - Optional, - Tuple, - Union, -) +from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Tuple, Union import jax import jax.numpy as jnp @@ -290,7 +280,7 @@ def setup( dim_data: int, optimizer_f: optax.OptState, optimizer_g: optax.OptState, - ): + ) -> None: """Setup all components required to train the network.""" # split random number generator rng, rng_f, rng_g = jax.random.split(rng, 3) @@ -695,7 +685,7 @@ def _update_logs( loss_f: jnp.ndarray, loss_g: jnp.ndarray, w_dist: jnp.ndarray, - ): + ) -> None: logs["loss_f"].append(float(loss_f)) logs["loss_g"].append(float(loss_g)) logs["w_dist"].append(float(w_dist)) diff --git a/src/ott/neural/flows/flows.py b/src/ott/neural/flows/flows.py index 572abba91..c379dcbc3 100644 --- a/src/ott/neural/flows/flows.py +++ b/src/ott/neural/flows/flows.py @@ -39,7 +39,7 @@ def compute_mu_t( ) -> jnp.ndarray: """Compute the mean of the probablitiy path. - Compute the mean of the probablitiy path between :math:`x` and :math:`y` + Compute the mean of the probablitiy path between :math:`x_0` and :math:`x_1` at time :math:`t`. Args: @@ -69,6 +69,9 @@ def compute_ut( t: Time :math:`t`. src: Sample from the source distribution. tgt: Sample from the target distribution. + + Returns: + Conditional vector field evaluated at time :math:`t`. """ def compute_xt( @@ -101,7 +104,7 @@ class StraightFlow(BaseFlow, abc.ABC): def compute_mu_t( # noqa: D102 self, t: jnp.ndarray, src: jnp.ndarray, tgt: jnp.ndarray ) -> jnp.ndarray: - return (1 - t) * src + t * tgt + return (1.0 - t) * src + t * tgt def compute_ut( self, t: jnp.ndarray, src: jnp.ndarray, tgt: jnp.ndarray @@ -119,6 +122,7 @@ def compute_ut( Returns: Conditional vector field evaluated at time :math:`t`. """ + del t return tgt - src @@ -134,15 +138,19 @@ def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray: Returns: Constant, time-independent standard deviation :math:`\sigma`. """ - return self.sigma + return jnp.full_like(t, fill_value=self.sigma) class BrownianNoiseFlow(StraightFlow): r"""Brownian Bridge Flow. Sampler for sampling noise implicitly defined by a Schroedinger Bridge - problem with parameter `\sigma` such that + problem with parameter :math:`\sigma` such that :math:`\sigma_t = \sigma * \sqrt(t * (1-t))`. + + Returns: + Samples from the probability path between :math:`x_0` and :math:`x_1` + at time :math:`t`. """ def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray: @@ -154,4 +162,4 @@ def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray: Returns: Standard deviation of the probablity path at time :math:`t`. """ - return self.sigma * jnp.sqrt(t * (1 - t)) + return self.sigma * jnp.sqrt(t * (1.0 - t)) diff --git a/src/ott/neural/flows/genot.py b/src/ott/neural/flows/genot.py index c12026251..b11b77b20 100644 --- a/src/ott/neural/flows/genot.py +++ b/src/ott/neural/flows/genot.py @@ -21,18 +21,12 @@ import diffrax import optax from flax.training import train_state -from flax.training.train_state import TrainState from orbax import checkpoint from ott import utils from ott.geometry import costs -from ott.neural.flows.flows import BaseFlow, ConstantNoiseFlow -from ott.neural.flows.samplers import sample_uniformly -from ott.neural.models.base_solver import ( - BaseNeuralSolver, - ResampleMixin, - UnbalancednessMixin, -) +from ott.neural.flows import flows, samplers +from ott.neural.models import base_solver from ott.solvers import was_solver from ott.solvers.linear import sinkhorn from ott.solvers.quadratic import gromov_wasserstein @@ -40,7 +34,10 @@ __all__ = ["GENOT"] -class GENOT(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): +class GENOT( + base_solver.UnbalancednessMixin, base_solver.ResampleMixin, + base_solver.BaseNeuralSolver +): """The GENOT training class as introduced in :cite:`klein_uscidda:23`. Args: @@ -81,15 +78,15 @@ class GENOT(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): fused_penalty: Fused penalty of the linear/fused term in the Fused Gromov-Wasserstein problem. tau_a: If :math:`<1`, defines how much unbalanced the problem is - on the first marginal. + on the first marginal. tau_b: If :math:`< 1`, defines how much unbalanced the problem is - on the second marginal. + on the second marginal. rescaling_a: Neural network to learn the left rescaling function. If :obj:`None`, the left rescaling factor is not learnt. rescaling_b: Neural network to learn the right rescaling function. If :obj:`None`, the right rescaling factor is not learnt. unbalanced_kwargs: Keyword arguments for the unbalancedness solver. - callback_fn: Callback function. + callback_fn: Callback function. rng: Random number generator. """ @@ -103,7 +100,7 @@ def __init__( cond_dim: int, iterations: int, valid_freq: int, - ot_solver: Type[was_solver.WassersteinSolver], + ot_solver: was_solver.WassersteinSolver, epsilon: float, cost_fn: Union[costs.CostFn, Dict[str, costs.CostFn]], scale_cost: Union[Union[bool, int, float, @@ -112,9 +109,10 @@ def __init__( Dict[str, Union[bool, int, float, Literal["mean", "max_norm", "max_bound", "max_cost", "median"]]]], - optimizer: Type[optax.GradientTransformation], - flow: Type[BaseFlow] = ConstantNoiseFlow(0.0), - time_sampler: Callable[[jax.Array, int], jnp.ndarray] = sample_uniformly, + optimizer: optax.GradientTransformation, + flow: Type[flows.BaseFlow] = flows.ConstantNoiseFlow(0.0), # noqa: B008 + time_sampler: Callable[[jax.Array, int], + jnp.ndarray] = samplers.uniform_sampler, checkpoint_manager: Type[checkpoint.CheckpointManager] = None, k_samples_per_x: int = 1, solver_latent_to_data: Optional[Type[was_solver.WassersteinSolver] @@ -132,11 +130,11 @@ def __init__( ): rng = utils.default_prng_key(rng) rng, rng_unbalanced = jax.random.split(rng) - BaseNeuralSolver.__init__( + base_solver.BaseNeuralSolver.__init__( self, iterations=iterations, valid_freq=valid_freq ) - ResampleMixin.__init__(self) - UnbalancednessMixin.__init__( + base_solver.ResampleMixin.__init__(self) + base_solver.UnbalancednessMixin.__init__( self, rng=rng_unbalanced, source_dim=input_dim, @@ -159,7 +157,7 @@ def __init__( self.rng = utils.default_prng_key(rng) self.velocity_field = velocity_field - self.state_velocity_field: Optional[TrainState] = None + self.state_velocity_field: Optional[train_state.TrainState] = None self.flow = flow self.time_sampler = time_sampler self.optimizer = optimizer @@ -189,14 +187,8 @@ def __init__( self.callback_fn = callback_fn self.setup() - def setup(self): - """Set up the model. - - Parameters - ---------- - kwargs - Keyword arguments for the setup function. - """ + def setup(self) -> None: + """Set up the model.""" self.state_velocity_field = ( self.velocity_field.create_train_state( self.rng, self.optimizer, self.output_dim @@ -341,7 +333,7 @@ def _get_step_fn(self) -> Callable: @jax.jit def step_fn( - key: jax.random.PRNGKeyArray, + rng: jax.Array, state_velocity_field: train_state.TrainState, batch: Dict[str, jnp.array], ): @@ -370,7 +362,7 @@ def loss_fn( ) return jnp.mean((v_t - u_t) ** 2) - keys_model = jax.random.split(key, len(batch["noise"])) + keys_model = jax.random.split(rng, len(batch["noise"])) grad_fn = jax.value_and_grad(loss_fn, has_aux=False) loss, grads = grad_fn(state_velocity_field.params, batch, keys_model) @@ -419,7 +411,7 @@ def transport( t0, t1 = (0.0, 1.0) @jax.jit - def solve_ode(input: jnp.ndarray, cond: jnp.ndarray): + def solve_ode(input: jnp.ndarray, cond: jnp.ndarray) -> jnp.ndarray: return diffrax.diffeqsolve( diffrax.ODETerm( lambda t, x, args: self.state_velocity_field. diff --git a/src/ott/neural/flows/models.py b/src/ott/neural/flows/models.py index bdc1ab4aa..6970e8368 100644 --- a/src/ott/neural/flows/models.py +++ b/src/ott/neural/flows/models.py @@ -70,15 +70,12 @@ class VelocityField(nn.Module): act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.silu n_frequencies: int = 128 - def __post_init__(self): - - # set embedded dim from latent embedded dim + def __post_init__(self) -> None: if self.condition_embed_dim is None: self.condition_embed_dim = self.latent_embed_dim if self.t_embed_dim is None: self.t_embed_dim = self.latent_embed_dim - # set joint hidden dim from all embedded dim concat_embed_dim = ( self.latent_embed_dim + self.condition_embed_dim + self.t_embed_dim ) diff --git a/src/ott/neural/flows/otfm.py b/src/ott/neural/flows/otfm.py index 1be3bdc16..2ec4707c6 100644 --- a/src/ott/neural/flows/otfm.py +++ b/src/ott/neural/flows/otfm.py @@ -11,20 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import collections import functools import types -from collections import defaultdict -from typing import ( - Any, - Callable, - Dict, - Literal, - Mapping, - Optional, - Tuple, - Type, - Union, -) +from typing import Any, Callable, Dict, Literal, Mapping, Optional, Tuple, Type, Union import jax import jax.numpy as jnp @@ -36,18 +26,17 @@ from ott import utils from ott.geometry import costs -from ott.neural.flows.flows import BaseFlow -from ott.neural.models.base_solver import ( - BaseNeuralSolver, - ResampleMixin, - UnbalancednessMixin, -) +from ott.neural.flows import flows +from ott.neural.models import base_solver from ott.solvers import was_solver __all__ = ["OTFlowMatching"] -class OTFlowMatching(UnbalancednessMixin, ResampleMixin, BaseNeuralSolver): +class OTFlowMatching( + base_solver.UnbalancednessMixin, base_solver.ResampleMixin, + base_solver.BaseNeuralSolver +): """(Optimal transport) flow matching class. Flow matching as introduced in :cite:`lipman:22`, with extension to OT-FM @@ -97,9 +86,9 @@ def __init__( cond_dim: int, iterations: int, ot_solver: Optional[Type[was_solver.WassersteinSolver]], - flow: Type[BaseFlow], + flow: Type[flows.BaseFlow], time_sampler: Callable[[jax.Array, int], jnp.ndarray], - optimizer: Type[optax.GradientTransformation], + optimizer: optax.GradientTransformation, checkpoint_manager: Type[checkpoint.CheckpointManager] = None, epsilon: float = 1e-2, cost_fn: Optional[Type[costs.CostFn]] = None, @@ -120,11 +109,11 @@ def __init__( ): rng = utils.default_prng_key(rng) rng, rng_unbalanced = jax.random.split(rng) - BaseNeuralSolver.__init__( + base_solver.BaseNeuralSolver.__init__( self, iterations=iterations, valid_freq=valid_freq ) - ResampleMixin.__init__(self) - UnbalancednessMixin.__init__( + base_solver.ResampleMixin.__init__(self) + base_solver.UnbalancednessMixin.__init__( self, rng=rng_unbalanced, source_dim=input_dim, @@ -151,11 +140,11 @@ def __init__( self.rng = rng self.logging_freq = logging_freq self.num_eval_samples = num_eval_samples - self._training_logs: Mapping[str, Any] = defaultdict(list) + self._training_logs: Mapping[str, Any] = collections.defaultdict(list) self.setup() - def setup(self): + def setup(self) -> None: """Setup :class:`OTFlowMatching`.""" self.state_velocity_field = ( self.velocity_field.create_train_state( @@ -180,7 +169,7 @@ def _get_step_fn(self) -> Callable: @jax.jit def step_fn( - key: jax.random.PRNGKeyArray, + rng: jax.Array, state_velocity_field: train_state.TrainState, batch: Dict[str, jnp.ndarray], ) -> Tuple[Any, Any]: @@ -203,7 +192,7 @@ def loss_fn( return jnp.mean((v_t - u_t) ** 2) batch_size = len(batch["source_lin"]) - key_noise, key_t, key_model = jax.random.split(key, 3) + key_noise, key_t, key_model = jax.random.split(rng, 3) keys_model = jax.random.split(key_model, batch_size) t = self.time_sampler(key_t, batch_size) noise = self.sample_noise(key_noise, batch_size) @@ -300,7 +289,7 @@ def transport( t0, t1 = (t_0, t_1) if forward else (t_1, t_0) @jax.jit - def solve_ode(input: jnp.ndarray, cond: jnp.ndarray): + def solve_ode(input: jnp.ndarray, cond: jnp.ndarray) -> jnp.ndarray: return diffrax.diffeqsolve( diffrax.ODETerm( lambda t, x, args: self.state_velocity_field. diff --git a/src/ott/neural/flows/samplers.py b/src/ott/neural/flows/samplers.py index f5d0e0d17..1bfee16b4 100644 --- a/src/ott/neural/flows/samplers.py +++ b/src/ott/neural/flows/samplers.py @@ -16,16 +16,16 @@ import jax import jax.numpy as jnp -__all__ = ["sample_uniformly"] +__all__ = ["uniform_sampler"] -def sample_uniformly( +def uniform_sampler( rng: jax.Array, num_samples: int, low: float = 0.0, high: float = 1.0, offset: Optional[float] = None -): +) -> jnp.ndarray: """Sample from a uniform distribution. Sample :math:`t` from a uniform distribution :math:`[low, high]` with diff --git a/src/ott/neural/gaps/map_estimator.py b/src/ott/neural/gaps/map_estimator.py index cfcc8cb86..13dbc4ef4 100644 --- a/src/ott/neural/gaps/map_estimator.py +++ b/src/ott/neural/gaps/map_estimator.py @@ -13,16 +13,7 @@ # limitations under the License. import collections import functools -from typing import ( - Any, - Callable, - Dict, - Iterator, - Optional, - Sequence, - Tuple, - Union, -) +from typing import Any, Callable, Dict, Iterator, Optional, Sequence, Tuple, Union import jax import jax.numpy as jnp diff --git a/src/ott/neural/models/base_solver.py b/src/ott/neural/models/base_solver.py index 98272c028..071dc6e07 100644 --- a/src/ott/neural/models/base_solver.py +++ b/src/ott/neural/models/base_solver.py @@ -13,7 +13,6 @@ # limitations under the License. import abc import pathlib -import types from typing import Any, Callable, Dict, Literal, Mapping, Optional, Tuple, Union import jax @@ -43,7 +42,7 @@ def __init__(self, iterations: int, valid_freq: int, **_: Any): self.valid_freq = valid_freq @abc.abstractmethod - def setup(self, *args: Any, **kwargs: Any): + def setup(self, *args: Any, **kwargs: Any) -> None: """Setup the model.""" @abc.abstractmethod @@ -73,14 +72,14 @@ class ResampleMixin: def _resample_data( self, - key: jax.random.KeyArray, + rng: jax.Array, tmat: jnp.ndarray, source_arrays: Tuple[jnp.ndarray, ...], target_arrays: Tuple[jnp.ndarray, ...], ) -> Tuple[jnp.ndarray, ...]: """Resample a batch according to coupling `tmat`.""" tmat_flattened = tmat.flatten() - indices = jax.random.choice(key, len(tmat_flattened), shape=[tmat.shape[0]]) + indices = jax.random.choice(rng, len(tmat_flattened), shape=[tmat.shape[0]]) indices_source = indices // tmat.shape[1] indices_target = indices % tmat.shape[1] return tuple( @@ -91,7 +90,7 @@ def _resample_data( def _sample_conditional_indices_from_tmap( self, - key: jax.random.PRNGKeyArray, + rng: jax.Array, tmat: jnp.ndarray, k_samples_per_x: Union[int, jnp.ndarray], source_arrays: Tuple[jnp.ndarray, ...], @@ -102,22 +101,19 @@ def _sample_conditional_indices_from_tmap( batch_size = tmat.shape[0] left_marginals = tmat.sum(axis=1) if not source_is_balanced: - key, key2 = jax.random.split(key, 2) + rng, key2 = jax.random.split(rng, 2) indices = jax.random.choice( key=key2, a=jnp.arange(len(left_marginals)), p=left_marginals, shape=(len(left_marginals),) ) + tmat_adapted = tmat[indices] else: - indices = jnp.arange(batch_size) - tmat_adapted = tmat[indices] + tmat_adapted = tmat indices_per_row = jax.vmap( - lambda tmat_adapted: jax.random.choice( - key=key, - a=jnp.arange(batch_size), - p=tmat_adapted, - shape=(k_samples_per_x,) + lambda row: jax.random.choice( + key=rng, a=jnp.arange(batch_size), p=row, shape=(k_samples_per_x,) ), in_axes=0, out_axes=0, @@ -134,8 +130,8 @@ def _sample_conditional_indices_from_tmap( -1)) if b is not None else None for b in source_arrays ), tuple( - jnp.reshape(b[indices_target, :], (k_samples_per_x, batch_size, - -1)) if b is not None else None + jnp.reshape(b[indices_target], (k_samples_per_x, batch_size, + -1)) if b is not None else None for b in target_arrays ) @@ -154,9 +150,7 @@ def _get_sinkhorn_match_fn( ) -> Callable: @jax.jit - def match_pairs( - x: jnp.ndarray, y: jnp.ndarray - ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: + def match_pairs(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: geom = pointcloud.PointCloud( x, y, epsilon=epsilon, scale_cost=scale_cost, cost_fn=cost_fn ) @@ -168,7 +162,7 @@ def match_pairs( def match_pairs_filtered( x_lin: jnp.ndarray, x_quad: jnp.ndarray, y_lin: jnp.ndarray, y_quad: jnp.ndarray - ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: + ) -> jnp.ndarray: geom = pointcloud.PointCloud( x_lin, y_lin, epsilon=epsilon, scale_cost=scale_cost, cost_fn=cost_fn ) @@ -220,7 +214,7 @@ def match_pairs( x_quad: Tuple[jnp.ndarray, jnp.ndarray], y_lin: Optional[jnp.ndarray], y_quad: Tuple[jnp.ndarray, jnp.ndarray], - ) -> Tuple[jnp.array, jnp.array]: + ) -> jnp.ndarray: geom_xx = pointcloud.PointCloud( x=x_quad, y=x_quad, cost_fn=cost_fn_xx, scale_cost=scale_cost_xx ) @@ -262,14 +256,12 @@ def __init__( jnp.ndarray]] = None, rescaling_b: Optional[Callable[[jnp.ndarray, Optional[jnp.ndarray]], jnp.ndarray]] = None, - seed: Optional[int] = None, opt_eta: Optional[optax.GradientTransformation] = None, opt_xi: Optional[optax.GradientTransformation] = None, resample_epsilon: float = 1e-2, scale_cost: Union[bool, int, float, Literal["mean", "max_cost", "median"]] = "mean", - sinkhorn_kwargs: Mapping[str, Any] = types.MappingProxyType({}), - **_: Any, + **kwargs: Mapping[str, Any], ): self.rng_unbalanced = rng self.source_dim = source_dim @@ -279,7 +271,6 @@ def __init__( self.tau_b = tau_b self.rescaling_a = rescaling_a self.rescaling_b = rescaling_b - self.seed = seed self.opt_eta = opt_eta self.opt_xi = opt_xi self.resample_epsilon = resample_epsilon @@ -290,7 +281,7 @@ def __init__( tau_b=tau_b, resample_epsilon=resample_epsilon, scale_cost=scale_cost, - sinkhorn_kwargs=sinkhorn_kwargs + sinkorn_kwargs=kwargs ) self._setup(source_dim=source_dim, target_dim=target_dim, cond_dim=cond_dim) @@ -301,7 +292,7 @@ def _get_compute_unbalanced_marginals( resample_epsilon: float, scale_cost: Union[bool, int, float, Literal["mean", "max_cost", "median"]] = "mean", - sinkhorn_kwargs: Dict[str, Any] = types.MappingProxyType({}), + **kwargs: Dict[str, Any], ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Compute the unbalanced source and target marginals for a batch.""" @@ -315,23 +306,23 @@ def compute_unbalanced_marginals( epsilon=resample_epsilon, scale_cost=scale_cost ) - out = sinkhorn.Sinkhorn(**sinkhorn_kwargs)( + out = sinkhorn.Sinkhorn(**kwargs)( linear_problem.LinearProblem(geom, tau_a=tau_a, tau_b=tau_b) ) - return out.matrix.sum(axis=1), out.matrix.sum(axis=0) + return out.marginal(axis=1), out.marginal(axis=0) return compute_unbalanced_marginals @jax.jit def _resample_unbalanced( self, - key: jax.random.KeyArray, + rng: jax.Array, batch: Tuple[jnp.ndarray, ...], marginals: jnp.ndarray, ) -> Tuple[jnp.ndarray, ...]: """Resample a batch based on marginals.""" indices = jax.random.choice( - key, a=len(marginals), p=jnp.squeeze(marginals), shape=[len(marginals)] + rng, a=len(marginals), p=jnp.squeeze(marginals), shape=[len(marginals)] ) return tuple(b[indices] if b is not None else None for b in batch) diff --git a/src/ott/neural/models/layers.py b/src/ott/neural/models/layers.py index 46313b0e2..d0352ff05 100644 --- a/src/ott/neural/models/layers.py +++ b/src/ott/neural/models/layers.py @@ -11,23 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Tuple +from typing import Any -import jax import jax.numpy as jnp import flax.linen as nn __all__ = ["MLPBlock"] -PRNGKey = jax.Array -Shape = Tuple[int, ...] -Dtype = Any -Array = Any - class MLPBlock(nn.Module): - """An MLP block.""" + """An MLP block. + + Args: + dim: Dimensionality of the input data. + num_layers: Number of layers in the MLP block. + act_fn: Activation function. + out_dim: Dimensionality of the output data. + """ dim: int = 128 num_layers: int = 3 act_fn: Any = nn.silu diff --git a/src/ott/neural/models/models.py b/src/ott/neural/models/models.py index 93af5b58d..0acb7daae 100644 --- a/src/ott/neural/models/models.py +++ b/src/ott/neural/models/models.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional, Sequence +from typing import Callable, Optional import jax import jax.numpy as jnp @@ -22,70 +22,7 @@ from ott.neural.models import layers -__all__ = ["MLP", "RescalingMLP"] - - -class MLP(nn.Module): - """A generic, not-convex MLP. - - Args: - dim_hidden: sequence specifying size of hidden dimensions. The output - dimension of the last layer is automatically set to 1 if - :attr:`is_potential` is ``True``, or the dimension of the input otherwise - is_potential: Model the potential if ``True``, otherwise - model the gradient of the potential - act_fn: Activation function - """ - - dim_hidden: Sequence[int] - is_potential: bool = True - act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.leaky_relu - - @nn.compact - def __call__(self, x: jnp.ndarray) -> jnp.ndarray: # noqa: D102 - squeeze = x.ndim == 1 - if squeeze: - x = jnp.expand_dims(x, 0) - assert x.ndim == 2, x.ndim - n_input = x.shape[-1] - - z = x - for n_hidden in self.dim_hidden: - Wx = nn.Dense(n_hidden, use_bias=True) - z = self.act_fn(Wx(z)) - - if self.is_potential: - Wx = nn.Dense(1, use_bias=True) - z = Wx(z).squeeze(-1) - - quad_term = 0.5 * jax.vmap(jnp.dot)(x, x) - z += quad_term - else: - Wx = nn.Dense(n_input, use_bias=True) - z = x + Wx(z) - - return z.squeeze(0) if squeeze else z - - def create_train_state( - self, - rng: jax.Array, - optimizer: optax.OptState, - input_dim: int, - ) -> train_state.TrainState: - """Create the training state. - - Args: - rng: Random number generator. - optimizer: Optimizer. - input_dim: Dimensionality of the input. - - Returns: - Training state. - """ - params = self.init(rng, jnp.ones(input_dim))["params"] - return train_state.TrainState.create( - apply_fn=self.apply, params=params, tx=optimizer - ) +__all__ = ["RescalingMLP"] class RescalingMLP(nn.Module): @@ -119,12 +56,12 @@ def __call__( self, x: jnp.ndarray, condition: Optional[jnp.ndarray] = None - ) -> jnp.ndarray: # noqa: D102 + ) -> jnp.ndarray: """Forward pass through the rescaling network. Args: - x: Data. - condition: Condition. + x: Data of shape ``[n, ...]``. + condition: Condition of shape ``[n, condition_dim]``. Returns: Estimated rescaling factors. diff --git a/src/ott/solvers/linear/sinkhorn_lr.py b/src/ott/solvers/linear/sinkhorn_lr.py index db948cf8b..45b4e4721 100644 --- a/src/ott/solvers/linear/sinkhorn_lr.py +++ b/src/ott/solvers/linear/sinkhorn_lr.py @@ -11,16 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import ( - Any, - Callable, - Literal, - Mapping, - NamedTuple, - Optional, - Tuple, - Union, -) +from typing import Any, Callable, Literal, Mapping, NamedTuple, Optional, Tuple, Union import jax import jax.experimental diff --git a/src/ott/solvers/quadratic/__init__.py b/src/ott/solvers/quadratic/__init__.py index 560ac3ddd..507812971 100644 --- a/src/ott/solvers/quadratic/__init__.py +++ b/src/ott/solvers/quadratic/__init__.py @@ -11,10 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import ( - gromov_wasserstein, - gromov_wasserstein_lr, - gw_barycenter, - lower_bound, -) +from . import gromov_wasserstein, gromov_wasserstein_lr, gw_barycenter, lower_bound from ._solve import solve diff --git a/src/ott/solvers/quadratic/gromov_wasserstein_lr.py b/src/ott/solvers/quadratic/gromov_wasserstein_lr.py index 214853f4c..ad8c4130a 100644 --- a/src/ott/solvers/quadratic/gromov_wasserstein_lr.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein_lr.py @@ -12,16 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """A Jax implementation of the unbalanced low-rank GW algorithm.""" -from typing import ( - Any, - Callable, - Literal, - Mapping, - NamedTuple, - Optional, - Tuple, - Union, -) +from typing import Any, Callable, Literal, Mapping, NamedTuple, Optional, Tuple, Union import jax import jax.experimental diff --git a/src/ott/tools/gaussian_mixture/fit_gmm_pair.py b/src/ott/tools/gaussian_mixture/fit_gmm_pair.py index 7ecde263c..0c3c78ba3 100644 --- a/src/ott/tools/gaussian_mixture/fit_gmm_pair.py +++ b/src/ott/tools/gaussian_mixture/fit_gmm_pair.py @@ -84,11 +84,7 @@ import jax import jax.numpy as jnp -from ott.tools.gaussian_mixture import ( - fit_gmm, - gaussian_mixture, - gaussian_mixture_pair, -) +from ott.tools.gaussian_mixture import fit_gmm, gaussian_mixture, gaussian_mixture_pair __all__ = ["get_fit_model_em_fn"] diff --git a/src/ott/tools/gaussian_mixture/gaussian_mixture.py b/src/ott/tools/gaussian_mixture/gaussian_mixture.py index 313689939..576d937c8 100644 --- a/src/ott/tools/gaussian_mixture/gaussian_mixture.py +++ b/src/ott/tools/gaussian_mixture/gaussian_mixture.py @@ -16,12 +16,7 @@ import jax import jax.numpy as jnp -from ott.tools.gaussian_mixture import ( - gaussian, - linalg, - probabilities, - scale_tril, -) +from ott.tools.gaussian_mixture import gaussian, linalg, probabilities, scale_tril __all__ = ["GaussianMixture"] diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py index 723d25393..05cd38af1 100644 --- a/tests/neural/conftest.py +++ b/tests/neural/conftest.py @@ -15,7 +15,7 @@ import numpy as np -from ott.neural.data.dataloaders import ConditionalDataLoader, OTDataLoader +from ott.neural.data import dataloaders @pytest.fixture(scope="module") @@ -24,7 +24,7 @@ def data_loader_gaussian(): rng = np.random.default_rng(seed=0) source = rng.normal(size=(100, 2)) target = rng.normal(size=(100, 2)) + 1.0 - return OTDataLoader(16, source_lin=source, target_lin=target) + return dataloaders.OTDataLoader(16, source_lin=source, target_lin=target) @pytest.fixture(scope="module") @@ -36,20 +36,23 @@ def data_loader_gaussian_conditional(): source_1 = rng.normal(size=(100, 2)) target_1 = rng.normal(size=(100, 2)) - 2.0 - dl0 = OTDataLoader( + dl0 = dataloaders.OTDataLoader( 16, source_lin=source_0, target_lin=target_0, source_conditions=np.zeros_like(source_0) * 0.0 ) - dl1 = OTDataLoader( + dl1 = dataloaders.OTDataLoader( 16, source_lin=source_1, target_lin=target_1, source_conditions=np.ones_like(source_1) * 1.0 ) - return ConditionalDataLoader({"0": dl0, "1": dl1}, np.array([0.5, 0.5])) + return dataloaders.ConditionalDataLoader({ + "0": dl0, + "1": dl1 + }, np.array([0.5, 0.5])) @pytest.fixture(scope="module") @@ -60,7 +63,7 @@ def data_loader_gaussian_with_conditions(): target = rng.normal(size=(100, 2)) + 1.0 source_conditions = rng.normal(size=(100, 1)) target_conditions = rng.normal(size=(100, 1)) - 1.0 - return OTDataLoader( + return dataloaders.OTDataLoader( 16, source_lin=source, target_lin=target, @@ -75,7 +78,7 @@ def genot_data_loader_linear(): rng = np.random.default_rng(seed=0) source = rng.normal(size=(100, 2)) target = rng.normal(size=(100, 2)) + 1.0 - return OTDataLoader(16, source_lin=source, target_lin=target) + return dataloaders.OTDataLoader(16, source_lin=source, target_lin=target) @pytest.fixture(scope="module") @@ -85,7 +88,7 @@ def genot_data_loader_linear_conditional(): source = rng.normal(size=(100, 2)) target = rng.normal(size=(100, 2)) + 1.0 source_conditions = rng.normal(size=(100, 4)) - return OTDataLoader( + return dataloaders.OTDataLoader( 16, source_lin=source, target_lin=target, @@ -99,7 +102,7 @@ def genot_data_loader_quad(): rng = np.random.default_rng(seed=0) source = rng.normal(size=(100, 2)) target = rng.normal(size=(100, 1)) + 1.0 - return OTDataLoader(16, source_quad=source, target_quad=target) + return dataloaders.OTDataLoader(16, source_quad=source, target_quad=target) @pytest.fixture(scope="module") @@ -109,7 +112,7 @@ def genot_data_loader_quad_conditional(): source = rng.normal(size=(100, 2)) target = rng.normal(size=(100, 1)) + 1.0 source_conditions = rng.normal(size=(100, 7)) - return OTDataLoader( + return dataloaders.OTDataLoader( 16, source_quad=source, target_quad=target, @@ -125,7 +128,7 @@ def genot_data_loader_fused(): target_q = rng.normal(size=(100, 1)) + 1.0 source_lin = rng.normal(size=(100, 2)) target_lin = rng.normal(size=(100, 2)) + 1.0 - return OTDataLoader( + return dataloaders.OTDataLoader( 16, source_lin=source_lin, source_quad=source_q, @@ -143,7 +146,7 @@ def genot_data_loader_fused_conditional(): source_lin = rng.normal(size=(100, 2)) target_lin = rng.normal(size=(100, 2)) + 1.0 source_conditions = rng.normal(size=(100, 7)) - return OTDataLoader( + return dataloaders.OTDataLoader( 16, source_lin=source_lin, source_quad=source_q, diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index b3de698ad..5a5b9a847 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -23,7 +23,7 @@ from ott.geometry import costs from ott.neural.flows.genot import GENOT from ott.neural.flows.models import VelocityField -from ott.neural.flows.samplers import sample_uniformly +from ott.neural.flows.samplers import uniform_sampler from ott.neural.models.models import RescalingMLP from ott.solvers.linear import sinkhorn from ott.solvers.quadratic import gromov_wasserstein @@ -56,7 +56,7 @@ def test_genot_linear_unconditional( latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() - time_sampler = sample_uniformly + time_sampler = uniform_sampler optimizer = optax.adam(learning_rate=1e-3) genot = GENOT( neural_vf, @@ -106,7 +106,7 @@ def test_genot_quad_unconditional( latent_embed_dim=5, ) ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) - time_sampler = functools.partial(sample_uniformly, offset=1e-2) + time_sampler = functools.partial(uniform_sampler, offset=1e-2) optimizer = optax.adam(learning_rate=1e-3) genot = GENOT( neural_vf, @@ -198,7 +198,7 @@ def test_genot_linear_conditional( latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() - time_sampler = sample_uniformly + time_sampler = uniform_sampler optimizer = optax.adam(learning_rate=1e-3) genot = GENOT( neural_vf, @@ -245,7 +245,7 @@ def test_genot_quad_conditional( latent_embed_dim=5, ) ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) - time_sampler = sample_uniformly + time_sampler = uniform_sampler optimizer = optax.adam(learning_rate=1e-3) genot = GENOT( neural_vf, @@ -292,7 +292,7 @@ def test_genot_fused_conditional( latent_embed_dim=5, ) ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) - time_sampler = sample_uniformly + time_sampler = uniform_sampler optimizer = optax.adam(learning_rate=1e-3) genot = GENOT( neural_vf, @@ -348,7 +348,7 @@ def test_genot_linear_learn_rescaling( latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() - time_sampler = sample_uniformly + time_sampler = uniform_sampler optimizer = optax.adam(learning_rate=1e-3) tau_a = 0.9 tau_b = 0.2 diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index 1230a638b..7f6a1a8dc 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -20,36 +20,31 @@ import optax -from ott.neural.flows.flows import ( - BaseFlow, - BrownianNoiseFlow, - ConstantNoiseFlow, -) -from ott.neural.flows.models import VelocityField -from ott.neural.flows.otfm import OTFlowMatching -from ott.neural.flows.samplers import sample_uniformly -from ott.neural.models.models import RescalingMLP +from ott.neural.flows import flows, models, otfm, samplers from ott.solvers.linear import sinkhorn class TestOTFlowMatching: @pytest.mark.parametrize( - "flow", - [ConstantNoiseFlow(0.0), - ConstantNoiseFlow(1.0), - BrownianNoiseFlow(0.2)] + "flow", [ + flows.ConstantNoiseFlow(0.0), + flows.ConstantNoiseFlow(1.0), + flows.BrownianNoiseFlow(0.2) + ] ) - def test_flow_matching(self, data_loader_gaussian, flow: Type[BaseFlow]): - neural_vf = VelocityField( + def test_flow_matching( + self, data_loader_gaussian, flow: Type[flows.BaseFlow] + ): + neural_vf = models.VelocityField( output_dim=2, condition_dim=0, latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() - time_sampler = sample_uniformly + time_sampler = samplers.uniform_sampler optimizer = optax.adam(learning_rate=1e-3) - fm = OTFlowMatching( + fm = otfm.OTFlowMatching( neural_vf, input_dim=2, cond_dim=0, @@ -78,23 +73,24 @@ def test_flow_matching(self, data_loader_gaussian, flow: Type[BaseFlow]): assert jnp.sum(jnp.isnan(result_backward)) == 0 @pytest.mark.parametrize( - "flow", - [ConstantNoiseFlow(0.0), - ConstantNoiseFlow(1.0), - BrownianNoiseFlow(0.2)] + "flow", [ + flows.ConstantNoiseFlow(0.0), + flows.ConstantNoiseFlow(1.0), + flows.BrownianNoiseFlow(0.2) + ] ) def test_flow_matching_with_conditions( - self, data_loader_gaussian_with_conditions, flow: Type[BaseFlow] + self, data_loader_gaussian_with_conditions, flow: Type[flows.BaseFlow] ): - neural_vf = VelocityField( + neural_vf = models.VelocityField( output_dim=2, condition_dim=1, latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() - time_sampler = functools.partial(sample_uniformly, offset=1e-5) + time_sampler = functools.partial(samplers.uniform_sampler, offset=1e-5) optimizer = optax.adam(learning_rate=1e-3) - fm = OTFlowMatching( + fm = otfm.OTFlowMatching( neural_vf, input_dim=2, cond_dim=1, @@ -126,23 +122,24 @@ def test_flow_matching_with_conditions( assert jnp.sum(jnp.isnan(result_backward)) == 0 @pytest.mark.parametrize( - "flow", - [ConstantNoiseFlow(0.0), - ConstantNoiseFlow(1.0), - BrownianNoiseFlow(0.2)] + "flow", [ + flows.ConstantNoiseFlow(0.0), + flows.ConstantNoiseFlow(1.0), + flows.BrownianNoiseFlow(0.2) + ] ) def test_flow_matching_conditional( - self, data_loader_gaussian_conditional, flow: Type[BaseFlow] + self, data_loader_gaussian_conditional, flow: Type[flows.BaseFlow] ): - neural_vf = VelocityField( + neural_vf = models.VelocityField( output_dim=2, condition_dim=0, latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() - time_sampler = sample_uniformly + time_sampler = samplers.uniform_sampler optimizer = optax.adam(learning_rate=1e-3) - fm = OTFlowMatching( + fm = otfm.OTFlowMatching( neural_vf, input_dim=2, cond_dim=0, @@ -182,21 +179,21 @@ def test_flow_matching_learn_rescaling( batch = next(data_loader) source_dim = batch["source_lin"].shape[1] condition_dim = batch["source_conditions"].shape[1] if conditional else 0 - neural_vf = VelocityField( + neural_vf = models.VelocityField( output_dim=2, condition_dim=0, latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() - time_sampler = sample_uniformly - flow = ConstantNoiseFlow(1.0) + time_sampler = samplers.uniform_sampler + flow = flows.ConstantNoiseFlow(1.0) optimizer = optax.adam(learning_rate=1e-3) tau_a = 0.9 tau_b = 0.2 - rescaling_a = RescalingMLP(hidden_dim=4, condition_dim=condition_dim) - rescaling_b = RescalingMLP(hidden_dim=4, condition_dim=condition_dim) - fm = OTFlowMatching( + rescaling_a = models.RescalingMLP(hidden_dim=4, condition_dim=condition_dim) + rescaling_b = models.RescalingMLP(hidden_dim=4, condition_dim=condition_dim) + fm = otfm.OTFlowMatching( neural_vf, input_dim=source_dim, cond_dim=condition_dim, From a94b585b53e83c23e9498abd8bae56b11e584c3e Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 9 Feb 2024 11:17:37 +0100 Subject: [PATCH 078/186] [ci skip] intermediate save --- src/ott/neural/models/base_solver.py | 2 +- tests/neural/genot_test.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/ott/neural/models/base_solver.py b/src/ott/neural/models/base_solver.py index 071dc6e07..71825aa27 100644 --- a/src/ott/neural/models/base_solver.py +++ b/src/ott/neural/models/base_solver.py @@ -121,7 +121,7 @@ def _sample_conditional_indices_from_tmap( tmat_adapted ) - indices_source = jnp.repeat(indices, k_samples_per_x) + indices_source = jnp.repeat(indices_per_row, k_samples_per_x) indices_target = jnp.reshape( indices_per_row % tmat.shape[1], (batch_size * k_samples_per_x,) ) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index 5a5b9a847..4960c1bec 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -50,6 +50,10 @@ def test_genot_linear_unconditional( target_dim = target_lin.shape[1] condition_dim = 0 + print("source dim is ", source_dim) + print("target dim is ", target_dim) + print("condition dim is ", condition_dim) + neural_vf = VelocityField( output_dim=target_dim, condition_dim=source_dim + condition_dim, @@ -294,6 +298,9 @@ def test_genot_fused_conditional( ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) time_sampler = uniform_sampler optimizer = optax.adam(learning_rate=1e-3) + print("source dim is ", source_dim) + print("target dim is ", target_dim) + print("condition dim is ", condition_dim) genot = GENOT( neural_vf, input_dim=source_dim, From 78b5e10f875816de268823c9cac48eb13b0f49a7 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 9 Feb 2024 11:31:35 +0100 Subject: [PATCH 079/186] [ci skip] neural base solver update --- src/ott/neural/models/base_solver.py | 29 ++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/src/ott/neural/models/base_solver.py b/src/ott/neural/models/base_solver.py index 71825aa27..b078393d3 100644 --- a/src/ott/neural/models/base_solver.py +++ b/src/ott/neural/models/base_solver.py @@ -42,7 +42,7 @@ def __init__(self, iterations: int, valid_freq: int, **_: Any): self.valid_freq = valid_freq @abc.abstractmethod - def setup(self, *args: Any, **kwargs: Any) -> None: + def setup(self, *args: Any, **kwargs: Any): """Setup the model.""" @abc.abstractmethod @@ -90,7 +90,7 @@ def _resample_data( def _sample_conditional_indices_from_tmap( self, - rng: jax.Array, + key: jax.random.PRNGKeyArray, tmat: jnp.ndarray, k_samples_per_x: Union[int, jnp.ndarray], source_arrays: Tuple[jnp.ndarray, ...], @@ -101,19 +101,22 @@ def _sample_conditional_indices_from_tmap( batch_size = tmat.shape[0] left_marginals = tmat.sum(axis=1) if not source_is_balanced: - rng, key2 = jax.random.split(rng, 2) + key, key2 = jax.random.split(key, 2) indices = jax.random.choice( key=key2, a=jnp.arange(len(left_marginals)), p=left_marginals, shape=(len(left_marginals),) ) - tmat_adapted = tmat[indices] else: - tmat_adapted = tmat + indices = jnp.arange(batch_size) + tmat_adapted = tmat[indices] indices_per_row = jax.vmap( - lambda row: jax.random.choice( - key=rng, a=jnp.arange(batch_size), p=row, shape=(k_samples_per_x,) + lambda tmat_adapted: jax.random.choice( + key=key, + a=jnp.arange(batch_size), + p=tmat_adapted, + shape=(k_samples_per_x,) ), in_axes=0, out_axes=0, @@ -121,7 +124,7 @@ def _sample_conditional_indices_from_tmap( tmat_adapted ) - indices_source = jnp.repeat(indices_per_row, k_samples_per_x) + indices_source = jnp.repeat(indices, k_samples_per_x) indices_target = jnp.reshape( indices_per_row % tmat.shape[1], (batch_size * k_samples_per_x,) ) @@ -130,8 +133,8 @@ def _sample_conditional_indices_from_tmap( -1)) if b is not None else None for b in source_arrays ), tuple( - jnp.reshape(b[indices_target], (k_samples_per_x, batch_size, - -1)) if b is not None else None + jnp.reshape(b[indices_target, :], (k_samples_per_x, batch_size, + -1)) if b is not None else None for b in target_arrays ) @@ -256,6 +259,7 @@ def __init__( jnp.ndarray]] = None, rescaling_b: Optional[Callable[[jnp.ndarray, Optional[jnp.ndarray]], jnp.ndarray]] = None, + seed: Optional[int] = None, opt_eta: Optional[optax.GradientTransformation] = None, opt_xi: Optional[optax.GradientTransformation] = None, resample_epsilon: float = 1e-2, @@ -271,6 +275,7 @@ def __init__( self.tau_b = tau_b self.rescaling_a = rescaling_a self.rescaling_b = rescaling_b + self.seed = seed self.opt_eta = opt_eta self.opt_xi = opt_xi self.resample_epsilon = resample_epsilon @@ -281,7 +286,7 @@ def __init__( tau_b=tau_b, resample_epsilon=resample_epsilon, scale_cost=scale_cost, - sinkorn_kwargs=kwargs + **kwargs ) self._setup(source_dim=source_dim, target_dim=target_dim, cond_dim=cond_dim) @@ -292,7 +297,7 @@ def _get_compute_unbalanced_marginals( resample_epsilon: float, scale_cost: Union[bool, int, float, Literal["mean", "max_cost", "median"]] = "mean", - **kwargs: Dict[str, Any], + **kwargs: Mapping[str, Any], ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Compute the unbalanced source and target marginals for a batch.""" From 592564fd3b47f10fb677fc8b9f2e4262b72b45cd Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 9 Feb 2024 14:17:48 +0100 Subject: [PATCH 080/186] make resamlpemixin a class --- src/ott/neural/flows/genot.py | 73 ++-- src/ott/neural/flows/otfm.py | 67 +--- src/ott/neural/models/__init__.py | 2 +- src/ott/neural/models/base_solver.py | 370 ++++++++++--------- src/ott/neural/models/{models.py => nets.py} | 0 tests/neural/genot_test.py | 69 +++- tests/neural/losses_test.py | 4 +- tests/neural/map_estimator_test.py | 4 +- tests/neural/otfm_test.py | 64 +++- 9 files changed, 336 insertions(+), 317 deletions(-) rename src/ott/neural/models/{models.py => nets.py} (100%) diff --git a/src/ott/neural/flows/genot.py b/src/ott/neural/flows/genot.py index b11b77b20..736d9e268 100644 --- a/src/ott/neural/flows/genot.py +++ b/src/ott/neural/flows/genot.py @@ -21,7 +21,6 @@ import diffrax import optax from flax.training import train_state -from orbax import checkpoint from ott import utils from ott.geometry import costs @@ -35,8 +34,7 @@ class GENOT( - base_solver.UnbalancednessMixin, base_solver.ResampleMixin, - base_solver.BaseNeuralSolver + base_solver.ResampleMixin, ): """The GENOT training class as introduced in :cite:`klein_uscidda:23`. @@ -68,7 +66,7 @@ class GENOT( optimizer: Optimizer for `velocity_field`. flow: Flow between latent distribution and target distribution. time_sampler: Sampler for the time. - checkpoint_manager: Checkpoint manager. + unbalancedness_handler: Handler for unbalancedness. k_samples_per_x: Number of samples drawn from the conditional distribution of an input sample, see algorithm TODO. solver_latent_to_data: Linear OT solver to match the latent distribution @@ -77,15 +75,6 @@ class GENOT( #TODO: adapt fused_penalty: Fused penalty of the linear/fused term in the Fused Gromov-Wasserstein problem. - tau_a: If :math:`<1`, defines how much unbalanced the problem is - on the first marginal. - tau_b: If :math:`< 1`, defines how much unbalanced the problem is - on the second marginal. - rescaling_a: Neural network to learn the left rescaling function. If - :obj:`None`, the left rescaling factor is not learnt. - rescaling_b: Neural network to learn the right rescaling function. If - :obj:`None`, the right rescaling factor is not learnt. - unbalanced_kwargs: Keyword arguments for the unbalancedness solver. callback_fn: Callback function. rng: Random number generator. """ @@ -109,43 +98,23 @@ def __init__( Dict[str, Union[bool, int, float, Literal["mean", "max_norm", "max_bound", "max_cost", "median"]]]], + unbalancedness_handler: base_solver.UnbalancednessHandler, optimizer: optax.GradientTransformation, flow: Type[flows.BaseFlow] = flows.ConstantNoiseFlow(0.0), # noqa: B008 time_sampler: Callable[[jax.Array, int], jnp.ndarray] = samplers.uniform_sampler, - checkpoint_manager: Type[checkpoint.CheckpointManager] = None, k_samples_per_x: int = 1, solver_latent_to_data: Optional[Type[was_solver.WassersteinSolver] ] = None, kwargs_solver_latent_to_data: Dict[str, Any] = types.MappingProxyType({}), fused_penalty: float = 0.0, - tau_a: float = 1.0, - tau_b: float = 1.0, - rescaling_a: Callable[[jnp.ndarray], float] = None, - rescaling_b: Callable[[jnp.ndarray], float] = None, - unbalanced_kwargs: Dict[str, Any] = types.MappingProxyType({}), callback_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], Any]] = None, rng: Optional[jax.Array] = None, ): rng = utils.default_prng_key(rng) - rng, rng_unbalanced = jax.random.split(rng) - base_solver.BaseNeuralSolver.__init__( - self, iterations=iterations, valid_freq=valid_freq - ) base_solver.ResampleMixin.__init__(self) - base_solver.UnbalancednessMixin.__init__( - self, - rng=rng_unbalanced, - source_dim=input_dim, - target_dim=input_dim, - cond_dim=cond_dim, - tau_a=tau_a, - tau_b=tau_b, - rescaling_a=rescaling_a, - rescaling_b=rescaling_b, - unbalanced_kwargs=unbalanced_kwargs, - ) + if isinstance( ot_solver, gromov_wasserstein.GromovWasserstein ) and epsilon is not None: @@ -156,12 +125,13 @@ def __init__( ) self.rng = utils.default_prng_key(rng) + self.iterations = iterations + self.valid_freq = valid_freq self.velocity_field = velocity_field self.state_velocity_field: Optional[train_state.TrainState] = None self.flow = flow self.time_sampler = time_sampler self.optimizer = optimizer - self.checkpoint_manager = checkpoint_manager self.latent_noise_fn = jax.tree_util.Partial( jax.random.multivariate_normal, mean=jnp.zeros((output_dim,)), @@ -172,6 +142,9 @@ def __init__( self.cond_dim = cond_dim self.k_samples_per_x = k_samples_per_x + # unbalancedness + self.unbalancedness_handler = unbalancedness_handler + # OT data-data matching parameters self.ot_solver = ot_solver self.epsilon = epsilon @@ -210,8 +183,8 @@ def setup(self) -> None: epsilon=self.epsilon, cost_fn=self.cost_fn, scale_cost=self.scale_cost, - tau_a=self.tau_a, - tau_b=self.tau_b, + tau_a=self.unbalancedness_handler.tau_a, + tau_b=self.unbalancedness_handler.tau_b, filter_input=True ) else: @@ -219,8 +192,8 @@ def setup(self) -> None: ot_solver=self.ot_solver, cost_fn=self.cost_fn, scale_cost=self.scale_cost, - tau_a=self.tau_a, - tau_b=self.tau_b, + tau_a=self.unbalancedness_handler.tau_a, + tau_b=self.unbalancedness_handler.tau_b, fused_penalty=self.fused_penalty ) @@ -278,7 +251,7 @@ def __call__(self, train_loader, valid_loader): tmat, self.k_samples_per_x, (batch["source"], batch["source_conditions"]), (batch["target"],), - source_is_balanced=(self.tau_a == 1.0) + source_is_balanced=(self.unbalancedness_handler.tau_a == 1.0) ) jax.random.split(rng_noise, batch_size * self.k_samples_per_x) @@ -310,24 +283,17 @@ def __call__(self, train_loader, valid_loader): ( self.state_eta, self.state_xi, eta_predictions, xi_predictions, loss_a, loss_b - ) = self.unbalancedness_step_fn( + ) = self.unbalancedness_handler.step_fn( source=batch["source"], target=batch["target"], condition=batch["source_conditions"], a=tmat.sum(axis=1), b=tmat.sum(axis=0), - state_eta=self.state_eta, - state_xi=self.state_xi, + state_eta=self.unbalancedness_handler.state_eta, + state_xi=self.unbalancedness_handler.state_xi, ) if iteration % self.valid_freq == 0: self._valid_step(valid_loader, iteration) - if self.checkpoint_manager is not None: - states_to_save = {"state_velocity_field": self.state_velocity_field} - if self.state_eta is not None: - states_to_save["state_eta"] = self.state_eta - if self.state_xi is not None: - states_to_save["state_xi"] = self.state_xi - self.checkpoint_manager.save(iteration, states_to_save) def _get_step_fn(self) -> Callable: @@ -441,7 +407,10 @@ def _valid_step(self, valid_loader, iter): @property def learn_rescaling(self) -> bool: """Whether to learn at least one rescaling factor.""" - return self.rescaling_a is not None or self.rescaling_b is not None + return ( + self.unbalancedness_handler.rescaling_a is not None or + self.unbalancedness_handler.rescaling_b is not None + ) def save(self, path: str): """Save the model. diff --git a/src/ott/neural/flows/otfm.py b/src/ott/neural/flows/otfm.py index 2ec4707c6..ef55e1dc5 100644 --- a/src/ott/neural/flows/otfm.py +++ b/src/ott/neural/flows/otfm.py @@ -34,8 +34,7 @@ class OTFlowMatching( - base_solver.UnbalancednessMixin, base_solver.ResampleMixin, - base_solver.BaseNeuralSolver + base_solver.ResampleMixin, ): """(Optimal transport) flow matching class. @@ -61,17 +60,6 @@ class OTFlowMatching( cost_fn: Cost function for the OT problem solved by the `ot_solver`. scale_cost: How to scale the cost matrix for the OT problem solved by the `ot_solver`. - tau_a: If :math:`<1`, defines how much unbalanced the problem is - on the first marginal. - tau_b: If :math:`< 1`, defines how much unbalanced the problem is - on the second marginal. - rescaling_a: Neural network to learn the left rescaling function as - suggested in :cite:`eyring:23`. If :obj:`None`, the left rescaling factor - is not learnt. - rescaling_b: Neural network to learn the right rescaling function as - suggested in :cite:`eyring:23`. If :obj:`None`, the right rescaling factor - is not learnt. - unbalanced_kwargs: Keyword arguments for the unbalancedness solver. callback_fn: Callback function. num_eval_samples: Number of samples to evaluate on during evaluation. rng: Random number generator. @@ -89,17 +77,13 @@ def __init__( flow: Type[flows.BaseFlow], time_sampler: Callable[[jax.Array, int], jnp.ndarray], optimizer: optax.GradientTransformation, + unbalancedness_handler: base_solver.UnbalancednessHandler, checkpoint_manager: Type[checkpoint.CheckpointManager] = None, epsilon: float = 1e-2, cost_fn: Optional[Type[costs.CostFn]] = None, scale_cost: Union[bool, int, float, Literal["mean", "max_norm", "max_bound", "max_cost", "median"]] = "mean", - tau_a: float = 1.0, - tau_b: float = 1.0, - rescaling_a: Callable[[jnp.ndarray], float] = None, - rescaling_b: Callable[[jnp.ndarray], float] = None, - unbalanced_kwargs: Dict[str, Any] = types.MappingProxyType({}), callback_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], Any]] = None, logging_freq: int = 100, @@ -108,24 +92,10 @@ def __init__( rng: Optional[jax.Array] = None, ): rng = utils.default_prng_key(rng) - rng, rng_unbalanced = jax.random.split(rng) - base_solver.BaseNeuralSolver.__init__( - self, iterations=iterations, valid_freq=valid_freq - ) base_solver.ResampleMixin.__init__(self) - base_solver.UnbalancednessMixin.__init__( - self, - rng=rng_unbalanced, - source_dim=input_dim, - target_dim=input_dim, - cond_dim=cond_dim, - tau_a=tau_a, - tau_b=tau_b, - rescaling_a=rescaling_a, - rescaling_b=rescaling_b, - unbalanced_kwargs=unbalanced_kwargs, - ) - + self.unbalancedness_handler = unbalancedness_handler + self.iterations = iterations + self.valid_freq = valid_freq self.velocity_field = velocity_field self.input_dim = input_dim self.ot_solver = ot_solver @@ -159,8 +129,8 @@ def setup(self) -> None: epsilon=self.epsilon, cost_fn=self.cost_fn, scale_cost=self.scale_cost, - tau_a=self.tau_a, - tau_b=self.tau_b, + tau_a=self.unbalancedness_handler.tau_a, + tau_b=self.unbalancedness_handler.tau_b, ) else: self.match_fn = None @@ -235,26 +205,20 @@ def __call__(self, train_loader, valid_loader): curr_loss = 0.0 if self.learn_rescaling: ( - self.state_eta, self.state_xi, eta_predictions, xi_predictions, - loss_a, loss_b - ) = self.unbalancedness_step_fn( + self.unbalancedness_handler.state_eta, + self.unbalancedness_handler.state_xi, eta_predictions, + xi_predictions, loss_a, loss_b + ) = self.unbalancedness_handler.step_fn( source=batch["source_lin"], target=batch["target_lin"], condition=batch["source_conditions"], a=tmat.sum(axis=1), b=tmat.sum(axis=0), - state_eta=self.state_eta, - state_xi=self.state_xi, + state_eta=self.unbalancedness_handler.state_eta, + state_xi=self.unbalancedness_handler.state_xi, ) if iter % self.valid_freq == 0: self._valid_step(valid_loader, iter) - if self.checkpoint_manager is not None: - states_to_save = {"state_velocity_field": self.state_velocity_field} - if self.state_eta is not None: - states_to_save["state_eta"] = self.state_eta - if self.state_xi is not None: - states_to_save["state_xi"] = self.state_xi - self.checkpoint_manager.save(iter, states_to_save) def transport( self, @@ -319,7 +283,10 @@ def _valid_step(self, valid_loader, iter): @property def learn_rescaling(self) -> bool: """Whether to learn at least one rescaling factor.""" - return self.rescaling_a is not None or self.rescaling_b is not None + return ( + self.unbalancedness_handler.rescaling_a is not None or + self.unbalancedness_handler.rescaling_b is not None + ) def save(self, path: str): """Save the model. diff --git a/src/ott/neural/models/__init__.py b/src/ott/neural/models/__init__.py index 5c2ac3b2b..ba39ae8b4 100644 --- a/src/ott/neural/models/__init__.py +++ b/src/ott/neural/models/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import base_solver, layers, models +from . import base_solver, layers, nets diff --git a/src/ott/neural/models/base_solver.py b/src/ott/neural/models/base_solver.py index b078393d3..1bc541ec7 100644 --- a/src/ott/neural/models/base_solver.py +++ b/src/ott/neural/models/base_solver.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import abc -import pathlib from typing import Any, Callable, Dict, Literal, Mapping, Optional, Tuple, Union import jax @@ -24,47 +22,114 @@ from ott.geometry import costs, pointcloud from ott.problems.linear import linear_problem from ott.problems.quadratic import quadratic_problem +from ott.solvers import was_solver from ott.solvers.linear import sinkhorn +from ott.solvers.quadratic import gromov_wasserstein -__all__ = ["BaseNeuralSolver", "ResampleMixin", "UnbalancednessMixin"] +__all__ = ["ResampleMixin", "UnbalancednessHandler"] -class BaseNeuralSolver(abc.ABC): - """Base class for neural solvers. +def _get_sinkhorn_match_fn( + ot_solver: Any, + epsilon: float = 1e-2, + cost_fn: Optional[costs.CostFn] = None, + scale_cost: Union[bool, int, float, Literal["mean", "max_norm", "max_bound", + "max_cost", "median"]] = "mean", + tau_a: float = 1.0, + tau_b: float = 1.0, + *, + filter_input: bool = False, +) -> Callable: - Args: - iterations: Number of iterations to train for. - valid_freq: Frequency at which to run validation. - """ - - def __init__(self, iterations: int, valid_freq: int, **_: Any): - self.iterations = iterations - self.valid_freq = valid_freq - - @abc.abstractmethod - def setup(self, *args: Any, **kwargs: Any): - """Setup the model.""" - - @abc.abstractmethod - def __call__(self, *args: Any, **kwargs: Any): - """Train the model.""" + @jax.jit + def match_pairs(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + geom = pointcloud.PointCloud( + x, y, epsilon=epsilon, scale_cost=scale_cost, cost_fn=cost_fn + ) + return ot_solver( + linear_problem.LinearProblem(geom, tau_a=tau_a, tau_b=tau_b) + ) - @abc.abstractmethod - def transport(self, *args: Any, forward: bool, **kwargs: Any) -> Any: - """Transport.""" + @jax.jit + def match_pairs_filtered( + x_lin: jnp.ndarray, x_quad: jnp.ndarray, y_lin: jnp.ndarray, + y_quad: jnp.ndarray + ) -> jnp.ndarray: + geom = pointcloud.PointCloud( + x_lin, y_lin, epsilon=epsilon, scale_cost=scale_cost, cost_fn=cost_fn + ) + return ot_solver( + linear_problem.LinearProblem(geom, tau_a=tau_a, tau_b=tau_b) + ) - @abc.abstractmethod - def save(self, path: pathlib.Path): - """Save the model.""" + return match_pairs_filtered if filter_input else match_pairs + + +def _get_gromov_match_fn( + ot_solver: Any, + cost_fn: Union[Any, Mapping[str, Any]], + scale_cost: Union[Union[bool, int, float, + Literal["mean", "max_norm", "max_bound", "max_cost", + "median"]], + Dict[str, Union[bool, int, float, + Literal["mean", "max_norm", "max_bound", + "max_cost", "median"]]]], + tau_a: float, + tau_b: float, + fused_penalty: float, +) -> Callable: + if isinstance(cost_fn, Mapping): + assert "cost_fn_xx" in cost_fn + assert "cost_fn_yy" in cost_fn + cost_fn_xx = cost_fn["cost_fn_xx"] + cost_fn_yy = cost_fn["cost_fn_yy"] + if fused_penalty > 0: + assert "cost_fn_xy" in cost_fn_xx + cost_fn_xy = cost_fn["cost_fn_xy"] + else: + cost_fn_xx = cost_fn_yy = cost_fn_xy = cost_fn + + if isinstance(scale_cost, Mapping): + assert "scale_cost_xx" in scale_cost + assert "scale_cost_yy" in scale_cost + scale_cost_xx = scale_cost["scale_cost_xx"] + scale_cost_yy = scale_cost["scale_cost_yy"] + if fused_penalty > 0: + assert "scale_cost_xy" in scale_cost + scale_cost_xy = cost_fn["scale_cost_xy"] + else: + scale_cost_xx = scale_cost_yy = scale_cost_xy = scale_cost - @abc.abstractmethod - def load(self, path: pathlib.Path): - """Load the model.""" + @jax.jit + def match_pairs( + x_lin: Optional[jnp.ndarray], + x_quad: Tuple[jnp.ndarray, jnp.ndarray], + y_lin: Optional[jnp.ndarray], + y_quad: Tuple[jnp.ndarray, jnp.ndarray], + ) -> jnp.ndarray: + geom_xx = pointcloud.PointCloud( + x=x_quad, y=x_quad, cost_fn=cost_fn_xx, scale_cost=scale_cost_xx + ) + geom_yy = pointcloud.PointCloud( + x=y_quad, y=y_quad, cost_fn=cost_fn_yy, scale_cost=scale_cost_yy + ) + if fused_penalty > 0: + geom_xy = pointcloud.PointCloud( + x=x_lin, y=y_lin, cost_fn=cost_fn_xy, scale_cost=scale_cost_xy + ) + else: + geom_xy = None + prob = quadratic_problem.QuadraticProblem( + geom_xx, + geom_yy, + geom_xy, + fused_penalty=fused_penalty, + tau_a=tau_a, + tau_b=tau_b + ) + return ot_solver(prob) - @property - @abc.abstractmethod - def training_logs(self) -> Dict[str, Any]: - """Return the training logs.""" + return match_pairs class ResampleMixin: @@ -83,14 +148,14 @@ def _resample_data( indices_source = indices // tmat.shape[1] indices_target = indices % tmat.shape[1] return tuple( - b[indices_source, :] if b is not None else None for b in source_arrays + b[indices_source] if b is not None else None for b in source_arrays ), tuple( - b[indices_target, :] if b is not None else None for b in target_arrays + b[indices_target] if b is not None else None for b in target_arrays ) def _sample_conditional_indices_from_tmap( self, - key: jax.random.PRNGKeyArray, + rng: jax.Array, tmat: jnp.ndarray, k_samples_per_x: Union[int, jnp.ndarray], source_arrays: Tuple[jnp.ndarray, ...], @@ -101,9 +166,9 @@ def _sample_conditional_indices_from_tmap( batch_size = tmat.shape[0] left_marginals = tmat.sum(axis=1) if not source_is_balanced: - key, key2 = jax.random.split(key, 2) + rng, rng_2 = jax.random.split(rng, 2) indices = jax.random.choice( - key=key2, + key=rng_2, a=jnp.arange(len(left_marginals)), p=left_marginals, shape=(len(left_marginals),) @@ -112,11 +177,8 @@ def _sample_conditional_indices_from_tmap( indices = jnp.arange(batch_size) tmat_adapted = tmat[indices] indices_per_row = jax.vmap( - lambda tmat_adapted: jax.random.choice( - key=key, - a=jnp.arange(batch_size), - p=tmat_adapted, - shape=(k_samples_per_x,) + lambda row: jax.random.choice( + key=rng, a=jnp.arange(batch_size), p=row, shape=(k_samples_per_x,) ), in_axes=0, out_axes=0, @@ -133,119 +195,60 @@ def _sample_conditional_indices_from_tmap( -1)) if b is not None else None for b in source_arrays ), tuple( - jnp.reshape(b[indices_target, :], (k_samples_per_x, batch_size, - -1)) if b is not None else None + jnp.reshape(b[indices_target], (k_samples_per_x, batch_size, + -1)) if b is not None else None for b in target_arrays ) - def _get_sinkhorn_match_fn( - self, - ot_solver: Any, - epsilon: float = 1e-2, - cost_fn: Optional[costs.CostFn] = None, - scale_cost: Union[bool, int, float, - Literal["mean", "max_norm", "max_bound", "max_cost", - "median"]] = "mean", - tau_a: float = 1.0, - tau_b: float = 1.0, - *, - filter_input: bool = False, - ) -> Callable: + def _get_sinkhorn_match_fn(self, *args, **kwargs) -> jnp.ndarray: + fn = _get_sinkhorn_match_fn(*args, **kwargs) @jax.jit - def match_pairs(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: - geom = pointcloud.PointCloud( - x, y, epsilon=epsilon, scale_cost=scale_cost, cost_fn=cost_fn - ) - return ot_solver( - linear_problem.LinearProblem(geom, tau_a=tau_a, tau_b=tau_b) - ).matrix + def match_pairs(*args, **kwargs): + return fn(*args, **kwargs).matrix - @jax.jit - def match_pairs_filtered( - x_lin: jnp.ndarray, x_quad: jnp.ndarray, y_lin: jnp.ndarray, - y_quad: jnp.ndarray - ) -> jnp.ndarray: - geom = pointcloud.PointCloud( - x_lin, y_lin, epsilon=epsilon, scale_cost=scale_cost, cost_fn=cost_fn - ) - return ot_solver( - linear_problem.LinearProblem(geom, tau_a=tau_a, tau_b=tau_b) - ).matrix + return match_pairs - return match_pairs_filtered if filter_input else match_pairs - - def _get_gromov_match_fn( - self, - ot_solver: Any, - cost_fn: Union[Any, Mapping[str, Any]], - scale_cost: Union[Union[bool, int, float, - Literal["mean", "max_norm", "max_bound", - "max_cost", "median"]], - Dict[str, Union[bool, int, float, - Literal["mean", "max_norm", "max_bound", - "max_cost", "median"]]]], - tau_a: float, - tau_b: float, - fused_penalty: float, - ) -> Callable: - if isinstance(cost_fn, Mapping): - assert "cost_fn_xx" in cost_fn - assert "cost_fn_yy" in cost_fn - cost_fn_xx = cost_fn["cost_fn_xx"] - cost_fn_yy = cost_fn["cost_fn_yy"] - if fused_penalty > 0: - assert "cost_fn_xy" in cost_fn_xx - cost_fn_xy = cost_fn["cost_fn_xy"] - else: - cost_fn_xx = cost_fn_yy = cost_fn_xy = cost_fn - - if isinstance(scale_cost, Mapping): - assert "scale_cost_xx" in scale_cost - assert "scale_cost_yy" in scale_cost - scale_cost_xx = scale_cost["scale_cost_xx"] - scale_cost_yy = scale_cost["scale_cost_yy"] - if fused_penalty > 0: - assert "scale_cost_xy" in scale_cost - scale_cost_xy = cost_fn["scale_cost_xy"] - else: - scale_cost_xx = scale_cost_yy = scale_cost_xy = scale_cost + def _get_gromov_match_fn(self, *args, **kwargs) -> jnp.ndarray: + fn = _get_gromov_match_fn(*args, **kwargs) @jax.jit - def match_pairs( - x_lin: Optional[jnp.ndarray], - x_quad: Tuple[jnp.ndarray, jnp.ndarray], - y_lin: Optional[jnp.ndarray], - y_quad: Tuple[jnp.ndarray, jnp.ndarray], - ) -> jnp.ndarray: - geom_xx = pointcloud.PointCloud( - x=x_quad, y=x_quad, cost_fn=cost_fn_xx, scale_cost=scale_cost_xx - ) - geom_yy = pointcloud.PointCloud( - x=y_quad, y=y_quad, cost_fn=cost_fn_yy, scale_cost=scale_cost_yy - ) - if fused_penalty > 0: - geom_xy = pointcloud.PointCloud( - x=x_lin, y=y_lin, cost_fn=cost_fn_xy, scale_cost=scale_cost_xy - ) - else: - geom_xy = None - prob = quadratic_problem.QuadraticProblem( - geom_xx, - geom_yy, - geom_xy, - fused_penalty=fused_penalty, - tau_a=tau_a, - tau_b=tau_b - ) - out = ot_solver(prob) - return out.matrix + def match_pairs(*args, **kwargs): + return fn(*args, **kwargs).matrix return match_pairs -class UnbalancednessMixin: - """Mixin class to incorporate unbalancedness into neural OT models.""" +class UnbalancednessHandler: + """Class to incorporate unbalancedness into neural OT models. + + This class implements the concepts introduced in :cite:`eyring:23` + in the Monge Map scenario and :cite:`klein:23` for the entropic OT case + for linear and quadratic cases. + + Args: + rng: Random number generator. + source_dim: Dimension of the source domain. + target_dim: Dimension of the target domain. + cond_dim: Dimension of the conditioning variable. + If :obj:`None`, no conditioning is used. + tau_a: Unbalancedness parameter for the source distribution. + tau_b: Unbalancedness parameter for the target distribution. + rescaling_a: Rescaling function for the source distribution. + If :obj:`None`, the left rescaling factor is not learnt. + rescaling_b: Rescaling function for the target distribution. + If :obj:`None`, the right rescaling factor is not learnt. + opt_eta: Optimizer for the left rescaling function. + opt_xi: Optimzier for the right rescaling function. + resample_epsilon: Epsilon for resampling. + scale_cost: Scaling of the cost matrix for estimating the rescaling factors. + ot_solver: Solver to compute unbalanced marginals. If `ot_solver` is `None`, + the method + :meth:`ott.neural.models.base_solver.UnbalancednessHandler.compute_unbalanced_marginals` + is not available, and hence the unbalanced marginals must be computed by the neural solver. + kwargs: Additional keyword arguments. + + """ def __init__( self, @@ -259,12 +262,12 @@ def __init__( jnp.ndarray]] = None, rescaling_b: Optional[Callable[[jnp.ndarray, Optional[jnp.ndarray]], jnp.ndarray]] = None, - seed: Optional[int] = None, opt_eta: Optional[optax.GradientTransformation] = None, opt_xi: Optional[optax.GradientTransformation] = None, resample_epsilon: float = 1e-2, scale_cost: Union[bool, int, float, Literal["mean", "max_cost", "median"]] = "mean", + ot_solver: Optional[was_solver.WassersteinSolver] = None, **kwargs: Mapping[str, Any], ): self.rng_unbalanced = rng @@ -275,48 +278,51 @@ def __init__( self.tau_b = tau_b self.rescaling_a = rescaling_a self.rescaling_b = rescaling_b - self.seed = seed self.opt_eta = opt_eta self.opt_xi = opt_xi self.resample_epsilon = resample_epsilon self.scale_cost = scale_cost + self.ot_solver = ot_solver + + if isinstance(ot_solver, sinkhorn.Sinkhorn): + self.compute_unbalanced_marginals = ( + self._get_compute_unbalanced_marginals_lin( + tau_a=tau_a, + tau_b=tau_b, + resample_epsilon=resample_epsilon, + scale_cost=scale_cost, + **kwargs + ) + ) + elif isinstance(ot_solver, gromov_wasserstein.GromovWasserstein): + self.compute_unbalanced_marginals = self._get_compute_unbalanced_marginals_quad + self.setup(source_dim=source_dim, target_dim=target_dim, cond_dim=cond_dim) - self._compute_unbalanced_marginals = self._get_compute_unbalanced_marginals( - tau_a=tau_a, - tau_b=tau_b, - resample_epsilon=resample_epsilon, - scale_cost=scale_cost, - **kwargs - ) - self._setup(source_dim=source_dim, target_dim=target_dim, cond_dim=cond_dim) + def _get_compute_unbalanced_marginals_lin( + self, *args: Any, **kwargs: Mapping[str, Any] + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Compute the unbalanced source and target marginals for a batch.""" + fn = _get_sinkhorn_match_fn(*args, **kwargs) - def _get_compute_unbalanced_marginals( - self, - tau_a: float, - tau_b: float, - resample_epsilon: float, - scale_cost: Union[bool, int, float, Literal["mean", "max_cost", - "median"]] = "mean", - **kwargs: Mapping[str, Any], + @jax.jit + def compute_unbalanced_marginals_lin(*args, **kwargs): + out = fn(*args, **kwargs) + return out.marginals(axis=1), out.marginals(axis=0) + + return compute_unbalanced_marginals_lin + + def _get_compute_unbalanced_marginals_quad( + self, *args: Any, **kwargs: Mapping[str, Any] ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Compute the unbalanced source and target marginals for a batch.""" + fn = _get_sinkhorn_match_fn(*args, **kwargs) @jax.jit - def compute_unbalanced_marginals( - batch_source: jnp.ndarray, batch_target: jnp.ndarray - ) -> Tuple[jnp.ndarray, jnp.ndarray]: - geom = pointcloud.PointCloud( - batch_source, - batch_target, - epsilon=resample_epsilon, - scale_cost=scale_cost - ) - out = sinkhorn.Sinkhorn(**kwargs)( - linear_problem.LinearProblem(geom, tau_a=tau_a, tau_b=tau_b) - ) - return out.marginal(axis=1), out.marginal(axis=0) + def compute_unbalanced_marginals_quad(*args, **kwargs): + out = fn(*args, **kwargs) + return out.marginals(axis=1), out.marginals(axis=0) - return compute_unbalanced_marginals + return compute_unbalanced_marginals_quad @jax.jit def _resample_unbalanced( @@ -331,11 +337,19 @@ def _resample_unbalanced( ) return tuple(b[indices] if b is not None else None for b in batch) - def _setup(self, source_dim: int, target_dim: int, cond_dim: int): + def setup(self, source_dim: int, target_dim: int, cond_dim: int): + """Setup the model. + + Args: + source_dim: Dimension of the source domain. + target_dim: Dimension of the target domain. + cond_dim: Dimension of the conditioning variable. + If :obj:`None`, no conditioning is used. + """ self.rng_unbalanced, rng_eta, rng_xi = jax.random.split( self.rng_unbalanced, 3 ) - self.unbalancedness_step_fn = self._get_rescaling_step_fn() + self.step_fn = self._get_rescaling_step_fn() if self.rescaling_a is not None: self.opt_eta = ( self.opt_eta if self.opt_eta is not None else diff --git a/src/ott/neural/models/models.py b/src/ott/neural/models/nets.py similarity index 100% rename from src/ott/neural/models/models.py rename to src/ott/neural/models/nets.py diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index 4960c1bec..191b04d2a 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -17,6 +17,7 @@ import pytest import jax.numpy as jnp +from jax import random import optax @@ -24,7 +25,8 @@ from ott.neural.flows.genot import GENOT from ott.neural.flows.models import VelocityField from ott.neural.flows.samplers import uniform_sampler -from ott.neural.models.models import RescalingMLP +from ott.neural.models import base_solver +from ott.neural.models.nets import RescalingMLP from ott.solvers.linear import sinkhorn from ott.solvers.quadratic import gromov_wasserstein @@ -32,7 +34,7 @@ class TestGENOT: @pytest.mark.parametrize("scale_cost", ["mean", 2.0]) - @pytest.mark.parametrize("k_samples_per_x", [1, 2]) + @pytest.mark.parametrize("k_samples_per_x", [1, 3]) @pytest.mark.parametrize("solver_latent_to_data", [None, "sinkhorn"]) def test_genot_linear_unconditional( self, genot_data_loader_linear: Iterator, @@ -50,16 +52,15 @@ def test_genot_linear_unconditional( target_dim = target_lin.shape[1] condition_dim = 0 - print("source dim is ", source_dim) - print("target dim is ", target_dim) - print("condition dim is ", condition_dim) - neural_vf = VelocityField( output_dim=target_dim, condition_dim=source_dim + condition_dim, latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() + unbalancedness_handler = base_solver.UnbalancednessHandler( + random.PRNGKey(0), source_dim, target_dim, condition_dim + ) time_sampler = uniform_sampler optimizer = optax.adam(learning_rate=1e-3) genot = GENOT( @@ -75,6 +76,7 @@ def test_genot_linear_unconditional( scale_cost=scale_cost, optimizer=optimizer, time_sampler=time_sampler, + unbalancedness_handler=unbalancedness_handler, k_samples_per_x=k_samples_per_x, solver_latent_to_data=solver_latent_to_data, ) @@ -110,6 +112,11 @@ def test_genot_quad_unconditional( latent_embed_dim=5, ) ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) + + unbalancedness_handler = base_solver.UnbalancednessHandler( + random.PRNGKey(0), source_dim, target_dim, condition_dim + ) + time_sampler = functools.partial(uniform_sampler, offset=1e-2) optimizer = optax.adam(learning_rate=1e-3) genot = GENOT( @@ -122,6 +129,7 @@ def test_genot_quad_unconditional( ot_solver=ot_solver, epsilon=None, cost_fn=costs.SqEuclidean(), + unbalancedness_handler=unbalancedness_handler, scale_cost=1.0, optimizer=optimizer, time_sampler=time_sampler, @@ -156,6 +164,10 @@ def test_genot_fused_unconditional( latent_embed_dim=5, ) ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) + unbalancedness_handler = base_solver.UnbalancednessHandler( + random.PRNGKey(0), source_dim, target_dim, condition_dim + ) + optimizer = optax.adam(learning_rate=1e-3) genot = GENOT( neural_vf, @@ -168,6 +180,7 @@ def test_genot_fused_unconditional( ot_solver=ot_solver, cost_fn=costs.SqEuclidean(), scale_cost=1.0, + unbalancedness_handler=unbalancedness_handler, optimizer=optimizer, fused_penalty=0.5, k_samples_per_x=k_samples_per_x, @@ -203,6 +216,10 @@ def test_genot_linear_conditional( ) ot_solver = sinkhorn.Sinkhorn() time_sampler = uniform_sampler + unbalancedness_handler = base_solver.UnbalancednessHandler( + random.PRNGKey(0), source_dim, target_dim, condition_dim + ) + optimizer = optax.adam(learning_rate=1e-3) genot = GENOT( neural_vf, @@ -215,6 +232,7 @@ def test_genot_linear_conditional( epsilon=0.1, cost_fn=costs.SqEuclidean(), scale_cost=1.0, + unbalancedness_handler=unbalancedness_handler, optimizer=optimizer, time_sampler=time_sampler, k_samples_per_x=k_samples_per_x, @@ -250,6 +268,10 @@ def test_genot_quad_conditional( ) ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) time_sampler = uniform_sampler + unbalancedness_handler = base_solver.UnbalancednessHandler( + random.PRNGKey(0), source_dim, target_dim, condition_dim + ) + optimizer = optax.adam(learning_rate=1e-3) genot = GENOT( neural_vf, @@ -262,6 +284,7 @@ def test_genot_quad_conditional( epsilon=None, cost_fn=costs.SqEuclidean(), scale_cost=1.0, + unbalancedness_handler=unbalancedness_handler, optimizer=optimizer, time_sampler=time_sampler, k_samples_per_x=k_samples_per_x, @@ -298,9 +321,10 @@ def test_genot_fused_conditional( ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) time_sampler = uniform_sampler optimizer = optax.adam(learning_rate=1e-3) - print("source dim is ", source_dim) - print("target dim is ", target_dim) - print("condition dim is ", condition_dim) + unbalancedness_handler = base_solver.UnbalancednessHandler( + random.PRNGKey(0), source_dim, target_dim, condition_dim + ) + genot = GENOT( neural_vf, input_dim=source_dim, @@ -312,6 +336,7 @@ def test_genot_fused_conditional( epsilon=None, cost_fn=costs.SqEuclidean(), scale_cost=1.0, + unbalancedness_handler=unbalancedness_handler, optimizer=optimizer, time_sampler=time_sampler, k_samples_per_x=k_samples_per_x, @@ -357,10 +382,23 @@ def test_genot_linear_learn_rescaling( ot_solver = sinkhorn.Sinkhorn() time_sampler = uniform_sampler optimizer = optax.adam(learning_rate=1e-3) + tau_a = 0.9 tau_b = 0.2 rescaling_a = RescalingMLP(hidden_dim=4, condition_dim=condition_dim) rescaling_b = RescalingMLP(hidden_dim=4, condition_dim=condition_dim) + + unbalancedness_handler = base_solver.UnbalancednessHandler( + random.PRNGKey(0), + source_dim, + target_dim, + condition_dim, + tau_a=tau_a, + tau_b=tau_b, + rescaling_a=rescaling_a, + rescaling_b=rescaling_b + ) + genot = GENOT( neural_vf, input_dim=source_dim, @@ -374,18 +412,19 @@ def test_genot_linear_learn_rescaling( scale_cost=1.0, optimizer=optimizer, time_sampler=time_sampler, - tau_a=tau_a, - tau_b=tau_b, - rescaling_a=rescaling_a, - rescaling_b=rescaling_b, + unbalancedness_handler=unbalancedness_handler, ) genot(data_loader, data_loader) - result_eta = genot.evaluate_eta(source_lin, condition=source_condition) + result_eta = genot.unbalancedness_handler.evaluate_eta( + source_lin, condition=source_condition + ) assert isinstance(result_eta, jnp.ndarray) assert jnp.sum(jnp.isnan(result_eta)) == 0 - result_xi = genot.evaluate_xi(target_lin, condition=source_condition) + result_xi = genot.unbalancedness_handler.evaluate_xi( + target_lin, condition=source_condition + ) assert isinstance(result_xi, jnp.ndarray) assert jnp.sum(jnp.isnan(result_xi)) == 0 diff --git a/tests/neural/losses_test.py b/tests/neural/losses_test.py index 6379b9dfa..733a8c2b3 100644 --- a/tests/neural/losses_test.py +++ b/tests/neural/losses_test.py @@ -19,7 +19,7 @@ from ott.geometry import costs from ott.neural.gaps import monge_gap -from ott.neural.models import models +from ott.neural.models import nets @pytest.mark.fast() @@ -35,7 +35,7 @@ def test_monge_gap_non_negativity( rng1, rng2 = jax.random.split(rng, 2) reference_points = jax.random.normal(rng1, (n_samples, n_features)) - model = models.MLP(dim_hidden=[8, 8], is_potential=False) + model = nets.MLP(dim_hidden=[8, 8], is_potential=False) params = model.init(rng2, x=reference_points[0]) target = model.apply(params, reference_points) diff --git a/tests/neural/map_estimator_test.py b/tests/neural/map_estimator_test.py index 508143465..c42f31daa 100644 --- a/tests/neural/map_estimator_test.py +++ b/tests/neural/map_estimator_test.py @@ -20,7 +20,7 @@ from ott import datasets from ott.geometry import pointcloud from ott.neural.gaps import map_estimator, monge_gap -from ott.neural.models import models +from ott.neural.models import nets from ott.tools import sinkhorn_divergence @@ -51,7 +51,7 @@ def regularizer(x, y): return gap, out.n_iters # define the model - model = models.MLP(dim_hidden=[16, 8], is_potential=False) + model = nets.MLP(dim_hidden=[16, 8], is_potential=False) # generate data train_dataset, valid_dataset, dim_data = ( diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index 7f6a1a8dc..e4d22a789 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -17,10 +17,12 @@ import pytest import jax.numpy as jnp +from jax import random import optax from ott.neural.flows import flows, models, otfm, samplers +from ott.neural.models import base_solver, nets from ott.solvers.linear import sinkhorn @@ -36,6 +38,8 @@ class TestOTFlowMatching: def test_flow_matching( self, data_loader_gaussian, flow: Type[flows.BaseFlow] ): + input_dim = 2 + condition_dim = 0 neural_vf = models.VelocityField( output_dim=2, condition_dim=0, @@ -44,16 +48,20 @@ def test_flow_matching( ot_solver = sinkhorn.Sinkhorn() time_sampler = samplers.uniform_sampler optimizer = optax.adam(learning_rate=1e-3) + unbalancedness_handler = base_solver.UnbalancednessHandler( + random.PRNGKey(0), input_dim, input_dim, condition_dim + ) fm = otfm.OTFlowMatching( neural_vf, - input_dim=2, - cond_dim=0, + input_dim=input_dim, + cond_dim=condition_dim, iterations=3, valid_freq=2, ot_solver=ot_solver, flow=flow, time_sampler=time_sampler, - optimizer=optimizer + optimizer=optimizer, + unbalancedness_handler=unbalancedness_handler ) fm(data_loader_gaussian, data_loader_gaussian) @@ -82,6 +90,8 @@ def test_flow_matching( def test_flow_matching_with_conditions( self, data_loader_gaussian_with_conditions, flow: Type[flows.BaseFlow] ): + input_dim = 2 + condition_dim = 1 neural_vf = models.VelocityField( output_dim=2, condition_dim=1, @@ -90,6 +100,10 @@ def test_flow_matching_with_conditions( ot_solver = sinkhorn.Sinkhorn() time_sampler = functools.partial(samplers.uniform_sampler, offset=1e-5) optimizer = optax.adam(learning_rate=1e-3) + unbalancedness_handler = base_solver.UnbalancednessHandler( + random.PRNGKey(0), input_dim, input_dim, condition_dim + ) + fm = otfm.OTFlowMatching( neural_vf, input_dim=2, @@ -99,7 +113,8 @@ def test_flow_matching_with_conditions( ot_solver=ot_solver, flow=flow, time_sampler=time_sampler, - optimizer=optimizer + optimizer=optimizer, + unbalancedness_handler=unbalancedness_handler ) fm( data_loader_gaussian_with_conditions, @@ -131,24 +146,31 @@ def test_flow_matching_with_conditions( def test_flow_matching_conditional( self, data_loader_gaussian_conditional, flow: Type[flows.BaseFlow] ): + dim = 2 + condition_dim = 0 neural_vf = models.VelocityField( - output_dim=2, - condition_dim=0, + output_dim=dim, + condition_dim=condition_dim, latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() time_sampler = samplers.uniform_sampler optimizer = optax.adam(learning_rate=1e-3) + unbalancedness_handler = base_solver.UnbalancednessHandler( + random.PRNGKey(0), dim, dim, condition_dim + ) + fm = otfm.OTFlowMatching( neural_vf, - input_dim=2, - cond_dim=0, + input_dim=dim, + cond_dim=condition_dim, iterations=3, valid_freq=2, ot_solver=ot_solver, flow=flow, time_sampler=time_sampler, - optimizer=optimizer + optimizer=optimizer, + unbalancedness_handler=unbalancedness_handler ) fm(data_loader_gaussian_conditional, data_loader_gaussian_conditional) @@ -191,8 +213,19 @@ def test_flow_matching_learn_rescaling( tau_a = 0.9 tau_b = 0.2 - rescaling_a = models.RescalingMLP(hidden_dim=4, condition_dim=condition_dim) - rescaling_b = models.RescalingMLP(hidden_dim=4, condition_dim=condition_dim) + rescaling_a = nets.RescalingMLP(hidden_dim=4, condition_dim=condition_dim) + rescaling_b = nets.RescalingMLP(hidden_dim=4, condition_dim=condition_dim) + unbalancedness_handler = base_solver.UnbalancednessHandler( + random.PRNGKey(0), + source_dim, + source_dim, + condition_dim, + tau_a=tau_a, + tau_b=tau_b, + rescaling_a=rescaling_a, + rescaling_b=rescaling_b + ) + fm = otfm.OTFlowMatching( neural_vf, input_dim=source_dim, @@ -203,20 +236,17 @@ def test_flow_matching_learn_rescaling( flow=flow, time_sampler=time_sampler, optimizer=optimizer, - tau_a=tau_a, - tau_b=tau_b, - rescaling_a=rescaling_a, - rescaling_b=rescaling_b, + unbalancedness_handler=unbalancedness_handler, ) fm(data_loader, data_loader) - result_eta = fm.evaluate_eta( + result_eta = fm.unbalancedness_handler.evaluate_eta( batch["source_lin"], condition=batch["source_conditions"] ) assert isinstance(result_eta, jnp.ndarray) assert jnp.sum(jnp.isnan(result_eta)) == 0 - result_xi = fm.evaluate_xi( + result_xi = fm.unbalancedness_handler.evaluate_xi( batch["target_lin"], condition=batch["source_conditions"] ) assert isinstance(result_xi, jnp.ndarray) From 5e05bfc69e6f9b3fa492e2877759712bb873b23d Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 9 Feb 2024 14:43:54 +0100 Subject: [PATCH 081/186] incorporate more changes --- src/ott/neural/flows/layers.py | 35 ++++++------ src/ott/neural/flows/samplers.py | 15 ++--- src/ott/tools/soft_sort.py | 2 +- tests/solvers/quadratic/lower_bound_test.py | 63 +++++++++++++++++++++ 4 files changed, 89 insertions(+), 26 deletions(-) create mode 100644 tests/solvers/quadratic/lower_bound_test.py diff --git a/src/ott/neural/flows/layers.py b/src/ott/neural/flows/layers.py index 84a526b1f..d18980c38 100644 --- a/src/ott/neural/flows/layers.py +++ b/src/ott/neural/flows/layers.py @@ -11,37 +11,36 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import abc import jax.numpy as jnp import flax.linen as nn -__all__ = ["TimeEncoder", "CyclicalTimeEncoder"] +__all__ = ["CyclicalTimeEncoder"] -class TimeEncoder(nn.Module, abc.ABC): - """A time encoder.""" - - @abc.abstractmethod - def __call__(self, t: jnp.ndarray) -> jnp.ndarray: - """Encode the time. - - Args: - t: Input time of shape (batch_size, 1). - - Returns: - The encoded time. - """ - pass +class CyclicalTimeEncoder(nn.Module): + r"""A cyclical time encoder. + Encodes time :math:`t` as + :math:`cos(\tilde{t})` and :math:`sin(\tilde{t})` + where :math:`\tilde{t} = [2\\pi t, 2\\pi 2 t,\\ldots, 2\\pi n_frequencies t]` -class CyclicalTimeEncoder(nn.Module): - """A cyclical time encoder.""" + Args: + n_frequencies: Frequency of cyclical encoding. + """ n_frequencies: int = 128 @nn.compact def __call__(self, t: jnp.ndarray) -> jnp.ndarray: # noqa: D102 + """Encode time :math:`t` into a cyclical representation. + + Args: + t: Time of shape ``[n, 1]``. + + Returns: + Encoded time of shape ``[n, 2 * n_frequencies]`` + """ freq = 2 * jnp.arange(self.n_frequencies) * jnp.pi t = freq * t return jnp.concatenate((jnp.cos(t), jnp.sin(t)), axis=-1) diff --git a/src/ott/neural/flows/samplers.py b/src/ott/neural/flows/samplers.py index 1bfee16b4..30373380a 100644 --- a/src/ott/neural/flows/samplers.py +++ b/src/ott/neural/flows/samplers.py @@ -26,10 +26,12 @@ def uniform_sampler( high: float = 1.0, offset: Optional[float] = None ) -> jnp.ndarray: - """Sample from a uniform distribution. + r"""Sample from a uniform distribution. - Sample :math:`t` from a uniform distribution :math:`[low, high]` with - offset `offset`. + Sample :math:`t` from a uniform distribution :math:`[low, high]`. + If `offset` is not :obj:`None`, one element :math:`t` is sampled from + :math:`[low, high]` and the K samples are constructed via + :math:`(t + k)/K \mod (high - low - offset) + low`. Args: rng: Random number generator. @@ -44,7 +46,6 @@ def uniform_sampler( """ if offset is None: return jax.random.uniform(rng, (num_samples, 1), minval=low, maxval=high) - return ( - jax.random.uniform(rng, (1, 1), minval=low, maxval=high) + - jnp.arange(num_samples)[:, None] / num_samples - ) % ((high - low) - offset) + t = jax.random.uniform(rng, (1, 1), minval=low, maxval=high) + mod_term = ((high - low) - offset) + return (t + jnp.arange(num_samples)[:, None] / num_samples) % mod_term diff --git a/src/ott/tools/soft_sort.py b/src/ott/tools/soft_sort.py index 1a30359ee..9e2e0c5d0 100644 --- a/src/ott/tools/soft_sort.py +++ b/src/ott/tools/soft_sort.py @@ -457,7 +457,7 @@ def _quantile( def multivariate_cdf_quantile_maps( inputs: jnp.ndarray, - target_sampler: Optional[Callable[[jnp.ndarray, Tuple[int, int]], + target_sampler: Optional[Callable[[jax.Array, Tuple[int, int]], jax.Array]] = None, rng: Optional[jax.Array] = None, num_target_samples: Optional[int] = None, diff --git a/tests/solvers/quadratic/lower_bound_test.py b/tests/solvers/quadratic/lower_bound_test.py new file mode 100644 index 000000000..08353f711 --- /dev/null +++ b/tests/solvers/quadratic/lower_bound_test.py @@ -0,0 +1,63 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import jax +import jax.numpy as jnp + +from ott.geometry import costs, distrib_costs, pointcloud +from ott.problems.quadratic import quadratic_problem +from ott.solvers.quadratic import lower_bound + + +class TestLowerBoundSolver: + + @pytest.fixture(autouse=True) + def initialize(self, rng: jax.Array): + d_x = 2 + d_y = 3 + self.n, self.m = 13, 15 + rngs = jax.random.split(rng, 4) + self.x = jax.random.uniform(rngs[0], (self.n, d_x)) + self.y = jax.random.uniform(rngs[1], (self.m, d_y)) + # Currently the Lower Bound only supports uniform distributions: + a = jnp.ones(self.n) + b = jnp.ones(self.m) + self.a = a / jnp.sum(a) + self.b = b / jnp.sum(b) + self.cx = jax.random.uniform(rngs[2], (self.n, self.n)) + self.cy = jax.random.uniform(rngs[3], (self.m, self.m)) + + @pytest.mark.fast.with_args( + "ground_cost", + [costs.SqEuclidean(), costs.PNormP(1.5)], + only_fast=0, + ) + def test_lb_pointcloud(self, ground_cost: costs.TICost): + x, y = self.x, self.y + + geom_x = pointcloud.PointCloud(x) + geom_y = pointcloud.PointCloud(y) + prob = quadratic_problem.QuadraticProblem( + geom_x, geom_y, a=self.a, b=self.b + ) + distrib_cost = distrib_costs.UnivariateWasserstein(ground_cost=ground_cost) + solver = lower_bound.LowerBoundSolver( + epsilon=1e-1, distrib_cost=distrib_cost + ) + + out = jax.jit(solver)(prob) + + assert not jnp.isnan(out.reg_ot_cost) From 831d3ea4a8b15eb9a7914fd780b65e86cd2fc305 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 9 Feb 2024 15:05:37 +0100 Subject: [PATCH 082/186] move noise sampling to flows --- src/ott/neural/flows/flows.py | 6 ++++-- src/ott/neural/flows/genot.py | 25 ++++--------------------- src/ott/neural/flows/models.py | 2 -- src/ott/neural/flows/otfm.py | 27 +++++++-------------------- 4 files changed, 15 insertions(+), 45 deletions(-) diff --git a/src/ott/neural/flows/flows.py b/src/ott/neural/flows/flows.py index c379dcbc3..51e19fb5c 100644 --- a/src/ott/neural/flows/flows.py +++ b/src/ott/neural/flows/flows.py @@ -14,6 +14,7 @@ import abc import jax.numpy as jnp +import jax __all__ = [ "BaseFlow", @@ -75,7 +76,7 @@ def compute_ut( """ def compute_xt( - self, noise: jnp.ndarray, t: jnp.ndarray, src: jnp.ndarray, + self, rng: jax.Array, t: jnp.ndarray, src: jnp.ndarray, tgt: jnp.ndarray ) -> jnp.ndarray: """Sample from the probability path. @@ -84,7 +85,7 @@ def compute_xt( time :math:`t`. Args: - noise: Noise sampled from a standard normal distribution. + rng: Random number generator. t: Time :math:`t`. src: Sample from the source distribution. tgt: Sample from the target distribution. @@ -93,6 +94,7 @@ def compute_xt( Samples from the probability path between :math:`x_0` and :math:`x_1` at time :math:`t`. """ + noise = jax.random.normal(rng, shape=(src.shape)) mu_t = self.compute_mu_t(t, src, tgt) sigma_t = self.compute_sigma_t(t) return mu_t + sigma_t * noise diff --git a/src/ott/neural/flows/genot.py b/src/ott/neural/flows/genot.py index 736d9e268..981f958dc 100644 --- a/src/ott/neural/flows/genot.py +++ b/src/ott/neural/flows/genot.py @@ -217,7 +217,6 @@ def __call__(self, train_loader, valid_loader): ) if batch["source_lin"] is not None else len(batch["source_quad"]) n_samples = batch_size * self.k_samples_per_x batch["time"] = self.time_sampler(rng_time, n_samples) - batch["noise"] = self.sample_noise(rng_noise, n_samples) batch["latent"] = self.latent_noise_fn( rng_noise, shape=(self.k_samples_per_x, batch_size) ) @@ -309,7 +308,7 @@ def loss_fn( rng: jax.random.PRNGKeyArray ): x_t = self.flow.compute_xt( - batch["noise"], batch["time"], batch["latent"], batch["target"] + rng, batch["time"], batch["latent"], batch["target"] ) apply_fn = functools.partial( state_velocity_field.apply_fn, {"params": params} @@ -322,16 +321,14 @@ def loss_fn( ], axis=1) v_t = jax.vmap(apply_fn - )(t=batch["time"], x=x_t, condition=cond_input, rng=rng) + )(t=batch["time"], x=x_t, condition=cond_input) u_t = self.flow.compute_ut( batch["time"], batch["latent"], batch["target"] ) return jnp.mean((v_t - u_t) ** 2) - keys_model = jax.random.split(rng, len(batch["noise"])) - grad_fn = jax.value_and_grad(loss_fn, has_aux=False) - loss, grads = grad_fn(state_velocity_field.params, batch, keys_model) + loss, grads = grad_fn(state_velocity_field.params, batch, rng) return state_velocity_field.apply_gradients(grads=grads), loss @@ -434,18 +431,4 @@ def load(self, path: str) -> "GENOT": @property def training_logs(self) -> Dict[str, Any]: """Logs of the training.""" - raise NotImplementedError - - def sample_noise( - self, key: jax.random.PRNGKey, batch_size: int - ) -> jnp.ndarray: - """Sample noise from a standard-normal distribution. - - Args: - key: Random key for seeding. - batch_size: Number of samples to draw. - - Returns: - Samples from the standard normal distribution. - """ - return jax.random.normal(key, shape=(batch_size, self.output_dim)) + raise NotImplementedError \ No newline at end of file diff --git a/src/ott/neural/flows/models.py b/src/ott/neural/flows/models.py index 6970e8368..bf365e772 100644 --- a/src/ott/neural/flows/models.py +++ b/src/ott/neural/flows/models.py @@ -95,7 +95,6 @@ def __call__( t: jnp.ndarray, x: jnp.ndarray, condition: Optional[jnp.ndarray] = None, - rng: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: """Forward pass through the neural vector field. @@ -103,7 +102,6 @@ def __call__( t: Time of shape (batch_size, 1). x: Data of shape (batch_size, output_dim). condition: Conditioning vector. - rng: Random number generator. Returns: Output of the neural vector field. diff --git a/src/ott/neural/flows/otfm.py b/src/ott/neural/flows/otfm.py index ef55e1dc5..e5bdea711 100644 --- a/src/ott/neural/flows/otfm.py +++ b/src/ott/neural/flows/otfm.py @@ -33,6 +33,7 @@ __all__ = ["OTFlowMatching"] + class OTFlowMatching( base_solver.ResampleMixin, ): @@ -145,30 +146,29 @@ def step_fn( ) -> Tuple[Any, Any]: def loss_fn( - params: jnp.ndarray, t: jnp.ndarray, noise: jnp.ndarray, - batch: Dict[str, jnp.ndarray], rng: jax.random.PRNGKeyArray + params: jnp.ndarray, t: jnp.ndarray, + batch: Dict[str, jnp.ndarray], rng: jax.Array ) -> jnp.ndarray: x_t = self.flow.compute_xt( - noise, t, batch["source_lin"], batch["target_lin"] + rng, t, batch["source_lin"], batch["target_lin"] ) apply_fn = functools.partial( state_velocity_field.apply_fn, {"params": params} ) v_t = jax.vmap(apply_fn)( - t=t, x=x_t, condition=batch["source_conditions"], rng=rng + t=t, x=x_t, condition=batch["source_conditions"] ) u_t = self.flow.compute_ut(t, batch["source_lin"], batch["target_lin"]) return jnp.mean((v_t - u_t) ** 2) batch_size = len(batch["source_lin"]) - key_noise, key_t, key_model = jax.random.split(rng, 3) + key_t, key_model = jax.random.split(rng, 2) keys_model = jax.random.split(key_model, batch_size) t = self.time_sampler(key_t, batch_size) - noise = self.sample_noise(key_noise, batch_size) grad_fn = jax.value_and_grad(loss_fn) loss, grads = grad_fn( - state_velocity_field.params, t, noise, batch, keys_model + state_velocity_field.params, t, batch, keys_model ) return state_velocity_field.apply_gradients(grads=grads), loss @@ -312,16 +312,3 @@ def training_logs(self) -> Dict[str, Any]: """Logs of the training.""" raise NotImplementedError - def sample_noise( - self, key: jax.random.PRNGKey, batch_size: int - ) -> jnp.ndarray: - """Sample noise from a standard-normal distribution. - - Args: - key: Random key for seeding. - batch_size: Number of samples to draw. - - Returns: - Samples from the standard normal distribution. - """ - return jax.random.normal(key, shape=(batch_size, self.input_dim)) From c18c461019db2368d32c42c730eb8006b2961444 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 9 Feb 2024 15:14:13 +0100 Subject: [PATCH 083/186] fix bug in passing rngs in otfm --- src/ott/neural/flows/flows.py | 7 +++---- src/ott/neural/flows/genot.py | 5 ++--- src/ott/neural/flows/otfm.py | 16 +++++----------- src/ott/neural/flows/samplers.py | 2 +- 4 files changed, 11 insertions(+), 19 deletions(-) diff --git a/src/ott/neural/flows/flows.py b/src/ott/neural/flows/flows.py index 51e19fb5c..65f697d89 100644 --- a/src/ott/neural/flows/flows.py +++ b/src/ott/neural/flows/flows.py @@ -13,8 +13,8 @@ # limitations under the License. import abc -import jax.numpy as jnp import jax +import jax.numpy as jnp __all__ = [ "BaseFlow", @@ -76,8 +76,7 @@ def compute_ut( """ def compute_xt( - self, rng: jax.Array, t: jnp.ndarray, src: jnp.ndarray, - tgt: jnp.ndarray + self, rng: jax.Array, t: jnp.ndarray, src: jnp.ndarray, tgt: jnp.ndarray ) -> jnp.ndarray: """Sample from the probability path. @@ -94,7 +93,7 @@ def compute_xt( Samples from the probability path between :math:`x_0` and :math:`x_1` at time :math:`t`. """ - noise = jax.random.normal(rng, shape=(src.shape)) + noise = jax.random.normal(rng, shape=src.shape) mu_t = self.compute_mu_t(t, src, tgt) sigma_t = self.compute_sigma_t(t) return mu_t + sigma_t * noise diff --git a/src/ott/neural/flows/genot.py b/src/ott/neural/flows/genot.py index 981f958dc..73a4cb1bc 100644 --- a/src/ott/neural/flows/genot.py +++ b/src/ott/neural/flows/genot.py @@ -320,8 +320,7 @@ def loss_fn( if batch[el] is not None ], axis=1) - v_t = jax.vmap(apply_fn - )(t=batch["time"], x=x_t, condition=cond_input) + v_t = jax.vmap(apply_fn)(t=batch["time"], x=x_t, condition=cond_input) u_t = self.flow.compute_ut( batch["time"], batch["latent"], batch["target"] ) @@ -431,4 +430,4 @@ def load(self, path: str) -> "GENOT": @property def training_logs(self) -> Dict[str, Any]: """Logs of the training.""" - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/src/ott/neural/flows/otfm.py b/src/ott/neural/flows/otfm.py index e5bdea711..e8233153e 100644 --- a/src/ott/neural/flows/otfm.py +++ b/src/ott/neural/flows/otfm.py @@ -33,7 +33,6 @@ __all__ = ["OTFlowMatching"] - class OTFlowMatching( base_solver.ResampleMixin, ): @@ -146,8 +145,8 @@ def step_fn( ) -> Tuple[Any, Any]: def loss_fn( - params: jnp.ndarray, t: jnp.ndarray, - batch: Dict[str, jnp.ndarray], rng: jax.Array + params: jnp.ndarray, t: jnp.ndarray, batch: Dict[str, jnp.ndarray], + rng: jax.Array ) -> jnp.ndarray: x_t = self.flow.compute_xt( @@ -156,20 +155,16 @@ def loss_fn( apply_fn = functools.partial( state_velocity_field.apply_fn, {"params": params} ) - v_t = jax.vmap(apply_fn)( - t=t, x=x_t, condition=batch["source_conditions"] - ) + v_t = jax.vmap(apply_fn + )(t=t, x=x_t, condition=batch["source_conditions"]) u_t = self.flow.compute_ut(t, batch["source_lin"], batch["target_lin"]) return jnp.mean((v_t - u_t) ** 2) batch_size = len(batch["source_lin"]) key_t, key_model = jax.random.split(rng, 2) - keys_model = jax.random.split(key_model, batch_size) t = self.time_sampler(key_t, batch_size) grad_fn = jax.value_and_grad(loss_fn) - loss, grads = grad_fn( - state_velocity_field.params, t, batch, keys_model - ) + loss, grads = grad_fn(state_velocity_field.params, t, batch, key_model) return state_velocity_field.apply_gradients(grads=grads), loss return step_fn @@ -311,4 +306,3 @@ def load(self, path: str) -> "OTFlowMatching": def training_logs(self) -> Dict[str, Any]: """Logs of the training.""" raise NotImplementedError - diff --git a/src/ott/neural/flows/samplers.py b/src/ott/neural/flows/samplers.py index 30373380a..34a28c2d2 100644 --- a/src/ott/neural/flows/samplers.py +++ b/src/ott/neural/flows/samplers.py @@ -29,7 +29,7 @@ def uniform_sampler( r"""Sample from a uniform distribution. Sample :math:`t` from a uniform distribution :math:`[low, high]`. - If `offset` is not :obj:`None`, one element :math:`t` is sampled from + If `offset` is not :obj:`None`, one element :math:`t` is sampled from :math:`[low, high]` and the K samples are constructed via :math:`(t + k)/K \mod (high - low - offset) + low`. From 83418215c9c23fa1ebd817fb681e4e823cb8d565 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 9 Feb 2024 15:41:57 +0100 Subject: [PATCH 084/186] introduce otmatcher in otfm --- src/ott/neural/flows/genot.py | 19 ++++------ src/ott/neural/flows/otfm.py | 53 ++++++++-------------------- src/ott/neural/models/base_solver.py | 47 ++++++++++++++++++++---- tests/neural/otfm_test.py | 16 ++++++--- 4 files changed, 73 insertions(+), 62 deletions(-) diff --git a/src/ott/neural/flows/genot.py b/src/ott/neural/flows/genot.py index 73a4cb1bc..02a203a24 100644 --- a/src/ott/neural/flows/genot.py +++ b/src/ott/neural/flows/genot.py @@ -33,9 +33,7 @@ __all__ = ["GENOT"] -class GENOT( - base_solver.ResampleMixin, -): +class GENOT: """The GENOT training class as introduced in :cite:`klein_uscidda:23`. Args: @@ -113,7 +111,6 @@ def __init__( rng: Optional[jax.Array] = None, ): rng = utils.default_prng_key(rng) - base_solver.ResampleMixin.__init__(self) if isinstance( ot_solver, gromov_wasserstein.GromovWasserstein @@ -252,7 +249,6 @@ def __call__(self, train_loader, valid_loader): (batch["target"],), source_is_balanced=(self.unbalancedness_handler.tau_a == 1.0) ) - jax.random.split(rng_noise, batch_size * self.k_samples_per_x) if self.solver_latent_to_data is not None: tmats_latent_data = jnp.array( @@ -339,7 +335,7 @@ def transport( condition: Optional[jnp.ndarray] = None, rng: Optional[jax.Array] = None, forward: bool = True, - diffeqsolve_kwargs: Dict[str, Any] = types.MappingProxyType({}), + **kwargs: Any, ) -> Union[jnp.array, diffrax.Solution, Optional[jnp.ndarray]]: """Transport data with the learnt plan. @@ -352,7 +348,7 @@ def transport( condition: Condition of the input data. rng: random seed for sampling from the latent distribution. forward: If `True` integrates forward, otherwise backwards. - diffeqsolve_kwargs: Keyword arguments for the ODE solver. + kwargs: Keyword arguments for the ODE solver. Returns: The push-forward or pull-back distribution defined by the learnt @@ -362,7 +358,6 @@ def transport( rng = utils.default_prng_key(rng) if not forward: raise NotImplementedError - diffeqsolve_kwargs = dict(diffeqsolve_kwargs) assert len(source) == len(condition) if condition is not None else True latent_batch = self.latent_noise_fn(rng, shape=(len(source),)) @@ -382,16 +377,16 @@ def solve_ode(input: jnp.ndarray, cond: jnp.ndarray) -> jnp.ndarray: x=x, condition=cond) ), - diffeqsolve_kwargs.pop("solver", diffrax.Tsit5()), + kwargs.pop("solver", diffrax.Tsit5()), t0=t0, t1=t1, - dt0=diffeqsolve_kwargs.pop("dt0", None), + dt0=kwargs.pop("dt0", None), y0=input, - stepsize_controller=diffeqsolve_kwargs.pop( + stepsize_controller=kwargs.pop( "stepsize_controller", diffrax.PIDController(rtol=1e-5, atol=1e-5) ), - **diffeqsolve_kwargs, + **kwargs, ).ys[0] return jax.vmap(solve_ode)(latent_batch, cond_input) diff --git a/src/ott/neural/flows/otfm.py b/src/ott/neural/flows/otfm.py index e8233153e..de0c8c5e5 100644 --- a/src/ott/neural/flows/otfm.py +++ b/src/ott/neural/flows/otfm.py @@ -13,7 +13,6 @@ # limitations under the License. import collections import functools -import types from typing import Any, Callable, Dict, Literal, Mapping, Optional, Tuple, Type, Union import jax @@ -28,14 +27,11 @@ from ott.geometry import costs from ott.neural.flows import flows from ott.neural.models import base_solver -from ott.solvers import was_solver __all__ = ["OTFlowMatching"] -class OTFlowMatching( - base_solver.ResampleMixin, -): +class OTFlowMatching: """(Optimal transport) flow matching class. Flow matching as introduced in :cite:`lipman:22`, with extension to OT-FM @@ -47,19 +43,10 @@ class OTFlowMatching( cond_dim: Dimension of the conditioning variable. iterations: Number of iterations. valid_freq: Frequency of validation. - ot_solver: OT solver to match samples from the source and the target - distribution as proposed in :cite:`tong:23`, :cite:`pooladian:23`. - If :obj:`None`, no matching will be performed as proposed in - :cite:`lipman:22`. flow: Flow between source and target distribution. time_sampler: Sampler for the time. optimizer: Optimizer for `velocity_field`. checkpoint_manager: Checkpoint manager. - epsilon: Entropy regularization term of the OT OT problem solved by the - `ot_solver`. - cost_fn: Cost function for the OT problem solved by the `ot_solver`. - scale_cost: How to scale the cost matrix for the OT problem solved by the - `ot_solver`. callback_fn: Callback function. num_eval_samples: Number of samples to evaluate on during evaluation. rng: Random number generator. @@ -73,10 +60,10 @@ def __init__( input_dim: int, cond_dim: int, iterations: int, - ot_solver: Optional[Type[was_solver.WassersteinSolver]], flow: Type[flows.BaseFlow], time_sampler: Callable[[jax.Array, int], jnp.ndarray], optimizer: optax.GradientTransformation, + ot_matcher: base_solver.OTMatcher, unbalancedness_handler: base_solver.UnbalancednessHandler, checkpoint_manager: Type[checkpoint.CheckpointManager] = None, epsilon: float = 1e-2, @@ -92,13 +79,12 @@ def __init__( rng: Optional[jax.Array] = None, ): rng = utils.default_prng_key(rng) - base_solver.ResampleMixin.__init__(self) self.unbalancedness_handler = unbalancedness_handler self.iterations = iterations self.valid_freq = valid_freq self.velocity_field = velocity_field self.input_dim = input_dim - self.ot_solver = ot_solver + self.ot_matcher = ot_matcher self.flow = flow self.time_sampler = time_sampler self.optimizer = optimizer @@ -123,17 +109,6 @@ def setup(self) -> None: ) self.step_fn = self._get_step_fn() - if self.ot_solver is not None: - self.match_fn = self._get_sinkhorn_match_fn( - self.ot_solver, - epsilon=self.epsilon, - cost_fn=self.cost_fn, - scale_cost=self.scale_cost, - tau_a=self.unbalancedness_handler.tau_a, - tau_b=self.unbalancedness_handler.tau_b, - ) - else: - self.match_fn = None def _get_step_fn(self) -> Callable: @@ -182,11 +157,13 @@ def __call__(self, train_loader, valid_loader): for iter in range(self.iterations): rng_resample, rng_step_fn, self.rng = jax.random.split(self.rng, 3) batch = next(train_loader) - if self.ot_solver is not None: - tmat = self.match_fn(batch["source_lin"], batch["target_lin"]) + if self.ot_matcher is not None: + tmat = self.ot_matcher.match_fn( + batch["source_lin"], batch["target_lin"] + ) (batch["source_lin"], batch["source_conditions"] ), (batch["target_lin"], - batch["target_conditions"]) = self._resample_data( + batch["target_conditions"]) = self.ot_matcher._resample_data( rng_resample, tmat, (batch["source_lin"], batch["source_conditions"]), (batch["target_lin"], batch["target_conditions"]) @@ -222,7 +199,7 @@ def transport( forward: bool = True, t_0: float = 0.0, t_1: float = 1.0, - diffeqsolve_kwargs: Dict[str, Any] = types.MappingProxyType({}) + **kwargs: Any, ) -> diffrax.Solution: """Transport data with the learnt map. @@ -236,15 +213,13 @@ def transport( forward: If `True` integrates forward, otherwise backwards. t_0: Starting point of integration. t_1: End point of integration. - diffeqsolve_kwargs: Keyword arguments for the ODE solver. + kwargs: Keyword arguments for the ODE solver. Returns: The push-forward or pull-back distribution defined by the learnt transport plan. """ - diffeqsolve_kwargs = dict(diffeqsolve_kwargs) - t0, t1 = (t_0, t_1) if forward else (t_1, t_0) @jax.jit @@ -257,16 +232,16 @@ def solve_ode(input: jnp.ndarray, cond: jnp.ndarray) -> jnp.ndarray: x=x, condition=cond) ), - diffeqsolve_kwargs.pop("solver", diffrax.Tsit5()), + kwargs.pop("solver", diffrax.Tsit5()), t0=t0, t1=t1, - dt0=diffeqsolve_kwargs.pop("dt0", None), + dt0=kwargs.pop("dt0", None), y0=input, - stepsize_controller=diffeqsolve_kwargs.pop( + stepsize_controller=kwargs.pop( "stepsize_controller", diffrax.PIDController(rtol=1e-5, atol=1e-5) ), - **diffeqsolve_kwargs, + **kwargs, ).ys[0] return jax.vmap(solve_ode)(data, condition) diff --git a/src/ott/neural/models/base_solver.py b/src/ott/neural/models/base_solver.py index 1bc541ec7..43675d4fc 100644 --- a/src/ott/neural/models/base_solver.py +++ b/src/ott/neural/models/base_solver.py @@ -26,7 +26,7 @@ from ott.solvers.linear import sinkhorn from ott.solvers.quadratic import gromov_wasserstein -__all__ = ["ResampleMixin", "UnbalancednessHandler"] +__all__ = ["OTMatcher", "UnbalancednessHandler"] def _get_sinkhorn_match_fn( @@ -132,8 +132,37 @@ def match_pairs( return match_pairs -class ResampleMixin: - """Mixin class for mini-batch OT in neural optimal transport solvers.""" +class OTMatcher: + """Class for mini-batch OT in neural optimal transport solvers. + + Args: + ot_solver: OT solver to match samples from the source and the target + distribution as proposed in :cite:`tong:23`, :cite:`pooladian:23`. + If :obj:`None`, no matching will be performed as proposed in + :cite:`lipman:22`. + """ + + def __init__( + self, + ot_solver: was_solver.WassersteinSolver, + epsilon: float = 1e-2, + cost_fn: Optional[costs.CostFn] = None, + scale_cost: Union[bool, int, float, + Literal["mean", "max_norm", "max_bound", "max_cost", + "median"]] = "mean", + tau_a: float = 1.0, + tau_b: float = 1.0 + ) -> None: + self.ot_solver = ot_solver + self.epsilon = epsilon + self.cost_fn = cost_fn + self.scale_cost = scale_cost + self.tau_a = tau_a + self.tau_b = tau_b + self.match_fn = self._get_sinkhorn_match_fn( + self.ot_solver, self.epsilon, self.cost_fn, self.scale_cost, self.tau_a, + self.tau_b + ) def _resample_data( self, @@ -233,19 +262,21 @@ class UnbalancednessHandler: cond_dim: Dimension of the conditioning variable. If :obj:`None`, no conditioning is used. tau_a: Unbalancedness parameter for the source distribution. + Only used if `ot_solver` is not :obj:`None`. tau_b: Unbalancedness parameter for the target distribution. + Only used if `ot_solver` is not :obj:`None`. rescaling_a: Rescaling function for the source distribution. - If :obj:`None`, the left rescaling factor is not learnt. + If :obj:`None`, the left rescaling factor is not learnt. rescaling_b: Rescaling function for the target distribution. - If :obj:`None`, the right rescaling factor is not learnt. + If :obj:`None`, the right rescaling factor is not learnt. opt_eta: Optimizer for the left rescaling function. opt_xi: Optimzier for the right rescaling function. resample_epsilon: Epsilon for resampling. scale_cost: Scaling of the cost matrix for estimating the rescaling factors. ot_solver: Solver to compute unbalanced marginals. If `ot_solver` is `None`, - the method + the method :meth:`ott.neural.models.base_solver.UnbalancednessHandler.compute_unbalanced_marginals` - is not available, and hence the unbalanced marginals must be computed by the neural solver. + is not available, and hence the unbalanced marginals must be computed by the neural solver. kwargs: Additional keyword arguments. """ @@ -296,6 +327,8 @@ def __init__( ) elif isinstance(ot_solver, gromov_wasserstein.GromovWasserstein): self.compute_unbalanced_marginals = self._get_compute_unbalanced_marginals_quad + else: + self.compute_unbalanced_marginals = None self.setup(source_dim=source_dim, target_dim=target_dim, cond_dim=cond_dim) def _get_compute_unbalanced_marginals_lin( diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index e4d22a789..c3675b820 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -46,6 +46,7 @@ def test_flow_matching( latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() + ot_matcher = base_solver.OTMatcher(ot_solver) time_sampler = samplers.uniform_sampler optimizer = optax.adam(learning_rate=1e-3) unbalancedness_handler = base_solver.UnbalancednessHandler( @@ -57,7 +58,7 @@ def test_flow_matching( cond_dim=condition_dim, iterations=3, valid_freq=2, - ot_solver=ot_solver, + ot_matcher=ot_matcher, flow=flow, time_sampler=time_sampler, optimizer=optimizer, @@ -98,6 +99,7 @@ def test_flow_matching_with_conditions( latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() + ot_matcher = base_solver.OTMatcher(ot_solver) time_sampler = functools.partial(samplers.uniform_sampler, offset=1e-5) optimizer = optax.adam(learning_rate=1e-3) unbalancedness_handler = base_solver.UnbalancednessHandler( @@ -110,7 +112,7 @@ def test_flow_matching_with_conditions( cond_dim=1, iterations=3, valid_freq=2, - ot_solver=ot_solver, + ot_matcher=ot_matcher, flow=flow, time_sampler=time_sampler, optimizer=optimizer, @@ -154,6 +156,7 @@ def test_flow_matching_conditional( latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() + ot_matcher = base_solver.OTMatcher(ot_solver) time_sampler = samplers.uniform_sampler optimizer = optax.adam(learning_rate=1e-3) unbalancedness_handler = base_solver.UnbalancednessHandler( @@ -166,7 +169,7 @@ def test_flow_matching_conditional( cond_dim=condition_dim, iterations=3, valid_freq=2, - ot_solver=ot_solver, + ot_matcher=ot_matcher, flow=flow, time_sampler=time_sampler, optimizer=optimizer, @@ -215,6 +218,11 @@ def test_flow_matching_learn_rescaling( tau_b = 0.2 rescaling_a = nets.RescalingMLP(hidden_dim=4, condition_dim=condition_dim) rescaling_b = nets.RescalingMLP(hidden_dim=4, condition_dim=condition_dim) + ot_matcher = base_solver.OTMatcher( + ot_solver, + tau_a=tau_a, + tau_b=tau_b, + ) unbalancedness_handler = base_solver.UnbalancednessHandler( random.PRNGKey(0), source_dim, @@ -232,7 +240,7 @@ def test_flow_matching_learn_rescaling( cond_dim=condition_dim, iterations=3, valid_freq=2, - ot_solver=ot_solver, + ot_matcher=ot_matcher, flow=flow, time_sampler=time_sampler, optimizer=optimizer, From 3cae628bdae6546f28e38cf3506efe8c44804360 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 9 Feb 2024 16:36:25 +0100 Subject: [PATCH 085/186] [ci skip] split GENOT into GENOTLin and GENOTQuad --- src/ott/neural/flows/genot.py | 346 ++++++++++++++------------- src/ott/neural/flows/otfm.py | 2 +- src/ott/neural/models/base_solver.py | 118 ++++++--- tests/neural/genot_test.py | 340 +++++++++++++------------- tests/neural/otfm_test.py | 8 +- 5 files changed, 445 insertions(+), 369 deletions(-) diff --git a/src/ott/neural/flows/genot.py b/src/ott/neural/flows/genot.py index 02a203a24..bd825ad40 100644 --- a/src/ott/neural/flows/genot.py +++ b/src/ott/neural/flows/genot.py @@ -13,7 +13,7 @@ # limitations under the License. import functools import types -from typing import Any, Callable, Dict, Literal, Optional, Type, Union +from typing import Any, Callable, Dict, Optional, Type, Union import jax import jax.numpy as jnp @@ -23,17 +23,13 @@ from flax.training import train_state from ott import utils -from ott.geometry import costs from ott.neural.flows import flows, samplers from ott.neural.models import base_solver -from ott.solvers import was_solver -from ott.solvers.linear import sinkhorn -from ott.solvers.quadratic import gromov_wasserstein -__all__ = ["GENOT"] +__all__ = ["GENOTBase", "GENOTLin", "GENOTQuad"] -class GENOT: +class GENOTBase: """The GENOT training class as introduced in :cite:`klein_uscidda:23`. Args: @@ -87,23 +83,14 @@ def __init__( cond_dim: int, iterations: int, valid_freq: int, - ot_solver: was_solver.WassersteinSolver, - epsilon: float, - cost_fn: Union[costs.CostFn, Dict[str, costs.CostFn]], - scale_cost: Union[Union[bool, int, float, - Literal["mean", "max_norm", "max_bound", - "max_cost", "median"]], - Dict[str, Union[bool, int, float, - Literal["mean", "max_norm", "max_bound", - "max_cost", "median"]]]], + ot_matcher: base_solver.BaseOTMatcher, unbalancedness_handler: base_solver.UnbalancednessHandler, optimizer: optax.GradientTransformation, flow: Type[flows.BaseFlow] = flows.ConstantNoiseFlow(0.0), # noqa: B008 time_sampler: Callable[[jax.Array, int], jnp.ndarray] = samplers.uniform_sampler, k_samples_per_x: int = 1, - solver_latent_to_data: Optional[Type[was_solver.WassersteinSolver] - ] = None, + matcher_latent_to_data: Optional[base_solver.OTMatcherLinear] = None, kwargs_solver_latent_to_data: Dict[str, Any] = types.MappingProxyType({}), fused_penalty: float = 0.0, callback_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], @@ -112,15 +99,6 @@ def __init__( ): rng = utils.default_prng_key(rng) - if isinstance( - ot_solver, gromov_wasserstein.GromovWasserstein - ) and epsilon is not None: - raise ValueError( - "If `ot_solver` is `GromovWasserstein`, `epsilon` must be `None`. " + - "This check is performed to ensure that in the (fused) Gromov case " + - "the `epsilon` parameter is passed via the `ot_solver`." - ) - self.rng = utils.default_prng_key(rng) self.iterations = iterations self.valid_freq = valid_freq @@ -129,6 +107,7 @@ def __init__( self.flow = flow self.time_sampler = time_sampler self.optimizer = optimizer + self.ot_matcher = ot_matcher self.latent_noise_fn = jax.tree_util.Partial( jax.random.multivariate_normal, mean=jnp.zeros((output_dim,)), @@ -143,14 +122,11 @@ def __init__( self.unbalancedness_handler = unbalancedness_handler # OT data-data matching parameters - self.ot_solver = ot_solver - self.epsilon = epsilon - self.cost_fn = cost_fn - self.scale_cost = scale_cost + self.fused_penalty = fused_penalty # OT latent-data matching parameters - self.solver_latent_to_data = solver_latent_to_data + self.matcher_latent_to_data = matcher_latent_to_data self.kwargs_solver_latent_to_data = kwargs_solver_latent_to_data # callback parameteres @@ -165,130 +141,6 @@ def setup(self) -> None: ) ) self.step_fn = self._get_step_fn() - if self.solver_latent_to_data is not None: - self.match_latent_to_data_fn = self._get_sinkhorn_match_fn( - ot_solver=self.solver_latent_to_data, - **self.kwargs_solver_latent_to_data - ) - else: - self.match_latent_to_data_fn = lambda key, x, y, **_: (x, y) - - # TODO: add graph construction function - if isinstance(self.ot_solver, sinkhorn.Sinkhorn): - self.match_fn = self._get_sinkhorn_match_fn( - ot_solver=self.ot_solver, - epsilon=self.epsilon, - cost_fn=self.cost_fn, - scale_cost=self.scale_cost, - tau_a=self.unbalancedness_handler.tau_a, - tau_b=self.unbalancedness_handler.tau_b, - filter_input=True - ) - else: - self.match_fn = self._get_gromov_match_fn( - ot_solver=self.ot_solver, - cost_fn=self.cost_fn, - scale_cost=self.scale_cost, - tau_a=self.unbalancedness_handler.tau_a, - tau_b=self.unbalancedness_handler.tau_b, - fused_penalty=self.fused_penalty - ) - - def __call__(self, train_loader, valid_loader): - """Train GENOT. - - Args: - train_loader: Data loader for the training data. - valid_loader: Data loader for the validation data. - """ - batch: Dict[str, jnp.array] = {} - for iteration in range(self.iterations): - batch = next(train_loader) - - ( - self.rng, rng_time, rng_resample, rng_noise, rng_latent_data_match, - rng_step_fn - ) = jax.random.split(self.rng, 6) - batch_size = len( - batch["source_lin"] - ) if batch["source_lin"] is not None else len(batch["source_quad"]) - n_samples = batch_size * self.k_samples_per_x - batch["time"] = self.time_sampler(rng_time, n_samples) - batch["latent"] = self.latent_noise_fn( - rng_noise, shape=(self.k_samples_per_x, batch_size) - ) - - tmat = self.match_fn( - batch["source_lin"], batch["source_quad"], batch["target_lin"], - batch["target_quad"] - ) - - batch["source"] = jnp.concatenate([ - batch[el] - for el in ["source_lin", "source_quad"] - if batch[el] is not None - ], - axis=1) - batch["target"] = jnp.concatenate([ - batch[el] - for el in ["target_lin", "target_quad"] - if batch[el] is not None - ], - axis=1) - - batch = { - k: v for k, v in batch.items() if k in - ["source", "target", "source_conditions", "time", "noise", "latent"] - } - - (batch["source"], batch["source_conditions"] - ), (batch["target"],) = self._sample_conditional_indices_from_tmap( - rng_resample, - tmat, - self.k_samples_per_x, (batch["source"], batch["source_conditions"]), - (batch["target"],), - source_is_balanced=(self.unbalancedness_handler.tau_a == 1.0) - ) - - if self.solver_latent_to_data is not None: - tmats_latent_data = jnp.array( - jax.vmap(self.match_latent_to_data_fn, 0, - 0)(x=batch["latent"], y=batch["target"]) - ) - - rng_latent_data_match = jax.random.split( - rng_latent_data_match, self.k_samples_per_x - ) - (batch["source"], batch["source_conditions"] - ), (batch["target"],) = jax.vmap(self._resample_data, 0, 0)( - rng_latent_data_match, tmats_latent_data, - (batch["source"], batch["source_conditions"]), (batch["target"],) - ) - batch = { - key: - jnp.reshape(arr, (batch_size * self.k_samples_per_x, - -1)) if arr is not None else None - for key, arr in batch.items() - } - - self.state_velocity_field, loss = self.step_fn( - rng_step_fn, self.state_velocity_field, batch - ) - if self.learn_rescaling: - ( - self.state_eta, self.state_xi, eta_predictions, xi_predictions, - loss_a, loss_b - ) = self.unbalancedness_handler.step_fn( - source=batch["source"], - target=batch["target"], - condition=batch["source_conditions"], - a=tmat.sum(axis=1), - b=tmat.sum(axis=0), - state_eta=self.unbalancedness_handler.state_eta, - state_xi=self.unbalancedness_handler.state_xi, - ) - if iteration % self.valid_freq == 0: - self._valid_step(valid_loader, iteration) def _get_step_fn(self) -> Callable: @@ -403,26 +255,178 @@ def learn_rescaling(self) -> bool: self.unbalancedness_handler.rescaling_b is not None ) - def save(self, path: str): - """Save the model. + +class GENOTLin(GENOTBase): + + def __call__(self, train_loader, valid_loader): + """Train GENOT. Args: - path: Where to save the model to. + train_loader: Data loader for the training data. + valid_loader: Data loader for the validation data. """ - raise NotImplementedError + batch: Dict[str, jnp.array] = {} + for iteration in range(self.iterations): + batch = next(train_loader) + + ( + self.rng, rng_time, rng_resample, rng_noise, rng_latent_data_match, + rng_step_fn + ) = jax.random.split(self.rng, 6) + batch_size = len(batch["source_lin"]) + n_samples = batch_size * self.k_samples_per_x + batch["time"] = self.time_sampler(rng_time, n_samples) + batch["latent"] = self.latent_noise_fn( + rng_noise, shape=(self.k_samples_per_x, batch_size) + ) - def load(self, path: str) -> "GENOT": - """Load a model. + tmat = self.ot_matcher.match_fn( + batch["source_lin"], + batch["target_lin"], + ) - Args: - path: Where to load the model from. + batch["source"] = batch["source_lin"] + batch["target"] = batch["target_lin"] - Returns: - An instance of :class:`ott.neural.solvers.OTFlowMatching`. + (batch["source"], batch["source_conditions"]), ( + batch["target"], + ) = self.ot_matcher._sample_conditional_indices_from_tmap( + rng_resample, + tmat, + self.k_samples_per_x, (batch["source"], batch["source_conditions"]), + (batch["target"],), + source_is_balanced=(self.unbalancedness_handler.tau_a == 1.0) + ) + + if self.matcher_latent_to_data.match_fn is not None: + tmats_latent_data = jnp.array( + jax.vmap(self.matcher_latent_to_data.match_fn, 0, + 0)(x=batch["latent"], y=batch["target"]) + ) + + rng_latent_data_match = jax.random.split( + rng_latent_data_match, self.k_samples_per_x + ) + (batch["source"], batch["source_conditions"] + ), (batch["target"],) = jax.vmap(self.ot_matcher._resample_data, 0, 0)( + rng_latent_data_match, tmats_latent_data, + (batch["source"], batch["source_conditions"]), (batch["target"],) + ) + batch = { + key: + jnp.reshape(arr, (batch_size * self.k_samples_per_x, + -1)) if arr is not None else None + for key, arr in batch.items() + } + + self.state_velocity_field, loss = self.step_fn( + rng_step_fn, self.state_velocity_field, batch + ) + if self.learn_rescaling: + ( + self.state_eta, self.state_xi, eta_predictions, xi_predictions, + loss_a, loss_b + ) = self.unbalancedness_handler.step_fn( + source=batch["source"], + target=batch["target"], + condition=batch["source_conditions"], + a=tmat.sum(axis=1), + b=tmat.sum(axis=0), + state_eta=self.unbalancedness_handler.state_eta, + state_xi=self.unbalancedness_handler.state_xi, + ) + if iteration % self.valid_freq == 0: + self._valid_step(valid_loader, iteration) + + +class GENOTQuad(GENOTBase): + + def __call__(self, train_loader, valid_loader): + """Train GENOT. + + Args: + train_loader: Data loader for the training data. + valid_loader: Data loader for the validation data. """ - raise NotImplementedError + batch: Dict[str, jnp.array] = {} + for iteration in range(self.iterations): + batch = next(train_loader) - @property - def training_logs(self) -> Dict[str, Any]: - """Logs of the training.""" - raise NotImplementedError + ( + self.rng, rng_time, rng_resample, rng_noise, rng_latent_data_match, + rng_step_fn + ) = jax.random.split(self.rng, 6) + batch_size = len( + batch["source_lin"] + ) if batch["source_lin"] is not None else len(batch["source_quad"]) + n_samples = batch_size * self.k_samples_per_x + batch["time"] = self.time_sampler(rng_time, n_samples) + batch["latent"] = self.latent_noise_fn( + rng_noise, shape=(self.k_samples_per_x, batch_size) + ) + + tmat = self.ot_matcher.match_fn( + batch["source_lin"], batch["source_quad"], batch["target_lin"], + batch["target_quad"] + ) + + if self.ot_matcher.fused_penalty > 0.0: + batch["source"] = jnp.concatenate( + (batch["source_lin"], batch["source_quad"]), axis=1 + ) + batch["target"] = jnp.concatenate( + (batch["target_lin"], batch["target_quad"]), axis=1 + ) + else: + batch["source"] = batch["source_quad"] + batch["target"] = batch["target_quad"] + + (batch["source"], batch["source_conditions"]), ( + batch["target"], + ) = self.ot_matcher._sample_conditional_indices_from_tmap( + rng_resample, + tmat, + self.k_samples_per_x, (batch["source"], batch["source_conditions"]), + (batch["target"],), + source_is_balanced=(self.unbalancedness_handler.tau_a == 1.0) + ) + + if self.matcher_latent_to_data.match_fn is not None: + tmats_latent_data = jnp.array( + jax.vmap(self.matcher_latent_to_data.match_fn, 0, + 0)(x=batch["latent"], y=batch["target"]) + ) + + rng_latent_data_match = jax.random.split( + rng_latent_data_match, self.k_samples_per_x + ) + (batch["source"], batch["source_conditions"] + ), (batch["target"],) = jax.vmap(self.ot_matcher._resample_data, 0, 0)( + rng_latent_data_match, tmats_latent_data, + (batch["source"], batch["source_conditions"]), (batch["target"],) + ) + batch = { + key: + jnp.reshape(arr, (batch_size * self.k_samples_per_x, + -1)) if arr is not None else None + for key, arr in batch.items() + } + + self.state_velocity_field, loss = self.step_fn( + rng_step_fn, self.state_velocity_field, batch + ) + if self.learn_rescaling: + ( + self.state_eta, self.state_xi, eta_predictions, xi_predictions, + loss_a, loss_b + ) = self.unbalancedness_handler.step_fn( + source=batch["source"], + target=batch["target"], + condition=batch["source_conditions"], + a=tmat.sum(axis=1), + b=tmat.sum(axis=0), + state_eta=self.unbalancedness_handler.state_eta, + state_xi=self.unbalancedness_handler.state_xi, + ) + if iteration % self.valid_freq == 0: + self._valid_step(valid_loader, iteration) diff --git a/src/ott/neural/flows/otfm.py b/src/ott/neural/flows/otfm.py index de0c8c5e5..84fcd5e96 100644 --- a/src/ott/neural/flows/otfm.py +++ b/src/ott/neural/flows/otfm.py @@ -63,7 +63,7 @@ def __init__( flow: Type[flows.BaseFlow], time_sampler: Callable[[jax.Array, int], jnp.ndarray], optimizer: optax.GradientTransformation, - ot_matcher: base_solver.OTMatcher, + ot_matcher: base_solver.OTMatcherLinear, unbalancedness_handler: base_solver.UnbalancednessHandler, checkpoint_manager: Type[checkpoint.CheckpointManager] = None, epsilon: float = 1e-2, diff --git a/src/ott/neural/models/base_solver.py b/src/ott/neural/models/base_solver.py index 43675d4fc..a4238cd05 100644 --- a/src/ott/neural/models/base_solver.py +++ b/src/ott/neural/models/base_solver.py @@ -26,7 +26,9 @@ from ott.solvers.linear import sinkhorn from ott.solvers.quadratic import gromov_wasserstein -__all__ = ["OTMatcher", "UnbalancednessHandler"] +__all__ = [ + "BaseOTMatcher", "OTMatcherLinear", "OTMatcherQuad", "UnbalancednessHandler" +] def _get_sinkhorn_match_fn( @@ -132,37 +134,8 @@ def match_pairs( return match_pairs -class OTMatcher: - """Class for mini-batch OT in neural optimal transport solvers. - - Args: - ot_solver: OT solver to match samples from the source and the target - distribution as proposed in :cite:`tong:23`, :cite:`pooladian:23`. - If :obj:`None`, no matching will be performed as proposed in - :cite:`lipman:22`. - """ - - def __init__( - self, - ot_solver: was_solver.WassersteinSolver, - epsilon: float = 1e-2, - cost_fn: Optional[costs.CostFn] = None, - scale_cost: Union[bool, int, float, - Literal["mean", "max_norm", "max_bound", "max_cost", - "median"]] = "mean", - tau_a: float = 1.0, - tau_b: float = 1.0 - ) -> None: - self.ot_solver = ot_solver - self.epsilon = epsilon - self.cost_fn = cost_fn - self.scale_cost = scale_cost - self.tau_a = tau_a - self.tau_b = tau_b - self.match_fn = self._get_sinkhorn_match_fn( - self.ot_solver, self.epsilon, self.cost_fn, self.scale_cost, self.tau_a, - self.tau_b - ) +class BaseOTMatcher: + """Base class for mini-batch neural OT matching classes.""" def _resample_data( self, @@ -229,6 +202,48 @@ def _sample_conditional_indices_from_tmap( for b in target_arrays ) + +class OTMatcherLinear(BaseOTMatcher): + """Class for mini-batch OT in neural optimal transport solvers. + + Args: + ot_solver: OT solver to match samples from the source and the target + distribution as proposed in :cite:`tong:23`, :cite:`pooladian:23`. + If :obj:`None`, no matching will be performed as proposed in + :cite:`lipman:22`. + """ + + def __init__( + self, + ot_solver: was_solver.WassersteinSolver, + epsilon: float = 1e-2, + cost_fn: Optional[costs.CostFn] = None, + scale_cost: Union[bool, int, float, + Literal["mean", "max_norm", "max_bound", "max_cost", + "median"]] = "mean", + tau_a: float = 1.0, + tau_b: float = 1.0, + ) -> None: + + if isinstance( + ot_solver, gromov_wasserstein.GromovWasserstein + ) and epsilon is not None: + raise ValueError( + "If `ot_solver` is `GromovWasserstein`, `epsilon` must be `None`. " + + "This check is performed to ensure that in the (fused) Gromov case " + + "the `epsilon` parameter is passed via the `ot_solver`." + ) + self.ot_solver = ot_solver + self.epsilon = epsilon + self.cost_fn = cost_fn + self.scale_cost = scale_cost + self.tau_a = tau_a + self.tau_b = tau_b + self.match_fn = None if ot_solver is None else self._get_sinkhorn_match_fn( + self.ot_solver, self.epsilon, self.cost_fn, self.scale_cost, self.tau_a, + self.tau_b + ) + def _get_sinkhorn_match_fn(self, *args, **kwargs) -> jnp.ndarray: fn = _get_sinkhorn_match_fn(*args, **kwargs) @@ -238,6 +253,43 @@ def match_pairs(*args, **kwargs): return match_pairs + +class OTMatcherQuad(BaseOTMatcher): + """Class for mini-batch OT in neural optimal transport solvers. + + Args: + ot_solver: OT solver to match samples from the source and the target + distribution as proposed in :cite:`tong:23`, :cite:`pooladian:23`. + If :obj:`None`, no matching will be performed as proposed in + :cite:`lipman:22`. + """ + + def __init__( + self, + ot_solver: was_solver.WassersteinSolver, + cost_fn: Optional[costs.CostFn] = None, + scale_cost: Union[bool, int, float, + Literal["mean", "max_norm", "max_bound", "max_cost", + "median"]] = "mean", + tau_a: float = 1.0, + tau_b: float = 1.0, + fused_penalty: float = 0.0, + ) -> None: + self.ot_solver = ot_solver + self.cost_fn = cost_fn + self.scale_cost = scale_cost + self.tau_a = tau_a + self.tau_b = tau_b + self.fused_penalty = fused_penalty + self.match_fn = self._get_gromov_match_fn( + self.ot_solver, + self.cost_fn, + self.scale_cost, + self.tau_a, + self.tau_b, + fused_penalty=self.fused_penalty + ) + def _get_gromov_match_fn(self, *args, **kwargs) -> jnp.ndarray: fn = _get_gromov_match_fn(*args, **kwargs) @@ -326,7 +378,7 @@ def __init__( ) ) elif isinstance(ot_solver, gromov_wasserstein.GromovWasserstein): - self.compute_unbalanced_marginals = self._get_compute_unbalanced_marginals_quad + raise NotImplementedError else: self.compute_unbalanced_marginals = None self.setup(source_dim=source_dim, target_dim=target_dim, cond_dim=cond_dim) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index 191b04d2a..7ab2a957c 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -22,7 +22,7 @@ import optax from ott.geometry import costs -from ott.neural.flows.genot import GENOT +from ott.neural.flows.genot import GENOTLin, GENOTQuad from ott.neural.flows.models import VelocityField from ott.neural.flows.samplers import uniform_sampler from ott.neural.models import base_solver @@ -31,7 +31,7 @@ from ott.solvers.quadratic import gromov_wasserstein -class TestGENOT: +class TestGENOTLin: @pytest.mark.parametrize("scale_cost", ["mean", 2.0]) @pytest.mark.parametrize("k_samples_per_x", [1, 3]) @@ -41,7 +41,7 @@ def test_genot_linear_unconditional( scale_cost: Union[float, Literal["mean"]], k_samples_per_x: int, solver_latent_to_data: Optional[str] ): - solver_latent_to_data = ( + matcher_latent_to_data = base_solver.OTMatcherLinear( None if solver_latent_to_data is None else sinkhorn.Sinkhorn() ) batch = next(genot_data_loader_linear) @@ -58,27 +58,27 @@ def test_genot_linear_unconditional( latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() + ot_matcher = base_solver.OTMatcherLinear( + ot_solver, cost_fn=costs.SqEuclidean(), scale_cost=scale_cost + ) unbalancedness_handler = base_solver.UnbalancednessHandler( random.PRNGKey(0), source_dim, target_dim, condition_dim ) time_sampler = uniform_sampler optimizer = optax.adam(learning_rate=1e-3) - genot = GENOT( + genot = GENOTLin( neural_vf, input_dim=source_dim, output_dim=target_dim, cond_dim=condition_dim, iterations=3, valid_freq=2, - ot_solver=ot_solver, - epsilon=0.1, - cost_fn=costs.SqEuclidean(), - scale_cost=scale_cost, + ot_matcher=ot_matcher, optimizer=optimizer, time_sampler=time_sampler, unbalancedness_handler=unbalancedness_handler, k_samples_per_x=k_samples_per_x, - solver_latent_to_data=solver_latent_to_data, + matcher_latent_to_data=matcher_latent_to_data, ) genot(genot_data_loader_linear, genot_data_loader_linear) @@ -94,69 +94,157 @@ def test_genot_linear_unconditional( @pytest.mark.parametrize("k_samples_per_x", [1, 2]) @pytest.mark.parametrize("solver_latent_to_data", [None, "sinkhorn"]) - def test_genot_quad_unconditional( - self, genot_data_loader_quad: Iterator, k_samples_per_x: int, - solver_latent_to_data: Optional[str] + def test_genot_linear_conditional( + self, genot_data_loader_linear_conditional: Iterator, + k_samples_per_x: int, solver_latent_to_data: Optional[str] ): - None if solver_latent_to_data is None else sinkhorn.Sinkhorn() - batch = next(genot_data_loader_quad) - source_quad, target_quad, source_condition = batch["source_quad"], batch[ - "target_quad"], batch["source_conditions"] + matcher_latent_to_data = base_solver.OTMatcherLinear( + None if solver_latent_to_data is None else sinkhorn.Sinkhorn() + ) + batch = next(genot_data_loader_linear_conditional) + source_lin, target_lin, source_condition = batch["source_lin"], batch[ + "target_lin"], batch["source_conditions"] + source_dim = source_lin.shape[1] + target_dim = target_lin.shape[1] + condition_dim = source_condition.shape[1] - source_dim = source_quad.shape[1] - target_dim = target_quad.shape[1] - condition_dim = 0 neural_vf = VelocityField( output_dim=target_dim, condition_dim=source_dim + condition_dim, latent_embed_dim=5, ) - ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) - + ot_solver = sinkhorn.Sinkhorn() + ot_matcher = base_solver.OTMatcherLinear( + ot_solver, cost_fn=costs.SqEuclidean() + ) + time_sampler = uniform_sampler unbalancedness_handler = base_solver.UnbalancednessHandler( random.PRNGKey(0), source_dim, target_dim, condition_dim ) - time_sampler = functools.partial(uniform_sampler, offset=1e-2) optimizer = optax.adam(learning_rate=1e-3) - genot = GENOT( + genot = GENOTLin( neural_vf, input_dim=source_dim, output_dim=target_dim, cond_dim=condition_dim, iterations=3, valid_freq=2, - ot_solver=ot_solver, - epsilon=None, - cost_fn=costs.SqEuclidean(), + ot_matcher=ot_matcher, unbalancedness_handler=unbalancedness_handler, - scale_cost=1.0, optimizer=optimizer, time_sampler=time_sampler, k_samples_per_x=k_samples_per_x, + matcher_latent_to_data=matcher_latent_to_data, + ) + genot( + genot_data_loader_linear_conditional, + genot_data_loader_linear_conditional ) - genot(genot_data_loader_quad, genot_data_loader_quad) - result_forward = genot.transport( - source_quad, condition=source_condition, forward=True + source_lin, condition=source_condition, forward=True ) assert isinstance(result_forward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_forward)) == 0 + @pytest.mark.parametrize("conditional", [False, True]) + @pytest.mark.parametrize("solver_latent_to_data", [None, "sinkhorn"]) + def test_genot_linear_learn_rescaling( + self, conditional: bool, genot_data_loader_linear: Iterator, + solver_latent_to_data: Optional[str], + genot_data_loader_linear_conditional: Iterator + ): + matcher_latent_to_data = base_solver.OTMatcherLinear( + None if solver_latent_to_data is None else sinkhorn.Sinkhorn() + ) + data_loader = ( + genot_data_loader_linear_conditional + if conditional else genot_data_loader_linear + ) + + batch = next(data_loader) + source_lin, target_lin, source_condition = batch["source_lin"], batch[ + "target_lin"], batch["source_conditions"] + + source_dim = source_lin.shape[1] + target_dim = target_lin.shape[1] + condition_dim = source_condition.shape[1] if conditional else 0 + + neural_vf = VelocityField( + output_dim=target_dim, + condition_dim=source_dim + condition_dim, + latent_embed_dim=5, + ) + ot_solver = sinkhorn.Sinkhorn() + ot_matcher = base_solver.OTMatcherLinear( + ot_solver, + cost_fn=costs.SqEuclidean(), + ) + time_sampler = uniform_sampler + optimizer = optax.adam(learning_rate=1e-3) + + tau_a = 0.9 + tau_b = 0.2 + rescaling_a = RescalingMLP(hidden_dim=4, condition_dim=condition_dim) + rescaling_b = RescalingMLP(hidden_dim=4, condition_dim=condition_dim) + + unbalancedness_handler = base_solver.UnbalancednessHandler( + random.PRNGKey(0), + source_dim, + target_dim, + condition_dim, + tau_a=tau_a, + tau_b=tau_b, + rescaling_a=rescaling_a, + rescaling_b=rescaling_b + ) + + genot = GENOTLin( + neural_vf, + input_dim=source_dim, + output_dim=target_dim, + cond_dim=condition_dim, + iterations=3, + valid_freq=2, + ot_matcher=ot_matcher, + optimizer=optimizer, + time_sampler=time_sampler, + unbalancedness_handler=unbalancedness_handler, + matcher_latent_to_data=matcher_latent_to_data, + ) + + genot(data_loader, data_loader) + + result_eta = genot.unbalancedness_handler.evaluate_eta( + source_lin, condition=source_condition + ) + assert isinstance(result_eta, jnp.ndarray) + assert jnp.sum(jnp.isnan(result_eta)) == 0 + + result_xi = genot.unbalancedness_handler.evaluate_xi( + target_lin, condition=source_condition + ) + assert isinstance(result_xi, jnp.ndarray) + assert jnp.sum(jnp.isnan(result_xi)) == 0 + + +class TestGENOTQuad: + @pytest.mark.parametrize("k_samples_per_x", [1, 2]) @pytest.mark.parametrize("solver_latent_to_data", [None, "sinkhorn"]) - def test_genot_fused_unconditional( - self, genot_data_loader_fused: Iterator, k_samples_per_x: int, + def test_genot_quad_unconditional( + self, genot_data_loader_quad: Iterator, k_samples_per_x: int, solver_latent_to_data: Optional[str] ): - None if solver_latent_to_data is None else sinkhorn.Sinkhorn() - batch = next(genot_data_loader_fused) - source_lin, source_quad, target_lin, target_quad, source_condition = batch[ - "source_lin"], batch["source_quad"], batch["target_lin"], batch[ - "target_quad"], batch["source_conditions"] + matcher_latent_to_data = base_solver.OTMatcherLinear( + None if solver_latent_to_data is None else sinkhorn.Sinkhorn() + ) + batch = next(genot_data_loader_quad) + source_quad, target_quad, source_condition = batch["source_quad"], batch[ + "target_quad"], batch["source_conditions"] - source_dim = source_lin.shape[1] + source_quad.shape[1] - target_dim = target_lin.shape[1] + target_quad.shape[1] + source_dim = source_quad.shape[1] + target_dim = target_quad.shape[1] condition_dim = 0 neural_vf = VelocityField( output_dim=target_dim, @@ -164,85 +252,89 @@ def test_genot_fused_unconditional( latent_embed_dim=5, ) ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) + ot_matcher = base_solver.OTMatcherQuad( + ot_solver, cost_fn=costs.SqEuclidean() + ) + unbalancedness_handler = base_solver.UnbalancednessHandler( random.PRNGKey(0), source_dim, target_dim, condition_dim ) + time_sampler = functools.partial(uniform_sampler, offset=1e-2) optimizer = optax.adam(learning_rate=1e-3) - genot = GENOT( + genot = GENOTQuad( neural_vf, input_dim=source_dim, output_dim=target_dim, cond_dim=condition_dim, - epsilon=None, iterations=3, valid_freq=2, - ot_solver=ot_solver, - cost_fn=costs.SqEuclidean(), - scale_cost=1.0, + ot_matcher=ot_matcher, unbalancedness_handler=unbalancedness_handler, optimizer=optimizer, - fused_penalty=0.5, + time_sampler=time_sampler, k_samples_per_x=k_samples_per_x, + matcher_latent_to_data=matcher_latent_to_data, ) - genot(genot_data_loader_fused, genot_data_loader_fused) + genot(genot_data_loader_quad, genot_data_loader_quad) result_forward = genot.transport( - jnp.concatenate((source_lin, source_quad), axis=1), - condition=source_condition, - forward=True + source_quad, condition=source_condition, forward=True ) assert isinstance(result_forward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_forward)) == 0 @pytest.mark.parametrize("k_samples_per_x", [1, 2]) @pytest.mark.parametrize("solver_latent_to_data", [None, "sinkhorn"]) - def test_genot_linear_conditional( - self, genot_data_loader_linear_conditional: Iterator, - k_samples_per_x: int, solver_latent_to_data: Optional[str] + def test_genot_fused_unconditional( + self, genot_data_loader_fused: Iterator, k_samples_per_x: int, + solver_latent_to_data: Optional[str] ): - None if solver_latent_to_data is None else sinkhorn.Sinkhorn() - batch = next(genot_data_loader_linear_conditional) - source_lin, target_lin, source_condition = batch["source_lin"], batch[ - "target_lin"], batch["source_conditions"] - source_dim = source_lin.shape[1] - target_dim = target_lin.shape[1] - condition_dim = source_condition.shape[1] + matcher_latent_to_data = base_solver.OTMatcherLinear( + None if solver_latent_to_data is None else sinkhorn.Sinkhorn() + ) + batch = next(genot_data_loader_fused) + source_lin, source_quad, target_lin, target_quad, source_condition = batch[ + "source_lin"], batch["source_quad"], batch["target_lin"], batch[ + "target_quad"], batch["source_conditions"] + source_dim = source_lin.shape[1] + source_quad.shape[1] + target_dim = target_lin.shape[1] + target_quad.shape[1] + condition_dim = 0 neural_vf = VelocityField( output_dim=target_dim, condition_dim=source_dim + condition_dim, latent_embed_dim=5, ) - ot_solver = sinkhorn.Sinkhorn() - time_sampler = uniform_sampler + ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) + ot_matcher = base_solver.OTMatcherQuad( + ot_solver, cost_fn=costs.SqEuclidean(), fused_penalty=0.5 + ) + unbalancedness_handler = base_solver.UnbalancednessHandler( random.PRNGKey(0), source_dim, target_dim, condition_dim ) optimizer = optax.adam(learning_rate=1e-3) - genot = GENOT( + genot = GENOTQuad( neural_vf, input_dim=source_dim, output_dim=target_dim, cond_dim=condition_dim, iterations=3, valid_freq=2, - ot_solver=ot_solver, - epsilon=0.1, - cost_fn=costs.SqEuclidean(), - scale_cost=1.0, + ot_matcher=ot_matcher, unbalancedness_handler=unbalancedness_handler, optimizer=optimizer, - time_sampler=time_sampler, k_samples_per_x=k_samples_per_x, + matcher_latent_to_data=matcher_latent_to_data, ) - genot( - genot_data_loader_linear_conditional, - genot_data_loader_linear_conditional - ) + genot(genot_data_loader_fused, genot_data_loader_fused) + result_forward = genot.transport( - source_lin, condition=source_condition, forward=True + jnp.concatenate((source_lin, source_quad), axis=1), + condition=source_condition, + forward=True ) assert isinstance(result_forward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_forward)) == 0 @@ -253,7 +345,9 @@ def test_genot_quad_conditional( self, genot_data_loader_quad_conditional: Iterator, k_samples_per_x: int, solver_latent_to_data: Optional[str] ): - None if solver_latent_to_data is None else sinkhorn.Sinkhorn() + matcher_latent_to_data = base_solver.OTMatcherLinear( + None if solver_latent_to_data is None else sinkhorn.Sinkhorn() + ) batch = next(genot_data_loader_quad_conditional) source_quad, target_quad, source_condition = batch["source_quad"], batch[ "target_quad"], batch["source_conditions"] @@ -267,27 +361,28 @@ def test_genot_quad_conditional( latent_embed_dim=5, ) ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) + ot_matcher = base_solver.OTMatcherQuad( + ot_solver, cost_fn=costs.SqEuclidean() + ) time_sampler = uniform_sampler unbalancedness_handler = base_solver.UnbalancednessHandler( random.PRNGKey(0), source_dim, target_dim, condition_dim ) optimizer = optax.adam(learning_rate=1e-3) - genot = GENOT( + genot = GENOTQuad( neural_vf, input_dim=source_dim, output_dim=target_dim, cond_dim=condition_dim, iterations=3, valid_freq=2, - ot_solver=ot_solver, - epsilon=None, - cost_fn=costs.SqEuclidean(), - scale_cost=1.0, + ot_matcher=ot_matcher, unbalancedness_handler=unbalancedness_handler, optimizer=optimizer, time_sampler=time_sampler, k_samples_per_x=k_samples_per_x, + matcher_latent_to_data=matcher_latent_to_data, ) genot( genot_data_loader_quad_conditional, genot_data_loader_quad_conditional @@ -305,7 +400,7 @@ def test_genot_fused_conditional( self, genot_data_loader_fused_conditional: Iterator, k_samples_per_x: int, solver_latent_to_data: Optional[str] ): - None if solver_latent_to_data is None else sinkhorn.Sinkhorn() + matcher_latent_to_data = base_solver.OTMatcherLinear(solver_latent_to_data) batch = next(genot_data_loader_fused_conditional) source_lin, source_quad, target_lin, target_quad, source_condition = batch[ "source_lin"], batch["source_quad"], batch["target_lin"], batch[ @@ -319,27 +414,28 @@ def test_genot_fused_conditional( latent_embed_dim=5, ) ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) + ot_matcher = base_solver.OTMatcherQuad( + ot_solver, cost_fn=costs.SqEuclidean(), fused_penalty=0.5 + ) time_sampler = uniform_sampler optimizer = optax.adam(learning_rate=1e-3) unbalancedness_handler = base_solver.UnbalancednessHandler( random.PRNGKey(0), source_dim, target_dim, condition_dim ) - genot = GENOT( + genot = GENOTQuad( neural_vf, input_dim=source_dim, output_dim=target_dim, cond_dim=condition_dim, iterations=3, valid_freq=2, - ot_solver=ot_solver, - epsilon=None, - cost_fn=costs.SqEuclidean(), - scale_cost=1.0, + ot_matcher=ot_matcher, unbalancedness_handler=unbalancedness_handler, optimizer=optimizer, time_sampler=time_sampler, k_samples_per_x=k_samples_per_x, + matcher_latent_to_data=matcher_latent_to_data, ) genot( genot_data_loader_fused_conditional, genot_data_loader_fused_conditional @@ -352,79 +448,3 @@ def test_genot_fused_conditional( ) assert isinstance(result_forward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_forward)) == 0 - - @pytest.mark.parametrize("conditional", [False, True]) - @pytest.mark.parametrize("solver_latent_to_data", [None, "sinkhorn"]) - def test_genot_linear_learn_rescaling( - self, conditional: bool, genot_data_loader_linear: Iterator, - solver_latent_to_data: Optional[str], - genot_data_loader_linear_conditional: Iterator - ): - None if solver_latent_to_data is None else sinkhorn.Sinkhorn() - data_loader = ( - genot_data_loader_linear_conditional - if conditional else genot_data_loader_linear - ) - - batch = next(data_loader) - source_lin, target_lin, source_condition = batch["source_lin"], batch[ - "target_lin"], batch["source_conditions"] - - source_dim = source_lin.shape[1] - target_dim = target_lin.shape[1] - condition_dim = source_condition.shape[1] if conditional else 0 - - neural_vf = VelocityField( - output_dim=target_dim, - condition_dim=source_dim + condition_dim, - latent_embed_dim=5, - ) - ot_solver = sinkhorn.Sinkhorn() - time_sampler = uniform_sampler - optimizer = optax.adam(learning_rate=1e-3) - - tau_a = 0.9 - tau_b = 0.2 - rescaling_a = RescalingMLP(hidden_dim=4, condition_dim=condition_dim) - rescaling_b = RescalingMLP(hidden_dim=4, condition_dim=condition_dim) - - unbalancedness_handler = base_solver.UnbalancednessHandler( - random.PRNGKey(0), - source_dim, - target_dim, - condition_dim, - tau_a=tau_a, - tau_b=tau_b, - rescaling_a=rescaling_a, - rescaling_b=rescaling_b - ) - - genot = GENOT( - neural_vf, - input_dim=source_dim, - output_dim=target_dim, - cond_dim=condition_dim, - iterations=3, - valid_freq=2, - ot_solver=ot_solver, - epsilon=0.1, - cost_fn=costs.SqEuclidean(), - scale_cost=1.0, - optimizer=optimizer, - time_sampler=time_sampler, - unbalancedness_handler=unbalancedness_handler, - ) - - genot(data_loader, data_loader) - - result_eta = genot.unbalancedness_handler.evaluate_eta( - source_lin, condition=source_condition - ) - assert isinstance(result_eta, jnp.ndarray) - assert jnp.sum(jnp.isnan(result_eta)) == 0 - - result_xi = genot.unbalancedness_handler.evaluate_xi( - target_lin, condition=source_condition - ) - assert isinstance(result_xi, jnp.ndarray) - assert jnp.sum(jnp.isnan(result_xi)) == 0 diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index c3675b820..9452c2faa 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -46,7 +46,7 @@ def test_flow_matching( latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() - ot_matcher = base_solver.OTMatcher(ot_solver) + ot_matcher = base_solver.OTMatcherLinear(ot_solver) time_sampler = samplers.uniform_sampler optimizer = optax.adam(learning_rate=1e-3) unbalancedness_handler = base_solver.UnbalancednessHandler( @@ -99,7 +99,7 @@ def test_flow_matching_with_conditions( latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() - ot_matcher = base_solver.OTMatcher(ot_solver) + ot_matcher = base_solver.OTMatcherLinear(ot_solver) time_sampler = functools.partial(samplers.uniform_sampler, offset=1e-5) optimizer = optax.adam(learning_rate=1e-3) unbalancedness_handler = base_solver.UnbalancednessHandler( @@ -156,7 +156,7 @@ def test_flow_matching_conditional( latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() - ot_matcher = base_solver.OTMatcher(ot_solver) + ot_matcher = base_solver.OTMatcherLinear(ot_solver) time_sampler = samplers.uniform_sampler optimizer = optax.adam(learning_rate=1e-3) unbalancedness_handler = base_solver.UnbalancednessHandler( @@ -218,7 +218,7 @@ def test_flow_matching_learn_rescaling( tau_b = 0.2 rescaling_a = nets.RescalingMLP(hidden_dim=4, condition_dim=condition_dim) rescaling_b = nets.RescalingMLP(hidden_dim=4, condition_dim=condition_dim) - ot_matcher = base_solver.OTMatcher( + ot_matcher = base_solver.OTMatcherLinear( ot_solver, tau_a=tau_a, tau_b=tau_b, From 20fbbb86633c39dce2d6647a894afa778745d57c Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Sun, 11 Feb 2024 13:06:09 +0100 Subject: [PATCH 086/186] remove dictionaries in OTFM and GENOT classes --- src/ott/neural/flows/genot.py | 190 +++++++++++++++------------ src/ott/neural/flows/otfm.py | 50 ++++--- src/ott/neural/models/base_solver.py | 2 +- tests/neural/genot_test.py | 3 + 4 files changed, 133 insertions(+), 112 deletions(-) diff --git a/src/ott/neural/flows/genot.py b/src/ott/neural/flows/genot.py index bd825ad40..7a2989b0e 100644 --- a/src/ott/neural/flows/genot.py +++ b/src/ott/neural/flows/genot.py @@ -13,7 +13,7 @@ # limitations under the License. import functools import types -from typing import Any, Callable, Dict, Optional, Type, Union +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union import jax import jax.numpy as jnp @@ -30,7 +30,13 @@ class GENOTBase: - """The GENOT training class as introduced in :cite:`klein_uscidda:23`. + """Base class for GENOT models (:cite:`klein_uscidda:23`). + + GENOT (Generative Entropic Neural Optimal Transport) is a neural solver + for entropic OT prooblems, in the linear + (:class:`ott.neural.flows.genot.GENOTLin`), the Gromov-Wasserstein, and + the Fused Gromov-Wasserstein ((:class:`ott.neural.flows.genot.GENOTQUad`)) + setting. Args: velocity_field: Neural vector field parameterized by a neural network. @@ -148,34 +154,35 @@ def _get_step_fn(self) -> Callable: def step_fn( rng: jax.Array, state_velocity_field: train_state.TrainState, - batch: Dict[str, jnp.array], + time: jnp.ndarray, + source: jnp.ndarray, + target: jnp.ndarray, + latent: jnp.ndarray, + source_conditions: Optional[jnp.ndarray], ): def loss_fn( - params: jnp.ndarray, batch: Dict[str, jnp.array], - rng: jax.random.PRNGKeyArray + params: jnp.ndarray, time: jnp.ndarray, source: jnp.ndarray, + target: jnp.ndarray, latent: jnp.ndarray, + source_conditions: Optional[jnp.ndarray], rng: jax.random.PRNGKeyArray ): - x_t = self.flow.compute_xt( - rng, batch["time"], batch["latent"], batch["target"] - ) + x_t = self.flow.compute_xt(rng, time, latent, target) apply_fn = functools.partial( state_velocity_field.apply_fn, {"params": params} ) cond_input = jnp.concatenate([ - batch[el] - for el in ["source", "source_conditions"] - if batch[el] is not None - ], - axis=1) - v_t = jax.vmap(apply_fn)(t=batch["time"], x=x_t, condition=cond_input) - u_t = self.flow.compute_ut( - batch["time"], batch["latent"], batch["target"] - ) + source, source_conditions + ], axis=1) if source_conditions is not None else source + v_t = jax.vmap(apply_fn)(t=time, x=x_t, condition=cond_input) + u_t = self.flow.compute_ut(time, latent, target) return jnp.mean((v_t - u_t) ** 2) grad_fn = jax.value_and_grad(loss_fn, has_aux=False) - loss, grads = grad_fn(state_velocity_field.params, batch, rng) + loss, grads = grad_fn( + state_velocity_field.params, time, source, target, latent, + source_conditions, rng + ) return state_velocity_field.apply_gradients(grads=grads), loss @@ -255,8 +262,20 @@ def learn_rescaling(self) -> bool: self.unbalancedness_handler.rescaling_b is not None ) + def _reshape_samples(self, arrays: Tuple[jnp.ndarray, ...], + batch_size: int) -> Tuple[jnp.ndarray, ...]: + return tuple( + jnp.reshape(arr, (batch_size * self.k_samples_per_x, + -1)) if arr is not None else None for arr in arrays + ) + class GENOTLin(GENOTBase): + """Implementation of GENOT-L (:cite:`klein:23`). + + GENOT-L (Generative Entropic Neural Optimal Transport, linear) solves the + entropic (linear) OT problem. + """ def __call__(self, train_loader, valid_loader): """Train GENOT. @@ -273,63 +292,61 @@ def __call__(self, train_loader, valid_loader): self.rng, rng_time, rng_resample, rng_noise, rng_latent_data_match, rng_step_fn ) = jax.random.split(self.rng, 6) + source, source_conditions, target = batch["source_lin"], batch[ + "source_conditions"], batch["target_lin"] + batch_size = len(batch["source_lin"]) n_samples = batch_size * self.k_samples_per_x - batch["time"] = self.time_sampler(rng_time, n_samples) - batch["latent"] = self.latent_noise_fn( + time = self.time_sampler(rng_time, n_samples) + latent = self.latent_noise_fn( rng_noise, shape=(self.k_samples_per_x, batch_size) ) tmat = self.ot_matcher.match_fn( - batch["source_lin"], - batch["target_lin"], + source, + target, ) - batch["source"] = batch["source_lin"] - batch["target"] = batch["target_lin"] - - (batch["source"], batch["source_conditions"]), ( - batch["target"], - ) = self.ot_matcher._sample_conditional_indices_from_tmap( - rng_resample, - tmat, - self.k_samples_per_x, (batch["source"], batch["source_conditions"]), - (batch["target"],), + (source, source_conditions + ), (target,) = self.ot_matcher._sample_conditional_indices_from_tmap( + rng=rng_resample, + tmat=tmat, + k_samples_per_x=self.k_samples_per_x, + source_arrays=(source, source_conditions), + target_arrays=(target,), source_is_balanced=(self.unbalancedness_handler.tau_a == 1.0) ) if self.matcher_latent_to_data.match_fn is not None: tmats_latent_data = jnp.array( jax.vmap(self.matcher_latent_to_data.match_fn, 0, - 0)(x=batch["latent"], y=batch["target"]) + 0)(x=latent, y=target) ) rng_latent_data_match = jax.random.split( rng_latent_data_match, self.k_samples_per_x ) - (batch["source"], batch["source_conditions"] - ), (batch["target"],) = jax.vmap(self.ot_matcher._resample_data, 0, 0)( + (source, source_conditions + ), (target,) = jax.vmap(self.ot_matcher._resample_data, 0, 0)( rng_latent_data_match, tmats_latent_data, - (batch["source"], batch["source_conditions"]), (batch["target"],) + (source, source_conditions), (target,) ) - batch = { - key: - jnp.reshape(arr, (batch_size * self.k_samples_per_x, - -1)) if arr is not None else None - for key, arr in batch.items() - } + source, source_conditions, target, latent = self._reshape_samples( + (source, source_conditions, target, latent), batch_size + ) self.state_velocity_field, loss = self.step_fn( - rng_step_fn, self.state_velocity_field, batch + rng_step_fn, self.state_velocity_field, time, source, target, latent, + source_conditions ) if self.learn_rescaling: ( self.state_eta, self.state_xi, eta_predictions, xi_predictions, loss_a, loss_b ) = self.unbalancedness_handler.step_fn( - source=batch["source"], - target=batch["target"], - condition=batch["source_conditions"], + source=source, + target=target, + condition=source_conditions, a=tmat.sum(axis=1), b=tmat.sum(axis=0), state_eta=self.unbalancedness_handler.state_eta, @@ -340,6 +357,13 @@ def __call__(self, train_loader, valid_loader): class GENOTQuad(GENOTBase): + """Implementation of GENOT-Q and GENOT-F (:cite:`klein:23`). + + GENOT-Q (Generative Entropic Neural Optimal Transport, quadratic) and + GENOT-F (Generative Entropic Neural Optimal Transport, fused) solve the + entropic Gromov-Wasserstein and the entropic Fused Gromov-Wasserstein problem, + respectively. + """ def __call__(self, train_loader, valid_loader): """Train GENOT. @@ -356,73 +380,71 @@ def __call__(self, train_loader, valid_loader): self.rng, rng_time, rng_resample, rng_noise, rng_latent_data_match, rng_step_fn ) = jax.random.split(self.rng, 6) - batch_size = len( - batch["source_lin"] - ) if batch["source_lin"] is not None else len(batch["source_quad"]) + (source_lin, source_quad, source_conditions, target_lin, target_quad) = ( + batch["source_lin"], batch["source_quad"], batch["source_conditions"], + batch["target_lin"], batch["target_quad"] + ) + batch_size = len(source_quad) n_samples = batch_size * self.k_samples_per_x - batch["time"] = self.time_sampler(rng_time, n_samples) - batch["latent"] = self.latent_noise_fn( + time = self.time_sampler(rng_time, n_samples) + latent = self.latent_noise_fn( rng_noise, shape=(self.k_samples_per_x, batch_size) ) tmat = self.ot_matcher.match_fn( - batch["source_lin"], batch["source_quad"], batch["target_lin"], - batch["target_quad"] + source_lin, source_quad, target_lin, target_quad ) if self.ot_matcher.fused_penalty > 0.0: - batch["source"] = jnp.concatenate( - (batch["source_lin"], batch["source_quad"]), axis=1 - ) - batch["target"] = jnp.concatenate( - (batch["target_lin"], batch["target_quad"]), axis=1 - ) + source = jnp.concatenate((source_lin, source_quad), axis=1) + target = jnp.concatenate((target_lin, target_quad), axis=1) else: - batch["source"] = batch["source_quad"] - batch["target"] = batch["target_quad"] - - (batch["source"], batch["source_conditions"]), ( - batch["target"], - ) = self.ot_matcher._sample_conditional_indices_from_tmap( - rng_resample, - tmat, - self.k_samples_per_x, (batch["source"], batch["source_conditions"]), - (batch["target"],), - source_is_balanced=(self.unbalancedness_handler.tau_a == 1.0) + source = source_quad + target = target_quad + + (source, source_conditions), (target,) = ( + self.ot_matcher._sample_conditional_indices_from_tmap( + rng=rng_resample, + tmat=tmat, + k_samples_per_x=self.k_samples_per_x, + source_arrays=(source, source_conditions), + target_arrays=(target,), + source_is_balanced=(self.unbalancedness_handler.tau_a == 1.0) + ) ) if self.matcher_latent_to_data.match_fn is not None: tmats_latent_data = jnp.array( jax.vmap(self.matcher_latent_to_data.match_fn, 0, - 0)(x=batch["latent"], y=batch["target"]) + 0)(x=latent, y=target) ) rng_latent_data_match = jax.random.split( rng_latent_data_match, self.k_samples_per_x ) - (batch["source"], batch["source_conditions"] - ), (batch["target"],) = jax.vmap(self.ot_matcher._resample_data, 0, 0)( + + (source, source_conditions + ), (target,) = jax.vmap(self.ot_matcher._resample_data, 0, 0)( rng_latent_data_match, tmats_latent_data, - (batch["source"], batch["source_conditions"]), (batch["target"],) + (source, source_conditions), (target,) ) - batch = { - key: - jnp.reshape(arr, (batch_size * self.k_samples_per_x, - -1)) if arr is not None else None - for key, arr in batch.items() - } + + source, source_conditions, target, latent = self._reshape_samples( + (source, source_conditions, target, latent), batch_size + ) self.state_velocity_field, loss = self.step_fn( - rng_step_fn, self.state_velocity_field, batch + rng_step_fn, self.state_velocity_field, time, source, target, latent, + source_conditions ) if self.learn_rescaling: ( self.state_eta, self.state_xi, eta_predictions, xi_predictions, loss_a, loss_b ) = self.unbalancedness_handler.step_fn( - source=batch["source"], - target=batch["target"], - condition=batch["source_conditions"], + source=source, + target=target, + condition=source_conditions, a=tmat.sum(axis=1), b=tmat.sum(axis=0), state_eta=self.unbalancedness_handler.state_eta, diff --git a/src/ott/neural/flows/otfm.py b/src/ott/neural/flows/otfm.py index 84fcd5e96..0027bf345 100644 --- a/src/ott/neural/flows/otfm.py +++ b/src/ott/neural/flows/otfm.py @@ -21,7 +21,6 @@ import diffrax import optax from flax.training import train_state -from orbax import checkpoint from ott import utils from ott.geometry import costs @@ -46,7 +45,6 @@ class OTFlowMatching: flow: Flow between source and target distribution. time_sampler: Sampler for the time. optimizer: Optimizer for `velocity_field`. - checkpoint_manager: Checkpoint manager. callback_fn: Callback function. num_eval_samples: Number of samples to evaluate on during evaluation. rng: Random number generator. @@ -65,7 +63,6 @@ def __init__( optimizer: optax.GradientTransformation, ot_matcher: base_solver.OTMatcherLinear, unbalancedness_handler: base_solver.UnbalancednessHandler, - checkpoint_manager: Type[checkpoint.CheckpointManager] = None, epsilon: float = 1e-2, cost_fn: Optional[Type[costs.CostFn]] = None, scale_cost: Union[bool, int, float, @@ -92,7 +89,6 @@ def __init__( self.cost_fn = cost_fn self.scale_cost = scale_cost self.callback_fn = callback_fn - self.checkpoint_manager = checkpoint_manager self.rng = rng self.logging_freq = logging_freq self.num_eval_samples = num_eval_samples @@ -116,30 +112,33 @@ def _get_step_fn(self) -> Callable: def step_fn( rng: jax.Array, state_velocity_field: train_state.TrainState, - batch: Dict[str, jnp.ndarray], + source: jnp.ndarray, + target: jnp.ndarray, + source_conditions: Optional[jnp.ndarray], ) -> Tuple[Any, Any]: def loss_fn( - params: jnp.ndarray, t: jnp.ndarray, batch: Dict[str, jnp.ndarray], + params: jnp.ndarray, t: jnp.ndarray, source: jnp.ndarray, + target: jnp.ndarray, source_conditions: Optional[jnp.ndarray], rng: jax.Array ) -> jnp.ndarray: - x_t = self.flow.compute_xt( - rng, t, batch["source_lin"], batch["target_lin"] - ) + x_t = self.flow.compute_xt(rng, t, source, target) apply_fn = functools.partial( state_velocity_field.apply_fn, {"params": params} ) - v_t = jax.vmap(apply_fn - )(t=t, x=x_t, condition=batch["source_conditions"]) - u_t = self.flow.compute_ut(t, batch["source_lin"], batch["target_lin"]) + v_t = jax.vmap(apply_fn)(t=t, x=x_t, condition=source_conditions) + u_t = self.flow.compute_ut(t, source, target) return jnp.mean((v_t - u_t) ** 2) - batch_size = len(batch["source_lin"]) + batch_size = len(source) key_t, key_model = jax.random.split(rng, 2) t = self.time_sampler(key_t, batch_size) grad_fn = jax.value_and_grad(loss_fn) - loss, grads = grad_fn(state_velocity_field.params, t, batch, key_model) + loss, grads = grad_fn( + state_velocity_field.params, t, source, target, source_conditions, + key_model + ) return state_velocity_field.apply_gradients(grads=grads), loss return step_fn @@ -157,19 +156,16 @@ def __call__(self, train_loader, valid_loader): for iter in range(self.iterations): rng_resample, rng_step_fn, self.rng = jax.random.split(self.rng, 3) batch = next(train_loader) + source, source_conditions, target = batch["source_lin"], batch[ + "source_conditions"], batch["target_lin"] if self.ot_matcher is not None: - tmat = self.ot_matcher.match_fn( - batch["source_lin"], batch["target_lin"] + tmat = self.ot_matcher.match_fn(source, target) + (source, source_conditions), (target,) = self.ot_matcher._resample_data( + rng_resample, tmat, (source, source_conditions), (target,) ) - (batch["source_lin"], batch["source_conditions"] - ), (batch["target_lin"], - batch["target_conditions"]) = self.ot_matcher._resample_data( - rng_resample, tmat, - (batch["source_lin"], batch["source_conditions"]), - (batch["target_lin"], batch["target_conditions"]) - ) self.state_velocity_field, loss = self.step_fn( - rng_step_fn, self.state_velocity_field, batch + rng_step_fn, self.state_velocity_field, source, target, + source_conditions ) curr_loss += loss if iter % self.logging_freq == 0: @@ -181,9 +177,9 @@ def __call__(self, train_loader, valid_loader): self.unbalancedness_handler.state_xi, eta_predictions, xi_predictions, loss_a, loss_b ) = self.unbalancedness_handler.step_fn( - source=batch["source_lin"], - target=batch["target_lin"], - condition=batch["source_conditions"], + source=source, + target=target, + condition=source_conditions, a=tmat.sum(axis=1), b=tmat.sum(axis=0), state_eta=self.unbalancedness_handler.state_eta, diff --git a/src/ott/neural/models/base_solver.py b/src/ott/neural/models/base_solver.py index a4238cd05..a55370403 100644 --- a/src/ott/neural/models/base_solver.py +++ b/src/ott/neural/models/base_solver.py @@ -159,10 +159,10 @@ def _sample_conditional_indices_from_tmap( self, rng: jax.Array, tmat: jnp.ndarray, + *, k_samples_per_x: Union[int, jnp.ndarray], source_arrays: Tuple[jnp.ndarray, ...], target_arrays: Tuple[jnp.ndarray, ...], - *, source_is_balanced: bool, ) -> Tuple[jnp.array, jnp.array]: batch_size = tmat.shape[0] diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index 7ab2a957c..9880f3f2b 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -400,6 +400,9 @@ def test_genot_fused_conditional( self, genot_data_loader_fused_conditional: Iterator, k_samples_per_x: int, solver_latent_to_data: Optional[str] ): + solver_latent_to_data = ( + None if solver_latent_to_data is None else sinkhorn.Sinkhorn() + ) matcher_latent_to_data = base_solver.OTMatcherLinear(solver_latent_to_data) batch = next(genot_data_loader_fused_conditional) source_lin, source_quad, target_lin, target_quad, source_condition = batch[ From 525ef64a14b550a0e39ad2640f3f7916dfffb46b Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Sun, 11 Feb 2024 13:13:24 +0100 Subject: [PATCH 087/186] change logic in match_latent_to_data in genot --- src/ott/neural/flows/genot.py | 8 +++---- src/ott/neural/flows/otfm.py | 2 +- tests/neural/genot_test.py | 40 +++++++++++++++++++++++------------ 3 files changed, 32 insertions(+), 18 deletions(-) diff --git a/src/ott/neural/flows/genot.py b/src/ott/neural/flows/genot.py index 7a2989b0e..267818cc3 100644 --- a/src/ott/neural/flows/genot.py +++ b/src/ott/neural/flows/genot.py @@ -34,8 +34,8 @@ class GENOTBase: GENOT (Generative Entropic Neural Optimal Transport) is a neural solver for entropic OT prooblems, in the linear - (:class:`ott.neural.flows.genot.GENOTLin`), the Gromov-Wasserstein, and - the Fused Gromov-Wasserstein ((:class:`ott.neural.flows.genot.GENOTQUad`)) + (:class:`ott.neural.flows.genot.GENOTLin`), the Gromov-Wasserstein, and + the Fused Gromov-Wasserstein ((:class:`ott.neural.flows.genot.GENOTQUad`)) setting. Args: @@ -317,7 +317,7 @@ def __call__(self, train_loader, valid_loader): source_is_balanced=(self.unbalancedness_handler.tau_a == 1.0) ) - if self.matcher_latent_to_data.match_fn is not None: + if self.matcher_latent_to_data is not None: tmats_latent_data = jnp.array( jax.vmap(self.matcher_latent_to_data.match_fn, 0, 0)(x=latent, y=target) @@ -413,7 +413,7 @@ def __call__(self, train_loader, valid_loader): ) ) - if self.matcher_latent_to_data.match_fn is not None: + if self.matcher_latent_to_data is not None: tmats_latent_data = jnp.array( jax.vmap(self.matcher_latent_to_data.match_fn, 0, 0)(x=latent, y=target) diff --git a/src/ott/neural/flows/otfm.py b/src/ott/neural/flows/otfm.py index 0027bf345..0e8f616f9 100644 --- a/src/ott/neural/flows/otfm.py +++ b/src/ott/neural/flows/otfm.py @@ -61,7 +61,7 @@ def __init__( flow: Type[flows.BaseFlow], time_sampler: Callable[[jax.Array, int], jnp.ndarray], optimizer: optax.GradientTransformation, - ot_matcher: base_solver.OTMatcherLinear, + ot_matcher: Optional[base_solver.OTMatcherLinear], unbalancedness_handler: base_solver.UnbalancednessHandler, epsilon: float = 1e-2, cost_fn: Optional[Type[costs.CostFn]] = None, diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index 9880f3f2b..00e0b2d46 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -41,8 +41,9 @@ def test_genot_linear_unconditional( scale_cost: Union[float, Literal["mean"]], k_samples_per_x: int, solver_latent_to_data: Optional[str] ): - matcher_latent_to_data = base_solver.OTMatcherLinear( - None if solver_latent_to_data is None else sinkhorn.Sinkhorn() + matcher_latent_to_data = ( + None if solver_latent_to_data is None else + base_solver.OTMatcherLinear(sinkhorn.Sinkhorn()) ) batch = next(genot_data_loader_linear) source_lin, target_lin, source_condition = batch["source_lin"], batch[ @@ -98,9 +99,11 @@ def test_genot_linear_conditional( self, genot_data_loader_linear_conditional: Iterator, k_samples_per_x: int, solver_latent_to_data: Optional[str] ): - matcher_latent_to_data = base_solver.OTMatcherLinear( - None if solver_latent_to_data is None else sinkhorn.Sinkhorn() + matcher_latent_to_data = ( + None if solver_latent_to_data is None else + base_solver.OTMatcherLinear(sinkhorn.Sinkhorn()) ) + batch = next(genot_data_loader_linear_conditional) source_lin, target_lin, source_condition = batch["source_lin"], batch[ "target_lin"], batch["source_conditions"] @@ -154,9 +157,11 @@ def test_genot_linear_learn_rescaling( solver_latent_to_data: Optional[str], genot_data_loader_linear_conditional: Iterator ): - matcher_latent_to_data = base_solver.OTMatcherLinear( - None if solver_latent_to_data is None else sinkhorn.Sinkhorn() + matcher_latent_to_data = ( + None if solver_latent_to_data is None else + base_solver.OTMatcherLinear(sinkhorn.Sinkhorn()) ) + data_loader = ( genot_data_loader_linear_conditional if conditional else genot_data_loader_linear @@ -236,9 +241,11 @@ def test_genot_quad_unconditional( self, genot_data_loader_quad: Iterator, k_samples_per_x: int, solver_latent_to_data: Optional[str] ): - matcher_latent_to_data = base_solver.OTMatcherLinear( - None if solver_latent_to_data is None else sinkhorn.Sinkhorn() + matcher_latent_to_data = ( + None if solver_latent_to_data is None else + base_solver.OTMatcherLinear(sinkhorn.Sinkhorn()) ) + batch = next(genot_data_loader_quad) source_quad, target_quad, source_condition = batch["source_quad"], batch[ "target_quad"], batch["source_conditions"] @@ -290,9 +297,11 @@ def test_genot_fused_unconditional( self, genot_data_loader_fused: Iterator, k_samples_per_x: int, solver_latent_to_data: Optional[str] ): - matcher_latent_to_data = base_solver.OTMatcherLinear( - None if solver_latent_to_data is None else sinkhorn.Sinkhorn() + matcher_latent_to_data = ( + None if solver_latent_to_data is None else + base_solver.OTMatcherLinear(sinkhorn.Sinkhorn()) ) + batch = next(genot_data_loader_fused) source_lin, source_quad, target_lin, target_quad, source_condition = batch[ "source_lin"], batch["source_quad"], batch["target_lin"], batch[ @@ -345,9 +354,11 @@ def test_genot_quad_conditional( self, genot_data_loader_quad_conditional: Iterator, k_samples_per_x: int, solver_latent_to_data: Optional[str] ): - matcher_latent_to_data = base_solver.OTMatcherLinear( - None if solver_latent_to_data is None else sinkhorn.Sinkhorn() + matcher_latent_to_data = ( + None if solver_latent_to_data is None else + base_solver.OTMatcherLinear(sinkhorn.Sinkhorn()) ) + batch = next(genot_data_loader_quad_conditional) source_quad, target_quad, source_condition = batch["source_quad"], batch[ "target_quad"], batch["source_conditions"] @@ -403,7 +414,10 @@ def test_genot_fused_conditional( solver_latent_to_data = ( None if solver_latent_to_data is None else sinkhorn.Sinkhorn() ) - matcher_latent_to_data = base_solver.OTMatcherLinear(solver_latent_to_data) + matcher_latent_to_data = ( + None if solver_latent_to_data is None else + base_solver.OTMatcherLinear(sinkhorn.Sinkhorn()) + ) batch = next(genot_data_loader_fused_conditional) source_lin, source_quad, target_lin, target_quad, source_condition = batch[ "source_lin"], batch["source_quad"], batch["target_lin"], batch[ From 1b30c115d27c51b7cf5ad66ab98e8dfee82892ca Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Sun, 11 Feb 2024 16:05:06 +0100 Subject: [PATCH 088/186] change data loaders / data sets --- src/ott/neural/data/dataloaders.py | 72 ++++++-------- src/ott/neural/flows/genot.py | 148 ++++++++++++++++------------- src/ott/neural/flows/otfm.py | 81 ++++++++-------- tests/neural/conftest.py | 46 +++++---- tests/neural/genot_test.py | 95 ++++++++++-------- tests/neural/otfm_test.py | 57 +++++++---- 6 files changed, 271 insertions(+), 228 deletions(-) diff --git a/src/ott/neural/data/dataloaders.py b/src/ott/neural/data/dataloaders.py index 68da7de6e..e063deefd 100644 --- a/src/ott/neural/data/dataloaders.py +++ b/src/ott/neural/data/dataloaders.py @@ -11,39 +11,34 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Iterator, Mapping, Optional +from typing import Any, List, Mapping, Optional import numpy as np -__all__ = ["OTDataLoader", "ConditionalDataLoader"] +__all__ = ["OTDataSet", "ConditionalOTDataLoader"] -class OTDataLoader: - """Data loader for OT problems. +class OTDataSet: + """Data set for OT problems. Args: - batch_size: Number of samples per batch. source_lin: Linear part of the source measure. source_quad: Quadratic part of the source measure. target_lin: Linear part of the target measure. target_quad: Quadratic part of the target measure. source_conditions: Conditions of the source measure. target_conditions: Conditions of the target measure. - seed: Random seed. """ def __init__( self, - batch_size: int = 64, source_lin: Optional[np.ndarray] = None, source_quad: Optional[np.ndarray] = None, target_lin: Optional[np.ndarray] = None, target_quad: Optional[np.ndarray] = None, source_conditions: Optional[np.ndarray] = None, target_conditions: Optional[np.ndarray] = None, - seed: int = 0, ): - super().__init__() if source_lin is not None: if source_quad is not None: assert len(source_lin) == len(source_quad) @@ -71,60 +66,53 @@ def __init__( self.target_quad = target_quad self.source_conditions = source_conditions self.target_conditions = target_conditions - self.batch_size = batch_size - self.rng = np.random.default_rng(seed=seed) - def __next__(self) -> Mapping[str, np.ndarray]: - inds_source = self.rng.choice(self.n_source, size=[self.batch_size]) - inds_target = self.rng.choice(self.n_target, size=[self.batch_size]) + def __getitem__(self, idx: np.ndarray) -> Mapping[str, np.ndarray]: return { "source_lin": - self.source_lin[inds_source, :] - if self.source_lin is not None else None, + self.source_lin[idx] if self.source_lin is not None else [], "source_quad": - self.source_quad[inds_source, :] - if self.source_quad is not None else None, + self.source_quad[idx] if self.source_quad is not None else [], "target_lin": - self.target_lin[inds_target, :] - if self.target_lin is not None else None, + self.target_lin[idx] if self.target_lin is not None else [], "target_quad": - self.target_quad[inds_target, :] - if self.target_quad is not None else None, + self.target_quad[idx] if self.target_quad is not None else [], "source_conditions": - self.source_conditions[inds_source, :] - if self.source_conditions is not None else None, + self.source_conditions[idx] + if self.source_conditions is not None else [], "target_conditions": - self.target_conditions[inds_target, :] - if self.target_conditions is not None else None, + self.target_conditions[idx] + if self.target_conditions is not None else [], } + def __len__(self): + return len(self.source_lin + ) if self.source_lin is not None else len(self.source_quad) -class ConditionalDataLoader: + +class ConditionalOTDataLoader: """Data loader for OT problems with conditions. - This data loader wraps several data loaders and samples from them according - to their conditions. + This data loader wraps several data loaders and samples from them. Args: - dataloaders: Dictionary of data loaders with keys corresponding to - conditions. - p: Probability of sampling from each data loader. + dataloaders: List of data loaders. seed: Random seed. """ def __init__( - self, dataloaders: Dict[str, Iterator], p: np.ndarray, seed: int = 0 + self, + dataloaders: List[Any], + seed: int = 0 # dataloader should subclass torch dataloader ): super().__init__() self.dataloaders = dataloaders - self.conditions = list(dataloaders.keys()) - self.p = p + self.conditions = list(dataloaders) self.rng = np.random.default_rng(seed=seed) - def __next__(self, cond: str = None) -> Mapping[str, np.ndarray]: - if cond is not None: - if cond not in self.conditions: - raise ValueError(f"Condition {cond} not in {self.conditions}") - return next(self.dataloaders[cond]) - idx = self.rng.choice(len(self.conditions), p=self.p) - return next(self.dataloaders[self.conditions[idx]]) + def __next__(self) -> Mapping[str, np.ndarray]: + idx = self.rng.choice(len(self.conditions)) + return next(iter(self.dataloaders[idx])) + + def __iter__(self) -> "ConditionalOTDataLoader": + return self diff --git a/src/ott/neural/flows/genot.py b/src/ott/neural/flows/genot.py index 267818cc3..0ab1b4a87 100644 --- a/src/ott/neural/flows/genot.py +++ b/src/ott/neural/flows/genot.py @@ -251,8 +251,7 @@ def solve_ode(input: jnp.ndarray, cond: jnp.ndarray) -> jnp.ndarray: return jax.vmap(solve_ode)(latent_batch, cond_input) def _valid_step(self, valid_loader, iter): - """TODO.""" - next(valid_loader) + pass @property def learn_rescaling(self) -> bool: @@ -284,76 +283,85 @@ def __call__(self, train_loader, valid_loader): train_loader: Data loader for the training data. valid_loader: Data loader for the validation data. """ - batch: Dict[str, jnp.array] = {} - for iteration in range(self.iterations): - batch = next(train_loader) - - ( - self.rng, rng_time, rng_resample, rng_noise, rng_latent_data_match, - rng_step_fn - ) = jax.random.split(self.rng, 6) - source, source_conditions, target = batch["source_lin"], batch[ - "source_conditions"], batch["target_lin"] - - batch_size = len(batch["source_lin"]) - n_samples = batch_size * self.k_samples_per_x - time = self.time_sampler(rng_time, n_samples) - latent = self.latent_noise_fn( - rng_noise, shape=(self.k_samples_per_x, batch_size) - ) - - tmat = self.ot_matcher.match_fn( - source, - target, - ) - - (source, source_conditions - ), (target,) = self.ot_matcher._sample_conditional_indices_from_tmap( - rng=rng_resample, - tmat=tmat, - k_samples_per_x=self.k_samples_per_x, - source_arrays=(source, source_conditions), - target_arrays=(target,), - source_is_balanced=(self.unbalancedness_handler.tau_a == 1.0) - ) - - if self.matcher_latent_to_data is not None: - tmats_latent_data = jnp.array( - jax.vmap(self.matcher_latent_to_data.match_fn, 0, - 0)(x=latent, y=target) + iter = -1 + while True: + for batch in train_loader: + iter += 1 + if iter >= self.iterations: + stop = True + break + ( + self.rng, rng_time, rng_resample, rng_noise, rng_latent_data_match, + rng_step_fn + ) = jax.random.split(self.rng, 6) + source, source_conditions, target = jnp.array( + batch["source_lin"] + ), jnp.array(batch["source_conditions"] + ) if len(batch["source_conditions"]) else None, jnp.array( + batch["target_lin"] + ) + + batch_size = len(source) + n_samples = batch_size * self.k_samples_per_x + time = self.time_sampler(rng_time, n_samples) + latent = self.latent_noise_fn( + rng_noise, shape=(self.k_samples_per_x, batch_size) ) - rng_latent_data_match = jax.random.split( - rng_latent_data_match, self.k_samples_per_x + tmat = self.ot_matcher.match_fn( + source, + target, ) + (source, source_conditions - ), (target,) = jax.vmap(self.ot_matcher._resample_data, 0, 0)( - rng_latent_data_match, tmats_latent_data, - (source, source_conditions), (target,) + ), (target,) = self.ot_matcher._sample_conditional_indices_from_tmap( + rng=rng_resample, + tmat=tmat, + k_samples_per_x=self.k_samples_per_x, + source_arrays=(source, source_conditions), + target_arrays=(target,), + source_is_balanced=(self.unbalancedness_handler.tau_a == 1.0) ) - source, source_conditions, target, latent = self._reshape_samples( - (source, source_conditions, target, latent), batch_size - ) - self.state_velocity_field, loss = self.step_fn( - rng_step_fn, self.state_velocity_field, time, source, target, latent, - source_conditions - ) - if self.learn_rescaling: - ( - self.state_eta, self.state_xi, eta_predictions, xi_predictions, - loss_a, loss_b - ) = self.unbalancedness_handler.step_fn( - source=source, - target=target, - condition=source_conditions, - a=tmat.sum(axis=1), - b=tmat.sum(axis=0), - state_eta=self.unbalancedness_handler.state_eta, - state_xi=self.unbalancedness_handler.state_xi, + if self.matcher_latent_to_data is not None: + tmats_latent_data = jnp.array( + jax.vmap(self.matcher_latent_to_data.match_fn, 0, + 0)(x=latent, y=target) + ) + + rng_latent_data_match = jax.random.split( + rng_latent_data_match, self.k_samples_per_x + ) + (source, source_conditions + ), (target,) = jax.vmap(self.ot_matcher._resample_data, 0, 0)( + rng_latent_data_match, tmats_latent_data, + (source, source_conditions), (target,) + ) + + source, source_conditions, target, latent = self._reshape_samples( + (source, source_conditions, target, latent), batch_size ) - if iteration % self.valid_freq == 0: - self._valid_step(valid_loader, iteration) + self.state_velocity_field, loss = self.step_fn( + rng_step_fn, self.state_velocity_field, time, source, target, + latent, source_conditions + ) + if self.learn_rescaling: + ( + self.state_eta, self.state_xi, eta_predictions, xi_predictions, + loss_a, loss_b + ) = self.unbalancedness_handler.step_fn( + source=source, + target=target, + condition=source_conditions, + a=tmat.sum(axis=1), + b=tmat.sum(axis=0), + state_eta=self.unbalancedness_handler.state_eta, + state_xi=self.unbalancedness_handler.state_xi, + ) + if iter % self.valid_freq == 0: + self._valid_step(valid_loader, iter) + if stop: + break class GENOTQuad(GENOTBase): @@ -374,15 +382,19 @@ def __call__(self, train_loader, valid_loader): """ batch: Dict[str, jnp.array] = {} for iteration in range(self.iterations): - batch = next(train_loader) + batch = next(iter(train_loader)) ( self.rng, rng_time, rng_resample, rng_noise, rng_latent_data_match, rng_step_fn ) = jax.random.split(self.rng, 6) (source_lin, source_quad, source_conditions, target_lin, target_quad) = ( - batch["source_lin"], batch["source_quad"], batch["source_conditions"], - batch["target_lin"], batch["target_quad"] + jnp.array(batch["source_lin"]) if len(batch["source_lin"]) else None, + jnp.array(batch["source_quad"]), + jnp.array(batch["source_conditions"]) + if len(batch["source_conditions"]) else None, + jnp.array(batch["target_lin"]) if len(batch["target_lin"]) else None, + jnp.array(batch["target_quad"]) ) batch_size = len(source_quad) n_samples = batch_size * self.k_samples_per_x diff --git a/src/ott/neural/flows/otfm.py b/src/ott/neural/flows/otfm.py index 0e8f616f9..69bad0e69 100644 --- a/src/ott/neural/flows/otfm.py +++ b/src/ott/neural/flows/otfm.py @@ -151,42 +151,50 @@ def __call__(self, train_loader, valid_loader): valid_loader: Dataloader for the validation data. """ batch: Mapping[str, jnp.ndarray] = {} - curr_loss = 0.0 - - for iter in range(self.iterations): - rng_resample, rng_step_fn, self.rng = jax.random.split(self.rng, 3) - batch = next(train_loader) - source, source_conditions, target = batch["source_lin"], batch[ - "source_conditions"], batch["target_lin"] - if self.ot_matcher is not None: - tmat = self.ot_matcher.match_fn(source, target) - (source, source_conditions), (target,) = self.ot_matcher._resample_data( - rng_resample, tmat, (source, source_conditions), (target,) - ) - self.state_velocity_field, loss = self.step_fn( - rng_step_fn, self.state_velocity_field, source, target, - source_conditions - ) - curr_loss += loss - if iter % self.logging_freq == 0: - self._training_logs["loss"].append(curr_loss / self.logging_freq) - curr_loss = 0.0 - if self.learn_rescaling: - ( - self.unbalancedness_handler.state_eta, - self.unbalancedness_handler.state_xi, eta_predictions, - xi_predictions, loss_a, loss_b - ) = self.unbalancedness_handler.step_fn( - source=source, - target=target, - condition=source_conditions, - a=tmat.sum(axis=1), - b=tmat.sum(axis=0), - state_eta=self.unbalancedness_handler.state_eta, - state_xi=self.unbalancedness_handler.state_xi, + + iter = -1 + while True: + for batch in train_loader: + iter += 1 + if iter >= self.iterations: + stop = True + break + rng_resample, rng_step_fn, self.rng = jax.random.split(self.rng, 3) + source, source_conditions, target = jnp.array( + batch["source_lin"] + ), jnp.array(batch["source_conditions"] + ) if batch["source_conditions"] else None, jnp.array( + batch["target_lin"] + ) + if self.ot_matcher is not None: + tmat = self.ot_matcher.match_fn(source, target) + (source, + source_conditions), (target,) = self.ot_matcher._resample_data( + rng_resample, tmat, (source, source_conditions), (target,) + ) + self.state_velocity_field, loss = self.step_fn( + rng_step_fn, self.state_velocity_field, source, target, + source_conditions ) - if iter % self.valid_freq == 0: - self._valid_step(valid_loader, iter) + self._training_logs["loss"].append(loss) + if self.learn_rescaling: + ( + self.unbalancedness_handler.state_eta, + self.unbalancedness_handler.state_xi, eta_predictions, + xi_predictions, loss_a, loss_b + ) = self.unbalancedness_handler.step_fn( + source=source, + target=target, + condition=source_conditions, + a=tmat.sum(axis=1), + b=tmat.sum(axis=0), + state_eta=self.unbalancedness_handler.state_eta, + state_xi=self.unbalancedness_handler.state_xi, + ) + if iter % self.valid_freq == 0: + self._valid_step(valid_loader, iter) + if stop: + break def transport( self, @@ -243,8 +251,7 @@ def solve_ode(input: jnp.ndarray, cond: jnp.ndarray) -> jnp.ndarray: return jax.vmap(solve_ode)(data, condition) def _valid_step(self, valid_loader, iter): - next(valid_loader) - # TODO: add callback and logging + pass @property def learn_rescaling(self) -> bool: diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py index 05cd38af1..504de7e52 100644 --- a/tests/neural/conftest.py +++ b/tests/neural/conftest.py @@ -14,6 +14,8 @@ import pytest import numpy as np +import torch +from torch.utils.data import DataLoader as Torch_loader from ott.neural.data import dataloaders @@ -24,7 +26,8 @@ def data_loader_gaussian(): rng = np.random.default_rng(seed=0) source = rng.normal(size=(100, 2)) target = rng.normal(size=(100, 2)) + 1.0 - return dataloaders.OTDataLoader(16, source_lin=source, target_lin=target) + dataset = dataloaders.OTDataSet(source_lin=source, target_lin=target) + return Torch_loader(dataset, batch_size=16, shuffle=True) @pytest.fixture(scope="module") @@ -36,23 +39,22 @@ def data_loader_gaussian_conditional(): source_1 = rng.normal(size=(100, 2)) target_1 = rng.normal(size=(100, 2)) - 2.0 - dl0 = dataloaders.OTDataLoader( - 16, + ds0 = dataloaders.OTDataSet( source_lin=source_0, target_lin=target_0, source_conditions=np.zeros_like(source_0) * 0.0 ) - dl1 = dataloaders.OTDataLoader( - 16, + ds1 = dataloaders.OTDataSet( source_lin=source_1, target_lin=target_1, source_conditions=np.ones_like(source_1) * 1.0 ) + sampler0 = torch.utils.data.RandomSampler(ds0, replacement=True) + sampler1 = torch.utils.data.RandomSampler(ds1, replacement=True) + dl0 = Torch_loader(ds0, batch_size=16, sampler=sampler0) + dl1 = Torch_loader(ds1, batch_size=16, sampler=sampler1) - return dataloaders.ConditionalDataLoader({ - "0": dl0, - "1": dl1 - }, np.array([0.5, 0.5])) + return dataloaders.ConditionalOTDataLoader((dl0, dl1)) @pytest.fixture(scope="module") @@ -63,13 +65,14 @@ def data_loader_gaussian_with_conditions(): target = rng.normal(size=(100, 2)) + 1.0 source_conditions = rng.normal(size=(100, 1)) target_conditions = rng.normal(size=(100, 1)) - 1.0 - return dataloaders.OTDataLoader( - 16, + + dataset = dataloaders.OTDataSet( source_lin=source, target_lin=target, source_conditions=source_conditions, target_conditions=target_conditions ) + return Torch_loader(dataset, batch_size=16, shuffle=True) @pytest.fixture(scope="module") @@ -78,7 +81,8 @@ def genot_data_loader_linear(): rng = np.random.default_rng(seed=0) source = rng.normal(size=(100, 2)) target = rng.normal(size=(100, 2)) + 1.0 - return dataloaders.OTDataLoader(16, source_lin=source, target_lin=target) + dataset = dataloaders.OTDataSet(source_lin=source, target_lin=target) + return Torch_loader(dataset, batch_size=16, shuffle=True) @pytest.fixture(scope="module") @@ -88,8 +92,7 @@ def genot_data_loader_linear_conditional(): source = rng.normal(size=(100, 2)) target = rng.normal(size=(100, 2)) + 1.0 source_conditions = rng.normal(size=(100, 4)) - return dataloaders.OTDataLoader( - 16, + return dataloaders.OTDataSet( source_lin=source, target_lin=target, source_conditions=source_conditions, @@ -102,7 +105,8 @@ def genot_data_loader_quad(): rng = np.random.default_rng(seed=0) source = rng.normal(size=(100, 2)) target = rng.normal(size=(100, 1)) + 1.0 - return dataloaders.OTDataLoader(16, source_quad=source, target_quad=target) + dataset = dataloaders.OTDataSet(source_quad=source, target_quad=target) + return Torch_loader(dataset, batch_size=16, shuffle=True) @pytest.fixture(scope="module") @@ -112,12 +116,12 @@ def genot_data_loader_quad_conditional(): source = rng.normal(size=(100, 2)) target = rng.normal(size=(100, 1)) + 1.0 source_conditions = rng.normal(size=(100, 7)) - return dataloaders.OTDataLoader( - 16, + dataset = dataloaders.OTDataSet( source_quad=source, target_quad=target, source_conditions=source_conditions, ) + return Torch_loader(dataset, batch_size=16, shuffle=True) @pytest.fixture(scope="module") @@ -128,13 +132,13 @@ def genot_data_loader_fused(): target_q = rng.normal(size=(100, 1)) + 1.0 source_lin = rng.normal(size=(100, 2)) target_lin = rng.normal(size=(100, 2)) + 1.0 - return dataloaders.OTDataLoader( - 16, + dataset = dataloaders.OTDataSet( source_lin=source_lin, source_quad=source_q, target_lin=target_lin, target_quad=target_q ) + return Torch_loader(dataset, batch_size=16, shuffle=True) @pytest.fixture(scope="module") @@ -146,11 +150,11 @@ def genot_data_loader_fused_conditional(): source_lin = rng.normal(size=(100, 2)) target_lin = rng.normal(size=(100, 2)) + 1.0 source_conditions = rng.normal(size=(100, 7)) - return dataloaders.OTDataLoader( - 16, + dataset = dataloaders.OTDataSet( source_lin=source_lin, source_quad=source_q, target_lin=target_lin, target_quad=target_q, source_conditions=source_conditions, ) + return Torch_loader(dataset, batch_size=16, shuffle=True) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index 00e0b2d46..0a030521c 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -45,9 +45,13 @@ def test_genot_linear_unconditional( None if solver_latent_to_data is None else base_solver.OTMatcherLinear(sinkhorn.Sinkhorn()) ) - batch = next(genot_data_loader_linear) - source_lin, target_lin, source_condition = batch["source_lin"], batch[ - "target_lin"], batch["source_conditions"] + batch = next(iter(genot_data_loader_linear)) + source_lin, source_conditions, target_lin = jnp.array( + batch["source_lin"] + ), jnp.array(batch["source_conditions"]) if len(batch["source_conditions"] + ) else None, jnp.array( + batch["target_lin"] + ) source_dim = source_lin.shape[1] target_dim = target_lin.shape[1] @@ -83,12 +87,9 @@ def test_genot_linear_unconditional( ) genot(genot_data_loader_linear, genot_data_loader_linear) - batch = next(genot_data_loader_linear) - source_lin, target_lin, source_condition = batch["source_lin"], batch[ - "target_lin"], batch["source_conditions"] - + batch = next(iter(genot_data_loader_linear)) result_forward = genot.transport( - source_lin, condition=source_condition, forward=True + source_lin, condition=source_conditions, forward=True ) assert isinstance(result_forward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_forward)) == 0 @@ -104,12 +105,16 @@ def test_genot_linear_conditional( base_solver.OTMatcherLinear(sinkhorn.Sinkhorn()) ) - batch = next(genot_data_loader_linear_conditional) - source_lin, target_lin, source_condition = batch["source_lin"], batch[ - "target_lin"], batch["source_conditions"] + batch = next(iter(genot_data_loader_linear_conditional)) + source_lin, source_conditions, target_lin = jnp.array( + batch["source_lin"] + ), jnp.array(batch["source_conditions"]) if len(batch["source_conditions"] + ) else None, jnp.array( + batch["target_lin"] + ) source_dim = source_lin.shape[1] target_dim = target_lin.shape[1] - condition_dim = source_condition.shape[1] + condition_dim = source_conditions.shape[1] neural_vf = VelocityField( output_dim=target_dim, @@ -145,7 +150,7 @@ def test_genot_linear_conditional( genot_data_loader_linear_conditional ) result_forward = genot.transport( - source_lin, condition=source_condition, forward=True + source_lin, condition=source_conditions, forward=True ) assert isinstance(result_forward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_forward)) == 0 @@ -167,9 +172,10 @@ def test_genot_linear_learn_rescaling( if conditional else genot_data_loader_linear ) - batch = next(data_loader) - source_lin, target_lin, source_condition = batch["source_lin"], batch[ - "target_lin"], batch["source_conditions"] + batch = next(iter(data_loader)) + source_lin, target_lin, source_condition = jnp.array( + batch["source_lin"] + ), jnp.array(batch["target_lin"]), jnp.array(batch["source_conditions"]) source_dim = source_lin.shape[1] target_dim = target_lin.shape[1] @@ -246,10 +252,12 @@ def test_genot_quad_unconditional( base_solver.OTMatcherLinear(sinkhorn.Sinkhorn()) ) - batch = next(genot_data_loader_quad) - source_quad, target_quad, source_condition = batch["source_quad"], batch[ - "target_quad"], batch["source_conditions"] - + batch = next(iter(genot_data_loader_quad)) + (source_quad, source_conditions, target_quad) = ( + jnp.array(batch["source_quad"]), jnp.array(batch["source_conditions"]) + if len(batch["source_conditions"]) else None, + jnp.array(batch["target_quad"]) + ) source_dim = source_quad.shape[1] target_dim = target_quad.shape[1] condition_dim = 0 @@ -286,7 +294,7 @@ def test_genot_quad_unconditional( genot(genot_data_loader_quad, genot_data_loader_quad) result_forward = genot.transport( - source_quad, condition=source_condition, forward=True + source_quad, condition=source_conditions, forward=True ) assert isinstance(result_forward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_forward)) == 0 @@ -302,11 +310,14 @@ def test_genot_fused_unconditional( base_solver.OTMatcherLinear(sinkhorn.Sinkhorn()) ) - batch = next(genot_data_loader_fused) - source_lin, source_quad, target_lin, target_quad, source_condition = batch[ - "source_lin"], batch["source_quad"], batch["target_lin"], batch[ - "target_quad"], batch["source_conditions"] - + batch = next(iter(genot_data_loader_fused)) + (source_lin, source_quad, source_conditions, target_lin, target_quad) = ( + jnp.array(batch["source_lin"]) if len(batch["source_lin"]) else None, + jnp.array(batch["source_quad"]), jnp.array(batch["source_conditions"]) + if len(batch["source_conditions"]) else None, + jnp.array(batch["target_lin"]) if len(batch["target_lin"]) else None, + jnp.array(batch["target_quad"]) + ) source_dim = source_lin.shape[1] + source_quad.shape[1] target_dim = target_lin.shape[1] + target_quad.shape[1] condition_dim = 0 @@ -342,7 +353,7 @@ def test_genot_fused_unconditional( result_forward = genot.transport( jnp.concatenate((source_lin, source_quad), axis=1), - condition=source_condition, + condition=source_conditions, forward=True ) assert isinstance(result_forward, jnp.ndarray) @@ -359,13 +370,15 @@ def test_genot_quad_conditional( base_solver.OTMatcherLinear(sinkhorn.Sinkhorn()) ) - batch = next(genot_data_loader_quad_conditional) - source_quad, target_quad, source_condition = batch["source_quad"], batch[ - "target_quad"], batch["source_conditions"] - + batch = next(iter(genot_data_loader_quad_conditional)) + (source_quad, source_conditions, target_quad) = ( + jnp.array(batch["source_quad"]), jnp.array(batch["source_conditions"]) + if len(batch["source_conditions"]) else None, + jnp.array(batch["target_quad"]) + ) source_dim = source_quad.shape[1] target_dim = target_quad.shape[1] - condition_dim = source_condition.shape[1] + condition_dim = source_conditions.shape[1] neural_vf = VelocityField( output_dim=target_dim, condition_dim=source_dim + condition_dim, @@ -400,7 +413,7 @@ def test_genot_quad_conditional( ) result_forward = genot.transport( - source_quad, condition=source_condition, forward=True + source_quad, condition=source_conditions, forward=True ) assert isinstance(result_forward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_forward)) == 0 @@ -418,13 +431,17 @@ def test_genot_fused_conditional( None if solver_latent_to_data is None else base_solver.OTMatcherLinear(sinkhorn.Sinkhorn()) ) - batch = next(genot_data_loader_fused_conditional) - source_lin, source_quad, target_lin, target_quad, source_condition = batch[ - "source_lin"], batch["source_quad"], batch["target_lin"], batch[ - "target_quad"], batch["source_conditions"] + batch = next(iter(genot_data_loader_fused_conditional)) + (source_lin, source_quad, source_conditions, target_lin, target_quad) = ( + jnp.array(batch["source_lin"]) if len(batch["source_lin"]) else None, + jnp.array(batch["source_quad"]), jnp.array(batch["source_conditions"]) + if len(batch["source_conditions"]) else None, + jnp.array(batch["target_lin"]) if len(batch["target_lin"]) else None, + jnp.array(batch["target_quad"]) + ) source_dim = source_lin.shape[1] + source_quad.shape[1] target_dim = target_lin.shape[1] + target_quad.shape[1] - condition_dim = source_condition.shape[1] + condition_dim = source_conditions.shape[1] neural_vf = VelocityField( output_dim=target_dim, condition_dim=source_dim + condition_dim, @@ -460,7 +477,7 @@ def test_genot_fused_conditional( result_forward = genot.transport( jnp.concatenate((source_lin, source_quad), axis=1), - condition=source_condition, + condition=source_conditions, forward=True ) assert isinstance(result_forward, jnp.ndarray) diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index 9452c2faa..d660ec33f 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -35,7 +35,7 @@ class TestOTFlowMatching: flows.BrownianNoiseFlow(0.2) ] ) - def test_flow_matching( + def test_flow_matching_unconditional( self, data_loader_gaussian, flow: Type[flows.BaseFlow] ): input_dim = 2 @@ -66,17 +66,20 @@ def test_flow_matching( ) fm(data_loader_gaussian, data_loader_gaussian) - batch = next(data_loader_gaussian) + batch = next(iter(data_loader_gaussian)) + source = jnp.asarray(batch["source_lin"]) + target = jnp.asarray(batch["target_lin"]) + source_conditions = jnp.asarray(batch["source_conditions"]) if len( + batch["source_conditions"] + ) > 0 else None result_forward = fm.transport( - batch["source_lin"], condition=batch["source_conditions"], forward=True + source, condition=source_conditions, forward=True ) assert isinstance(result_forward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_forward)) == 0 result_backward = fm.transport( - batch["target_lin"], - condition=batch["source_conditions"], - forward=False + target, condition=source_conditions, forward=False ) assert isinstance(result_backward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_backward)) == 0 @@ -123,17 +126,20 @@ def test_flow_matching_with_conditions( data_loader_gaussian_with_conditions ) - batch = next(data_loader_gaussian_with_conditions) + batch = next(iter(data_loader_gaussian_with_conditions)) + source = jnp.asarray(batch["source_lin"]) + target = jnp.asarray(batch["target_lin"]) + source_conditions = jnp.asarray(batch["source_conditions"]) if len( + batch["source_conditions"] + ) > 0 else None result_forward = fm.transport( - batch["source_lin"], condition=batch["source_conditions"], forward=True + source, condition=source_conditions, forward=True ) assert isinstance(result_forward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_forward)) == 0 result_backward = fm.transport( - batch["target_lin"], - condition=batch["source_conditions"], - forward=False + target, condition=source_conditions, forward=False ) assert isinstance(result_backward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_backward)) == 0 @@ -177,17 +183,20 @@ def test_flow_matching_conditional( ) fm(data_loader_gaussian_conditional, data_loader_gaussian_conditional) - batch = next(data_loader_gaussian_conditional) + batch = next(iter(data_loader_gaussian_conditional)) + source = jnp.asarray(batch["source_lin"]) + target = jnp.asarray(batch["target_lin"]) + source_conditions = jnp.asarray(batch["source_conditions"]) if len( + batch["source_conditions"] + ) > 0 else None result_forward = fm.transport( - batch["source_lin"], condition=batch["source_conditions"], forward=True + source, condition=source_conditions, forward=True ) assert isinstance(result_forward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_forward)) == 0 result_backward = fm.transport( - batch["target_lin"], - condition=batch["source_conditions"], - forward=False + target, condition=source_conditions, forward=False ) assert isinstance(result_backward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_backward)) == 0 @@ -201,9 +210,15 @@ def test_flow_matching_learn_rescaling( data_loader_gaussian_conditional if conditional else data_loader_gaussian ) - batch = next(data_loader) - source_dim = batch["source_lin"].shape[1] - condition_dim = batch["source_conditions"].shape[1] if conditional else 0 + batch = next(iter(data_loader)) + source = jnp.asarray(batch["source_lin"]) + target = jnp.asarray(batch["target_lin"]) + source_conditions = jnp.asarray(batch["source_conditions"]) if len( + batch["source_conditions"] + ) > 0 else None + + source_dim = source.shape[1] + condition_dim = source_conditions.shape[1] if conditional else 0 neural_vf = models.VelocityField( output_dim=2, condition_dim=0, @@ -249,13 +264,13 @@ def test_flow_matching_learn_rescaling( fm(data_loader, data_loader) result_eta = fm.unbalancedness_handler.evaluate_eta( - batch["source_lin"], condition=batch["source_conditions"] + source, condition=source_conditions ) assert isinstance(result_eta, jnp.ndarray) assert jnp.sum(jnp.isnan(result_eta)) == 0 result_xi = fm.unbalancedness_handler.evaluate_xi( - batch["target_lin"], condition=batch["source_conditions"] + target, condition=source_conditions ) assert isinstance(result_xi, jnp.ndarray) assert jnp.sum(jnp.isnan(result_xi)) == 0 From e2ebb19ef79514bafad063f931f7f0d47a946d91 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Sun, 11 Feb 2024 17:54:30 +0100 Subject: [PATCH 089/186] finish data loader refactoring --- tests/neural/conftest.py | 93 +++++++++++++++++++++++++++----------- tests/neural/genot_test.py | 1 + 2 files changed, 67 insertions(+), 27 deletions(-) diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py index 504de7e52..f33252f07 100644 --- a/tests/neural/conftest.py +++ b/tests/neural/conftest.py @@ -89,14 +89,26 @@ def genot_data_loader_linear(): def genot_data_loader_linear_conditional(): """Returns a data loader for a simple Gaussian mixture.""" rng = np.random.default_rng(seed=0) - source = rng.normal(size=(100, 2)) - target = rng.normal(size=(100, 2)) + 1.0 - source_conditions = rng.normal(size=(100, 4)) - return dataloaders.OTDataSet( - source_lin=source, - target_lin=target, - source_conditions=source_conditions, + source_0 = rng.normal(size=(100, 2)) + target_0 = rng.normal(size=(100, 2)) + 1.0 + source_1 = rng.normal(size=(100, 2)) + target_1 = rng.normal(size=(100, 2)) + 1.0 + ds0 = dataloaders.OTDataSet( + source_lin=source_0, + target_lin=target_0, + source_conditions=np.zeros_like(source_0) * 0.0 + ) + ds1 = dataloaders.OTDataSet( + source_lin=source_1, + target_lin=target_1, + source_conditions=np.ones_like(source_1) * 1.0 ) + sampler0 = torch.utils.data.RandomSampler(ds0, replacement=True) + sampler1 = torch.utils.data.RandomSampler(ds1, replacement=True) + dl0 = Torch_loader(ds0, batch_size=16, sampler=sampler0) + dl1 = Torch_loader(ds1, batch_size=16, sampler=sampler1) + + return dataloaders.ConditionalOTDataLoader((dl0, dl1)) @pytest.fixture(scope="module") @@ -113,15 +125,26 @@ def genot_data_loader_quad(): def genot_data_loader_quad_conditional(): """Returns a data loader for a simple Gaussian mixture.""" rng = np.random.default_rng(seed=0) - source = rng.normal(size=(100, 2)) - target = rng.normal(size=(100, 1)) + 1.0 - source_conditions = rng.normal(size=(100, 7)) - dataset = dataloaders.OTDataSet( - source_quad=source, - target_quad=target, - source_conditions=source_conditions, + source_0 = rng.normal(size=(100, 2)) + target_0 = rng.normal(size=(100, 1)) + 1.0 + source_1 = rng.normal(size=(100, 2)) + target_1 = rng.normal(size=(100, 1)) + 1.0 + ds0 = dataloaders.OTDataSet( + source_quad=source_0, + target_quad=target_0, + source_conditions=np.zeros_like(source_0) * 0.0 ) - return Torch_loader(dataset, batch_size=16, shuffle=True) + ds1 = dataloaders.OTDataSet( + source_quad=source_1, + target_quad=target_1, + source_conditions=np.ones_like(source_1) * 1.0 + ) + sampler0 = torch.utils.data.RandomSampler(ds0, replacement=True) + sampler1 = torch.utils.data.RandomSampler(ds1, replacement=True) + dl0 = Torch_loader(ds0, batch_size=16, sampler=sampler0) + dl1 = Torch_loader(ds1, batch_size=16, sampler=sampler1) + + return dataloaders.ConditionalOTDataLoader((dl0, dl1)) @pytest.fixture(scope="module") @@ -145,16 +168,32 @@ def genot_data_loader_fused(): def genot_data_loader_fused_conditional(): """Returns a data loader for a simple Gaussian mixture.""" rng = np.random.default_rng(seed=0) - source_q = rng.normal(size=(100, 2)) - target_q = rng.normal(size=(100, 1)) + 1.0 - source_lin = rng.normal(size=(100, 2)) - target_lin = rng.normal(size=(100, 2)) + 1.0 - source_conditions = rng.normal(size=(100, 7)) - dataset = dataloaders.OTDataSet( - source_lin=source_lin, - source_quad=source_q, - target_lin=target_lin, - target_quad=target_q, - source_conditions=source_conditions, + source_q_0 = rng.normal(size=(100, 2)) + target_q_0 = rng.normal(size=(100, 1)) + 1.0 + source_lin_0 = rng.normal(size=(100, 2)) + target_lin_0 = rng.normal(size=(100, 2)) + 1.0 + + source_q_1 = 2 * rng.normal(size=(100, 2)) + target_q_1 = 2 * rng.normal(size=(100, 1)) + 1.0 + source_lin_1 = 2 * rng.normal(size=(100, 2)) + target_lin_1 = 2 * rng.normal(size=(100, 2)) + 1.0 + + ds0 = dataloaders.OTDataSet( + source_lin=source_lin_0, + target_lin=target_lin_0, + source_quad=source_q_0, + target_quad=target_q_0, + source_conditions=np.zeros_like(source_lin_0) * 0.0 ) - return Torch_loader(dataset, batch_size=16, shuffle=True) + ds1 = dataloaders.OTDataSet( + source_lin=source_lin_1, + target_lin=target_lin_1, + source_quad=source_q_1, + target_quad=target_q_1, + source_conditions=np.ones_like(source_lin_1) * 1.0 + ) + sampler0 = torch.utils.data.RandomSampler(ds0, replacement=True) + sampler1 = torch.utils.data.RandomSampler(ds1, replacement=True) + dl0 = Torch_loader(ds0, batch_size=16, sampler=sampler0) + dl1 = Torch_loader(ds1, batch_size=16, sampler=sampler1) + return dataloaders.ConditionalOTDataLoader((dl0, dl1)) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index 0a030521c..d44db4476 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -376,6 +376,7 @@ def test_genot_quad_conditional( if len(batch["source_conditions"]) else None, jnp.array(batch["target_quad"]) ) + source_dim = source_quad.shape[1] target_dim = target_quad.shape[1] condition_dim = source_conditions.shape[1] From 8644fd93afbaa1281cb8e623ec0b84a466c259d1 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 12 Feb 2024 17:53:44 +0100 Subject: [PATCH 090/186] Update linter --- .pre-commit-config.yaml | 6 +++--- docs/tutorials/MetaOT.ipynb | 1 + pyproject.toml | 21 +++++++++++-------- src/ott/__init__.py | 11 +++++++++- src/ott/neural/duality/models.py | 1 - src/ott/neural/duality/neuraldual.py | 12 ++++++++++- src/ott/neural/flows/otfm.py | 12 ++++++++++- src/ott/neural/gaps/map_estimator.py | 11 +++++++++- src/ott/neural/models/base_solver.py | 9 ++++---- src/ott/solvers/linear/sinkhorn_lr.py | 11 +++++++++- src/ott/solvers/quadratic/__init__.py | 7 ++++++- .../quadratic/gromov_wasserstein_lr.py | 11 +++++++++- .../tools/gaussian_mixture/fit_gmm_pair.py | 6 +++++- .../gaussian_mixture/gaussian_mixture.py | 7 ++++++- tests/conftest.py | 3 +-- tests/geometry/lr_kernel_test.py | 4 +++- 16 files changed, 103 insertions(+), 30 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1f84672bb..ec54873a3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,12 +20,12 @@ repos: - id: trailing-whitespace - id: check-case-conflict - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.1.6 + rev: v0.2.1 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] - repo: https://github.com/pycqa/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort name: isort @@ -42,7 +42,7 @@ repos: - id: nbqa-black - id: nbqa-isort - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks - rev: v2.11.0 + rev: v2.12.0 hooks: - id: pretty-format-yaml args: [--autofix, --indent, '2'] diff --git a/docs/tutorials/MetaOT.ipynb b/docs/tutorials/MetaOT.ipynb index 1ef687b28..172024733 100644 --- a/docs/tutorials/MetaOT.ipynb +++ b/docs/tutorials/MetaOT.ipynb @@ -63,6 +63,7 @@ "import jax.numpy as jnp\n", "import numpy as np\n", "import torchvision\n", + "\n", "from flax import linen as nn\n", "\n", "import matplotlib.pyplot as plt\n", diff --git a/pyproject.toml b/pyproject.toml index 6128854ab..7306525b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,13 +103,14 @@ include = '\.ipynb$' [tool.isort] profile = "black" +line_length = 80 include_trailing_comma = true multi_line_output = 3 sections = ["FUTURE", "STDLIB", "THIRDPARTY", "TEST", "NUMERIC", "NEURAL", "PLOTTING", "FIRSTPARTY", "LOCALFOLDER"] # also contains what we import in notebooks/tests known_neural = ["flax", "optax", "diffrax", "orbax"] known_numeric = ["numpy", "scipy", "jax", "flax", "optax", "jaxopt", "torch", "ot", "torchvision", "pandas", "sklearn", "tslearn"] -known_test = ["pytest"] +known_test = ["_pytest", "pytest"] known_plotting = ["IPython", "matplotlib", "mpl_toolkits", "seaborn"] [tool.pytest.ini_options] @@ -274,6 +275,10 @@ exclude = [ "docs/_build", "dist" ] +line-length = 80 +target-version = "py38" + +[tool.ruff.lint] ignore = [ # Do not assign a lambda expression, use a def -> lambda expression assignments are convenient "E731", @@ -288,10 +293,8 @@ ignore = [ # Missing docstring in magic method "D105", ] -line-length = 80 select = [ "D", # flake8-docstrings - "I", # isort "E", # pycodestyle "F", # pyflakes "W", # pycodestyle @@ -308,20 +311,20 @@ select = [ "RET", # flake8-raise ] unfixable = ["B", "UP", "C4", "BLE", "T20", "RET"] -target-version = "py38" -[tool.ruff.per-file-ignores] + +[tool.ruff.lint.per-file-ignores] # TODO(michalk8): PO004 - remove `self.initialize` "tests/*" = ["D", "PT004", "E402"] "*/__init__.py" = ["F401"] "docs/*" = ["D"] "src/ott/types.py" = ["D102"] -[tool.ruff.pydocstyle] +[tool.ruff.lint.pydocstyle] convention = "google" -[tool.ruff.pyupgrade] +[tool.ruff.lint.pyupgrade] # Preserve types, even if a file imports `from __future__ import annotations`. keep-runtime-typing = true -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] # Disallow all relative imports. ban-relative-imports = "parents" -[tool.ruff.flake8-quotes] +[tool.ruff.lint.flake8-quotes] inline-quotes = "double" diff --git a/src/ott/__init__.py b/src/ott/__init__.py index 8d2f007c5..dac0eb854 100644 --- a/src/ott/__init__.py +++ b/src/ott/__init__.py @@ -13,7 +13,16 @@ # limitations under the License. import contextlib -from . import datasets, geometry, initializers, math, problems, solvers, tools, utils +from . import ( + datasets, + geometry, + initializers, + math, + problems, + solvers, + tools, + utils, +) with contextlib.suppress(ImportError): # TODO(michalk8): add warning that neural module is not imported diff --git a/src/ott/neural/duality/models.py b/src/ott/neural/duality/models.py index d498a8d4e..b3ce94c35 100644 --- a/src/ott/neural/duality/models.py +++ b/src/ott/neural/duality/models.py @@ -16,7 +16,6 @@ import jax import jax.numpy as jnp -from jax.nn import initializers import flax.linen as nn import optax diff --git a/src/ott/neural/duality/neuraldual.py b/src/ott/neural/duality/neuraldual.py index 09f51dc80..3ea88f74a 100644 --- a/src/ott/neural/duality/neuraldual.py +++ b/src/ott/neural/duality/neuraldual.py @@ -13,7 +13,17 @@ # limitations under the License. import abc import warnings -from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + Literal, + Optional, + Tuple, + Union, +) import jax import jax.numpy as jnp diff --git a/src/ott/neural/flows/otfm.py b/src/ott/neural/flows/otfm.py index 69bad0e69..42ffa422e 100644 --- a/src/ott/neural/flows/otfm.py +++ b/src/ott/neural/flows/otfm.py @@ -13,7 +13,17 @@ # limitations under the License. import collections import functools -from typing import Any, Callable, Dict, Literal, Mapping, Optional, Tuple, Type, Union +from typing import ( + Any, + Callable, + Dict, + Literal, + Mapping, + Optional, + Tuple, + Type, + Union, +) import jax import jax.numpy as jnp diff --git a/src/ott/neural/gaps/map_estimator.py b/src/ott/neural/gaps/map_estimator.py index be8834458..61c24f0c3 100644 --- a/src/ott/neural/gaps/map_estimator.py +++ b/src/ott/neural/gaps/map_estimator.py @@ -13,7 +13,16 @@ # limitations under the License. import collections import functools -from typing import Any, Callable, Dict, Iterator, Optional, Sequence, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Iterator, + Optional, + Sequence, + Tuple, + Union, +) import jax import jax.numpy as jnp diff --git a/src/ott/neural/models/base_solver.py b/src/ott/neural/models/base_solver.py index a55370403..b0587d3f0 100644 --- a/src/ott/neural/models/base_solver.py +++ b/src/ott/neural/models/base_solver.py @@ -326,12 +326,11 @@ class UnbalancednessHandler: resample_epsilon: Epsilon for resampling. scale_cost: Scaling of the cost matrix for estimating the rescaling factors. ot_solver: Solver to compute unbalanced marginals. If `ot_solver` is `None`, - the method - :meth:`ott.neural.models.base_solver.UnbalancednessHandler.compute_unbalanced_marginals` - is not available, and hence the unbalanced marginals must be computed by the neural solver. + the method :meth:`ott.neural.models.base_solver.UnbalancednessHandler.compute_unbalanced_marginals` + is not available, and hence the unbalanced marginals must be computed + by the neural solver. kwargs: Additional keyword arguments. - - """ + """ # noqa: E501 # TODO(MUCDK): fix me def __init__( self, diff --git a/src/ott/solvers/linear/sinkhorn_lr.py b/src/ott/solvers/linear/sinkhorn_lr.py index f6e6216fe..da949da0d 100644 --- a/src/ott/solvers/linear/sinkhorn_lr.py +++ b/src/ott/solvers/linear/sinkhorn_lr.py @@ -11,7 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Literal, Mapping, NamedTuple, Optional, Tuple, Union +from typing import ( + Any, + Callable, + Literal, + Mapping, + NamedTuple, + Optional, + Tuple, + Union, +) import jax import jax.experimental diff --git a/src/ott/solvers/quadratic/__init__.py b/src/ott/solvers/quadratic/__init__.py index 507812971..560ac3ddd 100644 --- a/src/ott/solvers/quadratic/__init__.py +++ b/src/ott/solvers/quadratic/__init__.py @@ -11,5 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import gromov_wasserstein, gromov_wasserstein_lr, gw_barycenter, lower_bound +from . import ( + gromov_wasserstein, + gromov_wasserstein_lr, + gw_barycenter, + lower_bound, +) from ._solve import solve diff --git a/src/ott/solvers/quadratic/gromov_wasserstein_lr.py b/src/ott/solvers/quadratic/gromov_wasserstein_lr.py index df2237477..cb12911bf 100644 --- a/src/ott/solvers/quadratic/gromov_wasserstein_lr.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein_lr.py @@ -12,7 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. """A Jax implementation of the unbalanced low-rank GW algorithm.""" -from typing import Any, Callable, Literal, Mapping, NamedTuple, Optional, Tuple, Union +from typing import ( + Any, + Callable, + Literal, + Mapping, + NamedTuple, + Optional, + Tuple, + Union, +) import jax import jax.experimental diff --git a/src/ott/tools/gaussian_mixture/fit_gmm_pair.py b/src/ott/tools/gaussian_mixture/fit_gmm_pair.py index 0c3c78ba3..7ecde263c 100644 --- a/src/ott/tools/gaussian_mixture/fit_gmm_pair.py +++ b/src/ott/tools/gaussian_mixture/fit_gmm_pair.py @@ -84,7 +84,11 @@ import jax import jax.numpy as jnp -from ott.tools.gaussian_mixture import fit_gmm, gaussian_mixture, gaussian_mixture_pair +from ott.tools.gaussian_mixture import ( + fit_gmm, + gaussian_mixture, + gaussian_mixture_pair, +) __all__ = ["get_fit_model_em_fn"] diff --git a/src/ott/tools/gaussian_mixture/gaussian_mixture.py b/src/ott/tools/gaussian_mixture/gaussian_mixture.py index c3f04c8fa..27a568989 100644 --- a/src/ott/tools/gaussian_mixture/gaussian_mixture.py +++ b/src/ott/tools/gaussian_mixture/gaussian_mixture.py @@ -16,7 +16,12 @@ import jax import jax.numpy as jnp -from ott.tools.gaussian_mixture import gaussian, linalg, probabilities, scale_tril +from ott.tools.gaussian_mixture import ( + gaussian, + linalg, + probabilities, + scale_tril, +) __all__ = ["GaussianMixture"] diff --git a/tests/conftest.py b/tests/conftest.py index da7e6a3dc..8fe7166aa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,9 +15,8 @@ import itertools from typing import Any, Mapping, Optional, Sequence -from _pytest.python import Metafunc - import pytest +from _pytest.python import Metafunc import jax import jax.experimental diff --git a/tests/geometry/lr_kernel_test.py b/tests/geometry/lr_kernel_test.py index 1f0a42e7d..6db247179 100644 --- a/tests/geometry/lr_kernel_test.py +++ b/tests/geometry/lr_kernel_test.py @@ -1,9 +1,11 @@ from typing import Literal, Optional +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest + from ott.geometry import costs, low_rank, pointcloud from ott.solvers import linear From 460bf901f29676e13568a63f833475a77c760f52 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 14 Feb 2024 09:05:26 +0100 Subject: [PATCH 091/186] fix bug in _resample_data` --- src/ott/neural/models/base_solver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ott/neural/models/base_solver.py b/src/ott/neural/models/base_solver.py index b0587d3f0..733fb6477 100644 --- a/src/ott/neural/models/base_solver.py +++ b/src/ott/neural/models/base_solver.py @@ -146,7 +146,7 @@ def _resample_data( ) -> Tuple[jnp.ndarray, ...]: """Resample a batch according to coupling `tmat`.""" tmat_flattened = tmat.flatten() - indices = jax.random.choice(rng, len(tmat_flattened), shape=[tmat.shape[0]]) + indices = jax.random.choice(rng, len(tmat_flattened), p=tmat_flattened, shape=[tmat.shape[0]]) indices_source = indices // tmat.shape[1] indices_target = indices % tmat.shape[1] return tuple( From ce42c1a1f80edb39174f3490e19e8de10304e236 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 16 Feb 2024 01:10:10 +0100 Subject: [PATCH 092/186] incorporate more changes --- src/ott/neural/__init__.py | 2 +- .../neural/{flows => flow_models}/__init__.py | 0 .../neural/{flows => flow_models}/flows.py | 32 +-- .../neural/{flows => flow_models}/genot.py | 182 +++++++++--------- .../neural/{flows => flow_models}/layers.py | 0 .../neural/{flows => flow_models}/models.py | 2 +- src/ott/neural/{flows => flow_models}/otfm.py | 16 +- .../neural/{flows => flow_models}/samplers.py | 0 src/ott/neural/models/base_solver.py | 142 ++++++++------ tests/neural/genot_test.py | 6 +- tests/neural/otfm_test.py | 2 +- 11 files changed, 208 insertions(+), 176 deletions(-) rename src/ott/neural/{flows => flow_models}/__init__.py (100%) rename src/ott/neural/{flows => flow_models}/flows.py (79%) rename src/ott/neural/{flows => flow_models}/genot.py (76%) rename src/ott/neural/{flows => flow_models}/layers.py (100%) rename src/ott/neural/{flows => flow_models}/models.py (99%) rename src/ott/neural/{flows => flow_models}/otfm.py (95%) rename src/ott/neural/{flows => flow_models}/samplers.py (100%) diff --git a/src/ott/neural/__init__.py b/src/ott/neural/__init__.py index 2a61ca021..678919a8c 100644 --- a/src/ott/neural/__init__.py +++ b/src/ott/neural/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import data, duality, flows, gaps, models +from . import data, duality, flow_models, gaps, models diff --git a/src/ott/neural/flows/__init__.py b/src/ott/neural/flow_models/__init__.py similarity index 100% rename from src/ott/neural/flows/__init__.py rename to src/ott/neural/flow_models/__init__.py diff --git a/src/ott/neural/flows/flows.py b/src/ott/neural/flow_models/flows.py similarity index 79% rename from src/ott/neural/flows/flows.py rename to src/ott/neural/flow_models/flows.py index 65f697d89..fd1009cef 100644 --- a/src/ott/neural/flows/flows.py +++ b/src/ott/neural/flow_models/flows.py @@ -44,9 +44,9 @@ def compute_mu_t( at time :math:`t`. Args: - t: Time :math:`t`. - src: Sample from the source distribution. - tgt: Sample from the target distribution. + t: Time :math:`t` of shape `(batch_size, 1)`. + src: Sample from the source distribution of shape `(batch_size, ...)`. + tgt: Sample from the target distribution of shape `(batch_size, ...)`. """ @abc.abstractmethod @@ -54,7 +54,7 @@ def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray: """Compute the standard deviation of the probablity path at time :math:`t`. Args: - t: Time :math:`t`. + t: Time :math:`t` of shape `(batch_size, 1)`. """ @abc.abstractmethod @@ -67,9 +67,9 @@ def compute_ut( :math:`x_1` at time :math:`t`. Args: - t: Time :math:`t`. - src: Sample from the source distribution. - tgt: Sample from the target distribution. + t: Time :math:`t` of shape `(batch_size, 1)`.. + src: Sample from the source distribution of shape `(batch_size, ...)`. + tgt: Sample from the target distribution of shape `(batch_size, ...)`. Returns: Conditional vector field evaluated at time :math:`t`. @@ -85,9 +85,9 @@ def compute_xt( Args: rng: Random number generator. - t: Time :math:`t`. - src: Sample from the source distribution. - tgt: Sample from the target distribution. + t: Time :math:`t` of shape `(batch_size, 1)`.. + src: Sample from the source distribution of shape `(batch_size, ...)`. + tgt: Sample from the target distribution of shape `(batch_size, ...)`. Returns: Samples from the probability path between :math:`x_0` and :math:`x_1` @@ -116,9 +116,9 @@ def compute_ut( :math:`x_1` at time :math:`t`. Args: - t: Time :math:`t`. - src: Sample from the source distribution. - tgt: Sample from the target distribution. + t: Time :math:`t` of shape `(batch_size, 1)`. + src: Sample from the source distribution of shape `(batch_size, ...)`. + tgt: Sample from the target distribution of shape `(batch_size, ...)`.. Returns: Conditional vector field evaluated at time :math:`t`. @@ -134,7 +134,7 @@ def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray: r"""Compute noise of the flow at time :math:`t`. Args: - t: Time :math:`t`. + t: Time :math:`t` of shape `(batch_size, 1)`.. Returns: Constant, time-independent standard deviation :math:`\sigma`. @@ -147,7 +147,7 @@ class BrownianNoiseFlow(StraightFlow): Sampler for sampling noise implicitly defined by a Schroedinger Bridge problem with parameter :math:`\sigma` such that - :math:`\sigma_t = \sigma * \sqrt(t * (1-t))`. + :math:`\sigma_t = \sigma * \sqrt(t * (1-t))` (:cite:`tong:23`). Returns: Samples from the probability path between :math:`x_0` and :math:`x_1` @@ -158,7 +158,7 @@ def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray: """Compute the standard deviation of the probablity path at time :math:`t`. Args: - t: Time :math:`t`. + t: Time :math:`t` of shape `(batch_size, 1)`.. Returns: Standard deviation of the probablity path at time :math:`t`. diff --git a/src/ott/neural/flows/genot.py b/src/ott/neural/flow_models/genot.py similarity index 76% rename from src/ott/neural/flows/genot.py rename to src/ott/neural/flow_models/genot.py index 0ab1b4a87..c900d1268 100644 --- a/src/ott/neural/flows/genot.py +++ b/src/ott/neural/flow_models/genot.py @@ -23,7 +23,7 @@ from flax.training import train_state from ott import utils -from ott.neural.flows import flows, samplers +from ott.neural.flow_models import flows, samplers from ott.neural.models import base_solver __all__ = ["GENOTBase", "GENOTLin", "GENOTQuad"] @@ -164,7 +164,7 @@ def step_fn( def loss_fn( params: jnp.ndarray, time: jnp.ndarray, source: jnp.ndarray, target: jnp.ndarray, latent: jnp.ndarray, - source_conditions: Optional[jnp.ndarray], rng: jax.random.PRNGKeyArray + source_conditions: Optional[jnp.ndarray], rng: jax.Array ): x_t = self.flow.compute_xt(rng, time, latent, target) apply_fn = functools.partial( @@ -263,17 +263,17 @@ def learn_rescaling(self) -> bool: def _reshape_samples(self, arrays: Tuple[jnp.ndarray, ...], batch_size: int) -> Tuple[jnp.ndarray, ...]: - return tuple( - jnp.reshape(arr, (batch_size * self.k_samples_per_x, - -1)) if arr is not None else None for arr in arrays + return jax.tree_util.tree_map( + lambda x: jnp.reshape(x, (batch_size * self.k_samples_per_x, -1)) + if x is not None else None, arrays ) class GENOTLin(GENOTBase): """Implementation of GENOT-L (:cite:`klein:23`). - GENOT-L (Generative Entropic Neural Optimal Transport, linear) solves the - entropic (linear) OT problem. + GENOT-L (Generative Entropic Neural Optimal Transport, linear) is a + neural solver for entropic (linear) OT problems. """ def __call__(self, train_loader, valid_loader): @@ -314,9 +314,9 @@ def __call__(self, train_loader, valid_loader): ) (source, source_conditions - ), (target,) = self.ot_matcher._sample_conditional_indices_from_tmap( + ), (target,) = self.ot_matcher.sample_conditional_indices_from_tmap( rng=rng_resample, - tmat=tmat, + conditional_distributions=tmat, k_samples_per_x=self.k_samples_per_x, source_arrays=(source, source_conditions), target_arrays=(target,), @@ -333,7 +333,7 @@ def __call__(self, train_loader, valid_loader): rng_latent_data_match, self.k_samples_per_x ) (source, source_conditions - ), (target,) = jax.vmap(self.ot_matcher._resample_data, 0, 0)( + ), (target,) = jax.vmap(self.ot_matcher.sample_joint, 0, 0)( rng_latent_data_match, tmats_latent_data, (source, source_conditions), (target,) ) @@ -368,9 +368,9 @@ class GENOTQuad(GENOTBase): """Implementation of GENOT-Q and GENOT-F (:cite:`klein:23`). GENOT-Q (Generative Entropic Neural Optimal Transport, quadratic) and - GENOT-F (Generative Entropic Neural Optimal Transport, fused) solve the - entropic Gromov-Wasserstein and the entropic Fused Gromov-Wasserstein problem, - respectively. + GENOT-F (Generative Entropic Neural Optimal Transport, fused) are neural + solver for entropic Gromov-Wasserstein and entropic Fused Gromov-Wasserstein + problems, respectively. """ def __call__(self, train_loader, valid_loader): @@ -381,86 +381,94 @@ def __call__(self, train_loader, valid_loader): valid_loader: Data loader for the validation data. """ batch: Dict[str, jnp.array] = {} - for iteration in range(self.iterations): - batch = next(iter(train_loader)) - - ( - self.rng, rng_time, rng_resample, rng_noise, rng_latent_data_match, - rng_step_fn - ) = jax.random.split(self.rng, 6) - (source_lin, source_quad, source_conditions, target_lin, target_quad) = ( - jnp.array(batch["source_lin"]) if len(batch["source_lin"]) else None, - jnp.array(batch["source_quad"]), - jnp.array(batch["source_conditions"]) - if len(batch["source_conditions"]) else None, - jnp.array(batch["target_lin"]) if len(batch["target_lin"]) else None, - jnp.array(batch["target_quad"]) - ) - batch_size = len(source_quad) - n_samples = batch_size * self.k_samples_per_x - time = self.time_sampler(rng_time, n_samples) - latent = self.latent_noise_fn( - rng_noise, shape=(self.k_samples_per_x, batch_size) - ) - - tmat = self.ot_matcher.match_fn( - source_lin, source_quad, target_lin, target_quad - ) - - if self.ot_matcher.fused_penalty > 0.0: - source = jnp.concatenate((source_lin, source_quad), axis=1) - target = jnp.concatenate((target_lin, target_quad), axis=1) - else: - source = source_quad - target = target_quad - - (source, source_conditions), (target,) = ( - self.ot_matcher._sample_conditional_indices_from_tmap( - rng=rng_resample, - tmat=tmat, - k_samples_per_x=self.k_samples_per_x, - source_arrays=(source, source_conditions), - target_arrays=(target,), - source_is_balanced=(self.unbalancedness_handler.tau_a == 1.0) - ) - ) + iter = -1 + while True: + for batch in train_loader: + iter += 1 + if iter >= self.iterations: + stop = True + break - if self.matcher_latent_to_data is not None: - tmats_latent_data = jnp.array( - jax.vmap(self.matcher_latent_to_data.match_fn, 0, - 0)(x=latent, y=target) + ( + self.rng, rng_time, rng_resample, rng_noise, rng_latent_data_match, + rng_step_fn + ) = jax.random.split(self.rng, 6) + (source_lin, source_quad, source_conditions, target_lin, + target_quad) = ( + jnp.array(batch["source_lin"]) if len(batch["source_lin"]) else + None, jnp.array(batch["source_quad"]), + jnp.array(batch["source_conditions"]) + if len(batch["source_conditions"]) else None, + jnp.array(batch["target_lin"]) if len(batch["target_lin"]) else + None, jnp.array(batch["target_quad"]) + ) + batch_size = len(source_quad) + n_samples = batch_size * self.k_samples_per_x + time = self.time_sampler(rng_time, n_samples) + latent = self.latent_noise_fn( + rng_noise, shape=(self.k_samples_per_x, batch_size) ) - rng_latent_data_match = jax.random.split( - rng_latent_data_match, self.k_samples_per_x + tmat = self.ot_matcher.match_fn( + source_quad, target_quad, source_lin, target_lin ) - (source, source_conditions - ), (target,) = jax.vmap(self.ot_matcher._resample_data, 0, 0)( - rng_latent_data_match, tmats_latent_data, - (source, source_conditions), (target,) + if self.ot_matcher.fused_penalty > 0.0: + source = jnp.concatenate((source_lin, source_quad), axis=1) + target = jnp.concatenate((target_lin, target_quad), axis=1) + else: + source = source_quad + target = target_quad + + (source, source_conditions), (target,) = ( + self.ot_matcher.sample_conditional_indices_from_tmap( + rng=rng_resample, + conditional_distributions=tmat, + k_samples_per_x=self.k_samples_per_x, + source_arrays=(source, source_conditions), + target_arrays=(target,), + source_is_balanced=(self.unbalancedness_handler.tau_a == 1.0) + ) ) - source, source_conditions, target, latent = self._reshape_samples( - (source, source_conditions, target, latent), batch_size - ) + if self.matcher_latent_to_data is not None: + tmats_latent_data = jnp.array( + jax.vmap(self.matcher_latent_to_data.match_fn, 0, + 0)(x=latent, y=target) + ) - self.state_velocity_field, loss = self.step_fn( - rng_step_fn, self.state_velocity_field, time, source, target, latent, - source_conditions - ) - if self.learn_rescaling: - ( - self.state_eta, self.state_xi, eta_predictions, xi_predictions, - loss_a, loss_b - ) = self.unbalancedness_handler.step_fn( - source=source, - target=target, - condition=source_conditions, - a=tmat.sum(axis=1), - b=tmat.sum(axis=0), - state_eta=self.unbalancedness_handler.state_eta, - state_xi=self.unbalancedness_handler.state_xi, + rng_latent_data_match = jax.random.split( + rng_latent_data_match, self.k_samples_per_x + ) + + (source, source_conditions + ), (target,) = jax.vmap(self.ot_matcher.sample_joint, 0, 0)( + rng_latent_data_match, tmats_latent_data, + (source, source_conditions), (target,) + ) + + source, source_conditions, target, latent = self._reshape_samples( + (source, source_conditions, target, latent), batch_size + ) + + self.state_velocity_field, loss = self.step_fn( + rng_step_fn, self.state_velocity_field, time, source, target, + latent, source_conditions ) - if iteration % self.valid_freq == 0: - self._valid_step(valid_loader, iteration) + if self.learn_rescaling: + ( + self.state_eta, self.state_xi, eta_predictions, xi_predictions, + loss_a, loss_b + ) = self.unbalancedness_handler.step_fn( + source=source, + target=target, + condition=source_conditions, + a=tmat.sum(axis=1), + b=tmat.sum(axis=0), + state_eta=self.unbalancedness_handler.state_eta, + state_xi=self.unbalancedness_handler.state_xi, + ) + if iter % self.valid_freq == 0: + self._valid_step(valid_loader, iter) + if stop: + break diff --git a/src/ott/neural/flows/layers.py b/src/ott/neural/flow_models/layers.py similarity index 100% rename from src/ott/neural/flows/layers.py rename to src/ott/neural/flow_models/layers.py diff --git a/src/ott/neural/flows/models.py b/src/ott/neural/flow_models/models.py similarity index 99% rename from src/ott/neural/flows/models.py rename to src/ott/neural/flow_models/models.py index bf365e772..ebb29aa99 100644 --- a/src/ott/neural/flows/models.py +++ b/src/ott/neural/flow_models/models.py @@ -20,7 +20,7 @@ import optax from flax.training import train_state -import ott.neural.flows.layers as flow_layers +import ott.neural.flow_models.layers as flow_layers from ott.neural.models import layers __all__ = ["VelocityField"] diff --git a/src/ott/neural/flows/otfm.py b/src/ott/neural/flow_models/otfm.py similarity index 95% rename from src/ott/neural/flows/otfm.py rename to src/ott/neural/flow_models/otfm.py index 42ffa422e..dca7bea60 100644 --- a/src/ott/neural/flows/otfm.py +++ b/src/ott/neural/flow_models/otfm.py @@ -34,7 +34,7 @@ from ott import utils from ott.geometry import costs -from ott.neural.flows import flows +from ott.neural.flow_models import flows from ott.neural.models import base_solver __all__ = ["OTFlowMatching"] @@ -172,16 +172,14 @@ def __call__(self, train_loader, valid_loader): rng_resample, rng_step_fn, self.rng = jax.random.split(self.rng, 3) source, source_conditions, target = jnp.array( batch["source_lin"] - ), jnp.array(batch["source_conditions"] - ) if batch["source_conditions"] else None, jnp.array( - batch["target_lin"] - ) + ), jnp.array(batch["source_conditions"]) if len( + batch["source_conditions"] + ) > 0 else None, jnp.array(batch["target_lin"]) if self.ot_matcher is not None: tmat = self.ot_matcher.match_fn(source, target) - (source, - source_conditions), (target,) = self.ot_matcher._resample_data( - rng_resample, tmat, (source, source_conditions), (target,) - ) + (source, source_conditions), (target,) = self.ot_matcher.sample_joint( + rng_resample, tmat, (source, source_conditions), (target,) + ) self.state_velocity_field, loss = self.step_fn( rng_step_fn, self.state_velocity_field, source, target, source_conditions diff --git a/src/ott/neural/flows/samplers.py b/src/ott/neural/flow_models/samplers.py similarity index 100% rename from src/ott/neural/flows/samplers.py rename to src/ott/neural/flow_models/samplers.py diff --git a/src/ott/neural/models/base_solver.py b/src/ott/neural/models/base_solver.py index 733fb6477..c0da4db6a 100644 --- a/src/ott/neural/models/base_solver.py +++ b/src/ott/neural/models/base_solver.py @@ -15,6 +15,7 @@ import jax import jax.numpy as jnp +from jax import tree_util import optax from flax.training import train_state @@ -39,8 +40,6 @@ def _get_sinkhorn_match_fn( "max_cost", "median"]] = "mean", tau_a: float = 1.0, tau_b: float = 1.0, - *, - filter_input: bool = False, ) -> Callable: @jax.jit @@ -52,19 +51,7 @@ def match_pairs(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: linear_problem.LinearProblem(geom, tau_a=tau_a, tau_b=tau_b) ) - @jax.jit - def match_pairs_filtered( - x_lin: jnp.ndarray, x_quad: jnp.ndarray, y_lin: jnp.ndarray, - y_quad: jnp.ndarray - ) -> jnp.ndarray: - geom = pointcloud.PointCloud( - x_lin, y_lin, epsilon=epsilon, scale_cost=scale_cost, cost_fn=cost_fn - ) - return ot_solver( - linear_problem.LinearProblem(geom, tau_a=tau_a, tau_b=tau_b) - ) - - return match_pairs_filtered if filter_input else match_pairs + return match_pairs def _get_gromov_match_fn( @@ -104,10 +91,10 @@ def _get_gromov_match_fn( @jax.jit def match_pairs( - x_lin: Optional[jnp.ndarray], x_quad: Tuple[jnp.ndarray, jnp.ndarray], - y_lin: Optional[jnp.ndarray], y_quad: Tuple[jnp.ndarray, jnp.ndarray], + x_lin: Optional[jnp.ndarray], + y_lin: Optional[jnp.ndarray], ) -> jnp.ndarray: geom_xx = pointcloud.PointCloud( x=x_quad, y=x_quad, cost_fn=cost_fn_xx, scale_cost=scale_cost_xx @@ -137,36 +124,67 @@ def match_pairs( class BaseOTMatcher: """Base class for mini-batch neural OT matching classes.""" - def _resample_data( + def sample_joint( self, rng: jax.Array, - tmat: jnp.ndarray, - source_arrays: Tuple[jnp.ndarray, ...], - target_arrays: Tuple[jnp.ndarray, ...], + joint_dist: jnp.ndarray, + source_arrays: Tuple[Optional[jnp.ndarray], ...], + target_arrays: Tuple[Optional[jnp.ndarray], ...], ) -> Tuple[jnp.ndarray, ...]: - """Resample a batch according to coupling `tmat`.""" - tmat_flattened = tmat.flatten() - indices = jax.random.choice(rng, len(tmat_flattened), p=tmat_flattened, shape=[tmat.shape[0]]) - indices_source = indices // tmat.shape[1] - indices_target = indices % tmat.shape[1] - return tuple( - b[indices_source] if b is not None else None for b in source_arrays - ), tuple( - b[indices_target] if b is not None else None for b in target_arrays + """Resample from arrays according to discrete joint distribution. + + Args: + rng: Random number generator. + joint_dist: Joint distribution between source and target to sample from. + source_arrays: Arrays corresponding to source distriubution to sample + from. + target_arrays: Arrays corresponding to target arrays to sample from. + + Returns: + Resampled source and target arrays. + """ + _, n_tgt = joint_dist.shape + tmat_flattened = joint_dist.flatten() + indices = jax.random.choice( + rng, len(tmat_flattened), p=tmat_flattened, shape=[joint_dist.shape[0]] + ) + indices_source = indices // n_tgt + indices_target = indices % n_tgt + return tree_util.tree_map( + lambda b: b[indices_source] if b is not None else b, source_arrays + ), tree_util.tree_map( + lambda b: b[indices_target] if b is not None else b, target_arrays ) - def _sample_conditional_indices_from_tmap( + def sample_conditional_indices_from_tmap( self, rng: jax.Array, - tmat: jnp.ndarray, + conditional_distributions: jnp.ndarray, *, k_samples_per_x: Union[int, jnp.ndarray], - source_arrays: Tuple[jnp.ndarray, ...], - target_arrays: Tuple[jnp.ndarray, ...], + source_arrays: Tuple[Optional[jnp.ndarray], ...], + target_arrays: Tuple[Optional[jnp.ndarray], ...], source_is_balanced: bool, - ) -> Tuple[jnp.array, jnp.array]: - batch_size = tmat.shape[0] - left_marginals = tmat.sum(axis=1) + ) -> Tuple[jnp.ndarray, ...]: + """Sample from arrays according to discrete conditional distributions. + + Args: + rng: Random number generator. + conditional_distributions: Conditional distributions to sample from. + k_samples_per_x: Expectation of number of samples to draw from each + conditional distribution. + source_arrays: Arrays corresponding to source distriubution to sample + from. + target_arrays: Arrays corresponding to target arrays to sample from. + source_is_balanced: Whether the source distribution is balanced. + If :obj:`False`, the number of samples drawn from each conditional + distribution `k_samples_per_x` is proportional to the left marginals. + + Returns: + Resampled source and target arrays. + """ + n_src, n_tgt = conditional_distributions.shape + left_marginals = conditional_distributions.sum(axis=1) if not source_is_balanced: rng, rng_2 = jax.random.split(rng, 2) indices = jax.random.choice( @@ -176,12 +194,11 @@ def _sample_conditional_indices_from_tmap( shape=(len(left_marginals),) ) else: - indices = jnp.arange(batch_size) - tmat_adapted = tmat[indices] + indices = jnp.arange(n_src) + tmat_adapted = conditional_distributions[indices] indices_per_row = jax.vmap( - lambda row: jax.random.choice( - key=rng, a=jnp.arange(batch_size), p=row, shape=(k_samples_per_x,) - ), + lambda row: jax.random. + choice(key=rng, a=jnp.arange(n_tgt), p=row, shape=(k_samples_per_x,)), in_axes=0, out_axes=0, )( @@ -190,16 +207,16 @@ def _sample_conditional_indices_from_tmap( indices_source = jnp.repeat(indices, k_samples_per_x) indices_target = jnp.reshape( - indices_per_row % tmat.shape[1], (batch_size * k_samples_per_x,) + indices_per_row % n_tgt, (n_src * k_samples_per_x,) ) - return tuple( - jnp.reshape(b[indices_source], (k_samples_per_x, batch_size, - -1)) if b is not None else None - for b in source_arrays - ), tuple( - jnp.reshape(b[indices_target], (k_samples_per_x, batch_size, - -1)) if b is not None else None - for b in target_arrays + return tree_util.tree_map( + lambda b: jnp. + reshape(b[indices_source], (k_samples_per_x, n_src, *b.shape[1:])) + if b is not None else None, source_arrays + ), tree_util.tree_map( + lambda b: jnp. + reshape(b[indices_target], (k_samples_per_x, n_src, *b.shape[1:])) + if b is not None else b, target_arrays ) @@ -409,17 +426,26 @@ def compute_unbalanced_marginals_quad(*args, **kwargs): return compute_unbalanced_marginals_quad @jax.jit - def _resample_unbalanced( + def resample_unbalanced( self, rng: jax.Array, - batch: Tuple[jnp.ndarray, ...], - marginals: jnp.ndarray, + arrays: Tuple[jnp.ndarray, ...], + p: jnp.ndarray, ) -> Tuple[jnp.ndarray, ...]: - """Resample a batch based on marginals.""" - indices = jax.random.choice( - rng, a=len(marginals), p=jnp.squeeze(marginals), shape=[len(marginals)] + """Resample a batch based on marginals. + + Args: + rng: Random number generator. + arrays: Arrays to resample from. + p: Probabilities according to which `arrays` are resampled. + + Returns: + Resampled arrays. + """ + indices = jax.random.choice(rng, a=len(p), p=jnp.squeeze(p), shape=[len(p)]) + return tree_util.tree_map( + lambda b: b[indices] if b is not None else b, arrays ) - return tuple(b[indices] if b is not None else None for b in batch) def setup(self, source_dim: int, target_dim: int, cond_dim: int): """Setup the model. diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index d44db4476..9480eb3cd 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -22,9 +22,9 @@ import optax from ott.geometry import costs -from ott.neural.flows.genot import GENOTLin, GENOTQuad -from ott.neural.flows.models import VelocityField -from ott.neural.flows.samplers import uniform_sampler +from ott.neural.flow_models.genot import GENOTLin, GENOTQuad +from ott.neural.flow_models.models import VelocityField +from ott.neural.flow_models.samplers import uniform_sampler from ott.neural.models import base_solver from ott.neural.models.nets import RescalingMLP from ott.solvers.linear import sinkhorn diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index d660ec33f..d66ea1611 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -21,7 +21,7 @@ import optax -from ott.neural.flows import flows, models, otfm, samplers +from ott.neural.flow_models import flows, models, otfm, samplers from ott.neural.models import base_solver, nets from ott.solvers.linear import sinkhorn From 1e21afb26016a1f250c0b4a8edf445e8c4694404 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 16 Feb 2024 01:34:50 +0100 Subject: [PATCH 093/186] add docs --- docs/neural/data.rst | 21 ++++++++++++++++ docs/neural/duality.rst | 37 ++++++++++++++++++++++++++++ docs/neural/flow_models.rst | 48 +++++++++++++++++++++++++++++++++++++ docs/neural/gap_models.rst | 26 ++++++++++++++++++++ docs/neural/index.rst | 31 ++++-------------------- docs/neural/models.rst | 0 6 files changed, 137 insertions(+), 26 deletions(-) create mode 100644 docs/neural/data.rst create mode 100644 docs/neural/duality.rst create mode 100644 docs/neural/flow_models.rst create mode 100644 docs/neural/gap_models.rst create mode 100644 docs/neural/models.rst diff --git a/docs/neural/data.rst b/docs/neural/data.rst new file mode 100644 index 000000000..8e13e7631 --- /dev/null +++ b/docs/neural/data.rst @@ -0,0 +1,21 @@ +ott.neural.data +=============== +.. module:: ott.neural.data +.. currentmodule:: ott.neural.data + +The :mod:`ott.problems.data` contains data sets and data loaders needed +for solving (conditional) neural optimal transport problems. + +Datasets +-------- +.. autosummary:: + :toctree: _autosummary + + dataloaders.OTDataset + +Dataloaders +----------- +.. autosummary:: + :toctree: _autosummary + + dataloaders.ConditionalOTDataLoader diff --git a/docs/neural/duality.rst b/docs/neural/duality.rst new file mode 100644 index 000000000..ea3f67bdf --- /dev/null +++ b/docs/neural/duality.rst @@ -0,0 +1,37 @@ +ott.neural.duality +================== +.. module:: ott.neural.duality +.. currentmodule:: ott.neural.duality + +This module implements various solvers to estimate optimal transport between +two probability measures, through samples, parameterized as neural networks. +These solvers build uponn dual formulation of the optimal transport problem. + +Solvers +------- +.. autosummary:: + :toctree: _autosummary + + neuraldual.W2NeuralDual + neuraldual.BaseW2NeuralDual + +Conjugate Solvers +----------------- +.. autosummary:: + :toctree: _autosummary + + conjugate.FenchelConjugateLBFGS + conjugate.FenchelConjugateSolver + conjugate.ConjugateResults + +Models +------ +.. autosummary:: + :toctree: _autosummary + + neuraldual.W2NeuralTrainState + neuraldual.BaseW2NeuralDual + neuraldual.W2NeuralDual + models.ICNN + models.PotentialMLP + models.MetaInitializer diff --git a/docs/neural/flow_models.rst b/docs/neural/flow_models.rst new file mode 100644 index 000000000..5d9d1f594 --- /dev/null +++ b/docs/neural/flow_models.rst @@ -0,0 +1,48 @@ +ott.neural.flow_models +====================== +.. module:: ott.neural.flow_models +.. currentmodule:: ott.neural.flow_models + +This module implements various solvers building upon flow matching +:cite:`lipman:22` to match distributions. + +Flows +----- +.. autosummary:: + :toctree: _autosummary + + flows.BaseFlow + flows.StraightFlow + flows.ConstantNoiseFlow + flows.BrownianNoiseFlow + +Optimal Transport Flow Matching +------------------------------- +.. autosummary:: + :toctree: _autosummary + + otfm.OTFlowMatching + +GENOT +----- +.. autosummary:: + :toctree: _autosummary + + genot.GENOTBase + genot.GENOTLin + genot.GENOTQuad + +Models +------ +.. autosummary:: + :toctree: _autosummary + + models.VelocityField + +Utils +----- +.. autosummary:: + :toctree: _autosummary + + layers.CyclicalTimeEncoder + samplers.uniform_sampler diff --git a/docs/neural/gap_models.rst b/docs/neural/gap_models.rst new file mode 100644 index 000000000..bacc93c71 --- /dev/null +++ b/docs/neural/gap_models.rst @@ -0,0 +1,26 @@ +ott.neural.models +================= +.. module:: ott.neural.models +.. currentmodule:: ott.neural.models + +This module implements models, network architectures and helper +functions which apply to various neural optimal transport solvers. + +Utils +----- +.. autosummary:: + :toctree: _autosummary + + base_solver.BaseOTMatcher + base_solver.OTMatcherLinear + base_solver.OTMatcherQuad + base_solver.UnbalancednessHandler + + +Neural networks +--------------- +.. autosummary:: + :toctree: _autosummary + + layers.MLPBlock + nets.RescalingMLP diff --git a/docs/neural/index.rst b/docs/neural/index.rst index d0315edae..06d9fd97b 100644 --- a/docs/neural/index.rst +++ b/docs/neural/index.rst @@ -13,29 +13,8 @@ and solvers to estimate such neural networks. .. toctree:: :maxdepth: 2 - solvers - -Models ------- -.. autosummary:: - :toctree: _autosummary - - models.ICNN - models.MLP - models.MetaInitializer - -Losses ------- -.. autosummary:: - :toctree: _autosummary - - losses.monge_gap - losses.monge_gap_from_samples - -Layers ------- -.. autosummary:: - :toctree: _autosummary - - layers.PositiveDense - layers.PosDefPotentials + data + duality + flow_models + gaps + models diff --git a/docs/neural/models.rst b/docs/neural/models.rst new file mode 100644 index 000000000..e69de29bb From 1afb922ca24a4bb91e1f95bde65c11d01fe90b30 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 16 Feb 2024 12:43:52 +0100 Subject: [PATCH 094/186] incorporate more changes --- docs/neural/data.rst | 2 +- docs/neural/gap_models.rst | 26 ------- docs/neural/gaps.rst | 15 ++++ docs/neural/models.rst | 26 +++++++ docs/neural/solvers.rst | 28 ------- pyproject.toml | 1 + src/ott/neural/duality/neuraldual.py | 8 +- src/ott/neural/flow_models/genot.py | 56 +++++++------- src/ott/neural/models/base_solver.py | 105 ++++++++++++--------------- tests/neural/genot_test.py | 64 ++++++++++++---- tests/neural/otfm_test.py | 22 +++--- 11 files changed, 185 insertions(+), 168 deletions(-) delete mode 100644 docs/neural/gap_models.rst create mode 100644 docs/neural/gaps.rst delete mode 100644 docs/neural/solvers.rst diff --git a/docs/neural/data.rst b/docs/neural/data.rst index 8e13e7631..970499ff5 100644 --- a/docs/neural/data.rst +++ b/docs/neural/data.rst @@ -11,7 +11,7 @@ Datasets .. autosummary:: :toctree: _autosummary - dataloaders.OTDataset + dataloaders.OTDataSet Dataloaders ----------- diff --git a/docs/neural/gap_models.rst b/docs/neural/gap_models.rst deleted file mode 100644 index bacc93c71..000000000 --- a/docs/neural/gap_models.rst +++ /dev/null @@ -1,26 +0,0 @@ -ott.neural.models -================= -.. module:: ott.neural.models -.. currentmodule:: ott.neural.models - -This module implements models, network architectures and helper -functions which apply to various neural optimal transport solvers. - -Utils ------ -.. autosummary:: - :toctree: _autosummary - - base_solver.BaseOTMatcher - base_solver.OTMatcherLinear - base_solver.OTMatcherQuad - base_solver.UnbalancednessHandler - - -Neural networks ---------------- -.. autosummary:: - :toctree: _autosummary - - layers.MLPBlock - nets.RescalingMLP diff --git a/docs/neural/gaps.rst b/docs/neural/gaps.rst new file mode 100644 index 000000000..abf621e24 --- /dev/null +++ b/docs/neural/gaps.rst @@ -0,0 +1,15 @@ +ott.neural.gaps +=============== +.. module:: ott.neural.gaps +.. currentmodule:: ott.neural.gaps + +This module implements gap models. + +Monge gap +--------- +.. autosummary:: + :toctree: _autosummary + + map_estimator.MapEstimator + monge_gap.monge_gap + monge_gap.monge_gap_from_samples diff --git a/docs/neural/models.rst b/docs/neural/models.rst index e69de29bb..bacc93c71 100644 --- a/docs/neural/models.rst +++ b/docs/neural/models.rst @@ -0,0 +1,26 @@ +ott.neural.models +================= +.. module:: ott.neural.models +.. currentmodule:: ott.neural.models + +This module implements models, network architectures and helper +functions which apply to various neural optimal transport solvers. + +Utils +----- +.. autosummary:: + :toctree: _autosummary + + base_solver.BaseOTMatcher + base_solver.OTMatcherLinear + base_solver.OTMatcherQuad + base_solver.UnbalancednessHandler + + +Neural networks +--------------- +.. autosummary:: + :toctree: _autosummary + + layers.MLPBlock + nets.RescalingMLP diff --git a/docs/neural/solvers.rst b/docs/neural/solvers.rst deleted file mode 100644 index c405d89ba..000000000 --- a/docs/neural/solvers.rst +++ /dev/null @@ -1,28 +0,0 @@ -ott.neural.solvers -================== -.. module:: ott.neural.solvers -.. currentmodule:: ott.neural.solvers - -This module implements various solvers to estimate optimal transport between -two probability measures, through samples, parameterized as neural networks. -These neural networks are described in :mod:`ott.neural.models`, borrowing -lower-level components from :mod:`ott.neural.layers` using -`flax `__. - -Solvers -------- -.. autosummary:: - :toctree: _autosummary - - map_estimator.MapEstimator - neuraldual.W2NeuralDual - neuraldual.BaseW2NeuralDual - -Conjugate Solvers ------------------ -.. autosummary:: - :toctree: _autosummary - - conjugate.FenchelConjugateLBFGS - conjugate.FenchelConjugateSolver - conjugate.ConjugateResults diff --git a/pyproject.toml b/pyproject.toml index 7306525b1..1c71241f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,6 +79,7 @@ test = [ "tslearn>=0.5; python_version < '3.12'", "lineax; python_version >= '3.9'", "matplotlib", + "torch" ] docs = [ "sphinx>=4.0", diff --git a/src/ott/neural/duality/neuraldual.py b/src/ott/neural/duality/neuraldual.py index 3ea88f74a..c00acb76c 100644 --- a/src/ott/neural/duality/neuraldual.py +++ b/src/ott/neural/duality/neuraldual.py @@ -53,7 +53,7 @@ class W2NeuralTrainState(train_state.TrainState): This extends :class:`~flax.training.train_state.TrainState` to include the potential methods from the - :class:`~ott.neural.solvers.neuraldual.BaseW2NeuralDual` used during training. + :class:`~ott.neural.duality.neuraldual.BaseW2NeuralDual` used during training. Args: potential_value_fn: the potential's value function @@ -186,10 +186,10 @@ class W2NeuralDual: transport map from :math:`\beta` to :math:`\alpha`. This solver estimates the conjugate :math:`f^\star` with a neural approximation :math:`g` that is fine-tuned - with :class:`~ott.neural.solvers.conjugate.FenchelConjugateSolver`, + with :class:`~ott.neural.duality.conjugate.FenchelConjugateSolver`, which is a combination further described in :cite:`amos:23`. - The :class:`~ott.neural.solvers.neuraldual.BaseW2NeuralDual` potentials for + The :class:`~ott.neural.duality.neuraldual.BaseW2NeuralDual` potentials for ``neural_f`` and ``neural_g`` can 1. both provide the values of the potentials :math:`f` and :math:`g`, or @@ -198,7 +198,7 @@ class W2NeuralDual: via the Fenchel conjugate as discussed in :cite:`amos:23`. The potential's value or gradient mapping is specified via - :attr:`~ott.neural.solvers.neuraldual.BaseW2NeuralDual.is_potential`. + :attr:`~ott.neural.duality.neuraldual.BaseW2NeuralDual.is_potential`. Args: dim_data: input dimensionality of data required for network init diff --git a/src/ott/neural/flow_models/genot.py b/src/ott/neural/flow_models/genot.py index c900d1268..ff9b12c82 100644 --- a/src/ott/neural/flow_models/genot.py +++ b/src/ott/neural/flow_models/genot.py @@ -84,6 +84,7 @@ def __init__( velocity_field: Callable[[ jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray], Optional[jnp.ndarray] ], jnp.ndarray], + *, input_dim: int, output_dim: int, cond_dim: int, @@ -194,6 +195,8 @@ def transport( condition: Optional[jnp.ndarray] = None, rng: Optional[jax.Array] = None, forward: bool = True, + t_0: float = 0.0, + t_1: float = 1.0, **kwargs: Any, ) -> Union[jnp.array, diffrax.Solution, Optional[jnp.ndarray]]: """Transport data with the learnt plan. @@ -207,6 +210,8 @@ def transport( condition: Condition of the input data. rng: random seed for sampling from the latent distribution. forward: If `True` integrates forward, otherwise backwards. + t_0: Starting time of integration of neural ODE. + t_1: End time of integration of neural ODE. kwargs: Keyword arguments for the ODE solver. Returns: @@ -217,36 +222,37 @@ def transport( rng = utils.default_prng_key(rng) if not forward: raise NotImplementedError - assert len(source) == len(condition) if condition is not None else True - + if condition is not None: + assert len(source) == len(condition), (len(source), len(condition)) latent_batch = self.latent_noise_fn(rng, shape=(len(source),)) - cond_input = source if condition is None else jnp.concatenate([ - source, condition - ], - axis=-1) - t0, t1 = (0.0, 1.0) + cond_input = source if condition is None else ( + jnp.concatenate([source, condition], axis=-1) + ) @jax.jit def solve_ode(input: jnp.ndarray, cond: jnp.ndarray) -> jnp.ndarray: - return diffrax.diffeqsolve( - diffrax.ODETerm( - lambda t, x, args: self.state_velocity_field. - apply_fn({"params": self.state_velocity_field.params}, - t=t, - x=x, - condition=cond) - ), - kwargs.pop("solver", diffrax.Tsit5()), - t0=t0, - t1=t1, + ode_term = diffrax.ODETerm( + lambda t, x, args: self.state_velocity_field. + apply_fn({"params": self.state_velocity_field.params}, + t=t, + x=x, + condition=cond) + ), + solver = kwargs.pop("solver", diffrax.Tsit5()) + stepsize_controller = kwargs.pop( + "stepsize_controller", diffrax.PIDController(rtol=1e-5, atol=1e-5) + ) + sol = diffrax.diffeqsolve( + ode_term, + solver, + t0=t_0, + t1=t_1, dt0=kwargs.pop("dt0", None), y0=input, - stepsize_controller=kwargs.pop( - "stepsize_controller", - diffrax.PIDController(rtol=1e-5, atol=1e-5) - ), + stepsize_controller=stepsize_controller, **kwargs, - ).ys[0] + ) + return sol.ys[0] return jax.vmap(solve_ode)(latent_batch, cond_input) @@ -264,8 +270,8 @@ def learn_rescaling(self) -> bool: def _reshape_samples(self, arrays: Tuple[jnp.ndarray, ...], batch_size: int) -> Tuple[jnp.ndarray, ...]: return jax.tree_util.tree_map( - lambda x: jnp.reshape(x, (batch_size * self.k_samples_per_x, -1)) - if x is not None else None, arrays + lambda x: jnp.reshape(x, (batch_size * self.k_samples_per_x, -1)), + arrays ) diff --git a/src/ott/neural/models/base_solver.py b/src/ott/neural/models/base_solver.py index c0da4db6a..042e9fd0c 100644 --- a/src/ott/neural/models/base_solver.py +++ b/src/ott/neural/models/base_solver.py @@ -25,7 +25,17 @@ from ott.problems.quadratic import quadratic_problem from ott.solvers import was_solver from ott.solvers.linear import sinkhorn -from ott.solvers.quadratic import gromov_wasserstein +from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr + +Scale_cost_lin_t = Union[bool, int, float, Literal["mean", "max_cost", + "median"]] +Scale_cost_quad_t = Union[Union[bool, int, float, + Literal["mean", "max_norm", "max_bound", + "max_cost", "median"]], + Dict[str, + Union[bool, int, float, + Literal["mean", "max_norm", "max_bound", + "max_cost", "median"]]]], __all__ = [ "BaseOTMatcher", "OTMatcherLinear", "OTMatcherQuad", "UnbalancednessHandler" @@ -57,12 +67,7 @@ def match_pairs(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: def _get_gromov_match_fn( ot_solver: Any, cost_fn: Union[Any, Mapping[str, Any]], - scale_cost: Union[Union[bool, int, float, - Literal["mean", "max_norm", "max_bound", "max_cost", - "median"]], - Dict[str, Union[bool, int, float, - Literal["mean", "max_norm", "max_bound", - "max_cost", "median"]]]], + scale_cost: Scale_cost_quad_t, tau_a: float, tau_b: float, fused_penalty: float, @@ -150,18 +155,17 @@ def sample_joint( ) indices_source = indices // n_tgt indices_target = indices % n_tgt - return tree_util.tree_map( - lambda b: b[indices_source] if b is not None else b, source_arrays - ), tree_util.tree_map( - lambda b: b[indices_target] if b is not None else b, target_arrays - ) + return tree_util.tree_map(lambda b: b[indices_source], + source_arrays), tree_util.tree_map( + lambda b: b[indices_target], target_arrays + ) def sample_conditional_indices_from_tmap( self, rng: jax.Array, conditional_distributions: jnp.ndarray, *, - k_samples_per_x: Union[int, jnp.ndarray], + k_samples_per_x: int, source_arrays: Tuple[Optional[jnp.ndarray], ...], target_arrays: Tuple[Optional[jnp.ndarray], ...], source_is_balanced: bool, @@ -198,7 +202,7 @@ def sample_conditional_indices_from_tmap( tmat_adapted = conditional_distributions[indices] indices_per_row = jax.vmap( lambda row: jax.random. - choice(key=rng, a=jnp.arange(n_tgt), p=row, shape=(k_samples_per_x,)), + choice(key=rng, a=n_tgt, p=row, shape=(k_samples_per_x,)), in_axes=0, out_axes=0, )( @@ -211,12 +215,12 @@ def sample_conditional_indices_from_tmap( ) return tree_util.tree_map( lambda b: jnp. - reshape(b[indices_source], (k_samples_per_x, n_src, *b.shape[1:])) - if b is not None else None, source_arrays + reshape(b[indices_source], + (k_samples_per_x, n_src, *b.shape[1:])), source_arrays ), tree_util.tree_map( lambda b: jnp. - reshape(b[indices_target], (k_samples_per_x, n_src, *b.shape[1:])) - if b is not None else b, target_arrays + reshape(b[indices_target], + (k_samples_per_x, n_src, *b.shape[1:])), target_arrays ) @@ -232,12 +236,12 @@ class OTMatcherLinear(BaseOTMatcher): def __init__( self, - ot_solver: was_solver.WassersteinSolver, + ot_solver: sinkhorn.Sinkhorn, epsilon: float = 1e-2, cost_fn: Optional[costs.CostFn] = None, scale_cost: Union[bool, int, float, Literal["mean", "max_norm", "max_bound", "max_cost", - "median"]] = "mean", + "median"]] = 1.0, tau_a: float = 1.0, tau_b: float = 1.0, ) -> None: @@ -283,11 +287,10 @@ class OTMatcherQuad(BaseOTMatcher): def __init__( self, - ot_solver: was_solver.WassersteinSolver, + ot_solver: Union[gromov_wasserstein.GromovWasserstein, + gromov_wasserstein_lr.LRGromovWasserstein], cost_fn: Optional[costs.CostFn] = None, - scale_cost: Union[bool, int, float, - Literal["mean", "max_norm", "max_bound", "max_cost", - "median"]] = "mean", + scale_cost: Scale_cost_quad_t = 1.0, tau_a: float = 1.0, tau_b: float = 1.0, fused_penalty: float = 0.0, @@ -347,7 +350,8 @@ class UnbalancednessHandler: is not available, and hence the unbalanced marginals must be computed by the neural solver. kwargs: Additional keyword arguments. - """ # noqa: E501 # TODO(MUCDK): fix me + + """ # noqa: E501 def __init__( self, @@ -364,8 +368,7 @@ def __init__( opt_eta: Optional[optax.GradientTransformation] = None, opt_xi: Optional[optax.GradientTransformation] = None, resample_epsilon: float = 1e-2, - scale_cost: Union[bool, int, float, Literal["mean", "max_cost", - "median"]] = "mean", + scale_cost: Union[Scale_cost_lin_t, Scale_cost_quad_t] = 1.0, ot_solver: Optional[was_solver.WassersteinSolver] = None, **kwargs: Mapping[str, Any], ): @@ -443,9 +446,7 @@ def resample_unbalanced( Resampled arrays. """ indices = jax.random.choice(rng, a=len(p), p=jnp.squeeze(p), shape=[len(p)]) - return tree_util.tree_map( - lambda b: b[indices] if b is not None else b, arrays - ) + return tree_util.tree_map(lambda b: b[indices], arrays) def setup(self, source_dim: int, target_dim: int, cond_dim: int): """Setup the model. @@ -479,37 +480,19 @@ def setup(self, source_dim: int, target_dim: int, cond_dim: int): def _get_rescaling_step_fn(self) -> Callable: # type:ignore[type-arg] - def loss_a_fn( - params_eta: Optional[jnp.ndarray], - apply_fn_eta: Callable[[Dict[str, jnp.ndarray], jnp.ndarray], - jnp.ndarray], + def loss_marginal_fn( + params: jnp.ndarray, + apply_fn: Callable[[Dict[str, jnp.ndarray], jnp.ndarray], + Optional[jnp.ndarray]], x: jnp.ndarray, condition: Optional[jnp.ndarray], - a: jnp.ndarray, + true_marginals: jnp.ndarray, expectation_reweighting: float, ) -> Tuple[float, jnp.ndarray]: - eta_predictions = apply_fn_eta({"params": params_eta}, x, condition) - return ( - optax.l2_loss(eta_predictions[:, 0], a).mean() + - optax.l2_loss(jnp.mean(eta_predictions) - expectation_reweighting), - eta_predictions, - ) - - def loss_b_fn( - params_xi: Optional[jnp.ndarray], - apply_fn_xi: Callable[[Dict[str, jnp.ndarray], jnp.ndarray], - jnp.ndarray], - x: jnp.ndarray, - condition: Optional[jnp.ndarray], - b: jnp.ndarray, - expectation_reweighting: float, - ) -> Tuple[float, jnp.ndarray]: - xi_predictions = apply_fn_xi({"params": params_xi}, x, condition) - return ( - optax.l2_loss(xi_predictions[:, 0], b).mean() + - optax.l2_loss(jnp.mean(xi_predictions) - expectation_reweighting), - xi_predictions, - ) + predictions = apply_fn({"params": params}, x, condition) + pred_loss = optax.l2_loss(jnp.squeeze(predictions), true_marginals).mean() + exp_loss = optax.l2_loss(jnp.mean(predictions) - expectation_reweighting) + return (pred_loss + exp_loss, predictions) @jax.jit def step_fn( @@ -524,7 +507,9 @@ def step_fn( is_training: bool = True, ): if state_eta is not None: - grad_a_fn = jax.value_and_grad(loss_a_fn, argnums=0, has_aux=True) + grad_a_fn = jax.value_and_grad( + loss_marginal_fn, argnums=0, has_aux=True + ) (loss_a, eta_predictions), grads_eta = grad_a_fn( state_eta.params, state_eta.apply_fn, @@ -540,7 +525,9 @@ def step_fn( else: new_state_eta = eta_predictions = loss_a = None if state_xi is not None: - grad_b_fn = jax.value_and_grad(loss_b_fn, argnums=0, has_aux=True) + grad_b_fn = jax.value_and_grad( + loss_marginal_fn, argnums=0, has_aux=True + ) (loss_b, xi_predictions), grads_xi = grad_b_fn( state_xi.params, state_xi.apply_fn, diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index 9480eb3cd..e924edd9a 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -27,8 +27,8 @@ from ott.neural.flow_models.samplers import uniform_sampler from ott.neural.models import base_solver from ott.neural.models.nets import RescalingMLP -from ott.solvers.linear import sinkhorn -from ott.solvers.quadratic import gromov_wasserstein +from ott.solvers.linear import sinkhorn, sinkhorn_lr +from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr class TestGENOTLin: @@ -36,10 +36,14 @@ class TestGENOTLin: @pytest.mark.parametrize("scale_cost", ["mean", 2.0]) @pytest.mark.parametrize("k_samples_per_x", [1, 3]) @pytest.mark.parametrize("solver_latent_to_data", [None, "sinkhorn"]) + @pytest.mark.parametrize("solver", ["sinkhorn", "lr_sinkhorn"]) def test_genot_linear_unconditional( - self, genot_data_loader_linear: Iterator, - scale_cost: Union[float, Literal["mean"]], k_samples_per_x: int, - solver_latent_to_data: Optional[str] + self, + genot_data_loader_linear: Iterator, + scale_cost: Union[float, Literal["mean"]], + k_samples_per_x: int, + solver_latent_to_data: Optional[str], + solver: Literal["sinkhorn", "lr_sinkhorn"], ): matcher_latent_to_data = ( None if solver_latent_to_data is None else @@ -62,7 +66,8 @@ def test_genot_linear_unconditional( condition_dim=source_dim + condition_dim, latent_embed_dim=5, ) - ot_solver = sinkhorn.Sinkhorn() + ot_solver = sinkhorn.Sinkhorn( + ) if solver == "sinkhorn" else sinkhorn_lr.LRSinkhorn(rank=3) ot_matcher = base_solver.OTMatcherLinear( ot_solver, cost_fn=costs.SqEuclidean(), scale_cost=scale_cost ) @@ -96,9 +101,11 @@ def test_genot_linear_unconditional( @pytest.mark.parametrize("k_samples_per_x", [1, 2]) @pytest.mark.parametrize("solver_latent_to_data", [None, "sinkhorn"]) + @pytest.mark.parametrize("solver", ["sinkhorn", "lr_sinkhorn"]) def test_genot_linear_conditional( self, genot_data_loader_linear_conditional: Iterator, - k_samples_per_x: int, solver_latent_to_data: Optional[str] + k_samples_per_x: int, solver_latent_to_data: Optional[str], + solver: Literal["sinkhorn", "lr_sinkhorn"] ): matcher_latent_to_data = ( None if solver_latent_to_data is None else @@ -121,7 +128,8 @@ def test_genot_linear_conditional( condition_dim=source_dim + condition_dim, latent_embed_dim=5, ) - ot_solver = sinkhorn.Sinkhorn() + ot_solver = sinkhorn.Sinkhorn( + ) if solver == "sinkhorn" else sinkhorn_lr.LRSinkhorn(rank=3) ot_matcher = base_solver.OTMatcherLinear( ot_solver, cost_fn=costs.SqEuclidean() ) @@ -243,9 +251,11 @@ class TestGENOTQuad: @pytest.mark.parametrize("k_samples_per_x", [1, 2]) @pytest.mark.parametrize("solver_latent_to_data", [None, "sinkhorn"]) + @pytest.mark.parametrize("solver", ["gromov", "gromov_lr"]) def test_genot_quad_unconditional( self, genot_data_loader_quad: Iterator, k_samples_per_x: int, - solver_latent_to_data: Optional[str] + solver_latent_to_data: Optional[str], solver: Literal["gromov", + "gromov_lr"] ): matcher_latent_to_data = ( None if solver_latent_to_data is None else @@ -266,7 +276,11 @@ def test_genot_quad_unconditional( condition_dim=source_dim + condition_dim, latent_embed_dim=5, ) - ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) + ot_solver = gromov_wasserstein.GromovWasserstein( + epsilon=1e-2 + ) if solver == "gromov" else gromov_wasserstein_lr.LRGromovWasserstein( + rank=3, epsilon=1e-2 + ) ot_matcher = base_solver.OTMatcherQuad( ot_solver, cost_fn=costs.SqEuclidean() ) @@ -301,9 +315,11 @@ def test_genot_quad_unconditional( @pytest.mark.parametrize("k_samples_per_x", [1, 2]) @pytest.mark.parametrize("solver_latent_to_data", [None, "sinkhorn"]) + @pytest.mark.parametrize("solver", ["gromov", "gromov_lr"]) def test_genot_fused_unconditional( self, genot_data_loader_fused: Iterator, k_samples_per_x: int, - solver_latent_to_data: Optional[str] + solver_latent_to_data: Optional[str], solver: Literal["gromov", + "gromov_lr"] ): matcher_latent_to_data = ( None if solver_latent_to_data is None else @@ -326,7 +342,11 @@ def test_genot_fused_unconditional( condition_dim=source_dim + condition_dim, latent_embed_dim=5, ) - ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) + ot_solver = gromov_wasserstein.GromovWasserstein( + epsilon=1e-2 + ) if solver == "gromov" else gromov_wasserstein_lr.LRGromovWasserstein( + rank=3, epsilon=1e-2 + ) ot_matcher = base_solver.OTMatcherQuad( ot_solver, cost_fn=costs.SqEuclidean(), fused_penalty=0.5 ) @@ -361,9 +381,11 @@ def test_genot_fused_unconditional( @pytest.mark.parametrize("k_samples_per_x", [1, 2]) @pytest.mark.parametrize("solver_latent_to_data", [None, "sinkhorn"]) + @pytest.mark.parametrize("solver", ["gromov", "gromov_lr"]) def test_genot_quad_conditional( self, genot_data_loader_quad_conditional: Iterator, k_samples_per_x: int, - solver_latent_to_data: Optional[str] + solver_latent_to_data: Optional[str], solver: Literal["gromov", + "gromov_lr"] ): matcher_latent_to_data = ( None if solver_latent_to_data is None else @@ -385,7 +407,11 @@ def test_genot_quad_conditional( condition_dim=source_dim + condition_dim, latent_embed_dim=5, ) - ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) + ot_solver = gromov_wasserstein.GromovWasserstein( + epsilon=1e-2 + ) if solver == "gromov" else gromov_wasserstein_lr.LRGromovWasserstein( + rank=3, epsilon=1e-2 + ) ot_matcher = base_solver.OTMatcherQuad( ot_solver, cost_fn=costs.SqEuclidean() ) @@ -421,9 +447,11 @@ def test_genot_quad_conditional( @pytest.mark.parametrize("k_samples_per_x", [1, 2]) @pytest.mark.parametrize("solver_latent_to_data", [None, "sinkhorn"]) + @pytest.mark.parametrize("solver", ["gromov", "gromov_lr"]) def test_genot_fused_conditional( self, genot_data_loader_fused_conditional: Iterator, k_samples_per_x: int, - solver_latent_to_data: Optional[str] + solver_latent_to_data: Optional[str], solver: Literal["gromov", + "gromov_lr"] ): solver_latent_to_data = ( None if solver_latent_to_data is None else sinkhorn.Sinkhorn() @@ -448,7 +476,11 @@ def test_genot_fused_conditional( condition_dim=source_dim + condition_dim, latent_embed_dim=5, ) - ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-2) + ot_solver = gromov_wasserstein.GromovWasserstein( + epsilon=1e-2 + ) if solver == "gromov" else gromov_wasserstein_lr.LRGromovWasserstein( + rank=3, epsilon=1e-2 + ) ot_matcher = base_solver.OTMatcherQuad( ot_solver, cost_fn=costs.SqEuclidean(), fused_penalty=0.5 ) diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index d66ea1611..4e68315ff 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -from typing import Iterator, Type +from typing import Iterator, Literal, Type import pytest @@ -87,8 +87,8 @@ def test_flow_matching_unconditional( @pytest.mark.parametrize( "flow", [ flows.ConstantNoiseFlow(0.0), - flows.ConstantNoiseFlow(1.0), - flows.BrownianNoiseFlow(0.2) + flows.ConstantNoiseFlow(1.1), + flows.BrownianNoiseFlow(2.2) ] ) def test_flow_matching_with_conditions( @@ -145,14 +145,17 @@ def test_flow_matching_with_conditions( assert jnp.sum(jnp.isnan(result_backward)) == 0 @pytest.mark.parametrize( - "flow", [ + "flow", + [ flows.ConstantNoiseFlow(0.0), - flows.ConstantNoiseFlow(1.0), - flows.BrownianNoiseFlow(0.2) - ] + flows.ConstantNoiseFlow(13.0), + flows.BrownianNoiseFlow(0.12) + ], ) + @pytest.mark.parametrize("solver", ["sinkhorn", "lr_sinkhorn"]) def test_flow_matching_conditional( - self, data_loader_gaussian_conditional, flow: Type[flows.BaseFlow] + self, data_loader_gaussian_conditional, flow: Type[flows.BaseFlow], + solver: Literal["sinkhorn", "lr_sinkhorn"] ): dim = 2 condition_dim = 0 @@ -161,7 +164,8 @@ def test_flow_matching_conditional( condition_dim=condition_dim, latent_embed_dim=5, ) - ot_solver = sinkhorn.Sinkhorn() + ot_solver = sinkhorn.Sinkhorn( + ) if solver == "sinkhorn" else sinkhorn.LRSinkhorn() ot_matcher = base_solver.OTMatcherLinear(ot_solver) time_sampler = samplers.uniform_sampler optimizer = optax.adam(learning_rate=1e-3) From dc436f4daff37315ff7ecd313e4f5fb696061ce3 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 16 Feb 2024 12:48:05 +0100 Subject: [PATCH 095/186] problem with custom type --- src/ott/neural/models/base_solver.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/src/ott/neural/models/base_solver.py b/src/ott/neural/models/base_solver.py index 042e9fd0c..1742a2a2b 100644 --- a/src/ott/neural/models/base_solver.py +++ b/src/ott/neural/models/base_solver.py @@ -27,15 +27,13 @@ from ott.solvers.linear import sinkhorn from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr -Scale_cost_lin_t = Union[bool, int, float, Literal["mean", "max_cost", - "median"]] -Scale_cost_quad_t = Union[Union[bool, int, float, - Literal["mean", "max_norm", "max_bound", - "max_cost", "median"]], - Dict[str, - Union[bool, int, float, - Literal["mean", "max_norm", "max_bound", - "max_cost", "median"]]]], +ScaleCostLin_t = Union[bool, int, float, Literal["mean", "max_cost", "median"]] +ScaleCostQuad_t = Union[Union[bool, int, float, + Literal["mean", "max_norm", "max_bound", + "max_cost", "median"]], + Dict[str, Union[bool, int, float, + Literal["mean", "max_norm", "max_bound", + "max_cost", "median"]]]], __all__ = [ "BaseOTMatcher", "OTMatcherLinear", "OTMatcherQuad", "UnbalancednessHandler" @@ -67,7 +65,7 @@ def match_pairs(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: def _get_gromov_match_fn( ot_solver: Any, cost_fn: Union[Any, Mapping[str, Any]], - scale_cost: Scale_cost_quad_t, + scale_cost: ScaleCostQuad_t, tau_a: float, tau_b: float, fused_penalty: float, @@ -290,7 +288,7 @@ def __init__( ot_solver: Union[gromov_wasserstein.GromovWasserstein, gromov_wasserstein_lr.LRGromovWasserstein], cost_fn: Optional[costs.CostFn] = None, - scale_cost: Scale_cost_quad_t = 1.0, + scale_cost: ScaleCostQuad_t = 1.0, tau_a: float = 1.0, tau_b: float = 1.0, fused_penalty: float = 0.0, @@ -368,7 +366,7 @@ def __init__( opt_eta: Optional[optax.GradientTransformation] = None, opt_xi: Optional[optax.GradientTransformation] = None, resample_epsilon: float = 1e-2, - scale_cost: Union[Scale_cost_lin_t, Scale_cost_quad_t] = 1.0, + scale_cost: Union[ScaleCostLin_t, ScaleCostQuad_t] = 1.0, ot_solver: Optional[was_solver.WassersteinSolver] = None, **kwargs: Mapping[str, Any], ): From 8bfe1a34e930e5eb18d9ab35bc1be5fc39a2ee7f Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 16 Feb 2024 13:29:16 +0100 Subject: [PATCH 096/186] fix scale cost bug --- src/ott/neural/models/base_solver.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/src/ott/neural/models/base_solver.py b/src/ott/neural/models/base_solver.py index 1742a2a2b..5bfa8caa0 100644 --- a/src/ott/neural/models/base_solver.py +++ b/src/ott/neural/models/base_solver.py @@ -27,13 +27,8 @@ from ott.solvers.linear import sinkhorn from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr -ScaleCostLin_t = Union[bool, int, float, Literal["mean", "max_cost", "median"]] -ScaleCostQuad_t = Union[Union[bool, int, float, - Literal["mean", "max_norm", "max_bound", - "max_cost", "median"]], - Dict[str, Union[bool, int, float, - Literal["mean", "max_norm", "max_bound", - "max_cost", "median"]]]], +ScaleCost_t = Union[int, float, Literal["mean", "max_cost", "median"]] +ScaleCostQuad_t = Union[ScaleCost_t, Dict[str, ScaleCost_t]] __all__ = [ "BaseOTMatcher", "OTMatcherLinear", "OTMatcherQuad", "UnbalancednessHandler" @@ -44,8 +39,7 @@ def _get_sinkhorn_match_fn( ot_solver: Any, epsilon: float = 1e-2, cost_fn: Optional[costs.CostFn] = None, - scale_cost: Union[bool, int, float, Literal["mean", "max_norm", "max_bound", - "max_cost", "median"]] = "mean", + scale_cost: ScaleCost_t = 1.0, tau_a: float = 1.0, tau_b: float = 1.0, ) -> Callable: @@ -237,9 +231,7 @@ def __init__( ot_solver: sinkhorn.Sinkhorn, epsilon: float = 1e-2, cost_fn: Optional[costs.CostFn] = None, - scale_cost: Union[bool, int, float, - Literal["mean", "max_norm", "max_bound", "max_cost", - "median"]] = 1.0, + scale_cost: ScaleCost_t = 1.0, tau_a: float = 1.0, tau_b: float = 1.0, ) -> None: @@ -366,7 +358,7 @@ def __init__( opt_eta: Optional[optax.GradientTransformation] = None, opt_xi: Optional[optax.GradientTransformation] = None, resample_epsilon: float = 1e-2, - scale_cost: Union[ScaleCostLin_t, ScaleCostQuad_t] = 1.0, + scale_cost: Union[ScaleCost_t, ScaleCostQuad_t] = 1.0, ot_solver: Optional[was_solver.WassersteinSolver] = None, **kwargs: Mapping[str, Any], ): From 2a1f23addb2567cccce007df6be8c8efe8111713 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 16 Feb 2024 14:14:35 +0100 Subject: [PATCH 097/186] fix bugs --- src/ott/neural/flow_models/genot.py | 2 ++ src/ott/neural/models/base_solver.py | 4 ++-- tests/neural/otfm_test.py | 4 ++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/ott/neural/flow_models/genot.py b/src/ott/neural/flow_models/genot.py index ff9b12c82..101c32978 100644 --- a/src/ott/neural/flow_models/genot.py +++ b/src/ott/neural/flow_models/genot.py @@ -290,6 +290,7 @@ def __call__(self, train_loader, valid_loader): valid_loader: Data loader for the validation data. """ iter = -1 + stop = False while True: for batch in train_loader: iter += 1 @@ -388,6 +389,7 @@ def __call__(self, train_loader, valid_loader): """ batch: Dict[str, jnp.array] = {} iter = -1 + stop = False while True: for batch in train_loader: iter += 1 diff --git a/src/ott/neural/models/base_solver.py b/src/ott/neural/models/base_solver.py index 5bfa8caa0..e9be6baa8 100644 --- a/src/ott/neural/models/base_solver.py +++ b/src/ott/neural/models/base_solver.py @@ -88,8 +88,8 @@ def _get_gromov_match_fn( @jax.jit def match_pairs( - x_quad: Tuple[jnp.ndarray, jnp.ndarray], - y_quad: Tuple[jnp.ndarray, jnp.ndarray], + x_quad: jnp.ndarray, + y_quad: jnp.ndarray, x_lin: Optional[jnp.ndarray], y_lin: Optional[jnp.ndarray], ) -> jnp.ndarray: diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index 4e68315ff..6f8d14879 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -23,7 +23,7 @@ from ott.neural.flow_models import flows, models, otfm, samplers from ott.neural.models import base_solver, nets -from ott.solvers.linear import sinkhorn +from ott.solvers.linear import sinkhorn, sinkhorn_lr class TestOTFlowMatching: @@ -165,7 +165,7 @@ def test_flow_matching_conditional( latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn( - ) if solver == "sinkhorn" else sinkhorn.LRSinkhorn() + ) if solver == "sinkhorn" else sinkhorn_lr.LRSinkhorn() ot_matcher = base_solver.OTMatcherLinear(ot_solver) time_sampler = samplers.uniform_sampler optimizer = optax.adam(learning_rate=1e-3) From a46405c96f90d2aa5740f1c4e9cbbade00019520 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Sun, 18 Feb 2024 18:48:12 +0100 Subject: [PATCH 098/186] fux bug in unbalancedness/rescalingMlp --- src/ott/neural/models/nets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ott/neural/models/nets.py b/src/ott/neural/models/nets.py index 0acb7daae..cad4e84c2 100644 --- a/src/ott/neural/models/nets.py +++ b/src/ott/neural/models/nets.py @@ -89,7 +89,7 @@ def __call__( out_layer = layers.MLPBlock( dim=self.hidden_dim, - out_dim=self.hidden_dim, + out_dim=1, num_layers=self.num_layers_per_block, act_fn=self.act_fn ) From 7afcac456a47a62330a4fc6b1a550d6c4dd360d3 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Sun, 18 Feb 2024 18:56:40 +0100 Subject: [PATCH 099/186] unify unbalancedness step in GENOT --- src/ott/neural/flow_models/genot.py | 52 +++++++++++++++++------------ 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/src/ott/neural/flow_models/genot.py b/src/ott/neural/flow_models/genot.py index 101c32978..4c5db7fa2 100644 --- a/src/ott/neural/flow_models/genot.py +++ b/src/ott/neural/flow_models/genot.py @@ -274,6 +274,25 @@ def _reshape_samples(self, arrays: Tuple[jnp.ndarray, ...], arrays ) + def _learn_rescaling( + self, source: jnp.ndarray, target: jnp.ndarray, + source_conditions: Optional[jnp.ndarray], tmat: jnp.ndarray + ) -> Tuple[jnp.ndarray, jnp.ndarray, float, float]: + + ( + self.state_eta, self.state_xi, eta_predictions, xi_predictions, loss_a, + loss_b + ) = self.unbalancedness_handler.step_fn( + source=source, + target=target, + condition=source_conditions, + a=tmat.sum(axis=1), + b=tmat.sum(axis=0), + state_eta=self.unbalancedness_handler.state_eta, + state_xi=self.unbalancedness_handler.state_xi, + ) + return eta_predictions, xi_predictions, float(loss_a), float(loss_b) + class GENOTLin(GENOTBase): """Implementation of GENOT-L (:cite:`klein:23`). @@ -304,7 +323,7 @@ def __call__(self, train_loader, valid_loader): source, source_conditions, target = jnp.array( batch["source_lin"] ), jnp.array(batch["source_conditions"] - ) if len(batch["source_conditions"]) else None, jnp.array( + ) if "source_conditions" in batch else None, jnp.array( batch["target_lin"] ) @@ -353,18 +372,13 @@ def __call__(self, train_loader, valid_loader): latent, source_conditions ) if self.learn_rescaling: - ( - self.state_eta, self.state_xi, eta_predictions, xi_predictions, - loss_a, loss_b - ) = self.unbalancedness_handler.step_fn( + eta_preds, xi_preds, loss_a, loss_b = self._learn_rescaling( source=source, target=target, condition=source_conditions, - a=tmat.sum(axis=1), - b=tmat.sum(axis=0), - state_eta=self.unbalancedness_handler.state_eta, - state_xi=self.unbalancedness_handler.state_xi, + tmat=tmat ) + if iter % self.valid_freq == 0: self._valid_step(valid_loader, iter) if stop: @@ -403,12 +417,12 @@ def __call__(self, train_loader, valid_loader): ) = jax.random.split(self.rng, 6) (source_lin, source_quad, source_conditions, target_lin, target_quad) = ( - jnp.array(batch["source_lin"]) if len(batch["source_lin"]) else - None, jnp.array(batch["source_quad"]), + jnp.array(batch["source_lin"]) if "source_lin" in batch else None, + jnp.array(batch["source_quad"]), jnp.array(batch["source_conditions"]) - if len(batch["source_conditions"]) else None, - jnp.array(batch["target_lin"]) if len(batch["target_lin"]) else - None, jnp.array(batch["target_quad"]) + if "source_conditions" in batch else None, + jnp.array(batch["target_lin"]) if "target_lin" in batch else None, + jnp.array(batch["target_quad"]) ) batch_size = len(source_quad) n_samples = batch_size * self.k_samples_per_x @@ -464,17 +478,11 @@ def __call__(self, train_loader, valid_loader): latent, source_conditions ) if self.learn_rescaling: - ( - self.state_eta, self.state_xi, eta_predictions, xi_predictions, - loss_a, loss_b - ) = self.unbalancedness_handler.step_fn( + eta_preds, xi_preds, loss_a, loss_b = self._learn_rescaling( source=source, target=target, condition=source_conditions, - a=tmat.sum(axis=1), - b=tmat.sum(axis=0), - state_eta=self.unbalancedness_handler.state_eta, - state_xi=self.unbalancedness_handler.state_xi, + tmat=tmat ) if iter % self.valid_freq == 0: self._valid_step(valid_loader, iter) From 4fc8fe625c876209f24f9f281036fd4ff75b5b34 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Sun, 18 Feb 2024 19:20:43 +0100 Subject: [PATCH 100/186] change OTDataSet and OTFlowMatching to 4 data loaderes --- src/ott/neural/data/dataloaders.py | 83 ++++++++++------------------- src/ott/neural/flow_models/genot.py | 5 ++ src/ott/neural/flow_models/otfm.py | 24 +++++---- tests/neural/conftest.py | 59 ++++++++++---------- tests/neural/otfm_test.py | 20 ++++--- 5 files changed, 90 insertions(+), 101 deletions(-) diff --git a/src/ott/neural/data/dataloaders.py b/src/ott/neural/data/dataloaders.py index e063deefd..8083a744c 100644 --- a/src/ott/neural/data/dataloaders.py +++ b/src/ott/neural/data/dataloaders.py @@ -14,6 +14,7 @@ from typing import Any, List, Mapping, Optional import numpy as np +from jax import tree_util __all__ = ["OTDataSet", "ConditionalOTDataLoader"] @@ -22,72 +23,44 @@ class OTDataSet: """Data set for OT problems. Args: - source_lin: Linear part of the source measure. - source_quad: Quadratic part of the source measure. - target_lin: Linear part of the target measure. - target_quad: Quadratic part of the target measure. - source_conditions: Conditions of the source measure. - target_conditions: Conditions of the target measure. + lin: Linear part of the measure. + quad: Quadratic part of the measure. + conditions: Conditions of the source measure. """ def __init__( self, - source_lin: Optional[np.ndarray] = None, - source_quad: Optional[np.ndarray] = None, - target_lin: Optional[np.ndarray] = None, - target_quad: Optional[np.ndarray] = None, - source_conditions: Optional[np.ndarray] = None, - target_conditions: Optional[np.ndarray] = None, + lin: Optional[np.ndarray] = None, + quad: Optional[np.ndarray] = None, + conditions: Optional[np.ndarray] = None, ): - if source_lin is not None: - if source_quad is not None: - assert len(source_lin) == len(source_quad) - self.n_source = len(source_lin) + if lin is not None: + if quad is not None: + assert len(lin) == len(quad) + self.n_samples = len(lin) else: - self.n_source = len(source_lin) + self.n_samples = len(lin) else: - self.n_source = len(source_quad) - if source_conditions is not None: - assert len(source_conditions) == self.n_source - if target_lin is not None: - if target_quad is not None: - assert len(target_lin) == len(target_quad) - self.n_target = len(target_lin) - else: - self.n_target = len(target_lin) - else: - self.n_target = len(target_quad) - if target_conditions is not None: - assert len(target_conditions) == self.n_target - - self.source_lin = source_lin - self.target_lin = target_lin - self.source_quad = source_quad - self.target_quad = target_quad - self.source_conditions = source_conditions - self.target_conditions = target_conditions + self.n_samples = len(quad) + if conditions is not None: + assert len(conditions) == self.n_samples + + self.lin = lin + self.quad = quad + self.conditions = conditions + self._tree = {} + if lin is not None: + self._tree["lin"] = lin + if quad is not None: + self._tree["quad"] = quad + if conditions is not None: + self._tree["conditions"] = conditions def __getitem__(self, idx: np.ndarray) -> Mapping[str, np.ndarray]: - return { - "source_lin": - self.source_lin[idx] if self.source_lin is not None else [], - "source_quad": - self.source_quad[idx] if self.source_quad is not None else [], - "target_lin": - self.target_lin[idx] if self.target_lin is not None else [], - "target_quad": - self.target_quad[idx] if self.target_quad is not None else [], - "source_conditions": - self.source_conditions[idx] - if self.source_conditions is not None else [], - "target_conditions": - self.target_conditions[idx] - if self.target_conditions is not None else [], - } + return tree_util.tree_map(lambda x: x[idx], self._tree) def __len__(self): - return len(self.source_lin - ) if self.source_lin is not None else len(self.source_quad) + return self.n_samples class ConditionalOTDataLoader: diff --git a/src/ott/neural/flow_models/genot.py b/src/ott/neural/flow_models/genot.py index 4c5db7fa2..d466f015d 100644 --- a/src/ott/neural/flow_models/genot.py +++ b/src/ott/neural/flow_models/genot.py @@ -339,6 +339,11 @@ def __call__(self, train_loader, valid_loader): target, ) + jax.debug.print("source.shape {x}", x=source.shape) + jax.debug.print( + "source_conditions.shape {x}", x=source_conditions.shape + ) + jax.debug.print("target.shape {x}", x=target.shape) (source, source_conditions ), (target,) = self.ot_matcher.sample_conditional_indices_from_tmap( rng=rng_resample, diff --git a/src/ott/neural/flow_models/otfm.py b/src/ott/neural/flow_models/otfm.py index dca7bea60..6c519f4da 100644 --- a/src/ott/neural/flow_models/otfm.py +++ b/src/ott/neural/flow_models/otfm.py @@ -153,28 +153,30 @@ def loss_fn( return step_fn - def __call__(self, train_loader, valid_loader): + def __call__( + self, train_loader_source, train_loader_target, valid_loader_source, + valid_loader_target + ): """Train :class:`OTFlowMatching`. Args; train_loader: Dataloader for the training data. valid_loader: Dataloader for the validation data. """ - batch: Mapping[str, jnp.ndarray] = {} - iter = -1 while True: - for batch in train_loader: + for batch_source, batch_target in zip( + train_loader_source, train_loader_target + ): iter += 1 if iter >= self.iterations: stop = True break rng_resample, rng_step_fn, self.rng = jax.random.split(self.rng, 3) - source, source_conditions, target = jnp.array( - batch["source_lin"] - ), jnp.array(batch["source_conditions"]) if len( - batch["source_conditions"] - ) > 0 else None, jnp.array(batch["target_lin"]) + source, source_conditions = jnp.array(batch_source["lin"]), jnp.array( + batch_source["conditions"] + ) if "conditions" in batch_source else None + target = jnp.array(batch_target["lin"]) if self.ot_matcher is not None: tmat = self.ot_matcher.match_fn(source, target) (source, source_conditions), (target,) = self.ot_matcher.sample_joint( @@ -200,7 +202,7 @@ def __call__(self, train_loader, valid_loader): state_xi=self.unbalancedness_handler.state_xi, ) if iter % self.valid_freq == 0: - self._valid_step(valid_loader, iter) + self._valid_step(valid_loader_source, valid_loader_target, iter) if stop: break @@ -258,7 +260,7 @@ def solve_ode(input: jnp.ndarray, cond: jnp.ndarray) -> jnp.ndarray: return jax.vmap(solve_ode)(data, condition) - def _valid_step(self, valid_loader, iter): + def _valid_step(self, valid_loader_source, valid_loader_target, iter): pass @property diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py index f33252f07..e40f93c16 100644 --- a/tests/neural/conftest.py +++ b/tests/neural/conftest.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple + import pytest import numpy as np @@ -21,13 +23,16 @@ @pytest.fixture(scope="module") -def data_loader_gaussian(): +def data_loaders_gaussian() -> Tuple[Torch_loader, Torch_loader]: """Returns a data loader for a simple Gaussian mixture.""" rng = np.random.default_rng(seed=0) source = rng.normal(size=(100, 2)) target = rng.normal(size=(100, 2)) + 1.0 - dataset = dataloaders.OTDataSet(source_lin=source, target_lin=target) - return Torch_loader(dataset, batch_size=16, shuffle=True) + src_dataset = dataloaders.OTDataSet(lin=source) + tgt_dataset = dataloaders.OTDataSet(lin=target) + loader_src = Torch_loader(src_dataset, batch_size=16, shuffle=True) + loader_tgt = Torch_loader(tgt_dataset, batch_size=16, shuffle=True) + return loader_src, loader_tgt @pytest.fixture(scope="module") @@ -40,14 +45,14 @@ def data_loader_gaussian_conditional(): source_1 = rng.normal(size=(100, 2)) target_1 = rng.normal(size=(100, 2)) - 2.0 ds0 = dataloaders.OTDataSet( - source_lin=source_0, + lin=source_0, target_lin=target_0, - source_conditions=np.zeros_like(source_0) * 0.0 + conditions=np.zeros_like(source_0) * 0.0 ) ds1 = dataloaders.OTDataSet( - source_lin=source_1, + lin=source_1, target_lin=target_1, - source_conditions=np.ones_like(source_1) * 1.0 + conditions=np.ones_like(source_1) * 1.0 ) sampler0 = torch.utils.data.RandomSampler(ds0, replacement=True) sampler1 = torch.utils.data.RandomSampler(ds1, replacement=True) @@ -67,9 +72,9 @@ def data_loader_gaussian_with_conditions(): target_conditions = rng.normal(size=(100, 1)) - 1.0 dataset = dataloaders.OTDataSet( - source_lin=source, + lin=source, target_lin=target, - source_conditions=source_conditions, + conditions=source_conditions, target_conditions=target_conditions ) return Torch_loader(dataset, batch_size=16, shuffle=True) @@ -81,7 +86,7 @@ def genot_data_loader_linear(): rng = np.random.default_rng(seed=0) source = rng.normal(size=(100, 2)) target = rng.normal(size=(100, 2)) + 1.0 - dataset = dataloaders.OTDataSet(source_lin=source, target_lin=target) + dataset = dataloaders.OTDataSet(lin=source, target_lin=target) return Torch_loader(dataset, batch_size=16, shuffle=True) @@ -94,14 +99,14 @@ def genot_data_loader_linear_conditional(): source_1 = rng.normal(size=(100, 2)) target_1 = rng.normal(size=(100, 2)) + 1.0 ds0 = dataloaders.OTDataSet( - source_lin=source_0, + lin=source_0, target_lin=target_0, - source_conditions=np.zeros_like(source_0) * 0.0 + conditions=np.zeros_like(source_0) * 0.0 ) ds1 = dataloaders.OTDataSet( - source_lin=source_1, + lin=source_1, target_lin=target_1, - source_conditions=np.ones_like(source_1) * 1.0 + conditions=np.ones_like(source_1) * 1.0 ) sampler0 = torch.utils.data.RandomSampler(ds0, replacement=True) sampler1 = torch.utils.data.RandomSampler(ds1, replacement=True) @@ -117,7 +122,7 @@ def genot_data_loader_quad(): rng = np.random.default_rng(seed=0) source = rng.normal(size=(100, 2)) target = rng.normal(size=(100, 1)) + 1.0 - dataset = dataloaders.OTDataSet(source_quad=source, target_quad=target) + dataset = dataloaders.OTDataSet(quad=source, target_quad=target) return Torch_loader(dataset, batch_size=16, shuffle=True) @@ -130,14 +135,14 @@ def genot_data_loader_quad_conditional(): source_1 = rng.normal(size=(100, 2)) target_1 = rng.normal(size=(100, 1)) + 1.0 ds0 = dataloaders.OTDataSet( - source_quad=source_0, + quad=source_0, target_quad=target_0, - source_conditions=np.zeros_like(source_0) * 0.0 + conditions=np.zeros_like(source_0) * 0.0 ) ds1 = dataloaders.OTDataSet( - source_quad=source_1, + quad=source_1, target_quad=target_1, - source_conditions=np.ones_like(source_1) * 1.0 + conditions=np.ones_like(source_1) * 1.0 ) sampler0 = torch.utils.data.RandomSampler(ds0, replacement=True) sampler1 = torch.utils.data.RandomSampler(ds1, replacement=True) @@ -156,8 +161,8 @@ def genot_data_loader_fused(): source_lin = rng.normal(size=(100, 2)) target_lin = rng.normal(size=(100, 2)) + 1.0 dataset = dataloaders.OTDataSet( - source_lin=source_lin, - source_quad=source_q, + lin=source_lin, + quad=source_q, target_lin=target_lin, target_quad=target_q ) @@ -179,18 +184,18 @@ def genot_data_loader_fused_conditional(): target_lin_1 = 2 * rng.normal(size=(100, 2)) + 1.0 ds0 = dataloaders.OTDataSet( - source_lin=source_lin_0, + lin=source_lin_0, target_lin=target_lin_0, - source_quad=source_q_0, + quad=source_q_0, target_quad=target_q_0, - source_conditions=np.zeros_like(source_lin_0) * 0.0 + conditions=np.zeros_like(source_lin_0) * 0.0 ) ds1 = dataloaders.OTDataSet( - source_lin=source_lin_1, + lin=source_lin_1, target_lin=target_lin_1, - source_quad=source_q_1, + quad=source_q_1, target_quad=target_q_1, - source_conditions=np.ones_like(source_lin_1) * 1.0 + conditions=np.ones_like(source_lin_1) * 1.0 ) sampler0 = torch.utils.data.RandomSampler(ds0, replacement=True) sampler1 = torch.utils.data.RandomSampler(ds1, replacement=True) diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index 6f8d14879..e57fce89c 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -36,7 +36,7 @@ class TestOTFlowMatching: ] ) def test_flow_matching_unconditional( - self, data_loader_gaussian, flow: Type[flows.BaseFlow] + self, data_loaders_gaussian, flow: Type[flows.BaseFlow] ): input_dim = 2 condition_dim = 0 @@ -64,14 +64,18 @@ def test_flow_matching_unconditional( optimizer=optimizer, unbalancedness_handler=unbalancedness_handler ) - fm(data_loader_gaussian, data_loader_gaussian) + fm( + data_loaders_gaussian[0], data_loaders_gaussian[1], + data_loaders_gaussian[0], data_loaders_gaussian[1] + ) - batch = next(iter(data_loader_gaussian)) - source = jnp.asarray(batch["source_lin"]) - target = jnp.asarray(batch["target_lin"]) - source_conditions = jnp.asarray(batch["source_conditions"]) if len( - batch["source_conditions"] - ) > 0 else None + batch_src = next(iter(data_loaders_gaussian[0])) + source = jnp.asarray(batch_src["lin"]) + batch_tgt = next(iter(data_loaders_gaussian[1])) + target = jnp.asarray(batch_tgt["lin"]) + source_conditions = jnp.asarray( + batch_src["conditions"] + ) if "conditions" in batch_src else None result_forward = fm.transport( source, condition=source_conditions, forward=True ) From 43d37f7eda3d4d08e820fa6ece1955593f6a2256 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 19 Feb 2024 13:26:39 +0100 Subject: [PATCH 101/186] Fix bug in the `ConditionalOTDataset` --- docs/neural/data.rst | 10 +--- src/ott/neural/data/__init__.py | 2 +- src/ott/neural/data/dataloaders.py | 91 ------------------------------ src/ott/neural/data/datasets.py | 87 ++++++++++++++++++++++++++++ tests/neural/conftest.py | 70 +++++++++++------------ 5 files changed, 125 insertions(+), 135 deletions(-) delete mode 100644 src/ott/neural/data/dataloaders.py create mode 100644 src/ott/neural/data/datasets.py diff --git a/docs/neural/data.rst b/docs/neural/data.rst index 970499ff5..95f05f93f 100644 --- a/docs/neural/data.rst +++ b/docs/neural/data.rst @@ -11,11 +11,5 @@ Datasets .. autosummary:: :toctree: _autosummary - dataloaders.OTDataSet - -Dataloaders ------------ -.. autosummary:: - :toctree: _autosummary - - dataloaders.ConditionalOTDataLoader + datasets.OTDataset + datasets.ConditionalOTDataset diff --git a/src/ott/neural/data/__init__.py b/src/ott/neural/data/__init__.py index 51f8dd2af..785604b21 100644 --- a/src/ott/neural/data/__init__.py +++ b/src/ott/neural/data/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import dataloaders +from . import datasets diff --git a/src/ott/neural/data/dataloaders.py b/src/ott/neural/data/dataloaders.py deleted file mode 100644 index 8083a744c..000000000 --- a/src/ott/neural/data/dataloaders.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright OTT-JAX -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, List, Mapping, Optional - -import numpy as np -from jax import tree_util - -__all__ = ["OTDataSet", "ConditionalOTDataLoader"] - - -class OTDataSet: - """Data set for OT problems. - - Args: - lin: Linear part of the measure. - quad: Quadratic part of the measure. - conditions: Conditions of the source measure. - """ - - def __init__( - self, - lin: Optional[np.ndarray] = None, - quad: Optional[np.ndarray] = None, - conditions: Optional[np.ndarray] = None, - ): - if lin is not None: - if quad is not None: - assert len(lin) == len(quad) - self.n_samples = len(lin) - else: - self.n_samples = len(lin) - else: - self.n_samples = len(quad) - if conditions is not None: - assert len(conditions) == self.n_samples - - self.lin = lin - self.quad = quad - self.conditions = conditions - self._tree = {} - if lin is not None: - self._tree["lin"] = lin - if quad is not None: - self._tree["quad"] = quad - if conditions is not None: - self._tree["conditions"] = conditions - - def __getitem__(self, idx: np.ndarray) -> Mapping[str, np.ndarray]: - return tree_util.tree_map(lambda x: x[idx], self._tree) - - def __len__(self): - return self.n_samples - - -class ConditionalOTDataLoader: - """Data loader for OT problems with conditions. - - This data loader wraps several data loaders and samples from them. - - Args: - dataloaders: List of data loaders. - seed: Random seed. - """ - - def __init__( - self, - dataloaders: List[Any], - seed: int = 0 # dataloader should subclass torch dataloader - ): - super().__init__() - self.dataloaders = dataloaders - self.conditions = list(dataloaders) - self.rng = np.random.default_rng(seed=seed) - - def __next__(self) -> Mapping[str, np.ndarray]: - idx = self.rng.choice(len(self.conditions)) - return next(iter(self.dataloaders[idx])) - - def __iter__(self) -> "ConditionalOTDataLoader": - return self diff --git a/src/ott/neural/data/datasets.py b/src/ott/neural/data/datasets.py new file mode 100644 index 000000000..990c27a2a --- /dev/null +++ b/src/ott/neural/data/datasets.py @@ -0,0 +1,87 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, List, Optional + +import jax.tree_util as jtu +import numpy as np + +__all__ = ["OTDataset", "ConditionalOTDataset"] + + +class OTDataset: + """Dataset for Optimal transport problems. + + Args: + lin: Linear part of the measure. + quad: Quadratic part of the measure. + conditions: Conditions of the source measure. + """ + + def __init__( + self, + lin: Optional[np.ndarray] = None, + quad: Optional[np.ndarray] = None, + conditions: Optional[np.ndarray] = None, + ): + self.data = {} + if lin is not None: + self.data["lin"] = lin + if quad is not None: + self.data["quad"] = quad + if conditions is not None: + self.data["conditions"] = conditions + self._check_sizes() + + def _check_sizes(self) -> None: + sizes = {k: len(v) for k, v in self.data.items()} + if not len(set(sizes.values())) == 1: + raise ValueError(f"Not all arrays have the same size: {sizes}.") + + def __getitem__(self, idx: np.ndarray) -> Dict[str, np.ndarray]: + return jtu.tree_map(lambda x: x[idx], self.data)["lin"] + + def __len__(self) -> int: + for v in self.data.values(): + return len(v) + return 0 + + +# TODO(michalk8): rename +class ConditionalOTDataset: + """Dataset for OT problems with conditions. + + This data loader wraps several data loaders and samples from them. + + Args: + datasets: Datasets to sample from. + seed: Random seed. + """ + + def __init__( + self, + # TODO(michalk8): allow for dict with weights + datasets: List[OTDataset], + seed: Optional[int] = None, + ): + self.datasets = tuple(datasets) + self._rng = np.random.default_rng(seed=seed) + self._iterators = () + + def __next__(self) -> Dict[str, np.ndarray]: + idx = self._rng.choice(len(self._iterators)) + return next(self._iterators[idx]) + + def __iter__(self) -> "ConditionalOTDataset": + self._iterators = tuple(iter(ds) for ds in self.datasets) + return self diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py index e40f93c16..f5c48e924 100644 --- a/tests/neural/conftest.py +++ b/tests/neural/conftest.py @@ -17,21 +17,21 @@ import numpy as np import torch -from torch.utils.data import DataLoader as Torch_loader +from torch.utils.data import DataLoader -from ott.neural.data import dataloaders +from ott.neural.data import datasets @pytest.fixture(scope="module") -def data_loaders_gaussian() -> Tuple[Torch_loader, Torch_loader]: +def data_loaders_gaussian() -> Tuple[DataLoader, DataLoader]: """Returns a data loader for a simple Gaussian mixture.""" rng = np.random.default_rng(seed=0) source = rng.normal(size=(100, 2)) target = rng.normal(size=(100, 2)) + 1.0 - src_dataset = dataloaders.OTDataSet(lin=source) - tgt_dataset = dataloaders.OTDataSet(lin=target) - loader_src = Torch_loader(src_dataset, batch_size=16, shuffle=True) - loader_tgt = Torch_loader(tgt_dataset, batch_size=16, shuffle=True) + src_dataset = datasets.OTDataset(lin=source) + tgt_dataset = datasets.OTDataset(lin=target) + loader_src = DataLoader(src_dataset, batch_size=16, shuffle=True) + loader_tgt = DataLoader(tgt_dataset, batch_size=16, shuffle=True) return loader_src, loader_tgt @@ -44,22 +44,22 @@ def data_loader_gaussian_conditional(): source_1 = rng.normal(size=(100, 2)) target_1 = rng.normal(size=(100, 2)) - 2.0 - ds0 = dataloaders.OTDataSet( + ds0 = datasets.OTDataset( lin=source_0, target_lin=target_0, conditions=np.zeros_like(source_0) * 0.0 ) - ds1 = dataloaders.OTDataSet( + ds1 = datasets.OTDataset( lin=source_1, target_lin=target_1, conditions=np.ones_like(source_1) * 1.0 ) sampler0 = torch.utils.data.RandomSampler(ds0, replacement=True) sampler1 = torch.utils.data.RandomSampler(ds1, replacement=True) - dl0 = Torch_loader(ds0, batch_size=16, sampler=sampler0) - dl1 = Torch_loader(ds1, batch_size=16, sampler=sampler1) + dl0 = DataLoader(ds0, batch_size=16, sampler=sampler0) + dl1 = DataLoader(ds1, batch_size=16, sampler=sampler1) - return dataloaders.ConditionalOTDataLoader((dl0, dl1)) + return datasets.ConditionalOTDataset((dl0, dl1)) @pytest.fixture(scope="module") @@ -71,13 +71,13 @@ def data_loader_gaussian_with_conditions(): source_conditions = rng.normal(size=(100, 1)) target_conditions = rng.normal(size=(100, 1)) - 1.0 - dataset = dataloaders.OTDataSet( + dataset = datasets.OTDataset( lin=source, target_lin=target, conditions=source_conditions, target_conditions=target_conditions ) - return Torch_loader(dataset, batch_size=16, shuffle=True) + return DataLoader(dataset, batch_size=16, shuffle=True) @pytest.fixture(scope="module") @@ -86,8 +86,8 @@ def genot_data_loader_linear(): rng = np.random.default_rng(seed=0) source = rng.normal(size=(100, 2)) target = rng.normal(size=(100, 2)) + 1.0 - dataset = dataloaders.OTDataSet(lin=source, target_lin=target) - return Torch_loader(dataset, batch_size=16, shuffle=True) + dataset = datasets.OTDataset(lin=source, target_lin=target) + return DataLoader(dataset, batch_size=16, shuffle=True) @pytest.fixture(scope="module") @@ -98,22 +98,22 @@ def genot_data_loader_linear_conditional(): target_0 = rng.normal(size=(100, 2)) + 1.0 source_1 = rng.normal(size=(100, 2)) target_1 = rng.normal(size=(100, 2)) + 1.0 - ds0 = dataloaders.OTDataSet( + ds0 = datasets.OTDataset( lin=source_0, target_lin=target_0, conditions=np.zeros_like(source_0) * 0.0 ) - ds1 = dataloaders.OTDataSet( + ds1 = datasets.OTDataset( lin=source_1, target_lin=target_1, conditions=np.ones_like(source_1) * 1.0 ) sampler0 = torch.utils.data.RandomSampler(ds0, replacement=True) sampler1 = torch.utils.data.RandomSampler(ds1, replacement=True) - dl0 = Torch_loader(ds0, batch_size=16, sampler=sampler0) - dl1 = Torch_loader(ds1, batch_size=16, sampler=sampler1) + dl0 = DataLoader(ds0, batch_size=16, sampler=sampler0) + dl1 = DataLoader(ds1, batch_size=16, sampler=sampler1) - return dataloaders.ConditionalOTDataLoader((dl0, dl1)) + return datasets.ConditionalOTDataset((dl0, dl1)) @pytest.fixture(scope="module") @@ -122,8 +122,8 @@ def genot_data_loader_quad(): rng = np.random.default_rng(seed=0) source = rng.normal(size=(100, 2)) target = rng.normal(size=(100, 1)) + 1.0 - dataset = dataloaders.OTDataSet(quad=source, target_quad=target) - return Torch_loader(dataset, batch_size=16, shuffle=True) + dataset = datasets.OTDataset(quad=source, target_quad=target) + return DataLoader(dataset, batch_size=16, shuffle=True) @pytest.fixture(scope="module") @@ -134,22 +134,22 @@ def genot_data_loader_quad_conditional(): target_0 = rng.normal(size=(100, 1)) + 1.0 source_1 = rng.normal(size=(100, 2)) target_1 = rng.normal(size=(100, 1)) + 1.0 - ds0 = dataloaders.OTDataSet( + ds0 = datasets.OTDataset( quad=source_0, target_quad=target_0, conditions=np.zeros_like(source_0) * 0.0 ) - ds1 = dataloaders.OTDataSet( + ds1 = datasets.OTDataset( quad=source_1, target_quad=target_1, conditions=np.ones_like(source_1) * 1.0 ) sampler0 = torch.utils.data.RandomSampler(ds0, replacement=True) sampler1 = torch.utils.data.RandomSampler(ds1, replacement=True) - dl0 = Torch_loader(ds0, batch_size=16, sampler=sampler0) - dl1 = Torch_loader(ds1, batch_size=16, sampler=sampler1) + dl0 = DataLoader(ds0, batch_size=16, sampler=sampler0) + dl1 = DataLoader(ds1, batch_size=16, sampler=sampler1) - return dataloaders.ConditionalOTDataLoader((dl0, dl1)) + return datasets.ConditionalOTDataset((dl0, dl1)) @pytest.fixture(scope="module") @@ -160,13 +160,13 @@ def genot_data_loader_fused(): target_q = rng.normal(size=(100, 1)) + 1.0 source_lin = rng.normal(size=(100, 2)) target_lin = rng.normal(size=(100, 2)) + 1.0 - dataset = dataloaders.OTDataSet( + dataset = datasets.OTDataset( lin=source_lin, quad=source_q, target_lin=target_lin, target_quad=target_q ) - return Torch_loader(dataset, batch_size=16, shuffle=True) + return DataLoader(dataset, batch_size=16, shuffle=True) @pytest.fixture(scope="module") @@ -183,14 +183,14 @@ def genot_data_loader_fused_conditional(): source_lin_1 = 2 * rng.normal(size=(100, 2)) target_lin_1 = 2 * rng.normal(size=(100, 2)) + 1.0 - ds0 = dataloaders.OTDataSet( + ds0 = datasets.OTDataset( lin=source_lin_0, target_lin=target_lin_0, quad=source_q_0, target_quad=target_q_0, conditions=np.zeros_like(source_lin_0) * 0.0 ) - ds1 = dataloaders.OTDataSet( + ds1 = datasets.OTDataset( lin=source_lin_1, target_lin=target_lin_1, quad=source_q_1, @@ -199,6 +199,6 @@ def genot_data_loader_fused_conditional(): ) sampler0 = torch.utils.data.RandomSampler(ds0, replacement=True) sampler1 = torch.utils.data.RandomSampler(ds1, replacement=True) - dl0 = Torch_loader(ds0, batch_size=16, sampler=sampler0) - dl1 = Torch_loader(ds1, batch_size=16, sampler=sampler1) - return dataloaders.ConditionalOTDataLoader((dl0, dl1)) + dl0 = DataLoader(ds0, batch_size=16, sampler=sampler0) + dl1 = DataLoader(ds1, batch_size=16, sampler=sampler1) + return datasets.ConditionalOTDataset((dl0, dl1)) From 86f6e7a49b687eb2d1fa0c1a03d7781e569aa3fb Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 19 Feb 2024 13:38:14 +0100 Subject: [PATCH 102/186] Polish docs in the `flows.py` --- src/ott/neural/flow_models/flows.py | 38 ++++++++--------------------- 1 file changed, 10 insertions(+), 28 deletions(-) diff --git a/src/ott/neural/flow_models/flows.py b/src/ott/neural/flow_models/flows.py index fd1009cef..b2e4970bf 100644 --- a/src/ott/neural/flow_models/flows.py +++ b/src/ott/neural/flow_models/flows.py @@ -38,9 +38,9 @@ def __init__(self, sigma: float): def compute_mu_t( self, t: jnp.ndarray, src: jnp.ndarray, tgt: jnp.ndarray ) -> jnp.ndarray: - """Compute the mean of the probablitiy path. + """Compute the mean of the probability path. - Compute the mean of the probablitiy path between :math:`x_0` and :math:`x_1` + Compute the mean of the probability path between :math:`x_0` and :math:`x_1` at time :math:`t`. Args: @@ -51,10 +51,13 @@ def compute_mu_t( @abc.abstractmethod def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray: - """Compute the standard deviation of the probablity path at time :math:`t`. + """Compute the standard deviation of the probability path at time :math:`t`. Args: t: Time :math:`t` of shape `(batch_size, 1)`. + + Returns: + Standard deviation of the probability path at time :math:`t`. """ @abc.abstractmethod @@ -67,7 +70,7 @@ def compute_ut( :math:`x_1` at time :math:`t`. Args: - t: Time :math:`t` of shape `(batch_size, 1)`.. + t: Time :math:`t` of shape `(batch_size, 1)`. src: Sample from the source distribution of shape `(batch_size, ...)`. tgt: Sample from the target distribution of shape `(batch_size, ...)`. @@ -107,22 +110,9 @@ def compute_mu_t( # noqa: D102 ) -> jnp.ndarray: return (1.0 - t) * src + t * tgt - def compute_ut( + def compute_ut( # noqa: D102 self, t: jnp.ndarray, src: jnp.ndarray, tgt: jnp.ndarray ) -> jnp.ndarray: - """Evaluate the conditional vector field. - - Evaluate the conditional vector field defined between :math:`x_0` and - :math:`x_1` at time :math:`t`. - - Args: - t: Time :math:`t` of shape `(batch_size, 1)`. - src: Sample from the source distribution of shape `(batch_size, ...)`. - tgt: Sample from the target distribution of shape `(batch_size, ...)`.. - - Returns: - Conditional vector field evaluated at time :math:`t`. - """ del t return tgt - src @@ -134,7 +124,7 @@ def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray: r"""Compute noise of the flow at time :math:`t`. Args: - t: Time :math:`t` of shape `(batch_size, 1)`.. + t: Time :math:`t` of shape `(batch_size, 1)`. Returns: Constant, time-independent standard deviation :math:`\sigma`. @@ -154,13 +144,5 @@ class BrownianNoiseFlow(StraightFlow): at time :math:`t`. """ - def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray: - """Compute the standard deviation of the probablity path at time :math:`t`. - - Args: - t: Time :math:`t` of shape `(batch_size, 1)`.. - - Returns: - Standard deviation of the probablity path at time :math:`t`. - """ + def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray: # noqa: D102 return self.sigma * jnp.sqrt(t * (1.0 - t)) From ae37132e494766fcd58d1372e4e55b533ebf8355 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 19 Feb 2024 14:47:20 +0100 Subject: [PATCH 103/186] Update `OTFM` --- src/ott/neural/flow_models/otfm.py | 232 +++++++++---------------- src/ott/neural/flow_models/samplers.py | 3 +- 2 files changed, 87 insertions(+), 148 deletions(-) diff --git a/src/ott/neural/flow_models/otfm.py b/src/ott/neural/flow_models/otfm.py index 6c519f4da..3e6d821a1 100644 --- a/src/ott/neural/flow_models/otfm.py +++ b/src/ott/neural/flow_models/otfm.py @@ -11,106 +11,66 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import collections import functools -from typing import ( - Any, - Callable, - Dict, - Literal, - Mapping, - Optional, - Tuple, - Type, - Union, -) +from typing import Any, Callable, Dict, Optional, Tuple import jax import jax.numpy as jnp +import jax.tree_util as jtu import diffrax import optax from flax.training import train_state from ott import utils -from ott.geometry import costs -from ott.neural.flow_models import flows +from ott.neural.flow_models import flows, models from ott.neural.models import base_solver __all__ = ["OTFlowMatching"] class OTFlowMatching: - """(Optimal transport) flow matching class. + """(Optimal transport) flow matching :cite:`lipman:22`. - Flow matching as introduced in :cite:`lipman:22`, with extension to OT-FM - (:cite`tong:23`, :cite:`pooladian:23`). + Includes an extension to OT-FM :cite`tong:23`, :cite:`pooladian:23`. Args: - velocity_field: Neural vector field parameterized by a neural network. input_dim: Dimension of the input data. - cond_dim: Dimension of the conditioning variable. - iterations: Number of iterations. - valid_freq: Frequency of validation. + velocity_field: Neural vector field parameterized by a neural network. flow: Flow between source and target distribution. time_sampler: Sampler for the time. - optimizer: Optimizer for `velocity_field`. - callback_fn: Callback function. - num_eval_samples: Number of samples to evaluate on during evaluation. + optimizer: Optimizer for the ``velocity_field``. + ot_matcher: TODO. + unbalancedness_handler: TODO. rng: Random number generator. """ + # TODO(michalk8): in the future, `input_dim`, `optimizer` and `rng` will be + # in a separate function def __init__( self, - velocity_field: Callable[[ - jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray], Optional[jnp.ndarray] - ], jnp.ndarray], input_dim: int, - cond_dim: int, - iterations: int, - flow: Type[flows.BaseFlow], + velocity_field: models.VelocityField, + flow: flows.BaseFlow, time_sampler: Callable[[jax.Array, int], jnp.ndarray], optimizer: optax.GradientTransformation, - ot_matcher: Optional[base_solver.OTMatcherLinear], - unbalancedness_handler: base_solver.UnbalancednessHandler, - epsilon: float = 1e-2, - cost_fn: Optional[Type[costs.CostFn]] = None, - scale_cost: Union[bool, int, float, - Literal["mean", "max_norm", "max_bound", "max_cost", - "median"]] = "mean", - callback_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], - Any]] = None, - logging_freq: int = 100, - valid_freq: int = 5000, - num_eval_samples: int = 1000, + ot_matcher: Optional[base_solver.OTMatcherLinear] = None, + unbalancedness_handler: Optional[base_solver.UnbalancednessHandler + ] = None, rng: Optional[jax.Array] = None, ): - rng = utils.default_prng_key(rng) - self.unbalancedness_handler = unbalancedness_handler - self.iterations = iterations - self.valid_freq = valid_freq - self.velocity_field = velocity_field self.input_dim = input_dim - self.ot_matcher = ot_matcher + self.velocity_field = velocity_field self.flow = flow self.time_sampler = time_sampler + self.unbalancedness_handler = unbalancedness_handler + self.ot_matcher = ot_matcher self.optimizer = optimizer - self.epsilon = epsilon - self.cost_fn = cost_fn - self.scale_cost = scale_cost - self.callback_fn = callback_fn - self.rng = rng - self.logging_freq = logging_freq - self.num_eval_samples = num_eval_samples - self._training_logs: Mapping[str, Any] = collections.defaultdict(list) - - self.setup() - - def setup(self) -> None: - """Setup :class:`OTFlowMatching`.""" + + rng = utils.default_prng_key(rng) self.state_velocity_field = ( self.velocity_field.create_train_state( - self.rng, self.optimizer, self.input_dim + rng, self.optimizer, self.input_dim ) ) @@ -153,41 +113,46 @@ def loss_fn( return step_fn - def __call__( - self, train_loader_source, train_loader_target, valid_loader_source, - valid_loader_target - ): - """Train :class:`OTFlowMatching`. + # TODO(michalk8): refactor in the future PR to just do one step + def __call__( # noqa: D102 + self, + n_iters: int, + train_source, + train_target, + valid_source, + valid_target, + valid_freq: int = 5000, + rng: Optional[jax.Array] = None, + ) -> Dict[str, Any]: + rng = utils.default_prng_key(rng) + training_logs = {"loss": []} + + for it in range(n_iters): + for batch_source, batch_target in zip(train_source, train_target): + rng, rng_resample, rng_step_fn = jax.random.split(rng, 3) + + batch_source = jtu.tree_map(jnp.asarray, batch_source) + batch_target = jtu.tree_map(jnp.asarray, batch_target) + + source = batch_source["lin"] + source_conditions = batch_source.get("conditions", None) + target = batch_target["lin"] - Args; - train_loader: Dataloader for the training data. - valid_loader: Dataloader for the validation data. - """ - iter = -1 - while True: - for batch_source, batch_target in zip( - train_loader_source, train_loader_target - ): - iter += 1 - if iter >= self.iterations: - stop = True - break - rng_resample, rng_step_fn, self.rng = jax.random.split(self.rng, 3) - source, source_conditions = jnp.array(batch_source["lin"]), jnp.array( - batch_source["conditions"] - ) if "conditions" in batch_source else None - target = jnp.array(batch_target["lin"]) if self.ot_matcher is not None: tmat = self.ot_matcher.match_fn(source, target) (source, source_conditions), (target,) = self.ot_matcher.sample_joint( rng_resample, tmat, (source, source_conditions), (target,) ) + else: + tmat = None + self.state_velocity_field, loss = self.step_fn( rng_step_fn, self.state_velocity_field, source, target, source_conditions ) - self._training_logs["loss"].append(loss) - if self.learn_rescaling: + training_logs["loss"].append(loss) + + if self.unbalancedness_handler is not None and tmat is not None: ( self.unbalancedness_handler.state_eta, self.unbalancedness_handler.state_xi, eta_predictions, @@ -201,23 +166,24 @@ def __call__( state_eta=self.unbalancedness_handler.state_eta, state_xi=self.unbalancedness_handler.state_xi, ) - if iter % self.valid_freq == 0: - self._valid_step(valid_loader_source, valid_loader_target, iter) - if stop: - break + + if it % valid_freq == 0: + self._valid_step(valid_source, valid_target, it) + + return training_logs def transport( self, data: jnp.array, condition: Optional[jnp.ndarray] = None, forward: bool = True, - t_0: float = 0.0, - t_1: float = 1.0, + t0: float = 0.0, + t1: float = 1.0, **kwargs: Any, - ) -> diffrax.Solution: + ) -> jnp.ndarray: """Transport data with the learnt map. - This method pushes-forward the `source` by + This method pushes-forward the ``data`` by solving the neural ODE parameterized by the :attr:`~ott.neural.flows.OTFlowMatching.velocity_field`. @@ -225,72 +191,44 @@ def transport( data: Initial condition of the ODE. condition: Condition of the input data. forward: If `True` integrates forward, otherwise backwards. - t_0: Starting point of integration. - t_1: End point of integration. + t0: Starting point of integration. + t1: End point of integration. kwargs: Keyword arguments for the ODE solver. Returns: The push-forward or pull-back distribution defined by the learnt transport plan. - """ - t0, t1 = (t_0, t_1) if forward else (t_1, t_0) - @jax.jit def solve_ode(input: jnp.ndarray, cond: jnp.ndarray) -> jnp.ndarray: - return diffrax.diffeqsolve( - diffrax.ODETerm( - lambda t, x, args: self.state_velocity_field. - apply_fn({"params": self.state_velocity_field.params}, - t=t, - x=x, - condition=cond) - ), - kwargs.pop("solver", diffrax.Tsit5()), + ode_term = diffrax.ODETerm( + lambda t, x, args: self.state_velocity_field. + apply_fn({"params": self.state_velocity_field.params}, + t=t, + x=x, + condition=cond) + ) + + result = diffrax.diffeqsolve( + ode_term, + solver, t0=t0, t1=t1, dt0=kwargs.pop("dt0", None), y0=input, - stepsize_controller=kwargs.pop( - "stepsize_controller", - diffrax.PIDController(rtol=1e-5, atol=1e-5) - ), + stepsize_controller=stepsize_controller, **kwargs, - ).ys[0] - - return jax.vmap(solve_ode)(data, condition) - - def _valid_step(self, valid_loader_source, valid_loader_target, iter): - pass + ) + return result.ys[0] - @property - def learn_rescaling(self) -> bool: - """Whether to learn at least one rescaling factor.""" - return ( - self.unbalancedness_handler.rescaling_a is not None or - self.unbalancedness_handler.rescaling_b is not None + if not forward: + t0, t1 = t1, t0 + solver = kwargs.pop("solver", diffrax.Tsit5()), + stepsize_controller = kwargs.pop( + "stepsize_controller", diffrax.PIDController(rtol=1e-5, atol=1e-5) ) - def save(self, path: str): - """Save the model. + return jax.jit(jax.vmap(solve_ode))(data, condition) - Args: - path: Where to save the model to. - """ - raise NotImplementedError - - def load(self, path: str) -> "OTFlowMatching": - """Load a model. - - Args: - path: Where to load the model from. - - Returns: - An instance of :class:`ott.neural.solvers.OTFlowMatching`. - """ - raise NotImplementedError - - @property - def training_logs(self) -> Dict[str, Any]: - """Logs of the training.""" - raise NotImplementedError + def _valid_step(self, it: int, valid_source, valid_target) -> None: + pass diff --git a/src/ott/neural/flow_models/samplers.py b/src/ott/neural/flow_models/samplers.py index 34a28c2d2..9bd85d8b0 100644 --- a/src/ott/neural/flow_models/samplers.py +++ b/src/ott/neural/flow_models/samplers.py @@ -42,10 +42,11 @@ def uniform_sampler( used. Returns: - An array with `num_samples` samples of the time `math`:t:. + An array with `num_samples` samples of the time :math:`t`. """ if offset is None: return jax.random.uniform(rng, (num_samples, 1), minval=low, maxval=high) + t = jax.random.uniform(rng, (1, 1), minval=low, maxval=high) mod_term = ((high - low) - offset) return (t + jnp.arange(num_samples)[:, None] / num_samples) % mod_term From de323d2ecda7b08f1d612d983b93c54417cd7015 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 19 Feb 2024 16:25:48 +0100 Subject: [PATCH 104/186] Fix small bugs in `OTFM` --- docs/neural/duality.rst | 2 +- src/ott/neural/data/datasets.py | 2 +- src/ott/neural/flow_models/otfm.py | 65 +++++++++++++----------------- 3 files changed, 31 insertions(+), 38 deletions(-) diff --git a/docs/neural/duality.rst b/docs/neural/duality.rst index ea3f67bdf..25dc89daa 100644 --- a/docs/neural/duality.rst +++ b/docs/neural/duality.rst @@ -5,7 +5,7 @@ ott.neural.duality This module implements various solvers to estimate optimal transport between two probability measures, through samples, parameterized as neural networks. -These solvers build uponn dual formulation of the optimal transport problem. +These solvers build upon dual formulation of the optimal transport problem. Solvers ------- diff --git a/src/ott/neural/data/datasets.py b/src/ott/neural/data/datasets.py index 990c27a2a..5a12ed2c0 100644 --- a/src/ott/neural/data/datasets.py +++ b/src/ott/neural/data/datasets.py @@ -49,7 +49,7 @@ def _check_sizes(self) -> None: raise ValueError(f"Not all arrays have the same size: {sizes}.") def __getitem__(self, idx: np.ndarray) -> Dict[str, np.ndarray]: - return jtu.tree_map(lambda x: x[idx], self.data)["lin"] + return jtu.tree_map(lambda x: x[idx], self.data) def __len__(self) -> int: for v in self.data.values(): diff --git a/src/ott/neural/flow_models/otfm.py b/src/ott/neural/flow_models/otfm.py index 3e6d821a1..53788fc37 100644 --- a/src/ott/neural/flow_models/otfm.py +++ b/src/ott/neural/flow_models/otfm.py @@ -59,21 +59,19 @@ def __init__( ] = None, rng: Optional[jax.Array] = None, ): + rng = utils.default_prng_key(rng) + self.input_dim = input_dim - self.velocity_field = velocity_field + self.vf = velocity_field self.flow = flow self.time_sampler = time_sampler self.unbalancedness_handler = unbalancedness_handler self.ot_matcher = ot_matcher self.optimizer = optimizer - rng = utils.default_prng_key(rng) - self.state_velocity_field = ( - self.velocity_field.create_train_state( - rng, self.optimizer, self.input_dim - ) + self.vf_state = self.vf.create_train_state( + rng, self.optimizer, self.input_dim ) - self.step_fn = self._get_step_fn() def _get_step_fn(self) -> Callable: @@ -146,11 +144,10 @@ def __call__( # noqa: D102 else: tmat = None - self.state_velocity_field, loss = self.step_fn( - rng_step_fn, self.state_velocity_field, source, target, - source_conditions + self.vf_state, loss = self.step_fn( + rng_step_fn, self.vf_state, source, target, source_conditions ) - training_logs["loss"].append(loss) + training_logs["loss"].append(float(loss)) if self.unbalancedness_handler is not None and tmat is not None: ( @@ -174,23 +171,20 @@ def __call__( # noqa: D102 def transport( self, - data: jnp.array, + x: jnp.ndarray, condition: Optional[jnp.ndarray] = None, - forward: bool = True, t0: float = 0.0, t1: float = 1.0, **kwargs: Any, ) -> jnp.ndarray: """Transport data with the learnt map. - This method pushes-forward the ``data`` by - solving the neural ODE parameterized by the - :attr:`~ott.neural.flows.OTFlowMatching.velocity_field`. + This method pushes-forward the data by solving the neural ODE + parameterized by the velocity field. Args: - data: Initial condition of the ODE. - condition: Condition of the input data. - forward: If `True` integrates forward, otherwise backwards. + x: Initial condition of the ODE of shape `(batch_size, ...)`. + condition: Condition of the input data of shape `(batch_size, ...)`. t0: Starting point of integration. t1: End point of integration. kwargs: Keyword arguments for the ODE solver. @@ -200,35 +194,34 @@ def transport( transport plan. """ - def solve_ode(input: jnp.ndarray, cond: jnp.ndarray) -> jnp.ndarray: - ode_term = diffrax.ODETerm( - lambda t, x, args: self.state_velocity_field. - apply_fn({"params": self.state_velocity_field.params}, - t=t, - x=x, - condition=cond) - ) + def vf( + t: jnp.ndarray, x: jnp.ndarray, cond: Optional[jnp.ndarray] + ) -> jnp.ndarray: + return self.vf_state.apply_fn({"params": self.vf_state.params}, + t=t, + x=x, + condition=cond) + def solve_ode(x: jnp.ndarray, cond: Optional[jnp.ndarray]) -> jnp.ndarray: + ode_term = diffrax.ODETerm(vf) result = diffrax.diffeqsolve( ode_term, - solver, t0=t0, t1=t1, - dt0=kwargs.pop("dt0", None), - y0=input, - stepsize_controller=stepsize_controller, + y0=x, + args=cond, **kwargs, ) return result.ys[0] - if not forward: - t0, t1 = t1, t0 - solver = kwargs.pop("solver", diffrax.Tsit5()), - stepsize_controller = kwargs.pop( + kwargs.setdefault("dt0", None) + kwargs.setdefault("solver", diffrax.Tsit5()) + kwargs.setdefault( "stepsize_controller", diffrax.PIDController(rtol=1e-5, atol=1e-5) ) - return jax.jit(jax.vmap(solve_ode))(data, condition) + in_axes = [0, None if condition is None else 0] + return jax.jit(jax.vmap(solve_ode, in_axes))(x, condition) def _valid_step(self, it: int, valid_source, valid_target) -> None: pass From 4408cc236edd60adacfabdf8fe5e106fea3b9840 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 19 Feb 2024 16:36:21 +0100 Subject: [PATCH 105/186] Polish layers --- src/ott/neural/flow_models/layers.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/ott/neural/flow_models/layers.py b/src/ott/neural/flow_models/layers.py index d18980c38..6f04b4e54 100644 --- a/src/ott/neural/flow_models/layers.py +++ b/src/ott/neural/flow_models/layers.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import jax.numpy as jnp import flax.linen as nn @@ -22,9 +21,8 @@ class CyclicalTimeEncoder(nn.Module): r"""A cyclical time encoder. - Encodes time :math:`t` as - :math:`cos(\tilde{t})` and :math:`sin(\tilde{t})` - where :math:`\tilde{t} = [2\\pi t, 2\\pi 2 t,\\ldots, 2\\pi n_frequencies t]` + Encodes time :math:`t` as :math:`cos(\tilde{t})` and :math:`sin(\tilde{t})` + where :math:`\tilde{t} = [2\pi t, 2\pi 2 t,\ldots, 2\pi n_frequencies t]`. Args: n_frequencies: Frequency of cyclical encoding. @@ -39,8 +37,8 @@ def __call__(self, t: jnp.ndarray) -> jnp.ndarray: # noqa: D102 t: Time of shape ``[n, 1]``. Returns: - Encoded time of shape ``[n, 2 * n_frequencies]`` + Encoded time of shape ``[n, 2 * n_frequencies]``. """ freq = 2 * jnp.arange(self.n_frequencies) * jnp.pi t = freq * t - return jnp.concatenate((jnp.cos(t), jnp.sin(t)), axis=-1) + return jnp.concatenate([jnp.cos(t), jnp.sin(t)], axis=-1) From 451f069df5507d67c15dc013f1b5752e65b1f8db Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 19 Feb 2024 17:40:18 +0100 Subject: [PATCH 106/186] Fix typo in citation --- src/ott/neural/flow_models/flows.py | 2 +- src/ott/neural/flow_models/layers.py | 10 +++---- src/ott/neural/flow_models/models.py | 39 +++++++++++++--------------- src/ott/neural/flow_models/otfm.py | 4 +-- 4 files changed, 26 insertions(+), 29 deletions(-) diff --git a/src/ott/neural/flow_models/flows.py b/src/ott/neural/flow_models/flows.py index b2e4970bf..d434e91f1 100644 --- a/src/ott/neural/flow_models/flows.py +++ b/src/ott/neural/flow_models/flows.py @@ -88,7 +88,7 @@ def compute_xt( Args: rng: Random number generator. - t: Time :math:`t` of shape `(batch_size, 1)`.. + t: Time :math:`t` of shape `(batch_size, 1)`. src: Sample from the source distribution of shape `(batch_size, ...)`. tgt: Sample from the target distribution of shape `(batch_size, ...)`. diff --git a/src/ott/neural/flow_models/layers.py b/src/ott/neural/flow_models/layers.py index 6f04b4e54..9ec703c4c 100644 --- a/src/ott/neural/flow_models/layers.py +++ b/src/ott/neural/flow_models/layers.py @@ -22,12 +22,12 @@ class CyclicalTimeEncoder(nn.Module): r"""A cyclical time encoder. Encodes time :math:`t` as :math:`cos(\tilde{t})` and :math:`sin(\tilde{t})` - where :math:`\tilde{t} = [2\pi t, 2\pi 2 t,\ldots, 2\pi n_frequencies t]`. + where :math:`\hat{t} = [2\pi t, 2\pi 2 t,\dots, 2\pi n_f t]`. Args: - n_frequencies: Frequency of cyclical encoding. + n_freqs: Frequency :math:`n_f` of the cyclical encoding. """ - n_frequencies: int = 128 + n_freqs: int = 128 @nn.compact def __call__(self, t: jnp.ndarray) -> jnp.ndarray: # noqa: D102 @@ -37,8 +37,8 @@ def __call__(self, t: jnp.ndarray) -> jnp.ndarray: # noqa: D102 t: Time of shape ``[n, 1]``. Returns: - Encoded time of shape ``[n, 2 * n_frequencies]``. + Encoded time of shape ``[n, 2 * n_freqs]``. """ - freq = 2 * jnp.arange(self.n_frequencies) * jnp.pi + freq = 2 * jnp.arange(self.n_freqs) * jnp.pi t = freq * t return jnp.concatenate([jnp.cos(t), jnp.sin(t)], axis=-1) diff --git a/src/ott/neural/flow_models/models.py b/src/ott/neural/flow_models/models.py index ebb29aa99..c71fff2c2 100644 --- a/src/ott/neural/flow_models/models.py +++ b/src/ott/neural/flow_models/models.py @@ -29,20 +29,18 @@ class VelocityField(nn.Module): r"""Parameterized neural vector field. - The `VelocityField` learns a map - :math:`v: \\mathbb{R}\times \\mathbb{R}^d\rightarrow \\mathbb{R}^d` solving - the ODE :math:`\frac{dx}{dt} = v(t, x)`. Given a source distribution at time - :math:`t=0`, the `VelocityField` can be used to transport the source - distribution given at :math:`t_0` to a target distribution given at - :math:`t_1` by integrating :math:`v(t, x)` from :math:`t=t_0` to - :math:`t=t_1`. + The `VelocityField` learns a map :math:`v: \mathbb{R}\times \mathbb{R}^d + \rightarrow \mathbb{R}^d` solving the ODE :math:`\frac{dx}{dt} = v(t, x)`. + Given a source distribution at time :math:`t_0`, the velocity field can be + used to transport the source distribution given at :math:`t_0` to + a target distribution given at :math:`t_1` by integrating :math:`v(t, x)` + from :math:`t=t_0` to :math:`t=t_1`. Each of the input, condition, and time embeddings are passed through a block consisting of ``num_layers_per_block`` layers of dimension ``latent_embed_dim``, ``condition_embed_dim``, and ``time_embed_dim``, - respectively. - The output of each block is concatenated and passed through a final block of - dimension ``joint_hidden_dim``. + respectively. The output of each block is concatenated and passed through + a final block of dimension ``joint_hidden_dim``. Args: output_dim: Dimensionality of the neural vector field. @@ -57,8 +55,7 @@ class VelocityField(nn.Module): t_embed_dim``. num_layers_per_block: Number of layers per block. act_fn: Activation function. - n_frequencies: Number of frequencies to use for the time embedding. - + n_freqs: Number of frequencies to use for the time embedding. """ output_dim: int latent_embed_dim: int @@ -68,7 +65,7 @@ class VelocityField(nn.Module): joint_hidden_dim: Optional[int] = None num_layers_per_block: int = 3 act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.silu - n_frequencies: int = 128 + n_freqs: int = 128 def __post_init__(self) -> None: if self.condition_embed_dim is None: @@ -81,8 +78,8 @@ def __post_init__(self) -> None: ) if self.joint_hidden_dim is not None: assert (self.joint_hidden_dim >= concat_embed_dim), ( - "joint_hidden_dim must be greater than or equal to the sum of " - "all embedded dimensions. " + "joint_hidden_dim must be greater than or equal to the sum of" + " all embedded dimensions." ) self.joint_hidden_dim = self.latent_embed_dim else: @@ -99,14 +96,14 @@ def __call__( """Forward pass through the neural vector field. Args: - t: Time of shape (batch_size, 1). - x: Data of shape (batch_size, output_dim). + t: Time of shape `(batch_size, 1)`. + x: Data of shape `(batch_size, output_dim)`. condition: Conditioning vector. Returns: Output of the neural vector field. """ - t = flow_layers.CyclicalTimeEncoder(n_frequencies=self.n_frequencies)(t) + t = flow_layers.CyclicalTimeEncoder(self.n_freqs)(t) t_layer = layers.MLPBlock( dim=self.t_embed_dim, out_dim=self.t_embed_dim, @@ -131,9 +128,9 @@ def __call__( act_fn=self.act_fn ) condition = condition_layer(condition) - concatenated = jnp.concatenate((t, x, condition), axis=-1) + concatenated = jnp.concatenate([t, x, condition], axis=-1) else: - concatenated = jnp.concatenate((t, x), axis=-1) + concatenated = jnp.concatenate([t, x], axis=-1) out_layer = layers.MLPBlock( dim=self.joint_hidden_dim, @@ -158,7 +155,7 @@ def create_train_state( input_dim: Dimensionality of the input. Returns: - Training state. + The training state. """ params = self.init( rng, jnp.ones((1, 1)), jnp.ones((1, input_dim)), diff --git a/src/ott/neural/flow_models/otfm.py b/src/ott/neural/flow_models/otfm.py index 53788fc37..3cd4d2cee 100644 --- a/src/ott/neural/flow_models/otfm.py +++ b/src/ott/neural/flow_models/otfm.py @@ -32,14 +32,14 @@ class OTFlowMatching: """(Optimal transport) flow matching :cite:`lipman:22`. - Includes an extension to OT-FM :cite`tong:23`, :cite:`pooladian:23`. + With an extension to OT-FM :cite:`tong:23`, :cite:`pooladian:23`. Args: input_dim: Dimension of the input data. velocity_field: Neural vector field parameterized by a neural network. flow: Flow between source and target distribution. time_sampler: Sampler for the time. - optimizer: Optimizer for the ``velocity_field``. + optimizer: Optimizer for the velocity field's parameters. ot_matcher: TODO. unbalancedness_handler: TODO. rng: Random number generator. From 5e10d3a5598425a37d6cbe664e8d46f11e7be0a4 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 19 Feb 2024 18:04:12 +0100 Subject: [PATCH 107/186] More polish for the docs --- docs/neural/flow_models.rst | 4 ++-- src/ott/neural/flow_models/flows.py | 34 ++++++++++++++++++++-------- src/ott/neural/flow_models/layers.py | 2 +- 3 files changed, 28 insertions(+), 12 deletions(-) diff --git a/docs/neural/flow_models.rst b/docs/neural/flow_models.rst index 5d9d1f594..5f9799292 100644 --- a/docs/neural/flow_models.rst +++ b/docs/neural/flow_models.rst @@ -16,8 +16,8 @@ Flows flows.ConstantNoiseFlow flows.BrownianNoiseFlow -Optimal Transport Flow Matching -------------------------------- +OT Flow Matching +---------------- .. autosummary:: :toctree: _autosummary diff --git a/src/ott/neural/flow_models/flows.py b/src/ott/neural/flow_models/flows.py index d434e91f1..150e2086e 100644 --- a/src/ott/neural/flow_models/flows.py +++ b/src/ott/neural/flow_models/flows.py @@ -28,7 +28,7 @@ class BaseFlow(abc.ABC): """Base class for all flows. Args: - sigma: Constant noise used for computing time-dependent noise schedule. + sigma: Noise used for computing time-dependent noise schedule. """ def __init__(self, sigma: float): @@ -103,7 +103,11 @@ def compute_xt( class StraightFlow(BaseFlow, abc.ABC): - """Base class for flows with straight paths.""" + """Base class for flows with straight paths. + + Args: + sigma: Noise used for computing time-dependent noise schedule. + """ def compute_mu_t( # noqa: D102 self, t: jnp.ndarray, src: jnp.ndarray, tgt: jnp.ndarray @@ -118,7 +122,11 @@ def compute_ut( # noqa: D102 class ConstantNoiseFlow(StraightFlow): - r"""Flow with straight paths and constant flow noise :math:`\sigma`.""" + r"""Flow with straight paths and constant flow noise :math:`\sigma`. + + Args: + sigma: Constant noise used for computing time-independent noise schedule. + """ def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray: r"""Compute noise of the flow at time :math:`t`. @@ -135,14 +143,22 @@ def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray: class BrownianNoiseFlow(StraightFlow): r"""Brownian Bridge Flow. - Sampler for sampling noise implicitly defined by a Schroedinger Bridge + Sampler for sampling noise implicitly defined by a Schrödinger Bridge problem with parameter :math:`\sigma` such that - :math:`\sigma_t = \sigma * \sqrt(t * (1-t))` (:cite:`tong:23`). + :math:`\sigma_t = \sigma \cdot \sqrt{t \cdot (1 - t)}` :cite:`tong:23`. - Returns: - Samples from the probability path between :math:`x_0` and :math:`x_1` - at time :math:`t`. + Args: + sigma: Noise used for computing time-dependent noise schedule. """ - def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray: # noqa: D102 + def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray: + r"""Compute noise of the flow at time :math:`t`. + + Args: + t: Time :math:`t` of shape `(batch_size, 1)`. + + Returns: + Samples from the probability path between :math:`x_0` and :math:`x_1` + at time :math:`t`. + """ return self.sigma * jnp.sqrt(t * (1.0 - t)) diff --git a/src/ott/neural/flow_models/layers.py b/src/ott/neural/flow_models/layers.py index 9ec703c4c..2f87f6cfc 100644 --- a/src/ott/neural/flow_models/layers.py +++ b/src/ott/neural/flow_models/layers.py @@ -21,7 +21,7 @@ class CyclicalTimeEncoder(nn.Module): r"""A cyclical time encoder. - Encodes time :math:`t` as :math:`cos(\tilde{t})` and :math:`sin(\tilde{t})` + Encodes time :math:`t` as :math:`cos(\hat{t})` and :math:`sin(\hat{t})` where :math:`\hat{t} = [2\pi t, 2\pi 2 t,\dots, 2\pi n_f t]`. Args: From 5edc66d4f595aca1047360990e611597b16b9174 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 6 Mar 2024 09:47:12 +0100 Subject: [PATCH 108/186] remove print statements and unbalancednesshandler --- docs/neural/models.rst | 1 - src/ott/neural/flow_models/genot.py | 34 +--- src/ott/neural/flow_models/otfm.py | 19 -- src/ott/neural/models/base_solver.py | 280 +-------------------------- tests/neural/genot_test.py | 59 +----- tests/neural/otfm_test.py | 46 +---- 6 files changed, 13 insertions(+), 426 deletions(-) diff --git a/docs/neural/models.rst b/docs/neural/models.rst index bacc93c71..af6d4e33a 100644 --- a/docs/neural/models.rst +++ b/docs/neural/models.rst @@ -14,7 +14,6 @@ Utils base_solver.BaseOTMatcher base_solver.OTMatcherLinear base_solver.OTMatcherQuad - base_solver.UnbalancednessHandler Neural networks diff --git a/src/ott/neural/flow_models/genot.py b/src/ott/neural/flow_models/genot.py index d466f015d..ba2c0f6a0 100644 --- a/src/ott/neural/flow_models/genot.py +++ b/src/ott/neural/flow_models/genot.py @@ -66,7 +66,6 @@ class GENOTBase: optimizer: Optimizer for `velocity_field`. flow: Flow between latent distribution and target distribution. time_sampler: Sampler for the time. - unbalancedness_handler: Handler for unbalancedness. k_samples_per_x: Number of samples drawn from the conditional distribution of an input sample, see algorithm TODO. solver_latent_to_data: Linear OT solver to match the latent distribution @@ -91,7 +90,6 @@ def __init__( iterations: int, valid_freq: int, ot_matcher: base_solver.BaseOTMatcher, - unbalancedness_handler: base_solver.UnbalancednessHandler, optimizer: optax.GradientTransformation, flow: Type[flows.BaseFlow] = flows.ConstantNoiseFlow(0.0), # noqa: B008 time_sampler: Callable[[jax.Array, int], @@ -125,9 +123,6 @@ def __init__( self.cond_dim = cond_dim self.k_samples_per_x = k_samples_per_x - # unbalancedness - self.unbalancedness_handler = unbalancedness_handler - # OT data-data matching parameters self.fused_penalty = fused_penalty @@ -262,10 +257,7 @@ def _valid_step(self, valid_loader, iter): @property def learn_rescaling(self) -> bool: """Whether to learn at least one rescaling factor.""" - return ( - self.unbalancedness_handler.rescaling_a is not None or - self.unbalancedness_handler.rescaling_b is not None - ) + return False def _reshape_samples(self, arrays: Tuple[jnp.ndarray, ...], batch_size: int) -> Tuple[jnp.ndarray, ...]: @@ -278,20 +270,7 @@ def _learn_rescaling( self, source: jnp.ndarray, target: jnp.ndarray, source_conditions: Optional[jnp.ndarray], tmat: jnp.ndarray ) -> Tuple[jnp.ndarray, jnp.ndarray, float, float]: - - ( - self.state_eta, self.state_xi, eta_predictions, xi_predictions, loss_a, - loss_b - ) = self.unbalancedness_handler.step_fn( - source=source, - target=target, - condition=source_conditions, - a=tmat.sum(axis=1), - b=tmat.sum(axis=0), - state_eta=self.unbalancedness_handler.state_eta, - state_xi=self.unbalancedness_handler.state_xi, - ) - return eta_predictions, xi_predictions, float(loss_a), float(loss_b) + raise NotImplementedError class GENOTLin(GENOTBase): @@ -339,11 +318,6 @@ def __call__(self, train_loader, valid_loader): target, ) - jax.debug.print("source.shape {x}", x=source.shape) - jax.debug.print( - "source_conditions.shape {x}", x=source_conditions.shape - ) - jax.debug.print("target.shape {x}", x=target.shape) (source, source_conditions ), (target,) = self.ot_matcher.sample_conditional_indices_from_tmap( rng=rng_resample, @@ -351,7 +325,7 @@ def __call__(self, train_loader, valid_loader): k_samples_per_x=self.k_samples_per_x, source_arrays=(source, source_conditions), target_arrays=(target,), - source_is_balanced=(self.unbalancedness_handler.tau_a == 1.0) + source_is_balanced=(self.ot_matcher.tau_a == 1.0) ) if self.matcher_latent_to_data is not None: @@ -454,7 +428,7 @@ def __call__(self, train_loader, valid_loader): k_samples_per_x=self.k_samples_per_x, source_arrays=(source, source_conditions), target_arrays=(target,), - source_is_balanced=(self.unbalancedness_handler.tau_a == 1.0) + source_is_balanced=(self.ot_matcher.tau_a == 1.0) ) ) diff --git a/src/ott/neural/flow_models/otfm.py b/src/ott/neural/flow_models/otfm.py index 3cd4d2cee..e426a6f5c 100644 --- a/src/ott/neural/flow_models/otfm.py +++ b/src/ott/neural/flow_models/otfm.py @@ -41,7 +41,6 @@ class OTFlowMatching: time_sampler: Sampler for the time. optimizer: Optimizer for the velocity field's parameters. ot_matcher: TODO. - unbalancedness_handler: TODO. rng: Random number generator. """ @@ -55,8 +54,6 @@ def __init__( time_sampler: Callable[[jax.Array, int], jnp.ndarray], optimizer: optax.GradientTransformation, ot_matcher: Optional[base_solver.OTMatcherLinear] = None, - unbalancedness_handler: Optional[base_solver.UnbalancednessHandler - ] = None, rng: Optional[jax.Array] = None, ): rng = utils.default_prng_key(rng) @@ -65,7 +62,6 @@ def __init__( self.vf = velocity_field self.flow = flow self.time_sampler = time_sampler - self.unbalancedness_handler = unbalancedness_handler self.ot_matcher = ot_matcher self.optimizer = optimizer @@ -149,21 +145,6 @@ def __call__( # noqa: D102 ) training_logs["loss"].append(float(loss)) - if self.unbalancedness_handler is not None and tmat is not None: - ( - self.unbalancedness_handler.state_eta, - self.unbalancedness_handler.state_xi, eta_predictions, - xi_predictions, loss_a, loss_b - ) = self.unbalancedness_handler.step_fn( - source=source, - target=target, - condition=source_conditions, - a=tmat.sum(axis=1), - b=tmat.sum(axis=0), - state_eta=self.unbalancedness_handler.state_eta, - state_xi=self.unbalancedness_handler.state_xi, - ) - if it % valid_freq == 0: self._valid_step(valid_source, valid_target, it) diff --git a/src/ott/neural/models/base_solver.py b/src/ott/neural/models/base_solver.py index e9be6baa8..5ddfd5ef5 100644 --- a/src/ott/neural/models/base_solver.py +++ b/src/ott/neural/models/base_solver.py @@ -17,13 +17,9 @@ import jax.numpy as jnp from jax import tree_util -import optax -from flax.training import train_state - from ott.geometry import costs, pointcloud from ott.problems.linear import linear_problem from ott.problems.quadratic import quadratic_problem -from ott.solvers import was_solver from ott.solvers.linear import sinkhorn from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr @@ -31,7 +27,9 @@ ScaleCostQuad_t = Union[ScaleCost_t, Dict[str, ScaleCost_t]] __all__ = [ - "BaseOTMatcher", "OTMatcherLinear", "OTMatcherQuad", "UnbalancednessHandler" + "BaseOTMatcher", + "OTMatcherLinear", + "OTMatcherQuad", ] @@ -308,275 +306,3 @@ def match_pairs(*args, **kwargs): return fn(*args, **kwargs).matrix return match_pairs - - -class UnbalancednessHandler: - """Class to incorporate unbalancedness into neural OT models. - - This class implements the concepts introduced in :cite:`eyring:23` - in the Monge Map scenario and :cite:`klein:23` for the entropic OT case - for linear and quadratic cases. - - Args: - rng: Random number generator. - source_dim: Dimension of the source domain. - target_dim: Dimension of the target domain. - cond_dim: Dimension of the conditioning variable. - If :obj:`None`, no conditioning is used. - tau_a: Unbalancedness parameter for the source distribution. - Only used if `ot_solver` is not :obj:`None`. - tau_b: Unbalancedness parameter for the target distribution. - Only used if `ot_solver` is not :obj:`None`. - rescaling_a: Rescaling function for the source distribution. - If :obj:`None`, the left rescaling factor is not learnt. - rescaling_b: Rescaling function for the target distribution. - If :obj:`None`, the right rescaling factor is not learnt. - opt_eta: Optimizer for the left rescaling function. - opt_xi: Optimzier for the right rescaling function. - resample_epsilon: Epsilon for resampling. - scale_cost: Scaling of the cost matrix for estimating the rescaling factors. - ot_solver: Solver to compute unbalanced marginals. If `ot_solver` is `None`, - the method :meth:`ott.neural.models.base_solver.UnbalancednessHandler.compute_unbalanced_marginals` - is not available, and hence the unbalanced marginals must be computed - by the neural solver. - kwargs: Additional keyword arguments. - - """ # noqa: E501 - - def __init__( - self, - rng: jax.Array, - source_dim: int, - target_dim: int, - cond_dim: Optional[int], - tau_a: float = 1.0, - tau_b: float = 1.0, - rescaling_a: Optional[Callable[[jnp.ndarray, Optional[jnp.ndarray]], - jnp.ndarray]] = None, - rescaling_b: Optional[Callable[[jnp.ndarray, Optional[jnp.ndarray]], - jnp.ndarray]] = None, - opt_eta: Optional[optax.GradientTransformation] = None, - opt_xi: Optional[optax.GradientTransformation] = None, - resample_epsilon: float = 1e-2, - scale_cost: Union[ScaleCost_t, ScaleCostQuad_t] = 1.0, - ot_solver: Optional[was_solver.WassersteinSolver] = None, - **kwargs: Mapping[str, Any], - ): - self.rng_unbalanced = rng - self.source_dim = source_dim - self.target_dim = target_dim - self.cond_dim = cond_dim - self.tau_a = tau_a - self.tau_b = tau_b - self.rescaling_a = rescaling_a - self.rescaling_b = rescaling_b - self.opt_eta = opt_eta - self.opt_xi = opt_xi - self.resample_epsilon = resample_epsilon - self.scale_cost = scale_cost - self.ot_solver = ot_solver - - if isinstance(ot_solver, sinkhorn.Sinkhorn): - self.compute_unbalanced_marginals = ( - self._get_compute_unbalanced_marginals_lin( - tau_a=tau_a, - tau_b=tau_b, - resample_epsilon=resample_epsilon, - scale_cost=scale_cost, - **kwargs - ) - ) - elif isinstance(ot_solver, gromov_wasserstein.GromovWasserstein): - raise NotImplementedError - else: - self.compute_unbalanced_marginals = None - self.setup(source_dim=source_dim, target_dim=target_dim, cond_dim=cond_dim) - - def _get_compute_unbalanced_marginals_lin( - self, *args: Any, **kwargs: Mapping[str, Any] - ) -> Tuple[jnp.ndarray, jnp.ndarray]: - """Compute the unbalanced source and target marginals for a batch.""" - fn = _get_sinkhorn_match_fn(*args, **kwargs) - - @jax.jit - def compute_unbalanced_marginals_lin(*args, **kwargs): - out = fn(*args, **kwargs) - return out.marginals(axis=1), out.marginals(axis=0) - - return compute_unbalanced_marginals_lin - - def _get_compute_unbalanced_marginals_quad( - self, *args: Any, **kwargs: Mapping[str, Any] - ) -> Tuple[jnp.ndarray, jnp.ndarray]: - """Compute the unbalanced source and target marginals for a batch.""" - fn = _get_sinkhorn_match_fn(*args, **kwargs) - - @jax.jit - def compute_unbalanced_marginals_quad(*args, **kwargs): - out = fn(*args, **kwargs) - return out.marginals(axis=1), out.marginals(axis=0) - - return compute_unbalanced_marginals_quad - - @jax.jit - def resample_unbalanced( - self, - rng: jax.Array, - arrays: Tuple[jnp.ndarray, ...], - p: jnp.ndarray, - ) -> Tuple[jnp.ndarray, ...]: - """Resample a batch based on marginals. - - Args: - rng: Random number generator. - arrays: Arrays to resample from. - p: Probabilities according to which `arrays` are resampled. - - Returns: - Resampled arrays. - """ - indices = jax.random.choice(rng, a=len(p), p=jnp.squeeze(p), shape=[len(p)]) - return tree_util.tree_map(lambda b: b[indices], arrays) - - def setup(self, source_dim: int, target_dim: int, cond_dim: int): - """Setup the model. - - Args: - source_dim: Dimension of the source domain. - target_dim: Dimension of the target domain. - cond_dim: Dimension of the conditioning variable. - If :obj:`None`, no conditioning is used. - """ - self.rng_unbalanced, rng_eta, rng_xi = jax.random.split( - self.rng_unbalanced, 3 - ) - self.step_fn = self._get_rescaling_step_fn() - if self.rescaling_a is not None: - self.opt_eta = ( - self.opt_eta if self.opt_eta is not None else - optax.adamw(learning_rate=1e-4, weight_decay=1e-10) - ) - self.state_eta = self.rescaling_a.create_train_state( - rng_eta, self.opt_eta, source_dim - ) - if self.rescaling_b is not None: - self.opt_xi = ( - self.opt_xi if self.opt_xi is not None else - optax.adamw(learning_rate=1e-4, weight_decay=1e-10) - ) - self.state_xi = self.rescaling_b.create_train_state( - rng_xi, self.opt_xi, target_dim - ) - - def _get_rescaling_step_fn(self) -> Callable: # type:ignore[type-arg] - - def loss_marginal_fn( - params: jnp.ndarray, - apply_fn: Callable[[Dict[str, jnp.ndarray], jnp.ndarray], - Optional[jnp.ndarray]], - x: jnp.ndarray, - condition: Optional[jnp.ndarray], - true_marginals: jnp.ndarray, - expectation_reweighting: float, - ) -> Tuple[float, jnp.ndarray]: - predictions = apply_fn({"params": params}, x, condition) - pred_loss = optax.l2_loss(jnp.squeeze(predictions), true_marginals).mean() - exp_loss = optax.l2_loss(jnp.mean(predictions) - expectation_reweighting) - return (pred_loss + exp_loss, predictions) - - @jax.jit - def step_fn( - source: jnp.ndarray, - target: jnp.ndarray, - condition: Optional[jnp.ndarray], - a: jnp.ndarray, - b: jnp.ndarray, - state_eta: Optional[train_state.TrainState] = None, - state_xi: Optional[train_state.TrainState] = None, - *, - is_training: bool = True, - ): - if state_eta is not None: - grad_a_fn = jax.value_and_grad( - loss_marginal_fn, argnums=0, has_aux=True - ) - (loss_a, eta_predictions), grads_eta = grad_a_fn( - state_eta.params, - state_eta.apply_fn, - source, - condition, - a * len(a), - jnp.sum(b), - ) - new_state_eta = state_eta.apply_gradients( - grads=grads_eta - ) if is_training else None - - else: - new_state_eta = eta_predictions = loss_a = None - if state_xi is not None: - grad_b_fn = jax.value_and_grad( - loss_marginal_fn, argnums=0, has_aux=True - ) - (loss_b, xi_predictions), grads_xi = grad_b_fn( - state_xi.params, - state_xi.apply_fn, - target, - condition, - b * len(b), - jnp.sum(a), - ) - new_state_xi = state_xi.apply_gradients( - grads=grads_xi - ) if is_training else None - else: - new_state_xi = xi_predictions = loss_b = None - - return ( - new_state_eta, new_state_xi, eta_predictions, xi_predictions, loss_a, - loss_b - ) - - return step_fn - - def evaluate_eta( - self, - source: jnp.ndarray, - condition: Optional[jnp.ndarray] = None - ) -> jnp.ndarray: - """Evaluate the left learnt rescaling factor. - - Args: - source: Samples from the source distribution to evaluate rescaling - function on. - condition: Condition belonging to the samples in the source distribution. - - Returns: - Learnt left rescaling factors. - """ - if self.state_eta is None: - raise ValueError("The left rescaling factor was not parameterized.") - return self.state_eta.apply_fn({"params": self.state_eta.params}, - x=source, - condition=condition) - - def evaluate_xi( - self, - target: jnp.ndarray, - condition: Optional[jnp.ndarray] = None - ) -> jnp.ndarray: - """Evaluate the right learnt rescaling factor. - - Args: - target: Samples from the target distribution to evaluate the rescaling - function on. - condition: Condition belonging to the samples in the target distribution. - - Returns: - Learnt right rescaling factors. - """ - if self.state_xi is None: - raise ValueError("The right rescaling factor was not parameterized.") - return self.state_xi.apply_fn({"params": self.state_xi.params}, - x=target, - condition=condition) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index e924edd9a..7156dec3d 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -17,7 +17,6 @@ import pytest import jax.numpy as jnp -from jax import random import optax @@ -26,7 +25,6 @@ from ott.neural.flow_models.models import VelocityField from ott.neural.flow_models.samplers import uniform_sampler from ott.neural.models import base_solver -from ott.neural.models.nets import RescalingMLP from ott.solvers.linear import sinkhorn, sinkhorn_lr from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr @@ -71,9 +69,6 @@ def test_genot_linear_unconditional( ot_matcher = base_solver.OTMatcherLinear( ot_solver, cost_fn=costs.SqEuclidean(), scale_cost=scale_cost ) - unbalancedness_handler = base_solver.UnbalancednessHandler( - random.PRNGKey(0), source_dim, target_dim, condition_dim - ) time_sampler = uniform_sampler optimizer = optax.adam(learning_rate=1e-3) genot = GENOTLin( @@ -86,7 +81,6 @@ def test_genot_linear_unconditional( ot_matcher=ot_matcher, optimizer=optimizer, time_sampler=time_sampler, - unbalancedness_handler=unbalancedness_handler, k_samples_per_x=k_samples_per_x, matcher_latent_to_data=matcher_latent_to_data, ) @@ -134,9 +128,6 @@ def test_genot_linear_conditional( ot_solver, cost_fn=costs.SqEuclidean() ) time_sampler = uniform_sampler - unbalancedness_handler = base_solver.UnbalancednessHandler( - random.PRNGKey(0), source_dim, target_dim, condition_dim - ) optimizer = optax.adam(learning_rate=1e-3) genot = GENOTLin( @@ -147,7 +138,6 @@ def test_genot_linear_conditional( iterations=3, valid_freq=2, ot_matcher=ot_matcher, - unbalancedness_handler=unbalancedness_handler, optimizer=optimizer, time_sampler=time_sampler, k_samples_per_x=k_samples_per_x, @@ -198,26 +188,12 @@ def test_genot_linear_learn_rescaling( ot_matcher = base_solver.OTMatcherLinear( ot_solver, cost_fn=costs.SqEuclidean(), + tau_a=0.2, + tau_b=0.9, ) time_sampler = uniform_sampler optimizer = optax.adam(learning_rate=1e-3) - tau_a = 0.9 - tau_b = 0.2 - rescaling_a = RescalingMLP(hidden_dim=4, condition_dim=condition_dim) - rescaling_b = RescalingMLP(hidden_dim=4, condition_dim=condition_dim) - - unbalancedness_handler = base_solver.UnbalancednessHandler( - random.PRNGKey(0), - source_dim, - target_dim, - condition_dim, - tau_a=tau_a, - tau_b=tau_b, - rescaling_a=rescaling_a, - rescaling_b=rescaling_b - ) - genot = GENOTLin( neural_vf, input_dim=source_dim, @@ -228,24 +204,11 @@ def test_genot_linear_learn_rescaling( ot_matcher=ot_matcher, optimizer=optimizer, time_sampler=time_sampler, - unbalancedness_handler=unbalancedness_handler, matcher_latent_to_data=matcher_latent_to_data, ) genot(data_loader, data_loader) - result_eta = genot.unbalancedness_handler.evaluate_eta( - source_lin, condition=source_condition - ) - assert isinstance(result_eta, jnp.ndarray) - assert jnp.sum(jnp.isnan(result_eta)) == 0 - - result_xi = genot.unbalancedness_handler.evaluate_xi( - target_lin, condition=source_condition - ) - assert isinstance(result_xi, jnp.ndarray) - assert jnp.sum(jnp.isnan(result_xi)) == 0 - class TestGENOTQuad: @@ -285,10 +248,6 @@ def test_genot_quad_unconditional( ot_solver, cost_fn=costs.SqEuclidean() ) - unbalancedness_handler = base_solver.UnbalancednessHandler( - random.PRNGKey(0), source_dim, target_dim, condition_dim - ) - time_sampler = functools.partial(uniform_sampler, offset=1e-2) optimizer = optax.adam(learning_rate=1e-3) genot = GENOTQuad( @@ -299,7 +258,6 @@ def test_genot_quad_unconditional( iterations=3, valid_freq=2, ot_matcher=ot_matcher, - unbalancedness_handler=unbalancedness_handler, optimizer=optimizer, time_sampler=time_sampler, k_samples_per_x=k_samples_per_x, @@ -351,10 +309,6 @@ def test_genot_fused_unconditional( ot_solver, cost_fn=costs.SqEuclidean(), fused_penalty=0.5 ) - unbalancedness_handler = base_solver.UnbalancednessHandler( - random.PRNGKey(0), source_dim, target_dim, condition_dim - ) - optimizer = optax.adam(learning_rate=1e-3) genot = GENOTQuad( neural_vf, @@ -364,7 +318,6 @@ def test_genot_fused_unconditional( iterations=3, valid_freq=2, ot_matcher=ot_matcher, - unbalancedness_handler=unbalancedness_handler, optimizer=optimizer, k_samples_per_x=k_samples_per_x, matcher_latent_to_data=matcher_latent_to_data, @@ -416,9 +369,6 @@ def test_genot_quad_conditional( ot_solver, cost_fn=costs.SqEuclidean() ) time_sampler = uniform_sampler - unbalancedness_handler = base_solver.UnbalancednessHandler( - random.PRNGKey(0), source_dim, target_dim, condition_dim - ) optimizer = optax.adam(learning_rate=1e-3) genot = GENOTQuad( @@ -429,7 +379,6 @@ def test_genot_quad_conditional( iterations=3, valid_freq=2, ot_matcher=ot_matcher, - unbalancedness_handler=unbalancedness_handler, optimizer=optimizer, time_sampler=time_sampler, k_samples_per_x=k_samples_per_x, @@ -486,9 +435,6 @@ def test_genot_fused_conditional( ) time_sampler = uniform_sampler optimizer = optax.adam(learning_rate=1e-3) - unbalancedness_handler = base_solver.UnbalancednessHandler( - random.PRNGKey(0), source_dim, target_dim, condition_dim - ) genot = GENOTQuad( neural_vf, @@ -498,7 +444,6 @@ def test_genot_fused_conditional( iterations=3, valid_freq=2, ot_matcher=ot_matcher, - unbalancedness_handler=unbalancedness_handler, optimizer=optimizer, time_sampler=time_sampler, k_samples_per_x=k_samples_per_x, diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index e57fce89c..5c53db325 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -17,12 +17,11 @@ import pytest import jax.numpy as jnp -from jax import random import optax from ott.neural.flow_models import flows, models, otfm, samplers -from ott.neural.models import base_solver, nets +from ott.neural.models import base_solver from ott.solvers.linear import sinkhorn, sinkhorn_lr @@ -49,9 +48,7 @@ def test_flow_matching_unconditional( ot_matcher = base_solver.OTMatcherLinear(ot_solver) time_sampler = samplers.uniform_sampler optimizer = optax.adam(learning_rate=1e-3) - unbalancedness_handler = base_solver.UnbalancednessHandler( - random.PRNGKey(0), input_dim, input_dim, condition_dim - ) + fm = otfm.OTFlowMatching( neural_vf, input_dim=input_dim, @@ -62,7 +59,6 @@ def test_flow_matching_unconditional( flow=flow, time_sampler=time_sampler, optimizer=optimizer, - unbalancedness_handler=unbalancedness_handler ) fm( data_loaders_gaussian[0], data_loaders_gaussian[1], @@ -101,17 +97,14 @@ def test_flow_matching_with_conditions( input_dim = 2 condition_dim = 1 neural_vf = models.VelocityField( - output_dim=2, - condition_dim=1, + output_dim=input_dim, + condition_dim=condition_dim, latent_embed_dim=5, ) ot_solver = sinkhorn.Sinkhorn() ot_matcher = base_solver.OTMatcherLinear(ot_solver) time_sampler = functools.partial(samplers.uniform_sampler, offset=1e-5) optimizer = optax.adam(learning_rate=1e-3) - unbalancedness_handler = base_solver.UnbalancednessHandler( - random.PRNGKey(0), input_dim, input_dim, condition_dim - ) fm = otfm.OTFlowMatching( neural_vf, @@ -123,7 +116,6 @@ def test_flow_matching_with_conditions( flow=flow, time_sampler=time_sampler, optimizer=optimizer, - unbalancedness_handler=unbalancedness_handler ) fm( data_loader_gaussian_with_conditions, @@ -173,9 +165,6 @@ def test_flow_matching_conditional( ot_matcher = base_solver.OTMatcherLinear(ot_solver) time_sampler = samplers.uniform_sampler optimizer = optax.adam(learning_rate=1e-3) - unbalancedness_handler = base_solver.UnbalancednessHandler( - random.PRNGKey(0), dim, dim, condition_dim - ) fm = otfm.OTFlowMatching( neural_vf, @@ -187,7 +176,6 @@ def test_flow_matching_conditional( flow=flow, time_sampler=time_sampler, optimizer=optimizer, - unbalancedness_handler=unbalancedness_handler ) fm(data_loader_gaussian_conditional, data_loader_gaussian_conditional) @@ -220,7 +208,6 @@ def test_flow_matching_learn_rescaling( ) batch = next(iter(data_loader)) source = jnp.asarray(batch["source_lin"]) - target = jnp.asarray(batch["target_lin"]) source_conditions = jnp.asarray(batch["source_conditions"]) if len( batch["source_conditions"] ) > 0 else None @@ -239,23 +226,11 @@ def test_flow_matching_learn_rescaling( tau_a = 0.9 tau_b = 0.2 - rescaling_a = nets.RescalingMLP(hidden_dim=4, condition_dim=condition_dim) - rescaling_b = nets.RescalingMLP(hidden_dim=4, condition_dim=condition_dim) ot_matcher = base_solver.OTMatcherLinear( ot_solver, tau_a=tau_a, tau_b=tau_b, ) - unbalancedness_handler = base_solver.UnbalancednessHandler( - random.PRNGKey(0), - source_dim, - source_dim, - condition_dim, - tau_a=tau_a, - tau_b=tau_b, - rescaling_a=rescaling_a, - rescaling_b=rescaling_b - ) fm = otfm.OTFlowMatching( neural_vf, @@ -267,18 +242,5 @@ def test_flow_matching_learn_rescaling( flow=flow, time_sampler=time_sampler, optimizer=optimizer, - unbalancedness_handler=unbalancedness_handler, ) fm(data_loader, data_loader) - - result_eta = fm.unbalancedness_handler.evaluate_eta( - source, condition=source_conditions - ) - assert isinstance(result_eta, jnp.ndarray) - assert jnp.sum(jnp.isnan(result_eta)) == 0 - - result_xi = fm.unbalancedness_handler.evaluate_xi( - target, condition=source_conditions - ) - assert isinstance(result_xi, jnp.ndarray) - assert jnp.sum(jnp.isnan(result_xi)) == 0 From 23eca2caa762e4207f55e0f300a81a048f4e4fe5 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 6 Mar 2024 09:50:22 +0100 Subject: [PATCH 109/186] remove tests --- tests/neural/genot_test.py | 56 -------------------------------------- tests/neural/otfm_test.py | 50 +--------------------------------- 2 files changed, 1 insertion(+), 105 deletions(-) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index 7156dec3d..c8d4a48f7 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -153,62 +153,6 @@ def test_genot_linear_conditional( assert isinstance(result_forward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_forward)) == 0 - @pytest.mark.parametrize("conditional", [False, True]) - @pytest.mark.parametrize("solver_latent_to_data", [None, "sinkhorn"]) - def test_genot_linear_learn_rescaling( - self, conditional: bool, genot_data_loader_linear: Iterator, - solver_latent_to_data: Optional[str], - genot_data_loader_linear_conditional: Iterator - ): - matcher_latent_to_data = ( - None if solver_latent_to_data is None else - base_solver.OTMatcherLinear(sinkhorn.Sinkhorn()) - ) - - data_loader = ( - genot_data_loader_linear_conditional - if conditional else genot_data_loader_linear - ) - - batch = next(iter(data_loader)) - source_lin, target_lin, source_condition = jnp.array( - batch["source_lin"] - ), jnp.array(batch["target_lin"]), jnp.array(batch["source_conditions"]) - - source_dim = source_lin.shape[1] - target_dim = target_lin.shape[1] - condition_dim = source_condition.shape[1] if conditional else 0 - - neural_vf = VelocityField( - output_dim=target_dim, - condition_dim=source_dim + condition_dim, - latent_embed_dim=5, - ) - ot_solver = sinkhorn.Sinkhorn() - ot_matcher = base_solver.OTMatcherLinear( - ot_solver, - cost_fn=costs.SqEuclidean(), - tau_a=0.2, - tau_b=0.9, - ) - time_sampler = uniform_sampler - optimizer = optax.adam(learning_rate=1e-3) - - genot = GENOTLin( - neural_vf, - input_dim=source_dim, - output_dim=target_dim, - cond_dim=condition_dim, - iterations=3, - valid_freq=2, - ot_matcher=ot_matcher, - optimizer=optimizer, - time_sampler=time_sampler, - matcher_latent_to_data=matcher_latent_to_data, - ) - - genot(data_loader, data_loader) - class TestGENOTQuad: diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index 5c53db325..f43403054 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -from typing import Iterator, Literal, Type +from typing import Literal, Type import pytest @@ -196,51 +196,3 @@ def test_flow_matching_conditional( ) assert isinstance(result_backward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_backward)) == 0 - - @pytest.mark.parametrize("conditional", [False, True]) - def test_flow_matching_learn_rescaling( - self, conditional: bool, data_loader_gaussian: Iterator, - data_loader_gaussian_conditional: Iterator - ): - data_loader = ( - data_loader_gaussian_conditional - if conditional else data_loader_gaussian - ) - batch = next(iter(data_loader)) - source = jnp.asarray(batch["source_lin"]) - source_conditions = jnp.asarray(batch["source_conditions"]) if len( - batch["source_conditions"] - ) > 0 else None - - source_dim = source.shape[1] - condition_dim = source_conditions.shape[1] if conditional else 0 - neural_vf = models.VelocityField( - output_dim=2, - condition_dim=0, - latent_embed_dim=5, - ) - ot_solver = sinkhorn.Sinkhorn() - time_sampler = samplers.uniform_sampler - flow = flows.ConstantNoiseFlow(1.0) - optimizer = optax.adam(learning_rate=1e-3) - - tau_a = 0.9 - tau_b = 0.2 - ot_matcher = base_solver.OTMatcherLinear( - ot_solver, - tau_a=tau_a, - tau_b=tau_b, - ) - - fm = otfm.OTFlowMatching( - neural_vf, - input_dim=source_dim, - cond_dim=condition_dim, - iterations=3, - valid_freq=2, - ot_matcher=ot_matcher, - flow=flow, - time_sampler=time_sampler, - optimizer=optimizer, - ) - fm(data_loader, data_loader) From 85427bac1eda6e85c45b6e6347b4f3676c7cc3e7 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 6 Mar 2024 10:06:43 +0100 Subject: [PATCH 110/186] make genot training loops more similar to otfm training loop --- src/ott/neural/flow_models/genot.py | 120 ++++++++++++++-------------- 1 file changed, 60 insertions(+), 60 deletions(-) diff --git a/src/ott/neural/flow_models/genot.py b/src/ott/neural/flow_models/genot.py index ba2c0f6a0..7291b8fdc 100644 --- a/src/ott/neural/flow_models/genot.py +++ b/src/ott/neural/flow_models/genot.py @@ -17,6 +17,7 @@ import jax import jax.numpy as jnp +import jax.tree_util as jtu import diffrax import optax @@ -43,7 +44,6 @@ class GENOTBase: input_dim: Dimension of the data in the source distribution. output_dim: Dimension of the data in the target distribution. cond_dim: Dimension of the conditioning variable. - iterations: Number of iterations. valid_freq: Frequency of validation. ot_solver: OT solver to match samples from the source and the target distribution. @@ -87,7 +87,6 @@ def __init__( input_dim: int, output_dim: int, cond_dim: int, - iterations: int, valid_freq: int, ot_matcher: base_solver.BaseOTMatcher, optimizer: optax.GradientTransformation, @@ -105,7 +104,6 @@ def __init__( rng = utils.default_prng_key(rng) self.rng = utils.default_prng_key(rng) - self.iterations = iterations self.valid_freq = valid_freq self.velocity_field = velocity_field self.state_velocity_field: Optional[train_state.TrainState] = None @@ -280,31 +278,33 @@ class GENOTLin(GENOTBase): neural solver for entropic (linear) OT problems. """ - def __call__(self, train_loader, valid_loader): - """Train GENOT. + def __call__( + self, + n_iters: int, + train_source, + train_target, + valid_source, + valid_target, + valid_freq: int = 5000, + rng: Optional[jax.Array] = None, + ): + """Train GENOTLin.""" + rng = utils.default_prng_key(rng) + training_logs = {"loss": []} - Args: - train_loader: Data loader for the training data. - valid_loader: Data loader for the validation data. - """ - iter = -1 - stop = False - while True: - for batch in train_loader: - iter += 1 - if iter >= self.iterations: - stop = True - break + for it in range(n_iters): + for batch_source, batch_target in zip(train_source, train_target): ( - self.rng, rng_time, rng_resample, rng_noise, rng_latent_data_match, + rng, rng_resample, rng_noise, rng_time, rng_latent_data_match, rng_step_fn - ) = jax.random.split(self.rng, 6) - source, source_conditions, target = jnp.array( - batch["source_lin"] - ), jnp.array(batch["source_conditions"] - ) if "source_conditions" in batch else None, jnp.array( - batch["target_lin"] - ) + ) = jax.random.split(rng, 6) + + batch_source = jtu.tree_map(jnp.asarray, batch_source) + batch_target = jtu.tree_map(jnp.asarray, batch_target) + + source = batch_source["lin"] + source_conditions = batch_source.get("conditions", None) + target = batch_target["lin"] batch_size = len(source) n_samples = batch_size * self.k_samples_per_x @@ -358,10 +358,10 @@ def __call__(self, train_loader, valid_loader): tmat=tmat ) - if iter % self.valid_freq == 0: - self._valid_step(valid_loader, iter) - if stop: - break + training_logs["loss"].append(float(loss)) + + if it % valid_freq == 0: + self._valid_step(valid_source, valid_target, it) class GENOTQuad(GENOTBase): @@ -373,36 +373,36 @@ class GENOTQuad(GENOTBase): problems, respectively. """ - def __call__(self, train_loader, valid_loader): - """Train GENOT. - - Args: - train_loader: Data loader for the training data. - valid_loader: Data loader for the validation data. - """ - batch: Dict[str, jnp.array] = {} - iter = -1 - stop = False - while True: - for batch in train_loader: - iter += 1 - if iter >= self.iterations: - stop = True - break + def __call__( + self, + n_iters: int, + train_source, + train_target, + valid_source, + valid_target, + valid_freq: int = 5000, + rng: Optional[jax.Array] = None, + ): + """Train GENOTQuad.""" + rng = utils.default_prng_key(rng) + training_logs = {"loss": []} + for it in range(n_iters): + for batch_source, batch_target in zip(train_source, train_target): ( - self.rng, rng_time, rng_resample, rng_noise, rng_latent_data_match, + rng, rng_resample, rng_noise, rng_time, rng_latent_data_match, rng_step_fn - ) = jax.random.split(self.rng, 6) - (source_lin, source_quad, source_conditions, target_lin, - target_quad) = ( - jnp.array(batch["source_lin"]) if "source_lin" in batch else None, - jnp.array(batch["source_quad"]), - jnp.array(batch["source_conditions"]) - if "source_conditions" in batch else None, - jnp.array(batch["target_lin"]) if "target_lin" in batch else None, - jnp.array(batch["target_quad"]) - ) + ) = jax.random.split(rng, 6) + + batch_source = jtu.tree_map(jnp.asarray, batch_source) + batch_target = jtu.tree_map(jnp.asarray, batch_target) + + source_lin = batch_source.get("lin", None) + source_quad = batch_source["quad"] + source_conditions = batch_source.get("conditions", None) + target_lin = batch_target.get("lin", None) + target_quad = batch_target["quad"] + batch_size = len(source_quad) n_samples = batch_size * self.k_samples_per_x time = self.time_sampler(rng_time, n_samples) @@ -463,7 +463,7 @@ def __call__(self, train_loader, valid_loader): condition=source_conditions, tmat=tmat ) - if iter % self.valid_freq == 0: - self._valid_step(valid_loader, iter) - if stop: - break + training_logs["loss"].append(float(loss)) + + if it % valid_freq == 0: + self._valid_step(valid_source, valid_target, it) From 5a2424abaf3cee8d7b663a1ca782d49f7c8e30bc Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 6 Mar 2024 10:23:39 +0100 Subject: [PATCH 111/186] adapt tests to the extent possible --- tests/neural/genot_test.py | 44 ++++++++++++++++++++++---------------- tests/neural/otfm_test.py | 25 +++++++++++++--------- 2 files changed, 41 insertions(+), 28 deletions(-) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index c8d4a48f7..f938728c1 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -76,15 +76,18 @@ def test_genot_linear_unconditional( input_dim=source_dim, output_dim=target_dim, cond_dim=condition_dim, - iterations=3, - valid_freq=2, ot_matcher=ot_matcher, optimizer=optimizer, time_sampler=time_sampler, k_samples_per_x=k_samples_per_x, matcher_latent_to_data=matcher_latent_to_data, ) - genot(genot_data_loader_linear, genot_data_loader_linear) + genot( + genot_data_loader_linear, + genot_data_loader_linear, + n_iters=2, + valid_freq=3 + ) batch = next(iter(genot_data_loader_linear)) result_forward = genot.transport( @@ -135,8 +138,6 @@ def test_genot_linear_conditional( input_dim=source_dim, output_dim=target_dim, cond_dim=condition_dim, - iterations=3, - valid_freq=2, ot_matcher=ot_matcher, optimizer=optimizer, time_sampler=time_sampler, @@ -145,7 +146,9 @@ def test_genot_linear_conditional( ) genot( genot_data_loader_linear_conditional, - genot_data_loader_linear_conditional + genot_data_loader_linear_conditional, + n_iters=2, + valid_freq=3 ) result_forward = genot.transport( source_lin, condition=source_conditions, forward=True @@ -199,15 +202,15 @@ def test_genot_quad_unconditional( input_dim=source_dim, output_dim=target_dim, cond_dim=condition_dim, - iterations=3, - valid_freq=2, ot_matcher=ot_matcher, optimizer=optimizer, time_sampler=time_sampler, k_samples_per_x=k_samples_per_x, matcher_latent_to_data=matcher_latent_to_data, ) - genot(genot_data_loader_quad, genot_data_loader_quad) + genot( + genot_data_loader_quad, genot_data_loader_quad, n_iters=2, valid_freq=3 + ) result_forward = genot.transport( source_quad, condition=source_conditions, forward=True @@ -259,14 +262,17 @@ def test_genot_fused_unconditional( input_dim=source_dim, output_dim=target_dim, cond_dim=condition_dim, - iterations=3, - valid_freq=2, ot_matcher=ot_matcher, optimizer=optimizer, k_samples_per_x=k_samples_per_x, matcher_latent_to_data=matcher_latent_to_data, ) - genot(genot_data_loader_fused, genot_data_loader_fused) + genot( + genot_data_loader_fused, + genot_data_loader_fused, + n_iters=2, + valid_freq=3 + ) result_forward = genot.transport( jnp.concatenate((source_lin, source_quad), axis=1), @@ -320,8 +326,6 @@ def test_genot_quad_conditional( input_dim=source_dim, output_dim=target_dim, cond_dim=condition_dim, - iterations=3, - valid_freq=2, ot_matcher=ot_matcher, optimizer=optimizer, time_sampler=time_sampler, @@ -329,7 +333,10 @@ def test_genot_quad_conditional( matcher_latent_to_data=matcher_latent_to_data, ) genot( - genot_data_loader_quad_conditional, genot_data_loader_quad_conditional + genot_data_loader_quad_conditional, + genot_data_loader_quad_conditional, + n_iters=2, + valid_freq=3 ) result_forward = genot.transport( @@ -385,8 +392,6 @@ def test_genot_fused_conditional( input_dim=source_dim, output_dim=target_dim, cond_dim=condition_dim, - iterations=3, - valid_freq=2, ot_matcher=ot_matcher, optimizer=optimizer, time_sampler=time_sampler, @@ -394,7 +399,10 @@ def test_genot_fused_conditional( matcher_latent_to_data=matcher_latent_to_data, ) genot( - genot_data_loader_fused_conditional, genot_data_loader_fused_conditional + genot_data_loader_fused_conditional, + genot_data_loader_fused_conditional, + n_iters=2, + valid_freq=3 ) result_forward = genot.transport( diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index f43403054..14af037db 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -53,16 +53,18 @@ def test_flow_matching_unconditional( neural_vf, input_dim=input_dim, cond_dim=condition_dim, - iterations=3, - valid_freq=2, ot_matcher=ot_matcher, flow=flow, time_sampler=time_sampler, optimizer=optimizer, ) fm( - data_loaders_gaussian[0], data_loaders_gaussian[1], - data_loaders_gaussian[0], data_loaders_gaussian[1] + data_loaders_gaussian[0], + data_loaders_gaussian[1], + data_loaders_gaussian[0], + data_loaders_gaussian[1], + n_iters=2, + valid_freq=3 ) batch_src = next(iter(data_loaders_gaussian[0])) @@ -110,8 +112,6 @@ def test_flow_matching_with_conditions( neural_vf, input_dim=2, cond_dim=1, - iterations=3, - valid_freq=2, ot_matcher=ot_matcher, flow=flow, time_sampler=time_sampler, @@ -119,7 +119,9 @@ def test_flow_matching_with_conditions( ) fm( data_loader_gaussian_with_conditions, - data_loader_gaussian_with_conditions + data_loader_gaussian_with_conditions, + n_iters=2, + valid_freq=3 ) batch = next(iter(data_loader_gaussian_with_conditions)) @@ -170,14 +172,17 @@ def test_flow_matching_conditional( neural_vf, input_dim=dim, cond_dim=condition_dim, - iterations=3, - valid_freq=2, ot_matcher=ot_matcher, flow=flow, time_sampler=time_sampler, optimizer=optimizer, ) - fm(data_loader_gaussian_conditional, data_loader_gaussian_conditional) + fm( + data_loader_gaussian_conditional, + data_loader_gaussian_conditional, + n_iters=2, + valid_freq=3 + ) batch = next(iter(data_loader_gaussian_conditional)) source = jnp.asarray(batch["source_lin"]) From c4a187e014b9cfd2af087499efef50c96d24c028 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 11 Mar 2024 19:28:25 +0100 Subject: [PATCH 112/186] Add weights to sampling --- src/ott/neural/data/datasets.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/ott/neural/data/datasets.py b/src/ott/neural/data/datasets.py index 5a12ed2c0..d13237cc9 100644 --- a/src/ott/neural/data/datasets.py +++ b/src/ott/neural/data/datasets.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional +from typing import Dict, Iterable, Optional import jax.tree_util as jtu import numpy as np @@ -65,21 +65,29 @@ class ConditionalOTDataset: Args: datasets: Datasets to sample from. + weights: TODO. seed: Random seed. """ def __init__( self, - # TODO(michalk8): allow for dict with weights - datasets: List[OTDataset], + datasets: Iterable[OTDataset], + weights: Iterable[float] = None, seed: Optional[int] = None, ): self.datasets = tuple(datasets) - self._rng = np.random.default_rng(seed=seed) + + if weights is None: + weights = np.ones(len(self.datasets)) + weights = np.asarray(weights) + self.weights = weights / np.sum(weights) + assert len(self.weights) == len(self.datasets), "TODO" + + self._rng = np.random.default_rng(seed) self._iterators = () def __next__(self) -> Dict[str, np.ndarray]: - idx = self._rng.choice(len(self._iterators)) + idx = self._rng.choice(len(self._iterators), p=self.weights) return next(self._iterators[idx]) def __iter__(self) -> "ConditionalOTDataset": From 30f2324ed251d52b3f5a94308bc16ab2865fdfa0 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 11 Mar 2024 19:58:41 +0100 Subject: [PATCH 113/186] Start cleaning matchers --- src/ott/neural/flow_models/otfm.py | 33 ++++++-------- src/ott/neural/flow_models/utils.py | 36 ++++++++++++++++ tests/neural/otfm_test.py | 67 +++++++++++------------------ 3 files changed, 73 insertions(+), 63 deletions(-) create mode 100644 src/ott/neural/flow_models/utils.py diff --git a/src/ott/neural/flow_models/otfm.py b/src/ott/neural/flow_models/otfm.py index e426a6f5c..699bf2960 100644 --- a/src/ott/neural/flow_models/otfm.py +++ b/src/ott/neural/flow_models/otfm.py @@ -24,7 +24,7 @@ from ott import utils from ott.neural.flow_models import flows, models -from ott.neural.models import base_solver +from ott.neural.flow_models.utils import sample_joint __all__ = ["OTFlowMatching"] @@ -40,7 +40,7 @@ class OTFlowMatching: flow: Flow between source and target distribution. time_sampler: Sampler for the time. optimizer: Optimizer for the velocity field's parameters. - ot_matcher: TODO. + match_fn: TODO. rng: Random number generator. """ @@ -53,7 +53,8 @@ def __init__( flow: flows.BaseFlow, time_sampler: Callable[[jax.Array, int], jnp.ndarray], optimizer: optax.GradientTransformation, - ot_matcher: Optional[base_solver.OTMatcherLinear] = None, + match_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray], + jnp.ndarray]] = None, rng: Optional[jax.Array] = None, ): rng = utils.default_prng_key(rng) @@ -62,7 +63,7 @@ def __init__( self.vf = velocity_field self.flow = flow self.time_sampler = time_sampler - self.ot_matcher = ot_matcher + self.match_fn = match_fn self.optimizer = optimizer self.vf_state = self.vf.create_train_state( @@ -113,15 +114,12 @@ def __call__( # noqa: D102 n_iters: int, train_source, train_target, - valid_source, - valid_target, - valid_freq: int = 5000, rng: Optional[jax.Array] = None, ) -> Dict[str, Any]: rng = utils.default_prng_key(rng) training_logs = {"loss": []} - for it in range(n_iters): + for _ in range(n_iters): for batch_source, batch_target in zip(train_source, train_target): rng, rng_resample, rng_step_fn = jax.random.split(rng, 3) @@ -132,22 +130,18 @@ def __call__( # noqa: D102 source_conditions = batch_source.get("conditions", None) target = batch_target["lin"] - if self.ot_matcher is not None: - tmat = self.ot_matcher.match_fn(source, target) - (source, source_conditions), (target,) = self.ot_matcher.sample_joint( - rng_resample, tmat, (source, source_conditions), (target,) - ) - else: - tmat = None + if self.match_fn is not None: + tmat = self.match_fn(source, target) + src_ixs, tgt_ixs = sample_joint(rng_resample, tmat) + source, target = source[src_ixs], target[tgt_ixs] + if source_conditions is not None: + source_conditions = source_conditions[src_ixs] self.vf_state, loss = self.step_fn( rng_step_fn, self.vf_state, source, target, source_conditions ) training_logs["loss"].append(float(loss)) - if it % valid_freq == 0: - self._valid_step(valid_source, valid_target, it) - return training_logs def transport( @@ -203,6 +197,3 @@ def solve_ode(x: jnp.ndarray, cond: Optional[jnp.ndarray]) -> jnp.ndarray: in_axes = [0, None if condition is None else 0] return jax.jit(jax.vmap(solve_ode, in_axes))(x, condition) - - def _valid_step(self, it: int, valid_source, valid_target) -> None: - pass diff --git a/src/ott/neural/flow_models/utils.py b/src/ott/neural/flow_models/utils.py new file mode 100644 index 000000000..797370dce --- /dev/null +++ b/src/ott/neural/flow_models/utils.py @@ -0,0 +1,36 @@ +from typing import Any, Optional, Tuple + +import jax +import jax.numpy as jnp + +from ott.geometry import costs, pointcloud +from ott.solvers import linear + +__all__ = ["match_linear", "sample_joint"] + + +def match_linear( + x: jnp.ndarray, + y: jnp.ndarray, + cost_fn: Optional[costs.CostFn] = None, + epsilon: Optional[float] = None, + # TODO(michalk8): expose rest of the geom arguments + **kwargs: Any +) -> jnp.ndarray: + """TODO.""" + geom = pointcloud.PointCloud(x, y, cost_fn=cost_fn, epsilon=epsilon) + out = linear.solve(geom, **kwargs) + return out.matrix + + +def sample_joint(rng: jax.Array, + tmat: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: + """TODO.""" + n, m = tmat.shape + tmat_flattened = tmat.flatten() + indices = jax.random.choice( + rng, len(tmat_flattened), p=tmat_flattened, shape=[n] + ) + src_ixs = indices // m + tgt_ixs = indices % m + return src_ixs, tgt_ixs diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index 14af037db..01042ec39 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -20,7 +20,7 @@ import optax -from ott.neural.flow_models import flows, models, otfm, samplers +from ott.neural.flow_models import flows, models, otfm, samplers, utils from ott.neural.models import base_solver from ott.solvers.linear import sinkhorn, sinkhorn_lr @@ -31,42 +31,35 @@ class TestOTFlowMatching: "flow", [ flows.ConstantNoiseFlow(0.0), flows.ConstantNoiseFlow(1.0), - flows.BrownianNoiseFlow(0.2) + flows.BrownianNoiseFlow(0.2), ] ) def test_flow_matching_unconditional( - self, data_loaders_gaussian, flow: Type[flows.BaseFlow] + self, data_loaders_gaussian, flow: flows.BaseFlow ): input_dim = 2 - condition_dim = 0 neural_vf = models.VelocityField( output_dim=2, condition_dim=0, latent_embed_dim=5, ) - ot_solver = sinkhorn.Sinkhorn() - ot_matcher = base_solver.OTMatcherLinear(ot_solver) - time_sampler = samplers.uniform_sampler optimizer = optax.adam(learning_rate=1e-3) fm = otfm.OTFlowMatching( + input_dim, neural_vf, - input_dim=input_dim, - cond_dim=condition_dim, - ot_matcher=ot_matcher, flow=flow, - time_sampler=time_sampler, + time_sampler=samplers.uniform_sampler, + match_fn=utils.match_linear, optimizer=optimizer, ) - fm( - data_loaders_gaussian[0], - data_loaders_gaussian[1], - data_loaders_gaussian[0], - data_loaders_gaussian[1], + _ = fm( n_iters=2, - valid_freq=3 + train_source=data_loaders_gaussian[0], + train_target=data_loaders_gaussian[1], ) + # TODO(michalk8): nicer batch_src = next(iter(data_loaders_gaussian[0])) source = jnp.asarray(batch_src["lin"]) batch_tgt = next(iter(data_loaders_gaussian[1])) @@ -74,16 +67,14 @@ def test_flow_matching_unconditional( source_conditions = jnp.asarray( batch_src["conditions"] ) if "conditions" in batch_src else None - result_forward = fm.transport( - source, condition=source_conditions, forward=True - ) - assert isinstance(result_forward, jnp.ndarray) + + result_forward = fm.transport(source, condition=source_conditions) + # TODO(michalk8): better condition assert jnp.sum(jnp.isnan(result_forward)) == 0 result_backward = fm.transport( - target, condition=source_conditions, forward=False + target, condition=source_conditions, t0=1.0, t1=0.0 ) - assert isinstance(result_backward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_backward)) == 0 @pytest.mark.parametrize( @@ -94,34 +85,29 @@ def test_flow_matching_unconditional( ] ) def test_flow_matching_with_conditions( - self, data_loader_gaussian_with_conditions, flow: Type[flows.BaseFlow] + self, data_loader_gaussian_with_conditions, flow: flows.BaseFlow ): - input_dim = 2 - condition_dim = 1 + input_dim, cond_dim = 2, 1 neural_vf = models.VelocityField( output_dim=input_dim, - condition_dim=condition_dim, + condition_dim=cond_dim, latent_embed_dim=5, ) - ot_solver = sinkhorn.Sinkhorn() - ot_matcher = base_solver.OTMatcherLinear(ot_solver) time_sampler = functools.partial(samplers.uniform_sampler, offset=1e-5) optimizer = optax.adam(learning_rate=1e-3) fm = otfm.OTFlowMatching( + 2, neural_vf, - input_dim=2, - cond_dim=1, - ot_matcher=ot_matcher, + match_fn=utils.match_linear, flow=flow, time_sampler=time_sampler, optimizer=optimizer, ) - fm( - data_loader_gaussian_with_conditions, - data_loader_gaussian_with_conditions, + _ = fm( n_iters=2, - valid_freq=3 + train_source=data_loader_gaussian_with_conditions, + train_target=data_loader_gaussian_with_conditions, ) batch = next(iter(data_loader_gaussian_with_conditions)) @@ -130,16 +116,13 @@ def test_flow_matching_with_conditions( source_conditions = jnp.asarray(batch["source_conditions"]) if len( batch["source_conditions"] ) > 0 else None - result_forward = fm.transport( - source, condition=source_conditions, forward=True - ) - assert isinstance(result_forward, jnp.ndarray) + + result_forward = fm.transport(source, condition=source_conditions) assert jnp.sum(jnp.isnan(result_forward)) == 0 result_backward = fm.transport( - target, condition=source_conditions, forward=False + target, condition=source_conditions, t0=1.0, t1=0.0 ) - assert isinstance(result_backward, jnp.ndarray) assert jnp.sum(jnp.isnan(result_backward)) == 0 @pytest.mark.parametrize( From 82bc7e687eb23cc98a8ce0e3d1f95494db6b7fb5 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 11 Mar 2024 20:47:18 +0100 Subject: [PATCH 114/186] Add conditional sampling + resampling --- src/ott/neural/flow_models/otfm.py | 9 +++--- src/ott/neural/flow_models/utils.py | 44 ++++++++++++++++++++++++++++- 2 files changed, 48 insertions(+), 5 deletions(-) diff --git a/src/ott/neural/flow_models/otfm.py b/src/ott/neural/flow_models/otfm.py index 699bf2960..a1f60f849 100644 --- a/src/ott/neural/flow_models/otfm.py +++ b/src/ott/neural/flow_models/otfm.py @@ -24,7 +24,7 @@ from ott import utils from ott.neural.flow_models import flows, models -from ott.neural.flow_models.utils import sample_joint +from ott.neural.flow_models.utils import resample_data, sample_joint __all__ = ["OTFlowMatching"] @@ -133,9 +133,10 @@ def __call__( # noqa: D102 if self.match_fn is not None: tmat = self.match_fn(source, target) src_ixs, tgt_ixs = sample_joint(rng_resample, tmat) - source, target = source[src_ixs], target[tgt_ixs] - if source_conditions is not None: - source_conditions = source_conditions[src_ixs] + source, source_conditions = resample_data( + source, source_conditions, ixs=src_ixs + ) + target = resample_data(target, ixs=tgt_ixs) self.vf_state, loss = self.step_fn( rng_step_fn, self.vf_state, source, target, source_conditions diff --git a/src/ott/neural/flow_models/utils.py b/src/ott/neural/flow_models/utils.py index 797370dce..25598c62e 100644 --- a/src/ott/neural/flow_models/utils.py +++ b/src/ott/neural/flow_models/utils.py @@ -2,11 +2,14 @@ import jax import jax.numpy as jnp +import jax.tree_util as jtu from ott.geometry import costs, pointcloud from ott.solvers import linear -__all__ = ["match_linear", "sample_joint"] +__all__ = [ + "match_linear", "sample_joint", "sample_conditional", "resample_data" +] def match_linear( @@ -34,3 +37,42 @@ def sample_joint(rng: jax.Array, src_ixs = indices // m tgt_ixs = indices % m return src_ixs, tgt_ixs + + +def sample_conditional( + rng: jax.Array, + tmat: jnp.ndarray, + *, + k: int = 1, + uniform_marginals: bool = False, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """TODO.""" + assert k > 0, "Number of samples per row must be positive." + n, m = tmat.shape + + if uniform_marginals: + indices = jnp.arange(n) + else: + src_marginals = tmat.sum(axis=1) + rng, rng_ixs = jax.random.split(rng, 2) + indices = jax.random.choice( + rng_ixs, a=n, p=src_marginals, shape=(len(src_marginals),) + ) + tmat = tmat[indices] + + tgt_ixs = jax.vmap( + lambda row: jax.random.choice(rng, a=m, p=row, shape=(k,)) + )(tmat) # (m, k) + + src_ixs = jnp.repeat(indices[:, None], k, axis=1) # (n, k) + return src_ixs, tgt_ixs + + +def resample_data(*data: Optional[jnp.ndarray], + ixs: jnp.ndarray) -> Tuple[Optional[jnp.ndarray], ...]: + """TODO.""" + if ixs.ndim == 2: + ixs = ixs.reshape(-1) + assert ixs.ndim == 1, ixs.shape + data = jtu.tree_map(lambda arr: None if arr is None else arr[ixs], data) + return data[0] if len(data) == 1 else data From f430c2937112ef86e50438465cf8d41ad74c5ed1 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 11 Mar 2024 20:54:07 +0100 Subject: [PATCH 115/186] Add initial quad matcher --- src/ott/neural/flow_models/utils.py | 36 +++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/src/ott/neural/flow_models/utils.py b/src/ott/neural/flow_models/utils.py index 25598c62e..6e1c01f17 100644 --- a/src/ott/neural/flow_models/utils.py +++ b/src/ott/neural/flow_models/utils.py @@ -5,10 +5,14 @@ import jax.tree_util as jtu from ott.geometry import costs, pointcloud -from ott.solvers import linear +from ott.solvers import linear, quadratic __all__ = [ - "match_linear", "sample_joint", "sample_conditional", "resample_data" + "match_linear", + "match_quadratic", + "sample_joint", + "sample_conditional", + "resample_data", ] @@ -17,15 +21,39 @@ def match_linear( y: jnp.ndarray, cost_fn: Optional[costs.CostFn] = None, epsilon: Optional[float] = None, - # TODO(michalk8): expose rest of the geom arguments + # TODO(michalk8): type this correctly + scale_cost: float = 1.0, **kwargs: Any ) -> jnp.ndarray: """TODO.""" - geom = pointcloud.PointCloud(x, y, cost_fn=cost_fn, epsilon=epsilon) + geom = pointcloud.PointCloud( + x, y, cost_fn=cost_fn, epsilon=epsilon, scale_cost=scale_cost + ) out = linear.solve(geom, **kwargs) return out.matrix +def match_quadratic( + xx: jnp.ndarray, + yy: jnp.ndarray, + xy: Optional[jnp.ndarray] = None, + # TODO(michalk8): expose for all the costs + scale_cost: float = 1.0, + cost_fn: Optional[costs.CostFn] = None, + **kwargs: Any +) -> jnp.ndarray: + """TODO.""" + geom_xx = pointcloud.PointCloud(xx, cost_fn=cost_fn, scale_cost=scale_cost) + geom_yy = pointcloud.PointCloud(yy, cost_fn=cost_fn, scale_cost=scale_cost) + if xy is None: + geom_xy = None + else: + geom_xy = pointcloud.PointCloud(xy, cost_fn=cost_fn, scale_cost=scale_cost) + + out = quadratic.solve(geom_xx, geom_yy, geom_xy, **kwargs) + return out.matrix + + def sample_joint(rng: jax.Array, tmat: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: """TODO.""" From 4b41f0ccbc7d60e9bdd96bc540d957de38d05712 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 11 Mar 2024 20:54:54 +0100 Subject: [PATCH 116/186] Improve typing --- src/ott/neural/flow_models/utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/ott/neural/flow_models/utils.py b/src/ott/neural/flow_models/utils.py index 6e1c01f17..d6d023266 100644 --- a/src/ott/neural/flow_models/utils.py +++ b/src/ott/neural/flow_models/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Tuple +from typing import Any, Literal, Optional, Tuple, Union import jax import jax.numpy as jnp @@ -15,14 +15,15 @@ "resample_data", ] +ScaleCost_t = Union[float, Literal["mean", "max_cost", "median"]] + def match_linear( x: jnp.ndarray, y: jnp.ndarray, cost_fn: Optional[costs.CostFn] = None, epsilon: Optional[float] = None, - # TODO(michalk8): type this correctly - scale_cost: float = 1.0, + scale_cost: ScaleCost_t = 1.0, **kwargs: Any ) -> jnp.ndarray: """TODO.""" @@ -38,7 +39,7 @@ def match_quadratic( yy: jnp.ndarray, xy: Optional[jnp.ndarray] = None, # TODO(michalk8): expose for all the costs - scale_cost: float = 1.0, + scale_cost: ScaleCost_t = 1.0, cost_fn: Optional[costs.CostFn] = None, **kwargs: Any ) -> jnp.ndarray: From cc2746b5de1377028f83bec3c69c3663ff0aaba7 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 11 Mar 2024 21:02:06 +0100 Subject: [PATCH 117/186] Remove `base_solver.py` --- src/ott/neural/flow_models/genot.py | 7 +- src/ott/neural/models/__init__.py | 2 +- src/ott/neural/models/base_solver.py | 308 --------------------------- tests/neural/otfm_test.py | 7 +- 4 files changed, 7 insertions(+), 317 deletions(-) delete mode 100644 src/ott/neural/models/base_solver.py diff --git a/src/ott/neural/flow_models/genot.py b/src/ott/neural/flow_models/genot.py index 7291b8fdc..88be0da38 100644 --- a/src/ott/neural/flow_models/genot.py +++ b/src/ott/neural/flow_models/genot.py @@ -25,7 +25,6 @@ from ott import utils from ott.neural.flow_models import flows, samplers -from ott.neural.models import base_solver __all__ = ["GENOTBase", "GENOTLin", "GENOTQuad"] @@ -88,13 +87,15 @@ def __init__( output_dim: int, cond_dim: int, valid_freq: int, - ot_matcher: base_solver.BaseOTMatcher, + # TODO(michalk8) + ot_matcher: Any, optimizer: optax.GradientTransformation, flow: Type[flows.BaseFlow] = flows.ConstantNoiseFlow(0.0), # noqa: B008 time_sampler: Callable[[jax.Array, int], jnp.ndarray] = samplers.uniform_sampler, k_samples_per_x: int = 1, - matcher_latent_to_data: Optional[base_solver.OTMatcherLinear] = None, + # TODO(michalk8) + matcher_latent_to_data: Optional[Callable] = None, kwargs_solver_latent_to_data: Dict[str, Any] = types.MappingProxyType({}), fused_penalty: float = 0.0, callback_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], diff --git a/src/ott/neural/models/__init__.py b/src/ott/neural/models/__init__.py index ba39ae8b4..83287aec5 100644 --- a/src/ott/neural/models/__init__.py +++ b/src/ott/neural/models/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import base_solver, layers, nets +from . import layers, nets diff --git a/src/ott/neural/models/base_solver.py b/src/ott/neural/models/base_solver.py deleted file mode 100644 index 5ddfd5ef5..000000000 --- a/src/ott/neural/models/base_solver.py +++ /dev/null @@ -1,308 +0,0 @@ -# Copyright OTT-JAX -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Callable, Dict, Literal, Mapping, Optional, Tuple, Union - -import jax -import jax.numpy as jnp -from jax import tree_util - -from ott.geometry import costs, pointcloud -from ott.problems.linear import linear_problem -from ott.problems.quadratic import quadratic_problem -from ott.solvers.linear import sinkhorn -from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr - -ScaleCost_t = Union[int, float, Literal["mean", "max_cost", "median"]] -ScaleCostQuad_t = Union[ScaleCost_t, Dict[str, ScaleCost_t]] - -__all__ = [ - "BaseOTMatcher", - "OTMatcherLinear", - "OTMatcherQuad", -] - - -def _get_sinkhorn_match_fn( - ot_solver: Any, - epsilon: float = 1e-2, - cost_fn: Optional[costs.CostFn] = None, - scale_cost: ScaleCost_t = 1.0, - tau_a: float = 1.0, - tau_b: float = 1.0, -) -> Callable: - - @jax.jit - def match_pairs(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: - geom = pointcloud.PointCloud( - x, y, epsilon=epsilon, scale_cost=scale_cost, cost_fn=cost_fn - ) - return ot_solver( - linear_problem.LinearProblem(geom, tau_a=tau_a, tau_b=tau_b) - ) - - return match_pairs - - -def _get_gromov_match_fn( - ot_solver: Any, - cost_fn: Union[Any, Mapping[str, Any]], - scale_cost: ScaleCostQuad_t, - tau_a: float, - tau_b: float, - fused_penalty: float, -) -> Callable: - if isinstance(cost_fn, Mapping): - assert "cost_fn_xx" in cost_fn - assert "cost_fn_yy" in cost_fn - cost_fn_xx = cost_fn["cost_fn_xx"] - cost_fn_yy = cost_fn["cost_fn_yy"] - if fused_penalty > 0: - assert "cost_fn_xy" in cost_fn_xx - cost_fn_xy = cost_fn["cost_fn_xy"] - else: - cost_fn_xx = cost_fn_yy = cost_fn_xy = cost_fn - - if isinstance(scale_cost, Mapping): - assert "scale_cost_xx" in scale_cost - assert "scale_cost_yy" in scale_cost - scale_cost_xx = scale_cost["scale_cost_xx"] - scale_cost_yy = scale_cost["scale_cost_yy"] - if fused_penalty > 0: - assert "scale_cost_xy" in scale_cost - scale_cost_xy = cost_fn["scale_cost_xy"] - else: - scale_cost_xx = scale_cost_yy = scale_cost_xy = scale_cost - - @jax.jit - def match_pairs( - x_quad: jnp.ndarray, - y_quad: jnp.ndarray, - x_lin: Optional[jnp.ndarray], - y_lin: Optional[jnp.ndarray], - ) -> jnp.ndarray: - geom_xx = pointcloud.PointCloud( - x=x_quad, y=x_quad, cost_fn=cost_fn_xx, scale_cost=scale_cost_xx - ) - geom_yy = pointcloud.PointCloud( - x=y_quad, y=y_quad, cost_fn=cost_fn_yy, scale_cost=scale_cost_yy - ) - if fused_penalty > 0: - geom_xy = pointcloud.PointCloud( - x=x_lin, y=y_lin, cost_fn=cost_fn_xy, scale_cost=scale_cost_xy - ) - else: - geom_xy = None - prob = quadratic_problem.QuadraticProblem( - geom_xx, - geom_yy, - geom_xy, - fused_penalty=fused_penalty, - tau_a=tau_a, - tau_b=tau_b - ) - return ot_solver(prob) - - return match_pairs - - -class BaseOTMatcher: - """Base class for mini-batch neural OT matching classes.""" - - def sample_joint( - self, - rng: jax.Array, - joint_dist: jnp.ndarray, - source_arrays: Tuple[Optional[jnp.ndarray], ...], - target_arrays: Tuple[Optional[jnp.ndarray], ...], - ) -> Tuple[jnp.ndarray, ...]: - """Resample from arrays according to discrete joint distribution. - - Args: - rng: Random number generator. - joint_dist: Joint distribution between source and target to sample from. - source_arrays: Arrays corresponding to source distriubution to sample - from. - target_arrays: Arrays corresponding to target arrays to sample from. - - Returns: - Resampled source and target arrays. - """ - _, n_tgt = joint_dist.shape - tmat_flattened = joint_dist.flatten() - indices = jax.random.choice( - rng, len(tmat_flattened), p=tmat_flattened, shape=[joint_dist.shape[0]] - ) - indices_source = indices // n_tgt - indices_target = indices % n_tgt - return tree_util.tree_map(lambda b: b[indices_source], - source_arrays), tree_util.tree_map( - lambda b: b[indices_target], target_arrays - ) - - def sample_conditional_indices_from_tmap( - self, - rng: jax.Array, - conditional_distributions: jnp.ndarray, - *, - k_samples_per_x: int, - source_arrays: Tuple[Optional[jnp.ndarray], ...], - target_arrays: Tuple[Optional[jnp.ndarray], ...], - source_is_balanced: bool, - ) -> Tuple[jnp.ndarray, ...]: - """Sample from arrays according to discrete conditional distributions. - - Args: - rng: Random number generator. - conditional_distributions: Conditional distributions to sample from. - k_samples_per_x: Expectation of number of samples to draw from each - conditional distribution. - source_arrays: Arrays corresponding to source distriubution to sample - from. - target_arrays: Arrays corresponding to target arrays to sample from. - source_is_balanced: Whether the source distribution is balanced. - If :obj:`False`, the number of samples drawn from each conditional - distribution `k_samples_per_x` is proportional to the left marginals. - - Returns: - Resampled source and target arrays. - """ - n_src, n_tgt = conditional_distributions.shape - left_marginals = conditional_distributions.sum(axis=1) - if not source_is_balanced: - rng, rng_2 = jax.random.split(rng, 2) - indices = jax.random.choice( - key=rng_2, - a=jnp.arange(len(left_marginals)), - p=left_marginals, - shape=(len(left_marginals),) - ) - else: - indices = jnp.arange(n_src) - tmat_adapted = conditional_distributions[indices] - indices_per_row = jax.vmap( - lambda row: jax.random. - choice(key=rng, a=n_tgt, p=row, shape=(k_samples_per_x,)), - in_axes=0, - out_axes=0, - )( - tmat_adapted - ) - - indices_source = jnp.repeat(indices, k_samples_per_x) - indices_target = jnp.reshape( - indices_per_row % n_tgt, (n_src * k_samples_per_x,) - ) - return tree_util.tree_map( - lambda b: jnp. - reshape(b[indices_source], - (k_samples_per_x, n_src, *b.shape[1:])), source_arrays - ), tree_util.tree_map( - lambda b: jnp. - reshape(b[indices_target], - (k_samples_per_x, n_src, *b.shape[1:])), target_arrays - ) - - -class OTMatcherLinear(BaseOTMatcher): - """Class for mini-batch OT in neural optimal transport solvers. - - Args: - ot_solver: OT solver to match samples from the source and the target - distribution as proposed in :cite:`tong:23`, :cite:`pooladian:23`. - If :obj:`None`, no matching will be performed as proposed in - :cite:`lipman:22`. - """ - - def __init__( - self, - ot_solver: sinkhorn.Sinkhorn, - epsilon: float = 1e-2, - cost_fn: Optional[costs.CostFn] = None, - scale_cost: ScaleCost_t = 1.0, - tau_a: float = 1.0, - tau_b: float = 1.0, - ) -> None: - - if isinstance( - ot_solver, gromov_wasserstein.GromovWasserstein - ) and epsilon is not None: - raise ValueError( - "If `ot_solver` is `GromovWasserstein`, `epsilon` must be `None`. " + - "This check is performed to ensure that in the (fused) Gromov case " + - "the `epsilon` parameter is passed via the `ot_solver`." - ) - self.ot_solver = ot_solver - self.epsilon = epsilon - self.cost_fn = cost_fn - self.scale_cost = scale_cost - self.tau_a = tau_a - self.tau_b = tau_b - self.match_fn = None if ot_solver is None else self._get_sinkhorn_match_fn( - self.ot_solver, self.epsilon, self.cost_fn, self.scale_cost, self.tau_a, - self.tau_b - ) - - def _get_sinkhorn_match_fn(self, *args, **kwargs) -> jnp.ndarray: - fn = _get_sinkhorn_match_fn(*args, **kwargs) - - @jax.jit - def match_pairs(*args, **kwargs): - return fn(*args, **kwargs).matrix - - return match_pairs - - -class OTMatcherQuad(BaseOTMatcher): - """Class for mini-batch OT in neural optimal transport solvers. - - Args: - ot_solver: OT solver to match samples from the source and the target - distribution as proposed in :cite:`tong:23`, :cite:`pooladian:23`. - If :obj:`None`, no matching will be performed as proposed in - :cite:`lipman:22`. - """ - - def __init__( - self, - ot_solver: Union[gromov_wasserstein.GromovWasserstein, - gromov_wasserstein_lr.LRGromovWasserstein], - cost_fn: Optional[costs.CostFn] = None, - scale_cost: ScaleCostQuad_t = 1.0, - tau_a: float = 1.0, - tau_b: float = 1.0, - fused_penalty: float = 0.0, - ) -> None: - self.ot_solver = ot_solver - self.cost_fn = cost_fn - self.scale_cost = scale_cost - self.tau_a = tau_a - self.tau_b = tau_b - self.fused_penalty = fused_penalty - self.match_fn = self._get_gromov_match_fn( - self.ot_solver, - self.cost_fn, - self.scale_cost, - self.tau_a, - self.tau_b, - fused_penalty=self.fused_penalty - ) - - def _get_gromov_match_fn(self, *args, **kwargs) -> jnp.ndarray: - fn = _get_gromov_match_fn(*args, **kwargs) - - @jax.jit - def match_pairs(*args, **kwargs): - return fn(*args, **kwargs).matrix - - return match_pairs diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index 01042ec39..f4c0fc1d6 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -21,8 +21,6 @@ import optax from ott.neural.flow_models import flows, models, otfm, samplers, utils -from ott.neural.models import base_solver -from ott.solvers.linear import sinkhorn, sinkhorn_lr class TestOTFlowMatching: @@ -145,9 +143,8 @@ def test_flow_matching_conditional( condition_dim=condition_dim, latent_embed_dim=5, ) - ot_solver = sinkhorn.Sinkhorn( - ) if solver == "sinkhorn" else sinkhorn_lr.LRSinkhorn() - ot_matcher = base_solver.OTMatcherLinear(ot_solver) + # TODO(michalk8): check for LR + ot_matcher = utils.match_linear time_sampler = samplers.uniform_sampler optimizer = optax.adam(learning_rate=1e-3) From 1068410a562dfa05501a9e2d2362795e462e301b Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 11 Mar 2024 21:08:21 +0100 Subject: [PATCH 118/186] Add TODO --- src/ott/neural/data/datasets.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/ott/neural/data/datasets.py b/src/ott/neural/data/datasets.py index d13237cc9..9a73677c1 100644 --- a/src/ott/neural/data/datasets.py +++ b/src/ott/neural/data/datasets.py @@ -71,6 +71,7 @@ class ConditionalOTDataset: def __init__( self, + # TODO(michalk8): generalize the type datasets: Iterable[OTDataset], weights: Iterable[float] = None, seed: Optional[int] = None, From e5597402e1bcfc047063db906f53ac2e87ac2a6b Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Wed, 13 Mar 2024 14:04:38 +0100 Subject: [PATCH 119/186] Update datasets, fix OTFM tests --- docs/neural/data.rst | 2 +- src/ott/neural/data/datasets.py | 157 +++++++++++++++-------- src/ott/neural/flow_models/otfm.py | 48 +++---- tests/neural/conftest.py | 199 ++++++++++++++--------------- tests/neural/genot_test.py | 1 - tests/neural/otfm_test.py | 186 +++++++++------------------ 6 files changed, 289 insertions(+), 304 deletions(-) diff --git a/docs/neural/data.rst b/docs/neural/data.rst index 95f05f93f..25172dcd3 100644 --- a/docs/neural/data.rst +++ b/docs/neural/data.rst @@ -12,4 +12,4 @@ Datasets :toctree: _autosummary datasets.OTDataset - datasets.ConditionalOTDataset + datasets.ConditionalLoader diff --git a/src/ott/neural/data/datasets.py b/src/ott/neural/data/datasets.py index 9a73677c1..0a2067f25 100644 --- a/src/ott/neural/data/datasets.py +++ b/src/ott/neural/data/datasets.py @@ -11,86 +11,139 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Iterable, Optional +import collections +import dataclasses +from typing import Any, Dict, Iterable, Optional, Sequence -import jax.tree_util as jtu import numpy as np -__all__ = ["OTDataset", "ConditionalOTDataset"] +__all__ = ["OTData", "OTDataset", "ConditionalLoader"] +Item_t = Dict[str, np.ndarray] -class OTDataset: - """Dataset for Optimal transport problems. - Args: - lin: Linear part of the measure. - quad: Quadratic part of the measure. - conditions: Conditions of the source measure. - """ +@dataclasses.dataclass(repr=False, frozen=True) +class OTData: + """TODO.""" + lin: Optional[np.ndarray] = None + quad: Optional[np.ndarray] = None + condition: Optional[np.ndarray] = None + + def __getitem__(self, ix: int) -> Item_t: + return {k: v[ix] for k, v in self.__dict__.items() if v is not None} + + def __len__(self) -> int: + if self.lin is not None: + return len(self.lin) + if self.quad is not None: + return len(self.quad) + return 0 + + +class OTDataset: + """TODO.""" + SRC_PREFIX = "src" + TGT_PREFIX = "tgt" def __init__( self, - lin: Optional[np.ndarray] = None, - quad: Optional[np.ndarray] = None, - conditions: Optional[np.ndarray] = None, + src_data: OTData, + tgt_data: OTData, + src_conditions: Optional[Sequence[Any]] = None, + tgt_conditions: Optional[Sequence[Any]] = None, + is_aligned: bool = False, + seed: Optional[int] = None ): - self.data = {} - if lin is not None: - self.data["lin"] = lin - if quad is not None: - self.data["quad"] = quad - if conditions is not None: - self.data["conditions"] = conditions - self._check_sizes() - - def _check_sizes(self) -> None: - sizes = {k: len(v) for k, v in self.data.items()} - if not len(set(sizes.values())) == 1: - raise ValueError(f"Not all arrays have the same size: {sizes}.") - - def __getitem__(self, idx: np.ndarray) -> Dict[str, np.ndarray]: - return jtu.tree_map(lambda x: x[idx], self.data) + self.src_data = src_data + self.tgt_data = tgt_data + + if src_conditions is None: + src_conditions = [None] * len(src_data) + self.src_conditions = list(src_conditions) + if tgt_conditions is None: + tgt_conditions = [None] * len(tgt_data) + self.tgt_conditions = list(tgt_conditions) + + self._tgt_cond_to_ix = collections.defaultdict(list) + for ix, cond in enumerate(tgt_conditions): + self._tgt_cond_to_ix[cond].append(ix) + + self.is_aligned = is_aligned + self._rng = np.random.default_rng(seed) + + self._verify_integriy() + + def _verify_integriy(self) -> None: + assert len(self.src_data) == len(self.src_conditions) + assert len(self.src_data) == len(self.tgt_conditions) + + if self.is_aligned: + assert len(self.src_data) == len(self.tgt_data) + assert self.src_conditions == self.tgt_conditions + else: + sym_diff = set(self.src_conditions + ).symmetric_difference(self.tgt_conditions) + assert not sym_diff, sym_diff + + def _sample_from_target(self, src_ix: int) -> Item_t: + src_cond = self.src_conditions[src_ix] + tgt_ixs = self._tgt_cond_to_ix[src_cond] + ix = self._rng.choice(tgt_ixs) + return self.src_data[ix] + + def __getitem__(self, ix: int) -> Item_t: + src = self.src_data[ix] + src = {f"{self.SRC_PREFIX}_{k}": v for k, v in src.items()} + + tgt = self.src_data[ix] if self.is_aligned else self._sample_from_target(ix) + tgt = {f"{self.TGT_PREFIX}_{k}": v for k, v in tgt.items()} + + return {**src, **tgt} def __len__(self) -> int: - for v in self.data.values(): - return len(v) - return 0 + return len(self.src_data) -# TODO(michalk8): rename -class ConditionalOTDataset: +class ConditionalLoader: """Dataset for OT problems with conditions. This data loader wraps several data loaders and samples from them. Args: datasets: Datasets to sample from. - weights: TODO. seed: Random seed. """ def __init__( self, - # TODO(michalk8): generalize the type datasets: Iterable[OTDataset], - weights: Iterable[float] = None, seed: Optional[int] = None, ): self.datasets = tuple(datasets) - - if weights is None: - weights = np.ones(len(self.datasets)) - weights = np.asarray(weights) - self.weights = weights / np.sum(weights) - assert len(self.weights) == len(self.datasets), "TODO" - self._rng = np.random.default_rng(seed) - self._iterators = () - - def __next__(self) -> Dict[str, np.ndarray]: - idx = self._rng.choice(len(self._iterators), p=self.weights) - return next(self._iterators[idx]) - - def __iter__(self) -> "ConditionalOTDataset": - self._iterators = tuple(iter(ds) for ds in self.datasets) + self._iterators = [] + self._it = 0 + + def __next__(self) -> Item_t: + if self._it == len(self): + raise StopIteration + + ix = self._rng.choice(len(self._iterators)) + iterator = self._iterators[ix] + try: + data = next(iterator) + # TODO(michalk8): improve the logic a bit + self._it += 1 + return data + except StopIteration: + self._iterators[ix] = iter(self.datasets[ix]) + if not self._iterators: + raise + + def __iter__(self) -> "ConditionalLoader": + self._iterators = [iter(ds) for ds in self.datasets] + self._it = 0 return self + + def __len__(self) -> int: + return max((len(ds) for ds in self.datasets), default=0) diff --git a/src/ott/neural/flow_models/otfm.py b/src/ott/neural/flow_models/otfm.py index a1f60f849..83d709a96 100644 --- a/src/ott/neural/flow_models/otfm.py +++ b/src/ott/neural/flow_models/otfm.py @@ -111,37 +111,39 @@ def loss_fn( # TODO(michalk8): refactor in the future PR to just do one step def __call__( # noqa: D102 self, + loader: Any, # TODO(michalk8): type it correctly + *, n_iters: int, - train_source, - train_target, rng: Optional[jax.Array] = None, ) -> Dict[str, Any]: rng = utils.default_prng_key(rng) training_logs = {"loss": []} - for _ in range(n_iters): - for batch_source, batch_target in zip(train_source, train_target): - rng, rng_resample, rng_step_fn = jax.random.split(rng, 3) + for batch in loader: + rng, rng_resample, rng_step_fn = jax.random.split(rng, 3) - batch_source = jtu.tree_map(jnp.asarray, batch_source) - batch_target = jtu.tree_map(jnp.asarray, batch_target) + batch = jtu.tree_map(jnp.asarray, batch) - source = batch_source["lin"] - source_conditions = batch_source.get("conditions", None) - target = batch_target["lin"] + src, tgt = batch["src_lin"], batch["tgt_lin"] + src_conds = batch.get("src_condition", None) - if self.match_fn is not None: - tmat = self.match_fn(source, target) - src_ixs, tgt_ixs = sample_joint(rng_resample, tmat) - source, source_conditions = resample_data( - source, source_conditions, ixs=src_ixs - ) - target = resample_data(target, ixs=tgt_ixs) + if self.match_fn is not None: + tmat = self.match_fn(src, tgt) + src_ixs, tgt_ixs = sample_joint(rng_resample, tmat) + src, src_conds = resample_data(src, src_conds, ixs=src_ixs) + tgt = resample_data(tgt, ixs=tgt_ixs) - self.vf_state, loss = self.step_fn( - rng_step_fn, self.vf_state, source, target, source_conditions - ) - training_logs["loss"].append(float(loss)) + self.vf_state, loss = self.step_fn( + rng_step_fn, + self.vf_state, + src, + tgt, + src_conds, + ) + + training_logs["loss"].append(float(loss)) + if len(training_logs["loss"]) >= n_iters: + break return training_logs @@ -159,8 +161,8 @@ def transport( parameterized by the velocity field. Args: - x: Initial condition of the ODE of shape `(batch_size, ...)`. - condition: Condition of the input data of shape `(batch_size, ...)`. + x: Initial condition of the ODE of shape ``[batch_size, ...]``. + condition: Condition of the input data of shape ``[batch_size, ...]``. t0: Starting point of integration. t1: End point of integration. kwargs: Keyword arguments for the ODE solver. diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py index f5c48e924..04f9917a8 100644 --- a/tests/neural/conftest.py +++ b/tests/neural/conftest.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +from typing import Optional, Union import pytest @@ -22,71 +22,73 @@ from ott.neural.data import datasets -@pytest.fixture(scope="module") -def data_loaders_gaussian() -> Tuple[DataLoader, DataLoader]: +def _ot_data( + rng: np.random.Generator, + *, + n: int = 100, + dim: int = 2, + condition: Optional[Union[float, np.ndarray]] = None, + cond_dim: Optional[int] = None, + offset: float = 0.0 +) -> datasets.OTData: + data = rng.normal(size=(n, dim)) + offset + + if isinstance(condition, float): + cond_dim = dim if cond_dim is None else cond_dim + condition = np.full((n, cond_dim), fill_value=condition) + + return datasets.OTData(lin=data, condition=condition) + + +@pytest.fixture() +def lin_dl() -> DataLoader: """Returns a data loader for a simple Gaussian mixture.""" - rng = np.random.default_rng(seed=0) - source = rng.normal(size=(100, 2)) - target = rng.normal(size=(100, 2)) + 1.0 - src_dataset = datasets.OTDataset(lin=source) - tgt_dataset = datasets.OTDataset(lin=target) - loader_src = DataLoader(src_dataset, batch_size=16, shuffle=True) - loader_tgt = DataLoader(tgt_dataset, batch_size=16, shuffle=True) - return loader_src, loader_tgt + n, d = 100, 2 + rng = np.random.default_rng(0) + src, tgt = _ot_data(rng, n=n, dim=d), _ot_data(rng, n=n, dim=d, offset=1.0) + ds = datasets.OTDataset(src, tgt) + return DataLoader(ds, batch_size=16, shuffle=True) -@pytest.fixture(scope="module") -def data_loader_gaussian_conditional(): - """Returns a data loader for Gaussian mixtures with conditions.""" - rng = np.random.default_rng(seed=0) - source_0 = rng.normal(size=(100, 2)) - target_0 = rng.normal(size=(100, 2)) + 2.0 +@pytest.fixture() +def lin_dl_with_conds() -> DataLoader: + n, d = 100, 2 + rng = np.random.default_rng(13) - source_1 = rng.normal(size=(100, 2)) - target_1 = rng.normal(size=(100, 2)) - 2.0 - ds0 = datasets.OTDataset( - lin=source_0, - target_lin=target_0, - conditions=np.zeros_like(source_0) * 0.0 - ) - ds1 = datasets.OTDataset( - lin=source_1, - target_lin=target_1, - conditions=np.ones_like(source_1) * 1.0 - ) - sampler0 = torch.utils.data.RandomSampler(ds0, replacement=True) - sampler1 = torch.utils.data.RandomSampler(ds1, replacement=True) - dl0 = DataLoader(ds0, batch_size=16, sampler=sampler0) - dl1 = DataLoader(ds1, batch_size=16, sampler=sampler1) + src_cond, tgt_cond = rng.normal(size=(n, 1)), rng.normal(size=(n, 1)) + src = _ot_data(rng, n=n, dim=d, condition=src_cond) + tgt = _ot_data(rng, n=n, dim=d, condition=tgt_cond) - return datasets.ConditionalOTDataset((dl0, dl1)) + ds = datasets.OTDataset(src, tgt) + return DataLoader(ds, batch_size=16, shuffle=True) -@pytest.fixture(scope="module") -def data_loader_gaussian_with_conditions(): - """Returns a data loader for a simple Gaussian mixture with conditions.""" - rng = np.random.default_rng(seed=0) - source = rng.normal(size=(100, 2)) - target = rng.normal(size=(100, 2)) + 1.0 - source_conditions = rng.normal(size=(100, 1)) - target_conditions = rng.normal(size=(100, 1)) - 1.0 +@pytest.fixture() +def conditional_lin_dl() -> datasets.ConditionalLoader: + rng = np.random.default_rng(42) - dataset = datasets.OTDataset( - lin=source, - target_lin=target, - conditions=source_conditions, - target_conditions=target_conditions - ) - return DataLoader(dataset, batch_size=16, shuffle=True) + src0, tgt0 = _ot_data(rng, condition=0.0), _ot_data(rng, offset=2.0) + src1, tgt1 = _ot_data(rng, condition=1.0), _ot_data(rng, offset=-2.0) + + src_ds = datasets.OTDataset(src0, tgt0) + tgt_ds = datasets.OTDataset(src1, tgt1) + + src_dl = DataLoader(src_ds, batch_size=16, shuffle=True) + tgt_dl = DataLoader(tgt_ds, batch_size=16, shuffle=True) + + return datasets.ConditionalLoader([src_dl, tgt_dl]) + + +# TODO(michalk8): refactor the below for GENOT @pytest.fixture(scope="module") def genot_data_loader_linear(): """Returns a data loader for a simple Gaussian mixture.""" rng = np.random.default_rng(seed=0) - source = rng.normal(size=(100, 2)) - target = rng.normal(size=(100, 2)) + 1.0 - dataset = datasets.OTDataset(lin=source, target_lin=target) + src = rng.normal(size=(100, 2)) + tgt = rng.normal(size=(100, 2)) + 1.0 + dataset = datasets.OTDataset(lin=src, tgt_lin=tgt) return DataLoader(dataset, batch_size=16, shuffle=True) @@ -94,35 +96,31 @@ def genot_data_loader_linear(): def genot_data_loader_linear_conditional(): """Returns a data loader for a simple Gaussian mixture.""" rng = np.random.default_rng(seed=0) - source_0 = rng.normal(size=(100, 2)) - target_0 = rng.normal(size=(100, 2)) + 1.0 - source_1 = rng.normal(size=(100, 2)) - target_1 = rng.normal(size=(100, 2)) + 1.0 + src_0 = rng.normal(size=(100, 2)) + tgt_0 = rng.normal(size=(100, 2)) + 1.0 + src_1 = rng.normal(size=(100, 2)) + tgt_1 = rng.normal(size=(100, 2)) + 1.0 ds0 = datasets.OTDataset( - lin=source_0, - target_lin=target_0, - conditions=np.zeros_like(source_0) * 0.0 + lin=src_0, tgt_lin=tgt_0, conditions=np.zeros_like(src_0) * 0.0 ) ds1 = datasets.OTDataset( - lin=source_1, - target_lin=target_1, - conditions=np.ones_like(source_1) * 1.0 + lin=src_1, tgt_lin=tgt_1, conditions=np.ones_like(src_1) * 1.0 ) sampler0 = torch.utils.data.RandomSampler(ds0, replacement=True) sampler1 = torch.utils.data.RandomSampler(ds1, replacement=True) dl0 = DataLoader(ds0, batch_size=16, sampler=sampler0) dl1 = DataLoader(ds1, batch_size=16, sampler=sampler1) - return datasets.ConditionalOTDataset((dl0, dl1)) + return datasets.ConditionalLoader((dl0, dl1)) @pytest.fixture(scope="module") def genot_data_loader_quad(): """Returns a data loader for a simple Gaussian mixture.""" rng = np.random.default_rng(seed=0) - source = rng.normal(size=(100, 2)) - target = rng.normal(size=(100, 1)) + 1.0 - dataset = datasets.OTDataset(quad=source, target_quad=target) + src = rng.normal(size=(100, 2)) + tgt = rng.normal(size=(100, 1)) + 1.0 + dataset = datasets.OTDataset(quad=src, tgt_quad=tgt) return DataLoader(dataset, batch_size=16, shuffle=True) @@ -130,41 +128,34 @@ def genot_data_loader_quad(): def genot_data_loader_quad_conditional(): """Returns a data loader for a simple Gaussian mixture.""" rng = np.random.default_rng(seed=0) - source_0 = rng.normal(size=(100, 2)) - target_0 = rng.normal(size=(100, 1)) + 1.0 - source_1 = rng.normal(size=(100, 2)) - target_1 = rng.normal(size=(100, 1)) + 1.0 + src_0 = rng.normal(size=(100, 2)) + tgt_0 = rng.normal(size=(100, 1)) + 1.0 + src_1 = rng.normal(size=(100, 2)) + tgt_1 = rng.normal(size=(100, 1)) + 1.0 ds0 = datasets.OTDataset( - quad=source_0, - target_quad=target_0, - conditions=np.zeros_like(source_0) * 0.0 + quad=src_0, tgt_quad=tgt_0, conditions=np.zeros_like(src_0) * 0.0 ) ds1 = datasets.OTDataset( - quad=source_1, - target_quad=target_1, - conditions=np.ones_like(source_1) * 1.0 + quad=src_1, tgt_quad=tgt_1, conditions=np.ones_like(src_1) * 1.0 ) sampler0 = torch.utils.data.RandomSampler(ds0, replacement=True) sampler1 = torch.utils.data.RandomSampler(ds1, replacement=True) dl0 = DataLoader(ds0, batch_size=16, sampler=sampler0) dl1 = DataLoader(ds1, batch_size=16, sampler=sampler1) - return datasets.ConditionalOTDataset((dl0, dl1)) + return datasets.ConditionalLoader((dl0, dl1)) @pytest.fixture(scope="module") def genot_data_loader_fused(): """Returns a data loader for a simple Gaussian mixture.""" rng = np.random.default_rng(seed=0) - source_q = rng.normal(size=(100, 2)) - target_q = rng.normal(size=(100, 1)) + 1.0 - source_lin = rng.normal(size=(100, 2)) - target_lin = rng.normal(size=(100, 2)) + 1.0 + src_q = rng.normal(size=(100, 2)) + tgt_q = rng.normal(size=(100, 1)) + 1.0 + src_lin = rng.normal(size=(100, 2)) + tgt_lin = rng.normal(size=(100, 2)) + 1.0 dataset = datasets.OTDataset( - lin=source_lin, - quad=source_q, - target_lin=target_lin, - target_quad=target_q + lin=src_lin, quad=src_q, tgt_lin=tgt_lin, tgt_quad=tgt_q ) return DataLoader(dataset, batch_size=16, shuffle=True) @@ -173,32 +164,32 @@ def genot_data_loader_fused(): def genot_data_loader_fused_conditional(): """Returns a data loader for a simple Gaussian mixture.""" rng = np.random.default_rng(seed=0) - source_q_0 = rng.normal(size=(100, 2)) - target_q_0 = rng.normal(size=(100, 1)) + 1.0 - source_lin_0 = rng.normal(size=(100, 2)) - target_lin_0 = rng.normal(size=(100, 2)) + 1.0 + src_q_0 = rng.normal(size=(100, 2)) + tgt_q_0 = rng.normal(size=(100, 1)) + 1.0 + src_lin_0 = rng.normal(size=(100, 2)) + tgt_lin_0 = rng.normal(size=(100, 2)) + 1.0 - source_q_1 = 2 * rng.normal(size=(100, 2)) - target_q_1 = 2 * rng.normal(size=(100, 1)) + 1.0 - source_lin_1 = 2 * rng.normal(size=(100, 2)) - target_lin_1 = 2 * rng.normal(size=(100, 2)) + 1.0 + src_q_1 = 2 * rng.normal(size=(100, 2)) + tgt_q_1 = 2 * rng.normal(size=(100, 1)) + 1.0 + src_lin_1 = 2 * rng.normal(size=(100, 2)) + tgt_lin_1 = 2 * rng.normal(size=(100, 2)) + 1.0 ds0 = datasets.OTDataset( - lin=source_lin_0, - target_lin=target_lin_0, - quad=source_q_0, - target_quad=target_q_0, - conditions=np.zeros_like(source_lin_0) * 0.0 + lin=src_lin_0, + tgt_lin=tgt_lin_0, + quad=src_q_0, + tgt_quad=tgt_q_0, + conditions=np.zeros_like(src_lin_0) * 0.0 ) ds1 = datasets.OTDataset( - lin=source_lin_1, - target_lin=target_lin_1, - quad=source_q_1, - target_quad=target_q_1, - conditions=np.ones_like(source_lin_1) * 1.0 + lin=src_lin_1, + tgt_lin=tgt_lin_1, + quad=src_q_1, + tgt_quad=tgt_q_1, + conditions=np.ones_like(src_lin_1) * 1.0 ) sampler0 = torch.utils.data.RandomSampler(ds0, replacement=True) sampler1 = torch.utils.data.RandomSampler(ds1, replacement=True) dl0 = DataLoader(ds0, batch_size=16, sampler=sampler0) dl1 = DataLoader(ds1, batch_size=16, sampler=sampler1) - return datasets.ConditionalOTDataset((dl0, dl1)) + return datasets.ConditionalLoader((dl0, dl1)) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index f938728c1..c60b2e064 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -24,7 +24,6 @@ from ott.neural.flow_models.genot import GENOTLin, GENOTQuad from ott.neural.flow_models.models import VelocityField from ott.neural.flow_models.samplers import uniform_sampler -from ott.neural.models import base_solver from ott.solvers.linear import sinkhorn, sinkhorn_lr from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index f4c0fc1d6..df2fb4bdb 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -12,172 +12,112 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -from typing import Literal, Type import pytest +import jax import jax.numpy as jnp +from torch.utils.data import DataLoader import optax +from ott.neural.data import datasets from ott.neural.flow_models import flows, models, otfm, samplers, utils class TestOTFlowMatching: - @pytest.mark.parametrize( - "flow", [ - flows.ConstantNoiseFlow(0.0), - flows.ConstantNoiseFlow(1.0), - flows.BrownianNoiseFlow(0.2), - ] - ) - def test_flow_matching_unconditional( - self, data_loaders_gaussian, flow: flows.BaseFlow - ): + def test_fm(self, lin_dl: DataLoader): input_dim = 2 neural_vf = models.VelocityField( output_dim=2, condition_dim=0, latent_embed_dim=5, ) - optimizer = optax.adam(learning_rate=1e-3) - fm = otfm.OTFlowMatching( input_dim, neural_vf, - flow=flow, + flow=flows.ConstantNoiseFlow(0.0), time_sampler=samplers.uniform_sampler, - match_fn=utils.match_linear, - optimizer=optimizer, - ) - _ = fm( - n_iters=2, - train_source=data_loaders_gaussian[0], - train_target=data_loaders_gaussian[1], + match_fn=jax.jit(utils.match_linear), + optimizer=optax.adam(learning_rate=1e-3), ) - # TODO(michalk8): nicer - batch_src = next(iter(data_loaders_gaussian[0])) - source = jnp.asarray(batch_src["lin"]) - batch_tgt = next(iter(data_loaders_gaussian[1])) - target = jnp.asarray(batch_tgt["lin"]) - source_conditions = jnp.asarray( - batch_src["conditions"] - ) if "conditions" in batch_src else None - - result_forward = fm.transport(source, condition=source_conditions) - # TODO(michalk8): better condition - assert jnp.sum(jnp.isnan(result_forward)) == 0 - - result_backward = fm.transport( - target, condition=source_conditions, t0=1.0, t1=0.0 - ) - assert jnp.sum(jnp.isnan(result_backward)) == 0 - - @pytest.mark.parametrize( - "flow", [ - flows.ConstantNoiseFlow(0.0), - flows.ConstantNoiseFlow(1.1), - flows.BrownianNoiseFlow(2.2) - ] - ) - def test_flow_matching_with_conditions( - self, data_loader_gaussian_with_conditions, flow: flows.BaseFlow - ): + _logs = fm(lin_dl, n_iters=2) + + for batch in lin_dl: + src = jnp.asarray(batch["src_lin"]) + tgt = jnp.asarray(batch["tgt_lin"]) + break + + res_fwd = fm.transport(src) + res_bwd = fm.transport(tgt, t0=1.0, t1=0.0) + + # TODO(michalk8): better assertions + assert jnp.sum(jnp.isnan(res_fwd)) == 0 + assert jnp.sum(jnp.isnan(res_bwd)) == 0 + + def test_fm_with_conds(self, lin_dl_with_conds: DataLoader): input_dim, cond_dim = 2, 1 neural_vf = models.VelocityField( output_dim=input_dim, condition_dim=cond_dim, latent_embed_dim=5, ) - time_sampler = functools.partial(samplers.uniform_sampler, offset=1e-5) - optimizer = optax.adam(learning_rate=1e-3) - fm = otfm.OTFlowMatching( 2, neural_vf, - match_fn=utils.match_linear, - flow=flow, - time_sampler=time_sampler, - optimizer=optimizer, - ) - _ = fm( - n_iters=2, - train_source=data_loader_gaussian_with_conditions, - train_target=data_loader_gaussian_with_conditions, + flow=flows.BrownianNoiseFlow(0.12), + time_sampler=functools.partial(samplers.uniform_sampler, offset=1e-5), + match_fn=jax.jit(utils.match_linear), + optimizer=optax.adam(learning_rate=1e-3), ) - batch = next(iter(data_loader_gaussian_with_conditions)) - source = jnp.asarray(batch["source_lin"]) - target = jnp.asarray(batch["target_lin"]) - source_conditions = jnp.asarray(batch["source_conditions"]) if len( - batch["source_conditions"] - ) > 0 else None + _logs = fm(lin_dl_with_conds, n_iters=2) - result_forward = fm.transport(source, condition=source_conditions) - assert jnp.sum(jnp.isnan(result_forward)) == 0 + for batch in lin_dl_with_conds: + src = jnp.asarray(batch["src_lin"]) + tgt = jnp.asarray(batch["tgt_lin"]) + src_cond = jnp.asarray(batch["src_condition"]) + break - result_backward = fm.transport( - target, condition=source_conditions, t0=1.0, t1=0.0 - ) - assert jnp.sum(jnp.isnan(result_backward)) == 0 - - @pytest.mark.parametrize( - "flow", - [ - flows.ConstantNoiseFlow(0.0), - flows.ConstantNoiseFlow(13.0), - flows.BrownianNoiseFlow(0.12) - ], - ) - @pytest.mark.parametrize("solver", ["sinkhorn", "lr_sinkhorn"]) - def test_flow_matching_conditional( - self, data_loader_gaussian_conditional, flow: Type[flows.BaseFlow], - solver: Literal["sinkhorn", "lr_sinkhorn"] + res_fwd = fm.transport(src, condition=src_cond) + res_bwd = fm.transport(tgt, condition=src_cond, t0=1.0, t1=0.0) + + # TODO(michalk8): better assertions + assert jnp.sum(jnp.isnan(res_fwd)) == 0 + assert jnp.sum(jnp.isnan(res_bwd)) == 0 + + @pytest.mark.parametrize("rank", [-1, 10]) + def test_fm_conditional_loader( + self, rank: int, conditional_lin_dl: datasets.ConditionalLoader ): - dim = 2 - condition_dim = 0 + input_dim, cond_dim = 2, 0 neural_vf = models.VelocityField( - output_dim=dim, - condition_dim=condition_dim, + output_dim=input_dim, + condition_dim=cond_dim, latent_embed_dim=5, ) - # TODO(michalk8): check for LR - ot_matcher = utils.match_linear - time_sampler = samplers.uniform_sampler - optimizer = optax.adam(learning_rate=1e-3) - fm = otfm.OTFlowMatching( + input_dim, neural_vf, - input_dim=dim, - cond_dim=condition_dim, - ot_matcher=ot_matcher, - flow=flow, - time_sampler=time_sampler, - optimizer=optimizer, - ) - fm( - data_loader_gaussian_conditional, - data_loader_gaussian_conditional, - n_iters=2, - valid_freq=3 + flow=flows.ConstantNoiseFlow(13.0), + time_sampler=samplers.uniform_sampler, + match_fn=jax.jit(functools.partial(utils.match_linear, rank=rank)), + optimizer=optax.adam(learning_rate=1e-3), ) - batch = next(iter(data_loader_gaussian_conditional)) - source = jnp.asarray(batch["source_lin"]) - target = jnp.asarray(batch["target_lin"]) - source_conditions = jnp.asarray(batch["source_conditions"]) if len( - batch["source_conditions"] - ) > 0 else None - result_forward = fm.transport( - source, condition=source_conditions, forward=True - ) - assert isinstance(result_forward, jnp.ndarray) - assert jnp.sum(jnp.isnan(result_forward)) == 0 + _logs = fm(conditional_lin_dl, n_iters=2) - result_backward = fm.transport( - target, condition=source_conditions, forward=False - ) - assert isinstance(result_backward, jnp.ndarray) - assert jnp.sum(jnp.isnan(result_backward)) == 0 + for batch in conditional_lin_dl: + src = jnp.asarray(batch["src_lin"]) + tgt = jnp.asarray(batch["tgt_lin"]) + src_cond = jnp.asarray(batch["src_condition"]) + break + + res_fwd = fm.transport(src, condition=src_cond) + res_bwd = fm.transport(tgt, condition=src_cond, t0=1.0, t1=0.0) + + # TODO(michalk8): better assertions + assert jnp.sum(jnp.isnan(res_fwd)) == 0 + assert jnp.sum(jnp.isnan(res_bwd)) == 0 From a9fe6181ec83fe718e632f0c51e70b3a7d07e2ad Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Thu, 14 Mar 2024 16:13:57 +0100 Subject: [PATCH 120/186] Start cleaning GENOT --- src/ott/neural/flow_models/genot.py | 214 +++++++--------------------- src/ott/neural/flow_models/otfm.py | 33 ++--- 2 files changed, 59 insertions(+), 188 deletions(-) diff --git a/src/ott/neural/flow_models/genot.py b/src/ott/neural/flow_models/genot.py index 88be0da38..fbdd7687c 100644 --- a/src/ott/neural/flow_models/genot.py +++ b/src/ott/neural/flow_models/genot.py @@ -12,135 +12,63 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -import types -from typing import Any, Callable, Dict, Optional, Tuple, Type, Union +from typing import Any, Callable, Optional, Tuple, Union import jax import jax.numpy as jnp import jax.tree_util as jtu import diffrax -import optax from flax.training import train_state from ott import utils -from ott.neural.flow_models import flows, samplers +from ott.neural.flow_models import flows, models __all__ = ["GENOTBase", "GENOTLin", "GENOTQuad"] +# TODO(michalk8): remove the base class? class GENOTBase: - """Base class for GENOT models (:cite:`klein_uscidda:23`). - - GENOT (Generative Entropic Neural Optimal Transport) is a neural solver - for entropic OT prooblems, in the linear - (:class:`ott.neural.flows.genot.GENOTLin`), the Gromov-Wasserstein, and - the Fused Gromov-Wasserstein ((:class:`ott.neural.flows.genot.GENOTQUad`)) - setting. + """TODO :cite:`klein_uscidda:23`. Args: velocity_field: Neural vector field parameterized by a neural network. - input_dim: Dimension of the data in the source distribution. - output_dim: Dimension of the data in the target distribution. - cond_dim: Dimension of the conditioning variable. - valid_freq: Frequency of validation. - ot_solver: OT solver to match samples from the source and the target - distribution. - epsilon: Entropy regularization term of the OT problem solved by - `ot_solver`. - cost_fn: Cost function for the OT problem solved by the `ot_solver`. - In the linear case, this is always expected to be of type `str`. - If the problem is of quadratic type and `cost_fn` is a string, - the `cost_fn` is used for all terms, i.e. both quadratic terms and, - if applicable, the linear temr. If of type :class:`dict`, the keys - are expected to be `cost_fn_xx`, `cost_fn_yy`, and if applicable, - `cost_fn_xy`. - scale_cost: How to scale the cost matrix for the OT problem solved by - the `ot_solver`. In the linear case, this is always expected to be - not a :class:`dict`. If the problem is of quadratic type and - `scale_cost` is a string, the `scale_cost` argument is used for all - terms, i.e. both quadratic terms and, if applicable, the linear temr. - If of type :class:`dict`, the keys are expected to be `scale_cost_xx`, - `scale_cost_yy`, and if applicable, `scale_cost_xy`. - optimizer: Optimizer for `velocity_field`. flow: Flow between latent distribution and target distribution. time_sampler: Sampler for the time. - k_samples_per_x: Number of samples drawn from the conditional distribution of an input sample, see algorithm TODO. - solver_latent_to_data: Linear OT solver to match the latent distribution + data_match_fn: Linear OT solver to match the latent distribution with the conditional distribution. - kwargs_solver_latent_to_data: Keyword arguments for `solver_latent_to_data`. - #TODO: adapt - fused_penalty: Fused penalty of the linear/fused term in the Fused - Gromov-Wasserstein problem. - callback_fn: Callback function. - rng: Random number generator. + latent_match_fn: TODO. + latent_noise_fn: TODO. + k_samples_per_x: Number of samples drawn from the conditional distribution + kwargs: TODO. """ def __init__( self, - velocity_field: Callable[[ - jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray], Optional[jnp.ndarray] - ], jnp.ndarray], - *, - input_dim: int, - output_dim: int, - cond_dim: int, - valid_freq: int, - # TODO(michalk8) - ot_matcher: Any, - optimizer: optax.GradientTransformation, - flow: Type[flows.BaseFlow] = flows.ConstantNoiseFlow(0.0), # noqa: B008 - time_sampler: Callable[[jax.Array, int], - jnp.ndarray] = samplers.uniform_sampler, + velocity_field: models.VelocityField, + flow: flows.BaseFlow, + time_sampler: Callable[[jax.Array, int], jnp.ndarray], + data_match_fn: Any, + latent_match_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray], + jnp.ndarray]] = None, + # TODO(michalk8): add a default for this? + latent_noise_fn: Optional[Callable[[jax.Array, Tuple[int, ...]], + jnp.ndarray]] = None, k_samples_per_x: int = 1, - # TODO(michalk8) - matcher_latent_to_data: Optional[Callable] = None, - kwargs_solver_latent_to_data: Dict[str, Any] = types.MappingProxyType({}), - fused_penalty: float = 0.0, - callback_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], - Any]] = None, - rng: Optional[jax.Array] = None, + **kwargs: Any, ): - rng = utils.default_prng_key(rng) - - self.rng = utils.default_prng_key(rng) - self.valid_freq = valid_freq - self.velocity_field = velocity_field - self.state_velocity_field: Optional[train_state.TrainState] = None + self.vf = velocity_field self.flow = flow self.time_sampler = time_sampler - self.optimizer = optimizer - self.ot_matcher = ot_matcher - self.latent_noise_fn = jax.tree_util.Partial( - jax.random.multivariate_normal, - mean=jnp.zeros((output_dim,)), - cov=jnp.diag(jnp.ones((output_dim,))) - ) - self.input_dim = input_dim - self.output_dim = output_dim - self.cond_dim = cond_dim + self.ot_matcher = data_match_fn + if latent_match_fn is not None: + latent_match_fn = jax.jit(jax.vmap(latent_match_fn, 0, 0)) + self.latent_match_fn = latent_match_fn + self.latent_noise_fn = latent_noise_fn self.k_samples_per_x = k_samples_per_x - # OT data-data matching parameters - - self.fused_penalty = fused_penalty - - # OT latent-data matching parameters - self.matcher_latent_to_data = matcher_latent_to_data - self.kwargs_solver_latent_to_data = kwargs_solver_latent_to_data - - # callback parameteres - self.callback_fn = callback_fn - self.setup() - - def setup(self) -> None: - """Set up the model.""" - self.state_velocity_field = ( - self.velocity_field.create_train_state( - self.rng, self.optimizer, self.output_dim - ) - ) + self.vf_state = self.vf.create_train_state(**kwargs) self.step_fn = self._get_step_fn() def _get_step_fn(self) -> Callable: @@ -148,7 +76,7 @@ def _get_step_fn(self) -> Callable: @jax.jit def step_fn( rng: jax.Array, - state_velocity_field: train_state.TrainState, + vf_state: train_state.TrainState, time: jnp.ndarray, source: jnp.ndarray, target: jnp.ndarray, @@ -162,9 +90,7 @@ def loss_fn( source_conditions: Optional[jnp.ndarray], rng: jax.Array ): x_t = self.flow.compute_xt(rng, time, latent, target) - apply_fn = functools.partial( - state_velocity_field.apply_fn, {"params": params} - ) + apply_fn = functools.partial(vf_state.apply_fn, {"params": params}) cond_input = jnp.concatenate([ source, source_conditions @@ -175,11 +101,10 @@ def loss_fn( grad_fn = jax.value_and_grad(loss_fn, has_aux=False) loss, grads = grad_fn( - state_velocity_field.params, time, source, target, latent, - source_conditions, rng + vf_state.params, time, source, target, latent, source_conditions, rng ) - return state_velocity_field.apply_gradients(grads=grads), loss + return vf_state.apply_gradients(grads=grads), loss return step_fn @@ -218,7 +143,7 @@ def transport( raise NotImplementedError if condition is not None: assert len(source) == len(condition), (len(source), len(condition)) - latent_batch = self.latent_noise_fn(rng, shape=(len(source),)) + latent_batch = self.latent_noise_fn(rng, (len(source),)) cond_input = source if condition is None else ( jnp.concatenate([source, condition], axis=-1) ) @@ -226,11 +151,8 @@ def transport( @jax.jit def solve_ode(input: jnp.ndarray, cond: jnp.ndarray) -> jnp.ndarray: ode_term = diffrax.ODETerm( - lambda t, x, args: self.state_velocity_field. - apply_fn({"params": self.state_velocity_field.params}, - t=t, - x=x, - condition=cond) + lambda t, x, args: self.vf_state. + apply_fn({"params": self.vf_state.params}, t=t, x=x, condition=cond) ), solver = kwargs.pop("solver", diffrax.Tsit5()) stepsize_controller = kwargs.pop( @@ -250,14 +172,6 @@ def solve_ode(input: jnp.ndarray, cond: jnp.ndarray) -> jnp.ndarray: return jax.vmap(solve_ode)(latent_batch, cond_input) - def _valid_step(self, valid_loader, iter): - pass - - @property - def learn_rescaling(self) -> bool: - """Whether to learn at least one rescaling factor.""" - return False - def _reshape_samples(self, arrays: Tuple[jnp.ndarray, ...], batch_size: int) -> Tuple[jnp.ndarray, ...]: return jax.tree_util.tree_map( @@ -265,12 +179,6 @@ def _reshape_samples(self, arrays: Tuple[jnp.ndarray, ...], arrays ) - def _learn_rescaling( - self, source: jnp.ndarray, target: jnp.ndarray, - source_conditions: Optional[jnp.ndarray], tmat: jnp.ndarray - ) -> Tuple[jnp.ndarray, jnp.ndarray, float, float]: - raise NotImplementedError - class GENOTLin(GENOTBase): """Implementation of GENOT-L (:cite:`klein:23`). @@ -293,7 +201,7 @@ def __call__( rng = utils.default_prng_key(rng) training_logs = {"loss": []} - for it in range(n_iters): + for _ in range(n_iters): for batch_source, batch_target in zip(train_source, train_target): ( rng, rng_resample, rng_noise, rng_time, rng_latent_data_match, @@ -311,7 +219,7 @@ def __call__( n_samples = batch_size * self.k_samples_per_x time = self.time_sampler(rng_time, n_samples) latent = self.latent_noise_fn( - rng_noise, shape=(self.k_samples_per_x, batch_size) + rng_noise, (self.k_samples_per_x, batch_size) ) tmat = self.ot_matcher.match_fn( @@ -329,11 +237,9 @@ def __call__( source_is_balanced=(self.ot_matcher.tau_a == 1.0) ) - if self.matcher_latent_to_data is not None: - tmats_latent_data = jnp.array( - jax.vmap(self.matcher_latent_to_data.match_fn, 0, - 0)(x=latent, y=target) - ) + if self.latent_match_fn is not None: + # already vmapped + tmats_latent_data = self.latent_match_fn(latent, target) rng_latent_data_match = jax.random.split( rng_latent_data_match, self.k_samples_per_x @@ -347,23 +253,13 @@ def __call__( source, source_conditions, target, latent = self._reshape_samples( (source, source_conditions, target, latent), batch_size ) - self.state_velocity_field, loss = self.step_fn( - rng_step_fn, self.state_velocity_field, time, source, target, - latent, source_conditions + self.vf_state, loss = self.step_fn( + rng_step_fn, self.vf_state, time, source, target, latent, + source_conditions ) - if self.learn_rescaling: - eta_preds, xi_preds, loss_a, loss_b = self._learn_rescaling( - source=source, - target=target, - condition=source_conditions, - tmat=tmat - ) training_logs["loss"].append(float(loss)) - if it % valid_freq == 0: - self._valid_step(valid_source, valid_target, it) - class GENOTQuad(GENOTBase): """Implementation of GENOT-Q and GENOT-F (:cite:`klein:23`). @@ -388,7 +284,7 @@ def __call__( rng = utils.default_prng_key(rng) training_logs = {"loss": []} - for it in range(n_iters): + for _ in range(n_iters): for batch_source, batch_target in zip(train_source, train_target): ( rng, rng_resample, rng_noise, rng_time, rng_latent_data_match, @@ -408,7 +304,7 @@ def __call__( n_samples = batch_size * self.k_samples_per_x time = self.time_sampler(rng_time, n_samples) latent = self.latent_noise_fn( - rng_noise, shape=(self.k_samples_per_x, batch_size) + rng_noise, (self.k_samples_per_x, batch_size) ) tmat = self.ot_matcher.match_fn( @@ -433,11 +329,9 @@ def __call__( ) ) - if self.matcher_latent_to_data is not None: - tmats_latent_data = jnp.array( - jax.vmap(self.matcher_latent_to_data.match_fn, 0, - 0)(x=latent, y=target) - ) + if self.latent_match_fn is not None: + # already vmapped + tmats_latent_data = self.latent_match_fn(latent, target) rng_latent_data_match = jax.random.split( rng_latent_data_match, self.k_samples_per_x @@ -453,18 +347,8 @@ def __call__( (source, source_conditions, target, latent), batch_size ) - self.state_velocity_field, loss = self.step_fn( - rng_step_fn, self.state_velocity_field, time, source, target, - latent, source_conditions + self.vf_state, loss = self.step_fn( + rng_step_fn, self.vf_state, time, source, target, latent, + source_conditions ) - if self.learn_rescaling: - eta_preds, xi_preds, loss_a, loss_b = self._learn_rescaling( - source=source, - target=target, - condition=source_conditions, - tmat=tmat - ) training_logs["loss"].append(float(loss)) - - if it % valid_freq == 0: - self._valid_step(valid_source, valid_target, it) diff --git a/src/ott/neural/flow_models/otfm.py b/src/ott/neural/flow_models/otfm.py index 83d709a96..79ffa016c 100644 --- a/src/ott/neural/flow_models/otfm.py +++ b/src/ott/neural/flow_models/otfm.py @@ -19,11 +19,10 @@ import jax.tree_util as jtu import diffrax -import optax from flax.training import train_state from ott import utils -from ott.neural.flow_models import flows, models +from ott.neural.flow_models import flows, models, samplers from ott.neural.flow_models.utils import resample_data, sample_joint __all__ = ["OTFlowMatching"] @@ -35,40 +34,31 @@ class OTFlowMatching: With an extension to OT-FM :cite:`tong:23`, :cite:`pooladian:23`. Args: - input_dim: Dimension of the input data. velocity_field: Neural vector field parameterized by a neural network. flow: Flow between source and target distribution. time_sampler: Sampler for the time. - optimizer: Optimizer for the velocity field's parameters. match_fn: TODO. - rng: Random number generator. + kwargs: TODO. """ # TODO(michalk8): in the future, `input_dim`, `optimizer` and `rng` will be # in a separate function def __init__( self, - input_dim: int, velocity_field: models.VelocityField, flow: flows.BaseFlow, - time_sampler: Callable[[jax.Array, int], jnp.ndarray], - optimizer: optax.GradientTransformation, + time_sampler: Callable[[jax.Array, int], + jnp.ndarray] = samplers.uniform_sampler, match_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None, - rng: Optional[jax.Array] = None, + **kwargs: Any, ): - rng = utils.default_prng_key(rng) - - self.input_dim = input_dim self.vf = velocity_field self.flow = flow self.time_sampler = time_sampler self.match_fn = match_fn - self.optimizer = optimizer - self.vf_state = self.vf.create_train_state( - rng, self.optimizer, self.input_dim - ) + self.vf_state = self.vf.create_train_state(**kwargs) self.step_fn = self._get_step_fn() def _get_step_fn(self) -> Callable: @@ -76,7 +66,7 @@ def _get_step_fn(self) -> Callable: @jax.jit def step_fn( rng: jax.Array, - state_velocity_field: train_state.TrainState, + vf_state: train_state.TrainState, source: jnp.ndarray, target: jnp.ndarray, source_conditions: Optional[jnp.ndarray], @@ -89,9 +79,7 @@ def loss_fn( ) -> jnp.ndarray: x_t = self.flow.compute_xt(rng, t, source, target) - apply_fn = functools.partial( - state_velocity_field.apply_fn, {"params": params} - ) + apply_fn = functools.partial(vf_state.apply_fn, {"params": params}) v_t = jax.vmap(apply_fn)(t=t, x=x_t, condition=source_conditions) u_t = self.flow.compute_ut(t, source, target) return jnp.mean((v_t - u_t) ** 2) @@ -101,10 +89,9 @@ def loss_fn( t = self.time_sampler(key_t, batch_size) grad_fn = jax.value_and_grad(loss_fn) loss, grads = grad_fn( - state_velocity_field.params, t, source, target, source_conditions, - key_model + vf_state.params, t, source, target, source_conditions, key_model ) - return state_velocity_field.apply_gradients(grads=grads), loss + return vf_state.apply_gradients(grads=grads), loss return step_fn From abca4f72e7a0da25f0d1f351e5f6bcfc136737f8 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 15 Mar 2024 02:25:11 +0100 Subject: [PATCH 121/186] Update GENOT --- src/ott/neural/data/datasets.py | 2 +- src/ott/neural/flow_models/genot.py | 181 ++++++++++++++++++++-------- src/ott/neural/flow_models/otfm.py | 23 ++-- src/ott/neural/flow_models/utils.py | 12 -- 4 files changed, 143 insertions(+), 75 deletions(-) diff --git a/src/ott/neural/data/datasets.py b/src/ott/neural/data/datasets.py index 0a2067f25..eca6f1e51 100644 --- a/src/ott/neural/data/datasets.py +++ b/src/ott/neural/data/datasets.py @@ -75,7 +75,7 @@ def __init__( def _verify_integriy(self) -> None: assert len(self.src_data) == len(self.src_conditions) - assert len(self.src_data) == len(self.tgt_conditions) + assert len(self.tgt_data) == len(self.tgt_conditions) if self.is_aligned: assert len(self.src_data) == len(self.tgt_data) diff --git a/src/ott/neural/flow_models/genot.py b/src/ott/neural/flow_models/genot.py index fbdd7687c..a4878c892 100644 --- a/src/ott/neural/flow_models/genot.py +++ b/src/ott/neural/flow_models/genot.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -from typing import Any, Callable, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple import jax import jax.numpy as jnp @@ -23,12 +23,13 @@ from ott import utils from ott.neural.flow_models import flows, models +from ott.neural.flow_models import utils as flow_utils -__all__ = ["GENOTBase", "GENOTLin", "GENOTQuad"] +__all__ = ["GENOT", "GENOTLin", "GENOTQuad"] # TODO(michalk8): remove the base class? -class GENOTBase: +class GENOT: """TODO :cite:`klein_uscidda:23`. Args: @@ -49,7 +50,9 @@ def __init__( velocity_field: models.VelocityField, flow: flows.BaseFlow, time_sampler: Callable[[jax.Array, int], jnp.ndarray], - data_match_fn: Any, + # TODO(mcihalk8): all args are optional + data_match_fn: Callable[ + [jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray], jnp.ndarray], latent_match_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None, # TODO(michalk8): add a default for this? @@ -61,9 +64,7 @@ def __init__( self.vf = velocity_field self.flow = flow self.time_sampler = time_sampler - self.ot_matcher = data_match_fn - if latent_match_fn is not None: - latent_match_fn = jax.jit(jax.vmap(latent_match_fn, 0, 0)) + self.data_match_fn = data_match_fn self.latent_match_fn = latent_match_fn self.latent_noise_fn = latent_noise_fn self.k_samples_per_x = k_samples_per_x @@ -108,16 +109,110 @@ def loss_fn( return step_fn + def __call__( + self, + loader: Any, + n_iters: int, + rng: Optional[jax.Array] = None + ) -> Dict[str, List[float]]: + """TODO.""" + + def prepare_data( + batch: Dict[str, jnp.ndarray] + ) -> Tuple[Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray], Tuple[ + jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]]: + src_lin, src_quad = batch.get("src_lin"), batch.get("src_quad") + tgt_lin, tgt_quad = batch.get("tgt_lin"), batch.get("tgt_quad") + arrs = src_lin, tgt_lin, src_quad, tgt_quad + + if src_quad is None and tgt_quad is None: # lin + src, tgt = src_lin, tgt_lin + elif src_lin is None and tgt_lin is None: # quad + src, tgt = src_quad, tgt_quad + elif all(arr is not None for arr in arrs): # fused quad + src = jnp.concatenate([src_lin, src_quad], axis=1) + tgt = jnp.concatenate([tgt_lin, tgt_quad], axis=1) + else: + raise RuntimeError("TODO") + + # TODO(michalk8): filter `None` from the `arrs`? + return (src, batch.get("src_condition"), tgt), arrs + + rng = utils.default_prng_key(rng) + training_logs = {"loss": []} + for batch in loader: + rng = jax.random.split(rng, 6) + rng, rng_resample, rng_noise, rng_time, rng_latent, rng_step_fn = rng + + batch = jtu.tree_map(jnp.asarray, batch) + (src, src_cond, tgt), data = prepare_data(batch) + + time = self.time_sampler(rng_time, len(src) * self.k_samples_per_x) + latent = self.latent_noise_fn(rng_noise, (self.k_samples_per_x, len(src))) + + tmat = self.data_match_fn(*data) # (n, m) + src_ixs, tgt_ixs = flow_utils.sample_conditional( # (n, k), (m, k) + rng_resample, + tmat, + k=self.k_samples_per_x, + uniform_marginals=True, # TODO(michalk8): expose + ) + + src = src[src_ixs].swapaxes(0, 1) # (k, n, ...) + tgt = tgt[tgt_ixs].swapaxes(0, 1) # (k, m, ...) + if src_cond is not None: + src_cond = src_cond[src_ixs].swapaxes(0, 1) # (k, n, ...) + + if self.latent_match_fn is not None: + src, src_cond, tgt = self._match_latent(rng, src, src_cond, latent, tgt) + + src = src.reshape(-1, *src.shape[2:]) # (k * bs, ...) + tgt = tgt.reshape(-1, *tgt.shape[2:]) + latent = latent.reshape(-1, *latent.shape[2:]) + if src_cond is not None: + src_cond = src_cond.reshape(-1, *src_cond.shape[2:]) + + self.vf_state, loss = self.step_fn( + rng_step_fn, self.vf_state, time, src, tgt, latent, src_cond + ) + training_logs["loss"].append(float(loss)) + + return training_logs + + def _match_latent( + self, rng: jax.Array, src: jnp.ndarray, src_cond: Optional[jnp.ndarray], + latent: jnp.ndarray, tgt: jnp.ndarray + ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray]: + + def resample( + rng: jax.Array, src: jnp.ndarray, src_cond: Optional[jnp.ndarray], + tgt: jnp.ndarray, latent: jnp.ndarray + ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray]: + tmat = self.latent_match_fn(latent, tgt) # (n, k) + + src_ixs, tgt_ixs = flow_utils.sample_joint(rng, tmat) # (n,), (m,) + src, tgt = src[src_ixs], tgt[tgt_ixs] + if src_cond is not None: + src_cond = src_cond[src_ixs] + + return src, src_cond, tgt + + cond_axis = None if src_cond is None else 0 + in_axes, out_axes = (0, 0, cond_axis, 0, 0), (0, None, 0) + resample_fn = jax.jit(jax.vmap(resample, in_axes, out_axes)) + + rngs = jax.random.split(rng, self.k_samples_per_x) + return resample_fn(rngs, src, src_cond, tgt, latent) + def transport( self, source: jnp.ndarray, condition: Optional[jnp.ndarray] = None, + t0: float = 0.0, + t1: float = 1.0, rng: Optional[jax.Array] = None, - forward: bool = True, - t_0: float = 0.0, - t_1: float = 1.0, **kwargs: Any, - ) -> Union[jnp.array, diffrax.Solution, Optional[jnp.ndarray]]: + ) -> jnp.ndarray: """Transport data with the learnt plan. This method pushes-forward the `source` to its conditional distribution by @@ -127,60 +222,48 @@ def transport( Args: source: Data to transport. condition: Condition of the input data. + t0: Starting time of integration of neural ODE. + t1: End time of integration of neural ODE. rng: random seed for sampling from the latent distribution. - forward: If `True` integrates forward, otherwise backwards. - t_0: Starting time of integration of neural ODE. - t_1: End time of integration of neural ODE. kwargs: Keyword arguments for the ODE solver. Returns: The push-forward or pull-back distribution defined by the learnt transport plan. - """ - rng = utils.default_prng_key(rng) - if not forward: - raise NotImplementedError - if condition is not None: - assert len(source) == len(condition), (len(source), len(condition)) - latent_batch = self.latent_noise_fn(rng, (len(source),)) - cond_input = source if condition is None else ( - jnp.concatenate([source, condition], axis=-1) - ) - @jax.jit - def solve_ode(input: jnp.ndarray, cond: jnp.ndarray) -> jnp.ndarray: - ode_term = diffrax.ODETerm( - lambda t, x, args: self.vf_state. - apply_fn({"params": self.vf_state.params}, t=t, x=x, condition=cond) - ), - solver = kwargs.pop("solver", diffrax.Tsit5()) - stepsize_controller = kwargs.pop( - "stepsize_controller", diffrax.PIDController(rtol=1e-5, atol=1e-5) - ) + def vf(t: jnp.ndarray, x: jnp.ndarray, cond: jnp.ndarray) -> jnp.ndarray: + params = self.vf_state.params + return self.vf_state.apply_fn({"params": params}, t, x, cond) + + def solve_ode(x: jnp.ndarray, cond: jnp.ndarray) -> jnp.ndarray: + ode_term = diffrax.ODETerm(vf) sol = diffrax.diffeqsolve( ode_term, - solver, - t0=t_0, - t1=t_1, - dt0=kwargs.pop("dt0", None), - y0=input, - stepsize_controller=stepsize_controller, + t0=t0, + t1=t1, + y0=x, + args=cond, **kwargs, ) return sol.ys[0] - return jax.vmap(solve_ode)(latent_batch, cond_input) - - def _reshape_samples(self, arrays: Tuple[jnp.ndarray, ...], - batch_size: int) -> Tuple[jnp.ndarray, ...]: - return jax.tree_util.tree_map( - lambda x: jnp.reshape(x, (batch_size * self.k_samples_per_x, -1)), - arrays + kwargs.setdefault("dt0", None) + kwargs.setdefault("solver", diffrax.Tsit5()) + kwargs.setdefault( + "stepsize_controller", diffrax.PIDController(rtol=1e-5, atol=1e-5) ) + rng = utils.default_prng_key(rng) + latent = self.latent_noise_fn(rng, (len(source),)) + + if condition is not None: + source = jnp.concatenate([source, condition], axis=-1) + + return jax.jit(jax.vmap(solve_ode))(latent, source) + -class GENOTLin(GENOTBase): +class GENOTLin(GENOT): """Implementation of GENOT-L (:cite:`klein:23`). GENOT-L (Generative Entropic Neural Optimal Transport, linear) is a @@ -261,7 +344,7 @@ def __call__( training_logs["loss"].append(float(loss)) -class GENOTQuad(GENOTBase): +class GENOTQuad(GENOT): """Implementation of GENOT-Q and GENOT-F (:cite:`klein:23`). GENOT-Q (Generative Entropic Neural Optimal Transport, quadratic) and diff --git a/src/ott/neural/flow_models/otfm.py b/src/ott/neural/flow_models/otfm.py index 79ffa016c..b793a17c3 100644 --- a/src/ott/neural/flow_models/otfm.py +++ b/src/ott/neural/flow_models/otfm.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import jax import jax.numpy as jnp @@ -23,7 +23,7 @@ from ott import utils from ott.neural.flow_models import flows, models, samplers -from ott.neural.flow_models.utils import resample_data, sample_joint +from ott.neural.flow_models.utils import sample_joint __all__ = ["OTFlowMatching"] @@ -80,7 +80,7 @@ def loss_fn( x_t = self.flow.compute_xt(rng, t, source, target) apply_fn = functools.partial(vf_state.apply_fn, {"params": params}) - v_t = jax.vmap(apply_fn)(t=t, x=x_t, condition=source_conditions) + v_t = jax.vmap(apply_fn)(t, x_t, source_conditions) u_t = self.flow.compute_ut(t, source, target) return jnp.mean((v_t - u_t) ** 2) @@ -102,30 +102,29 @@ def __call__( # noqa: D102 *, n_iters: int, rng: Optional[jax.Array] = None, - ) -> Dict[str, Any]: + ) -> Dict[str, List[float]]: rng = utils.default_prng_key(rng) training_logs = {"loss": []} - for batch in loader: rng, rng_resample, rng_step_fn = jax.random.split(rng, 3) batch = jtu.tree_map(jnp.asarray, batch) src, tgt = batch["src_lin"], batch["tgt_lin"] - src_conds = batch.get("src_condition", None) + src_cond = batch.get("src_condition") if self.match_fn is not None: tmat = self.match_fn(src, tgt) src_ixs, tgt_ixs = sample_joint(rng_resample, tmat) - src, src_conds = resample_data(src, src_conds, ixs=src_ixs) - tgt = resample_data(tgt, ixs=tgt_ixs) + src, tgt = src[src_ixs], tgt[tgt_ixs] + src_cond = None if src_cond is None else src_cond[src_ixs] self.vf_state, loss = self.step_fn( rng_step_fn, self.vf_state, src, tgt, - src_conds, + src_cond, ) training_logs["loss"].append(float(loss)) @@ -162,10 +161,8 @@ def transport( def vf( t: jnp.ndarray, x: jnp.ndarray, cond: Optional[jnp.ndarray] ) -> jnp.ndarray: - return self.vf_state.apply_fn({"params": self.vf_state.params}, - t=t, - x=x, - condition=cond) + params = self.vf_state.params + return self.vf_state.apply_fn({"params": params}, t, x, cond) def solve_ode(x: jnp.ndarray, cond: Optional[jnp.ndarray]) -> jnp.ndarray: ode_term = diffrax.ODETerm(vf) diff --git a/src/ott/neural/flow_models/utils.py b/src/ott/neural/flow_models/utils.py index d6d023266..9b8386ea9 100644 --- a/src/ott/neural/flow_models/utils.py +++ b/src/ott/neural/flow_models/utils.py @@ -2,7 +2,6 @@ import jax import jax.numpy as jnp -import jax.tree_util as jtu from ott.geometry import costs, pointcloud from ott.solvers import linear, quadratic @@ -12,7 +11,6 @@ "match_quadratic", "sample_joint", "sample_conditional", - "resample_data", ] ScaleCost_t = Union[float, Literal["mean", "max_cost", "median"]] @@ -95,13 +93,3 @@ def sample_conditional( src_ixs = jnp.repeat(indices[:, None], k, axis=1) # (n, k) return src_ixs, tgt_ixs - - -def resample_data(*data: Optional[jnp.ndarray], - ixs: jnp.ndarray) -> Tuple[Optional[jnp.ndarray], ...]: - """TODO.""" - if ixs.ndim == 2: - ixs = ixs.reshape(-1) - assert ixs.ndim == 1, ixs.shape - data = jtu.tree_map(lambda arr: None if arr is None else arr[ixs], data) - return data[0] if len(data) == 1 else data From f2c20a47e5d34c90dfb7ec1d1159ca936fb22ed7 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 15 Mar 2024 02:25:50 +0100 Subject: [PATCH 122/186] Remove old GENOTLin/GENOTQuad --- src/ott/neural/flow_models/genot.py | 176 +--------------------------- 1 file changed, 1 insertion(+), 175 deletions(-) diff --git a/src/ott/neural/flow_models/genot.py b/src/ott/neural/flow_models/genot.py index a4878c892..f6b9c080e 100644 --- a/src/ott/neural/flow_models/genot.py +++ b/src/ott/neural/flow_models/genot.py @@ -25,7 +25,7 @@ from ott.neural.flow_models import flows, models from ott.neural.flow_models import utils as flow_utils -__all__ = ["GENOT", "GENOTLin", "GENOTQuad"] +__all__ = ["GENOT"] # TODO(michalk8): remove the base class? @@ -261,177 +261,3 @@ def solve_ode(x: jnp.ndarray, cond: jnp.ndarray) -> jnp.ndarray: source = jnp.concatenate([source, condition], axis=-1) return jax.jit(jax.vmap(solve_ode))(latent, source) - - -class GENOTLin(GENOT): - """Implementation of GENOT-L (:cite:`klein:23`). - - GENOT-L (Generative Entropic Neural Optimal Transport, linear) is a - neural solver for entropic (linear) OT problems. - """ - - def __call__( - self, - n_iters: int, - train_source, - train_target, - valid_source, - valid_target, - valid_freq: int = 5000, - rng: Optional[jax.Array] = None, - ): - """Train GENOTLin.""" - rng = utils.default_prng_key(rng) - training_logs = {"loss": []} - - for _ in range(n_iters): - for batch_source, batch_target in zip(train_source, train_target): - ( - rng, rng_resample, rng_noise, rng_time, rng_latent_data_match, - rng_step_fn - ) = jax.random.split(rng, 6) - - batch_source = jtu.tree_map(jnp.asarray, batch_source) - batch_target = jtu.tree_map(jnp.asarray, batch_target) - - source = batch_source["lin"] - source_conditions = batch_source.get("conditions", None) - target = batch_target["lin"] - - batch_size = len(source) - n_samples = batch_size * self.k_samples_per_x - time = self.time_sampler(rng_time, n_samples) - latent = self.latent_noise_fn( - rng_noise, (self.k_samples_per_x, batch_size) - ) - - tmat = self.ot_matcher.match_fn( - source, - target, - ) - - (source, source_conditions - ), (target,) = self.ot_matcher.sample_conditional_indices_from_tmap( - rng=rng_resample, - conditional_distributions=tmat, - k_samples_per_x=self.k_samples_per_x, - source_arrays=(source, source_conditions), - target_arrays=(target,), - source_is_balanced=(self.ot_matcher.tau_a == 1.0) - ) - - if self.latent_match_fn is not None: - # already vmapped - tmats_latent_data = self.latent_match_fn(latent, target) - - rng_latent_data_match = jax.random.split( - rng_latent_data_match, self.k_samples_per_x - ) - (source, source_conditions - ), (target,) = jax.vmap(self.ot_matcher.sample_joint, 0, 0)( - rng_latent_data_match, tmats_latent_data, - (source, source_conditions), (target,) - ) - - source, source_conditions, target, latent = self._reshape_samples( - (source, source_conditions, target, latent), batch_size - ) - self.vf_state, loss = self.step_fn( - rng_step_fn, self.vf_state, time, source, target, latent, - source_conditions - ) - - training_logs["loss"].append(float(loss)) - - -class GENOTQuad(GENOT): - """Implementation of GENOT-Q and GENOT-F (:cite:`klein:23`). - - GENOT-Q (Generative Entropic Neural Optimal Transport, quadratic) and - GENOT-F (Generative Entropic Neural Optimal Transport, fused) are neural - solver for entropic Gromov-Wasserstein and entropic Fused Gromov-Wasserstein - problems, respectively. - """ - - def __call__( - self, - n_iters: int, - train_source, - train_target, - valid_source, - valid_target, - valid_freq: int = 5000, - rng: Optional[jax.Array] = None, - ): - """Train GENOTQuad.""" - rng = utils.default_prng_key(rng) - training_logs = {"loss": []} - - for _ in range(n_iters): - for batch_source, batch_target in zip(train_source, train_target): - ( - rng, rng_resample, rng_noise, rng_time, rng_latent_data_match, - rng_step_fn - ) = jax.random.split(rng, 6) - - batch_source = jtu.tree_map(jnp.asarray, batch_source) - batch_target = jtu.tree_map(jnp.asarray, batch_target) - - source_lin = batch_source.get("lin", None) - source_quad = batch_source["quad"] - source_conditions = batch_source.get("conditions", None) - target_lin = batch_target.get("lin", None) - target_quad = batch_target["quad"] - - batch_size = len(source_quad) - n_samples = batch_size * self.k_samples_per_x - time = self.time_sampler(rng_time, n_samples) - latent = self.latent_noise_fn( - rng_noise, (self.k_samples_per_x, batch_size) - ) - - tmat = self.ot_matcher.match_fn( - source_quad, target_quad, source_lin, target_lin - ) - - if self.ot_matcher.fused_penalty > 0.0: - source = jnp.concatenate((source_lin, source_quad), axis=1) - target = jnp.concatenate((target_lin, target_quad), axis=1) - else: - source = source_quad - target = target_quad - - (source, source_conditions), (target,) = ( - self.ot_matcher.sample_conditional_indices_from_tmap( - rng=rng_resample, - conditional_distributions=tmat, - k_samples_per_x=self.k_samples_per_x, - source_arrays=(source, source_conditions), - target_arrays=(target,), - source_is_balanced=(self.ot_matcher.tau_a == 1.0) - ) - ) - - if self.latent_match_fn is not None: - # already vmapped - tmats_latent_data = self.latent_match_fn(latent, target) - - rng_latent_data_match = jax.random.split( - rng_latent_data_match, self.k_samples_per_x - ) - - (source, source_conditions - ), (target,) = jax.vmap(self.ot_matcher.sample_joint, 0, 0)( - rng_latent_data_match, tmats_latent_data, - (source, source_conditions), (target,) - ) - - source, source_conditions, target, latent = self._reshape_samples( - (source, source_conditions, target, latent), batch_size - ) - - self.vf_state, loss = self.step_fn( - rng_step_fn, self.vf_state, time, source, target, latent, - source_conditions - ) - training_logs["loss"].append(float(loss)) From 693ecc4e4fa9cd03f9e8eb3cfa9f2c0244521095 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 15 Mar 2024 02:34:55 +0100 Subject: [PATCH 123/186] Remove axis swapping --- src/ott/neural/flow_models/genot.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/ott/neural/flow_models/genot.py b/src/ott/neural/flow_models/genot.py index f6b9c080e..b2f5965c2 100644 --- a/src/ott/neural/flow_models/genot.py +++ b/src/ott/neural/flow_models/genot.py @@ -145,12 +145,13 @@ def prepare_data( rng, rng_resample, rng_noise, rng_time, rng_latent, rng_step_fn = rng batch = jtu.tree_map(jnp.asarray, batch) - (src, src_cond, tgt), data = prepare_data(batch) + (src, src_cond, tgt), matching_data = prepare_data(batch) - time = self.time_sampler(rng_time, len(src) * self.k_samples_per_x) - latent = self.latent_noise_fn(rng_noise, (self.k_samples_per_x, len(src))) + n = src.shape[0] + time = self.time_sampler(rng_time, n * self.k_samples_per_x) + latent = self.latent_noise_fn(rng_noise, (n, self.k_samples_per_x)) - tmat = self.data_match_fn(*data) # (n, m) + tmat = self.data_match_fn(*matching_data) # (n, m) src_ixs, tgt_ixs = flow_utils.sample_conditional( # (n, k), (m, k) rng_resample, tmat, @@ -158,16 +159,15 @@ def prepare_data( uniform_marginals=True, # TODO(michalk8): expose ) - src = src[src_ixs].swapaxes(0, 1) # (k, n, ...) - tgt = tgt[tgt_ixs].swapaxes(0, 1) # (k, m, ...) + src, tgt = src[src_ixs], tgt[tgt_ixs] # (n, k, ...), # (m, k, ...) if src_cond is not None: - src_cond = src_cond[src_ixs].swapaxes(0, 1) # (k, n, ...) + src_cond = src_cond[src_ixs] if self.latent_match_fn is not None: src, src_cond, tgt = self._match_latent(rng, src, src_cond, latent, tgt) - src = src.reshape(-1, *src.shape[2:]) # (k * bs, ...) - tgt = tgt.reshape(-1, *tgt.shape[2:]) + src = src.reshape(-1, *src.shape[2:]) # (n * k, ...) + tgt = tgt.reshape(-1, *tgt.shape[2:]) # (m * k, ...) latent = latent.reshape(-1, *latent.shape[2:]) if src_cond is not None: src_cond = src_cond.reshape(-1, *src_cond.shape[2:]) @@ -197,8 +197,8 @@ def resample( return src, src_cond, tgt - cond_axis = None if src_cond is None else 0 - in_axes, out_axes = (0, 0, cond_axis, 0, 0), (0, None, 0) + cond_axis = None if src_cond is None else 1 + in_axes, out_axes = (0, 1, cond_axis, 1, 1), (1, cond_axis, 1) resample_fn = jax.jit(jax.vmap(resample, in_axes, out_axes)) rngs = jax.random.split(rng, self.k_samples_per_x) From 3d9c70278187e330bad60dca9d6c3294bc25e212 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 15 Mar 2024 02:39:49 +0100 Subject: [PATCH 124/186] Remove old todo --- src/ott/neural/flow_models/genot.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/ott/neural/flow_models/genot.py b/src/ott/neural/flow_models/genot.py index b2f5965c2..2ac511611 100644 --- a/src/ott/neural/flow_models/genot.py +++ b/src/ott/neural/flow_models/genot.py @@ -28,7 +28,6 @@ __all__ = ["GENOT"] -# TODO(michalk8): remove the base class? class GENOT: """TODO :cite:`klein_uscidda:23`. From f27d209e1fa8abe3e4bb69e6104a79b68743615b Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 15 Mar 2024 14:39:01 +0100 Subject: [PATCH 125/186] Fix OTFM tests --- src/ott/neural/flow_models/models.py | 20 ++---- src/ott/neural/flow_models/otfm.py | 7 +- tests/neural/conftest.py | 12 ++-- tests/neural/genot_test.py | 1 - tests/neural/otfm_test.py | 102 ++++++--------------------- 5 files changed, 37 insertions(+), 105 deletions(-) diff --git a/src/ott/neural/flow_models/models.py b/src/ott/neural/flow_models/models.py index c71fff2c2..26cc31915 100644 --- a/src/ott/neural/flow_models/models.py +++ b/src/ott/neural/flow_models/models.py @@ -50,9 +50,6 @@ class VelocityField(nn.Module): If :obj:`None`, set to ``latent_embed_dim``. t_embed_dim: Dimensionality of the time embedding. If :obj:`None`, set to ``latent_embed_dim``. - joint_hidden_dim: Dimensionality of the hidden layers of the joint network. - If :obj:`None`, set to ``latent_embed_dim + condition_embed_dim + - t_embed_dim``. num_layers_per_block: Number of layers per block. act_fn: Activation function. n_freqs: Number of frequencies to use for the time embedding. @@ -62,7 +59,6 @@ class VelocityField(nn.Module): condition_dim: int = 0 condition_embed_dim: Optional[int] = None t_embed_dim: Optional[int] = None - joint_hidden_dim: Optional[int] = None num_layers_per_block: int = 3 act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.silu n_freqs: int = 128 @@ -72,18 +68,9 @@ def __post_init__(self) -> None: self.condition_embed_dim = self.latent_embed_dim if self.t_embed_dim is None: self.t_embed_dim = self.latent_embed_dim - - concat_embed_dim = ( + self.joint_hidden_dim = ( self.latent_embed_dim + self.condition_embed_dim + self.t_embed_dim ) - if self.joint_hidden_dim is not None: - assert (self.joint_hidden_dim >= concat_embed_dim), ( - "joint_hidden_dim must be greater than or equal to the sum of" - " all embedded dimensions." - ) - self.joint_hidden_dim = self.latent_embed_dim - else: - self.joint_hidden_dim = concat_embed_dim super().__post_init__() @nn.compact @@ -121,8 +108,11 @@ def __call__( x = x_layer(x) if self.condition_dim > 0: + assert condition is not None, \ + "Condition must be specified when `condition_dim > 0`." condition_layer = layers.MLPBlock( - dim=self.condition_embed_dim, + # TODO(michalk8): doesn't fail with `condition_embed_dim` + dim=self.condition_dim, out_dim=self.condition_embed_dim, num_layers=self.num_layers_per_block, act_fn=self.act_fn diff --git a/src/ott/neural/flow_models/otfm.py b/src/ott/neural/flow_models/otfm.py index b793a17c3..d80d8e6b8 100644 --- a/src/ott/neural/flow_models/otfm.py +++ b/src/ott/neural/flow_models/otfm.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import functools from typing import Any, Callable, Dict, List, Optional, Tuple import jax @@ -79,8 +78,10 @@ def loss_fn( ) -> jnp.ndarray: x_t = self.flow.compute_xt(rng, t, source, target) - apply_fn = functools.partial(vf_state.apply_fn, {"params": params}) - v_t = jax.vmap(apply_fn)(t, x_t, source_conditions) + v_t = vf_state.apply_fn({"params": params}, t, x_t, source_conditions) + # TODO(michalk8): should be removed + # apply_fn = functools.partial(vf_state.apply_fn, {"params": params}) + # v_t = jax.vmap(apply_fn)(t, x_t, source_conditions) u_t = self.flow.compute_ut(t, source, target) return jnp.mean((v_t - u_t) ** 2) diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py index 04f9917a8..c3cd11ce7 100644 --- a/tests/neural/conftest.py +++ b/tests/neural/conftest.py @@ -52,10 +52,11 @@ def lin_dl() -> DataLoader: @pytest.fixture() def lin_dl_with_conds() -> DataLoader: - n, d = 100, 2 + n, d, cond_dim = 100, 2, 3 rng = np.random.default_rng(13) - src_cond, tgt_cond = rng.normal(size=(n, 1)), rng.normal(size=(n, 1)) + src_cond = rng.normal(size=(n, cond_dim)) + tgt_cond = rng.normal(size=(n, cond_dim)) src = _ot_data(rng, n=n, dim=d, condition=src_cond) tgt = _ot_data(rng, n=n, dim=d, condition=tgt_cond) @@ -65,10 +66,13 @@ def lin_dl_with_conds() -> DataLoader: @pytest.fixture() def conditional_lin_dl() -> datasets.ConditionalLoader: + cond_dim = 4 rng = np.random.default_rng(42) - src0, tgt0 = _ot_data(rng, condition=0.0), _ot_data(rng, offset=2.0) - src1, tgt1 = _ot_data(rng, condition=1.0), _ot_data(rng, offset=-2.0) + src0 = _ot_data(rng, condition=0.0, cond_dim=cond_dim) + tgt0 = _ot_data(rng, offset=2.0) + src1 = _ot_data(rng, condition=1.0, cond_dim=cond_dim) + tgt1 = _ot_data(rng, offset=-2.0) src_ds = datasets.OTDataset(src0, tgt0) tgt_ds = datasets.OTDataset(src1, tgt1) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index c60b2e064..1c9c985a9 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -21,7 +21,6 @@ import optax from ott.geometry import costs -from ott.neural.flow_models.genot import GENOTLin, GENOTQuad from ott.neural.flow_models.models import VelocityField from ott.neural.flow_models.samplers import uniform_sampler from ott.solvers.linear import sinkhorn, sinkhorn_lr diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index df2fb4bdb..ccca15214 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -11,112 +11,50 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import functools - import pytest import jax import jax.numpy as jnp -from torch.utils.data import DataLoader import optax -from ott.neural.data import datasets -from ott.neural.flow_models import flows, models, otfm, samplers, utils +from ott.neural.flow_models import flows, models, otfm, utils class TestOTFlowMatching: - def test_fm(self, lin_dl: DataLoader): - input_dim = 2 - neural_vf = models.VelocityField( - output_dim=2, - condition_dim=0, - latent_embed_dim=5, - ) - fm = otfm.OTFlowMatching( - input_dim, - neural_vf, - flow=flows.ConstantNoiseFlow(0.0), - time_sampler=samplers.uniform_sampler, - match_fn=jax.jit(utils.match_linear), - optimizer=optax.adam(learning_rate=1e-3), - ) - - _logs = fm(lin_dl, n_iters=2) - - for batch in lin_dl: - src = jnp.asarray(batch["src_lin"]) - tgt = jnp.asarray(batch["tgt_lin"]) - break - - res_fwd = fm.transport(src) - res_bwd = fm.transport(tgt, t0=1.0, t1=0.0) - - # TODO(michalk8): better assertions - assert jnp.sum(jnp.isnan(res_fwd)) == 0 - assert jnp.sum(jnp.isnan(res_bwd)) == 0 + @pytest.mark.parametrize(("cond_dim", "dl"), [(0, "lin_dl"), + (3, "lin_dl_with_conds"), + (4, "conditional_lin_dl")]) + def test_fm(self, rng: jax.Array, cond_dim: int, dl: str, request): + input_dim, output_dim, latent_dim = 2, 2, 5 + dl = request.getfixturevalue(dl) - def test_fm_with_conds(self, lin_dl_with_conds: DataLoader): - input_dim, cond_dim = 2, 1 neural_vf = models.VelocityField( - output_dim=input_dim, + output_dim=output_dim, condition_dim=cond_dim, - latent_embed_dim=5, + latent_embed_dim=latent_dim, ) fm = otfm.OTFlowMatching( - 2, neural_vf, - flow=flows.BrownianNoiseFlow(0.12), - time_sampler=functools.partial(samplers.uniform_sampler, offset=1e-5), + flows.ConstantNoiseFlow(0.0), match_fn=jax.jit(utils.match_linear), + rng=rng, optimizer=optax.adam(learning_rate=1e-3), + input_dim=input_dim, ) - _logs = fm(lin_dl_with_conds, n_iters=2) - - for batch in lin_dl_with_conds: - src = jnp.asarray(batch["src_lin"]) - tgt = jnp.asarray(batch["tgt_lin"]) - src_cond = jnp.asarray(batch["src_condition"]) - break - - res_fwd = fm.transport(src, condition=src_cond) - res_bwd = fm.transport(tgt, condition=src_cond, t0=1.0, t1=0.0) - - # TODO(michalk8): better assertions - assert jnp.sum(jnp.isnan(res_fwd)) == 0 - assert jnp.sum(jnp.isnan(res_bwd)) == 0 - - @pytest.mark.parametrize("rank", [-1, 10]) - def test_fm_conditional_loader( - self, rank: int, conditional_lin_dl: datasets.ConditionalLoader - ): - input_dim, cond_dim = 2, 0 - neural_vf = models.VelocityField( - output_dim=input_dim, - condition_dim=cond_dim, - latent_embed_dim=5, - ) - fm = otfm.OTFlowMatching( - input_dim, - neural_vf, - flow=flows.ConstantNoiseFlow(13.0), - time_sampler=samplers.uniform_sampler, - match_fn=jax.jit(functools.partial(utils.match_linear, rank=rank)), - optimizer=optax.adam(learning_rate=1e-3), - ) - - _logs = fm(conditional_lin_dl, n_iters=2) + _logs = fm(dl, n_iters=3) - for batch in conditional_lin_dl: - src = jnp.asarray(batch["src_lin"]) - tgt = jnp.asarray(batch["tgt_lin"]) - src_cond = jnp.asarray(batch["src_condition"]) - break + batch = next(iter(dl)) + src = jnp.asarray(batch["src_lin"]) + tgt = jnp.asarray(batch["tgt_lin"]) + src_cond = batch.get("src_condition") + if src_cond is not None: + src_cond = jnp.asarray(src_cond) res_fwd = fm.transport(src, condition=src_cond) - res_bwd = fm.transport(tgt, condition=src_cond, t0=1.0, t1=0.0) + res_bwd = fm.transport(tgt, t0=1.0, t1=0.0, condition=src_cond) # TODO(michalk8): better assertions assert jnp.sum(jnp.isnan(res_fwd)) == 0 From 4688998c4ce0b842d0784b74a33374a0fee47e32 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 15 Mar 2024 16:41:46 +0100 Subject: [PATCH 126/186] Remove `MLPBlock` and `RescalingMLP` --- src/ott/neural/__init__.py | 2 +- src/ott/neural/duality/layers.py | 5 +- src/ott/neural/flow_models/models.py | 64 +++++--------- src/ott/neural/models/__init__.py | 14 ---- src/ott/neural/models/layers.py | 50 ----------- src/ott/neural/models/nets.py | 121 --------------------------- tests/neural/losses_test.py | 5 +- tests/neural/map_estimator_test.py | 6 +- 8 files changed, 28 insertions(+), 239 deletions(-) delete mode 100644 src/ott/neural/models/__init__.py delete mode 100644 src/ott/neural/models/layers.py delete mode 100644 src/ott/neural/models/nets.py diff --git a/src/ott/neural/__init__.py b/src/ott/neural/__init__.py index 678919a8c..10dac222c 100644 --- a/src/ott/neural/__init__.py +++ b/src/ott/neural/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import data, duality, flow_models, gaps, models +from . import data, duality, flow_models, gaps diff --git a/src/ott/neural/duality/layers.py b/src/ott/neural/duality/layers.py index e0d755d0e..6ed857452 100644 --- a/src/ott/neural/duality/layers.py +++ b/src/ott/neural/duality/layers.py @@ -79,7 +79,8 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: class PosDefPotentials(nn.Module): - r""":math:`\frac{1}{2} x^T (A_i A_i^T + \text{Diag}(d_i)) x + b_i^T x^2 + c_i` potentials. + r""":math:`\frac{1}{2} x^T (A_i A_i^T + \text{Diag}(d_i)) x + b_i^T x^2 + c_i` + potentials. This class implements a layer that takes (batched) ``d``-dimensional vectors ``x`` in, to output a ``num_potentials``-dimensional vector. Each of the @@ -111,7 +112,7 @@ class PosDefPotentials(nn.Module): bias_init: Initializer for the bias. The default is :func:`~flax.linen.initializers.zeros`. precision: Numerical precision of the computation. - """ # noqa: E501 + """ # noqa: D205,E501 num_potentials: int rank: int = 0 diff --git a/src/ott/neural/flow_models/models.py b/src/ott/neural/flow_models/models.py index 26cc31915..d0a07d66e 100644 --- a/src/ott/neural/flow_models/models.py +++ b/src/ott/neural/flow_models/models.py @@ -21,7 +21,6 @@ from flax.training import train_state import ott.neural.flow_models.layers as flow_layers -from ott.neural.models import layers __all__ = ["VelocityField"] @@ -37,7 +36,7 @@ class VelocityField(nn.Module): from :math:`t=t_0` to :math:`t=t_1`. Each of the input, condition, and time embeddings are passed through a block - consisting of ``num_layers_per_block`` layers of dimension + consisting of ``num_layers`` layers of dimension ``latent_embed_dim``, ``condition_embed_dim``, and ``time_embed_dim``, respectively. The output of each block is concatenated and passed through a final block of dimension ``joint_hidden_dim``. @@ -50,7 +49,7 @@ class VelocityField(nn.Module): If :obj:`None`, set to ``latent_embed_dim``. t_embed_dim: Dimensionality of the time embedding. If :obj:`None`, set to ``latent_embed_dim``. - num_layers_per_block: Number of layers per block. + num_layers: Number of layers. act_fn: Activation function. n_freqs: Number of frequencies to use for the time embedding. """ @@ -59,7 +58,7 @@ class VelocityField(nn.Module): condition_dim: int = 0 condition_embed_dim: Optional[int] = None t_embed_dim: Optional[int] = None - num_layers_per_block: int = 3 + num_layers: int = 3 act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.silu n_freqs: int = 128 @@ -83,52 +82,27 @@ def __call__( """Forward pass through the neural vector field. Args: - t: Time of shape `(batch_size, 1)`. - x: Data of shape `(batch_size, output_dim)`. - condition: Conditioning vector. + t: Time of shape ``[batch, 1]``. + x: Data of shape ``[batch, ...]``. + condition: Conditioning vector of shape ``[batch, cond_dim]``. Returns: - Output of the neural vector field. + Output of the neural vector field of shape ``[batch, output_dim]``. """ t = flow_layers.CyclicalTimeEncoder(self.n_freqs)(t) - t_layer = layers.MLPBlock( - dim=self.t_embed_dim, - out_dim=self.t_embed_dim, - num_layers=self.num_layers_per_block, - act_fn=self.act_fn - ) - t = t_layer(t) - x_layer = layers.MLPBlock( - dim=self.latent_embed_dim, - out_dim=self.latent_embed_dim, - num_layers=self.num_layers_per_block, - act_fn=self.act_fn - ) - x = x_layer(x) - - if self.condition_dim > 0: - assert condition is not None, \ - "Condition must be specified when `condition_dim > 0`." - condition_layer = layers.MLPBlock( - # TODO(michalk8): doesn't fail with `condition_embed_dim` - dim=self.condition_dim, - out_dim=self.condition_embed_dim, - num_layers=self.num_layers_per_block, - act_fn=self.act_fn - ) - condition = condition_layer(condition) - concatenated = jnp.concatenate([t, x, condition], axis=-1) - else: - concatenated = jnp.concatenate([t, x], axis=-1) - - out_layer = layers.MLPBlock( - dim=self.joint_hidden_dim, - out_dim=self.joint_hidden_dim, - num_layers=self.num_layers_per_block, - act_fn=self.act_fn - ) - out = out_layer(concatenated) + for _ in range(self.num_layers): + t = self.act_fn(nn.Dense(self.t_embed_dim)(t)) + x = self.act_fn(nn.Dense(self.latent_embed_dim)(x)) + if self.condition_dim > 0: + assert condition is not None, "TODO." + condition = self.act_fn(nn.Dense(self.condition_embed_dim)(condition)) + + arrs = [t, x] + ([] if condition is None else [condition]) + out = jnp.concatenate(arrs, axis=-1) + + for _ in range(self.num_layers): + out = self.act_fn(nn.Dense(self.joint_hidden_dim)(out)) return nn.Dense(self.output_dim, use_bias=True)(out) def create_train_state( diff --git a/src/ott/neural/models/__init__.py b/src/ott/neural/models/__init__.py deleted file mode 100644 index 83287aec5..000000000 --- a/src/ott/neural/models/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright OTT-JAX -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from . import layers, nets diff --git a/src/ott/neural/models/layers.py b/src/ott/neural/models/layers.py deleted file mode 100644 index d0352ff05..000000000 --- a/src/ott/neural/models/layers.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright OTT-JAX -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any - -import jax.numpy as jnp - -import flax.linen as nn - -__all__ = ["MLPBlock"] - - -class MLPBlock(nn.Module): - """An MLP block. - - Args: - dim: Dimensionality of the input data. - num_layers: Number of layers in the MLP block. - act_fn: Activation function. - out_dim: Dimensionality of the output data. - """ - dim: int = 128 - num_layers: int = 3 - act_fn: Any = nn.silu - out_dim: int = 128 - - @nn.compact - def __call__(self, x: jnp.ndarray) -> jnp.ndarray: - """Apply the MLP block. - - Args: - x: Input data of shape (batch_size, dim). - - Returns: - Output data of shape (batch_size, out_dim). - """ - for _ in range(self.num_layers): - x = nn.Dense(self.dim)(x) - x = self.act_fn(x) - return nn.Dense(self.out_dim)(x) diff --git a/src/ott/neural/models/nets.py b/src/ott/neural/models/nets.py deleted file mode 100644 index cad4e84c2..000000000 --- a/src/ott/neural/models/nets.py +++ /dev/null @@ -1,121 +0,0 @@ -# Copyright OTT-JAX -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Callable, Optional - -import jax -import jax.numpy as jnp - -import flax.linen as nn -import optax -from flax.training import train_state - -from ott.neural.models import layers - -__all__ = ["RescalingMLP"] - - -class RescalingMLP(nn.Module): - """Network to learn distributional rescaling factors based on a MLP. - - The input is passed through a block consisting of ``num_layers_per_block`` - with size ``hidden_dim``. - If ``condition_dim`` is greater than 0, the conditioning vector is passed - through a block of the same size. - Both outputs are concatenated and passed through another block of the same - size. - - To ensure non-negativity of the output, the output is exponentiated. - - Args: - hidden_dim: Dimensionality of the hidden layers. - condition_dim: Dimensionality of the conditioning vector. - num_layers_per_block: Number of layers per block. - act_fn: Activation function. - - Returns: - Non-negative escaling factors. - """ - hidden_dim: int - condition_dim: int = 0 - num_layers_per_block: int = 3 - act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.selu - - @nn.compact - def __call__( - self, - x: jnp.ndarray, - condition: Optional[jnp.ndarray] = None - ) -> jnp.ndarray: - """Forward pass through the rescaling network. - - Args: - x: Data of shape ``[n, ...]``. - condition: Condition of shape ``[n, condition_dim]``. - - Returns: - Estimated rescaling factors. - """ - x_layer = layers.MLPBlock( - dim=self.hidden_dim, - out_dim=self.hidden_dim, - num_layers=self.num_layers_per_block, - act_fn=self.act_fn - ) - x = x_layer(x) - - if self.condition_dim > 0: - condition_layer = layers.MLPBlock( - dim=self.hidden_dim, - out_dim=self.hidden_dim, - num_layers=self.num_layers_per_block, - act_fn=self.act_fn - ) - - condition = condition_layer(condition) - concatenated = jnp.concatenate((x, condition), axis=-1) - else: - concatenated = x - - out_layer = layers.MLPBlock( - dim=self.hidden_dim, - out_dim=1, - num_layers=self.num_layers_per_block, - act_fn=self.act_fn - ) - - out = out_layer(concatenated) - return jnp.exp(out) - - def create_train_state( - self, - rng: jax.Array, - optimizer: optax.OptState, - input_dim: int, - ) -> train_state.TrainState: - """Create the training state. - - Args: - rng: Random number generator. - optimizer: Optimizer. - input_dim: Dimensionality of the input. - - Returns: - Training state. - """ - params = self.init( - rng, jnp.ones((1, input_dim)), jnp.ones((1, self.condition_dim)) - )["params"] - return train_state.TrainState.create( - apply_fn=self.apply, params=params, tx=optimizer - ) diff --git a/tests/neural/losses_test.py b/tests/neural/losses_test.py index e26d8227b..e1e13f193 100644 --- a/tests/neural/losses_test.py +++ b/tests/neural/losses_test.py @@ -11,15 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import pytest import jax import numpy as np from ott.geometry import costs +from ott.neural.duality import models from ott.neural.gaps import monge_gap -from ott.neural.models import nets @pytest.mark.fast() @@ -35,7 +34,7 @@ def test_monge_gap_non_negativity( rng1, rng2 = jax.random.split(rng, 2) reference_points = jax.random.normal(rng1, (n_samples, n_features)) - model = nets.MLP(dim_hidden=[8, 8], is_potential=False) + model = models.PotentialMLP(dim_hidden=[8, 8], is_potential=False) params = model.init(rng2, x=reference_points[0]) target = model.apply(params, reference_points) diff --git a/tests/neural/map_estimator_test.py b/tests/neural/map_estimator_test.py index 399dff39d..cee66e40e 100644 --- a/tests/neural/map_estimator_test.py +++ b/tests/neural/map_estimator_test.py @@ -19,8 +19,8 @@ from ott import datasets from ott.geometry import pointcloud +from ott.neural.duality import models from ott.neural.gaps import map_estimator, monge_gap -from ott.neural.models import nets from ott.tools import sinkhorn_divergence @@ -44,14 +44,14 @@ def fitting_loss( x=samples, y=mapped_samples, ).divergence - return (div, None) + return div, None def regularizer(x, y): gap, out = monge_gap.monge_gap_from_samples(x, y, return_output=True) return gap, out.n_iters # define the model - model = nets.MLP(dim_hidden=[16, 8], is_potential=False) + model = models.PotentialMLP(dim_hidden=[16, 8], is_potential=False) # generate data train_dataset, valid_dataset, dim_data = ( From 52c5de985f4cd02c2536a35a27e0f5ab3773f4ad Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 15 Mar 2024 16:43:02 +0100 Subject: [PATCH 127/186] Add forgotten license --- src/ott/neural/flow_models/utils.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/ott/neural/flow_models/utils.py b/src/ott/neural/flow_models/utils.py index 9b8386ea9..21f91b350 100644 --- a/src/ott/neural/flow_models/utils.py +++ b/src/ott/neural/flow_models/utils.py @@ -1,3 +1,16 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Any, Literal, Optional, Tuple, Union import jax From 0b417d76299aeeff18636b9d60deb07238418906 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 15 Mar 2024 17:13:58 +0100 Subject: [PATCH 128/186] Remove `__post_init__` from `VF` --- src/ott/neural/flow_models/models.py | 61 ++++++++++++---------------- tests/neural/genot_test.py | 12 +++--- tests/neural/otfm_test.py | 11 ++--- 3 files changed, 39 insertions(+), 45 deletions(-) diff --git a/src/ott/neural/flow_models/models.py b/src/ott/neural/flow_models/models.py index d0a07d66e..ff3183016 100644 --- a/src/ott/neural/flow_models/models.py +++ b/src/ott/neural/flow_models/models.py @@ -37,41 +37,29 @@ class VelocityField(nn.Module): Each of the input, condition, and time embeddings are passed through a block consisting of ``num_layers`` layers of dimension - ``latent_embed_dim``, ``condition_embed_dim``, and ``time_embed_dim``, + ``hidden_dim``, ``condition_dim``, and ``time_embed_dim``, respectively. The output of each block is concatenated and passed through a final block of dimension ``joint_hidden_dim``. Args: output_dim: Dimensionality of the neural vector field. - latent_embed_dim: Dimensionality of the embedding of the data. - condition_dim: Dimensionality of the conditioning vector. - condition_embed_dim: Dimensionality of the embedding of the condition. - If :obj:`None`, set to ``latent_embed_dim``. - t_embed_dim: Dimensionality of the time embedding. - If :obj:`None`, set to ``latent_embed_dim``. + hidden_dim: Dimensionality of the embedding of the data. num_layers: Number of layers. + condition_dim: Dimensionality of the embedding of the condition. + If :obj:`None`, TODO. + time_dim: Dimensionality of the time embedding. + If :obj:`None`, set to ``hidden_dim``. act_fn: Activation function. n_freqs: Number of frequencies to use for the time embedding. """ output_dim: int - latent_embed_dim: int - condition_dim: int = 0 - condition_embed_dim: Optional[int] = None - t_embed_dim: Optional[int] = None + hidden_dim: int num_layers: int = 3 + condition_dim: Optional[int] = None + time_dim: Optional[int] = None act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.silu n_freqs: int = 128 - def __post_init__(self) -> None: - if self.condition_embed_dim is None: - self.condition_embed_dim = self.latent_embed_dim - if self.t_embed_dim is None: - self.t_embed_dim = self.latent_embed_dim - self.joint_hidden_dim = ( - self.latent_embed_dim + self.condition_embed_dim + self.t_embed_dim - ) - super().__post_init__() - @nn.compact def __call__( self, @@ -84,32 +72,36 @@ def __call__( Args: t: Time of shape ``[batch, 1]``. x: Data of shape ``[batch, ...]``. - condition: Conditioning vector of shape ``[batch, cond_dim]``. + condition: Conditioning vector of shape ``[batch, ...]``. Returns: Output of the neural vector field of shape ``[batch, output_dim]``. """ + time_dim = self.hidden_dim if self.time_dim is None else self.time_dim t = flow_layers.CyclicalTimeEncoder(self.n_freqs)(t) for _ in range(self.num_layers): - t = self.act_fn(nn.Dense(self.t_embed_dim)(t)) - x = self.act_fn(nn.Dense(self.latent_embed_dim)(x)) - if self.condition_dim > 0: + t = self.act_fn(nn.Dense(time_dim)(t)) + x = self.act_fn(nn.Dense(self.hidden_dim)(x)) + if self.condition_dim is not None: assert condition is not None, "TODO." - condition = self.act_fn(nn.Dense(self.condition_embed_dim)(condition)) + condition = self.act_fn(nn.Dense(self.condition_dim)(condition)) - arrs = [t, x] + ([] if condition is None else [condition]) - out = jnp.concatenate(arrs, axis=-1) + feats = [t, x] + ([] if condition is None else [condition]) + feats = jnp.concatenate(feats, axis=-1) + joint_dim = feats.shape[-1] for _ in range(self.num_layers): - out = self.act_fn(nn.Dense(self.joint_hidden_dim)(out)) - return nn.Dense(self.output_dim, use_bias=True)(out) + feats = self.act_fn(nn.Dense(joint_dim)(feats)) + + return nn.Dense(self.output_dim, use_bias=True)(feats) def create_train_state( self, rng: jax.Array, optimizer: optax.OptState, input_dim: int, + cond_dim: Optional[int] = None, ) -> train_state.TrainState: """Create the training state. @@ -117,14 +109,15 @@ def create_train_state( rng: Random number generator. optimizer: Optimizer. input_dim: Dimensionality of the input. + cond_dim: TODO. Returns: The training state. """ - params = self.init( - rng, jnp.ones((1, 1)), jnp.ones((1, input_dim)), - jnp.ones((1, self.condition_dim)) - )["params"] + t, x = jnp.ones((1, 1)), jnp.ones((1, input_dim)) + cond = jnp.ones((1, cond_dim)) if self.condition_dim is not None else None + + params = self.init(rng, t, x, cond)["params"] return train_state.TrainState.create( apply_fn=self.apply, params=params, tx=optimizer ) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index 1c9c985a9..50e7dd504 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -60,7 +60,7 @@ def test_genot_linear_unconditional( neural_vf = VelocityField( output_dim=target_dim, condition_dim=source_dim + condition_dim, - latent_embed_dim=5, + hidden_dim=5, ) ot_solver = sinkhorn.Sinkhorn( ) if solver == "sinkhorn" else sinkhorn_lr.LRSinkhorn(rank=3) @@ -121,7 +121,7 @@ def test_genot_linear_conditional( neural_vf = VelocityField( output_dim=target_dim, condition_dim=source_dim + condition_dim, - latent_embed_dim=5, + hidden_dim=5, ) ot_solver = sinkhorn.Sinkhorn( ) if solver == "sinkhorn" else sinkhorn_lr.LRSinkhorn(rank=3) @@ -182,7 +182,7 @@ def test_genot_quad_unconditional( neural_vf = VelocityField( output_dim=target_dim, condition_dim=source_dim + condition_dim, - latent_embed_dim=5, + hidden_dim=5, ) ot_solver = gromov_wasserstein.GromovWasserstein( epsilon=1e-2 @@ -243,7 +243,7 @@ def test_genot_fused_unconditional( neural_vf = VelocityField( output_dim=target_dim, condition_dim=source_dim + condition_dim, - latent_embed_dim=5, + hidden_dim=5, ) ot_solver = gromov_wasserstein.GromovWasserstein( epsilon=1e-2 @@ -306,7 +306,7 @@ def test_genot_quad_conditional( neural_vf = VelocityField( output_dim=target_dim, condition_dim=source_dim + condition_dim, - latent_embed_dim=5, + hidden_dim=5, ) ot_solver = gromov_wasserstein.GromovWasserstein( epsilon=1e-2 @@ -372,7 +372,7 @@ def test_genot_fused_conditional( neural_vf = VelocityField( output_dim=target_dim, condition_dim=source_dim + condition_dim, - latent_embed_dim=5, + hidden_dim=5, ) ot_solver = gromov_wasserstein.GromovWasserstein( epsilon=1e-2 diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index ccca15214..60044293d 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -27,13 +27,13 @@ class TestOTFlowMatching: (3, "lin_dl_with_conds"), (4, "conditional_lin_dl")]) def test_fm(self, rng: jax.Array, cond_dim: int, dl: str, request): - input_dim, output_dim, latent_dim = 2, 2, 5 + output_dim, hidden_dim = 2, 5 dl = request.getfixturevalue(dl) neural_vf = models.VelocityField( - output_dim=output_dim, - condition_dim=cond_dim, - latent_embed_dim=latent_dim, + output_dim, + hidden_dim, + condition_dim=hidden_dim if cond_dim > 0 else None, ) fm = otfm.OTFlowMatching( neural_vf, @@ -41,7 +41,8 @@ def test_fm(self, rng: jax.Array, cond_dim: int, dl: str, request): match_fn=jax.jit(utils.match_linear), rng=rng, optimizer=optax.adam(learning_rate=1e-3), - input_dim=input_dim, + input_dim=2, # all dataloaders have dim `2` + cond_dim=cond_dim, ) _logs = fm(dl, n_iters=3) From fe74a57f0a63cb5c3f0dbf9a01e7e49c5d0419de Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 15 Mar 2024 17:23:54 +0100 Subject: [PATCH 129/186] Move cyclical time encoder --- src/ott/neural/flow_models/__init__.py | 2 +- src/ott/neural/flow_models/layers.py | 44 -------------------------- src/ott/neural/flow_models/models.py | 13 ++++---- src/ott/neural/flow_models/utils.py | 19 +++++++++++ tests/neural/otfm_test.py | 2 +- 5 files changed, 28 insertions(+), 52 deletions(-) delete mode 100644 src/ott/neural/flow_models/layers.py diff --git a/src/ott/neural/flow_models/__init__.py b/src/ott/neural/flow_models/__init__.py index cc2c4bfdb..2d6fca4b5 100644 --- a/src/ott/neural/flow_models/__init__.py +++ b/src/ott/neural/flow_models/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import flows, genot, layers, models, otfm, samplers +from . import flows, genot, models, otfm, samplers diff --git a/src/ott/neural/flow_models/layers.py b/src/ott/neural/flow_models/layers.py deleted file mode 100644 index 2f87f6cfc..000000000 --- a/src/ott/neural/flow_models/layers.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright OTT-JAX -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import jax.numpy as jnp - -import flax.linen as nn - -__all__ = ["CyclicalTimeEncoder"] - - -class CyclicalTimeEncoder(nn.Module): - r"""A cyclical time encoder. - - Encodes time :math:`t` as :math:`cos(\hat{t})` and :math:`sin(\hat{t})` - where :math:`\hat{t} = [2\pi t, 2\pi 2 t,\dots, 2\pi n_f t]`. - - Args: - n_freqs: Frequency :math:`n_f` of the cyclical encoding. - """ - n_freqs: int = 128 - - @nn.compact - def __call__(self, t: jnp.ndarray) -> jnp.ndarray: # noqa: D102 - """Encode time :math:`t` into a cyclical representation. - - Args: - t: Time of shape ``[n, 1]``. - - Returns: - Encoded time of shape ``[n, 2 * n_freqs]``. - """ - freq = 2 * jnp.arange(self.n_freqs) * jnp.pi - t = freq * t - return jnp.concatenate([jnp.cos(t), jnp.sin(t)], axis=-1) diff --git a/src/ott/neural/flow_models/models.py b/src/ott/neural/flow_models/models.py index ff3183016..5590113e8 100644 --- a/src/ott/neural/flow_models/models.py +++ b/src/ott/neural/flow_models/models.py @@ -20,7 +20,7 @@ import optax from flax.training import train_state -import ott.neural.flow_models.layers as flow_layers +from ott.neural.flow_models import utils __all__ = ["VelocityField"] @@ -42,23 +42,24 @@ class VelocityField(nn.Module): a final block of dimension ``joint_hidden_dim``. Args: - output_dim: Dimensionality of the neural vector field. hidden_dim: Dimensionality of the embedding of the data. + output_dim: Dimensionality of the neural vector field. num_layers: Number of layers. condition_dim: Dimensionality of the embedding of the condition. If :obj:`None`, TODO. time_dim: Dimensionality of the time embedding. If :obj:`None`, set to ``hidden_dim``. + time_encoder: TODO. act_fn: Activation function. - n_freqs: Number of frequencies to use for the time embedding. """ - output_dim: int hidden_dim: int + output_dim: int num_layers: int = 3 condition_dim: Optional[int] = None time_dim: Optional[int] = None + time_encoder: Callable[[jnp.ndarray], + jnp.ndarray] = utils.cyclical_time_encoder act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.silu - n_freqs: int = 128 @nn.compact def __call__( @@ -78,8 +79,8 @@ def __call__( Output of the neural vector field of shape ``[batch, output_dim]``. """ time_dim = self.hidden_dim if self.time_dim is None else self.time_dim - t = flow_layers.CyclicalTimeEncoder(self.n_freqs)(t) + t = self.time_encoder(t) for _ in range(self.num_layers): t = self.act_fn(nn.Dense(time_dim)(t)) x = self.act_fn(nn.Dense(self.hidden_dim)(x)) diff --git a/src/ott/neural/flow_models/utils.py b/src/ott/neural/flow_models/utils.py index 21f91b350..d385edb2e 100644 --- a/src/ott/neural/flow_models/utils.py +++ b/src/ott/neural/flow_models/utils.py @@ -24,6 +24,7 @@ "match_quadratic", "sample_joint", "sample_conditional", + "cyclical_time_encoder", ] ScaleCost_t = Union[float, Literal["mean", "max_cost", "median"]] @@ -106,3 +107,21 @@ def sample_conditional( src_ixs = jnp.repeat(indices[:, None], k, axis=1) # (n, k) return src_ixs, tgt_ixs + + +def cyclical_time_encoder(t: jnp.ndarray, n_freqs: int = 128) -> jnp.ndarray: + r"""Encode time :math:`t` into a cyclical representation. + + Time :math:`t` is encoded as :math:`cos(\hat{t})` and :math:`sin(\hat{t})` + where :math:`\hat{t} = [2\pi t, 2\pi 2 t,\dots, 2\pi n_f t]`. + + Args: + t: Time of shape ``[n, 1]``. + n_freqs: Frequency :math:`n_f` of the cyclical encoding. + + Returns: + Encoded time of shape ``[n, 2 * n_freqs]``. + """ + freq = 2 * jnp.arange(n_freqs) * jnp.pi + t = freq * t + return jnp.concatenate([jnp.cos(t), jnp.sin(t)], axis=-1) diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index 60044293d..30e38dba6 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -31,8 +31,8 @@ def test_fm(self, rng: jax.Array, cond_dim: int, dl: str, request): dl = request.getfixturevalue(dl) neural_vf = models.VelocityField( - output_dim, hidden_dim, + output_dim, condition_dim=hidden_dim if cond_dim > 0 else None, ) fm = otfm.OTFlowMatching( From 4affc14375518b4f064c5e2b067373e7c0eac14a Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 15 Mar 2024 17:27:58 +0100 Subject: [PATCH 130/186] Move more stuff to `utils` --- src/ott/neural/flow_models/__init__.py | 2 +- src/ott/neural/flow_models/otfm.py | 8 ++-- src/ott/neural/flow_models/samplers.py | 52 -------------------------- src/ott/neural/flow_models/utils.py | 34 +++++++++++++++++ tests/neural/genot_test.py | 2 +- 5 files changed, 40 insertions(+), 58 deletions(-) diff --git a/src/ott/neural/flow_models/__init__.py b/src/ott/neural/flow_models/__init__.py index 2d6fca4b5..a6239fa06 100644 --- a/src/ott/neural/flow_models/__init__.py +++ b/src/ott/neural/flow_models/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import flows, genot, models, otfm, samplers +from . import flows, genot, models, otfm, utils diff --git a/src/ott/neural/flow_models/otfm.py b/src/ott/neural/flow_models/otfm.py index d80d8e6b8..d436fd9a0 100644 --- a/src/ott/neural/flow_models/otfm.py +++ b/src/ott/neural/flow_models/otfm.py @@ -21,8 +21,8 @@ from flax.training import train_state from ott import utils -from ott.neural.flow_models import flows, models, samplers -from ott.neural.flow_models.utils import sample_joint +from ott.neural.flow_models import flows, models +from ott.neural.flow_models import utils as flow_utils __all__ = ["OTFlowMatching"] @@ -47,7 +47,7 @@ def __init__( velocity_field: models.VelocityField, flow: flows.BaseFlow, time_sampler: Callable[[jax.Array, int], - jnp.ndarray] = samplers.uniform_sampler, + jnp.ndarray] = flow_utils.uniform_sampler, match_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None, **kwargs: Any, @@ -116,7 +116,7 @@ def __call__( # noqa: D102 if self.match_fn is not None: tmat = self.match_fn(src, tgt) - src_ixs, tgt_ixs = sample_joint(rng_resample, tmat) + src_ixs, tgt_ixs = flow_utils.sample_joint(rng_resample, tmat) src, tgt = src[src_ixs], tgt[tgt_ixs] src_cond = None if src_cond is None else src_cond[src_ixs] diff --git a/src/ott/neural/flow_models/samplers.py b/src/ott/neural/flow_models/samplers.py index 9bd85d8b0..e69de29bb 100644 --- a/src/ott/neural/flow_models/samplers.py +++ b/src/ott/neural/flow_models/samplers.py @@ -1,52 +0,0 @@ -# Copyright OTT-JAX -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Optional - -import jax -import jax.numpy as jnp - -__all__ = ["uniform_sampler"] - - -def uniform_sampler( - rng: jax.Array, - num_samples: int, - low: float = 0.0, - high: float = 1.0, - offset: Optional[float] = None -) -> jnp.ndarray: - r"""Sample from a uniform distribution. - - Sample :math:`t` from a uniform distribution :math:`[low, high]`. - If `offset` is not :obj:`None`, one element :math:`t` is sampled from - :math:`[low, high]` and the K samples are constructed via - :math:`(t + k)/K \mod (high - low - offset) + low`. - - Args: - rng: Random number generator. - num_samples: Number of samples to generate. - low: Lower bound of the uniform distribution. - high: Upper bound of the uniform distribution. - offset: Offset of the uniform distribution. If :obj:`None`, no offset is - used. - - Returns: - An array with `num_samples` samples of the time :math:`t`. - """ - if offset is None: - return jax.random.uniform(rng, (num_samples, 1), minval=low, maxval=high) - - t = jax.random.uniform(rng, (1, 1), minval=low, maxval=high) - mod_term = ((high - low) - offset) - return (t + jnp.arange(num_samples)[:, None] / num_samples) % mod_term diff --git a/src/ott/neural/flow_models/utils.py b/src/ott/neural/flow_models/utils.py index d385edb2e..516342ed6 100644 --- a/src/ott/neural/flow_models/utils.py +++ b/src/ott/neural/flow_models/utils.py @@ -25,6 +25,7 @@ "sample_joint", "sample_conditional", "cyclical_time_encoder", + "uniform_sampler", ] ScaleCost_t = Union[float, Literal["mean", "max_cost", "median"]] @@ -125,3 +126,36 @@ def cyclical_time_encoder(t: jnp.ndarray, n_freqs: int = 128) -> jnp.ndarray: freq = 2 * jnp.arange(n_freqs) * jnp.pi t = freq * t return jnp.concatenate([jnp.cos(t), jnp.sin(t)], axis=-1) + + +def uniform_sampler( + rng: jax.Array, + num_samples: int, + low: float = 0.0, + high: float = 1.0, + offset: Optional[float] = None +) -> jnp.ndarray: + r"""Sample from a uniform distribution. + + Sample :math:`t` from a uniform distribution :math:`[low, high]`. + If `offset` is not :obj:`None`, one element :math:`t` is sampled from + :math:`[low, high]` and the K samples are constructed via + :math:`(t + k)/K \mod (high - low - offset) + low`. + + Args: + rng: Random number generator. + num_samples: Number of samples to generate. + low: Lower bound of the uniform distribution. + high: Upper bound of the uniform distribution. + offset: Offset of the uniform distribution. If :obj:`None`, no offset is + used. + + Returns: + An array with `num_samples` samples of the time :math:`t`. + """ + if offset is None: + return jax.random.uniform(rng, (num_samples, 1), minval=low, maxval=high) + + t = jax.random.uniform(rng, (1, 1), minval=low, maxval=high) + mod_term = ((high - low) - offset) + return (t + jnp.arange(num_samples)[:, None] / num_samples) % mod_term diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index 50e7dd504..abd8aad94 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -22,7 +22,7 @@ from ott.geometry import costs from ott.neural.flow_models.models import VelocityField -from ott.neural.flow_models.samplers import uniform_sampler +from ott.neural.flow_models.utils import uniform_sampler from ott.solvers.linear import sinkhorn, sinkhorn_lr from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr From 21ce5230646a0794e7918e1bd86e762906a5e8c8 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 15 Mar 2024 17:29:08 +0100 Subject: [PATCH 131/186] Remove `samplers.py` --- src/ott/neural/flow_models/samplers.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 src/ott/neural/flow_models/samplers.py diff --git a/src/ott/neural/flow_models/samplers.py b/src/ott/neural/flow_models/samplers.py deleted file mode 100644 index e69de29bb..000000000 From aa636ef11d7fd80104d26aab29449347fb064167 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 15 Mar 2024 17:39:54 +0100 Subject: [PATCH 132/186] Rename `cond_dim` -> `condition_dim` --- src/ott/neural/flow_models/genot.py | 2 +- src/ott/neural/flow_models/models.py | 8 ++++---- tests/neural/otfm_test.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/ott/neural/flow_models/genot.py b/src/ott/neural/flow_models/genot.py index 2ac511611..b27360d1c 100644 --- a/src/ott/neural/flow_models/genot.py +++ b/src/ott/neural/flow_models/genot.py @@ -49,7 +49,7 @@ def __init__( velocity_field: models.VelocityField, flow: flows.BaseFlow, time_sampler: Callable[[jax.Array, int], jnp.ndarray], - # TODO(mcihalk8): all args are optional + # TODO(michalk8): all args are optional data_match_fn: Callable[ [jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray], jnp.ndarray], latent_match_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray], diff --git a/src/ott/neural/flow_models/models.py b/src/ott/neural/flow_models/models.py index 5590113e8..a3b3261b0 100644 --- a/src/ott/neural/flow_models/models.py +++ b/src/ott/neural/flow_models/models.py @@ -28,7 +28,7 @@ class VelocityField(nn.Module): r"""Parameterized neural vector field. - The `VelocityField` learns a map :math:`v: \mathbb{R}\times \mathbb{R}^d + This class learns a map :math:`v: \mathbb{R}\times \mathbb{R}^d \rightarrow \mathbb{R}^d` solving the ODE :math:`\frac{dx}{dt} = v(t, x)`. Given a source distribution at time :math:`t_0`, the velocity field can be used to transport the source distribution given at :math:`t_0` to @@ -102,7 +102,7 @@ def create_train_state( rng: jax.Array, optimizer: optax.OptState, input_dim: int, - cond_dim: Optional[int] = None, + condition_dim: Optional[int] = None, ) -> train_state.TrainState: """Create the training state. @@ -110,13 +110,13 @@ def create_train_state( rng: Random number generator. optimizer: Optimizer. input_dim: Dimensionality of the input. - cond_dim: TODO. + condition_dim: TODO. Returns: The training state. """ t, x = jnp.ones((1, 1)), jnp.ones((1, input_dim)) - cond = jnp.ones((1, cond_dim)) if self.condition_dim is not None else None + cond = None if self.condition_dim is None else jnp.ones((1, condition_dim)) params = self.init(rng, t, x, cond)["params"] return train_state.TrainState.create( diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index 30e38dba6..a9b799d4a 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -42,7 +42,7 @@ def test_fm(self, rng: jax.Array, cond_dim: int, dl: str, request): rng=rng, optimizer=optax.adam(learning_rate=1e-3), input_dim=2, # all dataloaders have dim `2` - cond_dim=cond_dim, + condition_dim=cond_dim, ) _logs = fm(dl, n_iters=3) From da0ef92c3a626068f6db3158af533ea82d8413c8 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 15 Mar 2024 17:56:08 +0100 Subject: [PATCH 133/186] Nicer formatting --- src/ott/neural/flow_models/flows.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/ott/neural/flow_models/flows.py b/src/ott/neural/flow_models/flows.py index 150e2086e..2cde34833 100644 --- a/src/ott/neural/flow_models/flows.py +++ b/src/ott/neural/flow_models/flows.py @@ -44,9 +44,9 @@ def compute_mu_t( at time :math:`t`. Args: - t: Time :math:`t` of shape `(batch_size, 1)`. - src: Sample from the source distribution of shape `(batch_size, ...)`. - tgt: Sample from the target distribution of shape `(batch_size, ...)`. + t: Time :math:`t` of shape ``[batch, 1]``. + src: Sample from the source distribution of shape ``[batch, ...]``. + tgt: Sample from the target distribution of shape ``[batch, ...]``. """ @abc.abstractmethod @@ -54,7 +54,7 @@ def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray: """Compute the standard deviation of the probability path at time :math:`t`. Args: - t: Time :math:`t` of shape `(batch_size, 1)`. + t: Time :math:`t` of shape ``[batch, 1]``. Returns: Standard deviation of the probability path at time :math:`t`. @@ -70,9 +70,9 @@ def compute_ut( :math:`x_1` at time :math:`t`. Args: - t: Time :math:`t` of shape `(batch_size, 1)`. - src: Sample from the source distribution of shape `(batch_size, ...)`. - tgt: Sample from the target distribution of shape `(batch_size, ...)`. + t: Time :math:`t` of shape ``[batch, 1]``. + src: Sample from the source distribution of shape ``[batch, ...]``. + tgt: Sample from the target distribution of shape ``[batch, ...]``. Returns: Conditional vector field evaluated at time :math:`t`. @@ -88,9 +88,9 @@ def compute_xt( Args: rng: Random number generator. - t: Time :math:`t` of shape `(batch_size, 1)`. - src: Sample from the source distribution of shape `(batch_size, ...)`. - tgt: Sample from the target distribution of shape `(batch_size, ...)`. + t: Time :math:`t` of shape ``[batch, 1]``. + src: Sample from the source distribution of shape ``[batch, ...]``. + tgt: Sample from the target distribution of shape ``[batch, ...]``. Returns: Samples from the probability path between :math:`x_0` and :math:`x_1` @@ -132,7 +132,7 @@ def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray: r"""Compute noise of the flow at time :math:`t`. Args: - t: Time :math:`t` of shape `(batch_size, 1)`. + t: Time :math:`t` of shape ``[batch, 1]``. Returns: Constant, time-independent standard deviation :math:`\sigma`. @@ -155,7 +155,7 @@ def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray: r"""Compute noise of the flow at time :math:`t`. Args: - t: Time :math:`t` of shape `(batch_size, 1)`. + t: Time :math:`t` of shape ``[batch, 1]``. Returns: Samples from the probability path between :math:`x_0` and :math:`x_1` From de1c2646e8033094f9dc05a2335ff16ecba8b824 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 15 Mar 2024 17:57:19 +0100 Subject: [PATCH 134/186] Fix bug when sampling from the target --- src/ott/neural/data/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ott/neural/data/datasets.py b/src/ott/neural/data/datasets.py index eca6f1e51..fb1bc345b 100644 --- a/src/ott/neural/data/datasets.py +++ b/src/ott/neural/data/datasets.py @@ -89,7 +89,7 @@ def _sample_from_target(self, src_ix: int) -> Item_t: src_cond = self.src_conditions[src_ix] tgt_ixs = self._tgt_cond_to_ix[src_cond] ix = self._rng.choice(tgt_ixs) - return self.src_data[ix] + return self.tgt_data[ix] def __getitem__(self, ix: int) -> Item_t: src = self.src_data[ix] From ce763f043befb5457c1975c3db0c03ec3b93977e Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 15 Mar 2024 17:57:42 +0100 Subject: [PATCH 135/186] Fix another bug when sampling from the data --- src/ott/neural/data/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ott/neural/data/datasets.py b/src/ott/neural/data/datasets.py index fb1bc345b..63215e61e 100644 --- a/src/ott/neural/data/datasets.py +++ b/src/ott/neural/data/datasets.py @@ -95,7 +95,7 @@ def __getitem__(self, ix: int) -> Item_t: src = self.src_data[ix] src = {f"{self.SRC_PREFIX}_{k}": v for k, v in src.items()} - tgt = self.src_data[ix] if self.is_aligned else self._sample_from_target(ix) + tgt = self.tgt_data[ix] if self.is_aligned else self._sample_from_target(ix) tgt = {f"{self.TGT_PREFIX}_{k}": v for k, v in tgt.items()} return {**src, **tgt} From f9db2db98332f82fbd948e6ae1afd30e7ad5be19 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 15 Mar 2024 19:04:25 +0100 Subject: [PATCH 136/186] Add initial test for GW --- src/ott/neural/flow_models/genot.py | 32 ++++++++---- src/ott/neural/flow_models/otfm.py | 10 ++-- src/ott/neural/flow_models/utils.py | 13 +++++ tests/neural/conftest.py | 80 +++++++++++++---------------- tests/neural/genot_test.py | 56 ++++++++++++++++++-- tests/neural/otfm_test.py | 8 +-- 6 files changed, 130 insertions(+), 69 deletions(-) diff --git a/src/ott/neural/flow_models/genot.py b/src/ott/neural/flow_models/genot.py index b27360d1c..34a207b50 100644 --- a/src/ott/neural/flow_models/genot.py +++ b/src/ott/neural/flow_models/genot.py @@ -34,10 +34,10 @@ class GENOT: Args: velocity_field: Neural vector field parameterized by a neural network. flow: Flow between latent distribution and target distribution. - time_sampler: Sampler for the time. - of an input sample, see algorithm TODO. data_match_fn: Linear OT solver to match the latent distribution with the conditional distribution. + time_sampler: Sampler for the time. + of an input sample, see algorithm TODO. latent_match_fn: TODO. latent_noise_fn: TODO. k_samples_per_x: Number of samples drawn from the conditional distribution @@ -48,23 +48,28 @@ def __init__( self, velocity_field: models.VelocityField, flow: flows.BaseFlow, - time_sampler: Callable[[jax.Array, int], jnp.ndarray], - # TODO(michalk8): all args are optional data_match_fn: Callable[ [jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray], jnp.ndarray], + time_sampler: Callable[[jax.Array, int], + jnp.ndarray] = flow_utils.uniform_sampler, latent_match_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None, - # TODO(michalk8): add a default for this? latent_noise_fn: Optional[Callable[[jax.Array, Tuple[int, ...]], jnp.ndarray]] = None, + # TODO(michalk8): rename, too descriptive k_samples_per_x: int = 1, **kwargs: Any, ): self.vf = velocity_field self.flow = flow - self.time_sampler = time_sampler self.data_match_fn = data_match_fn + self.time_sampler = time_sampler self.latent_match_fn = latent_match_fn + if latent_noise_fn is None: + dim = kwargs["input_dim"] + latent_noise_fn = functools.partial( + flow_utils.multivariate_normal, dim=dim + ) self.latent_noise_fn = latent_noise_fn self.k_samples_per_x = k_samples_per_x @@ -90,13 +95,14 @@ def loss_fn( source_conditions: Optional[jnp.ndarray], rng: jax.Array ): x_t = self.flow.compute_xt(rng, time, latent, target) - apply_fn = functools.partial(vf_state.apply_fn, {"params": params}) + cond = ( + source if source_conditions is None else + jnp.concatenate([source, source_conditions], axis=-1) + ) - cond_input = jnp.concatenate([ - source, source_conditions - ], axis=1) if source_conditions is not None else source - v_t = jax.vmap(apply_fn)(t=time, x=x_t, condition=cond_input) + v_t = vf_state.apply_fn({"params": params}, time, x_t, cond) u_t = self.flow.compute_ut(time, latent, target) + return jnp.mean((v_t - u_t) ** 2) grad_fn = jax.value_and_grad(loss_fn, has_aux=False) @@ -104,6 +110,7 @@ def loss_fn( vf_state.params, time, source, target, latent, source_conditions, rng ) + # TODO(michalk8): follow the convention with loss being first return vf_state.apply_gradients(grads=grads), loss return step_fn @@ -174,7 +181,10 @@ def prepare_data( self.vf_state, loss = self.step_fn( rng_step_fn, self.vf_state, time, src, tgt, latent, src_cond ) + training_logs["loss"].append(float(loss)) + if len(training_logs["loss"]) >= n_iters: + break return training_logs diff --git a/src/ott/neural/flow_models/otfm.py b/src/ott/neural/flow_models/otfm.py index d436fd9a0..a8b85a44e 100644 --- a/src/ott/neural/flow_models/otfm.py +++ b/src/ott/neural/flow_models/otfm.py @@ -35,8 +35,8 @@ class OTFlowMatching: Args: velocity_field: Neural vector field parameterized by a neural network. flow: Flow between source and target distribution. - time_sampler: Sampler for the time. match_fn: TODO. + time_sampler: Sampler for the time. kwargs: TODO. """ @@ -46,10 +46,10 @@ def __init__( self, velocity_field: models.VelocityField, flow: flows.BaseFlow, - time_sampler: Callable[[jax.Array, int], - jnp.ndarray] = flow_utils.uniform_sampler, match_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None, + time_sampler: Callable[[jax.Array, int], + jnp.ndarray] = flow_utils.uniform_sampler, **kwargs: Any, ): self.vf = velocity_field @@ -79,10 +79,8 @@ def loss_fn( x_t = self.flow.compute_xt(rng, t, source, target) v_t = vf_state.apply_fn({"params": params}, t, x_t, source_conditions) - # TODO(michalk8): should be removed - # apply_fn = functools.partial(vf_state.apply_fn, {"params": params}) - # v_t = jax.vmap(apply_fn)(t, x_t, source_conditions) u_t = self.flow.compute_ut(t, source, target) + return jnp.mean((v_t - u_t) ** 2) batch_size = len(source) diff --git a/src/ott/neural/flow_models/utils.py b/src/ott/neural/flow_models/utils.py index 516342ed6..a656f440f 100644 --- a/src/ott/neural/flow_models/utils.py +++ b/src/ott/neural/flow_models/utils.py @@ -26,6 +26,7 @@ "sample_conditional", "cyclical_time_encoder", "uniform_sampler", + "multivariate_normal", ] ScaleCost_t = Union[float, Literal["mean", "max_cost", "median"]] @@ -159,3 +160,15 @@ def uniform_sampler( t = jax.random.uniform(rng, (1, 1), minval=low, maxval=high) mod_term = ((high - low) - offset) return (t + jnp.arange(num_samples)[:, None] / num_samples) % mod_term + + +def multivariate_normal( + rng: jax.Array, + shape: Tuple[int, ...], + dim: int, + mean: float = 0.0, + cov: float = 1.0 +) -> jnp.ndarray: + mean = jnp.full(dim, fill_value=mean) + cov = jnp.diag(jnp.full(dim, fill_value=cov)) + return jax.random.multivariate_normal(rng, mean=mean, cov=cov, shape=shape) diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py index c3cd11ce7..e7de132e8 100644 --- a/tests/neural/conftest.py +++ b/tests/neural/conftest.py @@ -26,39 +26,50 @@ def _ot_data( rng: np.random.Generator, *, n: int = 100, - dim: int = 2, + lin_dim: Optional[int] = None, + quad_dim: Optional[int] = None, condition: Optional[Union[float, np.ndarray]] = None, cond_dim: Optional[int] = None, offset: float = 0.0 ) -> datasets.OTData: - data = rng.normal(size=(n, dim)) + offset + assert lin_dim or quad_dim, "TODO" + + lin_data = None if lin_dim is None else ( + rng.normal(size=(n, lin_dim)) + offset + ) + quad_data = None if quad_dim is None else ( + rng.normal(size=(n, quad_dim)) + offset + ) if isinstance(condition, float): - cond_dim = dim if cond_dim is None else cond_dim + cond_dim = lin_dim if cond_dim is None else cond_dim condition = np.full((n, cond_dim), fill_value=condition) - return datasets.OTData(lin=data, condition=condition) + return datasets.OTData(lin=lin_data, quad=quad_data, condition=condition) @pytest.fixture() def lin_dl() -> DataLoader: """Returns a data loader for a simple Gaussian mixture.""" - n, d = 100, 2 + n, d = 128, 2 rng = np.random.default_rng(0) - src, tgt = _ot_data(rng, n=n, dim=d), _ot_data(rng, n=n, dim=d, offset=1.0) + + src = _ot_data(rng, n=n, lin_dim=d) + tgt = _ot_data(rng, n=n, lin_dim=d, offset=1.0) ds = datasets.OTDataset(src, tgt) + return DataLoader(ds, batch_size=16, shuffle=True) @pytest.fixture() def lin_dl_with_conds() -> DataLoader: - n, d, cond_dim = 100, 2, 3 + n, d, cond_dim = 128, 2, 3 rng = np.random.default_rng(13) src_cond = rng.normal(size=(n, cond_dim)) tgt_cond = rng.normal(size=(n, cond_dim)) - src = _ot_data(rng, n=n, dim=d, condition=src_cond) - tgt = _ot_data(rng, n=n, dim=d, condition=tgt_cond) + src = _ot_data(rng, n=n, lin_dim=d, condition=src_cond) + tgt = _ot_data(rng, n=n, lin_dim=d, condition=tgt_cond) ds = datasets.OTDataset(src, tgt) return DataLoader(ds, batch_size=16, shuffle=True) @@ -66,12 +77,12 @@ def lin_dl_with_conds() -> DataLoader: @pytest.fixture() def conditional_lin_dl() -> datasets.ConditionalLoader: - cond_dim = 4 + d, cond_dim = 2, 4 rng = np.random.default_rng(42) - src0 = _ot_data(rng, condition=0.0, cond_dim=cond_dim) + src0 = _ot_data(rng, condition=0.0, lin_dim=d, cond_dim=cond_dim) tgt0 = _ot_data(rng, offset=2.0) - src1 = _ot_data(rng, condition=1.0, cond_dim=cond_dim) + src1 = _ot_data(rng, condition=1.0, lin_dim=d, cond_dim=cond_dim) tgt1 = _ot_data(rng, offset=-2.0) src_ds = datasets.OTDataset(src0, tgt0) @@ -83,39 +94,22 @@ def conditional_lin_dl() -> datasets.ConditionalLoader: return datasets.ConditionalLoader([src_dl, tgt_dl]) -# TODO(michalk8): refactor the below for GENOT - - -@pytest.fixture(scope="module") -def genot_data_loader_linear(): - """Returns a data loader for a simple Gaussian mixture.""" - rng = np.random.default_rng(seed=0) - src = rng.normal(size=(100, 2)) - tgt = rng.normal(size=(100, 2)) + 1.0 - dataset = datasets.OTDataset(lin=src, tgt_lin=tgt) - return DataLoader(dataset, batch_size=16, shuffle=True) - - -@pytest.fixture(scope="module") -def genot_data_loader_linear_conditional(): - """Returns a data loader for a simple Gaussian mixture.""" - rng = np.random.default_rng(seed=0) - src_0 = rng.normal(size=(100, 2)) - tgt_0 = rng.normal(size=(100, 2)) + 1.0 - src_1 = rng.normal(size=(100, 2)) - tgt_1 = rng.normal(size=(100, 2)) + 1.0 - ds0 = datasets.OTDataset( - lin=src_0, tgt_lin=tgt_0, conditions=np.zeros_like(src_0) * 0.0 - ) - ds1 = datasets.OTDataset( - lin=src_1, tgt_lin=tgt_1, conditions=np.ones_like(src_1) * 1.0 +@pytest.fixture() +def quad_dl(): + n, d = 128, 2 + rng = np.random.default_rng(11) + src, tgt = _ot_data( + rng, n=n, quad_dim=d + ), _ot_data( + rng, n=n, quad_dim=d, offset=1.0 ) - sampler0 = torch.utils.data.RandomSampler(ds0, replacement=True) - sampler1 = torch.utils.data.RandomSampler(ds1, replacement=True) - dl0 = DataLoader(ds0, batch_size=16, sampler=sampler0) - dl1 = DataLoader(ds1, batch_size=16, sampler=sampler1) + ds = datasets.OTDataset(src, tgt) + return DataLoader(ds, batch_size=16, shuffle=True) - return datasets.ConditionalLoader((dl0, dl1)) + +@pytest.fixture() +def quad_dl_with_conds(): + pass @pytest.fixture(scope="module") diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index abd8aad94..a8c02f07b 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -16,15 +16,61 @@ import pytest +import jax import jax.numpy as jnp import optax -from ott.geometry import costs -from ott.neural.flow_models.models import VelocityField -from ott.neural.flow_models.utils import uniform_sampler -from ott.solvers.linear import sinkhorn, sinkhorn_lr -from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr +from ott.neural.flow_models import flows, genot, models, utils + + +def data_match_fn( + src_lin: jnp.ndarray, tgt_lin: jnp.ndarray, src_quad: jnp.ndarray, + tgt_quad: jnp.ndarray +): + # TODO(michalk8): extend for GW/FGW + return utils.match_linear(src_lin, tgt_lin) + + +class TestGENOT: + + # TODO(michalk8): test gw/fgw, k, etc. + @pytest.mark.parametrize(("cond_dim", "dl"), [(2, "lin_dl")]) + def test_genot2(self, rng: jax.Array, cond_dim: int, dl: str, request): + rng_init, rng_call = jax.random.split(rng) + input_dim, hidden_dim = 2, 7 + dl = request.getfixturevalue(dl) + + vf = models.VelocityField( + hidden_dim=hidden_dim, + output_dim=input_dim, + # TODO(michalk8): the source is the condition + condition_dim=cond_dim, + ) + + model = genot.GENOT( + vf, + flow=flows.ConstantNoiseFlow(0.0), + data_match_fn=data_match_fn, + rng=rng_init, + optimizer=optax.adam(learning_rate=1e-3), + input_dim=input_dim, + condition_dim=cond_dim, + ) + + _logs = model(dl, n_iters=3, rng=rng_call) + + # TODO(michalk8): generalize for gw/fgw + batch = next(iter(dl)) + src = jnp.asarray(batch["src_lin"]) + tgt = jnp.asarray(batch["tgt_lin"]) + src_cond = batch.get("src_condition") + if src_cond is not None: + src_cond = jnp.asarray(src_cond) + + res = model.transport(src, condition=src_cond) + + assert jnp.sum(jnp.isnan(res)) == 0 class TestGENOTLin: diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index a9b799d4a..a4db65fa5 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -27,12 +27,12 @@ class TestOTFlowMatching: (3, "lin_dl_with_conds"), (4, "conditional_lin_dl")]) def test_fm(self, rng: jax.Array, cond_dim: int, dl: str, request): - output_dim, hidden_dim = 2, 5 + input_dim, hidden_dim = 2, 5 dl = request.getfixturevalue(dl) neural_vf = models.VelocityField( - hidden_dim, - output_dim, + hidden_dim=hidden_dim, + output_dim=input_dim, condition_dim=hidden_dim if cond_dim > 0 else None, ) fm = otfm.OTFlowMatching( @@ -41,7 +41,7 @@ def test_fm(self, rng: jax.Array, cond_dim: int, dl: str, request): match_fn=jax.jit(utils.match_linear), rng=rng, optimizer=optax.adam(learning_rate=1e-3), - input_dim=2, # all dataloaders have dim `2` + input_dim=input_dim, condition_dim=cond_dim, ) From 8bc9b104b611bbaec95fda9f0efcee0975aa6a88 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 15 Mar 2024 19:05:29 +0100 Subject: [PATCH 137/186] Remove old GENOT tests --- src/ott/neural/flow_models/utils.py | 1 + tests/neural/genot_test.py | 388 ---------------------------- 2 files changed, 1 insertion(+), 388 deletions(-) diff --git a/src/ott/neural/flow_models/utils.py b/src/ott/neural/flow_models/utils.py index a656f440f..1181de1cb 100644 --- a/src/ott/neural/flow_models/utils.py +++ b/src/ott/neural/flow_models/utils.py @@ -169,6 +169,7 @@ def multivariate_normal( mean: float = 0.0, cov: float = 1.0 ) -> jnp.ndarray: + """TODO.""" mean = jnp.full(dim, fill_value=mean) cov = jnp.diag(jnp.full(dim, fill_value=cov)) return jax.random.multivariate_normal(rng, mean=mean, cov=cov, shape=shape) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index a8c02f07b..063b770db 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import functools -from typing import Iterator, Literal, Optional, Union import pytest @@ -63,7 +61,6 @@ def test_genot2(self, rng: jax.Array, cond_dim: int, dl: str, request): # TODO(michalk8): generalize for gw/fgw batch = next(iter(dl)) src = jnp.asarray(batch["src_lin"]) - tgt = jnp.asarray(batch["tgt_lin"]) src_cond = batch.get("src_condition") if src_cond is not None: src_cond = jnp.asarray(src_cond) @@ -71,388 +68,3 @@ def test_genot2(self, rng: jax.Array, cond_dim: int, dl: str, request): res = model.transport(src, condition=src_cond) assert jnp.sum(jnp.isnan(res)) == 0 - - -class TestGENOTLin: - - @pytest.mark.parametrize("scale_cost", ["mean", 2.0]) - @pytest.mark.parametrize("k_samples_per_x", [1, 3]) - @pytest.mark.parametrize("solver_latent_to_data", [None, "sinkhorn"]) - @pytest.mark.parametrize("solver", ["sinkhorn", "lr_sinkhorn"]) - def test_genot_linear_unconditional( - self, - genot_data_loader_linear: Iterator, - scale_cost: Union[float, Literal["mean"]], - k_samples_per_x: int, - solver_latent_to_data: Optional[str], - solver: Literal["sinkhorn", "lr_sinkhorn"], - ): - matcher_latent_to_data = ( - None if solver_latent_to_data is None else - base_solver.OTMatcherLinear(sinkhorn.Sinkhorn()) - ) - batch = next(iter(genot_data_loader_linear)) - source_lin, source_conditions, target_lin = jnp.array( - batch["source_lin"] - ), jnp.array(batch["source_conditions"]) if len(batch["source_conditions"] - ) else None, jnp.array( - batch["target_lin"] - ) - - source_dim = source_lin.shape[1] - target_dim = target_lin.shape[1] - condition_dim = 0 - - neural_vf = VelocityField( - output_dim=target_dim, - condition_dim=source_dim + condition_dim, - hidden_dim=5, - ) - ot_solver = sinkhorn.Sinkhorn( - ) if solver == "sinkhorn" else sinkhorn_lr.LRSinkhorn(rank=3) - ot_matcher = base_solver.OTMatcherLinear( - ot_solver, cost_fn=costs.SqEuclidean(), scale_cost=scale_cost - ) - time_sampler = uniform_sampler - optimizer = optax.adam(learning_rate=1e-3) - genot = GENOTLin( - neural_vf, - input_dim=source_dim, - output_dim=target_dim, - cond_dim=condition_dim, - ot_matcher=ot_matcher, - optimizer=optimizer, - time_sampler=time_sampler, - k_samples_per_x=k_samples_per_x, - matcher_latent_to_data=matcher_latent_to_data, - ) - genot( - genot_data_loader_linear, - genot_data_loader_linear, - n_iters=2, - valid_freq=3 - ) - - batch = next(iter(genot_data_loader_linear)) - result_forward = genot.transport( - source_lin, condition=source_conditions, forward=True - ) - assert isinstance(result_forward, jnp.ndarray) - assert jnp.sum(jnp.isnan(result_forward)) == 0 - - @pytest.mark.parametrize("k_samples_per_x", [1, 2]) - @pytest.mark.parametrize("solver_latent_to_data", [None, "sinkhorn"]) - @pytest.mark.parametrize("solver", ["sinkhorn", "lr_sinkhorn"]) - def test_genot_linear_conditional( - self, genot_data_loader_linear_conditional: Iterator, - k_samples_per_x: int, solver_latent_to_data: Optional[str], - solver: Literal["sinkhorn", "lr_sinkhorn"] - ): - matcher_latent_to_data = ( - None if solver_latent_to_data is None else - base_solver.OTMatcherLinear(sinkhorn.Sinkhorn()) - ) - - batch = next(iter(genot_data_loader_linear_conditional)) - source_lin, source_conditions, target_lin = jnp.array( - batch["source_lin"] - ), jnp.array(batch["source_conditions"]) if len(batch["source_conditions"] - ) else None, jnp.array( - batch["target_lin"] - ) - source_dim = source_lin.shape[1] - target_dim = target_lin.shape[1] - condition_dim = source_conditions.shape[1] - - neural_vf = VelocityField( - output_dim=target_dim, - condition_dim=source_dim + condition_dim, - hidden_dim=5, - ) - ot_solver = sinkhorn.Sinkhorn( - ) if solver == "sinkhorn" else sinkhorn_lr.LRSinkhorn(rank=3) - ot_matcher = base_solver.OTMatcherLinear( - ot_solver, cost_fn=costs.SqEuclidean() - ) - time_sampler = uniform_sampler - - optimizer = optax.adam(learning_rate=1e-3) - genot = GENOTLin( - neural_vf, - input_dim=source_dim, - output_dim=target_dim, - cond_dim=condition_dim, - ot_matcher=ot_matcher, - optimizer=optimizer, - time_sampler=time_sampler, - k_samples_per_x=k_samples_per_x, - matcher_latent_to_data=matcher_latent_to_data, - ) - genot( - genot_data_loader_linear_conditional, - genot_data_loader_linear_conditional, - n_iters=2, - valid_freq=3 - ) - result_forward = genot.transport( - source_lin, condition=source_conditions, forward=True - ) - assert isinstance(result_forward, jnp.ndarray) - assert jnp.sum(jnp.isnan(result_forward)) == 0 - - -class TestGENOTQuad: - - @pytest.mark.parametrize("k_samples_per_x", [1, 2]) - @pytest.mark.parametrize("solver_latent_to_data", [None, "sinkhorn"]) - @pytest.mark.parametrize("solver", ["gromov", "gromov_lr"]) - def test_genot_quad_unconditional( - self, genot_data_loader_quad: Iterator, k_samples_per_x: int, - solver_latent_to_data: Optional[str], solver: Literal["gromov", - "gromov_lr"] - ): - matcher_latent_to_data = ( - None if solver_latent_to_data is None else - base_solver.OTMatcherLinear(sinkhorn.Sinkhorn()) - ) - - batch = next(iter(genot_data_loader_quad)) - (source_quad, source_conditions, target_quad) = ( - jnp.array(batch["source_quad"]), jnp.array(batch["source_conditions"]) - if len(batch["source_conditions"]) else None, - jnp.array(batch["target_quad"]) - ) - source_dim = source_quad.shape[1] - target_dim = target_quad.shape[1] - condition_dim = 0 - neural_vf = VelocityField( - output_dim=target_dim, - condition_dim=source_dim + condition_dim, - hidden_dim=5, - ) - ot_solver = gromov_wasserstein.GromovWasserstein( - epsilon=1e-2 - ) if solver == "gromov" else gromov_wasserstein_lr.LRGromovWasserstein( - rank=3, epsilon=1e-2 - ) - ot_matcher = base_solver.OTMatcherQuad( - ot_solver, cost_fn=costs.SqEuclidean() - ) - - time_sampler = functools.partial(uniform_sampler, offset=1e-2) - optimizer = optax.adam(learning_rate=1e-3) - genot = GENOTQuad( - neural_vf, - input_dim=source_dim, - output_dim=target_dim, - cond_dim=condition_dim, - ot_matcher=ot_matcher, - optimizer=optimizer, - time_sampler=time_sampler, - k_samples_per_x=k_samples_per_x, - matcher_latent_to_data=matcher_latent_to_data, - ) - genot( - genot_data_loader_quad, genot_data_loader_quad, n_iters=2, valid_freq=3 - ) - - result_forward = genot.transport( - source_quad, condition=source_conditions, forward=True - ) - assert isinstance(result_forward, jnp.ndarray) - assert jnp.sum(jnp.isnan(result_forward)) == 0 - - @pytest.mark.parametrize("k_samples_per_x", [1, 2]) - @pytest.mark.parametrize("solver_latent_to_data", [None, "sinkhorn"]) - @pytest.mark.parametrize("solver", ["gromov", "gromov_lr"]) - def test_genot_fused_unconditional( - self, genot_data_loader_fused: Iterator, k_samples_per_x: int, - solver_latent_to_data: Optional[str], solver: Literal["gromov", - "gromov_lr"] - ): - matcher_latent_to_data = ( - None if solver_latent_to_data is None else - base_solver.OTMatcherLinear(sinkhorn.Sinkhorn()) - ) - - batch = next(iter(genot_data_loader_fused)) - (source_lin, source_quad, source_conditions, target_lin, target_quad) = ( - jnp.array(batch["source_lin"]) if len(batch["source_lin"]) else None, - jnp.array(batch["source_quad"]), jnp.array(batch["source_conditions"]) - if len(batch["source_conditions"]) else None, - jnp.array(batch["target_lin"]) if len(batch["target_lin"]) else None, - jnp.array(batch["target_quad"]) - ) - source_dim = source_lin.shape[1] + source_quad.shape[1] - target_dim = target_lin.shape[1] + target_quad.shape[1] - condition_dim = 0 - neural_vf = VelocityField( - output_dim=target_dim, - condition_dim=source_dim + condition_dim, - hidden_dim=5, - ) - ot_solver = gromov_wasserstein.GromovWasserstein( - epsilon=1e-2 - ) if solver == "gromov" else gromov_wasserstein_lr.LRGromovWasserstein( - rank=3, epsilon=1e-2 - ) - ot_matcher = base_solver.OTMatcherQuad( - ot_solver, cost_fn=costs.SqEuclidean(), fused_penalty=0.5 - ) - - optimizer = optax.adam(learning_rate=1e-3) - genot = GENOTQuad( - neural_vf, - input_dim=source_dim, - output_dim=target_dim, - cond_dim=condition_dim, - ot_matcher=ot_matcher, - optimizer=optimizer, - k_samples_per_x=k_samples_per_x, - matcher_latent_to_data=matcher_latent_to_data, - ) - genot( - genot_data_loader_fused, - genot_data_loader_fused, - n_iters=2, - valid_freq=3 - ) - - result_forward = genot.transport( - jnp.concatenate((source_lin, source_quad), axis=1), - condition=source_conditions, - forward=True - ) - assert isinstance(result_forward, jnp.ndarray) - assert jnp.sum(jnp.isnan(result_forward)) == 0 - - @pytest.mark.parametrize("k_samples_per_x", [1, 2]) - @pytest.mark.parametrize("solver_latent_to_data", [None, "sinkhorn"]) - @pytest.mark.parametrize("solver", ["gromov", "gromov_lr"]) - def test_genot_quad_conditional( - self, genot_data_loader_quad_conditional: Iterator, k_samples_per_x: int, - solver_latent_to_data: Optional[str], solver: Literal["gromov", - "gromov_lr"] - ): - matcher_latent_to_data = ( - None if solver_latent_to_data is None else - base_solver.OTMatcherLinear(sinkhorn.Sinkhorn()) - ) - - batch = next(iter(genot_data_loader_quad_conditional)) - (source_quad, source_conditions, target_quad) = ( - jnp.array(batch["source_quad"]), jnp.array(batch["source_conditions"]) - if len(batch["source_conditions"]) else None, - jnp.array(batch["target_quad"]) - ) - - source_dim = source_quad.shape[1] - target_dim = target_quad.shape[1] - condition_dim = source_conditions.shape[1] - neural_vf = VelocityField( - output_dim=target_dim, - condition_dim=source_dim + condition_dim, - hidden_dim=5, - ) - ot_solver = gromov_wasserstein.GromovWasserstein( - epsilon=1e-2 - ) if solver == "gromov" else gromov_wasserstein_lr.LRGromovWasserstein( - rank=3, epsilon=1e-2 - ) - ot_matcher = base_solver.OTMatcherQuad( - ot_solver, cost_fn=costs.SqEuclidean() - ) - time_sampler = uniform_sampler - - optimizer = optax.adam(learning_rate=1e-3) - genot = GENOTQuad( - neural_vf, - input_dim=source_dim, - output_dim=target_dim, - cond_dim=condition_dim, - ot_matcher=ot_matcher, - optimizer=optimizer, - time_sampler=time_sampler, - k_samples_per_x=k_samples_per_x, - matcher_latent_to_data=matcher_latent_to_data, - ) - genot( - genot_data_loader_quad_conditional, - genot_data_loader_quad_conditional, - n_iters=2, - valid_freq=3 - ) - - result_forward = genot.transport( - source_quad, condition=source_conditions, forward=True - ) - assert isinstance(result_forward, jnp.ndarray) - assert jnp.sum(jnp.isnan(result_forward)) == 0 - - @pytest.mark.parametrize("k_samples_per_x", [1, 2]) - @pytest.mark.parametrize("solver_latent_to_data", [None, "sinkhorn"]) - @pytest.mark.parametrize("solver", ["gromov", "gromov_lr"]) - def test_genot_fused_conditional( - self, genot_data_loader_fused_conditional: Iterator, k_samples_per_x: int, - solver_latent_to_data: Optional[str], solver: Literal["gromov", - "gromov_lr"] - ): - solver_latent_to_data = ( - None if solver_latent_to_data is None else sinkhorn.Sinkhorn() - ) - matcher_latent_to_data = ( - None if solver_latent_to_data is None else - base_solver.OTMatcherLinear(sinkhorn.Sinkhorn()) - ) - batch = next(iter(genot_data_loader_fused_conditional)) - (source_lin, source_quad, source_conditions, target_lin, target_quad) = ( - jnp.array(batch["source_lin"]) if len(batch["source_lin"]) else None, - jnp.array(batch["source_quad"]), jnp.array(batch["source_conditions"]) - if len(batch["source_conditions"]) else None, - jnp.array(batch["target_lin"]) if len(batch["target_lin"]) else None, - jnp.array(batch["target_quad"]) - ) - source_dim = source_lin.shape[1] + source_quad.shape[1] - target_dim = target_lin.shape[1] + target_quad.shape[1] - condition_dim = source_conditions.shape[1] - neural_vf = VelocityField( - output_dim=target_dim, - condition_dim=source_dim + condition_dim, - hidden_dim=5, - ) - ot_solver = gromov_wasserstein.GromovWasserstein( - epsilon=1e-2 - ) if solver == "gromov" else gromov_wasserstein_lr.LRGromovWasserstein( - rank=3, epsilon=1e-2 - ) - ot_matcher = base_solver.OTMatcherQuad( - ot_solver, cost_fn=costs.SqEuclidean(), fused_penalty=0.5 - ) - time_sampler = uniform_sampler - optimizer = optax.adam(learning_rate=1e-3) - - genot = GENOTQuad( - neural_vf, - input_dim=source_dim, - output_dim=target_dim, - cond_dim=condition_dim, - ot_matcher=ot_matcher, - optimizer=optimizer, - time_sampler=time_sampler, - k_samples_per_x=k_samples_per_x, - matcher_latent_to_data=matcher_latent_to_data, - ) - genot( - genot_data_loader_fused_conditional, - genot_data_loader_fused_conditional, - n_iters=2, - valid_freq=3 - ) - - result_forward = genot.transport( - jnp.concatenate((source_lin, source_quad), axis=1), - condition=source_conditions, - forward=True - ) - assert isinstance(result_forward, jnp.ndarray) - assert jnp.sum(jnp.isnan(result_forward)) == 0 From 6f4f8640daa6a77180168ed2ffc9d2febb50d7b3 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 15 Mar 2024 19:06:21 +0100 Subject: [PATCH 138/186] Remove old dataloaders --- tests/neural/conftest.py | 82 ---------------------------------------- 1 file changed, 82 deletions(-) diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py index e7de132e8..a3d89d959 100644 --- a/tests/neural/conftest.py +++ b/tests/neural/conftest.py @@ -16,7 +16,6 @@ import pytest import numpy as np -import torch from torch.utils.data import DataLoader from ott.neural.data import datasets @@ -110,84 +109,3 @@ def quad_dl(): @pytest.fixture() def quad_dl_with_conds(): pass - - -@pytest.fixture(scope="module") -def genot_data_loader_quad(): - """Returns a data loader for a simple Gaussian mixture.""" - rng = np.random.default_rng(seed=0) - src = rng.normal(size=(100, 2)) - tgt = rng.normal(size=(100, 1)) + 1.0 - dataset = datasets.OTDataset(quad=src, tgt_quad=tgt) - return DataLoader(dataset, batch_size=16, shuffle=True) - - -@pytest.fixture(scope="module") -def genot_data_loader_quad_conditional(): - """Returns a data loader for a simple Gaussian mixture.""" - rng = np.random.default_rng(seed=0) - src_0 = rng.normal(size=(100, 2)) - tgt_0 = rng.normal(size=(100, 1)) + 1.0 - src_1 = rng.normal(size=(100, 2)) - tgt_1 = rng.normal(size=(100, 1)) + 1.0 - ds0 = datasets.OTDataset( - quad=src_0, tgt_quad=tgt_0, conditions=np.zeros_like(src_0) * 0.0 - ) - ds1 = datasets.OTDataset( - quad=src_1, tgt_quad=tgt_1, conditions=np.ones_like(src_1) * 1.0 - ) - sampler0 = torch.utils.data.RandomSampler(ds0, replacement=True) - sampler1 = torch.utils.data.RandomSampler(ds1, replacement=True) - dl0 = DataLoader(ds0, batch_size=16, sampler=sampler0) - dl1 = DataLoader(ds1, batch_size=16, sampler=sampler1) - - return datasets.ConditionalLoader((dl0, dl1)) - - -@pytest.fixture(scope="module") -def genot_data_loader_fused(): - """Returns a data loader for a simple Gaussian mixture.""" - rng = np.random.default_rng(seed=0) - src_q = rng.normal(size=(100, 2)) - tgt_q = rng.normal(size=(100, 1)) + 1.0 - src_lin = rng.normal(size=(100, 2)) - tgt_lin = rng.normal(size=(100, 2)) + 1.0 - dataset = datasets.OTDataset( - lin=src_lin, quad=src_q, tgt_lin=tgt_lin, tgt_quad=tgt_q - ) - return DataLoader(dataset, batch_size=16, shuffle=True) - - -@pytest.fixture(scope="module") -def genot_data_loader_fused_conditional(): - """Returns a data loader for a simple Gaussian mixture.""" - rng = np.random.default_rng(seed=0) - src_q_0 = rng.normal(size=(100, 2)) - tgt_q_0 = rng.normal(size=(100, 1)) + 1.0 - src_lin_0 = rng.normal(size=(100, 2)) - tgt_lin_0 = rng.normal(size=(100, 2)) + 1.0 - - src_q_1 = 2 * rng.normal(size=(100, 2)) - tgt_q_1 = 2 * rng.normal(size=(100, 1)) + 1.0 - src_lin_1 = 2 * rng.normal(size=(100, 2)) - tgt_lin_1 = 2 * rng.normal(size=(100, 2)) + 1.0 - - ds0 = datasets.OTDataset( - lin=src_lin_0, - tgt_lin=tgt_lin_0, - quad=src_q_0, - tgt_quad=tgt_q_0, - conditions=np.zeros_like(src_lin_0) * 0.0 - ) - ds1 = datasets.OTDataset( - lin=src_lin_1, - tgt_lin=tgt_lin_1, - quad=src_q_1, - tgt_quad=tgt_q_1, - conditions=np.ones_like(src_lin_1) * 1.0 - ) - sampler0 = torch.utils.data.RandomSampler(ds0, replacement=True) - sampler1 = torch.utils.data.RandomSampler(ds1, replacement=True) - dl0 = DataLoader(ds0, batch_size=16, sampler=sampler0) - dl1 = DataLoader(ds1, batch_size=16, sampler=sampler1) - return datasets.ConditionalLoader((dl0, dl1)) From 11911c4081c9d4fc7317bf92fd1232d4b70c9fa8 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Sun, 17 Mar 2024 23:48:33 +0100 Subject: [PATCH 139/186] Add more todos --- src/ott/neural/flow_models/genot.py | 11 +++++++---- src/ott/neural/flow_models/otfm.py | 6 ++++-- tests/neural/conftest.py | 9 ++++----- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/ott/neural/flow_models/genot.py b/src/ott/neural/flow_models/genot.py index 34a207b50..fc550cd93 100644 --- a/src/ott/neural/flow_models/genot.py +++ b/src/ott/neural/flow_models/genot.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple import jax import jax.numpy as jnp import jax.tree_util as jtu +import numpy as np import diffrax from flax.training import train_state @@ -40,7 +41,9 @@ class GENOT: of an input sample, see algorithm TODO. latent_match_fn: TODO. latent_noise_fn: TODO. + # TODO(michalk8): rename k_samples_per_x: Number of samples drawn from the conditional distribution + # TODO(michalk8): expose all args for the train state? kwargs: TODO. """ @@ -48,6 +51,7 @@ def __init__( self, velocity_field: models.VelocityField, flow: flows.BaseFlow, + # TODO(michalk8): all of these can be optional, explain in the docs data_match_fn: Callable[ [jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray], jnp.ndarray], time_sampler: Callable[[jax.Array, int], @@ -66,9 +70,8 @@ def __init__( self.time_sampler = time_sampler self.latent_match_fn = latent_match_fn if latent_noise_fn is None: - dim = kwargs["input_dim"] latent_noise_fn = functools.partial( - flow_utils.multivariate_normal, dim=dim + flow_utils.multivariate_normal, dim=kwargs["input_dim"] ) self.latent_noise_fn = latent_noise_fn self.k_samples_per_x = k_samples_per_x @@ -117,7 +120,7 @@ def loss_fn( def __call__( self, - loader: Any, + loader: Iterable[Dict[str, np.ndarray]], n_iters: int, rng: Optional[jax.Array] = None ) -> Dict[str, List[float]]: diff --git a/src/ott/neural/flow_models/otfm.py b/src/ott/neural/flow_models/otfm.py index a8b85a44e..e1ad5aaab 100644 --- a/src/ott/neural/flow_models/otfm.py +++ b/src/ott/neural/flow_models/otfm.py @@ -11,11 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple import jax import jax.numpy as jnp import jax.tree_util as jtu +import numpy as np import diffrax from flax.training import train_state @@ -37,6 +38,7 @@ class OTFlowMatching: flow: Flow between source and target distribution. match_fn: TODO. time_sampler: Sampler for the time. + # TODO(michalk8): expose all args for the train state? kwargs: TODO. """ @@ -97,7 +99,7 @@ def loss_fn( # TODO(michalk8): refactor in the future PR to just do one step def __call__( # noqa: D102 self, - loader: Any, # TODO(michalk8): type it correctly + loader: Iterable[Dict[str, np.ndarray]], *, n_iters: int, rng: Optional[jax.Array] = None, diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py index a3d89d959..d0bc11e7e 100644 --- a/tests/neural/conftest.py +++ b/tests/neural/conftest.py @@ -97,12 +97,11 @@ def conditional_lin_dl() -> datasets.ConditionalLoader: def quad_dl(): n, d = 128, 2 rng = np.random.default_rng(11) - src, tgt = _ot_data( - rng, n=n, quad_dim=d - ), _ot_data( - rng, n=n, quad_dim=d, offset=1.0 - ) + + src = _ot_data(rng, n=n, quad_dim=d) + tgt = _ot_data(rng, n=n, quad_dim=d, offset=1.0) ds = datasets.OTDataset(src, tgt) + return DataLoader(ds, batch_size=16, shuffle=True) From a8de2ea17ff920ca536265a333ce843fe5a00842 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 19 Mar 2024 10:08:54 +0100 Subject: [PATCH 140/186] add docs to dataloader --- src/ott/neural/data/datasets.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/src/ott/neural/data/datasets.py b/src/ott/neural/data/datasets.py index 63215e61e..0ec8edfe9 100644 --- a/src/ott/neural/data/datasets.py +++ b/src/ott/neural/data/datasets.py @@ -24,7 +24,13 @@ @dataclasses.dataclass(repr=False, frozen=True) class OTData: - """TODO.""" + """Distribution data for (conditional) optimal transport problems. + + Args: + lin: Linear (living in the shared space) part of the samples. + quad: Quadratic (living in the incomparable subspace) part of the samples. + condition: Condition corresponding to the data distribution. + """ lin: Optional[np.ndarray] = None quad: Optional[np.ndarray] = None condition: Optional[np.ndarray] = None @@ -41,7 +47,16 @@ def __len__(self) -> int: class OTDataset: - """TODO.""" + """Dataset for (conditional) optimal transport problems. + + Args: + src_data: Samples from the source distribution. + tgt_data: Samples from the target distribution. + src_conditions: Conditions for the source data. + tgt_conditions: Conditions for the target data. + is_aligned: Whether the samples from `src_data` and `tgt_data` are paired. + seed: Random seed. + """ SRC_PREFIX = "src" TGT_PREFIX = "tgt" @@ -71,9 +86,9 @@ def __init__( self.is_aligned = is_aligned self._rng = np.random.default_rng(seed) - self._verify_integriy() + self._verify_integrity() - def _verify_integriy(self) -> None: + def _verify_integrity(self) -> None: assert len(self.src_data) == len(self.src_conditions) assert len(self.tgt_data) == len(self.tgt_conditions) From dfaf0428334e5d54d374c845e185cbf3d8f5f60c Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 19 Mar 2024 10:36:52 +0100 Subject: [PATCH 141/186] expose args in GENOT --- src/ott/neural/flow_models/genot.py | 68 ++++++++++++++++++++-------- src/ott/neural/flow_models/models.py | 5 +- 2 files changed, 53 insertions(+), 20 deletions(-) diff --git a/src/ott/neural/flow_models/genot.py b/src/ott/neural/flow_models/genot.py index fc550cd93..cd10784a2 100644 --- a/src/ott/neural/flow_models/genot.py +++ b/src/ott/neural/flow_models/genot.py @@ -30,21 +30,37 @@ class GENOT: - """TODO :cite:`klein_uscidda:23`. + """GENOT for entropic neural optimal transport :cite:`klein_uscidda:23`. + + GENOT (Generative Entropic Neural Optimal Transport) is a framework for + learning neural optimal transport plans between two distributions. It + allows for learning linear and quadratic (Fused) Gromov-Wasserstein couplings, + in both the balanced and the unbalanced setting. + Args: velocity_field: Neural vector field parameterized by a neural network. flow: Flow between latent distribution and target distribution. - data_match_fn: Linear OT solver to match the latent distribution - with the conditional distribution. - time_sampler: Sampler for the time. - of an input sample, see algorithm TODO. - latent_match_fn: TODO. - latent_noise_fn: TODO. + data_match_fn: OT solver to matching the source and the target distribution. + source_dim: Dimension of the source space. + target_dim: Dimension of the target space. + condition_dim: Dimension of the conditions. + time_sampler: Sampler for the time to learn the neural ODE. If :obj:`None`, + the time is uniformly sampled. # TODO(michalk8): rename k_samples_per_x: Number of samples drawn from the conditional distribution + per single source sample. + latent_match_fn: Linear OT matcher to optimally pair the latent + distribution with the `k_samples_per_x` samples of the conditional + distribution (corresponding to one sample). If :obj:`None`, samples + from the latent distribution are randomly paired with the samples from + the conditional distribution. + latent_noise_fn: Function to sample from the latent distribution in the + target space. If :obj:`None`, the latent distribution is sampled from a + multivariate normal distribution. # TODO(michalk8): expose all args for the train state? - kwargs: TODO. + kwargs: Keyword arguments for + :meth:`ott.neural.flow_models.models.VelocityField.create_train_state`. """ def __init__( @@ -54,14 +70,17 @@ def __init__( # TODO(michalk8): all of these can be optional, explain in the docs data_match_fn: Callable[ [jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray], jnp.ndarray], + source_dim: int, + target_dim: int, + condition_dim: int, time_sampler: Callable[[jax.Array, int], jnp.ndarray] = flow_utils.uniform_sampler, + # TODO(michalk8): rename, too descriptive + k_samples_per_x: int = 1, latent_match_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None, latent_noise_fn: Optional[Callable[[jax.Array, Tuple[int, ...]], jnp.ndarray]] = None, - # TODO(michalk8): rename, too descriptive - k_samples_per_x: int = 1, **kwargs: Any, ): self.vf = velocity_field @@ -71,12 +90,16 @@ def __init__( self.latent_match_fn = latent_match_fn if latent_noise_fn is None: latent_noise_fn = functools.partial( - flow_utils.multivariate_normal, dim=kwargs["input_dim"] + flow_utils.multivariate_normal, dim=target_dim ) self.latent_noise_fn = latent_noise_fn self.k_samples_per_x = k_samples_per_x - self.vf_state = self.vf.create_train_state(**kwargs) + self.vf_state = self.vf.create_train_state( + input_dim=target_dim, + condition_dim=source_dim + condition_dim, + **kwargs + ) self.step_fn = self._get_step_fn() def _get_step_fn(self) -> Callable: @@ -113,8 +136,7 @@ def loss_fn( vf_state.params, time, source, target, latent, source_conditions, rng ) - # TODO(michalk8): follow the convention with loss being first - return vf_state.apply_gradients(grads=grads), loss + return loss, vf_state.apply_gradients(grads=grads) return step_fn @@ -124,7 +146,17 @@ def __call__( n_iters: int, rng: Optional[jax.Array] = None ) -> Dict[str, List[float]]: - """TODO.""" + """Train the GENOT model. + + Args: + loader: Data loader returning a dictionary with possible keys + `src_lin`, `tgt_lin`, `src_quad`, `tgt_quad`, `src_conditions`. + n_iters: Number of iterations to train the model. + rng: Random seed. + + Returns: + Training logs. + """ def prepare_data( batch: Dict[str, jnp.ndarray] @@ -150,8 +182,8 @@ def prepare_data( rng = utils.default_prng_key(rng) training_logs = {"loss": []} for batch in loader: - rng = jax.random.split(rng, 6) - rng, rng_resample, rng_noise, rng_time, rng_latent, rng_step_fn = rng + rng = jax.random.split(rng, 5) + rng, rng_resample, rng_noise, rng_time, rng_step_fn = rng batch = jtu.tree_map(jnp.asarray, batch) (src, src_cond, tgt), matching_data = prepare_data(batch) @@ -181,7 +213,7 @@ def prepare_data( if src_cond is not None: src_cond = src_cond.reshape(-1, *src_cond.shape[2:]) - self.vf_state, loss = self.step_fn( + loss, self.vf_state = self.step_fn( rng_step_fn, self.vf_state, time, src, tgt, latent, src_cond ) diff --git a/src/ott/neural/flow_models/models.py b/src/ott/neural/flow_models/models.py index a3b3261b0..eb164e12c 100644 --- a/src/ott/neural/flow_models/models.py +++ b/src/ott/neural/flow_models/models.py @@ -109,8 +109,9 @@ def create_train_state( Args: rng: Random number generator. optimizer: Optimizer. - input_dim: Dimensionality of the input. - condition_dim: TODO. + input_dim: Dimensionality of the velocity field. + condition_dim: Dimensionsanilty of the condition + to the velocity field. Returns: The training state. From 2734c605c5c175fe3db5d5b6be2f4af4b7a436b9 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 19 Mar 2024 11:08:11 +0100 Subject: [PATCH 142/186] add docs and adapt data_match_fn --- src/ott/neural/flow_models/utils.py | 38 +++++++++++-- tests/neural/conftest.py | 2 +- tests/neural/genot_test.py | 83 +++++++++++++++++++++++------ 3 files changed, 102 insertions(+), 21 deletions(-) diff --git a/src/ott/neural/flow_models/utils.py b/src/ott/neural/flow_models/utils.py index 1181de1cb..00c246018 100644 --- a/src/ott/neural/flow_models/utils.py +++ b/src/ott/neural/flow_models/utils.py @@ -40,7 +40,19 @@ def match_linear( scale_cost: ScaleCost_t = 1.0, **kwargs: Any ) -> jnp.ndarray: - """TODO.""" + """Compute solution to a linear OT problem. + + Args: + x: Linear term of the source point cloud. + y: Linear term of the target point cloud. + cost_fn: Cost function. + epsilon: Regularization parameter. + scale_cost: Scaling of the cost matrix. + kwargs: Additional arguments for :func:`ott.solvers.linear.solve`. + + Returns: + Optimal transport matrix. + """ geom = pointcloud.PointCloud( x, y, cost_fn=cost_fn, epsilon=epsilon, scale_cost=scale_cost ) @@ -51,19 +63,35 @@ def match_linear( def match_quadratic( xx: jnp.ndarray, yy: jnp.ndarray, - xy: Optional[jnp.ndarray] = None, + x: Optional[jnp.ndarray] = None, + y: Optional[jnp.ndarray] = None, # TODO(michalk8): expose for all the costs scale_cost: ScaleCost_t = 1.0, cost_fn: Optional[costs.CostFn] = None, **kwargs: Any ) -> jnp.ndarray: - """TODO.""" + """Compute solution to a quadratic OT problem. + + Args: + xx: Quadratic (incomparable) term of the source point cloud. + yy: Quadratic (incomparable) term of the target point cloud. + x: Linear (fused) term of the source point cloud. + y: Linear (fused) term of the target point cloud. + scale_cost: Scaling of the cost matrix. + cost_fn: Cost function. + kwargs: Additional arguments for :func:`ott.solvers.quadratic.solve`. + + Returns: + Optimal transport matrix. + """ geom_xx = pointcloud.PointCloud(xx, cost_fn=cost_fn, scale_cost=scale_cost) geom_yy = pointcloud.PointCloud(yy, cost_fn=cost_fn, scale_cost=scale_cost) - if xy is None: + if x is None: geom_xy = None else: - geom_xy = pointcloud.PointCloud(xy, cost_fn=cost_fn, scale_cost=scale_cost) + geom_xy = pointcloud.PointCloud( + x, y, cost_fn=cost_fn, scale_cost=scale_cost + ) out = quadratic.solve(geom_xx, geom_yy, geom_xy, **kwargs) return out.matrix diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py index d0bc11e7e..ede2d4dc4 100644 --- a/tests/neural/conftest.py +++ b/tests/neural/conftest.py @@ -99,7 +99,7 @@ def quad_dl(): rng = np.random.default_rng(11) src = _ot_data(rng, n=n, quad_dim=d) - tgt = _ot_data(rng, n=n, quad_dim=d, offset=1.0) + tgt = _ot_data(rng, n=n, quad_dim=d + 2, offset=1.0) ds = datasets.OTDataset(src, tgt) return DataLoader(ds, batch_size=16, shuffle=True) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index 063b770db..b88864f6a 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools +from typing import Literal import pytest @@ -24,47 +26,98 @@ def data_match_fn( src_lin: jnp.ndarray, tgt_lin: jnp.ndarray, src_quad: jnp.ndarray, - tgt_quad: jnp.ndarray + tgt_quad: jnp.ndarray, *, type: Literal["linear", "quadratic", "fused"] ): - # TODO(michalk8): extend for GW/FGW - return utils.match_linear(src_lin, tgt_lin) + if type == "linear": + return utils.match_linear(x=src_lin, y=tgt_lin) + if type == "quadratic": + return utils.match_quadratic(xx=src_quad, yy=tgt_quad) + if type == "fused": + return utils.match_quadratic(xx=src_quad, yy=tgt_quad, x=src_lin, y=tgt_lin) + raise NotImplementedError(f"Unknown type: {type}") class TestGENOT: # TODO(michalk8): test gw/fgw, k, etc. - @pytest.mark.parametrize(("cond_dim", "dl"), [(2, "lin_dl")]) - def test_genot2(self, rng: jax.Array, cond_dim: int, dl: str, request): + @pytest.mark.parametrize("dl", ["lin_dl", "conditional_lin_dl"]) + def test_genot_linear(self, rng: jax.Array, dl: str, request): rng_init, rng_call = jax.random.split(rng) - input_dim, hidden_dim = 2, 7 + hidden_dim = 7 dl = request.getfixturevalue(dl) + batch = next(iter(dl)) + src = jnp.asarray(batch["src_lin"]) + tgt = jnp.asarray(batch["tgt_lin"]) + src_cond = batch.get("src_condition") + if src_cond is not None: + src_cond = jnp.asarray(src_cond) + src_dim = src.shape[-1] + tgt_dim = tgt.shape[-1] + cond_dim = src_cond.shape[-1] if src_cond is not None else 0 + vf = models.VelocityField( hidden_dim=hidden_dim, - output_dim=input_dim, - # TODO(michalk8): the source is the condition - condition_dim=cond_dim, + output_dim=tgt_dim, + condition_dim=src_dim + cond_dim, ) + data_mfn = functools.partial(data_match_fn, type="linear") + model = genot.GENOT( vf, flow=flows.ConstantNoiseFlow(0.0), - data_match_fn=data_match_fn, - rng=rng_init, - optimizer=optax.adam(learning_rate=1e-3), - input_dim=input_dim, + data_match_fn=data_mfn, + source_dim=src_dim, + target_dim=tgt_dim, condition_dim=cond_dim, + rng=rng_init, + optimizer=optax.adam(learning_rate=1e-4), ) _logs = model(dl, n_iters=3, rng=rng_call) + res = model.transport(src, condition=src_cond) + + assert jnp.sum(jnp.isnan(res)) == 0 + assert res.shape[-1] == tgt_dim + + @pytest.mark.parametrize("dl", ["quad_dl", "conditional_quad_dl"]) + def test_genot_quad(self, rng: jax.Array, dl: str, request): + rng_init, rng_call = jax.random.split(rng) + hidden_dim = 7 + dl = request.getfixturevalue(dl) - # TODO(michalk8): generalize for gw/fgw batch = next(iter(dl)) - src = jnp.asarray(batch["src_lin"]) + src = jnp.asarray(batch["src_quad"]) + tgt = jnp.asarray(batch["tgt_quad"]) src_cond = batch.get("src_condition") if src_cond is not None: src_cond = jnp.asarray(src_cond) + src_dim = src.shape[-1] + tgt_dim = tgt.shape[-1] + cond_dim = src_cond.shape[-1] if src_cond is not None else 0 + + vf = models.VelocityField( + hidden_dim=hidden_dim, + output_dim=tgt_dim, + condition_dim=src_dim + cond_dim, + ) + data_mfn = functools.partial(data_match_fn, type="quadratic") + + model = genot.GENOT( + vf, + flow=flows.ConstantNoiseFlow(0.0), + data_match_fn=data_mfn, + source_dim=src_dim, + target_dim=tgt_dim, + condition_dim=cond_dim, + rng=rng_init, + optimizer=optax.adam(learning_rate=1e-4), + ) + + _logs = model(dl, n_iters=3, rng=rng_call) res = model.transport(src, condition=src_cond) assert jnp.sum(jnp.isnan(res)) == 0 + assert res.shape[-1] == tgt_dim From 08e24d81c99852c69f3ff81ea1f9d7e93b5eccda Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 19 Mar 2024 12:08:32 +0100 Subject: [PATCH 143/186] fix linting --- tests/geometry/geodesic_test.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/geometry/geodesic_test.py b/tests/geometry/geodesic_test.py index aa1ac37c5..4cf7aa44b 100644 --- a/tests/geometry/geodesic_test.py +++ b/tests/geometry/geodesic_test.py @@ -13,16 +13,15 @@ # limitations under the License. from typing import Optional, Union - -import jax -import jax.experimental.sparse as jesp -import jax.numpy as jnp - import networkx as nx from networkx.algorithms import shortest_paths from networkx.generators import balanced_tree, random_graphs import pytest + +import jax +import jax.experimental.sparse as jesp +import jax.numpy as jnp import numpy as np from ott.geometry import geodesic, geometry, graph From 7d7da3a986fd9f79ed99f03f5c64e5df797569c4 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 19 Mar 2024 12:40:40 +0100 Subject: [PATCH 144/186] fix data loading and add genot fused tests --- tests/neural/conftest.py | 56 ++++++++++++++++++++++++++++++++++---- tests/neural/genot_test.py | 44 ++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 5 deletions(-) diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py index ede2d4dc4..581796c0c 100644 --- a/tests/neural/conftest.py +++ b/tests/neural/conftest.py @@ -41,7 +41,8 @@ def _ot_data( ) if isinstance(condition, float): - cond_dim = lin_dim if cond_dim is None else cond_dim + _dim = lin_dim if lin_dim is not None else quad_dim + cond_dim = _dim if cond_dim is None else cond_dim condition = np.full((n, cond_dim), fill_value=condition) return datasets.OTData(lin=lin_data, quad=quad_data, condition=condition) @@ -80,9 +81,9 @@ def conditional_lin_dl() -> datasets.ConditionalLoader: rng = np.random.default_rng(42) src0 = _ot_data(rng, condition=0.0, lin_dim=d, cond_dim=cond_dim) - tgt0 = _ot_data(rng, offset=2.0) + tgt0 = _ot_data(rng, lin_dim=d, offset=2.0) src1 = _ot_data(rng, condition=1.0, lin_dim=d, cond_dim=cond_dim) - tgt1 = _ot_data(rng, offset=-2.0) + tgt1 = _ot_data(rng, lin_dim=d, offset=-2.0) src_ds = datasets.OTDataset(src0, tgt0) tgt_ds = datasets.OTDataset(src1, tgt1) @@ -106,5 +107,50 @@ def quad_dl(): @pytest.fixture() -def quad_dl_with_conds(): - pass +def conditional_quad_dl() -> datasets.ConditionalLoader: + n, d, cond_dim = 128, 2, 5 + rng = np.random.default_rng(11) + + src0 = _ot_data(rng, n=n, condition=0.0, cond_dim=cond_dim, quad_dim=d) + tgt0 = _ot_data(rng, n=n, quad_dim=d, cond_dim=cond_dim, offset=2.0) + src1 = _ot_data(rng, n=n, condition=1.0, quad_dim=d + 2) + tgt1 = _ot_data(rng, n=n, quad_dim=d + 2, offset=-2.0) + + src_ds = datasets.OTDataset(src0, tgt0) + tgt_ds = datasets.OTDataset(src1, tgt1) + + src_dl = DataLoader(src_ds, batch_size=16, shuffle=True) + tgt_dl = DataLoader(tgt_ds, batch_size=16, shuffle=True) + + return datasets.ConditionalLoader([src_dl, tgt_dl]) + + +@pytest.fixture() +def fused_dl(): + n, lin_dim, d = 128, 2 + rng = np.random.default_rng(11) + + src = _ot_data(rng, n=n, lin_dim=lin_dim, quad_dim=d) + tgt = _ot_data(rng, n=n, lin_dim=lin_dim, quad_dim=d + 2, offset=1.0) + ds = datasets.OTDataset(src, tgt) + + return DataLoader(ds, batch_size=16, shuffle=True) + + +@pytest.fixture() +def conditional_fused_dl() -> datasets.ConditionalLoader: + n, lin_dim, d = 128, 3, 2 + rng = np.random.default_rng(11) + + src0 = _ot_data(rng, n=n, condition=0.0, lin_dim=lin_dim, quad_dim=d) + tgt0 = _ot_data(rng, n=n, lin_dim=lin_dim, quad_dim=d + 2, offset=2.0) + src1 = _ot_data(rng, n=n, condition=1.0, lin_dim=lin_dim, quad_dim=d) + tgt1 = _ot_data(rng, n=n, lin_dim=lin_dim, quad_dim=d + 2, offset=-2.0) + + src_ds = datasets.OTDataset(src0, tgt0) + tgt_ds = datasets.OTDataset(src1, tgt1) + + src_dl = DataLoader(src_ds, batch_size=16, shuffle=True) + tgt_dl = DataLoader(tgt_ds, batch_size=16, shuffle=True) + + return datasets.ConditionalLoader([src_dl, tgt_dl]) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index b88864f6a..d68ae9511 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -121,3 +121,47 @@ def test_genot_quad(self, rng: jax.Array, dl: str, request): assert jnp.sum(jnp.isnan(res)) == 0 assert res.shape[-1] == tgt_dim + + @pytest.mark.parametrize("dl", ["fused_dl", "conditional_fused_dl"]) + def test_genot_fused(self, rng: jax.Array, dl: str, request): + rng_init, rng_call = jax.random.split(rng) + hidden_dim = 7 + dl = request.getfixturevalue(dl) + + batch = next(iter(dl)) + src_lin = jnp.asarray(batch["src_lin"]) + tgt_lin = jnp.asarray(batch["tgt_lin"]) + src_quad = jnp.asarray(batch["src_quad"]) + tgt_quad = jnp.asarray(batch["tgt_quad"]) + src_cond = batch.get("src_condition") + if src_cond is not None: + src_cond = jnp.asarray(src_cond) + src_dim = src_lin.shape[-1] + src_quad.shape[-1] + tgt_dim = tgt_lin.shape[-1] + tgt_quad.shape[-1] + cond_dim = src_cond.shape[-1] if src_cond is not None else 0 + + vf = models.VelocityField( + hidden_dim=hidden_dim, + output_dim=tgt_dim, + condition_dim=src_dim + cond_dim, + ) + + data_mfn = functools.partial(data_match_fn, type="fused") + + model = genot.GENOT( + vf, + flow=flows.ConstantNoiseFlow(0.0), + data_match_fn=data_mfn, + source_dim=src_dim, + target_dim=tgt_dim, + condition_dim=cond_dim, + rng=rng_init, + optimizer=optax.adam(learning_rate=1e-4), + ) + + _logs = model(dl, n_iters=3, rng=rng_call) + src = jnp.concatenate([src_lin, src_quad], axis=-1) + res = model.transport(src, condition=src_cond) + + assert jnp.sum(jnp.isnan(res)) == 0 + assert res.shape[-1] == tgt_dim From 4c8477a124b96f23dcaf95fcb24cff1b0d303cd3 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 19 Mar 2024 13:44:59 +0100 Subject: [PATCH 145/186] genot tests passing --- tests/neural/conftest.py | 56 ++++++++++++++++++++++++++++------------ 1 file changed, 40 insertions(+), 16 deletions(-) diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py index 581796c0c..b19962608 100644 --- a/tests/neural/conftest.py +++ b/tests/neural/conftest.py @@ -96,11 +96,12 @@ def conditional_lin_dl() -> datasets.ConditionalLoader: @pytest.fixture() def quad_dl(): - n, d = 128, 2 + n = 128 + quad_dim_src, quad_dim_tgt = 2, 4 rng = np.random.default_rng(11) - src = _ot_data(rng, n=n, quad_dim=d) - tgt = _ot_data(rng, n=n, quad_dim=d + 2, offset=1.0) + src = _ot_data(rng, n=n, quad_dim=quad_dim_src) + tgt = _ot_data(rng, n=n, quad_dim=quad_dim_tgt, offset=1.0) ds = datasets.OTDataset(src, tgt) return DataLoader(ds, batch_size=16, shuffle=True) @@ -108,13 +109,20 @@ def quad_dl(): @pytest.fixture() def conditional_quad_dl() -> datasets.ConditionalLoader: - n, d, cond_dim = 128, 2, 5 + n, cond_dim = 128, 5 + quad_dim_src, quad_dim_tgt = 2, 4 rng = np.random.default_rng(11) - src0 = _ot_data(rng, n=n, condition=0.0, cond_dim=cond_dim, quad_dim=d) - tgt0 = _ot_data(rng, n=n, quad_dim=d, cond_dim=cond_dim, offset=2.0) - src1 = _ot_data(rng, n=n, condition=1.0, quad_dim=d + 2) - tgt1 = _ot_data(rng, n=n, quad_dim=d + 2, offset=-2.0) + src0 = _ot_data( + rng, n=n, condition=0.0, cond_dim=cond_dim, quad_dim=quad_dim_src + ) + tgt0 = _ot_data( + rng, n=n, quad_dim=quad_dim_tgt, cond_dim=cond_dim, offset=2.0 + ) + src1 = _ot_data( + rng, n=n, condition=1.0, cond_dim=cond_dim, quad_dim=quad_dim_src + ) + tgt1 = _ot_data(rng, n=n, quad_dim=quad_dim_tgt, offset=-2.0) src_ds = datasets.OTDataset(src0, tgt0) tgt_ds = datasets.OTDataset(src1, tgt1) @@ -127,11 +135,12 @@ def conditional_quad_dl() -> datasets.ConditionalLoader: @pytest.fixture() def fused_dl(): - n, lin_dim, d = 128, 2 + n, lin_dim = 128, 6 + quad_dim_src, quad_dim_tgt = 2, 4 rng = np.random.default_rng(11) - src = _ot_data(rng, n=n, lin_dim=lin_dim, quad_dim=d) - tgt = _ot_data(rng, n=n, lin_dim=lin_dim, quad_dim=d + 2, offset=1.0) + src = _ot_data(rng, n=n, lin_dim=lin_dim, quad_dim=quad_dim_src) + tgt = _ot_data(rng, n=n, lin_dim=lin_dim, quad_dim=quad_dim_tgt, offset=1.0) ds = datasets.OTDataset(src, tgt) return DataLoader(ds, batch_size=16, shuffle=True) @@ -139,13 +148,28 @@ def fused_dl(): @pytest.fixture() def conditional_fused_dl() -> datasets.ConditionalLoader: - n, lin_dim, d = 128, 3, 2 + n, lin_dim, cond_dim = 128, 3, 7 + quad_dim_src, quad_dim_tgt = 2, 4 rng = np.random.default_rng(11) - src0 = _ot_data(rng, n=n, condition=0.0, lin_dim=lin_dim, quad_dim=d) - tgt0 = _ot_data(rng, n=n, lin_dim=lin_dim, quad_dim=d + 2, offset=2.0) - src1 = _ot_data(rng, n=n, condition=1.0, lin_dim=lin_dim, quad_dim=d) - tgt1 = _ot_data(rng, n=n, lin_dim=lin_dim, quad_dim=d + 2, offset=-2.0) + src0 = _ot_data( + rng, + n=n, + condition=0.0, + cond_dim=cond_dim, + lin_dim=lin_dim, + quad_dim=quad_dim_src + ) + tgt0 = _ot_data(rng, n=n, lin_dim=lin_dim, quad_dim=quad_dim_tgt, offset=2.0) + src1 = _ot_data( + rng, + n=n, + condition=1.0, + cond_dim=cond_dim, + lin_dim=lin_dim, + quad_dim=quad_dim_src + ) + tgt1 = _ot_data(rng, n=n, lin_dim=lin_dim, quad_dim=quad_dim_tgt, offset=-2.0) src_ds = datasets.OTDataset(src0, tgt0) tgt_ds = datasets.OTDataset(src1, tgt1) From 001d21dcd675e9f6ea2532ef8dd2d1d04df1d719 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 19 Mar 2024 13:54:36 +0100 Subject: [PATCH 146/186] adapt docs --- docs/neural/flow_models.rst | 13 ++++++++----- src/ott/neural/flow_models/genot.py | 1 - src/ott/neural/flow_models/models.py | 5 +++-- src/ott/neural/flow_models/otfm.py | 3 ++- src/ott/neural/flow_models/utils.py | 14 +++++++++++++- 5 files changed, 26 insertions(+), 10 deletions(-) diff --git a/docs/neural/flow_models.rst b/docs/neural/flow_models.rst index 5f9799292..32fe6e7f3 100644 --- a/docs/neural/flow_models.rst +++ b/docs/neural/flow_models.rst @@ -28,9 +28,7 @@ GENOT .. autosummary:: :toctree: _autosummary - genot.GENOTBase - genot.GENOTLin - genot.GENOTQuad + genot.GENOT Models ------ @@ -44,5 +42,10 @@ Utils .. autosummary:: :toctree: _autosummary - layers.CyclicalTimeEncoder - samplers.uniform_sampler + utils.match_linear + utils.match_quadratic + utils.sample_joint + utils.sample_conditional + utils.cyclical_time_encoder + utils.uniform_sampler + utils.multivariate_normal diff --git a/src/ott/neural/flow_models/genot.py b/src/ott/neural/flow_models/genot.py index cd10784a2..51e218b8d 100644 --- a/src/ott/neural/flow_models/genot.py +++ b/src/ott/neural/flow_models/genot.py @@ -67,7 +67,6 @@ def __init__( self, velocity_field: models.VelocityField, flow: flows.BaseFlow, - # TODO(michalk8): all of these can be optional, explain in the docs data_match_fn: Callable[ [jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray], jnp.ndarray], source_dim: int, diff --git a/src/ott/neural/flow_models/models.py b/src/ott/neural/flow_models/models.py index eb164e12c..5cee1b05f 100644 --- a/src/ott/neural/flow_models/models.py +++ b/src/ott/neural/flow_models/models.py @@ -46,10 +46,11 @@ class VelocityField(nn.Module): output_dim: Dimensionality of the neural vector field. num_layers: Number of layers. condition_dim: Dimensionality of the embedding of the condition. - If :obj:`None`, TODO. + If :obj:`None`, the velocity field has no conditions. time_dim: Dimensionality of the time embedding. If :obj:`None`, set to ``hidden_dim``. - time_encoder: TODO. + time_encoder: Function to encode the time input to the time-dependent + velocity field. act_fn: Activation function. """ hidden_dim: int diff --git a/src/ott/neural/flow_models/otfm.py b/src/ott/neural/flow_models/otfm.py index e1ad5aaab..56fff2e2e 100644 --- a/src/ott/neural/flow_models/otfm.py +++ b/src/ott/neural/flow_models/otfm.py @@ -36,7 +36,8 @@ class OTFlowMatching: Args: velocity_field: Neural vector field parameterized by a neural network. flow: Flow between source and target distribution. - match_fn: TODO. + match_fn: Function to match data points from the source distribution and + the target distribution. time_sampler: Sampler for the time. # TODO(michalk8): expose all args for the train state? kwargs: TODO. diff --git a/src/ott/neural/flow_models/utils.py b/src/ott/neural/flow_models/utils.py index 00c246018..3149106bb 100644 --- a/src/ott/neural/flow_models/utils.py +++ b/src/ott/neural/flow_models/utils.py @@ -117,7 +117,19 @@ def sample_conditional( k: int = 1, uniform_marginals: bool = False, ) -> Tuple[jnp.ndarray, jnp.ndarray]: - """TODO.""" + """Sample indices from a transport matrix. + + Args: + rng: Random number generator. + tmat: Transport matrix. + k: Expected number of samples to sample per row. + uniform_marginals: If :obj:`True`, sample exactly `k` samples + per row, otherwise sample proportionally to the sums of the + rows of the transport matrix. + + Returns: + Source and target indices sampled from the transport matrix. + """ assert k > 0, "Number of samples per row must be positive." n, m = tmat.shape From 52d8466d7651b62cb9801ece391a09443321af1a Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 19 Mar 2024 13:58:56 +0100 Subject: [PATCH 147/186] adapt docs --- docs/neural/index.rst | 1 - docs/neural/models.rst | 25 ------------------------- src/ott/neural/flow_models/genot.py | 2 +- src/ott/neural/flow_models/otfm.py | 11 +++++++++++ 4 files changed, 12 insertions(+), 27 deletions(-) delete mode 100644 docs/neural/models.rst diff --git a/docs/neural/index.rst b/docs/neural/index.rst index 06d9fd97b..9de1781f6 100644 --- a/docs/neural/index.rst +++ b/docs/neural/index.rst @@ -17,4 +17,3 @@ and solvers to estimate such neural networks. duality flow_models gaps - models diff --git a/docs/neural/models.rst b/docs/neural/models.rst deleted file mode 100644 index af6d4e33a..000000000 --- a/docs/neural/models.rst +++ /dev/null @@ -1,25 +0,0 @@ -ott.neural.models -================= -.. module:: ott.neural.models -.. currentmodule:: ott.neural.models - -This module implements models, network architectures and helper -functions which apply to various neural optimal transport solvers. - -Utils ------ -.. autosummary:: - :toctree: _autosummary - - base_solver.BaseOTMatcher - base_solver.OTMatcherLinear - base_solver.OTMatcherQuad - - -Neural networks ---------------- -.. autosummary:: - :toctree: _autosummary - - layers.MLPBlock - nets.RescalingMLP diff --git a/src/ott/neural/flow_models/genot.py b/src/ott/neural/flow_models/genot.py index 51e218b8d..29f59f858 100644 --- a/src/ott/neural/flow_models/genot.py +++ b/src/ott/neural/flow_models/genot.py @@ -151,7 +151,7 @@ def __call__( loader: Data loader returning a dictionary with possible keys `src_lin`, `tgt_lin`, `src_quad`, `tgt_quad`, `src_conditions`. n_iters: Number of iterations to train the model. - rng: Random seed. + rng: Random number generator. Returns: Training logs. diff --git a/src/ott/neural/flow_models/otfm.py b/src/ott/neural/flow_models/otfm.py index 56fff2e2e..d048581b4 100644 --- a/src/ott/neural/flow_models/otfm.py +++ b/src/ott/neural/flow_models/otfm.py @@ -105,6 +105,17 @@ def __call__( # noqa: D102 n_iters: int, rng: Optional[jax.Array] = None, ) -> Dict[str, List[float]]: + """Train the OTFlowMatching model. + + Args: + loader: Data loader returning a dictionary with possible keys + `src_lin`, `tgt_lin`, `src_condition`. + n_iters: Number of iterations to train the model. + rng: Random number generator. + + Returns: + Training logs. + """ rng = utils.default_prng_key(rng) training_logs = {"loss": []} for batch in loader: From 9f230c8530d5453ff8973c02d3f34850807d61a2 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 19 Mar 2024 14:03:36 +0100 Subject: [PATCH 148/186] add error message --- src/ott/neural/flow_models/genot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ott/neural/flow_models/genot.py b/src/ott/neural/flow_models/genot.py index 29f59f858..5c64a5481 100644 --- a/src/ott/neural/flow_models/genot.py +++ b/src/ott/neural/flow_models/genot.py @@ -173,7 +173,7 @@ def prepare_data( src = jnp.concatenate([src_lin, src_quad], axis=1) tgt = jnp.concatenate([tgt_lin, tgt_quad], axis=1) else: - raise RuntimeError("TODO") + raise RuntimeError("Cannot infer OT problem type from data.") # TODO(michalk8): filter `None` from the `arrs`? return (src, batch.get("src_condition"), tgt), arrs From 6c816788cd28edc33e17d78a106fb9736d91361b Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 19 Mar 2024 15:29:16 +0100 Subject: [PATCH 149/186] clean docs --- docs/spelling/technical.txt | 1 + src/ott/neural/flow_models/genot.py | 4 ++-- src/ott/neural/flow_models/models.py | 2 +- src/ott/neural/flow_models/otfm.py | 4 ++-- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/docs/spelling/technical.txt b/docs/spelling/technical.txt index 7c7ba4ae9..f5a2ffb57 100644 --- a/docs/spelling/technical.txt +++ b/docs/spelling/technical.txt @@ -25,6 +25,7 @@ McCann Monge Moreau SGD +Schrödinger Schur Seidel Sinkhorn diff --git a/src/ott/neural/flow_models/genot.py b/src/ott/neural/flow_models/genot.py index 5c64a5481..f115fe7a4 100644 --- a/src/ott/neural/flow_models/genot.py +++ b/src/ott/neural/flow_models/genot.py @@ -256,7 +256,7 @@ def transport( rng: Optional[jax.Array] = None, **kwargs: Any, ) -> jnp.ndarray: - """Transport data with the learnt plan. + """Transport data with the learned plan. This method pushes-forward the `source` to its conditional distribution by solving the neural ODE parameterized by the @@ -271,7 +271,7 @@ def transport( kwargs: Keyword arguments for the ODE solver. Returns: - The push-forward or pull-back distribution defined by the learnt + The push-forward or pull-back distribution defined by the learned transport plan. """ diff --git a/src/ott/neural/flow_models/models.py b/src/ott/neural/flow_models/models.py index 5cee1b05f..abf3ec8f7 100644 --- a/src/ott/neural/flow_models/models.py +++ b/src/ott/neural/flow_models/models.py @@ -111,7 +111,7 @@ def create_train_state( rng: Random number generator. optimizer: Optimizer. input_dim: Dimensionality of the velocity field. - condition_dim: Dimensionsanilty of the condition + condition_dim: Dimensionality of the condition to the velocity field. Returns: diff --git a/src/ott/neural/flow_models/otfm.py b/src/ott/neural/flow_models/otfm.py index d048581b4..ad1a70522 100644 --- a/src/ott/neural/flow_models/otfm.py +++ b/src/ott/neural/flow_models/otfm.py @@ -154,7 +154,7 @@ def transport( t1: float = 1.0, **kwargs: Any, ) -> jnp.ndarray: - """Transport data with the learnt map. + """Transport data with the learned map. This method pushes-forward the data by solving the neural ODE parameterized by the velocity field. @@ -167,7 +167,7 @@ def transport( kwargs: Keyword arguments for the ODE solver. Returns: - The push-forward or pull-back distribution defined by the learnt + The push-forward or pull-back distribution defined by the learned transport plan. """ From e77cc34f523d74ff3390a27e788a91819e7ceda2 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 19 Mar 2024 15:46:59 +0100 Subject: [PATCH 150/186] comprise genot tests --- docs/spelling/misc.txt | 1 + docs/spelling/technical.txt | 2 + tests/neural/genot_test.py | 127 ++++++++++-------------------------- 3 files changed, 39 insertions(+), 91 deletions(-) diff --git a/docs/spelling/misc.txt b/docs/spelling/misc.txt index 26bc961ce..4be10fe05 100644 --- a/docs/spelling/misc.txt +++ b/docs/spelling/misc.txt @@ -1,4 +1,5 @@ Eulerian +Utils alg arg args diff --git a/docs/spelling/technical.txt b/docs/spelling/technical.txt index f5a2ffb57..f7997b48c 100644 --- a/docs/spelling/technical.txt +++ b/docs/spelling/technical.txt @@ -47,6 +47,7 @@ chromatin collinear covariance covariances +dataclass dataloaders dataset datasets @@ -111,6 +112,7 @@ preprocess preprocessing proteome prox +pytree quantile quantiles quantizes diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index d68ae9511..a5f061335 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -39,62 +39,40 @@ def data_match_fn( class TestGENOT: - # TODO(michalk8): test gw/fgw, k, etc. - @pytest.mark.parametrize("dl", ["lin_dl", "conditional_lin_dl"]) - def test_genot_linear(self, rng: jax.Array, dl: str, request): + @pytest.mark.parametrize( + "dl", [ + "lin_dl", "conditional_lin_dl", "quad_dl", "conditional_quad_dl", + "fused_dl", "conditional_fused_dl" + ] + ) + def test_genot(self, rng: jax.Array, dl: str, request): rng_init, rng_call = jax.random.split(rng) hidden_dim = 7 dl = request.getfixturevalue(dl) batch = next(iter(dl)) - src = jnp.asarray(batch["src_lin"]) - tgt = jnp.asarray(batch["tgt_lin"]) + src_lin = batch.get("src_lin") + if src_lin is not None: + src_lin = jnp.asarray(src_lin) + src_quad = batch.get("src_quad") + if src_quad is not None: + src_quad = jnp.asarray(src_quad) + tgt_lin = batch.get("tgt_lin") + if tgt_lin is not None: + tgt_lin = jnp.asarray(batch["tgt_lin"]) + tgt_quad = batch.get("tgt_quad") + if tgt_quad is not None: + tgt_quad = jnp.asarray(batch["tgt_quad"]) src_cond = batch.get("src_condition") if src_cond is not None: src_cond = jnp.asarray(src_cond) - src_dim = src.shape[-1] - tgt_dim = tgt.shape[-1] - cond_dim = src_cond.shape[-1] if src_cond is not None else 0 - vf = models.VelocityField( - hidden_dim=hidden_dim, - output_dim=tgt_dim, - condition_dim=src_dim + cond_dim, - ) - - data_mfn = functools.partial(data_match_fn, type="linear") - - model = genot.GENOT( - vf, - flow=flows.ConstantNoiseFlow(0.0), - data_match_fn=data_mfn, - source_dim=src_dim, - target_dim=tgt_dim, - condition_dim=cond_dim, - rng=rng_init, - optimizer=optax.adam(learning_rate=1e-4), - ) - - _logs = model(dl, n_iters=3, rng=rng_call) - res = model.transport(src, condition=src_cond) - - assert jnp.sum(jnp.isnan(res)) == 0 - assert res.shape[-1] == tgt_dim - - @pytest.mark.parametrize("dl", ["quad_dl", "conditional_quad_dl"]) - def test_genot_quad(self, rng: jax.Array, dl: str, request): - rng_init, rng_call = jax.random.split(rng) - hidden_dim = 7 - dl = request.getfixturevalue(dl) - - batch = next(iter(dl)) - src = jnp.asarray(batch["src_quad"]) - tgt = jnp.asarray(batch["tgt_quad"]) - src_cond = batch.get("src_condition") - if src_cond is not None: - src_cond = jnp.asarray(src_cond) - src_dim = src.shape[-1] - tgt_dim = tgt.shape[-1] + src_lin_dim = src_lin.shape[-1] if src_lin is not None else 0 + src_quad_dim = src_quad.shape[-1] if src_quad is not None else 0 + tgt_lin_shape = tgt_lin.shape[-1] if tgt_lin is not None else 0 + tgt_quad_shape = tgt_quad.shape[-1] if tgt_quad is not None else 0 + src_dim = src_lin_dim + src_quad_dim + tgt_dim = tgt_lin_shape + tgt_quad_shape cond_dim = src_cond.shape[-1] if src_cond is not None else 0 vf = models.VelocityField( @@ -103,50 +81,16 @@ def test_genot_quad(self, rng: jax.Array, dl: str, request): condition_dim=src_dim + cond_dim, ) - data_mfn = functools.partial(data_match_fn, type="quadratic") - - model = genot.GENOT( - vf, - flow=flows.ConstantNoiseFlow(0.0), - data_match_fn=data_mfn, - source_dim=src_dim, - target_dim=tgt_dim, - condition_dim=cond_dim, - rng=rng_init, - optimizer=optax.adam(learning_rate=1e-4), - ) - - _logs = model(dl, n_iters=3, rng=rng_call) - res = model.transport(src, condition=src_cond) - - assert jnp.sum(jnp.isnan(res)) == 0 - assert res.shape[-1] == tgt_dim - - @pytest.mark.parametrize("dl", ["fused_dl", "conditional_fused_dl"]) - def test_genot_fused(self, rng: jax.Array, dl: str, request): - rng_init, rng_call = jax.random.split(rng) - hidden_dim = 7 - dl = request.getfixturevalue(dl) - - batch = next(iter(dl)) - src_lin = jnp.asarray(batch["src_lin"]) - tgt_lin = jnp.asarray(batch["tgt_lin"]) - src_quad = jnp.asarray(batch["src_quad"]) - tgt_quad = jnp.asarray(batch["tgt_quad"]) - src_cond = batch.get("src_condition") - if src_cond is not None: - src_cond = jnp.asarray(src_cond) - src_dim = src_lin.shape[-1] + src_quad.shape[-1] - tgt_dim = tgt_lin.shape[-1] + tgt_quad.shape[-1] - cond_dim = src_cond.shape[-1] if src_cond is not None else 0 - - vf = models.VelocityField( - hidden_dim=hidden_dim, - output_dim=tgt_dim, - condition_dim=src_dim + cond_dim, - ) + if src_lin_dim > 0 and src_quad_dim == 0: + problem_type = "linear" + elif src_lin_dim == 0 and src_quad_dim > 0: + problem_type = "quadratic" + elif src_lin_dim > 0 and src_quad_dim > 0: + problem_type = "fused" + else: + raise ValueError("Unknown problem type") - data_mfn = functools.partial(data_match_fn, type="fused") + data_mfn = functools.partial(data_match_fn, type=problem_type) model = genot.GENOT( vf, @@ -160,7 +104,8 @@ def test_genot_fused(self, rng: jax.Array, dl: str, request): ) _logs = model(dl, n_iters=3, rng=rng_call) - src = jnp.concatenate([src_lin, src_quad], axis=-1) + src_terms = [term for term in [src_lin, src_quad] if term is not None] + src = jnp.concatenate(src_terms, axis=-1) res = model.transport(src, condition=src_cond) assert jnp.sum(jnp.isnan(res)) == 0 From d8603f76f55ce9f0e3240a9ee86f7d5d02e1c3d8 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 19 Mar 2024 16:36:00 +0100 Subject: [PATCH 151/186] change reference for GENOT --- docs/references.bib | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/references.bib b/docs/references.bib index a53f5a8a1..d07643e8c 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -842,7 +842,7 @@ @misc{klein_uscidda:23 eprint = {2310.09254}, eprintclass = {stat.ML}, eprinttype = {arXiv}, - title = {Generative Entropic Neural Optimal Transport To Map Within and Across Spaces}, + title = {Entropic (Gromov) Wasserstein Flow Matching with GENOT}, year = {2023}, } From 7813f833cfe9e22d32af9c9934ae87b1e1b7c1e6 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 20 Mar 2024 16:04:51 +0100 Subject: [PATCH 152/186] add missing docstring --- src/ott/neural/flow_models/utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/ott/neural/flow_models/utils.py b/src/ott/neural/flow_models/utils.py index 3149106bb..a189b8f0f 100644 --- a/src/ott/neural/flow_models/utils.py +++ b/src/ott/neural/flow_models/utils.py @@ -99,7 +99,15 @@ def match_quadratic( def sample_joint(rng: jax.Array, tmat: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: - """TODO.""" + """Sample from a transport matrix. + + Args: + rng: Random number generator. + tmat: Transport matrix. + + Returns: + Source and target indices sampled from the transport matrix. + """ n, m = tmat.shape tmat_flattened = tmat.flatten() indices = jax.random.choice( From 212ee012b830a1ad808fb1336dc90b74cdd3d5d2 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 25 Mar 2024 19:21:10 +0100 Subject: [PATCH 153/186] Modify behaviour of `ConditionalLoader` --- src/ott/neural/data/datasets.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/ott/neural/data/datasets.py b/src/ott/neural/data/datasets.py index 0ec8edfe9..9cf1c085f 100644 --- a/src/ott/neural/data/datasets.py +++ b/src/ott/neural/data/datasets.py @@ -142,22 +142,20 @@ def __init__( def __next__(self) -> Item_t: if self._it == len(self): raise StopIteration + self._it += 1 ix = self._rng.choice(len(self._iterators)) iterator = self._iterators[ix] try: - data = next(iterator) - # TODO(michalk8): improve the logic a bit - self._it += 1 - return data + return next(iterator) except StopIteration: - self._iterators[ix] = iter(self.datasets[ix]) - if not self._iterators: - raise + # reset the consumed iterator and return it's first element + self._iterators[ix] = iterator = iter(self.datasets[ix]) + return next(iterator) def __iter__(self) -> "ConditionalLoader": - self._iterators = [iter(ds) for ds in self.datasets] self._it = 0 + self._iterators = [iter(ds) for ds in self.datasets] return self def __len__(self) -> int: From 95c71420c40e7a70c0b55e8e304a561d02cca5a9 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 25 Mar 2024 19:22:49 +0100 Subject: [PATCH 154/186] Update docstring --- src/ott/neural/data/datasets.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/ott/neural/data/datasets.py b/src/ott/neural/data/datasets.py index 9cf1c085f..dcea1fd17 100644 --- a/src/ott/neural/data/datasets.py +++ b/src/ott/neural/data/datasets.py @@ -54,7 +54,8 @@ class OTDataset: tgt_data: Samples from the target distribution. src_conditions: Conditions for the source data. tgt_conditions: Conditions for the target data. - is_aligned: Whether the samples from `src_data` and `tgt_data` are paired. + is_aligned: Whether the samples from the source and the target data + are paired. If yes, the source and the target conditions must match. seed: Random seed. """ SRC_PREFIX = "src" @@ -67,7 +68,7 @@ def __init__( src_conditions: Optional[Sequence[Any]] = None, tgt_conditions: Optional[Sequence[Any]] = None, is_aligned: bool = False, - seed: Optional[int] = None + seed: Optional[int] = None, ): self.src_data = src_data self.tgt_data = tgt_data From 52a54d38b02192a4b8ffcad86ca67944f5579fe4 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 25 Mar 2024 20:08:43 +0100 Subject: [PATCH 155/186] Clean GENOT docs --- docs/neural/flow_models.rst | 1 - src/ott/neural/data/datasets.py | 2 +- src/ott/neural/flow_models/genot.py | 96 ++++++++++++++++------------ src/ott/neural/flow_models/models.py | 2 +- src/ott/neural/flow_models/otfm.py | 2 +- src/ott/neural/flow_models/utils.py | 20 +----- 6 files changed, 62 insertions(+), 61 deletions(-) diff --git a/docs/neural/flow_models.rst b/docs/neural/flow_models.rst index 32fe6e7f3..273f145f3 100644 --- a/docs/neural/flow_models.rst +++ b/docs/neural/flow_models.rst @@ -48,4 +48,3 @@ Utils utils.sample_conditional utils.cyclical_time_encoder utils.uniform_sampler - utils.multivariate_normal diff --git a/src/ott/neural/data/datasets.py b/src/ott/neural/data/datasets.py index dcea1fd17..28c49a5af 100644 --- a/src/ott/neural/data/datasets.py +++ b/src/ott/neural/data/datasets.py @@ -56,7 +56,7 @@ class OTDataset: tgt_conditions: Conditions for the target data. is_aligned: Whether the samples from the source and the target data are paired. If yes, the source and the target conditions must match. - seed: Random seed. + seed: Random seed used to match source and target when not aligned. """ SRC_PREFIX = "src" TGT_PREFIX = "tgt" diff --git a/src/ott/neural/flow_models/genot.py b/src/ott/neural/flow_models/genot.py index f115fe7a4..d2e579f23 100644 --- a/src/ott/neural/flow_models/genot.py +++ b/src/ott/neural/flow_models/genot.py @@ -28,37 +28,45 @@ __all__ = ["GENOT"] +# input: (src_lin, tgt_lin, src_quad, tgt_quad), output: (len(src), len(tgt)) +# all are optional because the problem can be linear/quadratic/fused +DataMatchFn_t = Callable[[ + Optional[jnp.ndarray], Optional[jnp.ndarray], Optional[jnp.ndarray], + Optional[jnp.ndarray] +], jnp.ndarray] -class GENOT: - """GENOT for entropic neural optimal transport :cite:`klein_uscidda:23`. - GENOT (Generative Entropic Neural Optimal Transport) is a framework for - learning neural optimal transport plans between two distributions. It - allows for learning linear and quadratic (Fused) Gromov-Wasserstein couplings, - in both the balanced and the unbalanced setting. +class GENOT: + """Generative Entropic Neural Optimal Transport :cite:`klein_uscidda:23`. + GENOT is a framework for learning neural optimal transport plans between + two distributions. It allows for learning linear and quadratic + (Fused) Gromov-Wasserstein couplings, in both the balanced and + the unbalanced setting. Args: - velocity_field: Neural vector field parameterized by a neural network. + velocity_field: Vector field parameterized by a neural network. flow: Flow between latent distribution and target distribution. - data_match_fn: OT solver to matching the source and the target distribution. + data_match_fn: Function to match source and target distributions. + The function accepts a 4-tuple ``(src_lin, tgt_lin, src_quad, tgt_quad)`` + and return the transport matrix of shape ``(len(src), len(tgt))``. + Either linear, quadratic or both linear and quadratic source and target + arrays are passed, corresponding to the linear, quadratic and + fused GW couplings, respectively. source_dim: Dimension of the source space. target_dim: Dimension of the target space. - condition_dim: Dimension of the conditions. + condition_dim: Dimension of the conditions. If :obj:`None`, the underlying + velocity field has no conditions. + n_samples_per_src: Number of samples drawn from the conditional distribution + per one source sample. time_sampler: Sampler for the time to learn the neural ODE. If :obj:`None`, the time is uniformly sampled. - # TODO(michalk8): rename - k_samples_per_x: Number of samples drawn from the conditional distribution - per single source sample. - latent_match_fn: Linear OT matcher to optimally pair the latent - distribution with the `k_samples_per_x` samples of the conditional - distribution (corresponding to one sample). If :obj:`None`, samples - from the latent distribution are randomly paired with the samples from - the conditional distribution. latent_noise_fn: Function to sample from the latent distribution in the target space. If :obj:`None`, the latent distribution is sampled from a multivariate normal distribution. - # TODO(michalk8): expose all args for the train state? + latent_match_fn: Function to pair the latent distribution with + the ``n_samples_per_src`` samples of the conditional distribution. + If :obj:`None`, no matching is performed. kwargs: Keyword arguments for :meth:`ott.neural.flow_models.models.VelocityField.create_train_state`. """ @@ -67,36 +75,33 @@ def __init__( self, velocity_field: models.VelocityField, flow: flows.BaseFlow, - data_match_fn: Callable[ - [jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray], jnp.ndarray], + data_match_fn: DataMatchFn_t, + *, source_dim: int, target_dim: int, - condition_dim: int, + condition_dim: Optional[int] = None, + n_samples_per_src: int = 1, time_sampler: Callable[[jax.Array, int], jnp.ndarray] = flow_utils.uniform_sampler, - # TODO(michalk8): rename, too descriptive - k_samples_per_x: int = 1, - latent_match_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray], - jnp.ndarray]] = None, latent_noise_fn: Optional[Callable[[jax.Array, Tuple[int, ...]], jnp.ndarray]] = None, + latent_match_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray], + jnp.ndarray]] = None, **kwargs: Any, ): self.vf = velocity_field self.flow = flow self.data_match_fn = data_match_fn self.time_sampler = time_sampler - self.latent_match_fn = latent_match_fn if latent_noise_fn is None: - latent_noise_fn = functools.partial( - flow_utils.multivariate_normal, dim=target_dim - ) + latent_noise_fn = functools.partial(multivariate_normal, dim=target_dim) self.latent_noise_fn = latent_noise_fn - self.k_samples_per_x = k_samples_per_x + self.latent_match_fn = latent_match_fn + self.n_samples_per_src = n_samples_per_src self.vf_state = self.vf.create_train_state( input_dim=target_dim, - condition_dim=source_dim + condition_dim, + condition_dim=source_dim + (condition_dim or 0), **kwargs ) self.step_fn = self._get_step_fn() @@ -120,10 +125,10 @@ def loss_fn( source_conditions: Optional[jnp.ndarray], rng: jax.Array ): x_t = self.flow.compute_xt(rng, time, latent, target) - cond = ( - source if source_conditions is None else - jnp.concatenate([source, source_conditions], axis=-1) - ) + if source_conditions is None: + cond = source + else: + cond = jnp.concatenate([source, source_conditions], axis=-1) v_t = vf_state.apply_fn({"params": params}, time, x_t, cond) u_t = self.flow.compute_ut(time, latent, target) @@ -151,7 +156,7 @@ def __call__( loader: Data loader returning a dictionary with possible keys `src_lin`, `tgt_lin`, `src_quad`, `tgt_quad`, `src_conditions`. n_iters: Number of iterations to train the model. - rng: Random number generator. + rng: Random key for seeding. Returns: Training logs. @@ -175,7 +180,6 @@ def prepare_data( else: raise RuntimeError("Cannot infer OT problem type from data.") - # TODO(michalk8): filter `None` from the `arrs`? return (src, batch.get("src_condition"), tgt), arrs rng = utils.default_prng_key(rng) @@ -188,14 +192,14 @@ def prepare_data( (src, src_cond, tgt), matching_data = prepare_data(batch) n = src.shape[0] - time = self.time_sampler(rng_time, n * self.k_samples_per_x) - latent = self.latent_noise_fn(rng_noise, (n, self.k_samples_per_x)) + time = self.time_sampler(rng_time, n * self.n_samples_per_src) + latent = self.latent_noise_fn(rng_noise, (n, self.n_samples_per_src)) tmat = self.data_match_fn(*matching_data) # (n, m) src_ixs, tgt_ixs = flow_utils.sample_conditional( # (n, k), (m, k) rng_resample, tmat, - k=self.k_samples_per_x, + k=self.n_samples_per_src, uniform_marginals=True, # TODO(michalk8): expose ) @@ -304,3 +308,15 @@ def solve_ode(x: jnp.ndarray, cond: jnp.ndarray) -> jnp.ndarray: source = jnp.concatenate([source, condition], axis=-1) return jax.jit(jax.vmap(solve_ode))(latent, source) + + +def multivariate_normal( + rng: jax.Array, + shape: Tuple[int, ...], + dim: int, + mean: float = 0.0, + cov: float = 1.0 +) -> jnp.ndarray: + mean = jnp.full(dim, fill_value=mean) + cov = jnp.diag(jnp.full(dim, fill_value=cov)) + return jax.random.multivariate_normal(rng, mean=mean, cov=cov, shape=shape) diff --git a/src/ott/neural/flow_models/models.py b/src/ott/neural/flow_models/models.py index abf3ec8f7..f88bcbfdf 100644 --- a/src/ott/neural/flow_models/models.py +++ b/src/ott/neural/flow_models/models.py @@ -86,7 +86,7 @@ def __call__( t = self.act_fn(nn.Dense(time_dim)(t)) x = self.act_fn(nn.Dense(self.hidden_dim)(x)) if self.condition_dim is not None: - assert condition is not None, "TODO." + assert condition is not None, "No condition was specified." condition = self.act_fn(nn.Dense(self.condition_dim)(condition)) feats = [t, x] + ([] if condition is None else [condition]) diff --git a/src/ott/neural/flow_models/otfm.py b/src/ott/neural/flow_models/otfm.py index ad1a70522..3e4381b65 100644 --- a/src/ott/neural/flow_models/otfm.py +++ b/src/ott/neural/flow_models/otfm.py @@ -34,7 +34,7 @@ class OTFlowMatching: With an extension to OT-FM :cite:`tong:23`, :cite:`pooladian:23`. Args: - velocity_field: Neural vector field parameterized by a neural network. + velocity_field: Vector field parameterized by a neural network. flow: Flow between source and target distribution. match_fn: Function to match data points from the source distribution and the target distribution. diff --git a/src/ott/neural/flow_models/utils.py b/src/ott/neural/flow_models/utils.py index a189b8f0f..f85ef159a 100644 --- a/src/ott/neural/flow_models/utils.py +++ b/src/ott/neural/flow_models/utils.py @@ -26,7 +26,6 @@ "sample_conditional", "cyclical_time_encoder", "uniform_sampler", - "multivariate_normal", ] ScaleCost_t = Union[float, Literal["mean", "max_cost", "median"]] @@ -100,7 +99,7 @@ def match_quadratic( def sample_joint(rng: jax.Array, tmat: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: """Sample from a transport matrix. - + Args: rng: Random number generator. tmat: Transport matrix. @@ -130,7 +129,7 @@ def sample_conditional( Args: rng: Random number generator. tmat: Transport matrix. - k: Expected number of samples to sample per row. + k: Expected number of samples to sample per source. uniform_marginals: If :obj:`True`, sample exactly `k` samples per row, otherwise sample proportionally to the sums of the rows of the transport matrix. @@ -138,7 +137,7 @@ def sample_conditional( Returns: Source and target indices sampled from the transport matrix. """ - assert k > 0, "Number of samples per row must be positive." + assert k > 0, "Number of samples per source must be positive." n, m = tmat.shape if uniform_marginals: @@ -208,16 +207,3 @@ def uniform_sampler( t = jax.random.uniform(rng, (1, 1), minval=low, maxval=high) mod_term = ((high - low) - offset) return (t + jnp.arange(num_samples)[:, None] / num_samples) % mod_term - - -def multivariate_normal( - rng: jax.Array, - shape: Tuple[int, ...], - dim: int, - mean: float = 0.0, - cov: float = 1.0 -) -> jnp.ndarray: - """TODO.""" - mean = jnp.full(dim, fill_value=mean) - cov = jnp.diag(jnp.full(dim, fill_value=cov)) - return jax.random.multivariate_normal(rng, mean=mean, cov=cov, shape=shape) From de2e4ac29e7efd5d452e74cab7362049c2d532f3 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Tue, 26 Mar 2024 18:03:40 +0100 Subject: [PATCH 156/186] Improve VF --- src/ott/neural/flow_models/models.py | 59 +++++++++++++++------------- tests/neural/genot_test.py | 9 ++--- tests/neural/otfm_test.py | 10 ++--- 3 files changed, 40 insertions(+), 38 deletions(-) diff --git a/src/ott/neural/flow_models/models.py b/src/ott/neural/flow_models/models.py index f88bcbfdf..cd182d20e 100644 --- a/src/ott/neural/flow_models/models.py +++ b/src/ott/neural/flow_models/models.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional +from typing import Callable, Optional, Sequence import jax import jax.numpy as jnp @@ -35,29 +35,20 @@ class VelocityField(nn.Module): a target distribution given at :math:`t_1` by integrating :math:`v(t, x)` from :math:`t=t_0` to :math:`t=t_1`. - Each of the input, condition, and time embeddings are passed through a block - consisting of ``num_layers`` layers of dimension - ``hidden_dim``, ``condition_dim``, and ``time_embed_dim``, - respectively. The output of each block is concatenated and passed through - a final block of dimension ``joint_hidden_dim``. - Args: - hidden_dim: Dimensionality of the embedding of the data. - output_dim: Dimensionality of the neural vector field. - num_layers: Number of layers. - condition_dim: Dimensionality of the embedding of the condition. + hidden_dims: Dimensionality of the embedding of the data. + condition_dims: Dimensionality of the embedding of the condition. If :obj:`None`, the velocity field has no conditions. - time_dim: Dimensionality of the time embedding. - If :obj:`None`, set to ``hidden_dim``. + time_dims: Dimensionality of the time embedding. + If :obj:`None`, ``hidden_dims`` will be used. time_encoder: Function to encode the time input to the time-dependent velocity field. act_fn: Activation function. """ - hidden_dim: int output_dim: int - num_layers: int = 3 - condition_dim: Optional[int] = None - time_dim: Optional[int] = None + hidden_dims: Sequence[int] = (128, 128, 128) + condition_dims: Optional[Sequence[int]] = None + time_dims: Optional[Sequence[int]] = None time_encoder: Callable[[jnp.ndarray], jnp.ndarray] = utils.cyclical_time_encoder act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.silu @@ -79,24 +70,33 @@ def __call__( Returns: Output of the neural vector field of shape ``[batch, output_dim]``. """ - time_dim = self.hidden_dim if self.time_dim is None else self.time_dim + if self.condition_dims is None: + cond_dims = [None] * len(self.hidden_dims) + else: + cond_dims = self.condition_dims + time_dims = self.hidden_dims if self.time_dims is None else self.time_dims + + assert len(self.hidden_dims) == len(cond_dims), "TODO" + assert len(self.hidden_dims) == len(time_dims), "TODO" t = self.time_encoder(t) - for _ in range(self.num_layers): + for time_dim, cond_dim, hidden_dim in zip( + time_dims, cond_dims, self.hidden_dims + ): t = self.act_fn(nn.Dense(time_dim)(t)) - x = self.act_fn(nn.Dense(self.hidden_dim)(x)) - if self.condition_dim is not None: + x = self.act_fn(nn.Dense(hidden_dim)(x)) + if self.condition_dims is not None: assert condition is not None, "No condition was specified." - condition = self.act_fn(nn.Dense(self.condition_dim)(condition)) + condition = self.act_fn(nn.Dense(cond_dim)(condition)) - feats = [t, x] + ([] if condition is None else [condition]) + feats = [t, x] + ([] if self.condition_dims is None else [condition]) feats = jnp.concatenate(feats, axis=-1) joint_dim = feats.shape[-1] - for _ in range(self.num_layers): + for _ in range(len(self.hidden_dims)): feats = self.act_fn(nn.Dense(joint_dim)(feats)) - return nn.Dense(self.output_dim, use_bias=True)(feats) + return nn.Dense(self.output_dim)(feats) def create_train_state( self, @@ -111,14 +111,17 @@ def create_train_state( rng: Random number generator. optimizer: Optimizer. input_dim: Dimensionality of the velocity field. - condition_dim: Dimensionality of the condition - to the velocity field. + condition_dim: Dimensionality of the condition of the velocity field. Returns: The training state. """ t, x = jnp.ones((1, 1)), jnp.ones((1, input_dim)) - cond = None if self.condition_dim is None else jnp.ones((1, condition_dim)) + if self.condition_dims is not None: + assert condition_dim is not None, "TODO" + cond = jnp.ones((1, condition_dim)) + else: + cond = None params = self.init(rng, t, x, cond)["params"] return train_state.TrainState.create( diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index a5f061335..22bce3e39 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -46,8 +46,7 @@ class TestGENOT: ] ) def test_genot(self, rng: jax.Array, dl: str, request): - rng_init, rng_call = jax.random.split(rng) - hidden_dim = 7 + rng_init, rng_call = jax.random.split(rng, 2) dl = request.getfixturevalue(dl) batch = next(iter(dl)) @@ -76,9 +75,9 @@ def test_genot(self, rng: jax.Array, dl: str, request): cond_dim = src_cond.shape[-1] if src_cond is not None else 0 vf = models.VelocityField( - hidden_dim=hidden_dim, - output_dim=tgt_dim, - condition_dim=src_dim + cond_dim, + tgt_dim, + hidden_dims=[7, 7, 7], + condition_dims=[7, 7, 7], ) if src_lin_dim > 0 and src_quad_dim == 0: diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index a4db65fa5..078b5c50a 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -27,13 +27,13 @@ class TestOTFlowMatching: (3, "lin_dl_with_conds"), (4, "conditional_lin_dl")]) def test_fm(self, rng: jax.Array, cond_dim: int, dl: str, request): - input_dim, hidden_dim = 2, 5 + dim = 2 # all dataloaders have this dim dl = request.getfixturevalue(dl) neural_vf = models.VelocityField( - hidden_dim=hidden_dim, - output_dim=input_dim, - condition_dim=hidden_dim if cond_dim > 0 else None, + dim, + hidden_dims=[5, 5, 5], + condition_dims=[5, 5, 5] if cond_dim > 0 else None, ) fm = otfm.OTFlowMatching( neural_vf, @@ -41,7 +41,7 @@ def test_fm(self, rng: jax.Array, cond_dim: int, dl: str, request): match_fn=jax.jit(utils.match_linear), rng=rng, optimizer=optax.adam(learning_rate=1e-3), - input_dim=input_dim, + input_dim=dim, condition_dim=cond_dim, ) From 9b89fd7f8a934135f09f8ffad658b26a10bb36a0 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Tue, 26 Mar 2024 18:23:53 +0100 Subject: [PATCH 157/186] Simplify GENOT test --- tests/neural/genot_test.py | 73 ++++++++++++-------------------------- 1 file changed, 22 insertions(+), 51 deletions(-) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index 22bce3e39..bd857aef8 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -from typing import Literal +from typing import Literal, Optional import pytest import jax import jax.numpy as jnp +import jax.tree_util as jtu import optax @@ -25,54 +26,35 @@ def data_match_fn( - src_lin: jnp.ndarray, tgt_lin: jnp.ndarray, src_quad: jnp.ndarray, - tgt_quad: jnp.ndarray, *, type: Literal["linear", "quadratic", "fused"] -): - if type == "linear": + src_lin: Optional[jnp.ndarray], tgt_lin: Optional[jnp.ndarray], + src_quad: Optional[jnp.ndarray], tgt_quad: Optional[jnp.ndarray], *, + typ: Literal["lin", "quad", "fused"] +) -> jnp.ndarray: + if typ == "lin": return utils.match_linear(x=src_lin, y=tgt_lin) - if type == "quadratic": + if typ == "quad": return utils.match_quadratic(xx=src_quad, yy=tgt_quad) - if type == "fused": + if typ == "fused": return utils.match_quadratic(xx=src_quad, yy=tgt_quad, x=src_lin, y=tgt_lin) - raise NotImplementedError(f"Unknown type: {type}") + raise NotImplementedError(f"Unknown type: {typ}.") class TestGENOT: - @pytest.mark.parametrize( - "dl", [ - "lin_dl", "conditional_lin_dl", "quad_dl", "conditional_quad_dl", - "fused_dl", "conditional_fused_dl" - ] - ) + # TODO(michalk8): add conds + @pytest.mark.parametrize("dl", ["lin_dl", "quad_dl", "fused_dl"]) def test_genot(self, rng: jax.Array, dl: str, request): - rng_init, rng_call = jax.random.split(rng, 2) + rng_init, rng_call, rng_data = jax.random.split(rng, 3) + problem_type = dl.split("_")[0] dl = request.getfixturevalue(dl) batch = next(iter(dl)) - src_lin = batch.get("src_lin") - if src_lin is not None: - src_lin = jnp.asarray(src_lin) - src_quad = batch.get("src_quad") - if src_quad is not None: - src_quad = jnp.asarray(src_quad) - tgt_lin = batch.get("tgt_lin") - if tgt_lin is not None: - tgt_lin = jnp.asarray(batch["tgt_lin"]) - tgt_quad = batch.get("tgt_quad") - if tgt_quad is not None: - tgt_quad = jnp.asarray(batch["tgt_quad"]) + batch = jtu.tree_map(jnp.asarray, batch) src_cond = batch.get("src_condition") - if src_cond is not None: - src_cond = jnp.asarray(src_cond) - src_lin_dim = src_lin.shape[-1] if src_lin is not None else 0 - src_quad_dim = src_quad.shape[-1] if src_quad is not None else 0 - tgt_lin_shape = tgt_lin.shape[-1] if tgt_lin is not None else 0 - tgt_quad_shape = tgt_quad.shape[-1] if tgt_quad is not None else 0 - src_dim = src_lin_dim + src_quad_dim - tgt_dim = tgt_lin_shape + tgt_quad_shape - cond_dim = src_cond.shape[-1] if src_cond is not None else 0 + dims = jtu.tree_map(lambda x: x.shape[-1], batch) + src_dim = dims.get("src_lin", 0) + dims.get("src_quad", 0) + tgt_dim = dims.get("tgt_lin", 0) + dims.get("tgt_quad", 0) vf = models.VelocityField( tgt_dim, @@ -80,31 +62,20 @@ def test_genot(self, rng: jax.Array, dl: str, request): condition_dims=[7, 7, 7], ) - if src_lin_dim > 0 and src_quad_dim == 0: - problem_type = "linear" - elif src_lin_dim == 0 and src_quad_dim > 0: - problem_type = "quadratic" - elif src_lin_dim > 0 and src_quad_dim > 0: - problem_type = "fused" - else: - raise ValueError("Unknown problem type") - - data_mfn = functools.partial(data_match_fn, type=problem_type) - model = genot.GENOT( vf, flow=flows.ConstantNoiseFlow(0.0), - data_match_fn=data_mfn, + data_match_fn=functools.partial(data_match_fn, typ=problem_type), source_dim=src_dim, target_dim=tgt_dim, - condition_dim=cond_dim, + condition_dim=None if src_cond is None else src_cond.shape[-1], rng=rng_init, optimizer=optax.adam(learning_rate=1e-4), ) _logs = model(dl, n_iters=3, rng=rng_call) - src_terms = [term for term in [src_lin, src_quad] if term is not None] - src = jnp.concatenate(src_terms, axis=-1) + + src = jax.random.normal(rng_data, (3, src_dim)) res = model.transport(src, condition=src_cond) assert jnp.sum(jnp.isnan(res)) == 0 From 433da0cdd50039716c06895244377748b0253d41 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Tue, 26 Mar 2024 18:58:41 +0100 Subject: [PATCH 158/186] Better metadata wrapper in tests --- tests/neural/conftest.py | 81 ++++++++++++++++++++------------------ tests/neural/genot_test.py | 21 +++++----- tests/neural/otfm_test.py | 24 +++++------ 3 files changed, 62 insertions(+), 64 deletions(-) diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py index b19962608..8653e7d63 100644 --- a/tests/neural/conftest.py +++ b/tests/neural/conftest.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import NamedTuple, Optional, Union import pytest @@ -21,6 +21,14 @@ from ott.neural.data import datasets +class OTLoader(NamedTuple): + loader: DataLoader + lin_dim: int = 0 + quad_src_dim: int = 0 + quad_tgt_dim: int = 0 + cond_dim: Optional[int] = None + + def _ot_data( rng: np.random.Generator, *, @@ -31,7 +39,8 @@ def _ot_data( cond_dim: Optional[int] = None, offset: float = 0.0 ) -> datasets.OTData: - assert lin_dim or quad_dim, "TODO" + assert lin_dim or quad_dim, \ + "Either linear or quadratic dimension has to be specified." lin_data = None if lin_dim is None else ( rng.normal(size=(n, lin_dim)) + offset @@ -50,7 +59,6 @@ def _ot_data( @pytest.fixture() def lin_dl() -> DataLoader: - """Returns a data loader for a simple Gaussian mixture.""" n, d = 128, 2 rng = np.random.default_rng(0) @@ -58,11 +66,14 @@ def lin_dl() -> DataLoader: tgt = _ot_data(rng, n=n, lin_dim=d, offset=1.0) ds = datasets.OTDataset(src, tgt) - return DataLoader(ds, batch_size=16, shuffle=True) + return OTLoader( + DataLoader(ds, batch_size=16, shuffle=True), + lin_dim=d, + ) @pytest.fixture() -def lin_dl_with_conds() -> DataLoader: +def lin_cond_dl() -> DataLoader: n, d, cond_dim = 128, 2, 3 rng = np.random.default_rng(13) @@ -72,39 +83,44 @@ def lin_dl_with_conds() -> DataLoader: tgt = _ot_data(rng, n=n, lin_dim=d, condition=tgt_cond) ds = datasets.OTDataset(src, tgt) - return DataLoader(ds, batch_size=16, shuffle=True) + return OTLoader( + DataLoader(ds, batch_size=16, shuffle=True), + lin_dim=d, + cond_dim=cond_dim, + ) @pytest.fixture() -def conditional_lin_dl() -> datasets.ConditionalLoader: - d, cond_dim = 2, 4 - rng = np.random.default_rng(42) - - src0 = _ot_data(rng, condition=0.0, lin_dim=d, cond_dim=cond_dim) - tgt0 = _ot_data(rng, lin_dim=d, offset=2.0) - src1 = _ot_data(rng, condition=1.0, lin_dim=d, cond_dim=cond_dim) - tgt1 = _ot_data(rng, lin_dim=d, offset=-2.0) - - src_ds = datasets.OTDataset(src0, tgt0) - tgt_ds = datasets.OTDataset(src1, tgt1) +def quad_dl(): + n, quad_src_dim, quad_tgt_dim = 128, 2, 4 + rng = np.random.default_rng(11) - src_dl = DataLoader(src_ds, batch_size=16, shuffle=True) - tgt_dl = DataLoader(tgt_ds, batch_size=16, shuffle=True) + src = _ot_data(rng, n=n, quad_dim=quad_src_dim) + tgt = _ot_data(rng, n=n, quad_dim=quad_tgt_dim, offset=1.0) + ds = datasets.OTDataset(src, tgt) - return datasets.ConditionalLoader([src_dl, tgt_dl]) + return OTLoader( + DataLoader(ds, batch_size=16, shuffle=True), + quad_src_dim=quad_src_dim, + quad_tgt_dim=quad_tgt_dim, + ) @pytest.fixture() -def quad_dl(): - n = 128 - quad_dim_src, quad_dim_tgt = 2, 4 +def fused_dl(): + n, lin_dim, quad_src_dim, quad_tgt_dim = 128, 6, 2, 4 rng = np.random.default_rng(11) - src = _ot_data(rng, n=n, quad_dim=quad_dim_src) - tgt = _ot_data(rng, n=n, quad_dim=quad_dim_tgt, offset=1.0) + src = _ot_data(rng, n=n, lin_dim=lin_dim, quad_dim=quad_src_dim) + tgt = _ot_data(rng, n=n, lin_dim=lin_dim, quad_dim=quad_tgt_dim, offset=1.0) ds = datasets.OTDataset(src, tgt) - return DataLoader(ds, batch_size=16, shuffle=True) + return OTLoader( + DataLoader(ds, batch_size=16, shuffle=True), + lin_dim=lin_dim, + quad_src_dim=quad_src_dim, + quad_tgt_dim=quad_tgt_dim, + ) @pytest.fixture() @@ -133,19 +149,6 @@ def conditional_quad_dl() -> datasets.ConditionalLoader: return datasets.ConditionalLoader([src_dl, tgt_dl]) -@pytest.fixture() -def fused_dl(): - n, lin_dim = 128, 6 - quad_dim_src, quad_dim_tgt = 2, 4 - rng = np.random.default_rng(11) - - src = _ot_data(rng, n=n, lin_dim=lin_dim, quad_dim=quad_dim_src) - tgt = _ot_data(rng, n=n, lin_dim=lin_dim, quad_dim=quad_dim_tgt, offset=1.0) - ds = datasets.OTDataset(src, tgt) - - return DataLoader(ds, batch_size=16, shuffle=True) - - @pytest.fixture() def conditional_fused_dl() -> datasets.ConditionalLoader: n, lin_dim, cond_dim = 128, 3, 7 diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index bd857aef8..35a2c5135 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -48,34 +48,33 @@ def test_genot(self, rng: jax.Array, dl: str, request): problem_type = dl.split("_")[0] dl = request.getfixturevalue(dl) - batch = next(iter(dl)) - batch = jtu.tree_map(jnp.asarray, batch) - src_cond = batch.get("src_condition") - - dims = jtu.tree_map(lambda x: x.shape[-1], batch) - src_dim = dims.get("src_lin", 0) + dims.get("src_quad", 0) - tgt_dim = dims.get("tgt_lin", 0) + dims.get("tgt_quad", 0) + src_dim = dl.lin_dim + dl.quad_src_dim + tgt_dim = dl.lin_dim + dl.quad_tgt_dim + cond_dim = dl.cond_dim vf = models.VelocityField( tgt_dim, hidden_dims=[7, 7, 7], - condition_dims=[7, 7, 7], + condition_dims=None if dl.cond_dim is None else [1, 3, 2], ) - model = genot.GENOT( vf, flow=flows.ConstantNoiseFlow(0.0), data_match_fn=functools.partial(data_match_fn, typ=problem_type), source_dim=src_dim, target_dim=tgt_dim, - condition_dim=None if src_cond is None else src_cond.shape[-1], + condition_dim=cond_dim, rng=rng_init, optimizer=optax.adam(learning_rate=1e-4), ) - _logs = model(dl, n_iters=3, rng=rng_call) + _logs = model(dl.loader, n_iters=3, rng=rng_call) + batch = next(iter(dl.loader)) + batch = jtu.tree_map(jnp.asarray, batch) src = jax.random.normal(rng_data, (3, src_dim)) + src_cond = batch.get("src_condition") + res = model.transport(src, condition=src_cond) assert jnp.sum(jnp.isnan(res)) == 0 diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index 078b5c50a..8e0b4aff7 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -15,6 +15,7 @@ import jax import jax.numpy as jnp +import jax.tree_util as jtu import optax @@ -23,17 +24,15 @@ class TestOTFlowMatching: - @pytest.mark.parametrize(("cond_dim", "dl"), [(0, "lin_dl"), - (3, "lin_dl_with_conds"), - (4, "conditional_lin_dl")]) - def test_fm(self, rng: jax.Array, cond_dim: int, dl: str, request): - dim = 2 # all dataloaders have this dim + @pytest.mark.parametrize("dl", ["lin_dl", "lin_cond_dl"]) + def test_fm(self, rng: jax.Array, dl: str, request): dl = request.getfixturevalue(dl) + dim, cond_dim = dl.lin_dim, dl.cond_dim neural_vf = models.VelocityField( dim, hidden_dims=[5, 5, 5], - condition_dims=[5, 5, 5] if cond_dim > 0 else None, + condition_dims=None if cond_dim is None else [4, 3, 2], ) fm = otfm.OTFlowMatching( neural_vf, @@ -45,17 +44,14 @@ def test_fm(self, rng: jax.Array, cond_dim: int, dl: str, request): condition_dim=cond_dim, ) - _logs = fm(dl, n_iters=3) + _logs = fm(dl.loader, n_iters=3) - batch = next(iter(dl)) - src = jnp.asarray(batch["src_lin"]) - tgt = jnp.asarray(batch["tgt_lin"]) + batch = next(iter(dl.loader)) + batch = jtu.tree_map(jnp.asarray, batch) src_cond = batch.get("src_condition") - if src_cond is not None: - src_cond = jnp.asarray(src_cond) - res_fwd = fm.transport(src, condition=src_cond) - res_bwd = fm.transport(tgt, t0=1.0, t1=0.0, condition=src_cond) + res_fwd = fm.transport(batch["src_lin"], condition=src_cond) + res_bwd = fm.transport(batch["tgt_lin"], t0=1.0, t1=0.0, condition=src_cond) # TODO(michalk8): better assertions assert jnp.sum(jnp.isnan(res_fwd)) == 0 From f8fcba7e8edcb65a3b66fa2099f71eae1db2e4af Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Tue, 26 Mar 2024 19:04:06 +0100 Subject: [PATCH 159/186] Fix condition in GENOT test --- tests/neural/genot_test.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index 35a2c5135..3f4de227f 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -41,8 +41,9 @@ def data_match_fn( class TestGENOT: - # TODO(michalk8): add conds - @pytest.mark.parametrize("dl", ["lin_dl", "quad_dl", "fused_dl"]) + @pytest.mark.parametrize( + "dl", ["lin_dl", "quad_dl", "fused_dl", "lin_cond_dl"] + ) def test_genot(self, rng: jax.Array, dl: str, request): rng_init, rng_call, rng_data = jax.random.split(rng, 3) problem_type = dl.split("_")[0] @@ -50,12 +51,12 @@ def test_genot(self, rng: jax.Array, dl: str, request): src_dim = dl.lin_dim + dl.quad_src_dim tgt_dim = dl.lin_dim + dl.quad_tgt_dim - cond_dim = dl.cond_dim + cond_dim = dl.cnd_dim vf = models.VelocityField( tgt_dim, hidden_dims=[7, 7, 7], - condition_dims=None if dl.cond_dim is None else [1, 3, 2], + condition_dims=None if cond_dim is None else [1, 3, 2], ) model = genot.GENOT( vf, @@ -72,8 +73,9 @@ def test_genot(self, rng: jax.Array, dl: str, request): batch = next(iter(dl.loader)) batch = jtu.tree_map(jnp.asarray, batch) - src = jax.random.normal(rng_data, (3, src_dim)) src_cond = batch.get("src_condition") + batch_size = 4 if src_cond is None else src_cond.shape[0] + src = jax.random.normal(rng_data, (batch_size, src_dim)) res = model.transport(src, condition=src_cond) From 49a07a00ad544bbf25e9f3f3fe235f486fbf6eaa Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Tue, 26 Mar 2024 19:09:10 +0100 Subject: [PATCH 160/186] Add quad cond dl --- tests/neural/conftest.py | 19 +++++++++++++++++++ tests/neural/genot_test.py | 4 ++-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py index 8653e7d63..3493115c5 100644 --- a/tests/neural/conftest.py +++ b/tests/neural/conftest.py @@ -106,6 +106,25 @@ def quad_dl(): ) +@pytest.fixture() +def quad_cond_dl(): + n, quad_src_dim, quad_tgt_dim, cond_dim = 128, 2, 4, 5 + rng = np.random.default_rng(414) + + src_cond = rng.normal(size=(n, cond_dim)) + tgt_cond = rng.normal(size=(n, cond_dim)) + src = _ot_data(rng, n=n, quad_dim=quad_src_dim, condition=src_cond) + tgt = _ot_data(rng, n=n, quad_dim=quad_tgt_dim, offset=1.0, cond_dim=tgt_cond) + ds = datasets.OTDataset(src, tgt) + + return OTLoader( + DataLoader(ds, batch_size=16, shuffle=True), + quad_src_dim=quad_src_dim, + quad_tgt_dim=quad_tgt_dim, + cond_dim=cond_dim, + ) + + @pytest.fixture() def fused_dl(): n, lin_dim, quad_src_dim, quad_tgt_dim = 128, 6, 2, 4 diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index 3f4de227f..59a738e0a 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -42,7 +42,7 @@ def data_match_fn( class TestGENOT: @pytest.mark.parametrize( - "dl", ["lin_dl", "quad_dl", "fused_dl", "lin_cond_dl"] + "dl", ["lin_dl", "quad_dl", "fused_dl", "lin_cond_dl", "quad_cond_dl"] ) def test_genot(self, rng: jax.Array, dl: str, request): rng_init, rng_call, rng_data = jax.random.split(rng, 3) @@ -51,7 +51,7 @@ def test_genot(self, rng: jax.Array, dl: str, request): src_dim = dl.lin_dim + dl.quad_src_dim tgt_dim = dl.lin_dim + dl.quad_tgt_dim - cond_dim = dl.cnd_dim + cond_dim = dl.cond_dim vf = models.VelocityField( tgt_dim, From d1ae1de237dac8edece99fbcb58394378a3d108b Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Tue, 26 Mar 2024 22:49:07 +0100 Subject: [PATCH 161/186] Add conf fused DL --- tests/neural/conftest.py | 64 ++++++++++---------------------------- tests/neural/genot_test.py | 5 ++- 2 files changed, 20 insertions(+), 49 deletions(-) diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py index 3493115c5..92c23f6a6 100644 --- a/tests/neural/conftest.py +++ b/tests/neural/conftest.py @@ -143,60 +143,28 @@ def fused_dl(): @pytest.fixture() -def conditional_quad_dl() -> datasets.ConditionalLoader: - n, cond_dim = 128, 5 - quad_dim_src, quad_dim_tgt = 2, 4 +def fused_cond_dl(): + n, lin_dim, quad_src_dim, quad_tgt_dim, cond_dim = 128, 6, 2, 4, 7 rng = np.random.default_rng(11) - src0 = _ot_data( - rng, n=n, condition=0.0, cond_dim=cond_dim, quad_dim=quad_dim_src - ) - tgt0 = _ot_data( - rng, n=n, quad_dim=quad_dim_tgt, cond_dim=cond_dim, offset=2.0 - ) - src1 = _ot_data( - rng, n=n, condition=1.0, cond_dim=cond_dim, quad_dim=quad_dim_src + src_cond = rng.normal(size=(n, cond_dim)) + tgt_cond = rng.normal(size=(n, cond_dim)) + src = _ot_data( + rng, n=n, lin_dim=lin_dim, quad_dim=quad_src_dim, condition=src_cond ) - tgt1 = _ot_data(rng, n=n, quad_dim=quad_dim_tgt, offset=-2.0) - - src_ds = datasets.OTDataset(src0, tgt0) - tgt_ds = datasets.OTDataset(src1, tgt1) - - src_dl = DataLoader(src_ds, batch_size=16, shuffle=True) - tgt_dl = DataLoader(tgt_ds, batch_size=16, shuffle=True) - - return datasets.ConditionalLoader([src_dl, tgt_dl]) - - -@pytest.fixture() -def conditional_fused_dl() -> datasets.ConditionalLoader: - n, lin_dim, cond_dim = 128, 3, 7 - quad_dim_src, quad_dim_tgt = 2, 4 - rng = np.random.default_rng(11) - - src0 = _ot_data( + tgt = _ot_data( rng, n=n, - condition=0.0, - cond_dim=cond_dim, lin_dim=lin_dim, - quad_dim=quad_dim_src + quad_dim=quad_tgt_dim, + offset=1.0, + condition=tgt_cond ) - tgt0 = _ot_data(rng, n=n, lin_dim=lin_dim, quad_dim=quad_dim_tgt, offset=2.0) - src1 = _ot_data( - rng, - n=n, - condition=1.0, - cond_dim=cond_dim, + ds = datasets.OTDataset(src, tgt) + + return OTLoader( + DataLoader(ds, batch_size=16, shuffle=True), lin_dim=lin_dim, - quad_dim=quad_dim_src + quad_src_dim=quad_src_dim, + quad_tgt_dim=quad_tgt_dim, ) - tgt1 = _ot_data(rng, n=n, lin_dim=lin_dim, quad_dim=quad_dim_tgt, offset=-2.0) - - src_ds = datasets.OTDataset(src0, tgt0) - tgt_ds = datasets.OTDataset(src1, tgt1) - - src_dl = DataLoader(src_ds, batch_size=16, shuffle=True) - tgt_dl = DataLoader(tgt_ds, batch_size=16, shuffle=True) - - return datasets.ConditionalLoader([src_dl, tgt_dl]) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index 59a738e0a..086cc82ea 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -42,7 +42,10 @@ def data_match_fn( class TestGENOT: @pytest.mark.parametrize( - "dl", ["lin_dl", "quad_dl", "fused_dl", "lin_cond_dl", "quad_cond_dl"] + "dl", [ + "lin_dl", "quad_dl", "fused_dl", "lin_cond_dl", "quad_cond_dl", + "fused_cond_dl" + ] ) def test_genot(self, rng: jax.Array, dl: str, request): rng_init, rng_call, rng_data = jax.random.split(rng, 3) From f6c69bdf4d26a49830eb36dddb9cea4520a77d92 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Wed, 27 Mar 2024 00:18:36 +0100 Subject: [PATCH 162/186] Polish docs --- src/ott/neural/flow_models/genot.py | 35 +++++++++++++---------------- src/ott/neural/flow_models/otfm.py | 21 +++++++++-------- tests/neural/otfm_test.py | 3 +-- 3 files changed, 27 insertions(+), 32 deletions(-) diff --git a/src/ott/neural/flow_models/genot.py b/src/ott/neural/flow_models/genot.py index d2e579f23..08e011c04 100644 --- a/src/ott/neural/flow_models/genot.py +++ b/src/ott/neural/flow_models/genot.py @@ -46,30 +46,27 @@ class GENOT: Args: velocity_field: Vector field parameterized by a neural network. - flow: Flow between latent distribution and target distribution. - data_match_fn: Function to match source and target distributions. - The function accepts a 4-tuple ``(src_lin, tgt_lin, src_quad, tgt_quad)`` - and return the transport matrix of shape ``(len(src), len(tgt))``. - Either linear, quadratic or both linear and quadratic source and target - arrays are passed, corresponding to the linear, quadratic and - fused GW couplings, respectively. - source_dim: Dimension of the source space. - target_dim: Dimension of the target space. + flow: Flow between the latent and the target distributions. + data_match_fn: Function to match samples from the source and the target + distributions with a ``(src_lin, tgt_lin, src_quad, tgt_quad) -> matching`` + signature. + source_dim: Dimensionality of the source distribution. + target_dim: Dimensionality of the target distribution. condition_dim: Dimension of the conditions. If :obj:`None`, the underlying velocity field has no conditions. + time_sampler: Time sampler with a ``(rng, n_samples) -> time`` signature. + latent_noise_fn: Function to sample from the latent distribution in the + target space with a ``(rng, shape) -> noise`` signature. + If :obj:`None`, multivariate normal distribution is used. + latent_match_fn: Function to match samples from the latent distribution + and the samples from the conditional distribution with a + ``(latent, samples) -> matching`` signature. If :obj:`None`, no matching + is performed. n_samples_per_src: Number of samples drawn from the conditional distribution per one source sample. - time_sampler: Sampler for the time to learn the neural ODE. If :obj:`None`, - the time is uniformly sampled. - latent_noise_fn: Function to sample from the latent distribution in the - target space. If :obj:`None`, the latent distribution is sampled from a - multivariate normal distribution. - latent_match_fn: Function to pair the latent distribution with - the ``n_samples_per_src`` samples of the conditional distribution. - If :obj:`None`, no matching is performed. kwargs: Keyword arguments for :meth:`ott.neural.flow_models.models.VelocityField.create_train_state`. - """ + """ # noqa: E501 def __init__( self, @@ -80,13 +77,13 @@ def __init__( source_dim: int, target_dim: int, condition_dim: Optional[int] = None, - n_samples_per_src: int = 1, time_sampler: Callable[[jax.Array, int], jnp.ndarray] = flow_utils.uniform_sampler, latent_noise_fn: Optional[Callable[[jax.Array, Tuple[int, ...]], jnp.ndarray]] = None, latent_match_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None, + n_samples_per_src: int = 1, **kwargs: Any, ): self.vf = velocity_field diff --git a/src/ott/neural/flow_models/otfm.py b/src/ott/neural/flow_models/otfm.py index 3e4381b65..2c6be25a4 100644 --- a/src/ott/neural/flow_models/otfm.py +++ b/src/ott/neural/flow_models/otfm.py @@ -31,20 +31,18 @@ class OTFlowMatching: """(Optimal transport) flow matching :cite:`lipman:22`. - With an extension to OT-FM :cite:`tong:23`, :cite:`pooladian:23`. + With an extension to OT-FM :cite:`tong:23,pooladian:23`. Args: velocity_field: Vector field parameterized by a neural network. - flow: Flow between source and target distribution. - match_fn: Function to match data points from the source distribution and - the target distribution. - time_sampler: Sampler for the time. - # TODO(michalk8): expose all args for the train state? - kwargs: TODO. + flow: Flow between the source and the target distributions. + match_fn: Function to match samples from the source and the target + distributions. It has a ``(src, tgt) -> matching`` signature. + time_sampler: Time sampler with a ``(rng, n_samples) -> time`` signature. + kwargs: Keyword arguments for + :meth:`~ott.neural.flow_models.models.VelocityField.create_train_state`. """ - # TODO(michalk8): in the future, `input_dim`, `optimizer` and `rng` will be - # in a separate function def __init__( self, velocity_field: models.VelocityField, @@ -60,7 +58,9 @@ def __init__( self.time_sampler = time_sampler self.match_fn = match_fn - self.vf_state = self.vf.create_train_state(**kwargs) + self.vf_state = self.vf.create_train_state( + input_dim=self.vf.output_dim, **kwargs + ) self.step_fn = self._get_step_fn() def _get_step_fn(self) -> Callable: @@ -97,7 +97,6 @@ def loss_fn( return step_fn - # TODO(michalk8): refactor in the future PR to just do one step def __call__( # noqa: D102 self, loader: Iterable[Dict[str, np.ndarray]], diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index 8e0b4aff7..7f08c55dd 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -25,7 +25,7 @@ class TestOTFlowMatching: @pytest.mark.parametrize("dl", ["lin_dl", "lin_cond_dl"]) - def test_fm(self, rng: jax.Array, dl: str, request): + def test_otfm(self, rng: jax.Array, dl: str, request): dl = request.getfixturevalue(dl) dim, cond_dim = dl.lin_dim, dl.cond_dim @@ -40,7 +40,6 @@ def test_fm(self, rng: jax.Array, dl: str, request): match_fn=jax.jit(utils.match_linear), rng=rng, optimizer=optax.adam(learning_rate=1e-3), - input_dim=dim, condition_dim=cond_dim, ) From 3b69c0f1ac56d9fdea2c015317d0ce422bb7d561 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Wed, 27 Mar 2024 00:22:51 +0100 Subject: [PATCH 163/186] Remove conditional loader --- docs/neural/data.rst | 2 +- src/ott/neural/data/datasets.py | 49 ++------------------------------- 2 files changed, 4 insertions(+), 47 deletions(-) diff --git a/docs/neural/data.rst b/docs/neural/data.rst index 25172dcd3..79a602746 100644 --- a/docs/neural/data.rst +++ b/docs/neural/data.rst @@ -11,5 +11,5 @@ Datasets .. autosummary:: :toctree: _autosummary + datasets.OTData datasets.OTDataset - datasets.ConditionalLoader diff --git a/src/ott/neural/data/datasets.py b/src/ott/neural/data/datasets.py index 28c49a5af..c5661aecc 100644 --- a/src/ott/neural/data/datasets.py +++ b/src/ott/neural/data/datasets.py @@ -13,11 +13,11 @@ # limitations under the License. import collections import dataclasses -from typing import Any, Dict, Iterable, Optional, Sequence +from typing import Any, Dict, Optional, Sequence import numpy as np -__all__ = ["OTData", "OTDataset", "ConditionalLoader"] +__all__ = ["OTData", "OTDataset"] Item_t = Dict[str, np.ndarray] @@ -47,7 +47,7 @@ def __len__(self) -> int: class OTDataset: - """Dataset for (conditional) optimal transport problems. + """Dataset for optimal transport problems. Args: src_data: Samples from the source distribution. @@ -118,46 +118,3 @@ def __getitem__(self, ix: int) -> Item_t: def __len__(self) -> int: return len(self.src_data) - - -class ConditionalLoader: - """Dataset for OT problems with conditions. - - This data loader wraps several data loaders and samples from them. - - Args: - datasets: Datasets to sample from. - seed: Random seed. - """ - - def __init__( - self, - datasets: Iterable[OTDataset], - seed: Optional[int] = None, - ): - self.datasets = tuple(datasets) - self._rng = np.random.default_rng(seed) - self._iterators = [] - self._it = 0 - - def __next__(self) -> Item_t: - if self._it == len(self): - raise StopIteration - self._it += 1 - - ix = self._rng.choice(len(self._iterators)) - iterator = self._iterators[ix] - try: - return next(iterator) - except StopIteration: - # reset the consumed iterator and return it's first element - self._iterators[ix] = iterator = iter(self.datasets[ix]) - return next(iterator) - - def __iter__(self) -> "ConditionalLoader": - self._it = 0 - self._iterators = [iter(ds) for ds in self.datasets] - return self - - def __len__(self) -> int: - return max((len(ds) for ds in self.datasets), default=0) From 0ff3ad64dd3cc50502124210b10ea8b956b0722b Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Wed, 27 Mar 2024 00:28:47 +0100 Subject: [PATCH 164/186] Fix link in the docs --- docs/neural/data.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/neural/data.rst b/docs/neural/data.rst index 79a602746..3ae0e4c53 100644 --- a/docs/neural/data.rst +++ b/docs/neural/data.rst @@ -3,7 +3,7 @@ ott.neural.data .. module:: ott.neural.data .. currentmodule:: ott.neural.data -The :mod:`ott.problems.data` contains data sets and data loaders needed +The :mod:`ott.neural.data` contains data sets and data loaders needed for solving (conditional) neural optimal transport problems. Datasets From c3ce78649f0ad550bfa964d52e68f66f9ae8bb8a Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Wed, 27 Mar 2024 00:42:28 +0100 Subject: [PATCH 165/186] Improve VF --- src/ott/neural/flow_models/models.py | 48 +++++++++++----------------- src/ott/neural/flow_models/otfm.py | 2 +- tests/neural/otfm_test.py | 2 +- 3 files changed, 21 insertions(+), 31 deletions(-) diff --git a/src/ott/neural/flow_models/models.py b/src/ott/neural/flow_models/models.py index cd182d20e..a1b8a5291 100644 --- a/src/ott/neural/flow_models/models.py +++ b/src/ott/neural/flow_models/models.py @@ -26,7 +26,7 @@ class VelocityField(nn.Module): - r"""Parameterized neural vector field. + r"""Neural vector field. This class learns a map :math:`v: \mathbb{R}\times \mathbb{R}^d \rightarrow \mathbb{R}^d` solving the ODE :math:`\frac{dx}{dt} = v(t, x)`. @@ -36,16 +36,16 @@ class VelocityField(nn.Module): from :math:`t=t_0` to :math:`t=t_1`. Args: + output_dims: TODO. hidden_dims: Dimensionality of the embedding of the data. condition_dims: Dimensionality of the embedding of the condition. If :obj:`None`, the velocity field has no conditions. time_dims: Dimensionality of the time embedding. - If :obj:`None`, ``hidden_dims`` will be used. - time_encoder: Function to encode the time input to the time-dependent - velocity field. + If :obj:`None`, ``hidden_dims`` is used. + time_encoder: Time encoder for the velocity field. act_fn: Activation function. """ - output_dim: int + output_dims: Sequence[int] hidden_dims: Sequence[int] = (128, 128, 128) condition_dims: Optional[Sequence[int]] = None time_dims: Optional[Sequence[int]] = None @@ -70,33 +70,27 @@ def __call__( Returns: Output of the neural vector field of shape ``[batch, output_dim]``. """ - if self.condition_dims is None: - cond_dims = [None] * len(self.hidden_dims) - else: - cond_dims = self.condition_dims time_dims = self.hidden_dims if self.time_dims is None else self.time_dims - assert len(self.hidden_dims) == len(cond_dims), "TODO" - assert len(self.hidden_dims) == len(time_dims), "TODO" - t = self.time_encoder(t) - for time_dim, cond_dim, hidden_dim in zip( - time_dims, cond_dims, self.hidden_dims - ): + for time_dim in time_dims: t = self.act_fn(nn.Dense(time_dim)(t)) + + for hidden_dim in self.hidden_dims: x = self.act_fn(nn.Dense(hidden_dim)(x)) - if self.condition_dims is not None: - assert condition is not None, "No condition was specified." - condition = self.act_fn(nn.Dense(cond_dim)(condition)) - feats = [t, x] + ([] if self.condition_dims is None else [condition]) - feats = jnp.concatenate(feats, axis=-1) - joint_dim = feats.shape[-1] + if self.condition_dims is not None: + assert condition is not None, "No condition was passed." + for cond_dim in self.condition_dims: + condition = self.act_fn(nn.Dense(cond_dim)(condition)) + feats = jnp.concatenate([t, x, condition], axis=-1) + else: + feats = jnp.concatenate([t, x], axis=-1) - for _ in range(len(self.hidden_dims)): - feats = self.act_fn(nn.Dense(joint_dim)(feats)) + for output_dim in self.output_dims: + feats = self.act_fn(nn.Dense(output_dim)(feats)) - return nn.Dense(self.output_dim)(feats) + return feats def create_train_state( self, @@ -117,11 +111,7 @@ def create_train_state( The training state. """ t, x = jnp.ones((1, 1)), jnp.ones((1, input_dim)) - if self.condition_dims is not None: - assert condition_dim is not None, "TODO" - cond = jnp.ones((1, condition_dim)) - else: - cond = None + cond = None if self.condition_dims is None else jnp.ones((1, condition_dim)) params = self.init(rng, t, x, cond)["params"] return train_state.TrainState.create( diff --git a/src/ott/neural/flow_models/otfm.py b/src/ott/neural/flow_models/otfm.py index 2c6be25a4..f6ccd6e1b 100644 --- a/src/ott/neural/flow_models/otfm.py +++ b/src/ott/neural/flow_models/otfm.py @@ -59,7 +59,7 @@ def __init__( self.match_fn = match_fn self.vf_state = self.vf.create_train_state( - input_dim=self.vf.output_dim, **kwargs + input_dim=self.vf.output_dims[-1], **kwargs ) self.step_fn = self._get_step_fn() diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index 7f08c55dd..8d746dd88 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -30,7 +30,7 @@ def test_otfm(self, rng: jax.Array, dl: str, request): dim, cond_dim = dl.lin_dim, dl.cond_dim neural_vf = models.VelocityField( - dim, + output_dims=[7, dim], hidden_dims=[5, 5, 5], condition_dims=None if cond_dim is None else [4, 3, 2], ) From 161dd4a498f4c525b9eb844edf0ce1acb469f479 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Wed, 27 Mar 2024 00:42:56 +0100 Subject: [PATCH 166/186] Fix GENOT test --- tests/neural/genot_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index 086cc82ea..0005d56ba 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -57,7 +57,7 @@ def test_genot(self, rng: jax.Array, dl: str, request): cond_dim = dl.cond_dim vf = models.VelocityField( - tgt_dim, + output_dims=[15, tgt_dim], hidden_dims=[7, 7, 7], condition_dims=None if cond_dim is None else [1, 3, 2], ) From 69c3a4d474bd1ba93a879af39498c5c2d381f855 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Wed, 27 Mar 2024 00:50:20 +0100 Subject: [PATCH 167/186] Polish docs --- src/ott/neural/flow_models/models.py | 4 ++-- src/ott/neural/flow_models/utils.py | 19 +++++++++---------- tests/neural/genot_test.py | 2 +- tests/neural/otfm_test.py | 2 +- 4 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/ott/neural/flow_models/models.py b/src/ott/neural/flow_models/models.py index a1b8a5291..a770b1fdd 100644 --- a/src/ott/neural/flow_models/models.py +++ b/src/ott/neural/flow_models/models.py @@ -36,8 +36,8 @@ class VelocityField(nn.Module): from :math:`t=t_0` to :math:`t=t_1`. Args: - output_dims: TODO. hidden_dims: Dimensionality of the embedding of the data. + output_dims: Dimensionality of the embedding of the output. condition_dims: Dimensionality of the embedding of the condition. If :obj:`None`, the velocity field has no conditions. time_dims: Dimensionality of the time embedding. @@ -45,8 +45,8 @@ class VelocityField(nn.Module): time_encoder: Time encoder for the velocity field. act_fn: Activation function. """ + hidden_dims: Sequence[int] output_dims: Sequence[int] - hidden_dims: Sequence[int] = (128, 128, 128) condition_dims: Optional[Sequence[int]] = None time_dims: Optional[Sequence[int]] = None time_encoder: Callable[[jnp.ndarray], diff --git a/src/ott/neural/flow_models/utils.py b/src/ott/neural/flow_models/utils.py index f85ef159a..dfbbe5c76 100644 --- a/src/ott/neural/flow_models/utils.py +++ b/src/ott/neural/flow_models/utils.py @@ -64,7 +64,6 @@ def match_quadratic( yy: jnp.ndarray, x: Optional[jnp.ndarray] = None, y: Optional[jnp.ndarray] = None, - # TODO(michalk8): expose for all the costs scale_cost: ScaleCost_t = 1.0, cost_fn: Optional[costs.CostFn] = None, **kwargs: Any @@ -98,14 +97,14 @@ def match_quadratic( def sample_joint(rng: jax.Array, tmat: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: - """Sample from a transport matrix. + """Sample jointly from a transport matrix. Args: rng: Random number generator. - tmat: Transport matrix. + tmat: Transport matrix of shape ``[n, m]``. Returns: - Source and target indices sampled from the transport matrix. + Source and target indices of shape ``[n,]`` and ``[m,]``, respectively. """ n, m = tmat.shape tmat_flattened = tmat.flatten() @@ -124,18 +123,18 @@ def sample_conditional( k: int = 1, uniform_marginals: bool = False, ) -> Tuple[jnp.ndarray, jnp.ndarray]: - """Sample indices from a transport matrix. + """Sample conditionally from a transport matrix. Args: rng: Random number generator. - tmat: Transport matrix. + tmat: Transport matrix of shape ``[n, m]``. k: Expected number of samples to sample per source. uniform_marginals: If :obj:`True`, sample exactly `k` samples per row, otherwise sample proportionally to the sums of the rows of the transport matrix. Returns: - Source and target indices sampled from the transport matrix. + Source and target indices of shape ``[n, k]`` and ``[m, k]``, respectively. """ assert k > 0, "Number of samples per source must be positive." n, m = tmat.shape @@ -195,11 +194,11 @@ def uniform_sampler( num_samples: Number of samples to generate. low: Lower bound of the uniform distribution. high: Upper bound of the uniform distribution. - offset: Offset of the uniform distribution. If :obj:`None`, no offset is - used. + offset: Offset of the uniform distribution. + If :obj:`None`, no offset is used. Returns: - An array with `num_samples` samples of the time :math:`t`. + An array of shape ``[num_samples, 1]``. """ if offset is None: return jax.random.uniform(rng, (num_samples, 1), minval=low, maxval=high) diff --git a/tests/neural/genot_test.py b/tests/neural/genot_test.py index 0005d56ba..c37d91563 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/genot_test.py @@ -57,8 +57,8 @@ def test_genot(self, rng: jax.Array, dl: str, request): cond_dim = dl.cond_dim vf = models.VelocityField( - output_dims=[15, tgt_dim], hidden_dims=[7, 7, 7], + output_dims=[15, tgt_dim], condition_dims=None if cond_dim is None else [1, 3, 2], ) model = genot.GENOT( diff --git a/tests/neural/otfm_test.py b/tests/neural/otfm_test.py index 8d746dd88..00619dc57 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/otfm_test.py @@ -30,8 +30,8 @@ def test_otfm(self, rng: jax.Array, dl: str, request): dim, cond_dim = dl.lin_dim, dl.cond_dim neural_vf = models.VelocityField( - output_dims=[7, dim], hidden_dims=[5, 5, 5], + output_dims=[7, dim], condition_dims=None if cond_dim is None else [4, 3, 2], ) fm = otfm.OTFlowMatching( From 65f2ab3d6e144e4dd60451308f9090eb7df8ce1e Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Wed, 27 Mar 2024 17:20:11 +0100 Subject: [PATCH 168/186] Remove `uniform_marginals` argument --- src/ott/neural/flow_models/genot.py | 1 - src/ott/neural/flow_models/utils.py | 25 +++++++++---------------- 2 files changed, 9 insertions(+), 17 deletions(-) diff --git a/src/ott/neural/flow_models/genot.py b/src/ott/neural/flow_models/genot.py index 08e011c04..3f2929e71 100644 --- a/src/ott/neural/flow_models/genot.py +++ b/src/ott/neural/flow_models/genot.py @@ -197,7 +197,6 @@ def prepare_data( rng_resample, tmat, k=self.n_samples_per_src, - uniform_marginals=True, # TODO(michalk8): expose ) src, tgt = src[src_ixs], tgt[tgt_ixs] # (n, k, ...), # (m, k, ...) diff --git a/src/ott/neural/flow_models/utils.py b/src/ott/neural/flow_models/utils.py index dfbbe5c76..6c67e45f0 100644 --- a/src/ott/neural/flow_models/utils.py +++ b/src/ott/neural/flow_models/utils.py @@ -121,17 +121,13 @@ def sample_conditional( tmat: jnp.ndarray, *, k: int = 1, - uniform_marginals: bool = False, ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Sample conditionally from a transport matrix. Args: rng: Random number generator. tmat: Transport matrix of shape ``[n, m]``. - k: Expected number of samples to sample per source. - uniform_marginals: If :obj:`True`, sample exactly `k` samples - per row, otherwise sample proportionally to the sums of the - rows of the transport matrix. + k: Expected number of samples to sample per source sample. Returns: Source and target indices of shape ``[n, k]`` and ``[m, k]``, respectively. @@ -139,19 +135,16 @@ def sample_conditional( assert k > 0, "Number of samples per source must be positive." n, m = tmat.shape - if uniform_marginals: - indices = jnp.arange(n) - else: - src_marginals = tmat.sum(axis=1) - rng, rng_ixs = jax.random.split(rng, 2) - indices = jax.random.choice( - rng_ixs, a=n, p=src_marginals, shape=(len(src_marginals),) - ) - tmat = tmat[indices] + src_marginals = tmat.sum(axis=1) + rng, rng_ixs = jax.random.split(rng, 2) + indices = jax.random.choice(rng_ixs, a=n, p=src_marginals, shape=(n,)) + tmat = tmat[indices] + rngs = jax.random.split(rng, n) tgt_ixs = jax.vmap( - lambda row: jax.random.choice(rng, a=m, p=row, shape=(k,)) - )(tmat) # (m, k) + lambda rng, row: jax.random.choice(rng, a=m, p=row, shape=(k,)), + in_axes=[0, 0], + )(rngs, tmat) # (m, k) src_ixs = jnp.repeat(indices[:, None], k, axis=1) # (n, k) return src_ixs, tgt_ixs From ba64056baeca5e3dc68395f0f6ff7a66037ceae4 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Wed, 27 Mar 2024 17:28:35 +0100 Subject: [PATCH 169/186] Fix undefined variable --- src/ott/neural/flow_models/genot.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ott/neural/flow_models/genot.py b/src/ott/neural/flow_models/genot.py index 3f2929e71..712d26ebb 100644 --- a/src/ott/neural/flow_models/genot.py +++ b/src/ott/neural/flow_models/genot.py @@ -120,7 +120,7 @@ def loss_fn( params: jnp.ndarray, time: jnp.ndarray, source: jnp.ndarray, target: jnp.ndarray, latent: jnp.ndarray, source_conditions: Optional[jnp.ndarray], rng: jax.Array - ): + ) -> jnp.ndarray: x_t = self.flow.compute_xt(rng, time, latent, target) if source_conditions is None: cond = source @@ -132,7 +132,7 @@ def loss_fn( return jnp.mean((v_t - u_t) ** 2) - grad_fn = jax.value_and_grad(loss_fn, has_aux=False) + grad_fn = jax.value_and_grad(loss_fn) loss, grads = grad_fn( vf_state.params, time, source, target, latent, source_conditions, rng ) @@ -244,7 +244,7 @@ def resample( in_axes, out_axes = (0, 1, cond_axis, 1, 1), (1, cond_axis, 1) resample_fn = jax.jit(jax.vmap(resample, in_axes, out_axes)) - rngs = jax.random.split(rng, self.k_samples_per_x) + rngs = jax.random.split(rng, self.n_samples_per_src) return resample_fn(rngs, src, src_cond, tgt, latent) def transport( From 80d292413cb1a6c1b731de0e8e7e930628574350 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Wed, 27 Mar 2024 17:30:14 +0100 Subject: [PATCH 170/186] Update `GENOT.transport` docs --- src/ott/neural/flow_models/genot.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/ott/neural/flow_models/genot.py b/src/ott/neural/flow_models/genot.py index 712d26ebb..a6fb30c91 100644 --- a/src/ott/neural/flow_models/genot.py +++ b/src/ott/neural/flow_models/genot.py @@ -258,21 +258,19 @@ def transport( ) -> jnp.ndarray: """Transport data with the learned plan. - This method pushes-forward the `source` to its conditional distribution by - solving the neural ODE parameterized by the - :attr:`~ott.neural.flows.genot.velocity_field` + This function pushes forward the source distribution to its conditional + distribution by solving the neural ODE. Args: source: Data to transport. condition: Condition of the input data. t0: Starting time of integration of neural ODE. t1: End time of integration of neural ODE. - rng: random seed for sampling from the latent distribution. - kwargs: Keyword arguments for the ODE solver. + rng: Random generate used to sample from the latent distribution. + kwargs: Keyword arguments for :func:`~diffrax.odesolve`. Returns: - The push-forward or pull-back distribution defined by the learned - transport plan. + The push-forward defined by the learned transport plan. """ def vf(t: jnp.ndarray, x: jnp.ndarray, cond: jnp.ndarray) -> jnp.ndarray: From e4aae7f5177e2c886e45ad947aafcc2fff0780b5 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Wed, 27 Mar 2024 17:34:27 +0100 Subject: [PATCH 171/186] Add `diffrax` to `conf.py` --- docs/conf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 5832efbda..571fc0cfd 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -63,13 +63,13 @@ "python": ("https://docs.python.org/3", None), "numpy": ("https://numpy.org/doc/stable/", None), "jax": ("https://jax.readthedocs.io/en/latest/", None), + "jaxopt": ("https://jaxopt.github.io/stable", None), "lineax": ("https://docs.kidger.site/lineax/", None), "flax": ("https://flax.readthedocs.io/en/latest/", None), - "scikit-sparse": ("https://scikit-sparse.readthedocs.io/en/latest/", None), + "optax": ("https://optax.readthedocs.io/en/latest/", None), + "diffrax": ("https://docs.kidger.site/diffrax/", None), "scipy": ("https://docs.scipy.org/doc/scipy/", None), "pot": ("https://pythonot.github.io/", None), - "jaxopt": ("https://jaxopt.github.io/stable", None), - "optax": ("https://optax.readthedocs.io/en/latest/", None), "matplotlib": ("https://matplotlib.org/stable/", None), } From 1d96fac1c7ceb683d52ba450c65375a1ddbc68a1 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 29 Mar 2024 19:05:56 +0100 Subject: [PATCH 172/186] Restructure files --- src/ott/__init__.py | 1 - src/ott/initializers/__init__.py | 7 + .../data => initializers/neural}/__init__.py | 2 +- .../neural/meta_initializer.py} | 195 +----------------- src/ott/neural/__init__.py | 2 +- src/ott/neural/{data => }/datasets.py | 0 src/ott/neural/gaps/monge_gap.py | 148 ------------- src/ott/neural/{gaps => methods}/__init__.py | 2 +- .../flows}/__init__.py | 2 +- .../flows.py => methods/flows/dynamics.py} | 6 +- .../{flow_models => methods/flows}/genot.py | 23 ++- .../{flow_models => methods/flows}/otfm.py | 17 +- .../map_estimator.py => methods/monge_gap.py} | 139 ++++++++++++- .../neural/{duality => methods}/neuraldual.py | 158 ++------------ src/ott/neural/networks/__init__.py | 14 ++ src/ott/neural/networks/icnn.py | 160 ++++++++++++++ .../{duality => networks/layers}/__init__.py | 2 +- .../{duality => networks/layers}/conjugate.py | 0 .../layers.py => networks/layers/posdef.py} | 3 +- .../neural/networks/layers/time_encoder.py | 34 +++ src/ott/neural/networks/potentials.py | 185 +++++++++++++++++ .../models.py => networks/velocity_field.py} | 6 +- src/ott/solvers/__init__.py | 2 +- .../{neural/flow_models => solvers}/utils.py | 19 -- tests/__init__.py | 0 tests/geometry/graph_test.py | 6 +- .../neural/meta_initializer_test.py | 6 +- tests/neural/__init__.py | 15 +- tests/neural/conftest.py | 2 +- tests/neural/map_estimator_test.py | 88 -------- tests/neural/{ => methods}/genot_test.py | 16 +- .../monge_gap_test.py} | 77 ++++++- tests/neural/{ => methods}/neuraldual_test.py | 22 +- tests/neural/{ => methods}/otfm_test.py | 12 +- tests/neural/{ => networks}/icnn_test.py | 6 +- tests/tools/plot_test.py | 5 +- 36 files changed, 721 insertions(+), 661 deletions(-) rename src/ott/{neural/data => initializers/neural}/__init__.py (94%) rename src/ott/{neural/duality/models.py => initializers/neural/meta_initializer.py} (51%) rename src/ott/neural/{data => }/datasets.py (100%) delete mode 100644 src/ott/neural/gaps/monge_gap.py rename src/ott/neural/{gaps => methods}/__init__.py (93%) rename src/ott/neural/{flow_models => methods/flows}/__init__.py (92%) rename src/ott/neural/{flow_models/flows.py => methods/flows/dynamics.py} (98%) rename src/ott/neural/{flow_models => methods/flows}/genot.py (94%) rename src/ott/neural/{flow_models => methods/flows}/otfm.py (93%) rename src/ott/neural/{gaps/map_estimator.py => methods/monge_gap.py} (63%) rename src/ott/neural/{duality => methods}/neuraldual.py (80%) create mode 100644 src/ott/neural/networks/__init__.py create mode 100644 src/ott/neural/networks/icnn.py rename src/ott/neural/{duality => networks/layers}/__init__.py (91%) rename src/ott/neural/{duality => networks/layers}/conjugate.py (100%) rename src/ott/neural/{duality/layers.py => networks/layers/posdef.py} (99%) create mode 100644 src/ott/neural/networks/layers/time_encoder.py create mode 100644 src/ott/neural/networks/potentials.py rename src/ott/neural/{flow_models/models.py => networks/velocity_field.py} (96%) rename src/ott/{neural/flow_models => solvers}/utils.py (90%) create mode 100644 tests/__init__.py rename tests/{ => initializers}/neural/meta_initializer_test.py (95%) delete mode 100644 tests/neural/map_estimator_test.py rename tests/neural/{ => methods}/genot_test.py (83%) rename tests/neural/{losses_test.py => methods/monge_gap_test.py} (64%) rename tests/neural/{ => methods}/neuraldual_test.py (86%) rename tests/neural/{ => methods}/otfm_test.py (85%) rename tests/neural/{ => networks}/icnn_test.py (93%) diff --git a/src/ott/__init__.py b/src/ott/__init__.py index dac0eb854..c40402511 100644 --- a/src/ott/__init__.py +++ b/src/ott/__init__.py @@ -25,7 +25,6 @@ ) with contextlib.suppress(ImportError): - # TODO(michalk8): add warning that neural module is not imported from . import neural from ._version import __version__ diff --git a/src/ott/initializers/__init__.py b/src/ott/initializers/__init__.py index 5406247dc..0fad8c3ff 100644 --- a/src/ott/initializers/__init__.py +++ b/src/ott/initializers/__init__.py @@ -11,4 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib + from . import linear, quadratic + +with contextlib.suppress(ImportError): + from . import neural + +del contextlib diff --git a/src/ott/neural/data/__init__.py b/src/ott/initializers/neural/__init__.py similarity index 94% rename from src/ott/neural/data/__init__.py rename to src/ott/initializers/neural/__init__.py index 785604b21..77e74d166 100644 --- a/src/ott/neural/data/__init__.py +++ b/src/ott/initializers/neural/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import datasets +from . import meta_initializer diff --git a/src/ott/neural/duality/models.py b/src/ott/initializers/neural/meta_initializer.py similarity index 51% rename from src/ott/neural/duality/models.py rename to src/ott/initializers/neural/meta_initializer.py index b3ce94c35..be1f87909 100644 --- a/src/ott/neural/duality/models.py +++ b/src/ott/initializers/neural/meta_initializer.py @@ -12,206 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Optional, Sequence, Tuple import jax import jax.numpy as jnp -import flax.linen as nn import optax +from flax import linen as nn from flax.core import frozen_dict from flax.training import train_state from ott import utils from ott.geometry import geometry -from ott.initializers.linear import initializers as lin_init -from ott.neural.duality import layers, neuraldual +from ott.initializers.linear import initializers from ott.problems.linear import linear_problem +from ott.solvers.linear import sinkhorn -__all__ = ["ICNN", "PotentialMLP", "MetaInitializer"] - -# wrap to silence docs linter -DEFAULT_KERNEL_INIT = lambda *a, **k: nn.initializers.normal()(*a, **k) -DEFAULT_RECTIFIER = nn.activation.relu -DEFAULT_ACTIVATION = nn.activation.relu - - -class ICNN(neuraldual.BaseW2NeuralDual): - """Input convex neural network (ICNN). - - Implementation of input convex neural networks as introduced in - :cite:`amos:17` with initialization schemes proposed by :cite:`bunne:22`. - - Args: - dim_data: data dimensionality. - dim_hidden: sequence specifying size of hidden dimensions. The - output dimension of the last layer is 1 by default. - ranks: ranks of the matrices :math:`A_i` used as low-rank factors - for the quadratic potentials. If a sequence is passed, it must contain - ``len(dim_hidden) + 2`` elements, where the last 2 elements correspond - to the ranks of the final layer with dimension 1 and the potentials, - respectively. - init_fn: Initializer for the kernel weight matrices. - The default is :func:`~flax.linen.initializers.normal`. - act_fn: choice of activation function used in network architecture, - needs to be convex. The default is :func:`~flax.linen.activation.relu`. - pos_weights: Enforce positive weights with a projection. - If :obj:`False`, the positive weights should be enforced with clipping - or regularization in the loss. - rectifier_fn: function to ensure the non negativity of the weights. - The default is :func:`~flax.linen.activation.relu`. - gaussian_map_samples: Tuple of source and target points, used to initialize - the ICNN to mimic the linear Bures map that morphs the (Gaussian - approximation) of the input measure to that of the target measure. If - :obj:`None`, the identity initialization is used, and ICNN mimics half the - squared Euclidean norm. - """ - - dim_data: int - dim_hidden: Sequence[int] - ranks: Union[int, Tuple[int, ...]] = 1 - init_fn: Callable[[jax.Array, Tuple[int, ...], Any], - jnp.ndarray] = DEFAULT_KERNEL_INIT - act_fn: Callable[[jnp.ndarray], jnp.ndarray] = DEFAULT_ACTIVATION - pos_weights: bool = False - rectifier_fn: Callable[[jnp.ndarray], jnp.ndarray] = DEFAULT_RECTIFIER - gaussian_map_samples: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None - - def setup(self) -> None: # noqa: D102 - dim_hidden = list(self.dim_hidden) + [1] - *ranks, pos_def_rank = self._normalize_ranks() - - # final layer computes average, still with normalized rescaling - self.w_zs = [self._get_wz(dim) for dim in dim_hidden[1:]] - # subsequent layers re-injected into convex functions - self.w_xs = [ - self._get_wx(dim, rank) for dim, rank in zip(dim_hidden, ranks) - ] - self.pos_def_potentials = self._get_pos_def_potentials(pos_def_rank) - - @nn.compact - def __call__(self, x: jnp.ndarray) -> float: # noqa: D102 - w_x, *w_xs = self.w_xs - assert len(self.w_zs) == len(w_xs), (len(self.w_zs), len(w_xs)) - - z = self.act_fn(w_x(x)) - for w_z, w_x in zip(self.w_zs, w_xs): - z = self.act_fn(w_z(z) + w_x(x)) - z = z + self.pos_def_potentials(x) - - return z.squeeze() - - def _get_wz(self, dim: int) -> nn.Module: - if self.pos_weights: - return layers.PositiveDense( - dim, - kernel_init=self.init_fn, - use_bias=False, - rectifier_fn=self.rectifier_fn, - ) - - return nn.Dense( - dim, - kernel_init=self.init_fn, - use_bias=False, - ) - - def _get_wx(self, dim: int, rank: int) -> nn.Module: - return layers.PosDefPotentials( - rank=rank, - num_potentials=dim, - use_linear=True, - use_bias=True, - kernel_diag_init=nn.initializers.zeros, - kernel_lr_init=self.init_fn, - kernel_linear_init=self.init_fn, - bias_init=nn.initializers.zeros, - ) - - def _get_pos_def_potentials(self, rank: int) -> layers.PosDefPotentials: - kwargs = { - "num_potentials": 1, - "use_linear": True, - "use_bias": True, - "bias_init": nn.initializers.zeros - } - - if self.gaussian_map_samples is None: - return layers.PosDefPotentials( - rank=rank, - kernel_diag_init=nn.initializers.ones, - kernel_lr_init=nn.initializers.zeros, - kernel_linear_init=nn.initializers.zeros, - **kwargs, - ) - - source, target = self.gaussian_map_samples - return layers.PosDefPotentials.init_from_samples( - source, - target, - rank=self.dim_data, - kernel_diag_init=nn.initializers.zeros, - **kwargs, - ) - - def _normalize_ranks(self) -> Tuple[int, ...]: - # +2 for the newly added layer with 1 + the final potentials - n_ranks = len(self.dim_hidden) + 2 - if isinstance(self.ranks, int): - return (self.ranks,) * n_ranks - - assert len(self.ranks) == n_ranks, (len(self.ranks), n_ranks) - return tuple(self.ranks) - - @property - def is_potential(self) -> bool: # noqa: D102 - return True - - -class PotentialMLP(neuraldual.BaseW2NeuralDual): - """A generic, not-convex MLP. - - Args: - dim_hidden: sequence specifying size of hidden dimensions. The output - dimension of the last layer is automatically set to 1 if - :attr:`is_potential` is ``True``, or the dimension of the input otherwise - is_potential: Model the potential if ``True``, otherwise - model the gradient of the potential - act_fn: Activation function - """ - - dim_hidden: Sequence[int] - is_potential: bool = True - act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.leaky_relu - - @nn.compact - def __call__(self, x: jnp.ndarray) -> jnp.ndarray: # noqa: D102 - squeeze = x.ndim == 1 - if squeeze: - x = jnp.expand_dims(x, 0) - assert x.ndim == 2, x.ndim - n_input = x.shape[-1] - - z = x - for n_hidden in self.dim_hidden: - Wx = nn.Dense(n_hidden, use_bias=True) - z = self.act_fn(Wx(z)) - - if self.is_potential: - Wx = nn.Dense(1, use_bias=True) - z = Wx(z).squeeze(-1) - - quad_term = 0.5 * jax.vmap(jnp.dot)(x, x) - z += quad_term - else: - Wx = nn.Dense(n_input, use_bias=True) - z = x + Wx(z) - - return z.squeeze(0) if squeeze else z +__all__ = ["MetaInitializer"] @jax.tree_util.register_pytree_node_class -class MetaInitializer(lin_init.DefaultInitializer): +class MetaInitializer(initializers.DefaultInitializer): """Meta OT Initializer with a fixed geometry :cite:`amos:22`. This initializer consists of a predictive model that outputs the @@ -314,7 +135,7 @@ def update( def init_dual_a( # noqa: D102 self, - ot_prob: "linear_problem.LinearProblem", + ot_prob: linear_problem.LinearProblem, lse_mode: bool, rng: Optional[jax.Array] = None, ) -> jnp.ndarray: @@ -337,8 +158,6 @@ def init_dual_a( # noqa: D102 def _get_update_fn(self): """Return the implementation (and jitted) update function.""" - from ott.problems.linear import linear_problem - from ott.solvers.linear import sinkhorn def dual_obj_loss_single(params, a, b): f_pred = self._compute_f(a, b, params) diff --git a/src/ott/neural/__init__.py b/src/ott/neural/__init__.py index 10dac222c..3af88e56b 100644 --- a/src/ott/neural/__init__.py +++ b/src/ott/neural/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import data, duality, flow_models, gaps +from . import datasets, methods, networks diff --git a/src/ott/neural/data/datasets.py b/src/ott/neural/datasets.py similarity index 100% rename from src/ott/neural/data/datasets.py rename to src/ott/neural/datasets.py diff --git a/src/ott/neural/gaps/monge_gap.py b/src/ott/neural/gaps/monge_gap.py deleted file mode 100644 index f6136bf07..000000000 --- a/src/ott/neural/gaps/monge_gap.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright OTT-JAX -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Callable, Literal, Optional, Tuple, Union - -import jax -import jax.numpy as jnp - -from ott.geometry import costs, pointcloud -from ott.solvers import linear -from ott.solvers.linear import sinkhorn - -__all__ = ["monge_gap", "monge_gap_from_samples"] - - -def monge_gap( - map_fn: Callable[[jnp.ndarray], jnp.ndarray], - reference_points: jnp.ndarray, - cost_fn: Optional[costs.CostFn] = None, - epsilon: Optional[float] = None, - relative_epsilon: Optional[bool] = None, - scale_cost: Union[int, float, Literal["mean", "max_cost", "median"]] = 1.0, - return_output: bool = False, - **kwargs: Any -) -> Union[float, Tuple[float, sinkhorn.SinkhornOutput]]: - r"""Monge gap regularizer :cite:`uscidda:23`. - - For a cost function :math:`c` and empirical reference measure - :math:`\hat{\rho}_n=\frac{1}{n}\sum_{i=1}^n \delta_{x_i}`, the - (entropic) Monge gap of a map function - :math:`T:\mathbb{R}^d\rightarrow\mathbb{R}^d` is defined as: - - .. math:: - \mathcal{M}^c_{\hat{\rho}_n, \varepsilon} (T) - = \frac{1}{n} \sum_{i=1}^n c(x_i, T(x_i)) - - W_{c, \varepsilon}(\hat{\rho}_n, T \sharp \hat{\rho}_n) - - See :cite:`uscidda:23` Eq. (8). This function is a thin wrapper that calls - :func:`~ott.neural.losses.monge_gap_from_samples`. - - Args: - map_fn: Callable corresponding to map :math:`T` in definition above. The - callable should be vectorized (e.g. using :func:`jax.vmap`), i.e, - able to process a *batch* of vectors of size `d`, namely - ``map_fn`` applied to an array returns an array of the same shape. - reference_points: Array of `[n,d]` points, :math:`\hat\rho_n` in paper - cost_fn: An object of class :class:`~ott.geometry.costs.CostFn`. - epsilon: Regularization parameter. See - :class:`~ott.geometry.pointcloud.PointCloud` - relative_epsilon: when `False`, the parameter ``epsilon`` specifies the - value of the entropic regularization parameter. When `True`, ``epsilon`` - refers to a fraction of the - :attr:`~ott.geometry.pointcloud.PointCloud.mean_cost_matrix`, which is - computed adaptively using ``source`` and ``target`` points. - scale_cost: option to rescale the cost matrix. Implemented scalings are - 'median', 'mean' and 'max_cost'. Alternatively, a float factor can be - given to rescale the cost such that ``cost_matrix /= scale_cost``. - return_output: boolean to also return the - :class:`~ott.solvers.linear.sinkhorn.SinkhornOutput`. - kwargs: holds the kwargs to instantiate the or - :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver to - compute the regularized OT cost. - - Returns: - The Monge gap value and optionally the - :class:`~ott.solvers.linear.sinkhorn.SinkhornOutput` - """ - target = map_fn(reference_points) - return monge_gap_from_samples( - source=reference_points, - target=target, - cost_fn=cost_fn, - epsilon=epsilon, - relative_epsilon=relative_epsilon, - scale_cost=scale_cost, - return_output=return_output, - **kwargs - ) - - -def monge_gap_from_samples( - source: jnp.ndarray, - target: jnp.ndarray, - cost_fn: Optional[costs.CostFn] = None, - epsilon: Optional[float] = None, - relative_epsilon: Optional[bool] = None, - scale_cost: Union[int, float, Literal["mean", "max_cost", "median"]] = 1.0, - return_output: bool = False, - **kwargs: Any -) -> Union[float, Tuple[float, sinkhorn.SinkhornOutput]]: - r"""Monge gap, instantiated in terms of samples before / after applying map. - - .. math:: - \frac{1}{n} \sum_{i=1}^n c(x_i, y_i)) - - W_{c, \varepsilon}(\frac{1}{n}\sum_i \delta_{x_i}, - \frac{1}{n}\sum_i \delta_{y_i}) - - where :math:`W_{c, \varepsilon}` is an entropy-regularized optimal transport - cost, the :attr:`~ott.solvers.linear.sinkhorn.SinkhornOutput.ent_reg_cost`. - - Args: - source: samples from first measure, array of shape ``[n, d]``. - target: samples from second measure, array of shape ``[n, d]``. - cost_fn: a cost function between two points in dimension :math:`d`. - If :obj:`None`, :class:`~ott.geometry.costs.SqEuclidean` is used. - epsilon: Regularization parameter. See - :class:`~ott.geometry.pointcloud.PointCloud` - relative_epsilon: when `False`, the parameter ``epsilon`` specifies the - value of the entropic regularization parameter. When `True`, ``epsilon`` - refers to a fraction of the - :attr:`~ott.geometry.pointcloud.PointCloud.mean_cost_matrix`, which is - computed adaptively using ``source`` and ``target`` points. - scale_cost: option to rescale the cost matrix. Implemented scalings are - 'median', 'mean' and 'max_cost'. Alternatively, a float factor can be - given to rescale the cost such that ``cost_matrix /= scale_cost``. - return_output: boolean to also return the - :class:`~ott.solvers.linear.sinkhorn.SinkhornOutput`. - kwargs: holds the kwargs to instantiate the or - :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver to - compute the regularized OT cost. - - Returns: - The Monge gap value and optionally the - :class:`~ott.solvers.linear.sinkhorn.SinkhornOutput` - """ - cost_fn = costs.SqEuclidean() if cost_fn is None else cost_fn - geom = pointcloud.PointCloud( - x=source, - y=target, - cost_fn=cost_fn, - epsilon=epsilon, - relative_epsilon=relative_epsilon, - scale_cost=scale_cost, - ) - gt_displacement_cost = jnp.mean(jax.vmap(cost_fn)(source, target)) - out = linear.solve(geom=geom, **kwargs) - loss = gt_displacement_cost - out.ent_reg_cost - return (loss, out) if return_output else loss diff --git a/src/ott/neural/gaps/__init__.py b/src/ott/neural/methods/__init__.py similarity index 93% rename from src/ott/neural/gaps/__init__.py rename to src/ott/neural/methods/__init__.py index 0ba36da05..a5836f921 100644 --- a/src/ott/neural/gaps/__init__.py +++ b/src/ott/neural/methods/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import map_estimator, monge_gap +from . import monge_gap, neuraldual diff --git a/src/ott/neural/flow_models/__init__.py b/src/ott/neural/methods/flows/__init__.py similarity index 92% rename from src/ott/neural/flow_models/__init__.py rename to src/ott/neural/methods/flows/__init__.py index a6239fa06..f5bba4cc5 100644 --- a/src/ott/neural/flow_models/__init__.py +++ b/src/ott/neural/methods/flows/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import flows, genot, models, otfm, utils +from . import dynamics, genot, otfm diff --git a/src/ott/neural/flow_models/flows.py b/src/ott/neural/methods/flows/dynamics.py similarity index 98% rename from src/ott/neural/flow_models/flows.py rename to src/ott/neural/methods/flows/dynamics.py index 2cde34833..3ca60168c 100644 --- a/src/ott/neural/flow_models/flows.py +++ b/src/ott/neural/methods/flows/dynamics.py @@ -20,7 +20,7 @@ "BaseFlow", "StraightFlow", "ConstantNoiseFlow", - "BrownianNoiseFlow", + "BrownianBridge", ] @@ -140,8 +140,8 @@ def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray: return jnp.full_like(t, fill_value=self.sigma) -class BrownianNoiseFlow(StraightFlow): - r"""Brownian Bridge Flow. +class BrownianBridge(StraightFlow): + r"""Brownian Bridge. Sampler for sampling noise implicitly defined by a Schrödinger Bridge problem with parameter :math:`\sigma` such that diff --git a/src/ott/neural/flow_models/genot.py b/src/ott/neural/methods/flows/genot.py similarity index 94% rename from src/ott/neural/flow_models/genot.py rename to src/ott/neural/methods/flows/genot.py index a6fb30c91..a3bad5902 100644 --- a/src/ott/neural/flow_models/genot.py +++ b/src/ott/neural/methods/flows/genot.py @@ -23,8 +23,9 @@ from flax.training import train_state from ott import utils -from ott.neural.flow_models import flows, models -from ott.neural.flow_models import utils as flow_utils +from ott.neural.methods.flows import dynamics +from ott.neural.networks import velocity_field +from ott.solvers import utils as solver_utils __all__ = ["GENOT"] @@ -45,7 +46,7 @@ class GENOT: the unbalanced setting. Args: - velocity_field: Vector field parameterized by a neural network. + vf: Vector field parameterized by a neural network. flow: Flow between the latent and the target distributions. data_match_fn: Function to match samples from the source and the target distributions with a ``(src_lin, tgt_lin, src_quad, tgt_quad) -> matching`` @@ -70,15 +71,15 @@ class GENOT: def __init__( self, - velocity_field: models.VelocityField, - flow: flows.BaseFlow, + vf: velocity_field.VelocityField, + flow: dynamics.BaseFlow, data_match_fn: DataMatchFn_t, *, source_dim: int, target_dim: int, condition_dim: Optional[int] = None, time_sampler: Callable[[jax.Array, int], - jnp.ndarray] = flow_utils.uniform_sampler, + jnp.ndarray] = solver_utils.uniform_sampler, latent_noise_fn: Optional[Callable[[jax.Array, Tuple[int, ...]], jnp.ndarray]] = None, latent_match_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray], @@ -86,12 +87,12 @@ def __init__( n_samples_per_src: int = 1, **kwargs: Any, ): - self.vf = velocity_field + self.vf = vf self.flow = flow self.data_match_fn = data_match_fn self.time_sampler = time_sampler if latent_noise_fn is None: - latent_noise_fn = functools.partial(multivariate_normal, dim=target_dim) + latent_noise_fn = functools.partial(_multivariate_normal, dim=target_dim) self.latent_noise_fn = latent_noise_fn self.latent_match_fn = latent_match_fn self.n_samples_per_src = n_samples_per_src @@ -193,7 +194,7 @@ def prepare_data( latent = self.latent_noise_fn(rng_noise, (n, self.n_samples_per_src)) tmat = self.data_match_fn(*matching_data) # (n, m) - src_ixs, tgt_ixs = flow_utils.sample_conditional( # (n, k), (m, k) + src_ixs, tgt_ixs = solver_utils.sample_conditional( # (n, k), (m, k) rng_resample, tmat, k=self.n_samples_per_src, @@ -233,7 +234,7 @@ def resample( ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray]: tmat = self.latent_match_fn(latent, tgt) # (n, k) - src_ixs, tgt_ixs = flow_utils.sample_joint(rng, tmat) # (n,), (m,) + src_ixs, tgt_ixs = solver_utils.sample_joint(rng, tmat) # (n,), (m,) src, tgt = src[src_ixs], tgt[tgt_ixs] if src_cond is not None: src_cond = src_cond[src_ixs] @@ -304,7 +305,7 @@ def solve_ode(x: jnp.ndarray, cond: jnp.ndarray) -> jnp.ndarray: return jax.jit(jax.vmap(solve_ode))(latent, source) -def multivariate_normal( +def _multivariate_normal( rng: jax.Array, shape: Tuple[int, ...], dim: int, diff --git a/src/ott/neural/flow_models/otfm.py b/src/ott/neural/methods/flows/otfm.py similarity index 93% rename from src/ott/neural/flow_models/otfm.py rename to src/ott/neural/methods/flows/otfm.py index f6ccd6e1b..ebeb138b5 100644 --- a/src/ott/neural/flow_models/otfm.py +++ b/src/ott/neural/methods/flows/otfm.py @@ -22,8 +22,9 @@ from flax.training import train_state from ott import utils -from ott.neural.flow_models import flows, models -from ott.neural.flow_models import utils as flow_utils +from ott.neural.methods.flows import dynamics +from ott.neural.networks import velocity_field +from ott.solvers import utils as solver_utils __all__ = ["OTFlowMatching"] @@ -34,7 +35,7 @@ class OTFlowMatching: With an extension to OT-FM :cite:`tong:23,pooladian:23`. Args: - velocity_field: Vector field parameterized by a neural network. + vf: Vector field parameterized by a neural network. flow: Flow between the source and the target distributions. match_fn: Function to match samples from the source and the target distributions. It has a ``(src, tgt) -> matching`` signature. @@ -45,15 +46,15 @@ class OTFlowMatching: def __init__( self, - velocity_field: models.VelocityField, - flow: flows.BaseFlow, + vf: velocity_field.VelocityField, + flow: dynamics.BaseFlow, match_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None, time_sampler: Callable[[jax.Array, int], - jnp.ndarray] = flow_utils.uniform_sampler, + jnp.ndarray] = solver_utils.uniform_sampler, **kwargs: Any, ): - self.vf = velocity_field + self.vf = vf self.flow = flow self.time_sampler = time_sampler self.match_fn = match_fn @@ -127,7 +128,7 @@ def __call__( # noqa: D102 if self.match_fn is not None: tmat = self.match_fn(src, tgt) - src_ixs, tgt_ixs = flow_utils.sample_joint(rng_resample, tmat) + src_ixs, tgt_ixs = solver_utils.sample_joint(rng_resample, tmat) src, tgt = src[src_ixs], tgt[tgt_ixs] src_cond = None if src_cond is None else src_cond[src_ixs] diff --git a/src/ott/neural/gaps/map_estimator.py b/src/ott/neural/methods/monge_gap.py similarity index 63% rename from src/ott/neural/gaps/map_estimator.py rename to src/ott/neural/methods/monge_gap.py index 61c24f0c3..c108a3509 100644 --- a/src/ott/neural/gaps/map_estimator.py +++ b/src/ott/neural/methods/monge_gap.py @@ -18,6 +18,7 @@ Callable, Dict, Iterator, + Literal, Optional, Sequence, Tuple, @@ -32,12 +33,140 @@ from flax.training import train_state from ott import utils -from ott.neural.duality import neuraldual +from ott.geometry import costs, pointcloud +from ott.neural.networks import potentials +from ott.solvers import linear +from ott.solvers.linear import sinkhorn + +__all__ = ["monge_gap", "monge_gap_from_samples", "MongeGapEstimator"] + + +def monge_gap( + map_fn: Callable[[jnp.ndarray], jnp.ndarray], + reference_points: jnp.ndarray, + cost_fn: Optional[costs.CostFn] = None, + epsilon: Optional[float] = None, + relative_epsilon: Optional[bool] = None, + scale_cost: Union[int, float, Literal["mean", "max_cost", "median"]] = 1.0, + return_output: bool = False, + **kwargs: Any +) -> Union[float, Tuple[float, sinkhorn.SinkhornOutput]]: + r"""Monge gap regularizer :cite:`uscidda:23`. + + For a cost function :math:`c` and empirical reference measure + :math:`\hat{\rho}_n=\frac{1}{n}\sum_{i=1}^n \delta_{x_i}`, the + (entropic) Monge gap of a map function + :math:`T:\mathbb{R}^d\rightarrow\mathbb{R}^d` is defined as: -__all__ = ["MapEstimator"] + .. math:: + \mathcal{M}^c_{\hat{\rho}_n, \varepsilon} (T) + = \frac{1}{n} \sum_{i=1}^n c(x_i, T(x_i)) - + W_{c, \varepsilon}(\hat{\rho}_n, T \sharp \hat{\rho}_n) + See :cite:`uscidda:23` Eq. (8). This function is a thin wrapper that calls + :func:`~ott.neural.losses.monge_gap_from_samples`. -class MapEstimator: + Args: + map_fn: Callable corresponding to map :math:`T` in definition above. The + callable should be vectorized (e.g. using :func:`jax.vmap`), i.e, + able to process a *batch* of vectors of size `d`, namely + ``map_fn`` applied to an array returns an array of the same shape. + reference_points: Array of `[n,d]` points, :math:`\hat\rho_n` in paper + cost_fn: An object of class :class:`~ott.geometry.costs.CostFn`. + epsilon: Regularization parameter. See + :class:`~ott.geometry.pointcloud.PointCloud` + relative_epsilon: when `False`, the parameter ``epsilon`` specifies the + value of the entropic regularization parameter. When `True`, ``epsilon`` + refers to a fraction of the + :attr:`~ott.geometry.pointcloud.PointCloud.mean_cost_matrix`, which is + computed adaptively using ``source`` and ``target`` points. + scale_cost: option to rescale the cost matrix. Implemented scalings are + 'median', 'mean' and 'max_cost'. Alternatively, a float factor can be + given to rescale the cost such that ``cost_matrix /= scale_cost``. + return_output: boolean to also return the + :class:`~ott.solvers.linear.sinkhorn.SinkhornOutput`. + kwargs: holds the kwargs to instantiate the or + :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver to + compute the regularized OT cost. + + Returns: + The Monge gap value and optionally the + :class:`~ott.solvers.linear.sinkhorn.SinkhornOutput` + """ + target = map_fn(reference_points) + return monge_gap_from_samples( + source=reference_points, + target=target, + cost_fn=cost_fn, + epsilon=epsilon, + relative_epsilon=relative_epsilon, + scale_cost=scale_cost, + return_output=return_output, + **kwargs + ) + + +def monge_gap_from_samples( + source: jnp.ndarray, + target: jnp.ndarray, + cost_fn: Optional[costs.CostFn] = None, + epsilon: Optional[float] = None, + relative_epsilon: Optional[bool] = None, + scale_cost: Union[int, float, Literal["mean", "max_cost", "median"]] = 1.0, + return_output: bool = False, + **kwargs: Any +) -> Union[float, Tuple[float, sinkhorn.SinkhornOutput]]: + r"""Monge gap, instantiated in terms of samples before / after applying map. + + .. math:: + \frac{1}{n} \sum_{i=1}^n c(x_i, y_i)) - + W_{c, \varepsilon}(\frac{1}{n}\sum_i \delta_{x_i}, + \frac{1}{n}\sum_i \delta_{y_i}) + + where :math:`W_{c, \varepsilon}` is an entropy-regularized optimal transport + cost, the :attr:`~ott.solvers.linear.sinkhorn.SinkhornOutput.ent_reg_cost`. + + Args: + source: samples from first measure, array of shape ``[n, d]``. + target: samples from second measure, array of shape ``[n, d]``. + cost_fn: a cost function between two points in dimension :math:`d`. + If :obj:`None`, :class:`~ott.geometry.costs.SqEuclidean` is used. + epsilon: Regularization parameter. See + :class:`~ott.geometry.pointcloud.PointCloud` + relative_epsilon: when `False`, the parameter ``epsilon`` specifies the + value of the entropic regularization parameter. When `True`, ``epsilon`` + refers to a fraction of the + :attr:`~ott.geometry.pointcloud.PointCloud.mean_cost_matrix`, which is + computed adaptively using ``source`` and ``target`` points. + scale_cost: option to rescale the cost matrix. Implemented scalings are + 'median', 'mean' and 'max_cost'. Alternatively, a float factor can be + given to rescale the cost such that ``cost_matrix /= scale_cost``. + return_output: boolean to also return the + :class:`~ott.solvers.linear.sinkhorn.SinkhornOutput`. + kwargs: holds the kwargs to instantiate the or + :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver to + compute the regularized OT cost. + + Returns: + The Monge gap value and optionally the + :class:`~ott.solvers.linear.sinkhorn.SinkhornOutput` + """ + cost_fn = costs.SqEuclidean() if cost_fn is None else cost_fn + geom = pointcloud.PointCloud( + x=source, + y=target, + cost_fn=cost_fn, + epsilon=epsilon, + relative_epsilon=relative_epsilon, + scale_cost=scale_cost, + ) + gt_displacement_cost = jnp.mean(jax.vmap(cost_fn)(source, target)) + out = linear.solve(geom=geom, **kwargs) + loss = gt_displacement_cost - out.ent_reg_cost + return (loss, out) if return_output else loss + + +class MongeGapEstimator: r"""Mapping estimator between probability measures. It estimates a map :math:`T` by minimizing the loss: @@ -78,7 +207,7 @@ class MapEstimator: def __init__( self, dim_data: int, - model: neuraldual.BaseW2NeuralDual, + model: potentials.BasePotential, optimizer: Optional[optax.OptState] = None, fitting_loss: Optional[Callable[[jnp.ndarray, jnp.ndarray], Tuple[float, Optional[Any]]]] = None, @@ -114,7 +243,7 @@ def __init__( def setup( self, dim_data: int, - neural_net: neuraldual.BaseW2NeuralDual, + neural_net: potentials.BasePotential, optimizer: optax.OptState, ): """Setup all components required to train the network.""" diff --git a/src/ott/neural/duality/neuraldual.py b/src/ott/neural/methods/neuraldual.py similarity index 80% rename from src/ott/neural/duality/neuraldual.py rename to src/ott/neural/methods/neuraldual.py index c00acb76c..6845224f4 100644 --- a/src/ott/neural/duality/neuraldual.py +++ b/src/ott/neural/methods/neuraldual.py @@ -11,10 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import abc import warnings from typing import ( - Any, Callable, Dict, Iterator, @@ -28,138 +26,18 @@ import jax import jax.numpy as jnp -import flax.linen as nn import optax -from flax import struct -from flax.core import frozen_dict -from flax.training import train_state from ott import utils from ott.geometry import costs -from ott.neural.duality import conjugate, models -from ott.problems.linear import potentials +from ott.neural.networks import icnn, potentials +from ott.neural.networks.layers import conjugate +from ott.problems.linear import potentials as dual_potentials -__all__ = ["W2NeuralTrainState", "BaseW2NeuralDual", "W2NeuralDual"] +__all__ = ["W2NeuralDual"] Train_t = Dict[Literal["train_logs", "valid_logs"], Dict[str, List[float]]] -Callback_t = Callable[[int, potentials.DualPotentials], None] - -PotentialValueFn_t = Callable[[jnp.ndarray], jnp.ndarray] -PotentialGradientFn_t = Callable[[jnp.ndarray], jnp.ndarray] - - -class W2NeuralTrainState(train_state.TrainState): - """Adds information about the model's value and gradient to the state. - - This extends :class:`~flax.training.train_state.TrainState` to include - the potential methods from the - :class:`~ott.neural.duality.neuraldual.BaseW2NeuralDual` used during training. - - Args: - potential_value_fn: the potential's value function - potential_gradient_fn: the potential's gradient function - """ - potential_value_fn: Callable[ - [frozen_dict.FrozenDict[str, jnp.ndarray], Optional[PotentialValueFn_t]], - PotentialValueFn_t] = struct.field(pytree_node=False) - potential_gradient_fn: Callable[[frozen_dict.FrozenDict[str, jnp.ndarray]], - PotentialGradientFn_t] = struct.field( - pytree_node=False - ) - - -class BaseW2NeuralDual(abc.ABC, nn.Module): - """Base class for the neural solver models.""" - - @property - @abc.abstractmethod - def is_potential(self) -> bool: - """Indicates if the module implements a potential value or a vector field. - - Returns: - ``True`` if the module defines a potential, ``False`` if it defines a - vector field. - """ - - def potential_value_fn( - self, - params: frozen_dict.FrozenDict[str, jnp.ndarray], - other_potential_value_fn: Optional[PotentialValueFn_t] = None, - ) -> PotentialValueFn_t: - r"""Return a function giving the value of the potential. - - Applies the module if :attr:`is_potential` is ``True``, otherwise - constructs the value of the potential from the gradient with - - .. math:: - - g(y) = -f(\nabla_y g(y)) + y^T \nabla_y g(y) - - where :math:`\nabla_y g(y)` is detached for the envelope theorem - :cite:`danskin:67,bertsekas:71` - to give the appropriate first derivatives of this construction. - - Args: - params: parameters of the module - other_potential_value_fn: function giving the value of the other - potential. Only needed when :attr:`is_potential` is ``False``. - - Returns: - A function that can be evaluated to obtain a potential value, or a linear - interpolation of a potential. - """ - if self.is_potential: - return lambda x: self.apply({"params": params}, x) - - assert other_potential_value_fn is not None, \ - "The value of the gradient-based potential depends " \ - "on the value of the other potential." - - def value_fn(x: jnp.ndarray) -> jnp.ndarray: - squeeze = x.ndim == 1 - if squeeze: - x = jnp.expand_dims(x, 0) - grad_g_x = jax.lax.stop_gradient(self.apply({"params": params}, x)) - value = -other_potential_value_fn(grad_g_x) + \ - jax.vmap(jnp.dot)(grad_g_x, x) - return value.squeeze(0) if squeeze else value - - return value_fn - - def potential_gradient_fn( - self, - params: frozen_dict.FrozenDict[str, jnp.ndarray], - ) -> PotentialGradientFn_t: - """Return a function returning a vector or the gradient of the potential. - - Args: - params: parameters of the module - - Returns: - A function that can be evaluated to obtain the potential's gradient - """ - if self.is_potential: - return jax.vmap(jax.grad(self.potential_value_fn(params))) - return lambda x: self.apply({"params": params}, x) - - def create_train_state( - self, - rng: jax.Array, - optimizer: optax.OptState, - input: Union[int, Tuple[int, ...]], - **kwargs: Any, - ) -> W2NeuralTrainState: - """Create initial training state.""" - params = self.init(rng, jnp.ones(input))["params"] - - return W2NeuralTrainState.create( - apply_fn=self.apply, - params=params, - tx=optimizer, - potential_value_fn=self.potential_value_fn, - potential_gradient_fn=self.potential_gradient_fn, - **kwargs, - ) +Callback_t = Callable[[int, dual_potentials.DualPotentials], None] class W2NeuralDual: @@ -228,8 +106,8 @@ class W2NeuralDual: def __init__( self, dim_data: int, - neural_f: Optional[BaseW2NeuralDual] = None, - neural_g: Optional[BaseW2NeuralDual] = None, + neural_f: Optional[potentials.BasePotential] = None, + neural_g: Optional[potentials.BasePotential] = None, optimizer_f: Optional[optax.OptState] = None, optimizer_g: Optional[optax.OptState] = None, num_train_iters: int = 20000, @@ -266,9 +144,9 @@ def __init__( # set default neural architectures if neural_f is None: - neural_f = models.ICNN(dim_data=dim_data, dim_hidden=[64, 64, 64, 64]) + neural_f = icnn.ICNN(dim_data=dim_data, dim_hidden=[64, 64, 64, 64]) if neural_g is None: - neural_g = models.ICNN(dim_data=dim_data, dim_hidden=[64, 64, 64, 64]) + neural_g = icnn.ICNN(dim_data=dim_data, dim_hidden=[64, 64, 64, 64]) self.neural_f = neural_f self.neural_g = neural_g @@ -285,8 +163,8 @@ def __init__( def setup( self, rng: jax.Array, - neural_f: BaseW2NeuralDual, - neural_g: BaseW2NeuralDual, + neural_f: potentials.BasePotential, + neural_g: potentials.BasePotential, dim_data: int, optimizer_f: optax.OptState, optimizer_g: optax.OptState, @@ -301,13 +179,13 @@ def setup( f"the `W2NeuralDual` setting, with positive weights " \ f"being {self.pos_weights}." if isinstance( - neural_f, models.ICNN + neural_f, icnn.ICNN ) and neural_f.pos_weights is not self.pos_weights: warnings.warn(warn_str, stacklevel=2) neural_f.pos_weights = self.pos_weights if isinstance( - neural_g, models.ICNN + neural_g, icnn.ICNN ) and neural_g.pos_weights is not self.pos_weights: warnings.warn(warn_str, stacklevel=2) neural_g.pos_weights = self.pos_weights @@ -325,7 +203,7 @@ def setup( # default to using back_and_forth with the non-convex models if self.back_and_forth is None: - self.back_and_forth = isinstance(neural_f, models.PotentialMLP) + self.back_and_forth = isinstance(neural_f, potentials.PotentialMLP) if self.num_inner_iters == 1 and self.parallel_updates: self.train_step_parallel = self.get_step_fn( @@ -359,8 +237,8 @@ def __call__( # noqa: D102 validloader_source: Iterator[jnp.ndarray], validloader_target: Iterator[jnp.ndarray], callback: Optional[Callback_t] = None, - ) -> Union[potentials.DualPotentials, Tuple[potentials.DualPotentials, - Train_t]]: + ) -> Union[dual_potentials.DualPotentials, + Tuple[dual_potentials.DualPotentials, Train_t]]: logs = self.train_fn( trainloader_source, trainloader_target, @@ -643,7 +521,7 @@ def step_fn(state_f, state_g, batch): def to_dual_potentials( self, finetune_g: bool = True - ) -> potentials.DualPotentials: + ) -> dual_potentials.DualPotentials: r"""Return the Kantorovich dual potentials from the trained potentials. Args: @@ -664,7 +542,7 @@ def g_value_finetuned(y: jnp.ndarray) -> jnp.ndarray: ) return -f_value(grad_g_y) + jnp.dot(grad_g_y, y) - return potentials.DualPotentials( + return dual_potentials.DualPotentials( f=f_value, g=g_value_prediction if not finetune_g or self.conjugate_solver is None else g_value_finetuned, diff --git a/src/ott/neural/networks/__init__.py b/src/ott/neural/networks/__init__.py new file mode 100644 index 000000000..5f2fd8636 --- /dev/null +++ b/src/ott/neural/networks/__init__.py @@ -0,0 +1,14 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from . import icnn, layers, potentials, velocity_field diff --git a/src/ott/neural/networks/icnn.py b/src/ott/neural/networks/icnn.py new file mode 100644 index 000000000..c6896dac4 --- /dev/null +++ b/src/ott/neural/networks/icnn.py @@ -0,0 +1,160 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional, Sequence, Tuple, Union + +import jax +import jax.numpy as jnp + +from flax import linen as nn + +from ott.neural.networks import potentials +from ott.neural.networks.layers import posdef + +__all__ = ["ICNN"] + +DEFAULT_KERNEL_INIT = lambda *a, **k: nn.initializers.normal()(*a, **k) +DEFAULT_RECTIFIER = nn.activation.relu +DEFAULT_ACTIVATION = nn.activation.relu + + +class ICNN(potentials.BasePotential): + """Input convex neural network (ICNN). + + Implementation of input convex neural networks as introduced in + :cite:`amos:17` with initialization schemes proposed by :cite:`bunne:22`. + + Args: + dim_data: data dimensionality. + dim_hidden: sequence specifying size of hidden dimensions. The + output dimension of the last layer is 1 by default. + ranks: ranks of the matrices :math:`A_i` used as low-rank factors + for the quadratic potentials. If a sequence is passed, it must contain + ``len(dim_hidden) + 2`` elements, where the last 2 elements correspond + to the ranks of the final layer with dimension 1 and the potentials, + respectively. + init_fn: Initializer for the kernel weight matrices. + The default is :func:`~flax.linen.initializers.normal`. + act_fn: choice of activation function used in network architecture, + needs to be convex. The default is :func:`~flax.linen.activation.relu`. + pos_weights: Enforce positive weights with a projection. + If :obj:`False`, the positive weights should be enforced with clipping + or regularization in the loss. + rectifier_fn: function to ensure the non negativity of the weights. + The default is :func:`~flax.linen.activation.relu`. + gaussian_map_samples: Tuple of source and target points, used to initialize + the ICNN to mimic the linear Bures map that morphs the (Gaussian + approximation) of the input measure to that of the target measure. If + :obj:`None`, the identity initialization is used, and ICNN mimics half the + squared Euclidean norm. + """ + + dim_data: int + dim_hidden: Sequence[int] + ranks: Union[int, Tuple[int, ...]] = 1 + init_fn: Callable[[jax.Array, Tuple[int, ...], Any], + jnp.ndarray] = DEFAULT_KERNEL_INIT + act_fn: Callable[[jnp.ndarray], jnp.ndarray] = DEFAULT_ACTIVATION + pos_weights: bool = False + rectifier_fn: Callable[[jnp.ndarray], jnp.ndarray] = DEFAULT_RECTIFIER + gaussian_map_samples: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None + + def setup(self) -> None: # noqa: D102 + dim_hidden = list(self.dim_hidden) + [1] + *ranks, pos_def_rank = self._normalize_ranks() + + # final layer computes average, still with normalized rescaling + self.w_zs = [self._get_wz(dim) for dim in dim_hidden[1:]] + # subsequent layers re-injected into convex functions + self.w_xs = [ + self._get_wx(dim, rank) for dim, rank in zip(dim_hidden, ranks) + ] + self.pos_def_potentials = self._get_pos_def_potentials(pos_def_rank) + + @nn.compact + def __call__(self, x: jnp.ndarray) -> float: # noqa: D102 + w_x, *w_xs = self.w_xs + assert len(self.w_zs) == len(w_xs), (len(self.w_zs), len(w_xs)) + + z = self.act_fn(w_x(x)) + for w_z, w_x in zip(self.w_zs, w_xs): + z = self.act_fn(w_z(z) + w_x(x)) + z = z + self.pos_def_potentials(x) + + return z.squeeze() + + def _get_wz(self, dim: int) -> nn.Module: + if self.pos_weights: + return posdef.PositiveDense( + dim, + kernel_init=self.init_fn, + use_bias=False, + rectifier_fn=self.rectifier_fn, + ) + + return nn.Dense( + dim, + kernel_init=self.init_fn, + use_bias=False, + ) + + def _get_wx(self, dim: int, rank: int) -> nn.Module: + return posdef.PosDefPotentials( + rank=rank, + num_potentials=dim, + use_linear=True, + use_bias=True, + kernel_diag_init=nn.initializers.zeros, + kernel_lr_init=self.init_fn, + kernel_linear_init=self.init_fn, + bias_init=nn.initializers.zeros, + ) + + def _get_pos_def_potentials(self, rank: int) -> posdef.PosDefPotentials: + kwargs = { + "num_potentials": 1, + "use_linear": True, + "use_bias": True, + "bias_init": nn.initializers.zeros + } + + if self.gaussian_map_samples is None: + return posdef.PosDefPotentials( + rank=rank, + kernel_diag_init=nn.initializers.ones, + kernel_lr_init=nn.initializers.zeros, + kernel_linear_init=nn.initializers.zeros, + **kwargs, + ) + + source, target = self.gaussian_map_samples + return posdef.PosDefPotentials.init_from_samples( + source, + target, + rank=self.dim_data, + kernel_diag_init=nn.initializers.zeros, + **kwargs, + ) + + def _normalize_ranks(self) -> Tuple[int, ...]: + # +2 for the newly added layer with 1 + the final potentials + n_ranks = len(self.dim_hidden) + 2 + if isinstance(self.ranks, int): + return (self.ranks,) * n_ranks + + assert len(self.ranks) == n_ranks, (len(self.ranks), n_ranks) + return tuple(self.ranks) + + @property + def is_potential(self) -> bool: # noqa: D102 + return True diff --git a/src/ott/neural/duality/__init__.py b/src/ott/neural/networks/layers/__init__.py similarity index 91% rename from src/ott/neural/duality/__init__.py rename to src/ott/neural/networks/layers/__init__.py index ef76b42fa..237c5f275 100644 --- a/src/ott/neural/duality/__init__.py +++ b/src/ott/neural/networks/layers/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import conjugate, layers, models, neuraldual +from . import conjugate, posdef, time_encoder diff --git a/src/ott/neural/duality/conjugate.py b/src/ott/neural/networks/layers/conjugate.py similarity index 100% rename from src/ott/neural/duality/conjugate.py rename to src/ott/neural/networks/layers/conjugate.py diff --git a/src/ott/neural/duality/layers.py b/src/ott/neural/networks/layers/posdef.py similarity index 99% rename from src/ott/neural/duality/layers.py rename to src/ott/neural/networks/layers/posdef.py index 6ed857452..41663ffe3 100644 --- a/src/ott/neural/duality/layers.py +++ b/src/ott/neural/networks/layers/posdef.py @@ -16,7 +16,7 @@ import jax import jax.numpy as jnp -import flax.linen as nn +from flax import linen as nn __all__ = ["PositiveDense", "PosDefPotentials"] @@ -25,7 +25,6 @@ Dtype = Any Array = jnp.ndarray -# wrap to silence docs linter DEFAULT_KERNEL_INIT = lambda *a, **k: nn.initializers.lecun_normal()(*a, **k) DEFAULT_BIAS_INIT = nn.initializers.zeros DEFAULT_RECTIFIER = nn.activation.relu diff --git a/src/ott/neural/networks/layers/time_encoder.py b/src/ott/neural/networks/layers/time_encoder.py new file mode 100644 index 000000000..b02bd125c --- /dev/null +++ b/src/ott/neural/networks/layers/time_encoder.py @@ -0,0 +1,34 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import jax.numpy as jnp + +__all__ = ["cyclical_time_encoder"] + + +def cyclical_time_encoder(t: jnp.ndarray, n_freqs: int = 128) -> jnp.ndarray: + r"""Encode time :math:`t` into a cyclical representation. + + Time :math:`t` is encoded as :math:`cos(\hat{t})` and :math:`sin(\hat{t})` + where :math:`\hat{t} = [2\pi t, 2\pi 2 t,\dots, 2\pi n_f t]`. + + Args: + t: Time of shape ``[n, 1]``. + n_freqs: Frequency :math:`n_f` of the cyclical encoding. + + Returns: + Encoded time of shape ``[n, 2 * n_freqs]``. + """ + freq = 2 * jnp.arange(n_freqs) * jnp.pi + t = freq * t + return jnp.concatenate([jnp.cos(t), jnp.sin(t)], axis=-1) diff --git a/src/ott/neural/networks/potentials.py b/src/ott/neural/networks/potentials.py new file mode 100644 index 000000000..6a08e0048 --- /dev/null +++ b/src/ott/neural/networks/potentials.py @@ -0,0 +1,185 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc +from typing import Any, Callable, Optional, Sequence, Tuple, Union + +import jax +import jax.numpy as jnp + +import optax +from flax import linen as nn +from flax import struct +from flax.core import frozen_dict +from flax.training import train_state + +__all__ = ["PotentialTrainState", "BasePotential", "PotentialMLP"] + +PotentialValueFn_t = Callable[[jnp.ndarray], jnp.ndarray] +PotentialGradientFn_t = Callable[[jnp.ndarray], jnp.ndarray] + + +class PotentialTrainState(train_state.TrainState): + """Adds information about the model's value and gradient to the state. + + This extends :class:`~flax.training.train_state.TrainState` to include + the potential methods from the + :class:`~ott.neural.duality.neuraldual.BaseW2NeuralDual` used during training. + + Args: + potential_value_fn: the potential's value function + potential_gradient_fn: the potential's gradient function + """ + potential_value_fn: Callable[ + [frozen_dict.FrozenDict[str, jnp.ndarray], Optional[PotentialValueFn_t]], + PotentialValueFn_t] = struct.field(pytree_node=False) + potential_gradient_fn: Callable[[frozen_dict.FrozenDict[str, jnp.ndarray]], + PotentialGradientFn_t] = struct.field( + pytree_node=False + ) + + +class BasePotential(abc.ABC, nn.Module): + """Base class for the neural solver models.""" + + @property + @abc.abstractmethod + def is_potential(self) -> bool: + """Indicates if the module implements a potential value or a vector field. + + Returns: + ``True`` if the module defines a potential, ``False`` if it defines a + vector field. + """ + + def potential_value_fn( + self, + params: frozen_dict.FrozenDict[str, jnp.ndarray], + other_potential_value_fn: Optional[PotentialValueFn_t] = None, + ) -> PotentialValueFn_t: + r"""Return a function giving the value of the potential. + + Applies the module if :attr:`is_potential` is ``True``, otherwise + constructs the value of the potential from the gradient with + + .. math:: + + g(y) = -f(\nabla_y g(y)) + y^T \nabla_y g(y) + + where :math:`\nabla_y g(y)` is detached for the envelope theorem + :cite:`danskin:67,bertsekas:71` + to give the appropriate first derivatives of this construction. + + Args: + params: parameters of the module + other_potential_value_fn: function giving the value of the other + potential. Only needed when :attr:`is_potential` is ``False``. + + Returns: + A function that can be evaluated to obtain a potential value, or a linear + interpolation of a potential. + """ + if self.is_potential: + return lambda x: self.apply({"params": params}, x) + + assert other_potential_value_fn is not None, \ + "The value of the gradient-based potential depends " \ + "on the value of the other potential." + + def value_fn(x: jnp.ndarray) -> jnp.ndarray: + squeeze = x.ndim == 1 + if squeeze: + x = jnp.expand_dims(x, 0) + grad_g_x = jax.lax.stop_gradient(self.apply({"params": params}, x)) + value = -other_potential_value_fn(grad_g_x) + \ + jax.vmap(jnp.dot)(grad_g_x, x) + return value.squeeze(0) if squeeze else value + + return value_fn + + def potential_gradient_fn( + self, + params: frozen_dict.FrozenDict[str, jnp.ndarray], + ) -> PotentialGradientFn_t: + """Return a function returning a vector or the gradient of the potential. + + Args: + params: parameters of the module + + Returns: + A function that can be evaluated to obtain the potential's gradient + """ + if self.is_potential: + return jax.vmap(jax.grad(self.potential_value_fn(params))) + return lambda x: self.apply({"params": params}, x) + + def create_train_state( + self, + rng: jax.Array, + optimizer: optax.OptState, + input: Union[int, Tuple[int, ...]], + **kwargs: Any, + ) -> PotentialTrainState: + """Create initial training state.""" + params = self.init(rng, jnp.ones(input))["params"] + + return PotentialTrainState.create( + apply_fn=self.apply, + params=params, + tx=optimizer, + potential_value_fn=self.potential_value_fn, + potential_gradient_fn=self.potential_gradient_fn, + **kwargs, + ) + + +class PotentialMLP(BasePotential): + """Potential MLP. + + Args: + dim_hidden: sequence specifying size of hidden dimensions. The output + dimension of the last layer is automatically set to 1 if + :attr:`is_potential` is ``True``, or the dimension of the input otherwise. + is_potential: Model the potential if ``True``, otherwise + model the gradient of the potential. + act_fn: Activation function. + """ + + dim_hidden: Sequence[int] + is_potential: bool = True + act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.leaky_relu + + @nn.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: # noqa: D102 + squeeze = x.ndim == 1 + if squeeze: + x = jnp.expand_dims(x, 0) + assert x.ndim == 2, x.ndim + n_input = x.shape[-1] + + z = x + for n_hidden in self.dim_hidden: + Wx = nn.Dense(n_hidden, use_bias=True) + z = self.act_fn(Wx(z)) + + if self.is_potential: + Wx = nn.Dense(1, use_bias=True) + z = Wx(z).squeeze(-1) + + quad_term = 0.5 * jax.vmap(jnp.dot)(x, x) + z += quad_term + else: + Wx = nn.Dense(n_input, use_bias=True) + z = x + Wx(z) + + return z.squeeze(0) if squeeze else z diff --git a/src/ott/neural/flow_models/models.py b/src/ott/neural/networks/velocity_field.py similarity index 96% rename from src/ott/neural/flow_models/models.py rename to src/ott/neural/networks/velocity_field.py index a770b1fdd..55bbfabfc 100644 --- a/src/ott/neural/flow_models/models.py +++ b/src/ott/neural/networks/velocity_field.py @@ -16,11 +16,11 @@ import jax import jax.numpy as jnp -import flax.linen as nn import optax +from flax import linen as nn from flax.training import train_state -from ott.neural.flow_models import utils +from ott.neural.networks.layers import time_encoder __all__ = ["VelocityField"] @@ -50,7 +50,7 @@ class VelocityField(nn.Module): condition_dims: Optional[Sequence[int]] = None time_dims: Optional[Sequence[int]] = None time_encoder: Callable[[jnp.ndarray], - jnp.ndarray] = utils.cyclical_time_encoder + jnp.ndarray] = time_encoder.cyclical_time_encoder act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.silu @nn.compact diff --git a/src/ott/solvers/__init__.py b/src/ott/solvers/__init__.py index 1303312f9..283fca465 100644 --- a/src/ott/solvers/__init__.py +++ b/src/ott/solvers/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import linear, quadratic, was_solver +from . import linear, quadratic, utils, was_solver diff --git a/src/ott/neural/flow_models/utils.py b/src/ott/solvers/utils.py similarity index 90% rename from src/ott/neural/flow_models/utils.py rename to src/ott/solvers/utils.py index 6c67e45f0..6c48a2577 100644 --- a/src/ott/neural/flow_models/utils.py +++ b/src/ott/solvers/utils.py @@ -24,7 +24,6 @@ "match_quadratic", "sample_joint", "sample_conditional", - "cyclical_time_encoder", "uniform_sampler", ] @@ -150,24 +149,6 @@ def sample_conditional( return src_ixs, tgt_ixs -def cyclical_time_encoder(t: jnp.ndarray, n_freqs: int = 128) -> jnp.ndarray: - r"""Encode time :math:`t` into a cyclical representation. - - Time :math:`t` is encoded as :math:`cos(\hat{t})` and :math:`sin(\hat{t})` - where :math:`\hat{t} = [2\pi t, 2\pi 2 t,\dots, 2\pi n_f t]`. - - Args: - t: Time of shape ``[n, 1]``. - n_freqs: Frequency :math:`n_f` of the cyclical encoding. - - Returns: - Encoded time of shape ``[n, 2 * n_freqs]``. - """ - freq = 2 * jnp.arange(n_freqs) * jnp.pi - t = freq * t - return jnp.concatenate([jnp.cos(t), jnp.sin(t)], axis=-1) - - def uniform_sampler( rng: jax.Array, num_samples: int, diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/geometry/graph_test.py b/tests/geometry/graph_test.py index 46e4f825b..14485c3b6 100644 --- a/tests/geometry/graph_test.py +++ b/tests/geometry/graph_test.py @@ -20,9 +20,9 @@ import pytest import jax +import jax.experimental.sparse as jesp import jax.numpy as jnp import numpy as np -from jax.experimental import sparse from ott.geometry import geometry, graph from ott.problems.linear import linear_problem @@ -259,7 +259,7 @@ def callback( data: jnp.ndarray, rows: jnp.ndarray, cols: jnp.ndarray, shape: Tuple[int, int] ) -> float: - G = sparse.BCOO((data, jnp.c_[rows, cols]), shape=shape).todense() + G = jesp.BCOO((data, jnp.c_[rows, cols]), shape=shape).todense() geom = graph.Graph.from_graph(G, t=1.0) solver = sinkhorn.Sinkhorn(lse_mode=False, **kwargs) @@ -274,7 +274,7 @@ def callback( eps = 1e-3 G = random_graph(20, p=0.5) - G = sparse.BCOO.fromdense(G) + G = jesp.BCOO.fromdense(G) w, rows, cols = G.data, G.indices[:, 0], G.indices[:, 1] v_w = jax.random.normal(rng, shape=w.shape) diff --git a/tests/neural/meta_initializer_test.py b/tests/initializers/neural/meta_initializer_test.py similarity index 95% rename from tests/neural/meta_initializer_test.py rename to tests/initializers/neural/meta_initializer_test.py index 117ca6b22..3e04556f9 100644 --- a/tests/neural/meta_initializer_test.py +++ b/tests/initializers/neural/meta_initializer_test.py @@ -18,11 +18,11 @@ import jax import jax.numpy as jnp -import flax.linen as nn +from flax import linen as nn from ott.geometry import pointcloud from ott.initializers.linear import initializers as linear_init -from ott.neural.duality import models as nn_init +from ott.initializers.neural import meta_initializer as meta_init from ott.problems.linear import linear_problem from ott.solvers.linear import sinkhorn @@ -109,7 +109,7 @@ def test_meta_initializer(self, rng: jax.Array, lse_mode: bool): # overfit the initializer to the problem. meta_model = MetaMLP(n) - meta_initializer = nn_init.MetaInitializer(geom, meta_model) + meta_initializer = meta_init.MetaInitializer(geom, meta_model) for _ in range(50): _, _, meta_initializer.state = meta_initializer.update( meta_initializer.state, a=a, b=b diff --git a/tests/neural/__init__.py b/tests/neural/__init__.py index f642d8b21..278074b14 100644 --- a/tests/neural/__init__.py +++ b/tests/neural/__init__.py @@ -1,3 +1,16 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pytest -_ = pytest.importorskip("flax") +_ = pytest.importorskip("ott.neural") diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py index 92c23f6a6..f4c25c514 100644 --- a/tests/neural/conftest.py +++ b/tests/neural/conftest.py @@ -18,7 +18,7 @@ import numpy as np from torch.utils.data import DataLoader -from ott.neural.data import datasets +from ott.neural import datasets class OTLoader(NamedTuple): diff --git a/tests/neural/map_estimator_test.py b/tests/neural/map_estimator_test.py deleted file mode 100644 index cee66e40e..000000000 --- a/tests/neural/map_estimator_test.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright OTT-JAX -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Optional - -import pytest - -import jax.numpy as jnp - -from ott import datasets -from ott.geometry import pointcloud -from ott.neural.duality import models -from ott.neural.gaps import map_estimator, monge_gap -from ott.tools import sinkhorn_divergence - - -@pytest.mark.fast() -class TestMapEstimator: - - def test_map_estimator_convergence(self): - """Tests convergence of a simple - map estimator with Sinkhorn divergence fitting loss - and Monge (coupling) gap regularizer. - """ - - # define the fitting loss and the regularizer - def fitting_loss( - samples: jnp.ndarray, - mapped_samples: jnp.ndarray, - ) -> Optional[float]: - r"""Sinkhorn divergence fitting loss.""" - div = sinkhorn_divergence.sinkhorn_divergence( - pointcloud.PointCloud, - x=samples, - y=mapped_samples, - ).divergence - return div, None - - def regularizer(x, y): - gap, out = monge_gap.monge_gap_from_samples(x, y, return_output=True) - return gap, out.n_iters - - # define the model - model = models.PotentialMLP(dim_hidden=[16, 8], is_potential=False) - - # generate data - train_dataset, valid_dataset, dim_data = ( - datasets.create_gaussian_mixture_samplers( - name_source="simple", - name_target="circle", - train_batch_size=30, - valid_batch_size=30, - ) - ) - - # fit the map - solver = map_estimator.MapEstimator( - dim_data=dim_data, - fitting_loss=fitting_loss, - regularizer=regularizer, - model=model, - regularizer_strength=1.0, - num_train_iters=15, - logging=True, - valid_freq=5, - ) - neural_state, logs = solver.train_map_estimator( - *train_dataset, *valid_dataset - ) - - # check if the loss has decreased during training - assert logs["train"]["total_loss"][0] > logs["train"]["total_loss"][-1] - - # check dimensionality of the mapped source - source = next(train_dataset.source_iter) - mapped_source = neural_state.apply_fn({"params": neural_state.params}, - source) - assert mapped_source.shape[1] == dim_data diff --git a/tests/neural/genot_test.py b/tests/neural/methods/genot_test.py similarity index 83% rename from tests/neural/genot_test.py rename to tests/neural/methods/genot_test.py index c37d91563..086ea7a80 100644 --- a/tests/neural/genot_test.py +++ b/tests/neural/methods/genot_test.py @@ -22,7 +22,9 @@ import optax -from ott.neural.flow_models import flows, genot, models, utils +from ott.neural.methods.flows import dynamics, genot +from ott.neural.networks import velocity_field +from ott.solvers import utils as solver_utils def data_match_fn( @@ -31,11 +33,13 @@ def data_match_fn( typ: Literal["lin", "quad", "fused"] ) -> jnp.ndarray: if typ == "lin": - return utils.match_linear(x=src_lin, y=tgt_lin) + return solver_utils.match_linear(x=src_lin, y=tgt_lin) if typ == "quad": - return utils.match_quadratic(xx=src_quad, yy=tgt_quad) + return solver_utils.match_quadratic(xx=src_quad, yy=tgt_quad) if typ == "fused": - return utils.match_quadratic(xx=src_quad, yy=tgt_quad, x=src_lin, y=tgt_lin) + return solver_utils.match_quadratic( + xx=src_quad, yy=tgt_quad, x=src_lin, y=tgt_lin + ) raise NotImplementedError(f"Unknown type: {typ}.") @@ -56,14 +60,14 @@ def test_genot(self, rng: jax.Array, dl: str, request): tgt_dim = dl.lin_dim + dl.quad_tgt_dim cond_dim = dl.cond_dim - vf = models.VelocityField( + vf = velocity_field.VelocityField( hidden_dims=[7, 7, 7], output_dims=[15, tgt_dim], condition_dims=None if cond_dim is None else [1, 3, 2], ) model = genot.GENOT( vf, - flow=flows.ConstantNoiseFlow(0.0), + flow=dynamics.ConstantNoiseFlow(0.0), data_match_fn=functools.partial(data_match_fn, typ=problem_type), source_dim=src_dim, target_dim=tgt_dim, diff --git a/tests/neural/losses_test.py b/tests/neural/methods/monge_gap_test.py similarity index 64% rename from tests/neural/losses_test.py rename to tests/neural/methods/monge_gap_test.py index e1e13f193..68d885537 100644 --- a/tests/neural/losses_test.py +++ b/tests/neural/methods/monge_gap_test.py @@ -11,14 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + import pytest import jax +import jax.numpy as jnp import numpy as np -from ott.geometry import costs -from ott.neural.duality import models -from ott.neural.gaps import monge_gap +from ott import datasets +from ott.geometry import costs, pointcloud +from ott.neural.methods import monge_gap +from ott.neural.networks import potentials +from ott.tools import sinkhorn_divergence @pytest.mark.fast() @@ -34,7 +39,7 @@ def test_monge_gap_non_negativity( rng1, rng2 = jax.random.split(rng, 2) reference_points = jax.random.normal(rng1, (n_samples, n_features)) - model = models.PotentialMLP(dim_hidden=[8, 8], is_potential=False) + model = potentials.PotentialMLP(dim_hidden=[8, 8], is_potential=False) params = model.init(rng2, x=reference_points[0]) target = model.apply(params, reference_points) @@ -122,3 +127,67 @@ def test_monge_gap_from_samples_different_cost( np.testing.assert_array_equal( np.isfinite(monge_gap_from_samples_value_cost_fn), True ) + + +@pytest.mark.fast() +class TestMongeGapEstimator: + + def test_map_estimator_convergence(self): + """Tests convergence of a simple + map estimator with Sinkhorn divergence fitting loss + and Monge (coupling) gap regularizer. + """ + + # define the fitting loss and the regularizer + def fitting_loss( + samples: jnp.ndarray, + mapped_samples: jnp.ndarray, + ) -> Optional[float]: + r"""Sinkhorn divergence fitting loss.""" + div = sinkhorn_divergence.sinkhorn_divergence( + pointcloud.PointCloud, + x=samples, + y=mapped_samples, + ).divergence + return div, None + + def regularizer(x, y): + gap, out = monge_gap.monge_gap_from_samples(x, y, return_output=True) + return gap, out.n_iters + + # define the model + model = potentials.PotentialMLP(dim_hidden=[16, 8], is_potential=False) + + # generate data + train_dataset, valid_dataset, dim_data = ( + datasets.create_gaussian_mixture_samplers( + name_source="simple", + name_target="circle", + train_batch_size=30, + valid_batch_size=30, + ) + ) + + # fit the map + solver = monge_gap.MongeGapEstimator( + dim_data=dim_data, + fitting_loss=fitting_loss, + regularizer=regularizer, + model=model, + regularizer_strength=1.0, + num_train_iters=15, + logging=True, + valid_freq=5, + ) + neural_state, logs = solver.train_map_estimator( + *train_dataset, *valid_dataset + ) + + # check if the loss has decreased during training + assert logs["train"]["total_loss"][0] > logs["train"]["total_loss"][-1] + + # check dimensionality of the mapped source + source = next(train_dataset.source_iter) + mapped_source = neural_state.apply_fn({"params": neural_state.params}, + source) + assert mapped_source.shape[1] == dim_data diff --git a/tests/neural/neuraldual_test.py b/tests/neural/methods/neuraldual_test.py similarity index 86% rename from tests/neural/neuraldual_test.py rename to tests/neural/methods/neuraldual_test.py index 5aef77aba..b0d847abb 100644 --- a/tests/neural/neuraldual_test.py +++ b/tests/neural/methods/neuraldual_test.py @@ -19,9 +19,11 @@ import numpy as np from ott import datasets -from ott.neural.duality import conjugate, models, neuraldual +from ott.neural.methods import neuraldual +from ott.neural.networks import icnn, potentials +from ott.neural.networks.layers import conjugate -ModelPair_t = Tuple[neuraldual.BaseW2NeuralDual, neuraldual.BaseW2NeuralDual] +ModelPair_t = Tuple[potentials.BasePotential, potentials.BasePotential] DatasetPair_t = Tuple[datasets.Dataset, datasets.Dataset] @@ -37,16 +39,16 @@ def ds(request: Tuple[str, str]) -> DatasetPair_t: def neural_models(request: str) -> ModelPair_t: if request.param == "icnns": return ( - models.ICNN(dim_data=2, - dim_hidden=[32]), models.ICNN(dim_data=2, dim_hidden=[32]) + icnn.ICNN(dim_data=2, + dim_hidden=[32]), icnn.ICNN(dim_data=2, dim_hidden=[32]) ) if request.param == "mlps": - return models.PotentialMLP(dim_hidden=[32] - ), models.PotentialMLP(dim_hidden=[32]), + return potentials.PotentialMLP(dim_hidden=[32] + ), potentials.PotentialMLP(dim_hidden=[32]), if request.param == "mlps-grad": return ( - models.PotentialMLP(dim_hidden=[32]), - models.PotentialMLP(is_potential=False, dim_hidden=[128]) + potentials.PotentialMLP(dim_hidden=[32]), + potentials.PotentialMLP(is_potential=False, dim_hidden=[128]) ) raise ValueError(f"Invalid request: {request.param}") @@ -82,7 +84,7 @@ def decreasing(losses: Sequence[float]) -> bool: train_dataset, valid_dataset = ds if test_gaussian_init: - neural_f = models.ICNN( + neural_f = icnn.ICNN( dim_data=2, dim_hidden=[32], gaussian_map_samples=[ @@ -90,7 +92,7 @@ def decreasing(losses: Sequence[float]) -> bool: next(train_dataset.target_iter) ] ) - neural_g = models.ICNN( + neural_g = icnn.ICNN( dim_data=2, dim_hidden=[32], gaussian_map_samples=[ diff --git a/tests/neural/otfm_test.py b/tests/neural/methods/otfm_test.py similarity index 85% rename from tests/neural/otfm_test.py rename to tests/neural/methods/otfm_test.py index 00619dc57..0eb311fa6 100644 --- a/tests/neural/otfm_test.py +++ b/tests/neural/methods/otfm_test.py @@ -19,7 +19,9 @@ import optax -from ott.neural.flow_models import flows, models, otfm, utils +from ott.neural.methods.flows import dynamics, otfm +from ott.neural.networks import velocity_field +from ott.solvers import utils as solver_utils class TestOTFlowMatching: @@ -29,15 +31,15 @@ def test_otfm(self, rng: jax.Array, dl: str, request): dl = request.getfixturevalue(dl) dim, cond_dim = dl.lin_dim, dl.cond_dim - neural_vf = models.VelocityField( + vf = velocity_field.VelocityField( hidden_dims=[5, 5, 5], output_dims=[7, dim], condition_dims=None if cond_dim is None else [4, 3, 2], ) fm = otfm.OTFlowMatching( - neural_vf, - flows.ConstantNoiseFlow(0.0), - match_fn=jax.jit(utils.match_linear), + vf, + dynamics.ConstantNoiseFlow(0.0), + match_fn=jax.jit(solver_utils.match_linear), rng=rng, optimizer=optax.adam(learning_rate=1e-3), condition_dim=cond_dim, diff --git a/tests/neural/icnn_test.py b/tests/neural/networks/icnn_test.py similarity index 93% rename from tests/neural/icnn_test.py rename to tests/neural/networks/icnn_test.py index a60682a06..b07e4994f 100644 --- a/tests/neural/icnn_test.py +++ b/tests/neural/networks/icnn_test.py @@ -17,7 +17,7 @@ import jax.numpy as jnp import numpy as np -from ott.neural.duality import models +from ott.neural.networks import icnn @pytest.mark.fast() @@ -29,7 +29,7 @@ def test_icnn_convexity(self, rng: jax.Array): dim_hidden = (64, 64) # define icnn model - model = models.ICNN(n_features, dim_hidden=dim_hidden) + model = icnn.ICNN(n_features, dim_hidden=dim_hidden) # initialize model rng1, rng2 = jax.random.split(rng, 2) @@ -55,7 +55,7 @@ def test_icnn_hessian(self, rng: jax.Array): # define icnn model n_features = 2 dim_hidden = (64, 64) - model = models.ICNN(n_features, dim_hidden=dim_hidden) + model = icnn.ICNN(n_features, dim_hidden=dim_hidden) # initialize model rng1, rng2 = jax.random.split(rng) diff --git a/tests/tools/plot_test.py b/tests/tools/plot_test.py index 1f9f9ba01..2d8ba55ac 100644 --- a/tests/tools/plot_test.py +++ b/tests/tools/plot_test.py @@ -16,14 +16,13 @@ import matplotlib.pyplot as plt -import ott from ott.geometry import pointcloud from ott.problems.linear import linear_problem from ott.solvers.linear import sinkhorn from ott.tools import plot -class TestSoftSort: +class TestPlotting: def test_plot(self, monkeypatch): monkeypatch.setattr(plt, "show", lambda: None) @@ -44,5 +43,5 @@ def test_plot(self, monkeypatch): plott = plot.Plot() _ = plott(ots[0]) fig = plt.figure(figsize=(8, 5)) - plott = ott.tools.plot.Plot(fig=fig, title="test") + plott = plot.Plot(fig=fig, title="test") plott.animate(ots, frame_rate=2, titles=["test1", "test2"]) From ef6afd1f779ca03e6c6c70f49e33de6cf1cc21e5 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 29 Mar 2024 19:15:41 +0100 Subject: [PATCH 173/186] Fix neural init tests import --- tests/initializers/neural/__init__.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 tests/initializers/neural/__init__.py diff --git a/tests/initializers/neural/__init__.py b/tests/initializers/neural/__init__.py new file mode 100644 index 000000000..8c23e4ba8 --- /dev/null +++ b/tests/initializers/neural/__init__.py @@ -0,0 +1,16 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +_ = pytest.importorskip("ott.initializers.neural") From 73c2527dfd2288164eea266002a90bfe62eff931 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 29 Mar 2024 20:12:20 +0100 Subject: [PATCH 174/186] Update `docs/` --- docs/neural/data.rst | 15 ----------- docs/neural/datasets.rst | 15 +++++++++++ docs/neural/duality.rst | 37 --------------------------- docs/neural/flow_models.rst | 50 ------------------------------------- docs/neural/gaps.rst | 15 ----------- docs/neural/index.rst | 8 +++--- docs/neural/methods.rst | 37 +++++++++++++++++++++++++++ docs/neural/networks.rst | 33 ++++++++++++++++++++++++ docs/solvers/index.rst | 11 ++++++++ 9 files changed, 99 insertions(+), 122 deletions(-) delete mode 100644 docs/neural/data.rst create mode 100644 docs/neural/datasets.rst delete mode 100644 docs/neural/duality.rst delete mode 100644 docs/neural/flow_models.rst delete mode 100644 docs/neural/gaps.rst create mode 100644 docs/neural/methods.rst create mode 100644 docs/neural/networks.rst diff --git a/docs/neural/data.rst b/docs/neural/data.rst deleted file mode 100644 index 3ae0e4c53..000000000 --- a/docs/neural/data.rst +++ /dev/null @@ -1,15 +0,0 @@ -ott.neural.data -=============== -.. module:: ott.neural.data -.. currentmodule:: ott.neural.data - -The :mod:`ott.neural.data` contains data sets and data loaders needed -for solving (conditional) neural optimal transport problems. - -Datasets --------- -.. autosummary:: - :toctree: _autosummary - - datasets.OTData - datasets.OTDataset diff --git a/docs/neural/datasets.rst b/docs/neural/datasets.rst new file mode 100644 index 000000000..67d5e3b6b --- /dev/null +++ b/docs/neural/datasets.rst @@ -0,0 +1,15 @@ +ott.neural.datasets +=================== +.. module:: ott.neural.datasets +.. currentmodule:: ott.neural + +The :mod:`ott.neural.datasets` contains datasets and needed for solving +(conditional) neural optimal transport problems. + +Datasets +-------- +.. autosummary:: + :toctree: _autosummary + + datasets.OTData + datasets.OTDataset diff --git a/docs/neural/duality.rst b/docs/neural/duality.rst deleted file mode 100644 index 25dc89daa..000000000 --- a/docs/neural/duality.rst +++ /dev/null @@ -1,37 +0,0 @@ -ott.neural.duality -================== -.. module:: ott.neural.duality -.. currentmodule:: ott.neural.duality - -This module implements various solvers to estimate optimal transport between -two probability measures, through samples, parameterized as neural networks. -These solvers build upon dual formulation of the optimal transport problem. - -Solvers -------- -.. autosummary:: - :toctree: _autosummary - - neuraldual.W2NeuralDual - neuraldual.BaseW2NeuralDual - -Conjugate Solvers ------------------ -.. autosummary:: - :toctree: _autosummary - - conjugate.FenchelConjugateLBFGS - conjugate.FenchelConjugateSolver - conjugate.ConjugateResults - -Models ------- -.. autosummary:: - :toctree: _autosummary - - neuraldual.W2NeuralTrainState - neuraldual.BaseW2NeuralDual - neuraldual.W2NeuralDual - models.ICNN - models.PotentialMLP - models.MetaInitializer diff --git a/docs/neural/flow_models.rst b/docs/neural/flow_models.rst deleted file mode 100644 index 273f145f3..000000000 --- a/docs/neural/flow_models.rst +++ /dev/null @@ -1,50 +0,0 @@ -ott.neural.flow_models -====================== -.. module:: ott.neural.flow_models -.. currentmodule:: ott.neural.flow_models - -This module implements various solvers building upon flow matching -:cite:`lipman:22` to match distributions. - -Flows ------ -.. autosummary:: - :toctree: _autosummary - - flows.BaseFlow - flows.StraightFlow - flows.ConstantNoiseFlow - flows.BrownianNoiseFlow - -OT Flow Matching ----------------- -.. autosummary:: - :toctree: _autosummary - - otfm.OTFlowMatching - -GENOT ------ -.. autosummary:: - :toctree: _autosummary - - genot.GENOT - -Models ------- -.. autosummary:: - :toctree: _autosummary - - models.VelocityField - -Utils ------ -.. autosummary:: - :toctree: _autosummary - - utils.match_linear - utils.match_quadratic - utils.sample_joint - utils.sample_conditional - utils.cyclical_time_encoder - utils.uniform_sampler diff --git a/docs/neural/gaps.rst b/docs/neural/gaps.rst deleted file mode 100644 index abf621e24..000000000 --- a/docs/neural/gaps.rst +++ /dev/null @@ -1,15 +0,0 @@ -ott.neural.gaps -=============== -.. module:: ott.neural.gaps -.. currentmodule:: ott.neural.gaps - -This module implements gap models. - -Monge gap ---------- -.. autosummary:: - :toctree: _autosummary - - map_estimator.MapEstimator - monge_gap.monge_gap - monge_gap.monge_gap_from_samples diff --git a/docs/neural/index.rst b/docs/neural/index.rst index 9de1781f6..5cf025cdc 100644 --- a/docs/neural/index.rst +++ b/docs/neural/index.rst @@ -1,7 +1,6 @@ ott.neural ========== .. module:: ott.neural -.. currentmodule:: ott.neural In contrast to most methods presented in :mod:`ott.solvers`, which output vectors or matrices, the goal of the :mod:`ott.neural` module is to parameterize @@ -13,7 +12,6 @@ and solvers to estimate such neural networks. .. toctree:: :maxdepth: 2 - data - duality - flow_models - gaps + datasets + methods + networks diff --git a/docs/neural/methods.rst b/docs/neural/methods.rst new file mode 100644 index 000000000..028651a34 --- /dev/null +++ b/docs/neural/methods.rst @@ -0,0 +1,37 @@ +ott.neural.methods +================== +.. module:: ott.neural.methods +.. currentmodule:: ott.neural.methods + +Monge Gap +--------- +.. autosummary:: + :toctree: _autosummary + + monge_gap.monge_gap + monge_gap.monge_gap_from_samples + monge_gap.MongeGapEstimator + +Neural Dual +----------- +.. autosummary:: + :toctree: _autosummary + + neuraldual.W2NeuralDual + +ott.neural.methods.flows +======================== +.. module:: ott.neural.methods.flows +.. currentmodule:: ott.neural.methods.flows + +Flows +----- +.. autosummary:: + :toctree: _autosummary + + otfm.OTFlowMatching + genot.GENOT + dynamics.BaseFlow + dynamics.StraightFlow + dynamics.ConstantNoiseFlow + dynamics.BrownianBridge diff --git a/docs/neural/networks.rst b/docs/neural/networks.rst new file mode 100644 index 000000000..647243192 --- /dev/null +++ b/docs/neural/networks.rst @@ -0,0 +1,33 @@ +ott.neural.networks +=================== +.. module:: ott.neural.networks +.. currentmodule:: ott.neural.networks + +Networks +-------- +.. autosummary:: + :toctree: _autosummary + + icnn.ICNN + velocity_field.VelocityField + potentials.BasePotential + potentials.PotentialMLP + potentials.PotentialTrainState + + +ott.neural.networks.layers +========================== +.. module:: ott.neural.networks.layers +.. currentmodule:: ott.neural.networks.layers + +Layers +------ +.. autosummary:: + :toctree: _autosummary + + conjugate.FenchelConjugateSolver + conjugate.FenchelConjugateLBFGS + conjugate.ConjugateResults + posdef.PositiveDense + posdef.PosDefPotentials + time_encoder.cyclical_time_encoder diff --git a/docs/solvers/index.rst b/docs/solvers/index.rst index ddfbc9230..d23b4cdac 100644 --- a/docs/solvers/index.rst +++ b/docs/solvers/index.rst @@ -23,3 +23,14 @@ Wasserstein Solver :toctree: _autosummary was_solver.WassersteinSolver + +Utilities +--------- +.. autosummary:: + :toctree: _autosummary + + utils.match_linear + utils.match_quadratic + utils.sample_joint + utils.sample_conditional + utils.uniform_sampler From 0418e788bfa1bb7b8e2cb4db8f9b08b55dde425f Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Tue, 2 Apr 2024 09:43:19 +0200 Subject: [PATCH 175/186] Update Monge Gap --- docs/tutorials/Monge_Gap.ipynb | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/tutorials/Monge_Gap.ipynb b/docs/tutorials/Monge_Gap.ipynb index 99fbc17ca..be9098a09 100644 --- a/docs/tutorials/Monge_Gap.ipynb +++ b/docs/tutorials/Monge_Gap.ipynb @@ -40,8 +40,8 @@ "\n", "from ott import datasets\n", "from ott.geometry import costs, pointcloud\n", - "from ott.neural import losses, models\n", - "from ott.neural.solvers import map_estimator\n", + "from ott.neural.methods import monge_gap\n", + "from ott.neural.networks import potentials\n", "from ott.solvers.linear import acceleration\n", "from ott.tools import sinkhorn_divergence" ] @@ -58,7 +58,7 @@ "T^\\star \\in \\arg\\min_{\\substack{T:\\mathbb{R}^d \\rightarrow \\mathbb{R}^d \\\\ T \\sharp \\mu = \\nu}} \\int c(x,T(x)) \\mathrm{d}\\mu(x)\n", "$$\n", "\n", - "We show how to use the {func}`~ott.neural.losses.monge_gap`, a regularizer proposed by {cite}`uscidda:23` to do so. Computing an OT map can be split into two goals: move mass efficiently from $\\mu$ to $T\\sharp\\mu$ (this is the objective), while, at the same time, making sure $T\\sharp\\mu$ \"lands\" on $\\nu$ (the constraint).\n", + "We show how to use the {func}`~ott.neural.methods.monge_gap.monge_gap`, a regularizer proposed by {cite}`uscidda:23` to do so. Computing an OT map can be split into two goals: move mass efficiently from $\\mu$ to $T\\sharp\\mu$ (this is the objective), while, at the same time, making sure $T\\sharp\\mu$ \"lands\" on $\\nu$ (the constraint).\n", "\n", "The first requirement (efficiency) can be quantified with the **Monge gap** $\\mathcal{M}_\\mu^c$, a non-negative regularizer defined through $\\mu$ and $c$, and which takes as an argument any map $T : \\mathbb{R}^d \\rightarrow \\mathbb{R}^d$. The value $\\mathcal{M}_\\mu^c(T)$ quantifies how $T$ moves mass efficiently between $\\mu$ and $T \\sharp \\mu$, and only cancels $\\mathcal{M}_\\mu^c(T) = 0$ i.f.f. $T$ is optimal between $\\mu$ and $T \\sharp \\mu$ for the cost $c$.\n", "\n", @@ -68,7 +68,7 @@ "\\min_{T:\\mathbb{R}^d \\rightarrow \\mathbb{R}^d} \\Delta(T\\sharp \\mu, \\nu) + \\lambda_\\mathrm{MG} \\mathcal{M}_\\mu^c(T)\n", "$$\n", "\n", - "We parameterize maps $T$ as neural networks $\\{T_\\theta\\}_{\\theta \\in \\mathbb{R}^d}$, using the {class}`~ott.neural.solvers.map_estimator.MapEstimator` solver. For the squared-Euclidean cost, this method provides a simple alternative to the {class}`~ott.neural.solvers.neuraldual.W2NeuralDual` solver, but one that does not require parameterizing networks as gradients of convex functions." + "We parameterize maps $T$ as neural networks $\\{T_\\theta\\}_{\\theta \\in \\mathbb{R}^d}$, using the {class}`~ott.neural.methods.monge_gap.MongeGapEstimator` solver. For the squared Euclidean cost, this method provides a simple alternative to the {class}`~ott.neural.methods.neuraldual.W2NeuralDual` solver, but one that does not require parameterizing networks as gradients of convex functions." ] }, { @@ -293,7 +293,7 @@ "outputs": [ { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -336,7 +336,7 @@ "$$\n", "\\min_{T:\\mathbb{R}^d \\rightarrow \\mathbb{R}^d} \\Delta(T\\sharp \\mu, \\nu) + \\lambda_\\mathrm{MG} \\mathcal{M}_\\mu^c(T)\n", "$$\n", - "For all fittings, we use $\\Delta = S_{\\varepsilon, \\ell_2^2}$, the {func}`~ott.tools.sinkhorn_divergence.sinkhorn_divergence` with the {class}`squared-Euclidean cost `\n", + "For all fittings, we use $\\Delta = S_{\\varepsilon, \\ell_2^2}$, the {func}`~ott.tools.sinkhorn_divergence.sinkhorn_divergence` with the {class}`squared Euclidean cost `\n", "The function considers a ground cost function `cost_fn` (corresponding to $c$), as well as the `epsilon` regularization parameters to compute approximated Wasserstein distances, both for fitting and regularizer." ] }, @@ -355,7 +355,7 @@ "):\n", " dim_data = 2\n", " # define the neural map\n", - " model = models.MLP(\n", + " model = potentials.PotentialMLP(\n", " dim_hidden=[32, 64, 32], is_potential=False, act_fn=nn.gelu\n", " )\n", "\n", @@ -388,7 +388,7 @@ " print(\"Selected `epsilon_regularizer`:\", epsilon_regularizer)\n", "\n", " def regularizer(x, y):\n", - " gap, out = losses.monge_gap_from_samples(\n", + " gap, out = monge_gap.monge_gap_from_samples(\n", " x,\n", " y,\n", " cost_fn=cost_fn,\n", @@ -398,7 +398,7 @@ " return gap, out.n_iters\n", "\n", " # define solver\n", - " solver = map_estimator.MapEstimator(\n", + " solver = monge_gap.MongeGapEstimator(\n", " dim_data=dim_data,\n", " fitting_loss=fitting_loss,\n", " regularizer=regularizer,\n", From b34b886ff45196b31809de3cd7439b78c3fc64fc Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Tue, 2 Apr 2024 11:20:33 +0200 Subject: [PATCH 176/186] Update MetaOT and NeuralDual --- docs/tutorials/MetaOT.ipynb | 30 +++++++++--------- docs/tutorials/neural_dual.ipynb | 53 +++++++++++++------------------- 2 files changed, 37 insertions(+), 46 deletions(-) diff --git a/docs/tutorials/MetaOT.ipynb b/docs/tutorials/MetaOT.ipynb index 172024733..ad7426a0f 100644 --- a/docs/tutorials/MetaOT.ipynb +++ b/docs/tutorials/MetaOT.ipynb @@ -23,8 +23,7 @@ "\n", "We will cover:\n", "\n", - "- {class}`~ott.neural.models.MetaInitializer`: The main class for the Meta OT initializer\n", - "- {class}`~ott.neural.models.MLP`: A Meta MLP to predict the dual potentials from the weights of the measures\n", + "- {class}`~ott.initializers.neural.meta_initializer.MetaInitializer`: The main class for the Meta OT initializer\n", "- {class}`~ott.initializers.linear.initializers.GaussianInitializer`: The main initialization class for the Gaussian initializer" ] }, @@ -46,8 +45,8 @@ "import sys\n", "\n", "if \"google.colab\" in sys.modules:\n", - " !pip install -q git+https://github.com/ott-jax/ott@main\n", - " !pip install -q torch torchvision" + " %pip install -q git+https://github.com/ott-jax/ott@main\n", + " %pip install -q torch torchvision" ] }, { @@ -71,7 +70,7 @@ "\n", "from ott.geometry import pointcloud\n", "from ott.initializers.linear import initializers\n", - "from ott.neural import models\n", + "from ott.initializers.neural import meta_initializer\n", "from ott.problems.linear import linear_problem\n", "from ott.solvers.linear import sinkhorn" ] @@ -216,7 +215,7 @@ "This tutorial shows how to train a meta OT model to predict\n", "the optimal Sinkhorn potentials from the image pairs.\n", "We will reproduce their results using \n", - "{class}`~ott.neural.models.MetaInitializer`,\n", + "{class}`~ott.neural.initializers.meta_initializer.MetaInitializer`,\n", "which provides an easy-to-use interface\n", "for training and using Meta OT models.\n", "\n", @@ -239,7 +238,7 @@ "outputs": [ { "data": { - "image/png": "", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAhUAAACLCAYAAADWF2tkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAAxOAAAMTgF/d4wjAACHNElEQVR4nO29d5Bc53Xm/XTOOafpmenJAwzCIAMkAIIgJSYFkkqUqJW0kssu21t2OZVd1tbu6tutXZd37S2v6bhaywq2JJIixSAQBAgQOcwAmMHk1DPdPd3TOefw/QG/L7uRw8x0D3B/VSqRnHRv33vfe95znvMcVqVSqYCBgYGBgYGB4QFh1/sAGBgYGBgYGB4OmKCCgYGBgYGBYVlgggoGBgYGBgaGZYEJKhgYGBgYGBiWBSaoYGBgYGBgYFgWmKCCgYGBgYGBYVlgggoGBgYGBgaGZYFbrz8sEAig0+nq9efrRiAQYM77EYI570cL5rwfLR7l887lcjf9Wt2CCp1OB7fbXa8/XzesVitz3o8QzHk/WjDn/WjxKJ/3rWDKHwwMDAwMDAzLAhNUMDAwMDAwMCwLTFDBwMDAwMDAsCzUTVPRCJBZan6/H16vF+VyGaVSCSqVCjabDVwuF2w2GywWq85HysDAwMDA0Pg80kEFAJRKJQwMDOBf/uVfUCgUkMlksHnzZvz2b/82RCIReDweE1QwMDAwMDDcBY9cUFEul1EoFFAul5FMJpHJZDA2Nga3241cLodMJgOz2YxQKASlUgmpVMpkLBgYGBqKSqWCSqWCYrGIbDaLYrGIRCKBYrGIQqFAs7AAIBaLIZPJwOVywefzwWazweVywWKxmDWNYdl55IKKUqmEUCiEeDyOo0ePYmZmBmfPnsWlS5dQqVRQLpfB4XBw/PhxNDc3Y8OGDZBIJODxeOBwOPU+fAYGBgaUy2UUi0VEo1FcvnwZLpcLR44cwdLSErxeL9LpNACAxWJh586deOKJJ2Cz2dDb21sTZJDvYWBYLh76oKJcLlOtRKFQQD6fh9frRSQSwdzcHBYWFuDz+ZDNZunPRCIRzM/Po1QqwWazoVwuQy6XN0xQQRYUcl4AIBQKG+b4lhuyKyuVSiiVSvT8SRBYvSu7GSTLRP5H/p3D4dAdG5u9NjXL5LMhnwWbzX5o7wOGT9azbDaLZDKJpaUlTE5OwufzYXp6GktLSwgGg8hkMgCuBQw6nQ52ux2ZTAZKpRIymQwcDgd8Ph98Pp8GF2sRct+Xy2XkcjkUi0WIRCIIBIJ6H9ojy9q9m+6STCaDSCSCpaUlfPzxxwiFQhgeHkYoFILb7UYikaAPIGFxcRH/7//9P2g0GkxNTcFqteKll16C2Wyu01nUkkqlsLCwQHcpAoEAL774IlQqVb0PbUUgKd1QKITx8XH4/X5cunQJ8Xgcfr8fiUTilj9LnFsFAgFkMhn4fD5MJhO0Wi3a29vR3t4OHo8HkUi0JgOLfD6PYrGIXC6HZDIJqVQKpVK5Js+F4fZUKhWk02lkMhmcO3cOb7zxBsLhMMbGxpDNZhEOh+mzAnySgRgfH8fS0hL4fD6USiWMRiOee+45OBwO9Pf3Q61W1/O0HohSqYR4PI5EIoGf/OQnGB0dxW/91m9h69at9T60R5aHLqggu1aye8vlcohEInC73RgZGUEoFMLFixdpMEF2+tUpwFwuB5fLhUgkgpaWFmSzWaRSKZTL5brXISuVCgqFAvx+PxYXFzE6OgqhUIh4PA6JRAIOh/PQ6T9I3TgajWJiYgI+nw8DAwPw+/2Yn5+nqd6bwePxYDabIRaLodVqIRaLEQqFYDAYIBAIYDAYIJFIIBAI6n5t7wZyf5OMTT6fp/dnNBpFuVyGWCy+abbi+vviVv/M0DiQ601245lMBrFYDFNTUzh37hwSiQQCgQDK5TK4XC7NvgGfXNNKpYJwOIxyuQyn0wmtVouenh6w2WysW7eubue2HFQqFfqZjIyM4MSJE3j55ZcbYq2+E+QdBVy7vtd/rfqfr39WqzOujcZDFVSQF26pVEI4HEY0GsWFCxdw9OjRmog+Go1SsSZwTcgkEonoCzmXyyEWiyGXy+HKlSvw+XwYHBwEh8OBVquFXC6vy/kVCgUUCgXMz8/jjTfeQDAYxNDQEDgcDrLZLEwmE1588UW0tLSAy+Wu6bRmNU6nE4cPH8b8/DzOnTuHeDyOpaUl5PN55PP52/5sqVRCMBgEl8tFMBgEm83G1NQUpFIpBgYG8MEHH6CzsxMvv/wyZDIZFeY2IiTFm8vlcPLkSczNzSEejyObzSIejyMajUKpVKKlpQU8Ho/+HFmUtmzZApvNBjabTbuaSBmIBKQMjQMJGguFAiYmJuD3+3H+/Hlato1GoxAKhdi8eTNUKhU6Ojogk8mgUqloOZTNZiORSCAajcLj8eD9998Hh8OB3+8Hj8e75fyGtUKhUMDY2BimpqYwOzuLcDiMiYkJbNiwAQqFAnK5vCFfvADoZiAWi2FhYYFuEAqFAqLRKF3fisUixGIxpFIpRCIRVCoVFAoFurq6IBAIaCNBo9CYq+cDQLQTgUAA8/PzOH/+PN5++216wW4GSY2TxTaVSiEej6NUKsHv9yOTyWBubg4WiwVisbhuQQU5t1AohKGhIXi9XrjdbvriVCqV2LBhA2w2G31ZNOoDdbdUKhX4fD6cPn0ac3NzuHTpEs0uAbgrPQQpb1XvDFgsFiYnJyESibBv3z7s378fwLUAs1EhnUuJRAIff/wxhoaGEIvF6E4tHo9DLpejra3thsCI3AtCoRB8Pp9mZthsNvh8PoRCYUPsfK7fld3v7wDWfvaFdHdkMhkMDQ1hcnISx48fx9DQEHg8HrhcLhQKBex2O7RaLXbu3AmtVgudTgexWEyvbyKRoGXf06dP05dZJBKhpZK1SqlUwuzsLGZmZhAIBJDJZBAIBOD3+8Hlcuu2Vt8N+XwesVgMHo8Hp06dojqZQqGAYDCIdDqNVCqFTCYDjUYDtVoNqVQKvV4Pq9UKu91e05nYKPf7QxVUFAoFDA4OYmZmBqOjo/B4PBgfH6cCnushL96DBw9ix44dNHU4OzuLH//4x4jH47RV6/Lly0gkEvjc5z4HrVYLNpu96tEheVG0tbXhG9/4Br0ZA4EAFhcXEQqF8OGHHyKRSECtVkOpVEKpVMJkMoHH49EXx1qCxWLBYDBgx44dUKlUWFpaQqFQAIvFgkQiweOPPw69Xn/H30MyPMFgEHNzc1haWgJwbbewsLCAkydPoq2tDSqVCnw+f6VP664gQVA6ncbS0hLi8ThGR0fh8/lw6dIlzMzMIJ/Po1Qq0QxGpVLB9PR0zXUmL2qBQIDJyUl6n5OAQiKRYNOmTVAoFGhvb6cB9kreK5lMBgsLC2CxWDCbzfS583q9sNvtMBqN4HK5NPgh50EEyteLd/P5PJLJJPL5PJaWllAul6HT6WoykOQZIHA4nFuWihoBcs8GAgGcOXMG09PTiMViUKlU2LdvH/bs2UOfb6FQCIPBQANEIkAGrpUAC4UCZDIZFTaSz3atrQe3ovqFKhAIIBaLa7J1jUKlUkE2m0U+n8eVK1dw9OhRBAIB+izncjmUy2W6VpdKJbDZbKTTabrJnZ6ehkqlQigUglwuh1gsBp/Px7p162A2myEUCmlQWQ8euqDinXfewalTpzA5OQm/33/b72exWOByudi+fTu+9KUvgcPhgMPhYGxsDB999BGKxSIikQhNN4+OjmLdunVYv349ANQlqOByuTCZTHjhhRfg9/tRKBTgcrloUPGrX/0KV69ehV6vh06nQ2dnJz71qU9RkeJaXETILozH42FkZASZTAZcLhdGoxFf+cpXYLfb7/g7CoUCTpw4AafTCQ6HQ8sGmUwGHo8HZ86cgd/vx/bt21fhjO4O8uKMxWI4ffo0FhYWcPToUfh8PiwsLNS0DRJyuRzi8fgNvwcA5ufna/47h8OBQCCAXC7H448/DpPJhG9961sQiUQrnlJNp9M4evQouFwuPv3pT0MikeDQoUM4ffo0nnvuOTzxxBOQy+U3qPiLxSKKxSINJjKZDDKZDOLxOObm5hAIBHD+/Hnkcjn09/fDarVCKBRCIBBAoVBArVbT8yLdD40aVOTzeQwNDWF6ehofffQR5ubmqMh43759+OpXv1rTwXQrWCwWMpkMZDIZ1eJUe1Y8TLBYrJrAqtGoVCq05HH06FG89tpryGQySKfT9Dnl8/loaWmBQqGg76RUKkVL+ouLiwCAI0eOgMfj0Q6eb3/72/jUpz4Fo9FY14xr433q9wFJEUajUQQCgRtaRG8Gh8OB2WyGWq2G1WqFRCKhKSSVSoXu7m4IhUIMDQ0hlUohnU6DzWZjcXERkUgEEokEUqm0LtFg9ctg3bp1kMlkuHTpEhKJBLLZLO1TJ/9OOh42b95MRYlk99YIKe87IRKJYDKZsGXLFqqZYbFYkMvlsNlsUCqVd/wdhUIBHR0dEIlE8Hg8mJ6epostUFsaqTf5fB7lchmJRAKRSAQzMzO4ePEiQqEQFhcXkUgkakpAACCXy6FSqcDhcMDj8cDj8eiiBFw7v7m5Ofj9/htac0l5L5/PIxQK0d3OSizKuVwO6XQabrcbk5OTKJVK9FkaGRmB0+nE4OAg2Gw2VCoVdDodJBIJzUYtLi4imUwimUzSNH4ikUAymaRfm52dRT6fh0AggNvtpsG4XC6HUqmEXC6H2WyGRCKheipSTmik9uJisYiFhQW4XC5ks1mw2Wz09/eju7sbXV1ddxRlk4xOJBLB4OAgxsbGUKlUIJfLsX79erS1tUEqla7yWa0sRJyfSqUgkUjqfTiUUqlEmwPGx8cxMzODyclJuv4Q75C+vj4olUq0t7dDLpfT61ssFmm5xO12I51Ow+PxIJvNIhKJIJvNYmxsDAKBALt27aJrQT1K4A9FUFEoFOB0OuHxeDA2Ngan03nDons9fD4fW7duRVNTEzo6OmoeLp1Ohx07dsBoNMLpdCKZTNKa9eTkJMbHx2Gz2WggstpUq7x37dqFzs5ODA4OIhKJIB6PIxAI0Bckl8vFL3/5SzgcDvyH//Af0NTUhKamJkilUrpTW4469koilUohkUhgsViwefPmmq/d7UNDTM2ampqwuLgIp9NJO3waCaJmT6VSuHz5Mo4dO4bZ2Vl8+OGHVLRFUtjVWCwW9Pf3g8/n0xel3W6HSCQCcG1R++CDD/Dhhx9S59hKpULFYIODg5iamsJTTz0Fk8kEnU5XUypYLuLxOEZGRnD58mW89957iMfjeO+998BmsxEMBpFKpTA7O4tf/OIX4PF4kEqlaG9vxxe+8AXw+Xx88MEH8Hg88Hg8CIfDVGdULBaRTCZpSQQArly5QoME8rwIBAK0tLTgqaeegk6nQ6lUglwuh8lkosFFowQV2WwWQ0NDGBkZQSwWA5fLxYEDB/CZz3wGKpXqjhkKEjxOT0/j7//+7xGLxVAsFtHa2opnn30WTU1NK3KN60F11188HkcoFIJMJqvzUX1CoVDA6OgonE4n3nzzTZw+fRqpVIoGP2q1GuvXr8d3vvMdmM1mOBwOCIXCmtJfuVymwUQwGMSbb75J9Wbk3w8dOoRvf/vb6OjogEAgoM//arKmgwrSk01e9nNzc3ShudWuk8PhQK1WQ6FQwOFwoKOjA0qlsubh5PF4UKvViEQidLdW3cpH6rr1hOhBxGIxisUi1q1bBx6Ph+npaXg8HsTjccRiMfpZ+P1+jI2N0bS/TCaDRqOh9XPy+0hKtJFEntUipHtNVZOOiUKhAK/Xi1AoBJ/PRy2NyUtYp9NBo9HU/YVCNBThcBhOpxNutxsejweZTIbqglgsFn0B2mw2GAwGtLa2oqOjg3puCIVCOhSP/F6v1wsAdMbN9c+IRCKhL5qVKglkMhlMTU3B4/FQm/xisQg2m41KpUKD3Fwuh3w+j0wmA6FQiLGxMfB4PCwuLsLtdiMYDCKZTNJsS6lUqslOslismnZxUurM5/OIRCIIh8NUf6JSqWh5sJFKIURDks/n6UslGo0iFApBIBBAIpHcNFNBMhQk0zU1NQW/3w82m42uri40NzdDKpU2VAB1vxCNAtHNEe1Qo2kqKpUKotEofD4fQqEQ0uk0zbYajUasX78edrsdFosFGo0GAoGgRttVbXKnVqvB5XLR29sLpVKJkZERmp0non2fzwelUkmz0qvJmg4qMpkMnE4nJicn8T/+x//AwsICXSxuhVwux5NPPgmr1YpvfOMbaGpquuHm43K5aG1tBYC6RHp3C4fDoW2Qv/Vbv4VCoYCjR4/i9OnTuHz5Mj7++GP6UnW5XHjttddompfH42Hv3r1Yt24d3cFZLBasX78eIpEIcrm8IWuS90o+n8fExAS8Xi9+9KMfYWBggGZ0SJ29q6sLBw4cgMViqbsTX7lcxszMDM6fP4+zZ8/i8OHDNENB4HA4sNls0Ov1+M3f/E3s3r2b6gaAT16i179sNm3adNuAGwAte6zUy2ZxcRE/+9nP4PV6adscCWh1Oh2kUiltoySGTtFolOpBMpkMbQcnQcOdSlfk6+T3BoNBjI6Ogs/n49SpUxCLxfjt3/5trF+/vqGCCkL1nI8jR47A7Xbjc5/7HPbt20fLXdU7WpK5OXr0KN566y3Mzs5ienoaGzduxDe/+U3YbDZotdoaEexapVwuw+v1wuv1IpfLgcvlwmq1orW1taHW7lKphKtXr1KhdalUgtFohMVioZo+lUoFg8FAyxbXP4OVSoWWgg0GA1paWhCPxzE9PY2JiQla0rxw4QL+4R/+AZs3b8aLL77IBBV3C3l4QqEQXC4XFhYWqKL/ZhANgsFggM1mg91uh1qtvqmghYh9Gl0dTRZj4FqJoFwuw2630wWb1Jaz2SzK5TLi8ThtSwRAW2RJUJFOp6HT6WjakNTmSeZiLegvgNpFOJfLYXFxkS6sTqeTvnDFYjGsVit9SOvpREnS9tXGZj6fD/F4HBwOB0KhkJYDhEIh2traoNFo0NzcDI1GU2N61MiUSiXaJkd21OR6iEQiqNVqxONxJJNJ+v3E/Ox2997d6mGu7xgJBoMQCARIpVJ3DLhWGzabTQN8Pp+PXC6HQCAAkUiE+fl5JJNJ8Pn8GqOn6u6ChYUFTE5OUk8XsnHQ6XR3LJ2sFUi5kFgAkLZpskNvlHMk97dYLIbJZEKlUkFzczPMZjNaW1upCd/tZkxVm14BoJkMspkg9240GoXX64Xf76dr4Wp+Do2/Ct0Esij4fD68/fbbcLvdSKVSt/2ZdevW4Rvf+AZsNhv6+voglUobqua2HLBYLHR3d6OlpQWf+tSn8M1vfhMulwtnz55FLBbD5OQkotEoXC4XkskkLl++jNnZWXqjisVivPnmm+Dz+RCJRBCJRHjyySfR09ODpqYm6HS6hp8tQZwHU6kUZmZm4PF48KMf/Qhzc3OYnp5GoVCA3W5HU1MT+vv78alPfQparRYOh4MGUfWgWCxifn4eoVAI77//Pj766CNEo1EA17pfent70dzcjJdeeonOb+DxeDAajTU71UZHrVZj+/bttITBYrGgVCohEomwYcMG6PV6uN1u2mK3EshkMvT09IDD4WBubo6+lKsN8RoBsViMxx57DFarFblcDvPz81haWqJdbfPz89BqtVSALBKJUKlU4HK5EA6HaRdcpVKBWCyG0WiEw+GAXC5vqNLAg1AsFuF2uzE+Po5UKgUWiwWZTNZwtvtklALp/Ein09BoNJDL5VSovFxt3ESsbLPZUCgUaKCyWmvEfQcV8Xgc/+f//B+88cYbmJiYQKFQQG9vL373d38XX/nKVwAAg4ODtK1renp62VLLpL6YSCRoluJ6H4rqzgYOhwOr1Yr+/n7odDoYjUbw+fw1sxDfLcS7gdykOp0OOp0O4XCYmqmIxWJqepNKpai4Dbj2uY6OjtLfJxQKodVqweVyoVQqodFoANy7rmE1ITvbZDIJp9OJ2dlZDA8P0/Qon8+HVquF1WpFW1sbOjs7qRC0notQuVxGKBTC3Nwc5ufnsbCwgEqlQoM9i8UCh8OB9evX09IU0Qnc6bhvt/te7WdAIBDAbDbTtuBisUhr4BKJhL4Qbpb+XS54PB4dECgUCqkmo9GCCg6Hg+bmZuq5EQwG6QtpbGyMptDz+TyEQiEUCgUqlQqcTicikQhmZ2cRj8chFAohk8noZ/ywiDOBa887EdETT4fVfoneDRwOhxpYGQwGOvhsJUo0RGdIMtSrnX27r6Digw8+wNe+9rUbfCAGBgbwyiuvoFKp4JVXXsF/+S//BQDwx3/8x8taqyZBBbEzvb7Fjs1mY8eOHejo6EBTUxPa29ths9nQ2tp6gzHMwwqXy4VIJILZbMbzzz+PXC6HUCiEXC5HVfakB56468ViMczPz9MArVAo4NixY5iamkI+n4dEIoFcLm9I61vSajoxMYHDhw8jEAhgYmIC6XQaSqUSWq0WGzduhE6nQ09PD3UhJKWDep9PoVDA8ePHMTIygpmZGVQqFdryuGPHDnz5y1+mJZpqi+27CShSqRQKhQKt1ZKUKJvNpkHJaj0TBoMBL730EoaHh3Hq1Clq6pVIJHD69GnIZDJ6LxIzr+UmmUxiZmYGLBYLU1NTyOVy+OCDDzAxMYFnn30Wvb29y/437weBQIC+vj60tbXBZDLB5XLh0KFDOHHiBLLZLEZHR2kGjpRrASAcDiOTydBysN1ux65du9Df39/QG4L7pdrOmsvlIhqNIp1O07JuI0A2AOS5q1QqK3YtuFwuNcSqR8n6noOKX/ziF3jppZdQKpWwf/9+/H//3/+Hrq4uDA0N4dVXX8XCwgL+5//8n1i/fj3eeust2O12fOtb31rWg66umZN+9WrYbDY6Ojqwbds29Pf3o7Ozk4oT6+GEWQ/IC4TP50MqlaJSqaCpqYkK3EqlEkwmEwQCAWKxGFUMezweGqCVSiWMjIxgdHQUnZ2d2LJlCxWH1vslXA2pk+fzeVy9ehU/+MEPaF1RIpFg69atMBqNePzxx+FwOGC1WqFQKBqqlFMqlTA+Po5Tp04hFAqhUqlQIanNZsP69eshlUrvWc1Nas6JRIJadJN7gMfjQSwW0wVuNa4pyaSVSiVotVrEYjEEAgHk83la7qk+9pUgm83C5XKhXC7D5/Mhn8/j4sWLcLvd6O/vb5iggsfj0bZXqVSK3t5ezMzMUE+aYDAIAJibm6M/c/1nxmKxYDKZ0NfXh87Ozody7SOGaOS9QAzRRCJRQ2WkV2utYbPZtHukHvbd9xRUeL1evPrqqyiVSti9ezcOHTpEa3N79+7F9773Pbz66qu4evUqvve976FSqeBP/uRPlt32mFyc5uZmfP3rX6cPGNlhczgc7NmzB62trdBoNDRiu1uPdFJaCYfDa94bv5pqwysOh0MHEJEhax6PB01NTQiFQjhz5kzNIk9MV4RCIYxGY12Ov9qTIJFIoFAoIJ1OI5fL0fbLwcFBxONxWrtXqVTYs2cPDAYD2traqK1tI2Qnqql22iNaAtJ2NjU1hY8++ggqlYqWo+6WQqGAS5cuwe/3Qy6XQyqV0lZskUgEo9FI0+MCgQBNTU2QSCTUSI4MqCKGa8s1Z4BoKDQaDc6fP0+da68vY1bP8bhTkFGpVKjOhMPhIBAI3HSCLemGAkDtj+/ka1NPiGCTy+XihRdegN1ux8DAAA4fPgyBQAC9Xo90Oo2ZmRlqplSN0+mkXgZ6vR4KhQI6na6hXrj3C4/HQ29vL4rFIj1/kuF72DPSt+p6UqlU6Onpgc1mo+tcwwo1v/e97yGRSIDD4dD2xGq2bNkC4Fo66uc//zlaWlrwjW98Y/mO9t8gu3Cz2Ywvf/nL1K6X1EOrxV/345BXLpcRCAQeqqCiulOEYLPZYLPZqA4hHo+jubmZur2RoIKMTp6YmIBara6bQr5YLFKn0MnJSYTDYUxPTyMajeLkyZO4evUqDZy6urrw/PPPo7m5Gf39/TVzTxpxoSmXywiHwwiHwwA+sVbOZDI4e/Ys0uk0pFIpLdfcLblcjrYhEmFYoVCg5SyHwwGxWAy1Wg25XI5XXnkFDoeDlpHa29uxc+dOatBDjNceFLFYjK1bt1JhZqlUQjQarQkqrl8079T9QXwu2traIBKJcPHixZsGFZlMhgYVJC1dPVOk0SCdSpVKBTt37sS2bdtgt9sxPT0NjUaDtrY2xGIxeL3emwYV8/PzcLvd6OnpgVKphMPhwL59+2qGUa1VeDwe1q9fD6FQSDdGZOPwsFPtXVGN0WhEb28vOjo66jLB9K6DinQ6jR/+8IcAgFdeeYXOv6im2i65UqngT//0T1e0zY0MRCJuidUfLolS7+eByefzdLwwGQ1MRsyS+nwjWcDeL9WfTblcRiaTwfT0NObn52kLH1l0dToduru7oVarV30RIkZI4XAYU1NTCIVCmJiYQDKZxNLSEtLpNPx+P/L5PBUztra2oqWl5QbNRKMuoGw2m/atRyKRmpcDsV4XCoVIJpP3HFSQdslMJkNNoYgocXFxEQKBAMlkElKpFCdPnoTf76eD+YgeQ6fToaOjg1pmP2itmsfjobOzExKJBOfPn6d18Hw+D6vViubmZjruOZvNYnZ2lgptSYBMsicCgQBCoZC+TBwOB7hcLtRqNXw+H8bHx+HxeG55LCRlTpxzs9ks9X9oJKo3SHa7HQcPHqTtgvl8nt7nJKNEMlMk6xqNRjE+Po5EIoHW1lbodDoolco1LdysvhcaedOw3JDnmdyv1ZBntG6Oz3f7jSdOnKCDil5++eWbfk/1S72trQ1f+9rXHvDwbk/1rulmtcT7/UCTySROnjyJyclJxGIx6nwmlUrhcDjQ0tLyUKQOCcSN0Ov14v3334fT6aQ7ZlKb6+npwa5du+py3tFolKZw/+Ef/gGJRII6YpLdJXlJ7tmzB9/5zncgl8uh0WiojoYEno0Kl8tFT08PSqUSTp8+DbfbTb8Wi8WQSCTue1dJsm3JZLJmcFEymUQoFKopaRw5cgQcDofWqYma3mg04qmnnkJTUxNeffXVBw4qxGIxNmzYgJaWFnz88ceIx+M023jw4EF89rOfhVarpUHWm2++iVgsBolEQkXIfD6fBvlqtRo2m41OIi2Xy1hYWEAoFMI//uM/4ic/+ckNI9GrP0viPutyuRCNRhs2fV6difvd3/1d+P1+fPTRRygUCrS0qVAoIJFIsHHjRlitVkxMTOD48ePw+Xz48Y9/DK1Wi3w+j5aWFjz33HN1K2cuF+Q90GjXaiUhc2G8Xi/V1hCIW7ROp2vsoOLkyZMArr1knnzyyZt+T/WL/bvf/e6qmPEs5w6UlAFIN0QqlaI97AqFAiaTiXruP0yCJ6JRiEQitG2temgXme5IpuGt9o1KRnsT7UwqlUI2m71pqjqdTiMYDCKTydCdLfF0IP/fiClfNpuNlpYW5HI5eL1eahBF7LSvb3UkO1SZTEZnBLDZbDogjNThrz/P2513uVymO5/rv4/L5cLj8YDFYsHv99e82O8HkgHj8/mwWq1IJBK03bm1tZU+a6TTqLOzE7FYjNovk2CX+M3IZDI6gInL5aJcLkMmk6FQKNQ4jVZ/dtUQUTNpR2zk55tcazIaOxgM0vEEXC4XDocDFosFra2tsNlsKBaLcDqdiMfj8Pv9iMfjcLvdNSO1G/GZuBfW8rHfD6VSCUtLS5idnUUikUClUqHvJZFIRE3yGjqoGBsbAwA66ORmDAwMALjm9EW8KtYSqVQKHo+H+hqQATw8Hg87duxAb28venp6aH3+YbmRo9EoTp06hZGREbhcLsRiMRQKBfD5fOzatQsOhwPr1q2r201KXBjT6TTS6fQtAwoAeP/993HhwgX6kpBIJGhra4NWq8W/+3f/Dna7nXZBNBJCoRDPP/88Dh48CIPBgKamJly5cgXDw8M3fG91S+jmzZvR1dVFjbsikQjGx8chFArR1dVFDZGAOy+8hUIBv/rVrzA5OXnD90ciEXz44YeQy+XI5XJoamrCF77wBbS0tDzQeQsEAnz+859HMpmkQk0y/4AIigUCAZ555pmaYKD6/8lLtlqUVi6XIRQK6VTe28FisaBWq2GxWGC326FSqeqimr8XstksAoEALl26hH/6p39CMBhENBqFQqHAb/zGb2Dnzp30xeJyudDX14exsTF8//vfRzwexzvvvAOVSoW9e/dCr9fT4YIMjQ159lOpFH76059iaGgIMzMzAED9iaxWKx05UI/g+K6DCp/PB+CaI97NKBaL+IM/+AMAuK3VaCNCUuikt9vr9SKbzdIBNWQImdVqhVQqXVPndj2kFZecMxmNvLCwAL/fj2w2Sz0NBAIBdDod7HY7HTxWD0iLFBEUkhT+zXbwxJ6Y+OALhUKkUilYrVY6ZAeo3e01wsuDzWbTF2Brays8Hg9tiyXdGsAnU1lJWaKpqQl2u53u3qVSKTKZDPh8Pux2OyQSyV2LHXO5HOx2O0KhEHUmJTM4iMlQPp+n+oY7udje7XkrlUqIxWJqky2Xy2s2Lvejb6iutd/p+pKsCZnq2Mh25+RaEo3R4uIivF4vEokE9SdoamqCwWCggSbRwxAjLJKJJZbwqVSK6tPWIsSnopE7eJYLsn6TcrXT6aT6K6KlUCgUdR2Od9dPD3kwY7HYTb/+l3/5l5iYmADQ2I6LN4MItM6ePYt/+qd/opPkyuUyNBoNlEolNm/ejB07dkCr1db7cB+IYrGI8+fPY2RkBHNzc/B6vYhEIpibm0MikUA6nQafz4fD4YDJZMJnP/tZbNq0qWY0/Gqj0WjQ398Ph8OBvXv3IpPJ0BbEUChUY+fsdDoxPj5Op3yWy2XMz8/D6/Xie9/7HgwGAx577DH09/fDbDbDbDY31FRWDoeDrVu3orOzE9FoFJFIBC6XC8PDw2Cz2VCpVBAKhWhtbaUTDknan81mI5fL0XkhxHnzbjMV5XIZ+/fvp/eD2+3G2NgY3nrrLfoZ5/N5XLp0CbOzs/j617/+wOdLHEOrMyprbf1YTYjJ09WrV/HDH/4QLpcLuVyOalTsdjsV0pLrLZfLsXHjRohEIhw6dAhzc3NYXFykQ8eCwSCeeuopdHV11fns7p1CoYChoSFcuXKF6sAeZsgsI6fTifn5eUSjUboJfPrpp/Hkk0+ir6+vbsZXwD0EFa2trTh58iRGRkYwOztLp3gCwNDQEP70T/+U/jvZ5a9UxP+gbV/VHzQZwBMKhTA2Nobjx4/TnRkxByKpUeJ5sVYhJlEjIyM4deoULly4QAPBaogXhdlsRnNzM/R6fR2OtvZ4hEIhvRak/ZUMhCIdOuR70+k0BAIB/H4/VfWXSiU6/pnP51NnVb1eT3e0jQDR78hkMhiNRhQKBTrCnM1mw2QyQSQS0e4IogMgP1sul6HX68Fise55JkilUoFSqUShUIDZbKZuq9UBV7lcrtG1LMf5NkKXBcn+NEJgeSuqTd6cTifOnj2LSCSCUqkEiUQCm80Gi8VCbc4JpLSh1+uh1+uRSCTg9/uRy+UwOzuLcrmM/v7+Op7Z/VMsFuH3++FyuZblfmx0CoUCAoEA3G53TecHmVy8adMmaLXaum6S7vqt/+yzz+IHP/gByuUyXnzxRfzd3/0dHA4Hjhw5gt/8zd9EJpPBc889h3feeQf5fB5/9md/ht/5nd+hEdNyQKyYiQPkvUBq7EKhEDqdDjwej6bQ3W43PvroI4yPjyOfz9OUukgkwrZt22CxWGA2m294WNcS+XwebrcbgUAA586dw4ULF26wWSdmSAaDAY8//jiampqgUCjqdMQ3QkoyJK1bLpehUqlq0p5WqxW7du2iQ3WWlpZw5MgRLC4uYnFxEalUCgMDA4jFYnA6nUgmkzAYDHSgWCO8WKq7MdhsNgwGA/bs2UPbBYlIsjoYqtYYLEcwT9xqV8oqu5HgcDhob29HR0fHLcu7jUClUoHH48H09DQGBwfh8/kgEAjQ29uLtrY2vPLKK7BarVCpVDU/R4zDBAIBbDYbgGsunJlMBj6fDxwOB4lEoh6ndN+Uy2UqSo7FYnTj0MgC2weBlHojkQjeeecdLCwsIB6Pg81mw2KxQKvVorOzEwaDoe4twne9+rz44ovYv38/PvroI1y+fBnbtm2r+fpTTz2Fn/3sZ9i4cSMmJibwx3/8x/jjP/5jjI6Ooru7+4EPlNSSCoUCPB4PBgYG7mmxEwqFUKlU0Ol0UKvVdHdEWs/OnTuHqakpavcKXBORORwOdHR0QKVSNYyP/P1QKBRw9epVTE1NYWBggGYoql+gIpEILS0tsFgsNJiqZ9njeu6m7kucMwuFArZs2QK/3w+/309dIlOpFCYmJjA5OQmv14tCoYD169ejqampYSzcr5/rQTpY7vZnlyPwJV1QD4v52+0gHiEdHR2Qy+X1PpxbQiaQHjt2DKOjo4hEIjAajejq6kJHRwd27dpF7eevh8Vigc/nw2Kx0MmVRE9Fur9We0T2g0CCCtIVRjxY1nIm+XYQjVMoFMKRI0fg8XiojqapqQk2m61G+1bP63jXQQWHw8F7772HP//zP8e//uu/YnZ2FtlsFmazGV/5ylfwve99D1wuF6+//jq+8Y1vYGBgAEKhEJ2dnQ90gET5T0b/+nw+jIyMYHx8/J5+DxExyeVyLC4uQi6XU3vejz76CFNTUwgEAgCumXitW7cONpsNu3btgtVqXXNmVySyJe2Y4XAYQ0NDcDqdN8xYkMvldITygQMHYDKZaM2+kUVrt4MEIGq1Gs8//zxcLheMRiMWFhYwPT0Nn8+HcDiM0dFRFAoFukttamq6afBIDKOIQHI1HlpiWkSuITFf4/F4kEgky1I2INk68izMzMzA7XZjeHgY4+PjmJ2dpS6XpFTR3NwMk8nUUFms6ymVSnA6nZicnITH47lhA1KtM2GxWFCpVGhqaqr7Lu9mkOtDxnzPz88jGAzSa3HgwAE4HA6qo8jlciiXyzVTmkmbbTKZpAMYSdeLzWZryCGBt6NYLCISidASJ+mGqvcLdaXIZDIIh8OYm5tDKBSidgcCgQAbN27EunXrYDQaG+L87+mNIRQK8Sd/8if4kz/5k1t+T29vL86fP//AB0YolUo0bf/6669jcHAQLpcLi4uL9/R7yMNFhGHVplnEE4D8u9FoxAsvvIC2tjbs3r2bzrpfS5CAIpPJwOv1wuVy4dSpUxgfH7+hdKTT6bBt2zb09fXhq1/9Ku1CaIQb9H7hcDgQiUQQCoU4cOAACoUCurq6MDY2hg8++AAffvghAoEAPvjgA4yNjaFSqcDhcNzS2IlYhBOjndUog3k8Hhw9ehTxeBxLS0tQKpXYtWsX1Go1bSN9EKo7aPL5PLLZLN5++218+OGHVMRLumiAT+yid+/ejebm5oYWLZdKJXz88cc4duwYrl69WvO1m9l/m81m2Gy2hsrMVZPP55HL5TA9PY2zZ88im82Cx+Ohq6sLL774Yo1nSCaToYJr0tVCrMhDoRBtlWez2bDb7Whtba1xQ14LEF0JmURcLpfB4/EeOg8hQiQSwfnz53Hx4kX4fD7aeSUQCHDgwAFs374dCoWiIcrzDbsNJbuzTCaD8fFxzM3NweVywefz0foZ8EkNmdQRSavUrX4nqS9Wvyzz+XzNIkPqybFYDJlMhtbxG+GC3QnSKppOp2n3wNWrV7GwsACfz0eDJy6XC6PRCL1ej+7ubvT19VFfg4dhJgCh2mpcr9ejUqnA6XQiEAhgYWGBtmT5/X7ajkmCzurFqVwuI5VKgc/nr/pcAeKfEolEIBAIoNVqqWMouVakpfReKJVKdCCb2+1GJBLBzMwMNd+qfo4EAgEMBgMMBgO6u7vR0dHRsC9gQrUd+Z1o5F0uMSWLxWIIh8N0l0pabblcLjWlI+6yuVyOCphVKhVEIhHdZBDbdg6HA71eT9Pmaw1SEiBZCp1OB41GsybP5VaQoD4UCuHq1at0Tg6Hw6kxZKzeKJMNJXmnVSoVqhUkujHy38n3kHISALqhvN85Pw0bVORyOdr69Dd/8zcYHh6mwrHqnRNps3vsscfAYrHw8ccf3yBAJJAPsVolTB7EapLJJKanp5FOp9Hd3Y1cLge9Xg+RSLRyJ7xMkCmeU1NTOHz4MJxOJ95//32kUim6mBAnwmeeeQaPPfYYHA4H2trawOfzIRKJGnZxfRA4HA6am5tpqre9vR0nTpzAP/3TPyEej+P8+fNIJBJU/CaRSGrqs+TFK5fLbxDCreQxi8VixGIxHD16lBqxKRQKzMzMoKWlBTabjabuTSbTPV23bDaL4eFhzM/P4x//8R/pPJXq4XwElUqFT33qU7Db7fjCF74AnU63JkpjNzOxqv73au1Ko97zxWIRly5dwvDwMIaGhhCNRqnonAQVRO1Pgl+SkYhGo+jt7aXuooFAAIuLi7QNsa+vD08++eSat+rm8/nYuXMnmpub1/y5ECqVCg0OL1y4gO9///v0HSiTybB161bY7Xbqn0Tu33w+T718yuUyzXCRgZEKhQKFQoEG3CTYdDqd4HA46O3thVwuh1qtvq93XsOuCuTlH4vF4Pf7EQqF6NcUCgUtSfB4PJhMJrS0tIDFYmFubg4sFguJRII6L1YHDXcj7iTGMnw+Hy6Xi7aXkkivUcVAxMArFothYWEBCwsLcLvd8Pv9KBaLEAqFEIlEdFplU1MTWlpaoNPpanbnjbq4Pihk8VUqlWhpacHExAQ1h0qn01RFnslkbniYyA5wNYWLQqEQer0ecrkcPB4P+Xye9qW73W6Uy2Vks1k6NK3am+BuIFNe5+fnMTs7C6/XC6DWxpq8vLRaLW0zlkqla0K0fH3wQFhLgkTg2vFGo1GEw+GaLG11RrH6fEgGI5vNIhgM0vbrVCqFRCKBVCpFO6g0Gg0UCkVDtPU+CCwWC0KhcM2bExKqnTMjkQgWFxcRDodvmlUns4GAa59DMplEJBKhmZxsNovJyUkkEgmqoyGmdiRDkUwmMTc3RzMgKpUKEonk4QsqcrkcTc+Wy2XqvPfpT38a+/fvh1gshkwmg1Qqhc1mQ7lcRl9fH2ZmZvDRRx9hYGCApgLvhXg8jjNnzkAoFGJ4eBhKpRIbN26EzWZraJOYcrmMQ4cO4a233oLb7cbMzAx1yJRKpXTU9J49e+BwONDe3k59Goj4cC0ttvcDi8WCXq+HSqVCMpnE8ePHEQgE4PF44HK5MDQ0hFKphA0bNtRdtGcwGKBSqaDRaOByueB2u3H+/HmkUin86le/ohM6SWblVnXx6qxDdUknk8kgFAohm81iaWkJQG1KlMViwWAwYPv27bDb7Xj66aeh0Wjq/rk8alQqFYTDYQQCASSTSQCg2cbrrwWbzYZarYZEIoHL5aIWzgAwPDyM0dFRxONxavzW3d0Nk8n0ULyIiX5krZ8L6XQk5mQffPABRkZGasYTpNNpXL16FXNzc0gmk9BoNPTniWEeEV8TG4ZSqQSxWAyBQEDLI6R8RLIVbDYb77zzDjQaDb773e/el39JQwcVJNIiJ05eflqtFi0tLdTbXyQSQalUolwuw2w2I5VKQS6X02mL1VSb3FRPsKxWwJPoLx6PI5FIUOOldDqN7du31+kTuTnkJUBuwunpaZw5cwbxeJxOWCW+BmazGUajET09PbBardBoNKuuD2gEiGGUUqmEXC6nu79cLodkMol4PN4Qlr/EtEin08FkMqFUKkEoFFInUbLjvl0geH2QcKegkTwbpKZKjN+sVivVcjyMQrhGJ5fLUVNB4Nq9cb2dOVkzyfXL5XKIxWI0aCZ23vl8nrbXy+Xyhs283i3k/iZZtYfh/iQmZ/Pz87h06RKWlpZqNgflcplO9r1y5QokEgl9dgOBAJaWlm7I0gOfrAE3+xoAOmQuEAggnU7f17E3bFBxPZVKBfF4HOl0GufPn0c2m4VYLKY+51KpFNlsFidPnsTCwgLm5uaooAn4pPtDIBBQccvu3bthNpuxuLhIXcquXr1K5xqQbAnp5ybDlBqFQqGARCKBeDyO119/HdPT0zh37hyCwSD129Bqtdi8eTOam5vx9a9/ne7Sq50YHzWq64iRSIS22PF4PBgMBhiNxob6bLRaLb7whS/A5XKBxWLB4/Hg8uXLiEajdzXX4mb/fDM4HA66urqoe6xSqURPTw+eeOIJSCQSqNVqKgpkWD1IWZNoogBg06ZNOHjwIPr7+8HhcFAqlZBMJpFKpXDixAlMTU1hZGQEV69exejoKC5fvoylpSVkMhkoFAo8+eST6OzsbGizr7uF+FN0d3dj165dq6Z5WikKhQLGx8exsLCAS5cuYX5+/ga30HK5jEwmg0KhQDdFcrkcYrEYzc3N6OzshEAggEwmo8aPXC6Xli6J8Pd6yDwhtVoNu91+X8ffsKvDzRZAIlqZnJxEPp+HTCaDUqmk35vNZnHixImbCjWrhwaZTCYYjUbs3LkTbW1tmJycxNzcHLhcLpxOJ43Wqm1xySjtu1GSrxYkWvV6vXjjjTdw+vTpG3ajEokELS0t6OzsRE9Pz0OljL5fSPaLBI8kUGSz2dQiu5FSqGTWh1Qqpc6fk5OTiMfjAG58Vm6lG7pTgEGMdIgPhcViwbp169Da2vrQtuqtFYjRExlyaDAYsGXLFjq7hnhQRCIRnD17FhcuXIDH40EwGASXy8XY2BjNwhL/oJ6enjUhPr8bOBwOzcSu9fJcuVzG7OwshoeHMTc3d4OvEIGUMIiZGXAtW6NSqaDVaiESiah7tFgsBo/Hg16vh1Qqhdfrvaktg0AgQF9fHxQKxX370DRsUEFaHsvlMpqamqhIiaT05ubm6IdFyOfztOZIUkFkZ6XX67Ft2zZotVrs2bMHBoMBzc3NkMlkUKlU6O3txdatW7F582Y4nU784he/oL3eYrEYzzzzDDZs2PDAo56Xg+oulitXrmB6ehrhcLjmRWGxWOBwONDb24uXXnoJOp1uzac5l4NKpYJkMolgMEg9O8rlMiwWCzo7O+lApkbKVHA4HAiFQhgMBnzuc59DLBZDb28vHaBUfd2j0ShmZ2fpxN3qYWtElGc0GmGxWG4InDgcDjZt2gSTyUSHfJHe97WmtWGz2Whubqafk9vtrvch3TeVSoXO6yDX0+l04qOPPsK2bdtgt9trxMZLS0twOp1IpVJU2EdakQ8ePAiLxYK+vj4YjcY1vyaw2Wxa6iaau7Ue/LJYLOh0OthsNvT399e4vJJJxAKBAGq1mk6lVSgU1KuEzEricrkQCoU103qJZ0lzc/NNZ6WQz5Note6Hhg0qOBwOtFotyuUyjEYjdDodFV3G43G6S7sVpCZMPuSOjg488cQTcDgc2LFjR00wQkQubW1t2LBhAyYnJzEyMgK/3w+FQgG1Wo2DBw+ir6+vIVTvJIOSSqVw5swZOJ1ORCKRmvkPNpsN27Ztw9atW7Ft27a6jsJtNCKRCPU+CYVCkEql6OzspCOjb2V1XC/IoiAUCrFhwwaUSiV0dHQgnU7f8LL3+Xw4evQootEo5ufna+qiNpsNer0e69atQ39//w0/y2az10zr9J0gz0A8HsfU1FS9D+eBIIFwOBymWTWi1BcKhXjiiSdqOgU8Hg/dhRJ9BTH127dvHxwOB5qammraENcqxGWWiJUfBqEmsUlobm5GLperMZmTSqV0jWptbYVEIoHJZIJQKGyYa9mwQQWxWVYoFFT7MDc3B5/PR7+HOA3m83nEYjFwOBy0tbVBo9HQEdFqtRoajQYWiwWbNm2CQqG4ZU2YPKRmsxkvvPACUqkUbcO0WCw1xiH1pFQq0XZbj8cDp9OJXC5H06JkVPju3btht9sfGjMrYmZUKpVoCxVpISP9+reCiFlLpRI8Hg8uXrwIl8uFSqVCxb/k3mj0z4nFYt3go0Fgs9nYs2cP0uk0vF5vTaZCp9NRgZ5CobhpULHWF2QCi8WCRqNBe3s7HA4Henp6EAqFarpcyPcRXxeDwYANGzY0XImQxWJBLpdDp9NRvxnSXjw3N4fLly+jVCrh3Llz8Pl8tDWYOMq2tbWhr6+P7nxVKlVDZeLuB9JW7fP5wOVyIZPJ6K68EdboB4HNZkOj0dCsBMm+A9c0DyRTQeZRNVomsWGDCmKzzOfz8YUvfIE6a87OztLvmZ+fx8DAAJaWljA+Pg6hUIhnnnkG7e3t1BTIYrHAYDDUdH3c6qYjLyaxWIxvf/vbNxxPo7RcFgoFhEIhzM/P49y5c5ifn6cT+jZt2oSuri4888wz2LFjR0276FqHCFOj0SguXLiAYrGI/v5+KJVKKiK8HUSTc/r0afzLv/wLbbMio95JOrjRFyU2m31LN0ulUgmr1XpL5Tf5/1ud48NwnwDXnlciyCYbjpGRERw+fJh2exHy+TyOHTuGubk5yGQyNDU11fHIb4QMPGtvb0c4HEY4HEY6nYbf78eJEyfgcrmQyWRw9epV2oIPXLsXDAYDnn76aXzta1+DTCajniaNspbdL+l0GleuXEEwGKQD98jLdi2fF3Dt3iXutTezlK/+50a8jg0bVACfLH4kAr1elU+Mi+LxOJqamsDn89Hb24vm5maoVCpIpVLIZLJ7qhtePyGyESkUCggGg1hcXKxpMyMeDCRbU+209zBAREmpVIoOtBOJRNBqtTCZTLedMFkul+mUUrfbjXg8jlwuBw6HA4lEArPZfFOdQaNyq2v6sFzr5YAIs81mM3p7e5HNZjE6OlpjDkRIp9O0u4zY+TfKGsBisWAymWC1WjE5OYmFhQXamZZKpeDz+VAul6kHAam12+12aDQadHV1QalUPjQ7eeATW+lKpQK5XE4zlg/L/b+Wr1FDBxUAaFAhEAhoKpOwdetWPP/887Q/GwDdad4pK7GWiUQiePfddzE/P08HywDXFtFt27bhs5/9LKRS6ZpPcV4PsSBeXFzEL37xC3g8HmpVbLPZbjvgqlwuw+/3I5FIwOv1IhQK0WmfHR0d+PSnPw2dTtcQmhmGB4e0GfL5fKxbtw5dXV3o6ekBj8eDy+XC+++/X6NJIXqEUCiETCZDA5JGeElxuVzs2LEDDocDfr8f09PTKBQKiEajSKfTSKfTUCgU2Lt3LzQaDfbt2wer1QqVSkXLAsR+/2FaD4nVuN1uh1qtXvNdHw8LDR9UAJ9EbWtlF7nSkKmShULhhhS3XC6nY7EbYUFcTojXCJ/Pp7XEZDJJjdGqAyzSgldtbBaNRmlqWCgU0nSw0WikC/CtFl02mw2BQLDm1fKPEuRakuBCqVRCqVQiHA7XaIyI3ob4lkSjUYhEIjqois/n1/VZYrFYkMlkqFQqsFqtaG1tpWZYpFyr1Wqp0Nhut8NgMNByABGtP2yQwE+hUECpVD6U57gWYa7CGkQmk2H79u2QSqX48MMP6304qwZJZSsUCvyn//SfsLS0hKtXr9JRztXp7FAohMHBwRonQhKEbdq0CZ2dnejs7MSePXvoXIvbTeUTCARobW1dM9NqGT6BBAREiH19YFipVBAKhaiJ3MTEBB3S5nA4sHv37rrugqv9U37zN38TX//61wGAumeSYJt4EYhEItox9DAItG+GQCCAxWIBi8WivgqPojtwI8IEFWsQPp8PvV4Pv98PmUxGd+ukFPCwUt1v3dnZCYvFgnK5jKWlJaRSqZq+62KxWDO5EfhkAJPZbEZzczP6+vrQ3d0NkUhEd3S3guwIH9ZF+lGAZB2ImyxptQRqjfVisRjMZjMV8u7cubOux01KOQCooPhRh8PhQK1W004J4lXBUH+YoGINIhKJ0NXVBavVio6ODprm53K56OnpeWj8728Fl8ulNtIKhQL5fB6lUqkmU5HJZPDNb36TugiSr1UvQjKZjM6yuJuZGCQNzgQVaxOlUoldu3bBaDQiEAjA7/djcnISqVQKmUyGzn+JRqOw2+2w2+2w2WwP9bO0VpHJZDh48CBtr+ZyuQ/1hmotwQQVa5DqGrHVaq334aw6JNULYNU8BR42kdujiFAohM1mA5fLRUdHBxQKBfx+f43/SaFQoEJNMieHofEQCoWwWCz1PgyGm8AEFQwMDI8ExCSJzWbjueeeQyAQAAAsLi5icnISkUgEu3btQl9fH8xmMxwOBzQaDRNMMjDcA0xQwcDA8EjA4/FoG7Fer0c6nUYkEsHk5CQAYHZ2Fo8//jheeuklWuYi3g4MDAx3BxNUMDAwPFIQTQyXy6XGUCqVCps2bcKWLVsgk8lqvofR0DAw3D1MUMHAwPBIwufzsWnTJlQqFTz11FPU76Q6M8EIcxkY7g0mqGBgYHgkYbFYjGESA8Myw6pcb8m4SggEAuh0unr86boSCASY836EYM770YI570eLR/W8PR7PDW7OhLqF6TqdDm63u15/vm5YrVbmvB8hmPN+tGDO+9HiUT7vW8H0SjEwMDAwMDAsC0xBkYHhIaBSqdyQjmT8FRgYGFYbJqhgYFjDkGBifn4eJ0+eRLlcRqlUgslkwp49e5h5JQwMDKsKE1QwMKxhKpUKyuUyxsfH8Vd/9VfI5/MoFovYuHEjNmzYAIFAwLRFMjAwrBpMUMHAsEapVCqIx+OIxWIYHx/H0tISHZ62tLQEr9cLNpsNlUoFgUBQ78NlYGB4BGCCCgaGNQope1y6dAmDg4Nwu91UVyEUCjE4OIh0Ok0zFgwMDAwrDRNUMDDcBblcDslkkk5I5XA4ddMqVCoVWuZwuVwYHR3F4uJijViTzWZDqVRCKpU+UoLNUqmETCaDYrGISCSCbDZ7w/dIpVIYjUbGgnsNQO7pbDaLfD6PdDqNZDIJFot1g/MpmSQsk8kgEAjA4XDoOHTmOq8eTFDBwHAHKpUKAoEA3n77bej1euzduxdisbhuw6ZKpRKCwSAikQjeffdd/OxnP0Mmk6np/pBKpWhra0NTUxP4fP6qH2O9yOfzuHr1Kubn5/HTn/4UFy9eBFD7Unn66afx3//7f4dEImECiwanXC6jXC5jbm4OY2NjOHfuHI4cOQKBQACNRkOfPxaLBblcDqFQiP3792Pjxo1Qq9XQaDQ3BCAMK8sjEVRc32r3sCwi5XIZ+Xy+5vx4PN4jZT1Mzn2lrmmhUEChUEAsFsPS0hLYbDbK5fKK/K27pVKpIJ1OIxqNIhKJIBqN0s9BIBBAqVTCarVCIpGAz+c/EpmKSqWCUqmEfD4Pl8uF2dlZjI+PY2FhAUDt/TE/Pw+fzweNRgOpVAoOhwMul/tIfE5rCSJCLhaLCIVCmJ6exsLCAubm5sDn8xEKheh15XA4UKlUkEgkmJ6ehkqlAgDIZLK6ZhUfRR76t0+pVEK5XKY3KIvFAp/PfyhusEwmg4GBAcRiMQDXAor+/v5HyjaWXFuS+lxuPB4PBgcHsbCwALfbTTMUAoGgbi+hcrkMp9OJ4eFhLC4u1gQ57e3t+OIXv4iOjg4YjUbaUvqwk8vlEA6H4Xa78eMf/xgjIyPwer03fc4HBgbwe7/3ezCbzXjllVdgMplgsVggkUjqcOQMt4IEz6lUCu+//z7+5V/+BYlEgpY/EokE/V4WiwW32w0Oh4PJyUn87Gc/wzPPPIOvfvWrUKlUMBqNTLZilXhogwrysikWiygUCjSoIBErh8O558CCfH+jBCTFYhHz8/MIBAIArgUV3d3ddT6qlYXsyEulEr2mJKgg13U5XqLk78TjcUxOTiIajdKXNwlg6nEfkHRwMBhEIBBAMpms+bpSqURnZydsNhv4fP4js5CWy2XE43EEg0FMTk5ifn4epVLpptcoHo/j7NmzsNls2L17NwQCAQwGQx2O+v6pzk6Sf64OLokWgUxevdfsJcn8AJ+sd6t9z5NjqM4UVioVei7k+SdrAfkcyCaru7sbgUAAHA4HWq2WtlY3yvr9sPJQBhWVSgV+vx9LS0uYm5vD0NAQCoUCMpkM5HI5tm/fDqVSCZPJBLFYjEwmg1wud8vfVy3+If+rJ+RBSqVSmJycxOLiIrLZLDgcDvbu3VvXY1tJyuUyCoUC0uk0Ll26hFQqBa1WCz6fj1QqhUwmg66uLjQ1NT3QwlGpVFAoFFAsFjE2NobDhw+ju7sbn/nMZ2CxWOqW6SKLayQSwenTpzE4OAiPxwMA4HK54PP5MBqNaGtrq6k3PwrE43H86le/gtPpRDgcpi+am0GeHb/fj0uXLiESicBms0GhUKzyUd8/hUIBuVwOpVIJ6XQauVwOLpcLqVSK/vvc3Bx8Ph9eeOEFHDhw4J6C7VQqhTNnziAej8NqtUIsFsNut0Mul6/gWdXCZrMhFovB4/Hw5JNPgsfjIZvN1ohvU6kULly4gHg8jlQqhWKxSL92/vx5lMtlOBwOvPTSS1Cr1bBYLHVfvx92HtqgIhwOY3x8HOfPn8eRI0eQyWSQSqWg0+lQLBZhMpmwZcsWqNVqRKPRmlTa9bDZbJhMJlqfq3f5hETo+Xwefr8fi4uLiEQi4HA4SKfTdTuulYYEU/F4HKdOnUIsFkNbWxvkcjm8Xi+i0SiUSiWampoe+G+RHZLH48HQ0BBaW1vR2dkJpVJ5X1mu5aBUKiEWi8Hv92Nubg7j4+P0enO5XAgEAkgkEqhUKshkskei7EFIp9MYHR2Fy+W6QbR6PaR7JpVKYXFxEWw2G9lsFpVKpeF3seS8isUiMpkMstksYrEYzb5EIhEkEglks1mcP38eo6OjUKvV2L9//z3dD9lsFidOnEA4HEZPTw8MBgN0Ot2qBhUsFgs8Hg8cDgcOhwPZbBbxeBzxeJx+D9FaFAoFZLPZmqDC4/EgFAph06ZN6Ovrg91uh9FoXLXjfxDWsg7woQkqyO5yfn4eoVAIR44cwejoKBVlFYtF2oZ36tQpKJVKjIyMQCKRIJ1OI5PJ3PJ3s1gsqFQqiMVi9PX10ZeLVqsFh8NZ9ZcM2bF6vV7Mzs5ibm4OAoEAMpls1Y6hHpB2wVAohKtXr8Lv9yOVSkEul0Mul0OtVkMkEt33768Who2OjmJiYgLT09NQq9V0URWJRHV7WReLRczNzWFqago+n4+2TgJAW1sbtm3bhi1btkCpVFInzYedVCqFSCRCr5XX60U+n7+rny0UCpidnUU6ncaxY8cQDodhNBqhUqlokFbvzUOxWKSBdCaTgc/nQyAQgN/vh9frRS6XQyKRQDqdhsfjQbFYRFtbG4xGIz7zmc/gs5/9LJ588sm7vmdJuZhsykirciAQwMaNG1f2hG8COW6bzQaZTEbFuOSz8fl8GBkZQTqdRiKRqMk4k3KoVqtFb28vdDpdQ4vYyfqTy+Vw4sQJzMzM0ECuvb0dTU1NtPRaLBYRj8eRy+WQy+VqgikCj8eDRCKh/79a5/5Af6VYLOJHP/oRfvSjH+HSpUuIRqMQi8Xo7u7GF7/4RfzWb/3WqpwIuRiFQgGnTp3ClStXcOzYMQwPD98waCmZTMLv9wP45Ia92TCm6yEX85lnnsGePXuwefNmyOVyGkmvJoVCAYFAAPPz85icnITP54PZbIZUKr3jeaxlSqUSEokElpaWcO7cOSwtLSGRSECv12Pr1q1oamqCUCh8oL9Bgs/jx4/jww8/RC6Xg8FggF6vh0KhoH3v9SCfz2NoaAizs7Pw+/01aeCuri48/fTTaGtrW9UFpN7E43GMj49jcHAQExMTiMfjdx1U5PN5qkmSy+UYHx+n7YgSiaTuafJqTZjT6YTP58PRo0epQNfr9dIyCMmykFZio9GIzZs3w2azQS6X33VwVCqV6Bo5MDAAn88H4NpL/XbZ3JWCBAZqtRpqtZr+d5K11Ov10Ol0WFxcvOGer/atsNvttNOnUSEaklQqhb/6q7/C+++/j+7ubjQ3N+OrX/0qzGYz1YwVi0V4PB6aubxeWwWAZm2lUimEQmHjBxWRSATPPfccTp8+XfPf4/E4zp07h3PnzuHEiRN44403HvggbwXZVZJ6YigUwqVLlzA7O4tQKHTHYIEImAQCAfh8PsRiMTQaDQ028vk8wuEwcrkc0uk0CoUC3G43hoeHUSgUIJVKoVarV90LoFAoIBgMwu/3UxHqckCCM2D1RVm3I5vNUnX/zMwMnE4n7HY7dDod9u7dC5PJhN7eXhiNRiiVynv+/aSURK5vOBzG/Pw84vE4Ojo60Nraig0bNtS124MsNrOzs5iZmUEqlQIAKjyTyWSwWCxQKBQNc91WAiJUzeVyNNNw4sQJzM3N1WRuquFyueDxeDAajejt7UUkEsHIyAjVzaRSKfrzYrEYhUIBHR0dq/4SIp1qJLBNp9NYWFhAIpHA4OAgAoEApqenqYZKLBbTFmKhUAir1QqZTIbdu3ejtbUVBoMBMpnsrsq15G+HQiEMDAxgZGQEqVQKHA4HLS0taG5ubojumGKxiGKxSD8b0pUVDodpMEnW9ba2NmzZsgX9/f3g8XgN/1yUy2Vks9majItOp4PNZkM2m4XL5UKhUKCZqcuXLyORSCAWi91UEygSiWA0GqHT6bB//34oFAqIRKIV3xjdV1BRqVTw4osv4vTp0xAIBPj93/99vPzyy7DZbHA6nfijP/ojfPDBB3jzzTcxPDyM9evXL/dxAwC9uVwuF/7yL/8Sc3NzuHr1KiKRyG2FWsC1xZg8lHq9HlqtFna7HZ2dnfTlkU6nMTY2hmAwiJmZGYTDYYyNjWF4eBitra2YmZnBunXr8O///b9f1Zs2nU7j6tWrmJ2dvetd2d1ASkgA6q4bqSYajWJubg7T09M4e/YsisUi1q9fD5VKhVdeeQVGoxFCoZBeg3s97mKxiGg0inA4jO9///twOp1U8PfFL34RX/rSl+pa9iiVSsjlcgiFQjh69Cimp6epMp/sxoxGI1pbWx/qFtLqcgDZqf/yl7/ED37wA2QyGaTT6Zs+82KxGGKxGPv27cMrr7yCiYkJ/N3f/R38fj99GQ0MDIDFYtEs58svv4y2trZVDSrIBikWi8HtdsPpdOKNN96Ay+XC3NwckskkCoUCyuUytFotTCYTmpqasH79emg0GuzatQsKhQJWqxUikeiengXyt4eHh/Hd734XgUAAoVAICoUCmzZtwubNm2syBfUin88jFothcnISf/u3fwuPx0NF29UW9WKxGI8//ji+/e1vQ6PRUIfNRqZUKiEUCmFpaQnpdBpsNhutra3YvHkz4vE4Xn/9dbjdbhw/fhzJZBLxeJxqv27mncPhcCAQCGgnWGdnJ9ra2hozqPjZz36Gjz76CADw4x//GJ///Ofp11QqFf7xH/8RNpsNADA5ObnsQQVpI0yn03A6nZidncXi4iJ8Ph8KhQIVUxL1sFgsrqmXZzIZsNls9PT0QK/XQ6PRQKFQwGg0wuFw0EU5m82Cy+UiHo/DaDTSPniv10sXNpVKhXA4jEqlApFItKIpJvLQEIEmCZ44HA6USiX0ev19ZUyqa3kk3alSqcDj8SAQCOqa8geu+XGQzz0cDqNcLtPrm0wmkc1m78vdslrwSurUPp8Pi4uLUCqVMBgMMJvNK35d7wRJScdisRt246QbSSwW1/04V5pyuYx0Ok13qaTmn0wm6WfCZrNpK61arYZYLIbRaIRarUZnZycMBgNyuRw6OjogFouRTCZpFwVw7ZlPpVK0pLCS4s3qdYzsUOPxOKLRKKampuDxeOByuRAIBFAsFqlgXCqVoqWlBTabDVqtFt3d3VCr1dDpdLRb4l6fhWw2C7/fj/n5eYTDYWSzWUilUqodUyqVdX8GyuUyYrEYZmdnMTo6CrfbjaWlJboR0mg09HprNBq0trZCrVZDIpE0dKBdvf66XC5MT08jHo+jUqkgGo3Srp5EIgG/349AIIBCoQAWiwUul0uvC9EMkv+xWCwUCgVaGlmtTeJ93SWvvfYaAOBTn/pUTUBBqE6TrYR4sFAoIJ/PY3h4GH/xF38Bj8eDiYkJ5PN5KBQKyOVyOvegvb0dHR0dKBaLdMGYmJgAl8vFr/3ar6GnpwdSqZTuRKvT/pVKBZ/+9KdRLpcRjUaRTqdx9OhRHD16FG63G6dPn4bX64XdbkdbWxt27969omJJUnMLBoM4e/YsPB4PstksBAIBNm7ciObm5vtK/5O67NzcHP78z/8cpVIJ+/btg81mw9atW+u6Q6lUKvB4PDh06BBCoRAmJyfpHA65XA6LxYKenh5s27btnvUU5Ly9Xi/efPNNeDwenDhxAolEAt/5znewc+dObNiwAWKxuK797alUCsPDw7hy5UqNjoLFYkGj0UCv16OpqYkuno2SYVpuisUihoaGMD4+jsOHD+P06dO0fRK4FlAQzwmNRoOnnnoKzc3NcDgc0Ol01LZZp9MhlUphZmYGi4uLNVqBau8D8tJfic+TZAULhQIOHTqEs2fPwuv1wuVyIZFIIBQKIZ/PI5FIgMViwWAwQKFQ4Omnn8aGDRvQ3t4Ou90OHo8HHo8HNptNBeP3E1xPT0/jjTfewOTkJEKhEAQCAXp7e2G329Hb2wubzfbAeqX7pVKpIJfLIZvN4tixY/ibv/kbBAIBLCws0BKWQCDAgQMHYLfbsXHjRjgcDphMJhiNxhUzxlsuSJbI7Xbjf//v/43R0VEsLCygVCrhgw8+wIkTJ+j9SLpchEIhOjs7IZPJqNbL5XJRnRnZ6JIshkqlgkKhWJVszT0HFalUCqdOnQIAfOELX7jp91y6dAnAtUVvORXD1Tv1SCRCdyuRSATpdBocDgd6vR5qtRparRYikQjNzc1oa2ujpZJEIoF8Pg82m00jfaFQeNsdPjGQkclkaG9vpy1MV65cQSQSgcvlAofDwcaNGyEQCFbM8pekulKpFILBIGKxGDX4EYvFkMlk97WbIGllstDmcjm0t7eDx+Mta3nlfo6rUqnQ8w0Gg3RnSVJ/i4uLVBysUqnoPVJt300WW3JNyEuDtOMFAgF4PB4sLS0hlUqhXC5DrVbDbDY3RDkhn8/T9rjrNQMSiQRarRYSiaRura4rSXWGkdSV5+fnqQcDgc/nQy6XQywWo6OjA2q1Gi0tLWhra4PJZIJSqaS7eKFQCLVajVAodMPzcn1QQVx4l/tzLZfLSCaTSCaTmJ6extzcHBYWFuDxeOjfJ+sZn89Ha2srlEolHA4HWltbYTQaoVAoaDBxP1T7sfj9fiwsLMDr9aJcLkMgEKC5uRlWqxVyubyuDrIA6PUnuhLSsVc9QE8mk0Gv18NqtcJkMkEul6+J2S6kqy0ajdKsOwmUyT1CtH88Hg9SqRQymQwOh4Ou+zwej2oy8vk8vWf5fD4kEglEItGqXcN7fgNduHCBppv27dt30+8hmYzNmzdDr9ff/9FdR6lUQqlUwsWLF/GjH/0I8/PzcLlcKJVK4PP50Ov1+O3f/m2sX7+eDpcRCAQQiUT0BUUuIHAtXXY3sxFYLBYNPLZt24auri4cPXoUAwMDyOVyOHToEL2JSWbkfjIGdyIajcLpdOL8+fM0JUZcQiUSyX2nKEl7Enm5kjQci8W6rSnYSkPSeEQcm06naSBF/vfWW29BpVIhl8thz549yOfzyGQyNKKXyWTo7OysefhImWNqagrnz5+H3+/HyZMnkcvlIBaLIZfL0dvbSzUK9SYSieDo0aNwuVw1PiRsNhsdHR3o7e2FxWKp4xGuDCSYSKfTGBoagsvlwk9/+lNcvXoVoVCo5nubm5vxwgsvwGq1UlGaVCqls3Cqd/K3I5fLIRqNIhQKIRaLQSKRQCqVLvuLKZ1O48/+7M8wNDSEmZkZRKNR+iLo6enBrl27YDKZsHv37pqShlQqpRuX+y1HVAdqFy9exKVLl3DmzBkcPXoU5XIZYrEYmzdvxu/93u9R8TOXy62rJiGZTGJpaYl68lwvUGez2dBoNLBarbBYLNDr9WsmyE4kEjh9+jSuXr1as7GpxuFwoL+/n3b1KJVKmj0i+rdDhw7h+PHjmJiYwNLSEhQKBTZv3oy2tjbYbDZa0l5p7vmuJFkIjUaDlpaWG77+93//9/j5z38OAPjd3/3dBzy8TyABQbFYxMzMDD7++GPqokaGyZB+5I6OjtsK98gFu5cdCPGjkMlkkEgkcDgckMvliEajWFhYQCwWw9TUFNhsNqxW64oEFel0Gl6vF4FAoCbtS4Ke+60dkh1LJpOh7qLELKxew7PIwkeib2L0Q3rUyddJym9mZgYmk4nWHol2RqFQUM0MaREkwdnY2BguXrxIO2lYLBaampqg1+uhUqkgEokaQtxFdujEh4BAJjMS/4w7Ud0NdTNznUazoSfp3mQyifHxcczPz2N4eBjz8/P0e0hqW6fTobu7G62trXA4HHf8PMjzTEoH1ZuOfD5PnwE+n78ibdqFQgGDg4M4evQo/f1isZial7W1tcHhcKCrq2tF/DJIq+rExAQGBwcxMjKCpaUl2gGn0+nQ0tLSMC6j5HnO5/O3FOFX7+bXQoaCkMvlsLCwQNvEib4HAG0hNRqNaG5uRktLCzZv3gyJRAK1Wl0zloA4n5J7n5QCq7U2DZmpGBwcBABs2bIFwLUHPxKJ4OLFi/iHf/gHGlC8+OKL+PKXv7wsB0kedp/PB7/fj5GREdpOScoYr7zyCpqbm2G32+nL4FZBw4MsniSNbrPZ8K1vfQtzc3P4+c9/jmAwiHfffReXL19GZ2cnzGbzA5/39WQyGSwsLCASidS0fgqFQnR1dWH9+vWQSqUP/HdIkKbX6+sm0iQiJDabjS1btuDVV1+F1+vFpUuXkEgksLi4iGKxSEs3w8PD9IEslUqQSCQwGo0IBoP427/9W5TLZVgsFkilUrhcLpqtmJ+fh1gsxt69e6HT6fD000/DarWio6Nj1R7CO1Gd1qxeTFksFkwmEzZu3EinMt4MsggTESDJSpF7SCgUwmaz0ZcaWZTrLfqMRCI4c+YMnE4nDh8+TIW61WzYsAG7du1Cd3c39u/fT7MTt0MgEFCXyCtXrsBisWB4eBh+vx+5XI62Vvr9flQqFcjl8mUPLrlcLnp7e5FOpzE1NYVgMIh8Po9yuQy3242pqSmUy2V0dnZCIpHclxD5VuRyOQwMDMDpdOLIkSO4ePEi1Sjt2LEDn/nMZ2gw0yiQbMmmTZtw4MABLC4uYmRkhL6Ac7kcDh8+jLGxMYyMjKCvrw/t7e1oaWmpi0HhvSASidDV1YV8Pk+zDiTLfvDgQXR0dGDdunXo6+uDWCymwQSPx6OaM9IZduLECfp+4HK5kMvlNEO7WhukBw4qnn32WfzqV7+q+Z5f//Vfx//6X/9r2S4i2ZWSgMLj8SAej9N6uV6vx5NPPgmLxQKlUnnHDogHOS4SqCiVSjz++ONQKpV48803kUqlMDAwgNHRUfz6r//6ff/+W1GpVJDNZhEKhWo8Cog4zWq1wmAwLEsQwOFwIJFIoFAo6vpiIYuByWTC448/jtHRUTpYiCjigWsv3bm5OcRiMXpPkIxDPB7H22+/jVAoBKPRCJlMhng8jnQ6TV+2IpGICt+eeOIJ6qjYCFkK4JNM0vVZI5LyNRqNkEgkN72vq3fgJEND1PPk98lkMhw4cABqtRp6vZ7+rnoHFdFoFMeOHYPb7aYC2utpa2vD3r174XA4YDab72qHyufzodFoaDBeqVQwPz8Pv99PW1aJ2I2UTpcbktEkzrhEL0PGfHu9XtqdQua6LNf9WCwWcebMGYyMjODs2bM0sCbGWU8//TR9ETUC5CXL5XLhcDjQ2dkJgUCAiYkJGlSQzM/Q0BCCwSDcbjcOHjxI5wA1yrN8M3g8HnQ6HXQ6HQ1+iG/S+vXrceDAAVitVio6rT6XQqEAr9eLq1evYmRkBBMTE/RrHA6HtteuZlB1T6tGJpOhB02CiosXL97wfT//+c+xceNGfOc733ngAyRtZJlMhppqjY2NAbjm8rZr1y50dHSgqamJzmVYDUjrmlAoXJWLRdzyWlpaEAgEqC++Uqmkoq37bSlks9kQiUR0aFo6nYbb7abtbvVGqVSip6cHZrMZnZ2dCAQCGBoaQjQaxZUrVxAOh8Hj8VAsFmkLHEkXBgIBGtGnUina8ZHP56HVatHU1ISOjg4cPHgQBoOBGh41QobiVhBhLhFpElvuaqoNvchOZmBgABMTE4hGozUj08ViMSKRCKRSKQ3MOzo6YLVaIZVK78mRcTlIpVKIx+OYmprC5OQkPB5PjWCYzWZjw4YNaGtrw+OPP46+vj4qWryX4+RwOGhtbUWxWMSFCxfof69UKkgmk1hcXIRUKl2REiCPx8OBAwfoTApi5DQzMwM+n4/R0VGEQiEUCgWo1Wp6jmazmQoQ7/alTwLLZDKJoaEhuN1uXLhwAbOzs9REa//+/ejv76d+FI3kUwOAruvNzc34/Oc/D5/Ph56eHvj9fhw5cgTRaJS2W8/Pz9Ndv1AohNlsRnd393372Kw0XC4XCoUCWq2WBsVyuZy2xTY1NdFZPtcfO8lAkpEF1ZDGgvb29lXdINzTX7py5QqNDElQMTMzg0QigdnZWRw9ehSvvfYalpaW8Gu/9mvgcrn45je/+UAHSNo5w+EwTpw4gddff53uHLq6uvDSSy/BbrfDZDKtarqOxWJBIBBAKBSu2guIiA7n5+fpDWYwGGC1WulskvuBw+FAJBJRO9dUKoXR0VFEIpEbbtR6QFJ4drsdGzZsQDabxc6dOxEIBPCTn/wEk5OTCIfDSKVSUCgUMJlMsNls6OjogEQioXVxMhKZoFar0d/fj3Xr1mHPnj1U9NTIAQXwidJdpVLRIOD6YLpcLlN9ya9+9SuMjo7i/PnzuHr16k133iTrpdVqIZPJ8Mwzz+DAgQNoa2uDTCZb1YU4EolgfHwcZ8+exblz55BOp28IKnbt2oW9e/eir68Pra2tNd09dwPZvdrtdgiFQhw+fLjm64lEAm63GwqFYkUyFQKBABs2bEC5XEZbWxt8Ph8+/vhjiMViLCws4NKlSyiVSnjvvfcgkUjw5JNPwmw246WXXkJ3dzetkd8N1eXjv/iLv8D09DRmZmbo1GaZTIbHHnsMX/3qV6mgudFevNVZS6PRiHw+j8ceewwejwfBYBDDw8Pw+XxUn+ByubC4uIjZ2Vls27YNra2tYLPZdc++3QwulwulUlmjkSATVR0OB81Q3CoTGQqF4PP5agatAdc2Y+vXr4fJZGrcoIKUPoxGI6xWKwDQYU4WiwWPPfYYfu3Xfg0bN27E0tISXnvttQcOKorFIlwuF513QLoduFwu1Go1mpuba6y1V5PVfvBIK5xarYZQKKRjj0lJoHoyJXkISUbldsdanXWpNrtqpDki1cdPzL64XC527tyJ5uZmpFIpZLNZagIkEomQSqUQjUZvudOMRqNwu93g8/kYHx+HRqOpaz/+vUBS4mSxqPZWIVM4XS4XgsEgZmdnsbCwQI3DCNW6lWrzp0qlQruMuFwu7HY71Vis1D1f3d7o8Xhw7tw5zM3N1QxL4vP5aG5upqJMh8NBbcnv97hu9bPEdEipVK6YWJn8bblcjkqlQoMMYkEfiUTgdDoBgM75IOPIjUYjDAYD3QwQa+qbnUsikcDCwgJGRkbgdrsRiUTAYrEgEonQ398Pm82G3t7eewpU6gU5Pw6HA7FYDK1Wi+3bt8NsNuPq1atUyB6LxZBKpeB2u6HX6zE+Pg6tVtuwo8+rrxsxveLz+QiHw0in0+Dz+RAIBLTdmKz9ZNJuLBajWWWSdTabzdBoNJBKpav6frynoIJ0fvT399/ye4xGI55++mn84Ac/QDQafaCDA6716L/55pu4dOkSLXuQNlGbzYaurq5lrTc2MiKRCBaLhfbhk7pvuVzGsWPHqHqbCHTILkStVt82NUxeGCqVCkqlEslksqE/Tx6PR4d82Ww2qrmpDoL8fj8+/PBDLCws3LItljjyaTQaTE9Pw+Fw4I/+6I8aPqggwQAR712/IBUKBaolmZmZwYkTJ+DxeGpU5QDofcJms6mqnpQeDh06RCd39vX1QSqVrkhrZfVxJxIJRKNRHDp0CH/3d39H3TOBa4GvXC6nIsL9+/fDbDavmLHR0tIS7cy42TyR5YLNZtPnTq/XY/PmzdSWfmZmhpqyjY+PI5vN4uzZsxAKhdi0aRPWrVtHBaqk9fVmz+3ExAT+63/9r/B4PBgbG6NCZrlcji996Ut4/PHHodPp6PVttCzF9ZD7XyaTQSgU4lvf+haSySQOHTqEyclJfPzxxxgcHEQkEqGi5Gw2i9bWVvzO7/xOQwYV1ZCNdCAQwOXLl2GxWGAymaDT6WgwkUgkMDAwAI/HgytXrmBmZoZmKjQaDfr6+tDT0wOLxbLqTrv3lakgpY9b4Xa7AQDd3d33eVifQKK2akMQuVxOrVjrqey92cCylTwOItKRSCQwmUzgcDhYWlqiNtbVfhoqlQoajaZGKUxS+9drBqq7Ycg/k4FGxGSL/HcSMddz4blZ+2N1mynZWS4tLSEajUIgENA6NIfDQTKZRCaToda4yWQSbrcbcrmcWjY3sjMledGRbpfqe5AYwy0tLcHr9cLr9SKRSNBOKdIWrVarIZPJYLVaqd15LpejM27IUKt4PE7FgittBEY8RsiQJNLdQ8TIYrEYFosFra2tkEqlK5o54fP5UCgU1E11JSGfKRGYVyoVmEwmsNlsdHd3Q6fTYXx8HNFolLoCe71eKqRuaWmhayKfz6ddS8Sxc2FhAfPz84hEIqhUKhAIBGhvb4fZbEZLSwsVtzd62a8ack3IfcnhcNDW1oZyuQyPxwO32410Oo1kMkk1YmKxGLFY7L6tzFcKUooTCoXQ6/WIRqP0mSVdSAqFgmYpSPA9PT0Nn89H12hSIiTrP+laWu3zvOugolAo4OrVqwCATZs23fL7FhcXcezYMQDAM88882BHB1DXxLm5Obowbtu2DXv27MHOnTvrGlAQAxkCeXmv1EUkN4jRaMT+/fsxPz+Pt99+G8FgEP/6r/9KFwZyk/J4PLS1tVEjnV27dkEul0OlUt02Wi+XywgGg4hGo/jnf/5nnDx5EkKhEEKhENu2bcOWLVsa5oGsHl8PfGKY5PP58N577yGZTFLnOVIWOXPmDIaHh2l2I5VKUY+RQCAAqVTa0KlgUh9PJpMIBoPU55/L5SIQCOCv//qv4fF4cPz4cUQiEWr2JhQKIRKJ8MQTT+DZZ5+FzWajA/RI8PDaa6/h8OHDiMVidKzy0NAQmpubV7QbiGQqfD4fnT1Brimfz4fBYEBHRweeeOIJtLa2rvgAv6amJjzzzDPo7e1dtfuArB1CoZC6Z/b09FDzL6/XizNnzuDy5cvUm0UgEOCnP/0pdDodduzYAb1ej927d0Oj0eD8+fMYGBjA5OQkpqamAIC6DP/hH/4hNUVq9NkYt4OUcQQCAXbu3IktW7bAYrHAYDBgZGSEDt8aGhpCJBLBiRMn0NnZie7ubsjl8nofPoBPAiOj0YinnnoK7e3tOHz4MAKBAC5cuIB4PI6XX34ZJpMJ8XgcAwMDmJ2dxU9+8hMaPBHfEQAwGAzo6upCW1sbDbxX8x151yvEyMgIjYR+8IMf4IUXXrjhQEulEr71rW9Rm+OvfvWry3KQxPiIiLHkcjmamprqPuaZ7ASIlW91NmAlIOdKxhwTx0gyop2kt6t3rplMBjqdjr5cC4UC+Hx+jQEY+f5kMkmH0pDOgfn5eRQKBdqa1NTU1NA7+WKxiGQyScVL5XIZ69evh0KhgMFggFgsxuLiIhYXF2lNknQYxeNxJBIJZDKZhhikdjtIpoIo3okugozxdjqdCIVCSKfT9LkhGQrSRaPT6aDX68FmsyGVSpFOp2EymaDVaum0zEQiAa/XC6VSueIam2w2i0gkgmw2W6NjIL4pZPd1N0Zf98LNRJ5CoRAajQYqlWrV7vPq1kfS3cHhcCCXy6kImQz8IhqZTCYDv98Pv98PrVaLaDQKi8WCYrGIyclJjI2NYX5+Hrlcjp6T1WpFc3MzzGYzbdVcq5AXJimDCQQCWCwW2Gw2BAIBcDgcKlomGW+5XE5Hud+rwHelzoEEk1arFblcDhKJBOFwmA6x9Hg8yGQyiMVi1AjP7XbD7/ff8PuEQiF1lK1HOeuu7yZS+gCAN954Ay+88AL+6I/+CD09PYjH47h8+TL+23/7bzh37hwA4K//+q+XxYip5mD/rfZvNpuxfv16qNXqur3YstksnE4nFhYWUCwWqZLcZDLVDFRbCdRqNT7zmc8gEAhAIpFgcXERhw8fhs/no0EBIRAI4MiRIxAKhTh//jyEQiHtQSemOvl8HtlsFsFgEJOTk8hkMjRAGRoawsTEBO10kcvl2L17N0QiUV3HgV9PLpdDoVDAhQsX8MMf/hButxuBQACtra341re+BZvNRmuLe/fuxezsLE6ePIkf/ehH9GeJL4LP58O+ffsaXltRKpUwNTWF4eFhyGQyyOVyXL16FaOjowgEAnQGgEgkglAoxEsvvYQDBw6gubmZ7vbJS4y8XJ588kmoVCocO3YMv/zlLzEzM4Nf/OIXcLvd2LZt24qVvsrlMiYmJvDhhx9icnKy5mvEE6alpWVZny2ymJMXOJfLpeUzcq+v9LN8p+Mjx9XW1kZnmSQSCQwPD+PMmTMIBoO4fPkyNbRis9m4evUqJBIJbSUulUqQyWTYsGEDvvGNb1Bnxodtqi25ni0tLdQE7dy5c7RcEIlEcOTIEYyMjEAoFNLpris5BPJeEIlEeOqpp+gcFuBaJ5Tf70cmk8HFixepczPJJF4Pi8WCSqWiE1rr8X6856CCy+WiWCzinXfewTvvvHPD90kkEvzVX/0VvvjFLy7fUf4b1TVhMiCoXkFFoVCA3+9HNBqlQQXpPFhpIRApRQiFQmzcuBEymQzDw8P0Jqu2YiYWzwAwNTVVs9tUKpWQyWTIZDJIJpPUBr0a4mAoEAjA5/MRCoWQyWSoULARIC1z2WwWIyMjeOutt5DP55HP5yGTyehQKZIKFIvFMBgMCAaDtIuGWDMvLCyAzWZj+/bt9T6tGm5W5qtUKggGg3A6ndBoNFRbEwqF6MA14Nq1k0gkaG9vx6ZNm+h1r/595LMxmUzo6+vD8PAwACAWi2FoaAh6vR7FYpF2X63Ec7e0tISFhYUb5noQr4GmpqZlfwlWd0lxOJyaQXQrNRjwbqnOXJANGvHN4PP5iMViWFhYwOLiIsLhMObm5pDJZGpszAHQuUBWq5VOHW7k8t79Qu5J4gpLpjYTW4JcLoeJiQkEAgHMzMxAq9XSFtpGgMvl0vcHsdeem5tDMBjEwMAApqamqAbsZms1yUoQDZlIJGrsoIJ0fvzBH/wB5HI5fvKTn2B6ehq5XA5yuRxdXV04ePAgvv3tb6/IcKNqV0FSDqmHipe08/h8Phw5cgQulwuZTAYcDgdWq5XuiFcDgUCA/v5+dHd3Y926dQiFQgiFQkgkElSoF4lEaPaBuPaRhbPa1rr6vxM4HA66u7tputRut2PLli10ME2jZCkqlQrcbjcmJiYwPj6OcrmMnp4ePPPMM2hra4NKpaqxbRcKhdBqtdi0aROeffZZuFwunDx5krZiVtfzGwG1Wo29e/fC5XLh3XffRTKZBHBtd3/hwgVEo1Fq+uV0OqnIC7imRyBtgxs3boRGo7mlHoGkgqszGGSgl8/nw/T0NEwmE/R6/Yo8exqNBhaLBUtLSzUvxnw+j0AgALlcfkMHy4NA7gWpVAqJRAKxWEyzdE6nE0ePHkUwGLythmy1IQGdyWTC008/jUAgAKvVioWFBfzzP/9zzeROco0PHDiAr33ta7BYLDCbzQ99txwJkNva2vD8889jbm4Ob7zxBjKZDB1GNjs7S8uBOp2u3ocMAPTZI105u3btwgcffICBgYEap1cyVPF6yPOr0+nQ1tZWN63MXQUV5XIZV65cAQBs3LgRL7/8Mv7wD/9wRQ/sZhCley6XQzabhUQiQaVSWdVorFoIeOTIEcTjcRQKBeqAtpomXHw+n84YaWlpoa55qVQK8/PzGBwchNvtRi6Xo7tXIuoDQJ0lq7l+99rd3Y2Ojg5s374dvb29UCgUt7SErheVSgUulwsff/wxZmdnUSqV4HA48JWvfIVmtKoXUYFAAIFAAJvNhq1bt0Imk+H8+fM0cCUdII2CRCLB5s2bIRKJcPTo0ZqgYmRkBJOTk/TeI3V2cvxcLhft7e00OLxTSfL67iAy1XdpaQljY2MoFAp3FPreD8T6XqvV3jDEqlAoIBKJ3HT0+4P+TT6fD5FIBLFYDKFQSLUqfr8fR48eBZvNXtGW0nul2ttCLpfDYDBALpdjenoab775Zs04eLI2rl+/Hi+88MJDVeq4HaRMbjAY8Pjjj0MsFuPdd9+lGqp8Pg+32w2JRIK+vr56Hy6lutzV19eHrq4uqvGKxWIIBALI5XI37TokP0/8m4g7Zz24q786MTFBnRXXr1+/ogd0N5C652rtJiuVCh2n7fV6MTc3h48//pi2uIpEIigUCrS1tWHjxo3LriW5G8hLXiAQ0IFnQqEQsVgMra2tiEQiGBoaQiwWQzAYRCqVoiI/DocDgUCAXC6HeDxOU9xisRgbN27E7t27YTKZoFKp7jhXZTWpNkyanZ2lbXd6vR46nQ5yufy2JTI+nw+9Xg+v1wsOhwM+n4/Ozk50dnbW5RreCpFIhJaWFiQSiRsWiupJnj6fj7bKEsGa2WzG1q1b0dfXd8uhY6VSibbZEk0GqemSZ4xY/loslmVdrMgmIZ/PY3Z2FnNzc1haWgIA+sJvaWnBvn370NLSsqyj6El5MJVK0dZDEmQbDAbs2bMH69ata+iXcTabxdDQEKampqgo9/p1MRAI0JZppVLZMBnGlaJUKtFBl4ODg5iZmaGZOwDUmbOrq6thOkCuh6xHW7ZsgUKhQCwWQzQapULNRCKBwcFBJJNJRKPRmvOrN3f1tJDSh1AoRHt7+4oe0N1AsgWrMZabtI6Gw2EEg0G8/fbbeOONN6jTGVHO6/V6bN++HV1dXXURd5GFgojyZDIZLBYLyuUydu7ciWw2i8uXL8Pr9WJiYgLhcBiRSASJRIK68iWTSUxMTKBQKFAzrKeeegobNmyoqaM3SpaCvJBSqRRGRkZw9OhRmM1mOguGjDu/1fGSF6/X6wWfz6d2xe3t7Q1TZwWuHWdrayvy+fwNGQJSvopEIohGozVjtHfs2AGbzYYnnngCVqv1lp9DsVikL/Nf/vKXuHjxIhYXF2teTgqFAu3t7dRIbbkgyvx4PI7R0VGcOnWKZmIEAgF0Oh16enrw6U9/etlnkJTLZSSTSfrZVQvfOjo68LnPfW5FdBzLSSwWw6FDh+ByuejnRgIL8lk5nU58+OGH6OrqwrZt2xpqY7ASkMFsLpcLH330UY3HEQDqaUHmnDQaJIMGAOvWrUNvby8KhUKNd5DP58P//b//F7Ozs5icnFwWo8nl4q6eFiLS7OrqWvVaHJklbzab6dAYr9eLgYEB+gLn8XjLMoOj2gKVjCEm6bLJyUksLCxgdnYW4XCYigCVSiV27NgBs9lMRTb1rldeb2ZFFhkyzVIkElGDlVQqRR1Kg8EgdWP0+Xx0ABdpmW20HQ7RQFTvMI1GI7q6utDU1HRHQSGp1cfjcSpukslkkEgkdb+G1ZBMkkwmg81mQyaTQTgcvmF3Uh0ElEolxGIxSCQSLC0t0XOqFiMWi0UkEgmk02lcuXIFXq8Xi4uLiEQidBEmglxipLPcAWV1toBkz8h5kXuOlGOW6/4jz3kul4PL5aKjxwHQv0PEjfUUg9+OTCaDeDyO+fl5OuWU6ELa29uhUqkwNTVFn+OxsTHweDxs2LABAFbc52O1qDa9y2QytL0+HA5jcHAQPp8PsViMDk3U6/XUiVculze8WPV6Hx4yVVkgENB7eDU21/fCPQUV9Sh9cLlcbNq0CXw+H8ePH4fT6cS7776LY8eOUWGowWBAa2vrA0fghUKBDmGamZlBIBDAxYsXEQqFcOnSJeptkM1moVQq0dLSgs2bN+P3f//3odVqaRtTIz2s1T3QDocDlUoFfX19N9TlWCwW/H4/xGIxXC4XPX9isnS9LqERIBkk0nIlFAqxefNmfOUrX6FjhG9HNBrFqVOn4Pf7oVKp0NzcDJPJBKVS2VDXkMvlQiqVwmg0YufOnTCZTDh27NhNe9QJ6XQag4ODmJ2dhUKhQEtLCwwGA9RqNTXKITqUaDSKyclJJJNJJBKJGp8IqVQKq9UKvV5fI3ZdLkqlEoLBILxeL/x+P+Lx+IqXNUm5JxAI4Mc//jEuX75Myz2k5KJUKmkJrdGCaeBa9uH111/H9PQ0zp49S903rVYrfuM3fgNtbW34yU9+gjfeeANLS0v4/ve/jwMHDmDHjh3QaDRQKBQN/0K9G8jGIp/PY3BwEGNjYzh9+jROnz6NVCqFSCRCjQA1Gg0+//nPw263Y+PGjdSjZS1Q7YxJrACIBfn1urh6c1dBxeXLlwFcS8WsNmw2GxaLBfF4nA6RymazSKfTdJdRKBSg1WpptqJ6h052O2R3cjNIDY4YpEQiEVoiWFhYgN/vp05/AoGATsLs6elBS0sLNBoNzZg06k1KREC3g+x0iDkWMfe63gq6kSCdK9WtVGSIzq1efiSgymaz1CCKGCs14jWs7lppbm4GAFitVupCmc1mb/gZogNisVhYXFwEAMTjcahUKmpu5vV6MT09XROYkWeBZCg0Gg26u7vpnI3lhkyPXVxcRDKZrNl1kdkkQqFwWQIZct65XA5LS0vw+Xzw+Xzw+/30M5TJZHQEgEgkargR4KTsG4lEsLCwgKWlJeTzeYjFYthsNlgsFhiNRiiVStjtdjQ3NyOfz2NmZoaKtUUiUcNqCe5EtZ6OtJKTacqTk5OYm5vD7Ows3G43XfMFAgGam5thsVjQ0tKC1tZWmrlrpGt7O66f70NKv2R9BkANCus9RuGuggriVVAP+Hw+9u3bh02bNiEQCCAUCiEWiyEej2NoaAh/+Zd/CaPRiM2bN0OpVNLZ82QUOBm0k0gkqNV3NZVKBTMzM/B6vdS9LJVK0QiQ9ASz2WwolUps2rQJPT09WL9+PQ4cOACJRAKFQrGmbtC7gTyQZBFr1KAC+MTwpbm5Ge3t7XQuyq1egkTc6XK5cOrUKWi1Wmzbto2OwW5UpFIpvvSlLyGZTKKjowOzs7M4dOgQLl26dEPmibSfpVIpHDp0iDqE8ng8ej2JayYp95E6PJvNpuPjH3vsMbz66qtQKBQr0vdeKBTwwQcf0KFI1RDr6Z6engfKkpHPxev14sSJE3C73RgYGEAwGMTQ0BCSySRdF/bu3YuDBw9i/fr1MJvNt72PVptKpUK7YE6ePIkjR46gUChALBbD4XDgS1/6Es22CQQCbNu2DXq9Hu+88w6GhoYQDAZx4cIFtLe3Q6PRrEltBemEyOfzSKVSCAaD+OEPfwiPxwOn04l4PI5YLFazzvf09OC73/0ujEYjmpqa6DTmtbheVyoV6v4bCAQQiURQKpXAYrHgcDjQ3d2Ntra2uh5j4yqQ/g1iy83j8Wg9jLS5pdNpTE5OIhAIgMfjQalUIpVKQS6XQ6/X0xn1uVwOkUgEY2NjN7SHlctlTE1Nwel0wufzwe12o1gs1tSUiWhRIpHAarWitbUVXV1d1FCpUS2rH4TrMz2NSqlUojNhyCTNOy0YxCiL6ErIhEgyxKlR4XK5UCgUEAqF6OzsBIfDwfDwMG31vD5gJvqg682kbgbREhCPCo1Gg6amJrS0tMBqta7YTJtyuYxAIACn00lnlBBIKaJ65POdsk83C37Jf49GoxgbG4Pb7ca5c+douYf8XmJg197eDp1O13BDtsicGr/fj6WlJYTDYZrNIUGgwWCgxy2Xy2tKV7lcDoFAAAqFoqFapgnVwwDJdaz+/EkgTGbDxGIxzM3N4eLFi5ifn6d6HHIfkwGCVquVBlJksOBahQwTJBk38j5jsVhQKBTQarXLLmi+Vxr+0yVpbS6Xiy9/+cvYu3cvTp06hYGBASwsLGBsbIzahPN4PAwPD9OJhsRaWigUIpPJIBAI3PAwEbe1VCpF/S/IZES5XI4nn3wSer0eGzduhN1uh16vp+n1hzWgID3eMpkMZrOZnmujQIRZ+XweAwMDGBkZweLiYk1aFLixS4V8jdShp6en0dPTg+7ubjz77LNQKpUNPxaZlLHWrVuH1tZWsFgsNDU14dKlSzhz5sx9ZZR4PB6sVisUCgV6enqg1Wqxfft2bNy4EUqlsqasuBLnQ9Lx1R4LwLVpxx9++CHcbjesViudunuzaxSJRHD16lU6Frr6OSf3y9TUFE6dOoV4PI5QKIRyuQyRSAQOh0MX4+bmZmg0mmVtXV0uKpUKjh8/jp///OfweDzU2OyVV16BzWZDa2srtVsnzrF8Ph92ux1yuRyFQgFDQ0NIJBJ4+umn6306FBL8xmIxTE5OIh6PY2ZmBiwWC5s3b4ZcLkcymUQ2m8XAwAAuXLiAQqFAd+zz8/PIZrMwGAyQSCTo6OiA0WiEwWCAw+FAc3MzneDaSEHicsJisWCxWNDZ2QmNRlPXY2mcN8UtIIYeZLYGMXsi7XPj4+PI5/M3LEgP8vfIwi2RSNDZ2YnW1lbs2bOH7gIa/cXzoJCAitToGlFnQOqp8/PzmJ6eRjqdBvCJPuZmL0HycnG73Th79izK5TL0ej2sViusVuuamYXA4XCgUCggk8nQ19eHYrGISCSCc+fO3XWpqvqz4fF4MBqN0Ov16OrqQktLC7Zv347W1taVPI2avy8UCm/IhMRiMYyOjiKXy+HSpUuwWCxob2+/qYfIwsICPvroIyQSCYTD4ZoWQnLdXS4XZmdn6efD4XAgFoshEono+RPr5kbTUgDXNkBjY2N49913IZFIaNv4nj176P1Q/Rny+Xzw+XyqDymVSpiZmaEmX40CyUJFo1FcuHABS0tLuHz5Mu1UstlsWFxcRCwWw4ULF/DLX/6yJqsBXNsIKZVKGI1GtLW1ob29HV1dXWhvb6fmZg/jBrAaiUSyKrOn7kTjr6D/RrUXf1dXF2QyGRYXF7Fx40aEw2EagVcLV4Br5jDZbBY8Hu+m7WEsFgtGoxFarbYm3apQKKBUKrF3717qUMbn89fES+dBIWlGkk68Xdq5npDjJCOzY7EYjhw5AoFAAKPRCIfDQa9XsVikqW+Spejv78dTTz0Fs9lMTcPWEiwWi46utlgs6OrqQiwWg8fjoUFGMpmk963FYqGGWHa7nZY8hEIhWlpaoFAooNfrIZPJoNVqV+Uc+Hw+nn/+eaxbtw7vvvsuzp8/TzVTJNXr9/tx7NgxyGQyaDSamwb14XAYo6OjNZNnqyFus5VKBXw+nxpBHTx4EDqdDp2dnTAYDLBardBqtQ3VGVEul5FIJKinBgBs2LAB+/btw8aNG6kp3a2eUYVCgW3bttGZFzfL2NaTUCgEp9OJoaEhHD9+nArkyZovk8mQTqeRTqcxMzNTEzjzeDyYzWZotVq8/PLL6OnpgcFggFKphFQqpa3QD2tAUT0wsKOjA5s2barroE1gDQUVAOiDbjabYTabkU6nsXHjRvh8Prz++uvUorg6Co/H4wgGg3SI1PUfNpvNRltbG+x2O7UJFolE0Ol0dEbEWhQ0PQjVVujEr6NcLjdcSylwLWiMx+NYWlpCIBDAu+++i/HxcWzevBmvvvoqncOSTCbxgx/8AKdPn0Y6nUYul8PBgwfx+OOP0/rrWlt0WCwW1Go11Go1rFYr+vv7EQqFcPHiRczNzWFychL5fB4SiQQCgQAbNmxAZ2cnNm3ahP7+ftqmxuVyacvwcvpB3A0khd/d3Q23241gMIjZ2VkaVJCW0/fee2/Z/ibRZ7W2tuK5555Dc3PzXVmY14tyuUzHm8diMWq9/fLLL9POs9s9m3K5HN3d3RAKhRgcHITf72+YoKJSqcDn8+H999/H+Pg43nvvvRptkNPpvO3PCwQCtLW1oampCZ/97GfR0tLSUOLalYbFYtGsVVdXF+x2+7K3fd8rayqouB4iXGOxWNi/fz/i8fgNZiDEgpfP50OhUNxws5FaFKkdi0Qi8Hi8Gs3Eo0a5XEY6nabWzUT81EiQXUxbWxsymQxSqRRCoRBtF5yZmcEHH3xAd7XZbBYLCwvIZDK0Q4SYua3lXQwxNiNeJCqVCr29vbTllBh7cblctLa2wmazwWg0QiqV0vOuNpha7c+BXEeRSIQtW7aAw+FgcHAQuVyOTmS81+4jcg6kbEd8GSQSCVQqFTQaDTo6OqDX69HS0kJb1RuVanHe9e3dd3O9iDCZZB0bzSyJmLndzqOEBP5isRhSqRQKhQJ2ux0KhQLr16+HwWCg6/tafZYflJXwkbkfGvdJugv4fD6d6ka0FtdDREBk4bzVB369BXWjWVKvJuVyGfF4HPF4HJFIBNlstqEWXaLU5/P52LNnDxwOB4LBIGZmZmia1OVy4ciRIzdoB3g8Hvr7+/Hcc8+hqampIfUi9wqx9SUlPiLU2rVrV80iTc7zZh099brPq4999+7d6O/vxzvvvINIJAKfz4e5uTna8nq3kGedZGj6+/vhcDhgMploQNnT00M7uxphIb4dxFMlHo/f4ElyN8FWsVikgwYbqTWcCKfj8ThcLhe8Xu8tAx6xWAyxWIzm5ma0tLSgpaUFBw8ehFwuh91up1q3RsymrjTkeV7tLOOtaJw3xX1yp7bH6mDiUbzh7pXq1rzbtek1AqSNqlQqoaWlBevXr6dBBfEaAUBNjFpbW6HX69HW1ga9Xt9w01YflOtfjmvlfq/OLLBYLLS2tqK/v5+2P5IXaj6fh9/vrxFhXg8ZEieRSNDe3g65XI62tjaaoSGOqWT8dyMswneCpLhVKhUt50UiEczMzCCfz9OWexIgk/uAZCai0SiCwSAymQzMZjNaWloaapNAhLpkhACXy6UjDyQSCbhcLhWbGgwGWCwW2Gw2ep1JJm4tXMuVgM/nU0F9I9A4d9YK8Sinw+4HkmputBTpzSDTBvV6Pb72ta9h586d8Hg88Hg8mJ2dxVtvvQUWi4W2tjaYzWa88sor6O7uhs1mo9MamXujMajOWGzbtg2bNm1COBzG5OQkYrEY5ufnEQwG8a//+q+Yn5+/5e9RKpX47Gc/C5vNhueee67GCK263LOWNDQcDgcWiwVqtRp6vR4AcPz4cczNzaGvrw9f/OIXodFo4HA46GfIZrMRjUbh8/lw8uRJfPjhhzCbzfjc5z6HlpaWuncIAJ8EkzKZjA5uCwaDaGpqwgsvvACj0YiWlhZIpVIqNhaJRLQN+Prrulau53JCyntNTU004Kw3D31QATyaJYz7gQzlIeOzG53qdmO1Wg273Q6JRAI+n0+zF6RbSKPRoKWlhQ5Va5SonuETqu3IiTW31Wql3Q0KhQJdXV23/R1WqxXNzc1obW2ldfa1DrnPhUIhFAoFVCoVbY3WaDSYmppCKpWCQqGgLbIcDgfBYBButxuBQADFYhECgQA2mw12u72hslgSiQRGoxEA0N3dDYPBgI6ODmg0GpjN5hoRMY/He+SE87eDuAmTZ6QReCSCCoa7I51O4+LFi5ienkYikWiohed2kC4IhUIBh8OBzZs3I5/P41vf+haAWu0NEXwxND5k1km5XEZ7eztKpRIOHDhwW31F9XVuRAOr+4H45rDZbOzfv5+Oqh8YGMD09DRee+01CAQC6rGh1+shEongdruxsLAAPp+PrVu3oq+vD0899VRDmbyxWCy0tLTg1VdfRbFYpO3/xHOj2mDwUc1G3A6BQIAtW7Zg8+bNMBgM9T4cAExQwVAFcbWLx+Pg8Xh0XsRaeJBvtkiS3Q/D2oQIKatRq9V1Opr6QvQCer0ePT09yGazkMvliMVimJ2dpS2iHA4HZrMZcrkci4uLCIVCaGtrw5YtW2CxWKDT6Wgmo1EgJQ2Ge4fD4UCn08FqtTbM3CImqGCgGAwG/Of//J+RyWSQz+fB4XDQ1dW15v3yGRgeBlgsFrUR37x5M5577jnE43E6ZTUSiVAzuGw2i/7+fvB4PPT09GDXrl10yOJa0pMwrD2YNwUDRSKRYOvWrfU+DAYGhpvAYrEgl8vpwMS2tjYUCgUkEgnkcjn4/X6kUikMDw8jHA7DaDRCp9Nh/fr16OjoeGS7IxhWFyaoYGBgYFhjkPIFCTTIcLRCoQCdTodsNguxWAyhUFh322aGRwsmqGBgYGBYY1SPqie1dLlcDgCw2Ww138sEFAyrCRNUMDAwMDwEMMHDww+ZUNzU1IQXX3wRiUQC/f39MBqNjFCTgYGBgYGB4e4hjrFarRb/8T/+RwCoyyDA28EEFQwMDAwMDGuA6rEUjdqRx6rUabCDQCCATqerx5+uKx6PBxaLpd6Hseow5/1owZz3owVz3o8WgUDgljN46hZUMDAwMDAwMDxcNEYRhoGBgYGBgWHNwwQVDAwMDAwMDMsCE1QwMDAwMDAwLAtMUMHAwMDAwMCwLDBBBQMDAwMDA8OywAQVDAwMDAwMDMvC/w/JI8gw+SgdQAAAAABJRU5ErkJggg==", "text/plain": [ "
" ] @@ -316,7 +315,7 @@ "outputs": [ { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -384,10 +383,10 @@ "in the meta distribution $\\mathcal{D}$ during training\n", "\n", "The following instantiates\n", - "{class}`~ott.neural.models.MetaInitializer`,\n", + "{class}`~ott.neural.initializers.meta_initializer.MetaInitializer`,\n", "which provides an implementation for training and deploying Meta OT models.\n", "The default meta potential model for $f_\\theta$ is a standard multi-layer MLP\n", - "defined in {class}`~ott.neural.models.MLP`\n", + "defined by the ``MetaMLP`` below\n", "and it is optimized with {func}`~optax.adam` by default.\n", "\n", "**Custom model and optimizers**.\n", @@ -438,7 +437,9 @@ "outputs": [], "source": [ "meta_mlp = MetaMLP(potential_size=geom.shape[0])\n", - "meta_initializer = models.MetaInitializer(geom=geom, meta_model=meta_mlp)" + "meta_initializer = meta_initializer.MetaInitializer(\n", + " geom=geom, meta_model=meta_mlp\n", + ")" ] }, { @@ -451,7 +452,8 @@ "Meta OT models have a preliminary training phase where they are\n", "given samples of OT problems from the meta distribution.\n", "The Meta OTT initializer internally stores the training state\n", - "of the model, and {meth}`~ott.neural.models.MetaInitializer.update` will update the initialization\n", + "of the model, and {meth}`~ott.neural.initializers.meta_initializer.MetaInitializer.update`\n", + "will update the initialization\n", "on a batch of problems to improve the next prediction.\n", "While we show here a separate training phase, the update\n", "can also be done in-tandem with deployment where the\n", @@ -501,7 +503,7 @@ "Now that we have trained the model, we can next deploy it anytime we\n", "want to make a rough prediction for new instances of the problems.\n", "While in practice, the model can be continued to be updated in deployment\n", - "by calling {meth}`~ott.neural.models.MetaInitializer.update`,\n", + "by calling {meth}`~ott.neural.initializers.meta_initializer.MetaInitializer.update`,\n", "here we will keep the model fixed so we can evaluate it on test instances." ] }, @@ -516,7 +518,7 @@ "prediction of the solution to the transport problems from above,\n", "which are sampled from testing pairs of MNIST digits that\n", "the model was not trained on.\n", - "The initializer uses the Meta OT model in {meth}`~ott.neural.models.MetaInitializer.init_dual_a`.\n", + "The initializer uses the Meta OT model in {meth}`~ott.neural.initializers.meta_initializer.MetaInitializer.init_dual_a`.\n", "This shows that the initialization is extremely close to the ground-truth coupling." ] }, diff --git a/docs/tutorials/neural_dual.ipynb b/docs/tutorials/neural_dual.ipynb index e268485c1..2021eebfb 100644 --- a/docs/tutorials/neural_dual.ipynb +++ b/docs/tutorials/neural_dual.ipynb @@ -7,12 +7,12 @@ "# Neural Dual Solver \n", "\n", "This tutorial shows how to use `OTT` to compute the Wasserstein-2 optimal transport map between continuous measures in Euclidean space that are accessible via sampling.\n", - "{class}`~ott.neural.solvers.neuraldual.W2NeuralDual` solves this\n", + "{class}`~ott.neural.methods.neuraldual.W2NeuralDual` solves this\n", "problem by optimizing parameterized Kantorovich dual potential functions\n", "and returning a {class}`~ott.problems.linear.potentials.DualPotentials`\n", "object that can be used to transport unseen source data samples to its target distribution (or vice-versa) or compute the corresponding distance between new source and target distribution.\n", "\n", - "The dual potentials can be specified as non-convex neural networks ({class}`~ott.neural.models.MLP`) or an input-convex neural network ({class}`~ott.neural.models.ICNN`) {cite}`amos:17`. {class}`~ott.neural.solvers.neuraldual.W2NeuralDual` implements the method developed by {cite}`makkuva:20` along with the improvements and fine-tuning of the conjugate computation from {cite}`amos:23`. For more insights on the approach itself, we refer the user to the original sources." + "The dual potentials can be specified as non-convex neural networks {class}`~ott.neural.networks.potentials.PotentialMLP` or an input-convex neural network {class}`~ott.neural.networks.icnn.ICNN` {cite}`amos:17`. {class}`~ott.neural.methods.neuraldual.W2NeuralDual` implements the method developed by {cite}`makkuva:20` along with the improvements and fine-tuning of the conjugate computation from {cite}`amos:23`. For more insights on the approach itself, we refer the user to the original sources." ] }, { @@ -24,7 +24,7 @@ "import sys\n", "\n", "if \"google.colab\" in sys.modules:\n", - " !pip install -q git+https://github.com/ott-jax/ott@main" + " %pip install -q git+https://github.com/ott-jax/ott@main" ] }, { @@ -47,8 +47,8 @@ "\n", "from ott import datasets\n", "from ott.geometry import pointcloud\n", - "from ott.neural import models\n", - "from ott.neural.solvers import neuraldual\n", + "from ott.neural.methods import neuraldual\n", + "from ott.neural.networks import potentials\n", "from ott.tools import sinkhorn_divergence" ] }, @@ -58,7 +58,7 @@ "source": [ "## Setup training and validation datasets\n", "\n", - "We apply the {class}`~ott.neural.solvers.neuraldual.W2NeuralDual` to compute the transport between toy datasets.\n", + "We apply the {class}`~ott.neural.methods.neuraldual.W2NeuralDual` to compute the transport between toy datasets.\n", "Here, we aim at computing the map between two toy datasets representing both, source and target distribution using the\n", "datasets `simple` (data clustered in one center) and `circle` (two-dimensional Gaussians arranged on a circle) from {class}`~ott.datasets.create_gaussian_mixture_samplers`.\n", "\n", @@ -95,18 +95,7 @@ "outputs": [ { "data": { - "text/plain": [ - "(
,\n", - " )" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -149,16 +138,16 @@ "eval_data_source = next(valid_dataloaders.source_iter)\n", "eval_data_target = next(valid_dataloaders.target_iter)\n", "\n", - "plot_samples(eval_data_source, eval_data_target)" + "_ = plot_samples(eval_data_source, eval_data_target)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Next, we define the architectures parameterizing the dual potentials $f$ and $g$. We first parameterize $f$ with an {class}`~ott.neural.models.ICNN` and $\\nabla g$ as a non-convex {class}`~ott.neural.models.MLP`. You can adapt the size of the ICNNs by passing a sequence containing hidden layer sizes. While ICNNs are by default containing partially positive weights, we can run the {class}`~ott.neural.solvers.neuraldual.W2NeuralDual` using approximations to this positivity constraint (via weight clipping and a weight penalization).\n", - "For this, set `pos_weights` to `True` in {class}`~ott.neural.models.ICNN` and {class}`~ott.neural.solvers.neuraldual.W2NeuralDual`.\n", - "For more details on how to customize {class}`~ott.neural.models.ICNN`,\n", + "Next, we define the architectures parameterizing the dual potentials $f$ and $g$. We first parameterize $f$ with an {class}`~ott.neural.networks.icnn.ICNN` and $\\nabla g$ as a non-convex {class}`~ott.neural.networks.potentials.PotentialMLP`. You can adapt the size of the ICNNs by passing a sequence containing hidden layer sizes. While ICNNs are by default containing partially positive weights, we can run the {class}`~ott.neural.methods.neuraldual.W2NeuralDual` using approximations to this positivity constraint (via weight clipping and a weight penalization).\n", + "For this, set `pos_weights` to `True` in {class}`~ott.neural.networks.icnn.ICNN` and {class}`~ott.neural.methods.neuraldual.W2NeuralDual`.\n", + "For more details on how to customize {class}`~ott.neural.networks.icnn.ICNN`,\n", "we refer you to the documentation." ] }, @@ -171,7 +160,7 @@ "# initialize models and optimizers\n", "num_train_iters = 5001\n", "\n", - "neural_f = models.ICNN(\n", + "neural_f = icnn.ICNN(\n", " dim_data=2,\n", " dim_hidden=[64, 64, 64, 64],\n", " pos_weights=True,\n", @@ -181,7 +170,7 @@ " ), # initialize the ICNN with source and target samples\n", ")\n", "\n", - "neural_g = models.MLP(\n", + "neural_g = potentials.PotentialMLP(\n", " dim_hidden=[64, 64, 64, 64],\n", " is_potential=False, # returns the gradient of the potential.\n", ")\n", @@ -198,7 +187,7 @@ "source": [ "## Train Neural Dual\n", "\n", - "We then initialize the {class}`~ott.neural.solvers.neuraldual.W2NeuralDual` by passing two {class}`~ott.neural.models.ICNN` models parameterizing $f$ and $g$, as well as by specifying the input dimensions of the data and the number of training iterations to execute. Once the {class}`~ott.neural.solvers.neuraldual.W2NeuralDual` is initialized, we can obtain the neural {class}`~ott.problems.linear.potentials.DualPotentials` by passing the corresponding dataloaders to it.\n", + "We then initialize the {class}`~ott.neural.methods.neuraldual.W2NeuralDual` by passing two {class}`~ott.neural.networks.icnn.ICNN` models parameterizing $f$ and $g$, as well as by specifying the input dimensions of the data and the number of training iterations to execute. Once the {class}`~ott.neural.methods.neuraldual.W2NeuralDual` is initialized, we can obtain the neural {class}`~ott.problems.linear.potentials.DualPotentials` by passing the corresponding dataloaders to it.\n", "\n", "Execution of the following cell will probably take a few minutes, depending on your system and the number of training iterations." ] @@ -259,7 +248,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The output of the solver, `learned_potentials`, is an instance of {class}`~ott.problems.linear.potentials.DualPotentials`. This gives us access to the learned potentials and provides functions to compute and plot the forward and inverse OT maps between the measures." + "The output of the solver, `learned_potentials`, is an instance of {class}`~ott.problems.linear.potentials.DualPotentials`. This gives us access to the learned potentials and provides functions to compute and plot the forward and inverse OT maps between the measures." ] }, { @@ -520,7 +509,7 @@ "source": [ "## Solving a harder problem\n", "\n", - "We next set up a harder OT problem to transport from a mixture of five Gaussians to a mixture of four Gaussians and solve it by using the non-convex {class}`~ott.neural.models.MLP` potentials to model $f$ and $g$." + "We next set up a harder OT problem to transport from a mixture of five Gaussians to a mixture of four Gaussians and solve it by using the non-convex {class}`~ott.neural.networks.potentials.PotentialMLP` potentials to model $f$ and $g$." ] }, { @@ -578,8 +567,8 @@ "source": [ "num_train_iters = 20001\n", "\n", - "neural_f = models.MLP(dim_hidden=[64, 64, 64, 64])\n", - "neural_g = models.MLP(dim_hidden=[64, 64, 64, 64])\n", + "neural_f = potentials.PotentialMLP(dim_hidden=[64, 64, 64, 64])\n", + "neural_g = potentials.PotentialMLP(dim_hidden=[64, 64, 64, 64])\n", "\n", "lr_schedule = optax.cosine_decay_schedule(\n", " init_value=5e-4, decay_steps=num_train_iters, alpha=1e-2\n", @@ -721,8 +710,8 @@ "\n", " input_dim = 2\n", "\n", - " neural_f = models.MLP(dim_hidden=[64, 64, 64, 64])\n", - " neural_g = models.MLP(dim_hidden=[64, 64, 64, 64])\n", + " neural_f = potentials.PotentialMLP(dim_hidden=[64, 64, 64, 64])\n", + " neural_g = potentials.PotentialMLP(dim_hidden=[64, 64, 64, 64])\n", "\n", " lr_schedule = optax.cosine_decay_schedule(\n", " init_value=5e-4, decay_steps=num_train_iters, alpha=1e-2\n", @@ -804,7 +793,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.10" + "version": "3.10.6" }, "vscode": { "interpreter": { From 67202c2ee259aa8a9cf9be4873cd65c97340992e Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Tue, 2 Apr 2024 11:31:58 +0200 Subject: [PATCH 177/186] Update ICNN inits --- docs/tutorials/icnn_inits.ipynb | 44 ++++++++++++++++----------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/docs/tutorials/icnn_inits.ipynb b/docs/tutorials/icnn_inits.ipynb index 5741e44e5..1f9d01b3c 100644 --- a/docs/tutorials/icnn_inits.ipynb +++ b/docs/tutorials/icnn_inits.ipynb @@ -8,7 +8,7 @@ "\n", "As input convex neural networks (ICNN) are notoriously difficult to train {cite}`richter-powell:21`, {cite}`bunne:22` propose to use closed-form solutions between Gaussian approximations to derive relevant parameter initializations for ICNNs: given two measures $\\mu$ and $\\nu$, one can initialize ICNN parameters so that its gradient can map approximately $\\mu$ into $\\nu$. These initializations rely on closed-form solutions available for Gaussian measures {cite}`gelbrich:90`.\n", "\n", - "In this notebook, we introduce the *identity* and *Gaussian approximation*-based initialization schemes, and illustrate how they can be used within the `OTT` library when using {class}`~ott.neural.models.ICNN`-based potentials with the {class}`~ott.neural.solvers.neuraldual.W2NeuralDual` solver." + "In this notebook, we introduce the *identity* and *Gaussian approximation*-based initialization schemes, and illustrate how they can be used within the `OTT` library when using {class}`~ott.neural.networks.icnn.ICNN`-based potentials with the {class}`~ott.neural.methods.neuraldual.W2NeuralDual` solver." ] }, { @@ -20,7 +20,7 @@ "import sys\n", "\n", "if \"google.colab\" in sys.modules:\n", - " !pip install -q git+https://github.com/ott-jax/ott@main" + " %pip install -q git+https://github.com/ott-jax/ott@main" ] }, { @@ -39,8 +39,8 @@ "\n", "from ott import datasets\n", "from ott.geometry import pointcloud\n", - "from ott.neural import models\n", - "from ott.neural.solvers import neuraldual\n", + "from ott.neural.methods import neuraldual\n", + "from ott.neural.networks import icnn\n", "from ott.tools import plot" ] }, @@ -50,9 +50,9 @@ "source": [ "## Setup training and validation datasets\n", "\n", - "To test the ICNN initialization methods, we choose the {class}`~ott.neural.solvers.neuraldual.W2NeuralDual` of the `OTT` library as an example. Here, we aim at computing the map between two toy datasets representing both, source and target distribution using the\n", + "To test the ICNN initialization methods, we choose the {class}`~ott.neural.methods.neuraldual.W2NeuralDual` of the `OTT` library as an example. Here, we aim at computing the map between two toy datasets representing both, source and target distribution using the\n", "datasets `simple` (data clustered in one center) and `circle` (two-dimensional Gaussians arranged on a circle) from {class}`~ott.datasets.create_gaussian_mixture_samplers`.\n", - "For more details on the execution of the {class}`~ott.neural.solvers.neuraldual.W2NeuralDual`, we refer the reader to {doc}`neural_dual` notebook.\n", + "For more details on the execution of the {class}`~ott.neural.methods.neuraldual.W2NeuralDual`, we refer the reader to {doc}`neural_dual` notebook.\n", "\n", "## Experimental setup \n", "\n", @@ -114,8 +114,8 @@ "### Identity initialization method\n", "\n", "Next, we define the architectures parameterizing the dual potentials $f$ and $g$. These need to be parameterized by ICNNs. You can adapt the size of the ICNNs by passing a sequence containing hidden layer sizes. While ICNNs are by default containing partially positive weights, we can solve the problem using approximations to this positivity constraint (via weight clipping and a weight penalization).\n", - "For this, set `pos_weights` to `True` in {class}`~ott.neural.models.ICNN` and {class}`~ott.neural.solvers.neuraldual.W2NeuralDual`.\n", - "For more details on how to customize {class}`~ott.neural.models.ICNN`,\n", + "For this, set `pos_weights` to `True` in {class}`~ott.neural.networks.icnn.ICNN` and {class}`~ott.neural.methods.neuraldual.W2NeuralDual`.\n", + "For more details on how to customize {class}`~ott.neural.networks.icnn.ICNN`,\n", "we refer you to the documentation.\n", "\n", "We first explore the `identity` initialization method. This initialization method is the default choice of the current ICNN and data independent, thus no further arguments need to be passed to the ICNN architecture." @@ -128,8 +128,8 @@ "outputs": [], "source": [ "# initialize models using identity initialization (default)\n", - "neural_f = models.ICNN(dim_hidden=[64, 64, 64, 64], dim_data=2)\n", - "neural_g = models.ICNN(dim_hidden=[64, 64, 64, 64], dim_data=2)" + "neural_f = icnn.ICNN(dim_hidden=[64, 64, 64, 64], dim_data=2)\n", + "neural_g = icnn.ICNN(dim_hidden=[64, 64, 64, 64], dim_data=2)" ] }, { @@ -141,14 +141,14 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/michal/projects/nott/src/ott/neural/solvers/neuraldual.py:276: UserWarning: Setting of ICNN and the positive weights setting of the `W2NeuralDual` are not consistent. Proceeding with the `W2NeuralDual` setting, with positive weights being True.\n", + "/Users/michal/Projects/dott/src/ott/neural/methods/neuraldual.py:154: UserWarning: Setting of ICNN and the positive weights setting of the `W2NeuralDual` are not consistent. Proceeding with the `W2NeuralDual` setting, with positive weights being True.\n", " self.setup(\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "243d6aa24b1d45cba5ba10522373dc3a", + "model_id": "62abc21c2f8b47c09c328cb9ef44efd1", "version_major": 2, "version_minor": 0 }, @@ -191,7 +191,7 @@ }, { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -221,7 +221,7 @@ }, { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -264,7 +264,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "To use the Gaussian initialization, the samples of source and target (`samples_source` and `samples_target`) need to be passed to the {class}`~ott.neural.models.ICNN` definition via the `gaussian_map_samples` argument. Note that ICNN $f$ maps source to target (`gaussian_map_samples=(samples_source, samples_target)`), and $g$ maps target to source cells (`gaussian_map_samples=(samples_target, samples_source)`)." + "To use the Gaussian initialization, the samples of source and target (`samples_source` and `samples_target`) need to be passed to the {class}`~ott.neural.networks.icnn.ICNN` definition via the `gaussian_map_samples` argument. Note that ICNN $f$ maps source to target (`gaussian_map_samples=(samples_source, samples_target)`), and $g$ maps target to source cells (`gaussian_map_samples=(samples_target, samples_source)`)." ] }, { @@ -274,12 +274,12 @@ "outputs": [], "source": [ "# initialize models using Gaussian initialization\n", - "neural_f = models.ICNN(\n", + "neural_f = icnn.ICNN(\n", " dim_hidden=[64, 64, 64, 64],\n", " dim_data=2,\n", " gaussian_map_samples=(samples_source, samples_target),\n", ")\n", - "neural_g = models.ICNN(\n", + "neural_g = icnn.ICNN(\n", " dim_hidden=[64, 64, 64, 64],\n", " dim_data=2,\n", " gaussian_map_samples=(samples_target, samples_source),\n", @@ -295,14 +295,14 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/michal/projects/nott/src/ott/neural/solvers/neuraldual.py:276: UserWarning: Setting of ICNN and the positive weights setting of the `W2NeuralDual` are not consistent. Proceeding with the `W2NeuralDual` setting, with positive weights being True.\n", + "/Users/michal/Projects/dott/src/ott/neural/methods/neuraldual.py:154: UserWarning: Setting of ICNN and the positive weights setting of the `W2NeuralDual` are not consistent. Proceeding with the `W2NeuralDual` setting, with positive weights being True.\n", " self.setup(\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "c4e2a1cdac674c588497d0803d003ec2", + "model_id": "fdf9e1aeda2b473c93d15d4815247286", "version_major": 2, "version_minor": 0 }, @@ -345,7 +345,7 @@ }, { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -377,7 +377,7 @@ }, { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiIAAAG/CAYAAABlpLwqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9d5AkWX7fCX5ch47UOkuL1nJ6erqnR2EUBsTOgVyC5C2wIGlY2i45xILkmd0St2Zn2Ns9gmu0u4OR3DWQXBx5WAIEDjyCAEEOgJE9qrundVer0iIrtQgtXd0fv3rlnlGZ1dWi9PuYhWVkhLuHcn/v+37SiOM4RqPRaDQajeYmYN7sN6DRaDQajebuRQsRjUaj0Wg0Nw0tRDQajUaj0dw0tBDRaDQajUZz09BCRKPRaDQazU1DCxGNRqPRaDQ3DS1ENBqNRqPR3DS0ENFoNBqNRnPT0EJEo9FoNBrNTUMLEY1Go9FoNDeN6y5EFhcX+fmf/3lGR0fJZrM8+OCDvPTSS9f7ZTUajUaj0dwG2Nfz4JVKhU9+8pN87nOf4+tf/zrj4+OcPHmS4eHha9o/iiKWlpYoFosYhnE936pGo9FoNJqPiDiOaTQazMzMYJpXt3kY17Pp3d//+3+fH/7wh3z/+9//QPtfvHiR+fn5j/hdaTQajUajuREsLCwwNzd31W2uqxC57777+PKXv8zFixd59tlnmZ2d5W/9rb/F3/gbf2PH7Xu9Hr1e7/L/tVqNPXv2sLCwQKlUul5vU6PRaDQazUdIvV5nfn6earVKuVy+6rbXVYhkMhkA/t7f+3v87M/+LC+++CK//Mu/zG/8xm/wV//qX71i+1/91V/lf/gf/ocrHq/ValqIaDQajUZzm1Cv1ymXy9c0f19XIeK6Lh/72Mf40Y9+dPmx//a//W958cUXee65567YftAiohSVFiIajUaj0dw+vB8hcl2zZqanp7nvvvu2PXbvvfdy4cKFHbf3PI9SqbTtptFoNBqN5s7lugqRT37ykxw/fnzbYydOnGDv3r3X82U1Go1Go9HcJlxXIfJ3/+7f5fnnn+cf/IN/wKlTp/id3/kd/vk//+d87Wtfu54vq9FoNBqN5jbhugqRJ554gj/4gz/g3/ybf8MDDzzA//g//o/8+q//Oj/3cz93PV9Wo9FoNBrNbcJ1DVb9sLyfYBeNRqPRaDS3BrdMsKpGo9FoNBrN1dBCRKPRaDQazU1DCxGNRqPRaDQ3jeva9E6j0dwdRFHE4tI6zVaHQj7L7Mz4tkZX7/W8RqO5e9FCRKPRfChOnlrg6994jlNnFuj1+niey6ED83zli09x+ND8ez6v0WjubrQQ0Wg0H5iTpxb4zd/6I7YqdWamx8hlPdqdHm++fZql5XW+8Nkn+OZ3X9z1+V/8ha9y8MCstpZoNHcxOn1Xo9F8IKIo4p/+s3/Lm2+f5siheQzDuPxcHMccP3mBZqtNoZDj6KE9Vzx/4tQCM1NjDA+XOH324jVbS7SbR6O59Xk/87e2iGg0mg/E4tI6p84sMDM9hmEYxLEJsY1h9jEMg1Ixx5tvn+bpJx/CMAzCYAjDiDCtOoZhkM14fPO7L7Jv7zSHDsztaC0ZFCPazaPR3HnoZYRGo/lANFsder0+uawHQBiMEwQTxHFMrd6k3mjR6fYwTLGERFGBKCwCYhFZXF6j0+0xNzNOsZDDsiyKhRxHDs2zVanzJ998jiiKLr+ecgO9+fZpRoZLHNw/y8hwiTffPs1v/tYfcfLUwo3/EjQazYdGW0Q0Gs0HopDP4nku7U6PYiEHsU2z2eLEmbfZqtRpNFu0Wh3efPs0ruNQKsQoP3C90WZ9o0ohn8V1HaIoC8SYZhfDMJiZGuXk6QUWl9aZn5skiiK+/o3n2KrUt7mBlHA5cWqBP/nmcxw8MKvdNBrNbYa+YjUazQdidmacQwfmWVreII5j6s06p88usbZeIZNx8VyXoaEi6+sVXnntXZrN9uV9e/0+zWabUjFPr+fTqOcIg5HLz2dzGXp9n2arA1zpBorCImFQALhCuGg0mtsLLUQ0Gs0HwjRNvvLFpxgZLnH85AXOnrtAr9+nkM/SbHbI57M88dh9jI8Ps75R4cLCCmEY0Wi0OHHqAr2+z8ZmlRdeeot3T5znxKkFNrdqAHTaXTzXoZDPAle6gaKwRBQNXX4vg8JFo9HcPmghotFoPjCHD83zi7/wVfbMTbGxuUHf96k3+0xPPMCjDz7Bwf2zPPbwUeZmJ6jWm1y4uMr5hVX6fZ9yuYBpmQwPFXBch1qtyauvn2Bjs8rSyiaHD84zOzMObHcDAcQYQJLwNyhcNBrN7YMWIhqN5kNx+NA8H3vsHgwzwjJNwsig041ZWOixsVlldKTMU08+yPTUGM889TAH9s0yMzXOp59+lFIhT6XaxCAmn8+wVa3z7A9fxXFsvvT5Jy/Hewy6gQBUMnAcx1cIF41Gc/ugg1U1Gs2H4uSpBf7jn/6QRrOGaQwR9GOazQ71eou3TrzFkx+7n2w2g2lMUK83OXVmgbHRIUZHSjz68IMsLvrU610azRabWzUs02Rjs8qffvMFiCGXy9BsdXjkocMsLq9J/ZGJcdyMSafRYmllk5GREj/5had0oKpGcxuihYhGo/nAqGyWza0aURTQ6/l4Xg7Pc4ijiPpSi289+xK2ZTAz+TFOnVnk3MIbjI6UmZ+b4sjBjzE8BCur5wmCAMOAMIpYWd3gm9/5MX/6reeZGBvGdW08z2WoXGR6YoJGq0OvUsW0Gjx4/0F+8gu6johGc7uihYhGo7lmVFXTeqNFs9Wh2Wzz+rETtFodDDPEsmwMY5QgCIhj8H2fIIgwDBgudzEMi3anR39lg3qjTatVJgxser0+hhEThhFRKK9xcWmDjOdgWyZPP/kgnW6fpeUNhocmefrJh5ifd5mecnRlVY3mNkcLEY1Gc02oqqavvn6c8wsrdDpdTNOgUm3i+z6O4zJaKlEq7mVj8xStziZBIAXJ4hhMw8I0TEzTIIxi2p0Ola0Gtp3Fsix8P8BzSuTKRRrNFbq9PoYBm1s1VtYqeK7D5MQwGxs+Z88v8eUvPk2hoAWIRnO7o4WIRqN5T1RV0wsLK6yuVwjDkFIxz1alTrPVJopi/CAkjEJs2yOfG6HZ2by8/3B5D0OlPdQaS5iGRRgHBEFEGMUYUQSGQRSHTE3eh+vY1BrLuI6N7wdUa/Diy29j2xa2bTExepTTZy+yubVOoTCpe89oNLc5WohoNJqrcjkOpFLDDwLCMGRstEy706PRbBGGYvXw/YBut4NpWDhuZtsxxkePYNtZ9s4+SRB1uXDxxUvZLyZgQBwzNryPbKaEH7aIQ4iNGM8tM1TaRz7XIpcL8IOQTteg2VrnnRNn6ff7uveMRnObo4WIRqO5KqqqaamQ5+y5JSZGD+AHPRaXFghDG9uyCcKAQn6cMPLFyoGNaSbDS6k4RT43hmnaFHMTWJZFEITYloVhmFiWgetlcN0SQQdMK0u7U8UrZojjmDgOMU0bzzXJuBlq9SZ/9q0XcGyLSrXBzPTYNTXN02g0tx7afqnRaK6KqmpqWiZBEJLNjtPr5uh1DcZH78N1Szh2ltHhA2QyQwBkM6OMDu9PDhKDYVqASRzHRFGEaRoYpoVlWkRRTKvVwfdDwjAgCPoAOHYG0zBotcG19wHQ7fqMDJV4861TLC6vc+TQ/DU1zdNoNLcmWohoNJqroqqaRmGEbVs0mm2abSmlbgAZr4xtiyvGMi0AbNslKTkGMRGunSOTKRPHBlEUY1kWhmFgGAZhGBFGARgWUQxh6APgOFkcx8ax89QbbTa2qriew+joENV6k5HhEoZhEPSniCJ5D7r3jEZze6GFiEajuSqqqmm92WKoXKRebxJHEBs+YDI2cpCJsSNYpotpOhhGem/5J45jLCt5zjAgjsVa4bkZHMcijiLMS0LGMAw8t4DnZrFME9vK0Ov2GSoVOLhvFvPScUqFnBwfmygcuvyquveMRnP7oGNENBrNVVHN7ZaW19nYrBJFISA1QkaHD4AB/a0WI8P78dw8ANElkeHaWfywgwSkinqwbZtsxsP3A2zLYmpqjFbbodU0sUwbyzSZm3kEy7QxTYPh4RKmadDt9Tl6dD/5fJaVtSWGygUs2yKOZRgzjN7l96x7z2g0tw/aIqLRaN4T1dzugXsPYtsGlmURhiGmaeE5eSDGthwcWyb+OApxrCxzM4/j2B6lwhiZTB4D2TeKY6I4xnUd4sig1w0oFefIZcpgxDi2g21ZmKZJt9cHYhzbxvctllc2GB/L8oknHmR5ZZMolI68htmV19a9ZzSa2wptEdFoNNfE4UPz/Fd/9ausb/xHjLhAbFzEsS0wjEtxHB6WaWEaFoYBlmXj2BksywXDRrlpbFvKtQdBSL/vU6k2yWWnKRcmMEwb4ogogiD2yeeylIp5arUmmYxDFDrsmZ/iiz9xgFwu5Dd/6484v1BjfCRLJt+m1Wrr3jMazW2GFiIajeaamZ+b5Oihac6e6/D0kx9jdTlHrdElDEMM08K0XLJemXa3QkyM43ipNN4YDAloDfwAg5i+79Pr+RTz9iXRENH3+4RhQBTHdDp9Go0mszMTfPaZRzh6dBzPHebQIRPXhV/8ha/yh398juWVTZY3LuK5ju49o9HcZmghotForhnTNPniTzzC7/zea1RrFSYnhgmCTTzXlkDTOCImJo5iDGIs28U01ONgGCaeO042lyeb8TAA0zDp9n3Khk0cxUSRT3zp9cIwoN5sY61vsVWt0u9P4DoxK6urtNpSSfWnvvQ0W1t1Rscf0ZVVNZrbEC1ENBrN++K+e2f5wudcXnvjDVZXIjAgjMC1PRzHwjTAtCyiKMRzypiWRxTHEMeYhoNtO0xPTNLrNwn6FqZpY5kOlpUhiqWOiGEYZDwbMLFtk61KjX/zb/+YH79yGMe0COKzlzvy7pl5kscfPcq9R/fd7K9Go9F8ALQQ0Wg07wvThJnpcWZnPseJE02+/9xbbFY8MhmPOA6J4/hymm6xMMHE6AEc27z0uCEBrm6eQn6ESsXAsm2cS3VH4ijCMGNGhsYJAp9mu05MjOc5bG41WVpaw7ZsesEqH3v0XrJZlwsXV9iqVJmbe1i7YzSa2xBtv9RoNB8Qk0w2xsDg3qOH8DyHIOyjUnUty8RzM3hujiiKL5c3i6KYQn6WQm4Ky7bwPId8oSC1RYwQz3PIuqNY1hAGFlEYEvghnV4Tvx8xNFQgDEPOnl8kn8szPTVGrbGpK6lqNLcpWohoNJoPRBwDRpcgDBkqjzI6UmR8fALHyWLZNq7r4NguYyMHME3rcgEzA4MgiGk0O5QKefK5wiWhYhKFAf2ejx9GlAtTjI0evFQa3qXbbar+eAyVhul28jTqYmUZH3V1JVWN5jZFu2Y0Gs37xrgkCEaGbWzLwu8bxHGH0I8xDQ/bcumEMYbh4Dg5MpkMSVlV8P2QoXKe4aEiMESznsEwDfygRxhGWKaJaVlYhoHjZClkp7GszCUxE5DLTdLrufiXaohkcjG9dV1JVaO5HdFCRKPRvG+UENm/b5iM57JwcRWMCgZZcpkstuVg2Q5gXKorYl6qA2LgOjZjY8P0/S5nLyzjWBGF/BxZ07yUtiuVWw0sMGxsI4NpmpfiSCCMOjhmHssEx5YS791OU1dS1WhuU7QQ0Wg075tMBtptWFpeot3uAi69Xo1iIYdpmPiRCbGF7TiMF0fxPAeLHJ1eSDbOsbhcp9NtYdsWnmPh2h6mYWLEIaYBQRSR8UpgmPR6W1imiedJQGu/1yQMMpTLRXLZ7OVKqg/ef1BXUtVobkN0jIhGo3nfZDIQRRHf+M5LRMTMTk9RKNgEYYBpmtimhWFI35h8LotJTK/fx7VthspFxseGePjBQ5SLeQlOzWWxbYuQkEw2IzVGTAODGNN0cFyb8dERbMtgcWUR0zSZmhim2+mxvLKuK6lqNLcx2iKi0WjeN5kMbG3VOH9hg6H8XhzHY8/cKL1eHsIcUQyG2SQMQ2qNHo5tkst6ZNw81fomUxOjRJGPaQ6T8bKYho1h2mS8DNUaFPNZMhmXKDIZHR0jjiLCyKFULjI2McFIeZhmu4tj2+zdM8VX/9wndOquRnObooWIRqN532Qy0On1CXwLy7GIiTDiPLZZJsbBAkZGRmh3YmamhxgZytPtZqg3fMJImt01210MI4PrFDCwsAyHfLaEbXt0ewG5XIxhxowOleh2fe675yif+eTjfPWn9/LCCw36fg/X8ZibK7Nvn7aEaDS3K1qIaDSa943jQNZz8bwCYRBg2xExkmarmtvFETiWRblUIJfL0utJOXfLtOj3Q2zTwTRDSd01TcAEI6aQz5PP5QkCn8D3CXyHmekxnnziPqanxoljGBsbxnHA96FQuJnfhEaj+bBoIaLRaN43hgEjI2Wmp6ZYXm4zVM5cCjgzLm/T64fkclmyGe/yY7Ztkc3m2Kg0mBwvkc1E+EGAYVhAjB+EFIsFZqYmqNXFqvLwgwcYH8szMSGv0G7LseJLDWny+RvykTUazXVC2zM1Gs0HwjRNHn/kHjKeTbVWxe/3iWOIEXFh2w7D5SJqmJEyIiYT4yPkc1lW1yoU8llM0yQIQsIowjINSoUc9UaHbMbhgXv3US4Vcd3kGJ1LpUJUEdVM5oZ/dM0dQhRFLFxc5Z3j51i4uKor894ktEVEo9F8IAwDpqfGeepJi2NvrVCtVjENH8tycByXsdEinpfZtj1APpfj0598lGNvnWBjAzKui2na0pnXs8EwKBeHmJqYZGwsTxSBlxhVLltE1JxhJEYYjeaaOXlqga9/4zlOnVmg1+vjeS6HDszzlS8+pQOfbzBaiGg0mg/Fvr0jeO4QjUabxUWLMLIwDZtMxrzsPlFi4VITXiYnRpmceJJz59rU6xGVSoZuz2Rs1GV8PKbTyWMYJrYN/T5kL9Upy2YTi4hG80E5eWqB3/ytP2KrUmdmeoxc1qPd6fHm26dZWl7nF3/hq1qM3EC0a0aj0XwgHEf+5vNgGCa5nASleq6LYWwfWixr+74iUEyKxQLlcgnbdrEtm1KpQKFQJAi27+9KUVUymSQ2BLQ1RPP+iaKIr3/jObYqdY4cmqdYyGFZFsVCjiOH5tmq1HUDxRuMtohoLhNFEYtL6zRbHQr5LLMz47pAlGZXMhmxVqgYjSCQv2nrx+X2MkZyX43vaddKWlyAZMM4TvK4EjKD8SBKoGjufD6q8WlxaZ1TZxaYmR7DMAyiKEscZbDsCoZhMDM1ermB4vzc5HX4JJpBtBDRANpfqnn/ZLNQr4N9aRQJQzDN7ULENOWvEhLKNaPu74Z6blC0KCESRXJsnTFzd/BRjk/NVoder08u6xEGo0RRFoixqACQzWXorW3pBoo3EC1ENNpfqvlApEUBiEXEsra7S0xTnk9bRBSDlu+drCaD2w6KFy1E7nw+6vGpkM/ieR7t5iSZbAaIsZ3ly8932l3dQPEGo+3udzk7+UtN09P+Us17ojJZej35G0ViHVFWkLRrRlnQ00IiDK88pnLTKPGi/iq3T7ebvBYkQayaO5PrEc8xOjLOnpkn2Ko2MIwejruIYcj+qoHi4YPzuoHiDeSGCZF/+A//IYZh8Hf+zt+5US+puQYG/aUAgT9J0J8FtvtLNZo0SlwocWBZIkQGXTNRtF2ADLpbdjqmYtAColJ3d9tec2eRHp/AJAzGiGNVU+b9j08bG3DhgtS/yec7nD73Mo1GiyAMaTRanDi1oBso3gRuiGvmxRdf5J/9s3/GQw89dCNeTvM+SPtLFYYREsc2gT9LJnuR3uomp85c1EGsmh1Rlg/TlABTZRFJWzUgsXao2051QAbdN4OBryp1dzALR3NnosanbCZP4M8ABqZlAnLyXGs8RxzDmTMSBA3w9NPjzM1/Nok7WdvCcx0evP8gP/kFHRd3o7nuQqTZbPJzP/dz/It/8S/4n/6n/+l6v5zmfSL+Upd2p0exkAPAdlYIgyGiqMDGWoGFxQb/+ve+jm1ZOohVcwVKJChxkI7z2EmIpPdL35SYUZYUdV+lCUMykewUc6K58yjks2QyQ3Rao2SyYFkVDCO4/Py1xHMEAZw6JfcNA44ckb+HD81z8MCszhS8Bbju3/jXvvY1/tyf+3N84QtfuN4vpfkAzM6Mc+jAPEvLG8SpWcKyq9QaZzlxegHXnqWYn2F8dIjhoSJvvn2a3/ytP+LkqYVreg1dRvnOJR0DYlnbU26VwAjD7fEiivRpoJ5XYkUJEUiESPrYkGTraO4MBseJIAioVg3K+f0sr25immuYVuvy9tcSz9FoJCKkXIajRweDqU3m5ya59+g+5ucmtQi5SVzXS/l3f/d3eeWVV3jxxRevafter0dPRb4B9Xr9er01zSVM0+QrX3yKpeV1TpxaYGZqlGwuQ6vZ4dvfe55G0+fQ3idYWOiB4WPZCxzcP8vmlgSJHTwwe9WLV6cF39ko4eB5EiuSDlRVpN0wg7Eig8XJoihJzVWka4j0+7LCtW3I5a7f59LcWAbHiX4/gGiUTHaEIAg5fuo5Tp6NePC+g8zMjNNpd1la2bxqPMfSkqSXA8zN6S7NtzLXTf4tLCzwy7/8y/z2b/82mWvsSvVrv/ZrlMvly7f5eT1R3QgOH5rnF3/hqzxw30G2qg3eOHaKb373RdbWKvS6Lc5ceIE4cvDcEUJ/jtfeOE02475nkJhKu3vz7dOMDJc4uH+WkeHS+7aoaG5d1KXteSIYBjNhlAgZFCaQPJ52zcSxCA1lYYHtVVUhcc/o1N07g8FxYqhcolEfYW29z+ZGldHRKo8+fACAl18/zutvnmKr2uDB+w/yi//llam7UQTHjyci5NAhLUJuda6bEHn55ZdZW1vjsccew7ZtbNvm2Wef5R//43+MbduEO+Tu/cqv/Aq1Wu3ybWFBT1Q3isOH5vnb//Vf5C/+zOfJ57MEYYBlWQwPlchkMqxvncbvWwyV9pLLHGZ5tUG31981SEyXUb47UOmznpek2e5mDUnHiAzGh0BSDE0VRlMoQTIYoKotIrc/g+NEPpdlfb0Iscn4WJm1rdc5c26J+dkJvvLFp9i3Z5qD++f4u1/7K3z1pz5FEIbb3L3dLpw4kVjp7rlHu/BuB67bT/T5z3+eY8eObXvsr//1v84999zDf/ff/XdYO4S9e56Hl26zqbkq16Mk+2tvnKDd6ZLxPLJZH8N0KBWmyWbGaLRWiekzOnyQzQoQ+7sGiQ2mBceRB8QYZl+XUb6DUFaKdHXVQbGhHlekLSJRJM+FYZJx0+ttr9Carsqq9ku/pub2ZXt6rkWjNkm7dQ7HDWm0jzMz+TDtboN6o0WpmGdyfJh33j3L//rP/39sVmt0Oj1yuQwP3X+YZz7xafK5aQAmJmBk5OZ+Ns21c90u5WKxyAMPPLDtsXw+z+jo6BWPa94/1yP2Qg0KI8MlLi6uMVyeIOPNYpo+URgRxxHV2gqeB547w9RYnigcZeHi6hViaDAtOAjytNsGfrCB47bI6TLKdwRKiCixEIbbM1oG40XSDBrDlDsmCLYHFEaRPK7qlei1yp1Dkp5boN89QhRF9PwWhrlIqXCvBDs3K6xvVHj3xHmWljdYWtngjTdP4nkuhXwW13WobBU5e+ZH/B/+3Kd5+ulxPC8Jfj1zbgmAA/tmdEDqLYpeU9yGXK+S7GpQGB8dwrYtTBeIpHaDYfiEkQ8xLK8sU8jVGR35KX7v919jdWOBZvvdbWIonRbc7/ucPnucOJwhDGOCqAZsMDxc0mWUb3OUVUKJDd/fOR5kJzGihIiyoigx0+8nrp70a6RriOjU3TuDQj5LNjNCp3kEx/EwzQ3anQXGR+8BoNPZpFo/S/ukQxiF1GpNgiDAdTzCMKLXi5ifeYQojDlzfpHv/vDf86lP/XWe/f4b/Kvf/o+8+c5pur0+jmUxOlLmqScf4r/8K1/RgfK3GDdUiHz3u9+9kS93RzLoU1XVUFXsxYlTC9eUzbITSjxYtsXIcImz55bYqlxgavwRIJTB3zCJohjTsnnznefJetMMD88wMlxiq/oOP3rhDU6cPM/X/uu/yKED8zz342NUqnW63T6FQofRofsIQo+N9Zj1cJ12p/uRf0eaG8dgb5jdrB/KUjIYG6LSe9P9aNS26vRVr9HvJ/9rIXJn4NjjTI59io3NCuOjbfLFTWanH6RWa2I7VdY2zxOGIYFpks9luNBew7ZtisU8I+UjmEaJTqePl+mzsvYm3/iOxcLiCj947g3anc6lIGgDx7HpdHv8pz/9IevrFf7eL/0XWozcQmgb1W3GTi2sw2CcODY+UMnjNKqmyPLKJqPDZSrVBn7gs7z+Kt3eJjGSu2+aDtValdNnL7K6fpoLi4ucvxCysOCyudXnxy+/zf/tH/4mRw7NU2+0WF+vUMhncByDreoxup0uIyPjzEw8xp9+8wWiKNK1Ru4QlIAYFChpoTKYNaP6yFhWIkzUsdRxlIhRz+mMmZvPh71mz56FY8dgamKYnn+WE2ffoN3MMzUxTLN9mlNn3pFFj2mR8Rw63RyzU49TKpbZP/9FpiefxHGznL/4Ou+efJlao8XC4hrf/O6LtNsdwMCyTCAmCELanR6NZps33j7N17/xo8vvV489Nx/tmrnNuLIku0EUeUTRDJZVJZMNqJxb4o23pIrP+wlgVTVFLi6t8r0fvkYQhhiA7/sEYUinUyWbHSIIfCzLxjTBD31q1Q5x3GF4eA+5XJ589hynz57k9/7dN8llM8zNTtBqd+l0Y0aH7sGy1zmw72Ecy+P8+Tbf++HrvPHmSV1r5DZFiQQlQpQQGawfMji+DwafWpbsq4RJOhhVHa/fl0wdnTFzc/kwMWpxLALk3eM1Wcxs/JhuF1pNh3dPXcCxl5mZLjMyksX3fc6eX2Zq4mEs02Gr0mVu5mco5qcJgg7nL/6YenOTUiFPFEZEUQwk/kDbyrNv3ydY3zzNVvUsnbhHpVLjtWMnWFxap9vt6zpHtwBaiNxmDJZkN8022DFhMEKt5nFhIeT0uSV+9/f/jD/95vM7XlTvlW3j+yGbm1V8X2aEXHaUYm4Mw7QIo4g4DgmCCNvO0Wp1KeUnsCyHfq9HsTSFabpgFNjYvECr1eETH7uferMDWIyU9pHNHcW2awRhwOJqm3/370/T7Z9jZnr4I4t30dwcLGt7lku6Pkg6gDX9VwmRdDVVZQ1R+2Qy0vBOCRFtEbl5fJgYtSiCV1+F02dqvHHsOBvVF5gaP8L4qEOt1uDEmR/iuln+2s/9Z3zyqYf48csn+K3ffo0YA8ccxjR8splROr1Nzl14jnpjVc4vDHp9//LrjJT3cfTQlwCDOI5YXjmGAYRhRLPVYXF5g7fePct3v/fyRx5rp3n/aCFym6HcJ2++ffpyjIhpdqg03+HMaag3uhza/xT3HC2xuXVRYjZOneeX/pu/xNHDe6+6kgH4zd/6I5aW1yVY1ZC1xdjIQcLYwDE9iDoEYUQUhRiGDXFI4F9gZGgPnjuKYw4R2QblgoWbGeVHL/wnvvuDV3FdB9u2GBupcWj/ExQKw3Q7a6yunmZ68hD79z6B465jGP5HEu+iuXG4rqTcQlLmXTFYxj19P53am7aoKJdNWoioY6p9VJEzzY3lw8So+T68/DI0GhFnz59lo/oce+cfotXqcfb0SRaXjxEEIUvL6/zG//vfMVQe58I5i+WVLUzGyWYMHHuIxeVjLK28RK2xShhFTI4fJZsZplJ9EcMwOXrwixza9xlMw6HTq3Jx6RXqrdXL7yOOYrY2a3z9z37E5laNhx84dPm96rHn5qCFyG3GTiXZM1mPt945zcXFNWan9zAyvIdTp9s0mx7rm1XOnLvIr/6D/42//vM/zbeffWnbCqDV7vLjl9/izXdOk89laTRamKZBu9MlujTo9/pNxkcO02itYJkOQdgnmymTy44QE9HurLJVvYAfBOTyJSxziEzGIAjyTI49TKd/nuGhPH4QsbK2Sr3xbR598CfY2PTxvBGmpiIMw6Df3YdpdbCdi7rWyG1EJiNCxLYTi8hO5doHA1l3KmgGVzbKS++ra4fcXK6oDxQbELsYZu+q12ynAy++KBatOG6ytPo8U+P30GqZXLhYZ3n1NJmMh2NbeK7N2hr8v/7pt9nabOK60+Szk8REnL/4HOcWXiAIe5imzZ7Zj12KlYtw3SKPPvCXGBs+iOvkaDSXOXH2O/K+Jh9mafV1AKI4ZmOryvd/+CrFYp5+3+fQgQOMDOcwjFiPPTcBfVnfhqiS7MqyUTm3xNp6hfm5SWamS5w++woZby+e57J37mNU60ucOn2Cf/Tr/zvjY8M89shRDMNgc6vGqTMX2dqqsbFVo9Ptkct41JttwGV0eJat6ln6/Sat9jpjI4epVM9j2x5B0MO2pYjEcGk/G5XTRI0len6TffMPk3FHafqbzM3cT6s9QqV2CtvpMlQusLFZ5bmX/hMP3PNZhsojeO4UlrVOFO0j9EeJIw/HO33NLb41N5dMBmq1RIioeiD9vlg7lDAZFB7prJntjciS+0p4KGGiLSE3l3SMWhxbBP4UhuFjm2sAO16z1apYQuIYxseh268TR6OEYZELC2eoNyoUChZ9P6Db7TM6/CDdjs/aWodsZoZSYZJev835iz/m7MKP5HUyw4yPHsY0TTqdGpbp8KmP/02ymREc26Pfb7BZOQPExLHUQEoTBCE932cqnyPwZzh7FqDL6Ii36+fQXD+0ELlNSbewfuOtU/zu7/8Z99+7n1deP0EYZcjnPUzTI459yqUZiIssLL1GLifiYXOrwdmzDq1Wjky2Rz7fZ6siabbEYFoO5dIM+dwo65snCMMAA4ODez/NZuUcm9XT9PtNRob20epsMjZ8mGrtPJ3OFksrrzI9+QiuM4RjF8h6Y2S8Oc5f/CErq2fJ5zLYtsVnPj3GD3+0Rr8XYJrjOO5Z/N4horBAv3uUvv/ae7b41tx80kXNlOhIC4/BGBHFoDBRj6VLwadri4C4aHTq7s3jcoxaOyLjTQFgmklH3E67u+2aXVqCd96R33N+XoKMX3ipR6tZ5MLF16nX66ysvQIG2JbL7PTH2disU8zvIZ/Lks0M4YchQdCn1lgGYHL8XjJeCcv0sG2P8eGjTE7ci+fksZ0sldp5wtCn060SxxGN5iqV2vltnyOKYgI/Q+jP4hUz9Ht9Tp1ZZGT4XgzDuOJzaK4vWojcxqgW1gB/+s3nWd+osVWpk830iWMfYk8G9MgBw2R26mHiuEWt3uLUmQVCf5bxkXkw5lhd+yEABoBp0O+3sEyH4dG9uE6eXGYYy3KI4pix0QOUS9OsrB6j12/g2nnCyKdQmKbba7BZ2aDZ/BF75h4n4xXI58pks3lKhZ9iafVV6q23GRstc8+RvSwurfPu8QvMzx4iDCZwvNP4/b1EYZZ6/RCH9m/t2uJbc2ugKp3atlgz0jVDduszoxi0kChXzqArxveT57J6brhpzM6Mc2DfQc6c8ZmeirGdCqYpVoM4jlla2eTB+w8yOzPOyZNw4YL8lvfeC80mHD+xzmuvblJvnqPbrbO89jL9vo/njnJw35ewTZMoNshkRrBMhygK6XQ2MU2brFdmz+zHMQyDcmGGTHaIsaED2E6WrFfGMKDV2SKKQs5ffIG+32Jp9Q0cO8Oe2Y+ztvEuvX4dy7Q5uO/T+H6fVrtLrXGBQkHywuuNNqVibtvn0Fx/dBTOHYAKYF1cXsf3AxzHpNk+RRi1yXoz9AMoF8pkMzlMo0irOU6t1qUXnCaM2vh+xIG5r3Bwz6eIgVx2CNfJslk5w+bWGcqlWVw3TxD2qTUWLtUSsZmdeZx2t0ouO0QxP0khN04uO8TI0AE6vSaLS68TRpuEURfLzJLJDLN/zzOMlB9la6tNIZ/lK198ilKpx4WLb9LtdOn3R+n3TrCxeYFcNs/0xOfpdPRpeiujXCnppnWwsxsmfX+wM28YJp131WNK5KjUXdAZMzeTTsfk4fs+SSGf4/zC67RaGwRhSKPR4sSpBUZGSnz580/xxhsmFy7IPo8+Co0GtNsRL7+yQr11nsOHyvT983S6PqPD93LkwBcxsHHdUfK5CVxbfvhmaxXTcqnUFji0/7Pks6NMjt3LxNg9HJh7Gs8rkssM4Qcd2t06W5VznDz7LRqtFS4svsjk2D1MjB2lVJhieuJBxkeO8OiD/0dGynvw3AxLKy+zsXWRSrVO3w+oVOuXP8dPfuEpHah6g9AWkTsAFcB64tR5zpy7SKvlkstlqDcu0Gz2KJVmmJ2a4cLiO/iXuqMWcofIZALa3fP0egYZ53EmJ44wPLSXpbU3GC5l6HQrLK29ThQHlIuz5LJDhJHP+sZxhsp7ME2bqfH7CPwuhhmTyZTxvAKN5grjo0eoVM+yun6GuRkPbAtiG/AYH3mIfLdMu21y7z1JvMvpM6eI4ykcu8C+fSYP3LsPgzLPPw+PPQbDwzf7m9ZcDeWi2anB3W7/Dz4XhmIFSTe8U9aUXk9eQwuRm0O9Lq6W6elxfu4vP8y3vieW1d7aFp7r8OD9B/nSTzxFrTpPvS6/40MPweqqCJGNzSZrm6cYGelRLBocPvBxxobH6Pdb5HJj5DJDBKH0COj7LVqdCnEc0e832Tv/CWr1pUuiYoZ8fpQ4DMlmh4miANt0OX/xBda3TrK28S4Ae2afwLZchof2sVU9y+zUw0RRQD43zPL6a5y9+D1c18LzXIkFiaHT6fHwg4f5yS/oOiI3Ei1E7hAOH5rnl/6bv8Sv/t//BafPLtLr+9h2m5Fhn5npaXK5LuXCHmyrSrvTk1ofOGS9Q2xsHuPM5r/loXt/GtctMDPxCGEYsL51kimvzNLqG0RRQKdToVyew8sUabY3sEybUmEax8limS5+0CabGYY4otFeZ2T4AK32GpXKFpbZolicJwx9XGeI4SGX11/zGBneHu9Sq3epVkqMDJcpFk2CAM6ckWC3+++H6emb/U1rdkMtHtPBp4NVVAdJP66sHoaRxIakm3SrFGHd9O7GU63Cyorc378fPG+Oe+/9i9vqEY2PjfPyyyadjojFw4dFhKyvy/+O06HVPs3ExCS97jy2WcBzG5QLc0CMaXkEQZ9mew0Dg26vyWh5D9nsEJ5XJjNWxnEyEMvJZVnOpZMnptpcJpcdZmPrNMPlPbhunkJ+kij0aXcqzM88jmU6WJbN0tqrnF94DtuycF2X6alRlle3+MSTD/F//uWf043xbgJaiNxBHD28l1/9v/wN/slv/H/Z2KoxOz3G+LjU6zh/8Syjw0f42KMPc+58la3Nl2k1DRzXZ276ITxnhdfe/COymTH2zH4aDIM9Mx9nceU1ZiYfotlao1JboB902Df/FL7foe+32KqeY2LsCJbtUihMEARdrPw4luXR69eAmHanjesGGK0lysU5grCNaxVYW495/nnxHx8+LPEu88jkdOIEtFoyKd1/P7z1ltw6HThw4GZ/05pB0jEg6n91U2m86nHFTsKk399enRWSuBMlSnSw6o1lc1PEBMDBg0lNl3SMWqsFP/qR/M7DwzA1JfssLcHEhOwzOgael6HT3EMUz9LprlDMTxETYWDS7TaoNhbJeCWiyOfg3mdotTcpF+fAMHCdLN1uHdvOEBGCadDrNWg0V2m21ljbPMHk+L0YwPDQPirV88xMPkQYhdhWhmZrlbXN46ysvwFxgOs6uI5No9GmWMjxl//C59m7R690bgZaiNxhHD2yl1/+W3/lcmrv2fPLeK7DA/fNc++Rcd548yT1doex0f1UazXanVXAIJMpMTf9BAvLb/Dqm/+Gew5+BZBiZq5boFI9RxB2OX/xecZHD+M5BfK5MXq9OvXmKkPFacKgi+cWiKKAGHDsDLadxzDqOI5FGHXBqGEYDlHYplIpUS5FvP66SbcrgkOlf95zD5w6JTED1ar4mV99Vawj7TY88MDN/JY1g6SDVNMiAhIhMhgnov6mM2pU513X3e7iAZ26ezNYXYVKRe4fOrRzHZfNTXjtNfkdJyYkmLhSkX2np+X//fthY2Oc6fHPUqmaQBfPGSYmwrayGEA/WCaTKZFxCpRLs/T9DuPDBzFME8t06fsdMpkycRQRhRHV2iLEMe1OlWZnkyDsMTV+P77foXPZCuKBEXLu4vNsVs5Qb1zEc10sy8QwIJvxcFyHTz39MM889fAN/GY1abQQuQNJuzqU2bTd6fIv//c/lmJmU5OMjhTJFwpcvGhQqS0xVBrGNJvMTNxPEPZ49/TXKRVmuP/IT+PaWfbNP0Wv16BUmGFh6WVGhveRz45TKk5img6dXpNur85QeR7bcslnR2l3qxSdHLGRxbIiTFw2KsfJ58YoFUfo+S3a7TJR5PHuu2J6f+CBJCvi0CFZUdXrsLYGjz8uLpqVFdn2scf06vhWIZORVXG6gJliMCh10BKSrr4ahrK/be8sRPTvfeNQ1x6ImyXtJlMsLIj1Mo5FbPi+nAebmzAzA2NjUjvk3Dl4/XWTsZF9bGxcpN2OiAHPLRIEHeI4xrI8Rob2AzGGYZHPjmA74q4Jwj6uk5fiZQRUqucxDKjWLmLaDt1OnYfu/Qu0OxVsyyOKAjyvwGb1LBeXXmF1/XV6/d4lt5/Jnvkp5mbGCcKI2Zlxfv4vf0W7Y24iWojcoaTNplEU8U//2b8dKMu8TC5nEUZVxkfnyWUzlMsHqVZ61NuncZwM6xtnuLj8Kntnn6TbreO6eSbGjuI6eVqdTRZXXiOKH6BcmMbzivT9FpXaBcrFGVxi8tkRsYJgEYQ+URwwPnqUfKGP60CjsUUQSC2CSkUmpGYTHn4YRkbkc8zMyCS3tiYC5PHH4Y03ZPsXXoAnnth5gNTcWHYSIoNumMGaIXBlOXdV7t22rywB77o6PuRGceGCWB4Bjhy5UlwCvP02LC/Lb3PPPXLt1usiRmZnYW5O6oa8+aZsW6+3WFhsXHK/Obh2jl6/iW1naDbXGR3aTxSHOHYW4hjHydDzW7hOFgO5yDvdKs3WKr1+gzDq0+lVGc7s5f6jPw3EFPPjWKZLTMzF5Vc5cfab1OoXGB4qYncNTNPA81zCKKJYzHPk0B4dmHoLoIXIXcAVZZkjB4yQfrDJxtbbzExZNJse5dI4ET2GiochLtDvB3humfOLPyafG2Vq4n7AYGriflqdCqPD+3n31J8QBn2ymRKjI4fw/Tat9gZxNo9pjpHLFjEsIGqBCZZhgtEnClqYZouRkfhyNkStloiRBx+Effvk/Y+MyAS0sCBi5L774ORJ2e5HP4JPfGJ7fxPNjUdlzERRUk8kjkUkqmyYnWJC0uIknfprWUkXXhXE6nm66+6N4MyZJFX66NGdC9G98oq4TEFcqrWaxIRkszA6KrEkYSgWzLNnodFocf5Cg2arheMUsG2bdquGYViEQZfR4X2EUYBjZTBNE8Ow8C+5elWg0Vb1ImHUpVq9QCYzxPrWWY4e+hJB2MdzcvSDLoZh0u3XWd88yY9f+/9gWy4zk/cTRGv0+z5HDu3l4IE5tip1fv6vfIUnHrtXW0JuAbQQuQtIl2UGCMMCcZzDtTNY9jssr71GoxHT7tXxnBny2SlMI8/8zOMsLb+BYRgEUZ8LCy+wd8/TNJqrZDIlspkSD97zM9SbF9msniMI+4yPHsK2M2QykCt0CH0b18qAWQQupT1EBp1ei2Jhnrm5DLWarL5MU8SF78NLL8n9++6Tx/N5CVI9cwY2NmDPHjEd12rwve/BU0/pSepmkq6u6jjbV9AqYHWn8T4dI6JESxQlsSZq/25XhIhO3b2+nDiRfP9HjlwpQoIguTY9T67JalWuxbExeezIEbkuX3lFrtUwjFhcatL3e7hOiTj2ieM+lgWulSeOTUzTxjBMHDuDH3SxLBPHyWJgEAQ+rc4a3d4mPb+DlymTyQxx9NCXMYCsVyaMAwzDoFq/yKtv/h5b1TNMjt1DNjNENutycfk8uVyW++87wFC5QLvTpVTMaxFyi6B/hbuAy2WZOyIELKsJSL+GybFH2Nzs02iu0+qeIIovsrrxKqbhASYT4/cwXN5HFAZMTtwr2TL9BsQRrdYGI0N7GRs+yv65T9BonSKfbzA8bJDJDDE6MoJh1un5DWL6QJY4Nun5PrZZIpvJc+pUj0wmYmhIJhnXFSFSq4k599VXk9WZ64qvGqQuwfAwTF7qR/Xcc7KP5uagLFLKGgLJJDYYqArb76uJz7blvrKKpN0w/qUO77qq6vUhjuH4cfn+LWtnS0ivBz/8oYiQYlHcppWKiJCpKSiXxUVz4QI8+6yIkFYLqlWffj/AtYv4fhvbyhCHEYX8NK6buyQ4wLGzRFGI5+axLEdKrffr+H6dMKjgunmi0Gdy/B7y2VFymTKOncM0bcIo4PzF5/n2D/9nOt0t9sx+nEymjO2YrKy/SS7n8NQTDzA2Utbl229BtEXkLkBVXn3z7dMSI2L6OO4iYVCCGIbLB3FdG89dIYgXCeKQ5bUGWW+S0ZGDEMUMlR6h22vgB1VGhvZQqV1geGieenORrDdCoTDG2MgMhdIWGC2q1Sam8QizM7NsbK7T6VQwKQEOBnnCqMnmZptKpc3SMuyZLzMzU8AwJNANYGtLBsZWCx55BEqlZJA8cUJWZ44De/fC+fPS3fOhhyRyX3NzcF2xXsD2lN50QOqgCFEWEZWRoVwyysoCiRjVC9iPHiVCQMTf/v1XblOvy/UVx+IqzWZlMaCCUsfH5XEVDxKGYimxbRVrksGwAxwnT2xYeJkxiGMMw8QyDeLYAuJLjTRj4jik1a4Q06PXa1DIjxGEUCxITJnn5jFMA8s26XS3ePXN32N57XX27/k4lmkThiE9fwPDanJg/wT333OAsdHyFWXoNbcGWojcBajKq0vL65w4tcDM1CjZXIbV9TOcOv8G0xOPkM/niKNJDMPD771FvXeSZmudTm+T+ZlHsEwHjAxhGLNROUchN0a3VyMKAxwnhxXFZLPj+L08YVQl5nXmZh/E9wu4zhyG0aNW67C55WPEJhlvGIMeMTHdTpdTp1fxfZO9e3OEoQxymYwMZlEk/z/2mKQDGoaIERVQFwRiDj5xQgJZDx8WcaK58Xhe4mKBna0hcOXzqmNvury76yYiRafuXh9UzR4Qi+T8DjGbq6siMOJYrj/LEmHS7UpQ6vy8CI4XX5QFQRiKddJx5Lo1TROMGAMHx3GIQqkbEhNgmQ5+EGBfsoBggIFPv9ek290ikyliGBaeN4Yd9iC2iSI5QdqdLZqtRV469i8o5se45/CnKJcKFApZCoUNmi2TdqfEwX0z5PJZGo0WSyubunz7LYgWIncJhw8lpdRVWeZut08+l2F2tkuhkCUKRvDscYJoL+fOv81G9Q1WVl8jm7HIZWYpFefoRnWKuTFa3Sp0IsZHDuD763juNLXmIllvlDDMY1uPA31mZqBWM1la8mi2KgRhk6w3BPhABoMA143p+SaLyxU8L8PUlInniQnYMGTQ6/fF/XLvvWL+NQyJE1lfl1VZqyXxJG+/LYGs7bZsq7lxGIaIRyUmVGxIOmNm0NyfLnSm6o+ox1w3CVL1PJ26+1EThnKtgLhVdqpafOaMBJvGsQSPB4EIk0xGLI8HD8p1+txzch32eklRulpNLCfttoVj2YShh2lGxLFJGAWYpk0cg2MrH1yEYXaxjAgv0yPTyxPHMFTeA1GAbXqXAp+71NsbvPzG/06zdYG984/hOC6e62DbFUZGe/zif/kXAC6Pdyvrlctl6HWWzK2HFiJ3EYP1Rd5+9yy/8Zv/jjffPo1hGNi2xczkPUyMPkgmM8WB+Rky7muY1jKmNUarvYbjFHDdDGHkkc8ViNjANkt0ulUwQkzDwjQdsvYIr77aY26uxZNP5un125y90CPjFcDogBFDZCJhSllcJ8LvQ63mE8ce5bLEgDSbMmB2uxIncOyY3H/wQVmFjY/LJKUCV5UYWVyUKqyPPqonsBuJ6hOjRMduabsKVewsHVei9rOs7am7OxXT0nwwgkAKBoK4VHZyZ77+usR5gAiOXk+uK+WaOXpULCCvvirPqZojvi//iwiBMDQxzSJh2Jd6IaZNFIWAKVl8QBT1iOMOnh0AHpZpUixMSO0PDGJsICYIG3T9BU6d/1327NlD6M8RRhG2FVIoVnnskaPbhMZgPaXZmXFtCbkF0Zf2XYaqL3Ly1AI/fuktwjAkDCPGx8oEQcSFxbdYWj1JuXA/+dwse+c+zux0jm5/jSAwiaI8cWgTx3vYv2+CE6cWqVQ2McyQYmGeMN6imJ/FMloEQY7l5YgXX4wYm+jS6y+S8SaIoyGIYsQq4gEBBgVi2vR6SWl3x4GhIZmIajX5u7EhE1e7LSIjl5PYEdeVokm1mlhMTp6UGJMXXoCPf1zHFtwIXFcmICU81He+WyGzwcdU0CokVVrjOIkF0hkzHw2+D6dPy/3xcUm3TRNFknZbq8lvuGePXJOrqyJYSiVxxxw7Bu++K7/P1haXrZiqIJ1aPACYpoXrOoShRRiGl3pdARZk7IBs1qPVbtHrWcRxgGWO4jgxjlMGQny/ix/V6Uff5fCRMn/+q/89nuvQ7vbI5SsMD3mUivkrhEa6npLm1kULkbuQKIr4+jeeo1Jt8OTH7ue1N05SqTYp5LMMDxVY3aiwvv4sw0OTPPPkV4njHJ6zn2y2CvRYWdlgcuIApWKebCZDdnoE08xgWiG2OU2Mj2lEWBj4gcH6eo5Go4zrjBGFq9hOn9AfuzTTtAGXmD6Wkb0kdmQQLBbF9TI5mZh6LSspftZoSEGz0VExFR86JKu8RkMGz+Xl7bVG9Ir6+pLJiBBRFg7T3N5nZqesmbRrRmVspMu7x7FMaMWiTs/+KOj1xNUCkukyNLT9ed+HH/9YrImZjFglGw0RGtPTIkSKRdnmwgUuLxw8T+K5cjnZV6Vcp910Mt1EWJcqEJomFAoGhYKD70O5PMnaWgyERJF1yUprks9HjE9E3P9ATLPxVYqFMqZpptxJUzfo29NcL/Q68S4kXeBsbHSIRx9+lAN7nwZKVKpNHMtmaKjInvkSa1vP4QdniOKQXjvHxrpBvgCf+PgE2VyPvt8gkynjeS62VQRkSRvHWQyzRhT3yOX69PsWGWeWXn8e02pgu+uYZuPS9n36fhfLAsuyLseGtFoymC0tyeQ0MSGDoyp+trIiNUTU6s62xVwM4qceHZVBs9uV7VT3Vs31QWW5qDLtahK6mnsmnU3j+0kdEpU5o6wsoC0iH5ZOJxEhs7NXipBOB77/fflbKIgLptkUgTE7KwHgliWpuefPi0BptxOLZS4n12wQJJlTKh07itRvagIWmYxFuWwxNGTi+3KMtTUT27YA91KAqsnICPzsz5p8+Ut5LGMf5dIwpmly8KDuxH0nodeIdyGDBc5GhjMU8xPSeyEIMaxV1jdX+c//s89x/NQFTp05Rb9/kkL2XsbG9nFw/yOUimVsuwrGGlFoYzlFjBhiowtxDowecZgHWpRKMZ5nYpoeYThGs5HBdeqY9hoxWfp+CdvKUCrZhKF5Odre82RycpxkpaxWaCpupN8XM3KrJX1qVMM8VR2yXJb9KxX4wQ/gySdlkNV89KSFiMqASZMOVh1segfbhYgSKKaZpO7qUv4fnFZLKhODuFUGRV2lIgXI4lgESqEgrhjHkQn/4EFZELz8sgjDrS15rtuV38p1RZQoV4yybikRojBNESz5fOJ2azSS88L3ZUFRLMqi4tOfloWGOkcmJ2UM0NxZaCFyF5IucFYs5DCMAMddwrLGiaIs3Y5NMTfM0SP7+eLnP74t2CufG+f4cZNqFeK4xMTEGBcXzzA1sYeIPMRZoAuxR7fXJZfLEIYZxsfBslxcN6RaLROELt2+QxieZmTEZLg0ThR5WJYMbmolbFkyMFWr4pv2fRmIVCn4MBT3TTpuRFV8XF5OVmqeJxaUF16QmiSDfnHNh0cVINvJIgJXio9BK4kSIiptF7bXItF8MOp1EREgVo3BonBLS/DOO/I9T07K9ba4mAiSQ4ckfff4cRGFlYr81rWaiE/fl9+810t+r0G3HMg+mYwcs9uVbTc2tpfz9zyxxPzMz8hxVUCtbcs1rWO97ky0ELkLuaLAmWFIiW17HSMy2Ko67JmfptMeZ2nJZG52clvmySc+IRaHc+dM9s3fS60asbJ6gZGhEpY9SRQ4dHodHMdkZGiGbNa8LAjy+SyuG9Fs2gRBiWx2jv37LUolETcrK0l11SCQwarRkMGrXk8GvrEx2a5alYFMFT9rNiVuRPmPs1k5pmVJ3MiFCxLlf++9Ym7WfHSkM1+U9eJqQmIwpTcI5BiqSquKNdFddz84lYpYNkAKlQ02DTxxQiwlcSyFyaJIhMn4uFxDU1Mi3i9eTFyljrPdFaOsHmmhOdg5WVlBTFN+ZxXUCiJuXFesIA88INfvuXPJvjtZcDR3FlqI3IXsVuCs0+5eLvjz5S88jWGYNJuyEkqbRA1DTLVzc/D222UM7ufU2SE2Kxfx/bO4TpbhoXmmJ+fIZTOXVz+ql8zoqIllmQSBDGhnzsixDx6UAef8eRmsCgWxcriu7KeCGF1XahZksxI3oppvVSoyqLVa0qV3bk5Wda4rAiQMkyqs77wj2x05crN+hTsXJRYVu3XdTf9NN8aTWCE5B1THXb0Sfv9sbGxPv003hoxjEeSVivw/NZW4XKam5GZZ8J3vyDWqtgsC+W1dV66ffj8JNE6X9FdYllzT+bwcPwwTEeL7sk8uJxbKn/kZeS0lQgoFWSxoEXrno4XIXcpOBc62F/wRc8Hmprg+Vlflljbtep64Qubny0xNFdnY2EPf9xkfM5icLLG2ZtJoJKveel32bTa3Z0B0OvIa3a6sfvbvlwF0a0tWSZ2OmGZ7vaQCp+or4/siNup1eZ12WwbHH/1Iaorcf7+81sGD4msOgsQyoiqzPvLIDf3q73g8T34zJSzURJKeUAbdNarzrprQVAaOsojojJn3x+pqIh4OHdqeMRZFkvXSbMq1NDYmoqLZFKuIKhT40kvyG2xuJq4Yz0sCT1XszmCVXEUmIwJExYq0WslzSsyUSlIx+b77xOqiOHBAV9O9m9BC5C5msMDZTgV/RkfFZ7uwIJP2+fMyURw6lJjfx8bgk580OXGiwMWLMhhVq7Lf8LAMapWKiArfT0zuzaZsoyqnVioy8A0NyWBYKsnrqmDV9H75fCJsfF+OY5rixun3YW1NtldFzRwnKQMfhuK2WV4WwaNqjeiV14dHiQj1e6nHBlN3d8ugUSb9bFZ+R1UYS5vmr53FRbkOQM75tDVJVSj2/SQ9d2NDtpmdlUXAO+/IddLpJOJDpeaqlgpwZQ0YhWEkVpB+P+kXpawgqpv22Bj8+T8v71eJkLExuWnuLrQQucu5loI/qpx6ECQR7CdPbjedqmyVvXvhrbdk4FL1H1SsRq0mA6TrJr1k2m15TmVcdDoiIjodKSl94IAMVPW6CJNmU0RFq5W4ajxPVm1qcNvclMFPFT9rNkVo5PMSia/KWk9Pi0hqNKSr6FNP6cyMD4tpyu9pWVdaQNL/D/agUX+DIKmk2uvJLZ/XFpFrRVn54MoOuq0WPP+8fL+5nAj+pSWJBSmV5Np9+WURBdVqIhpUbJYKDk+7YuJ4uwhR1it1bXY6SdCqcquWShJnNj+fCJDBxY3m7kJ7XjXXjKrToRpjqfgRZQIGmYQ+9jEJOlPNsVZXZXDau1eEgloRqch605TBU9X98Dw55vHjYrWYn5d4jyhK9rMsWW2p6o2qpkGrJfEs2awcZ2tLBudvf1sEjmHIKrFQkONMTsogq2qNpGMbNO+fTEaE4uBklWanrJm06yAIkjReVY9Cm+nfmzNndhchGxuJCCmV5La8LNfj5KTcvv1tuVZWVxNroooBSbvalPAYTM3N5UTUgPyGyl2qgs7zebmOf/7n5TxZX5dtZ2bk/WoRcveihYjmfZPPi/VDmVBXV6XUc6eTbDM1BZ/6lAxwcSzColaTffbskdWYbSfm3kxG9leDWT4vk9D6ulhhstnE121Z8r8qJa18z7WaHGNzU4RGJiMTWL0uVpXvfS/pNDo3J5kBIO8pn5eB9XvfSwZzzftnsOmdIj1hwZVZFent1MSVDq7UXB2VWmsYSVNIxYUL0jdG1QhxHLlmJydF5McxfOMbco2urMjz1ar8Ve6xIEiEyKAVxLZF2ORySX2fRiP5LVVK7k/+JHzmM4kVJJMRAVIq3chvSnMrol0zmg/M2JhYMS5cEAFw/rxMQgcPJpkPDz4o7pU335TBqduVAWt+Pik0VizKgNXpyKBVrydVHbtdERadjjy2d69YOdbWZABrteQ47bYMiKoQWhDIcR0niUFZXk4KKD36qLx3z5OBcXg4scQ895wE0OnCSe8fFcisTPjpxncKNZGlH1NxDKaZiBTH0TVE3os4FnGtso0OH97+/FtvibhQIkRlqk1Py4Lg5Em5NZtyXSgRks3KtZXOblIWzDSeJ6I/CBI3KCRF7QoFWZR85SviBlKxK/v2Je5YjUYLEc2HwjBEHPi+mIajSAa2YlFMrsok++STMuEfPy4DXKcj26hgua2tpE+Fav+u3CvKNKxiR2ZnZd+FBdlWrdZUaqGqVaBqUqi4kXTxs1YrqbK6f7+Uvs7lZJ/NTfGV33+/LiP9ftmpzHvapL+TWyadVaPSdnUNkfcmjuV6ArlO9u/f/txLLyXNIicmEhfq7KwsBF59VcTBxob8RmGYVDZVImSwF5BCVUhVWTQqNVdt43myUPj852VbVVBtaEiEiUaTRgsRzUeC44iZVZWSbjRkkEw31pqbk//feksEQbMpA9b0dGKNUCJCCRKVWaPEy+amWElGR2XgXVqSFZyqOQKybS6XPB5Fsr1qU765mcSVPP20iKHDh0VAZTLyeltb8j47HbHoaK4NFeuhVsTpku7pbrxKlKTrTxiG7G8YSWl/z9OBqjsRRYmbMZ9P4rZAvvvnnkuql05OiiumUJBzfXJS6oM0GiLuVU0QZUlUAcPpGJ+0ZcpxkusqDJN+M+o3V+/nM5+R61xZxw4e1I0nNTujTwvNR4qKH1HFlFZW5KZMsbYNDz8souDNN2UQ6/WSsuyrqzJ45fMiAhxHJiXPSyapZlMe63QSobO4KBYUtY8ahBsN2bbfl8HTsmSbalX+fvvbEly7d68IqRMn5DgjIyJYVADgAw/c5C/2NiOK5PdS8SJKeKSLXqXdM+pxlaqt0j51DZErCcMk8yvpQCt0u1JDR/V6GRsTsTEyIjFRUQR/9mdy7qt+MdVqksGWLjC3kysmm5VbFCV9ZpTQ8Dx5P1/6kvx+Khh1py6/Gk0aLUQ014XB+JFz57bHj5RKki577pxM9qqB3fCwiJDNTRncgiBJ8W235XnVYGt9PYkdOXxYYlRUtUY1yamiWKr4kiolDfKaFy/KPtUqPPSQiBFV2XFyMhFSnY4IFu0muDaUCEmvqHfrNaO2UQGqti2/VxAkgcsaIQiS/isjI+JyUdTr8OKLSbzI0JCIkIkJcZNeuCACplJJAstVbJVyxaSF4WBAqmpUp6wg6v2oWJADB+QaUcUGdX8YzbWihYjmunEt8SP794vP+s03ZYVWr8sApqo7rq5ur66qovAdJwmwU9aR6Wl5bG1NBkZVSlpl5dRqIoDiOHG/+H5S/KzVkj4X+/YllSlnZ5PmeT/6kYgnPbBeHbU6VjEicGXQqYobUSg3gGqMpjJATPPK/ih3K/2+XEcg1o1048bVVbmGVA0W15Xze3parrVjx+Q8Xl6W71Q1lFSxVXBlQLEik5Fb2gqiREomI4Lny19OCqCBXL/akqW5VvSQqrnuqPgR5cdW8SOqR4zrSpbK448nrpiNDXn88GEZ6NRgqDq0qgqs5bIMpBsbYhGJ48Tqks0mk5myuNTrIjhUto4SNRsb4pb51rdEwExOyiBumjKQm6YMtM8+mwzcmt1RrjTY3RqSXnGr1bhtJ7VIVPaMtkLJ+a5EyPT0dhFy+rSIkChKYjdaLTlvJybg+9+XuK2LF5PaPqplgqqbk47VURiGHE/9Fqo4mQoEL5fFivilLyUWlmJRXLNahGjeD9oiorlhqPiR9XVxvQzGjwwPS+2RkyfFjKy6dE5OysS2sZFUVVXZGK4rg2qzmaTptloielRpeddNBkrVqlzFIHie3NptWc11u7LtJz4hA77risCZmZHVZBiKGPnkJ3X64W64bhIPlGZQjAw+ptJ9s9mkrbwmSY2HJGNM8dprci0FgVhJtrbk+1Pn7re/LdeG6lKtKhurANOdfguQ6yyf3zkjJpuVRcBP/qQcRxWd0/1hNB8UbRHR3HDGx8VCoibyc+dEfKigtyNHRJAoN0q1mkTdq3of+Xzin+73ZWBUvWw2N+WYnicix7ZlZQeJdcT3RbioYk3KOtLriQh69lnJmlGF1EAGd1Vk64c/TMzQmu1kMknBuXQMCCR/d0rpVfeVlQq0NaTZTETI/HwiQuJYMmM2NuT8V+m5Q0OyXasl8SLr63L9BEEi7JQo361Gi0rLVVYQ5b4xTTn+44/Dpz+dxImMj8sCQ4sQzQdFW0Q0NwXDEJGg4kdUJkCplKTzPvGEDKRvvinPVyoiKHI5edyyZDBUDdZcN3msXpdj5/MSp7K0lPS/iOMkqLVWk+OpwMhaTQb21VXZttGQ96Ea5k1OSkxJvy91Gh56KKnQqhHS9V92a3C3U9l3SFJA4zhJ371bqdeT+hvpAmBBIELY9+X+xIRYQlS59uPHZb/lZfkuu10597vd7b1iBlGpt1Ek53e6G7LKJPv85xNBo/vDaD4qtBDR3FRU/EizKT7sel1u09MiDMbHpR7BO+/IwKpcKqrC6spKknqorCPlsqzk2m35v9+X4xQKIiI8T15PuWhUV1Dl6lHpv+vr8nirJa6Ye+6RjIWJCVllNhpSOvvIEQnO0wiq++5OabtwpQUknTWjitIpt9ndmjFTqYgYhu0uj05HgqZVD5dyWc7FqSkRCi+8INfP2lrisrRt2W8nF4wi3TG510vSch1HrDCPPCLXkBIhgy4ijebDoIWI5pagUNgeP6Ii/NVK8P77ZUB+/fXE5w3irtnYkPvdbpKloawjrVZiHclmk26+uVyyfa8ng7Xvy+CqXDnttkwIqnfGpz8tK8DFxeQ1VJBrqwX33nujv7VbE2XFSBcxGySdNTPYk+ZuryGiavCAnN/KTVWpwCuvbK9e2m5L/JLnSVBqtZoU7kuf37C7K0aJPSXG1fFzOQmKfeaZ7fEhe/Zol5nmo0ULEc0txfi4mJjPn5eB9Nw5ERQHDsgg+IlPiNn5nXeSgmXj4zJoKutIs5ms7oaGZLButxNT9vi4CIe0m0ZVk6zXk34pxaI81u3K+/mzPxM3zcGDSb+bqSkRJouLcsyPfexmfnu3BumiZSoANf142g2jHk9PbCo4slBIfou7BZU2DpIxptwei4vSWDIM5TtR6bcqZV2l56oy7emicLvhOEk/H+WKUS6xYlEy2UZHk2Po/jCa64UWIppbjqvFj8zMyG1yUgbfjY0koO7QoaSnjKpDAUlmTdo6ksslsSO2LUJFuWgMIwnQUxNhvy9C5/nnZdX52GOyCl1YkOOcPSuPq1ojesUopC0i6aqd6VoUg9sq14xl3V01WxYXk6ZwR44kn/3dd+U53xdXTKsl18L4uAj1lRV5Xp3Hqnrw1c7BTCZxhanePoYh18X4uAhqlfWkysJrNNcLLUQ0tyzvFT/yyCNJnAbI/UxGetqsrYmYULEgcSz7tNsySPu+3KanxbqhXAIqbiQIkhWlqhzZ7cpx031qDhwQsbR/f1KR9dlnJetHB/HtLiSUEBnsRaNiS+62rrsXLiRZKEePJuesalzX7YqlsNUSK8XoqDRmrFZFfEMiptV5vNN3aFlJvIlyxSgrSKkkGTGq67TuD6O5UdxF6w3N7YqKHxkZkf+Xl2WV2OuJCfmZZ8SMrXrR9HpipRgd3d6jBmQQVkWf6nVZSRqGBKCqLA2VVdDvi+BQroJ8PqlY+e678Cd/IhPE4cMyAezfL/8HAXz3u3d3HQxVSG6wsqr6O9jZFWSSVL1mBt01dzKnT18pQqIoSRFvtcTF2GrJeZrNimVueTkJqFYdqFWBsp1EiEqL9v0kbR3ECjI/D5/9bCJCpqflvWgRorkRXFch8mu/9ms88cQTFItFJiYm+Jmf+RmOq77VGs37ZGJCTNYqGPLsWcliCUMRHp/5jAykmUySJbBvn6z0CoXtMQmlUlLorFaTleWePbKvOr7qBNztimWl05FJQJWXP3sWvv51mRCOHpXtDx1K4k1+8APZ7m5ElWZXPWdgu2smXVlV3VdVVdXkp4I072SOH0/cgffck1jkvvtdOd+aTRHbvZ7EI7XbUsRsYSHJCoMk02Un1G8x6IpxHBHrn/qUxD7ZdmKFLJdv1Deg0VxnIfLss8/yta99jeeff55vfOMb+L7Pl770JVqt1vV8Wc0djLI8HDiQDKwnT4oYsG0xLT/xRFLds9cTk/bEhKz8VLCd6muSzyf1QhYX5bGhoaRXCshg3+vJNq2WHFtV/1xcFFeMSuPN5ZIiamEo6ZTKdH43kc0momK3SqqDLhjLSiZMx7mzM2biWESIKmt/9Kg83mxK9osqQFYoyPcxNSWxICdPighRDR2VcNkN1URQxU0FgeyTz8t5+rnPJZ1x9+wRV8zdYonS3DpcV8Pbn/zJn2z7/1/9q3/FxMQEL7/8Mp/+9Kev50tr7nBcd3v8SK0mNxU/8ulPi8n73DlZ/fm+DLxbWzLQquqSpimDfa8nj6mmerOzEg+iLCKqroJy0TiO7NfpyASh3DhPPy2vAUkTv1dfldXu3NzN+rZuPKo3UNrFMmgRGRQlppkIF9e9c2uIKBEC8h3t2yf319fhjTeSKqaZjHwHw8OSJba5mcQzpW+7odxcKhZEPVYuS7C1csOoIHCN5mZxQz2AtUs1sUeUs3+AXq9HL+VYr6uEeI1mF1T8yNqaDNKq/sj+/eIm2bNHTNkqdsTzZNBVIkMF+KmVY6uVZG0UiyIw1OoTZEBvtZIeNSqVcmtLKsA2GrLKVO4d15XXevddES2HD9+0r+qGkhYiu1VWTdcRURVVQe7ncnemRSSKpO4MyLmrxOn58+Jm7PXknFKp564r4mRlJRHKcPXGiyrbSJ23qlhcLiev99BDibUvXadEo7lZ3LBg1SiK+Dt/5+/wyU9+kgceeGDHbX7t136Ncrl8+Tav2rVqNO/BbvEjtg0f/7hk2Ch3Tb8vFo+RkaS9OcjgrAqatdsiLlRVVtURFmSibLfFGqPqOSgRc/Ik/If/II/v2yeTzeysbH/+vFhH7gZUv5m0RSQ9eQ66ZdLZM76fFKS7kwjDRISUy4kIOXZMztVmMxEhY2MiPN58UzJqVH8k1f12N2xbbmo75foZHZW08kcekfNc9YfRIkRzK2DE8Y1JlPubf/Nv8vWvf50f/OAHzO1io97JIjI/P0+tVqNUKt2It6m5A+j3RYioM7tcFpdNHCc1GVSrdNcVk3evl3QRVdUlVS0G1Rm41RJBka5UaVmy0lQDuvLbT05KvMqhQzLJ+L78LRbF3P6JT9z5vvh33pHGa0tL8puoInOmKROjasamrFFzczIBx7Hcv+eem/0JPjqCQH5/kM8+Pi6f84UXkkrBpZJ8R2NjUthseTnpefReAkRlKcH20vkqZun++5NtDh26u+qzaG4O9Xqdcrl8TfP3DXHN/O2//bf54z/+Y773ve/tKkIAPM/Du5u7XGk+ElT8iApAVfEjMzNShn3/fimVrcpfDw/LwK1iPSARGP2+CJSNjaTWQqOR1BhRgiaTSYJY+32ZRJ57TqwqTz4plpKjR2VyjmP43vckW+FOnhCUNWRQcO3U82RwuztJpPX7UmsGxHI3MiLi4gc/kPOvXk9EyNCQnCsbG4lQgyTGYydUZpLKhlGPqU65Khh1bi7pQq3R3EpcVyESxzG/9Eu/xB/8wR/w3e9+l/3791/Pl9NotlEsyqpalc1eWpLb/v0SVLqyAm+9JYN3rSYDdaWSlIRXmTUqYFUF/Y2MJF16lXBRJeRzuaRuycZG0qfmc5+ToNr77kvM89/5jgTV3snm8bR7JS080um7O4mQO0Wg9XpinYMkkLrXExGiYpTK5aRj8enTcp6o2jXKunE10taSnawgqk7InSTuNHcW11WIfO1rX+N3fud3+MM//EOKxSIrKysAlMtlsndbEwnNTWNyMimHrdw2jiOCZGJCggFVKnA2KwJmbS0pTmbbMpj3enLb3EwKo7VaiavG9xM/v6pg2WiIO6hehy9+Udw9R4/KezBNsYw89dSdGZipBMVOAatqglUuhHRmjWneGRkznY7EBUHSrbZeF3eVEhqlkjzeasn5ubSUtBe4WkaMqkKbblKnLHZPPpl0xt2/P4mb0mhuVa5rjIixiwT/l//yX/LX/tpfe8/934+PSaO5FtJmchCz9dSUTASvvCIThCoLr4qdqVRftfJUAaqqOZhy1ahVqWmKCHHdpOS2acrrPP20vObqqkw6vZ5sl06nvFM4flzqq6TLlyvRMTws1qc4Ttxas7NiHXBdePjhxKVwO6LSykEyt1RTxrfeSjK1Mhn5vGtr4spTbQnS1qLdMM3tFqV8XkTHfffJYyMjIrI1mpvFLRMjcoPiYDWaa8Z1xV2j4keqVbnNzEjMxoUL4qMPQ7GITE6KBaTbFUGh/PG9nlhAVLde100qXaqOvVGUmNx9X4THd78rJvP775f3s7Ulx3/lFXlsauomfjkfMao42U6kJ9p0d94oSgrN3a7U6/JbQ9Kx9tQpsY7Uakn2VqEglrGVFTkf09aNndipJouygjzzTJIufejQnZdxpLmz0Z0ENHclV4sfmZ2VNFvTFEExNCQD/+amWEdMMwlKVT07bFsmlnRWjWqtrtJYTVOKVr38clL8DGTfxUVJ1Wy3pWrsnYCydOzUryRd1CzdayYMZZ/bNW6mUpFzCuR3dF0RmaoYWbGYVO09c0Z+d3WevJcVRAk1dT+fF9GhqrKqGBSN5nZDCxHNXc1u8SOPPSbm9ZdflomjWpW0y2YziQuxbZk8lRhpt2XytSwRMKoSqyqApmo81OtSO6JahS99Sd6H6yZBrK0WPPjgTfpCPkLSfXtgu+hIFzODJJbEtm/fmIaNDblB0rX2Bz9IatIMD8tnU12cl5aSZnXvRVqEqOqon/qUnDeeJ5YXHYyquV3RQkRz12OasnpV8SO+L6JgaEiyXc6cEYGismby+SS7Rlk8er1kZWtZiXXE9xNXjeru67qy/blz8Id/CD/xEzKJ3HuvpPeqlOBPfOJmfzMfjt2qq+5U4l116lVi5HZjZUWEJUj13CiCb39bfn8lYnM5sYitrIgwuRYriIoFUSKuUBALyMGD8vzevXL+aTS3M3dIkpxG8+FR8SOq70a1KhkvY2OSZjs8LPeVX35kZHspd9dNMh46HZl4lFsGZFJqt+V5VdxraQn+9E9FmAwNwQMPiMWkVpOMmts5zMrzRIgogQE71xBRmTKOc3uu6lWsEUh1315PfjsV7Kwq8y4siKjd2Lg2EQLb64KMjsIXviAipFyWc1WLEM2dgBYiGs0ApZIM8iprY2lJ6js8+qi4bEolmRQ8T4RJoZCkU6qJQblqQCwoKngwihJLiQo+XFsTE/7p03K8+++XbVotWVVfraLmrYyqYZGuCaIESHoSVmXIlWC5nVKZz5+XeB8QS0W1KoXsGg0RIqpS7IULsm2zeW1Bqen7xaJ0lP7850XsHjwo8SAazZ3CbWgE1WhuDFNTkgKp3DJnzshE8NnPSmaNacpkE0Wy4lWTj+PIZKPcMmEoz/t+UgBNuXFUiq/q0lutwic/KW6a06fFpfPss/JYJnMzv40PhuoAe7XmdyBCRFmKbpeMmdOn5TcFESGLi2JB29qS33V4WH7XSkXSc1Vq7nuR/k5GRsQaZ1lJVVaN5k5DCxGN5iqYpqxAVYXMfl9EyPCwZNi89JIIhEolqR/SaCQWElVfRLkfslkRF0qgdLuJ+6LbTZqf/cRPyOtevCj///CHsiq+3crpDMZ8DKafKmxbJvV8/vYQIsePJ3EbR4/C22+LEFlfl9/I80SQrK2JuLyWgFSFsgo99pi4CS1LzoU7pdqsRjOIFiIazTXgeeKuUTUiKhW5PfSQWEGOHRPRsbUl2zYa4l5JW0dUV1TVAVitkJUVxTTlscVF+JM/kYqrc3Myua2uSkXOBx+8vQpVGcbuZd7Tj6VTd133xr2/90scSyCzcicdPCiN66pV+Z3GxuSzqAJlymJ2rdi2HOOZZ+Q48/O3hzDTaD4MWohoNO+DUkluKkticVEef+YZWSU7jogQNVGpnjXKOgIiStIWk/RjKmBzdVXiQx5+WIIdPU+ybF5/XQIi9+69GZ/+/WNZV2bBpDNmFCrb6GpF0G42cSy/MYiY3LNHglJbLREhMzNi/VpakvNCuW2uBWUFeeIJSSnP50WE3qrfhUbzUaKFiEbzARiMHzl7VoJWDx4Ud00+L9kRnpcEniqLiIqFUE31VAG0KJL7SoyoviSHDknMiOOIW+j4cTnmvffe3O/gWlB1VQZJuyrSBc1uVWtIFCV1XgoF+e2/8x1xm6nKvJWKWMRUVsy1YttyPj31VJJKfqt+DxrN9UALEY3mA7JT/MjCgrhwgkACF9ttqapp2zJpBUHSsl31FHEc+asmL9+XSVkFwx4/LlaWJ5/cXmuk2ZQV9K3MYPquIm0VUcW6VP+eW40wFAEIkklVKMD3vy+iw/clg2VlRX7nev3aU64NQ2KGnn5asrBGR8X6pdHcbWghotF8SHaKHwH42MdkAstkZJIyTZm4Gg0RJqryqsoocZzEnK+ec5xE4HQ6smq+/35pnhZFMiE+88yta8IfrK6qGJyslQvnVssMCgIJIAYRCnEMzz8v8R+uK266ixfFlabcbNeCbUv8xxNPyHmh+8No7mZ0HLZG8xExWH9keVlqQDz+uJjyVSv4UkkmXNtOrCGQuGzSKGESBDL5fec74gpQQbLV6q1da0QJkcHPNShEVGXVW6mGSL+fiJCJCRGQr70m1o9iMQlKXV5+fyIkn4cvfhE+/nE5J44e1SJEc3ejLSIazUeMih85cyYREPv3i+A4d07Ew8ZGEswKSQaNSglVGTSQBLIGgVTqfOEFmbxUrZF6XWqNPPPMrden5VqsHOmusrdKhki3K78ViOvl3DkRHKurktVSq4lrZmvr2l0xKg330UflO9m799a1ZGk0NxItRDSa64Ayt6v4ERWQet99UmGzWJRMC9OU53q9JJDVMJIsEiVGlNUkDMUS8vbbMhnef79Mjpub4qb5xCckhuFWQvWbSZPuIpsOVr0VSpa321IJFcRi8cYbIjhWV+X/paWkE/O1ks9LIbxCQRrU3WouKI3mZqKFiEZzHVHxI7WarKhrNXHdzM9LMGuxKBPcoHUEdq4/oR7r9UTQdDpSW8Tz5DjPPSeFsEZHb8jHuyZUwOpuKLfMrZC622xKzAdIJsyPfywxP/W6pNVeuCAi5FpTcy1Lfv8HHpDffWrqur11jea2RQsRjeYGUC7LTYmRRkNW16p778aGTHD9vogMVY1VkS6RHkWJ1WR1NREje/bIRPnyy2J5mZu7OZ91kGz26jEiKpX5ZqesKrEIIhh+/GOxWvm+WDIuXBDLyLUWKCsUpEJuNivWsduxq7BGcyPQl4ZGcwOZnk7qj6jg1L17xWpQLouwqNWSipwqCHUwDkH9r9rMv/yyTHaHDkmA5ZtvSu2So0dv6MfbEdPcfRJW8TA3O2OmUpHvHqSfy4svJpaqKBIriWpu915YllhA7rlHrCjDw9fvfWs0dwJaiGg0NxjL2h4/EkVSP0JVUFWTYq8nQZOq5shuhKGIjnffFReCCmI9eVIef+yxG/fZdsLzdhci6QZvNys+ZGNDbiCxHK+8IpkxhYI8ruqFXAulklhB8nkpTKb7w2g0740WIhrNTWIwfgQkLkGVkV9eFnfNTq6aQaJIrCiq3shDD0mmx4ULEnvyyU/evPiLTGb3QmXKInKzuu6qUv0gouiNN+R7Hx6WoNRa7dqqpFoWPPKICEzdH0ajeX9oIaLR3GQG40eKxcSKMDQkfUva7USQXI1+X9KFf/xjyajZ3JTJ9jvfkayNm7FCv5oQUTjOja8hsriYuFvqdfl/c1Pex/nz1+6KGRqCz39e/t4qcTkaze2EFiIazS2Cih85c0YCN2dmJIsjlxNxsbIiYiQMr24dCQJZ5b/yiqSK2ra4GL79bfjMZ258GXVVvG0nlEXkRhczUxlHIN/r6mrSrPD8+WsrUGZZUj133z7dH0aj+TBoIaLR3EJYFhw+nBTUKhRkgnYccRdcvCgio9+/ussgiiQ+5PRpETdTUxLr8J3vSOGzGznpe97VJ2nDEKFyo6qLnj4tMR9hKFYQFY/TaokYuRZXzMiIxIJMTkqBM41G88HRQkSjuQXJZLbHj0xNiTgxDHEBLC3Jit73rx7I2u3K/t2urNqrVWld/+STNy6bQzV32w3bvnHWhOPH5fvq98USsrQk/9frYn16ryqpliX9fubmdH8YjeajQgsRjeYWplxOAldBSsWvr4ulZGVF7vf7V3fV+H4S9HrokEy6zz0nwZUzMzfkY+wqROJYJvPrXZo+jkWEgAi45WURIYYhbq9rccVMTEiczdycxPFoNJqPBi1ENJpbHMMQwTA5KfEj4+NJJ9jhYcmUaTSuHsgahmJdeecdiWloNOCllySg9eDB6/8ZrmYRud6BqmkRUq+LCFlfF2F2LQXKLAs+9Smp96L7w2g0Hz1aiGg0twmD8SMHD0pcg2mKyFhZkeeu5l5ot6Xg2fi4bPf663KMhx66vu99NyFimuKWKZWuz+tGEZw4IfdXV+VWrYob5lqyYqanRYTs36/7w2g01wstRDSa2wwVP1Ktivg4elTcDKWSBLPWale3jqj4iGJRBMm778rE/PTT1+895/Pby9Sncd3r06gvDKWoG4hw29iQz1mpvLcrxrIkGPXgQbFEaTSa64cWIhrNbcrQUFJ/BCSGIYoko2NxUWIhdrOOhKEIGd+X/S5ehG99Sybf6+F6KBQGhUgMRMQxRMRksxbw0b1wEIjlJwzlb6UiFpBK5b0DUufnpQDcwYO6P4xGcyPQl9lHRBRFLC6t02x1KOSzzM6MY+r6zprrTDp+5PRpaXZXr8sEurUlVoCrpaO2WiJaxsfl/z/7MynO9VFOwHEcE+FjmBYMxGNEccxWs8nb57vsmy4zVPjwdd77fYml8X35TjY2xErUbF59P9uGL34RjhwRkafRaG4MWoh8BJw8tcDXv/Ecp84s0Ov18TyXQwfm+coXn+Lwofmb/fY0dwGWJROoih958EEJYlUZN1ebhJWrptcTa8E3vgGf+9xHExNRa3VZXG+wttUHJoDt4twwY4LI5/xaldVakwf2TTA79sEDRtTnb7dFjGxtya3Xu/p++/eLFeTAAR2MqtHcaLQQ+ZCcPLXA//Zbf8ji0jrDw0XGRoexbZM33z7N0vI6v/gLX+XwoXltMdHcENLxI4Yh9UdsW/5fW9vdOhKGYjnodGB2VsTIpz8trp8PQhzHrFSanFmq4IcR+ayLbUO4Q+yK50HGten2At46t04h41IuZC4fp9X1CcII2zLJZxyMXZRCuy29dapVESOVinymq2XFWBZ85SvSKPBGl5jXaDSCFiIfgiiK+Ne/+3VefPltDMPg4uIatm0xMlzi4P45Nrdq/Mk3nyOKIv70Wy9oi4nmhqHiR5aWJCNmaytxU9Tru+/XaskkPjcnJeGffvragzWVaKg02ixvNFmttgiiCNc2iaIYw44YtIhghARxn27fx7ZMOn2fMysVHjk4Rb3dY3G9Qa3dJYpiTNOgnMswM1bAtqxt4qTVMrh4Uaw/CwvyeWu1q7/fQ4ekLsjs7LV9Po1Gc33QQuRD8IPn3uBbz75ETMzIUImRofsIQ4NOp8e5czaz0/fyzjttTp/+AX2/zcT4NNksdDqtKywmGs1HjWHIJBuGIkKGh8VdsbgodTR2y6zp96XfyuSkVGF94gmpPXI1as0uZ1YqrFVbNNo9wlAiQi0TAgMa7R4QMDjkGGaE6fbo9EMsM8S2LBrtHiuVJhfX6/T6IbmMjWWZhGHEaqXJwkaNrGtjmSamaeDEOQy/yMqiy8qKiK1ud/f3atvw0z8tVhDdH0ajufloIfIBiaKIb3z7BTrdHnvmJjBNi15/lYw3S7GYpdHosFmp0Wy2KZeKHDxwWLIGIsh4sH/PIZZXNvjDPz7HT//kLIZh4rqXzNQZ+as6sGo0HwYVP9LpJOLk7belpsZu1pEwFGtKqyVxI80mPPDAztsubtR569w67W6fnh8Sk+S/RIAfSnaM6frA9sAT04qwPVFEYQRhFNLu+ixtNOj1Q0p597Irxo9i+kFItx9gGgYTQxnqdYPl1Yi1iz5h32Jj/eo11++7TzKDRkev7bvTaK432m2vhcgHZnFpnaWVdQr5LH4QUciN4dglgqBGGPWIaHLq7DqdbovHH7mfWqOHa+3Hy5QwsYhjm5FyibW1FufONRkfL2HbUmXyaueg6laaFiyZjKzsdJCd5mpksxI/UqnIObO+LhVHV1d3t47UaiJgokjSX596KnkujmOWtxq8cXqVrh8QX8qLNS+l6cYAMcTEhBG4XkB74PimHeO421+80elJT51CIlriOKbe7hFEERlX3DKVLZNOw2HlQoZmw6DTvPSCO6QBOw589asiRO6yMV5zC7NTosPB/XM8+vARxseG7xphooXIB6TZ6mCZJuPjQ2xs1CgVPCwrJwNvHOA6ZRw7j1eErc0iG+t9stl1ykUYGS6SyThYjonf7OMHfUxThEQYSoS/7yc3FWxnWTKgOk7SKGy3pltKlCgrS1q4aCvL3c3wsMSQFItSc+TECTh7dvdKo/2+FAYLw5hKNeSpZwK6vs9mtc35tRqtbh/TMAjjmJjtdTriGC55aTCdK9WOZYUY1vZo0jCCarNHPxDRMVTIYBomPT/EsUwMA6obFrV+zMr5DO2Wjd9VA3UEbL8oHnpI0nKvR9E0jeaDcvLUAr/5W3/EVqXOzPQYuazH4vIGf/gfn+X3//232DM3ydjo0HvGE94JFhU9JX1ACvksmYzH7PQErWaXpZXj2JZFuxPQ60MYWhg4ZDNFSsUcxWKeIGyysrbFVtXh4P4ZbMvENLsUi/IzdLuJ+NgtuyGOLw3uYVKwKgjk5vvJNkq0KCuL48hjO52faSvLoGtIW1nuTJSLZmpKzo29e+HYMSlstpN1JIpCTp2C0pDPxbUeU4dXMQ2DIAwxTYMoircJEINLFpFLFLY2yJ05y5bxU9uOO7l4nOJUk+b49ojYGOj2fXp+QKsbMJT3iOIYI4LNdZtG1aGyVKDftrgiAPYSrgt/6S/plFzNrUcURXz9G8+xValz5NA8hmGwuVXnxMkLhGEEMfh+wPBQ8arxhLtbVI4yPjZ02wgTLUTeB2nlmct6HNw/x1vvnOGRhw5z7O3TnD23hB+EOI5Nt9vDcWyC0KZ66gT79kyTzWaIjZjltTphdIZyucCD9x/lsUdLBIGsPHu97d1UB6tAhmEiVpT4UAJkcDu1fxgmPUiUcAnDpNKlYSRiRYkX277S2qIGc8e5UrA4zkf7XWtuDOn4kVxOsmXeekvcNwmJKq5XXbpd6PUnmD+yTp+IKLqyVOmgCPn4v/vXrFvzvLpvuxAptzb41O//Gy7c9zBvffYr256LInBs6AcBW40Q0zDY2MzRqLpUl4YJ+7tHmj76KHz5y9e/q69G814MzhsAp84sc+ytBWam9hBHRaLY4My503S6PcZGy/R6PpVqgziOOXJonhOnFviTbz7HwQOzl0WFsqhsbtUoFnIElsnq2havvPYuv/8H32LP/BRjo+XbIkNTC5FrZCflOVwuYhoGFy6u0uv2yeczFPI5mq0OcRwzMzWGY9s0GkXCYIiMW8a2bVy7R7PZYW5mD4888Ay1mnlZBORy8ncnt4sSFUqspIVLWngoomi7WFH3B0VOHMsxgkAmpPTrRdGVq0llZUlbWwbdPYYht0HB4nnaR38rouJHJifFSnLsmGTaBMGVIqPfdamu2oSBwej+ZRw3INuoYft9AselUywnJ00c8/gf/S5ev0vJ2Ro4UkwuEH/Qnrdfp5/JcvITn728X6ZRw+73KPR7+G6Gre4ElW6O+mqJONxZ+Vp2xF/8Sz5zcyau6/BRlo3XaAZRY2q/v/2vsmovLa/z8qvvsry6SaPZo9XsYxg2rpeh1ZrFNoaIRkcxDBOTKnPTs9i2TSEXsVXZYnVNVgTTkyOcPL3A4tI683OTly0q5xdWqDdaLC6u0mx18IMQw4CM51Iq5Tl8cO62yNDUQuQa2MmX1+70OH7yAqvrW3Q7PdY2KpiGQb8XMDd7mOnxEYaHCoBBNtPHNEaIQoOQGMfOUy4FPPLAQ5RLE9RqIjquxXyc3kaJgEJBhEvaqqFEjLK0DAoXZX5PCxd1UaXFy6C/3zDkeVWFU1lV1HaD7h/T3NnaorZRn8e2twuWTEYe0yb168+2omGewZ59YLkwN2/y2qsmm5tX7hOFJrW1ImGzy1O1/8DepdfINeoYxISmxbl7HuLkM58nu7VBrt3EADLhYKhqTMnfuCwVDrzyPCc/9gyFepXJ428y+84xMt0WAKeHPsba+GepeTlgZxFSnF5n7z0VVttZKmctSlmPkXKWrOu8ZzE0zd2LWuClRYQaLwcXbunFnLIsK9f4oPV6a6vG68fO0ekZ2PYUUVDDsnyiKCDoh8RxyGZllUZzg7GxPL1ek2Ihi9/3aTR6rG+t0u6YLC37lMtlPKfMqdMx7RZsbNZ4592IbnuSTqtFsTCKH1yg21uCOKYV9jh1eoFDB+Z2tajcStyVQuT9BPfs5MsD6Pd9KtU6tWqDkZESEwyTybr0uj69fo1Cdoxao02j0aLXD4iiNRzbIpfLMjM1Srlc4MCBUTwvuRAGT+wgkBNeneCGIRO9uin3ibqpgNfdUM8ZxpXCRf1vmlzhJlIXJCTvRb3PdFBt+mJMb6eEjxIsKobFNBMBpt77oJUlLarU+x8ULNrK8sFRJdhr7S69fkCnH2AAnmvjlm0O3J/DvWCxerFAFFx5cjUbYxzrPEk5WKGIVBAzo5DDb7/KgbdfpZPLXxYabnRlcY/h/url+xZw+PlnmX3nDbJ+73KcyaniY7xR+jwXzUMQuhISkn4rps/+x87j5nxiDDKuTRTHnF+rcXa1Qj7jknFtyrkMs+NFyvmPoHa95pZjUFCkxy4lHtKu6cHxNgyvXoU3jRqLbVvGn3x++yIsDCNefv0Ezc4mE2NDnF9YodVpkc9liKOQWrOG7/ewLJMghM2tPpZl0u70qDda9Po+w+UpSiWJsG7WezSNHhsbbUaHoVL1qVYCgiCD52XIZi2CoM9QcRrTtAmjiFpjmbfeOcOeuUlmpka3WVRuNe46IXLy1AL/6c9+xLG3TtFqd8nnMjx4/yF+6ktP72i2Wlxa59SZBWamxzAMgzjyiCKbs+cW6HZMRodHaXd6ZDMlCvksI0OwvrHF2Y3nqNakwYdlSqR/NuuxvB5xfsHliz/xNPfdWyaKrnSdKN6rSyjIhaOOkb6g1H3Y3vX0akJmt2BWdQzYbuFIu4/SYkpl/ajH4/hKS0t6paFQbqcwTG7qPSvRko5dSbuIVBZR2oqiBonBWJY7fVF8rWXRa60uJy5u0uuHWJZB1w8JI6n5QT8g69rkhhtMWD5jEz4XThepbaasESGAwUr+fr5plXlq84+4p/7C5adNIN9uXf7fIMaIe8SG+MmtqEs5JUQADr7x4uXQ0xiDE8XHeXH0p1jPHuCy+kglxpRnV5g8UME0Lj0bQ98Pafd84ji+5F6M8WyLrUaHVq/PkblRLUZuYZRYGLRQqL9qzFPjyaCouFYxAcnix7KSEghxLOONck2nEwfSY6a6rxZR6e23KnXW1rcYKhTodg2CfpZiPodhBpiWw1CxQLfbxXE9iC1M06GQtak3WpSLIQYGpmVgGiYxYGVjXMfmxIkA145YW7fJZffi2DGWZWMaNp5TZKtymjAKME2LKPRZXm1Sq7coFLL01rZotjq7fhc3k7tKiJw8tcD/45/8DidOnZfI5EucObfIO8fP8X/6pf/iCjHSbHXo9fqXg4zCsMTaWpNKxcWxZ4ljKOULDJWnZdDP5xkpBWxs1YjCCNOUwdOyTGzbIQoD/DCk0ypx7Fi8zRqh/npekna7U8bLoDVi8H76wrmamFEXTVo8KAGwW12JtFVmJyGzk4tJZeRks4lQSFtH1EpEPb6Tiyj9ftT77vW2D0DyPSeiZXCwSFtblIBRomanWJbdUqNvZdIWjnRZ9EFLQBzHLK5L0bBizmGz3iGMIjKODQb0+iHtXsBoMUOj3SMubjB3b53MUo7VsyMQbB86qpk5vjn1CzSdET62+XXgyswZgC9c/Od8f+Jn6dl5DjVeIRttd9ckhdBM3il/gh+N/QxN98oVnONXuOfRUwSjBS7Hyl6KS2r1+oRRjOdYRLEUQYuBUt6l3uqzuN6glPO0m+Y6MygoBi0Vg9aIQUvw+0XFpaUtw4aRWELUGKD+qvtp0aEWdZ1OYhFOj0Fp8ZMewwbfb6+XxbMexsCBKGK43MOyTaKof2lBG9LNtCnmc/T8CL8f4jgulukDBoZpYhomYeQTBj5RHGEasLIWUMg3cF0P1y4Qhj05VlDFD7osLr+CH/UuzW8xnutQqdSxTAPPdSjkP3x36+vBXSNEVF+Y14+dwHUdSsU8jmPh+yH1RovXj53gt3/vT/i//sovbnPTFPJZPM+l3elRLOSoNc5w7J1zrG82KOYLGIZNu7OFYfgYhks/KOPYLr7fxXM94tjAc4uUS3NiGTHBwCCKQo4dazEyUto2eaYnUXWRqAvsvdwu6iJU+6cFjhI26rFULOG2C0z9HXSxXPl9bh9AlPtlUAjtxE7iJS0a1CChSFtH1GurgUUNNOlsovT7j6LkvalBJj1oDFpa1OCkviclXlw3GcAGXUO3ipUlbeFIl0XfyRLQ6vrU2l1yGZsgjFM1OuSDOLZJzw9o9wIMEwI/ArvL8FyH0e5Z+m9GnCo9AUbyQ/lWge+P/jRW3OfRrW8BV4aK3t98ifuarxACDlfONgYQGhanC4/wnYmfxbdHrtjmMxf/Gfe3f8yPH/4F+nFSHMSyTBzLIgxjHFs+i2lAEEaS+msY5DI2tXaXVtenkNX13a8FZdHcSUx0u9tFRHox817jQHqMUSJAxZ2lrbiDLty0FUJdd2nLclpApMWNEg/qddRrwvbxcKf3lx5P3mtsBIhiC9PMEmOAGWIYYqEzDYcoDgmjEMu0sV0bx43pWhG5nMWFi6uEfkAQBcRxiO93aLVXsGzwAxElfX+SL33+MSrP/ZATJ88AITEB7U6NTr99ubCgYUDfD3j35HkmKsM89eSDzM6Mv/cPfhO4a4TIwsVVnnvxTUzLZGy0jGUVyGX2ADA6HFOtNXjneMyPflRlYmKISrVGr9clm3HYO/cQZ85eIOsVubCwShQZuI5J3+/S7YZEcY9Od4Nmq0GMj2WZNJsdysU8xWKOmcn7yYcmYWRg2RZxFNJotrHdEuPjpcsXRNoaoS5w9Xj6Qt1ptZBeAaRXBVcTOGkLxm7xJYMCJ21NUG4aZSVJB6+ms20GL9hBEaNWILuthNRx0+9pJxGTzSZFq9LHVgNietAatMao+2r7RmNnn/FO1hYl9JR7SH0/O8WyXC8rS9rCkS6LbtoWpbx5hSUgCCXt1rJM+n5IFMc4RiLATdPAD0Lq7R5xxOUU3RiY6Zzi6PK3Odh8jR+M/wVaaYuFVaBtFokxMK6wh4jQMIh2qfwh1J1xltx9V4iQTH+N/+rMf48T94kAu5/EnMRAHMXYrkkURQRhRBBEqHdtGollMuoFBOEOJ9odjhIUO4mJbje5JgfdHbsdC5JxKX1LX/9q7EgvLNQ6b3DMSwsGtWBIX5/p46v90+9np3Fm8P2m/x90Q6v3OzhGpq/5wXi8wbEoGUtNjr21zMbWJvmcw2ZthXanRTHvERNSq9coFh0mp2a4uLTM2EiRxdUNjr1zXBYB/YB2p0mrs4ll2QSBZGIahsHqZpF68xiTk8O89Op5/CAgjmLCHQZP27JY36gSBCH3Htl3Swaqwl0kRM6cW6JaazA+OgwYZNxJspm5y8+bhk+n0+X7P1qGeJlqrU4QBsQROE4exzrE4mKMbRxl//y91OoNLLOEabnYpkEQBYRBSBSHokjjGMexKBaz2LbM0pfqTuIHEZZpk81kL1+Igxdd2mSYDlZN39Q2g7e0OVJN+mmrQfq1BlcJV2MnsaNeL32BXotlZ3DgUBexCphVx0y7UNT3MegT7vWkJ8rgwHM1K4V6f8rSkTaxpgParmZ+VdaWZnP3AXsni4uyTCnRksmIiEoLlvdbsj9t4Rh0OexkCZBiegZhGGEaBqZhEMUx1qV9oygmjGLavR7hQJ2Q0LFx4oD76s9zoHWMl4a/xEujXyI2czj+JrOtd68QITG7J9IOPlfurzLWu4AEg5hAzFfO/SPu6by1bb/A3V4kJIhiOpfcMpdXrYBjycDuOpZ8XtPAtm7NAflaSK/+lajodkXMdzo7uzx2Ir3ASd/fSUyoa1ZdE+pvOshz8Kbea7q0wE7XUtrCkLaEwJXXwKBl+Grj0KDrOC2K0uPS4DWatsAoV3J6sTU4Vgw+JjeT8bEiZ89dYGVV4jKi0KFWj3Ecl3JxntGREhubIcOleyC2MOIC8zNDhGF42d1YqZ4lCH2qtQXqzWVMEzIZj7fePcOZ8xkR3cHOg08cg+c5HDk0j23bvHPiHD/x2Y/dkmLkrhEilzHkF+711zBNj37fp9ls0e706fX7VOtnKOSzjI0NU2+0aDQ69P0Qx7bJZ6VGiOt69Ho+cVwn45UwDKnuaBoWhmlhWw6mYWOYNoHv4DoZLFsthWP8fpfx0RyeO8LKSocgCLFti2zWwzC2nyTpCXU3oZC+OAdNlrD9gkkfU11sgxfV4HHUKn7Q95oerNSknLY0pC/Qa/55dhE7gwJnUNgMxp0oMbdTUFl6+7TgGGS361V9PnX8waC2wQEq7feOoqtnN6WD55RQymYTF1AuJxH6g40R0xaOnRi0BOQzDuVchq1Gh2LOwXMsOpeayakYkb4fXiFCiGM6+dLleNFM2OKZjT/gwdr36Bg5HAJK/mC9EOi5Hl6/t6sYSfflNYm5p/kKE6f+e6I4YiJYxRwQNiEQOFdWK/PD7duZQBTHrFRaEryKwUgxSz5z8yvwKUGRFhPttvxNxz8NxkjB7iJiUBTstNBJ34cr3Q47Ta47uSzS7CQY0vfT/ytL6qB4GLxW0zd1HHVfXXdpy+/gmKH2SS/sruU3UYu3D8PWVo1z55fwPBfTMuj3QwzDpN8PCIII27EBk/HRMhPjoxx76yS9nlSZjGPodCr0gw5xbGCaFp5XZMIrYVk2lumQzYBl2MxM7b/c6TqNaUCru8b4uMnHH7+PVqujs2ZuBQ7sm2GoVKBWa+GNuwRhk7WN11he3cD3A4IwJAxChoaLrG/FvHKsiWkaOI6NbZn0fUlrDIKQbNaj2+vjOg7dnk/Wm6JUnMWxHWzbI+N5xFFEPwgBg0I+y/BwiV63T6vdI5NxefCBh7iw+DobmzWCIMSyLMbGytx3dB+TEyPbJstB02TairHTY4MDziAflXtAXfSDg9Ru7pVBMTQ4MOwmWgb3U+wUP3MtsTSD/++0b/rvoHDYSQQpV1U6zmTQBJ0ODk5PLoNCZidLUXrVl/7e1GtHsU2tPUQmA5ksOG5MJheSzUQ4XgxmgGEklgDDMJgdL9Lq9Wm0fbKeQz+IpHldBH4YEAyIkMLWBvNvvcrcm68yeAqV/U3K7FBwBPCB17/4VWbfPcbE2ZM4UXJSdl2Ptz79ZYqbaxx59fnLQsUmZMJf2vF48jw8+u/+NW/8ub9Ic2Jq1+0sy8QywQ8i1mptZseKzI4Xr7AafRjS8Ulpy4QKeBysNbGTIBi0Ygy6R9KCIi0QDGO763EnF8TVGDzf1f30+JM+lwetD7tN/unrJp2av9PiYPDcfq9reCcGF1OD7uzB8VLtMzh2pRd+adGzm+V5cPyR40ScPLNEpxcwMTEOGPh+QBTGmJZJvdGhVCxw+MBestks1WqTXi+D53o49ihBEFMq9vD9NoZh4QedS4u9CD/oXHpvJp5j0/ODy9b2OIZa/SJ+0Ma2HfLZPN12haWVTSzLZGl5g29990X27pnmwL4Z5ucmbxnriBHH73Wqfnj+l//lf+Ef/aN/xMrKCg8//DD/5J/8Ez7+8Y+/5371ep1yuUytVqNUKn2o9xBFEb/6D/43vvHtF/A8h0I+x/pGhUarIz0zgpBcboiJsSEWl9axrRJjo3vJZjL4gU+73SMmwrEsYiMmDALCKCaKQor5SbKZMvncKLlsmYxXJIrEBNwPAohikahRhGkZeJ4LxDiWST6fxTAlirrT6+HaNnv3TFMq5Xe1dAxeKIOWip1cHzudb1ezVqQn/sELeafn0o+lzbKDZtbBAe+9nlPWlsGBZCe/cfr/3T7D4OfeSeQMDjBXuw1+54PfbfqYg6SPkw7STQ/w6ntMr2AH0xTlNWJ6fZ8ginFsMC4PmLF8h3GI61iUCx62bWBZYmmJTR8/ahPbPUyrR2j6mGZIL/Txw+iyMMhvbXDPD7/F2MLZq8Z3DBIDFw/fx7Ev/GcAZGtVhpcXsMKA5vAYlek5ME0KWxs88e9/h0y3/b5qocbA8sGjvPblP598r4BtGfL9GGCZBjLMGTx2ZJr9U8M7HksJACUmlHWi09keV7GTqEw3qkwLip0mREhExE5uid2uVdh9Vb+bRWKnc3bQ3ZB2XajzT1ks0oJhp/e000Jpt8fV507vm574d1skpL8Ttb3aVv2fHmsGXUHp9zF43e8kSq4m5q7FwhuGIZ1OB9MyMS7njkWXbhATEYYR5VIRx3GoVutU6jWIYqJLBw+CnsSFAH6/Rc+XUhDVxhJxFOE6Do5j02jV6XVrhHGEgcHS6ht0ezVs28JxxM4wMlSi1mjS7fRwXJtcNsPkxChPf/xBfv6vfOW6VVt9P/P3dRciv/d7v8cv/MIv8Bu/8Rs8+eST/Pqv/zq///u/z/Hjx5mYmLjqvh+lEAFJ3/1//pPf4fip83R7fTY3a1iWhW1bjI3MUMjvIQhCGo0WxcI42cwwruviWAXi2CRG0iGl1oIEDkVRiHJym4ZFuTSJZUk0vus4BGFIFEY4jkPGc/Fch1arQ88PcBybXC6DZSTry57fx3MzTIwPQWrIfy/T6LVuo9hNEFzrPoOPfZgFZvoY6YEJdh4cBh/bSVSo/wctGYOvqe4PTgq7HXdw5TX4/G4iKP1auw2mO73nwf0GH9v+/kP8ywGa6iAxcSTCxLFMTNPaNgnJ/YgojjCIMS+tWtu9HmEcYZlgWCFjGxcYWztH0d+i1N+g5K9TCCp4UWdX4RABlelZ3vrMV2iOjO2yVUJha4PH/8PvkWs1rlmMqK/kwr0P89bnpFdNHBnEoQGhQ+BbmLFLFFp02ibjxTKulaHbTQI1VdBm2oWWdtcN/rbXyk7n8bVY7NILjUGXQ9pCprZJ75cWNDu997TFIP3eribad/scg48NPj547N0eu9pY9UFmp2sZiz5Cg9jl4yXfdYjv9wfcpIkQgRg/CCkUcoRBSK1eJwwltlAsGzFB0CG8ZDkMgz6dXlWejwPAwDCg3dkiCAM2t84QxWJeVbEkaVzXod/3JRbMNDAtSxbjuSwfe+y+HctWfBTcUkLkySef5IknnuCf/tN/CohlYn5+nl/6pV/i7//9v3/VfT9qIQJJQbMfPv8Gx0+co1QuMDUxwvTkGO+e3KJaaUnjOtsmjGKmJo5QyM1Lym0MpmmTzw5jmBbxJQVrXDoxVBEawzCxLety8zsMg5HhEvalznD1WvtSoGCM69oU8kXUxBHHEb4fMjU5huO421bFisGVRPqxQXZ77Grb7jQZ7jQJXm3lcC2vdzV22menSfy9jvF+Xm8QtdIafH43wbLbcXcatK9lUvtgV+bVciZ38sm9R47l5c0iuBRsDWDGIWYcsq/+Oj+9/BvbhEMMdGyHlXse4uIDj16TCEl2jslurHHft/4jo1tr23zHNXuIN4c+w1LmKG23hG9m6BsOoekQGzah7cI1SZjbsEDMbcTVFizXss973R8U4Ds9v5PleCc3yuDiZNDS+V7vYfD4AP1+jwsLyziOnRIj8eW/YRjQ9wMO7J/hnXfPUm+08Pt9wjgiTt2iS3/7/RZB2IUoIiImiiMsw6DWWKTeXL0kUGLiOKTWuEi0y+Diug6OYxH4IaZp4jgWIyNl/vJf+AJ/+7/+2Y/cTfN+5u/rGiPS7/d5+eWX+ZVf+ZXLj5mmyRe+8AWee+656/nSu3L40Dy/dOBn+cTHH+B//Rf/lrGRIaanRgE4fe4ildoCqhCUZWcIgnlW1t+EWK1I8oTlWbLZDL1ujGXntg19hmHi2DZOLkOr06Lf87Fsm0o1IOOVMUwTCUByMI2YOLIJYxvrkvXDADD8y71c4NompPfa5lontd1W8h/kWFfb9v18pvdzjPdafX3Y13yv/W4+FokpWGGy8wR9jSLk8jESIsMmMuBU6RG6q1myUVKxMQZe/6n/nMrs3ve/9DQMOuOTvP1Tf4Fnfvuf///Z+/MgSdLzvBP8+RUed2RG3ldlHVlHV/V9ooEGARBoHKIIiVxqOBrNUNRyOKMZaWdtqVkZsWM2szSbNZpM1I5WMhuZtMOVSF3DkUiKIIXGQRAkCKDR6G50V3d1dR1ZZ95XZMYdfu8fb33lkVGZ1VXddZc/ZlFZEeHhEeHh/n3P977P+7wQyffwMfgP43+LSu7Qze0vwR3HThGS3aJ5O11vHxQd/CB8EBG6VdGQ3fdjkbJG8YPwSmoGuiOUQRBiWzpLiya2NcXQAFcEp9GVlIyDF7R3HVPEI1Oe3C7jimg014mi7UTE81u0O5voms765mlWN05IBAao1Zocf/fsXRex3lYisr6+ThAEjIxs/4IjIyOcOnXqmu0dx8FxnKv3a7Xabflcuq7z3NOP8MyTj3Di5DnGGEDTNA7sneTipSXaHQfXC8gYHstrb1GrNTEMAytl0mp2KOQzRJpOf+ERIkTl3H1OpuwUmbRNu+MQhREjQ2XCyKblWLgOtNsexWIG25ZtprOT2BkxmvIcFy/wmJl5gmIxtWsIc7fwZvdjvcR4t9fstjq/3nvsdJHsJlLt/VzX28f1og27pUN6/9/7eXrDzTs9/kGplRv5zjdKjK63j97tPhzB0fjgVf+tYk4hNWuIjHN5+7uHwUca8U3PResaUCMialbp+i8KuMFgR8SNRU4SfFR8mIXIjaZt7l3ogCXmfEEkUXJUOiy6cuaJhsk0swCYhnzRlJUliqTVQhh6O+38amQ4igKCqzm26Mp+Urheq/shgrCIaWbQopCOU6WydZYgaGGZJq7nUdms3XXr93uqaubXf/3X+bVf+7U78l66rvOll19kcWmNM7NzjI8OMD4+yMTYMBcuLUh9dhBgBgHpdArf92m12oRRSKPVIghC6vUfoWma+IQgjfAkmqKTSVuUSgXsVIqWU8CPJC0ThiHnLy8QaRF7JkZwXI/9+x8nnc0RRRGXFuZ47NgBfvqnX0TXdy+pU7cPKrnrfWy3Mj+1fa/Arlepr8777vLcXjFYrzfJB5GGm829fxB223/v59nt8d222enz3wiJ6UXv9+0WG3e/tldc1/v9Pho+zEGPrv7ViEgFLZ5f/0MG3KVrt/qIn9G3Utt2YRHywua3+eHQlwn0LL0RmptDyP2cnrmZFf9Oourd9nEj+1b3P2y0ohe7ncsfdd+3Sify4SMyBp7n0mp1cD0Xzw+IwghN17BMA8NMYZo6jUadiAj/qsJW/nH9No7biB/TJBKiaxoREWEY0mxVcL02UShp0yiK8IKO6BaJcL0WzdZ6/LmAamMex22hXdGLRKH0sLnb1u+3lYgMDg5iGAYrK9sbW62srDA6em3J3Ve+8hV+5Vd+5er9Wq3G1NTtUfSCpGl+6Re+zCvfepXZ83M4qxXGxwexLIMgFGWz47jU6y0q7Q5RFJHLZmi1pY8HRJT79tFXHKFUHMPQTXTdAE3H0DUO7J+iXm/TbLbJ50voiA3pcNmn3XbRDZ1CLoup9VPfDGm1XUbKz1LITPL7vy8D7a2epG8ldopefBiCtFtE40bft1evstv+bkRsu1teubesUFU9qPfoFjjuVi1xvcjK7cJO30PdwsgnZYdYqQg7G5HNhuTyIaHeodZpoJk+hhlimD4v/vt/QXl95fpv1oVUdesjzVZtM4WPhtVFR56vfIPnK9/YcXsfne9/4a+yPrIPz7EIOilc1yTwLIwwhRZZBJ6O64Gpm1fcYq8lnLudhztphW41dtMd9N66t91JJ9F7/3pRiQ/6Pmqb7n32HosbITu7aTF6bzvtM4p2fr36XLtdw7uND93YSeDb/doPg/i1KdLpgOWVKuCTTluYhkmEh+O28H2DiI6Y8HWph+V89HHd5rb9ZtIpcvks1WoDPwiwrDSmkabeXMEPXETuKREXP/BotjZYXD6OHzjb9iML5RSe56MbOo8ePXDXrd9vKxFJpVI888wzfPvb3+Yv/+W/DEhE4Nvf/jZ/+2//7Wu2t20b277WpOh24uDMFAf2T7CwuEaj2SafE6LxjT9+jdnzc1Q2a1yeWyFlWwyWSywtb1BvxI26NE1HNzIYRu5K2aBJEIUEIVQqTfpKBXxPIwxzGKaBrmvYRAR+B9u2SNtpmk0PXdcolYpMjA1SyOevXmDX8/zoHhA+aDXePQnudKGqv7035RNyvVWRem/1/50iJTt95pudnG9E8NZNJLo//07HQZGF7maBu5UD7zZI9Q6eu233Qd9vtwml28BN2et3N+ZTjRGV8Vk2GxugZTLx/7u7EgM02gHvXljFtgwsc/tJVqm1cSvdoVqNE5/6Ip/83d+6cu/60IEnvvctDr3xfd74S3+VxsDNDXIf/53fpLSxdlOVM246TXPfBBndJ5P3gTa6BrZlkLJMhvukHN7xAh7bN0I+c/2ISBTFpbyOI+65jYZY/7dacutujNbbsG034q32fb3zYafzZqfX9F4P1xNe9m7ffZ3s9Fl2er16vJsY7EQ6bnQSv52kfCdy1GuB0Gt/0G162Pvc9Qjg7kQw5N2Tc7hejWIph+fVcUMfQ9fJ5gw2N6uYps7W+hr1apOQ4CqZiK5ERjRNu/J5NNLpMm3HZLMmkgdTN+i4HbZqyzsKVFfXT11DQgCiKKTjuJiBwb694/wnP/u5u+4ncttTM7/yK7/CX//rf51nn32W559/nn/4D/8hzWaTv/E3/sbtfusbhq7r1wh1Dh6YYmFxjdnz8/zr3/k6E+ND5HMZvvntd/C8PGCg6wYZu4TnttmqzqNrJinbxjLEyKHWCAiCDgP9fbQ6NVot/2qduGEYPPXkMfpKearVOmjQ31egWMjuOOF2T4Y7pRh6J/VeK+befV1vn9eb8He6EHvNiNTfbvfEnYyLdnI63Sm6ANubanU3seu2mt7NsGgnrcxu2G1F3P1472qt+3j0ukQqi/ruzsq9nT97n0+nryUSqgPzrUK3s2oxFze7I4KOK7lpXZP7IdAYGWPpwGHGzp2+of1rQLbT4sX/459z/Is/w+q+gzf0up/4l/8ruXrtpkhICJz8zJeuMbvQNY10SqrfgjDE9cIbdlXVtNiGH2DoIy4Y1bncTW6Uk2qjsZ3cqHLiMNxOcHa65ne6htX7fZgozk4TdS8p7j13d7u+d1pk7PTY9SIWCr1Oqb3jxW5kr/ex6y0qul93K+D5Ae1mmbRdptnw8X0fPxCfHNM0sMw8ERGPzOxncXkNx3EIw5CIkCj08YMOmiYRjlw+QzaTwg9CLNPHMoEoJGvkabdreL5zpcomIopCgqANmouh69f0oJFrXaOYz/F//+/+c7LZNO+fvkg+l2FifOiukJLbTkR+/ud/nrW1Nf7H//F/ZHl5mSeffJKvf/3r1whY7zUoctJotjEMnVw2TbXaQouGKRUllFYsjFPMjwAykGvoaLp+pQuohWmZ6JqB49oMlIv4niiVPS+gr1TE90ucOrVFq+MQBRG64ZFONxnoL5HNZnYMU+406e802avBYqeVwE5M/trvf+379HoY9Nq4dw8M3Z0w1eDbPah2mz5dT7vRjRtdRfbe32lVp75j92qo28xJkYVUavtzO7ks9jbAU/brvURC/Sb3ArqdVWtN92q33nbHo+MFWKZOGEYEXQfy7S/8DN6fvsLUyePXKDQaZh9Zv4beoz0xo5Bj3/mPtEp/jUapzOiZk+x5+0cU65uEmsbW0Bg//sLPQCbD8KkT5Oo3J1D3TJN3Xv5LOxKdMIqu9p1pOT65tHXLXVVvFJoWn1OZW5SOV0RFeaE4TkxwFMlpNoXc7NSJtvf6g2uvwe7rcidCpLATgd8pitBrvtg7dvW2ceg1VdtpwbPbGKge6x1He6McvYuvbmLXnWbtXuzsNOb1kp56w2VxeR3XFbdiwzQIgwaO5+B0PLGDMAx8v49iwWDV8a58Lg3DSBFGbaIoIJ22SaUM6o0WmXQaXTchAt00CUOPgf49hIDv+xKt0sAyTfK5IcLQIYzkuVZ7Ez9YIp/L0lfKky9kef3HJ/naN3+A47jYdoqZ/VN86eUXb5vJ2W64I86qHxa3w0dkN4RhuC09o5jh3PwKf///86/QNI2T71/g3IVlDMPC9wMmxp7EThWJIv9qzbedKlIqDKET4YUiUEKLSJnmlf4CkMumeeTwXjY2qrQdR3QiV+q7680WGdvmyccPMThY+sALrfvx3olUXdDdEYPuNET3wKTMnbodIndqpd3r6LnbKqwbN3uG7TRZ7xRF6U5ZdDew26np1W7HTb1fb1O67qiEIhLqfe7CPHbbUG12WFirU22J7ikIQ5ptj4xtsF5tX2PzDoDr8vH//Tfpa1RpGQUWszOspSY4VXgWM2jz8/P/gFQUN+uIgPXxKfqWFrCia0+QCKgODJHdWCPV9fgfjP4y65k9HK69wccqX8OM4iqClm5y/It/mc09+3dldxriqprPpJgaLjE5VKSUS3+4A/UQQI0TKiKjiI36v+sKqVFEpzcltZOb7AdVrvU+t9uColuT1Y3eBdZOJKd7AdXbnLN7wda9/90it9eL6na/1jShXq/yb/7d12h1GhQLNmHo0+k0WFlbw3FaOG4LQw/55CceY229wuz5BUrFIq4f4LsBmuEShTprG1s4HR9NM9A06WlmWhZpO0M2m8ayLAI/ot7ooGk6lpUik87gOD6GLn3QTNPE81d46slJyv0FOo7Ln//gONNToxw5NE02Y9NqOywurVPuL/JLv/Dlj0xG7hkfkfsFZ2fnYsFqDzM8sH+CUjHP17/1qvSk8Tv4focIWFh6G8tME0YBYegThgGGbjE5fgzT1HFcB89z8X336sk5PNSHYWb57g9/SD5n8/ixfWi6BJjDMCKT17lwcY255XVeeOEL+L5+dWDoHSC6SUR3yqJX89Bb4QIfTB7Uc70pCLg2J9qdT1Xpg26ioJ7rJky9f7tf0xvN2aknRTe6BxWVDumOSnQTCZUuSiAo5dIUszbNjocfhLhewOziBrWmu+trMp0WegDnCk/RMAq8Xv4CdTvuZP3nQ/8JP7n6b6924NWAwcW563+OjbVt9xtajvP9nwA0Xk/v4Z3yT/LluX/EZOccAGYUsTl94LqsUEiqweE9g+wd6bsrkZD7CSqdalnSWPGjQo0tqkN299ilFjzdj/WOYd09d3YjOOp9eiMy3eNa9/i302e8HsHpfb6b8CiC022P3012Oo5FKX+IXDYgbaeIopBKsMTIwCSmZeB5AYahk89MUd5vYenzRBFMTY4QBiGVrTqLy+v0F+WLXMmUXv0euhbSdlYJgjaaGZDN6DiOQy5rYlk+tfoGGiF22iKMYGqyzPTUOBBy/N2zBIHPwQNTFPJSQlzIZzk0M8WZ2Tm+/sevcmD/xB1L0zz0ROTs7By/+dtfpbJZY3xs8CozPHHyHAtLqzx+bIYfvn6CzS0JGatupIau4QctIiJMI4VlZcik+8hnh8hmiuSyRTw/QsxtUhTyecIwIp/PU+7vZ2Fxk1w2T6cxQBilAAswIdIYHYDKKvzbfxuiafqOF9lOF9ZOudbeiwh2Ttl0T/Y7hUe77/dip2jDTkRit3mgl0gYxvaoRC+R6BVeJvho0DSJGgBEUcTCeo3FDdEt2ZaO43X5eUTQ3MwxZx/hUuZxThefBj27bX9Nq3Dte1zv/eEaZw/PSNFdZuuYffy7ff8Dj1b+hE+u/z5m0CRTr9Iu9u2636xtkUun6M9nEhJyF9CtmbpVNQg3ErXpjtR0E5rex1R0d6eUS+977qRz8bzdo726Dp5nUMhP4Qf+ldJdKOUzkIcIDV3TCMI2hEWiyGJ0OEe1VqfVSpHNWrRbbbLpfoLQJwp9wsjF9300TdKOETrZTD8pW66/drtDvdFmYnSI0dEBFhbXqDdaWJaBbdsc2DtO4GdpNdsE3hRHDhwlm96H52qY1jKaFqBpGuOjA3e8U+9DTUTCMOSVb71KZbPGoZmpq4NVIZ9loFziuz94iz965fv4nkY200cUGZgmDPbvI22XyOcGyeWG0dHQdINsph8QXxHZk4muW2L3nkqhYxISEQUmg+UxNE3DCy0MLOgqU5TXBvj+tSv4nbQg3emYnZpVadr2RlCadq04dTcycT0CobATkTDN3YnErRZeJrh10DSNQtaGiKvXg66Jw3unkcF3TELH5weDP0PDEn1UN1JhnU+u/N7VaEgvTheeZ8mapG7mSWnwhZXflvft2S4deRysvc7Z4se6HjU4UX6Z84Wn+dTSv8Jwd4/a6BoMlTIEUdxxOMH9j9sRtVGkRKWnlbZNEZudBPIqyrJTo0P1t90GoxZimCa+H+L7AaaVuzJf6Oi6RhimQbcIAwPbsshnDexMinarQypVJmWJR0gQOmxVL7K4+jar6+8TBiFBGJJJpxgZGSCTtsllMwwNpTl4cJyNyjyG1aTtVAhDm+mpfeTyAZ3OJnOL60Say8TEILqmoekdun2FMtk0zmrljpqcPdRERKpi5hgfG0TTNKLIwPdGaTRaXLzQIpM6hFWW9uloIjK1UyWKhVHJ05k2XDmpABynDmgEoU4U+uh6Fl030WwLIwgI0fB8B123aLUb+L7DwEAO09LRtBAI0PUQz3VxfZenH3mMQiG/LRS5U0lsFH04AqGwG5FQKY6diMS9JLxMcGuRTVsYhn41wua1U7QaNmGgUVvL06rkIGWynWuE5PxNfubSP6DfX9txvyE6Pxj+MlupOPf80tq/Jxe2rtk2EzT4yeV/jR12ONH3CSRiKGhZA7yy57+lb3Gdcr6Jlb7WgTKdMnG8iIHijVXJJHg40S0iTqfhVkgR43SUxf/2L97g7LkV9kxMsL7ZZu7yKpl0nkw2j9MJJEI+MIqGhuu4oPsce+Qop89Uubw4j6lbWKYl84hh0XG2UJ3dgxDSaZtHjx5geLCf5dUNHj82w3/zX/4sS8sbNJpt1tY3eev4Gc5dmOfi3DJ2ymL//n7QOtjpFcxU/ZrP3251sFPWHTU5e6iJSKPZxnFcspkrccPIRMNjcXmdZrtJGHVw3Bau28GyMvQVp3A9h82tSwSBRxj6V5dy8qPpOE5AJp3BNFL4IeQyWQYHBnB8U9i0H1EoDuL7mzRaHYaGxjEM5SIp7nj1VpvBgX4KV0p5u1MiO5GL3sqNnVIcvUTiQRReJrg1KGZSZFImjWaE08gR+BFuw6S2XsBtCflGp6tVjUe/s8LPXP5/U/I3tu2rO+Wylp6ibm2vg50tPMcT1T/b8XNkgxovrf470kGTE30v0TFK8ZOGxdbWGLW3PIb2r1AcaqDrchWZOoRhhGFod61KJsHDC11Xpd86f/FLT/Cbv/1Vzl9+i7GRMpp1gYsLGxiGQTabZnLPYdLZNaIo4uIVV+2/8nPP8R+/sc5v/c4/Iwylc7tlSurU8x00TbtakptJ2+SzaVZWKwyUS3zxcy9imubVlMojh/fy0otPbCvEGBsd4H/9//4eJ06e25YJAEnNLi5v8NixO2ty9lATkXwug22naLUdCvksmu7Q7Fzg8sI7EMHy6hquF1DuO0A2M4CdypCyCoRRiGmYVwY7g5SVIp2xicJIWjdfLcvysUwDTbMA7YqQKE3KMjBNjUzapNmqo5PFTGn4nsNWvUYum+LpJ0cZHdWvekvsRCQS4WWC2wHLSOHU+qhX27huRGsrS209B2HPcGEAgcNI+zJ/+fL/QjZs7rg/haX0fkJt+z4u545sIyIhQlzU0JgJGjy3/kfYQZMTfZ+gmprY9vrQs1g5PUFna5P+8SrZoqzmTF1nz3ApqZJJcFfR696dSqVA0zAMnUMH9lAq5ajXmywub1AuF/ni515E13W+8NkXGBnuZ35xDT3wr5b16rpOyjJptR0MQ6eQz+B6Po8dO8AXP7dz2e1OPlm97U0y2TTtVueaz3Gn8FATkYnxIWb2T21jhq7n4fuBmMf4AYXcKP2lYRzXw3E75LIDBEEougoDTMNgaKhEvdHADx2yGZN6rYkWhQyUc1QqaywsNwgCB9s26O8fZGHlPUp9af7iS09w6uwFzl2Yw3E97JTFwQNTV06ogbt9eBI8ZAgCuHABqg0P34d2PUNtI4vTSNGrBQFAi8jnW3zM+yqpsLkt+hEBjm3j2FmKtU0A1tN7rlGObFrbB8hasR8vX6C8ePlqN5h02OLxzT/BCDocH3uZqjHW+0GorpRp14ocOFajVHKx7JC+/N3tn5EgAVzr3t2dLjl/cRE7ZV1DJFKpFH/z//yz/L9+41/QajsEYYRpGoRhiOf5mKbBsUf289/9Nz/PsSP7btqIbKf2Jjt9jjuFh5qI7NT4ztB1oiiiWmuRTtvUm8sYZp0wCNkMLAw9h+M5hKGHho9la2SLMzzzzGGefPwQQ4P9rK1v8dbx05y7cAnNqtLZ2gRNoy/TRxDpHH1k6uqP/fJnn9nRvyRBgjuFKILFRbEvj6KI+SWHrfUU1ZUsbmeXc1EP6RtsMv7IKqfDL3HphRfJVdbpX5oHXaM+MMTyvsMcePs17OOvYzsd2maeXllqx+qPPwewfPhRzj/zIpnqJiPnz1CsrBHoBtXRCZw9BxkKDYJLHRpb10Y63LbJqR+X2VxtcvBQSOZwog1JcG+gOyqxU7pkp3H/F/7aTwHwj//Zv2NtbRPHEXF2Jm3z0see4Cv//S9+JMKwU3uTuzX/JIZmXOsjcuHSEr4fMDUxxBtvnSKKIJ1JYWga9UYbXdcYHRlg//Q4XhDw3/7yz/Hc049s+wG7DdKUBqXVdhKykeCewsYGrF3RlrouVOser78ZUN802a3oVjNCRiYaPPpsnbbr0XF9HM/HC64dSvKVdWbe+B4Dl87z1fH/G/PZI3RHV6ywzd8+/TeJgMroBO99+ks0yoNXn9c1SFkGtik9nAzdZHkuzdZyifpmiijYufyqVIJPf9pgehr6+3fcJEGC+wKu6/LKt37IqTMXKRXzfP4nn2fv9Pg9P4fczPydEJEr6CYOa+tb/Mevf4/KZo1Lc0ssLm9AFOF6ovnYt3eCx47uZ6NS47FjB/hb/9XP3fMnRYIE3Wg0YH5e/h9FQkIWFuD9UwGNa4X0V2FYIcWRCtNHNklZxhU/nZAgjMTwL4jo9cfLV9YZOXGSHzpfomJO0k1E9Mjlvz39X7Gx5wBnXvz0VRKSSUkzPsOQTta+H9J0PExDR48MwmaZufMZqjUNv53iWhhYFhw9CkeOwIEDoqlKkCDBnUHirPohsD10BmMjA7zyrVfxfJ9avYXruoyNDnBoZo904V2p3BVRT4IEHwWOAxcvbu+/0W7Dm2/C+ro8vyM0SNk+g1Nb9E2skzItDF3DC0I0pMutpkHkBRCEmLqOF4REQLM8yPtPfpH6e4PgbY+yhJh87z/9Zdp9/VdLuAppCzSNlGWgoRFFER3PJ4quNADLmgwM+9gph0uzNvWGS7OqyIiOiuR4HrzzDszNiT/E8DCMj9/a45kgQYKPjoSI7ILu/Nl7py7w+pvvsba+heN6bFYbd03UkyDBh0EQwPnzsbGdrsv/z54VgerWliInOmyLaUTouoad8xjes8ngRI0oMgkjsHQNWzNw/ABT1+nvT9NoueQyKUb6cpxf3mJtSyppnHaagJ3SPTrN0gC6FtvBB1GEqWs47pXme1GI54foV8hJMZtCN2FsSnLm85dSmIZLtZK6Zv9RBJUKfOMbcPiwREf27IFslgQJEtwjSIjIdaCiJFOTI3z+J5+/J0Q9CRLcDKJIUi6NhtzXNJmEV1fhxz+GWi1+7soW8f+MEA2dQsljcM86peEmlmlQyKaot9yrRME0NNquj67r5LM24wMFFjfqaGhkUhZBFOI5JlG4s55DC0xSqYCUIaQmCCMsU8M2dTpegOsJe8pnUgwUM9gpGbZ0HcanXNBg/rJBLpNifU3STL3odODECUlHffzjYly1f/d+eQkSJLiDSIjIDWKnWuwECe5lrK/LTWFoSISpb7whE3KtJg6Q18IgZQeEoU6+5LDn8BZ6tkHWTlHK2dgpk5RpUGu5OJ5PGIb4YUQhm+LAeD+L6w0cN6C/YBNGIW3HRw9NdiwBBrTIImvrBFFEIZUCIjpuSNY2KaZMLEOn5Xj05dJY1nYyo+kwNNbGcS1SYQY7ZbC5KRGeXgQBbG7Ct78NMzPy3YeHoVz+kAc4QYIEtwQJEUmQ4AFDvS5REIWhISEd586JZqLRgGp159cqV0hdNxifinjxEybpQoFzix4527pKBOyUyZBl4vqBtC0PQg5PDoCmUW11yKZNdF2nmLVxvRDH0dml/QyhaxNETUxDp1zIYJo61YbD/rF++gsZsrbJqcsbVOptiqZ+jRNk2/GZmbEwHJN8VhyDs1kpSb7mvUJoteDkSXn+hRckOpSIWRMkuHtIiEiCBA8IHEf0HgrFokzIy8uSllhakkm409n59bYtqRzDgMlJ+MIXNPr7U0SRxVa9cy0R0MAyddqOz0AhSy6Totp0rjgOS/TDTpkU7ByEu8/yjU2b0UmHYlaiLZ4fkLIM+guZq12BJ4YKNB2XWtMlmzYxDJ0gCGl1fOyUweRwgWJWPlcqJaRr3z64fDnWxXTD9yU68qd/KtsBFAowMXHttgkSJLi9SIhIggT3OYJAJl7VCNGyYGpKSMnly/Duu1IZEwtSr0UuJ1Um6TRMT8PLL8vEDGIt/UFEQPV0MQ3pKhoEIbop0RO3baNHpjSW3OG9Qy/DYMm/0ngyotXxKRe2N6sr5dIcmhxgYa1OtdUhdHx0XaNcyDAxVLhq5X7kCJw+Hf/dvz+umtnpuLVacOaMkLXnn5do0tTUrensmiBBghtDQkQSJLhPEUWi9WheafGiabK6r1SEmLz7rmhC2u14m15omhAO14V8Xl7/2c9Cpscd/UaJQC5tUcqmJXqSk+hJbcvY1i26F52WQQT4fnANsen9DMWsTbPj4QchpqGTS1vXbHf4sJCLo0fh/feFWK2s7KwbUR4q1Sp873tCQkCiQgcOJGLWBAnuBBIikiDBfYi1NXFFVZiakknz/Hl5/ORJiXBsbu6cmgAhHqYpE3GpJBPvpz4lqY2dcCNEYKfoSaejEwSi54ircgxU+17X0Wm03B2JTS80TbuarrkeDh0SMvboo3IsRkcl9bSysvP2nheXOK+uwrPPCplJxKwJEtx+JEQkQYL7CL1C1JER6OuTNEyrJVqQzU3Ri9Rqu6dihoclXeF5MDAgJOQTnxBicj3cCBHojZ602yFBiKRsQuBqOzsDiNCBY3uHd41wfFgcOCDHRZGR/n5JPS0s7FwtFIZyTHwffvhDIS+PPSbEZP/+3QlaggQJPhoSIpIgwX2ATkccURVKJZkoq1XRQqytSRoiDOWx3RxSdV1ISLMpJGVkREpZn39e0hG3Cip60mh7rJ7T2DJ1AkMn8HtJhobnGZRyBreIf2yDEqweOyYRDsOQVM3ioqSsehFFQs7qdTmGm5vw5JMSKcnnRcSbIEGCW4uEiCRIcA/D92USVBqLVAr27pX7Z87ENuaNhqRYtrZ2T8VkMlJJo0jI6KjoKZ55httCAjRNQwtTRAFo0e56CxWJ6NWl3Crs2SNamsOH5VgqMe/6uuhpdoKKmGxswOuvw+AgPPEEnDqViFkTJLjVSIhIggT3IKJIqj1aLbmvaZIesCwpw61WZVWvSMrWlkzmu6ViBq80tFVRgPFxEXM+/vjtISEKGxvyntd7jyi6vUQEJJKxtCTHcGFBUlC6LlGO+Xl2FNMqMlKryedrNCTNA/LamZlEzJogwa1AQkQSJLjHsLq6faWueqO02yLAdN24JLfTkYnS83bel6qkUZESXY+1D0eP3v7vsrYmKY4w3J2MKCJyuzE2FhOHtbU4FaWErbsdQ8+Tz762Jrb4/f0SHTlzRsziBgZu/2dPkOBBRkJEEiS4R1CrbXcDHRmRSS+K4olybk5W8L4vK/RWazebdokw7N0rHhlRJNGUoSGpCNm//458paufLwx3j9ZEkWgyxsZu/+cZGYnJSColBGNrSyIdp0/HEaheuK4QFxV56nQk1QNCUBIxa4IEHx4JEUmQ4C6jV4ja1ydRC5CJb3lZtnn3XZkQXVdISLu9++Te3y8T+9KSbJPJyMr9Yx+7s+6hKspwPR+RKNrZ4+N2YWgoJiOWJWRkc1MiROfOyf93QhDI9/A8+U0cR47zY49JiiyXkxTQ7Ux1JUjwICIhIgkS3CX4vkx8ikzYtlR06Lo8Nzsrz126JOkaxxFC0mzuXhUDMhmmUnF0pVgUL4wXX5SKmTsFRT407fpEBO4sEQEhZbouviIjV3pZVqsSQcpmhcDt9JlVGimdlvSZio4cOCDPnz4txz+fv2NfJUGC+x4JEUmQ4A4jiqSkVAlHu4WoEItRWy0pyXUciYIoErJbKiadlv3UavHEXi5LBOCFF2T1fifRaMhnvZGy4DtNRECOh67L8R4fl9+h2RRhb6EgUQ7X3fm1nY4IXn1fXu84cqyPHpXUmaaJmPVWlkQnSPCgIiEiCRLcQewmRAUhHpcvx5qQrS0hK54nk/r1qmL6+mTiUxUgUSQEZHRUNCHF4u3+Ztdic1MmctOUCXm3smKQ73c3UCoJGVlYkJTVyop8llRKBKknT+5uj6+0L9lsXB3Ubos4eHAQzp6Vv6piKUGCBDsjISIJEtwBVKuyclYYHRXyADKZXbgghKPZlJRMqyX3lUHZbitzEF+LsTFZwadS8rqhIXn86adjonOnsb4eR310/frpmd0m+zuBQkHSKfPz8rtsbEhUqdkUo7fjx3f3GwlDIS62LURLGaUNDMAjj8gxWF9PxKwJElwPCRFJkOA2ot0WjYdCtxAVJGqwsiIT2uysTGpqUlb/320CN03RJhQK8tp0OiYh+/bJit62b9tX+0C0WnEUR9ev77nheXKzrN23uZ3I5yU6dfmyRDAMQ27LyxJRevddiWbtFtVxHHmuUBDS4jjy/Q8flmjU+fNCCKemEjFrggS9SIhIggS3Ab1C1HRaJjo1GXuePA+y+r58WUpYPU+2qVR2tiBXKBaFbESRTHLpKz3iBgeFnDz++N2b1BVcN640+SAEgZCWu/mZs1kRC1+6JPqRTEZIw9ycHM9z5+R32i065fvyG9p2rB3pdCQ6cviwEJNEzJogwbVIiEiCBLcQqspFGXRpmhCD7mZyi4tCPsJQJrd6XSYpw4ib1e0mSAWJqExPi4ZkY0NC/pYlE+mRI9JX5W6LJNXn1zQhV7tpWxSUzXuhcPs/2/WQyQjBu3BByJ3qanz5sqRXsln5za6nG/F9+U2KRfl9XFd+36NH5fWJmDVBgu1IiEiCBLcIKyvbPSimp7fblisxKsh2y8ui/1AT1+qqbLPbpK1pojvo65PVdrstk2Q2K6vwo0fldi/YjjcasYma+v/1UhKKiNwLsG0hHefPy2feu1ceX1yUtFc+L5GNSmX338p15bctFOT3XV6W32twUJxcDUPErAMDss8ECR5mJEQkQYKPiF4h6tiYVGMohKFMaqrK4uJFiWa0WkJAokgmOdfdfWLL5cQ4yzTl9cprpK9PVu6PPCK3e0V/sLkp39W24+/U/dl6xatRdP1U1J1GKiWRrHPn5DjPzMRmZ9msNAo8cULI5266Ec+TcyOVEkKysRELkh9/XB7f2JDbvn13V8+TIMHdREJEEiT4kOgVovb3x+ZYCkqMCjLhrK3FfV+yWblfq8lEfL2GdY8+Kmmbs2eFlNRq8n4qHXPkyG35ih8am5vbxaea9sERkeuZtN0NWJYQkNlZidYcOiT9ZTodiYY8+6yU9+6kG9E0+T2DIHbA7euT76iccstlePJJee2FC4mYNcHDi4SIJEhwk/A8iXAo4pDJiBC1ewLpFqP6vugCKhWJgmSz8vzCgkxI1xNzTk/Lvlstec9SSTQl5bKsso8evXN9Y24GqiGfKlm9kcn1bnmJXA+mCQcPCgFsNoXwzc7K95uflyhVoQCnTm2P6KhUlDpH1O9s25LaUdqRV1+NhcVKzDoxcfe1MgkS3EkkRCRBghtEGMrq93pCVBCCUa/L/ysViXpsbsqkVCyKFkQJUneLgti2REFsWybouTkhIZ2OREKUT8Xk5O37vh8FyvpcHRtdv76hGUga416EYcRkpF6XKMmlS/K7nD8v91MpiZZUqzGx7P1tXTdOzw0OCrFZWYHXXpPf8/nnZf8LC4mYNcHDhYSIJEhwA1he3m5DvndvXDKr0C1G9TzRfWxsyOOFgkxACwsygV0vClIqiSV7qyXvub4uBEb5VIyOSjlotx/JvYTuNIVKtyhNi0JvhCSKhJzdqzCMODVTr0uUSumCTp+W51IpiZbs5DeinGXDUAiIakSYzQpZ9X340z+NtSNRJMSnXL6z/YESJLgbSIhIggTXgep+qzA+fq1dercYFWRVrIhLFMlkoqIiyi11N4yMwCc+IZPZ2pqE+2077hK7Z49MegMDt/qb3jrU63FqQnXf7a3k2SlVowjavVD1sxN0XQjg6dPyWUdGJOKj6/KY8m85e1bSNt2ELAji3jQgJNP35bcdGREStrYm0ZFyGT75SSGxlYrcEjFrggcZCRFJkGAHdEc3YPeVaaUipAFk4llelgmk2RRxoqZJFESRkl7tgIJpSij+yBEJ18/NyWNRFEdT9u0TEtJdkXMvQqVYUqnYb0N9l92gqmYcZ3vJ870GTYvJSKsl54WKjF28KGLTxx4THcjZs9t1I74vpEXThJioVA3IbxxFcXTkm9+U/ZTLQnqUr8n0dCJmTfDg4R5deyRIcHfgeTLJKBKSzcrE00tCPE8EioqE1OsSFVlakglmaEgm4fPnP5iE5HKSipmclP1dvCirX9XRNQxFo/DII/c+CYG4Csi24+iPYXywqZnn3TteIteDIiMQR6xmZoRMKH+Xgwelz0+xuD3CE4ZyS6VkP2Eo545qcDg6KsdhfR1+9CO5qRRcpyPn5r2cwkqQ4MMgiYgkSIBMCJcuxZqG3YSosF2MGgRxRUyzKSkTw5DHlEC1uzS3dzIeHISXXpKJp14XElIqyb6GhuTzHDkiE9u9HCnohko7KK0D3Fi6xffvDyICcn4cOSLEwHXlN3/sMelJs7kpv9vEhJCU996T31fpRqJIXqMaFEZR3OQQ5Pd33bjM+z/+RxEuT05K+mZxUQhPImZN8KAgiYgkeOixtCQiREVC9u6VFW8vCWk2JQqiSEi9LuH35WWZRMfHZR+zs5KeUavfnULpmibb//RPyyTV6UgUpq9PJqWhIZmYjh2TCe9+ISEg38V1hXwYhnxXy9pOwnY6JveSu+qN4vDhONWytSXeIiqaNTcnv+Mzz0jKprf7riIwSvuhDNAqFdnf8LBsU6nAG2+ImHViQo6rErOqiFyCBPczkohIgocW3WZjsLMQFeKeMGpFaxiSclFRkKEhmWjn5uJJBOKJt1ecmk7Lavb554UANZuy0i0WZSIeHJTXHD0qmpCdojL3KlS3XVWmqtJRHyRWjSLZfreGcvcyDh8W8un7cj698IKITpWIed8+EbGm0xJ169WNhKEQTceJUzUqOtLXJ9ursuCvfjU+L9T5VqnsXMWVIMH9gvtoiEuQ4NagV4h6vX4f3WJUkAlBNa0zTVnptloyEalISW86pht9feKmOToqJGRzU4hIJiOTUqkUh/1VT5L7CapiBmKflJ2IyE5QZOR+xMyMCEodR9Jyn/gEvP66PHfunJjOHTkiv+/Jk3LOKMIahkI2stmYyHU6cQO9XE7Oz60t2f+PfywE5wtfkMcaDUnpJWLWBPcrktRMgocGuwlRdyIhrrtdjGrbspo9f17SLuWy3BYXhYSoShGVitmJhAwPy+SRTstqVglbTVNu6bTcDh2SSet+IyEQCylV6a6qClHpBIXuybL7//dSv5mbxb59ck6BEIOPfUwErOm0nDeGITqP55+Xc6c30tVqyWOZjBwT3xcivLkpx2VgQNI7tZqcg3/wB5IWVM66iZg1wf2KhIgkeOChQuTnzslkqOsi/uy1ZVdYWJDtQVIu1Sq89VbsJ7Jvn7zuwoW4WV03+ehd1VuWvNfP/ZyQj1otLs8NAin1BHFM3b8/1h3cj1A27ZYlE2kQxCWrN4JW6/Z9tjuBPXvi3/PcOdGMqMaEly/LuTEwIOmbsbGddSOuKwRGRZEaDSEeW1uyn8HBWMx6/Dj84R9KWlFVdi0uCiHpNpBLkOBeRkJEEjywUF1tz5yJtQfKi2OnaEOvGDWTEUGgKsEdGpKV7MqKbFepxBOsEiz2pmRU19yf+Rl45x2JEszPy8o5CGR/nY5MStPT9zcJgTiiYVky6QZBTNJ2S7t0Hy9V+ns/Y3Iy1hrNzsJTTwl5SKXkt+904pLt6WkhF91NAX1fyEc+L8dR1+X83diITfEGBuL+NJUK/NEfybmqzu0okvfu1kAlSHCvItGIJHgg0avtuF4jsV4xaqEgq9eTJ2VCsG2JVDiOREFUi3uVblCiTNg+iQ4MiJfEzAy8+aZMICsrMQkZGJCJd+9e0YxMT9+WQ3HHoEzJwjAmeuo4+f7u6Rj1/ygSEuj7cdfe+xXj43IMNjeFIDz2mEQpFhflHCiXhaw8/7ycD6rDrzqfwlDOjUxGbsp1dmtLSEm5LEQlk5HtPE/KhOfm4HOfi8XTm5tyS8SsCe5l3LaIyMWLF/mlX/ol9u3bRyaT4cCBA/xP/9P/hHs/yuIT3DdQUQ1FQgYHRW+xGwnZ2JCIiVq5p9PiBXHxoqRkBgeFMFQqst+Njfi1asLojYLouhCLL35R/r77rkzQa2uyfxWer9XEq2Rq6v4nIRCTELW6VxoIy9pubb9TxYx6rNm8/0p4d8PISGzFr6IVyptma0vOJd8Xj5CnnxYhq2FsF/aqLsZ9ffFzrZac35WK3FfRkWZT9vu1r0nK5vDh+Ly/eFFI9P0ebUrwYOK2RUROnTpFGIb803/6T5mZmeHEiRP88i//Ms1mk9/4jd+4XW+b4CGF624faHM5CZHvluZw3VgHAkI4Tp8WHYiqYjlwQLa7eFEmDc+TyUBNtkr/0J1yyGRif5Bz52SyqNdl8lAGX+WykJzDh2Xb3Sp27jeoihlFLNSxVz1WrlcRo7ZVXXt3I473G4aG5BxZW5Oox8yMkNH33pNzwnUlLTc2Jt/57bfFV0ZFkJQVfK0mREXZ4LuuRDpcV7QhhULcqdn3hTQvLsJnPyvPnzsnrzt9Wt7rfnDoTfDwQIuiO8eR//7f//v8k3/yTzjfPQNcB7VajVKpRLVapbiTwUOChx5hKARE+S4YhqRRdqs4iaK4Ay7EEYrTp4UcdDoSxUinhZCoyAjEpENFQXon1r4+IRef/rT4SHieEBG1nWlKOL1eF7v2iQkRqD4oOHs2dqdNpeR7Li3Jir3RkFW8InOp1PYKGRU9yedFT7Nnz935DrcL3Z41Bw7IsXnrrfg8mpqS5zxPyIiqqFJkBOT/uZycm93nVTot515fX5ze8jw5xtmsRNs+/nGJlqjPcD3n4AQJbgVuZv6+o6dhtVqlXC7v+rzjODjK3hL5IgkS7AQlRFXCUvjgDqXNpuTNFUZH4cQJiYK0WjLIT03JRHrpkkycrithb9XCXaVweknI8LA4aD7+OHz/+/L82lpMiLLZePI9elQm2gdl1a+giEW3cFelaLp1NCpa0q0NUbifbN5vBv39QmSXliQ6sW+fiFVff12Ow+JiXEXz0kvSY2ZuLnaoVZEmFfEYGpIonfIcWV+XbUdGRHviunK+12qxNuUzn5H0kOoUPTsr5EX1skmQ4G7hjlXNzM7O8o//8T/mv/6v/+tdt/n1X/91SqXS1duUWiYkSNCFSkUGV0VCJidFB7IbCQlD0YEoEqImhVdflSqGZlMG45ERmUxPn45XpIqERFGcilErVJCJY3wc/sJfkJJgRUIqle2rfNOU/R0+LJPQg0ZCQI6dMi9TKQVlYd5NRGA7EVH34drtHiSUSnKugETxTFO8RpSIdG0tPm6f/KTY+xcKsp2yywchHpWKnMfZbCwG3twUkbXrSopQ6UqUdfw3vwnf+55EQlTEaWtL0jj3s39LgvsfN01EfvVXfxVN0657O3Xq1LbXLCws8MUvfpG/8lf+Cr/8y7+8676/8pWvUK1Wr97mupevCR56NBo7C1GVb8NOUGJUpesYGxOicfy4rCJtW8iBZQkpUU3L1EpeeWGoCEi3N0OhIGmgn/952f7HP5Zt1et1PS7jjCIhKgcOxKZXDxLUytww5BYEcdv73j4z3WLMbjLSW4H0IKJYFOIMkvYDqZwpFOKuu54nJPvpp+XWbX6mqolU2k/XYx2KpgmpXliQaySblffLZOJoyrlz8O//vWx35Eh8fiqzvgf52Ce4d3HTGpG1tTU2uksHdsD+/ftJXXHqWVxc5NOf/jQf+9jH+Bf/4l+g34jX8xUkGpEEcK2wNJ8XfcX1/DZ6XzM2JkTj0qW43HFkRCYA35dS3WpVHrft2J7c9+OJtTsdo0jQJz8p5GhtTd5TCTZVft5xZMW7Z4+QkPu9LHU3NBqii8lkJM3V6cQr/P37hfxtbMh925aJVZWrQnyMUyn4qZ+SSpIHGd1tBqam5Bw5fjyOpPX3y3EqFiU6p1KISsTaTYjTaSErqsOvOmdzOblOVDRORRA1Tc7NsTHRM/m+EBSF0VGJpiRI8FFwWzUiQ0NDDN2gzH9hYYHPfOYzPPPMM/zzf/7Pb4qEJEgQBBLCVoOuaUpa43rW51EUp1sgDlG/9ZasEh1HBvd9+yQcvbws2yvxaioVk44wjAdxBcOQFehzz0nPmB/8QCZUNflGkQzySuyqcvAHD96flu03CkXA1HdUKQHLkmO4U1qm241W6UlU35UHHUpEeumSpAwnJ8X47J13hMBtbMSVLQMD4tB6/Licq563XXfTbgsJKZfl//V63Dzv0iUhNXv2SNSk3ZbzW/Wn+Xf/Dj71KSHVSlC7vCx/EzFrgjuF23aaLSws8OlPf5rp6Wl+4zd+g7W1tavPjSbqqATXQW9lC8iqutcOuxeNhgzUCtPTkpZ57z2Jgvh+vNoLgrhHTKcjq0+VGlCrTk3bTkKUvfbnPy/RlO98R17T7X1RKMjnrFZlxVkui0DwQefgarWtjpdtxxEO09xOOrrdaLuh0lmtVmyE9iAjkxGjsYsX5bwdH4cnnoijH5ub20m4aqQ3N7ddjxNFQrDX1yVaODoqr48iOZa+L9tPTcnv0enEv1etBt/+tpzPn/60XBvnzwvZmZ0VMjQ2dneOT4KHB7eNiHzrW99idnaW2dlZJlVS9AruYMVwgvsMGxuyIlSYnLy+BgRk0jp7Np7oBgZk4H71VRmcHSceUFsteY/z5+OVYzodh7RVFERpQxT6++WzvPyybPPnfx7n3dXkq8SBlUosSL3fLdtvFKrYrbvJnXJY7dV9KIKxU3NATZPfxXFkon7QkU4LyT5/XipbwlDSUpmMRAOVo2p/v5CJT31KynvPnImfs205XkqU6nlCOlTVl+pLozxLRkbkOHc68luoc/h3f1dSjQcOxKmjalVu09MPx++R4O7gtq05fvEXf5Eoina8JUjQCyVEVSRkaOiDhaggROPMmTi0v3+/rC5/8AMZiH1fBt+JCdn+4kV4/30ZmJWfhSIhIPddd7t3w/CwrFR/+qdlBfnGG3FIXJGQgYGYhBw8KOmfI0ceDhICMql1f1eV2lLaj2732e6Kme4SXvV4o/FglvDuhlRKJn+IIyEHDojXjIpqbG3J8xsbkhZ8+mk5LxXRS6XiY6k8W/r75abSK82mXBtnzsjjpZKQGBVZqdUkyvfNb8rjR47E6aFLl+KmkQkS3GokGcAEdxWqf4tCoSAh6g+awHvFqBMTMnl997tCBjxPyMDYmBCGWk0G4FotLm9UglTlDQLb9QmplKRinn9e/EGUW2UYyj7a7VhYqKpllJ38vn237hjd61C6DmXlDrH/hSIiEEdAdkq5dOtGmk3Z34Nk9vZBsCxxXVWN6sJQzmnbFm2I0jyVSkJWjh4VncmJE3JOdvfnCQIhL0Eg18DERHzeKo3I++9LhG9wMI6O2La8j4qOfPzjss3QkJAQzxPR8cjIw/XbJLj9SIhIgruCIBAioSYpy5J8+QcJOnvFqEr9/847MkCrdMvIiAyyQSB6k5UVGZwNI3ZTDYK4isN1ry3N7e+HL3xB9v/DH8p7+r5s2+nIfkolIVONhqxg+/ri8syHBY1GTPw8L/bCUIRDHVe1mla6BkVMekt4Pe/hiogomKZE086ejSuOhoclAvL223L+6bpECZeWYs2Haco5ro6ZKpfutoHfu1euj3Y7Lg++cEHO16NH5fpwHNmHaQph/+53JdL32c9KinFrKxayKjHrg1oFluDOIiEiCe4oeoWomibRgw8SosK1YtR9+yT68ad/Ggv7SiUR66lKlvfek0G33Rbi0G22pWlCQjqda0tz9+wRJ0rblnC1ErF2OjJgFwpxeW67LZGQgQEhQA8blPBRCX6VwVZ3RKRXI9LdHA/i9ADEeoeHEYYRkxFl4z46GlfNqGOdzUr6pVSSyMVbb0n6RAl9VRm1qp7xfTmvXVdIjhJZB4H430xPyzmtoiPZbNwL5/d+D158UYhPqSQExnUlSlIsxiZtCRJ8WCREJMEdw/q63BSmpsTr4IOgKly6xailkgzMq6vx4DwyImFk35cV2+KiDLZhKKtIFc0IgriktNWK38c0ZbB+9FFJxziOrApVNU2jIa9V4WwlBDx0SFauqtPqw4ZGI05xgRwbzxNiYZrb7d4VUenWiyhSAnGk6mGGYcg5deaMRCHCUCb7Z56Rc75alceyWYlcWJbYwheLcp0oUp7JxF2PVcfeUklI9uKiHHPHkeN99qwQnkceiatmFFGv1cSRtb8fPvc50WG127EnT60m+3wQjfoS3BkkRCTBbUe9LlEQheFhKWu9EXSTF02TPPrSEvzZn8mAG4axFsR1ZQA9flwGx3Y7Dl2rKIiqkulddWcyMtC++KJYa6+uSv5dTZyVivx/dDQmNGEon+dh72baG71QpbtqIlOC3utZvCs8DO6qNwJdl3TI6dNyLodh7DXy7rsiWlVkBOR8fe45IfanTgkRVy0JlBdOux2nbPbtkzSLItetllyj1apEZFQJe6cj15fSVv3e7wlJ37tXooDLy0KWLl+W99m//8Evu05w65EQkQS3Db1CVEUYbqSSpPe1ExOy8nvzTRl0Gw0Z8FQUxHGELFy8KFEQz4sHac+LV3gqZN1bmlsuS2nuyMj2vjSGEYey9+yJxXyplEwMNxrVeZChKmYUIbHtOCKSTm+POoEc0+4qGbi24uZhJyIgx0SRkUZDIhDT07HXiOpmbFlyTOfm4LHH5Ji/955EPVREpDvVoq6H4WFJxywvx1GTIJDXTk2JduT06bhjcjothOTVV8WJ+OWXhZgPDkqaxvfl2knErAluFgkRSXDLEQQyMCndhWXJCuxGVkpRJAOqmryyWZnwL12Kw84gIruRERn8fF8Gz2pViIJlCTkIQyEoYRgL+Lq79RqGpFP27ZPQdi4nNuXKJdQ0hYToumyzuRk7sw4NyapQNSx7WKHMslRjO4h1IaYpxE89vlP5bvfj3XDda4WsDyO6yUi7LeR83z4hHO+/L2RjeVkIgWEIEZ+ZiR1tFxfjMvNsVv6vdCNRJNfRvn0x8VbE5fx5IfZPPimER9eFcPb3y3XmuvD7vy/pogMH5DNWqxKtTMSsCW4WCRFJcMvQW9FyM0JUuDaFs2+fTGg/+pEQgmYz1nGoKEijIflt5T2Rzcrg5/tCZjRNyEKns90lNZUSEvL44yIE1DTRg6hVfRTJe9q2rA4rFRnEh4Zk8L4Rp9eHAcoMq69Pfh9FHHxfjk86vb0aSaE7EqKEqt2CYVXdkRzjmIycOSPn/OyskI2jR+Vcv3RpOxm5dEk0Jfm8RE4uX477K2WzsU5K6accR661tbVYk9JqyTXw2mtCuB95RFI+rivXYC4nxPz114UQff7zkp4sFoUMOY4sRgqF2MMnQYLdkBCRBLcEqj+Gws2kLHrFqIODQhLOnpVBtVKRAVb1bVFmWWfOyHP1ugzWhYK8Xk1ipikTWau1XcdQLMog/dJLkudutyXcrKIgzabccjl5v0pF7k9MyOtmZpIeHAq1mvzNZGLTLVUKnUrJbaeIR7dviEJ3BET17kmIiKA7MuL78vfwYdFzpNNyf3lZiHIqJZGQgQF44QU5Z8+dE62V0n6o49rpyGMqVZPLyWJA6UPCUPZdrYqJ2rlzQj4bDXmvSkVIx3/4D/L8zEzcx+nSJbk2T51KUpgJro9kOE3wkdAbxbjZ/HA3gVFiVEUMNjaEAKjoxdCQDI6djmhFWi25ZTJxEzAVeras2C5crbR1XbQgQ0PSV2N4WAbn48dlEsznJaSsQtD9/bLqq9dlVZjJPBx9Y24GSgCpKmJULxNNi8Py3VUxsL1qphe97qpJ0+3tOHxYSHs3GZmaksjdiRNyPQ0NybHf2JDJ/5lnJBJy6lTcwTcI5LlGQ+7XavJYf7+kVObn45SO8uKp1YT4jI/LIqHVkt87l5PrqDs6kskIyV9ZkWtobk4WEwcOJNdPgmuREJEEHwqdjoRgFZR/x43m9HvFqJOTseJ/bk4GL9MU4jAyEkdBLlyQnLUyF8vn406vyvApk4mrBBRSKRlkDx6Ej31M3uv8+fgzlEoSwlaW8Om0rAIbjTjn/rD0jbkZqLSLOi6GEUeoetMtCjs93puq6W4kmGA7ZmbkvHUcuV4OHxZS/fTTMRkZGJBzttkUwqBErKq7b7stxzef396RtztVU6nExmqqs7QSsr7wglTvNJsSCRsbi3vb/MEfiKD28GG5dgcGJJISBBLFvJmquQQPBxIikuCm4PsygauJJJWSaMGNrnJ6xai5nJCQeh2+//1Yi2HbcfRCuXW+9VY8sFqWrJZVPltZXKtUTLceJJcTovH00zJAmqaYOFUqMunl8zKwR5FUJUSRrP6aTSEupimRkISEXAsV/VC/vxKq6ro8rmz0e7FTrxmI96MiXwl2xr59kvpot+PISF+fiEvfe0+iIeWykPIoEiJw+LAc35MnJVLRbMbNBR1HfgtlguZ5srDIZCQ64vuyTRTJ9b+xIfoqkGunVpNtMxkhL2+9JaTjc5+Tx7rFrKurckt0VgkUEiKS4IbQSyBUg7mbUcX3pnH275dJ/t1342ZfKnKhOoSGobzv/HysG8jn5X3VKi6K4s6gKtSs0N8v23/iE0Imokg656qBN5ORAR1ksKzXY/fUgwfl88zMfLRj96Ci2+peHXPbjnUIsLM52U5lu+qvipaoqpkEu2N6Wq6LRiMmI/m8kG1FRkolISSeJ+mUAwdkm+7Ioyq1zmTic1+R/3JZrtOFhbgsXnWpfuMN+QwvvSSp0nZb9rdnjxAOx4E//EMxCDx6NBazXrok73P+fCJmTSBIiEiCD8TqqkQPFG7WRXEnMergoOzzxIm4LFaZig0OxsZLb78di+MMQwYyw5BBz3FirwrHudYltVyWVd1LL8X6ku9/X/Zt27Ld/LyQmmPH5P9BIIPtgQMSSdmz5yMfvgcWirSpFvUQ+4moKIkiIr1mZqqnUK9gVRERz0uIyI1gclKEqbWakJFDh+R6ePxxiXysrQlxmJiQ6+PcOTmnn3lGSIDqa6OOeaEQ66pUia/nSdRT9ZpR0RGQqEe1KpVnq6tCWNbXZXxIpyXy8s47cv2//LJc43v3xqndRMyaACCRDSXYFbWaDBKKhIyMiADtZkjI2poMdkrQePCghJCPH5cV1dKSbNfXJwNUf79su7go4retLfkcti0kRNfjklHlF9JubychmYwQj6NHpWnd0JCQHUVC+vpkIF1akm0ff1z0IUEgA/D+/bJNQkKuD+XJ0h3a766SUT1nerGbo2p3Jc1Or0uwM8bH5XwFIQZBINfGo48KEW82JQqhrtvLl+UYP/64aEfGxmR7XZfrqFiU+5Yl19rmplyPhYJcE6mUbOs4ch0uLko/plYLPvlJuSaDQMaN/ftlu81NiY6cOCGfIZ2WsUQJ2+fm5LPvpClK8OAjiYgkuAa9QlRVNnsz2EmMms8LMXnvPSEYnieDVl+fREFAJqB33pFVlioNLRRk8HPd7ZOfpsXW0wqFgtyee04GWcOQgVeRobExGZQbDdnuyJHYfM2y5HkVsUlwfSjNgK5vFwb7vkSkVPpsJ2HqTmREPa5s3hPcOJSHyMaGnOsHD8r9Y8fk91lclKjEgQPyWy0vy3V37Jj8TkrkqppFqnSbpsXl8J4nUcaZmbgaJghi0nj6tJCe55+XBcz6urxPoSD7W1yU9zl/XrQj2WzcJXt2Ni7JHxp6ePs2PaxIiEiCq/B9mZTVita2JQd8M+V2USQTv5qYlBjV90XAtrYmA1gmE7ujqtLbSkXK/1xXBjTblpthCHFwnNitUw2Y3W6e/f3yfi+9FOs6jh+Pe9VMTckA2enIYLdnT2xNXSjI4Dc6Gq8uE1wfKu2iBIemGWsLDCPuOdNt2d6tBekmJN1pHJWeUa9Nyj1vDENDcqxUFFL53Rw9GlvAz87KeR8Eshhot8WszDCEBKysCLlX5CKXixsaqlSNSvUUCnE6UwmLL1yQ6/iRR0SX9frrsfHgzIx8hmoV/uiPRNPyxBPy3ocPy/suLsrnX1tLxKwPExIi8hAgDEMWFtdoNNvkcxkmxofQu0b3XvLwYYSosLMYVZkrKVOkIJABrFSSiV815Dp5MvbsUJ1Ds1n5v3pdKiWP1+vbV+C2LSRkfFya1g0OygT2wx/G7qojI7Ia830hJH19EvXxPPkcxWI8uCa4Mah0jLLPT6djMzP1+6lOsYow9kZDdvIXUY85TuyWm+DGMDAgx3FlJY6AqNLzdFoeu3xZIn8qvdItdJ2djU0ElS9PqSTXXBDEfiOuK2W4e/cKGXEcuRmGXMdvvy37+NSnJAWr0jvFolyv8/Oy6Lh4ET77WXnvYlGuv24xaz4v12VSsfZgIyEiDzjOzs7xyrdeZfb8HI7jYtspZvZP8aWXX+TgzNTVEKvC9HRcgXKj6BWjqtCq40iZrLKOzmZjt9LuKMiJE/GqSQlSLUsGQdXcLpuVv9VqLJRTpbeFgqz6nn461it8//uyf9uW9zx5Uj6fGphVQzDlJpm0Mb85dDry+2QyccRC/WbK2My25f+9otPdJhUVFVHPK6+L5He5OfT3y7FcWpKIn1oQTE/Lb3LqlDw3MhJbvp8+LemcXE5us7Ny3UKcxmy1ZL/Kw8d1JVVz4IBcx6ursbi105H3rtWkpPjAAYlOtttCao4ckeiJ58HXvibv/dRT8tvv3RundlVFkErtJngwkRCRBxhnZ+f4zd/+KpXNGuNjg2QzNq22w4mT51haavP5z/wFxsaGgA+fkuiuqNF1Cb/qeiw+UzoPVbpXLstgqAaqbj+DdFoGvCiKNSQqCuI420PGui77zGbFXEmFn6tVWYFFkTwfhjKo6rqEix0nriQYH5fPsm9fXEWT4MbQXTGjwvLptJwLyipfVcZ0R0RgdyKi9CHKi0SVaye4eZRKcjwXFyWyoBo0jo7KNXXqlFx7atFQr0s6Z98+0Y2oDr7K16fZlMdUawPHiT1igkBITT4v0QzPk9/OtuX963XZ7+c+Bz/4gVyjly/LuZNKyf9Pn5a/P/mTMgbYtpAVNb7Mz8di9yRV9+AhISIPKMIw5JVvvUpls8ahmSm0K6N/PtvHvj0zLC1v8ObxU/xnhwcYH7/5K3s3MWq7HXsYbG3J6iqdloHKsmRC2toSQarryqCkIhsqtF+rxdGM7lRMdzffclkGrI9/XAY5kEHv/fdlcJyakkFsaUn288QTQkBUqmBiQvaTdAj9cFB6gXQ6FhBbVhytUlUYvWLV66VmlEOueixxV/1oUFVm8/OSAlHRznJZhNynTsWlu+PjMuFfuCDX8iOPCOlQ5fVKFK4iluoaVdeTStVMT4tAVXmOmKZENc6ckceee07unzghf0EWEcoU7ZVXJILzzDNyHgwPC1FSEdczZ64Vk39Q6jnBvY+EiDygWFhcY/b8HONjg2iaRhRpBN44ERqaBuV+m8sLrxOEh4CRG95vFMU5XIhzuBBbptdqMgD29QlZKJclctFqiYZEvV6tmrLZuGlasymDXT4fp2K6tSvZrAyE09PiXaDU9SdPxqXABw/KCksRoaeekoFYdRSdmJD3UJUFCW4eyudDtYfv7roL8cp5p7RMr6FZt2C1e3s1mSX48MjnhZTPzcl1p1KQhYKU9546JYuGIJBran1diMvQUNzS4MIFIReqqWGjIde0ilDW6/I7qVTN9HTsoqqqanQ9FrLOzEhZ/fe+J/u4cCHEsloEkcvaWo4wtFhY0Pn0pyWyo8Ss1WrIuyeqzC24ZNIpnn2mxNz8wnVTzwnuDyRE5AFFo9nGcVyyGck5RJEtJIQIw1oiY3g4ay6NZvsD9hRDqdoVVO652RQiUKnIYFUsyuOqIkYRihMnhFRUqzLh5HIyIPq+RC9Ux1aVium2atd12W82K6u5J56I7atfe00GR00TQ6d33pHPVCoJCVHty8NQVnu6njSv+6hQjewsKyalYRiXfHYTvF4ystNxVyLV7gqb7rLsBB8euZyQg0uXJP2hopeZjKRhTp2Kr7/9+4WYqH40hw/HJfGXLglRSaXkGu7Wjah+NUEgi4PBQXmP8+fld1QN8jY3JWpZr0tVzY9+tMEbP27QbLYJgoAgWmRz6xEGB8t84xtZ9u6VKMrsuVjrFgXDpFJ5vvXtDJvVdTru3DWp58WlNX7pF76ckJH7BAkReUCRz2Ww7RSttkMhn0XTHExrCU2TGsl2q4OdssjnPliZ6nkhP3q9SqvtkrYt+ko+ZqrF8kqGdmuIuTmdWi12M83nY2v1RkNy0ap9eaMhk1c6LaSi2YzD/Co6Uq3K5KbKOU1Toiu2LQ3rVLmh54koVRGYqSkRx3Y6kgs/elRICci+9uyJSUiiwv/wiKLY1Tafl+OviEevk2qvJ4gq21Xo9g7ptXvfrU9NgpuHcjS9eFEiHuPj8YJBeY0sL8t1OjMjCwqVUjl0SK7LbFae39yU36Vel/2qTtcq1arOj6EhSZuqlKg6T5Q+bHGxycX5N6lsnWZk4FOYVhbfm2Zt/TxtZwHPe4IwTHPqVIOTs9+hsrlwhXBENJtbnHx/Gcf1eezo4+SyHXS9QyGf5dDMFGdm5/j6H7/Kgf0TSZrmPkBCRB5QTIwPMbN/ihMnz13RiADILBFFEYvLGzx27AAT40PX3c/rbyzzZ39+lqWVDWr1Opfm3gAtYmR4itHBp+jvm2RkaIqpyRyGEUdBbFsGH7X6UWF225ZIhWnKyqvdlsEpm43FpiqFouvx9v39YpS0d698rkZDIiFKlJrJCAnxfVn97d0rKn0F5Ydy6NDtONoPF1Q1SyYTO6nadhzN0PU46tFNSBR60zG9zyskEZFbi3RaIh7nz0tkU107vcZnp07FxmdhKPcPHZIFgG1LZLNaleta6UaUNXwQxKRDpWomJmQhcelS3KQynQ5ZWoogOEopX+KtE/+GbOYx+kuH0bUclYqHZbzOQPnjXLpcI/KfZe/EUVKZ00Jc6VCpnsDUB1ha6SOfn8IwHAxzHU3TGB8d4Oy5ORYW15iavPHUc4K7g4QqPqDQdZ0vvfwi5f4iZ2bnqNeb+EFAvd7kzOwc5XKRL37uxV1XC50OfOc7a/z+V9/k8vwyQbjE3MLrtNodsulHSBmPQTTM8soWp84eZ3V9jogVXG+DKApZWRGiUKnILQhk9dzXJ/9fWpKBLpWSx4NAVlrNZjyZFYtCQGZmRHGvSMjyckxCJidlIDx5UvZx8KCs9pSVtGXJNradkJBbBSUezmbjCEg6HQtTDUOIZnfpbrehWXe3XvWY2lalfFSapvu1CT46UikhGSDXoCrd1zQhGsp88MwZeVyVzJ45E2s1Hn9cUi/lckwoGw25VpUJYaMhqdyVFSEefX3yWlWd1m5HuJ5DENp02sP0Fb7A/OJ7nJr9P/D8BhEGa+t5NrcuUWueJ2P34/tDdBrPEwQ5XM8jCELMVJ25xddpdzqEYex+lsmmcVzvplLPCe4ekojIA4yDM1P80i98ORZzrVawUxaPHTvAFz+3s5hLiVFbrZA33z5Ftb7O3mmd1388RxQVOHrwc2gUCII084sXME2d5ZWznDj5KqOjOcr9I5T7HqevOA5R5irZKBRkUNvYkIlMdb7tTsUoYaJpCglJp8UbRJUTggyIc3Py/yNHJMx8+bIQjiNH5LVnzsj3UCmipMPnrYVqimZZQhxBdAiNRkxEUqm4bHenxna96ZluEmIY24mIKuNOcGtgWULuZ2eFKISh6Do0Ta4h25aoydycGJ8NDcl1e+6cRBYPH45L9FVkU+k/SqW4ckp18QUhN/39spioVGBxMSAMTYgCdM1kZPARMnaeC3N/xhvv/HOeeORnMc0R5hYciPKY9nmiYJIwzOB1ZkgZg5jmKaIootVus7L+NiODAxStDJqm3VTqOcHdR0JEHnAcnJniwP6JGypv6xajVjarXF78MUNDWeqNkE47x/Tk80RhRNtxaHc26LS36Lib1FurZO0BOq006eGnqNcinJZLNmtSLFr098tqeXFRBifLEgJimpKLVlbtSvzYrQc5ciSupnjjjbjcV3UXXVkRQvPEE/L5lbOrqtjp75d0UYJbB6Xd0PXYJyablYkoDK/tM9NNKnpTMt0aEQWV2uk2x0qIyK2FacZkRJXwDl3J0ipfndlZiZqoSrPlZVmkjI1JdDGfFwJqmtt1I6qyTbkkb23J305HoijDwyHz8xfwgzK6ZpLJlvHcOvncGJPjz4Nm8u6pr5PP9WPbP0U23YfvlDBTFXTTIfRHSZmwf+qneefkK1TrVU68d47Z9By5bIbxsUEazTbPP3P0A1PPCe4NJETkIYCu69fNk/Y6ow4PQ6RVcZwW2Uw/m5sZcpks4HJ54T00rYjr1HDcDs32BuW+cYYHjzBYPojnGuSyfXh+B9dz6e8foNnU2dqKfSfSaRn4lEYEYp1IPi8D4jPPxN1vfR9efTXuNfPoo/Dmm0JK8nlR1av245omg10mI9+jXL6th/ahhBKWGkbsIaLswpWrqmXJ76aIiEJ3WqY3UtKdvlHW/0qPUizeme/2MME0JZV59qxci2EYk/bxcfkdz54VAhKGQlDm54WcdDqyTS4n1+2pU3IuKKM7w5Brr1oVMqn0I54H6XSHllOlWl8imz5ARi9imhk8r4ll9jE29BxRaFOtX+br3/lf+MSzv0B/3x76+/owDA8zNU+zaROFFhOjn2B0uEWt+Rpb1Tpz8yucPH2Bgf4S+/dOcO78QlI5cx8gISIPObot3g1D8se6Ds1WhrSdp9UYwTIjgnCDtfUKTicgpEKjuYLvRwwP7GVs5Elsq0gmXUbTDAIcTLNOq9NgdbUPz9OvpmIyGVlFt9uxGLE7FXP4sJTnqvbgrZaQEJVqmZoSfUizKdu88IIMlpubsh+Vp1ZVAQluLVSUQgkUl5djQqFKOJWrqiqZ7g6+qWjHThoRRVK6yUqnk5ia3U4YhkQ3zpyJu+mOj8tzg4NCKGdnpbw3DGV8mJ+XbdttSbUcOiS/+XvvxdERkOu8v19IiCrx9X1otXWiKE9fHywtncK2ivT17SWKdIr5ccJgkb17Ps7GxhCmmePPX/uXDA8e4NHDX2R4aATbnmBl5V2WV+eZmniOjlPBMj/DVvXPMQyHlKFTKuVYWlnnN3/7q0kZ732ARKz6kKLTkVWMGjSmpmL75CiCMBhiavwpKls1slmdfC5DveEQhAFrGyfYrF5maGCG8dGnMY0M+fwIYRTieg3CoIKGgalPUKuFRJGDnXbIZMKrZYGuG6di+vuFoDz3nBALRULW1mISMjoqA+Mbb8jrx8akwZ36DqmUPG/b8l0SEnJ70GzGQtV8Pm58p5rdKXt309zuqtrbfbf3bzdUeieKhIi2E73hbUV3NVmtJkRDoVSS9OjAgEQcz5yJiUqnI+W8th0vIAYHZVulIapU4h5TmcwV6/6mhWmU8dwSmmZQa65y6fKrOG6LMPTp75sgZWYY6N/P9OTz7Jv6OJ1Og1ff/GfML55jc6uGro8yPfk4qfQCjteBSGdmz6d4ZObLZNJp2m2H0eEBKps1vv7HrxJ215AnuOeQEJEHFGEYMje/wvunLzI3v3L1QowicTi8eFG2KxRkoMnl5H67LYNLu63zzJOHyecsFharGKbBVm2JSws/ot1pMj7yBOXSNJn0AIX8MJoG9cYq9eZldKNAEPTjBxptZ5PK1jJLS+vMzbWo14OrbeKzWRm0slkxN3r88VhVf+6ceIBEkQxyYSjluI4jJYhPPSXP12ryGtXJd+/e+LskuPVQoXflHwFx7yCVilGeIr4fN7HrJiKGcW1ERP1VRFiVBbdad+67PczQdbnOQIj+pUvxc7mcVNQMDwuxOHVK0qfqNzp9WrY7fFgWM9PTsq1aaNRq8n/pvBsS4aNTQNcKlPunyGX78COfpbUTbNbn8X0Py0yTyQyQzw6wb89LTI2/QF/pCO+c+h3ml75LKgX9fcM47TFq1XUuL30bK2WQTufZP/XT1Oo+6xtb28p4E9y7SFIzDyB267j7Ex9/ibQ9fnW77j4rvdbt2SyMjQ3xxKNH+ePvvMbxk99lo7KGphlMjDzF0MBhRgYfwbQyBL7DRvUCnt9mYvQoBDnaTpswdMSqnSIRFkFgAwGWFVEomFfeQ4SmU1Px5/jxj2M7aWXPfu5cvHLbt080Ir4v6ZxyWZ5TTq8Jbh+UANGyYr8X1a25195d+Yp8UMWM2laRlG5diXqPBLcfmiZkQhYismBRfZxsW8iI6up7+rQ8l83KtXr2rFx/hw7J4iablZROoxHrfMLQod5Yp97U0DCxUyU8z6JUMHFdl1Z7i3p9hXptiaHBQ/heh0x6gEZ7nbGRx0nbRfxwiMWVUyysvM+xQz+LoRcZHnwE04o4d/mPKJcOkUmPMth3mOpWjj1TNo5boVZvMje/kvSjuUeREJEHDDt23G15zM5qLC+9xcufsXjiiSH6+uJGUYaeJfAH0TS5MG37Sl+YpVV+8PoPqLfOcvjwBMXCIIXcHtL2fsqlvZhWmo5TpdncIGXnGSzPYBg2jXYN16lTLKbQGQRsNCwMI8APPMLIIZvt4+hRnWPH4q6/QQA//GFscvb001IZMzcXh39VekZ9TlV2ODMTT4AJbh+6LdhVtMK24y6sEJOI3VIz3VqQ7u3V491RdEVuEtwZKDJy5oxEH2dn5doCub4UGVlYEPKxZ48IXFdWpOR3akoWF9msLBJOnpTXra871Go1/CBFu72EZRYIw4CUlUfXDPpL06SsAs3WGkEUsLB8HMdpMtC/h0y6H89vkc8Pk7JGSaUK1BqLfO9H/4wnj75MJn2YwfKjpFJ5Lsx9A9d5j4nRJ/A9k61KH567yu999TusrW8l/WjuUSRD9wOEnTruBn4faTvH2CgsLq3w7qk/ZXzief7V77zG7Pk5CMexrAxjIwN84sWD9PeN4jiwsLjK//Yv/wmnz54jlUoxPvwU/aUiA32Pkbb7CIKAytYlWu11hgaOkMsMABHN1hae18I0TQK/iK/ZaIaJoflAgGE08DyPmYMazz7b32VwJC3CVQ+aY8ck6rG2JvefeEIGybfflr/pdNzwLukbc2ehCIOqmFGlmiAEUqVmuu3dFRFRBKS36Z26meZ28pE0vbvz6I6M+L78VWkbXZdUrqYJGTl3Tkp7VWO9ubm4Ws22lYg1ZG5+jTBKE4QOaXuIRmMVtAb5LERAPj+MYVgYhkWztY4fdOg4W5w5P8vYyGOkUwVyuQE8z6dcfJS01U+7FbKw+g6Bf5w9Yy9jGHn2TvxFzl3+Jm3nLBEjLCyl8LwSF86H9JdNJsYGaXfcpB/NPYaEiDxA6O24G/hFwjAPRFjWOgMDVd46vsKZsxfxgzSjg4+SslO4rsuFyxepbNZ4+TPPk8sH/Na//ae89c57ZOw+irkjWMYk5f7DWGYGx21RrS3S6VSZHH8Ky8riBw6+36bdqZFN95POlEilCkRRSOh3wPQwNB/NcKnVX2N45FlsW1SplQq89ZZMVoOD4u74gx9IbrlYFGv3rS1ZcakOvErQqjqEJrj9CMO4tXuhIKtgkNSMEj3repzuU4LTbvR23+2G0ofs5D+idEUJ7hwOH5aoRxBsJyPK+Mw0hXgsLMjvtm+fpFFXV+Py3oMHYX2jyo/eXMTQ+3EcCz8IyGUHcb06W7U5UlaBIPBJ20UMI0UQOERRiFY0yWb6WF47Sbm0h83qOcZGjmGGJvn8JOMjJn4wQqM1z4/f+1fsGX+JYmGC/VOfZ6t2gbMXXiWzdI6JkSfpdCwqWwYXLp7n8KGRpB/NPYaEiDxA6O24q+k+hlZDN8R1Kp2xuTS3wv49z3Ng/54rK9smup5jcnyKpeU13nn/z9ioVHjv1DkG+44wUJ6ikJ0ikx3FMm3qjSVqjRVSqSzjY09gGjae18T1mnheh1JhHMvKkrKyBIEHkUPHqWGYUCo18cNZ0C5RLHySMAx588dbnDuvYVsWTz6ZxTD0q+W5g4NSRXPxouSlQSo1SqW4025CQu4cGg2ZYFTFzKVLcvxTqe1VUCrKpYhIb/lu91+4tnxXObJCnO5xHHnfBHcWBw/GHXRPndpO/Gdm5Le+dCnuXTMzI9vXapK6m5mB8kCVVucizeYihMOk7RJuGICmk0730Wxt4HpNgsAhnS6Rz46gaQaWmabj2FhWhrWNc9Rql/CDFtOTR0nbU2TSRdLpEab3HMCySpy/9DrV+hn2Tn6GQm4vjx6a4r0zf8D88usU85Nk0sOEYZZvfPuHvPjco0k/mnsICRF5gNDbcVfXt5ccrK0F5LOHGSj3o2ugGZuEQRmIMIwW5fImx0+scGmuwtjI02TtCXKZEWx7EMdtsFo/Q725xtDAQYr5EaIownEbdJwamqZTKkyQThfR0HG9DmHYIaSFpocsrb5OOhewsnaZx44doNXu8A/+4eusr0f4gU8QneLcxUcYHnwcO1VgakpMzU6flhWW8q3I52WyU/0yEtw5qB4z5bKky5SwOYpiTxjVqBCu7aqrHuu9dT+noiLdr4+imAAluPPYv1/Ihqqo6yYjU1Pye1+4IIuFMBTycemSpHVOnZJxqe2cZX7Jx9AHmRx9hggLIogiCX15fpuwFRFFIZZpU8gP025voWk6mmsyNHCAtJ1nZfV10naKemOFseGnCEKPdmuAcukw9kyGS/Pv8ua7/4ZD+75IJt3HIwe/zPrGKdrOefxgE8sq0Gi0ePX1E3zuM88l/WjuESTxqAcIquPu4tI6UVdMPIo0PHec6lYKO2XR3+9gphbQMIiikFb7HJWtC3h+QLuZJ2PNMDZ0FF3XsKwirXaFdmeLVmeT4YGDFPLDBIGP77dZWz9DFEXkskNkMn1EYYjj1mX79hZhELC2+edcXvweZ8+/T7lc5MjBvfz2v7zE8opPJp1icGievvx+VldLnD27RKGwwbPPitBtdVUGur4+ISHZbEJC7hYajTg1o1IoqnTXceS+ZcV9gbpJhYpw9BqWQTypKdKiyIeqoHGcxEvkbmN6Oi6LP316e8pteFiik/39kq47c0bSq4o4NhtDjI2MsLJ+hsrW+yyuvkYYtmh11vH8No3mKqaZJgjbVOvzNNsVPK9NOl0kmx2Qv+k++ooTlEp7mb34fTZr8xT6ljEMF9etYRh9pKwZSoUjaKQ4cfqrnDn/bUn3DhxhqPw4ESGut0Uhn6XVavPuidmkH809giQi8oAgDKUK5tDBPZyZvcTp2ctMjA5ipwdxWlkqWxsUi1n86AId18BKZdmszjN7fp7KZo0w0OgrHSFljZHPyZK24zTY3PwRppnGMFPks0Pouonvd3CcOtXaAoXcMLnsALlMP77n0O5sEmk+YRDQcbaYX36DpdW3iKKQY0ee4md++mX+6GtVOo5Pud8mlZ7Fc/ZiGGMU8h7rmyc5e9Fj6PjPUK3q5HIyAKZSkpIZG7vLB/ohRvfko8pqbVtIguPEJlaqz4zSeJjmzl13YffoiPIjCYK4/DPB3cXUlKRgajUhI90i8f5++d3PnxeBeRBIJCWTgY0NnScf+wzf/cE7rG+s0e40gYh8ZopWe5Nifpyt6hyF3DDtThXPb5PLlDH8DLZdIJMq4nRq2HaRdGqQybFnaXXO43oLNDsV6nWdYn4/um4wMfoEnt/i7IXvslVb4Mfv/hsOTP8EQdghbQ3Q8dbRDR1N11leqfCJF59I+tHcA0iIyAOAXt8Qx/VwOgGrehHTCrHMNgf223zx84/zjT/2OXHyHAPlIm+/c5Z2x7miA5kkDMtEYYDrBiyunCEMU2i6gWmlrqRafLhiXNbubJLLDTE0cAjDsnG9Nq3OJhohELG6/h6V2glqzYu02w6DAyWymUFefz3F5maTQt4nZS/gOvsJvSE0vYOZWqS/r8L8/BTzYy3GxvLk8zKRDQ7KLcHdh6bFXXfT6Tgaoqzf1eSkIhu95bm9hERFTrpv3RU0yho8wd3H+Lj8XltbEvk4eDAWEefzMTlZWZFzYnpaXhNF+/jUx/8C33vt26yszjF78YdMjUkFjeNsYqey1FtrRKFPGGhsVRfJ5YZotTexrSzpdAHfbzM++hj15hJRDU6euoimd4iCKTYry4wOP4lupZkafx7Pc7kw932iKGT24p+StotYlkEmHREREQYhpmXy7NNHE6HqPYCEiNzn2Nk3JM3GBlimyQvPPcKLHyszOSGsf2W1womT5/jzV48ThXBg78cIgyyel0HTXMZGBjh7/gREGWxLKm5crw1RRMrKUqst0upUGRk8TLlvGt2w8NwWHbdOGHigwblL32Fp5W06zgaappNKWZRLB7h4ocSifY5qfZaZgTyee5AoNNH0FpZ9Ed1oYRqPEgYOhtmhWMyj62LvrrxGEtwd+L7cbFsmnEZDHs9m5XHPi6MfSmDa3XG3u3x3NygSol6jTNPa7cTU7F7C6Kj8jpWKVNV0kxHVL0rXJXri+5Km2bdP5zOffAbfDzh/+X0qlUu0O+eIwgDHadLfN0Exn8d1Q1y/DUTUGksUcqNUG8tkM2U03ULXDFJmkXLfYdYrZ9A0k7XKa+haP223RT4zTD4/ymD5ACvrp2i1N9CAjlPD93V838QwDUrFHEcf2cexI/vu5qFMcAUJEbmP0esbAga+N0ba1hgfC7k09w6rGz6TEz/HufMLV6Mm6xubtJpQ7jtIu1kiCAIiOtSby1S2lijl+khZGRzXxfM6bNXmGejfz1Z9AYjIZfsZ6NtHEHo028tXPkxERMDJ03/I0uo7+IFDPpchm00zNf4xsulxms0Wa5WTdFyfiZHPk0qZmNYmhjkPURa3M4PvBmhag/5+A12XQSyfv5tHOQHEFTOZjPweqlw3l5PVcRDEJbaKlKgoxk7Oqt2GZt0RlG5YljyW6EPuPQwPy++2vi5kpNtQ0LJiQevCgghZOx345CeHgOd58+0SC8sTXJp7m8Xlc2gMYFoR2cx+0CPymX42a8uYepp6c4m0XaLd3sQwTILIx7YLOJ0qE2NPsb5xioE+mwtzr1Orr6DrJpaVZXT4GEMDM9Qaebaql9CATMYmCCKGB/uZHB/mmSePJGmZewQJEbmPca1vSAHQ0DQP01phdCTD2XNzfO/Vd/jaN74vUZPRQbLpfYTBOobWz2ZtkU5nC90wydr9GFoOx7Nw3SqtTpWOU8d1W5w88zVm9n6K4cEjZDNlHK9Bx6mi6xYQUa0vsrT6Oro5Rz5vMVAeBDT68k9jGWVc1+P47B8wMXqIkcHH2KzWGR/3sFIr+H4fvrOHCNisLzA5UWBwoMT0dGwfnuDuQnVMHhgQItLpyETTTUQUuVBCVs+7try6Vw/SnbrpjqBA3MvEce7IV0xwkxgclN98dVUcWLtbRhjGduOzuTkhpp/61BCjowO8d+oCgW+Rzb5LuVxkdbVBrTlHyhxhvTpHf980tcYihBGOUwfNoNWp4LhNCrlhyn2jOF6L4aHH2KicZ2r8GVbXT9NxRGNy4fL3mRp/BsvMXCkh19B1Qzp09xWZnh7ji597MUnL3CNIiMh9jF7fEF1voxtNwKNWb9LuOGxu1vjmt39IZbPGwQP7CfwRQt8mndJZXH6ftY1Z2s4mB/Z+gkJ+kDCKaLUb1BqLhGHAxuZ5tAgePfJlSoVx7FSJdmcLx21gXzEsm1t4g4WVt2g7Cxw6sIcgCLFMk06nQBDomIbDxcVvMzn2OLn0OI7rU2++x/lLVcaHj2EYkwSex2ZtgVxe59knj7B/v361DDTB3Ue7LSW6ui6TjBKPqo6qvY3sVCqnW/cBOwtTFYHpTeWoSa3bnTXxjbm3oPo8LS+Ly2p3vydlfKb60yjjs4MH4WvfkBa/xw5/Giu1xJ6JkFrNY23dpdGM2KpdJp8bod2p0ulUSaVypKwMQeCxsXUBO2Vi20Vq9TXy+WGymT50zWBj6zz1xgpEAZcXfoSm6ZimQTaTvmJpoPHcM0f5T/9PLyeOqvcQEiJyH6PXN0TTXTYq1auVMJ2Og+N4zC2s8vQTnyDwRxFDZZ+NzVkuL7yOH7iMDh2lvziN5zt0nCa1xgpRFLGw9BajI48yNnyMjN2H4zRotirksmUs0wYi3j/zCktr7+J6NUZHytjpFKm2yXqlihZ5hNG7OG6NseHnyGVGaXeaLMx/j4+9sJex4ZdYX7PwgwaaXmFqqp9nnjjMJz85lPSNucfQnTZRJbWqOqaboKh0inJh7UZvbxm13+7oSPfzakJTmhPPS5oa3ovo64s1IefPi8Nq9yLi0CG5v7Ag26ytV7m08Dbl8h5Ao1Yt4/oXSKV0Gs3LbG5pFPMj+F4Vy0yx3lzl8tnXOTLzBTLpEo3mKstr75PPlvACDTsoYZopxkefoNneoN5Y4erpGoWMDA3ywnOPsra+xaNH9/Orv/ILmMkAc08h+TXuYyjfkBMnz3FoZorKZo23jp+h3XHI5zN4nk9/fx5Tm2JpsU023aCvv4Eb1FirvE0QeOi6Tr25wuraSUwzj+M26Tg1VtfPMDX+DLnsoLT6Pvct0ukiQwOHAOh0qsxe+i4Ly8cJQ4fJiRHGRwfJZTM0Gy0ajRa5bIjvR4wOPUsuM4rjVJm98CdMjJfQogPsn36Sg/sgk2uTy1gMDJQ4dEhPrLzvYWharNlIpyUyoqplFBGB7RqRbqioiYqKKL8QFTVRhEZFRDQtJiLtdkJE7lUUi/Lbzc+LJqQ3rTo9Lb/d/DycuwCEk3ScNc6eO4fnFqhszYLm4Hs+lxcX2DP+DGguhlHi8uLrBIHLe6f/kOnJF4GQMNRxg3XaLYOOVieXHyYVhZQKo6xtnLn6vrqhMzRUZnOrzt49o/z8z76ckJB7EEmC7D6Grut86eUXKfcXOX32Mu+9f55mu0M+l6bRaFPum+SxR36CXC5LrbHBmfPfBzxSqZC0nULXNTRNo9XeZLO2gOO22Ng6x9rGKVJWFk3TqNbmefvE79BorWJZGTQNVtZPMXvpu1xeeJ0gcCjkc+zfO04+n2VkuEwqZaFpGpqWo69wCA2NZmuV9868gpWCJx/7EkQDgMfMTIHpqWEGB/s5ciQhIfciPE/IQDotk0uzKUQhnY5TNiAEQple+f61Nu2wu6MqxNERZWTWa4aWeInc28jnxWsExFlVdWcOw5C5+RW2ahfJZNcZKIOuDXHxgsna+hpt9zyFgkE6beN6PpqmcWn+DTY2L1KrLzA19hR2ShTrl+ZfpdFcJWWZmEaBan0JP3RpNZfRcMlmLY4e+hS5TP5Kui9icWmN1bUKqYTF3rNIqOF9joMzU/zSL3yZ//3ff4sTJ8+hmzodx2NibD/79jxKLpdmcfk9tlYv4bgpavUmxUKO4aF+VlYr+EFAJm2D1mBp7S00TUZ/12uyuPwOQdDBsvJMjD5BsTBBtTpPpXqJjc0LaJpGJmNTLOY4e+4yhbxc/K22S19pkr7iBEEYcmn+NdqdOvl8gU997GexjUF8vcpAWb9qdpX0jbl3Ua/HFTOFAlSr8ngqda2ZWT4vxKS7X4z6Czs3vOvWiCj/kCAQMqNMzYIgISL3A3I5iX5cugSXL4PrLfCn3/v+VY8j206xd2qGra0GujbA5NiTtJ05XK+CndIZGS6zslrBCV2WVs9SKlYZ7D/A2PCjrG2cpdneoFa/TCrlkw/30Fccp95cp5AfpO2sYpoGuq4zMvw46xuzBFGdT774JP39BZaWN/jN3/5q0nH3HsQdISKO4/DCCy9w/Phx3nrrLZ588sk78bYPDQ7OTPGzf+nTnDl3mdGRATJpm3xuECIN3Vhk394BKptrrFe22Nqqk8tl2LtnjNOzcwTtDpquXZkcQgJfiEgUhRhGgGFa9Bf2kM8Ns1E5x8r6+7Q7WwBYlomhG3TaDkND/Tz71BHSaRuNEVZXO0RhSLMzSy5vsHf6EMcOfY50eoD1jTkmJvoYHi5iGCJeS3DvolYTEtDfL0SjuwFhsykRE9eVCEk+L/4SqnkdbDcp69WJqHSMMjVT/iGdjhARlZpxHBLd0H2CTAb27oVXX13jW995m0p1hdGRongctR1+/M6PuXhxg+Ghx0jZk9jWHly/jtNpsLnVYKBcolZv0O641BtrdJwGo0OPkc2WCKMGtm3RaG5Qb2wyOf4UI9lJ/CCg0VzFceu0Wi0G+g8yOXYUy/bJ5tIUCzkK+WzScfcexR25tP/u3/27jI+Pc/z48Tvxdg8lioUc/X0FcllRh0P7yg0GykWOHJrm/TMRzXaH8xcXsVMWX/zsx3jz7VOsV7bodBzCUFwHU5aJH4RYlsn0xHO4XkBl6wLzS28TBO4VHwidtJ0ibac4eGCKmQOTlPtLBN4Eh2fA0ue4uPA6k2NlJsb30ld8nMC1WNu4RC5n8MyTB8lmdfbuvZtHLcGNwHGEGOi6kATHEYKQzQoRUYZmuVysEemugukVoe5WvqueU71sfD8WxLZacdonwb2PVCrknZN/TqMZMD35GKa1ia63KOSzTI4P8f7pi7j+eXQ9RxgWaNRNVtbWAMikbeyUhWmaVxY7OmF0kUMHRxkdmSaTsdnaarC5VcM0G4wNP8XmZpOtWo6FpbcJwiZtd5ax0RcAjZQpTW80TUs67t6juO1E5JVXXuGb3/wmv/u7v8srr7xyu9/uoUWvcFXrGt2jKKLdcfiLX/wEf+mnfoJWW8zGJsaHODs7x+/87h/zozffY2F5jXqthWGIG+rw4KMEgVi2V7beZXrPEKPDZVzHY3xiiEajxfBgmbHRAUDD9yYADV0PGZ9oUWsVGB/bT+DvY2vLB22DyYk+nnnyMIcPDTExcbeOVoKbQa/RmEqRZLOxZ4jvxz1lYHv0YqfS3d00IhBrRJQI1vPkPRN31fsHC4trnLt4gXK5jKYNEfhlMCN0vU3qSqO5er3KU4+1aTVz1JvD6JqJYW0QhRGWZRKGIY7r8cyTRxgeKlMsiG4tiiI2N+u89OITRBG89/5pDh96lGZziMnxEd59/1VKRZ3VjeNMjO2nUCgCwooz2TTOaiXpuHuP4bYSkZWVFX75l3+Z//Af/gPZG+jh7TgOTpd7Ua1Wu50f74GCEq4uLq1xZnaO8dEBMtk07VaHxeUNyuUiX3r540zv2d417vChaX7+5z7H6nqFfD5LPpeh0YhotUrUqg02a8vkcy0+++nnGCyXWFqpUC4X+fQnn+H3v/odRob7AajV2vheE9Ny8Pwlzp6fo1IxGerfT8oyyWU7PP7YEY4e2cvgoM7w8N04Sgk+ClSVi0qTqBJNJUztTseoqAZs77yr9tObpul2WVXRFNXNt9VK9CH3G656HGUNdH0Z3xslCk3QJXo7NNTHhYtLOF6HuaU3IJpgZGiSKBri0sJbjI4McGDvBH/6vR9zZvYyI0P9BGF4zXgGsLS8zrmLJxgfmaRYHKC/dIBWe51MxmN6Tx5dj0VK7VYn6bh7D+K2EZEoivjFX/xF/ubf/Js8++yzXLx48QNf8+u//uv82q/92u36SA88lHD1agO81Qp2yuKxYwf44ude3FGgFYYh3/jj13Bdn2efOkIY9hP4OdrtDuuVM7x7sknKTuH5AZvVxtV9pdMpvvaN77OwtM7S8rp08A3fxfPaNFsdBsuH2DP+FP39RQxti0q1wo/f7rB/X5rh4fG7cHQSfBh0OkIKMhkhHp2OEASVJumOUuh67B2iohq92/RGQ3p9RNTzqvw3sXm/P7Hd48jAtBavPqdpGhOjQyyvVDh3YYH19S3y+RaZ4BiddkB/3xQz+/sZKBd5/pmjvH/6IgtL6ximseN41j3mdZwlLLNEITfMoZln6O/fuvq+URSxuLzBY8cOJNbu9xhumoj86q/+Kn/v7/29627z/vvv881vfpN6vc5XvvKVG973V77yFX7lV37l6v1arcbUVKJuvhkcnJniwP4JFhbXaDTbV1Mwuwmzem3ioyCFpkGxtEmpb4ChwWdYWFrnr/7c55nZP3l1X2EY0lcq8K0/eQ3btsjnc1imztzCFsX8EUr5w6RSKbKZKppuMJYZ5NL8Cb7/WpMnn/i5RCh2n6BelwiIqphptYQY2LY87nlCFtTP2V3Kq4SovdbtO2lEuslJd7WNZW0v5VUurgnubVybKo6fk1Sxy2c/9SxBEPLN+ddAA8d5m/GRw0xOTjBQlqlpfHyItuPyV3/u84yODOw4nvWOeatrm3z9m3PUG20sc4B0boNOu8nC8jp2KsXBmT0sLK5dd1xMcGdx00Tk7/ydv8Mv/uIvXneb/fv38yd/8ie8+uqr2D0+3c8++yx/7a/9NX7rt37rmtfZtn3N9gluHrqu37AQq9cm3rAqgH914MjmMhimwejIwI779P2AKIpIpSzCwGCg7wmKhWk0YKv2LmhPAWBZG4wO24lQNTcrcAAAMaRJREFU7D5DvS7RCFUxs7IijysPkW4zM02LiYhK46i/vfqOnTQicG0URVk/dHuJJKLVex83kir+z3/+S6RSJssrG2SyNv2l4lUdiEK71SFtp5jZP3ndMaN7zHvk8F7GRwf5o1feY36hQ6Vq02qt0my38Wyf3//qd/jaN77PzP4pvvTyzpHiBHcWN01EhoaGGBr64LDWP/pH/4j/+X/+n6/eX1xc5Atf+AK/8zu/wwsvvHCzb5vgNuEam3htux3mbjnV7736Dj98/QSmadJstanWWuyZ+BiF3CSWEVFvnqTtrtLuTFMsBmialwjF7kMooaimSRREmVQpV1XXFT2IaQppUG6qKoqhynK7sZOXSLfTKsR/1bpEmaIlROT+wY2kisMw5PFHD3Li5DmKk9tJyEdJpRycmeL/+rcmmD23zptvVnj9xzqpVJHh4fBqGfGJk+dYXFpLfEXuAdw2jciePXu23c9f6eV+4MABJicnb9fbJrhJfFC1zU4DwdnZOf7173ydjUqV0eEBRkfKOK5Ff2EfjcYWLecstu3itwJcbxlN6wMSodj9iN5ohutuL931vJhwdBtXKuLQTS52Ss3A9mZ56j3V6zIZ+b8iOIlo9f7CB6WKbyRy8mG75Oq6zsyBQV751nfxg4ipiQPoGuj6AoV8lkMzU4mvyD2CxCLoIcfNDgRhGPLKt16l2WrTX8qj6Rp2qp9SYRyiGgsrP8D1GkxNjmCaBqmUnGKJUOz+gyIQqpIFRBcCQkS2tuS+Kt3N5WLiEgSxXXs3AdlprO/ef/f7gpAbZWoWRQkRuR/xQaniDyOyv1EsLK5x7sIlyuUiuj5AFBlEURpN6yS+IvcQ7hgR2bt3L1FiBHBP4mYGAiVuPbBvAtf1WFuvMtgnVTDN9ln6+iwWl3zmF1fZNzVKNpumXm9+5NVNgjuPdjuOSiijMscR4tDtIaLSM91iVhXBUKJVhRvxEel+zDTl/ZRpWnLqPJi4WZH9jSIuI7YxjCXCMI2muVefT9LF9waSiEgC4MYHAnVh57JpZvZPUq+32NhcBG1TdAK2hWVZDPSX6OsvcuHS0i1b3SS4s1AVM9mskAzXjStoVOWKIgjK3r3VkscUEemOduyWmlH7iqK434wiMEp/EgRJ990HHTcjsr9R9GrgdH17SC1JF98bSIhIgqu4kYGg+8IeKJd46olDzJ6fp7LZulpBMzLcz//w3/8N9u+buKWrmwR3FqrZXakU95UBIR0KKiJiGHHb9yCItSPK2KzXVbWXiHSLVXV9OzlRniRdXocJEtwQPowGLsGdR0JEEtwUei/sgXKJcn+RWr2J63rML67x/DNH+eQnnkyIx30O35coRF+fkI9KRYhBOr29zFa5qqoKF12Py3i7HVZ3IiGw3RdEeYeo13XbvCemZgluFrdTDJvg1iE5+gluCurCLvcXOTM7R73eJAhDdE2jsllncmKYL7388eTCfkDQTSJURMQ0JToRBNtt27uNy1TVjJTvhkRRgO+7tFotoijcRka6e9RA7EsShhKRMc2490z3Z0qQ4EagNHCPHj1AZavO+YuLVLbqPHbsAL/0XySlu/cCkohIgpvG7VS5J7g3sFPFjNJ9ZLOxmZkSk2paTBRA/q9psFWtsbUFQWBTqzfYPLnIwnKbp5/YDwwCcYO87qZ3iuy0WjERUe/vunH0JUGCG8HtEsMmuDVIiEiCD4Xkwn6w0WwKMchm49SJMjbL5YQgKCKinu92VQ0CaDZbzC2cg3APppXBwMDEYmHpMlvVCoN9P4lhpLaV+EJIp9PBcSJcN4XrGliWnFOKqHQ6CRFJcPO4HWLYBLcGCRFJ8KGRXNgPLup1IRqqxwzEHh7ZrOhFXFeiIbouhMSNqyIJw5DVtS06jksha1193DANRob7WV5ZJm21KBZMDEO/Ql48Fpc2aXXWiCKdlDlEs91mfLQPz8tQ2Wyzvu5TKpUolRLCmyDBg4KEiCRIkOAaqB4zqmLG9yVVkk7HVu6OE1Kttmi2IpqtCMfJAzq6Do1Gh7bjkM9mr+wxAkLQQAPKfXlcz8P3AwxDx/dd2m0XP3RIpww0zSaKDLa2OswvHGeg/1EqW4ucPHOG8bEiX/6pg0kKMEGCBwQJEUmQIME1UELRUgnS6ZCzsxusrZn09+tEUYHFxQ1ee32NrWqLKLI4dXaF989YPPX4o4yPDeF4IVEIhmUSegARaKH8JcJKmTidANfzqNVaNFshUWRiWyl0qwWhjudrBEFErbZBfykkmykQaVkuXd7gN3/7/aRHSIIEDwgSIpIgQYIdEQSwuLTGN7/zHS5cdNDZB9omf/7aEtWtASxjgny2gG6ahGHIxUtrbGz8iJc/8zymkUfXdQIvRAMCAiLfR4tcOk6Ly/NzDPfP0NDafONPXidlPEo2XULiJeLx4HsBYQR2OsIPAnBDTCtibHSQC5dnkx4hCRI8IEiISIIECbbB90PW16tcuFjjvVNvUm/NMjH6FIZexA8a/Pmrb5PPPsLjRw9imGk0zcPORUxkhlhc2uDN46eZGnuRtJ1hbWOJbMZD16DZalJvrrC2fgbXazE2ZGJbFl6g43U8fCsAPUTHww8hiPJo6NQaW5QKLo7rsbqxTCajUyjYSY+QBAkeECREJEGCBFdxdnaO//j1t5mbD1hablOrLZPJNRgZDLFsnQgPXTPQdY1qrcHwQAm0DprWATKU+wosLa8z2OdiGAaVSpXUsIdhGASBT6vZptFsY1ghlqUTEBKGAZoWEkURTscjZUMYNCAo4/s+jWaDMHTQtIgwjNiobDG/dJGhgb6kR0iCBA8AEiKSIEECQEjIb/72V6nXc+SzwwTBOqmUx9r6Jimzwv7pEfywgabZ5DJpHMfFC5rYpouuu/gBWLaFV62ztLLB0nIDPwgR38SIIPDxAg9D1zA0jXa7g2XqXJ5fYnRwP2EIrhfguhG6YRCGAUEQ0mxs4Xouvt9ma6uF4zZZXl+WZmaZpI43QYL7HUlyNUGCBIRhyCvfepXKZo2piWlMM4vve5iWQ7m/iOO0WK9UMM02tl1A0y2CMCAMPQyzBviAhue4eK7P3NwyrueSsiyMq/arEVEUgg6tToeO0yKMAnI5G02LRBfiB2xu1llbrxGGIWEU0XKaRFGAaaTQdZ0wDOl0AlbXKly4tHQ3D1uCBAluARIikiBBAhYW15g9P8f42CCgYeji79HuVNE0A9s2WVk7AZrLQP8QjhMRRf5Voag0rIvY2KyxtVWn3XEIo5AwgjCKiKKAKAyJwgDP9fGDAN/vEIYum5s1PGXLqulsVOrUai38UHm6h4Shj++7BEGArmlk0gV8P+CPv/MjQuUtnyBBgvsSCRFJkCABjWa7K9Whkcn6pOwG9UYT08hgmgbtTg3PC5gcmyKKDHzPx/U6BEFAvd5kaXkdz/OoNTqEUQRRiK4ZuG6LRnMD1+sIoQiFlBBB4Id0HBeiAI0IQ9fRNIMwDAkCcUjTdZ1qbZEg8nHcNpquUSr0oxs65y7Os7C4dncPXoIECT4SEo1IggQJyOcy2HaKVssjbYNhOOyZyrJRsWk0QMMlDBusb5hUt3RMYxA0WFhcpuMskC+scezwIywsrYBmYhqiC3HcNl7gEF6JbgRRCFGEH/j4oY/ndyBqx9GPK9GYIHBw3CaGYWGnsrSdCmHoMr/0FtMTj2NZWVIpiygiEawmSHCfI4mIJEiQgInxIWb2T7Gy6hBFEZrWZqBc5KknDjM0OEqj2aLRqvP+mUsATO+Z4uD+SSbHhygVc+SyBfZOj+G4DvlMFl3XCAIhF2Eo+pEIJEVDBFFEGIZ4vsPy2inCSKVXNCJ0qaBx67Q7VUDH99o4Xo0wdPF8H8vKk89myGXT5HOZu3HIEiRIcIuQREQSJEiArut86eUXWV76EUvL65T7t8jkUqQsg0KhRH/fAH7os7lV59ihI4RBHoBCIUWpb5KLlzd558Q5NCKKxSJRCI7ToO1UKQPRlaZ2QjgiNE0D5P+maYleJPLRNB3DsAjDEN938H2HVrtCJtNHKpXHsrLomk7atkFL88SjB5kYH7pLRy1BggS3AklEJEGCBIB0VH75J3+CPZOjVKpVzl1YYG5xlf5ikcePHQBgZKhMvdnCcR2kd4yPRkA+X2BhcZW206Gv1I9pGtSby7Tam6Bp+L5DGPgEgXiKGLpGFIqVfKPZod5YwQs8NHR0zSAIfQLfIYoCbLuI53VA08lli2iGhqbpHD64hy+9/PHEWTVBgvscSUQkQYIEVzE2OsRPfXGAtUqOb/2JREe2ag2++4O3+fE7b1EsFBkasND1ZbKZNIVclnpzka1qFd/PYhgebmeLvmIBz+9ciXxoeH6b85e+S6RFmLpGJwpBu2L/HmoEV8Srmq6BpolQVdOwDJtcpoTjtImiENNIk0lbPPPkEf6L/+xY0msmQYIHAAkRSZAgwTYsr2zwjT/5AZXNGpPjw7hOhkuXL9FotNHIkM92iEKX6laDy16HtrvC8OAIqZRJ4Pu0Wz4aTTzPwTAsAKIoZH3rAu3OJlPjTxEGPlEUAaBf0YT4fofISKGh43kdNN1As3IYRgbdbJMyLWb27+Pln/wELzz7KAdnjLt5mBIkSHCLkBCRBAkSAOB5Ymz21jtvU9mscWhmiihMcXmuguc3KBZzBH6Kdtsll8thGDYdp0GzVaVa7cPQI9AChso2mqYRhh6WmQZkv4HvEIUhQRAQXDErQwNdN4iiED9w0XUTTdOpt1bJM0DG7idlZWm11/HDAD0yOXHyFJVKHSs1zaGDSUQkQYL7HUlyNUGCBADU61CpVLk8f57xsUE0TaPeiGg0WqRSAbquYVkZgiDA81w0TUzPWq0qrU4bPwhIWTaGIeub8Ep5ru93qDWWaLY3ut5NKmcANE3HcRt4XhsQ91Xf71CtL4qexEhhmGBoOqZpUCgYXJ5f5v/3L1/h7OzcnT9QCRIkuKVIiEiCBAkAqNWg7bi0O5tXe7i4fl2Ih1/H9wOCYAPPq+P7PrqWAsAPXNY3TrNWOU274+C4oucIwoBqfYl2ZwvHqcmbaFI+E4YRYRQAGrpucOHyD+i4NRynieu1rmhLQNMMNA2CwMFOW7iOTxS5jI0OUq16fP2PX02cVRMkuM+REJEECRIA4DiQsVPYtkWr7QCQskKqjVM4jkcYRuh6QNuZI5N1MU0N0zTRNZ0o8mk0N/GDgHZ7i3pjhTD0MQ2LRmuN1fUzAGgoghHiOg06bgPHbdBqb9BorhKEV9xUrziydjpVIjTAJ5tJg2bg+g00TWNkaISz5+YSZ9UECe5zJEQkQYIEgGRKBgZKzOyfYnFpnSiKKBaylPvztDsOmgaO65HJ2Nh2Az+o4PuSspF+MhGr66eo1heJCAmjgDD0r0RHhNiEUUAURbheC93Q8Nzm1ffvdKq4ntzXdZOl1XdxvCYaUC6naTsVWu11HKdFFEXYdgHH9RJn1QQJ7nMkRCRBggRXYRhibFbuL3Jmdo5Go8W+6XFs28JxpOlcqZSXJnUamKYB2lW5Bx2nRqu9QRRdqZTZPE+1tkAYBgAEgcvi8nFct4LjrtF2ttA0qX6pNZavOKmKgFVB03QWlxZZXLzM1tY67743y9lzc2xstLFTVuKsmiDBfY6EiCRIkABHAhYUCmJs9ku/8GUePXqAyladrVqD6akxjhzaR6mYZ32jSq3eYqBc5NDMFJm0vW1f65vneP/sKywuHwciXK95lYgA+IGDYRq4Xgvgqh6k3dliZfUUrXYFXY8Fr0EYsLK6ghf42KkcAFvVOmfOXabcV0ycVRMkuM+RlO8mSJCA2hUtaaEgfw/OTHFg/wQLi2s0mm3yuQxjowP84LV3rxqdZdM5LMvgC599gW//2RvU6y3QwHE8WlcqZGy9gAaEkQ+ArmtoGmhAu90mZcVEZG3jLI3mGoXCMPqVKEkQumL5HgUYmk4Q6qyvbZJNi9A1IrqDRylBggS3AwkRSZAgAfW6/M3l4sd0XWdqcmTbdj/xiad46cUnmF9Y4/QZKBUsnn76p/m1X/9Nvvq179JqddB1TVxSNQ1DD9B1nTBsk7KMK66pIZ7nYZo6mqahXQ3MRoShh+vWr6RrNFy3RcrKYRgpgijEcxzRqxAxs3+Kza3zLCyuXfM5EyRIcP8gISIJEiTAk0a5qrr2utB1nbHREVpNKBbBNOEX/upfYG19k9feeI8IsCwT3/fpOA3mlt4gnbY4MDnJRqVKtd4kk04TRQG6pmOZ5tX3jiKftrNFs1kBIiEl+WHSqTwdd5VWc418Lo1paWSyaRqVRKyaIMH9jkQjkiBBAqLoxkiIgtKU2FfkIQdnpvg7/5e/xk9/6ZMUchlcx0PXdQqFHAf2j/P5z7zAx55/DNMyOTwzzUsvPs5gOU8mk8K0rCsfAjzPodXaZLM2h6ZBvbmC06nhhw6Os46dMQmCkHa7TrvVwU6lE7FqggT3OZKISIIEDzHCMGR+YY25BcjnUoRh6Ya62XY68jedjh87ODPF//P/8V/yk596ht/8rT+k3myxf3qM4eEynbbDqbOXsUyTvXtGuXh5mSDYYnxkkPHRQTY2LxMEIRCxVjkNgG2naHfW2Ni6SBS4mGYOQ9dxghAvaNFotjn6yL5ErJogwX2OhIgkSPCQ4uzsHK9861XOnV8jCkcwzCavv5XhSy+/+IFdbXsjIgq6rvPpTz7DxNgwr3zr/9/evYdFVecPHH+fGWAYrgGiArIo4iULAyHZ1N1M+WnWY6uWdjHNS7ilrdrF1DStViVD20xbs0yrtfWWq3kLI/PSmmmmkHgBUUEcEBEUkOswM78/WCZHUCCZhsHP63nmeZwz5/KZ78E5n/O9nO9+0s5kcjYjG42TI3d3CUajcSTrwiVKy8rxbeEFSikeHu608PEkP78QfaUBB7W6qmnHYKCiopjsi0k4OjhUzdRrMKDXV2IwlOLi4swfAoPQZeUS4O9brwRKCNH0SCIixG3oVFomn3y+mfzLhfi37oCTow8V+hKSj58mKzuXsSMfuWkyUp2IONzgF+RGo25iF3zOlq+/p2VLL6pbghRFoXVLH4qLy1CpqmblNRiM/3t0u4KrixYXF2fKyisoLStH4+SIu5uasvIKdu5OYtf3OwkJDqxXAiWEaHokERHiNmM0Gvk6Yb95hl1DpTcmkwo3NzUdQwJJTcsk/tv9tA8OuGEtQ3Xn1pupbdRN94gubE/YR1FRCQoqHByucLU4h5LScvz9fNE6azifdRFFUTCZjGicqh45Xz1I1+sOD1r5euPt7UGrFt5onFWU6yvqnUAJIZoeSUSEuM3osnJJO5NpnmEXqmfBBVDwb+1jnsOltmGxRqORS3kFlJRW4OpGg5pFutzZjs4d2pJ/uYDikjKKik/ioFbT0teLkOBAHB1U+Pn58MCfIjmboeNi7mWuFFxFpVIR0i4AfaWBCzl5dAwJpFKvQcGEu5tLvRMoIUTTI4mIELeZq8WllJdXmGfYVTtcwmRyMn+udXGm/GJ+rcNiq/uVZGQ4otdXsiX+XIOaRQL8fQm/pxNHj6fRuqUP+spKnBwd8XB3ASA1LZOw0I48NawfgEXTjtFoZOHif1smUErVE1sVpe4ESgjRNEkiIsRtxs1Vi0bjRElpOe5uLihKJYpSaf68alhszTlcLPqVtIxAo3GgvPJKg5pFVKqquWyysnPJuZiPf2ufqueBXC0h60Ie3t4ePBh9n7lG49qE4kRK+nUJVB6K8msb0c0SKCFE0yX1l0LcZgL8fS1m2L2WyWQi60IeHdoHWgyLvb5fibNWU/WckP81i+RfLiT+2/3/62B6c9fPZXMmPYv8K0WE3tWesSNunMxcm0ABqFRlKMqvc9jcKIESQjRtUiMixG3m2lqJ1LRMc61EaUlZrbUSUHu/EkX166R1DW0WqW1UTV19TaoTqOTjp+kYEmieowZ+TaBC72ovzxURws5IIiLEbai6VqL6WR/lF/PRODkSeld7Hoyu2d/j+n4lDo4XgV9rP35Ls0hto2rqWr+hCZQQoumTRESI21RDaiVq9iuxHL/7ezWLNDSBEkI0fZKICHEbq2+tRFNqFvktzTpCiKZLEhEhRJ2aWrNIQ5t1hBBNlyQiQoh6kWYRIYQ1SCIihKg3aRYRQjQ2SUSEEA0izSJCiMYktzFCCCGEsBmrJiLbtm0jKioKrVaLl5cXgwYNsubhhBBCCGFnrNY0s2HDBmJiYpg3bx59+vShsrKS5ORkax1OCCGEEHbIKolIZWUlkyZNIi4ujrFjx5qXd+nSxRqHE0IIIYSdskrTzOHDh9HpdKhUKsLDw/Hz82PAgAF11oiUl5dTWFho8RJCCCFE82WVROTMmTMAvPHGG8ycOZOtW7fi5eVF7969yc/Pv+F2sbGxeHp6ml+BgfJcAiGEEKI5a1AiMm3aNBRFuenr5MmT5qnAZ8yYwaOPPkpERAQrV65EURTWr19/w/1Pnz6dgoIC8yszM/PWvp0QQgghmrQG9RF5+eWXGTVq1E3XCQ4OJjs7G7DsE6LRaAgODubcuXM33Faj0aDRaBoSkhBCCCHsWIMSEV9fX3x9657UKiIiAo1GQ0pKCr169QJAr9eTnp5OUFBQvY9nMpkApK+IEEIIYUeqr9vV1/GbscqoGQ8PD5577jlmz55NYGAgQUFBxMXFATB06NB676eoqAhA+ooIIYQQdqioqAhPT8+brmO154jExcXh4ODAiBEjKC0tJSoqiu+++w4vL69678Pf35/jx4/TpUsXMjMz8fDwsFa4zV5hYSGBgYFSjrdAyrBxSDk2DinHxiHleOtqK0OTyURRURH+/v51bq+Y6lNvYkOFhYV4enpSUFAgfyS3QMrx1kkZNg4px8Yh5dg4pBxv3a2Wocw1I4QQQgibkURECCGEEDbT5BMRjUbD7NmzZVjvLZJyvHVSho1DyrFxSDk2DinHW3erZdjk+4gIIYQQovlq8jUiQgghhGi+JBERQgghhM1IIiKEEEIIm5FERAghhBA2Y3eJyLZt24iKikKr1eLl5cWgQYNsHZLdKi8vJywsDEVRSExMtHU4diU9PZ2xY8fSrl07tFot7du3Z/bs2VRUVNg6tCbvgw8+oG3btjg7OxMVFcXBgwdtHZJdiY2N5d5778Xd3Z2WLVsyaNAgUlJSbB2WXXv77bdRFIXJkyfbOhS7o9PpePrpp/Hx8UGr1RIaGsqhQ4catA+7SkQ2bNjAiBEjGD16NElJSezbt4+nnnrK1mHZrVdffbVej98VNZ08eRKj0ciyZcs4duwY//jHP/jwww957bXXbB1ak7Z27VpeeuklZs+ezeHDh7nnnnvo378/Fy9etHVodmPPnj1MmDCBH3/8kYSEBPR6Pf369aO4uNjWodmln376iWXLltG1a1dbh2J3Ll++TM+ePXF0dOTrr7/m+PHjLFy4sEFTuQBgshN6vd4UEBBgWr58ua1DaRa2b99u6ty5s+nYsWMmwHTkyBFbh2T33nnnHVO7du1sHUaT1r17d9OECRPM7w0Gg8nf398UGxtrw6js28WLF02Aac+ePbYOxe4UFRWZOnToYEpISDDdf//9pkmTJtk6JLsydepUU69evW55P3ZTI3L48GF0Oh0qlYrw8HD8/PwYMGAAycnJtg7N7uTk5BATE8O//vUvXFxcbB1Os1FQUIC3t7etw2iyKioq+Pnnn4mOjjYvU6lUREdHs3//fhtGZt8KCgoA5G/vN5gwYQIPP/ywxd+kqL/NmzcTGRnJ0KFDadmyJeHh4Xz88ccN3o/dJCJnzpwB4I033mDmzJls3boVLy8vevfuTX5+vo2jsx8mk4lRo0bx3HPPERkZaetwmo20tDQWL17MX//6V1uH0mRdunQJg8FAq1atLJa3atWKCxcu2Cgq+2Y0Gpk8eTI9e/bk7rvvtnU4dmXNmjUcPnyY2NhYW4dit86cOcPSpUvp0KEDO3bs4Pnnn2fixIl89tlnDdqPzRORadOmoSjKTV/V7fEAM2bM4NFHHyUiIoKVK1eiKArr16+38bewvfqW4+LFiykqKmL69Om2DrlJqm85Xkun0/Hggw8ydOhQYmJibBS5uB1NmDCB5ORk1qxZY+tQ7EpmZiaTJk3iiy++wNnZ2dbh2C2j0Ui3bt2YN28e4eHhjBs3jpiYGD788MMG7cfBSvHV28svv8yoUaNuuk5wcDDZ2dkAdOnSxbxco9EQHBzMuXPnrBmiXahvOX733Xfs37+/xpwAkZGRDB8+vMGZbHNT33KslpWVxQMPPECPHj346KOPrBydfWvRogVqtZqcnByL5Tk5ObRu3dpGUdmvF154ga1bt7J3717atGlj63Dsys8//8zFixfp1q2beZnBYGDv3r0sWbKE8vJy1Gq1DSO0D35+fhbXZIA777yTDRs2NGg/Nk9EfH198fX1rXO9iIgINBoNKSkp9OrVCwC9Xk96ejpBQUHWDrPJq285vv/++8yZM8f8Pisri/79+7N27VqioqKsGaJdqG85QlVNyAMPPGCunVOpbF7B2KQ5OTkRERHBzp07zcPujUYjO3fu5IUXXrBtcHbEZDLxt7/9jY0bN7J7927atWtn65DsTt++fTl69KjFstGjR9O5c2emTp0qSUg99ezZs8bQ8dTU1AZfk22eiNSXh4cHzz33HLNnzyYwMJCgoCDi4uIAGDp0qI2jsx9/+MMfLN67ubkB0L59e7mragCdTkfv3r0JCgpiwYIF5Obmmj+Tu/sbe+mll3jmmWeIjIyke/fuvPfeexQXFzN69Ghbh2Y3JkyYwL///W+++uor3N3dzf1rPD090Wq1No7OPri7u9foU+Pq6oqPj4/0tWmAF198kR49ejBv3jyGDRvGwYMH+eijjxpcO2w3iQhAXFwcDg4OjBgxgtLSUqKiovjuu+8aPmZZiFuUkJBAWloaaWlpNRI4k0xofUOPP/44ubm5zJo1iwsXLhAWFkZ8fHyNDqzixpYuXQpA7969LZavXLmyzmZFIRrTvffey8aNG5k+fTpvvfUW7dq147333mP48OEN2o9ikl9NIYQQQtiINGoLIYQQwmYkERFCCCGEzUgiIoQQQgibsavOqqL+ioqKyM7ONj8ITgghRMOpVCr8/Pxwd3e3dSjNliQizYzRaCQ2NpaNGzfaOhQhhGg2Bg8ezPTp0+V5QVYgiUgzExsby6ZNm5g4cSLh4eE4OjraOiQhhLBber2eI0eOsHjxYqBqmhHRuGT4bjNSWFhInz59mDhxIiNHjrR1OEII0Wx8/vnnvP/+++zatUuaaRqZ1DE1I9VPWAwPD7dxJEII0bxU/65Wz3smGo8kIs1IdcdUaY4RQojGVf27KgMAGp8kIkIIIYSwGUlEhBBCCGEzkogIIYQQwmYkERFCiNtQWVlZna+G9IfIy8ujZcuWpKenWy9oK3jiiSdYuHChrcO4rUkiIoRoFL1792by5Mm2DsOsqcXT1HTr1g2tVnvDl4uLC+fOnav3/ubOnctf/vIX2rZta1724osvMmTIECtE/9tdH9PMmTOZO3cuBQUFNozq9iaJiKiVyWTiamkFV66WcbW0gub0uBmDwdDse74bjUYyz+dwIiWdzPM5dvN9KyoqbB2CzZiMRowXczFkZGK8mIvJyufs2WefxcPDg9TUVM6ePWvx6tu3L9HR0RZJxc2UlJTwySefMHbsWIvlBw8eJDIy8pZjraysvOV9VLs+prvvvpv27duzatWqRjuGaBhJREQNBcVlnMi4xNGzORxLv8jRszmcyLhEQXGZ1Y755ZdfEhoailarxcfHh+joaIqLizEajbz11lu0adMGjUZDWFgY8fHx5u12796NoihcuXLFvCwxMRFFUcxVxJ9++il33HEHmzdvpkuXLmg0Gs6dO0d5eTlTp04lMDAQjUZDSEgIn3zyiXk/ycnJDBgwADc3N1q1asWIESO4dOmS1cqgsZxKy2TJsi+JW7SK9z5YTdyiVSxZ9iWn0jKtdsxRo0axZ88eFi1ahKIoKIrC6dOnGTt2LO3atUOr1dKpUycWLVpUY7tBgwYxd+5c/P396dSpEwA//PADYWFhODs7ExkZyaZNm1AUhcTERPO2Nzs/tcXTlJsMDOd1VGzaQvnaLyn/ciPla7+kYtMWDOd1VjvmyJEjKSsr4+jRo7Rt29b8cnNzY+/evTz77LPmdQ8cOECvXr3QarWEhYWxd+9eFEUhOTkZgO3bt6PRaPjjH/8IVCWUjo6O/PDDD8yYMQNFUcyfzZ49m9DQUFxdXWnVqhXPP/88er3efKz09HQURWHdunX86U9/QqPRsHnz5nrFAXDu3DmeeuopvLy88Pb2Zvjw4Vy+fPmmMQ0cOJA1a9ZYrazFzUkiIiwUFJeRej6P/KJSNI5q3Fyc0DiqyS8qJfV8nlWSkezsbJ588knGjBnDiRMn2L17N0OGDMFkMrFo0SIWLlzIggUL+OWXX+jfvz+PPPIIp06datAxSkpKmD9/PsuXL+fYsWO0bNmSkSNHsnr1at5//31OnDjBsmXLcHNzA+DKlSv06dOH8PBwDh06RHx8PDk5OQwbNqzRv39jOpWWySefbyb5+Gm8vTxo3y4Aby8Pko+f5pPPN1stGVm0aBH33XcfMTExZGdnk52dTZs2bWjTpg3r16/n+PHjzJo1i9dee41169ZZbLtz505SUlJISEhg69atFBYWMnDgQEJDQzl8+DB///vfmTp1qsU2dZ2f2uIJDAy0yne/VYbzOvTbd2A4m4Hi7o4qwA/F3R3D2Yyq5VZKRlq0aMGgQYNYsWKFxfJVq1bh6enJoEGDgKqEr2/fvvTu3ZsjR47w+uuvM3ToUDQaDZ07dwbg+++/JyIiwrwPBwcH9u3bB1TdGGRnZxMfH4/JZMJkMrFs2TKOHz/Op59+yoYNG1i+fLl526SkJADi4uKYNWsWx44do2/fvvWKIy0tjYiICEJCQvjxxx9JSEggLS2NKVOm3DAmgO7du3Pw4EHKy8utUNKiLjLXjDAzmUzocosorzDg4eqEoigAqBzUeLiqKCyuQJdbhIeLxvxZY8jOzqayspIhQ4YQFBQEQGhoKAALFixg6tSpPPHEEwDMnz+fXbt28d577/HBBx/U+xh6vZ5//vOf3HPPPQCkpqaybt06EhISiI6OBiA4ONi8/pIlSwgPD2fevHnmZStWrCAwMJDU1FQ6dux4a1/aCoxGI18n7Cf/ciEdQwLN58jdzYWOIYGkpmUS/+1+2gcHNPrEXZ6enjg5OeHi4kLr1q3Ny998803zv9u1a8f+/ftZt26dRULn6urK8uXLcXJyAuDDDz9EURQ+/vhjnJ2d6dKlCzqdjpiYGPM29Tk/tcXT1JiMRioPHsJYdBVVG/9f/1+5aFFp/TGez6Ly4M+o/P1QrDDZWkxMDA8++CBZWVn4+/sDsHLlSkaOHGk+HxMnTuSRRx5hzpw5AHTu3JnPPvuM8+fP4+BQdQnJyMgwbw9VM9ZmZWXh4+Nj/j9X7a233jL/OygoiOjoaFJSUszLEhMTcXV1Zf369RZNQ4MHD64zjvHjxzN+/HiLv7tXX32VKVOm3DQmf39/KioquHDhgvk3SPx+JBERZsVlegpKynBxdqiRaCiKgouzAwUlZRSX6XHTOjXace+55x769u1LaGgo/fv3p1+/fjz22GOo1WqysrLo2bOnxfo9e/Y03zXVl5OTE127djW/T0xMRK1Wc//999e6flJSErt27TLXkFzr9OnTTTIR0WXlknYmE3+/FrWeP//WPpw6nYkuK5fANq1+l5g++OADVqxYwblz5ygtLaWiooKwsDCLdUJDQ80XPYCUlBS6du2Ks7OzeVn37t0ttrHH81Mb06U8jLosVD7etZ4zlY8XRp0O06U8lJa+jX78vn37EhQUxGeffcb06dP5+eef+eWXX8zNFBkZGezatcui6QNAo9FYXMxLS0stzhfAkSNHalzwMzIyeOedd9izZw86nQ69Xk9ZWRlvv/22eZ2kpCQeeeQRiySkPnFkZGSQkJDAf//7X4tRMAaDwVwbVltMAFqtFqiqORW/P0lEhFmlwYjRaEKtrv3OS61WYSyvpNLQuJ3o1Go1CQkJ/PDDD3zzzTcsXryYGTNmkJCQUOe21Xf213amvba9uZpWq7X4oa/+4bmRq1evMnDgQObPn1/jMz8/vzrjsoWrxaWUl1fgotXU+rnWxZnyi/lcLS79XeJZs2YNr7zyCgsXLuS+++7D3d2duLg4Dhw4YLGeq6trg/dtj+enNqbSMkwVehTn2s8ZGg2m/MuYSq3TP0tRFMaMGcPKlSuZPn06K1asoEePHtx5551AVcLu5OTEXXfdZbHdiRMnLPqQtGjRgsuXL1usk5iYaHHRz83N5d5776VPnz68++67BAQEYDAYiIyMtFgvMTGRadOm1dhXXXEkJSXh7e1d4+8Lfv3/fn1M1fLz8wHw9W38ZE/UTRIRYeagVqFSKRgMRlQO6hqfGwxGVCoFhxskKrdCURR69uxJz549mTVrFkFBQezcuRN/f3/27dtnUXOxb98+8x1y9Q9HdnY2Xl5eABYdGm8kNDQUo9HInj17zE0z1+rWrRsbNmygbdu25mrfps7NVYtG40RJaTnubi41Pi8tKUPj5Iib682TsN/KyckJg8Fgfr9v3z569OjB+PHjzctOnz5d5346derEqlWrKC8vR6OpukD/9NNPFuvU5/xcH09TpGidUZwcoawcXGo5L+XlKI6OKFrnmp81ktGjRzN79my+/fZbVq9ezbvvvmv+TK1WU1lZSVlZmbnGY+fOnRw7dszigh4eHl5j1MnRo0d59NFHze+3bNmCwWBg9erV5puCJUuWoNfrzbVkhYWFpKen15i4sz5xODo6UlRUhL+/Py4uNf/+a4upWnJyMm3atKFFixb1KjPRuKSzqjBzdXbE08WZkrLKGsN1TSYTJWWVeLo44+rcuJPqHThwgHnz5nHo0CHOnTvHf/7zH3Jzc7nzzjuZMmUK8+fPZ+3ataSkpDBt2jQSExOZNGkSACEhIQQGBvLGG29w6tQptm3bVq+HE7Vt25ZnnnmGMWPGsGnTJs6ePcvu3bvNHSknTJhAfn4+Tz75JD/99BOnT59mx44djB49usle3AL8fQkJDiQr+1Kt5y/rQh4d2gcS4G+du762bdty4MAB0tPTuXTpEh06dODQoUPs2LGD1NRUXn/99RoJRW2eeuopjEYj48aN48SJE+zYsYMFCxYAmC9g9Tk/18fTFIcwKy18UAX4Y8zLr/WcGfMuowoIQGnhY7UY/P39eeihhxgzZgwGg8Gi/05ERASOjo5MmTKFM2fOsGXLFsaNGwdgkYj079+fY8eOWdSKGI1GUlJSyMrKoqCgAB8fHwoLC9m8eTOnTp3i3Xff5c033yQgIMB8Q5GUlIRarTb3EWtIHFFRUXh4eDBy5EiSkpJIS0sjPj7e4lky18dU7fvvv6dfv36NVKKioSQREWaKohDg647GSU1hcQX6SgNGkwl9pYHC4go0TmoCfN0btaMqgIeHB3v37uWhhx6iY8eOzJw5k4ULFzJgwAAmTpzISy+9xMsvv0xoaCjx8fFs3ryZDh06AFV3QatXr+bkyZN07dqV+fPnmzuz1WXp0qU89thjjB8/ns6dOxMTE0NxcTGAuSbGYDDQr18/QkNDmTx5MnfccUejd/RsLCqVigH/dx/eXh6kpmVSVFRMpcFAUVExqWmZeHt78GD0fVaL/5VXXkGtVtOlSxd8fX3p378/Q4YM4fHHHycqKoq8vDyL2pEb8fDwYMuWLSQmJhIWFsaMGTOYNWsWgPluuD7n5/p4GvJwrt+LolLh0D0SlbsbxvNZmEpKMBkMmEpKMJ7PQuXuhkP3CKt0VL3WuHHjyMrKYvjw4Ra1CX5+fqxYsYKvvvqKrl27snLlSp555hlCQkLw9vY2rxcaGkq3bt0sRkTNmTOHTz/9lICAAObMmcPAgQMZO3YsI0aMoFevXuh0OoYNG2bRZygpKYlOnTrV6G9Snzi8vb3Zvn07eXl5/PnPf6Zbt27MmDHDohP69TFB1RNmN23aZNEZWvy+FFNzelLVbe7kyZM8/fTTrFq1yjyc7bcoKC5Dl1tEQUkZRqMJlUrB08WZAF93PF2tV0UsGseptEy+TthP2plMyiv0aJwc6dA+kAej76NDSNMcwlqXL774gtGjR1NQUFBn/x57ZDivqxo9o8vCpNejODqiCgjAoXsE6jYBtg7PzGg00rt3b3r16mUxYglg27ZtTJkyheTkZKsn6zeLo6GWLl3Kxo0b+eabb266XmP9voqa7KPxW/yuPF2d8XDRUFymp9JgxEGtwtXZsdFrQoR1dAgJpH1wALqsXK4Wl+LmqiXA37fJ1uTU5vPPPyc4OJiAgACSkpKYOnUqw4YNa5ZJCIC6TQAqfz9Ml/IwlZZV9R1p4WP1mpC67N27l9zcXMLDw7l06RJxcXFkZGSwadOmGus+/PDDnDp1Cp1O1+jPbGlIHA3l6OjI4sWLbz1I8ZtJIiJqpShKow7RFb8vlUr1uw3RtYYLFy4wa9YsLly4gJ+fH0OHDmXu3Lm2DsuqFJXKKkN0b0VOTg7Tpk1Dp9PRqlUroqOjOXjwoEWzzLWsNbdPQ+NoiGtH/wjbkKaZZkSqDoUQwjrk99V67KeuVgghhBDNjiQiQgghhLAZSUSEEEIIYTOSiDQj1aMianvEuRBCiN+u+nfVnkaf2Qsp0WakepbRI0eO2DgSIYRoXqp/V+1pLiN7IcN3mxEPDw8GDx5sHhMfHh6Oo2PjPo5dCCFuJ3q9niNHjrB48WIGDx6Mu7u7rUNqdmT4bjNjNBqJjY1l48aNtg5FCCGajcGDBzN9+nRpmrECSUSaqaKiIrKzs5vkRF9CCGEvVCoVfn5+UhNiRZKICCGEEMJmpI5JCCGEEDYjiYgQQgghbEYSESGEEELYjCQiQgghhLAZSUSEEEIIYTOSiAghhBDCZiQREUIIIYTN/D9DH+fQlwn9igAAAABJRU5ErkJggg==", + "image/png": "", "text/plain": [ "
" ] @@ -414,7 +414,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.10" + "version": "3.10.6" }, "vscode": { "interpreter": { From 982d20b6c36d81f4bb6a42fd81577034a49e6b41 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Tue, 2 Apr 2024 14:48:59 +0200 Subject: [PATCH 178/186] Fix links to neural in the docs --- src/ott/neural/datasets.py | 4 ++-- src/ott/neural/methods/flows/genot.py | 2 +- src/ott/neural/methods/flows/otfm.py | 2 +- src/ott/neural/methods/monge_gap.py | 12 ++++++------ src/ott/neural/methods/neuraldual.py | 9 +++++---- src/ott/neural/networks/potentials.py | 2 +- 6 files changed, 16 insertions(+), 15 deletions(-) diff --git a/src/ott/neural/datasets.py b/src/ott/neural/datasets.py index c5661aecc..89453b2ce 100644 --- a/src/ott/neural/datasets.py +++ b/src/ott/neural/datasets.py @@ -27,8 +27,8 @@ class OTData: """Distribution data for (conditional) optimal transport problems. Args: - lin: Linear (living in the shared space) part of the samples. - quad: Quadratic (living in the incomparable subspace) part of the samples. + lin: Linear term of the samples. + quad: Quadratic term of the samples. condition: Condition corresponding to the data distribution. """ lin: Optional[np.ndarray] = None diff --git a/src/ott/neural/methods/flows/genot.py b/src/ott/neural/methods/flows/genot.py index a3bad5902..ce200d376 100644 --- a/src/ott/neural/methods/flows/genot.py +++ b/src/ott/neural/methods/flows/genot.py @@ -66,7 +66,7 @@ class GENOT: n_samples_per_src: Number of samples drawn from the conditional distribution per one source sample. kwargs: Keyword arguments for - :meth:`ott.neural.flow_models.models.VelocityField.create_train_state`. + :meth:`~ott.neural.networks.velocity_field.VelocityField.create_train_state`. """ # noqa: E501 def __init__( diff --git a/src/ott/neural/methods/flows/otfm.py b/src/ott/neural/methods/flows/otfm.py index ebeb138b5..65d6a149d 100644 --- a/src/ott/neural/methods/flows/otfm.py +++ b/src/ott/neural/methods/flows/otfm.py @@ -41,7 +41,7 @@ class OTFlowMatching: distributions. It has a ``(src, tgt) -> matching`` signature. time_sampler: Time sampler with a ``(rng, n_samples) -> time`` signature. kwargs: Keyword arguments for - :meth:`~ott.neural.flow_models.models.VelocityField.create_train_state`. + :meth:`~ott.neural.networks.velocity_field.VelocityField.create_train_state`. """ def __init__( diff --git a/src/ott/neural/methods/monge_gap.py b/src/ott/neural/methods/monge_gap.py index c108a3509..140fad4a1 100644 --- a/src/ott/neural/methods/monge_gap.py +++ b/src/ott/neural/methods/monge_gap.py @@ -64,14 +64,14 @@ def monge_gap( W_{c, \varepsilon}(\hat{\rho}_n, T \sharp \hat{\rho}_n) See :cite:`uscidda:23` Eq. (8). This function is a thin wrapper that calls - :func:`~ott.neural.losses.monge_gap_from_samples`. + :func:`~ott.neural.methods.monge_gap.monge_gap_from_samples`. Args: map_fn: Callable corresponding to map :math:`T` in definition above. The - callable should be vectorized (e.g. using :func:`jax.vmap`), i.e, + callable should be vectorized (e.g. using :func:`~jax.vmap`), i.e, able to process a *batch* of vectors of size `d`, namely ``map_fn`` applied to an array returns an array of the same shape. - reference_points: Array of `[n,d]` points, :math:`\hat\rho_n` in paper + reference_points: Array of `[n,d]` points, :math:`\hat\rho_n`. cost_fn: An object of class :class:`~ott.geometry.costs.CostFn`. epsilon: Regularization parameter. See :class:`~ott.geometry.pointcloud.PointCloud` @@ -184,7 +184,7 @@ class MongeGapEstimator: For instance, :math:`\Delta` can be the :func:`~ott.tools.sinkhorn_divergence.sinkhorn_divergence` - and :math:`R` the :func:`~ott.neural.losses.monge_gap_from_samples` + and :math:`R` the :func:`~ott.neural.methods.monge_gap.monge_gap_from_samples` :cite:`uscidda:23` for a given cost function :math:`c`. In that case, it estimates a :math:`c`-OT map, i.e. a map :math:`T` optimal for the Monge problem induced by :math:`c`. @@ -259,11 +259,11 @@ def setup( def regularizer(self) -> Callable[[jnp.ndarray, jnp.ndarray], float]: """Regularizer added to the fitting loss. - Can be, e.g. the :func:`~ott.neural.losses.monge_gap_from_samples`. + Can be, e.g. the :func:`~ott.neural.methods.monge_gap.monge_gap_from_samples`. If no regularizer is passed for solver instantiation, or regularization weight :attr:`regularizer_strength` is 0, return 0 by default along with an empty set of log values. - """ + """ # noqa: E501 if self._regularizer is not None: return self._regularizer return lambda *_, **__: (0.0, None) diff --git a/src/ott/neural/methods/neuraldual.py b/src/ott/neural/methods/neuraldual.py index 6845224f4..30fd08d4e 100644 --- a/src/ott/neural/methods/neuraldual.py +++ b/src/ott/neural/methods/neuraldual.py @@ -48,7 +48,8 @@ class W2NeuralDual: denoted source and target, respectively. This is achieved by parameterizing a Kantorovich potential :math:`f_\theta: \mathbb{R}^n\rightarrow\mathbb{R}` associated with the :math:`\alpha` measure with an - :class:`~ott.neural.models.ICNN` or :class:`~ott.neural.models.MLP`, where + :class:`~ott.neural.networks.icnn.ICNN` or a + :class:`~ott.neural.networks.potentials.PotentialMLP`, where :math:`\nabla f` transports source to target cells. This potential is learned by optimizing the dual form associated with the negative inner product cost @@ -64,10 +65,10 @@ class W2NeuralDual: transport map from :math:`\beta` to :math:`\alpha`. This solver estimates the conjugate :math:`f^\star` with a neural approximation :math:`g` that is fine-tuned - with :class:`~ott.neural.duality.conjugate.FenchelConjugateSolver`, + with :class:`~ott.neural.networks.layers.conjugate.FenchelConjugateSolver`, which is a combination further described in :cite:`amos:23`. - The :class:`~ott.neural.duality.neuraldual.BaseW2NeuralDual` potentials for + The :class:`~ott.neural.networks.potentials.BasePotential` potentials for ``neural_f`` and ``neural_g`` can 1. both provide the values of the potentials :math:`f` and :math:`g`, or @@ -76,7 +77,7 @@ class W2NeuralDual: via the Fenchel conjugate as discussed in :cite:`amos:23`. The potential's value or gradient mapping is specified via - :attr:`~ott.neural.duality.neuraldual.BaseW2NeuralDual.is_potential`. + :attr:`~ott.neural.networks.potentials.BasePotential.is_potential`. Args: dim_data: input dimensionality of data required for network init diff --git a/src/ott/neural/networks/potentials.py b/src/ott/neural/networks/potentials.py index 6a08e0048..563f4537c 100644 --- a/src/ott/neural/networks/potentials.py +++ b/src/ott/neural/networks/potentials.py @@ -34,7 +34,7 @@ class PotentialTrainState(train_state.TrainState): This extends :class:`~flax.training.train_state.TrainState` to include the potential methods from the - :class:`~ott.neural.duality.neuraldual.BaseW2NeuralDual` used during training. + :class:`~ott.neural.networks.potentials.BasePotential` used during training. Args: potential_value_fn: the potential's value function From 7b61e054ba8d0b242ac430ab24a17220c915f633 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Tue, 2 Apr 2024 17:00:16 +0200 Subject: [PATCH 179/186] Check for condition dim in VF --- src/ott/neural/networks/velocity_field.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/ott/neural/networks/velocity_field.py b/src/ott/neural/networks/velocity_field.py index 55bbfabfc..589f3d33d 100644 --- a/src/ott/neural/networks/velocity_field.py +++ b/src/ott/neural/networks/velocity_field.py @@ -111,7 +111,11 @@ def create_train_state( The training state. """ t, x = jnp.ones((1, 1)), jnp.ones((1, input_dim)) - cond = None if self.condition_dims is None else jnp.ones((1, condition_dim)) + if self.condition_dims is None: + cond = None + else: + assert condition_dim > 0, "Condition dimension must be positive." + cond = jnp.ones((1, condition_dim)) params = self.init(rng, t, x, cond)["params"] return train_state.TrainState.create( From 8819d5e87e5bb356dc6bd97a366cb50602de4d5a Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Tue, 2 Apr 2024 19:28:20 +0200 Subject: [PATCH 180/186] Don't use activation fn in the last layer of VF --- src/ott/neural/networks/velocity_field.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/ott/neural/networks/velocity_field.py b/src/ott/neural/networks/velocity_field.py index 589f3d33d..39c7d98da 100644 --- a/src/ott/neural/networks/velocity_field.py +++ b/src/ott/neural/networks/velocity_field.py @@ -87,10 +87,11 @@ def __call__( else: feats = jnp.concatenate([t, x], axis=-1) - for output_dim in self.output_dims: + for output_dim in self.output_dims[:-1]: feats = self.act_fn(nn.Dense(output_dim)(feats)) - return feats + # no activation function for the final layer + return nn.Dense(self.output_dims[-1])(feats) def create_train_state( self, From 6f9cbcc848fc4a89f93d59f35ec2aec8403ba9dd Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Tue, 2 Apr 2024 19:35:00 +0200 Subject: [PATCH 181/186] Update assertions --- tests/neural/methods/genot_test.py | 8 +++++--- tests/neural/methods/otfm_test.py | 10 +++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/neural/methods/genot_test.py b/tests/neural/methods/genot_test.py index 086ea7a80..2c746596c 100644 --- a/tests/neural/methods/genot_test.py +++ b/tests/neural/methods/genot_test.py @@ -19,6 +19,7 @@ import jax import jax.numpy as jnp import jax.tree_util as jtu +import numpy as np import optax @@ -76,7 +77,7 @@ def test_genot(self, rng: jax.Array, dl: str, request): optimizer=optax.adam(learning_rate=1e-4), ) - _logs = model(dl.loader, n_iters=3, rng=rng_call) + _logs = model(dl.loader, n_iters=2, rng=rng_call) batch = next(iter(dl.loader)) batch = jtu.tree_map(jnp.asarray, batch) @@ -86,5 +87,6 @@ def test_genot(self, rng: jax.Array, dl: str, request): res = model.transport(src, condition=src_cond) - assert jnp.sum(jnp.isnan(res)) == 0 - assert res.shape[-1] == tgt_dim + assert len(_logs["loss"]) == 2 + np.testing.assert_array_equal(jnp.isfinite(res), True) + assert res.shape == (batch_size, tgt_dim) diff --git a/tests/neural/methods/otfm_test.py b/tests/neural/methods/otfm_test.py index 0eb311fa6..a7c14758c 100644 --- a/tests/neural/methods/otfm_test.py +++ b/tests/neural/methods/otfm_test.py @@ -16,6 +16,7 @@ import jax import jax.numpy as jnp import jax.tree_util as jtu +import numpy as np import optax @@ -54,6 +55,9 @@ def test_otfm(self, rng: jax.Array, dl: str, request): res_fwd = fm.transport(batch["src_lin"], condition=src_cond) res_bwd = fm.transport(batch["tgt_lin"], t0=1.0, t1=0.0, condition=src_cond) - # TODO(michalk8): better assertions - assert jnp.sum(jnp.isnan(res_fwd)) == 0 - assert jnp.sum(jnp.isnan(res_bwd)) == 0 + assert len(_logs["loss"]) == 3 + + assert res_fwd.shape == batch["src_lin"].shape + assert res_bwd.shape == batch["tgt_lin"].shape + np.testing.assert_array_equal(jnp.isfinite(res_fwd), True) + np.testing.assert_array_equal(jnp.isfinite(res_bwd), True) From 9e1499b93e7980f2e00cf55fae1d401522619217 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Tue, 2 Apr 2024 21:06:31 +0200 Subject: [PATCH 182/186] Try skipping OTFM/GENOT tests temporarily --- tests/neural/methods/genot_test.py | 2 +- tests/neural/methods/otfm_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/neural/methods/genot_test.py b/tests/neural/methods/genot_test.py index 2c746596c..2bc4ba9fb 100644 --- a/tests/neural/methods/genot_test.py +++ b/tests/neural/methods/genot_test.py @@ -52,7 +52,7 @@ class TestGENOT: "fused_cond_dl" ] ) - def test_genot(self, rng: jax.Array, dl: str, request): + def skip_test_genot(self, rng: jax.Array, dl: str, request): rng_init, rng_call, rng_data = jax.random.split(rng, 3) problem_type = dl.split("_")[0] dl = request.getfixturevalue(dl) diff --git a/tests/neural/methods/otfm_test.py b/tests/neural/methods/otfm_test.py index a7c14758c..76759d4e8 100644 --- a/tests/neural/methods/otfm_test.py +++ b/tests/neural/methods/otfm_test.py @@ -28,7 +28,7 @@ class TestOTFlowMatching: @pytest.mark.parametrize("dl", ["lin_dl", "lin_cond_dl"]) - def test_otfm(self, rng: jax.Array, dl: str, request): + def skip_test_otfm(self, rng: jax.Array, dl: str, request): dl = request.getfixturevalue(dl) dim, cond_dim = dl.lin_dim, dl.cond_dim From b37da2a0c627b8a251889989f372d88bb4358718 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Tue, 2 Apr 2024 22:26:43 +0200 Subject: [PATCH 183/186] Be extra verbose when intalling packages --- .github/workflows/tests.yml | 2 +- tests/neural/methods/genot_test.py | 2 +- tests/neural/methods/otfm_test.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a06447d9f..d8ead3517 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -97,7 +97,7 @@ jobs: - name: Setup environment run: | - tox -e py${{ matrix.python-version }} --notest -v + tox -e py${{ matrix.python-version }} --notest -vv - name: Run tests run: | diff --git a/tests/neural/methods/genot_test.py b/tests/neural/methods/genot_test.py index 2bc4ba9fb..2c746596c 100644 --- a/tests/neural/methods/genot_test.py +++ b/tests/neural/methods/genot_test.py @@ -52,7 +52,7 @@ class TestGENOT: "fused_cond_dl" ] ) - def skip_test_genot(self, rng: jax.Array, dl: str, request): + def test_genot(self, rng: jax.Array, dl: str, request): rng_init, rng_call, rng_data = jax.random.split(rng, 3) problem_type = dl.split("_")[0] dl = request.getfixturevalue(dl) diff --git a/tests/neural/methods/otfm_test.py b/tests/neural/methods/otfm_test.py index 76759d4e8..a7c14758c 100644 --- a/tests/neural/methods/otfm_test.py +++ b/tests/neural/methods/otfm_test.py @@ -28,7 +28,7 @@ class TestOTFlowMatching: @pytest.mark.parametrize("dl", ["lin_dl", "lin_cond_dl"]) - def skip_test_otfm(self, rng: jax.Array, dl: str, request): + def test_otfm(self, rng: jax.Array, dl: str, request): dl = request.getfixturevalue(dl) dim, cond_dim = dl.lin_dim, dl.cond_dim From 9c561a5e01d513f0f5e2c46732a775c6c8aa1a57 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Tue, 2 Apr 2024 22:54:23 +0200 Subject: [PATCH 184/186] Remove `torch` dependency --- .github/workflows/tests.yml | 2 +- tests/neural/conftest.py | 57 +++++++++++++++++++++++-------- tests/neural/methods/otfm_test.py | 2 +- 3 files changed, 44 insertions(+), 17 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d8ead3517..a06447d9f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -97,7 +97,7 @@ jobs: - name: Setup environment run: | - tox -e py${{ matrix.python-version }} --notest -vv + tox -e py${{ matrix.python-version }} --notest -v - name: Run tests run: | diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py index f4c25c514..41b5ea71a 100644 --- a/tests/neural/conftest.py +++ b/tests/neural/conftest.py @@ -11,18 +11,45 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import NamedTuple, Optional, Union +from collections import defaultdict +from typing import Dict, NamedTuple, Optional, Union import pytest +import jax.numpy as jnp import numpy as np -from torch.utils.data import DataLoader from ott.neural import datasets +class SimpleDataLoader: + + def __init__( + self, + dataset: datasets.OTDataset, + batch_size: int, + seed: Optional[int] = None + ): + self.dataset = dataset + self.batch_size = batch_size + self.seed = seed + + def __iter__(self): + self._rng = np.random.default_rng(self.seed) + return self + + def __next__(self) -> Dict[str, jnp.ndarray]: + data = defaultdict(list) + for _ in range(self.batch_size): + ix = self._rng.integers(0, len(self.dataset)) + for k, v in self.dataset[ix].items(): + data[k].append(v) + + return {k: jnp.vstack(v) for k, v in data.items()} + + class OTLoader(NamedTuple): - loader: DataLoader + loader: SimpleDataLoader lin_dim: int = 0 quad_src_dim: int = 0 quad_tgt_dim: int = 0 @@ -58,7 +85,7 @@ def _ot_data( @pytest.fixture() -def lin_dl() -> DataLoader: +def lin_dl() -> OTLoader: n, d = 128, 2 rng = np.random.default_rng(0) @@ -67,13 +94,13 @@ def lin_dl() -> DataLoader: ds = datasets.OTDataset(src, tgt) return OTLoader( - DataLoader(ds, batch_size=16, shuffle=True), + SimpleDataLoader(ds, batch_size=13), lin_dim=d, ) @pytest.fixture() -def lin_cond_dl() -> DataLoader: +def lin_cond_dl() -> OTLoader: n, d, cond_dim = 128, 2, 3 rng = np.random.default_rng(13) @@ -84,14 +111,14 @@ def lin_cond_dl() -> DataLoader: ds = datasets.OTDataset(src, tgt) return OTLoader( - DataLoader(ds, batch_size=16, shuffle=True), + SimpleDataLoader(ds, batch_size=14), lin_dim=d, cond_dim=cond_dim, ) @pytest.fixture() -def quad_dl(): +def quad_dl() -> OTLoader: n, quad_src_dim, quad_tgt_dim = 128, 2, 4 rng = np.random.default_rng(11) @@ -100,14 +127,14 @@ def quad_dl(): ds = datasets.OTDataset(src, tgt) return OTLoader( - DataLoader(ds, batch_size=16, shuffle=True), + SimpleDataLoader(ds, batch_size=15), quad_src_dim=quad_src_dim, quad_tgt_dim=quad_tgt_dim, ) @pytest.fixture() -def quad_cond_dl(): +def quad_cond_dl() -> OTLoader: n, quad_src_dim, quad_tgt_dim, cond_dim = 128, 2, 4, 5 rng = np.random.default_rng(414) @@ -118,7 +145,7 @@ def quad_cond_dl(): ds = datasets.OTDataset(src, tgt) return OTLoader( - DataLoader(ds, batch_size=16, shuffle=True), + SimpleDataLoader(ds, batch_size=16), quad_src_dim=quad_src_dim, quad_tgt_dim=quad_tgt_dim, cond_dim=cond_dim, @@ -126,7 +153,7 @@ def quad_cond_dl(): @pytest.fixture() -def fused_dl(): +def fused_dl() -> OTLoader: n, lin_dim, quad_src_dim, quad_tgt_dim = 128, 6, 2, 4 rng = np.random.default_rng(11) @@ -135,7 +162,7 @@ def fused_dl(): ds = datasets.OTDataset(src, tgt) return OTLoader( - DataLoader(ds, batch_size=16, shuffle=True), + SimpleDataLoader(ds, batch_size=17), lin_dim=lin_dim, quad_src_dim=quad_src_dim, quad_tgt_dim=quad_tgt_dim, @@ -143,7 +170,7 @@ def fused_dl(): @pytest.fixture() -def fused_cond_dl(): +def fused_cond_dl() -> OTLoader: n, lin_dim, quad_src_dim, quad_tgt_dim, cond_dim = 128, 6, 2, 4, 7 rng = np.random.default_rng(11) @@ -163,7 +190,7 @@ def fused_cond_dl(): ds = datasets.OTDataset(src, tgt) return OTLoader( - DataLoader(ds, batch_size=16, shuffle=True), + SimpleDataLoader(ds, batch_size=18), lin_dim=lin_dim, quad_src_dim=quad_src_dim, quad_tgt_dim=quad_tgt_dim, diff --git a/tests/neural/methods/otfm_test.py b/tests/neural/methods/otfm_test.py index a7c14758c..f1ccae767 100644 --- a/tests/neural/methods/otfm_test.py +++ b/tests/neural/methods/otfm_test.py @@ -25,7 +25,7 @@ from ott.solvers import utils as solver_utils -class TestOTFlowMatching: +class TestOTFM: @pytest.mark.parametrize("dl", ["lin_dl", "lin_cond_dl"]) def test_otfm(self, rng: jax.Array, dl: str, request): From f227d54ba30bbc920fbdb095a0a19c36edd21fc5 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Tue, 2 Apr 2024 22:54:40 +0200 Subject: [PATCH 185/186] Remove `torch` from tests in `pyproject.toml` --- pyproject.toml | 147 ++++++++++++++++++++++++------------------------- 1 file changed, 73 insertions(+), 74 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1c71241f1..3bb8351be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,7 +79,6 @@ test = [ "tslearn>=0.5; python_version < '3.12'", "lineax; python_version >= '3.9'", "matplotlib", - "torch" ] docs = [ "sphinx>=4.0", @@ -110,7 +109,7 @@ multi_line_output = 3 sections = ["FUTURE", "STDLIB", "THIRDPARTY", "TEST", "NUMERIC", "NEURAL", "PLOTTING", "FIRSTPARTY", "LOCALFOLDER"] # also contains what we import in notebooks/tests known_neural = ["flax", "optax", "diffrax", "orbax"] -known_numeric = ["numpy", "scipy", "jax", "flax", "optax", "jaxopt", "torch", "ot", "torchvision", "pandas", "sklearn", "tslearn"] +known_numeric = ["numpy", "scipy", "jax", "flax", "optax", "jaxopt", "ot", "torch", "torchvision", "pandas", "sklearn", "tslearn"] known_test = ["_pytest", "pytest"] known_plotting = ["IPython", "matplotlib", "mpl_toolkits", "seaborn"] @@ -187,85 +186,85 @@ ignore_path = ["docs/**/_autosummary", "docs/contributing.rst"] [tool.tox] legacy_tox_ini = """ - [tox] - min_version = 4.0 - env_list = lint-code,py{3.8,3.9,3.10,3.11,3.12},py3.9-jax-default - skip_missing_interpreters = true +[tox] +min_version = 4.0 +env_list = lint-code,py{3.8,3.9,3.10,3.11,3.12},py3.9-jax-default +skip_missing_interpreters = true - [testenv] - extras = - test - # https://github.com/google/flax/issues/3329 - py{3.9,3.10,3.11,3.12},py3.9-jax-default: neural - pass_env = CUDA_*,PYTEST_*,CI - commands_pre = - gpu: python -I -m pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - jax-latest: python -I -m pip install 'git+https://github.com/google/jax@main' - commands = - python -m pytest {tty:--color=yes} {posargs: \ - --cov={env_site_packages_dir}{/}ott --cov-config={tox_root}{/}pyproject.toml \ - --no-cov-on-fail --cov-report=xml --cov-report=term-missing:skip-covered} +[testenv] +extras = + test + # https://github.com/google/flax/issues/3329 + py{3.9,3.10,3.11,3.12},py3.9-jax-default: neural +pass_env = CUDA_*,PYTEST_*,CI +commands_pre = + gpu: python -I -m pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + jax-latest: python -I -m pip install 'git+https://github.com/google/jax@main' +commands = + python -m pytest {tty:--color=yes} {posargs: \ + --cov={env_site_packages_dir}{/}ott --cov-config={tox_root}{/}pyproject.toml \ + --no-cov-on-fail --cov-report=xml --cov-report=term-missing:skip-covered} - [testenv:lint-code] - description = Lint the code. - deps = pre-commit>=2.16.0 - skip_install = true - commands = - pre-commit run --all-files --show-diff-on-failure +[testenv:lint-code] +description = Lint the code. +deps = pre-commit>=2.16.0 +skip_install = true +commands = + pre-commit run --all-files --show-diff-on-failure - [testenv:lint-docs] - description = Lint the documentation. - deps = - extras = docs,neural - ignore_errors = true - allowlist_externals = make - pass_env = PYENCHANT_LIBRARY_PATH - set_env = SPHINXOPTS = -W -q --keep-going - changedir = {tox_root}{/}docs - commands = - make linkcheck {posargs} - make spelling {posargs} +[testenv:lint-docs] +description = Lint the documentation. +deps = +extras = docs,neural +ignore_errors = true +allowlist_externals = make +pass_env = PYENCHANT_LIBRARY_PATH +set_env = SPHINXOPTS = -W -q --keep-going +changedir = {tox_root}{/}docs +commands = + make linkcheck {posargs} + make spelling {posargs} - [testenv:build-docs] - description = Build the documentation. - use_develop = true - deps = - extras = docs,neural - allowlist_externals = make - changedir = {tox_root}{/}docs - commands = - make html {posargs} - commands_post = - python -c 'import pathlib; print("Documentation is under:", pathlib.Path("{tox_root}") / "docs" / "_build" / "html" / "index.html")' +[testenv:build-docs] +description = Build the documentation. +use_develop = true +deps = +extras = docs,neural +allowlist_externals = make +changedir = {tox_root}{/}docs +commands = + make html {posargs} +commands_post = + python -c 'import pathlib; print("Documentation is under:", pathlib.Path("{tox_root}") / "docs" / "_build" / "html" / "index.html")' - [testenv:clean-docs] - description = Remove the documentation. - deps = - skip_install = true - changedir = {tox_root}{/}docs - allowlist_externals = make - commands = - make clean +[testenv:clean-docs] +description = Remove the documentation. +deps = +skip_install = true +changedir = {tox_root}{/}docs +allowlist_externals = make +commands = + make clean - [testenv:build-package] - description = Build the package. - deps = - build - twine - commands = - python -m build --sdist --wheel --outdir {tox_root}{/}dist{/} {posargs:} - twine check {tox_root}{/}dist{/}* - commands_post = - python -c 'import pathlib; print(f"Package is under:", pathlib.Path("{tox_root}") / "dist")' +[testenv:build-package] +description = Build the package. +deps = + build + twine +commands = + python -m build --sdist --wheel --outdir {tox_root}{/}dist{/} {posargs:} + twine check {tox_root}{/}dist{/}* +commands_post = + python -c 'import pathlib; print(f"Package is under:", pathlib.Path("{tox_root}") / "dist")' - [testenv:format-references] - description = Format references.bib. - skip_install = true - allowlist_externals = biber - commands = biber --tool --output_file={tox_root}{/}docs{/}references.bib --nolog \ - --output_align --output_indent=2 --output_fieldcase=lower \ - --output_legacy_dates --output-field-replace=journaltitle:journal,thesis:phdthesis,institution:school \ - {tox_root}{/}docs{/}references.bib +[testenv:format-references] +description = Format references.bib. +skip_install = true +allowlist_externals = biber +commands = biber --tool --output_file={tox_root}{/}docs{/}references.bib --nolog \ + --output_align --output_indent=2 --output_fieldcase=lower \ + --output_legacy_dates --output-field-replace=journaltitle:journal,thesis:phdthesis,institution:school \ + {tox_root}{/}docs{/}references.bib """ [tool.ruff] From 6f9a77c52eb7bd78704b7b84ae75b32e338fc774 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Wed, 3 Apr 2024 14:19:14 +0200 Subject: [PATCH 186/186] [ci skip] Update docstrings --- src/ott/solvers/utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/ott/solvers/utils.py b/src/ott/solvers/utils.py index 6c48a2577..f7bdae63a 100644 --- a/src/ott/solvers/utils.py +++ b/src/ott/solvers/utils.py @@ -32,7 +32,7 @@ def match_linear( x: jnp.ndarray, - y: jnp.ndarray, + y: Optional[jnp.ndarray], cost_fn: Optional[costs.CostFn] = None, epsilon: Optional[float] = None, scale_cost: ScaleCost_t = 1.0, @@ -41,8 +41,8 @@ def match_linear( """Compute solution to a linear OT problem. Args: - x: Linear term of the source point cloud. - y: Linear term of the target point cloud. + x: Source point cloud of shape ``[n, d]``. + y: Target point cloud of shape ``[m, d]``. cost_fn: Cost function. epsilon: Regularization parameter. scale_cost: Scaling of the cost matrix. @@ -70,8 +70,8 @@ def match_quadratic( """Compute solution to a quadratic OT problem. Args: - xx: Quadratic (incomparable) term of the source point cloud. - yy: Quadratic (incomparable) term of the target point cloud. + xx: Source point cloud of shape ``[n, d1]``. + yy: Target point cloud of shape ``[m, d2]``. x: Linear (fused) term of the source point cloud. y: Linear (fused) term of the target point cloud. scale_cost: Scaling of the cost matrix.