Skip to content

Commit

Permalink
Default check_stable to False in `PatchedPhaseDiagram.get_decomp_…
Browse files Browse the repository at this point in the history
…and_e_above_hull()` for speed (#2842)

* default check_stable to False in PatchedPhaseDiagram.get_decomp_and_e_above_hull() for speed

* clean up and types

* fix test_get_decomp_and_e_above_hull()

on pymatgen/analysis/tests/test_phase_diagram.py:766: AssertionError:
assert decomp_pd == decomp_ppd
E           assert {PDEntry : V4...370.4576: 1.0} == {PDEntry : V4...0000000000142}

* remove superfluous list comprehensions
  • Loading branch information
janosh authored Feb 12, 2023
1 parent 9a785ff commit 87a12cc
Show file tree
Hide file tree
Showing 12 changed files with 88 additions and 101 deletions.
50 changes: 23 additions & 27 deletions pymatgen/alchemy/materials.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ def __getattr__(self, name) -> Any:
def __len__(self) -> int:
return len(self.history)

def append_transformation(self, transformation, return_alternatives=False, clear_redo=True):
def append_transformation(
self, transformation, return_alternatives: bool = False, clear_redo: bool = True
) -> list[TransformedStructure] | None:
"""
Appends a transformation to the TransformedStructure.
Expand Down Expand Up @@ -129,30 +131,30 @@ def append_transformation(self, transformation, return_alternatives=False, clear
for x in ranked_list[1:]:
s = x.pop("structure")
actual_transformation = x.pop("transformation", transformation)
hdict = actual_transformation.as_dict()
hdict["input_structure"] = input_structure
hdict["output_parameters"] = x
h_dict = actual_transformation.as_dict()
h_dict["input_structure"] = input_structure
h_dict["output_parameters"] = x
self.final_structure = s
d = self.as_dict()
d["history"].append(hdict)
d["history"].append(h_dict)
d["final_structure"] = s.as_dict()
alts.append(TransformedStructure.from_dict(d))

x = ranked_list[0]
s = x.pop("structure")
actual_transformation = x.pop("transformation", transformation)
hdict = actual_transformation.as_dict()
hdict["input_structure"] = self.final_structure.as_dict()
hdict["output_parameters"] = x
self.history.append(hdict)
h_dict = actual_transformation.as_dict()
h_dict["input_structure"] = self.final_structure.as_dict()
h_dict["output_parameters"] = x
self.history.append(h_dict)
self.final_structure = s
return alts

s = transformation.apply_transformation(self.final_structure)
hdict = transformation.as_dict()
hdict["input_structure"] = self.final_structure.as_dict()
hdict["output_parameters"] = {}
self.history.append(hdict)
h_dict = transformation.as_dict()
h_dict["input_structure"] = self.final_structure.as_dict()
h_dict["output_parameters"] = {}
self.history.append(h_dict)
self.final_structure = s
return None

Expand All @@ -164,9 +166,9 @@ def append_filter(self, structure_filter: AbstractStructureFilter) -> None:
structure_filter (StructureFilter): A filter implementing the
AbstractStructureFilter API. Tells transmuter what structures to retain.
"""
hdict = structure_filter.as_dict()
hdict["input_structure"] = self.final_structure.as_dict()
self.history.append(hdict)
h_dict = structure_filter.as_dict()
h_dict["input_structure"] = self.final_structure.as_dict()
self.history.append(h_dict)

def extend_transformations(
self, transformations: list[AbstractTransformation], return_alternatives: bool = False
Expand All @@ -186,7 +188,7 @@ def extend_transformations(

def get_vasp_input(self, vasp_input_set: type[VaspInputSet] = MPRelaxSet, **kwargs) -> dict[str, Any]:
"""
Returns VASP input as a dict of vasp objects.
Returns VASP input as a dict of VASP objects.
Args:
vasp_input_set (pymatgen.io.vaspio_set.VaspInputSet): input set
Expand Down Expand Up @@ -219,13 +221,7 @@ def write_vasp_input(
json.dump(self.as_dict(), fp)

def __str__(self) -> str:
output = [
"Current structure",
"------------",
str(self.final_structure),
"\nHistory",
"------------",
]
output = ["Current structure", "------------", str(self.final_structure), "\nHistory", "------------"]
for h in self.history:
h.pop("input_structure", None)
output.append(str(h))
Expand Down Expand Up @@ -259,8 +255,8 @@ def structures(self) -> list[Structure]:
Copy of all structures in the TransformedStructure. A
structure is stored after every single transformation.
"""
hstructs = [Structure.from_dict(s["input_structure"]) for s in self.history if "input_structure" in s]
return hstructs + [self.final_structure]
h_structs = [Structure.from_dict(s["input_structure"]) for s in self.history if "input_structure" in s]
return h_structs + [self.final_structure]

@staticmethod
def from_cif_string(
Expand All @@ -274,7 +270,7 @@ def from_cif_string(
Args:
cif_string (str): Input cif string. Should contain only one
structure. For cifs containing multiple structures, please use
structure. For CIFs containing multiple structures, please use
CifTransmuter.
transformations (list[Transformation]): Sequence of transformations
to be applied to the input structure.
Expand Down
2 changes: 1 addition & 1 deletion pymatgen/analysis/diffraction/tests/test_tem.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def test_TEM_dots(self):
points = c.generate_points(-2, 2)
structure = self.get_structure("Si")
dots = c.tem_dots(structure, points)
assert all([isinstance(x, tuple) for x in dots])
assert all(isinstance(x, tuple) for x in dots)

def test_get_pattern(self):
# All dependencies in get_pattern method are tested.
Expand Down
38 changes: 26 additions & 12 deletions pymatgen/analysis/phase_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -1722,77 +1722,91 @@ def get_equilibrium_reaction_energy(self, entry: Entry) -> float:

# NOTE the following functions are not implemented for PatchedPhaseDiagram

def get_decomp_and_e_above_hull(
self,
entry: PDEntry,
allow_negative: bool = False,
check_stable: bool = False,
on_error: Literal["raise", "warn", "ignore"] = "raise",
) -> tuple[dict[PDEntry, float], float] | tuple[None, None]:
"""Same as method on parent class PhaseDiagram except check_stable defaults to False
for speed. See https://github.com/materialsproject/pymatgen/issues/2840 for details.
"""
return super().get_decomp_and_e_above_hull(
entry=entry, allow_negative=allow_negative, check_stable=check_stable, on_error=on_error
)

def _get_facet_and_simplex(self):
"""
Not Implemented - See PhaseDiagram
"""
raise NotImplementedError("`_get_facet_and_simplex` not implemented for `PatchedPhaseDiagram`")
raise NotImplementedError("_get_facet_and_simplex() not implemented for PatchedPhaseDiagram")

def _get_all_facets_and_simplexes(self):
"""
Not Implemented - See PhaseDiagram
"""
raise NotImplementedError("`_get_all_facets_and_simplexes` not implemented for `PatchedPhaseDiagram`")
raise NotImplementedError("_get_all_facets_and_simplexes() not implemented for PatchedPhaseDiagram")

def _get_facet_chempots(self):
"""
Not Implemented - See PhaseDiagram
"""
raise NotImplementedError("`_get_facet_chempots` not implemented for `PatchedPhaseDiagram`")
raise NotImplementedError("_get_facet_chempots() not implemented for PatchedPhaseDiagram")

def _get_simplex_intersections(self):
"""
Not Implemented - See PhaseDiagram
"""
raise NotImplementedError("`_get_simplex_intersections` not implemented for `PatchedPhaseDiagram`")
raise NotImplementedError("_get_simplex_intersections() not implemented for PatchedPhaseDiagram")

def get_composition_chempots(self):
"""
Not Implemented - See PhaseDiagram
"""
raise NotImplementedError("`get_composition_chempots` not implemented for `PatchedPhaseDiagram`")
raise NotImplementedError("get_composition_chempots() not implemented for PatchedPhaseDiagram")

def get_all_chempots(self):
"""
Not Implemented - See PhaseDiagram
"""
raise NotImplementedError("`get_all_chempots` not implemented for `PatchedPhaseDiagram`")
raise NotImplementedError("get_all_chempots() not implemented for PatchedPhaseDiagram")

def get_transition_chempots(self):
"""
Not Implemented - See PhaseDiagram
"""
raise NotImplementedError("`get_transition_chempots` not implemented for `PatchedPhaseDiagram`")
raise NotImplementedError("get_transition_chempots() not implemented for PatchedPhaseDiagram")

def get_critical_compositions(self):
"""
Not Implemented - See PhaseDiagram
"""
raise NotImplementedError("`get_critical_compositions` not implemented for `PatchedPhaseDiagram`")
raise NotImplementedError("get_critical_compositions() not implemented for PatchedPhaseDiagram")

def get_element_profile(self):
"""
Not Implemented - See PhaseDiagram
"""
raise NotImplementedError("`get_element_profile` not implemented for `PatchedPhaseDiagram`")
raise NotImplementedError("get_element_profile() not implemented for PatchedPhaseDiagram")

def get_chempot_range_map(self):
"""
Not Implemented - See PhaseDiagram
"""
raise NotImplementedError("`get_chempot_range_map` not implemented for `PatchedPhaseDiagram`")
raise NotImplementedError("get_chempot_range_map() not implemented for PatchedPhaseDiagram")

def getmu_vertices_stability_phase(self):
"""
Not Implemented - See PhaseDiagram
"""
raise NotImplementedError("`getmu_vertices_stability_phase` not implemented for `PatchedPhaseDiagram`")
raise NotImplementedError("getmu_vertices_stability_phase() not implemented for PatchedPhaseDiagram")

def get_chempot_range_stability_phase(self):
"""
Not Implemented - See PhaseDiagram
"""
raise NotImplementedError("`get_chempot_range_stability_phase` not implemented for `PatchedPhaseDiagram`")
raise NotImplementedError("get_chempot_range_stability_phase() not implemented for PatchedPhaseDiagram")

def _get_pd_patch_for_space(self, space: frozenset[Element]) -> tuple[frozenset[Element], PhaseDiagram]:
"""
Expand Down
14 changes: 5 additions & 9 deletions pymatgen/analysis/tests/test_phase_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,14 +246,10 @@ def test_all_entries_hulldata(self):
assert len(self.pd.all_entries_hulldata) == 490

def test_planar_inputs(self):
e1 = PDEntry("H", 0)
e2 = PDEntry("He", 0)
e3 = PDEntry("Li", 0)
e4 = PDEntry("Be", 0)
e5 = PDEntry("B", 0)
e6 = PDEntry("Rb", 0)
elems = ["H", "He", "Li", "Be", "B", "Rb"]
e1, e2, e3, e4, e5, e6 = (PDEntry(elem, 0) for elem in elems)

pd = PhaseDiagram([e1, e2, e3, e4, e5, e6], map(Element, ["Rb", "He", "B", "Be", "Li", "H"]))
pd = PhaseDiagram([e1, e2, e3, e4, e5, e6], map(Element, elems))

assert len(pd.facets) == 1

Expand All @@ -264,7 +260,7 @@ def test_get_e_above_hull(self):
for entry in self.pd.all_entries:
for entry in self.pd.stable_entries:
decomp, e_hull = self.pd.get_decomp_and_e_above_hull(entry)
assert e_hull < 1e-11, "Stable entries should have e above hull of zero!"
assert e_hull < 1e-11, "Stable entries should have e_above_hull of zero!"
assert decomp[entry] == 1, "Decomposition of stable entry should be itself."
else:
e_ah = self.pd.get_e_above_hull(entry)
Expand Down Expand Up @@ -762,7 +758,7 @@ def test_get_hull_energy(self):
def test_get_decomp_and_e_above_hull(self):
for entry in self.pd.stable_entries:
decomp_pd, e_above_hull_pd = self.pd.get_decomp_and_e_above_hull(entry)
decomp_ppd, e_above_hull_ppd = self.ppd.get_decomp_and_e_above_hull(entry)
decomp_ppd, e_above_hull_ppd = self.ppd.get_decomp_and_e_above_hull(entry, check_stable=True)
assert decomp_pd == decomp_ppd
assert np.isclose(e_above_hull_pd, e_above_hull_ppd)

Expand Down
4 changes: 2 additions & 2 deletions pymatgen/apps/borg/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def get_valid_paths(self, path):
paths, depending on what kind of data you are assimilating. For
example, if you are assimilating VASP runs, you are only interested in
directories containing vasprun.xml files. On the other hand, if you are
interested converting all POSCARs in a directory tree to cifs for
interested converting all POSCARs in a directory tree to CIFs for
example, you will want the file paths.
Args:
Expand All @@ -73,7 +73,7 @@ def get_valid_paths(self, path):

class VaspToComputedEntryDrone(AbstractDrone):
"""
VaspToEntryDrone assimilates directories containing vasp output to
VaspToEntryDrone assimilates directories containing VASP output to
ComputedEntry/ComputedStructureEntry objects. There are some restrictions
on the valid directory structures:
Expand Down
6 changes: 3 additions & 3 deletions pymatgen/core/tests/test_surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,7 @@ def test_get_symmetrically_distinct_miller_indices(self):
# Now try a trigonal system.
indices = get_symmetrically_distinct_miller_indices(self.trigBi, 2, return_hkil=True)
assert len(indices) == 17
assert all([len(hkl) == 4 for hkl in indices])
assert all(len(hkl) == 4 for hkl in indices)

def test_get_symmetrically_equivalent_miller_indices(self):
# Tests to see if the function obtains all equivalent hkl for cubic (100)
Expand All @@ -805,14 +805,14 @@ def test_get_symmetrically_equivalent_miller_indices(self):
(-1, 0, 0),
]
indices = get_symmetrically_equivalent_miller_indices(self.cscl, (1, 0, 0))
assert all([hkl in indices for hkl in indices001])
assert all(hkl in indices for hkl in indices001)

# Tests to see if it captures expanded Miller indices in the family e.g. (001) == (002)
hcp_indices_100 = get_symmetrically_equivalent_miller_indices(self.Mg, (1, 0, 0))
hcp_indices_200 = get_symmetrically_equivalent_miller_indices(self.Mg, (2, 0, 0))
assert len(hcp_indices_100) * 2 == len(hcp_indices_200)
assert len(hcp_indices_100) == 6
assert all([len(hkl) == 4 for hkl in hcp_indices_100])
assert all(len(hkl) == 4 for hkl in hcp_indices_100)

def test_generate_all_slabs(self):
slabs = generate_all_slabs(self.cscl, 1, 10, 10)
Expand Down
10 changes: 5 additions & 5 deletions pymatgen/core/tests/test_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,22 @@ def _get_lattice_species_and_coords(self):
return lattice, species, frac_coords

def test_single_index_slice(self):
assert all([self.traj[i] == self.structures[i] for i in range(0, len(self.structures), 19)])
assert all(self.traj[i] == self.structures[i] for i in range(0, len(self.structures), 19))

def test_slice(self):
sliced_traj = self.traj[2:99:3]
sliced_traj_from_structs = Trajectory.from_structures(self.structures[2:99:3])

if len(sliced_traj) == len(sliced_traj_from_structs):
assert all([sliced_traj[i] == sliced_traj_from_structs[i] for i in range(len(sliced_traj))])
assert all(sliced_traj[i] == sliced_traj_from_structs[i] for i in range(len(sliced_traj)))
else:
raise AssertionError

sliced_traj = self.traj[:-4:2]
sliced_traj_from_structs = Trajectory.from_structures(self.structures[:-4:2])

if len(sliced_traj) == len(sliced_traj_from_structs):
assert all([sliced_traj[i] == sliced_traj_from_structs[i] for i in range(len(sliced_traj))])
assert all(sliced_traj[i] == sliced_traj_from_structs[i] for i in range(len(sliced_traj)))
else:
raise AssertionError

Expand All @@ -65,7 +65,7 @@ def test_list_slice(self):
sliced_traj_from_structs = Trajectory.from_structures([self.structures[i] for i in [10, 30, 70]])

if len(sliced_traj) == len(sliced_traj_from_structs):
assert all([sliced_traj[i] == sliced_traj_from_structs[i] for i in range(len(sliced_traj))])
assert all(sliced_traj[i] == sliced_traj_from_structs[i] for i in range(len(sliced_traj)))
else:
raise AssertionError

Expand All @@ -74,7 +74,7 @@ def test_conversion(self):
self.traj.to_displacements()
self.traj.to_positions()

assert all([struct == self.structures[i] for i, struct in enumerate(self.traj)])
assert all(struct == self.structures[i] for i, struct in enumerate(self.traj))

def test_site_properties(self):
lattice, species, frac_coords = self._get_lattice_species_and_coords()
Expand Down
2 changes: 1 addition & 1 deletion pymatgen/ext/tests/test_matproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ def test_pourbaix_mpr_pipeline(self):
# Test against ion sets with multiple equivalent ions (Bi-V regression)
entries = self.rester.get_pourbaix_entries(["Bi", "V"])
pbx = PourbaixDiagram(entries, filter_solids=True, conc_dict={"Bi": 1e-8, "V": 1e-8})
assert all(["Bi" in entry.composition and "V" in entry.composition for entry in pbx.all_entries])
assert all("Bi" in entry.composition and "V" in entry.composition for entry in pbx.all_entries)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion pymatgen/io/tests/test_xyz.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_from_string(self):
assert site.species_string == sp[i]
assert len(site.coords) == 3
if i == 0:
assert all([c == 0 for c in site.coords])
assert all(c == 0 for c in site.coords)

mol_str = """2
Random
Expand Down
Loading

0 comments on commit 87a12cc

Please sign in to comment.