From fa0d6075c4b429e2b22b0303e7f337ad8641f1dd Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Tue, 7 Jun 2022 09:38:09 -0700 Subject: [PATCH] Prepare for upcoming keras initializer change. PiperOrigin-RevId: 453457229 Change-Id: Id8cd4c6409876d4d27fec03da8ab88f68ad17c26 --- tf_agents/networks/utils.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/tf_agents/networks/utils.py b/tf_agents/networks/utils.py index 638e65965..e832d4dfa 100644 --- a/tf_agents/networks/utils.py +++ b/tf_agents/networks/utils.py @@ -164,7 +164,7 @@ def mlp_layers(conv_layer_params=None, kernel_size=kernel_size, strides=strides, activation=activation_fn, - kernel_initializer=kernel_initializer, + kernel_initializer=clone_initializer(kernel_initializer), name='/'.join([name, 'conv2d']) if name else None) for (filters, kernel_size, strides) in conv_layer_params ]) @@ -195,7 +195,7 @@ def mlp_layers(conv_layer_params=None, layers.append(tf.keras.layers.Dense( num_units, activation=activation_fn, - kernel_initializer=kernel_initializer, + kernel_initializer=clone_initializer(kernel_initializer), kernel_regularizer=kernel_regularizer, name='/'.join([name, 'dense']) if name else None)) if not isinstance(dropout_params, dict): @@ -207,3 +207,14 @@ def mlp_layers(conv_layer_params=None, **dropout_params)) return layers + + +def clone_initializer(initializer): + # Keras initializer is going to be stateless, which mean reusing the same + # initializer will produce same init value when the shapes are the same. + if isinstance(initializer, tf.keras.initializers.Initializer): + return initializer.__class__.from_config(initializer.get_config()) + # When the input is string/dict or other serialized configs, caller will + # create a new keras Initializer instance based on that, and we don't need to + # do anything + return initializer