Skip to content

Commit

Permalink
Fix/point cloud apply cost fn (#93)
Browse files Browse the repository at this point in the history
* Fix not using efficient apply in sqeucl case

* Remove bwd compatible lax.cond in LRCGeometry

* Fix requiring affine fn instead of linear

* Fix function inversion check when tracing
  • Loading branch information
michalk8 authored Jul 1, 2022
1 parent fb00c9e commit 42890aa
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 44 deletions.
12 changes: 8 additions & 4 deletions ott/core/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,14 @@ class PositiveDense(nn.Module):
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.zeros

def setup(self):
if round(self.inv_rectifier_fn(self.rectifier_fn(0.1)), 3) != 0.1:
raise RuntimeError(
"Make sure both rectifier and inverse are defined properly."
)
try:
if round(self.inv_rectifier_fn(self.rectifier_fn(0.1)), 3) != 0.1:
raise RuntimeError(
"Make sure both rectifier and inverse are defined properly."
)
except TypeError as e:
if "doesn't define __round__ method" not in str(e):
raise # not comparing tracer values, raise

@nn.compact
def __call__(self, inputs):
Expand Down
1 change: 0 additions & 1 deletion ott/core/quad_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,5 +538,4 @@ def update_epsilon_unbalanced(epsilon, transport_mass):
def apply_cost(
geom: geometry.Geometry, arr: jnp.ndarray, *, axis: int, fn: Loss
) -> jnp.ndarray:
# TODO(michalk8): handle PCs
return geom.apply_cost(arr, axis=axis, fn=fn.func, is_linear=fn.is_linear)
4 changes: 2 additions & 2 deletions ott/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# Lint as: python3
"""A class describing operations used to instantiate and use a geometry."""
import functools
from typing import Any, Optional, Tuple, Union
from typing import Any, Callable, Optional, Tuple, Union

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -534,7 +534,7 @@ def apply_cost(
self,
arr: jnp.ndarray,
axis: int = 0,
fn=None,
fn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None,
**kwargs: Any
) -> jnp.ndarray:
"""Apply cost matrix to array (vector or matrix).
Expand Down
26 changes: 8 additions & 18 deletions ott/geometry/low_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,10 @@ def inv_scale_cost(self) -> float:
def apply_square_cost(self, arr: jnp.ndarray, axis: int = 0) -> jnp.ndarray:
"""Apply elementwise-square of cost matrix to array (vector or matrix)."""
(n, m), r = self.shape, self.cost_rank
# When applying square of a LRCgeometry, one can either elementwise square
# When applying square of a LRCGeometry, one can either elementwise square
# the cost matrix, or instantiate an augmented (rank^2) LRCGeometry
# and apply it. First is O(nm), the other is O((n+m)r^2).
if n * m < (n + m) * r ** 2: # better use regular apply
if n * m < (n + m) * r ** 2: # better use regular apply
return super().apply_square_cost(arr, axis)
else:
new_cost_1 = self.cost_1[:, :, None] * self.cost_1[:, None, :]
Expand All @@ -140,7 +140,7 @@ def _apply_cost_to_vec(
vec: jnp.ndarray,
axis: int = 0,
fn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None,
is_linear: Optional[bool] = None,
is_linear: bool = False,
) -> jnp.ndarray:
"""Apply [num_a, num_b] fn(cost) (or transpose) to vector.
Expand All @@ -149,9 +149,9 @@ def _apply_cost_to_vec(
axis: axis on which the reduction is done.
fn: function optionally applied to cost matrix element-wise, before the
doc product
is_linear: Whether ``fn`` is a linear function. If yes, efficient
implementation is used. If ``None``, it will be determined by
:func:`ott.geometry.geometry.is_linear` at runtime.
is_linear: Whether ``fn`` is a linear function to enable efficient
implementation. See :func:`ott.geometry.geometry.is_linear`
for a heuristic to help determine if a function is linear.
Returns:
A jnp.ndarray corresponding to cost x vector
Expand All @@ -168,18 +168,8 @@ def linear_apply(
return out + bias * jnp.sum(vec) * jnp.ones_like(out)

if fn is None or is_linear:
return linear_apply(vec, axis, fn)

# TODO(michalk8): for bwd compatibility only, should be removed once
# same principle is used in `LRSinkhorn` and `PointCloud`
# yapf: disable
return jax.lax.cond(
geometry.is_linear(fn),
lambda _: linear_apply(vec, axis, fn),
lambda g: super(g.__class__, g)._apply_cost_to_vec(vec, axis, fn),
self
)
# yapf: enable
return linear_apply(vec, axis, fn=fn)
return super()._apply_cost_to_vec(vec, axis, fn=fn)

def compute_max_cost(self) -> float:
"""Compute the maximum of the cost matrix.
Expand Down
42 changes: 23 additions & 19 deletions ott/geometry/pointcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# Lint as: python3
"""A geometry defined using 2 point clouds and a cost function between them."""
import math
from typing import Any, Optional, Tuple, Union
from typing import Any, Callable, Optional, Tuple, Union

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -352,34 +352,39 @@ def transport_from_scalings(
)

def apply_cost(
self, arr: jnp.ndarray, axis: int = 0, fn=None, **_: Any
self,
arr: jnp.ndarray,
axis: int = 0,
fn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None,
is_linear: bool = False,
) -> jnp.ndarray:
"""Apply cost matrix to array (vector or matrix).
This function applies the geometry's cost matrix, to perform either
output = C arr (if axis=1)
output = C' arr (if axis=0)
where C is [num_a, num_b] matrix resulting from the (optional) elementwise
application of fn to each entry of the `cost_matrix`.
application of fn to each entry of the :attr:`cost_matrix`.
Args:
arr: jnp.ndarray [num_a or num_b, batch], vector that will be multiplied
by the cost matrix.
axis: standard cost matrix if axis=1, transpose if 0
axis: standard cost matrix if axis=1, transpose if 0.
fn: function optionally applied to cost matrix element-wise, before the
apply
apply.
is_linear: Whether ``fn`` is a linear function.
If true and :attr:`is_squared_euclidean` is ``True``, efficient
implementation is used. See :func:`ott.geometry.geometry.is_linear`
for a heuristic to help determine if a function is linear.
Returns:
A jnp.ndarray, [num_b, batch] if axis=0 or [num_a, batch] if axis=1
"""
if fn is None:
return self._apply_cost(arr, axis, fn=fn)
# Switch to efficient computation for the squared euclidean case.
return jax.lax.cond(
jnp.logical_and(self.is_squared_euclidean, geometry.is_affine(fn)),
lambda: self.vec_apply_cost(arr, axis, fn=fn),
lambda: self._apply_cost(arr, axis, fn=fn)
)
# switch to efficient computation for the squared euclidean case.
if self.is_squared_euclidean and (fn is None or is_linear):
return self.vec_apply_cost(arr, axis, fn=fn)

return self._apply_cost(arr, axis, fn=fn)

def _apply_cost(
self, arr: jnp.ndarray, axis: int = 0, fn=None
Expand Down Expand Up @@ -430,19 +435,18 @@ def vec_apply_cost(
Returns:
A jnp.ndarray, [num_b, p] if axis=0 or [num_a, p] if axis=1
"""
rank = len(arr.shape)
rank = arr.ndim
x, y = (self.x, self.y) if axis == 0 else (self.y, self.x)
nx, ny = jnp.array(self._norm_x), jnp.array(self._norm_y)
nx, ny = jnp.asarray(self._norm_x), jnp.asarray(self._norm_y)
nx, ny = (nx, ny) if axis == 0 else (ny, nx)

applied_cost = jnp.dot(nx, arr).reshape(1, -1)
applied_cost += ny.reshape(-1, 1) * jnp.sum(arr, axis=0).reshape(1, -1)
cross_term = -2.0 * jnp.dot(y, jnp.dot(x.T, arr))
applied_cost += cross_term[:, None] if rank == 1 else cross_term
return (
fn(applied_cost) * self.inv_scale_cost if fn else applied_cost *
self.inv_scale_cost
)
if fn is not None:
applied_cost = fn(applied_cost)
return self.inv_scale_cost * applied_cost

def leading_slice(self, t: jnp.ndarray, i: int) -> jnp.ndarray:
start_indices = [i * self._bs] + (t.ndim - 1) * [0]
Expand Down
25 changes: 25 additions & 0 deletions tests/geometry/geometry_lr_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,31 @@ def test_add_lr_geoms(self):
rtol=1e-4
)

@parameterized.product(fn=[lambda x: x + 10, lambda x: x * 2], axis=[0, 1])
def test_apply_affine_function_efficient(self, fn, axis):
n, m, d = 21, 13, 3
keys = jax.random.split(self.rng, 3)
x = jax.random.normal(keys[0], (n, d))
y = jax.random.normal(keys[1], (m, d))
vec = jax.random.normal(keys[2], (n if axis == 0 else m,))

geom = pointcloud.PointCloud(x, y)

res_eff = geom.apply_cost(vec, axis=axis, fn=fn, is_linear=True)
res_ineff = geom.apply_cost(vec, axis=axis, fn=fn, is_linear=False)

if fn(0.0) == 0.0:
np.testing.assert_allclose(res_eff, res_ineff, rtol=1e-4, atol=1e-4)
else:
self.assertRaises(
AssertionError,
np.testing.assert_allclose,
res_ineff,
res_eff,
rtol=1e-4,
atol=1e-4
)


if __name__ == '__main__':
absltest.main()

0 comments on commit 42890aa

Please sign in to comment.