Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
tlimbacher committed Oct 22, 2020
1 parent 9a88599 commit 78631b8
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 187 deletions.
153 changes: 23 additions & 130 deletions babi_task_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
parser = argparse.ArgumentParser()
parser.add_argument('--task_id', type=int, default=1)
parser.add_argument('--max_num_sentences', type=int, default=-1)
parser.add_argument('--training_set_size', type=str, default='10k')
parser.add_argument('--training_set_size', type=str, default='10k', help='`1k` or `10k`')

parser.add_argument('--epochs', type=int, default=250)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--learning_rate', type=float, default=0.003)
parser.add_argument('--batch_size_per_replica', type=int, default=128)
parser.add_argument('--random_state', type=int, default=None)
Expand All @@ -35,15 +35,15 @@
parser.add_argument('--hops', type=int, default=3)
parser.add_argument('--memory_size', type=int, default=100)
parser.add_argument('--embeddings_size', type=int, default=80)
parser.add_argument('--read_before_write', type=int, default=0)
parser.add_argument('--gamma_pos', type=float, default=0.01)
parser.add_argument('--gamma_neg', type=float, default=0.01)
parser.add_argument('--w_assoc_max', type=float, default=1.0)
parser.add_argument('--encodings_type', type=str, default='learned_encoding')
parser.add_argument('--encodings_constraint', type=str, default='mask_time_word')
parser.add_argument('--encodings_type', type=str, default='learned_encoding',
help='`identity_encoding`, `position_encoding` or `learned_encoding`')

parser.add_argument('--verbose', type=int, default=1)
parser.add_argument('--logging', type=int, default=0)
parser.add_argument('--make_plots', type=int, default=0)
args = parser.parse_args()

batch_size = args.batch_size_per_replica * strategy.num_replicas_in_sync
Expand Down Expand Up @@ -148,7 +148,7 @@
story_embedded = TimeDistributed(embedding, name='story_embedding')(story_input)
query_embedded = TimeDistributed(embedding, name='query_embedding')(query_input)

encoding = Encoding(args.encodings_type, args.encodings_constraint, name='encoding')
encoding = Encoding(args.encodings_type, name='encoding')
story_encoded = TimeDistributed(encoding, name='story_encoding')(story_embedded)
query_encoded = TimeDistributed(encoding, name='query_encoding')(query_embedded)

Expand All @@ -163,27 +163,21 @@
name='entity_extracting')(story_encoded)

memory_matrix = tf.keras.layers.RNN(WritingCell(units=args.memory_size,
read_before_write=args.read_before_write,
use_bias=False,
gamma_pos=args.gamma_pos,
gamma_neg=args.gamma_neg,
w_assoc_max=args.w_assoc_max),
w_assoc_max=args.w_assoc_max,
kernel_initializer='he_uniform',
kernel_regularizer=tf.keras.regularizers.l2(1e-3)),
name='entity_writing')(entities)

# queried_value = tf.keras.layers.RNN(ReadingCell(units=args.memory_size,
# use_bias=False,
# activation='relu',
# kernel_initializer='he_uniform',
# kernel_regularizer=tf.keras.regularizers.l2(1e-3)),
# name='entity_reading')(query_encoded, constants=[memory_matrix])

k, queried_values = tf.keras.layers.RNN(ReadingCell(units=args.memory_size,
use_bias=False,
activation='relu',
kernel_initializer='he_uniform',
kernel_regularizer=tf.keras.regularizers.l2(1e-3)),
return_sequences=True, name='entity_reading')(
query_encoded, constants=[memory_matrix])

queried_value = tf.keras.layers.Lambda(lambda x: x[:, -1, :])(queried_values)
queried_value = tf.keras.layers.RNN(ReadingCell(units=args.memory_size,
use_bias=False,
activation='relu',
kernel_initializer='he_uniform',
kernel_regularizer=tf.keras.regularizers.l2(1e-3)),
name='entity_reading')(query_encoded, constants=[memory_matrix])

outputs = tf.keras.layers.Dense(vocab_size,
use_bias=False,
Expand All @@ -203,13 +197,13 @@

# Train and evaluate.
def lr_scheduler(epoch):
# return args.learning_rate
# return args.learning_rate * 0.75**tf.math.floor(epoch / 100)
if epoch < 150:
return args.learning_rate
if args.read_before_write:
if epoch < 150:
return args.learning_rate
else:
return args.learning_rate * tf.math.exp(0.01 * (150 - epoch))
else:
# return args.learning_rate * 0.85**tf.math.floor(epoch / 50)
return args.learning_rate * tf.math.exp(0.01 * (150 - epoch))
return args.learning_rate * 0.85**tf.math.floor(epoch / 20)


callbacks = []
Expand All @@ -227,104 +221,3 @@ def lr_scheduler(epoch):
args.task_id, args.training_set_size, args.encodings_type, args.hops, args.random_state))))

model.evaluate(x=x_test, y=y_test, callbacks=callbacks, verbose=2)

if args.make_plots:

import utils.configure_seaborn as cs
import seaborn as sns
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
from scipy import spatial

sns.set(context='paper', style='ticks', rc=cs.rc_params)

examples = range(10) # [0, 1, 2, 3] # 2 # range(1)
for example in examples:
x = [testS[example][np.newaxis, :, :], testQ[example][np.newaxis, :, :]]

extracting_layer = Model(inputs=model.input, outputs=model.get_layer('entity_extracting').output)
entities = extracting_layer.predict(x)
keys_story, values_story = tf.split(entities, 2, axis=-1)

reading_layer = Model(inputs=model.input, outputs=model.get_layer('entity_reading').output)
keys_query, queried_values = reading_layer.predict(x)

print(test[example])

print(' '.join(test[example][0][0]))

cosine_sim_keys = np.zeros((x[0].shape[1], args.hops))
for i, key_story in enumerate(keys_story[0]):
for j, key_query in enumerate(keys_query[0]):
cosine_sim_keys[i, j] = 1 - spatial.distance.cosine(key_story, key_query)
print(cosine_sim_keys)

cosine_sim_values = np.zeros((x[0].shape[1], args.hops))
for i, value_story in enumerate(values_story[0]):
for j, queried_value in enumerate(queried_values[0]):
cosine_sim_values[i, j] = 1 - spatial.distance.cosine(value_story, queried_value)
print(cosine_sim_values)

fig, ax = plt.subplots(nrows=3, ncols=1, sharex=True)
# plt.suptitle('entities writing')
for i, t in enumerate(test[example][0]):
ax[0].text(0.5+i*1, 0.0, ' '.join(t), {'ha': 'center', 'va': 'bottom'},
fontsize=7, rotation=90)
ax[0].set_frame_on(False)
ax[0].axes.get_yaxis().set_visible(False)
ax[0].axes.get_xaxis().set_visible(False)

vmin_k = np.minimum(np.min(keys_story[0]), np.min(keys_query[0]))
vmax_k = np.maximum(np.max(keys_story[0]), np.max(keys_query[0]))
print(vmin_k, vmax_k)

vmin_v = np.minimum(np.min(values_story[0]), np.min(queried_values[0]))
vmax_v = np.maximum(np.max(values_story[0]), np.max(queried_values[0]))
print(vmin_v, vmax_v)

vmin = np.minimum(vmin_k, vmin_v)
vmax = np.maximum(vmax_k, vmax_v)

ax[1].pcolormesh(tf.transpose(keys_story[0]), cmap='Blues') # , vmin=vmin_k, vmax=vmax_k)
ax[2].pcolormesh(tf.transpose(values_story[0]), cmap='Oranges') # , vmin=vmin_v, vmax=vmax_v)
ax[1].set_ylabel('keys story')
ax[2].set_ylabel('values story')
ax[2].set_xlim([0, 10])
fig.savefig('entities-writing-{0}.pdf'.format(example), dpi=fig.dpi)

fig, ax = plt.subplots(nrows=3, ncols=1, sharex=True)
# plt.suptitle('entities reading')
for i in range(args.hops):
ax[0].text(0.5+i*1, 0.0, ' '.join(test[example][1]), {'ha': 'center', 'va': 'bottom'},
fontsize=7, rotation=90)
ax[0].set_frame_on(False)
ax[0].axes.get_yaxis().set_visible(False)
ax[0].axes.get_xaxis().set_visible(False)

ax[1].pcolormesh(tf.transpose(keys_query[0]), cmap='Blues') # , vmin=vmin_k, vmax=vmax_k)
ax[2].pcolormesh(tf.transpose(queried_values[0]), cmap='Oranges') # , vmin=vmin_v, vmax=vmax_v)
ax[1].set_ylabel('key query')
ax[2].set_ylabel('queried value')
fig.savefig('entities-reading-{0}.pdf'.format(example), dpi=fig.dpi)

# my_cmap = sns.color_palette("crest", as_cmap=True)
my_cmap = sns.color_palette("ch:start=.2,rot=-.3", as_cmap=True)

fig, ax = plt.subplots(nrows=1, ncols=1)
plt.suptitle('cosine sim keys')
cax = ax.matshow(cosine_sim_keys[:10, :], cmap=my_cmap) # , vmin=-1, vmax=1)
fig.colorbar(cax)
ax.set_ylabel('keys story')
ax.set_xlabel('keys query')
fig.savefig('cosine-sim-keys-{0}.pdf'.format(example), dpi=fig.dpi)

fig, ax = plt.subplots(nrows=1, ncols=1)
plt.suptitle('cosine sim values')
cax = ax.matshow(cosine_sim_values[:10, :], cmap=my_cmap) # , vmin=-1, vmax=1)
fig.colorbar(cax)
ax.set_ylabel('values story')
ax.set_xlabel('queried values')
fig.savefig('cosine-sim-values-{0}.pdf'.format(example), dpi=fig.dpi)

# plt.show()
1 change: 0 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,4 @@ channels:

dependencies:
- python=3.7
- matplotlib=3.1
- tensorflow-gpu=2.1
43 changes: 7 additions & 36 deletions image_association_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
parser = argparse.ArgumentParser()
parser.add_argument('--delay', type=int, default=0)
parser.add_argument('--timesteps', type=int, default=3)
parser.add_argument('--delay_padding', type=str, default='random')
parser.add_argument('--delay_padding', type=str, default='random', help='`zeros` or `random`')

parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--learning_rate', type=float, default=0.001)
Expand All @@ -39,7 +39,6 @@
parser.add_argument('--w_assoc_max', type=float, default=1.0)

parser.add_argument('--verbose', type=int, default=1)
parser.add_argument('--make_plots', type=int, default=0)
args = parser.parse_args()

batch_size = args.batch_size_per_replica * strategy.num_replicas_in_sync
Expand Down Expand Up @@ -147,12 +146,12 @@ def dataset_generator(x, y, seed):
learn_gamma_neg=False),
name='entity_writing')(entities)

_, queried_value = Reading(units=args.memory_size,
use_bias=False,
activation='relu',
kernel_initializer='he_uniform',
kernel_regularizer=tf.keras.regularizers.l2(1e-3),
name='entity_reading')(features_b, constants=[memory_matrix])
queried_value = Reading(units=args.memory_size,
use_bias=False,
activation='relu',
kernel_initializer='he_uniform',
kernel_regularizer=tf.keras.regularizers.l2(1e-3),
name='entity_reading')(features_b, constants=[memory_matrix])

outputs = tf.keras.layers.Dense(10,
use_bias=False,
Expand Down Expand Up @@ -188,31 +187,3 @@ def lr_scheduler(epoch):
verbose=args.verbose)

model.evaluate(test_dataset, steps=np.ceil(num_test/batch_size), verbose=2)

if args.make_plots:

import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
from scipy import spatial

x = test_dataset.take(1)

extracting_layer = Model(inputs=model.input, outputs=model.get_layer('entity_extracting').output)
entities = extracting_layer.predict(x)[0]
keys_story, values_story = tf.split(entities, 2, axis=-1)

reading_layer = Model(inputs=model.input, outputs=model.get_layer('entity_reading').output)
key_query, queried_value = reading_layer.predict(x)
key_query = key_query[0]
queried_value = queried_value[0]

cosine_sim_keys = np.zeros((args.timesteps, 1))
for i, key_story in enumerate(keys_story):
cosine_sim_keys[i, 0] = 1 - spatial.distance.cosine(key_story, key_query)
print(cosine_sim_keys)

cosine_sim_values = np.zeros((args.timesteps, 1))
for i, value_story in enumerate(values_story):
cosine_sim_values[i, 0] = 1 - spatial.distance.cosine(value_story, queried_value)
print(cosine_sim_values)
2 changes: 1 addition & 1 deletion image_association_task_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
parser = argparse.ArgumentParser()
parser.add_argument('--delay', type=int, default=0)
parser.add_argument('--timesteps', type=int, default=3)
parser.add_argument('--delay_padding', type=str, default='random')
parser.add_argument('--delay_padding', type=str, default='random', help='`zeros` or `random`')

parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--batch_size_per_replica', type=int, default=32)
Expand Down
3 changes: 1 addition & 2 deletions layers/encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@


class Encoding(Layer):
"""TODO"""

def __init__(self,
encodings_type,
encodings_constraint,
encodings_constraint='mask_time_word',
**kwargs):
super().__init__(**kwargs)

Expand Down
4 changes: 2 additions & 2 deletions layers/reading.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def call(self, inputs, constants):

v = K.batch_dot(k, memory_matrix)

return k, v
return v

def compute_mask(self, inputs, mask=None):
return mask
Expand Down Expand Up @@ -82,7 +82,7 @@ def call(self, inputs, states, constants):

v = K.batch_dot(k, memory_matrix)

return [k, v], v
return v, v

def compute_mask(self, inputs, mask=None):
return mask
Expand Down
35 changes: 20 additions & 15 deletions layers/writing.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def __init__(self,
gamma_pos,
gamma_neg,
w_assoc_max,
use_bias=False,
read_before_write=False,
kernel_initializer=None,
kernel_regularizer=None,
learn_gamma_pos=False,
learn_gamma_neg=False,
**kwargs):
Expand All @@ -57,16 +61,21 @@ def __init__(self,
self.w_max = w_assoc_max
self._gamma_pos = gamma_pos
self._gamma_neg = gamma_neg
self.use_bias = use_bias
self.read_before_write = read_before_write
self.kernel_initializer = kernel_initializer
self.kernel_regularizer = kernel_regularizer
self.learn_gamma_pos = learn_gamma_pos
self.learn_gamma_neg = learn_gamma_neg

self.dense = tf.keras.layers.Dense(units=self.units,
use_bias=False,
kernel_initializer='he_uniform',
kernel_regularizer=tf.keras.regularizers.l2(1e-3))
if self.read_before_write:
self.dense = tf.keras.layers.Dense(units=self.units,
use_bias=self.use_bias,
kernel_initializer=self.kernel_initializer,
kernel_regularizer=self.kernel_regularizer)

self.ln1 = tf.keras.layers.LayerNormalization()
self.ln2 = tf.keras.layers.LayerNormalization()
self.ln1 = tf.keras.layers.LayerNormalization()
self.ln2 = tf.keras.layers.LayerNormalization()

@property
def state_size(self):
Expand All @@ -86,21 +95,17 @@ def call(self, inputs, states, mask=None):
memory_matrix = states[0]
k, v = tf.split(inputs, 2, axis=-1)

k = self.ln1(k) # TODO layer norm for v here?
# v = self.ln2(v)
if self.read_before_write:
k = self.ln1(k)
v_h = K.batch_dot(k, memory_matrix)

v_h = K.batch_dot(k, memory_matrix)

v = self.dense(tf.concat([v, v_h], axis=1))
v = self.ln2(v)
# k = self.ln1(k) # TODO layer norm for v here?
v = self.dense(tf.concat([v, v_h], axis=1))
v = self.ln2(v)

k = tf.expand_dims(k, 2)
v = tf.expand_dims(v, 1)

hebb = self.gamma_pos * (self.w_max - memory_matrix) * k * v - self.gamma_neg * memory_matrix * k**2
# hebb = self.gamma_pos * (self.w_max - memory_matrix) * k * v - self.gamma_neg * memory_matrix * k
# hebb = self.gamma_pos * (self.w_max - memory_matrix) * k * v

memory_matrix = hebb + memory_matrix

Expand Down

0 comments on commit 78631b8

Please sign in to comment.