Skip to content

Commit

Permalink
[fix] fix an issue where tf_patch.py failed in TF >= 2.6.x. It cause …
Browse files Browse the repository at this point in the history
…by Keras detached from TF repo.
  • Loading branch information
MoFHeka authored and rhdong committed Mar 6, 2022
1 parent 6872e62 commit a4d5f92
Showing 1 changed file with 18 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,20 @@
kinit2 = None
pass # for compatible with TF < 2.3.x

try:
import tensorflow as tf
kinit_tf = tf.keras.initializers
except ImportError:
kinit_tf = None
pass # for compatible with TF >= 2.6.x

try:
import keras as K
kinit_K = K.initializers
except ImportError:
kinit_K = None
pass # for compatible with standalone Keras

from tensorflow.core.framework import node_def_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import device as pydev
Expand Down Expand Up @@ -351,3 +365,7 @@ def patch_on_tf():
kinit1.VarianceScaling.__call__ = __call__for_keras_init_v1
if kinit2 is not None:
kinit2.VarianceScaling.__call__ = __call__for_keras_init_v2
if kinit_tf is not None:
kinit_tf.VarianceScaling.__call__ = __call__for_keras_init_v2
if kinit_K is not None:
kinit_K.VarianceScaling.__call__ = __call__for_keras_init_v2

0 comments on commit a4d5f92

Please sign in to comment.