diff --git a/matsciml/datasets/utils.py b/matsciml/datasets/utils.py index bbf311f5..83256da8 100644 --- a/matsciml/datasets/utils.py +++ b/matsciml/datasets/utils.py @@ -882,7 +882,9 @@ def _all_sites_have_neighbors(neighbors): "pos": coords, } - cell = rearrange(cell, "i j -> () i j") + # only do the reshape if we are missing a dimension + if cell.ndim == 2: + cell = rearrange(cell, "i j -> () i j") return_dict["offsets"] = einsum(return_dict["images"], cell, "v i, n i j -> v j") src, dst = return_dict["src_nodes"], return_dict["dst_nodes"] return_dict["unit_offsets"] = (