diff --git a/build.sh b/build.sh index 51e59cc259..200d6710e0 100755 --- a/build.sh +++ b/build.sh @@ -78,7 +78,7 @@ INSTALL_TARGET=install BUILD_REPORT_METRICS="" BUILD_REPORT_INCL_CACHE_STATS=OFF -TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_NN_DESCENT_TEST;NEIGHBORS_ANN_IVF_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" +TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_BRUTE_FORCE_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_NN_DESCENT_TEST;NEIGHBORS_ANN_IVF_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" BENCH_TARGETS="CLUSTER_BENCH;CORE_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH" CACHE_ARGS="" @@ -323,6 +323,7 @@ if hasArg tests || (( ${NUMARGS} == 0 )); then if [[ $CMAKE_TARGET == *"CLUSTER_TEST"* || \ $CMAKE_TARGET == *"DISTANCE_TEST"* || \ $CMAKE_TARGET == *"MATRIX_TEST"* || \ + $CMAKE_TARGET == *"NEIGHBORS_ANN_BRUTE_FORCE_TEST"* || \ $CMAKE_TARGET == *"NEIGHBORS_ANN_CAGRA_TEST"* || \ $CMAKE_TARGET == *"NEIGHBORS_ANN_IVF_TEST"* || \ $CMAKE_TARGET == *"NEIGHBORS_ANN_NN_DESCENT_TEST"* || \ diff --git a/cpp/include/raft/linalg/norm.cuh b/cpp/include/raft/linalg/norm.cuh index 9dad96356b..0d472c5476 100644 --- a/cpp/include/raft/linalg/norm.cuh +++ b/cpp/include/raft/linalg/norm.cuh @@ -20,6 +20,7 @@ #include "detail/norm.cuh" #include "linalg_types.hpp" +#include #include #include @@ -154,4 +155,4 @@ void norm(raft::resources const& handle, }; // end namespace linalg }; // end namespace raft -#endif \ No newline at end of file +#endif diff --git a/cpp/include/raft/neighbors/brute_force-ext.cuh b/cpp/include/raft/neighbors/brute_force-ext.cuh index 4c1f7ea21e..ddce6d8fda 100644 --- a/cpp/include/raft/neighbors/brute_force-ext.cuh +++ b/cpp/include/raft/neighbors/brute_force-ext.cuh @@ -45,8 +45,21 @@ index build(raft::resources const& res, raft::distance::DistanceType metric = distance::DistanceType::L2Unexpanded, T metric_arg = 0.0) RAFT_EXPLICIT; +template +index build(raft::resources const& res, + index_params const& params, + mdspan, row_major, Accessor> dataset) RAFT_EXPLICIT; + +template +void search(raft::resources const& res, + const index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) RAFT_EXPLICIT; + template void search(raft::resources const& res, + search_params const& params, const index& idx, raft::device_matrix_view queries, raft::device_matrix_view neighbors, @@ -116,6 +129,14 @@ extern template void search( raft::device_matrix_view neighbors, raft::device_matrix_view distances); +extern template void search( + raft::resources const& res, + search_params const& params, + const raft::neighbors::brute_force::index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances); + extern template void search( raft::resources const& res, const raft::neighbors::brute_force::index& idx, @@ -123,11 +144,35 @@ extern template void search( raft::device_matrix_view neighbors, raft::device_matrix_view distances); +extern template void search( + raft::resources const& res, + search_params const& params, + const raft::neighbors::brute_force::index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances); + extern template raft::neighbors::brute_force::index build( raft::resources const& res, raft::device_matrix_view dataset, raft::distance::DistanceType metric, float metric_arg); + +extern template raft::neighbors::brute_force::index build( + raft::resources const& res, + index_params const& params, + raft::device_matrix_view dataset); + +extern template raft::neighbors::brute_force::index build( + raft::resources const& res, + raft::host_matrix_view dataset, + raft::distance::DistanceType metric, + float metric_arg); + +extern template raft::neighbors::brute_force::index build( + raft::resources const& res, + index_params const& params, + raft::host_matrix_view dataset); } // namespace raft::neighbors::brute_force #define instantiate_raft_neighbors_brute_force_fused_l2_knn( \ diff --git a/cpp/include/raft/neighbors/brute_force-inl.cuh b/cpp/include/raft/neighbors/brute_force-inl.cuh index 906371bd01..f955cc8518 100644 --- a/cpp/include/raft/neighbors/brute_force-inl.cuh +++ b/cpp/include/raft/neighbors/brute_force-inl.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include #include @@ -341,6 +342,18 @@ index build(raft::resources const& res, // certain distance metrics can benefit by pre-calculating the norms for the index dataset // which lets us avoid calculating these at query time std::optional> norms; + // TODO(wphicks): Replace once mdbuffer is available + auto dataset_storage = std::optional>{}; + auto dataset_view = [&res, &dataset_storage, dataset]() { + if constexpr (std::is_same_v>) { + return dataset; + } else { + dataset_storage = make_device_matrix(res, dataset.extent(0), dataset.extent(1)); + raft::copy(res, dataset_storage->view(), dataset); + return raft::make_const_mdspan(dataset_storage->view()); + } + }(); if (metric == raft::distance::DistanceType::L2Expanded || metric == raft::distance::DistanceType::L2SqrtExpanded || metric == raft::distance::DistanceType::CosineExpanded) { @@ -348,14 +361,14 @@ index build(raft::resources const& res, // cosine needs the l2norm, where as l2 distances needs the squared norm if (metric == raft::distance::DistanceType::CosineExpanded) { raft::linalg::norm(res, - dataset, + dataset_view, norms->view(), raft::linalg::NormType::L2Norm, raft::linalg::Apply::ALONG_ROWS, raft::sqrt_op{}); } else { raft::linalg::norm(res, - dataset, + dataset_view, norms->view(), raft::linalg::NormType::L2Norm, raft::linalg::Apply::ALONG_ROWS); @@ -365,6 +378,25 @@ index build(raft::resources const& res, return index(res, dataset, std::move(norms), metric, metric_arg); } +/** + * @brief Build the index from the dataset for efficient search. + * + * @tparam T data element type + * + * @param[in] res + * @param[in] params configure the index building + * @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim] + * + * @return the constructed brute force index + */ +template +index build(raft::resources const& res, + index_params const& params, + mdspan, row_major, Accessor> dataset) +{ + return build(res, dataset, params.metric, float(params.metric_arg)); +} + /** * @brief Brute Force search using the constructed index. * @@ -390,5 +422,32 @@ void search(raft::resources const& res, { raft::neighbors::detail::brute_force_search(res, idx, queries, neighbors, distances); } + +/** + * @brief Brute Force search using the constructed index. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] res raft resources + * @param[in] params configure the search + * @param[in] idx brute force index + * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * k] + */ +template +void search(raft::resources const& res, + search_params const& params, + const index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) +{ + raft::neighbors::detail::brute_force_search(res, idx, queries, neighbors, distances); +} + /** @} */ // end group brute_force_knn } // namespace raft::neighbors::brute_force diff --git a/cpp/include/raft/neighbors/brute_force_serialize.cuh b/cpp/include/raft/neighbors/brute_force_serialize.cuh new file mode 100644 index 0000000000..bed3bed9e1 --- /dev/null +++ b/cpp/include/raft/neighbors/brute_force_serialize.cuh @@ -0,0 +1,224 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::neighbors::brute_force { + +auto static constexpr serialization_version = 0; + +/** + * \defgroup brute_force_serialize Brute Force Serialize + * @{ + */ + +/** + * Write the index to an output stream + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create an output stream + * std::ostream os(std::cout.rdbuf()); + * // create an index with `auto index = brute_force::build(...);` + * raft::neighbors::brute_force::serialize(handle, os, index); + * @endcode + * + * @tparam T data element type + * + * @param[in] handle the raft handle + * @param[in] os output stream + * @param[in] index brute force index + * @param[in] include_dataset whether to include the dataset in the serialized + * output + * + */ +template +void serialize(raft::resources const& handle, + std::ostream& os, + const index& index, + bool include_dataset = true) +{ + RAFT_LOG_DEBUG( + "Saving brute force index, size %zu, dim %u", static_cast(index.size()), index.dim()); + + auto dtype_string = raft::detail::numpy_serializer::get_numpy_dtype().to_string(); + dtype_string.resize(4); + os << dtype_string; + + serialize_scalar(handle, os, serialization_version); + serialize_scalar(handle, os, index.size()); + serialize_scalar(handle, os, index.dim()); + serialize_scalar(handle, os, index.metric()); + serialize_scalar(handle, os, index.metric_arg()); + serialize_scalar(handle, os, include_dataset); + if (include_dataset) { serialize_mdspan(handle, os, index.dataset()); } + auto has_norms = index.has_norms(); + serialize_scalar(handle, os, has_norms); + if (has_norms) { serialize_mdspan(handle, os, index.norms()); } + resource::sync_stream(handle); +} + +/** + * Save the index to file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * // create an index with `auto index = brute_force::build(...);` + * raft::neighbors::brute_force::serialize(handle, filename, index); + * @endcode + * + * @tparam T data element type + * + * @param[in] handle the raft handle + * @param[in] filename the file name for saving the index + * @param[in] index brute force index + * @param[in] include_dataset whether to include the dataset in the serialized + * output + * + */ +template +void serialize(raft::resources const& handle, + const std::string& filename, + const index& index, + bool include_dataset = true) +{ + auto os = std::ofstream{filename, std::ios::out | std::ios::binary}; + RAFT_EXPECTS(os, "Cannot open file %s", filename.c_str()); + serialize(handle, os, index, include_dataset); +} + +/** + * Load index from input stream + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create an input stream + * std::istream is(std::cin.rdbuf()); + * using T = float; // data element type + * auto index = raft::neighbors::brute_force::deserialize(handle, is); + * @endcode + * + * @tparam T data element type + * + * @param[in] handle the raft handle + * @param[in] is input stream + * + * @return raft::neighbors::brute_force::index + */ +template +auto deserialize(raft::resources const& handle, std::istream& is) +{ + auto dtype_string = std::array{}; + is.read(dtype_string.data(), 4); + + auto ver = deserialize_scalar(handle, is); + if (ver != serialization_version) { + RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver); + } + auto rows = deserialize_scalar(handle, is); + auto dim = deserialize_scalar(handle, is); + auto metric = deserialize_scalar(handle, is); + auto metric_arg = deserialize_scalar(handle, is); + + auto dataset_storage = raft::make_host_matrix(std::int64_t{}, std::int64_t{}); + auto include_dataset = deserialize_scalar(handle, is); + if (include_dataset) { + dataset_storage = raft::make_host_matrix(rows, dim); + deserialize_mdspan(handle, is, dataset_storage.view()); + } + + auto has_norms = deserialize_scalar(handle, is); + auto norms_storage = has_norms ? std::optional{raft::make_host_vector(rows)} + : std::optional>{}; + // TODO(wphicks): Use mdbuffer here when available + auto norms_storage_dev = + has_norms ? std::optional{raft::make_device_vector(handle, rows)} + : std::optional>{}; + if (has_norms) { + deserialize_mdspan(handle, is, norms_storage->view()); + raft::copy(handle, norms_storage_dev->view(), norms_storage->view()); + } + + auto result = index(handle, + raft::make_const_mdspan(dataset_storage.view()), + std::move(norms_storage_dev), + metric, + metric_arg); + resource::sync_stream(handle); + + return result; +} + +/** + * Load index from file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * using T = float; // data element type + * auto index = raft::neighbors::brute_force::deserialize(handle, filename); + * @endcode + * + * @tparam T data element type + * + * @param[in] handle the raft handle + * @param[in] filename the name of the file that stores the index + * + * @return raft::neighbors::brute_force::index + */ +template +auto deserialize(raft::resources const& handle, const std::string& filename) +{ + auto is = std::ifstream{filename, std::ios::in | std::ios::binary}; + RAFT_EXPECTS(is, "Cannot open file %s", filename.c_str()); + + return deserialize(handle, is); +} + +/**@}*/ + +} // namespace raft::neighbors::brute_force diff --git a/cpp/include/raft/neighbors/brute_force_types.hpp b/cpp/include/raft/neighbors/brute_force_types.hpp index 039599845e..176b32a866 100644 --- a/cpp/include/raft/neighbors/brute_force_types.hpp +++ b/cpp/include/raft/neighbors/brute_force_types.hpp @@ -19,6 +19,7 @@ #include "ann_types.hpp" #include +#include #include #include #include @@ -35,6 +36,9 @@ namespace raft::neighbors::brute_force { * @{ */ +using ann::index_params; +using ann::search_params; + /** * @brief Brute Force index. * @@ -52,10 +56,10 @@ struct index : ann::index { } /** Total length of the index (number of vectors). */ - [[nodiscard]] constexpr inline int64_t size() const noexcept { return dataset_view_.extent(0); } + [[nodiscard]] constexpr inline auto size() const noexcept { return dataset_view_.extent(0); } /** Dimensionality of the data. */ - [[nodiscard]] constexpr inline uint32_t dim() const noexcept { return dataset_view_.extent(1); } + [[nodiscard]] constexpr inline auto dim() const noexcept { return dataset_view_.extent(1); } /** Dataset [size, dim] */ [[nodiscard]] inline auto dataset() const noexcept @@ -127,7 +131,22 @@ struct index : ann::index { { } - private: + template + index(raft::resources const& res, + index_params const& params, + mdspan, row_major, data_accessor> dataset, + std::optional>&& norms = std::nullopt) + : ann::index(), + metric_(params.metric), + dataset_(make_device_matrix(res, 0, 0)), + norms_(std::move(norms)), + metric_arg_(params.metric_arg) + { + if (norms_) { norms_view_ = make_const_mdspan(norms_.value().view()); } + update_dataset(res, dataset); + resource::sync_stream(res); + } + /** * Replace the dataset with a new dataset. */ @@ -145,14 +164,12 @@ struct index : ann::index { void update_dataset(raft::resources const& res, raft::host_matrix_view dataset) { - dataset_ = make_device_matrix(dataset.extents(0), dataset.extents(1)); - raft::copy(dataset_.data_handle(), - dataset.data_handle(), - dataset.size(), - resource::get_cuda_stream(res)); + dataset_ = make_device_matrix(res, dataset.extent(0), dataset.extent(1)); + raft::copy(res, dataset_.view(), dataset); dataset_view_ = make_const_mdspan(dataset_.view()); } + private: raft::distance::DistanceType metric_; raft::device_matrix dataset_; std::optional> norms_; diff --git a/cpp/src/neighbors/brute_force_knn_index_float.cu b/cpp/src/neighbors/brute_force_knn_index_float.cu index d4f902c087..faf99ceb3c 100644 --- a/cpp/src/neighbors/brute_force_knn_index_float.cu +++ b/cpp/src/neighbors/brute_force_knn_index_float.cu @@ -16,6 +16,9 @@ */ #include +#include +#include +#include #include template void raft::neighbors::brute_force::search( @@ -25,6 +28,14 @@ template void raft::neighbors::brute_force::search( raft::device_matrix_view neighbors, raft::device_matrix_view distances); +template void raft::neighbors::brute_force::search( + raft::resources const& res, + raft::neighbors::brute_force::search_params const& params, + const raft::neighbors::brute_force::index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances); + template void raft::neighbors::brute_force::search( raft::resources const& res, const raft::neighbors::brute_force::index& idx, @@ -32,8 +43,36 @@ template void raft::neighbors::brute_force::search( raft::device_matrix_view neighbors, raft::device_matrix_view distances); -template raft::neighbors::brute_force::index raft::neighbors::brute_force::build( +template void raft::neighbors::brute_force::search( raft::resources const& res, - raft::device_matrix_view dataset, - raft::distance::DistanceType metric, - float metric_arg); + raft::neighbors::brute_force::search_params const& params, + const raft::neighbors::brute_force::index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances); + +template raft::neighbors::brute_force::index raft::neighbors::brute_force:: + build::accessor_type>( + raft::resources const& res, + raft::host_matrix_view dataset, + raft::distance::DistanceType metric, + float metric_arg); + +template raft::neighbors::brute_force::index raft::neighbors::brute_force:: + build::accessor_type>( + raft::resources const& res, + raft::device_matrix_view dataset, + raft::distance::DistanceType metric, + float metric_arg); + +template raft::neighbors::brute_force::index raft::neighbors::brute_force:: + build::accessor_type>( + raft::resources const& res, + raft::neighbors::brute_force::index_params const& params, + raft::host_matrix_view dataset); + +template raft::neighbors::brute_force::index raft::neighbors::brute_force:: + build::accessor_type>( + raft::resources const& res, + raft::neighbors::brute_force::index_params const& params, + raft::device_matrix_view dataset); diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 847dec8568..f043442840 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -350,6 +350,11 @@ if(BUILD_TESTS) EXPLICIT_INSTANTIATE_ONLY ) + ConfigureTest( + NAME NEIGHBORS_ANN_BRUTE_FORCE_TEST PATH test/neighbors/ann_brute_force/test_float.cu LIB + EXPLICIT_INSTANTIATE_ONLY GPUS 1 PERCENT 100 + ) + ConfigureTest( NAME NEIGHBORS_ANN_CAGRA_TEST diff --git a/cpp/test/neighbors/ann_brute_force.cuh b/cpp/test/neighbors/ann_brute_force.cuh new file mode 100644 index 0000000000..1cba6bfb39 --- /dev/null +++ b/cpp/test/neighbors/ann_brute_force.cuh @@ -0,0 +1,253 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "../test_utils.cuh" +#include "ann_utils.cuh" +#include "knn_utils.cuh" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include +#include + +#include +#include +#include + +namespace raft::neighbors::brute_force { + +template +struct AnnBruteForceInputs { + IdxT num_queries; + IdxT num_db_vecs; + IdxT dim; + IdxT k; + raft::distance::DistanceType metric; + bool host_dataset; +}; + +template +::std::ostream& operator<<(::std::ostream& os, const AnnBruteForceInputs& p) +{ + os << "{ " << p.num_queries << ", " << p.num_db_vecs << ", " << p.dim << ", " << p.k << ", " + << static_cast(p.metric) << ", " << p.host_dataset << '}' << std::endl; + return os; +} + +template +class AnnBruteForceTest : public ::testing::TestWithParam> { + public: + AnnBruteForceTest() + : stream_(resource::get_cuda_stream(handle_)), + ps(::testing::TestWithParam>::GetParam()), + database(0, stream_), + search_queries(0, stream_) + { + } + + void testBruteForce() + { + size_t queries_size = ps.num_queries * ps.k; + + rmm::device_uvector distances_naive_dev(queries_size, stream_); + rmm::device_uvector indices_naive_dev(queries_size, stream_); + naive_knn(handle_, + distances_naive_dev.data(), + indices_naive_dev.data(), + search_queries.data(), + database.data(), + ps.num_queries, + ps.num_db_vecs, + ps.dim, + ps.k, + ps.metric); + resource::sync_stream(handle_); + + { + // Require exact result for brute force + rmm::device_uvector distances_bruteforce_dev(queries_size, stream_); + rmm::device_uvector indices_bruteforce_dev(queries_size, stream_); + brute_force::index_params index_params{}; + brute_force::search_params search_params{}; + index_params.metric = ps.metric; + index_params.metric_arg = 0; + + auto device_dataset = std::optional>{}; + auto idx = [this, &index_params]() { + if (ps.host_dataset) { + auto host_database = raft::make_host_matrix(ps.num_db_vecs, ps.dim); + raft::copy( + host_database.data_handle(), database.data(), ps.num_db_vecs * ps.dim, stream_); + return brute_force::build( + handle_, index_params, raft::make_const_mdspan(host_database.view())); + } else { + auto database_view = raft::make_device_matrix_view( + (const DataT*)database.data(), ps.num_db_vecs, ps.dim); + return brute_force::build(handle_, index_params, database_view); + } + }(); + + auto search_queries_view = raft::make_device_matrix_view( + search_queries.data(), ps.num_queries, ps.dim); + auto indices_out_view = raft::make_device_matrix_view( + indices_bruteforce_dev.data(), ps.num_queries, ps.k); + auto dists_out_view = raft::make_device_matrix_view( + distances_bruteforce_dev.data(), ps.num_queries, ps.k); + brute_force::serialize(handle_, std::string{"brute_force_index"}, idx); + + auto index_loaded = + brute_force::deserialize(handle_, std::string{"brute_force_index"}); + ASSERT_EQ(idx.size(), index_loaded.size()); + + brute_force::search(handle_, + search_params, + index_loaded, + search_queries_view, + indices_out_view, + dists_out_view); + + resource::sync_stream(handle_); + + ASSERT_TRUE(raft::spatial::knn::devArrMatchKnnPair(indices_naive_dev.data(), + indices_bruteforce_dev.data(), + distances_naive_dev.data(), + distances_bruteforce_dev.data(), + ps.num_queries, + ps.k, + 0.001f, + stream_, + true)); + brute_force::serialize(handle_, std::string{"brute_force_index"}, idx, false); + index_loaded = brute_force::deserialize(handle_, std::string{"brute_force_index"}); + index_loaded.update_dataset(handle_, idx.dataset()); + ASSERT_EQ(idx.size(), index_loaded.size()); + + brute_force::search(handle_, + search_params, + index_loaded, + search_queries_view, + indices_out_view, + dists_out_view); + + resource::sync_stream(handle_); + + ASSERT_TRUE(raft::spatial::knn::devArrMatchKnnPair(indices_naive_dev.data(), + indices_bruteforce_dev.data(), + distances_naive_dev.data(), + distances_bruteforce_dev.data(), + ps.num_queries, + ps.k, + 0.001f, + stream_, + true)); + } + } + + void SetUp() override + { + database.resize(ps.num_db_vecs * ps.dim, stream_); + search_queries.resize(ps.num_queries * ps.dim, stream_); + + raft::random::RngState r(1234ULL); + if constexpr (std::is_same{}) { + raft::random::uniform( + handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(0.1), DataT(2.0)); + raft::random::uniform( + handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(0.1), DataT(2.0)); + } else { + raft::random::uniformInt( + handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(1), DataT(20)); + raft::random::uniformInt( + handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(1), DataT(20)); + } + resource::sync_stream(handle_); + } + + void TearDown() override + { + resource::sync_stream(handle_); + database.resize(0, stream_); + search_queries.resize(0, stream_); + } + + private: + raft::resources handle_; + rmm::cuda_stream_view stream_; + AnnBruteForceInputs ps; + rmm::device_uvector database; + rmm::device_uvector search_queries; +}; + +const std::vector> inputs = { + // test various dims (aligned and not aligned to vector sizes) + {1000, 10000, 1, 16, raft::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 2, 16, raft::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 3, 16, raft::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 4, 16, raft::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 5, 16, raft::distance::DistanceType::InnerProduct, true}, + {1000, 10000, 8, 16, raft::distance::DistanceType::InnerProduct, true}, + {1000, 10000, 5, 16, raft::distance::DistanceType::L2SqrtExpanded, true}, + {1000, 10000, 8, 16, raft::distance::DistanceType::L2SqrtExpanded, true}, + + // test dims that do not fit into kernel shared memory limits + {1000, 10000, 2048, 16, raft::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 2049, 16, raft::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 2050, 16, raft::distance::DistanceType::InnerProduct, true}, + {1000, 10000, 2051, 16, raft::distance::DistanceType::InnerProduct, true}, + {1000, 10000, 2052, 16, raft::distance::DistanceType::InnerProduct, true}, + {1000, 10000, 2053, 16, raft::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 2056, 16, raft::distance::DistanceType::L2Expanded, true}, + + // host input data + {1000, 10000, 16, 10, raft::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 16, 10, raft::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 16, 10, raft::distance::DistanceType::L2Expanded, false}, + {100, 10000, 16, 10, raft::distance::DistanceType::L2Expanded, false}, + {20, 100000, 16, 10, raft::distance::DistanceType::L2Expanded, false}, + {1000, 100000, 16, 10, raft::distance::DistanceType::L2Expanded, false}, + {10000, 131072, 8, 10, raft::distance::DistanceType::L2Expanded, false}, + + {1000, 10000, 16, 10, raft::distance::DistanceType::InnerProduct, false}}; +} // namespace raft::neighbors::brute_force diff --git a/cpp/test/neighbors/ann_brute_force/test_float.cu b/cpp/test/neighbors/ann_brute_force/test_float.cu new file mode 100644 index 0000000000..f618f44b61 --- /dev/null +++ b/cpp/test/neighbors/ann_brute_force/test_float.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../ann_brute_force.cuh" + +namespace raft::neighbors::brute_force { + +using AnnBruteForceTest_float = AnnBruteForceTest; +TEST_P(AnnBruteForceTest_float, AnnBruteForce) { this->testBruteForce(); } + +INSTANTIATE_TEST_CASE_P(AnnBruteForceTest, AnnBruteForceTest_float, ::testing::ValuesIn(inputs)); + +} // namespace raft::neighbors::brute_force