Skip to content

Commit

Permalink
continue adding missing links between TaskDoc and TaskDocument
Browse files Browse the repository at this point in the history
  • Loading branch information
esoteric-ephemera committed Apr 3, 2024
1 parent 6e3aa48 commit d6ca6b0
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 14 deletions.
4 changes: 0 additions & 4 deletions emmet-builders/emmet/builders/vasp/task_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,6 @@ def unary_function(self, item):
Args:
item (dict): a (projection of a) task doc
"""
if not item["output"].get("energy"):
# Default value required for pydantic typing. `TaskDoc.output.energy`
# must be float.
item["output"]["energy"] = 1e20
task_doc = TaskDoc(**item)
validation_doc = ValidationDoc.from_task_doc(
task_doc=task_doc,
Expand Down
19 changes: 13 additions & 6 deletions emmet-core/emmet/core/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import re
from collections import OrderedDict
from datetime import datetime
from datetime import datetime, UTC
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union

Expand Down Expand Up @@ -100,7 +100,7 @@ class OutputDoc(BaseModel):
)

density: Optional[float] = Field(None, description="Density of in units of g/cc.")
energy: float = Field(..., description="Total Energy in units of eV.")
energy: Optional[float] = Field(None, description="Total Energy in units of eV.")
forces: Optional[List[List[float]]] = Field(
None, description="The force on each atom in units of eV/A^2."
)
Expand Down Expand Up @@ -353,7 +353,7 @@ class TaskDoc(StructureMetadata, extra="allow"):
None, description="Final output structure from the task"
)

task_type: Optional[Union[CalcType, TaskType]] = Field(
task_type: Optional[Union[TaskType, CalcType]] = Field(
None, description="The type of calculation."
)

Expand Down Expand Up @@ -417,7 +417,7 @@ class TaskDoc(StructureMetadata, extra="allow"):
)

last_updated: Optional[datetime] = Field(
datetime.utcnow(),
datetime.now(UTC),
description="Timestamp for the most recent calculation for this task document",
)

Expand All @@ -438,6 +438,13 @@ def model_post_init(self, __context: Any) -> None:
# Needed for compatibility with TaskDocument
if self.task_type is None:
self.task_type = task_type(self.orig_inputs)

if isinstance(self.task_type,CalcType):
# For a while, the TaskDoc.task_type was allowed to be a CalcType or TaskType
# For backwards compatibility with TaskDocument, ensure that isinstance(TaskDoc.task_type, TaskType)
temp = str(self.task_type).split(" ")
self._run_type = RunType(temp[0])
self.task_type = TaskType(" ".join(temp[1:]))

if self.structure is None:
self.structure = self.calcs_reversed[0].output.structure
Expand All @@ -451,7 +458,7 @@ def last_updated_dict_ok(cls, v) -> datetime:

@model_validator(mode="after")
def set_entry(self) -> datetime:
if not self.entry and self.calcs_reversed:
if not self.entry and self.calcs_reversed and getattr(self.calcs_reversed[0].output,"structure",None):
self.entry = self.get_entry(self.calcs_reversed, self.task_id)
return self

Expand Down Expand Up @@ -679,7 +686,7 @@ def get_entry(
"data": {
"oxide_type": oxide_type(calcs_reversed[0].output.structure),
"aspherical": calcs_reversed[0].input.parameters.get("LASPH", False),
"last_updated": str(datetime.utcnow()),
"last_updated": str(datetime.now(UTC)),
},
}
return ComputedEntry.from_dict(entry_dict)
Expand Down
9 changes: 6 additions & 3 deletions emmet-core/emmet/core/vasp/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from pydantic import ConfigDict, Field, ImportString
from pymatgen.core.structure import Structure
from pymatgen.io.vasp.inputs import Kpoints
from pymatgen.io.vasp.sets import VaspInputSet

from emmet.core.settings import EmmetSettings
Expand Down Expand Up @@ -145,7 +146,7 @@ def from_task_doc(
# Checking K-Points
# Calculations that use KSPACING will not have a .kpoints attr

if task_type != task_type.NSCF_Line:
if task_type != TaskType.NSCF_Line:
# Not validating k-point data for line-mode calculations as constructing
# the k-path is too costly for the builder and the uniform input set is used.

Expand Down Expand Up @@ -257,7 +258,7 @@ def _scf_upward_check(calcs_reversed, inputs, data, max_allowed_scf_gradient, wa

def _u_value_checks(task_doc, valid_input_set, warnings):
# NOTE: Reverting to old method of just using input.hubbards which is wrong in many instances
input_hubbards = task_doc.input.hubbards
input_hubbards = {} if task_doc.input.hubbards is None else task_doc.input.hubbards

if valid_input_set.incar.get("LDAU", False) or len(input_hubbards) > 0:
# Assemble required input_set LDAU params into dictionary
Expand Down Expand Up @@ -303,8 +304,10 @@ def _kpoint_check(input_set, inputs, calcs_reversed, data, kpts_tolerance):
input_dict = inputs

kpoints = input_dict.get("kpoints", {})
if not isinstance(kpoints, dict):
if isinstance(kpoints, Kpoints):
kpoints = kpoints.as_dict()
elif kpoints is None:
kpoints = {}
num_kpts = kpoints.get("nkpoints", 0) or np.prod(kpoints.get("kpoints", [1, 1, 1]))

data["kpts_ratio"] = num_kpts / valid_num_kpts
Expand Down
2 changes: 1 addition & 1 deletion emmet-core/tests/vasp/test_vasp.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def task_ldau(test_dir):
def test_ldau(task_ldau):
task_ldau.input.is_hubbard = True
assert task_ldau.run_type == RunType.GGA_U
assert ValidationDoc.from_task_doc(task_ldau).valid is False
assert not ValidationDoc.from_task_doc(task_ldau).valid


def test_ldau_validation(test_dir):
Expand Down

0 comments on commit d6ca6b0

Please sign in to comment.