Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove dataset from CAGRA index #1435

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions cpp/include/raft/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ template <typename T,
typename IdxT = uint32_t,
typename Accessor =
host_device_accessor<std::experimental::default_accessor<T>, memory_type::host>>
index<T, IdxT> build(raft::device_resources const& res,
const index_params& params,
mdspan<const T, matrix_extent<IdxT>, row_major, Accessor> dataset)
index<IdxT> build(raft::device_resources const& res,
const index_params& params,
mdspan<const T, matrix_extent<IdxT>, row_major, Accessor> dataset)
{
size_t degree = params.intermediate_graph_degree;
if (degree >= dataset.extent(0)) {
Expand All @@ -181,7 +181,7 @@ index<T, IdxT> build(raft::device_resources const& res,
prune<T, IdxT>(res, dataset, knn_graph.view(), cagra_graph.view());

// Construct an index from dataset and pruned knn graph.
return index<T, IdxT>(res, params.metric, dataset, cagra_graph.view());
return index<IdxT>(res, params.metric, cagra_graph.view());
}

/**
Expand All @@ -195,6 +195,7 @@ index<T, IdxT> build(raft::device_resources const& res,
* @param[in] res raft resources
* @param[in] params configure the search
* @param[in] idx cagra index
* @param[in] dataset a device matrix view to a row-major matrix [index->size(), index->dim()]
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
Expand All @@ -204,7 +205,8 @@ index<T, IdxT> build(raft::device_resources const& res,
template <typename T, typename IdxT>
void search(raft::device_resources const& res,
const search_params& params,
const index<T, IdxT>& idx,
const index<IdxT>& idx,
raft::device_matrix_view<const T, IdxT, row_major> dataset,
raft::device_matrix_view<const T, IdxT, row_major> queries,
raft::device_matrix_view<IdxT, IdxT, row_major> neighbors,
raft::device_matrix_view<float, IdxT, row_major> distances)
Expand All @@ -216,10 +218,10 @@ void search(raft::device_resources const& res,
RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1),
"Number of columns in output neighbors and distances matrices must equal k");

RAFT_EXPECTS(queries.extent(1) == idx.dim(),
RAFT_EXPECTS(queries.extent(1) == dataset.extent(1),
"Number of query dimensions should equal number of dimensions in the index.");

detail::search_main(res, params, idx, queries, neighbors, distances);
detail::search_main(res, params, idx, dataset, queries, neighbors, distances);
}
/** @} */ // end group cagra

Expand Down
20 changes: 10 additions & 10 deletions cpp/include/raft/neighbors/cagra_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ namespace raft::neighbors::experimental::cagra {
* @param[in] index CAGRA index
*
*/
template <typename T, typename IdxT>
void serialize(raft::device_resources const& handle, std::ostream& os, const index<T, IdxT>& index)
template <typename IdxT>
void serialize(raft::device_resources const& handle, std::ostream& os, const index<IdxT>& index)
{
detail::serialize(handle, os, index);
}
Expand Down Expand Up @@ -79,10 +79,10 @@ void serialize(raft::device_resources const& handle, std::ostream& os, const ind
* @param[in] index CAGRA index
*
*/
template <typename T, typename IdxT>
template <typename IdxT>
void serialize(raft::device_resources const& handle,
const std::string& filename,
const index<T, IdxT>& index)
const index<IdxT>& index)
{
detail::serialize(handle, filename, index);
}
Expand Down Expand Up @@ -112,10 +112,10 @@ void serialize(raft::device_resources const& handle,
*
* @return raft::neighbors::cagra::index<T, IdxT>
*/
template <typename T, typename IdxT>
index<T, IdxT> deserialize(raft::device_resources const& handle, std::istream& is)
template <typename IdxT>
index<IdxT> deserialize(raft::device_resources const& handle, std::istream& is)
{
return detail::deserialize<T, IdxT>(handle, is);
return detail::deserialize<IdxT>(handle, is);
}

/**
Expand Down Expand Up @@ -143,10 +143,10 @@ index<T, IdxT> deserialize(raft::device_resources const& handle, std::istream& i
*
* @return raft::neighbors::cagra::index<T, IdxT>
*/
template <typename T, typename IdxT>
index<T, IdxT> deserialize(raft::device_resources const& handle, const std::string& filename)
template <typename IdxT>
index<IdxT> deserialize(raft::device_resources const& handle, const std::string& filename)
{
return detail::deserialize<T, IdxT>(handle, filename);
return detail::deserialize<IdxT>(handle, filename);
}

/**@}*/
Expand Down
25 changes: 3 additions & 22 deletions cpp/include/raft/neighbors/cagra_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,10 @@ static_assert(std::is_aggregate_v<search_params>);
*
* The index stores the dataset and a kNN graph in device memory.
*
* @tparam T data element type
* @tparam IdxT type of the indices in the source dataset
*
*/
template <typename T, typename IdxT>
template <typename IdxT>
struct index : ann::index {
static_assert(!raft::is_narrowing_v<uint32_t, IdxT>,
"IdxT must be able to represent all values of uint32_t");
Expand All @@ -123,25 +122,14 @@ struct index : ann::index {
}

// /** Total length of the index. */
[[nodiscard]] constexpr inline auto size() const noexcept -> IdxT { return dataset_.extent(0); }
[[nodiscard]] constexpr inline auto size() const noexcept -> IdxT { return graph_.extent(0); }

/** Dimensionality of the data. */
[[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t
{
return dataset_.extent(1);
}
/** Graph degree */
[[nodiscard]] constexpr inline auto graph_degree() const noexcept -> uint32_t
{
return graph_.extent(1);
}

/** Dataset [size, dim] */
[[nodiscard]] inline auto dataset() const noexcept -> device_matrix_view<const T, IdxT, row_major>
{
return dataset_.view();
}

/** neighborhood graph [size, graph-degree] */
inline auto graph() noexcept -> device_matrix_view<IdxT, IdxT, row_major>
{
Expand All @@ -165,32 +153,25 @@ struct index : ann::index {
index(raft::device_resources const& res)
: ann::index(),
metric_(raft::distance::DistanceType::L2Expanded),
dataset_(make_device_matrix<T, IdxT>(res, 0, 0)),
graph_(make_device_matrix<IdxT, IdxT>(res, 0, 0))
{
}

/** Construct an index from dataset and knn_graph arrays */
template <typename data_accessor, typename graph_accessor>
template <typename graph_accessor>
index(raft::device_resources const& res,
raft::distance::DistanceType metric,
mdspan<const T, matrix_extent<IdxT>, row_major, data_accessor> dataset,
mdspan<IdxT, matrix_extent<IdxT>, row_major, graph_accessor> knn_graph)
: ann::index(),
metric_(metric),
dataset_(make_device_matrix<T, IdxT>(res, dataset.extent(0), dataset.extent(1))),
graph_(make_device_matrix<IdxT, IdxT>(res, knn_graph.extent(0), knn_graph.extent(1)))
{
RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0),
"Dataset and knn_graph must have equal number of rows");
raft::copy(dataset_.data_handle(), dataset.data_handle(), dataset.size(), res.get_stream());
raft::copy(graph_.data_handle(), knn_graph.data_handle(), knn_graph.size(), res.get_stream());
res.sync_stream();
}

private:
raft::distance::DistanceType metric_;
raft::device_matrix<T, IdxT, row_major> dataset_;
raft::device_matrix<IdxT, IdxT, row_major> graph_;
};

Expand Down
14 changes: 8 additions & 6 deletions cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ namespace raft::neighbors::experimental::cagra::detail {
* @param[in] handle
* @param[in] params configure the search
* @param[in] idx ivf-pq constructed index
* @param[in] dataset a device matrix view to a row-major matrix [index->size(), index->dim()]
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
Expand All @@ -50,22 +51,23 @@ namespace raft::neighbors::experimental::cagra::detail {
template <typename T, typename IdxT = uint32_t, typename DistanceT = float>
void search_main(raft::device_resources const& res,
search_params params,
const index<T, IdxT>& index,
const index<IdxT>& index,
raft::device_matrix_view<const T, IdxT, row_major> dataset,
raft::device_matrix_view<const T, IdxT, row_major> queries,
raft::device_matrix_view<IdxT, IdxT, row_major> neighbors,
raft::device_matrix_view<DistanceT, IdxT, row_major> distances)
{
RAFT_LOG_DEBUG("# dataset size = %lu, dim = %lu\n",
static_cast<size_t>(index.dataset().extent(0)),
static_cast<size_t>(index.dataset().extent(1)));
static_cast<size_t>(dataset.extent(0)),
static_cast<size_t>(dataset.extent(1)));
RAFT_LOG_DEBUG("# query size = %lu, dim = %lu\n",
static_cast<size_t>(queries.extent(0)),
static_cast<size_t>(queries.extent(1)));
RAFT_EXPECTS(queries.extent(1) == index.dim(), "Querise and index dim must match");
RAFT_EXPECTS(queries.extent(1) == dataset.extent(1), "Querise and dataset dim must match");
uint32_t topk = neighbors.extent(1);

std::unique_ptr<search_plan_impl<T, IdxT, DistanceT>> plan =
factory<T, IdxT, DistanceT>::create(res, params, index.dim(), index.graph_degree(), topk);
factory<T, IdxT, DistanceT>::create(res, params, dataset.extent(1), index.graph_degree(), topk);

plan->check(neighbors.extent(1));

Expand All @@ -84,7 +86,7 @@ void search_main(raft::device_resources const& res,
uint32_t* _num_executed_iterations = nullptr;

(*plan)(res,
index.dataset(),
dataset,
index.graph(),
_topk_indices_ptr,
_topk_distances_ptr,
Expand Down
32 changes: 13 additions & 19 deletions cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ struct check_index_layout {
"paste in the new size and consider updating the serialization logic");
};

template struct check_index_layout<sizeof(index<double, std::uint64_t>), 136>;
template struct check_index_layout<sizeof(index<std::uint64_t>), 72>;

/**
* Save the index to file.
Expand All @@ -48,25 +48,22 @@ template struct check_index_layout<sizeof(index<double, std::uint64_t>), 136>;
* @param[in] index_ CAGRA index
*
*/
template <typename T, typename IdxT>
void serialize(raft::device_resources const& res, std::ostream& os, const index<T, IdxT>& index_)
template <typename IdxT>
void serialize(raft::device_resources const& res, std::ostream& os, const index<IdxT>& index_)
{
RAFT_LOG_DEBUG(
"Saving CAGRA index, size %zu, dim %u", static_cast<size_t>(index_.size()), index_.dim());
RAFT_LOG_DEBUG("Saving CAGRA index, size %zu", static_cast<size_t>(index_.size()));

serialize_scalar(res, os, serialization_version);
serialize_scalar(res, os, index_.size());
serialize_scalar(res, os, index_.dim());
serialize_scalar(res, os, index_.graph_degree());
serialize_scalar(res, os, index_.metric());
serialize_mdspan(res, os, index_.dataset());
serialize_mdspan(res, os, index_.graph());
}

template <typename T, typename IdxT>
template <typename IdxT>
void serialize(raft::device_resources const& res,
const std::string& filename,
const index<T, IdxT>& index_)
const index<IdxT>& index_)
{
std::ofstream of(filename, std::ios::out | std::ios::binary);
if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); }
Expand All @@ -86,35 +83,32 @@ void serialize(raft::device_resources const& res,
* @param[in] index_ CAGRA index
*
*/
template <typename T, typename IdxT>
auto deserialize(raft::device_resources const& res, std::istream& is) -> index<T, IdxT>
template <typename IdxT>
auto deserialize(raft::device_resources const& res, std::istream& is) -> index<IdxT>
{
auto ver = deserialize_scalar<int>(res, is);
if (ver != serialization_version) {
RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver);
}
auto n_rows = deserialize_scalar<IdxT>(res, is);
auto dim = deserialize_scalar<std::uint32_t>(res, is);
auto graph_degree = deserialize_scalar<std::uint32_t>(res, is);
auto metric = deserialize_scalar<raft::distance::DistanceType>(res, is);

auto dataset = raft::make_host_matrix<T, IdxT>(n_rows, dim);
auto graph = raft::make_host_matrix<IdxT, IdxT>(n_rows, graph_degree);
auto graph = raft::make_host_matrix<IdxT, IdxT>(n_rows, graph_degree);

deserialize_mdspan(res, is, dataset.view());
deserialize_mdspan(res, is, graph.view());

return index<T, IdxT>(res, metric, raft::make_const_mdspan(dataset.view()), graph.view());
return index<IdxT>(res, metric, graph.view());
}

template <typename T, typename IdxT>
auto deserialize(raft::device_resources const& res, const std::string& filename) -> index<T, IdxT>
template <typename IdxT>
auto deserialize(raft::device_resources const& res, const std::string& filename) -> index<IdxT>
{
std::ifstream is(filename, std::ios::in | std::ios::binary);

if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); }

auto index = detail::deserialize<T, IdxT>(res, is);
auto index = detail::deserialize<IdxT>(res, is);

is.close();

Expand Down
13 changes: 9 additions & 4 deletions cpp/test/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class AnnCagraTest : public ::testing::TestWithParam<AnnCagraInputs> {
(const DataT*)database.data(), ps.n_rows, ps.dim);

{
cagra::index<DataT, IdxT> index(handle_);
cagra::index<IdxT> index(handle_);
if (ps.host_dataset) {
auto database_host = raft::make_host_matrix<DataT, IdxT>(ps.n_rows, ps.dim);
raft::copy(database_host.data_handle(), database.data(), database.size(), stream_);
Expand All @@ -135,7 +135,7 @@ class AnnCagraTest : public ::testing::TestWithParam<AnnCagraInputs> {
};
cagra::serialize(handle_, "cagra_index", index);
}
auto index = cagra::deserialize<DataT, IdxT>(handle_, "cagra_index");
auto index = cagra::deserialize<IdxT>(handle_, "cagra_index");

auto search_queries_view = raft::make_device_matrix_view<const DataT, IdxT>(
search_queries.data(), ps.n_queries, ps.dim);
Expand All @@ -144,8 +144,13 @@ class AnnCagraTest : public ::testing::TestWithParam<AnnCagraInputs> {
auto dists_out_view =
raft::make_device_matrix_view<DistanceT, IdxT>(distances_dev.data(), ps.n_queries, ps.k);

cagra::search(
handle_, search_params, index, search_queries_view, indices_out_view, dists_out_view);
cagra::search(handle_,
search_params,
index,
database_view,
search_queries_view,
indices_out_view,
dists_out_view);

update_host(distances_Cagra.data(), distances_dev.data(), queries_size, stream_);
update_host(indices_Cagra.data(), indices_dev.data(), queries_size, stream_);
Expand Down