Skip to content

Commit

Permalink
add LRU cache to structure matcher (#4036)
Browse files Browse the repository at this point in the history
* add LRU cache to  _get_reduced_structure computations

* pre-commit auto-fixes

* make structure hashable via as_dict

* make structure hash recursive as_dict, change structure test to check hashability

* precommit

* moved computation using lru_cache out of class method to avoid memory leakage issue

* pre-commit auto-fixes

* fix structure matcher caching, fix a few tests (mcsqs wrong file destination and missing pytest approx in TestBSPlot)

* precommit

* add suggested SiteOrderedIStructure from @kbuma

* pre-commit auto-fixes

* add cast in eq for SiteOrderedIStructure to make mypy happy

* pre-commit auto-fixes

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: esoteric-ephemera <[email protected]>
  • Loading branch information
3 people authored Sep 6, 2024
1 parent 54cdebc commit 51ea7de
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 8 deletions.
46 changes: 41 additions & 5 deletions src/pymatgen/analysis/structure_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

import abc
import itertools
from typing import TYPE_CHECKING
from functools import lru_cache
from typing import TYPE_CHECKING, cast

import numpy as np
from monty.json import MSONable

from pymatgen.core import Composition, Lattice, Structure, get_el_sp
from pymatgen.core import SETTINGS, Composition, IStructure, Lattice, Structure, get_el_sp
from pymatgen.optimization.linear_assignment import LinearAssignment
from pymatgen.util.coord import lattice_points_in_supercell
from pymatgen.util.coord_cython import is_coord_subset_pbc, pbc_shortest_vectors
Expand All @@ -29,6 +30,31 @@
__email__ = "[email protected]"
__status__ = "Production"
__date__ = "Dec 3, 2012"
LRU_CACHE_SIZE = SETTINGS.get("STRUCTURE_MATCHER_CACHE_SIZE", 300)


class SiteOrderedIStructure(IStructure):
"""
Imutable structure where the order of sites matters.
In caching reduced structures (see `StructureMatcher._get_reduced_structure`)
the order of input sites can be important.
In general, the order of sites in a structure does not matter, but when
a method like `StructureMatcher.get_s2_like_s1` tries to put s2's sites in
the same order as s1, the site order matters.
"""

def __eq__(self, other: object) -> bool:
"""Check for IStructure equality and same site order."""
if not super().__eq__(other):
return False
other = cast(SiteOrderedIStructure, other) # make mypy happy

return list(self.sites) == list(other.sites)

def __hash__(self) -> int:
"""Use the composition hash for now."""
return super().__hash__()


class AbstractComparator(MSONable, abc.ABC):
Expand Down Expand Up @@ -939,16 +965,26 @@ def _anonymous_match(
break
return matches

@classmethod
def _get_reduced_structure(cls, struct: Structure, primitive_cell: bool = True, niggli: bool = True) -> Structure:
"""Helper method to find a reduced structure."""
@staticmethod
@lru_cache(maxsize=LRU_CACHE_SIZE)
def _get_reduced_istructure(
struct: SiteOrderedIStructure, primitive_cell: bool = True, niggli: bool = True
) -> SiteOrderedIStructure:
"""Helper method to find a reduced imutable structure."""
reduced = struct.copy()
if niggli:
reduced = reduced.get_reduced_structure(reduction_algo="niggli")
if primitive_cell:
reduced = reduced.get_primitive_structure()
return reduced

@classmethod
def _get_reduced_structure(cls, struct: Structure, primitive_cell: bool = True, niggli: bool = True) -> Structure:
"""Helper method to find a reduced structure."""
return Structure.from_sites(
cls._get_reduced_istructure(SiteOrderedIStructure.from_sites(struct), primitive_cell, niggli)
)

def get_rms_anonymous(self, struct1, struct2):
"""
Performs an anonymous fitting, which allows distinct species in one
Expand Down
2 changes: 1 addition & 1 deletion tests/electronic_structure/test_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def test_bs_plot_data(self):

def test_get_ticks(self):
assert self.plotter.get_ticks()["label"][5] == "K", "wrong tick label"
assert self.plotter.get_ticks()["distance"][5] == 2.406607625322699, "wrong tick distance"
assert self.plotter.get_ticks()["distance"][5] == pytest.approx(2.406607625322699), "wrong tick distance"

# Minimal baseline testing for get_plot. not a true test. Just checks that
# it can actually execute.
Expand Down
4 changes: 2 additions & 2 deletions tests/transformations/test_advanced_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def test_apply_transformation(self):
@pytest.mark.skipif(not mcsqs_cmd, reason="mcsqs not present.")
class TestSQSTransformation(PymatgenTest):
def test_apply_transformation(self):
pzt_structs = loadfn(f"{TEST_FILES_DIR}/mcsqs/pzt-structs.json")
pzt_structs = loadfn(f"{TEST_FILES_DIR}/io/atat/mcsqs/pzt-structs.json")
trans = SQSTransformation(scaling=[2, 1, 1], search_time=0.01, instances=1, wd=0)
# nonsensical example just for testing purposes
struct = self.get_structure("Pb2TiZrO6").copy()
Expand All @@ -605,7 +605,7 @@ def test_apply_transformation(self):

def test_return_ranked_list(self):
# list of structures
pzt_structs_2 = loadfn(f"{TEST_FILES_DIR}/mcsqs/pzt-structs-2.json")
pzt_structs_2 = loadfn(f"{TEST_FILES_DIR}/io/atat/mcsqs/pzt-structs-2.json")

n_structs_expected = 1
sqs_kwargs = {"scaling": 2, "search_time": 0.01, "instances": 8, "wd": 0}
Expand Down

0 comments on commit 51ea7de

Please sign in to comment.