From 6e3aa48d184cf96b6bdfa6970520c5567f3947e8 Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Tue, 2 Apr 2024 17:14:38 -0700 Subject: [PATCH] continue refactor --- .../emmet/builders/vasp/materials.py | 4 +- .../emmet/builders/vasp/task_validator.py | 6 ++- .../tests/test_corrected_entries_thermo.py | 2 +- emmet-core/emmet/core/tasks.py | 44 ++++++++++--------- emmet-core/emmet/core/vasp/calculation.py | 4 +- emmet-core/emmet/core/vasp/validation.py | 4 +- emmet-core/tests/vasp/test_vasp.py | 37 +++++++++------- 7 files changed, 54 insertions(+), 47 deletions(-) diff --git a/emmet-builders/emmet/builders/vasp/materials.py b/emmet-builders/emmet/builders/vasp/materials.py index b7b82673d5..f1a86e1c6c 100644 --- a/emmet-builders/emmet/builders/vasp/materials.py +++ b/emmet-builders/emmet/builders/vasp/materials.py @@ -229,7 +229,9 @@ def process_item(self, items: List[Dict]) -> List[Dict]: were processed """ - tasks = [TaskDoc(**task) for task in items] + tasks = [ + TaskDoc(**task) for task in items + ] # [TaskDoc(**task) for task in items] formula = tasks[0].formula_pretty task_ids = [task.task_id for task in tasks] diff --git a/emmet-builders/emmet/builders/vasp/task_validator.py b/emmet-builders/emmet/builders/vasp/task_validator.py index b601114783..9b16440a6c 100644 --- a/emmet-builders/emmet/builders/vasp/task_validator.py +++ b/emmet-builders/emmet/builders/vasp/task_validator.py @@ -4,8 +4,8 @@ from maggma.core import Store from emmet.builders.settings import EmmetBuildSettings -from emmet.core.tasks import TaskDoc from emmet.builders.utils import get_potcar_stats +from emmet.core.tasks import TaskDoc from emmet.core.vasp.calc_types.enums import CalcType from emmet.core.vasp.validation import DeprecationMessage, ValidationDoc @@ -66,7 +66,9 @@ def unary_function(self, item): item (dict): a (projection of a) task doc """ if not item["output"].get("energy"): - item["output"]["energy"] = -1e20 + # Default value required for pydantic typing. `TaskDoc.output.energy` + # must be float. + item["output"]["energy"] = 1e20 task_doc = TaskDoc(**item) validation_doc = ValidationDoc.from_task_doc( task_doc=task_doc, diff --git a/emmet-builders/tests/test_corrected_entries_thermo.py b/emmet-builders/tests/test_corrected_entries_thermo.py index 620e662125..f7f03af057 100644 --- a/emmet-builders/tests/test_corrected_entries_thermo.py +++ b/emmet-builders/tests/test_corrected_entries_thermo.py @@ -34,7 +34,7 @@ def thermo_store(): @pytest.fixture def phase_diagram_store(): - return MemoryStore(key="chemsys") + return MemoryStore(key="phase_diagram_id") def test_corrected_entries_builder(corrected_entries_store, materials_store): diff --git a/emmet-core/emmet/core/tasks.py b/emmet-core/emmet/core/tasks.py index 66bb650702..c18b96c981 100644 --- a/emmet-core/emmet/core/tasks.py +++ b/emmet-core/emmet/core/tasks.py @@ -9,6 +9,7 @@ import numpy as np from emmet.core.common import convert_datetime +from emmet.core.mpid import MPID from emmet.core.structure import StructureMetadata from emmet.core.vasp.calc_types import ( CalcType, @@ -28,7 +29,14 @@ from emmet.core.vasp.task_valid import TaskState from monty.json import MontyDecoder from monty.serialization import loadfn -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + field_validator, + model_validator, + PrivateAttr, +) from pymatgen.analysis.structure_analyzer import oxide_type from pymatgen.core.structure import Structure from pymatgen.core.trajectory import Trajectory @@ -349,7 +357,7 @@ class TaskDoc(StructureMetadata, extra="allow"): None, description="The type of calculation." ) - task_id: Optional[str] = Field( + task_id: Optional[Union[MPID, str]] = Field( None, description="The (task) ID of this calculation, used as a universal reference across property documents." "This comes in the form: mp-******.", @@ -419,18 +427,12 @@ class TaskDoc(StructureMetadata, extra="allow"): # can't find them, throws an AttributeError. It does this before looking to see if the # class has that attr defined on it. - private_calc_type: Optional[CalcType] = Field( - None, description="Private field used to store output of `TaskDoc.calc_type`." - ) - - private_run_type: Optional[RunType] = Field( - None, description="Private field used to store output of `TaskDoc.run_type`." - ) - - private_structure_entry: Optional[ComputedStructureEntry] = Field( - None, - description="Private field used to store output of `TaskDoc.structure_entry`.", - ) + # Private field used to store output of `TaskDoc.calc_type` + _calc_type: Optional[CalcType] = PrivateAttr(None) + # Private field used to store output of `TaskDoc.run_type` + _run_type: Optional[RunType] = PrivateAttr(None) + # Private field used to store output of `TaskDoc.structure_entry`. + _structure_entry: Optional[ComputedStructureEntry] = PrivateAttr(None) def model_post_init(self, __context: Any) -> None: # Needed for compatibility with TaskDocument @@ -641,7 +643,7 @@ def from_vasprun( @staticmethod def get_entry( - calcs_reversed: List[Calculation], task_id: Optional[str] = None + calcs_reversed: List[Calculation], task_id: Optional[Union[MPID, str]] = None ) -> ComputedEntry: """ Get a computed entry from a list of VASP calculation documents. @@ -700,16 +702,16 @@ def calc_type(self) -> CalcType: params = self.calcs_reversed[0].input.parameters incar = self.calcs_reversed[0].input.incar - self.private_calc_type = calc_type(inputs, {**params, **incar}) - return self.private_calc_type + self._calc_type = calc_type(inputs, {**params, **incar}) + return self._calc_type @property def run_type(self) -> RunType: params = self.calcs_reversed[0].input.parameters incar = self.calcs_reversed[0].input.incar - self.private_run_type = run_type({**params, **incar}) - return self.private_run_type + self._run_type = run_type({**params, **incar}) + return self._run_type @property def structure_entry(self) -> ComputedStructureEntry: @@ -721,7 +723,7 @@ def structure_entry(self) -> ComputedStructureEntry: ComputedStructureEntry The TaskDoc.entry with corresponding TaskDoc.structure added. """ - self.private_structure_entry = ComputedStructureEntry( + self._structure_entry = ComputedStructureEntry( structure=self.structure, energy=self.entry.energy, correction=self.entry.correction, @@ -731,7 +733,7 @@ def structure_entry(self) -> ComputedStructureEntry: data=self.entry.data, entry_id=self.entry.entry_id, ) - return self.private_structure_entry + return self._structure_entry class TrajectoryDoc(BaseModel): diff --git a/emmet-core/emmet/core/vasp/calculation.py b/emmet-core/emmet/core/vasp/calculation.py index 02a3fcda90..c7628bde7a 100644 --- a/emmet-core/emmet/core/vasp/calculation.py +++ b/emmet-core/emmet/core/vasp/calculation.py @@ -72,9 +72,7 @@ class CalculationBaseModel(BaseModel): """Wrapper around pydantic BaseModel with extra functionality.""" def get(self, key: Any, default_value: Optional[Any] = None) -> Any: - if hasattr(self, key): - return self.__getattribute__(key) - return default_value + return getattr(self, key, default_value) class PotcarSpec(BaseModel): diff --git a/emmet-core/emmet/core/vasp/validation.py b/emmet-core/emmet/core/vasp/validation.py index d989eee00a..6103f503fd 100644 --- a/emmet-core/emmet/core/vasp/validation.py +++ b/emmet-core/emmet/core/vasp/validation.py @@ -339,14 +339,14 @@ def _potcar_stats_check(task_doc, potcar_stats: dict): data_tol = 1.0e-6 try: - potcar_details = task_doc.calcs_reversed[0]["input"]["potcar_spec"] + potcar_details = task_doc.calcs_reversed[0].model_dump()["input"]["potcar_spec"] except KeyError: # Assume it is an old calculation without potcar_spec data and treat it as passing POTCAR hash check return False use_legacy_hash_check = False - if any(len(entry.get("summary_stats", {})) == 0 for entry in potcar_details): + if any(entry.get("summary_stats", None) is None for entry in potcar_details): # potcar_spec doesn't include summary_stats kwarg needed to check potcars # fall back to header hash checking use_legacy_hash_check = True diff --git a/emmet-core/tests/vasp/test_vasp.py b/emmet-core/tests/vasp/test_vasp.py index 0d24ea6f9f..d3a0e26aed 100644 --- a/emmet-core/tests/vasp/test_vasp.py +++ b/emmet-core/tests/vasp/test_vasp.py @@ -4,7 +4,7 @@ from monty.io import zopen from emmet.core.vasp.calc_types import RunType, TaskType, run_type, task_type -from emmet.core.vasp.task_valid import TaskDocument +from emmet.core.tasks import TaskDoc from emmet.core.vasp.validation import ValidationDoc, _potcar_stats_check @@ -45,7 +45,7 @@ def tasks(test_dir): with zopen(test_dir / "test_si_tasks.json.gz") as f: data = json.load(f) - return [TaskDocument(**d) for d in data] + return [TaskDoc(**d) for d in data] def test_validator(tasks): @@ -58,7 +58,7 @@ def test_validator(tasks): def test_validator_failed_symmetry(test_dir): with zopen(test_dir / "failed_elastic_task.json.gz", "r") as f: failed_task = json.load(f) - taskdoc = TaskDocument(**failed_task) + taskdoc = TaskDoc(**failed_task) validation = ValidationDoc.from_task_doc(taskdoc) assert any("SYMMETRY" in repr(reason) for reason in validation.reasons) @@ -74,7 +74,7 @@ def task_ldau(test_dir): with zopen(test_dir / "test_task.json") as f: data = json.load(f) - return TaskDocument(**data) + return TaskDoc(**data) def test_ldau(task_ldau): @@ -87,7 +87,7 @@ def test_ldau_validation(test_dir): with open(test_dir / "old_aflow_ggau_task.json") as f: data = json.load(f) - task = TaskDocument(**data) + task = TaskDoc(**data) assert task.run_type == "GGA+U" valid = ValidationDoc.from_task_doc(task) @@ -113,21 +113,24 @@ def test_potcar_stats_check(test_dir): < filename > ) I cannot rebuild the TaskDoc without excluding the `orig_inputs` key. """ - task_doc = TaskDocument(**{key: data[key] for key in data if key != "last_updated"}) + # task_doc = TaskDocument(**{key: data[key] for key in data if key != "last_updated"}) + task_doc = TaskDoc(**data) try: # First check: generate hashes from POTCARs in TaskDoc, check should pass calc_type = str(task_doc.calc_type) expected_hashes = {calc_type: {}} - for spec in task_doc.calcs_reversed[0]["input"]["potcar_spec"]: - symbol = spec["titel"].split(" ")[1] + for spec in task_doc.calcs_reversed[0].input.potcar_spec: + symbol = spec.titel.split(" ")[1] potcar = PotcarSingle.from_symbol_and_functional( symbol=symbol, functional="PBE" ) - expected_hashes[calc_type][symbol] = { - **potcar._summary_stats, - "hash": potcar.md5_header_hash, - "titel": potcar.TITEL, - } + expected_hashes[calc_type][symbol] = [ + { + **potcar._summary_stats, + "hash": potcar.md5_header_hash, + "titel": potcar.TITEL, + } + ] assert not _potcar_stats_check(task_doc, expected_hashes) @@ -141,8 +144,8 @@ def test_potcar_stats_check(test_dir): # Third check: change data in expected hashes, check should fail wrong_hashes = {calc_type: {**expected_hashes[calc_type]}} - for key in wrong_hashes[calc_type][first_element]["stats"]["data"]: - wrong_hashes[calc_type][first_element]["stats"]["data"][key] *= 1.1 + for key in wrong_hashes[calc_type][first_element][0]["stats"]["data"]: + wrong_hashes[calc_type][first_element][0]["stats"]["data"][key] *= 1.1 assert _potcar_stats_check(task_doc, wrong_hashes) @@ -159,7 +162,7 @@ def test_potcar_stats_check(test_dir): } for potcar in legacy_data["calcs_reversed"][0]["input"]["potcar_spec"] ] - legacy_task_doc = TaskDocument( + legacy_task_doc = TaskDoc( **{key: legacy_data[key] for key in legacy_data if key != "last_updated"} ) assert not _potcar_stats_check(legacy_task_doc, expected_hashes) @@ -180,7 +183,7 @@ def test_potcar_stats_check(test_dir): legacy_data["calcs_reversed"][0]["input"]["potcar_spec"][0][ "hash" ] = legacy_data["calcs_reversed"][0]["input"]["potcar_spec"][0]["hash"][:-1] - legacy_task_doc = TaskDocument( + legacy_task_doc = TaskDoc( **{key: legacy_data[key] for key in legacy_data if key != "last_updated"} ) assert _potcar_stats_check(legacy_task_doc, expected_hashes)