diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py index eabf9103c..7018b4775 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py @@ -96,6 +96,9 @@ def apply_grad_to_update_var(var, grad): else: return update_op else: + if not var.params.trainable: + return control_flow_ops.no_op() + with ops.colocate_with(None, ignore_existing=True): _slots = [self.get_slot(var, _s) for _s in self.get_slot_names()] var._track_optimizer_slots(_slots) diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_patch.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_patch.py index 5a482ce03..d31fc93b6 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_patch.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_patch.py @@ -77,6 +77,9 @@ def update_op(self, optimizer, g): # pylint: disable=protected-access # for better convergence: + if not self._v.params.trainable: + return control_flow_ops.no_op() + with ops.colocate_with(None, ignore_existing=True): _slots = [ optimizer.get_slot(self._v, _s) for _s in optimizer.get_slot_names()