diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_patch.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_patch.py index b33d0f7ec..3eeca249f 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_patch.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_patch.py @@ -27,6 +27,20 @@ kinit2 = None pass # for compatible with TF < 2.3.x +try: + import tensorflow as tf + kinit_tf = tf.keras.initializers +except ImportError: + kinit_tf = None + pass # for compatible with TF >= 2.6.x + +try: + import keras as K + kinit_K = K.initializers +except ImportError: + kinit_K = None + pass # for compatible with standalone Keras + from tensorflow.core.framework import node_def_pb2 from tensorflow.python.eager import context from tensorflow.python.framework import device as pydev @@ -351,3 +365,7 @@ def patch_on_tf(): kinit1.VarianceScaling.__call__ = __call__for_keras_init_v1 if kinit2 is not None: kinit2.VarianceScaling.__call__ = __call__for_keras_init_v2 + if kinit_tf is not None: + kinit_tf.VarianceScaling.__call__ = __call__for_keras_init_v2 + if kinit_K is not None: + kinit_K.VarianceScaling.__call__ = __call__for_keras_init_v2