forked from FAIR-Chem/fairchem
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #35 from RolnickLab/gnn-improvements
Gnn improvements
- Loading branch information
Showing
128 changed files
with
11,546 additions
and
964 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# more epochs, larger batch size, explore fanet: larger model & skip-co & mlp_rij | ||
job: | ||
mem: 24GB | ||
cpus: 4 | ||
gres: gpu:1 | ||
time: 30:00 | ||
partition: main | ||
|
||
default: | ||
wandb_project: ocp-debug | ||
config: schnet-qm9-all | ||
mode: train | ||
wandb_tags: qm9, debug | ||
optim: | ||
batch_size: 64 | ||
max_epochs: -1 | ||
max_steps: 1e3 | ||
note: | ||
model: name, num_gaussians, hidden_channels, num_filters, num_interactions, phys_embeds, pg_hidden_channels, phys_hidden_channels | ||
optim: batch_size, lr_initial | ||
_root_: frame_averaging, fa_frames | ||
|
||
runs: | ||
- model: | ||
hidden_channels: 128 | ||
- model: | ||
hidden_channels: 64 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# more epochs, larger batch size, explore fanet: larger model & skip-co & mlp_rij | ||
job: | ||
mem: 24GB | ||
cpus: 4 | ||
gres: gpu:16gb:1 | ||
time: 1:00:00 | ||
partition: long | ||
code_loc: /home/mila/s/schmidtv/ocp-project/ocp-drlab | ||
env: ocp-a100 | ||
|
||
default: | ||
wandb_project: ocp-qm | ||
config: schnet-qm9-all | ||
mode: train | ||
test_ri: true | ||
wandb_tags: qm9, orion-debug | ||
phys_hidden_channels: 0 | ||
phys_embeds: False | ||
energy_head: False | ||
pg_hidden_channels: 0 | ||
tag_hidden_channels: 0 | ||
frame_averaging: "" | ||
cp_data_to_tmpdir: true | ||
optim: | ||
batch_size: 64 | ||
warmup_steps: 3000 | ||
lr_initial: 0.0002 | ||
# parameters EMA | ||
ema_decay: 0.999 | ||
# exp. decay to 0.01 * lr_initial in 1000000 steps | ||
decay_steps: max_steps | ||
decay_rate: 0.05 # at the end of training, lr is decay_rate*lr_initial | ||
# max_epochs = ref_steps[3e6] / (n_train[110 000] / ref_batch_size[32]) | ||
max_epochs: -1 | ||
note: | ||
model: name, num_gaussians, hidden_channels, num_filters, num_interactions, phys_embeds, pg_hidden_channels, phys_hidden_channels | ||
optim: batch_size, lr_initial | ||
_root_: frame_averaging, fa_frames | ||
|
||
orion: | ||
# Remember to change the experiment name if you change anything in the search space | ||
n_jobs: 20 | ||
|
||
unique_exp_name: ocp-qm9-orion-debug-v1.0.1 | ||
|
||
space: | ||
optim/max_steps: fidelity(1e3, 1e4, base=3) | ||
optim/batch_size: uniform(32, 128, discrete=True) | ||
optim/lr_initial: loguniform(1e-5, 5e-3, precision=2) | ||
model/num_gaussians: uniform(16, 200, discrete=True) | ||
model/hidden_channels: uniform(32, 512, discrete=True) | ||
model/num_filters: uniform(32, 512, discrete=True) | ||
model/num_interactions: uniform(1, 7, discrete=True) | ||
model/phys_embeds: choices([True, False]) | ||
|
||
algorithms: | ||
asha: | ||
seed: 123 | ||
num_rungs: 4 | ||
num_brackets: 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
job: | ||
mem: 32GB | ||
cpus: 4 | ||
gres: gpu:rtx8000:1 | ||
partition: long | ||
time: 20:00:00 | ||
|
||
default: | ||
test_ri: True | ||
mode: train | ||
graph_rewiring: remove-tag-0 | ||
model: | ||
phys_embeds: True | ||
tag_hidden_channels: 64 | ||
pg_hidden_channels: 0 | ||
energy_head: False | ||
edge_embed_type: all_rij | ||
wandb_tags: 'mp-type' | ||
optim: | ||
max_epochs: 10 | ||
batch_size: 256 | ||
eval_batch_size: 256 | ||
cp_data_to_tmpdir: true | ||
|
||
runs: | ||
- config: fanet-is2re-all | ||
note: 'batch norm after propagate Interaction' | ||
frame_averaging: 2D | ||
fa_frames: se3-random | ||
model: | ||
mp_type: base | ||
graph_norm: True | ||
- config: fanet-is2re-all | ||
note: 'batch norm after propagate Interaction' | ||
frame_averaging: 2D | ||
fa_frames: se3-random | ||
model: | ||
mp_type: att | ||
graph_norm: True | ||
- config: fanet-is2re-all | ||
note: 'batch norm after propagate Interaction' | ||
frame_averaging: 2D | ||
fa_frames: se3-random | ||
model: | ||
mp_type: local_env | ||
graph_norm: True | ||
- config: fanet-is2re-all | ||
note: 'batch norm after propagate Interaction' | ||
frame_averaging: 2D | ||
fa_frames: se3-random | ||
model: | ||
mp_type: sfarinet | ||
graph_norm: True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
job: | ||
mem: 32GB | ||
cpus: 4 | ||
gres: gpu:rtx8000:1 | ||
partition: long | ||
time: 20:00:00 | ||
|
||
default: | ||
test_ri: True | ||
mode: train | ||
graph_rewiring: remove-tag-0 | ||
model: | ||
phys_embeds: True | ||
tag_hidden_channels: 64 | ||
pg_hidden_channels: 0 # shall have been 32 | ||
energy_head: 'weighted-av-initial-embeds' # False ? | ||
wandb_tags: 'edge-embed-test' | ||
optim: | ||
max_epochs: 15 | ||
batch_size: 256 | ||
eval_batch_size: 256 | ||
cp_data_to_tmpdir: true | ||
|
||
runs: | ||
- config: fanet-is2re-all # 2678275 | ||
note: 'all rij' | ||
frame_averaging: 2D | ||
fa_frames: se3-random | ||
model: | ||
edge_embed_type: all_rij | ||
mp_type: base | ||
- config: fanet-is2re-all # 2678276 | ||
note: 'all' | ||
frame_averaging: 2D | ||
fa_frames: se3-random | ||
model: | ||
edge_embed_type: all | ||
mp_type: base | ||
- config: sfarinet-is2re-all # 2678277 | ||
note: 'all rij sfarinet' | ||
frame_averaging: 2D | ||
fa_frames: se3-random | ||
model: | ||
edge_embed_type: all_rij | ||
mp_type: sfarinet | ||
- config: sfarinet-is2re-all # 2678278 | ||
note: 'sfarinet all' | ||
frame_averaging: 2D | ||
fa_frames: se3-random | ||
model: | ||
edge_embed_type: all | ||
mp_type: sfarinet | ||
- config: sfarinet-is2re-all # 2678279 | ||
note: 'sfarinet all' | ||
frame_averaging: 2D | ||
fa_frames: se3-random | ||
model: | ||
edge_embed_type: all | ||
mp_type: base | ||
skip_co: "concat" | ||
complex_mp: true | ||
graph_norm: true | ||
second_layer_mlp: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
job: | ||
mem: 32GB | ||
cpus: 4 | ||
gres: gpu:rtx8000:1 | ||
partition: long | ||
time: 30:00:00 | ||
|
||
default: | ||
test_ri: True | ||
mode: train | ||
graph_rewiring: remove-tag-0 | ||
cp_data_to_tmpdir: true | ||
model: | ||
phys_embeds: True | ||
tag_hidden_channels: 64 | ||
pg_hidden_channels: 0 # shall have been 32 | ||
energy_head: False # False ? | ||
regress_forces: direct_with_gradient_target | ||
wandb_tags: 's2ef-archi-tests' | ||
optim: | ||
max_epochs: 5 | ||
batch_size: 192 | ||
eval_batch_size: 192 | ||
|
||
runs: | ||
- config: sfarinet-s2ef-2M | ||
note: 'Sfarinet no sym' | ||
- config: sfarinet-s2ef-2M | ||
note: 'Sfarinet baseline sym' | ||
frame_averaging: 2D | ||
fa_frames: se3-random | ||
- config: sfarinet-s2ef-2M | ||
note: 'Sfarinet baseline sym' | ||
frame_averaging: 2D | ||
fa_frames: se3-random | ||
- config: sfarinet-s2ef-2M | ||
note: 'rij' | ||
frame_averaging: 2D | ||
fa_frames: se3-random | ||
model: | ||
edge_embed_type: rij | ||
- config: sfarinet-s2ef-2M | ||
note: 'sh' | ||
frame_averaging: 2D | ||
fa_frames: se3-random | ||
model: | ||
edge_embed_type: sh | ||
- config: sfarinet-s2ef-2M | ||
note: 'all rij' | ||
frame_averaging: 2D | ||
fa_frames: se3-random | ||
model: | ||
edge_embed_type: all_rij | ||
- config: sfarinet-s2ef-2M | ||
note: 'all' | ||
frame_averaging: 2D | ||
fa_frames: se3-random | ||
model: | ||
edge_embed_type: all | ||
- config: sfarinet-s2ef-2M | ||
note: 'all' | ||
frame_averaging: 2D | ||
fa_frames: se3-random | ||
model: | ||
edge_embed_type: all |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
job: | ||
mem: 32GB | ||
cpus: 4 | ||
gres: gpu:rtx8000:1 | ||
partition: long | ||
time: 30:00:00 | ||
|
||
default: | ||
test_ri: True | ||
mode: train | ||
graph_rewiring: remove-tag-0 | ||
model: | ||
phys_embeds: True | ||
tag_hidden_channels: 64 | ||
pg_hidden_channels: 0 # shall have been 32 | ||
energy_head: 'weighted-av-initial-embeds' # False ? | ||
wandb_tags: 'is2re-archi-tests' | ||
optim: | ||
max_epochs: 5 | ||
batch_size: 256 | ||
eval_batch_size: 256 | ||
|
||
runs: | ||
- config: schnet-is2re-all | ||
note: 'Schnet' | ||
- config: sfarinet-is2re-all | ||
note: 'Sfarinet test' | ||
frame_averaging: 2D | ||
fa_fames: se3-random | ||
- config: sfarinet-is2re-all | ||
note: 'Smaller lr' | ||
frame_averaging: 2D | ||
fa_fames: se3-random | ||
optim: | ||
lr_initial: 0.0005 | ||
- config: sfarinet-is2re-all | ||
note: 'Sfarinet test smaller lr' | ||
frame_averaging: 2D | ||
fa_fames: se3-random | ||
optim: | ||
lr_initial: 0.003 | ||
- config: sfarinet-is2re-all | ||
note: 'Bigger size' | ||
frame_averaging: 2D | ||
fa_fames: se3-random | ||
optim: | ||
lr_initial: 0.001 | ||
model: | ||
hidden_channels: 500 | ||
num_interactions: 4 | ||
num_filters: 200 | ||
num_gaussians: 200 | ||
- config: sfarinet-is2re-all | ||
note: 'Bigger size' | ||
frame_averaging: 2D | ||
fa_fames: se3-random | ||
optim: | ||
lr_initial: 0.0007 | ||
model: | ||
hidden_channels: 500 | ||
num_interactions: 4 | ||
num_filters: 200 | ||
num_gaussians: 200 | ||
- config: sfarinet-is2re-all | ||
note: 'Bigger size' | ||
frame_averaging: 2D | ||
fa_fames: se3-random | ||
optim: | ||
lr_initial: 0.001 | ||
model: | ||
num_interactions: 6 | ||
- config: sfarinet-is2re-all | ||
note: 'Bigger size and smaller lr' | ||
frame_averaging: 2D | ||
fa_fames: se3-random | ||
optim: | ||
lr_initial: 0.0007 | ||
model: | ||
num_interactions: 6 | ||
- config: sfarinet-is2re-all | ||
note: 'Bigger size and change warmup steps' | ||
frame_averaging: 2D | ||
fa_fames: se3-random | ||
optim: | ||
lr_initial: 0.001 | ||
lr_milestones: | ||
- 20981 | ||
- 26972 | ||
- 35963 | ||
warmup_steps: 10094 | ||
model: | ||
hidden_channels: 500 | ||
num_interactions: 4 | ||
num_filters: 200 | ||
num_gaussians: 200 | ||
- config: sfarinet-is2re-all | ||
note: 'Much Bigger size' | ||
frame_averaging: 2D | ||
fa_fames: se3-random | ||
optim: | ||
lr_initial: 0.0007 | ||
model: | ||
hidden_channels: 800 | ||
num_interactions: 4 | ||
num_filters: 284 | ||
num_gaussians: 284 | ||
- config: sfarinet-is2re-all | ||
note: 'Smaller size more interactions' | ||
frame_averaging: 2D | ||
fa_fames: se3-random | ||
optim: | ||
lr_initial: 0.001 | ||
model: | ||
hidden_channels: 128 | ||
num_interactions: 6 | ||
num_filters: 100 | ||
num_gaussians: 100 | ||
|
Oops, something went wrong.