diff --git a/matbench_discovery/__init__.py b/matbench_discovery/__init__.py index dba41a7f..82b6094d 100644 --- a/matbench_discovery/__init__.py +++ b/matbench_discovery/__init__.py @@ -6,14 +6,13 @@ import os import warnings from datetime import datetime -from enum import StrEnum, unique from importlib.metadata import Distribution -from typing import Self import matplotlib.pyplot as plt import plotly.express as px import plotly.io as pio -from pymatviz.utils import styled_html_tag + +from matbench_discovery.enums import Key, Model, ModelType, Quantity, Task # noqa: F401 pkg_name = "matbench-discovery" direct_url = Distribution.from_name(pkg_name).read_text("direct_url.json") or "{}" @@ -57,177 +56,10 @@ warnings.filterwarnings(action="ignore", category=UserWarning, module="pymatgen") -class DictableStrEnum(StrEnum): - """StrEnum with optional description attributes and dict() method.""" - - def __new__(cls, val: str, desc: str | None = None) -> Self: - """Create a new class.""" - member = str.__new__(cls, val) - member._value_ = val - member.__dict__["desc"] = desc - return member - - @property - def description(self) -> str: - """Make description read-only.""" - return self.__dict__["desc"] - - @classmethod - def dict(cls) -> dict[str, str]: - """Return the enum as a dictionary.""" - return {key: str(val) for key, val in cls.__members__.items()} - - -@unique -class Key(DictableStrEnum): - """Keys used to access dataframes columns.""" - - arity = "arity" - bandgap_pbe = "bandgap_pbe" - chem_sys = "chemical_system" - composition = "composition" - cse = "computed_structure_entry" - dft_energy = "uncorrected_energy" - e_form = "e_form_per_atom_mp2020_corrected" - e_form_pred = "e_form_per_atom_pred" - e_form_raw = "e_form_per_atom_uncorrected" - e_form_wbm = "e_form_per_atom_wbm" - each = "energy_above_hull" # as returned by MP API - each_pred = "e_above_hull_pred" - each_true = "e_above_hull_mp2020_corrected_ppd_mp" - each_wbm = "e_above_hull_wbm" - final_struct = "relaxed_structure" - forces = "forces" - form_energy = "formation_energy_per_atom" - formula = "formula" - init_struct = "initial_structure" - magmoms = "magmoms" - mat_id = "material_id" - model_mean_each = "Mean prediction all models" - model_mean_err = "Mean error all models" - model_std_each = "Std. dev. over models" - n_sites = "n_sites" - site_nums = "site_nums" - spacegroup = "spacegroup" - stress = "stress" - stress_trace = "stress_trace" - struct = "structure" - task_id = "task_id" - # lowest WBM structures for a given prototype that isn't already in MP - uniq_proto = "unique_prototype" - volume = "volume" - wyckoff = "wyckoff_spglib" # relaxed structure Aflow label - init_wyckoff = "wyckoff_spglib_initial_structure" # initial structure Aflow label - - -@unique -class Task(DictableStrEnum): - """Thermodynamic stability prediction task types.""" - - IS2RE = "IS2RE", "initial structure to relaxed energy" - RS2RE = "RS2RE", "relaxed structure to relaxed energy" - S2EFSM = "S2EFSM", "structure to energy force stress magmom" - S2EFS = "S2EFS", "structure to energy force stress" - # S2RE is for models that learned a discrete version of PES like CGCNN+P - S2RE = "S2RE", "structure to relaxed energy" - RP2RE = "RP2RE", "relaxed prototype to relaxed energy" - IP2RE = "IP2RE", "initial prototype to relaxed energy" - IS2E = "IS2E", "initial structure to energy" - IS2RE_SR = "IS2RE-SR", "initial structure to relaxed energy after ML relaxation" - - -@unique -class Targets(DictableStrEnum): - """Thermodynamic stability prediction task types.""" - - E = "E", "energy" - EFS = "EFS", "energy forces stress" - EFSM = "EFSM", "energy forces stress magmoms" - - -@unique -class ModelType(DictableStrEnum): - """Model types.""" - - GNN = "GNN", "graph neural network" - UIP = "UIP-GNN", "universal interatomic potential" - BO_GNN = "BO-GNN", "GNN in a Bayesian optimization loop" - Fingerprint = "Fingerprint", "models with structural fingerprint features" # ex. RF - Transformer = "Transformer", "transformer-based models" # ex. Wrenformer - RF = "RF", "random forest" - - -@unique -class Open(DictableStrEnum): - """Openness of data and code for a model.""" - - OSOD = "OSOD", "open source, open data" - CSOD = "CSOD", "closed source, open data" - OSCD = "OSCD", "open source, closed data" - CSCD = "CSCD", "closed source, closed data" - - with open(f"{FIGSHARE_DIR}/1.0.0.json") as file: FIGSHARE_URLS = json.load(file) # --- start global plot settings -ev_per_atom = styled_html_tag( - "(eV/atom)", tag="span", style="font-size: 0.8em; font-weight: lighter;" -) - - -class Quantity(DictableStrEnum): - """Quantity labels for plots.""" - - n_atoms = "Atom Count" - n_elems = "Element Count" - crystal_sys = "Crystal system" - spg_num = "Space group" - n_wyckoff = "Number of Wyckoff positions" - n_sites = "Number of atoms" - energy_per_atom = f"Energy {ev_per_atom}" - e_form = f"DFT Eform {ev_per_atom}" - e_above_hull = f"Ehull dist {ev_per_atom}" - e_above_hull_mp2020_corrected_ppd_mp = f"DFT Ehull dist {ev_per_atom}" - e_above_hull_pred = f"Predicted Ehull dist {ev_per_atom}" - e_above_hull_mp = f"Eabove MP hull {ev_per_atom}" - e_above_hull_error = f"Error in Ehull dist {ev_per_atom}" - vol_diff = "Volume difference (A^3)" - e_form_per_atom_mp2020_corrected = f"DFT Eform {ev_per_atom}" - e_form_per_atom_pred = f"Predicted Eform {ev_per_atom}" - material_id = "Material ID" - band_gap = "Band gap (eV)" - formula = "Formula" - stress = "σ (eV/ų)" # noqa: RUF001 - stress_trace = "1/3 Tr(σ) (eV/ų)" # noqa: RUF001 - - -class Model(DictableStrEnum): - """Model labels for plots.""" - - alignn_ff = "ALIGNN FF" - alignn_pretrained = "ALIGNN Pretrained" - alignn = "ALIGNN" - bowsr_megnet = "BOWSR" - cgcnn_p = "CGCNN+P" - cgcnn = "CGCNN" - chgnet_megnet = "CHGNet→MEGNet" - chgnet = "CHGNet" - dft = "DFT" - gnome = "GNoME" - m3gnet_direct = "M3GNet DIRECT" - m3gnet_megnet = "M3GNet→MEGNet" - m3gnet_ms = "M3GNet MS" - m3gnet = "M3GNet" - mace = "MACE" - megnet_rs2re = "MEGNet RS2RE" - megnet = "MEGNet" - pfp = "PFP" - voronoi_rf = "Voronoi RF" - wbm = "WBM" - wrenformer = "Wrenformer" - - px.defaults.labels = Quantity.dict() | Model.dict() diff --git a/matbench_discovery/enums.py b/matbench_discovery/enums.py new file mode 100644 index 00000000..4123e3f9 --- /dev/null +++ b/matbench_discovery/enums.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +from enum import StrEnum, unique +from typing import Self + +from pymatviz.utils import styled_html_tag + + +class DictableStrEnum(StrEnum): + """StrEnum with optional description attributes and dict() method.""" + + def __new__(cls, val: str, desc: str | None = None) -> Self: + """Create a new class.""" + member = str.__new__(cls, val) + member._value_ = val + member.__dict__["desc"] = desc + return member + + @property + def description(self) -> str: + """Make description read-only.""" + return self.__dict__["desc"] + + @classmethod + def dict(cls) -> dict[str, str]: + """Return the enum as a dictionary.""" + return {key: str(val) for key, val in cls.__members__.items()} + + +@unique +class Key(DictableStrEnum): + """Keys used to access dataframes columns.""" + + arity = "arity" + bandgap_pbe = "bandgap_pbe" + chem_sys = "chemical_system" + composition = "composition" + cse = "computed_structure_entry" + dft_energy = "uncorrected_energy" + e_form = "e_form_per_atom_mp2020_corrected" + e_form_pred = "e_form_per_atom_pred" + e_form_raw = "e_form_per_atom_uncorrected" + e_form_wbm = "e_form_per_atom_wbm" + each = "energy_above_hull" # as returned by MP API + each_pred = "e_above_hull_pred" + each_true = "e_above_hull_mp2020_corrected_ppd_mp" + each_wbm = "e_above_hull_wbm" + final_struct = "relaxed_structure" + forces = "forces" + form_energy = "formation_energy_per_atom" + formula = "formula" + init_struct = "initial_structure" + magmoms = "magmoms" + mat_id = "material_id" + model_mean_each = "Mean prediction all models" + model_mean_err = "Mean error all models" + model_std_each = "Std. dev. over models" + n_sites = "n_sites" + site_nums = "site_nums" + spacegroup = "spacegroup" + stress = "stress" + stress_trace = "stress_trace" + struct = "structure" + task_id = "task_id" + # lowest WBM structures for a given prototype that isn't already in MP + uniq_proto = "unique_prototype" + volume = "volume" + wyckoff = "wyckoff_spglib" # relaxed structure Aflow label + init_wyckoff = "wyckoff_spglib_initial_structure" # initial structure Aflow label + + +@unique +class Task(DictableStrEnum): + """Thermodynamic stability prediction task types.""" + + IS2RE = "IS2RE", "initial structure to relaxed energy" + RS2RE = "RS2RE", "relaxed structure to relaxed energy" + S2EFSM = "S2EFSM", "structure to energy force stress magmom" + S2EFS = "S2EFS", "structure to energy force stress" + # S2RE is for models that learned a discrete version of PES like CGCNN+P + S2RE = "S2RE", "structure to relaxed energy" + RP2RE = "RP2RE", "relaxed prototype to relaxed energy" + IP2RE = "IP2RE", "initial prototype to relaxed energy" + IS2E = "IS2E", "initial structure to energy" + IS2RE_SR = "IS2RE-SR", "initial structure to relaxed energy after ML relaxation" + + +@unique +class Targets(DictableStrEnum): + """Thermodynamic stability prediction task types.""" + + E = "E", "energy" + EFS = "EFS", "energy forces stress" + EFSM = "EFSM", "energy forces stress magmoms" + + +@unique +class ModelType(DictableStrEnum): + """Model types.""" + + GNN = "GNN", "graph neural network" + UIP = "UIP-GNN", "universal interatomic potential" + BO_GNN = "BO-GNN", "GNN in a Bayesian optimization loop" + Fingerprint = "Fingerprint", "models with structural fingerprint features" # ex. RF + Transformer = "Transformer", "transformer-based models" # ex. Wrenformer + RF = "RF", "random forest" + + +@unique +class Open(DictableStrEnum): + """Openness of data and code for a model.""" + + OSOD = "OSOD", "open source, open data" + CSOD = "CSOD", "closed source, open data" + OSCD = "OSCD", "open source, closed data" + CSCD = "CSCD", "closed source, closed data" + + +ev_per_atom = styled_html_tag( + "(eV/atom)", tag="span", style="font-size: 0.8em; font-weight: lighter;" +) + + +class Quantity(DictableStrEnum): + """Quantity labels for plots.""" + + n_atoms = "Atom Count" + n_elems = "Element Count" + crystal_sys = "Crystal system" + spg_num = "Space group" + n_wyckoff = "Number of Wyckoff positions" + n_sites = "Number of atoms" + energy_per_atom = f"Energy {ev_per_atom}" + e_form = f"DFT Eform {ev_per_atom}" + e_above_hull = f"Ehull dist {ev_per_atom}" + e_above_hull_mp2020_corrected_ppd_mp = f"DFT Ehull dist {ev_per_atom}" + e_above_hull_pred = f"Predicted Ehull dist {ev_per_atom}" + e_above_hull_mp = f"Eabove MP hull {ev_per_atom}" + e_above_hull_error = f"Error in Ehull dist {ev_per_atom}" + vol_diff = "Volume difference (A^3)" + e_form_per_atom_mp2020_corrected = f"DFT Eform {ev_per_atom}" + e_form_per_atom_pred = f"Predicted Eform {ev_per_atom}" + material_id = "Material ID" + band_gap = "Band gap (eV)" + formula = "Formula" + stress = "σ (eV/ų)" # noqa: RUF001 + stress_trace = "1/3 Tr(σ) (eV/ų)" # noqa: RUF001 + + +class Model(DictableStrEnum): + """Model labels for plots.""" + + alignn_ff = "ALIGNN FF" + alignn_pretrained = "ALIGNN Pretrained" + alignn = "ALIGNN" + bowsr_megnet = "BOWSR" + cgcnn_p = "CGCNN+P" + cgcnn = "CGCNN" + chgnet_megnet = "CHGNet→MEGNet" + chgnet = "CHGNet" + dft = "DFT" + gnome = "GNoME" + m3gnet_direct = "M3GNet DIRECT" + m3gnet_megnet = "M3GNet→MEGNet" + m3gnet_ms = "M3GNet MS" + m3gnet = "M3GNet" + mace = "MACE" + megnet_rs2re = "MEGNet RS2RE" + megnet = "MEGNet" + pfp = "PFP" + voronoi_rf = "Voronoi RF" + wbm = "WBM" + wrenformer = "Wrenformer" diff --git a/matbench_discovery/preds.py b/matbench_discovery/preds.py index 2c31f18f..5e4c2d6e 100644 --- a/matbench_discovery/preds.py +++ b/matbench_discovery/preds.py @@ -52,7 +52,7 @@ class PredFiles(Files): # megnet_rs2re = "megnet/2023-08-23-megnet-wbm-RS2RE.csv.gz" # Magpie composition+Voronoi tessellation structure features + sklearn random forest - voronoi_rf = "voronoi/2022-11-27-train-test/e-form-preds-IS2RE.csv.gz" + voronoi_rf = "voronoi_rf/2022-11-27-train-test/e-form-preds-IS2RE.csv.gz" # wrenformer 10-member ensemble wrenformer = "wrenformer/2022-11-15-wrenformer-ens=10-IS2RE-preds.csv.gz"