From 0276c4846f01478b3a3e4878cae9b455a740ec47 Mon Sep 17 00:00:00 2001 From: Andrew-S-Rosen Date: Mon, 12 Feb 2024 22:23:03 -0800 Subject: [PATCH 1/6] Make a (de)serializable `PMGAtoms` class --- pymatgen/io/ase.py | 119 +++++++++++++++++++++++++++++++++++-------- tests/io/test_ase.py | 41 +++++++++++++-- 2 files changed, 135 insertions(+), 25 deletions(-) diff --git a/pymatgen/io/ase.py b/pymatgen/io/ase.py index 4516ee04542..0460e032c0a 100644 --- a/pymatgen/io/ase.py +++ b/pymatgen/io/ase.py @@ -3,7 +3,6 @@ Atoms object and pymatgen Structure objects. """ - from __future__ import annotations import warnings @@ -12,16 +11,17 @@ from typing import TYPE_CHECKING import numpy as np - +from monty.json import MSONable from pymatgen.core.structure import Molecule, Structure if TYPE_CHECKING: - from numpy.typing import ArrayLike + from typing import Any + from numpy.typing import ArrayLike from pymatgen.core.structure import SiteCollection try: - from ase import Atoms + from ase.atoms import Atoms from ase.calculators.singlepoint import SinglePointDFTCalculator from ase.constraints import FixAtoms from ase.spacegroup import Spacegroup @@ -38,25 +38,52 @@ __date__ = "Mar 8, 2012" +class PMGAtoms(Atoms, MSONable): + """A custom subclass of ASE Atoms that is MSONable, including `.as_dict()` and `.from_dict()` methods.""" + + def as_dict(s: Atoms) -> dict[str, Any]: + from ase.io.jsonio import encode + + # Normally, we would want to this to be a wrapper around atoms.todict() with @module and + # @class key-value pairs inserted. However, atoms.todict()/atoms.fromdict() is not meant + # to be used in a round-trip fashion and does not work properly with constraints. + # See ASE issue #1387. + return {"@module": "ase.atoms", "@class": "Atoms", "atoms_json": encode(s)} + + def from_dict(d: dict[str, Any]) -> Atoms: + from ase.io.jsonio import decode + + # Normally, we would want to this to be a wrapper around atoms.fromdict() with @module and + # @class key-value pairs inserted. However, atoms.todict()/atoms.fromdict() is not meant + # to be used in a round-trip fashion and does not work properly with constraints. + # See ASE issue #1387. + return decode(d["atoms_json"]) + + # NOTE: If making notable changes to this class, please ping @Andrew-S-Rosen on GitHub. # There are some subtleties in here, particularly related to spins/charges. class AseAtomsAdaptor: """Adaptor serves as a bridge between ASE Atoms and pymatgen objects.""" @staticmethod - def get_atoms(structure: SiteCollection, **kwargs) -> Atoms: + def get_atoms( + structure: SiteCollection, msonable: bool = True, **kwargs + ) -> PMGAtoms | Atoms: """ Returns ASE Atoms object from pymatgen structure or molecule. Args: structure (SiteCollection): pymatgen Structure or Molecule + msonable (bool): Whether to return a PMGAtoms object, which is MSONable. **kwargs: passed to the ASE Atoms constructor Returns: Atoms: ASE Atoms object """ if not ase_loaded: - raise PackageNotFoundError("AseAtomsAdaptor requires the ASE package. Use `pip install ase`") + raise PackageNotFoundError( + "AseAtomsAdaptor requires the ASE package. Use `pip install ase`" + ) if not structure.is_ordered: raise ValueError("ASE Atoms only supports ordered structures") @@ -70,7 +97,12 @@ def get_atoms(structure: SiteCollection, **kwargs) -> Atoms: pbc = False cell = None - atoms = Atoms(symbols=symbols, positions=positions, pbc=pbc, cell=cell, **kwargs) + atoms = Atoms( + symbols=symbols, positions=positions, pbc=pbc, cell=cell, **kwargs + ) + + if msonable: + atoms = PMGAtoms(atoms) if "tags" in structure.site_properties: atoms.set_tags(structure.site_properties["tags"]) @@ -112,7 +144,9 @@ def get_atoms(structure: SiteCollection, **kwargs) -> Atoms: atoms.spin_multiplicity = structure.spin_multiplicity # Get the oxidation states from the structure - oxi_states: list[float | None] = [getattr(site.specie, "oxi_state", None) for site in structure] + oxi_states: list[float | None] = [ + getattr(site.specie, "oxi_state", None) for site in structure + ] # Read in selective dynamics if present. Note that the ASE FixAtoms class fixes (x,y,z), so # here we make sure that [False, False, False] or [True, True, True] is set for the site selective @@ -142,7 +176,13 @@ def get_atoms(structure: SiteCollection, **kwargs) -> Atoms: # Add any remaining site properties to the ASE Atoms object for prop in structure.site_properties: - if prop not in ["magmom", "charge", "final_magmom", "final_charge", "selective_dynamics"]: + if prop not in [ + "magmom", + "charge", + "final_magmom", + "final_charge", + "selective_dynamics", + ]: atoms.set_array(prop, np.array(structure.site_properties[prop])) if any(oxi_states): atoms.set_array("oxi_states", np.array(oxi_states)) @@ -154,7 +194,8 @@ def get_atoms(structure: SiteCollection, **kwargs) -> Atoms: # Regenerate Spacegroup object from `.todict()` representation if isinstance(atoms.info.get("spacegroup"), dict): atoms.info["spacegroup"] = Spacegroup( - atoms.info["spacegroup"]["number"], setting=atoms.info["spacegroup"].get("setting", 1) + atoms.info["spacegroup"]["number"], + setting=atoms.info["spacegroup"].get("setting", 1), ) # Atoms.calc <---> Structure.calc @@ -173,7 +214,9 @@ def get_atoms(structure: SiteCollection, **kwargs) -> Atoms: return atoms @staticmethod - def get_structure(atoms: Atoms, cls: type[Structure] = Structure, **cls_kwargs) -> Structure: + def get_structure( + atoms: Atoms, cls: type[Structure] = Structure, **cls_kwargs + ) -> Structure: """ Returns pymatgen structure from ASE Atoms. @@ -193,15 +236,24 @@ def get_structure(atoms: Atoms, cls: type[Structure] = Structure, **cls_kwargs) tags = atoms.get_tags() if atoms.has("tags") else None # Get the (final) site magmoms and charges from the ASE Atoms object. - if getattr(atoms, "calc", None) is not None and getattr(atoms.calc, "results", None) is not None: + if ( + getattr(atoms, "calc", None) is not None + and getattr(atoms.calc, "results", None) is not None + ): charges = atoms.calc.results.get("charges") magmoms = atoms.calc.results.get("magmoms") else: magmoms = charges = None # Get the initial magmoms and charges from the ASE Atoms object. - initial_charges = atoms.get_initial_charges() if atoms.has("initial_charges") else None - initial_magmoms = atoms.get_initial_magnetic_moments() if atoms.has("initial_magmoms") else None + initial_charges = ( + atoms.get_initial_charges() if atoms.has("initial_charges") else None + ) + initial_magmoms = ( + atoms.get_initial_magnetic_moments() + if atoms.has("initial_magmoms") + else None + ) oxi_states = atoms.get_array("oxi_states") if atoms.has("oxi_states") else None # If the ASE Atoms object has constraints, make sure that they are of the @@ -216,14 +268,22 @@ def get_structure(atoms: Atoms, cls: type[Structure] = Structure, **cls_kwargs) else: unsupported_constraint_type = True if unsupported_constraint_type: - warnings.warn("Only FixAtoms is supported by Pymatgen. Other constraints will not be set.", UserWarning) - sel_dyn = [[False] * 3 if atom.index in constraint_indices else [True] * 3 for atom in atoms] + warnings.warn( + "Only FixAtoms is supported by Pymatgen. Other constraints will not be set.", + UserWarning, + ) + sel_dyn = [ + [False] * 3 if atom.index in constraint_indices else [True] * 3 + for atom in atoms + ] else: sel_dyn = None # Atoms.info <---> Structure.properties # But first make sure `spacegroup` is JSON serializable - if atoms.info.get("spacegroup") and isinstance(atoms.info["spacegroup"], Spacegroup): + if atoms.info.get("spacegroup") and isinstance( + atoms.info["spacegroup"], Spacegroup + ): atoms.info["spacegroup"] = atoms.info["spacegroup"].todict() properties = getattr(atoms, "info", {}) @@ -232,7 +292,14 @@ def get_structure(atoms: Atoms, cls: type[Structure] = Structure, **cls_kwargs) if cls == Molecule: structure = cls(symbols, positions, properties=properties, **cls_kwargs) else: - structure = cls(lattice, symbols, positions, coords_are_cartesian=True, properties=properties, **cls_kwargs) + structure = cls( + lattice, + symbols, + positions, + coords_are_cartesian=True, + properties=properties, + **cls_kwargs, + ) # Atoms.calc <---> Structure.calc if calc := getattr(atoms, "calc", None): @@ -293,7 +360,9 @@ def get_structure(atoms: Atoms, cls: type[Structure] = Structure, **cls_kwargs) return structure @staticmethod - def get_molecule(atoms: Atoms, cls: type[Molecule] = Molecule, **cls_kwargs) -> Molecule: + def get_molecule( + atoms: Atoms, cls: type[Molecule] = Molecule, **cls_kwargs + ) -> Molecule: """ Returns pymatgen molecule from ASE Atoms. @@ -311,12 +380,20 @@ def get_molecule(atoms: Atoms, cls: type[Molecule] = Molecule, **cls_kwargs) -> try: charge = atoms.charge except AttributeError: - charge = round(np.sum(atoms.get_initial_charges())) if atoms.has("initial_charges") else 0 + charge = ( + round(np.sum(atoms.get_initial_charges())) + if atoms.has("initial_charges") + else 0 + ) try: spin_mult = atoms.spin_multiplicity except AttributeError: - spin_mult = round(np.sum(atoms.get_initial_magnetic_moments())) + 1 if atoms.has("initial_magmoms") else 1 + spin_mult = ( + round(np.sum(atoms.get_initial_magnetic_moments())) + 1 + if atoms.has("initial_magmoms") + else 1 + ) molecule.set_charge_and_spin(charge, spin_multiplicity=spin_mult) diff --git a/tests/io/test_ase.py b/tests/io/test_ase.py index 3a96b2d8422..7627ef88a37 100644 --- a/tests/io/test_ase.py +++ b/tests/io/test_ase.py @@ -8,7 +8,7 @@ from pymatgen.core import Composition, Lattice, Molecule, Structure from pymatgen.core.structure import StructureError -from pymatgen.io.ase import AseAtomsAdaptor +from pymatgen.io.ase import AseAtomsAdaptor, PMGAtoms from pymatgen.util.testing import TEST_FILES_DIR ase = pytest.importorskip("ase") @@ -147,7 +147,10 @@ def test_get_structure(): atoms = read(f"{TEST_FILES_DIR}/POSCAR_overlap") struct = AseAtomsAdaptor.get_structure(atoms) assert [s.species_string for s in struct] == atoms.get_chemical_symbols() - with pytest.raises(StructureError, match=f"sites are less than {struct.DISTANCE_TOLERANCE} Angstrom apart"): + with pytest.raises( + StructureError, + match=f"sites are less than {struct.DISTANCE_TOLERANCE} Angstrom apart", + ): struct = AseAtomsAdaptor.get_structure(atoms, validate_proximity=True) @@ -162,14 +165,22 @@ def test_get_structure_mag(): atoms = read(f"{TEST_FILES_DIR}/OUTCAR") structure = AseAtomsAdaptor.get_structure(atoms) - assert structure.site_properties["final_magmom"] == atoms.get_magnetic_moments().tolist() + assert ( + structure.site_properties["final_magmom"] + == atoms.get_magnetic_moments().tolist() + ) assert "magmom" not in structure.site_properties assert "initial_magmoms" not in structure.site_properties @pytest.mark.parametrize( "select_dyn", - [[True, True, True], [False, False, False], np.array([True, True, True]), np.array([False, False, False])], + [ + [True, True, True], + [False, False, False], + np.array([True, True, True]), + np.array([False, False, False]), + ], ) def test_get_structure_dyn(select_dyn): atoms = read(f"{TEST_FILES_DIR}/POSCAR") @@ -289,3 +300,25 @@ def test_back_forth_v4(): # test document can be jsanitized and decoded dct = jsanitize(molecule, strict=True, enum_values=True) MontyDecoder().process_decoded(dct) + + +def test_pmg_atoms(): + from ase.io.jsonio import encode + + atoms = read(f"{TEST_FILES_DIR}/OUTCAR") + ref = {"@module": "ase.atoms", "@class": "Atoms", "atoms_json": encode(atoms)} + pmgatoms = PMGAtoms(atoms) + assert pmgatoms.as_dict() == ref + assert PMGAtoms.from_dict(ref) == atoms + + +def test_pmg_atoms_v2(): + structure = Structure.from_file(f"{TEST_FILES_DIR}/POSCAR") + + atoms = AseAtomsAdaptor.get_atoms(structure, msonable=True) + assert hasattr(atoms, "as_dict") + assert hasattr(atoms, "from_dict") + + atoms = AseAtomsAdaptor.get_atoms(structure, msonable=False) + assert not hasattr(atoms, "as_dict") + assert not hasattr(atoms, "from_dict") From 6a401a809d87e2c457f3ecbe1543c343f3146262 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 13 Feb 2024 06:27:19 +0000 Subject: [PATCH 2/6] pre-commit auto-fixes --- pymatgen/io/ase.py | 62 +++++++++++--------------------------------- tests/io/test_ase.py | 5 +--- 2 files changed, 16 insertions(+), 51 deletions(-) diff --git a/pymatgen/io/ase.py b/pymatgen/io/ase.py index 0460e032c0a..34f08e5f35e 100644 --- a/pymatgen/io/ase.py +++ b/pymatgen/io/ase.py @@ -12,12 +12,14 @@ import numpy as np from monty.json import MSONable + from pymatgen.core.structure import Molecule, Structure if TYPE_CHECKING: from typing import Any from numpy.typing import ArrayLike + from pymatgen.core.structure import SiteCollection try: @@ -66,9 +68,7 @@ class AseAtomsAdaptor: """Adaptor serves as a bridge between ASE Atoms and pymatgen objects.""" @staticmethod - def get_atoms( - structure: SiteCollection, msonable: bool = True, **kwargs - ) -> PMGAtoms | Atoms: + def get_atoms(structure: SiteCollection, msonable: bool = True, **kwargs) -> PMGAtoms | Atoms: """ Returns ASE Atoms object from pymatgen structure or molecule. @@ -81,9 +81,7 @@ def get_atoms( Atoms: ASE Atoms object """ if not ase_loaded: - raise PackageNotFoundError( - "AseAtomsAdaptor requires the ASE package. Use `pip install ase`" - ) + raise PackageNotFoundError("AseAtomsAdaptor requires the ASE package. Use `pip install ase`") if not structure.is_ordered: raise ValueError("ASE Atoms only supports ordered structures") @@ -97,9 +95,7 @@ def get_atoms( pbc = False cell = None - atoms = Atoms( - symbols=symbols, positions=positions, pbc=pbc, cell=cell, **kwargs - ) + atoms = Atoms(symbols=symbols, positions=positions, pbc=pbc, cell=cell, **kwargs) if msonable: atoms = PMGAtoms(atoms) @@ -144,9 +140,7 @@ def get_atoms( atoms.spin_multiplicity = structure.spin_multiplicity # Get the oxidation states from the structure - oxi_states: list[float | None] = [ - getattr(site.specie, "oxi_state", None) for site in structure - ] + oxi_states: list[float | None] = [getattr(site.specie, "oxi_state", None) for site in structure] # Read in selective dynamics if present. Note that the ASE FixAtoms class fixes (x,y,z), so # here we make sure that [False, False, False] or [True, True, True] is set for the site selective @@ -214,9 +208,7 @@ def get_atoms( return atoms @staticmethod - def get_structure( - atoms: Atoms, cls: type[Structure] = Structure, **cls_kwargs - ) -> Structure: + def get_structure(atoms: Atoms, cls: type[Structure] = Structure, **cls_kwargs) -> Structure: """ Returns pymatgen structure from ASE Atoms. @@ -236,24 +228,15 @@ def get_structure( tags = atoms.get_tags() if atoms.has("tags") else None # Get the (final) site magmoms and charges from the ASE Atoms object. - if ( - getattr(atoms, "calc", None) is not None - and getattr(atoms.calc, "results", None) is not None - ): + if getattr(atoms, "calc", None) is not None and getattr(atoms.calc, "results", None) is not None: charges = atoms.calc.results.get("charges") magmoms = atoms.calc.results.get("magmoms") else: magmoms = charges = None # Get the initial magmoms and charges from the ASE Atoms object. - initial_charges = ( - atoms.get_initial_charges() if atoms.has("initial_charges") else None - ) - initial_magmoms = ( - atoms.get_initial_magnetic_moments() - if atoms.has("initial_magmoms") - else None - ) + initial_charges = atoms.get_initial_charges() if atoms.has("initial_charges") else None + initial_magmoms = atoms.get_initial_magnetic_moments() if atoms.has("initial_magmoms") else None oxi_states = atoms.get_array("oxi_states") if atoms.has("oxi_states") else None # If the ASE Atoms object has constraints, make sure that they are of the @@ -272,18 +255,13 @@ def get_structure( "Only FixAtoms is supported by Pymatgen. Other constraints will not be set.", UserWarning, ) - sel_dyn = [ - [False] * 3 if atom.index in constraint_indices else [True] * 3 - for atom in atoms - ] + sel_dyn = [[False] * 3 if atom.index in constraint_indices else [True] * 3 for atom in atoms] else: sel_dyn = None # Atoms.info <---> Structure.properties # But first make sure `spacegroup` is JSON serializable - if atoms.info.get("spacegroup") and isinstance( - atoms.info["spacegroup"], Spacegroup - ): + if atoms.info.get("spacegroup") and isinstance(atoms.info["spacegroup"], Spacegroup): atoms.info["spacegroup"] = atoms.info["spacegroup"].todict() properties = getattr(atoms, "info", {}) @@ -360,9 +338,7 @@ def get_structure( return structure @staticmethod - def get_molecule( - atoms: Atoms, cls: type[Molecule] = Molecule, **cls_kwargs - ) -> Molecule: + def get_molecule(atoms: Atoms, cls: type[Molecule] = Molecule, **cls_kwargs) -> Molecule: """ Returns pymatgen molecule from ASE Atoms. @@ -380,20 +356,12 @@ def get_molecule( try: charge = atoms.charge except AttributeError: - charge = ( - round(np.sum(atoms.get_initial_charges())) - if atoms.has("initial_charges") - else 0 - ) + charge = round(np.sum(atoms.get_initial_charges())) if atoms.has("initial_charges") else 0 try: spin_mult = atoms.spin_multiplicity except AttributeError: - spin_mult = ( - round(np.sum(atoms.get_initial_magnetic_moments())) + 1 - if atoms.has("initial_magmoms") - else 1 - ) + spin_mult = round(np.sum(atoms.get_initial_magnetic_moments())) + 1 if atoms.has("initial_magmoms") else 1 molecule.set_charge_and_spin(charge, spin_multiplicity=spin_mult) diff --git a/tests/io/test_ase.py b/tests/io/test_ase.py index 7627ef88a37..049f2d7a5fa 100644 --- a/tests/io/test_ase.py +++ b/tests/io/test_ase.py @@ -165,10 +165,7 @@ def test_get_structure_mag(): atoms = read(f"{TEST_FILES_DIR}/OUTCAR") structure = AseAtomsAdaptor.get_structure(atoms) - assert ( - structure.site_properties["final_magmom"] - == atoms.get_magnetic_moments().tolist() - ) + assert structure.site_properties["final_magmom"] == atoms.get_magnetic_moments().tolist() assert "magmom" not in structure.site_properties assert "initial_magmoms" not in structure.site_properties From 54bef0c93fd9f97e17772cd82b9334b39b1cbaf3 Mon Sep 17 00:00:00 2001 From: "Andrew S. Rosen" Date: Tue, 13 Feb 2024 08:22:22 -0800 Subject: [PATCH 3/6] Update ase.py Signed-off-by: Andrew S. Rosen --- pymatgen/io/ase.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pymatgen/io/ase.py b/pymatgen/io/ase.py index 34f08e5f35e..3198417938a 100644 --- a/pymatgen/io/ase.py +++ b/pymatgen/io/ase.py @@ -40,7 +40,7 @@ __date__ = "Mar 8, 2012" -class PMGAtoms(Atoms, MSONable): +class MSONableAtoms(Atoms, MSONable): """A custom subclass of ASE Atoms that is MSONable, including `.as_dict()` and `.from_dict()` methods.""" def as_dict(s: Atoms) -> dict[str, Any]: @@ -68,13 +68,13 @@ class AseAtomsAdaptor: """Adaptor serves as a bridge between ASE Atoms and pymatgen objects.""" @staticmethod - def get_atoms(structure: SiteCollection, msonable: bool = True, **kwargs) -> PMGAtoms | Atoms: + def get_atoms(structure: SiteCollection, msonable: bool = True, **kwargs) -> MSONableAtoms | Atoms: """ Returns ASE Atoms object from pymatgen structure or molecule. Args: structure (SiteCollection): pymatgen Structure or Molecule - msonable (bool): Whether to return a PMGAtoms object, which is MSONable. + msonable (bool): Whether to return an MSONableAtoms object, which is MSONable. **kwargs: passed to the ASE Atoms constructor Returns: @@ -98,7 +98,7 @@ def get_atoms(structure: SiteCollection, msonable: bool = True, **kwargs) -> PMG atoms = Atoms(symbols=symbols, positions=positions, pbc=pbc, cell=cell, **kwargs) if msonable: - atoms = PMGAtoms(atoms) + atoms = MSONableAtoms(atoms) if "tags" in structure.site_properties: atoms.set_tags(structure.site_properties["tags"]) From c51f0f922e57f1cb34e8185284e8edda13aa37a4 Mon Sep 17 00:00:00 2001 From: "Andrew S. Rosen" Date: Tue, 13 Feb 2024 08:22:53 -0800 Subject: [PATCH 4/6] Update test_ase.py Signed-off-by: Andrew S. Rosen --- tests/io/test_ase.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/io/test_ase.py b/tests/io/test_ase.py index 049f2d7a5fa..9f388d554c5 100644 --- a/tests/io/test_ase.py +++ b/tests/io/test_ase.py @@ -8,7 +8,7 @@ from pymatgen.core import Composition, Lattice, Molecule, Structure from pymatgen.core.structure import StructureError -from pymatgen.io.ase import AseAtomsAdaptor, PMGAtoms +from pymatgen.io.ase import AseAtomsAdaptor, MSONableAtoms from pymatgen.util.testing import TEST_FILES_DIR ase = pytest.importorskip("ase") @@ -299,17 +299,17 @@ def test_back_forth_v4(): MontyDecoder().process_decoded(dct) -def test_pmg_atoms(): +def test_msonable_atoms(): from ase.io.jsonio import encode atoms = read(f"{TEST_FILES_DIR}/OUTCAR") ref = {"@module": "ase.atoms", "@class": "Atoms", "atoms_json": encode(atoms)} - pmgatoms = PMGAtoms(atoms) - assert pmgatoms.as_dict() == ref - assert PMGAtoms.from_dict(ref) == atoms + msonable_atoms = MSONableAtoms(atoms) + assert msonable_atoms.as_dict() == ref + assert MSONableAtoms.from_dict(ref) == atoms -def test_pmg_atoms_v2(): +def test_msonable_atoms_v2(): structure = Structure.from_file(f"{TEST_FILES_DIR}/POSCAR") atoms = AseAtomsAdaptor.get_atoms(structure, msonable=True) From 59e62f4e49845ee78a4148c6afd1609f06d5d5ff Mon Sep 17 00:00:00 2001 From: "Andrew S. Rosen" Date: Tue, 13 Feb 2024 20:29:57 -0800 Subject: [PATCH 5/6] Update ase.py Signed-off-by: Andrew S. Rosen --- pymatgen/io/ase.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pymatgen/io/ase.py b/pymatgen/io/ase.py index 3198417938a..4f377668c58 100644 --- a/pymatgen/io/ase.py +++ b/pymatgen/io/ase.py @@ -40,7 +40,7 @@ __date__ = "Mar 8, 2012" -class MSONableAtoms(Atoms, MSONable): +class MSONAtoms(Atoms, MSONable): """A custom subclass of ASE Atoms that is MSONable, including `.as_dict()` and `.from_dict()` methods.""" def as_dict(s: Atoms) -> dict[str, Any]: @@ -68,13 +68,13 @@ class AseAtomsAdaptor: """Adaptor serves as a bridge between ASE Atoms and pymatgen objects.""" @staticmethod - def get_atoms(structure: SiteCollection, msonable: bool = True, **kwargs) -> MSONableAtoms | Atoms: + def get_atoms(structure: SiteCollection, msonable: bool = True, **kwargs) -> MSONAtoms | Atoms: """ Returns ASE Atoms object from pymatgen structure or molecule. Args: structure (SiteCollection): pymatgen Structure or Molecule - msonable (bool): Whether to return an MSONableAtoms object, which is MSONable. + msonable (bool): Whether to return an MSONAtoms object, which is MSONable. **kwargs: passed to the ASE Atoms constructor Returns: @@ -98,7 +98,7 @@ def get_atoms(structure: SiteCollection, msonable: bool = True, **kwargs) -> MSO atoms = Atoms(symbols=symbols, positions=positions, pbc=pbc, cell=cell, **kwargs) if msonable: - atoms = MSONableAtoms(atoms) + atoms = MSONAtoms(atoms) if "tags" in structure.site_properties: atoms.set_tags(structure.site_properties["tags"]) From 04b644a2883a34400fa55ac0b81b345a277f8758 Mon Sep 17 00:00:00 2001 From: "Andrew S. Rosen" Date: Tue, 13 Feb 2024 20:30:15 -0800 Subject: [PATCH 6/6] Update test_ase.py Signed-off-by: Andrew S. Rosen --- tests/io/test_ase.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/io/test_ase.py b/tests/io/test_ase.py index 9f388d554c5..78284982fff 100644 --- a/tests/io/test_ase.py +++ b/tests/io/test_ase.py @@ -8,7 +8,7 @@ from pymatgen.core import Composition, Lattice, Molecule, Structure from pymatgen.core.structure import StructureError -from pymatgen.io.ase import AseAtomsAdaptor, MSONableAtoms +from pymatgen.io.ase import AseAtomsAdaptor, MSONAtoms from pymatgen.util.testing import TEST_FILES_DIR ase = pytest.importorskip("ase") @@ -304,9 +304,9 @@ def test_msonable_atoms(): atoms = read(f"{TEST_FILES_DIR}/OUTCAR") ref = {"@module": "ase.atoms", "@class": "Atoms", "atoms_json": encode(atoms)} - msonable_atoms = MSONableAtoms(atoms) + msonable_atoms = MSONAtoms(atoms) assert msonable_atoms.as_dict() == ref - assert MSONableAtoms.from_dict(ref) == atoms + assert MSONAtoms.from_dict(ref) == atoms def test_msonable_atoms_v2():