From 1b501fca5f5fc09b6a055d4fde76b9515c844c32 Mon Sep 17 00:00:00 2001 From: "Haoyu (Daniel)" Date: Sat, 4 Jan 2025 10:30:55 +0800 Subject: [PATCH] add some types for io.exciting inputs --- src/pymatgen/io/exciting/inputs.py | 93 +++++++++++++----------------- 1 file changed, 41 insertions(+), 52 deletions(-) diff --git a/src/pymatgen/io/exciting/inputs.py b/src/pymatgen/io/exciting/inputs.py index 7ef99482b03..46476994a84 100644 --- a/src/pymatgen/io/exciting/inputs.py +++ b/src/pymatgen/io/exciting/inputs.py @@ -16,10 +16,13 @@ from pymatgen.symmetry.bandstructure import HighSymmKpath if TYPE_CHECKING: - from pathlib import Path + from typing import Literal + from numpy.typing import ArrayLike from typing_extensions import Self + from pymatgen.util.typing import PathLike + __author__ = "Christian Vorwerk" __copyright__ = "Copyright 2016" __version__ = "1.0" @@ -37,10 +40,10 @@ class ExcitingInput(MSONable): Attributes: structure (Structure): Associated Structure. title (str): Optional title string. - lockxyz (numpy.ndarray): Lockxyz attribute for each site if available. A Nx3 array of booleans. + lockxyz (NDArray): Lockxyz attribute for each site if available. A Nx3 array of booleans. """ - def __init__(self, structure: Structure, title=None, lockxyz=None): + def __init__(self, structure: Structure, title: str | None = None, lockxyz: ArrayLike | None = None): """ Args: structure (Structure): Structure object. @@ -52,7 +55,7 @@ def __init__(self, structure: Structure, title=None, lockxyz=None): if structure.is_ordered: site_properties = {} if lockxyz: - site_properties["selective_dynamics"] = lockxyz + site_properties["selective_dynamics"] = np.asarray(lockxyz) self.structure = structure.copy(site_properties=site_properties) self.title = structure.formula if title is None else title else: @@ -164,7 +167,7 @@ def from_str(cls, data: str) -> Self: return cls(structure_in, title_in, lockxyz) @classmethod - def from_file(cls, filename: str | Path) -> Self: + def from_file(cls, filename: PathLike) -> Self: """ Args: filename: Filename. @@ -178,11 +181,11 @@ def from_file(cls, filename: str | Path) -> Self: def write_etree( self, - celltype, - cartesian=False, - bandstr=False, + celltype: Literal["unchanged", "conventional", "primitive"], + cartesian: bool = False, + bandstr: bool = False, symprec: float = 0.4, - angle_tolerance=5, + angle_tolerance: float = 5, **kwargs, ): """Write the exciting input parameters to an XML object. @@ -191,19 +194,14 @@ def write_etree( celltype (str): Choice of unit cell. Can be either the unit cell from self.structure ("unchanged"), the conventional cell ("conventional"), or the primitive unit cell ("primitive"). - cartesian (bool): Whether the atomic positions are provided in Cartesian or unit-cell coordinates. Default is False. - bandstr (bool): Whether the bandstructure path along the HighSymmKpath is included in the input file. Only supported if the celltype is set to "primitive". Default is False. - symprec (float): Tolerance for the symmetry finding. Default is 0.4. - angle_tolerance (float): Angle tolerance for the symmetry finding. - Default is 5. - + Default is 5. **kwargs: Additional parameters for the input file. Returns: @@ -297,32 +295,27 @@ def write_etree( def write_string( self, - celltype, - cartesian=False, - bandstr=False, + celltype: Literal["unchanged", "conventional", "primitive"], + cartesian: bool = False, + bandstr: bool = False, symprec: float = 0.4, - angle_tolerance=5, + angle_tolerance: float = 5, **kwargs, - ): - """Write exciting input.xml as a string. + ) -> str: + """Convert exciting input.xml to a string. Args: celltype (str): Choice of unit cell. Can be either the unit cell - from self.structure ("unchanged"), the conventional cell - ("conventional"), or the primitive unit cell ("primitive"). - + from self.structure ("unchanged"), the conventional cell + ("conventional"), or the primitive unit cell ("primitive"). cartesian (bool): Whether the atomic positions are provided in - Cartesian or unit-cell coordinates. Default is False. - + Cartesian or unit-cell coordinates. Default is False. bandstr (bool): Whether the bandstructure path along the - HighSymmKpath is included in the input file. Only supported if the - celltype is set to "primitive". Default is False. - + HighSymmKpath is included in the input file. Only supported if the + celltype is set to "primitive". Default is False. symprec (float): Tolerance for the symmetry finding. Default is 0.4. - angle_tolerance (float): Angle tolerance for the symmetry finding. - Default is 5. - + Default is 5. **kwargs: Additional parameters for the input file. Returns: @@ -333,41 +326,37 @@ def write_string( self._indent(root) # output should be a string not a bytes object string = ET.tostring(root).decode("UTF-8") + except Exception: raise ValueError("Incorrect celltype!") + return string def write_file( self, - celltype, - filename, - cartesian=False, - bandstr=False, + celltype: Literal["unchanged", "conventional", "primitive"], + filename: str, + cartesian: bool = False, + bandstr: bool = False, symprec: float = 0.4, - angle_tolerance=5, + angle_tolerance: float = 5, **kwargs, - ): + ) -> None: """Write exciting input file. Args: celltype (str): Choice of unit cell. Can be either the unit cell - from self.structure ("unchanged"), the conventional cell - ("conventional"), or the primitive unit cell ("primitive"). - + from self.structure ("unchanged"), the conventional cell + ("conventional"), or the primitive unit cell ("primitive"). filename (str): Filename for exciting input. - cartesian (bool): Whether the atomic positions are provided in - Cartesian or unit-cell coordinates. Default is False. - + Cartesian or unit-cell coordinates. Default is False. bandstr (bool): Whether the bandstructure path along the - HighSymmKpath is included in the input file. Only supported if the - celltype is set to "primitive". Default is False. - + HighSymmKpath is included in the input file. Only supported if the + celltype is set to "primitive". Default is False. symprec (float): Tolerance for the symmetry finding. Default is 0.4. - angle_tolerance (float): Angle tolerance for the symmetry finding. - Default is 5. - + Default is 5. **kwargs: Additional parameters for the input file. """ try: @@ -380,7 +369,7 @@ def write_file( # Missing PrettyPrint option in the current version of xml.etree.cElementTree @staticmethod - def _indent(elem, level=0): + def _indent(elem, level: int = 0) -> None: """ Helper method to indent elements. @@ -401,7 +390,7 @@ def _indent(elem, level=0): elif level and (not elem.tail or not elem.tail.strip()): elem.tail = i - def _dicttoxml(self, paramdict_, element): + def _dicttoxml(self, paramdict_, element) -> None: for key, value in paramdict_.items(): if isinstance(value, str) and key == "text()": element.text = value