diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh index 65115b2ccb..385e16383e 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh @@ -90,8 +90,8 @@ DI void storeWarpQShmem(myWarpSelect& heapArr, template DI void storeWarpQGmem(myWarpSelect& heapArr, - OutT* out_dists, - IdxT* out_inds, + volatile OutT* out_dists, + volatile IdxT* out_inds, const IdxT m, const unsigned int numOfNN, const IdxT starty) @@ -115,8 +115,8 @@ DI void storeWarpQGmem(myWarpSelect& heapArr, template DI void loadPrevTopKsGmemWarpQ(myWarpSelect& heapArr, - OutT* out_dists, - IdxT* out_inds, + volatile OutT* out_dists, + volatile IdxT* out_inds, const IdxT m, const unsigned int numOfNN, const IdxT starty) @@ -207,9 +207,9 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x FinalLambda fin_op, bool sqrt, unsigned int numOfNN, - int* mutexes, - OutT* out_dists, - IdxT* out_inds) + volatile int* mutexes, + volatile OutT* out_dists, + volatile IdxT* out_inds) { extern __shared__ char smem[]; @@ -225,8 +225,6 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x IdxT gridStrideY) { if (gridDim.x == 1) { return; } - volatile int* mutex = mutexes; - Pair* shDumpKV = nullptr; if (useNorms) { shDumpKV = (Pair*)(&smem[Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT))]); @@ -240,7 +238,7 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x // 0 -> consumer done consuming the buffer. // -1 -> consumer started consuming the buffer // -2 -> producer done filling the buffer - // blockIdx.x -> prod started to fill the buffer + // 1 -> prod acquired to fill the buffer if (blockIdx.x == 0) { auto cta_processed = 0; myWarpSelect heapArr1(identity, keyMax, numOfNN); @@ -252,45 +250,41 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x while (cta_processed < gridDim.x - 1) { if (threadIdx.x == 0) { - int32_t old = -3; - while (old != -1) { - old = atomicCAS((int*)&mutex[gridStrideY / Policy::Mblk], -2, -1); - } - __threadfence(); + while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], -2, -1) != -2) + ; } + __threadfence(); __syncthreads(); #pragma unroll for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + const auto rowId = starty + i * Policy::AccThRows; + if (rowId < m) { #pragma unroll - for (int j = 0; j < heapArr[i]->kNumWarpQRegisters; ++j) { - Pair otherKV; - otherKV.value = identity; - otherKV.key = keyMax; - const auto idx = j * warpSize + lid; - if (idx < numOfNN && rowId < m) { - otherKV.value = out_dists[rowId * numOfNN + idx]; - otherKV.key = (uint32_t)out_inds[rowId * numOfNN + idx]; - shDumpKV[shMemRowId * numOfNN + idx] = otherKV; + for (int j = 0; j < heapArr[i]->kNumWarpQRegisters; ++j) { + Pair otherKV; + otherKV.value = identity; + otherKV.key = keyMax; + const auto idx = j * warpSize + lid; + if (idx < numOfNN) { + otherKV.value = out_dists[rowId * numOfNN + idx]; + otherKV.key = (uint32_t)out_inds[rowId * numOfNN + idx]; + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + shDumpKV[shMemRowId * numOfNN + idx] = otherKV; + } } } } __threadfence(); + __syncthreads(); - if (threadIdx.x == 0) { - mutex[gridStrideY / Policy::Mblk] = 0; - __threadfence(); - } + if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], 0); } + __threadfence(); // Perform merging of otherKV with topk's across warp. - __syncwarp(); - #pragma unroll for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + const auto rowId = starty + i * Policy::AccThRows; if (rowId < m) { #pragma unroll for (int j = 0; j < heapArr[i]->kNumWarpQRegisters; ++j) { @@ -298,7 +292,10 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x otherKV.value = identity; otherKV.key = keyMax; const auto idx = j * warpSize + lid; - if (idx < numOfNN) { otherKV = shDumpKV[shMemRowId * numOfNN + idx]; } + if (idx < numOfNN) { + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + otherKV = shDumpKV[shMemRowId * numOfNN + idx]; + } heapArr[i]->add(otherKV.value, otherKV.key); } } @@ -317,33 +314,29 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, starty); } else { if (threadIdx.x == 0) { - int32_t old = -1; - int32_t blkIdX = (int32_t)blockIdx.x; - while (old != blkIdX) { - old = atomicCAS((int*)&mutex[gridStrideY / Policy::Mblk], 0, blkIdX); - } - __threadfence(); + while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], 0, 1) != 0) + ; } + __threadfence(); __syncthreads(); #pragma unroll for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + const auto rowId = starty + i * Policy::AccThRows; if (rowId < m) { for (int idx = lid; idx < numOfNN; idx += warpSize) { - Pair KVPair = shDumpKV[shMemRowId * numOfNN + idx]; + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + Pair KVPair = shDumpKV[shMemRowId * numOfNN + idx]; out_dists[rowId * numOfNN + idx] = KVPair.value; out_inds[rowId * numOfNN + idx] = (IdxT)KVPair.key; } } } __threadfence(); + __syncthreads(); - if (threadIdx.x == 0) { - mutex[gridStrideY / Policy::Mblk] = -2; - __threadfence(); - } + if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], -2); } + __threadfence(); } }; diff --git a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh index 9aef395ad3..6e0ea1f538 100644 --- a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh @@ -294,64 +294,63 @@ void brute_force_knn_impl( auto stream = handle.get_next_usable_stream(i); - // // TODO: Enable this once we figure out why it's causing pytest failures in cuml. - // if (k <= 64 && rowMajorQuery == rowMajorIndex && rowMajorQuery == true && - // (metric == raft::distance::DistanceType::L2Unexpanded || - // metric == raft::distance::DistanceType::L2SqrtUnexpanded || - // metric == raft::distance::DistanceType::L2Expanded || - // metric == raft::distance::DistanceType::L2SqrtExpanded)) { - // fusedL2Knn(D, - // out_i_ptr, - // out_d_ptr, - // input[i], - // search_items, - // sizes[i], - // n, - // k, - // rowMajorIndex, - // rowMajorQuery, - // stream, - // metric); - // } else { - switch (metric) { - case raft::distance::DistanceType::Haversine: - - ASSERT(D == 2, - "Haversine distance requires 2 dimensions " - "(latitude / longitude)."); - - haversine_knn(out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, stream); - break; - default: - faiss::MetricType m = build_faiss_metric(metric); - - faiss::gpu::StandardGpuResources gpu_res; - - gpu_res.noTempMemory(); - gpu_res.setDefaultStream(device, stream); - - faiss::gpu::GpuDistanceParams args; - args.metric = m; - args.metricArg = metricArg; - args.k = k; - args.dims = D; - args.vectors = input[i]; - args.vectorsRowMajor = rowMajorIndex; - args.numVectors = sizes[i]; - args.queries = search_items; - args.queriesRowMajor = rowMajorQuery; - args.numQueries = n; - args.outDistances = out_d_ptr; - args.outIndices = out_i_ptr; - - /** - * @todo: Until FAISS supports pluggable allocation strategies, - * we will not reap the benefits of the pool allocator for - * avoiding device-wide synchronizations from cudaMalloc/cudaFree - */ - bfKnn(&gpu_res, args); + if (k <= 64 && rowMajorQuery == rowMajorIndex && rowMajorQuery == true && + (metric == raft::distance::DistanceType::L2Unexpanded || + metric == raft::distance::DistanceType::L2SqrtUnexpanded || + metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded)) { + fusedL2Knn(D, + out_i_ptr, + out_d_ptr, + input[i], + search_items, + sizes[i], + n, + k, + rowMajorIndex, + rowMajorQuery, + stream, + metric); + } else { + switch (metric) { + case raft::distance::DistanceType::Haversine: + + ASSERT(D == 2, + "Haversine distance requires 2 dimensions " + "(latitude / longitude)."); + + haversine_knn(out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, stream); + break; + default: + faiss::MetricType m = build_faiss_metric(metric); + + faiss::gpu::StandardGpuResources gpu_res; + + gpu_res.noTempMemory(); + gpu_res.setDefaultStream(device, stream); + + faiss::gpu::GpuDistanceParams args; + args.metric = m; + args.metricArg = metricArg; + args.k = k; + args.dims = D; + args.vectors = input[i]; + args.vectorsRowMajor = rowMajorIndex; + args.numVectors = sizes[i]; + args.queries = search_items; + args.queriesRowMajor = rowMajorQuery; + args.numQueries = n; + args.outDistances = out_d_ptr; + args.outIndices = out_i_ptr; + + /** + * @todo: Until FAISS supports pluggable allocation strategies, + * we will not reap the benefits of the pool allocator for + * avoiding device-wide synchronizations from cudaMalloc/cudaFree + */ + bfKnn(&gpu_res, args); + } } - // } RAFT_CUDA_TRY(cudaPeekAtLastError()); } @@ -377,7 +376,11 @@ void brute_force_knn_impl( float p = 0.5; // standard l2 if (metric == raft::distance::DistanceType::LpUnexpanded) p = 1.0 / metricArg; raft::linalg::unaryOp( - res_D, res_D, n * k, [p] __device__(float input) { return powf(input, p); }, userStream); + res_D, + res_D, + n * k, + [p] __device__(float input) { return powf(fabsf(input), p); }, + userStream); } query_metric_processor->revert(search_items);