diff --git a/cpp/include/raft/neighbors/ivf_flat_types.hpp b/cpp/include/raft/neighbors/ivf_flat_types.hpp index eea6ae256d..fc5a8116ab 100644 --- a/cpp/include/raft/neighbors/ivf_flat_types.hpp +++ b/cpp/include/raft/neighbors/ivf_flat_types.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -252,14 +252,21 @@ struct index : ann::index { * Replace the content of the index with new uninitialized mdarrays to hold the indicated amount * of data. */ - void allocate(const handle_t& handle, IdxT index_size, bool allocate_center_norms) + void allocate(const handle_t& handle, IdxT index_size) { data_ = make_device_mdarray(handle, make_extents(index_size, dim())); indices_ = make_device_mdarray(handle, make_extents(index_size)); - center_norms_ = - allocate_center_norms - ? std::optional(make_device_mdarray(handle, make_extents(n_lists()))) - : std::nullopt; + + switch (metric_) { + case raft::distance::DistanceType::L2Expanded: + case raft::distance::DistanceType::L2SqrtExpanded: + case raft::distance::DistanceType::L2Unexpanded: + case raft::distance::DistanceType::L2SqrtUnexpanded: + center_norms_ = make_device_mdarray(handle, make_extents(n_lists())); + break; + default: center_norms_ = std::nullopt; + } + check_consistency(); } diff --git a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh index 10f781d817..975f1a0f89 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -94,6 +94,18 @@ void approx_knn_ivfsq_build_index(knnIndex* index, const IVFSQParam& params, Int index->gpu_res.get(), D, params.nlist, faiss_qtype, faiss_metric, params.encodeResidual)); } +inline bool ivf_flat_supported_metric(raft::distance::DistanceType metric) +{ + switch (metric) { + case raft::distance::DistanceType::L2Unexpanded: + case raft::distance::DistanceType::L2Expanded: + case raft::distance::DistanceType::L2SqrtExpanded: + case raft::distance::DistanceType::L2SqrtUnexpanded: + case raft::distance::DistanceType::InnerProduct: return true; + default: return false; + } +} + template void approx_knn_build_index(const handle_t& handle, knnIndex* index, @@ -120,9 +132,7 @@ void approx_knn_build_index(const handle_t& handle, } if constexpr (std::is_same_v) { index->metric_processor->preprocess(index_array); } - if (ivf_ft_pams && (metric == raft::distance::DistanceType::L2Unexpanded || - metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::InnerProduct)) { + if (ivf_ft_pams && ivf_flat_supported_metric(metric)) { auto new_params = from_legacy_index_params(*ivf_ft_pams, metric, metricArg); index->ivf_flat() = std::make_unique>( ivf_flat::build(handle, new_params, index_array, int64_t(n), D)); diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh index e951d8fe5d..ed2c6bae49 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -191,8 +191,7 @@ inline auto extend(const handle_t& handle, update_host(&index_size, list_offsets_ptr + n_lists, 1, stream); handle.sync_stream(stream); - ext_index.allocate( - handle, index_size, ext_index.metric() == raft::distance::DistanceType::L2Expanded); + ext_index.allocate(handle, index_size); // Populate index with the old data if (orig_index.size() > 0) { @@ -359,8 +358,7 @@ inline void fill_refinement_index(const handle_t& handle, stream); IdxT index_size = n_roundup * n_lists; - refinement_index->allocate( - handle, index_size, refinement_index->metric() == raft::distance::DistanceType::L2Expanded); + refinement_index->allocate(handle, index_size); RAFT_CUDA_TRY(cudaMemsetAsync(list_sizes_ptr, 0, n_lists * sizeof(uint32_t), stream)); @@ -454,7 +452,7 @@ auto load(const handle_t& handle, const std::string& filename) -> index index index_ = raft::spatial::knn::ivf_flat::index(handle, metric, n_lists, adaptive_centers, dim); - index_.allocate(handle, n_rows, metric == raft::distance::DistanceType::L2Expanded); + index_.allocate(handle, n_rows); auto data = index_.data(); read_mdspan(handle, infile, data); read_mdspan(handle, infile, index_.indices()); diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh index 8ed71864fd..fac8519a03 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh @@ -663,9 +663,11 @@ template + typename Lambda, + typename PostLambda> __global__ void __launch_bounds__(kThreadsPerBlock) interleaved_scan_kernel(Lambda compute_dist, + PostLambda post_process, const uint32_t query_smem_elems, const T* query, const uint32_t* coarse_index, @@ -777,7 +779,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) // finalize and store selected neighbours queue.done(); - queue.store(distances, neighbors); + queue.store(distances, neighbors, post_process); } /** @@ -805,8 +807,10 @@ template + typename Lambda, + typename PostLambda> void launch_kernel(Lambda lambda, + PostLambda post_process, const ivf_flat::index& index, const T* queries, const uint32_t* coarse_index, @@ -821,7 +825,7 @@ void launch_kernel(Lambda lambda, RAFT_EXPECTS(Veclen == index.veclen(), "Configured Veclen does not match the index interleaving pattern."); constexpr auto kKernel = - interleaved_scan_kernel; + interleaved_scan_kernel; const int max_query_smem = 16384; int query_smem_elems = std::min(max_query_smem / sizeof(T), Pow2::roundUp(index.dim())); @@ -851,6 +855,7 @@ void launch_kernel(Lambda lambda, n_probes, smem_size); kKernel<<>>(lambda, + post_process, query_smem_elems, queries, coarse_index, @@ -941,7 +946,18 @@ void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... arg T, AccT, IdxT, - euclidean_dist>({}, std::forward(args)...); + euclidean_dist, + raft::identity_op>({}, {}, std::forward(args)...); + case raft::distance::DistanceType::L2SqrtExpanded: + case raft::distance::DistanceType::L2SqrtUnexpanded: + return launch_kernel, + raft::sqrt_op>({}, {}, std::forward(args)...); case raft::distance::DistanceType::InnerProduct: return launch_kernel>({}, std::forward(args)...); + inner_prod_dist, + raft::identity_op>({}, {}, std::forward(args)...); // NB: update the description of `knn::ivf_flat::build` when adding here a new metric. default: RAFT_FAIL("The chosen distance metric is not supported (%d)", int(metric)); } @@ -1105,28 +1122,33 @@ void search_impl(const handle_t& handle, float beta = 0.0f; // todo(lsugy): raft distance? (if performance is similar/better than gemm) - if (index.metric() == raft::distance::DistanceType::L2Expanded) { - alpha = -2.0f; - beta = 1.0f; - raft::linalg::rowNorm(query_norm_dev.data(), - converted_queries_ptr, - static_cast(index.dim()), - static_cast(n_queries), - raft::linalg::L2Norm, - true, - stream, - raft::sqrt_op()); - utils::outer_add(query_norm_dev.data(), - (IdxT)n_queries, - index.center_norms()->data_handle(), - (IdxT)index.n_lists(), - distance_buffer_dev.data(), - stream); - RAFT_LOG_TRACE_VEC(index.center_norms()->data_handle(), std::min(20, index.dim())); - RAFT_LOG_TRACE_VEC(distance_buffer_dev.data(), std::min(20, index.n_lists())); - } else { - alpha = 1.0f; - beta = 0.0f; + switch (index.metric()) { + case raft::distance::DistanceType::L2Expanded: + case raft::distance::DistanceType::L2SqrtExpanded: { + alpha = -2.0f; + beta = 1.0f; + raft::linalg::rowNorm(query_norm_dev.data(), + converted_queries_ptr, + static_cast(index.dim()), + static_cast(n_queries), + raft::linalg::L2Norm, + true, + stream, + raft::sqrt_op()); + utils::outer_add(query_norm_dev.data(), + (IdxT)n_queries, + index.center_norms()->data_handle(), + (IdxT)index.n_lists(), + distance_buffer_dev.data(), + stream); + RAFT_LOG_TRACE_VEC(index.center_norms()->data_handle(), std::min(20, index.dim())); + RAFT_LOG_TRACE_VEC(distance_buffer_dev.data(), std::min(20, index.n_lists())); + break; + } + default: { + alpha = 1.0f; + beta = 0.0f; + } } linalg::gemm(handle, diff --git a/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh b/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh index cbe9f36e97..c06aa04aea 100644 --- a/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh +++ b/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -212,12 +212,13 @@ class warp_sort { * device pointer to a contiguous array, unique per-subwarp of size `kWarpWidth` * (length: k <= kWarpWidth * kMaxArrLen). */ - __device__ void store(T* out, IdxT* out_idx) const + template + __device__ void store(T* out, IdxT* out_idx, Lambda post_process = raft::identity_op()) const { int idx = Pow2::mod(laneId()); #pragma unroll kMaxArrLen for (int i = 0; i < kMaxArrLen && idx < k; i++, idx += kWarpWidth) { - out[idx] = val_arr_[i]; + out[idx] = post_process(val_arr_[i]); out_idx[idx] = idx_arr_[i]; } } @@ -591,9 +592,10 @@ class block_sort { } /** Save the content by the pointer location. */ - __device__ void store(T* out, IdxT* out_idx) const + template + __device__ void store(T* out, IdxT* out_idx, Lambda post_process = raft::identity_op()) const { - if (threadIdx.x < subwarp_align::Value) { queue_.store(out, out_idx); } + if (threadIdx.x < subwarp_align::Value) { queue_.store(out, out_idx, post_process); } } private: diff --git a/cpp/test/neighbors/ann_ivf_flat.cu b/cpp/test/neighbors/ann_ivf_flat.cu index 3285bc3496..86a62bb487 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cu +++ b/cpp/test/neighbors/ann_ivf_flat.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -294,6 +294,8 @@ const std::vector> inputs = { {1000, 10000, 4, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, false}, {1000, 10000, 5, 16, 40, 1024, raft::distance::DistanceType::InnerProduct, false}, {1000, 10000, 8, 16, 40, 1024, raft::distance::DistanceType::InnerProduct, true}, + {1000, 10000, 5, 16, 40, 1024, raft::distance::DistanceType::L2SqrtExpanded, false}, + {1000, 10000, 8, 16, 40, 1024, raft::distance::DistanceType::L2SqrtExpanded, true}, // test dims that do not fit into kernel shared memory limits {1000, 10000, 2048, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, false},