Skip to content

Commit

Permalink
Supports final layer activation in `create_feed_forward_common_tower_…
Browse files Browse the repository at this point in the history
…network`

PiperOrigin-RevId: 680586249
Change-Id: I67b731672e0a2ead0133547a84c25b7fc0a6b0b0
  • Loading branch information
TF-Agents Team authored and copybara-github committed Sep 30, 2024
1 parent 488343a commit 9919a23
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
7 changes: 7 additions & 0 deletions tf_agents/bandits/networks/global_and_arm_feature_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def create_feed_forward_common_tower_network(
activation_fn: Callable[
[types.Tensor], types.Tensor
] = tf.keras.activations.relu,
final_activation_fn_for_q_network: (
Callable[[types.Tensor], types.Tensor] | None
) = None,
name: Optional[str] = None,
) -> types.Network:
"""Creates a common tower network with feedforward towers.
Expand All @@ -92,6 +95,9 @@ def create_feed_forward_common_tower_network(
arm_preprocessing_combiner: Preprocessing combiner for the arm features.
activation_fn: A keras activation, specifying the activation function used
in all layers. Defaults to relu.
final_activation_fn_for_q_network: Only used when `output_dim=1`, a Callable
specifying the activation function for the final layer of the common
tower. Defaults to None, which means no activation function.
name: The network name to use. Shows up in Tensorboard losses.
Returns:
Expand Down Expand Up @@ -145,6 +151,7 @@ def create_feed_forward_common_tower_network(
),
fc_layer_params=common_layers,
activation_fn=activation_fn,
q_layer_activation_fn=final_activation_fn_for_q_network,
)
else:
common_network = encoding_network.EncodingNetwork(
Expand Down
23 changes: 23 additions & 0 deletions tf_agents/bandits/networks/global_and_arm_feature_network_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,29 @@ def testCreateFeedForwardCommonTowerNetwork(
output = self.evaluate(output)
self.assertAllEqual(output.shape, (batch_size, num_actions))

@parameters
def testCreateFeedForwardCommonTowerNetworkWithFinalActivation(
self, batch_size, feature_dim, num_actions
):
obs_spec = bandit_spec_utils.create_per_arm_observation_spec(
7, feature_dim, num_actions
)
net = gafn.create_feed_forward_common_tower_network(
obs_spec,
global_layers=(4, 3, 2),
arm_layers=(6, 5, 4),
common_layers=(7, 6, 5),
final_activation_fn_for_q_network=tf.exp,
)
input_nest = tensor_spec.sample_spec_nest(
obs_spec, outer_dims=(batch_size,)
)
output, _ = net(input_nest)
self.evaluate(tf.compat.v1.global_variables_initializer())
output = self.evaluate(output)
self.assertAllEqual(output.shape, (batch_size, num_actions))
self.assertAllGreaterEqual(output, 0.0)

@parameters
def testCreateFeedForwardCommonTowerNetworkWithEmptyLayers(
self, batch_size, feature_dim, num_actions
Expand Down

0 comments on commit 9919a23

Please sign in to comment.