From 02f3b7a311bfabfd872cdadc9131f187a8ccb707 Mon Sep 17 00:00:00 2001 From: Thomas Limbacher Date: Fri, 14 Aug 2020 13:05:17 +0200 Subject: [PATCH 1/2] First version solving bAbI tasks 3 and 16 --- babi_task_single.py | 130 ++++++++++++++++++++++++++++++++++---- environment.yml | 1 + image_association_task.py | 44 +++++++++++-- layers/reading.py | 8 ++- layers/writing.py | 18 ++++++ 5 files changed, 180 insertions(+), 21 deletions(-) diff --git a/babi_task_single.py b/babi_task_single.py index aa598d9..2ca0cae 100644 --- a/babi_task_single.py +++ b/babi_task_single.py @@ -9,14 +9,13 @@ import numpy as np import tensorflow as tf -from tensorflow.keras import Model -from tensorflow.keras.layers import TimeDistributed - from data.babi_data import download, load_task, tasks, vectorize_data from layers.encoding import Encoding from layers.extracting import Extracting from layers.reading import ReadingCell from layers.writing import WritingCell +from tensorflow.keras import Model +from tensorflow.keras.layers import TimeDistributed from utils.logger import MyCSVLogger strategy = tf.distribute.MirroredStrategy() @@ -26,9 +25,9 @@ parser.add_argument('--max_num_sentences', type=int, default=-1) parser.add_argument('--training_set_size', type=str, default='10k') -parser.add_argument('--epochs', type=int, default=100) +parser.add_argument('--epochs', type=int, default=200) parser.add_argument('--learning_rate', type=float, default=0.003) -parser.add_argument('--batch_size_per_replica', type=int, default=128) +parser.add_argument('--batch_size_per_replica', type=int, default=32) parser.add_argument('--random_state', type=int, default=None) parser.add_argument('--max_grad_norm', type=float, default=20.0) parser.add_argument('--validation_split', type=float, default=0.1) @@ -44,6 +43,7 @@ parser.add_argument('--verbose', type=int, default=1) parser.add_argument('--logging', type=int, default=0) +parser.add_argument('--make_plots', type=int, default=0) args = parser.parse_args() batch_size = args.batch_size_per_replica * strategy.num_replicas_in_sync @@ -168,12 +168,34 @@ w_assoc_max=args.w_assoc_max), name='entity_writing')(entities) - queried_value = tf.keras.layers.RNN(ReadingCell(units=args.memory_size, - use_bias=False, - activation='relu', - kernel_initializer='he_uniform', - kernel_regularizer=tf.keras.regularizers.l2(1e-3)), - name='entity_reading')(query_encoded, constants=[memory_matrix]) + # queried_value = tf.keras.layers.RNN(ReadingCell(units=args.memory_size, + # use_bias=False, + # activation='relu', + # kernel_initializer='he_uniform', + # kernel_regularizer=tf.keras.regularizers.l2(1e-3)), + # name='entity_reading')(query_encoded, constants=[memory_matrix]) + + k, queried_values = tf.keras.layers.RNN(ReadingCell(units=args.memory_size, + use_bias=False, + activation='relu', + kernel_initializer='he_uniform', + kernel_regularizer=tf.keras.regularizers.l2(1e-3)), + return_sequences=True, name='entity_reading')( + query_encoded, constants=[memory_matrix]) + + # queried_value = tf.keras.layers.Lambda(lambda x: x[:, -1, :])(queried_values) + + # queried_value = tf.keras.layers.LSTM(100, + # kernel_regularizer=tf.keras.regularizers.l2(1e-3))(queried_values) + + # queried_value = tf.keras.layers.SimpleRNN(100, kernel_regularizer=tf.keras.regularizers.l2(1e-3), + # )(queried_values) + + queried_value = tf.keras.layers.Lambda(lambda x: tf.keras.backend.sum(x, axis=1))(queried_values) + + # queried_value = tf.keras.layers.Attention()([k, queried_values]) + + # queried_value = tf.keras.layers.GlobalAveragePooling1D()(queried_value) outputs = tf.keras.layers.Dense(vocab_size, use_bias=False, @@ -193,7 +215,8 @@ # Train and evaluate. def lr_scheduler(epoch): - return args.learning_rate * 0.85**tf.math.floor(epoch / 20) + # return args.learning_rate * 0.85**tf.math.floor(epoch / 20) + return args.learning_rate callbacks = [] @@ -211,3 +234,86 @@ def lr_scheduler(epoch): args.task_id, args.training_set_size, args.encodings_type, args.hops, args.random_state)))) model.evaluate(x=x_test, y=y_test, callbacks=callbacks, verbose=2) + +if args.make_plots: + + import matplotlib as mpl + mpl.use('Agg') + import matplotlib.pyplot as plt + from scipy import spatial + + examples = range(1) + for example in examples: + x = [testS[example][np.newaxis, :, :], testQ[example][np.newaxis, :, :]] + + extracting_layer = Model(inputs=model.input, outputs=model.get_layer('entity_extracting').output) + entities = extracting_layer.predict(x) + keys_story, values_story = tf.split(entities, 2, axis=-1) + + reading_layer = Model(inputs=model.input, outputs=model.get_layer('entity_reading').output) + keys_query, queried_values = reading_layer.predict(x) + + print(test[example]) + + print(' '.join(test[example][0][0])) + + cosine_sim_keys = np.zeros((x[0].shape[1], args.hops)) + for i, key_story in enumerate(keys_story[0]): + for j, key_query in enumerate(keys_query[0]): + cosine_sim_keys[i, j] = 1 - spatial.distance.cosine(key_story, key_query) + print(cosine_sim_keys) + + cosine_sim_values = np.zeros((x[0].shape[1], args.hops)) + for i, value_story in enumerate(values_story[0]): + for j, queried_value in enumerate(queried_values[0]): + cosine_sim_values[i, j] = 1 - spatial.distance.cosine(value_story, queried_value) + print(cosine_sim_values) + + fig, ax = plt.subplots(nrows=3, ncols=1, sharex=True) + # plt.suptitle('entities writing') + for i, t in enumerate(test[example][0]): + ax[0].text(0.5+i*1, 0.0, ' '.join(t), {'ha': 'center', 'va': 'bottom'}, + fontsize=7, rotation=90) + ax[0].set_frame_on(False) + ax[0].axes.get_yaxis().set_visible(False) + ax[0].axes.get_xaxis().set_visible(False) + + ax[1].pcolormesh(tf.transpose(keys_story[0]), cmap='coolwarm') + ax[2].pcolormesh(tf.transpose(values_story[0]), cmap='coolwarm') + ax[1].set_ylabel('keys story') + ax[2].set_ylabel('values story') + ax[2].set_xlim([0, 10]) + fig.savefig('entities-writing-{0}.png'.format(example), dpi=fig.dpi) + + fig, ax = plt.subplots(nrows=3, ncols=1, sharex=True) + # plt.suptitle('entities reading') + for i in range(args.hops): + ax[0].text(0.5+i*1, 0.0, ' '.join(test[example][1]), {'ha': 'center', 'va': 'bottom'}, + fontsize=7, rotation=90) + ax[0].set_frame_on(False) + ax[0].axes.get_yaxis().set_visible(False) + ax[0].axes.get_xaxis().set_visible(False) + + ax[1].pcolormesh(tf.transpose(keys_query[0]), cmap='coolwarm') + ax[2].pcolormesh(tf.transpose(queried_values[0]), cmap='coolwarm') + ax[1].set_ylabel('keys query') + ax[2].set_ylabel('queried value') + fig.savefig('entities-reading-{0}.png'.format(example), dpi=fig.dpi) + + fig, ax = plt.subplots(nrows=1, ncols=1) + plt.suptitle('cosine sim keys') + cax = ax.matshow(cosine_sim_keys[:10, :], cmap='coolwarm') + fig.colorbar(cax) + ax.set_ylabel('keys story') + ax.set_xlabel('keys query') + fig.savefig('cosine-sim-keys-{0}.png'.format(example), dpi=fig.dpi) + + fig, ax = plt.subplots(nrows=1, ncols=1) + plt.suptitle('cosine sim values') + cax = ax.matshow(cosine_sim_values[:10, :], cmap='coolwarm') + fig.colorbar(cax) + ax.set_ylabel('values story') + ax.set_xlabel('queried values') + fig.savefig('cosine-sim-values-{0}.png'.format(example), dpi=fig.dpi) + + # plt.show() diff --git a/environment.yml b/environment.yml index 5a07a1f..1b31907 100644 --- a/environment.yml +++ b/environment.yml @@ -5,4 +5,5 @@ channels: dependencies: - python=3.7 + - matplotlib=3.1 - tensorflow-gpu=2.1 diff --git a/image_association_task.py b/image_association_task.py index fc4bec3..38f3576 100644 --- a/image_association_task.py +++ b/image_association_task.py @@ -7,6 +7,7 @@ import numpy as np import tensorflow as tf from tensorflow.keras.layers import TimeDistributed +from tensorflow.keras import Model from data.image_association_data import load_data from layers.extracting import Extracting @@ -38,6 +39,7 @@ parser.add_argument('--w_assoc_max', type=float, default=1.0) parser.add_argument('--verbose', type=int, default=1) +parser.add_argument('--make_plots', type=int, default=0) args = parser.parse_args() batch_size = args.batch_size_per_replica * strategy.num_replicas_in_sync @@ -145,19 +147,19 @@ def dataset_generator(x, y, seed): learn_gamma_neg=False), name='entity_writing')(entities) - queried_value = Reading(units=args.memory_size, - use_bias=False, - activation='relu', - kernel_initializer='he_uniform', - kernel_regularizer=tf.keras.regularizers.l2(1e-3), - name='entity_reading')(features_b, constants=[memory_matrix]) + _, queried_value = Reading(units=args.memory_size, + use_bias=False, + activation='relu', + kernel_initializer='he_uniform', + kernel_regularizer=tf.keras.regularizers.l2(1e-3), + name='entity_reading')(features_b, constants=[memory_matrix]) outputs = tf.keras.layers.Dense(10, use_bias=False, kernel_initializer='he_uniform', name='output')(queried_value) - model = tf.keras.Model(inputs=[input_a, input_b], outputs=outputs) + model = Model(inputs=[input_a, input_b], outputs=outputs) # Compile the model. optimizer_kwargs = {'clipnorm': args.max_grad_norm} if args.max_grad_norm else {} @@ -186,3 +188,31 @@ def lr_scheduler(epoch): verbose=args.verbose) model.evaluate(test_dataset, steps=np.ceil(num_test/batch_size), verbose=2) + +if args.make_plots: + + import matplotlib as mpl + mpl.use('Agg') + import matplotlib.pyplot as plt + from scipy import spatial + + x = test_dataset.take(1) + + extracting_layer = Model(inputs=model.input, outputs=model.get_layer('entity_extracting').output) + entities = extracting_layer.predict(x)[0] + keys_story, values_story = tf.split(entities, 2, axis=-1) + + reading_layer = Model(inputs=model.input, outputs=model.get_layer('entity_reading').output) + key_query, queried_value = reading_layer.predict(x) + key_query = key_query[0] + queried_value = queried_value[0] + + cosine_sim_keys = np.zeros((args.timesteps, 1)) + for i, key_story in enumerate(keys_story): + cosine_sim_keys[i, 0] = 1 - spatial.distance.cosine(key_story, key_query) + print(cosine_sim_keys) + + cosine_sim_values = np.zeros((args.timesteps, 1)) + for i, value_story in enumerate(values_story): + cosine_sim_values[i, 0] = 1 - spatial.distance.cosine(value_story, queried_value) + print(cosine_sim_values) diff --git a/layers/reading.py b/layers/reading.py index 69079fd..c8f5cc1 100644 --- a/layers/reading.py +++ b/layers/reading.py @@ -39,7 +39,7 @@ def call(self, inputs, constants): v = K.batch_dot(k, memory_matrix) - return v + return k, v def compute_mask(self, inputs, mask=None): return mask @@ -69,6 +69,8 @@ def __init__(self, kernel_initializer=self.kernel_initializer, kernel_regularizer=self.kernel_regularizer) + self.ln1 = tf.keras.layers.LayerNormalization() + @property def state_size(self): return self.units @@ -82,9 +84,11 @@ def call(self, inputs, states, constants): k = self.dense(tf.concat([inputs, v], axis=1)) + k = self.ln1(k) + v = K.batch_dot(k, memory_matrix) - return v, v + return [k, v], v def compute_mask(self, inputs, mask=None): return mask diff --git a/layers/writing.py b/layers/writing.py index 10ee3a0..7bdc665 100644 --- a/layers/writing.py +++ b/layers/writing.py @@ -2,6 +2,7 @@ import tensorflow as tf from tensorflow.keras.layers import Layer +import tensorflow.keras.backend as K class Writing(Layer): @@ -73,16 +74,33 @@ def build(self, input_shape): initializer=tf.keras.initializers.Constant(self._gamma_neg), dtype=self.dtype, name='gamma_neg') + self.dense1 = tf.keras.layers.Dense(units=self.units) + + self.ln1 = tf.keras.layers.LayerNormalization() + self.ln2 = tf.keras.layers.LayerNormalization() + self.ln3 = tf.keras.layers.LayerNormalization() + super().build(input_shape) def call(self, inputs, states, mask=None): memory_matrix = states[0] k, v = tf.split(inputs, 2, axis=-1) + k = self.ln1(k) + v = self.ln2(v) + + v_h = K.batch_dot(k, memory_matrix) + + v = self.dense1(tf.concat([v, v_h], axis=-1)) + + v = self.ln3(v) + k = tf.expand_dims(k, 2) v = tf.expand_dims(v, 1) hebb = self.gamma_pos * (self.w_max - memory_matrix) * k * v - self.gamma_neg * memory_matrix * k**2 + # hebb = self.gamma_pos * (self.w_max - memory_matrix) * k * v + memory_matrix = hebb + memory_matrix return memory_matrix, memory_matrix From 4849f29176e697f8081158cb607292998163e3c6 Mon Sep 17 00:00:00 2001 From: Thomas Limbacher Date: Thu, 22 Oct 2020 13:56:15 +0200 Subject: [PATCH 2/2] First version solving all bAbI tasks --- babi_task_single.py | 69 ++++++++++++++++++++++---------------- layers/reading.py | 4 --- layers/writing.py | 25 ++++++++------ utils/configure_seaborn.py | 50 +++++++++++++++++++++++++++ 4 files changed, 104 insertions(+), 44 deletions(-) create mode 100644 utils/configure_seaborn.py diff --git a/babi_task_single.py b/babi_task_single.py index 2ca0cae..aca9d7b 100644 --- a/babi_task_single.py +++ b/babi_task_single.py @@ -25,9 +25,9 @@ parser.add_argument('--max_num_sentences', type=int, default=-1) parser.add_argument('--training_set_size', type=str, default='10k') -parser.add_argument('--epochs', type=int, default=200) +parser.add_argument('--epochs', type=int, default=250) parser.add_argument('--learning_rate', type=float, default=0.003) -parser.add_argument('--batch_size_per_replica', type=int, default=32) +parser.add_argument('--batch_size_per_replica', type=int, default=128) parser.add_argument('--random_state', type=int, default=None) parser.add_argument('--max_grad_norm', type=float, default=20.0) parser.add_argument('--validation_split', type=float, default=0.1) @@ -183,19 +183,7 @@ return_sequences=True, name='entity_reading')( query_encoded, constants=[memory_matrix]) - # queried_value = tf.keras.layers.Lambda(lambda x: x[:, -1, :])(queried_values) - - # queried_value = tf.keras.layers.LSTM(100, - # kernel_regularizer=tf.keras.regularizers.l2(1e-3))(queried_values) - - # queried_value = tf.keras.layers.SimpleRNN(100, kernel_regularizer=tf.keras.regularizers.l2(1e-3), - # )(queried_values) - - queried_value = tf.keras.layers.Lambda(lambda x: tf.keras.backend.sum(x, axis=1))(queried_values) - - # queried_value = tf.keras.layers.Attention()([k, queried_values]) - - # queried_value = tf.keras.layers.GlobalAveragePooling1D()(queried_value) + queried_value = tf.keras.layers.Lambda(lambda x: x[:, -1, :])(queried_values) outputs = tf.keras.layers.Dense(vocab_size, use_bias=False, @@ -215,8 +203,13 @@ # Train and evaluate. def lr_scheduler(epoch): - # return args.learning_rate * 0.85**tf.math.floor(epoch / 20) - return args.learning_rate + # return args.learning_rate + # return args.learning_rate * 0.75**tf.math.floor(epoch / 100) + if epoch < 150: + return args.learning_rate + else: + # return args.learning_rate * 0.85**tf.math.floor(epoch / 50) + return args.learning_rate * tf.math.exp(0.01 * (150 - epoch)) callbacks = [] @@ -237,12 +230,16 @@ def lr_scheduler(epoch): if args.make_plots: + import utils.configure_seaborn as cs + import seaborn as sns import matplotlib as mpl mpl.use('Agg') import matplotlib.pyplot as plt from scipy import spatial - examples = range(1) + sns.set(context='paper', style='ticks', rc=cs.rc_params) + + examples = range(10) # [0, 1, 2, 3] # 2 # range(1) for example in examples: x = [testS[example][np.newaxis, :, :], testQ[example][np.newaxis, :, :]] @@ -278,12 +275,23 @@ def lr_scheduler(epoch): ax[0].axes.get_yaxis().set_visible(False) ax[0].axes.get_xaxis().set_visible(False) - ax[1].pcolormesh(tf.transpose(keys_story[0]), cmap='coolwarm') - ax[2].pcolormesh(tf.transpose(values_story[0]), cmap='coolwarm') + vmin_k = np.minimum(np.min(keys_story[0]), np.min(keys_query[0])) + vmax_k = np.maximum(np.max(keys_story[0]), np.max(keys_query[0])) + print(vmin_k, vmax_k) + + vmin_v = np.minimum(np.min(values_story[0]), np.min(queried_values[0])) + vmax_v = np.maximum(np.max(values_story[0]), np.max(queried_values[0])) + print(vmin_v, vmax_v) + + vmin = np.minimum(vmin_k, vmin_v) + vmax = np.maximum(vmax_k, vmax_v) + + ax[1].pcolormesh(tf.transpose(keys_story[0]), cmap='Blues') # , vmin=vmin_k, vmax=vmax_k) + ax[2].pcolormesh(tf.transpose(values_story[0]), cmap='Oranges') # , vmin=vmin_v, vmax=vmax_v) ax[1].set_ylabel('keys story') ax[2].set_ylabel('values story') ax[2].set_xlim([0, 10]) - fig.savefig('entities-writing-{0}.png'.format(example), dpi=fig.dpi) + fig.savefig('entities-writing-{0}.pdf'.format(example), dpi=fig.dpi) fig, ax = plt.subplots(nrows=3, ncols=1, sharex=True) # plt.suptitle('entities reading') @@ -294,26 +302,29 @@ def lr_scheduler(epoch): ax[0].axes.get_yaxis().set_visible(False) ax[0].axes.get_xaxis().set_visible(False) - ax[1].pcolormesh(tf.transpose(keys_query[0]), cmap='coolwarm') - ax[2].pcolormesh(tf.transpose(queried_values[0]), cmap='coolwarm') - ax[1].set_ylabel('keys query') + ax[1].pcolormesh(tf.transpose(keys_query[0]), cmap='Blues') # , vmin=vmin_k, vmax=vmax_k) + ax[2].pcolormesh(tf.transpose(queried_values[0]), cmap='Oranges') # , vmin=vmin_v, vmax=vmax_v) + ax[1].set_ylabel('key query') ax[2].set_ylabel('queried value') - fig.savefig('entities-reading-{0}.png'.format(example), dpi=fig.dpi) + fig.savefig('entities-reading-{0}.pdf'.format(example), dpi=fig.dpi) + + # my_cmap = sns.color_palette("crest", as_cmap=True) + my_cmap = sns.color_palette("ch:start=.2,rot=-.3", as_cmap=True) fig, ax = plt.subplots(nrows=1, ncols=1) plt.suptitle('cosine sim keys') - cax = ax.matshow(cosine_sim_keys[:10, :], cmap='coolwarm') + cax = ax.matshow(cosine_sim_keys[:10, :], cmap=my_cmap) # , vmin=-1, vmax=1) fig.colorbar(cax) ax.set_ylabel('keys story') ax.set_xlabel('keys query') - fig.savefig('cosine-sim-keys-{0}.png'.format(example), dpi=fig.dpi) + fig.savefig('cosine-sim-keys-{0}.pdf'.format(example), dpi=fig.dpi) fig, ax = plt.subplots(nrows=1, ncols=1) plt.suptitle('cosine sim values') - cax = ax.matshow(cosine_sim_values[:10, :], cmap='coolwarm') + cax = ax.matshow(cosine_sim_values[:10, :], cmap=my_cmap) # , vmin=-1, vmax=1) fig.colorbar(cax) ax.set_ylabel('values story') ax.set_xlabel('queried values') - fig.savefig('cosine-sim-values-{0}.png'.format(example), dpi=fig.dpi) + fig.savefig('cosine-sim-values-{0}.pdf'.format(example), dpi=fig.dpi) # plt.show() diff --git a/layers/reading.py b/layers/reading.py index c8f5cc1..6cd0123 100644 --- a/layers/reading.py +++ b/layers/reading.py @@ -69,8 +69,6 @@ def __init__(self, kernel_initializer=self.kernel_initializer, kernel_regularizer=self.kernel_regularizer) - self.ln1 = tf.keras.layers.LayerNormalization() - @property def state_size(self): return self.units @@ -84,8 +82,6 @@ def call(self, inputs, states, constants): k = self.dense(tf.concat([inputs, v], axis=1)) - k = self.ln1(k) - v = K.batch_dot(k, memory_matrix) return [k, v], v diff --git a/layers/writing.py b/layers/writing.py index 7bdc665..0aa7a6f 100644 --- a/layers/writing.py +++ b/layers/writing.py @@ -62,6 +62,14 @@ def __init__(self, self.learn_gamma_pos = learn_gamma_pos self.learn_gamma_neg = learn_gamma_neg + self.dense = tf.keras.layers.Dense(units=self.units, + use_bias=False, + kernel_initializer='he_uniform', + kernel_regularizer=tf.keras.regularizers.l2(1e-3)) + + self.ln1 = tf.keras.layers.LayerNormalization() + self.ln2 = tf.keras.layers.LayerNormalization() + @property def state_size(self): return tf.TensorShape((self.units, self.units)) @@ -74,31 +82,26 @@ def build(self, input_shape): initializer=tf.keras.initializers.Constant(self._gamma_neg), dtype=self.dtype, name='gamma_neg') - self.dense1 = tf.keras.layers.Dense(units=self.units) - - self.ln1 = tf.keras.layers.LayerNormalization() - self.ln2 = tf.keras.layers.LayerNormalization() - self.ln3 = tf.keras.layers.LayerNormalization() - super().build(input_shape) def call(self, inputs, states, mask=None): memory_matrix = states[0] k, v = tf.split(inputs, 2, axis=-1) - k = self.ln1(k) - v = self.ln2(v) + k = self.ln1(k) # TODO layer norm for v here? + # v = self.ln2(v) v_h = K.batch_dot(k, memory_matrix) - v = self.dense1(tf.concat([v, v_h], axis=-1)) - - v = self.ln3(v) + v = self.dense(tf.concat([v, v_h], axis=1)) + v = self.ln2(v) + # k = self.ln1(k) # TODO layer norm for v here? k = tf.expand_dims(k, 2) v = tf.expand_dims(v, 1) hebb = self.gamma_pos * (self.w_max - memory_matrix) * k * v - self.gamma_neg * memory_matrix * k**2 + # hebb = self.gamma_pos * (self.w_max - memory_matrix) * k * v - self.gamma_neg * memory_matrix * k # hebb = self.gamma_pos * (self.w_max - memory_matrix) * k * v memory_matrix = hebb + memory_matrix diff --git a/utils/configure_seaborn.py b/utils/configure_seaborn.py new file mode 100644 index 0000000..daed6b5 --- /dev/null +++ b/utils/configure_seaborn.py @@ -0,0 +1,50 @@ +"""Configure seaborn for plotting.""" +import matplotlib + +PAD_INCHES = 0.00 + +rc_params = { + # FONT + 'font.size': 9, + 'font.sans-serif': 'Myriad Pro', + 'font.stretch': 'condensed', + 'font.weight': 'normal', + 'mathtext.fontset': 'custom', + 'mathtext.fallback_to_cm': True, + 'mathtext.rm': 'Minion Pro', + 'mathtext.it': 'Minion Pro:italic', + 'mathtext.bf': 'Minion Pro:bold:italic', + 'mathtext.cal': 'Minion Pro:italic', # TODO find calligraphy font + 'mathtext.tt': 'monospace', + # AXES + 'axes.linewidth': 0.5, + 'axes.spines.top': False, + 'axes.spines.right': False, + 'axes.labelsize': 9, + # TICKS + 'xtick.major.size': 3.5, + 'xtick.major.width': 0.8, + 'ytick.major.size': 3.5, + 'xtick.labelsize': 8, + 'ytick.labelsize': 8, + 'ytick.major.width': 0.8, + # LEGEND + 'legend.fontsize': 9, + # SAVING FIGURES + 'savefig.bbox': 'tight', + 'savefig.pad_inches': PAD_INCHES, + 'pdf.fonttype': 42, + 'savefig.dpi': 300} + + +def mm2inch(*tupl, pad): + inch = 25.4 + if isinstance(tupl[0], tuple): + return tuple(i/inch - pad for i in tupl[0]) + else: + return tuple(i/inch - pad for i in tupl) + + +def set_figure_size(*size): + params = {'figure.figsize': mm2inch(size, pad=PAD_INCHES)} + matplotlib.rcParams.update(params)