From bd93b2d1c8c3c464228bcd45fe49b301511f3ad5 Mon Sep 17 00:00:00 2001
From: Willis O'Leary <wolearyc@gmail.com>
Date: Wed, 8 Jan 2025 23:41:48 +0000
Subject: [PATCH] fixed bug in which row data was not fully copied during
 dataset split

---
 .../core/preprocessing/atoms_to_graphs.py     | 13 +++---
 .../preprocessing/test_atoms_to_graphs.py     | 43 ++++++++++++++++---
 2 files changed, 45 insertions(+), 11 deletions(-)

diff --git a/src/fairchem/core/preprocessing/atoms_to_graphs.py b/src/fairchem/core/preprocessing/atoms_to_graphs.py
index 27dc664a38..5453466aab 100644
--- a/src/fairchem/core/preprocessing/atoms_to_graphs.py
+++ b/src/fairchem/core/preprocessing/atoms_to_graphs.py
@@ -296,18 +296,19 @@ def convert_all(
         else:
             raise NotImplementedError
 
-        for atoms in tqdm(
+        for atoms_or_row in tqdm(
             atoms_iter,
             desc="converting ASE atoms collection to graphs",
             total=len(atoms_collection),
             unit=" systems",
             disable=disable_tqdm,
         ):
-            # check if atoms is an ASE Atoms object this for the ase.db case
-            data = self.convert(
-                atoms if isinstance(atoms, ase.atoms.Atoms) else atoms.toatoms()
-            )
-            data_list.append(data)
+            if isinstance(atoms_or_row, ase.db.row.AtomsRow):
+                atoms = atoms_or_row.toatoms(add_additional_information=True)
+                atoms.info = atoms.info["data"]
+                data_list.append(self.convert(atoms))
+            else:
+                data_list.append(self.convert(atoms_or_row))
 
         if collate_and_save:
             data, slices = collate(data_list)
diff --git a/tests/core/preprocessing/test_atoms_to_graphs.py b/tests/core/preprocessing/test_atoms_to_graphs.py
index 5c07a45243..0ec11b3951 100644
--- a/tests/core/preprocessing/test_atoms_to_graphs.py
+++ b/tests/core/preprocessing/test_atoms_to_graphs.py
@@ -11,11 +11,13 @@
 
 import numpy as np
 import pytest
+from ase import db
 from ase.io import read
 from ase.neighborlist import NeighborList, NewPrimitiveNeighborList
 
-from fairchem.core.preprocessing import AtomsToGraphs
 from fairchem.core.modules.evaluator import min_diff
+from fairchem.core.preprocessing import AtomsToGraphs
+
 
 @pytest.fixture(scope="class")
 def atoms_to_graphs_internals(request) -> None:
@@ -44,7 +46,17 @@ def atoms_to_graphs_internals(request) -> None:
         r_distances=True,
         r_data_keys=["stiffness_tensor"],
     )
+    test_object_only_stiffness = AtomsToGraphs(
+        max_neigh=200,
+        radius=6,
+        r_energy=False,
+        r_forces=False,
+        r_stress=False,
+        r_distances=False,
+        r_data_keys=["stiffness_tensor"],
+    )
     request.cls.atg = test_object
+    request.cls.atg_only_stiffness = test_object_only_stiffness
     request.cls.atoms = atoms
 
 
@@ -110,7 +122,9 @@ def test_convert(self) -> None:
         # positions
         act_positions = self.atoms.get_positions()
         positions = data.pos.numpy()
-        mindiff = min_diff(act_positions, positions, self.atoms.get_cell(), self.atoms.pbc)        
+        mindiff = min_diff(
+            act_positions, positions, self.atoms.get_cell(), self.atoms.pbc
+        )
         np.testing.assert_allclose(mindiff, 0, atol=1e-6)
         # check energy value
         act_energy = self.atoms.get_potential_energy(apply_constraint=False)
@@ -130,9 +144,8 @@ def test_convert(self) -> None:
             self.atoms.info["stiffness_tensor"], stiffness_tensor
         )
 
-    def test_convert_all(self) -> None:
+    def test_convert_all_atoms_list(self) -> None:
         # run convert_all on a list with one atoms object
-        # this does not test the atoms.db functionality
         atoms_list = [self.atoms]
         data_list = self.atg.convert_all(atoms_list)
         # check shape/values of features
@@ -143,7 +156,9 @@ def test_convert_all(self) -> None:
         # positions
         act_positions = self.atoms.get_positions()
         positions = data_list[0].pos.numpy()
-        mindiff = min_diff(act_positions, positions, self.atoms.get_cell(), self.atoms.pbc)        
+        mindiff = min_diff(
+            act_positions, positions, self.atoms.get_cell(), self.atoms.pbc
+        )
         np.testing.assert_allclose(mindiff, 0, atol=1e-6)
         # check energy value
         act_energy = self.atoms.get_potential_energy(apply_constraint=False)
@@ -162,3 +177,21 @@ def test_convert_all(self) -> None:
         np.testing.assert_allclose(
             self.atoms.info["stiffness_tensor"], stiffness_tensor
         )
+
+    def test_convert_all_ase_db(self, tmp_path_factory) -> None:
+        # run convert_all on an ASE db object
+
+        # There is a possible bug in ASE which makes this test annoying to write.
+        # AtomsRow.toatoms() has a calculator attached that computes a stress tensor # with the wrong shape: (9,). This makes convert_all fail due to an assertion in
+        # atoms.get_stress().
+
+        tmp_path = tmp_path_factory.mktemp("convert_all_test")
+        with db.connect(tmp_path / "asedb.db") as database:
+            database.write(self.atoms, data=self.atoms.info)
+            data_list = self.atg_only_stiffness.convert_all(database)
+
+        # additional data (ie stiffness_tensor)
+        stiffness_tensor = data_list[0].stiffness_tensor.numpy()
+        np.testing.assert_allclose(
+            self.atoms.info["stiffness_tensor"], stiffness_tensor
+        )