-
Notifications
You must be signed in to change notification settings - Fork 121
/
ddpg_with_vae.py
146 lines (120 loc) · 5.53 KB
/
ddpg_with_vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# Copyright (c) 2018 Roma Sokolkov
# MIT License
"""
DDPGWithVAE inherits DDPG from stable-baselines
and reimplements learning method.
"""
import time
import numpy as np
from mpi4py import MPI
from stable_baselines import logger
from stable_baselines.ddpg.ddpg import DDPG
class DDPGWithVAE(DDPG):
"""
Modified learn method from stable-baselines
- Stop rollout on episode done.
- More verbosity.
- Add VAE optimization step.
"""
def learn(self, total_timesteps, callback=None, vae=None, skip_episodes=5):
rank = MPI.COMM_WORLD.Get_rank()
# we assume symmetric actions.
assert np.all(np.abs(self.env.action_space.low) == self.env.action_space.high)
self.episode_reward = np.zeros((1,))
with self.sess.as_default(), self.graph.as_default():
# Prepare everything.
self._reset()
episode_reward = 0.
episode_step = 0
episodes = 0
step = 0
total_steps = 0
start_time = time.time()
actor_losses = []
critic_losses = []
while True:
obs = self.env.reset()
# Rollout one episode.
while True:
if total_steps >= total_timesteps:
return self
# Predict next action.
action, q_value = self._policy(obs, apply_noise=True, compute_q=True)
print(action)
assert action.shape == self.env.action_space.shape
# Execute next action.
if rank == 0 and self.render:
self.env.render()
new_obs, reward, done, _ = self.env.step(action * np.abs(self.action_space.low))
step += 1
total_steps += 1
if rank == 0 and self.render:
self.env.render()
episode_reward += reward
episode_step += 1
# Book-keeping.
# Do not record observations, while we skip DDPG training.
if (episodes + 1) > skip_episodes:
self._store_transition(obs, action, reward, new_obs, done)
obs = new_obs
if callback is not None:
callback(locals(), globals())
if done:
print("episode finished. Reward: ", episode_reward)
# Episode done.
episode_reward = 0.
episode_step = 0
episodes += 1
self._reset()
obs = self.env.reset()
# Finish rollout on episode finish.
break
print("rollout finished")
# Train VAE.
train_start = time.time()
vae.optimize()
print("VAE training duration:", time.time() - train_start)
# Train DDPG.
actor_losses = []
critic_losses = []
train_start = time.time()
if episodes > skip_episodes:
for t_train in range(self.nb_train_steps):
critic_loss, actor_loss = self._train_step(0, None, log=t_train == 0)
critic_losses.append(critic_loss)
actor_losses.append(actor_loss)
self._update_target_net()
print("DDPG training duration:", time.time() - train_start)
mpi_size = MPI.COMM_WORLD.Get_size()
# Log stats.
# XXX shouldn't call np.mean on variable length lists
duration = time.time() - start_time
stats = self._get_stats()
combined_stats = stats.copy()
combined_stats['train/loss_actor'] = np.mean(actor_losses)
combined_stats['train/loss_critic'] = np.mean(critic_losses)
combined_stats['total/duration'] = duration
combined_stats['total/steps_per_second'] = float(step) / float(duration)
combined_stats['total/episodes'] = episodes
def as_scalar(scalar):
"""
check and return the input if it is a scalar, otherwise raise ValueError
:param scalar: (Any) the object to check
:return: (Number) the scalar if x is a scalar
"""
if isinstance(scalar, np.ndarray):
assert scalar.size == 1
return scalar[0]
elif np.isscalar(scalar):
return scalar
else:
raise ValueError('expected scalar, got %s' % scalar)
combined_stats_sums = MPI.COMM_WORLD.allreduce(
np.array([as_scalar(x) for x in combined_stats.values()]))
combined_stats = {k: v / mpi_size for (k, v) in zip(combined_stats.keys(), combined_stats_sums)}
# Total statistics.
combined_stats['total/steps'] = step
for key in sorted(combined_stats.keys()):
logger.record_tabular(key, combined_stats[key])
logger.dump_tabular()
logger.info('')