-
Notifications
You must be signed in to change notification settings - Fork 1
/
ReadWriteVectorGeneration.py
57 lines (37 loc) · 1.43 KB
/
ReadWriteVectorGeneration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import numpy as np
import tensorflow as tf
def ReadVector(M_t, w_t):
'''
M_t: Memory of size (N,M) at time t.
w_t: (N,), Weighting generated by READ HEAD at time t for reading memory location
RETURNS:
r_t: The Read Vector
'''
tol = 0.01
assert (np.sum(w_t) >= 1.0 - tol) & (np.sum(w_t) <= 1.0 + tol)
r_t = np.dot(tf.reshape(w_t,(1,M_t.shape[0])),M_t)
assert r_t.shape == (1,M_t.shape[1])
return tf.reshape(r_t,(-1,))
def WriteOnMemory(M_prev, w_t, e_t, a_t):
'''
M_prev: Memory Matrix at the previous time step of size (N,M)
w_t: (N,), Weighting generated by WRITE HEAD at time t for Writing to the memory locations.
e_t: (M,), Erase vector generated by WRITE HEAD.
a_t: (M,), Add vector generated by WRITE HEAD.
RETURNS:
M_t: New Memory Matrix after Erasing/Adding new instances.
'''
assert np.all(e_t[(e_t < 1) & (e_t > 0)] == e_t) == True
assert np.sum(w_t) == 1.0
(N,M) = M_prev.shape
assert w_t.shape == (N,)
assert e_t.shape == (M,)
assert a_t.shape == (M,)
matrix = np.dot(tf.reshape(w_t,(N,1)),tf.reshape(e_t,(1,M)))
matrix = 1 - matrix
M_hat_t = tf.multiply(M_prev,matrix)
assert M_hat_t.shape == M_prev.shape
matrix2 = np.dot(tf.reshape(w_t,(N,1)),tf.reshape(a_t,(1,M)))
M_t = (M_hat_t + matrix2)
assert M_t.shape == M_prev.shape
return M_t