Skip to content

Commit

Permalink
Improve row-major meanvar kernel via minimizing atomicCAS locks (#489)
Browse files Browse the repository at this point in the history
Map the row-major kernel grid onto the data more efficiently. In particular, make sure there is only one `atomicCAS` lock per thread block to avoid any possible deadlocks caused by branch divergence within the critical section.

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #489
  • Loading branch information
achirkin authored Feb 7, 2022
1 parent 0996d4a commit 57a23f3
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 57 deletions.
114 changes: 76 additions & 38 deletions cpp/include/raft/stats/detail/meanvar.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -93,46 +93,79 @@ class mean_var {
/*
NB: current implementation here is not optimal, especially the rowmajor version;
leaving this for further work (perhaps, as a more generic "linewiseReduce").
Possible improvements:
1. (romajor) Process input by the warps, not by blocks (thus reduce the iteration workset),
then aggregate output partially within blocks.
2. (both) Use vectorized loads to utilize memory better
3. (rowmajor) Scale the grid size better to utilize more the GPU (like in linewise_op).
Vectorized loads/stores could speed things up a lot.
*/
/**
* meanvar kernel - row-major version
*
* Assumptions:
*
* 1. blockDim.x == WarpSize
* 2. Dimension X goes along columns (D)
* 3. Dimension Y goes along rows (N)
*
*
* @tparam T element type
* @tparam I indexing type
* @tparam BlockSize must be equal to blockDim.x * blockDim.y * blockDim.z
* @param data input data
* @param mvs meanvars -- output
* @param locks guards for updating meanvars
* @param len total length of input data (N * D)
* @param D number of columns in the input data.
*/
template <typename T, typename I, int BlockSize>
__global__ void meanvar_kernel_rowmajor(
const T* data, volatile mean_var<T>* mvs, int* locks, I len, I D)
__global__ void __launch_bounds__(BlockSize)
meanvar_kernel_rowmajor(const T* data, volatile mean_var<T>* mvs, int* locks, I len, I D)
{
const I thread_idx = threadIdx.x + BlockSize * blockIdx.x;
// read the data
const I col = threadIdx.x + blockDim.x * blockIdx.x;
mean_var<T> thread_data;
{
const I grid_size = BlockSize * gridDim.x;
for (I i = thread_idx; i < len; i += grid_size) {
if (col < D) {
const I step = D * blockDim.y * gridDim.y;
for (I i = col + D * (threadIdx.y + blockDim.y * blockIdx.y); i < len; i += step) {
thread_data += mean_var<T>(data[i]);
}
}

{
const I col = thread_idx % D;
int* lock = locks + col;
while (atomicCAS(lock, 0, 1) == 1) {
// aggregate within block
if (blockDim.y > 1) {
__shared__ uint8_t shm_bytes[BlockSize * sizeof(mean_var<T>)];
auto shm = (mean_var<T>*)shm_bytes;
int tid = threadIdx.x + threadIdx.y * blockDim.x;
shm[tid] = thread_data;
for (int bs = BlockSize >> 1; bs >= blockDim.x; bs = bs >> 1) {
__syncthreads();
if (tid < bs) { shm[tid] += shm[tid + bs]; }
}
thread_data = shm[tid];
}

// aggregate across blocks
if (threadIdx.y == 0) {
int* lock = locks + blockIdx.x;
if (threadIdx.x == 0 && col < D) {
while (atomicCAS(lock, 0, 1) == 1) {
__threadfence();
}
}
__syncthreads();
if (col < D) {
__threadfence();
mean_var<T> global_data;
global_data.load(mvs + col);
global_data += thread_data;
global_data.store(mvs + col);
__threadfence();
}
__threadfence();
mean_var<T> global_data;
global_data.load(mvs + col);
global_data += thread_data;
global_data.store(mvs + col);
__threadfence();
__stwt(lock, 0);
__syncthreads();
if (threadIdx.x == 0 && col < D) { __stwt(lock, 0); }
}
}

template <typename T, typename I, int BlockSize>
__global__ void meanvar_kernel_colmajor(T* mean, T* var, const T* data, I D, I N, bool sample)
__global__ void __launch_bounds__(BlockSize)
meanvar_kernel_colmajor(T* mean, T* var, const T* data, I D, I N, bool sample)
{
using BlockReduce = cub::BlockReduce<mean_var<T>, BlockSize>;
__shared__ typename BlockReduce::TempStorage shm;
Expand Down Expand Up @@ -164,21 +197,26 @@ void meanvar(
T* mean, T* var, const T* data, I D, I N, bool sample, bool rowMajor, cudaStream_t stream)
{
if (rowMajor) {
const uint64_t len = uint64_t(D) * uint64_t(N);
ASSERT(len <= uint64_t(std::numeric_limits<I>::max()), "N * D does not fit the indexing type");
// lcm(row width, block size):
// this way, each thread processes the same column on each iteration.
const uint64_t expected_grid_size =
(uint64_t(N) / raft::gcd<uint64_t>(uint64_t(N), uint64_t(BlockSize))) * uint64_t(BlockSize);
const uint gs =
uint(min(expected_grid_size, raft::ceildiv<uint64_t>(len, uint64_t(BlockSize))));

rmm::device_buffer buf((sizeof(mean_var<T>) + sizeof(int)) * D, stream);
static_assert(BlockSize >= WarpSize, "Block size must be not smaller than the warp size.");
const dim3 bs(WarpSize, BlockSize / WarpSize, 1);
dim3 gs(raft::ceildiv<typeof(bs.x)>(D, bs.x), raft::ceildiv<typeof(bs.y)>(N, bs.y), 1);

// Don't create more blocks than necessary to occupy the GPU
int occupancy;
RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy, meanvar_kernel_rowmajor<T, I, BlockSize>, BlockSize, 0));
gs.y = min(gs.y, raft::ceildiv<typeof(gs.y)>(occupancy * getMultiProcessorCount(), gs.x));

// Global memory: one mean_var<T> for each column
// one lock per all blocks working on the same set of columns
rmm::device_buffer buf(sizeof(mean_var<T>) * D + sizeof(int) * gs.x, stream);
RAFT_CUDA_TRY(cudaMemsetAsync(buf.data(), 0, buf.size(), stream));
mean_var<T>* mvs = static_cast<mean_var<T>*>(buf.data());
int* locks = static_cast<int*>(static_cast<void*>(mvs + D));
meanvar_kernel_rowmajor<T, I, BlockSize>
<<<gs, BlockSize, 0, stream>>>(data, mvs, locks, len, D);

const uint64_t len = uint64_t(D) * uint64_t(N);
ASSERT(len <= uint64_t(std::numeric_limits<I>::max()), "N * D does not fit the indexing type");
meanvar_kernel_rowmajor<T, I, BlockSize><<<gs, bs, 0, stream>>>(data, mvs, locks, len, D);
meanvar_kernel_fill<T, I>
<<<raft::ceildiv<I>(D, BlockSize), BlockSize, 0, stream>>>(mean, var, mvs, D, sample);
} else {
Expand Down
31 changes: 12 additions & 19 deletions cpp/test/stats/meanvar.cu
Original file line number Diff line number Diff line change
Expand Up @@ -81,25 +81,18 @@ class MeanVarTest : public ::testing::TestWithParam<MeanVarInputs<T>> {
rmm::device_uvector<T> data, mean_act, vars_act;
};

const std::vector<MeanVarInputs<float>> inputsf = {{1.f, 2.f, 1024, 32, true, false, 1234ULL},
{1.f, 2.f, 1024, 64, true, false, 1234ULL},
{1.f, 2.f, 1024, 128, true, false, 1234ULL},
{1.f, 2.f, 1024, 256, true, false, 1234ULL},
{-1.f, 2.f, 1024, 32, false, false, 1234ULL},
{-1.f, 2.f, 1024, 64, false, false, 1234ULL},
{-1.f, 2.f, 1024, 128, false, false, 1234ULL},
{-1.f, 2.f, 1024, 256, false, false, 1234ULL},
{-1.f, 2.f, 1024, 256, false, false, 1234ULL},
{-1.f, 2.f, 1024, 257, false, false, 1234ULL},
{1.f, 2.f, 1024, 32, true, true, 1234ULL},
{1.f, 2.f, 1024, 64, true, true, 1234ULL},
{1.f, 2.f, 1024, 128, true, true, 1234ULL},
{1.f, 2.f, 1024, 256, true, true, 1234ULL},
{-1.f, 2.f, 1024, 32, false, true, 1234ULL},
{-1.f, 2.f, 1024, 64, false, true, 1234ULL},
{-1.f, 2.f, 1024, 128, false, true, 1234ULL},
{-1.f, 2.f, 1024, 256, false, true, 1234ULL},
{-1.f, 2.f, 1024, 257, false, true, 1234ULL}};
const std::vector<MeanVarInputs<float>> inputsf = {
{1.f, 2.f, 1024, 32, true, false, 1234ULL}, {1.f, 2.f, 1024, 64, true, false, 1234ULL},
{1.f, 2.f, 1024, 128, true, false, 1234ULL}, {1.f, 2.f, 1024, 256, true, false, 1234ULL},
{-1.f, 2.f, 1024, 32, false, false, 1234ULL}, {-1.f, 2.f, 1024, 64, false, false, 1234ULL},
{-1.f, 2.f, 1024, 128, false, false, 1234ULL}, {-1.f, 2.f, 1024, 256, false, false, 1234ULL},
{-1.f, 2.f, 1024, 256, false, false, 1234ULL}, {-1.f, 2.f, 1024, 257, false, false, 1234ULL},
{1.f, 2.f, 1024, 32, true, true, 1234ULL}, {1.f, 2.f, 1024, 64, true, true, 1234ULL},
{1.f, 2.f, 1024, 128, true, true, 1234ULL}, {1.f, 2.f, 1024, 256, true, true, 1234ULL},
{-1.f, 2.f, 1024, 32, false, true, 1234ULL}, {-1.f, 2.f, 1024, 64, false, true, 1234ULL},
{-1.f, 2.f, 1024, 128, false, true, 1234ULL}, {-1.f, 2.f, 1024, 256, false, true, 1234ULL},
{-1.f, 2.f, 1024, 257, false, true, 1234ULL}, {-1.f, 2.f, 700, 13, false, true, 1234ULL},
{10.f, 2.f, 500000, 811, false, true, 1234ULL}};

const std::vector<MeanVarInputs<double>> inputsd = {{1.0, 2.0, 1024, 32, true, false, 1234ULL},
{1.0, 2.0, 1024, 64, true, false, 1234ULL},
Expand Down

0 comments on commit 57a23f3

Please sign in to comment.