Skip to content

Commit

Permalink
replace defaultdict with specific type
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielYang59 committed Apr 21, 2024
1 parent e8055cd commit 2954c80
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(self, lambda_table=None, alpha=-5):

# create Z and px
self.Z = 0
self._px: defaultdict = defaultdict(float)
self._px: dict[Species, float] = defaultdict(float)
for s1, s2 in itertools.product(self.species, repeat=2):
value = math.exp(self.get_lambda(s1, s2))
self._px[s1] += value / 2
Expand Down
2 changes: 1 addition & 1 deletion pymatgen/io/pwscf.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ def __init__(self, filename):
filename (str): Filename.
"""
self.filename = filename
self.data: defaultdict = defaultdict(list)
self.data: dict[str, list[float] | float] = defaultdict(list)
self.read_pattern(PWOutput.patterns)
for k, v in self.data.items():
if k == "energies":
Expand Down
8 changes: 5 additions & 3 deletions pymatgen/io/vasp/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3803,9 +3803,11 @@ def __init__(self, filename):
headers.pop(0)
headers.pop(-1)

data: defaultdict = defaultdict(lambda: np.zeros((n_kpoints, n_bands, n_ions, len(headers))))
data: dict[Spin, np.ndarray] = defaultdict(
lambda: np.zeros((n_kpoints, n_bands, n_ions, len(headers)))
)

phase_factors: defaultdict = defaultdict(
phase_factors: dict[Spin, np.ndarray] = defaultdict(
lambda: np.full((n_kpoints, n_bands, n_ions, len(headers)), np.nan, dtype=np.complex128)
)
elif expr.match(line):
Expand Down Expand Up @@ -4270,7 +4272,7 @@ def __init__(self, filename):
lines = list(clean_lines(file.readlines()))
self._nspecs, self._natoms, self._ndisps = map(int, lines[0].split())
self._masses = map(float, lines[1].split())
self.data: defaultdict = defaultdict(dict)
self.data: dict[int, dict] = defaultdict(dict)
atom, disp = None, None
for idx, line in enumerate(lines[2:]):
v = list(map(float, line.split()))
Expand Down

0 comments on commit 2954c80

Please sign in to comment.