diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh index f37bccaadb..5b3b2129f7 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh @@ -1074,9 +1074,9 @@ void search_impl(const handle_t& handle, rmm::device_uvector coarse_distances_dev(n_queries * n_probes, stream, search_mr); // The topk index of cluster(list) and queries rmm::device_uvector coarse_indices_dev(n_queries * n_probes, stream, search_mr); - // The topk distance value of candicate vectors from each cluster(list) + // The topk distance value of candidate vectors from each cluster(list) rmm::device_uvector refined_distances_dev(n_queries * n_probes * k, stream, search_mr); - // The topk index of candicate vectors from each cluster(list) + // The topk index of candidate vectors from each cluster(list) rmm::device_uvector refined_indices_dev(n_queries * n_probes * k, stream, search_mr); size_t float_query_size; diff --git a/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh index d12af0dccf..b1f47a6c52 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh @@ -405,6 +405,7 @@ void postprocess_distances(float* out, // [n_queries, topk] distance::DistanceType metric, uint32_t n_queries, uint32_t topk, + float scaling_factor, rmm::cuda_stream_view stream) { size_t len = size_t(n_queries) * size_t(topk); @@ -412,16 +413,32 @@ void postprocess_distances(float* out, // [n_queries, topk] case distance::DistanceType::L2Unexpanded: case distance::DistanceType::L2Expanded: { linalg::unaryOp( - out, in, len, [] __device__(ScoreT x) -> float { return float(x); }, stream); + out, + in, + len, + [scaling_factor] __device__(ScoreT x) -> float { + return scaling_factor * scaling_factor * float(x); + }, + stream); } break; case distance::DistanceType::L2SqrtUnexpanded: case distance::DistanceType::L2SqrtExpanded: { linalg::unaryOp( - out, in, len, [] __device__(ScoreT x) -> float { return sqrtf(float(x)); }, stream); + out, + in, + len, + [scaling_factor] __device__(ScoreT x) -> float { return scaling_factor * sqrtf(float(x)); }, + stream); } break; case distance::DistanceType::InnerProduct: { linalg::unaryOp( - out, in, len, [] __device__(ScoreT x) -> float { return -float(x); }, stream); + out, + in, + len, + [scaling_factor] __device__(ScoreT x) -> float { + return -scaling_factor * scaling_factor * float(x); + }, + stream); } break; default: RAFT_FAIL("Unexpected metric."); } @@ -998,6 +1015,7 @@ void ivfpq_search_worker(const handle_t& handle, const float* query, // [n_queries, rot_dim] IdxT* neighbors, // [n_queries, topK] float* distances, // [n_queries, topK] + float scaling_factor, rmm::mr::device_memory_resource* mr) { auto stream = handle.get_stream(); @@ -1125,7 +1143,8 @@ void ivfpq_search_worker(const handle_t& handle, mr); // Postprocessing - postprocess_distances(distances, topk_dists.data(), index.metric(), n_queries, topK, stream); + postprocess_distances( + distances, topk_dists.data(), index.metric(), n_queries, topK, scaling_factor, stream); postprocess_neighbors(neighbors, manage_local_topk, data_indices, @@ -1156,6 +1175,7 @@ struct ivfpq_search { const float*, IdxT*, float*, + float, rmm::mr::device_memory_resource*); /** @@ -1366,6 +1386,7 @@ inline void search(const handle_t& handle, rot_queries.data() + uint64_t(index.rot_dim()) * offset_b, neighbors + uint64_t(k) * (offset_q + offset_b), distances + uint64_t(k) * (offset_q + offset_b), + utils::config::kDivisor / utils::config::kDivisor, mr); } } diff --git a/cpp/test/spatial/ann_utils.cuh b/cpp/test/spatial/ann_utils.cuh index 7fb040c913..faf6fad115 100644 --- a/cpp/test/spatial/ann_utils.cuh +++ b/cpp/test/spatial/ann_utils.cuh @@ -109,7 +109,6 @@ __global__ void naive_distance_kernel(EvalT* dist, IdxT k, raft::distance::DistanceType type) { - detail::utils::mapping f{}; IdxT midx = threadIdx.x + blockIdx.x * blockDim.x; if (midx >= m) return; for (IdxT nidx = threadIdx.y + blockIdx.y * blockDim.y; nidx < n; @@ -118,8 +117,8 @@ __global__ void naive_distance_kernel(EvalT* dist, for (IdxT i = 0; i < k; ++i) { IdxT xidx = i + midx * k; IdxT yidx = i + nidx * k; - EvalT xv = f(x[xidx]); - EvalT yv = f(y[yidx]); + EvalT xv = (EvalT)x[xidx]; + EvalT yv = (EvalT)y[yidx]; if (type == raft::distance::DistanceType::InnerProduct) { acc += xv * yv; } else {