-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_training_DG.py
168 lines (145 loc) · 5.66 KB
/
main_training_DG.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# Standard imports:
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import random_split
from torch_geometric.data import DataLoader
from torch_geometric.transforms import Compose
from pathlib import Path
# Custom data loader and model:
from data import ProteinPairsSurfaces, PairData, CenterPairAtoms
from data import RandomRotationPairAtoms, NormalizeChemFeatures, iface_valid_filter
from model import dMaSIF
from data_iteration import iterate, iterate_surface_precompute
from helper import *
from Arguments import parser
# Parse the arguments, prepare the TensorBoard writer:
args = parser.parse_args()
writer = SummaryWriter("runs/{}".format(args.experiment_name))
model_path = "models/" + args.experiment_name
#DG Add
save_predictions_path = Path("preds/" + args.experiment_name)
# save_predictions_path = Path("preds_tmp/" + args.experiment_name)
if not Path("models/").exists():
Path("models/").mkdir(exist_ok=False)
# Ensure reproducibility:
torch.backends.cudnn.deterministic = True
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
# Create the model, with a warm restart if applicable:
net = dMaSIF(args)
net = net.to(args.device)
# We load the train and test datasets.
# Random transforms, to ensure that no network/baseline overfits on pose parameters:
transformations = (
Compose([NormalizeChemFeatures(), CenterPairAtoms(), RandomRotationPairAtoms()])
if args.random_rotation
else Compose([NormalizeChemFeatures()])
)
# PyTorch geometric expects an explicit list of "batched variables":
batch_vars = ["xyz_p1", "xyz_p2", "atom_coords_p1", "atom_coords_p2"]
# Load the train dataset:
train_dataset = ProteinPairsSurfaces(
"surface_data", ppi=args.search, train=True, transform=transformations
)
train_dataset = [data for data in train_dataset if iface_valid_filter(data)]
train_loader = DataLoader(
train_dataset, batch_size=1, follow_batch=batch_vars, shuffle=True
)
print("Preprocessing training dataset")
train_dataset = iterate_surface_precompute(train_loader, net, args)
# Train/Validation split:
train_nsamples = len(train_dataset)
val_nsamples = int(train_nsamples * args.validation_fraction)
train_nsamples = train_nsamples - val_nsamples
train_dataset, val_dataset = random_split(
train_dataset, [train_nsamples, val_nsamples]
)
# Load the test dataset:
test_dataset = ProteinPairsSurfaces(
"surface_data", ppi=args.search, train=False, transform=transformations
)
test_dataset = [data for data in test_dataset if iface_valid_filter(data)]
test_loader = DataLoader(
test_dataset, batch_size=1, follow_batch=batch_vars, shuffle=True
)
print("Preprocessing testing dataset")
test_dataset = iterate_surface_precompute(test_loader, net, args)
# PyTorch_geometric data loaders:
train_loader = DataLoader(
train_dataset, batch_size=1, follow_batch=batch_vars, shuffle=True
)
val_loader = DataLoader(val_dataset, batch_size=1, follow_batch=batch_vars)
test_loader = DataLoader(test_dataset, batch_size=1, follow_batch=batch_vars)
# DG add
training_pdb_ids = (
np.load("surface_data/processed/training_pairs_data_ids.npy")
if args.site
else np.load("surface_data/processed/training_pairs_data_ids_ppi.npy")
)
# Baseline optimizer:
optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, amsgrad=True)
best_loss = 1e10 # We save the "best model so far"
starting_epoch = 0
if args.restart_training != "":
checkpoint = torch.load("models/" + args.restart_training)
net.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
starting_epoch = checkpoint["epoch"]
best_loss = checkpoint["best_loss"]
# Training loop (~100 times) over the dataset:
for i in range(starting_epoch, args.n_epochs):
# Train first, Test second:
for dataset_type in ["Train", "Validation", "Test"]:
if dataset_type == "Train":
test = False
else:
test = True
suffix = dataset_type
if dataset_type == "Train":
dataloader = train_loader
elif dataset_type == "Validation":
dataloader = val_loader
elif dataset_type == "Test":
dataloader = test_loader
# Perform one pass through the data:
info = iterate(
net,
dataloader,
optimizer,
args,
test=test,
save_path=save_predictions_path, # DG add
pdb_ids=training_pdb_ids, # DG add
summary_writer=writer,
epoch_number=i,
)
# Write down the results using a TensorBoard writer:
for key, val in info.items():
if key in [
"Loss",
"ROC-AUC",
"Distance/Positives",
"Distance/Negatives",
"Matching ROC-AUC",
]:
writer.add_scalar(f"{key}/{suffix}", np.mean(val), i)
if "R_values/" in key:
val = np.array(val)
writer.add_scalar(f"{key}/{suffix}", np.mean(val[val > 0]), i)
if dataset_type == "Validation": # Store validation loss for saving the model
val_loss = np.mean(info["Loss"])
if True: # Additional saves
if val_loss < best_loss:
print("Validation loss {}, saving model".format(val_loss))
torch.save(
{
"epoch": i,
"model_state_dict": net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"best_loss": best_loss,
},
model_path + "_epoch{}".format(i),
)
best_loss = val_loss