Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/refactor initializers #599

Merged
merged 7 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions docs/tutorials/neural/400_MetaOT.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -582,15 +582,19 @@
" ot_problem = linear_problem.LinearProblem(geom, a=a, b=b)\n",
" solver = sinkhorn.Sinkhorn(**sink_kwargs)\n",
"\n",
" base_sink_out = solver(ot_problem, init=(None, None))\n",
" base_sink_out = solver(ot_problem, init=None)\n",
"\n",
" init_dual_a = meta_initializer.init_dual_a(ot_problem, lse_mode=True)\n",
" meta_sink_out = solver(ot_problem, init=(init_dual_a, None))\n",
" meta_sink_out = solver(\n",
" ot_problem, init=(init_dual_a, jnp.zeros_like(init_dual_a))\n",
" )\n",
"\n",
" init_dual_a = initializers.GaussianInitializer().init_dual_a(\n",
" ot_problem, lse_mode=True\n",
" )\n",
" gaus_sink_out = solver(ot_problem, init=(init_dual_a, None))\n",
" gaus_sink_out = solver(\n",
" ot_problem, init=(init_dual_a, jnp.zeros_like(init_dual_a))\n",
" )\n",
"\n",
" error_log[\"base\"].append(base_sink_out.errors)\n",
" error_log[\"meta_ot\"].append(meta_sink_out.errors)\n",
Expand Down
45 changes: 16 additions & 29 deletions src/ott/initializers/linear/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class SinkhornInitializer(abc.ABC):
"""Base class for Sinkhorn initializers."""

@abc.abstractmethod
def init_dual_a(
def init_fu(
self,
ot_prob: linear_problem.LinearProblem,
lse_mode: bool,
Expand All @@ -50,7 +50,7 @@ def init_dual_a(
"""

@abc.abstractmethod
def init_dual_b(
def init_gv(
self,
ot_prob: linear_problem.LinearProblem,
lse_mode: bool,
Expand All @@ -70,8 +70,6 @@ def init_dual_b(
def __call__(
self,
ot_prob: linear_problem.LinearProblem,
a: Optional[jnp.ndarray],
b: Optional[jnp.ndarray],
lse_mode: bool,
rng: Optional[jax.Array] = None,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
Expand All @@ -90,25 +88,15 @@ def __call__(
The initial potentials/scalings.
"""
rng = utils.default_prng_key(rng)
rng_x, rng_y = jax.random.split(rng, 2)
n, m = ot_prob.geom.shape
if a is None:
a = self.init_dual_a(ot_prob, lse_mode=lse_mode, rng=rng_x)
if b is None:
b = self.init_dual_b(ot_prob, lse_mode=lse_mode, rng=rng_y)

assert a.shape == (
n,
), f"Expected `f_u` to have shape `{n,}`, found `{a.shape}`."
assert b.shape == (
m,
), f"Expected `g_v` to have shape `{m,}`, found `{b.shape}`."
rng_f, rng_g = jax.random.split(rng, 2)
fu = self.init_fu(ot_prob, lse_mode=lse_mode, rng=rng_f)
gv = self.init_gv(ot_prob, lse_mode=lse_mode, rng=rng_g)

# cancel dual variables for zero weights
a = jnp.where(ot_prob.a > 0.0, a, -jnp.inf if lse_mode else 0.0)
b = jnp.where(ot_prob.b > 0.0, b, -jnp.inf if lse_mode else 0.0)

return a, b
mask_value = -jnp.inf if lse_mode else 0.0
fu = jnp.where(ot_prob.a > 0.0, fu, mask_value)
gv = jnp.where(ot_prob.b > 0.0, gv, mask_value)
return fu, gv

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102
return [], {}
Expand All @@ -124,7 +112,7 @@ def tree_unflatten( # noqa: D102
class DefaultInitializer(SinkhornInitializer):
"""Default initialization of Sinkhorn dual potentials/primal scalings."""

def init_dual_a( # noqa: D102
def init_fu( # noqa: D102
self,
ot_prob: linear_problem.LinearProblem,
lse_mode: bool,
Expand All @@ -133,7 +121,7 @@ def init_dual_a( # noqa: D102
del rng
return jnp.zeros_like(ot_prob.a) if lse_mode else jnp.ones_like(ot_prob.a)

def init_dual_b( # noqa: D102
def init_gv( # noqa: D102
self,
ot_prob: linear_problem.LinearProblem,
lse_mode: bool,
Expand All @@ -154,7 +142,7 @@ class GaussianInitializer(DefaultInitializer):
to initialize Sinkhorn potentials/scalings.
"""

def init_dual_a( # noqa: D102
def init_fu( # noqa: D102
self,
ot_prob: linear_problem.LinearProblem,
lse_mode: bool,
Expand Down Expand Up @@ -241,7 +229,7 @@ def cond_fn(state: Tuple[jnp.ndarray, float, int]) -> bool:

return f_potential

def init_dual_a(
def init_fu(
self,
ot_prob: linear_problem.LinearProblem,
lse_mode: bool,
Expand Down Expand Up @@ -304,9 +292,8 @@ class SubsampleInitializer(DefaultInitializer):
:class:`~ott.geometry.pointcloud.PointCloud`.
subsample_n_y: number of points to subsample from the second measure in
:class:`~ott.geometry.pointcloud.PointCloud`.
If ``None``, use ``subsample_n_x``.
kwargs: Keyword arguments for
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn`.
If :obj:`None`, use ``subsample_n_x``.
kwargs: Keyword arguments for :func:`~ott.solvers.linear.solve`.
"""

def __init__(
Expand All @@ -320,7 +307,7 @@ def __init__(
self.subsample_n_y = subsample_n_y or subsample_n_x
self.sinkhorn_kwargs = kwargs

def init_dual_a( # noqa: D102
def init_fu( # noqa: D102
self,
ot_prob: linear_problem.LinearProblem,
lse_mode: bool,
Expand Down
69 changes: 8 additions & 61 deletions src/ott/initializers/linear/initializers_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@
if TYPE_CHECKING:
from ott.problems.linear import linear_problem
from ott.problems.quadratic import quadratic_problem
from ott.solvers.linear import sinkhorn, sinkhorn_lr
from ott.solvers.quadratic import gromov_wasserstein_lr
from ott.solvers.linear import sinkhorn

Problem_t = Union["linear_problem.LinearProblem",
"quadratic_problem.QuadraticProblem"]
Expand Down Expand Up @@ -96,7 +95,7 @@ def init_r(
"""Initialize the low-rank factor :math:`R`.

Args:
ot_prob: Linear OT problem.
ot_prob: OT problem.
rng: Random key for seeding.
init_g: Initial value for :math:`g` factor.
kwargs: Additional keyword arguments.
Expand All @@ -123,65 +122,16 @@ def init_g(
Array of shape ``[rank,]``.
"""

@classmethod
def from_solver(
cls,
solver: Union["sinkhorn_lr.LRSinkhorn",
"gromov_wasserstein_lr.LRGromovWasserstein"],
*,
kind: Literal["random", "rank2", "k-means", "generalized-k-means"],
**kwargs: Any,
) -> "LRInitializer":
"""Create a low-rank initializer from a linear or quadratic solver.

Args:
solver: Low-rank linear or quadratic solver.
kind: Which initializer to instantiate.
kwargs: Keyword arguments when creating the initializer.

Returns:
Low-rank initializer.
"""
rank = solver.rank
sinkhorn_kwargs = {
"norm_error": solver._norm_error,
"lse_mode": solver.lse_mode,
"implicit_diff": solver.implicit_diff,
"use_danskin": solver.use_danskin
}

if kind == "random":
return RandomInitializer(rank, **kwargs)
if kind == "rank2":
return Rank2Initializer(rank, **kwargs)
if kind == "k-means":
return KMeansInitializer(rank, sinkhorn_kwargs=sinkhorn_kwargs, **kwargs)
if kind == "generalized-k-means":
return GeneralizedKMeansInitializer(
rank, sinkhorn_kwargs=sinkhorn_kwargs, **kwargs
)
raise NotImplementedError(f"Initializer `{kind}` is not implemented.")

def __call__(
self,
ot_prob: Problem_t,
q: Optional[jnp.ndarray] = None,
r: Optional[jnp.ndarray] = None,
g: Optional[jnp.ndarray] = None,
*,
rng: Optional[jax.Array] = None,
**kwargs: Any
**kwargs: Any,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Initialize the factors :math:`Q`, :math:`R` and :math:`g`.

Args:
ot_prob: OT problem.
q: Factor of shape ``[n, rank]``. If `None`, it will be initialized
using :meth:`init_q`.
r: Factor of shape ``[m, rank]``. If `None`, it will be initialized
using :meth:`init_r`.
g: Factor of shape ``[rank,]``. If `None`, it will be initialized
using :meth:`init_g`.
rng: Random key for seeding.
kwargs: Additional keyword arguments for :meth:`init_q`, :meth:`init_r`
and :meth:`init_g`.
Expand All @@ -190,14 +140,11 @@ def __call__(
The factors :math:`Q`, :math:`R` and :math:`g`, respectively.
"""
rng = utils.default_prng_key(rng)
rng1, rng2, rng3 = jax.random.split(rng, 3)

if g is None:
g = self.init_g(ot_prob, rng1, **kwargs)
if q is None:
q = self.init_q(ot_prob, rng2, init_g=g, **kwargs)
if r is None:
r = self.init_r(ot_prob, rng3, init_g=g, **kwargs)
rng_g, rng_q, rng_r = jax.random.split(rng, 3)

g = self.init_g(ot_prob, rng_g, **kwargs)
q = self.init_q(ot_prob, rng_q, init_g=g, **kwargs)
r = self.init_r(ot_prob, rng_r, init_g=g, **kwargs)

assert g.shape == (self.rank,)
assert q.shape == (ot_prob.a.shape[0], self.rank)
Expand Down
5 changes: 3 additions & 2 deletions src/ott/initializers/neural/meta_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import jax
import jax.numpy as jnp
import jax.tree_util as jtu

import optax
from flax import linen as nn
Expand All @@ -31,7 +32,7 @@
__all__ = ["MetaInitializer"]


@jax.tree_util.register_pytree_node_class
@jtu.register_pytree_node_class
class MetaInitializer(initializers.DefaultInitializer):
"""Meta OT Initializer with a fixed geometry :cite:`amos:22`.

Expand Down Expand Up @@ -133,7 +134,7 @@ def update(
"""
return self.update_impl(state, a, b)

def init_dual_a( # noqa: D102
def init_fu( # noqa: D102
self,
ot_prob: linear_problem.LinearProblem,
lse_mode: bool,
Expand Down
30 changes: 11 additions & 19 deletions src/ott/initializers/quadratic/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
import abc
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple

import jax
import jax.numpy as jnp
import jax.tree_util as jtu

from ott.geometry import geometry

Expand All @@ -26,16 +26,9 @@
__all__ = ["BaseQuadraticInitializer", "QuadraticInitializer"]


@jax.tree_util.register_pytree_node_class
@jtu.register_pytree_node_class
class BaseQuadraticInitializer(abc.ABC):
"""Base class for quadratic initializers.

Args:
kwargs: Keyword arguments.
"""

def __init__(self, **kwargs: Any):
self._kwargs = kwargs
"""Base class for quadratic initializers."""

def __call__(
self, quad_prob: "quadratic_problem.QuadraticProblem", **kwargs: Any
Expand All @@ -47,7 +40,7 @@ def __call__(
kwargs: Additional keyword arguments.

Returns:
Linear problem.
The linearized problem.
"""
from ott.problems.linear import linear_problem

Expand Down Expand Up @@ -80,7 +73,7 @@ def _create_geometry(
"""

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102
return [], self._kwargs
return [], {}

@classmethod
def tree_unflatten( # noqa: D102
Expand All @@ -89,6 +82,7 @@ def tree_unflatten( # noqa: D102
return cls(*children, **aux_data)


@jtu.register_pytree_node_class
class QuadraticInitializer(BaseQuadraticInitializer):
r"""Initialize a linear problem locally around a selected coupling.

Expand Down Expand Up @@ -125,10 +119,8 @@ class QuadraticInitializer(BaseQuadraticInitializer):
defaults to the product coupling :math:`ab^T`.
"""

def __init__(
self, init_coupling: Optional[jnp.ndarray] = None, **kwargs: Any
):
super().__init__(**kwargs)
def __init__(self, init_coupling: Optional[jnp.ndarray] = None):
super().__init__()
self.init_coupling = init_coupling

def _create_geometry(
Expand All @@ -145,10 +137,10 @@ def _create_geometry(
quad_prob: Quadratic OT problem.
epsilon: Epsilon regularization.
relative_epsilon: Flag, use `relative_epsilon` or not in geometry.
kwargs: Keyword arguments for :class:`~ott.geometry.geometry.Geometry`.
kwargs: Unused.

Returns:
The initial geometry used to initialize the linearized problem.
Geometry used to initialize the linearized problem.
"""
from ott.problems.quadratic import quadratic_problem

Expand Down Expand Up @@ -188,4 +180,4 @@ def _create_geometry(
)

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102
return [self.init_coupling], self._kwargs
return [self.init_coupling], {}
Loading
Loading