From 699b31688576916845cd05293117a97b629eadbb Mon Sep 17 00:00:00 2001 From: Thomas Limbacher Date: Tue, 14 Jul 2020 14:03:24 +0200 Subject: [PATCH] Add layers and script for training on bAbI tasks --- babi_task_single.py | 213 ++++++++++++++++++++++++++++++++++++ data/babi_data.py | 235 ++++++++++++++++++++++++++++++++++++++++ layers/encoding.py | 70 ++++++++++++ layers/extracting.py | 52 +++++++++ layers/reading.py | 93 ++++++++++++++++ layers/writing.py | 94 ++++++++++++++++ results/.gitkeep | 0 tests/.gitkeep | 0 utils/logger.py | 32 ++++++ utils/word_encodings.py | 31 ++++++ 10 files changed, 820 insertions(+) create mode 100644 babi_task_single.py create mode 100644 data/babi_data.py create mode 100644 layers/encoding.py create mode 100644 layers/extracting.py create mode 100644 layers/reading.py create mode 100644 layers/writing.py create mode 100644 results/.gitkeep create mode 100644 tests/.gitkeep create mode 100644 utils/logger.py create mode 100644 utils/word_encodings.py diff --git a/babi_task_single.py b/babi_task_single.py new file mode 100644 index 0000000..a16fd6e --- /dev/null +++ b/babi_task_single.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python +"""Runs H-Mem on a single bAbI task.""" + +import argparse +import os +import random +from functools import reduce +from itertools import chain + +import numpy as np +import tensorflow as tf +from tensorflow.keras import Model +from tensorflow.keras.layers import TimeDistributed + +from data.babi_data import download, load_task, tasks, vectorize_data +from layers.encoding import Encoding +from layers.extracting import Extracting +from layers.reading import ReadingCell +from layers.writing import WritingCell +from utils.logger import MyCSVLogger + +strategy = tf.distribute.MirroredStrategy() + +parser = argparse.ArgumentParser() +parser.add_argument('--task_id', type=int, default=1) +parser.add_argument('--max_num_sentences', type=int, default=-1) +parser.add_argument('--training_set_size', type=str, default='10k') + +parser.add_argument('--epochs', type=int, default=100) +parser.add_argument('--learning_rate', type=float, default=0.003) +parser.add_argument('--batch_size_per_replica', type=int, default=128) +parser.add_argument('--random_state', type=int, default=None) +parser.add_argument('--max_grad_norm', type=float, default=20.0) +parser.add_argument('--validation_split', type=float, default=0.1) + +parser.add_argument('--hops', type=int, default=3) +parser.add_argument('--memory_size', type=int, default=100) +parser.add_argument('--embeddings_size', type=int, default=80) +parser.add_argument('--gamma_pos', type=float, default=0.01) +parser.add_argument('--gamma_neg', type=float, default=0.01) +parser.add_argument('--w_assoc_max', type=float, default=1.0) +parser.add_argument('--encodings_type', type=str, default='learned_encoding') +parser.add_argument('--encodings_constraint', type=str, default='mask_time_word') + +parser.add_argument('--verbose', type=int, default=1) +parser.add_argument('--logging', type=int, default=0) +args = parser.parse_args() + +batch_size = args.batch_size_per_replica * strategy.num_replicas_in_sync + +# Set random seeds. +np.random.seed(args.random_state) +random.seed(args.random_state) +tf.random.set_seed(args.random_state) + +if args.logging: + logdir = 'results/' + + if not os.path.exists(logdir): + os.makedirs(logdir) + +# Download bAbI data set. +data_dir = download() + +if args.verbose: + print('Extracting stories for the challenge: {0}, {1}'.format(args.task_id, tasks[args.task_id])) + +# Load the data. +train, test = load_task(data_dir, args.task_id, args.training_set_size) +data = train + test + +vocab = sorted(reduce(lambda x, y: x | y, (set(list(chain.from_iterable(s)) + q + a) for s, q, a in data))) +word_idx = dict((c, i + 1) for i, c in enumerate(vocab)) + +max_story_size = max(map(len, (s for s, _, _ in data))) + +max_num_sentences = max_story_size if args.max_num_sentences == -1 else min(args.max_num_sentences, + max_story_size) + +out_size = len(word_idx) + 1 # +1 for nil word. + +# Add time words/indexes +for i in range(max_num_sentences): + word_idx['time{}'.format(i+1)] = 'time{}'.format(i+1) + +vocab_size = len(word_idx) + 1 # +1 for nil word. +mean_story_size = int(np.mean([len(s) for s, _, _ in data])) +max_sentence_size = max(map(len, chain.from_iterable(s for s, _, _ in data))) + 1 # +1 for time word. +max_query_size = max(map(len, (q for _, q, _ in data))) + +if args.verbose: + print('-') + print('Vocab size:', vocab_size, 'unique words (including "nil" word and "time" words)') + print('Story max length:', max_story_size, 'sentences') + print('Story mean length:', mean_story_size, 'sentences') + print('Story max length:', max_sentence_size, 'words (including "time" word)') + print('Query max length:', max_query_size, 'words') + print('-') + print('Here\'s what a "story" tuple looks like (story, query, answer):') + print(data[0]) + print('-') + print('Vectorizing the stories...') + +# Vectorize the data. +max_words = max(max_sentence_size, max_query_size) +trainS, trainQ, trainA = vectorize_data(train, word_idx, max_num_sentences, max_words, max_words) +testS, testQ, testA = vectorize_data(test, word_idx, max_num_sentences, max_words, max_words) + +trainQ = np.repeat(np.expand_dims(trainQ, axis=1), args.hops, axis=1) +testQ = np.repeat(np.expand_dims(testQ, axis=1), args.hops, axis=1) + +story_shape = trainS.shape[1:] +query_shape = trainQ.shape[1:] + +x_train = [trainS, trainQ] +y_train = np.argmax(trainA, axis=1) + +x_test = [testS, testQ] +y_test = np.argmax(testA, axis=1) + +if args.verbose: + print('-') + print('Stories: integer tensor of shape (samples, max_length, max_words): {0}'.format(trainS.shape)) + print('Here\'s what a vectorized story looks like (sentence, word):') + print(trainS[0]) + print('-') + print('Queries: integer tensor of shape (samples, length): {0}'.format(trainQ.shape)) + print('Here\'s what a vectorized query looks like:') + print(trainQ[0]) + print('-') + print('Answers: binary tensor of shape (samples, vocab_size): {0}'.format(trainA.shape)) + print('Here\'s what a vectorized answer looks like:') + print(trainA[0]) + print('-') + print('Training...') + +with strategy.scope(): + # Build the model. + story_input = tf.keras.layers.Input(story_shape, name='story_input') + query_input = tf.keras.layers.Input(query_shape, name='query_input') + + embedding = tf.keras.layers.Embedding(input_dim=vocab_size, + output_dim=args.embeddings_size, + embeddings_initializer='he_uniform', + embeddings_regularizer=None, + mask_zero=True, + name='embedding') + story_embedded = TimeDistributed(embedding, name='story_embedding')(story_input) + query_embedded = TimeDistributed(embedding, name='query_embedding')(query_input) + + encoding = Encoding(args.encodings_type, args.encodings_constraint, name='encoding') + story_encoded = TimeDistributed(encoding, name='story_encoding')(story_embedded) + query_encoded = TimeDistributed(encoding, name='query_encoding')(query_embedded) + + story_encoded = tf.keras.layers.BatchNormalization(name='batch_norm_story')(story_encoded) + query_encoded = tf.keras.layers.BatchNormalization(name='batch_norm_query')(query_encoded) + + entities = Extracting(units=args.memory_size, + use_bias=False, + activation='relu', + kernel_initializer='he_uniform', + kernel_regularizer=tf.keras.regularizers.l1_l2(l2=1e-3), + name='entity_extracting')(story_encoded) + + memory_matrix = tf.keras.layers.RNN(WritingCell(units=args.memory_size, + gamma_pos=args.gamma_pos, + gamma_neg=args.gamma_neg, + w_assoc_max=args.w_assoc_max), + name='entity_writing')(entities) + + queried_value = tf.keras.layers.RNN(ReadingCell(units=args.memory_size, + use_bias=False, + activation='relu', + kernel_initializer='he_uniform', + kernel_regularizer=tf.keras.regularizers.l1_l2(l2=1e-3)), + name='entity_reading')(query_encoded, constants=[memory_matrix]) + + outputs = tf.keras.layers.Dense(vocab_size, + use_bias=False, + kernel_initializer='he_uniform', + name='output')(queried_value) + + model = Model(inputs=[story_input, query_input], outputs=outputs) + + # Compile the model. + optimizer_kwargs = {'clipnorm': args.max_grad_norm} if args.max_grad_norm else {} + model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=args.learning_rate, **optimizer_kwargs), + loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), + metrics=['accuracy']) + +model.summary() + + +# Train and evaluate. +def lr_scheduler(epoch): + return args.learning_rate * 0.85**tf.math.floor(epoch / 20) + + +callbacks = [] +callbacks.append(tf.keras.callbacks.LearningRateScheduler(lr_scheduler, verbose=0)) +if args.logging: + callbacks.append(tf.keras.callbacks.CSVLogger(os.path.join(logdir, '{0}_{1}_{2}_{3}-{4}.log'.format( + args.task_id, args.training_set_size, args.encodings_type, args.hops, args.random_state)))) + +model.fit(x=x_train, y=y_train, epochs=args.epochs, validation_split=args.validation_split, + batch_size=batch_size, callbacks=callbacks, verbose=args.verbose) + +callbacks = [] +if args.logging: + callbacks.append(MyCSVLogger(os.path.join(logdir, '{0}_{1}_{2}_{3}-{4}.log'.format( + args.task_id, args.training_set_size, args.encodings_type, args.hops, args.random_state)))) + +model.evaluate(x=x_test, y=y_test, callbacks=callbacks, verbose=2) diff --git a/data/babi_data.py b/data/babi_data.py new file mode 100644 index 0000000..29f4435 --- /dev/null +++ b/data/babi_data.py @@ -0,0 +1,235 @@ +"""Utilities for downloading and parsing bAbI task data. + +Modified from https://github.com/domluna/memn2n/blob/master/data_utils.py. + +""" + +import os +import re +import shutil +import urllib.request + +import numpy as np + +tasks = { + 1: 'single_supporting_fact', + 2: 'two_supporting_facts', + 3: 'three_supporting_facts', + 4: 'two_arg_relations', + 5: 'three_arg_relations', + 6: 'yes_no_questions', + 7: 'counting', + 8: 'lists_sets', + 9: 'simple_negation', + 10: 'indefinite_knowledge', + 11: 'basic_coreference', + 12: 'conjunction', + 13: 'compound_coreference', + 14: 'time_reasoning', + 15: 'basic_deduction', + 16: 'basic_induction', + 17: 'positional_reasoning', + 18: 'size_reasoning', + 19: 'path_finding', + 20: 'agents_motivations' +} + + +def download(extract=True): + """Downloads the data set. + + Arguments: + extract: boolean, whether to extract the downloaded archive (default=`True`). + + Returns: + data_dir: string, the data directory. + + """ + url = 'https://s3.amazonaws.com/text-datasets/' + file_name = 'babi_tasks_1-20_v1-2.tar.gz' + data_dir = 'data/' + file_path = data_dir + file_name + + if not os.path.exists(file_path): + print('Downloading ' + url + file_name + '...') + print('-') + with urllib.request.urlopen(url + file_name) as response, open(file_path, 'wb') as out_file: + shutil.copyfileobj(response, out_file) + shutil.unpack_archive(file_path, data_dir) + shutil.move(data_dir + 'tasks_1-20_v1-2', data_dir + 'babi_tasks_1-20_v1-2') + + return data_dir + 'babi_tasks_1-20_v1-2' + + +def load_task(data_dir, task_id, training_set_size='1k', only_supporting=False): + """Loads the nth task. There are 20 tasks in total. + + Arguments: + data_dir: string, the data directory. + task_id: int, the ID of the task (valid values are in `range(1, 21)`). + training_set_size: string, the size of the training set to load (`1k` or `10k`, default=`1k`). + only_supporting: boolean, if `True` only supporting facts are loaded (default=`False`). + + Returns: + A Python tuple containing the training and testing data for the task. + + """ + assert task_id > 0 and task_id < 21 + + data_dir = data_dir + '/en/' if training_set_size == '1k' else data_dir + '/en-10k/' + files = os.listdir(data_dir) + files = [os.path.join(data_dir, f) for f in files] + s = 'qa{}_'.format(task_id) + train_file = [f for f in files if s in f and 'train' in f][0] + test_file = [f for f in files if s in f and 'test' in f][0] + train_data = _get_stories(train_file, only_supporting) + test_data = _get_stories(test_file, only_supporting) + + return train_data, test_data + + +def vectorize_data(data, word_idx, max_num_sentences, sentence_size, query_size): + """Vectorize stories, queries and answers. + + If a sentence length < `sentence_size`, the sentence will be padded with `0`s. If a story length < + `max_num_sentences`, the story will be padded with empty sentences. Empty sentences are 1-D arrays of + length `sentence_size` filled with `0`s. The answer array is returned as a one-hot encoding. + + Arguments: + data: iterable, containing stories, queries and answers. + word_idx: dict, mapping words to unique integers. + max_num_sentences: int, the maximum number of sentences to extract. + sentence_size: int, the maximum number of words in a sentence. + query_size: int, the maximum number of words in a query. + + Returns: + A Python tuple containing vectorized stories, queries, and answers. + + """ + S = [] + Q = [] + A = [] + for story, query, answer in data: + if len(story) > max_num_sentences: + continue + + ss = [] + for i, sentence in enumerate(story, 1): + # Pad to sentence_size, i.e., add nil words, and add story. + ls = max(0, sentence_size - len(sentence)) + ss.append([word_idx[w] for w in sentence] + [0] * ls) + + # Make the last word of each sentence the time 'word' which corresponds to vector of lookup table. + for i in range(len(ss)): + ss[i][-1] = len(word_idx) - max_num_sentences - i + len(ss) + + # Pad stories to max_num_sentences (i.e., add empty stories). + ls = max(0, max_num_sentences - len(ss)) + for _ in range(ls): + ss.append([0] * sentence_size) + + # Pad queries to query_size (i.e., add nil words). + lq = max(0, query_size - len(query)) + q = [word_idx[w] for w in query] + [0] * lq + + y = np.zeros(len(word_idx) + 1) # 0 is reserved for nil word. + for a in answer: + y[word_idx[a]] = 1 + + S.append(ss) + Q.append(q) + A.append(y) + + return np.array(S), np.array(Q), np.array(A) + + +def _get_stories(f, only_supporting=False): + """Given a file name, read the file, retrieve the stories, and then convert the sentences into a single + story. + + If only_supporting is true, only the sentences that support the answer are kept. + + Arguments: + f: string, the file name. + only_supporting: boolean, if `True` only supporting facts are loaded (default=`False`). + + Returns: + A list of Python tuples containing stories, queries, and answers. + + """ + with open(f) as f: + data = _parse_stories(f.readlines(), only_supporting=only_supporting) + + return data + + +def _parse_stories(lines, only_supporting=False): + """Parse stories provided in the bAbI tasks format. + + If only_supporting is true, only the sentences that support the answer are kept. + + Arguments: + lines: iterable, containing the sentences of a full story (story, query, and answer). + only_supporting: boolean, if `True` only supporting facts are loaded (default=`False`). + + Returns: + A Python list containing the parsed stories. + + """ + data = [] + story = [] + for line in lines: + line = str.lower(line) + nid, line = line.split(' ', 1) + nid = int(nid) + if nid == 1: + story = [] + if '\t' in line: # Question + q, a, supporting = line.split('\t') + q = _tokenize(q) + a = [a] # Answer is one vocab word even ie it's actually multiple words. + substory = None + + # Remove question marks + if q[-1] == '?': + q = q[:-1] + + if only_supporting: + # Only select the related substory. + supporting = map(int, supporting.split()) + substory = [story[i - 1] for i in supporting] + else: + # Provide all the substories. + substory = [x for x in story if x] + + data.append((substory, q, a)) + story.append('') + else: # Regular sentence + sent = _tokenize(line) + # Remove periods + if sent[-1] == '.': + sent = sent[:-1] + story.append(sent) + + return data + + +def _tokenize(sent): + """Return the tokens of a sentence including punctuation. + + Arguments: + sent: iterable, containing the sentence. + + Returns: + A Python list containing the tokens in the sentence. + + Examples: + + ```python + tokenize('Bob dropped the apple. Where is the apple?') + + ['Bob', 'dropped', 'the', 'apple', '.', 'Where', 'is', 'the', 'apple', '?'] + ``` + + """ + return [x.strip() for x in re.split(r'(\W+)+?', sent) if x.strip()] diff --git a/layers/encoding.py b/layers/encoding.py new file mode 100644 index 0000000..e4bf76e --- /dev/null +++ b/layers/encoding.py @@ -0,0 +1,70 @@ +"""Sentence encoding.""" + +import tensorflow as tf +from tensorflow.keras.constraints import Constraint +from tensorflow.keras.layers import Layer + +from utils.word_encodings import position_encoding + + +class Encoding(Layer): + """TODO""" + + def __init__(self, + encodings_type, + encodings_constraint, + **kwargs): + super().__init__(**kwargs) + + self.encodings_type = encodings_type.lower() + self.encodings_constraint = encodings_constraint.lower() + + if self.encodings_type not in ('identity_encoding', 'position_encoding', 'learned_encoding'): + raise ValueError('Could not interpret encodings type:', self.encodings_type) + + if self.encodings_constraint not in ('none', 'mask_time_word'): + raise ValueError('Could not interpret encodings constraint:', self.encodings_type) + + self.constraint = self.MaskTimeWord() if self.encodings_constraint == 'mask_time_word' else None + + def build(self, input_shape): + if self.encodings_type.lower() == 'identity_encoding': + self.encoding = tf.ones((input_shape[-2], input_shape[-1])) + if self.encodings_type.lower() == 'position_encoding': + self.encoding = position_encoding(input_shape[-2], input_shape[-1]) + if self.encodings_type.lower() == 'learned_encoding': + self.encoding = self.add_weight(shape=(input_shape[-2], input_shape[-1]), trainable=True, + initializer=tf.initializers.Ones(), + constraint=self.constraint, + dtype=self.dtype, name='encoding') + + super().build(input_shape) + + def call(self, inputs, mask=None): + mask = tf.cast(mask, dtype=self.dtype) + mask = tf.expand_dims(mask, axis=-1) + + return tf.reduce_sum(mask * self.encoding * inputs, axis=-2) + + def compute_mask(self, inputs, mask=None): + if mask is None: + return None + + return tf.reduce_any(mask, axis=-1) + + def compute_output_shape(self, input_shape): + return (input_shape[0], input_shape[-1]) + + class MaskTimeWord(Constraint): + """Make encoding of time words identity to avoid modifying them.""" + + def __init__(self, + **kwargs): + super().__init__(**kwargs) + + def __call__(self, w): + indices = [[w.shape[0]-1]] + updates = tf.ones((1, w.shape[1])) + new_w = tf.tensor_scatter_nd_update(w, indices, updates) + + return new_w diff --git a/layers/extracting.py b/layers/extracting.py new file mode 100644 index 0000000..b88aecd --- /dev/null +++ b/layers/extracting.py @@ -0,0 +1,52 @@ +"""Extracting layer that computes key and value.""" + +import tensorflow as tf +from tensorflow.keras.layers import Dense, Layer + + +class Extracting(Layer): + """TODO""" + + def __init__(self, + units, + use_bias, + activation, + kernel_initializer, + kernel_regularizer, + **kwargs): + super().__init__(**kwargs) + + self.units = units + self.use_bias = use_bias + self.activation = activation + self.kernel_initializer = kernel_initializer + self.kernel_regularizer = kernel_regularizer + + self.dense1 = Dense(units=self.units, + use_bias=self.use_bias, + activation=self.activation, + kernel_initializer=self.kernel_initializer, + kernel_regularizer=self.kernel_regularizer) + self.dense2 = Dense(units=self.units, + use_bias=self.use_bias, + activation=self.activation, + kernel_initializer=self.kernel_initializer, + kernel_regularizer=self.kernel_regularizer) + + def build(self, input_shape): + super().build(input_shape) + + def call(self, inputs, mask=None): + if mask is not None: + mask = tf.cast(mask, dtype=self.dtype) + mask = tf.expand_dims(mask, axis=-1) + else: + mask = 1.0 + + k = mask * self.dense1(inputs) + v = mask * self.dense2(inputs) + + return tf.concat([k, v], axis=-1) + + def compute_mask(self, inputs, mask=None): + return mask diff --git a/layers/reading.py b/layers/reading.py new file mode 100644 index 0000000..69079fd --- /dev/null +++ b/layers/reading.py @@ -0,0 +1,93 @@ +"""Reading layers that read from memory.""" + +import tensorflow as tf +import tensorflow.keras.backend as K +from tensorflow.keras.layers import Dense, Layer + + +class Reading(Layer): + """TODO""" + + def __init__(self, + units, + use_bias, + activation, + kernel_initializer, + kernel_regularizer, + **kwargs): + super().__init__(**kwargs) + + self.units = units + self.use_bias = use_bias + self.activation = activation + self.kernel_initializer = kernel_initializer + self.kernel_regularizer = kernel_regularizer + + self.dense = Dense(units=self.units, + use_bias=self.use_bias, + activation=self.activation, + kernel_initializer=self.kernel_initializer, + kernel_regularizer=self.kernel_regularizer) + + def build(self, input_shape): + super().build(input_shape) + + def call(self, inputs, constants): + memory_matrix = constants[0] + + k = self.dense(inputs) + + v = K.batch_dot(k, memory_matrix) + + return v + + def compute_mask(self, inputs, mask=None): + return mask + + +class ReadingCell(Layer): + """TODO""" + + def __init__(self, + units, + use_bias, + activation, + kernel_initializer, + kernel_regularizer, + **kwargs): + super().__init__(**kwargs) + + self.units = units + self.use_bias = use_bias + self.activation = activation + self.kernel_initializer = kernel_initializer + self.kernel_regularizer = kernel_regularizer + + self.dense = Dense(units=self.units, + use_bias=self.use_bias, + activation=self.activation, + kernel_initializer=self.kernel_initializer, + kernel_regularizer=self.kernel_regularizer) + + @property + def state_size(self): + return self.units + + def build(self, input_shape): + super().build(input_shape) + + def call(self, inputs, states, constants): + v = states[0] + memory_matrix = constants[0] + + k = self.dense(tf.concat([inputs, v], axis=1)) + + v = K.batch_dot(k, memory_matrix) + + return v, v + + def compute_mask(self, inputs, mask=None): + return mask + + def get_initial_state(self, inputs=None, batch_size=None, dtype=None): + return tf.zeros((batch_size, self.units)) diff --git a/layers/writing.py b/layers/writing.py new file mode 100644 index 0000000..10ee3a0 --- /dev/null +++ b/layers/writing.py @@ -0,0 +1,94 @@ +"""Writing layers that write to memory.""" + +import tensorflow as tf +from tensorflow.keras.layers import Layer + + +class Writing(Layer): + """TODO""" + + def __init__(self, + units, + gamma, + learn_gamma=False, + **kwargs): + super().__init__(**kwargs) + + self.units = units + self._gamma = gamma + self.learn_gamma = learn_gamma + + def build(self, input_shape): + self.gamma = self.add_weight(shape=(1,), trainable=self.learn_gamma, + initializer=tf.keras.initializers.Constant(self._gamma), + dtype=self.dtype, name='gamma') + + super().build(input_shape) + + def call(self, inputs, mask=None): + k, v = tf.split(inputs, 2, axis=-1) + + k = tf.expand_dims(k, 2) + v = tf.expand_dims(v, 1) + + hebb = self.gamma * k * v + + memory_matrix = hebb + + return memory_matrix + + def compute_mask(self, inputs, mask=None): + return mask + + +class WritingCell(Layer): + """TODO""" + + def __init__(self, + units, + gamma_pos, + gamma_neg, + w_assoc_max, + learn_gamma_pos=False, + learn_gamma_neg=False, + **kwargs): + super().__init__(**kwargs) + + self.units = units + self.w_max = w_assoc_max + self._gamma_pos = gamma_pos + self._gamma_neg = gamma_neg + self.learn_gamma_pos = learn_gamma_pos + self.learn_gamma_neg = learn_gamma_neg + + @property + def state_size(self): + return tf.TensorShape((self.units, self.units)) + + def build(self, input_shape): + self.gamma_pos = self.add_weight(shape=(1,), trainable=self.learn_gamma_pos, + initializer=tf.keras.initializers.Constant(self._gamma_pos), + dtype=self.dtype, name='gamma_pos') + self.gamma_neg = self.add_weight(shape=(1,), trainable=self.learn_gamma_neg, + initializer=tf.keras.initializers.Constant(self._gamma_neg), + dtype=self.dtype, name='gamma_neg') + + super().build(input_shape) + + def call(self, inputs, states, mask=None): + memory_matrix = states[0] + k, v = tf.split(inputs, 2, axis=-1) + + k = tf.expand_dims(k, 2) + v = tf.expand_dims(v, 1) + + hebb = self.gamma_pos * (self.w_max - memory_matrix) * k * v - self.gamma_neg * memory_matrix * k**2 + memory_matrix = hebb + memory_matrix + + return memory_matrix, memory_matrix + + def compute_mask(self, inputs, mask=None): + return mask + + def get_initial_state(self, inputs=None, batch_size=None, dtype=None): + return tf.zeros((batch_size, self.units, self.units)) diff --git a/results/.gitkeep b/results/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/tests/.gitkeep b/tests/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000..b26d5e9 --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,32 @@ +"""CSV logger""" + +import csv + +from tensorflow.keras.callbacks import Callback + + +class MyCSVLogger(Callback): + def __init__(self, filename): + self.filename = filename + + def on_test_begin(self, logs=None): + self.csv_file = open(self.filename, "a") + + class CustomDialect(csv.excel): + delimiter = ',' + + self.fieldnames = ['error [%]'] + self.writer = csv.DictWriter(self.csv_file, self.fieldnames, dialect=CustomDialect) + self.writer.writeheader() + + def on_test_batch_begin(self, batch, logs=None): + pass + + def on_test_batch_end(self, batch, logs=None): + logs = {'error [%]': 100.0 - logs['accuracy'] * 100.0} + self.writer.writerow(logs) + self.csv_file.flush() + + def on_test_end(self, logs=None): + self.csv_file.close() + self.writer = None diff --git a/utils/word_encodings.py b/utils/word_encodings.py new file mode 100644 index 0000000..2163b5c --- /dev/null +++ b/utils/word_encodings.py @@ -0,0 +1,31 @@ +"""Word encodings.""" + +import numpy as np + + +def position_encoding(sentence_size, embedding_size): + """Position Encoding. + + Encodes the position of words within the sentence (implementation based on + https://arxiv.org/pdf/1503.08895.pdf [1]). + + Arguments: + sentence_size: int, the size of the sentence (number of words). + embedding_size: int, the size of the word embedding. + + Returns: + A encoding matrix represented by a Numpy array with shape `[sentence_size, embedding_size]`. + + """ + encoding = np.ones((embedding_size, sentence_size), dtype=np.float32) + ls = sentence_size + 1 + le = embedding_size + 1 + for i in range(1, le): + for j in range(1, ls): + encoding[i - 1, j - 1] = (i - (embedding_size + 1) / 2) * (j - (sentence_size + 1) / 2) + encoding = 1 + 4 * encoding / embedding_size / sentence_size + + # Make position encoding of time words identity to avoid modifying them. + encoding[:, -1] = 1.0 + + return np.transpose(encoding)