diff --git a/ipsuite/geometry/unwrap.py b/ipsuite/geometry/unwrap.py index 8fe80a27..8dfdc410 100644 --- a/ipsuite/geometry/unwrap.py +++ b/ipsuite/geometry/unwrap.py @@ -1,7 +1,9 @@ import ase +import networkx as nx import numpy as np from ase.calculators.singlepoint import SinglePointCalculator +from ipsuite.geometry import graphs from ipsuite.geometry.graphs import edges_from_atoms @@ -24,26 +26,18 @@ def sort_atomic_edges(edges, idx): def displace_neighbors(mol, edges): - for edge in edges: - dist = mol.get_distance(edge[0], edge[1], vector=True) - pdist = mol.get_distance(edge[0], edge[1], True, vector=True) + dist = mol.get_distance(edges[0], edges[1], vector=True) + pdist = mol.get_distance(edges[0], edges[1], True, vector=True) - displacement = dist - pdist - mol.positions[edge[1]] -= displacement + displacement = dist - pdist + mol.positions[edges[1]] -= displacement def unwrap(atoms, edges, idx): - # TODO this should probably be width first, not depth first - current_edges = sort_atomic_edges(edges, idx) - displace_neighbors(atoms, current_edges) - - next_idxs = current_edges[:, 1] - - mask = np.all(edges != idx, axis=1) - filtered_edges = edges[mask] - - for next_idx in next_idxs: - unwrap(atoms, filtered_edges, next_idx) + G = graphs.atoms_to_graph(atoms) + edges = nx.traversal.bfs_edges(G, idx) + for e in edges: + displace_neighbors(atoms, e) def unwrap_system(atoms: ase.Atoms, components: list[np.ndarray]) -> list[ase.Atom]: