diff --git a/ott/core/layers.py b/ott/core/layers.py index c289fff07..2e7bb3fa8 100644 --- a/ott/core/layers.py +++ b/ott/core/layers.py @@ -50,10 +50,14 @@ class PositiveDense(nn.Module): bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.zeros def setup(self): - if round(self.inv_rectifier_fn(self.rectifier_fn(0.1)), 3) != 0.1: - raise RuntimeError( - "Make sure both rectifier and inverse are defined properly." - ) + try: + if round(self.inv_rectifier_fn(self.rectifier_fn(0.1)), 3) != 0.1: + raise RuntimeError( + "Make sure both rectifier and inverse are defined properly." + ) + except TypeError as e: + if "doesn't define __round__ method" not in str(e): + raise # not comparing tracer values, raise @nn.compact def __call__(self, inputs): diff --git a/ott/core/quad_problems.py b/ott/core/quad_problems.py index a475d8fc7..fe40b4da6 100644 --- a/ott/core/quad_problems.py +++ b/ott/core/quad_problems.py @@ -538,5 +538,4 @@ def update_epsilon_unbalanced(epsilon, transport_mass): def apply_cost( geom: geometry.Geometry, arr: jnp.ndarray, *, axis: int, fn: Loss ) -> jnp.ndarray: - # TODO(michalk8): handle PCs return geom.apply_cost(arr, axis=axis, fn=fn.func, is_linear=fn.is_linear) diff --git a/ott/geometry/geometry.py b/ott/geometry/geometry.py index fc7f2613f..da7cb18a2 100644 --- a/ott/geometry/geometry.py +++ b/ott/geometry/geometry.py @@ -15,7 +15,7 @@ # Lint as: python3 """A class describing operations used to instantiate and use a geometry.""" import functools -from typing import Any, Optional, Tuple, Union +from typing import Any, Callable, Optional, Tuple, Union import jax import jax.numpy as jnp @@ -534,7 +534,7 @@ def apply_cost( self, arr: jnp.ndarray, axis: int = 0, - fn=None, + fn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, **kwargs: Any ) -> jnp.ndarray: """Apply cost matrix to array (vector or matrix). diff --git a/ott/geometry/low_rank.py b/ott/geometry/low_rank.py index 8e30b43c8..b99901691 100644 --- a/ott/geometry/low_rank.py +++ b/ott/geometry/low_rank.py @@ -122,10 +122,10 @@ def inv_scale_cost(self) -> float: def apply_square_cost(self, arr: jnp.ndarray, axis: int = 0) -> jnp.ndarray: """Apply elementwise-square of cost matrix to array (vector or matrix).""" (n, m), r = self.shape, self.cost_rank - # When applying square of a LRCgeometry, one can either elementwise square + # When applying square of a LRCGeometry, one can either elementwise square # the cost matrix, or instantiate an augmented (rank^2) LRCGeometry # and apply it. First is O(nm), the other is O((n+m)r^2). - if n * m < (n + m) * r ** 2: # better use regular apply + if n * m < (n + m) * r ** 2: # better use regular apply return super().apply_square_cost(arr, axis) else: new_cost_1 = self.cost_1[:, :, None] * self.cost_1[:, None, :] @@ -140,7 +140,7 @@ def _apply_cost_to_vec( vec: jnp.ndarray, axis: int = 0, fn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, - is_linear: Optional[bool] = None, + is_linear: bool = False, ) -> jnp.ndarray: """Apply [num_a, num_b] fn(cost) (or transpose) to vector. @@ -149,9 +149,9 @@ def _apply_cost_to_vec( axis: axis on which the reduction is done. fn: function optionally applied to cost matrix element-wise, before the doc product - is_linear: Whether ``fn`` is a linear function. If yes, efficient - implementation is used. If ``None``, it will be determined by - :func:`ott.geometry.geometry.is_linear` at runtime. + is_linear: Whether ``fn`` is a linear function to enable efficient + implementation. See :func:`ott.geometry.geometry.is_linear` + for a heuristic to help determine if a function is linear. Returns: A jnp.ndarray corresponding to cost x vector @@ -168,18 +168,8 @@ def linear_apply( return out + bias * jnp.sum(vec) * jnp.ones_like(out) if fn is None or is_linear: - return linear_apply(vec, axis, fn) - - # TODO(michalk8): for bwd compatibility only, should be removed once - # same principle is used in `LRSinkhorn` and `PointCloud` - # yapf: disable - return jax.lax.cond( - geometry.is_linear(fn), - lambda _: linear_apply(vec, axis, fn), - lambda g: super(g.__class__, g)._apply_cost_to_vec(vec, axis, fn), - self - ) - # yapf: enable + return linear_apply(vec, axis, fn=fn) + return super()._apply_cost_to_vec(vec, axis, fn=fn) def compute_max_cost(self) -> float: """Compute the maximum of the cost matrix. diff --git a/ott/geometry/pointcloud.py b/ott/geometry/pointcloud.py index 621007e24..9b5343b8e 100644 --- a/ott/geometry/pointcloud.py +++ b/ott/geometry/pointcloud.py @@ -15,7 +15,7 @@ # Lint as: python3 """A geometry defined using 2 point clouds and a cost function between them.""" import math -from typing import Any, Optional, Tuple, Union +from typing import Any, Callable, Optional, Tuple, Union import jax import jax.numpy as jnp @@ -352,7 +352,11 @@ def transport_from_scalings( ) def apply_cost( - self, arr: jnp.ndarray, axis: int = 0, fn=None, **_: Any + self, + arr: jnp.ndarray, + axis: int = 0, + fn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, + is_linear: bool = False, ) -> jnp.ndarray: """Apply cost matrix to array (vector or matrix). @@ -360,26 +364,27 @@ def apply_cost( output = C arr (if axis=1) output = C' arr (if axis=0) where C is [num_a, num_b] matrix resulting from the (optional) elementwise - application of fn to each entry of the `cost_matrix`. + application of fn to each entry of the :attr:`cost_matrix`. Args: arr: jnp.ndarray [num_a or num_b, batch], vector that will be multiplied by the cost matrix. - axis: standard cost matrix if axis=1, transpose if 0 + axis: standard cost matrix if axis=1, transpose if 0. fn: function optionally applied to cost matrix element-wise, before the - apply + apply. + is_linear: Whether ``fn`` is a linear function. + If true and :attr:`is_squared_euclidean` is ``True``, efficient + implementation is used. See :func:`ott.geometry.geometry.is_linear` + for a heuristic to help determine if a function is linear. Returns: A jnp.ndarray, [num_b, batch] if axis=0 or [num_a, batch] if axis=1 """ - if fn is None: - return self._apply_cost(arr, axis, fn=fn) - # Switch to efficient computation for the squared euclidean case. - return jax.lax.cond( - jnp.logical_and(self.is_squared_euclidean, geometry.is_affine(fn)), - lambda: self.vec_apply_cost(arr, axis, fn=fn), - lambda: self._apply_cost(arr, axis, fn=fn) - ) + # switch to efficient computation for the squared euclidean case. + if self.is_squared_euclidean and (fn is None or is_linear): + return self.vec_apply_cost(arr, axis, fn=fn) + + return self._apply_cost(arr, axis, fn=fn) def _apply_cost( self, arr: jnp.ndarray, axis: int = 0, fn=None @@ -430,19 +435,18 @@ def vec_apply_cost( Returns: A jnp.ndarray, [num_b, p] if axis=0 or [num_a, p] if axis=1 """ - rank = len(arr.shape) + rank = arr.ndim x, y = (self.x, self.y) if axis == 0 else (self.y, self.x) - nx, ny = jnp.array(self._norm_x), jnp.array(self._norm_y) + nx, ny = jnp.asarray(self._norm_x), jnp.asarray(self._norm_y) nx, ny = (nx, ny) if axis == 0 else (ny, nx) applied_cost = jnp.dot(nx, arr).reshape(1, -1) applied_cost += ny.reshape(-1, 1) * jnp.sum(arr, axis=0).reshape(1, -1) cross_term = -2.0 * jnp.dot(y, jnp.dot(x.T, arr)) applied_cost += cross_term[:, None] if rank == 1 else cross_term - return ( - fn(applied_cost) * self.inv_scale_cost if fn else applied_cost * - self.inv_scale_cost - ) + if fn is not None: + applied_cost = fn(applied_cost) + return self.inv_scale_cost * applied_cost def leading_slice(self, t: jnp.ndarray, i: int) -> jnp.ndarray: start_indices = [i * self._bs] + (t.ndim - 1) * [0] diff --git a/tests/geometry/geometry_lr_test.py b/tests/geometry/geometry_lr_test.py index 0dee19053..49ab470ba 100644 --- a/tests/geometry/geometry_lr_test.py +++ b/tests/geometry/geometry_lr_test.py @@ -125,6 +125,31 @@ def test_add_lr_geoms(self): rtol=1e-4 ) + @parameterized.product(fn=[lambda x: x + 10, lambda x: x * 2], axis=[0, 1]) + def test_apply_affine_function_efficient(self, fn, axis): + n, m, d = 21, 13, 3 + keys = jax.random.split(self.rng, 3) + x = jax.random.normal(keys[0], (n, d)) + y = jax.random.normal(keys[1], (m, d)) + vec = jax.random.normal(keys[2], (n if axis == 0 else m,)) + + geom = pointcloud.PointCloud(x, y) + + res_eff = geom.apply_cost(vec, axis=axis, fn=fn, is_linear=True) + res_ineff = geom.apply_cost(vec, axis=axis, fn=fn, is_linear=False) + + if fn(0.0) == 0.0: + np.testing.assert_allclose(res_eff, res_ineff, rtol=1e-4, atol=1e-4) + else: + self.assertRaises( + AssertionError, + np.testing.assert_allclose, + res_ineff, + res_eff, + rtol=1e-4, + atol=1e-4 + ) + if __name__ == '__main__': absltest.main()