From 2c7303de3e45bad0c2541f04cacbd131b9a669b4 Mon Sep 17 00:00:00 2001 From: Michal Klein Date: Mon, 21 Feb 2022 01:35:24 +0100 Subject: [PATCH] Remove custom vdot implementation --- ott/geometry/costs.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/ott/geometry/costs.py b/ott/geometry/costs.py index 3c04d1739..ffcea2b7a 100644 --- a/ott/geometry/costs.py +++ b/ott/geometry/costs.py @@ -18,19 +18,10 @@ import abc import jax -from jax.lib import xla_bridge import jax.numpy as jnp from ott.geometry import matrix_square_root -def dot(x: jnp.ndarray, y: jnp.ndarray): - """Accelerator dependent dot. Implemented to avoid OOMs with online mode.""" - platform = xla_bridge.get_backend().platform - return jnp.where(platform == 'gpu', - jnp.sum(x * y), - jnp.vdot(x, y)) - - @jax.tree_util.register_pytree_node_class class CostFn(abc.ABC): """A generic cost function, taking two vectors as input. @@ -93,7 +84,7 @@ def norm(self, x): return jnp.sum(x ** 2, axis=-1) def pairwise(self, x, y): - return -2 * dot(x, y) + return -2 * jnp.vdot(x, y) @jax.tree_util.register_pytree_node_class @@ -108,7 +99,7 @@ def pairwise(self, x, y): ridge = self._ridge x_norm = jnp.linalg.norm(x, axis=-1) y_norm = jnp.linalg.norm(y, axis=-1) - cosine_similarity = dot(x, y) / (x_norm * y_norm + ridge) + cosine_similarity = jnp.vdot(x, y) / (x_norm * y_norm + ridge) cosine_distance = 1.0 - cosine_similarity return cosine_distance @@ -131,7 +122,7 @@ def norm(self, x): return norm def pairwise(self, x, y): - mean_dot_prod = dot(x[0:self._dimension], y[0:self._dimension]) + mean_dot_prod = jnp.vdot(x[0:self._dimension], y[0:self._dimension]) x_mat = jnp.reshape(x[self._dimension:], (self._dimension, self._dimension)) y_mat = jnp.reshape(y[self._dimension:],