-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathmodule.py
261 lines (219 loc) · 10.7 KB
/
module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR
from torch.nn.functional import mse_loss, l1_loss
from pytorch_lightning import LightningModule
from torchmdnet.models.model import create_model, load_model
class LNNP(LightningModule):
def __init__(self, hparams, prior_model=None, mean=None, std=None):
super(LNNP, self).__init__()
self.save_hyperparameters(hparams)
if self.hparams.load_model:
self.model = load_model(self.hparams.load_model, args=self.hparams)
elif self.hparams.pretrained_model:
self.model = load_model(self.hparams.pretrained_model, args=self.hparams, mean=mean, std=std)
else:
self.model = create_model(self.hparams, prior_model, mean, std)
# initialize exponential smoothing
self.ema = None
self._reset_ema_dict()
# initialize loss collection
self.losses = None
self._reset_losses_dict()
def configure_optimizers(self):
optimizer = AdamW(
self.model.parameters(),
lr=self.hparams.lr,
weight_decay=self.hparams.weight_decay,
)
if self.hparams.lr_schedule == 'cosine':
scheduler = CosineAnnealingLR(optimizer, self.hparams.lr_cosine_length)
lr_scheduler = {
"scheduler": scheduler,
"interval": "step",
"frequency": 1,
}
elif self.hparams.lr_schedule == 'reduce_on_plateau':
scheduler = ReduceLROnPlateau(
optimizer,
"min",
factor=self.hparams.lr_factor,
patience=self.hparams.lr_patience,
min_lr=self.hparams.lr_min,
)
lr_scheduler = {
"scheduler": scheduler,
"monitor": "val_loss",
"interval": "epoch",
"frequency": 1,
}
else:
raise ValueError(f"Unknown lr_schedule: {self.hparams.lr_schedule}")
return [optimizer], [lr_scheduler]
def forward(self, z, pos, batch=None):
return self.model(z, pos, batch=batch)
def training_step(self, batch, batch_idx):
return self.step(batch, mse_loss, "train")
def validation_step(self, batch, batch_idx, *args):
if len(args) == 0 or (len(args) > 0 and args[0] == 0):
# validation step
return self.step(batch, mse_loss, "val")
# test step
return self.step(batch, l1_loss, "test")
def test_step(self, batch, batch_idx):
return self.step(batch, l1_loss, "test")
def step(self, batch, loss_fn, stage):
with torch.set_grad_enabled(stage == "train" or self.hparams.derivative):
# TODO: the model doesn't necessarily need to return a derivative once
# Union typing works under TorchScript (https://github.com/pytorch/pytorch/pull/53180)
pred, noise_pred, deriv = self(batch.z, batch.pos, batch.batch)
denoising_is_on = ("pos_target" in batch) and (self.hparams.denoising_weight > 0) and (noise_pred is not None)
loss_y, loss_dy, loss_pos = 0, 0, 0
if self.hparams.derivative:
if "y" not in batch:
# "use" both outputs of the model's forward function but discard the first
# to only use the derivative and avoid 'Expected to have finished reduction
# in the prior iteration before starting a new one.', which otherwise get's
# thrown because of setting 'find_unused_parameters=False' in the DDPPlugin
deriv = deriv + pred.sum() * 0
# force/derivative loss
loss_dy = loss_fn(deriv, batch.dy)
if stage in ["train", "val"] and self.hparams.ema_alpha_dy < 1:
if self.ema[stage + "_dy"] is None:
self.ema[stage + "_dy"] = loss_dy.detach()
# apply exponential smoothing over batches to dy
loss_dy = (
self.hparams.ema_alpha_dy * loss_dy
+ (1 - self.hparams.ema_alpha_dy) * self.ema[stage + "_dy"]
)
self.ema[stage + "_dy"] = loss_dy.detach()
if self.hparams.force_weight > 0:
self.losses[stage + "_dy"].append(loss_dy.detach())
if "y" in batch:
if (noise_pred is not None) and not denoising_is_on:
# "use" both outputs of the model's forward (see comment above).
pred = pred + noise_pred.sum() * 0
if batch.y.ndim == 1:
batch.y = batch.y.unsqueeze(1)
# energy/prediction loss
loss_y = loss_fn(pred, batch.y)
if stage in ["train", "val"] and self.hparams.ema_alpha_y < 1:
if self.ema[stage + "_y"] is None:
self.ema[stage + "_y"] = loss_y.detach()
# apply exponential smoothing over batches to y
loss_y = (
self.hparams.ema_alpha_y * loss_y
+ (1 - self.hparams.ema_alpha_y) * self.ema[stage + "_y"]
)
self.ema[stage + "_y"] = loss_y.detach()
if self.hparams.energy_weight > 0:
self.losses[stage + "_y"].append(loss_y.detach())
if denoising_is_on:
if "y" not in batch:
# "use" both outputs of the model's forward (see comment above).
noise_pred = noise_pred + pred.sum() * 0
normalized_pos_target = self.model.pos_normalizer(batch.pos_target)
loss_pos = loss_fn(noise_pred, normalized_pos_target)
self.losses[stage + "_pos"].append(loss_pos.detach())
# total loss
loss = loss_y * self.hparams.energy_weight + loss_dy * self.hparams.force_weight + loss_pos * self.hparams.denoising_weight
self.losses[stage].append(loss.detach())
# Frequent per-batch logging for training
if stage == 'train':
train_metrics = {k + "_per_step": v[-1] for k, v in self.losses.items() if (k.startswith("train") and len(v) > 0)}
train_metrics['lr_per_step'] = self.trainer.optimizers[0].param_groups[0]["lr"]
train_metrics['step'] = self.trainer.global_step
train_metrics['batch_pos_mean'] = batch.pos.mean().item()
self.log_dict(train_metrics, sync_dist=True)
return loss
def optimizer_step(self, *args, **kwargs):
optimizer = kwargs["optimizer"] if "optimizer" in kwargs else args[2]
if self.trainer.global_step < self.hparams.lr_warmup_steps:
lr_scale = min(
1.0,
float(self.trainer.global_step + 1)
/ float(self.hparams.lr_warmup_steps),
)
for pg in optimizer.param_groups:
pg["lr"] = lr_scale * self.hparams.lr
super().optimizer_step(*args, **kwargs)
optimizer.zero_grad()
def training_epoch_end(self, training_step_outputs):
dm = self.trainer.datamodule
if hasattr(dm, "test_dataset") and len(dm.test_dataset) > 0:
should_reset = (
self.current_epoch % self.hparams.test_interval == 0
or (self.current_epoch - 1) % self.hparams.test_interval == 0
)
if should_reset:
# reset validation dataloaders before and after testing epoch, which is faster
# than skipping test validation steps by returning None
self.trainer.reset_val_dataloader(self)
# TODO(shehzaidi): clean up this function, redundant logging if dy loss exists.
def validation_epoch_end(self, validation_step_outputs):
if not self.trainer.running_sanity_check:
# construct dict of logged metrics
result_dict = {
"epoch": self.current_epoch,
"lr": self.trainer.optimizers[0].param_groups[0]["lr"],
"train_loss": torch.stack(self.losses["train"]).mean(),
"val_loss": torch.stack(self.losses["val"]).mean(),
}
# add test loss if available
if len(self.losses["test"]) > 0:
result_dict["test_loss"] = torch.stack(self.losses["test"]).mean()
# if prediction and derivative are present, also log them separately
if len(self.losses["train_y"]) > 0 and len(self.losses["train_dy"]) > 0:
result_dict["train_loss_y"] = torch.stack(self.losses["train_y"]).mean()
result_dict["train_loss_dy"] = torch.stack(
self.losses["train_dy"]
).mean()
result_dict["val_loss_y"] = torch.stack(self.losses["val_y"]).mean()
result_dict["val_loss_dy"] = torch.stack(self.losses["val_dy"]).mean()
if len(self.losses["test"]) > 0:
result_dict["test_loss_y"] = torch.stack(
self.losses["test_y"]
).mean()
result_dict["test_loss_dy"] = torch.stack(
self.losses["test_dy"]
).mean()
if len(self.losses["train_y"]) > 0:
result_dict["train_loss_y"] = torch.stack(self.losses["train_y"]).mean()
if len(self.losses['val_y']) > 0:
result_dict["val_loss_y"] = torch.stack(self.losses["val_y"]).mean()
if len(self.losses["test_y"]) > 0:
result_dict["test_loss_y"] = torch.stack(
self.losses["test_y"]
).mean()
# if denoising is present, also log it
if len(self.losses["train_pos"]) > 0:
result_dict["train_loss_pos"] = torch.stack(
self.losses["train_pos"]
).mean()
if len(self.losses["val_pos"]) > 0:
result_dict["val_loss_pos"] = torch.stack(
self.losses["val_pos"]
).mean()
if len(self.losses["test_pos"]) > 0:
result_dict["test_loss_pos"] = torch.stack(
self.losses["test_pos"]
).mean()
self.log_dict(result_dict, sync_dist=True)
self._reset_losses_dict()
def _reset_losses_dict(self):
self.losses = {
"train": [],
"val": [],
"test": [],
"train_y": [],
"val_y": [],
"test_y": [],
"train_dy": [],
"val_dy": [],
"test_dy": [],
"train_pos": [],
"val_pos": [],
"test_pos": [],
}
def _reset_ema_dict(self):
self.ema = {"train_y": None, "val_y": None, "train_dy": None, "val_dy": None}