Skip to content

Commit

Permalink
Add support for having an activation in the Q layer.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 452725885
Change-Id: I59f05d68c3247a808611ebb9f6aae46786ff59b3
  • Loading branch information
efiko authored and copybara-github committed Jun 3, 2022
1 parent d6f84de commit 4b2203f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
4 changes: 3 additions & 1 deletion tf_agents/networks/q_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(self,
kernel_initializer=None,
batch_squash=True,
dtype=tf.float32,
q_layer_activation_fn=None,
name='QNetwork'):
"""Creates an instance of `QNetwork`.
Expand Down Expand Up @@ -90,6 +91,7 @@ def __init__(self,
the batch dimension. This allow encoding networks to be used with
observations with shape [BxTx...].
dtype: The dtype to use by the convolution and fully connected layers.
q_layer_activation_fn: Activation function for the Q layer.
name: A string representing the name of the network.
Raises:
Expand All @@ -115,7 +117,7 @@ def __init__(self,

q_value_layer = tf.keras.layers.Dense(
num_actions,
activation=None,
activation=q_layer_activation_fn,
kernel_initializer=tf.random_uniform_initializer(
minval=-0.03, maxval=0.03),
bias_initializer=tf.constant_initializer(-0.2),
Expand Down
13 changes: 13 additions & 0 deletions tf_agents/networks/q_network_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,19 @@ def testPreprocessingLayersSingleObservations(self):
q_logits, _ = network(tf.ones((3, num_state_dims)))
self.assertAllEqual(q_logits.shape.as_list(), [3, 2])

def testQLayerActivation(self):
"""Tests activation for the Q layer."""
num_state_dims = 5
network = q_network.QNetwork(
input_tensor_spec=tensor_spec.TensorSpec([num_state_dims], tf.float32),
action_spec=tensor_spec.BoundedTensorSpec([1], tf.int32, 0, 1),
q_layer_activation_fn=tf.keras.activations.softplus)
q_logits, _ = network(tf.ones((3, num_state_dims)))
self.evaluate(tf.compat.v1.global_variables_initializer())
self.evaluate(tf.compat.v1.initializers.tables_initializer())
self.assertAllEqual(q_logits.shape.as_list(), [3, 2])
self.assertAllGreater(q_logits, 0.0)


if __name__ == '__main__':
tf.test.main()

0 comments on commit 4b2203f

Please sign in to comment.