From 66b8493cd3b228838ea8d50474e824f718286c1a Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Thu, 6 Apr 2023 10:13:57 +0200 Subject: [PATCH] CAGRA (#1375) This PR adds CAGRA, a graph based method for nearest neighbor search. Authors: - Tamas Bela Feher (https://github.com/tfeher) - Corey J. Nolet (https://github.com/cjnolet) - Ben Frederickson (https://github.com/benfred) Approvers: - Ben Frederickson (https://github.com/benfred) - Divye Gala (https://github.com/divyegala) URL: https://github.com/rapidsai/raft/pull/1375 --- cpp/include/raft/neighbors/cagra.cuh | 226 ++++ .../raft/neighbors/cagra_serialize.cuh | 154 +++ cpp/include/raft/neighbors/cagra_types.hpp | 199 +++ .../raft/neighbors/detail/cagra/bitonic.hpp | 226 ++++ .../neighbors/detail/cagra/cagra_build.cuh | 236 ++++ .../neighbors/detail/cagra/cagra_search.cuh | 100 ++ .../detail/cagra/cagra_serialize.cuh | 123 ++ .../detail/cagra/compute_distance.hpp | 253 ++++ .../neighbors/detail/cagra/device_common.hpp | 76 ++ .../raft/neighbors/detail/cagra/factory.cuh | 90 ++ .../raft/neighbors/detail/cagra/fragment.hpp | 212 +++ .../neighbors/detail/cagra/graph_core.cuh | 809 ++++++++++++ .../raft/neighbors/detail/cagra/hashmap.hpp | 86 ++ .../detail/cagra/search_multi_cta.cuh | 632 +++++++++ .../detail/cagra/search_multi_kernel.cuh | 721 ++++++++++ .../neighbors/detail/cagra/search_plan.cuh | 334 +++++ .../detail/cagra/search_single_cta.cuh | 1157 +++++++++++++++++ .../detail/cagra/topk_for_cagra/topk.h | 57 + .../detail/cagra/topk_for_cagra/topk_core.cuh | 926 +++++++++++++ .../raft/neighbors/detail/cagra/utils.hpp | 143 ++ cpp/include/raft/util/cache_util.cuh | 4 +- cpp/test/CMakeLists.txt | 1 + cpp/test/neighbors/ann_cagra.cuh | 313 +++++ .../ann_cagra/test_float_uint32_t.cu | 32 + cpp/test/neighbors/ann_utils.cuh | 49 + 25 files changed, 7157 insertions(+), 2 deletions(-) create mode 100644 cpp/include/raft/neighbors/cagra.cuh create mode 100644 cpp/include/raft/neighbors/cagra_serialize.cuh create mode 100644 cpp/include/raft/neighbors/cagra_types.hpp create mode 100644 cpp/include/raft/neighbors/detail/cagra/bitonic.hpp create mode 100644 cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh create mode 100644 cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh create mode 100644 cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh create mode 100644 cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp create mode 100644 cpp/include/raft/neighbors/detail/cagra/device_common.hpp create mode 100644 cpp/include/raft/neighbors/detail/cagra/factory.cuh create mode 100644 cpp/include/raft/neighbors/detail/cagra/fragment.hpp create mode 100644 cpp/include/raft/neighbors/detail/cagra/graph_core.cuh create mode 100644 cpp/include/raft/neighbors/detail/cagra/hashmap.hpp create mode 100644 cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh create mode 100644 cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh create mode 100644 cpp/include/raft/neighbors/detail/cagra/search_plan.cuh create mode 100644 cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh create mode 100644 cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk.h create mode 100644 cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh create mode 100644 cpp/include/raft/neighbors/detail/cagra/utils.hpp create mode 100644 cpp/test/neighbors/ann_cagra.cuh create mode 100644 cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh new file mode 100644 index 0000000000..90728efd70 --- /dev/null +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -0,0 +1,226 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "detail/cagra/cagra_build.cuh" +#include "detail/cagra/cagra_search.cuh" +#include "detail/cagra/graph_core.cuh" + +#include +#include +#include +#include +#include +#include + +namespace raft::neighbors::experimental::cagra { + +/** + * @defgroup cagra CUDA ANN Graph-based nearest neighbor search + * @{ + */ + +/** + * @brief Build a kNN graph. + * + * The kNN graph is the first building block for CAGRA index. + * This function uses the IVF-PQ method to build a kNN graph. + * + * The output is a dense matrix that stores the neighbor indices for each pont in the dataset. + * Each point has the same number of neighbors. + * + * See [cagra::build](#cagra::build) for an alternative method. + * + * The following distance metrics are supported: + * - L2Expanded + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * // use default index parameters + * ivf_pq::index_params build_params; + * ivf_pq::search_params search_params + * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128); + * // create knn graph + * cagra::build_knn_graph(res, dataset, knn_graph.view(), 2, build_params, search_params); + * auto pruned_gaph = raft::make_host_matrix(dataset.extent(0), 64); + * cagra::prune(res, dataset, knn_graph.view(), pruned_graph.view()); + * // Construct an index from dataset and pruned knn_graph + * auto index = cagra::index(res, build_params.metric(), dataset, pruned_graph.view()); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] res raft resources + * @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim] + * @param[out] knn_graph a host matrix view to store the output knn graph [n_rows, graph_degree] + * @param[in] refine_rate refinement rate for ivf-pq search + * @param[in] build_params (optional) ivf_pq index building parameters for knn graph + * @param[in] search_params (optional) ivf_pq search parameters + */ +template +void build_knn_graph(raft::device_resources const& res, + mdspan, row_major, accessor> dataset, + raft::host_matrix_view knn_graph, + std::optional refine_rate = std::nullopt, + std::optional build_params = std::nullopt, + std::optional search_params = std::nullopt) +{ + detail::build_knn_graph(res, dataset, knn_graph, refine_rate, build_params, search_params); +} + +/** + * @brief Prune a KNN graph. + * + * Decrease the number of neighbors for each node. + * + * See [cagra::build_knn_graph](#cagra::build_knn_graph) for usage example + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] res raft resources + * @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim] + * @param[in] knn_graph a matrix view (host or device) of the input knn graph [n_rows, + * knn_graph_degree] + * @param[out] new_graph a host matrix view of the pruned knn graph [n_rows, graph_degree] + */ +template , memory_type::device>, + typename g_accessor = + host_device_accessor, memory_type::host>> +void prune(raft::device_resources const& res, + mdspan, row_major, d_accessor> dataset, + mdspan, row_major, g_accessor> knn_graph, + raft::host_matrix_view new_graph) +{ + detail::graph::prune(res, dataset, knn_graph, new_graph); +} + +/** + * @brief Build the index from the dataset for efficient search. + * + * The build consist of two steps: build an intermediate knn-graph, and prune it to + * create the final graph. The index_params struct controls the node degree of these + * graphs. + * + * It is required that dataset and the pruned graph fit the GPU memory. + * + * To customize the parameters for knn-graph building and pruning, and to reuse the + * intermediate results, you could build the index in two steps using + * [cagra::build_knn_graph](#cagra::build_knn_graph) and [cagra::prune](#cagra::prune). + * + * The following distance metrics are supported: + * - L2 + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * // use default index parameters + * cagra::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = cagra::build(res, index_params, dataset); + * // use default search parameters + * ivf_pq::search_params search_params; + * // search K nearest neighbours + * auto neighbors = raft::make_device_matrix(res, n_queries, k); + * auto distances = raft::make_device_matrix(res, n_queries, k); + * ivf_pq::search(res, search_params, index, queries, neighbors, distances); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] res + * @param[in] params parameters for building the index + * @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim] + * + * @return the constructed cagra index + */ +template , memory_type::host>> +index build(raft::device_resources const& res, + const index_params& params, + mdspan, row_major, Accessor> dataset) +{ + size_t degree = params.intermediate_graph_degree; + if (degree >= dataset.extent(0)) { + RAFT_LOG_WARN( + "Intermediate graph degree cannot be larger than dataset size, reducing it to %lu", + dataset.extent(0)); + degree = dataset.extent(0) - 1; + } + RAFT_EXPECTS(degree >= params.graph_degree, + "Intermediate graph degree cannot be smaller than final graph degree"); + + auto knn_graph = raft::make_host_matrix(dataset.extent(0), degree); + + build_knn_graph(res, dataset, knn_graph.view()); + + auto cagra_graph = raft::make_host_matrix(dataset.extent(0), params.graph_degree); + + prune(res, dataset, knn_graph.view(), cagra_graph.view()); + + // Construct an index from dataset and pruned knn graph. + return index(res, params.metric, dataset, cagra_graph.view()); +} + +/** + * @brief Search ANN using the constructed index. + * + * See the [cagra::build](#cagra::build) documentation for a usage example. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] res raft resources + * @param[in] params configure the search + * @param[in] idx cagra index + * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * k] + */ +template +void search(raft::device_resources const& res, + const search_params& params, + const index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) +{ + RAFT_EXPECTS( + queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), + "Number of rows in output neighbors and distances matrices must equal the number of queries."); + + RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), + "Number of columns in output neighbors and distances matrices must equal k"); + + RAFT_EXPECTS(queries.extent(1) == idx.dim(), + "Number of query dimensions should equal number of dimensions in the index."); + + detail::search_main(res, params, idx, queries, neighbors, distances); +} +/** @} */ // end group cagra + +} // namespace raft::neighbors::experimental::cagra diff --git a/cpp/include/raft/neighbors/cagra_serialize.cuh b/cpp/include/raft/neighbors/cagra_serialize.cuh new file mode 100644 index 0000000000..befd5e9c07 --- /dev/null +++ b/cpp/include/raft/neighbors/cagra_serialize.cuh @@ -0,0 +1,154 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "detail/cagra/cagra_serialize.cuh" + +namespace raft::neighbors::experimental::cagra { + +/** + * \defgroup cagra_serialize CAGRA Serialize + * @{ + */ + +/** + * Write the index to an output stream + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::device_resources handle; + * + * // create an output stream + * std::ostream os(std::cout.rdbuf()); + * // create an index with `auto index = cagra::build(...);` + * raft::serialize(handle, os, index); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle the raft handle + * @param[in] os output stream + * @param[in] index CAGRA index + * + */ +template +void serialize(raft::device_resources const& handle, std::ostream& os, const index& index) +{ + detail::serialize(handle, os, index); +} + +/** + * Save the index to file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::device_resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * // create an index with `auto index = cagra::build(...);` + * raft::serialize(handle, filename, index); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle the raft handle + * @param[in] filename the file name for saving the index + * @param[in] index CAGRA index + * + */ +template +void serialize(raft::device_resources const& handle, + const std::string& filename, + const index& index) +{ + detail::serialize(handle, filename, index); +} + +/** + * Load index from input stream + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::device_resources handle; + * + * // create an input stream + * std::istream is(std::cin.rdbuf()); + * using T = float; // data element type + * using IdxT = int; // type of the index + * auto index = raft::deserialize(handle, is); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle the raft handle + * @param[in] is input stream + * + * @return raft::neighbors::cagra::index + */ +template +index deserialize(raft::device_resources const& handle, std::istream& is) +{ + return detail::deserialize(handle, is); +} + +/** + * Load index from file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::device_resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * using T = float; // data element type + * using IdxT = int; // type of the index + * auto index = raft::deserialize(handle, filename); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle the raft handle + * @param[in] filename the name of the file that stores the index + * + * @return raft::neighbors::cagra::index + */ +template +index deserialize(raft::device_resources const& handle, const std::string& filename) +{ + return detail::deserialize(handle, filename); +} + +/**@}*/ + +} // namespace raft::neighbors::experimental::cagra diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp new file mode 100644 index 0000000000..bd9b3b586b --- /dev/null +++ b/cpp/include/raft/neighbors/cagra_types.hpp @@ -0,0 +1,199 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "ann_types.hpp" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace raft::neighbors::experimental::cagra { +/** + * @ingroup cagra + * @{ + */ + +struct index_params : ann::index_params { + size_t intermediate_graph_degree = 128; // Degree of input graph for pruning. + size_t graph_degree = 64; // Degree of output graph. +}; + +enum class search_algo { + SINGLE_CTA, // for large batch + MULTI_CTA, // for small batch + MULTI_KERNEL, + AUTO +}; + +enum class hash_mode { HASH, SMALL, AUTO }; + +struct search_params : ann::search_params { + /** Maximum number of queries to search at the same time (batch size). */ + size_t max_queries = 1; + + /** Number of intermediate search results retained during the search. + * + * This is the main knob to adjust trade off between accuracy and search speed. + * Higher values improve the search accuracy. + */ + size_t itopk_size = 64; + + /** Upper limit of search iterations. Auto select when 0.*/ + size_t max_iterations = 0; + + // In the following we list additional search parameters for fine tuning. + // Reasonable default values are automatically chosen. + + /** Which search implementation to use. */ + search_algo algo = search_algo::AUTO; + + /** Number of threads used to calculate a single distance. 4, 8, 16, or 32. */ + size_t team_size = 0; + + /*/ Number of graph nodes to select as the starting point for the search in each iteration. aka + * search width?*/ + size_t num_parents = 1; + /** Lower limit of search iterations. */ + size_t min_iterations = 0; + + /** Bit length for reading the dataset vectors. 0, 64 or 128. Auto selection when 0. */ + size_t load_bit_length = 0; + /** Thread block size. 0, 64, 128, 256, 512, 1024. Auto selection when 0. */ + size_t thread_block_size = 0; + /** Hashmap type. Auto selection when AUTO. */ + hash_mode hashmap_mode = hash_mode::AUTO; + /** Lower limit of hashmap bit length. More than 8. */ + size_t hashmap_min_bitlen = 0; + /** Upper limit of hashmap fill rate. More than 0.1, less than 0.9.*/ + float hashmap_max_fill_rate = 0.5; + + /* Number of iterations of initial random seed node selection. 1 or more. */ + uint32_t num_random_samplings = 1; + // Bit mask used for initial random seed node selection. */ + uint64_t rand_xor_mask = 0x128394; +}; + +static_assert(std::is_aggregate_v); +static_assert(std::is_aggregate_v); + +/** + * @brief CAGRA index. + * + * The index stores the dataset and a kNN graph in device memory. + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + */ +template +struct index : ann::index { + static_assert(!raft::is_narrowing_v, + "IdxT must be able to represent all values of uint32_t"); + + public: + /** Distance metric used for clustering. */ + [[nodiscard]] constexpr inline auto metric() const noexcept -> raft::distance::DistanceType + { + return metric_; + } + + // /** Total length of the index. */ + [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT { return dataset_.extent(0); } + + /** Dimensionality of the data. */ + [[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t + { + return dataset_.extent(1); + } + /** Graph degree */ + [[nodiscard]] constexpr inline auto graph_degree() const noexcept -> uint32_t + { + return graph_.extent(1); + } + + /** Dataset [size, dim] */ + [[nodiscard]] inline auto dataset() const noexcept -> device_matrix_view + { + return dataset_.view(); + } + + /** neighborhood graph [size, graph-degree] */ + inline auto graph() noexcept -> device_matrix_view + { + return graph_.view(); + } + + [[nodiscard]] inline auto graph() const noexcept + -> device_matrix_view + { + return graph_.view(); + } + + // Don't allow copying the index for performance reasons (try avoiding copying data) + index(const index&) = delete; + index(index&&) = default; + auto operator=(const index&) -> index& = delete; + auto operator=(index&&) -> index& = default; + ~index() = default; + + /** Construct an empty index. */ + index(raft::device_resources const& res) + : ann::index(), + metric_(raft::distance::DistanceType::L2Expanded), + dataset_(make_device_matrix(res, 0, 0)), + graph_(make_device_matrix(res, 0, 0)) + { + } + + /** Construct an index from dataset and knn_graph arrays */ + template + index(raft::device_resources const& res, + raft::distance::DistanceType metric, + mdspan, row_major, data_accessor> dataset, + mdspan, row_major, graph_accessor> knn_graph) + : ann::index(), + metric_(metric), + dataset_(make_device_matrix(res, dataset.extent(0), dataset.extent(1))), + graph_(make_device_matrix(res, knn_graph.extent(0), knn_graph.extent(1))) + { + RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0), + "Dataset and knn_graph must have equal number of rows"); + raft::copy(dataset_.data_handle(), dataset.data_handle(), dataset.size(), res.get_stream()); + raft::copy(graph_.data_handle(), knn_graph.data_handle(), knn_graph.size(), res.get_stream()); + res.sync_stream(); + } + + private: + raft::distance::DistanceType metric_; + raft::device_matrix dataset_; + raft::device_matrix graph_; +}; + +/** @} */ + +} // namespace raft::neighbors::experimental::cagra diff --git a/cpp/include/raft/neighbors/detail/cagra/bitonic.hpp b/cpp/include/raft/neighbors/detail/cagra/bitonic.hpp new file mode 100644 index 0000000000..45aff99421 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/cagra/bitonic.hpp @@ -0,0 +1,226 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +namespace raft::neighbors::experimental::cagra::detail { +namespace bitonic { + +namespace detail { + +template +_RAFT_DEVICE inline void swap_if_needed(K& k0, V& v0, K& k1, V& v1, const bool asc) +{ + if ((k0 != k1) && ((k0 < k1) != asc)) { + const auto tmp_k = k0; + k0 = k1; + k1 = tmp_k; + const auto tmp_v = v0; + v0 = v1; + v1 = tmp_v; + } +} + +template +_RAFT_DEVICE inline void swap_if_needed(K& k0, V& v0, const unsigned lane_offset, const bool asc) +{ + auto k1 = __shfl_xor_sync(~0u, k0, lane_offset); + auto v1 = __shfl_xor_sync(~0u, v0, lane_offset); + if ((k0 != k1) && ((k0 < k1) != asc)) { + k0 = k1; + v0 = v1; + } +} + +template +struct warp_merge_core { + _RAFT_DEVICE inline void operator()(K k[N], V v[N], const std::uint32_t range, const bool asc) + { + const auto lane_id = threadIdx.x % warp_size; + + if (range == 1) { + for (std::uint32_t b = 2; b <= N; b <<= 1) { + for (std::uint32_t c = b / 2; c >= 1; c >>= 1) { +#pragma unroll + for (std::uint32_t i = 0; i < N; i++) { + std::uint32_t j = i ^ c; + if (i >= j) continue; + const auto line_id = i + (N * lane_id); + const auto p = static_cast(line_id & b) == static_cast(line_id & c); + swap_if_needed(k[i], v[i], k[j], v[j], p); + } + } + } + return; + } + + const std::uint32_t b = range; + for (std::uint32_t c = b / 2; c >= 1; c >>= 1) { + const auto p = static_cast(lane_id & b) == static_cast(lane_id & c); +#pragma unroll + for (std::uint32_t i = 0; i < N; i++) { + swap_if_needed(k[i], v[i], c, p); + } + } + const auto p = ((lane_id & b) == 0); + for (std::uint32_t c = N / 2; c >= 1; c >>= 1) { +#pragma unroll + for (std::uint32_t i = 0; i < N; i++) { + std::uint32_t j = i ^ c; + if (i >= j) continue; + swap_if_needed(k[i], v[i], k[j], v[j], p); + } + } + } +}; + +template +struct warp_merge_core { + _RAFT_DEVICE inline void operator()(K k[6], V v[6], const std::uint32_t range, const bool asc) + { + constexpr unsigned N = 6; + const auto lane_id = threadIdx.x % warp_size; + + if (range == 1) { + for (std::uint32_t i = 0; i < N; i += 3) { + const auto p = (i == 0); + swap_if_needed(k[0 + i], v[0 + i], k[1 + i], v[1 + i], p); + swap_if_needed(k[1 + i], v[1 + i], k[2 + i], v[2 + i], p); + swap_if_needed(k[0 + i], v[0 + i], k[1 + i], v[1 + i], p); + } + const auto p = ((lane_id & 1) == 0); + for (std::uint32_t i = 0; i < 3; i++) { + std::uint32_t j = i + 3; + swap_if_needed(k[i], v[i], k[j], v[j], p); + } + for (std::uint32_t i = 0; i < N; i += 3) { + swap_if_needed(k[0 + i], v[0 + i], k[1 + i], v[1 + i], p); + swap_if_needed(k[1 + i], v[1 + i], k[2 + i], v[2 + i], p); + swap_if_needed(k[0 + i], v[0 + i], k[1 + i], v[1 + i], p); + } + return; + } + + const std::uint32_t b = range; + for (std::uint32_t c = b / 2; c >= 1; c >>= 1) { + const auto p = static_cast(lane_id & b) == static_cast(lane_id & c); +#pragma unroll + for (std::uint32_t i = 0; i < N; i++) { + swap_if_needed(k[i], v[i], c, p); + } + } + const auto p = ((lane_id & b) == 0); + for (std::uint32_t i = 0; i < 3; i++) { + std::uint32_t j = i + 3; + swap_if_needed(k[i], v[i], k[j], v[j], p); + } + for (std::uint32_t i = 0; i < N; i += N / 2) { + swap_if_needed(k[0 + i], v[0 + i], k[1 + i], v[1 + i], p); + swap_if_needed(k[1 + i], v[1 + i], k[2 + i], v[2 + i], p); + swap_if_needed(k[0 + i], v[0 + i], k[1 + i], v[1 + i], p); + } + } +}; + +template +struct warp_merge_core { + _RAFT_DEVICE inline void operator()(K k[3], V v[3], const std::uint32_t range, const bool asc) + { + constexpr unsigned N = 3; + const auto lane_id = threadIdx.x % warp_size; + + if (range == 1) { + const auto p = ((lane_id & 1) == 0); + swap_if_needed(k[0], v[0], k[1], v[1], p); + swap_if_needed(k[1], v[1], k[2], v[2], p); + swap_if_needed(k[0], v[0], k[1], v[1], p); + return; + } + + const std::uint32_t b = range; + for (std::uint32_t c = b / 2; c >= 1; c >>= 1) { + const auto p = static_cast(lane_id & b) == static_cast(lane_id & c); +#pragma unroll + for (std::uint32_t i = 0; i < N; i++) { + swap_if_needed(k[i], v[i], c, p); + } + } + const auto p = ((lane_id & b) == 0); + swap_if_needed(k[0], v[0], k[1], v[1], p); + swap_if_needed(k[1], v[1], k[2], v[2], p); + swap_if_needed(k[0], v[0], k[1], v[1], p); + } +}; + +template +struct warp_merge_core { + _RAFT_DEVICE inline void operator()(K k[2], V v[2], const std::uint32_t range, const bool asc) + { + constexpr unsigned N = 2; + const auto lane_id = threadIdx.x % warp_size; + + if (range == 1) { + const auto p = ((lane_id & 1) == 0); + swap_if_needed(k[0], v[0], k[1], v[1], p); + return; + } + + const std::uint32_t b = range; + for (std::uint32_t c = b / 2; c >= 1; c >>= 1) { + const auto p = static_cast(lane_id & b) == static_cast(lane_id & c); +#pragma unroll + for (std::uint32_t i = 0; i < N; i++) { + swap_if_needed(k[i], v[i], c, p); + } + } + const auto p = ((lane_id & b) == 0); + swap_if_needed(k[0], v[0], k[1], v[1], p); + } +}; + +template +struct warp_merge_core { + _RAFT_DEVICE inline void operator()(K k[1], V v[1], const std::uint32_t range, const bool asc) + { + const auto lane_id = threadIdx.x % warp_size; + const std::uint32_t b = range; + for (std::uint32_t c = b / 2; c >= 1; c >>= 1) { + const auto p = static_cast(lane_id & b) == static_cast(lane_id & c); + swap_if_needed(k[0], v[0], c, p); + } + } +}; + +} // namespace detail + +template +__device__ void warp_merge(K k[N], V v[N], unsigned range, const bool asc = true) +{ + detail::warp_merge_core{}(k, v, range, asc); +} + +template +__device__ void warp_sort(K k[N], V v[N], const bool asc = true) +{ + for (std::uint32_t range = 1; range <= warp_size; range <<= 1) { + warp_merge(k, v, range, asc); + } +} + +} // namespace bitonic +} // namespace raft::neighbors::experimental::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh new file mode 100644 index 0000000000..4d63fb7999 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh @@ -0,0 +1,236 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "../../cagra_types.hpp" +#include "graph_core.cuh" +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace raft::neighbors::experimental::cagra::detail { + +using INDEX_T = std::uint32_t; + +template +void build_knn_graph(raft::device_resources const& res, + mdspan, row_major, accessor> dataset, + raft::host_matrix_view knn_graph, + std::optional refine_rate = std::nullopt, + std::optional build_params = std::nullopt, + std::optional search_params = std::nullopt) +{ + RAFT_EXPECTS( + dataset.extent(1) * sizeof(DataT) % 8 == 0, + "Dataset rows are expected to have at least 8 bytes alignment. Try padding feature dims."); + + RAFT_EXPECTS(!build_params || build_params->metric == distance::DistanceType::L2Expanded, + "Currently only L2Expanded metric is supported"); + + uint32_t node_degree = knn_graph.extent(1); + common::nvtx::range fun_scope("cagra::build_graph(%zu, %zu, %u)", + size_t(dataset.extent(0)), + size_t(dataset.extent(1)), + node_degree); + + if (!build_params) { + build_params = ivf_pq::index_params{}; + build_params->n_lists = dataset.extent(0) < 4 * 2500 ? 4 : (uint32_t)(dataset.extent(0) / 2500); + build_params->pq_dim = raft::Pow2<8>::roundUp(dataset.extent(1) / 2); + build_params->pq_bits = 8; + build_params->kmeans_trainset_fraction = dataset.extent(0) < 10000 ? 1 : 10; + build_params->kmeans_n_iters = 25; + build_params->add_data_on_build = true; + } + + // Make model name + const std::string model_name = [&]() { + char model_name[1024]; + sprintf(model_name, + "%s-%lux%lu.cluster_%u.pq_%u.%ubit.itr_%u.metric_%u.pqcenter_%u", + "IVF-PQ", + static_cast(dataset.extent(0)), + static_cast(dataset.extent(1)), + build_params->n_lists, + build_params->pq_dim, + build_params->pq_bits, + build_params->kmeans_n_iters, + build_params->metric, + static_cast(build_params->codebook_kind)); + return std::string(model_name); + }(); + + RAFT_LOG_DEBUG("# Building IVF-PQ index %s", model_name.c_str()); + auto index = ivf_pq::build( + res, *build_params, dataset.data_handle(), dataset.extent(0), dataset.extent(1)); + + // + // search top (k + 1) neighbors + // + if (!search_params) { + search_params = ivf_pq::search_params{}; + search_params->n_probes = std::min(dataset.extent(1) * 2, build_params->n_lists); + search_params->lut_dtype = CUDA_R_8U; + search_params->internal_distance_dtype = CUDA_R_32F; + } + const auto top_k = node_degree + 1; + uint32_t gpu_top_k = node_degree * refine_rate.value_or(2.0f); + gpu_top_k = std::min(std::max(gpu_top_k, top_k), dataset.extent(0)); + const auto num_queries = dataset.extent(0); + const auto max_batch_size = 1024; + RAFT_LOG_DEBUG( + "IVF-PQ search node_degree: %d, top_k: %d, gpu_top_k: %d, max_batch_size:: %d, n_probes: %u", + node_degree, + top_k, + gpu_top_k, + max_batch_size, + search_params->n_probes); + + // TODO(tfeher): shall we use uint32_t? + auto distances = raft::make_device_matrix(res, max_batch_size, gpu_top_k); + auto neighbors = raft::make_device_matrix(res, max_batch_size, gpu_top_k); + auto refined_distances = raft::make_device_matrix(res, max_batch_size, top_k); + auto refined_neighbors = raft::make_device_matrix(res, max_batch_size, top_k); + auto neighbors_host = raft::make_host_matrix(max_batch_size, gpu_top_k); + auto queries_host = raft::make_host_matrix(max_batch_size, dataset.extent(1)); + auto refined_neighbors_host = raft::make_host_matrix(max_batch_size, top_k); + auto refined_distances_host = raft::make_host_matrix(max_batch_size, top_k); + + // TODO(tfeher): batched search with multiple GPUs + std::size_t num_self_included = 0; + bool first = true; + const auto start_clock = std::chrono::system_clock::now(); + + rmm::mr::device_memory_resource* device_memory = nullptr; + auto pool_guard = raft::get_pool_memory_resource(device_memory, 1024 * 1024); + if (pool_guard) { + RAFT_LOG_DEBUG("ivf_pq using pool memory resource with initial size %zu bytes", + pool_guard->pool_size()); + } + + raft::spatial::knn::detail::utils::batch_load_iterator vec_batches(dataset.data_handle(), + dataset.extent(0), + dataset.extent(1), + max_batch_size, + res.get_stream(), + device_memory); + + for (const auto& batch : vec_batches) { + auto queries_view = raft::make_device_matrix_view( + batch.data(), batch.size(), batch.row_width()); + auto neighbors_view = make_device_matrix_view( + neighbors.data_handle(), batch.size(), neighbors.extent(1)); + auto distances_view = make_device_matrix_view( + distances.data_handle(), batch.size(), distances.extent(1)); + + ivf_pq::search(res, *search_params, index, queries_view, neighbors_view, distances_view); + + if constexpr (is_host_mdspan_v) { + raft::copy(neighbors_host.data_handle(), + neighbors.data_handle(), + neighbors_view.size(), + res.get_stream()); + raft::copy(queries_host.data_handle(), batch.data(), queries_view.size(), res.get_stream()); + auto queries_host_view = make_host_matrix_view( + queries_host.data_handle(), batch.size(), batch.row_width()); + auto neighbors_host_view = make_host_matrix_view( + neighbors_host.data_handle(), batch.size(), neighbors.extent(1)); + auto refined_neighbors_host_view = make_host_matrix_view( + refined_neighbors_host.data_handle(), batch.size(), top_k); + auto refined_distances_host_view = make_host_matrix_view( + refined_distances_host.data_handle(), batch.size(), top_k); + res.sync_stream(); + + raft::neighbors::detail::refine_host( // res, + dataset, + queries_host_view, + neighbors_host_view, + refined_neighbors_host_view, + refined_distances_host_view, + build_params->metric); + } else { + auto neighbor_candidates_view = make_device_matrix_view( + neighbors.data_handle(), batch.size(), gpu_top_k); + auto refined_neighbors_view = make_device_matrix_view( + refined_neighbors.data_handle(), batch.size(), top_k); + auto refined_distances_view = make_device_matrix_view( + refined_distances.data_handle(), batch.size(), top_k); + + auto dataset_view = make_device_matrix_view( + dataset.data_handle(), dataset.extent(0), dataset.extent(1)); + raft::neighbors::detail::refine_device( + res, + dataset_view, + queries_view, + neighbor_candidates_view, + refined_neighbors_view, + refined_distances_view, + build_params->metric); + raft::copy(refined_neighbors_host.data_handle(), + refined_neighbors_view.data_handle(), + refined_neighbors_view.size(), + res.get_stream()); + res.sync_stream(); + } + // omit itself & write out + // TODO(tfeher): do this in parallel with GPU processing of next batch + for (std::size_t i = 0; i < batch.size(); i++) { + size_t vec_idx = i + batch.offset(); + for (std::size_t j = 0, num_added = 0; j < top_k && num_added < node_degree; j++) { + const auto v = refined_neighbors_host(i, j); + if (static_cast(v) == vec_idx) { + num_self_included++; + continue; + } + knn_graph(vec_idx, num_added) = v; + num_added++; + } + } + + size_t num_queries_done = batch.offset() + batch.size(); + const auto end_clock = std::chrono::system_clock::now(); + const auto time = + std::chrono::duration_cast(end_clock - start_clock).count() * 1e-6; + const auto throughput = num_queries_done / time; + RAFT_LOG_DEBUG( + "# Search %12lu / %12lu (%3.2f %%), %e queries/sec, %.2f minutes ETA, self included = " + "%3.2f %% \r", + num_queries_done, + dataset.extent(0), + num_queries_done / static_cast(dataset.extent(0)) * 100, + throughput, + (num_queries - num_queries_done) / throughput / 60, + static_cast(num_self_included) / num_queries_done * 100.); + first = false; + } + if (!first) RAFT_LOG_DEBUG("# Finished building kNN graph"); +} + +} // namespace raft::neighbors::experimental::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh new file mode 100644 index 0000000000..79cbb6198f --- /dev/null +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#include "factory.cuh" +#include "search_multi_cta.cuh" +#include "search_multi_kernel.cuh" +#include "search_plan.cuh" +#include "search_single_cta.cuh" + +namespace raft::neighbors::experimental::cagra::detail { + +/** + * @brief Search ANN using the constructed index. + * + * See the [build](#build) documentation for a usage example. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle + * @param[in] params configure the search + * @param[in] idx ivf-pq constructed index + * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * k] + */ + +template +void search_main(raft::device_resources const& res, + search_params params, + const index& index, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) +{ + RAFT_LOG_DEBUG("# dataset size = %lu, dim = %lu\n", + static_cast(index.dataset().extent(0)), + static_cast(index.dataset().extent(1))); + RAFT_LOG_DEBUG("# query size = %lu, dim = %lu\n", + static_cast(queries.extent(0)), + static_cast(queries.extent(1))); + RAFT_EXPECTS(queries.extent(1) == index.dim(), "Querise and index dim must match"); + uint32_t topk = neighbors.extent(1); + + std::unique_ptr> plan = + factory::create(res, params, index.dim(), index.graph_degree(), topk); + + plan->check(neighbors.extent(1)); + + RAFT_LOG_DEBUG("Cagra search"); + uint32_t max_queries = plan->max_queries; + uint32_t query_dim = queries.extent(1); + + for (unsigned qid = 0; qid < queries.extent(0); qid += max_queries) { + const uint32_t n_queries = std::min(max_queries, queries.extent(0) - qid); + IdxT* _topk_indices_ptr = neighbors.data_handle() + (topk * qid); + DistanceT* _topk_distances_ptr = distances.data_handle() + (topk * qid); + // todo(tfeher): one could keep distances optional and pass nullptr + const T* _query_ptr = queries.data_handle() + (query_dim * qid); + const IdxT* _seed_ptr = + plan->num_seeds > 0 ? plan->dev_seed.data() + (plan->num_seeds * qid) : nullptr; + uint32_t* _num_executed_iterations = nullptr; + + (*plan)(res, + index.dataset(), + index.graph(), + _topk_indices_ptr, + _topk_distances_ptr, + _query_ptr, + n_queries, + _seed_ptr, + _num_executed_iterations, + topk); + } +} +/** @} */ // end group cagra + +} // namespace raft::neighbors::experimental::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh new file mode 100644 index 0000000000..171f261cf3 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include + +namespace raft::neighbors::experimental::cagra::detail { + +// Serialization version 1. +constexpr int serialization_version = 1; + +// NB: we wrap this check in a struct, so that the updated RealSize is easy to see in the error +// message. +template +struct check_index_layout { + static_assert(RealSize == ExpectedSize, + "The size of the index struct has changed since the last update; " + "paste in the new size and consider updating the serialization logic"); +}; + +template struct check_index_layout), 136>; + +/** + * Save the index to file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @param[in] res the raft resource handle + * @param[in] filename the file name for saving the index + * @param[in] index_ CAGRA index + * + */ +template +void serialize(raft::device_resources const& res, std::ostream& os, const index& index_) +{ + RAFT_LOG_DEBUG( + "Saving CAGRA index, size %zu, dim %u", static_cast(index_.size()), index_.dim()); + + serialize_scalar(res, os, serialization_version); + serialize_scalar(res, os, index_.size()); + serialize_scalar(res, os, index_.dim()); + serialize_scalar(res, os, index_.graph_degree()); + serialize_scalar(res, os, index_.metric()); + serialize_mdspan(res, os, index_.dataset()); + serialize_mdspan(res, os, index_.graph()); +} + +template +void serialize(raft::device_resources const& res, + const std::string& filename, + const index& index_) +{ + std::ofstream of(filename, std::ios::out | std::ios::binary); + if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } + + detail::serialize(res, of, index_); + + of.close(); + if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } +} + +/** Load an index from file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @param[in] res the raft resource handle + * @param[in] filename the name of the file that stores the index + * @param[in] index_ CAGRA index + * + */ +template +auto deserialize(raft::device_resources const& res, std::istream& is) -> index +{ + auto ver = deserialize_scalar(res, is); + if (ver != serialization_version) { + RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver); + } + auto n_rows = deserialize_scalar(res, is); + auto dim = deserialize_scalar(res, is); + auto graph_degree = deserialize_scalar(res, is); + auto metric = deserialize_scalar(res, is); + + auto dataset = raft::make_host_matrix(n_rows, dim); + auto graph = raft::make_host_matrix(n_rows, graph_degree); + + deserialize_mdspan(res, is, dataset.view()); + deserialize_mdspan(res, is, graph.view()); + + return index(res, metric, raft::make_const_mdspan(dataset.view()), graph.view()); +} + +template +auto deserialize(raft::device_resources const& res, const std::string& filename) -> index +{ + std::ifstream is(filename, std::ios::in | std::ios::binary); + + if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } + + auto index = detail::deserialize(res, is); + + is.close(); + + return index; +} +} // namespace raft::neighbors::experimental::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp b/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp new file mode 100644 index 0000000000..a05c714700 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp @@ -0,0 +1,253 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "device_common.hpp" +#include "hashmap.hpp" +#include "utils.hpp" +#include + +namespace raft::neighbors::experimental::cagra::detail { +namespace device { + +// using LOAD_256BIT_T = ulonglong4; +using LOAD_128BIT_T = uint4; +using LOAD_64BIT_T = uint64_t; + +template +_RAFT_DEVICE constexpr unsigned get_vlen() +{ + return utils::size_of() / utils::size_of(); +} + +template +struct data_load_t { + union { + LOAD_T load; + DATA_T data[VLEN]; + }; +}; + +template +_RAFT_DEVICE void compute_distance_to_random_nodes( + INDEX_T* const result_indices_ptr, // [num_pickup] + DISTANCE_T* const result_distances_ptr, // [num_pickup] + const float* const query_buffer, + const DATA_T* const dataset_ptr, // [dataset_size, dataset_dim] + const std::size_t dataset_dim, + const std::size_t dataset_size, + const std::size_t num_pickup, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const INDEX_T* seed_ptr, // [num_seeds] + const uint32_t num_seeds, + uint32_t* const visited_hash_ptr, + const uint32_t hash_bitlen, + const uint32_t block_id = 0, + const uint32_t num_blocks = 1) +{ + const unsigned lane_id = threadIdx.x % TEAM_SIZE; + constexpr unsigned vlen = get_vlen(); + constexpr unsigned nelem = (MAX_DATASET_DIM + (TEAM_SIZE * vlen) - 1) / (TEAM_SIZE * vlen); + struct data_load_t dl_buff[nelem]; + uint32_t max_i = num_pickup; + if (max_i % (32 / TEAM_SIZE)) { max_i += (32 / TEAM_SIZE) - (max_i % (32 / TEAM_SIZE)); } + for (uint32_t i = threadIdx.x / TEAM_SIZE; i < max_i; i += blockDim.x / TEAM_SIZE) { + const bool valid_i = (i < num_pickup); + + INDEX_T best_index_team_local; + DISTANCE_T best_norm2_team_local = utils::get_max_value(); + for (uint32_t j = 0; j < num_distilation; j++) { + // Select a node randomly and compute the distance to it + uint32_t seed_index; + DISTANCE_T norm2 = 0.0; + if (valid_i) { + // uint32_t gid = i + (num_pickup * (j + (num_distilation * block_id))); + uint32_t gid = block_id + (num_blocks * (i + (num_pickup * j))); + if (seed_ptr && (gid < num_seeds)) { + seed_index = seed_ptr[gid]; + } else { + seed_index = device::xorshift64(gid ^ rand_xor_mask) % dataset_size; + } +#pragma unroll + for (uint32_t e = 0; e < nelem; e++) { + const uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen; + if (k >= dataset_dim) break; + dl_buff[e].load = ((LOAD_T*)(dataset_ptr + k + (dataset_dim * seed_index)))[0]; + } +#pragma unroll + for (uint32_t e = 0; e < nelem; e++) { + const uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen; + if (k >= dataset_dim) break; +#pragma unroll + for (uint32_t v = 0; v < vlen; v++) { + const uint32_t kv = k + v; + // if (kv >= dataset_dim) break; + DISTANCE_T diff = query_buffer[device::swizzling(kv)]; + diff -= static_cast(dl_buff[e].data[v]) * device::fragment_scale(); + norm2 += diff * diff; + } + } + } + for (uint32_t offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) { + norm2 += __shfl_xor_sync(0xffffffff, norm2, offset); + } + + if (valid_i && (norm2 < best_norm2_team_local)) { + best_norm2_team_local = norm2; + best_index_team_local = seed_index; + } + } + + if (valid_i && (threadIdx.x % TEAM_SIZE == 0)) { + if (hashmap::insert(visited_hash_ptr, hash_bitlen, best_index_team_local)) { + result_distances_ptr[i] = best_norm2_team_local; + result_indices_ptr[i] = best_index_team_local; + } else { + result_distances_ptr[i] = utils::get_max_value(); + result_indices_ptr[i] = utils::get_max_value(); + } + } + } +} + +template +_RAFT_DEVICE void compute_distance_to_child_nodes(INDEX_T* const result_child_indices_ptr, + DISTANCE_T* const result_child_distances_ptr, + // query + const float* const query_buffer, + // [dataset_dim, dataset_size] + const DATA_T* const dataset_ptr, + const std::size_t dataset_dim, + // [knn_k, dataset_size] + const INDEX_T* const knn_graph, + const std::uint32_t knn_k, + // hashmap + std::uint32_t* const visited_hashmap_ptr, + const std::uint32_t hash_bitlen, + const INDEX_T* const parent_indices, + const std::uint32_t num_parents) +{ + const INDEX_T invalid_index = utils::get_max_value(); + + // Read child indices of parents from knn graph and check if the distance + // computaiton is necessary. + for (uint32_t i = threadIdx.x; i < knn_k * num_parents; i += BLOCK_SIZE) { + const INDEX_T parent_id = parent_indices[i / knn_k]; + INDEX_T child_id = invalid_index; + if (parent_id != invalid_index) { + child_id = knn_graph[(i % knn_k) + ((uint64_t)knn_k * parent_id)]; + } + if (child_id != invalid_index) { + if (hashmap::insert(visited_hashmap_ptr, hash_bitlen, child_id) == 0) { + child_id = invalid_index; + } + } + result_child_indices_ptr[i] = child_id; + } + + constexpr unsigned vlen = get_vlen(); + constexpr unsigned nelem = (MAX_DATASET_DIM + (TEAM_SIZE * vlen) - 1) / (TEAM_SIZE * vlen); + const unsigned lane_id = threadIdx.x % TEAM_SIZE; + + // [Notice] + // Loading the query vector here from shared memory into registers reduces + // shared memory trafiic. However, register usage increase. The + // MAX_N_FRAGS below is used as the threshold to enable or disable this, + // but the appropriate value should be discussed. + constexpr unsigned N_FRAGS = (MAX_DATASET_DIM + TEAM_SIZE - 1) / TEAM_SIZE; + float query_frags[N_FRAGS]; + if (N_FRAGS <= MAX_N_FRAGS) { + // Pre-load query vectors into registers when register usage is not too large. +#pragma unroll + for (unsigned e = 0; e < nelem; e++) { + const unsigned k = (lane_id + (TEAM_SIZE * e)) * vlen; + // if (k >= dataset_dim) break; +#pragma unroll + for (unsigned v = 0; v < vlen; v++) { + const unsigned kv = k + v; + const unsigned ev = (vlen * e) + v; + query_frags[ev] = query_buffer[device::swizzling(kv)]; + } + } + } + __syncthreads(); + + // Compute the distance to child nodes + std::uint32_t max_i = knn_k * num_parents; + if (max_i % (32 / TEAM_SIZE)) { max_i += (32 / TEAM_SIZE) - (max_i % (32 / TEAM_SIZE)); } + for (std::uint32_t i = threadIdx.x / TEAM_SIZE; i < max_i; i += BLOCK_SIZE / TEAM_SIZE) { + const bool valid_i = (i < (knn_k * num_parents)); + INDEX_T child_id = invalid_index; + if (valid_i) { child_id = result_child_indices_ptr[i]; } + + DISTANCE_T norm2 = 0.0; + struct data_load_t dl_buff[nelem]; + if (child_id != invalid_index) { +#pragma unroll + for (unsigned e = 0; e < nelem; e++) { + const unsigned k = (lane_id + (TEAM_SIZE * e)) * vlen; + if (k >= dataset_dim) break; + dl_buff[e].load = ((LOAD_T*)(dataset_ptr + k + (dataset_dim * child_id)))[0]; + } +#pragma unroll + for (unsigned e = 0; e < nelem; e++) { + const unsigned k = (lane_id + (TEAM_SIZE * e)) * vlen; + if (k >= dataset_dim) break; +#pragma unroll + for (unsigned v = 0; v < vlen; v++) { + DISTANCE_T diff; + if (N_FRAGS <= MAX_N_FRAGS) { + const unsigned ev = (vlen * e) + v; + diff = query_frags[ev]; + } else { + const unsigned kv = k + v; + diff = query_buffer[device::swizzling(kv)]; + } + diff -= static_cast(dl_buff[e].data[v]) * device::fragment_scale(); + norm2 += diff * diff; + } + } + } + for (unsigned offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) { + norm2 += __shfl_xor_sync(0xffffffff, norm2, offset); + } + + // Store the distance + if (valid_i && (threadIdx.x % TEAM_SIZE == 0)) { + if (child_id != invalid_index) { + result_child_distances_ptr[i] = norm2; + } else { + result_child_distances_ptr[i] = utils::get_max_value(); + } + } + } +} + +} // namespace device +} // namespace raft::neighbors::experimental::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/device_common.hpp b/cpp/include/raft/neighbors/detail/cagra/device_common.hpp new file mode 100644 index 0000000000..20f30d9f11 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/cagra/device_common.hpp @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "utils.hpp" +#include +#include +#include +#include + +namespace raft::neighbors::experimental::cagra::detail { +namespace device { + +// warpSize for compile time calculation +constexpr unsigned warp_size = 32; + +// scaling factor for distance computation +template +_RAFT_HOST_DEVICE constexpr float fragment_scale(); +template <> +_RAFT_HOST_DEVICE constexpr float fragment_scale() +{ + return 1.0; +}; +template <> +_RAFT_HOST_DEVICE constexpr float fragment_scale() +{ + return 1.0; +}; +template <> +_RAFT_HOST_DEVICE constexpr float fragment_scale() +{ + return 1.0 / 256.0; +}; +template <> +_RAFT_HOST_DEVICE constexpr float fragment_scale() +{ + return 1.0 / 128.0; +}; + +/** Xorshift rondem number generator. + * + * See https://en.wikipedia.org/wiki/Xorshift#xorshift for reference. + */ +_RAFT_HOST_DEVICE inline uint64_t xorshift64(uint64_t u) +{ + u ^= u >> 12; + u ^= u << 25; + u ^= u >> 27; + return u * 0x2545F4914F6CDD1DULL; +} + +template +_RAFT_DEVICE inline T swizzling(T x) +{ + // Address swizzling reduces bank conflicts in shared memory, but increases + // the amount of operation instead. + // return x; + return x ^ (x >> 5); // "x" must be less than 1024 +} + +} // namespace device +} // namespace raft::neighbors::experimental::cagra::detail \ No newline at end of file diff --git a/cpp/include/raft/neighbors/detail/cagra/factory.cuh b/cpp/include/raft/neighbors/detail/cagra/factory.cuh new file mode 100644 index 0000000000..beeebc605c --- /dev/null +++ b/cpp/include/raft/neighbors/detail/cagra/factory.cuh @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "search_multi_cta.cuh" +#include "search_multi_kernel.cuh" +#include "search_plan.cuh" +#include "search_single_cta.cuh" + +namespace raft::neighbors::experimental::cagra::detail { + +template +class factory { + public: + /** + * Create a search structure for dataset with dim features. + */ + static std::unique_ptr> create( + raft::device_resources const& res, + search_params const& params, + int64_t dim, + int64_t graph_degree, + uint32_t topk) + { + search_plan_impl_base plan(params, dim, graph_degree, topk); + switch (plan.max_dim) { + case 128: + switch (plan.team_size) { + case 8: return dispatch_kernel<128, 8>(res, plan); break; + default: THROW("Incorrect team size %lu", plan.team_size); + } + break; + case 256: + switch (plan.team_size) { + case 16: return dispatch_kernel<256, 16>(res, plan); break; + default: THROW("Incorrect team size %lu", plan.team_size); + } + break; + case 512: + switch (plan.team_size) { + case 32: return dispatch_kernel<512, 32>(res, plan); break; + default: THROW("Incorrect team size %lu", plan.team_size); + } + break; + case 1024: + switch (plan.team_size) { + case 32: return dispatch_kernel<1024, 32>(res, plan); break; + default: THROW("Incorrect team size %lu", plan.team_size); + } + break; + default: RAFT_LOG_DEBUG("Incorrect max_dim (%lu)\n", plan.max_dim); + } + return std::unique_ptr>(); + } + + private: + template + static std::unique_ptr> dispatch_kernel( + raft::device_resources const& res, search_plan_impl_base& plan) + { + if (plan.algo == search_algo::SINGLE_CTA) { + return std::unique_ptr>( + new single_cta_search::search( + res, plan, plan.dim, plan.graph_degree, plan.topk)); + } else if (plan.algo == search_algo::MULTI_CTA) { + return std::unique_ptr>( + new multi_cta_search::search( + res, plan, plan.dim, plan.graph_degree, plan.topk)); + } else { + return std::unique_ptr>( + new multi_kernel_search::search( + res, plan, plan.dim, plan.graph_degree, plan.topk)); + } + } +}; +}; // namespace raft::neighbors::experimental::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/fragment.hpp b/cpp/include/raft/neighbors/detail/cagra/fragment.hpp new file mode 100644 index 0000000000..d5ec2207e7 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/cagra/fragment.hpp @@ -0,0 +1,212 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "device_common.hpp" +#include "utils.hpp" +#include +#include + +namespace raft::neighbors::experimental::cagra::detail { +namespace device { + +namespace detail { +template +struct load_unit_t { + using type = uint4; +}; +template <> +struct load_unit_t<8> { + using type = std::uint64_t; +}; +template <> +struct load_unit_t<4> { + using type = std::uint32_t; +}; +template <> +struct load_unit_t<2> { + using type = std::uint16_t; +}; +template <> +struct load_unit_t<1> { + using type = std::uint8_t; +}; +} // namespace detail + +// One dataset or query vector is distributed within a warp and stored as `fragment`. +template +struct fragment_base { +}; +template +struct fragment + : fragment_base()) == 0>::type> { + static constexpr unsigned num_elements = DIM / TEAM_SIZE; + using block_t = typename detail::load_unit_t()>::type; + static constexpr unsigned num_load_blocks = + num_elements * utils::size_of() / utils::size_of(); + + union { + T x[num_elements]; + block_t load_block[num_load_blocks]; + }; +}; + +// Load a vector from device/shared memory +template +_RAFT_DEVICE void load_vector_sync(device::fragment& frag, + const INPUT_T* const input_vector_ptr, + const unsigned input_vector_length, + const bool sync = true) +{ + const auto lane_id = threadIdx.x % TEAM_SIZE; + if (DIM == input_vector_length) { + for (unsigned i = 0; i < frag.num_load_blocks; i++) { + const auto vector_index = i * TEAM_SIZE + lane_id; + frag.load_block[i] = + reinterpret_cast::block_t*>( + input_vector_ptr)[vector_index]; + } + } else { + for (unsigned i = 0; i < frag.num_elements; i++) { + const auto vector_index = i * TEAM_SIZE + lane_id; + + INPUT_T v; + if (vector_index < input_vector_length) { + v = static_cast(input_vector_ptr[vector_index]); + } else { + v = static_cast(0); + } + + frag.x[i] = v; + } + } + if (sync) { __syncwarp(); } +} + +// Compute the square of the L2 norm of two vectors +template +_RAFT_DEVICE COMPUTE_T norm2(const device::fragment& a, + const device::fragment& b) +{ + COMPUTE_T sum = 0; + + // Compute the thread-local norm2 + for (unsigned i = 0; i < a.num_elements; i++) { + const auto diff = static_cast(a.x[i]) - static_cast(b.x[i]); + sum += diff * diff; + } + + // Compute the result norm2 summing up the thread-local norm2s. + for (unsigned offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) + sum += __shfl_xor_sync(0xffffffff, sum, offset); + + return sum; +} + +template +_RAFT_DEVICE COMPUTE_T norm2(const device::fragment& a, + const device::fragment& b, + const float scale) +{ + COMPUTE_T sum = 0; + + // Compute the thread-local norm2 + for (unsigned i = 0; i < a.num_elements; i++) { + const auto diff = + static_cast((static_cast(a.x[i]) - static_cast(b.x[i])) * scale); + sum += diff * diff; + } + + // Compute the result norm2 summing up the thread-local norm2s. + for (unsigned offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) + sum += __shfl_xor_sync(0xffffffff, sum, offset); + + return sum; +} + +template +_RAFT_DEVICE COMPUTE_T norm2(const device::fragment& a, + const T* b, // [DIM] + const float scale) +{ + COMPUTE_T sum = 0; + + // Compute the thread-local norm2 + const unsigned chunk_size = a.num_elements / a.num_load_blocks; + const unsigned lane_id = threadIdx.x % TEAM_SIZE; + for (unsigned i = 0; i < a.num_elements; i++) { + unsigned j = (i % chunk_size) + chunk_size * (lane_id + TEAM_SIZE * (i / chunk_size)); + const auto diff = static_cast(a.x[i] * scale) - static_cast(b[j] * scale); + sum += diff * diff; + } + + // Compute the result norm2 summing up the thread-local norm2s. + for (unsigned offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) + sum += __shfl_xor_sync(0xffffffff, sum, offset); + + return sum; +} + +template +_RAFT_DEVICE inline COMPUTE_T norm2x(const device::fragment& a, + const COMPUTE_T* b, // [dim] + const uint32_t dim, + const float scale) +{ + // Compute the thread-local norm2 + COMPUTE_T sum = 0; + const unsigned lane_id = threadIdx.x % TEAM_SIZE; + if (dim == DIM) { + const unsigned chunk_size = a.num_elements / a.num_load_blocks; + for (unsigned i = 0; i < a.num_elements; i++) { + unsigned j = (i % chunk_size) + chunk_size * (lane_id + TEAM_SIZE * (i / chunk_size)); + const auto diff = static_cast(a.x[i] * scale) - b[j]; + sum += diff * diff; + } + } else { + for (unsigned i = 0; i < a.num_elements; i++) { + unsigned j = lane_id + (TEAM_SIZE * i); + if (j >= dim) break; + const auto diff = static_cast(a.x[i] * scale) - b[j]; + sum += diff * diff; + } + } + + // Compute the result norm2 summing up the thread-local norm2s. + for (unsigned offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) + sum += __shfl_xor_sync(0xffffffff, sum, offset); + + return sum; +} + +template +_RAFT_DEVICE void print_fragment(const device::fragment& a) +{ + for (unsigned i = 0; i < TEAM_SIZE; i++) { + if ((threadIdx.x % TEAM_SIZE) == i) { + for (unsigned j = 0; j < a.num_elements; j++) { + RAFT_LOG_DEBUG("%+e ", static_cast(a.x[j])); + } + } + __syncwarp(); + } +} + +} // namespace device +} // namespace raft::neighbors::experimental::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh new file mode 100644 index 0000000000..568ad0826c --- /dev/null +++ b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh @@ -0,0 +1,809 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace raft::neighbors::experimental::cagra::detail { +namespace graph { + +template +__host__ __device__ float compute_norm2(const T* a, + const T* b, + const std::size_t dim, + const float scale) +{ + float sum = 0.f; + for (std::size_t j = 0; j < dim; j++) { + const auto diff = a[j] * scale - b[j] * scale; + sum += diff * diff; + } + return sum; +} + +inline double cur_time(void) +{ + struct timeval tv; + gettimeofday(&tv, NULL); + return ((double)tv.tv_sec + (double)tv.tv_usec * 1e-6); +} + +template +__device__ inline void swap(T& val1, T& val2) +{ + T val0 = val1; + val1 = val2; + val2 = val0; +} + +template +__device__ inline bool swap_if_needed(K& key1, K& key2, V& val1, V& val2, bool ascending) +{ + if (key1 == key2) { return false; } + if ((key1 > key2) == ascending) { + swap(key1, key2); + swap(val1, val2); + return true; + } + return false; +} + +template +__global__ void kern_sort( + DATA_T** dataset, // [num_gpus][dataset_chunk_size, dataset_dim] + uint32_t dataset_size, + uint32_t dataset_chunk_size, // (*) num_gpus * dataset_chunk_size >= dataset_size + uint32_t dataset_dim, + float scale, + uint32_t** knn_graph, // [num_gpus][graph_chunk_size, graph_degree] + uint32_t graph_size, + uint32_t graph_chunk_size, // (*) num_gpus * graph_chunk_size >= graph_size + uint32_t graph_degree, + int dev_id) +{ + __shared__ float smem_keys[blockDim_x * numElementsPerThread]; + __shared__ uint32_t smem_vals[blockDim_x * numElementsPerThread]; + + uint64_t srcNode = blockIdx.x + ((uint64_t)graph_chunk_size * dev_id); + uint64_t srcNode_dev = srcNode / graph_chunk_size; + uint64_t srcNode_loc = srcNode % graph_chunk_size; + if (srcNode >= graph_size) { return; } + + const uint32_t num_warps = blockDim_x / 32; + const uint32_t warp_id = threadIdx.x / 32; + const uint32_t lane_id = threadIdx.x % 32; + + // Compute distance from a src node to its neighbors + for (int k = warp_id; k < graph_degree; k += num_warps) { + uint64_t dstNode = knn_graph[srcNode_dev][k + ((uint64_t)graph_degree * srcNode_loc)]; + uint64_t dstNode_dev = dstNode / graph_chunk_size; + uint64_t dstNode_loc = dstNode % graph_chunk_size; + float dist = 0.0; + for (int d = lane_id; d < dataset_dim; d += 32) { + float diff = + (float)(dataset[srcNode_dev][d + ((uint64_t)dataset_dim * srcNode_loc)]) * scale - + (float)(dataset[dstNode_dev][d + ((uint64_t)dataset_dim * dstNode_loc)]) * scale; + dist += diff * diff; + } + dist += __shfl_xor_sync(0xffffffff, dist, 1); + dist += __shfl_xor_sync(0xffffffff, dist, 2); + dist += __shfl_xor_sync(0xffffffff, dist, 4); + dist += __shfl_xor_sync(0xffffffff, dist, 8); + dist += __shfl_xor_sync(0xffffffff, dist, 16); + if (lane_id == 0) { + smem_keys[k] = dist; + smem_vals[k] = dstNode; + } + } + __syncthreads(); + + float my_keys[numElementsPerThread]; + uint32_t my_vals[numElementsPerThread]; + for (int i = 0; i < numElementsPerThread; i++) { + int k = i + (numElementsPerThread * threadIdx.x); + if (k < graph_degree) { + my_keys[i] = smem_keys[k]; + my_vals[i] = smem_vals[k]; + } else { + my_keys[i] = FLT_MAX; + my_vals[i] = 0xffffffffU; + } + } + __syncthreads(); + + // Sorting by thread + uint32_t mask = 1; + bool ascending = ((threadIdx.x & mask) == 0); + for (int j = 0; j < numElementsPerThread; j += 2) { +#pragma unroll + for (int i = 0; i < numElementsPerThread; i += 2) { + swap_if_needed( + my_keys[i], my_keys[i + 1], my_vals[i], my_vals[i + 1], ascending); + } +#pragma unroll + for (int i = 1; i < numElementsPerThread - 1; i += 2) { + swap_if_needed( + my_keys[i], my_keys[i + 1], my_vals[i], my_vals[i + 1], ascending); + } + } + + // Bitonic Sorting + while (mask < blockDim_x) { + uint32_t next_mask = mask << 1; + + for (uint32_t curr_mask = mask; curr_mask > 0; curr_mask >>= 1) { + bool ascending = ((threadIdx.x & curr_mask) == 0) == ((threadIdx.x & next_mask) == 0); + if (mask >= 32) { + // inter warp + __syncthreads(); +#pragma unroll + for (int i = 0; i < numElementsPerThread; i++) { + smem_keys[threadIdx.x + (blockDim_x * i)] = my_keys[i]; + smem_vals[threadIdx.x + (blockDim_x * i)] = my_vals[i]; + } + __syncthreads(); +#pragma unroll + for (int i = 0; i < numElementsPerThread; i++) { + float opp_key = smem_keys[(threadIdx.x ^ curr_mask) + (blockDim_x * i)]; + uint32_t opp_val = smem_vals[(threadIdx.x ^ curr_mask) + (blockDim_x * i)]; + swap_if_needed(my_keys[i], opp_key, my_vals[i], opp_val, ascending); + } + } else { +// intra warp +#pragma unroll + for (int i = 0; i < numElementsPerThread; i++) { + float opp_key = __shfl_xor_sync(0xffffffff, my_keys[i], curr_mask); + uint32_t opp_val = __shfl_xor_sync(0xffffffff, my_vals[i], curr_mask); + swap_if_needed(my_keys[i], opp_key, my_vals[i], opp_val, ascending); + } + } + } + + bool ascending = ((threadIdx.x & next_mask) == 0); +#pragma unroll + for (uint32_t curr_mask = numElementsPerThread / 2; curr_mask > 0; curr_mask >>= 1) { +#pragma unroll + for (int i = 0; i < numElementsPerThread; i++) { + int j = i ^ curr_mask; + if (i > j) continue; + swap_if_needed(my_keys[i], my_keys[j], my_vals[i], my_vals[j], ascending); + } + } + mask = next_mask; + } + + // Update knn_graph + for (int i = 0; i < numElementsPerThread; i++) { + int k = i + (numElementsPerThread * threadIdx.x); + if (k < graph_degree) { + knn_graph[srcNode_dev][k + ((uint64_t)graph_degree * srcNode_loc)] = my_vals[i]; + } + } +} + +template +__global__ void kern_prune( + uint32_t** knn_graph, // [num_gpus][graph_chunk_size, graph_degree] + uint32_t graph_size, + uint32_t graph_chunk_size, // (*) num_gpus * graph_chunk_size >= graph_size + uint32_t graph_degree, + uint32_t degree, + int dev_id, + uint32_t batch_size, + uint32_t batch_id, + uint8_t** detour_count, // [num_gpus][graph_chunk_size, graph_degree] + uint32_t** num_no_detour_edges, // [num_gpus][graph_size] + uint64_t* stats) +{ + __shared__ uint32_t smem_num_detour[MAX_DEGREE]; + uint64_t* num_retain = stats; + uint64_t* num_full = stats + 1; + + uint64_t nid = blockIdx.x + (batch_size * batch_id); + if (nid >= graph_chunk_size) { return; } + for (uint32_t k = threadIdx.x; k < graph_degree; k += blockDim.x) { + smem_num_detour[k] = 0; + } + __syncthreads(); + + uint64_t iA = nid + ((uint64_t)graph_chunk_size * dev_id); + uint64_t iA_dev = iA / graph_chunk_size; + uint64_t iA_loc = iA % graph_chunk_size; + if (iA >= graph_size) { return; } + + // count number of detours (A->D->B) + for (uint32_t kAD = 0; kAD < graph_degree - 1; kAD++) { + uint64_t iD = knn_graph[iA_dev][kAD + (graph_degree * iA_loc)]; + uint64_t iD_dev = iD / graph_chunk_size; + uint64_t iD_loc = iD % graph_chunk_size; + for (uint32_t kDB = threadIdx.x; kDB < graph_degree; kDB += blockDim.x) { + uint64_t iB_candidate = knn_graph[iD_dev][kDB + ((uint64_t)graph_degree * iD_loc)]; + for (uint32_t kAB = kAD + 1; kAB < graph_degree; kAB++) { + // if ( kDB < kAB ) + { + uint64_t iB = knn_graph[iA_dev][kAB + (graph_degree * iA_loc)]; + if (iB == iB_candidate) { + atomicAdd(smem_num_detour + kAB, 1); + break; + } + } + } + } + __syncthreads(); + } + + uint32_t num_edges_no_detour = 0; + for (uint32_t k = threadIdx.x; k < graph_degree; k += blockDim.x) { + detour_count[iA_dev][k + (graph_degree * iA_loc)] = min(smem_num_detour[k], (uint32_t)255); + if (smem_num_detour[k] == 0) { num_edges_no_detour++; } + } + num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 1); + num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 2); + num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 4); + num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 8); + num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 16); + num_edges_no_detour = min(num_edges_no_detour, degree); + + if (threadIdx.x == 0) { + num_no_detour_edges[iA_dev][iA_loc] = num_edges_no_detour; + atomicAdd((unsigned long long int*)num_retain, (unsigned long long int)num_edges_no_detour); + if (num_edges_no_detour >= degree) { atomicAdd((unsigned long long int*)num_full, 1); } + } +} + +// unnamed namespace to avoid multiple definition error +namespace { +__global__ void kern_make_rev_graph(const uint32_t i_gpu, + const uint32_t* dest_nodes, // [global_graph_size] + const uint32_t global_graph_size, + uint32_t* rev_graph, // [graph_size, degree] + uint32_t* rev_graph_count, // [graph_size] + const uint32_t graph_size, + const uint32_t degree) +{ + const uint32_t tid = threadIdx.x + (blockDim.x * blockIdx.x); + const uint32_t tnum = blockDim.x * gridDim.x; + + for (uint32_t gl_src_id = tid; gl_src_id < global_graph_size; gl_src_id += tnum) { + uint32_t gl_dest_id = dest_nodes[gl_src_id]; + if (gl_dest_id < graph_size * i_gpu) continue; + if (gl_dest_id >= graph_size * (i_gpu + 1)) continue; + if (gl_dest_id >= global_graph_size) continue; + + uint32_t dest_id = gl_dest_id - (graph_size * i_gpu); + uint32_t pos = atomicAdd(rev_graph_count + dest_id, 1); + if (pos < degree) { rev_graph[pos + ((uint64_t)degree * dest_id)] = gl_src_id; } + } +} +} // namespace +template +T*** mgpu_alloc(int n_gpus, uint32_t chunk, uint32_t nelems) +{ + T** arrays; // [n_gpus][chunk, nelems] + arrays = (T**)malloc(sizeof(T*) * n_gpus); /* h1 */ + size_t bsize = sizeof(T) * chunk * nelems; + // RAFT_LOG_DEBUG("[%s, %s, %d] n_gpus: %d, chunk: %u, nelems: %u, bsize: %lu (%lu MiB)\n", + // __FILE__, __func__, __LINE__, n_gpus, chunk, nelems, bsize, bsize / 1024 / 1024); + for (int i_gpu = 0; i_gpu < n_gpus; i_gpu++) { + RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); + RAFT_CUDA_TRY(cudaMalloc(&(arrays[i_gpu]), bsize)); /* d1 */ + } + T*** d_arrays; // [n_gpus+1][n_gpus][chunk, nelems] + d_arrays = (T***)malloc(sizeof(T**) * (n_gpus + 1)); /* h2 */ + bsize = sizeof(T*) * n_gpus; + for (int i_gpu = 0; i_gpu < n_gpus; i_gpu++) { + RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); + RAFT_CUDA_TRY(cudaMalloc(&(d_arrays[i_gpu]), bsize)); /* d2 */ + RAFT_CUDA_TRY(cudaMemcpy(d_arrays[i_gpu], arrays, bsize, cudaMemcpyDefault)); + } + RAFT_CUDA_TRY(cudaSetDevice(0)); + d_arrays[n_gpus] = arrays; + return d_arrays; +} + +template +void mgpu_free(T*** d_arrays, int n_gpus) +{ + for (int i_gpu = 0; i_gpu < n_gpus; i_gpu++) { + RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); + RAFT_CUDA_TRY(cudaFree(d_arrays[n_gpus][i_gpu])); /* d1 */ + RAFT_CUDA_TRY(cudaFree(d_arrays[i_gpu])); /* d2 */ + } + RAFT_CUDA_TRY(cudaSetDevice(0)); + free(d_arrays[n_gpus]); /* h1 */ + free(d_arrays); /* h2 */ +} + +template +void mgpu_H2D(T*** d_arrays, // [n_gpus+1][n_gpus][chunk, nelems] + const T* h_array, // [size, nelems] + int n_gpus, + uint32_t size, + uint32_t chunk, // (*) n_gpus * chunk >= size + uint32_t nelems) +{ +#pragma omp parallel num_threads(n_gpus) + { + int i_gpu = omp_get_thread_num(); + RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); + uint32_t _chunk = std::min(size - (chunk * i_gpu), chunk); + size_t bsize = sizeof(T) * _chunk * nelems; + RAFT_CUDA_TRY(cudaMemcpy(d_arrays[n_gpus][i_gpu], + h_array + ((uint64_t)chunk * nelems * i_gpu), + bsize, + cudaMemcpyDefault)); + } + RAFT_CUDA_TRY(cudaDeviceSynchronize()); + RAFT_CUDA_TRY(cudaSetDevice(0)); +} + +template +void mgpu_D2H(T*** d_arrays, // [n_gpus+1][n_gpus][chunk, nelems] + T* h_array, // [size, nelems] + int n_gpus, + uint32_t size, + uint32_t chunk, // (*) n_gpus * chunk >= size + uint32_t nelems) +{ +#pragma omp parallel num_threads(n_gpus) + { + int i_gpu = omp_get_thread_num(); + RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); + uint32_t _chunk = std::min(size - (chunk * i_gpu), chunk); + size_t bsize = sizeof(T) * _chunk * nelems; + RAFT_CUDA_TRY(cudaMemcpy(h_array + ((uint64_t)chunk * nelems * i_gpu), + d_arrays[n_gpus][i_gpu], + bsize, + cudaMemcpyDefault)); + } + RAFT_CUDA_TRY(cudaDeviceSynchronize()); + RAFT_CUDA_TRY(cudaSetDevice(0)); +} + +template +uint64_t pos_in_array(T val, const T* array, uint64_t num) +{ + for (uint64_t i = 0; i < num; i++) { + if (val == array[i]) { return i; } + } + return num; +} + +template +void shift_array(T* array, uint64_t num) +{ + for (uint64_t i = num; i > 0; i--) { + array[i] = array[i - 1]; + } +} + +/** Input arrays can be both host and device*/ +template , memory_type::device>, + typename g_accessor = + host_device_accessor, memory_type::host>> +void prune(raft::device_resources const& res, + mdspan, row_major, d_accessor> dataset, + mdspan, row_major, g_accessor> knn_graph, + raft::host_matrix_view new_graph) +{ + RAFT_LOG_DEBUG( + "# Pruning kNN graph (size=%lu, degree=%lu)\n", knn_graph.extent(0), knn_graph.extent(1)); + + RAFT_EXPECTS( + dataset.extent(0) == knn_graph.extent(0) && knn_graph.extent(0) == new_graph.extent(0), + "Each input array is expected to have the same number of rows"); + RAFT_EXPECTS(new_graph.extent(1) <= knn_graph.extent(1), + "output graph cannot have more columns than input graph"); + const uint32_t dataset_size = dataset.extent(0); + const uint32_t dataset_dim = dataset.extent(1); + const uint32_t input_graph_degree = knn_graph.extent(1); + const uint32_t output_graph_degree = new_graph.extent(1); + const DATA_T* dataset_ptr = dataset.data_handle(); + uint32_t* input_graph_ptr = (uint32_t*)knn_graph.data_handle(); + uint32_t* output_graph_ptr = new_graph.data_handle(); + float scale = 1.0f / raft::spatial::knn::detail::utils::config::kDivisor; + const std::size_t graph_size = dataset_size; + size_t array_size; + + // Setup GPUs + int num_gpus = 0; + + // Setup GPUs + RAFT_CUDA_TRY(cudaGetDeviceCount(&num_gpus)); + RAFT_LOG_DEBUG("# num_gpus: %d\n", num_gpus); + for (int self = 0; self < num_gpus; self++) { + RAFT_CUDA_TRY(cudaSetDevice(self)); + for (int peer = 0; peer < num_gpus; peer++) { + if (self == peer) { continue; } + RAFT_CUDA_TRY(cudaDeviceEnablePeerAccess(peer, 0)); + } + } + RAFT_CUDA_TRY(cudaSetDevice(0)); + + uint32_t graph_chunk_size = graph_size; + uint32_t*** d_input_graph_ptr = NULL; // [...][num_gpus][graph_chunk_size, input_graph_degree] + graph_chunk_size = (graph_size + num_gpus - 1) / num_gpus; + d_input_graph_ptr = mgpu_alloc(num_gpus, graph_chunk_size, input_graph_degree); + + uint32_t dataset_chunk_size = dataset_size; + DATA_T*** d_dataset_ptr = NULL; // [num_gpus+1][...][...] + dataset_chunk_size = (dataset_size + num_gpus - 1) / num_gpus; + assert(dataset_chunk_size == graph_chunk_size); + d_dataset_ptr = mgpu_alloc(num_gpus, dataset_chunk_size, dataset_dim); + + mgpu_H2D( + d_dataset_ptr, dataset_ptr, num_gpus, dataset_size, dataset_chunk_size, dataset_dim); + + // + // Sorting kNN graph + // + double time_sort_start = cur_time(); + RAFT_LOG_DEBUG("# Sorting kNN Graph on GPUs "); + mgpu_H2D( + d_input_graph_ptr, input_graph_ptr, num_gpus, graph_size, graph_chunk_size, input_graph_degree); + void (*kernel_sort)( + DATA_T**, uint32_t, uint32_t, uint32_t, float, uint32_t**, uint32_t, uint32_t, uint32_t, int); + constexpr int numElementsPerThread = 4; + dim3 threads_sort(1, 1, 1); + if (input_graph_degree <= numElementsPerThread * 32) { + constexpr int blockDim_x = 32; + kernel_sort = kern_sort; + threads_sort.x = blockDim_x; + } else if (input_graph_degree <= numElementsPerThread * 64) { + constexpr int blockDim_x = 64; + kernel_sort = kern_sort; + threads_sort.x = blockDim_x; + } else if (input_graph_degree <= numElementsPerThread * 128) { + constexpr int blockDim_x = 128; + kernel_sort = kern_sort; + threads_sort.x = blockDim_x; + } else if (input_graph_degree <= numElementsPerThread * 256) { + constexpr int blockDim_x = 256; + kernel_sort = kern_sort; + threads_sort.x = blockDim_x; + } else { + fprintf(stderr, + "[ERROR] The degree of input knn graph is too large (%u). " + "It must be equal to or small than %d.\n", + input_graph_degree, + numElementsPerThread * 256); + exit(-1); + } + dim3 blocks_sort(graph_chunk_size, 1, 1); + for (int i_gpu = 0; i_gpu < num_gpus; i_gpu++) { + RAFT_LOG_DEBUG("."); + RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); + kernel_sort<<>>(d_dataset_ptr[i_gpu], + dataset_size, + dataset_chunk_size, + dataset_dim, + scale, + d_input_graph_ptr[i_gpu], + graph_size, + graph_chunk_size, + input_graph_degree, + i_gpu); + } + RAFT_CUDA_TRY(cudaSetDevice(0)); + RAFT_CUDA_TRY(cudaDeviceSynchronize()); + RAFT_LOG_DEBUG("."); + mgpu_D2H( + d_input_graph_ptr, input_graph_ptr, num_gpus, graph_size, graph_chunk_size, input_graph_degree); + RAFT_LOG_DEBUG("\n"); + double time_sort_end = cur_time(); + RAFT_LOG_DEBUG("# Sorting kNN graph time: %.1lf sec\n", time_sort_end - time_sort_start); + + mgpu_free(d_dataset_ptr, num_gpus); + + // + uint8_t* detour_count; // [graph_size, input_graph_degree] + array_size = sizeof(uint8_t) * graph_size * input_graph_degree; + detour_count = (uint8_t*)malloc(array_size); + memset(detour_count, 0xff, array_size); + + uint8_t*** d_detour_count = NULL; // [...][num_gpus][graph_chunk_size, input_graph_degree] + d_detour_count = mgpu_alloc(num_gpus, graph_chunk_size, input_graph_degree); + mgpu_H2D( + d_detour_count, detour_count, num_gpus, graph_size, graph_chunk_size, input_graph_degree); + + // + uint32_t* num_no_detour_edges; // [graph_size] + array_size = sizeof(uint32_t) * graph_size; + num_no_detour_edges = (uint32_t*)malloc(array_size); + memset(num_no_detour_edges, 0, array_size); + + uint32_t*** d_num_no_detour_edges = NULL; // [...][num_gpus][graph_chunk_size] + d_num_no_detour_edges = mgpu_alloc(num_gpus, graph_chunk_size, 1); + mgpu_H2D( + d_num_no_detour_edges, num_no_detour_edges, num_gpus, graph_size, graph_chunk_size, 1); + + // + uint64_t** dev_stats = NULL; // [num_gpus][2] + uint64_t** host_stats = NULL; // [num_gpus][2] + dev_stats = (uint64_t**)malloc(sizeof(uint64_t*) * num_gpus); + host_stats = (uint64_t**)malloc(sizeof(uint64_t*) * num_gpus); + array_size = sizeof(uint64_t) * 2; + for (int i_gpu = 0; i_gpu < num_gpus; i_gpu++) { + RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); + RAFT_CUDA_TRY(cudaMalloc(&(dev_stats[i_gpu]), array_size)); + host_stats[i_gpu] = (uint64_t*)malloc(array_size); + } + RAFT_CUDA_TRY(cudaSetDevice(0)); + + // + // Prune unimportant edges. + // + // The edge to be retained is determined without explicitly considering + // distance or angle. Suppose the edge is the k-th edge of some node-A to + // node-B (A->B). Among the edges originating at node-A, there are k-1 edges + // shorter than the edge A->B. Each of these k-1 edges are connected to a + // different k-1 nodes. Among these k-1 nodes, count the number of nodes with + // edges to node-B, which is the number of 2-hop detours for the edge A->B. + // Once the number of 2-hop detours has been counted for all edges, the + // specified number of edges are picked up for each node, starting with the + // edge with the lowest number of 2-hop detours. + // + double time_prune_start = cur_time(); + uint64_t num_keep = 0; + uint64_t num_full = 0; + RAFT_LOG_DEBUG("# Pruning kNN Graph on GPUs\r"); + mgpu_H2D( + d_input_graph_ptr, input_graph_ptr, num_gpus, graph_size, graph_chunk_size, input_graph_degree); + void (*kernel_prune)(uint32_t**, + uint32_t, + uint32_t, + uint32_t, + uint32_t, + int, + uint32_t, + uint32_t, + uint8_t**, + uint32_t**, + uint64_t*); + if (input_graph_degree <= 1024) { + constexpr int MAX_DEGREE = 1024; + kernel_prune = kern_prune; + } else { + fprintf(stderr, + "[ERROR] The degree of input knn graph is too large (%u). " + "It must be equal to or small than %d.\n", + input_graph_degree, + 1024); + exit(-1); + } + uint32_t batch_size = std::min(graph_chunk_size, (uint32_t)256 * 1024); + uint32_t num_batch = (graph_chunk_size + batch_size - 1) / batch_size; + dim3 threads_prune(32, 1, 1); + dim3 blocks_prune(batch_size, 1, 1); + for (int i_gpu = 0; i_gpu < num_gpus; i_gpu++) { + RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); + RAFT_CUDA_TRY(cudaMemset(dev_stats[i_gpu], 0, sizeof(uint64_t) * 2)); + } + for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { + for (int i_gpu = 0; i_gpu < num_gpus; i_gpu++) { + RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); + kernel_prune<<>>(d_input_graph_ptr[i_gpu], + graph_size, + graph_chunk_size, + input_graph_degree, + output_graph_degree, + i_gpu, + batch_size, + i_batch, + d_detour_count[i_gpu], + d_num_no_detour_edges[i_gpu], + dev_stats[i_gpu]); + } + RAFT_CUDA_TRY(cudaDeviceSynchronize()); + fprintf( + stderr, + "# Pruning kNN Graph on GPUs (%.1lf %%)\r", + (double)std::min((i_batch + 1) * batch_size, graph_chunk_size) / graph_chunk_size * 100); + } + for (int i_gpu = 0; i_gpu < num_gpus; i_gpu++) { + RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); + RAFT_CUDA_TRY( + cudaMemcpy(host_stats[i_gpu], dev_stats[i_gpu], sizeof(uint64_t) * 2, cudaMemcpyDefault)); + num_keep += host_stats[i_gpu][0]; + num_full += host_stats[i_gpu][1]; + } + RAFT_CUDA_TRY(cudaDeviceSynchronize()); + RAFT_CUDA_TRY(cudaSetDevice(0)); + RAFT_LOG_DEBUG("\n"); + + mgpu_D2H( + d_detour_count, detour_count, num_gpus, graph_size, graph_chunk_size, input_graph_degree); + mgpu_D2H( + d_num_no_detour_edges, num_no_detour_edges, num_gpus, graph_size, graph_chunk_size, 1); + + mgpu_free(d_input_graph_ptr, num_gpus); + mgpu_free(d_detour_count, num_gpus); + mgpu_free(d_num_no_detour_edges, num_gpus); + + // Create pruned kNN graph + array_size = sizeof(uint32_t) * graph_size * output_graph_degree; + uint32_t* pruned_graph_ptr = (uint32_t*)malloc(array_size); + uint32_t max_detour = 0; +#pragma omp parallel for reduction(max : max_detour) + for (uint64_t i = 0; i < graph_size; i++) { + uint64_t pk = 0; + for (uint32_t num_detour = 0; num_detour < output_graph_degree; num_detour++) { + if (max_detour < num_detour) { max_detour = num_detour; /* stats */ } + for (uint64_t k = 0; k < input_graph_degree; k++) { + if (detour_count[k + (input_graph_degree * i)] != num_detour) { continue; } + pruned_graph_ptr[pk + (output_graph_degree * i)] = + input_graph_ptr[k + (input_graph_degree * i)]; + pk += 1; + if (pk >= output_graph_degree) break; + } + if (pk >= output_graph_degree) break; + } + assert(pk == output_graph_degree); + } + // RAFT_LOG_DEBUG("# max_detour: %u\n", max_detour); + + double time_prune_end = cur_time(); + fprintf(stderr, + "# Pruning time: %.1lf sec, " + "avg_no_detour_edges_per_node: %.2lf/%u, " + "nodes_with_no_detour_at_all_edges: %.1lf%%\n", + time_prune_end - time_prune_start, + (double)num_keep / graph_size, + output_graph_degree, + (double)num_full / graph_size * 100); + + // + // Make reverse graph + // + double time_make_start = cur_time(); + + array_size = sizeof(uint32_t) * graph_size * output_graph_degree; + uint32_t* rev_graph_ptr = (uint32_t*)malloc(array_size); + memset(rev_graph_ptr, 0xff, array_size); + + uint32_t*** d_rev_graph_ptr; // [...][num_gpus][graph_chunk_size, output_graph_degree] + d_rev_graph_ptr = mgpu_alloc(num_gpus, graph_chunk_size, output_graph_degree); + mgpu_H2D( + d_rev_graph_ptr, rev_graph_ptr, num_gpus, graph_size, graph_chunk_size, output_graph_degree); + + array_size = sizeof(uint32_t) * graph_size; + uint32_t* rev_graph_count = (uint32_t*)malloc(array_size); + memset(rev_graph_count, 0, array_size); + + uint32_t*** d_rev_graph_count; // [...][num_gpus][graph_chunk_size, 1] + d_rev_graph_count = mgpu_alloc(num_gpus, graph_chunk_size, 1); + mgpu_H2D(d_rev_graph_count, rev_graph_count, num_gpus, graph_size, graph_chunk_size, 1); + + uint32_t* dest_nodes; // [graph_size] + dest_nodes = (uint32_t*)malloc(sizeof(uint32_t) * graph_size); + uint32_t** d_dest_nodes; // [num_gpus][graph_size] + d_dest_nodes = (uint32_t**)malloc(sizeof(uint32_t*) * num_gpus); + for (int i_gpu = 0; i_gpu < num_gpus; i_gpu++) { + RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); + RAFT_CUDA_TRY(cudaMalloc(&(d_dest_nodes[i_gpu]), sizeof(uint32_t) * graph_size)); + } + + for (uint64_t k = 0; k < output_graph_degree; k++) { +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + dest_nodes[i] = pruned_graph_ptr[k + (output_graph_degree * i)]; + } + RAFT_CUDA_TRY(cudaDeviceSynchronize()); +#pragma omp parallel num_threads(num_gpus) + { + int i_gpu = omp_get_thread_num(); + RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); + RAFT_CUDA_TRY(cudaMemcpy( + d_dest_nodes[i_gpu], dest_nodes, sizeof(uint32_t) * graph_size, cudaMemcpyHostToDevice)); + dim3 threads(256, 1, 1); + dim3 blocks(1024, 1, 1); + kern_make_rev_graph<<>>(i_gpu, + d_dest_nodes[i_gpu], + graph_size, + d_rev_graph_ptr[num_gpus][i_gpu], + d_rev_graph_count[num_gpus][i_gpu], + graph_chunk_size, + output_graph_degree); + } + RAFT_LOG_DEBUG("# Making reverse graph on GPUs: %lu / %u \r", k, output_graph_degree); + } + RAFT_CUDA_TRY(cudaDeviceSynchronize()); + RAFT_CUDA_TRY(cudaSetDevice(0)); + RAFT_LOG_DEBUG("\n"); + + mgpu_D2H( + d_rev_graph_ptr, rev_graph_ptr, num_gpus, graph_size, graph_chunk_size, output_graph_degree); + mgpu_D2H(d_rev_graph_count, rev_graph_count, num_gpus, graph_size, graph_chunk_size, 1); + mgpu_free(d_rev_graph_ptr, num_gpus); + mgpu_free(d_rev_graph_count, num_gpus); + + double time_make_end = cur_time(); + RAFT_LOG_DEBUG("# Making reverse graph time: %.1lf sec", time_make_end - time_make_start); + + // + // Replace some edges with reverse edges + // + double time_replace_start = cur_time(); + + uint64_t num_protected_edges = output_graph_degree / 2; + RAFT_LOG_DEBUG("# num_protected_edges: %lu", num_protected_edges); + + array_size = sizeof(uint32_t) * graph_size * output_graph_degree; + memcpy(output_graph_ptr, pruned_graph_ptr, array_size); + + constexpr int _omp_chunk = 1024; +#pragma omp parallel for schedule(dynamic, _omp_chunk) + for (uint64_t j = 0; j < graph_size; j++) { + for (uint64_t _k = 0; _k < rev_graph_count[j]; _k++) { + uint64_t k = rev_graph_count[j] - 1 - _k; + uint64_t i = rev_graph_ptr[k + (output_graph_degree * j)]; + + uint64_t pos = pos_in_array( + i, output_graph_ptr + (output_graph_degree * j), output_graph_degree); + if (pos < num_protected_edges) { continue; } + uint64_t num_shift = pos - num_protected_edges; + if (pos == output_graph_degree) { num_shift = output_graph_degree - num_protected_edges - 1; } + shift_array(output_graph_ptr + num_protected_edges + (output_graph_degree * j), + num_shift); + output_graph_ptr[num_protected_edges + (output_graph_degree * j)] = i; + } + if ((omp_get_thread_num() == 0) && ((j % _omp_chunk) == 0)) { + RAFT_LOG_DEBUG("# Replacing reverse edges: %lu / %lu ", j, graph_size); + } + } + RAFT_LOG_DEBUG("\n"); + free(rev_graph_ptr); + free(rev_graph_count); + + double time_replace_end = cur_time(); + RAFT_LOG_DEBUG("# Replacing edges time: %.1lf sec", time_replace_end - time_replace_start); + + /* stats */ + uint64_t num_replaced_edges = 0; +#pragma omp parallel for reduction(+ : num_replaced_edges) + for (uint64_t i = 0; i < graph_size; i++) { + for (uint64_t k = 0; k < output_graph_degree; k++) { + uint64_t j = pruned_graph_ptr[k + (output_graph_degree * i)]; + uint64_t pos = pos_in_array( + j, output_graph_ptr + (output_graph_degree * i), output_graph_degree); + if (pos == output_graph_degree) { num_replaced_edges += 1; } + } + } + fprintf(stderr, + "# Average number of replaced edges per node: %.2f", + (double)num_replaced_edges / graph_size); +} + +} // namespace graph +} // namespace raft::neighbors::experimental::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/hashmap.hpp b/cpp/include/raft/neighbors/detail/cagra/hashmap.hpp new file mode 100644 index 0000000000..18f4006367 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/cagra/hashmap.hpp @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "utils.hpp" +#include +#include + +// #pragma GCC diagnostic push +// #pragma GCC diagnostic ignored +// #pragma GCC diagnostic pop +namespace raft::neighbors::experimental::cagra::detail { +namespace hashmap { + +_RAFT_HOST_DEVICE inline uint32_t get_size(const uint32_t bitlen) { return 1U << bitlen; } + +template +_RAFT_DEVICE inline void init(uint32_t* table, const uint32_t bitlen) +{ + if (threadIdx.x < FIRST_TID) return; + for (unsigned i = threadIdx.x - FIRST_TID; i < get_size(bitlen); i += blockDim.x - FIRST_TID) { + table[i] = utils::get_max_value(); + } +} + +template +_RAFT_DEVICE inline void init(uint32_t* table, const uint32_t bitlen) +{ + if ((FIRST_TID > 0 && threadIdx.x < FIRST_TID) || threadIdx.x >= LAST_TID) return; + for (unsigned i = threadIdx.x - FIRST_TID; i < get_size(bitlen); i += LAST_TID - FIRST_TID) { + table[i] = utils::get_max_value(); + } +} + +_RAFT_DEVICE inline uint32_t insert(uint32_t* table, const uint32_t bitlen, const uint32_t key) +{ + // Open addressing is used for collision resolution + const uint32_t size = get_size(bitlen); + const uint32_t bit_mask = size - 1; +#if 1 + // Linear probing + uint32_t index = (key ^ (key >> bitlen)) & bit_mask; + constexpr uint32_t stride = 1; +#else + // Double hashing + uint32_t index = key & bit_mask; + const uint32_t stride = (key >> bitlen) * 2 + 1; +#endif + for (unsigned i = 0; i < size; i++) { + const uint32_t old = atomicCAS(&table[index], ~0u, key); + if (old == ~0u) { + return 1; + } else if (old == key) { + return 0; + } + index = (index + stride) & bit_mask; + } + return 0; +} + +template +_RAFT_DEVICE inline uint32_t insert(uint32_t* table, const uint32_t bitlen, const uint32_t key) +{ + uint32_t ret = 0; + if (threadIdx.x % TEAM_SIZE == 0) { ret = insert(table, bitlen, key); } + for (unsigned offset = 1; offset < TEAM_SIZE; offset *= 2) { + ret |= __shfl_xor_sync(0xffffffff, ret, offset); + } + return ret; +} + +} // namespace hashmap +} // namespace raft::neighbors::experimental::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh new file mode 100644 index 0000000000..2c0ac98417 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh @@ -0,0 +1,632 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "bitonic.hpp" +#include "compute_distance.hpp" +#include "device_common.hpp" +#include "hashmap.hpp" +#include "search_plan.cuh" +#include "topk_for_cagra/topk_core.cuh" // TODO replace with raft topk if possible +#include "utils.hpp" +#include +#include +#include // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp + +namespace raft::neighbors::experimental::cagra::detail { +namespace multi_cta_search { + +// #define _CLK_BREAKDOWN + +template +__device__ void pickup_next_parents(INDEX_T* const next_parent_indices, // [num_parents] + const uint32_t num_parents, + INDEX_T* const itopk_indices, // [num_itopk] + const size_t num_itopk, + uint32_t* const terminate_flag) +{ + const unsigned warp_id = threadIdx.x / 32; + if (warp_id > 0) { return; } + const unsigned lane_id = threadIdx.x % 32; + for (uint32_t i = lane_id; i < num_parents; i += 32) { + next_parent_indices[i] = utils::get_max_value(); + } + uint32_t max_itopk = num_itopk; + if (max_itopk % 32) { max_itopk += 32 - (max_itopk % 32); } + uint32_t num_new_parents = 0; + for (uint32_t j = lane_id; j < max_itopk; j += 32) { + INDEX_T index; + int new_parent = 0; + if (j < num_itopk) { + index = itopk_indices[j]; + if ((index & 0x80000000) == 0) { // check if most significant bit is set + new_parent = 1; + } + } + const uint32_t ballot_mask = __ballot_sync(0xffffffff, new_parent); + if (new_parent) { + const auto i = __popc(ballot_mask & ((1 << lane_id) - 1)) + num_new_parents; + if (i < num_parents) { + next_parent_indices[i] = index; + itopk_indices[j] |= 0x80000000; // set most significant bit as used node + } + } + num_new_parents += __popc(ballot_mask); + if (num_new_parents >= num_parents) { break; } + } + if (threadIdx.x == 0 && (num_new_parents == 0)) { *terminate_flag = 1; } +} + +template +__device__ inline void topk_by_bitonic_sort(float* distances, // [num_elements] + uint32_t* indices, // [num_elements] + const uint32_t num_elements, + const uint32_t num_itopk // num_itopk <= num_elements +) +{ + const unsigned warp_id = threadIdx.x / 32; + if (warp_id > 0) { return; } + const unsigned lane_id = threadIdx.x % 32; + constexpr unsigned N = (MAX_ELEMENTS + 31) / 32; + float key[N]; + uint32_t val[N]; + for (unsigned i = 0; i < N; i++) { + unsigned j = lane_id + (32 * i); + if (j < num_elements) { + key[i] = distances[j]; + val[i] = indices[j]; + } else { + key[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); + } + } + /* Warp Sort */ + bitonic::warp_sort(key, val); + /* Store itopk sorted results */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * lane_id) + i; + if (j < num_itopk) { + distances[j] = key[i]; + indices[j] = val[i]; + } + } +} + +// +// multiple CTAs per single query +// +template +__launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( + INDEX_T* const result_indices_ptr, // [num_queries, num_cta_per_query, itopk_size] + DISTANCE_T* const result_distances_ptr, // [num_queries, num_cta_per_query, itopk_size] + const DATA_T* const dataset_ptr, // [dataset_size, dataset_dim] + const size_t dataset_dim, + const size_t dataset_size, + const DATA_T* const queries_ptr, // [num_queries, dataset_dim] + const INDEX_T* const knn_graph, // [dataset_size, graph_degree] + const uint32_t graph_degree, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const INDEX_T* seed_ptr, // [num_queries, num_seeds] + const uint32_t num_seeds, + uint32_t* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + const uint32_t hash_bitlen, + const uint32_t itopk_size, + const uint32_t num_parents, + const uint32_t min_iteration, + const uint32_t max_iteration, + uint32_t* const num_executed_iterations /* stats */ +) +{ + assert(blockDim.x == BLOCK_SIZE); + assert(dataset_dim <= MAX_DATASET_DIM); + + // const auto num_queries = gridDim.y; + const auto query_id = blockIdx.y; + const auto num_cta_per_query = gridDim.x; + const auto cta_id = blockIdx.x; // local CTA ID + +#ifdef _CLK_BREAKDOWN + uint64_t clk_init = 0; + uint64_t clk_compute_1st_distance = 0; + uint64_t clk_topk = 0; + uint64_t clk_pickup_parents = 0; + uint64_t clk_compute_distance = 0; + uint64_t clk_start; +#define _CLK_START() clk_start = clock64() +#define _CLK_REC(V) V += clock64() - clk_start; +#else +#define _CLK_START() +#define _CLK_REC(V) +#endif + _CLK_START(); + + extern __shared__ uint32_t smem[]; + + // Layout of result_buffer + // +----------------+------------------------------+---------+ + // | internal_top_k | neighbors of parent nodes | padding | + // | | | upto 32 | + // +----------------+------------------------------+---------+ + // |<--- result_buffer_size --->| + uint32_t result_buffer_size = itopk_size + (num_parents * graph_degree); + uint32_t result_buffer_size_32 = result_buffer_size; + if (result_buffer_size % 32) { result_buffer_size_32 += 32 - (result_buffer_size % 32); } + assert(result_buffer_size_32 <= MAX_ELEMENTS); + + auto query_buffer = reinterpret_cast(smem); + auto result_indices_buffer = reinterpret_cast(query_buffer + MAX_DATASET_DIM); + auto result_distances_buffer = + reinterpret_cast(result_indices_buffer + result_buffer_size_32); + auto parent_indices_buffer = + reinterpret_cast(result_distances_buffer + result_buffer_size_32); + auto terminate_flag = reinterpret_cast(parent_indices_buffer + num_parents); + +#if 0 + /* debug */ + for (unsigned i = threadIdx.x; i < result_buffer_size_32; i += BLOCK_SIZE) { + result_indices_buffer[i] = utils::get_max_value(); + result_distances_buffer[i] = utils::get_max_value(); + } +#endif + + const DATA_T* const query_ptr = queries_ptr + (dataset_dim * query_id); + for (unsigned i = threadIdx.x; i < MAX_DATASET_DIM; i += BLOCK_SIZE) { + unsigned j = device::swizzling(i); + if (i < dataset_dim) { + query_buffer[j] = static_cast(query_ptr[i]) * device::fragment_scale(); + } else { + query_buffer[j] = 0.0; + } + } + if (threadIdx.x == 0) { terminate_flag[0] = 0; } + uint32_t* local_visited_hashmap_ptr = + visited_hashmap_ptr + (hashmap::get_size(hash_bitlen) * query_id); + __syncthreads(); + _CLK_REC(clk_init); + + // compute distance to randomly selecting nodes + _CLK_START(); + const INDEX_T* const local_seed_ptr = seed_ptr ? seed_ptr + (num_seeds * query_id) : nullptr; + device::compute_distance_to_random_nodes( + result_indices_buffer, + result_distances_buffer, + query_buffer, + dataset_ptr, + dataset_dim, + dataset_size, + result_buffer_size, + num_distilation, + rand_xor_mask, + local_seed_ptr, + num_seeds, + local_visited_hashmap_ptr, + hash_bitlen, + cta_id, + num_cta_per_query); + __syncthreads(); + _CLK_REC(clk_compute_1st_distance); + + uint32_t iter = 0; + while (1) { + // topk with bitonic sort + _CLK_START(); + topk_by_bitonic_sort(result_distances_buffer, + result_indices_buffer, + itopk_size + (num_parents * graph_degree), + itopk_size); + _CLK_REC(clk_topk); + + if (iter + 1 == max_iteration) { + __syncthreads(); + break; + } + + // pick up next parents + _CLK_START(); + pickup_next_parents( + parent_indices_buffer, num_parents, result_indices_buffer, itopk_size, terminate_flag); + _CLK_REC(clk_pickup_parents); + + __syncthreads(); + if (*terminate_flag && iter >= min_iteration) { break; } + + // compute the norms between child nodes and query node + _CLK_START(); + // constexpr unsigned max_n_frags = 16; + constexpr unsigned max_n_frags = 0; + device:: + compute_distance_to_child_nodes( + result_indices_buffer + itopk_size, + result_distances_buffer + itopk_size, + query_buffer, + dataset_ptr, + dataset_dim, + knn_graph, + graph_degree, + local_visited_hashmap_ptr, + hash_bitlen, + parent_indices_buffer, + num_parents); + _CLK_REC(clk_compute_distance); + __syncthreads(); + + iter++; + } + + for (uint32_t i = threadIdx.x; i < itopk_size; i += BLOCK_SIZE) { + uint32_t j = i + (itopk_size * (cta_id + (num_cta_per_query * query_id))); + if (result_distances_ptr != nullptr) { result_distances_ptr[j] = result_distances_buffer[i]; } + result_indices_ptr[j] = result_indices_buffer[i] & ~0x80000000; // clear most significant bit + } + + if (threadIdx.x == 0 && cta_id == 0 && num_executed_iterations != nullptr) { + num_executed_iterations[query_id] = iter + 1; + } + +#ifdef _CLK_BREAKDOWN + if ((threadIdx.x == 0 || threadIdx.x == BLOCK_SIZE - 1) && (blockIdx.x == 0) && + ((query_id * 3) % gridDim.y < 3)) { + RAFT_LOG_DEBUG( + "query, %d, thread, %d" + ", init, %d" + ", 1st_distance, %lu" + ", topk, %lu" + ", pickup_parents, %lu" + ", distance, %lu" + "\n", + query_id, + threadIdx.x, + clk_init, + clk_compute_1st_distance, + clk_topk, + clk_pickup_parents, + clk_compute_distance); + } +#endif +} + +#define SET_MC_KERNEL_3(BLOCK_SIZE, BLOCK_COUNT, MAX_ELEMENTS, LOAD_T) \ + kernel = search_kernel; + +#define SET_MC_KERNEL_2(BLOCK_SIZE, BLOCK_COUNT, MAX_ELEMENTS) \ + if (load_bit_length == 128) { \ + SET_MC_KERNEL_3(BLOCK_SIZE, BLOCK_COUNT, MAX_ELEMENTS, device::LOAD_128BIT_T) \ + } else if (load_bit_length == 64) { \ + SET_MC_KERNEL_3(BLOCK_SIZE, BLOCK_COUNT, MAX_ELEMENTS, device::LOAD_64BIT_T) \ + } + +#define SET_MC_KERNEL_1(MAX_ELEMENTS) \ + /* if ( block_size == 32 ) { \ + SET_MC_KERNEL_2( 32, 32, MAX_ELEMENTS ) \ + } else */ \ + if (block_size == 64) { \ + SET_MC_KERNEL_2(64, 16, MAX_ELEMENTS) \ + } else if (block_size == 128) { \ + SET_MC_KERNEL_2(128, 8, MAX_ELEMENTS) \ + } else if (block_size == 256) { \ + SET_MC_KERNEL_2(256, 4, MAX_ELEMENTS) \ + } else if (block_size == 512) { \ + SET_MC_KERNEL_2(512, 2, MAX_ELEMENTS) \ + } else { \ + SET_MC_KERNEL_2(1024, 1, MAX_ELEMENTS) \ + } + +#define SET_MC_KERNEL \ + typedef void (*search_kernel_t)(INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const DATA_T* const dataset_ptr, \ + const size_t dataset_dim, \ + const size_t dataset_size, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + uint32_t* const visited_hashmap_ptr, \ + const uint32_t hash_bitlen, \ + const uint32_t itopk_size, \ + const uint32_t num_parents, \ + const uint32_t min_iteration, \ + const uint32_t max_iteration, \ + uint32_t* const num_executed_iterations); \ + search_kernel_t kernel; \ + if (result_buffer_size <= 64) { \ + SET_MC_KERNEL_1(64) \ + } else if (result_buffer_size <= 128) { \ + SET_MC_KERNEL_1(128) \ + } else if (result_buffer_size <= 256) { \ + SET_MC_KERNEL_1(256) \ + } + +template +__global__ void set_value_batch_kernel(T* const dev_ptr, + const std::size_t ld, + const T val, + const std::size_t count, + const std::size_t batch_size) +{ + const auto tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= count * batch_size) { return; } + const auto batch_id = tid / count; + const auto elem_id = tid % count; + dev_ptr[elem_id + ld * batch_id] = val; +} + +template +void set_value_batch(T* const dev_ptr, + const std::size_t ld, + const T val, + const std::size_t count, + const std::size_t batch_size, + cudaStream_t cuda_stream) +{ + constexpr std::uint32_t block_size = 256; + const auto grid_size = (count * batch_size + block_size - 1) / block_size; + set_value_batch_kernel + <<>>(dev_ptr, ld, val, count, batch_size); +} + +template + +struct search : public search_plan_impl { + using search_plan_impl::max_queries; + using search_plan_impl::itopk_size; + using search_plan_impl::algo; + using search_plan_impl::team_size; + using search_plan_impl::num_parents; + using search_plan_impl::min_iterations; + using search_plan_impl::max_iterations; + using search_plan_impl::load_bit_length; + using search_plan_impl::thread_block_size; + using search_plan_impl::hashmap_mode; + using search_plan_impl::hashmap_min_bitlen; + using search_plan_impl::hashmap_max_fill_rate; + using search_plan_impl::num_random_samplings; + using search_plan_impl::rand_xor_mask; + + using search_plan_impl::max_dim; + using search_plan_impl::dim; + using search_plan_impl::graph_degree; + using search_plan_impl::topk; + + using search_plan_impl::hash_bitlen; + + using search_plan_impl::small_hash_bitlen; + using search_plan_impl::small_hash_reset_interval; + using search_plan_impl::hashmap_size; + using search_plan_impl::dataset_size; + using search_plan_impl::result_buffer_size; + + using search_plan_impl::smem_size; + using search_plan_impl::load_bit_lenght; + + using search_plan_impl::hashmap; + using search_plan_impl::num_executed_iterations; + using search_plan_impl::dev_seed; + using search_plan_impl::num_seeds; + + uint32_t num_cta_per_query; + rmm::device_uvector intermediate_indices; + rmm::device_uvector intermediate_distances; + size_t topk_workspace_size; + rmm::device_uvector topk_workspace; + + search(raft::device_resources const& res, + search_params params, + int64_t dim, + int64_t graph_degree, + uint32_t topk) + : search_plan_impl(res, params, dim, graph_degree, topk), + intermediate_indices(0, res.get_stream()), + intermediate_distances(0, res.get_stream()), + topk_workspace(0, res.get_stream()) + + { + set_params(res); + } + + void set_params(raft::device_resources const& res) + { + this->itopk_size = 32; + num_parents = 1; + num_cta_per_query = max(num_parents, itopk_size / 32); + result_buffer_size = itopk_size + num_parents * graph_degree; + typedef raft::Pow2<32> AlignBytes; + unsigned result_buffer_size_32 = AlignBytes::roundUp(result_buffer_size); + // constexpr unsigned max_result_buffer_size = 256; + RAFT_EXPECTS(result_buffer_size_32 <= 256, "Result buffer size cannot exceed 256"); + + smem_size = sizeof(float) * max_dim + + (sizeof(INDEX_T) + sizeof(DISTANCE_T)) * result_buffer_size_32 + + sizeof(uint32_t) * num_parents + sizeof(uint32_t); + RAFT_LOG_DEBUG("# smem_size: %u", smem_size); + + // + // Determine the thread block size + // + constexpr unsigned min_block_size = 64; + constexpr unsigned max_block_size = 1024; + uint32_t block_size = thread_block_size; + if (block_size == 0) { + block_size = min_block_size; + + // Increase block size according to shared memory requirements. + // If block size is 32, upper limit of shared memory size per + // thread block is set to 4096. This is GPU generation dependent. + constexpr unsigned ulimit_smem_size_cta32 = 4096; + while (smem_size > ulimit_smem_size_cta32 / 32 * block_size) { + block_size *= 2; + } + + // Increase block size to improve GPU occupancy when total number of + // CTAs (= num_cta_per_query * max_queries) is small. + cudaDeviceProp deviceProp = res.get_device_properties(); + RAFT_LOG_DEBUG("# multiProcessorCount: %d", deviceProp.multiProcessorCount); + while ((block_size < max_block_size) && + (graph_degree * num_parents * team_size >= block_size * 2) && + (num_cta_per_query * max_queries <= + (1024 / (block_size * 2)) * deviceProp.multiProcessorCount)) { + block_size *= 2; + } + } + RAFT_LOG_DEBUG("# thread_block_size: %u", block_size); + RAFT_EXPECTS(block_size >= min_block_size, + "block_size cannot be smaller than min_block size, %u", + min_block_size); + RAFT_EXPECTS(block_size <= max_block_size, + "block_size cannot be larger than max_block size %u", + max_block_size); + thread_block_size = block_size; + + // + // Determine load bit length + // + const uint32_t total_bit_length = dim * sizeof(DATA_T) * 8; + if (load_bit_length == 0) { + load_bit_length = 128; + while (total_bit_length % load_bit_length) { + load_bit_length /= 2; + } + } + RAFT_LOG_DEBUG("# load_bit_length: %u (%u loads per vector)", + load_bit_length, + total_bit_length / load_bit_length); + RAFT_EXPECTS(total_bit_length % load_bit_length == 0, + "load_bit_length must be a divisor of dim*sizeof(data_t)*8=%u", + total_bit_length); + RAFT_EXPECTS(load_bit_length >= 64, "load_bit_lenght cannot be less than 64"); + + // + // Allocate memory for intermediate buffer and workspace. + // + uint32_t num_intermediate_results = num_cta_per_query * itopk_size; + intermediate_indices.resize(num_intermediate_results, res.get_stream()); + intermediate_distances.resize(num_intermediate_results, res.get_stream()); + + hashmap.resize(hashmap_size, res.get_stream()); + + topk_workspace_size = _cuann_find_topk_bufferSize( + topk, max_queries, num_intermediate_results, utils::get_cuda_data_type()); + RAFT_LOG_DEBUG("# topk_workspace_size: %lu", topk_workspace_size); + topk_workspace.resize(topk_workspace_size, res.get_stream()); + } + + ~search() {} + + void operator()(raft::device_resources const& res, + raft::device_matrix_view dataset, + raft::device_matrix_view graph, + INDEX_T* const topk_indices_ptr, // [num_queries, topk] + DISTANCE_T* const topk_distances_ptr, // [num_queries, topk] + const DATA_T* const queries_ptr, // [num_queries, dataset_dim] + const uint32_t num_queries, + const INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] + uint32_t* const num_executed_iterations, // [num_queries,] + uint32_t topk) + { + cudaStream_t stream = res.get_stream(); + uint32_t block_size = thread_block_size; + + SET_MC_KERNEL; + RAFT_CUDA_TRY( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + // Initialize hash table + const uint32_t hash_size = hashmap::get_size(hash_bitlen); + set_value_batch( + hashmap.data(), hash_size, utils::get_max_value(), hash_size, num_queries, stream); + + dim3 block_dims(block_size, 1, 1); + dim3 grid_dims(num_cta_per_query, num_queries, 1); + RAFT_LOG_DEBUG("Launching kernel with %u threads, (%u, %u) blocks %lu smem", + block_size, + num_cta_per_query, + num_queries, + smem_size); + kernel<<>>(intermediate_indices.data(), + intermediate_distances.data(), + dataset.data_handle(), + dataset.extent(1), + dataset.extent(0), + queries_ptr, + graph.data_handle(), + graph.extent(1), + num_random_samplings, + rand_xor_mask, + dev_seed_ptr, + num_seeds, + hashmap.data(), + hash_bitlen, + itopk_size, + num_parents, + min_iterations, + max_iterations, + num_executed_iterations); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + + // Select the top-k results from the intermediate results + const uint32_t num_intermediate_results = num_cta_per_query * itopk_size; + _cuann_find_topk(topk, + num_queries, + num_intermediate_results, + intermediate_distances.data(), + num_intermediate_results, + intermediate_indices.data(), + num_intermediate_results, + topk_distances_ptr, + topk, + topk_indices_ptr, + topk, + topk_workspace.data(), + true, + NULL, + stream); + } +}; + +} // namespace multi_cta_search +} // namespace raft::neighbors::experimental::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh new file mode 100644 index 0000000000..f688941239 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh @@ -0,0 +1,721 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "compute_distance.hpp" +#include "device_common.hpp" +#include "fragment.hpp" +#include "hashmap.hpp" +#include "search_plan.cuh" +#include "topk_for_cagra/topk_core.cuh" //todo replace with raft kernel +#include "utils.hpp" +#include +#include +#include // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp + +namespace raft::neighbors::experimental::cagra::detail { +namespace multi_kernel_search { + +template +__global__ void set_value_kernel(T* const dev_ptr, const T val) +{ + *dev_ptr = val; +} + +template +__global__ void set_value_kernel(T* const dev_ptr, const T val, const std::size_t count) +{ + const auto tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= count) { return; } + dev_ptr[tid] = val; +} + +template +void set_value(T* const dev_ptr, const T val, cudaStream_t cuda_stream) +{ + set_value_kernel<<<1, 1, 0, cuda_stream>>>(dev_ptr, val); +} + +template +void set_value(T* const dev_ptr, const T val, const std::size_t count, cudaStream_t cuda_stream) +{ + constexpr std::uint32_t block_size = 256; + const auto grid_size = (count + block_size - 1) / block_size; + set_value_kernel<<>>(dev_ptr, val, count); +} + +template +__global__ void get_value_kernel(T* const host_ptr, const T* const dev_ptr) +{ + *host_ptr = *dev_ptr; +} + +template +void get_value(T* const host_ptr, const T* const dev_ptr, cudaStream_t cuda_stream) +{ + get_value_kernel<<<1, 1, 0, cuda_stream>>>(host_ptr, dev_ptr); +} + +// MAX_DATASET_DIM : must equal to or greater than dataset_dim +template +__global__ void random_pickup_kernel( + const DATA_T* const dataset_ptr, // [dataset_size, dataset_dim] + const std::size_t dataset_dim, + const std::size_t dataset_size, + const DATA_T* const queries_ptr, // [num_queries, dataset_dim] + const std::size_t num_pickup, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const INDEX_T* seed_ptr, // [num_queries, num_seeds] + const uint32_t num_seeds, + INDEX_T* const result_indices_ptr, // [num_queries, ldr] + DISTANCE_T* const result_distances_ptr, // [num_queries, ldr] + const std::uint32_t ldr, // (*) ldr >= num_pickup + std::uint32_t* const visited_hashmap_ptr, // [num_queries, 1 << bitlen] + const std::uint32_t hash_bitlen) +{ + const auto ldb = hashmap::get_size(hash_bitlen); + const auto global_team_index = (blockIdx.x * blockDim.x + threadIdx.x) / TEAM_SIZE; + const uint32_t query_id = blockIdx.y; + if (global_team_index >= num_pickup) { return; } + // Load a query + device::fragment query_frag; + device::load_vector_sync(query_frag, queries_ptr + query_id * dataset_dim, dataset_dim); + + INDEX_T best_index_team_local; + DISTANCE_T best_norm2_team_local = utils::get_max_value(); + for (unsigned i = 0; i < num_distilation; i++) { + INDEX_T seed_index; + if (seed_ptr && (global_team_index < num_seeds)) { + seed_index = seed_ptr[global_team_index + (num_seeds * query_id)]; + } else { + // Chose a seed node randomly + seed_index = device::xorshift64((global_team_index ^ rand_xor_mask) * (i + 1)) % dataset_size; + } + device::fragment random_data_frag; + device::load_vector_sync( + random_data_frag, dataset_ptr + (dataset_dim * seed_index), dataset_dim); + + // Compute the norm of two data + const auto norm2 = + device::norm2(query_frag, random_data_frag, device::fragment_scale() + /*, scale*/ + ); + + if (norm2 < best_norm2_team_local) { + best_norm2_team_local = norm2; + best_index_team_local = seed_index; + } + } + + const auto store_gmem_index = global_team_index + (ldr * query_id); + if (threadIdx.x % TEAM_SIZE == 0) { + if (hashmap::insert( + visited_hashmap_ptr + (ldb * query_id), hash_bitlen, best_index_team_local)) { + result_distances_ptr[store_gmem_index] = best_norm2_team_local; + result_indices_ptr[store_gmem_index] = best_index_team_local; + } else { + result_distances_ptr[store_gmem_index] = utils::get_max_value(); + result_indices_ptr[store_gmem_index] = utils::get_max_value(); + } + } +} + +// MAX_DATASET_DIM : must be equal to or greater than dataset_dim +template +void random_pickup(const DATA_T* const dataset_ptr, // [dataset_size, dataset_dim] + const std::size_t dataset_dim, + const std::size_t dataset_size, + const DATA_T* const queries_ptr, // [num_queries, dataset_dim] + const std::size_t num_queries, + const std::size_t num_pickup, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const INDEX_T* seed_ptr, // [num_queries, num_seeds] + const uint32_t num_seeds, + INDEX_T* const result_indices_ptr, // [num_queries, ldr] + DISTANCE_T* const result_distances_ptr, // [num_queries, ldr] + const std::size_t ldr, // (*) ldr >= num_pickup + std::uint32_t* const visited_hashmap_ptr, // [num_queries, 1 << bitlen] + const std::uint32_t hash_bitlen, + cudaStream_t const cuda_stream = 0) +{ + const auto block_size = 256u; + const auto num_teams_per_threadblock = block_size / TEAM_SIZE; + const dim3 grid_size((num_pickup + num_teams_per_threadblock - 1) / num_teams_per_threadblock, + num_queries); + + random_pickup_kernel + <<>>(dataset_ptr, + dataset_dim, + dataset_size, + queries_ptr, + num_pickup, + num_distilation, + rand_xor_mask, + seed_ptr, + num_seeds, + result_indices_ptr, + result_distances_ptr, + ldr, + visited_hashmap_ptr, + hash_bitlen); +} + +template +__global__ void pickup_next_parents_kernel( + INDEX_T* const parent_candidates_ptr, // [num_queries, lds] + const std::size_t lds, // (*) lds >= parent_candidates_size + const std::uint32_t parent_candidates_size, // + std::uint32_t* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + const std::size_t hash_bitlen, + const std::uint32_t small_hash_bitlen, + INDEX_T* const parent_list_ptr, // [num_queries, ldd] + const std::size_t ldd, // (*) ldd >= parent_list_size + const std::size_t parent_list_size, // + std::uint32_t* const terminate_flag) +{ + const std::size_t ldb = hashmap::get_size(hash_bitlen); + const uint32_t query_id = blockIdx.x; + if (threadIdx.x < 32) { + // pickup next parents with single warp + for (std::uint32_t i = threadIdx.x; i < parent_list_size; i += 32) { + parent_list_ptr[i + (ldd * query_id)] = utils::get_max_value(); + } + std::uint32_t parent_candidates_size_max = parent_candidates_size; + if (parent_candidates_size % 32) { + parent_candidates_size_max += 32 - (parent_candidates_size % 32); + } + std::uint32_t num_new_parents = 0; + for (std::uint32_t j = threadIdx.x; j < parent_candidates_size_max; j += 32) { + INDEX_T index; + int new_parent = 0; + if (j < parent_candidates_size) { + index = parent_candidates_ptr[j + (lds * query_id)]; + if ((index & 0x80000000) == 0) { // check most significant bit + new_parent = 1; + } + } + const std::uint32_t ballot_mask = __ballot_sync(0xffffffff, new_parent); + if (new_parent) { + const auto i = __popc(ballot_mask & ((1 << threadIdx.x) - 1)) + num_new_parents; + if (i < parent_list_size) { + parent_list_ptr[i + (ldd * query_id)] = index; + parent_candidates_ptr[j + (lds * query_id)] |= + 0x80000000; // set most significant bit as used node + } + } + num_new_parents += __popc(ballot_mask); + if (num_new_parents >= parent_list_size) { break; } + } + if ((num_new_parents > 0) && (threadIdx.x == 0)) { *terminate_flag = 0; } + } else if (small_hash_bitlen) { + // reset small-hash + hashmap::init<32>(visited_hashmap_ptr + (ldb * query_id), hash_bitlen); + } + + if (small_hash_bitlen) { + __syncthreads(); + // insert internal-topk indices into small-hash + for (unsigned i = threadIdx.x; i < parent_candidates_size; i += blockDim.x) { + auto key = + parent_candidates_ptr[i + (lds * query_id)] & ~0x80000000; // clear most significant bit + hashmap::insert(visited_hashmap_ptr + (ldb * query_id), hash_bitlen, key); + } + } +} + +template +void pickup_next_parents( + INDEX_T* const parent_candidates_ptr, // [num_queries, lds] + const std::size_t lds, // (*) lds >= parent_candidates_size + const std::size_t parent_candidates_size, // + const std::size_t num_queries, + std::uint32_t* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + const std::size_t hash_bitlen, + const std::size_t small_hash_bitlen, + INDEX_T* const parent_list_ptr, // [num_queries, ldd] + const std::size_t ldd, // (*) ldd >= parent_list_size + const std::size_t parent_list_size, // + std::uint32_t* const terminate_flag, + cudaStream_t cuda_stream = 0) +{ + std::uint32_t block_size = 32; + if (small_hash_bitlen) { + block_size = 128; + while (parent_candidates_size > block_size) { + block_size *= 2; + } + block_size = min(block_size, (uint32_t)512); + } + pickup_next_parents_kernel + <<>>(parent_candidates_ptr, + lds, + parent_candidates_size, + visited_hashmap_ptr, + hash_bitlen, + small_hash_bitlen, + parent_list_ptr, + ldd, + parent_list_size, + terminate_flag); +} + +template +__global__ void compute_distance_to_child_nodes_kernel( + const INDEX_T* const parent_node_list, // [num_queries, num_parents] + const std::uint32_t num_parents, + const DATA_T* const dataset_ptr, // [dataset_size, data_dim] + const std::uint32_t data_dim, + const std::uint32_t dataset_size, + const INDEX_T* const neighbor_graph_ptr, // [dataset_size, graph_degree] + const std::uint32_t graph_degree, + const DATA_T* query_ptr, // [num_queries, data_dim] + std::uint32_t* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + const std::uint32_t hash_bitlen, + INDEX_T* const result_indices_ptr, // [num_queries, ldd] + DISTANCE_T* const result_distances_ptr, // [num_queries, ldd] + const std::uint32_t ldd // (*) ldd >= num_parents * graph_degree +) +{ + const uint32_t ldb = hashmap::get_size(hash_bitlen); + const auto tid = threadIdx.x + blockDim.x * blockIdx.x; + const auto global_team_id = tid / TEAM_SIZE; + if (global_team_id >= num_parents * graph_degree) { return; } + + const std::size_t parent_index = + parent_node_list[global_team_id / graph_degree + (num_parents * blockIdx.y)]; + if (parent_index == utils::get_max_value()) { + result_distances_ptr[ldd * blockIdx.y + global_team_id] = utils::get_max_value(); + return; + } + const auto neighbor_list_head_ptr = neighbor_graph_ptr + (graph_degree * parent_index); + + const std::size_t child_id = neighbor_list_head_ptr[global_team_id % graph_degree]; + + if (hashmap::insert(visited_hashmap_ptr + (ldb * blockIdx.y), hash_bitlen, child_id)) { + device::fragment frag_target; + device::load_vector_sync(frag_target, dataset_ptr + (data_dim * child_id), data_dim); + + device::fragment frag_query; + device::load_vector_sync(frag_query, query_ptr + blockIdx.y * data_dim, data_dim); + + const auto norm2 = + device::norm2(frag_target, frag_query, device::fragment_scale()); + + if (threadIdx.x % TEAM_SIZE == 0) { + result_indices_ptr[ldd * blockIdx.y + global_team_id] = child_id; + result_distances_ptr[ldd * blockIdx.y + global_team_id] = norm2; + } + } else { + if (threadIdx.x % TEAM_SIZE == 0) { + result_distances_ptr[ldd * blockIdx.y + global_team_id] = utils::get_max_value(); + } + } +} + +template +void compute_distance_to_child_nodes( + const INDEX_T* const parent_node_list, // [num_queries, num_parents] + const uint32_t num_parents, + const DATA_T* const dataset_ptr, // [dataset_size, data_dim] + const std::uint32_t data_dim, + const std::uint32_t dataset_size, + const INDEX_T* const neighbor_graph_ptr, // [dataset_size, graph_degree] + const std::uint32_t graph_degree, + const DATA_T* query_ptr, // [num_queries, data_dim] + const std::uint32_t num_queries, + std::uint32_t* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + const std::uint32_t hash_bitlen, + INDEX_T* const result_indices_ptr, // [num_queries, ldd] + DISTANCE_T* const result_distances_ptr, // [num_queries, ldd] + const std::uint32_t ldd, // (*) ldd >= num_parants * graph_degree + cudaStream_t cuda_stream = 0) +{ + const auto block_size = 128; + const dim3 grid_size( + (num_parents * graph_degree + (block_size / TEAM_SIZE) - 1) / (block_size / TEAM_SIZE), + num_queries); + compute_distance_to_child_nodes_kernel + <<>>(parent_node_list, + num_parents, + dataset_ptr, + data_dim, + dataset_size, + neighbor_graph_ptr, + graph_degree, + query_ptr, + visited_hashmap_ptr, + hash_bitlen, + result_indices_ptr, + result_distances_ptr, + ldd); +} + +template +__global__ void remove_parent_bit_kernel(const std::uint32_t num_queries, + const std::uint32_t num_topk, + INDEX_T* const topk_indices_ptr, // [ld, num_queries] + const std::uint32_t ld) +{ + uint32_t i_query = blockIdx.x; + if (i_query >= num_queries) return; + + for (unsigned i = threadIdx.x; i < num_topk; i += blockDim.x) { + topk_indices_ptr[i + (ld * i_query)] &= ~0x80000000; // clear most significant bit + } +} + +template +void remove_parent_bit(const std::uint32_t num_queries, + const std::uint32_t num_topk, + INDEX_T* const topk_indices_ptr, // [ld, num_queries] + const std::uint32_t ld, + cudaStream_t cuda_stream = 0) +{ + const std::size_t grid_size = num_queries; + const std::size_t block_size = 256; + remove_parent_bit_kernel<<>>( + num_queries, num_topk, topk_indices_ptr, ld); +} + +template +__global__ void batched_memcpy_kernel(T* const dst, // [batch_size, ld_dst] + const uint64_t ld_dst, + const T* const src, // [batch_size, ld_src] + const uint64_t ld_src, + const uint64_t count, + const uint64_t batch_size) +{ + const auto tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= count * batch_size) { return; } + const auto i = tid % count; + const auto j = tid / count; + dst[i + (ld_dst * j)] = src[i + (ld_src * j)]; +} + +template +void batched_memcpy(T* const dst, // [batch_size, ld_dst] + const uint64_t ld_dst, + const T* const src, // [batch_size, ld_src] + const uint64_t ld_src, + const uint64_t count, + const uint64_t batch_size, + cudaStream_t cuda_stream) +{ + assert(ld_dst >= count); + assert(ld_src >= count); + constexpr uint32_t block_size = 256; + const auto grid_size = (batch_size * count + block_size - 1) / block_size; + batched_memcpy_kernel + <<>>(dst, ld_dst, src, ld_src, count, batch_size); +} + +template +__global__ void set_value_batch_kernel(T* const dev_ptr, + const std::size_t ld, + const T val, + const std::size_t count, + const std::size_t batch_size) +{ + const auto tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= count * batch_size) { return; } + const auto batch_id = tid / count; + const auto elem_id = tid % count; + dev_ptr[elem_id + ld * batch_id] = val; +} + +template +void set_value_batch(T* const dev_ptr, + const std::size_t ld, + const T val, + const std::size_t count, + const std::size_t batch_size, + cudaStream_t cuda_stream) +{ + constexpr std::uint32_t block_size = 256; + const auto grid_size = (count * batch_size + block_size - 1) / block_size; + set_value_batch_kernel + <<>>(dev_ptr, ld, val, count, batch_size); +} + +// result_buffer (work buffer) for "multi-kernel" +// +--------------------+------------------------------+-------------------+ +// | internal_top_k (A) | neighbors of internal_top_k | internal_topk (B) | +// | | | | +// +--------------------+------------------------------+-------------------+ +// |<--- result_buffer_allocation_size --->| +// |<--- result_buffer_size --->| // Double buffer (A) +// |<--- result_buffer_size --->| // Double buffer (B) +template +struct search : search_plan_impl { + using search_plan_impl::max_queries; + using search_plan_impl::itopk_size; + using search_plan_impl::algo; + using search_plan_impl::team_size; + using search_plan_impl::num_parents; + using search_plan_impl::min_iterations; + using search_plan_impl::max_iterations; + using search_plan_impl::load_bit_length; + using search_plan_impl::thread_block_size; + using search_plan_impl::hashmap_mode; + using search_plan_impl::hashmap_min_bitlen; + using search_plan_impl::hashmap_max_fill_rate; + using search_plan_impl::num_random_samplings; + using search_plan_impl::rand_xor_mask; + + using search_plan_impl::max_dim; + using search_plan_impl::dim; + using search_plan_impl::graph_degree; + using search_plan_impl::topk; + + using search_plan_impl::hash_bitlen; + + using search_plan_impl::small_hash_bitlen; + using search_plan_impl::small_hash_reset_interval; + using search_plan_impl::hashmap_size; + using search_plan_impl::dataset_size; + using search_plan_impl::result_buffer_size; + + using search_plan_impl::smem_size; + using search_plan_impl::load_bit_lenght; + + using search_plan_impl::hashmap; + using search_plan_impl::num_executed_iterations; + using search_plan_impl::dev_seed; + using search_plan_impl::num_seeds; + + size_t result_buffer_allocation_size; + rmm::device_uvector result_indices; // results_indices_buffer + rmm::device_uvector result_distances; // result_distances_buffer + rmm::device_uvector parent_node_list; + rmm::device_uvector topk_hint; + rmm::device_scalar terminate_flag; // dev_terminate_flag, host_terminate_flag.; + rmm::device_uvector topk_workspace; + + search(raft::device_resources const& res, + search_params params, + int64_t dim, + int64_t graph_degree, + uint32_t topk) + : search_plan_impl(res, params, dim, graph_degree, topk), + result_indices(0, res.get_stream()), + result_distances(0, res.get_stream()), + parent_node_list(0, res.get_stream()), + topk_hint(0, res.get_stream()), + topk_workspace(0, res.get_stream()), + terminate_flag(res.get_stream()) + { + set_params(res); + } + + void set_params(raft::device_resources const& res) + { + // + // Allocate memory for intermediate buffer and workspace. + // + result_buffer_size = itopk_size + (num_parents * graph_degree); + result_buffer_allocation_size = result_buffer_size + itopk_size; + result_indices.resize(result_buffer_allocation_size * max_queries, res.get_stream()); + result_distances.resize(result_buffer_allocation_size * max_queries, res.get_stream()); + + parent_node_list.resize(max_queries * num_parents, res.get_stream()); + topk_hint.resize(max_queries, res.get_stream()); + + size_t topk_workspace_size = _cuann_find_topk_bufferSize( + itopk_size, max_queries, result_buffer_size, utils::get_cuda_data_type()); + RAFT_LOG_DEBUG("# topk_workspace_size: %lu", topk_workspace_size); + topk_workspace.resize(topk_workspace_size, res.get_stream()); + + hashmap.resize(hashmap_size, res.get_stream()); + } + + ~search() {} + + void operator()(raft::device_resources const& res, + raft::device_matrix_view dataset, + raft::device_matrix_view graph, + INDEX_T* const topk_indices_ptr, // [num_queries, topk] + DISTANCE_T* const topk_distances_ptr, // [num_queries, topk] + const DATA_T* const queries_ptr, // [num_queries, dataset_dim] + const uint32_t num_queries, + const INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] + uint32_t* const num_executed_iterations, // [num_queries,] + uint32_t topk) + { + // Init hashmap + cudaStream_t stream = res.get_stream(); + const uint32_t hash_size = hashmap::get_size(hash_bitlen); + set_value_batch( + hashmap.data(), hash_size, utils::get_max_value(), hash_size, num_queries, stream); + // Init topk_hint + if (topk_hint.size() > 0) { set_value(topk_hint.data(), 0xffffffffu, num_queries, stream); } + + // Choose initial entry point candidates at random + random_pickup( + dataset.data_handle(), + dataset.extent(1), + dataset.extent(0), + queries_ptr, + num_queries, + result_buffer_size, + num_random_samplings, + rand_xor_mask, + dev_seed_ptr, + num_seeds, + result_indices.data(), + result_distances.data(), + result_buffer_allocation_size, + hashmap.data(), + hash_bitlen, + stream); + + unsigned iter = 0; + while (1) { + // Make an index list of internal top-k nodes + _cuann_find_topk(itopk_size, + num_queries, + result_buffer_size, + result_distances.data() + (iter & 0x1) * itopk_size, + result_buffer_allocation_size, + result_indices.data() + (iter & 0x1) * itopk_size, + result_buffer_allocation_size, + result_distances.data() + (1 - (iter & 0x1)) * result_buffer_size, + result_buffer_allocation_size, + result_indices.data() + (1 - (iter & 0x1)) * result_buffer_size, + result_buffer_allocation_size, + topk_workspace.data(), + true, + topk_hint.data(), + stream); + + // termination (1) + if ((iter + 1 == max_iterations)) { + iter++; + break; + } + + if (iter + 1 >= min_iterations) { set_value(terminate_flag.data(), 1, stream); } + + // pickup parent nodes + uint32_t _small_hash_bitlen = 0; + if ((iter + 1) % small_hash_reset_interval == 0) { _small_hash_bitlen = small_hash_bitlen; } + pickup_next_parents(result_indices.data() + (1 - (iter & 0x1)) * result_buffer_size, + result_buffer_allocation_size, + itopk_size, + num_queries, + hashmap.data(), + hash_bitlen, + _small_hash_bitlen, + parent_node_list.data(), + num_parents, + num_parents, + terminate_flag.data(), + stream); + + // termination (2) + if (iter + 1 >= min_iterations && terminate_flag.value(stream)) { + iter++; + break; + } + + // Compute distance to child nodes that are adjacent to the parent node + compute_distance_to_child_nodes( + parent_node_list.data(), + num_parents, + dataset.data_handle(), + dataset.extent(1), + dataset.extent(0), + graph.data_handle(), + graph.extent(1), + queries_ptr, + num_queries, + hashmap.data(), + hash_bitlen, + result_indices.data() + itopk_size, + result_distances.data() + itopk_size, + result_buffer_allocation_size, + stream); + + iter++; + } // while ( 1 ) + + // Remove parent bit in search results + remove_parent_bit(num_queries, + itopk_size, + result_indices.data() + (iter & 0x1) * result_buffer_size, + result_buffer_allocation_size, + stream); + + // Copy results from working buffer to final buffer + batched_memcpy(topk_indices_ptr, + topk, + result_indices.data() + (iter & 0x1) * result_buffer_size, + result_buffer_allocation_size, + topk, + num_queries, + stream); + if (topk_distances_ptr) { + batched_memcpy(topk_distances_ptr, + topk, + result_distances.data() + (iter & 0x1) * result_buffer_size, + result_buffer_allocation_size, + topk, + num_queries, + stream); + } + + if (num_executed_iterations) { + for (std::uint32_t i = 0; i < num_queries; i++) { + num_executed_iterations[i] = iter; + } + } + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } +}; + +} // namespace multi_kernel_search +} // namespace raft::neighbors::experimental::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh new file mode 100644 index 0000000000..d9613b345c --- /dev/null +++ b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh @@ -0,0 +1,334 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "hashmap.hpp" +// #include "search_single_cta.cuh" +// #include "topk_for_cagra/topk_core.cuh" + +#include +#include +#include +#include + +namespace raft::neighbors::experimental::cagra::detail { + +struct search_plan_impl_base : public search_params { + int64_t max_dim; + int64_t dim; + int64_t graph_degree; + uint32_t topk; + search_plan_impl_base(search_params params, int64_t dim, int64_t graph_degree, uint32_t topk) + : search_params(params), dim(dim), graph_degree(graph_degree), topk(topk) + { + set_max_dim_team(dim); + if (algo == search_algo::AUTO) { + if (itopk_size <= 512) { + algo = search_algo::SINGLE_CTA; + RAFT_LOG_DEBUG("Auto strategy: selecting single-cta"); + } else { + algo = search_algo::MULTI_KERNEL; + RAFT_LOG_DEBUG("Auto strategy: selecting multi-kernel"); + } + } + } + + void set_max_dim_team(int64_t dim) + { + max_dim = 128; + while (max_dim < dim && max_dim <= 1024) + max_dim *= 2; + if (team_size != 0) { RAFT_LOG_WARN("Overriding team size parameter."); } + // To keep binary size in check we limit only one team size specialization for each max_dim. + // TODO(tfeher): revise this decision. + switch (max_dim) { + case 128: team_size = 8; break; + case 256: team_size = 16; break; + case 512: team_size = 32; break; + case 1024: team_size = 32; break; + default: RAFT_LOG_DEBUG("Dataset dimension is too large (%lu)\n", dim); + } + } +}; + +template +struct search_plan_impl : public search_plan_impl_base { + int64_t hash_bitlen; + + size_t small_hash_bitlen; + size_t small_hash_reset_interval; + size_t hashmap_size; + uint32_t dataset_size; + uint32_t result_buffer_size; + + uint32_t smem_size; + uint32_t load_bit_lenght; + uint32_t topk; + uint32_t num_seeds; + + rmm::device_uvector hashmap; + rmm::device_uvector num_executed_iterations; // device or managed? + rmm::device_uvector dev_seed; // IdxT + + search_plan_impl(raft::device_resources const& res, + search_params params, + int64_t dim, + int64_t graph_degree, + uint32_t topk) + : search_plan_impl_base(params, dim, graph_degree, topk), + hashmap(0, res.get_stream()), + num_executed_iterations(0, res.get_stream()), + dev_seed(0, res.get_stream()), + num_seeds(0) + { + adjust_search_params(); + check_params(); + calc_hashmap_params(res); + set_max_dim_team(dim); + num_executed_iterations.resize(max_queries, res.get_stream()); + RAFT_LOG_DEBUG("# algo = %d", static_cast(algo)); + } + + virtual ~search_plan_impl() {} + + virtual void operator()(raft::device_resources const& res, + raft::device_matrix_view dataset, + raft::device_matrix_view graph, + INDEX_T* const result_indices_ptr, // [num_queries, topk] + DISTANCE_T* const result_distances_ptr, // [num_queries, topk] + const DATA_T* const queries_ptr, // [num_queries, dataset_dim] + const std::uint32_t num_queries, + const INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] + std::uint32_t* const num_executed_iterations, // [num_queries] + uint32_t topk){}; + + void adjust_search_params() + { + uint32_t _max_iterations = max_iterations; + if (max_iterations == 0) { + if (algo == search_algo::MULTI_CTA) { + _max_iterations = 1 + std::min(32 * 1.1, 32 + 10.0); // TODO(anaruse) + } else { + _max_iterations = + 1 + std::min((itopk_size / num_parents) * 1.1, (itopk_size / num_parents) + 10.0); + } + } + if (max_iterations < min_iterations) { _max_iterations = min_iterations; } + if (max_iterations < _max_iterations) { + RAFT_LOG_DEBUG( + "# max_iterations is increased from %u to %u.", max_iterations, _max_iterations); + max_iterations = _max_iterations; + } + if (itopk_size % 32) { + uint32_t itopk32 = itopk_size; + itopk32 += 32 - (itopk_size % 32); + RAFT_LOG_DEBUG("# internal_topk is increased from %u to %u, as it must be multiple of 32.", + itopk_size, + itopk32); + itopk_size = itopk32; + } + } + + // defines hash_bitlen, small_hash_bitlen, small_hash_reset interval, hash_size + inline void calc_hashmap_params(raft::device_resources const& res) + { + // for multipel CTA search + uint32_t mc_num_cta_per_query = 0; + uint32_t mc_num_parents = 0; + uint32_t mc_itopk_size = 0; + if (algo == search_algo::MULTI_CTA) { + mc_itopk_size = 32; + mc_num_parents = 1; + mc_num_cta_per_query = max(num_parents, itopk_size / 32); + RAFT_LOG_DEBUG("# mc_itopk_size: %u", mc_itopk_size); + RAFT_LOG_DEBUG("# mc_num_parents: %u", mc_num_parents); + RAFT_LOG_DEBUG("# mc_num_cta_per_query: %u", mc_num_cta_per_query); + } + + // Determine hash size (bit length) + hashmap_size = 0; + hash_bitlen = 0; + small_hash_bitlen = 0; + small_hash_reset_interval = 1024 * 1024; + float max_fill_rate = hashmap_max_fill_rate; + while (hashmap_mode == hash_mode::AUTO || hashmap_mode == hash_mode::SMALL) { + // + // The small-hash reduces hash table size by initializing the hash table + // for each iteraton and re-registering only the nodes that should not be + // re-visited in that iteration. Therefore, the size of small-hash should + // be determined based on the internal topk size and the number of nodes + // visited per iteration. + // + const auto max_visited_nodes = itopk_size + (num_parents * graph_degree * 1); + unsigned min_bitlen = 8; // 256 + unsigned max_bitlen = 13; // 8K + if (min_bitlen < hashmap_min_bitlen) { min_bitlen = hashmap_min_bitlen; } + hash_bitlen = min_bitlen; + while (max_visited_nodes > hashmap::get_size(hash_bitlen) * max_fill_rate) { + hash_bitlen += 1; + } + if (hash_bitlen > max_bitlen) { + // Switch to normal hash if hashmap_mode is AUTO, otherwise exit. + if (hashmap_mode == hash_mode::AUTO) { + hash_bitlen = 0; + break; + } else { + RAFT_LOG_DEBUG( + "[CAGRA Error]" + "small-hash cannot be used because the required hash size exceeds the limit (%u)", + hashmap::get_size(max_bitlen)); + exit(-1); + } + } + small_hash_bitlen = hash_bitlen; + // + // Sincc the hash table size is limited to a power of 2, the requirement, + // the maximum fill rate, may be satisfied even if the frequency of hash + // table reset is reduced to once every 2 or more iterations without + // changing the hash table size. In that case, reduce the reset frequency. + // + small_hash_reset_interval = 1; + while (1) { + const auto max_visited_nodes = + itopk_size + (num_parents * graph_degree * (small_hash_reset_interval + 1)); + if (max_visited_nodes > hashmap::get_size(hash_bitlen) * max_fill_rate) { break; } + small_hash_reset_interval += 1; + } + break; + } + if (hash_bitlen == 0) { + // + // The size of hash table is determined based on the maximum number of + // nodes that may be visited before the search is completed and the + // maximum fill rate of the hash table. + // + uint32_t max_visited_nodes = itopk_size + (num_parents * graph_degree * max_iterations); + if (algo == search_algo::MULTI_CTA) { + max_visited_nodes = mc_itopk_size + (mc_num_parents * graph_degree * max_iterations); + max_visited_nodes *= mc_num_cta_per_query; + } + unsigned min_bitlen = 11; // 2K + if (min_bitlen < hashmap_min_bitlen) { min_bitlen = hashmap_min_bitlen; } + hash_bitlen = min_bitlen; + while (max_visited_nodes > hashmap::get_size(hash_bitlen) * max_fill_rate) { + hash_bitlen += 1; + } + RAFT_EXPECTS(hash_bitlen <= 20, "hash_bitlen cannot be largen than 20 (1M)"); + } + + RAFT_LOG_DEBUG("# internal topK = %lu", itopk_size); + RAFT_LOG_DEBUG("# parent size = %lu", num_parents); + RAFT_LOG_DEBUG("# min_iterations = %lu", min_iterations); + RAFT_LOG_DEBUG("# max_iterations = %lu", max_iterations); + RAFT_LOG_DEBUG("# max_queries = %lu", max_queries); + RAFT_LOG_DEBUG("# hashmap mode = %s%s-%u", + (small_hash_bitlen > 0 ? "small-" : ""), + "hash", + hashmap::get_size(hash_bitlen)); + if (small_hash_bitlen > 0) { + RAFT_LOG_DEBUG("# small_hash_reset_interval = %lu", small_hash_reset_interval); + } + hashmap_size = sizeof(std::uint32_t) * max_queries * hashmap::get_size(hash_bitlen); + RAFT_LOG_DEBUG("# hashmap size: %lu", hashmap_size); + if (hashmap_size >= 1024 * 1024 * 1024) { + RAFT_LOG_DEBUG(" (%.2f GiB)", (double)hashmap_size / (1024 * 1024 * 1024)); + } else if (hashmap_size >= 1024 * 1024) { + RAFT_LOG_DEBUG(" (%.2f MiB)", (double)hashmap_size / (1024 * 1024)); + } else if (hashmap_size >= 1024) { + RAFT_LOG_DEBUG(" (%.2f KiB)", (double)hashmap_size / (1024)); + } + } + + void check(uint32_t topk) + { + RAFT_EXPECTS(topk <= itopk_size, "topk must be smaller than itopk_size = %lu", itopk_size); + if (algo == search_algo::MULTI_CTA) { + uint32_t mc_num_cta_per_query = max(num_parents, itopk_size / 32); + RAFT_EXPECTS(mc_num_cta_per_query * 32 >= topk, + "`mc_num_cta_per_query` (%u) * 32 must be equal to or greater than " + "`topk` /%u) when 'search_mode' is \"multi-cta\"", + mc_num_cta_per_query, + topk); + } + } + + inline void check_params() + { + std::string error_message = ""; + + if (itopk_size > 1024) { + if (algo == search_algo::MULTI_CTA) { + } else { + error_message += std::string("- `internal_topk` (" + std::to_string(itopk_size) + + ") must be smaller or equal to 1024"); + } + } + if (algo != search_algo::SINGLE_CTA && algo != search_algo::MULTI_CTA && + algo != search_algo::MULTI_KERNEL) { + error_message += "An invalid kernel mode has been given: " + std::to_string((int)algo) + ""; + } + if (team_size != 0 && team_size != 4 && team_size != 8 && team_size != 16 && team_size != 32) { + error_message += + "`team_size` must be 0, 4, 8, 16 or 32. " + std::to_string(team_size) + " has been given."; + } + if (load_bit_length != 0 && load_bit_length != 64 && load_bit_length != 128) { + error_message += "`load_bit_length` must be 0, 64 or 128. " + + std::to_string(load_bit_length) + " has been given."; + } + if (thread_block_size != 0 && thread_block_size != 64 && thread_block_size != 128 && + thread_block_size != 256 && thread_block_size != 512 && thread_block_size != 1024) { + error_message += "`thread_block_size` must be 0, 64, 128, 256 or 512. " + + std::to_string(load_bit_length) + " has been given."; + } + if (hashmap_min_bitlen > 20) { + error_message += "`hashmap_min_bitlen` must be equal to or smaller than 20. " + + std::to_string(hashmap_min_bitlen) + " has been given."; + } + if (hashmap_max_fill_rate < 0.1 || hashmap_max_fill_rate >= 0.9) { + error_message += + "`hashmap_max_fill_rate` must be equal to or greater than 0.1 and smaller than 0.9. " + + std::to_string(hashmap_max_fill_rate) + " has been given."; + } + if (algo == search_algo::MULTI_CTA) { + if (hashmap_mode == hash_mode::SMALL) { + error_message += "`small_hash` is not available when 'search_mode' is \"multi-cta\""; + } else { + hashmap_mode = hash_mode::HASH; + } + } + + if (error_message.length() != 0) { THROW("[CAGRA Error] %s", error_message.c_str()); } + } +}; + +// template +// struct search_plan { +// search_plan(raft::device_resources const& res, +// search_params param, +// int64_t dim, +// int64_t graph_degree) +// : plan(res, param, dim, graph_degree) +// { +// } +// void check(uint32_t topk) { plan.check(topk); } + +// // private: +// detail::search_plan_impl plan; +// }; +/** @} */ // end group cagra + +} // namespace raft::neighbors::experimental::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh new file mode 100644 index 0000000000..acd7ac321f --- /dev/null +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh @@ -0,0 +1,1157 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "bitonic.hpp" +#include "compute_distance.hpp" +#include "device_common.hpp" +#include "hashmap.hpp" +#include "search_plan.cuh" +#include "topk_for_cagra/topk_core.cuh" // TODO replace with raft topk +#include "utils.hpp" +#include +#include +#include // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp + +namespace raft::neighbors::experimental::cagra::detail { +namespace single_cta_search { + +// #define _CLK_BREAKDOWN + +template +__device__ void pickup_next_parents(std::uint32_t* const terminate_flag, + INDEX_T* const next_parent_indices, + INDEX_T* const internal_topk_indices, + const std::size_t internal_topk_size, + const std::size_t dataset_size, + const std::uint32_t num_parents) +{ + // if (threadIdx.x >= 32) return; + + for (std::uint32_t i = threadIdx.x; i < num_parents; i += 32) { + next_parent_indices[i] = utils::get_max_value(); + } + std::uint32_t itopk_max = internal_topk_size; + if (itopk_max % 32) { itopk_max += 32 - (itopk_max % 32); } + std::uint32_t num_new_parents = 0; + for (std::uint32_t j = threadIdx.x; j < itopk_max; j += 32) { + std::uint32_t jj = j; + if (TOPK_BY_BITONIC_SORT) { jj = device::swizzling(j); } + INDEX_T index; + int new_parent = 0; + if (j < internal_topk_size) { + index = internal_topk_indices[jj]; + if ((index & 0x80000000) == 0) { // check if most significant bit is set + new_parent = 1; + } + } + const std::uint32_t ballot_mask = __ballot_sync(0xffffffff, new_parent); + if (new_parent) { + const auto i = __popc(ballot_mask & ((1 << threadIdx.x) - 1)) + num_new_parents; + if (i < num_parents) { + next_parent_indices[i] = index; + // set most significant bit as used node + internal_topk_indices[jj] |= 0x80000000; + } + } + num_new_parents += __popc(ballot_mask); + if (num_new_parents >= num_parents) { break; } + } + if (threadIdx.x == 0 && (num_new_parents == 0)) { *terminate_flag = 1; } +} + +template +struct topk_by_radix_sort_base { + static constexpr std::uint32_t smem_size = MAX_INTERNAL_TOPK * 2 + 2048 + 8; + static constexpr std::uint32_t state_bit_lenght = 0; + static constexpr std::uint32_t vecLen = 2; // TODO +}; +template +struct topk_by_radix_sort : topk_by_radix_sort_base { +}; + +template +struct topk_by_radix_sort> + : topk_by_radix_sort_base { + __device__ void operator()(uint32_t topk, + uint32_t batch_size, + uint32_t len_x, + const uint32_t* _x, + const uint32_t* _in_vals, + uint32_t* _y, + uint32_t* _out_vals, + uint32_t* work, + uint32_t* _hints, + bool sort, + uint32_t* _smem) + { + std::uint8_t* state = (std::uint8_t*)work; + topk_cta_11_core::state_bit_lenght, + topk_by_radix_sort_base::vecLen, + 64, + 32>(topk, len_x, _x, _in_vals, _y, _out_vals, state, _hints, sort, _smem); + } +}; + +#define TOP_FUNC_PARTIAL_SPECIALIZATION(V) \ + template \ + struct topk_by_radix_sort< \ + MAX_INTERNAL_TOPK, \ + BLOCK_SIZE, \ + std::enable_if_t<((MAX_INTERNAL_TOPK <= V) && (2 * MAX_INTERNAL_TOPK > V))>> \ + : topk_by_radix_sort_base { \ + __device__ void operator()(uint32_t topk, \ + uint32_t batch_size, \ + uint32_t len_x, \ + const uint32_t* _x, \ + const uint32_t* _in_vals, \ + uint32_t* _y, \ + uint32_t* _out_vals, \ + uint32_t* work, \ + uint32_t* _hints, \ + bool sort, \ + uint32_t* _smem) \ + { \ + assert(BLOCK_SIZE >= V / 4); \ + std::uint8_t* state = (std::uint8_t*)work; \ + topk_cta_11_core::state_bit_lenght, \ + topk_by_radix_sort_base::vecLen, \ + V, \ + V / 4>( \ + topk, len_x, _x, _in_vals, _y, _out_vals, state, _hints, sort, _smem); \ + } \ + }; +TOP_FUNC_PARTIAL_SPECIALIZATION(128); +TOP_FUNC_PARTIAL_SPECIALIZATION(256); +TOP_FUNC_PARTIAL_SPECIALIZATION(512); +TOP_FUNC_PARTIAL_SPECIALIZATION(1024); + +template +__device__ inline void topk_by_bitonic_sort_1st( + float* candidate_distances, // [num_candidates] + std::uint32_t* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + const std::uint32_t num_itopk) +{ + const unsigned lane_id = threadIdx.x % 32; + const unsigned warp_id = threadIdx.x / 32; + if (MULTI_WARPS == 0) { + if (warp_id > 0) { return; } + constexpr unsigned N = (MAX_CANDIDATES + 31) / 32; + float key[N]; + std::uint32_t val[N]; + /* Candidates -> Reg */ + for (unsigned i = 0; i < N; i++) { + unsigned j = lane_id + (32 * i); + if (j < num_candidates) { + key[i] = candidate_distances[j]; + val[i] = candidate_indices[j]; + } else { + key[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); + } + } + /* Sort */ + bitonic::warp_sort(key, val); + /* Reg -> Temp_itopk */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * lane_id) + i; + if (j < num_candidates && j < num_itopk) { + candidate_distances[device::swizzling(j)] = key[i]; + candidate_indices[device::swizzling(j)] = val[i]; + } + } + } else { + // Use two warps (64 threads) + constexpr unsigned max_candidates_per_warp = (MAX_CANDIDATES + 1) / 2; + constexpr unsigned N = (max_candidates_per_warp + 31) / 32; + float key[N]; + std::uint32_t val[N]; + if (warp_id < 2) { + /* Candidates -> Reg */ + for (unsigned i = 0; i < N; i++) { + unsigned jl = lane_id + (32 * i); + unsigned j = jl + (max_candidates_per_warp * warp_id); + if (j < num_candidates) { + key[i] = candidate_distances[j]; + val[i] = candidate_indices[j]; + } else { + key[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); + } + } + /* Sort */ + bitonic::warp_sort(key, val); + /* Reg -> Temp_candidates */ + for (unsigned i = 0; i < N; i++) { + unsigned jl = (N * lane_id) + i; + unsigned j = jl + (max_candidates_per_warp * warp_id); + if (j < num_candidates && jl < num_itopk) { + candidate_distances[device::swizzling(j)] = key[i]; + candidate_indices[device::swizzling(j)] = val[i]; + } + } + } + __syncthreads(); + + unsigned num_warps_used = (num_itopk + max_candidates_per_warp - 1) / max_candidates_per_warp; + if (warp_id < num_warps_used) { + /* Temp_candidates -> Reg */ + for (unsigned i = 0; i < N; i++) { + unsigned jl = (N * lane_id) + i; + unsigned kl = max_candidates_per_warp - 1 - jl; + unsigned j = jl + (max_candidates_per_warp * warp_id); + unsigned k = MAX_CANDIDATES - 1 - j; + if (j >= num_candidates || k >= num_candidates || kl >= num_itopk) continue; + float temp_key = candidate_distances[device::swizzling(k)]; + if (key[i] == temp_key) continue; + if ((warp_id == 0) == (key[i] > temp_key)) { + key[i] = temp_key; + val[i] = candidate_indices[device::swizzling(k)]; + } + } + } + if (num_warps_used > 1) { __syncthreads(); } + if (warp_id < num_warps_used) { + /* Merge */ + bitonic::warp_merge(key, val, 32); + /* Reg -> Temp_itopk */ + for (unsigned i = 0; i < N; i++) { + unsigned jl = (N * lane_id) + i; + unsigned j = jl + (max_candidates_per_warp * warp_id); + if (j < num_candidates && j < num_itopk) { + candidate_distances[device::swizzling(j)] = key[i]; + candidate_indices[device::swizzling(j)] = val[i]; + } + } + } + if (num_warps_used > 1) { __syncthreads(); } + } +} + +template +__device__ inline void topk_by_bitonic_sort_2nd( + float* itopk_distances, // [num_itopk] + std::uint32_t* itopk_indices, // [num_itopk] + const std::uint32_t num_itopk, + float* candidate_distances, // [num_candidates] + std::uint32_t* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + std::uint32_t* work_buf, + const bool first) +{ + const unsigned lane_id = threadIdx.x % 32; + const unsigned warp_id = threadIdx.x / 32; + if (MULTI_WARPS == 0) { + if (warp_id > 0) { return; } + constexpr unsigned N = (MAX_ITOPK + 31) / 32; + float key[N]; + std::uint32_t val[N]; + if (first) { + /* Load itopk results */ + for (unsigned i = 0; i < N; i++) { + unsigned j = lane_id + (32 * i); + if (j < num_itopk) { + key[i] = itopk_distances[j]; + val[i] = itopk_indices[j]; + } else { + key[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); + } + } + /* Warp Sort */ + bitonic::warp_sort(key, val); + } else { + /* Load itopk results */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * lane_id) + i; + if (j < num_itopk) { + key[i] = itopk_distances[device::swizzling(j)]; + val[i] = itopk_indices[device::swizzling(j)]; + } else { + key[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); + } + } + } + /* Merge candidates */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * lane_id) + i; // [0:MAX_ITOPK-1] + unsigned k = MAX_ITOPK - 1 - j; + if (k >= num_itopk || k >= num_candidates) continue; + float candidate_key = candidate_distances[device::swizzling(k)]; + if (key[i] > candidate_key) { + key[i] = candidate_key; + val[i] = candidate_indices[device::swizzling(k)]; + } + } + /* Warp Merge */ + bitonic::warp_merge(key, val, 32); + /* Store new itopk results */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * lane_id) + i; + if (j < num_itopk) { + itopk_distances[device::swizzling(j)] = key[i]; + itopk_indices[device::swizzling(j)] = val[i]; + } + } + } else { + // Use two warps (64 threads) or more + constexpr unsigned max_itopk_per_warp = (MAX_ITOPK + 1) / 2; + constexpr unsigned N = (max_itopk_per_warp + 31) / 32; + float key[N]; + std::uint32_t val[N]; + if (first) { + /* Load itop results (not sorted) */ + if (warp_id < 2) { + for (unsigned i = 0; i < N; i++) { + unsigned j = lane_id + (32 * i) + (max_itopk_per_warp * warp_id); + if (j < num_itopk) { + key[i] = itopk_distances[j]; + val[i] = itopk_indices[j]; + } else { + key[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); + } + } + /* Warp Sort */ + bitonic::warp_sort(key, val); + /* Store intermedidate results */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * threadIdx.x) + i; + if (j >= num_itopk) continue; + itopk_distances[device::swizzling(j)] = key[i]; + itopk_indices[device::swizzling(j)] = val[i]; + } + } + __syncthreads(); + if (warp_id < 2) { + /* Load intermedidate results */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * threadIdx.x) + i; + unsigned k = MAX_ITOPK - 1 - j; + if (k >= num_itopk) continue; + float temp_key = itopk_distances[device::swizzling(k)]; + if (key[i] == temp_key) continue; + if ((warp_id == 0) == (key[i] > temp_key)) { + key[i] = temp_key; + val[i] = itopk_indices[device::swizzling(k)]; + } + } + /* Warp Merge */ + bitonic::warp_merge(key, val, 32); + } + __syncthreads(); + /* Store itopk results (sorted) */ + if (warp_id < 2) { + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * threadIdx.x) + i; + if (j >= num_itopk) continue; + itopk_distances[device::swizzling(j)] = key[i]; + itopk_indices[device::swizzling(j)] = val[i]; + } + } + } + const uint32_t num_itopk_div2 = num_itopk / 2; + if (threadIdx.x < 3) { + // work_buf is used to obtain turning points in 1st and 2nd half of itopk afer merge. + work_buf[threadIdx.x] = num_itopk_div2; + } + __syncthreads(); + + // Merge candidates (using whole threads) + for (unsigned k = threadIdx.x; k < min(num_candidates, num_itopk); k += blockDim.x) { + const unsigned j = num_itopk - 1 - k; + const float itopk_key = itopk_distances[device::swizzling(j)]; + const float candidate_key = candidate_distances[device::swizzling(k)]; + if (itopk_key > candidate_key) { + itopk_distances[device::swizzling(j)] = candidate_key; + itopk_indices[device::swizzling(j)] = candidate_indices[device::swizzling(k)]; + if (j < num_itopk_div2) { + atomicMin(work_buf + 2, j); + } else { + atomicMin(work_buf + 1, j - num_itopk_div2); + } + } + } + __syncthreads(); + + // Merge 1st and 2nd half of itopk (using whole threads) + for (unsigned j = threadIdx.x; j < num_itopk_div2; j += blockDim.x) { + const unsigned k = j + num_itopk_div2; + float key_0 = itopk_distances[device::swizzling(j)]; + float key_1 = itopk_distances[device::swizzling(k)]; + if (key_0 > key_1) { + itopk_distances[device::swizzling(j)] = key_1; + itopk_distances[device::swizzling(k)] = key_0; + std::uint32_t val_0 = itopk_indices[device::swizzling(j)]; + std::uint32_t val_1 = itopk_indices[device::swizzling(k)]; + itopk_indices[device::swizzling(j)] = val_1; + itopk_indices[device::swizzling(k)] = val_0; + atomicMin(work_buf + 0, j); + } + } + if (threadIdx.x == blockDim.x - 1) { + if (work_buf[2] < num_itopk_div2) { work_buf[1] = work_buf[2]; } + } + __syncthreads(); + // if ((blockIdx.x == 0) && (threadIdx.x == 0)) { + // RAFT_LOG_DEBUG( "work_buf: %u, %u, %u\n", work_buf[0], work_buf[1], work_buf[2] ); + // } + + // Warp-0 merges 1st half of itopk, warp-1 does 2nd half. + if (warp_id < 2) { + // Load intermedidate itopk results + const uint32_t turning_point = work_buf[warp_id]; // turning_point <= num_itopk_div2 + for (unsigned i = 0; i < N; i++) { + unsigned k = num_itopk; + unsigned j = (N * lane_id) + i; + if (j < turning_point) { + k = j + (num_itopk_div2 * warp_id); + } else if (j >= (MAX_ITOPK / 2 - num_itopk_div2)) { + j -= (MAX_ITOPK / 2 - num_itopk_div2); + if ((turning_point <= j) && (j < num_itopk_div2)) { k = j + (num_itopk_div2 * warp_id); } + } + if (k < num_itopk) { + key[i] = itopk_distances[device::swizzling(k)]; + val[i] = itopk_indices[device::swizzling(k)]; + } else { + key[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); + } + } + /* Warp Merge */ + bitonic::warp_merge(key, val, 32); + /* Store new itopk results */ + for (unsigned i = 0; i < N; i++) { + const unsigned j = (N * lane_id) + i; + if (j < num_itopk_div2) { + unsigned k = j + (num_itopk_div2 * warp_id); + itopk_distances[device::swizzling(k)] = key[i]; + itopk_indices[device::swizzling(k)] = val[i]; + } + } + } + } +} + +template +__device__ void topk_by_bitonic_sort(float* itopk_distances, // [num_itopk] + std::uint32_t* itopk_indices, // [num_itopk] + const std::uint32_t num_itopk, + float* candidate_distances, // [num_candidates] + std::uint32_t* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + std::uint32_t* work_buf, + const bool first) +{ + // The results in candidate_distances/indices are sorted by bitonic sort. + topk_by_bitonic_sort_1st( + candidate_distances, candidate_indices, num_candidates, num_itopk); + + // The results sorted above are merged with the internal intermediate top-k + // results so far using bitonic merge. + topk_by_bitonic_sort_2nd(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first); +} + +template +__device__ inline void hashmap_restore(uint32_t* hashmap_ptr, + const size_t hashmap_bitlen, + const INDEX_T* itopk_indices, + uint32_t itopk_size) +{ + if (threadIdx.x < FIRST_TID || threadIdx.x >= LAST_TID) return; + for (unsigned i = threadIdx.x - FIRST_TID; i < itopk_size; i += LAST_TID - FIRST_TID) { + auto key = itopk_indices[i] & ~0x80000000; // clear most significant bit + hashmap::insert(hashmap_ptr, hashmap_bitlen, key); + } +} + +template +__device__ inline void set_value_device(T* const ptr, const T fill, const std::uint32_t count) +{ + for (std::uint32_t i = threadIdx.x; i < count; i += BLOCK_SIZE) { + ptr[i] = fill; + } +} + +// One query one thread block +template +__launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ + void search_kernel(INDEX_T* const result_indices_ptr, // [num_queries, top_k] + DISTANCE_T* const result_distances_ptr, // [num_queries, top_k] + const std::uint32_t top_k, + const DATA_T* const dataset_ptr, // [dataset_size, dataset_dim] + const std::size_t dataset_dim, + const std::size_t dataset_size, + const DATA_T* const queries_ptr, // [num_queries, dataset_dim] + const INDEX_T* const knn_graph, // [dataset_size, graph_degree] + const std::uint32_t graph_degree, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const INDEX_T* seed_ptr, // [num_queries, num_seeds] + const uint32_t num_seeds, + std::uint32_t* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + const std::uint32_t internal_topk, + const std::uint32_t num_parents, + const std::uint32_t min_iteration, + const std::uint32_t max_iteration, + std::uint32_t* const num_executed_iterations, // [num_queries] + const std::uint32_t hash_bitlen, + const std::uint32_t small_hash_bitlen, + const std::uint32_t small_hash_reset_interval) +{ + const auto query_id = blockIdx.y; + +#ifdef _CLK_BREAKDOWN + std::uint64_t clk_init = 0; + std::uint64_t clk_compute_1st_distance = 0; + std::uint64_t clk_topk = 0; + std::uint64_t clk_reset_hash = 0; + std::uint64_t clk_pickup_parents = 0; + std::uint64_t clk_restore_hash = 0; + std::uint64_t clk_compute_distance = 0; + std::uint64_t clk_start; +#define _CLK_START() clk_start = clock64() +#define _CLK_REC(V) V += clock64() - clk_start; +#else +#define _CLK_START() +#define _CLK_REC(V) +#endif + _CLK_START(); + + extern __shared__ std::uint32_t smem[]; + + // Layout of result_buffer + // +----------------------+------------------------------+---------+ + // | internal_top_k | neighbors of internal_top_k | padding | + // | | | upto 32 | + // +----------------------+------------------------------+---------+ + // |<--- result_buffer_size --->| + std::uint32_t result_buffer_size = internal_topk + (num_parents * graph_degree); + std::uint32_t result_buffer_size_32 = result_buffer_size; + if (result_buffer_size % 32) { result_buffer_size_32 += 32 - (result_buffer_size % 32); } + const auto small_hash_size = hashmap::get_size(small_hash_bitlen); + auto query_buffer = reinterpret_cast(smem); + auto result_indices_buffer = reinterpret_cast(query_buffer + MAX_DATASET_DIM); + auto result_distances_buffer = + reinterpret_cast(result_indices_buffer + result_buffer_size_32); + auto visited_hash_buffer = + reinterpret_cast(result_distances_buffer + result_buffer_size_32); + auto parent_list_buffer = reinterpret_cast(visited_hash_buffer + small_hash_size); + auto topk_ws = reinterpret_cast(parent_list_buffer + num_parents); + auto terminate_flag = reinterpret_cast(topk_ws + 3); + auto smem_working_ptr = reinterpret_cast(terminate_flag + 1); + + const DATA_T* const query_ptr = queries_ptr + query_id * dataset_dim; + for (unsigned i = threadIdx.x; i < MAX_DATASET_DIM; i += BLOCK_SIZE) { + unsigned j = device::swizzling(i); + if (i < dataset_dim) { + query_buffer[j] = static_cast(query_ptr[i]) * device::fragment_scale(); + } else { + query_buffer[j] = 0.0; + } + } + if (threadIdx.x == 0) { + terminate_flag[0] = 0; + topk_ws[0] = ~0u; + } + + // Init hashmap + uint32_t* local_visited_hashmap_ptr; + if (small_hash_bitlen) { + local_visited_hashmap_ptr = visited_hash_buffer; + } else { + local_visited_hashmap_ptr = visited_hashmap_ptr + (hashmap::get_size(hash_bitlen) * query_id); + } + hashmap::init<0, BLOCK_SIZE>(local_visited_hashmap_ptr, hash_bitlen); + __syncthreads(); + _CLK_REC(clk_init); + + // compute distance to randomly selecting nodes + _CLK_START(); + const INDEX_T* const local_seed_ptr = seed_ptr ? seed_ptr + (num_seeds * query_id) : nullptr; + device::compute_distance_to_random_nodes( + result_indices_buffer, + result_distances_buffer, + query_buffer, + dataset_ptr, + dataset_dim, + dataset_size, + result_buffer_size, + num_distilation, + rand_xor_mask, + local_seed_ptr, + num_seeds, + local_visited_hashmap_ptr, + hash_bitlen); + __syncthreads(); + _CLK_REC(clk_compute_1st_distance); + + std::uint32_t iter = 0; + while (1) { + // sort + if (TOPK_BY_BITONIC_SORT) { + // [Notice] + // It is good to use multiple warps in topk_by_bitonic_sort() when + // batch size is small (short-latency), but it might not be always good + // when batch size is large (high-throughput). + // topk_by_bitonic_sort() consists of two operations: + // if MAX_CANDIDATES is greater than 128, the first operation uses two warps; + // if MAX_ITOPK is greater than 256, the second operation used two warps. + constexpr unsigned multi_warps_1 = ((BLOCK_SIZE >= 64) && (MAX_CANDIDATES > 128)) ? 1 : 0; + constexpr unsigned multi_warps_2 = ((BLOCK_SIZE >= 64) && (MAX_ITOPK > 256)) ? 1 : 0; + + // reset small-hash table. + if ((iter + 1) % small_hash_reset_interval == 0) { + // Depending on the block size and the number of warps used in + // topk_by_bitonic_sort(), determine which warps are used to reset + // the small hash and whether they are performed in overlap with + // topk_by_bitonic_sort(). + _CLK_START(); + if (BLOCK_SIZE == 32) { + hashmap::init<0, BLOCK_SIZE>(local_visited_hashmap_ptr, hash_bitlen); + } else if (BLOCK_SIZE == 64) { + if (multi_warps_1 || multi_warps_2) { + hashmap::init<0, BLOCK_SIZE>(local_visited_hashmap_ptr, hash_bitlen); + } else { + hashmap::init<32, BLOCK_SIZE>(local_visited_hashmap_ptr, hash_bitlen); + } + } else { + if (multi_warps_1 || multi_warps_2) { + hashmap::init<64, BLOCK_SIZE>(local_visited_hashmap_ptr, hash_bitlen); + } else { + hashmap::init<32, BLOCK_SIZE>(local_visited_hashmap_ptr, hash_bitlen); + } + } + _CLK_REC(clk_reset_hash); + } + + // topk with bitonic sort + _CLK_START(); + topk_by_bitonic_sort( + result_distances_buffer, + result_indices_buffer, + internal_topk, + result_distances_buffer + internal_topk, + result_indices_buffer + internal_topk, + num_parents * graph_degree, + topk_ws, + (iter == 0)); + _CLK_REC(clk_topk); + + } else { + _CLK_START(); + // topk with radix block sort + topk_by_radix_sort{}( + internal_topk, + gridDim.x, + result_buffer_size, + reinterpret_cast(result_distances_buffer), + result_indices_buffer, + reinterpret_cast(result_distances_buffer), + result_indices_buffer, + nullptr, + topk_ws, + true, + reinterpret_cast(smem_working_ptr)); + _CLK_REC(clk_topk); + + // reset small-hash table + if ((iter + 1) % small_hash_reset_interval == 0) { + _CLK_START(); + hashmap::init<0, BLOCK_SIZE>(local_visited_hashmap_ptr, hash_bitlen); + _CLK_REC(clk_reset_hash); + } + } + __syncthreads(); + + if (iter + 1 == max_iteration) { break; } + + // pick up next parents + if (threadIdx.x < 32) { + _CLK_START(); + pickup_next_parents(terminate_flag, + parent_list_buffer, + result_indices_buffer, + internal_topk, + dataset_size, + num_parents); + _CLK_REC(clk_pickup_parents); + } + + // restore small-hash table by putting internal-topk indices in it + if ((iter + 1) % small_hash_reset_interval == 0) { + constexpr unsigned first_tid = ((BLOCK_SIZE <= 32) ? 0 : 32); + _CLK_START(); + hashmap_restore( + local_visited_hashmap_ptr, hash_bitlen, result_indices_buffer, internal_topk); + _CLK_REC(clk_restore_hash); + } + __syncthreads(); + + if (*terminate_flag && iter >= min_iteration) { break; } + + // compute the norms between child nodes and query node + _CLK_START(); + constexpr unsigned max_n_frags = 16; + device:: + compute_distance_to_child_nodes( + result_indices_buffer + internal_topk, + result_distances_buffer + internal_topk, + query_buffer, + dataset_ptr, + dataset_dim, + knn_graph, + graph_degree, + local_visited_hashmap_ptr, + hash_bitlen, + parent_list_buffer, + num_parents); + __syncthreads(); + _CLK_REC(clk_compute_distance); + + iter++; + } + for (std::uint32_t i = threadIdx.x; i < top_k; i += BLOCK_SIZE) { + unsigned j = i + (top_k * query_id); + unsigned ii = i; + if (TOPK_BY_BITONIC_SORT) { ii = device::swizzling(i); } + if (result_distances_ptr != nullptr) { result_distances_ptr[j] = result_distances_buffer[ii]; } + result_indices_ptr[j] = result_indices_buffer[ii] & ~0x80000000; // clear most significant bit + } + if (threadIdx.x == 0 && num_executed_iterations != nullptr) { + num_executed_iterations[query_id] = iter + 1; + } +#ifdef _CLK_BREAKDOWN + if ((threadIdx.x == 0 || threadIdx.x == BLOCK_SIZE - 1) && ((query_id * 3) % gridDim.y < 3)) { + RAFT_LOG_DEBUG( + "query, %d, thread, %d" + ", init, %d" + ", 1st_distance, %lu" + ", topk, %lu" + ", reset_hash, %lu" + ", pickup_parents, %lu" + ", restore_hash, %lu" + ", distance, %lu" + "\n", + query_id, + threadIdx.x, + clk_init, + clk_compute_1st_distance, + clk_topk, + clk_reset_hash, + clk_pickup_parents, + clk_restore_hash, + clk_compute_distance); + } +#endif +} + +#define SET_KERNEL_3( \ + BLOCK_SIZE, BLOCK_COUNT, MAX_ITOPK, MAX_CANDIDATES, TOPK_BY_BITONIC_SORT, LOAD_T) \ + kernel = search_kernel; + +#define SET_KERNEL_2(BLOCK_SIZE, BLOCK_COUNT, MAX_ITOPK, MAX_CANDIDATES, TOPK_BY_BITONIC_SORT) \ + if (load_bit_length == 128) { \ + SET_KERNEL_3(BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ITOPK, \ + MAX_CANDIDATES, \ + TOPK_BY_BITONIC_SORT, \ + device::LOAD_128BIT_T) \ + } else if (load_bit_length == 64) { \ + SET_KERNEL_3(BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ITOPK, \ + MAX_CANDIDATES, \ + TOPK_BY_BITONIC_SORT, \ + device::LOAD_64BIT_T) \ + } + +#define SET_KERNEL_1B(MAX_ITOPK, MAX_CANDIDATES) \ + /* if ( block_size == 32 ) { \ + SET_KERNEL_2( 32, 20, MAX_ITOPK, MAX_CANDIDATES, 1 ) \ + } else */ \ + if (block_size == 64) { \ + SET_KERNEL_2(64, 16 /*20*/, MAX_ITOPK, MAX_CANDIDATES, 1) \ + } else if (block_size == 128) { \ + SET_KERNEL_2(128, 8, MAX_ITOPK, MAX_CANDIDATES, 1) \ + } else if (block_size == 256) { \ + SET_KERNEL_2(256, 4, MAX_ITOPK, MAX_CANDIDATES, 1) \ + } else if (block_size == 512) { \ + SET_KERNEL_2(512, 2, MAX_ITOPK, MAX_CANDIDATES, 1) \ + } else { \ + SET_KERNEL_2(1024, 1, MAX_ITOPK, MAX_CANDIDATES, 1) \ + } + +#define SET_KERNEL_1R(MAX_ITOPK, MAX_CANDIDATES) \ + if (block_size == 256) { \ + SET_KERNEL_2(256, 4, MAX_ITOPK, MAX_CANDIDATES, 0) \ + } else if (block_size == 512) { \ + SET_KERNEL_2(512, 2, MAX_ITOPK, MAX_CANDIDATES, 0) \ + } else { \ + SET_KERNEL_2(1024, 1, MAX_ITOPK, MAX_CANDIDATES, 0) \ + } + +#define SET_KERNEL \ + typedef void (*search_kernel_t)(INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const std::uint32_t top_k, \ + const DATA_T* const dataset_ptr, \ + const std::size_t dataset_dim, \ + const std::size_t dataset_size, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const std::uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + std::uint32_t* const visited_hashmap_ptr, \ + const std::uint32_t itopk_size, \ + const std::uint32_t num_parents, \ + const std::uint32_t min_iteration, \ + const std::uint32_t max_iteration, \ + std::uint32_t* const num_executed_iterations, \ + const std::uint32_t hash_bitlen, \ + const std::uint32_t small_hash_bitlen, \ + const std::uint32_t small_hash_reset_interval); \ + search_kernel_t kernel; \ + if (num_itopk_candidates <= 64) { \ + constexpr unsigned max_candidates = 64; \ + if (itopk_size <= 64) { \ + SET_KERNEL_1B(64, max_candidates) \ + } else if (itopk_size <= 128) { \ + SET_KERNEL_1B(128, max_candidates) \ + } else if (itopk_size <= 256) { \ + SET_KERNEL_1B(256, max_candidates) \ + } else if (itopk_size <= 512) { \ + SET_KERNEL_1B(512, max_candidates) \ + } \ + } else if (num_itopk_candidates <= 128) { \ + constexpr unsigned max_candidates = 128; \ + if (itopk_size <= 64) { \ + SET_KERNEL_1B(64, max_candidates) \ + } else if (itopk_size <= 128) { \ + SET_KERNEL_1B(128, max_candidates) \ + } else if (itopk_size <= 256) { \ + SET_KERNEL_1B(256, max_candidates) \ + } else if (itopk_size <= 512) { \ + SET_KERNEL_1B(512, max_candidates) \ + } \ + } else if (num_itopk_candidates <= 256) { \ + constexpr unsigned max_candidates = 256; \ + if (itopk_size <= 64) { \ + SET_KERNEL_1B(64, max_candidates) \ + } else if (itopk_size <= 128) { \ + SET_KERNEL_1B(128, max_candidates) \ + } else if (itopk_size <= 256) { \ + SET_KERNEL_1B(256, max_candidates) \ + } else if (itopk_size <= 512) { \ + SET_KERNEL_1B(512, max_candidates) \ + } \ + } else { \ + /* Radix-based topk is used */ \ + if (itopk_size <= 256) { \ + SET_KERNEL_1R(256, /*to avoid build failure*/ 32) \ + } else if (itopk_size <= 512) { \ + SET_KERNEL_1R(512, /*to avoid build failure*/ 32) \ + } \ + } + +template +struct search : search_plan_impl { + using search_plan_impl::max_queries; + using search_plan_impl::itopk_size; + using search_plan_impl::algo; + using search_plan_impl::team_size; + using search_plan_impl::num_parents; + using search_plan_impl::min_iterations; + using search_plan_impl::max_iterations; + using search_plan_impl::load_bit_length; + using search_plan_impl::thread_block_size; + using search_plan_impl::hashmap_mode; + using search_plan_impl::hashmap_min_bitlen; + using search_plan_impl::hashmap_max_fill_rate; + using search_plan_impl::num_random_samplings; + using search_plan_impl::rand_xor_mask; + + using search_plan_impl::max_dim; + using search_plan_impl::dim; + using search_plan_impl::graph_degree; + using search_plan_impl::topk; + + using search_plan_impl::hash_bitlen; + + using search_plan_impl::small_hash_bitlen; + using search_plan_impl::small_hash_reset_interval; + using search_plan_impl::hashmap_size; + using search_plan_impl::dataset_size; + using search_plan_impl::result_buffer_size; + + using search_plan_impl::smem_size; + using search_plan_impl::load_bit_lenght; + + using search_plan_impl::hashmap; + using search_plan_impl::num_executed_iterations; + using search_plan_impl::dev_seed; + using search_plan_impl::num_seeds; + + uint32_t num_itopk_candidates; + + search(raft::device_resources const& res, + search_params params, + int64_t dim, + int64_t graph_degree, + uint32_t topk) + : search_plan_impl(res, params, dim, graph_degree, topk) + { + set_params(res); + } + + ~search() {} + + inline void set_params(raft::device_resources const& res) + { + num_itopk_candidates = num_parents * graph_degree; + result_buffer_size = itopk_size + num_itopk_candidates; + + typedef raft::Pow2<32> AlignBytes; + unsigned result_buffer_size_32 = AlignBytes::roundUp(result_buffer_size); + + constexpr unsigned max_itopk = 512; + RAFT_EXPECTS(itopk_size <= max_itopk, "itopk_size cannot be larger than %u", max_itopk); + + RAFT_LOG_DEBUG("# num_itopk_candidates: %u", num_itopk_candidates); + RAFT_LOG_DEBUG("# num_itopk: %u", itopk_size); + // + // Determine the thread block size + // + constexpr unsigned min_block_size = 64; // 32 or 64 + constexpr unsigned min_block_size_radix = 256; + constexpr unsigned max_block_size = 1024; + // + const std::uint32_t topk_ws_size = 3; + const std::uint32_t base_smem_size = + sizeof(float) * max_dim + (sizeof(INDEX_T) + sizeof(DISTANCE_T)) * result_buffer_size_32 + + sizeof(std::uint32_t) * hashmap::get_size(small_hash_bitlen) + + sizeof(std::uint32_t) * num_parents + sizeof(std::uint32_t) * topk_ws_size + + sizeof(std::uint32_t); + smem_size = base_smem_size; + if (num_itopk_candidates > 256) { + // Tentatively calculate the required share memory size when radix + // sort based topk is used, assuming the block size is the maximum. + if (itopk_size <= 256) { + smem_size += topk_by_radix_sort<256, max_block_size>::smem_size * sizeof(std::uint32_t); + } else { + smem_size += topk_by_radix_sort<512, max_block_size>::smem_size * sizeof(std::uint32_t); + } + } + + uint32_t block_size = thread_block_size; + if (block_size == 0) { + block_size = min_block_size; + + if (num_itopk_candidates > 256) { + // radix-based topk is used. + block_size = min_block_size_radix; + + // Internal topk values per thread must be equlal to or less than 4 + // when radix-sort block_topk is used. + while ((block_size < max_block_size) && (max_itopk / block_size > 4)) { + block_size *= 2; + } + } + + // Increase block size according to shared memory requirements. + // If block size is 32, upper limit of shared memory size per + // thread block is set to 4096. This is GPU generation dependent. + constexpr unsigned ulimit_smem_size_cta32 = 4096; + while (smem_size > ulimit_smem_size_cta32 / 32 * block_size) { + block_size *= 2; + } + + // Increase block size to improve GPU occupancy when batch size + // is small, that is, number of queries is low. + cudaDeviceProp deviceProp = res.get_device_properties(); + RAFT_LOG_DEBUG("# multiProcessorCount: %d", deviceProp.multiProcessorCount); + while ((block_size < max_block_size) && + (graph_degree * num_parents * team_size >= block_size * 2) && + (max_queries <= (1024 / (block_size * 2)) * deviceProp.multiProcessorCount)) { + block_size *= 2; + } + } + RAFT_LOG_DEBUG("# thread_block_size: %u", block_size); + RAFT_EXPECTS(block_size >= min_block_size, + "block_size cannot be smaller than min_block size, %u", + min_block_size); + RAFT_EXPECTS(block_size <= max_block_size, + "block_size cannot be larger than max_block size %u", + max_block_size); + thread_block_size = block_size; + + // Determine load bit length + const uint32_t total_bit_length = dim * sizeof(DATA_T) * 8; + if (load_bit_length == 0) { + load_bit_length = 128; + while (total_bit_length % load_bit_length) { + load_bit_length /= 2; + } + } + RAFT_LOG_DEBUG("# load_bit_length: %u (%u loads per vector)", + load_bit_length, + total_bit_length / load_bit_length); + RAFT_EXPECTS(total_bit_length % load_bit_length == 0, + "load_bit_length must be a divisor of dim*sizeof(data_t)*8=%u", + total_bit_length); + RAFT_EXPECTS(load_bit_length >= 64, "load_bit_lenght cannot be less than 64"); + + if (num_itopk_candidates <= 256) { + RAFT_LOG_DEBUG("# bitonic-sort based topk routine is used"); + } else { + RAFT_LOG_DEBUG("# radix-sort based topk routine is used"); + smem_size = base_smem_size; + if (itopk_size <= 256) { + constexpr unsigned MAX_ITOPK = 256; + if (block_size == 256) { + constexpr unsigned BLOCK_SIZE = 256; + smem_size += topk_by_radix_sort::smem_size * sizeof(std::uint32_t); + } else if (block_size == 512) { + constexpr unsigned BLOCK_SIZE = 512; + smem_size += topk_by_radix_sort::smem_size * sizeof(std::uint32_t); + } else { + constexpr unsigned BLOCK_SIZE = 1024; + smem_size += topk_by_radix_sort::smem_size * sizeof(std::uint32_t); + } + } else { + constexpr unsigned MAX_ITOPK = 512; + if (block_size == 256) { + constexpr unsigned BLOCK_SIZE = 256; + smem_size += topk_by_radix_sort::smem_size * sizeof(std::uint32_t); + } else if (block_size == 512) { + constexpr unsigned BLOCK_SIZE = 512; + smem_size += topk_by_radix_sort::smem_size * sizeof(std::uint32_t); + } else { + constexpr unsigned BLOCK_SIZE = 1024; + smem_size += topk_by_radix_sort::smem_size * sizeof(std::uint32_t); + } + } + } + RAFT_LOG_DEBUG("# smem_size: %u", smem_size); + hashmap_size = 0; + if (small_hash_bitlen == 0) { + hashmap_size = sizeof(uint32_t) * max_queries * hashmap::get_size(hash_bitlen); + hashmap.resize(hashmap_size, res.get_stream()); + } + RAFT_LOG_DEBUG("# hashmap_size: %lu", hashmap_size); + } + + void operator()(raft::device_resources const& res, + raft::device_matrix_view dataset, + raft::device_matrix_view graph, + INDEX_T* const result_indices_ptr, // [num_queries, topk] + DISTANCE_T* const result_distances_ptr, // [num_queries, topk] + const DATA_T* const queries_ptr, // [num_queries, dataset_dim] + const std::uint32_t num_queries, + const INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] + std::uint32_t* const num_executed_iterations, // [num_queries] + uint32_t topk) + { + cudaStream_t stream = res.get_stream(); + uint32_t block_size = thread_block_size; + SET_KERNEL; + RAFT_CUDA_TRY( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + dim3 thread_dims(block_size, 1, 1); + dim3 block_dims(1, num_queries, 1); + RAFT_LOG_DEBUG( + "Launching kernel with %u threads, %u block %lu smem", block_size, num_queries, smem_size); + kernel<<>>(result_indices_ptr, + result_distances_ptr, + topk, + dataset.data_handle(), + dataset.extent(1), + dataset.extent(0), + queries_ptr, + graph.data_handle(), + graph.extent(1), + num_random_samplings, + rand_xor_mask, + dev_seed_ptr, + num_seeds, + hashmap.data(), + itopk_size, + num_parents, + min_iterations, + max_iterations, + num_executed_iterations, + hash_bitlen, + small_hash_bitlen, + small_hash_reset_interval); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } +}; + +} // namespace single_cta_search +} // namespace raft::neighbors::experimental::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk.h b/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk.h new file mode 100644 index 0000000000..ccb65fd0ea --- /dev/null +++ b/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk.h @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +namespace raft::neighbors::experimental::cagra::detail { + +// +size_t _cuann_find_topk_bufferSize(uint32_t topK, + uint32_t sizeBatch, + uint32_t numElements, + cudaDataType_t sampleDtype = CUDA_R_32F); + +// +void _cuann_find_topk(uint32_t topK, + uint32_t sizeBatch, + uint32_t numElements, + const float* inputKeys, // [sizeBatch, ldIK,] + uint32_t ldIK, // (*) ldIK >= numElements + const uint32_t* inputVals, // [sizeBatch, ldIV,] + uint32_t ldIV, // (*) ldIV >= numElements + float* outputKeys, // [sizeBatch, ldOK,] + uint32_t ldOK, // (*) ldOK >= topK + uint32_t* outputVals, // [sizeBatch, ldOV,] + uint32_t ldOV, // (*) ldOV >= topK + void* workspace, + bool sort = false, + uint32_t* hint = NULL, + cudaStream_t stream = 0); + +#ifdef __CUDA_ARCH__ +#define CUDA_DEVICE_HOST_FUNC __device__ +#else +#define CUDA_DEVICE_HOST_FUNC +#endif +// +CUDA_DEVICE_HOST_FUNC inline size_t _cuann_aligned(size_t size, size_t unit = 128) +{ + if (size % unit) { size += unit - (size % unit); } + return size; +} +} // namespace raft::neighbors::experimental::cagra::detail \ No newline at end of file diff --git a/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh b/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh new file mode 100644 index 0000000000..d09478d1db --- /dev/null +++ b/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh @@ -0,0 +1,926 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "topk.h" +#include +#include +#include +#include +#include + +namespace raft::neighbors::experimental::cagra::detail { +using namespace cub; + +// +__device__ inline uint32_t convert(uint32_t x) +{ + if (x & 0x80000000) { + return x ^ 0xffffffff; + } else { + return x ^ 0x80000000; + } +} + +// +__device__ inline uint16_t convert(uint16_t x) +{ + if (x & 0x8000) { + return x ^ 0xffff; + } else { + return x ^ 0x8000; + } +} + +// +struct u32_vector { + uint1 x1; + uint2 x2; + uint4 x4; + ulonglong4 x8; +}; + +// +struct u16_vector { + ushort1 x1; + ushort2 x2; + ushort4 x4; + uint4 x8; +}; + +// +template +__device__ inline void load_u32_vector(struct u32_vector& vec, const uint32_t* x, int i) +{ + if (vecLen == 1) { + vec.x1 = ((uint1*)(x + i))[0]; + } else if (vecLen == 2) { + vec.x2 = ((uint2*)(x + i))[0]; + } else if (vecLen == 4) { + vec.x4 = ((uint4*)(x + i))[0]; + } else if (vecLen == 8) { + vec.x8 = ((ulonglong4*)(x + i))[0]; + } +} + +// +template +__device__ inline void load_u16_vector(struct u16_vector& vec, const uint16_t* x, int i) +{ + if (vecLen == 1) { + vec.x1 = ((ushort1*)(x + i))[0]; + } else if (vecLen == 2) { + vec.x2 = ((ushort2*)(x + i))[0]; + } else if (vecLen == 4) { + vec.x4 = ((ushort4*)(x + i))[0]; + } else if (vecLen == 8) { + vec.x8 = ((uint4*)(x + i))[0]; + } +} + +// +template +__device__ inline uint32_t get_element_from_u32_vector(struct u32_vector& vec, int i) +{ + uint32_t xi; + if (vecLen == 1) { + xi = convert(vec.x1.x); + } else if (vecLen == 2) { + if (i == 0) + xi = convert(vec.x2.x); + else + xi = convert(vec.x2.y); + } else if (vecLen == 4) { + if (i == 0) + xi = convert(vec.x4.x); + else if (i == 1) + xi = convert(vec.x4.y); + else if (i == 2) + xi = convert(vec.x4.z); + else + xi = convert(vec.x4.w); + } else if (vecLen == 8) { + if (i == 0) + xi = convert((uint32_t)(vec.x8.x & 0xffffffff)); + else if (i == 1) + xi = convert((uint32_t)(vec.x8.x >> 32)); + else if (i == 2) + xi = convert((uint32_t)(vec.x8.y & 0xffffffff)); + else if (i == 3) + xi = convert((uint32_t)(vec.x8.y >> 32)); + else if (i == 4) + xi = convert((uint32_t)(vec.x8.z & 0xffffffff)); + else if (i == 5) + xi = convert((uint32_t)(vec.x8.z >> 32)); + else if (i == 6) + xi = convert((uint32_t)(vec.x8.w & 0xffffffff)); + else + xi = convert((uint32_t)(vec.x8.w >> 32)); + } + return xi; +} + +// +template +__device__ inline uint16_t get_element_from_u16_vector(struct u16_vector& vec, int i) +{ + uint16_t xi; + if (vecLen == 1) { + xi = convert(vec.x1.x); + } else if (vecLen == 2) { + if (i == 0) + xi = convert(vec.x2.x); + else + xi = convert(vec.x2.y); + } else if (vecLen == 4) { + if (i == 0) + xi = convert(vec.x4.x); + else if (i == 1) + xi = convert(vec.x4.y); + else if (i == 2) + xi = convert(vec.x4.z); + else + xi = convert(vec.x4.w); + } else if (vecLen == 8) { + if (i == 0) + xi = convert((uint16_t)(vec.x8.x & 0xffff)); + else if (i == 1) + xi = convert((uint16_t)(vec.x8.x >> 16)); + else if (i == 2) + xi = convert((uint16_t)(vec.x8.y & 0xffff)); + else if (i == 3) + xi = convert((uint16_t)(vec.x8.y >> 16)); + else if (i == 4) + xi = convert((uint16_t)(vec.x8.z & 0xffff)); + else if (i == 5) + xi = convert((uint16_t)(vec.x8.z >> 16)); + else if (i == 6) + xi = convert((uint16_t)(vec.x8.w & 0xffff)); + else + xi = convert((uint16_t)(vec.x8.w >> 16)); + } + return xi; +} + +// +template +__device__ inline void update_histogram(int itr, + uint32_t thread_id, + uint32_t num_threads, + uint32_t hint, + uint32_t threshold, + uint32_t& num_bins, + uint32_t& shift, + const T* x, // [nx,] + uint32_t nx, + uint32_t* hist, // [num_bins] + uint8_t* state, + uint32_t* output, // [topk] + uint32_t* output_count) +{ + if (sizeof(T) == 4) { + // 32-bit (uint32_t) + // itr:0, calculate histogram with 11 bits from bit-21 to bit-31 + // itr:1, calculate histogram with 11 bits from bit-10 to bit-20 + // itr:2, calculate histogram with 10 bits from bit-0 to bit-9 + if (itr == 0) { + shift = 21; + num_bins = 2048; + } else if (itr == 1) { + shift = 10; + num_bins = 2048; + } else { + shift = 0; + num_bins = 1024; + } + } else if (sizeof(T) == 2) { + // 16-bit (uint16_t) + // itr:0, calculate histogram with 8 bits from bit-8 to bit-15 + // itr:1, calculate histogram with 8 bits from bit-0 to bit-7 + if (itr == 0) { + shift = 8; + num_bins = 256; + } else { + shift = 0; + num_bins = 256; + } + } else { + return; + } + if (itr > 0) { + for (int i = threadIdx.x; i < num_bins; i += blockDim_x) { + hist[i] = 0; + } + __syncthreads(); + } + + // (*) Note that 'thread_id' may be different from 'threadIdx.x', + // and 'num_threads' may be different from 'blockDim.x' + int ii = 0; + for (int i = thread_id * vecLen; i < nx; i += num_threads * max(vecLen, stateBitLen), ii++) { + uint8_t iState = 0; + if ((stateBitLen == 8) && (itr > 0)) { + iState = state[thread_id + (num_threads * ii)]; + if (iState == (uint8_t)0xff) continue; + } +#pragma unroll + for (int v = 0; v < max(vecLen, stateBitLen); v += vecLen) { + int iv = i + (num_threads * v); + if (iv >= nx) break; + + struct u32_vector x_u32_vec; + struct u16_vector x_u16_vec; + if (sizeof(T) == 4) { + load_u32_vector(x_u32_vec, (const uint32_t*)x, iv); + } else { + load_u16_vector(x_u16_vec, (const uint16_t*)x, iv); + } +#pragma unroll + for (int u = 0; u < vecLen; u++) { + int ivu = iv + u; + if (ivu >= nx) break; + + uint8_t mask = (uint8_t)0x1 << (v + u); + if ((stateBitLen == 8) && (iState & mask)) continue; + + uint32_t xi; + if (sizeof(T) == 4) { + xi = get_element_from_u32_vector(x_u32_vec, u); + } else { + xi = get_element_from_u16_vector(x_u16_vec, u); + } + if ((xi > hint) && (itr == 0)) { + if (stateBitLen == 8) { iState |= mask; } + } else if (xi < threshold) { + if (stateBitLen == 8) { + // If the condition is already met, record the index. + output[atomicAdd(output_count, 1)] = ivu; + iState |= mask; + } + } else { + uint32_t k = (xi - threshold) >> shift; // 0 <= k + if (k >= num_bins) { + if (stateBitLen == 8) { iState |= mask; } + } else if (k + 1 < num_bins) { + // Update histogram + atomicAdd(&(hist[k + 1]), 1); + } + } + } + } + if (stateBitLen == 8) { state[thread_id + (num_threads * ii)] = iState; } + } + __syncthreads(); +} + +// +template +__device__ inline void select_best_index_for_next_threshold(uint32_t topk, + uint32_t threshold, + uint32_t max_threshold, + uint32_t nx_below_threshold, + uint32_t num_bins, + uint32_t shift, + const uint32_t* hist, // [num_bins] + uint32_t* best_index, + uint32_t* best_csum) +{ + // Scan the histogram ('hist') and compute csum. Then, find the largest + // index under the condition that the sum of the number of elements found + // so far ('nx_below_threshold') and the csum value does not exceed the + // topk value. + typedef BlockScan BlockScanT; + __shared__ typename BlockScanT::TempStorage temp_storage; + + uint32_t my_index = 0xffffffff; + uint32_t my_csum = 0; + if (num_bins <= blockDim_x) { + uint32_t csum = 0; + if (threadIdx.x < num_bins) { csum = hist[threadIdx.x]; } + BlockScanT(temp_storage).InclusiveSum(csum, csum); + if (threadIdx.x < num_bins) { + uint32_t index = threadIdx.x; + if ((nx_below_threshold + csum <= topk) && (threshold + (index << shift) <= max_threshold)) { + my_index = index; + my_csum = csum; + } + } + } else { + if (num_bins == 2048) { + constexpr int n_data = 2048 / blockDim_x; + uint32_t csum[n_data]; + for (int i = 0; i < n_data; i++) { + csum[i] = hist[i + (n_data * threadIdx.x)]; + } + BlockScanT(temp_storage).InclusiveSum(csum, csum); + for (int i = n_data - 1; i >= 0; i--) { + if (nx_below_threshold + csum[i] > topk) continue; + uint32_t index = i + (n_data * threadIdx.x); + if (threshold + (index << shift) > max_threshold) continue; + my_index = index; + my_csum = csum[i]; + break; + } + } else if (num_bins == 1024) { + constexpr int n_data = 1024 / blockDim_x; + uint32_t csum[n_data]; + for (int i = 0; i < n_data; i++) { + csum[i] = hist[i + (n_data * threadIdx.x)]; + } + BlockScanT(temp_storage).InclusiveSum(csum, csum); + for (int i = n_data - 1; i >= 0; i--) { + if (nx_below_threshold + csum[i] > topk) continue; + uint32_t index = i + (n_data * threadIdx.x); + if (threshold + (index << shift) > max_threshold) continue; + my_index = index; + my_csum = csum[i]; + break; + } + } + } + if (threadIdx.x < num_bins) { + int laneid = 31 - __clz(__ballot_sync(0xffffffff, (my_index != 0xffffffff))); + if ((threadIdx.x & 0x1f) == laneid) { + uint32_t old_index = atomicMax(best_index, my_index); + if (old_index < my_index) { atomicMax(best_csum, my_csum); } + } + } + __syncthreads(); +} + +// +template +__device__ inline void output_index_below_threshold(uint32_t topk, + uint32_t thread_id, + uint32_t num_threads, + uint32_t threshold, + uint32_t nx_below_threshold, + const T* x, // [nx,] + uint32_t nx, + const uint8_t* state, + uint32_t* output, // [topk] + uint32_t* output_count, + uint32_t* output_count_eq) +{ + int ii = 0; + for (int i = thread_id * vecLen; i < nx; i += num_threads * max(vecLen, stateBitLen), ii++) { + uint8_t iState = 0; + if (stateBitLen == 8) { + iState = state[thread_id + (num_threads * ii)]; + if (iState == (uint8_t)0xff) continue; + } +#pragma unroll + for (int v = 0; v < max(vecLen, stateBitLen); v += vecLen) { + int iv = i + (num_threads * v); + if (iv >= nx) break; + + struct u32_vector u32_vec; + struct u16_vector u16_vec; + if (sizeof(T) == 4) { + load_u32_vector(u32_vec, (const uint32_t*)x, iv); + } else { + load_u16_vector(u16_vec, (const uint16_t*)x, iv); + } +#pragma unroll + for (int u = 0; u < vecLen; u++) { + int ivu = iv + u; + if (ivu >= nx) break; + + uint8_t mask = (uint8_t)0x1 << (v + u); + if ((stateBitLen == 8) && (iState & mask)) continue; + + uint32_t xi; + if (sizeof(T) == 4) { + xi = get_element_from_u32_vector(u32_vec, u); + } else { + xi = get_element_from_u16_vector(u16_vec, u); + } + if (xi < threshold) { + output[atomicAdd(output_count, 1)] = ivu; + } else if (xi == threshold) { + // (*) If the value is equal to the threshold, the index + // processed first is recorded. Cause of non-determinism. + if (nx_below_threshold + atomicAdd(output_count_eq, 1) < topk) { + output[atomicAdd(output_count, 1)] = ivu; + } + } + } + } + } +} + +// +template +__device__ inline void swap(T& val1, T& val2) +{ + T val0 = val1; + val1 = val2; + val2 = val0; +} + +// +template +__device__ inline bool swap_if_needed(K& key1, K& key2) +{ + if (key1 > key2) { + swap(key1, key2); + return true; + } + return false; +} + +// +template +__device__ inline bool swap_if_needed(K& key1, K& key2, V& val1, V& val2) +{ + if (key1 > key2) { + swap(key1, key2); + swap(val1, val2); + return true; + } + return false; +} + +// +template +__device__ inline bool swap_if_needed(K& key1, K& key2, V& val1, V& val2, bool ascending) +{ + if (key1 == key2) { return false; } + if ((key1 > key2) == ascending) { + swap(key1, key2); + swap(val1, val2); + return true; + } + return false; +} + +// +template +__device__ inline T max_value_of(); +template <> +__device__ inline float max_value_of() +{ + return FLT_MAX; +} +template <> +__device__ inline uint32_t max_value_of() +{ + return ~0u; +} + +template +__device__ __host__ inline uint32_t get_state_size(uint32_t len_x) +{ + const uint32_t num_threads = blockDim_x; + if (stateBitLen == 8) { + uint32_t numElements_perThread = (len_x + num_threads - 1) / num_threads; + uint32_t numState_perThread = (numElements_perThread + stateBitLen - 1) / stateBitLen; + return numState_perThread * num_threads; + } + return 0; +} + +// +template +__device__ inline void topk_cta_11_core(uint32_t topk, + uint32_t len_x, + const uint32_t* _x, // [size_batch, ld_x,] + const uint32_t* _in_vals, // [size_batch, ld_iv,] + uint32_t* _y, // [size_batch, ld_y,] + uint32_t* _out_vals, // [size_batch, ld_ov,] + uint8_t* _state, // [size_batch, ...,] + uint32_t* _hint, + bool sort, + uint32_t* _smem) +{ + uint32_t* smem_out_vals = _smem; + uint32_t* hist = &(_smem[2 * maxTopk]); + uint32_t* best_index = &(_smem[2 * maxTopk + 2048]); + uint32_t* best_csum = &(_smem[2 * maxTopk + 2048 + 3]); + + const uint32_t num_threads = blockDim_x; + const uint32_t thread_id = threadIdx.x; + uint32_t nx = len_x; + const uint32_t* x = _x; + const uint32_t* in_vals = NULL; + if (_in_vals) { in_vals = _in_vals; } + uint32_t* y = NULL; + if (_y) { y = _y; } + uint32_t* out_vals = NULL; + if (_out_vals) { out_vals = _out_vals; } + uint8_t* state = _state; + uint32_t hint = (_hint == NULL ? ~0u : *_hint); + + // Initialize shared memory + for (int i = 2 * maxTopk + thread_id; i < 2 * maxTopk + 2048 + 8; i += num_threads) { + _smem[i] = 0; + } + uint32_t* output_count = &(_smem[2 * maxTopk + 2048 + 6]); + uint32_t* output_count_eq = &(_smem[2 * maxTopk + 2048 + 7]); + uint32_t threshold = 0; + uint32_t nx_below_threshold = 0; + __syncthreads(); + + // + // Search for the maximum threshold that satisfies "(x < threshold).sum() <= topk". + // +#pragma unroll + for (int j = 0; j < 3; j += 1) { + uint32_t num_bins; + uint32_t shift; + update_histogram(j, + thread_id, + num_threads, + hint, + threshold, + num_bins, + shift, + x, + nx, + hist, + state, + smem_out_vals, + output_count); + + select_best_index_for_next_threshold(topk, + threshold, + hint, + nx_below_threshold, + num_bins, + shift, + hist, + best_index + j, + best_csum + j); + + threshold += (best_index[j] << shift); + nx_below_threshold += best_csum[j]; + if (nx_below_threshold == topk) break; + } + + if ((_hint != NULL) && (thread_id == 0)) { *_hint = min(threshold, hint); } + + // + // Output index that satisfies "x[i] < threshold". + // + output_index_below_threshold(topk, + thread_id, + num_threads, + threshold, + nx_below_threshold, + x, + nx, + state, + smem_out_vals, + output_count, + output_count_eq); + __syncthreads(); + +#ifdef CUANN_DEBUG + if (thread_id == 0 && output_count[0] < topk) { + RAFT_LOG_DEBUG( + "# i_batch:%d, topk:%d, output_count:%d, nx_below_threshold:%d, threshold:%08x\n", + i_batch, + topk, + output_count[0], + nx_below_threshold, + threshold); + } +#endif + + if (!sort) { + for (int k = thread_id; k < topk; k += blockDim_x) { + uint32_t i = smem_out_vals[k]; + if (y) { y[k] = x[i]; } + if (out_vals) { + if (in_vals) { + out_vals[k] = in_vals[i]; + } else { + out_vals[k] = i; + } + } + } + return; + } + + constexpr int numTopkPerThread = maxTopk / numSortThreads; + float my_keys[numTopkPerThread]; + uint32_t my_vals[numTopkPerThread]; + + // Read keys and values to registers + if (thread_id < numSortThreads) { + for (int i = 0; i < numTopkPerThread; i++) { + int k = thread_id + (numSortThreads * i); + if (k < topk) { + int j = smem_out_vals[k]; + my_keys[i] = ((float*)x)[j]; + if (in_vals) { + my_vals[i] = in_vals[j]; + } else { + my_vals[i] = j; + } + } else { + my_keys[i] = FLT_MAX; + my_vals[i] = 0xffffffffU; + } + } + } + + uint32_t mask = 1; + + // Sorting by thread + if (thread_id < numSortThreads) { + bool ascending = ((thread_id & mask) == 0); + if (numTopkPerThread == 3) { + swap_if_needed(my_keys[0], my_keys[1], my_vals[0], my_vals[1], ascending); + swap_if_needed(my_keys[0], my_keys[2], my_vals[0], my_vals[2], ascending); + swap_if_needed(my_keys[1], my_keys[2], my_vals[1], my_vals[2], ascending); + } else { + for (int j = 0; j < numTopkPerThread / 2; j += 1) { +#pragma unroll + for (int i = 0; i < numTopkPerThread; i += 2) { + swap_if_needed( + my_keys[i], my_keys[i + 1], my_vals[i], my_vals[i + 1], ascending); + } +#pragma unroll + for (int i = 1; i < numTopkPerThread - 1; i += 2) { + swap_if_needed( + my_keys[i], my_keys[i + 1], my_vals[i], my_vals[i + 1], ascending); + } + } + } + } + + // Bitonic Sorting + while (mask < numSortThreads) { + uint32_t next_mask = mask << 1; + + for (uint32_t curr_mask = mask; curr_mask > 0; curr_mask >>= 1) { + bool ascending = ((thread_id & curr_mask) == 0) == ((thread_id & next_mask) == 0); + if (curr_mask >= 32) { + // inter warp + uint32_t* smem_vals = _smem; // [numTopkPerThread, numSortThreads] + float* smem_keys = (float*)(_smem + numTopkPerThread * numSortThreads); + __syncthreads(); + if (thread_id < numSortThreads) { +#pragma unroll + for (int i = 0; i < numTopkPerThread; i++) { + smem_keys[thread_id + (numSortThreads * i)] = my_keys[i]; + smem_vals[thread_id + (numSortThreads * i)] = my_vals[i]; + } + } + __syncthreads(); + if (thread_id < numSortThreads) { +#pragma unroll + for (int i = 0; i < numTopkPerThread; i++) { + float opp_key = smem_keys[(thread_id ^ curr_mask) + (numSortThreads * i)]; + uint32_t opp_val = smem_vals[(thread_id ^ curr_mask) + (numSortThreads * i)]; + swap_if_needed(my_keys[i], opp_key, my_vals[i], opp_val, ascending); + } + } + } else { + // intra warp + if (thread_id < numSortThreads) { +#pragma unroll + for (int i = 0; i < numTopkPerThread; i++) { + float opp_key = __shfl_xor_sync(0xffffffff, my_keys[i], curr_mask); + uint32_t opp_val = __shfl_xor_sync(0xffffffff, my_vals[i], curr_mask); + swap_if_needed(my_keys[i], opp_key, my_vals[i], opp_val, ascending); + } + } + } + } + + if (thread_id < numSortThreads) { + bool ascending = ((thread_id & next_mask) == 0); + if (numTopkPerThread == 3) { + swap_if_needed(my_keys[0], my_keys[1], my_vals[0], my_vals[1], ascending); + swap_if_needed(my_keys[0], my_keys[2], my_vals[0], my_vals[2], ascending); + swap_if_needed(my_keys[1], my_keys[2], my_vals[1], my_vals[2], ascending); + } else { +#pragma unroll + for (uint32_t curr_mask = numTopkPerThread / 2; curr_mask > 0; curr_mask >>= 1) { +#pragma unroll + for (int i = 0; i < numTopkPerThread; i++) { + int j = i ^ curr_mask; + if (i > j) continue; + swap_if_needed( + my_keys[i], my_keys[j], my_vals[i], my_vals[j], ascending); + } + } + } + } + mask = next_mask; + } + + // Write sorted keys and values + if (thread_id < numSortThreads) { + for (int i = 0; i < numTopkPerThread; i++) { + int k = i + (numTopkPerThread * thread_id); + if (k < topk) { + if (y) { y[k] = ((uint32_t*)my_keys)[i]; } + if (out_vals) { out_vals[k] = my_vals[i]; } + } + } + } +} + +namespace { + +// +constexpr std::uint32_t NUM_THREADS = 1024; // DO NOT CHANGE +constexpr std::uint32_t STATE_BIT_LENGTH = 8; // 0: state not used, 8: state used +constexpr std::uint32_t MAX_VEC_LENGTH = 4; // 1, 2, 4 or 8 + +// +// +int _get_vecLen(uint32_t maxSamples, int maxVecLen = MAX_VEC_LENGTH) +{ + int vecLen = min(maxVecLen, (int)MAX_VEC_LENGTH); + while ((maxSamples % vecLen) != 0) { + vecLen /= 2; + } + return vecLen; +} +} // unnamed namespace + +template +__launch_bounds__(1024, 1) __global__ + void kern_topk_cta_11(uint32_t topk, + uint32_t size_batch, + uint32_t len_x, + const uint32_t* _x, // [size_batch, ld_x,] + uint32_t ld_x, + const uint32_t* _in_vals, // [size_batch, ld_iv,] + uint32_t ld_iv, + uint32_t* _y, // [size_batch, ld_y,] + uint32_t ld_y, + uint32_t* _out_vals, // [size_batch, ld_ov,] + uint32_t ld_ov, + uint8_t* _state, // [size_batch, ...,] + uint32_t* _hints, // [size_batch,] + bool sort) +{ + uint32_t i_batch = blockIdx.x; + if (i_batch >= size_batch) return; + __shared__ uint32_t _smem[2 * maxTopk + 2048 + 8]; + + topk_cta_11_core( + topk, + len_x, + (_x == NULL ? NULL : _x + i_batch * ld_x), + (_in_vals == NULL ? NULL : _in_vals + i_batch * ld_iv), + (_y == NULL ? NULL : _y + i_batch * ld_y), + (_out_vals == NULL ? NULL : _out_vals + i_batch * ld_ov), + (_state == NULL ? NULL : _state + i_batch * get_state_size(len_x)), + (_hints == NULL ? NULL : _hints + i_batch), + sort, + _smem); +} + +// +size_t inline _cuann_find_topk_bufferSize(uint32_t topK, + uint32_t sizeBatch, + uint32_t numElements, + cudaDataType_t sampleDtype) +{ + constexpr int numThreads = NUM_THREADS; + constexpr int stateBitLen = STATE_BIT_LENGTH; + assert(stateBitLen == 0 || stateBitLen == 8); + + size_t workspaceSize = 1; + // state + if (stateBitLen == 8) { + workspaceSize = _cuann_aligned( + sizeof(uint8_t) * get_state_size(numElements) * sizeBatch); + } + + return workspaceSize; +} + +inline void _cuann_find_topk(uint32_t topK, + uint32_t sizeBatch, + uint32_t numElements, + const float* inputKeys, // [sizeBatch, ldIK,] + uint32_t ldIK, // (*) ldIK >= numElements + const uint32_t* inputVals, // [sizeBatch, ldIV,] + uint32_t ldIV, // (*) ldIV >= numElements + float* outputKeys, // [sizeBatch, ldOK,] + uint32_t ldOK, // (*) ldOK >= topK + uint32_t* outputVals, // [sizeBatch, ldOV,] + uint32_t ldOV, // (*) ldOV >= topK + void* workspace, + bool sort, + uint32_t* hints, + cudaStream_t stream) +{ + assert(ldIK >= numElements); + assert(ldIV >= numElements); + assert(ldOK >= topK); + assert(ldOV >= topK); + + constexpr int numThreads = NUM_THREADS; + constexpr int stateBitLen = STATE_BIT_LENGTH; + assert(stateBitLen == 0 || stateBitLen == 8); + + uint8_t* state = NULL; + if (stateBitLen == 8) { state = (uint8_t*)workspace; } + + dim3 threads(numThreads, 1, 1); + dim3 blocks(sizeBatch, 1, 1); + + void (*cta_kernel)(uint32_t, + uint32_t, + uint32_t, + const uint32_t*, + uint32_t, + const uint32_t*, + uint32_t, + uint32_t*, + uint32_t, + uint32_t*, + uint32_t, + uint8_t*, + uint32_t*, + bool) = nullptr; + + // V:vecLen, K:maxTopk, T:numSortThreads +#define SET_KERNEL_VKT(V, K, T) \ + do { \ + assert(numThreads >= T); \ + assert((K % T) == 0); \ + assert((K / T) <= 4); \ + cta_kernel = kern_topk_cta_11; \ + } while (0) + + // V: vecLen +#define SET_KERNEL_V(V) \ + do { \ + if (topK <= 32) { \ + SET_KERNEL_VKT(V, 32, 32); \ + } else if (topK <= 64) { \ + SET_KERNEL_VKT(V, 64, 32); \ + } else if (topK <= 96) { \ + SET_KERNEL_VKT(V, 96, 32); \ + } else if (topK <= 128) { \ + SET_KERNEL_VKT(V, 128, 32); \ + } else if (topK <= 192) { \ + SET_KERNEL_VKT(V, 192, 64); \ + } else if (topK <= 256) { \ + SET_KERNEL_VKT(V, 256, 64); \ + } else if (topK <= 384) { \ + SET_KERNEL_VKT(V, 384, 128); \ + } else if (topK <= 512) { \ + SET_KERNEL_VKT(V, 512, 128); \ + } else if (topK <= 768) { \ + SET_KERNEL_VKT(V, 768, 256); \ + } else if (topK <= 1024) { \ + SET_KERNEL_VKT(V, 1024, 256); \ + } \ + /* else if (topK <= 1536) { SET_KERNEL_VKT(V, 1536, 512); } */ \ + /* else if (topK <= 2048) { SET_KERNEL_VKT(V, 2048, 512); } */ \ + /* else if (topK <= 3072) { SET_KERNEL_VKT(V, 3072, 1024); } */ \ + /* else if (topK <= 4096) { SET_KERNEL_VKT(V, 4096, 1024); } */ \ + else { \ + RAFT_LOG_DEBUG( \ + "[ERROR] (%s, %d) topk must be lower than or equla to 1024.\n", __func__, __LINE__); \ + exit(-1); \ + } \ + } while (0) + + int _vecLen = _get_vecLen(ldIK, 2); + if (_vecLen == 2) { + SET_KERNEL_V(2); + } else if (_vecLen == 1) { + SET_KERNEL_V(1); + } + + cta_kernel<<>>(topK, + sizeBatch, + numElements, + (const uint32_t*)inputKeys, + ldIK, + inputVals, + ldIV, + (uint32_t*)outputKeys, + ldOK, + outputVals, + ldOV, + state, + hints, + sort); + + return; +} +} // namespace raft::neighbors::experimental::cagra::detail \ No newline at end of file diff --git a/cpp/include/raft/neighbors/detail/cagra/utils.hpp b/cpp/include/raft/neighbors/detail/cagra/utils.hpp new file mode 100644 index 0000000000..3e329c9239 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/cagra/utils.hpp @@ -0,0 +1,143 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace raft::neighbors::experimental::cagra::detail { +namespace utils { +template +inline cudaDataType_t get_cuda_data_type(); +template <> +inline cudaDataType_t get_cuda_data_type() +{ + return CUDA_R_32F; +} +template <> +inline cudaDataType_t get_cuda_data_type() +{ + return CUDA_R_16F; +} +template <> +inline cudaDataType_t get_cuda_data_type() +{ + return CUDA_R_8I; +} +template <> +inline cudaDataType_t get_cuda_data_type() +{ + return CUDA_R_8U; +} +template <> +inline cudaDataType_t get_cuda_data_type() +{ + return CUDA_R_32U; +} +template <> +inline cudaDataType_t get_cuda_data_type() +{ + return CUDA_R_64U; +} + +template +constexpr unsigned size_of(); +template <> +_RAFT_HOST_DEVICE constexpr unsigned size_of() +{ + return 1; +} +template <> +_RAFT_HOST_DEVICE constexpr unsigned size_of() +{ + return 1; +} +template <> +_RAFT_HOST_DEVICE constexpr unsigned size_of() +{ + return 2; +} +template <> +_RAFT_HOST_DEVICE constexpr unsigned size_of() +{ + return 4; +} +template <> +_RAFT_HOST_DEVICE constexpr unsigned size_of() +{ + return 8; +} +template <> +_RAFT_HOST_DEVICE constexpr unsigned size_of() +{ + return 16; +} +template <> +_RAFT_HOST_DEVICE constexpr unsigned size_of() +{ + return 32; +} +template <> +_RAFT_HOST_DEVICE constexpr unsigned size_of() +{ + return 4; +} +template <> +_RAFT_HOST_DEVICE constexpr unsigned size_of() +{ + return 2; +} + +// max values for data types +template +union fp_conv { + BS_T bs; + FP_T fp; +}; +template +_RAFT_HOST_DEVICE inline T get_max_value(); +template <> +_RAFT_HOST_DEVICE inline float get_max_value() +{ + return FLT_MAX; +}; +template <> +_RAFT_HOST_DEVICE inline half get_max_value() +{ + return fp_conv{.bs = 0x7aff}.fp; +}; +template <> +_RAFT_HOST_DEVICE inline std::uint32_t get_max_value() +{ + return 0xffffffffu; +}; + +template +struct constexpr_max { + static const int value = A; +}; + +template +struct constexpr_max A), bool>> { + static const int value = B; +}; +} // namespace utils + +} // namespace raft::neighbors::experimental::cagra::detail diff --git a/cpp/include/raft/util/cache_util.cuh b/cpp/include/raft/util/cache_util.cuh index 4200be96e8..413e7522b1 100644 --- a/cpp/include/raft/util/cache_util.cuh +++ b/cpp/include/raft/util/cache_util.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -50,7 +50,7 @@ __global__ void get_vecs( if (tid < n_vec * n) { size_t out_col = tid / n_vec; // col idx size_t cache_col = cache_idx[out_col]; - if (cache_idx[out_col] >= 0) { + if (!std::is_signed::value || cache_idx[out_col] >= 0) { if (row + out_col * n_vec < (size_t)n_vec * n) { out[tid] = cache[row + cache_col * n_vec]; } } } diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 91050461ae..e805f53712 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -258,6 +258,7 @@ if(BUILD_TESTS) NAME NEIGHBORS_TEST PATH + test/neighbors/ann_cagra/test_float_uint32_t.cu test/neighbors/ann_ivf_flat/test_float_int64_t.cu test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh new file mode 100644 index 0000000000..385e9a80c0 --- /dev/null +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -0,0 +1,313 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "../test_utils.cuh" +#include "ann_utils.cuh" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include + +#include +#include +#include +#include + +namespace raft::neighbors::experimental::cagra { + +struct AnnCagraInputs { + int n_queries; + int n_rows; + int dim; + int k; + search_algo algo; + int max_queries; + int team_size; + int itopk_size; + int num_parents; + raft::distance::DistanceType metric; + bool host_dataset; + // std::optional + double min_recall; // = std::nullopt; +}; + +inline ::std::ostream& operator<<(::std::ostream& os, const AnnCagraInputs& p) +{ + std::vector algo = {"single-cta", "multi_cta", "multi_kernel", "auto"}; + os << "{n_queries=" << p.n_queries << ", dataset shape=" << p.n_rows << "x" << p.dim + << ", k=" << p.k << ", " << algo.at((int)p.algo) << ", max_queries=" << p.max_queries + << ", itopk_size=" << p.itopk_size << ", num_parents=" << p.num_parents + << ", metric=" << static_cast(p.metric) << (p.host_dataset ? ", host" : ", device") << '}' + << std::endl; + return os; +} + +template +class AnnCagraTest : public ::testing::TestWithParam { + public: + AnnCagraTest() + : stream_(handle_.get_stream()), + ps(::testing::TestWithParam::GetParam()), + database(0, stream_), + search_queries(0, stream_) + { + } + + protected: + void testCagra() + { + size_t queries_size = ps.n_queries * ps.k; + std::vector indices_Cagra(queries_size); + std::vector indices_naive(queries_size); + std::vector distances_Cagra(queries_size); + std::vector distances_naive(queries_size); + + { + rmm::device_uvector distances_naive_dev(queries_size, stream_); + rmm::device_uvector indices_naive_dev(queries_size, stream_); + naive_knn(distances_naive_dev.data(), + indices_naive_dev.data(), + search_queries.data(), + database.data(), + ps.n_queries, + ps.n_rows, + ps.dim, + ps.k, + ps.metric, + stream_); + update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); + update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); + handle_.sync_stream(stream_); + } + + { + rmm::device_uvector distances_dev(queries_size, stream_); + rmm::device_uvector indices_dev(queries_size, stream_); + + { + cagra::index_params index_params; + index_params.metric = ps.metric; // Note: currently ony the cagra::index_params metric is + // not used for knn_graph building. + cagra::search_params search_params; + search_params.algo = ps.algo; + search_params.max_queries = ps.max_queries; + search_params.team_size = ps.team_size; + + auto database_view = raft::make_device_matrix_view( + (const DataT*)database.data(), ps.n_rows, ps.dim); + + { + cagra::index index(handle_); + if (ps.host_dataset) { + auto database_host = raft::make_host_matrix(ps.n_rows, ps.dim); + raft::copy(database_host.data_handle(), database.data(), database.size(), stream_); + auto database_host_view = raft::make_host_matrix_view( + (const DataT*)database_host.data_handle(), ps.n_rows, ps.dim); + index = cagra::build(handle_, index_params, database_host_view); + } else { + index = cagra::build(handle_, index_params, database_view); + }; + cagra::serialize(handle_, "cagra_index", index); + } + auto index = cagra::deserialize(handle_, "cagra_index"); + + auto search_queries_view = raft::make_device_matrix_view( + search_queries.data(), ps.n_queries, ps.dim); + auto indices_out_view = + raft::make_device_matrix_view(indices_dev.data(), ps.n_queries, ps.k); + auto dists_out_view = + raft::make_device_matrix_view(distances_dev.data(), ps.n_queries, ps.k); + + cagra::search( + handle_, search_params, index, search_queries_view, indices_out_view, dists_out_view); + + update_host(distances_Cagra.data(), distances_dev.data(), queries_size, stream_); + update_host(indices_Cagra.data(), indices_dev.data(), queries_size, stream_); + handle_.sync_stream(stream_); + } + // for (int i = 0; i < ps.n_queries; i++) { + // // std::cout << "query " << i << std::end; + // print_vector("T", indices_naive.data() + i * ps.k, ps.k, std::cout); + // print_vector("C", indices_Cagra.data() + i * ps.k, ps.k, std::cout); + // print_vector("T", distances_naive.data() + i * ps.k, ps.k, std::cout); + // print_vector("C", distances_Cagra.data() + i * ps.k, ps.k, std::cout); + // } + double min_recall = ps.min_recall; + ASSERT_TRUE(eval_neighbours(indices_naive, + indices_Cagra, + distances_naive, + distances_Cagra, + ps.n_queries, + ps.k, + 0.001, + min_recall)); + ASSERT_TRUE(eval_distances(handle_, + database.data(), + search_queries.data(), + indices_dev.data(), + distances_dev.data(), + ps.n_rows, + ps.dim, + ps.n_queries, + ps.k, + ps.metric, + 1.0e-4)); + } + } + + void SetUp() override + { + std::cout << "Resizing database: " << ps.n_rows * ps.dim << std::endl; + database.resize(((size_t)ps.n_rows) * ps.dim, stream_); + std::cout << "Done.\nResizing queries" << std::endl; + search_queries.resize(ps.n_queries * ps.dim, stream_); + std::cout << "Done.\nRuning rng" << std::endl; + raft::random::Rng r(1234ULL); + if constexpr (std::is_same{}) { + r.uniform(database.data(), ps.n_rows * ps.dim, DataT(0.1), DataT(2.0), stream_); + r.uniform(search_queries.data(), ps.n_queries * ps.dim, DataT(0.1), DataT(2.0), stream_); + } else { + r.uniformInt(database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20), stream_); + r.uniformInt(search_queries.data(), ps.n_queries * ps.dim, DataT(1), DataT(20), stream_); + } + handle_.sync_stream(stream_); + } + + void TearDown() override + { + handle_.sync_stream(stream_); + database.resize(0, stream_); + search_queries.resize(0, stream_); + } + + private: + raft::device_resources handle_; + rmm::cuda_stream_view stream_; + AnnCagraInputs ps; + rmm::device_uvector database; + rmm::device_uvector search_queries; +}; + +inline std::vector generate_inputs() +{ + // Todo(tfeher): MULTI_CTA tests a bug, consider disabling that mode. + std::vector inputs = raft::util::itertools::product( + {100}, + {1000}, + {8}, + {1, 16, 33}, // k + {search_algo::SINGLE_CTA, search_algo::MULTI_KERNEL}, + {1, 10, 100}, // query size + {0}, + {64}, + {1}, + {raft::distance::DistanceType::L2Expanded}, + {false}, + {0.995}); + + auto inputs2 = + raft::util::itertools::product({100}, + {1000}, + {2, 4, 8, 64, 128, 196, 256, 512, 1024}, // dim + {16}, + {search_algo::AUTO}, + {10}, + {0}, + {64}, + {1}, + {raft::distance::DistanceType::L2Expanded}, + {false}, + {0.995}); + inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); + inputs2 = + raft::util::itertools::product({100}, + {1000}, + {64}, + {16}, + {search_algo::AUTO}, + {10}, + {0, 4, 8, 16, 32}, // team_size + {64}, + {1}, + {raft::distance::DistanceType::L2Expanded}, + {false}, + {0.995}); + inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); + + inputs2 = + raft::util::itertools::product({100}, + {1000}, + {64}, + {16}, + {search_algo::AUTO}, + {10}, + {0}, // team_size + {32, 64, 128, 256, 512, 768}, + {1}, + {raft::distance::DistanceType::L2Expanded}, + {false}, + {0.995}); + inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); + + inputs2 = + raft::util::itertools::product({100}, + {10000, 20000}, + {30}, + {10}, + {search_algo::AUTO}, + {10}, + {0}, // team_size + {64}, + {1}, + {raft::distance::DistanceType::L2Expanded}, + {false, true}, + {0.995}); + inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); + + inputs2 = + raft::util::itertools::product({100}, + {10000, 20000}, + {30}, + {10}, + {search_algo::AUTO}, + {10}, + {0}, // team_size + {64}, + {1}, + {raft::distance::DistanceType::L2Expanded}, + {false, true}, + {0.995}); + inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); + + return inputs; +} + +const std::vector inputs = generate_inputs(); + +} // namespace raft::neighbors::experimental::cagra \ No newline at end of file diff --git a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu new file mode 100644 index 0000000000..71a83e2cca --- /dev/null +++ b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../ann_cagra.cuh" + +// #if defined RAFT_DISTANCE_COMPILED +// #include +// #endif + +namespace raft::neighbors::experimental::cagra { + +typedef AnnCagraTest AnnCagraTestF; +TEST_P(AnnCagraTestF, AnnCagra) { this->testCagra(); } + +INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestF, ::testing::ValuesIn(inputs)); + +} // namespace raft::neighbors::experimental::cagra diff --git a/cpp/test/neighbors/ann_utils.cuh b/cpp/test/neighbors/ann_utils.cuh index 4b07db32f4..fc448f014f 100644 --- a/cpp/test/neighbors/ann_utils.cuh +++ b/cpp/test/neighbors/ann_utils.cuh @@ -18,6 +18,7 @@ #include #include +#include #include #include @@ -25,8 +26,11 @@ #include #include +#include + #include "../test_utils.cuh" #include +#include namespace raft::neighbors { @@ -164,4 +168,49 @@ auto eval_neighbours(const std::vector& expected_idx, return testing::AssertionSuccess(); } +template +auto eval_distances(raft::device_resources const& handle, + const T* x, // dataset, n_rows * n_cols + const T* queries, // n_queries * n_cols + const IdxT* neighbors, // n_queries * k + const DistT* distances, // n_queries *k + size_t n_rows, + size_t n_cols, + size_t n_queries, + uint32_t k, + raft::distance::DistanceType metric, + double eps) -> testing::AssertionResult +{ + // for each vector, we calculate the actual distance to the k neighbors + + for (size_t i = 0; i < n_queries; i++) { + auto y = raft::make_device_matrix(handle, k, n_cols); + auto naive_dist = raft::make_device_matrix(handle, 1, k); + + raft::matrix::copyRows( + x, k, n_cols, y.data_handle(), neighbors + i * k, k, handle.get_stream(), true); + + dim3 block_dim(16, 32, 1); + auto grid_y = + static_cast(std::min(raft::ceildiv(k, block_dim.y), 32768)); + dim3 grid_dim(raft::ceildiv(n_rows, block_dim.x), grid_y, 1); + + naive_distance_kernel<<>>( + naive_dist.data_handle(), queries + i * n_cols, y.data_handle(), 1, k, n_cols, metric); + + if (!devArrMatch(distances + i * k, + naive_dist.data_handle(), + naive_dist.size(), + CompareApprox(eps))) { + std::cout << n_rows << "x" << n_cols << ", " << k << std::endl; + std::cout << "query " << i << std::endl; + print_vector(" indices", neighbors + i * k, k, std::cout); + print_vector("n dist", distances + i * k, k, std::cout); + print_vector("c dist", naive_dist.data_handle(), naive_dist.size(), std::cout); + + return testing::AssertionFailure(); + } + } + return testing::AssertionSuccess(); +} } // namespace raft::neighbors