Skip to content

Commit

Permalink
Prepare for upcoming keras initializer change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 453457229
Change-Id: Id8cd4c6409876d4d27fec03da8ab88f68ad17c26
  • Loading branch information
qlzh727 authored and copybara-github committed Jun 7, 2022
1 parent 340d52c commit fa0d607
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions tf_agents/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def mlp_layers(conv_layer_params=None,
kernel_size=kernel_size,
strides=strides,
activation=activation_fn,
kernel_initializer=kernel_initializer,
kernel_initializer=clone_initializer(kernel_initializer),
name='/'.join([name, 'conv2d']) if name else None)
for (filters, kernel_size, strides) in conv_layer_params
])
Expand Down Expand Up @@ -195,7 +195,7 @@ def mlp_layers(conv_layer_params=None,
layers.append(tf.keras.layers.Dense(
num_units,
activation=activation_fn,
kernel_initializer=kernel_initializer,
kernel_initializer=clone_initializer(kernel_initializer),
kernel_regularizer=kernel_regularizer,
name='/'.join([name, 'dense']) if name else None))
if not isinstance(dropout_params, dict):
Expand All @@ -207,3 +207,14 @@ def mlp_layers(conv_layer_params=None,
**dropout_params))

return layers


def clone_initializer(initializer):
# Keras initializer is going to be stateless, which mean reusing the same
# initializer will produce same init value when the shapes are the same.
if isinstance(initializer, tf.keras.initializers.Initializer):
return initializer.__class__.from_config(initializer.get_config())
# When the input is string/dict or other serialized configs, caller will
# create a new keras Initializer instance based on that, and we don't need to
# do anything
return initializer

0 comments on commit fa0d607

Please sign in to comment.