Skip to content

Commit

Permalink
Feature/initializers as literals (#148)
Browse files Browse the repository at this point in the history
* Allow passing `initializer` as a `literal` in Sink

* Add test
  • Loading branch information
michalk8 authored Oct 10, 2022
1 parent 468c2d5 commit 6c79d96
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 7 deletions.
28 changes: 24 additions & 4 deletions ott/core/sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@

# Lint as: python3
"""A Jax implementation of the Sinkhorn algorithm."""
from typing import Any, Callable, NamedTuple, Optional, Sequence, Tuple
from typing import Any, Callable, Mapping, NamedTuple, Optional, Sequence, Tuple, Union

import jax
import jax.numpy as jnp
import numpy as np
from typing_extensions import Literal

from ott.core import anderson as anderson_lib
from ott.core import fixed_point_loop
Expand Down Expand Up @@ -335,11 +336,12 @@ class Sinkhorn:
gradients have been stopped. This is useful when carrying out first order
differentiation, and is only valid (as with ``implicit_differentiation``)
when the algorithm has converged with a low tolerance.
initializer: how to compute the initial potentials/scalings.
kwargs_init: keyword arguments when creating the initializer.
jit: if True, automatically jits the function upon first call.
Should be set to False when used in a function that is jitted by the user,
or when computing gradients (in which case the gradient function
should be jitted by the user)
initializer: how to compute the initial potentials/scalings.
"""

def __init__(
Expand All @@ -356,7 +358,9 @@ def __init__(
use_danskin: Optional[bool] = None,
implicit_diff: Optional[implicit_lib.ImplicitDiff
] = implicit_lib.ImplicitDiff(), # noqa: E124
initializer: init_lib.SinkhornInitializer = init_lib.DefaultInitializer(),
initializer: Union[Literal["default", "gaussian", "sorting"],
init_lib.SinkhornInitializer] = "default",
kwargs_init: Optional[Mapping[str, Any]] = None,
jit: bool = True
):
self.lse_mode = lse_mode
Expand All @@ -377,6 +381,7 @@ def __init__(
self.implicit_diff = implicit_diff
self.parallel_dual_updates = parallel_dual_updates
self.initializer = initializer
self.kwargs_init = {} if kwargs_init is None else kwargs_init
self.jit = jit

# Force implicit_differentiation to True when using Anderson acceleration,
Expand Down Expand Up @@ -408,7 +413,8 @@ def __call__(
Returns:
The Sinkhorn output.
"""
init_dual_a, init_dual_b = self.initializer(
initializer = self.create_initializer()
init_dual_a, init_dual_b = initializer(
ot_prob, *init, lse_mode=self.lse_mode
)
run_fn = jax.jit(run) if self.jit else run
Expand Down Expand Up @@ -578,6 +584,20 @@ def norm_error(self) -> Tuple[int, ...]:
return self._norm_error, 1
return self._norm_error,

# TODO(michalk8): in the future, enforce this (+ in GW) via abstract method
def create_initializer(self) -> init_lib.SinkhornInitializer:
if isinstance(self.initializer, init_lib.SinkhornInitializer):
return self.initializer
if self.initializer == "default":
return init_lib.DefaultInitializer()
if self.initializer == "gaussian":
return init_lib.GaussianInitializer()
if self.initializer == "sorting":
return init_lib.SortingInitializer(**self.kwargs_init)
raise NotImplementedError(
f"Initializer `{self.initializer}` is not yet implemented."
)

def tree_flatten(self):
aux = vars(self).copy()
aux['norm_error'] = aux.pop('_norm_error')
Expand Down
29 changes: 26 additions & 3 deletions tests/core/initializers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,30 @@ def init_gaus():
_ = init_gaus()
_ = init_sort()

@pytest.mark.parametrize(
"init", [
"default", "gaussian", "sorting",
init_lib.DefaultInitializer(), "non-existent"
]
)
def test_create_initializer(self, init: str):
solver = sinkhorn.Sinkhorn(initializer=init)
expected_types = {
"default": init_lib.DefaultInitializer,
"gaussian": init_lib.GaussianInitializer,
"sorting": init_lib.SortingInitializer,
}

if isinstance(init, init_lib.SinkhornInitializer):
assert solver.create_initializer() is init
elif init == "non-existent":
with pytest.raises(NotImplementedError, match=r""):
_ = solver.create_initializer()
else:
actual = solver.create_initializer()
expected_type = expected_types[init]
assert isinstance(actual, expected_type)

@pytest.mark.parametrize(
"vector_min, lse_mode", [(True, True), (True, False), (False, True)]
)
Expand Down Expand Up @@ -267,10 +291,9 @@ def test_gauss_initializer(self, lse_mode, rng: jnp.ndarray):
assert base_num_iter >= gaus_num_iter


# TODO(michalk8): mark tests as fast
class TestLRInitializers:

@pytest.mark.parametrize("kind", ["pc", "lrc", "geom"])
@pytest.mark.fast.with_args("kind", ["pc", "lrc", "geom"], only_fast=0)
def test_create_default_initializer(self, rng: jnp.ndarray, kind: str):
n, d, rank = 110, 2, 3
x = jax.random.normal(rng, (n, d))
Expand Down Expand Up @@ -340,7 +363,7 @@ def test_partial_initialization(
else:
raise NotImplementedError(partial_init)

@pytest.mark.parametrize("rank", [2, 4, 10, 13])
@pytest.mark.fast.with_args("rank", [2, 4, 10, 13], only_fast=True)
def test_generalized_k_means_has_correct_rank(
self, rng: jnp.ndarray, rank: int
):
Expand Down

0 comments on commit 6c79d96

Please sign in to comment.