From ac6b088d258c873c22407782770dca1d14613d92 Mon Sep 17 00:00:00 2001 From: achirkin Date: Tue, 5 Mar 2024 22:33:48 +0100 Subject: [PATCH 01/18] Add CAGRA-Q build (compression) --- .../neighbors/detail/cagra/cagra_build_q.cuh | 520 ++++++++++++++++++ 1 file changed, 520 insertions(+) create mode 100644 cpp/include/raft/neighbors/detail/cagra/cagra_build_q.cuh diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build_q.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build_q.cuh new file mode 100644 index 0000000000..36345ec1d7 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build_q.cuh @@ -0,0 +1,520 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "../../cagra_types.hpp" + +// reuse helper code +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef __cpp_lib_bitops +#include +#endif + +// A temporary stub till https://github.com/rapidsai/raft/pull/2077 is re-merged +namespace raft::util { + +/** + * Subsample the dataset to create a training set. + * + * @tparam DatasetT a row-major mdspan or mdarray (device or host) + * + * @param res raft handle + * @param dataset input row-major mdspan or mdarray (device or host) + * @param n_samples the size of the output mdarray + * + * @return a newly allocated subset of the dataset. + */ +template +auto subsample(raft::resources const& res, + const DatasetT& dataset, + typename DatasetT::index_type n_samples) + -> raft::device_matrix +{ + using value_type = typename DatasetT::value_type; + using index_type = typename DatasetT::index_type; + static_assert(std::is_same_v, + "Only row-major layout is supported at the moment"); + RAFT_EXPECTS(n_samples <= dataset.extent(0), + "The number of samples must be smaller than the number of input rows in the current " + "implementation."); + size_t dim = dataset.extent(1); + size_t trainset_ratio = dataset.extent(0) / n_samples; + auto result = raft::make_device_matrix(res, n_samples, dataset.extent(1)); + + RAFT_CUDA_TRY(cudaMemcpy2DAsync(result.data_handle(), + sizeof(value_type) * dim, + dataset.data_handle(), + sizeof(value_type) * dim * trainset_ratio, + sizeof(value_type) * dim, + n_samples, + cudaMemcpyDefault, + raft::resource::get_cuda_stream(res))); + return result; +} + +} // namespace raft::util + +namespace raft::neighbors::cagra::detail { + +struct compression_params { + /** + * The bit length of the vector element after compression by PQ. + * + * Possible values: [4, 5, 6, 7, 8]. + * + * Hint: the smaller the 'pq_bits', the smaller the index size and the better the search + * performance, but the lower the recall. + */ + uint32_t pq_bits = 8; + /** + * The dimensionality of the vector after compression by PQ. + * When zero, an optimal value is selected using a heuristic. + * + * NB: `pq_dim * pq_bits` must be a multiple of 8. + * + * Hint: a smaller 'pq_dim' results in a smaller index size and better search performance, but + * lower recall. If 'pq_bits' is 8, 'pq_dim' can be set to any number, but multiple of 8 are + * desirable for good performance. If 'pq_bits' is not 8, 'pq_dim' should be a multiple of 8. + * For good performance, it is desirable that 'pq_dim' is a multiple of 32. Ideally, 'pq_dim' + * should be also a divisor of the dataset dim. + */ + uint32_t pq_dim = 0; + /** + * Vector Quantization (VQ) codebook size - number of "coarse cluster centers". + * When zero, an optimal value is selected using a heuristic. + */ + uint32_t vq_n_centers = 0; + /** The number of iterations searching for kmeans centers (both VQ & PQ phases). */ + uint32_t kmeans_n_iters = 25; + /** + * The fraction of data to use during iterative kmeans building (VQ phase). + * When zero, an optimal value is selected using a heuristic. + */ + double vq_kmeans_trainset_fraction = 0; + /** + * The fraction of data to use during iterative kmeans building (PQ phase). + * When zero, an optimal value is selected using a heuristic. + */ + double pq_kmeans_trainset_fraction = 0; +}; + +template +auto fill_missing_params_heuristics(const compression_params& params, const DatasetT& dataset) + -> compression_params +{ + compression_params r = params; + double n_rows = dataset.extent(0); + size_t dim = dataset.extent(1); + if (r.pq_dim == 0) { r.pq_dim = raft::div_rounding_up_safe(dim, size_t{4}); } + if (r.pq_bits == 0) { r.pq_bits = 8; } + if (r.vq_n_centers == 0) { r.vq_n_centers = raft::round_up_safe(std::sqrt(n_rows), 8); } + if (r.vq_kmeans_trainset_fraction == 0) { + double vq_trainset_size = 100.0 * r.vq_n_centers; + r.vq_kmeans_trainset_fraction = std::min(1.0, vq_trainset_size / n_rows); + } + if (r.pq_kmeans_trainset_fraction == 0) { + // NB: we'll have actually `pq_dim` times more samples than this + // (because the dataset is reinterpreted as `[n_rows * pq_dim, pq_len]`) + double pq_trainset_size = 1000.0 * (1u << r.pq_bits); + r.pq_kmeans_trainset_fraction = std::min(1.0, pq_trainset_size / n_rows); + } + return r; +} + +template +struct compressed_dataset { + using raw_codes_type = uint8_t; + /** Vector Quantization codebook - "coarse cluster centers". */ + device_matrix vq_code_book; + /** Product Quantization codebook - "fine cluster centers". */ + device_matrix pq_code_book; + /** Compressed dataset. */ + device_matrix dataset; + + /** Total length of the index. */ + [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT { return dataset.extent(0); } + /** Dimensionality of the data. */ + [[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t + { + return vq_code_book.extent(1); + } + /** The number of "coarse cluster centers" */ + [[nodiscard]] constexpr inline auto vq_n_centers() const noexcept -> uint32_t + { + return vq_code_book.extent(0); + } + /** The bit length of an encoded vector element after compression by PQ. */ + [[nodiscard]] constexpr inline auto pq_bits() const noexcept -> uint32_t + { + auto pq_width = pq_n_centers(); +#ifdef __cpp_lib_bitops + return std::countr_zero(pq_width); +#else + uint32_t pq_bits = 0; + while (pq_width > 1) { + pq_bits++; + pq_width >>= 1; + } + return pq_bits; +#endif + } + /** The dimensionality of an encoded vector after compression by PQ. */ + [[nodiscard]] constexpr inline auto pq_dim() const noexcept -> uint32_t + { + return raft::div_rounding_up_unsafe(dim(), pq_len()); + } + /** Dimensionality of a subspaces, i.e. the number of vector components mapped to a subspace */ + [[nodiscard]] constexpr inline auto pq_len() const noexcept -> uint32_t + { + return pq_code_book.extent(1); + } + /** The number of vectors in a PQ codebook (`1 << pq_bits`). */ + [[nodiscard]] constexpr inline auto pq_n_centers() const noexcept -> uint32_t + { + return pq_code_book.extent(0); + } +}; + +/** Fix the internal indexing type to avoid integer underflows/overflows */ +using ix_t = int64_t; + +template +auto transform_data(const raft::resources& res, DatasetT dataset) + -> device_mdarray +{ + using index_type = typename DatasetT::index_type; + using extents_type = typename DatasetT::extents_type; + using layout_type = typename DatasetT::layout_type; + using out_mdarray_type = device_mdarray; + if constexpr (std::is_same_v>) { return dataset; } + + auto result = raft::make_device_mdarray(res, dataset.extents()); + + linalg::map(res, + result.view(), + spatial::knn::detail::utils::mapping{}, + raft::make_const_mdspan(dataset.view())); + + return result; +} + +template +auto train_vq(const raft::resources& res, const compression_params& params, const DatasetT& dataset) + -> device_matrix +{ + const ix_t n_rows = dataset.extent(0); + const ix_t vq_n_centers = params.vq_n_centers; + const ix_t dim = dataset.extent(1); + const ix_t n_rows_train = n_rows * params.vq_kmeans_trainset_fraction; + + // Subsample the dataset and transform into the required type if necessary + auto vq_trainset = raft::util::subsample(res, dataset, n_rows_train); + auto vq_centers = raft::make_device_matrix(res, vq_n_centers, dim); + + using kmeans_in_type = typename DatasetT::value_type; + raft::cluster::kmeans_balanced_params kmeans_params; + kmeans_params.n_iters = params.kmeans_n_iters; + kmeans_params.metric = raft::distance::DistanceType::L2Expanded; + auto vq_centers_view = + raft::make_device_matrix_view(vq_centers.data_handle(), vq_n_centers, dim); + auto vq_trainset_view = raft::make_device_matrix_view( + vq_trainset.data_handle(), n_rows_train, dim); + raft::cluster::kmeans_balanced::fit( + res, + kmeans_params, + vq_trainset_view, + vq_centers_view, + spatial::knn::detail::utils::mapping{}); + + return vq_centers; +} + +template +auto predict_vq(const raft::resources& res, const DatasetT& dataset, const VqCentersT& vq_centers) + -> device_vector +{ + using kmeans_data_type = typename DatasetT::value_type; + using kmeans_math_type = typename VqCentersT::value_type; + using index_type = typename DatasetT::index_type; + using label_type = LabelT; + + auto vq_labels = raft::make_device_vector(res, dataset.extent(0)); + + raft::cluster::kmeans_balanced_params kmeans_params; + kmeans_params.metric = raft::distance::DistanceType::L2Expanded; + + auto vq_centers_view = raft::make_device_matrix_view( + vq_centers.data_handle(), vq_centers.extent(0), vq_centers.extent(1)); + + auto vq_dataset_view = raft::make_device_matrix_view( + dataset.data_handle(), dataset.extent(0), dataset.extent(1)); + + raft::cluster::kmeans_balanced:: + predict( + res, + kmeans_params, + vq_dataset_view, + vq_centers_view, + vq_labels.view(), + spatial::knn::detail::utils::mapping{}); + + return vq_labels; +} + +template +auto train_pq(const raft::resources& res, + const compression_params& params, + const DatasetT& dataset, + const device_matrix_view& vq_centers) + -> device_matrix +{ + const ix_t n_rows = dataset.extent(0); + const ix_t dim = dataset.extent(1); + const ix_t pq_dim = params.pq_dim; + const ix_t pq_bits = params.pq_bits; + const ix_t pq_n_centers = ix_t{1} << pq_bits; + const ix_t pq_len = raft::div_rounding_up_safe(dim, pq_dim); + const ix_t n_rows_train = n_rows * params.pq_kmeans_trainset_fraction; + + // Subsample the dataset and transform into the required type if necessary + auto pq_trainset = transform_data(res, raft::util::subsample(res, dataset, n_rows_train)); + + // Subtract VQ centers + { + auto vq_labels = predict_vq(res, pq_trainset, vq_centers); + using index_type = typename DatasetT::index_type; + linalg::map_offset( + res, + pq_trainset.view(), + [labels = vq_labels.view(), centers = vq_centers, dim] __device__(index_type off, DataT x) { + index_type i = off / dim; + index_type j = off % dim; + return x - centers(labels(i), j); + }, + raft::make_const_mdspan(pq_trainset.view())); + } + + auto pq_centers = raft::make_device_matrix(res, pq_n_centers, pq_len); + + // Train PQ centers + { + raft::cluster::kmeans_balanced_params kmeans_params; + kmeans_params.n_iters = params.kmeans_n_iters; + kmeans_params.metric = raft::distance::DistanceType::L2Expanded; + + auto pq_centers_view = + raft::make_device_matrix_view(pq_centers.data_handle(), pq_n_centers, pq_len); + + auto pq_trainset_view = raft::make_device_matrix_view( + pq_trainset.data_handle(), n_rows_train * pq_dim, pq_len); + + raft::cluster::kmeans_balanced::fit( + res, kmeans_params, pq_trainset_view, pq_centers_view); + } + + return pq_centers; +} + +template +__device__ auto compute_code(device_matrix_view dataset, + device_matrix_view vq_centers, + device_matrix_view pq_centers, + IdxT i, + uint32_t j, + LabelT vq_label) -> uint8_t +{ + auto data_mapping = spatial::knn::detail::utils::mapping{}; + uint32_t lane_id = Pow2::mod(laneId()); + + const uint32_t pq_book_size = pq_centers.extent(0); + const uint32_t pq_len = pq_centers.extent(1); + float min_dist = std::numeric_limits::infinity(); + uint8_t code = 0; + // calculate the distance for each PQ cluster, find the minimum for each thread + for (uint32_t l = lane_id; l < pq_book_size; l += SubWarpSize) { + // NB: the L2 quantifiers on residuals are always trained on L2 metric. + float d = 0.0f; + for (uint32_t k = 0; k < pq_len; k++) { + auto jk = j * pq_len + k; + auto x = data_mapping(dataset(i, jk)) - vq_centers(vq_label, jk); + auto t = x - pq_centers(l, k); + d += t * t; + } + if (d < min_dist) { + min_dist = d; + code = uint8_t(l); + } + } + // reduce among threads +#pragma unroll + for (uint32_t stride = SubWarpSize >> 1; stride > 0; stride >>= 1) { + const auto other_dist = shfl_xor(min_dist, stride, SubWarpSize); + const auto other_code = shfl_xor(code, stride, SubWarpSize); + if (other_dist < min_dist) { + min_dist = other_dist; + code = other_code; + } + } + return code; +} + +template +__launch_bounds__(BlockSize) RAFT_KERNEL + process_and_fill_codes_kernel(device_matrix_view out_codes, + device_matrix_view dataset, + device_matrix_view vq_centers, + device_vector_view vq_labels, + device_matrix_view pq_centers) +{ + constexpr uint32_t kSubWarpSize = std::min(WarpSize, 1u << PqBits); + using subwarp_align = Pow2; + const IdxT row_ix = subwarp_align::div(IdxT{threadIdx.x} + IdxT{BlockSize} * IdxT{blockIdx.x}); + if (row_ix >= out_codes.extent(0)) { return; } + + const uint32_t pq_dim = raft::div_rounding_up_unsafe(vq_centers.extent(1), pq_centers.extent(1)); + + const uint32_t lane_id = Pow2::mod(threadIdx.x); + const LabelT vq_label = vq_labels(row_ix); + + // write label + auto* out_label_ptr = reinterpret_cast(&out_codes(row_ix, 0)); + if (lane_id == 0) { *out_label_ptr = vq_label; } + + auto* out_codes_ptr = reinterpret_cast(out_label_ptr + 1); + ivf_pq::detail::bitfield_view_t code_view{out_codes_ptr}; + for (uint32_t j = 0; j < pq_dim; j++) { + // find PQ label + uint8_t code = compute_code(dataset, vq_centers, pq_centers, row_ix, j, vq_label); + // TODO: this writes in global memory bytewise, which is very slow. + // It's better to keep the codes in the shared memory or registers and dump them at once. + if (lane_id == 0) { code_view[j] = code; } + } +} + +template +auto process_and_fill_codes(const raft::resources& res, + const compression_params& params, + const DatasetT& dataset, + device_matrix_view vq_centers, + device_matrix_view pq_centers) + -> device_matrix::raw_codes_type, IdxT, row_major> +{ + using data_t = typename DatasetT::value_type; + using cdataset_t = compressed_dataset; + using codes_t = typename cdataset_t::raw_codes_type; + using label_t = uint32_t; + + const ix_t n_rows = dataset.extent(0); + const ix_t dim = dataset.extent(1); + const ix_t pq_dim = params.pq_dim; + const ix_t pq_bits = params.pq_bits; + const ix_t pq_n_centers = ix_t{1} << pq_bits; + const ix_t pq_len = raft::div_rounding_up_safe(dim, pq_dim); + // NB: codes must be aligned at least to sizeof(label_t) to be able to read labels. + const ix_t codes_rowlen = + sizeof(label_t) * (1 + raft::div_rounding_up_safe(pq_dim * pq_bits, 8 * sizeof(label_t))); + + auto codes = raft::make_device_matrix(res, n_rows, codes_rowlen); + + auto stream = raft::resource::get_cuda_stream(res); + + // TODO: with scaling workspace we could choose the batch size dynamically + constexpr ix_t kReasonableMaxBatchSize = 65536; + constexpr ix_t kBlockSize = 256; + const ix_t threads_per_vec = std::min(WarpSize, pq_n_centers); + dim3 threads(kBlockSize, 1, 1); + ix_t max_batch_size = std::min(n_rows, kReasonableMaxBatchSize); + auto kernel = [](uint32_t pq_bits) { + switch (pq_bits) { + case 4: return process_and_fill_codes_kernel; + case 5: return process_and_fill_codes_kernel; + case 6: return process_and_fill_codes_kernel; + case 7: return process_and_fill_codes_kernel; + case 8: return process_and_fill_codes_kernel; + default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); + } + }(pq_bits); + for (const auto& batch : + spatial::knn::detail::utils::batch_load_iterator(dataset.data_handle(), + n_rows, + dim, + max_batch_size, + stream, + rmm::mr::get_current_device_resource())) { + auto batch_view = raft::make_device_matrix_view(batch.data(), ix_t(batch.size()), dim); + auto labels = predict_vq(res, batch_view, vq_centers); + dim3 blocks(div_rounding_up_safe(n_rows, kBlockSize / threads_per_vec), 1, 1); + kernel<<>>( + make_device_matrix_view( + codes.data_handle() + batch.offset() * codes_rowlen, batch.size(), codes_rowlen), + batch_view, + vq_centers, + make_const_mdspan(labels.view()), + pq_centers); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + + return codes; +} + +template +auto compress(const raft::resources& res, const compression_params& params, const DatasetT& dataset) + -> compressed_dataset +{ + // Use a heuristic to impute missing parameters. + auto ps = fill_missing_params_heuristics(params, dataset); + + // Relevant constants + const ix_t n_rows = dataset.extent(0); + const ix_t dim = dataset.extent(1); + const ix_t pq_dim = ps.pq_dim; + + // Train codes + auto vq_code_book = train_vq(res, ps, dataset); + auto pq_code_book = + train_pq(res, ps, dataset, raft::make_const_mdspan(vq_code_book.view())); + + // Encode dataset + auto codes = + process_and_fill_codes(res, + ps, + dataset, + raft::make_const_mdspan(vq_code_book.view()), + raft::make_const_mdspan(pq_code_book.view())); + + return compressed_dataset{vq_code_book, pq_code_book, codes}; +} + +} // namespace raft::neighbors::cagra::detail From 8b9bee0859460c5538c69006ad91b2054085ba76 Mon Sep 17 00:00:00 2001 From: achirkin Date: Wed, 6 Mar 2024 08:53:48 +0100 Subject: [PATCH 02/18] Formatting and style refactoring --- cpp/include/raft/neighbors/cagra.cuh | 21 ++ cpp/include/raft/neighbors/cagra_types.hpp | 107 ++++++++++ .../neighbors/detail/cagra/cagra_build_q.cuh | 192 ++++-------------- 3 files changed, 168 insertions(+), 152 deletions(-) diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index b8258297e6..19e010e08a 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -17,6 +17,7 @@ #pragma once #include "detail/cagra/cagra_build.cuh" +#include "detail/cagra/cagra_build_q.cuh" #include "detail/cagra/cagra_search.cuh" #include "detail/cagra/graph_core.cuh" @@ -279,6 +280,26 @@ index build(raft::resources const& res, return detail::build(res, params, dataset); } +/** + * @brief Compress a dataset for use in CAGRA-Q search in place of the original data. + * + * @tparam DatasetT a row-major mdspan or mdarray (device or host). + * @tparam MathT a type of the codebook elements and internal math ops. + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] res + * @param[in] params VQ and PQ parameters for compressing the data + * @param[in] dataset a row-major mdspan or mdarray (device or host) [n_rows, dim]. + */ +template +auto compress(const raft::resources& res, const compression_params& params, const DatasetT& dataset) + -> compressed_dataset +{ + return detail::compress(res, params, dataset); +} + /** * @brief Search ANN using the constructed index. * diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp index 0f574ae5bb..aaf6db74fa 100644 --- a/cpp/include/raft/neighbors/cagra_types.hpp +++ b/cpp/include/raft/neighbors/cagra_types.hpp @@ -35,6 +35,11 @@ #include #include #include + +#ifdef __cpp_lib_bitops +#include +#endif + namespace raft::neighbors::cagra { /** * @addtogroup cagra @@ -63,6 +68,43 @@ struct index_params : ann::index_params { size_t nn_descent_niter = 20; }; +/** Parameters for CAGRA-Q compression. */ +struct compression_params { + /** + * The bit length of the vector element after compression by PQ. + * + * Possible values: [4, 5, 6, 7, 8]. + * + * Hint: the smaller the 'pq_bits', the smaller the index size and the better the search + * performance, but the lower the recall. + */ + uint32_t pq_bits = 8; + /** + * The dimensionality of the vector after compression by PQ. + * When zero, an optimal value is selected using a heuristic. + * + * TODO: at the moment `dim` must be a multiple `pq_dim`. + */ + uint32_t pq_dim = 0; + /** + * Vector Quantization (VQ) codebook size - number of "coarse cluster centers". + * When zero, an optimal value is selected using a heuristic. + */ + uint32_t vq_n_centers = 0; + /** The number of iterations searching for kmeans centers (both VQ & PQ phases). */ + uint32_t kmeans_n_iters = 25; + /** + * The fraction of data to use during iterative kmeans building (VQ phase). + * When zero, an optimal value is selected using a heuristic. + */ + double vq_kmeans_trainset_fraction = 0; + /** + * The fraction of data to use during iterative kmeans building (PQ phase). + * When zero, an optimal value is selected using a heuristic. + */ + double pq_kmeans_trainset_fraction = 0; +}; + enum class search_algo { /** For large batch sizes. */ SINGLE_CTA, @@ -356,6 +398,71 @@ struct index : ann::index { raft::device_matrix_view graph_view_; }; +/** + * @brief CAGRA-Q compressed dataset. + * + * @tparam MathT the type of elements in the codebooks + * @tparam IdxT type of the vector indices (represent dataset.extent(0)) + * + */ +template +struct compressed_dataset { + /** Vector Quantization codebook - "coarse cluster centers". */ + device_matrix vq_code_book; + /** Product Quantization codebook - "fine cluster centers". */ + device_matrix pq_code_book; + /** Compressed dataset. */ + device_matrix dataset; + + /** Total length of the index. */ + [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT { return dataset.extent(0); } + /** Row length of the encoded data in bytes. */ + [[nodiscard]] constexpr inline auto encoded_row_length() const noexcept -> uint32_t + { + return dataset.extent(1); + } + /** Dimensionality of the data. */ + [[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t + { + return vq_code_book.extent(1); + } + /** The number of "coarse cluster centers" */ + [[nodiscard]] constexpr inline auto vq_n_centers() const noexcept -> uint32_t + { + return vq_code_book.extent(0); + } + /** The bit length of an encoded vector element after compression by PQ. */ + [[nodiscard]] constexpr inline auto pq_bits() const noexcept -> uint32_t + { + auto pq_width = pq_n_centers(); +#ifdef __cpp_lib_bitops + return std::countr_zero(pq_width); +#else + uint32_t pq_bits = 0; + while (pq_width > 1) { + pq_bits++; + pq_width >>= 1; + } + return pq_bits; +#endif + } + /** The dimensionality of an encoded vector after compression by PQ. */ + [[nodiscard]] constexpr inline auto pq_dim() const noexcept -> uint32_t + { + return raft::div_rounding_up_unsafe(dim(), pq_len()); + } + /** Dimensionality of a subspaces, i.e. the number of vector components mapped to a subspace */ + [[nodiscard]] constexpr inline auto pq_len() const noexcept -> uint32_t + { + return pq_code_book.extent(1); + } + /** The number of vectors in a PQ codebook (`1 << pq_bits`). */ + [[nodiscard]] constexpr inline auto pq_n_centers() const noexcept -> uint32_t + { + return pq_code_book.extent(0); + } +}; + /** @} */ } // namespace raft::neighbors::cagra diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build_q.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build_q.cuh index 36345ec1d7..b1b501aaef 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_build_q.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build_q.cuh @@ -17,10 +17,6 @@ #include "../../cagra_types.hpp" -// reuse helper code -#include -#include - #include #include #include @@ -28,13 +24,11 @@ #include #include #include +#include // pq_bits-bitfield +#include // utils::mapping etc #include #include -#ifdef __cpp_lib_bitops -#include -#endif - // A temporary stub till https://github.com/rapidsai/raft/pull/2077 is re-merged namespace raft::util { @@ -81,48 +75,6 @@ auto subsample(raft::resources const& res, namespace raft::neighbors::cagra::detail { -struct compression_params { - /** - * The bit length of the vector element after compression by PQ. - * - * Possible values: [4, 5, 6, 7, 8]. - * - * Hint: the smaller the 'pq_bits', the smaller the index size and the better the search - * performance, but the lower the recall. - */ - uint32_t pq_bits = 8; - /** - * The dimensionality of the vector after compression by PQ. - * When zero, an optimal value is selected using a heuristic. - * - * NB: `pq_dim * pq_bits` must be a multiple of 8. - * - * Hint: a smaller 'pq_dim' results in a smaller index size and better search performance, but - * lower recall. If 'pq_bits' is 8, 'pq_dim' can be set to any number, but multiple of 8 are - * desirable for good performance. If 'pq_bits' is not 8, 'pq_dim' should be a multiple of 8. - * For good performance, it is desirable that 'pq_dim' is a multiple of 32. Ideally, 'pq_dim' - * should be also a divisor of the dataset dim. - */ - uint32_t pq_dim = 0; - /** - * Vector Quantization (VQ) codebook size - number of "coarse cluster centers". - * When zero, an optimal value is selected using a heuristic. - */ - uint32_t vq_n_centers = 0; - /** The number of iterations searching for kmeans centers (both VQ & PQ phases). */ - uint32_t kmeans_n_iters = 25; - /** - * The fraction of data to use during iterative kmeans building (VQ phase). - * When zero, an optimal value is selected using a heuristic. - */ - double vq_kmeans_trainset_fraction = 0; - /** - * The fraction of data to use during iterative kmeans building (PQ phase). - * When zero, an optimal value is selected using a heuristic. - */ - double pq_kmeans_trainset_fraction = 0; -}; - template auto fill_missing_params_heuristics(const compression_params& params, const DatasetT& dataset) -> compression_params @@ -146,86 +98,32 @@ auto fill_missing_params_heuristics(const compression_params& params, const Data return r; } -template -struct compressed_dataset { - using raw_codes_type = uint8_t; - /** Vector Quantization codebook - "coarse cluster centers". */ - device_matrix vq_code_book; - /** Product Quantization codebook - "fine cluster centers". */ - device_matrix pq_code_book; - /** Compressed dataset. */ - device_matrix dataset; - - /** Total length of the index. */ - [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT { return dataset.extent(0); } - /** Dimensionality of the data. */ - [[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t - { - return vq_code_book.extent(1); - } - /** The number of "coarse cluster centers" */ - [[nodiscard]] constexpr inline auto vq_n_centers() const noexcept -> uint32_t - { - return vq_code_book.extent(0); - } - /** The bit length of an encoded vector element after compression by PQ. */ - [[nodiscard]] constexpr inline auto pq_bits() const noexcept -> uint32_t - { - auto pq_width = pq_n_centers(); -#ifdef __cpp_lib_bitops - return std::countr_zero(pq_width); -#else - uint32_t pq_bits = 0; - while (pq_width > 1) { - pq_bits++; - pq_width >>= 1; - } - return pq_bits; -#endif - } - /** The dimensionality of an encoded vector after compression by PQ. */ - [[nodiscard]] constexpr inline auto pq_dim() const noexcept -> uint32_t - { - return raft::div_rounding_up_unsafe(dim(), pq_len()); - } - /** Dimensionality of a subspaces, i.e. the number of vector components mapped to a subspace */ - [[nodiscard]] constexpr inline auto pq_len() const noexcept -> uint32_t - { - return pq_code_book.extent(1); - } - /** The number of vectors in a PQ codebook (`1 << pq_bits`). */ - [[nodiscard]] constexpr inline auto pq_n_centers() const noexcept -> uint32_t - { - return pq_code_book.extent(0); - } -}; - -/** Fix the internal indexing type to avoid integer underflows/overflows */ -using ix_t = int64_t; - -template +template auto transform_data(const raft::resources& res, DatasetT dataset) - -> device_mdarray + -> device_mdarray { using index_type = typename DatasetT::index_type; using extents_type = typename DatasetT::extents_type; using layout_type = typename DatasetT::layout_type; - using out_mdarray_type = device_mdarray; + using out_mdarray_type = device_mdarray; if constexpr (std::is_same_v>) { return dataset; } - auto result = raft::make_device_mdarray(res, dataset.extents()); + auto result = raft::make_device_mdarray(res, dataset.extents()); linalg::map(res, result.view(), - spatial::knn::detail::utils::mapping{}, + spatial::knn::detail::utils::mapping{}, raft::make_const_mdspan(dataset.view())); return result; } -template +/** Fix the internal indexing type to avoid integer underflows/overflows */ +using ix_t = int64_t; + +template auto train_vq(const raft::resources& res, const compression_params& params, const DatasetT& dataset) - -> device_matrix + -> device_matrix { const ix_t n_rows = dataset.extent(0); const ix_t vq_n_centers = params.vq_n_centers; @@ -234,22 +132,22 @@ auto train_vq(const raft::resources& res, const compression_params& params, cons // Subsample the dataset and transform into the required type if necessary auto vq_trainset = raft::util::subsample(res, dataset, n_rows_train); - auto vq_centers = raft::make_device_matrix(res, vq_n_centers, dim); + auto vq_centers = raft::make_device_matrix(res, vq_n_centers, dim); using kmeans_in_type = typename DatasetT::value_type; raft::cluster::kmeans_balanced_params kmeans_params; kmeans_params.n_iters = params.kmeans_n_iters; kmeans_params.metric = raft::distance::DistanceType::L2Expanded; auto vq_centers_view = - raft::make_device_matrix_view(vq_centers.data_handle(), vq_n_centers, dim); + raft::make_device_matrix_view(vq_centers.data_handle(), vq_n_centers, dim); auto vq_trainset_view = raft::make_device_matrix_view( vq_trainset.data_handle(), n_rows_train, dim); - raft::cluster::kmeans_balanced::fit( + raft::cluster::kmeans_balanced::fit( res, kmeans_params, vq_trainset_view, vq_centers_view, - spatial::knn::detail::utils::mapping{}); + spatial::knn::detail::utils::mapping{}); return vq_centers; } @@ -286,12 +184,12 @@ auto predict_vq(const raft::resources& res, const DatasetT& dataset, const VqCen return vq_labels; } -template +template auto train_pq(const raft::resources& res, const compression_params& params, const DatasetT& dataset, - const device_matrix_view& vq_centers) - -> device_matrix + const device_matrix_view& vq_centers) + -> device_matrix { const ix_t n_rows = dataset.extent(0); const ix_t dim = dataset.extent(1); @@ -302,7 +200,7 @@ auto train_pq(const raft::resources& res, const ix_t n_rows_train = n_rows * params.pq_kmeans_trainset_fraction; // Subsample the dataset and transform into the required type if necessary - auto pq_trainset = transform_data(res, raft::util::subsample(res, dataset, n_rows_train)); + auto pq_trainset = transform_data(res, raft::util::subsample(res, dataset, n_rows_train)); // Subtract VQ centers { @@ -311,7 +209,7 @@ auto train_pq(const raft::resources& res, linalg::map_offset( res, pq_trainset.view(), - [labels = vq_labels.view(), centers = vq_centers, dim] __device__(index_type off, DataT x) { + [labels = vq_labels.view(), centers = vq_centers, dim] __device__(index_type off, MathT x) { index_type i = off / dim; index_type j = off % dim; return x - centers(labels(i), j); @@ -319,7 +217,7 @@ auto train_pq(const raft::resources& res, raft::make_const_mdspan(pq_trainset.view())); } - auto pq_centers = raft::make_device_matrix(res, pq_n_centers, pq_len); + auto pq_centers = raft::make_device_matrix(res, pq_n_centers, pq_len); // Train PQ centers { @@ -328,12 +226,12 @@ auto train_pq(const raft::resources& res, kmeans_params.metric = raft::distance::DistanceType::L2Expanded; auto pq_centers_view = - raft::make_device_matrix_view(pq_centers.data_handle(), pq_n_centers, pq_len); + raft::make_device_matrix_view(pq_centers.data_handle(), pq_n_centers, pq_len); - auto pq_trainset_view = raft::make_device_matrix_view( + auto pq_trainset_view = raft::make_device_matrix_view( pq_trainset.data_handle(), n_rows_train * pq_dim, pq_len); - raft::cluster::kmeans_balanced::fit( + raft::cluster::kmeans_balanced::fit( res, kmeans_params, pq_trainset_view, pq_centers_view); } @@ -415,23 +313,22 @@ __launch_bounds__(BlockSize) RAFT_KERNEL for (uint32_t j = 0; j < pq_dim; j++) { // find PQ label uint8_t code = compute_code(dataset, vq_centers, pq_centers, row_ix, j, vq_label); - // TODO: this writes in global memory bytewise, which is very slow. + // TODO: this writes in global memory one byte per warp, which is very slow. // It's better to keep the codes in the shared memory or registers and dump them at once. if (lane_id == 0) { code_view[j] = code; } } } -template +template auto process_and_fill_codes(const raft::resources& res, const compression_params& params, const DatasetT& dataset, device_matrix_view vq_centers, device_matrix_view pq_centers) - -> device_matrix::raw_codes_type, IdxT, row_major> + -> device_matrix { using data_t = typename DatasetT::value_type; using cdataset_t = compressed_dataset; - using codes_t = typename cdataset_t::raw_codes_type; using label_t = uint32_t; const ix_t n_rows = dataset.extent(0); @@ -439,12 +336,11 @@ auto process_and_fill_codes(const raft::resources& res, const ix_t pq_dim = params.pq_dim; const ix_t pq_bits = params.pq_bits; const ix_t pq_n_centers = ix_t{1} << pq_bits; - const ix_t pq_len = raft::div_rounding_up_safe(dim, pq_dim); // NB: codes must be aligned at least to sizeof(label_t) to be able to read labels. const ix_t codes_rowlen = sizeof(label_t) * (1 + raft::div_rounding_up_safe(pq_dim * pq_bits, 8 * sizeof(label_t))); - auto codes = raft::make_device_matrix(res, n_rows, codes_rowlen); + auto codes = raft::make_device_matrix(res, n_rows, codes_rowlen); auto stream = raft::resource::get_cuda_stream(res); @@ -487,34 +383,26 @@ auto process_and_fill_codes(const raft::resources& res, return codes; } -template +template auto compress(const raft::resources& res, const compression_params& params, const DatasetT& dataset) - -> compressed_dataset + -> compressed_dataset { // Use a heuristic to impute missing parameters. auto ps = fill_missing_params_heuristics(params, dataset); - // Relevant constants - const ix_t n_rows = dataset.extent(0); - const ix_t dim = dataset.extent(1); - const ix_t pq_dim = ps.pq_dim; - // Train codes - auto vq_code_book = train_vq(res, ps, dataset); + auto vq_code_book = train_vq(res, ps, dataset); auto pq_code_book = - train_pq(res, ps, dataset, raft::make_const_mdspan(vq_code_book.view())); + train_pq(res, ps, dataset, raft::make_const_mdspan(vq_code_book.view())); // Encode dataset - auto codes = - process_and_fill_codes(res, - ps, - dataset, - raft::make_const_mdspan(vq_code_book.view()), - raft::make_const_mdspan(pq_code_book.view())); - - return compressed_dataset{vq_code_book, pq_code_book, codes}; + auto codes = process_and_fill_codes(res, + ps, + dataset, + raft::make_const_mdspan(vq_code_book.view()), + raft::make_const_mdspan(pq_code_book.view())); + + return compressed_dataset{vq_code_book, pq_code_book, codes}; } } // namespace raft::neighbors::cagra::detail From 1a720205dcdc1cb29a9488e112854b15d88a22e4 Mon Sep 17 00:00:00 2001 From: achirkin Date: Thu, 7 Mar 2024 16:53:15 +0100 Subject: [PATCH 03/18] Integrate vpq_dataset into cagra --- cpp/include/raft/neighbors/cagra.cuh | 21 -- cpp/include/raft/neighbors/cagra_types.hpp | 161 ++-------- cpp/include/raft/neighbors/dataset.hpp | 278 ++++++++++++++++++ .../cagra_build_q.cuh => vpq_dataset.cuh} | 30 +- cpp/include/raft/neighbors/vpq_dataset.cuh | 46 +++ 5 files changed, 356 insertions(+), 180 deletions(-) create mode 100644 cpp/include/raft/neighbors/dataset.hpp rename cpp/include/raft/neighbors/detail/{cagra/cagra_build_q.cuh => vpq_dataset.cuh} (95%) create mode 100644 cpp/include/raft/neighbors/vpq_dataset.cuh diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index 19e010e08a..b8258297e6 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -17,7 +17,6 @@ #pragma once #include "detail/cagra/cagra_build.cuh" -#include "detail/cagra/cagra_build_q.cuh" #include "detail/cagra/cagra_search.cuh" #include "detail/cagra/graph_core.cuh" @@ -280,26 +279,6 @@ index build(raft::resources const& res, return detail::build(res, params, dataset); } -/** - * @brief Compress a dataset for use in CAGRA-Q search in place of the original data. - * - * @tparam DatasetT a row-major mdspan or mdarray (device or host). - * @tparam MathT a type of the codebook elements and internal math ops. - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] res - * @param[in] params VQ and PQ parameters for compressing the data - * @param[in] dataset a row-major mdspan or mdarray (device or host) [n_rows, dim]. - */ -template -auto compress(const raft::resources& res, const compression_params& params, const DatasetT& dataset) - -> compressed_dataset -{ - return detail::compress(res, params, dataset); -} - /** * @brief Search ANN using the constructed index. * diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp index aaf6db74fa..7eb5e21f53 100644 --- a/cpp/include/raft/neighbors/cagra_types.hpp +++ b/cpp/include/raft/neighbors/cagra_types.hpp @@ -17,6 +17,7 @@ #pragma once #include "ann_types.hpp" +#include "dataset.hpp" #include #include @@ -36,10 +37,6 @@ #include #include -#ifdef __cpp_lib_bitops -#include -#endif - namespace raft::neighbors::cagra { /** * @addtogroup cagra @@ -66,43 +63,8 @@ struct index_params : ann::index_params { graph_build_algo build_algo = graph_build_algo::IVF_PQ; /** Number of Iterations to run if building with NN_DESCENT */ size_t nn_descent_niter = 20; -}; - -/** Parameters for CAGRA-Q compression. */ -struct compression_params { - /** - * The bit length of the vector element after compression by PQ. - * - * Possible values: [4, 5, 6, 7, 8]. - * - * Hint: the smaller the 'pq_bits', the smaller the index size and the better the search - * performance, but the lower the recall. - */ - uint32_t pq_bits = 8; - /** - * The dimensionality of the vector after compression by PQ. - * When zero, an optimal value is selected using a heuristic. - * - * TODO: at the moment `dim` must be a multiple `pq_dim`. - */ - uint32_t pq_dim = 0; - /** - * Vector Quantization (VQ) codebook size - number of "coarse cluster centers". - * When zero, an optimal value is selected using a heuristic. - */ - uint32_t vq_n_centers = 0; - /** The number of iterations searching for kmeans centers (both VQ & PQ phases). */ - uint32_t kmeans_n_iters = 25; - /** - * The fraction of data to use during iterative kmeans building (VQ phase). - * When zero, an optimal value is selected using a heuristic. - */ - double vq_kmeans_trainset_fraction = 0; - /** - * The fraction of data to use during iterative kmeans building (PQ phase). - * When zero, an optimal value is selected using a heuristic. - */ - double pq_kmeans_trainset_fraction = 0; + /** Specify compression params if compression is desired. */ + std::optional compression = std::nullopt; }; enum class search_algo { @@ -187,14 +149,12 @@ struct index : ann::index { /** Total length of the index (number of vectors). */ [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT { - return dataset_view_.extent(0) ? dataset_view_.extent(0) : graph_view_.extent(0); + auto data_rows = dataset_->n_rows(); + return data_rows > 0 ? data_rows : graph_view_.extent(0); } /** Dimensionality of the data. */ - [[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t - { - return dataset_view_.extent(1); - } + [[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t { return dataset_->dim(); } /** Graph degree */ [[nodiscard]] constexpr inline auto graph_degree() const noexcept -> uint32_t { @@ -205,7 +165,10 @@ struct index : ann::index { [[nodiscard]] inline auto dataset() const noexcept -> device_matrix_view { - return dataset_view_; + auto p = dynamic_cast*>(dataset_.get()); + if (p != nullptr) { return p->view(); } + auto d = dataset_->dim(); + return make_device_strided_matrix_view(nullptr, 0, d, d); } /** neighborhood graph [size, graph-degree] */ @@ -227,7 +190,7 @@ struct index : ann::index { raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded) : ann::index(), metric_(metric), - dataset_(make_device_matrix(res, 0, 0)), + dataset_(new neighbors::empty_dataset(0)), graph_(make_device_matrix(res, 0, 0)) { } @@ -293,12 +256,11 @@ struct index : ann::index { mdspan, row_major, graph_accessor> knn_graph) : ann::index(), metric_(metric), - dataset_(make_device_matrix(res, 0, 0)), + dataset_(construct_aligned_dataset(res, dataset, uint32_t{16})), graph_(make_device_matrix(res, 0, 0)) { RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0), "Dataset and knn_graph must have equal number of rows"); - update_dataset(res, dataset); update_graph(res, knn_graph); resource::sync_stream(res); } @@ -313,21 +275,14 @@ struct index : ann::index { void update_dataset(raft::resources const& res, raft::device_matrix_view dataset) { - if (dataset.extent(1) * sizeof(T) % 16 != 0) { - RAFT_LOG_DEBUG("Creating a padded copy of CAGRA dataset in device memory"); - copy_padded(res, dataset); - } else { - dataset_view_ = make_device_strided_matrix_view( - dataset.data_handle(), dataset.extent(0), dataset.extent(1), dataset.extent(1)); - } + construct_aligned_dataset(res, dataset, 16).swap(dataset_); } /** Set the dataset reference explicitly to a device matrix view with padding. */ - void update_dataset(raft::resources const&, + void update_dataset(raft::resources const& res, raft::device_matrix_view dataset) { - RAFT_EXPECTS(dataset.stride(0) * sizeof(T) % 16 == 0, "Incorrect data padding."); - dataset_view_ = dataset; + construct_aligned_dataset(res, dataset, 16).swap(dataset_); } /** @@ -338,8 +293,7 @@ struct index : ann::index { void update_dataset(raft::resources const& res, raft::host_matrix_view dataset) { - RAFT_LOG_DEBUG("Copying CAGRA dataset from host to device"); - copy_padded(res, dataset); + construct_aligned_dataset(res, dataset, 16).swap(dataset_); } /** @@ -376,91 +330,10 @@ struct index : ann::index { } private: - /** Create a device copy of the dataset, and pad it if necessary. */ - template - void copy_padded(raft::resources const& res, - mdspan, row_major, data_accessor> dataset) - { - detail::copy_with_padding(res, dataset_, dataset); - - dataset_view_ = make_device_strided_matrix_view( - dataset_.data_handle(), dataset_.extent(0), dataset.extent(1), dataset_.extent(1)); - RAFT_LOG_DEBUG("CAGRA dataset strided matrix view %zux%zu, stride %zu", - static_cast(dataset_view_.extent(0)), - static_cast(dataset_view_.extent(1)), - static_cast(dataset_view_.stride(0))); - } - raft::distance::DistanceType metric_; - raft::device_matrix dataset_; raft::device_matrix graph_; - raft::device_matrix_view dataset_view_; raft::device_matrix_view graph_view_; -}; - -/** - * @brief CAGRA-Q compressed dataset. - * - * @tparam MathT the type of elements in the codebooks - * @tparam IdxT type of the vector indices (represent dataset.extent(0)) - * - */ -template -struct compressed_dataset { - /** Vector Quantization codebook - "coarse cluster centers". */ - device_matrix vq_code_book; - /** Product Quantization codebook - "fine cluster centers". */ - device_matrix pq_code_book; - /** Compressed dataset. */ - device_matrix dataset; - - /** Total length of the index. */ - [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT { return dataset.extent(0); } - /** Row length of the encoded data in bytes. */ - [[nodiscard]] constexpr inline auto encoded_row_length() const noexcept -> uint32_t - { - return dataset.extent(1); - } - /** Dimensionality of the data. */ - [[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t - { - return vq_code_book.extent(1); - } - /** The number of "coarse cluster centers" */ - [[nodiscard]] constexpr inline auto vq_n_centers() const noexcept -> uint32_t - { - return vq_code_book.extent(0); - } - /** The bit length of an encoded vector element after compression by PQ. */ - [[nodiscard]] constexpr inline auto pq_bits() const noexcept -> uint32_t - { - auto pq_width = pq_n_centers(); -#ifdef __cpp_lib_bitops - return std::countr_zero(pq_width); -#else - uint32_t pq_bits = 0; - while (pq_width > 1) { - pq_bits++; - pq_width >>= 1; - } - return pq_bits; -#endif - } - /** The dimensionality of an encoded vector after compression by PQ. */ - [[nodiscard]] constexpr inline auto pq_dim() const noexcept -> uint32_t - { - return raft::div_rounding_up_unsafe(dim(), pq_len()); - } - /** Dimensionality of a subspaces, i.e. the number of vector components mapped to a subspace */ - [[nodiscard]] constexpr inline auto pq_len() const noexcept -> uint32_t - { - return pq_code_book.extent(1); - } - /** The number of vectors in a PQ codebook (`1 << pq_bits`). */ - [[nodiscard]] constexpr inline auto pq_n_centers() const noexcept -> uint32_t - { - return pq_code_book.extent(0); - } + std::unique_ptr> dataset_; }; /** @} */ diff --git a/cpp/include/raft/neighbors/dataset.hpp b/cpp/include/raft/neighbors/dataset.hpp new file mode 100644 index 0000000000..cc655438bc --- /dev/null +++ b/cpp/include/raft/neighbors/dataset.hpp @@ -0,0 +1,278 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include // get_device_for_address +#include // rounding up + +#include +#include + +#ifdef __cpp_lib_bitops +#include +#endif + +namespace raft::neighbors { + +/** Two-dimensional dataset; maybe owning, maybe compressed, maybe strided. */ +template +struct dataset { + /** Size of the dataset. */ + [[nodiscard]] virtual auto n_rows() const noexcept -> IdxT; + /** Dimensionality of the dataset. */ + [[nodiscard]] virtual auto dim() const noexcept -> uint32_t; + /** Whether the object owns the data. */ + [[nodiscard]] virtual auto is_owning() const noexcept -> bool; + virtual ~dataset() noexcept = default; +}; + +template +struct empty_dataset : public dataset { + uint32_t suggested_dim; + explicit empty_dataset(uint32_t dim) noexcept : suggested_dim(0) {} + [[nodiscard]] auto n_rows() const noexcept -> IdxT final { return 0; } + [[nodiscard]] auto dim() const noexcept -> uint32_t final { return suggested_dim; } + [[nodiscard]] auto is_owning() const noexcept -> bool final { return true; } +}; + +template +struct strided_dataset : public dataset { + using view_type = device_matrix_view; + [[nodiscard]] auto n_rows() const noexcept -> IdxT final { return view().extent(0); } + [[nodiscard]] auto dim() const noexcept -> uint32_t final + { + return static_cast(view().extent(1)); + } + /** Leading dimension of the dataset. */ + [[nodiscard]] constexpr auto stride() const noexcept -> uint32_t + { + return static_cast(view().stride(0)); + } + /** Get the view of the data. */ + [[nodiscard]] virtual auto view() const noexcept -> view_type; +}; + +template +struct non_owning_dataset : public strided_dataset { + using typename strided_dataset::view_type; + view_type data; + explicit non_owning_dataset(view_type data) noexcept : data(data) {} + [[nodiscard]] auto is_owning() const noexcept -> bool final { return false; } + [[nodiscard]] auto view() const noexcept -> view_type final { return data; }; +}; + +template +struct owning_dataset : public strided_dataset { + using typename strided_dataset::view_type; + using storage_type = mdarray, LayoutPolicy, ContainerPolicy>; + using mapping_type = typename view_type::mapping_type; + storage_type data; + mapping_type view_mapping; + owning_dataset(storage_type&& data, mapping_type view_mapping) noexcept + : data{data}, view_mapping{view_mapping} + { + } + + [[nodiscard]] auto is_owning() const noexcept -> bool final { return true; } + [[nodiscard]] auto view() const noexcept -> view_type final + { + return view_type{data.data_handle(), view_mapping}; + }; +}; + +template +auto construct_strided_dataset(const raft::resources& res, + const SrcT& src, + uint32_t required_stride) + -> std::unique_ptr> +{ + using extents_type = typename SrcT::extents_type; + using value_type = typename SrcT::value_type; + using index_type = typename SrcT::index_type; + using layout_type = typename SrcT::layout_type; + using out_type = strided_dataset; + static_assert(extents_type::rank() == 2, "The input must be a matrix."); + static_assert(std::is_same_v || + std::is_same_v> || + std::is_same_v, + "The input must be row-major"); + RAFT_EXPECTS(src.extent(1) <= required_stride, + "The input row length must be not larger than the desired stride."); + const bool device_accessible = get_device_for_address(src.data_handle()) >= 0; + const bool row_major = src.stride(1) == 0; + const bool stride_matches = required_stride == src.stride(0); + + if (device_accessible && row_major && stride_matches) { + // Everything matches: make a non-owning dataset + return std::unique_ptr{new non_owning_dataset{ + make_device_strided_matrix_view( + src.data_handle(), src.extent(0), src.extent(1), required_stride)}}; + } + // Something is wrong: have to make a copy and produce an owning dataset + using out_mdarray_type = device_mdarray, layout_stride>; + using out_layout_type = typename out_mdarray_type::layout_type; + using out_container_policy_type = typename out_mdarray_type::container_policy_type; + using out_owning_type = + owning_dataset; + auto out_layout = + make_strided_layout(src.extents(), std::array{required_stride, 1}); + auto out_array = out_mdarray_type{res, out_layout, out_container_policy_type{}}; + + RAFT_CUDA_TRY(cudaMemsetAsync(out_array.data_handle(), + 0, + out_array.size() * sizeof(value_type), + resource::get_cuda_stream(res))); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(out_array.data_handle(), + sizeof(value_type) * required_stride, + src.data_handle(), + sizeof(value_type) * src.extent(1), + sizeof(value_type) * src.extent(1), + src.extent(0), + cudaMemcpyDefault, + resource::get_cuda_stream(res))); + + return std::unique_ptr{new out_owning_type{std::move(out_array), out_layout}}; +} + +template +auto construct_aligned_dataset(const raft::resources& res, const SrcT& src, uint32_t align_bytes) + -> std::unique_ptr> +{ + using value_type = typename SrcT::value_type; + using index_type = typename SrcT::index_type; + using out_type = dataset; + constexpr size_t kSize = sizeof(value_type); + uint32_t required_stride = + raft::round_up_safe(src.extent(1) * kSize, align_bytes) / kSize; + return std::unique_ptr{construct_strided_dataset(res, src, required_stride).release()}; +} + +/** Parameters for VPQ compression. */ +struct vpq_params { + /** + * The bit length of the vector element after compression by PQ. + * + * Possible values: [4, 5, 6, 7, 8]. + * + * Hint: the smaller the 'pq_bits', the smaller the index size and the better the search + * performance, but the lower the recall. + */ + uint32_t pq_bits = 8; + /** + * The dimensionality of the vector after compression by PQ. + * When zero, an optimal value is selected using a heuristic. + * + * TODO: at the moment `dim` must be a multiple `pq_dim`. + */ + uint32_t pq_dim = 0; + /** + * Vector Quantization (VQ) codebook size - number of "coarse cluster centers". + * When zero, an optimal value is selected using a heuristic. + */ + uint32_t vq_n_centers = 0; + /** The number of iterations searching for kmeans centers (both VQ & PQ phases). */ + uint32_t kmeans_n_iters = 25; + /** + * The fraction of data to use during iterative kmeans building (VQ phase). + * When zero, an optimal value is selected using a heuristic. + */ + double vq_kmeans_trainset_fraction = 0; + /** + * The fraction of data to use during iterative kmeans building (PQ phase). + * When zero, an optimal value is selected using a heuristic. + */ + double pq_kmeans_trainset_fraction = 0; +}; + +/** + * @brief VPQ compressed dataset. + * + * Twice quantized data: + * + * 1. Vector Quantization + * 2. Product Quantization of residuals + * + * @tparam MathT the type of elements in the codebooks + * @tparam IdxT type of the vector indices (represent dataset.extent(0)) + * + */ +template +struct vpq_dataset : public dataset { + /** Vector Quantization codebook - "coarse cluster centers". */ + device_matrix vq_code_book; + /** Product Quantization codebook - "fine cluster centers". */ + device_matrix pq_code_book; + /** Compressed dataset. */ + device_matrix data; + + vpq_dataset(device_matrix&& vq_code_book, + device_matrix&& pq_code_book, + device_matrix&& data) + : vq_code_book{vq_code_book}, pq_code_book{pq_code_book}, data{data} + { + } + + [[nodiscard]] auto n_rows() const noexcept -> IdxT final { return data.extent(0); } + [[nodiscard]] auto dim() const noexcept -> uint32_t final { return vq_code_book.extent(1); } + [[nodiscard]] auto is_owning() const noexcept -> bool final { return true; } + + /** Row length of the encoded data in bytes. */ + [[nodiscard]] constexpr inline auto encoded_row_length() const noexcept -> uint32_t + { + return data.extent(1); + } + /** The number of "coarse cluster centers" */ + [[nodiscard]] constexpr inline auto vq_n_centers() const noexcept -> uint32_t + { + return vq_code_book.extent(0); + } + /** The bit length of an encoded vector element after compression by PQ. */ + [[nodiscard]] constexpr inline auto pq_bits() const noexcept -> uint32_t + { + auto pq_width = pq_n_centers(); +#ifdef __cpp_lib_bitops + return std::countr_zero(pq_width); +#else + uint32_t pq_bits = 0; + while (pq_width > 1) { + pq_bits++; + pq_width >>= 1; + } + return pq_bits; +#endif + } + /** The dimensionality of an encoded vector after compression by PQ. */ + [[nodiscard]] constexpr inline auto pq_dim() const noexcept -> uint32_t + { + return raft::div_rounding_up_unsafe(dim(), pq_len()); + } + /** Dimensionality of a subspaces, i.e. the number of vector components mapped to a subspace */ + [[nodiscard]] constexpr inline auto pq_len() const noexcept -> uint32_t + { + return pq_code_book.extent(1); + } + /** The number of vectors in a PQ codebook (`1 << pq_bits`). */ + [[nodiscard]] constexpr inline auto pq_n_centers() const noexcept -> uint32_t + { + return pq_code_book.extent(0); + } +}; + +} // namespace raft::neighbors diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build_q.cuh b/cpp/include/raft/neighbors/detail/vpq_dataset.cuh similarity index 95% rename from cpp/include/raft/neighbors/detail/cagra/cagra_build_q.cuh rename to cpp/include/raft/neighbors/detail/vpq_dataset.cuh index b1b501aaef..42e1e5db41 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_build_q.cuh +++ b/cpp/include/raft/neighbors/detail/vpq_dataset.cuh @@ -15,7 +15,7 @@ */ #pragma once -#include "../../cagra_types.hpp" +#include "../dataset.hpp" #include #include @@ -73,15 +73,14 @@ auto subsample(raft::resources const& res, } // namespace raft::util -namespace raft::neighbors::cagra::detail { +namespace raft::neighbors::detail { template -auto fill_missing_params_heuristics(const compression_params& params, const DatasetT& dataset) - -> compression_params +auto fill_missing_params_heuristics(const vpq_params& params, const DatasetT& dataset) -> vpq_params { - compression_params r = params; - double n_rows = dataset.extent(0); - size_t dim = dataset.extent(1); + vpq_params r = params; + double n_rows = dataset.extent(0); + size_t dim = dataset.extent(1); if (r.pq_dim == 0) { r.pq_dim = raft::div_rounding_up_safe(dim, size_t{4}); } if (r.pq_bits == 0) { r.pq_bits = 8; } if (r.vq_n_centers == 0) { r.vq_n_centers = raft::round_up_safe(std::sqrt(n_rows), 8); } @@ -122,7 +121,7 @@ auto transform_data(const raft::resources& res, DatasetT dataset) using ix_t = int64_t; template -auto train_vq(const raft::resources& res, const compression_params& params, const DatasetT& dataset) +auto train_vq(const raft::resources& res, const vpq_params& params, const DatasetT& dataset) -> device_matrix { const ix_t n_rows = dataset.extent(0); @@ -186,7 +185,7 @@ auto predict_vq(const raft::resources& res, const DatasetT& dataset, const VqCen template auto train_pq(const raft::resources& res, - const compression_params& params, + const vpq_params& params, const DatasetT& dataset, const device_matrix_view& vq_centers) -> device_matrix @@ -321,14 +320,14 @@ __launch_bounds__(BlockSize) RAFT_KERNEL template auto process_and_fill_codes(const raft::resources& res, - const compression_params& params, + const vpq_params& params, const DatasetT& dataset, device_matrix_view vq_centers, device_matrix_view pq_centers) -> device_matrix { using data_t = typename DatasetT::value_type; - using cdataset_t = compressed_dataset; + using cdataset_t = vpq_dataset; using label_t = uint32_t; const ix_t n_rows = dataset.extent(0); @@ -384,8 +383,8 @@ auto process_and_fill_codes(const raft::resources& res, } template -auto compress(const raft::resources& res, const compression_params& params, const DatasetT& dataset) - -> compressed_dataset +auto vpq_build(const raft::resources& res, const vpq_params& params, const DatasetT& dataset) + -> vpq_dataset { // Use a heuristic to impute missing parameters. auto ps = fill_missing_params_heuristics(params, dataset); @@ -402,7 +401,8 @@ auto compress(const raft::resources& res, const compression_params& params, cons raft::make_const_mdspan(vq_code_book.view()), raft::make_const_mdspan(pq_code_book.view())); - return compressed_dataset{vq_code_book, pq_code_book, codes}; + return vpq_dataset{ + std::move(vq_code_book), std::move(pq_code_book), std::move(codes)}; } -} // namespace raft::neighbors::cagra::detail +} // namespace raft::neighbors::detail diff --git a/cpp/include/raft/neighbors/vpq_dataset.cuh b/cpp/include/raft/neighbors/vpq_dataset.cuh new file mode 100644 index 0000000000..27dbaf1d94 --- /dev/null +++ b/cpp/include/raft/neighbors/vpq_dataset.cuh @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "dataset.hpp" +#include "detail/vpq_dataset.cuh" + +#include + +namespace raft::neighbors { + +/** + * @brief Compress a dataset for use in CAGRA-Q search in place of the original data. + * + * @tparam DatasetT a row-major mdspan or mdarray (device or host). + * @tparam MathT a type of the codebook elements and internal math ops. + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] res + * @param[in] params VQ and PQ parameters for compressing the data + * @param[in] dataset a row-major mdspan or mdarray (device or host) [n_rows, dim]. + */ +template +auto vpq_build(const raft::resources& res, const vpq_params& params, const DatasetT& dataset) + -> vpq_dataset +{ + return detail::vpq_build(res, params, dataset); +} + +} // namespace raft::neighbors From 99fa02f85bf417e677df350c7dc8b4baa6afe184 Mon Sep 17 00:00:00 2001 From: achirkin Date: Thu, 7 Mar 2024 19:12:18 +0100 Subject: [PATCH 04/18] Add dataset compression as an optional step during build --- cpp/include/raft/neighbors/cagra_types.hpp | 16 +++++-- cpp/include/raft/neighbors/dataset.hpp | 42 +++++++++++++------ .../neighbors/detail/cagra/cagra_build.cuh | 11 +++++ 3 files changed, 53 insertions(+), 16 deletions(-) diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp index 7eb5e21f53..34e79987ae 100644 --- a/cpp/include/raft/neighbors/cagra_types.hpp +++ b/cpp/include/raft/neighbors/cagra_types.hpp @@ -256,7 +256,7 @@ struct index : ann::index { mdspan, row_major, graph_accessor> knn_graph) : ann::index(), metric_(metric), - dataset_(construct_aligned_dataset(res, dataset, uint32_t{16})), + dataset_(upcast_dataset_ptr(construct_aligned_dataset(res, dataset, 16))), graph_(make_device_matrix(res, 0, 0)) { RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0), @@ -275,14 +275,14 @@ struct index : ann::index { void update_dataset(raft::resources const& res, raft::device_matrix_view dataset) { - construct_aligned_dataset(res, dataset, 16).swap(dataset_); + upcast_dataset_ptr(construct_aligned_dataset(res, dataset, 16)).swap(dataset_); } /** Set the dataset reference explicitly to a device matrix view with padding. */ void update_dataset(raft::resources const& res, raft::device_matrix_view dataset) { - construct_aligned_dataset(res, dataset, 16).swap(dataset_); + upcast_dataset_ptr(construct_aligned_dataset(res, dataset, 16)).swap(dataset_); } /** @@ -293,7 +293,15 @@ struct index : ann::index { void update_dataset(raft::resources const& res, raft::host_matrix_view dataset) { - construct_aligned_dataset(res, dataset, 16).swap(dataset_); + upcast_dataset_ptr(construct_aligned_dataset(res, dataset, 16)).swap(dataset_); + } + + /** Replace the dataset with a new dataset. */ + template + auto update_dataset(raft::resources const& res, DatasetT&& dataset) + -> std::enable_if_t, DatasetT>> + { + upcast_dataset_ptr(std::make_unique(std::move(dataset))).swap(dataset_); } /** diff --git a/cpp/include/raft/neighbors/dataset.hpp b/cpp/include/raft/neighbors/dataset.hpp index cc655438bc..8586757679 100644 --- a/cpp/include/raft/neighbors/dataset.hpp +++ b/cpp/include/raft/neighbors/dataset.hpp @@ -34,28 +34,32 @@ namespace raft::neighbors { /** Two-dimensional dataset; maybe owning, maybe compressed, maybe strided. */ template struct dataset { + using index_type = IdxT; /** Size of the dataset. */ - [[nodiscard]] virtual auto n_rows() const noexcept -> IdxT; + [[nodiscard]] virtual auto n_rows() const noexcept -> index_type = 0; /** Dimensionality of the dataset. */ - [[nodiscard]] virtual auto dim() const noexcept -> uint32_t; + [[nodiscard]] virtual auto dim() const noexcept -> uint32_t = 0; /** Whether the object owns the data. */ - [[nodiscard]] virtual auto is_owning() const noexcept -> bool; - virtual ~dataset() noexcept = default; + [[nodiscard]] virtual auto is_owning() const noexcept -> bool = 0; + virtual ~dataset() noexcept = default; }; template struct empty_dataset : public dataset { + using index_type = IdxT; uint32_t suggested_dim; explicit empty_dataset(uint32_t dim) noexcept : suggested_dim(0) {} - [[nodiscard]] auto n_rows() const noexcept -> IdxT final { return 0; } + [[nodiscard]] auto n_rows() const noexcept -> index_type final { return 0; } [[nodiscard]] auto dim() const noexcept -> uint32_t final { return suggested_dim; } [[nodiscard]] auto is_owning() const noexcept -> bool final { return true; } }; template struct strided_dataset : public dataset { - using view_type = device_matrix_view; - [[nodiscard]] auto n_rows() const noexcept -> IdxT final { return view().extent(0); } + using index_type = IdxT; + using value_type = DataT; + using view_type = device_matrix_view; + [[nodiscard]] auto n_rows() const noexcept -> index_type final { return view().extent(0); } [[nodiscard]] auto dim() const noexcept -> uint32_t final { return static_cast(view().extent(1)); @@ -71,7 +75,9 @@ struct strided_dataset : public dataset { template struct non_owning_dataset : public strided_dataset { - using typename strided_dataset::view_type; + using index_type = IdxT; + using value_type = DataT; + using typename strided_dataset::view_type; view_type data; explicit non_owning_dataset(view_type data) noexcept : data(data) {} [[nodiscard]] auto is_owning() const noexcept -> bool final { return false; } @@ -80,8 +86,11 @@ struct non_owning_dataset : public strided_dataset { template struct owning_dataset : public strided_dataset { - using typename strided_dataset::view_type; - using storage_type = mdarray, LayoutPolicy, ContainerPolicy>; + using index_type = IdxT; + using value_type = DataT; + using typename strided_dataset::view_type; + using storage_type = + mdarray, LayoutPolicy, ContainerPolicy>; using mapping_type = typename view_type::mapping_type; storage_type data; mapping_type view_mapping; @@ -153,17 +162,26 @@ auto construct_strided_dataset(const raft::resources& res, template auto construct_aligned_dataset(const raft::resources& res, const SrcT& src, uint32_t align_bytes) - -> std::unique_ptr> + -> std::unique_ptr> { using value_type = typename SrcT::value_type; using index_type = typename SrcT::index_type; - using out_type = dataset; + using out_type = strided_dataset; constexpr size_t kSize = sizeof(value_type); uint32_t required_stride = raft::round_up_safe(src.extent(1) * kSize, align_bytes) / kSize; return std::unique_ptr{construct_strided_dataset(res, src, required_stride).release()}; } +template +auto upcast_dataset_ptr(std::unique_ptr&& src) + -> std::unique_ptr> +{ + using out_type = dataset; + static_assert(std::is_base_of_v, "The source must be a child of `dataset`"); + return std::unique_ptr{src.release()}; +} + /** Parameters for VPQ compression. */ struct vpq_params { /** diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh index 08cc2beaeb..743188dae3 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh @@ -16,6 +16,7 @@ #pragma once #include "../../cagra_types.hpp" +#include "../../vpq_dataset.cuh" #include "graph_core.cuh" #include @@ -344,6 +345,16 @@ index build( RAFT_LOG_INFO("Graph optimized, creating index"); // Construct an index from dataset and optimized knn graph. if (construct_index_with_dataset) { + if (params.compression.has_value()) { + index idx(res, params.metric); + idx.update_graph(res, raft::make_const_mdspan(cagra_graph.view())); + idx.update_dataset( + res, + // TODO: ATM, only float math type is supported in kmeans training. + // Later, we can do runtime dispatching of the math type. + neighbors::vpq_build(res, *params.compression, dataset)); + return idx; + } return index(res, params.metric, dataset, raft::make_const_mdspan(cagra_graph.view())); } else { // We just add the graph. User is expected to update dataset separately. This branch is used From 833b50fa9f15d205fa906dac86dfea0c0a8f49b0 Mon Sep 17 00:00:00 2001 From: "Artem M. Chirkin" <9253178+achirkin@users.noreply.github.com> Date: Fri, 8 Mar 2024 06:34:59 +0100 Subject: [PATCH 05/18] Update cpp/include/raft/neighbors/dataset.hpp Co-authored-by: Tamas Bela Feher --- cpp/include/raft/neighbors/dataset.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/neighbors/dataset.hpp b/cpp/include/raft/neighbors/dataset.hpp index 8586757679..0a3d770079 100644 --- a/cpp/include/raft/neighbors/dataset.hpp +++ b/cpp/include/raft/neighbors/dataset.hpp @@ -222,7 +222,7 @@ struct vpq_params { /** * @brief VPQ compressed dataset. * - * Twice quantized data: + * The dataset is compressed using two level quantization * * 1. Vector Quantization * 2. Product Quantization of residuals From 53a5c14ecb3335fc7def80bd396b9b83fb224a12 Mon Sep 17 00:00:00 2001 From: achirkin Date: Fri, 8 Mar 2024 11:36:52 +0100 Subject: [PATCH 06/18] Add dataset serialization --- .../core/detail/mdspan_numpy_serializer.hpp | 3 +- cpp/include/raft/neighbors/cagra_types.hpp | 14 +- .../neighbors/detail/cagra/cagra_search.cuh | 14 +- .../detail/cagra/cagra_serialize.cuh | 37 +-- .../neighbors/detail/dataset_serialize.hpp | 238 ++++++++++++++++++ 5 files changed, 271 insertions(+), 35 deletions(-) create mode 100644 cpp/include/raft/neighbors/detail/dataset_serialize.hpp diff --git a/cpp/include/raft/core/detail/mdspan_numpy_serializer.hpp b/cpp/include/raft/core/detail/mdspan_numpy_serializer.hpp index 176309c8ce..3fb7b3005b 100644 --- a/cpp/include/raft/core/detail/mdspan_numpy_serializer.hpp +++ b/cpp/include/raft/core/detail/mdspan_numpy_serializer.hpp @@ -126,7 +126,8 @@ inline dtype_t get_numpy_dtype() } #if defined(_RAFT_HAS_CUDA) -template , bool> = true> +template , half>, bool> = true> inline dtype_t get_numpy_dtype() { return {RAFT_NUMPY_HOST_ENDIAN_CHAR, 'e', sizeof(T)}; diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp index 34e79987ae..868f214320 100644 --- a/cpp/include/raft/neighbors/cagra_types.hpp +++ b/cpp/include/raft/neighbors/cagra_types.hpp @@ -162,7 +162,7 @@ struct index : ann::index { } /** Dataset [size, dim] */ - [[nodiscard]] inline auto dataset() const noexcept + [[nodiscard]] inline auto dataset_view() const noexcept -> device_matrix_view { auto p = dynamic_cast*>(dataset_.get()); @@ -171,6 +171,11 @@ struct index : ann::index { return make_device_strided_matrix_view(nullptr, 0, d, d); } + [[nodiscard]] inline auto dataset() const noexcept -> const neighbors::dataset& + { + return *dataset_; + } + /** neighborhood graph [size, graph-degree] */ [[nodiscard]] inline auto graph() const noexcept -> device_matrix_view @@ -304,6 +309,13 @@ struct index : ann::index { upcast_dataset_ptr(std::make_unique(std::move(dataset))).swap(dataset_); } + template + auto update_dataset(raft::resources const& res, std::unique_ptr&& dataset) + -> std::enable_if_t, DatasetT>> + { + upcast_dataset_ptr(std::move(dataset)).swap(dataset_); + } + /** * Replace the graph with a new graph. * diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh index 0832e75633..4b519eadcb 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh @@ -111,8 +111,8 @@ void search_main(raft::resources const& res, CagraSampleFilterT sample_filter = CagraSampleFilterT()) { RAFT_LOG_DEBUG("# dataset size = %lu, dim = %lu\n", - static_cast(index.dataset().extent(0)), - static_cast(index.dataset().extent(1))); + static_cast(index.dataset_view().extent(0)), + static_cast(index.dataset_view().extent(1))); RAFT_LOG_DEBUG("# query size = %lu, dim = %lu\n", static_cast(queries.extent(0)), static_cast(queries.extent(1))); @@ -151,11 +151,11 @@ void search_main(raft::resources const& res, : nullptr; uint32_t* _num_executed_iterations = nullptr; - auto dataset_internal = - make_device_strided_matrix_view(index.dataset().data_handle(), - index.dataset().extent(0), - index.dataset().extent(1), - index.dataset().stride(0)); + auto dataset_internal = make_device_strided_matrix_view( + index.dataset_view().data_handle(), + index.dataset_view().extent(0), + index.dataset_view().extent(1), + index.dataset_view().stride(0)); auto graph_internal = raft::make_device_matrix_view( reinterpret_cast(index.graph().data_handle()), index.graph().extent(0), diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh index d7bd27222b..f5556b256a 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -65,26 +66,14 @@ void serialize(raft::resources const& res, serialize_scalar(res, os, index_.metric()); serialize_mdspan(res, os, index_.graph()); - include_dataset &= (index_.dataset().extent(0) > 0); + include_dataset &= (index_.dataset().n_rows() > 0); serialize_scalar(res, os, include_dataset); if (include_dataset) { RAFT_LOG_INFO("Saving CAGRA index with dataset"); - auto dataset = index_.dataset(); - // Remove padding before saving the dataset - auto host_dataset = make_host_matrix(dataset.extent(0), dataset.extent(1)); - RAFT_CUDA_TRY(cudaMemcpy2DAsync(host_dataset.data_handle(), - sizeof(T) * host_dataset.extent(1), - dataset.data_handle(), - sizeof(T) * dataset.stride(0), - sizeof(T) * host_dataset.extent(1), - dataset.extent(0), - cudaMemcpyDefault, - resource::get_cuda_stream(res))); - resource::sync_stream(res); - serialize_mdspan(res, os, host_dataset.view()); + neighbors::detail::serialize(res, os, index_.dataset()); } else { - RAFT_LOG_INFO("Saving CAGRA index WITHOUT dataset"); + RAFT_LOG_DEBUG("Saving CAGRA index WITHOUT dataset"); } } @@ -158,7 +147,7 @@ void serialize_to_hnswlib(raft::resources const& res, std::size_t efConstruction = 500; os.write(reinterpret_cast(&efConstruction), sizeof(std::size_t)); - auto dataset = index_.dataset(); + auto dataset = index_.dataset_view(); // Remove padding before saving the dataset auto host_dataset = make_host_matrix(dataset.extent(0), dataset.extent(1)); RAFT_CUDA_TRY(cudaMemcpy2DAsync(host_dataset.data_handle(), @@ -256,19 +245,15 @@ auto deserialize(raft::resources const& res, std::istream& is) -> index auto graph = raft::make_host_matrix(n_rows, graph_degree); deserialize_mdspan(res, is, graph.view()); + index idx(res, metric); + idx.update_graph(res, raft::make_const_mdspan(graph.view())); bool has_dataset = deserialize_scalar(res, is); if (has_dataset) { - auto dataset = raft::make_host_matrix(n_rows, dim); - deserialize_mdspan(res, is, dataset.view()); - return index( - res, metric, raft::make_const_mdspan(dataset.view()), raft::make_const_mdspan(graph.view())); - } else { - // create a new index with no dataset - the user must supply via update_dataset themselves - // later (this avoids allocating GPU memory in the meantime) - index idx(res, metric); - idx.update_graph(res, raft::make_const_mdspan(graph.view())); - return idx; + std::unique_ptr> dataset; + neighbors::detail::deserialize(res, is, dataset); + idx.update_dataset(res, std::move(dataset)); } + return idx; } template diff --git a/cpp/include/raft/neighbors/detail/dataset_serialize.hpp b/cpp/include/raft/neighbors/detail/dataset_serialize.hpp new file mode 100644 index 0000000000..2864d260c1 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/dataset_serialize.hpp @@ -0,0 +1,238 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "../dataset.hpp" + +#include +#include +#include + +#include + +#include +#include + +namespace raft::neighbors::detail { + +using dataset_instance_tag = uint32_t; +constexpr dataset_instance_tag kSerializeEmptyDataset = 1; +constexpr dataset_instance_tag kSerializeStridedDataset = 2; +constexpr dataset_instance_tag kSerializeVPQDataset = 3; + +template +void serialize(const raft::resources& res, std::ostream& os, const empty_dataset& dataset) +{ + serialize_scalar(res, os, dataset.suggested_dim); +} + +template +void serialize(const raft::resources& res, + std::ostream& os, + const strided_dataset& dataset) +{ + serialize_scalar(res, os, dataset.n_rows()); + serialize_scalar(res, os, dataset.dim()); + serialize_scalar(res, os, dataset.stride()); + // Remove padding before saving the dataset + auto src = dataset.view(); + auto dst = make_host_mdarray(src.extents()); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(dst.data_handle(), + sizeof(DataT) * dst.extent(1), + src.data_handle(), + sizeof(DataT) * src.stride(0), + sizeof(DataT) * dst.extent(1), + src.extent(0), + cudaMemcpyDefault, + resource::get_cuda_stream(res))); + resource::sync_stream(res); + serialize_mdspan(res, os, dst.view()); +} + +template +void serialize(const raft::resources& res, + std::ostream& os, + const vpq_dataset& dataset) +{ + serialize_scalar(res, os, dataset.n_rows()); + serialize_scalar(res, os, dataset.dim()); + serialize_scalar(res, os, dataset.vq_n_centers()); + serialize_scalar(res, os, dataset.pq_n_centers()); + serialize_scalar(res, os, dataset.pq_len()); + serialize_scalar(res, os, dataset.encoded_row_length()); + serialize_mdspan(res, os, make_const_mdspan(dataset.vq_code_book.view())); + serialize_mdspan(res, os, make_const_mdspan(dataset.pq_code_book.view())); + serialize_mdspan(res, os, make_const_mdspan(dataset.data.view())); +} + +template +void serialize(const raft::resources& res, std::ostream& os, const dataset& dataset) +{ + if (auto x = dynamic_cast*>(&dataset); x != nullptr) { + serialize_scalar(res, os, kSerializeEmptyDataset); + return serialize(res, os, *x); + } + if (auto x = dynamic_cast*>(&dataset); x != nullptr) { + serialize_scalar(res, os, kSerializeStridedDataset); + serialize_scalar(res, os, CUDA_R_32F); + return serialize(res, os, *x); + } + if (auto x = dynamic_cast*>(&dataset); x != nullptr) { + serialize_scalar(res, os, kSerializeStridedDataset); + serialize_scalar(res, os, CUDA_R_16F); + return serialize(res, os, *x); + } + if (auto x = dynamic_cast*>(&dataset); x != nullptr) { + serialize_scalar(res, os, kSerializeStridedDataset); + serialize_scalar(res, os, CUDA_R_8I); + return serialize(res, os, *x); + } + if (auto x = dynamic_cast*>(&dataset); x != nullptr) { + serialize_scalar(res, os, kSerializeStridedDataset); + serialize_scalar(res, os, CUDA_R_8U); + return serialize(res, os, *x); + } + if (auto x = dynamic_cast*>(&dataset); x != nullptr) { + serialize_scalar(res, os, kSerializeVPQDataset); + serialize_scalar(res, os, CUDA_R_32F); + return serialize(res, os, *x); + } + if (auto x = dynamic_cast*>(&dataset); x != nullptr) { + serialize_scalar(res, os, kSerializeVPQDataset); + serialize_scalar(res, os, CUDA_R_16F); + return serialize(res, os, *x); + } + RAFT_FAIL("unsupported dataset type."); +} + +template +void deserialize(raft::resources const& res, + std::istream& is, + std::unique_ptr>& out) +{ + auto suggested_dim = deserialize_scalar(res, is); + return std::make_unique>(suggested_dim).swap(out); +} + +template +void deserialize(raft::resources const& res, + std::istream& is, + std::unique_ptr>& out) +{ + using out_mdarray_type = device_mdarray, layout_stride>; + using out_layout_type = typename out_mdarray_type::layout_type; + using out_container_policy_type = typename out_mdarray_type::container_policy_type; + using out_owning_type = owning_dataset; + + auto n_rows = deserialize_scalar(res, is); + auto dim = deserialize_scalar(res, is); + auto stride = deserialize_scalar(res, is); + auto out_extents = make_extents(n_rows, dim); + auto out_layout = make_strided_layout(out_extents, std::array{stride, 1}); + auto out_array = out_mdarray_type{res, out_layout, out_container_policy_type{}}; + auto host_arrray = make_host_mdarray(out_extents); + deserialize_mdspan(res, is, host_arrray.view()); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(out_array.data_handle(), + sizeof(DataT) * stride, + host_arrray.data_handle(), + sizeof(DataT) * dim, + sizeof(DataT) * dim, + n_rows, + cudaMemcpyDefault, + resource::get_cuda_stream(res))); + return std::unique_ptr>{ + new out_owning_type{std::move(out_array), out_layout}} + .swap(out); +} + +template +void deserialize(raft::resources const& res, + std::istream& is, + std::unique_ptr>& out) +{ + auto n_rows = deserialize_scalar(res, is); + auto dim = deserialize_scalar(res, is); + auto vq_n_centers = deserialize_scalar(res, is); + auto pq_n_centers = deserialize_scalar(res, is); + auto pq_len = deserialize_scalar(res, is); + auto encoded_row_length = deserialize_scalar(res, is); + + auto vq_code_book = make_device_matrix(res, vq_n_centers, dim); + auto pq_code_book = make_device_matrix(res, pq_n_centers, pq_len); + auto data = make_device_matrix(res, n_rows, encoded_row_length); + + deserialize_mdspan(res, is, vq_code_book.view()); + deserialize_mdspan(res, is, pq_code_book.view()); + deserialize_mdspan(res, is, data.view()); + + return std::unique_ptr>{ + new vpq_dataset{std::move(vq_code_book), std::move(pq_code_book), std::move(data)}} + .swap(out); +} + +template +void deserialize(raft::resources const& res, std::istream& is, std::unique_ptr>& out) +{ + switch (deserialize_scalar(res, is)) { + case kSerializeEmptyDataset: { + std::unique_ptr> p; + deserialize(res, is, p); + return upcast_dataset_ptr(std::move(p)).swap(out); + } + case kSerializeStridedDataset: + switch (deserialize_scalar(res, is)) { + case CUDA_R_32F: { + std::unique_ptr> p; + deserialize(res, is, p); + return upcast_dataset_ptr(std::move(p)).swap(out); + } + case CUDA_R_16F: { + std::unique_ptr> p; + deserialize(res, is, p); + return upcast_dataset_ptr(std::move(p)).swap(out); + } + case CUDA_R_8I: { + std::unique_ptr> p; + deserialize(res, is, p); + return upcast_dataset_ptr(std::move(p)).swap(out); + } + case CUDA_R_8U: { + std::unique_ptr> p; + deserialize(res, is, p); + return upcast_dataset_ptr(std::move(p)).swap(out); + } + default: break; + } + case kSerializeVPQDataset: + switch (deserialize_scalar(res, is)) { + case CUDA_R_32F: { + std::unique_ptr> p; + deserialize(res, is, p); + return upcast_dataset_ptr(std::move(p)).swap(out); + } + case CUDA_R_16F: { + std::unique_ptr> p; + deserialize(res, is, p); + return upcast_dataset_ptr(std::move(p)).swap(out); + } + default: break; + } + default: break; + } + RAFT_FAIL("Failed to deserialize dataset: unsupported combination of instance tags."); +} + +} // namespace raft::neighbors::detail From 02f2193fbef23aa81b8d46e115922444a0a5d4a6 Mon Sep 17 00:00:00 2001 From: achirkin Date: Fri, 8 Mar 2024 11:44:44 +0100 Subject: [PATCH 07/18] Add comments regarding the internals of pq_bits/pq_width --- cpp/include/raft/neighbors/dataset.hpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/cpp/include/raft/neighbors/dataset.hpp b/cpp/include/raft/neighbors/dataset.hpp index 0a3d770079..503225de20 100644 --- a/cpp/include/raft/neighbors/dataset.hpp +++ b/cpp/include/raft/neighbors/dataset.hpp @@ -264,6 +264,14 @@ struct vpq_dataset : public dataset { /** The bit length of an encoded vector element after compression by PQ. */ [[nodiscard]] constexpr inline auto pq_bits() const noexcept -> uint32_t { + /* + NOTE: pq_bits and the book size + + Normally, we'd store `pq_bits` as a part of the index. + However, we know there's an invariant `pq_n_centers = 1 << pq_bits`, i.e. the codebook size is + the same as the number of possible code values. Hence, we don't store the pq_bits and derive it + from the array dimensions instead. + */ auto pq_width = pq_n_centers(); #ifdef __cpp_lib_bitops return std::countr_zero(pq_width); From 34a764248775e5eeb0c2343399d2651c2f1c6147 Mon Sep 17 00:00:00 2001 From: achirkin Date: Fri, 8 Mar 2024 12:55:02 +0100 Subject: [PATCH 08/18] Fix incorrect stride assumption that prevented construct_strided_dataset from making a view --- cpp/include/raft/neighbors/dataset.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/neighbors/dataset.hpp b/cpp/include/raft/neighbors/dataset.hpp index 503225de20..da20c8dbaf 100644 --- a/cpp/include/raft/neighbors/dataset.hpp +++ b/cpp/include/raft/neighbors/dataset.hpp @@ -125,7 +125,7 @@ auto construct_strided_dataset(const raft::resources& res, RAFT_EXPECTS(src.extent(1) <= required_stride, "The input row length must be not larger than the desired stride."); const bool device_accessible = get_device_for_address(src.data_handle()) >= 0; - const bool row_major = src.stride(1) == 0; + const bool row_major = src.stride(1) == 1; const bool stride_matches = required_stride == src.stride(0); if (device_accessible && row_major && stride_matches) { From 308870336072eca8bf447a249bfde196bfa7db50 Mon Sep 17 00:00:00 2001 From: achirkin Date: Mon, 11 Mar 2024 14:05:27 +0100 Subject: [PATCH 09/18] Various small changes to the dataset type to improve safety and be more explicit about arguments --- cpp/include/raft/neighbors/dataset.hpp | 36 +++++++++++++++----------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/cpp/include/raft/neighbors/dataset.hpp b/cpp/include/raft/neighbors/dataset.hpp index da20c8dbaf..dd346e6d9d 100644 --- a/cpp/include/raft/neighbors/dataset.hpp +++ b/cpp/include/raft/neighbors/dataset.hpp @@ -23,6 +23,7 @@ #include // rounding up #include +#include #include #ifdef __cpp_lib_bitops @@ -48,7 +49,7 @@ template struct empty_dataset : public dataset { using index_type = IdxT; uint32_t suggested_dim; - explicit empty_dataset(uint32_t dim) noexcept : suggested_dim(0) {} + explicit empty_dataset(uint32_t dim) noexcept : suggested_dim(dim) {} [[nodiscard]] auto n_rows() const noexcept -> index_type final { return 0; } [[nodiscard]] auto dim() const noexcept -> uint32_t final { return suggested_dim; } [[nodiscard]] auto is_owning() const noexcept -> bool final { return true; } @@ -67,7 +68,8 @@ struct strided_dataset : public dataset { /** Leading dimension of the dataset. */ [[nodiscard]] constexpr auto stride() const noexcept -> uint32_t { - return static_cast(view().stride(0)); + auto v = view(); + return static_cast(v.stride(0) > 0 ? v.stride(0) : v.extent(1)); } /** Get the view of the data. */ [[nodiscard]] virtual auto view() const noexcept -> view_type; @@ -79,7 +81,7 @@ struct non_owning_dataset : public strided_dataset { using value_type = DataT; using typename strided_dataset::view_type; view_type data; - explicit non_owning_dataset(view_type data) noexcept : data(data) {} + explicit non_owning_dataset(view_type v) noexcept : data(v) {} [[nodiscard]] auto is_owning() const noexcept -> bool final { return false; } [[nodiscard]] auto view() const noexcept -> view_type final { return data; }; }; @@ -94,8 +96,8 @@ struct owning_dataset : public strided_dataset { using mapping_type = typename view_type::mapping_type; storage_type data; mapping_type view_mapping; - owning_dataset(storage_type&& data, mapping_type view_mapping) noexcept - : data{data}, view_mapping{view_mapping} + owning_dataset(storage_type&& store, mapping_type view_mapping) noexcept + : data{std::move(store)}, view_mapping{view_mapping} { } @@ -124,9 +126,10 @@ auto construct_strided_dataset(const raft::resources& res, "The input must be row-major"); RAFT_EXPECTS(src.extent(1) <= required_stride, "The input row length must be not larger than the desired stride."); + const uint32_t src_stride = src.stride(0) > 0 ? src.stride(0) : src.extent(1); const bool device_accessible = get_device_for_address(src.data_handle()) >= 0; - const bool row_major = src.stride(1) == 1; - const bool stride_matches = required_stride == src.stride(0); + const bool row_major = src.stride(1) <= 1; + const bool stride_matches = required_stride == src_stride; if (device_accessible && row_major && stride_matches) { // Everything matches: make a non-owning dataset @@ -135,14 +138,15 @@ auto construct_strided_dataset(const raft::resources& res, src.data_handle(), src.extent(0), src.extent(1), required_stride)}}; } // Something is wrong: have to make a copy and produce an owning dataset - using out_mdarray_type = device_mdarray, layout_stride>; - using out_layout_type = typename out_mdarray_type::layout_type; + auto out_layout = + make_strided_layout(src.extents(), std::array{required_stride, 1}); + auto out_array = make_device_matrix(res, src.extent(0), required_stride); + + using out_mdarray_type = decltype(out_array); + using out_layout_type = typename out_mdarray_type::layout_type; using out_container_policy_type = typename out_mdarray_type::container_policy_type; using out_owning_type = owning_dataset; - auto out_layout = - make_strided_layout(src.extents(), std::array{required_stride, 1}); - auto out_array = out_mdarray_type{res, out_layout, out_container_policy_type{}}; RAFT_CUDA_TRY(cudaMemsetAsync(out_array.data_handle(), 0, @@ -151,7 +155,7 @@ auto construct_strided_dataset(const raft::resources& res, RAFT_CUDA_TRY(cudaMemcpy2DAsync(out_array.data_handle(), sizeof(value_type) * required_stride, src.data_handle(), - sizeof(value_type) * src.extent(1), + sizeof(value_type) * src_stride, sizeof(value_type) * src.extent(1), src.extent(0), cudaMemcpyDefault, @@ -169,7 +173,7 @@ auto construct_aligned_dataset(const raft::resources& res, const SrcT& src, uint using out_type = strided_dataset; constexpr size_t kSize = sizeof(value_type); uint32_t required_stride = - raft::round_up_safe(src.extent(1) * kSize, align_bytes) / kSize; + raft::round_up_safe(src.extent(1) * kSize, std::lcm(align_bytes, kSize)) / kSize; return std::unique_ptr{construct_strided_dataset(res, src, required_stride).release()}; } @@ -243,7 +247,9 @@ struct vpq_dataset : public dataset { vpq_dataset(device_matrix&& vq_code_book, device_matrix&& pq_code_book, device_matrix&& data) - : vq_code_book{vq_code_book}, pq_code_book{pq_code_book}, data{data} + : vq_code_book{std::move(vq_code_book)}, + pq_code_book{std::move(pq_code_book)}, + data{std::move(data)} { } From 4498a22eb5567818b921f4544d4a52a141838a32 Mon Sep 17 00:00:00 2001 From: achirkin Date: Mon, 11 Mar 2024 14:49:56 +0100 Subject: [PATCH 10/18] Add a stub for the search function --- cpp/include/raft/neighbors/cagra.cuh | 116 +++++++++++++-------------- 1 file changed, 58 insertions(+), 58 deletions(-) diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index b8258297e6..2446a95ea7 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -25,6 +25,7 @@ #include #include #include +#include #include @@ -279,62 +280,6 @@ index build(raft::resources const& res, return detail::build(res, params, dataset); } -/** - * @brief Search ANN using the constructed index. - * - * See the [cagra::build](#cagra::build) documentation for a usage example. - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] res raft resources - * @param[in] params configure the search - * @param[in] idx cagra index - * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, - * k] - */ -template -void search(raft::resources const& res, - const search_params& params, - const index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances) -{ - RAFT_EXPECTS( - queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), - "Number of rows in output neighbors and distances matrices must equal the number of queries."); - - RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), - "Number of columns in output neighbors and distances matrices must equal k"); - RAFT_EXPECTS(queries.extent(1) == idx.dim(), - "Number of query dimensions should equal number of dimensions in the index."); - - using internal_IdxT = typename std::make_unsigned::type; - auto queries_internal = raft::make_device_matrix_view( - queries.data_handle(), queries.extent(0), queries.extent(1)); - auto neighbors_internal = raft::make_device_matrix_view( - reinterpret_cast(neighbors.data_handle()), - neighbors.extent(0), - neighbors.extent(1)); - auto distances_internal = raft::make_device_matrix_view( - distances.data_handle(), distances.extent(0), distances.extent(1)); - - cagra::detail::search_main(res, - params, - idx, - queries_internal, - neighbors_internal, - distances_internal, - raft::neighbors::filtering::none_cagra_sample_filter()); -} - /** * @brief Search ANN using the constructed index with the given sample filter. * @@ -401,8 +346,63 @@ void search_with_filtering(raft::resources const& res, auto distances_internal = raft::make_device_matrix_view( distances.data_handle(), distances.extent(0), distances.extent(1)); - cagra::detail::search_main( - res, params, idx, queries_internal, neighbors_internal, distances_internal, sample_filter); + // n_rows has the same type as the dataset index (the array extents type) + using ds_idx_type = decltype(idx.dataset().n_rows()); + // Dispatch search parameters based on the dataset kind. + if (auto* strided_dset = dynamic_cast*>(&idx.dataset()); + strided_dset != nullptr) { + // Search using a plain (strided) row-major dataset + return cagra::detail::search_main( + res, params, idx, queries_internal, neighbors_internal, distances_internal, sample_filter); + } else if (auto* vpq_dset = + dynamic_cast*>(&idx.dataset()); + vpq_dset != nullptr) { + // Search using a compressed dataset + RAFT_FAIL("FP32 VPQ dataset support is coming soon"); + } else if (auto* vpq_dset = + dynamic_cast*>(&idx.dataset()); + vpq_dset != nullptr) { + // Search using a compressed dataset + RAFT_FAIL("FP16 VPQ dataset support is coming soon"); + } else if (auto* empty_dset = dynamic_cast*>(&idx.dataset()); + empty_dset != nullptr) { + // Forgot to add a dataset. + RAFT_FAIL( + "Attempted to search without a dataset. Please call index.update_dataset(...) first."); + } else { + // This is a logic error. + RAFT_FAIL("Unrecognized dataset format"); + } +} + +/** + * @brief Search ANN using the constructed index. + * + * See the [cagra::build](#cagra::build) documentation for a usage example. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] res raft resources + * @param[in] params configure the search + * @param[in] idx cagra index + * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * k] + */ +template +void search(raft::resources const& res, + const search_params& params, + const index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) +{ + using none_filter_type = raft::neighbors::filtering::none_cagra_sample_filter; + return cagra::search_with_filtering( + res, params, idx, queries, neighbors, distances, none_filter_type{}); } /** @} */ // end group cagra From dd1cc9952886c009deeaa97dd7f23cbd2eabaf27 Mon Sep 17 00:00:00 2001 From: achirkin Date: Mon, 11 Mar 2024 15:55:46 +0100 Subject: [PATCH 11/18] Switch to half as the vpq codebook type --- .../neighbors/detail/cagra/cagra_build.cuh | 5 ++--- .../raft/neighbors/detail/vpq_dataset.cuh | 19 +++++++++++++++++++ cpp/include/raft/neighbors/vpq_dataset.cuh | 7 ++++++- 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh index 743188dae3..d91e45257e 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh @@ -350,9 +350,8 @@ index build( idx.update_graph(res, raft::make_const_mdspan(cagra_graph.view())); idx.update_dataset( res, - // TODO: ATM, only float math type is supported in kmeans training. - // Later, we can do runtime dispatching of the math type. - neighbors::vpq_build(res, *params.compression, dataset)); + // TODO: hardcoding codebook math to `half`, we can do runtime dispatching later + neighbors::vpq_build(res, *params.compression, dataset)); return idx; } return index(res, params.metric, dataset, raft::make_const_mdspan(cagra_graph.view())); diff --git a/cpp/include/raft/neighbors/detail/vpq_dataset.cuh b/cpp/include/raft/neighbors/detail/vpq_dataset.cuh index 42e1e5db41..f6cd2a1ceb 100644 --- a/cpp/include/raft/neighbors/detail/vpq_dataset.cuh +++ b/cpp/include/raft/neighbors/detail/vpq_dataset.cuh @@ -382,6 +382,25 @@ auto process_and_fill_codes(const raft::resources& res, return codes; } +template +auto vpq_convert_math_type(const raft::resources& res, vpq_dataset&& src) + -> vpq_dataset +{ + auto vq_code_book = make_device_mdarray(res, src.vq_code_book.extents()); + auto pq_code_book = make_device_mdarray(res, src.pq_code_book.extents()); + + linalg::map(res, + vq_code_book.view(), + spatial::knn::detail::utils::mapping{}, + raft::make_const_mdspan(src.vq_code_book.view())); + linalg::map(res, + pq_code_book.view(), + spatial::knn::detail::utils::mapping{}, + raft::make_const_mdspan(src.pq_code_book.view())); + return vpq_dataset{ + std::move(vq_code_book), std::move(pq_code_book), std::move(src.data)}; +} + template auto vpq_build(const raft::resources& res, const vpq_params& params, const DatasetT& dataset) -> vpq_dataset diff --git a/cpp/include/raft/neighbors/vpq_dataset.cuh b/cpp/include/raft/neighbors/vpq_dataset.cuh index 27dbaf1d94..73ee6c52ed 100644 --- a/cpp/include/raft/neighbors/vpq_dataset.cuh +++ b/cpp/include/raft/neighbors/vpq_dataset.cuh @@ -40,7 +40,12 @@ template vpq_dataset { - return detail::vpq_build(res, params, dataset); + if constexpr (std::is_same_v) { + return detail::vpq_convert_math_type( + res, detail::vpq_build(res, params, dataset)); + } else { + return detail::vpq_build(res, params, dataset); + } } } // namespace raft::neighbors From 292406c390ff16ea2bdd2ba0e0df1bbae470c0d2 Mon Sep 17 00:00:00 2001 From: achirkin Date: Mon, 11 Mar 2024 18:21:36 +0100 Subject: [PATCH 12/18] Simplify unique_ptr arithmetics --- cpp/include/raft/neighbors/cagra_types.hpp | 12 ++--- cpp/include/raft/neighbors/dataset.hpp | 20 ++------ .../neighbors/detail/dataset_serialize.hpp | 46 +++++++++++-------- 3 files changed, 36 insertions(+), 42 deletions(-) diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp index 868f214320..9b86ca29f2 100644 --- a/cpp/include/raft/neighbors/cagra_types.hpp +++ b/cpp/include/raft/neighbors/cagra_types.hpp @@ -261,7 +261,7 @@ struct index : ann::index { mdspan, row_major, graph_accessor> knn_graph) : ann::index(), metric_(metric), - dataset_(upcast_dataset_ptr(construct_aligned_dataset(res, dataset, 16))), + dataset_(construct_aligned_dataset(res, dataset, 16)), graph_(make_device_matrix(res, 0, 0)) { RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0), @@ -280,14 +280,14 @@ struct index : ann::index { void update_dataset(raft::resources const& res, raft::device_matrix_view dataset) { - upcast_dataset_ptr(construct_aligned_dataset(res, dataset, 16)).swap(dataset_); + dataset_ = construct_aligned_dataset(res, dataset, 16); } /** Set the dataset reference explicitly to a device matrix view with padding. */ void update_dataset(raft::resources const& res, raft::device_matrix_view dataset) { - upcast_dataset_ptr(construct_aligned_dataset(res, dataset, 16)).swap(dataset_); + dataset_ = construct_aligned_dataset(res, dataset, 16); } /** @@ -298,7 +298,7 @@ struct index : ann::index { void update_dataset(raft::resources const& res, raft::host_matrix_view dataset) { - upcast_dataset_ptr(construct_aligned_dataset(res, dataset, 16)).swap(dataset_); + dataset_ = construct_aligned_dataset(res, dataset, 16); } /** Replace the dataset with a new dataset. */ @@ -306,14 +306,14 @@ struct index : ann::index { auto update_dataset(raft::resources const& res, DatasetT&& dataset) -> std::enable_if_t, DatasetT>> { - upcast_dataset_ptr(std::make_unique(std::move(dataset))).swap(dataset_); + dataset_ = std::make_unique(std::move(dataset)); } template auto update_dataset(raft::resources const& res, std::unique_ptr&& dataset) -> std::enable_if_t, DatasetT>> { - upcast_dataset_ptr(std::move(dataset)).swap(dataset_); + dataset_ = std::move(dataset); } /** diff --git a/cpp/include/raft/neighbors/dataset.hpp b/cpp/include/raft/neighbors/dataset.hpp index dd346e6d9d..23ea6054bd 100644 --- a/cpp/include/raft/neighbors/dataset.hpp +++ b/cpp/include/raft/neighbors/dataset.hpp @@ -118,7 +118,6 @@ auto construct_strided_dataset(const raft::resources& res, using value_type = typename SrcT::value_type; using index_type = typename SrcT::index_type; using layout_type = typename SrcT::layout_type; - using out_type = strided_dataset; static_assert(extents_type::rank() == 2, "The input must be a matrix."); static_assert(std::is_same_v || std::is_same_v> || @@ -133,9 +132,9 @@ auto construct_strided_dataset(const raft::resources& res, if (device_accessible && row_major && stride_matches) { // Everything matches: make a non-owning dataset - return std::unique_ptr{new non_owning_dataset{ + return std::make_unique>( make_device_strided_matrix_view( - src.data_handle(), src.extent(0), src.extent(1), required_stride)}}; + src.data_handle(), src.extent(0), src.extent(1), required_stride)); } // Something is wrong: have to make a copy and produce an owning dataset auto out_layout = @@ -161,7 +160,7 @@ auto construct_strided_dataset(const raft::resources& res, cudaMemcpyDefault, resource::get_cuda_stream(res))); - return std::unique_ptr{new out_owning_type{std::move(out_array), out_layout}}; + return std::make_unique(std::move(out_array), out_layout); } template @@ -169,21 +168,10 @@ auto construct_aligned_dataset(const raft::resources& res, const SrcT& src, uint -> std::unique_ptr> { using value_type = typename SrcT::value_type; - using index_type = typename SrcT::index_type; - using out_type = strided_dataset; constexpr size_t kSize = sizeof(value_type); uint32_t required_stride = raft::round_up_safe(src.extent(1) * kSize, std::lcm(align_bytes, kSize)) / kSize; - return std::unique_ptr{construct_strided_dataset(res, src, required_stride).release()}; -} - -template -auto upcast_dataset_ptr(std::unique_ptr&& src) - -> std::unique_ptr> -{ - using out_type = dataset; - static_assert(std::is_base_of_v, "The source must be a child of `dataset`"); - return std::unique_ptr{src.release()}; + return construct_strided_dataset(res, src, required_stride); } /** Parameters for VPQ compression. */ diff --git a/cpp/include/raft/neighbors/detail/dataset_serialize.hpp b/cpp/include/raft/neighbors/detail/dataset_serialize.hpp index 2864d260c1..529569865b 100644 --- a/cpp/include/raft/neighbors/detail/dataset_serialize.hpp +++ b/cpp/include/raft/neighbors/detail/dataset_serialize.hpp @@ -124,7 +124,7 @@ void deserialize(raft::resources const& res, std::unique_ptr>& out) { auto suggested_dim = deserialize_scalar(res, is); - return std::make_unique>(suggested_dim).swap(out); + out = std::make_unique>(suggested_dim); } template @@ -132,17 +132,18 @@ void deserialize(raft::resources const& res, std::istream& is, std::unique_ptr>& out) { - using out_mdarray_type = device_mdarray, layout_stride>; - using out_layout_type = typename out_mdarray_type::layout_type; - using out_container_policy_type = typename out_mdarray_type::container_policy_type; - using out_owning_type = owning_dataset; - auto n_rows = deserialize_scalar(res, is); auto dim = deserialize_scalar(res, is); auto stride = deserialize_scalar(res, is); auto out_extents = make_extents(n_rows, dim); auto out_layout = make_strided_layout(out_extents, std::array{stride, 1}); - auto out_array = out_mdarray_type{res, out_layout, out_container_policy_type{}}; + auto out_array = make_device_matrix(res, n_rows, stride); + + using out_mdarray_type = decltype(out_array); + using out_layout_type = typename out_mdarray_type::layout_type; + using out_container_policy_type = typename out_mdarray_type::container_policy_type; + using out_owning_type = owning_dataset; + auto host_arrray = make_host_mdarray(out_extents); deserialize_mdspan(res, is, host_arrray.view()); RAFT_CUDA_TRY(cudaMemcpy2DAsync(out_array.data_handle(), @@ -153,9 +154,8 @@ void deserialize(raft::resources const& res, n_rows, cudaMemcpyDefault, resource::get_cuda_stream(res))); - return std::unique_ptr>{ - new out_owning_type{std::move(out_array), out_layout}} - .swap(out); + + out = std::make_unique(std::move(out_array), out_layout); } template @@ -178,9 +178,8 @@ void deserialize(raft::resources const& res, deserialize_mdspan(res, is, pq_code_book.view()); deserialize_mdspan(res, is, data.view()); - return std::unique_ptr>{ - new vpq_dataset{std::move(vq_code_book), std::move(pq_code_book), std::move(data)}} - .swap(out); + out = std::make_unique>( + std::move(vq_code_book), std::move(pq_code_book), std::move(data)); } template @@ -190,29 +189,34 @@ void deserialize(raft::resources const& res, std::istream& is, std::unique_ptr> p; deserialize(res, is, p); - return upcast_dataset_ptr(std::move(p)).swap(out); + out = std::move(p); + return; } case kSerializeStridedDataset: switch (deserialize_scalar(res, is)) { case CUDA_R_32F: { std::unique_ptr> p; deserialize(res, is, p); - return upcast_dataset_ptr(std::move(p)).swap(out); + out = std::move(p); + return; } case CUDA_R_16F: { std::unique_ptr> p; deserialize(res, is, p); - return upcast_dataset_ptr(std::move(p)).swap(out); + out = std::move(p); + return; } case CUDA_R_8I: { std::unique_ptr> p; deserialize(res, is, p); - return upcast_dataset_ptr(std::move(p)).swap(out); + out = std::move(p); + return; } case CUDA_R_8U: { std::unique_ptr> p; deserialize(res, is, p); - return upcast_dataset_ptr(std::move(p)).swap(out); + out = std::move(p); + return; } default: break; } @@ -221,12 +225,14 @@ void deserialize(raft::resources const& res, std::istream& is, std::unique_ptr> p; deserialize(res, is, p); - return upcast_dataset_ptr(std::move(p)).swap(out); + out = std::move(p); + return; } case CUDA_R_16F: { std::unique_ptr> p; deserialize(res, is, p); - return upcast_dataset_ptr(std::move(p)).swap(out); + out = std::move(p); + return; } default: break; } From 24ebae2f7ec0b33c539b53731885a96d338b58e4 Mon Sep 17 00:00:00 2001 From: achirkin Date: Mon, 11 Mar 2024 20:16:42 +0100 Subject: [PATCH 13/18] Fix deserialization: set the padding bytes to zero in the strided dataset. --- .../neighbors/detail/dataset_serialize.hpp | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/dataset_serialize.hpp b/cpp/include/raft/neighbors/detail/dataset_serialize.hpp index 529569865b..dc60a4782d 100644 --- a/cpp/include/raft/neighbors/detail/dataset_serialize.hpp +++ b/cpp/include/raft/neighbors/detail/dataset_serialize.hpp @@ -44,18 +44,21 @@ void serialize(const raft::resources& res, std::ostream& os, const strided_dataset& dataset) { - serialize_scalar(res, os, dataset.n_rows()); - serialize_scalar(res, os, dataset.dim()); - serialize_scalar(res, os, dataset.stride()); + auto n_rows = dataset.n_rows(); + auto dim = dataset.dim(); + auto stride = dataset.stride(); + serialize_scalar(res, os, n_rows); + serialize_scalar(res, os, dim); + serialize_scalar(res, os, stride); // Remove padding before saving the dataset auto src = dataset.view(); - auto dst = make_host_mdarray(src.extents()); + auto dst = make_host_matrix(n_rows, dim); RAFT_CUDA_TRY(cudaMemcpy2DAsync(dst.data_handle(), - sizeof(DataT) * dst.extent(1), + sizeof(DataT) * dim, src.data_handle(), - sizeof(DataT) * src.stride(0), - sizeof(DataT) * dst.extent(1), - src.extent(0), + sizeof(DataT) * stride, + sizeof(DataT) * dim, + n_rows, cudaMemcpyDefault, resource::get_cuda_stream(res))); resource::sync_stream(res); @@ -144,8 +147,10 @@ void deserialize(raft::resources const& res, using out_container_policy_type = typename out_mdarray_type::container_policy_type; using out_owning_type = owning_dataset; - auto host_arrray = make_host_mdarray(out_extents); + auto host_arrray = make_host_matrix(n_rows, dim); deserialize_mdspan(res, is, host_arrray.view()); + RAFT_CUDA_TRY(cudaMemsetAsync( + out_array.data_handle(), 0, sizeof(DataT) * out_array.size(), resource::get_cuda_stream(res))); RAFT_CUDA_TRY(cudaMemcpy2DAsync(out_array.data_handle(), sizeof(DataT) * stride, host_arrray.data_handle(), From cb11327ab4b9f7d2735100f503af03e32791ed56 Mon Sep 17 00:00:00 2001 From: achirkin Date: Tue, 12 Mar 2024 11:33:11 +0100 Subject: [PATCH 14/18] Further simplify deserialization code --- .../detail/cagra/cagra_serialize.cuh | 4 +- .../neighbors/detail/dataset_serialize.hpp | 103 ++++-------------- 2 files changed, 24 insertions(+), 83 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh index f5556b256a..cb8073d8f6 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh @@ -249,9 +249,7 @@ auto deserialize(raft::resources const& res, std::istream& is) -> index idx.update_graph(res, raft::make_const_mdspan(graph.view())); bool has_dataset = deserialize_scalar(res, is); if (has_dataset) { - std::unique_ptr> dataset; - neighbors::detail::deserialize(res, is, dataset); - idx.update_dataset(res, std::move(dataset)); + idx.update_dataset(res, neighbors::detail::deserialize_dataset(res, is)); } return idx; } diff --git a/cpp/include/raft/neighbors/detail/dataset_serialize.hpp b/cpp/include/raft/neighbors/detail/dataset_serialize.hpp index dc60a4782d..dc55d891be 100644 --- a/cpp/include/raft/neighbors/detail/dataset_serialize.hpp +++ b/cpp/include/raft/neighbors/detail/dataset_serialize.hpp @@ -122,51 +122,28 @@ void serialize(const raft::resources& res, std::ostream& os, const dataset } template -void deserialize(raft::resources const& res, - std::istream& is, - std::unique_ptr>& out) +auto deserialize_empty(raft::resources const& res, std::istream& is) + -> std::unique_ptr> { auto suggested_dim = deserialize_scalar(res, is); - out = std::make_unique>(suggested_dim); + return std::make_unique>(suggested_dim); } template -void deserialize(raft::resources const& res, - std::istream& is, - std::unique_ptr>& out) +auto deserialize_strided(raft::resources const& res, std::istream& is) + -> std::unique_ptr> { - auto n_rows = deserialize_scalar(res, is); - auto dim = deserialize_scalar(res, is); - auto stride = deserialize_scalar(res, is); - auto out_extents = make_extents(n_rows, dim); - auto out_layout = make_strided_layout(out_extents, std::array{stride, 1}); - auto out_array = make_device_matrix(res, n_rows, stride); - - using out_mdarray_type = decltype(out_array); - using out_layout_type = typename out_mdarray_type::layout_type; - using out_container_policy_type = typename out_mdarray_type::container_policy_type; - using out_owning_type = owning_dataset; - - auto host_arrray = make_host_matrix(n_rows, dim); - deserialize_mdspan(res, is, host_arrray.view()); - RAFT_CUDA_TRY(cudaMemsetAsync( - out_array.data_handle(), 0, sizeof(DataT) * out_array.size(), resource::get_cuda_stream(res))); - RAFT_CUDA_TRY(cudaMemcpy2DAsync(out_array.data_handle(), - sizeof(DataT) * stride, - host_arrray.data_handle(), - sizeof(DataT) * dim, - sizeof(DataT) * dim, - n_rows, - cudaMemcpyDefault, - resource::get_cuda_stream(res))); - - out = std::make_unique(std::move(out_array), out_layout); + auto n_rows = deserialize_scalar(res, is); + auto dim = deserialize_scalar(res, is); + auto stride = deserialize_scalar(res, is); + auto host_array = make_host_matrix(n_rows, dim); + deserialize_mdspan(res, is, host_array.view()); + return construct_strided_dataset(res, host_array, stride); } template -void deserialize(raft::resources const& res, - std::istream& is, - std::unique_ptr>& out) +auto deserialize_vpq(raft::resources const& res, std::istream& is) + -> std::unique_ptr> { auto n_rows = deserialize_scalar(res, is); auto dim = deserialize_scalar(res, is); @@ -183,62 +160,28 @@ void deserialize(raft::resources const& res, deserialize_mdspan(res, is, pq_code_book.view()); deserialize_mdspan(res, is, data.view()); - out = std::make_unique>( + return std::make_unique>( std::move(vq_code_book), std::move(pq_code_book), std::move(data)); } template -void deserialize(raft::resources const& res, std::istream& is, std::unique_ptr>& out) +auto deserialize_dataset(raft::resources const& res, std::istream& is) + -> std::unique_ptr> { switch (deserialize_scalar(res, is)) { - case kSerializeEmptyDataset: { - std::unique_ptr> p; - deserialize(res, is, p); - out = std::move(p); - return; - } + case kSerializeEmptyDataset: return deserialize_empty(res, is); case kSerializeStridedDataset: switch (deserialize_scalar(res, is)) { - case CUDA_R_32F: { - std::unique_ptr> p; - deserialize(res, is, p); - out = std::move(p); - return; - } - case CUDA_R_16F: { - std::unique_ptr> p; - deserialize(res, is, p); - out = std::move(p); - return; - } - case CUDA_R_8I: { - std::unique_ptr> p; - deserialize(res, is, p); - out = std::move(p); - return; - } - case CUDA_R_8U: { - std::unique_ptr> p; - deserialize(res, is, p); - out = std::move(p); - return; - } + case CUDA_R_32F: return deserialize_strided(res, is); + case CUDA_R_16F: return deserialize_strided(res, is); + case CUDA_R_8I: return deserialize_strided(res, is); + case CUDA_R_8U: return deserialize_strided(res, is); default: break; } case kSerializeVPQDataset: switch (deserialize_scalar(res, is)) { - case CUDA_R_32F: { - std::unique_ptr> p; - deserialize(res, is, p); - out = std::move(p); - return; - } - case CUDA_R_16F: { - std::unique_ptr> p; - deserialize(res, is, p); - out = std::move(p); - return; - } + case CUDA_R_32F: return deserialize_vpq(res, is); + case CUDA_R_16F: return deserialize_vpq(res, is); default: break; } default: break; From 9a5587469f0120aa97257f88295e0f4483ef6987 Mon Sep 17 00:00:00 2001 From: achirkin Date: Wed, 13 Mar 2024 08:50:49 +0100 Subject: [PATCH 15/18] Remove the dynamic dispatch from public search function for it to be moved into detail namespace in #2206 --- cpp/include/raft/neighbors/cagra.cuh | 29 ++-------------------------- 1 file changed, 2 insertions(+), 27 deletions(-) diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index 2446a95ea7..b7e362f704 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -346,33 +346,8 @@ void search_with_filtering(raft::resources const& res, auto distances_internal = raft::make_device_matrix_view( distances.data_handle(), distances.extent(0), distances.extent(1)); - // n_rows has the same type as the dataset index (the array extents type) - using ds_idx_type = decltype(idx.dataset().n_rows()); - // Dispatch search parameters based on the dataset kind. - if (auto* strided_dset = dynamic_cast*>(&idx.dataset()); - strided_dset != nullptr) { - // Search using a plain (strided) row-major dataset - return cagra::detail::search_main( - res, params, idx, queries_internal, neighbors_internal, distances_internal, sample_filter); - } else if (auto* vpq_dset = - dynamic_cast*>(&idx.dataset()); - vpq_dset != nullptr) { - // Search using a compressed dataset - RAFT_FAIL("FP32 VPQ dataset support is coming soon"); - } else if (auto* vpq_dset = - dynamic_cast*>(&idx.dataset()); - vpq_dset != nullptr) { - // Search using a compressed dataset - RAFT_FAIL("FP16 VPQ dataset support is coming soon"); - } else if (auto* empty_dset = dynamic_cast*>(&idx.dataset()); - empty_dset != nullptr) { - // Forgot to add a dataset. - RAFT_FAIL( - "Attempted to search without a dataset. Please call index.update_dataset(...) first."); - } else { - // This is a logic error. - RAFT_FAIL("Unrecognized dataset format"); - } + return cagra::detail::search_main( + res, params, idx, queries_internal, neighbors_internal, distances_internal, sample_filter); } /** From 88566d6885c19de87a223a044a4d79e937a2e779 Mon Sep 17 00:00:00 2001 From: achirkin Date: Wed, 13 Mar 2024 11:28:43 +0100 Subject: [PATCH 16/18] Make the construct_strided_dataset only copy the data when it's not accessible by the current device and document the api --- cpp/include/raft/neighbors/dataset.hpp | 42 ++++++++++++++++++++++++-- 1 file changed, 39 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/neighbors/dataset.hpp b/cpp/include/raft/neighbors/dataset.hpp index 23ea6054bd..46b94de2bf 100644 --- a/cpp/include/raft/neighbors/dataset.hpp +++ b/cpp/include/raft/neighbors/dataset.hpp @@ -108,6 +108,24 @@ struct owning_dataset : public strided_dataset { }; }; +/** + * @brief Contstruct a strided matrix from any mdarray or mdspan. + * + * This function constructs a non-owning view if the input satisfied two conditions: + * + * 1) The data is accessible from the current device + * 2) The memory layout is the same as expected (row-major matrix with the required stride) + * + * Otherwise, this function constructs an owning device matrix and copies the data. + * When the data is copied, padding elements are filled with zeroes. + * + * @tparam SrcT the source mdarray or mdspan + * + * @param[in] res raft resources handle + * @param[in] src the source mdarray or mdspan + * @param[in] required_stride the leading dimension (in elements) + * @return maybe owning current-device-accessible strided matrix + */ template auto construct_strided_dataset(const raft::resources& res, const SrcT& src, @@ -125,8 +143,11 @@ auto construct_strided_dataset(const raft::resources& res, "The input must be row-major"); RAFT_EXPECTS(src.extent(1) <= required_stride, "The input row length must be not larger than the desired stride."); + cudaPointerAttributes ptr_attrs; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&ptr_attrs, src.data_handle())); + auto* device_ptr = reinterpret_cast(ptr_attrs.devicePointer); const uint32_t src_stride = src.stride(0) > 0 ? src.stride(0) : src.extent(1); - const bool device_accessible = get_device_for_address(src.data_handle()) >= 0; + const bool device_accessible = device_ptr != nullptr; const bool row_major = src.stride(1) <= 1; const bool stride_matches = required_stride == src_stride; @@ -134,7 +155,7 @@ auto construct_strided_dataset(const raft::resources& res, // Everything matches: make a non-owning dataset return std::make_unique>( make_device_strided_matrix_view( - src.data_handle(), src.extent(0), src.extent(1), required_stride)); + device_ptr, src.extent(0), src.extent(1), required_stride)); } // Something is wrong: have to make a copy and produce an owning dataset auto out_layout = @@ -163,8 +184,23 @@ auto construct_strided_dataset(const raft::resources& res, return std::make_unique(std::move(out_array), out_layout); } +/** + * @brief Contstruct a strided matrix from any mdarray or mdspan. + * + * A variant `construct_strided_dataset` that allows specifying the byte alignment instead of the + * explicit stride length. + * + * @tparam SrcT the source mdarray or mdspan + * + * @param[in] res raft resources handle + * @param[in] src the source mdarray or mdspan + * @param[in] align_bytes the required byte alignment for the dataset rows. + * @return maybe owning current-device-accessible strided matrix + */ template -auto construct_aligned_dataset(const raft::resources& res, const SrcT& src, uint32_t align_bytes) +auto construct_aligned_dataset(const raft::resources& res, + const SrcT& src, + uint32_t align_bytes = 16) -> std::unique_ptr> { using value_type = typename SrcT::value_type; From 890b29e128eec4c932f4d8a95d9f473574bc243f Mon Sep 17 00:00:00 2001 From: achirkin Date: Thu, 14 Mar 2024 10:44:24 +0100 Subject: [PATCH 17/18] Bump serialization version --- cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh index cb8073d8f6..d15d74f2be 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh @@ -32,7 +32,7 @@ namespace raft::neighbors::cagra::detail { -constexpr int serialization_version = 3; +constexpr int serialization_version = 4; /** * Save the index to file. From 66ae8ae51f6fd3810588619a7a231e955ee3c24f Mon Sep 17 00:00:00 2001 From: achirkin Date: Thu, 14 Mar 2024 19:24:41 +0100 Subject: [PATCH 18/18] Address offline and online review comments --- cpp/include/raft/neighbors/cagra_types.hpp | 26 +++++++++++++------ cpp/include/raft/neighbors/dataset.hpp | 12 +++------ .../neighbors/detail/cagra/cagra_search.cuh | 14 +++++----- .../detail/cagra/cagra_serialize.cuh | 6 ++--- .../neighbors/detail/dataset_serialize.hpp | 2 +- 5 files changed, 33 insertions(+), 27 deletions(-) diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp index 9b86ca29f2..807f89fd65 100644 --- a/cpp/include/raft/neighbors/cagra_types.hpp +++ b/cpp/include/raft/neighbors/cagra_types.hpp @@ -63,7 +63,11 @@ struct index_params : ann::index_params { graph_build_algo build_algo = graph_build_algo::IVF_PQ; /** Number of Iterations to run if building with NN_DESCENT */ size_t nn_descent_niter = 20; - /** Specify compression params if compression is desired. */ + /** + * Specify compression params if compression is desired. + * + * NOTE: this is experimental new API, consider it unsafe. + */ std::optional compression = std::nullopt; }; @@ -161,8 +165,13 @@ struct index : ann::index { return graph_view_.extent(1); } - /** Dataset [size, dim] */ - [[nodiscard]] inline auto dataset_view() const noexcept + /** + * DEPRECATED: please use data() instead. + * If you need to query dataset dimensions, use the dim() and size() of the cagra index. + * The data_handle() is not always available: you need to do a dynamic_cast to the expected + * dataset type at runtime. + */ + [[nodiscard]] [[deprecated("Use data()")]] inline auto dataset() const noexcept -> device_matrix_view { auto p = dynamic_cast*>(dataset_.get()); @@ -171,7 +180,8 @@ struct index : ann::index { return make_device_strided_matrix_view(nullptr, 0, d, d); } - [[nodiscard]] inline auto dataset() const noexcept -> const neighbors::dataset& + /** Dataset [size, dim] */ + [[nodiscard]] inline auto data() const noexcept -> const neighbors::dataset& { return *dataset_; } @@ -261,7 +271,7 @@ struct index : ann::index { mdspan, row_major, graph_accessor> knn_graph) : ann::index(), metric_(metric), - dataset_(construct_aligned_dataset(res, dataset, 16)), + dataset_(make_aligned_dataset(res, dataset, 16)), graph_(make_device_matrix(res, 0, 0)) { RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0), @@ -280,14 +290,14 @@ struct index : ann::index { void update_dataset(raft::resources const& res, raft::device_matrix_view dataset) { - dataset_ = construct_aligned_dataset(res, dataset, 16); + dataset_ = make_aligned_dataset(res, dataset, 16); } /** Set the dataset reference explicitly to a device matrix view with padding. */ void update_dataset(raft::resources const& res, raft::device_matrix_view dataset) { - dataset_ = construct_aligned_dataset(res, dataset, 16); + dataset_ = make_aligned_dataset(res, dataset, 16); } /** @@ -298,7 +308,7 @@ struct index : ann::index { void update_dataset(raft::resources const& res, raft::host_matrix_view dataset) { - dataset_ = construct_aligned_dataset(res, dataset, 16); + dataset_ = make_aligned_dataset(res, dataset, 16); } /** Replace the dataset with a new dataset. */ diff --git a/cpp/include/raft/neighbors/dataset.hpp b/cpp/include/raft/neighbors/dataset.hpp index 46b94de2bf..e7a3ba97a4 100644 --- a/cpp/include/raft/neighbors/dataset.hpp +++ b/cpp/include/raft/neighbors/dataset.hpp @@ -127,9 +127,7 @@ struct owning_dataset : public strided_dataset { * @return maybe owning current-device-accessible strided matrix */ template -auto construct_strided_dataset(const raft::resources& res, - const SrcT& src, - uint32_t required_stride) +auto make_strided_dataset(const raft::resources& res, const SrcT& src, uint32_t required_stride) -> std::unique_ptr> { using extents_type = typename SrcT::extents_type; @@ -187,7 +185,7 @@ auto construct_strided_dataset(const raft::resources& res, /** * @brief Contstruct a strided matrix from any mdarray or mdspan. * - * A variant `construct_strided_dataset` that allows specifying the byte alignment instead of the + * A variant `make_strided_dataset` that allows specifying the byte alignment instead of the * explicit stride length. * * @tparam SrcT the source mdarray or mdspan @@ -198,16 +196,14 @@ auto construct_strided_dataset(const raft::resources& res, * @return maybe owning current-device-accessible strided matrix */ template -auto construct_aligned_dataset(const raft::resources& res, - const SrcT& src, - uint32_t align_bytes = 16) +auto make_aligned_dataset(const raft::resources& res, const SrcT& src, uint32_t align_bytes = 16) -> std::unique_ptr> { using value_type = typename SrcT::value_type; constexpr size_t kSize = sizeof(value_type); uint32_t required_stride = raft::round_up_safe(src.extent(1) * kSize, std::lcm(align_bytes, kSize)) / kSize; - return construct_strided_dataset(res, src, required_stride); + return make_strided_dataset(res, src, required_stride); } /** Parameters for VPQ compression. */ diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh index 4b519eadcb..0832e75633 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh @@ -111,8 +111,8 @@ void search_main(raft::resources const& res, CagraSampleFilterT sample_filter = CagraSampleFilterT()) { RAFT_LOG_DEBUG("# dataset size = %lu, dim = %lu\n", - static_cast(index.dataset_view().extent(0)), - static_cast(index.dataset_view().extent(1))); + static_cast(index.dataset().extent(0)), + static_cast(index.dataset().extent(1))); RAFT_LOG_DEBUG("# query size = %lu, dim = %lu\n", static_cast(queries.extent(0)), static_cast(queries.extent(1))); @@ -151,11 +151,11 @@ void search_main(raft::resources const& res, : nullptr; uint32_t* _num_executed_iterations = nullptr; - auto dataset_internal = make_device_strided_matrix_view( - index.dataset_view().data_handle(), - index.dataset_view().extent(0), - index.dataset_view().extent(1), - index.dataset_view().stride(0)); + auto dataset_internal = + make_device_strided_matrix_view(index.dataset().data_handle(), + index.dataset().extent(0), + index.dataset().extent(1), + index.dataset().stride(0)); auto graph_internal = raft::make_device_matrix_view( reinterpret_cast(index.graph().data_handle()), index.graph().extent(0), diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh index d15d74f2be..600c8785e0 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh @@ -66,12 +66,12 @@ void serialize(raft::resources const& res, serialize_scalar(res, os, index_.metric()); serialize_mdspan(res, os, index_.graph()); - include_dataset &= (index_.dataset().n_rows() > 0); + include_dataset &= (index_.data().n_rows() > 0); serialize_scalar(res, os, include_dataset); if (include_dataset) { RAFT_LOG_INFO("Saving CAGRA index with dataset"); - neighbors::detail::serialize(res, os, index_.dataset()); + neighbors::detail::serialize(res, os, index_.data()); } else { RAFT_LOG_DEBUG("Saving CAGRA index WITHOUT dataset"); } @@ -147,7 +147,7 @@ void serialize_to_hnswlib(raft::resources const& res, std::size_t efConstruction = 500; os.write(reinterpret_cast(&efConstruction), sizeof(std::size_t)); - auto dataset = index_.dataset_view(); + auto dataset = index_.dataset(); // Remove padding before saving the dataset auto host_dataset = make_host_matrix(dataset.extent(0), dataset.extent(1)); RAFT_CUDA_TRY(cudaMemcpy2DAsync(host_dataset.data_handle(), diff --git a/cpp/include/raft/neighbors/detail/dataset_serialize.hpp b/cpp/include/raft/neighbors/detail/dataset_serialize.hpp index dc55d891be..a6a6ae59a5 100644 --- a/cpp/include/raft/neighbors/detail/dataset_serialize.hpp +++ b/cpp/include/raft/neighbors/detail/dataset_serialize.hpp @@ -138,7 +138,7 @@ auto deserialize_strided(raft::resources const& res, std::istream& is) auto stride = deserialize_scalar(res, is); auto host_array = make_host_matrix(n_rows, dim); deserialize_mdspan(res, is, host_array.view()); - return construct_strided_dataset(res, host_array, stride); + return make_strided_dataset(res, host_array, stride); } template