diff --git a/cpp/include/raft/core/detail/mdspan_numpy_serializer.hpp b/cpp/include/raft/core/detail/mdspan_numpy_serializer.hpp index 176309c8ce..3fb7b3005b 100644 --- a/cpp/include/raft/core/detail/mdspan_numpy_serializer.hpp +++ b/cpp/include/raft/core/detail/mdspan_numpy_serializer.hpp @@ -126,7 +126,8 @@ inline dtype_t get_numpy_dtype() } #if defined(_RAFT_HAS_CUDA) -template , bool> = true> +template , half>, bool> = true> inline dtype_t get_numpy_dtype() { return {RAFT_NUMPY_HOST_ENDIAN_CHAR, 'e', sizeof(T)}; diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index b8258297e6..b7e362f704 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -25,6 +25,7 @@ #include #include #include +#include #include @@ -279,62 +280,6 @@ index build(raft::resources const& res, return detail::build(res, params, dataset); } -/** - * @brief Search ANN using the constructed index. - * - * See the [cagra::build](#cagra::build) documentation for a usage example. - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] res raft resources - * @param[in] params configure the search - * @param[in] idx cagra index - * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, - * k] - */ -template -void search(raft::resources const& res, - const search_params& params, - const index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances) -{ - RAFT_EXPECTS( - queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), - "Number of rows in output neighbors and distances matrices must equal the number of queries."); - - RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), - "Number of columns in output neighbors and distances matrices must equal k"); - RAFT_EXPECTS(queries.extent(1) == idx.dim(), - "Number of query dimensions should equal number of dimensions in the index."); - - using internal_IdxT = typename std::make_unsigned::type; - auto queries_internal = raft::make_device_matrix_view( - queries.data_handle(), queries.extent(0), queries.extent(1)); - auto neighbors_internal = raft::make_device_matrix_view( - reinterpret_cast(neighbors.data_handle()), - neighbors.extent(0), - neighbors.extent(1)); - auto distances_internal = raft::make_device_matrix_view( - distances.data_handle(), distances.extent(0), distances.extent(1)); - - cagra::detail::search_main(res, - params, - idx, - queries_internal, - neighbors_internal, - distances_internal, - raft::neighbors::filtering::none_cagra_sample_filter()); -} - /** * @brief Search ANN using the constructed index with the given sample filter. * @@ -401,10 +346,40 @@ void search_with_filtering(raft::resources const& res, auto distances_internal = raft::make_device_matrix_view( distances.data_handle(), distances.extent(0), distances.extent(1)); - cagra::detail::search_main( + return cagra::detail::search_main( res, params, idx, queries_internal, neighbors_internal, distances_internal, sample_filter); } +/** + * @brief Search ANN using the constructed index. + * + * See the [cagra::build](#cagra::build) documentation for a usage example. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] res raft resources + * @param[in] params configure the search + * @param[in] idx cagra index + * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * k] + */ +template +void search(raft::resources const& res, + const search_params& params, + const index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) +{ + using none_filter_type = raft::neighbors::filtering::none_cagra_sample_filter; + return cagra::search_with_filtering( + res, params, idx, queries, neighbors, distances, none_filter_type{}); +} + /** @} */ // end group cagra } // namespace raft::neighbors::cagra diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp index 0f574ae5bb..807f89fd65 100644 --- a/cpp/include/raft/neighbors/cagra_types.hpp +++ b/cpp/include/raft/neighbors/cagra_types.hpp @@ -17,6 +17,7 @@ #pragma once #include "ann_types.hpp" +#include "dataset.hpp" #include #include @@ -35,6 +36,7 @@ #include #include #include + namespace raft::neighbors::cagra { /** * @addtogroup cagra @@ -61,6 +63,12 @@ struct index_params : ann::index_params { graph_build_algo build_algo = graph_build_algo::IVF_PQ; /** Number of Iterations to run if building with NN_DESCENT */ size_t nn_descent_niter = 20; + /** + * Specify compression params if compression is desired. + * + * NOTE: this is experimental new API, consider it unsafe. + */ + std::optional compression = std::nullopt; }; enum class search_algo { @@ -145,25 +153,37 @@ struct index : ann::index { /** Total length of the index (number of vectors). */ [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT { - return dataset_view_.extent(0) ? dataset_view_.extent(0) : graph_view_.extent(0); + auto data_rows = dataset_->n_rows(); + return data_rows > 0 ? data_rows : graph_view_.extent(0); } /** Dimensionality of the data. */ - [[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t - { - return dataset_view_.extent(1); - } + [[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t { return dataset_->dim(); } /** Graph degree */ [[nodiscard]] constexpr inline auto graph_degree() const noexcept -> uint32_t { return graph_view_.extent(1); } - /** Dataset [size, dim] */ - [[nodiscard]] inline auto dataset() const noexcept + /** + * DEPRECATED: please use data() instead. + * If you need to query dataset dimensions, use the dim() and size() of the cagra index. + * The data_handle() is not always available: you need to do a dynamic_cast to the expected + * dataset type at runtime. + */ + [[nodiscard]] [[deprecated("Use data()")]] inline auto dataset() const noexcept -> device_matrix_view { - return dataset_view_; + auto p = dynamic_cast*>(dataset_.get()); + if (p != nullptr) { return p->view(); } + auto d = dataset_->dim(); + return make_device_strided_matrix_view(nullptr, 0, d, d); + } + + /** Dataset [size, dim] */ + [[nodiscard]] inline auto data() const noexcept -> const neighbors::dataset& + { + return *dataset_; } /** neighborhood graph [size, graph-degree] */ @@ -185,7 +205,7 @@ struct index : ann::index { raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded) : ann::index(), metric_(metric), - dataset_(make_device_matrix(res, 0, 0)), + dataset_(new neighbors::empty_dataset(0)), graph_(make_device_matrix(res, 0, 0)) { } @@ -251,12 +271,11 @@ struct index : ann::index { mdspan, row_major, graph_accessor> knn_graph) : ann::index(), metric_(metric), - dataset_(make_device_matrix(res, 0, 0)), + dataset_(make_aligned_dataset(res, dataset, 16)), graph_(make_device_matrix(res, 0, 0)) { RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0), "Dataset and knn_graph must have equal number of rows"); - update_dataset(res, dataset); update_graph(res, knn_graph); resource::sync_stream(res); } @@ -271,21 +290,14 @@ struct index : ann::index { void update_dataset(raft::resources const& res, raft::device_matrix_view dataset) { - if (dataset.extent(1) * sizeof(T) % 16 != 0) { - RAFT_LOG_DEBUG("Creating a padded copy of CAGRA dataset in device memory"); - copy_padded(res, dataset); - } else { - dataset_view_ = make_device_strided_matrix_view( - dataset.data_handle(), dataset.extent(0), dataset.extent(1), dataset.extent(1)); - } + dataset_ = make_aligned_dataset(res, dataset, 16); } /** Set the dataset reference explicitly to a device matrix view with padding. */ - void update_dataset(raft::resources const&, + void update_dataset(raft::resources const& res, raft::device_matrix_view dataset) { - RAFT_EXPECTS(dataset.stride(0) * sizeof(T) % 16 == 0, "Incorrect data padding."); - dataset_view_ = dataset; + dataset_ = make_aligned_dataset(res, dataset, 16); } /** @@ -296,8 +308,22 @@ struct index : ann::index { void update_dataset(raft::resources const& res, raft::host_matrix_view dataset) { - RAFT_LOG_DEBUG("Copying CAGRA dataset from host to device"); - copy_padded(res, dataset); + dataset_ = make_aligned_dataset(res, dataset, 16); + } + + /** Replace the dataset with a new dataset. */ + template + auto update_dataset(raft::resources const& res, DatasetT&& dataset) + -> std::enable_if_t, DatasetT>> + { + dataset_ = std::make_unique(std::move(dataset)); + } + + template + auto update_dataset(raft::resources const& res, std::unique_ptr&& dataset) + -> std::enable_if_t, DatasetT>> + { + dataset_ = std::move(dataset); } /** @@ -334,26 +360,10 @@ struct index : ann::index { } private: - /** Create a device copy of the dataset, and pad it if necessary. */ - template - void copy_padded(raft::resources const& res, - mdspan, row_major, data_accessor> dataset) - { - detail::copy_with_padding(res, dataset_, dataset); - - dataset_view_ = make_device_strided_matrix_view( - dataset_.data_handle(), dataset_.extent(0), dataset.extent(1), dataset_.extent(1)); - RAFT_LOG_DEBUG("CAGRA dataset strided matrix view %zux%zu, stride %zu", - static_cast(dataset_view_.extent(0)), - static_cast(dataset_view_.extent(1)), - static_cast(dataset_view_.stride(0))); - } - raft::distance::DistanceType metric_; - raft::device_matrix dataset_; raft::device_matrix graph_; - raft::device_matrix_view dataset_view_; raft::device_matrix_view graph_view_; + std::unique_ptr> dataset_; }; /** @} */ diff --git a/cpp/include/raft/neighbors/dataset.hpp b/cpp/include/raft/neighbors/dataset.hpp new file mode 100644 index 0000000000..e7a3ba97a4 --- /dev/null +++ b/cpp/include/raft/neighbors/dataset.hpp @@ -0,0 +1,330 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include // get_device_for_address +#include // rounding up + +#include +#include +#include + +#ifdef __cpp_lib_bitops +#include +#endif + +namespace raft::neighbors { + +/** Two-dimensional dataset; maybe owning, maybe compressed, maybe strided. */ +template +struct dataset { + using index_type = IdxT; + /** Size of the dataset. */ + [[nodiscard]] virtual auto n_rows() const noexcept -> index_type = 0; + /** Dimensionality of the dataset. */ + [[nodiscard]] virtual auto dim() const noexcept -> uint32_t = 0; + /** Whether the object owns the data. */ + [[nodiscard]] virtual auto is_owning() const noexcept -> bool = 0; + virtual ~dataset() noexcept = default; +}; + +template +struct empty_dataset : public dataset { + using index_type = IdxT; + uint32_t suggested_dim; + explicit empty_dataset(uint32_t dim) noexcept : suggested_dim(dim) {} + [[nodiscard]] auto n_rows() const noexcept -> index_type final { return 0; } + [[nodiscard]] auto dim() const noexcept -> uint32_t final { return suggested_dim; } + [[nodiscard]] auto is_owning() const noexcept -> bool final { return true; } +}; + +template +struct strided_dataset : public dataset { + using index_type = IdxT; + using value_type = DataT; + using view_type = device_matrix_view; + [[nodiscard]] auto n_rows() const noexcept -> index_type final { return view().extent(0); } + [[nodiscard]] auto dim() const noexcept -> uint32_t final + { + return static_cast(view().extent(1)); + } + /** Leading dimension of the dataset. */ + [[nodiscard]] constexpr auto stride() const noexcept -> uint32_t + { + auto v = view(); + return static_cast(v.stride(0) > 0 ? v.stride(0) : v.extent(1)); + } + /** Get the view of the data. */ + [[nodiscard]] virtual auto view() const noexcept -> view_type; +}; + +template +struct non_owning_dataset : public strided_dataset { + using index_type = IdxT; + using value_type = DataT; + using typename strided_dataset::view_type; + view_type data; + explicit non_owning_dataset(view_type v) noexcept : data(v) {} + [[nodiscard]] auto is_owning() const noexcept -> bool final { return false; } + [[nodiscard]] auto view() const noexcept -> view_type final { return data; }; +}; + +template +struct owning_dataset : public strided_dataset { + using index_type = IdxT; + using value_type = DataT; + using typename strided_dataset::view_type; + using storage_type = + mdarray, LayoutPolicy, ContainerPolicy>; + using mapping_type = typename view_type::mapping_type; + storage_type data; + mapping_type view_mapping; + owning_dataset(storage_type&& store, mapping_type view_mapping) noexcept + : data{std::move(store)}, view_mapping{view_mapping} + { + } + + [[nodiscard]] auto is_owning() const noexcept -> bool final { return true; } + [[nodiscard]] auto view() const noexcept -> view_type final + { + return view_type{data.data_handle(), view_mapping}; + }; +}; + +/** + * @brief Contstruct a strided matrix from any mdarray or mdspan. + * + * This function constructs a non-owning view if the input satisfied two conditions: + * + * 1) The data is accessible from the current device + * 2) The memory layout is the same as expected (row-major matrix with the required stride) + * + * Otherwise, this function constructs an owning device matrix and copies the data. + * When the data is copied, padding elements are filled with zeroes. + * + * @tparam SrcT the source mdarray or mdspan + * + * @param[in] res raft resources handle + * @param[in] src the source mdarray or mdspan + * @param[in] required_stride the leading dimension (in elements) + * @return maybe owning current-device-accessible strided matrix + */ +template +auto make_strided_dataset(const raft::resources& res, const SrcT& src, uint32_t required_stride) + -> std::unique_ptr> +{ + using extents_type = typename SrcT::extents_type; + using value_type = typename SrcT::value_type; + using index_type = typename SrcT::index_type; + using layout_type = typename SrcT::layout_type; + static_assert(extents_type::rank() == 2, "The input must be a matrix."); + static_assert(std::is_same_v || + std::is_same_v> || + std::is_same_v, + "The input must be row-major"); + RAFT_EXPECTS(src.extent(1) <= required_stride, + "The input row length must be not larger than the desired stride."); + cudaPointerAttributes ptr_attrs; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&ptr_attrs, src.data_handle())); + auto* device_ptr = reinterpret_cast(ptr_attrs.devicePointer); + const uint32_t src_stride = src.stride(0) > 0 ? src.stride(0) : src.extent(1); + const bool device_accessible = device_ptr != nullptr; + const bool row_major = src.stride(1) <= 1; + const bool stride_matches = required_stride == src_stride; + + if (device_accessible && row_major && stride_matches) { + // Everything matches: make a non-owning dataset + return std::make_unique>( + make_device_strided_matrix_view( + device_ptr, src.extent(0), src.extent(1), required_stride)); + } + // Something is wrong: have to make a copy and produce an owning dataset + auto out_layout = + make_strided_layout(src.extents(), std::array{required_stride, 1}); + auto out_array = make_device_matrix(res, src.extent(0), required_stride); + + using out_mdarray_type = decltype(out_array); + using out_layout_type = typename out_mdarray_type::layout_type; + using out_container_policy_type = typename out_mdarray_type::container_policy_type; + using out_owning_type = + owning_dataset; + + RAFT_CUDA_TRY(cudaMemsetAsync(out_array.data_handle(), + 0, + out_array.size() * sizeof(value_type), + resource::get_cuda_stream(res))); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(out_array.data_handle(), + sizeof(value_type) * required_stride, + src.data_handle(), + sizeof(value_type) * src_stride, + sizeof(value_type) * src.extent(1), + src.extent(0), + cudaMemcpyDefault, + resource::get_cuda_stream(res))); + + return std::make_unique(std::move(out_array), out_layout); +} + +/** + * @brief Contstruct a strided matrix from any mdarray or mdspan. + * + * A variant `make_strided_dataset` that allows specifying the byte alignment instead of the + * explicit stride length. + * + * @tparam SrcT the source mdarray or mdspan + * + * @param[in] res raft resources handle + * @param[in] src the source mdarray or mdspan + * @param[in] align_bytes the required byte alignment for the dataset rows. + * @return maybe owning current-device-accessible strided matrix + */ +template +auto make_aligned_dataset(const raft::resources& res, const SrcT& src, uint32_t align_bytes = 16) + -> std::unique_ptr> +{ + using value_type = typename SrcT::value_type; + constexpr size_t kSize = sizeof(value_type); + uint32_t required_stride = + raft::round_up_safe(src.extent(1) * kSize, std::lcm(align_bytes, kSize)) / kSize; + return make_strided_dataset(res, src, required_stride); +} + +/** Parameters for VPQ compression. */ +struct vpq_params { + /** + * The bit length of the vector element after compression by PQ. + * + * Possible values: [4, 5, 6, 7, 8]. + * + * Hint: the smaller the 'pq_bits', the smaller the index size and the better the search + * performance, but the lower the recall. + */ + uint32_t pq_bits = 8; + /** + * The dimensionality of the vector after compression by PQ. + * When zero, an optimal value is selected using a heuristic. + * + * TODO: at the moment `dim` must be a multiple `pq_dim`. + */ + uint32_t pq_dim = 0; + /** + * Vector Quantization (VQ) codebook size - number of "coarse cluster centers". + * When zero, an optimal value is selected using a heuristic. + */ + uint32_t vq_n_centers = 0; + /** The number of iterations searching for kmeans centers (both VQ & PQ phases). */ + uint32_t kmeans_n_iters = 25; + /** + * The fraction of data to use during iterative kmeans building (VQ phase). + * When zero, an optimal value is selected using a heuristic. + */ + double vq_kmeans_trainset_fraction = 0; + /** + * The fraction of data to use during iterative kmeans building (PQ phase). + * When zero, an optimal value is selected using a heuristic. + */ + double pq_kmeans_trainset_fraction = 0; +}; + +/** + * @brief VPQ compressed dataset. + * + * The dataset is compressed using two level quantization + * + * 1. Vector Quantization + * 2. Product Quantization of residuals + * + * @tparam MathT the type of elements in the codebooks + * @tparam IdxT type of the vector indices (represent dataset.extent(0)) + * + */ +template +struct vpq_dataset : public dataset { + /** Vector Quantization codebook - "coarse cluster centers". */ + device_matrix vq_code_book; + /** Product Quantization codebook - "fine cluster centers". */ + device_matrix pq_code_book; + /** Compressed dataset. */ + device_matrix data; + + vpq_dataset(device_matrix&& vq_code_book, + device_matrix&& pq_code_book, + device_matrix&& data) + : vq_code_book{std::move(vq_code_book)}, + pq_code_book{std::move(pq_code_book)}, + data{std::move(data)} + { + } + + [[nodiscard]] auto n_rows() const noexcept -> IdxT final { return data.extent(0); } + [[nodiscard]] auto dim() const noexcept -> uint32_t final { return vq_code_book.extent(1); } + [[nodiscard]] auto is_owning() const noexcept -> bool final { return true; } + + /** Row length of the encoded data in bytes. */ + [[nodiscard]] constexpr inline auto encoded_row_length() const noexcept -> uint32_t + { + return data.extent(1); + } + /** The number of "coarse cluster centers" */ + [[nodiscard]] constexpr inline auto vq_n_centers() const noexcept -> uint32_t + { + return vq_code_book.extent(0); + } + /** The bit length of an encoded vector element after compression by PQ. */ + [[nodiscard]] constexpr inline auto pq_bits() const noexcept -> uint32_t + { + /* + NOTE: pq_bits and the book size + + Normally, we'd store `pq_bits` as a part of the index. + However, we know there's an invariant `pq_n_centers = 1 << pq_bits`, i.e. the codebook size is + the same as the number of possible code values. Hence, we don't store the pq_bits and derive it + from the array dimensions instead. + */ + auto pq_width = pq_n_centers(); +#ifdef __cpp_lib_bitops + return std::countr_zero(pq_width); +#else + uint32_t pq_bits = 0; + while (pq_width > 1) { + pq_bits++; + pq_width >>= 1; + } + return pq_bits; +#endif + } + /** The dimensionality of an encoded vector after compression by PQ. */ + [[nodiscard]] constexpr inline auto pq_dim() const noexcept -> uint32_t + { + return raft::div_rounding_up_unsafe(dim(), pq_len()); + } + /** Dimensionality of a subspaces, i.e. the number of vector components mapped to a subspace */ + [[nodiscard]] constexpr inline auto pq_len() const noexcept -> uint32_t + { + return pq_code_book.extent(1); + } + /** The number of vectors in a PQ codebook (`1 << pq_bits`). */ + [[nodiscard]] constexpr inline auto pq_n_centers() const noexcept -> uint32_t + { + return pq_code_book.extent(0); + } +}; + +} // namespace raft::neighbors diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh index 08cc2beaeb..d91e45257e 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh @@ -16,6 +16,7 @@ #pragma once #include "../../cagra_types.hpp" +#include "../../vpq_dataset.cuh" #include "graph_core.cuh" #include @@ -344,6 +345,15 @@ index build( RAFT_LOG_INFO("Graph optimized, creating index"); // Construct an index from dataset and optimized knn graph. if (construct_index_with_dataset) { + if (params.compression.has_value()) { + index idx(res, params.metric); + idx.update_graph(res, raft::make_const_mdspan(cagra_graph.view())); + idx.update_dataset( + res, + // TODO: hardcoding codebook math to `half`, we can do runtime dispatching later + neighbors::vpq_build(res, *params.compression, dataset)); + return idx; + } return index(res, params.metric, dataset, raft::make_const_mdspan(cagra_graph.view())); } else { // We just add the graph. User is expected to update dataset separately. This branch is used diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh index d7bd27222b..600c8785e0 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -31,7 +32,7 @@ namespace raft::neighbors::cagra::detail { -constexpr int serialization_version = 3; +constexpr int serialization_version = 4; /** * Save the index to file. @@ -65,26 +66,14 @@ void serialize(raft::resources const& res, serialize_scalar(res, os, index_.metric()); serialize_mdspan(res, os, index_.graph()); - include_dataset &= (index_.dataset().extent(0) > 0); + include_dataset &= (index_.data().n_rows() > 0); serialize_scalar(res, os, include_dataset); if (include_dataset) { RAFT_LOG_INFO("Saving CAGRA index with dataset"); - auto dataset = index_.dataset(); - // Remove padding before saving the dataset - auto host_dataset = make_host_matrix(dataset.extent(0), dataset.extent(1)); - RAFT_CUDA_TRY(cudaMemcpy2DAsync(host_dataset.data_handle(), - sizeof(T) * host_dataset.extent(1), - dataset.data_handle(), - sizeof(T) * dataset.stride(0), - sizeof(T) * host_dataset.extent(1), - dataset.extent(0), - cudaMemcpyDefault, - resource::get_cuda_stream(res))); - resource::sync_stream(res); - serialize_mdspan(res, os, host_dataset.view()); + neighbors::detail::serialize(res, os, index_.data()); } else { - RAFT_LOG_INFO("Saving CAGRA index WITHOUT dataset"); + RAFT_LOG_DEBUG("Saving CAGRA index WITHOUT dataset"); } } @@ -256,19 +245,13 @@ auto deserialize(raft::resources const& res, std::istream& is) -> index auto graph = raft::make_host_matrix(n_rows, graph_degree); deserialize_mdspan(res, is, graph.view()); + index idx(res, metric); + idx.update_graph(res, raft::make_const_mdspan(graph.view())); bool has_dataset = deserialize_scalar(res, is); if (has_dataset) { - auto dataset = raft::make_host_matrix(n_rows, dim); - deserialize_mdspan(res, is, dataset.view()); - return index( - res, metric, raft::make_const_mdspan(dataset.view()), raft::make_const_mdspan(graph.view())); - } else { - // create a new index with no dataset - the user must supply via update_dataset themselves - // later (this avoids allocating GPU memory in the meantime) - index idx(res, metric); - idx.update_graph(res, raft::make_const_mdspan(graph.view())); - return idx; + idx.update_dataset(res, neighbors::detail::deserialize_dataset(res, is)); } + return idx; } template diff --git a/cpp/include/raft/neighbors/detail/dataset_serialize.hpp b/cpp/include/raft/neighbors/detail/dataset_serialize.hpp new file mode 100644 index 0000000000..a6a6ae59a5 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/dataset_serialize.hpp @@ -0,0 +1,192 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "../dataset.hpp" + +#include +#include +#include + +#include + +#include +#include + +namespace raft::neighbors::detail { + +using dataset_instance_tag = uint32_t; +constexpr dataset_instance_tag kSerializeEmptyDataset = 1; +constexpr dataset_instance_tag kSerializeStridedDataset = 2; +constexpr dataset_instance_tag kSerializeVPQDataset = 3; + +template +void serialize(const raft::resources& res, std::ostream& os, const empty_dataset& dataset) +{ + serialize_scalar(res, os, dataset.suggested_dim); +} + +template +void serialize(const raft::resources& res, + std::ostream& os, + const strided_dataset& dataset) +{ + auto n_rows = dataset.n_rows(); + auto dim = dataset.dim(); + auto stride = dataset.stride(); + serialize_scalar(res, os, n_rows); + serialize_scalar(res, os, dim); + serialize_scalar(res, os, stride); + // Remove padding before saving the dataset + auto src = dataset.view(); + auto dst = make_host_matrix(n_rows, dim); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(dst.data_handle(), + sizeof(DataT) * dim, + src.data_handle(), + sizeof(DataT) * stride, + sizeof(DataT) * dim, + n_rows, + cudaMemcpyDefault, + resource::get_cuda_stream(res))); + resource::sync_stream(res); + serialize_mdspan(res, os, dst.view()); +} + +template +void serialize(const raft::resources& res, + std::ostream& os, + const vpq_dataset& dataset) +{ + serialize_scalar(res, os, dataset.n_rows()); + serialize_scalar(res, os, dataset.dim()); + serialize_scalar(res, os, dataset.vq_n_centers()); + serialize_scalar(res, os, dataset.pq_n_centers()); + serialize_scalar(res, os, dataset.pq_len()); + serialize_scalar(res, os, dataset.encoded_row_length()); + serialize_mdspan(res, os, make_const_mdspan(dataset.vq_code_book.view())); + serialize_mdspan(res, os, make_const_mdspan(dataset.pq_code_book.view())); + serialize_mdspan(res, os, make_const_mdspan(dataset.data.view())); +} + +template +void serialize(const raft::resources& res, std::ostream& os, const dataset& dataset) +{ + if (auto x = dynamic_cast*>(&dataset); x != nullptr) { + serialize_scalar(res, os, kSerializeEmptyDataset); + return serialize(res, os, *x); + } + if (auto x = dynamic_cast*>(&dataset); x != nullptr) { + serialize_scalar(res, os, kSerializeStridedDataset); + serialize_scalar(res, os, CUDA_R_32F); + return serialize(res, os, *x); + } + if (auto x = dynamic_cast*>(&dataset); x != nullptr) { + serialize_scalar(res, os, kSerializeStridedDataset); + serialize_scalar(res, os, CUDA_R_16F); + return serialize(res, os, *x); + } + if (auto x = dynamic_cast*>(&dataset); x != nullptr) { + serialize_scalar(res, os, kSerializeStridedDataset); + serialize_scalar(res, os, CUDA_R_8I); + return serialize(res, os, *x); + } + if (auto x = dynamic_cast*>(&dataset); x != nullptr) { + serialize_scalar(res, os, kSerializeStridedDataset); + serialize_scalar(res, os, CUDA_R_8U); + return serialize(res, os, *x); + } + if (auto x = dynamic_cast*>(&dataset); x != nullptr) { + serialize_scalar(res, os, kSerializeVPQDataset); + serialize_scalar(res, os, CUDA_R_32F); + return serialize(res, os, *x); + } + if (auto x = dynamic_cast*>(&dataset); x != nullptr) { + serialize_scalar(res, os, kSerializeVPQDataset); + serialize_scalar(res, os, CUDA_R_16F); + return serialize(res, os, *x); + } + RAFT_FAIL("unsupported dataset type."); +} + +template +auto deserialize_empty(raft::resources const& res, std::istream& is) + -> std::unique_ptr> +{ + auto suggested_dim = deserialize_scalar(res, is); + return std::make_unique>(suggested_dim); +} + +template +auto deserialize_strided(raft::resources const& res, std::istream& is) + -> std::unique_ptr> +{ + auto n_rows = deserialize_scalar(res, is); + auto dim = deserialize_scalar(res, is); + auto stride = deserialize_scalar(res, is); + auto host_array = make_host_matrix(n_rows, dim); + deserialize_mdspan(res, is, host_array.view()); + return make_strided_dataset(res, host_array, stride); +} + +template +auto deserialize_vpq(raft::resources const& res, std::istream& is) + -> std::unique_ptr> +{ + auto n_rows = deserialize_scalar(res, is); + auto dim = deserialize_scalar(res, is); + auto vq_n_centers = deserialize_scalar(res, is); + auto pq_n_centers = deserialize_scalar(res, is); + auto pq_len = deserialize_scalar(res, is); + auto encoded_row_length = deserialize_scalar(res, is); + + auto vq_code_book = make_device_matrix(res, vq_n_centers, dim); + auto pq_code_book = make_device_matrix(res, pq_n_centers, pq_len); + auto data = make_device_matrix(res, n_rows, encoded_row_length); + + deserialize_mdspan(res, is, vq_code_book.view()); + deserialize_mdspan(res, is, pq_code_book.view()); + deserialize_mdspan(res, is, data.view()); + + return std::make_unique>( + std::move(vq_code_book), std::move(pq_code_book), std::move(data)); +} + +template +auto deserialize_dataset(raft::resources const& res, std::istream& is) + -> std::unique_ptr> +{ + switch (deserialize_scalar(res, is)) { + case kSerializeEmptyDataset: return deserialize_empty(res, is); + case kSerializeStridedDataset: + switch (deserialize_scalar(res, is)) { + case CUDA_R_32F: return deserialize_strided(res, is); + case CUDA_R_16F: return deserialize_strided(res, is); + case CUDA_R_8I: return deserialize_strided(res, is); + case CUDA_R_8U: return deserialize_strided(res, is); + default: break; + } + case kSerializeVPQDataset: + switch (deserialize_scalar(res, is)) { + case CUDA_R_32F: return deserialize_vpq(res, is); + case CUDA_R_16F: return deserialize_vpq(res, is); + default: break; + } + default: break; + } + RAFT_FAIL("Failed to deserialize dataset: unsupported combination of instance tags."); +} + +} // namespace raft::neighbors::detail diff --git a/cpp/include/raft/neighbors/detail/vpq_dataset.cuh b/cpp/include/raft/neighbors/detail/vpq_dataset.cuh new file mode 100644 index 0000000000..f6cd2a1ceb --- /dev/null +++ b/cpp/include/raft/neighbors/detail/vpq_dataset.cuh @@ -0,0 +1,427 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "../dataset.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include // pq_bits-bitfield +#include // utils::mapping etc +#include +#include + +// A temporary stub till https://github.com/rapidsai/raft/pull/2077 is re-merged +namespace raft::util { + +/** + * Subsample the dataset to create a training set. + * + * @tparam DatasetT a row-major mdspan or mdarray (device or host) + * + * @param res raft handle + * @param dataset input row-major mdspan or mdarray (device or host) + * @param n_samples the size of the output mdarray + * + * @return a newly allocated subset of the dataset. + */ +template +auto subsample(raft::resources const& res, + const DatasetT& dataset, + typename DatasetT::index_type n_samples) + -> raft::device_matrix +{ + using value_type = typename DatasetT::value_type; + using index_type = typename DatasetT::index_type; + static_assert(std::is_same_v, + "Only row-major layout is supported at the moment"); + RAFT_EXPECTS(n_samples <= dataset.extent(0), + "The number of samples must be smaller than the number of input rows in the current " + "implementation."); + size_t dim = dataset.extent(1); + size_t trainset_ratio = dataset.extent(0) / n_samples; + auto result = raft::make_device_matrix(res, n_samples, dataset.extent(1)); + + RAFT_CUDA_TRY(cudaMemcpy2DAsync(result.data_handle(), + sizeof(value_type) * dim, + dataset.data_handle(), + sizeof(value_type) * dim * trainset_ratio, + sizeof(value_type) * dim, + n_samples, + cudaMemcpyDefault, + raft::resource::get_cuda_stream(res))); + return result; +} + +} // namespace raft::util + +namespace raft::neighbors::detail { + +template +auto fill_missing_params_heuristics(const vpq_params& params, const DatasetT& dataset) -> vpq_params +{ + vpq_params r = params; + double n_rows = dataset.extent(0); + size_t dim = dataset.extent(1); + if (r.pq_dim == 0) { r.pq_dim = raft::div_rounding_up_safe(dim, size_t{4}); } + if (r.pq_bits == 0) { r.pq_bits = 8; } + if (r.vq_n_centers == 0) { r.vq_n_centers = raft::round_up_safe(std::sqrt(n_rows), 8); } + if (r.vq_kmeans_trainset_fraction == 0) { + double vq_trainset_size = 100.0 * r.vq_n_centers; + r.vq_kmeans_trainset_fraction = std::min(1.0, vq_trainset_size / n_rows); + } + if (r.pq_kmeans_trainset_fraction == 0) { + // NB: we'll have actually `pq_dim` times more samples than this + // (because the dataset is reinterpreted as `[n_rows * pq_dim, pq_len]`) + double pq_trainset_size = 1000.0 * (1u << r.pq_bits); + r.pq_kmeans_trainset_fraction = std::min(1.0, pq_trainset_size / n_rows); + } + return r; +} + +template +auto transform_data(const raft::resources& res, DatasetT dataset) + -> device_mdarray +{ + using index_type = typename DatasetT::index_type; + using extents_type = typename DatasetT::extents_type; + using layout_type = typename DatasetT::layout_type; + using out_mdarray_type = device_mdarray; + if constexpr (std::is_same_v>) { return dataset; } + + auto result = raft::make_device_mdarray(res, dataset.extents()); + + linalg::map(res, + result.view(), + spatial::knn::detail::utils::mapping{}, + raft::make_const_mdspan(dataset.view())); + + return result; +} + +/** Fix the internal indexing type to avoid integer underflows/overflows */ +using ix_t = int64_t; + +template +auto train_vq(const raft::resources& res, const vpq_params& params, const DatasetT& dataset) + -> device_matrix +{ + const ix_t n_rows = dataset.extent(0); + const ix_t vq_n_centers = params.vq_n_centers; + const ix_t dim = dataset.extent(1); + const ix_t n_rows_train = n_rows * params.vq_kmeans_trainset_fraction; + + // Subsample the dataset and transform into the required type if necessary + auto vq_trainset = raft::util::subsample(res, dataset, n_rows_train); + auto vq_centers = raft::make_device_matrix(res, vq_n_centers, dim); + + using kmeans_in_type = typename DatasetT::value_type; + raft::cluster::kmeans_balanced_params kmeans_params; + kmeans_params.n_iters = params.kmeans_n_iters; + kmeans_params.metric = raft::distance::DistanceType::L2Expanded; + auto vq_centers_view = + raft::make_device_matrix_view(vq_centers.data_handle(), vq_n_centers, dim); + auto vq_trainset_view = raft::make_device_matrix_view( + vq_trainset.data_handle(), n_rows_train, dim); + raft::cluster::kmeans_balanced::fit( + res, + kmeans_params, + vq_trainset_view, + vq_centers_view, + spatial::knn::detail::utils::mapping{}); + + return vq_centers; +} + +template +auto predict_vq(const raft::resources& res, const DatasetT& dataset, const VqCentersT& vq_centers) + -> device_vector +{ + using kmeans_data_type = typename DatasetT::value_type; + using kmeans_math_type = typename VqCentersT::value_type; + using index_type = typename DatasetT::index_type; + using label_type = LabelT; + + auto vq_labels = raft::make_device_vector(res, dataset.extent(0)); + + raft::cluster::kmeans_balanced_params kmeans_params; + kmeans_params.metric = raft::distance::DistanceType::L2Expanded; + + auto vq_centers_view = raft::make_device_matrix_view( + vq_centers.data_handle(), vq_centers.extent(0), vq_centers.extent(1)); + + auto vq_dataset_view = raft::make_device_matrix_view( + dataset.data_handle(), dataset.extent(0), dataset.extent(1)); + + raft::cluster::kmeans_balanced:: + predict( + res, + kmeans_params, + vq_dataset_view, + vq_centers_view, + vq_labels.view(), + spatial::knn::detail::utils::mapping{}); + + return vq_labels; +} + +template +auto train_pq(const raft::resources& res, + const vpq_params& params, + const DatasetT& dataset, + const device_matrix_view& vq_centers) + -> device_matrix +{ + const ix_t n_rows = dataset.extent(0); + const ix_t dim = dataset.extent(1); + const ix_t pq_dim = params.pq_dim; + const ix_t pq_bits = params.pq_bits; + const ix_t pq_n_centers = ix_t{1} << pq_bits; + const ix_t pq_len = raft::div_rounding_up_safe(dim, pq_dim); + const ix_t n_rows_train = n_rows * params.pq_kmeans_trainset_fraction; + + // Subsample the dataset and transform into the required type if necessary + auto pq_trainset = transform_data(res, raft::util::subsample(res, dataset, n_rows_train)); + + // Subtract VQ centers + { + auto vq_labels = predict_vq(res, pq_trainset, vq_centers); + using index_type = typename DatasetT::index_type; + linalg::map_offset( + res, + pq_trainset.view(), + [labels = vq_labels.view(), centers = vq_centers, dim] __device__(index_type off, MathT x) { + index_type i = off / dim; + index_type j = off % dim; + return x - centers(labels(i), j); + }, + raft::make_const_mdspan(pq_trainset.view())); + } + + auto pq_centers = raft::make_device_matrix(res, pq_n_centers, pq_len); + + // Train PQ centers + { + raft::cluster::kmeans_balanced_params kmeans_params; + kmeans_params.n_iters = params.kmeans_n_iters; + kmeans_params.metric = raft::distance::DistanceType::L2Expanded; + + auto pq_centers_view = + raft::make_device_matrix_view(pq_centers.data_handle(), pq_n_centers, pq_len); + + auto pq_trainset_view = raft::make_device_matrix_view( + pq_trainset.data_handle(), n_rows_train * pq_dim, pq_len); + + raft::cluster::kmeans_balanced::fit( + res, kmeans_params, pq_trainset_view, pq_centers_view); + } + + return pq_centers; +} + +template +__device__ auto compute_code(device_matrix_view dataset, + device_matrix_view vq_centers, + device_matrix_view pq_centers, + IdxT i, + uint32_t j, + LabelT vq_label) -> uint8_t +{ + auto data_mapping = spatial::knn::detail::utils::mapping{}; + uint32_t lane_id = Pow2::mod(laneId()); + + const uint32_t pq_book_size = pq_centers.extent(0); + const uint32_t pq_len = pq_centers.extent(1); + float min_dist = std::numeric_limits::infinity(); + uint8_t code = 0; + // calculate the distance for each PQ cluster, find the minimum for each thread + for (uint32_t l = lane_id; l < pq_book_size; l += SubWarpSize) { + // NB: the L2 quantifiers on residuals are always trained on L2 metric. + float d = 0.0f; + for (uint32_t k = 0; k < pq_len; k++) { + auto jk = j * pq_len + k; + auto x = data_mapping(dataset(i, jk)) - vq_centers(vq_label, jk); + auto t = x - pq_centers(l, k); + d += t * t; + } + if (d < min_dist) { + min_dist = d; + code = uint8_t(l); + } + } + // reduce among threads +#pragma unroll + for (uint32_t stride = SubWarpSize >> 1; stride > 0; stride >>= 1) { + const auto other_dist = shfl_xor(min_dist, stride, SubWarpSize); + const auto other_code = shfl_xor(code, stride, SubWarpSize); + if (other_dist < min_dist) { + min_dist = other_dist; + code = other_code; + } + } + return code; +} + +template +__launch_bounds__(BlockSize) RAFT_KERNEL + process_and_fill_codes_kernel(device_matrix_view out_codes, + device_matrix_view dataset, + device_matrix_view vq_centers, + device_vector_view vq_labels, + device_matrix_view pq_centers) +{ + constexpr uint32_t kSubWarpSize = std::min(WarpSize, 1u << PqBits); + using subwarp_align = Pow2; + const IdxT row_ix = subwarp_align::div(IdxT{threadIdx.x} + IdxT{BlockSize} * IdxT{blockIdx.x}); + if (row_ix >= out_codes.extent(0)) { return; } + + const uint32_t pq_dim = raft::div_rounding_up_unsafe(vq_centers.extent(1), pq_centers.extent(1)); + + const uint32_t lane_id = Pow2::mod(threadIdx.x); + const LabelT vq_label = vq_labels(row_ix); + + // write label + auto* out_label_ptr = reinterpret_cast(&out_codes(row_ix, 0)); + if (lane_id == 0) { *out_label_ptr = vq_label; } + + auto* out_codes_ptr = reinterpret_cast(out_label_ptr + 1); + ivf_pq::detail::bitfield_view_t code_view{out_codes_ptr}; + for (uint32_t j = 0; j < pq_dim; j++) { + // find PQ label + uint8_t code = compute_code(dataset, vq_centers, pq_centers, row_ix, j, vq_label); + // TODO: this writes in global memory one byte per warp, which is very slow. + // It's better to keep the codes in the shared memory or registers and dump them at once. + if (lane_id == 0) { code_view[j] = code; } + } +} + +template +auto process_and_fill_codes(const raft::resources& res, + const vpq_params& params, + const DatasetT& dataset, + device_matrix_view vq_centers, + device_matrix_view pq_centers) + -> device_matrix +{ + using data_t = typename DatasetT::value_type; + using cdataset_t = vpq_dataset; + using label_t = uint32_t; + + const ix_t n_rows = dataset.extent(0); + const ix_t dim = dataset.extent(1); + const ix_t pq_dim = params.pq_dim; + const ix_t pq_bits = params.pq_bits; + const ix_t pq_n_centers = ix_t{1} << pq_bits; + // NB: codes must be aligned at least to sizeof(label_t) to be able to read labels. + const ix_t codes_rowlen = + sizeof(label_t) * (1 + raft::div_rounding_up_safe(pq_dim * pq_bits, 8 * sizeof(label_t))); + + auto codes = raft::make_device_matrix(res, n_rows, codes_rowlen); + + auto stream = raft::resource::get_cuda_stream(res); + + // TODO: with scaling workspace we could choose the batch size dynamically + constexpr ix_t kReasonableMaxBatchSize = 65536; + constexpr ix_t kBlockSize = 256; + const ix_t threads_per_vec = std::min(WarpSize, pq_n_centers); + dim3 threads(kBlockSize, 1, 1); + ix_t max_batch_size = std::min(n_rows, kReasonableMaxBatchSize); + auto kernel = [](uint32_t pq_bits) { + switch (pq_bits) { + case 4: return process_and_fill_codes_kernel; + case 5: return process_and_fill_codes_kernel; + case 6: return process_and_fill_codes_kernel; + case 7: return process_and_fill_codes_kernel; + case 8: return process_and_fill_codes_kernel; + default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); + } + }(pq_bits); + for (const auto& batch : + spatial::knn::detail::utils::batch_load_iterator(dataset.data_handle(), + n_rows, + dim, + max_batch_size, + stream, + rmm::mr::get_current_device_resource())) { + auto batch_view = raft::make_device_matrix_view(batch.data(), ix_t(batch.size()), dim); + auto labels = predict_vq(res, batch_view, vq_centers); + dim3 blocks(div_rounding_up_safe(n_rows, kBlockSize / threads_per_vec), 1, 1); + kernel<<>>( + make_device_matrix_view( + codes.data_handle() + batch.offset() * codes_rowlen, batch.size(), codes_rowlen), + batch_view, + vq_centers, + make_const_mdspan(labels.view()), + pq_centers); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + + return codes; +} + +template +auto vpq_convert_math_type(const raft::resources& res, vpq_dataset&& src) + -> vpq_dataset +{ + auto vq_code_book = make_device_mdarray(res, src.vq_code_book.extents()); + auto pq_code_book = make_device_mdarray(res, src.pq_code_book.extents()); + + linalg::map(res, + vq_code_book.view(), + spatial::knn::detail::utils::mapping{}, + raft::make_const_mdspan(src.vq_code_book.view())); + linalg::map(res, + pq_code_book.view(), + spatial::knn::detail::utils::mapping{}, + raft::make_const_mdspan(src.pq_code_book.view())); + return vpq_dataset{ + std::move(vq_code_book), std::move(pq_code_book), std::move(src.data)}; +} + +template +auto vpq_build(const raft::resources& res, const vpq_params& params, const DatasetT& dataset) + -> vpq_dataset +{ + // Use a heuristic to impute missing parameters. + auto ps = fill_missing_params_heuristics(params, dataset); + + // Train codes + auto vq_code_book = train_vq(res, ps, dataset); + auto pq_code_book = + train_pq(res, ps, dataset, raft::make_const_mdspan(vq_code_book.view())); + + // Encode dataset + auto codes = process_and_fill_codes(res, + ps, + dataset, + raft::make_const_mdspan(vq_code_book.view()), + raft::make_const_mdspan(pq_code_book.view())); + + return vpq_dataset{ + std::move(vq_code_book), std::move(pq_code_book), std::move(codes)}; +} + +} // namespace raft::neighbors::detail diff --git a/cpp/include/raft/neighbors/vpq_dataset.cuh b/cpp/include/raft/neighbors/vpq_dataset.cuh new file mode 100644 index 0000000000..73ee6c52ed --- /dev/null +++ b/cpp/include/raft/neighbors/vpq_dataset.cuh @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "dataset.hpp" +#include "detail/vpq_dataset.cuh" + +#include + +namespace raft::neighbors { + +/** + * @brief Compress a dataset for use in CAGRA-Q search in place of the original data. + * + * @tparam DatasetT a row-major mdspan or mdarray (device or host). + * @tparam MathT a type of the codebook elements and internal math ops. + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] res + * @param[in] params VQ and PQ parameters for compressing the data + * @param[in] dataset a row-major mdspan or mdarray (device or host) [n_rows, dim]. + */ +template +auto vpq_build(const raft::resources& res, const vpq_params& params, const DatasetT& dataset) + -> vpq_dataset +{ + if constexpr (std::is_same_v) { + return detail::vpq_convert_math_type( + res, detail::vpq_build(res, params, dataset)); + } else { + return detail::vpq_build(res, params, dataset); + } +} + +} // namespace raft::neighbors