diff --git a/matsciml/datasets/transforms/pbc.py b/matsciml/datasets/transforms/pbc.py index eedaacbf..a363ce91 100644 --- a/matsciml/datasets/transforms/pbc.py +++ b/matsciml/datasets/transforms/pbc.py @@ -187,9 +187,12 @@ def __call__(self, data: DataDict) -> DataDict: raise RuntimeError(f"Requested backend f{self.backend} not available.") data.update(graph_props) if not self.allow_self_loops: - mask = data["src_nodes"] == data["dst_nodes"] + # this looks for src and dst nodes that are the same, i.e. self-loops + loop_mask = data["src_nodes"] == data["dst_nodes"] # only mask out self-loops within the same image - mask &= data["unit_offsets"].sum(dim=-1) == 0 + image_mask = data["images"].sum(dim=-1) == 0 + # we negate the mask because we want to *exclude* what we've found + mask = ~torch.logical_and(loop_mask, image_mask) # apply mask to each of the tensors that depend on edges for key in ["src_nodes", "dst_nodes", "images", "unit_offsets", "offsets"]: data[key] = data[key][mask] diff --git a/matsciml/datasets/transforms/tests/test_pbc.py b/matsciml/datasets/transforms/tests/test_pbc.py index 2878f614..90bbaa3e 100644 --- a/matsciml/datasets/transforms/tests/test_pbc.py +++ b/matsciml/datasets/transforms/tests/test_pbc.py @@ -56,9 +56,7 @@ ) @pytest.mark.parametrize("self_loops", [True, False]) @pytest.mark.parametrize("backend", ["pymatgen", "ase"]) -@pytest.mark.parametrize( - "cutoff_radius", [6.0, 9.0, 15.0] -) # TODO figure out why pmg fails on 3 +@pytest.mark.parametrize("cutoff_radius", [6.0, 9.0, 15.0]) def test_periodic_generation( coords: np.ndarray, cell: np.ndarray, @@ -84,4 +82,31 @@ def test_periodic_generation( counts = Counter(src_nodes) for index, count in counts.items(): if not self_loops: - assert count < 10, print(f"Node {index} has too many counts. {src_nodes}") + # TODO pymatgen backend fails this check at cutoff radius = 15 + # and I don't know why + assert count <= 10, f"Node {index} has too many counts. {src_nodes}" + + +def test_self_loop_condition(): + """Tests for whether the self-loops exclusion is behaving as intended""" + coords = torch.FloatTensor(alumina.cart_coords) + cell = torch.FloatTensor(alumina.lattice.matrix) + num_atoms = coords.size(0) + atomic_numbers = torch.ones(num_atoms) + packed_data = {"pos": coords, "cell": cell, "atomic_numbers": atomic_numbers} + no_loop_transform = PeriodicPropertiesTransform( + cutoff_radius=6.0, backend="ase", allow_self_loops=False + ) + no_loop_result = no_loop_transform(packed_data) + # since it's no self loops this sum should be zero + same_node = no_loop_result["src_nodes"] == no_loop_result["dst_nodes"] + same_image = no_loop_result["images"].sum(dim=-1) == 0 + assert torch.sum(torch.logical_and(same_node, same_image)) == 0 + allow_loop_transform = PeriodicPropertiesTransform( + cutoff_radius=6.0, backend="ase", allow_self_loops=True + ) + loop_result = allow_loop_transform(packed_data) + # there should be some self-loops in this graph + same_node = loop_result["src_nodes"] == loop_result["dst_nodes"] + same_image = loop_result["images"].sum(dim=-1) == 0 + assert torch.sum(torch.logical_and(same_node, same_image)) > 0