Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
huzongxiang authored Jan 13, 2022
1 parent 3c010e9 commit e4954e3
Show file tree
Hide file tree
Showing 10 changed files with 1,163 additions and 67 deletions.
2 changes: 1 addition & 1 deletion crysnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
@author: huzongxiang
"""
__version__ = "0.1.2"
__version__ = "0.1.5"
2 changes: 1 addition & 1 deletion crysnet/activations/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ def shifted_softplus(x):
"""

return tf.nn.softplus(x) - tf.log(2.0)
return tf.nn.softplus(x) - tf.math.log(2.0)
5 changes: 3 additions & 2 deletions crysnet/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
@author: huzongxiang
"""

from .graphnetworklayer import MessagePassing
from .graphnetworklayer import MessagePassing, NewMessagePassing
from .edgenetworklayer import SphericalBasisLayer, AzimuthLayer, ConcatLayer, EdgeAggragate, EdgeMessagePassing
from .crystalgraphlayer import CrystalGraphConvolution
from .crystalgraphlayer import CrystalGraphConvolution, GNConvolution
from .graphtransformer import EdgesAugmentedLayer, GraphTransformerEncoder
from .graphormer import GraphormerEncoder, ConvGraphormerEncoder
from .partitionpaddinglayer import PartitionPadding, PartitionPaddingPair
from .readout import Set2Set
176 changes: 148 additions & 28 deletions crysnet/layers/crystalgraphlayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import activations, initializers, regularizers, constraints
from crysnet.activations import shifted_softplus


class CrystalGraphConvolution(layers.Layer):
Expand All @@ -18,7 +17,7 @@ class CrystalGraphConvolution(layers.Layer):
Xie et al. PHYSICAL REVIEW LETTERS 120, 145301 (2018)
"""
def __init__(self,
steps,
steps=1,
kernel_initializer="glorot_uniform",
kernel_regularizer=None,
kernel_constraint=None,
Expand All @@ -45,24 +44,6 @@ def __init__(self,
def build(self, input_shape):
self.atom_dim = input_shape[0][-1]
self.edge_dim = input_shape[1][-1]
self.state_dim = input_shape[2][-1]

self.update_nodes = layers.GRUCell(self.atom_dim,
kernel_regularizer=self.recurrent_regularizer,
recurrent_regularizer=self.recurrent_regularizer,
name='update_nodes'
)

self.update_edges = layers.GRUCell(self.edge_dim,
kernel_regularizer=self.recurrent_regularizer,
recurrent_regularizer=self.recurrent_regularizer,
name='update_edges'
)

self.update_states = layers.GRUCell(self.state_dim,
kernel_regularizer=self.recurrent_regularizer,
recurrent_regularizer=self.recurrent_regularizer,
name='update_states')

with tf.name_scope("nodes_aggregate"):
# weight for updating atom_features by bond_features
Expand Down Expand Up @@ -101,8 +82,6 @@ def build(self, input_shape):
constraint=self.bias_constraint,
name='nodes_bias_g',
)

self.softplus = shifted_softplus

self.built = True

Expand All @@ -120,11 +99,11 @@ def aggregate_nodes(self, inputs: Sequence):
DESCRIPTION.
"""
atom_features, edges_sph_features, state_attrs, pair_indices, atom_graph_indices, bond_graph_indices = inputs
atom_features, edges_features, pair_indices = inputs

# concat state_attrs with atom_features to get merged atom_merge_state_features
atom_features_gather = tf.gather(atom_features, pair_indices)
atom_merge_features = tf.concat([atom_features_gather[:,0], atom_features_gather[:,1], edges_sph_features], axis=-1)
atom_merge_features = tf.concat([atom_features_gather[:,0], atom_features_gather[:,1], edges_features], axis=-1)

transformed_features_s = tf.matmul(atom_merge_features, self.kernel_s) + self.bias_s
transformed_features_g = tf.matmul(atom_merge_features, self.kernel_g) + self.bias_g
Expand All @@ -151,14 +130,155 @@ def call(self, inputs: Sequence) -> Sequence:
DESCRIPTION.
"""
atom_features, edges_sph_features, state_attrs, pair_indices, atom_graph_indices, bond_graph_indices = inputs
atom_features, edges_features, pair_indices = inputs

atom_features_updated = atom_features

for i in range(self.steps):
atom_features_updated = self.aggregate_nodes([atom_features_updated, edges_features, pair_indices])

return atom_features_updated


def get_config(self):
config = super().get_config()
config.update({"steps": self.steps})
return config


class GNConvolution(layers.Layer):
"""
The CGCNN graph implementation as described in the paper
Xie et al. PHYSICAL REVIEW LETTERS 120, 145301 (2018)
"""
def __init__(self,
steps=1,
kernel_initializer="glorot_uniform",
kernel_regularizer=None,
kernel_constraint=None,
bias_initializer="zeros",
bias_regularizer=None,
recurrent_regularizer=None,
bias_constraint=None,
activation=None,
**kwargs):
super().__init__(**kwargs)
self.steps = steps
self.activation = activations.get(activation)
self.kernel_initializer = initializers.get(kernel_initializer)
self.bias_initializer = initializers.get(bias_initializer)

self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.recurrent_regularizer = regularizers.get(recurrent_regularizer)

self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)


def build(self, input_shape):
self.atom_dim = input_shape[0][-1]
self.edge_dim = input_shape[1][-1]
self.state_dim = input_shape[2][-1]


# weight for updating atom_features by bond_features
self.kernel_s = self.add_weight(
shape=(self.atom_dim * 2 + self.edge_dim + self.state_dim,
self.atom_dim),
trainable=True,
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
name='nodes_kernel_s',
)
self.bias_s = self.add_weight(
shape=(self.atom_dim,),
trainable=True,
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
name='nodes_bias_s',
)

self.kernel_g = self.add_weight(
shape=(self.atom_dim * 2 + self.edge_dim + self.state_dim,
self.atom_dim),
trainable=True,
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
name='nodes_kernel_g',
)
self.bias_g = self.add_weight(
shape=(self.atom_dim,),
trainable=True,
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
name='nodes_bias_g',
)

self.built = True


def aggregate_nodes(self, inputs: Sequence):
"""
Parameters
----------
inputs : Sequence
DESCRIPTION.
Returns
-------
atom_features_aggregated : TYPE
DESCRIPTION.
"""
atom_features, edges_features, state_attrs, pair_indices, atom_graph_indices, bond_graph_indices = inputs

# each bond in pair_indices, concatenate thier atom features by atom indexes
# then concatenate atom features with bond features
atom_features_gather = tf.gather(atom_features, pair_indices)
edges_merge_atom_features = tf.concat([atom_features_gather[:,0], atom_features_gather[:,1]], axis=-1)

# repeat state attributes by bond_graph_indices, then concatenate to bond_merge_atom_features
state_attrs_repeat = tf.gather(state_attrs, bond_graph_indices)
edges_features_concated = tf.concat([edges_merge_atom_features, state_attrs_repeat, edges_features], axis=-1)

transformed_features_s = tf.matmul(edges_features_concated, self.kernel_s) + self.bias_s
transformed_features_g = tf.matmul(edges_features_concated, self.kernel_g) + self.bias_g

transformed_features = tf.sigmoid(transformed_features_s) * tf.nn.softplus(transformed_features_g)
atom_features_aggregated = tf.math.segment_sum(transformed_features, pair_indices[:,0])

atom_features_updated = atom_features + atom_features_aggregated
atom_features_updated = tf.nn.softplus(atom_features_updated)

return atom_features_updated


def call(self, inputs: Sequence) -> Sequence:
"""
Parameters
----------
inputs : Sequence
DESCRIPTION.
Returns
-------
Sequence
DESCRIPTION.
"""
atom_features, bond_features, state_attrs, pair_indices, atom_graph_indices, bond_graph_indices = inputs

atom_features_updated = atom_features

for i in range(self.steps):
atom_features_updated = self.aggregate_nodes([atom_features_updated, edges_sph_features,
state_attrs, pair_indices,
atom_graph_indices, bond_graph_indices])
atom_features_updated = self.aggregate_nodes([atom_features_updated, bond_features,
state_attrs, pair_indices,
atom_graph_indices, bond_graph_indices])

return atom_features_updated

Expand Down
Loading

0 comments on commit e4954e3

Please sign in to comment.