Skip to content

Commit

Permalink
extract all StrEnums to new matbench_discovery/enums.py
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Feb 3, 2024
1 parent 1bd5a00 commit 6d88d3f
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 171 deletions.
172 changes: 2 additions & 170 deletions matbench_discovery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
import os
import warnings
from datetime import datetime
from enum import StrEnum, unique
from importlib.metadata import Distribution
from typing import Self

import matplotlib.pyplot as plt
import plotly.express as px
import plotly.io as pio
from pymatviz.utils import styled_html_tag

from matbench_discovery.enums import Key, Model, ModelType, Quantity, Task # noqa: F401

pkg_name = "matbench-discovery"
direct_url = Distribution.from_name(pkg_name).read_text("direct_url.json") or "{}"
Expand Down Expand Up @@ -57,177 +56,10 @@
warnings.filterwarnings(action="ignore", category=UserWarning, module="pymatgen")


class DictableStrEnum(StrEnum):
"""StrEnum with optional description attributes and dict() method."""

def __new__(cls, val: str, desc: str | None = None) -> Self:
"""Create a new class."""
member = str.__new__(cls, val)
member._value_ = val
member.__dict__["desc"] = desc
return member

@property
def description(self) -> str:
"""Make description read-only."""
return self.__dict__["desc"]

@classmethod
def dict(cls) -> dict[str, str]:
"""Return the enum as a dictionary."""
return {key: str(val) for key, val in cls.__members__.items()}


@unique
class Key(DictableStrEnum):
"""Keys used to access dataframes columns."""

arity = "arity"
bandgap_pbe = "bandgap_pbe"
chem_sys = "chemical_system"
composition = "composition"
cse = "computed_structure_entry"
dft_energy = "uncorrected_energy"
e_form = "e_form_per_atom_mp2020_corrected"
e_form_pred = "e_form_per_atom_pred"
e_form_raw = "e_form_per_atom_uncorrected"
e_form_wbm = "e_form_per_atom_wbm"
each = "energy_above_hull" # as returned by MP API
each_pred = "e_above_hull_pred"
each_true = "e_above_hull_mp2020_corrected_ppd_mp"
each_wbm = "e_above_hull_wbm"
final_struct = "relaxed_structure"
forces = "forces"
form_energy = "formation_energy_per_atom"
formula = "formula"
init_struct = "initial_structure"
magmoms = "magmoms"
mat_id = "material_id"
model_mean_each = "Mean prediction all models"
model_mean_err = "Mean error all models"
model_std_each = "Std. dev. over models"
n_sites = "n_sites"
site_nums = "site_nums"
spacegroup = "spacegroup"
stress = "stress"
stress_trace = "stress_trace"
struct = "structure"
task_id = "task_id"
# lowest WBM structures for a given prototype that isn't already in MP
uniq_proto = "unique_prototype"
volume = "volume"
wyckoff = "wyckoff_spglib" # relaxed structure Aflow label
init_wyckoff = "wyckoff_spglib_initial_structure" # initial structure Aflow label


@unique
class Task(DictableStrEnum):
"""Thermodynamic stability prediction task types."""

IS2RE = "IS2RE", "initial structure to relaxed energy"
RS2RE = "RS2RE", "relaxed structure to relaxed energy"
S2EFSM = "S2EFSM", "structure to energy force stress magmom"
S2EFS = "S2EFS", "structure to energy force stress"
# S2RE is for models that learned a discrete version of PES like CGCNN+P
S2RE = "S2RE", "structure to relaxed energy"
RP2RE = "RP2RE", "relaxed prototype to relaxed energy"
IP2RE = "IP2RE", "initial prototype to relaxed energy"
IS2E = "IS2E", "initial structure to energy"
IS2RE_SR = "IS2RE-SR", "initial structure to relaxed energy after ML relaxation"


@unique
class Targets(DictableStrEnum):
"""Thermodynamic stability prediction task types."""

E = "E", "energy"
EFS = "EFS", "energy forces stress"
EFSM = "EFSM", "energy forces stress magmoms"


@unique
class ModelType(DictableStrEnum):
"""Model types."""

GNN = "GNN", "graph neural network"
UIP = "UIP-GNN", "universal interatomic potential"
BO_GNN = "BO-GNN", "GNN in a Bayesian optimization loop"
Fingerprint = "Fingerprint", "models with structural fingerprint features" # ex. RF
Transformer = "Transformer", "transformer-based models" # ex. Wrenformer
RF = "RF", "random forest"


@unique
class Open(DictableStrEnum):
"""Openness of data and code for a model."""

OSOD = "OSOD", "open source, open data"
CSOD = "CSOD", "closed source, open data"
OSCD = "OSCD", "open source, closed data"
CSCD = "CSCD", "closed source, closed data"


with open(f"{FIGSHARE_DIR}/1.0.0.json") as file:
FIGSHARE_URLS = json.load(file)

# --- start global plot settings
ev_per_atom = styled_html_tag(
"(eV/atom)", tag="span", style="font-size: 0.8em; font-weight: lighter;"
)


class Quantity(DictableStrEnum):
"""Quantity labels for plots."""

n_atoms = "Atom Count"
n_elems = "Element Count"
crystal_sys = "Crystal system"
spg_num = "Space group"
n_wyckoff = "Number of Wyckoff positions"
n_sites = "Number of atoms"
energy_per_atom = f"Energy {ev_per_atom}"
e_form = f"DFT E<sub>form</sub> {ev_per_atom}"
e_above_hull = f"E<sub>hull dist</sub> {ev_per_atom}"
e_above_hull_mp2020_corrected_ppd_mp = f"DFT E<sub>hull dist</sub> {ev_per_atom}"
e_above_hull_pred = f"Predicted E<sub>hull dist</sub> {ev_per_atom}"
e_above_hull_mp = f"E<sub>above MP hull</sub> {ev_per_atom}"
e_above_hull_error = f"Error in E<sub>hull dist</sub> {ev_per_atom}"
vol_diff = "Volume difference (A^3)"
e_form_per_atom_mp2020_corrected = f"DFT E<sub>form</sub> {ev_per_atom}"
e_form_per_atom_pred = f"Predicted E<sub>form</sub> {ev_per_atom}"
material_id = "Material ID"
band_gap = "Band gap (eV)"
formula = "Formula"
stress = "σ (eV/ų)" # noqa: RUF001
stress_trace = "1/3 Tr(σ) (eV/ų)" # noqa: RUF001


class Model(DictableStrEnum):
"""Model labels for plots."""

alignn_ff = "ALIGNN FF"
alignn_pretrained = "ALIGNN Pretrained"
alignn = "ALIGNN"
bowsr_megnet = "BOWSR"
cgcnn_p = "CGCNN+P"
cgcnn = "CGCNN"
chgnet_megnet = "CHGNet→MEGNet"
chgnet = "CHGNet"
dft = "DFT"
gnome = "GNoME"
m3gnet_direct = "M3GNet DIRECT"
m3gnet_megnet = "M3GNet→MEGNet"
m3gnet_ms = "M3GNet MS"
m3gnet = "M3GNet"
mace = "MACE"
megnet_rs2re = "MEGNet RS2RE"
megnet = "MEGNet"
pfp = "PFP"
voronoi_rf = "Voronoi RF"
wbm = "WBM"
wrenformer = "Wrenformer"


px.defaults.labels = Quantity.dict() | Model.dict()


Expand Down
173 changes: 173 additions & 0 deletions matbench_discovery/enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
from __future__ import annotations

from enum import StrEnum, unique
from typing import Self

from pymatviz.utils import styled_html_tag


class DictableStrEnum(StrEnum):
"""StrEnum with optional description attributes and dict() method."""

def __new__(cls, val: str, desc: str | None = None) -> Self:
"""Create a new class."""
member = str.__new__(cls, val)
member._value_ = val
member.__dict__["desc"] = desc
return member

@property
def description(self) -> str:
"""Make description read-only."""
return self.__dict__["desc"]

@classmethod
def dict(cls) -> dict[str, str]:
"""Return the enum as a dictionary."""
return {key: str(val) for key, val in cls.__members__.items()}


@unique
class Key(DictableStrEnum):
"""Keys used to access dataframes columns."""

arity = "arity"
bandgap_pbe = "bandgap_pbe"
chem_sys = "chemical_system"
composition = "composition"
cse = "computed_structure_entry"
dft_energy = "uncorrected_energy"
e_form = "e_form_per_atom_mp2020_corrected"
e_form_pred = "e_form_per_atom_pred"
e_form_raw = "e_form_per_atom_uncorrected"
e_form_wbm = "e_form_per_atom_wbm"
each = "energy_above_hull" # as returned by MP API
each_pred = "e_above_hull_pred"
each_true = "e_above_hull_mp2020_corrected_ppd_mp"
each_wbm = "e_above_hull_wbm"
final_struct = "relaxed_structure"
forces = "forces"
form_energy = "formation_energy_per_atom"
formula = "formula"
init_struct = "initial_structure"
magmoms = "magmoms"
mat_id = "material_id"
model_mean_each = "Mean prediction all models"
model_mean_err = "Mean error all models"
model_std_each = "Std. dev. over models"
n_sites = "n_sites"
site_nums = "site_nums"
spacegroup = "spacegroup"
stress = "stress"
stress_trace = "stress_trace"
struct = "structure"
task_id = "task_id"
# lowest WBM structures for a given prototype that isn't already in MP
uniq_proto = "unique_prototype"
volume = "volume"
wyckoff = "wyckoff_spglib" # relaxed structure Aflow label
init_wyckoff = "wyckoff_spglib_initial_structure" # initial structure Aflow label


@unique
class Task(DictableStrEnum):
"""Thermodynamic stability prediction task types."""

IS2RE = "IS2RE", "initial structure to relaxed energy"
RS2RE = "RS2RE", "relaxed structure to relaxed energy"
S2EFSM = "S2EFSM", "structure to energy force stress magmom"
S2EFS = "S2EFS", "structure to energy force stress"
# S2RE is for models that learned a discrete version of PES like CGCNN+P
S2RE = "S2RE", "structure to relaxed energy"
RP2RE = "RP2RE", "relaxed prototype to relaxed energy"
IP2RE = "IP2RE", "initial prototype to relaxed energy"
IS2E = "IS2E", "initial structure to energy"
IS2RE_SR = "IS2RE-SR", "initial structure to relaxed energy after ML relaxation"


@unique
class Targets(DictableStrEnum):
"""Thermodynamic stability prediction task types."""

E = "E", "energy"
EFS = "EFS", "energy forces stress"
EFSM = "EFSM", "energy forces stress magmoms"


@unique
class ModelType(DictableStrEnum):
"""Model types."""

GNN = "GNN", "graph neural network"
UIP = "UIP-GNN", "universal interatomic potential"
BO_GNN = "BO-GNN", "GNN in a Bayesian optimization loop"
Fingerprint = "Fingerprint", "models with structural fingerprint features" # ex. RF
Transformer = "Transformer", "transformer-based models" # ex. Wrenformer
RF = "RF", "random forest"


@unique
class Open(DictableStrEnum):
"""Openness of data and code for a model."""

OSOD = "OSOD", "open source, open data"
CSOD = "CSOD", "closed source, open data"
OSCD = "OSCD", "open source, closed data"
CSCD = "CSCD", "closed source, closed data"


ev_per_atom = styled_html_tag(
"(eV/atom)", tag="span", style="font-size: 0.8em; font-weight: lighter;"
)


class Quantity(DictableStrEnum):
"""Quantity labels for plots."""

n_atoms = "Atom Count"
n_elems = "Element Count"
crystal_sys = "Crystal system"
spg_num = "Space group"
n_wyckoff = "Number of Wyckoff positions"
n_sites = "Number of atoms"
energy_per_atom = f"Energy {ev_per_atom}"
e_form = f"DFT E<sub>form</sub> {ev_per_atom}"
e_above_hull = f"E<sub>hull dist</sub> {ev_per_atom}"
e_above_hull_mp2020_corrected_ppd_mp = f"DFT E<sub>hull dist</sub> {ev_per_atom}"
e_above_hull_pred = f"Predicted E<sub>hull dist</sub> {ev_per_atom}"
e_above_hull_mp = f"E<sub>above MP hull</sub> {ev_per_atom}"
e_above_hull_error = f"Error in E<sub>hull dist</sub> {ev_per_atom}"
vol_diff = "Volume difference (A^3)"
e_form_per_atom_mp2020_corrected = f"DFT E<sub>form</sub> {ev_per_atom}"
e_form_per_atom_pred = f"Predicted E<sub>form</sub> {ev_per_atom}"
material_id = "Material ID"
band_gap = "Band gap (eV)"
formula = "Formula"
stress = "σ (eV/ų)" # noqa: RUF001
stress_trace = "1/3 Tr(σ) (eV/ų)" # noqa: RUF001


class Model(DictableStrEnum):
"""Model labels for plots."""

alignn_ff = "ALIGNN FF"
alignn_pretrained = "ALIGNN Pretrained"
alignn = "ALIGNN"
bowsr_megnet = "BOWSR"
cgcnn_p = "CGCNN+P"
cgcnn = "CGCNN"
chgnet_megnet = "CHGNet→MEGNet"
chgnet = "CHGNet"
dft = "DFT"
gnome = "GNoME"
m3gnet_direct = "M3GNet DIRECT"
m3gnet_megnet = "M3GNet→MEGNet"
m3gnet_ms = "M3GNet MS"
m3gnet = "M3GNet"
mace = "MACE"
megnet_rs2re = "MEGNet RS2RE"
megnet = "MEGNet"
pfp = "PFP"
voronoi_rf = "Voronoi RF"
wbm = "WBM"
wrenformer = "Wrenformer"
2 changes: 1 addition & 1 deletion matbench_discovery/preds.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class PredFiles(Files):
# megnet_rs2re = "megnet/2023-08-23-megnet-wbm-RS2RE.csv.gz"

# Magpie composition+Voronoi tessellation structure features + sklearn random forest
voronoi_rf = "voronoi/2022-11-27-train-test/e-form-preds-IS2RE.csv.gz"
voronoi_rf = "voronoi_rf/2022-11-27-train-test/e-form-preds-IS2RE.csv.gz"

# wrenformer 10-member ensemble
wrenformer = "wrenformer/2022-11-15-wrenformer-ens=10-IS2RE-preds.csv.gz"
Expand Down

0 comments on commit 6d88d3f

Please sign in to comment.