diff --git a/ott/geometry/costs.py b/ott/geometry/costs.py index 42559f1ef..94c06112a 100644 --- a/ott/geometry/costs.py +++ b/ott/geometry/costs.py @@ -429,13 +429,14 @@ def norm(self, x: jnp.ndarray) -> jnp.ndarray: """Compute norm of Gaussian for unbalanced Bures. Args: - x: Array of shape ``[n_points + n_points + n_dim ** 2,]`` - corresponding to the raveled mass, means and the covariance matrix. + x: Array of shape ``[n_points + n_points + n_dim ** 2,]``, potentially + batched, corresponding to the raveled mass, means and the covariance + matrix. Returns: - The norm, array of shape ``[n_points,]``. + The norm, array of shape ``[]`` or ``[batch,]`` in the batched case. """ - return self._gamma * x[:, 0] + return self._gamma * x[..., 0] def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: """Compute dot-product for unbalanced Bures.