-
Notifications
You must be signed in to change notification settings - Fork 380
/
td3_bc.py
336 lines (317 loc) · 17.3 KB
/
td3_bc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
from typing import List, Dict, Any, Tuple, Union
from easydict import EasyDict
from collections import namedtuple
import torch
import torch.nn.functional as F
import copy
from ding.torch_utils import Adam, to_device
from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample
from ding.model import model_wrap
from ding.utils import POLICY_REGISTRY
from ding.utils.data import default_collate, default_decollate
from .base_policy import Policy
from .common_utils import default_preprocess_learn
from .ddpg import DDPGPolicy
@POLICY_REGISTRY.register('td3_bc')
class TD3BCPolicy(DDPGPolicy):
r"""
Overview:
Policy class of TD3_BC algorithm.
Since DDPG and TD3 share many common things, we can easily derive this TD3_BC
class from DDPG class by changing ``_actor_update_freq``, ``_twin_critic`` and noise in model wrapper.
https://arxiv.org/pdf/2106.06860.pdf
Property:
learn_mode, collect_mode, eval_mode
Config:
== ==================== ======== ================== ================================= =======================
ID Symbol Type Default Value Description Other(Shape)
== ==================== ======== ================== ================================= =======================
1 ``type`` str td3_bc | RL policy register name, refer | this arg is optional,
| to registry ``POLICY_REGISTRY`` | a placeholder
2 ``cuda`` bool True | Whether to use cuda for network |
3 | ``random_`` int 25000 | Number of randomly collected | Default to 25000 for
| ``collect_size`` | training samples in replay | DDPG/TD3, 10000 for
| | buffer when training starts. | sac.
4 | ``model.twin_`` bool True | Whether to use two critic | Default True for TD3,
| ``critic`` | networks or only one. | Clipped Double
| | | Q-learning method in
| | | TD3 paper.
5 | ``learn.learning`` float 1e-3 | Learning rate for actor |
| ``_rate_actor`` | network(aka. policy). |
6 | ``learn.learning`` float 1e-3 | Learning rates for critic |
| ``_rate_critic`` | network (aka. Q-network). |
7 | ``learn.actor_`` int 2 | When critic network updates | Default 2 for TD3, 1
| ``update_freq`` | once, how many times will actor | for DDPG. Delayed
| | network update. | Policy Updates method
| | | in TD3 paper.
8 | ``learn.noise`` bool True | Whether to add noise on target | Default True for TD3,
| | network's action. | False for DDPG.
| | | Target Policy Smoo-
| | | thing Regularization
| | | in TD3 paper.
9 | ``learn.noise_`` dict | dict(min=-0.5, | Limit for range of target |
| ``range`` | max=0.5,) | policy smoothing noise, |
| | | aka. noise_clip. |
10 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only
| ``ignore_done`` | done flag. | in halfcheetah env.
11 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation
| ``target_theta`` | target network. | factor in polyak aver
| | | aging for target
| | | networks.
12 | ``collect.-`` float 0.1 | Used for add noise during co- | Sample noise from dis
| ``noise_sigma`` | llection, through controlling | tribution, Ornstein-
| | the sigma of distribution | Uhlenbeck process in
| | | DDPG paper, Guassian
| | | process in ours.
== ==================== ======== ================== ================================= =======================
"""
# You can refer to DDPG's default config for more details.
config = dict(
# (str) RL policy register name (refer to function "POLICY_REGISTRY").
type='td3_bc',
# (bool) Whether to use cuda for network.
cuda=False,
# (bool type) on_policy: Determine whether on-policy or off-policy.
# on-policy setting influences the behaviour of buffer.
# Default False in TD3.
on_policy=False,
# (bool) Whether use priority(priority sample, IS weight, update priority)
# Default False in TD3.
priority=False,
# (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
priority_IS_weight=False,
# (int) Number of training samples(randomly collected) in replay buffer when training starts.
# Default 25000 in DDPG/TD3.
random_collect_size=25000,
# (bool) Whether use batch normalization for reward
reward_batch_norm=False,
action_space='continuous',
model=dict(
# (bool) Whether to use two critic networks or only one.
# Clipped Double Q-Learning for Actor-Critic in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf).
# Default True for TD3, False for DDPG.
twin_critic=True,
# (str type) action_space: Use regression trick for continous action
action_space='regression',
# (int) Hidden size for actor network head.
actor_head_hidden_size=256,
# (int) Hidden size for critic network head.
critic_head_hidden_size=256,
),
learn=dict(
# How many updates(iterations) to train after collector's one collection.
# Bigger "update_per_collect" means bigger off-policy.
# collect data -> update policy-> collect data -> ...
update_per_collect=1,
# (int) Minibatch size for gradient descent.
batch_size=256,
# (float) Learning rates for actor network(aka. policy).
learning_rate_actor=1e-3,
# (float) Learning rates for critic network(aka. Q-network).
learning_rate_critic=1e-3,
# (bool) Whether ignore done(usually for max step termination env. e.g. pendulum)
# Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers.
# These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000.
# However, interaction with HalfCheetah always gets done with False,
# Since we inplace done==True with done==False to keep
# TD-error accurate computation(``gamma * (1 - done) * next_v + reward``),
# when the episode step is greater than max episode step.
ignore_done=False,
# (float type) target_theta: Used for soft update of the target network,
# aka. Interpolation factor in polyak averaging for target networks.
# Default to 0.005.
target_theta=0.005,
# (float) discount factor for the discounted sum of rewards, aka. gamma.
discount_factor=0.99,
# (int) When critic network updates once, how many times will actor network update.
# Delayed Policy Updates in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf).
# Default 1 for DDPG, 2 for TD3.
actor_update_freq=2,
# (bool) Whether to add noise on target network's action.
# Target Policy Smoothing Regularization in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf).
# Default True for TD3, False for DDPG.
noise=True,
# (float) Sigma for smoothing noise added to target policy.
noise_sigma=0.2,
# (dict) Limit for range of target policy smoothing noise, aka. noise_clip.
noise_range=dict(
min=-0.5,
max=0.5,
),
alpha=2.5,
),
collect=dict(
# (int) Cut trajectories into pieces with length "unroll_len".
unroll_len=1,
# (float) It is a must to add noise during collection. So here omits "noise" and only set "noise_sigma".
noise_sigma=0.1,
),
eval=dict(
evaluator=dict(
# (int) Evaluate every "eval_freq" training iterations.
eval_freq=5000,
),
),
other=dict(
replay_buffer=dict(
# (int) Maximum size of replay buffer.
replay_buffer_size=1000000,
),
),
)
def default_model(self) -> Tuple[str, List[str]]:
return 'continuous_qac', ['ding.model.template.qac']
def _init_learn(self) -> None:
"""
Overview:
Learn mode init method. Called by ``self.__init__``. Init actor and critic optimizers, algorithm config.
"""
super(TD3BCPolicy, self)._init_learn()
self._alpha = self._cfg.learn.alpha
# actor and critic optimizer
self._optimizer_actor = Adam(
self._model.actor.parameters(),
lr=self._cfg.learn.learning_rate_actor,
grad_clip_type='clip_norm',
clip_value=1.0,
)
self._optimizer_critic = Adam(
self._model.critic.parameters(),
lr=self._cfg.learn.learning_rate_critic,
grad_clip_type='clip_norm',
clip_value=1.0,
)
self.noise_sigma = self._cfg.learn.noise_sigma
self.noise_range = self._cfg.learn.noise_range
def _forward_learn(self, data: dict) -> Dict[str, Any]:
r"""
Overview:
Forward and backward function of learn mode.
Arguments:
- data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs']
Returns:
- info_dict (:obj:`Dict[str, Any]`): Including at least actor and critic lr, different losses.
"""
loss_dict = {}
data = default_preprocess_learn(
data,
use_priority=self._cfg.priority,
use_priority_IS_weight=self._cfg.priority_IS_weight,
ignore_done=self._cfg.learn.ignore_done,
use_nstep=False
)
if self._cuda:
data = to_device(data, self._device)
# ====================
# critic learn forward
# ====================
self._learn_model.train()
self._target_model.train()
next_obs = data['next_obs']
reward = data['reward']
if self._reward_batch_norm:
reward = (reward - reward.mean()) / (reward.std() + 1e-8)
# current q value
q_value = self._learn_model.forward(data, mode='compute_critic')['q_value']
q_value_dict = {}
if self._twin_critic:
q_value_dict['q_value'] = q_value[0].mean()
q_value_dict['q_value_twin'] = q_value[1].mean()
else:
q_value_dict['q_value'] = q_value.mean()
# target q value.
with torch.no_grad():
next_action = self._target_model.forward(next_obs, mode='compute_actor')['action']
noise = (torch.randn_like(next_action) *
self.noise_sigma).clamp(self.noise_range['min'], self.noise_range['max'])
next_action = (next_action + noise).clamp(-1, 1)
next_data = {'obs': next_obs, 'action': next_action}
target_q_value = self._target_model.forward(next_data, mode='compute_critic')['q_value']
if self._twin_critic:
# TD3: two critic networks
target_q_value = torch.min(target_q_value[0], target_q_value[1]) # find min one as target q value
# critic network1
td_data = v_1step_td_data(q_value[0], target_q_value, reward, data['done'], data['weight'])
critic_loss, td_error_per_sample1 = v_1step_td_error(td_data, self._gamma)
loss_dict['critic_loss'] = critic_loss
# critic network2(twin network)
td_data_twin = v_1step_td_data(q_value[1], target_q_value, reward, data['done'], data['weight'])
critic_twin_loss, td_error_per_sample2 = v_1step_td_error(td_data_twin, self._gamma)
loss_dict['critic_twin_loss'] = critic_twin_loss
td_error_per_sample = (td_error_per_sample1 + td_error_per_sample2) / 2
else:
# DDPG: single critic network
td_data = v_1step_td_data(q_value, target_q_value, reward, data['done'], data['weight'])
critic_loss, td_error_per_sample = v_1step_td_error(td_data, self._gamma)
loss_dict['critic_loss'] = critic_loss
# ================
# critic update
# ================
self._optimizer_critic.zero_grad()
for k in loss_dict:
if 'critic' in k:
loss_dict[k].backward()
self._optimizer_critic.step()
# ===============================
# actor learn forward and update
# ===============================
# actor updates every ``self._actor_update_freq`` iters
if (self._forward_learn_cnt + 1) % self._actor_update_freq == 0:
actor_data = self._learn_model.forward(data['obs'], mode='compute_actor')
actor_data['obs'] = data['obs']
if self._twin_critic:
q_value = self._learn_model.forward(actor_data, mode='compute_critic')['q_value'][0]
actor_loss = -q_value.mean()
else:
q_value = self._learn_model.forward(actor_data, mode='compute_critic')['q_value']
actor_loss = -q_value.mean()
# add behavior cloning loss weight(\lambda)
lmbda = self._alpha / q_value.abs().mean().detach()
# bc_loss = ((actor_data['action'] - data['action'])**2).mean()
bc_loss = F.mse_loss(actor_data['action'], data['action'])
actor_loss = lmbda * actor_loss + bc_loss
loss_dict['actor_loss'] = actor_loss
# actor update
self._optimizer_actor.zero_grad()
actor_loss.backward()
self._optimizer_actor.step()
# =============
# after update
# =============
loss_dict['total_loss'] = sum(loss_dict.values())
self._forward_learn_cnt += 1
self._target_model.update(self._learn_model.state_dict())
return {
'cur_lr_actor': self._optimizer_actor.defaults['lr'],
'cur_lr_critic': self._optimizer_critic.defaults['lr'],
# 'q_value': np.array(q_value).mean(),
'action': data.get('action').mean(),
'priority': td_error_per_sample.abs().tolist(),
'td_error': td_error_per_sample.abs().mean(),
**loss_dict,
**q_value_dict,
}
def _forward_eval(self, data: dict) -> dict:
r"""
Overview:
Forward function of eval mode, similar to ``self._forward_collect``.
Arguments:
- data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
Returns:
- output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env.
ReturnsKeys
- necessary: ``action``
- optional: ``logit``
"""
data_id = list(data.keys())
data = default_collate(list(data.values()))
if self._cuda:
data = to_device(data, self._device)
self._eval_model.eval()
with torch.no_grad():
output = self._eval_model.forward(data, mode='compute_actor')
if self._cuda:
output = to_device(output, 'cpu')
output = default_decollate(output)
return {i: d for i, d in zip(data_id, output)}