diff --git a/crysnet/__init__.py b/crysnet/__init__.py index 172b467..ec14872 100644 --- a/crysnet/__init__.py +++ b/crysnet/__init__.py @@ -6,4 +6,4 @@ @author: huzongxiang """ -__version__ = "0.1.2" \ No newline at end of file +__version__ = "0.1.5" \ No newline at end of file diff --git a/crysnet/activations/activations.py b/crysnet/activations/activations.py index 1c4fc13..69576b1 100644 --- a/crysnet/activations/activations.py +++ b/crysnet/activations/activations.py @@ -39,4 +39,4 @@ def shifted_softplus(x): """ - return tf.nn.softplus(x) - tf.log(2.0) \ No newline at end of file + return tf.nn.softplus(x) - tf.math.log(2.0) \ No newline at end of file diff --git a/crysnet/layers/__init__.py b/crysnet/layers/__init__.py index 46507cb..6e2963e 100644 --- a/crysnet/layers/__init__.py +++ b/crysnet/layers/__init__.py @@ -5,9 +5,10 @@ @author: huzongxiang """ -from .graphnetworklayer import MessagePassing +from .graphnetworklayer import MessagePassing, NewMessagePassing from .edgenetworklayer import SphericalBasisLayer, AzimuthLayer, ConcatLayer, EdgeAggragate, EdgeMessagePassing -from .crystalgraphlayer import CrystalGraphConvolution +from .crystalgraphlayer import CrystalGraphConvolution, GNConvolution from .graphtransformer import EdgesAugmentedLayer, GraphTransformerEncoder +from .graphormer import GraphormerEncoder, ConvGraphormerEncoder from .partitionpaddinglayer import PartitionPadding, PartitionPaddingPair from .readout import Set2Set \ No newline at end of file diff --git a/crysnet/layers/crystalgraphlayer.py b/crysnet/layers/crystalgraphlayer.py index c9dd7dc..b6e32d0 100644 --- a/crysnet/layers/crystalgraphlayer.py +++ b/crysnet/layers/crystalgraphlayer.py @@ -9,7 +9,6 @@ import tensorflow as tf from tensorflow.keras import layers from tensorflow.keras import activations, initializers, regularizers, constraints -from crysnet.activations import shifted_softplus class CrystalGraphConvolution(layers.Layer): @@ -18,7 +17,7 @@ class CrystalGraphConvolution(layers.Layer): Xie et al. PHYSICAL REVIEW LETTERS 120, 145301 (2018) """ def __init__(self, - steps, + steps=1, kernel_initializer="glorot_uniform", kernel_regularizer=None, kernel_constraint=None, @@ -45,24 +44,6 @@ def __init__(self, def build(self, input_shape): self.atom_dim = input_shape[0][-1] self.edge_dim = input_shape[1][-1] - self.state_dim = input_shape[2][-1] - - self.update_nodes = layers.GRUCell(self.atom_dim, - kernel_regularizer=self.recurrent_regularizer, - recurrent_regularizer=self.recurrent_regularizer, - name='update_nodes' - ) - - self.update_edges = layers.GRUCell(self.edge_dim, - kernel_regularizer=self.recurrent_regularizer, - recurrent_regularizer=self.recurrent_regularizer, - name='update_edges' - ) - - self.update_states = layers.GRUCell(self.state_dim, - kernel_regularizer=self.recurrent_regularizer, - recurrent_regularizer=self.recurrent_regularizer, - name='update_states') with tf.name_scope("nodes_aggregate"): # weight for updating atom_features by bond_features @@ -101,8 +82,6 @@ def build(self, input_shape): constraint=self.bias_constraint, name='nodes_bias_g', ) - - self.softplus = shifted_softplus self.built = True @@ -120,11 +99,11 @@ def aggregate_nodes(self, inputs: Sequence): DESCRIPTION. """ - atom_features, edges_sph_features, state_attrs, pair_indices, atom_graph_indices, bond_graph_indices = inputs + atom_features, edges_features, pair_indices = inputs # concat state_attrs with atom_features to get merged atom_merge_state_features atom_features_gather = tf.gather(atom_features, pair_indices) - atom_merge_features = tf.concat([atom_features_gather[:,0], atom_features_gather[:,1], edges_sph_features], axis=-1) + atom_merge_features = tf.concat([atom_features_gather[:,0], atom_features_gather[:,1], edges_features], axis=-1) transformed_features_s = tf.matmul(atom_merge_features, self.kernel_s) + self.bias_s transformed_features_g = tf.matmul(atom_merge_features, self.kernel_g) + self.bias_g @@ -151,14 +130,155 @@ def call(self, inputs: Sequence) -> Sequence: DESCRIPTION. """ - atom_features, edges_sph_features, state_attrs, pair_indices, atom_graph_indices, bond_graph_indices = inputs + atom_features, edges_features, pair_indices = inputs + + atom_features_updated = atom_features + + for i in range(self.steps): + atom_features_updated = self.aggregate_nodes([atom_features_updated, edges_features, pair_indices]) + + return atom_features_updated + + + def get_config(self): + config = super().get_config() + config.update({"steps": self.steps}) + return config + + +class GNConvolution(layers.Layer): + """ + The CGCNN graph implementation as described in the paper + Xie et al. PHYSICAL REVIEW LETTERS 120, 145301 (2018) + """ + def __init__(self, + steps=1, + kernel_initializer="glorot_uniform", + kernel_regularizer=None, + kernel_constraint=None, + bias_initializer="zeros", + bias_regularizer=None, + recurrent_regularizer=None, + bias_constraint=None, + activation=None, + **kwargs): + super().__init__(**kwargs) + self.steps = steps + self.activation = activations.get(activation) + self.kernel_initializer = initializers.get(kernel_initializer) + self.bias_initializer = initializers.get(bias_initializer) + + self.kernel_regularizer = regularizers.get(kernel_regularizer) + self.bias_regularizer = regularizers.get(bias_regularizer) + self.recurrent_regularizer = regularizers.get(recurrent_regularizer) + + self.kernel_constraint = constraints.get(kernel_constraint) + self.bias_constraint = constraints.get(bias_constraint) + + + def build(self, input_shape): + self.atom_dim = input_shape[0][-1] + self.edge_dim = input_shape[1][-1] + self.state_dim = input_shape[2][-1] + + + # weight for updating atom_features by bond_features + self.kernel_s = self.add_weight( + shape=(self.atom_dim * 2 + self.edge_dim + self.state_dim, + self.atom_dim), + trainable=True, + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + name='nodes_kernel_s', + ) + self.bias_s = self.add_weight( + shape=(self.atom_dim,), + trainable=True, + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint, + name='nodes_bias_s', + ) + + self.kernel_g = self.add_weight( + shape=(self.atom_dim * 2 + self.edge_dim + self.state_dim, + self.atom_dim), + trainable=True, + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + name='nodes_kernel_g', + ) + self.bias_g = self.add_weight( + shape=(self.atom_dim,), + trainable=True, + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint, + name='nodes_bias_g', + ) + + self.built = True + + + def aggregate_nodes(self, inputs: Sequence): + """ + Parameters + ---------- + inputs : Sequence + DESCRIPTION. + + Returns + ------- + atom_features_aggregated : TYPE + DESCRIPTION. + + """ + atom_features, edges_features, state_attrs, pair_indices, atom_graph_indices, bond_graph_indices = inputs + + # each bond in pair_indices, concatenate thier atom features by atom indexes + # then concatenate atom features with bond features + atom_features_gather = tf.gather(atom_features, pair_indices) + edges_merge_atom_features = tf.concat([atom_features_gather[:,0], atom_features_gather[:,1]], axis=-1) + + # repeat state attributes by bond_graph_indices, then concatenate to bond_merge_atom_features + state_attrs_repeat = tf.gather(state_attrs, bond_graph_indices) + edges_features_concated = tf.concat([edges_merge_atom_features, state_attrs_repeat, edges_features], axis=-1) + + transformed_features_s = tf.matmul(edges_features_concated, self.kernel_s) + self.bias_s + transformed_features_g = tf.matmul(edges_features_concated, self.kernel_g) + self.bias_g + + transformed_features = tf.sigmoid(transformed_features_s) * tf.nn.softplus(transformed_features_g) + atom_features_aggregated = tf.math.segment_sum(transformed_features, pair_indices[:,0]) + + atom_features_updated = atom_features + atom_features_aggregated + atom_features_updated = tf.nn.softplus(atom_features_updated) + + return atom_features_updated + + + def call(self, inputs: Sequence) -> Sequence: + """ + Parameters + ---------- + inputs : Sequence + DESCRIPTION. + + Returns + ------- + Sequence + DESCRIPTION. + + """ + atom_features, bond_features, state_attrs, pair_indices, atom_graph_indices, bond_graph_indices = inputs atom_features_updated = atom_features for i in range(self.steps): - atom_features_updated = self.aggregate_nodes([atom_features_updated, edges_sph_features, - state_attrs, pair_indices, - atom_graph_indices, bond_graph_indices]) + atom_features_updated = self.aggregate_nodes([atom_features_updated, bond_features, + state_attrs, pair_indices, + atom_graph_indices, bond_graph_indices]) return atom_features_updated diff --git a/crysnet/layers/graphnetworklayer.py b/crysnet/layers/graphnetworklayer.py index a42237d..1bc49cd 100644 --- a/crysnet/layers/graphnetworklayer.py +++ b/crysnet/layers/graphnetworklayer.py @@ -9,7 +9,6 @@ import tensorflow as tf from tensorflow.keras import layers from tensorflow.keras import activations, initializers, regularizers, constraints -from crysnet.activations import shifted_softplus class MessagePassing(layers.Layer): @@ -96,7 +95,7 @@ def concat_edges(self, inputs: Sequence): DESCRIPTION. """ - atom_features, edges_sph_features, state_attrs, pair_indices, atom_graph_indices, bond_graph_indices = inputs + atom_features, edges_features, state_attrs, pair_indices, atom_graph_indices, bond_graph_indices = inputs # each bond in pair_indices, concatenate thier atom features by atom indexes # then concatenate atom features with bond features @@ -105,7 +104,7 @@ def concat_edges(self, inputs: Sequence): # repeat state attributes by bond_graph_indices, then concatenate to bond_merge_atom_features state_attrs_repeat = tf.gather(state_attrs, bond_graph_indices) - edges_features_concated = tf.concat([edges_merge_atom_features, state_attrs_repeat, edges_sph_features], axis=-1) + edges_features_concated = tf.concat([edges_merge_atom_features, state_attrs_repeat, edges_features], axis=-1) return edges_features_concated @@ -123,7 +122,7 @@ def aggregate_nodes(self, inputs: Sequence): DESCRIPTION. """ - atom_features, edges_sph_features, state_attrs, pair_indices, atom_graph_indices, bond_graph_indices = inputs + atom_features, edges_features, state_attrs, pair_indices, atom_graph_indices, bond_graph_indices = inputs # concat state_attrs with atom_features to get merged atom_merge_state_features state_attrs_repeat = tf.gather(state_attrs, atom_graph_indices) @@ -137,7 +136,7 @@ def aggregate_nodes(self, inputs: Sequence): # so num_bonds of bond features need num_bonds of bond matrix, so a matrix with shape # (num_bonds,(atom_dim,atom_dim,bond_dim)) to transfer bond_features to shape (num_bonds,(atom_dim,atom_dim)) # finally, apply this bond_matrix to adjacent atoms, get bond_features updated atom_features_neighbors - edges_weights = tf.matmul(edges_sph_features, self.kernel) + self.bias + edges_weights = tf.matmul(edges_features, self.kernel) + self.bias edges_weights = tf.reshape(edges_weights, (-1, self.atom_dim + self.state_dim, self.atom_dim + self.state_dim)) atom_features_neighbors = tf.gather(atom_merge_state_features, pair_indices[:, 1]) atom_features_neighbors = tf.expand_dims(atom_features_neighbors, axis=-1) @@ -166,10 +165,10 @@ def concat_states(self, inputs: Sequence): DESCRIPTION. """ - atom_features, edges_sph_features, state_attrs, pair_indices, atom_graph_indices, bond_graph_indices = inputs + atom_features, edges_features, state_attrs, pair_indices, atom_graph_indices, bond_graph_indices = inputs # concat state_attrs with bond_updated and atom_features_aggregated - edges_features_sum = tf.math.segment_sum(edges_sph_features, bond_graph_indices) + edges_features_sum = tf.math.segment_sum(edges_features, bond_graph_indices) atom_features_sum = tf.math.segment_sum(atom_features, atom_graph_indices) state_attrs_concated = tf.concat([atom_features_sum, edges_features_sum, state_attrs], axis=-1) @@ -189,10 +188,10 @@ def call(self, inputs: Sequence) -> Sequence: DESCRIPTION. """ - atom_features, edges_sph_features, state_attrs, pair_indices, atom_graph_indices, bond_graph_indices = inputs + atom_features, edges_features, state_attrs, pair_indices, atom_graph_indices, bond_graph_indices = inputs atom_features_updated = atom_features - edges_features_updated = edges_sph_features + edges_features_updated = edges_features state_attrs_updated = state_attrs # Perform a number of steps of message passing @@ -235,7 +234,7 @@ def get_config(self): return config -class MessagePassingSotfplus(layers.Layer): +class NewMessagePassing(layers.Layer): """ Introducing a kernel sigmoid(node_features) * softplus(node_features) from Xie et al. PHYSICAL REVIEW LETTERS 120, 145301 (2018) """ @@ -323,8 +322,6 @@ def build(self, input_shape): constraint=self.bias_constraint, name='nodes_bias_g', ) - - self.softplus = shifted_softplus self.built = True @@ -342,7 +339,7 @@ def concat_edges(self, inputs: Sequence): DESCRIPTION. """ - atom_features, edges_sph_features, state_attrs, pair_indices, atom_graph_indices, bond_graph_indices = inputs + atom_features, edges_features, state_attrs, pair_indices, atom_graph_indices, bond_graph_indices = inputs # each bond in pair_indices, concatenate thier atom features by atom indexes # then concatenate atom features with bond features @@ -351,7 +348,7 @@ def concat_edges(self, inputs: Sequence): # repeat state attributes by bond_graph_indices, then concatenate to bond_merge_atom_features state_attrs_repeat = tf.gather(state_attrs, bond_graph_indices) - edges_features_concated = tf.concat([edges_merge_atom_features, state_attrs_repeat, edges_sph_features], axis=-1) + edges_features_concated = tf.concat([edges_merge_atom_features, state_attrs_repeat, edges_features], axis=-1) return edges_features_concated @@ -369,7 +366,7 @@ def aggregate_nodes(self, inputs: Sequence): DESCRIPTION. """ - atom_features, edges_sph_features, state_attrs, pair_indices, atom_graph_indices, bond_graph_indices = inputs + atom_features, edges_features, state_attrs, pair_indices, atom_graph_indices, bond_graph_indices = inputs # concat state_attrs with atom_features to get merged atom_merge_state_features state_attrs_repeat = tf.gather(state_attrs, atom_graph_indices) @@ -383,10 +380,10 @@ def aggregate_nodes(self, inputs: Sequence): # so num_bonds of bond features need num_bonds of bond matrix, so a matrix with shape # (num_bonds,(atom_dim,atom_dim,bond_dim)) to transfer bond_features to shape (num_bonds,(atom_dim,atom_dim)) # finally, apply this bond_matrix to adjacent atoms, get bond_features updated atom_features_neighbors - edges_weights_s = tf.matmul(edges_sph_features, self.kernel_s) + self.bias_s + edges_weights_s = tf.matmul(edges_features, self.kernel_s) + self.bias_s edges_weights_s = tf.reshape(edges_weights_s, (-1, self.atom_dim + self.state_dim, self.atom_dim + self.state_dim)) - edges_weights_g = tf.matmul(edges_sph_features, self.kernel_g) + self.bias_g + edges_weights_g = tf.matmul(edges_features, self.kernel_g) + self.bias_g edges_weights_g = tf.reshape(edges_weights_g, (-1, self.atom_dim + self.state_dim, self.atom_dim + self.state_dim)) atom_features_neighbors = tf.gather(atom_merge_state_features, pair_indices[:, 1]) @@ -402,7 +399,7 @@ def aggregate_nodes(self, inputs: Sequence): # first tf.gather end features using end atom index pair_indices[:,1] to atom_features_neighbors # then using bond matrix updates atom_features_neighbors, get transformed_features # finally tf.segment_sum calculates sum of updated neighbors feature by start atom index pair_indices[:,0] - transformed_features = tf.sigmod(transformed_features_s) * self.softplus(transformed_features_g) + transformed_features = tf.sigmoid(transformed_features_s) * tf.nn.softplus(transformed_features_g) atom_features_aggregated = tf.math.segment_sum(transformed_features, pair_indices[:,0]) return atom_features_aggregated @@ -421,10 +418,10 @@ def concat_states(self, inputs: Sequence): DESCRIPTION. """ - atom_features, edges_sph_features, state_attrs, pair_indices, atom_graph_indices, bond_graph_indices = inputs + atom_features, edges_features, state_attrs, pair_indices, atom_graph_indices, bond_graph_indices = inputs # concat state_attrs with bond_updated and atom_features_aggregated - edges_features_sum = tf.math.segment_sum(edges_sph_features, bond_graph_indices) + edges_features_sum = tf.math.segment_sum(edges_features, bond_graph_indices) atom_features_sum = tf.math.segment_sum(atom_features, atom_graph_indices) state_attrs_concated = tf.concat([atom_features_sum, edges_features_sum, state_attrs], axis=-1) @@ -444,10 +441,10 @@ def call(self, inputs: Sequence) -> Sequence: DESCRIPTION. """ - atom_features, edges_sph_features, state_attrs, pair_indices, atom_graph_indices, bond_graph_indices = inputs + atom_features, edges_features, state_attrs, pair_indices, atom_graph_indices, bond_graph_indices = inputs atom_features_updated = atom_features - edges_features_updated = edges_sph_features + edges_features_updated = edges_features state_attrs_updated = state_attrs # Perform a number of steps of message passing diff --git a/crysnet/layers/graphormer.py b/crysnet/layers/graphormer.py new file mode 100644 index 0000000..bb4a74c --- /dev/null +++ b/crysnet/layers/graphormer.py @@ -0,0 +1,369 @@ +# -*- coding: utf-8 -*- +""" +Created on Wed Oct 13 14:47:13 2021 + +@author: huzongxiang +""" + +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras import layers +from .graphnetworklayer import MessagePassing +from .crystalgraphlayer import GNConvolution + + +class ConvMultiHeadAttention(layers.Layer): + def __init__(self, num_heads=8, + steps=1, + kernel_initializer="glorot_uniform", + kernel_regularizer=None, + kernel_constraint=None, + use_bias=True, + bias_initializer="zeros", + bias_regularizer=None, + bias_constraint=None, + **kwargs): + super().__init__(**kwargs) + self.num_heads = num_heads + self.steps = steps + self.kernel_initializer = kernel_initializer + self.kernel_regularizer = kernel_regularizer + self.kernel_constraint = kernel_constraint + self.use_bias = use_bias + self.bias_initializer = bias_initializer + self.bias_regularizer = bias_regularizer + self.bias_constraint = bias_constraint + self.supports_masking = True + + + def build(self, input_shape): + self.atom_dim = input_shape[0][-1] + self.bond_dim = input_shape[1][-1] + self.state_dim = input_shape[2][-1] + + self.query_conv = GNConvolution(steps=self.steps, + kernel_initializer=self.kernel_initializer, + kernel_regularizer=self.kernel_regularizer, + kernel_constraint=self.kernel_constraint , + bias_initializer=self.bias_initializer, + bias_regularizer=self.bias_regularizer, + recurrent_regularizer=None, + bias_constraint=None, + activation=None, + name='query', + ) + + self.key_conv = GNConvolution(steps=self.steps, + kernel_initializer=self.kernel_initializer, + kernel_regularizer=self.kernel_regularizer, + kernel_constraint=self.kernel_constraint , + bias_initializer=self.bias_initializer, + bias_regularizer=self.bias_regularizer, + recurrent_regularizer=None, + bias_constraint=None, + activation=None, + name='key', + ) + + self.value_conv = GNConvolution(steps=self.steps, + kernel_initializer=self.kernel_initializer, + kernel_regularizer=self.kernel_regularizer, + kernel_constraint=self.kernel_constraint , + bias_initializer=self.bias_initializer, + bias_regularizer=self.bias_regularizer, + recurrent_regularizer=None, + bias_constraint=None, + activation=None, + name='value', + ) + + + def call(self, inputs, mask=None): + """ + Parameters + ---------- + query : TYPE + DESCRIPTION. + key : TYPE + DESCRIPTION. + value : TYPE + DESCRIPTION. + attn_bias : TYPE, optional + DESCRIPTION. The default is None. + attention_mask : TYPE, optional + DESCRIPTION. The default is None. + + Returns + ------- + attention_output : TYPE + DESCRIPTION. + + """ + atom_features, bond_features, state_attrs, pair_indices, atom_graph_indices, bond_graph_indices = inputs + + # calculate Q=WqIn, K=WkIn, V=WvIn, In=Inputs + atom_features_q = self.query_conv([atom_features, bond_features, state_attrs, pair_indices, atom_graph_indices, bond_graph_indices]) + atom_features_k = self.key_conv([atom_features, bond_features, state_attrs, pair_indices, atom_graph_indices, bond_graph_indices]) + atom_features_v = self.value_conv([atom_features, bond_features, state_attrs, pair_indices, atom_graph_indices, bond_graph_indices]) + + atom_features_k = tf.transpose(atom_features_k, perm=(1, 0)) + + # calculate Attention(Q, K, V) = softmax((QK^T)/sqrt(d_k))V + atom_features_scores = tf.matmul(atom_features_q, atom_features_k)/self.atom_dim**0.5 + atom_features_attention = tf.nn.softmax(atom_features_scores) + atom_features_output = tf.matmul(atom_features_attention, atom_features_v) + + return atom_features_output + + + def get_config(self): + config = super().get_config() + config.update({"num_heads": self.num_heads}) + config.update({"steps": self.steps}) + return config + + +class ConvGraphormerEncoder(layers.Layer): + def __init__(self, num_heads=8, + steps=1, + embed_dim=16, + edge_embed_dim=64, + state_embed_dim=16, + dense_dim=32, + kernel_initializer="glorot_uniform", + kernel_regularizer=None, + kernel_constraint=None, + use_bias=True, + bias_initializer="zeros", + bias_regularizer=None, + bias_constraint=None, + activation="relu", + **kwargs): + super().__init__(**kwargs) + + self.attention = ConvMultiHeadAttention(num_heads, + steps, + kernel_initializer=kernel_initializer, + kernel_regularizer=kernel_regularizer, + kernel_constraint=kernel_constraint, + use_bias=use_bias, + bias_initializer=bias_initializer, + bias_regularizer=bias_regularizer, + bias_constraint=bias_constraint, + ) + + self.dense_proj = keras.Sequential( + [layers.Dense(dense_dim, activation=activation, name='act_proj'), layers.Dense(embed_dim, name='proj'),] + ) + + self.layernorm_1 = layers.LayerNormalization() + self.layernorm_2 = layers.LayerNormalization() + + self.supports_masking = True + + + def call(self, inputs, mask=None): + atom_features_output = self.attention(inputs, mask=mask) + proj_input = self.layernorm_1(inputs[0] + atom_features_output) + return self.layernorm_2(atom_features_output + self.dense_proj(proj_input)) + + + def get_config(self): + config = super().get_config() + config.update({"num_heads": self.num_heads}) + config.update({"embed_dim": self.embed_dim}) + config.update({"edge_embed_dim": self.edge_embed_dim}) + config.update({"state_embed_dim": self.state_embed_dim}) + config.update({"dense_dim": self.dense_dim}) + return config + + +class MpnnMultiHeadAttention(layers.Layer): + def __init__(self, steps, + kernel_initializer="glorot_uniform", + kernel_regularizer=None, + kernel_constraint=None, + use_bias=True, + bias_initializer="zeros", + bias_regularizer=None, + bias_constraint=None, + **kwargs): + super().__init__(**kwargs) + self.steps = steps + self.kernel_initializer = kernel_initializer + self.kernel_regularizer = kernel_regularizer + self.kernel_constraint = kernel_constraint + self.use_bias = use_bias + self.bias_initializer = bias_initializer + self.bias_regularizer = bias_regularizer + self.bias_constraint = bias_constraint + self.supports_masking = True + + + def build(self, input_shape): + self.atom_dim = input_shape[0][-1] + self.bond_dim = input_shape[1][-1] + self.state_dim = input_shape[2][-1] + + self.query_mpnn = MessagePassing(steps=self.steps, + kernel_initializer=self.kernel_initializer, + kernel_regularizer=self.kernel_regularizer, + kernel_constraint=self.kernel_constraint , + bias_initializer=self.bias_initializer, + bias_regularizer=self.bias_regularizer, + recurrent_regularizer=None, + bias_constraint=None, + activation=None, + name='query', + ) + + self.key_mpnn = MessagePassing(steps=self.steps, + kernel_initializer=self.kernel_initializer, + kernel_regularizer=self.kernel_regularizer, + kernel_constraint=self.kernel_constraint , + bias_initializer=self.bias_initializer, + bias_regularizer=self.bias_regularizer, + recurrent_regularizer=None, + bias_constraint=None, + activation=None, + name='key', + ) + + self.value_mpnn = MessagePassing(steps=self.steps, + kernel_initializer=self.kernel_initializer, + kernel_regularizer=self.kernel_regularizer, + kernel_constraint=self.kernel_constraint , + bias_initializer=self.bias_initializer, + bias_regularizer=self.bias_regularizer, + recurrent_regularizer=None, + bias_constraint=None, + activation=None, + name='value', + ) + + + def call(self, inputs, mask=None): + """ + Parameters + ---------- + query : TYPE + DESCRIPTION. + key : TYPE + DESCRIPTION. + value : TYPE + DESCRIPTION. + attn_bias : TYPE, optional + DESCRIPTION. The default is None. + attention_mask : TYPE, optional + DESCRIPTION. The default is None. + + Returns + ------- + attention_output : TYPE + DESCRIPTION. + + """ + atom_features, bond_features, state_attrs, pair_indices, atom_graph_indices, bond_graph_indices = inputs + + # calculate Q=WqIn, K=WkIn, V=WvIn, In=Inputs + atom_features_q, bond_features_q, state_attrs_q = self.query_mpnn([atom_features, bond_features, state_attrs, pair_indices, atom_graph_indices, bond_graph_indices]) + atom_features_k, bond_features_k, state_attrs_k = self.key_mpnn([atom_features, bond_features, state_attrs, pair_indices, atom_graph_indices, bond_graph_indices]) + atom_features_v, bond_features_v, state_attrs_v = self.value_mpnn([atom_features, bond_features, state_attrs, pair_indices, atom_graph_indices, bond_graph_indices]) + + # calculate Attention(Q, K, V) = softmax((QK^T)/sqrt(d_k))V + atom_features_k = tf.transpose(atom_features_k, perm=(1, 0)) + bond_features_k = tf.transpose(bond_features_k, perm=(1, 0)) + state_attrs_k = tf.transpose(state_attrs_k, perm=(1, 0)) + + atom_features_scores = tf.matmul(atom_features_q, atom_features_k)/self.atom_dim**0.5 + atom_features_attention = tf.nn.softmax(atom_features_scores) + atom_features_output = tf.matmul(atom_features_attention, atom_features_v) + + bond_features_scores = tf.matmul(bond_features_q, bond_features_k)/self.bond_dim**0.5 + bond_features_attention = tf.nn.softmax(bond_features_scores) + bond_features_output = tf.matmul(bond_features_attention, bond_features_v) + + state_attrs_scores = tf.matmul(state_attrs_q, state_attrs_k)/self.state_dim**0.5 + state_attrs_attention = tf.nn.softmax(state_attrs_scores) + state_attrs_output = tf.matmul(state_attrs_attention, state_attrs_v) + + return atom_features_output, bond_features_output, state_attrs_output + + + def get_config(self): + config = super().get_config() + config.update({"num_heads": self.num_heads}) + config.update({"steps": self.steps}) + return config + + +class GraphormerEncoder(layers.Layer): + def __init__(self, num_heads=8, + embed_dim=16, + edge_embed_dim=64, + state_embed_dim=16, + dense_dim=32, + kernel_initializer="glorot_uniform", + kernel_regularizer=None, + kernel_constraint=None, + use_bias=True, + bias_initializer="zeros", + bias_regularizer=None, + bias_constraint=None, + activation="relu", + **kwargs): + super().__init__(**kwargs) + + self.attention = MpnnMultiHeadAttention(num_heads, + kernel_initializer=kernel_initializer, + kernel_regularizer=kernel_regularizer, + kernel_constraint=kernel_constraint, + use_bias=use_bias, + bias_initializer=bias_initializer, + bias_regularizer=bias_regularizer, + bias_constraint=bias_constraint, + ) + + self.dense_proj = keras.Sequential( + [layers.Dense(dense_dim, activation=activation, name='act_proj'), layers.Dense(embed_dim, name='proj'),] + ) + + self.dense_proj_edge = keras.Sequential( + [layers.Dense(dense_dim, activation=activation, name='act_edge_proj'), layers.Dense(edge_embed_dim, name='edge_proj'),] + ) + + self.dense_proj_state = keras.Sequential( + [layers.Dense(dense_dim, activation=activation, name='act_state_proj'), layers.Dense(state_embed_dim, name='state_proj'),] + ) + + self.layernorm_1 = layers.LayerNormalization() + self.layernorm_2 = layers.LayerNormalization() + + self.layernorm_edge_1 = layers.LayerNormalization() + self.layernorm_edge_2 = layers.LayerNormalization() + + self.layernorm_state_1 = layers.LayerNormalization() + self.layernorm_state_2 = layers.LayerNormalization() + + self.supports_masking = True + + + def call(self, inputs, mask=None): + atom_features_output, bond_features_output, state_attrs_output = self.attention(inputs, mask=mask) + proj_input = self.layernorm_1(inputs[0] + atom_features_output) + proj_input_edge = self.layernorm_edge_1(inputs[1] + bond_features_output) + proj_input_state = self.layernorm_state_1(inputs[2] + state_attrs_output) + return (self.layernorm_2(proj_input + self.dense_proj(proj_input)), + self.layernorm_edge_2(proj_input_edge + self.dense_proj_edge(proj_input_edge)), + self.layernorm_state_2(proj_input_state + self.dense_proj_state(proj_input_state)) + ) + + + def get_config(self): + config = super().get_config() + config.update({"num_heads": self.num_heads}) + config.update({"embed_dim": self.embed_dim}) + config.update({"edge_embed_dim": self.edge_embed_dim}) + config.update({"state_embed_dim": self.state_embed_dim}) + config.update({"dense_dim": self.dense_dim}) + return config \ No newline at end of file diff --git a/crysnet/models/GraphModel_MpnnTransformer.py b/crysnet/models/GraphModel_MpnnTransformer.py index b58d850..f8fc796 100644 --- a/crysnet/models/GraphModel_MpnnTransformer.py +++ b/crysnet/models/GraphModel_MpnnTransformer.py @@ -15,7 +15,7 @@ from Readout import Set2Set -def GraphModel( +def GraphformerModel( bond_dim, atom_dim=16, num_atom=118, @@ -59,13 +59,7 @@ def GraphModel( pair_indices_per_graph = layers.Input((2), dtype="int32", name="pair_indices_per_graph") - edge_features = EdgeMessagePassing(units, - edge_steps, - kernel_regularizer=l2(reg0), - sph=spherical_harmonics - )([bond_features, local_env, pair_indices]) - - x_nodes_, x_edges_, x_state = MpnnTransformerEncoder()([atom_features_, edge_features, state_attrs_, pair_indices, atom_graph_indices, bond_graph_indices]) + x_nodes_, x_edges_, x_state = MpnnTransformerEncoder()([atom_features_, bond_features, state_attrs_, pair_indices, atom_graph_indices, bond_graph_indices]) x_nodes_b = PartitionPadding(batch_size)([x_nodes_, atom_graph_indices]) x_edges_b = PartitionPadding(batch_size)([x_edges_, bond_graph_indices]) diff --git a/crysnet/models/__init__.py b/crysnet/models/__init__.py index 0d07d9b..e30e2cc 100644 --- a/crysnet/models/__init__.py +++ b/crysnet/models/__init__.py @@ -5,4 +5,4 @@ @author: huzongxiang """ -from .gnnmodel import GNN \ No newline at end of file +from .gnnframework import GNN \ No newline at end of file diff --git a/crysnet/models/gnnframework.py b/crysnet/models/gnnframework.py new file mode 100644 index 0000000..fc33d4a --- /dev/null +++ b/crysnet/models/gnnframework.py @@ -0,0 +1,347 @@ +# -*- coding: utf-8 -*- +""" +Created on Mon Dec 20 15:15:31 2021 + +@author: huzongxiang +""" + + +import numpy as np +from pathlib import Path +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras.models import Model +from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping +from crysnet.callbacks.cosineannealing import WarmUpCosineDecayScheduler +import matplotlib.pyplot as plt +from scipy import interp +from sklearn.metrics import roc_curve +from sklearn.metrics import auc + + +ModulePath = Path(__file__).parent.absolute() + + +class GNN: + def __init__(self, + model: Model, + atom_dim=16, + bond_dim=32, + num_atom=118, + state_dim=16, + sp_dim=230, + batch_size=16, + regression=True, + ntarget=1, + multiclassification=None, + optimizer='Adam', + **kwargs, + ): + self.model = model + self.atom_dim = atom_dim + self.bond_dim = bond_dim + self.num_atom = num_atom + self.state_dim = state_dim + self.sp_dim = sp_dim + self.batch_size = batch_size + self.regression = regression + self.ntarget = ntarget + self.multiclassification = multiclassification + self.optimizer = optimizer + + self.gnn = model(atom_dim=atom_dim, + bond_dim=bond_dim, + num_atom=num_atom, + state_dim=state_dim, + sp_dim=sp_dim, + batch_size=batch_size, + regression=regression, + multiclassification=multiclassification, + **kwargs) + + + def __getattr__(self, attr): + return getattr(self.gnn, attr) + + + def train(self, train_data, valid_data=None, test_data=None, epochs=200, lr=1e-3, warm_up=True, warmrestart=None, load_weights=False, verbose=1, checkpoints=None, save_weights_only=True, workdir=None): + + gnn = self.gnn + if self.regression: + gnn.compile( + loss=keras.losses.MeanAbsoluteError(), + optimizer=self.optimizer, + metrics=[keras.metrics.MeanAbsoluteError(name='mae')], + ) + elif self.multiclassification: + gnn.compile( + loss=tf.keras.losses.CategoricalCrossentropy(), + optimizer=self.optimizer, + metrics=[tf.keras.metrics.AUC(name="AUC")], + ) + else: + gnn.compile( + loss=tf.keras.losses.BinaryCrossentropy(), + optimizer=self.optimizer, + metrics=[tf.keras.metrics.AUC(name="AUC")], + ) + + print(gnn.summary()) + keras.utils.plot_model(gnn, Path(workdir/"gnn_arch.png"), show_dtype=True, show_shapes=True) + + if load_weights: + print('load weights') + path = train_data.task_type + ".hdf5" + if load_weights == 'default': + best_checkpoint = Path(ModulePath/"model"/path) + elif load_weights == 'custom': + best_checkpoint = Path(workdir/"model"/path) + else: + raise ValueError('load_weights should be "default" or "custom"') + gnn.load_weights(best_checkpoint) + print(train_data.task_type) + Path(workdir/"model").mkdir(exist_ok=True) + Path(workdir/"model"/train_data.task_type).mkdir(exist_ok=True) + if checkpoints is None: + if self.regression: + filepath = Path(workdir/"model"/train_data.task_type/"gnn_{epoch:02d}-{val_mae:.3f}.hdf5") + checkpoint = ModelCheckpoint(filepath, monitor='val_mae', save_best_only=True, save_weights_only=save_weights_only, verbose=verbose, mode='min') + else: + filepath = Path(workdir/"model"/train_data.task_type/"gnn_{epoch:02d}-{val_AUC:.3f}.hdf5") + checkpoint = ModelCheckpoint(filepath, monitor='val_AUC', save_best_only=True, save_weights_only=save_weights_only, verbose=verbose, mode='max') + + earlystop = EarlyStopping(monitor='val_loss', patience=200, verbose=verbose, mode='min') + + if warm_up: + sample_count = train_data.data_size + warmup_epoch = 5 + train_per_epoch = sample_count / self.batch_size + warmup_steps = warmup_epoch * train_per_epoch + restart_epoches = warmrestart + + warm_up_lr = WarmUpCosineDecayScheduler(epochs=epochs, + restart_epoches=restart_epoches, + train_per_epoch=train_per_epoch, + learning_rate_base=lr, + warmup_learning_rate=2e-6, + warmup_steps=warmup_steps, + hold_base_rate_steps=5, + ) + + checkpoints = [checkpoint, warm_up_lr, earlystop] + else: + reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=50, verbose=1, min_lr=1e-6, mode='min') + checkpoints = [checkpoint, reduce_lr, earlystop] + + + if valid_data: + steps_per_train = int(np.ceil(train_data.data_size / self.batch_size)) + steps_per_val = int(np.ceil(valid_data.data_size / self.batch_size)) + else: + steps_per_train = None + steps_per_val = None + + print('gnn fit') + history = gnn.fit( + train_data, + validation_data=valid_data, + steps_per_epoch=steps_per_train, + validation_steps=steps_per_val, + epochs=epochs, + verbose=verbose, + callbacks=checkpoints, + ) + + Path(workdir/"results").mkdir(exist_ok=True) + if self.regression: + plot_train_regression(history, train_data.task_type, workdir) + if test_data: + plot_mae(gnn, test_data, workdir, name='test') + else: + plot_train(history, train_data.task_type, workdir) + if test_data: + if self.multiclassification: + plot_auc_multiclassification(gnn, test_data, self.multiclassification, workdir, name='test') + else: + plot_auc(gnn, test_data, workdir, name='test') + if warm_up: + total_steps = int(epochs * sample_count / self.batch_size) + plot_warm_up_lr(warm_up_lr, total_steps, lr, workdir) + + return gnn + + + def predict_datas(self, test_data, workdir=None, load_weights='default'): + print('load weights and predict...') + save_file = test_data.task_type + ".hdf5" + if load_weights: + best_checkpoint = Path(ModulePath/"model"/save_file) + else: + best_checkpoint = Path(workdir/"model"/save_file) + gnn = self.gnn() + gnn.load_weights(best_checkpoint) + Path(workdir/"results").mkdir(exist_ok=True) + if self.regression: + plot_mae(gnn, test_data, name='test') + else: + if self.multiclassification: + plot_auc_multiclassification(gnn, test_data, self.multiclassification, workdir, name='test') + else: + plot_auc(gnn, test_data, name='test') + + + def predict(self, data, workdir=None, load_weights='default'): + print('load weights and predict...') + save_file = data.task_type + ".hdf5" + if load_weights: + best_checkpoint = Path(ModulePath/"model"/save_file) + else: + best_checkpoint = Path(workdir/"model"/save_file) + gnn = self.gnn() + gnn.load_weights(best_checkpoint) + y_pred_keras = gnn.predict(data).ravel() + return y_pred_keras + + +def plot_train(history, name, path): + print('plot curve of training') + plt.figure(figsize=(10, 12)) + plt.subplot(2,1,1) + plt.plot(history.history["AUC"], label="train AUC") + plt.plot(history.history["val_AUC"], label="valid AUC") + plt.xlabel("Epochs", fontsize=16) + plt.ylabel("AUC", fontsize=16) + plt.legend(fontsize=16) + plt.subplot(2,1,2) + plt.plot(history.history["loss"], label="train loss") + plt.plot(history.history["val_loss"], label="valid loss") + plt.xlabel("Epochs", fontsize=16) + plt.ylabel("loss", fontsize=16) + plt.legend(fontsize=16) + save_path = name + "_train.png" + plt.savefig(path/"results"/save_path) + + +def plot_train_regression(history, name, path): + print('plot curve of training') + plt.figure(figsize=(10, 12)) + plt.subplot(2,1,1) + plt.plot(history.history["mae"], label="train mae") + plt.plot(history.history["val_mae"], label="valid mae") + plt.xlabel("Epochs", fontsize=16) + plt.ylabel("mae", fontsize=16) + plt.legend(fontsize=16) + plt.subplot(2,1,2) + plt.plot(history.history["loss"], label="train loss") + plt.plot(history.history["val_loss"], label="valid loss") + plt.xlabel("Epochs", fontsize=16) + plt.ylabel("loss", fontsize=16) + plt.legend(fontsize=16) + save_path = name + "_train.png" + plt.savefig(path/"results"/save_path) + + +def plot_auc(gnn, test_data, path, name='test'): + print('predict') + name = test_data.task_type + '_' + name + y_pred_keras = gnn.predict(test_data).ravel() + fpr_keras, tpr_keras, _ = roc_curve(test_data.labels, y_pred_keras) + auc_keras = auc(fpr_keras, tpr_keras) + plt.figure(figsize=(10, 6)) + plt.plot([0, 1], [0, 1], 'k--') + plt.plot(fpr_keras, tpr_keras, label='Keras (area = {:.3f})'.format(auc_keras)) + plt.xlabel('False positive rate') + plt.ylabel('True positive rate') + plt.title('ROC curve test') + plt.legend(loc='best') + save_path = name + "_predict" + ".png" + plt.savefig(Path(path/"results"/save_path)) + + +def plot_auc_multiclassification(gnn, test_data, n_classes, path, name='test'): + print('predict') + name = test_data.task_type + '_' + name + y_pred_keras = gnn.predict(test_data) + + fpr = dict() + tpr = dict() + roc_auc = dict() + for i in range(n_classes): + fpr[i], tpr[i], _ = roc_curve(test_data.labels[:, i], y_pred_keras[:, i]) + roc_auc[i] = auc(fpr[i], tpr[i]) + + # Compute micro-average ROC curve and ROC area + fpr["micro"], tpr["micro"], _ = roc_curve(np.array(test_data.labels)[:, i], y_pred_keras[:, i]) + roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) + + # Compute macro-average ROC curve and ROC area + + # First aggregate all false positive rates + all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)])) + + # Then interpolate all ROC curves at this points + mean_tpr = np.zeros_like(all_fpr) + for i in range(n_classes): + mean_tpr += interp(all_fpr, fpr[i], tpr[i]) + + # Finally average it and compute AUC + mean_tpr /= n_classes + + fpr["macro"] = all_fpr + tpr["macro"] = mean_tpr + roc_auc["macro"] = auc(fpr["macro"], tpr["macro"]) + + # Plot all ROC curves + plt.figure(figsize=(10, 6)) + plt.plot(fpr["micro"], tpr["micro"], + label='micro-average ROC curve (area = {0:0.2f})' + ''.format(roc_auc["micro"]), + color='deeppink', linestyle=':', linewidth=4) + + plt.plot(fpr["macro"], tpr["macro"], + label='macro-average ROC curve (area = {0:0.2f})' + ''.format(roc_auc["macro"]), + color='navy', linestyle=':', linewidth=4) + + for i in range(n_classes): + plt.plot(fpr[i], tpr[i], lw=2, + label='ROC curve of class {0} (area = {1:0.2f})' + ''.format(i, roc_auc[i])) + + plt.plot([0, 1], [0, 1], 'k--', lw=2) + plt.xlim([0.0, 1.0]) + plt.ylim([0.0, 1.05]) + plt.xlabel('False Positive Rate') + plt.ylabel('True Positive Rate') + plt.title('Some extension of Receiver operating characteristic to multi-class') + plt.legend(loc="lower right") + + save_path = name + "_predict" + ".png" + plt.savefig(Path(path/"results"/save_path)) + + +def plot_mae(gnn, test_data, path, name='test'): + print('predict') + name = test_data.task_type + '_' + name + y_pred_keras = gnn.predict(test_data).ravel() + plt.figure(figsize=(10, 6)) + plt.scatter(test_data.labels, y_pred_keras) + plt.plot([0, 8], [0, 8], 'k--') + plt.xlim(-4, 4) + plt.ylim(-4, 4) + plt.xlabel("experimetal", fontsize=16) + plt.ylabel("pred", fontsize=16) + plt.title('predicted') + save_path = name + "_predict" + ".png" + plt.savefig(Path(path/"results"/save_path)) + + +def plot_warm_up_lr(warm_up_lr, total_steps, lr, path): + plt.plot(warm_up_lr.learning_rates) + plt.xlabel('Step', fontsize=20) + plt.ylabel('lr', fontsize=20) + plt.axis([0, total_steps, 0, lr*1.1]) + # plt.xticks(np.arange(0, epochs, 1)) + plt.grid() + plt.title('Cosine decay with warmup', fontsize=20) + plt.savefig(Path(path/"results"/"cosine_decay.png")) diff --git a/crysnet/models/graphmodel.py b/crysnet/models/graphmodel.py index d9cb2b0..75d526d 100644 --- a/crysnet/models/graphmodel.py +++ b/crysnet/models/graphmodel.py @@ -8,12 +8,13 @@ from tensorflow.keras.models import Model from tensorflow.keras.regularizers import l2 from tensorflow.keras import layers -from crysnet.layers import MessagePassing +from crysnet.layers import MessagePassing, NewMessagePassing from crysnet.layers import SphericalBasisLayer, AzimuthLayer, ConcatLayer, EdgeMessagePassing from crysnet.layers import PartitionPadding, PartitionPaddingPair from crysnet.layers import EdgesAugmentedLayer, GraphTransformerEncoder -from crysnet.layers import Set2Set +from crysnet.layers import GraphormerEncoder, ConvGraphormerEncoder from crysnet.layers import CrystalGraphConvolution +from crysnet.layers import Set2Set def GraphModel( @@ -333,6 +334,104 @@ def MpnnBaseModel( return model +def NewMpnnBaseModel( + bond_dim, + atom_dim=16, + num_atom=118, + state_dim=16, + sp_dim=230, + units=32, + edge_steps=1, + message_steps=1, + transform_steps=1, + num_attention_heads=8, + dense_units=32, + output_dim=32, + readout_units=128, + dropout=0.0, + reg0=0.0, + reg1=0.0, + reg2=0.0, + reg3=0.0, + reg_rec=0.0, + batch_size=16, + spherical_harmonics=False, + regression=False, + ntarget=1, + multiclassification=None, + ): + atom_features = layers.Input((), dtype="int32", name="atom_features_input") + atom_features_ = layers.Embedding(num_atom, atom_dim, dtype="float32", name="atom_features")(atom_features) + bond_features = layers.Input((bond_dim), dtype="float32", name="bond_features") + local_env = layers.Input((6), dtype="float32", name="local_env") + state_attrs = layers.Input((), dtype="int32", name="state_attrs_input") + state_attrs_ = layers.Embedding(sp_dim, state_dim, dtype="float32", name="state_attrs")(state_attrs) + + pair_indices = layers.Input((2), dtype="int32", name="pair_indices") + + atom_graph_indices = layers.Input( + (), dtype="int32", name="atom_graph_indices" + ) + + bond_graph_indices = layers.Input( + (), dtype="int32", name="bond_graph_indices" + ) + + pair_indices_per_graph = layers.Input((2), dtype="int32", name="pair_indices_per_graph") + + x = NewMessagePassing(message_steps, kernel_regularizer=l2(reg0))( + [atom_features_, bond_features, state_attrs_, pair_indices, + atom_graph_indices, bond_graph_indices] + ) + + x_nodes_ = x[0] + x_edges_ = x[1] + x_state = x[2] + + x_nodes_b = PartitionPadding(batch_size)([x_nodes_, atom_graph_indices]) + x_edges_b = PartitionPadding(batch_size)([x_edges_, bond_graph_indices]) + + x_nodes = layers.BatchNormalization()(x_nodes_b) + x_edges = layers.BatchNormalization()(x_edges_b) + + x_nodes = layers.Masking(mask_value=0.)(x_nodes) + x_edges = layers.Masking(mask_value=0.)(x_edges) + + x_node = Set2Set(output_dim, kernel_regularizer=l2(reg2), recurrent_regularizer=l2(reg_rec))(x_nodes) + x_edge = Set2Set(output_dim, kernel_regularizer=l2(reg2), recurrent_regularizer=l2(reg_rec))(x_edges) + + x = layers.Concatenate(axis=-1, name='concat')([x_node, x_edge, x_state]) + + x = layers.Dense(readout_units, activation="relu", kernel_regularizer=l2(reg3), name='readout0')(x) + + if dropout: + x = layers.Dropout(dropout, name='dropout0')(x) + + x = layers.Dense(readout_units//2, activation="relu", kernel_regularizer=l2(reg3), name='readout1')(x) + + if dropout: + x = layers.Dropout(dropout, name='dropout1')(x) + + x = layers.Dense(readout_units//4, activation="relu", kernel_regularizer=l2(reg3), name='readout2')(x) + + if dropout: + x = layers.Dropout(dropout, name='dropout')(x) + + if regression: + x = layers.Dense(ntarget, name='final')(x) + elif multiclassification is not None: + x = layers.Dense(multiclassification, activation="softmax", name='final_softmax')(x) + else: + x = layers.Dense(1, activation="sigmoid", name='final')(x) + + model = Model( + inputs=[atom_features, bond_features, local_env, state_attrs, pair_indices, atom_graph_indices, + bond_graph_indices, pair_indices_per_graph], + outputs=[x], + ) + return model + + def DirectionalMpnnModel( bond_dim, atom_dim=16, @@ -819,8 +918,7 @@ def CgcnnModel( pair_indices_per_graph = layers.Input((2), dtype="int32", name="pair_indices_per_graph") x = CrystalGraphConvolution(message_steps, kernel_regularizer=l2(reg0))( - [atom_features_, bond_features, state_attrs_, pair_indices, - atom_graph_indices, bond_graph_indices] + [atom_features_, bond_features, pair_indices] ) x = PartitionPadding(batch_size)([x, atom_graph_indices]) @@ -851,6 +949,176 @@ def CgcnnModel( else: x = layers.Dense(1, activation="sigmoid", name='final')(x) + model = Model( + inputs=[atom_features, bond_features, local_env, state_attrs, pair_indices, atom_graph_indices, + bond_graph_indices, pair_indices_per_graph], + outputs=[x], + ) + return model + + +def GraphormerModel( + bond_dim, + atom_dim=16, + num_atom=118, + state_dim=16, + sp_dim=230, + units=32, + edge_steps=1, + message_steps=1, + transform_steps=1, + num_attention_heads=8, + dense_units=32, + output_dim=32, + readout_units=128, + dropout=0.0, + reg0=0.0, + reg1=0.0, + reg2=0.0, + reg3=0.0, + reg_rec=0.0, + batch_size=16, + spherical_harmonics=False, + regression=False, + multiclassification=None, + ): + atom_features = layers.Input((), dtype="int32", name="atom_features_input") + atom_features_ = layers.Embedding(num_atom, atom_dim, dtype="float32", name="atom_features")(atom_features) + bond_features = layers.Input((bond_dim), dtype="float32", name="bond_features") + local_env = layers.Input((6), dtype="float32", name="local_env") + state_attrs = layers.Input((), dtype="int32", name="state_attrs_input") + state_attrs_ = layers.Embedding(sp_dim, state_dim, dtype="float32", name="state_attrs")(state_attrs) + + pair_indices = layers.Input((2), dtype="int32", name="pair_indices") + + atom_graph_indices = layers.Input( + (), dtype="int32", name="atom_graph_indices" + ) + + bond_graph_indices = layers.Input( + (), dtype="int32", name="bond_graph_indices" + ) + + pair_indices_per_graph = layers.Input((2), dtype="int32", name="pair_indices_per_graph") + + x_nodes_, x_edges_, x_state = GraphormerEncoder()([atom_features_, bond_features, state_attrs_, pair_indices, atom_graph_indices, bond_graph_indices]) + + x_nodes_b = PartitionPadding(batch_size)([x_nodes_, atom_graph_indices]) + x_edges_b = PartitionPadding(batch_size)([x_edges_, bond_graph_indices]) + + x_nodes = layers.BatchNormalization()(x_nodes_b) + x_edges = layers.BatchNormalization()(x_edges_b) + + x_node = Set2Set(output_dim, kernel_regularizer=l2(reg2), recurrent_regularizer=l2(reg_rec))(x_nodes) + x_edge = Set2Set(output_dim, kernel_regularizer=l2(reg2), recurrent_regularizer=l2(reg_rec))(x_edges) + + x = layers.Concatenate(axis=-1, name='concat')([x_node, x_edge, x_state]) + + x = layers.Dense(readout_units, activation="relu", kernel_regularizer=l2(reg3), name='readout0')(x) + + if dropout: + x = layers.Dropout(dropout, name='dropout0')(x) + + x = layers.Dense(readout_units//2, activation="relu", kernel_regularizer=l2(reg3), name='readout1')(x) + + if dropout: + x = layers.Dropout(dropout, name='dropout1')(x) + + x = layers.Dense(readout_units//4, activation="relu", kernel_regularizer=l2(reg3), name='readout2')(x) + + if dropout: + x = layers.Dropout(dropout, name='dropout')(x) + + if regression: + x = layers.Dense(1, name='final')(x) + elif multiclassification is not None: + x = layers.Dense(multiclassification, activation="softmax", name='final_softmax')(x) + else: + x = layers.Dense(1, activation="sigmoid", name='final')(x) + + model = Model( + inputs=[atom_features, bond_features, local_env, state_attrs, pair_indices, atom_graph_indices, + bond_graph_indices, pair_indices_per_graph], + outputs=[x], + ) + return model + + +def ConvGraphormerModel( + bond_dim, + atom_dim=16, + num_atom=118, + state_dim=16, + sp_dim=230, + units=32, + edge_steps=1, + message_steps=1, + transform_steps=1, + num_attention_heads=8, + dense_units=32, + output_dim=32, + readout_units=128, + dropout=0.0, + reg0=0.0, + reg1=0.0, + reg2=0.0, + reg3=0.0, + reg_rec=0.0, + batch_size=16, + spherical_harmonics=False, + regression=False, + multiclassification=None, + ): + atom_features = layers.Input((), dtype="int32", name="atom_features_input") + atom_features_ = layers.Embedding(num_atom, atom_dim, dtype="float32", name="atom_features")(atom_features) + bond_features = layers.Input((bond_dim), dtype="float32", name="bond_features") + local_env = layers.Input((6), dtype="float32", name="local_env") + state_attrs = layers.Input((), dtype="int32", name="state_attrs_input") + state_attrs_ = layers.Embedding(sp_dim, state_dim, dtype="float32", name="state_attrs")(state_attrs) + + pair_indices = layers.Input((2), dtype="int32", name="pair_indices") + + atom_graph_indices = layers.Input( + (), dtype="int32", name="atom_graph_indices" + ) + + bond_graph_indices = layers.Input( + (), dtype="int32", name="bond_graph_indices" + ) + + pair_indices_per_graph = layers.Input((2), dtype="int32", name="pair_indices_per_graph") + + x_nodes_ = ConvGraphormerEncoder()([atom_features_, bond_features, + state_attrs_, pair_indices, atom_graph_indices, bond_graph_indices]) + + x_nodes_b = PartitionPadding(batch_size)([x_nodes_, atom_graph_indices]) + + x_nodes = layers.BatchNormalization()(x_nodes_b) + + x = Set2Set(output_dim, kernel_regularizer=l2(reg2), recurrent_regularizer=l2(reg_rec))(x_nodes) + + x = layers.Dense(readout_units, activation="relu", kernel_regularizer=l2(reg3), name='readout0')(x) + + if dropout: + x = layers.Dropout(dropout, name='dropout0')(x) + + x = layers.Dense(readout_units//2, activation="relu", kernel_regularizer=l2(reg3), name='readout1')(x) + + if dropout: + x = layers.Dropout(dropout, name='dropout1')(x) + + x = layers.Dense(readout_units//4, activation="relu", kernel_regularizer=l2(reg3), name='readout2')(x) + + if dropout: + x = layers.Dropout(dropout, name='dropout')(x) + + if regression: + x = layers.Dense(1, name='final')(x) + elif multiclassification is not None: + x = layers.Dense(multiclassification, activation="softmax", name='final_softmax')(x) + else: + x = layers.Dense(1, activation="sigmoid", name='final')(x) + model = Model( inputs=[atom_features, bond_features, local_env, state_attrs, pair_indices, atom_graph_indices, bond_graph_indices, pair_indices_per_graph],