From c62794724ed0e428e733b04202271b76f8dc602d Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 10 Jan 2023 14:36:35 -0800 Subject: [PATCH 1/3] Add L2SqrtExpanded support to ivf_flat ANN indices --- cpp/include/raft/neighbors/ivf_flat_types.hpp | 7 ++++++- .../raft/spatial/knn/detail/ivf_flat_build.cuh | 8 +++----- .../raft/spatial/knn/detail/ivf_flat_search.cuh | 12 +++++++++++- cpp/test/neighbors/ann_ivf_flat.cu | 2 ++ 4 files changed, 22 insertions(+), 7 deletions(-) diff --git a/cpp/include/raft/neighbors/ivf_flat_types.hpp b/cpp/include/raft/neighbors/ivf_flat_types.hpp index eea6ae256d..c56b6a0e65 100644 --- a/cpp/include/raft/neighbors/ivf_flat_types.hpp +++ b/cpp/include/raft/neighbors/ivf_flat_types.hpp @@ -252,8 +252,13 @@ struct index : ann::index { * Replace the content of the index with new uninitialized mdarrays to hold the indicated amount * of data. */ - void allocate(const handle_t& handle, IdxT index_size, bool allocate_center_norms) + void allocate(const handle_t& handle, IdxT index_size) { + bool allocate_center_norms = ((metric_ == raft::distance::DistanceType::L2Expanded) || + (metric_ == raft::distance::DistanceType::L2SqrtExpanded) || + (metric_ == raft::distance::DistanceType::L2Unexpanded) || + (metric_ == raft::distance::DistanceType::L2SqrtUnexpanded)); + data_ = make_device_mdarray(handle, make_extents(index_size, dim())); indices_ = make_device_mdarray(handle, make_extents(index_size)); center_norms_ = diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh index e951d8fe5d..a8c3ed0082 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh @@ -191,8 +191,7 @@ inline auto extend(const handle_t& handle, update_host(&index_size, list_offsets_ptr + n_lists, 1, stream); handle.sync_stream(stream); - ext_index.allocate( - handle, index_size, ext_index.metric() == raft::distance::DistanceType::L2Expanded); + ext_index.allocate(handle, index_size); // Populate index with the old data if (orig_index.size() > 0) { @@ -359,8 +358,7 @@ inline void fill_refinement_index(const handle_t& handle, stream); IdxT index_size = n_roundup * n_lists; - refinement_index->allocate( - handle, index_size, refinement_index->metric() == raft::distance::DistanceType::L2Expanded); + refinement_index->allocate(handle, index_size); RAFT_CUDA_TRY(cudaMemsetAsync(list_sizes_ptr, 0, n_lists * sizeof(uint32_t), stream)); @@ -454,7 +452,7 @@ auto load(const handle_t& handle, const std::string& filename) -> index index index_ = raft::spatial::knn::ivf_flat::index(handle, metric, n_lists, adaptive_centers, dim); - index_.allocate(handle, n_rows, metric == raft::distance::DistanceType::L2Expanded); + index_.allocate(handle, n_rows); auto data = index_.data(); read_mdspan(handle, infile, data); read_mdspan(handle, infile, index_.indices()); diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh index 628b83a23c..f523afa22d 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh @@ -935,6 +935,8 @@ void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... arg switch (metric) { case raft::distance::DistanceType::L2Expanded: case raft::distance::DistanceType::L2Unexpanded: + case raft::distance::DistanceType::L2SqrtExpanded: + case raft::distance::DistanceType::L2SqrtUnexpanded: return launch_kernel( + distances, distances, n_queries * k, raft::sqrt_op(), handle.get_stream()); + } } /** diff --git a/cpp/test/neighbors/ann_ivf_flat.cu b/cpp/test/neighbors/ann_ivf_flat.cu index 3285bc3496..4e52f79983 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cu +++ b/cpp/test/neighbors/ann_ivf_flat.cu @@ -294,6 +294,8 @@ const std::vector> inputs = { {1000, 10000, 4, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, false}, {1000, 10000, 5, 16, 40, 1024, raft::distance::DistanceType::InnerProduct, false}, {1000, 10000, 8, 16, 40, 1024, raft::distance::DistanceType::InnerProduct, true}, + {1000, 10000, 5, 16, 40, 1024, raft::distance::DistanceType::L2SqrtExpanded, false}, + {1000, 10000, 8, 16, 40, 1024, raft::distance::DistanceType::L2SqrtExpanded, true}, // test dims that do not fit into kernel shared memory limits {1000, 10000, 2048, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, false}, From e9e8faffd5d7675777533a060eb5c03f216ff47d Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Wed, 11 Jan 2023 10:01:15 -0800 Subject: [PATCH 2/3] Fix --- cpp/include/raft/neighbors/ivf_flat_types.hpp | 2 +- cpp/include/raft/spatial/knn/detail/ann_quantized.cuh | 4 +++- cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh | 2 +- cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh | 2 +- cpp/test/neighbors/ann_ivf_flat.cu | 2 +- 5 files changed, 7 insertions(+), 5 deletions(-) diff --git a/cpp/include/raft/neighbors/ivf_flat_types.hpp b/cpp/include/raft/neighbors/ivf_flat_types.hpp index c56b6a0e65..05f36c6fd0 100644 --- a/cpp/include/raft/neighbors/ivf_flat_types.hpp +++ b/cpp/include/raft/neighbors/ivf_flat_types.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh index 10f781d817..2b8950d156 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -122,6 +122,8 @@ void approx_knn_build_index(const handle_t& handle, if (ivf_ft_pams && (metric == raft::distance::DistanceType::L2Unexpanded || metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded || + metric == raft::distance::DistanceType::L2SqrtUnexpanded || metric == raft::distance::DistanceType::InnerProduct)) { auto new_params = from_legacy_index_params(*ivf_ft_pams, metric, metricArg); index->ivf_flat() = std::make_unique>( diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh index a8c3ed0082..ed2c6bae49 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh index f523afa22d..64e0e15430 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/test/neighbors/ann_ivf_flat.cu b/cpp/test/neighbors/ann_ivf_flat.cu index 4e52f79983..86a62bb487 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cu +++ b/cpp/test/neighbors/ann_ivf_flat.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From 5145ff724d73a4d6a976c981b0d086b738773bbf Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 12 Jan 2023 16:43:20 -0800 Subject: [PATCH 3/3] updates from code review --- cpp/include/raft/neighbors/ivf_flat_types.hpp | 20 +++-- .../raft/spatial/knn/detail/ann_quantized.cuh | 18 ++-- .../spatial/knn/detail/ivf_flat_search.cuh | 84 +++++++++++-------- .../spatial/knn/detail/topk/warpsort_topk.cuh | 12 +-- 4 files changed, 79 insertions(+), 55 deletions(-) diff --git a/cpp/include/raft/neighbors/ivf_flat_types.hpp b/cpp/include/raft/neighbors/ivf_flat_types.hpp index 05f36c6fd0..fc5a8116ab 100644 --- a/cpp/include/raft/neighbors/ivf_flat_types.hpp +++ b/cpp/include/raft/neighbors/ivf_flat_types.hpp @@ -254,17 +254,19 @@ struct index : ann::index { */ void allocate(const handle_t& handle, IdxT index_size) { - bool allocate_center_norms = ((metric_ == raft::distance::DistanceType::L2Expanded) || - (metric_ == raft::distance::DistanceType::L2SqrtExpanded) || - (metric_ == raft::distance::DistanceType::L2Unexpanded) || - (metric_ == raft::distance::DistanceType::L2SqrtUnexpanded)); - data_ = make_device_mdarray(handle, make_extents(index_size, dim())); indices_ = make_device_mdarray(handle, make_extents(index_size)); - center_norms_ = - allocate_center_norms - ? std::optional(make_device_mdarray(handle, make_extents(n_lists()))) - : std::nullopt; + + switch (metric_) { + case raft::distance::DistanceType::L2Expanded: + case raft::distance::DistanceType::L2SqrtExpanded: + case raft::distance::DistanceType::L2Unexpanded: + case raft::distance::DistanceType::L2SqrtUnexpanded: + center_norms_ = make_device_mdarray(handle, make_extents(n_lists())); + break; + default: center_norms_ = std::nullopt; + } + check_consistency(); } diff --git a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh index 2b8950d156..975f1a0f89 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh @@ -94,6 +94,18 @@ void approx_knn_ivfsq_build_index(knnIndex* index, const IVFSQParam& params, Int index->gpu_res.get(), D, params.nlist, faiss_qtype, faiss_metric, params.encodeResidual)); } +inline bool ivf_flat_supported_metric(raft::distance::DistanceType metric) +{ + switch (metric) { + case raft::distance::DistanceType::L2Unexpanded: + case raft::distance::DistanceType::L2Expanded: + case raft::distance::DistanceType::L2SqrtExpanded: + case raft::distance::DistanceType::L2SqrtUnexpanded: + case raft::distance::DistanceType::InnerProduct: return true; + default: return false; + } +} + template void approx_knn_build_index(const handle_t& handle, knnIndex* index, @@ -120,11 +132,7 @@ void approx_knn_build_index(const handle_t& handle, } if constexpr (std::is_same_v) { index->metric_processor->preprocess(index_array); } - if (ivf_ft_pams && (metric == raft::distance::DistanceType::L2Unexpanded || - metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded || - metric == raft::distance::DistanceType::L2SqrtUnexpanded || - metric == raft::distance::DistanceType::InnerProduct)) { + if (ivf_ft_pams && ivf_flat_supported_metric(metric)) { auto new_params = from_legacy_index_params(*ivf_ft_pams, metric, metricArg); index->ivf_flat() = std::make_unique>( ivf_flat::build(handle, new_params, index_array, int64_t(n), D)); diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh index 64e0e15430..fac8519a03 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh @@ -663,9 +663,11 @@ template + typename Lambda, + typename PostLambda> __global__ void __launch_bounds__(kThreadsPerBlock) interleaved_scan_kernel(Lambda compute_dist, + PostLambda post_process, const uint32_t query_smem_elems, const T* query, const uint32_t* coarse_index, @@ -777,7 +779,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) // finalize and store selected neighbours queue.done(); - queue.store(distances, neighbors); + queue.store(distances, neighbors, post_process); } /** @@ -805,8 +807,10 @@ template + typename Lambda, + typename PostLambda> void launch_kernel(Lambda lambda, + PostLambda post_process, const ivf_flat::index& index, const T* queries, const uint32_t* coarse_index, @@ -821,7 +825,7 @@ void launch_kernel(Lambda lambda, RAFT_EXPECTS(Veclen == index.veclen(), "Configured Veclen does not match the index interleaving pattern."); constexpr auto kKernel = - interleaved_scan_kernel; + interleaved_scan_kernel; const int max_query_smem = 16384; int query_smem_elems = std::min(max_query_smem / sizeof(T), Pow2::roundUp(index.dim())); @@ -851,6 +855,7 @@ void launch_kernel(Lambda lambda, n_probes, smem_size); kKernel<<>>(lambda, + post_process, query_smem_elems, queries, coarse_index, @@ -935,6 +940,14 @@ void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... arg switch (metric) { case raft::distance::DistanceType::L2Expanded: case raft::distance::DistanceType::L2Unexpanded: + return launch_kernel, + raft::identity_op>({}, {}, std::forward(args)...); case raft::distance::DistanceType::L2SqrtExpanded: case raft::distance::DistanceType::L2SqrtUnexpanded: return launch_kernel>({}, std::forward(args)...); + euclidean_dist, + raft::sqrt_op>({}, {}, std::forward(args)...); case raft::distance::DistanceType::InnerProduct: return launch_kernel>({}, std::forward(args)...); + inner_prod_dist, + raft::identity_op>({}, {}, std::forward(args)...); // NB: update the description of `knn::ivf_flat::build` when adding here a new metric. default: RAFT_FAIL("The chosen distance metric is not supported (%d)", int(metric)); } @@ -1107,29 +1122,33 @@ void search_impl(const handle_t& handle, float beta = 0.0f; // todo(lsugy): raft distance? (if performance is similar/better than gemm) - if ((index.metric() == raft::distance::DistanceType::L2Expanded) || - (index.metric() == raft::distance::DistanceType::L2SqrtExpanded)) { - alpha = -2.0f; - beta = 1.0f; - raft::linalg::rowNorm(query_norm_dev.data(), - converted_queries_ptr, - static_cast(index.dim()), - static_cast(n_queries), - raft::linalg::L2Norm, - true, - stream, - raft::sqrt_op()); - utils::outer_add(query_norm_dev.data(), - (IdxT)n_queries, - index.center_norms()->data_handle(), - (IdxT)index.n_lists(), - distance_buffer_dev.data(), - stream); - RAFT_LOG_TRACE_VEC(index.center_norms()->data_handle(), std::min(20, index.dim())); - RAFT_LOG_TRACE_VEC(distance_buffer_dev.data(), std::min(20, index.n_lists())); - } else { - alpha = 1.0f; - beta = 0.0f; + switch (index.metric()) { + case raft::distance::DistanceType::L2Expanded: + case raft::distance::DistanceType::L2SqrtExpanded: { + alpha = -2.0f; + beta = 1.0f; + raft::linalg::rowNorm(query_norm_dev.data(), + converted_queries_ptr, + static_cast(index.dim()), + static_cast(n_queries), + raft::linalg::L2Norm, + true, + stream, + raft::sqrt_op()); + utils::outer_add(query_norm_dev.data(), + (IdxT)n_queries, + index.center_norms()->data_handle(), + (IdxT)index.n_lists(), + distance_buffer_dev.data(), + stream); + RAFT_LOG_TRACE_VEC(index.center_norms()->data_handle(), std::min(20, index.dim())); + RAFT_LOG_TRACE_VEC(distance_buffer_dev.data(), std::min(20, index.n_lists())); + break; + } + default: { + alpha = 1.0f; + beta = 0.0f; + } } linalg::gemm(handle, @@ -1218,13 +1237,6 @@ void search_impl(const handle_t& handle, stream, search_mr); } - - // post-process - if (index.metric() == raft::distance::DistanceType::L2SqrtExpanded || - index.metric() == raft::distance::DistanceType::L2SqrtUnexpanded) { - raft::linalg::unaryOp( - distances, distances, n_queries * k, raft::sqrt_op(), handle.get_stream()); - } } /** diff --git a/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh b/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh index cbe9f36e97..c06aa04aea 100644 --- a/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh +++ b/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -212,12 +212,13 @@ class warp_sort { * device pointer to a contiguous array, unique per-subwarp of size `kWarpWidth` * (length: k <= kWarpWidth * kMaxArrLen). */ - __device__ void store(T* out, IdxT* out_idx) const + template + __device__ void store(T* out, IdxT* out_idx, Lambda post_process = raft::identity_op()) const { int idx = Pow2::mod(laneId()); #pragma unroll kMaxArrLen for (int i = 0; i < kMaxArrLen && idx < k; i++, idx += kWarpWidth) { - out[idx] = val_arr_[i]; + out[idx] = post_process(val_arr_[i]); out_idx[idx] = idx_arr_[i]; } } @@ -591,9 +592,10 @@ class block_sort { } /** Save the content by the pointer location. */ - __device__ void store(T* out, IdxT* out_idx) const + template + __device__ void store(T* out, IdxT* out_idx, Lambda post_process = raft::identity_op()) const { - if (threadIdx.x < subwarp_align::Value) { queue_.store(out, out_idx); } + if (threadIdx.x < subwarp_align::Value) { queue_.store(out, out_idx, post_process); } } private: