Skip to content

Commit

Permalink
Merge pull request #10 from michalk8/fix/legacy-dot-product
Browse files Browse the repository at this point in the history
Remove custom vdot implementation
  • Loading branch information
LaetitiaPapaxanthos authored Feb 21, 2022
2 parents fe585f2 + cbd3291 commit aeb4f40
Showing 1 changed file with 3 additions and 12 deletions.
15 changes: 3 additions & 12 deletions ott/geometry/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,10 @@
import abc

import jax
from jax.lib import xla_bridge
import jax.numpy as jnp
from ott.geometry import matrix_square_root


def dot(x: jnp.ndarray, y: jnp.ndarray):
"""Accelerator dependent dot. Implemented to avoid OOMs with online mode."""
platform = xla_bridge.get_backend().platform
return jnp.where(platform == 'gpu',
jnp.sum(x * y),
jnp.vdot(x, y))


@jax.tree_util.register_pytree_node_class
class CostFn(abc.ABC):
"""A generic cost function, taking two vectors as input.
Expand Down Expand Up @@ -93,7 +84,7 @@ def norm(self, x):
return jnp.sum(x ** 2, axis=-1)

def pairwise(self, x, y):
return -2 * dot(x, y)
return -2 * jnp.vdot(x, y)


@jax.tree_util.register_pytree_node_class
Expand All @@ -108,7 +99,7 @@ def pairwise(self, x, y):
ridge = self._ridge
x_norm = jnp.linalg.norm(x, axis=-1)
y_norm = jnp.linalg.norm(y, axis=-1)
cosine_similarity = dot(x, y) / (x_norm * y_norm + ridge)
cosine_similarity = jnp.vdot(x, y) / (x_norm * y_norm + ridge)
cosine_distance = 1.0 - cosine_similarity
return cosine_distance

Expand All @@ -131,7 +122,7 @@ def norm(self, x):
return norm

def pairwise(self, x, y):
mean_dot_prod = dot(x[0:self._dimension], y[0:self._dimension])
mean_dot_prod = jnp.vdot(x[0:self._dimension], y[0:self._dimension])
x_mat = jnp.reshape(x[self._dimension:],
(self._dimension, self._dimension))
y_mat = jnp.reshape(y[self._dimension:],
Expand Down

0 comments on commit aeb4f40

Please sign in to comment.