diff --git a/ott/core/sinkhorn.py b/ott/core/sinkhorn.py index 9c0c38144..7cbc8172a 100644 --- a/ott/core/sinkhorn.py +++ b/ott/core/sinkhorn.py @@ -14,11 +14,12 @@ # Lint as: python3 """A Jax implementation of the Sinkhorn algorithm.""" -from typing import Any, Callable, NamedTuple, Optional, Sequence, Tuple +from typing import Any, Callable, Mapping, NamedTuple, Optional, Sequence, Tuple, Union import jax import jax.numpy as jnp import numpy as np +from typing_extensions import Literal from ott.core import anderson as anderson_lib from ott.core import fixed_point_loop @@ -335,11 +336,12 @@ class Sinkhorn: gradients have been stopped. This is useful when carrying out first order differentiation, and is only valid (as with ``implicit_differentiation``) when the algorithm has converged with a low tolerance. + initializer: how to compute the initial potentials/scalings. + kwargs_init: keyword arguments when creating the initializer. jit: if True, automatically jits the function upon first call. Should be set to False when used in a function that is jitted by the user, or when computing gradients (in which case the gradient function should be jitted by the user) - initializer: how to compute the initial potentials/scalings. """ def __init__( @@ -356,7 +358,9 @@ def __init__( use_danskin: Optional[bool] = None, implicit_diff: Optional[implicit_lib.ImplicitDiff ] = implicit_lib.ImplicitDiff(), # noqa: E124 - initializer: init_lib.SinkhornInitializer = init_lib.DefaultInitializer(), + initializer: Union[Literal["default", "gaussian", "sorting"], + init_lib.SinkhornInitializer] = "default", + kwargs_init: Optional[Mapping[str, Any]] = None, jit: bool = True ): self.lse_mode = lse_mode @@ -377,6 +381,7 @@ def __init__( self.implicit_diff = implicit_diff self.parallel_dual_updates = parallel_dual_updates self.initializer = initializer + self.kwargs_init = {} if kwargs_init is None else kwargs_init self.jit = jit # Force implicit_differentiation to True when using Anderson acceleration, @@ -408,7 +413,8 @@ def __call__( Returns: The Sinkhorn output. """ - init_dual_a, init_dual_b = self.initializer( + initializer = self.create_initializer() + init_dual_a, init_dual_b = initializer( ot_prob, *init, lse_mode=self.lse_mode ) run_fn = jax.jit(run) if self.jit else run @@ -578,6 +584,20 @@ def norm_error(self) -> Tuple[int, ...]: return self._norm_error, 1 return self._norm_error, + # TODO(michalk8): in the future, enforce this (+ in GW) via abstract method + def create_initializer(self) -> init_lib.SinkhornInitializer: + if isinstance(self.initializer, init_lib.SinkhornInitializer): + return self.initializer + if self.initializer == "default": + return init_lib.DefaultInitializer() + if self.initializer == "gaussian": + return init_lib.GaussianInitializer() + if self.initializer == "sorting": + return init_lib.SortingInitializer(**self.kwargs_init) + raise NotImplementedError( + f"Initializer `{self.initializer}` is not yet implemented." + ) + def tree_flatten(self): aux = vars(self).copy() aux['norm_error'] = aux.pop('_norm_error') diff --git a/tests/core/initializers_test.py b/tests/core/initializers_test.py index 6c8d9336a..15d8c4316 100644 --- a/tests/core/initializers_test.py +++ b/tests/core/initializers_test.py @@ -133,6 +133,30 @@ def init_gaus(): _ = init_gaus() _ = init_sort() + @pytest.mark.parametrize( + "init", [ + "default", "gaussian", "sorting", + init_lib.DefaultInitializer(), "non-existent" + ] + ) + def test_create_initializer(self, init: str): + solver = sinkhorn.Sinkhorn(initializer=init) + expected_types = { + "default": init_lib.DefaultInitializer, + "gaussian": init_lib.GaussianInitializer, + "sorting": init_lib.SortingInitializer, + } + + if isinstance(init, init_lib.SinkhornInitializer): + assert solver.create_initializer() is init + elif init == "non-existent": + with pytest.raises(NotImplementedError, match=r""): + _ = solver.create_initializer() + else: + actual = solver.create_initializer() + expected_type = expected_types[init] + assert isinstance(actual, expected_type) + @pytest.mark.parametrize( "vector_min, lse_mode", [(True, True), (True, False), (False, True)] ) @@ -267,10 +291,9 @@ def test_gauss_initializer(self, lse_mode, rng: jnp.ndarray): assert base_num_iter >= gaus_num_iter -# TODO(michalk8): mark tests as fast class TestLRInitializers: - @pytest.mark.parametrize("kind", ["pc", "lrc", "geom"]) + @pytest.mark.fast.with_args("kind", ["pc", "lrc", "geom"], only_fast=0) def test_create_default_initializer(self, rng: jnp.ndarray, kind: str): n, d, rank = 110, 2, 3 x = jax.random.normal(rng, (n, d)) @@ -340,7 +363,7 @@ def test_partial_initialization( else: raise NotImplementedError(partial_init) - @pytest.mark.parametrize("rank", [2, 4, 10, 13]) + @pytest.mark.fast.with_args("rank", [2, 4, 10, 13], only_fast=True) def test_generalized_k_means_has_correct_rank( self, rng: jnp.ndarray, rank: int ):