Skip to content

Commit

Permalink
Showing 11 changed files with 100 additions and 82 deletions.
2 changes: 1 addition & 1 deletion docs/_modules/pymatgen/analysis/defects/utils.html
Original file line number Diff line number Diff line change
@@ -642,7 +642,7 @@ <h1>Source code for pymatgen.analysis.defects.utils</h1><div class="highlight"><
<span class="sd"> determine if something are actually periodic boundary images of</span>
<span class="sd"> each other. Default is usually fine.</span>
<span class="sd"> max_cell_range (int): This is the range of periodic images to</span>
<span class="sd"> construct the Voronoi tesselation. A value of 1 means that we</span>
<span class="sd"> construct the Voronoi tessellation. A value of 1 means that we</span>
<span class="sd"> include all points from (x +- 1, y +- 1, z+- 1) in the</span>
<span class="sd"> voronoi construction. This is because the Voronoi poly</span>
<span class="sd"> extends beyond the standard unit cell because of PBC.</span>
2 changes: 1 addition & 1 deletion pymatgen/analysis/chemenv/utils/tests/test_graph_utils.py
Original file line number Diff line number Diff line change
@@ -696,7 +696,7 @@ def test_multigraph_cycle(self):
self.assertEqual(mgc.edge_indices, tuple([0, 1, 4, 0, 2, 2, 5, 3]))

# Testing all cases for a length-4 cycle
nodes_ref = tuple(FakeNodeWithEqLtMethods(inode) for inode in [0, 1, 2, 3])
nodes_ref = tuple(FakeNodeWithEqLtMethods(inode) for inode in range(4))
edges_ref = (3, 6, 9, 12)
for inodes, iedges in [
((0, 1, 2, 3), (3, 6, 9, 12)),
2 changes: 1 addition & 1 deletion pymatgen/core/composition.py
Original file line number Diff line number Diff line change
@@ -278,7 +278,7 @@ def almost_equals(self, other: Composition, rtol: float = 0.1, atol: float = 1e-
@property
def is_element(self) -> bool:
"""
True if composition is for an element.
True if composition is an element.
"""
return len(self) == 1

14 changes: 6 additions & 8 deletions pymatgen/core/tests/test_sites.py
Original file line number Diff line number Diff line change
@@ -122,10 +122,10 @@ def test_distance_from_point(self):

def test_distance_and_image(self):
other_site = PeriodicSite("Fe", np.array([1, 1, 1]), self.lattice)
(distance, image) = self.site.distance_and_image(other_site)
distance, image = self.site.distance_and_image(other_site)
assert round(abs(distance - 6.22494979899), 5) == 0
assert ([-1, -1, -1] == image).all()
(distance, image) = self.site.distance_and_image(other_site, [1, 0, 0])
distance, image = self.site.distance_and_image(other_site, [1, 0, 0])
assert round(abs(distance - 19.461500456028563), 5) == 0
# Test that old and new distance algo give the same ans for
# "standard lattices"
@@ -138,16 +138,16 @@ def test_distance_and_image(self):
site2 = PeriodicSite("Fe", np.array([0.99, 0.98, 0.97]), lattice)
assert get_distance_and_image_old(site1, site2)[0] > site1.distance_and_image(site2)[0]
site2 = PeriodicSite("Fe", np.random.rand(3), lattice)
(dist_old, jimage_old) = get_distance_and_image_old(site1, site2)
(dist_new, jimage_new) = site1.distance_and_image(site2)
dist_old, jimage_old = get_distance_and_image_old(site1, site2)
dist_new, jimage_new = site1.distance_and_image(site2)
assert dist_old - dist_new > -1e-8, "New distance algo should give smaller answers!"
assert (
not (abs(dist_old - dist_new) < 1e-8) ^ (jimage_old == jimage_new).all()
), "If old dist == new dist, images must be the same!"
latt = Lattice.from_parameters(3.0, 3.1, 10.0, 2.96, 2.0, 1.0)
site = PeriodicSite("Fe", [0.1, 0.1, 0.1], latt)
site2 = PeriodicSite("Fe", [0.99, 0.99, 0.99], latt)
(dist, img) = site.distance_and_image(site2)
dist, img = site.distance_and_image(site2)
assert round(abs(dist - 0.15495358379511573), 7) == 0
assert list(img) == [-11, 6, 0]

@@ -160,10 +160,8 @@ def test_is_periodic_image(self):
assert not self.site.is_periodic_image(other), "Different lattices should not be periodic images."

def test_equality(self):
other_site = PeriodicSite("Fe", np.array([1, 1, 1]), self.lattice)
assert self.site == self.site
assert not other_site == self.site
assert not self.site != self.site
other_site = PeriodicSite("Fe", np.array([1, 1, 1]), self.lattice)
assert other_site != self.site

def test_as_from_dict(self):
2 changes: 1 addition & 1 deletion pymatgen/core/tests/test_trajectory.py
Original file line number Diff line number Diff line change
@@ -20,7 +20,7 @@ def setUp(self):
self.structures = xdatcar.structures

def _check_traj_equality(self, traj_1, traj_2):
if np.sum(np.square(np.subtract(traj_1.lattice, traj_2.lattice))) > 0.0001:
if not np.allclose(traj_1.lattice, traj_2.lattice):
return False

if traj_1.species != traj_2.species:
8 changes: 4 additions & 4 deletions pymatgen/core/trajectory.py
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@
import warnings
from fnmatch import fnmatch
from pathlib import Path
from typing import Any, Union
from typing import Any, Dict, List, Tuple, Union

import numpy as np
from monty.io import zopen
@@ -32,9 +32,9 @@
__version__ = "0.1"
__date__ = "Jun 29, 2022"

Vector3D = tuple[float, float, float]
Matrix3D = tuple[Vector3D, Vector3D, Vector3D]
SitePropsType = Union[list[dict[Any, list[Any]]], dict[Any, list[Any]]]
Vector3D = Tuple[float, float, float]
Matrix3D = Tuple[Vector3D, Vector3D, Vector3D]
SitePropsType = Union[List[Dict[Any, List[Any]]], Dict[Any, List[Any]]]


class Trajectory(MSONable):
37 changes: 13 additions & 24 deletions pymatgen/entries/compatibility.py
Original file line number Diff line number Diff line change
@@ -785,7 +785,9 @@ class MaterialsProjectCompatibility(CorrectionsList):
valid.
"""

def __init__(self, compat_type="Advanced", correct_peroxide=True, check_potcar_hash=False):
def __init__(
self, compat_type: str = "Advanced", correct_peroxide: bool = True, check_potcar_hash: bool = False
) -> None:
"""
Args:
compat_type: Two options, GGA or Advanced. GGA means all GGA+U
@@ -828,11 +830,11 @@ class MaterialsProject2020Compatibility(Compatibility):

def __init__(
self,
compat_type="Advanced",
correct_peroxide=True,
check_potcar_hash=False,
config_file=None,
):
compat_type: str = "Advanced",
correct_peroxide: bool = True,
check_potcar_hash: bool = False,
config_file: str = None,
) -> None:
"""
Args:
compat_type: Two options, GGA or Advanced. GGA means all GGA+U
@@ -883,12 +885,10 @@ def __init__(
# load corrections and uncertainties
if config_file:
if os.path.isfile(config_file):
self.config_file = config_file
self.config_file: str | None = config_file
c = loadfn(self.config_file)
else:
raise ValueError(
f"Custom MaterialsProject2020Compatibility config_file ({config_file}) does not exist."
)
raise ValueError(f"Custom MaterialsProject2020Compatibility {config_file=} does not exist.")
else:
self.config_file = None
c = loadfn(os.path.join(MODULE_DIR, "MP2020Compatibility.yaml"))
@@ -986,20 +986,9 @@ def get_adjustments(self, entry: AnyCompEntry):
"formulas, e.g., Li2O2."
)

common_peroxides = [
"Li2O2",
"Na2O2",
"K2O2",
"Cs2O2",
"Rb2O2",
"BeO2",
"MgO2",
"CaO2",
"SrO2",
"BaO2",
]
common_superoxides = ["LiO2", "NaO2", "KO2", "RbO2", "CsO2"]
ozonides = ["LiO3", "NaO3", "KO3", "NaO5"]
common_peroxides = "Li2O2 Na2O2 K2O2 Cs2O2 Rb2O2 BeO2 MgO2 CaO2 SrO2 BaO2".split()
common_superoxides = "LiO2 NaO2 KO2 RbO2 CsO2".split()
ozonides = "LiO3 NaO3 KO3 NaO5".split()

if rform in common_peroxides:
ox_type = "peroxide"
83 changes: 43 additions & 40 deletions pymatgen/io/res.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
Converting from and back to pymatgen objects is expected to be reversible, i.e. you
should get the same Structure or ComputedStructureEntry back. On the other hand, converting
from and back to a string/file is not garunteed to be reversible, i.e. a diff on the output
from and back to a string/file is not guaranteed to be reversible, i.e. a diff on the output
would not be empty. The difference should be limited to whitespace, float precision, and the
REM entries.
@@ -15,10 +15,11 @@
import re
from collections.abc import Iterator
from dataclasses import dataclass
from typing import Callable, Literal
from typing import Any, Callable, Literal

import dateutil.parser # type: ignore
from monty.io import zopen
from monty.json import MSONable

from pymatgen.core.lattice import Lattice
from pymatgen.core.periodic_table import Element
@@ -41,16 +42,10 @@ class AirssTITL:
appearances: int

def __str__(self) -> str:
title_fmt = "TITL {:s} {:.2f} {:.4f} {:.5f} {:f} {:f} ({:s}) n - {:d}"
return title_fmt.format(
self.seed,
self.pressure,
self.volume,
self.energy,
self.integrated_spin_density,
self.integrated_absolute_spin_density,
self.spacegroup_label,
self.appearances,
return (
f"TITL {self.seed:s} {self.pressure:.2f} {self.volume:.4f} {self.energy:.5f} "
f"{self.integrated_spin_density:f} {self.integrated_absolute_spin_density:f} ({self.spacegroup_label:s}) "
f"n - {self.appearances:d}"
)


@@ -65,8 +60,10 @@ class ResCELL:
gamma: float

def __str__(self) -> str:
cell_fmt = "CELL {:.5f} {:.5f} {:.5f} {:.5f} {:.5f} {:.5f} {:.5f}"
return cell_fmt.format(self.unknown_field_1, self.a, self.b, self.c, self.alpha, self.beta, self.gamma)
return (
f"CELL {self.unknown_field_1:.5f} {self.a:.5f} {self.b:.5f} {self.c:.5f} "
f"{self.alpha:.5f} {self.beta:.5f} {self.gamma:.5f}"
)


@dataclass(frozen=True)
@@ -142,11 +139,11 @@ def __init__(self):
self.source: str = ""

def _parse_titl(self, line: str) -> AirssTITL | None:
"""Parses the TITL entry. Checks for airss values in the entry."""
"""Parses the TITL entry. Checks for AIRSS values in the entry."""
fields = line.split(maxsplit=6)
if len(fields) >= 6:
# this is probably an airss res file
seed, pressure, volume, energy, spin, absspin = fields[:6]
# this is probably an AIRSS res file
seed, pressure, volume, energy, spin, abs_spin = fields[:6]
spg, nap = "P1", "1"
if len(fields) == 7:
rest = fields[6]
@@ -157,10 +154,10 @@ def _parse_titl(self, line: str) -> AirssTITL | None:
nmin = rest.find("n -")
nap = rest[nmin + 4 :]
return AirssTITL(
seed, float(pressure), float(volume), float(energy), float(spin), float(absspin), spg, int(nap)
seed, float(pressure), float(volume), float(energy), float(spin), float(abs_spin), spg, int(nap)
)
else:
# there should at least be the first 6 fields if it's an airss res file
# there should at least be the first 6 fields if it's an AIRSS res file
# if it doesn't have them, then just stop looking
return None

@@ -243,7 +240,7 @@ def _parse_str(cls, source: str) -> Res:
return self._parse_txt()

@classmethod
def _parse_filename(cls, filename: str) -> Res:
def _parse_file(cls, filename: str) -> Res:
"""Parses the res file as a file."""
self = cls()
self.filename = filename
@@ -327,24 +324,24 @@ def write(self, filename: str) -> None:
return None


class ResProvider:
class ResProvider(MSONable):
"""
Provides access to elements of the res file in the form of familiar pymatgen objects.
"""

def __init__(self, res: Res):
def __init__(self, res: Res) -> None:
"""The :func:`from_str` and :func:`from_file` methods should be used instead of constructing this directly."""
self._res = res

@classmethod
def from_str(cls, string: str):
def from_str(cls, string: str) -> ResProvider:
"""Construct a Provider from a string."""
return cls(ResParser._parse_str(string))

@classmethod
def from_file(cls, filename: str):
def from_file(cls, filename: str) -> ResProvider:
"""Construct a Provider from a file."""
return cls(ResParser._parse_filename(filename))
return cls(ResParser._parse_file(filename))

@property
def rems(self) -> list[str]:
@@ -360,8 +357,8 @@ def lattice(self) -> Lattice:
@property
def sites(self) -> list[PeriodicSite]:
"""Construct a list of PeriodicSites from the res file."""
sfactag = self._res.SFAC
return [PeriodicSite(ion.specie, ion.pos, self.lattice) for ion in sfactag.ions]
sfac_tag = self._res.SFAC
return [PeriodicSite(ion.specie, ion.pos, self.lattice) for ion in sfac_tag.ions]

@property
def structure(self) -> Structure:
@@ -373,7 +370,7 @@ class AirssProvider(ResProvider):
"""
Provides access to the res file as does :class:`ResProvider`. This class additionally provides
access to fields in the TITL entry and various other fields found in the REM entries
that airss puts in the file. Values in the TITL entry that AIRSS could not get end up as 0.
that AIRSS puts in the file. Values in the TITL entry that AIRSS could not get end up as 0.
If the TITL entry is malformed, empty, or missing then attempting to construct this class
from a res file will raise a ResError.
@@ -403,14 +400,14 @@ def __init__(self, res: Res, parse_rems: Literal["gentle", "strict"] = "gentle")
self.parse_rems = parse_rems

@classmethod
def from_str(cls, string: str, parse_rems: Literal["gentle", "strict"] = "gentle"):
def from_str(cls, string: str, parse_rems: Literal["gentle", "strict"] = "gentle") -> AirssProvider:
"""Construct a Provider from a string."""
return cls(ResParser._parse_str(string), parse_rems)

@classmethod
def from_file(cls, filename: str, parse_rems: Literal["gentle", "strict"] = "gentle"):
def from_file(cls, filename: str, parse_rems: Literal["gentle", "strict"] = "gentle") -> AirssProvider:
"""Construct a Provider from a file."""
return cls(ResParser._parse_filename(filename), parse_rems)
return cls(ResParser._parse_file(filename), parse_rems)

@classmethod
def _parse_date(cls, string: str) -> datetime.date:
@@ -421,10 +418,10 @@ def _parse_date(cls, string: str) -> datetime.date:
date_string = match.group(0)
return dateutil.parser.parse(date_string) # type: ignore

def _raise_or_none(self, e: ParseError):
def _raise_or_none(self, err: ParseError) -> None:
if self.parse_rems != "strict":
return None
raise e
raise err

def get_run_start_info(self) -> tuple[datetime.date, str] | None:
"""
@@ -438,7 +435,7 @@ def get_run_start_info(self) -> tuple[datetime.date, str] | None:
date = self._parse_date(rem)
path = rem.split()[-1]
return date, path
return self._raise_or_none(ParseError("Could not find run started information."))
return self._raise_or_none(ParseError("Could not find run started information.")) # type: ignore

def get_castep_version(self) -> str | None:
"""
@@ -451,7 +448,7 @@ def get_castep_version(self) -> str | None:
if rem.strip().startswith("CASTEP"):
srem = rem.split()
return srem[1][:-1]
return self._raise_or_none(ParseError("Could not find castep version.")) # type: ignore
return self._raise_or_none(ParseError("Could not find CASTEP version.")) # type: ignore

def get_func_rel_disp(self) -> tuple[str, str, str] | None:
"""
@@ -468,7 +465,7 @@ def get_func_rel_disp(self) -> tuple[str, str, str] | None:

def get_cut_grid_gmax_fsbc(self) -> tuple[float, float, float, str] | None:
"""
Retirieves the cut-off energy, grid scale, Gmax, and finite basis set correction setting
Retrieves the cut-off energy, grid scale, Gmax, and finite basis set correction setting
from the REM entries.
Returns:
@@ -484,7 +481,7 @@ def get_mpgrid_offset_nkpts_spacing(
self,
) -> tuple[tuple[int, int, int], tuple[float, float, float], int, float] | None:
"""
Retrieves the MP grid, the grid offsets, number of kpoints, and maximim kpoint spacing.
Retrieves the MP grid, the grid offsets, number of kpoints, and maximum kpoint spacing.
Returns:
(MP grid), (offsets), No. kpts, max spacing)
@@ -509,7 +506,7 @@ def get_airss_version(self) -> tuple[str, datetime.date] | None:
date = self._parse_date(rem)
v = rem.split()[2]
return v, date
return self._raise_or_none(ParseError("Could not find line with airss version.")) # type: ignore
return self._raise_or_none(ParseError("Could not find line with AIRSS version.")) # type: ignore

def _get_compiler(self):
raise NotImplementedError()
@@ -558,12 +555,12 @@ def energy(self) -> float:

@property
def integrated_spin_density(self) -> float:
"""Corresponds to the last ``Integrated Spin Density`` in the castep file."""
"""Corresponds to the last ``Integrated Spin Density`` in the CASTEP file."""
return self._TITL.integrated_spin_density

@property
def integrated_absolute_spin_density(self) -> float:
"""Corresponds to the last ``Integrated |Spin Density|`` in the castep file."""
"""Corresponds to the last ``Integrated |Spin Density|`` in the CASTEP file."""
return self._TITL.integrated_absolute_spin_density

@property
@@ -590,6 +587,12 @@ def entry(self) -> ComputedStructureEntry:
"""
return ComputedStructureEntry(self.structure, self.energy, data={"rems": self.rems})

def as_dict(self, verbose: bool = True) -> dict[str, Any]:
"""Get dict with title fields, structure and rems of this AirssProvider."""
if verbose:
return super().as_dict()
return dict(**vars(self._res.TITL), structure=self.structure.as_dict(), rems=self.rems)


class ResIO:
"""
28 changes: 28 additions & 0 deletions pymatgen/io/tests/test_res.py
Original file line number Diff line number Diff line change
@@ -85,6 +85,34 @@ def test_raise(self, provider: AirssProvider):
with pytest.raises(ParseError):
prov.get_castep_version()

def test_as_dict(self, provider: AirssProvider):
verbose_dict = provider.as_dict(verbose=True)

assert sorted(verbose_dict) == ["@class", "@module", "@version", "parse_rems", "res"]

# test round-trip serialization/deserialization gives same dict
assert AirssProvider.from_dict(verbose_dict).as_dict() == verbose_dict

# non-verbose case
dct = provider.as_dict(verbose=False)
assert sorted(dct) == [
"appearances",
"energy",
"integrated_absolute_spin_density",
"integrated_spin_density",
"pressure",
"rems",
"seed",
"spacegroup_label",
"structure",
"volume",
]
assert dct["seed"] == "coc-115925-9326-14"
assert dct["energy"] == pytest.approx(-3904.2741)
assert dct["spacegroup_label"] == "R3"
assert dct["pressure"] == pytest.approx(15.0252)
assert dct["volume"] == pytest.approx(57.051984)


class TestStructureModule:
def test_structure_from_file(self):
2 changes: 1 addition & 1 deletion pymatgen/symmetry/analyzer.py
Original file line number Diff line number Diff line change
@@ -572,7 +572,7 @@ def get_conventional_standard_structure(self, international_monoclinic=True, kee
latt_type = self.get_lattice_type()
sorted_lengths = sorted(latt.abc)
sorted_dic = sorted(
({"vec": latt.matrix[i], "length": latt.abc[i], "orig_index": i} for i in [0, 1, 2]),
({"vec": latt.matrix[i], "length": latt.abc[i], "orig_index": i} for i in range(3)),
key=lambda k: k["length"],
)

2 changes: 1 addition & 1 deletion pymatgen/util/plotting.py
Original file line number Diff line number Diff line change
@@ -513,7 +513,7 @@ def get_ax_fig_plt(ax=None, **kwargs):

if ax is None:
fig = plt.figure(**kwargs)
ax = fig.add_subplot(1, 1, 1)
ax = fig.gca()
else:
fig = plt.gcf()

0 comments on commit 77f0845

Please sign in to comment.