Skip to content

Commit

Permalink
Merge branch 'read-before-write' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
tlimbacher committed Oct 22, 2020
2 parents c2966d3 + 4849f29 commit 9a88599
Show file tree
Hide file tree
Showing 6 changed files with 239 additions and 20 deletions.
139 changes: 128 additions & 11 deletions babi_task_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@

import numpy as np
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import TimeDistributed

from data.babi_data import download, load_task, tasks, vectorize_data
from layers.encoding import Encoding
from layers.extracting import Extracting
from layers.reading import ReadingCell
from layers.writing import WritingCell
from tensorflow.keras import Model
from tensorflow.keras.layers import TimeDistributed
from utils.logger import MyCSVLogger

strategy = tf.distribute.MirroredStrategy()
Expand All @@ -26,7 +25,7 @@
parser.add_argument('--max_num_sentences', type=int, default=-1)
parser.add_argument('--training_set_size', type=str, default='10k')

parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--epochs', type=int, default=250)
parser.add_argument('--learning_rate', type=float, default=0.003)
parser.add_argument('--batch_size_per_replica', type=int, default=128)
parser.add_argument('--random_state', type=int, default=None)
Expand All @@ -44,6 +43,7 @@

parser.add_argument('--verbose', type=int, default=1)
parser.add_argument('--logging', type=int, default=0)
parser.add_argument('--make_plots', type=int, default=0)
args = parser.parse_args()

batch_size = args.batch_size_per_replica * strategy.num_replicas_in_sync
Expand Down Expand Up @@ -168,12 +168,22 @@
w_assoc_max=args.w_assoc_max),
name='entity_writing')(entities)

queried_value = tf.keras.layers.RNN(ReadingCell(units=args.memory_size,
use_bias=False,
activation='relu',
kernel_initializer='he_uniform',
kernel_regularizer=tf.keras.regularizers.l2(1e-3)),
name='entity_reading')(query_encoded, constants=[memory_matrix])
# queried_value = tf.keras.layers.RNN(ReadingCell(units=args.memory_size,
# use_bias=False,
# activation='relu',
# kernel_initializer='he_uniform',
# kernel_regularizer=tf.keras.regularizers.l2(1e-3)),
# name='entity_reading')(query_encoded, constants=[memory_matrix])

k, queried_values = tf.keras.layers.RNN(ReadingCell(units=args.memory_size,
use_bias=False,
activation='relu',
kernel_initializer='he_uniform',
kernel_regularizer=tf.keras.regularizers.l2(1e-3)),
return_sequences=True, name='entity_reading')(
query_encoded, constants=[memory_matrix])

queried_value = tf.keras.layers.Lambda(lambda x: x[:, -1, :])(queried_values)

outputs = tf.keras.layers.Dense(vocab_size,
use_bias=False,
Expand All @@ -193,7 +203,13 @@

# Train and evaluate.
def lr_scheduler(epoch):
return args.learning_rate * 0.85**tf.math.floor(epoch / 20)
# return args.learning_rate
# return args.learning_rate * 0.75**tf.math.floor(epoch / 100)
if epoch < 150:
return args.learning_rate
else:
# return args.learning_rate * 0.85**tf.math.floor(epoch / 50)
return args.learning_rate * tf.math.exp(0.01 * (150 - epoch))


callbacks = []
Expand All @@ -211,3 +227,104 @@ def lr_scheduler(epoch):
args.task_id, args.training_set_size, args.encodings_type, args.hops, args.random_state))))

model.evaluate(x=x_test, y=y_test, callbacks=callbacks, verbose=2)

if args.make_plots:

import utils.configure_seaborn as cs
import seaborn as sns
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
from scipy import spatial

sns.set(context='paper', style='ticks', rc=cs.rc_params)

examples = range(10) # [0, 1, 2, 3] # 2 # range(1)
for example in examples:
x = [testS[example][np.newaxis, :, :], testQ[example][np.newaxis, :, :]]

extracting_layer = Model(inputs=model.input, outputs=model.get_layer('entity_extracting').output)
entities = extracting_layer.predict(x)
keys_story, values_story = tf.split(entities, 2, axis=-1)

reading_layer = Model(inputs=model.input, outputs=model.get_layer('entity_reading').output)
keys_query, queried_values = reading_layer.predict(x)

print(test[example])

print(' '.join(test[example][0][0]))

cosine_sim_keys = np.zeros((x[0].shape[1], args.hops))
for i, key_story in enumerate(keys_story[0]):
for j, key_query in enumerate(keys_query[0]):
cosine_sim_keys[i, j] = 1 - spatial.distance.cosine(key_story, key_query)
print(cosine_sim_keys)

cosine_sim_values = np.zeros((x[0].shape[1], args.hops))
for i, value_story in enumerate(values_story[0]):
for j, queried_value in enumerate(queried_values[0]):
cosine_sim_values[i, j] = 1 - spatial.distance.cosine(value_story, queried_value)
print(cosine_sim_values)

fig, ax = plt.subplots(nrows=3, ncols=1, sharex=True)
# plt.suptitle('entities writing')
for i, t in enumerate(test[example][0]):
ax[0].text(0.5+i*1, 0.0, ' '.join(t), {'ha': 'center', 'va': 'bottom'},
fontsize=7, rotation=90)
ax[0].set_frame_on(False)
ax[0].axes.get_yaxis().set_visible(False)
ax[0].axes.get_xaxis().set_visible(False)

vmin_k = np.minimum(np.min(keys_story[0]), np.min(keys_query[0]))
vmax_k = np.maximum(np.max(keys_story[0]), np.max(keys_query[0]))
print(vmin_k, vmax_k)

vmin_v = np.minimum(np.min(values_story[0]), np.min(queried_values[0]))
vmax_v = np.maximum(np.max(values_story[0]), np.max(queried_values[0]))
print(vmin_v, vmax_v)

vmin = np.minimum(vmin_k, vmin_v)
vmax = np.maximum(vmax_k, vmax_v)

ax[1].pcolormesh(tf.transpose(keys_story[0]), cmap='Blues') # , vmin=vmin_k, vmax=vmax_k)
ax[2].pcolormesh(tf.transpose(values_story[0]), cmap='Oranges') # , vmin=vmin_v, vmax=vmax_v)
ax[1].set_ylabel('keys story')
ax[2].set_ylabel('values story')
ax[2].set_xlim([0, 10])
fig.savefig('entities-writing-{0}.pdf'.format(example), dpi=fig.dpi)

fig, ax = plt.subplots(nrows=3, ncols=1, sharex=True)
# plt.suptitle('entities reading')
for i in range(args.hops):
ax[0].text(0.5+i*1, 0.0, ' '.join(test[example][1]), {'ha': 'center', 'va': 'bottom'},
fontsize=7, rotation=90)
ax[0].set_frame_on(False)
ax[0].axes.get_yaxis().set_visible(False)
ax[0].axes.get_xaxis().set_visible(False)

ax[1].pcolormesh(tf.transpose(keys_query[0]), cmap='Blues') # , vmin=vmin_k, vmax=vmax_k)
ax[2].pcolormesh(tf.transpose(queried_values[0]), cmap='Oranges') # , vmin=vmin_v, vmax=vmax_v)
ax[1].set_ylabel('key query')
ax[2].set_ylabel('queried value')
fig.savefig('entities-reading-{0}.pdf'.format(example), dpi=fig.dpi)

# my_cmap = sns.color_palette("crest", as_cmap=True)
my_cmap = sns.color_palette("ch:start=.2,rot=-.3", as_cmap=True)

fig, ax = plt.subplots(nrows=1, ncols=1)
plt.suptitle('cosine sim keys')
cax = ax.matshow(cosine_sim_keys[:10, :], cmap=my_cmap) # , vmin=-1, vmax=1)
fig.colorbar(cax)
ax.set_ylabel('keys story')
ax.set_xlabel('keys query')
fig.savefig('cosine-sim-keys-{0}.pdf'.format(example), dpi=fig.dpi)

fig, ax = plt.subplots(nrows=1, ncols=1)
plt.suptitle('cosine sim values')
cax = ax.matshow(cosine_sim_values[:10, :], cmap=my_cmap) # , vmin=-1, vmax=1)
fig.colorbar(cax)
ax.set_ylabel('values story')
ax.set_xlabel('queried values')
fig.savefig('cosine-sim-values-{0}.pdf'.format(example), dpi=fig.dpi)

# plt.show()
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ channels:

dependencies:
- python=3.7
- matplotlib=3.1
- tensorflow-gpu=2.1
44 changes: 37 additions & 7 deletions image_association_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import TimeDistributed
from tensorflow.keras import Model

from data.image_association_data import load_data
from layers.extracting import Extracting
Expand Down Expand Up @@ -38,6 +39,7 @@
parser.add_argument('--w_assoc_max', type=float, default=1.0)

parser.add_argument('--verbose', type=int, default=1)
parser.add_argument('--make_plots', type=int, default=0)
args = parser.parse_args()

batch_size = args.batch_size_per_replica * strategy.num_replicas_in_sync
Expand Down Expand Up @@ -145,19 +147,19 @@ def dataset_generator(x, y, seed):
learn_gamma_neg=False),
name='entity_writing')(entities)

queried_value = Reading(units=args.memory_size,
use_bias=False,
activation='relu',
kernel_initializer='he_uniform',
kernel_regularizer=tf.keras.regularizers.l2(1e-3),
name='entity_reading')(features_b, constants=[memory_matrix])
_, queried_value = Reading(units=args.memory_size,
use_bias=False,
activation='relu',
kernel_initializer='he_uniform',
kernel_regularizer=tf.keras.regularizers.l2(1e-3),
name='entity_reading')(features_b, constants=[memory_matrix])

outputs = tf.keras.layers.Dense(10,
use_bias=False,
kernel_initializer='he_uniform',
name='output')(queried_value)

model = tf.keras.Model(inputs=[input_a, input_b], outputs=outputs)
model = Model(inputs=[input_a, input_b], outputs=outputs)

# Compile the model.
optimizer_kwargs = {'clipnorm': args.max_grad_norm} if args.max_grad_norm else {}
Expand Down Expand Up @@ -186,3 +188,31 @@ def lr_scheduler(epoch):
verbose=args.verbose)

model.evaluate(test_dataset, steps=np.ceil(num_test/batch_size), verbose=2)

if args.make_plots:

import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
from scipy import spatial

x = test_dataset.take(1)

extracting_layer = Model(inputs=model.input, outputs=model.get_layer('entity_extracting').output)
entities = extracting_layer.predict(x)[0]
keys_story, values_story = tf.split(entities, 2, axis=-1)

reading_layer = Model(inputs=model.input, outputs=model.get_layer('entity_reading').output)
key_query, queried_value = reading_layer.predict(x)
key_query = key_query[0]
queried_value = queried_value[0]

cosine_sim_keys = np.zeros((args.timesteps, 1))
for i, key_story in enumerate(keys_story):
cosine_sim_keys[i, 0] = 1 - spatial.distance.cosine(key_story, key_query)
print(cosine_sim_keys)

cosine_sim_values = np.zeros((args.timesteps, 1))
for i, value_story in enumerate(values_story):
cosine_sim_values[i, 0] = 1 - spatial.distance.cosine(value_story, queried_value)
print(cosine_sim_values)
4 changes: 2 additions & 2 deletions layers/reading.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def call(self, inputs, constants):

v = K.batch_dot(k, memory_matrix)

return v
return k, v

def compute_mask(self, inputs, mask=None):
return mask
Expand Down Expand Up @@ -82,7 +82,7 @@ def call(self, inputs, states, constants):

v = K.batch_dot(k, memory_matrix)

return v, v
return [k, v], v

def compute_mask(self, inputs, mask=None):
return mask
Expand Down
21 changes: 21 additions & 0 deletions layers/writing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import tensorflow as tf
from tensorflow.keras.layers import Layer
import tensorflow.keras.backend as K


class Writing(Layer):
Expand Down Expand Up @@ -59,6 +60,14 @@ def __init__(self,
self.learn_gamma_pos = learn_gamma_pos
self.learn_gamma_neg = learn_gamma_neg

self.dense = tf.keras.layers.Dense(units=self.units,
use_bias=False,
kernel_initializer='he_uniform',
kernel_regularizer=tf.keras.regularizers.l2(1e-3))

self.ln1 = tf.keras.layers.LayerNormalization()
self.ln2 = tf.keras.layers.LayerNormalization()

@property
def state_size(self):
return tf.TensorShape((self.units, self.units))
Expand All @@ -77,10 +86,22 @@ def call(self, inputs, states, mask=None):
memory_matrix = states[0]
k, v = tf.split(inputs, 2, axis=-1)

k = self.ln1(k) # TODO layer norm for v here?
# v = self.ln2(v)

v_h = K.batch_dot(k, memory_matrix)

v = self.dense(tf.concat([v, v_h], axis=1))
v = self.ln2(v)
# k = self.ln1(k) # TODO layer norm for v here?

k = tf.expand_dims(k, 2)
v = tf.expand_dims(v, 1)

hebb = self.gamma_pos * (self.w_max - memory_matrix) * k * v - self.gamma_neg * memory_matrix * k**2
# hebb = self.gamma_pos * (self.w_max - memory_matrix) * k * v - self.gamma_neg * memory_matrix * k
# hebb = self.gamma_pos * (self.w_max - memory_matrix) * k * v

memory_matrix = hebb + memory_matrix

return memory_matrix, memory_matrix
Expand Down
50 changes: 50 additions & 0 deletions utils/configure_seaborn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Configure seaborn for plotting."""
import matplotlib

PAD_INCHES = 0.00

rc_params = {
# FONT
'font.size': 9,
'font.sans-serif': 'Myriad Pro',
'font.stretch': 'condensed',
'font.weight': 'normal',
'mathtext.fontset': 'custom',
'mathtext.fallback_to_cm': True,
'mathtext.rm': 'Minion Pro',
'mathtext.it': 'Minion Pro:italic',
'mathtext.bf': 'Minion Pro:bold:italic',
'mathtext.cal': 'Minion Pro:italic', # TODO find calligraphy font
'mathtext.tt': 'monospace',
# AXES
'axes.linewidth': 0.5,
'axes.spines.top': False,
'axes.spines.right': False,
'axes.labelsize': 9,
# TICKS
'xtick.major.size': 3.5,
'xtick.major.width': 0.8,
'ytick.major.size': 3.5,
'xtick.labelsize': 8,
'ytick.labelsize': 8,
'ytick.major.width': 0.8,
# LEGEND
'legend.fontsize': 9,
# SAVING FIGURES
'savefig.bbox': 'tight',
'savefig.pad_inches': PAD_INCHES,
'pdf.fonttype': 42,
'savefig.dpi': 300}


def mm2inch(*tupl, pad):
inch = 25.4
if isinstance(tupl[0], tuple):
return tuple(i/inch - pad for i in tupl[0])
else:
return tuple(i/inch - pad for i in tupl)


def set_figure_size(*size):
params = {'figure.figsize': mm2inch(size, pad=PAD_INCHES)}
matplotlib.rcParams.update(params)

0 comments on commit 9a88599

Please sign in to comment.