-
Notifications
You must be signed in to change notification settings - Fork 0
/
policy_network.py
85 lines (69 loc) · 3.11 KB
/
policy_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os.path
import numpy as np
import tensorflow as tf
OBSERVATIONS_SIZE = 6400
class Network:
def __init__(self, hidden_layer_size, learning_rate, checkpoints_dir):
self.learning_rate = learning_rate
self.sess = tf.InteractiveSession()
self.observations = tf.placeholder(tf.float32,
[None, OBSERVATIONS_SIZE])
# +1 for up, -1 for down
self.sampled_actions = tf.placeholder(tf.float32, [None, 1])
self.advantage = tf.placeholder(
tf.float32, [None, 1], name='advantage')
h = tf.layers.dense(
self.observations,
units=hidden_layer_size,
activation=tf.nn.relu,
kernel_initializer=tf.contrib.layers.xavier_initializer())
self.up_probability = tf.layers.dense(
h,
units=1,
activation=tf.sigmoid,
kernel_initializer=tf.contrib.layers.xavier_initializer())
# Train based on the log probability of the sampled action.
#
# The idea is to encourage actions taken in rounds where the agent won,
# and discourage actions in rounds where the agent lost.
# More specifically, we want to increase the log probability of winning
# actions, and decrease the log probability of losing actions.
#
# Which direction to push the log probability in is controlled by
# 'advantage', which is the reward for each action in each round.
# Positive reward pushes the log probability of chosen action up;
# negative reward pushes the log probability of the chosen action down.
self.loss = tf.losses.log_loss(
labels=self.sampled_actions,
predictions=self.up_probability,
weights=self.advantage)
optimizer = tf.train.AdamOptimizer(self.learning_rate)
self.train_op = optimizer.minimize(self.loss)
tf.global_variables_initializer().run()
self.saver = tf.train.Saver()
self.checkpoint_file = os.path.join(checkpoints_dir,
'policy_network.ckpt')
def load_checkpoint(self):
print("Loading checkpoint...")
self.saver.restore(self.sess, self.checkpoint_file)
def save_checkpoint(self):
print("Saving checkpoint...")
self.saver.save(self.sess, self.checkpoint_file)
def forward_pass(self, observations):
up_probability = self.sess.run(
self.up_probability,
feed_dict={self.observations: observations.reshape([1, -1])})
return up_probability
def train(self, state_action_reward_tuples):
print("Training with %d (state, action, reward) tuples" %
len(state_action_reward_tuples))
states, actions, rewards = zip(*state_action_reward_tuples)
states = np.vstack(states)
actions = np.vstack(actions)
rewards = np.vstack(rewards)
feed_dict = {
self.observations: states,
self.sampled_actions: actions,
self.advantage: rewards
}
self.sess.run(self.train_op, feed_dict)