diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 2da18d2a74..81a6d76507 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -340,6 +340,21 @@ if(RAFT_COMPILE_NN_LIBRARY) src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu + src/nn/specializations/detail/ivfpq_compute_similarity_float_fast.cu + src/nn/specializations/detail/ivfpq_compute_similarity_float_no_basediff.cu + src/nn/specializations/detail/ivfpq_compute_similarity_float_no_smem_lut.cu + src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_fast.cu + src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_no_basediff.cu + src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_no_smem_lut.cu + src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_fast.cu + src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_no_basediff.cu + src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_no_smem_lut.cu + src/nn/specializations/detail/ivfpq_compute_similarity_half_fast.cu + src/nn/specializations/detail/ivfpq_compute_similarity_half_no_basediff.cu + src/nn/specializations/detail/ivfpq_compute_similarity_half_no_smem_lut.cu + src/nn/specializations/detail/ivfpq_search_float_int64_t.cu + src/nn/specializations/detail/ivfpq_search_float_uint32_t.cu + src/nn/specializations/detail/ivfpq_search_float_uint64_t.cu src/nn/specializations/fused_l2_knn_long_float_true.cu src/nn/specializations/fused_l2_knn_long_float_false.cu src/nn/specializations/fused_l2_knn_int_float_true.cu diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index 266571d4f3..ff457faef7 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -33,7 +33,16 @@ add_executable(${RAFT_CPP_BENCH_TARGET} bench/random/rng.cu bench/sparse/convert_csr.cu bench/spatial/fused_l2_nn.cu - bench/spatial/knn.cu + bench/spatial/knn/brute_force_float_int64_t.cu + bench/spatial/knn/brute_force_float_uint32_t.cu + bench/spatial/knn/ivf_flat_float_int64_t.cu + bench/spatial/knn/ivf_flat_float_uint32_t.cu + bench/spatial/knn/ivf_flat_int8_t_int64_t.cu + bench/spatial/knn/ivf_flat_uint8_t_uint32_t.cu + bench/spatial/knn/ivf_pq_float_int64_t.cu + bench/spatial/knn/ivf_pq_float_uint32_t.cu + bench/spatial/knn/ivf_pq_int8_t_int64_t.cu + bench/spatial/knn/ivf_pq_uint8_t_uint32_t.cu bench/spatial/selection.cu bench/main.cpp ) diff --git a/cpp/bench/spatial/knn.cu b/cpp/bench/spatial/knn.cuh similarity index 87% rename from cpp/bench/spatial/knn.cu rename to cpp/bench/spatial/knn.cuh index 6b08c7ee33..dedc9ea702 100644 --- a/cpp/bench/spatial/knn.cu +++ b/cpp/bench/spatial/knn.cuh @@ -14,11 +14,14 @@ * limitations under the License. */ +#pragma once + #include #include #include +#include #include #if defined RAFT_NN_COMPILED #include @@ -45,16 +48,16 @@ struct params { size_t k; }; -auto operator<<(std::ostream& os, const params& p) -> std::ostream& +inline auto operator<<(std::ostream& os, const params& p) -> std::ostream& { os << p.n_samples << "#" << p.n_dims << "#" << p.n_queries << "#" << p.k; return os; } -enum class TransferStrategy { NO_COPY, COPY_PLAIN, COPY_PINNED, MAP_PINNED, MANAGED }; -enum class Scope { BUILD, SEARCH, BUILD_SEARCH }; +enum class TransferStrategy { NO_COPY, COPY_PLAIN, COPY_PINNED, MAP_PINNED, MANAGED }; // NOLINT +enum class Scope { BUILD, SEARCH, BUILD_SEARCH }; // NOLINT -auto operator<<(std::ostream& os, const TransferStrategy& ts) -> std::ostream& +inline auto operator<<(std::ostream& os, const TransferStrategy& ts) -> std::ostream& { switch (ts) { case TransferStrategy::NO_COPY: os << "NO_COPY"; break; @@ -67,7 +70,7 @@ auto operator<<(std::ostream& os, const TransferStrategy& ts) -> std::ostream& return os; } -auto operator<<(std::ostream& os, const Scope& s) -> std::ostream& +inline auto operator<<(std::ostream& os, const Scope& s) -> std::ostream& { switch (s) { case Scope::BUILD: os << "BUILD"; break; @@ -156,6 +159,34 @@ struct ivf_flat_knn { } }; +template +struct ivf_pq_knn { + using dist_t = float; + + std::optional> index; + raft::spatial::knn::ivf_pq::index_params index_params; + raft::spatial::knn::ivf_pq::search_params search_params; + params ps; + + ivf_pq_knn(const raft::handle_t& handle, const params& ps, const ValT* data) : ps(ps) + { + index_params.n_lists = 4096; + index_params.metric = raft::distance::DistanceType::L2Expanded; + index.emplace(raft::spatial::knn::ivf_pq::build( + handle, index_params, data, IdxT(ps.n_samples), uint32_t(ps.n_dims))); + } + + void search(const raft::handle_t& handle, + const ValT* search_items, + dist_t* out_dists, + IdxT* out_idxs) + { + search_params.n_probes = 20; + raft::spatial::knn::ivf_pq::search( + handle, search_params, *index, search_items, ps.n_queries, ps.k, out_idxs, out_dists); + } +}; + template struct brute_force_knn { using dist_t = ValT; @@ -217,7 +248,7 @@ struct knn : public fixture { } template - void gen_data(raft::random::RngState& state, + void gen_data(raft::random::RngState& state, // NOLINT rmm::device_uvector& vec, size_t n, rmm::cuda_stream_view stream) @@ -338,15 +369,15 @@ struct knn : public fixture { rmm::device_uvector out_idxs_; }; -const std::vector kInputs{ +inline const std::vector kInputs{ {2000000, 128, 1000, 32}, {10000000, 128, 1000, 32}, {10000, 8192, 1000, 32}}; -const std::vector kAllStrategies{ +inline const std::vector kAllStrategies{ TransferStrategy::NO_COPY, TransferStrategy::MAP_PINNED, TransferStrategy::MANAGED}; -const std::vector kNoCopyOnly{TransferStrategy::NO_COPY}; +inline const std::vector kNoCopyOnly{TransferStrategy::NO_COPY}; -const std::vector kScopeFull{Scope::BUILD_SEARCH}; -const std::vector kAllScopes{Scope::BUILD_SEARCH, Scope::SEARCH, Scope::BUILD}; +inline const std::vector kScopeFull{Scope::BUILD_SEARCH}; +inline const std::vector kAllScopes{Scope::BUILD_SEARCH, Scope::SEARCH, Scope::BUILD}; #define KNN_REGISTER(ValT, IdxT, ImplT, inputs, strats, scope) \ namespace BENCHMARK_PRIVATE_NAME(knn) \ @@ -355,14 +386,4 @@ const std::vector kAllScopes{Scope::BUILD_SEARCH, Scope::SEARCH, Scope::B RAFT_BENCH_REGISTER(KNN, #ValT "/" #IdxT "/" #ImplT, inputs, strats, scope); \ } -KNN_REGISTER(float, int64_t, brute_force_knn, kInputs, kAllStrategies, kScopeFull); -KNN_REGISTER(float, int64_t, ivf_flat_knn, kInputs, kNoCopyOnly, kAllScopes); -KNN_REGISTER(int8_t, int64_t, ivf_flat_knn, kInputs, kNoCopyOnly, kAllScopes); -KNN_REGISTER(uint8_t, int64_t, ivf_flat_knn, kInputs, kNoCopyOnly, kAllScopes); - -KNN_REGISTER(float, uint32_t, brute_force_knn, kInputs, kNoCopyOnly, kScopeFull); -KNN_REGISTER(float, uint32_t, ivf_flat_knn, kInputs, kNoCopyOnly, kAllScopes); -KNN_REGISTER(int8_t, uint32_t, ivf_flat_knn, kInputs, kNoCopyOnly, kAllScopes); -KNN_REGISTER(uint8_t, uint32_t, ivf_flat_knn, kInputs, kNoCopyOnly, kAllScopes); - } // namespace raft::bench::spatial diff --git a/cpp/bench/spatial/knn/brute_force_float_int64_t.cu b/cpp/bench/spatial/knn/brute_force_float_int64_t.cu new file mode 100644 index 0000000000..d981104e20 --- /dev/null +++ b/cpp/bench/spatial/knn/brute_force_float_int64_t.cu @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../knn.cuh" + +namespace raft::bench::spatial { + +KNN_REGISTER(float, int64_t, brute_force_knn, kInputs, kAllStrategies, kScopeFull); + +} // namespace raft::bench::spatial diff --git a/cpp/bench/spatial/knn/brute_force_float_uint32_t.cu b/cpp/bench/spatial/knn/brute_force_float_uint32_t.cu new file mode 100644 index 0000000000..60f7edae96 --- /dev/null +++ b/cpp/bench/spatial/knn/brute_force_float_uint32_t.cu @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../knn.cuh" + +namespace raft::bench::spatial { + +KNN_REGISTER(float, uint32_t, brute_force_knn, kInputs, kAllStrategies, kScopeFull); + +} // namespace raft::bench::spatial diff --git a/cpp/bench/spatial/knn/ivf_flat_float_int64_t.cu b/cpp/bench/spatial/knn/ivf_flat_float_int64_t.cu new file mode 100644 index 0000000000..594d4d16d2 --- /dev/null +++ b/cpp/bench/spatial/knn/ivf_flat_float_int64_t.cu @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../knn.cuh" + +namespace raft::bench::spatial { + +KNN_REGISTER(float, int64_t, ivf_flat_knn, kInputs, kNoCopyOnly, kAllScopes); + +} // namespace raft::bench::spatial diff --git a/cpp/bench/spatial/knn/ivf_flat_float_uint32_t.cu b/cpp/bench/spatial/knn/ivf_flat_float_uint32_t.cu new file mode 100644 index 0000000000..595ad2b922 --- /dev/null +++ b/cpp/bench/spatial/knn/ivf_flat_float_uint32_t.cu @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../knn.cuh" + +namespace raft::bench::spatial { + +KNN_REGISTER(float, uint32_t, ivf_flat_knn, kInputs, kNoCopyOnly, kAllScopes); + +} // namespace raft::bench::spatial diff --git a/cpp/bench/spatial/knn/ivf_flat_int8_t_int64_t.cu b/cpp/bench/spatial/knn/ivf_flat_int8_t_int64_t.cu new file mode 100644 index 0000000000..bd268f036c --- /dev/null +++ b/cpp/bench/spatial/knn/ivf_flat_int8_t_int64_t.cu @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../knn.cuh" + +namespace raft::bench::spatial { + +KNN_REGISTER(int8_t, int64_t, ivf_flat_knn, kInputs, kNoCopyOnly, kAllScopes); + +} // namespace raft::bench::spatial diff --git a/cpp/bench/spatial/knn/ivf_flat_uint8_t_uint32_t.cu b/cpp/bench/spatial/knn/ivf_flat_uint8_t_uint32_t.cu new file mode 100644 index 0000000000..9d8b982c3e --- /dev/null +++ b/cpp/bench/spatial/knn/ivf_flat_uint8_t_uint32_t.cu @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../knn.cuh" + +namespace raft::bench::spatial { + +KNN_REGISTER(uint8_t, uint32_t, ivf_flat_knn, kInputs, kNoCopyOnly, kAllScopes); + +} // namespace raft::bench::spatial diff --git a/cpp/bench/spatial/knn/ivf_pq_float_int64_t.cu b/cpp/bench/spatial/knn/ivf_pq_float_int64_t.cu new file mode 100644 index 0000000000..18d8cd8ad6 --- /dev/null +++ b/cpp/bench/spatial/knn/ivf_pq_float_int64_t.cu @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../knn.cuh" + +namespace raft::bench::spatial { + +KNN_REGISTER(float, int64_t, ivf_pq_knn, kInputs, kNoCopyOnly, kAllScopes); + +} // namespace raft::bench::spatial diff --git a/cpp/bench/spatial/knn/ivf_pq_float_uint32_t.cu b/cpp/bench/spatial/knn/ivf_pq_float_uint32_t.cu new file mode 100644 index 0000000000..81621674bf --- /dev/null +++ b/cpp/bench/spatial/knn/ivf_pq_float_uint32_t.cu @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../knn.cuh" + +namespace raft::bench::spatial { + +KNN_REGISTER(float, uint32_t, ivf_pq_knn, kInputs, kNoCopyOnly, kAllScopes); + +} // namespace raft::bench::spatial diff --git a/cpp/bench/spatial/knn/ivf_pq_int8_t_int64_t.cu b/cpp/bench/spatial/knn/ivf_pq_int8_t_int64_t.cu new file mode 100644 index 0000000000..cc28eee67c --- /dev/null +++ b/cpp/bench/spatial/knn/ivf_pq_int8_t_int64_t.cu @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../knn.cuh" + +namespace raft::bench::spatial { + +KNN_REGISTER(int8_t, int64_t, ivf_pq_knn, kInputs, kNoCopyOnly, kAllScopes); + +} // namespace raft::bench::spatial diff --git a/cpp/bench/spatial/knn/ivf_pq_uint8_t_uint32_t.cu b/cpp/bench/spatial/knn/ivf_pq_uint8_t_uint32_t.cu new file mode 100644 index 0000000000..b4759cbac1 --- /dev/null +++ b/cpp/bench/spatial/knn/ivf_pq_uint8_t_uint32_t.cu @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../knn.cuh" + +namespace raft::bench::spatial { + +KNN_REGISTER(uint8_t, uint32_t, ivf_pq_knn, kInputs, kNoCopyOnly, kAllScopes); + +} // namespace raft::bench::spatial diff --git a/cpp/include/raft/linalg/detail/qr.cuh b/cpp/include/raft/linalg/detail/qr.cuh index 4aa843081e..74e9c3e1aa 100644 --- a/cpp/include/raft/linalg/detail/qr.cuh +++ b/cpp/include/raft/linalg/detail/qr.cuh @@ -28,6 +28,60 @@ namespace raft { namespace linalg { namespace detail { +/** + * @brief Calculate the QR decomposition and get matrix Q in place of the input. + * + * Subject to the algorithm constraint `n_rows >= n_cols`. + * + * @param handle + * @param[inout] Q device pointer to input matrix and the output matrix Q, + * both column-major and of size [n_rows, n_cols]. + * @param n_rows + * @param n_cols + * @param stream + */ +template +void qrGetQ_inplace( + const raft::handle_t& handle, math_t* Q, int n_rows, int n_cols, cudaStream_t stream) +{ + RAFT_EXPECTS(n_rows >= n_cols, "QR decomposition expects n_rows >= n_cols."); + cusolverDnHandle_t cusolver = handle.get_cusolver_dn_handle(); + + rmm::device_uvector tau(n_cols, stream); + RAFT_CUDA_TRY(cudaMemsetAsync(tau.data(), 0, sizeof(math_t) * n_cols, stream)); + + rmm::device_scalar dev_info(stream); + int ws_size; + + RAFT_CUSOLVER_TRY(cusolverDngeqrf_bufferSize(cusolver, n_rows, n_cols, Q, n_rows, &ws_size)); + rmm::device_uvector workspace(ws_size, stream); + RAFT_CUSOLVER_TRY(cusolverDngeqrf(cusolver, + n_rows, + n_cols, + Q, + n_rows, + tau.data(), + workspace.data(), + ws_size, + dev_info.data(), + stream)); + + RAFT_CUSOLVER_TRY( + cusolverDnorgqr_bufferSize(cusolver, n_rows, n_cols, n_cols, Q, n_rows, tau.data(), &ws_size)); + workspace.resize(ws_size, stream); + RAFT_CUSOLVER_TRY(cusolverDnorgqr(cusolver, + n_rows, + n_cols, + n_cols, + Q, + n_rows, + tau.data(), + workspace.data(), + ws_size, + dev_info.data(), + stream)); +} + template void qrGetQ(const raft::handle_t& handle, const math_t* M, @@ -36,27 +90,8 @@ void qrGetQ(const raft::handle_t& handle, int n_cols, cudaStream_t stream) { - cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle(); - - int m = n_rows, n = n_cols; - int k = std::min(m, n); - RAFT_CUDA_TRY(cudaMemcpyAsync(Q, M, sizeof(math_t) * m * n, cudaMemcpyDeviceToDevice, stream)); - - rmm::device_uvector tau(k, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(tau.data(), 0, sizeof(math_t) * k, stream)); - - rmm::device_scalar devInfo(stream); - int Lwork; - - RAFT_CUSOLVER_TRY(cusolverDngeqrf_bufferSize(cusolverH, m, n, Q, m, &Lwork)); - rmm::device_uvector workspace(Lwork, stream); - RAFT_CUSOLVER_TRY(cusolverDngeqrf( - cusolverH, m, n, Q, m, tau.data(), workspace.data(), Lwork, devInfo.data(), stream)); - - RAFT_CUSOLVER_TRY(cusolverDnorgqr_bufferSize(cusolverH, m, n, k, Q, m, tau.data(), &Lwork)); - workspace.resize(Lwork, stream); - RAFT_CUSOLVER_TRY(cusolverDnorgqr( - cusolverH, m, n, k, Q, m, tau.data(), workspace.data(), Lwork, devInfo.data(), stream)); + raft::copy(Q, M, n_rows * n_cols, stream); + qrGetQ_inplace(handle, Q, n_rows, n_cols, stream); } template diff --git a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh index e55758711a..a48fad2737 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh @@ -567,4 +567,5 @@ void copy_selected(uint64_t n_rows, default: RAFT_FAIL("All pointers must reside on the same side, host or device."); } } + } // namespace raft::spatial::knn::detail::utils diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh index 201cca5afe..770530b77c 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh @@ -18,7 +18,7 @@ #include "../ivf_flat_types.hpp" #include "ann_utils.cuh" -#include "topk/radix_topk.cuh" +#include "topk.cuh" #include "topk/warpsort_topk.cuh" #include @@ -1133,29 +1133,16 @@ void search_impl(const handle_t& handle, stream); RAFT_LOG_TRACE_VEC(distance_buffer_dev.data(), std::min(20, index.n_lists())); - if (n_probes <= raft::spatial::knn::detail::topk::kMaxCapacity) { - topk::warp_sort_topk(distance_buffer_dev.data(), - nullptr, - n_queries, - index.n_lists(), - n_probes, - coarse_distances_dev.data(), - coarse_indices_dev.data(), - select_min, - stream, - search_mr); - } else { - topk::radix_topk(distance_buffer_dev.data(), - nullptr, - n_queries, - index.n_lists(), - n_probes, - coarse_distances_dev.data(), - coarse_indices_dev.data(), - select_min, - stream, - search_mr); - } + select_topk(distance_buffer_dev.data(), + nullptr, + n_queries, + index.n_lists(), + n_probes, + coarse_distances_dev.data(), + coarse_indices_dev.data(), + select_min, + stream, + search_mr); RAFT_LOG_TRACE_VEC(coarse_indices_dev.data(), n_probes); RAFT_LOG_TRACE_VEC(coarse_distances_dev.data(), n_probes); @@ -1204,31 +1191,16 @@ void search_impl(const handle_t& handle, // Merge topk values from different blocks if (grid_dim_x > 1) { - if (k <= raft::spatial::knn::detail::topk::kMaxCapacity) { - topk::warp_sort_topk(refined_distances_dev.data(), - refined_indices_dev.data(), - n_queries, - k * grid_dim_x, - k, - distances, - neighbors, - select_min, - stream, - search_mr); - } else { - // NB: this branch can only be triggered once `ivfflat_interleaved_scan` above supports larger - // `k` values (kMaxCapacity limit as a dependency of topk::block_sort) - topk::radix_topk(refined_distances_dev.data(), - refined_indices_dev.data(), - n_queries, - k * grid_dim_x, - k, - distances, - neighbors, - select_min, - stream, - search_mr); - } + select_topk(refined_distances_dev.data(), + refined_indices_dev.data(), + n_queries, + k * grid_dim_x, + k, + distances, + neighbors, + select_min, + stream, + search_mr); } } diff --git a/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh b/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh new file mode 100644 index 0000000000..5a146c18fe --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh @@ -0,0 +1,1074 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../ivf_pq_types.hpp" +#include "ann_kmeans_balanced.cuh" +#include "ann_utils.cuh" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace raft::spatial::knn::ivf_pq::detail { + +using namespace raft::spatial::knn::detail; // NOLINT + +namespace { + +/** + * This type mimics the `uint8_t&` for the indexing operator of `bitfield_view_t`. + * + * @tparam Bits number of bits comprising the value. + */ +template +struct bitfield_ref_t { + static_assert(Bits <= 8 && Bits > 0, "Bit code must fit one byte"); + constexpr static uint8_t kMask = static_cast((1u << Bits) - 1u); + uint8_t* ptr; + uint32_t offset; + + constexpr operator uint8_t() // NOLINT + { + auto pair = static_cast(ptr[0]); + if (offset + Bits > 8) { pair |= static_cast(ptr[1]) << 8; } + return static_cast((pair >> offset) & kMask); + } + + constexpr auto operator=(uint8_t code) -> bitfield_ref_t& + { + if (offset + Bits > 8) { + auto pair = static_cast(ptr[0]); + pair |= static_cast(ptr[1]) << 8; + pair &= ~(static_cast(kMask) << offset); + pair |= static_cast(code) << offset; + ptr[0] = static_cast(Pow2<256>::mod(pair)); + ptr[1] = static_cast(Pow2<256>::div(pair)); + } else { + ptr[0] = (ptr[0] & ~(kMask << offset)) | (code << offset); + } + return *this; + } +}; + +/** + * View a byte array as an array of unsigned integers of custom small bit size. + * + * @tparam Bits number of bits comprising a single element of the array. + */ +template +struct bitfield_view_t { + static_assert(Bits <= 8 && Bits > 0, "Bit code must fit one byte"); + uint8_t* raw; + + constexpr auto operator[](uint32_t i) -> bitfield_ref_t + { + uint32_t bit_offset = i * Bits; + return bitfield_ref_t{raw + Pow2<8>::div(bit_offset), Pow2<8>::mod(bit_offset)}; + } +}; + +/* + NB: label type is uint32_t although it can only contain values up to `1 << pq_bits`. + We keep it this way to not force one more overload for kmeans::predict. + */ +template +HDI void ivfpq_encode_core(uint32_t n_rows, uint32_t pq_dim, const uint32_t* label, uint8_t* output) +{ + bitfield_view_t out{output}; + for (uint32_t j = 0; j < pq_dim; j++, label += n_rows) { + out[j] = static_cast(*label); + } +} + +template +__launch_bounds__(BlockDim) __global__ + void ivfpq_encode_kernel(uint32_t n_rows, + uint32_t pq_dim, + const uint32_t* label, // [pq_dim, n_rows] + uint8_t* output // [n_rows, pq_dim] + ) +{ + uint32_t i = threadIdx.x + BlockDim * blockIdx.x; + if (i >= n_rows) return; + ivfpq_encode_core(n_rows, pq_dim, label + i, output + (pq_dim * PqBits / 8) * i); +} +} // namespace + +inline void ivfpq_encode(uint32_t n_rows, + uint32_t pq_dim, + uint32_t pq_bits, // 4 <= pq_bits <= 8 + const uint32_t* label, // [pq_dim, n_rows] + uint8_t* output, // [n_rows, pq_dim] + rmm::cuda_stream_view stream) +{ + constexpr uint32_t kBlockDim = 128; + dim3 threads(kBlockDim, 1, 1); + dim3 blocks(raft::ceildiv(n_rows, kBlockDim), 1, 1); + switch (pq_bits) { + case 4: + return ivfpq_encode_kernel + <<>>(n_rows, pq_dim, label, output); + case 5: + return ivfpq_encode_kernel + <<>>(n_rows, pq_dim, label, output); + case 6: + return ivfpq_encode_kernel + <<>>(n_rows, pq_dim, label, output); + case 7: + return ivfpq_encode_kernel + <<>>(n_rows, pq_dim, label, output); + case 8: + return ivfpq_encode_kernel + <<>>(n_rows, pq_dim, label, output); + default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); + } +} + +/** + * @brief Fill-in a random orthogonal transformation matrix. + * + * @param handle + * @param force_random_rotation + * @param n_rows + * @param n_cols + * @param[out] rotation_matrix device pointer to a row-major matrix of size [n_rows, n_cols]. + * @param rng random number generator state + */ +inline void make_rotation_matrix(const handle_t& handle, + bool force_random_rotation, + uint32_t n_rows, + uint32_t n_cols, + float* rotation_matrix, + raft::random::Rng rng = raft::random::Rng(7ULL)) +{ + common::nvtx::range fun_scope( + "ivf_pq::make_rotation_matrix(%u * %u)", n_rows, n_cols); + auto stream = handle.get_stream(); + bool inplace = n_rows == n_cols; + uint32_t n = std::max(n_rows, n_cols); + if (force_random_rotation || !inplace) { + rmm::device_uvector buf(inplace ? 0 : n * n, stream); + float* mat = inplace ? rotation_matrix : buf.data(); + rng.normal(mat, n * n, 0.0f, 1.0f, stream); + linalg::detail::qrGetQ_inplace(handle, mat, n, n, stream); + if (!inplace) { + RAFT_CUDA_TRY(cudaMemcpy2DAsync(rotation_matrix, + sizeof(float) * n_cols, + mat, + sizeof(float) * n, + sizeof(float) * n_cols, + n_rows, + cudaMemcpyDefault, + stream)); + } + } else { + uint32_t stride = n + 1; + auto f = [stride] __device__(float* out, uint32_t i) -> void { *out = float(i % stride == 0); }; + linalg::writeOnlyUnaryOp(rotation_matrix, n * n, f, stream); + } +} + +/** + * @brief Compute residual vectors from the source dataset given by selected indices. + * + * The residual has the form `rotation_matrix %* (dataset[row_ids, :] - center)` + * + */ +template +void select_residuals(const handle_t& handle, + float* residuals, + IdxT n_rows, + uint32_t dim, + uint32_t rot_dim, + const float* rotation_matrix, // [rot_dim, dim] + const float* center, // [dim] + const T* dataset, // [.., dim] + const IdxT* row_ids, // [n_rows] + rmm::mr::device_memory_resource* device_memory + +) +{ + auto stream = handle.get_stream(); + rmm::device_uvector tmp(n_rows * dim, stream, device_memory); + utils::copy_selected(n_rows, dim, dataset, row_ids, dim, tmp.data(), dim, stream); + + raft::matrix::linewiseOp( + tmp.data(), + tmp.data(), + IdxT(dim), + n_rows, + true, + [] __device__(float a, float b) { return a - b; }, + stream, + center); + + float alpha = 1.0; + float beta = 0.0; + linalg::gemm(handle, + true, + false, + rot_dim, + n_rows, + dim, + &alpha, + rotation_matrix, + dim, + tmp.data(), + dim, + &beta, + residuals, + rot_dim, + stream); +} + +/** + * @param handle, + * @param n_rows + * @param data_dim + * @param rot_dim + * @param pq_dim + * @param pq_len + * @param pq_bits + * @param n_clusters + * @param codebook_kind + * @param max_cluster_size + * @param cluster_centers // [n_clusters, data_dim] + * @param rotation_matrix // [rot_dim, data_dim] + * @param dataset // [n_rows] + * @param data_indices + * tells which indices to select in the dataset for each cluster [n_rows]; + * it should be partitioned by the clusters by now. + * @param cluster_sizes // [n_clusters] + * @param cluster_offsets // [n_clusters + 1] + * @param pq_centers // [...] + * @param pq_dataset // [n_rows, pq_dim * pq_bits / 8] + * @param device_memory + */ +template +void compute_pq_codes(const handle_t& handle, + IdxT n_rows, + uint32_t data_dim, + uint32_t rot_dim, + uint32_t pq_dim, + uint32_t pq_len, + uint32_t pq_bits, + uint32_t n_clusters, + codebook_gen codebook_kind, + uint32_t max_cluster_size, + float* cluster_centers, + const float* rotation_matrix, + const T* dataset, + const IdxT* data_indices, + const uint32_t* cluster_sizes, + const IdxT* cluster_offsets, + const float* pq_centers, + uint8_t* pq_dataset, + rmm::mr::device_memory_resource* device_memory) +{ + common::nvtx::range fun_scope( + "ivf_pq::compute_pq_codes(n_rows = %zu, data_dim = %u, rot_dim = %u (%u * %u), n_clusters = " + "%u)", + size_t(n_rows), + data_dim, + rot_dim, + pq_dim, + pq_len, + n_clusters); + auto stream = handle.get_stream(); + + // + // Compute PQ code + // + utils::memzero(pq_dataset, n_rows * pq_dim * pq_bits / 8, stream); + + rmm::device_uvector rot_vectors(max_cluster_size * rot_dim, stream, device_memory); + rmm::device_uvector sub_vectors(max_cluster_size * pq_dim * pq_len, stream, device_memory); + rmm::device_uvector sub_vector_labels(max_cluster_size * pq_dim, stream, device_memory); + rmm::device_uvector my_pq_dataset( + max_cluster_size * pq_dim * pq_bits / 8 /* NB: pq_dim * bitPQ % 8 == 0 */, + stream, + device_memory); + + for (uint32_t l = 0; l < n_clusters; l++) { + auto cluster_size = cluster_sizes[l]; + common::nvtx::range cluster_scope( + "ivf_pq::compute_pq_codes::cluster[%u](size = %u)", l, cluster_size); + if (cluster_size == 0) continue; + + select_residuals(handle, + rot_vectors.data(), + IdxT(cluster_size), + data_dim, + rot_dim, + rotation_matrix, + cluster_centers + uint64_t(l) * data_dim, + dataset, + data_indices + cluster_offsets[l], + device_memory); + + // + // Change the order of the vector data to facilitate processing in + // each vector subspace. + // input: rot_vectors[cluster_size, rot_dim] = [cluster_size, pq_dim, pq_len] + // output: sub_vectors[pq_dim, cluster_size, pq_len] + // + for (uint32_t i = 0; i < pq_dim; i++) { + RAFT_CUDA_TRY(cudaMemcpy2DAsync(sub_vectors.data() + i * pq_len * cluster_size, + sizeof(float) * pq_len, + rot_vectors.data() + i * pq_len, + sizeof(float) * rot_dim, + sizeof(float) * pq_len, + cluster_size, + cudaMemcpyDefault, + stream)); + } + + // + // Find a label (cluster ID) for each vector subspace. + // + for (uint32_t j = 0; j < pq_dim; j++) { + const float* sub_pq_centers = nullptr; + switch (codebook_kind) { + case codebook_gen::PER_SUBSPACE: + sub_pq_centers = pq_centers + ((1 << pq_bits) * pq_len) * j; + break; + case codebook_gen::PER_CLUSTER: + sub_pq_centers = pq_centers + ((1 << pq_bits) * pq_len) * l; + break; + default: RAFT_FAIL("Unreachable code"); + } + kmeans::predict(handle, + sub_pq_centers, + (1 << pq_bits), + pq_len, + sub_vectors.data() + j * (cluster_size * pq_len), + cluster_size, + sub_vector_labels.data() + j * cluster_size, + raft::distance::DistanceType::L2Expanded, + stream, + device_memory); + } + + // + // PQ encoding + // + ivfpq_encode( + cluster_size, pq_dim, pq_bits, sub_vector_labels.data(), my_pq_dataset.data(), stream); + copy(pq_dataset + cluster_offsets[l] * uint64_t{pq_dim * pq_bits / 8}, + my_pq_dataset.data(), + cluster_size * pq_dim * pq_bits / 8, + stream); + } +} + +template +__launch_bounds__(BlockDim) __global__ void fill_indices_kernel(IdxT n_rows, + IdxT* data_indices, + IdxT* data_offsets, + const uint32_t* labels) +{ + const auto i = BlockDim * IdxT(blockIdx.x) + IdxT(threadIdx.x); + if (i >= n_rows) { return; } + data_indices[atomicAdd(data_offsets + labels[i], 1)] = i; +} + +/** + * @brief Calculate cluster offsets and arrange data indices into clusters. + * + * @param n_rows + * @param n_lists + * @param[in] labels output of k-means prediction [n_rows] + * @param[in] cluster_sizes [n_lists] + * @param[out] cluster_offsets [n_lists+1] + * @param[out] data_indices [n_rows] + * + * @return size of the largest cluster + */ +template +auto calculate_offsets_and_indices(IdxT n_rows, + uint32_t n_lists, + const uint32_t* labels, + const uint32_t* cluster_sizes, + IdxT* cluster_offsets, + IdxT* data_indices, + rmm::cuda_stream_view stream) -> uint32_t +{ + auto exec_policy = rmm::exec_policy(stream); + uint32_t max_cluster_size = 0; + rmm::device_scalar max_cluster_size_dev_buf(stream); + auto max_cluster_size_dev = max_cluster_size_dev_buf.data(); + update_device(max_cluster_size_dev, &max_cluster_size, 1, stream); + // Calculate the offsets + IdxT cumsum = 0; + update_device(cluster_offsets, &cumsum, 1, stream); + thrust::inclusive_scan(exec_policy, + cluster_sizes, + cluster_sizes + n_lists, + cluster_offsets + 1, + [max_cluster_size_dev] __device__(IdxT s, uint32_t l) { + atomicMax(max_cluster_size_dev, l); + return s + l; + }); + update_host(&cumsum, cluster_offsets + n_lists, 1, stream); + update_host(&max_cluster_size, max_cluster_size_dev, 1, stream); + stream.synchronize(); + RAFT_EXPECTS(cumsum == n_rows, "cluster sizes do not add up."); + rmm::device_uvector data_offsets_buf(n_lists, stream); + auto data_offsets = data_offsets_buf.data(); + copy(data_offsets, cluster_offsets, n_lists, stream); + constexpr uint32_t n_threads = 128; // NOLINT + const IdxT n_blocks = raft::div_rounding_up_unsafe(n_rows, n_threads); + fill_indices_kernel + <<>>(n_rows, data_indices, data_offsets, labels); + return max_cluster_size; +} + +template +void train_per_subset(const handle_t& handle, + index& index, + IdxT n_rows, + const float* trainset, // [n_rows, dim] + const uint32_t* labels, // [n_rows] + uint32_t kmeans_n_iters, + rmm::mr::device_memory_resource* managed_memory, + rmm::mr::device_memory_resource* device_memory) +{ + auto stream = handle.get_stream(); + + rmm::device_uvector sub_trainset(n_rows * index.pq_len(), stream, device_memory); + rmm::device_uvector sub_labels(n_rows, stream, device_memory); + + rmm::device_uvector pq_cluster_sizes(index.pq_book_size(), stream, device_memory); + + for (uint32_t j = 0; j < index.pq_dim(); j++) { + common::nvtx::range pq_per_subspace_scope( + "ivf_pq::build::per_subspace[%u]", j); + + // Get the rotated cluster centers for each training vector. + // This will be subtracted from the input vectors afterwards. + utils::copy_selected(n_rows, + index.pq_len(), + index.centers_rot().data_handle() + index.pq_len() * j, + labels, + index.rot_dim(), + sub_trainset.data(), + index.pq_len(), + stream); + + // sub_trainset is the slice of: rotate(trainset) - centers_rot + float alpha = 1.0; + float beta = -1.0; + linalg::gemm(handle, + true, + false, + index.pq_len(), + n_rows, + index.dim(), + &alpha, + index.rotation_matrix().data_handle() + index.dim() * index.pq_len() * j, + index.dim(), + trainset, + index.dim(), + &beta, + sub_trainset.data(), + index.pq_len(), + stream); + + // train PQ codebook for this subspace + kmeans::build_clusters( + handle, + kmeans_n_iters, + index.pq_len(), + sub_trainset.data(), + n_rows, + index.pq_book_size(), + index.pq_centers().data_handle() + (index.pq_book_size() * index.pq_len()) * j, + sub_labels.data(), + pq_cluster_sizes.data(), + raft::distance::DistanceType::L2Expanded, + stream, + device_memory); + } +} + +template +void train_per_cluster(const handle_t& handle, + index& index, + IdxT n_rows, + const float* trainset, // [n_rows, dim] + const uint32_t* labels, // [n_rows] + uint32_t kmeans_n_iters, + rmm::mr::device_memory_resource* managed_memory, + rmm::mr::device_memory_resource* device_memory) +{ + auto stream = handle.get_stream(); + rmm::device_uvector cluster_sizes(index.n_lists(), stream, managed_memory); + rmm::device_uvector indices_buf(n_rows, stream, device_memory); + rmm::device_uvector offsets_buf(index.list_offsets().size(), stream, managed_memory); + + raft::stats::histogram(raft::stats::HistTypeAuto, + reinterpret_cast(cluster_sizes.data()), + IdxT(index.n_lists()), + labels, + n_rows, + 1, + stream); + + auto cluster_offsets = offsets_buf.data(); + auto indices = indices_buf.data(); + uint32_t max_cluster_size = calculate_offsets_and_indices( + n_rows, index.n_lists(), labels, cluster_sizes.data(), cluster_offsets, indices, stream); + + rmm::device_uvector pq_labels(max_cluster_size * index.pq_dim(), stream, device_memory); + rmm::device_uvector pq_cluster_sizes(index.pq_book_size(), stream, device_memory); + rmm::device_uvector rot_vectors(max_cluster_size * index.rot_dim(), stream, device_memory); + + handle.sync_stream(); // make sure cluster offsets are up-to-date + for (uint32_t l = 0; l < index.n_lists(); l++) { + auto cluster_size = cluster_sizes.data()[l]; + if (cluster_size == 0) continue; + common::nvtx::range pq_per_cluster_scope( + "ivf_pq::build::per_cluster[%u](size = %u)", l, cluster_size); + + select_residuals(handle, + rot_vectors.data(), + IdxT(cluster_size), + index.dim(), + index.rot_dim(), + index.rotation_matrix().data_handle(), + index.centers().data_handle() + uint64_t(l) * index.dim_ext(), + trainset, + indices + cluster_offsets[l], + device_memory); + + // limit the cluster size to bound the training time. + // [sic] we interpret the data as pq_len-dimensional + size_t big_enough = 256 * std::max(index.pq_book_size(), index.pq_dim()); + size_t available_rows = cluster_size * index.pq_dim(); + auto pq_n_rows = uint32_t(std::min(big_enough, available_rows)); + // train PQ codebook for this cluster + kmeans::build_clusters( + handle, + kmeans_n_iters, + index.pq_len(), + rot_vectors.data(), + pq_n_rows, + index.pq_book_size(), + index.pq_centers().data_handle() + index.pq_book_size() * index.pq_len() * l, + pq_labels.data(), + pq_cluster_sizes.data(), + raft::distance::DistanceType::L2Expanded, + stream, + device_memory); + } +} + +/** See raft::spatial::knn::ivf_pq::extend docs */ +template +inline auto extend(const handle_t& handle, + const index& orig_index, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) -> index +{ + common::nvtx::range fun_scope( + "ivf_pq::extend(%zu, %u)", size_t(n_rows), orig_index.dim()); + auto stream = handle.get_stream(); + + RAFT_EXPECTS(new_indices != nullptr || orig_index.size() == 0, + "You must pass data indices when the index is non-empty."); + + static_assert(std::is_same_v || std::is_same_v || std::is_same_v, + "Unsupported data type"); + + rmm::mr::device_memory_resource* device_memory = nullptr; + auto pool_guard = raft::get_pool_memory_resource(device_memory, 1024 * 1024); + if (pool_guard) { + RAFT_LOG_DEBUG("ivf_pq::extend: using pool memory resource with initial size %zu bytes", + pool_guard->pool_size()); + } + + rmm::mr::managed_memory_resource managed_memory_upstream; + rmm::mr::pool_memory_resource managed_memory( + &managed_memory_upstream, 1024 * 1024); + + // + // The cluster_centers stored in index contain data other than cluster + // centroids to speed up the search. Here, only the cluster centroids + // are extracted. + // + const auto n_clusters = orig_index.n_lists(); + + rmm::device_uvector cluster_centers(n_clusters * orig_index.dim(), stream, device_memory); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(cluster_centers.data(), + sizeof(float) * orig_index.dim(), + orig_index.centers().data_handle(), + sizeof(float) * orig_index.dim_ext(), + sizeof(float) * orig_index.dim(), + n_clusters, + cudaMemcpyDefault, + stream)); + + // + // Use the existing cluster centroids to find the label (cluster ID) + // of the vector to be added. + // + + rmm::device_uvector new_data_labels(n_rows, stream, device_memory); + utils::memzero(new_data_labels.data(), n_rows, stream); + rmm::device_uvector new_cluster_sizes_buf(n_clusters, stream, &managed_memory); + auto new_cluster_sizes = new_cluster_sizes_buf.data(); + utils::memzero(new_cluster_sizes, n_clusters, stream); + + kmeans::predict(handle, + cluster_centers.data(), + n_clusters, + orig_index.dim(), + new_vectors, + n_rows, + new_data_labels.data(), + orig_index.metric(), + stream); + raft::stats::histogram(raft::stats::HistTypeAuto, + reinterpret_cast(new_cluster_sizes), + IdxT(n_clusters), + new_data_labels.data(), + n_rows, + 1, + stream); + + // + // Make new_cluster_offsets, new_data_indices + // + rmm::device_uvector new_data_indices(n_rows, stream, &managed_memory); + rmm::device_uvector new_cluster_offsets(n_clusters + 1, stream, &managed_memory); + uint32_t new_max_cluster_size = calculate_offsets_and_indices(n_rows, + n_clusters, + new_data_labels.data(), + new_cluster_sizes, + new_cluster_offsets.data(), + new_data_indices.data(), + stream); + + // + // Compute PQ code for new vectors + // + rmm::device_uvector new_pq_codes( + n_rows * orig_index.pq_dim() * orig_index.pq_bits() / 8, stream, device_memory); + compute_pq_codes(handle, + n_rows, + orig_index.dim(), + orig_index.rot_dim(), + orig_index.pq_dim(), + orig_index.pq_len(), + orig_index.pq_bits(), + n_clusters, + orig_index.codebook_kind(), + new_max_cluster_size, + cluster_centers.data(), + orig_index.rotation_matrix().data_handle(), + new_vectors, + new_data_indices.data(), + new_cluster_sizes, + new_cluster_offsets.data(), + orig_index.pq_centers().data_handle(), + new_pq_codes.data(), + device_memory); + + // Get the combined cluster sizes and sort the clusters in decreasing order + // (this makes it easy to estimate the max number of samples during search). + rmm::device_uvector old_cluster_sizes_buf(n_clusters, stream, &managed_memory); + rmm::device_uvector ext_cluster_sizes_buf(n_clusters, stream, &managed_memory); + rmm::device_uvector old_cluster_offsets_buf(n_clusters + 1, stream, &managed_memory); + rmm::device_uvector ext_cluster_offsets_buf(n_clusters + 1, stream, &managed_memory); + rmm::device_uvector cluster_ordering(n_clusters, stream, &managed_memory); + auto old_cluster_sizes = old_cluster_sizes_buf.data(); + auto ext_cluster_sizes = ext_cluster_sizes_buf.data(); + auto old_cluster_offsets = old_cluster_offsets_buf.data(); + auto ext_cluster_offsets = ext_cluster_offsets_buf.data(); + copy(old_cluster_offsets, + orig_index.list_offsets().data_handle(), + orig_index.list_offsets().size(), + stream); + + uint32_t n_nonempty_lists = 0; + { + rmm::device_uvector ext_cluster_sizes_buf_in(n_clusters, stream, device_memory); + rmm::device_uvector cluster_ordering_in(n_clusters, stream, device_memory); + auto ext_cluster_sizes_in = ext_cluster_sizes_buf_in.data(); + linalg::writeOnlyUnaryOp( + old_cluster_sizes, + n_clusters, + [ext_cluster_sizes_in, new_cluster_sizes, old_cluster_offsets] __device__(uint32_t * out, + size_t i) { + auto old_size = old_cluster_offsets[i + 1] - old_cluster_offsets[i]; + ext_cluster_sizes_in[i] = old_size + new_cluster_sizes[i]; + *out = old_size; + }, + stream); + + thrust::sequence(handle.get_thrust_policy(), + cluster_ordering_in.data(), + cluster_ordering_in.data() + n_clusters); + + int begin_bit = 0; + int end_bit = sizeof(uint32_t) * 8; + size_t cub_workspace_size = 0; + cub::DeviceRadixSort::SortPairsDescending(nullptr, + cub_workspace_size, + ext_cluster_sizes_in, + ext_cluster_sizes, + cluster_ordering_in.data(), + cluster_ordering.data(), + n_clusters, + begin_bit, + end_bit, + stream); + rmm::device_buffer cub_workspace(cub_workspace_size, stream, device_memory); + cub::DeviceRadixSort::SortPairsDescending(cub_workspace.data(), + cub_workspace_size, + ext_cluster_sizes_in, + ext_cluster_sizes, + cluster_ordering_in.data(), + cluster_ordering.data(), + n_clusters, + begin_bit, + end_bit, + stream); + + n_nonempty_lists = thrust::lower_bound(handle.get_thrust_policy(), + ext_cluster_sizes, + ext_cluster_sizes + n_clusters, + 0, + thrust::greater()) - + ext_cluster_sizes; + } + + // Assemble the extended index + ivf_pq::index ext_index(handle, + orig_index.metric(), + orig_index.codebook_kind(), + n_clusters, + orig_index.dim(), + orig_index.pq_bits(), + orig_index.pq_dim(), + n_nonempty_lists); + ext_index.allocate(handle, orig_index.size() + n_rows); + + // Copy the unchanged parts + copy(ext_index.rotation_matrix().data_handle(), + orig_index.rotation_matrix().data_handle(), + orig_index.rotation_matrix().size(), + stream); + + // calculate extended cluster offsets + auto ext_indices = ext_index.indices().data_handle(); + { + IdxT zero = 0; + update_device(ext_cluster_offsets, &zero, 1, stream); + thrust::inclusive_scan(handle.get_thrust_policy(), + ext_cluster_sizes, + ext_cluster_sizes + n_clusters, + ext_cluster_offsets + 1, + [] __device__(IdxT s, uint32_t l) { return s + l; }); + copy(ext_index.list_offsets().data_handle(), + ext_cluster_offsets, + ext_index.list_offsets().size(), + stream); + } + + // copy cluster-ordering-dependent data + utils::copy_selected(n_clusters, + ext_index.dim_ext(), + orig_index.centers().data_handle(), + cluster_ordering.data(), + orig_index.dim_ext(), + ext_index.centers().data_handle(), + ext_index.dim_ext(), + stream); + utils::copy_selected(n_clusters, + ext_index.rot_dim(), + orig_index.centers_rot().data_handle(), + cluster_ordering.data(), + orig_index.rot_dim(), + ext_index.centers_rot().data_handle(), + ext_index.rot_dim(), + stream); + switch (orig_index.codebook_kind()) { + case codebook_gen::PER_SUBSPACE: { + copy(ext_index.pq_centers().data_handle(), + orig_index.pq_centers().data_handle(), + orig_index.pq_centers().size(), + stream); + } break; + case codebook_gen::PER_CLUSTER: { + auto d = orig_index.pq_book_size() * orig_index.pq_len(); + utils::copy_selected(n_clusters, + d, + orig_index.pq_centers().data_handle(), + cluster_ordering.data(), + d, + ext_index.pq_centers().data_handle(), + d, + stream); + } break; + default: RAFT_FAIL("Unreachable code"); + } + + // Make ext_indices + handle.sync_stream(); // make sure cluster sizes are up-to-date + for (uint32_t l = 0; l < ext_index.n_lists(); l++) { + auto k = cluster_ordering.data()[l]; + auto old_cluster_size = old_cluster_sizes[k]; + auto new_cluster_size = new_cluster_sizes[k]; + if (old_cluster_size > 0) { + copy(ext_indices + ext_cluster_offsets[l], + orig_index.indices().data_handle() + old_cluster_offsets[k], + old_cluster_size, + stream); + } + if (new_cluster_size > 0) { + if (new_indices == nullptr) { + // implies the orig index is empty + copy(ext_indices + ext_cluster_offsets[l] + old_cluster_size, + new_data_indices.data() + new_cluster_offsets.data()[k], + new_cluster_size, + stream); + } else { + utils::copy_selected(new_cluster_size, + 1, + new_indices, + new_data_indices.data() + new_cluster_offsets.data()[k], + 1, + ext_indices + ext_cluster_offsets[l] + old_cluster_size, + 1, + stream); + } + } + } + + /* Extend the pq_dataset */ + auto ext_pq_dataset = ext_index.pq_dataset().data_handle(); + size_t pq_dataset_unit = ext_index.pq_dim() * ext_index.pq_bits() / 8; + for (uint32_t l = 0; l < ext_index.n_lists(); l++) { + auto k = cluster_ordering.data()[l]; + auto old_cluster_size = old_cluster_sizes[k]; + copy(ext_pq_dataset + pq_dataset_unit * ext_cluster_offsets[l], + orig_index.pq_dataset().data_handle() + pq_dataset_unit * old_cluster_offsets[k], + pq_dataset_unit * old_cluster_size, + stream); + copy(ext_pq_dataset + pq_dataset_unit * (ext_cluster_offsets[l] + old_cluster_size), + new_pq_codes.data() + pq_dataset_unit * new_cluster_offsets.data()[k], + pq_dataset_unit * new_cluster_sizes[k], + stream); + } + + return ext_index; +} + +/** See raft::spatial::knn::ivf_pq::build docs */ +template +inline auto build( + const handle_t& handle, const index_params& params, const T* dataset, IdxT n_rows, uint32_t dim) + -> index +{ + common::nvtx::range fun_scope( + "ivf_pq::build(%zu, %u)", size_t(n_rows), dim); + static_assert(std::is_same_v || std::is_same_v || std::is_same_v, + "Unsupported data type"); + + RAFT_EXPECTS(n_rows > 0 && dim > 0, "empty dataset"); + + auto stream = handle.get_stream(); + + ivf_pq::index index(handle, params, dim); + utils::memzero(index.list_offsets().data_handle(), index.list_offsets().size(), stream); + + auto trainset_ratio = std::max( + 1, n_rows / std::max(params.kmeans_trainset_fraction * n_rows, index.n_lists())); + auto n_rows_train = n_rows / trainset_ratio; + + rmm::mr::device_memory_resource* device_memory = nullptr; + auto pool_guard = raft::get_pool_memory_resource(device_memory, 1024 * 1024); + if (pool_guard) { + RAFT_LOG_DEBUG("ivf_pq::build: using pool memory resource with initial size %zu bytes", + pool_guard->pool_size()); + } + + rmm::mr::managed_memory_resource managed_memory_upstream; + rmm::mr::pool_memory_resource managed_memory( + &managed_memory_upstream, 1024 * 1024); + + // Besides just sampling, we transform the input dataset into floats to make it easier + // to use gemm operations from cublas. + rmm::device_uvector trainset(n_rows_train * index.dim(), stream, device_memory); + // TODO: a proper sampling + if constexpr (std::is_same_v) { + RAFT_CUDA_TRY(cudaMemcpy2DAsync(trainset.data(), + sizeof(T) * index.dim(), + dataset, + sizeof(T) * index.dim() * trainset_ratio, + sizeof(T) * index.dim(), + n_rows_train, + cudaMemcpyDefault, + stream)); + } else { + auto dim = index.dim(); + linalg::writeOnlyUnaryOp( + trainset.data(), + index.dim() * n_rows_train, + [dataset, trainset_ratio, dim] __device__(float* out, size_t i) { + auto col = i % dim; + *out = utils::mapping{}(dataset[(i - col) * trainset_ratio + col]); + }, + stream); + } + + // NB: here cluster_centers is used as if it is [n_clusters, data_dim] not [n_clusters, dim_ext]! + rmm::device_uvector cluster_centers_buf( + index.n_lists() * index.dim(), stream, device_memory); + auto cluster_centers = cluster_centers_buf.data(); + + // Train balanced hierarchical kmeans clustering + kmeans::build_hierarchical(handle, + params.kmeans_n_iters, + index.dim(), + trainset.data(), + n_rows_train, + cluster_centers, + index.n_lists(), + index.metric(), + stream); + + // Trainset labels are needed for training PQ codebooks + rmm::device_uvector labels(n_rows_train, stream, device_memory); + kmeans::predict(handle, + cluster_centers, + index.n_lists(), + index.dim(), + trainset.data(), + n_rows_train, + labels.data(), + index.metric(), + stream, + device_memory); + + { + // combine cluster_centers and their norms + RAFT_CUDA_TRY(cudaMemcpy2DAsync(index.centers().data_handle(), + sizeof(float) * index.dim_ext(), + cluster_centers, + sizeof(float) * index.dim(), + sizeof(float) * index.dim(), + index.n_lists(), + cudaMemcpyDefault, + stream)); + + rmm::device_uvector center_norms(index.n_lists(), stream, device_memory); + utils::dots_along_rows( + index.n_lists(), index.dim(), cluster_centers, center_norms.data(), stream); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(index.centers().data_handle() + index.dim(), + sizeof(float) * index.dim_ext(), + center_norms.data(), + sizeof(float), + sizeof(float), + index.n_lists(), + cudaMemcpyDefault, + stream)); + } + + // Make rotation matrix + make_rotation_matrix(handle, + params.force_random_rotation, + index.rot_dim(), + index.dim(), + index.rotation_matrix().data_handle()); + + // Rotate cluster_centers + float alpha = 1.0; + float beta = 0.0; + linalg::gemm(handle, + true, + false, + index.rot_dim(), + index.n_lists(), + index.dim(), + &alpha, + index.rotation_matrix().data_handle(), + index.dim(), + cluster_centers, + index.dim(), + &beta, + index.centers_rot().data_handle(), + index.rot_dim(), + stream); + + // Train PQ codebooks + switch (index.codebook_kind()) { + case codebook_gen::PER_SUBSPACE: + train_per_subset(handle, + index, + n_rows_train, + trainset.data(), + labels.data(), + params.kmeans_n_iters, + &managed_memory, + device_memory); + break; + case codebook_gen::PER_CLUSTER: + train_per_cluster(handle, + index, + n_rows_train, + trainset.data(), + labels.data(), + params.kmeans_n_iters, + &managed_memory, + device_memory); + break; + default: RAFT_FAIL("Unreachable code"); + } + + // add the data if necessary + if (params.add_data_on_build) { + return detail::extend(handle, index, dataset, nullptr, n_rows); + } else { + return index; + } +} + +} // namespace raft::spatial::knn::ivf_pq::detail diff --git a/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh new file mode 100644 index 0000000000..73030ea53f --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh @@ -0,0 +1,1374 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../ivf_pq_types.hpp" +#include "ann_utils.cuh" +#include "topk.cuh" +#include "topk/warpsort_topk.cuh" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +#include + +namespace raft::spatial::knn::ivf_pq::detail { + +/** + * Maximum value of k for the fused calculate & select in ivfpq. + * + * If runtime value of k is larger than this, the main search operation + * is split into two kernels (per batch, first calculate distance, then select top-k). + */ +static constexpr int kMaxCapacity = 128; +static_assert((kMaxCapacity >= 32) && !(kMaxCapacity & (kMaxCapacity - 1)), + "kMaxCapacity must be a power of two, not smaller than the WarpSize."); + +using namespace raft::spatial::knn::detail; // NOLINT + +/** 8-bit floating-point storage type. + * + * This is a custom type for the current IVF-PQ implementation. No arithmetic operations defined + * only conversion to and from fp32. This type is unrelated to the proposed FP8 specification. + */ +template +struct fp_8bit { + static_assert(ExpBits + uint8_t{Signed} <= 8, "The type does not fit in 8 bits."); + constexpr static uint32_t ExpMask = (1u << (ExpBits - 1u)) - 1u; // NOLINT + constexpr static uint32_t ValBits = 8u - ExpBits; // NOLINT + + public: + uint8_t bitstring; + + HDI explicit fp_8bit(uint8_t bs) : bitstring(bs) {} + HDI explicit fp_8bit(float fp) : fp_8bit(float2fp_8bit(fp).bitstring) {} + HDI auto operator=(float fp) -> fp_8bit& + { + bitstring = float2fp_8bit(fp).bitstring; + return *this; + } + HDI explicit operator float() const { return fp_8bit2float(*this); } + + private: + static constexpr float kMin = 1.0f / float(1u << ExpMask); + static constexpr float kMax = float(1u << (ExpMask + 1)) * (2.0f - 1.0f / float(1u << ValBits)); + + static HDI auto float2fp_8bit(float v) -> fp_8bit + { + if constexpr (Signed) { + auto u = fp_8bit(std::abs(v)).bitstring; + u = (u & 0xfeu) | uint8_t{v < 0}; // set the sign bit + return fp_8bit(u); + } else { + // sic! all small and negative numbers are truncated to zero. + if (v < kMin) { return fp_8bit{static_cast(0)}; } + // protect from overflow + if (v >= kMax) { return fp_8bit{static_cast(0xffu)}; } + // the rest of possible float values should be within the normalized range + return fp_8bit{static_cast( + (*reinterpret_cast(&v) + (ExpMask << 23u) - 0x3f800000u) >> (15u + ExpBits))}; + } + } + + static HDI auto fp_8bit2float(const fp_8bit& v) -> float + { + uint32_t u = v.bitstring; + if constexpr (Signed) { + u &= ~1; // zero the sign bit + } + float r; + *reinterpret_cast(&r) = + ((u << (15u + ExpBits)) + (0x3f800000u | (0x00400000u >> ValBits)) - (ExpMask << 23)); + if constexpr (Signed) { // recover the sign bit + if (v.bitstring & 1) { r = -r; } + } + return r; + } +}; + +/** + * Select the clusters to probe and, as a side-effect, translate the queries type `T -> float` + * + * Assuming the number of clusters is not that big (a few thousands), we do a plain GEMM + * followed by select_topk to select the clusters to probe. There's no need to return the similarity + * scores here. + */ +template +void select_clusters(const handle_t& handle, + uint32_t* clusters_to_probe, // [n_queries, n_probes] + float* float_queries, // [n_queries, dim_ext] + uint32_t n_queries, + uint32_t n_probes, + uint32_t n_lists, + uint32_t dim, + uint32_t dim_ext, + raft::distance::DistanceType metric, + const T* queries, // [n_queries, dim] + const float* cluster_centers, // [n_lists, dim_ext] + rmm::mr::device_memory_resource* mr) +{ + auto stream = handle.get_stream(); + rmm::device_uvector qc_distances(n_queries * n_lists, stream, mr); + /* NOTE[qc_distances] + + We compute query-center distances to choose the clusters to probe. + We accomplish that with just one GEMM operation thanks to some preprocessing: + + L2 distance: + cluster_centers[i, dim()] contains the squared norm of the center vector i; + we extend the dimension K of the GEMM to compute it together with all the dot products: + + `cq_distances[i, j] = |cluster_centers[j]|^2 - 2 * (queries[i], cluster_centers[j])` + + This is a monotonous mapping of the proper L2 distance. + + IP distance: + `cq_distances[i, j] = - (queries[i], cluster_centers[j])` + + This is a negative inner-product distance. We minimize it to find the similar clusters. + + NB: cq_distances is NOT used further in ivfpq_search. + */ + float norm_factor; + switch (metric) { + case raft::distance::DistanceType::L2Expanded: norm_factor = 1.0 / -2.0; break; + case raft::distance::DistanceType::InnerProduct: norm_factor = 0.0; break; + default: RAFT_FAIL("Unsupported distance type %d.", int(metric)); + } + linalg::writeOnlyUnaryOp( + float_queries, + dim_ext * n_queries, + [queries, dim, dim_ext, norm_factor] __device__(float* out, uint32_t ix) { + uint32_t col = ix % dim_ext; + uint32_t row = ix / dim_ext; + *out = col < dim ? utils::mapping{}(queries[col + dim * row]) : norm_factor; + }, + stream); + + float alpha; + float beta; + uint32_t gemm_k = dim; + switch (metric) { + case raft::distance::DistanceType::L2Expanded: { + alpha = -2.0; + beta = 0.0; + gemm_k = dim + 1; + RAFT_EXPECTS(gemm_k <= dim_ext, "unexpected gemm_k or dim_ext"); + } break; + case raft::distance::DistanceType::InnerProduct: { + alpha = -1.0; + beta = 0.0; + } break; + default: RAFT_FAIL("Unsupported distance type %d.", int(metric)); + } + linalg::gemm(handle, + true, + false, + n_lists, + n_queries, + gemm_k, + &alpha, + cluster_centers, + dim_ext, + float_queries, + dim_ext, + &beta, + qc_distances.data(), + n_lists, + stream); + + // Select neighbor clusters for each query. + rmm::device_uvector cluster_dists(n_queries * n_probes, stream, mr); + select_topk(qc_distances.data(), + nullptr, + n_queries, + n_lists, + n_probes, + cluster_dists.data(), + clusters_to_probe, + true, + stream, + mr); +} + +/** + * For each query, we calculate a cumulative sum of the cluster sizes that we probe, and return that + * in chunk_indices. Essentially this is a segmented inclusive scan of the cluster sizes. The total + * number of samples per query (sum of the cluster sizes that we probe) is returned in n_samples. + */ +template +__launch_bounds__(BlockDim) __global__ + void calc_chunk_indices_kernel(uint32_t n_probes, + const IdxT* cluster_offsets, // [n_clusters + 1] + const uint32_t* clusters_to_probe, // [n_queries, n_probes] + uint32_t* chunk_indices, // [n_queries, n_probes] + uint32_t* n_samples // [n_queries] + ) +{ + using block_scan = cub::BlockScan; + __shared__ typename block_scan::TempStorage shm; + + // locate the query data + clusters_to_probe += n_probes * blockIdx.x; + chunk_indices += n_probes * blockIdx.x; + + // block scan + const uint32_t n_probes_aligned = Pow2::roundUp(n_probes); + uint32_t total = 0; + for (uint32_t probe_ix = threadIdx.x; probe_ix < n_probes_aligned; probe_ix += BlockDim) { + auto label = probe_ix < n_probes ? clusters_to_probe[probe_ix] : 0u; + auto chunk = probe_ix < n_probes + ? static_cast(cluster_offsets[label + 1] - cluster_offsets[label]) + : 0u; + if (threadIdx.x == 0) { chunk += total; } + block_scan(shm).InclusiveSum(chunk, chunk, total); + __syncthreads(); + if (probe_ix < n_probes) { chunk_indices[probe_ix] = chunk; } + } + // save the total size + if (threadIdx.x == 0) { n_samples[blockIdx.x] = total; } +} + +template +struct calc_chunk_indices { + public: + using kernel_t = void (*)(uint32_t, const IdxT*, const uint32_t*, uint32_t*, uint32_t*); + + struct configured { + kernel_t kernel; + uint32_t block_dim; + uint32_t n_probes; + uint32_t n_queries; + + void operator()(const IdxT* cluster_offsets, + const uint32_t* clusters_to_probe, + uint32_t* chunk_indices, + uint32_t* n_samples, + rmm::cuda_stream_view stream) + { + kernel<<>>( + n_probes, cluster_offsets, clusters_to_probe, chunk_indices, n_samples); + } + }; + + static auto configure(uint32_t n_probes, uint32_t n_queries) -> configured + { + return try_block_dim<1024>(n_probes, n_queries); + } + + private: + template + static auto try_block_dim(uint32_t n_probes, uint32_t n_queries) -> configured + { + if constexpr (BlockDim >= WarpSize * 2) { + if (BlockDim >= n_probes * 2) { return try_block_dim<(BlockDim / 2)>(n_probes, n_queries); } + } + return {calc_chunk_indices_kernel, BlockDim, n_probes, n_queries}; + } +}; + +/** + * Look up the dataset index that corresponds to a sample index. + * + * Each query vector was compared to all the vectors from n_probes clusters, and sample_ix is one of + * such vector. This function looks up which cluster sample_ix belongs to, and returns the original + * dataset index for that vector. + * + * @return whether the input index is in a valid range + * (the opposite can happen if there is not enough data to output in the selected clusters). + */ +template +__device__ auto find_db_row(IdxT& x, // NOLINT + uint32_t n_probes, + const IdxT* cluster_offsets, // [n_clusters + 1,] + const uint32_t* cluster_labels, // [n_probes,] + const uint32_t* chunk_indices // [n_probes,] + ) -> bool +{ + uint32_t ix_min = 0; + uint32_t ix_max = n_probes; + do { + uint32_t i = (ix_min + ix_max) / 2; + if (IdxT(chunk_indices[i]) < x) { + ix_min = i + 1; + } else { + ix_max = i; + } + } while (ix_min < ix_max); + if (ix_min == n_probes) { return false; } + if (ix_min > 0) { x -= chunk_indices[ix_min - 1]; } + x += cluster_offsets[cluster_labels[ix_min]]; + return true; +} + +template +__launch_bounds__(BlockDim) __global__ + void postprocess_neighbors_kernel(IdxT* neighbors, // [n_queries, topk] + const IdxT* db_indices, // [n_rows] + const IdxT* cluster_offsets, // [n_clusters + 1] + const uint32_t* clusters_to_probe, // [n_queries, n_probes] + const uint32_t* chunk_indices, // [n_queries, n_probes] + uint32_t n_queries, + uint32_t n_probes, + uint32_t topk) +{ + uint64_t i = threadIdx.x + BlockDim * uint64_t(blockIdx.x); + uint32_t query_ix = i / uint64_t(topk); + if (query_ix >= n_queries) { return; } + uint32_t k = i % uint64_t(topk); + neighbors += query_ix * topk; + IdxT data_ix = neighbors[k]; + // backtrace the index if we don't have local top-k + bool valid = true; + if (n_probes > 0) { + valid = find_db_row(data_ix, + n_probes, + cluster_offsets, + clusters_to_probe + n_probes * query_ix, + chunk_indices + n_probes * query_ix); + } + neighbors[k] = valid ? db_indices[data_ix] : std::numeric_limits::max(); +} + +/** + * Transform found neighbor indices into the corresponding database indices + * (as stored in index.indices()). + * + * When the main kernel runs with a fused top-k (`manage_local_topk == true`), this function simply + * fetches the index values by the returned row ids. Otherwise, the found neighors require extra + * pre-processing (performed by `find_db_row`). + */ +template +void postprocess_neighbors(IdxT* neighbors, // [n_queries, topk] + bool manage_local_topk, + const IdxT* db_indices, // [n_rows] + const IdxT* cluster_offsets, // [n_clusters + 1] + const uint32_t* clusters_to_probe, // [n_queries, n_probes] + const uint32_t* chunk_indices, // [n_queries, n_probes] + uint32_t n_queries, + uint32_t n_probes, + uint32_t topk, + rmm::cuda_stream_view stream) +{ + constexpr int kPNThreads = 256; + const int pn_blocks = raft::div_rounding_up_unsafe(n_queries * topk, kPNThreads); + postprocess_neighbors_kernel + <<>>(neighbors, + db_indices, + cluster_offsets, + clusters_to_probe, + chunk_indices, + n_queries, + manage_local_topk ? 0u : n_probes, + topk); +} + +/** + * Post-process the scores depending on the metric type; + * translate the element type if necessary. + */ +template +void postprocess_distances(float* out, // [n_queries, topk] + const ScoreT* in, // [n_queries, topk] + distance::DistanceType metric, + uint32_t n_queries, + uint32_t topk, + rmm::cuda_stream_view stream) +{ + size_t len = size_t(n_queries) * size_t(topk); + switch (metric) { + case distance::DistanceType::L2Unexpanded: + case distance::DistanceType::L2Expanded: { + linalg::unaryOp( + out, in, len, [] __device__(ScoreT x) -> float { return float(x); }, stream); + } break; + case distance::DistanceType::L2SqrtUnexpanded: + case distance::DistanceType::L2SqrtExpanded: { + linalg::unaryOp( + out, in, len, [] __device__(ScoreT x) -> float { return sqrtf(float(x)); }, stream); + } break; + case distance::DistanceType::InnerProduct: { + linalg::unaryOp( + out, in, len, [] __device__(ScoreT x) -> float { return -float(x); }, stream); + } break; + default: RAFT_FAIL("Unexpected metric."); + } +} + +/** + * @brief Compute the similarity score between a vector from `pq_dataset` and a query vector. + * + * @tparam OpT an unsigned integer type that is used for bit operations on multiple PQ codes + * at once; it's selected to maximize throughput while matching criteria: + * 1. `pq_bits * vec_len % 8 * sizeof(OpT) == 0`. + * 2. `pq_dim % vec_len == 0` + * + * @tparam LutT type of the elements in the lookup table. + * + * @param pq_bits The bit length of an encoded vector element after compression by PQ + * @param vec_len == 8 * sizeof(OpT) / gcd(8 * sizeof(OpT), pq_bits) + * @param pq_dim + * @param[in] pq_code_ptr + * a device pointer to the dataset at the indexed position (`pq_dim * pq_bits` bits-wide) + * @param[in] lut_scores + * a device or shared memory pointer to the lookup table [pq_dim, pq_book_size] + * + * @return the score for the entry `data_ix` in the `pq_dataset`. + */ +template +__device__ auto ivfpq_compute_score( + uint32_t pq_bits, uint32_t vec_len, uint32_t pq_dim, const OpT* pq_head, const LutT* lut_scores) + -> float +{ + float score = 0.0; + constexpr uint32_t kBitsTotal = 8 * sizeof(OpT); + for (; pq_dim > 0; pq_dim -= vec_len) { + OpT pq_code = pq_head[0]; + pq_head++; + auto bits_left = kBitsTotal; + for (uint32_t k = 0; k < vec_len; k++) { + uint8_t code = pq_code; + if (bits_left > pq_bits) { + pq_code >>= pq_bits; + bits_left -= pq_bits; + } else { + if (k < vec_len - 1) { + pq_code = pq_head[0]; + pq_head++; + } + code |= (pq_code << bits_left); + pq_code >>= (pq_bits - bits_left); + bits_left += (kBitsTotal - pq_bits); + } + code &= (1 << pq_bits) - 1; + score += float(lut_scores[code]); + lut_scores += (1 << pq_bits); + } + } + return score; +} + +template +struct dummy_block_sort_t { + using queue_t = topk::warp_sort_immediate; + __device__ dummy_block_sort_t(int k, uint8_t* smem_buf){}; +}; + +template +struct pq_block_sort { + using type = topk::block_sort; +}; + +template +struct pq_block_sort<0, T, IdxT> : dummy_block_sort_t { + using type = dummy_block_sort_t; +}; + +template +using block_sort_t = typename pq_block_sort::type; + +/** + * The main kernel that computes similarity scores across multiple queries and probes. + * When `Capacity > 0`, it also selects top K candidates for each query and probe + * (which need to be merged across probes afterwards). + * + * Each block processes a (query, probe) pair: it calculates the distance between the single query + * vector and all the dataset vector in the cluster that we are probing. + * + * @tparam OpT is a carrier integer type selected to maximize throughput; + * Used solely in `ivfpq_compute_score`; + * @tparam IdxT + * The type of data indices + * @tparam OutT + * The output type - distances. + * @tparam LutT + * The lookup table element type (lut_scores). + * @tparam Capacity + * Power-of-two; the maximum possible `k` in top-k. Value zero disables fused top-k search. + * @tparam PrecompBaseDiff + * Defines whether we should precompute part of the distance and keep it in shared memory + * before the main part (score calculation) to increase memory usage efficiency in the latter. + * For L2, this is the distance between the query and the cluster center. + * @tparam EnableSMemLut + * Defines whether to use the shared memory for the lookup table (`lut_scores`). + * Setting this to `false` allows to reduce the shared memory usage (and maximum data dim) + * at the cost of reducing global memory reading throughput. + * + * @param n_rows the number of records in the dataset + * @param dim the dimensionality of the data (NB: after rotation transform, i.e. `index.rot_dim()`). + * @param n_probes the number of clusters to search for each query + * @param pq_bits the bit length of an encoded vector element after compression by PQ + * (NB: pq_book_size = 1 << pq_bits). + * @param pq_dim + * The dimensionality of an encoded vector after compression by PQ. + * @param n_queries the number of queries. + * @param metric the distance type. + * @param codebook_kind Defines the way PQ codebooks have been trained. + * @param topk the `k` in the select top-k. + * @param cluster_centers + * The device pointer to the cluster centers in the original space (NB: after rotation) + * [n_clusters, dim]. + * @param pq_centers + * The device pointer to the cluster centers in the PQ space + * [pq_dim, pq_book_size, pq_len] or [n_clusters, pq_book_size, pq_len,]. + * @param pq_dataset + * The device pointer to the PQ index (data) [n_rows, pq_dim * pq_bits / 8]. + * @param cluster_offsets + * The device pointer to the cluster offsets [n_clusters + 1]. + * @param cluster_labels + * The device pointer to the labels (clusters) for each query and probe [n_queries, n_probes]. + * @param _chunk_indices + * The device pointer to the data offsets for each query and probe [n_queries, n_probes]. + * @param queries + * The device pointer to the queries (NB: after rotation) [n_queries, dim]. + * @param index_list + * An optional device pointer to the enforced order of search [n_queries, n_probes]. + * One can pass reordered indices here to try to improve data reading locality. + * @param lut_scores + * The device pointer for storing the lookup table globally [gridDim.x, pq_dim << pq_bits]. + * Ignored when `EnableSMemLut == true`. + * @param _out_scores + * The device pointer to the output scores + * [n_queries, max_samples] or [n_queries, n_probes, topk]. + * @param _out_indices + * The device pointer to the output indices [n_queries, n_probes, topk]. + * Ignored when `Capacity == 0`. + */ +template +__launch_bounds__(1024) __global__ + void ivfpq_compute_similarity_kernel(uint32_t n_rows, + uint32_t dim, + uint32_t n_probes, + uint32_t pq_bits, + uint32_t pq_dim, + uint32_t n_queries, + distance::DistanceType metric, + codebook_gen codebook_kind, + uint32_t topk, + const float* cluster_centers, + const float* pq_centers, + const uint8_t* pq_dataset, + const IdxT* cluster_offsets, + const uint32_t* cluster_labels, + const uint32_t* _chunk_indices, + const float* queries, + const uint32_t* index_list, + LutT* lut_scores, + OutT* _out_scores, + IdxT* _out_indices) +{ + /* Shared memory: + + * lut_scores: lookup table (LUT) of size = `pq_dim << pq_bits` (when EnableSMemLut) + * base_diff: size = dim (which is equal to `pq_dim * pq_len`) + * topk::block_sort: some amount of shared memory, but overlaps with the rest: + block_sort only needs shared memory for `.done()` operation, which can come very last. + */ + extern __shared__ __align__(256) uint8_t smem_buf[]; // NOLINT + constexpr bool kManageLocalTopK = Capacity > 0; + constexpr uint32_t kOpBits = 8 * sizeof(OpT); + + const uint32_t pq_len = dim / pq_dim; + const uint32_t vec_len = kOpBits / gcd(kOpBits, pq_bits); + + if constexpr (EnableSMemLut) { + lut_scores = reinterpret_cast(smem_buf); + } else { + lut_scores += (pq_dim << pq_bits) * blockIdx.x; + } + + float* base_diff = nullptr; + if constexpr (PrecompBaseDiff) { + if constexpr (EnableSMemLut) { + base_diff = reinterpret_cast(lut_scores + (pq_dim << pq_bits)); + } else { + base_diff = reinterpret_cast(smem_buf); + } + } + + for (int ib = blockIdx.x; ib < n_queries * n_probes; ib += gridDim.x) { + uint32_t query_ix; + uint32_t probe_ix; + if (index_list == nullptr) { + query_ix = ib % n_queries; + probe_ix = ib / n_queries; + } else { + query_ix = index_list[ib] / n_probes; + probe_ix = index_list[ib] % n_probes; + } + if (query_ix >= n_queries || probe_ix >= n_probes) continue; + + const uint32_t* chunk_indices = _chunk_indices + (n_probes * query_ix); + const float* query = queries + (dim * query_ix); + OutT* out_scores; + IdxT* out_indices = nullptr; + if constexpr (kManageLocalTopK) { + // Store topk calculated distances to out_scores (and its indices to out_indices) + out_scores = _out_scores + topk * (probe_ix + (n_probes * query_ix)); + out_indices = _out_indices + topk * (probe_ix + (n_probes * query_ix)); + } else { + // Store all calculated distances to out_scores + auto max_samples = cluster_offsets[n_probes]; + out_scores = _out_scores + max_samples * query_ix; + } + uint32_t label = cluster_labels[n_probes * query_ix + probe_ix]; + const float* cluster_center = cluster_centers + (dim * label); + const float* pq_center; + if (codebook_kind == codebook_gen::PER_SUBSPACE) { + pq_center = pq_centers; + } else { + pq_center = pq_centers + (pq_len << pq_bits) * label; + } + + if constexpr (PrecompBaseDiff) { + // Reduce computational complexity by pre-computing the difference + // between the cluster centroid and the query. + for (uint32_t i = threadIdx.x; i < dim; i += blockDim.x) { + base_diff[i] = query[i] - cluster_center[i]; + } + __syncthreads(); + } + + // Create a lookup table + // For each subspace, the lookup table stores the distance between the actual query vector + // (projected into the subspace) and all possible pq vectors in that subspace. + for (uint32_t i = threadIdx.x; i < (pq_dim << pq_bits); i += blockDim.x) { + uint32_t i_pq = i >> pq_bits; + uint32_t i_code = codebook_kind == codebook_gen::PER_CLUSTER ? i & ((1 << pq_bits) - 1) : i; + float score = 0.0; + switch (metric) { + case distance::DistanceType::L2Expanded: { + for (uint32_t j = 0; j < pq_len; j++) { + uint32_t k = j + (pq_len * i_pq); + float diff; + if constexpr (PrecompBaseDiff) { + diff = base_diff[k]; + } else { + diff = query[k] - cluster_center[k]; + } + diff -= pq_center[j + pq_len * i_code]; + score += diff * diff; + } + } break; + case distance::DistanceType::InnerProduct: { + for (uint32_t j = 0; j < pq_len; j++) { + uint32_t k = j + (pq_len * i_pq); + score += query[k] * (cluster_center[k] + pq_center[j + pq_len * i_code]); + } + } break; + } + lut_scores[i] = LutT(score); + } + + uint32_t sample_offset = 0; + if (probe_ix > 0) { sample_offset = chunk_indices[probe_ix - 1]; } + uint32_t n_samples = chunk_indices[probe_ix] - sample_offset; + uint32_t n_samples32 = Pow2<32>::roundUp(n_samples); + IdxT cluster_offset = cluster_offsets[label]; + + using local_topk_t = block_sort_t; + local_topk_t block_topk(topk, smem_buf); + + // Ensure lut_scores is written by all threads before using it in ivfpq_compute_score + __threadfence_block(); + __syncthreads(); + + // Compute a distance for each sample + const uint32_t pq_line_width = pq_dim * pq_bits / 8; + for (uint32_t i = threadIdx.x; i < n_samples32; i += blockDim.x) { + OutT score = local_topk_t::queue_t::kDummy; + if (i < n_samples) { + auto pq_ptr = + reinterpret_cast(pq_dataset + uint64_t(pq_line_width) * (cluster_offset + i)); + float fscore = ivfpq_compute_score(pq_bits, vec_len, pq_dim, pq_ptr, lut_scores); + switch (metric) { + // For similarity metrics, + // we negate the scores as we hardcoded select-topk to always take the minimum + case distance::DistanceType::InnerProduct: fscore = -fscore; break; + default: break; + } + if (fscore < float(score)) { score = OutT{fscore}; } + } + if constexpr (kManageLocalTopK) { + block_topk.add(score, cluster_offset + i); + } else { + if (i < n_samples) { out_scores[i + sample_offset] = score; } + } + } + __syncthreads(); + if constexpr (kManageLocalTopK) { + // sync threads before and after the topk merging operation, because we reuse smem_buf + block_topk.done(); + block_topk.store(out_scores, out_indices); + __syncthreads(); + } else { + // fill in the rest of the out_scores with dummy values + uint32_t max_samples = uint32_t(cluster_offsets[n_probes]); + if (probe_ix + 1 == n_probes) { + for (uint32_t i = threadIdx.x + sample_offset + n_samples; i < max_samples; + i += blockDim.x) { + out_scores[i] = local_topk_t::queue_t::kDummy; + } + } + } + } +} + +/** + * This structure selects configurable template parameters (instance) based on + * the search/index parameters at runtime. + * + * This is done by means of recusively iterating through a small set of possible + * values for every parameter. + */ +template +struct ivfpq_compute_similarity { + using kernel_t = void (*)(uint32_t, + uint32_t, + uint32_t, + uint32_t, + uint32_t, + uint32_t, + distance::DistanceType, + codebook_gen, + uint32_t, + const float*, + const float*, + const uint8_t*, + const IdxT*, + const uint32_t*, + const uint32_t*, + const float*, + const uint32_t*, + LutT*, + OutT*, + IdxT*); + + template + struct configured { + public: + /** + * Select a proper kernel instance based on the runtime parameters. + * + * @param pq_bits + * @param pq_dim + * @param k_max + */ + static auto kernel(uint32_t pq_bits, uint32_t pq_dim, uint32_t k_max) -> kernel_t + { + return kernel_base(pq_bits, pq_dim, k_max); + } + + private: + template + static auto kernel_try_capacity(uint32_t k_max) -> kernel_t + { + if constexpr (Capacity > 0) { + if (k_max == 0 || k_max > Capacity) { return kernel_try_capacity(k_max); } + } + if constexpr (Capacity > 32) { + if (k_max * 2 <= Capacity) { return kernel_try_capacity(k_max); } + } + return ivfpq_compute_similarity_kernel; + } + + static auto kernel_base(uint32_t pq_bits, uint32_t pq_dim, uint32_t k_max) -> kernel_t + { + switch (gcd(pq_bits * pq_dim, 64)) { + case 64: return kernel_try_capacity(k_max); + case 32: return kernel_try_capacity(k_max); + case 16: return kernel_try_capacity(k_max); + case 8: return kernel_try_capacity(k_max); + default: + RAFT_FAIL("`pq_bits * pq_dim` must be a multiple of 8 (pq_bits = %u, pq_dim = %u).", + pq_bits, + pq_dim); + } + } + }; + + struct selected { + kernel_t kernel; + uint32_t n_blocks; + uint32_t n_threads; + size_t smem_size; + size_t device_lut_size; + + template + void operator()(rmm::cuda_stream_view stream, Args&&... args) + { + kernel<<>>(std::forward(args)...); + } + }; + + /** + * Use heuristics to choose an optimal instance of the search kernel. + * It selects among a few kernel variants (with/out using shared mem for + * lookup tables / precomputed distances) and tries to choose the block size + * to maximize kernel occupancy. + * + * @param manage_local_topk + * whether use the fused calculate+select or just calculate the distances for each + * query and probed cluster. + * + */ + static inline auto select(bool manage_local_topk, + uint32_t pq_bits, + uint32_t pq_dim, + uint32_t rot_dim, + uint32_t preferred_thread_block_size, + uint32_t n_queries, + uint32_t n_probes, + uint32_t topk) -> selected + { + using conf_fast = configured; + using conf_no_basediff = configured; + using conf_no_smem_lut = configured; + + kernel_t kernel_fast = conf_fast::kernel(pq_bits, pq_dim, manage_local_topk ? topk : 0u); + kernel_t kernel_no_basediff = + conf_no_basediff::kernel(pq_bits, pq_dim, manage_local_topk ? topk : 0u); + kernel_t kernel_no_smem_lut = + conf_no_smem_lut::kernel(pq_bits, pq_dim, manage_local_topk ? topk : 0u); + + const size_t smem_threshold = 48 * 1024; + size_t smem_size = sizeof(LutT) * (pq_dim << pq_bits); + size_t smem_size_base_diff = sizeof(float) * rot_dim; + + uint32_t n_blocks = n_queries * n_probes; + uint32_t n_threads = 1024; + // preferred_thread_block_size == 0 means using auto thread block size calculation mode + if (preferred_thread_block_size == 0) { + const uint32_t thread_min = 256; + int cur_dev; + cudaDeviceProp dev_props; + RAFT_CUDA_TRY(cudaGetDevice(&cur_dev)); + RAFT_CUDA_TRY(cudaGetDeviceProperties(&dev_props, cur_dev)); + while (n_threads > thread_min) { + if (n_blocks < uint32_t(getMultiProcessorCount() * (1024 / (n_threads / 2)))) { break; } + if (dev_props.sharedMemPerMultiprocessor * 2 / 3 < smem_size * (1024 / (n_threads / 2))) { + break; + } + n_threads /= 2; + } + } else { + n_threads = preferred_thread_block_size; + } + size_t smem_size_local_topk = + manage_local_topk + ? topk::template calc_smem_size_for_block_wide(n_threads / WarpSize, topk) + : 0; + smem_size = max(smem_size, smem_size_local_topk); + + kernel_t kernel = kernel_no_basediff; + + bool kernel_no_basediff_available = true; + bool use_smem_lut = true; + if (smem_size > smem_threshold) { + cudaError_t cuda_status = cudaFuncSetAttribute( + kernel_no_basediff, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + if (cuda_status != cudaSuccess) { + RAFT_EXPECTS( + cuda_status == cudaGetLastError(), + "Tried to reset the expected cuda error code, but it didn't match the expectation"); + kernel_no_basediff_available = false; + + // Use "kernel_no_smem_lut" which just uses small amount of shared memory. + RAFT_LOG_DEBUG( + "Non-shared-mem look-up table kernel is selected, because it wouldn't fit shmem " + "required: " + "%zu bytes)", + smem_size); + kernel = kernel_no_smem_lut; + use_smem_lut = false; + n_threads = 1024; + smem_size_local_topk = + manage_local_topk + ? topk::template calc_smem_size_for_block_wide(n_threads / WarpSize, topk) + : 0; + smem_size = max(smem_size_base_diff, smem_size_local_topk); + n_blocks = getMultiProcessorCount(); + } + } + if (kernel_no_basediff_available) { + bool kernel_fast_available = true; + if (smem_size + smem_size_base_diff > smem_threshold) { + cudaError_t cuda_status = cudaFuncSetAttribute(kernel_fast, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size + smem_size_base_diff); + if (cuda_status != cudaSuccess) { + RAFT_EXPECTS( + cuda_status == cudaGetLastError(), + "Tried to reset the expected cuda error code, but it didn't match the expectation"); + kernel_fast_available = false; + RAFT_LOG_DEBUG( + "No-precomputed-basediff kernel is selected, because the basediff wouldn't fit (shmem " + "required: %zu bytes)", + smem_size + smem_size_base_diff); + } + } + if (kernel_fast_available) { + int kernel_no_basediff_n_blocks = 0; + RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &kernel_no_basediff_n_blocks, kernel_no_basediff, n_threads, smem_size)); + + int kernel_fast_n_blocks = 0; + RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &kernel_fast_n_blocks, kernel_fast, n_threads, smem_size + smem_size_base_diff)); + + // Use "kernel_fast" only if GPU occupancy does not drop + if (kernel_no_basediff_n_blocks == kernel_fast_n_blocks) { + kernel = kernel_fast; + smem_size += smem_size_base_diff; + } + } + } + + uint32_t device_lut_size = use_smem_lut ? 0u : n_blocks * (pq_dim << pq_bits); + return {kernel, n_blocks, n_threads, smem_size, device_lut_size}; + } +}; + +/** + * The "main part" of the search, which assumes that outer-level `search` has already: + * + * 1. computed the closest clusters to probe (`clusters_to_probe`); + * 2. transformed input queries into the rotated space (rot_dim); + * 3. split the query batch into smaller chunks, so that the device workspace + * is guaranteed to fit into GPU memory. + */ +template +void ivfpq_search_worker(const handle_t& handle, + const index& index, + uint32_t max_samples, + uint32_t n_probes, + uint32_t max_batch_size, + uint32_t topK, + uint32_t preferred_thread_block_size, + uint32_t n_queries, + const uint32_t* clusters_to_probe, // [n_queries, n_probes] + const float* query, // [n_queries, rot_dim] + IdxT* neighbors, // [n_queries, topK] + float* distances, // [n_queries, topK] + rmm::mr::device_memory_resource* mr) +{ + RAFT_EXPECTS(n_queries <= max_batch_size, + "number of queries (%u) must be smaller the max batch size (%u)", + n_queries, + max_batch_size); + auto stream = handle.get_stream(); + + auto pq_centers = index.pq_centers().data_handle(); + auto pq_dataset = index.pq_dataset().data_handle(); + auto data_indices = index.indices().data_handle(); + auto cluster_centers = index.centers_rot().data_handle(); + auto cluster_offsets = index.list_offsets().data_handle(); + + bool manage_local_topk = + topK <= kMaxCapacity // depth is not too large + && n_probes >= 16 // not too few clusters looked up + && max_batch_size * n_probes >= 256 // overall amount of work is not too small + ; + auto topk_len = manage_local_topk ? n_probes * topK : max_samples; + if (manage_local_topk) { + RAFT_LOG_DEBUG("Fused version of the search kernel is selected (manage_local_topk == true)"); + } else { + RAFT_LOG_DEBUG( + "Non-fused version of the search kernel is selected (manage_local_topk == false)"); + } + + rmm::device_uvector index_list_sorted_buf(0, stream, mr); + uint32_t* index_list_sorted = nullptr; + rmm::device_uvector num_samples(max_batch_size, stream, mr); + rmm::device_uvector chunk_index(max_batch_size * n_probes, stream, mr); + // [maxBatchSize, max_samples] or [maxBatchSize, n_probes, topk] + rmm::device_uvector distances_buf(max_batch_size * topk_len, stream, mr); + rmm::device_uvector neighbors_buf(0, stream, mr); + IdxT* neighbors_ptr = nullptr; + if (manage_local_topk) { + neighbors_buf.resize(max_batch_size * topk_len, stream); + neighbors_ptr = neighbors_buf.data(); + } + + calc_chunk_indices::configure(n_probes, n_queries)( + cluster_offsets, clusters_to_probe, chunk_index.data(), num_samples.data(), stream); + + if (n_queries * n_probes > 256) { + // Sorting index by cluster number (label). + // The goal is to incrase the L2 cache hit rate to read the vectors + // of a cluster by processing the cluster at the same time as much as + // possible. + index_list_sorted_buf.resize(max_batch_size * n_probes, stream); + rmm::device_uvector index_list_buf(max_batch_size * n_probes, stream, mr); + rmm::device_uvector cluster_labels_out(max_batch_size * n_probes, stream, mr); + auto index_list = index_list_buf.data(); + index_list_sorted = index_list_sorted_buf.data(); + thrust::sequence(handle.get_thrust_policy(), + thrust::device_pointer_cast(index_list), + thrust::device_pointer_cast(index_list + n_queries * n_probes)); + + int begin_bit = 0; + int end_bit = sizeof(uint32_t) * 8; + size_t cub_workspace_size = 0; + cub::DeviceRadixSort::SortPairs(nullptr, + cub_workspace_size, + clusters_to_probe, + cluster_labels_out.data(), + index_list, + index_list_sorted, + n_queries * n_probes, + begin_bit, + end_bit, + stream); + rmm::device_buffer cub_workspace(cub_workspace_size, stream, mr); + cub::DeviceRadixSort::SortPairs(cub_workspace.data(), + cub_workspace_size, + clusters_to_probe, + cluster_labels_out.data(), + index_list, + index_list_sorted, + n_queries * n_probes, + begin_bit, + end_bit, + stream); + } + + // select and run the main search kernel + auto search_instance = + ivfpq_compute_similarity::select(manage_local_topk, + index.pq_bits(), + index.pq_dim(), + index.rot_dim(), + preferred_thread_block_size, + n_queries, + n_probes, + topK); + + rmm::device_uvector device_lut(search_instance.device_lut_size, stream, mr); + search_instance(stream, + index.size(), + index.rot_dim(), + n_probes, + index.pq_bits(), + index.pq_dim(), + n_queries, + index.metric(), + index.codebook_kind(), + topK, + cluster_centers, + pq_centers, + pq_dataset, + cluster_offsets, + clusters_to_probe, + chunk_index.data(), + query, + index_list_sorted, + device_lut.data(), + distances_buf.data(), + neighbors_ptr); + + // Select topk vectors for each query + rmm::device_uvector topk_dists(n_queries * topK, stream, mr); + select_topk(distances_buf.data(), + neighbors_ptr, + n_queries, + topk_len, + topK, + topk_dists.data(), + neighbors, + true, + stream, + mr); + + // Postprocessing + postprocess_distances(distances, topk_dists.data(), index.metric(), n_queries, topK, stream); + postprocess_neighbors(neighbors, + manage_local_topk, + data_indices, + cluster_offsets, + clusters_to_probe, + chunk_index.data(), + n_queries, + n_probes, + topK, + stream); +} + +/** + * This structure helps selecting a proper instance of the worker search function, + * which contains a few template parameters. + */ +template +struct ivfpq_search { + public: + using fun_t = void (*)(const handle_t&, + const ivf_pq::index&, + uint32_t, + uint32_t, + uint32_t, + uint32_t, + uint32_t, + uint32_t, + const uint32_t*, + const float*, + IdxT*, + float*, + rmm::mr::device_memory_resource*); + + /** + * Select an instance of the ivf-pq search function based on search tuning parameters, + * such as the look-up data type or the internal score type. + */ + static auto fun(const search_params& params, distance::DistanceType metric) -> fun_t + { + return fun_try_score_t(params, metric); + } + + private: + template + static auto fun_try_lut_t(const search_params& params, distance::DistanceType metric) -> fun_t + { + bool signed_metric = false; + switch (metric) { + case raft::distance::DistanceType::InnerProduct: signed_metric = true; break; + default: break; + } + + switch (params.lut_dtype) { + case CUDA_R_32F: return ivfpq_search_worker; + case CUDA_R_16F: return ivfpq_search_worker; + case CUDA_R_8U: + case CUDA_R_8I: + if (signed_metric) { + return ivfpq_search_worker, IdxT>; + } else { + return ivfpq_search_worker, IdxT>; + } + default: RAFT_FAIL("Unexpected lut_dtype (%d)", int(params.lut_dtype)); + } + } + + static auto fun_try_score_t(const search_params& params, distance::DistanceType metric) -> fun_t + { + switch (params.internal_distance_dtype) { + case CUDA_R_32F: return fun_try_lut_t(params, metric); + case CUDA_R_16F: return fun_try_lut_t(params, metric); + default: + RAFT_FAIL("Unexpected internal_distance_dtype (%d)", int(params.internal_distance_dtype)); + } + } +}; + +/** + * A heuristic for bounding the number of queries per batch, to improve GPU utilization. + * (based on the number of SMs and the work size). + * + * @param n_queries number of queries hoped to be processed at once. + * (maximum value for the returned batch size) + * + * @return maximum recommended batch size. + */ +inline auto get_max_batch_size(uint32_t n_queries) -> uint32_t +{ + uint32_t max_batch_size = n_queries; + uint32_t n_ctas_total = getMultiProcessorCount() * 2; + uint32_t n_ctas_total_per_batch = n_ctas_total / max_batch_size; + float utilization = float(n_ctas_total_per_batch * max_batch_size) / n_ctas_total; + if (n_ctas_total_per_batch > 1 || (n_ctas_total_per_batch == 1 && utilization < 0.6)) { + uint32_t n_ctas_total_per_batch_1 = n_ctas_total_per_batch + 1; + uint32_t max_batch_size_1 = n_ctas_total / n_ctas_total_per_batch_1; + float utilization_1 = float(n_ctas_total_per_batch_1 * max_batch_size_1) / n_ctas_total; + if (utilization < utilization_1) { max_batch_size = max_batch_size_1; } + } + return max_batch_size; +} + +/** See raft::spatial::knn::ivf_pq::search docs */ +template +inline void search(const handle_t& handle, + const search_params& params, + const index& index, + const T* queries, + uint32_t n_queries, + uint32_t k, + IdxT* neighbors, + float* distances, + rmm::mr::device_memory_resource* mr = nullptr) +{ + static_assert(std::is_same_v || std::is_same_v || std::is_same_v, + "Unsupported element type."); + common::nvtx::range fun_scope( + "ivf_pq::search(k = %u, n_queries = %u, dim = %zu)", k, n_queries, index.dim()); + + RAFT_EXPECTS( + params.internal_distance_dtype == CUDA_R_16F || params.internal_distance_dtype == CUDA_R_32F, + "internal_distance_dtype must be either CUDA_R_16F or CUDA_R_32F"); + RAFT_EXPECTS(params.lut_dtype == CUDA_R_16F || params.lut_dtype == CUDA_R_32F || + params.lut_dtype == CUDA_R_8U, + "lut_dtype must be CUDA_R_16F, CUDA_R_32F or CUDA_R_8U"); + RAFT_EXPECTS( + params.preferred_thread_block_size == 256 || params.preferred_thread_block_size == 512 || + params.preferred_thread_block_size == 1024 || params.preferred_thread_block_size == 0, + "preferred_thread_block_size must be 0, 256, 512 or 1024, but %u is given.", + params.preferred_thread_block_size); + RAFT_EXPECTS(k > 0, "parameter `k` in top-k must be positive."); + RAFT_EXPECTS( + k <= index.size(), + "parameter `k` (%u) in top-k must not be larger that the total size of the index (%zu)", + k, + static_cast(index.size())); + RAFT_EXPECTS(params.n_probes > 0, + "n_probes (number of clusters to probe in the search) must be positive."); + + switch (utils::check_pointer_residency(queries, neighbors, distances)) { + case utils::pointer_residency::device_only: + case utils::pointer_residency::host_and_device: break; + default: RAFT_FAIL("all pointers must be accessible from the device."); + } + + auto stream = handle.get_stream(); + + auto dim = index.dim(); + auto dim_ext = index.dim_ext(); + auto n_probes = std::min(params.n_probes, index.n_lists()); + + IdxT max_samples = 0; + { + IdxT offset_worst_case = 0; + auto cluster_offsets = index.list_offsets().data_handle(); + copy(&max_samples, cluster_offsets + n_probes, 1, stream); + if (n_probes < index.n_nonempty_lists()) { + copy(&offset_worst_case, cluster_offsets + index.n_nonempty_lists() - n_probes, 1, stream); + } + handle.sync_stream(); + max_samples = Pow2<128>::roundUp(max_samples); + IdxT min_samples = index.size() - offset_worst_case; + if (IdxT{k} > min_samples) { + RAFT_LOG_WARN( + "n_probes is too small to get top-k results reliably (n_probes: %u, k: %u, n_samples " + "(worst_case): %zu).", + n_probes, + k, + static_cast(min_samples)); + } + RAFT_EXPECTS(max_samples <= IdxT(std::numeric_limits::max()), + "The maximum sample size is too big."); + } + + auto pool_guard = raft::get_pool_memory_resource(mr, n_queries * n_probes * k * 16); + if (pool_guard) { + RAFT_LOG_DEBUG("ivf_pq::search: using pool memory resource with initial size %zu bytes", + pool_guard->pool_size()); + } + + // Maximum number of query vectors to search at the same time. + const auto max_queries = std::min(std::max(n_queries, 1), 4096); + auto max_batch_size = get_max_batch_size(max_queries); + + rmm::device_uvector float_queries(max_queries * dim_ext, stream, mr); + rmm::device_uvector rot_queries(max_queries * index.rot_dim(), stream, mr); + rmm::device_uvector clusters_to_probe(max_queries * params.n_probes, stream, mr); + + auto search_instance = ivfpq_search::fun(params, index.metric()); + + for (uint32_t offset_q = 0; offset_q < n_queries; offset_q += max_queries) { + uint32_t queries_batch = min(max_queries, n_queries - offset_q); + + select_clusters(handle, + clusters_to_probe.data(), + float_queries.data(), + n_queries, + params.n_probes, + index.n_lists(), + dim, + dim_ext, + index.metric(), + queries + static_cast(dim) * offset_q, + index.centers().data_handle(), + mr); + + // Rotate queries + float alpha = 1.0; + float beta = 0.0; + linalg::gemm(handle, + true, + false, + index.rot_dim(), + queries_batch, + dim, + &alpha, + index.rotation_matrix().data_handle(), + dim, + float_queries.data(), + dim_ext, + &beta, + rot_queries.data(), + index.rot_dim(), + stream); + + for (uint32_t offset_b = 0; offset_b < queries_batch; offset_b += max_batch_size) { + uint32_t batch_size = min(max_batch_size, queries_batch - offset_b); + /* The distance calculation is done in the rotated/transformed space; + as long as `index.rotation_matrix()` is orthogonal, the distances and thus results are + preserved. + */ + search_instance(handle, + index, + max_samples, + params.n_probes, + max_batch_size, + k, + params.preferred_thread_block_size, + batch_size, + clusters_to_probe.data() + uint64_t(params.n_probes) * offset_b, + rot_queries.data() + uint64_t(index.rot_dim()) * offset_b, + neighbors + uint64_t(k) * (offset_q + offset_b), + distances + uint64_t(k) * (offset_q + offset_b), + mr); + } + } +} + +} // namespace raft::spatial::knn::ivf_pq::detail diff --git a/cpp/include/raft/spatial/knn/detail/topk.cuh b/cpp/include/raft/spatial/knn/detail/topk.cuh new file mode 100644 index 0000000000..5adf6df472 --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/topk.cuh @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "topk/radix_topk.cuh" +#include "topk/warpsort_topk.cuh" + +#include +#include + +namespace raft::spatial::knn::detail { + +/** + * Select k smallest or largest key/values from each row in the input data. + * + * If you think of the input data `in_keys` as a row-major matrix with len columns and + * batch_size rows, then this function selects k smallest/largest values in each row and fills + * in the row-major matrix `out` of size (batch_size, k). + * + * @tparam T + * the type of the keys (what is being compared). + * @tparam IdxT + * the index type (what is being selected together with the keys). + * + * @param[in] in + * contiguous device array of inputs of size (len * batch_size); + * these are compared and selected. + * @param[in] in_idx + * contiguous device array of inputs of size (len * batch_size); + * typically, these are indices of the corresponding in_keys. + * @param batch_size + * number of input rows, i.e. the batch size. + * @param len + * length of a single input array (row); also sometimes referred as n_cols. + * Invariant: len >= k. + * @param k + * the number of outputs to select in each input row. + * @param[out] out + * contiguous device array of outputs of size (k * batch_size); + * the k smallest/largest values from each row of the `in_keys`. + * @param[out] out_idx + * contiguous device array of outputs of size (k * batch_size); + * the payload selected together with `out`. + * @param select_min + * whether to select k smallest (true) or largest (false) keys. + * @param stream + * @param mr an optional memory resource to use across the calls (you can provide a large enough + * memory pool here to avoid memory allocations within the call). + */ +template +void select_topk(const T* in, + const IdxT* in_idx, + size_t batch_size, + size_t len, + int k, + T* out, + IdxT* out_idx, + bool select_min, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = nullptr) +{ + if (k <= raft::spatial::knn::detail::topk::kMaxCapacity) { + topk::warp_sort_topk( + in, in_idx, batch_size, len, k, out, out_idx, select_min, stream, mr); + } else { + topk::radix_topk= 4 ? 11 : 8), 512>( + in, in_idx, batch_size, len, k, out, out_idx, select_min, stream, mr); + } +} + +} // namespace raft::spatial::knn::detail diff --git a/cpp/include/raft/spatial/knn/detail/topk/radix_topk.cuh b/cpp/include/raft/spatial/knn/detail/topk/radix_topk.cuh index 4cbad8e906..9c0f20b706 100644 --- a/cpp/include/raft/spatial/knn/detail/topk/radix_topk.cuh +++ b/cpp/include/raft/spatial/knn/detail/topk/radix_topk.cuh @@ -19,6 +19,7 @@ #include #include #include +#include #include #include diff --git a/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh b/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh index dfbe8a735d..84cc072620 100644 --- a/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh +++ b/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh @@ -506,8 +506,8 @@ template