Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

eh features #1

Merged
merged 2 commits into from
Oct 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
![Tests](https://github.com/bayer-science-for-a-better-life/molnet-geometric-lightning/actions/workflows/python-package-conda.yml/badge.svg)
![Tests](https://github.com/bayer-science-for-a-better-life/eh-benchmark/actions/workflows/python-package-conda.yml/badge.svg)

# molnet-geometric-lightning
# eh-benchmark

This is a package for benchmarking the [MoleculeNet datasets](https://pubs.rsc.org/en/content/articlelanding/2018/sc/c7sc02664a) present in the [Open Graph Benchmark](https://ogb.stanford.edu/) on different [graph convolutional neural network](https://distill.pub/2021/gnn-intro/) architectures.
The neural networks are implemented using [PyTorch Geometric](https://github.com/pyg-team/pytorch_geometric) and [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning).
This is a package to benchmark how adding simple QM calculation results as features to molecular graphs used for training graph neural networks effects training performance.
Large parts of this code are borrowed from PyTorch Geometric and OGB examples, therefore this package is available under the same license (MIT).

## Why?

The OGB library offers premade data objects compatible with PyTorch Geometric.
While convenient, this makes it difficult to implement different featurizations.
Furthermore, the PyTorch Lightning framework makes for easier-to-maintain code, with a nice command line interface and Tensorboard logging built-in.
The primary advantage of using a trained GNN for molecular property prediction is speed.
This precludes using expensive QM calculations to generate features.
However a cheap calculation, like extended Hueckel, could be used.

## Installation

Expand All @@ -27,7 +26,7 @@ The following will train 5 models on the `bbbp` dataset with the default paramet
The models will be stored in `example_models/`, and the data will be downloaded to `datasets/`.

```shell script
python molnet_geometric_lightning/train.py --default_root_dir=example_model/ --dataset_name=bbbp --dataset_root=datasets/ --gpus=1 --max_epochs=100 --n_runs=5
python eh-benchmark/train.py --default_root_dir=example_model/ --dataset_name=bbbp --dataset_root=datasets/ --gpus=1 --max_epochs=100 --n_runs=5
```

Replace the directories to your liking, and `bbbp` with any name from MoleculeNet, for example `tox21`, `muv`, `hiv`, `pcba`, `bace`, `esol`.
Expand Down
File renamed without changes.
22 changes: 14 additions & 8 deletions molnet_geometric_lightning/model.py → eh_benchmark/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch_geometric.utils import degree
from torch_geometric.nn.functional import bro, gini

from .molecule_net import MoleculeNetEH
from .mol_encoder import AtomEncoder, BondEncoder

cls_criterion = BCEWithLogitsLoss()
Expand All @@ -22,13 +23,17 @@ def __init__(
self,
root,
name,
eh_feat=False,
batch_size=32,
):
super().__init__()
self.name = name
self.root = root
self.batch_size = batch_size
self.dataset_class = MoleculeNet
if eh_feat:
self.dataset_class = MoleculeNetEH
else:
self.dataset_class = MoleculeNet
# only need the split idx from the ogb dataset
ogb_name = f'ogbg-mol{name}'
ogb_dataset = PygGraphPropPredDataset(name=ogb_name, root='/tmp/ogb')
Expand Down Expand Up @@ -68,14 +73,14 @@ def test_dataloader(self):


class GINConv(MessagePassing):
def __init__(self, emb_dim):
def __init__(self, emb_dim, eh_feat=False):
'''
emb_dim (int): node embedding dimensionality
'''
super(GINConv, self).__init__(aggr="add")
self.mlp = Sequential(Linear(emb_dim, 2*emb_dim), BatchNorm1d(2*emb_dim), ReLU(), Linear(2*emb_dim, emb_dim))
self.eps = Parameter(Tensor([0]))
self.bond_encoder = BondEncoder(emb_dim=emb_dim)
self.bond_encoder = BondEncoder(emb_dim=emb_dim, eh_feat=eh_feat)

def forward(self, x, edge_index, edge_attr):
edge_embedding = self.bond_encoder(edge_attr)
Expand All @@ -91,12 +96,12 @@ def update(self, aggr_out):


class GCNConv(MessagePassing):
def __init__(self, emb_dim):
def __init__(self, emb_dim, eh_feat=False):
super(GCNConv, self).__init__(aggr='add')

self.linear = Linear(emb_dim, emb_dim)
self.root_emb = Embedding(1, emb_dim)
self.bond_encoder = BondEncoder(emb_dim=emb_dim)
self.bond_encoder = BondEncoder(emb_dim=emb_dim, eh_feat=eh_feat)

def forward(self, x, edge_index, edge_attr):
x = self.linear(x)
Expand Down Expand Up @@ -160,7 +165,7 @@ def add_model_specific_args(parent_parser):
)
return parent_parser

def __init__(self, task_type, num_tasks, evaluator, conf):
def __init__(self, task_type, eh_feat, num_tasks, evaluator, conf):
super().__init__()
self.save_hyperparameters(conf)
self.n_conv_layers = self.hparams.n_conv_layers
Expand All @@ -174,11 +179,12 @@ def __init__(self, task_type, num_tasks, evaluator, conf):
self.BRO = self.hparams.BRO
self.gini = self.hparams.gini

self.eh_feat = eh_feat
self.task_type = task_type
self.num_tasks = num_tasks
self.evaluator = evaluator

self.atom_encoder = AtomEncoder(emb_dim=self.embedding_dim)
self.atom_encoder = AtomEncoder(emb_dim=self.embedding_dim, eh_feat=self.eh_feat)

if self.virtual_node:
self.virtualnode_embedding = Embedding(1, self.embedding_dim)
Expand All @@ -204,7 +210,7 @@ def __init__(self, task_type, num_tasks, evaluator, conf):
self.conv = GCNConv

for _ in range(self.n_conv_layers):
self.convs.append(self.conv(self.embedding_dim))
self.convs.append(self.conv(self.embedding_dim, self.eh_feat))
self.batch_norms.append(BatchNorm1d(self.embedding_dim))

self.graph_pred_linear = Linear(self.embedding_dim, self.num_tasks)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,29 @@ def get_bond_feature_dims():

class AtomEncoder(torch.nn.Module):

def __init__(self, emb_dim):
def __init__(self, emb_dim, eh_feat=False):
super(AtomEncoder, self).__init__()

full_atom_feature_dims = get_atom_feature_dims()

self.atom_embedding_list = torch.nn.ModuleList()

self.eh_feat = eh_feat
if eh_feat:
self.eh_embedding = torch.nn.Linear(2, emb_dim)
torch.nn.init.xavier_uniform_(self.eh_embedding.weight.data)
for i, dim in enumerate(full_atom_feature_dims):
emb = torch.nn.Embedding(dim, emb_dim)
torch.nn.init.xavier_uniform_(emb.weight.data)
self.atom_embedding_list.append(emb)

def forward(self, x):
x_embedding = 0
if self.eh_feat:
for i in range(x.shape[1] - 2):
x_embedding += self.atom_embedding_list[i](x[:, i].long())
x_embedding += self.eh_embedding(x[:, -2:])
return x_embedding

for i in range(x.shape[1]):
x_embedding += self.atom_embedding_list[i](x[:, i])

Expand All @@ -52,12 +61,16 @@ def forward(self, x):

class BondEncoder(torch.nn.Module):

def __init__(self, emb_dim):
def __init__(self, emb_dim, eh_feat=False):
super(BondEncoder, self).__init__()

full_bond_feature_dims = get_bond_feature_dims()

self.bond_embedding_list = torch.nn.ModuleList()
self.eh_feat = eh_feat
if eh_feat:
self.eh_embedding = torch.nn.Linear(1, emb_dim)
torch.nn.init.xavier_uniform_(self.eh_embedding.weight.data)

for i, dim in enumerate(full_bond_feature_dims):
emb = torch.nn.Embedding(dim, emb_dim)
Expand All @@ -66,6 +79,11 @@ def __init__(self, emb_dim):

def forward(self, edge_attr):
bond_embedding = 0
if self.eh_feat:
for i in range(edge_attr.shape[1] - 1):
bond_embedding += self.bond_embedding_list[i](edge_attr[:, i].long())
bond_embedding += self.eh_embedding(edge_attr[:, -1:])
return bond_embedding
for i in range(edge_attr.shape[1]):
bond_embedding += self.bond_embedding_list[i](edge_attr[:, i])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
import os.path as osp
import re

import numpy as np
from rdkit.Chem import AllChem
from rdkit.Chem import rdEHTTools
from rdkit import Chem
import torch
from torch_geometric.data import (InMemoryDataset, Data, download_url,
extract_gz)
Expand Down Expand Up @@ -58,7 +62,71 @@
}


class MoleculeNetHBonds(InMemoryDataset):
def set_overlap_populations(m, rop, n_atoms):
for bnd in m.GetBonds():
a1 = bnd.GetBeginAtom()
a2 = bnd.GetEndAtom()
if a1.GetIdx() >= n_atoms:
continue
if a2.GetIdx() >= n_atoms:
continue
# symmetric matrix:
i1 = max(a1.GetIdx(), a2.GetIdx())
i2 = min(a1.GetIdx(), a2.GetIdx())
idx = (i1 * (i1 + 1)) // 2 + i2
bnd.SetDoubleProp("MullikenOverlapPopulation", rop[idx])

for atom in m.GetAtoms():
if atom.GetIdx() >= n_atoms:
break
i1 = atom.GetIdx()
idx = (i1 * (i1 + 1)) // 2 + i1
atom.SetDoubleProp("MullikenPopulation", rop[idx])


def get_eH_features(mol):
# add hydrogens
mh = Chem.AddHs(mol)
n_atoms = mol.GetNumAtoms()

try:
AllChem.EmbedMultipleConfs(mh, numConfs=10, useRandomCoords=True, maxAttempts=100)
res = AllChem.MMFFOptimizeMoleculeConfs(mh)
min_energy_conf = np.argmin([x[1] for x in res])
# this can throw a ValueError, which should be caught in the process loop
passed, res = rdEHTTools.RunMol(mol=mh, confId=int(min_energy_conf), keepOverlapAndHamiltonianMatrices=True)
if passed < 0:
raise ValueError
rop = res.GetReducedOverlapPopulationMatrix()
charges = res.GetAtomicCharges()
orbital_E = res.GetOrbitalEnergies()
homo = orbital_E[res.numElectrons // 2 - 1]
lumo = orbital_E[res.numElectrons // 2]
except ValueError:
# place dummy values
rop = np.ones(mh.GetNumAtoms()**2)
charges = 3. * np.ones(mh.GetNumAtoms())
homo = -10.
lumo = -2.

# set bond and atom electron populations
set_overlap_populations(mh, rop, n_atoms)

# set atomic charges and ionization potentials
_i = 0
for atom in mh.GetAtoms():
if _i >= n_atoms:
break
# if atom.GetAtomicNum() == 1:
# continue
idx = atom.GetIdx()
atom.SetDoubleProp("eHCharge", charges[idx])
_i += 1

return mh, homo, lumo, n_atoms


class MoleculeNetEH(InMemoryDataset):
r"""The `MoleculeNet <http://moleculenet.ai/datasets-1>`_ benchmark
collection from the `"MoleculeNet: A Benchmark for Molecular Machine
Learning" <https://arxiv.org/abs/1703.00564>`_ paper, containing datasets
Expand Down Expand Up @@ -115,7 +183,7 @@ def __init__(self, root, name, transform=None, pre_transform=None,
pre_filter=None):
self.name = name.lower()
assert self.name in self.names.keys()
super(MoleculeNetHBonds, self).__init__(root, transform, pre_transform, pre_filter)
super(MoleculeNetEH, self).__init__(root, transform, pre_transform, pre_filter)
self.data, self.slices = torch.load(self.processed_paths[0])

@property
Expand All @@ -141,9 +209,7 @@ def download(self):
extract_gz(path, self.raw_dir)
os.unlink(path)

def process(self):
from rdkit import Chem

def process(self): # noqa
with open(self.raw_paths[0], 'r') as f:
dataset = f.read().split('\n')[1:-1]
dataset = [x for x in dataset if len(x) > 0] # Filter empty lines.
Expand All @@ -164,8 +230,12 @@ def process(self):
if mol is None:
continue

mh, eH_homo, eH_lumo, n_atoms = get_eH_features(mol)

xs = []
for atom in mol.GetAtoms():
for _i, atom in enumerate(mh.GetAtoms()):
if _i >= n_atoms:
break
x = []
x.append(x_map['atomic_num'].index(atom.GetAtomicNum()))
x.append(x_map['chirality'].index(str(atom.GetChiralTag())))
Expand All @@ -178,26 +248,35 @@ def process(self):
str(atom.GetHybridization())))
x.append(x_map['is_aromatic'].index(atom.GetIsAromatic()))
x.append(x_map['is_in_ring'].index(atom.IsInRing()))
# add eH features
x.append(atom.GetDoubleProp("eHCharge"))
x.append(atom.GetDoubleProp("MullikenPopulation"))
xs.append(x)

x = torch.tensor(xs, dtype=torch.long).view(-1, 9)
x = torch.tensor(xs).view(-1, 11)

edge_indices, edge_attrs = [], []
for bond in mol.GetBonds():
for bond in mh.GetBonds():
if bond.GetBeginAtomIdx() >= n_atoms:
continue
if bond.GetEndAtomIdx() >= n_atoms:
continue
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()

e = []
e.append(e_map['bond_type'].index(str(bond.GetBondType())))
e.append(e_map['stereo'].index(str(bond.GetStereo())))
e.append(e_map['is_conjugated'].index(bond.GetIsConjugated()))
# add eH features
e.append(bond.GetDoubleProp('MullikenOverlapPopulation'))

edge_indices += [[i, j], [j, i]]
edge_attrs += [e, e]

edge_index = torch.tensor(edge_indices)
edge_index = edge_index.t().to(torch.long).view(2, -1)
edge_attr = torch.tensor(edge_attrs, dtype=torch.long).view(-1, 3)
edge_attr = torch.tensor(edge_attrs).view(-1, 4)

# Sort indices.
if edge_index.numel() > 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pytorch_lightning.callbacks import EarlyStopping
from ogb.graphproppred import Evaluator

from molnet_geometric_lightning.model import Net, MolData
from eh_benchmark.model import Net, MolData


def parse_args(args):
Expand All @@ -19,6 +19,9 @@ def parse_args(args):
parser.add_argument(
'--dataset_root', type=str, default='data'
)
parser.add_argument(
'--eh_feat', action='store_true', default=False,
)
parser.add_argument(
'--n_runs', type=int, default=1,
)
Expand All @@ -36,6 +39,7 @@ def train(args):
mol_data = MolData(
root=args.dataset_root,
name=args.dataset_name,
eh_feat=args.eh_feat,
batch_size=args.batch_size,
)

Expand All @@ -57,6 +61,7 @@ def train(args):
trainer.checkpoint_callback.save_top_k = 1
model = Net(
task_type=mol_data.task_type,
eh_feat=args.eh_feat,
num_tasks=mol_data.num_tasks,
evaluator=evaluator,
conf=args,
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from setuptools import setup, find_packages

setup(
name='molnet-geometric-lightning',
name='eh-benchmark',
version='0.1.0',
packages=find_packages(include=['molnet-geometric-lightning', 'molnet-geometric-lightning.*'])
packages=find_packages(include=['eh-benchmark', 'eh-benchmark.*'])
)
Loading