forked from tensorflow/models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvaluerl_learner.py
81 lines (66 loc) · 3.47 KB
/
valuerl_learner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
import numpy as np
import os
from learner import Learner
from valuerl import ValueRL
from worldmodel import DeterministicWorldModel
class ValueRLLearner(Learner):
"""
ValueRL-specific training loop details.
"""
def learner_name(self): return "valuerl"
def make_loader_placeholders(self):
self.obs_loader = tf.placeholder(tf.float32, [self.learner_config["batch_size"], np.prod(self.env_config["obs_dims"])])
self.next_obs_loader = tf.placeholder(tf.float32, [self.learner_config["batch_size"], np.prod(self.env_config["obs_dims"])])
self.action_loader = tf.placeholder(tf.float32, [self.learner_config["batch_size"], self.env_config["action_dim"]])
self.reward_loader = tf.placeholder(tf.float32, [self.learner_config["batch_size"]])
self.done_loader = tf.placeholder(tf.float32, [self.learner_config["batch_size"]])
self.datasize_loader = tf.placeholder(tf.float64, [])
return [self.obs_loader, self.next_obs_loader, self.action_loader, self.reward_loader, self.done_loader, self.datasize_loader]
def make_core_model(self):
if self.config["model_config"] is not False:
self.worldmodel = DeterministicWorldModel(self.config["name"], self.env_config, self.config["model_config"])
else:
self.worldmodel = None
valuerl = ValueRL(self.config["name"], self.env_config, self.learner_config)
(policy_loss, Q_loss), inspect_losses = valuerl.build_training_graph(*self.current_batch, worldmodel=self.worldmodel)
policy_optimizer = tf.train.AdamOptimizer(3e-4)
policy_gvs = policy_optimizer.compute_gradients(policy_loss, var_list=valuerl.policy_params)
capped_policy_gvs = policy_gvs
policy_train_op = policy_optimizer.apply_gradients(capped_policy_gvs)
Q_optimizer = tf.train.AdamOptimizer(3e-4)
Q_gvs = Q_optimizer.compute_gradients(Q_loss, var_list=valuerl.Q_params)
capped_Q_gvs = Q_gvs
Q_train_op = Q_optimizer.apply_gradients(capped_Q_gvs)
return valuerl, (policy_loss, Q_loss), (policy_train_op, Q_train_op), inspect_losses
## Optional functions to override
def initialize(self):
if self.config["model_config"] is not False:
while not self.load_worldmodel(): pass
def resume_from_checkpoint(self, epoch):
if self.config["model_config"] is not False:
with self.bonus_kwargs["model_lock"]: self.worldmodel.load(self.sess, self.save_path, epoch)
def checkpoint(self):
self.core.copy_to_old(self.sess)
if self.config["model_config"] is not False:
self.load_worldmodel()
def backup(self): pass
# Other functions
def load_worldmodel(self):
if not os.path.exists("%s/%s.params.index" % (self.save_path, self.worldmodel.saveid)): return False
with self.bonus_kwargs["model_lock"]: self.worldmodel.load(self.sess, self.save_path)
return True