Skip to content

Commit

Permalink
Disable Caching Molecular Graphs by Default (#129)
Browse files Browse the repository at this point in the history
* Update

* Update

* Update

* Update

* Update

* Update

* Fix

* Update

* Update

* Update

* Update
  • Loading branch information
mufeili authored Jan 10, 2021
1 parent 69c11a9 commit 89be5a3
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def main(rank, dev_id, args):
train_set, val_set = load_dataset(args)
get_center_subset(train_set, rank, args['num_devices'])
train_loader = DataLoader(train_set, batch_size=args['batch_size'],
collate_fn=collate_center, shuffle=True)
collate_fn=collate_center, shuffle=True,
num_workers=args['num_processes'])
val_loader = DataLoader(val_set, batch_size=args['batch_size'],
collate_fn=collate_center, shuffle=False)

Expand Down Expand Up @@ -107,10 +108,9 @@ def main(rank, dev_id, args):
if total_iter % args['decay_every'] == 0:
optimizer.decay_lr(args['lr_decay_factor'])
if total_iter % args['decay_every'] == 0 and rank == 0:
if epoch >= 1:
dur.append(time.time() - t0)
print('Training time per {:d} iterations: {:.4f}'.format(
rank_iter, np.mean(dur)))
dur.append(time.time() - t0)
print('Estimated training time per epoch: {:.4f}'.format(
np.mean(dur) / args['decay_every'] * len(train_loader)))
total_samples = total_iter * args['batch_size']
prediction_summary = 'total samples {:d}, (epoch {:d}/{:d}, iter {:d}/{:d}) '.format(
total_samples, epoch + 1, args['num_epochs'], batch_id + 1, len(train_loader)) + \
Expand Down Expand Up @@ -161,7 +161,7 @@ def run(rank, dev_id, args):
parser.add_argument('--val-path', type=str, default=None,
help='Path to a new validation set. '
'If None, we will use the default validation set in USPTO.')
parser.add_argument('-np', '--num-processes', type=int, default=32,
parser.add_argument('-np', '--num-processes', type=int, default=4,
help='Number of processes to use for data pre-processing')
parser.add_argument('--master-ip', type=str, default='127.0.0.1',
help='master ip address')
Expand Down
73 changes: 46 additions & 27 deletions python/dgllife/data/uspto.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,8 @@ class WLNCenterDataset(object):
``reaction_validity_result_prefix + _valid_reactions.proc`` and
invalid ones in ``reaction_validity_result_prefix + _invalid_reactions.proc``.
Default to ``''``.
cache : bool
If True, construct and featurize all graphs at once.
"""
def __init__(self,
raw_file_path,
Expand All @@ -417,7 +419,8 @@ def __init__(self,
load=True,
num_processes=1,
check_reaction_validity=True,
reaction_validity_result_prefix='',
reaction_validity_result_prefix='',
cache=True,
**kwargs):
super(WLNCenterDataset, self).__init__()

Expand All @@ -426,6 +429,7 @@ def __init__(self,
self.atom_pair_labels = []
# Map number of nodes to a corresponding complete graph
self.complete_graphs = dict()
self.cache = cache

path_to_reaction_file = raw_file_path + '.proc'
built_in = kwargs.get('built_in', False)
Expand Down Expand Up @@ -455,32 +459,37 @@ def __init__(self,
t0 = time.time()
full_mols, full_reactions, full_graph_edits = \
self.load_reaction_data(path_to_reaction_file, num_processes)
self.mols = full_mols
self.reactions = full_reactions
self.graph_edits = full_graph_edits
print('Time spent', time.time() - t0)

if load and os.path.isfile(mol_graph_path):
print('Loading previously saved graphs...')
self.reactant_mol_graphs, _ = load_graphs(mol_graph_path)
else:
print('Constructing graphs from scratch...')
if num_processes == 1:
self.reactant_mol_graphs = []
for mol in full_mols:
self.reactant_mol_graphs.append(mol_to_graph(
mol, node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer, canonical_atom_order=False))
if self.cache:
if load and os.path.isfile(mol_graph_path):
print('Loading previously saved graphs...')
self.reactant_mol_graphs, _ = load_graphs(mol_graph_path)
else:
torch.multiprocessing.set_sharing_strategy('file_system')
with Pool(processes=num_processes) as pool:
self.reactant_mol_graphs = pool.map(
partial(mol_to_graph, node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer, canonical_atom_order=False),
full_mols)

save_graphs(mol_graph_path, self.reactant_mol_graphs)
print('Constructing graphs from scratch...')
if num_processes == 1:
self.reactant_mol_graphs = []
for mol in full_mols:
self.reactant_mol_graphs.append(mol_to_graph(
mol, node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer, canonical_atom_order=False))
else:
torch.multiprocessing.set_sharing_strategy('file_system')
with Pool(processes=num_processes) as pool:
self.reactant_mol_graphs = pool.map(
partial(mol_to_graph, node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer, canonical_atom_order=False),
full_mols)

save_graphs(mol_graph_path, self.reactant_mol_graphs)
else:
self.mol_to_graph = mol_to_graph
self.node_featurizer = node_featurizer
self.edge_featurizer = edge_featurizer

self.mols = full_mols
self.reactions = full_reactions
self.graph_edits = full_graph_edits
self.atom_pair_features.extend([None for _ in range(len(self.mols))])
self.atom_pair_labels.extend([None for _ in range(len(self.mols))])

Expand Down Expand Up @@ -570,9 +579,15 @@ def __getitem__(self, item):
if self.atom_pair_labels[item] is None:
self.atom_pair_labels[item] = get_pair_label(mol, self.graph_edits[item])

if self.cache:
mol_graph = self.reactant_mol_graphs[item]
else:
mol_graph = self.mol_to_graph(mol, node_featurizer=self.node_featurizer,
edge_featurizer=self.edge_featurizer,
canonical_atom_order=False)

return self.reactions[item], self.graph_edits[item], \
self.reactant_mol_graphs[item], \
self.complete_graphs[num_atoms], \
mol_graph, self.complete_graphs[num_atoms], \
self.atom_pair_features[item], \
self.atom_pair_labels[item]

Expand Down Expand Up @@ -618,6 +633,8 @@ class USPTOCenter(WLNCenterDataset):
featurization methods and need to preprocess from scratch. Default to True.
num_processes : int
Number of processes to use for data pre-processing. Default to 1.
cache : bool
If True, construct and featurize all graphs at once.
"""
def __init__(self,
subset,
Expand All @@ -626,7 +643,8 @@ def __init__(self,
edge_featurizer=default_edge_featurizer_center,
atom_pair_featurizer=default_atom_pair_featurizer,
load=True,
num_processes=1):
num_processes=1,
cache=False):
assert subset in ['train', 'val', 'test'], \
'Expect subset to be "train" or "val" or "test", got {}'.format(subset)
print('Preparing {} subset of USPTO for reaction center prediction.'.format(subset))
Expand All @@ -650,7 +668,8 @@ def __init__(self,
load=load,
num_processes=num_processes,
check_reaction_validity=False,
built_in=True)
built_in=True,
cache=cache)

@property
def subset(self):
Expand Down
3 changes: 2 additions & 1 deletion python/dgllife/model/model_zoo/gin_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ class GINPredictor(nn.Module):
Dropout to apply to the output of each GIN layer. Default to 0.5.
readout : str
Readout for computing graph representations out of node representations, which
can be ``'sum'``, ``'mean'``, ``'max'``, ``'attention'``, or ``'set2set'``. Default to 'mean'.
can be ``'sum'``, ``'mean'``, ``'max'``, ``'attention'``, or ``'set2set'``. Default
to 'mean'.
n_tasks : int
Number of tasks, which is also the output size. Default to 1.
"""
Expand Down
8 changes: 5 additions & 3 deletions python/dgllife/model/readout/sum_and_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@

import dgl
import torch
import torch.nn as nn

__all__ = ['SumAndMax']

# pylint: disable=W0221
class SumAndMax(nn.Module):
# pylint: disable=W0221, W0622
class SumAndMax(object):
r"""Apply sum and max pooling to the node
representations and concatenate the results.
"""
Expand Down Expand Up @@ -42,3 +41,6 @@ def forward(self, bg, feats):
h_g_max = dgl.max_nodes(bg, 'h')
h_g = torch.cat([h_g_sum, h_g_max], dim=1)
return h_g

def __call__(self, *input, **kwargs):
return self.forward(*input, **kwargs)
2 changes: 1 addition & 1 deletion tests/model/test_property_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def test_nf_predictor():
activation=[None, None],
batchnorm=[False, False],
dropout=[0.1, 0.1],
predictor_size=4,
predictor_hidden_size=4,
predictor_batchnorm=False,
predictor_dropout=0.1,
predictor_activation=None).to(device)
Expand Down
2 changes: 1 addition & 1 deletion tests/model/test_readout.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_sum_and_max():
g, node_feats = g.to(device), node_feats.to(device)
bg, batch_node_feats = test_graph2()
bg, batch_node_feats = bg.to(device), batch_node_feats.to(device)
model = SumAndMax().to(device)
model = SumAndMax()
assert model(g, node_feats).shape == torch.Size([1, 2])
assert model(bg, batch_node_feats).shape == torch.Size([2, 2])

Expand Down

0 comments on commit 89be5a3

Please sign in to comment.