Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove custom vdot implementation #10

Merged
merged 1 commit into from
Feb 21, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 3 additions & 12 deletions ott/geometry/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,10 @@
import abc

import jax
from jax.lib import xla_bridge
import jax.numpy as jnp
from ott.geometry import matrix_square_root


def dot(x: jnp.ndarray, y: jnp.ndarray):
"""Accelerator dependent dot. Implemented to avoid OOMs with online mode."""
platform = xla_bridge.get_backend().platform
return jnp.where(platform == 'gpu',
jnp.sum(x * y),
jnp.vdot(x, y))


@jax.tree_util.register_pytree_node_class
class CostFn(abc.ABC):
"""A generic cost function, taking two vectors as input.
Expand Down Expand Up @@ -93,7 +84,7 @@ def norm(self, x):
return jnp.sum(x ** 2, axis=-1)

def pairwise(self, x, y):
return -2 * dot(x, y)
return -2 * jnp.vdot(x, y)


@jax.tree_util.register_pytree_node_class
Expand All @@ -108,7 +99,7 @@ def pairwise(self, x, y):
ridge = self._ridge
x_norm = jnp.linalg.norm(x, axis=-1)
y_norm = jnp.linalg.norm(y, axis=-1)
cosine_similarity = dot(x, y) / (x_norm * y_norm + ridge)
cosine_similarity = jnp.vdot(x, y) / (x_norm * y_norm + ridge)
cosine_distance = 1.0 - cosine_similarity
return cosine_distance

Expand All @@ -131,7 +122,7 @@ def norm(self, x):
return norm

def pairwise(self, x, y):
mean_dot_prod = dot(x[0:self._dimension], y[0:self._dimension])
mean_dot_prod = jnp.vdot(x[0:self._dimension], y[0:self._dimension])
x_mat = jnp.reshape(x[self._dimension:],
(self._dimension, self._dimension))
y_mat = jnp.reshape(y[self._dimension:],
Expand Down