From bd93b2d1c8c3c464228bcd45fe49b301511f3ad5 Mon Sep 17 00:00:00 2001 From: Willis O'Leary <wolearyc@gmail.com> Date: Wed, 8 Jan 2025 23:41:48 +0000 Subject: [PATCH] fixed bug in which row data was not fully copied during dataset split --- .../core/preprocessing/atoms_to_graphs.py | 13 +++--- .../preprocessing/test_atoms_to_graphs.py | 43 ++++++++++++++++--- 2 files changed, 45 insertions(+), 11 deletions(-) diff --git a/src/fairchem/core/preprocessing/atoms_to_graphs.py b/src/fairchem/core/preprocessing/atoms_to_graphs.py index 27dc664a38..5453466aab 100644 --- a/src/fairchem/core/preprocessing/atoms_to_graphs.py +++ b/src/fairchem/core/preprocessing/atoms_to_graphs.py @@ -296,18 +296,19 @@ def convert_all( else: raise NotImplementedError - for atoms in tqdm( + for atoms_or_row in tqdm( atoms_iter, desc="converting ASE atoms collection to graphs", total=len(atoms_collection), unit=" systems", disable=disable_tqdm, ): - # check if atoms is an ASE Atoms object this for the ase.db case - data = self.convert( - atoms if isinstance(atoms, ase.atoms.Atoms) else atoms.toatoms() - ) - data_list.append(data) + if isinstance(atoms_or_row, ase.db.row.AtomsRow): + atoms = atoms_or_row.toatoms(add_additional_information=True) + atoms.info = atoms.info["data"] + data_list.append(self.convert(atoms)) + else: + data_list.append(self.convert(atoms_or_row)) if collate_and_save: data, slices = collate(data_list) diff --git a/tests/core/preprocessing/test_atoms_to_graphs.py b/tests/core/preprocessing/test_atoms_to_graphs.py index 5c07a45243..0ec11b3951 100644 --- a/tests/core/preprocessing/test_atoms_to_graphs.py +++ b/tests/core/preprocessing/test_atoms_to_graphs.py @@ -11,11 +11,13 @@ import numpy as np import pytest +from ase import db from ase.io import read from ase.neighborlist import NeighborList, NewPrimitiveNeighborList -from fairchem.core.preprocessing import AtomsToGraphs from fairchem.core.modules.evaluator import min_diff +from fairchem.core.preprocessing import AtomsToGraphs + @pytest.fixture(scope="class") def atoms_to_graphs_internals(request) -> None: @@ -44,7 +46,17 @@ def atoms_to_graphs_internals(request) -> None: r_distances=True, r_data_keys=["stiffness_tensor"], ) + test_object_only_stiffness = AtomsToGraphs( + max_neigh=200, + radius=6, + r_energy=False, + r_forces=False, + r_stress=False, + r_distances=False, + r_data_keys=["stiffness_tensor"], + ) request.cls.atg = test_object + request.cls.atg_only_stiffness = test_object_only_stiffness request.cls.atoms = atoms @@ -110,7 +122,9 @@ def test_convert(self) -> None: # positions act_positions = self.atoms.get_positions() positions = data.pos.numpy() - mindiff = min_diff(act_positions, positions, self.atoms.get_cell(), self.atoms.pbc) + mindiff = min_diff( + act_positions, positions, self.atoms.get_cell(), self.atoms.pbc + ) np.testing.assert_allclose(mindiff, 0, atol=1e-6) # check energy value act_energy = self.atoms.get_potential_energy(apply_constraint=False) @@ -130,9 +144,8 @@ def test_convert(self) -> None: self.atoms.info["stiffness_tensor"], stiffness_tensor ) - def test_convert_all(self) -> None: + def test_convert_all_atoms_list(self) -> None: # run convert_all on a list with one atoms object - # this does not test the atoms.db functionality atoms_list = [self.atoms] data_list = self.atg.convert_all(atoms_list) # check shape/values of features @@ -143,7 +156,9 @@ def test_convert_all(self) -> None: # positions act_positions = self.atoms.get_positions() positions = data_list[0].pos.numpy() - mindiff = min_diff(act_positions, positions, self.atoms.get_cell(), self.atoms.pbc) + mindiff = min_diff( + act_positions, positions, self.atoms.get_cell(), self.atoms.pbc + ) np.testing.assert_allclose(mindiff, 0, atol=1e-6) # check energy value act_energy = self.atoms.get_potential_energy(apply_constraint=False) @@ -162,3 +177,21 @@ def test_convert_all(self) -> None: np.testing.assert_allclose( self.atoms.info["stiffness_tensor"], stiffness_tensor ) + + def test_convert_all_ase_db(self, tmp_path_factory) -> None: + # run convert_all on an ASE db object + + # There is a possible bug in ASE which makes this test annoying to write. + # AtomsRow.toatoms() has a calculator attached that computes a stress tensor # with the wrong shape: (9,). This makes convert_all fail due to an assertion in + # atoms.get_stress(). + + tmp_path = tmp_path_factory.mktemp("convert_all_test") + with db.connect(tmp_path / "asedb.db") as database: + database.write(self.atoms, data=self.atoms.info) + data_list = self.atg_only_stiffness.convert_all(database) + + # additional data (ie stiffness_tensor) + stiffness_tensor = data_list[0].stiffness_tensor.numpy() + np.testing.assert_allclose( + self.atoms.info["stiffness_tensor"], stiffness_tensor + )