From 589b8aec9679556fc176ab27f8c5a6277d435e98 Mon Sep 17 00:00:00 2001 From: William Hicks Date: Thu, 30 Nov 2023 11:09:49 -0500 Subject: [PATCH 01/17] Update brute-force API to accept params struct --- .../raft/neighbors/brute_force-ext.cuh | 57 +++++++++++++++++++ .../raft/neighbors/brute_force-inl.cuh | 46 +++++++++++++++ .../raft/neighbors/brute_force_types.hpp | 3 + .../neighbors/brute_force_knn_index_float.cu | 21 +++++++ 4 files changed, 127 insertions(+) diff --git a/cpp/include/raft/neighbors/brute_force-ext.cuh b/cpp/include/raft/neighbors/brute_force-ext.cuh index 4c1f7ea21e..157c7f291a 100644 --- a/cpp/include/raft/neighbors/brute_force-ext.cuh +++ b/cpp/include/raft/neighbors/brute_force-ext.cuh @@ -116,6 +116,14 @@ extern template void search( raft::device_matrix_view neighbors, raft::device_matrix_view distances); +extern template void search( + raft::resources const& res, + search_params const& params, + const raft::neighbors::brute_force::index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances); + extern template void search( raft::resources const& res, const raft::neighbors::brute_force::index& idx, @@ -123,11 +131,60 @@ extern template void search( raft::device_matrix_view neighbors, raft::device_matrix_view distances); +extern template void search( + raft::resources const& res, + search_params const& params, + const raft::neighbors::brute_force::index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances); + extern template raft::neighbors::brute_force::index build( raft::resources const& res, raft::device_matrix_view dataset, raft::distance::DistanceType metric, float metric_arg); + +extern template raft::neighbors::brute_force::index build( + raft::resources const& res, + index_params const& params, + raft::device_matrix_view dataset) +} // namespace raft::neighbors::brute_force + +#define instantiate_raft_neighbors_brute_force_fused_l2_knn( \ + value_t, idx_t, idx_layout, query_layout) \ + extern template void raft::neighbors::brute_force::fused_l2_knn( \ + raft::resources const& handle, \ + raft::device_matrix_view index, \ + raft::device_matrix_view query, \ + raft::device_matrix_view out_inds, \ + raft::device_matrix_view out_dists, \ + raft::distance::DistanceType metric); + +instantiate_raft_neighbors_brute_force_fused_l2_knn(float, + int64_t, + raft::row_major, + raft::row_major) + +#undef instantiate_raft_neighbors_brute_force_fused_l2_knn + + extern template void search( + raft::resources const& res, + const raft::neighbors::brute_force::index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances); + +extern template raft::neighbors::brute_force::index build( + raft::resources const& res, + raft::device_matrix_view dataset, + raft::distance::DistanceType metric, + float metric_arg); + +extern template raft::neighbors::brute_force::index build( + raft::resources const& res, + index_params const& params, + raft::device_matrix_view dataset) } // namespace raft::neighbors::brute_force #define instantiate_raft_neighbors_brute_force_fused_l2_knn( \ diff --git a/cpp/include/raft/neighbors/brute_force-inl.cuh b/cpp/include/raft/neighbors/brute_force-inl.cuh index 6b86f2463f..2a0a503864 100644 --- a/cpp/include/raft/neighbors/brute_force-inl.cuh +++ b/cpp/include/raft/neighbors/brute_force-inl.cuh @@ -327,6 +327,49 @@ index build(raft::resources const& res, return index(res, dataset, std::move(norms), metric, metric_arg); } +/** + * @brief Build the index from the dataset for efficient search. + * + * @tparam T data element type + * + * @param[in] res + * @param[in] params configure the index building + * @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim] + * + * @return the constructed brute force index + */ +template +index build(raft::resources const& res, + index_params const& params, + mdspan, row_major, Accessor> dataset) +{ + return build(res, dataset, params.metric, float(params.metric_arg)); +} + +/** + * @brief Brute Force search using the constructed index. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] res raft resources + * @param[in] idx brute force index + * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * k] + */ +template +void search(raft::resources const& res, + const index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) +{ + raft::neighbors::detail::brute_force_search(res, idx, queries, neighbors, distances); +} + /** * @brief Brute Force search using the constructed index. * @@ -334,6 +377,7 @@ index build(raft::resources const& res, * @tparam IdxT type of the indices * * @param[in] res raft resources + * @param[in] params configure the search * @param[in] idx brute force index * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset @@ -343,6 +387,7 @@ index build(raft::resources const& res, */ template void search(raft::resources const& res, + search_params const& params, const index& idx, raft::device_matrix_view queries, raft::device_matrix_view neighbors, @@ -350,5 +395,6 @@ void search(raft::resources const& res, { raft::neighbors::detail::brute_force_search(res, idx, queries, neighbors, distances); } + /** @} */ // end group brute_force_knn } // namespace raft::neighbors::brute_force diff --git a/cpp/include/raft/neighbors/brute_force_types.hpp b/cpp/include/raft/neighbors/brute_force_types.hpp index 039599845e..9c8327fb4a 100644 --- a/cpp/include/raft/neighbors/brute_force_types.hpp +++ b/cpp/include/raft/neighbors/brute_force_types.hpp @@ -35,6 +35,9 @@ namespace raft::neighbors::brute_force { * @{ */ +using ann::index_params; +using ann::search_params; + /** * @brief Brute Force index. * diff --git a/cpp/src/neighbors/brute_force_knn_index_float.cu b/cpp/src/neighbors/brute_force_knn_index_float.cu index d4f902c087..899e029df7 100644 --- a/cpp/src/neighbors/brute_force_knn_index_float.cu +++ b/cpp/src/neighbors/brute_force_knn_index_float.cu @@ -25,6 +25,14 @@ template void raft::neighbors::brute_force::search( raft::device_matrix_view neighbors, raft::device_matrix_view distances); +template void raft::neighbors::brute_force::search( + raft::resources const& res, + raft::neighbors::brute_force::search_params const& params, + const raft::neighbors::brute_force::index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances); + template void raft::neighbors::brute_force::search( raft::resources const& res, const raft::neighbors::brute_force::index& idx, @@ -32,8 +40,21 @@ template void raft::neighbors::brute_force::search( raft::device_matrix_view neighbors, raft::device_matrix_view distances); +template void raft::neighbors::brute_force::search( + raft::resources const& res, + raft::neighbors::brute_force::search_params const& params, + const raft::neighbors::brute_force::index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances); + template raft::neighbors::brute_force::index raft::neighbors::brute_force::build( raft::resources const& res, raft::device_matrix_view dataset, raft::distance::DistanceType metric, float metric_arg); + +template raft::neighbors::brute_force::index raft::neighbors::brute_force::build( + raft::resources const& res, + raft::neighbors::brute_force::index_params const& params, + raft::device_matrix_view dataset); From 4050682d51a3043e90bac5c7d60dc36309180b10 Mon Sep 17 00:00:00 2001 From: William Hicks Date: Fri, 1 Dec 2023 12:10:11 -0500 Subject: [PATCH 02/17] Update brute force extern template declarations --- .../raft/neighbors/brute_force-ext.cuh | 59 +++++-------------- 1 file changed, 14 insertions(+), 45 deletions(-) diff --git a/cpp/include/raft/neighbors/brute_force-ext.cuh b/cpp/include/raft/neighbors/brute_force-ext.cuh index 157c7f291a..b8571e7c0e 100644 --- a/cpp/include/raft/neighbors/brute_force-ext.cuh +++ b/cpp/include/raft/neighbors/brute_force-ext.cuh @@ -45,8 +45,21 @@ index build(raft::resources const& res, raft::distance::DistanceType metric = distance::DistanceType::L2Unexpanded, T metric_arg = 0.0) RAFT_EXPLICIT; +template +index build(raft::resources const& res, + index_params const& params, + mdspan, row_major, Accessor> dataset) RAFT_EXPLICIT; + +template +void search(raft::resources const& res, + const index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) RAFT_EXPLICIT; + template void search(raft::resources const& res, + search_params const& params, const index& idx, raft::device_matrix_view queries, raft::device_matrix_view neighbors, @@ -131,50 +144,6 @@ extern template void search( raft::device_matrix_view neighbors, raft::device_matrix_view distances); -extern template void search( - raft::resources const& res, - search_params const& params, - const raft::neighbors::brute_force::index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances); - -extern template raft::neighbors::brute_force::index build( - raft::resources const& res, - raft::device_matrix_view dataset, - raft::distance::DistanceType metric, - float metric_arg); - -extern template raft::neighbors::brute_force::index build( - raft::resources const& res, - index_params const& params, - raft::device_matrix_view dataset) -} // namespace raft::neighbors::brute_force - -#define instantiate_raft_neighbors_brute_force_fused_l2_knn( \ - value_t, idx_t, idx_layout, query_layout) \ - extern template void raft::neighbors::brute_force::fused_l2_knn( \ - raft::resources const& handle, \ - raft::device_matrix_view index, \ - raft::device_matrix_view query, \ - raft::device_matrix_view out_inds, \ - raft::device_matrix_view out_dists, \ - raft::distance::DistanceType metric); - -instantiate_raft_neighbors_brute_force_fused_l2_knn(float, - int64_t, - raft::row_major, - raft::row_major) - -#undef instantiate_raft_neighbors_brute_force_fused_l2_knn - - extern template void search( - raft::resources const& res, - const raft::neighbors::brute_force::index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances); - extern template raft::neighbors::brute_force::index build( raft::resources const& res, raft::device_matrix_view dataset, @@ -184,7 +153,7 @@ extern template raft::neighbors::brute_force::index build( extern template raft::neighbors::brute_force::index build( raft::resources const& res, index_params const& params, - raft::device_matrix_view dataset) + raft::device_matrix_view dataset); } // namespace raft::neighbors::brute_force #define instantiate_raft_neighbors_brute_force_fused_l2_knn( \ From 3844863d3f4a3b0929d424b34d754f1f130e17f8 Mon Sep 17 00:00:00 2001 From: William Hicks Date: Fri, 1 Dec 2023 23:39:10 -0500 Subject: [PATCH 03/17] Begin adding brute-force index serialization --- .../raft/neighbors/brute_force_serialize.cuh | 198 ++++++++++++++++++ 1 file changed, 198 insertions(+) create mode 100644 cpp/include/raft/neighbors/brute_force_serialize.cuh diff --git a/cpp/include/raft/neighbors/brute_force_serialize.cuh b/cpp/include/raft/neighbors/brute_force_serialize.cuh new file mode 100644 index 0000000000..0642385cb4 --- /dev/null +++ b/cpp/include/raft/neighbors/brute_force_serialize.cuh @@ -0,0 +1,198 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace raft::neighbors::brute_force { + +/** + * \defgroup brute_force_serialize Brute Force Serialize + * @{ + */ + +/** + * Write the index to an output stream + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::resources handle; + * + * // create an output stream + * std::ostream os(std::cout.rdbuf()); + * // create an index with `auto index = brute_force::build(...);` + * raft::serialize(handle, os, index); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle the raft handle + * @param[in] os output stream + * @param[in] index brute force index + * + */ +template +void serialize(raft::resources const& handle, std::ostream& os, const index& index) +{ + RAFT_LOG_DEBUG( + "Saving brute force index, size %zu, dim %u", static_cast(index.size()), index.dim()); + + auto dtype_string = raft::detail::numpy_serializer::get_numpy_dtype().to_string(); + dtype_string.resize(4); + os << dtype_string; + + serialize_scalar(handle, os, serialization_version); + serialize_scalar(handle, os, index.size()); + serialize_scalar(handle, os, index.dim()); + serialize_scalar(handle, os, index.metric()); + serialize_scalar(handle, os, index.metric_arg()); + serialize_mdspan(handle, os, index.dataset()); + auto has_norms = index.has_norms(); + serialize_scalar(handle, os, has_norms); + if (has_norms) { serialize_mdspan(handle, os, index.norms()); } + resource::sync_stream(handle); +} + +/** + * Save the index to file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * // create an index with `auto index = brute_force::build(...);` + * raft::serialize(handle, filename, index); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle the raft handle + * @param[in] filename the file name for saving the index + * @param[in] index brute force index + * + */ +template +void serialize(raft::resources const& handle, + const std::string& filename, + const index& index) +{ + detail::serialize(handle, filename, index); +} + +/** + * Load index from input stream + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::resources handle; + * + * // create an input stream + * std::istream is(std::cin.rdbuf()); + * using T = float; // data element type + * using IdxT = int; // type of the index + * auto index = raft::deserialize(handle, is); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle the raft handle + * @param[in] is input stream + * + * @return raft::neighbors::brute_force::index + */ +template +auto deserialize(raft::resources const& handle, std::istream& is) +{ + char dtype_string[4]; + is.read(dtype_string, 4); + + auto ver = deserialize_scalar(handle, is); + if (ver != serialization_version) { + RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver); + } + auto rows = deserialize_scalar(handle, is); + auto dim = deserialize_scalar(handle, is); + auto metric = deserialize_scalar(handle, is); + auto metric_arg = deserialize_scalar(handle, is); + + auto dataset_storage = raft::make_host_matrix(rows, dim); + deserialize_mdspan(handle, is, dataset_storage.view()); + + auto has_norms = deserialize_scalar(handle, is); + auto norms_storage = has_norms ? std::optional{raft::make_host_vector(rows)} + : std::optional>{}; + if (has_norms) { deserialize_mdspan(handle, is, norms_storage->view()); } + auto result = + build(handle, + dataset_storage.view(), + norms_storage ? norms_storage->view() : std::optional>{}, + metric, + metric_arg); + resource::sync_stream(handle); + + return result; +} + +/** + * Load index from file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * using T = float; // data element type + * using IdxT = int; // type of the index + * auto index = raft::deserialize(handle, filename); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle the raft handle + * @param[in] filename the name of the file that stores the index + * + * @return raft::neighbors::brute_force::index + */ +template +auto deserialize(raft::resources const& handle, const std::string& filename) +{ + std::ifstream is(filename, std::ios::in | std::ios::binary); + RAFT_EXPECTS(is, "Cannot open file %s", filename.c_str()); + + return deserialize(handle, is); +} + +/**@}*/ + +} // namespace raft::neighbors::brute_force From 606f5e3bcdb56334de7df9cb215799572856840a Mon Sep 17 00:00:00 2001 From: William Hicks Date: Sat, 2 Dec 2023 19:03:46 -0500 Subject: [PATCH 04/17] Add test header for brute force index --- cpp/test/neighbors/ann_brute_force.cuh | 252 +++++++++++++++++++++++++ 1 file changed, 252 insertions(+) create mode 100644 cpp/test/neighbors/ann_brute_force.cuh diff --git a/cpp/test/neighbors/ann_brute_force.cuh b/cpp/test/neighbors/ann_brute_force.cuh new file mode 100644 index 0000000000..cbdc05fc83 --- /dev/null +++ b/cpp/test/neighbors/ann_brute_force.cuh @@ -0,0 +1,252 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "../test_utils.cuh" +#include "ann_utils.cuh" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include +#include + +#include +#include +#include + +namespace raft::neighbors::brute_force { + +template +struct AnnBruteForceInputs { + IdxT num_queries; + IdxT num_db_vecs; + IdxT dim; + IdxT k; + raft::distance::DistanceType metric; + bool host_dataset; +}; + +template +::std::ostream& operator<<(::std::ostream& os, const AnnBruteForceInputs& p) +{ + os << "{ " << p.num_queries << ", " << p.num_db_vecs << ", " << p.dim << ", " << p.k << ", " + << static_cast(p.metric) << ", " << p.host_dataset << '}' << std::endl; + return os; +} + +template +class AnnBruteForceTest : public ::testing::TestWithParam> { + public: + AnnBruteForceTest() + : stream_(resource::get_cuda_stream(handle_)), + ps(::testing::TestWithParam>::GetParam()), + database(0, stream_), + search_queries(0, stream_) + { + } + + void testBruteForce() + { + size_t queries_size = ps.num_queries * ps.k; + std::vector indices_bruteforce(queries_size); + std::vector indices_naive(queries_size); + std::vector distances_bruteforce(queries_size); + std::vector distances_naive(queries_size); + + { + rmm::device_uvector distances_naive_dev(queries_size, stream_); + rmm::device_uvector indices_naive_dev(queries_size, stream_); + naive_knn(handle_, + distances_naive_dev.data(), + indices_naive_dev.data(), + search_queries.data(), + database.data(), + ps.num_queries, + ps.num_db_vecs, + ps.dim, + ps.k, + ps.metric); + update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); + update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); + resource::sync_stream(handle_); + } + + { + // Require exact result for brute force + auto min_recall = double{1}; + + rmm::device_uvector distances_bruteforce_dev(queries_size, stream_); + rmm::device_uvector indices_bruteforce_dev(queries_size, stream_); + { + brute_force::index_params index_params; + brute_force::search_params search_params; + index_params.metric = ps.metric; + index_params.metric_arg = 0; + + brute_force::index idx(handle_, index_params, ps.dim); + brute_force::index index_2(handle_, index_params, ps.dim); + + if (!ps.host_dataset) { + auto database_view = raft::make_device_matrix_view( + (const DataT*)database.data(), ps.num_db_vecs, ps.dim); + idx = brute_force::build(handle_, index_params, database_view); + rmm::device_uvector vector_indices(ps.num_db_vecs, stream_); + thrust::sequence(resource::get_thrust_policy(handle_), + thrust::device_pointer_cast(vector_indices.data()), + thrust::device_pointer_cast(vector_indices.data() + ps.num_db_vecs)); + resource::sync_stream(handle_); + + } else { + auto host_database = raft::make_host_matrix(ps.num_db_vecs, ps.dim); + raft::copy( + host_database.data_handle(), database.data(), ps.num_db_vecs * ps.dim, stream_); + idx = brute_force::build( + handle_, index_params, raft::make_const_mdspan(host_database.view())); + + auto vector_indices = raft::make_host_vector(handle_, ps.num_db_vecs); + std::iota(vector_indices.data_handle(), vector_indices.data_handle() + ps.num_db_vecs, 0); + } + + auto search_queries_view = raft::make_device_matrix_view( + search_queries.data(), ps.num_queries, ps.dim); + auto indices_out_view = raft::make_device_matrix_view( + indices_bruteforce_dev.data(), ps.num_queries, ps.k); + auto dists_out_view = raft::make_device_matrix_view( + distances_bruteforce_dev.data(), ps.num_queries, ps.k); + brute_force::serialize(handle_, "brute_force_index", index_2); + + auto index_loaded = brute_force::deserialize(handle_, "brute_force_index"); + ASSERT_EQ(index_2.size(), index_loaded.size()); + + brute_force::search(handle_, + search_params, + index_loaded, + search_queries_view, + indices_out_view, + dists_out_view); + + update_host( + distances_bruteforce.data(), distances_bruteforce_dev.data(), queries_size, stream_); + update_host( + indices_bruteforce.data(), indices_bruteforce_dev.data(), queries_size, stream_); + resource::sync_stream(handle_); + } + + ASSERT_TRUE(eval_neighbours(indices_naive, + indices_bruteforce, + distances_naive, + distances_bruteforce, + ps.num_queries, + ps.k, + 0.001, + min_recall)); + } + } + + void SetUp() override + { + database.resize(ps.num_db_vecs * ps.dim, stream_); + search_queries.resize(ps.num_queries * ps.dim, stream_); + + raft::random::RngState r(1234ULL); + if constexpr (std::is_same{}) { + raft::random::uniform( + handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(0.1), DataT(2.0)); + raft::random::uniform( + handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(0.1), DataT(2.0)); + } else { + raft::random::uniformInt( + handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(1), DataT(20)); + raft::random::uniformInt( + handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(1), DataT(20)); + } + resource::sync_stream(handle_); + } + + void TearDown() override + { + resource::sync_stream(handle_); + database.resize(0, stream_); + search_queries.resize(0, stream_); + } + + private: + raft::resources handle_; + rmm::cuda_stream_view stream_; + AnnBruteForceInputs ps; + rmm::device_uvector database; + rmm::device_uvector search_queries; +}; + +const std::vector> inputs = { + // test various dims (aligned and not aligned to vector sizes) + {1000, 10000, 1, 16, raft::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 2, 16, raft::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 3, 16, raft::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 4, 16, raft::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 5, 16, raft::distance::DistanceType::InnerProduct, true}, + {1000, 10000, 8, 16, raft::distance::DistanceType::InnerProduct, true}, + {1000, 10000, 5, 16, raft::distance::DistanceType::L2SqrtExpanded, true}, + {1000, 10000, 8, 16, raft::distance::DistanceType::L2SqrtExpanded, true}, + + // test dims that do not fit into kernel shared memory limits + {1000, 10000, 2048, 16, raft::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 2049, 16, raft::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 2050, 16, raft::distance::DistanceType::InnerProduct, true}, + {1000, 10000, 2051, 16, raft::distance::DistanceType::InnerProduct, true}, + {1000, 10000, 2052, 16, raft::distance::DistanceType::InnerProduct, true}, + {1000, 10000, 2053, 16, raft::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 2056, 16, raft::distance::DistanceType::L2Expanded, true}, + + // host input data + {1000, 10000, 16, 10, raft::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 16, 10, raft::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 16, 10, raft::distance::DistanceType::L2Expanded, false}, + {100, 10000, 16, 10, raft::distance::DistanceType::L2Expanded, false}, + {20, 100000, 16, 10, raft::distance::DistanceType::L2Expanded, false}, + {1000, 100000, 16, 10, raft::distance::DistanceType::L2Expanded, false}, + {10000, 131072, 8, 10, raft::distance::DistanceType::L2Expanded, false}, + + {1000, 10000, 16, 10, raft::distance::DistanceType::InnerProduct, false}}; +} // namespace raft::neighbors::brute_force From e47117e04e9c83d6c0b7a6cfd575e19be958b458 Mon Sep 17 00:00:00 2001 From: William Hicks Date: Sat, 2 Dec 2023 19:04:43 -0500 Subject: [PATCH 05/17] Add filename serialization for brute force --- .../raft/neighbors/brute_force_serialize.cuh | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/cpp/include/raft/neighbors/brute_force_serialize.cuh b/cpp/include/raft/neighbors/brute_force_serialize.cuh index 0642385cb4..6b23a519d6 100644 --- a/cpp/include/raft/neighbors/brute_force_serialize.cuh +++ b/cpp/include/raft/neighbors/brute_force_serialize.cuh @@ -16,8 +16,14 @@ #pragma once +#include +#include +#include + namespace raft::neighbors::brute_force { +auto static constexpr serialization_version = 0; + /** * \defgroup brute_force_serialize Brute Force Serialize * @{ @@ -94,11 +100,11 @@ void serialize(raft::resources const& handle, std::ostream& os, const index& * */ template -void serialize(raft::resources const& handle, - const std::string& filename, - const index& index) +void serialize(raft::resources const& handle, const std::string& filename, const index& index) { - detail::serialize(handle, filename, index); + auto os = std::ofstream{filename, std::ios::out | std::ios::binary}; + RAFT_EXPECTS(os, "Cannot open file %s", filename.c_str()); + serialize(handle, os, index); } /** @@ -129,8 +135,8 @@ void serialize(raft::resources const& handle, template auto deserialize(raft::resources const& handle, std::istream& is) { - char dtype_string[4]; - is.read(dtype_string, 4); + auto dtype_string = std::array{}; + is.read(dtype_string.data(), 4); auto ver = deserialize_scalar(handle, is); if (ver != serialization_version) { @@ -187,7 +193,7 @@ auto deserialize(raft::resources const& handle, std::istream& is) template auto deserialize(raft::resources const& handle, const std::string& filename) { - std::ifstream is(filename, std::ios::in | std::ios::binary); + auto is = std::ifstream{filename, std::ios::in | std::ios::binary}; RAFT_EXPECTS(is, "Cannot open file %s", filename.c_str()); return deserialize(handle, is); From 26da060502a224455badd082f523a0cb1e8d57f7 Mon Sep 17 00:00:00 2001 From: William Hicks Date: Sun, 3 Dec 2023 10:53:46 -0500 Subject: [PATCH 06/17] Add brute force index tests --- build.sh | 3 +- cpp/test/CMakeLists.txt | 5 ++++ .../neighbors/ann_brute_force/test_float.cu | 28 +++++++++++++++++++ 3 files changed, 35 insertions(+), 1 deletion(-) create mode 100644 cpp/test/neighbors/ann_brute_force/test_float.cu diff --git a/build.sh b/build.sh index 51e59cc259..200d6710e0 100755 --- a/build.sh +++ b/build.sh @@ -78,7 +78,7 @@ INSTALL_TARGET=install BUILD_REPORT_METRICS="" BUILD_REPORT_INCL_CACHE_STATS=OFF -TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_NN_DESCENT_TEST;NEIGHBORS_ANN_IVF_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" +TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_BRUTE_FORCE_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_NN_DESCENT_TEST;NEIGHBORS_ANN_IVF_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" BENCH_TARGETS="CLUSTER_BENCH;CORE_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH" CACHE_ARGS="" @@ -323,6 +323,7 @@ if hasArg tests || (( ${NUMARGS} == 0 )); then if [[ $CMAKE_TARGET == *"CLUSTER_TEST"* || \ $CMAKE_TARGET == *"DISTANCE_TEST"* || \ $CMAKE_TARGET == *"MATRIX_TEST"* || \ + $CMAKE_TARGET == *"NEIGHBORS_ANN_BRUTE_FORCE_TEST"* || \ $CMAKE_TARGET == *"NEIGHBORS_ANN_CAGRA_TEST"* || \ $CMAKE_TARGET == *"NEIGHBORS_ANN_IVF_TEST"* || \ $CMAKE_TARGET == *"NEIGHBORS_ANN_NN_DESCENT_TEST"* || \ diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 6c03da8d7f..26dc626402 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -351,6 +351,11 @@ if(BUILD_TESTS) EXPLICIT_INSTANTIATE_ONLY ) + ConfigureTest( + NAME NEIGHBORS_ANN_BRUTE_FORCE_TEST PATH test/neighbors/ann_brute_force/test_float.cu LIB + EXPLICIT_INSTANTIATE_ONLY GPUS 1 PERCENT 100 + ) + ConfigureTest( NAME NEIGHBORS_ANN_CAGRA_TEST diff --git a/cpp/test/neighbors/ann_brute_force/test_float.cu b/cpp/test/neighbors/ann_brute_force/test_float.cu new file mode 100644 index 0000000000..f618f44b61 --- /dev/null +++ b/cpp/test/neighbors/ann_brute_force/test_float.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../ann_brute_force.cuh" + +namespace raft::neighbors::brute_force { + +using AnnBruteForceTest_float = AnnBruteForceTest; +TEST_P(AnnBruteForceTest_float, AnnBruteForce) { this->testBruteForce(); } + +INSTANTIATE_TEST_CASE_P(AnnBruteForceTest, AnnBruteForceTest_float, ::testing::ValuesIn(inputs)); + +} // namespace raft::neighbors::brute_force From e0f63c8f8b3a54075741e130c553fdd4d6dd259a Mon Sep 17 00:00:00 2001 From: William Hicks Date: Sun, 3 Dec 2023 16:34:08 -0500 Subject: [PATCH 07/17] Fix brute force explicit instantiations --- cpp/include/raft/linalg/norm.cuh | 3 +- .../raft/neighbors/brute_force-ext.cuh | 19 ++++++++ .../raft/neighbors/brute_force-inl.cuh | 17 ++++++- .../raft/neighbors/brute_force_serialize.cuh | 41 ++++++++-------- .../raft/neighbors/brute_force_types.hpp | 24 ++++++++-- .../neighbors/brute_force_knn_index_float.cu | 36 ++++++++++---- cpp/test/neighbors/ann_brute_force.cuh | 48 ++++++++----------- 7 files changed, 123 insertions(+), 65 deletions(-) diff --git a/cpp/include/raft/linalg/norm.cuh b/cpp/include/raft/linalg/norm.cuh index 9dad96356b..0d472c5476 100644 --- a/cpp/include/raft/linalg/norm.cuh +++ b/cpp/include/raft/linalg/norm.cuh @@ -20,6 +20,7 @@ #include "detail/norm.cuh" #include "linalg_types.hpp" +#include #include #include @@ -154,4 +155,4 @@ void norm(raft::resources const& handle, }; // end namespace linalg }; // end namespace raft -#endif \ No newline at end of file +#endif diff --git a/cpp/include/raft/neighbors/brute_force-ext.cuh b/cpp/include/raft/neighbors/brute_force-ext.cuh index b8571e7c0e..ddce6d8fda 100644 --- a/cpp/include/raft/neighbors/brute_force-ext.cuh +++ b/cpp/include/raft/neighbors/brute_force-ext.cuh @@ -144,6 +144,14 @@ extern template void search( raft::device_matrix_view neighbors, raft::device_matrix_view distances); +extern template void search( + raft::resources const& res, + search_params const& params, + const raft::neighbors::brute_force::index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances); + extern template raft::neighbors::brute_force::index build( raft::resources const& res, raft::device_matrix_view dataset, @@ -154,6 +162,17 @@ extern template raft::neighbors::brute_force::index build( raft::resources const& res, index_params const& params, raft::device_matrix_view dataset); + +extern template raft::neighbors::brute_force::index build( + raft::resources const& res, + raft::host_matrix_view dataset, + raft::distance::DistanceType metric, + float metric_arg); + +extern template raft::neighbors::brute_force::index build( + raft::resources const& res, + index_params const& params, + raft::host_matrix_view dataset); } // namespace raft::neighbors::brute_force #define instantiate_raft_neighbors_brute_force_fused_l2_knn( \ diff --git a/cpp/include/raft/neighbors/brute_force-inl.cuh b/cpp/include/raft/neighbors/brute_force-inl.cuh index 2a0a503864..d96370bcd1 100644 --- a/cpp/include/raft/neighbors/brute_force-inl.cuh +++ b/cpp/include/raft/neighbors/brute_force-inl.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include #include @@ -303,6 +304,18 @@ index build(raft::resources const& res, // certain distance metrics can benefit by pre-calculating the norms for the index dataset // which lets us avoid calculating these at query time std::optional> norms; + // TODO(wphicks): Replace once mdbuffer is available + auto dataset_storage = std::optional>{}; + auto dataset_view = [&res, &dataset_storage, dataset]() { + if constexpr (std::is_same_v>) { + return dataset; + } else { + dataset_storage = make_device_matrix(res, dataset.extent(0), dataset.extent(1)); + raft::copy(res, dataset_storage->view(), dataset); + return raft::make_const_mdspan(dataset_storage->view()); + } + }(); if (metric == raft::distance::DistanceType::L2Expanded || metric == raft::distance::DistanceType::L2SqrtExpanded || metric == raft::distance::DistanceType::CosineExpanded) { @@ -310,14 +323,14 @@ index build(raft::resources const& res, // cosine needs the l2norm, where as l2 distances needs the squared norm if (metric == raft::distance::DistanceType::CosineExpanded) { raft::linalg::norm(res, - dataset, + dataset_view, norms->view(), raft::linalg::NormType::L2Norm, raft::linalg::Apply::ALONG_ROWS, raft::sqrt_op{}); } else { raft::linalg::norm(res, - dataset, + dataset_view, norms->view(), raft::linalg::NormType::L2Norm, raft::linalg::Apply::ALONG_ROWS); diff --git a/cpp/include/raft/neighbors/brute_force_serialize.cuh b/cpp/include/raft/neighbors/brute_force_serialize.cuh index 6b23a519d6..759b7efe39 100644 --- a/cpp/include/raft/neighbors/brute_force_serialize.cuh +++ b/cpp/include/raft/neighbors/brute_force_serialize.cuh @@ -46,7 +46,6 @@ auto static constexpr serialization_version = 0; * @endcode * * @tparam T data element type - * @tparam IdxT type of the indices * * @param[in] handle the raft handle * @param[in] os output stream @@ -92,14 +91,13 @@ void serialize(raft::resources const& handle, std::ostream& os, const index& * @endcode * * @tparam T data element type - * @tparam IdxT type of the indices * * @param[in] handle the raft handle * @param[in] filename the file name for saving the index * @param[in] index brute force index * */ -template +template void serialize(raft::resources const& handle, const std::string& filename, const index& index) { auto os = std::ofstream{filename, std::ios::out | std::ios::binary}; @@ -120,17 +118,15 @@ void serialize(raft::resources const& handle, const std::string& filename, const * // create an input stream * std::istream is(std::cin.rdbuf()); * using T = float; // data element type - * using IdxT = int; // type of the index - * auto index = raft::deserialize(handle, is); + * auto index = raft::deserialize(handle, is); * @endcode * * @tparam T data element type - * @tparam IdxT type of the indices * * @param[in] handle the raft handle * @param[in] is input stream * - * @return raft::neighbors::brute_force::index + * @return raft::neighbors::brute_force::index */ template auto deserialize(raft::resources const& handle, std::istream& is) @@ -151,15 +147,22 @@ auto deserialize(raft::resources const& handle, std::istream& is) deserialize_mdspan(handle, is, dataset_storage.view()); auto has_norms = deserialize_scalar(handle, is); - auto norms_storage = has_norms ? std::optional{raft::make_host_vector(rows)} - : std::optional>{}; - if (has_norms) { deserialize_mdspan(handle, is, norms_storage->view()); } - auto result = - build(handle, - dataset_storage.view(), - norms_storage ? norms_storage->view() : std::optional>{}, - metric, - metric_arg); + auto norms_storage = has_norms ? std::optional{raft::make_host_vector(rows)} + : std::optional>{}; + // TODO(wphicks): Use mdbuffer here when available + auto norms_storage_dev = + has_norms ? std::optional{raft::make_device_vector(handle, rows)} + : std::optional>{}; + if (has_norms) { + deserialize_mdspan(handle, is, norms_storage->view()); + raft::copy(handle, norms_storage_dev->view(), norms_storage->view()); + } + + auto result = index(handle, + raft::make_const_mdspan(dataset_storage.view()), + std::move(norms_storage_dev), + metric, + metric_arg); resource::sync_stream(handle); return result; @@ -178,17 +181,15 @@ auto deserialize(raft::resources const& handle, std::istream& is) * // create a string with a filepath * std::string filename("/path/to/index"); * using T = float; // data element type - * using IdxT = int; // type of the index - * auto index = raft::deserialize(handle, filename); + * auto index = raft::deserialize(handle, filename); * @endcode * * @tparam T data element type - * @tparam IdxT type of the indices * * @param[in] handle the raft handle * @param[in] filename the name of the file that stores the index * - * @return raft::neighbors::brute_force::index + * @return raft::neighbors::brute_force::index */ template auto deserialize(raft::resources const& handle, const std::string& filename) diff --git a/cpp/include/raft/neighbors/brute_force_types.hpp b/cpp/include/raft/neighbors/brute_force_types.hpp index 9c8327fb4a..4857685798 100644 --- a/cpp/include/raft/neighbors/brute_force_types.hpp +++ b/cpp/include/raft/neighbors/brute_force_types.hpp @@ -19,6 +19,7 @@ #include "ann_types.hpp" #include +#include #include #include #include @@ -130,6 +131,22 @@ struct index : ann::index { { } + template + index(raft::resources const& res, + index_params const& params, + mdspan, row_major, data_accessor> dataset, + std::optional>&& norms = std::nullopt) + : ann::index(), + metric_(params.metric), + dataset_(make_device_matrix(res, 0, 0)), + norms_(std::move(norms)), + metric_arg_(params.metric_arg) + { + if (norms_) { norms_view_ = make_const_mdspan(norms_.value().view()); } + update_dataset(res, dataset); + resource::sync_stream(res); + } + private: /** * Replace the dataset with a new dataset. @@ -148,11 +165,8 @@ struct index : ann::index { void update_dataset(raft::resources const& res, raft::host_matrix_view dataset) { - dataset_ = make_device_matrix(dataset.extents(0), dataset.extents(1)); - raft::copy(dataset_.data_handle(), - dataset.data_handle(), - dataset.size(), - resource::get_cuda_stream(res)); + dataset_ = make_device_matrix(res, dataset.extent(0), dataset.extent(1)); + raft::copy(res, dataset_.view(), dataset); dataset_view_ = make_const_mdspan(dataset_.view()); } diff --git a/cpp/src/neighbors/brute_force_knn_index_float.cu b/cpp/src/neighbors/brute_force_knn_index_float.cu index 899e029df7..faf99ceb3c 100644 --- a/cpp/src/neighbors/brute_force_knn_index_float.cu +++ b/cpp/src/neighbors/brute_force_knn_index_float.cu @@ -16,6 +16,9 @@ */ #include +#include +#include +#include #include template void raft::neighbors::brute_force::search( @@ -48,13 +51,28 @@ template void raft::neighbors::brute_force::search( raft::device_matrix_view neighbors, raft::device_matrix_view distances); -template raft::neighbors::brute_force::index raft::neighbors::brute_force::build( - raft::resources const& res, - raft::device_matrix_view dataset, - raft::distance::DistanceType metric, - float metric_arg); +template raft::neighbors::brute_force::index raft::neighbors::brute_force:: + build::accessor_type>( + raft::resources const& res, + raft::host_matrix_view dataset, + raft::distance::DistanceType metric, + float metric_arg); -template raft::neighbors::brute_force::index raft::neighbors::brute_force::build( - raft::resources const& res, - raft::neighbors::brute_force::index_params const& params, - raft::device_matrix_view dataset); +template raft::neighbors::brute_force::index raft::neighbors::brute_force:: + build::accessor_type>( + raft::resources const& res, + raft::device_matrix_view dataset, + raft::distance::DistanceType metric, + float metric_arg); + +template raft::neighbors::brute_force::index raft::neighbors::brute_force:: + build::accessor_type>( + raft::resources const& res, + raft::neighbors::brute_force::index_params const& params, + raft::host_matrix_view dataset); + +template raft::neighbors::brute_force::index raft::neighbors::brute_force:: + build::accessor_type>( + raft::resources const& res, + raft::neighbors::brute_force::index_params const& params, + raft::device_matrix_view dataset); diff --git a/cpp/test/neighbors/ann_brute_force.cuh b/cpp/test/neighbors/ann_brute_force.cuh index cbdc05fc83..43452134e3 100644 --- a/cpp/test/neighbors/ann_brute_force.cuh +++ b/cpp/test/neighbors/ann_brute_force.cuh @@ -119,34 +119,25 @@ class AnnBruteForceTest : public ::testing::TestWithParam distances_bruteforce_dev(queries_size, stream_); rmm::device_uvector indices_bruteforce_dev(queries_size, stream_); { - brute_force::index_params index_params; - brute_force::search_params search_params; + brute_force::index_params index_params{}; + brute_force::search_params search_params{}; index_params.metric = ps.metric; index_params.metric_arg = 0; - brute_force::index idx(handle_, index_params, ps.dim); - brute_force::index index_2(handle_, index_params, ps.dim); - - if (!ps.host_dataset) { - auto database_view = raft::make_device_matrix_view( - (const DataT*)database.data(), ps.num_db_vecs, ps.dim); - idx = brute_force::build(handle_, index_params, database_view); - rmm::device_uvector vector_indices(ps.num_db_vecs, stream_); - thrust::sequence(resource::get_thrust_policy(handle_), - thrust::device_pointer_cast(vector_indices.data()), - thrust::device_pointer_cast(vector_indices.data() + ps.num_db_vecs)); - resource::sync_stream(handle_); - - } else { - auto host_database = raft::make_host_matrix(ps.num_db_vecs, ps.dim); - raft::copy( - host_database.data_handle(), database.data(), ps.num_db_vecs * ps.dim, stream_); - idx = brute_force::build( - handle_, index_params, raft::make_const_mdspan(host_database.view())); - - auto vector_indices = raft::make_host_vector(handle_, ps.num_db_vecs); - std::iota(vector_indices.data_handle(), vector_indices.data_handle() + ps.num_db_vecs, 0); - } + auto device_dataset = std::optional>{}; + auto idx = [this, &index_params]() { + if (ps.host_dataset) { + auto host_database = raft::make_host_matrix(ps.num_db_vecs, ps.dim); + raft::copy( + host_database.data_handle(), database.data(), ps.num_db_vecs * ps.dim, stream_); + return brute_force::build( + handle_, index_params, raft::make_const_mdspan(host_database.view())); + } else { + auto database_view = raft::make_device_matrix_view( + (const DataT*)database.data(), ps.num_db_vecs, ps.dim); + return brute_force::build(handle_, index_params, database_view); + } + }(); auto search_queries_view = raft::make_device_matrix_view( search_queries.data(), ps.num_queries, ps.dim); @@ -154,10 +145,11 @@ class AnnBruteForceTest : public ::testing::TestWithParam( distances_bruteforce_dev.data(), ps.num_queries, ps.k); - brute_force::serialize(handle_, "brute_force_index", index_2); + brute_force::serialize(handle_, std::string{"brute_force_index"}, idx); - auto index_loaded = brute_force::deserialize(handle_, "brute_force_index"); - ASSERT_EQ(index_2.size(), index_loaded.size()); + auto index_loaded = + brute_force::deserialize(handle_, std::string{"brute_force_index"}); + ASSERT_EQ(idx.size(), index_loaded.size()); brute_force::search(handle_, search_params, From 07bd3d38d7952e7c478de8df91415753cab8788b Mon Sep 17 00:00:00 2001 From: William Hicks Date: Sun, 3 Dec 2023 16:58:13 -0500 Subject: [PATCH 08/17] Fix dim serialization type --- cpp/include/raft/neighbors/brute_force_types.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/neighbors/brute_force_types.hpp b/cpp/include/raft/neighbors/brute_force_types.hpp index 4857685798..b53a9cacd9 100644 --- a/cpp/include/raft/neighbors/brute_force_types.hpp +++ b/cpp/include/raft/neighbors/brute_force_types.hpp @@ -56,10 +56,10 @@ struct index : ann::index { } /** Total length of the index (number of vectors). */ - [[nodiscard]] constexpr inline int64_t size() const noexcept { return dataset_view_.extent(0); } + [[nodiscard]] constexpr inline auto size() const noexcept { return dataset_view_.extent(0); } /** Dimensionality of the data. */ - [[nodiscard]] constexpr inline uint32_t dim() const noexcept { return dataset_view_.extent(1); } + [[nodiscard]] constexpr inline auto dim() const noexcept { return dataset_view_.extent(1); } /** Dataset [size, dim] */ [[nodiscard]] inline auto dataset() const noexcept From 1deb340ef8e77927adbebe7057deeb9f6536324c Mon Sep 17 00:00:00 2001 From: William Hicks Date: Mon, 4 Dec 2023 15:04:25 -0500 Subject: [PATCH 09/17] Update cpp/include/raft/neighbors/brute_force_serialize.cuh Co-authored-by: Ben Frederickson --- cpp/include/raft/neighbors/brute_force_serialize.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/neighbors/brute_force_serialize.cuh b/cpp/include/raft/neighbors/brute_force_serialize.cuh index 759b7efe39..30efc106d5 100644 --- a/cpp/include/raft/neighbors/brute_force_serialize.cuh +++ b/cpp/include/raft/neighbors/brute_force_serialize.cuh @@ -42,7 +42,7 @@ auto static constexpr serialization_version = 0; * // create an output stream * std::ostream os(std::cout.rdbuf()); * // create an index with `auto index = brute_force::build(...);` - * raft::serialize(handle, os, index); + * raft::neighbors::brute_force::serialize(handle, os, index); * @endcode * * @tparam T data element type From 35e2ba6f53707d0e42e2e79b043b1484f0d52413 Mon Sep 17 00:00:00 2001 From: William Hicks Date: Mon, 4 Dec 2023 15:04:33 -0500 Subject: [PATCH 10/17] Update cpp/include/raft/neighbors/brute_force_serialize.cuh Co-authored-by: Ben Frederickson --- cpp/include/raft/neighbors/brute_force_serialize.cuh | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/include/raft/neighbors/brute_force_serialize.cuh b/cpp/include/raft/neighbors/brute_force_serialize.cuh index 30efc106d5..06809c3468 100644 --- a/cpp/include/raft/neighbors/brute_force_serialize.cuh +++ b/cpp/include/raft/neighbors/brute_force_serialize.cuh @@ -36,6 +36,7 @@ auto static constexpr serialization_version = 0; * * @code{.cpp} * #include + * #include * * raft::resources handle; * From aee5e38f59d7339a7a5c3cc2e70d36757992a61a Mon Sep 17 00:00:00 2001 From: William Hicks Date: Mon, 4 Dec 2023 15:34:27 -0500 Subject: [PATCH 11/17] Use devArrMatchKnnPair for brute force index evaluation --- cpp/test/neighbors/ann_brute_force.cuh | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/cpp/test/neighbors/ann_brute_force.cuh b/cpp/test/neighbors/ann_brute_force.cuh index 43452134e3..0906ccd760 100644 --- a/cpp/test/neighbors/ann_brute_force.cuh +++ b/cpp/test/neighbors/ann_brute_force.cuh @@ -17,6 +17,7 @@ #include "../test_utils.cuh" #include "ann_utils.cuh" +#include "knn_utils.cuh" #include #include #include @@ -114,8 +115,6 @@ class AnnBruteForceTest : public ::testing::TestWithParam distances_bruteforce_dev(queries_size, stream_); rmm::device_uvector indices_bruteforce_dev(queries_size, stream_); { @@ -165,14 +164,16 @@ class AnnBruteForceTest : public ::testing::TestWithParam Date: Mon, 4 Dec 2023 16:52:11 -0500 Subject: [PATCH 12/17] Correct handling of exact test match --- cpp/test/neighbors/ann_brute_force.cuh | 142 ++++++++++++------------- 1 file changed, 68 insertions(+), 74 deletions(-) diff --git a/cpp/test/neighbors/ann_brute_force.cuh b/cpp/test/neighbors/ann_brute_force.cuh index 0906ccd760..1c6e4ad2f6 100644 --- a/cpp/test/neighbors/ann_brute_force.cuh +++ b/cpp/test/neighbors/ann_brute_force.cuh @@ -95,85 +95,79 @@ class AnnBruteForceTest : public ::testing::TestWithParam distances_bruteforce(queries_size); std::vector distances_naive(queries_size); - { - rmm::device_uvector distances_naive_dev(queries_size, stream_); - rmm::device_uvector indices_naive_dev(queries_size, stream_); - naive_knn(handle_, - distances_naive_dev.data(), - indices_naive_dev.data(), - search_queries.data(), - database.data(), - ps.num_queries, - ps.num_db_vecs, - ps.dim, - ps.k, - ps.metric); - update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); - update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); - resource::sync_stream(handle_); - } + rmm::device_uvector distances_naive_dev(queries_size, stream_); + rmm::device_uvector indices_naive_dev(queries_size, stream_); + naive_knn(handle_, + distances_naive_dev.data(), + indices_naive_dev.data(), + search_queries.data(), + database.data(), + ps.num_queries, + ps.num_db_vecs, + ps.dim, + ps.k, + ps.metric); + update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); + update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); + resource::sync_stream(handle_); { // Require exact result for brute force rmm::device_uvector distances_bruteforce_dev(queries_size, stream_); rmm::device_uvector indices_bruteforce_dev(queries_size, stream_); - { - brute_force::index_params index_params{}; - brute_force::search_params search_params{}; - index_params.metric = ps.metric; - index_params.metric_arg = 0; - - auto device_dataset = std::optional>{}; - auto idx = [this, &index_params]() { - if (ps.host_dataset) { - auto host_database = raft::make_host_matrix(ps.num_db_vecs, ps.dim); - raft::copy( - host_database.data_handle(), database.data(), ps.num_db_vecs * ps.dim, stream_); - return brute_force::build( - handle_, index_params, raft::make_const_mdspan(host_database.view())); - } else { - auto database_view = raft::make_device_matrix_view( - (const DataT*)database.data(), ps.num_db_vecs, ps.dim); - return brute_force::build(handle_, index_params, database_view); - } - }(); - - auto search_queries_view = raft::make_device_matrix_view( - search_queries.data(), ps.num_queries, ps.dim); - auto indices_out_view = raft::make_device_matrix_view( - indices_bruteforce_dev.data(), ps.num_queries, ps.k); - auto dists_out_view = raft::make_device_matrix_view( - distances_bruteforce_dev.data(), ps.num_queries, ps.k); - brute_force::serialize(handle_, std::string{"brute_force_index"}, idx); - - auto index_loaded = - brute_force::deserialize(handle_, std::string{"brute_force_index"}); - ASSERT_EQ(idx.size(), index_loaded.size()); - - brute_force::search(handle_, - search_params, - index_loaded, - search_queries_view, - indices_out_view, - dists_out_view); - - update_host( - distances_bruteforce.data(), distances_bruteforce_dev.data(), queries_size, stream_); - update_host( - indices_bruteforce.data(), indices_bruteforce_dev.data(), queries_size, stream_); - resource::sync_stream(handle_); - } - - ASSERT_TRUE(raft::spatial::knn::devArrMatchKnnPair( - indices_naive.data_handle(), - indices_bruteforce.data_handle(), - distances_naive.data_handle(), - distances_bruteforce.data_handle(), - ps.num_queries, - ps.k, - 0.001f, - stream_, - true)); + brute_force::index_params index_params{}; + brute_force::search_params search_params{}; + index_params.metric = ps.metric; + index_params.metric_arg = 0; + + auto device_dataset = std::optional>{}; + auto idx = [this, &index_params]() { + if (ps.host_dataset) { + auto host_database = raft::make_host_matrix(ps.num_db_vecs, ps.dim); + raft::copy( + host_database.data_handle(), database.data(), ps.num_db_vecs * ps.dim, stream_); + return brute_force::build( + handle_, index_params, raft::make_const_mdspan(host_database.view())); + } else { + auto database_view = raft::make_device_matrix_view( + (const DataT*)database.data(), ps.num_db_vecs, ps.dim); + return brute_force::build(handle_, index_params, database_view); + } + }(); + + auto search_queries_view = raft::make_device_matrix_view( + search_queries.data(), ps.num_queries, ps.dim); + auto indices_out_view = raft::make_device_matrix_view( + indices_bruteforce_dev.data(), ps.num_queries, ps.k); + auto dists_out_view = raft::make_device_matrix_view( + distances_bruteforce_dev.data(), ps.num_queries, ps.k); + brute_force::serialize(handle_, std::string{"brute_force_index"}, idx); + + auto index_loaded = + brute_force::deserialize(handle_, std::string{"brute_force_index"}); + ASSERT_EQ(idx.size(), index_loaded.size()); + + brute_force::search(handle_, + search_params, + index_loaded, + search_queries_view, + indices_out_view, + dists_out_view); + + update_host( + distances_bruteforce.data(), distances_bruteforce_dev.data(), queries_size, stream_); + update_host(indices_bruteforce.data(), indices_bruteforce_dev.data(), queries_size, stream_); + resource::sync_stream(handle_); + + ASSERT_TRUE(raft::spatial::knn::devArrMatchKnnPair(indices_naive_dev.data(), + indices_bruteforce_dev.data(), + distances_naive_dev.data(), + distances_bruteforce_dev.data(), + ps.num_queries, + ps.k, + 0.001f, + stream_, + true)); } } From b65226fc3db885a9e87279d1dffd2d20df1bf82e Mon Sep 17 00:00:00 2001 From: William Hicks Date: Mon, 4 Dec 2023 16:57:09 -0500 Subject: [PATCH 13/17] Apply suggestions from code review Co-authored-by: Ben Frederickson --- cpp/include/raft/neighbors/brute_force_serialize.cuh | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/neighbors/brute_force_serialize.cuh b/cpp/include/raft/neighbors/brute_force_serialize.cuh index 06809c3468..f6aa67f134 100644 --- a/cpp/include/raft/neighbors/brute_force_serialize.cuh +++ b/cpp/include/raft/neighbors/brute_force_serialize.cuh @@ -82,13 +82,14 @@ void serialize(raft::resources const& handle, std::ostream& os, const index& * * @code{.cpp} * #include + * #include * * raft::resources handle; * * // create a string with a filepath * std::string filename("/path/to/index"); * // create an index with `auto index = brute_force::build(...);` - * raft::serialize(handle, filename, index); + * raft::neighbors::brute_force::serialize(handle, filename, index); * @endcode * * @tparam T data element type @@ -113,13 +114,14 @@ void serialize(raft::resources const& handle, const std::string& filename, const * * @code{.cpp} * #include + * #include * * raft::resources handle; * * // create an input stream * std::istream is(std::cin.rdbuf()); * using T = float; // data element type - * auto index = raft::deserialize(handle, is); + * auto index = raft::neighbors::brute_force::deserialize(handle, is); * @endcode * * @tparam T data element type @@ -176,13 +178,14 @@ auto deserialize(raft::resources const& handle, std::istream& is) * * @code{.cpp} * #include + * #include * * raft::resources handle; * * // create a string with a filepath * std::string filename("/path/to/index"); * using T = float; // data element type - * auto index = raft::deserialize(handle, filename); + * auto index = raft::neighbors::brute_force::deserialize(handle, filename); * @endcode * * @tparam T data element type From ce12715ee73d0e7efeefc09a9f425adae2625c73 Mon Sep 17 00:00:00 2001 From: William Hicks Date: Mon, 4 Dec 2023 17:03:44 -0500 Subject: [PATCH 14/17] Remove unused host variables --- cpp/test/neighbors/ann_brute_force.cuh | 9 --------- 1 file changed, 9 deletions(-) diff --git a/cpp/test/neighbors/ann_brute_force.cuh b/cpp/test/neighbors/ann_brute_force.cuh index 1c6e4ad2f6..6404fa7719 100644 --- a/cpp/test/neighbors/ann_brute_force.cuh +++ b/cpp/test/neighbors/ann_brute_force.cuh @@ -90,10 +90,6 @@ class AnnBruteForceTest : public ::testing::TestWithParam indices_bruteforce(queries_size); - std::vector indices_naive(queries_size); - std::vector distances_bruteforce(queries_size); - std::vector distances_naive(queries_size); rmm::device_uvector distances_naive_dev(queries_size, stream_); rmm::device_uvector indices_naive_dev(queries_size, stream_); @@ -107,8 +103,6 @@ class AnnBruteForceTest : public ::testing::TestWithParam Date: Tue, 5 Dec 2023 12:48:26 -0500 Subject: [PATCH 15/17] Make dataset serialization optional --- .../raft/neighbors/brute_force_serialize.cuh | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/cpp/include/raft/neighbors/brute_force_serialize.cuh b/cpp/include/raft/neighbors/brute_force_serialize.cuh index f6aa67f134..1c3430bd2f 100644 --- a/cpp/include/raft/neighbors/brute_force_serialize.cuh +++ b/cpp/include/raft/neighbors/brute_force_serialize.cuh @@ -54,7 +54,10 @@ auto static constexpr serialization_version = 0; * */ template -void serialize(raft::resources const& handle, std::ostream& os, const index& index) +void serialize(raft::resources const& handle, + std::ostream& os, + const index& index, + bool include_dataset = true) { RAFT_LOG_DEBUG( "Saving brute force index, size %zu, dim %u", static_cast(index.size()), index.dim()); @@ -68,7 +71,8 @@ void serialize(raft::resources const& handle, std::ostream& os, const index& serialize_scalar(handle, os, index.dim()); serialize_scalar(handle, os, index.metric()); serialize_scalar(handle, os, index.metric_arg()); - serialize_mdspan(handle, os, index.dataset()); + serialize_scalar(handle, os, include_dataset); + if (include_dataset) { serialize_mdspan(handle, os, index.dataset()); } auto has_norms = index.has_norms(); serialize_scalar(handle, os, has_norms); if (has_norms) { serialize_mdspan(handle, os, index.norms()); } @@ -100,11 +104,14 @@ void serialize(raft::resources const& handle, std::ostream& os, const index& * */ template -void serialize(raft::resources const& handle, const std::string& filename, const index& index) +void serialize(raft::resources const& handle, + const std::string& filename, + const index& index, + bool include_dataset = true) { auto os = std::ofstream{filename, std::ios::out | std::ios::binary}; RAFT_EXPECTS(os, "Cannot open file %s", filename.c_str()); - serialize(handle, os, index); + serialize(handle, os, index, include_dataset); } /** @@ -147,7 +154,8 @@ auto deserialize(raft::resources const& handle, std::istream& is) auto metric_arg = deserialize_scalar(handle, is); auto dataset_storage = raft::make_host_matrix(rows, dim); - deserialize_mdspan(handle, is, dataset_storage.view()); + auto include_dataset = deserialize_scalar(handle, is); + if (include_dataset) { deserialize_mdspan(handle, is, dataset_storage.view()); } auto has_norms = deserialize_scalar(handle, is); auto norms_storage = has_norms ? std::optional{raft::make_host_vector(rows)} From e975ce2dcc6e8d6da246f103128fbd539b41f5ab Mon Sep 17 00:00:00 2001 From: William Hicks Date: Tue, 5 Dec 2023 13:22:51 -0500 Subject: [PATCH 16/17] Add tests for no-dataset serialization --- .../raft/neighbors/brute_force_serialize.cuh | 7 ++++-- .../raft/neighbors/brute_force_types.hpp | 2 +- cpp/test/neighbors/ann_brute_force.cuh | 23 +++++++++++++++++++ 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/neighbors/brute_force_serialize.cuh b/cpp/include/raft/neighbors/brute_force_serialize.cuh index 1c3430bd2f..1541730cb0 100644 --- a/cpp/include/raft/neighbors/brute_force_serialize.cuh +++ b/cpp/include/raft/neighbors/brute_force_serialize.cuh @@ -153,9 +153,12 @@ auto deserialize(raft::resources const& handle, std::istream& is) auto metric = deserialize_scalar(handle, is); auto metric_arg = deserialize_scalar(handle, is); - auto dataset_storage = raft::make_host_matrix(rows, dim); + auto dataset_storage = raft::make_host_matrix(std::int64_t{}, std::int64_t{}); auto include_dataset = deserialize_scalar(handle, is); - if (include_dataset) { deserialize_mdspan(handle, is, dataset_storage.view()); } + if (include_dataset) { + dataset_storage = raft::make_host_matrix(rows, dim); + deserialize_mdspan(handle, is, dataset_storage.view()); + } auto has_norms = deserialize_scalar(handle, is); auto norms_storage = has_norms ? std::optional{raft::make_host_vector(rows)} diff --git a/cpp/include/raft/neighbors/brute_force_types.hpp b/cpp/include/raft/neighbors/brute_force_types.hpp index b53a9cacd9..176b32a866 100644 --- a/cpp/include/raft/neighbors/brute_force_types.hpp +++ b/cpp/include/raft/neighbors/brute_force_types.hpp @@ -147,7 +147,6 @@ struct index : ann::index { resource::sync_stream(res); } - private: /** * Replace the dataset with a new dataset. */ @@ -170,6 +169,7 @@ struct index : ann::index { dataset_view_ = make_const_mdspan(dataset_.view()); } + private: raft::distance::DistanceType metric_; raft::device_matrix dataset_; std::optional> norms_; diff --git a/cpp/test/neighbors/ann_brute_force.cuh b/cpp/test/neighbors/ann_brute_force.cuh index 6404fa7719..1cba6bfb39 100644 --- a/cpp/test/neighbors/ann_brute_force.cuh +++ b/cpp/test/neighbors/ann_brute_force.cuh @@ -150,6 +150,29 @@ class AnnBruteForceTest : public ::testing::TestWithParam(handle_, std::string{"brute_force_index"}); + index_loaded.update_dataset(handle_, idx.dataset()); + ASSERT_EQ(idx.size(), index_loaded.size()); + + brute_force::search(handle_, + search_params, + index_loaded, + search_queries_view, + indices_out_view, + dists_out_view); + + resource::sync_stream(handle_); + ASSERT_TRUE(raft::spatial::knn::devArrMatchKnnPair(indices_naive_dev.data(), indices_bruteforce_dev.data(), distances_naive_dev.data(), From 21cd1daace54c368282017cbdaf444c74c80b57d Mon Sep 17 00:00:00 2001 From: William Hicks Date: Thu, 7 Dec 2023 13:24:59 -0500 Subject: [PATCH 17/17] Document include_dataset parameter --- cpp/include/raft/neighbors/brute_force_serialize.cuh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cpp/include/raft/neighbors/brute_force_serialize.cuh b/cpp/include/raft/neighbors/brute_force_serialize.cuh index 1541730cb0..bed3bed9e1 100644 --- a/cpp/include/raft/neighbors/brute_force_serialize.cuh +++ b/cpp/include/raft/neighbors/brute_force_serialize.cuh @@ -51,6 +51,8 @@ auto static constexpr serialization_version = 0; * @param[in] handle the raft handle * @param[in] os output stream * @param[in] index brute force index + * @param[in] include_dataset whether to include the dataset in the serialized + * output * */ template @@ -101,6 +103,8 @@ void serialize(raft::resources const& handle, * @param[in] handle the raft handle * @param[in] filename the file name for saving the index * @param[in] index brute force index + * @param[in] include_dataset whether to include the dataset in the serialized + * output * */ template