From 6e3190df0c71e0113aa29ff733171024f3310e31 Mon Sep 17 00:00:00 2001 From: alionkun Date: Wed, 28 Sep 2022 23:00:40 +0800 Subject: [PATCH 1/2] [fix] fix shadow variable lookup race condition --- .../dynamic_embedding/python/ops/shadow_embedding_ops.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/shadow_embedding_ops.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/shadow_embedding_ops.py index 6fd8ad613..6d90e9e0a 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/shadow_embedding_ops.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/shadow_embedding_ops.py @@ -247,5 +247,8 @@ def embedding_lookup( ' {},'.format(ids.dtype, shadow.ids.dtype)) with ops.name_scope(name, "shadow_embedding_lookup"): - with ops.control_dependencies([shadow._reset_ids(ids)]): - return shadow.read_value(do_prefetch=True) + if de.ModelMode.CURRENT_SETTING == 'train': + with ops.control_dependencies([shadow._reset_ids(ids)]): + return shadow.read_value(do_prefetch=True) + else: + return shadow.params.lookup(ids) From 8001133e5ac5304d0b0b530630f947ffd431ff96 Mon Sep 17 00:00:00 2001 From: alionkun Date: Thu, 29 Sep 2022 14:21:48 +0800 Subject: [PATCH 2/2] use constant instead of string literal --- .../dynamic_embedding/python/ops/shadow_embedding_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/shadow_embedding_ops.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/shadow_embedding_ops.py index 6d90e9e0a..5e3de393c 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/shadow_embedding_ops.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/shadow_embedding_ops.py @@ -247,7 +247,7 @@ def embedding_lookup( ' {},'.format(ids.dtype, shadow.ids.dtype)) with ops.name_scope(name, "shadow_embedding_lookup"): - if de.ModelMode.CURRENT_SETTING == 'train': + if de.ModelMode.CURRENT_SETTING == de.ModelMode.TRAIN: with ops.control_dependencies([shadow._reset_ids(ids)]): return shadow.read_value(do_prefetch=True) else: