Skip to content

Commit

Permalink
Fix/axis norm (#103)
Browse files Browse the repository at this point in the history
* Fix norms in PointCloud when none is defined

* Add test for cosine cost application
  • Loading branch information
michalk8 authored Jul 11, 2022
1 parent b218701 commit 5d48391
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 10 deletions.
16 changes: 6 additions & 10 deletions ott/geometry/pointcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,18 +82,16 @@ def __init__(
self._scale_cost = "mean" if scale_cost is True else scale_cost

@property
def _norm_x(self) -> jnp.ndarray:
def _norm_x(self) -> Union[float, jnp.ndarray]:
if self._axis_norm == 0:
return self._cost_fn.norm(self.x)
elif self._axis_norm is None:
return jnp.zeros(self.x.shape[0])
return 0.

@property
def _norm_y(self) -> jnp.ndarray:
def _norm_y(self) -> Union[float, jnp.ndarray]:
if self._axis_norm == 0:
return self._cost_fn.norm(self.y)
elif self._axis_norm is None:
return jnp.zeros(self.y.shape[0])
return 0.

@property
def cost_matrix(self) -> Optional[jnp.ndarray]:
Expand Down Expand Up @@ -158,10 +156,7 @@ def inv_scale_cost(self) -> float:
)
elif self._scale_cost == 'max_norm':
if self._cost_fn.norm is not None:
return 1.0 / jnp.maximum(
self._cost_fn.norm(self.x).max(),
self._cost_fn.norm(self.y).max()
)
return 1.0 / jnp.maximum(self._norm_x.max(), self._norm_y.max())
else:
return 1.0
elif self._scale_cost == 'max_bound':
Expand Down Expand Up @@ -421,6 +416,7 @@ def vec_apply_cost(
Returns:
A jnp.ndarray, [num_b, p] if axis=0 or [num_a, p] if axis=1
"""
assert self.is_squared_euclidean, "Cost matrix is not a squared Euclidean."
rank = arr.ndim
x, y = (self.x, self.y) if axis == 0 else (self.y, self.x)
nx, ny = jnp.asarray(self._norm_x), jnp.asarray(self._norm_y)
Expand Down
19 changes: 19 additions & 0 deletions tests/geometry/geometry_pointcloud_apply_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,22 @@ def test_correct_shape(self):
y = jnp.zeros((m, d))
pc = pointcloud.PointCloud(x=x, y=y)
np.testing.assert_array_equal(pc.shape, (n, m))

@pytest.mark.parametrize("axis", [0, 1])
def test_apply_cost_without_norm(self, rng: jnp.ndarray, axis: 1):
key1, key2 = jax.random.split(rng, 2)
x = jax.random.normal(key1, shape=(17, 3))
y = jax.random.normal(key2, shape=(12, 3))
pc = pointcloud.PointCloud(x, y, cost_fn=costs.Cosine())
arr = jnp.ones((pc.shape[0],)) if axis == 0 else jnp.ones((pc.shape[1],))

assert pc._cost_fn.norm is None
with pytest.raises(
AssertionError, match=r"Cost matrix is not a squared Euclidean\."
):
_ = pc.vec_apply_cost(arr, axis=axis)

expected = pc.cost_matrix @ arr if axis == 1 else pc.cost_matrix.T @ arr
actual = pc.apply_cost(arr, axis=axis).squeeze()

np.testing.assert_allclose(actual, expected, rtol=1e-6, atol=1e-6)

0 comments on commit 5d48391

Please sign in to comment.