forked from mttga/pymarl_transformers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathparallel_runner.py
288 lines (234 loc) · 11 KB
/
parallel_runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
from envs import REGISTRY as env_REGISTRY
from functools import partial
from components.episode_buffer import EpisodeBatch
from multiprocessing import Pipe, Process
from runners.episode_runner import EpisodeRunner
import numpy as np
import torch as th
import copy
# Based (very) heavily on SubprocVecEnv from OpenAI Baselines
# https://github.com/openai/baselines/blob/master/baselines/common/vec_env/subproc_vec_env.py
class ParallelRunner:
def __init__(self, args, logger):
self.args = args
self.logger = logger
self.batch_size = self.args.batch_size_run
# Make subprocesses for the envs
self.parent_conns, self.worker_conns = zip(*[Pipe() for _ in range(self.batch_size)])
env_fn = env_REGISTRY[self.args.env]
env_args = [self.args.env_args.copy() for _ in range(self.batch_size)]
for i in range(self.batch_size):
env_args[i]["seed"] += i
self.ps = [Process(target=env_worker, args=(worker_conn, CloudpickleWrapper(partial(env_fn, **env_arg))))
for env_arg, worker_conn in zip(env_args, self.worker_conns)]
for p in self.ps:
p.daemon = True
p.start()
self.parent_conns[0].send(("get_env_info", None))
self.env_info = self.parent_conns[0].recv()
self.episode_limit = self.env_info["episode_limit"]
self.t = 0
self.t_env = 0
self.train_returns = []
self.test_returns = []
self.train_stats = {}
self.test_stats = {}
self.log_train_stats_t = -100000
# initialize also an episode runner just for the animations
dummy_args = copy.copy(args)
dummy_args.batch_size_run = 1 # set the batch run size to 1 for the episode runner
self.ep_runner = EpisodeRunner(dummy_args, logger=None)
def setup(self, scheme, groups, preprocess, mac):
self.new_batch = partial(EpisodeBatch, scheme, groups, self.batch_size, self.episode_limit + 1,
preprocess=preprocess, device=self.args.device)
self.mac = mac
self.scheme = scheme
self.groups = groups
self.preprocess = preprocess
# setup also the episode runner
self.ep_runner.setup(scheme, groups, preprocess, mac)
def get_env_info(self):
return self.env_info
def save_replay(self, path):
self.ep_runner.save_replay(path)
def save_animation(self, path):
self.ep_runner.save_animation(path)
def close_env(self):
for parent_conn in self.parent_conns:
parent_conn.send(("close", None))
def reset(self):
self.batch = self.new_batch()
# Reset the envs
for parent_conn in self.parent_conns:
parent_conn.send(("reset", None))
pre_transition_data = {
"state": [],
"avail_actions": [],
"obs": []
}
# Get the obs, state and avail_actions back
for parent_conn in self.parent_conns:
data = parent_conn.recv()
pre_transition_data["state"].append(data["state"])
pre_transition_data["avail_actions"].append(data["avail_actions"])
pre_transition_data["obs"].append(data["obs"])
self.batch.update(pre_transition_data, ts=0)
self.t = 0
self.env_steps_this_run = 0
def run(self, test_mode=False,render=False, save_animation=False, benchmark_mode=False):
if test_mode and (render or save_animation or benchmark_mode):
return self.ep_runner.run(test_mode=True, render=render, save_animation=save_animation, benchmark_mode=benchmark_mode)
self.reset()
all_terminated = False
episode_returns = [0 for _ in range(self.batch_size)]
episode_lengths = [0 for _ in range(self.batch_size)]
self.mac.init_hidden(batch_size=self.batch_size)
terminated = [False for _ in range(self.batch_size)]
envs_not_terminated = [b_idx for b_idx, termed in enumerate(terminated) if not termed]
final_env_infos = [] # may store extra stats like battle won. this is filled in ORDER OF TERMINATION
while True:
# Pass the entire batch of experiences up till now to the agents
# Receive the actions for each agent at this timestep in a batch for each un-terminated env
actions = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, bs=envs_not_terminated, test_mode=test_mode)
cpu_actions = actions.to("cpu").numpy()
# Update the actions taken
actions_chosen = {
"actions": actions.unsqueeze(1)
}
self.batch.update(actions_chosen, bs=envs_not_terminated, ts=self.t, mark_filled=False)
# Send actions to each env
action_idx = 0
for idx, parent_conn in enumerate(self.parent_conns):
if idx in envs_not_terminated: # We produced actions for this env
if not terminated[idx]: # Only send the actions to the env if it hasn't terminated
parent_conn.send(("step", cpu_actions[action_idx]))
action_idx += 1 # actions is not a list over every env
# Update envs_not_terminated
envs_not_terminated = [b_idx for b_idx, termed in enumerate(terminated) if not termed]
all_terminated = all(terminated)
if all_terminated:
break
# Post step data we will insert for the current timestep
post_transition_data = {
"reward": [],
"terminated": []
}
# Data for the next step we will insert in order to select an action
pre_transition_data = {
"state": [],
"avail_actions": [],
"obs": []
}
# Receive data back for each unterminated env
for idx, parent_conn in enumerate(self.parent_conns):
if not terminated[idx]:
data = parent_conn.recv()
# Remaining data for this current timestep
post_transition_data["reward"].append((data["reward"],))
episode_returns[idx] += data["reward"]
episode_lengths[idx] += 1
if not test_mode:
self.env_steps_this_run += 1
env_terminated = False
if data["terminated"]:
final_env_infos.append(data["info"])
if data["terminated"] and not data["info"].get("episode_limit", False):
env_terminated = True
terminated[idx] = data["terminated"]
post_transition_data["terminated"].append((env_terminated,))
# Data for the next timestep needed to select an action
pre_transition_data["state"].append(data["state"])
pre_transition_data["avail_actions"].append(data["avail_actions"])
pre_transition_data["obs"].append(data["obs"])
# Add post_transiton data into the batch
self.batch.update(post_transition_data, bs=envs_not_terminated, ts=self.t, mark_filled=False)
# Move onto the next timestep
self.t += 1
# Add the pre-transition data
self.batch.update(pre_transition_data, bs=envs_not_terminated, ts=self.t, mark_filled=True)
if not test_mode:
self.t_env += self.env_steps_this_run
# Get stats back for each env
for parent_conn in self.parent_conns:
parent_conn.send(("get_stats",None))
env_stats = []
for parent_conn in self.parent_conns:
env_stat = parent_conn.recv()
env_stats.append(env_stat)
cur_stats = self.test_stats if test_mode else self.train_stats
cur_returns = self.test_returns if test_mode else self.train_returns
log_prefix = "test_" if test_mode else ""
infos = [cur_stats] + final_env_infos
cur_stats.update({k: sum(d.get(k, 0) for d in infos) for k in set.union(*[set(d) for d in infos])})
cur_stats["n_episodes"] = self.batch_size + cur_stats.get("n_episodes", 0)
cur_stats["ep_length"] = sum(episode_lengths) + cur_stats.get("ep_length", 0)
cur_returns.extend(episode_returns)
n_test_runs = max(1, self.args.test_nepisode // self.batch_size) * self.batch_size
if test_mode and (len(self.test_returns) == n_test_runs):
self._log(cur_returns, cur_stats, log_prefix)
elif self.t_env - self.log_train_stats_t >= self.args.runner_log_interval:
self._log(cur_returns, cur_stats, log_prefix)
if hasattr(self.mac.action_selector, "epsilon"):
self.logger.log_stat("epsilon", self.mac.action_selector.epsilon, self.t_env)
self.log_train_stats_t = self.t_env
return self.batch
def _log(self, returns, stats, prefix):
self.logger.log_stat(prefix + "return_mean", np.mean(returns), self.t_env)
self.logger.log_stat(prefix + "return_std", np.std(returns), self.t_env)
returns.clear()
for k, v in stats.items():
if k != "n_episodes":
self.logger.log_stat(prefix + k + "_mean" , v/stats["n_episodes"], self.t_env)
stats.clear()
def env_worker(remote, env_fn):
# Make environment
env = env_fn.x()
while True:
cmd, data = remote.recv()
if cmd == "step":
actions = data
# Take a step in the environment
reward, terminated, env_info = env.step(actions)
# Return the observations, avail_actions and state to make the next action
state = env.get_state()
avail_actions = env.get_avail_actions()
obs = env.get_obs()
remote.send({
# Data for the next timestep needed to pick an action
"state": state,
"avail_actions": avail_actions,
"obs": obs,
# Rest of the data for the current timestep
"reward": reward,
"terminated": terminated,
"info": env_info
})
elif cmd == "reset":
env.reset()
remote.send({
"state": env.get_state(),
"avail_actions": env.get_avail_actions(),
"obs": env.get_obs()
})
elif cmd == "close":
env.close()
remote.close()
break
elif cmd == "get_env_info":
remote.send(env.get_env_info())
elif cmd == "get_stats":
remote.send(env.get_stats())
else:
raise NotImplementedError
class CloudpickleWrapper():
"""
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
"""
def __init__(self, x):
self.x = x
def __getstate__(self):
import cloudpickle
return cloudpickle.dumps(self.x)
def __setstate__(self, ob):
import pickle
self.x = pickle.loads(ob)