diff --git a/pymatgen/io/ase.py b/pymatgen/io/ase.py index 4516ee04542..4f377668c58 100644 --- a/pymatgen/io/ase.py +++ b/pymatgen/io/ase.py @@ -3,7 +3,6 @@ Atoms object and pymatgen Structure objects. """ - from __future__ import annotations import warnings @@ -12,16 +11,19 @@ from typing import TYPE_CHECKING import numpy as np +from monty.json import MSONable from pymatgen.core.structure import Molecule, Structure if TYPE_CHECKING: + from typing import Any + from numpy.typing import ArrayLike from pymatgen.core.structure import SiteCollection try: - from ase import Atoms + from ase.atoms import Atoms from ase.calculators.singlepoint import SinglePointDFTCalculator from ase.constraints import FixAtoms from ase.spacegroup import Spacegroup @@ -38,18 +40,41 @@ __date__ = "Mar 8, 2012" +class MSONAtoms(Atoms, MSONable): + """A custom subclass of ASE Atoms that is MSONable, including `.as_dict()` and `.from_dict()` methods.""" + + def as_dict(s: Atoms) -> dict[str, Any]: + from ase.io.jsonio import encode + + # Normally, we would want to this to be a wrapper around atoms.todict() with @module and + # @class key-value pairs inserted. However, atoms.todict()/atoms.fromdict() is not meant + # to be used in a round-trip fashion and does not work properly with constraints. + # See ASE issue #1387. + return {"@module": "ase.atoms", "@class": "Atoms", "atoms_json": encode(s)} + + def from_dict(d: dict[str, Any]) -> Atoms: + from ase.io.jsonio import decode + + # Normally, we would want to this to be a wrapper around atoms.fromdict() with @module and + # @class key-value pairs inserted. However, atoms.todict()/atoms.fromdict() is not meant + # to be used in a round-trip fashion and does not work properly with constraints. + # See ASE issue #1387. + return decode(d["atoms_json"]) + + # NOTE: If making notable changes to this class, please ping @Andrew-S-Rosen on GitHub. # There are some subtleties in here, particularly related to spins/charges. class AseAtomsAdaptor: """Adaptor serves as a bridge between ASE Atoms and pymatgen objects.""" @staticmethod - def get_atoms(structure: SiteCollection, **kwargs) -> Atoms: + def get_atoms(structure: SiteCollection, msonable: bool = True, **kwargs) -> MSONAtoms | Atoms: """ Returns ASE Atoms object from pymatgen structure or molecule. Args: structure (SiteCollection): pymatgen Structure or Molecule + msonable (bool): Whether to return an MSONAtoms object, which is MSONable. **kwargs: passed to the ASE Atoms constructor Returns: @@ -72,6 +97,9 @@ def get_atoms(structure: SiteCollection, **kwargs) -> Atoms: atoms = Atoms(symbols=symbols, positions=positions, pbc=pbc, cell=cell, **kwargs) + if msonable: + atoms = MSONAtoms(atoms) + if "tags" in structure.site_properties: atoms.set_tags(structure.site_properties["tags"]) @@ -142,7 +170,13 @@ def get_atoms(structure: SiteCollection, **kwargs) -> Atoms: # Add any remaining site properties to the ASE Atoms object for prop in structure.site_properties: - if prop not in ["magmom", "charge", "final_magmom", "final_charge", "selective_dynamics"]: + if prop not in [ + "magmom", + "charge", + "final_magmom", + "final_charge", + "selective_dynamics", + ]: atoms.set_array(prop, np.array(structure.site_properties[prop])) if any(oxi_states): atoms.set_array("oxi_states", np.array(oxi_states)) @@ -154,7 +188,8 @@ def get_atoms(structure: SiteCollection, **kwargs) -> Atoms: # Regenerate Spacegroup object from `.todict()` representation if isinstance(atoms.info.get("spacegroup"), dict): atoms.info["spacegroup"] = Spacegroup( - atoms.info["spacegroup"]["number"], setting=atoms.info["spacegroup"].get("setting", 1) + atoms.info["spacegroup"]["number"], + setting=atoms.info["spacegroup"].get("setting", 1), ) # Atoms.calc <---> Structure.calc @@ -216,7 +251,10 @@ def get_structure(atoms: Atoms, cls: type[Structure] = Structure, **cls_kwargs) else: unsupported_constraint_type = True if unsupported_constraint_type: - warnings.warn("Only FixAtoms is supported by Pymatgen. Other constraints will not be set.", UserWarning) + warnings.warn( + "Only FixAtoms is supported by Pymatgen. Other constraints will not be set.", + UserWarning, + ) sel_dyn = [[False] * 3 if atom.index in constraint_indices else [True] * 3 for atom in atoms] else: sel_dyn = None @@ -232,7 +270,14 @@ def get_structure(atoms: Atoms, cls: type[Structure] = Structure, **cls_kwargs) if cls == Molecule: structure = cls(symbols, positions, properties=properties, **cls_kwargs) else: - structure = cls(lattice, symbols, positions, coords_are_cartesian=True, properties=properties, **cls_kwargs) + structure = cls( + lattice, + symbols, + positions, + coords_are_cartesian=True, + properties=properties, + **cls_kwargs, + ) # Atoms.calc <---> Structure.calc if calc := getattr(atoms, "calc", None): diff --git a/tests/io/test_ase.py b/tests/io/test_ase.py index 3a96b2d8422..78284982fff 100644 --- a/tests/io/test_ase.py +++ b/tests/io/test_ase.py @@ -8,7 +8,7 @@ from pymatgen.core import Composition, Lattice, Molecule, Structure from pymatgen.core.structure import StructureError -from pymatgen.io.ase import AseAtomsAdaptor +from pymatgen.io.ase import AseAtomsAdaptor, MSONAtoms from pymatgen.util.testing import TEST_FILES_DIR ase = pytest.importorskip("ase") @@ -147,7 +147,10 @@ def test_get_structure(): atoms = read(f"{TEST_FILES_DIR}/POSCAR_overlap") struct = AseAtomsAdaptor.get_structure(atoms) assert [s.species_string for s in struct] == atoms.get_chemical_symbols() - with pytest.raises(StructureError, match=f"sites are less than {struct.DISTANCE_TOLERANCE} Angstrom apart"): + with pytest.raises( + StructureError, + match=f"sites are less than {struct.DISTANCE_TOLERANCE} Angstrom apart", + ): struct = AseAtomsAdaptor.get_structure(atoms, validate_proximity=True) @@ -169,7 +172,12 @@ def test_get_structure_mag(): @pytest.mark.parametrize( "select_dyn", - [[True, True, True], [False, False, False], np.array([True, True, True]), np.array([False, False, False])], + [ + [True, True, True], + [False, False, False], + np.array([True, True, True]), + np.array([False, False, False]), + ], ) def test_get_structure_dyn(select_dyn): atoms = read(f"{TEST_FILES_DIR}/POSCAR") @@ -289,3 +297,25 @@ def test_back_forth_v4(): # test document can be jsanitized and decoded dct = jsanitize(molecule, strict=True, enum_values=True) MontyDecoder().process_decoded(dct) + + +def test_msonable_atoms(): + from ase.io.jsonio import encode + + atoms = read(f"{TEST_FILES_DIR}/OUTCAR") + ref = {"@module": "ase.atoms", "@class": "Atoms", "atoms_json": encode(atoms)} + msonable_atoms = MSONAtoms(atoms) + assert msonable_atoms.as_dict() == ref + assert MSONAtoms.from_dict(ref) == atoms + + +def test_msonable_atoms_v2(): + structure = Structure.from_file(f"{TEST_FILES_DIR}/POSCAR") + + atoms = AseAtomsAdaptor.get_atoms(structure, msonable=True) + assert hasattr(atoms, "as_dict") + assert hasattr(atoms, "from_dict") + + atoms = AseAtomsAdaptor.get_atoms(structure, msonable=False) + assert not hasattr(atoms, "as_dict") + assert not hasattr(atoms, "from_dict")