Skip to content

Commit

Permalink
Speed up k_find_block_bounds (#1105)
Browse files Browse the repository at this point in the history
* Speed up k_find_bounding_blocks

* TODO: Rewrite with a parallel reduction to avoid running tight loop of 32 threads
  • Loading branch information
badisa authored Aug 3, 2023
1 parent a61612e commit c35a5f7
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 101 deletions.
106 changes: 72 additions & 34 deletions tests/test_nblist.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Tuple

import numpy as np
import pytest
Expand All @@ -9,6 +9,7 @@
from timemachine.ff import Forcefield
from timemachine.lib import custom_ops
from timemachine.md.builders import build_water_system
from timemachine.testsystems.dhfr import setup_dhfr
from timemachine.testsystems.relative import get_hif2a_ligand_pair_single_topology

pytestmark = [pytest.mark.memcheck]
Expand All @@ -19,47 +20,84 @@ def test_empty_neighborlist():
custom_ops.Neighborlist_f32(0)


def test_block_bounds():
np.random.seed(2020)
sizes = [128, 156, 298]
max_size = max(sizes)
nblist = custom_ops.Neighborlist_f64(max_size)
for N in sizes:
nblist.resize(N)
block_size = 32
def reference_block_bounds(coords: NDArray, box: NDArray, block_size: int) -> Tuple[NDArray, NDArray]:
# Make a copy to avoid modify the coordinates that end up used later by the Neighborlist
coords = coords.copy()
N = coords.shape[0]
box_diag = np.diagonal(box)
num_blocks = (N + block_size - 1) // block_size

D = 3
ref_ctrs = []
ref_exts = []

for bidx in range(num_blocks):
start_idx = bidx * block_size
end_idx = min((bidx + 1) * block_size, N)
block_coords = coords[start_idx:end_idx]
min_coords = block_coords[0]
max_coords = block_coords[0]
for new_coords in block_coords[1:]:
center = 0.5 * (max_coords + min_coords)
new_coords -= box_diag * np.floor((new_coords - center) / box_diag + 0.5)
min_coords = np.minimum(min_coords, new_coords)
max_coords = np.maximum(max_coords, new_coords)

ref_ctrs.append((max_coords + min_coords) / 2)
ref_exts.append((max_coords - min_coords) / 2)

ref_ctrs = np.array(ref_ctrs)
ref_exts = np.array(ref_exts)
return ref_ctrs, ref_exts


@pytest.mark.parametrize("precision,atol,rtol", [(np.float32, 1e-6, 1e-6), (np.float64, 1e-7, 1e-7)])
@pytest.mark.parametrize("sort", [True, False])
def test_block_bounds_dhfr(precision, atol, rtol, sort):
_, _, coords, box = setup_dhfr()

if precision == np.float32:
nblist = custom_ops.Neighborlist_f32(coords.shape[0])
else:
nblist = custom_ops.Neighborlist_f64(coords.shape[0])

coords = np.random.randn(N, D)
box_diag = np.random.rand(3) + 1
box = np.eye(3) * box_diag
num_blocks = (N + block_size - 1) // block_size
if sort:
perm = hilbert_sort(coords + np.argmin(coords), coords.shape[1])
coords = coords[perm]

block_size = 32
ref_ctrs, ref_exts = reference_block_bounds(coords, box, block_size)

test_ctrs, test_exts = nblist.compute_block_bounds(coords, box, block_size)

for i, (ref_ctr, test_ctr) in enumerate(zip(ref_ctrs, test_ctrs)):
np.testing.assert_allclose(ref_ctr, test_ctr, atol=atol, rtol=rtol, err_msg=f"Center {i} has mismatch")
for i, (ref_ext, test_ext) in enumerate(zip(ref_exts, test_exts)):
np.testing.assert_allclose(ref_ext, test_ext, atol=atol, rtol=rtol, err_msg=f"Extent {i} has mismatch")


@pytest.mark.parametrize("precision,atol,rtol", [(np.float32, 1e-6, 1e-6), (np.float64, 1e-7, 1e-7)])
@pytest.mark.parametrize("size", [12, 128, 156, 298])
def test_block_bounds(precision, atol, rtol, size):
np.random.seed(2020)
block_size = 32
D = 3

ref_ctrs = []
ref_exts = []
if precision == np.float32:
nblist = custom_ops.Neighborlist_f32(size)
else:
nblist = custom_ops.Neighborlist_f64(size)

for bidx in range(num_blocks):
start_idx = bidx * block_size
end_idx = min((bidx + 1) * block_size, N)
block_coords = coords[start_idx:end_idx]
min_coords = block_coords[0]
max_coords = block_coords[0]
for new_coords in block_coords[1:]:
center = 0.5 * (max_coords + min_coords)
new_coords -= box_diag * np.floor((new_coords - center) / box_diag + 0.5)
min_coords = np.minimum(min_coords, new_coords)
max_coords = np.maximum(max_coords, new_coords)
coords = np.random.randn(size, D)

ref_ctrs.append((max_coords + min_coords) / 2)
ref_exts.append((max_coords - min_coords) / 2)
box_diag = np.random.rand(3) + 1
box = np.eye(3) * box_diag

ref_ctrs = np.array(ref_ctrs)
ref_exts = np.array(ref_exts)
ref_ctrs, ref_exts = reference_block_bounds(coords, box, block_size)

test_ctrs, test_exts = nblist.compute_block_bounds(coords, box, block_size)
test_ctrs, test_exts = nblist.compute_block_bounds(coords, box, block_size)

np.testing.assert_almost_equal(ref_ctrs, test_ctrs)
np.testing.assert_almost_equal(ref_exts, test_exts)
np.testing.assert_allclose(ref_ctrs, test_ctrs, atol=atol, rtol=rtol)
np.testing.assert_allclose(ref_exts, test_exts, atol=atol, rtol=rtol)


def get_water_coords(D, sort=False):
Expand Down
137 changes: 80 additions & 57 deletions timemachine/cpp/src/kernels/k_neighborlist.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ static const int TILE_SIZE = WARP_SIZE;

template <typename RealType>
void __global__ k_find_block_bounds(
const int N, // Number of atoms
const int num_tiles, // Number of tiles
const int num_indices, // Number of indices
const unsigned int *__restrict__ row_idxs, // [num_indices]
Expand All @@ -22,71 +21,95 @@ void __global__ k_find_block_bounds(
// Algorithm taken from https://github.com/openmm/openmm/blob/master/platforms/cuda/src/kernels/findInteractingBlocks.cu#L7
// Computes smaller bounding boxes than simpler form by accounting for periodic box conditions

// each thread processes one tile
const int tile_idx = blockIdx.x * blockDim.x + threadIdx.x;
// each warp processes one tile
int tile_idx = (blockIdx.x * blockDim.x + threadIdx.x) / WARP_SIZE;

if (tile_idx >= num_tiles) {
return;
}
const int row_idx = tile_idx * TILE_SIZE;
if (row_idx >= num_indices) {
return;

RealType pos_x;
RealType pos_y;
RealType pos_z;

RealType min_pos_x;
RealType min_pos_y;
RealType min_pos_z;

RealType max_pos_x;
RealType max_pos_y;
RealType max_pos_z;

RealType imaged_pos;

const RealType box_x = box[0 * 3 + 0];
const RealType box_y = box[1 * 3 + 1];
const RealType box_z = box[2 * 3 + 2];

const RealType inv_bx = 1 / box_x;
const RealType inv_by = 1 / box_y;
const RealType inv_bz = 1 / box_z;

int row_idx = tile_idx * TILE_SIZE + (threadIdx.x % WARP_SIZE);
// Reset the ixn count
if (row_idx == 0) {
ixn_count[0] = 0;
}
int atom_idx = row_idxs[row_idx];

const RealType bx = box[0 * 3 + 0];
const RealType by = box[1 * 3 + 1];
const RealType bz = box[2 * 3 + 2];

const RealType inv_bx = 1 / bx;
const RealType inv_by = 1 / by;
const RealType inv_bz = 1 / bz;
RealType pos_x = coords[atom_idx * 3 + 0];
RealType pos_y = coords[atom_idx * 3 + 1];
RealType pos_z = coords[atom_idx * 3 + 2];

RealType minPos_x = pos_x;
RealType minPos_y = pos_y;
RealType minPos_z = pos_z;

RealType maxPos_x = pos_x;
RealType maxPos_y = pos_y;
RealType maxPos_z = pos_z;

const int last = min(row_idx + TILE_SIZE, num_indices);
for (int i = row_idx + 1; i < last; i++) {
atom_idx = row_idxs[i];

if (row_idx < num_indices) {
int atom_idx = row_idxs[row_idx];

pos_x = coords[atom_idx * 3 + 0];
pos_y = coords[atom_idx * 3 + 1];
pos_z = coords[atom_idx * 3 + 2];
// Build up center over time, and recenter before computing
// min and max, to reduce overall size of box thanks to accounting
// for periodic boundary conditions
RealType center_x = static_cast<RealType>(0.5) * (maxPos_x + minPos_x);
RealType center_y = static_cast<RealType>(0.5) * (maxPos_y + minPos_y);
RealType center_z = static_cast<RealType>(0.5) * (maxPos_z + minPos_z);
pos_x -= bx * nearbyint((pos_x - center_x) * inv_bx);
pos_y -= by * nearbyint((pos_y - center_y) * inv_by);
pos_z -= bz * nearbyint((pos_z - center_z) * inv_bz);
minPos_x = min(minPos_x, pos_x);
minPos_y = min(minPos_y, pos_y);
minPos_z = min(minPos_z, pos_z);

maxPos_x = max(maxPos_x, pos_x);
maxPos_y = max(maxPos_y, pos_y);
maxPos_z = max(maxPos_z, pos_z);
}

block_bounds_ctr[tile_idx * 3 + 0] = static_cast<RealType>(0.5) * (maxPos_x + minPos_x);
block_bounds_ctr[tile_idx * 3 + 1] = static_cast<RealType>(0.5) * (maxPos_y + minPos_y);
block_bounds_ctr[tile_idx * 3 + 2] = static_cast<RealType>(0.5) * (maxPos_z + minPos_z);
min_pos_x = pos_x;
min_pos_y = pos_y;
min_pos_z = pos_z;

block_bounds_ext[tile_idx * 3 + 0] = static_cast<RealType>(0.5) * (maxPos_x - minPos_x);
block_bounds_ext[tile_idx * 3 + 1] = static_cast<RealType>(0.5) * (maxPos_y - minPos_y);
block_bounds_ext[tile_idx * 3 + 2] = static_cast<RealType>(0.5) * (maxPos_z - minPos_z);
max_pos_x = min_pos_x;
max_pos_y = min_pos_y;
max_pos_z = min_pos_z;
}

// Reset the ixn count
if (tile_idx == 0) {
ixn_count[0] = 0;
// Only the first thread in each warp computes the min/max of the bounding box
bool compute_bounds = threadIdx.x % WARP_SIZE == 0;

// Build up center over time, and recenter before computing
// min and max, to reduce overall size of box thanks to accounting
// for periodic boundary conditions
const int src_lane = threadIdx.x + 1 % WARP_SIZE;
for (int i = 0; i < WARP_SIZE; i++) {
row_idx = __shfl_sync(0xffffffff, row_idx, src_lane);
pos_x = __shfl_sync(0xffffffff, pos_x, src_lane);
pos_y = __shfl_sync(0xffffffff, pos_y, src_lane);
pos_z = __shfl_sync(0xffffffff, pos_z, src_lane);
// Only evaluate for the first thread and when the row idx is valid
if (compute_bounds && row_idx < num_indices) {
imaged_pos =
pos_x - box_x * nearbyint((pos_x - static_cast<RealType>(0.5) * (max_pos_x + min_pos_x)) * inv_bx);
min_pos_x = min(min_pos_x, imaged_pos);
max_pos_x = max(max_pos_x, imaged_pos);

imaged_pos =
pos_y - box_y * nearbyint((pos_y - static_cast<RealType>(0.5) * (max_pos_y + min_pos_y)) * inv_by);
min_pos_y = min(min_pos_y, imaged_pos);
max_pos_y = max(max_pos_y, imaged_pos);

imaged_pos =
pos_z - box_z * nearbyint((pos_z - static_cast<RealType>(0.5) * (max_pos_z + min_pos_z)) * inv_bz);
min_pos_z = min(min_pos_z, imaged_pos);
max_pos_z = max(max_pos_z, imaged_pos);
}
}
if (threadIdx.x % WARP_SIZE == 0) {
block_bounds_ctr[tile_idx * 3 + 0] = static_cast<RealType>(0.5) * (max_pos_x + min_pos_x);
block_bounds_ctr[tile_idx * 3 + 1] = static_cast<RealType>(0.5) * (max_pos_y + min_pos_y);
block_bounds_ctr[tile_idx * 3 + 2] = static_cast<RealType>(0.5) * (max_pos_z + min_pos_z);

block_bounds_ext[tile_idx * 3 + 0] = static_cast<RealType>(0.5) * (max_pos_x - min_pos_x);
block_bounds_ext[tile_idx * 3 + 1] = static_cast<RealType>(0.5) * (max_pos_y - min_pos_y);
block_bounds_ext[tile_idx * 3 + 2] = static_cast<RealType>(0.5) * (max_pos_z - min_pos_z);
}
}

Expand Down Expand Up @@ -181,7 +204,7 @@ void __global__ k_find_blocks_with_ixns(
const RealType *__restrict__ column_bb_ctr, // [N * 3] block centers
const RealType *__restrict__ column_bb_ext, // [N * 3] block extents
const RealType *__restrict__ row_bb_ctr, // [N * 3] block centers
const RealType *__restrict__ row_bb_ext, // [N * 3] block extants
const RealType *__restrict__ row_bb_ext, // [N * 3] block extents
const double *__restrict__ coords, // [N * 3]
const double *__restrict__ box,
unsigned int *__restrict__ interactionCount, // number of tiles that have interactions
Expand Down
27 changes: 17 additions & 10 deletions timemachine/cpp/src/neighborlist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,22 +64,33 @@ void Neighborlist<RealType>::compute_block_bounds_host(
DeviceBuffer<double> d_coords(N * D);
DeviceBuffer<double> d_box(D * D);

std::vector<RealType> h_block_bounds_centers(this->num_column_blocks() * 3);
std::vector<RealType> h_block_bounds_extents(this->num_column_blocks() * 3);

d_coords.copy_from(h_coords);
d_box.copy_from(h_box);

this->compute_block_bounds_device(N, D, d_coords.data, d_box.data, static_cast<cudaStream_t>(0));
gpuErrchk(cudaDeviceSynchronize());

gpuErrchk(cudaMemcpy(
h_bb_ctrs,
&h_block_bounds_centers[0],
d_column_block_bounds_ctr_,
this->num_column_blocks() * 3 * sizeof(*d_column_block_bounds_ctr_),
cudaMemcpyDeviceToHost));
gpuErrchk(cudaMemcpy(
h_bb_exts,
&h_block_bounds_extents[0],
d_column_block_bounds_ext_,
this->num_column_blocks() * 3 * sizeof(*d_column_block_bounds_ext_),
cudaMemcpyDeviceToHost));

// Handle the float -> double, doing a direct copy from a double buffer to a float buffer results in garbage values
for (auto i = 0; i < h_block_bounds_centers.size(); i++) {
h_bb_ctrs[i] = h_block_bounds_centers[i];
}
for (auto i = 0; i < h_block_bounds_extents.size(); i++) {
h_bb_exts[i] = h_block_bounds_extents[i];
}
}

template <typename RealType>
Expand Down Expand Up @@ -199,11 +210,9 @@ void Neighborlist<RealType>::compute_block_bounds_device(
}

const int tpb = DEFAULT_THREADS_PER_BLOCK;
const int column_blocks = this->num_column_blocks(); // total number of blocks we need to process

k_find_block_bounds<RealType><<<column_blocks, tpb, 0, stream>>>(
N,
column_blocks,
k_find_block_bounds<RealType><<<ceil_divide(NC_, tpb), tpb, 0, stream>>>(
this->num_column_blocks(),
NC_,
d_column_idxs_,
d_coords,
Expand All @@ -215,10 +224,8 @@ void Neighborlist<RealType>::compute_block_bounds_device(
// In the case of upper triangle of the matrix, the column and row indices are the same, so only compute block ixns for both
// when they are different
if (!this->compute_upper_triangular()) {
const int row_blocks = this->num_row_blocks();
k_find_block_bounds<RealType><<<row_blocks, tpb, 0, stream>>>(
N,
row_blocks,
k_find_block_bounds<RealType><<<ceil_divide(NR_, tpb), tpb, 0, stream>>>(
this->num_row_blocks(),
NR_,
d_row_idxs_,
d_coords,
Expand Down

0 comments on commit c35a5f7

Please sign in to comment.