This repository has been archived by the owner on Sep 24, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 123
/
modules.py
352 lines (265 loc) · 10.6 KB
/
modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
class Retina:
"""A visual retina.
Extracts a foveated glimpse `phi` around location `l`
from an image `x`.
Concretely, encodes the region around `l` at a
high-resolution but uses a progressively lower
resolution for pixels further from `l`, resulting
in a compressed representation of the original
image `x`.
Args:
x: a 4D Tensor of shape (B, H, W, C). The minibatch
of images.
l: a 2D Tensor of shape (B, 2). Contains normalized
coordinates in the range [-1, 1].
g: size of the first square patch.
k: number of patches to extract in the glimpse.
s: scaling factor that controls the size of
successive patches.
Returns:
phi: a 5D tensor of shape (B, k, g, g, C). The
foveated glimpse of the image.
"""
def __init__(self, g, k, s):
self.g = g
self.k = k
self.s = s
def foveate(self, x, l):
"""Extract `k` square patches of size `g`, centered
at location `l`. The initial patch is a square of
size `g`, and each subsequent patch is a square
whose side is `s` times the size of the previous
patch.
The `k` patches are finally resized to (g, g) and
concatenated into a tensor of shape (B, k, g, g, C).
"""
phi = []
size = self.g
# extract k patches of increasing size
for i in range(self.k):
phi.append(self.extract_patch(x, l, size))
size = int(self.s * size)
# resize the patches to squares of size g
for i in range(1, len(phi)):
k = phi[i].shape[-1] // self.g
phi[i] = F.avg_pool2d(phi[i], k)
# concatenate into a single tensor and flatten
phi = torch.cat(phi, 1)
phi = phi.view(phi.shape[0], -1)
return phi
def extract_patch(self, x, l, size):
"""Extract a single patch for each image in `x`.
Args:
x: a 4D Tensor of shape (B, H, W, C). The minibatch
of images.
l: a 2D Tensor of shape (B, 2).
size: a scalar defining the size of the extracted patch.
Returns:
patch: a 4D Tensor of shape (B, size, size, C)
"""
B, C, H, W = x.shape
start = self.denormalize(H, l)
end = start + size
# pad with zeros
x = F.pad(x, (size // 2, size // 2, size // 2, size // 2))
# loop through mini-batch and extract patches
patch = []
for i in range(B):
patch.append(x[i, :, start[i, 1] : end[i, 1], start[i, 0] : end[i, 0]])
return torch.stack(patch)
def denormalize(self, T, coords):
"""Convert coordinates in the range [-1, 1] to
coordinates in the range [0, T] where `T` is
the size of the image.
"""
return (0.5 * ((coords + 1.0) * T)).long()
def exceeds(self, from_x, to_x, from_y, to_y, T):
"""Check whether the extracted patch will exceed
the boundaries of the image of size `T`.
"""
if (from_x < 0) or (from_y < 0) or (to_x > T) or (to_y > T):
return True
return False
class GlimpseNetwork(nn.Module):
"""The glimpse network.
Combines the "what" and the "where" into a glimpse
feature vector `g_t`.
- "what": glimpse extracted from the retina.
- "where": location tuple where glimpse was extracted.
Concretely, feeds the output of the retina `phi` to
a fc layer and the glimpse location vector `l_t_prev`
to a fc layer. Finally, these outputs are fed each
through a fc layer and their sum is rectified.
In other words:
`g_t = relu( fc( fc(l) ) + fc( fc(phi) ) )`
Args:
h_g: hidden layer size of the fc layer for `phi`.
h_l: hidden layer size of the fc layer for `l`.
g: size of the square patches in the glimpses extracted
by the retina.
k: number of patches to extract per glimpse.
s: scaling factor that controls the size of successive patches.
c: number of channels in each image.
x: a 4D Tensor of shape (B, H, W, C). The minibatch
of images.
l_t_prev: a 2D tensor of shape (B, 2). Contains the glimpse
coordinates [x, y] for the previous timestep `t-1`.
Returns:
g_t: a 2D tensor of shape (B, hidden_size).
The glimpse representation returned by
the glimpse network for the current
timestep `t`.
"""
def __init__(self, h_g, h_l, g, k, s, c):
super().__init__()
self.retina = Retina(g, k, s)
# glimpse layer
D_in = k * g * g * c
self.fc1 = nn.Linear(D_in, h_g)
# location layer
D_in = 2
self.fc2 = nn.Linear(D_in, h_l)
self.fc3 = nn.Linear(h_g, h_g + h_l)
self.fc4 = nn.Linear(h_l, h_g + h_l)
def forward(self, x, l_t_prev):
# generate glimpse phi from image x
phi = self.retina.foveate(x, l_t_prev)
# flatten location vector
l_t_prev = l_t_prev.view(l_t_prev.size(0), -1)
# feed phi and l to respective fc layers
phi_out = F.relu(self.fc1(phi))
l_out = F.relu(self.fc2(l_t_prev))
what = self.fc3(phi_out)
where = self.fc4(l_out)
# feed to fc layer
g_t = F.relu(what + where)
return g_t
class CoreNetwork(nn.Module):
"""The core network.
An RNN that maintains an internal state by integrating
information extracted from the history of past observations.
It encodes the agent's knowledge of the environment through
a state vector `h_t` that gets updated at every time step `t`.
Concretely, it takes the glimpse representation `g_t` as input,
and combines it with its internal state `h_t_prev` at the previous
time step, to produce the new internal state `h_t` at the current
time step.
In other words:
`h_t = relu( fc(h_t_prev) + fc(g_t) )`
Args:
input_size: input size of the rnn.
hidden_size: hidden size of the rnn.
g_t: a 2D tensor of shape (B, hidden_size). The glimpse
representation returned by the glimpse network for the
current timestep `t`.
h_t_prev: a 2D tensor of shape (B, hidden_size). The
hidden state vector for the previous timestep `t-1`.
Returns:
h_t: a 2D tensor of shape (B, hidden_size). The hidden
state vector for the current timestep `t`.
"""
def __init__(self, input_size, hidden_size):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.i2h = nn.Linear(input_size, hidden_size)
self.h2h = nn.Linear(hidden_size, hidden_size)
def forward(self, g_t, h_t_prev):
h1 = self.i2h(g_t)
h2 = self.h2h(h_t_prev)
h_t = F.relu(h1 + h2)
return h_t
class ActionNetwork(nn.Module):
"""The action network.
Uses the internal state `h_t` of the core network to
produce the final output classification.
Concretely, feeds the hidden state `h_t` through a fc
layer followed by a softmax to create a vector of
output probabilities over the possible classes.
Hence, the environment action `a_t` is drawn from a
distribution conditioned on an affine transformation
of the hidden state vector `h_t`, or in other words,
the action network is simply a linear softmax classifier.
Args:
input_size: input size of the fc layer.
output_size: output size of the fc layer.
h_t: the hidden state vector of the core network
for the current time step `t`.
Returns:
a_t: output probability vector over the classes.
"""
def __init__(self, input_size, output_size):
super().__init__()
self.fc = nn.Linear(input_size, output_size)
def forward(self, h_t):
a_t = F.log_softmax(self.fc(h_t), dim=1)
return a_t
class LocationNetwork(nn.Module):
"""The location network.
Uses the internal state `h_t` of the core network to
produce the location coordinates `l_t` for the next
time step.
Concretely, feeds the hidden state `h_t` through a fc
layer followed by a tanh to clamp the output beween
[-1, 1]. This produces a 2D vector of means used to
parametrize a two-component Gaussian with a fixed
variance from which the location coordinates `l_t`
for the next time step are sampled.
Hence, the location `l_t` is chosen stochastically
from a distribution conditioned on an affine
transformation of the hidden state vector `h_t`.
Args:
input_size: input size of the fc layer.
output_size: output size of the fc layer.
std: standard deviation of the normal distribution.
h_t: the hidden state vector of the core network for
the current time step `t`.
Returns:
mu: a 2D vector of shape (B, 2).
l_t: a 2D vector of shape (B, 2).
"""
def __init__(self, input_size, output_size, std):
super().__init__()
self.std = std
hid_size = input_size // 2
self.fc = nn.Linear(input_size, hid_size)
self.fc_lt = nn.Linear(hid_size, output_size)
def forward(self, h_t):
# compute mean
feat = F.relu(self.fc(h_t.detach()))
mu = torch.tanh(self.fc_lt(feat))
# reparametrization trick
l_t = torch.distributions.Normal(mu, self.std).rsample()
l_t = l_t.detach()
log_pi = Normal(mu, self.std).log_prob(l_t)
# we assume both dimensions are independent
# 1. pdf of the joint is the product of the pdfs
# 2. log of the product is the sum of the logs
log_pi = torch.sum(log_pi, dim=1)
# bound between [-1, 1]
l_t = torch.clamp(l_t, -1, 1)
return log_pi, l_t
class BaselineNetwork(nn.Module):
"""The baseline network.
This network regresses the baseline in the
reward function to reduce the variance of
the gradient update.
Args:
input_size: input size of the fc layer.
output_size: output size of the fc layer.
h_t: the hidden state vector of the core network
for the current time step `t`.
Returns:
b_t: a 2D vector of shape (B, 1). The baseline
for the current time step `t`.
"""
def __init__(self, input_size, output_size):
super().__init__()
self.fc = nn.Linear(input_size, output_size)
def forward(self, h_t):
b_t = self.fc(h_t.detach())
return b_t