Skip to content

Commit

Permalink
Merge pull request #38 from ott-jax/526517853C347DEEFE613233BE082B1B
Browse files Browse the repository at this point in the history
Adding scaling factor to the cost matrix.
  • Loading branch information
marcocuturi authored Mar 24, 2022
2 parents 2371723 + 1b5a059 commit 4772645
Show file tree
Hide file tree
Showing 7 changed files with 528 additions and 112 deletions.
3 changes: 2 additions & 1 deletion ott/core/quad_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,8 @@ def update_lr_linearization(

def update_epsilon_unbalanced(epsilon, transport_mass):
updated_epsilon = epsilon_scheduler.Epsilon.make(epsilon)
updated_epsilon._scale = updated_epsilon._scale * transport_mass
updated_epsilon._scale_epsilon = (
updated_epsilon._scale_epsilon * transport_mass)
return updated_epsilon


Expand Down
61 changes: 31 additions & 30 deletions ott/core/sinkhorn_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def lr_costs(self, ot_prob, state, iteration):
diag_qcr = jnp.sum(state.q * ot_prob.geom.apply_cost(state.r, axis=1),
axis=0)
h = diag_qcr / state.g ** 2 - (
self.epsilon - 1 / self.gamma) * jnp.log(state.g)
self.epsilon - 1 / self.gamma) * jnp.log(state.g)
return c_q, c_r, h

def dysktra_update(self, c_q, c_r, h, ot_prob, state, iteration,
Expand Down Expand Up @@ -423,34 +423,35 @@ def run(ot_prob, solver, init) -> LRSinkhornOutput:
out = out.set_cost(ot_prob, solver.lse_mode, solver.use_danskin)
return out.set(ot_prob=ot_prob)


def make(
rank: int = 10,
gamma: float = 1.0,
epsilon: float = 1e-4,
lse_mode: bool = True,
threshold: float = 1e-3,
norm_error: int = 1,
inner_iterations: int = 1,
min_iterations: int = 0,
max_iterations: int = 2000,
use_danskin: bool = True,
implicit_diff: bool = False,
jit: bool = True,
rng_key: int = 0,
kwargs_dys: Any = None) -> LRSinkhorn:
rank: int = 10,
gamma: float = 1.0,
epsilon: float = 1e-4,
lse_mode: bool = True,
threshold: float = 1e-3,
norm_error: int = 1,
inner_iterations: int = 1,
min_iterations: int = 0,
max_iterations: int = 2000,
use_danskin: bool = True,
implicit_diff: bool = False,
jit: bool = True,
rng_key: int = 0,
kwargs_dys: Any = None) -> LRSinkhorn:

return LRSinkhorn(
rank=rank,
gamma=gamma,
epsilon=epsilon,
lse_mode=lse_mode,
threshold=threshold,
norm_error=norm_error,
inner_iterations=inner_iterations,
min_iterations=min_iterations,
max_iterations=max_iterations,
use_danskin=use_danskin,
implicit_diff=implicit_diff,
jit=jit,
rng_key=rng_key,
kwargs_dys=kwargs_dys)
rank=rank,
gamma=gamma,
epsilon=epsilon,
lse_mode=lse_mode,
threshold=threshold,
norm_error=norm_error,
inner_iterations=inner_iterations,
min_iterations=min_iterations,
max_iterations=max_iterations,
use_danskin=use_danskin,
implicit_diff=implicit_diff,
jit=jit,
rng_key=rng_key,
kwargs_dys=kwargs_dys)
15 changes: 8 additions & 7 deletions ott/geometry/epsilon_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Epsilon:

def __init__(self,
target: Optional[float] = None,
scale: Optional[float] = None,
scale_epsilon: Optional[float] = None,
init: Optional[float] = None,
decay: Optional[float] = None):
r"""Initializes a scheduler using possibly geometric decay.
Expand All @@ -38,26 +38,26 @@ def __init__(self,
geometric decay of an initial value that is larger than the intended target.
Concretely, the value returned by such a scheduler will consider first
the max between ``target`` and ``init * target * decay ** iteration``.
If the ``scale`` parameter is provided, that value is used to multiply the
max computed previously by ``scale``.
If the ``scale_epsilon`` parameter is provided, that value is used to multiply the
max computed previously by ``scale_epsilon``.
Args:
target: the epsilon regularizer that is targeted.
scale: if passed, used to multiply the regularizer, to rescale it.
scale_epsilon: if passed, used to multiply the regularizer, to rescale it.
init: initial value when using epsilon scheduling, understood as multiple
of target value. if passed, ``int * decay ** iteration`` will be used
to rescale target.
decay: geometric decay factor, smaller than 1.
"""
self._target_init = .01 if target is None else target
self._scale = 1.0 if scale is None else scale
self._scale_epsilon = 1.0 if scale_epsilon is None else scale_epsilon
self._init = 1.0 if init is None else init
self._decay = 1.0 if decay is None else decay

@property
def target(self):
"""Returns final regularizer value of scheduler."""
return self._target_init * self._scale
return self._target_init * self._scale_epsilon

def at(self, iteration: Optional[int] = 1) -> float:
"""Returns (intermediate) regularizer value at a given iteration."""
Expand All @@ -76,7 +76,8 @@ def done_at(self, iteration):
return self.done(self.at(iteration))

def tree_flatten(self):
return (self._target_init, self._scale, self._init, self._decay), None
return (self._target_init, self._scale_epsilon,
self._init, self._decay), None

@classmethod
def tree_unflatten(cls, aux_data, children):
Expand Down
57 changes: 41 additions & 16 deletions ott/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def __init__(self,
kernel_matrix: Optional[jnp.ndarray] = None,
epsilon: Union[epsilon_scheduler.Epsilon, float, None] = None,
relative_epsilon: Optional[bool] = None,
scale: Optional[float] = None,
scale_epsilon: Optional[float] = None,
scale_cost: Optional[Union[float, str]] = None,
**kwargs):
r"""Initializes a geometry by passing it a cost matrix or a kernel matrix.
Expand All @@ -68,14 +69,18 @@ def __init__(self,
the mean value of the ``cost_matrix``.
relative_epsilon: whether epsilon is passed relative to scale of problem,
here understood as mean value of ``cost_matrix``.
scale: the scale multiplier for epsilon.
scale_epsilon: the scale multiplier for epsilon.
scale_cost: option to rescale the cost matrix. Implemented scalings are
'median', 'mean' and 'max_cost'. Alternatively, a float factor can be
given to rescale the cost such that ``cost_matrix /= factor``.
**kwargs: additional kwargs to epsilon.
"""
self._cost_matrix = cost_matrix
self._kernel_matrix = kernel_matrix
self._epsilon_init = epsilon
self._relative_epsilon = relative_epsilon
self._scale = scale
self._scale_epsilon = scale_epsilon
self._scale_cost = scale_cost
# Define default dictionary and update it with user's values.
self._kwargs = {**{'init': None, 'decay': None}, **kwargs}

Expand All @@ -84,28 +89,29 @@ def cost_rank(self):
return None

@property
def scale(self) -> float:
def scale_epsilon(self) -> float:
"""Computes the scale of the epsilon, potentially based on data."""
if isinstance(self._epsilon_init, epsilon_scheduler.Epsilon):
return 1.0

rel = self._relative_epsilon
trigger = ((self._scale is None) and
trigger = ((self._scale_epsilon is None) and
(rel or rel is None) and
(self._epsilon_init is None or rel))
if (self._scale is None) and (trigger is not None): # for dry run
if (self._scale_epsilon is None) and (trigger is not None): # for dry run
return jnp.where(
trigger, jax.lax.stop_gradient(self.mean_cost_matrix), 1.0)
else:
return self._scale
return self._scale_epsilon

@property
def _epsilon(self):
"""Returns epsilon scheduler, either passed directly or by building it."""
if isinstance(self._epsilon_init, epsilon_scheduler.Epsilon):
return self._epsilon_init
eps = 5e-2 if self._epsilon_init is None else self._epsilon_init
return epsilon_scheduler.Epsilon.make(eps, scale=self.scale, **self._kwargs)
return epsilon_scheduler.Epsilon.make(
eps, scale_epsilon=self.scale_epsilon, **self._kwargs)

@property
def cost_matrix(self):
Expand All @@ -114,8 +120,9 @@ def cost_matrix(self):
# If no epsilon was passed on to the geometry, then assume it is one by
# default.
cost = -jnp.log(self._kernel_matrix)
cost *= self.scale_cost
return cost if self._epsilon_init is None else self.epsilon * cost
return self._cost_matrix
return self._cost_matrix * self.scale_cost

@property
def median_cost_matrix(self):
Expand All @@ -132,7 +139,8 @@ def mean_cost_matrix(self):
@property
def kernel_matrix(self):
if self._kernel_matrix is None:
return jnp.exp(-(self._cost_matrix / self.epsilon))
return jnp.exp(
-(self._cost_matrix / self.epsilon))**(1.0 / self.scale_cost)
return self._kernel_matrix

@property
Expand All @@ -141,7 +149,8 @@ def epsilon(self):

@property
def shape(self):
mat = self.kernel_matrix if self.cost_matrix is None else self.cost_matrix
mat = (self._kernel_matrix if self._cost_matrix is None
else self._cost_matrix)
if mat is not None:
return mat.shape
return (0, 0)
Expand All @@ -160,12 +169,28 @@ def is_symmetric(self):
return (mat.shape[0] == mat.shape[1] and
jnp.all(mat == mat.T)) if mat is not None else False

@property
def scale_cost(self):
"""Computes the factor to scale the cost matrix."""
if isinstance(self._scale_cost, float):
return 1.0 / self._scale_cost
elif self._scale_cost == 'max_cost':
return jax.lax.stop_gradient(1.0 / jnp.max(self._cost_matrix))
elif self._scale_cost == 'mean':
return jax.lax.stop_gradient(1.0 / jnp.mean(self._cost_matrix))
elif self._scale_cost == 'median':
return jax.lax.stop_gradient(1.0 / jnp.median(self._cost_matrix))
elif isinstance(self._scale_cost, str):
raise ValueError(f'Scaling {self._scale_cost} not implemented.')
else:
return 1.0

def copy_epsilon(self, other):
"""Copies the epsilon parameters from another geometry."""
scheduler = other._epsilon
self._epsilon_init = scheduler._target_init
self._relative_epsilon = False
self._scale = other.scale
self._scale_epsilon = other.scale_epsilon

# The functions below are at the core of Sinkhorn iterations, they
# are implemented here in their default form, either in lse (using directly
Expand Down Expand Up @@ -441,7 +466,7 @@ def apply_cost(self, arr: jnp.ndarray, axis: int = 0, fn=None) -> jnp.ndarray:
)(
arr)

def rescale_cost(self, factor: float):
def rescale_cost_fn(self, factor: float):
if self._cost_matrix is not None:
self._cost_matrix *= factor
if self._kernel_matrix is not None:
Expand Down Expand Up @@ -486,12 +511,12 @@ def prepare_divergences(cls, *args, static_b: bool = False, **kwargs):

def tree_flatten(self):
return (self._cost_matrix, self._kernel_matrix, self._epsilon_init,
self._relative_epsilon, self._kwargs), None
self._relative_epsilon,
self._kwargs), {'scale_cost': self._scale_cost}

@classmethod
def tree_unflatten(cls, aux_data, children):
del aux_data
return cls(*children[:-1], **children[-1])
return cls(*children[:-1], **children[-1], **aux_data)


def is_affine(fn) -> bool:
Expand Down
63 changes: 52 additions & 11 deletions ott/geometry/low_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

# Lint as: python3
"""A class describing low-rank geometries."""
from typing import Union, Optional
import jax
import jax.numpy as jnp
from ott.geometry import geometry
Expand All @@ -29,6 +30,7 @@ def __init__(self,
cost_1: jnp.ndarray,
cost_2: jnp.ndarray,
bias: float = 0.0,
scale_cost: Optional[Union[float, str]] = None,
**kwargs
):
r"""Initializes a geometry by passing it low-rank factors.
Expand All @@ -37,33 +39,72 @@ def __init__(self,
cost_1: jnp.ndarray<float>[num_a, r]
cost_2: jnp.ndarray<float>[num_b, r]
bias: constant added to entire cost matrix.
scale_cost: option to rescale the cost matrix. Implemented scalings are
'max_bound'. Alternatively, a float factor can be
given to rescale the cost such that ``cost_matrix /= factor``.
**kwargs: additional kwargs to Geometry
"""
assert cost_1.shape[1] == cost_2.shape[1]
self.cost_1 = cost_1
self.cost_2 = cost_2
self.bias = bias
self._cost_1 = cost_1
self._cost_2 = cost_2
self._bias = bias
self._kwargs = kwargs

super().__init__(**kwargs)
self._scale_cost = scale_cost

@property
def cost_1(self):
return self._cost_1 * jnp.sqrt(self.scale_cost)

@property
def cost_2(self):
return self._cost_2 * jnp.sqrt(self.scale_cost)

@property
def bias(self):
return self._bias * self.scale_cost

@property
def cost_rank(self):
return self.cost_1.shape[1]
return self._cost_1.shape[1]

@property
def cost_matrix(self):
"""Returns cost matrix if requested."""
return jnp.matmul(self.cost_1, self.cost_2.T) + self.bias
return (jnp.matmul(self.cost_1, self.cost_2.T) + self.bias)

@property
def shape(self):
return (self.cost_1.shape[0], self.cost_2.shape[0])
return (self._cost_1.shape[0], self._cost_2.shape[0])

@property
def is_symmetric(self):
return (self.cost_1.shape[0] == self.cost_2.shape[0] and
jnp.all(self.cost_1 == self.cost_2))
return (self._cost_1.shape[0] == self._cost_2.shape[0] and
jnp.all(self._cost_1 == self._cost_2))

@property
def scale_cost(self):
if isinstance(self._scale_cost, float):
return self._scale_cost
elif self._scale_cost == 'max_bound':
return jax.lax.stop_gradient(
1.0 / (jnp.max(jnp.abs(self._cost_1))
* jnp.max(jnp.abs(self._cost_2))
+ jnp.abs(self._bias)))
elif self._scale_cost == 'mean':
factor1 = jnp.dot(jnp.ones(self.shape[0]), self._cost_1)
factor2 = jnp.dot(self._cost_2.T, jnp.ones(self.shape[1]))
mean = (jnp.dot(factor1, factor2) / (self.shape[0] * self.shape[1])
+ self._bias)
return jax.lax.stop_gradient(1.0 / mean)
elif self._scale_cost == 'max_cost':
# TODO(lpapaxanthos): implement memory efficient max.
raise NotImplementedError(f'Scaling {self._scale_cost} not implemented.')
elif isinstance(self._scale_cost, str):
raise ValueError(f'Scaling {self._scale_cost} not provided.')
else:
return 1.0

def apply_square_cost(self, arr: jnp.ndarray, axis: int = 0) -> jnp.ndarray:
"""Applies elementwise-square of cost matrix to array (vector or matrix)."""
Expand Down Expand Up @@ -114,12 +155,12 @@ def apply_cost_2(self, vec, axis=0):
return jnp.dot(self.cost_2 if axis == 0 else self.cost_2.T, vec)

def tree_flatten(self):
return (self.cost_1, self.cost_2, self._kwargs), None
return (self._cost_1, self._cost_2, self._kwargs), {
'bias': self._bias, 'scale_cost': self._scale_cost}

@classmethod
def tree_unflatten(cls, aux_data, children):
del aux_data
return cls(*children[:-1], **children[-1])
return cls(*children[:-1], **children[-1], **aux_data)


def add_lrc_geom(geom1: LRCGeometry, geom2: LRCGeometry):
Expand Down
Loading

0 comments on commit 4772645

Please sign in to comment.