Skip to content

Commit

Permalink
Merge pull request #901 from porta-logica:master
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 591437086
Change-Id: If29bd104009954c5773f242b84332c33ac8ec529
  • Loading branch information
copybara-github committed Dec 16, 2023
2 parents 4041bda + 00d037d commit d185fb8
Show file tree
Hide file tree
Showing 14 changed files with 1,455 additions and 25 deletions.
21 changes: 11 additions & 10 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,23 +175,24 @@ def get_required_packages():
"""Returns list of required packages."""

required_packages = [
'absl-py >= 0.6.1',
'cloudpickle >= 1.3',
'gin-config >= 0.4.0',
'gym >= 0.17.0, <=0.23.0',
'numpy >= 1.19.0',
'pillow',
'six >= 1.10.0',
'protobuf >= 3.11.3',
'wrapt >= 1.11.1',
'absl-py >= 2.0.0',
'cloudpickle >= 3.0.0',
'gin-config >= 0.5.0',
'gym >= 0.17.0, <= 0.23.1',
'gymnasium >= 0.29.0',
'numpy >= 1.26.2',
'pillow >= 10.1.0',
'six >= 1.16.0',
'protobuf >= 3.11.3, <= 4.23.4',
'wrapt >= 1.16.0',
# Using an older version to avoid this bug
# https://github.com/tensorflow/tensorflow/issues/62217
# while using tf 2.15.0
'typing-extensions == 4.5.0',
# Used by gym >= 0.22.0. Only installed as a dependency when gym[all] is
# installed or if gym[*] (where * is an environment which lists pygame as
# a dependency).
'pygame == 2.1.3',
'pygame == 2.5.2',
]
add_additional_packages(required_packages)
return required_packages
Expand Down
80 changes: 72 additions & 8 deletions tf_agents/agents/dqn/dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,11 +328,14 @@ def _check_network_output(self, net, label):
net: A `Network`.
label: A label to print in case of a mismatch.
"""
network_utils.check_single_floating_network_output(
net.create_variables(),
expected_output_shape=(self._num_actions,),
label=label,
)
outputs = net.create_variables()
iterable = list(outputs) if isinstance(outputs, tuple) else [outputs]
for output in iterable:
network_utils.check_single_floating_network_output(
output,
expected_output_shape=(self._num_actions,),
label=label,
)

def _setup_policy(
self,
Expand Down Expand Up @@ -590,8 +593,10 @@ def _compute_q_values(self, time_steps, actions, training=False):
# param. Note: assumes len(tf.nest.flatten(action_spec)) == 1.
action_spec = cast(tensor_spec.BoundedTensorSpec, self._action_spec)
multi_dim_actions = action_spec.shape.rank > 0
# support for dueling networks
a_values = q_values[0] if isinstance(q_values, tuple) else q_values
return common.index_with_actions(
q_values,
a_values,
tf.cast(actions, dtype=tf.int32),
multi_dim_actions=multi_dim_actions,
)
Expand All @@ -614,9 +619,12 @@ def _compute_next_q_values(self, next_time_steps, info):
network_observation
)

next_target_q_values, _ = self._target_q_network(
q_next_target, _ = self._target_q_network(
network_observation, step_type=next_time_steps.step_type
)
next_target_q_values = (
q_next_target[0] if isinstance(q_next_target, tuple) else q_next_target
)
batch_size = (
next_target_q_values.shape[0] or tf.shape(next_target_q_values)[0]
)
Expand Down Expand Up @@ -668,9 +676,12 @@ def _compute_next_q_values(self, next_time_steps, info):
network_observation
)

next_target_q_values, _ = self._target_q_network(
q_next_target, _ = self._target_q_network(
network_observation, step_type=next_time_steps.step_type
)
next_target_q_values = (
q_next_target[0] if isinstance(q_next_target, tuple) else q_next_target
)
batch_size = (
next_target_q_values.shape[0] or tf.shape(next_target_q_values)[0]
)
Expand All @@ -687,3 +698,56 @@ def _compute_next_q_values(self, next_time_steps, info):
best_next_actions,
multi_dim_actions=multi_dim_actions,
)


@gin.configurable
class D3qnAgent(DqnAgent):
"""A Dueling DQN Agent.
Implements the Double Dueling DQN algorithm from
"Dueling Network Architectures for Deep Reinforcement Learning"
Wang et al., 2016
https://arxiv.org/abs/1511.06581
"""

def _compute_next_q_values(self, next_time_steps, info):
"""Compute the q value of the next state for TD error computation.
Args:
next_time_steps: A batch of next timesteps
info: PolicyStep.info that may be used by other agents inherited from
dqn_agent.
Returns:
A tensor of Q values for the given next state.
"""
del info
# TODO(b/117175589): Add binary tests for DDQN.
network_observation = next_time_steps.observation

if self._observation_and_action_constraint_splitter is not None:
network_observation, _ = self._observation_and_action_constraint_splitter(
network_observation
)

q_next_target, _ = self._target_q_network(
network_observation, step_type=next_time_steps.step_type
)
next_target_q_values = (
q_next_target[0] if isinstance(q_next_target, tuple) else q_next_target
)
q_next, _ = self._q_network(
network_observation, step_type=next_time_steps.step_type
)
next_q_values = q_next[1] if isinstance(q_next, tuple) else q_next
best_next_actions = tf.math.argmax(next_q_values, axis=1)

# Handle action_spec.shape=(), and shape=(1,) by using the multi_dim_actions
# param. Note: assumes len(tf.nest.flatten(action_spec)) == 1.
multi_dim_actions = tf.nest.flatten(self._action_spec)[0].shape.rank > 0
return common.index_with_actions(
next_target_q_values,
best_next_actions,
multi_dim_actions=multi_dim_actions,
)
28 changes: 25 additions & 3 deletions tf_agents/agents/dqn/dqn_agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ def testComputeTDTargets(self):


@parameterized.named_parameters(
('DqnAgent', dqn_agent.DqnAgent), ('DdqnAgent', dqn_agent.DdqnAgent)
('DqnAgent', dqn_agent.DqnAgent),
('DdqnAgent', dqn_agent.DdqnAgent),
('D3qnAgent', dqn_agent.D3qnAgent),
)
class DqnAgentTest(test_utils.TestCase):

Expand Down Expand Up @@ -216,6 +218,10 @@ def testLoss(self, agent_class):
self.assertAllClose(self.evaluate(loss), expected_loss)

def testLossWithChangedOptimalActions(self, agent_class):

# if 'D3qnAgent' in agent_class.__name__:
# self.skipTest('invalid for dueling networks')

q_net = DummyNet(self._observation_spec, self._action_spec)
agent = agent_class(
self._time_step_spec, self._action_spec, q_network=q_net, optimizer=None
Expand Down Expand Up @@ -475,6 +481,10 @@ def testLossNStepMidMidLastFirst(self, agent_class):
self.assertAllClose(self.evaluate(loss), expected_loss)

def testLossWithMaskedActions(self, agent_class):

# if 'D3qnAgent' in agent_class.__name__:
# self.skipTest('invalid for dueling networks')

# Observations are now a tuple of the usual observation and an action mask.
observation_spec_with_mask = (
self._observation_spec,
Expand Down Expand Up @@ -529,10 +539,22 @@ def testLossWithMaskedActions(self, agent_class):
# Target Q-value for second next_observation (only action 0 is valid):
# 2 * 7 + 1 * 8 + 1 = 23
# TD targets: 10 + 0.9 * 12 = 20.8 and 20 + 0.9 * 23 = 40.7
# TD errors: 20.8 - 5 = 15.8 and 40.7 - 8 = 32.7
# TD loss: 15.3 and 32.2 (Huber loss subtracts 0.5)
# TD errors: 20.8 - 5 = 20.3 and 40.7 - 8 = 32.7
# TD loss: 19.8 and 32.2 (Huber loss subtracts 0.5)
# Overall loss: (15.3 + 32.2) / 2 = 23.75
expected_loss = 23.75
if 'D3qnAgent' in agent_class.__name__:
# Using Q-values for next_observations only for D3qnAgent.
# Q-value for first next_observation/action pair:
# 2 * 5 + 1 * 6 + 1 = 17
# Q-value for second next_observation/action pair:
# 1 * 7 + 1 * 8 + 1 = 16
# TD targets: 10 + 0.9 * 17 = 25.3 and 20 + 0.9 * 23 = 40.7
# TD errors: 25.3 - 5 = 20.3 and 40.7 - 8 = 32.7
# TD loss: 19.8 and 32.2 (Huber loss subtracts 0.5)
# Overall loss: (19.8 + 32.2) / 2 = 26.0
expected_loss = 26.0

loss, _ = agent._loss(experience)

self.evaluate(tf.compat.v1.global_variables_initializer())
Expand Down
2 changes: 2 additions & 0 deletions tf_agents/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
# pylint: disable=g-import-not-at-top
try:
from tf_agents.environments import gym_wrapper
from tf_agents.environments import gymnasium_wrapper
from tf_agents.environments import suite_gym
from tf_agents.environments import suite_gymnasium
from tf_agents.environments import suite_atari
from tf_agents.environments import suite_dm_control
from tf_agents.environments import suite_mujoco
Expand Down
9 changes: 9 additions & 0 deletions tf_agents/environments/configs/suite_gymnasium.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#-*-Python-*-
import tf_agents.environments.suite_gymnasium

## Configure Environment
ENVIRONMENT = @suite_gymnasium.load()
suite_gymnasium.load.environment_name = %ENVIRONMENT_NAME
# Note: The ENVIRONMENT_NAME can be overridden by passing the command line flag:
# --params="ENVIRONMENT_NAME='CartPole-v1'"
ENVIRONMENT_NAME = 'CartPole-v1'
Loading

0 comments on commit d185fb8

Please sign in to comment.