Skip to content

Commit

Permalink
Add CAGRA-Q build (compression) (rapidsai#2213)
Browse files Browse the repository at this point in the history
Add a `cagra::compress` function that implements CAGRA-Q (VQ + PQ) compression of a given dataset.
The result, `compressed_dataset`, is supposed to complement the CAGRA graph during `cagra::search` in place of a raw dataset.

### Current state:

  - The code runs and produces a meaningful output (tested internally by running the original prototype search with the generated compressed dataset); the recall levels are approximately the same as with the prototype implementation.
  - No test coverage yet (need to coordinate with the search PR rapidsai#2206)
  - Full `pq_bits` support ([4,5,6,7,8] - same as in IVF-PQ)
  - Any `pq_dim` values are accepted, but the dataset is not padded and thus `dim` must be a multiple of `pq_dim`.
  - The codebook math type is hardcoded to `half` to match the prototype implementation for now. This could be a runtime (build) parameter as well.
  - All common input data types should work (`uint8_t`, `int8_t`, `half`, and `float` compile), but I tested only `float`.

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)

URL: rapidsai#2213
  • Loading branch information
achirkin authored Mar 18, 2024
1 parent 69fd971 commit 32f6f40
Show file tree
Hide file tree
Showing 9 changed files with 1,103 additions and 124 deletions.
3 changes: 2 additions & 1 deletion cpp/include/raft/core/detail/mdspan_numpy_serializer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ inline dtype_t get_numpy_dtype()
}

#if defined(_RAFT_HAS_CUDA)
template <typename T, typename std::enable_if_t<std::is_same_v<T, half>, bool> = true>
template <typename T,
typename std::enable_if_t<std::is_same_v<std::remove_cv_t<T>, half>, bool> = true>
inline dtype_t get_numpy_dtype()
{
return {RAFT_NUMPY_HOST_ENDIAN_CHAR, 'e', sizeof(T)};
Expand Down
89 changes: 32 additions & 57 deletions cpp/include/raft/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <raft/core/mdspan.hpp>
#include <raft/core/resources.hpp>
#include <raft/neighbors/cagra_types.hpp>
#include <raft/neighbors/dataset.hpp>

#include <rmm/cuda_stream_view.hpp>

Expand Down Expand Up @@ -279,62 +280,6 @@ index<T, IdxT> build(raft::resources const& res,
return detail::build<T, IdxT, Accessor>(res, params, dataset);
}

/**
* @brief Search ANN using the constructed index.
*
* See the [cagra::build](#cagra::build) documentation for a usage example.
*
* @tparam T data element type
* @tparam IdxT type of the indices
*
* @param[in] res raft resources
* @param[in] params configure the search
* @param[in] idx cagra index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
*/
template <typename T, typename IdxT>
void search(raft::resources const& res,
const search_params& params,
const index<T, IdxT>& idx,
raft::device_matrix_view<const T, int64_t, row_major> queries,
raft::device_matrix_view<IdxT, int64_t, row_major> neighbors,
raft::device_matrix_view<float, int64_t, row_major> distances)
{
RAFT_EXPECTS(
queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0),
"Number of rows in output neighbors and distances matrices must equal the number of queries.");

RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1),
"Number of columns in output neighbors and distances matrices must equal k");
RAFT_EXPECTS(queries.extent(1) == idx.dim(),
"Number of query dimensions should equal number of dimensions in the index.");

using internal_IdxT = typename std::make_unsigned<IdxT>::type;
auto queries_internal = raft::make_device_matrix_view<const T, int64_t, row_major>(
queries.data_handle(), queries.extent(0), queries.extent(1));
auto neighbors_internal = raft::make_device_matrix_view<internal_IdxT, int64_t, row_major>(
reinterpret_cast<internal_IdxT*>(neighbors.data_handle()),
neighbors.extent(0),
neighbors.extent(1));
auto distances_internal = raft::make_device_matrix_view<float, int64_t, row_major>(
distances.data_handle(), distances.extent(0), distances.extent(1));

cagra::detail::search_main<T,
internal_IdxT,
decltype(raft::neighbors::filtering::none_cagra_sample_filter()),
IdxT>(res,
params,
idx,
queries_internal,
neighbors_internal,
distances_internal,
raft::neighbors::filtering::none_cagra_sample_filter());
}

/**
* @brief Search ANN using the constructed index with the given sample filter.
*
Expand Down Expand Up @@ -401,10 +346,40 @@ void search_with_filtering(raft::resources const& res,
auto distances_internal = raft::make_device_matrix_view<float, int64_t, row_major>(
distances.data_handle(), distances.extent(0), distances.extent(1));

cagra::detail::search_main<T, internal_IdxT, CagraSampleFilterT, IdxT>(
return cagra::detail::search_main<T, internal_IdxT, CagraSampleFilterT, IdxT>(
res, params, idx, queries_internal, neighbors_internal, distances_internal, sample_filter);
}

/**
* @brief Search ANN using the constructed index.
*
* See the [cagra::build](#cagra::build) documentation for a usage example.
*
* @tparam T data element type
* @tparam IdxT type of the indices
*
* @param[in] res raft resources
* @param[in] params configure the search
* @param[in] idx cagra index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
*/
template <typename T, typename IdxT>
void search(raft::resources const& res,
const search_params& params,
const index<T, IdxT>& idx,
raft::device_matrix_view<const T, int64_t, row_major> queries,
raft::device_matrix_view<IdxT, int64_t, row_major> neighbors,
raft::device_matrix_view<float, int64_t, row_major> distances)
{
using none_filter_type = raft::neighbors::filtering::none_cagra_sample_filter;
return cagra::search_with_filtering<T, IdxT, none_filter_type>(
res, params, idx, queries, neighbors, distances, none_filter_type{});
}

/** @} */ // end group cagra

} // namespace raft::neighbors::cagra
Expand Down
90 changes: 50 additions & 40 deletions cpp/include/raft/neighbors/cagra_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include "ann_types.hpp"
#include "dataset.hpp"

#include <raft/core/device_mdarray.hpp>
#include <raft/core/error.hpp>
Expand All @@ -35,6 +36,7 @@
#include <optional>
#include <string>
#include <type_traits>

namespace raft::neighbors::cagra {
/**
* @addtogroup cagra
Expand All @@ -61,6 +63,12 @@ struct index_params : ann::index_params {
graph_build_algo build_algo = graph_build_algo::IVF_PQ;
/** Number of Iterations to run if building with NN_DESCENT */
size_t nn_descent_niter = 20;
/**
* Specify compression params if compression is desired.
*
* NOTE: this is experimental new API, consider it unsafe.
*/
std::optional<vpq_params> compression = std::nullopt;
};

enum class search_algo {
Expand Down Expand Up @@ -145,25 +153,37 @@ struct index : ann::index {
/** Total length of the index (number of vectors). */
[[nodiscard]] constexpr inline auto size() const noexcept -> IdxT
{
return dataset_view_.extent(0) ? dataset_view_.extent(0) : graph_view_.extent(0);
auto data_rows = dataset_->n_rows();
return data_rows > 0 ? data_rows : graph_view_.extent(0);
}

/** Dimensionality of the data. */
[[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t
{
return dataset_view_.extent(1);
}
[[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t { return dataset_->dim(); }
/** Graph degree */
[[nodiscard]] constexpr inline auto graph_degree() const noexcept -> uint32_t
{
return graph_view_.extent(1);
}

/** Dataset [size, dim] */
[[nodiscard]] inline auto dataset() const noexcept
/**
* DEPRECATED: please use data() instead.
* If you need to query dataset dimensions, use the dim() and size() of the cagra index.
* The data_handle() is not always available: you need to do a dynamic_cast to the expected
* dataset type at runtime.
*/
[[nodiscard]] [[deprecated("Use data()")]] inline auto dataset() const noexcept
-> device_matrix_view<const T, int64_t, layout_stride>
{
return dataset_view_;
auto p = dynamic_cast<strided_dataset<T, int64_t>*>(dataset_.get());
if (p != nullptr) { return p->view(); }
auto d = dataset_->dim();
return make_device_strided_matrix_view<const T, int64_t>(nullptr, 0, d, d);
}

/** Dataset [size, dim] */
[[nodiscard]] inline auto data() const noexcept -> const neighbors::dataset<int64_t>&
{
return *dataset_;
}

/** neighborhood graph [size, graph-degree] */
Expand All @@ -185,7 +205,7 @@ struct index : ann::index {
raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded)
: ann::index(),
metric_(metric),
dataset_(make_device_matrix<T, int64_t>(res, 0, 0)),
dataset_(new neighbors::empty_dataset<int64_t>(0)),
graph_(make_device_matrix<IdxT, int64_t>(res, 0, 0))
{
}
Expand Down Expand Up @@ -251,12 +271,11 @@ struct index : ann::index {
mdspan<const IdxT, matrix_extent<int64_t>, row_major, graph_accessor> knn_graph)
: ann::index(),
metric_(metric),
dataset_(make_device_matrix<T, int64_t>(res, 0, 0)),
dataset_(make_aligned_dataset(res, dataset, 16)),
graph_(make_device_matrix<IdxT, int64_t>(res, 0, 0))
{
RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0),
"Dataset and knn_graph must have equal number of rows");
update_dataset(res, dataset);
update_graph(res, knn_graph);
resource::sync_stream(res);
}
Expand All @@ -271,21 +290,14 @@ struct index : ann::index {
void update_dataset(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, row_major> dataset)
{
if (dataset.extent(1) * sizeof(T) % 16 != 0) {
RAFT_LOG_DEBUG("Creating a padded copy of CAGRA dataset in device memory");
copy_padded(res, dataset);
} else {
dataset_view_ = make_device_strided_matrix_view<const T, int64_t>(
dataset.data_handle(), dataset.extent(0), dataset.extent(1), dataset.extent(1));
}
dataset_ = make_aligned_dataset(res, dataset, 16);
}

/** Set the dataset reference explicitly to a device matrix view with padding. */
void update_dataset(raft::resources const&,
void update_dataset(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, layout_stride> dataset)
{
RAFT_EXPECTS(dataset.stride(0) * sizeof(T) % 16 == 0, "Incorrect data padding.");
dataset_view_ = dataset;
dataset_ = make_aligned_dataset(res, dataset, 16);
}

/**
Expand All @@ -296,8 +308,22 @@ struct index : ann::index {
void update_dataset(raft::resources const& res,
raft::host_matrix_view<const T, int64_t, row_major> dataset)
{
RAFT_LOG_DEBUG("Copying CAGRA dataset from host to device");
copy_padded(res, dataset);
dataset_ = make_aligned_dataset(res, dataset, 16);
}

/** Replace the dataset with a new dataset. */
template <typename DatasetT>
auto update_dataset(raft::resources const& res, DatasetT&& dataset)
-> std::enable_if_t<std::is_base_of_v<neighbors::dataset<int64_t>, DatasetT>>
{
dataset_ = std::make_unique<DatasetT>(std::move(dataset));
}

template <typename DatasetT>
auto update_dataset(raft::resources const& res, std::unique_ptr<DatasetT>&& dataset)
-> std::enable_if_t<std::is_base_of_v<neighbors::dataset<int64_t>, DatasetT>>
{
dataset_ = std::move(dataset);
}

/**
Expand Down Expand Up @@ -334,26 +360,10 @@ struct index : ann::index {
}

private:
/** Create a device copy of the dataset, and pad it if necessary. */
template <typename data_accessor>
void copy_padded(raft::resources const& res,
mdspan<const T, matrix_extent<int64_t>, row_major, data_accessor> dataset)
{
detail::copy_with_padding(res, dataset_, dataset);

dataset_view_ = make_device_strided_matrix_view<const T, int64_t>(
dataset_.data_handle(), dataset_.extent(0), dataset.extent(1), dataset_.extent(1));
RAFT_LOG_DEBUG("CAGRA dataset strided matrix view %zux%zu, stride %zu",
static_cast<size_t>(dataset_view_.extent(0)),
static_cast<size_t>(dataset_view_.extent(1)),
static_cast<size_t>(dataset_view_.stride(0)));
}

raft::distance::DistanceType metric_;
raft::device_matrix<T, int64_t, row_major> dataset_;
raft::device_matrix<IdxT, int64_t, row_major> graph_;
raft::device_matrix_view<const T, int64_t, layout_stride> dataset_view_;
raft::device_matrix_view<const IdxT, int64_t, row_major> graph_view_;
std::unique_ptr<neighbors::dataset<int64_t>> dataset_;
};

/** @} */
Expand Down
Loading

0 comments on commit 32f6f40

Please sign in to comment.