Skip to content

Commit

Permalink
fix raggedtensor should pass name in the eager mode
Browse files Browse the repository at this point in the history
  • Loading branch information
jq committed Jun 4, 2024
1 parent 2c79a0d commit 8f590f6
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Dynamic Embedding is designed for Large-scale Sparse Weights Training.
See [Sparse Domain Isolation](https://github.com/tensorflow/community/pull/237)
"""
from tensorflow.python.ops.variables import VariableAggregation

from tensorflow_recommenders_addons import dynamic_embedding as de
from tensorflow_recommenders_addons.utils.resource_loader import get_tf_version_triple
Expand Down Expand Up @@ -62,7 +63,6 @@
except:
from tensorflow.python.training.tracking import data_structures
from tensorflow.python.util import compat, dispatch
from tensorflow.python.util.tf_export import tf_export

from tensorflow.python.keras.utils import tf_utils
try: # tf version >= 2.14.0
Expand Down Expand Up @@ -649,7 +649,7 @@ def _create_or_get_trainable(trainable_name):

with ops.colocate_with(ids, ignore_existing=True):
if distribute_ctx.has_strategy():
trainable_ = _distribute_trainable_store.get(name, None)
trainable_ = params._distribute_trainable_store.get(name, None)
if trainable_ is None:
strategy_devices = distribute_ctx.get_strategy(
).extended.worker_devices
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ def _embedding_lookup_sparse_impl(

ids, idx = array_ops.unique(ids)
if isinstance(params, de.shadow_ops.ShadowVariable):
embeddings = de.shadow_ops.embedding_lookup(params, ids)
embeddings = de.shadow_ops.embedding_lookup(params, ids, name=name)
else:
embeddings = de.embedding_lookup(params, ids)
embeddings = de.embedding_lookup(params, ids, name=name)

if not ignore_weights:
if segment_ids.dtype != dtypes.int32:
Expand Down

0 comments on commit 8f590f6

Please sign in to comment.