Skip to content

Commit

Permalink
Adds test for HilbertSort class (#1126)
Browse files Browse the repository at this point in the history
* Adds test for HilbertSort class

* Basic test that verifies that the distance between the particles
within a block are reduced by sorting
* Modifies tests/common.py::hilbert_sort to image particles the same way
that the C++ does for consistency, rather than just putting things above
zero
  • Loading branch information
badisa authored Aug 17, 2023
1 parent 1b628a3 commit b853c2d
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 9 deletions.
9 changes: 5 additions & 4 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,12 @@ def prepare_nb_system(
return params, potential


def hilbert_sort(conf, D):
hc = HilbertCurve(HILBERT_GRID_DIM, D)
def hilbert_sort(conf, box):
hc = HilbertCurve(HILBERT_GRID_DIM, conf.shape[1])

# hc assumes non-negative coordinates
conf = np.array(conf - np.min(conf))
box_diag = np.diagonal(box)
# hc assumes non-negative coordinates, re-image coordinates into home box
conf = conf - box_diag * np.floor(conf / box_diag)
assert (conf >= 0.0).all()

int_confs = (conf * 1000).astype(np.int64)
Expand Down
44 changes: 44 additions & 0 deletions tests/test_hilbert_sort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Intended to test the custom_ops.HilbertSort. Note that this does not compare against the
the tests/common.py::hilbert_sort function which produces a different permutation. This simply verifies that
sorting will produce more compact blocks.
"""
import numpy as np
import pytest
from numpy.typing import NDArray

from timemachine.lib import custom_ops
from timemachine.potentials.jax_utils import delta_r
from timemachine.testsystems.dhfr import setup_dhfr

pytestmark = [pytest.mark.memcheck]


def get_max_block_distances(coords: NDArray, box: NDArray, block_size: int) -> NDArray:
"""Compute the max distance between all particles within each block"""
# Act on a copy to avoid modifying original coordinates
coords = coords.copy()
N = coords.shape[0]
num_blocks = (N + block_size - 1) // block_size
block_distances = []
for bidx in range(num_blocks):
start_idx = bidx * block_size
end_idx = min((bidx + 1) * block_size, N)
block = coords[start_idx:end_idx]
block_distances.append(np.max(delta_r(block[:, None], block[None, :], box=box)))
return np.array(block_distances)


@pytest.mark.parametrize("block_size", [8, 16, 32])
def test_hilbert_sort_dhfr(block_size):
_, _, coords, box = setup_dhfr()
distances = get_max_block_distances(coords, box, block_size)
unsorted_mean_dist = np.mean(distances)

sorter = custom_ops.HilbertSort(coords.shape[0])
perm = sorter.sort(coords, box)
sorted_coords = coords[perm]
sorted_dists = get_max_block_distances(sorted_coords, box, block_size)
sorted_mean_dist = np.mean(sorted_dists)
# On average the max distance within blocks generated by hilbert curve sorting should be smaller than half
# of the unsorted distances
assert sorted_mean_dist < (unsorted_mean_dist * 0.5)
8 changes: 3 additions & 5 deletions tests/test_nblist.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_block_bounds_dhfr(precision, atol, rtol, sort):
nblist = custom_ops.Neighborlist_f64(coords.shape[0])

if sort:
perm = hilbert_sort(coords, coords.shape[1])
perm = hilbert_sort(coords, box)
coords = coords[perm]

block_size = 32
Expand Down Expand Up @@ -252,12 +252,11 @@ def test_neighborlist():
diag = np.amax(coords, axis=0) - np.amin(coords, axis=0) + padding
box = np.eye(3) * diag

D = 3
cutoff = 1.0

sort = True
if sort:
perm = hilbert_sort(coords, D)
perm = hilbert_sort(coords, box)
coords = coords[perm]

ref_ixn_list = build_reference_ixn_list(coords, box, cutoff)
Expand Down Expand Up @@ -325,7 +324,6 @@ def test_neighborlist_on_subset_of_system():
coords = np.concatenate([host_coords, ligand_coords])
N = coords.shape[0]

D = 3
cutoff = 1.0
padding = 0.1

Expand All @@ -336,7 +334,7 @@ def test_neighborlist_on_subset_of_system():
atom_idxs = np.arange(num_host_atoms, N, dtype=np.uint32)
sort = True
if sort:
perm = hilbert_sort(coords, D)
perm = hilbert_sort(coords, box)
coords = coords[perm]
# Get the new idxs of the ligand atoms
atom_idxs = np.isin(perm, atom_idxs).nonzero()[0]
Expand Down
22 changes: 22 additions & 0 deletions timemachine/cpp/src/hilbert_sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,26 @@ void HilbertSort::sort_device(
gpuErrchk(cudaPeekAtLastError());
}

std::vector<unsigned int> HilbertSort::sort_host(const int N, const double *h_coords, const double *h_box) {

std::vector<unsigned int> h_atom_idxs(N);
std::iota(h_atom_idxs.begin(), h_atom_idxs.end(), 0);

DeviceBuffer<double> d_coords(N * 3);
DeviceBuffer<double> d_box(3 * 3);
DeviceBuffer<unsigned int> d_atom_idxs(N);
DeviceBuffer<unsigned int> d_perm(N);

d_coords.copy_from(h_coords);
d_box.copy_from(h_box);
d_atom_idxs.copy_from(&h_atom_idxs[0]);

cudaStream_t stream = static_cast<cudaStream_t>(0);
this->sort_device(N, d_atom_idxs.data, d_coords.data, d_box.data, d_perm.data, stream);
gpuErrchk(cudaStreamSynchronize(stream));

d_perm.copy_to(&h_atom_idxs[0]);
return h_atom_idxs;
}

} // namespace timemachine
4 changes: 4 additions & 0 deletions timemachine/cpp/src/hilbert_sort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include "device_buffer.hpp"
#include "math_utils.cuh"
#include <memory>
#include <numeric>
#include <vector>

namespace timemachine {

Expand Down Expand Up @@ -32,6 +34,8 @@ class HilbertSort {
const double *d_box,
unsigned int *d_output_perm,
cudaStream_t stream);

std::vector<unsigned int> sort_host(const int N, const double *h_coords, const double *h_box);
};

} // namespace timemachine
25 changes: 25 additions & 0 deletions timemachine/cpp/src/wrap_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,29 @@ template <typename RealType> void declare_neighborlist(py::module &m, const char
.def("resize", &timemachine::Neighborlist<RealType>::resize, py::arg("size"));
}

void declare_hilbert_sort(py::module &m) {

using Class = timemachine::HilbertSort;
std::string pyclass_name = std::string("HilbertSort");
py::class_<Class, std::shared_ptr<Class>>(m, pyclass_name.c_str(), py::buffer_protocol(), py::dynamic_attr())
.def(py::init([](const int N) { return new timemachine::HilbertSort(N); }), py::arg("size"))
.def(
"sort",
[](timemachine::HilbertSort &sorter,
const py::array_t<double, py::array::c_style> &coords,
const py::array_t<double, py::array::c_style> &box) -> const py::array_t<uint32_t, py::array::c_style> {
const int N = coords.shape()[0];
verify_coords_and_box(coords, box);

std::vector<unsigned int> sort_perm = sorter.sort_host(N, coords.data(), box.data());
py::array_t<uint32_t, py::array::c_style> output_perm(sort_perm.size());
std::memcpy(output_perm.mutable_data(), sort_perm.data(), sort_perm.size() * sizeof(unsigned int));
return output_perm;
},
py::arg("coords"),
py::arg("box"));
}

void declare_context(py::module &m) {

using Class = timemachine::Context;
Expand Down Expand Up @@ -1313,6 +1336,8 @@ PYBIND11_MODULE(custom_ops, m) {
declare_neighborlist<double>(m, "f64");
declare_neighborlist<float>(m, "f32");

declare_hilbert_sort(m);

declare_centroid_restraint<double>(m, "f64");
declare_centroid_restraint<float>(m, "f32");

Expand Down
4 changes: 4 additions & 0 deletions timemachine/lib/custom_ops.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ class HarmonicBond_f32(Potential):
class HarmonicBond_f64(Potential):
def __init__(self, bond_idxs: numpy.typing.NDArray[numpy.int32]) -> None: ...

class HilbertSort:
def __init__(self, size: int) -> None: ...
def sort(self, coords: numpy.typing.NDArray[numpy.float64], box: numpy.typing.NDArray[numpy.float64]) -> numpy.typing.NDArray[numpy.uint32]: ...

class Integrator:
def __init__(self, *args, **kwargs) -> None: ...

Expand Down

0 comments on commit b853c2d

Please sign in to comment.