Skip to content

Commit

Permalink
Fix bug in producer-consumer buffer exchange which occurs in UMAP tes…
Browse files Browse the repository at this point in the history
…t on GV100 (#429)

--  fix incorrect output in prod-cons code detected by UMAP test on GV100, the reason seems to be not using volatile and syncthreads.
-- enable fused L2 knn usage as all the issues are now resolved with this PR.

Authors:
  - Mahesh Doijade (https://github.com/mdoijade)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #429
  • Loading branch information
mdoijade authored Dec 20, 2021
1 parent f48612d commit ccd5d75
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 106 deletions.
89 changes: 41 additions & 48 deletions cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ DI void storeWarpQShmem(myWarpSelect& heapArr,

template <typename Policy, typename Pair, typename myWarpSelect, typename IdxT, typename OutT>
DI void storeWarpQGmem(myWarpSelect& heapArr,
OutT* out_dists,
IdxT* out_inds,
volatile OutT* out_dists,
volatile IdxT* out_inds,
const IdxT m,
const unsigned int numOfNN,
const IdxT starty)
Expand All @@ -115,8 +115,8 @@ DI void storeWarpQGmem(myWarpSelect& heapArr,

template <typename Policy, typename Pair, typename myWarpSelect, typename IdxT, typename OutT>
DI void loadPrevTopKsGmemWarpQ(myWarpSelect& heapArr,
OutT* out_dists,
IdxT* out_inds,
volatile OutT* out_dists,
volatile IdxT* out_inds,
const IdxT m,
const unsigned int numOfNN,
const IdxT starty)
Expand Down Expand Up @@ -207,9 +207,9 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x
FinalLambda fin_op,
bool sqrt,
unsigned int numOfNN,
int* mutexes,
OutT* out_dists,
IdxT* out_inds)
volatile int* mutexes,
volatile OutT* out_dists,
volatile IdxT* out_inds)
{
extern __shared__ char smem[];

Expand All @@ -225,8 +225,6 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x
IdxT gridStrideY) {
if (gridDim.x == 1) { return; }

volatile int* mutex = mutexes;

Pair* shDumpKV = nullptr;
if (useNorms) {
shDumpKV = (Pair*)(&smem[Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT))]);
Expand All @@ -240,7 +238,7 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x
// 0 -> consumer done consuming the buffer.
// -1 -> consumer started consuming the buffer
// -2 -> producer done filling the buffer
// blockIdx.x -> prod started to fill the buffer
// 1 -> prod acquired to fill the buffer
if (blockIdx.x == 0) {
auto cta_processed = 0;
myWarpSelect heapArr1(identity, keyMax, numOfNN);
Expand All @@ -252,53 +250,52 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x

while (cta_processed < gridDim.x - 1) {
if (threadIdx.x == 0) {
int32_t old = -3;
while (old != -1) {
old = atomicCAS((int*)&mutex[gridStrideY / Policy::Mblk], -2, -1);
}
__threadfence();
while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], -2, -1) != -2)
;
}
__threadfence();
__syncthreads();

#pragma unroll
for (int i = 0; i < Policy::AccRowsPerTh; ++i) {
const auto rowId = starty + i * Policy::AccThRows;
const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows;
const auto rowId = starty + i * Policy::AccThRows;
if (rowId < m) {
#pragma unroll
for (int j = 0; j < heapArr[i]->kNumWarpQRegisters; ++j) {
Pair otherKV;
otherKV.value = identity;
otherKV.key = keyMax;
const auto idx = j * warpSize + lid;
if (idx < numOfNN && rowId < m) {
otherKV.value = out_dists[rowId * numOfNN + idx];
otherKV.key = (uint32_t)out_inds[rowId * numOfNN + idx];
shDumpKV[shMemRowId * numOfNN + idx] = otherKV;
for (int j = 0; j < heapArr[i]->kNumWarpQRegisters; ++j) {
Pair otherKV;
otherKV.value = identity;
otherKV.key = keyMax;
const auto idx = j * warpSize + lid;
if (idx < numOfNN) {
otherKV.value = out_dists[rowId * numOfNN + idx];
otherKV.key = (uint32_t)out_inds[rowId * numOfNN + idx];
const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows;
shDumpKV[shMemRowId * numOfNN + idx] = otherKV;
}
}
}
}
__threadfence();
__syncthreads();

if (threadIdx.x == 0) {
mutex[gridStrideY / Policy::Mblk] = 0;
__threadfence();
}
if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], 0); }
__threadfence();

// Perform merging of otherKV with topk's across warp.
__syncwarp();

#pragma unroll
for (int i = 0; i < Policy::AccRowsPerTh; ++i) {
const auto rowId = starty + i * Policy::AccThRows;
const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows;
const auto rowId = starty + i * Policy::AccThRows;
if (rowId < m) {
#pragma unroll
for (int j = 0; j < heapArr[i]->kNumWarpQRegisters; ++j) {
Pair otherKV;
otherKV.value = identity;
otherKV.key = keyMax;
const auto idx = j * warpSize + lid;
if (idx < numOfNN) { otherKV = shDumpKV[shMemRowId * numOfNN + idx]; }
if (idx < numOfNN) {
const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows;
otherKV = shDumpKV[shMemRowId * numOfNN + idx];
}
heapArr[i]->add(otherKV.value, otherKV.key);
}
}
Expand All @@ -317,33 +314,29 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x
storeWarpQGmem<Policy, Pair>(heapArr, out_dists, out_inds, m, numOfNN, starty);
} else {
if (threadIdx.x == 0) {
int32_t old = -1;
int32_t blkIdX = (int32_t)blockIdx.x;
while (old != blkIdX) {
old = atomicCAS((int*)&mutex[gridStrideY / Policy::Mblk], 0, blkIdX);
}
__threadfence();
while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], 0, 1) != 0)
;
}
__threadfence();
__syncthreads();

#pragma unroll
for (int i = 0; i < Policy::AccRowsPerTh; ++i) {
const auto rowId = starty + i * Policy::AccThRows;
const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows;
const auto rowId = starty + i * Policy::AccThRows;
if (rowId < m) {
for (int idx = lid; idx < numOfNN; idx += warpSize) {
Pair KVPair = shDumpKV[shMemRowId * numOfNN + idx];
const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows;
Pair KVPair = shDumpKV[shMemRowId * numOfNN + idx];
out_dists[rowId * numOfNN + idx] = KVPair.value;
out_inds[rowId * numOfNN + idx] = (IdxT)KVPair.key;
}
}
}
__threadfence();
__syncthreads();

if (threadIdx.x == 0) {
mutex[gridStrideY / Policy::Mblk] = -2;
__threadfence();
}
if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], -2); }
__threadfence();
}
};

Expand Down
119 changes: 61 additions & 58 deletions cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -294,64 +294,63 @@ void brute_force_knn_impl(

auto stream = handle.get_next_usable_stream(i);

// // TODO: Enable this once we figure out why it's causing pytest failures in cuml.
// if (k <= 64 && rowMajorQuery == rowMajorIndex && rowMajorQuery == true &&
// (metric == raft::distance::DistanceType::L2Unexpanded ||
// metric == raft::distance::DistanceType::L2SqrtUnexpanded ||
// metric == raft::distance::DistanceType::L2Expanded ||
// metric == raft::distance::DistanceType::L2SqrtExpanded)) {
// fusedL2Knn(D,
// out_i_ptr,
// out_d_ptr,
// input[i],
// search_items,
// sizes[i],
// n,
// k,
// rowMajorIndex,
// rowMajorQuery,
// stream,
// metric);
// } else {
switch (metric) {
case raft::distance::DistanceType::Haversine:

ASSERT(D == 2,
"Haversine distance requires 2 dimensions "
"(latitude / longitude).");

haversine_knn(out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, stream);
break;
default:
faiss::MetricType m = build_faiss_metric(metric);

faiss::gpu::StandardGpuResources gpu_res;

gpu_res.noTempMemory();
gpu_res.setDefaultStream(device, stream);

faiss::gpu::GpuDistanceParams args;
args.metric = m;
args.metricArg = metricArg;
args.k = k;
args.dims = D;
args.vectors = input[i];
args.vectorsRowMajor = rowMajorIndex;
args.numVectors = sizes[i];
args.queries = search_items;
args.queriesRowMajor = rowMajorQuery;
args.numQueries = n;
args.outDistances = out_d_ptr;
args.outIndices = out_i_ptr;

/**
* @todo: Until FAISS supports pluggable allocation strategies,
* we will not reap the benefits of the pool allocator for
* avoiding device-wide synchronizations from cudaMalloc/cudaFree
*/
bfKnn(&gpu_res, args);
if (k <= 64 && rowMajorQuery == rowMajorIndex && rowMajorQuery == true &&
(metric == raft::distance::DistanceType::L2Unexpanded ||
metric == raft::distance::DistanceType::L2SqrtUnexpanded ||
metric == raft::distance::DistanceType::L2Expanded ||
metric == raft::distance::DistanceType::L2SqrtExpanded)) {
fusedL2Knn(D,
out_i_ptr,
out_d_ptr,
input[i],
search_items,
sizes[i],
n,
k,
rowMajorIndex,
rowMajorQuery,
stream,
metric);
} else {
switch (metric) {
case raft::distance::DistanceType::Haversine:

ASSERT(D == 2,
"Haversine distance requires 2 dimensions "
"(latitude / longitude).");

haversine_knn(out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, stream);
break;
default:
faiss::MetricType m = build_faiss_metric(metric);

faiss::gpu::StandardGpuResources gpu_res;

gpu_res.noTempMemory();
gpu_res.setDefaultStream(device, stream);

faiss::gpu::GpuDistanceParams args;
args.metric = m;
args.metricArg = metricArg;
args.k = k;
args.dims = D;
args.vectors = input[i];
args.vectorsRowMajor = rowMajorIndex;
args.numVectors = sizes[i];
args.queries = search_items;
args.queriesRowMajor = rowMajorQuery;
args.numQueries = n;
args.outDistances = out_d_ptr;
args.outIndices = out_i_ptr;

/**
* @todo: Until FAISS supports pluggable allocation strategies,
* we will not reap the benefits of the pool allocator for
* avoiding device-wide synchronizations from cudaMalloc/cudaFree
*/
bfKnn(&gpu_res, args);
}
}
// }

RAFT_CUDA_TRY(cudaPeekAtLastError());
}
Expand All @@ -377,7 +376,11 @@ void brute_force_knn_impl(
float p = 0.5; // standard l2
if (metric == raft::distance::DistanceType::LpUnexpanded) p = 1.0 / metricArg;
raft::linalg::unaryOp<float>(
res_D, res_D, n * k, [p] __device__(float input) { return powf(input, p); }, userStream);
res_D,
res_D,
n * k,
[p] __device__(float input) { return powf(fabsf(input), p); },
userStream);
}

query_metric_processor->revert(search_items);
Expand Down

0 comments on commit ccd5d75

Please sign in to comment.