Skip to content

Commit

Permalink
Refactor model preds loading to use new required pred_col key in YA…
Browse files Browse the repository at this point in the history
…ML metadata files (#134)

* consistent join_(model)_preds.py file names

* remove figshare/* from setuptools.package-data

* fix max possible DAF in CaptionedMetricsTable.svelte

now counts uniq WBM prototypes only (excl. duplicates and ones already in MP)

* rename "Show proprietary models" toggle to "Show non-compliant models"

* refactor model preds loading to use new required pred_col key in yaml metadata files
  • Loading branch information
janosh committed Sep 4, 2024
1 parent a2e4c94 commit a614864
Show file tree
Hide file tree
Showing 45 changed files with 163 additions and 172 deletions.
1 change: 0 additions & 1 deletion matbench_discovery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
SITE_LIB = f"{ROOT}/site/src/lib"
SCRIPTS = f"{ROOT}/scripts" # model and date analysis scripts
PDF_FIGS = f"{ROOT}/paper/figs" # directory for light-themed PDF figures
FIGSHARE_DIR = f"{PKG_DIR}/figshare"

for directory in (SITE_FIGS, SITE_LIB, PDF_FIGS):
os.makedirs(directory, exist_ok=True)
Expand Down
8 changes: 1 addition & 7 deletions matbench_discovery/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,13 @@
from pymatviz.enums import Key
from tqdm import tqdm

from matbench_discovery import DATA_DIR, FIGSHARE_DIR, pkg_is_editable
from matbench_discovery import DATA_DIR, pkg_is_editable

# ruff: noqa: T201
T = TypeVar("T", bound="Files")

# repo URL to raw files on GitHub
RAW_REPO_URL = "https://github.com/janosh/matbench-discovery/raw"
figshare_versions = sorted(
x.split(os.path.sep)[-1].split(".json")[0] for x in glob(f"{FIGSHARE_DIR}/*.json")
)
# directory to cache downloaded data files
DEFAULT_CACHE_DIR = os.getenv(
"MATBENCH_DISCOVERY_CACHE_DIR",
Expand Down Expand Up @@ -207,9 +204,6 @@ def __new__(
file from and directory where to save it to.
"""
obj = str.__new__(cls)
if url is not None and len(url) == 33:
# looks like a Google Drive ID, turn into direct download link
url = f"https://drive.usercontent.google.com/download?id={url}&confirm=t"
obj._value_ = file_path.split("/")[-1] # use file name as enum value

obj._rel_path = file_path # type: ignore[attr-defined] # noqa: SLF001
Expand Down
161 changes: 65 additions & 96 deletions matbench_discovery/preds.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pandas as pd
import plotly.express as px
import yaml
from pymatviz.enums import Key
from tqdm import tqdm

Expand All @@ -24,64 +25,56 @@ class Model(Files, base_dir=f"{ROOT}/models"):
See https://janosh.github.io/matbench-discovery/contribute for data descriptions.
"""

alignn = "alignn/2023-06-02-alignn-wbm-IS2RE.csv.gz", None, "ALIGNN"
# alignn_pretrained = "alignn/2023-06-03-mp-e-form-alignn-wbm-IS2RE.csv.gz", None, "ALIGNN Pretrained"
# alignn_ff = "alignn_ff/2023-07-11-alignn-ff-wbm-IS2RE.csv.gz", None, "ALIGNN FF"
alignn = "alignn/2023-06-02-alignn-wbm-IS2RE.csv.gz", "alignn/alignn.yml", "ALIGNN"
# alignn_pretrained = "alignn/2023-06-03-mp-e-form-alignn-wbm-IS2RE.csv.gz", "alignn/alignn.yml", "ALIGNN Pretrained"
# alignn_ff = "alignn_ff/2023-07-11-alignn-ff-wbm-IS2RE.csv.gz", "alignn/alignn-ff.yml", "ALIGNN FF"

# BOWSR optimizer coupled with original megnet
bowsr_megnet = "bowsr/2023-01-23-bowsr-megnet-wbm-IS2RE.csv.gz", None, "BOWSR"
bowsr_megnet = "bowsr/2023-01-23-bowsr-megnet-wbm-IS2RE.csv.gz", "bowsr/bowsr+megnet.yml", "BOWSR" # fmt: skip

# default CHGNet model from publication with 400,438 params
chgnet = "chgnet/2023-12-21-chgnet-0.3.0-wbm-IS2RE.csv.gz", None, "CHGNet"
chgnet = "chgnet/2023-12-21-chgnet-0.3.0-wbm-IS2RE.csv.gz", "chgnet/chgnet.yml", "CHGNet" # fmt: skip
# chgnet_no_relax = "chgnet/2023-12-05-chgnet-0.3.0-wbm-IS2RE-no-relax.csv.gz", None, "CHGNet No Relax"

# CGCNN 10-member ensemble
cgcnn = "cgcnn/2023-01-26-cgcnn-ens=10-wbm-IS2RE.csv.gz", None, "CGCNN"
cgcnn = "cgcnn/2023-01-26-cgcnn-ens=10-wbm-IS2RE.csv.gz", "cgcnn/cgcnn.yml", "CGCNN"

# CGCNN 10-member ensemble with 5-fold training set perturbations
cgcnn_p = "cgcnn/2023-02-05-cgcnn-perturb=5-wbm-IS2RE.csv.gz", None, "CGCNN+P"
cgcnn_p = "cgcnn/2023-02-05-cgcnn-perturb=5-wbm-IS2RE.csv.gz", "cgcnn/cgcnn+p.yml", "CGCNN+P" # fmt: skip

# original M3GNet straight from publication, not re-trained
m3gnet = "m3gnet/2023-12-28-m3gnet-wbm-IS2RE.csv.gz", None, "M3GNet"
m3gnet = "m3gnet/2023-12-28-m3gnet-wbm-IS2RE.csv.gz", "m3gnet/m3gnet.yml", "M3GNet"
# m3gnet_direct = "m3gnet/2023-05-30-m3gnet-direct-wbm-IS2RE.csv.gz", None, "M3GNet DIRECT"
# m3gnet_ms = "m3gnet/2023-06-01-m3gnet-manual-sampling-wbm-IS2RE.csv.gz", None, "M3GNet MS"

# MACE-MP as published in https://arxiv.org/abs/2401.00096 trained on MPtrj
mace = "mace/2023-12-11-mace-wbm-IS2RE-FIRE.csv.gz", None, "MACE"
mace = "mace/2023-12-11-mace-wbm-IS2RE-FIRE.csv.gz", "mace/mace.yml", "MACE"
# mace_alex = "mace/2024-08-09-mace-wbm-IS2RE-FIRE.csv.gz", None, "MACE Alex"
# https://github.com/ACEsuit/mace-mp/releases/tag/mace_mp_0b
# mace_0b = "mace/2024-07-20-mace-wbm-IS2RE-FIRE.csv.gz", None, "MACE 0b"

# original MEGNet straight from publication, not re-trained
megnet = "megnet/2022-11-18-megnet-wbm-IS2RE.csv.gz", None, "MEGNet"
megnet = "megnet/2022-11-18-megnet-wbm-IS2RE.csv.gz", "megnet/megnet.yml", "MEGNet"

# SevenNet trained on MPtrj
sevennet = "sevennet/2024-07-11-sevennet-preds.csv.gz", None, "SevenNet"
sevennet = "sevennet/2024-07-11-sevennet-preds.csv.gz", "sevennet/sevennet.yml", "SevenNet" # fmt: skip

# Magpie composition+Voronoi tessellation structure features + sklearn random forest
voronoi_rf = (
"voronoi_rf/2022-11-27-train-test/e-form-preds-IS2RE.csv.gz",
None,
"Voronoi RF",
)
voronoi_rf = "voronoi_rf/2022-11-27-train-test/e-form-preds-IS2RE.csv.gz", "voronoi_rf/voronoi_rf.yml", "Voronoi RF" # fmt: skip

# wrenformer 10-member ensemble
wrenformer = (
"wrenformer/2022-11-15-wrenformer-ens=10-IS2RE-preds.csv.gz",
None,
"Wrenformer",
)
wrenformer = "wrenformer/2022-11-15-wrenformer-ens=10-IS2RE-preds.csv.gz", "wrenformer/wrenformer.yml", "Wrenformer" # fmt: skip

# --- Proprietary Models
# GNoMe
gnome = "gnome/2023-11-01-gnome-preds-50076332.csv.gz", None, "GNoME"
# GNoME
gnome = "gnome/2023-11-01-gnome-preds-50076332.csv.gz", "gnome/gnome.yml", "GNoME"

# MatterSim
mattersim = "mattersim/mattersim-wbm-IS2RE.csv.gz", None, "MatterSim"
mattersim = "mattersim/mattersim-wbm-IS2RE.csv.gz", "mattersim/mattersim.yml", "MatterSim" # fmt: skip

# ORB
orb = "orb/orbff-v1-20240827.csv.gz", None, "ORB"
orb_mptrj = "orb/orbff-mptrj-only-v1-20240827.csv.gz", None, "ORB-MPtrj"
orb = "orb/orbff-v1-20240827.csv.gz", "orb/orb.yml", "ORB"
orb_mptrj = "orb/orbff-mptrj-only-v1-20240827.csv.gz", "orb/orb-mptrj.yml", "ORB MPtrj" # fmt: skip

# --- Model Combos
# # CHGNet-relaxed structures fed into MEGNet for formation energy prediction
Expand Down Expand Up @@ -129,84 +122,60 @@ def load_df_wbm_with_preds(
Returns:
pd.DataFrame: WBM summary dataframe with model predictions.
"""
valid_pred_files = {model.name for model in Model}
valid_models = {model.name for model in Model}
if models == ():
models = tuple(valid_pred_files)
models = tuple(valid_models)
inv_label_map = {v: k for k, v in Model.label_map.items()}
# map pretty model names back to Model enum keys
models = {inv_label_map.get(model, model) for model in models}
if mismatch := ", ".join(models - valid_pred_files):
raise ValueError(
f"Unknown models: {mismatch}, expected subset of {valid_pred_files}"
)

dfs: dict[str, pd.DataFrame] = {}
model: str = ""
try:
for model in (bar := tqdm(models, disable=not pbar, desc="Loading preds")):
bar.set_postfix_str(model)
pred_file = Model[model]
df_preds = glob_to_df(pred_file.path, pbar=False, **kwargs)
dfs[pred_file.label] = df_preds.set_index(id_col)
except Exception as exc:
exc.add_note(f"Failed to load {model=}")
raise
if unknown_models := ", ".join(models - valid_models):
raise ValueError(f"{unknown_models=}, expected subset of {valid_models}")

model_name: str = ""
from matbench_discovery.data import df_wbm

df_out = df_wbm.copy()
for model_name, df_preds in dfs.items():
model_key = model_name.lower().replace("→", "_").replace(" ", "_")

cols = [
col
for col in df_preds
if col.startswith((f"e_form_per_atom_{model_key}", f"e_{model_key}_"))
]
if model_key == "mace_0b":
df_out[model_name] = df_preds["e_form_per_atom_mace"]

elif model_key == "orb-mptrj":
df_out[model_name] = df_preds["e_form_per_atom_orb"]

elif cols:
if len(cols) > 1:
print(
f"Warning: multiple pred cols for {model_name=}, using {cols[0]!r} "
f"out of {cols=}"

try:
prog_bar = tqdm(models, disable=not pbar, desc="Loading preds")
for model_name in prog_bar:
prog_bar.set_postfix_str(model_name)
pred_file = Model[model_name]
df_preds = glob_to_df(pred_file.path, pbar=False, **kwargs)

# Get prediction column name from metadata
model_key = getattr(Model, model_name)
model_label = model_key.label
model_yaml_path = f"{ROOT}/models/{model_key.url}"
with open(model_yaml_path) as file:
model_data = yaml.safe_load(file)

pred_col = model_data.get("pred_col")
if not pred_col:
raise ValueError(
f"pred_col not specified for {model_name} in {model_yaml_path}"
)
df_out[model_name] = df_preds[cols[0]]

elif pred_cols := list(df_preds.filter(like="_pred_ens")):
if len(pred_cols) != 1:
raise ValueError(f"{len(pred_cols)=}, expected 1")
df_out[model_name] = df_preds[pred_cols[0]]
if std_cols := list(df_preds.filter(like="_std_ens")):
df_out[f"{model_name}_std"] = df_preds[std_cols[0]]

elif pred_cols := list(df_preds.filter(like=r"_pred_")):
# make sure we average the expected number of ensemble member predictions
if len(pred_cols) != 10:
raise ValueError(f"{len(pred_cols)=}, expected 10")
df_out[model_name] = df_preds[pred_cols].mean(axis=1)

else:
cols = list(df_preds)
msg = f"No pred col for {model_name=}, available {cols=}"
if model_name != model_key:
msg = msg.replace(", ", f" ({model_key=}), ")
raise ValueError(msg)

if max_error_threshold is not None:
if max_error_threshold < 0:
raise ValueError("max_error_threshold must be a positive number")
# Apply centralized model prediction cleaning criterion (see doc string)
bad_mask = (
abs(df_out[model_name] - df_out[MbdKey.e_form_dft])
) > max_error_threshold
df_out.loc[bad_mask, model_name] = pd.NA
n_preds, n_bad = len(df_out[model_name].dropna()), sum(bad_mask)
if n_bad > 0:
print(f"{n_bad:,} of {n_preds:,} unrealistic preds for {model_name}")

if pred_col not in df_preds:
raise ValueError(f"{pred_col=} not found in {pred_file.path}")

df_out[model_label] = df_preds.set_index(id_col)[pred_col]
if max_error_threshold is not None:
if max_error_threshold < 0:
raise ValueError("max_error_threshold must be a positive number")
# Apply centralized model prediction cleaning criterion (see doc string)
bad_mask = (
abs(df_out[model_label] - df_out[MbdKey.e_form_dft])
) > max_error_threshold
df_out.loc[bad_mask, model_label] = pd.NA
n_preds, n_bad = len(df_out[model_label].dropna()), sum(bad_mask)
if n_bad > 0:
print(
f"{n_bad:,} of {n_preds:,} unrealistic preds for {model_name}"
)
except Exception as exc:
exc.add_note(f"Failed to load {model_name=}")
raise

if subset == "uniq_protos":
df_out = df_out.query(Key.uniq_proto)
Expand Down
1 change: 1 addition & 0 deletions models/alignn/alignn.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ requirements:
scikit-learn: 1.2.2
torch: 1.9.0+cu111

pred_col: e_form_per_atom_alignn
openness: OSOD
trained_for_benchmark: true
model_type: GNN
Expand Down
8 changes: 2 additions & 6 deletions models/alignn_ff/alignn-ff.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,6 @@ requirements:
openness: OSOD
model_type: UIP
trained_for_benchmark: true
# for hyperparams, see alignn-config.json

training_set:
title: MPtrj
url: https://figshare.com/articles/dataset/23713842
n_structures: 1_580_395
n_materials: 145_923
# for other hyperparams, see alignn-config.json
training_set: MPtrj
1 change: 1 addition & 0 deletions models/bowsr/bowsr.yml → models/bowsr/bowsr+megnet.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ requirements:
numpy: 1.24.0
pandas: 1.5.1

pred_col: e_form_per_atom_bowsr_megnet
openness: OSOD
trained_for_benchmark: false
train_task: RS2RE
Expand Down
File renamed without changes.
Binary file modified models/cgcnn/2023-01-26-cgcnn-ens=10-wbm-IS2RE.csv.gz
Binary file not shown.
1 change: 1 addition & 0 deletions models/cgcnn/cgcnn+p.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ requirements:
numpy: 1.24.0
pandas: 1.5.1

pred_col: e_form_per_atom_cgcnn_pred_ens
openness: OSOD
trained_for_benchmark: true
train_task: S2RE
Expand Down
1 change: 1 addition & 0 deletions models/cgcnn/cgcnn.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ requirements:
numpy: 1.24.0
pandas: 1.5.1

pred_col: e_form_per_atom_mp2020_corrected_pred_ens
openness: OSOD
trained_for_benchmark: true
train_task: RS2RE
Expand Down
7 changes: 2 additions & 5 deletions models/chgnet/chgnet.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ requirements:
pymatgen: 2022.10.22
numpy: 1.24.0

pred_col: e_form_per_atom_chgnet
openness: OSOD
trained_for_benchmark: false
train_task: S2EFSM
Expand All @@ -48,11 +49,7 @@ model_type: UIP
model_params: 412_525
n_estimators: 1

training_set:
title: MPtrj
url: https://figshare.com/articles/dataset/23713842
n_structures: 1_580_395
n_materials: 145_923
training_set: MPtrj

hyperparams:
max_steps: 2000
Expand Down
File renamed without changes.
1 change: 1 addition & 0 deletions models/gnome/gnome.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ requirements:
numpy: 1.26.2
pymatgen: 2023.11.12

pred_col: e_gnome_after_relax
openness: OSCD
trained_for_benchmark: false
train_task: S2EFS
Expand Down
File renamed without changes.
1 change: 1 addition & 0 deletions models/m3gnet/m3gnet.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ requirements:
numpy: 1.24.0
pandas: 1.5.1

pred_col: e_form_per_atom_m3gnet
openness: OSOD
trained_for_benchmark: false
train_task: S2EFS
Expand Down
File renamed without changes.
7 changes: 2 additions & 5 deletions models/mace/mace.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ requirements:
pymatgen: 2023.7.14
numpy: 1.25.0

pred_col: e_form_per_atom_mace
openness: OSOD
trained_for_benchmark: true
train_task: S2EFS
Expand All @@ -52,11 +53,7 @@ model_type: UIP
model_params: 4_688_656 # 2023-12-03-mace-128-L1: https://tinyurl.com/y7uhwpje
n_estimators: 1

training_set:
title: MPtrj
url: https://figshare.com/articles/dataset/23713842
n_structures: 1_580_395
n_materials: 145_923
training_set: MPtrj

hyperparams:
max_force: 0.05
Expand Down
1 change: 1 addition & 0 deletions models/mattersim/mattersim.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ requirements:
numpy: 1.26.2
pymatgen: 2024.5.1

pred_col: e_form_per_atom_mattersim
openness: CSCD
trained_for_benchmark: false
train_task: S2EFS
Expand Down
1 change: 1 addition & 0 deletions models/megnet/megnet.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ requirements:
numpy: 1.24.0
pandas: 1.5.1

pred_col: e_form_per_atom_megnet
openness: OSOD
trained_for_benchmark: false
train_task: RS2RE
Expand Down
File renamed without changes.
3 changes: 2 additions & 1 deletion models/orb/orb-mptrj.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
model_name: ORB-MPtrj
model_name: ORB MPtrj
model_version: v1
matbench_discovery_version: 1.2.0
date_added: "2024-09-02" # required
Expand Down Expand Up @@ -30,6 +30,7 @@ url: "#" # placeholder
doi: "#" # placeholder
paper: "#" # placeholder

pred_col: e_form_per_atom_orb
openness: OSOD
trained_for_benchmark: true
train_task: S2EFS
Expand Down
1 change: 1 addition & 0 deletions models/orb/orb.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ url: "#" # placeholder
doi: "#" # placeholder
paper: "#" # placeholder

pred_col: e_form_per_atom_orb
openness: OSCD
trained_for_benchmark: false
train_task: S2EFS
Expand Down
File renamed without changes.
Loading

0 comments on commit a614864

Please sign in to comment.