diff --git a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh index 8962c27d52..414c1dc1ce 100644 --- a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh @@ -294,63 +294,63 @@ void brute_force_knn_impl( cudaStream_t stream = raft::select_stream(userStream, internalStreams, n_int_streams, i); // // TODO: Enable this once we figure out why it's causing pytest failures in cuml. - if (k <= 64 && rowMajorQuery == rowMajorIndex && rowMajorQuery == true && - (metric == raft::distance::DistanceType::L2Unexpanded || - metric == raft::distance::DistanceType::L2SqrtUnexpanded || - metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded)) { - fusedL2Knn(D, - out_i_ptr, - out_d_ptr, - input[i], - search_items, - sizes[i], - n, - k, - rowMajorIndex, - rowMajorQuery, - stream, - metric); - } else { - switch (metric) { - case raft::distance::DistanceType::Haversine: - - ASSERT(D == 2, - "Haversine distance requires 2 dimensions " - "(latitude / longitude)."); - - haversine_knn(out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, stream); - break; - default: - faiss::MetricType m = build_faiss_metric(metric); - - faiss::gpu::StandardGpuResources gpu_res; - - gpu_res.noTempMemory(); - gpu_res.setDefaultStream(device, stream); - - faiss::gpu::GpuDistanceParams args; - args.metric = m; - args.metricArg = metricArg; - args.k = k; - args.dims = D; - args.vectors = input[i]; - args.vectorsRowMajor = rowMajorIndex; - args.numVectors = sizes[i]; - args.queries = search_items; - args.queriesRowMajor = rowMajorQuery; - args.numQueries = n; - args.outDistances = out_d_ptr; - args.outIndices = out_i_ptr; - - /** - * @todo: Until FAISS supports pluggable allocation strategies, - * we will not reap the benefits of the pool allocator for - * avoiding device-wide synchronizations from cudaMalloc/cudaFree - */ - bfKnn(&gpu_res, args); - } + // if (k <= 64 && rowMajorQuery == rowMajorIndex && rowMajorQuery == true && + // (metric == raft::distance::DistanceType::L2Unexpanded || + // metric == raft::distance::DistanceType::L2SqrtUnexpanded || + // metric == raft::distance::DistanceType::L2Expanded || + // metric == raft::distance::DistanceType::L2SqrtExpanded)) { + // fusedL2Knn(D, + // out_i_ptr, + // out_d_ptr, + // input[i], + // search_items, + // sizes[i], + // n, + // k, + // rowMajorIndex, + // rowMajorQuery, + // stream, + // metric); + // } else { + switch (metric) { + case raft::distance::DistanceType::Haversine: + + ASSERT(D == 2, + "Haversine distance requires 2 dimensions " + "(latitude / longitude)."); + + haversine_knn(out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, stream); + break; + default: + faiss::MetricType m = build_faiss_metric(metric); + + faiss::gpu::StandardGpuResources gpu_res; + + gpu_res.noTempMemory(); + gpu_res.setDefaultStream(device, stream); + + faiss::gpu::GpuDistanceParams args; + args.metric = m; + args.metricArg = metricArg; + args.k = k; + args.dims = D; + args.vectors = input[i]; + args.vectorsRowMajor = rowMajorIndex; + args.numVectors = sizes[i]; + args.queries = search_items; + args.queriesRowMajor = rowMajorQuery; + args.numQueries = n; + args.outDistances = out_d_ptr; + args.outIndices = out_i_ptr; + + /** + * @todo: Until FAISS supports pluggable allocation strategies, + * we will not reap the benefits of the pool allocator for + * avoiding device-wide synchronizations from cudaMalloc/cudaFree + */ + bfKnn(&gpu_res, args); } + // } RAFT_CUDA_TRY(cudaPeekAtLastError()); }