Skip to content

Commit

Permalink
add some types for io.exciting inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielYang59 committed Jan 4, 2025
1 parent ddc94b7 commit 1b501fc
Showing 1 changed file with 41 additions and 52 deletions.
93 changes: 41 additions & 52 deletions src/pymatgen/io/exciting/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
from pymatgen.symmetry.bandstructure import HighSymmKpath

if TYPE_CHECKING:
from pathlib import Path
from typing import Literal

from numpy.typing import ArrayLike
from typing_extensions import Self

from pymatgen.util.typing import PathLike

__author__ = "Christian Vorwerk"
__copyright__ = "Copyright 2016"
__version__ = "1.0"
Expand All @@ -37,10 +40,10 @@ class ExcitingInput(MSONable):
Attributes:
structure (Structure): Associated Structure.
title (str): Optional title string.
lockxyz (numpy.ndarray): Lockxyz attribute for each site if available. A Nx3 array of booleans.
lockxyz (NDArray): Lockxyz attribute for each site if available. A Nx3 array of booleans.
"""

def __init__(self, structure: Structure, title=None, lockxyz=None):
def __init__(self, structure: Structure, title: str | None = None, lockxyz: ArrayLike | None = None):
"""
Args:
structure (Structure): Structure object.
Expand All @@ -52,7 +55,7 @@ def __init__(self, structure: Structure, title=None, lockxyz=None):
if structure.is_ordered:
site_properties = {}
if lockxyz:
site_properties["selective_dynamics"] = lockxyz
site_properties["selective_dynamics"] = np.asarray(lockxyz)
self.structure = structure.copy(site_properties=site_properties)
self.title = structure.formula if title is None else title
else:
Expand Down Expand Up @@ -164,7 +167,7 @@ def from_str(cls, data: str) -> Self:
return cls(structure_in, title_in, lockxyz)

@classmethod
def from_file(cls, filename: str | Path) -> Self:
def from_file(cls, filename: PathLike) -> Self:
"""
Args:
filename: Filename.
Expand All @@ -178,11 +181,11 @@ def from_file(cls, filename: str | Path) -> Self:

def write_etree(
self,
celltype,
cartesian=False,
bandstr=False,
celltype: Literal["unchanged", "conventional", "primitive"],
cartesian: bool = False,
bandstr: bool = False,
symprec: float = 0.4,
angle_tolerance=5,
angle_tolerance: float = 5,
**kwargs,
):
"""Write the exciting input parameters to an XML object.
Expand All @@ -191,19 +194,14 @@ def write_etree(
celltype (str): Choice of unit cell. Can be either the unit cell
from self.structure ("unchanged"), the conventional cell
("conventional"), or the primitive unit cell ("primitive").
cartesian (bool): Whether the atomic positions are provided in
Cartesian or unit-cell coordinates. Default is False.
bandstr (bool): Whether the bandstructure path along the
HighSymmKpath is included in the input file. Only supported if the
celltype is set to "primitive". Default is False.
symprec (float): Tolerance for the symmetry finding. Default is 0.4.
angle_tolerance (float): Angle tolerance for the symmetry finding.
Default is 5.
Default is 5.
**kwargs: Additional parameters for the input file.
Returns:
Expand Down Expand Up @@ -297,32 +295,27 @@ def write_etree(

def write_string(
self,
celltype,
cartesian=False,
bandstr=False,
celltype: Literal["unchanged", "conventional", "primitive"],
cartesian: bool = False,
bandstr: bool = False,
symprec: float = 0.4,
angle_tolerance=5,
angle_tolerance: float = 5,
**kwargs,
):
"""Write exciting input.xml as a string.
) -> str:
"""Convert exciting input.xml to a string.
Args:
celltype (str): Choice of unit cell. Can be either the unit cell
from self.structure ("unchanged"), the conventional cell
("conventional"), or the primitive unit cell ("primitive").
from self.structure ("unchanged"), the conventional cell
("conventional"), or the primitive unit cell ("primitive").
cartesian (bool): Whether the atomic positions are provided in
Cartesian or unit-cell coordinates. Default is False.
Cartesian or unit-cell coordinates. Default is False.
bandstr (bool): Whether the bandstructure path along the
HighSymmKpath is included in the input file. Only supported if the
celltype is set to "primitive". Default is False.
HighSymmKpath is included in the input file. Only supported if the
celltype is set to "primitive". Default is False.
symprec (float): Tolerance for the symmetry finding. Default is 0.4.
angle_tolerance (float): Angle tolerance for the symmetry finding.
Default is 5.
Default is 5.
**kwargs: Additional parameters for the input file.
Returns:
Expand All @@ -333,41 +326,37 @@ def write_string(
self._indent(root)
# output should be a string not a bytes object
string = ET.tostring(root).decode("UTF-8")

except Exception:
raise ValueError("Incorrect celltype!")

return string

def write_file(
self,
celltype,
filename,
cartesian=False,
bandstr=False,
celltype: Literal["unchanged", "conventional", "primitive"],
filename: str,
cartesian: bool = False,
bandstr: bool = False,
symprec: float = 0.4,
angle_tolerance=5,
angle_tolerance: float = 5,
**kwargs,
):
) -> None:
"""Write exciting input file.
Args:
celltype (str): Choice of unit cell. Can be either the unit cell
from self.structure ("unchanged"), the conventional cell
("conventional"), or the primitive unit cell ("primitive").
from self.structure ("unchanged"), the conventional cell
("conventional"), or the primitive unit cell ("primitive").
filename (str): Filename for exciting input.
cartesian (bool): Whether the atomic positions are provided in
Cartesian or unit-cell coordinates. Default is False.
Cartesian or unit-cell coordinates. Default is False.
bandstr (bool): Whether the bandstructure path along the
HighSymmKpath is included in the input file. Only supported if the
celltype is set to "primitive". Default is False.
HighSymmKpath is included in the input file. Only supported if the
celltype is set to "primitive". Default is False.
symprec (float): Tolerance for the symmetry finding. Default is 0.4.
angle_tolerance (float): Angle tolerance for the symmetry finding.
Default is 5.
Default is 5.
**kwargs: Additional parameters for the input file.
"""
try:
Expand All @@ -380,7 +369,7 @@ def write_file(

# Missing PrettyPrint option in the current version of xml.etree.cElementTree
@staticmethod
def _indent(elem, level=0):
def _indent(elem, level: int = 0) -> None:
"""
Helper method to indent elements.
Expand All @@ -401,7 +390,7 @@ def _indent(elem, level=0):
elif level and (not elem.tail or not elem.tail.strip()):
elem.tail = i

def _dicttoxml(self, paramdict_, element):
def _dicttoxml(self, paramdict_, element) -> None:
for key, value in paramdict_.items():
if isinstance(value, str) and key == "text()":
element.text = value
Expand Down

0 comments on commit 1b501fc

Please sign in to comment.