From cc4df9b1690e4181524e50aa2b804d93f611c477 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 22 Nov 2024 08:45:03 -0800 Subject: [PATCH 1/5] fix: correcting logic for edge calculation --- matsciml/datasets/transforms/pbc.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/matsciml/datasets/transforms/pbc.py b/matsciml/datasets/transforms/pbc.py index eedaacbf..a363ce91 100644 --- a/matsciml/datasets/transforms/pbc.py +++ b/matsciml/datasets/transforms/pbc.py @@ -187,9 +187,12 @@ def __call__(self, data: DataDict) -> DataDict: raise RuntimeError(f"Requested backend f{self.backend} not available.") data.update(graph_props) if not self.allow_self_loops: - mask = data["src_nodes"] == data["dst_nodes"] + # this looks for src and dst nodes that are the same, i.e. self-loops + loop_mask = data["src_nodes"] == data["dst_nodes"] # only mask out self-loops within the same image - mask &= data["unit_offsets"].sum(dim=-1) == 0 + image_mask = data["images"].sum(dim=-1) == 0 + # we negate the mask because we want to *exclude* what we've found + mask = ~torch.logical_and(loop_mask, image_mask) # apply mask to each of the tensors that depend on edges for key in ["src_nodes", "dst_nodes", "images", "unit_offsets", "offsets"]: data[key] = data[key][mask] From 0055c8ccf0f0b77bcb2e7403fba32ebb525f5684 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 22 Nov 2024 09:05:34 -0800 Subject: [PATCH 2/5] test: added unit test devoted to checking self-loop --- .../datasets/transforms/tests/test_pbc.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/matsciml/datasets/transforms/tests/test_pbc.py b/matsciml/datasets/transforms/tests/test_pbc.py index 2878f614..8c98a4e8 100644 --- a/matsciml/datasets/transforms/tests/test_pbc.py +++ b/matsciml/datasets/transforms/tests/test_pbc.py @@ -85,3 +85,28 @@ def test_periodic_generation( for index, count in counts.items(): if not self_loops: assert count < 10, print(f"Node {index} has too many counts. {src_nodes}") + + +def test_self_loop_condition(): + """Tests for whether the self-loops exclusion is behaving as intended""" + coords = torch.FloatTensor(alumina.cart_coords) + cell = torch.FloatTensor(alumina.lattice.matrix) + num_atoms = coords.size(0) + atomic_numbers = torch.ones(num_atoms) + packed_data = {"pos": coords, "cell": cell, "atomic_numbers": atomic_numbers} + no_loop_transform = PeriodicPropertiesTransform( + cutoff_radius=6.0, backend="ase", allow_self_loops=False + ) + no_loop_result = no_loop_transform(packed_data) + # since it's no self loops this sum should be zero + same_node = no_loop_result["src_nodes"] == no_loop_result["dst_nodes"] + same_image = no_loop_result["images"].sum(dim=-1) == 0 + assert torch.sum(torch.logical_and(same_node, same_image)) == 0 + allow_loop_transform = PeriodicPropertiesTransform( + cutoff_radius=6.0, backend="ase", allow_self_loops=True + ) + loop_result = allow_loop_transform(packed_data) + # there should be some self-loops in this graph + same_node = loop_result["src_nodes"] == loop_result["dst_nodes"] + same_image = loop_result["images"].sum(dim=-1) == 0 + assert torch.sum(torch.logical_and(same_node, same_image)) > 0 From 9f51c994feac06dc09fba2ed5295612c508eea92 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 22 Nov 2024 09:05:59 -0800 Subject: [PATCH 3/5] test: updated max neighbors check to be greater than or equal --- matsciml/datasets/transforms/tests/test_pbc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/matsciml/datasets/transforms/tests/test_pbc.py b/matsciml/datasets/transforms/tests/test_pbc.py index 8c98a4e8..41c53c5f 100644 --- a/matsciml/datasets/transforms/tests/test_pbc.py +++ b/matsciml/datasets/transforms/tests/test_pbc.py @@ -84,7 +84,7 @@ def test_periodic_generation( counts = Counter(src_nodes) for index, count in counts.items(): if not self_loops: - assert count < 10, print(f"Node {index} has too many counts. {src_nodes}") + assert count <= 10, print(f"Node {index} has too many counts. {src_nodes}") def test_self_loop_condition(): From f7d7e4752255f7a121ee6160454496b981977b6c Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 22 Nov 2024 09:24:04 -0800 Subject: [PATCH 4/5] fix: correcting assert error message in test --- matsciml/datasets/transforms/tests/test_pbc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/matsciml/datasets/transforms/tests/test_pbc.py b/matsciml/datasets/transforms/tests/test_pbc.py index 41c53c5f..bf8f8ff9 100644 --- a/matsciml/datasets/transforms/tests/test_pbc.py +++ b/matsciml/datasets/transforms/tests/test_pbc.py @@ -84,7 +84,7 @@ def test_periodic_generation( counts = Counter(src_nodes) for index, count in counts.items(): if not self_loops: - assert count <= 10, print(f"Node {index} has too many counts. {src_nodes}") + assert count <= 10, f"Node {index} has too many counts. {src_nodes}" def test_self_loop_condition(): From 5458b128c26379931e8c07b1a4b272c1e9f26279 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 22 Nov 2024 09:25:53 -0800 Subject: [PATCH 5/5] docs: added TODO for pymatgen max neighbors check --- matsciml/datasets/transforms/tests/test_pbc.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/matsciml/datasets/transforms/tests/test_pbc.py b/matsciml/datasets/transforms/tests/test_pbc.py index bf8f8ff9..90bbaa3e 100644 --- a/matsciml/datasets/transforms/tests/test_pbc.py +++ b/matsciml/datasets/transforms/tests/test_pbc.py @@ -56,9 +56,7 @@ ) @pytest.mark.parametrize("self_loops", [True, False]) @pytest.mark.parametrize("backend", ["pymatgen", "ase"]) -@pytest.mark.parametrize( - "cutoff_radius", [6.0, 9.0, 15.0] -) # TODO figure out why pmg fails on 3 +@pytest.mark.parametrize("cutoff_radius", [6.0, 9.0, 15.0]) def test_periodic_generation( coords: np.ndarray, cell: np.ndarray, @@ -84,6 +82,8 @@ def test_periodic_generation( counts = Counter(src_nodes) for index, count in counts.items(): if not self_loops: + # TODO pymatgen backend fails this check at cutoff radius = 15 + # and I don't know why assert count <= 10, f"Node {index} has too many counts. {src_nodes}"