From a59405d65bcdf6458600a48a200cdb4b7c455306 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 16 Nov 2023 23:17:35 -0800 Subject: [PATCH] Add support for iterating over batches in bfknn (#1947) This adds support for iterating over batches of nearest neighbors in the brute force knn. This lets you query for the nearest neighbors, and then filter down the results - and if you have filtered out too many results, get the next batch of nearest neighbors. The challenge here is balancing memory consumption versus efficieny: we could store the results of the full gemm in the distance calculation - but for large indices this discards the benefits of using the tiling strategy and risks running OOM. Instead we exponentially grow the number of neighbors being returned, and also cache both the query norms and index norms between calls. Authors: - Ben Frederickson (https://github.com/benfred) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1947 --- cpp/bench/ann/CMakeLists.txt | 5 +- .../distance/detail/distance_ops/l2_exp.cuh | 2 +- .../raft/neighbors/brute_force-ext.cuh | 2 +- .../raft/neighbors/brute_force-inl.cuh | 31 +---- cpp/include/raft/neighbors/brute_force.cuh | 68 ++++++++++ .../raft/neighbors/brute_force_types.hpp | 119 +++++++++++++++++- .../raft/neighbors/detail/knn_brute_force.cuh | 93 ++++++++++---- .../detail/knn_brute_force_batch_k_query.cuh | 98 +++++++++++++++ .../raft/neighbors/neighbors_types.hpp | 63 ++++++++++ .../neighbors/brute_force_knn_index_float.cu | 2 +- cpp/test/neighbors/tiled_knn.cu | 74 ++++++++++- 11 files changed, 497 insertions(+), 60 deletions(-) create mode 100644 cpp/include/raft/neighbors/detail/knn_brute_force_batch_k_query.cuh create mode 100644 cpp/include/raft/neighbors/neighbors_types.hpp diff --git a/cpp/bench/ann/CMakeLists.txt b/cpp/bench/ann/CMakeLists.txt index eb44e58cb5..5919de07e7 100644 --- a/cpp/bench/ann/CMakeLists.txt +++ b/cpp/bench/ann/CMakeLists.txt @@ -90,7 +90,7 @@ if(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ OR RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE OR RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT OR RAFT_ANN_BENCH_USE_RAFT_CAGRA - OR RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB + OR RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB ) set(RAFT_ANN_BENCH_USE_RAFT ON) endif() @@ -263,7 +263,8 @@ if(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) ${CMAKE_CURRENT_BINARY_DIR}/_deps/hnswlib-src/hnswlib LINKS raft::compiled - CXXFLAGS "${HNSW_CXX_FLAGS}" + CXXFLAGS + "${HNSW_CXX_FLAGS}" ) endif() diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh index 5b4048c1c3..a218c85a0a 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh @@ -44,7 +44,7 @@ struct l2_exp_cutlass_op { __device__ l2_exp_cutlass_op() noexcept : sqrt(false) {} __device__ l2_exp_cutlass_op(bool isSqrt) noexcept : sqrt(isSqrt) {} - __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept + inline __device__ AccT operator()(DataT aNorm, DataT bNorm, DataT accVal) const noexcept { AccT outVal = aNorm + bNorm - DataT(2.0) * accVal; diff --git a/cpp/include/raft/neighbors/brute_force-ext.cuh b/cpp/include/raft/neighbors/brute_force-ext.cuh index b8c00616da..4c1f7ea21e 100644 --- a/cpp/include/raft/neighbors/brute_force-ext.cuh +++ b/cpp/include/raft/neighbors/brute_force-ext.cuh @@ -50,7 +50,7 @@ void search(raft::resources const& res, const index& idx, raft::device_matrix_view queries, raft::device_matrix_view neighbors, - raft::device_matrix_view distances) RAFT_EXPLICIT; + raft::device_matrix_view distances) RAFT_EXPLICIT; template & idx, raft::device_matrix_view queries, raft::device_matrix_view neighbors, - raft::device_matrix_view distances) + raft::device_matrix_view distances) { - RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), "Value of k must match for outputs"); - RAFT_EXPECTS(idx.dataset().extent(1) == queries.extent(1), - "Number of columns in queries must match brute force index"); - - auto k = neighbors.extent(1); - auto d = idx.dataset().extent(1); - - std::vector dataset = {const_cast(idx.dataset().data_handle())}; - std::vector sizes = {idx.dataset().extent(0)}; - std::vector norms; - if (idx.has_norms()) { norms.push_back(const_cast(idx.norms().data_handle())); } - - detail::brute_force_knn_impl(res, - dataset, - sizes, - d, - const_cast(queries.data_handle()), - queries.extent(0), - neighbors.data_handle(), - distances.data_handle(), - k, - true, - true, - nullptr, - idx.metric(), - idx.metric_arg(), - raft::identity_op(), - norms.size() ? &norms : nullptr); + raft::neighbors::detail::brute_force_search(res, idx, queries, neighbors, distances); } /** @} */ // end group brute_force_knn } // namespace raft::neighbors::brute_force diff --git a/cpp/include/raft/neighbors/brute_force.cuh b/cpp/include/raft/neighbors/brute_force.cuh index 6cebf4b52a..4ba9159556 100644 --- a/cpp/include/raft/neighbors/brute_force.cuh +++ b/cpp/include/raft/neighbors/brute_force.cuh @@ -14,6 +14,7 @@ * limitations under the License. */ #pragma once +#include #ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY #include "brute_force-inl.cuh" @@ -22,3 +23,70 @@ #ifdef RAFT_COMPILED #include "brute_force-ext.cuh" #endif + +#include + +namespace raft::neighbors::brute_force { +/** + * @brief Make a brute force query over batches of k + * + * This lets you query for batches of k. For example, you can get + * the first 100 neighbors, then the next 100 neighbors etc. + * + * Example usage: + * @code{.cpp} + * #include + * #include + * #include + + * // create a random dataset + * int n_rows = 10000; + * int n_cols = 10000; + + * raft::device_resources res; + * auto dataset = raft::make_device_matrix(res, n_rows, n_cols); + * auto labels = raft::make_device_vector(res, n_rows); + + * raft::make_blobs(res, dataset.view(), labels.view()); + * + * // create a brute_force knn index from the dataset + * auto index = raft::neighbors::brute_force::build(res, + * raft::make_const_mdspan(dataset.view())); + * + * // search the index in batches of 128 nearest neighbors + * auto search = raft::make_const_mdspan(dataset.view()); + * auto query = make_batch_k_query(res, index, search, 128); + * for (auto & batch: *query) { + * // batch.indices() and batch.distances() contain the information on the current batch + * } + * + * // we can also support variable sized batches - loaded up a different number + * // of neighbors at each iteration through the ::advance method + * int64_t batch_size = 128; + * query = make_batch_k_query(res, index, search, batch_size); + * for (auto it = query->begin(); it != query->end(); it.advance(batch_size)) { + * // batch.indices() and batch.distances() contain the information on the current batch + * + * batch_size += 16; // load up an extra 16 items in the next batch + * } + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * @param[in] res + * @param[in] index The index to query + * @param[in] query A device matrix view to query for [n_queries, index->dim()] + * @param[in] batch_size The size of each batch + */ + +template +std::shared_ptr> make_batch_k_query( + const raft::resources& res, + const raft::neighbors::brute_force::index& index, + raft::device_matrix_view query, + int64_t batch_size) +{ + return std::shared_ptr>( + new detail::gpu_batch_k_query(res, index, query, batch_size)); +} +} // namespace raft::neighbors::brute_force diff --git a/cpp/include/raft/neighbors/brute_force_types.hpp b/cpp/include/raft/neighbors/brute_force_types.hpp index f7030503f1..039599845e 100644 --- a/cpp/include/raft/neighbors/brute_force_types.hpp +++ b/cpp/include/raft/neighbors/brute_force_types.hpp @@ -25,6 +25,7 @@ #include #include #include +#include #include @@ -69,7 +70,7 @@ struct index : ann::index { return norms_view_.value(); } - /** Whether ot not this index has dataset norms */ + /** Whether or not this index has dataset norms */ [[nodiscard]] inline bool has_norms() const noexcept { return norms_view_.has_value(); } [[nodiscard]] inline T metric_arg() const noexcept { return metric_arg_; } @@ -160,6 +161,122 @@ struct index : ann::index { T metric_arg_; }; +/** + * @brief Interface for performing queries over values of k + * + * This interface lets you iterate over batches of k from a brute_force::index. + * This lets you do things like retrieve the first 100 neighbors for a query, + * apply post processing to remove any unwanted items and then if needed get the + * next 100 closest neighbors for the query. + * + * This query interface exposes C++ iterators through the ::begin and ::end, and + * is compatible with range based for loops. + * + * Note that this class is an abstract class without any cuda dependencies, meaning + * that it doesn't require a cuda compiler to use - but also means it can't be directly + * instantiated. See the raft::neighbors::brute_force::make_batch_k_query + * function for usage examples. + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + */ +template +class batch_k_query { + public: + batch_k_query(const raft::resources& res, + int64_t index_size, + int64_t query_size, + int64_t batch_size) + : res(res), index_size(index_size), query_size(query_size), batch_size(batch_size) + { + } + virtual ~batch_k_query() {} + + using value_type = raft::neighbors::batch; + + class iterator { + public: + using value_type = raft::neighbors::batch; + using reference = const value_type&; + using pointer = const value_type*; + + iterator(const batch_k_query* query, int64_t offset = 0) + : current(query->res, 0, 0), batches(query->res, 0, 0), query(query), offset(offset) + { + query->load_batch(offset, query->batch_size, &batches); + query->slice_batch(batches, offset, query->batch_size, ¤t); + } + + reference operator*() const { return current; } + + pointer operator->() const { return ¤t; } + + iterator& operator++() + { + advance(query->batch_size); + return *this; + } + + iterator operator++(int) + { + iterator previous(*this); + operator++(); + return previous; + } + + /** + * @brief Advance the iterator, using a custom size for the next batch + * + * Using operator++ means that we will load up the same batch_size for each + * batch. This method allows us to get around this restriction, and load up + * arbitrary batch sizes on each iteration. + * See raft::neighbors::brute_force::make_batch_k_query for a usage example. + * + * @param[in] next_batch_size: size of the next batch to load up + */ + void advance(int64_t next_batch_size) + { + offset = std::min(offset + current.batch_size(), query->index_size); + if (offset + next_batch_size > batches.batch_size()) { + query->load_batch(offset, next_batch_size, &batches); + } + query->slice_batch(batches, offset, next_batch_size, ¤t); + } + + friend bool operator==(const iterator& lhs, const iterator& rhs) + { + return (lhs.query == rhs.query) && (lhs.offset == rhs.offset); + }; + friend bool operator!=(const iterator& lhs, const iterator& rhs) { return !(lhs == rhs); }; + + protected: + // the current batch of data + value_type current; + + // the currently loaded group of data (containing multiple batches of data that we can iterate + // through) + value_type batches; + + const batch_k_query* query; + int64_t offset, current_batch_size; + }; + + iterator begin() const { return iterator(this); } + iterator end() const { return iterator(this, index_size); } + + protected: + // these two methods need cuda code, and are implemented in the subclass + virtual void load_batch(int64_t offset, + int64_t next_batch_size, + batch* output) const = 0; + virtual void slice_batch(const value_type& input, + int64_t offset, + int64_t batch_size, + value_type* output) const = 0; + + const raft::resources& res; + int64_t index_size, query_size, batch_size; +}; /** @} */ } // namespace raft::neighbors::brute_force diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 5da4e77874..27ef00e385 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -36,6 +36,7 @@ #include #include #include +#include #include #include #include @@ -65,11 +66,12 @@ void tiled_brute_force_knn(const raft::resources& handle, ElementType* distances, // size (m, k) IndexType* indices, // size (m, k) raft::distance::DistanceType metric, - float metric_arg = 2.0, - size_t max_row_tile_size = 0, - size_t max_col_tile_size = 0, - DistanceEpilogue distance_epilogue = raft::identity_op(), - const ElementType* precomputed_index_norms = nullptr) + float metric_arg = 2.0, + size_t max_row_tile_size = 0, + size_t max_col_tile_size = 0, + DistanceEpilogue distance_epilogue = raft::identity_op(), + const ElementType* precomputed_index_norms = nullptr, + const ElementType* precomputed_search_norms = nullptr) { // Figure out the number of rows/cols to tile for size_t tile_rows = 0; @@ -98,18 +100,20 @@ void tiled_brute_force_knn(const raft::resources& handle, if (metric == raft::distance::DistanceType::L2Expanded || metric == raft::distance::DistanceType::L2SqrtExpanded || metric == raft::distance::DistanceType::CosineExpanded) { - search_norms.resize(m, stream); + if (!precomputed_search_norms) { search_norms.resize(m, stream); } if (!precomputed_index_norms) { index_norms.resize(n, stream); } // cosine needs the l2norm, where as l2 distances needs the squared norm if (metric == raft::distance::DistanceType::CosineExpanded) { - raft::linalg::rowNorm(search_norms.data(), - search, - d, - m, - raft::linalg::NormType::L2Norm, - true, - stream, - raft::sqrt_op{}); + if (!precomputed_search_norms) { + raft::linalg::rowNorm(search_norms.data(), + search, + d, + m, + raft::linalg::NormType::L2Norm, + true, + stream, + raft::sqrt_op{}); + } if (!precomputed_index_norms) { raft::linalg::rowNorm(index_norms.data(), index, @@ -121,9 +125,10 @@ void tiled_brute_force_knn(const raft::resources& handle, raft::sqrt_op{}); } } else { - raft::linalg::rowNorm( - search_norms.data(), search, d, m, raft::linalg::NormType::L2Norm, true, stream); - + if (!precomputed_search_norms) { + raft::linalg::rowNorm( + search_norms.data(), search, d, m, raft::linalg::NormType::L2Norm, true, stream); + } if (!precomputed_index_norms) { raft::linalg::rowNorm( index_norms.data(), index, d, n, raft::linalg::NormType::L2Norm, true, stream); @@ -184,7 +189,7 @@ void tiled_brute_force_knn(const raft::resources& handle, metric_arg); if (metric == raft::distance::DistanceType::L2Expanded || metric == raft::distance::DistanceType::L2SqrtExpanded) { - auto row_norms = search_norms.data(); + auto row_norms = precomputed_search_norms ? precomputed_search_norms : search_norms.data(); auto col_norms = precomputed_index_norms ? precomputed_index_norms : index_norms.data(); auto dist = temp_distances.data(); bool sqrt = metric == raft::distance::DistanceType::L2SqrtExpanded; @@ -201,7 +206,7 @@ void tiled_brute_force_knn(const raft::resources& handle, return distance_epilogue(val, row, col); }); } else if (metric == raft::distance::DistanceType::CosineExpanded) { - auto row_norms = search_norms.data(); + auto row_norms = precomputed_search_norms ? precomputed_search_norms : search_norms.data(); auto col_norms = precomputed_index_norms ? precomputed_index_norms : index_norms.data(); auto dist = temp_distances.data(); @@ -333,7 +338,8 @@ void brute_force_knn_impl( raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded, float metricArg = 0, DistanceEpilogue distance_epilogue = raft::identity_op(), - std::vector* input_norms = nullptr) + std::vector* input_norms = nullptr, + const value_t* search_norms = nullptr) { auto userStream = resource::get_cuda_stream(handle); @@ -376,7 +382,7 @@ void brute_force_knn_impl( } // currently we don't support col_major inside tiled_brute_force_knn, because - // of limitattions of the pairwise_distance API: + // of limitations of the pairwise_distance API: // 1) paiwise_distance takes a single 'isRowMajor' parameter - and we have // multiple options here (like rowMajorQuery/rowMajorIndex) // 2) because of tiling, we need to be able to set a custom stride in the PW @@ -428,7 +434,8 @@ void brute_force_knn_impl( rowMajorQuery, stream, metric, - input_norms ? (*input_norms)[i] : nullptr); + input_norms ? (*input_norms)[i] : nullptr, + search_norms); // Perform necessary post-processing if (metric == raft::distance::DistanceType::L2SqrtExpanded || @@ -478,7 +485,8 @@ void brute_force_knn_impl( 0, 0, distance_epilogue, - input_norms ? (*input_norms)[i] : nullptr); + input_norms ? (*input_norms)[i] : nullptr, + search_norms); break; } } @@ -500,4 +508,43 @@ void brute_force_knn_impl( if (translations == nullptr) delete id_ranges; }; +template +void brute_force_search( + raft::resources const& res, + const raft::neighbors::brute_force::index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + std::optional> query_norms = std::nullopt) +{ + RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), "Value of k must match for outputs"); + RAFT_EXPECTS(idx.dataset().extent(1) == queries.extent(1), + "Number of columns in queries must match brute force index"); + + auto k = neighbors.extent(1); + auto d = idx.dataset().extent(1); + + std::vector dataset = {const_cast(idx.dataset().data_handle())}; + std::vector sizes = {idx.dataset().extent(0)}; + std::vector norms; + if (idx.has_norms()) { norms.push_back(const_cast(idx.norms().data_handle())); } + + brute_force_knn_impl(res, + dataset, + sizes, + d, + const_cast(queries.data_handle()), + queries.extent(0), + neighbors.data_handle(), + distances.data_handle(), + k, + true, + true, + nullptr, + idx.metric(), + idx.metric_arg(), + raft::identity_op(), + norms.size() ? &norms : nullptr, + query_norms ? query_norms->data_handle() : nullptr); +} } // namespace raft::neighbors::detail diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force_batch_k_query.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force_batch_k_query.cuh new file mode 100644 index 0000000000..384eacae79 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/knn_brute_force_batch_k_query.cuh @@ -0,0 +1,98 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +namespace raft::neighbors::brute_force::detail { +template +class gpu_batch_k_query : public batch_k_query { + public: + gpu_batch_k_query(const raft::resources& res, + const raft::neighbors::brute_force::index& index, + raft::device_matrix_view query, + int64_t batch_size) + : batch_k_query(res, index.size(), query.extent(0), batch_size), + index(index), + query(query) + { + auto metric = index.metric(); + + // precompute query norms, and re-use across batches + if (metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded || + metric == raft::distance::DistanceType::CosineExpanded) { + query_norms = make_device_vector(res, query.extent(0)); + + if (metric == raft::distance::DistanceType::CosineExpanded) { + raft::linalg::norm(res, + query, + query_norms->view(), + raft::linalg::NormType::L2Norm, + raft::linalg::Apply::ALONG_ROWS, + raft::sqrt_op{}); + } else { + raft::linalg::norm(res, + query, + query_norms->view(), + raft::linalg::NormType::L2Norm, + raft::linalg::Apply::ALONG_ROWS); + } + } + } + + protected: + void load_batch(int64_t offset, int64_t next_batch_size, batch* output) const override + { + if (offset >= index.size()) { return; } + + // we're aiming to load multiple batches here - since we don't know the max iteration + // grow the size we're loading exponentially + int64_t batch_size = std::min(std::max(offset * 2, next_batch_size * 2), this->index_size); + output->resize(this->res, this->query_size, batch_size); + + std::optional> query_norms_view; + if (query_norms) { query_norms_view = query_norms->view(); } + + raft::neighbors::detail::brute_force_search( + this->res, index, query, output->indices(), output->distances(), query_norms_view); + }; + + void slice_batch(const batch& input, + int64_t offset, + int64_t batch_size, + batch* output) const override + { + auto num_queries = input.indices().extent(0); + batch_size = std::min(batch_size, index.size() - offset); + + output->resize(this->res, num_queries, batch_size); + + if (!num_queries || !batch_size) { return; } + + matrix::slice_coordinates coords{0, offset, num_queries, offset + batch_size}; + matrix::slice(this->res, input.indices(), output->indices(), coords); + matrix::slice(this->res, input.distances(), output->distances(), coords); + } + + const raft::neighbors::brute_force::index& index; + raft::device_matrix_view query; + std::optional> query_norms; +}; +} // namespace raft::neighbors::brute_force::detail diff --git a/cpp/include/raft/neighbors/neighbors_types.hpp b/cpp/include/raft/neighbors/neighbors_types.hpp new file mode 100644 index 0000000000..d503779741 --- /dev/null +++ b/cpp/include/raft/neighbors/neighbors_types.hpp @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::neighbors { + +/** A single batch of nearest neighbors in device memory */ +template +class batch { + public: + /** Create a new empty batch of data */ + batch(raft::resources const& res, int64_t rows, int64_t cols) + : indices_(make_device_matrix(res, rows, cols)), + distances_(make_device_matrix(res, rows, cols)) + { + } + + void resize(raft::resources const& res, int64_t rows, int64_t cols) + { + indices_ = make_device_matrix(res, rows, cols); + distances_ = make_device_matrix(res, rows, cols); + } + + /** Returns the indices for the batch */ + device_matrix_view indices() const + { + return raft::make_const_mdspan(indices_.view()); + } + device_matrix_view indices() { return indices_.view(); } + + /** Returns the distances for the batch */ + device_matrix_view distances() const + { + return raft::make_const_mdspan(distances_.view()); + } + device_matrix_view distances() { return distances_.view(); } + + /** Returns the size of the batch */ + int64_t batch_size() const { return indices().extent(1); } + + protected: + raft::device_matrix indices_; + raft::device_matrix distances_; +}; +} // namespace raft::neighbors diff --git a/cpp/src/neighbors/brute_force_knn_index_float.cu b/cpp/src/neighbors/brute_force_knn_index_float.cu index f2fda93a97..d4f902c087 100644 --- a/cpp/src/neighbors/brute_force_knn_index_float.cu +++ b/cpp/src/neighbors/brute_force_knn_index_float.cu @@ -36,4 +36,4 @@ template raft::neighbors::brute_force::index raft::neighbors::brute_force raft::resources const& res, raft::device_matrix_view dataset, raft::distance::DistanceType metric, - float metric_arg); \ No newline at end of file + float metric_arg); diff --git a/cpp/test/neighbors/tiled_knn.cu b/cpp/test/neighbors/tiled_knn.cu index ebde8e6d35..a84c9749d7 100644 --- a/cpp/test/neighbors/tiled_knn.cu +++ b/cpp/test/neighbors/tiled_knn.cu @@ -38,6 +38,7 @@ #include namespace raft::neighbors::brute_force { + struct TiledKNNInputs { int num_queries; int num_db_vecs; @@ -190,11 +191,13 @@ class TiledKNNTest : public ::testing::TestWithParam { metric, metric_arg); + auto query_view = raft::make_device_matrix_view( + search_queries.data(), params_.num_queries, params_.dim); + raft::neighbors::brute_force::search( handle_, idx, - raft::make_device_matrix_view( - search_queries.data(), params_.num_queries, params_.dim), + query_view, raft::make_device_matrix_view( raft_indices_.data(), params_.num_queries, params_.k), raft::make_device_matrix_view( @@ -209,6 +212,73 @@ class TiledKNNTest : public ::testing::TestWithParam { float(0.001), stream_, true)); + // also test out the batch api. First get new reference results (all k, up to a certain + // max size) + auto all_size = std::min(params_.num_db_vecs, 1024); + auto all_indices = raft::make_device_matrix(handle_, num_queries, all_size); + auto all_distances = raft::make_device_matrix(handle_, num_queries, all_size); + raft::neighbors::brute_force::search( + handle_, idx, query_view, all_indices.view(), all_distances.view()); + + int64_t offset = 0; + auto query = make_batch_k_query(handle_, idx, query_view, k_); + for (auto batch : *query) { + auto batch_size = batch.batch_size(); + auto indices = raft::make_device_matrix(handle_, num_queries, batch_size); + auto distances = raft::make_device_matrix(handle_, num_queries, batch_size); + + matrix::slice_coordinates coords{0, offset, num_queries, offset + batch_size}; + + matrix::slice(handle_, raft::make_const_mdspan(all_indices.view()), indices.view(), coords); + matrix::slice( + handle_, raft::make_const_mdspan(all_distances.view()), distances.view(), coords); + + ASSERT_TRUE(raft::spatial::knn::devArrMatchKnnPair(indices.data_handle(), + batch.indices().data_handle(), + distances.data_handle(), + batch.distances().data_handle(), + num_queries, + batch_size, + float(0.001), + stream_, + true)); + + offset += batch_size; + if (offset + batch_size > all_size) break; + } + + // also test out with variable batch sizes + offset = 0; + int64_t batch_size = k_; + query = make_batch_k_query(handle_, idx, query_view, batch_size); + for (auto it = query->begin(); it != query->end(); it.advance(batch_size)) { + // batch_size could be less than requested (in the case of final batch). handle. + ASSERT_TRUE(it->indices().extent(1) <= batch_size); + batch_size = it->indices().extent(1); + + auto indices = raft::make_device_matrix(handle_, num_queries, batch_size); + auto distances = raft::make_device_matrix(handle_, num_queries, batch_size); + + matrix::slice_coordinates coords{0, offset, num_queries, offset + batch_size}; + matrix::slice(handle_, raft::make_const_mdspan(all_indices.view()), indices.view(), coords); + matrix::slice( + handle_, raft::make_const_mdspan(all_distances.view()), distances.view(), coords); + + ASSERT_TRUE(raft::spatial::knn::devArrMatchKnnPair(indices.data_handle(), + it->indices().data_handle(), + distances.data_handle(), + it->distances().data_handle(), + num_queries, + batch_size, + float(0.001), + stream_, + true)); + + offset += batch_size; + if (offset + batch_size > all_size) break; + + batch_size += 2; + } } }