diff --git a/ott/geometry/pointcloud.py b/ott/geometry/pointcloud.py index f42827a62..25cc21930 100644 --- a/ott/geometry/pointcloud.py +++ b/ott/geometry/pointcloud.py @@ -15,7 +15,9 @@ # Lint as: python3 """A geometry defined using 2 point clouds and a cost function between them.""" -from typing import Optional +from typing import Optional, Union + +import math import jax import jax.numpy as jnp @@ -34,7 +36,7 @@ def __init__(self, y: Optional[jnp.ndarray] = None, cost_fn: Optional[costs.CostFn] = None, power: float = 2.0, - online: bool = False, + online: Optional[Union[bool, int]] = None, **kwargs): """Creates a geometry from two point clouds, using CostFn. @@ -54,7 +56,8 @@ def __init__(self, power: a power to raise (norm(x) + norm(y) + cost(x,y)) ** online: whether to run the online version of the computation or not. The online computation is particularly useful for big point clouds such that - their cost matrix does not fit in memory. + their cost matrix does not fit in memory. This is done by batching + :meth:`apply_lse_kernel`. If `True`, batch size of 1024 is used. **kwargs: other optional parameters to be passed on to superclass initializer, notably those related to epsilon regularization. """ @@ -63,11 +66,22 @@ def __init__(self, self._axis_norm = 0 if callable(self._cost_fn.norm) else None self.x = x - self.y = y if y is not None else x + self.y = self.x if y is None else y + + if online is True: + online = 1024 + if online: + assert online > 0, f"`online={online}` must be positive." + n, m = self.shape + self._bs = min(online, online, *(() + ((n,) if n else ()) + ((m,) if m else ()))) + # use `floor` instead of `ceil` and handle the rest seperately + self._x_nsplit = int(math.floor(n / self._bs)) + self._y_nsplit = int(math.floor(m / self._bs)) + else: + self._bs = self._x_nsplit = self._y_nsplit = None - self.power = power self._online = online - + self.power = power super().__init__(**kwargs) @property @@ -101,8 +115,13 @@ def kernel_matrix(self): @property def shape(self): - return (self.x.shape[0] if self.x is not None else 0, - self.y.shape[0] if self.y is not None else 0) + # in the process of flattening/unflattening in vmap, `__init__` can be called with dummy objects + # we optionally access `shape` in order to get the batch size + try: + return (self.x.shape[0] if self.x is not None else 0, + self.y.shape[0] if self.y is not None else 0) + except AttributeError: + return 0, 0 @property def is_symmetric(self): @@ -115,7 +134,7 @@ def is_squared_euclidean(self): @property def is_online(self) -> bool: - return self._online + return self._online is not None def apply_lse_kernel(self, f: jnp.ndarray, @@ -123,24 +142,59 @@ def apply_lse_kernel(self, eps: float, vec: jnp.ndarray = None, axis: int = 0) -> jnp.ndarray: + def body0(carry, i: int): + f, g, eps, vec = carry + y = jax.lax.dynamic_slice(self.y, (i * self._bs, 0), (self._bs, self.y.shape[1])) + g_ = jax.lax.dynamic_slice(g, (i * self._bs,), (self._bs,)) + if self._axis_norm is None: + norm_y = self._norm_y + else: + norm_y = jax.lax.dynamic_slice(self._norm_y, (i * self._bs,), (self._bs,)) + h_res, h_sgn = app(self.x, y, self._norm_x, norm_y, f, g_, eps, vec, self._cost_fn, self.power) + return carry, (h_res, h_sgn) + + def body1(carry, i: int): + f, g, eps, vec = carry + x = jax.lax.dynamic_slice(self.x, (i * self._bs, 0), (self._bs, self.x.shape[1])) + f_ = jax.lax.dynamic_slice(f, (i * self._bs,), (self._bs,)) + if self._axis_norm is None: + norm_x = self._norm_x + else: + norm_x = jax.lax.dynamic_slice(self._norm_x, (i * self._bs,), (self._bs,)) + h_res, h_sgn = app(self.y, x, self._norm_y, norm_x, g, f_, eps, vec, self._cost_fn, self.power) + return carry, (h_res, h_sgn) + + def finalize(i: int): + if axis == 0: + norm_y = self._norm_y if self._axis_norm is None else self._norm_y[i:] + return app(self.x, self.y[i:], self._norm_x, norm_y, f, g[i:], eps, vec, self._cost_fn, self.power) + norm_x = self._norm_x if self._axis_norm is None else self._norm_x[i:] + return app(self.y, self.x[i:], self._norm_y, norm_x, g, f[i:], eps, vec, self._cost_fn, self.power) + if not self._online: return super().apply_lse_kernel(f, g, eps, vec, axis) app = jax.vmap( _apply_lse_kernel_xy, - in_axes=[ - None, 0, None, self._axis_norm, None, 0, None, None, None, None - ]) + in_axes=[None, 0, None, self._axis_norm, None, 0, None, None, None, None] + ) if axis == 0: - h_res, h_sgn = app(self.x, self.y, self._norm_x, self._norm_y, f, g, eps, - vec, self._cost_fn, self.power) - h_res = eps * h_res - jnp.where(jnp.isfinite(g), g, 0) - if axis == 1: - h_res, h_sgn = app(self.y, self.x, self._norm_y, self._norm_x, g, f, eps, - vec, self._cost_fn, self.power) - h_res = eps * h_res - jnp.where(jnp.isfinite(f), f, 0) - return h_res, h_sgn + fun, size = body0, self.shape[1] + v, n = g, self._y_nsplit + elif axis == 1: + fun, size = body1, self.shape[0] + v, n = f, self._x_nsplit + else: + raise ValueError(axis) + + _, (h_res, h_sign) = jax.lax.scan(fun, init=(f, g, eps, vec), xs=jnp.arange(n)) + h_res, h_sign = jnp.concatenate(h_res), jnp.concatenate(h_sign) + h_res_rest, h_sign_rest = finalize(n * self._bs) + h_res = jnp.concatenate([h_res, h_res_rest]) + h_sign = jnp.concatenate([h_sign, h_sign_rest]) + + return eps * h_res - jnp.where(jnp.isfinite(v), v, 0), h_sign def apply_kernel(self, scaling: jnp.ndarray, diff --git a/tests/core/sinkhorn_online_large_test.py b/tests/core/sinkhorn_online_large_test.py index a5def3619..3ce7e5f00 100644 --- a/tests/core/sinkhorn_online_large_test.py +++ b/tests/core/sinkhorn_online_large_test.py @@ -15,12 +15,15 @@ # Lint as: python3 """Tests Online option for PointCloud geometry.""" +from functools import partial from absl.testing import absltest from absl.testing import parameterized import jax import jax.numpy as jnp import jax.test_util +import numpy as np from ott.core import sinkhorn +from ott.core.sinkhorn import SinkhornOutput from ott.geometry import pointcloud @@ -59,6 +62,74 @@ def test_euclidean_point_cloud(self, lse_mode): err = errors[errors > -1][-1] self.assertGreater(threshold, err) + @parameterized.parameters([1], [13], [402], [4000]) + def test_online_matches_offline_size(self, online: int): + threshold, rtol, atol = 1e-1, 1e-6, 1e-6 + geom_offline = pointcloud.PointCloud(self.x, self.y, epsilon=1, online=False) + geom_online = pointcloud.PointCloud(self.x, self.y, epsilon=1, online=online) + + sol_online = sinkhorn.sinkhorn( + geom_online, + a=self.a, + b=self.b, + threshold=threshold, + lse_mode=True, + implicit_differentiation=True + ) + errors_online = sol_online.errors + err_online = errors_online[errors_online > -1][-1] + + sol_offline = sinkhorn.sinkhorn( + geom_offline, + a=self.a, + b=self.b, + threshold=threshold, + lse_mode=True, + implicit_differentiation=True + ) + + self.assertGreater(threshold, err_online) + np.testing.assert_allclose(sol_online.matrix, sol_offline.matrix, rtol=rtol, atol=atol) + np.testing.assert_allclose(sol_online.a, sol_offline.a, rtol=rtol, atol=atol) + np.testing.assert_allclose(sol_online.b, sol_offline.b, rtol=rtol, atol=atol) + + def test_online_sinkhorn_jit(self): + threshold = 1e-1 + geom = pointcloud.PointCloud(self.x, self.y, epsilon=1, online=512) + errors = sinkhorn.sinkhorn( + geom, + a=self.a, + b=self.b, + threshold=threshold, + jit=True, + lse_mode=True, + implicit_differentiation=True + ).errors + err = errors[errors > -1][-1] + + self.assertGreater(threshold, err) + + def test_online_external_jit(self): + @partial(jax.jit, static_argnums=1) + def callback(epsilon: float, online: int) -> SinkhornOutput: + geom = pointcloud.PointCloud(self.x, self.y, epsilon=epsilon, online=online) + return sinkhorn.sinkhorn( + geom, + a=self.a, + b=self.b, + threshold=threshold, + jit=True, + lse_mode=True, + implicit_differentiation=True + ) + + threshold = 1e-1 + sol = callback(epsilon=1, online=42) + errors = sol.errors + err = errors[errors > -1][-1] + + self.assertGreater(threshold, err) + if __name__ == '__main__': absltest.main()