Skip to content

Commit

Permalink
Scaling factor for the cost matrix in online mode.
Browse files Browse the repository at this point in the history
Adding scaling factors for the cost matrix for LR.

PiperOrigin-RevId: 436846816
  • Loading branch information
LaetitiaPapaxanthos committed Mar 23, 2022
1 parent c855e02 commit 256b2d2
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 28 deletions.
2 changes: 2 additions & 0 deletions ott/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ def scale_cost(self):
return jax.lax.stop_gradient(1.0 / jnp.mean(self._cost_matrix))
elif self._scale_cost == 'median':
return jax.lax.stop_gradient(1.0 / jnp.median(self._cost_matrix))
elif isinstance(self._scale_cost, str):
raise ValueError(f'Scaling {self._scale_cost} not implemented.')
else:
return 1.0

Expand Down
30 changes: 17 additions & 13 deletions ott/geometry/low_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def __init__(self,
self._cost_1 = cost_1
self._cost_2 = cost_2
self._bias = bias
self._scale_cost = scale_cost
self._kwargs = kwargs

super().__init__(**kwargs)
self._scale_cost = scale_cost

@property
def cost_1(self):
Expand All @@ -67,38 +67,42 @@ def bias(self):

@property
def cost_rank(self):
return self.cost_1.shape[1]
return self._cost_1.shape[1]

@property
def cost_matrix(self):
"""Returns cost matrix if requested."""
return (
jnp.matmul(self.cost_1, self.cost_2.T) + self.bias) * self.scale_cost
return (jnp.matmul(self.cost_1, self.cost_2.T) + self.bias)

@property
def shape(self):
return (self.cost_1.shape[0], self.cost_2.shape[0])
return (self._cost_1.shape[0], self._cost_2.shape[0])

@property
def is_symmetric(self):
return (self.cost_1.shape[0] == self.cost_2.shape[0] and
jnp.all(self.cost_1 == self.cost_2))
return (self._cost_1.shape[0] == self._cost_2.shape[0] and
jnp.all(self._cost_1 == self._cost_2))

@property
def scale_cost(self):
if isinstance(self._scale_cost, float):
return self._scale_cost
elif self._scale_cost == 'max_bound':
return jax.lax.stop_gradient(
1.0 / (jnp.max(jnp.abs(self.cost_1))
* jnp.max(jnp.abs(self.cost_2))
+ jnp.abs(self.bias)))
1.0 / (jnp.max(jnp.abs(self._cost_1))
* jnp.max(jnp.abs(self._cost_2))
+ jnp.abs(self._bias)))
elif self._scale_cost == 'mean':
# TODO(lpapaxanthos): implement memory efficient mean.
return 1.0
factor1 = jnp.dot(jnp.ones(self.shape[0]), self._cost_1)
factor2 = jnp.dot(self._cost_2.T, jnp.ones(self.shape[1]))
mean = (jnp.dot(factor1, factor2) / (self.shape[0] * self.shape[1])
+ self._bias)
return 1.0 / mean
elif self._scale_cost == 'max_cost':
# TODO(lpapaxanthos): implement memory efficient max.
return 1.0
raise NotImplementedError(f'Scaling {self._scale_cost} not implemented.')
elif isinstance(self._scale_cost, str):
raise ValueError(f'Scaling {self._scale_cost} not implemented.')
else:
return 1.0

Expand Down
104 changes: 98 additions & 6 deletions ott/geometry/pointcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,12 @@ def scale_cost(self):
return 1.0 / self._scale_cost
elif self._scale_cost == 'max_cost':
if self.is_online:
# TODO(lpapaxanthos): implement memory efficient max.
return 1.0
return self.compute_summary_online(self._scale_cost)
else:
return jax.lax.stop_gradient(1.0 / jnp.max(self.compute_cost_matrix()))
elif self._scale_cost == 'mean':
if self.is_online:
# TODO(lpapaxanthos): implement memory efficient mean.
return 1.0
return self.compute_summary_online(self._scale_cost)
else:
if isinstance(self.shape[0], int) and (self.shape[0] > 0):
return jax.lax.stop_gradient(
Expand Down Expand Up @@ -184,6 +182,8 @@ def scale_cost(self):
return jax.lax.stop_gradient(1.0 / max_bound)
else:
return 1.0
elif isinstance(self._scale_cost, str):
raise ValueError(f'Scaling {self._scale_cost} not implemented.')
else:
return 1.0

Expand Down Expand Up @@ -246,10 +246,10 @@ def finalize(i: int):
)

if axis == 0:
fun, size = body0, self.shape[1]
fun = body0
v, n = g, self._y_nsplit
elif axis == 1:
fun, size = body1, self.shape[0]
fun = body1
v, n = f, self._x_nsplit
else:
raise ValueError(axis)
Expand Down Expand Up @@ -381,6 +381,93 @@ def vec_apply_cost(self,
return (fn(applied_cost) * self.scale_cost if fn
else applied_cost * self.scale_cost)

def leading_slice(self, t: jnp.ndarray, i: int) -> jnp.ndarray:
start_indices = [i * self._bs] + (t.ndim - 1) * [0]
slice_sizes = [self._bs] + list(t.shape[1:])
return jax.lax.dynamic_slice(t, start_indices, slice_sizes)

def compute_summary_online(self, summary: str) -> float:
"""Compute mean or max of cost matrix online, i.e. without instantiating it.
Args:
summary: str, can be 'mean' or 'max_cost'
Returns:
summary statistics
"""
scale_cost = 1.0

def body0(carry, i: int):
vec, = carry
y = self.leading_slice(self.y, i)
if self._axis_norm is None:
norm_y = self._norm_y
else:
norm_y = self.leading_slice(self._norm_y, i)
h_res = app(
self.x, y, self._norm_x, norm_y, vec,
self._cost_fn, self.power, scale_cost)
return carry, (h_res,)

def body1(carry, i: int):
vec, = carry
x = self.leading_slice(self.x, i)
if self._axis_norm is None:
norm_x = self._norm_x
else:
norm_x = self.leading_slice(self._norm_x, i)
h_res = app(
self.y, x, self._norm_y, norm_x, vec,
self._cost_fn, self.power, scale_cost)
return carry, (h_res,)

def finalize(i: int):
if batch_for_y:
norm_y = self._norm_y if self._axis_norm is None else self._norm_y[i:]
return app(
self.x, self.y[i:], self._norm_x, norm_y, vec,
self._cost_fn, self.power, scale_cost)
norm_x = self._norm_x if self._axis_norm is None else self._norm_x[i:]
return app(
self.y, self.x[i:], self._norm_y, norm_x, vec,
self._cost_fn, self.power, scale_cost)

if summary == 'mean':
fn = _apply_cost_xy
elif summary == 'max_cost':
fn = _apply_max_xy
else:
raise ValueError(
f'Scaling method {summary} does not exist for online mode.')
app = jax.vmap(
fn,
in_axes=[None, 0, None, self._axis_norm, None, None, None, None]
)

batch_for_y = self.shape[0] < self.shape[1]
if batch_for_y:
fun = body0
n = self._y_nsplit
vec = jnp.ones(self.shape[0]) / (self.shape[1] * self.shape[0])
else:
fun = body1
n = self._x_nsplit
vec = jnp.ones(self.shape[1]) / (self.shape[1] * self.shape[0])

_, val = jax.lax.scan(
fun, init=(vec,), xs=jnp.arange(n))
val = jnp.concatenate(val).squeeze()
val_rest = finalize(n * self._bs)
val_res = jnp.concatenate([val, val_rest])

if summary == 'mean':
return 1.0 / jnp.sum(val_res)
elif summary == 'max_cost':
return 1.0 / jnp.max(val_res)
else:
raise ValueError(
f'Scaling method {summary} does not exist for online mode.')

@classmethod
def prepare_divergences(cls, *args, static_b: bool = False, **kwargs):
"""Instantiates the geometries used for a divergence computation."""
Expand Down Expand Up @@ -488,3 +575,8 @@ def _apply_cost_xy(
return jnp.dot(c, vec) if fn is None else jnp.dot(fn(c), vec)


def _apply_max_xy(
x, y, norm_x, norm_y, vec, cost_fn, cost_pow, scale_cost):
del vec
c = _cost(x, y, norm_x, norm_y, cost_fn, cost_pow, scale_cost)
return jnp.max(jnp.abs(c))
37 changes: 28 additions & 9 deletions tests/geometry/scaling_cost_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def setUp(self):
self.vec = jax.random.uniform(rngs[4], (self.m,))
self.cost1 = jax.random.uniform(rngs[5], (self.n, 2))
self.cost2 = jax.random.uniform(rngs[6], (self.m, 2))
self.eps = 5e-2

@parameterized.parameters(
['median', 'mean', 'max_cost', 'max_norm', 'max_bound', 100.])
Expand All @@ -53,15 +54,15 @@ def test_scale_cost_pointcloud(self, scale):

def apply_sinkhorn(x, y, a, b, scale_cost):
geom = pointcloud.PointCloud(
x, y, epsilon=1e-2, scale_cost=scale_cost)
x, y, epsilon=self.eps, scale_cost=scale_cost)
out = sinkhorn.sinkhorn(geom, a, b)
transport = geom.transport_from_potentials(out.f, out.g)
return geom, out, transport

geom0, _, _ = apply_sinkhorn(
self.x, self.y, self.a, self.b, scale_cost=1.0)

geom, out, transport = apply_sinkhorn(
geom, out, transport = jax.jit(apply_sinkhorn, static_argnums=4)(
self.x, self.y, self.a, self.b, scale_cost=scale)

apply_cost_vec = geom.apply_cost(self.vec, axis=1)
Expand All @@ -74,19 +75,19 @@ def apply_sinkhorn(x, y, a, b, scale_cost):
geom0.apply_cost(self.vec, axis=1) * geom.scale_cost,
apply_cost_vec, rtol=1e-4)

@parameterized.parameters(['max_norm', 'max_bound', 100.])
@parameterized.parameters(['mean', 'max_cost', 'max_norm', 'max_bound', 100.])
def test_scale_cost_pointcloud_online(self, scale):
"""Test various scale cost options for point cloud with online option."""

def apply_sinkhorn(x, y, a, b, scale_cost):
geom = pointcloud.PointCloud(
x, y, epsilon=1e-2, scale_cost=scale_cost, online=True)
x, y, epsilon=self.eps, scale_cost=scale_cost, online=True)
out = sinkhorn.sinkhorn(geom, a, b)
transport = geom.transport_from_potentials(out.f, out.g)
return geom, out, transport

geom0 = pointcloud.PointCloud(
self.x, self.y, epsilon=1e-2, scale_cost=1.0, online=True)
self.x, self.y, epsilon=self.eps, scale_cost=1.0, online=True)

geom, out, transport = apply_sinkhorn(
self.x, self.y, self.a, self.b, scale_cost=scale)
Expand All @@ -101,12 +102,27 @@ def apply_sinkhorn(x, y, a, b, scale_cost):
geom0.apply_cost(self.vec, axis=1) * geom.scale_cost,
apply_cost_vec, rtol=1e-4)

@parameterized.parameters(['mean', 'max_cost', 'max_norm', 'max_bound', 100.])
def test_online_matches_notonline_pointcloud(self, scale):
"""Tests that the scale factors for online matches the ones without."""
geom0 = pointcloud.PointCloud(
self.x, self.y, epsilon=self.eps, scale_cost=scale, online=True)
geom1 = pointcloud.PointCloud(
self.x, self.y, epsilon=self.eps, scale_cost=scale, online=None)
np.testing.assert_allclose(geom0.scale_cost, geom1.scale_cost, rtol=1e-4)
if scale == 'mean':
np.testing.assert_allclose(
1.0, geom1.cost_matrix.mean(), rtol=1e-4)
elif scale == 'max_cost':
np.testing.assert_allclose(
1.0, geom1.cost_matrix.max(), rtol=1e-4)

@parameterized.parameters(['median', 'mean', 'max_cost', 100.])
def test_scale_cost_geometry(self, scale):
"""Test various scale cost options for geometry."""

def apply_sinkhorn(cost, a, b, scale_cost):
geom = geometry.Geometry(cost, epsilon=1e-2, scale_cost=scale_cost)
geom = geometry.Geometry(cost, epsilon=self.eps, scale_cost=scale_cost)
out = sinkhorn.sinkhorn(geom, a, b)
transport = geom.transport_from_potentials(out.f, out.g)
return geom, out, transport
Expand All @@ -126,7 +142,7 @@ def apply_sinkhorn(cost, a, b, scale_cost):
geom0.apply_cost(self.vec, axis=1) * geom.scale_cost,
apply_cost_vec, rtol=1e-4)

@parameterized.parameters(['max_bound', 100.])
@parameterized.parameters(['mean', 'max_bound', 100.])
def test_scale_cost_low_rank(self, scale):
"""Test various scale cost options for low rank."""

Expand All @@ -139,8 +155,7 @@ def apply_sinkhorn(cost1, cost2, scale_cost):

geom0 = low_rank.LRCGeometry(self.cost1, self.cost2, scale_cost=1.0)

geom, out = apply_sinkhorn(
self.cost1, self.cost2, scale_cost=scale)
geom, out = apply_sinkhorn(self.cost1, self.cost2, scale_cost=scale)

apply_cost_vec = geom._apply_cost_to_vec(self.vec, axis=1)
apply_transport_vec = out.apply(self.vec, axis=1)
Expand All @@ -152,6 +167,10 @@ def apply_sinkhorn(cost1, cost2, scale_cost):
geom0._apply_cost_to_vec(self.vec, axis=1) * geom.scale_cost,
apply_cost_vec, rtol=1e-4)

if scale == 'mean':
np.testing.assert_allclose(
1.0, geom.cost_matrix.mean(), rtol=1e-4)


if __name__ == '__main__':
absltest.main()
Expand Down

0 comments on commit 256b2d2

Please sign in to comment.