-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathagent.py
129 lines (105 loc) · 4.22 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import time
import random
import collections
import gym
import numpy as np
import tensorflow as tf
from deep_q_network import *
class Agent:
def __init__(self, env):
# DQN Env Variables
self.env = env
self.actions = self.env.action_space.n
self.observations = self.env.observation_space.shape
# DQN Agent Variables
self.epsilon = 1.0
self.epsilon_min = 0.01
self.epsilon_decay = 0.995
self.replay_buffer_size = 50000
self.train_start = 2
self.memory = collections.deque(maxlen=self.replay_buffer_size)
self.gamma = 0.95
# DQN Network Variables
self.state_shape = self.observations
self.learning_rate = 1e-3
self.model = DQN(self.state_shape, self.actions, self.learning_rate)
self.batch_size = 2
def get_action(self, state):
if np.random.rand() <= self.epsilon:
return np.random.randint(self.actions)
else:
return np.argmax(self.model.predict(state))
def train(self, num_episodes):
best_total_reward = 0.0
for episode in range(num_episodes):
total_reward = 0.0
state = self.env.reset()
state = np.reshape(state, (1, state.shape[0]))
print('state reshaped', state)
while True:
action = self.get_action(state)
print('chosen action', action)
next_state, reward, done, _ = self.env.step(action)
next_state = np.reshape(next_state, (1, next_state.shape[0]))
if done and total_reward < 499:
reward = -100
self.remember(state, action, reward, next_state, done)
self.replay()
total_reward += reward
state = next_state
if done:
if total_reward != 500:
total_reward += 100
if total_reward > best_total_reward:
best_total_reward = total_reward
if not os.path.exists('models'):
os.makedirs('models')
self.model.save_model('models/policy_weights.h5')
print("Episode: ", episode+1, " Total Reward: ", total_reward, " Epsilon:", self.epsilon)
break
def remember(self, state, action, reward, next_state, done):
self.memory.append((state, action, reward, next_state, done))
print('memory:', self.memory)
if self.epsilon > self.epsilon_min:
self.epsilon *= self.epsilon_decay
def replay(self):
if len(self.memory) < self.train_start:
return
minibatch = random.sample(self.memory, self.batch_size)
print('minibatch', minibatch)
states, actions, rewards, states_next, dones = zip(*minibatch)
states = np.concatenate(states)
print('states concatenated', states)
states_next = np.concatenate(states_next)
print('states next concatenated', states)
q_values = self.model.predict(states)
q_values_next = self.model.predict(states_next)
for i in range(self.batch_size):
a = actions[i]
print('action a', a)
done = dones[i]
if done:
q_values[i][a] = rewards[i]
else:
q_values[i][a] = rewards[i] + self.gamma * np.max(q_values_next[i])
self.model.train(states, q_values)
def play(self, num_episodes, render=True):
self.model.load_model('models/policy_weights.h5')
for episode in range(num_episodes):
state = self.env.reset()
state = np.reshape(state, (1, state.shape[0]))
while True:
action = self.get_action(state)
next_state, reward, done, _ = self.env.step(action)
next_state = np.reshape(next_state, (1, next_state.shape[0]))
state = next_state
if render:
self.env.render()
if done:
break
if __name__ == "__main__":
env = gym.make("CartPole-v1")
agent = Agent(env)
agent.train(num_episodes=2)
agent.play(10)