From 26b3a4383abe11fd9bdbec4609ee3555520d8e37 Mon Sep 17 00:00:00 2001 From: Hiroyuki Ootomo Date: Thu, 20 Apr 2023 19:02:12 +0900 Subject: [PATCH 1/2] Remove dataset from cagra index --- cpp/include/raft/neighbors/cagra.cuh | 12 ++++---- .../raft/neighbors/cagra_serialize.cuh | 20 ++++++------- cpp/include/raft/neighbors/cagra_types.hpp | 25 ++-------------- .../neighbors/detail/cagra/cagra_search.cuh | 14 +++++---- .../detail/cagra/cagra_serialize.cuh | 29 ++++++++----------- cpp/test/neighbors/ann_cagra.cuh | 6 ++-- 6 files changed, 43 insertions(+), 63 deletions(-) diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index 90728efd70..fb9e2e79c2 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -158,7 +158,7 @@ template , memory_type::host>> -index build(raft::device_resources const& res, +index build(raft::device_resources const& res, const index_params& params, mdspan, row_major, Accessor> dataset) { @@ -181,7 +181,7 @@ index build(raft::device_resources const& res, prune(res, dataset, knn_graph.view(), cagra_graph.view()); // Construct an index from dataset and pruned knn graph. - return index(res, params.metric, dataset, cagra_graph.view()); + return index(res, params.metric, cagra_graph.view()); } /** @@ -195,6 +195,7 @@ index build(raft::device_resources const& res, * @param[in] res raft resources * @param[in] params configure the search * @param[in] idx cagra index + * @param[in] dataset a device matrix view to a row-major matrix [index->size(), index->dim()] * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset * [n_queries, k] @@ -204,7 +205,8 @@ index build(raft::device_resources const& res, template void search(raft::device_resources const& res, const search_params& params, - const index& idx, + const index& idx, + raft::device_matrix_view dataset, raft::device_matrix_view queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances) @@ -216,10 +218,10 @@ void search(raft::device_resources const& res, RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), "Number of columns in output neighbors and distances matrices must equal k"); - RAFT_EXPECTS(queries.extent(1) == idx.dim(), + RAFT_EXPECTS(queries.extent(1) == dataset.extent(1), "Number of query dimensions should equal number of dimensions in the index."); - detail::search_main(res, params, idx, queries, neighbors, distances); + detail::search_main(res, params, idx, dataset, queries, neighbors, distances); } /** @} */ // end group cagra diff --git a/cpp/include/raft/neighbors/cagra_serialize.cuh b/cpp/include/raft/neighbors/cagra_serialize.cuh index befd5e9c07..5eda2fb06b 100644 --- a/cpp/include/raft/neighbors/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/cagra_serialize.cuh @@ -49,8 +49,8 @@ namespace raft::neighbors::experimental::cagra { * @param[in] index CAGRA index * */ -template -void serialize(raft::device_resources const& handle, std::ostream& os, const index& index) +template +void serialize(raft::device_resources const& handle, std::ostream& os, const index& index) { detail::serialize(handle, os, index); } @@ -79,10 +79,10 @@ void serialize(raft::device_resources const& handle, std::ostream& os, const ind * @param[in] index CAGRA index * */ -template +template void serialize(raft::device_resources const& handle, const std::string& filename, - const index& index) + const index& index) { detail::serialize(handle, filename, index); } @@ -112,10 +112,10 @@ void serialize(raft::device_resources const& handle, * * @return raft::neighbors::cagra::index */ -template -index deserialize(raft::device_resources const& handle, std::istream& is) +template +index deserialize(raft::device_resources const& handle, std::istream& is) { - return detail::deserialize(handle, is); + return detail::deserialize(handle, is); } /** @@ -143,10 +143,10 @@ index deserialize(raft::device_resources const& handle, std::istream& i * * @return raft::neighbors::cagra::index */ -template -index deserialize(raft::device_resources const& handle, const std::string& filename) +template +index deserialize(raft::device_resources const& handle, const std::string& filename) { - return detail::deserialize(handle, filename); + return detail::deserialize(handle, filename); } /**@}*/ diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp index bd9b3b586b..7b88fc15b5 100644 --- a/cpp/include/raft/neighbors/cagra_types.hpp +++ b/cpp/include/raft/neighbors/cagra_types.hpp @@ -106,11 +106,10 @@ static_assert(std::is_aggregate_v); * * The index stores the dataset and a kNN graph in device memory. * - * @tparam T data element type * @tparam IdxT type of the indices in the source dataset * */ -template +template struct index : ann::index { static_assert(!raft::is_narrowing_v, "IdxT must be able to represent all values of uint32_t"); @@ -123,25 +122,14 @@ struct index : ann::index { } // /** Total length of the index. */ - [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT { return dataset_.extent(0); } + [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT { return graph_.extent(0); } - /** Dimensionality of the data. */ - [[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t - { - return dataset_.extent(1); - } /** Graph degree */ [[nodiscard]] constexpr inline auto graph_degree() const noexcept -> uint32_t { return graph_.extent(1); } - /** Dataset [size, dim] */ - [[nodiscard]] inline auto dataset() const noexcept -> device_matrix_view - { - return dataset_.view(); - } - /** neighborhood graph [size, graph-degree] */ inline auto graph() noexcept -> device_matrix_view { @@ -165,32 +153,25 @@ struct index : ann::index { index(raft::device_resources const& res) : ann::index(), metric_(raft::distance::DistanceType::L2Expanded), - dataset_(make_device_matrix(res, 0, 0)), graph_(make_device_matrix(res, 0, 0)) { } /** Construct an index from dataset and knn_graph arrays */ - template + template index(raft::device_resources const& res, raft::distance::DistanceType metric, - mdspan, row_major, data_accessor> dataset, mdspan, row_major, graph_accessor> knn_graph) : ann::index(), metric_(metric), - dataset_(make_device_matrix(res, dataset.extent(0), dataset.extent(1))), graph_(make_device_matrix(res, knn_graph.extent(0), knn_graph.extent(1))) { - RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0), - "Dataset and knn_graph must have equal number of rows"); - raft::copy(dataset_.data_handle(), dataset.data_handle(), dataset.size(), res.get_stream()); raft::copy(graph_.data_handle(), knn_graph.data_handle(), knn_graph.size(), res.get_stream()); res.sync_stream(); } private: raft::distance::DistanceType metric_; - raft::device_matrix dataset_; raft::device_matrix graph_; }; diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh index 79cbb6198f..842859c0f8 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh @@ -40,6 +40,7 @@ namespace raft::neighbors::experimental::cagra::detail { * @param[in] handle * @param[in] params configure the search * @param[in] idx ivf-pq constructed index + * @param[in] dataset a device matrix view to a row-major matrix [index->size(), index->dim()] * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset * [n_queries, k] @@ -50,22 +51,23 @@ namespace raft::neighbors::experimental::cagra::detail { template void search_main(raft::device_resources const& res, search_params params, - const index& index, + const index& index, + raft::device_matrix_view dataset, raft::device_matrix_view queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances) { RAFT_LOG_DEBUG("# dataset size = %lu, dim = %lu\n", - static_cast(index.dataset().extent(0)), - static_cast(index.dataset().extent(1))); + static_cast(dataset.extent(0)), + static_cast(dataset.extent(1))); RAFT_LOG_DEBUG("# query size = %lu, dim = %lu\n", static_cast(queries.extent(0)), static_cast(queries.extent(1))); - RAFT_EXPECTS(queries.extent(1) == index.dim(), "Querise and index dim must match"); + RAFT_EXPECTS(queries.extent(1) == dataset.extent(1), "Querise and dataset dim must match"); uint32_t topk = neighbors.extent(1); std::unique_ptr> plan = - factory::create(res, params, index.dim(), index.graph_degree(), topk); + factory::create(res, params, dataset.extent(1), index.graph_degree(), topk); plan->check(neighbors.extent(1)); @@ -84,7 +86,7 @@ void search_main(raft::device_resources const& res, uint32_t* _num_executed_iterations = nullptr; (*plan)(res, - index.dataset(), + dataset, index.graph(), _topk_indices_ptr, _topk_distances_ptr, diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh index 171f261cf3..b753f24d14 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh @@ -36,7 +36,7 @@ struct check_index_layout { "paste in the new size and consider updating the serialization logic"); }; -template struct check_index_layout), 136>; +template struct check_index_layout), 72>; /** * Save the index to file. @@ -48,25 +48,23 @@ template struct check_index_layout), 136>; * @param[in] index_ CAGRA index * */ -template -void serialize(raft::device_resources const& res, std::ostream& os, const index& index_) +template +void serialize(raft::device_resources const& res, std::ostream& os, const index& index_) { RAFT_LOG_DEBUG( - "Saving CAGRA index, size %zu, dim %u", static_cast(index_.size()), index_.dim()); + "Saving CAGRA index, size %zu", static_cast(index_.size())); serialize_scalar(res, os, serialization_version); serialize_scalar(res, os, index_.size()); - serialize_scalar(res, os, index_.dim()); serialize_scalar(res, os, index_.graph_degree()); serialize_scalar(res, os, index_.metric()); - serialize_mdspan(res, os, index_.dataset()); serialize_mdspan(res, os, index_.graph()); } -template +template void serialize(raft::device_resources const& res, const std::string& filename, - const index& index_) + const index& index_) { std::ofstream of(filename, std::ios::out | std::ios::binary); if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } @@ -86,35 +84,32 @@ void serialize(raft::device_resources const& res, * @param[in] index_ CAGRA index * */ -template -auto deserialize(raft::device_resources const& res, std::istream& is) -> index +template +auto deserialize(raft::device_resources const& res, std::istream& is) -> index { auto ver = deserialize_scalar(res, is); if (ver != serialization_version) { RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver); } auto n_rows = deserialize_scalar(res, is); - auto dim = deserialize_scalar(res, is); auto graph_degree = deserialize_scalar(res, is); auto metric = deserialize_scalar(res, is); - auto dataset = raft::make_host_matrix(n_rows, dim); auto graph = raft::make_host_matrix(n_rows, graph_degree); - deserialize_mdspan(res, is, dataset.view()); deserialize_mdspan(res, is, graph.view()); - return index(res, metric, raft::make_const_mdspan(dataset.view()), graph.view()); + return index(res, metric, graph.view()); } -template -auto deserialize(raft::device_resources const& res, const std::string& filename) -> index +template +auto deserialize(raft::device_resources const& res, const std::string& filename) -> index { std::ifstream is(filename, std::ios::in | std::ios::binary); if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - auto index = detail::deserialize(res, is); + auto index = detail::deserialize(res, is); is.close(); diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index 385e9a80c0..80add6b905 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -123,7 +123,7 @@ class AnnCagraTest : public ::testing::TestWithParam { (const DataT*)database.data(), ps.n_rows, ps.dim); { - cagra::index index(handle_); + cagra::index index(handle_); if (ps.host_dataset) { auto database_host = raft::make_host_matrix(ps.n_rows, ps.dim); raft::copy(database_host.data_handle(), database.data(), database.size(), stream_); @@ -135,7 +135,7 @@ class AnnCagraTest : public ::testing::TestWithParam { }; cagra::serialize(handle_, "cagra_index", index); } - auto index = cagra::deserialize(handle_, "cagra_index"); + auto index = cagra::deserialize(handle_, "cagra_index"); auto search_queries_view = raft::make_device_matrix_view( search_queries.data(), ps.n_queries, ps.dim); @@ -145,7 +145,7 @@ class AnnCagraTest : public ::testing::TestWithParam { raft::make_device_matrix_view(distances_dev.data(), ps.n_queries, ps.k); cagra::search( - handle_, search_params, index, search_queries_view, indices_out_view, dists_out_view); + handle_, search_params, index, database_view, search_queries_view, indices_out_view, dists_out_view); update_host(distances_Cagra.data(), distances_dev.data(), queries_size, stream_); update_host(indices_Cagra.data(), indices_dev.data(), queries_size, stream_); From 48e109b990002b0550ed160528508dfc8d54a753 Mon Sep 17 00:00:00 2001 From: Hiroyuki Ootomo Date: Fri, 21 Apr 2023 10:47:45 +0900 Subject: [PATCH 2/2] Fix style --- cpp/include/raft/neighbors/cagra.cuh | 4 ++-- .../raft/neighbors/detail/cagra/cagra_serialize.cuh | 5 ++--- cpp/test/neighbors/ann_cagra.cuh | 9 +++++++-- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index fb9e2e79c2..3e6fb8c8c4 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -159,8 +159,8 @@ template , memory_type::host>> index build(raft::device_resources const& res, - const index_params& params, - mdspan, row_major, Accessor> dataset) + const index_params& params, + mdspan, row_major, Accessor> dataset) { size_t degree = params.intermediate_graph_degree; if (degree >= dataset.extent(0)) { diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh index b753f24d14..ada426dff3 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh @@ -51,8 +51,7 @@ template struct check_index_layout), 72>; template void serialize(raft::device_resources const& res, std::ostream& os, const index& index_) { - RAFT_LOG_DEBUG( - "Saving CAGRA index, size %zu", static_cast(index_.size())); + RAFT_LOG_DEBUG("Saving CAGRA index, size %zu", static_cast(index_.size())); serialize_scalar(res, os, serialization_version); serialize_scalar(res, os, index_.size()); @@ -95,7 +94,7 @@ auto deserialize(raft::device_resources const& res, std::istream& is) -> index(res, is); auto metric = deserialize_scalar(res, is); - auto graph = raft::make_host_matrix(n_rows, graph_degree); + auto graph = raft::make_host_matrix(n_rows, graph_degree); deserialize_mdspan(res, is, graph.view()); diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index 80add6b905..8dfde2c95c 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -144,8 +144,13 @@ class AnnCagraTest : public ::testing::TestWithParam { auto dists_out_view = raft::make_device_matrix_view(distances_dev.data(), ps.n_queries, ps.k); - cagra::search( - handle_, search_params, index, database_view, search_queries_view, indices_out_view, dists_out_view); + cagra::search(handle_, + search_params, + index, + database_view, + search_queries_view, + indices_out_view, + dists_out_view); update_host(distances_Cagra.data(), distances_dev.data(), queries_size, stream_); update_host(indices_Cagra.data(), indices_dev.data(), queries_size, stream_);