-
Notifications
You must be signed in to change notification settings - Fork 157
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* [Example] Pretrain_gnns * Change pretrain_gnns location * Create README.md * Fix README.md * Update README.md * Update pretrain_masking.py * Update pretrain_supervised.py * Update pretrain_masking.py * Update * Update * Fix * Update * Improve performance * Fix * Update * Update pretrain_supervised.py * Support custom datasets * Update README.md * Fix typo * Update gin_predictor.py to support skipping readout * Update pretrain_masking.py model * Update * Merge datasets * Update * Create classification.py * Bugfix * Bugfix * Fix gin_predictor.py Support skipping the top linear layer * Update gin_predictor.py Support skipping the readout and top linear layer * Update utils.py * Update pretrain_masking.py * Update * Update .gitignore * Update utils.py * Update pretrain_supervised.py * Update pretrain_masking.py * Update classification.py * Update README.md Update experiment results * Update README.md * Fix * Update * Update * Update pretrain_masking.py * Update pretrain_supervised.py * Update README.md * Update * Update gin_predictor.py * Update gin_predictor.py * Delete model.py * Update * Update * Update * Update * Update * Update * Update * Update Co-authored-by: Mufei Li <[email protected]>
- Loading branch information
Showing
7 changed files
with
831 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
# IDE | ||
.idea | ||
.vscode | ||
|
||
# Byte-compiled | ||
__pycache__/ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# Strategies for Pre-training Graph Neural Networks | ||
|
||
## Intro | ||
This is a DGL implementation of the following paper based on PyTorch. | ||
|
||
- [Strategies for Pre-training Graph Neural Networks.](https://arxiv.org/abs/1905.12265) W. Hu*, B. Liu*, J. Gomes, M. Zitnik., P. Liang, V. Pande, J. Leskovec. *International Conference on Learning Representations (ICLR)*, 2020. | ||
|
||
## Datasets | ||
- For node-level self-supervised pre-training, 2 million unlabeled molecules sampled from the ZINC15 database are used. Custom datasets are supported. | ||
- For graph-level multi-task supervised pre-training, a preprocessed ChEMBL dataset is used, which contains 456K molecules with 1310 kinds of diverse and extensive biochemical assays. Custom datasets are supported. | ||
- For fine-tuning downstream tasks, BBBP, Tox21, ToxCast, SIDER, MUV, HIV and BACE dataset are supported. | ||
|
||
## Usage | ||
**1. Self-supervised pre-training** | ||
|
||
This paper purposed an attribute masking pre-training method. It randomly masks input node/edge attributes by replacing them with special masked indicators, then the GNN will predict those attributes based on neighboring structure. | ||
|
||
``` bash | ||
python pretrain_masking.py --output_model_file OUTPUT_MODEL_FILE | ||
``` | ||
The self-supervised pre-training model will be found in `OUTPUT_MODEL_FILE` after training (default filename: pretrain_masking.pth). | ||
|
||
If a custom dataset is specified, the path needs to be provided with `--dataset`. The custom dataset is supposed to be a text file, where every line is a molecule SMILES except that the first is 'smiles'. | ||
|
||
**2. Supervised pre-training** | ||
``` bash | ||
python pretrain_supervised.py --input_model_file INPUT_MODEL_FILE --output_model_file OUTPUT_MODEL_FILE | ||
``` | ||
The self-supervised pre-trained model can be loaded from `INPUT_MODEL_FILE`. | ||
|
||
The supervised pre-training model will be found in `OUTPUT_MODEL_FILE` after training (default filename: pretrain_supervised.pth). | ||
|
||
If a custom dataset is specified, the path needs to be provided with `--dataset`. The custom dataset is supposed to be a `.pkl` file, which is pickled from "a list of tuples". The first element in every `tuple` should be a molecule SMILES in class `str`, and the second element should be its corresponding label in class `torch.Tensor`. Possible values are {-1, 0, 1} in labels. "1" means positive, and "-1" means negative. "0" indicates the molecule is invalid. | ||
|
||
**3. Fine-tuning for downstream dataset** | ||
``` bash | ||
python classification.py --input_model_file INPUT_MODEL_FILE --output_model_file OUTPUT_MODEL_FILE --dataset DOWNSTREAM_DATASET | ||
``` | ||
|
||
The supervised pre-trained model can be loaded from `INPUT_MODEL_FILE`. | ||
|
||
The fine-tuned model will be found in `OUTPUT_MODEL_FILE` after training (default filename: pretrain_fine_tuning.pth). | ||
|
||
## Experiment Results | ||
|
||
With the default parameters, following downstream task results are based on Attribute Masking (Node-level) and Supervised (Graph-level) pre-training strategy with GIN. | ||
|
||
| Dataset | ROC-AUC (%) | ROC-AUC reported (%) | | ||
| :-----: | :-----: | :--------: | | ||
| BBBP | 71.75 | 66.5 ± 2.5 | | ||
| Tox21 | 72.67 | 77.9 ± 0.4 | | ||
| ToxCast | 62.22 | 65.1 ± 0.3 | | ||
| SIDER | 58.97 | 63.9 ± 0.9 | | ||
| MUV | 79.44 | 81.2 ± 1.9 | | ||
| HIV | 74.52 | 77.1 ± 1.2 | | ||
| BACE | 77.34 | 80.3 ± 0.9 | |
225 changes: 225 additions & 0 deletions
225
examples/property_prediction/pretrain_gnns/chem/classification.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
# -*- coding: utf-8 -*- | ||
# | ||
# adapted from | ||
# https://github.com/awslabs/dgl-lifesci/blob/master/examples/property_prediction/moleculenet/classification.py | ||
|
||
import argparse | ||
from functools import partial | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
|
||
from dgllife.utils import PretrainAtomFeaturizer | ||
from dgllife.utils import PretrainBondFeaturizer | ||
from dgllife.utils import smiles_to_bigraph, Meter, EarlyStopping | ||
from dgllife.model.model_zoo.gin_predictor import GINPredictor | ||
|
||
from torch.utils.data import DataLoader | ||
|
||
from utils import split_dataset, collate_molgraphs | ||
|
||
|
||
def train(args, epoch, model, data_loader, loss_criterion, optimizer, device): | ||
model.train() | ||
train_meter = Meter() | ||
for batch_id, batch_data in enumerate(data_loader): | ||
smiles, bg, labels, masks = batch_data | ||
|
||
if len(smiles) == 1: | ||
# Avoid potential issues with batch normalization | ||
continue | ||
|
||
labels, masks = labels.to(device), masks.to(device) | ||
bg = bg.to(device) | ||
node_feats = [ | ||
bg.ndata.pop('atomic_number').to(device), | ||
bg.ndata.pop('chirality_type').to(device) | ||
] | ||
edge_feats = [ | ||
bg.edata.pop('bond_type').to(device), | ||
bg.edata.pop('bond_direction_type').to(device) | ||
] | ||
logits = model(bg, node_feats, edge_feats) | ||
# Mask non-existing labels | ||
loss = (loss_criterion(logits, labels) * (masks != 0).float()).mean() | ||
|
||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
|
||
train_meter.update(logits, labels, masks) | ||
|
||
if batch_id % args.print_every == 0: | ||
print('epoch {:d}/{:d}, batch {:d}/{:d}, loss {:.4f}'.format( | ||
epoch + 1, args.num_epochs, batch_id + 1, len(data_loader), loss.item())) | ||
|
||
train_score = np.mean(train_meter.compute_metric(args.metric)) | ||
print('epoch {:d}/{:d}, training {} {:.4f}'.format( | ||
epoch + 1, args.num_epochs, args.metric, train_score)) | ||
|
||
|
||
def evaluation(args, model, data_loader, device): | ||
model.eval() | ||
eval_meter = Meter() | ||
with torch.no_grad(): | ||
for _, batch_data in enumerate(data_loader): | ||
_, bg, labels, masks = batch_data | ||
labels = labels.to(device) | ||
bg = bg.to(device) | ||
node_feats = [ | ||
bg.ndata.pop('atomic_number').to(device), | ||
bg.ndata.pop('chirality_type').to(device) | ||
] | ||
edge_feats = [ | ||
bg.edata.pop('bond_type').to(device), | ||
bg.edata.pop('bond_direction_type').to(device) | ||
] | ||
logits = model(bg, node_feats, edge_feats) | ||
eval_meter.update(logits, labels, masks) | ||
return np.mean(eval_meter.compute_metric(args.metric)) | ||
|
||
|
||
def main(args, dataset, device): | ||
train_set, val_set, test_set = split_dataset(args, dataset) | ||
train_loader = DataLoader(dataset=train_set, batch_size=32, shuffle=True, | ||
collate_fn=collate_molgraphs, num_workers=args.num_workers) | ||
val_loader = DataLoader(dataset=val_set, batch_size=32, | ||
collate_fn=collate_molgraphs, num_workers=args.num_workers) | ||
test_loader = DataLoader(dataset=test_set, batch_size=32, | ||
collate_fn=collate_molgraphs, num_workers=args.num_workers) | ||
|
||
model = GINPredictor(num_node_emb_list=[119, 4], | ||
num_edge_emb_list=[6, 3], | ||
num_layers=5, | ||
emb_dim=300, | ||
JK='last', | ||
dropout=0.2, | ||
readout='mean', | ||
n_tasks=dataset.n_tasks) | ||
if args.input_model_file != '': | ||
model.gnn.load_state_dict(torch.load(args.input_model_file)) | ||
model.to(device) | ||
|
||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0) | ||
criterion = nn.BCEWithLogitsLoss(reduction='none') | ||
stopper = EarlyStopping(filename=args.output_model_file) | ||
|
||
for epoch in range(0, args.num_epochs): | ||
train(args, epoch, model, train_loader, criterion, optimizer, device) | ||
val_score = evaluation(args, model, val_loader, device) | ||
early_stop = stopper.step(val_score, model) | ||
print('epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'.format( | ||
epoch + 1, args.num_epochs, args.metric, | ||
val_score, args.metric, stopper.best_score)) | ||
|
||
if early_stop: | ||
break | ||
stopper.load_checkpoint(model) | ||
test_score = evaluation(args, model, test_loader, device) | ||
print('test {} {:.4f}'.format(args.metric, test_score)) | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description='pretrain_downstream_classification_task') | ||
parser.add_argument('--device', type=int, default=0, | ||
help='which gpu to use if any. (default: 0)') | ||
parser.add_argument('-d', '--dataset', choices=['MUV', 'BACE', 'BBBP', 'ClinTox', 'SIDER', | ||
'ToxCast', 'HIV', 'PCBA', 'Tox21'], | ||
help='Dataset to use') | ||
parser.add_argument('-s', '--split', choices=['scaffold', 'random'], default='scaffold', | ||
help='Dataset splitting method (default: scaffold)') | ||
parser.add_argument('-sr', '--split-ratio', default='0.8,0.1,0.1', type=str, | ||
help='Proportion of the dataset to use for training, validation and test, ' | ||
'(default: 0.8,0.1,0.1)') | ||
parser.add_argument('-me', '--metric', choices=['roc_auc_score', 'pr_auc_score'], | ||
default='roc_auc_score', | ||
help='Metric for evaluation (default: roc_auc_score)') | ||
parser.add_argument('-n', '--num-epochs', type=int, default=1000, | ||
help='Maximum number of epochs for training. ' | ||
'We set a large number by default as early stopping ' | ||
'will be performed. (default: 1000)') | ||
parser.add_argument('-nw', '--num-workers', type=int, default=0, | ||
help='Number of processes for data loading (default: 0)') | ||
parser.add_argument('-pe', '--print-every', type=int, default=20, | ||
help='Print the training progress every X mini-batches') | ||
parser.add_argument('--input_model_file', type=str, default='pretrain_supervised.pth', | ||
help='filename to input the pre-trained model if there is any.' | ||
' (default: pretrain_supervised.pth)') | ||
parser.add_argument('--output_model_file', type=str, default='pretrain_fine_tuning.pth', | ||
help='filename to output the pre-trained downstream task model.' | ||
' (default: pretrain_fine_tuning.pth)') | ||
args = parser.parse_args() | ||
print(args) | ||
|
||
device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") | ||
|
||
atom_featurizer = PretrainAtomFeaturizer() | ||
bond_featurizer = PretrainBondFeaturizer() | ||
|
||
if args.dataset == 'MUV': | ||
from dgllife.data import MUV | ||
|
||
dataset = MUV(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), | ||
node_featurizer=atom_featurizer, | ||
edge_featurizer=bond_featurizer, | ||
n_jobs=1 if args.num_workers == 0 else args.num_workers) | ||
elif args.dataset == 'BACE': | ||
from dgllife.data import BACE | ||
|
||
dataset = BACE(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), | ||
node_featurizer=atom_featurizer, | ||
edge_featurizer=bond_featurizer, | ||
n_jobs=1 if args.num_workers == 0 else args.num_workers) | ||
elif args.dataset == 'BBBP': | ||
from dgllife.data import BBBP | ||
|
||
dataset = BBBP(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), | ||
node_featurizer=atom_featurizer, | ||
edge_featurizer=bond_featurizer, | ||
n_jobs=1 if args.num_workers == 0 else args.num_workers) | ||
elif args.dataset == 'ClinTox': | ||
from dgllife.data import ClinTox | ||
|
||
dataset = ClinTox(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), | ||
node_featurizer=atom_featurizer, | ||
edge_featurizer=bond_featurizer, | ||
n_jobs=1 if args.num_workers == 0 else args.num_workers) | ||
elif args.dataset == 'SIDER': | ||
from dgllife.data import SIDER | ||
|
||
dataset = SIDER(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), | ||
node_featurizer=atom_featurizer, | ||
edge_featurizer=bond_featurizer, | ||
n_jobs=1 if args.num_workers == 0 else args.num_workers) | ||
elif args.dataset == 'ToxCast': | ||
from dgllife.data import ToxCast | ||
|
||
dataset = ToxCast(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), | ||
node_featurizer=atom_featurizer, | ||
edge_featurizer=bond_featurizer, | ||
n_jobs=1 if args.num_workers == 0 else args.num_workers) | ||
elif args.dataset == 'HIV': | ||
from dgllife.data import HIV | ||
|
||
dataset = HIV(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), | ||
node_featurizer=atom_featurizer, | ||
edge_featurizer=bond_featurizer, | ||
n_jobs=1 if args.num_workers == 0 else args.num_workers) | ||
elif args.dataset == 'PCBA': | ||
from dgllife.data import PCBA | ||
|
||
dataset = PCBA(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), | ||
node_featurizer=atom_featurizer, | ||
edge_featurizer=bond_featurizer, | ||
n_jobs=1 if args.num_workers == 0 else args.num_workers) | ||
elif args.dataset == 'Tox21': | ||
from dgllife.data import Tox21 | ||
|
||
dataset = Tox21(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), | ||
node_featurizer=atom_featurizer, | ||
edge_featurizer=bond_featurizer, | ||
n_jobs=1 if args.num_workers == 0 else args.num_workers) | ||
else: | ||
raise ValueError('Unexpected dataset: {}'.format(args.dataset)) | ||
|
||
main(args, dataset, device) |
Oops, something went wrong.