diff --git a/build.sh b/build.sh index e8dfa3e404..0caa823ca7 100755 --- a/build.sh +++ b/build.sh @@ -19,7 +19,7 @@ ARGS=$* REPODIR=$(cd $(dirname $0); pwd) VALIDARGS="clean libraft pylibraft raft-dask docs tests bench clean -v -g --install --compile-libs --compile-nn --compile-dist --allgpuarch --no-nvtx --show_depr_warn -h --buildfaiss --minimal-deps" -HELP="$0 [ ...] [ ...] [--cmake-args=\"\"] [--cache-tool=] +HELP="$0 [ ...] [ ...] [--cmake-args=\"\"] [--cache-tool=] [--limit-tests=] where is: clean - remove all existing build artifacts and configuration (start over) libraft - build the raft C++ code only. Also builds the C-wrapper library @@ -40,6 +40,7 @@ HELP="$0 [ ...] [ ...] [--cmake-args=\"\"] [--cache-tool=\"] [--cache-tool= +/** + * Builds and populates a previously unbuilt BallCoverIndex + * @tparam idx_t knn index type + * @tparam value_t knn value type + * @tparam int_t integral type for knn params + * @tparam matrix_idx_t matrix indexing type + * @param[in] handle library resource management handle + * @param[inout] index an empty (and not previous built) instance of BallCoverIndex + */ +template void rbc_build_index(const raft::handle_t& handle, - BallCoverIndex& index) + BallCoverIndex& index) { ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); if (index.metric == raft::distance::DistanceType::Haversine) { - detail::rbc_build_index(handle, index, detail::HaversineFunc()); + detail::rbc_build_index(handle, index, detail::HaversineFunc()); } else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded || index.metric == raft::distance::DistanceType::L2SqrtUnexpanded) { - detail::rbc_build_index(handle, index, detail::EuclideanFunc()); + detail::rbc_build_index(handle, index, detail::EuclideanFunc()); } else { RAFT_FAIL("Metric not support"); } @@ -55,18 +64,18 @@ void rbc_build_index(const raft::handle_t& handle, * the index and query are the same array. This function will * build the index and assumes rbc_build_index() has not already * been called. - * @tparam value_idx knn index type + * @tparam idx_t knn index type * @tparam value_t knn distance type - * @tparam value_int type for integers, such as number of rows/cols - * @param handle raft handle for resource management - * @param index ball cover index which has not yet been built - * @param k number of nearest neighbors to find - * @param perform_post_filtering if this is false, only the closest k landmarks + * @tparam int_t type for integers, such as number of rows/cols + * @param[in] handle raft handle for resource management + * @param[inout] index ball cover index which has not yet been built + * @param[in] k number of nearest neighbors to find + * @param[in] perform_post_filtering if this is false, only the closest k landmarks * are considered (which will return approximate * results). * @param[out] inds output knn indices * @param[out] dists output knn distances - * @param weight a weight for overlap between the closest landmark and + * @param[in] weight a weight for overlap between the closest landmark and * the radius of other landmarks when pruning distances. * Setting this value below 1 can effectively turn off * computing distances against many other balls, enabling @@ -75,11 +84,11 @@ void rbc_build_index(const raft::handle_t& handle, * many datasets can still have great recall even by only * looking in the closest landmark. */ -template +template void rbc_all_knn_query(const raft::handle_t& handle, - BallCoverIndex& index, - value_int k, - value_idx* inds, + BallCoverIndex& index, + int_t k, + idx_t* inds, value_t* dists, bool perform_post_filtering = true, float weight = 1.0) @@ -91,7 +100,7 @@ void rbc_all_knn_query(const raft::handle_t& handle, k, inds, dists, - detail::HaversineFunc(), + detail::HaversineFunc(), perform_post_filtering, weight); } else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded || @@ -101,7 +110,7 @@ void rbc_all_knn_query(const raft::handle_t& handle, k, inds, dists, - detail::EuclideanFunc(), + detail::EuclideanFunc(), perform_post_filtering, weight); } else { @@ -111,6 +120,58 @@ void rbc_all_knn_query(const raft::handle_t& handle, index.set_index_trained(); } +/** + * Performs a faster exact knn in metric spaces using the triangle + * inequality with a number of landmark points to reduce the + * number of distance computations from O(n^2) to O(sqrt(n)). This + * performs an all neighbors knn, which can reuse memory when + * the index and query are the same array. This function will + * build the index and assumes rbc_build_index() has not already + * been called. + * @tparam idx_t knn index type + * @tparam value_t knn distance type + * @tparam int_t type for integers, such as number of rows/cols + * @tparam matrix_idx_t matrix indexing type + * @param[in] handle raft handle for resource management + * @param[in] index ball cover index which has not yet been built + * @param[out] inds output knn indices + * @param[out] dists output knn distances + * @param[in] k number of nearest neighbors to find + * @param[in] perform_post_filtering if this is false, only the closest k landmarks + * are considered (which will return approximate + * results). + * @param[in] weight a weight for overlap between the closest landmark and + * the radius of other landmarks when pruning distances. + * Setting this value below 1 can effectively turn off + * computing distances against many other balls, enabling + * approximate nearest neighbors. Recall can be adjusted + * based on how many relevant balls are ignored. Note that + * many datasets can still have great recall even by only + * looking in the closest landmark. + */ +template +void rbc_all_knn_query(const raft::handle_t& handle, + BallCoverIndex& index, + raft::device_matrix_view inds, + raft::device_matrix_view dists, + int_t k, + bool perform_post_filtering = true, + float weight = 1.0) +{ + RAFT_EXPECTS(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); + RAFT_EXPECTS(k <= index.m, + "k must be less than or equal to the number of data points in the index"); + RAFT_EXPECTS(inds.extent(1) == dists.extent(1) && dists.extent(1) == static_cast(k), + "Number of columns in output indices and distances matrices must be equal to k"); + + RAFT_EXPECTS(inds.extent(0) == dists.extent(0) && dists.extent(0) == index.get_X().extent(0), + "Number of rows in output indices and distances matrices must equal number of rows " + "in index matrix."); + + rbc_all_knn_query( + handle, index, k, inds.data_handle(), dists.data_handle(), perform_post_filtering, weight); +} + /** * Performs a faster exact knn in metric spaces using the triangle * inequality with a number of landmark points to reduce the @@ -118,19 +179,19 @@ void rbc_all_knn_query(const raft::handle_t& handle, * function does not build the index and assumes rbc_build_index() has * already been called. Use this function when the index and * query arrays are different, otherwise use rbc_all_knn_query(). - * @tparam value_idx index type + * @tparam idx_t index type * @tparam value_t distances type - * @tparam value_int integer type for size info - * @param handle raft handle for resource management - * @param index ball cover index which has not yet been built - * @param k number of nearest neighbors to find - * @param query the - * @param perform_post_filtering if this is false, only the closest k landmarks + * @tparam int_t integer type for size info + * @param[in] handle raft handle for resource management + * @param[inout] index ball cover index which has not yet been built + * @param[in] k number of nearest neighbors to find + * @param[in] query the + * @param[in] perform_post_filtering if this is false, only the closest k landmarks * are considered (which will return approximate * results). * @param[out] inds output knn indices * @param[out] dists output knn distances - * @param weight a weight for overlap between the closest landmark and + * @param[in] weight a weight for overlap between the closest landmark and * the radius of other landmarks when pruning distances. * Setting this value below 1 can effectively turn off * computing distances against many other balls, enabling @@ -140,13 +201,13 @@ void rbc_all_knn_query(const raft::handle_t& handle, * looking in the closest landmark. * @param[in] n_query_pts number of query points */ -template +template void rbc_knn_query(const raft::handle_t& handle, - BallCoverIndex& index, - value_int k, + BallCoverIndex& index, + int_t k, const value_t* query, - value_int n_query_pts, - value_idx* inds, + int_t n_query_pts, + idx_t* inds, value_t* dists, bool perform_post_filtering = true, float weight = 1.0) @@ -160,7 +221,7 @@ void rbc_knn_query(const raft::handle_t& handle, n_query_pts, inds, dists, - detail::HaversineFunc(), + detail::HaversineFunc(), perform_post_filtering, weight); } else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded || @@ -172,7 +233,7 @@ void rbc_knn_query(const raft::handle_t& handle, n_query_pts, inds, dists, - detail::EuclideanFunc(), + detail::EuclideanFunc(), perform_post_filtering, weight); } else { @@ -180,6 +241,68 @@ void rbc_knn_query(const raft::handle_t& handle, } } +/** + * Performs a faster exact knn in metric spaces using the triangle + * inequality with a number of landmark points to reduce the + * number of distance computations from O(n^2) to O(sqrt(n)). This + * function does not build the index and assumes rbc_build_index() has + * already been called. Use this function when the index and + * query arrays are different, otherwise use rbc_all_knn_query(). + * @tparam idx_t index type + * @tparam value_t distances type + * @tparam int_t integer type for size info + * @tparam matrix_idx_t + * @param[in] handle raft handle for resource management + * @param[in] index ball cover index which has not yet been built + * @param[in] query device matrix containing query data points + * @param[out] inds output knn indices + * @param[out] dists output knn distances + * @param[in] k number of nearest neighbors to find + * @param[in] perform_post_filtering if this is false, only the closest k landmarks + * are considered (which will return approximate + * results). + * @param[in] weight a weight for overlap between the closest landmark and + * the radius of other landmarks when pruning distances. + * Setting this value below 1 can effectively turn off + * computing distances against many other balls, enabling + * approximate nearest neighbors. Recall can be adjusted + * based on how many relevant balls are ignored. Note that + * many datasets can still have great recall even by only + * looking in the closest landmark. + */ +template +void rbc_knn_query(const raft::handle_t& handle, + BallCoverIndex& index, + raft::device_matrix_view query, + raft::device_matrix_view inds, + raft::device_matrix_view dists, + int_t k, + bool perform_post_filtering = true, + float weight = 1.0) +{ + RAFT_EXPECTS(k <= index.m, + "k must be less than or equal to the number of data points in the index"); + RAFT_EXPECTS(inds.extent(1) == dists.extent(1) && dists.extent(1) == static_cast(k), + "Number of columns in output indices and distances matrices must be equal to k"); + + RAFT_EXPECTS(inds.extent(0) == dists.extent(0) && dists.extent(0) == query.extent(0), + "Number of rows in output indices and distances matrices must equal number of rows " + "in search matrix."); + + RAFT_EXPECTS(query.extent(1) == index.get_R().extent(1), + "Number of columns in query and index matrices must match."); + + rbc_knn_query(handle, + index, + k, + query.data_handle(), + query.extent(0), + inds.data_handle(), + dists.data_handle(), + perform_post_filtering, + weight); +} + // TODO: implement functions for: // 4. rbc_eps_neigh() - given a populated index, perform query against different query array // 5. rbc_all_eps_neigh() - populate a BallCoverIndex and query against training data diff --git a/cpp/include/raft/spatial/knn/ball_cover_types.hpp b/cpp/include/raft/spatial/knn/ball_cover_types.hpp index 9870217011..1dd45365b7 100644 --- a/cpp/include/raft/spatial/knn/ball_cover_types.hpp +++ b/cpp/include/raft/spatial/knn/ball_cover_types.hpp @@ -17,6 +17,8 @@ #pragma once #include +#include +#include #include #include #include @@ -34,7 +36,10 @@ namespace knn { * @tparam value_t * @tparam value_int */ -template +template class BallCoverIndex { public: explicit BallCoverIndex(const raft::handle_t& handle_, @@ -43,7 +48,7 @@ class BallCoverIndex { value_int n_, raft::distance::DistanceType metric_) : handle(handle_), - X(X_), + X(raft::make_device_matrix_view(X_, m_, n_)), m(m_), n(n_), metric(metric_), @@ -53,23 +58,55 @@ class BallCoverIndex { * Total memory footprint of index: (2 * sqrt(m)) + (n * sqrt(m)) + (2 * m) */ n_landmarks(sqrt(m_)), - R_indptr(sqrt(m_) + 1, handle.get_stream()), - R_1nn_cols(m_, handle.get_stream()), - R_1nn_dists(m_, handle.get_stream()), - R_closest_landmark_dists(m_, handle.get_stream()), - R(sqrt(m_) * n_, handle.get_stream()), - R_radius(sqrt(m_), handle.get_stream()), + R_indptr(std::move(raft::make_device_vector(handle, sqrt(m_) + 1))), + R_1nn_cols(std::move(raft::make_device_vector(handle, m_))), + R_1nn_dists(std::move(raft::make_device_vector(handle, m_))), + R_closest_landmark_dists( + std::move(raft::make_device_vector(handle, m_))), + R(std::move(raft::make_device_matrix(handle, sqrt(m_), n_))), + R_radius(std::move(raft::make_device_vector(handle, sqrt(m_)))), index_trained(false) { } - value_idx* get_R_indptr() { return R_indptr.data(); } - value_idx* get_R_1nn_cols() { return R_1nn_cols.data(); } - value_t* get_R_1nn_dists() { return R_1nn_dists.data(); } - value_t* get_R_radius() { return R_radius.data(); } - value_t* get_R() { return R.data(); } - value_t* get_R_closest_landmark_dists() { return R_closest_landmark_dists.data(); } - const value_t* get_X() { return X; } + explicit BallCoverIndex(const raft::handle_t& handle_, + raft::device_matrix_view X_, + raft::distance::DistanceType metric_) + : handle(handle_), + X(X_), + m(X_.extent(0)), + n(X_.extent(1)), + metric(metric_), + /** + * the sqrt() here makes the sqrt(m)^2 a linear-time lower bound + * + * Total memory footprint of index: (2 * sqrt(m)) + (n * sqrt(m)) + (2 * m) + */ + n_landmarks(sqrt(X_.extent(0))), + R_indptr( + std::move(raft::make_device_vector(handle, sqrt(X_.extent(0)) + 1))), + R_1nn_cols(std::move(raft::make_device_vector(handle, X_.extent(0)))), + R_1nn_dists(std::move(raft::make_device_vector(handle, X_.extent(0)))), + R_closest_landmark_dists( + std::move(raft::make_device_vector(handle, X_.extent(0)))), + R(std::move( + raft::make_device_matrix(handle, sqrt(X_.extent(0)), X_.extent(1)))), + R_radius( + std::move(raft::make_device_vector(handle, sqrt(X_.extent(0))))), + index_trained(false) + { + } + + raft::device_vector_view get_R_indptr() { return R_indptr.view(); } + raft::device_vector_view get_R_1nn_cols() { return R_1nn_cols.view(); } + raft::device_vector_view get_R_1nn_dists() { return R_1nn_dists.view(); } + raft::device_vector_view get_R_radius() { return R_radius.view(); } + raft::device_matrix_view get_R() { return R.view(); } + raft::device_vector_view get_R_closest_landmark_dists() + { + return R_closest_landmark_dists.view(); + } + raft::device_matrix_view get_X() { return X; } bool is_index_trained() const { return index_trained; }; @@ -82,20 +119,20 @@ class BallCoverIndex { const value_int n; const value_int n_landmarks; - const value_t* X; + raft::device_matrix_view X; raft::distance::DistanceType metric; private: // CSR storing the neighborhoods for each data point - rmm::device_uvector R_indptr; - rmm::device_uvector R_1nn_cols; - rmm::device_uvector R_1nn_dists; - rmm::device_uvector R_closest_landmark_dists; + raft::device_vector R_indptr; + raft::device_vector R_1nn_cols; + raft::device_vector R_1nn_dists; + raft::device_vector R_closest_landmark_dists; - rmm::device_uvector R_radius; + raft::device_vector R_radius; - rmm::device_uvector R; + raft::device_matrix R; protected: bool index_trained; diff --git a/cpp/include/raft/spatial/knn/brute_force.cuh b/cpp/include/raft/spatial/knn/brute_force.cuh new file mode 100644 index 0000000000..c32a33d2e2 --- /dev/null +++ b/cpp/include/raft/spatial/knn/brute_force.cuh @@ -0,0 +1,154 @@ +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "detail/knn_brute_force_faiss.cuh" +#include "detail/selection_faiss.cuh" +#include + +namespace raft::spatial::knn { + +/** + * @brief Performs a k-select across row partitioned index/distance + * matrices formatted like the following: + * row1: k0, k1, k2 + * row2: k0, k1, k2 + * row3: k0, k1, k2 + * row1: k0, k1, k2 + * row2: k0, k1, k2 + * row3: k0, k1, k2 + * + * etc... + * + * @tparam idx_t + * @tparam value_t + * @param[in] handle + * @param[in] in_keys matrix of input keys (size n_samples * n_parts * k) + * @param[in] in_values matrix of input values (size n_samples * n_parts * k) + * @param[out] out_keys matrix of output keys (size n_samples * k) + * @param[out] out_values matrix of output values (size n_samples * k) + * @param[in] n_samples number of rows in each part + * @param[in] k number of neighbors for each part + * @param[in] translations optional vector of starting index mappings for each partition + */ +template +inline void knn_merge_parts( + const raft::handle_t& handle, + raft::device_matrix_view in_keys, + raft::device_matrix_view in_values, + raft::device_matrix_view out_keys, + raft::device_matrix_view out_values, + size_t n_samples, + int k, + std::optional> translations = std::nullopt) +{ + RAFT_EXPECTS(in_keys.extent(1) == in_values.extent(1) && in_keys.extent(0) == in_values.extent(0), + "in_keys and in_values must have the same shape."); + RAFT_EXPECTS( + out_keys.extent(0) == out_values.extent(0) == n_samples, + "Number of rows in output keys and val matrices must equal number of rows in search matrix."); + RAFT_EXPECTS(out_keys.extent(1) == out_values.extent(1) == k, + "Number of columns in output indices and distances matrices must be equal to k"); + + auto n_parts = in_keys.extent(0) / n_samples; + detail::knn_merge_parts(in_keys.data_handle(), + in_values.data_handle(), + out_keys.data_handle(), + out_values.data_handle(), + n_samples, + n_parts, + k, + handle.get_stream(), + translations.value_or(nullptr)); +} + +/** + * @brief Flat C++ API function to perform a brute force knn on + * a series of input arrays and combine the results into a single + * output array for indexes and distances. Inputs can be either + * row- or column-major but the output matrices will always be in + * row-major format. + * + * @param[in] handle the cuml handle to use + * @param[in] index vector of device matrices (each size m_i*d) to be used as the knn index + * @param[in] search matrix (size n*d) to be used for searching the index + * @param[out] indices matrix (size n*k) to store output knn indices + * @param[out] distances matrix (size n*k) to store the output knn distance + * @param[in] k the number of nearest neighbors to return + * @param[in] metric distance metric to use. Euclidean (L2) is used by default + * @param[in] metric_arg the value of `p` for Minkowski (l-p) distances. This + * is ignored if the metric_type is not Minkowski. + * @param[in] translations starting offsets for partitions. should be the same size + * as input vector. + */ +template +void brute_force_knn( + raft::handle_t const& handle, + std::vector> index, + raft::device_matrix_view search, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + value_int k, + distance::DistanceType metric = distance::DistanceType::L2Unexpanded, + std::optional metric_arg = std::make_optional(2.0f), + std::optional> translations = std::nullopt) +{ + RAFT_EXPECTS(index[0].extent(1) == search.extent(1), + "Number of dimensions for both index and search matrices must be equal"); + + RAFT_EXPECTS(indices.extent(0) == distances.extent(0) && distances.extent(0) == search.extent(0), + "Number of rows in output indices and distances matrices must equal number of rows " + "in search matrix."); + RAFT_EXPECTS( + indices.extent(1) == distances.extent(1) && distances.extent(1) == static_cast(k), + "Number of columns in output indices and distances matrices must be equal to k"); + + bool rowMajorIndex = std::is_same_v; + bool rowMajorQuery = std::is_same_v; + + std::vector inputs; + std::vector sizes; + for (std::size_t i = 0; i < index.size(); ++i) { + inputs.push_back(const_cast(index[i].data_handle())); + sizes.push_back(index[i].extent(0)); + } + + std::vector* trans = translations.has_value() ? &(*translations) : nullptr; + + detail::brute_force_knn_impl(handle, + inputs, + sizes, + static_cast(index[0].extent(1)), + // TODO: This is unfortunate. Need to fix. + const_cast(search.data_handle()), + static_cast(search.extent(0)), + indices.data_handle(), + distances.data_handle(), + k, + rowMajorIndex, + rowMajorQuery, + trans, + metric, + metric_arg.value_or(2.0f)); +} + +} // namespace raft::spatial::knn diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index 457e1f495a..e65a895f60 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -75,8 +75,8 @@ void sample_landmarks(const raft::handle_t& handle, rmm::device_uvector R_indices(index.n_landmarks, handle.get_stream()); thrust::sequence(handle.get_thrust_policy(), - index.get_R_1nn_cols(), - index.get_R_1nn_cols() + index.m, + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_cols().data_handle() + index.m, (value_idx)0); thrust::fill( @@ -93,15 +93,15 @@ void sample_landmarks(const raft::handle_t& handle, rng_state, R_indices.data(), R_1nn_cols2.data(), - index.get_R_1nn_cols(), + index.get_R_1nn_cols().data_handle(), R_1nn_ones.data(), (value_idx)index.n_landmarks, (value_idx)index.m); - raft::matrix::copyRows(index.get_X(), + raft::matrix::copyRows(index.get_X().data_handle(), index.m, index.n, - index.get_R(), + index.get_R().data_handle(), R_1nn_cols2.data(), index.n_landmarks, handle.get_stream(), @@ -133,7 +133,7 @@ void construct_landmark_1nn(const raft::handle_t& handle, std::numeric_limits::max()); value_idx* R_1nn_inds_ptr = R_1nn_inds.data(); - value_t* R_1nn_dists_ptr = index.get_R_1nn_dists(); + value_t* R_1nn_dists_ptr = index.get_R_1nn_dists().data_handle(); auto idxs = thrust::make_counting_iterator(0); thrust::for_each(handle.get_thrust_policy(), idxs, idxs + index.m, [=] __device__(value_idx i) { @@ -141,16 +141,22 @@ void construct_landmark_1nn(const raft::handle_t& handle, R_1nn_dists_ptr[i] = R_knn_dists_ptr[i * k]; }); - auto keys = - thrust::make_zip_iterator(thrust::make_tuple(R_1nn_inds.data(), index.get_R_1nn_dists())); + auto keys = thrust::make_zip_iterator( + thrust::make_tuple(R_1nn_inds.data(), index.get_R_1nn_dists().data_handle())); // group neighborhoods for each reference landmark and sort each group by distance - thrust::sort_by_key( - handle.get_thrust_policy(), keys, keys + index.m, index.get_R_1nn_cols(), NNComp()); + thrust::sort_by_key(handle.get_thrust_policy(), + keys, + keys + index.m, + index.get_R_1nn_cols().data_handle(), + NNComp()); // convert to CSR for fast lookup - raft::sparse::convert::sorted_coo_to_csr( - R_1nn_inds.data(), index.m, index.get_R_indptr(), index.n_landmarks + 1, handle.get_stream()); + raft::sparse::convert::sorted_coo_to_csr(R_1nn_inds.data(), + index.m, + index.get_R_indptr().data_handle(), + index.n_landmarks + 1, + handle.get_stream()); } /** @@ -175,7 +181,7 @@ void k_closest_landmarks(const raft::handle_t& handle, value_idx* R_knn_inds, value_t* R_knn_dists) { - std::vector input = {index.get_R()}; + std::vector input = {index.get_R().data_handle()}; std::vector sizes = {index.n_landmarks}; brute_force_knn_impl(handle, @@ -207,9 +213,9 @@ void compute_landmark_radii(const raft::handle_t& handle, { auto entries = thrust::make_counting_iterator(0); - const value_idx* R_indptr_ptr = index.get_R_indptr(); - const value_t* R_1nn_dists_ptr = index.get_R_1nn_dists(); - value_t* R_radius_ptr = index.get_R_radius(); + const value_idx* R_indptr_ptr = index.get_R_indptr().data_handle(); + const value_t* R_1nn_dists_ptr = index.get_R_1nn_dists().data_handle(); + value_t* R_radius_ptr = index.get_R_radius().data_handle(); thrust::for_each(handle.get_thrust_policy(), entries, entries + index.n_landmarks, @@ -350,8 +356,8 @@ void rbc_build_index(const raft::handle_t& handle, R_knn_inds.end(), std::numeric_limits::max()); thrust::fill(handle.get_thrust_policy(), - index.get_R_closest_landmark_dists(), - index.get_R_closest_landmark_dists() + index.m, + index.get_R_closest_landmark_dists().data_handle(), + index.get_R_closest_landmark_dists().data_handle() + index.m, std::numeric_limits::max()); /** @@ -365,11 +371,11 @@ void rbc_build_index(const raft::handle_t& handle, value_int k = 1; k_closest_landmarks(handle, index, - index.get_X(), + index.get_X().data_handle(), index.m, k, R_knn_inds.data(), - index.get_R_closest_landmark_dists()); + index.get_R_closest_landmark_dists().data_handle()); /** * 3. Create L_r = knn[:,0].T (CSR) @@ -377,7 +383,8 @@ void rbc_build_index(const raft::handle_t& handle, * Slice closest neighboring R * Secondary sort by (R_knn_inds, R_knn_dists) */ - construct_landmark_1nn(handle, R_knn_inds.data(), index.get_R_closest_landmark_dists(), k, index); + construct_landmark_1nn( + handle, R_knn_inds.data(), index.get_R_closest_landmark_dists().data_handle(), k, index); /** * Compute radius of each R for filtering: p(q, r) <= p(q, q_r) + radius(r) @@ -432,7 +439,7 @@ void rbc_all_knn_query(const raft::handle_t& handle, sample_landmarks(handle, index); k_closest_landmarks( - handle, index, index.get_X(), index.m, k, R_knn_inds.data(), R_knn_dists.data()); + handle, index, index.get_X().data_handle(), index.m, k, R_knn_inds.data(), R_knn_dists.data()); construct_landmark_1nn(handle, R_knn_inds.data(), R_knn_dists.data(), k, index); @@ -440,7 +447,7 @@ void rbc_all_knn_query(const raft::handle_t& handle, perform_rbc_query(handle, index, - index.get_X(), + index.get_X().data_handle(), index.m, k, R_knn_inds.data(), diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh index 88f5aa3460..c0056e7137 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh @@ -486,114 +486,114 @@ void rbc_low_dim_pass_one(const raft::handle_t& handle, { if (k <= 32) block_rbc_kernel_registers - <<>>(index.get_X(), + <<>>(index.get_X().data_handle(), query, index.n, R_knn_inds, R_knn_dists, index.m, k, - index.get_R_indptr(), - index.get_R_1nn_cols(), - index.get_R_1nn_dists(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), inds, dists, dists_counter, - index.get_R_radius(), + index.get_R_radius().data_handle(), dfunc, weight); else if (k <= 64) block_rbc_kernel_registers - <<>>(index.get_X(), + <<>>(index.get_X().data_handle(), query, index.n, R_knn_inds, R_knn_dists, index.m, k, - index.get_R_indptr(), - index.get_R_1nn_cols(), - index.get_R_1nn_dists(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), inds, dists, dists_counter, - index.get_R_radius(), + index.get_R_radius().data_handle(), dfunc, weight); else if (k <= 128) block_rbc_kernel_registers - <<>>(index.get_X(), + <<>>(index.get_X().data_handle(), query, index.n, R_knn_inds, R_knn_dists, index.m, k, - index.get_R_indptr(), - index.get_R_1nn_cols(), - index.get_R_1nn_dists(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), inds, dists, dists_counter, - index.get_R_radius(), + index.get_R_radius().data_handle(), dfunc, weight); else if (k <= 256) block_rbc_kernel_registers - <<>>(index.get_X(), + <<>>(index.get_X().data_handle(), query, index.n, R_knn_inds, R_knn_dists, index.m, k, - index.get_R_indptr(), - index.get_R_1nn_cols(), - index.get_R_1nn_dists(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), inds, dists, dists_counter, - index.get_R_radius(), + index.get_R_radius().data_handle(), dfunc, weight); else if (k <= 512) block_rbc_kernel_registers - <<>>(index.get_X(), + <<>>(index.get_X().data_handle(), query, index.n, R_knn_inds, R_knn_dists, index.m, k, - index.get_R_indptr(), - index.get_R_1nn_cols(), - index.get_R_1nn_dists(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), inds, dists, dists_counter, - index.get_R_radius(), + index.get_R_radius().data_handle(), dfunc, weight); else if (k <= 1024) block_rbc_kernel_registers - <<>>(index.get_X(), + <<>>(index.get_X().data_handle(), query, index.n, R_knn_inds, R_knn_dists, index.m, k, - index.get_R_indptr(), - index.get_R_1nn_cols(), - index.get_R_1nn_dists(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), inds, dists, dists_counter, - index.get_R_radius(), + index.get_R_radius().data_handle(), dfunc, weight); } @@ -627,8 +627,8 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle, index.n, R_knn_inds, R_knn_dists, - index.get_R_radius(), - index.get_R(), + index.get_R_radius().data_handle(), + index.get_R().data_handle(), index.n_landmarks, bitset_size, k, @@ -645,22 +645,22 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle, 32, 2, 128, - dims> - <<>>(index.get_X(), - query, - index.n, - bitset.data(), - bitset_size, - index.get_R_closest_landmark_dists(), - index.get_R_indptr(), - index.get_R_1nn_cols(), - index.get_R_1nn_dists(), - inds, - dists, - index.n_landmarks, - k, - dfunc, - post_dists_counter); + dims><<>>( + index.get_X().data_handle(), + query, + index.n, + bitset.data(), + bitset_size, + index.get_R_closest_landmark_dists().data_handle(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + index.n_landmarks, + k, + dfunc, + post_dists_counter); else if (k <= 64) compute_final_dists_registers - <<>>(index.get_X(), - query, - index.n, - bitset.data(), - bitset_size, - index.get_R_closest_landmark_dists(), - index.get_R_indptr(), - index.get_R_1nn_cols(), - index.get_R_1nn_dists(), - inds, - dists, - index.n_landmarks, - k, - dfunc, - post_dists_counter); + dims><<>>( + index.get_X().data_handle(), + query, + index.n, + bitset.data(), + bitset_size, + index.get_R_closest_landmark_dists().data_handle(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + index.n_landmarks, + k, + dfunc, + post_dists_counter); else if (k <= 128) compute_final_dists_registers - <<>>(index.get_X(), - query, - index.n, - bitset.data(), - bitset_size, - index.get_R_closest_landmark_dists(), - index.get_R_indptr(), - index.get_R_1nn_cols(), - index.get_R_1nn_dists(), - inds, - dists, - index.n_landmarks, - k, - dfunc, - post_dists_counter); + dims><<>>( + index.get_X().data_handle(), + query, + index.n, + bitset.data(), + bitset_size, + index.get_R_closest_landmark_dists().data_handle(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + index.n_landmarks, + k, + dfunc, + post_dists_counter); else if (k <= 256) compute_final_dists_registers - <<>>(index.get_X(), - query, - index.n, - bitset.data(), - bitset_size, - index.get_R_closest_landmark_dists(), - index.get_R_indptr(), - index.get_R_1nn_cols(), - index.get_R_1nn_dists(), - inds, - dists, - index.n_landmarks, - k, - dfunc, - post_dists_counter); + dims><<>>( + index.get_X().data_handle(), + query, + index.n, + bitset.data(), + bitset_size, + index.get_R_closest_landmark_dists().data_handle(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + index.n_landmarks, + k, + dfunc, + post_dists_counter); else if (k <= 512) compute_final_dists_registers - <<>>(index.get_X(), - query, - index.n, - bitset.data(), - bitset_size, - index.get_R_closest_landmark_dists(), - index.get_R_indptr(), - index.get_R_1nn_cols(), - index.get_R_1nn_dists(), - inds, - dists, - index.n_landmarks, - k, - dfunc, - post_dists_counter); + dims><<>>( + index.get_X().data_handle(), + query, + index.n, + bitset.data(), + bitset_size, + index.get_R_closest_landmark_dists().data_handle(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + index.n_landmarks, + k, + dfunc, + post_dists_counter); else if (k <= 1024) compute_final_dists_registers - <<>>(index.get_X(), - query, - index.n, - bitset.data(), - bitset_size, - index.get_R_closest_landmark_dists(), - index.get_R_indptr(), - index.get_R_1nn_cols(), - index.get_R_1nn_dists(), - inds, - dists, - index.n_landmarks, - k, - dfunc, - post_dists_counter); + dims><<>>( + index.get_X().data_handle(), + query, + index.n, + bitset.data(), + bitset_size, + index.get_R_closest_landmark_dists().data_handle(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + index.n_landmarks, + k, + dfunc, + post_dists_counter); } }; // namespace detail diff --git a/cpp/include/raft/spatial/knn/epsilon_neighborhood.cuh b/cpp/include/raft/spatial/knn/epsilon_neighborhood.cuh index 29ed51fb3d..53fe76fada 100644 --- a/cpp/include/raft/spatial/knn/epsilon_neighborhood.cuh +++ b/cpp/include/raft/spatial/knn/epsilon_neighborhood.cuh @@ -19,6 +19,7 @@ #pragma once +#include #include namespace raft { @@ -28,8 +29,8 @@ namespace knn { /** * @brief Computes epsilon neighborhood for the L2-Squared distance metric * - * @tparam DataT IO and math type - * @tparam IdxT Index type + * @tparam value_t IO and math type + * @tparam idx_t Index type * * @param[out] adj adjacency matrix [row-major] [on device] [dim = m x n] * @param[out] vd vertex degree array [on device] [len = m + 1] @@ -44,19 +45,56 @@ namespace knn { * squared as we compute L2-squared distance in this method) * @param[in] stream cuda stream */ -template +template void epsUnexpL2SqNeighborhood(bool* adj, - IdxT* vd, - const DataT* x, - const DataT* y, - IdxT m, - IdxT n, - IdxT k, - DataT eps, + idx_t* vd, + const value_t* x, + const value_t* y, + idx_t m, + idx_t n, + idx_t k, + value_t eps, cudaStream_t stream) { - detail::epsUnexpL2SqNeighborhood(adj, vd, x, y, m, n, k, eps, stream); + detail::epsUnexpL2SqNeighborhood(adj, vd, x, y, m, n, k, eps, stream); } + +/** + * @brief Computes epsilon neighborhood for the L2-Squared distance metric + * + * @tparam value_t IO and math type + * @tparam idx_t Index type + * @tparam matrix_idx_t matrix indexing type + * + * @param[in] handle raft handle to manage library resources + * @param[in] x first matrix [row-major] [on device] [dim = m x k] + * @param[in] y second matrix [row-major] [on device] [dim = n x k] + * @param[out] adj adjacency matrix [row-major] [on device] [dim = m x n] + * @param[out] vd vertex degree array [on device] [len = m + 1] + * `vd + m` stores the total number of edges in the adjacency + * matrix. Pass a nullptr if you don't need this info. + * @param[in] eps defines epsilon neighborhood radius (should be passed as + * squared as we compute L2-squared distance in this method) + */ +template +void eps_neighbors_l2sq(const raft::handle_t& handle, + raft::device_matrix_view x, + raft::device_matrix_view y, + raft::device_matrix_view adj, + raft::device_vector_view vd, + value_t eps) +{ + epsUnexpL2SqNeighborhood(adj.data_handle(), + vd.data_handle(), + x.data_handle(), + y.data_handle(), + x.extent(0), + y.extent(0), + x.extent(1), + eps, + handle.get_stream()); +} + } // namespace knn } // namespace spatial } // namespace raft diff --git a/cpp/include/raft/spatial/knn/ivf_flat.cuh b/cpp/include/raft/spatial/knn/ivf_flat.cuh index 09bd8edd85..58ca96d392 100644 --- a/cpp/include/raft/spatial/knn/ivf_flat.cuh +++ b/cpp/include/raft/spatial/knn/ivf_flat.cuh @@ -22,6 +22,7 @@ #include +#include #include #include @@ -51,15 +52,15 @@ namespace raft::spatial::knn::ivf_flat { * @tparam T data element type * @tparam IdxT type of the indices in the source dataset * - * @param handle - * @param params configure the index building + * @param[in] handle + * @param[in] params configure the index building * @param[in] dataset a device pointer to a row-major matrix [n_rows, dim] - * @param n_rows the number of samples - * @param dim the dimensionality of the data + * @param[in] n_rows the number of samples + * @param[in] dim the dimensionality of the data * * @return the constructed ivf-flat index */ -template +template inline auto build( const handle_t& handle, const index_params& params, const T* dataset, IdxT n_rows, uint32_t dim) -> index @@ -67,6 +68,50 @@ inline auto build( return raft::spatial::knn::ivf_flat::detail::build(handle, params, dataset, n_rows, dim); } +/** + * @brief Build the index from the dataset for efficient search. + * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * + * Usage example: + * @code{.cpp} + * using namespace raft::spatial::knn; + * // use default index parameters + * ivf_flat::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = ivf_flat::build(handle, index_params, dataset, N, D); + * // use default search parameters + * ivf_flat::search_params search_params; + * // search K nearest neighbours for each of the N queries + * ivf_flat::search(handle, search_params, index, queries, N, K, out_inds, out_dists); + * @endcode + * + * @tparam value_t data element type + * @tparam idx_t type of the indices in the source dataset + * @tparam int_t precision / type of integral arguments + * @tparam matrix_idx_t matrix indexing type + * + * @param[in] handle + * @param[in] params configure the index building + * @param[in] dataset a device pointer to a row-major matrix [n_rows, dim] + * + * @return the constructed ivf-flat index + */ +template +auto build_index(const handle_t& handle, + raft::device_matrix_view dataset, + const index_params& params) -> index +{ + return raft::spatial::knn::ivf_flat::detail::build(handle, + params, + dataset.data_handle(), + static_cast(dataset.extent(0)), + static_cast(dataset.extent(1))); +} + /** * @brief Build a new index containing the data of the original plus new extra vectors. * @@ -89,13 +134,13 @@ inline auto build( * @tparam T data element type * @tparam IdxT type of the indices in the source dataset * - * @param handle - * @param orig_index original index + * @param[in] handle + * @param[in] orig_index original index * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] * @param[in] new_indices a device pointer to a vector of indices [n_rows]. * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` * here to imply a continuous range `[0...n_rows)`. - * @param n_rows the number of samples + * @param[in] n_rows number of rows in `new_vectors` * * @return the constructed extended ivf-flat index */ @@ -110,6 +155,54 @@ inline auto extend(const handle_t& handle, handle, orig_index, new_vectors, new_indices, n_rows); } +/** + * @brief Build a new index containing the data of the original plus new extra vectors. + * + * Implementation note: + * The new data is clustered according to existing kmeans clusters, then the cluster + * centers are adjusted to match the newly labeled data. + * + * Usage example: + * @code{.cpp} + * using namespace raft::spatial::knn; + * ivf_flat::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_flat::build(handle, index_params, dataset, N, D); + * // fill the index with the data + * auto index = ivf_flat::extend(handle, index_empty, dataset, nullptr, N); + * @endcode + * + * @tparam value_t data element type + * @tparam idx_t type of the indices in the source dataset + * @tparam int_t precision / type of integral arguments + * @tparam matrix_idx_t matrix indexing type + * + * @param[in] handle + * @param[in] orig_index original index + * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] + * @param[in] new_indices a device pointer to a vector of indices [n_rows]. + * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` + * here to imply a continuous range `[0...n_rows)`. + * + * @return the constructed extended ivf-flat index + */ +template +auto extend(const handle_t& handle, + const index& orig_index, + raft::device_matrix_view new_vectors, + std::optional> new_indices = std::nullopt) + -> index +{ + return raft::spatial::knn::ivf_flat::detail::extend( + handle, + orig_index, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + new_vectors.extent(0)); +} + /** * @brief Extend the index with the new data. * * @@ -122,7 +215,7 @@ inline auto extend(const handle_t& handle, * @param[in] new_indices a device pointer to a vector of indices [n_rows]. * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` * here to imply a continuous range `[0...n_rows)`. - * @param n_rows the number of samples + * @param[in] n_rows the number of samples */ template inline void extend(const handle_t& handle, @@ -134,6 +227,34 @@ inline void extend(const handle_t& handle, *index = extend(handle, *index, new_vectors, new_indices, n_rows); } +/** + * @brief Extend the index with the new data. + * * + * @tparam value_t data element type + * @tparam idx_t type of the indices in the source dataset + * @tparam int_t precision / type of integral arguments + * @tparam matrix_idx_t matrix indexing type + * + * @param[in] handle + * @param[inout] index + * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] + * @param[in] new_indices a device pointer to a vector of indices [n_rows]. + * If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. + */ +template +void extend(const handle_t& handle, + index* index, + raft::device_matrix_view new_vectors, + std::optional> new_indices = std::nullopt) +{ + *index = extend(handle, + *index, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + static_cast(new_vectors.extent(0))); +} + /** * @brief Search ANN using the constructed index. * @@ -164,17 +285,17 @@ inline void extend(const handle_t& handle, * @tparam T data element type * @tparam IdxT type of the indices * - * @param handle - * @param params configure the search - * @param index ivf-flat constructed index + * @param[in] handle + * @param[in] params configure the search + * @param[in] index ivf-flat constructed index * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] - * @param n_queries the batch size - * @param k the number of neighbors to find for each query. + * @param[in] n_queries the batch size + * @param[in] k the number of neighbors to find for each query. * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset * [n_queries, k] * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - * @param mr an optional memory resource to use across the searches (you can provide a large enough - * memory pool here to avoid memory allocations within search). + * @param[in] mr an optional memory resource to use across the searches (you can provide a large + * enough memory pool here to avoid memory allocations within search). */ template inline void search(const handle_t& handle, @@ -191,4 +312,76 @@ inline void search(const handle_t& handle, handle, params, index, queries, n_queries, k, neighbors, distances, mr); } +/** + * @brief Search ANN using the constructed index. + * + * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`: + * @code{.cpp} + * ... + * // Create a pooling memory resource with a pre-defined initial size. + * rmm::mr::pool_memory_resource mr( + * rmm::mr::get_current_device_resource(), 1024 * 1024); + * // use default search parameters + * ivf_flat::search_params search_params; + * // Use the same allocator across multiple searches to reduce the number of + * // cuda memory allocations + * ivf_flat::search(handle, search_params, index, queries1, N1, K, out_inds1, out_dists1, &mr); + * ivf_flat::search(handle, search_params, index, queries2, N2, K, out_inds2, out_dists2, &mr); + * ivf_flat::search(handle, search_params, index, queries3, N3, K, out_inds3, out_dists3, &mr); + * ... + * @endcode + * The exact size of the temporary buffer depends on multiple factors and is an implementation + * detail. However, you can safely specify a small initial size for the memory pool, so that only a + * few allocations happen to grow it during the first invocations of the `search`. + * + * @tparam value_t data element type + * @tparam idx_t type of the indices + * @tparam int_t precision / type of integral arguments + * @tparam matrix_idx_t matrix indexing type + * + * @param[in] handle + * @param[in] index ivf-flat constructed index + * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] + * @param[in] params configure the search + * @param[in] k the number of neighbors to find for each query. + */ +template +void search(const handle_t& handle, + const index& index, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + const search_params& params, + int_t k) +{ + RAFT_EXPECTS( + queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), + "Number of rows in output neighbors and distances matrices must equal the number of queries."); + + RAFT_EXPECTS( + neighbors.extent(1) == distances.extent(1) && neighbors.extent(1) == static_cast(k), + "Number of columns in output neighbors and distances matrices must equal k"); + + RAFT_EXPECTS(queries.extent(1) == index.dim(), + "Number of query dimensions should equal number of dimensions in the index."); + + return raft::spatial::knn::ivf_flat::detail::search(handle, + params, + index, + queries.data_handle(), + queries.extent(0), + k, + neighbors.data_handle(), + distances.data_handle(), + nullptr); +} + } // namespace raft::spatial::knn::ivf_flat diff --git a/cpp/include/raft/spatial/knn/knn.cuh b/cpp/include/raft/spatial/knn/knn.cuh index deed59195b..95f7aab9da 100644 --- a/cpp/include/raft/spatial/knn/knn.cuh +++ b/cpp/include/raft/spatial/knn/knn.cuh @@ -18,6 +18,7 @@ #include "detail/knn_brute_force_faiss.cuh" #include "detail/selection_faiss.cuh" +#include #include "detail/topk/radix_topk.cuh" #include "detail/topk/warpsort_topk.cuh" @@ -224,4 +225,5 @@ void brute_force_knn(raft::handle_t const& handle, metric, metric_arg); } + } // namespace raft::spatial::knn diff --git a/cpp/include/raft/spatial/knn/specializations/ball_cover.cuh b/cpp/include/raft/spatial/knn/specializations/ball_cover.cuh index 0c35bf4b9c..c859f2c5ec 100644 --- a/cpp/include/raft/spatial/knn/specializations/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/specializations/ball_cover.cuh @@ -25,15 +25,16 @@ namespace raft { namespace spatial { namespace knn { -extern template class BallCoverIndex; -extern template class BallCoverIndex; +extern template class BallCoverIndex; +extern template class BallCoverIndex; -extern template void rbc_build_index( - const raft::handle_t& handle, BallCoverIndex& index); +extern template void rbc_build_index( + const raft::handle_t& handle, + BallCoverIndex& index); extern template void rbc_knn_query( const raft::handle_t& handle, - BallCoverIndex& index, + BallCoverIndex& index, std::uint32_t k, const float* query, std::uint32_t n_query_pts, @@ -42,9 +43,9 @@ extern template void rbc_knn_query( bool perform_post_filtering, float weight); -extern template void rbc_all_knn_query( +extern template void rbc_all_knn_query( const raft::handle_t& handle, - BallCoverIndex& index, + BallCoverIndex& index, std::uint32_t k, std::int64_t* inds, float* dists, diff --git a/cpp/src/nn/specializations/ball_cover.cu b/cpp/src/nn/specializations/ball_cover.cu index 87796752d9..7473b65d25 100644 --- a/cpp/src/nn/specializations/ball_cover.cu +++ b/cpp/src/nn/specializations/ball_cover.cu @@ -28,15 +28,16 @@ namespace raft { namespace spatial { namespace knn { -template class BallCoverIndex; -template class BallCoverIndex; +template class BallCoverIndex; +template class BallCoverIndex; -template void rbc_build_index( - const raft::handle_t& handle, BallCoverIndex& index); +template void rbc_build_index( + const raft::handle_t& handle, + BallCoverIndex& index); template void rbc_knn_query( const raft::handle_t& handle, - BallCoverIndex& index, + BallCoverIndex& index, std::uint32_t k, const float* query, std::uint32_t n_query_pts, @@ -47,7 +48,7 @@ template void rbc_knn_query( template void rbc_all_knn_query( const raft::handle_t& handle, - BallCoverIndex& index, + BallCoverIndex& index, std::uint32_t k, std::int64_t* inds, float* dists, diff --git a/cpp/test/spatial/ann_ivf_flat.cu b/cpp/test/spatial/ann_ivf_flat.cu index 7cc217f789..241c4f6547 100644 --- a/cpp/test/spatial/ann_ivf_flat.cu +++ b/cpp/test/spatial/ann_ivf_flat.cu @@ -17,6 +17,7 @@ #include "../test_utils.h" #include "ann_utils.cuh" +#include #include #include #include @@ -38,22 +39,24 @@ namespace raft { namespace spatial { namespace knn { + +template struct AnnIvfFlatInputs { - int num_queries; - int num_db_vecs; - int dim; - int k; - int nprobe; - int nlist; + IdxT num_queries; + IdxT num_db_vecs; + IdxT dim; + IdxT k; + IdxT nprobe; + IdxT nlist; raft::distance::DistanceType metric; }; -template -class AnnIVFFlatTest : public ::testing::TestWithParam { +template +class AnnIVFFlatTest : public ::testing::TestWithParam> { public: AnnIVFFlatTest() : stream_(handle_.get_stream()), - ps(::testing::TestWithParam::GetParam()), + ps(::testing::TestWithParam>::GetParam()), database(0, stream_), search_queries(0, stream_) { @@ -63,24 +66,24 @@ class AnnIVFFlatTest : public ::testing::TestWithParam { void testIVFFlat() { size_t queries_size = ps.num_queries * ps.k; - std::vector indices_ivfflat(queries_size); - std::vector indices_naive(queries_size); + std::vector indices_ivfflat(queries_size); + std::vector indices_naive(queries_size); std::vector distances_ivfflat(queries_size); std::vector distances_naive(queries_size); { rmm::device_uvector distances_naive_dev(queries_size, stream_); - rmm::device_uvector indices_naive_dev(queries_size, stream_); - naiveBfKnn(distances_naive_dev.data(), - indices_naive_dev.data(), - search_queries.data(), - database.data(), - ps.num_queries, - ps.num_db_vecs, - ps.dim, - ps.k, - ps.metric, - stream_); + rmm::device_uvector indices_naive_dev(queries_size, stream_); + naiveBfKnn(distances_naive_dev.data(), + indices_naive_dev.data(), + search_queries.data(), + database.data(), + ps.num_queries, + ps.num_db_vecs, + ps.dim, + ps.k, + ps.metric, + stream_); update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); handle_.sync_stream(stream_); @@ -92,7 +95,7 @@ class AnnIVFFlatTest : public ::testing::TestWithParam { double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); rmm::device_uvector distances_ivfflat_dev(queries_size, stream_); - rmm::device_uvector indices_ivfflat_dev(queries_size, stream_); + rmm::device_uvector indices_ivfflat_dev(queries_size, stream_); { // legacy interface @@ -143,25 +146,30 @@ class AnnIVFFlatTest : public ::testing::TestWithParam { index_params.add_data_on_build = false; index_params.kmeans_trainset_fraction = 0.5; - auto index = - ivf_flat::build(handle_, index_params, database.data(), int64_t(ps.num_db_vecs), ps.dim); - rmm::device_uvector vector_indices(ps.num_db_vecs, stream_); + auto database_view = raft::make_device_matrix_view( + (const DataT*)database.data(), ps.num_db_vecs, ps.dim); + + auto index = ivf_flat::build_index(handle_, database_view, index_params); + + rmm::device_uvector vector_indices(ps.num_db_vecs, stream_); thrust::sequence(handle_.get_thrust_policy(), thrust::device_pointer_cast(vector_indices.data()), thrust::device_pointer_cast(vector_indices.data() + ps.num_db_vecs)); handle_.sync_stream(stream_); - int64_t half_of_data = ps.num_db_vecs / 2; + IdxT half_of_data = ps.num_db_vecs / 2; + + auto half_of_data_view = raft::make_device_matrix_view( + (const DataT*)database.data(), half_of_data, ps.dim); - auto index_2 = - ivf_flat::extend(handle_, index, database.data(), nullptr, half_of_data); + auto index_2 = ivf_flat::extend(handle_, index, half_of_data_view); - ivf_flat::extend(handle_, - &index_2, - database.data() + half_of_data * ps.dim, - vector_indices.data() + half_of_data, - int64_t(ps.num_db_vecs) - half_of_data); + ivf_flat::extend(handle_, + &index_2, + database.data() + half_of_data * ps.dim, + vector_indices.data() + half_of_data, + IdxT(ps.num_db_vecs) - half_of_data); ivf_flat::search(handle_, search_params, @@ -213,12 +221,12 @@ class AnnIVFFlatTest : public ::testing::TestWithParam { private: raft::handle_t handle_; rmm::cuda_stream_view stream_; - AnnIvfFlatInputs ps; + AnnIvfFlatInputs ps; rmm::device_uvector database; rmm::device_uvector search_queries; }; -const std::vector inputs = { +const std::vector> inputs = { // test various dims (aligned and not aligned to vector sizes) {1000, 10000, 1, 16, 40, 1024, raft::distance::DistanceType::L2Expanded}, {1000, 10000, 2, 16, 40, 1024, raft::distance::DistanceType::L2Expanded}, @@ -275,17 +283,17 @@ const std::vector inputs = { raft::spatial::knn::detail::topk::kMaxCapacity * 4, raft::distance::DistanceType::InnerProduct}}; -typedef AnnIVFFlatTest AnnIVFFlatTestF; +typedef AnnIVFFlatTest AnnIVFFlatTestF; TEST_P(AnnIVFFlatTestF, AnnIVFFlat) { this->testIVFFlat(); } INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatTestF, ::testing::ValuesIn(inputs)); -typedef AnnIVFFlatTest AnnIVFFlatTestF_uint8; +typedef AnnIVFFlatTest AnnIVFFlatTestF_uint8; TEST_P(AnnIVFFlatTestF_uint8, AnnIVFFlat) { this->testIVFFlat(); } INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatTestF_uint8, ::testing::ValuesIn(inputs)); -typedef AnnIVFFlatTest AnnIVFFlatTestF_int8; +typedef AnnIVFFlatTest AnnIVFFlatTestF_int8; TEST_P(AnnIVFFlatTestF_int8, AnnIVFFlat) { this->testIVFFlat(); } INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatTestF_int8, ::testing::ValuesIn(inputs)); diff --git a/cpp/test/spatial/ball_cover.cu b/cpp/test/spatial/ball_cover.cu index 46867f0fa7..d9ad9cc358 100644 --- a/cpp/test/spatial/ball_cover.cu +++ b/cpp/test/spatial/ball_cover.cu @@ -16,6 +16,7 @@ #include "../test_utils.h" #include "spatial_data.h" +#include #include #include #include @@ -138,21 +139,22 @@ struct ToRadians { __device__ __host__ float operator()(float a) { return a * (CUDART_PI_F / 180.0); } }; +template struct BallCoverInputs { - uint32_t k; - uint32_t n_rows; - uint32_t n_cols; + value_int k; + value_int n_rows; + value_int n_cols; float weight; - uint32_t n_query; + value_int n_query; raft::distance::DistanceType metric; }; -template -class BallCoverKNNQueryTest : public ::testing::TestWithParam { +template +class BallCoverKNNQueryTest : public ::testing::TestWithParam> { protected: void basicTest() { - params = ::testing::TestWithParam::GetParam(); + params = ::testing::TestWithParam>::GetParam(); raft::handle_t handle; uint32_t k = params.k; @@ -200,12 +202,21 @@ class BallCoverKNNQueryTest : public ::testing::TestWithParam { rmm::device_uvector d_pred_I(params.n_query * k, handle.get_stream()); rmm::device_uvector d_pred_D(params.n_query * k, handle.get_stream()); - BallCoverIndex index( - handle, X.data(), params.n_rows, params.n_cols, metric); + auto X_view = + raft::make_device_matrix_view(X.data(), params.n_rows, params.n_cols); + auto X2_view = raft::make_device_matrix_view( + (const value_t*)X2.data(), params.n_query, params.n_cols); + + auto d_pred_I_view = + raft::make_device_matrix_view(d_pred_I.data(), params.n_query, k); + auto d_pred_D_view = + raft::make_device_matrix_view(d_pred_D.data(), params.n_query, k); + + BallCoverIndex index(handle, X_view, metric); raft::spatial::knn::rbc_build_index(handle, index); raft::spatial::knn::rbc_knn_query( - handle, index, k, X2.data(), params.n_query, d_pred_I.data(), d_pred_D.data(), true, weight); + handle, index, X2_view, d_pred_I_view, d_pred_D_view, k, true); handle.sync_stream(); // What we really want are for the distances to match exactly. The @@ -236,15 +247,15 @@ class BallCoverKNNQueryTest : public ::testing::TestWithParam { protected: uint32_t d = 2; - BallCoverInputs params; + BallCoverInputs params; }; -template -class BallCoverAllKNNTest : public ::testing::TestWithParam { +template +class BallCoverAllKNNTest : public ::testing::TestWithParam> { protected: void basicTest() { - params = ::testing::TestWithParam::GetParam(); + params = ::testing::TestWithParam>::GetParam(); raft::handle_t handle; uint32_t k = params.k; @@ -261,6 +272,9 @@ class BallCoverAllKNNTest : public ::testing::TestWithParam { rmm::device_uvector d_ref_I(params.n_rows * k, handle.get_stream()); rmm::device_uvector d_ref_D(params.n_rows * k, handle.get_stream()); + auto X_view = raft::make_device_matrix_view( + (const value_t*)X.data(), params.n_rows, params.n_cols); + if (metric == raft::distance::DistanceType::Haversine) { thrust::transform( handle.get_thrust_policy(), X.data(), X.data() + X.size(), X.data(), ToRadians()); @@ -283,11 +297,14 @@ class BallCoverAllKNNTest : public ::testing::TestWithParam { rmm::device_uvector d_pred_I(params.n_rows * k, handle.get_stream()); rmm::device_uvector d_pred_D(params.n_rows * k, handle.get_stream()); - BallCoverIndex index( - handle, X.data(), params.n_rows, params.n_cols, metric); + auto d_pred_I_view = + raft::make_device_matrix_view(d_pred_I.data(), params.n_rows, k); + auto d_pred_D_view = + raft::make_device_matrix_view(d_pred_D.data(), params.n_rows, k); + + BallCoverIndex index(handle, X_view, metric); - raft::spatial::knn::rbc_all_knn_query( - handle, index, k, d_pred_I.data(), d_pred_D.data(), true, weight); + raft::spatial::knn::rbc_all_knn_query(handle, index, d_pred_I_view, d_pred_D_view, k, true); handle.sync_stream(); // What we really want are for the distances to match exactly. The @@ -321,13 +338,13 @@ class BallCoverAllKNNTest : public ::testing::TestWithParam { void TearDown() override {} protected: - BallCoverInputs params; + BallCoverInputs params; }; typedef BallCoverAllKNNTest BallCoverAllKNNTestF; typedef BallCoverKNNQueryTest BallCoverKNNQueryTestF; -const std::vector ballcover_inputs = { +const std::vector> ballcover_inputs = { {11, 5000, 2, 1.0, 10000, raft::distance::DistanceType::Haversine}, {25, 10000, 2, 1.0, 5000, raft::distance::DistanceType::Haversine}, {2, 10000, 2, 1.0, 5000, raft::distance::DistanceType::L2SqrtUnexpanded}, diff --git a/cpp/test/spatial/epsilon_neighborhood.cu b/cpp/test/spatial/epsilon_neighborhood.cu index 515636ad8c..c83817f6f8 100644 --- a/cpp/test/spatial/epsilon_neighborhood.cu +++ b/cpp/test/spatial/epsilon_neighborhood.cu @@ -17,6 +17,7 @@ #include "../test_utils.h" #include #include +#include #include #include #include @@ -40,12 +41,18 @@ template template class EpsNeighTest : public ::testing::TestWithParam> { protected: - EpsNeighTest() : data(0, stream), adj(0, stream), labels(0, stream), vd(0, stream) {} + EpsNeighTest() + : data(0, handle.get_stream()), + adj(0, handle.get_stream()), + labels(0, handle.get_stream()), + vd(0, handle.get_stream()) + { + } void SetUp() override { - param = ::testing::TestWithParam>::GetParam(); - RAFT_CUDA_TRY(cudaStreamCreate(&stream)); + auto stream = handle.get_stream(); + param = ::testing::TestWithParam>::GetParam(); data.resize(param.n_row * param.n_col, stream); labels.resize(param.n_row, stream); batchSize = param.n_row / param.n_batches; @@ -65,14 +72,13 @@ class EpsNeighTest : public ::testing::TestWithParam> { false); } - void TearDown() override { RAFT_CUDA_TRY(cudaStreamDestroy(stream)); } - EpsInputs param; cudaStream_t stream = 0; rmm::device_uvector data; rmm::device_uvector adj; rmm::device_uvector labels, vd; IdxT batchSize; + const raft::handle_t handle; }; // class EpsNeighTest const std::vector> inputsfi = { @@ -93,15 +99,16 @@ TEST_P(EpsNeighTestFI, Result) for (int i = 0; i < param.n_batches; ++i) { RAFT_CUDA_TRY(cudaMemsetAsync(adj.data(), 0, sizeof(bool) * param.n_row * batchSize, stream)); RAFT_CUDA_TRY(cudaMemsetAsync(vd.data(), 0, sizeof(int) * (batchSize + 1), stream)); - epsUnexpL2SqNeighborhood(adj.data(), - vd.data(), - data.data(), - data.data() + (i * batchSize * param.n_col), - param.n_row, - batchSize, - param.n_col, - param.eps * param.eps, - stream); + + auto adj_view = make_device_matrix_view(adj.data(), param.n_row, batchSize); + auto vd_view = make_device_vector_view(vd.data(), batchSize + 1); + auto x_view = make_device_matrix_view(data.data(), param.n_row, param.n_col); + auto y_view = make_device_matrix_view( + data.data() + (i * batchSize * param.n_col), batchSize, param.n_col); + + eps_neighbors_l2sq( + handle, x_view, y_view, adj_view, vd_view, param.eps * param.eps); + ASSERT_TRUE(raft::devArrMatch( param.n_row / param.n_centers, vd.data(), batchSize, raft::Compare(), stream)); } diff --git a/cpp/test/spatial/knn.cu b/cpp/test/spatial/knn.cu index 3f91242930..5807705038 100644 --- a/cpp/test/spatial/knn.cu +++ b/cpp/test/spatial/knn.cu @@ -16,9 +16,10 @@ #include "../test_utils.h" +#include #include #include -#include +#include #if defined RAFT_NN_COMPILED #include #endif @@ -80,28 +81,22 @@ class KNNTest : public ::testing::TestWithParam { protected: void testBruteForce() { -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_DEBUG) + //#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_DEBUG) raft::print_device_vector("Input array: ", input_.data(), rows_ * cols_, std::cout); std::cout << "K: " << k_ << std::endl; raft::print_device_vector("Labels array: ", search_labels_.data(), rows_, std::cout); -#endif + //#endif + + std::vector> index = { + make_device_matrix_view((const T*)(input_.data()), rows_, cols_)}; + auto search = raft::make_device_matrix_view( + (const T*)(search_data_.data()), rows_, cols_); + + auto indices = raft::make_device_matrix_view(indices_.data(), rows_, k_); + auto distances = + raft::make_device_matrix_view(distances_.data(), rows_, k_); - std::vector input_vec; - std::vector sizes_vec; - input_vec.push_back(input_.data()); - sizes_vec.push_back(rows_); - - brute_force_knn(handle, - input_vec, - sizes_vec, - cols_, - search_data_.data(), - rows_, - indices_.data(), - distances_.data(), - k_, - true, - true); + brute_force_knn(handle, index, search, indices, distances, k_); build_actual_output<<>>( actual_labels_.data(), rows_, k_, search_labels_.data(), indices_.data());