Skip to content

Commit

Permalink
continue refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
esoteric-ephemera committed Apr 3, 2024
1 parent 0f99c88 commit 6e3aa48
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 47 deletions.
4 changes: 3 additions & 1 deletion emmet-builders/emmet/builders/vasp/materials.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,9 @@ def process_item(self, items: List[Dict]) -> List[Dict]:
were processed
"""

tasks = [TaskDoc(**task) for task in items]
tasks = [
TaskDoc(**task) for task in items
] # [TaskDoc(**task) for task in items]
formula = tasks[0].formula_pretty
task_ids = [task.task_id for task in tasks]

Expand Down
6 changes: 4 additions & 2 deletions emmet-builders/emmet/builders/vasp/task_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from maggma.core import Store

from emmet.builders.settings import EmmetBuildSettings
from emmet.core.tasks import TaskDoc
from emmet.builders.utils import get_potcar_stats
from emmet.core.tasks import TaskDoc
from emmet.core.vasp.calc_types.enums import CalcType
from emmet.core.vasp.validation import DeprecationMessage, ValidationDoc

Expand Down Expand Up @@ -66,7 +66,9 @@ def unary_function(self, item):
item (dict): a (projection of a) task doc
"""
if not item["output"].get("energy"):
item["output"]["energy"] = -1e20
# Default value required for pydantic typing. `TaskDoc.output.energy`
# must be float.
item["output"]["energy"] = 1e20
task_doc = TaskDoc(**item)
validation_doc = ValidationDoc.from_task_doc(
task_doc=task_doc,
Expand Down
2 changes: 1 addition & 1 deletion emmet-builders/tests/test_corrected_entries_thermo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def thermo_store():

@pytest.fixture
def phase_diagram_store():
return MemoryStore(key="chemsys")
return MemoryStore(key="phase_diagram_id")


def test_corrected_entries_builder(corrected_entries_store, materials_store):
Expand Down
44 changes: 23 additions & 21 deletions emmet-core/emmet/core/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np
from emmet.core.common import convert_datetime
from emmet.core.mpid import MPID
from emmet.core.structure import StructureMetadata
from emmet.core.vasp.calc_types import (
CalcType,
Expand All @@ -28,7 +29,14 @@
from emmet.core.vasp.task_valid import TaskState
from monty.json import MontyDecoder
from monty.serialization import loadfn
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from pydantic import (
BaseModel,
ConfigDict,
Field,
field_validator,
model_validator,
PrivateAttr,
)
from pymatgen.analysis.structure_analyzer import oxide_type
from pymatgen.core.structure import Structure
from pymatgen.core.trajectory import Trajectory
Expand Down Expand Up @@ -349,7 +357,7 @@ class TaskDoc(StructureMetadata, extra="allow"):
None, description="The type of calculation."
)

task_id: Optional[str] = Field(
task_id: Optional[Union[MPID, str]] = Field(
None,
description="The (task) ID of this calculation, used as a universal reference across property documents."
"This comes in the form: mp-******.",
Expand Down Expand Up @@ -419,18 +427,12 @@ class TaskDoc(StructureMetadata, extra="allow"):
# can't find them, throws an AttributeError. It does this before looking to see if the
# class has that attr defined on it.

private_calc_type: Optional[CalcType] = Field(
None, description="Private field used to store output of `TaskDoc.calc_type`."
)

private_run_type: Optional[RunType] = Field(
None, description="Private field used to store output of `TaskDoc.run_type`."
)

private_structure_entry: Optional[ComputedStructureEntry] = Field(
None,
description="Private field used to store output of `TaskDoc.structure_entry`.",
)
# Private field used to store output of `TaskDoc.calc_type`
_calc_type: Optional[CalcType] = PrivateAttr(None)
# Private field used to store output of `TaskDoc.run_type`
_run_type: Optional[RunType] = PrivateAttr(None)
# Private field used to store output of `TaskDoc.structure_entry`.
_structure_entry: Optional[ComputedStructureEntry] = PrivateAttr(None)

def model_post_init(self, __context: Any) -> None:
# Needed for compatibility with TaskDocument
Expand Down Expand Up @@ -641,7 +643,7 @@ def from_vasprun(

@staticmethod
def get_entry(
calcs_reversed: List[Calculation], task_id: Optional[str] = None
calcs_reversed: List[Calculation], task_id: Optional[Union[MPID, str]] = None
) -> ComputedEntry:
"""
Get a computed entry from a list of VASP calculation documents.
Expand Down Expand Up @@ -700,16 +702,16 @@ def calc_type(self) -> CalcType:
params = self.calcs_reversed[0].input.parameters
incar = self.calcs_reversed[0].input.incar

self.private_calc_type = calc_type(inputs, {**params, **incar})
return self.private_calc_type
self._calc_type = calc_type(inputs, {**params, **incar})
return self._calc_type

@property
def run_type(self) -> RunType:
params = self.calcs_reversed[0].input.parameters
incar = self.calcs_reversed[0].input.incar

self.private_run_type = run_type({**params, **incar})
return self.private_run_type
self._run_type = run_type({**params, **incar})
return self._run_type

@property
def structure_entry(self) -> ComputedStructureEntry:
Expand All @@ -721,7 +723,7 @@ def structure_entry(self) -> ComputedStructureEntry:
ComputedStructureEntry
The TaskDoc.entry with corresponding TaskDoc.structure added.
"""
self.private_structure_entry = ComputedStructureEntry(
self._structure_entry = ComputedStructureEntry(
structure=self.structure,
energy=self.entry.energy,
correction=self.entry.correction,
Expand All @@ -731,7 +733,7 @@ def structure_entry(self) -> ComputedStructureEntry:
data=self.entry.data,
entry_id=self.entry.entry_id,
)
return self.private_structure_entry
return self._structure_entry


class TrajectoryDoc(BaseModel):
Expand Down
4 changes: 1 addition & 3 deletions emmet-core/emmet/core/vasp/calculation.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ class CalculationBaseModel(BaseModel):
"""Wrapper around pydantic BaseModel with extra functionality."""

def get(self, key: Any, default_value: Optional[Any] = None) -> Any:
if hasattr(self, key):
return self.__getattribute__(key)
return default_value
return getattr(self, key, default_value)


class PotcarSpec(BaseModel):
Expand Down
4 changes: 2 additions & 2 deletions emmet-core/emmet/core/vasp/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,14 +339,14 @@ def _potcar_stats_check(task_doc, potcar_stats: dict):
data_tol = 1.0e-6

try:
potcar_details = task_doc.calcs_reversed[0]["input"]["potcar_spec"]
potcar_details = task_doc.calcs_reversed[0].model_dump()["input"]["potcar_spec"]

except KeyError:
# Assume it is an old calculation without potcar_spec data and treat it as passing POTCAR hash check
return False

use_legacy_hash_check = False
if any(len(entry.get("summary_stats", {})) == 0 for entry in potcar_details):
if any(entry.get("summary_stats", None) is None for entry in potcar_details):
# potcar_spec doesn't include summary_stats kwarg needed to check potcars
# fall back to header hash checking
use_legacy_hash_check = True
Expand Down
37 changes: 20 additions & 17 deletions emmet-core/tests/vasp/test_vasp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from monty.io import zopen

from emmet.core.vasp.calc_types import RunType, TaskType, run_type, task_type
from emmet.core.vasp.task_valid import TaskDocument
from emmet.core.tasks import TaskDoc
from emmet.core.vasp.validation import ValidationDoc, _potcar_stats_check


Expand Down Expand Up @@ -45,7 +45,7 @@ def tasks(test_dir):
with zopen(test_dir / "test_si_tasks.json.gz") as f:
data = json.load(f)

return [TaskDocument(**d) for d in data]
return [TaskDoc(**d) for d in data]


def test_validator(tasks):
Expand All @@ -58,7 +58,7 @@ def test_validator(tasks):
def test_validator_failed_symmetry(test_dir):
with zopen(test_dir / "failed_elastic_task.json.gz", "r") as f:
failed_task = json.load(f)
taskdoc = TaskDocument(**failed_task)
taskdoc = TaskDoc(**failed_task)
validation = ValidationDoc.from_task_doc(taskdoc)
assert any("SYMMETRY" in repr(reason) for reason in validation.reasons)

Expand All @@ -74,7 +74,7 @@ def task_ldau(test_dir):
with zopen(test_dir / "test_task.json") as f:
data = json.load(f)

return TaskDocument(**data)
return TaskDoc(**data)


def test_ldau(task_ldau):
Expand All @@ -87,7 +87,7 @@ def test_ldau_validation(test_dir):
with open(test_dir / "old_aflow_ggau_task.json") as f:
data = json.load(f)

task = TaskDocument(**data)
task = TaskDoc(**data)
assert task.run_type == "GGA+U"

valid = ValidationDoc.from_task_doc(task)
Expand All @@ -113,21 +113,24 @@ def test_potcar_stats_check(test_dir):
< filename > )
I cannot rebuild the TaskDoc without excluding the `orig_inputs` key.
"""
task_doc = TaskDocument(**{key: data[key] for key in data if key != "last_updated"})
# task_doc = TaskDocument(**{key: data[key] for key in data if key != "last_updated"})
task_doc = TaskDoc(**data)
try:
# First check: generate hashes from POTCARs in TaskDoc, check should pass
calc_type = str(task_doc.calc_type)
expected_hashes = {calc_type: {}}
for spec in task_doc.calcs_reversed[0]["input"]["potcar_spec"]:
symbol = spec["titel"].split(" ")[1]
for spec in task_doc.calcs_reversed[0].input.potcar_spec:
symbol = spec.titel.split(" ")[1]
potcar = PotcarSingle.from_symbol_and_functional(
symbol=symbol, functional="PBE"
)
expected_hashes[calc_type][symbol] = {
**potcar._summary_stats,
"hash": potcar.md5_header_hash,
"titel": potcar.TITEL,
}
expected_hashes[calc_type][symbol] = [
{
**potcar._summary_stats,
"hash": potcar.md5_header_hash,
"titel": potcar.TITEL,
}
]

assert not _potcar_stats_check(task_doc, expected_hashes)

Expand All @@ -141,8 +144,8 @@ def test_potcar_stats_check(test_dir):
# Third check: change data in expected hashes, check should fail

wrong_hashes = {calc_type: {**expected_hashes[calc_type]}}
for key in wrong_hashes[calc_type][first_element]["stats"]["data"]:
wrong_hashes[calc_type][first_element]["stats"]["data"][key] *= 1.1
for key in wrong_hashes[calc_type][first_element][0]["stats"]["data"]:
wrong_hashes[calc_type][first_element][0]["stats"]["data"][key] *= 1.1

assert _potcar_stats_check(task_doc, wrong_hashes)

Expand All @@ -159,7 +162,7 @@ def test_potcar_stats_check(test_dir):
}
for potcar in legacy_data["calcs_reversed"][0]["input"]["potcar_spec"]
]
legacy_task_doc = TaskDocument(
legacy_task_doc = TaskDoc(
**{key: legacy_data[key] for key in legacy_data if key != "last_updated"}
)
assert not _potcar_stats_check(legacy_task_doc, expected_hashes)
Expand All @@ -180,7 +183,7 @@ def test_potcar_stats_check(test_dir):
legacy_data["calcs_reversed"][0]["input"]["potcar_spec"][0][
"hash"
] = legacy_data["calcs_reversed"][0]["input"]["potcar_spec"][0]["hash"][:-1]
legacy_task_doc = TaskDocument(
legacy_task_doc = TaskDoc(
**{key: legacy_data[key] for key in legacy_data if key != "last_updated"}
)
assert _potcar_stats_check(legacy_task_doc, expected_hashes)
Expand Down

0 comments on commit 6e3aa48

Please sign in to comment.