Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CAGRA-Q build (compression) #2213

Merged
merged 30 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
ac6b088
Add CAGRA-Q build (compression)
achirkin Mar 5, 2024
fdbae63
Merge branch 'branch-24.04' into fea-cagra-q-build
achirkin Mar 5, 2024
aeb0daa
Merge branch 'branch-24.04' into fea-cagra-q-build
achirkin Mar 6, 2024
8b9bee0
Formatting and style refactoring
achirkin Mar 6, 2024
aa70b61
Merge branch 'branch-24.04' into fea-cagra-q-build
achirkin Mar 6, 2024
1a72020
Integrate vpq_dataset into cagra
achirkin Mar 7, 2024
99fa02f
Add dataset compression as an optional step during build
achirkin Mar 7, 2024
e1bd06b
Merge branch 'branch-24.04' into fea-cagra-q-build
achirkin Mar 7, 2024
833b50f
Update cpp/include/raft/neighbors/dataset.hpp
achirkin Mar 8, 2024
53a5c14
Add dataset serialization
achirkin Mar 8, 2024
02f2193
Add comments regarding the internals of pq_bits/pq_width
achirkin Mar 8, 2024
34a7642
Fix incorrect stride assumption that prevented construct_strided_data…
achirkin Mar 8, 2024
3088703
Various small changes to the dataset type to improve safety and be mo…
achirkin Mar 11, 2024
999d343
Merge branch 'branch-24.04' into fea-cagra-q-build
achirkin Mar 11, 2024
4498a22
Add a stub for the search function
achirkin Mar 11, 2024
dd1cc99
Switch to half as the vpq codebook type
achirkin Mar 11, 2024
292406c
Simplify unique_ptr arithmetics
achirkin Mar 11, 2024
24ebae2
Fix deserialization: set the padding bytes to zero in the strided dat…
achirkin Mar 11, 2024
cb11327
Further simplify deserialization code
achirkin Mar 12, 2024
44aabc4
Merge branch 'branch-24.04' into fea-cagra-q-build
achirkin Mar 13, 2024
9a55874
Remove the dynamic dispatch from public search function for it to be …
achirkin Mar 13, 2024
88566d6
Make the construct_strided_dataset only copy the data when it's not a…
achirkin Mar 13, 2024
8a3ae0d
Merge branch 'branch-24.04' into fea-cagra-q-build
achirkin Mar 13, 2024
d1e9e3d
Merge branch 'branch-24.04' into fea-cagra-q-build
achirkin Mar 14, 2024
890b29e
Bump serialization version
achirkin Mar 14, 2024
66ae8ae
Address offline and online review comments
achirkin Mar 14, 2024
82f638d
Merge branch 'branch-24.04' into fea-cagra-q-build
achirkin Mar 15, 2024
dc7d761
Merge branch 'branch-24.04' into fea-cagra-q-build
achirkin Mar 15, 2024
54b99b9
Merge branch 'branch-24.04' into fea-cagra-q-build
achirkin Mar 15, 2024
0f2f63b
Merge branch 'branch-24.04' into fea-cagra-q-build
achirkin Mar 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 26 additions & 38 deletions cpp/include/raft/neighbors/cagra_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include "ann_types.hpp"
#include "dataset.hpp"

#include <raft/core/device_mdarray.hpp>
#include <raft/core/error.hpp>
Expand All @@ -35,6 +36,7 @@
#include <optional>
#include <string>
#include <type_traits>

namespace raft::neighbors::cagra {
/**
* @addtogroup cagra
Expand All @@ -61,6 +63,8 @@ struct index_params : ann::index_params {
graph_build_algo build_algo = graph_build_algo::IVF_PQ;
/** Number of Iterations to run if building with NN_DESCENT */
size_t nn_descent_niter = 20;
/** Specify compression params if compression is desired. */
std::optional<vpq_params> compression = std::nullopt;
};

enum class search_algo {
Expand Down Expand Up @@ -145,14 +149,12 @@ struct index : ann::index {
/** Total length of the index (number of vectors). */
[[nodiscard]] constexpr inline auto size() const noexcept -> IdxT
{
return dataset_view_.extent(0) ? dataset_view_.extent(0) : graph_view_.extent(0);
auto data_rows = dataset_->n_rows();
return data_rows > 0 ? data_rows : graph_view_.extent(0);
}

/** Dimensionality of the data. */
[[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t
{
return dataset_view_.extent(1);
}
[[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t { return dataset_->dim(); }
/** Graph degree */
[[nodiscard]] constexpr inline auto graph_degree() const noexcept -> uint32_t
{
Expand All @@ -163,7 +165,10 @@ struct index : ann::index {
[[nodiscard]] inline auto dataset() const noexcept
-> device_matrix_view<const T, int64_t, layout_stride>
{
return dataset_view_;
auto p = dynamic_cast<strided_dataset<T, int64_t>*>(dataset_.get());
if (p != nullptr) { return p->view(); }
auto d = dataset_->dim();
return make_device_strided_matrix_view<const T, int64_t>(nullptr, 0, d, d);
achirkin marked this conversation as resolved.
Show resolved Hide resolved
}

/** neighborhood graph [size, graph-degree] */
Expand All @@ -185,7 +190,7 @@ struct index : ann::index {
raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded)
: ann::index(),
metric_(metric),
dataset_(make_device_matrix<T, int64_t>(res, 0, 0)),
dataset_(new neighbors::empty_dataset<int64_t>(0)),
graph_(make_device_matrix<IdxT, int64_t>(res, 0, 0))
{
}
Expand Down Expand Up @@ -251,12 +256,11 @@ struct index : ann::index {
mdspan<const IdxT, matrix_extent<int64_t>, row_major, graph_accessor> knn_graph)
: ann::index(),
metric_(metric),
dataset_(make_device_matrix<T, int64_t>(res, 0, 0)),
dataset_(upcast_dataset_ptr(construct_aligned_dataset(res, dataset, 16))),
graph_(make_device_matrix<IdxT, int64_t>(res, 0, 0))
{
RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0),
"Dataset and knn_graph must have equal number of rows");
update_dataset(res, dataset);
update_graph(res, knn_graph);
resource::sync_stream(res);
}
Expand All @@ -271,21 +275,14 @@ struct index : ann::index {
void update_dataset(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, row_major> dataset)
{
if (dataset.extent(1) * sizeof(T) % 16 != 0) {
RAFT_LOG_DEBUG("Creating a padded copy of CAGRA dataset in device memory");
copy_padded(res, dataset);
} else {
dataset_view_ = make_device_strided_matrix_view<const T, int64_t>(
dataset.data_handle(), dataset.extent(0), dataset.extent(1), dataset.extent(1));
}
upcast_dataset_ptr(construct_aligned_dataset(res, dataset, 16)).swap(dataset_);
}

/** Set the dataset reference explicitly to a device matrix view with padding. */
void update_dataset(raft::resources const&,
void update_dataset(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, layout_stride> dataset)
{
RAFT_EXPECTS(dataset.stride(0) * sizeof(T) % 16 == 0, "Incorrect data padding.");
dataset_view_ = dataset;
upcast_dataset_ptr(construct_aligned_dataset(res, dataset, 16)).swap(dataset_);
}

/**
Expand All @@ -296,8 +293,15 @@ struct index : ann::index {
void update_dataset(raft::resources const& res,
raft::host_matrix_view<const T, int64_t, row_major> dataset)
{
RAFT_LOG_DEBUG("Copying CAGRA dataset from host to device");
copy_padded(res, dataset);
upcast_dataset_ptr(construct_aligned_dataset(res, dataset, 16)).swap(dataset_);
}

/** Replace the dataset with a new dataset. */
template <typename DatasetT>
auto update_dataset(raft::resources const& res, DatasetT&& dataset)
-> std::enable_if_t<std::is_base_of_v<neighbors::dataset<int64_t>, DatasetT>>
{
upcast_dataset_ptr(std::make_unique<DatasetT>(std::move(dataset))).swap(dataset_);
}

/**
Expand Down Expand Up @@ -334,26 +338,10 @@ struct index : ann::index {
}

private:
/** Create a device copy of the dataset, and pad it if necessary. */
template <typename data_accessor>
void copy_padded(raft::resources const& res,
mdspan<const T, matrix_extent<int64_t>, row_major, data_accessor> dataset)
{
detail::copy_with_padding(res, dataset_, dataset);

dataset_view_ = make_device_strided_matrix_view<const T, int64_t>(
dataset_.data_handle(), dataset_.extent(0), dataset.extent(1), dataset_.extent(1));
RAFT_LOG_DEBUG("CAGRA dataset strided matrix view %zux%zu, stride %zu",
static_cast<size_t>(dataset_view_.extent(0)),
static_cast<size_t>(dataset_view_.extent(1)),
static_cast<size_t>(dataset_view_.stride(0)));
}

raft::distance::DistanceType metric_;
raft::device_matrix<T, int64_t, row_major> dataset_;
raft::device_matrix<IdxT, int64_t, row_major> graph_;
raft::device_matrix_view<const T, int64_t, layout_stride> dataset_view_;
raft::device_matrix_view<const IdxT, int64_t, row_major> graph_view_;
std::unique_ptr<neighbors::dataset<int64_t>> dataset_;
achirkin marked this conversation as resolved.
Show resolved Hide resolved
};

/** @} */
Expand Down
Loading
Loading