-
Notifications
You must be signed in to change notification settings - Fork 3
/
maddpg_torch_model.py
311 lines (280 loc) · 11.5 KB
/
maddpg_torch_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
from numpy.core.fromnumeric import shape
import ray
import gym
from gym.spaces import Discrete, Box
import numpy as np
from typing import List, Dict, Union
from ray.rllib.agents.ddpg.ddpg_torch_model import DDPGTorchModel
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.utils import get_activation_fn
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import ModelConfigDict, TensorType
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.typing import (
TrainerConfigDict,
TensorType,
LocalOptimizer,
GradInfoDict,
)
from ray.rllib.agents.ddpg.noop_model import TorchNoopModel
from ray.rllib.models import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2
torch, nn = try_import_torch()
def _make_continuous_space(space):
if isinstance(space, Box):
return space
elif isinstance(space, Discrete):
return Box(low=np.zeros((space.n,)), high=np.ones((space.n,)))
else:
raise UnsupportedSpaceException("Space {} is not supported.".format(space))
def build_maddpg_models(
policy: Policy, obs_space: Box, action_space: Box, config: TrainerConfigDict
) -> ModelV2:
config["model"]["multiagent"] = config[
"multiagent"
] # Needed for critic obs_space and act_space
if policy.config["use_state_preprocessor"]:
default_model = None # catalog decides
num_outputs = 256 # arbitrary
config["model"]["no_final_linear"] = True
else:
default_model = TorchNoopModel
num_outputs = np.prod(obs_space.shape)
policy.model = ModelCatalog.get_model_v2(
obs_space=obs_space,
action_space=action_space,
num_outputs=num_outputs,
model_config=config["model"],
framework=config["framework"],
model_interface=MADDPGTorchModel,
default_model=default_model,
name="maddpg_model",
actor_hidden_activation=config["actor_hidden_activation"],
actor_hiddens=config["actor_hiddens"],
critic_hidden_activation=config["critic_hidden_activation"],
critic_hiddens=config["critic_hiddens"],
twin_q=config["twin_q"],
add_layer_norm=(
policy.config["exploration_config"].get("type") == "ParameterNoise"
),
)
policy.target_model = ModelCatalog.get_model_v2(
obs_space=obs_space,
action_space=action_space,
num_outputs=num_outputs,
model_config=config["model"],
framework=config["framework"],
model_interface=MADDPGTorchModel,
default_model=default_model,
name="target_maddpg_model",
actor_hidden_activation=config["actor_hidden_activation"],
actor_hiddens=config["actor_hiddens"],
critic_hidden_activation=config["critic_hidden_activation"],
critic_hiddens=config["critic_hiddens"],
twin_q=config["twin_q"],
add_layer_norm=(
policy.config["exploration_config"].get("type") == "ParameterNoise"
),
)
return policy.model
class MADDPGTorchModel(TorchModelV2, nn.Module):
"""
Extension of TorchModelV2 for MADDPG
Note that the critic takes in the joint state and action over all agents
Data flow:
obs -> forward() -> model_out
model_out -> get_policy_output() -> pi(s)
model_out, actions -> get_q_values() -> Q(s, a)
model_out, actions -> get_twin_q_values() -> Q_twin(s, a)
Note that this class by itself is not a valid model unless you
implement forward() in a subclass.
"""
def __init__(
self,
observation_space: Box,
action_space: Box,
num_outputs: int,
model_config: ModelConfigDict,
name: str,
# Extra MADDPGActionModel args:
actor_hiddens: List[int] = [256, 256],
actor_hidden_activation: str = "relu",
critic_hiddens: List[int] = [256, 256],
critic_hidden_activation: str = "relu",
twin_q: bool = False,
add_layer_norm: bool = False,
):
nn.Module.__init__(self)
TorchModelV2.__init__(
self, observation_space, action_space, num_outputs, model_config, name
)
self.bounded = np.logical_and(self.action_space.bounded_above,
self.action_space.bounded_below).any()
self.action_dim = np.product(self.action_space.shape)
# Build the policy network.
self.policy_model = nn.Sequential()
ins = int(np.product(observation_space.shape))
self.obs_ins = ins
activation = get_activation_fn(actor_hidden_activation, framework="torch")
for i, n in enumerate(actor_hiddens):
self.policy_model.add_module(
"action_{}".format(i),
SlimFC(
ins,
n,
initializer=torch.nn.init.xavier_uniform_,
activation_fn=activation,
),
)
# Add LayerNorm after each Dense.
if add_layer_norm:
self.policy_model.add_module(
"LayerNorm_A_{}".format(i), nn.LayerNorm(n)
)
ins = n
self.policy_model.add_module(
"action_out",
SlimFC(
ins,
self.action_dim,
initializer=torch.nn.init.xavier_uniform_,
activation_fn=None,
),
)
# Use sigmoid to scale to [0,1], but also double magnitude of input to
# emulate behaviour of tanh activation used in DDPG and TD3 papers.
# After sigmoid squashing, re-scale to env action space bounds.
class _Lambda(nn.Module):
def __init__(self_):
super().__init__()
low_action = nn.Parameter(
torch.from_numpy(self.action_space.low).float())
low_action.requires_grad = False
self_.register_parameter("low_action", low_action)
action_range = nn.Parameter(
torch.from_numpy(self.action_space.high -
self.action_space.low).float())
action_range.requires_grad = False
self_.register_parameter("action_range", action_range)
def forward(self_, x):
sigmoid_out = nn.Sigmoid()(2.0 * x)
squashed = self_.action_range * sigmoid_out + self_.low_action
return squashed
# Only squash if we have bounded actions.
if self.bounded:
self.policy_model.add_module("action_out_squashed", _Lambda())
# Build MADDPG Critic and Target Critic
obs_space_n = [
_make_continuous_space(space)
for _, (_, space, _, _) in model_config["multiagent"]["policies"].items()
]
act_space_n = [
_make_continuous_space(space)
for _, (_, _, space, _) in model_config["multiagent"]["policies"].items()
]
self.critic_obs = np.sum([obs_space.shape[0] for obs_space in obs_space_n])
self.critic_act = np.sum([act_space.shape[0] for act_space in act_space_n])
# Build the Q-net(s), including target Q-net(s).
def build_q_net(name_):
activation = get_activation_fn(critic_hidden_activation, framework="torch")
q_net = nn.Sequential()
ins = self.critic_obs + self.critic_act
for i, n in enumerate(critic_hiddens):
q_net.add_module(
"{}_hidden_{}".format(name_, i),
SlimFC(
ins,
n,
initializer=nn.init.xavier_uniform_,
activation_fn=activation,
),
)
ins = n
q_net.add_module(
"{}_out".format(name_),
SlimFC(
ins,
1,
initializer=torch.nn.init.xavier_uniform_,
activation_fn=None,
),
)
return q_net
self.q_model = build_q_net("q")
if twin_q:
self.twin_q_model = build_q_net("twin_q")
else:
self.twin_q_model = None
self.view_requirements[SampleBatch.ACTIONS] = ViewRequirement(
SampleBatch.ACTIONS
)
self.view_requirements["new_actions"] = ViewRequirement("new_actions")
self.view_requirements["t"] = ViewRequirement("t")
self.view_requirements[SampleBatch.NEXT_OBS] = ViewRequirement(
data_col=SampleBatch.OBS, shift=1, space=self.obs_space
)
def get_q_values(
self, model_out_n: List[TensorType], act_n: List[TensorType]
) -> TensorType:
"""Return the Q estimates for the most recent forward pass.
This implements Q(s, a).
Args:
model_out_n (List[Tensor]): obs embeddings from the model layers of each agent,
of shape [BATCH_SIZE, num_outputs].
actions (Tensor): Actions from each agent to return the Q-values for.
Shape: [BATCH_SIZE, action_dim].
Returns:
tensor of shape [BATCH_SIZE].
"""
model_out_n = torch.cat(model_out_n, -1)
act_n = torch.cat(act_n, dim=-1)
return self.q_model(torch.cat([model_out_n, act_n], -1))
def get_twin_q_values(
self, model_out_n: TensorType, act_n: TensorType
) -> TensorType:
"""Same as get_q_values but using the twin Q net.
This implements the twin Q(s, a).
Args:
model_out_n (List[Tensor]): obs embeddings from the model layers of each agent,
of shape [BATCH_SIZE, num_outputs].
actions (Tensor): Actions from each agent to return the Q-values for.
Shape: [BATCH_SIZE, action_dim].
Returns:
tensor of shape [BATCH_SIZE].
"""
model_out_n = torch.cat(model_out_n, -1)
act_n = torch.cat(act_n, dim=-1)
return self.twin_q_model(torch.cat([model_out_n, act_n], -1))
def get_policy_output(self, model_out: TensorType) -> TensorType:
"""Return the action output for the most recent forward pass.
This outputs the logits over the action space for discrete actions.
Args:
model_out (Tensor): obs embeddings from the model layers, of shape
[BATCH_SIZE, num_outputs].
Returns:
tensor of shape [BATCH_SIZE, action_out_size]
"""
return self.policy_model(model_out)
def policy_variables(
self, as_dict: bool = False
) -> Union[List[TensorType], Dict[str, TensorType]]:
"""Return the list of variables for the policy net."""
if as_dict:
return self.policy_model.state_dict()
return list(self.policy_model.parameters())
def q_variables(
self, as_dict=False
) -> Union[List[TensorType], Dict[str, TensorType]]:
"""Return the list of variables for Q / twin Q nets."""
if as_dict:
return {
**self.q_model.state_dict(),
**(self.twin_q_model.state_dict() if self.twin_q_model else {}),
}
return list(self.q_model.parameters()) + (
list(self.twin_q_model.parameters()) if self.twin_q_model else []
)