Skip to content

Commit

Permalink
Add type annotations for io.vasp.inputs/optics (#3740)
Browse files Browse the repository at this point in the history
* some easy mypy fixes

* ruff check pymatgen/io/vasp --select ANN204 --unsafe-fixes --fix

* add type for io.vasp.help

* add timeout 60 sec for requests.get

* pre-commit auto-fixes

* add timeout 60 sec for requests.get

* fix default value of default_names

* finish poscar.from_str

* finish Poscar

* finish Incar

* temp save for potcarsingle

* put dunder methods close and to the top

* put properties close and to the top

* put properties close and to the top

* replace str with PathLike

* add types for optics

* suppress some overload

* remove None type from completely untyped classes

* pre-commit auto-fixes

* fix type error outside io.vasp

* check for None in Incar init

* ruff fix

* allow None

* fix types

* replace `defaultdict` with specific type

* revert accidental changes

* fix test

* fix tests

---------

Co-authored-by: Janosh Riebesell <[email protected]>
  • Loading branch information
DanielYang59 and janosh authored Apr 21, 2024
1 parent 666e1d7 commit ad6eafe
Show file tree
Hide file tree
Showing 36 changed files with 949 additions and 830 deletions.
2 changes: 1 addition & 1 deletion dev_scripts/update_pt_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def gen_iupac_ordering():
def add_electron_affinities():
"""Update the periodic table data file with electron affinities."""

req = requests.get("https://wikipedia.org/wiki/Electron_affinity_(data_page)")
req = requests.get("https://wikipedia.org/wiki/Electron_affinity_(data_page)", timeout=60)
soup = BeautifulSoup(req.text, "html.parser")
table = None
for table in soup.find_all("table"):
Expand Down
2 changes: 1 addition & 1 deletion pymatgen/alchemy/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def __init__(self, existing_structures, structure_matcher=None, symprec=None):
structure matcher is used. A recommended value is 1e-5.
"""
self.symprec = symprec
self.structure_list = []
self.structure_list: list = []
self.existing_structures = existing_structures
if isinstance(structure_matcher, dict):
self.structure_matcher = StructureMatcher.from_dict(structure_matcher)
Expand Down
2 changes: 1 addition & 1 deletion pymatgen/alchemy/transmuters.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def __init__(self, cif_string, transformations=None, primitive=True, extend_coll
"""
transformed_structures = []
lines = cif_string.split("\n")
structure_data = []
structure_data: list = []
read_data = False
for line in lines:
if re.match(r"^\s*data", line):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,7 @@ def __init__(self, permutations_safe_override=False, only_symbols=None):

self.minpoints = {}
self.maxpoints = {}
self.separations_cg = {}
self.separations_cg: dict[int, dict] = {}
for cn in range(6, 21):
for cg in self.get_implemented_geometries(coordination=cn):
if only_symbols is not None and cg.ce_symbol not in only_symbols:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2085,7 +2085,7 @@ def __init__(self, coord_geoms=None):
coord_geoms: coordination geometries to be added to the chemical environment.
"""
if coord_geoms is None:
self.coord_geoms = {}
self.coord_geoms: dict = {}
else:
raise NotImplementedError(
"Constructor for ChemicalEnvironments with the coord_geoms argument is not yet implemented"
Expand Down
2 changes: 1 addition & 1 deletion pymatgen/analysis/ewald.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ def __init__(self, matrix, m_list, num_to_return=1, algo=ALGO_FAST):
if algo == EwaldMinimizer.ALGO_COMPLETE:
raise NotImplementedError("Complete algo not yet implemented for EwaldMinimizer")

self._output_lists = []
self._output_lists: list = []
# Tag that the recurse function looks at each level. If a method
# sets this to true it breaks the recursion and stops the search.
self._finished = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(self, lambda_table=None, alpha=-5):

# create Z and px
self.Z = 0
self._px = defaultdict(float)
self._px: dict[Species, float] = defaultdict(float)
for s1, s2 in itertools.product(self.species, repeat=2):
value = math.exp(self.get_lambda(s1, s2))
self._px[s1] += value / 2
Expand Down
4 changes: 2 additions & 2 deletions pymatgen/analysis/wulff.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def __init__(self, normal, e_surf, normal_pt, dual_pt, index, m_ind_orig, miller
self.index = index
self.m_ind_orig = m_ind_orig
self.miller = miller
self.points = []
self.outer_lines = []
self.points: list = []
self.outer_lines: list = []


class WulffShape:
Expand Down
43 changes: 24 additions & 19 deletions pymatgen/apps/battery/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@

from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import plotly.graph_objects as go

from pymatgen.util.plotting import pretty_plot

if TYPE_CHECKING:
from pymatgen.apps.battery.battery_abc import AbstractElectrode

__author__ = "Shyue Ping Ong"
__copyright__ = "Copyright 2012, The Materials Project"
__version__ = "0.1"
Expand All @@ -18,7 +23,7 @@
class VoltageProfilePlotter:
"""A plotter to make voltage profile plots for batteries."""

def __init__(self, xaxis="capacity", hide_negative=False):
def __init__(self, xaxis: str = "capacity", hide_negative: bool = False) -> None:
"""
Args:
xaxis: The quantity to use as the xaxis. Can be either
Expand All @@ -28,11 +33,11 @@ def __init__(self, xaxis="capacity", hide_negative=False):
- frac_x: the atomic fraction of the working ion
hide_negative: If True only plot the voltage steps above zero.
"""
self._electrodes = {}
self._electrodes: dict[str, AbstractElectrode] = {}
self.xaxis = xaxis
self.hide_negative = hide_negative

def add_electrode(self, electrode, label=None):
def add_electrode(self, electrode: AbstractElectrode, label: str | None = None) -> None:
"""Add an electrode to the plot.
Args:
Expand All @@ -41,11 +46,11 @@ def add_electrode(self, electrode, label=None):
label: A label for the electrode. If None, defaults to a counting
system, i.e. 'Electrode 1', 'Electrode 2', ...
"""
if not label:
if label is None:
label = f"Electrode {len(self._electrodes) + 1}"
self._electrodes[label] = electrode

def get_plot_data(self, electrode, term_zero=True):
def get_plot_data(self, electrode: AbstractElectrode, term_zero: bool = True) -> tuple[list, list]:
"""
Args:
electrode: Electrode object
Expand Down Expand Up @@ -82,7 +87,7 @@ def get_plot_data(self, electrode, term_zero=True):
y.append(0)
return x, y

def get_plot(self, width=8, height=8, term_zero=True, ax: plt.Axes = None):
def get_plot(self, width: float = 8, height: float = 8, term_zero: bool = True, ax: plt.Axes = None) -> plt.Axes:
"""Returns a plot object.
Args:
Expand Down Expand Up @@ -112,12 +117,12 @@ def get_plot(self, width=8, height=8, term_zero=True, ax: plt.Axes = None):

def get_plotly_figure(
self,
width=800,
height=600,
font_dict=None,
term_zero=True,
width: float = 800,
height: float = 600,
font_dict: dict | None = None,
term_zero: bool = True,
**kwargs,
):
) -> plt.Figure:
"""Return plotly Figure object.
Args:
Expand Down Expand Up @@ -163,28 +168,28 @@ def get_plotly_figure(
fig.update_layout(template="plotly_white", title_x=0.5)
return fig

def _choose_best_x_label(self, formula, work_ion_symbol):
def _choose_best_x_label(self, formula: set[str], work_ion_symbol: set[str]) -> str:
if self.xaxis in {"capacity", "capacity_grav"}:
return "Capacity (mAh/g)"
if self.xaxis == "capacity_vol":
return "Capacity (Ah/l)"

formula = formula.pop() if len(formula) == 1 else None
_formula: str | None = formula.pop() if len(formula) == 1 else None

work_ion_symbol = work_ion_symbol.pop() if len(work_ion_symbol) == 1 else None
_work_ion_symbol: str | None = work_ion_symbol.pop() if len(work_ion_symbol) == 1 else None

if self.xaxis == "x_form":
if formula and work_ion_symbol:
return f"x in {work_ion_symbol}<sub>x</sub>{formula}"
if _formula and _work_ion_symbol:
return f"x in {_work_ion_symbol}<sub>x</sub>{_formula}"
return "x Work Ion per Host F.U."

if self.xaxis == "frac_x":
if work_ion_symbol:
return f"Atomic Fraction of {work_ion_symbol}"
if _work_ion_symbol:
return f"Atomic Fraction of {_work_ion_symbol}"
return "Atomic Fraction of Working Ion"
raise RuntimeError("No xaxis label can be determined")

def show(self, width=8, height=6):
def show(self, width: float = 8, height: float = 6) -> None:
"""Show the voltage profile plot.
Args:
Expand Down
2 changes: 1 addition & 1 deletion pymatgen/apps/borg/queen.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, drone, rootpath=None, number_of_drones=1):
"""
self._drone = drone
self._num_drones = number_of_drones
self._data = []
self._data: list = []

if rootpath:
if number_of_drones > 1:
Expand Down
6 changes: 3 additions & 3 deletions pymatgen/cli/pmg_analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,13 @@ def get_magnetizations(dir: str, ion_list: list[int]):
fullpath = os.path.join(parent, file)
outcar = Outcar(fullpath)
mags = outcar.magnetization
mags = [m["tot"] for m in mags]
all_ions = list(range(len(mags)))
_mags: list = [m["tot"] for m in mags]
all_ions = list(range(len(_mags)))
row.append(fullpath.lstrip("./"))
if ion_list:
all_ions = ion_list
for ion in all_ions:
row.append(str(mags[ion]))
row.append(str(_mags[ion]))
data.append(row)
if len(all_ions) > max_row:
max_row = len(all_ions)
Expand Down
4 changes: 2 additions & 2 deletions pymatgen/command_line/vampire_caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ def __init__(

# Call Vampire
with subprocess.Popen(["vampire-serial"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) as process:
stdout, stderr = process.communicate()
stdout = stdout.decode()
_stdout, stderr = process.communicate()
stdout: str = _stdout.decode()

if stderr:
van_helsing = stderr.decode()
Expand Down
2 changes: 1 addition & 1 deletion pymatgen/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def coincidents(self) -> list[Site]:
coincident_sites.append(self.sites[idx])
return coincident_sites

def __str__(self):
def __str__(self) -> str:
comp = self.composition
outs = [
f"Gb Summary ({comp.formula})",
Expand Down
2 changes: 1 addition & 1 deletion pymatgen/core/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2860,7 +2860,7 @@ def from_str( # type: ignore[override]
elif fmt_low == "poscar":
from pymatgen.io.vasp import Poscar

struct = Poscar.from_str(input_string, default_names=False, read_velocities=False, **kwargs).structure
struct = Poscar.from_str(input_string, default_names=None, read_velocities=False, **kwargs).structure
elif fmt_low == "cssr":
from pymatgen.io.cssr import Cssr

Expand Down
4 changes: 2 additions & 2 deletions pymatgen/ext/cod.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_structure_by_id(self, cod_id, **kwargs):
Returns:
A Structure.
"""
response = requests.get(f"http://{self.url}/cod/{cod_id}.cif")
response = requests.get(f"http://{self.url}/cod/{cod_id}.cif", timeout=60)
return Structure.from_str(response.text, fmt="cif", **kwargs)

@requires(which("mysql"), "mysql must be installed to use this query.")
Expand All @@ -112,7 +112,7 @@ def get_structure_by_formula(self, formula: str, **kwargs) -> list[dict[str, str
for line in text:
if line.strip():
cod_id, sg = line.split("\t")
response = requests.get(f"http://www.crystallography.net/cod/{cod_id.strip()}.cif")
response = requests.get(f"http://www.crystallography.net/cod/{cod_id.strip()}.cif", timeout=60)
try:
struct = Structure.from_str(response.text, fmt="cif", **kwargs)
structures.append({"structure": struct, "cod_id": int(cod_id), "sg": sg})
Expand Down
2 changes: 1 addition & 1 deletion pymatgen/ext/matproj_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1570,7 +1570,7 @@ def _check_get_download_info_url_by_task_id(self, prefix, task_ids) -> list[str]

@staticmethod
def _check_nomad_exist(url) -> bool:
response = requests.get(url=url)
response = requests.get(url=url, timeout=60)
if response.status_code != 200:
return False
content = json.loads(response.text)
Expand Down
4 changes: 2 additions & 2 deletions pymatgen/io/abinit/abitimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ def walk(cls, top=".", ext=".abo"):
def __init__(self):
"""Initialize object."""
# List of files that have been parsed.
self._filenames = []
self._filenames: list = []

# timers[filename][mpi_rank]
# contains the timer extracted from the file filename associated to the MPI rank mpi_rank.
self._timers = {}
self._timers: dict = {}

def __iter__(self):
return iter(self._timers)
Expand Down
8 changes: 4 additions & 4 deletions pymatgen/io/abinit/pseudos.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,10 +1033,10 @@ class PseudoParser:

def __init__(self):
# List of files that have been parsed successfully.
self._parsed_paths = []
self._parsed_paths: list = []

# List of files that could not been parsed.
self._wrong_paths = []
self._wrong_paths: list = []

def scan_directory(self, dirname, exclude_exts=(), exclude_fnames=()):
"""
Expand Down Expand Up @@ -1228,14 +1228,14 @@ def __init__(self, filepath):
# In this way, we know that only the first two bound states (with f and n attributes)
# should be used for constructing an initial guess for the wave functions.

self.valence_states = {}
self.valence_states: dict = {}
for node in root.find("valence_states"):
attrib = AttrDict(node.attrib)
assert attrib.id not in self.valence_states
self.valence_states[attrib.id] = attrib

# Parse the radial grids
self.rad_grids = {}
self.rad_grids: dict = {}
for node in root.findall("radial_grid"):
grid_params = node.attrib
gid = grid_params["id"]
Expand Down
14 changes: 7 additions & 7 deletions pymatgen/io/cp2k/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,15 @@ def __init__(self, filename, verbose=False, auto_load=False):
# IO Info
self.filename = filename
self.dir = os.path.dirname(filename)
self.filenames = {}
self.filenames: dict = {}
self.parse_files()
self.data = {}
self.data: dict = {}

# Material properties/results
self.input = self.initial_structure = self.lattice = self.final_structure = self.composition = None
self.efermi = self.vbm = self.cbm = self.band_gap = None
self.structures = []
self.ionic_steps = []
self.structures: list = []
self.ionic_steps: list = []

# parse the basic run parameters always
self.parse_cp2k_params()
Expand Down Expand Up @@ -171,7 +171,7 @@ def calculation_type(self):
@property
def project_name(self) -> str:
"""What project name was used for this calculation."""
return self.data.get("global").get("project_name")
return self.data.get("global", {}).get("project_name")

@property
def spin_polarized(self) -> bool:
Expand Down Expand Up @@ -1259,12 +1259,12 @@ def parse_dos(self, dos_file=None, pdos_files=None, ldos_files=None):
self.data["cdos"] = CompleteDos(self.final_structure, total_dos=tdos, pdoss=_ldoss)

@property
def complete_dos(self) -> CompleteDos:
def complete_dos(self) -> CompleteDos | None:
"""Returns complete dos object if it has been parsed."""
return self.data.get("cdos")

@property
def band_structure(self) -> BandStructure:
def band_structure(self) -> BandStructure | None:
"""Returns band structure object if it has been parsed."""
return self.data.get("band_structure")

Expand Down
2 changes: 1 addition & 1 deletion pymatgen/io/pwscf.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ def __init__(self, filename):
filename (str): Filename.
"""
self.filename = filename
self.data = defaultdict(list)
self.data: dict[str, list[float] | float] = defaultdict(list)
self.read_pattern(PWOutput.patterns)
for k, v in self.data.items():
if k == "energies":
Expand Down
4 changes: 2 additions & 2 deletions pymatgen/io/qchem/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ def get_str(self) -> str:
"""Return a string representation of an entire input file."""
return str(self)

def __str__(self):
combined_list = []
def __str__(self) -> str:
combined_list: list = []
# molecule section
combined_list.extend((self.molecule_template(self.molecule), "", self.rem_template(self.rem), ""))
# opt section
Expand Down
Loading

0 comments on commit ad6eafe

Please sign in to comment.