Skip to content

Commit

Permalink
add models/cgcnn/{slurm_train_cgcnn_ensemble,use_cgcnn_ensemble}.py
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Jun 20, 2023
1 parent bf36e42 commit a2e3f46
Show file tree
Hide file tree
Showing 6 changed files with 242 additions and 43 deletions.
3 changes: 2 additions & 1 deletion models/bowsr/slurm_array_bowsr_wbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@
slurm_job_id = os.environ.get("SLURM_JOB_ID", "debug")
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
out_path = f"{out_dir}/{slurm_array_task_id}.json.gz"
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"

print(f"Job started running {datetime.now():%Y-%m-%d@%H-%M}")
print(f"Job started running {timestamp}")
print(f"{slurm_job_id = }")
print(f"{slurm_array_task_id = }")
print(f"{data_path = }")
Expand Down
116 changes: 116 additions & 0 deletions models/cgcnn/slurm_train_cgcnn_ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# %%
import os
from datetime import datetime

import pandas as pd
from aviary import ROOT
from aviary.cgcnn.data import CrystalGraphData, collate_batch
from aviary.cgcnn.model import CrystalGraphConvNet
from aviary.core import TaskType
from aviary.train import df_train_test_split, train_model
from pymatgen.core import Structure
from torch.utils.data import DataLoader
from tqdm import tqdm

from matbench_discovery.slurm import slurm_submit_python

"""
Train a Wrenformer ensemble of size n_folds on target_col of data_path.
"""

__author__ = "Janosh Riebesell"
__date__ = "2022-06-13"


# %%
epochs = 300
target_col = "formation_energy_per_atom"
run_name = f"cgcnn-robust-{epochs=}-{target_col}"
print(f"{run_name=}")
robust = "robust" in run_name.lower()
n_folds = 10
today = f"{datetime.now():%Y-%m-%d}"
log_dir = f"{os.path.dirname(__file__)}/{today}-{run_name}"

slurm_submit_python(
job_name=run_name,
partition="ampere",
account="LEE-SL3-GPU",
time="8:0:0",
array=f"1-{n_folds}",
log_dir=log_dir,
slurm_flags=("--nodes", "1", "--gpus-per-node", "1"),
# prepend into sbatch script to source module command and load default env
# for Ampere GPU partition before actual job command
pre_cmd=". /etc/profile.d/modules.sh; module load rhel8/default-amp;",
)


# %%
optimizer = "AdamW"
learning_rate = 3e-4
batch_size = 128
swa_start = None
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
task_type: TaskType = "regression"


# %%
data_path = f"{ROOT}/datasets/2022-08-13-mp-energies.json.gz"
# data_path = f"{ROOT}/datasets/2022-08-13-mp-energies-1k-samples.json.gz"
print(f"{data_path=}")
df = pd.read_json(data_path).set_index("material_id", drop=False)
df["structure"] = [Structure.from_dict(s) for s in tqdm(df.structure, disable=None)]
assert target_col in df

train_df, test_df = df_train_test_split(df, test_size=0.5)

train_data = CrystalGraphData(train_df, task_dict={target_col: task_type})
train_loader = DataLoader(
train_data, batch_size=batch_size, shuffle=True, collate_fn=collate_batch
)

test_data = CrystalGraphData(test_df, task_dict={target_col: task_type})
test_loader = DataLoader(
test_data, batch_size=batch_size, shuffle=False, collate_fn=collate_batch
)

# 1 for regression, n_classes for classification
n_targets = [1 if task_type == "regression" else df[target_col].max() + 1]

model_params = dict(
n_targets=n_targets,
elem_emb_len=train_data.elem_emb_len,
nbr_fea_len=train_data.nbr_fea_dim,
task_dict={target_col: task_type}, # e.g. {'exfoliation_en': 'regression'}
robust=robust,
)
model = CrystalGraphConvNet(**model_params)

run_params = dict(
batch_size=batch_size,
train_df=dict(shape=train_data.df.shape, columns=", ".join(train_df)),
test_df=dict(shape=test_data.df.shape, columns=", ".join(test_df)),
)


# %%
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"
print(f"Job started running {timestamp}")

train_model(
checkpoint="wandb", # None | 'local' | 'wandb',
epochs=epochs,
learning_rate=learning_rate,
model_params=model_params,
model=model,
optimizer=optimizer,
run_name=run_name,
swa_start=swa_start,
target_col=target_col,
task_type=task_type,
test_loader=test_loader,
timestamp=timestamp,
train_loader=train_loader,
wandb_path="janosh/matbench-discovery",
)
71 changes: 71 additions & 0 deletions models/cgcnn/use_cgcnn_ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# %%
from __future__ import annotations

import os
from datetime import datetime

import pandas as pd
import wandb
from aviary.cgcnn.data import CrystalGraphData, collate_batch
from aviary.cgcnn.model import CrystalGraphConvNet
from aviary.deploy import predict_from_wandb_checkpoints
from pymatgen.core import Structure
from torch.utils.data import DataLoader
from tqdm import tqdm

from matbench_discovery import ROOT
from matbench_discovery.plot_scripts import df_wbm

__author__ = "Janosh Riebesell"
__date__ = "2022-08-15"

"""
Script that downloads checkpoints for an ensemble of Wrenformer models trained on the MP
formation energies, then makes predictions on some dataset, prints ensemble metrics and
stores predictions to CSV.
"""

module_dir = os.path.dirname(__file__)
today = f"{datetime.now():%Y-%m-%d}"


# %%
data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-init-structs.json.bz2"
df = pd.read_json(data_path).set_index("material_id", drop=False)
old_len = len(df)
df = df.dropna() # two missing initial structures
assert len(df) == old_len - 2

df["e_form_per_atom_mp2020_corrected"] = df_wbm.e_form_per_atom_mp2020_corrected

target_col = "e_form_per_atom_mp2020_corrected"
input_col = "initial_structure"
assert target_col in df, f"{target_col=} not in {list(df)}"
assert input_col in df, f"{input_col=} not in {list(df)}"

df[input_col] = [Structure.from_dict(x) for x in tqdm(df[input_col])]

wandb.login()
wandb_api = wandb.Api()
ensemble_id = "cgcnn-e_form-ensemble-1"
runs = wandb_api.runs(
"janosh/matbench-discovery", filters={"tags": {"$in": [ensemble_id]}}
)

assert len(runs) == 10, f"Expected 10 runs, got {len(runs)} for {ensemble_id=}"

cg_data = CrystalGraphData(
df, task_dict={target_col: "regression"}, structure_col=input_col
)
data_loader = DataLoader(
cg_data, batch_size=1024, shuffle=False, collate_fn=collate_batch
)
df, ensemble_metrics = predict_from_wandb_checkpoints(
runs,
df=df,
target_col=target_col,
model_class=CrystalGraphConvNet,
data_loader=data_loader,
)

df.round(6).to_csv(f"{module_dir}/{today}-{ensemble_id}-preds-{target_col}.csv")
3 changes: 2 additions & 1 deletion models/m3gnet/slurm_array_m3gnet_wbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@
# %%
slurm_job_id = os.environ.get("SLURM_JOB_ID", "debug")
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"

print(f"Job started running {datetime.now():%Y-%m-%d@%H-%M}")
print(f"Job started running {timestamp}")
print(f"{slurm_job_id = }")
print(f"{slurm_array_task_id = }")
print(f"{version('m3gnet') = }")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

import pandas as pd
import wandb
from aviary.wrenformer.deploy import deploy_wandb_checkpoints
from aviary.deploy import predict_from_wandb_checkpoints
from aviary.wrenformer.data import df_to_in_mem_dataloader
from aviary.wrenformer.model import Wrenformer

__author__ = "Janosh Riebesell"
__date__ = "2022-08-15"
Expand All @@ -26,8 +28,10 @@
data_path = "https://figshare.com/files/37570234?private_link=ff0ad14505f9624f0c05"
df = pd.read_csv(data_path).set_index("material_id")


target_col = "e_form_per_atom"
input_col = "wyckoff"
assert target_col in df, f"{target_col=} not in {list(df)}"
assert input_col in df, f"{input_col=} not in {list(df)}"

wandb.login()
wandb_api = wandb.Api()
Expand All @@ -38,8 +42,17 @@

assert len(runs) == 10, f"Expected 10 runs, got {len(runs)} for {ensemble_id=}"

df, ensemble_metrics = deploy_wandb_checkpoints(
runs, df, input_col="wyckoff", target_col=target_col
data_loader = df_to_in_mem_dataloader(
df=df,
target_col=target_col,
batch_size=1024,
input_col=input_col,
embedding_type="wyckoff",
shuffle=False, # False is default but best be explicit
)

df, ensemble_metrics = predict_from_wandb_checkpoints(
runs, data_loader, df=df, model_class=Wrenformer
)

df.round(6).to_csv(f"{module_dir}/{today}-{ensemble_id}-preds-{target_col}.csv")
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,37 @@
import os
from datetime import datetime

from aviary.wrenformer.train import train_wrenformer_on_df
import pandas as pd
from aviary import ROOT
from aviary.train import df_train_test_split, train_wrenformer

from matbench_discovery import ROOT
from matbench_discovery.slurm import slurm_submit_python

"""
Train a Wrenformer
ensemble of size n_folds on target_col of df_or_path.
Train a Wrenformer ensemble of size n_folds on target_col of data_path.
"""

__author__ = "Janosh Riebesell"
__date__ = "2022-08-13"


# %%
df_or_path = f"{ROOT}/data/2022-08-13-mp-energies.json.gz"
target_col = "energy_per_atom"
# df_or_path = f"{ROOT}/data/2022-08-25-m3gnet-trainset-mp-2021-struct-energy.json.gz"
# target_col = "mp_energy_per_atom"

epochs = 300
job_name = f"wrenformer-robust-{epochs=}-{target_col}"
target_col = "e_form"
run_name = f"wrenformer-robust-mp+wbm-{epochs=}-{target_col}"
n_folds = 10
today = f"{datetime.now():%Y-%m-%d}"
dataset = "mp"
# dataset = 'm3gnet_train_set'
log_dir = f"{os.path.dirname(__file__)}/{dataset}/{today}-{job_name}"
log_dir = f"{os.path.dirname(__file__)}/{dataset}/{today}-{run_name}"

slurm_submit_python(
job_name=job_name,
job_name=run_name,
partition="ampere",
account="LEE-SL3-GPU",
time="8:0:0",
array=f"1-{n_folds}",
log_dir=log_dir,
account="LEE-SL3-GPU",
slurm_flags=("--nodes 1", "--gpus-per-node 1"),
slurm_flags=("--nodes", "1", "--gpus-per-node", "1"),
# prepend into sbatch script to source module command and load default env
# for Ampere GPU partition before actual job command
pre_cmd=". /etc/profile.d/modules.sh; module load rhel8/default-amp;",
Expand All @@ -48,39 +43,41 @@


# %%
n_attn_layers = 3
embedding_aggregations = ("mean",)
optimizer = "AdamW"
learning_rate = 3e-4
task_type = "regression"
checkpoint = "wandb" # None | 'local' | 'wandb'
data_path = f"{ROOT}/data/2022-08-13-mp-energies.json.gz"
target_col = "energy_per_atom"
# data_path = f"{ROOT}/data/2022-08-25-m3gnet-trainset-mp-2021-struct-energy.json.gz"
# target_col = "mp_energy_per_atom"
batch_size = 128
swa_start = None
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"

print(f"Job started running {datetime.now():%Y-%m-%d@%H-%M}")
slurm_job_id = os.environ.get("SLURM_JOB_ID")
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
print(f"Job started running {timestamp}")
print(f"{run_name=}")
print(f"{data_path=}")

print(f"{slurm_job_id=}")
print(f"{slurm_array_task_id=}")
print(f"{job_name=}")
print(f"{df_or_path=}")
df = pd.read_json(data_path).set_index("material_id", drop=False)
assert target_col in df
train_df, test_df = df_train_test_split(df, test_size=0.3)

run_params = dict(
batch_size=batch_size,
train_df=dict(shape=train_df.shape, columns=", ".join(train_df)),
test_df=dict(shape=test_df.shape, columns=", ".join(test_df)),
)

train_wrenformer_on_df(
run_name=job_name,
train_wrenformer(
run_name=run_name,
train_df=train_df,
test_df=test_df,
target_col=target_col,
df_or_path=df_or_path,
task_type="regression",
timestamp=timestamp,
test_size=0.05,
# folds=(n_folds, slurm_array_task_id),
epochs=epochs,
n_attn_layers=n_attn_layers,
checkpoint=checkpoint,
optimizer=optimizer,
checkpoint="wandb", # None | 'local' | 'wandb',
learning_rate=learning_rate,
embedding_aggregations=embedding_aggregations,
batch_size=batch_size,
swa_start=swa_start,
wandb_path="janosh/matbench-discovery",
run_params=run_params,
)

0 comments on commit a2e3f46

Please sign in to comment.