Skip to content

Commit

Permalink
Fix UnbalancedBures's norm
Browse files Browse the repository at this point in the history
  • Loading branch information
michalk8 committed Nov 22, 2022
1 parent 687e594 commit aea8912
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions ott/geometry/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,13 +429,14 @@ def norm(self, x: jnp.ndarray) -> jnp.ndarray:
"""Compute norm of Gaussian for unbalanced Bures.
Args:
x: Array of shape ``[n_points + n_points + n_dim ** 2,]``
corresponding to the raveled mass, means and the covariance matrix.
x: Array of shape ``[n_points + n_points + n_dim ** 2,]``, potentially
batched, corresponding to the raveled mass, means and the covariance
matrix.
Returns:
The norm, array of shape ``[n_points,]``.
The norm, array of shape ``[]`` or ``[batch,]`` in the batched case.
"""
return self._gamma * x[:, 0]
return self._gamma * x[..., 0]

def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"""Compute dot-product for unbalanced Bures.
Expand Down

0 comments on commit aea8912

Please sign in to comment.