Skip to content

Commit

Permalink
move mb_discovery/{energy/__init__.py -> energy.py} and add tests/tes…
Browse files Browse the repository at this point in the history
…t_energy.py
  • Loading branch information
janosh committed Jun 20, 2023
1 parent 8d9e346 commit aec9933
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 19 deletions.
4 changes: 1 addition & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ __pycache__
*.csv.bz2
*.pkl.gz
data/**/raw
data/**/202*

# checkpoint files of trained models
pretrained/
Expand All @@ -22,9 +23,6 @@ job-logs/
# slurm logs
slurm-*out
models/**/*.csv
mb_discovery/energy/**/*.csv
mb_discovery/energy/**/*.json
mb_discovery/energy/**/*.gzip

# temporary ignore rule
paper
Expand Down
64 changes: 48 additions & 16 deletions mb_discovery/energy/__init__.py → mb_discovery/energy.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,48 @@
import itertools
from collections.abc import Sequence

import pandas as pd
from pymatgen.analysis.phase_diagram import Entry, PDEntry
from pymatgen.core import Composition
from pymatgen.util.typing import EntryLike
from tqdm import tqdm

from mb_discovery import ROOT


def get_elemental_ref_entries(
entries: list[Entry], verbose: bool = False
entries: Sequence[EntryLike], verbose: bool = False
) -> dict[str, Entry]:
"""Get the lowest energy entry for each element in a list of entries.
Args:
entries (Sequence[Entry]): pymatgen Entries (PDEntry, ComputedEntry or
ComputedStructureEntry) to find elemental reference entries for.
verbose (bool, optional): _description_. Defaults to False.
Raises:
ValueError: If some elements are missing terminal reference entries.
ValueError: If there are more terminal entries than dimensions. Should never
happen.
Returns:
dict[str, Entry]: Map from element symbol to its lowest energy entry.
"""
entries = [PDEntry.from_dict(e) if isinstance(e, dict) else e for e in entries]
elements = {elems for entry in entries for elems in entry.composition.elements}
dim = len(elements)

if verbose:
print(f"Sorting {len(entries)} entries with {dim} dimensions...")

entries = sorted(entries, key=lambda e: e.composition.reduced_composition)

elemental_ref_entries = {}
if verbose:
print("Finding elemental reference entries...", flush=True)
for composition, group in tqdm(
itertools.groupby(entries, key=lambda e: e.composition.reduced_composition)
for composition, entry_group in tqdm(
itertools.groupby(entries, key=lambda e: e.composition.reduced_composition),
disable=not verbose,
):
min_entry = min(group, key=lambda e: e.energy_per_atom)
min_entry = min(entry_group, key=lambda e: e.energy_per_atom)
if composition.is_element:
elem_symb = str(composition.elements[0])
elemental_ref_entries[elem_symb] = min_entry
Expand Down Expand Up @@ -53,14 +71,16 @@ def get_elemental_ref_entries(


def get_e_form_per_atom(
entry: Entry, elemental_ref_entries: dict[str, Entry] = None
entry: EntryLike,
elemental_ref_entries: dict[str, EntryLike] = None,
) -> float:
"""Get the formation energy of a composition from a list of entries and elemental
reference energies.
Args:
entry (Entry): pymatgen Entry (PDEntry, ComputedEntry or ComputedStructureEntry)
to compute formation energy of.
entry: Entry | dict[str, float | str | Composition]: pymatgen Entry (PDEntry,
ComputedEntry or ComputedStructureEntry) or dict with energy and composition
keys to compute formation energy of.
elemental_ref_entries (dict[str, Entry], optional): Must be a complete set of
terminal (i.e. elemental) reference entries containing the lowest energy
phase for each element present in entry. Defaults to MP elemental reference
Expand All @@ -76,13 +96,25 @@ def get_e_form_per_atom(
f"Couldn't load {mp_elem_refs_path=}, you must pass "
f"{elemental_ref_entries=} explicitly."
)

elemental_ref_entries = mp_elem_reference_entries

comp = entry.composition
form_energy = entry.uncorrected_energy - sum(
comp[el] * elemental_ref_entries[str(el)].energy_per_atom
for el in entry.composition.elements
)
if isinstance(entry, dict):
energy = entry["energy"]
comp = Composition(entry["composition"]) # is idempotent if already Composition
elif isinstance(entry, Entry):
energy = entry.energy
comp = entry.composition
else:
raise TypeError(
f"{entry=} must be Entry (or subclass like ComputedEntry) or dict"
)

refs = {str(el): elemental_ref_entries[str(el)] for el in comp}

for key, ref_entry in refs.items():
if isinstance(ref_entry, dict):
refs[key] = PDEntry.from_dict(ref_entry)

form_energy = energy - sum(comp[el] * refs[str(el)].energy_per_atom for el in comp)

return form_energy / entry.composition.num_atoms
return form_energy / comp.num_atoms
74 changes: 74 additions & 0 deletions tests/test_energy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from __future__ import annotations

from typing import Any, Callable

import pytest
from pymatgen.analysis.phase_diagram import PDEntry
from pymatgen.core import Lattice, Structure
from pymatgen.entries.computed_entries import (
ComputedEntry,
ComputedStructureEntry,
Entry,
)

from mb_discovery.energy import get_e_form_per_atom, get_elemental_ref_entries

dummy_struct = Structure(
lattice=Lattice.cubic(5),
species=("Fe", "O"),
coords=((0, 0, 0), (0.5, 0.5, 0.5)),
)


@pytest.mark.parametrize(
"constructor", [PDEntry, ComputedEntry, ComputedStructureEntry, lambda **x: x]
)
def test_get_e_form_per_atom(
constructor: Callable[..., Entry | dict[str, Any]]
) -> None:
"""Test that the formation energy of a composition is computed correctly."""

entry = {"composition": {"Fe": 1, "O": 1}, "energy": -2.5}
elemental_ref_entries = {
"Fe": {"composition": {"Fe": 1}, "energy": -1.0},
"O": {"composition": {"O": 1}, "energy": -1.0},
}
if constructor == ComputedStructureEntry:
entry["structure"] = dummy_struct
entry.pop("composition")

entry = constructor(**entry)

# don't use ComputedStructureEntry for elemental ref entries, would need many
# dummy structures
if constructor == ComputedStructureEntry:
constructor = ComputedEntry
elemental_ref_entries = {
k: constructor(**v) for k, v in elemental_ref_entries.items()
}
assert get_e_form_per_atom(entry, elemental_ref_entries) == -0.25


@pytest.mark.parametrize("constructor", [PDEntry, ComputedEntry, lambda **x: x])
@pytest.mark.parametrize("verbose", [True, False])
def test_get_elemental_ref_entries(
constructor: Callable[..., Entry | dict[str, Any]], verbose: bool
) -> None:
"""Test that the elemental reference entries are correctly identified."""
entries = [
("Fe1 O1", -2.5),
("Fe1", -1.0),
("Fe1", -2.0),
("O1", -1.0),
("O3", -2.0),
]
elemental_ref_entries = get_elemental_ref_entries(
[constructor(composition=comp, energy=energy) for comp, energy in entries],
verbose=verbose,
)
if constructor.__name__ == "<lambda>":
expected = {"Fe": PDEntry(*entries[2]), "O": PDEntry(*entries[3])}
else:
expected = {"Fe": constructor(*entries[2]), "O": constructor(*entries[3])}

assert elemental_ref_entries == expected

0 comments on commit aec9933

Please sign in to comment.