From c855e02a78166abb0c90fa694d6966d2505d0e16 Mon Sep 17 00:00:00 2001 From: Laetitia Papaxanthos Date: Wed, 23 Mar 2022 14:26:43 +0000 Subject: [PATCH] Adding scaling factor to the cost matrix. PiperOrigin-RevId: 436732691 --- ott/core/quad_problems.py | 3 +- ott/core/sinkhorn_lr.py | 61 +++++----- ott/geometry/epsilon_scheduler.py | 15 +-- ott/geometry/geometry.py | 55 +++++++--- ott/geometry/low_rank.py | 51 +++++++-- ott/geometry/pointcloud.py | 165 ++++++++++++++++++++-------- tests/geometry/scaling_cost_test.py | 158 ++++++++++++++++++++++++++ 7 files changed, 402 insertions(+), 106 deletions(-) create mode 100644 tests/geometry/scaling_cost_test.py diff --git a/ott/core/quad_problems.py b/ott/core/quad_problems.py index 1a1ca3d93..0b13a2f4f 100644 --- a/ott/core/quad_problems.py +++ b/ott/core/quad_problems.py @@ -488,7 +488,8 @@ def update_lr_linearization( def update_epsilon_unbalanced(epsilon, transport_mass): updated_epsilon = epsilon_scheduler.Epsilon.make(epsilon) - updated_epsilon._scale = updated_epsilon._scale * transport_mass + updated_epsilon._scale_epsilon = ( + updated_epsilon._scale_epsilon * transport_mass) return updated_epsilon diff --git a/ott/core/sinkhorn_lr.py b/ott/core/sinkhorn_lr.py index b4c68923d..8f7379028 100644 --- a/ott/core/sinkhorn_lr.py +++ b/ott/core/sinkhorn_lr.py @@ -271,7 +271,7 @@ def lr_costs(self, ot_prob, state, iteration): diag_qcr = jnp.sum(state.q * ot_prob.geom.apply_cost(state.r, axis=1), axis=0) h = diag_qcr / state.g ** 2 - ( - self.epsilon - 1 / self.gamma) * jnp.log(state.g) + self.epsilon - 1 / self.gamma) * jnp.log(state.g) return c_q, c_r, h def dysktra_update(self, c_q, c_r, h, ot_prob, state, iteration, @@ -423,34 +423,35 @@ def run(ot_prob, solver, init) -> LRSinkhornOutput: out = out.set_cost(ot_prob, solver.lse_mode, solver.use_danskin) return out.set(ot_prob=ot_prob) + def make( - rank: int = 10, - gamma: float = 1.0, - epsilon: float = 1e-4, - lse_mode: bool = True, - threshold: float = 1e-3, - norm_error: int = 1, - inner_iterations: int = 1, - min_iterations: int = 0, - max_iterations: int = 2000, - use_danskin: bool = True, - implicit_diff: bool = False, - jit: bool = True, - rng_key: int = 0, - kwargs_dys: Any = None) -> LRSinkhorn: - + rank: int = 10, + gamma: float = 1.0, + epsilon: float = 1e-4, + lse_mode: bool = True, + threshold: float = 1e-3, + norm_error: int = 1, + inner_iterations: int = 1, + min_iterations: int = 0, + max_iterations: int = 2000, + use_danskin: bool = True, + implicit_diff: bool = False, + jit: bool = True, + rng_key: int = 0, + kwargs_dys: Any = None) -> LRSinkhorn: + return LRSinkhorn( - rank=rank, - gamma=gamma, - epsilon=epsilon, - lse_mode=lse_mode, - threshold=threshold, - norm_error=norm_error, - inner_iterations=inner_iterations, - min_iterations=min_iterations, - max_iterations=max_iterations, - use_danskin=use_danskin, - implicit_diff=implicit_diff, - jit=jit, - rng_key=rng_key, - kwargs_dys=kwargs_dys) + rank=rank, + gamma=gamma, + epsilon=epsilon, + lse_mode=lse_mode, + threshold=threshold, + norm_error=norm_error, + inner_iterations=inner_iterations, + min_iterations=min_iterations, + max_iterations=max_iterations, + use_danskin=use_danskin, + implicit_diff=implicit_diff, + jit=jit, + rng_key=rng_key, + kwargs_dys=kwargs_dys) diff --git a/ott/geometry/epsilon_scheduler.py b/ott/geometry/epsilon_scheduler.py index f3d59025e..268f166fb 100644 --- a/ott/geometry/epsilon_scheduler.py +++ b/ott/geometry/epsilon_scheduler.py @@ -27,7 +27,7 @@ class Epsilon: def __init__(self, target: Optional[float] = None, - scale: Optional[float] = None, + scale_epsilon: Optional[float] = None, init: Optional[float] = None, decay: Optional[float] = None): r"""Initializes a scheduler using possibly geometric decay. @@ -38,26 +38,26 @@ def __init__(self, geometric decay of an initial value that is larger than the intended target. Concretely, the value returned by such a scheduler will consider first the max between ``target`` and ``init * target * decay ** iteration``. - If the ``scale`` parameter is provided, that value is used to multiply the - max computed previously by ``scale``. + If the ``scale_epsilon`` parameter is provided, that value is used to multiply the + max computed previously by ``scale_epsilon``. Args: target: the epsilon regularizer that is targeted. - scale: if passed, used to multiply the regularizer, to rescale it. + scale_epsilon: if passed, used to multiply the regularizer, to rescale it. init: initial value when using epsilon scheduling, understood as multiple of target value. if passed, ``int * decay ** iteration`` will be used to rescale target. decay: geometric decay factor, smaller than 1. """ self._target_init = .01 if target is None else target - self._scale = 1.0 if scale is None else scale + self._scale_epsilon = 1.0 if scale_epsilon is None else scale_epsilon self._init = 1.0 if init is None else init self._decay = 1.0 if decay is None else decay @property def target(self): """Returns final regularizer value of scheduler.""" - return self._target_init * self._scale + return self._target_init * self._scale_epsilon def at(self, iteration: Optional[int] = 1) -> float: """Returns (intermediate) regularizer value at a given iteration.""" @@ -76,7 +76,8 @@ def done_at(self, iteration): return self.done(self.at(iteration)) def tree_flatten(self): - return (self._target_init, self._scale, self._init, self._decay), None + return (self._target_init, self._scale_epsilon, + self._init, self._decay), None @classmethod def tree_unflatten(cls, aux_data, children): diff --git a/ott/geometry/geometry.py b/ott/geometry/geometry.py index d36643720..14c29f7a8 100644 --- a/ott/geometry/geometry.py +++ b/ott/geometry/geometry.py @@ -51,7 +51,8 @@ def __init__(self, kernel_matrix: Optional[jnp.ndarray] = None, epsilon: Union[epsilon_scheduler.Epsilon, float, None] = None, relative_epsilon: Optional[bool] = None, - scale: Optional[float] = None, + scale_epsilon: Optional[float] = None, + scale_cost: Optional[Union[float, str]] = None, **kwargs): r"""Initializes a geometry by passing it a cost matrix or a kernel matrix. @@ -68,14 +69,18 @@ def __init__(self, the mean value of the ``cost_matrix``. relative_epsilon: whether epsilon is passed relative to scale of problem, here understood as mean value of ``cost_matrix``. - scale: the scale multiplier for epsilon. + scale_epsilon: the scale multiplier for epsilon. + scale_cost: option to rescale the cost matrix. Implemented scalings are + 'median', 'mean' and 'max_cost'. Alternatively, a float factor can be + given to rescale the cost such that ``cost_matrix /= factor``. **kwargs: additional kwargs to epsilon. """ self._cost_matrix = cost_matrix self._kernel_matrix = kernel_matrix self._epsilon_init = epsilon self._relative_epsilon = relative_epsilon - self._scale = scale + self._scale_epsilon = scale_epsilon + self._scale_cost = scale_cost # Define default dictionary and update it with user's values. self._kwargs = {**{'init': None, 'decay': None}, **kwargs} @@ -84,20 +89,20 @@ def cost_rank(self): return None @property - def scale(self) -> float: + def scale_epsilon(self) -> float: """Computes the scale of the epsilon, potentially based on data.""" if isinstance(self._epsilon_init, epsilon_scheduler.Epsilon): return 1.0 rel = self._relative_epsilon - trigger = ((self._scale is None) and + trigger = ((self._scale_epsilon is None) and (rel or rel is None) and (self._epsilon_init is None or rel)) - if (self._scale is None) and (trigger is not None): # for dry run + if (self._scale_epsilon is None) and (trigger is not None): # for dry run return jnp.where( trigger, jax.lax.stop_gradient(self.mean_cost_matrix), 1.0) else: - return self._scale + return self._scale_epsilon @property def _epsilon(self): @@ -105,7 +110,8 @@ def _epsilon(self): if isinstance(self._epsilon_init, epsilon_scheduler.Epsilon): return self._epsilon_init eps = 5e-2 if self._epsilon_init is None else self._epsilon_init - return epsilon_scheduler.Epsilon.make(eps, scale=self.scale, **self._kwargs) + return epsilon_scheduler.Epsilon.make( + eps, scale_epsilon=self.scale_epsilon, **self._kwargs) @property def cost_matrix(self): @@ -114,8 +120,9 @@ def cost_matrix(self): # If no epsilon was passed on to the geometry, then assume it is one by # default. cost = -jnp.log(self._kernel_matrix) + cost *= self.scale_cost return cost if self._epsilon_init is None else self.epsilon * cost - return self._cost_matrix + return self._cost_matrix * self.scale_cost @property def median_cost_matrix(self): @@ -132,7 +139,8 @@ def mean_cost_matrix(self): @property def kernel_matrix(self): if self._kernel_matrix is None: - return jnp.exp(-(self._cost_matrix / self.epsilon)) + return jnp.exp( + -(self._cost_matrix / self.epsilon))**(1.0 / self.scale_cost) return self._kernel_matrix @property @@ -141,7 +149,8 @@ def epsilon(self): @property def shape(self): - mat = self.kernel_matrix if self.cost_matrix is None else self.cost_matrix + mat = (self._kernel_matrix if self._cost_matrix is None + else self._cost_matrix) if mat is not None: return mat.shape return (0, 0) @@ -160,12 +169,26 @@ def is_symmetric(self): return (mat.shape[0] == mat.shape[1] and jnp.all(mat == mat.T)) if mat is not None else False + @property + def scale_cost(self): + """Computes the factor to scale the cost matrix.""" + if isinstance(self._scale_cost, float): + return 1.0 / self._scale_cost + elif self._scale_cost == 'max_cost': + return jax.lax.stop_gradient(1.0 / jnp.max(self._cost_matrix)) + elif self._scale_cost == 'mean': + return jax.lax.stop_gradient(1.0 / jnp.mean(self._cost_matrix)) + elif self._scale_cost == 'median': + return jax.lax.stop_gradient(1.0 / jnp.median(self._cost_matrix)) + else: + return 1.0 + def copy_epsilon(self, other): """Copies the epsilon parameters from another geometry.""" scheduler = other._epsilon self._epsilon_init = scheduler._target_init self._relative_epsilon = False - self._scale = other.scale + self._scale_epsilon = other.scale_epsilon # The functions below are at the core of Sinkhorn iterations, they # are implemented here in their default form, either in lse (using directly @@ -441,7 +464,7 @@ def apply_cost(self, arr: jnp.ndarray, axis: int = 0, fn=None) -> jnp.ndarray: )( arr) - def rescale_cost(self, factor: float): + def rescale_cost_fn(self, factor: float): if self._cost_matrix is not None: self._cost_matrix *= factor if self._kernel_matrix is not None: @@ -486,12 +509,12 @@ def prepare_divergences(cls, *args, static_b: bool = False, **kwargs): def tree_flatten(self): return (self._cost_matrix, self._kernel_matrix, self._epsilon_init, - self._relative_epsilon, self._kwargs), None + self._relative_epsilon, + self._kwargs), {'scale_cost': self._scale_cost} @classmethod def tree_unflatten(cls, aux_data, children): - del aux_data - return cls(*children[:-1], **children[-1]) + return cls(*children[:-1], **children[-1], **aux_data) def is_affine(fn) -> bool: diff --git a/ott/geometry/low_rank.py b/ott/geometry/low_rank.py index 95460be5f..bc514bbc7 100644 --- a/ott/geometry/low_rank.py +++ b/ott/geometry/low_rank.py @@ -15,6 +15,7 @@ # Lint as: python3 """A class describing low-rank geometries.""" +from typing import Union, Optional import jax import jax.numpy as jnp from ott.geometry import geometry @@ -29,6 +30,7 @@ def __init__(self, cost_1: jnp.ndarray, cost_2: jnp.ndarray, bias: float = 0.0, + scale_cost: Optional[Union[float, str]] = None, **kwargs ): r"""Initializes a geometry by passing it low-rank factors. @@ -37,16 +39,32 @@ def __init__(self, cost_1: jnp.ndarray[num_a, r] cost_2: jnp.ndarray[num_b, r] bias: constant added to entire cost matrix. + scale_cost: option to rescale the cost matrix. Implemented scalings are + 'max_bound'. Alternatively, a float factor can be + given to rescale the cost such that ``cost_matrix /= factor``. **kwargs: additional kwargs to Geometry """ assert cost_1.shape[1] == cost_2.shape[1] - self.cost_1 = cost_1 - self.cost_2 = cost_2 - self.bias = bias + self._cost_1 = cost_1 + self._cost_2 = cost_2 + self._bias = bias + self._scale_cost = scale_cost self._kwargs = kwargs super().__init__(**kwargs) + @property + def cost_1(self): + return self._cost_1 * jnp.sqrt(self.scale_cost) + + @property + def cost_2(self): + return self._cost_2 * jnp.sqrt(self.scale_cost) + + @property + def bias(self): + return self._bias * self.scale_cost + @property def cost_rank(self): return self.cost_1.shape[1] @@ -54,7 +72,8 @@ def cost_rank(self): @property def cost_matrix(self): """Returns cost matrix if requested.""" - return jnp.matmul(self.cost_1, self.cost_2.T) + self.bias + return ( + jnp.matmul(self.cost_1, self.cost_2.T) + self.bias) * self.scale_cost @property def shape(self): @@ -65,6 +84,24 @@ def is_symmetric(self): return (self.cost_1.shape[0] == self.cost_2.shape[0] and jnp.all(self.cost_1 == self.cost_2)) + @property + def scale_cost(self): + if isinstance(self._scale_cost, float): + return self._scale_cost + elif self._scale_cost == 'max_bound': + return jax.lax.stop_gradient( + 1.0 / (jnp.max(jnp.abs(self.cost_1)) + * jnp.max(jnp.abs(self.cost_2)) + + jnp.abs(self.bias))) + elif self._scale_cost == 'mean': + # TODO(lpapaxanthos): implement memory efficient mean. + return 1.0 + elif self._scale_cost == 'max_cost': + # TODO(lpapaxanthos): implement memory efficient max. + return 1.0 + else: + return 1.0 + def apply_square_cost(self, arr: jnp.ndarray, axis: int = 0) -> jnp.ndarray: """Applies elementwise-square of cost matrix to array (vector or matrix).""" (n, m), r = self.shape, self.cost_rank @@ -114,12 +151,12 @@ def apply_cost_2(self, vec, axis=0): return jnp.dot(self.cost_2 if axis == 0 else self.cost_2.T, vec) def tree_flatten(self): - return (self.cost_1, self.cost_2, self._kwargs), None + return (self._cost_1, self._cost_2, self._kwargs), { + 'bias': self._bias, 'scale_cost': self._scale_cost} @classmethod def tree_unflatten(cls, aux_data, children): - del aux_data - return cls(*children[:-1], **children[-1]) + return cls(*children[:-1], **children[-1], **aux_data) def add_lrc_geom(geom1: LRCGeometry, geom2: LRCGeometry): diff --git a/ott/geometry/pointcloud.py b/ott/geometry/pointcloud.py index 25cc21930..967eb5174 100644 --- a/ott/geometry/pointcloud.py +++ b/ott/geometry/pointcloud.py @@ -37,6 +37,7 @@ def __init__(self, cost_fn: Optional[costs.CostFn] = None, power: float = 2.0, online: Optional[Union[bool, int]] = None, + scale_cost: Optional[Union[float, str]] = None, **kwargs): """Creates a geometry from two point clouds, using CostFn. @@ -58,6 +59,10 @@ def __init__(self, online computation is particularly useful for big point clouds such that their cost matrix does not fit in memory. This is done by batching :meth:`apply_lse_kernel`. If `True`, batch size of 1024 is used. + scale_cost: option to rescale the cost matrix. Implemented scalings are + 'median', 'mean', 'max_cost', 'max_norm' and 'max_bound'. + Alternatively, a float factor can be given to rescale the cost such + that ``cost_matrix /= factor``. **kwargs: other optional parameters to be passed on to superclass initializer, notably those related to epsilon regularization. """ @@ -73,7 +78,8 @@ def __init__(self, if online: assert online > 0, f"`online={online}` must be positive." n, m = self.shape - self._bs = min(online, online, *(() + ((n,) if n else ()) + ((m,) if m else ()))) + self._bs = min( + online, online, *(() + ((n,) if n else ()) + ((m,) if m else ()))) # use `floor` instead of `ceil` and handle the rest seperately self._x_nsplit = int(math.floor(n / self._bs)) self._y_nsplit = int(math.floor(m / self._bs)) @@ -83,6 +89,7 @@ def __init__(self, self._online = online self.power = power super().__init__(**kwargs) + self._scale_cost = scale_cost @property def _norm_x(self): @@ -102,10 +109,8 @@ def _norm_y(self): def cost_matrix(self): if self._online: return None - cost_matrix = self._cost_fn.all_pairs_pairwise(self.x, self.y) - if self._axis_norm is not None: - cost_matrix += self._norm_x[:, jnp.newaxis] + self._norm_y[jnp.newaxis, :] - return cost_matrix ** (0.5 * self.power) + cost_matrix = self.compute_cost_matrix() + return cost_matrix * self.scale_cost @property def kernel_matrix(self): @@ -136,6 +141,58 @@ def is_squared_euclidean(self): def is_online(self) -> bool: return self._online is not None + @property + def scale_cost(self): + """Computes the factor to scale the cost matrix.""" + if isinstance(self._scale_cost, float): + return 1.0 / self._scale_cost + elif self._scale_cost == 'max_cost': + if self.is_online: + # TODO(lpapaxanthos): implement memory efficient max. + return 1.0 + else: + return jax.lax.stop_gradient(1.0 / jnp.max(self.compute_cost_matrix())) + elif self._scale_cost == 'mean': + if self.is_online: + # TODO(lpapaxanthos): implement memory efficient mean. + return 1.0 + else: + if isinstance(self.shape[0], int) and (self.shape[0] > 0): + return jax.lax.stop_gradient( + 1.0 / jnp.mean(self.compute_cost_matrix())) + elif self._scale_cost == 'median': + if not self.is_online: + return jax.lax.stop_gradient( + 1.0 / jnp.median(self.compute_cost_matrix())) + else: + raise NotImplementedError("""Using the median as scaling factor for + the cost matrix with the online mode is not implemented.""") + elif self._scale_cost == 'max_norm': + if self._cost_fn.norm is not None: + return jax.lax.stop_gradient( + 1.0 / jnp.maximum(self._cost_fn.norm(self.x).max(), + self._cost_fn.norm(self.y).max())) + else: + return 1.0 + elif self._scale_cost == 'max_bound': + if self.is_squared_euclidean: + x_argmax = jnp.argmax(self._norm_x) + y_argmax = jnp.argmax(self._norm_y) + max_bound = (self._norm_x[x_argmax] + self._norm_y[y_argmax] + + 2 * jnp.sqrt( + self._norm_x[x_argmax] * self._norm_y[y_argmax])) + return jax.lax.stop_gradient(1.0 / max_bound) + else: + return 1.0 + else: + return 1.0 + + def compute_cost_matrix(self): + cost_matrix = self._cost_fn.all_pairs_pairwise(self.x, self.y) + if self._axis_norm is not None: + cost_matrix += self._norm_x[:, jnp.newaxis] + self._norm_y[jnp.newaxis, :] + return cost_matrix ** (0.5 * self.power) + def apply_lse_kernel(self, f: jnp.ndarray, g: jnp.ndarray, @@ -144,39 +201,48 @@ def apply_lse_kernel(self, axis: int = 0) -> jnp.ndarray: def body0(carry, i: int): f, g, eps, vec = carry - y = jax.lax.dynamic_slice(self.y, (i * self._bs, 0), (self._bs, self.y.shape[1])) + y = jax.lax.dynamic_slice( + self.y, (i * self._bs, 0), (self._bs, self.y.shape[1])) g_ = jax.lax.dynamic_slice(g, (i * self._bs,), (self._bs,)) if self._axis_norm is None: norm_y = self._norm_y else: - norm_y = jax.lax.dynamic_slice(self._norm_y, (i * self._bs,), (self._bs,)) - h_res, h_sgn = app(self.x, y, self._norm_x, norm_y, f, g_, eps, vec, self._cost_fn, self.power) + norm_y = jax.lax.dynamic_slice( + self._norm_y, (i * self._bs,), (self._bs,)) + h_res, h_sgn = app(self.x, y, self._norm_x, norm_y, f, g_, eps, vec, + self._cost_fn, self.power, self.scale_cost) return carry, (h_res, h_sgn) def body1(carry, i: int): f, g, eps, vec = carry - x = jax.lax.dynamic_slice(self.x, (i * self._bs, 0), (self._bs, self.x.shape[1])) + x = jax.lax.dynamic_slice( + self.x, (i * self._bs, 0), (self._bs, self.x.shape[1])) f_ = jax.lax.dynamic_slice(f, (i * self._bs,), (self._bs,)) if self._axis_norm is None: norm_x = self._norm_x else: - norm_x = jax.lax.dynamic_slice(self._norm_x, (i * self._bs,), (self._bs,)) - h_res, h_sgn = app(self.y, x, self._norm_y, norm_x, g, f_, eps, vec, self._cost_fn, self.power) + norm_x = jax.lax.dynamic_slice( + self._norm_x, (i * self._bs,), (self._bs,)) + h_res, h_sgn = app(self.y, x, self._norm_y, norm_x, g, f_, eps, vec, + self._cost_fn, self.power, self.scale_cost) return carry, (h_res, h_sgn) def finalize(i: int): if axis == 0: norm_y = self._norm_y if self._axis_norm is None else self._norm_y[i:] - return app(self.x, self.y[i:], self._norm_x, norm_y, f, g[i:], eps, vec, self._cost_fn, self.power) + return app(self.x, self.y[i:], self._norm_x, norm_y, f, g[i:], eps, + vec, self._cost_fn, self.power, self.scale_cost) norm_x = self._norm_x if self._axis_norm is None else self._norm_x[i:] - return app(self.y, self.x[i:], self._norm_y, norm_x, g, f[i:], eps, vec, self._cost_fn, self.power) + return app(self.y, self.x[i:], self._norm_y, norm_x, + g, f[i:], eps, vec, self._cost_fn, self.power, self.scale_cost) if not self._online: return super().apply_lse_kernel(f, g, eps, vec, axis) app = jax.vmap( _apply_lse_kernel_xy, - in_axes=[None, 0, None, self._axis_norm, None, 0, None, None, None, None] + in_axes=[None, 0, None, self._axis_norm, None, 0, None, None, None, + None, None] ) if axis == 0: @@ -188,7 +254,8 @@ def finalize(i: int): else: raise ValueError(axis) - _, (h_res, h_sign) = jax.lax.scan(fun, init=(f, g, eps, vec), xs=jnp.arange(n)) + _, (h_res, h_sign) = jax.lax.scan( + fun, init=(f, g, eps, vec), xs=jnp.arange(n)) h_res, h_sign = jnp.concatenate(h_res), jnp.concatenate(h_sign) h_res_rest, h_sign_rest = finalize(n * self._bs) h_res = jnp.concatenate([h_res, h_res_rest]) @@ -207,29 +274,29 @@ def apply_kernel(self, return super().apply_kernel(scaling, eps, axis) app = jax.vmap(_apply_kernel_xy, in_axes=[ - None, 0, None, self._axis_norm, None, None, None, None]) + None, 0, None, self._axis_norm, None, None, None, None, None]) if axis == 0: return app(self.x, self.y, self._norm_x, self._norm_y, scaling, eps, - self._cost_fn, self.power) + self._cost_fn, self.power, self.scale_cost) if axis == 1: return app(self.y, self.x, self._norm_y, self._norm_x, scaling, eps, - self._cost_fn, self.power) + self._cost_fn, self.power, self.scale_cost) def transport_from_potentials(self, f, g): if not self._online: return super().transport_from_potentials(f, g) transport = jax.vmap(_transport_from_potentials_xy, in_axes=[ - None, 0, None, self._axis_norm, None, 0, None, None, None]) + None, 0, None, self._axis_norm, None, 0, None, None, None, None]) return transport(self.y, self.x, self._norm_y, self._norm_x, g, f, - self.epsilon, self._cost_fn, self.power) + self.epsilon, self._cost_fn, self.power, self.scale_cost) def transport_from_scalings(self, u, v): if not self._online: return super().transport_from_scalings(u, v) transport = jax.vmap(_transport_from_scalings_xy, in_axes=[ - None, 0, None, self._axis_norm, None, 0, None, None, None]) + None, 0, None, self._axis_norm, None, 0, None, None, None, None]) return transport(self.y, self.x, self._norm_y, self._norm_x, v, u, - self.epsilon, self._cost_fn, self.power) + self.epsilon, self._cost_fn, self.power, self.scale_cost) def apply_cost(self, arr: jnp.ndarray, @@ -268,13 +335,13 @@ def _apply_cost(self, """See apply_cost.""" if self._online: app = jax.vmap(_apply_cost_xy, in_axes=[ - None, 0, None, self._axis_norm, None, None, None, None]) + None, 0, None, self._axis_norm, None, None, None, None, None]) if axis == 0: return app(self.x, self.y, self._norm_x, self._norm_y, arr, - self._cost_fn, self.power, fn) + self._cost_fn, self.power, self.scale_cost, fn) if axis == 1: return app(self.y, self.x, self._norm_y, self._norm_x, arr, - self._cost_fn, self.power, fn) + self._cost_fn, self.power, self.scale_cost, fn) else: return super().apply_cost(arr, axis, fn) @@ -311,7 +378,8 @@ def vec_apply_cost(self, applied_cost += ny.reshape(-1, 1) * jnp.sum(arr, axis=0).reshape(1, -1) cross_term = -2.0 * jnp.dot(y, jnp.dot(x.T, arr)) applied_cost += cross_term[:, None] if rank == 1 else cross_term - return fn(applied_cost) if fn else applied_cost + return (fn(applied_cost) * self.scale_cost if fn + else applied_cost * self.scale_cost) @classmethod def prepare_divergences(cls, *args, static_b: bool = False, **kwargs): @@ -322,7 +390,8 @@ def prepare_divergences(cls, *args, static_b: bool = False, **kwargs): def tree_flatten(self): return ((self.x, self.y, self._epsilon, self._cost_fn), - {'online': self._online, 'power': self.power}) + {'online': self._online, 'power': self.power, + 'scale_cost': self._scale_cost}) # Passing self.power in aux_data to be able to condition on it. @classmethod @@ -331,7 +400,7 @@ def tree_unflatten(cls, aux_data, children): return cls(*children[:2], epsilon=eps, cost_fn=fn, **aux_data) def to_LRCGeometry(self, scale=1.0): - """Convert sqEuc. PointCloud to LRCGeometry if useful, and rescale.""" + """Converts sqEuc. PointCloud to LRCGeometry if useful, and rescale.""" if self.is_squared_euclidean: (n, m), d = self.shape, self.x.shape[1] if n * m > (n + m) * d: # here apply_cost using LRCGeometry preferable. @@ -353,7 +422,7 @@ def to_LRCGeometry(self, scale=1.0): cost_2=cost_2, epsilon=self._epsilon_init, relative_epsilon=self._relative_epsilon, - scale=self._scale, + scale=self._scale_epsilon, **self._kwargs) else: self.x *= jnp.sqrt(scale) @@ -363,33 +432,38 @@ def to_LRCGeometry(self, scale=1.0): raise ValueError('Cannot turn non-sq-Euclidean geometry into low-rank') -def _apply_lse_kernel_xy(x, y, norm_x, norm_y, f, g, eps, - vec, cost_fn, cost_pow): - c = _cost(x, y, norm_x, norm_y, cost_fn, cost_pow) +def _apply_lse_kernel_xy( + x, y, norm_x, norm_y, f, g, eps, vec, cost_fn, cost_pow, scale_cost): + c = _cost(x, y, norm_x, norm_y, cost_fn, cost_pow, scale_cost) return ops.logsumexp((f + g - c) / eps, b=vec, return_sign=True, axis=-1) -def _transport_from_potentials_xy(x, y, norm_x, norm_y, f, g, eps, cost_fn, - cost_pow): - return jnp.exp((f + g - _cost(x, y, norm_x, norm_y, cost_fn, cost_pow)) / eps) +def _transport_from_potentials_xy( + x, y, norm_x, norm_y, f, g, eps, cost_fn, cost_pow, scale_cost): + return jnp.exp(( + f + g - _cost(x, y, norm_x, norm_y, cost_fn, cost_pow, scale_cost)) / eps) -def _apply_kernel_xy(x, y, norm_x, norm_y, vec, eps, cost_fn, cost_pow): - c = _cost(x, y, norm_x, norm_y, cost_fn, cost_pow) +def _apply_kernel_xy( + x, y, norm_x, norm_y, vec, eps, cost_fn, cost_pow, scale_cost): + c = _cost(x, y, norm_x, norm_y, cost_fn, cost_pow, scale_cost) return jnp.dot(jnp.exp(-c / eps), vec) -def _transport_from_scalings_xy(x, y, norm_x, norm_y, u, v, eps, cost_fn, - cost_pow): - return jnp.exp(- _cost(x, y, norm_x, norm_y, cost_fn, cost_pow) / eps) * u * v +def _transport_from_scalings_xy( + x, y, norm_x, norm_y, u, v, eps, cost_fn, cost_pow, scale_cost): + return jnp.exp(- _cost(x, y, norm_x, norm_y, cost_fn, cost_pow, scale_cost) + * scale_cost / eps) * u * v -def _cost(x, y, norm_x, norm_y, cost_fn, cost_pow): +def _cost(x, y, norm_x, norm_y, cost_fn, cost_pow, scale_cost): one_line_pairwise = jax.vmap(cost_fn.pairwise, in_axes=[0, None]) - return (norm_x + norm_y + one_line_pairwise(x, y)) ** (0.5 * cost_pow) + return ((norm_x + norm_y + one_line_pairwise(x, y)) ** (0.5 * cost_pow) + * scale_cost) -def _apply_cost_xy(x, y, norm_x, norm_y, vec, cost_fn, cost_pow, fn=None): +def _apply_cost_xy( + x, y, norm_x, norm_y, vec, cost_fn, cost_pow, scale_cost, fn=None): """Applies [num_b, num_a] fn(cost) matrix (or transpose) to vector. Applies [num_b, num_a] ([num_a, num_b] if axis=1 from `apply_cost`) @@ -403,13 +477,14 @@ def _apply_cost_xy(x, y, norm_x, norm_y, vec, cost_fn, cost_pow, fn=None): vec: jnp.ndarray [num_a,] ([num_b,] if axis=1 from `apply_cost`) vector cost_fn: a CostFn function between two points in dimension d. cost_pow: a power to raise (norm(x) + norm(y) + cost(x,y)) ** + scale_cost: scaling factor of the cost matrix. fn: function optionally applied to cost matrix element-wise, before the - apply + apply. Returns: A jnp.ndarray corresponding to cost x vector """ - c = _cost(x, y, norm_x, norm_y, cost_fn, cost_pow) + c = _cost(x, y, norm_x, norm_y, cost_fn, cost_pow, scale_cost) return jnp.dot(c, vec) if fn is None else jnp.dot(fn(c), vec) diff --git a/tests/geometry/scaling_cost_test.py b/tests/geometry/scaling_cost_test.py new file mode 100644 index 000000000..aa2614c7e --- /dev/null +++ b/tests/geometry/scaling_cost_test.py @@ -0,0 +1,158 @@ +# coding=utf-8 +# Copyright 2022 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the option to scale the cost matrix.""" + +from absl.testing import absltest +from absl.testing import parameterized +import jax +import jax.numpy as jnp +import numpy as np +from ott.core import problems +from ott.core import sinkhorn +from ott.core import sinkhorn_lr +from ott.geometry import geometry +from ott.geometry import low_rank +from ott.geometry import pointcloud + + +class ScaleCostTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.rng = jax.random.PRNGKey(0) + self.dim = 4 + self.n = 7 + self.m = 9 + self.rng, *rngs = jax.random.split(self.rng, 8) + self.x = jax.random.uniform(rngs[0], (self.n, self.dim)) + self.y = jax.random.uniform(rngs[1], (self.m, self.dim)) + self.a = jax.random.uniform(rngs[2], (self.n,)) + self.b = jax.random.uniform(rngs[3], (self.m,)) + self.cost = ((self.x[:, None, :] - self.y[None, :, :])**2).sum(-1) + self.vec = jax.random.uniform(rngs[4], (self.m,)) + self.cost1 = jax.random.uniform(rngs[5], (self.n, 2)) + self.cost2 = jax.random.uniform(rngs[6], (self.m, 2)) + + @parameterized.parameters( + ['median', 'mean', 'max_cost', 'max_norm', 'max_bound', 100.]) + def test_scale_cost_pointcloud(self, scale): + """Test various scale cost options for pointcloud.""" + + def apply_sinkhorn(x, y, a, b, scale_cost): + geom = pointcloud.PointCloud( + x, y, epsilon=1e-2, scale_cost=scale_cost) + out = sinkhorn.sinkhorn(geom, a, b) + transport = geom.transport_from_potentials(out.f, out.g) + return geom, out, transport + + geom0, _, _ = apply_sinkhorn( + self.x, self.y, self.a, self.b, scale_cost=1.0) + + geom, out, transport = apply_sinkhorn( + self.x, self.y, self.a, self.b, scale_cost=scale) + + apply_cost_vec = geom.apply_cost(self.vec, axis=1) + apply_transport_vec = geom.apply_transport_from_potentials( + out.f, out.g, self.vec, axis=1) + + np.testing.assert_allclose( + jnp.matmul(transport, self.vec), apply_transport_vec, rtol=1e-4) + np.testing.assert_allclose( + geom0.apply_cost(self.vec, axis=1) * geom.scale_cost, + apply_cost_vec, rtol=1e-4) + + @parameterized.parameters(['max_norm', 'max_bound', 100.]) + def test_scale_cost_pointcloud_online(self, scale): + """Test various scale cost options for point cloud with online option.""" + + def apply_sinkhorn(x, y, a, b, scale_cost): + geom = pointcloud.PointCloud( + x, y, epsilon=1e-2, scale_cost=scale_cost, online=True) + out = sinkhorn.sinkhorn(geom, a, b) + transport = geom.transport_from_potentials(out.f, out.g) + return geom, out, transport + + geom0 = pointcloud.PointCloud( + self.x, self.y, epsilon=1e-2, scale_cost=1.0, online=True) + + geom, out, transport = apply_sinkhorn( + self.x, self.y, self.a, self.b, scale_cost=scale) + + apply_cost_vec = geom.apply_cost(self.vec, axis=1) + apply_transport_vec = geom.apply_transport_from_potentials( + out.f, out.g, self.vec, axis=1) + + np.testing.assert_allclose( + jnp.matmul(transport, self.vec), apply_transport_vec, rtol=1e-4) + np.testing.assert_allclose( + geom0.apply_cost(self.vec, axis=1) * geom.scale_cost, + apply_cost_vec, rtol=1e-4) + + @parameterized.parameters(['median', 'mean', 'max_cost', 100.]) + def test_scale_cost_geometry(self, scale): + """Test various scale cost options for geometry.""" + + def apply_sinkhorn(cost, a, b, scale_cost): + geom = geometry.Geometry(cost, epsilon=1e-2, scale_cost=scale_cost) + out = sinkhorn.sinkhorn(geom, a, b) + transport = geom.transport_from_potentials(out.f, out.g) + return geom, out, transport + + geom0 = geometry.Geometry(self.cost, epsilon=1e-2, scale_cost=1.0) + + geom, out, transport = apply_sinkhorn( + self.cost, self.a, self.b, scale_cost=scale) + + apply_cost_vec = geom.apply_cost(self.vec, axis=1) + apply_transport_vec = geom.apply_transport_from_potentials( + out.f, out.g, self.vec, axis=1) + + np.testing.assert_allclose( + jnp.matmul(transport, self.vec), apply_transport_vec, rtol=1e-4) + np.testing.assert_allclose( + geom0.apply_cost(self.vec, axis=1) * geom.scale_cost, + apply_cost_vec, rtol=1e-4) + + @parameterized.parameters(['max_bound', 100.]) + def test_scale_cost_low_rank(self, scale): + """Test various scale cost options for low rank.""" + + def apply_sinkhorn(cost1, cost2, scale_cost): + geom = low_rank.LRCGeometry(cost1, cost2, scale_cost=scale_cost) + ot_prob = problems.LinearProblem(geom, self.a, self.b) + solver = sinkhorn_lr.LRSinkhorn(threshold=1e-3, rank=10) + out = solver(ot_prob) + return geom, out + + geom0 = low_rank.LRCGeometry(self.cost1, self.cost2, scale_cost=1.0) + + geom, out = apply_sinkhorn( + self.cost1, self.cost2, scale_cost=scale) + + apply_cost_vec = geom._apply_cost_to_vec(self.vec, axis=1) + apply_transport_vec = out.apply(self.vec, axis=1) + transport = out.matrix + + np.testing.assert_allclose( + jnp.matmul(transport, self.vec), apply_transport_vec, rtol=1e-4) + np.testing.assert_allclose( + geom0._apply_cost_to_vec(self.vec, axis=1) * geom.scale_cost, + apply_cost_vec, rtol=1e-4) + + +if __name__ == '__main__': + absltest.main() +