Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/refactor relative epsilon #602

Merged
merged 13 commits into from
Dec 2, 2024
2 changes: 0 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ jobs:
fast-tests:
name: Fast tests Python ${{ matrix.python-version }} ${{ matrix.jax-version }}
runs-on: ubuntu-latest
# allow tests using the latest JAX to fail
continue-on-error: ${{ matrix.jax-version == 'jax-latest' }}
strategy:
fail-fast: false
matrix:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ skip_missing_interpreters = true
extras =
test
# https://github.com/google/flax/issues/3329
py{3.9,3.10,3.11,3.12,3.13},py3.10-jax-default: neural
py{3.9,3.10,3.11,3.12},py3.10-jax-default: neural
pass_env = CUDA_*,PYTEST_*,CI
commands_pre =
gpu: python -I -m pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Expand Down
100 changes: 34 additions & 66 deletions src/ott/geometry/epsilon_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,95 +11,63 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
from typing import Optional

import jax
import jax.numpy as jnp
import jax.tree_util as jtu

__all__ = ["Epsilon", "DEFAULT_SCALE"]

#: Scaling applied to statistic (mean/std) of cost to compute default epsilon.
DEFAULT_SCALE = 0.05
michalk8 marked this conversation as resolved.
Show resolved Hide resolved


@jax.tree_util.register_pytree_node_class
@jtu.register_pytree_node_class
class Epsilon:
"""Scheduler class for the regularization parameter epsilon.
r"""Scheduler class for the regularization parameter epsilon.

An epsilon scheduler outputs a regularization strength, to be used by in a
Sinkhorn-type algorithm, at any iteration count. That value is either the
final, targeted regularization, or one that is larger, obtained by
geometric decay of an initial value that is larger than the intended target.
Concretely, the value returned by such a scheduler will consider first
the max between ``target`` and ``init * target * decay ** iteration``.
If the ``scale_epsilon`` parameter is provided, that value is used to
multiply the max computed previously by ``scale_epsilon``.
An epsilon scheduler outputs a regularization strength, to be used by the
:term:`Sinkhorn algorithm` or variant, at any iteration count. That value is
either the final, targeted regularization, or one that is larger, obtained by
geometric decay of an initial multiplier.

Args:
target: the epsilon regularizer that is targeted. If :obj:`None`,
use :obj:`DEFAULT_SCALE`, currently set at :math:`0.05`.
scale_epsilon: if passed, used to multiply the regularizer, to rescale it.
If :obj:`None`, use :math:`1`.
init: initial value when using epsilon scheduling, understood as multiple
of target value. if passed, ``int * decay ** iteration`` will be used
to rescale target.
decay: geometric decay factor, :math:`<1`.
target: The epsilon regularizer that is targeted.
init: Initial value when using epsilon scheduling, understood as a multiple
of the ``target``, following :math:`\text{init} \text{decay}^{\text{it}}`.
decay: Geometric decay factor, :math:`\leq 1`.
"""

def __init__(
self,
target: Optional[float] = None,
scale_epsilon: Optional[float] = None,
init: float = 1.0,
decay: float = 1.0
):
self._target_init = target
self._scale_epsilon = scale_epsilon
self._init = init
self._decay = decay
def __init__(self, target: jnp.array, init: float = 1.0, decay: float = 1.0):
assert decay <= 1.0, f"Decay must be <= 1, found {decay}."
self.target = target
self.init = init
self.decay = decay

@property
def target(self) -> float:
"""Return the final regularizer value of scheduler."""
target = DEFAULT_SCALE if self._target_init is None else self._target_init
scale = 1.0 if self._scale_epsilon is None else self._scale_epsilon
return scale * target
def __call__(self, it: Optional[int]) -> jnp.array:
"""Intermediate regularizer value at a given iteration number.

def at(self, iteration: Optional[int] = 1) -> float:
"""Return (intermediate) regularizer value at a given iteration."""
if iteration is None:
Args:
it: Current iteration. If :obj:`None`, return :attr:`target`.

Returns:
The epsilon value at the iteration.
"""
if it is None:
return self.target
# check the decay is smaller than 1.0.
decay = jnp.minimum(self._decay, 1.0)
# the multiple is either 1.0 or a larger init value that is decayed.
multiple = jnp.maximum(self._init * (decay ** iteration), 1.0)
multiple = jnp.maximum(self.init * (self.decay ** it), 1.0)
return multiple * self.target

def done(self, eps: float) -> bool:
"""Return whether the scheduler is done at a given value."""
return eps == self.target

def done_at(self, iteration: Optional[int]) -> bool:
"""Return whether the scheduler is done at a given iteration."""
return self.done(self.at(iteration))

def set(self, **kwargs: Any) -> "Epsilon":
"""Return a copy of self, with potential overwrites."""
kwargs = {
"target": self._target_init,
"scale_epsilon": self._scale_epsilon,
"init": self._init,
"decay": self._decay,
**kwargs
}
return Epsilon(**kwargs)
def __repr__(self) -> str:
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
return (
f"{self.__class__.__name__}(target={self.target:.4f}, "
f"init={self.init:.4f}, decay={self.decay:.4f})"
)

def tree_flatten(self): # noqa: D102
return (
self._target_init, self._scale_epsilon, self._init, self._decay
), None
return (self.target,), {"init": self.init, "decay": self.decay}

@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
del aux_data
return cls(*children)
return cls(*children, **aux_data)
138 changes: 60 additions & 78 deletions src/ott/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,53 +20,48 @@
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jax.tree_util as jtu
import numpy as np

from ott import utils
from ott.geometry import epsilon_scheduler
from ott.geometry import epsilon_scheduler as eps_scheduler
from ott.math import utils as mu

__all__ = ["Geometry"]


@jax.tree_util.register_pytree_node_class
@jtu.register_pytree_node_class
class Geometry:
r"""Base class to define ground costs/kernels used in optimal transport.

Optimal transport problems are intrinsically geometric: they compute an
optimal way to transport mass from one configuration onto another. To define
what is meant by optimality of transport requires defining a cost, of moving
mass from one among several sources, towards one out of multiple targets.
These sources and targets can be provided as points in vectors spaces, grids,
or more generally exclusively described through a (dissimilarity) cost matrix,
or almost equivalently, a (similarity) kernel matrix.

Once that cost or kernel matrix is set, the ``Geometry`` class provides a
basic operations to be run with the Sinkhorn algorithm.
what is meant by optimality of transport requires defining a
:term:`ground cost`, which quantifies how costly it is to move mass from
one among several source locations, towards one out of multiple
target locations. These source and target locations can be described as
points in vectors spaces, grids, or more generally described
through a (dissimilarity) cost matrix, or almost equivalently, a
(similarity) kernel matrix. This class describes such a
geometry and several useful methods to exploit it.

Args:
cost_matrix: Cost matrix of shape ``[n, m]``.
kernel_matrix: Kernel matrix of shape ``[n, m]``.
epsilon: Regularization parameter. If ``None`` and either
``relative_epsilon = True`` or ``relative_epsilon = None`` or
``relative_epsilon = str`` where ``str`` can be either ``mean`` or ``std``
, this value defaults to a multiple of :attr:`std_cost_matrix`
(or :attr:`mean_cost_matrix` if ``str`` is ``mean``), where that multiple
is set as ``DEFAULT_SCALE`` in ``epsilon_scheduler.py```.
If passed as a
``float``, then the regularizer that is ultimately used is either that
``float`` value (if ``relative_epsilon = False`` or ``None``) or that
``float`` times the :attr:`std_cost_matrix` (if
``relative_epsilon = True`` or ``relative_epsilon = `std```) or
:attr:`mean_cost_matrix` (if ``relative_epsilon = `mean```). Look for
epsilon: Regularization parameter or scheduler. Look for
:class:`~ott.geometry.epsilon_scheduler.Epsilon` when passed as a
scheduler.
relative_epsilon: when :obj:`False`, the parameter ``epsilon`` specifies the
value of the entropic regularization parameter. When :obj:`True` or set
to a string, ``epsilon`` refers to a fraction of the
:attr:`std_cost_matrix` or :attr:`mean_cost_matrix`, which is computed
adaptively from data, depending on whether it is set to ``mean`` or
``std``.
scheduler directly. Otherwise, if :obj:`None` and
``relative_epsilon`` is :obj:`None` the regularizer value
defaults to a multiple of :attr:`std_cost_matrix`, that multiple
is set as :obj:`~ott.geometry.epsilon_scheduler.DEFAULT_SCALE`,
currently equal to `0.05`. If passed as
a ``float``, then the regularizer that is ultimately used is either
that ``float`` value (if ``relative_epsilon`` is :obj:`None`) or that
``float`` times the :attr:`std_cost_matrix` (if
``relative_epsilon`` is ``"std"``) or
:attr:`mean_cost_matrix` (if ``relative_epsilon`` is ``"mean"``).
relative_epsilon: Whether ``epsilon`` refers to a fraction of the
:attr:`mean_cost_matrix` or :attr:`std_cost_matrix`.
scale_cost: option to rescale the cost matrix. Implemented scalings are
'median', 'mean', 'std' and 'max_cost'. Alternatively, a float factor can
be given to rescale the cost such that ``cost_matrix /= scale_cost``.
Expand All @@ -87,22 +82,17 @@ def __init__(
self,
cost_matrix: Optional[jnp.ndarray] = None,
kernel_matrix: Optional[jnp.ndarray] = None,
epsilon: Optional[Union[float, epsilon_scheduler.Epsilon]] = None,
relative_epsilon: Optional[Union[bool, Literal["mean", "std"]]] = None,
epsilon: Optional[Union[float, eps_scheduler.Epsilon]] = None,
relative_epsilon: Optional[Literal["mean", "std"]] = None,
scale_cost: Union[float, Literal["mean", "max_cost", "median",
"std"]] = 1.0,
src_mask: Optional[jnp.ndarray] = None,
tgt_mask: Optional[jnp.ndarray] = None,
):
self._cost_matrix = cost_matrix
self._kernel_matrix = kernel_matrix

# needed for `copy_epsilon`, because of the `isinstance` check
self._epsilon_init = epsilon if isinstance(
epsilon, epsilon_scheduler.Epsilon
) else epsilon_scheduler.Epsilon(epsilon)
self._epsilon_init = epsilon
self._relative_epsilon = relative_epsilon

self._scale_cost = scale_cost

self._src_mask = src_mask
Expand Down Expand Up @@ -150,7 +140,7 @@ def std_cost_matrix(self) -> float:
to output :math:`\sigma`.
"""
tmp = self._masked_geom().apply_square_cost(self._n_normed_ones).squeeze()
tmp = jnp.sum(tmp * self._m_normed_ones) - (self.mean_cost_matrix) ** 2
tmp = jnp.sum(tmp * self._m_normed_ones) - (self.mean_cost_matrix ** 2)
return jnp.sqrt(jax.nn.relu(tmp))

@property
Expand All @@ -164,35 +154,36 @@ def kernel_matrix(self) -> jnp.ndarray:
return self._kernel_matrix ** self.inv_scale_cost

@property
def _epsilon(self) -> epsilon_scheduler.Epsilon:
(target, scale_eps, _, _), _ = self._epsilon_init.tree_flatten()
rel = self._relative_epsilon

# If nothing passed, default to STD
if rel is None and target is None and scale_eps is None:
scale_eps = jax.lax.stop_gradient(self.std_cost_matrix)
# If instructions passed change, otherwise (notably if False) skip.
elif rel is not None:
if rel == "mean" or rel is True: # Legacy option.
scale_eps = jax.lax.stop_gradient(self.mean_cost_matrix)
elif rel == "std":
scale_eps = jax.lax.stop_gradient(self.std_cost_matrix)
# Avoid 0 std, since this would set epsilon to 0.0 and result in
# a division by 0.
scale_eps = jnp.where(scale_eps <= 0.0, 1.0, scale_eps)

if isinstance(self._epsilon_init, epsilon_scheduler.Epsilon):
return self._epsilon_init.set(scale_epsilon=scale_eps)

return epsilon_scheduler.Epsilon(
target=epsilon_scheduler.DEFAULT_SCALE if target is None else target,
scale_epsilon=scale_eps
def epsilon_scheduler(self) -> eps_scheduler.Epsilon:
"""TODO."""
if isinstance(self._epsilon_init, eps_scheduler.Epsilon):
return self._epsilon_init

# no relative epsilon
if self._relative_epsilon is None:
if self._epsilon_init is not None:
return eps_scheduler.Epsilon(self._epsilon_init)
multiplier = eps_scheduler.DEFAULT_SCALE
scale = jax.lax.stop_gradient(self.std_cost_matrix)
return eps_scheduler.Epsilon(target=multiplier * scale)

if self._relative_epsilon == "std":
scale = jax.lax.stop_gradient(self.std_cost_matrix)
elif self._relative_epsilon == "mean":
scale = jax.lax.stop_gradient(self.mean_cost_matrix)
else:
raise ValueError(f"Invalid relative epsilon: {self._relative_epsilon}.")

multiplier = (
eps_scheduler.DEFAULT_SCALE
if self._epsilon_init is None else self._epsilon_init
)
return eps_scheduler.Epsilon(target=multiplier * scale)

@property
def epsilon(self) -> float:
"""Epsilon regularization value."""
return self._epsilon.target
return self.epsilon_scheduler.target

@property
def shape(self) -> Tuple[int, int]:
Expand Down Expand Up @@ -257,20 +248,11 @@ def set_scale_cost(self, scale_cost: Union[float, str]) -> "Geometry":

def copy_epsilon(self, other: "Geometry") -> "Geometry":
"""Copy the epsilon parameters from another geometry."""
other_epsilon = other._epsilon
children, aux_data = self.tree_flatten()

new_children = []
for child in children:
if isinstance(child, epsilon_scheduler.Epsilon):
child = child.set(
target=other_epsilon._target_init,
scale_epsilon=other_epsilon._scale_epsilon
)
new_children.append(child)

aux_data["relative_epsilon"] = False
return type(self).tree_unflatten(aux_data, new_children)
new_geom = type(self).tree_unflatten(aux_data, children)
new_geom._epsilon_init = other.epsilon_scheduler
new_geom._relative_epsilon = other._relative_epsilon # has no effect
return new_geom

# The functions below are at the core of Sinkhorn iterations, they
# are implemented here in their default form, either in lse (using directly
Expand Down Expand Up @@ -412,7 +394,7 @@ def update_potential(
Returns:
new potential value, g if axis=0, f if axis is 1.
"""
eps = self._epsilon.at(iteration)
eps = self.epsilon_scheduler(iteration)
app_lse = self.apply_lse_kernel(f, g, eps, axis=axis)[0]
return eps * log_marginal - jnp.where(jnp.isfinite(app_lse), app_lse, 0)

Expand All @@ -434,7 +416,7 @@ def update_scaling(
Returns:
new scaling vector, of size num_b if axis=0, num_a if axis is 1.
"""
eps = self._epsilon.at(iteration)
eps = self.epsilon_scheduler(iteration)
app_kernel = self.apply_kernel(scaling, eps, axis=axis)
return marginal / jnp.where(app_kernel > 0, app_kernel, 1.0)

Expand Down Expand Up @@ -931,7 +913,7 @@ def tree_flatten(self): # noqa: D102
self._src_mask, self._tgt_mask
), {
"scale_cost": self._scale_cost,
"relative_epsilon": self._relative_epsilon
"relative_epsilon": self._relative_epsilon,
}

@classmethod
Expand Down
Loading
Loading