Skip to content

Commit

Permalink
Extend model metadata schema (#86)
Browse files Browse the repository at this point in the history
* convert model-metadata-schema from .json to .yaml

make keys targets, test_task, train_task required

* add train_size model_params openness to Key enum

* fix loading multiple metadata files per model dir in models.py

* move TS defs into separate src/lib/types.ts file

* package.json add dev dep json-schema-to-typescript to auto-generate model-metadata.d.ts

* mv model-metadata-schema.y(a->'')ml

* fix load() util for model checkpoint loading

discovered broken for alignn_checkpoint

* add keys train_task test_task targets model_type model_params to all model metadata files

* add model_params to wandb-collected run params in all test scripts

* rename DictableStrEnum->LabelEnum, add ability to set key labels and add labels to all Keys

* set additionalProperties: false on model_metadata_schema.yml

fix typo in mace.yml + gnome.yml
remove non-parametric option for model_params

* add model params col to metrics table

refactor make_metrics_tables.py to load all metadata from model yaml files

* change MP v2022.10.28 model training set titles

* fix model openness toggle, ModelCard remove DOI from model links
  • Loading branch information
janosh authored Feb 6, 2024
1 parent 4d40cd6 commit e7a5abe
Show file tree
Hide file tree
Showing 39 changed files with 655 additions and 517 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,5 @@ repos:
hooks:
- id: check-jsonschema
files: ^models/(.+)/\1.*\.yml$
args: [--schemafile, tests/model-metadata-schema.json]
args: [--schemafile, tests/model-metadata-schema.yml]
- id: check-github-actions
1 change: 1 addition & 0 deletions matbench_discovery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Key,
Model,
ModelType,
Open,
Quantity,
Targets,
Task,
Expand Down
2 changes: 2 additions & 0 deletions matbench_discovery/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def load(
if ".pkl" in file_path: # handle key='mp_patched_phase_diagram' separately
with gzip.open(cache_path, "rb") as zip_file:
return pickle.load(zip_file)
if ".pth" in file_path: # handle model checkpoints (e.g. key='alignn_checkpoint')
return cache_path

csv_ext = (".csv", ".csv.gz", ".csv.bz2")
reader = pd.read_csv if file_path.endswith(csv_ext) else pd.read_json
Expand Down
113 changes: 66 additions & 47 deletions matbench_discovery/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,71 +6,90 @@
from pymatviz.utils import styled_html_tag


class DictableStrEnum(StrEnum):
"""StrEnum with optional description attributes and dict() method."""
class LabelEnum(StrEnum):
"""StrEnum with optional label and description attributes plus dict() method."""

def __new__(cls, val: str, desc: str | None = None) -> Self:
def __new__(
cls, val: str, label: str | None = None, desc: str | None = None
) -> Self:
"""Create a new class."""
member = str.__new__(cls, val)
member._value_ = val
member.__dict__["desc"] = desc
member.__dict__ |= dict(label=label, desc=desc)
return member

@property
def label(self) -> str:
"""Make label read-only."""
return self.__dict__["label"]

@property
def description(self) -> str:
"""Make description read-only."""
return self.__dict__["desc"]

@classmethod
def dict(cls) -> dict[str, str]:
"""Return the enum as a dictionary."""
"""Return the Enum as dictionary."""
return {key: str(val) for key, val in cls.__members__.items()}


@unique
class Key(DictableStrEnum):
class Key(LabelEnum):
"""Keys used to access dataframes columns."""

arity = "arity"
bandgap_pbe = "bandgap_pbe"
chem_sys = "chemical_system"
composition = "composition"
cse = "computed_structure_entry"
dft_energy = "uncorrected_energy"
e_form = "e_form_per_atom_mp2020_corrected"
e_form_pred = "e_form_per_atom_pred"
e_form_raw = "e_form_per_atom_uncorrected"
e_form_wbm = "e_form_per_atom_wbm"
each = "energy_above_hull" # as returned by MP API
each_pred = "e_above_hull_pred"
each_true = "e_above_hull_mp2020_corrected_ppd_mp"
each_wbm = "e_above_hull_wbm"
final_struct = "relaxed_structure"
forces = "forces"
form_energy = "formation_energy_per_atom"
formula = "formula"
init_struct = "initial_structure"
magmoms = "magmoms"
mat_id = "material_id"
model_mean_each = "Mean prediction all models"
model_mean_err = "Mean error all models"
model_std_each = "Std. dev. over models"
n_sites = "n_sites"
site_nums = "site_nums"
spacegroup = "spacegroup"
stress = "stress"
stress_trace = "stress_trace"
struct = "structure"
task_id = "task_id"
arity = "arity", "Arity"
bandgap_pbe = "bandgap_pbe", "PBE Band Gap"
chem_sys = "chemical_system", "Chemical System"
composition = "composition", "Composition"
cse = "computed_structure_entry", "Computed Structure Entry"
dft_energy = "uncorrected_energy", "DFT Energy"
e_form = "e_form_per_atom_mp2020_corrected", "DFT E_form"
e_form_pred = "e_form_per_atom_pred", "Predicted E_form"
e_form_raw = "e_form_per_atom_uncorrected", "DFT E_form raw"
e_form_wbm = "e_form_per_atom_wbm", "WBM E_form"
each = "energy_above_hull", "E<sub>hull dist</sub>"
each_pred = "e_above_hull_pred", "Predicted E<sub>hull dist</sub>"
each_true = "e_above_hull_mp2020_corrected_ppd_mp", "E<sub>MP hull dist</sub>"
each_wbm = "e_above_hull_wbm", "E<sub>WBM hull dist</sub>"
final_struct = "relaxed_structure", "Relaxed Structure"
forces = "forces", "Forces"
form_energy = "formation_energy_per_atom", "Formation Energy (eV/atom)"
formula = "formula", "Formula"
init_struct = "initial_structure", "Initial Structure"
magmoms = "magmoms", "Magnetic Moments"
mat_id = "material_id", "Material ID"
model_mean_each = "mean_pred_models", "Mean prediction all models"
model_mean_err = "each_err_models", "Mean E<sub>hull dist</sub> error all models"
model_std_each = "each_std_models", "Std. dev. all models"
n_sites = "n_sites", "Number of Sites"
site_nums = "site_nums", "Site Numbers", "Atomic numbers for each crystal site"
spacegroup = "spacegroup", "Spacegroup"
stress = "stress", "Stress"
stress_trace = "stress_trace", "Stress Trace"
struct = "structure", "Structure"
task_id = "task_id", "Task ID"
task_type = "task_type", "Task Type"
train_task = "train_task", "Training Task"
test_task = "test_task", "Test Task"
targets = "targets", "Targets"
# lowest WBM structures for a given prototype that isn't already in MP
uniq_proto = "unique_prototype"
volume = "volume"
wyckoff = "wyckoff_spglib" # relaxed structure Aflow label
init_wyckoff = "wyckoff_spglib_initial_structure" # initial structure Aflow label
uniq_proto = "unique_prototype", "Unique Prototype"
volume = "volume", "Volume (ų)"
wyckoff = "wyckoff_spglib", "Aflow-Wyckoff Label" # relaxed structure Aflow label
init_wyckoff = (
"wyckoff_spglib_initial_structure",
"Aflow-Wyckoff Label Initial Structure",
)
# number of structures in a model's training set
train_size = "train_size", "Training Size"
model_params = "model_params", "Model Params" # model's parameter count
model_type = "model_type", "Model Type" # number of parameters in the model
openness = "openness", "Openness" # openness of data and code for a model


@unique
class Task(DictableStrEnum):
class Task(LabelEnum):
"""Thermodynamic stability prediction task types."""

IS2RE = "IS2RE", "initial structure to relaxed energy"
Expand All @@ -86,7 +105,7 @@ class Task(DictableStrEnum):


@unique
class Targets(DictableStrEnum):
class Targets(LabelEnum):
"""Thermodynamic stability prediction task types."""

E = "E", "energy"
Expand All @@ -95,7 +114,7 @@ class Targets(DictableStrEnum):


@unique
class ModelType(DictableStrEnum):
class ModelType(LabelEnum):
"""Model types."""

GNN = "GNN", "graph neural network"
Expand All @@ -107,7 +126,7 @@ class ModelType(DictableStrEnum):


@unique
class Open(DictableStrEnum):
class Open(LabelEnum):
"""Openness of data and code for a model."""

OSOD = "OSOD", "open source, open data"
Expand All @@ -121,7 +140,7 @@ class Open(DictableStrEnum):
)


class Quantity(DictableStrEnum):
class Quantity(LabelEnum):
"""Quantity labels for plots."""

n_atoms = "Atom Count"
Expand All @@ -147,7 +166,7 @@ class Quantity(DictableStrEnum):
stress_trace = "1/3 Tr(σ) (eV/ų)" # noqa: RUF001


class Model(DictableStrEnum):
class Model(LabelEnum):
"""Model labels for plots."""

alignn_ff = "ALIGNN FF"
Expand Down
24 changes: 12 additions & 12 deletions matbench_discovery/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@
md_files = glob(f"{model_dir}*.yml")
if not 1 <= len(md_files) <= 2:
raise RuntimeError(f"expected 1 metadata file, got {md_files=} in {model_dir=}")
md_file = md_files[0]
if md_file.endswith("aborted.yml"):
continue
# make sure all required keys are non-empty
with open(md_file) as yml_file:
models = yaml.full_load(yml_file)
for md_file in md_files:
if md_file.endswith("aborted.yml"):
continue
# make sure all required keys are non-empty
with open(md_file) as yml_file:
models = yaml.full_load(yml_file)

# some metadata files contain a single model, some have multiple
if not isinstance(models, list):
models = [models]
for model in models:
model["model_dir"] = model_dir
MODEL_METADATA[model["model_name"]] = model
# some metadata files contain a single model, some have multiple
if not isinstance(models, list):
models = [models]
for model in models:
model["model_dir"] = model_dir
MODEL_METADATA[model["model_name"]] = model
9 changes: 7 additions & 2 deletions models/alignn/alignn.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,15 @@ requirements:
scikit-learn: 1.2.2
torch: 1.9.0+cu111
trained_for_benchmark: true
# hyperparams: see align-config.json
model_type: GNN
train_task: RS2RE
test_task: IS2RE
targets: E
model_params: 4_026_753 # pre-trained 'mp_e_form_alignn' and our custom MBD checkpoint have the same size
# other hyperparams: see align-config.json

training_set: # model trained from specifically for MBD
title: MP Computed Structure Entries
title: MP v2022.10.28
url: https://figshare.com/ndownloader/files/40344436
n_structures: 154_719
# training_set: # NIST pre-trained model
Expand Down
26 changes: 14 additions & 12 deletions models/alignn/test_alignn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@


# %%
model_name = "mp_e_form_alignn" # pre-trained by NIST
# TODO fix this to load checkpoint from figshare
# model_name = f"{module_dir}/data-train-result/best-model.pth"
# model_name = "mp_e_form_alignn" # pre-trained by NIST (not used for MBD submission)
model_name = DATA_FILES.alignn_checkpoint # trained by Philipp Benner
task_type = Task.IS2RE
target_col = Key.e_form
input_col = Key.init_struct
Expand Down Expand Up @@ -93,15 +92,18 @@


# %%
run_params = dict(
data_path=data_path,
versions={dep: version(dep) for dep in ("megnet", "numpy")},
model_name=model_name,
task_type=task_type,
target_col=target_col,
df=dict(shape=str(df_in.shape), columns=", ".join(df_in)),
slurm_vars=slurm_vars,
)
run_params = {
"data_path": data_path,
"versions": {dep: version(dep) for dep in ("megnet", "numpy")},
"model_name": model_name,
"task_type": task_type,
"target_col": target_col,
"df": {"shape": str(df_in.shape), "columns": ", ".join(df_in)},
"slurm_vars": slurm_vars,
Key.model_params: sum( # count trainable params
p.numel() for p in model.parameters() if p.requires_grad
),
}

wandb.init(project="matbench-discovery", name=job_name, config=run_params)

Expand Down
2 changes: 1 addition & 1 deletion models/alignn/train_alignn.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def df_to_loader(
**kwargs: Additional arguments to pass to the StructureDataset
Returns:
DataLoader: _description_
DataLoader: PyTorch data loader
"""
graphs = load_graphs(
df, neighbor_strategy=config.neighbor_strategy, use_canonize=config.use_canonize
Expand Down
5 changes: 5 additions & 0 deletions models/bowsr/bowsr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ requirements:
numpy: 1.24.0
pandas: 1.5.1
trained_for_benchmark: false
train_task: RS2RE
test_task: IS2RE-BO
targets: E
model_type: BO-GNN
model_params: 167_761

training_set:
title: Graphs of MP 2019
Expand Down
25 changes: 14 additions & 11 deletions models/bowsr/test_bowsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,23 +86,26 @@
seed=42,
)
optimize_kwargs = dict(n_init=100, n_iter=100, alpha=0.026**2)
model = MEGNet()

run_params = dict(
bayes_optim_kwargs=bayes_optim_kwargs,
data_path=data_path,
df=dict(shape=str(df_in.shape), columns=", ".join(df_in)),
energy_model=energy_model,
versions={dep: version(dep) for dep in ("maml", "numpy", energy_model)},
optimize_kwargs=optimize_kwargs,
task_type=task_type,
slurm_vars=slurm_vars,
)

# %%
run_params = {
"bayes_optim_kwargs": bayes_optim_kwargs,
"data_path": data_path,
"df": {"shape": str(df_in.shape), "columns": ", ".join(df_in)},
"energy_model": energy_model,
"versions": {dep: version(dep) for dep in ("maml", "numpy", energy_model)},
"optimize_kwargs": optimize_kwargs,
"task_type": task_type,
"slurm_vars": slurm_vars,
Key.model_params: sum(np.prod(p.shape) for p in model.model.trainable_weights),
}

wandb.init(project="matbench-discovery", name=job_name, config=run_params)


# %%
model = MEGNet()
relax_results: dict[str, dict[str, Any]] = {}
input_col = {Task.IS2RE: Key.init_struct, Task.RS2RE: Key.final_struct}[task_type]

Expand Down
9 changes: 7 additions & 2 deletions models/cgcnn/cgcnn+p.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,16 @@ requirements:
numpy: 1.24.0
pandas: 1.5.1
trained_for_benchmark: true
train_task: S2RE
test_task: IS2RE
targets: E
model_type: GNN
model_params: 128_450

training_set:
title: MP Computed Structure Entries
title: MP v2022.10.28
url: https://figshare.com/ndownloader/files/40344436
n_structures: 154719
n_structures: 154_719

hyperparams:
Ensemble Size: 10
Expand Down
7 changes: 6 additions & 1 deletion models/cgcnn/cgcnn.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,14 @@ requirements:
numpy: 1.24.0
pandas: 1.5.1
trained_for_benchmark: true
train_task: RS2RE
test_task: IS2E
targets: E
model_type: GNN
model_params: 128_450

training_set:
title: MP Computed Structure Entries
title: MP v2022.10.28
url: https://figshare.com/ndownloader/files/40344436
n_structures: 154_719

Expand Down
Loading

0 comments on commit e7a5abe

Please sign in to comment.