Skip to content

Commit

Permalink
Add MACE (#48)
Browse files Browse the repository at this point in the history
* add scripts/model_figs/update_all_model_figs.py

goal is to be able to run a single script that will update all the model figures
(both the interactive version for the leaderboard and the PDF version for the paper)

* prepare MACE submission (readme.md, metadata.yml, test_mace.py)

* fix model name still set to chgnet in test_mace.py

* wandb collect all dep versions in single dict

* delete matbench_discovery.DEBUG global

* roc_prc_curves_models.py fix n_rows x n_cols in out filename

* test_mace.py add relax trajectory recording

* add MACE + ALIGNN checkpoints figshare urls to class DataFiles

* rename BOWSR + MEGnet -> BOWSR

* ensure out_path matches glob_pattern in join scripts

* load_df_wbm_with_preds() use 1st matching df column

* refactor df_to_pdf() from wkhtmltopdf to weasyprint

* improve Figshare description

* update most figures with MACE results

* extract scripts/model_figs/per_element_errors.py out of scripts/analyze_element_errors.py

to run former as part of update_model_figs.py

* add MACE to site/src/figs/model-run-times-bar.svelte

* revert 'extract scripts/model_figs/per_element_errors.py out of scripts/analyze_element_errors.py'

* add weasyprint to df-pdf-export optional deps

* fix test_df_metrics failing from MACE R^2 of -1.291 being below -0.9

* fix df_to_pdf if weasyprint not installed
  • Loading branch information
janosh authored Jul 26, 2023
1 parent 6696d22 commit a549532
Show file tree
Hide file tree
Showing 78 changed files with 1,264 additions and 1,004 deletions.
8 changes: 8 additions & 0 deletions data/figshare/1.0.0.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
{
"alignn_checkpoint": [
"https://figshare.com/ndownloader/files/41233560",
"2023-06-02-pbenner-best-alignn-model.pth.zip"
],
"mace_checkpoint": [
"https://figshare.com/ndownloader/files/41565618",
"2023-07-14-mace-universal-2-big-128-6.model"
],
"mp_computed_structure_entries": [
"https://figshare.com/ndownloader/files/40344436",
"2023-02-07-mp-computed-structure-entries.json.gz"
Expand Down
2 changes: 1 addition & 1 deletion data/mp/build_phase_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
df = pd.read_json(data_path).set_index("material_id")

# drop the structure, just load ComputedEntry, makes the PPD faster to build and load
mp_computed_entries = [ComputedEntry.from_dict(x) for x in tqdm(df.entry)]
mp_computed_entries = [ComputedEntry.from_dict(dct) for dct in tqdm(df.entry)]

print(f"{len(mp_computed_entries) = :,} on {today}")
# len(mp_computed_entries) = 146,323 on 2022-09-16
Expand Down
5 changes: 3 additions & 2 deletions data/wbm/compare_cse_vs_ce_mp_2020_corrections.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@
)

cses = [
ComputedStructureEntry.from_dict(x) for x in tqdm(df_cse.computed_structure_entry)
ComputedStructureEntry.from_dict(dct)
for dct in tqdm(df_cse.computed_structure_entry)
]

ces = [ComputedEntry.from_dict(x) for x in tqdm(df_cse.computed_structure_entry)]
ces = [ComputedEntry.from_dict(dct) for dct in tqdm(df_cse.computed_structure_entry)]


warnings.filterwarnings(action="ignore", category=UserWarning, module="pymatgen")
Expand Down
7 changes: 4 additions & 3 deletions data/wbm/fetch_process_wbm_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,8 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
assert mat_id == cse["entry_id"], f"{mat_id} != {cse['entry_id']}"

df_wbm["cse"] = [
ComputedStructureEntry.from_dict(x) for x in tqdm(df_wbm.computed_structure_entry)
ComputedStructureEntry.from_dict(dct)
for dct in tqdm(df_wbm.computed_structure_entry)
]
# raw WBM ComputedStructureEntries have no energy corrections applied:
assert all(cse.uncorrected_energy == cse.energy for cse in df_wbm.cse)
Expand Down Expand Up @@ -640,6 +641,6 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
).set_index("material_id")

df_wbm["cse"] = [
ComputedStructureEntry.from_dict(x)
for x in tqdm(df_wbm.computed_structure_entry)
ComputedStructureEntry.from_dict(dct)
for dct in tqdm(df_wbm.computed_structure_entry)
]
5 changes: 0 additions & 5 deletions matbench_discovery/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Global variables used all across the matbench_discovery package."""

import os
import sys
from datetime import datetime

ROOT = os.path.dirname(os.path.dirname(__file__)) # repo root directory
Expand All @@ -13,10 +12,6 @@
for directory in [FIGS, MODELS, FIGSHARE, PDF_FIGS]:
os.makedirs(directory, exist_ok=True)

# whether a currently running slurm job is in debug mode
DEBUG = "DEBUG" in os.environ or (
"slurm-submit" not in sys.argv and "SLURM_JOB_ID" not in os.environ
)
# directory to store model checkpoints downloaded from wandb cloud storage
CHECKPOINT_DIR = f"{ROOT}/wandb/checkpoints"
# wandb <entity>/<project name> to record new runs to
Expand Down
3 changes: 3 additions & 0 deletions matbench_discovery/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def _on_not_found(self, key: str, msg: str) -> None: # type: ignore[override]
if answer == "y":
load(key) # download and cache data file

# TODO maybe set attrs to None and load file names from Figshare json
mp_computed_structure_entries = (
"mp/2023-02-07-mp-computed-structure-entries.json.gz"
)
Expand All @@ -246,6 +247,8 @@ def _on_not_found(self, key: str, msg: str) -> None: # type: ignore[override]
"wbm/2022-10-19-wbm-computed-structure-entries+init-structs.json.bz2"
)
wbm_summary = "wbm/2022-10-19-wbm-summary.csv.gz"
alignn_checkpoint = "2023-06-02-pbenner-best-alignn-model.pth.zip"
mace_checkpoint = "2023-07-14-mace-universal-2-big-128-6.model"


# data files can be downloaded and cached with matbench_discovery.data.load()
Expand Down
91 changes: 68 additions & 23 deletions matbench_discovery/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import math
import os
import subprocess
from collections import defaultdict
from collections.abc import Sequence
from pathlib import Path
Expand Down Expand Up @@ -65,7 +66,7 @@ def unit(text: str) -> str:
model_labels = dict(
alignn="ALIGNN",
alignn_pretrained="ALIGNN Pretrained",
bowsr_megnet="BOWSR + MEGNet",
bowsr_megnet="BOWSR",
chgnet="CHGNet",
chgnet_megnet="CHGNet + MEGNet",
cgcnn_p="CGCNN+P",
Expand All @@ -74,6 +75,7 @@ def unit(text: str) -> str:
m3gnet="M3GNet",
m3gnet_direct="M3GNet DIRECT",
m3gnet_ms="M3GNet MS",
mace="MACE",
megnet="MEGNet",
voronoi_rf="Voronoi RF",
wrenformer="Wrenformer",
Expand Down Expand Up @@ -874,38 +876,81 @@ def df_to_svelte_table(
def df_to_pdf(
styler: Styler, file_path: str | Path, crop: bool = True, **kwargs: Any
) -> None:
"""Export a pandas Styler to PDF.
"""Export a pandas Styler to PDF with WeasyPrint.
Args:
styler (Styler): Styler object to export.
file_path (str): Path to save the PDF to. Requires pdfkit.
crop (bool): Whether to crop the PDF margins. Requires pdfCropMargins. Defaults
to True.
file_path (str): Path to save the PDF to. Requires WeasyPrint.
crop (bool): Whether to crop the PDF margins. Requires pdfCropMargins.
Defaults to True.
**kwargs: Keyword arguments passed to Styler.to_html().
"""
try:
# pdfkit used to export pandas Styler to PDF, requires:
# pip install pdfkit && brew install homebrew/cask/wkhtmltopdf
import pdfkit
from weasyprint import HTML
except ImportError as exc:
raise ImportError(
"pdfkit not installed\nrun pip install pdfkit && brew install "
"homebrew/cask/wkhtmltopdf\n(brew is macOS only, use apt on linux)"
) from exc

pdfkit.from_string(styler.to_html(**kwargs), file_path)
if not crop:
return
msg = "weasyprint not installed\nrun pip install weasyprint"
raise ImportError(msg) from exc

html_str = styler.to_html(**kwargs)

# CSS to adjust layout and margins
html_str = f"""
<style>
@page {{ size: landscape; margin: 1cm; }}
body {{ margin: 0; padding: 1em; }}
</style>
{html_str}
"""

html = HTML(string=html_str)

html.write_pdf(file_path)

if crop:
normalize_and_crop_pdf(file_path)


def normalize_and_crop_pdf(file_path: str | Path) -> None:
"""Normalize a PDF using Ghostscript and then crop it.
Without gs normalization, pdfCropMargins sometimes corrupts the PDF.
Args:
file_path (str | Path): Path to the PDF file.
"""
try:
# needed to auto-crop large white margins from PDF
# pip install pdfCropMargins
from pdfCropMargins import crop as crop_pdf
normalized_file_path = f"{file_path}_normalized.pdf"
from pdfCropMargins import crop

# Normalize the PDF with Ghostscript
subprocess.run(
[
"gs",
"-sDEVICE=pdfwrite",
"-dCompatibilityLevel=1.4",
"-dPDFSETTINGS=/default",
"-dNOPAUSE",
"-dQUIET",
"-dBATCH",
f"-sOutputFile={normalized_file_path}",
str(file_path),
]
)

# Remove PDF margins
cropped_file_path, _exit_code, _stdout, _stderr = crop_pdf(
["--percentRetain", "0", file_path]
# Crop the normalized PDF
cropped_file_path, exit_code, stdout, stderr = crop(
["--percentRetain", "0", normalized_file_path]
)
os.replace(cropped_file_path, file_path)

if stderr:
print(f"pdfCropMargins {stderr=}")
# something went wrong, remove the cropped PDF
os.remove(cropped_file_path)
else:
# replace the original PDF with the cropped one
os.replace(cropped_file_path, str(file_path))

os.remove(normalized_file_path)

except ImportError as exc:
msg = "pdfCropMargins not installed\nrun pip install pdfCropMargins"
raise ImportError(msg) from exc
Expand Down
14 changes: 12 additions & 2 deletions matbench_discovery/preds.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class PredFiles(Files):
# m3gnet_direct = "m3gnet/2023-05-30-m3gnet-direct-wbm-IS2RE.csv.gz"
# m3gnet_ms = "m3gnet/2023-06-01-m3gnet-manual-sampling-wbm-IS2RE.csv.gz"

# MACE trained on original M3GNet training set
mace = "mace/2023-07-23-mace-wbm-IS2RE-FIRE.csv.gz"

# original MEGNet straight from publication, not re-trained
megnet = "megnet/2022-11-18-megnet-wbm-IS2RE.csv.gz"
# CHGNet-relaxed structures fed into MEGNet for formation energy prediction
Expand Down Expand Up @@ -106,8 +109,15 @@ def load_df_wbm_with_preds(
df_out = df_wbm.copy()
for model_name, df in dfs.items():
model_key = model_name.lower().replace(" + ", "_").replace(" ", "_")
if (col := f"e_form_per_atom_{model_key}") in df:
df_out[model_name] = df[col]

cols = [col for col in df if col.startswith(f"e_form_per_atom_{model_key}")]
if cols:
if len(cols) > 1:
print(
f"Warning: multiple pred cols for {model_name=}, using {cols[0]!r} "
f"out of {cols=}"
)
df_out[model_name] = df[cols[0]]

elif pred_cols := list(df.filter(like="_pred_ens")):
assert len(pred_cols) == 1
Expand Down
10 changes: 5 additions & 5 deletions models/alignn/test_alignn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from sklearn.metrics import r2_score
from tqdm import tqdm

from matbench_discovery import DEBUG, today
from matbench_discovery import today
from matbench_discovery.data import DATA_FILES, df_wbm
from matbench_discovery.plots import wandb_scatter
from matbench_discovery.slurm import slurm_submit
Expand All @@ -36,7 +36,7 @@
input_col = "initial_structure"
id_col = "material_id"
device = "cuda" if torch.cuda.is_available() else "cpu"
job_name = f"{model_name}-wbm-{task_type}{'-debug' if DEBUG else ''}"
job_name = f"{model_name}-wbm-{task_type}"
out_dir = os.getenv("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")


Expand Down Expand Up @@ -85,15 +85,15 @@
assert input_col in df_in, f"{input_col=} not in {list(df_in)}"

df_in[input_col] = [
JarvisAtomsAdaptor.get_atoms(Structure.from_dict(x))
for x in tqdm(df_in[input_col], leave=False, desc="Converting to JARVIS atoms")
JarvisAtomsAdaptor.get_atoms(Structure.from_dict(dct))
for dct in tqdm(df_in[input_col], leave=False, desc="Converting to JARVIS atoms")
]


# %%
run_params = dict(
data_path=data_path,
**{f"{dep}_version": version(dep) for dep in ("megnet", "numpy")},
versions={dep: version(dep) for dep in ("megnet", "numpy")},
model_name=model_name,
task_type=task_type,
target_col=target_col,
Expand Down
8 changes: 4 additions & 4 deletions models/alignn/train_alignn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch.utils.data import DataLoader
from tqdm import tqdm

from matbench_discovery import DEBUG, today
from matbench_discovery import today
from matbench_discovery.data import DATA_FILES
from matbench_discovery.slurm import slurm_submit

Expand All @@ -35,7 +35,7 @@
input_col = "atoms"
id_col = "material_id"
device = "cuda" if torch.cuda.is_available() else "cpu"
job_name = f"train-{model_name}{'-debug' if DEBUG else ''}"
job_name = f"train-{model_name}"


pred_col = "e_form_per_atom_alignn"
Expand All @@ -49,7 +49,7 @@
slurm_vars = slurm_submit(
job_name=job_name,
# partition="perlmuttter",
account="matgen_g",
account="matgen",
time="4:0:0",
out_dir=out_dir,
slurm_flags="--qos regular --constraint gpu --gpus 1",
Expand Down Expand Up @@ -79,7 +79,7 @@
# %%
run_params = dict(
data_path=DATA_FILES.mp_energies,
**{f"{dep}_version": version(dep) for dep in ("alignn", "numpy", "torch", "dgl")},
versions={dep: version(dep) for dep in ("alignn", "numpy", "torch", "dgl")},
model_name=model_name,
target_col=target_col,
df=dict(shape=str(df_in.shape), columns=", ".join(df_in)),
Expand Down
3 changes: 1 addition & 2 deletions models/bowsr/join_bowsr_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import pymatviz
from tqdm import tqdm

from matbench_discovery import today
from matbench_discovery.data import DATA_FILES

__author__ = "Janosh Riebesell"
Expand Down Expand Up @@ -66,7 +65,7 @@


# %%
out_path = f"{module_dir}/{today}-bowsr-megnet-wbm-{task_type}"
out_path = f"{module_dir}/{glob_pattern.split('/*')[0]}"
df_bowsr = df_bowsr.round(4)
# save energy and formation energy as fast-loading CSV
df_bowsr.select_dtypes("number").to_csv(f"{out_path}.csv")
Expand Down
2 changes: 1 addition & 1 deletion models/bowsr/metadata.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
model_name: BOWSR + MEGNet
model_name: BOWSR
model_version: 2022.9.20
matbench_discovery_version: 1.0
date_added: "2022-11-17"
Expand Down
Loading

0 comments on commit a549532

Please sign in to comment.