-
Notifications
You must be signed in to change notification settings - Fork 6
/
tutorial_tensordb_atari_pong_trainer.py
executable file
·128 lines (115 loc) · 4.68 KB
/
tutorial_tensordb_atari_pong_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#! /usr/bin/python
# -*- coding: utf8 -*-
"""
To understand Reinforcement Learning, we let computer to learn how to play
Pong game from the original screen inputs. Before we start, we highly recommend
you to go through a famous blog called “Deep Reinforcement Learning: Pong from
Pixels” which is a minimalistic implementation of deep reinforcement learning by
using python-numpy and OpenAI gym environment.
The code here is the reimplementation of Karpathy's Blog by using TensorLayer.
Link
-----
http://karpathy.github.io/2016/05/31/rl/
"""
import tensorflow as tf
import tensorlayer as tl
import gym
import numpy as np
import time
# hyperparameters
image_size = 80
D = image_size * image_size
H = 200
batch_size = 10
learning_rate = 1e-4
gamma = 0.99
decay_rate = 0.99
# render = False # display the game environment
# resume = False # load existing policy network
# model_file_name = "model_pong"
np.set_printoptions(threshold=np.nan)
from tensorlayer.db import TensorDB
# This is to initialize the connection to your MondonDB server
# Note: make sure your MongoDB is reachable before changing this line
db = TensorDB(ip='IP_ADDRESS_OR_YOUR_MONGODB', port=27017, db_name='DATABASE_NAME', user_name=None, password=None, studyID='ANY_ID (e.g., mnist)')
# def prepro(I):
# """ prepro 210x160x3 uint8 frame into 6400 (80x80) 1D float vector """
# I = I[35:195]
# I = I[::2,::2,0]
# I[I == 144] = 0
# I[I == 109] = 0
# I[I != 0] = 1
# return I.astype(np.float).ravel()
# env = gym.make("Pong-v0")
# observation = env.reset()
# prev_x = None
# running_reward = None
# reward_sum = 0
# episode_number = 0
# xs, ys, rs = [], [], []
# observation for training and inference
states_batch_pl = tf.placeholder(tf.float32, shape=[None, D])
# policy network
net = tl.layers.InputLayer(states_batch_pl, name='input')
net = tl.layers.DenseLayer(net, n_units=H, act=tf.nn.relu, name='relu1')
net = tl.layers.DenseLayer(net, n_units=3, act=tf.identity, name='output')
probs = net.outputs
sampling_prob = tf.nn.softmax(probs)
actions_batch_pl = tf.placeholder(tf.int32, shape=[None])
discount_rewards_batch_pl = tf.placeholder(tf.float32, shape=[None])
loss = tl.rein.cross_entropy_reward_loss(probs, actions_batch_pl,
discount_rewards_batch_pl)
train_op = tf.train.RMSPropOptimizer(learning_rate, decay_rate).minimize(loss)
with tf.Session() as sess:
tl.layers.initialize_global_variables(sess)
# if resume:
# load_params = tl.files.load_npz(name=model_file_name+'.npz')
# tl.files.assign_params(sess, load_params, net)
net.print_params()
net.print_layers()
start_time = time.time()
game_number = 0
n = 0
total_n_examples = 0
while True:
is_found = False
while is_found is False:
## read on
data, f_id = db.find_one_params(args={'type': 'train_data'}, lz4_decomp=True)
if (data is not False):
epx, epy, epr = data
db.del_params(args={'type': 'train_data', 'f_id': f_id})
is_found = True
else:
# print("Waiting training data")
time.sleep(0.5)
## read all
# temp = db.find_all_params(args={'type': 'train'})
# if (temp is not False):
# epx = temp[0][0]
# for i in range(1, len(temp[0])):
# epx = np.append(epx, temp[i][0], axis = 0)
# epy = temp[0][1]
# for i in range(1, len(temp[1])):
# epy = np.append(epy, temp[i][1], axis = 0)
# epr = temp[0][2]
# for i in range(1, len(temp[2])):
# epr = np.append(epr, temp[i][2], axis = 0)
# is_found = True
# break
disR = tl.rein.discount_episode_rewards(epr, gamma)
disR -= np.mean(disR)
disR /= np.std(disR)
sess.run(train_op,{
states_batch_pl: epx,
actions_batch_pl: epy,
discount_rewards_batch_pl: disR
})
n_examples = epx.shape[0]
total_n_examples += n_examples
print("[*] Update {}: n_examples: {} / total averaged speed: {} examples/second".format(n, n_examples,
round(total_n_examples/(time.time() - start_time), 2)))
n += 1
if n % 10 == 0:
db.del_params(args={'type': 'network_parameters'})
db.save_params(sess.run(net.all_params), args={'type': 'network_parameters'}, lz4_comp=True)#, file_name='network_parameters')