From c35a5f77b7885a3dabe089774da063324e03bb1b Mon Sep 17 00:00:00 2001 From: Forrest York Date: Thu, 3 Aug 2023 08:13:39 -0600 Subject: [PATCH] Speed up k_find_block_bounds (#1105) * Speed up k_find_bounding_blocks * TODO: Rewrite with a parallel reduction to avoid running tight loop of 32 threads --- tests/test_nblist.py | 106 +++++++++----- .../cpp/src/kernels/k_neighborlist.cuh | 137 ++++++++++-------- timemachine/cpp/src/neighborlist.cu | 27 ++-- 3 files changed, 169 insertions(+), 101 deletions(-) diff --git a/tests/test_nblist.py b/tests/test_nblist.py index 4c0744ee1..4745659a6 100644 --- a/tests/test_nblist.py +++ b/tests/test_nblist.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Tuple import numpy as np import pytest @@ -9,6 +9,7 @@ from timemachine.ff import Forcefield from timemachine.lib import custom_ops from timemachine.md.builders import build_water_system +from timemachine.testsystems.dhfr import setup_dhfr from timemachine.testsystems.relative import get_hif2a_ligand_pair_single_topology pytestmark = [pytest.mark.memcheck] @@ -19,47 +20,84 @@ def test_empty_neighborlist(): custom_ops.Neighborlist_f32(0) -def test_block_bounds(): - np.random.seed(2020) - sizes = [128, 156, 298] - max_size = max(sizes) - nblist = custom_ops.Neighborlist_f64(max_size) - for N in sizes: - nblist.resize(N) - block_size = 32 +def reference_block_bounds(coords: NDArray, box: NDArray, block_size: int) -> Tuple[NDArray, NDArray]: + # Make a copy to avoid modify the coordinates that end up used later by the Neighborlist + coords = coords.copy() + N = coords.shape[0] + box_diag = np.diagonal(box) + num_blocks = (N + block_size - 1) // block_size - D = 3 + ref_ctrs = [] + ref_exts = [] + + for bidx in range(num_blocks): + start_idx = bidx * block_size + end_idx = min((bidx + 1) * block_size, N) + block_coords = coords[start_idx:end_idx] + min_coords = block_coords[0] + max_coords = block_coords[0] + for new_coords in block_coords[1:]: + center = 0.5 * (max_coords + min_coords) + new_coords -= box_diag * np.floor((new_coords - center) / box_diag + 0.5) + min_coords = np.minimum(min_coords, new_coords) + max_coords = np.maximum(max_coords, new_coords) + + ref_ctrs.append((max_coords + min_coords) / 2) + ref_exts.append((max_coords - min_coords) / 2) + + ref_ctrs = np.array(ref_ctrs) + ref_exts = np.array(ref_exts) + return ref_ctrs, ref_exts + + +@pytest.mark.parametrize("precision,atol,rtol", [(np.float32, 1e-6, 1e-6), (np.float64, 1e-7, 1e-7)]) +@pytest.mark.parametrize("sort", [True, False]) +def test_block_bounds_dhfr(precision, atol, rtol, sort): + _, _, coords, box = setup_dhfr() + + if precision == np.float32: + nblist = custom_ops.Neighborlist_f32(coords.shape[0]) + else: + nblist = custom_ops.Neighborlist_f64(coords.shape[0]) - coords = np.random.randn(N, D) - box_diag = np.random.rand(3) + 1 - box = np.eye(3) * box_diag - num_blocks = (N + block_size - 1) // block_size + if sort: + perm = hilbert_sort(coords + np.argmin(coords), coords.shape[1]) + coords = coords[perm] + + block_size = 32 + ref_ctrs, ref_exts = reference_block_bounds(coords, box, block_size) + + test_ctrs, test_exts = nblist.compute_block_bounds(coords, box, block_size) + + for i, (ref_ctr, test_ctr) in enumerate(zip(ref_ctrs, test_ctrs)): + np.testing.assert_allclose(ref_ctr, test_ctr, atol=atol, rtol=rtol, err_msg=f"Center {i} has mismatch") + for i, (ref_ext, test_ext) in enumerate(zip(ref_exts, test_exts)): + np.testing.assert_allclose(ref_ext, test_ext, atol=atol, rtol=rtol, err_msg=f"Extent {i} has mismatch") + + +@pytest.mark.parametrize("precision,atol,rtol", [(np.float32, 1e-6, 1e-6), (np.float64, 1e-7, 1e-7)]) +@pytest.mark.parametrize("size", [12, 128, 156, 298]) +def test_block_bounds(precision, atol, rtol, size): + np.random.seed(2020) + block_size = 32 + D = 3 - ref_ctrs = [] - ref_exts = [] + if precision == np.float32: + nblist = custom_ops.Neighborlist_f32(size) + else: + nblist = custom_ops.Neighborlist_f64(size) - for bidx in range(num_blocks): - start_idx = bidx * block_size - end_idx = min((bidx + 1) * block_size, N) - block_coords = coords[start_idx:end_idx] - min_coords = block_coords[0] - max_coords = block_coords[0] - for new_coords in block_coords[1:]: - center = 0.5 * (max_coords + min_coords) - new_coords -= box_diag * np.floor((new_coords - center) / box_diag + 0.5) - min_coords = np.minimum(min_coords, new_coords) - max_coords = np.maximum(max_coords, new_coords) + coords = np.random.randn(size, D) - ref_ctrs.append((max_coords + min_coords) / 2) - ref_exts.append((max_coords - min_coords) / 2) + box_diag = np.random.rand(3) + 1 + box = np.eye(3) * box_diag - ref_ctrs = np.array(ref_ctrs) - ref_exts = np.array(ref_exts) + ref_ctrs, ref_exts = reference_block_bounds(coords, box, block_size) - test_ctrs, test_exts = nblist.compute_block_bounds(coords, box, block_size) + test_ctrs, test_exts = nblist.compute_block_bounds(coords, box, block_size) - np.testing.assert_almost_equal(ref_ctrs, test_ctrs) - np.testing.assert_almost_equal(ref_exts, test_exts) + np.testing.assert_allclose(ref_ctrs, test_ctrs, atol=atol, rtol=rtol) + np.testing.assert_allclose(ref_exts, test_exts, atol=atol, rtol=rtol) def get_water_coords(D, sort=False): diff --git a/timemachine/cpp/src/kernels/k_neighborlist.cuh b/timemachine/cpp/src/kernels/k_neighborlist.cuh index 70b9982a0..29096060c 100644 --- a/timemachine/cpp/src/kernels/k_neighborlist.cuh +++ b/timemachine/cpp/src/kernels/k_neighborlist.cuh @@ -8,7 +8,6 @@ static const int TILE_SIZE = WARP_SIZE; template void __global__ k_find_block_bounds( - const int N, // Number of atoms const int num_tiles, // Number of tiles const int num_indices, // Number of indices const unsigned int *__restrict__ row_idxs, // [num_indices] @@ -22,71 +21,95 @@ void __global__ k_find_block_bounds( // Algorithm taken from https://github.com/openmm/openmm/blob/master/platforms/cuda/src/kernels/findInteractingBlocks.cu#L7 // Computes smaller bounding boxes than simpler form by accounting for periodic box conditions - // each thread processes one tile - const int tile_idx = blockIdx.x * blockDim.x + threadIdx.x; + // each warp processes one tile + int tile_idx = (blockIdx.x * blockDim.x + threadIdx.x) / WARP_SIZE; + if (tile_idx >= num_tiles) { return; } - const int row_idx = tile_idx * TILE_SIZE; - if (row_idx >= num_indices) { - return; + + RealType pos_x; + RealType pos_y; + RealType pos_z; + + RealType min_pos_x; + RealType min_pos_y; + RealType min_pos_z; + + RealType max_pos_x; + RealType max_pos_y; + RealType max_pos_z; + + RealType imaged_pos; + + const RealType box_x = box[0 * 3 + 0]; + const RealType box_y = box[1 * 3 + 1]; + const RealType box_z = box[2 * 3 + 2]; + + const RealType inv_bx = 1 / box_x; + const RealType inv_by = 1 / box_y; + const RealType inv_bz = 1 / box_z; + + int row_idx = tile_idx * TILE_SIZE + (threadIdx.x % WARP_SIZE); + // Reset the ixn count + if (row_idx == 0) { + ixn_count[0] = 0; } - int atom_idx = row_idxs[row_idx]; - - const RealType bx = box[0 * 3 + 0]; - const RealType by = box[1 * 3 + 1]; - const RealType bz = box[2 * 3 + 2]; - - const RealType inv_bx = 1 / bx; - const RealType inv_by = 1 / by; - const RealType inv_bz = 1 / bz; - RealType pos_x = coords[atom_idx * 3 + 0]; - RealType pos_y = coords[atom_idx * 3 + 1]; - RealType pos_z = coords[atom_idx * 3 + 2]; - - RealType minPos_x = pos_x; - RealType minPos_y = pos_y; - RealType minPos_z = pos_z; - - RealType maxPos_x = pos_x; - RealType maxPos_y = pos_y; - RealType maxPos_z = pos_z; - - const int last = min(row_idx + TILE_SIZE, num_indices); - for (int i = row_idx + 1; i < last; i++) { - atom_idx = row_idxs[i]; + + if (row_idx < num_indices) { + int atom_idx = row_idxs[row_idx]; + pos_x = coords[atom_idx * 3 + 0]; pos_y = coords[atom_idx * 3 + 1]; pos_z = coords[atom_idx * 3 + 2]; - // Build up center over time, and recenter before computing - // min and max, to reduce overall size of box thanks to accounting - // for periodic boundary conditions - RealType center_x = static_cast(0.5) * (maxPos_x + minPos_x); - RealType center_y = static_cast(0.5) * (maxPos_y + minPos_y); - RealType center_z = static_cast(0.5) * (maxPos_z + minPos_z); - pos_x -= bx * nearbyint((pos_x - center_x) * inv_bx); - pos_y -= by * nearbyint((pos_y - center_y) * inv_by); - pos_z -= bz * nearbyint((pos_z - center_z) * inv_bz); - minPos_x = min(minPos_x, pos_x); - minPos_y = min(minPos_y, pos_y); - minPos_z = min(minPos_z, pos_z); - - maxPos_x = max(maxPos_x, pos_x); - maxPos_y = max(maxPos_y, pos_y); - maxPos_z = max(maxPos_z, pos_z); - } - block_bounds_ctr[tile_idx * 3 + 0] = static_cast(0.5) * (maxPos_x + minPos_x); - block_bounds_ctr[tile_idx * 3 + 1] = static_cast(0.5) * (maxPos_y + minPos_y); - block_bounds_ctr[tile_idx * 3 + 2] = static_cast(0.5) * (maxPos_z + minPos_z); + min_pos_x = pos_x; + min_pos_y = pos_y; + min_pos_z = pos_z; - block_bounds_ext[tile_idx * 3 + 0] = static_cast(0.5) * (maxPos_x - minPos_x); - block_bounds_ext[tile_idx * 3 + 1] = static_cast(0.5) * (maxPos_y - minPos_y); - block_bounds_ext[tile_idx * 3 + 2] = static_cast(0.5) * (maxPos_z - minPos_z); + max_pos_x = min_pos_x; + max_pos_y = min_pos_y; + max_pos_z = min_pos_z; + } - // Reset the ixn count - if (tile_idx == 0) { - ixn_count[0] = 0; + // Only the first thread in each warp computes the min/max of the bounding box + bool compute_bounds = threadIdx.x % WARP_SIZE == 0; + + // Build up center over time, and recenter before computing + // min and max, to reduce overall size of box thanks to accounting + // for periodic boundary conditions + const int src_lane = threadIdx.x + 1 % WARP_SIZE; + for (int i = 0; i < WARP_SIZE; i++) { + row_idx = __shfl_sync(0xffffffff, row_idx, src_lane); + pos_x = __shfl_sync(0xffffffff, pos_x, src_lane); + pos_y = __shfl_sync(0xffffffff, pos_y, src_lane); + pos_z = __shfl_sync(0xffffffff, pos_z, src_lane); + // Only evaluate for the first thread and when the row idx is valid + if (compute_bounds && row_idx < num_indices) { + imaged_pos = + pos_x - box_x * nearbyint((pos_x - static_cast(0.5) * (max_pos_x + min_pos_x)) * inv_bx); + min_pos_x = min(min_pos_x, imaged_pos); + max_pos_x = max(max_pos_x, imaged_pos); + + imaged_pos = + pos_y - box_y * nearbyint((pos_y - static_cast(0.5) * (max_pos_y + min_pos_y)) * inv_by); + min_pos_y = min(min_pos_y, imaged_pos); + max_pos_y = max(max_pos_y, imaged_pos); + + imaged_pos = + pos_z - box_z * nearbyint((pos_z - static_cast(0.5) * (max_pos_z + min_pos_z)) * inv_bz); + min_pos_z = min(min_pos_z, imaged_pos); + max_pos_z = max(max_pos_z, imaged_pos); + } + } + if (threadIdx.x % WARP_SIZE == 0) { + block_bounds_ctr[tile_idx * 3 + 0] = static_cast(0.5) * (max_pos_x + min_pos_x); + block_bounds_ctr[tile_idx * 3 + 1] = static_cast(0.5) * (max_pos_y + min_pos_y); + block_bounds_ctr[tile_idx * 3 + 2] = static_cast(0.5) * (max_pos_z + min_pos_z); + + block_bounds_ext[tile_idx * 3 + 0] = static_cast(0.5) * (max_pos_x - min_pos_x); + block_bounds_ext[tile_idx * 3 + 1] = static_cast(0.5) * (max_pos_y - min_pos_y); + block_bounds_ext[tile_idx * 3 + 2] = static_cast(0.5) * (max_pos_z - min_pos_z); } } @@ -181,7 +204,7 @@ void __global__ k_find_blocks_with_ixns( const RealType *__restrict__ column_bb_ctr, // [N * 3] block centers const RealType *__restrict__ column_bb_ext, // [N * 3] block extents const RealType *__restrict__ row_bb_ctr, // [N * 3] block centers - const RealType *__restrict__ row_bb_ext, // [N * 3] block extants + const RealType *__restrict__ row_bb_ext, // [N * 3] block extents const double *__restrict__ coords, // [N * 3] const double *__restrict__ box, unsigned int *__restrict__ interactionCount, // number of tiles that have interactions diff --git a/timemachine/cpp/src/neighborlist.cu b/timemachine/cpp/src/neighborlist.cu index cb983eb12..021d2aa7f 100644 --- a/timemachine/cpp/src/neighborlist.cu +++ b/timemachine/cpp/src/neighborlist.cu @@ -64,6 +64,9 @@ void Neighborlist::compute_block_bounds_host( DeviceBuffer d_coords(N * D); DeviceBuffer d_box(D * D); + std::vector h_block_bounds_centers(this->num_column_blocks() * 3); + std::vector h_block_bounds_extents(this->num_column_blocks() * 3); + d_coords.copy_from(h_coords); d_box.copy_from(h_box); @@ -71,15 +74,23 @@ void Neighborlist::compute_block_bounds_host( gpuErrchk(cudaDeviceSynchronize()); gpuErrchk(cudaMemcpy( - h_bb_ctrs, + &h_block_bounds_centers[0], d_column_block_bounds_ctr_, this->num_column_blocks() * 3 * sizeof(*d_column_block_bounds_ctr_), cudaMemcpyDeviceToHost)); gpuErrchk(cudaMemcpy( - h_bb_exts, + &h_block_bounds_extents[0], d_column_block_bounds_ext_, this->num_column_blocks() * 3 * sizeof(*d_column_block_bounds_ext_), cudaMemcpyDeviceToHost)); + + // Handle the float -> double, doing a direct copy from a double buffer to a float buffer results in garbage values + for (auto i = 0; i < h_block_bounds_centers.size(); i++) { + h_bb_ctrs[i] = h_block_bounds_centers[i]; + } + for (auto i = 0; i < h_block_bounds_extents.size(); i++) { + h_bb_exts[i] = h_block_bounds_extents[i]; + } } template @@ -199,11 +210,9 @@ void Neighborlist::compute_block_bounds_device( } const int tpb = DEFAULT_THREADS_PER_BLOCK; - const int column_blocks = this->num_column_blocks(); // total number of blocks we need to process - k_find_block_bounds<<>>( - N, - column_blocks, + k_find_block_bounds<<>>( + this->num_column_blocks(), NC_, d_column_idxs_, d_coords, @@ -215,10 +224,8 @@ void Neighborlist::compute_block_bounds_device( // In the case of upper triangle of the matrix, the column and row indices are the same, so only compute block ixns for both // when they are different if (!this->compute_upper_triangular()) { - const int row_blocks = this->num_row_blocks(); - k_find_block_bounds<<>>( - N, - row_blocks, + k_find_block_bounds<<>>( + this->num_row_blocks(), NR_, d_row_idxs_, d_coords,