Skip to content

Commit

Permalink
fixed bug in which row data was not fully copied during dataset split
Browse files Browse the repository at this point in the history
  • Loading branch information
wolearyc committed Jan 8, 2025
1 parent 2a4543f commit bd93b2d
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 11 deletions.
13 changes: 7 additions & 6 deletions src/fairchem/core/preprocessing/atoms_to_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,18 +296,19 @@ def convert_all(
else:
raise NotImplementedError

for atoms in tqdm(
for atoms_or_row in tqdm(
atoms_iter,
desc="converting ASE atoms collection to graphs",
total=len(atoms_collection),
unit=" systems",
disable=disable_tqdm,
):
# check if atoms is an ASE Atoms object this for the ase.db case
data = self.convert(
atoms if isinstance(atoms, ase.atoms.Atoms) else atoms.toatoms()
)
data_list.append(data)
if isinstance(atoms_or_row, ase.db.row.AtomsRow):
atoms = atoms_or_row.toatoms(add_additional_information=True)
atoms.info = atoms.info["data"]
data_list.append(self.convert(atoms))
else:
data_list.append(self.convert(atoms_or_row))

if collate_and_save:
data, slices = collate(data_list)
Expand Down
43 changes: 38 additions & 5 deletions tests/core/preprocessing/test_atoms_to_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@

import numpy as np
import pytest
from ase import db
from ase.io import read
from ase.neighborlist import NeighborList, NewPrimitiveNeighborList

from fairchem.core.preprocessing import AtomsToGraphs
from fairchem.core.modules.evaluator import min_diff
from fairchem.core.preprocessing import AtomsToGraphs


@pytest.fixture(scope="class")
def atoms_to_graphs_internals(request) -> None:
Expand Down Expand Up @@ -44,7 +46,17 @@ def atoms_to_graphs_internals(request) -> None:
r_distances=True,
r_data_keys=["stiffness_tensor"],
)
test_object_only_stiffness = AtomsToGraphs(
max_neigh=200,
radius=6,
r_energy=False,
r_forces=False,
r_stress=False,
r_distances=False,
r_data_keys=["stiffness_tensor"],
)
request.cls.atg = test_object
request.cls.atg_only_stiffness = test_object_only_stiffness
request.cls.atoms = atoms


Expand Down Expand Up @@ -110,7 +122,9 @@ def test_convert(self) -> None:
# positions
act_positions = self.atoms.get_positions()
positions = data.pos.numpy()
mindiff = min_diff(act_positions, positions, self.atoms.get_cell(), self.atoms.pbc)
mindiff = min_diff(
act_positions, positions, self.atoms.get_cell(), self.atoms.pbc
)
np.testing.assert_allclose(mindiff, 0, atol=1e-6)
# check energy value
act_energy = self.atoms.get_potential_energy(apply_constraint=False)
Expand All @@ -130,9 +144,8 @@ def test_convert(self) -> None:
self.atoms.info["stiffness_tensor"], stiffness_tensor
)

def test_convert_all(self) -> None:
def test_convert_all_atoms_list(self) -> None:
# run convert_all on a list with one atoms object
# this does not test the atoms.db functionality
atoms_list = [self.atoms]
data_list = self.atg.convert_all(atoms_list)
# check shape/values of features
Expand All @@ -143,7 +156,9 @@ def test_convert_all(self) -> None:
# positions
act_positions = self.atoms.get_positions()
positions = data_list[0].pos.numpy()
mindiff = min_diff(act_positions, positions, self.atoms.get_cell(), self.atoms.pbc)
mindiff = min_diff(
act_positions, positions, self.atoms.get_cell(), self.atoms.pbc
)
np.testing.assert_allclose(mindiff, 0, atol=1e-6)
# check energy value
act_energy = self.atoms.get_potential_energy(apply_constraint=False)
Expand All @@ -162,3 +177,21 @@ def test_convert_all(self) -> None:
np.testing.assert_allclose(
self.atoms.info["stiffness_tensor"], stiffness_tensor
)

def test_convert_all_ase_db(self, tmp_path_factory) -> None:
# run convert_all on an ASE db object

# There is a possible bug in ASE which makes this test annoying to write.
# AtomsRow.toatoms() has a calculator attached that computes a stress tensor # with the wrong shape: (9,). This makes convert_all fail due to an assertion in
# atoms.get_stress().

tmp_path = tmp_path_factory.mktemp("convert_all_test")
with db.connect(tmp_path / "asedb.db") as database:
database.write(self.atoms, data=self.atoms.info)
data_list = self.atg_only_stiffness.convert_all(database)

# additional data (ie stiffness_tensor)
stiffness_tensor = data_list[0].stiffness_tensor.numpy()
np.testing.assert_allclose(
self.atoms.info["stiffness_tensor"], stiffness_tensor
)

0 comments on commit bd93b2d

Please sign in to comment.