Skip to content

Commit

Permalink
enable ShadowVariable look up for safe_embedding_lookup_sparse and su…
Browse files Browse the repository at this point in the history
…pport raggedtensor
  • Loading branch information
jq committed Apr 17, 2024
1 parent 2a15cbf commit 4280d1d
Show file tree
Hide file tree
Showing 8 changed files with 830 additions and 133 deletions.
8 changes: 4 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ artifacts/# File patterns to ignore; see `git help ignore` for more information.
# Lines that start with '#' are comments.

*.whl
/bazel-bin/
/bazel-out/
/bazel-recommenders-addons/
/bazel-testlogs/
bazel-bin
bazel-out
bazel-recommenders-addons
bazel-testlogs
/tensorflow_recommenders_addons/dynamic_embedding/core/*.so

bazel-genfiles
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,20 +268,9 @@ def call(self, ids):
Returns:
A embedding output with shape (shape(ids), embedding_size).
"""
ids = tf.convert_to_tensor(ids)
input_shape = tf.shape(ids)
embeddings_shape = tf.concat([input_shape, [self.embedding_size]], 0)
ids_flat = tf.reshape(ids, (-1,))
if self.with_unique:
with tf.name_scope(self.name + "/EmbeddingWithUnique"):
unique_ids, idx = tf.unique(ids_flat)
unique_embeddings = de.shadow_ops.embedding_lookup(
self.shadow, unique_ids)
embeddings_flat = tf.gather(unique_embeddings, idx)
else:
embeddings_flat = de.shadow_ops.embedding_lookup(self.shadow, ids_flat)
embeddings = tf.reshape(embeddings_flat, embeddings_shape)
return embeddings
return de.shadow_ops.embedding_lookup_unify(self.shadow, ids,
self.embedding_size,
self.with_unique, self.name)

def get_config(self):
_initializer = self.params.initializer
Expand Down
Loading

0 comments on commit 4280d1d

Please sign in to comment.