Skip to content

Commit

Permalink
enable ShadowVariable look up for safe_embedding_lookup_sparse
Browse files Browse the repository at this point in the history
  • Loading branch information
jq committed Apr 9, 2024
1 parent 2a15cbf commit 84d1c40
Showing 1 changed file with 39 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -855,15 +855,22 @@ def embedding_lookup_sparse(

ids = sp_ids.values
ids, idx = array_ops.unique(ids)

embeddings, trainable_ = embedding_lookup(
params,
ids,
name=name + '/embedding_lookup',
partition_strategy=partition_strategy,
max_norm=max_norm,
return_trainable=True,
)
if isinstance(params, de.shadow_ops.ShadowVariable):
embeddings = de.shadow_ops.embedding_lookup(
params,
ids,
name=name + '/embedding_lookup',
)
trainable_ = params
else:
embeddings, trainable_ = embedding_lookup(
params,
ids,
name=name + '/embedding_lookup',
partition_strategy=partition_strategy,
max_norm=max_norm,
return_trainable=True,
)
if embeddings.dtype in (dtypes.float16, dtypes.bfloat16):
embeddings = math_ops.cast(embeddings, dtypes.float32)
if not ignore_weights:
Expand Down Expand Up @@ -928,6 +935,24 @@ def embedding_lookup_sparse(
return (embeddings, trainable_) if return_trainable else embeddings


def _verify_embedding_weights(embedding_weights,
sparse_ids,
sparse_weights=None):
if embedding_weights is None:
raise ValueError("Missing embedding_weights %s." % embedding_weights)

if embedding_weights.key_dtype != sparse_ids.dtype:
raise TypeError(
"embedding_weights.key_dtype should be same with sparse_ids.dtype: "
"{} vs. {}".format(embedding_weights.key_dtype, sparse_ids.dtype))

weights_dtype = sparse_weights.dtype if sparse_weights is not None else None
if weights_dtype and embedding_weights.value_dtype != weights_dtype:
raise TypeError(
"embedding_weights.value_dtype should be same with sparse_weights.dtype"
": {} vs. {}".format(embedding_weights.value_dtype, weights_dtype))


def safe_embedding_lookup_sparse(
embedding_weights,
sparse_ids,
Expand Down Expand Up @@ -980,19 +1005,11 @@ def safe_embedding_lookup_sparse(
Raises:
ValueError: if `embedding_weights` is empty.
"""
if embedding_weights is None:
raise ValueError("Missing embedding_weights %s." % embedding_weights)

if embedding_weights.key_dtype != sparse_ids.dtype:
raise TypeError(
"embedding_weights.key_dtype should be same with sparse_ids.dtype: "
"{} vs. {}".format(embedding_weights.key_dtype, sparse_ids.dtype))

weights_dtype = sparse_weights.dtype if sparse_weights is not None else None
if weights_dtype and embedding_weights.value_dtype != weights_dtype:
raise TypeError(
"embedding_weights.value_dtype should be same with sparse_weights.dtype"
": {} vs. {}".format(embedding_weights.value_dtype, weights_dtype))
if isinstance(embedding_weights, de.shadow_ops.ShadowVariable):
_verify_embedding_weights(embedding_weights.params, sparse_ids,
sparse_weights)
else:
_verify_embedding_weights(embedding_weights, sparse_ids, sparse_weights)

scope = variable_scope.get_variable_scope()
full_name = scope.name + "/" + name if scope.name else name
Expand Down

0 comments on commit 84d1c40

Please sign in to comment.