-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathsingle_task_devset.py
76 lines (69 loc) · 2.03 KB
/
single_task_devset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from __future__ import annotations
import lightning.pytorch as pl
from torch.nn import LayerNorm, SiLU
from matsciml.datasets.transforms import (
PointCloudToGraphTransform,
PeriodicPropertiesTransform,
)
from matsciml.lightning.data_utils import MatSciMLDataModule
from matsciml.models import PLEGNNBackbone
from matsciml.models.base import ScalarRegressionTask
# configure a simple model for testing
model_args = {
"embed_in_dim": 128,
"embed_hidden_dim": 32,
"embed_out_dim": 128,
"embed_depth": 5,
"embed_feat_dims": [128, 128, 128],
"embed_message_dims": [128, 128, 128],
"embed_position_dims": [64, 64],
"embed_edge_attributes_dim": 0,
"embed_activation": "relu",
"embed_residual": True,
"embed_normalize": True,
"embed_tanh": True,
"embed_activate_last": False,
"embed_k_linears": 1,
"embed_use_attention": False,
"embed_attention_norm": "sigmoid",
"readout": "sum",
"node_projection_depth": 3,
"node_projection_hidden_dim": 128,
"node_projection_activation": "relu",
"prediction_out_dim": 1,
"prediction_depth": 3,
"prediction_hidden_dim": 128,
"prediction_activation": "relu",
"encoder_only": True,
}
model = PLEGNNBackbone(**model_args)
task = ScalarRegressionTask(
model,
output_kwargs={
"norm": LayerNorm(128),
"hidden_dim": 128,
"activation": SiLU,
"lazy": False,
"input_dim": 128,
},
lr=1e-3,
task_keys=["stability"],
)
# configure materials project from devset
dm = MatSciMLDataModule.from_devset(
"OQMDDataset",
dset_kwargs={
"transforms": [
PeriodicPropertiesTransform(cutoff_radius=6.0, adaptive_cutoff=True),
PointCloudToGraphTransform(
"dgl",
cutoff_dist=20.0,
node_keys=["pos", "atomic_numbers"],
),
],
},
num_workers=0,
)
# run 10 steps for funsies
trainer = pl.Trainer(fast_dev_run=10, enable_checkpointing=False, logger=False)
trainer.fit(task, datamodule=dm)