Skip to content

Commit

Permalink
Merge pull request #23 from michalk8/feature/apply-lse-batch
Browse files Browse the repository at this point in the history
Batch `apply_lse_kernel` for `online=True`
  • Loading branch information
marcocuturi authored Mar 2, 2022
2 parents ef31e76 + 9764b0d commit f4fceab
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 20 deletions.
94 changes: 74 additions & 20 deletions ott/geometry/pointcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

# Lint as: python3
"""A geometry defined using 2 point clouds and a cost function between them."""
from typing import Optional
from typing import Optional, Union

import math

import jax
import jax.numpy as jnp
Expand All @@ -34,7 +36,7 @@ def __init__(self,
y: Optional[jnp.ndarray] = None,
cost_fn: Optional[costs.CostFn] = None,
power: float = 2.0,
online: bool = False,
online: Optional[Union[bool, int]] = None,
**kwargs):
"""Creates a geometry from two point clouds, using CostFn.
Expand All @@ -54,7 +56,8 @@ def __init__(self,
power: a power to raise (norm(x) + norm(y) + cost(x,y)) **
online: whether to run the online version of the computation or not. The
online computation is particularly useful for big point clouds such that
their cost matrix does not fit in memory.
their cost matrix does not fit in memory. This is done by batching
:meth:`apply_lse_kernel`. If `True`, batch size of 1024 is used.
**kwargs: other optional parameters to be passed on to superclass
initializer, notably those related to epsilon regularization.
"""
Expand All @@ -63,11 +66,22 @@ def __init__(self,
self._axis_norm = 0 if callable(self._cost_fn.norm) else None

self.x = x
self.y = y if y is not None else x
self.y = self.x if y is None else y

if online is True:
online = 1024
if online:
assert online > 0, f"`online={online}` must be positive."
n, m = self.shape
self._bs = min(online, online, *(() + ((n,) if n else ()) + ((m,) if m else ())))
# use `floor` instead of `ceil` and handle the rest seperately
self._x_nsplit = int(math.floor(n / self._bs))
self._y_nsplit = int(math.floor(m / self._bs))
else:
self._bs = self._x_nsplit = self._y_nsplit = None

self.power = power
self._online = online

self.power = power
super().__init__(**kwargs)

@property
Expand Down Expand Up @@ -101,8 +115,13 @@ def kernel_matrix(self):

@property
def shape(self):
return (self.x.shape[0] if self.x is not None else 0,
self.y.shape[0] if self.y is not None else 0)
# in the process of flattening/unflattening in vmap, `__init__` can be called with dummy objects
# we optionally access `shape` in order to get the batch size
try:
return (self.x.shape[0] if self.x is not None else 0,
self.y.shape[0] if self.y is not None else 0)
except AttributeError:
return 0, 0

@property
def is_symmetric(self):
Expand All @@ -115,32 +134,67 @@ def is_squared_euclidean(self):

@property
def is_online(self) -> bool:
return self._online
return self._online is not None

def apply_lse_kernel(self,
f: jnp.ndarray,
g: jnp.ndarray,
eps: float,
vec: jnp.ndarray = None,
axis: int = 0) -> jnp.ndarray:
def body0(carry, i: int):
f, g, eps, vec = carry
y = jax.lax.dynamic_slice(self.y, (i * self._bs, 0), (self._bs, self.y.shape[1]))
g_ = jax.lax.dynamic_slice(g, (i * self._bs,), (self._bs,))
if self._axis_norm is None:
norm_y = self._norm_y
else:
norm_y = jax.lax.dynamic_slice(self._norm_y, (i * self._bs,), (self._bs,))
h_res, h_sgn = app(self.x, y, self._norm_x, norm_y, f, g_, eps, vec, self._cost_fn, self.power)
return carry, (h_res, h_sgn)

def body1(carry, i: int):
f, g, eps, vec = carry
x = jax.lax.dynamic_slice(self.x, (i * self._bs, 0), (self._bs, self.x.shape[1]))
f_ = jax.lax.dynamic_slice(f, (i * self._bs,), (self._bs,))
if self._axis_norm is None:
norm_x = self._norm_x
else:
norm_x = jax.lax.dynamic_slice(self._norm_x, (i * self._bs,), (self._bs,))
h_res, h_sgn = app(self.y, x, self._norm_y, norm_x, g, f_, eps, vec, self._cost_fn, self.power)
return carry, (h_res, h_sgn)

def finalize(i: int):
if axis == 0:
norm_y = self._norm_y if self._axis_norm is None else self._norm_y[i:]
return app(self.x, self.y[i:], self._norm_x, norm_y, f, g[i:], eps, vec, self._cost_fn, self.power)
norm_x = self._norm_x if self._axis_norm is None else self._norm_x[i:]
return app(self.y, self.x[i:], self._norm_y, norm_x, g, f[i:], eps, vec, self._cost_fn, self.power)

if not self._online:
return super().apply_lse_kernel(f, g, eps, vec, axis)

app = jax.vmap(
_apply_lse_kernel_xy,
in_axes=[
None, 0, None, self._axis_norm, None, 0, None, None, None, None
])
in_axes=[None, 0, None, self._axis_norm, None, 0, None, None, None, None]
)

if axis == 0:
h_res, h_sgn = app(self.x, self.y, self._norm_x, self._norm_y, f, g, eps,
vec, self._cost_fn, self.power)
h_res = eps * h_res - jnp.where(jnp.isfinite(g), g, 0)
if axis == 1:
h_res, h_sgn = app(self.y, self.x, self._norm_y, self._norm_x, g, f, eps,
vec, self._cost_fn, self.power)
h_res = eps * h_res - jnp.where(jnp.isfinite(f), f, 0)
return h_res, h_sgn
fun, size = body0, self.shape[1]
v, n = g, self._y_nsplit
elif axis == 1:
fun, size = body1, self.shape[0]
v, n = f, self._x_nsplit
else:
raise ValueError(axis)

_, (h_res, h_sign) = jax.lax.scan(fun, init=(f, g, eps, vec), xs=jnp.arange(n))
h_res, h_sign = jnp.concatenate(h_res), jnp.concatenate(h_sign)
h_res_rest, h_sign_rest = finalize(n * self._bs)
h_res = jnp.concatenate([h_res, h_res_rest])
h_sign = jnp.concatenate([h_sign, h_sign_rest])

return eps * h_res - jnp.where(jnp.isfinite(v), v, 0), h_sign

def apply_kernel(self,
scaling: jnp.ndarray,
Expand Down
71 changes: 71 additions & 0 deletions tests/core/sinkhorn_online_large_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@

# Lint as: python3
"""Tests Online option for PointCloud geometry."""
from functools import partial
from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.numpy as jnp
import jax.test_util
import numpy as np
from ott.core import sinkhorn
from ott.core.sinkhorn import SinkhornOutput
from ott.geometry import pointcloud


Expand Down Expand Up @@ -59,6 +62,74 @@ def test_euclidean_point_cloud(self, lse_mode):
err = errors[errors > -1][-1]
self.assertGreater(threshold, err)

@parameterized.parameters([1], [13], [402], [4000])
def test_online_matches_offline_size(self, online: int):
threshold, rtol, atol = 1e-1, 1e-6, 1e-6
geom_offline = pointcloud.PointCloud(self.x, self.y, epsilon=1, online=False)
geom_online = pointcloud.PointCloud(self.x, self.y, epsilon=1, online=online)

sol_online = sinkhorn.sinkhorn(
geom_online,
a=self.a,
b=self.b,
threshold=threshold,
lse_mode=True,
implicit_differentiation=True
)
errors_online = sol_online.errors
err_online = errors_online[errors_online > -1][-1]

sol_offline = sinkhorn.sinkhorn(
geom_offline,
a=self.a,
b=self.b,
threshold=threshold,
lse_mode=True,
implicit_differentiation=True
)

self.assertGreater(threshold, err_online)
np.testing.assert_allclose(sol_online.matrix, sol_offline.matrix, rtol=rtol, atol=atol)
np.testing.assert_allclose(sol_online.a, sol_offline.a, rtol=rtol, atol=atol)
np.testing.assert_allclose(sol_online.b, sol_offline.b, rtol=rtol, atol=atol)

def test_online_sinkhorn_jit(self):
threshold = 1e-1
geom = pointcloud.PointCloud(self.x, self.y, epsilon=1, online=512)
errors = sinkhorn.sinkhorn(
geom,
a=self.a,
b=self.b,
threshold=threshold,
jit=True,
lse_mode=True,
implicit_differentiation=True
).errors
err = errors[errors > -1][-1]

self.assertGreater(threshold, err)

def test_online_external_jit(self):
@partial(jax.jit, static_argnums=1)
def callback(epsilon: float, online: int) -> SinkhornOutput:
geom = pointcloud.PointCloud(self.x, self.y, epsilon=epsilon, online=online)
return sinkhorn.sinkhorn(
geom,
a=self.a,
b=self.b,
threshold=threshold,
jit=True,
lse_mode=True,
implicit_differentiation=True
)

threshold = 1e-1
sol = callback(epsilon=1, online=42)
errors = sol.errors
err = errors[errors > -1][-1]

self.assertGreater(threshold, err)


if __name__ == '__main__':
absltest.main()

0 comments on commit f4fceab

Please sign in to comment.