From 33ca1e2676ff57189f20012ff62da2f4573ae30f Mon Sep 17 00:00:00 2001 From: hexagonrose Date: Tue, 3 Dec 2024 15:55:27 +0900 Subject: [PATCH] feat: add graph build test of ase and matscipy in pytest --- pyproject.toml | 10 +++ tests/unit_tests/test_data.py | 121 +++++++++++++++++++++++++++++++++- 2 files changed, 129 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 56cfa778..b0885244 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,3 +69,13 @@ exclude = ["tests*", "example_inputs*", ] log_cli = true log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)" log_cli_date_format = "%Y-%m-%d %H:%M:%S" + +[tool.ruff] +line-length = 85 + +[tool.ruff.lint] +extend-select = ["E501"] + +[tool.ruff.format] +quote-style = "single" +docstring-code-format = true diff --git a/tests/unit_tests/test_data.py b/tests/unit_tests/test_data.py index c9a3ca51..6350dd36 100644 --- a/tests/unit_tests/test_data.py +++ b/tests/unit_tests/test_data.py @@ -10,28 +10,44 @@ import numpy as np import pytest import torch +from ase import Atoms from ase.build import bulk, molecule from torch_geometric.loader import DataLoader +import sevenn._keys as KEY import sevenn.train.dataload as dl import sevenn.train.graph_dataset as ds from sevenn._const import NUM_UNIV_ELEMENT from sevenn.atom_graph_data import AtomGraphData +from sevenn.util import model_from_checkpoint, pretrained_name_to_path cutoff = 4.0 +lattice_constant = 3.35 _samples = { 'bulk': bulk('NaCl', 'rocksalt', a=5.63), 'mol': molecule('H2O'), 'isolated': molecule('H'), + 'small_bulk': Atoms( + symbols='Cu', + positions=[ + (0, 0, 0), # Atom at the corner of the cube + ], + cell=[ + [lattice_constant, 0, 0], + [0, lattice_constant, 0], + [0, 0, lattice_constant], + ], + pbc=True, # Periodic boundary conditions + ), } -_nedges_c4 = {'bulk': 36, 'mol': 6, 'isolated': 0} +_nedges_c4 = {'bulk': 36, 'mol': 6, 'isolated': 0, 'small_bulk': 18} def get_atoms( - atoms_type: Literal['bulk', 'mol', 'isolated'], + atoms_type: Literal['bulk', 'mol', 'isolated', 'small_bulk'], init_y_as: Literal['calc', 'info', 'none'], ): """ @@ -357,3 +373,104 @@ def test_7net_graph_dataset_batch_shape(a_types, init_ys, tmp_path): ), f'{k}: {type(graph[k])} is not an tensor' assert graph[k].is_floating_point() == (dtype is float) assert graph[k].shape == shape, f'{k} shape {graph[k].shape} != {shape}' + + +@pytest.mark.parametrize('atoms_type', ['bulk', 'mol', 'isolated', 'small_bulk']) +def test_graph_build_ase_and_matscipy(atoms_type): + atoms, _ = get_atoms(atoms_type, 'calc') + atoms.rattle() + pos = atoms.get_positions() + cell = np.array(atoms.get_cell()) + pbc = atoms.get_pbc() + + # graph build check + # ase graph build + edge_src_ase, edge_dst_ase, edge_vec_ase, shifts_ase = dl._graph_build_ase( + cutoff, pbc, cell, pos + ) + # matscipy graph build + edge_src_matsci, edge_dst_matsci, edge_vec_matsci, shifts_matsci = ( + dl._graph_build_matscipy(cutoff, pbc, cell, pos) + ) + + # sort the graph + sorted_indices_ase = np.lexsort( + (edge_vec_ase[:, 2], edge_vec_ase[:, 1], edge_vec_ase[:, 0]) + ) + sorted_indices_matsci = np.lexsort( + (edge_vec_matsci[:, 2], edge_vec_matsci[:, 1], edge_vec_matsci[:, 0]) + ) + sorted_vec_ase = edge_vec_ase[sorted_indices_ase] + sorted_vec_matsci = edge_vec_matsci[sorted_indices_matsci] + sorted_src_ase = edge_src_ase[sorted_indices_ase] + sorted_dst_ase = edge_dst_ase[sorted_indices_ase] + sorted_src_matsci = edge_src_matsci[sorted_indices_matsci] + sorted_dst_matsci = edge_dst_matsci[sorted_indices_matsci] + sorted_shift_ase = shifts_ase[sorted_indices_ase] + sorted_shift_matsci = shifts_matsci[sorted_indices_matsci] + + # compare the result + assert np.allclose(sorted_vec_ase, sorted_vec_matsci) + assert np.array_equal(sorted_src_ase, sorted_src_matsci) + assert np.array_equal(sorted_dst_ase, sorted_dst_matsci) + assert np.array_equal(sorted_shift_ase, sorted_shift_matsci) + + # energy test + model, _ = model_from_checkpoint(pretrained_name_to_path('7net-0_11July2024')) + model.eval() + model.set_is_batch_data(False) + + # for ase energy + edge_idx_ase = np.array([edge_src_ase, edge_dst_ase]) + atomic_numbers = atoms.get_atomic_numbers() + cell = np.array(cell) + vol = dl._correct_scalar(atoms.cell.volume) + if vol == 0: + vol = np.array(np.finfo(float).eps) + + data_ase = { + KEY.NODE_FEATURE: atomic_numbers, + KEY.ATOMIC_NUMBERS: atomic_numbers, + KEY.POS: pos, + KEY.EDGE_IDX: edge_idx_ase, + KEY.EDGE_VEC: edge_vec_ase, + KEY.CELL: cell, + KEY.CELL_SHIFT: shifts_ase, + KEY.CELL_VOLUME: vol, + KEY.NUM_ATOMS: dl._correct_scalar(len(atomic_numbers)), + } + data_ase[KEY.INFO] = {} + atom_graph_data_ase = AtomGraphData.from_numpy_dict(data_ase) + output_ase = model(atom_graph_data_ase) + ase_pred_energy = output_ase[KEY.PRED_TOTAL_ENERGY] + ase_pred_force = output_ase[KEY.PRED_FORCE] + ase_pred_stress = output_ase[KEY.PRED_STRESS] + + # for matsci energy + edge_idx_matsci = np.array([edge_src_matsci, edge_dst_matsci]) + atomic_numbers = atoms.get_atomic_numbers() + cell = np.array(cell) + vol = dl._correct_scalar(atoms.cell.volume) + if vol == 0: + vol = np.array(np.finfo(float).eps) + + data_matsci = { + KEY.NODE_FEATURE: atomic_numbers, + KEY.ATOMIC_NUMBERS: atomic_numbers, + KEY.POS: pos, + KEY.EDGE_IDX: edge_idx_matsci, + KEY.EDGE_VEC: edge_vec_matsci, + KEY.CELL: cell, + KEY.CELL_SHIFT: shifts_matsci, + KEY.CELL_VOLUME: vol, + KEY.NUM_ATOMS: dl._correct_scalar(len(atomic_numbers)), + } + data_matsci[KEY.INFO] = {} + atom_graph_data_matsci = AtomGraphData.from_numpy_dict(data_matsci) + output_matsci = model(atom_graph_data_matsci) + matsci_pred_energy = output_matsci[KEY.PRED_TOTAL_ENERGY] + matsci_pred_force = output_matsci[KEY.PRED_FORCE] + matsci_pred_stress = output_matsci[KEY.PRED_STRESS] + assert torch.equal(ase_pred_energy, matsci_pred_energy) + assert torch.allclose(ase_pred_force, matsci_pred_force, atol=1e-06) + assert torch.allclose(ase_pred_stress, matsci_pred_stress)