-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds test for HilbertSort class (#1126)
* Adds test for HilbertSort class * Basic test that verifies that the distance between the particles within a block are reduced by sorting * Modifies tests/common.py::hilbert_sort to image particles the same way that the C++ does for consistency, rather than just putting things above zero
- Loading branch information
Showing
7 changed files
with
107 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
"""Intended to test the custom_ops.HilbertSort. Note that this does not compare against the | ||
the tests/common.py::hilbert_sort function which produces a different permutation. This simply verifies that | ||
sorting will produce more compact blocks. | ||
""" | ||
import numpy as np | ||
import pytest | ||
from numpy.typing import NDArray | ||
|
||
from timemachine.lib import custom_ops | ||
from timemachine.potentials.jax_utils import delta_r | ||
from timemachine.testsystems.dhfr import setup_dhfr | ||
|
||
pytestmark = [pytest.mark.memcheck] | ||
|
||
|
||
def get_max_block_distances(coords: NDArray, box: NDArray, block_size: int) -> NDArray: | ||
"""Compute the max distance between all particles within each block""" | ||
# Act on a copy to avoid modifying original coordinates | ||
coords = coords.copy() | ||
N = coords.shape[0] | ||
num_blocks = (N + block_size - 1) // block_size | ||
block_distances = [] | ||
for bidx in range(num_blocks): | ||
start_idx = bidx * block_size | ||
end_idx = min((bidx + 1) * block_size, N) | ||
block = coords[start_idx:end_idx] | ||
block_distances.append(np.max(delta_r(block[:, None], block[None, :], box=box))) | ||
return np.array(block_distances) | ||
|
||
|
||
@pytest.mark.parametrize("block_size", [8, 16, 32]) | ||
def test_hilbert_sort_dhfr(block_size): | ||
_, _, coords, box = setup_dhfr() | ||
distances = get_max_block_distances(coords, box, block_size) | ||
unsorted_mean_dist = np.mean(distances) | ||
|
||
sorter = custom_ops.HilbertSort(coords.shape[0]) | ||
perm = sorter.sort(coords, box) | ||
sorted_coords = coords[perm] | ||
sorted_dists = get_max_block_distances(sorted_coords, box, block_size) | ||
sorted_mean_dist = np.mean(sorted_dists) | ||
# On average the max distance within blocks generated by hilbert curve sorting should be smaller than half | ||
# of the unsorted distances | ||
assert sorted_mean_dist < (unsorted_mean_dist * 0.5) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters