Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in producer-consumer buffer exchange which occurs in UMAP test on GV100 #429

Merged
merged 3 commits into from
Dec 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 41 additions & 48 deletions cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ DI void storeWarpQShmem(myWarpSelect& heapArr,

template <typename Policy, typename Pair, typename myWarpSelect, typename IdxT, typename OutT>
DI void storeWarpQGmem(myWarpSelect& heapArr,
OutT* out_dists,
IdxT* out_inds,
volatile OutT* out_dists,
volatile IdxT* out_inds,
const IdxT m,
const unsigned int numOfNN,
const IdxT starty)
Expand All @@ -115,8 +115,8 @@ DI void storeWarpQGmem(myWarpSelect& heapArr,

template <typename Policy, typename Pair, typename myWarpSelect, typename IdxT, typename OutT>
DI void loadPrevTopKsGmemWarpQ(myWarpSelect& heapArr,
OutT* out_dists,
IdxT* out_inds,
volatile OutT* out_dists,
volatile IdxT* out_inds,
const IdxT m,
const unsigned int numOfNN,
const IdxT starty)
Expand Down Expand Up @@ -207,9 +207,9 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x
FinalLambda fin_op,
bool sqrt,
unsigned int numOfNN,
int* mutexes,
OutT* out_dists,
IdxT* out_inds)
volatile int* mutexes,
volatile OutT* out_dists,
volatile IdxT* out_inds)
{
extern __shared__ char smem[];

Expand All @@ -225,8 +225,6 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x
IdxT gridStrideY) {
if (gridDim.x == 1) { return; }

volatile int* mutex = mutexes;

Pair* shDumpKV = nullptr;
if (useNorms) {
shDumpKV = (Pair*)(&smem[Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT))]);
Expand All @@ -240,7 +238,7 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x
// 0 -> consumer done consuming the buffer.
// -1 -> consumer started consuming the buffer
// -2 -> producer done filling the buffer
// blockIdx.x -> prod started to fill the buffer
// 1 -> prod acquired to fill the buffer
if (blockIdx.x == 0) {
auto cta_processed = 0;
myWarpSelect heapArr1(identity, keyMax, numOfNN);
Expand All @@ -252,53 +250,52 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x

while (cta_processed < gridDim.x - 1) {
if (threadIdx.x == 0) {
int32_t old = -3;
while (old != -1) {
old = atomicCAS((int*)&mutex[gridStrideY / Policy::Mblk], -2, -1);
}
__threadfence();
while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], -2, -1) != -2)
;
}
__threadfence();
__syncthreads();

#pragma unroll
for (int i = 0; i < Policy::AccRowsPerTh; ++i) {
const auto rowId = starty + i * Policy::AccThRows;
const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows;
const auto rowId = starty + i * Policy::AccThRows;
if (rowId < m) {
#pragma unroll
for (int j = 0; j < heapArr[i]->kNumWarpQRegisters; ++j) {
Pair otherKV;
otherKV.value = identity;
otherKV.key = keyMax;
const auto idx = j * warpSize + lid;
if (idx < numOfNN && rowId < m) {
otherKV.value = out_dists[rowId * numOfNN + idx];
otherKV.key = (uint32_t)out_inds[rowId * numOfNN + idx];
shDumpKV[shMemRowId * numOfNN + idx] = otherKV;
for (int j = 0; j < heapArr[i]->kNumWarpQRegisters; ++j) {
Pair otherKV;
otherKV.value = identity;
otherKV.key = keyMax;
const auto idx = j * warpSize + lid;
if (idx < numOfNN) {
otherKV.value = out_dists[rowId * numOfNN + idx];
otherKV.key = (uint32_t)out_inds[rowId * numOfNN + idx];
const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows;
shDumpKV[shMemRowId * numOfNN + idx] = otherKV;
}
}
}
}
__threadfence();
__syncthreads();

if (threadIdx.x == 0) {
mutex[gridStrideY / Policy::Mblk] = 0;
__threadfence();
}
if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], 0); }
__threadfence();

// Perform merging of otherKV with topk's across warp.
__syncwarp();

#pragma unroll
for (int i = 0; i < Policy::AccRowsPerTh; ++i) {
const auto rowId = starty + i * Policy::AccThRows;
const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows;
const auto rowId = starty + i * Policy::AccThRows;
if (rowId < m) {
#pragma unroll
for (int j = 0; j < heapArr[i]->kNumWarpQRegisters; ++j) {
Pair otherKV;
otherKV.value = identity;
otherKV.key = keyMax;
const auto idx = j * warpSize + lid;
if (idx < numOfNN) { otherKV = shDumpKV[shMemRowId * numOfNN + idx]; }
if (idx < numOfNN) {
const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows;
otherKV = shDumpKV[shMemRowId * numOfNN + idx];
}
heapArr[i]->add(otherKV.value, otherKV.key);
}
}
Expand All @@ -317,33 +314,29 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x
storeWarpQGmem<Policy, Pair>(heapArr, out_dists, out_inds, m, numOfNN, starty);
} else {
if (threadIdx.x == 0) {
int32_t old = -1;
int32_t blkIdX = (int32_t)blockIdx.x;
while (old != blkIdX) {
old = atomicCAS((int*)&mutex[gridStrideY / Policy::Mblk], 0, blkIdX);
}
__threadfence();
while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], 0, 1) != 0)
;
}
__threadfence();
__syncthreads();

#pragma unroll
for (int i = 0; i < Policy::AccRowsPerTh; ++i) {
const auto rowId = starty + i * Policy::AccThRows;
const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows;
const auto rowId = starty + i * Policy::AccThRows;
if (rowId < m) {
for (int idx = lid; idx < numOfNN; idx += warpSize) {
Pair KVPair = shDumpKV[shMemRowId * numOfNN + idx];
const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows;
Pair KVPair = shDumpKV[shMemRowId * numOfNN + idx];
out_dists[rowId * numOfNN + idx] = KVPair.value;
out_inds[rowId * numOfNN + idx] = (IdxT)KVPair.key;
}
}
}
__threadfence();
__syncthreads();

if (threadIdx.x == 0) {
mutex[gridStrideY / Policy::Mblk] = -2;
__threadfence();
}
if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], -2); }
__threadfence();
}
};

Expand Down
119 changes: 61 additions & 58 deletions cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -294,64 +294,63 @@ void brute_force_knn_impl(

auto stream = handle.get_next_usable_stream(i);

// // TODO: Enable this once we figure out why it's causing pytest failures in cuml.
// if (k <= 64 && rowMajorQuery == rowMajorIndex && rowMajorQuery == true &&
// (metric == raft::distance::DistanceType::L2Unexpanded ||
// metric == raft::distance::DistanceType::L2SqrtUnexpanded ||
// metric == raft::distance::DistanceType::L2Expanded ||
// metric == raft::distance::DistanceType::L2SqrtExpanded)) {
// fusedL2Knn(D,
// out_i_ptr,
// out_d_ptr,
// input[i],
// search_items,
// sizes[i],
// n,
// k,
// rowMajorIndex,
// rowMajorQuery,
// stream,
// metric);
// } else {
switch (metric) {
case raft::distance::DistanceType::Haversine:

ASSERT(D == 2,
"Haversine distance requires 2 dimensions "
"(latitude / longitude).");

haversine_knn(out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, stream);
break;
default:
faiss::MetricType m = build_faiss_metric(metric);

faiss::gpu::StandardGpuResources gpu_res;

gpu_res.noTempMemory();
gpu_res.setDefaultStream(device, stream);

faiss::gpu::GpuDistanceParams args;
args.metric = m;
args.metricArg = metricArg;
args.k = k;
args.dims = D;
args.vectors = input[i];
args.vectorsRowMajor = rowMajorIndex;
args.numVectors = sizes[i];
args.queries = search_items;
args.queriesRowMajor = rowMajorQuery;
args.numQueries = n;
args.outDistances = out_d_ptr;
args.outIndices = out_i_ptr;

/**
* @todo: Until FAISS supports pluggable allocation strategies,
* we will not reap the benefits of the pool allocator for
* avoiding device-wide synchronizations from cudaMalloc/cudaFree
*/
bfKnn(&gpu_res, args);
if (k <= 64 && rowMajorQuery == rowMajorIndex && rowMajorQuery == true &&
(metric == raft::distance::DistanceType::L2Unexpanded ||
metric == raft::distance::DistanceType::L2SqrtUnexpanded ||
metric == raft::distance::DistanceType::L2Expanded ||
metric == raft::distance::DistanceType::L2SqrtExpanded)) {
fusedL2Knn(D,
out_i_ptr,
out_d_ptr,
input[i],
search_items,
sizes[i],
n,
k,
rowMajorIndex,
rowMajorQuery,
stream,
metric);
} else {
switch (metric) {
case raft::distance::DistanceType::Haversine:

ASSERT(D == 2,
"Haversine distance requires 2 dimensions "
"(latitude / longitude).");

haversine_knn(out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, stream);
break;
default:
faiss::MetricType m = build_faiss_metric(metric);

faiss::gpu::StandardGpuResources gpu_res;

gpu_res.noTempMemory();
gpu_res.setDefaultStream(device, stream);

faiss::gpu::GpuDistanceParams args;
args.metric = m;
args.metricArg = metricArg;
args.k = k;
args.dims = D;
args.vectors = input[i];
args.vectorsRowMajor = rowMajorIndex;
args.numVectors = sizes[i];
args.queries = search_items;
args.queriesRowMajor = rowMajorQuery;
args.numQueries = n;
args.outDistances = out_d_ptr;
args.outIndices = out_i_ptr;

/**
* @todo: Until FAISS supports pluggable allocation strategies,
* we will not reap the benefits of the pool allocator for
* avoiding device-wide synchronizations from cudaMalloc/cudaFree
*/
bfKnn(&gpu_res, args);
}
}
// }

RAFT_CUDA_TRY(cudaPeekAtLastError());
}
Expand All @@ -377,7 +376,11 @@ void brute_force_knn_impl(
float p = 0.5; // standard l2
if (metric == raft::distance::DistanceType::LpUnexpanded) p = 1.0 / metricArg;
raft::linalg::unaryOp<float>(
res_D, res_D, n * k, [p] __device__(float input) { return powf(input, p); }, userStream);
res_D,
res_D,
n * k,
[p] __device__(float input) { return powf(fabsf(input), p); },
userStream);
}

query_metric_processor->revert(search_items);
Expand Down