-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbaselines.py
301 lines (252 loc) · 12.2 KB
/
baselines.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
""" Wrappers for use with jaxmarl baselines. """
import jax
import jax.numpy as jnp
import chex
import numpy as np
from flax import struct
from functools import partial
# from gymnax.environments import environment, spaces
from gymnax.environments.spaces import Box as BoxGymnax, Discrete as DiscreteGymnax
from typing import Optional, List, Tuple, Union
from jaxmarl.environments.spaces import Box, Discrete, MultiDiscrete
from jaxmarl.environments.multi_agent_env import MultiAgentEnv, State
class JaxMARLWrapper(object):
"""Base class for all jaxmarl wrappers."""
def __init__(self, env: MultiAgentEnv):
self._env = env
def __getattr__(self, name: str):
return getattr(self._env, name)
def _batchify_floats(self, x: dict):
return jnp.stack([x[a] for a in self._env.agents])
@struct.dataclass
class LogEnvState:
env_state: State
episode_returns: float
episode_lengths: int
returned_episode_returns: float
returned_episode_lengths: int
class LogWrapper(JaxMARLWrapper):
"""Log the episode returns and lengths.
NOTE for now for envs where agents terminate at the same time.
"""
def __init__(self, env: MultiAgentEnv, replace_info: bool = False):
super().__init__(env)
self.replace_info = replace_info
@partial(jax.jit, static_argnums=(0,))
def reset(self, key: chex.PRNGKey) -> Tuple[chex.Array, State]:
obs, env_state = self._env.reset(key)
state = LogEnvState(
env_state,
jnp.zeros((self._env.num_agents,)),
jnp.zeros((self._env.num_agents,)),
jnp.zeros((self._env.num_agents,)),
jnp.zeros((self._env.num_agents,)),
)
return obs, state
@partial(jax.jit, static_argnums=(0,))
def step(
self,
key: chex.PRNGKey,
state: LogEnvState,
action: Union[int, float],
) -> Tuple[chex.Array, LogEnvState, float, bool, dict]:
obs, env_state, reward, done, info = self._env.step(
key, state.env_state, action
)
ep_done = done["__all__"]
new_episode_return = state.episode_returns + jnp.squeeze(self._batchify_floats(reward), axis=(1, 2))
new_episode_length = state.episode_lengths + 1
state = LogEnvState(
env_state=env_state,
episode_returns=new_episode_return * (1 - ep_done),
episode_lengths=new_episode_length * (1 - ep_done),
returned_episode_returns=state.returned_episode_returns * (1 - ep_done)
+ new_episode_return * ep_done,
returned_episode_lengths=state.returned_episode_lengths * (1 - ep_done)
+ new_episode_length * ep_done,
)
if self.replace_info:
info = {}
info["returned_episode_returns"] = state.returned_episode_returns
info["returned_episode_lengths"] = state.returned_episode_lengths
info["returned_episode"] = jnp.full((self._env.num_agents,), ep_done)
return obs, state, reward, done, info
class MPELogWrapper(LogWrapper):
""" Times reward signal by number of agents within the environment,
to match the on-policy codebase. """
@partial(jax.jit, static_argnums=(0,))
def step(
self,
key: chex.PRNGKey,
state: LogEnvState,
action: Union[int, float],
) -> Tuple[chex.Array, LogEnvState, float, bool, dict]:
obs, env_state, reward, done, info = self._env.step(
key, state.env_state, action
)
rewardlog = jax.tree_map(lambda x: x*self._env.num_agents, reward)
ep_done = done["__all__"]
new_episode_return = state.episode_returns + self._batchify_floats(rewardlog)
new_episode_length = state.episode_lengths + 1
state = LogEnvState(
env_state=env_state,
episode_returns=new_episode_return * (1 - ep_done),
episode_lengths=new_episode_length * (1 - ep_done),
returned_episode_returns=state.returned_episode_returns * (1 - ep_done)
+ new_episode_return * ep_done,
returned_episode_lengths=state.returned_episode_lengths * (1 - ep_done)
+ new_episode_length * ep_done,
)
if self.replace_info:
info = {}
info["returned_episode_returns"] = state.returned_episode_returns
info["returned_episode_lengths"] = state.returned_episode_lengths
info["returned_episode"] = jnp.full((self._env.num_agents,), ep_done)
return obs, state, reward, done, info
@struct.dataclass
class SMAXLogEnvState:
env_state: State
episode_returns: float
episode_lengths: int
won_episode: int
returned_episode_returns: float
returned_episode_lengths: int
returned_won_episode: int
class SMAXLogWrapper(JaxMARLWrapper):
def __init__(self, env: MultiAgentEnv, replace_info: bool = False):
super().__init__(env)
self.replace_info = replace_info
@partial(jax.jit, static_argnums=(0,))
def reset(self, key: chex.PRNGKey) -> Tuple[chex.Array, State]:
obs, env_state = self._env.reset(key)
state = SMAXLogEnvState(
env_state,
jnp.zeros((self._env.num_agents,)),
jnp.zeros((self._env.num_agents,)),
jnp.zeros((self._env.num_agents,)),
jnp.zeros((self._env.num_agents,)),
jnp.zeros((self._env.num_agents,)),
jnp.zeros((self._env.num_agents,)),
)
return obs, state
@partial(jax.jit, static_argnums=(0,))
def step(
self,
key: chex.PRNGKey,
state: SMAXLogEnvState,
action: Union[int, float],
) -> Tuple[chex.Array, LogEnvState, float, bool, dict]:
obs, env_state, reward, done, info = self._env.step(
key, state.env_state, action
)
ep_done = done["__all__"]
batch_reward = self._batchify_floats(reward)
new_episode_return = state.episode_returns + self._batchify_floats(reward)
new_episode_length = state.episode_lengths + 1
new_won_episode = (batch_reward >= 1.0).astype(jnp.float32)
state = SMAXLogEnvState(
env_state=env_state,
won_episode=new_won_episode * (1 - ep_done),
episode_returns=new_episode_return * (1 - ep_done),
episode_lengths=new_episode_length * (1 - ep_done),
returned_episode_returns=state.returned_episode_returns * (1 - ep_done)
+ new_episode_return * ep_done,
returned_episode_lengths=state.returned_episode_lengths * (1 - ep_done)
+ new_episode_length * ep_done,
returned_won_episode=state.returned_won_episode * (1 - ep_done)
+ new_won_episode * ep_done,
)
if self.replace_info:
info = {}
info["returned_episode_returns"] = state.returned_episode_returns
info["returned_episode_lengths"] = state.returned_episode_lengths
info["returned_won_episode"] = state.returned_won_episode
info["returned_episode"] = jnp.full((self._env.num_agents,), ep_done)
return obs, state, reward, done, info
def get_space_dim(space):
# get the proper action/obs space from Discrete-MultiDiscrete-Box spaces
if isinstance(space, (DiscreteGymnax, Discrete)):
return space.n
elif isinstance(space, (BoxGymnax, Box, MultiDiscrete)):
return np.prod(space.shape)
else:
print(space)
raise NotImplementedError('Current wrapper works only with Discrete/MultiDiscrete/Box action and obs spaces')
class CTRolloutManager(JaxMARLWrapper):
"""
Rollout Manager for Centralized Training of with Parameters Sharing. Used by JaxMARL Q-Learning Baselines.
- Batchify multiple environments (the number of parallel envs is defined by batch_size in __init__).
- Adds a global state (obs["__all__"]) and a global reward (rewards["__all__"]) in the env.step returns.
- Pads the observations of the agents in order to have all the same length.
- Adds an agent id (one hot encoded) to the observation vectors.
By default:
- global_state is the concatenation of all agents' observations.
- global_reward is the sum of all agents' rewards.
"""
def __init__(self, env: MultiAgentEnv, batch_size:int, training_agents:List=None, preprocess_obs:bool=True):
super().__init__(env)
self.batch_size = batch_size
self.training_agents = self.agents if training_agents is None else training_agents
self.preprocess_obs = preprocess_obs
if len(env.observation_spaces) == 0:
self.observation_spaces = {agent:self.observation_space() for agent in self.agents}
if len(env.action_spaces) == 0:
self.action_spaces = {agent:env.action_space() for agent in self.agents}
self.batch_samplers = {agent: jax.jit(jax.vmap(self.action_space(agent).sample, in_axes=0)) for agent in self.agents}
self.max_obs_length = max(list(map(lambda x: get_space_dim(x), self.observation_spaces.values())))
self.max_action_space = max(list(map(lambda x: get_space_dim(x), self.action_spaces.values())))
self.obs_size = self.max_obs_length + len(self.agents)
self.agents_one_hot = {a:oh for a, oh in zip(self.agents, jnp.eye(len(self.agents)))}
self.valid_actions = {a:jnp.arange(u.n) for a, u in self.action_spaces.items()}
self.valid_actions_oh ={a:jnp.concatenate((jnp.ones(u.n), jnp.zeros(self.max_action_space - u.n))) for a, u in self.action_spaces.items()}
if 'smax' in env.name.lower():
self.global_state = lambda obs, state: obs['world_state']
self.global_reward = lambda rewards: rewards[self.training_agents[0]]
elif 'overcooked' in env.name.lower():
self.global_state = lambda obs, state: jnp.concatenate([obs[agent].ravel() for agent in self.agents], axis=-1)
self.global_reward = lambda rewards: rewards[self.training_agents[0]]
@partial(jax.jit, static_argnums=0)
def batch_reset(self, key):
keys = jax.random.split(key, self.batch_size)
return jax.vmap(self.wrapped_reset, in_axes=0)(keys)
@partial(jax.jit, static_argnums=0)
def batch_step(self, key, states, actions):
keys = jax.random.split(key, self.batch_size)
return jax.vmap(self.wrapped_step, in_axes=(0, 0, 0))(keys, states, actions)
@partial(jax.jit, static_argnums=0)
def wrapped_reset(self, key):
obs_, state = self._env.reset(key)
if self.preprocess_obs:
obs = jax.tree_util.tree_map(self._preprocess_obs, {agent:obs_[agent] for agent in self.agents}, self.agents_one_hot)
else:
obs = obs_
obs["__all__"] = self.global_state(obs_, state)
return obs, state
@partial(jax.jit, static_argnums=0)
def wrapped_step(self, key, state, actions):
if 'hanabi' in self._env.name.lower():
actions = jax.tree_util.tree_map(lambda x:jnp.expand_dims(x, 0), actions)
obs_, state, reward, done, infos = self._env.step(key, state, actions)
if self.preprocess_obs:
obs = jax.tree_util.tree_map(self._preprocess_obs, {agent:obs_[agent] for agent in self.agents}, self.agents_one_hot)
obs = jax.tree_util.tree_map(lambda d, o: jnp.where(d, 0., o), {agent:done[agent] for agent in self.agents}, obs)
else:
obs = obs_
obs["__all__"] = self.global_state(obs_, state)
reward["__all__"] = self.global_reward(reward)
return obs, state, reward, done, infos
@partial(jax.jit, static_argnums=0)
def global_state(self, obs, state):
return jnp.concatenate([obs[agent] for agent in self.agents], axis=-1)
@partial(jax.jit, static_argnums=0)
def global_reward(self, reward):
return jnp.stack([reward[agent] for agent in self.training_agents]).sum(axis=0)
def batch_sample(self, key, agent):
return self.batch_samplers[agent](jax.random.split(key, self.batch_size)).astype(int)
@partial(jax.jit, static_argnums=0)
def _preprocess_obs(self, arr, extra_features):
arr = arr.flatten()
pad_width = [(0, 0)] * (arr.ndim - 1) + [(0, max(0, self.max_obs_length - arr.shape[-1]))]
arr = jnp.pad(arr, pad_width, mode='constant', constant_values=0)
arr = jnp.concatenate((arr, extra_features), axis=-1)
return arr