Skip to content

Commit

Permalink
[Example] Pre-train GNNs (#110)
Browse files Browse the repository at this point in the history
* [Example] Pretrain_gnns

* Change pretrain_gnns location

* Create README.md

* Fix README.md

* Update README.md

* Update pretrain_masking.py

* Update pretrain_supervised.py

* Update pretrain_masking.py

* Update

* Update

* Fix

* Update

* Improve performance

* Fix

* Update

* Update pretrain_supervised.py

* Support custom datasets

* Update README.md

* Fix typo

* Update gin_predictor.py to support skipping readout

* Update pretrain_masking.py model

* Update

* Merge datasets

* Update

* Create classification.py

* Bugfix

* Bugfix

* Fix gin_predictor.py

Support skipping the top linear layer

* Update gin_predictor.py

Support skipping the readout and top linear layer

* Update utils.py

* Update pretrain_masking.py

* Update

* Update .gitignore

* Update utils.py

* Update pretrain_supervised.py

* Update pretrain_masking.py

* Update classification.py

* Update README.md

Update experiment results

* Update README.md

* Fix

* Update

* Update

* Update pretrain_masking.py

* Update pretrain_supervised.py

* Update README.md

* Update

* Update gin_predictor.py

* Update gin_predictor.py

* Delete model.py

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

Co-authored-by: Mufei Li <[email protected]>
  • Loading branch information
wenx00 and mufeili authored Jan 10, 2021
1 parent 53ebccf commit 69c11a9
Show file tree
Hide file tree
Showing 7 changed files with 831 additions and 3 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# IDE
.idea
.vscode

# Byte-compiled
__pycache__/
Expand Down
56 changes: 56 additions & 0 deletions examples/property_prediction/pretrain_gnns/chem/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Strategies for Pre-training Graph Neural Networks

## Intro
This is a DGL implementation of the following paper based on PyTorch.

- [Strategies for Pre-training Graph Neural Networks.](https://arxiv.org/abs/1905.12265) W. Hu*, B. Liu*, J. Gomes, M. Zitnik., P. Liang, V. Pande, J. Leskovec. *International Conference on Learning Representations (ICLR)*, 2020.

## Datasets
- For node-level self-supervised pre-training, 2 million unlabeled molecules sampled from the ZINC15 database are used. Custom datasets are supported.
- For graph-level multi-task supervised pre-training, a preprocessed ChEMBL dataset is used, which contains 456K molecules with 1310 kinds of diverse and extensive biochemical assays. Custom datasets are supported.
- For fine-tuning downstream tasks, BBBP, Tox21, ToxCast, SIDER, MUV, HIV and BACE dataset are supported.

## Usage
**1. Self-supervised pre-training**

This paper purposed an attribute masking pre-training method. It randomly masks input node/edge attributes by replacing them with special masked indicators, then the GNN will predict those attributes based on neighboring structure.

``` bash
python pretrain_masking.py --output_model_file OUTPUT_MODEL_FILE
```
The self-supervised pre-training model will be found in `OUTPUT_MODEL_FILE` after training (default filename: pretrain_masking.pth).

If a custom dataset is specified, the path needs to be provided with `--dataset`. The custom dataset is supposed to be a text file, where every line is a molecule SMILES except that the first is 'smiles'.

**2. Supervised pre-training**
``` bash
python pretrain_supervised.py --input_model_file INPUT_MODEL_FILE --output_model_file OUTPUT_MODEL_FILE
```
The self-supervised pre-trained model can be loaded from `INPUT_MODEL_FILE`.

The supervised pre-training model will be found in `OUTPUT_MODEL_FILE` after training (default filename: pretrain_supervised.pth).

If a custom dataset is specified, the path needs to be provided with `--dataset`. The custom dataset is supposed to be a `.pkl` file, which is pickled from "a list of tuples". The first element in every `tuple` should be a molecule SMILES in class `str`, and the second element should be its corresponding label in class `torch.Tensor`. Possible values are {-1, 0, 1} in labels. "1" means positive, and "-1" means negative. "0" indicates the molecule is invalid.

**3. Fine-tuning for downstream dataset**
``` bash
python classification.py --input_model_file INPUT_MODEL_FILE --output_model_file OUTPUT_MODEL_FILE --dataset DOWNSTREAM_DATASET
```

The supervised pre-trained model can be loaded from `INPUT_MODEL_FILE`.

The fine-tuned model will be found in `OUTPUT_MODEL_FILE` after training (default filename: pretrain_fine_tuning.pth).

## Experiment Results

With the default parameters, following downstream task results are based on Attribute Masking (Node-level) and Supervised (Graph-level) pre-training strategy with GIN.

| Dataset | ROC-AUC (%) | ROC-AUC reported (%) |
| :-----: | :-----: | :--------: |
| BBBP | 71.75 | 66.5 ± 2.5 |
| Tox21 | 72.67 | 77.9 ± 0.4 |
| ToxCast | 62.22 | 65.1 ± 0.3 |
| SIDER | 58.97 | 63.9 ± 0.9 |
| MUV | 79.44 | 81.2 ± 1.9 |
| HIV | 74.52 | 77.1 ± 1.2 |
| BACE | 77.34 | 80.3 ± 0.9 |
225 changes: 225 additions & 0 deletions examples/property_prediction/pretrain_gnns/chem/classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# -*- coding: utf-8 -*-
#
# adapted from
# https://github.com/awslabs/dgl-lifesci/blob/master/examples/property_prediction/moleculenet/classification.py

import argparse
from functools import partial
import numpy as np
import torch
import torch.nn as nn

from dgllife.utils import PretrainAtomFeaturizer
from dgllife.utils import PretrainBondFeaturizer
from dgllife.utils import smiles_to_bigraph, Meter, EarlyStopping
from dgllife.model.model_zoo.gin_predictor import GINPredictor

from torch.utils.data import DataLoader

from utils import split_dataset, collate_molgraphs


def train(args, epoch, model, data_loader, loss_criterion, optimizer, device):
model.train()
train_meter = Meter()
for batch_id, batch_data in enumerate(data_loader):
smiles, bg, labels, masks = batch_data

if len(smiles) == 1:
# Avoid potential issues with batch normalization
continue

labels, masks = labels.to(device), masks.to(device)
bg = bg.to(device)
node_feats = [
bg.ndata.pop('atomic_number').to(device),
bg.ndata.pop('chirality_type').to(device)
]
edge_feats = [
bg.edata.pop('bond_type').to(device),
bg.edata.pop('bond_direction_type').to(device)
]
logits = model(bg, node_feats, edge_feats)
# Mask non-existing labels
loss = (loss_criterion(logits, labels) * (masks != 0).float()).mean()

optimizer.zero_grad()
loss.backward()
optimizer.step()

train_meter.update(logits, labels, masks)

if batch_id % args.print_every == 0:
print('epoch {:d}/{:d}, batch {:d}/{:d}, loss {:.4f}'.format(
epoch + 1, args.num_epochs, batch_id + 1, len(data_loader), loss.item()))

train_score = np.mean(train_meter.compute_metric(args.metric))
print('epoch {:d}/{:d}, training {} {:.4f}'.format(
epoch + 1, args.num_epochs, args.metric, train_score))


def evaluation(args, model, data_loader, device):
model.eval()
eval_meter = Meter()
with torch.no_grad():
for _, batch_data in enumerate(data_loader):
_, bg, labels, masks = batch_data
labels = labels.to(device)
bg = bg.to(device)
node_feats = [
bg.ndata.pop('atomic_number').to(device),
bg.ndata.pop('chirality_type').to(device)
]
edge_feats = [
bg.edata.pop('bond_type').to(device),
bg.edata.pop('bond_direction_type').to(device)
]
logits = model(bg, node_feats, edge_feats)
eval_meter.update(logits, labels, masks)
return np.mean(eval_meter.compute_metric(args.metric))


def main(args, dataset, device):
train_set, val_set, test_set = split_dataset(args, dataset)
train_loader = DataLoader(dataset=train_set, batch_size=32, shuffle=True,
collate_fn=collate_molgraphs, num_workers=args.num_workers)
val_loader = DataLoader(dataset=val_set, batch_size=32,
collate_fn=collate_molgraphs, num_workers=args.num_workers)
test_loader = DataLoader(dataset=test_set, batch_size=32,
collate_fn=collate_molgraphs, num_workers=args.num_workers)

model = GINPredictor(num_node_emb_list=[119, 4],
num_edge_emb_list=[6, 3],
num_layers=5,
emb_dim=300,
JK='last',
dropout=0.2,
readout='mean',
n_tasks=dataset.n_tasks)
if args.input_model_file != '':
model.gnn.load_state_dict(torch.load(args.input_model_file))
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0)
criterion = nn.BCEWithLogitsLoss(reduction='none')
stopper = EarlyStopping(filename=args.output_model_file)

for epoch in range(0, args.num_epochs):
train(args, epoch, model, train_loader, criterion, optimizer, device)
val_score = evaluation(args, model, val_loader, device)
early_stop = stopper.step(val_score, model)
print('epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'.format(
epoch + 1, args.num_epochs, args.metric,
val_score, args.metric, stopper.best_score))

if early_stop:
break
stopper.load_checkpoint(model)
test_score = evaluation(args, model, test_loader, device)
print('test {} {:.4f}'.format(args.metric, test_score))


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='pretrain_downstream_classification_task')
parser.add_argument('--device', type=int, default=0,
help='which gpu to use if any. (default: 0)')
parser.add_argument('-d', '--dataset', choices=['MUV', 'BACE', 'BBBP', 'ClinTox', 'SIDER',
'ToxCast', 'HIV', 'PCBA', 'Tox21'],
help='Dataset to use')
parser.add_argument('-s', '--split', choices=['scaffold', 'random'], default='scaffold',
help='Dataset splitting method (default: scaffold)')
parser.add_argument('-sr', '--split-ratio', default='0.8,0.1,0.1', type=str,
help='Proportion of the dataset to use for training, validation and test, '
'(default: 0.8,0.1,0.1)')
parser.add_argument('-me', '--metric', choices=['roc_auc_score', 'pr_auc_score'],
default='roc_auc_score',
help='Metric for evaluation (default: roc_auc_score)')
parser.add_argument('-n', '--num-epochs', type=int, default=1000,
help='Maximum number of epochs for training. '
'We set a large number by default as early stopping '
'will be performed. (default: 1000)')
parser.add_argument('-nw', '--num-workers', type=int, default=0,
help='Number of processes for data loading (default: 0)')
parser.add_argument('-pe', '--print-every', type=int, default=20,
help='Print the training progress every X mini-batches')
parser.add_argument('--input_model_file', type=str, default='pretrain_supervised.pth',
help='filename to input the pre-trained model if there is any.'
' (default: pretrain_supervised.pth)')
parser.add_argument('--output_model_file', type=str, default='pretrain_fine_tuning.pth',
help='filename to output the pre-trained downstream task model.'
' (default: pretrain_fine_tuning.pth)')
args = parser.parse_args()
print(args)

device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")

atom_featurizer = PretrainAtomFeaturizer()
bond_featurizer = PretrainBondFeaturizer()

if args.dataset == 'MUV':
from dgllife.data import MUV

dataset = MUV(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=atom_featurizer,
edge_featurizer=bond_featurizer,
n_jobs=1 if args.num_workers == 0 else args.num_workers)
elif args.dataset == 'BACE':
from dgllife.data import BACE

dataset = BACE(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=atom_featurizer,
edge_featurizer=bond_featurizer,
n_jobs=1 if args.num_workers == 0 else args.num_workers)
elif args.dataset == 'BBBP':
from dgllife.data import BBBP

dataset = BBBP(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=atom_featurizer,
edge_featurizer=bond_featurizer,
n_jobs=1 if args.num_workers == 0 else args.num_workers)
elif args.dataset == 'ClinTox':
from dgllife.data import ClinTox

dataset = ClinTox(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=atom_featurizer,
edge_featurizer=bond_featurizer,
n_jobs=1 if args.num_workers == 0 else args.num_workers)
elif args.dataset == 'SIDER':
from dgllife.data import SIDER

dataset = SIDER(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=atom_featurizer,
edge_featurizer=bond_featurizer,
n_jobs=1 if args.num_workers == 0 else args.num_workers)
elif args.dataset == 'ToxCast':
from dgllife.data import ToxCast

dataset = ToxCast(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=atom_featurizer,
edge_featurizer=bond_featurizer,
n_jobs=1 if args.num_workers == 0 else args.num_workers)
elif args.dataset == 'HIV':
from dgllife.data import HIV

dataset = HIV(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=atom_featurizer,
edge_featurizer=bond_featurizer,
n_jobs=1 if args.num_workers == 0 else args.num_workers)
elif args.dataset == 'PCBA':
from dgllife.data import PCBA

dataset = PCBA(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=atom_featurizer,
edge_featurizer=bond_featurizer,
n_jobs=1 if args.num_workers == 0 else args.num_workers)
elif args.dataset == 'Tox21':
from dgllife.data import Tox21

dataset = Tox21(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=atom_featurizer,
edge_featurizer=bond_featurizer,
n_jobs=1 if args.num_workers == 0 else args.num_workers)
else:
raise ValueError('Unexpected dataset: {}'.format(args.dataset))

main(args, dataset, device)
Loading

0 comments on commit 69c11a9

Please sign in to comment.