Skip to content

Commit

Permalink
enable ShadowVariable look up for safe_embedding_lookup_sparse and su…
Browse files Browse the repository at this point in the history
…pport raggedtensor
  • Loading branch information
jq committed Apr 17, 2024
1 parent 2a15cbf commit d1bbbec
Show file tree
Hide file tree
Showing 8 changed files with 881 additions and 140 deletions.
8 changes: 4 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ artifacts/# File patterns to ignore; see `git help ignore` for more information.
# Lines that start with '#' are comments.

*.whl
/bazel-bin/
/bazel-out/
/bazel-recommenders-addons/
/bazel-testlogs/
bazel-bin
bazel-out
bazel-recommenders-addons
bazel-testlogs
/tensorflow_recommenders_addons/dynamic_embedding/core/*.so

bazel-genfiles
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,12 @@
See [Sparse Domain Isolation](https://github.com/tensorflow/community/pull/237)
"""

import pickle

import tensorflow as tf

from tensorflow.python.eager import context
from tensorflow.python.ops import init_ops
from tensorflow_recommenders_addons import dynamic_embedding as de
from tensorflow_recommenders_addons.dynamic_embedding.python.ops import dynamic_embedding_variable as devar

from tensorflow.python.distribute import distribute_lib
from tensorflow.python.keras.utils import tf_utils
try: # tf version >= 2.14.0
from tensorflow.python.distribute import distribute_lib as distribute_ctx
Expand Down Expand Up @@ -268,20 +264,9 @@ def call(self, ids):
Returns:
A embedding output with shape (shape(ids), embedding_size).
"""
ids = tf.convert_to_tensor(ids)
input_shape = tf.shape(ids)
embeddings_shape = tf.concat([input_shape, [self.embedding_size]], 0)
ids_flat = tf.reshape(ids, (-1,))
if self.with_unique:
with tf.name_scope(self.name + "/EmbeddingWithUnique"):
unique_ids, idx = tf.unique(ids_flat)
unique_embeddings = de.shadow_ops.embedding_lookup(
self.shadow, unique_ids)
embeddings_flat = tf.gather(unique_embeddings, idx)
else:
embeddings_flat = de.shadow_ops.embedding_lookup(self.shadow, ids_flat)
embeddings = tf.reshape(embeddings_flat, embeddings_shape)
return embeddings
return de.shadow_ops.embedding_lookup_unique(self.shadow, ids,
self.embedding_size,
self.with_unique, self.name)

def get_config(self):
_initializer = self.params.initializer
Expand Down
Loading

0 comments on commit d1bbbec

Please sign in to comment.