Skip to content

Commit

Permalink
rename DictableStrEnum->LabelEnum, add ability to set key labels and …
Browse files Browse the repository at this point in the history
…add labels to all Keys
  • Loading branch information
janosh committed Feb 6, 2024
1 parent cdf840d commit 80a494c
Showing 1 changed file with 66 additions and 51 deletions.
117 changes: 66 additions & 51 deletions matbench_discovery/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,75 +6,90 @@
from pymatviz.utils import styled_html_tag


class DictableStrEnum(StrEnum):
"""StrEnum with optional description attributes and dict() method."""
class LabelEnum(StrEnum):
"""StrEnum with optional label and description attributes plus dict() method."""

def __new__(cls, val: str, desc: str | None = None) -> Self:
def __new__(
cls, val: str, label: str | None = None, desc: str | None = None
) -> Self:
"""Create a new class."""
member = str.__new__(cls, val)
member._value_ = val
member.__dict__["desc"] = desc
member.__dict__ |= dict(label=label, desc=desc)
return member

@property
def label(self) -> str:
"""Make label read-only."""
return self.__dict__["label"]

@property
def description(self) -> str:
"""Make description read-only."""
return self.__dict__["desc"]

@classmethod
def dict(cls) -> dict[str, str]:
"""Return the enum as a dictionary."""
"""Return the Enum as dictionary."""
return {key: str(val) for key, val in cls.__members__.items()}


@unique
class Key(DictableStrEnum):
class Key(LabelEnum):
"""Keys used to access dataframes columns."""

arity = "arity"
bandgap_pbe = "bandgap_pbe"
chem_sys = "chemical_system"
composition = "composition"
cse = "computed_structure_entry"
dft_energy = "uncorrected_energy"
e_form = "e_form_per_atom_mp2020_corrected"
e_form_pred = "e_form_per_atom_pred"
e_form_raw = "e_form_per_atom_uncorrected"
e_form_wbm = "e_form_per_atom_wbm"
each = "energy_above_hull" # as returned by MP API
each_pred = "e_above_hull_pred"
each_true = "e_above_hull_mp2020_corrected_ppd_mp"
each_wbm = "e_above_hull_wbm"
final_struct = "relaxed_structure"
forces = "forces"
form_energy = "formation_energy_per_atom"
formula = "formula"
init_struct = "initial_structure"
magmoms = "magmoms"
mat_id = "material_id"
model_mean_each = "Mean prediction all models"
model_mean_err = "Mean error all models"
model_std_each = "Std. dev. over models"
n_sites = "n_sites"
site_nums = "site_nums"
spacegroup = "spacegroup"
stress = "stress"
stress_trace = "stress_trace"
struct = "structure"
task_id = "task_id"
task_type = "task_type"
arity = "arity", "Arity"
bandgap_pbe = "bandgap_pbe", "PBE Band Gap"
chem_sys = "chemical_system", "Chemical System"
composition = "composition", "Composition"
cse = "computed_structure_entry", "Computed Structure Entry"
dft_energy = "uncorrected_energy", "DFT Energy"
e_form = "e_form_per_atom_mp2020_corrected", "DFT E_form"
e_form_pred = "e_form_per_atom_pred", "Predicted E_form"
e_form_raw = "e_form_per_atom_uncorrected", "DFT E_form raw"
e_form_wbm = "e_form_per_atom_wbm", "WBM E_form"
each = "energy_above_hull", "E<sub>hull dist</sub>"
each_pred = "e_above_hull_pred", "Predicted E<sub>hull dist</sub>"
each_true = "e_above_hull_mp2020_corrected_ppd_mp", "E<sub>MP hull dist</sub>"
each_wbm = "e_above_hull_wbm", "E<sub>WBM hull dist</sub>"
final_struct = "relaxed_structure", "Relaxed Structure"
forces = "forces", "Forces"
form_energy = "formation_energy_per_atom", "Formation Energy (eV/atom)"
formula = "formula", "Formula"
init_struct = "initial_structure", "Initial Structure"
magmoms = "magmoms", "Magnetic Moments"
mat_id = "material_id", "Material ID"
model_mean_each = "mean_pred_models", "Mean prediction all models"
model_mean_err = "each_err_models", "Mean E<sub>hull dist</sub> error all models"
model_std_each = "each_std_models", "Std. dev. all models"
n_sites = "n_sites", "Number of Sites"
site_nums = "site_nums", "Site Numbers", "Atomic numbers for each crystal site"
spacegroup = "spacegroup", "Spacegroup"
stress = "stress", "Stress"
stress_trace = "stress_trace", "Stress Trace"
struct = "structure", "Structure"
task_id = "task_id", "Task ID"
task_type = "task_type", "Task Type"
train_task = "train_task", "Training Task"
test_task = "test_task", "Test Task"
targets = "targets", "Targets"
# lowest WBM structures for a given prototype that isn't already in MP
uniq_proto = "unique_prototype"
volume = "volume"
wyckoff = "wyckoff_spglib" # relaxed structure Aflow label
init_wyckoff = "wyckoff_spglib_initial_structure" # initial structure Aflow label
train_size = "Training Size" # number of structures in the training set
model_params = "Model Params" # number of parameters in the model
openness = "Openness"
uniq_proto = "unique_prototype", "Unique Prototype"
volume = "volume", "Volume (ų)"
wyckoff = "wyckoff_spglib", "Aflow-Wyckoff Label" # relaxed structure Aflow label
init_wyckoff = (
"wyckoff_spglib_initial_structure",
"Aflow-Wyckoff Label Initial Structure",
)
# number of structures in a model's training set
train_size = "train_size", "Training Size"
model_params = "model_params", "Model Parameters" # model's parameter count
model_type = "model_type", "Model Type" # number of parameters in the model
openness = "openness", "Openness" # openness of data and code for a model


@unique
class Task(DictableStrEnum):
class Task(LabelEnum):
"""Thermodynamic stability prediction task types."""

IS2RE = "IS2RE", "initial structure to relaxed energy"
Expand All @@ -90,7 +105,7 @@ class Task(DictableStrEnum):


@unique
class Targets(DictableStrEnum):
class Targets(LabelEnum):
"""Thermodynamic stability prediction task types."""

E = "E", "energy"
Expand All @@ -99,7 +114,7 @@ class Targets(DictableStrEnum):


@unique
class ModelType(DictableStrEnum):
class ModelType(LabelEnum):
"""Model types."""

GNN = "GNN", "graph neural network"
Expand All @@ -111,7 +126,7 @@ class ModelType(DictableStrEnum):


@unique
class Open(DictableStrEnum):
class Open(LabelEnum):
"""Openness of data and code for a model."""

OSOD = "OSOD", "open source, open data"
Expand All @@ -125,7 +140,7 @@ class Open(DictableStrEnum):
)


class Quantity(DictableStrEnum):
class Quantity(LabelEnum):
"""Quantity labels for plots."""

n_atoms = "Atom Count"
Expand All @@ -151,7 +166,7 @@ class Quantity(DictableStrEnum):
stress_trace = "1/3 Tr(σ) (eV/ų)" # noqa: RUF001


class Model(DictableStrEnum):
class Model(LabelEnum):
"""Model labels for plots."""

alignn_ff = "ALIGNN FF"
Expand Down

0 comments on commit 80a494c

Please sign in to comment.