From 5d48391922b9c442ea6190207750f7af26e15227 Mon Sep 17 00:00:00 2001 From: michalk8 <46717574+michalk8@users.noreply.github.com> Date: Mon, 11 Jul 2022 21:56:21 +0000 Subject: [PATCH] Fix/axis norm (#103) * Fix norms in PointCloud when none is defined * Add test for cosine cost application --- ott/geometry/pointcloud.py | 16 ++++++---------- .../geometry_pointcloud_apply_test.py | 19 +++++++++++++++++++ 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/ott/geometry/pointcloud.py b/ott/geometry/pointcloud.py index 89afd3baf..ee4e1c0e8 100644 --- a/ott/geometry/pointcloud.py +++ b/ott/geometry/pointcloud.py @@ -82,18 +82,16 @@ def __init__( self._scale_cost = "mean" if scale_cost is True else scale_cost @property - def _norm_x(self) -> jnp.ndarray: + def _norm_x(self) -> Union[float, jnp.ndarray]: if self._axis_norm == 0: return self._cost_fn.norm(self.x) - elif self._axis_norm is None: - return jnp.zeros(self.x.shape[0]) + return 0. @property - def _norm_y(self) -> jnp.ndarray: + def _norm_y(self) -> Union[float, jnp.ndarray]: if self._axis_norm == 0: return self._cost_fn.norm(self.y) - elif self._axis_norm is None: - return jnp.zeros(self.y.shape[0]) + return 0. @property def cost_matrix(self) -> Optional[jnp.ndarray]: @@ -158,10 +156,7 @@ def inv_scale_cost(self) -> float: ) elif self._scale_cost == 'max_norm': if self._cost_fn.norm is not None: - return 1.0 / jnp.maximum( - self._cost_fn.norm(self.x).max(), - self._cost_fn.norm(self.y).max() - ) + return 1.0 / jnp.maximum(self._norm_x.max(), self._norm_y.max()) else: return 1.0 elif self._scale_cost == 'max_bound': @@ -421,6 +416,7 @@ def vec_apply_cost( Returns: A jnp.ndarray, [num_b, p] if axis=0 or [num_a, p] if axis=1 """ + assert self.is_squared_euclidean, "Cost matrix is not a squared Euclidean." rank = arr.ndim x, y = (self.x, self.y) if axis == 0 else (self.y, self.x) nx, ny = jnp.asarray(self._norm_x), jnp.asarray(self._norm_y) diff --git a/tests/geometry/geometry_pointcloud_apply_test.py b/tests/geometry/geometry_pointcloud_apply_test.py index 7114cc1e1..bb48b7daf 100644 --- a/tests/geometry/geometry_pointcloud_apply_test.py +++ b/tests/geometry/geometry_pointcloud_apply_test.py @@ -95,3 +95,22 @@ def test_correct_shape(self): y = jnp.zeros((m, d)) pc = pointcloud.PointCloud(x=x, y=y) np.testing.assert_array_equal(pc.shape, (n, m)) + + @pytest.mark.parametrize("axis", [0, 1]) + def test_apply_cost_without_norm(self, rng: jnp.ndarray, axis: 1): + key1, key2 = jax.random.split(rng, 2) + x = jax.random.normal(key1, shape=(17, 3)) + y = jax.random.normal(key2, shape=(12, 3)) + pc = pointcloud.PointCloud(x, y, cost_fn=costs.Cosine()) + arr = jnp.ones((pc.shape[0],)) if axis == 0 else jnp.ones((pc.shape[1],)) + + assert pc._cost_fn.norm is None + with pytest.raises( + AssertionError, match=r"Cost matrix is not a squared Euclidean\." + ): + _ = pc.vec_apply_cost(arr, axis=axis) + + expected = pc.cost_matrix @ arr if axis == 1 else pc.cost_matrix.T @ arr + actual = pc.apply_cost(arr, axis=axis).squeeze() + + np.testing.assert_allclose(actual, expected, rtol=1e-6, atol=1e-6)