-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathpredictron.py
154 lines (139 loc) · 5.34 KB
/
predictron.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import inspect
import chainer
from chainer import functions as F
from chainer import links as L
class Sequence(chainer.ChainList):
def __init__(self, *layers):
self.layers = layers
links = [layer for layer in layers if isinstance(layer, chainer.Link)]
super().__init__(*links)
def __call__(self, x, test):
h = x
for layer in self.layers:
argnames = inspect.getargspec(layer)[0]
if 'test' in argnames:
h = layer(h, test=test)
else:
h = layer(h)
return h
class PredictronCore(chainer.Chain):
def __init__(self, n_tasks, n_channels):
super().__init__(
state2hidden=Sequence(
L.Convolution2D(n_channels, n_channels, ksize=3, pad=1),
L.BatchNormalization(n_channels),
F.relu,
),
hidden2nextstate=Sequence(
L.Convolution2D(n_channels, n_channels, ksize=3, pad=1),
L.BatchNormalization(n_channels),
F.relu,
L.Convolution2D(n_channels, n_channels, ksize=3, pad=1),
L.BatchNormalization(n_channels),
F.relu,
),
hidden2reward=Sequence(
L.Linear(None, n_channels),
L.BatchNormalization(n_channels),
F.relu,
L.Linear(n_channels, n_tasks),
),
hidden2gamma=Sequence(
L.Linear(None, n_channels),
L.BatchNormalization(n_channels),
F.relu,
L.Linear(n_channels, n_tasks),
F.sigmoid,
),
hidden2lambda=Sequence(
L.Linear(None, n_channels),
L.BatchNormalization(n_channels),
F.relu,
L.Linear(n_channels, n_tasks),
F.sigmoid,
),
)
def __call__(self, x, test):
hidden = self.state2hidden(x, test=test)
# No skip
nextstate = self.hidden2nextstate(hidden, test=test)
reward = self.hidden2reward(hidden, test=test)
gamma = self.hidden2gamma(hidden, test=test)
# lambda doesn't backprop errors to states
lmbda = self.hidden2lambda(
chainer.Variable(hidden.data), test=test)
return nextstate, reward, gamma, lmbda
class Predictron(chainer.Chain):
def __init__(self, n_tasks, n_channels, model_steps,
use_reward_gamma=True, use_lambda=True, usage_weighting=True):
self.model_steps = model_steps
self.use_reward_gamma = use_reward_gamma
self.use_lambda = use_lambda
self.usage_weighting = usage_weighting
super().__init__(
obs2state=Sequence(
L.Convolution2D(None, n_channels, ksize=3, pad=1),
L.BatchNormalization(n_channels),
F.relu,
L.Convolution2D(n_channels, n_channels, ksize=3, pad=1),
L.BatchNormalization(n_channels),
F.relu,
),
core=PredictronCore(n_tasks=n_tasks, n_channels=n_channels),
state2value=Sequence(
L.Linear(None, n_channels),
L.BatchNormalization(n_channels),
F.relu,
L.Linear(n_channels, n_tasks),
),
)
def unroll(self, x, test):
# Compute g^k and lambda^k for k=0,...,K
g_k = []
lambda_k = []
state = self.obs2state(x, test=test)
g_k.append(self.state2value(state, test=test)) # g^0 = v^0
reward_sum = 0
gamma_prod = 1
for k in range(self.model_steps):
state, reward, gamma, lmbda = self.core(state, test=test)
if not self.use_reward_gamma:
reward = 0
gamma = 1
if not self.use_lambda:
lmbda = 1
lambda_k.append(lmbda) # lambda^k
v = self.state2value(state, test=test)
reward_sum += gamma_prod * reward
gamma_prod *= gamma
g_k.append(reward_sum + gamma_prod * v) # g^{k+1}
lambda_k.append(0) # lambda^K = 0
# Compute g^lambda
lambda_prod = 1
g_lambda = 0
w_k = []
for k in range(self.model_steps + 1):
w = (1 - lambda_k[k]) * lambda_prod
w_k.append(w)
lambda_prod *= lambda_k[k]
# g^lambda doesn't backprop errors to g^k
g_lambda += w * chainer.Variable(g_k[k].data)
return g_k, g_lambda, w_k
def supervised_loss(self, x, t):
g_k, g_lambda, w_k = self.unroll(x, test=False)
if self.usage_weighting:
g_k_loss = sum(F.sum(w * (g - t) ** 2) / x.shape[0]
for g, w in zip(g_k, w_k))
else:
g_k_loss = sum(F.mean_squared_error(g, t) for g in g_k) / len(g_k)
g_lambda_loss = F.mean_squared_error(g_lambda, t)
return g_k_loss, g_lambda_loss
def unsupervised_loss(self, x):
g_k, g_lambda, w_k = self.unroll(x, test=False)
# Only update g_k
g_lambda.creator = None
if self.usage_weighting:
return sum(F.sum(w * (g - g_lambda) ** 2) / x.shape[0]
for g, w in zip(g_k, w_k))
else:
return sum(F.mean_squared_error(g, g_lambda) for g in g_k) / len(g_k)