diff --git a/tf_agents/networks/utils.py b/tf_agents/networks/utils.py index 638e65965..e832d4dfa 100644 --- a/tf_agents/networks/utils.py +++ b/tf_agents/networks/utils.py @@ -164,7 +164,7 @@ def mlp_layers(conv_layer_params=None, kernel_size=kernel_size, strides=strides, activation=activation_fn, - kernel_initializer=kernel_initializer, + kernel_initializer=clone_initializer(kernel_initializer), name='/'.join([name, 'conv2d']) if name else None) for (filters, kernel_size, strides) in conv_layer_params ]) @@ -195,7 +195,7 @@ def mlp_layers(conv_layer_params=None, layers.append(tf.keras.layers.Dense( num_units, activation=activation_fn, - kernel_initializer=kernel_initializer, + kernel_initializer=clone_initializer(kernel_initializer), kernel_regularizer=kernel_regularizer, name='/'.join([name, 'dense']) if name else None)) if not isinstance(dropout_params, dict): @@ -207,3 +207,14 @@ def mlp_layers(conv_layer_params=None, **dropout_params)) return layers + + +def clone_initializer(initializer): + # Keras initializer is going to be stateless, which mean reusing the same + # initializer will produce same init value when the shapes are the same. + if isinstance(initializer, tf.keras.initializers.Initializer): + return initializer.__class__.from_config(initializer.get_config()) + # When the input is string/dict or other serialized configs, caller will + # create a new keras Initializer instance based on that, and we don't need to + # do anything + return initializer