From 90be4848f38a4425771248f3eb67c4c62e7479bb Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Wed, 17 Jul 2024 06:43:19 -0700 Subject: [PATCH] Add col-major support for brute force knn (#217) Authors: - Ben Frederickson (https://github.com/benfred) Approvers: - Tarang Jain (https://github.com/tarang-jain) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/cuvs/pull/217 --- cpp/include/cuvs/neighbors/brute_force.hpp | 72 +++++- cpp/src/neighbors/brute_force.cu | 52 ++++ cpp/src/neighbors/detail/knn_brute_force.cuh | 27 +-- cpp/test/neighbors/brute_force.cu | 240 ++++++++++++++++++- cpp/test/neighbors/knn_utils.cuh | 15 +- 5 files changed, 364 insertions(+), 42 deletions(-) diff --git a/cpp/include/cuvs/neighbors/brute_force.hpp b/cpp/include/cuvs/neighbors/brute_force.hpp index 13a5ea0cb..1ec7e81f7 100644 --- a/cpp/include/cuvs/neighbors/brute_force.hpp +++ b/cpp/include/cuvs/neighbors/brute_force.hpp @@ -49,7 +49,8 @@ struct index : cuvs::neighbors::index { * * Constructs a brute force index from a dataset. This lets us precompute norms for * the dataset, providing a speed benefit over doing this at query time. - * This index will store a non-owning reference to the dataset. + * This index will copy the host dataset onto the device, and take ownership of any + * precaculated norms. */ index(raft::resources const& res, raft::host_matrix_view dataset_view, @@ -61,7 +62,8 @@ struct index : cuvs::neighbors::index { * * Constructs a brute force index from a dataset. This lets us precompute norms for * the dataset, providing a speed benefit over doing this at query time. - * The dataset will be copied to the device and the index will own the device memory. + * This index will store a non-owning reference to the dataset, but will move + * any norms supplied. */ index(raft::resources const& res, raft::device_matrix_view dataset_view, @@ -71,7 +73,7 @@ struct index : cuvs::neighbors::index { /** Construct a brute force index from dataset * - * This class stores a non-owning reference to the dataset and norms here. + * This class stores a non-owning reference to the dataset and norms. * Having precomputed norms gives us a performance advantage at query time. */ index(raft::resources const& res, @@ -80,6 +82,17 @@ struct index : cuvs::neighbors::index { cuvs::distance::DistanceType metric, T metric_arg = 0.0); + /** Construct a brute force index from dataset + * + * This class stores a non-owning reference to the dataset and norms, with + * the dataset being supplied on device in a col_major format + */ + index(raft::resources const& res, + raft::device_matrix_view dataset_view, + std::optional>&& norms, + cuvs::distance::DistanceType metric, + T metric_arg = 0.0); + /** * Replace the dataset with a new dataset. */ @@ -152,12 +165,34 @@ struct index : cuvs::neighbors::index { * @param[in] metric cuvs::distance::DistanceType * @param[in] metric_arg metric argument * - * @return the constructed ivf-flat index + * @return the constructed bruteforce index */ auto build(raft::resources const& handle, raft::device_matrix_view dataset, cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded, float metric_arg = 0) -> cuvs::neighbors::brute_force::index; + +/** + * @brief Build the index from the dataset for efficient search. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // create and fill the index from a [N, D] dataset + * auto index = brute_force::build(handle, dataset, metric); + * @endcode + * + * @param[in] handle + * @param[in] dataset a device pointer to a col-major matrix [n_rows, dim] + * @param[in] metric cuvs::distance::DistanceType + * @param[in] metric_arg metric argument + * + * @return the constructed bruteforce index + */ +auto build(raft::resources const& handle, + raft::device_matrix_view dataset, + cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded, + float metric_arg = 0) -> cuvs::neighbors::brute_force::index; /** * @} */ @@ -169,7 +204,7 @@ auto build(raft::resources const& handle, /** * @brief Search ANN using the constructed index. * - * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. + * See the [brute_force::build](#brute_force::build) documentation for a usage example. * * Note, this function requires a temporary buffer to store intermediate results between cuda kernel * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can @@ -186,13 +221,13 @@ auto build(raft::resources const& handle, * @endcode * * @param[in] handle - * @param[in] index ivf-flat constructed index + * @param[in] index bruteforce constructed index * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset * [n_queries, k] * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - * @param[in] sample_filter a optional device bitmap filter function that greenlights samples for a - * given + * @param[in] sample_filter an optional device bitmap filter function that greenlights samples for a + * given query */ void search(raft::resources const& handle, const cuvs::neighbors::brute_force::index& index, @@ -200,6 +235,27 @@ void search(raft::resources const& handle, raft::device_matrix_view neighbors, raft::device_matrix_view distances, std::optional> sample_filter); + +/** + * @brief Search ANN using the constructed index. + * + * See the [brute_force::build](#brute_force::build) documentation for a usage example. + * + * @param[in] handle + * @param[in] index bruteforce constructed index + * @param[in] queries a device pointer to a col-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] + * @param[in] sample_filter an optional device bitmap filter function that greenlights samples for a + * given query + */ +void search(raft::resources const& handle, + const cuvs::neighbors::brute_force::index& index, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + std::optional> sample_filter); /** * @} */ diff --git a/cpp/src/neighbors/brute_force.cu b/cpp/src/neighbors/brute_force.cu index 13554c0b5..c0414079c 100644 --- a/cpp/src/neighbors/brute_force.cu +++ b/cpp/src/neighbors/brute_force.cu @@ -69,6 +69,36 @@ index::index(raft::resources const& res, { } +template +index::index(raft::resources const& res, + raft::device_matrix_view dataset_view, + std::optional>&& norms, + cuvs::distance::DistanceType metric, + T metric_arg) + : cuvs::neighbors::index(), + metric_(metric), + dataset_( + raft::make_device_matrix(res, dataset_view.extent(0), dataset_view.extent(1))), + norms_(std::move(norms)), + metric_arg_(metric_arg) +{ + // currently we don't support col_major inside tiled_brute_force_knn, because + // of limitations of the pairwise_distance API: + // 1) paiwise_distance takes a single 'isRowMajor' parameter - and we have + // multiple options here (both dataset and queries) + // 2) because of tiling, we need to be able to set a custom stride in the PW + // api, which isn't supported + // Instead, transpose the input matrices if they are passed as col-major. + // (note: we're doing the transpose here to avoid doing per query) + raft::linalg::transpose(res, + const_cast(dataset_view.data_handle()), + dataset_.data_handle(), + dataset_view.extent(0), + dataset_view.extent(1), + raft::resource::get_cuda_stream(res)); + dataset_view_ = raft::make_const_mdspan(dataset_.view()); +} + template void index::update_dataset(raft::resources const& res, raft::device_matrix_view dataset) @@ -93,6 +123,14 @@ void index::update_dataset(raft::resources const& res, ->cuvs::neighbors::brute_force::index \ { \ return detail::build(res, dataset, metric, metric_arg); \ + } \ + auto build(raft::resources const& res, \ + raft::device_matrix_view dataset, \ + cuvs::distance::DistanceType metric, \ + T metric_arg) \ + ->cuvs::neighbors::brute_force::index \ + { \ + return detail::build(res, dataset, metric, metric_arg); \ } \ \ void search( \ @@ -109,6 +147,20 @@ void index::update_dataset(raft::resources const& res, detail::brute_force_search_filtered( \ res, idx, queries, *sample_filter, neighbors, distances); \ } \ + } \ + void search( \ + raft::resources const& res, \ + const cuvs::neighbors::brute_force::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + std::optional> sample_filter = std::nullopt) \ + { \ + if (!sample_filter.has_value()) { \ + detail::brute_force_search(res, idx, queries, neighbors, distances); \ + } else { \ + RAFT_FAIL("filtered search isn't available with col_major queries yet"); \ + } \ } \ \ template struct cuvs::neighbors::brute_force::index; diff --git a/cpp/src/neighbors/detail/knn_brute_force.cuh b/cpp/src/neighbors/detail/knn_brute_force.cuh index 97f7fba75..fe425fe8f 100644 --- a/cpp/src/neighbors/detail/knn_brute_force.cuh +++ b/cpp/src/neighbors/detail/knn_brute_force.cuh @@ -513,11 +513,11 @@ void brute_force_knn_impl( if (translations == nullptr) delete id_ranges; }; -template +template void brute_force_search( raft::resources const& res, const cuvs::neighbors::brute_force::index& idx, - raft::device_matrix_view queries, + raft::device_matrix_view queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances, std::optional> query_norms = std::nullopt) @@ -544,7 +544,7 @@ void brute_force_search( distances.data_handle(), k, true, - true, + std::is_same_v, nullptr, idx.metric(), idx.metric_arg(), @@ -719,28 +719,17 @@ void brute_force_search_filtered( return; } -template +template cuvs::neighbors::brute_force::index build( raft::resources const& res, - raft::device_matrix_view dataset, + raft::device_matrix_view dataset, cuvs::distance::DistanceType metric, T metric_arg) { // certain distance metrics can benefit by pre-calculating the norms for the index dataset // which lets us avoid calculating these at query time std::optional> norms; - auto dataset_storage = std::optional>{}; - auto dataset_view = [&res, &dataset_storage, dataset]() { - if constexpr (std::is_same_v>) { - return dataset; - } else { - dataset_storage = - raft::make_device_matrix(res, dataset.extent(0), dataset.extent(1)); - raft::copy(res, dataset_storage->view(), dataset); - return raft::make_const_mdspan(dataset_storage->view()); - } - }(); + if (metric == cuvs::distance::DistanceType::L2Expanded || metric == cuvs::distance::DistanceType::L2SqrtExpanded || metric == cuvs::distance::DistanceType::CosineExpanded) { @@ -748,14 +737,14 @@ cuvs::neighbors::brute_force::index build( // cosine needs the l2norm, where as l2 distances needs the squared norm if (metric == cuvs::distance::DistanceType::CosineExpanded) { raft::linalg::norm(res, - dataset_view, + dataset, norms->view(), raft::linalg::NormType::L2Norm, raft::linalg::Apply::ALONG_ROWS, raft::sqrt_op{}); } else { raft::linalg::norm(res, - dataset_view, + dataset, norms->view(), raft::linalg::NormType::L2Norm, raft::linalg::Apply::ALONG_ROWS); diff --git a/cpp/test/neighbors/brute_force.cu b/cpp/test/neighbors/brute_force.cu index 081a2966e..c97bb5531 100644 --- a/cpp/test/neighbors/brute_force.cu +++ b/cpp/test/neighbors/brute_force.cu @@ -15,8 +15,15 @@ */ #include "../test_utils.cuh" +#include "./knn_utils.cuh" #include #include +#include + +#include +#include +#include +#include namespace cuvs::neighbors::brute_force { struct KNNInputs { @@ -179,4 +186,235 @@ typedef KNNTest KNNTestFint64_t; TEST_P(KNNTestFint64_t, BruteForce) { this->testBruteForce(); } INSTANTIATE_TEST_CASE_P(KNNTest, KNNTestFint64_t, ::testing::ValuesIn(inputs)); -} // namespace cuvs::neighbors::brute_force \ No newline at end of file + +// Also test with larger random inputs, including col-major inputs +struct RandomKNNInputs { + int num_queries; + int num_db_vecs; + int dim; + int k; + cuvs::distance::DistanceType metric; + bool row_major; +}; + +std::ostream& operator<<(std::ostream& os, const RandomKNNInputs& input) +{ + return os << "num_queries:" << input.num_queries << " num_vecs:" << input.num_db_vecs + << " dim:" << input.dim << " k:" << input.k + << " metric:" << cuvs::neighbors::print_metric{input.metric} + << " row_major:" << input.row_major; +} + +template +class RandomBruteForceKNNTest : public ::testing::TestWithParam { + public: + RandomBruteForceKNNTest() + : stream_(raft::resource::get_cuda_stream(handle_)), + params_(::testing::TestWithParam::GetParam()), + database(params_.num_db_vecs * params_.dim, stream_), + search_queries(params_.num_queries * params_.dim, stream_), + cuvs_indices_(params_.num_queries * params_.k, stream_), + cuvs_distances_(params_.num_queries * params_.k, stream_), + ref_indices_(params_.num_queries * params_.k, stream_), + ref_distances_(params_.num_queries * params_.k, stream_) + { + raft::matrix::fill( + handle_, + raft::make_device_matrix_view(database.data(), params_.num_db_vecs, params_.dim), + T{0.0}); + raft::matrix::fill( + handle_, + raft::make_device_matrix_view(search_queries.data(), params_.num_queries, params_.dim), + T{0.0}); + raft::matrix::fill( + handle_, + raft::make_device_matrix_view(cuvs_distances_.data(), params_.num_queries, params_.k), + T{0.0}); + raft::matrix::fill( + handle_, + raft::make_device_matrix_view(ref_distances_.data(), params_.num_queries, params_.k), + T{0.0}); + } + + protected: + void testBruteForce() + { + float metric_arg = 3.0; + + // calculate the naive knn, by calculating the full pairwise distances and doing a k-select + rmm::device_uvector temp_distances(num_db_vecs * num_queries, stream_); + rmm::device_uvector workspace(0, stream_); + + auto temp_dist = temp_distances.data(); + rmm::device_uvector temp_row_major_dist(num_db_vecs * num_queries, stream_); + + if (params_.row_major) { + distance::pairwise_distance( + handle_, + raft::make_device_matrix_view( + search_queries.data(), params_.num_queries, params_.dim), + raft::make_device_matrix_view( + database.data(), params_.num_db_vecs, params_.dim), + raft::make_device_matrix_view(temp_distances.data(), num_queries, num_db_vecs), + metric, + metric_arg); + + } else { + distance::pairwise_distance(handle_, + raft::make_device_matrix_view( + search_queries.data(), params_.num_queries, params_.dim), + raft::make_device_matrix_view( + database.data(), params_.num_db_vecs, params_.dim), + raft::make_device_matrix_view( + temp_distances.data(), num_queries, num_db_vecs), + metric, + metric_arg); + + // the pairwisse_distance call assumes that the inputs and outputs are all either row-major + // or col-major - meaning we have to transpose the output back for col-major queries + // for comparison + raft::linalg::transpose( + handle_, temp_dist, temp_row_major_dist.data(), num_queries, num_db_vecs, stream_); + temp_dist = temp_row_major_dist.data(); + } + + cuvs::selection::select_k( + handle_, + raft::make_device_matrix_view(temp_dist, num_queries, num_db_vecs), + std::nullopt, + raft::make_device_matrix_view(ref_distances_.data(), params_.num_queries, params_.k), + raft::make_device_matrix_view(ref_indices_.data(), params_.num_queries, params_.k), + cuvs::distance::is_min_close(metric), + true); + + auto indices = raft::make_device_matrix_view( + cuvs_indices_.data(), params_.num_queries, params_.k); + auto distances = raft::make_device_matrix_view( + cuvs_distances_.data(), params_.num_queries, params_.k); + + if (params_.row_major) { + auto idx = + cuvs::neighbors::brute_force::build(handle_, + raft::make_device_matrix_view( + database.data(), params_.num_db_vecs, params_.dim), + metric, + metric_arg); + + cuvs::neighbors::brute_force::search( + handle_, + idx, + raft::make_device_matrix_view( + search_queries.data(), params_.num_queries, params_.dim), + indices, + distances, + std::nullopt); + } else { + auto idx = cuvs::neighbors::brute_force::build( + handle_, + raft::make_device_matrix_view( + database.data(), params_.num_db_vecs, params_.dim), + metric, + metric_arg); + + cuvs::neighbors::brute_force::search( + handle_, + idx, + raft::make_device_matrix_view( + search_queries.data(), params_.num_queries, params_.dim), + indices, + distances, + std::nullopt); + } + + ASSERT_TRUE(cuvs::neighbors::devArrMatchKnnPair(ref_indices_.data(), + cuvs_indices_.data(), + ref_distances_.data(), + cuvs_distances_.data(), + num_queries, + k_, + float(0.001), + stream_, + true)); + } + + void SetUp() override + { + num_queries = params_.num_queries; + num_db_vecs = params_.num_db_vecs; + dim = params_.dim; + k_ = params_.k; + metric = params_.metric; + + unsigned long long int seed = 1234ULL; + raft::random::RngState r(seed); + + // JensenShannon distance requires positive values + T min_val = metric == cuvs::distance::DistanceType::JensenShannon ? T(0.0) : T(-1.0); + uniform(handle_, r, database.data(), num_db_vecs * dim, min_val, T(1.0)); + uniform(handle_, r, search_queries.data(), num_queries * dim, min_val, T(1.0)); + } + + private: + raft::resources handle_; + cudaStream_t stream_ = 0; + RandomKNNInputs params_; + int num_queries; + int num_db_vecs; + int dim; + rmm::device_uvector database; + rmm::device_uvector search_queries; + rmm::device_uvector cuvs_indices_; + rmm::device_uvector cuvs_distances_; + rmm::device_uvector ref_indices_; + rmm::device_uvector ref_distances_; + int k_; + cuvs::distance::DistanceType metric; +}; + +const std::vector random_inputs = { + // test each distance metric on a small-ish input, with row-major inputs + {256, 512, 16, 8, cuvs::distance::DistanceType::L2Expanded, true}, + {256, 512, 16, 8, cuvs::distance::DistanceType::L2Unexpanded, true}, + {256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtExpanded, true}, + {256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtUnexpanded, true}, + {256, 512, 16, 8, cuvs::distance::DistanceType::L1, true}, + {256, 512, 16, 8, cuvs::distance::DistanceType::Linf, true}, + {256, 512, 16, 8, cuvs::distance::DistanceType::InnerProduct, true}, + {256, 512, 16, 8, cuvs::distance::DistanceType::CorrelationExpanded, true}, + {256, 512, 16, 8, cuvs::distance::DistanceType::CosineExpanded, true}, + {256, 512, 16, 8, cuvs::distance::DistanceType::LpUnexpanded, true}, + {256, 512, 16, 8, cuvs::distance::DistanceType::JensenShannon, true}, + {256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtExpanded, true}, + {256, 512, 16, 8, cuvs::distance::DistanceType::Canberra, true}, + // test each distance metric with col-major inputs + {256, 512, 16, 8, cuvs::distance::DistanceType::L2Expanded, false}, + {256, 512, 16, 8, cuvs::distance::DistanceType::L2Unexpanded, false}, + {256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtExpanded, false}, + {256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtUnexpanded, false}, + {256, 512, 16, 8, cuvs::distance::DistanceType::L1, false}, + {256, 512, 16, 8, cuvs::distance::DistanceType::Linf, false}, + {256, 512, 16, 8, cuvs::distance::DistanceType::InnerProduct, false}, + {256, 512, 16, 8, cuvs::distance::DistanceType::CorrelationExpanded, false}, + {256, 512, 16, 8, cuvs::distance::DistanceType::CosineExpanded, false}, + {256, 512, 16, 8, cuvs::distance::DistanceType::LpUnexpanded, false}, + {256, 512, 16, 8, cuvs::distance::DistanceType::JensenShannon, false}, + {256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtExpanded, false}, + {256, 512, 16, 8, cuvs::distance::DistanceType::Canberra, false}, + // larger tests on different sized data / k values + {10000, 40000, 32, 30, cuvs::distance::DistanceType::L2Expanded, false}, + {345, 1023, 16, 128, cuvs::distance::DistanceType::CosineExpanded, true}, + {789, 20516, 64, 256, cuvs::distance::DistanceType::L2SqrtExpanded, false}, + {1000, 500000, 128, 128, cuvs::distance::DistanceType::L2Expanded, true}, + {1000, 500000, 128, 128, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 5000, 128, 128, cuvs::distance::DistanceType::LpUnexpanded, true}, + {1000, 5000, 128, 128, cuvs::distance::DistanceType::L2SqrtExpanded, false}, + {1000, 5000, 128, 128, cuvs::distance::DistanceType::InnerProduct, false}}; + +typedef RandomBruteForceKNNTest RandomBruteForceKNNTestF; +TEST_P(RandomBruteForceKNNTestF, BruteForce) { this->testBruteForce(); } + +INSTANTIATE_TEST_CASE_P(RandomBruteForceKNNTest, + RandomBruteForceKNNTestF, + ::testing::ValuesIn(random_inputs)); + +} // namespace cuvs::neighbors::brute_force diff --git a/cpp/test/neighbors/knn_utils.cuh b/cpp/test/neighbors/knn_utils.cuh index d95174ef6..75cc90916 100644 --- a/cpp/test/neighbors/knn_utils.cuh +++ b/cpp/test/neighbors/knn_utils.cuh @@ -17,6 +17,7 @@ #pragma once #include "../test_utils.cuh" +#include "./ann_utils.cuh" #include @@ -25,20 +26,6 @@ #include namespace cuvs::neighbors { -template -struct idx_dist_pair { - IdxT idx; - DistT dist; - compareDist eq_compare; - bool operator==(const idx_dist_pair& a) const - { - if (idx == a.idx) return true; - if (eq_compare(dist, a.dist)) return true; - return false; - } - idx_dist_pair(IdxT x, DistT y, compareDist op) : idx(x), dist(y), eq_compare(op) {} -}; - template testing::AssertionResult devArrMatchKnnPair(const T* expected_idx, const T* actual_idx,