-
Notifications
You must be signed in to change notification settings - Fork 154
/
abstract.py
170 lines (148 loc) · 6.55 KB
/
abstract.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from abc import ABC, abstractmethod
import numpy as np
from gymnasium import spaces
from rl_agents.agents.common.abstract import AbstractStochasticAgent
from rl_agents.agents.common.exploration.abstract import exploration_factory
from rl_agents.agents.common.memory import ReplayMemory, Transition
class AbstractDQNAgent(AbstractStochasticAgent, ABC):
def __init__(self, env, config=None):
super(AbstractDQNAgent, self).__init__(config)
self.env = env
assert isinstance(env.action_space, spaces.Discrete) or isinstance(env.action_space, spaces.Tuple), \
"Only compatible with Discrete action spaces."
self.memory = ReplayMemory(self.config)
self.exploration_policy = exploration_factory(self.config["exploration"], self.env.action_space)
self.training = True
self.previous_state = None
@classmethod
def default_config(cls):
return dict(model=dict(type="DuelingNetwork"),
optimizer=dict(type="ADAM",
lr=5e-4,
weight_decay=0,
k=5),
loss_function="l2",
memory_capacity=50000,
batch_size=100,
gamma=0.99,
device="cuda:best",
exploration=dict(method="EpsilonGreedy"),
target_update=1,
double=True)
def record(self, state, action, reward, next_state, done, info):
"""
Record a transition by performing a Deep Q-Network iteration
- push the transition into memory
- sample a minibatch
- compute the bellman residual loss over the minibatch
- perform one gradient descent step
- slowly track the policy network with the target network
:param state: a state
:param action: an action
:param reward: a reward
:param next_state: a next state
:param done: whether state is terminal
"""
if not self.training:
return
if isinstance(state, tuple) and isinstance(action, tuple): # Multi-agent setting
[self.memory.push(agent_state, agent_action, reward, agent_next_state, done, info)
for agent_state, agent_action, agent_next_state in zip(state, action, next_state)]
else: # Single-agent setting
self.memory.push(state, action, reward, next_state, done, info)
batch = self.sample_minibatch()
if batch:
loss, _, _ = self.compute_bellman_residual(batch)
self.step_optimizer(loss)
self.update_target_network()
def act(self, state, step_exploration_time=True):
"""
Act according to the state-action value model and an exploration policy
:param state: current state
:param step_exploration_time: step the exploration schedule
:return: an action
"""
self.previous_state = state
if step_exploration_time:
self.exploration_policy.step_time()
# Handle multi-agent observations
# TODO: it would be more efficient to forward a batch of states
if isinstance(state, tuple):
return tuple(self.act(agent_state, step_exploration_time=False) for agent_state in state)
# Single-agent setting
values = self.get_state_action_values(state)
self.exploration_policy.update(values)
return self.exploration_policy.sample()
def sample_minibatch(self):
if len(self.memory) < self.config["batch_size"]:
return None
transitions = self.memory.sample(self.config["batch_size"])
return Transition(*zip(*transitions))
def update_target_network(self):
self.steps += 1
if self.steps % self.config["target_update"] == 0:
self.target_net.load_state_dict(self.value_net.state_dict())
@abstractmethod
def compute_bellman_residual(self, batch, target_state_action_value=None):
"""
Compute the Bellman Residual Loss over a batch
:param batch: batch of transitions
:param target_state_action_value: if provided, acts as a target (s,a)-value
if not, it will be computed from batch and model (Double DQN target)
:return: the loss over the batch, and the computed target
"""
raise NotImplementedError
@abstractmethod
def get_batch_state_values(self, states):
"""
Get the state values of several states
:param states: [s1; ...; sN] an array of states
:return: values, actions:
- [V1; ...; VN] the array of the state values for each state
- [a1*; ...; aN*] the array of corresponding optimal action indexes for each state
"""
raise NotImplementedError
@abstractmethod
def get_batch_state_action_values(self, states):
"""
Get the state-action values of several states
:param states: [s1; ...; sN] an array of states
:return: values:[[Q11, ..., Q1n]; ...] the array of all action values for each state
"""
raise NotImplementedError
def get_state_value(self, state):
"""
:param state: s, an environment state
:return: V, its state-value
"""
values, actions = self.get_batch_state_values([state])
return values[0], actions[0]
def get_state_action_values(self, state):
"""
:param state: s, an environment state
:return: [Q(a1,s), ..., Q(an,s)] the array of its action-values for each actions
"""
return self.get_batch_state_action_values([state])[0]
def step_optimizer(self, loss):
raise NotImplementedError
def seed(self, seed=None):
return self.exploration_policy.seed(seed)
def reset(self):
pass
def set_writer(self, writer):
super().set_writer(writer)
try:
self.exploration_policy.set_writer(writer)
except AttributeError:
pass
def action_distribution(self, state):
self.previous_state = state
values = self.get_state_action_values(state)
self.exploration_policy.update(values)
return self.exploration_policy.get_distribution()
def set_time(self, time):
self.exploration_policy.set_time(time)
def eval(self):
self.training = False
self.config['exploration']['method'] = "Greedy"
self.exploration_policy = exploration_factory(self.config["exploration"], self.env.action_space)