From d85fcf0d227706e71480544792fa777cfa36e2e8 Mon Sep 17 00:00:00 2001 From: hrong Date: Mon, 6 May 2024 20:54:38 -0700 Subject: [PATCH 01/15] [FEA] support of prefiltered brute force based on cuSparseSDDMM - This PR is one part of the feature of #1969 - Add the API of 'search_with_filtering' for brute force. Authors: - James Rong (https://github.com/rhdong) --- cpp/CMakeLists.txt | 8 + cpp/bench/prims/CMakeLists.txt | 1 + cpp/bench/prims/neighbors/knn.cuh | 111 ++++- .../knn/brute_force_filter_float_int64_t.cu | 25 + cpp/include/raft/core/bitmap.cuh | 6 +- cpp/include/raft/core/bitset.cuh | 36 +- cpp/include/raft/core/detail/popc.cuh | 75 +++ .../raft/neighbors/brute_force-ext.cuh | 25 + .../raft/neighbors/brute_force-inl.cuh | 29 +- .../raft/neighbors/detail/knn_brute_force.cuh | 128 +++++ .../sparse/convert/detail/bitmap_to_csr.cuh | 6 +- .../raft/sparse/distance/detail/utils.cuh | 75 ++- .../sparse/matrix/detail/select_k-ext.cuh | 2 +- .../raft/sparse/matrix/detail/select_k.cuh | 3 +- .../neighbors/brute_force_knn_index_float.cu | 16 + .../matrix/detail/select_k_float_int32.cu | 2 +- cpp/test/CMakeLists.txt | 2 + cpp/test/ext_headers/00_generate.py | 1 + .../raft_sparse_matrix_detail_select_k.cu | 27 ++ .../neighbors/prefiltered_brute_force.cu | 444 ++++++++++++++++++ 20 files changed, 976 insertions(+), 46 deletions(-) create mode 100644 cpp/bench/prims/neighbors/knn/brute_force_filter_float_int64_t.cu create mode 100644 cpp/include/raft/core/detail/popc.cuh create mode 100644 cpp/test/ext_headers/raft_sparse_matrix_detail_select_k.cu create mode 100644 cpp/test/sparse/neighbors/prefiltered_brute_force.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 259d9fe428..7e6aa53ca6 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -334,6 +334,14 @@ if(RAFT_COMPILE_LIBRARY) src/matrix/detail/select_k_float_int32.cu src/matrix/detail/select_k_half_int64_t.cu src/matrix/detail/select_k_half_uint32_t.cu + src/sparse/matrix/detail/select_k_half_uint32_t.cu + src/sparse/matrix/detail/select_k_double_int64_t.cu + src/sparse/matrix/detail/select_k_double_uint32_t.cu + src/sparse/matrix/detail/select_k_float_int64_t.cu + src/sparse/matrix/detail/select_k_float_uint32_t.cu + src/sparse/matrix/detail/select_k_float_int32.cu + src/sparse/matrix/detail/select_k_half_int64_t.cu + src/sparse/matrix/detail/select_k_half_uint32_t.cu src/neighbors/ball_cover.cu src/neighbors/brute_force_fused_l2_knn_float_int64_t.cu src/neighbors/brute_force_knn_int64_t_float_int64_t.cu diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index 0c5521d447..16987288db 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -145,6 +145,7 @@ if(BUILD_PRIMS_BENCH) NAME NEIGHBORS_BENCH PATH + bench/prims/neighbors/knn/brute_force_filter_float_int64_t.cu bench/prims/neighbors/knn/brute_force_float_int64_t.cu bench/prims/neighbors/knn/brute_force_float_uint32_t.cu bench/prims/neighbors/knn/cagra_float_uint32_t.cu diff --git a/cpp/bench/prims/neighbors/knn.cuh b/cpp/bench/prims/neighbors/knn.cuh index 6499078623..ec9f1a245b 100644 --- a/cpp/bench/prims/neighbors/knn.cuh +++ b/cpp/bench/prims/neighbors/knn.cuh @@ -20,6 +20,7 @@ #include #include +#include #include #include #include @@ -36,7 +37,10 @@ #include +#include #include +#include +#include namespace raft::bench::spatial { @@ -51,12 +55,19 @@ struct params { size_t k; /** Ratio of removed indices. */ double removed_ratio; + /** Distance Type. */ + raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded; }; inline auto operator<<(std::ostream& os, const params& p) -> std::ostream& { os << p.n_samples << "#" << p.n_dims << "#" << p.n_queries << "#" << p.k << "#" << p.removed_ratio; + switch (p.metric) { + case raft::distance::DistanceType::InnerProduct: os << "#InnerProduct"; break; + case raft::distance::DistanceType::L2Expanded: os << "#L2Expanded"; break; + default: os << "UNKNOWN DistanceType, please add one case here."; + } return os; } @@ -149,7 +160,7 @@ struct ivf_flat_knn { ivf_flat_knn(const raft::device_resources& handle, const params& ps, const ValT* data) : ps(ps) { index_params.n_lists = 4096; - index_params.metric = raft::distance::DistanceType::L2Expanded; + index_params.metric = ps.metric; index.emplace(raft::neighbors::ivf_flat::build( handle, index_params, data, IdxT(ps.n_samples), uint32_t(ps.n_dims))); } @@ -184,7 +195,7 @@ struct ivf_pq_knn { ivf_pq_knn(const raft::device_resources& handle, const params& ps, const ValT* data) : ps(ps) { index_params.n_lists = 4096; - index_params.metric = raft::distance::DistanceType::L2Expanded; + index_params.metric = ps.metric; auto data_view = raft::make_device_matrix_view(data, ps.n_samples, ps.n_dims); index.emplace(raft::neighbors::ivf_pq::build(handle, index_params, data_view)); } @@ -236,6 +247,88 @@ struct brute_force_knn { } }; +template +RAFT_KERNEL initialize_random_bits( + bitmap_t* data, IdxT N, float sparsity, size_t total_bits, unsigned long seed) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= N) return; + + curandState state; + curand_init(seed, idx, 0, &state); + + bitmap_t value = 0; + for (int i = 0; i < sizeof(bitmap_t) * 8; i++) { + int rnd = curand(&state) % 10000; + + if (rnd < int(10000 * sparsity) && (idx * sizeof(bitmap_t) * 8 + i < total_bits)) { + bitmap_t bit_mask = 1u << i; + value |= bit_mask; + } + } + data[idx] = value; +} + +template +struct brute_force_filter_knn { + using dist_t = float; + using bitmap_t = std::uint32_t; + + std::optional> index; + raft::neighbors::brute_force::index_params index_params; + raft::neighbors::brute_force::search_params search_params; + raft::core::bitset removed_indices_bitset_; + params ps; + + brute_force_filter_knn(const raft::device_resources& handle, const params& ps, const ValT* data) + : ps(ps), removed_indices_bitset_(handle, ps.n_samples * ps.n_queries) + { + auto stream = resource::get_cuda_stream(handle); + index_params.metric = ps.metric; + + auto data_view = raft::make_device_matrix_view(data, ps.n_samples, ps.n_dims); + index.emplace(raft::neighbors::brute_force::build(handle, index_params, data_view)); + + IdxT element = raft::ceildiv(IdxT(ps.n_samples * ps.n_queries), IdxT(sizeof(bitmap_t) * 8)); + + size_t threadsPerBlock = 256; + size_t numBlocks = (element + threadsPerBlock - 1) / threadsPerBlock; + unsigned long seed = 1234; + initialize_random_bits<<>>( + removed_indices_bitset_.data(), + removed_indices_bitset_.size(), + float(1.0 - ps.removed_ratio), + ps.n_samples * ps.n_queries, + seed); + + resource::sync_stream(handle); + } + + void search(const raft::device_resources& handle, + const ValT* search_items, + ValT* out_dists, + IdxT* out_idxs) + { + auto queries_view = + raft::make_device_matrix_view(search_items, ps.n_queries, ps.n_dims); + auto neighbors_view = + raft::make_device_matrix_view(out_idxs, ps.n_queries, ps.k); + auto distance_view = + raft::make_device_matrix_view(out_dists, ps.n_queries, ps.k); + + if (ps.removed_ratio > 0) { + auto filter = raft::core::bitmap_view( + (const bitmap_t*)removed_indices_bitset_.data(), IdxT(ps.n_queries), IdxT(ps.n_samples)); + + raft::neighbors::brute_force::search_with_filtering( + handle, *index, queries_view, filter, neighbors_view, distance_view); + } else { + raft::neighbors::brute_force::search( + handle, search_params, *index, queries_view, neighbors_view, distance_view); + } + } +}; + template struct ivf_flat_filter_knn { using dist_t = float; @@ -250,7 +343,7 @@ struct ivf_flat_filter_knn { : ps(ps), removed_indices_bitset_(handle, ps.n_samples) { index_params.n_lists = 4096; - index_params.metric = raft::distance::DistanceType::L2Expanded; + index_params.metric = ps.metric; index.emplace(raft::neighbors::ivf_flat::build( handle, index_params, data, IdxT(ps.n_samples), uint32_t(ps.n_dims))); auto removed_indices = @@ -298,7 +391,7 @@ struct ivf_pq_filter_knn { : ps(ps), removed_indices_bitset_(handle, ps.n_samples) { index_params.n_lists = 4096; - index_params.metric = raft::distance::DistanceType::L2Expanded; + index_params.metric = ps.metric; auto data_view = raft::make_device_matrix_view(data, ps.n_samples, ps.n_dims); index.emplace(raft::neighbors::ivf_pq::build(handle, index_params, data_view)); auto removed_indices = @@ -500,10 +593,20 @@ const std::vector kInputsFilter = {size_t(255)}, // k {0.0, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64} // removed_ratio ); + +const std::vector kInputsBruteForceFilter = raft::util::itertools::product( + {size_t(1000000)}, // n_samples + {size_t(128)}, // n_dim + {size_t(1000)}, // n_queries + {size_t(255)}, // k + {0.0, 0.8, 0.9}, // removed_ratio + {raft::distance::DistanceType::InnerProduct, raft::distance::DistanceType::L2Expanded}); + inline const std::vector kAllStrategies{ TransferStrategy::NO_COPY, TransferStrategy::MAP_PINNED, TransferStrategy::MANAGED}; inline const std::vector kNoCopyOnly{TransferStrategy::NO_COPY}; +inline const std::vector kScopeOnlySearch{Scope::SEARCH}; inline const std::vector kScopeFull{Scope::BUILD_SEARCH}; inline const std::vector kAllScopes{Scope::BUILD_SEARCH, Scope::SEARCH, Scope::BUILD}; diff --git a/cpp/bench/prims/neighbors/knn/brute_force_filter_float_int64_t.cu b/cpp/bench/prims/neighbors/knn/brute_force_filter_float_int64_t.cu new file mode 100644 index 0000000000..13e2c1febd --- /dev/null +++ b/cpp/bench/prims/neighbors/knn/brute_force_filter_float_int64_t.cu @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter +#include "../knn.cuh" + +namespace raft::bench::spatial { + +KNN_REGISTER( + float, int64_t, brute_force_filter_knn, kInputsBruteForceFilter, kNoCopyOnly, kScopeOnlySearch); + +} // namespace raft::bench::spatial diff --git a/cpp/include/raft/core/bitmap.cuh b/cpp/include/raft/core/bitmap.cuh index 829c84ed25..0056cfa5f4 100644 --- a/cpp/include/raft/core/bitmap.cuh +++ b/cpp/include/raft/core/bitmap.cuh @@ -22,6 +22,8 @@ #include #include +#include + namespace raft::core { /** * @defgroup bitmap Bitmap @@ -39,8 +41,8 @@ namespace raft::core { */ template struct bitmap_view : public bitset_view { - static_assert((std::is_same::value || - std::is_same::value), + static_assert((std::is_same::type, uint32_t>::value || + std::is_same::type, uint64_t>::value), "The bitmap_t must be uint32_t or uint64_t."); /** * @brief Create a bitmap view from a device raw pointer. diff --git a/cpp/include/raft/core/bitset.cuh b/cpp/include/raft/core/bitset.cuh index 53fd586ed2..13fc8bbcdb 100644 --- a/cpp/include/raft/core/bitset.cuh +++ b/cpp/include/raft/core/bitset.cuh @@ -16,7 +16,7 @@ #pragma once -#include // native_popc +#include #include #include #include @@ -326,37 +326,9 @@ struct bitset { */ void count(const raft::resources& res, raft::device_scalar_view count_gpu_scalar) { - auto n_elements_ = n_elements(); - auto count_gpu = - raft::make_device_vector_view(count_gpu_scalar.data_handle(), 1); - auto bitset_matrix_view = raft::make_device_matrix_view( - bitset_.data(), n_elements_, 1); - - bitset_t n_last_element = (bitset_len_ % bitset_element_size); - bitset_t last_element_mask = - n_last_element ? (bitset_t)((bitset_t{1} << n_last_element) - bitset_t{1}) : ~bitset_t{0}; - raft::linalg::coalesced_reduction( - res, - bitset_matrix_view, - count_gpu, - index_t{0}, - false, - [last_element_mask, n_elements_] __device__(bitset_t element, index_t index) { - index_t result = 0; - if constexpr (bitset_element_size == 64) { - if (index == n_elements_ - 1) - result = index_t(raft::detail::popc(element & last_element_mask)); - else - result = index_t(raft::detail::popc(element)); - } else { // Needed because popc is not overloaded for 16 and 8 bit elements - if (index == n_elements_ - 1) - result = index_t(raft::detail::popc(uint32_t{element} & last_element_mask)); - else - result = index_t(raft::detail::popc(uint32_t{element})); - } - - return result; - }); + auto values = + raft::make_device_vector_view(bitset_.data(), n_elements()); + raft::detail::popc(res, values, bitset_len_, count_gpu_scalar); } /** * @brief Returns the number of bits set to true. diff --git a/cpp/include/raft/core/detail/popc.cuh b/cpp/include/raft/core/detail/popc.cuh new file mode 100644 index 0000000000..d74b68b715 --- /dev/null +++ b/cpp/include/raft/core/detail/popc.cuh @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +namespace raft::detail { + +/** + * @brief Count the number of bits that are set to 1 in a vector. + * + * @tparam value_t the value type of the vector. + * @tparam index_t the index type of vector and scalar. + * + * @param[in] res raft handle for managing expensive resources + * @param[in] values Number of row in the matrix. + * @param[in] max_len Maximum number of bits to count. + * @param[out] counter Number of bits that are set to 1. + */ +template +void popc(const raft::resources& res, + device_vector_view values, + index_t max_len, + raft::device_scalar_view counter) +{ + auto values_size = values.size(); + auto values_matrix = raft::make_device_matrix_view( + values.data_handle(), values_size, 1); + auto counter_vector = raft::make_device_vector_view(counter.data_handle(), 1); + + static constexpr index_t len_per_item = sizeof(value_t) * 8; + + value_t tail_len = (max_len % len_per_item); + value_t tail_mask = tail_len ? (value_t)((value_t{1} << tail_len) - value_t{1}) : ~value_t{0}; + raft::linalg::coalesced_reduction( + res, + values_matrix, + counter_vector, + index_t{0}, + false, + [tail_mask, values_size] __device__(value_t value, index_t index) { + index_t result = 0; + if constexpr (len_per_item == 64) { + if (index == values_size - 1) + result = index_t(raft::detail::popc(value & tail_mask)); + else + result = index_t(raft::detail::popc(value)); + } else { // Needed because popc is not overloaded for 16 and 8 bit elements + if (index == values_size - 1) + result = index_t(raft::detail::popc(uint32_t{value} & tail_mask)); + else + result = index_t(raft::detail::popc(uint32_t{value})); + } + + return result; + }); +} + +} // end namespace raft::detail \ No newline at end of file diff --git a/cpp/include/raft/neighbors/brute_force-ext.cuh b/cpp/include/raft/neighbors/brute_force-ext.cuh index 4055c253c8..c033077e51 100644 --- a/cpp/include/raft/neighbors/brute_force-ext.cuh +++ b/cpp/include/raft/neighbors/brute_force-ext.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include // raft::device_matrix_view #include // raft::identity_op #include // raft::resources @@ -65,6 +66,14 @@ void search(raft::resources const& res, raft::device_matrix_view neighbors, raft::device_matrix_view distances) RAFT_EXPLICIT; +template +void search_with_filtering(raft::resources const& res, + const index& idx, + raft::device_matrix_view queries, + raft::core::bitmap_view filter, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) RAFT_EXPLICIT; + template ( raft::device_matrix_view neighbors, raft::device_matrix_view distances); +extern template void search_with_filtering( + raft::resources const& res, + const raft::neighbors::brute_force::index& idx, + raft::device_matrix_view queries, + raft::core::bitmap_view filter, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances); + extern template void search( raft::resources const& res, const raft::neighbors::brute_force::index& idx, @@ -152,6 +169,14 @@ extern template void search( raft::device_matrix_view neighbors, raft::device_matrix_view distances); +extern template void search_with_filtering( + raft::resources const& res, + const raft::neighbors::brute_force::index& idx, + raft::device_matrix_view queries, + raft::core::bitmap_view filter, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances); + extern template raft::neighbors::brute_force::index build( raft::resources const& res, raft::device_matrix_view dataset, diff --git a/cpp/include/raft/neighbors/brute_force-inl.cuh b/cpp/include/raft/neighbors/brute_force-inl.cuh index f955cc8518..c10a45d3b4 100644 --- a/cpp/include/raft/neighbors/brute_force-inl.cuh +++ b/cpp/include/raft/neighbors/brute_force-inl.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ #pragma once +#include #include #include #include @@ -449,5 +450,31 @@ void search(raft::resources const& res, raft::neighbors::detail::brute_force_search(res, idx, queries, neighbors, distances); } +/** + * @brief Brute Force search with filter using the constructed index. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] res raft resources + * @param[in] idx brute force index + * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] + * @param[in] sample_filter a device filter function that green lights samples for a given query + * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * k] + */ +template +void search_with_filtering(raft::resources const& res, + const index& idx, + raft::device_matrix_view queries, + raft::core::bitmap_view sample_filter, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) +{ + raft::neighbors::detail::brute_force_search( + res, idx, queries, sample_filter, neighbors, distances); +} /** @} */ // end group brute_force_knn } // namespace raft::neighbors::brute_force diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index daa2798b00..a85b7d3d84 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -16,6 +16,9 @@ #pragma once +#include +#include +#include #include #include #include @@ -25,12 +28,18 @@ #include #include #include +#include #include #include #include #include #include #include +#include +#include +#include +#include +#include #include #include #include @@ -548,4 +557,123 @@ void brute_force_search( norms.size() ? &norms : nullptr, query_norms ? query_norms->data_handle() : nullptr); } + +template +void brute_force_search( + raft::resources const& res, + const raft::neighbors::brute_force::index& idx, + raft::device_matrix_view queries, + raft::core::bitmap_view filter, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + std::optional> query_norms = std::nullopt) +{ + auto metric = idx.metric(); + + RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), "Value of k must match for outputs"); + RAFT_EXPECTS(idx.dataset().extent(1) == queries.extent(1), + "Number of columns in queries must match brute force index"); + RAFT_EXPECTS(metric == raft::distance::DistanceType::InnerProduct || + metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded || + metric == raft::distance::DistanceType::CosineExpanded, + "Only Euclidean, IP, and Cosine are supported!"); + + RAFT_EXPECTS(idx.has_norms() || !(metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded || + metric == raft::distance::DistanceType::CosineExpanded), + "Index must has norms when using Euclidean, IP, and Cosine!"); + + IdxT n_queries = queries.extent(0); + IdxT n_dataset = idx.dataset().extent(0); + IdxT dim = idx.dataset().extent(1); + + auto stream = resource::get_cuda_stream(res); + + // calc nnz + IdxT nnz_h = 0; + rmm::device_scalar nnz(0, stream); + auto nnz_view = make_device_scalar_view(nnz.data()); + auto filter_view = + raft::make_device_vector_view(filter.data(), filter.n_elements()); + + raft::detail::popc(res, filter_view, n_queries * n_dataset, nnz_view); + raft::copy(&nnz_h, nnz.data(), 1, stream); + auto csr = raft::make_device_csr_matrix(res, n_queries, n_dataset, nnz_h); + + // fill csr + raft::sparse::convert::bitmap_to_csr(res, filter, csr); + + // create filter csr view + auto compressed_csr_view = csr.structure_view(); + auto csr_view = make_device_csr_matrix_view(csr.get_elements().data(), + compressed_csr_view); + + // create dataset view + auto dataset_view = raft::make_device_matrix_view( + idx.dataset().data_handle(), dim, n_dataset); + + // calc dot + T alpha = static_cast(1.0f); + T beta = static_cast(0.0f); + raft::sparse::linalg::sddmm(res, + queries, + dataset_view, + csr_view, + raft::linalg::Operation::NON_TRANSPOSE, + raft::linalg::Operation::NON_TRANSPOSE, + raft::make_host_scalar_view(&alpha), + raft::make_host_scalar_view(&beta)); + + // post process + std::optional> query_norms_; + if (metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded || + metric == raft::distance::DistanceType::CosineExpanded) { + if (metric == raft::distance::DistanceType::CosineExpanded) { + if (!query_norms) { + query_norms_ = make_device_vector(res, n_queries); + raft::linalg::rowNorm((T*)(query_norms_->data_handle()), + queries.data_handle(), + dim, + n_queries, + raft::linalg::L2Norm, + true, + stream, + raft::sqrt_op{}); + } + } else { + if (!query_norms) { + query_norms_ = make_device_vector(res, n_queries); + raft::linalg::rowNorm((T*)(query_norms_->data_handle()), + queries.data_handle(), + dim, + n_queries, + raft::linalg::L2Norm, + true, + stream, + raft::identity_op{}); + } + } + raft::sparse::distance::detail::epilogue_on_csr( + res, + csr.get_elements().data(), + compressed_csr_view.get_indptr().data(), + compressed_csr_view.get_nnz(), + compressed_csr_view.get_n_rows(), + compressed_csr_view.get_indices().data(), + query_norms ? query_norms->data_handle() : query_norms_->data_handle(), + idx.norms().data_handle(), + metric); + } + + // select k + auto const_csr_view = make_device_csr_matrix_view( + csr.get_elements().data(), compressed_csr_view); + std::optional> no_opt = std::nullopt; + raft::sparse::matrix::select_k(res, const_csr_view, no_opt, distances, neighbors, true, true); + + return; +} + } // namespace raft::neighbors::detail diff --git a/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh index b0315486ff..a9624d891a 100644 --- a/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh +++ b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh @@ -68,7 +68,7 @@ RAFT_KERNEL __launch_bounds__(calc_nnz_by_rows_tpb) calc_nnz_by_rows_kernel(cons while (offset < num_cols) { index_t bitmap_idx = lane_id + (s_bit + offset) / BITS_PER_BITMAP; - bitmap_t l_bitmap = bitmap_t(0); + typename std::remove_const::type l_bitmap = 0; if (bitmap_idx * BITS_PER_BITMAP < e_bit) { l_bitmap = bitmap[bitmap_idx]; } @@ -177,8 +177,8 @@ RAFT_KERNEL __launch_bounds__(fill_indices_by_rows_tpb) #pragma unroll for (index_t offset = 0; offset < num_cols; offset += BITS_PER_BITMAP * warpSize) { index_t bitmap_idx = lane_id + (s_bit + offset) / BITS_PER_BITMAP; - bitmap_t l_bitmap = bitmap_t(0); - index_t l_offset = offset + lane_id * BITS_PER_BITMAP - (s_bit % BITS_PER_BITMAP); + typename std::remove_const::type l_bitmap = 0; + index_t l_offset = offset + lane_id * BITS_PER_BITMAP - (s_bit % BITS_PER_BITMAP); if (bitmap_idx * BITS_PER_BITMAP < e_bit) { l_bitmap = bitmap[bitmap_idx]; } diff --git a/cpp/include/raft/sparse/distance/detail/utils.cuh b/cpp/include/raft/sparse/distance/detail/utils.cuh index ed2b414c70..a5dc6b02cc 100644 --- a/cpp/include/raft/sparse/distance/detail/utils.cuh +++ b/cpp/include/raft/sparse/distance/detail/utils.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,9 @@ #pragma once +#include +#include + #include namespace raft { @@ -37,6 +40,76 @@ inline int max_cols_per_block() sizeof(value_t); } +template +RAFT_KERNEL epilogue_on_csr_kernel(value_t* __restrict__ compressed_C, + const value_idx* __restrict__ rows, + const value_idx* __restrict__ cols, + const value_t* __restrict__ Q_sq_norms, + const value_t* __restrict__ R_sq_norms, + value_idx nnz, + expansion_f expansion_func) +{ + auto tid = blockDim.x * blockIdx.x + threadIdx.x; + + if (tid >= nnz) return; + const value_idx i = rows[tid]; + const value_idx j = cols[tid]; + + compressed_C[tid] = expansion_func(compressed_C[tid], Q_sq_norms[i], R_sq_norms[j]); +} + +template +void epilogue_on_csr(raft::resources const& handle, + value_t* compressed_C, + const value_idx* indptr, + const value_idx nnz, + const value_idx n_rows, + const value_idx* cols, + const value_t* Q_sq_norms, + const value_t* R_sq_norms, + raft::distance::DistanceType metric) +{ + auto stream = resource::get_cuda_stream(handle); + + rmm::device_uvector rows(nnz, stream); + raft::sparse::convert::csr_to_coo(indptr, n_rows, rows.data(), nnz, stream); + + int blocks = raft::ceildiv((size_t)nnz, tpb); + if (metric == raft::distance::DistanceType::L2Expanded) { + epilogue_on_csr_kernel<<>>( + compressed_C, + rows.data(), + cols, + Q_sq_norms, + R_sq_norms, + nnz, + [] __device__ __host__(value_t dot, value_t q_norm, value_t r_norm) -> value_t { + return value_t(-2.0) * dot + q_norm + r_norm; + }); + } else if (metric == raft::distance::DistanceType::L2SqrtExpanded) { + epilogue_on_csr_kernel<<>>( + compressed_C, + rows.data(), + cols, + Q_sq_norms, + R_sq_norms, + nnz, + [] __device__ __host__(value_t dot, value_t q_norm, value_t r_norm) -> value_t { + return raft::sqrt(value_t(-2.0) * dot + q_norm + r_norm); + }); + } else if (metric == raft::distance::DistanceType::CosineExpanded) { + epilogue_on_csr_kernel<<>>( + compressed_C, + rows.data(), + cols, + Q_sq_norms, + R_sq_norms, + nnz, + [] __device__ __host__(value_t dot, value_t q_norm, value_t r_norm) -> value_t { + return value_t(1.0) - dot / (q_norm * r_norm); + }); + } +} } // namespace detail } // namespace distance } // namespace sparse diff --git a/cpp/include/raft/sparse/matrix/detail/select_k-ext.cuh b/cpp/include/raft/sparse/matrix/detail/select_k-ext.cuh index 922356b040..01625a0ce8 100644 --- a/cpp/include/raft/sparse/matrix/detail/select_k-ext.cuh +++ b/cpp/include/raft/sparse/matrix/detail/select_k-ext.cuh @@ -37,7 +37,7 @@ void select_k(raft::resources const& handle, raft::device_matrix_view out_idx, bool select_min, bool sorted = false, - raft::matrix::SelectAlgo algo = SelectAlgo::kAuto) RAFT_EXPLICIT; + raft::matrix::SelectAlgo algo = raft::matrix::SelectAlgo::kAuto) RAFT_EXPLICIT; } // namespace raft::sparse::matrix::detail #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY diff --git a/cpp/include/raft/sparse/matrix/detail/select_k.cuh b/cpp/include/raft/sparse/matrix/detail/select_k.cuh index 711169984b..5d52b94b2f 100644 --- a/cpp/include/raft/sparse/matrix/detail/select_k.cuh +++ b/cpp/include/raft/sparse/matrix/detail/select_k.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ #ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY #include "select_k-inl.cuh" + #endif #ifdef RAFT_COMPILED diff --git a/cpp/src/neighbors/brute_force_knn_index_float.cu b/cpp/src/neighbors/brute_force_knn_index_float.cu index de94be4c09..4de66b163d 100644 --- a/cpp/src/neighbors/brute_force_knn_index_float.cu +++ b/cpp/src/neighbors/brute_force_knn_index_float.cu @@ -37,6 +37,14 @@ template void raft::neighbors::brute_force::search( raft::device_matrix_view neighbors, raft::device_matrix_view distances); +template void raft::neighbors::brute_force::search_with_filtering( + raft::resources const& res, + const raft::neighbors::brute_force::index& idx, + raft::device_matrix_view queries, + raft::core::bitmap_view filter, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances); + template void raft::neighbors::brute_force::search( raft::resources const& res, const raft::neighbors::brute_force::index& idx, @@ -52,6 +60,14 @@ template void raft::neighbors::brute_force::search( raft::device_matrix_view neighbors, raft::device_matrix_view distances); +template void raft::neighbors::brute_force::search_with_filtering( + raft::resources const& res, + const raft::neighbors::brute_force::index& idx, + raft::device_matrix_view queries, + raft::core::bitmap_view filter, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances); + template raft::neighbors::brute_force::index raft::neighbors::brute_force:: build::accessor_type>( raft::resources const& res, diff --git a/cpp/src/sparse/matrix/detail/select_k_float_int32.cu b/cpp/src/sparse/matrix/detail/select_k_float_int32.cu index bff213ae69..49bec86e6e 100644 --- a/cpp/src/sparse/matrix/detail/select_k_float_int32.cu +++ b/cpp/src/sparse/matrix/detail/select_k_float_int32.cu @@ -17,7 +17,7 @@ #include #define instantiate_raft_sparse_matrix_detail_select_k(T, IdxT) \ - template void raft::matrix::detail::select_k( \ + template void raft::sparse::matrix::detail::select_k( \ raft::resources const& handle, \ raft::device_csr_matrix_view in_val, \ std::optional> in_idx, \ diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 752dffdc16..f2d121ac96 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -193,6 +193,7 @@ if(BUILD_TESTS) test/ext_headers/raft_neighbors_refine.cu test/ext_headers/raft_neighbors_detail_ivf_flat_search.cu test/ext_headers/raft_linalg_detail_coalesced_reduction.cu + test/ext_headers/raft_sparse_matrix_detail_select_k.cu test/ext_headers/raft_spatial_knn_detail_ball_cover_registers.cu test/ext_headers/raft_neighbors_detail_ivf_flat_interleaved_scan.cu test/ext_headers/raft_neighbors_detail_ivf_pq_compute_similarity.cu @@ -338,6 +339,7 @@ if(BUILD_TESTS) test/sparse/neighbors/cross_component_nn.cu test/sparse/neighbors/brute_force.cu test/sparse/neighbors/knn_graph.cu + test/sparse/neighbors/prefiltered_brute_force.cu LIB EXPLICIT_INSTANTIATE_ONLY ) diff --git a/cpp/test/ext_headers/00_generate.py b/cpp/test/ext_headers/00_generate.py index d9c766979b..1e1106f8bf 100644 --- a/cpp/test/ext_headers/00_generate.py +++ b/cpp/test/ext_headers/00_generate.py @@ -54,6 +54,7 @@ "raft/neighbors/refine-ext.cuh", "raft/neighbors/detail/ivf_flat_search-ext.cuh", "raft/linalg/detail/coalesced_reduction-ext.cuh", + "raft/sparse/matrix/detail/select_k-ext.cuh", "raft/spatial/knn/detail/ball_cover/registers-ext.cuh", "raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh", "raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh", diff --git a/cpp/test/ext_headers/raft_sparse_matrix_detail_select_k.cu b/cpp/test/ext_headers/raft_sparse_matrix_detail_select_k.cu new file mode 100644 index 0000000000..b748a31a5b --- /dev/null +++ b/cpp/test/ext_headers/raft_sparse_matrix_detail_select_k.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/sparse/neighbors/prefiltered_brute_force.cu b/cpp/test/sparse/neighbors/prefiltered_brute_force.cu new file mode 100644 index 0000000000..8aa54ef61f --- /dev/null +++ b/cpp/test/sparse/neighbors/prefiltered_brute_force.cu @@ -0,0 +1,444 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../../neighbors/knn_utils.cuh" +#include "../../test_utils.cuh" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::neighbors::brute_force { + +template +struct PrefilteredBruteForceInputs { + index_t n_queries; + index_t n_dataset; + index_t dim; + index_t top_k; + float sparsity; + raft::distance::DistanceType metric; + bool select_min = true; +}; + +template +struct CompareApproxWithInf { + CompareApproxWithInf(T eps_) : eps(eps_) {} + bool operator()(const T& a, const T& b) const + { + if (std::isinf(a) && std::isinf(b)) return true; + T diff = std::abs(a - b); + T m = std::max(std::abs(a), std::abs(b)); + T ratio = diff > eps ? diff / m : diff; + + return (ratio <= eps); + } + + private: + T eps; +}; + +template +class PrefilteredBruteForceTest + : public ::testing::TestWithParam> { + public: + PrefilteredBruteForceTest() + : stream(resource::get_cuda_stream(handle)), + params(::testing::TestWithParam>::GetParam()), + filter_d(0, stream), + dataset_d(0, stream), + queries_d(0, stream), + out_val_d(0, stream), + out_val_expected_d(0, stream), + out_idx_d(0, stream), + out_idx_expected_d(0, stream) + { + } + + protected: + index_t create_sparse_matrix(index_t m, index_t n, float sparsity, std::vector& bitmap) + { + index_t total = static_cast(m * n); + index_t num_ones = static_cast((total * 1.0f) * sparsity); + index_t res = num_ones; + + for (auto& item : bitmap) { + item = static_cast(0); + } + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dis(0, total - 1); + + while (num_ones > 0) { + index_t index = dis(gen); + + bitmap_t& element = bitmap[index / (8 * sizeof(bitmap_t))]; + index_t bit_position = index % (8 * sizeof(bitmap_t)); + + if (((element >> bit_position) & 1) == 0) { + element |= (static_cast(1) << bit_position); + num_ones--; + } + } + return res; + } + + void cpu_convert_to_csr(std::vector& bitmap, + index_t rows, + index_t cols, + std::vector& indices, + std::vector& indptr) + { + index_t offset_indptr = 0; + index_t offset_values = 0; + indptr[offset_indptr++] = 0; + + index_t index = 0; + bitmap_t element = 0; + index_t bit_position = 0; + + for (index_t i = 0; i < rows; ++i) { + for (index_t j = 0; j < cols; ++j) { + index = i * cols + j; + element = bitmap[index / (8 * sizeof(bitmap_t))]; + bit_position = index % (8 * sizeof(bitmap_t)); + + if (((element >> bit_position) & 1)) { + indices[offset_values] = static_cast(j); + offset_values++; + } + } + indptr[offset_indptr++] = static_cast(offset_values); + } + } + + void cpu_sddmm(const std::vector& A, + const std::vector& B, + std::vector& vals, + const std::vector& cols, + const std::vector& row_ptrs, + bool is_row_major_A, + bool is_row_major_B, + value_t alpha = 1.0, + value_t beta = 0.0) + { + if (params.n_queries * params.dim != static_cast(A.size()) || + params.dim * params.n_dataset != static_cast(B.size())) { + std::cerr << "Matrix dimensions and vector size do not match!" << std::endl; + return; + } + + bool trans_a = is_row_major_A; + bool trans_b = is_row_major_B; + + for (index_t i = 0; i < params.n_queries; ++i) { + for (index_t j = row_ptrs[i]; j < row_ptrs[i + 1]; ++j) { + value_t sum = 0; + value_t norms_A = 0; + value_t norms_B = 0; + for (index_t l = 0; l < params.dim; ++l) { + index_t a_index = trans_a ? i * params.dim + l : l * params.n_queries + i; + index_t b_index = trans_b ? l * params.n_dataset + cols[j] : cols[j] * params.dim + l; + sum += A[a_index] * B[b_index]; + + norms_A += A[a_index] * A[a_index]; + norms_B += B[b_index] * B[b_index]; + } + vals[j] = alpha * sum + beta * vals[j]; + if (params.metric == raft::distance::DistanceType::L2Expanded) { + vals[j] = value_t(-2.0) * vals[j] + norms_A + norms_B; + } else if (params.metric == raft::distance::DistanceType::L2SqrtExpanded) { + vals[j] = std::sqrt(value_t(-2.0) * vals[j] + norms_A + norms_B); + } else if (params.metric == raft::distance::DistanceType::CosineExpanded) { + vals[j] = value_t(1.0) - vals[j] / std::sqrt(norms_A * norms_B); + } + } + } + } + + void cpu_select_k(const std::vector& indptr_h, + const std::vector& indices_h, + const std::vector& values_h, + std::optional>& in_idx_h, + index_t n_queries, + index_t n_dataset, + index_t top_k, + std::vector& out_values_h, + std::vector& out_indices_h, + bool select_min = true) + { + auto comp = [select_min](const std::pair& a, + const std::pair& b) { + return select_min ? a.first < b.first : a.first >= b.first; + }; + + for (index_t row = 0; row < n_queries; ++row) { + std::priority_queue, + std::vector>, + decltype(comp)> + pq(comp); + + for (index_t idx = indptr_h[row]; idx < indptr_h[row + 1]; ++idx) { + pq.push({values_h[idx], (in_idx_h.has_value()) ? (*in_idx_h)[idx] : indices_h[idx]}); + if (pq.size() > size_t(top_k)) { pq.pop(); } + } + + std::vector> row_pairs; + while (!pq.empty()) { + row_pairs.push_back(pq.top()); + pq.pop(); + } + + if (select_min) { + std::sort(row_pairs.begin(), row_pairs.end(), [](const auto& a, const auto& b) { + return a.first <= b.first; + }); + } else { + std::sort(row_pairs.begin(), row_pairs.end(), [](const auto& a, const auto& b) { + return a.first >= b.first; + }); + } + for (index_t col = 0; col < top_k; col++) { + if (col < index_t(row_pairs.size())) { + out_values_h[row * top_k + col] = row_pairs[col].first; + out_indices_h[row * top_k + col] = row_pairs[col].second; + } + } + } + } + + void random_array(value_t* array, size_t size) + { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution dis(-10.0, 10.0); + std::unordered_set uset; + + while (uset.size() < size) { + uset.insert(dis(gen)); + } + typename std::unordered_set::iterator it = uset.begin(); + for (size_t i = 0; i < size; ++i) { + array[i] = *(it++); + } + } + + void SetUp() override + { + index_t element = + raft::ceildiv(params.n_queries * params.n_dataset, index_t(sizeof(bitmap_t) * 8)); + std::vector filter_h(element); + + nnz = create_sparse_matrix(params.n_queries, params.n_dataset, params.sparsity, filter_h); + + index_t dataset_size = params.n_dataset * params.dim; + index_t queries_size = params.n_queries * params.dim; + + std::vector dataset_h(dataset_size); + std::vector queries_h(queries_size); + + dataset_d.resize(dataset_size, stream); + queries_d.resize(queries_size, stream); + + auto blobs_in_val = + raft::make_device_matrix(handle, 1, dataset_size + queries_size); + auto labels = raft::make_device_vector(handle, 1); + + raft::random::make_blobs(blobs_in_val.data_handle(), + labels.data_handle(), + 1, + dataset_size + queries_size, + 1, + stream, + false, + nullptr, + nullptr, + value_t(1.0), + false, + value_t(-1.0f), + value_t(1.0f), + uint64_t(2024)); + + raft::copy(dataset_h.data(), blobs_in_val.data_handle(), dataset_size, stream); + raft::copy(dataset_d.data(), blobs_in_val.data_handle(), dataset_size, stream); + + raft::copy(queries_h.data(), blobs_in_val.data_handle() + dataset_size, queries_size, stream); + raft::copy(queries_d.data(), blobs_in_val.data_handle() + dataset_size, queries_size, stream); + + resource::sync_stream(handle); + + std::vector values_h(nnz); + std::vector indices_h(nnz); + std::vector indptr_h(params.n_queries + 1); + + filter_d.resize(filter_h.size(), stream); + cpu_convert_to_csr(filter_h, params.n_queries, params.n_dataset, indices_h, indptr_h); + + cpu_sddmm(queries_h, dataset_h, values_h, indices_h, indptr_h, true, false); + + std::vector out_val_h(params.n_queries * params.top_k, + std::numeric_limits::infinity()); + std::vector out_idx_h(params.n_queries * params.top_k, static_cast(0)); + + out_val_d.resize(params.n_queries * params.top_k, stream); + out_idx_d.resize(params.n_queries * params.top_k, stream); + + update_device(out_val_d.data(), out_val_h.data(), out_val_h.size(), stream); + update_device(out_idx_d.data(), out_idx_h.data(), out_idx_h.size(), stream); + update_device(filter_d.data(), filter_h.data(), filter_h.size(), stream); + + resource::sync_stream(handle); + + std::optional> optional_indices_h = std::nullopt; + + cpu_select_k(indptr_h, + indices_h, + values_h, + optional_indices_h, + params.n_queries, + params.n_dataset, + params.top_k, + out_val_h, + out_idx_h, + params.select_min); + + out_val_expected_d.resize(params.n_queries * params.top_k, stream); + out_idx_expected_d.resize(params.n_queries * params.top_k, stream); + + update_device(out_val_expected_d.data(), out_val_h.data(), out_val_h.size(), stream); + update_device(out_idx_expected_d.data(), out_idx_h.data(), out_idx_h.size(), stream); + + resource::sync_stream(handle); + } + + void Run() + { + auto dataset_raw = raft::make_device_matrix_view( + (const value_t*)dataset_d.data(), params.n_dataset, params.dim); + + auto queries = raft::make_device_matrix_view( + (const value_t*)queries_d.data(), params.n_queries, params.dim); + + brute_force::index_params index_params{}; + index_params.metric = params.metric; + index_params.metric_arg = 0; + + auto dataset = brute_force::build(handle, index_params, dataset_raw); + + auto filter = + raft::core::bitmap_view((const bitmap_t*)filter_d.data(), params.n_queries, params.n_dataset); + + auto out_val = raft::make_device_matrix_view( + out_val_d.data(), params.n_queries, params.top_k); + auto out_idx = raft::make_device_matrix_view( + out_idx_d.data(), params.n_queries, params.top_k); + + brute_force::search_with_filtering(handle, dataset, queries, filter, out_idx, out_val); + + ASSERT_TRUE(raft::spatial::knn::devArrMatchKnnPair(out_idx_expected_d.data(), + out_idx.data_handle(), + out_val_expected_d.data(), + out_val.data_handle(), + params.n_queries, + params.top_k, + 0.001f, + stream, + true)); + } + + protected: + raft::resources handle; + cudaStream_t stream; + + PrefilteredBruteForceInputs params; + + index_t nnz; + + rmm::device_uvector dataset_d; + rmm::device_uvector queries_d; + rmm::device_uvector filter_d; + + rmm::device_uvector out_val_d; + rmm::device_uvector out_val_expected_d; + + rmm::device_uvector out_idx_d; + rmm::device_uvector out_idx_expected_d; +}; + +using PrefilteredBruteForceTest_float_int64 = PrefilteredBruteForceTest; +TEST_P(PrefilteredBruteForceTest_float_int64, Result) { Run(); } + +template +const std::vector> selectk_inputs = { + {1000, 10000, 1, 0, 0.1, raft::distance::DistanceType::L2Expanded}, + {1000, 10000, 3, 0, 0.1, raft::distance::DistanceType::InnerProduct}, + {1000, 10000, 5, 0, 0.1, raft::distance::DistanceType::L2SqrtExpanded}, + {1000, 10000, 8, 0, 0.1, raft::distance::DistanceType::CosineExpanded}, + + {1000, 10000, 1, 1, 0.1, raft::distance::DistanceType::L2Expanded}, + {1000, 10000, 3, 1, 0.1, raft::distance::DistanceType::InnerProduct}, + {1000, 10000, 5, 1, 0.1, raft::distance::DistanceType::L2SqrtExpanded}, + {1000, 10000, 8, 1, 0.1, raft::distance::DistanceType::CosineExpanded}, + + {1000, 10000, 2050, 16, 0.4, raft::distance::DistanceType::L2Expanded}, + {1000, 10000, 2051, 16, 0.5, raft::distance::DistanceType::L2Expanded}, + {1000, 10000, 2052, 16, 0.2, raft::distance::DistanceType::L2Expanded}, + {1000, 10000, 2050, 16, 0.4, raft::distance::DistanceType::InnerProduct}, + {1000, 10000, 2051, 16, 0.5, raft::distance::DistanceType::InnerProduct}, + {1000, 10000, 2052, 16, 0.2, raft::distance::DistanceType::InnerProduct}, + {1000, 10000, 2050, 16, 0.4, raft::distance::DistanceType::L2SqrtExpanded}, + {1000, 10000, 2051, 16, 0.5, raft::distance::DistanceType::L2SqrtExpanded}, + {1000, 10000, 2052, 16, 0.2, raft::distance::DistanceType::L2SqrtExpanded}, + {1000, 10000, 2050, 16, 0.4, raft::distance::DistanceType::CosineExpanded}, + {1000, 10000, 2051, 16, 0.5, raft::distance::DistanceType::CosineExpanded}, + {1000, 10000, 2052, 16, 0.2, raft::distance::DistanceType::CosineExpanded}, + + {1000, 10000, 1, 16, 0.5, raft::distance::DistanceType::L2Expanded}, + {1000, 10000, 2, 16, 0.2, raft::distance::DistanceType::L2Expanded}, + {1000, 10000, 3, 16, 0.4, raft::distance::DistanceType::InnerProduct}, + {1000, 10000, 4, 16, 0.5, raft::distance::DistanceType::InnerProduct}, + {1000, 10000, 5, 16, 0.2, raft::distance::DistanceType::L2SqrtExpanded}, + {1000, 10000, 8, 16, 0.4, raft::distance::DistanceType::L2SqrtExpanded}, + {1000, 10000, 5, 16, 0.5, raft::distance::DistanceType::CosineExpanded}, + {1000, 10000, 8, 16, 0.2, raft::distance::DistanceType::CosineExpanded}}; + +INSTANTIATE_TEST_CASE_P(PrefilteredBruteForceTest, + PrefilteredBruteForceTest_float_int64, + ::testing::ValuesIn(selectk_inputs)); + +} // namespace raft::neighbors::brute_force From c7e4e7a4262466172ed17d5fef6b801781320436 Mon Sep 17 00:00:00 2001 From: rhdong Date: Wed, 8 May 2024 02:44:17 -0700 Subject: [PATCH 02/15] Improve the performance in classic scenarios by replace the cuSparseSDDMM with faster_dot_on_csr --- cpp/bench/prims/neighbors/knn.cuh | 19 ++-- .../raft/neighbors/detail/knn_brute_force.cuh | 36 ++++---- .../raft/sparse/distance/detail/utils.cuh | 88 +++++++++++++++++-- .../neighbors/prefiltered_brute_force.cu | 6 ++ 4 files changed, 114 insertions(+), 35 deletions(-) diff --git a/cpp/bench/prims/neighbors/knn.cuh b/cpp/bench/prims/neighbors/knn.cuh index ec9f1a245b..ea3d5acb0b 100644 --- a/cpp/bench/prims/neighbors/knn.cuh +++ b/cpp/bench/prims/neighbors/knn.cuh @@ -61,8 +61,13 @@ struct params { inline auto operator<<(std::ostream& os, const params& p) -> std::ostream& { - os << p.n_samples << "#" << p.n_dims << "#" << p.n_queries << "#" << p.k << "#" - << p.removed_ratio; + os << p.n_samples << "#" << p.n_dims << "#" << p.n_queries << "#" << p.k; + if (p.removed_ratio > 0.0) { + os << "#" << p.removed_ratio; + } else { + os << "#" + << "[No filtered]"; + } switch (p.metric) { case raft::distance::DistanceType::InnerProduct: os << "#InnerProduct"; break; case raft::distance::DistanceType::L2Expanded: os << "#L2Expanded"; break; @@ -595,11 +600,11 @@ const std::vector kInputsFilter = ); const std::vector kInputsBruteForceFilter = raft::util::itertools::product( - {size_t(1000000)}, // n_samples - {size_t(128)}, // n_dim - {size_t(1000)}, // n_queries - {size_t(255)}, // k - {0.0, 0.8, 0.9}, // removed_ratio + {size_t(1000000)}, // n_samples + {size_t(4096), size_t(512), size_t(128)}, // n_dim + {size_t(1), size_t(10), size_t(1000)}, // n_queries + {size_t(255)}, // k + {0.0, 0.8, 0.9, 0.99}, // removed_ratio {raft::distance::DistanceType::InnerProduct, raft::distance::DistanceType::L2Expanded}); inline const std::vector kAllStrategies{ diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index a85b7d3d84..0f3da3f34b 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -606,24 +606,21 @@ void brute_force_search( // create filter csr view auto compressed_csr_view = csr.structure_view(); - auto csr_view = make_device_csr_matrix_view(csr.get_elements().data(), - compressed_csr_view); - - // create dataset view - auto dataset_view = raft::make_device_matrix_view( - idx.dataset().data_handle(), dim, n_dataset); - - // calc dot - T alpha = static_cast(1.0f); - T beta = static_cast(0.0f); - raft::sparse::linalg::sddmm(res, - queries, - dataset_view, - csr_view, - raft::linalg::Operation::NON_TRANSPOSE, - raft::linalg::Operation::NON_TRANSPOSE, - raft::make_host_scalar_view(&alpha), - raft::make_host_scalar_view(&beta)); + rmm::device_uvector rows(compressed_csr_view.get_nnz(), stream); + raft::sparse::convert::csr_to_coo(compressed_csr_view.get_indptr().data(), + compressed_csr_view.get_n_rows(), + rows.data(), + compressed_csr_view.get_nnz(), + stream); + + raft::sparse::distance::detail::faster_dot_on_csr(res, + csr.get_elements().data(), + compressed_csr_view.get_nnz(), + rows.data(), + compressed_csr_view.get_indices().data(), + queries.data_handle(), + idx.dataset().data_handle(), + dim); // post process std::optional> query_norms_; @@ -658,9 +655,8 @@ void brute_force_search( raft::sparse::distance::detail::epilogue_on_csr( res, csr.get_elements().data(), - compressed_csr_view.get_indptr().data(), compressed_csr_view.get_nnz(), - compressed_csr_view.get_n_rows(), + rows.data(), compressed_csr_view.get_indices().data(), query_norms ? query_norms->data_handle() : query_norms_->data_handle(), idx.norms().data_handle(), diff --git a/cpp/include/raft/sparse/distance/detail/utils.cuh b/cpp/include/raft/sparse/distance/detail/utils.cuh index a5dc6b02cc..69220183c7 100644 --- a/cpp/include/raft/sparse/distance/detail/utils.cuh +++ b/cpp/include/raft/sparse/distance/detail/utils.cuh @@ -61,9 +61,8 @@ RAFT_KERNEL epilogue_on_csr_kernel(value_t* __restrict__ compressed_C, template void epilogue_on_csr(raft::resources const& handle, value_t* compressed_C, - const value_idx* indptr, const value_idx nnz, - const value_idx n_rows, + const value_idx* rows, const value_idx* cols, const value_t* Q_sq_norms, const value_t* R_sq_norms, @@ -71,14 +70,11 @@ void epilogue_on_csr(raft::resources const& handle, { auto stream = resource::get_cuda_stream(handle); - rmm::device_uvector rows(nnz, stream); - raft::sparse::convert::csr_to_coo(indptr, n_rows, rows.data(), nnz, stream); - int blocks = raft::ceildiv((size_t)nnz, tpb); if (metric == raft::distance::DistanceType::L2Expanded) { epilogue_on_csr_kernel<<>>( compressed_C, - rows.data(), + rows, cols, Q_sq_norms, R_sq_norms, @@ -89,7 +85,7 @@ void epilogue_on_csr(raft::resources const& handle, } else if (metric == raft::distance::DistanceType::L2SqrtExpanded) { epilogue_on_csr_kernel<<>>( compressed_C, - rows.data(), + rows, cols, Q_sq_norms, R_sq_norms, @@ -100,7 +96,7 @@ void epilogue_on_csr(raft::resources const& handle, } else if (metric == raft::distance::DistanceType::CosineExpanded) { epilogue_on_csr_kernel<<>>( compressed_C, - rows.data(), + rows, cols, Q_sq_norms, R_sq_norms, @@ -110,6 +106,82 @@ void epilogue_on_csr(raft::resources const& handle, }); } } + +template +__inline__ __device__ value_t warpReduceSum(value_t val) +{ + return val; +} + +template +RAFT_KERNEL faster_dot_on_csr_kernel(value_t* __restrict__ dot, + const value_idx* __restrict__ rows, + const value_idx* __restrict__ cols, + const value_t* __restrict__ A, + const value_t* __restrict__ B, + const value_idx nnz, + const value_idx dim) +{ + auto dot_id = blockIdx.x; + auto vec_id = threadIdx.x; + auto lane_id = threadIdx.x & 0x1f; + + const value_idx row = rows[dot_id] * dim; + const value_idx col = cols[dot_id] * dim; + __shared__ value_t g_dot_; + + if (threadIdx.x == 0) { g_dot_ = 0.0; } + __syncthreads(); + + value_t l_dot_ = 0.0; + +#pragma unroll + for (value_idx k = vec_id; k < dim; k += blockDim.x) { + l_dot_ += A[row + k] * B[col + k]; + } + +#pragma unroll + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + l_dot_ += __shfl_down_sync(0xffffffff, l_dot_, offset); + } + + if (lane_id == 0) { atomicAdd_block(&g_dot_, l_dot_); } + __syncthreads(); + + if (threadIdx.x == 0) { dot[dot_id] = g_dot_; } +} + +template +void faster_dot_on_csr(raft::resources const& handle, + value_t* dot, + const value_idx nnz, + const value_idx* rows, + const value_idx* cols, + const value_t* A, + const value_t* B, + const value_idx dim) +{ + auto stream = resource::get_cuda_stream(handle); + + int blocks = int(nnz); + if (dim < 128) { + constexpr int tpb = 64; + faster_dot_on_csr_kernel + <<>>(dot, rows, cols, A, B, nnz, dim); + } else if (dim < 256) { + constexpr int tpb = 128; + faster_dot_on_csr_kernel + <<>>(dot, rows, cols, A, B, nnz, dim); + } else if (dim < 512) { + constexpr int tpb = 256; + faster_dot_on_csr_kernel + <<>>(dot, rows, cols, A, B, nnz, dim); + } else { + constexpr int tpb = 512; + faster_dot_on_csr_kernel + <<>>(dot, rows, cols, A, B, nnz, dim); + } +} } // namespace detail } // namespace distance } // namespace sparse diff --git a/cpp/test/sparse/neighbors/prefiltered_brute_force.cu b/cpp/test/sparse/neighbors/prefiltered_brute_force.cu index 8aa54ef61f..990f012ec0 100644 --- a/cpp/test/sparse/neighbors/prefiltered_brute_force.cu +++ b/cpp/test/sparse/neighbors/prefiltered_brute_force.cu @@ -405,6 +405,12 @@ TEST_P(PrefilteredBruteForceTest_float_int64, Result) { Run(); } template const std::vector> selectk_inputs = { + {1, 100000, 255, 255, 0.4, raft::distance::DistanceType::L2Expanded}, + {10, 100000, 512, 16, 0.5, raft::distance::DistanceType::L2Expanded}, + {20, 100000, 2052, 16, 0.2, raft::distance::DistanceType::L2Expanded}, + {1, 10000, 255, 16, 0.4, raft::distance::DistanceType::InnerProduct}, + {20, 10000, 512, 16, 0.5, raft::distance::DistanceType::InnerProduct}, + {100, 10000, 2052, 16, 0.2, raft::distance::DistanceType::InnerProduct}, {1000, 10000, 1, 0, 0.1, raft::distance::DistanceType::L2Expanded}, {1000, 10000, 3, 0, 0.1, raft::distance::DistanceType::InnerProduct}, {1000, 10000, 5, 0, 0.1, raft::distance::DistanceType::L2SqrtExpanded}, From 5c5aa9b8ce2f4dd93175027aa7ac05caef3e3825 Mon Sep 17 00:00:00 2001 From: rhdong Date: Mon, 13 May 2024 05:00:03 -0700 Subject: [PATCH 03/15] optimize and remove used. --- cpp/bench/prims/neighbors/knn.cuh | 26 +++- .../knn/brute_force_filter_float_int64_t.cu | 8 +- cpp/include/raft/core/bitset.cuh | 6 + .../raft/neighbors/detail/knn_brute_force.cuh | 3 +- .../raft/sparse/distance/detail/utils.cuh | 134 ++++++++++++------ .../neighbors/prefiltered_brute_force.cu | 18 +++ 6 files changed, 145 insertions(+), 50 deletions(-) diff --git a/cpp/bench/prims/neighbors/knn.cuh b/cpp/bench/prims/neighbors/knn.cuh index ea3d5acb0b..4f8e62b1da 100644 --- a/cpp/bench/prims/neighbors/knn.cuh +++ b/cpp/bench/prims/neighbors/knn.cuh @@ -600,12 +600,26 @@ const std::vector kInputsFilter = ); const std::vector kInputsBruteForceFilter = raft::util::itertools::product( - {size_t(1000000)}, // n_samples - {size_t(4096), size_t(512), size_t(128)}, // n_dim - {size_t(1), size_t(10), size_t(1000)}, // n_queries - {size_t(255)}, // k - {0.0, 0.8, 0.9, 0.99}, // removed_ratio - {raft::distance::DistanceType::InnerProduct, raft::distance::DistanceType::L2Expanded}); + {size_t(10 * 1024 * 1024)}, // n_samples + {size_t(256), size_t(768), size_t(1024), size_t(2048), size_t(4096)}, // n_dim + {size_t(10), size_t(1000)}, // n_queries + {size_t(255)}, // k + {0.0, 0.8, 0.9, 0.99}, // removed_ratio + {raft::distance::DistanceType::InnerProduct}); + +const std::vector kInputsBruteForceFilterExtra = + raft::util::itertools::product({size_t(1024 * 1024)}, // n_samples + {size_t(256), size_t(768)}, // n_dim + {size_t(10), + size_t(20), + size_t(40), + size_t(60), + size_t(80), + size_t(100), + size_t(300)}, // n_queries + {size_t(255)}, // k + {0.3, 0.4, 0.9}, // removed_ratio + {raft::distance::DistanceType::InnerProduct}); inline const std::vector kAllStrategies{ TransferStrategy::NO_COPY, TransferStrategy::MAP_PINNED, TransferStrategy::MANAGED}; diff --git a/cpp/bench/prims/neighbors/knn/brute_force_filter_float_int64_t.cu b/cpp/bench/prims/neighbors/knn/brute_force_filter_float_int64_t.cu index 13e2c1febd..7cb0b45d30 100644 --- a/cpp/bench/prims/neighbors/knn/brute_force_filter_float_int64_t.cu +++ b/cpp/bench/prims/neighbors/knn/brute_force_filter_float_int64_t.cu @@ -19,7 +19,11 @@ namespace raft::bench::spatial { -KNN_REGISTER( - float, int64_t, brute_force_filter_knn, kInputsBruteForceFilter, kNoCopyOnly, kScopeOnlySearch); +KNN_REGISTER(float, + int64_t, + brute_force_filter_knn, + kInputsBruteForceFilterExtra, + kNoCopyOnly, + kScopeOnlySearch); } // namespace raft::bench::spatial diff --git a/cpp/include/raft/core/bitset.cuh b/cpp/include/raft/core/bitset.cuh index ea1ef07b61..e45781f357 100644 --- a/cpp/include/raft/core/bitset.cuh +++ b/cpp/include/raft/core/bitset.cuh @@ -60,6 +60,12 @@ _RAFT_HOST_DEVICE void bitset_view::set(const index_t sample_ } } +template +_RAFT_HOST_DEVICE inline index_t bitset_view::n_elements() const +{ + return raft::ceildiv(bitset_len_, bitset_element_size); +} + template bitset::bitset(const raft::resources& res, raft::device_vector_view mask_index, diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 0f3da3f34b..747fcb11db 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -616,10 +616,11 @@ void brute_force_search( raft::sparse::distance::detail::faster_dot_on_csr(res, csr.get_elements().data(), compressed_csr_view.get_nnz(), - rows.data(), + compressed_csr_view.get_indptr().data(), compressed_csr_view.get_indices().data(), queries.data_handle(), idx.dataset().data_handle(), + compressed_csr_view.get_n_rows(), dim); // post process diff --git a/cpp/include/raft/sparse/distance/detail/utils.cuh b/cpp/include/raft/sparse/distance/detail/utils.cuh index 69220183c7..021fef6aba 100644 --- a/cpp/include/raft/sparse/distance/detail/utils.cuh +++ b/cpp/include/raft/sparse/distance/detail/utils.cuh @@ -20,6 +20,7 @@ #include #include +#include namespace raft { namespace sparse { @@ -68,6 +69,7 @@ void epilogue_on_csr(raft::resources const& handle, const value_t* R_sq_norms, raft::distance::DistanceType metric) { + if (nnz == 0) return; auto stream = resource::get_cuda_stream(handle); int blocks = raft::ceildiv((size_t)nnz, tpb); @@ -105,82 +107,132 @@ void epilogue_on_csr(raft::resources const& handle, return value_t(1.0) - dot / (q_norm * r_norm); }); } + RAFT_CUDA_TRY(cudaPeekAtLastError()); } -template -__inline__ __device__ value_t warpReduceSum(value_t val) -{ - return val; -} - -template -RAFT_KERNEL faster_dot_on_csr_kernel(value_t* __restrict__ dot, - const value_idx* __restrict__ rows, - const value_idx* __restrict__ cols, - const value_t* __restrict__ A, - const value_t* __restrict__ B, - const value_idx nnz, - const value_idx dim) +template +__global__ void faster_dot_on_csr_kernel(value_t* __restrict__ dot, + const value_idx* __restrict__ indptr, + const value_idx* __restrict__ cols, + const value_t* __restrict__ A, + const value_t* __restrict__ B, + const value_idx nnz, + const value_idx n_rows, + const value_idx dim) { - auto dot_id = blockIdx.x; auto vec_id = threadIdx.x; auto lane_id = threadIdx.x & 0x1f; - const value_idx row = rows[dot_id] * dim; - const value_idx col = cols[dot_id] * dim; - __shared__ value_t g_dot_; + extern __shared__ char smem[]; + value_t* s_A = (value_t*)smem; + value_idx cur_row = -1; - if (threadIdx.x == 0) { g_dot_ = 0.0; } - __syncthreads(); +#pragma unroll + for (int row = blockIdx.x; row < n_rows; row += gridDim.x) { +#pragma unroll + for (int dot_id = blockIdx.y + indptr[row]; dot_id < indptr[row + 1]; dot_id += gridDim.y) { + if (dot_id >= nnz) { return; } + const value_idx col = cols[dot_id] * dim; + const value_t* __restrict__ B_col = B + col; - value_t l_dot_ = 0.0; + if (threadIdx.x == 0) { dot[dot_id] = 0.0; } + __syncthreads(); + if (cur_row != row) { #pragma unroll - for (value_idx k = vec_id; k < dim; k += blockDim.x) { - l_dot_ += A[row + k] * B[col + k]; - } + for (value_idx k = vec_id; k < dim; k += blockDim.x) { + s_A[k] = A[row * dim + k]; + } + cur_row = row; + } + value_t l_dot_ = 0.0; #pragma unroll - for (int offset = warpSize / 2; offset > 0; offset /= 2) { - l_dot_ += __shfl_down_sync(0xffffffff, l_dot_, offset); + for (value_idx k = vec_id; k < dim; k += blockDim.x) { + asm("prefetch.global.L2 [%0];" ::"l"(B_col + k + blockDim.x)); + l_dot_ += s_A[k] * __ldcg(B_col + k); + } + l_dot_ += __shfl_down_sync(0xffffffff, l_dot_, 16); + l_dot_ += __shfl_down_sync(0xffff, l_dot_, 8); + l_dot_ += __shfl_down_sync(0xff, l_dot_, 4); + l_dot_ += __shfl_down_sync(0xf, l_dot_, 2); + l_dot_ += __shfl_down_sync(0x3, l_dot_, 1); + + if (lane_id == 0) { atomicAdd_block(dot + dot_id, l_dot_); } + } } - - if (lane_id == 0) { atomicAdd_block(&g_dot_, l_dot_); } - __syncthreads(); - - if (threadIdx.x == 0) { dot[dot_id] = g_dot_; } } template void faster_dot_on_csr(raft::resources const& handle, value_t* dot, const value_idx nnz, - const value_idx* rows, + const value_idx* indptr, const value_idx* cols, const value_t* A, const value_t* B, + const value_idx n_rows, const value_idx dim) { + if (nnz == 0 || n_rows == 0) return; + auto stream = resource::get_cuda_stream(handle); - int blocks = int(nnz); + constexpr value_idx MAX_ROW_PER_ITER = 500; + int dev_id, sm_count, blocks_per_sm; + + const int smem_size = dim * sizeof(value_t); + cudaGetDevice(&dev_id); + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + if (dim < 128) { constexpr int tpb = 64; - faster_dot_on_csr_kernel - <<>>(dot, rows, cols, A, B, nnz, dim); + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &blocks_per_sm, faster_dot_on_csr_kernel, tpb, smem_size); + auto block_x = std::min(n_rows, MAX_ROW_PER_ITER); + auto block_y = + (std::min(value_idx(blocks_per_sm * sm_count * 16), nnz) + block_x - 1) / block_x; + dim3 blocks(block_x, block_y, 1); + + faster_dot_on_csr_kernel + <<>>(dot, indptr, cols, A, B, nnz, n_rows, dim); + } else if (dim < 256) { constexpr int tpb = 128; - faster_dot_on_csr_kernel - <<>>(dot, rows, cols, A, B, nnz, dim); + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &blocks_per_sm, faster_dot_on_csr_kernel, tpb, smem_size); + auto block_x = std::min(n_rows, MAX_ROW_PER_ITER); + auto block_y = + (std::min(value_idx(blocks_per_sm * sm_count * 16), nnz) + block_x - 1) / block_x; + dim3 blocks(block_x, block_y, 1); + + faster_dot_on_csr_kernel + <<>>(dot, indptr, cols, A, B, nnz, n_rows, dim); } else if (dim < 512) { constexpr int tpb = 256; - faster_dot_on_csr_kernel - <<>>(dot, rows, cols, A, B, nnz, dim); + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &blocks_per_sm, faster_dot_on_csr_kernel, tpb, smem_size); + auto block_x = std::min(n_rows, MAX_ROW_PER_ITER); + auto block_y = + (std::min(value_idx(blocks_per_sm * sm_count * 16), nnz) + block_x - 1) / block_x; + dim3 blocks(block_x, block_y, 1); + + faster_dot_on_csr_kernel + <<>>(dot, indptr, cols, A, B, nnz, n_rows, dim); } else { constexpr int tpb = 512; - faster_dot_on_csr_kernel - <<>>(dot, rows, cols, A, B, nnz, dim); + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &blocks_per_sm, faster_dot_on_csr_kernel, tpb, smem_size); + auto block_x = std::min(n_rows, MAX_ROW_PER_ITER); + auto block_y = + (std::min(value_idx(blocks_per_sm * sm_count * 16), nnz) + block_x - 1) / block_x; + dim3 blocks(block_x, block_y, 1); + + faster_dot_on_csr_kernel + <<>>(dot, indptr, cols, A, B, nnz, n_rows, dim); } + + RAFT_CUDA_TRY(cudaPeekAtLastError()); } } // namespace detail } // namespace distance diff --git a/cpp/test/sparse/neighbors/prefiltered_brute_force.cu b/cpp/test/sparse/neighbors/prefiltered_brute_force.cu index 990f012ec0..916bc791c5 100644 --- a/cpp/test/sparse/neighbors/prefiltered_brute_force.cu +++ b/cpp/test/sparse/neighbors/prefiltered_brute_force.cu @@ -29,6 +29,7 @@ #include #include #include +#include #include @@ -443,8 +444,25 @@ const std::vector> selectk_inputs = { {1000, 10000, 5, 16, 0.5, raft::distance::DistanceType::CosineExpanded}, {1000, 10000, 8, 16, 0.2, raft::distance::DistanceType::CosineExpanded}}; +template +const std::vector> selectk_inputs_extra = + raft::util::itertools::product>( + {index_t(1), index_t(10), index_t(1000)}, // n_queries + {index_t(10 * 1024), index_t(100 * 1024)}, // n_dataset + {index_t(128), index_t(256), index_t(768), index_t(4096)}, // n_dim + {index_t(1), index_t(255), index_t(1024)}, // k + {float(0.0), float(0.2), float(0.01)}, // sparsity + {raft::distance::DistanceType::InnerProduct, + raft::distance::DistanceType::L2Expanded, + raft::distance::DistanceType::L2SqrtExpanded, + raft::distance::DistanceType::CosineExpanded}); + INSTANTIATE_TEST_CASE_P(PrefilteredBruteForceTest, PrefilteredBruteForceTest_float_int64, ::testing::ValuesIn(selectk_inputs)); +INSTANTIATE_TEST_CASE_P(PrefilteredBruteForceExtraTest, + PrefilteredBruteForceTest_float_int64, + ::testing::ValuesIn(selectk_inputs_extra)); + } // namespace raft::neighbors::brute_force From b4971c26007e5b56eeb197ebf92226918b23884c Mon Sep 17 00:00:00 2001 From: "Artem M. Chirkin" <9253178+achirkin@users.noreply.github.com> Date: Mon, 13 May 2024 18:44:29 +0200 Subject: [PATCH 04/15] Update cpp/include/raft/sparse/distance/detail/utils.cuh --- cpp/include/raft/sparse/distance/detail/utils.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/sparse/distance/detail/utils.cuh b/cpp/include/raft/sparse/distance/detail/utils.cuh index 021fef6aba..bbce80c7a4 100644 --- a/cpp/include/raft/sparse/distance/detail/utils.cuh +++ b/cpp/include/raft/sparse/distance/detail/utils.cuh @@ -111,7 +111,7 @@ void epilogue_on_csr(raft::resources const& handle, } template -__global__ void faster_dot_on_csr_kernel(value_t* __restrict__ dot, +RAFT_KERNEL void faster_dot_on_csr_kernel(value_t* __restrict__ dot, const value_idx* __restrict__ indptr, const value_idx* __restrict__ cols, const value_t* __restrict__ A, From 77ee4a6960f2f5ddd0a6f07086a833d656761cfc Mon Sep 17 00:00:00 2001 From: rhdong Date: Tue, 14 May 2024 05:04:07 -0700 Subject: [PATCH 05/15] Test cases adjustment --- cpp/bench/prims/neighbors/knn.cuh | 17 ++++++++++------- .../knn/brute_force_filter_float_int64_t.cu | 2 ++ .../raft/sparse/distance/detail/utils.cuh | 16 ++++++++-------- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/cpp/bench/prims/neighbors/knn.cuh b/cpp/bench/prims/neighbors/knn.cuh index 4f8e62b1da..3b69288265 100644 --- a/cpp/bench/prims/neighbors/knn.cuh +++ b/cpp/bench/prims/neighbors/knn.cuh @@ -599,13 +599,16 @@ const std::vector kInputsFilter = {0.0, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64} // removed_ratio ); -const std::vector kInputsBruteForceFilter = raft::util::itertools::product( - {size_t(10 * 1024 * 1024)}, // n_samples - {size_t(256), size_t(768), size_t(1024), size_t(2048), size_t(4096)}, // n_dim - {size_t(10), size_t(1000)}, // n_queries - {size_t(255)}, // k - {0.0, 0.8, 0.9, 0.99}, // removed_ratio - {raft::distance::DistanceType::InnerProduct}); +const std::vector kInputsBruteForceFilter = + raft::util::itertools::product({size_t(1 * 1024 * 1024)}, // n_samples + {size_t(256), size_t(2051)}, // n_dim + {size_t(1000)}, // n_queries + {size_t(1), size_t(255)}, // k + {0.0, 0.8, 0.99}, // removed_ratio + {raft::distance::DistanceType::InnerProduct, + raft::distance::DistanceType::L2Expanded, + raft::distance::DistanceType::L2SqrtExpanded, + raft::distance::DistanceType::CosineExpanded}); const std::vector kInputsBruteForceFilterExtra = raft::util::itertools::product({size_t(1024 * 1024)}, // n_samples diff --git a/cpp/bench/prims/neighbors/knn/brute_force_filter_float_int64_t.cu b/cpp/bench/prims/neighbors/knn/brute_force_filter_float_int64_t.cu index 7cb0b45d30..e4e39a8de9 100644 --- a/cpp/bench/prims/neighbors/knn/brute_force_filter_float_int64_t.cu +++ b/cpp/bench/prims/neighbors/knn/brute_force_filter_float_int64_t.cu @@ -18,6 +18,8 @@ #include "../knn.cuh" namespace raft::bench::spatial { +KNN_REGISTER( + float, int64_t, brute_force_filter_knn, kInputsBruteForceFilter, kNoCopyOnly, kScopeOnlySearch); KNN_REGISTER(float, int64_t, diff --git a/cpp/include/raft/sparse/distance/detail/utils.cuh b/cpp/include/raft/sparse/distance/detail/utils.cuh index bbce80c7a4..6122641f7d 100644 --- a/cpp/include/raft/sparse/distance/detail/utils.cuh +++ b/cpp/include/raft/sparse/distance/detail/utils.cuh @@ -111,14 +111,14 @@ void epilogue_on_csr(raft::resources const& handle, } template -RAFT_KERNEL void faster_dot_on_csr_kernel(value_t* __restrict__ dot, - const value_idx* __restrict__ indptr, - const value_idx* __restrict__ cols, - const value_t* __restrict__ A, - const value_t* __restrict__ B, - const value_idx nnz, - const value_idx n_rows, - const value_idx dim) +RAFT_KERNEL faster_dot_on_csr_kernel(value_t* __restrict__ dot, + const value_idx* __restrict__ indptr, + const value_idx* __restrict__ cols, + const value_t* __restrict__ A, + const value_t* __restrict__ B, + const value_idx nnz, + const value_idx n_rows, + const value_idx dim) { auto vec_id = threadIdx.x; auto lane_id = threadIdx.x & 0x1f; From cc2b228f225354ddb728c43d48a54bd30a59d232 Mon Sep 17 00:00:00 2001 From: hrong Date: Wed, 15 May 2024 09:28:28 -0700 Subject: [PATCH 06/15] Merge SDDMM with customized kernel, optimize bitset count --- cpp/include/raft/core/bitset.cuh | 34 ++-------------- .../raft/neighbors/detail/knn_brute_force.cuh | 40 ++++++++++++++----- 2 files changed, 33 insertions(+), 41 deletions(-) diff --git a/cpp/include/raft/core/bitset.cuh b/cpp/include/raft/core/bitset.cuh index e45781f357..d7eedee92e 100644 --- a/cpp/include/raft/core/bitset.cuh +++ b/cpp/include/raft/core/bitset.cuh @@ -167,37 +167,9 @@ template void bitset::count(const raft::resources& res, raft::device_scalar_view count_gpu_scalar) { - auto n_elements_ = n_elements(); - auto count_gpu = - raft::make_device_vector_view(count_gpu_scalar.data_handle(), 1); - auto bitset_matrix_view = raft::make_device_matrix_view( - bitset_.data(), n_elements_, 1); - - bitset_t n_last_element = (bitset_len_ % bitset_element_size); - bitset_t last_element_mask = - n_last_element ? (bitset_t)((bitset_t{1} << n_last_element) - bitset_t{1}) : ~bitset_t{0}; - raft::linalg::coalesced_reduction( - res, - bitset_matrix_view, - count_gpu, - index_t{0}, - false, - [last_element_mask, n_elements_] __device__(bitset_t element, index_t index) { - index_t result = 0; - if constexpr (bitset_element_size == 64) { - if (index == n_elements_ - 1) - result = index_t(raft::detail::popc(element & last_element_mask)); - else - result = index_t(raft::detail::popc(element)); - } else { // Needed because popc is not overloaded for 16 and 8 bit elements - if (index == n_elements_ - 1) - result = index_t(raft::detail::popc(uint32_t{element} & last_element_mask)); - else - result = index_t(raft::detail::popc(uint32_t{element})); - } - - return result; - }); + auto values = + raft::make_device_vector_view(bitset_.data(), n_elements()); + raft::detail::popc(res, values, bitset_len_, count_gpu_scalar); } } // end namespace raft::core diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 747fcb11db..1491f77e63 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -612,16 +612,36 @@ void brute_force_search( rows.data(), compressed_csr_view.get_nnz(), stream); - - raft::sparse::distance::detail::faster_dot_on_csr(res, - csr.get_elements().data(), - compressed_csr_view.get_nnz(), - compressed_csr_view.get_indptr().data(), - compressed_csr_view.get_indices().data(), - queries.data_handle(), - idx.dataset().data_handle(), - compressed_csr_view.get_n_rows(), - dim); + if (n_queries > 10 || (1.0f * nnz_h / (1.0f * n_queries * n_dataset)) > 0.01f) { + auto csr_view = make_device_csr_matrix_view(csr.get_elements().data(), + compressed_csr_view); + + // create dataset view + auto dataset_view = raft::make_device_matrix_view( + idx.dataset().data_handle(), dim, n_dataset); + + // calc dot + T alpha = static_cast(1.0f); + T beta = static_cast(0.0f); + raft::sparse::linalg::sddmm(res, + queries, + dataset_view, + csr_view, + raft::linalg::Operation::NON_TRANSPOSE, + raft::linalg::Operation::NON_TRANSPOSE, + raft::make_host_scalar_view(&alpha), + raft::make_host_scalar_view(&beta)); + } else { + raft::sparse::distance::detail::faster_dot_on_csr(res, + csr.get_elements().data(), + compressed_csr_view.get_nnz(), + compressed_csr_view.get_indptr().data(), + compressed_csr_view.get_indices().data(), + queries.data_handle(), + idx.dataset().data_handle(), + compressed_csr_view.get_n_rows(), + dim); + } // post process std::optional> query_norms_; From 2684afe3bc9663385b2f27e7f347af92fb8e10d2 Mon Sep 17 00:00:00 2001 From: hrong Date: Mon, 20 May 2024 11:57:57 -0700 Subject: [PATCH 07/15] Optimize by dense bfknn --- cpp/bench/prims/neighbors/knn.cuh | 17 +++--- .../raft/neighbors/detail/knn_brute_force.cuh | 60 ++++++++++++++++++- .../neighbors/prefiltered_brute_force.cu | 7 ++- 3 files changed, 69 insertions(+), 15 deletions(-) diff --git a/cpp/bench/prims/neighbors/knn.cuh b/cpp/bench/prims/neighbors/knn.cuh index 3b69288265..64bda8b36d 100644 --- a/cpp/bench/prims/neighbors/knn.cuh +++ b/cpp/bench/prims/neighbors/knn.cuh @@ -599,16 +599,13 @@ const std::vector kInputsFilter = {0.0, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64} // removed_ratio ); -const std::vector kInputsBruteForceFilter = - raft::util::itertools::product({size_t(1 * 1024 * 1024)}, // n_samples - {size_t(256), size_t(2051)}, // n_dim - {size_t(1000)}, // n_queries - {size_t(1), size_t(255)}, // k - {0.0, 0.8, 0.99}, // removed_ratio - {raft::distance::DistanceType::InnerProduct, - raft::distance::DistanceType::L2Expanded, - raft::distance::DistanceType::L2SqrtExpanded, - raft::distance::DistanceType::CosineExpanded}); +const std::vector kInputsBruteForceFilter = raft::util::itertools::product( + {size_t(10000000), size_t(1 * 1024 * 1024)}, // n_samples + {size_t(256), size_t(2048)}, // n_dim + {size_t(1), size_t(10), size_t(100), size_t(1000)}, // n_queries + {size_t(256)}, // k + {0.0, 0.8, 0.99}, // removed_ratio + {raft::distance::DistanceType::InnerProduct}); const std::vector kInputsBruteForceFilterExtra = raft::util::itertools::product({size_t(1024 * 1024)}, // n_samples diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 1491f77e63..13eae9f2c2 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -81,7 +81,8 @@ void tiled_brute_force_knn(const raft::resources& handle, size_t max_col_tile_size = 0, DistanceEpilogue distance_epilogue = raft::identity_op(), const ElementType* precomputed_index_norms = nullptr, - const ElementType* precomputed_search_norms = nullptr) + const ElementType* precomputed_search_norms = nullptr, + const uint32_t* filter_bitmap = nullptr) { // Figure out the number of rows/cols to tile for size_t tile_rows = 0; @@ -246,6 +247,27 @@ void tiled_brute_force_knn(const raft::resources& handle, } } + if (filter_bitmap != nullptr) { + auto distances_ptr = temp_distances.data(); + auto count = thrust::make_counting_iterator(0); + ElementType masked_distance = select_min ? std::numeric_limits::infinity() + : std::numeric_limits::lowest(); + thrust::for_each(resource::get_thrust_policy(handle), + count, + count + current_query_size * current_centroid_size, + [=] __device__(IndexType idx) { + IndexType row = i + (idx / current_centroid_size); + IndexType col = j + (idx % current_centroid_size); + IndexType g_idx = row * n + col; + IndexType item_idx = (g_idx) >> 5; + uint32_t bit_idx = (g_idx)&31; + uint32_t filter = filter_bitmap[item_idx]; + if ((filter & (uint32_t(1) << bit_idx)) == 0) { + distances_ptr[idx] = masked_distance; + } + }); + } + matrix::select_k( handle, raft::make_device_matrix_view( @@ -587,6 +609,7 @@ void brute_force_search( IdxT n_queries = queries.extent(0); IdxT n_dataset = idx.dataset().extent(0); IdxT dim = idx.dataset().extent(1); + IdxT k = neighbors.extent(1); auto stream = resource::get_cuda_stream(res); @@ -599,6 +622,35 @@ void brute_force_search( raft::detail::popc(res, filter_view, n_queries * n_dataset, nnz_view); raft::copy(&nnz_h, nnz.data(), 1, stream); + + resource::sync_stream(res, stream); + float sparsity = (1.0f * nnz_h / (1.0f * n_queries * n_dataset)); + + if (sparsity > 0.01f) { + raft::resources stream_pool_handle(res); + raft::resource::set_cuda_stream(stream_pool_handle, stream); + auto idx_norm = idx.has_norms() ? const_cast(idx.norms().data_handle()) : nullptr; + + tiled_brute_force_knn(stream_pool_handle, + queries.data_handle(), + idx.dataset().data_handle(), + n_queries, + n_dataset, + dim, + k, + distances.data_handle(), + neighbors.data_handle(), + metric, + 2.0, + 0, + 0, + raft::identity_op(), + idx_norm, + nullptr, + filter.data()); + return; + } + auto csr = raft::make_device_csr_matrix(res, n_queries, n_dataset, nnz_h); // fill csr @@ -612,7 +664,7 @@ void brute_force_search( rows.data(), compressed_csr_view.get_nnz(), stream); - if (n_queries > 10 || (1.0f * nnz_h / (1.0f * n_queries * n_dataset)) > 0.01f) { + if (n_queries > 10) { auto csr_view = make_device_csr_matrix_view(csr.get_elements().data(), compressed_csr_view); @@ -688,7 +740,9 @@ void brute_force_search( auto const_csr_view = make_device_csr_matrix_view( csr.get_elements().data(), compressed_csr_view); std::optional> no_opt = std::nullopt; - raft::sparse::matrix::select_k(res, const_csr_view, no_opt, distances, neighbors, true, true); + bool select_min = raft::distance::is_min_close(metric); + raft::sparse::matrix::select_k( + res, const_csr_view, no_opt, distances, neighbors, select_min, true); return; } diff --git a/cpp/test/sparse/neighbors/prefiltered_brute_force.cu b/cpp/test/sparse/neighbors/prefiltered_brute_force.cu index 916bc791c5..d08b564c7c 100644 --- a/cpp/test/sparse/neighbors/prefiltered_brute_force.cu +++ b/cpp/test/sparse/neighbors/prefiltered_brute_force.cu @@ -312,8 +312,11 @@ class PrefilteredBruteForceTest cpu_sddmm(queries_h, dataset_h, values_h, indices_h, indptr_h, true, false); + bool select_min = raft::distance::is_min_close(params.metric); + std::vector out_val_h(params.n_queries * params.top_k, - std::numeric_limits::infinity()); + select_min ? std::numeric_limits::infinity() + : std::numeric_limits::lowest()); std::vector out_idx_h(params.n_queries * params.top_k, static_cast(0)); out_val_d.resize(params.n_queries * params.top_k, stream); @@ -336,7 +339,7 @@ class PrefilteredBruteForceTest params.top_k, out_val_h, out_idx_h, - params.select_min); + select_min); out_val_expected_d.resize(params.n_queries * params.top_k, stream); out_idx_expected_d.resize(params.n_queries * params.top_k, stream); From 8e1217ca94f8c620a3892e0efd696779e57a839d Mon Sep 17 00:00:00 2001 From: hrong Date: Mon, 20 May 2024 15:19:00 -0700 Subject: [PATCH 08/15] Optimize the test cases --- cpp/bench/prims/neighbors/knn.cuh | 2 +- cpp/include/raft/neighbors/detail/knn_brute_force.cuh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/bench/prims/neighbors/knn.cuh b/cpp/bench/prims/neighbors/knn.cuh index 64bda8b36d..026d74f74b 100644 --- a/cpp/bench/prims/neighbors/knn.cuh +++ b/cpp/bench/prims/neighbors/knn.cuh @@ -601,7 +601,7 @@ const std::vector kInputsFilter = const std::vector kInputsBruteForceFilter = raft::util::itertools::product( {size_t(10000000), size_t(1 * 1024 * 1024)}, // n_samples - {size_t(256), size_t(2048)}, // n_dim + {size_t(256)}, // n_dim {size_t(1), size_t(10), size_t(100), size_t(1000)}, // n_queries {size_t(256)}, // k {0.0, 0.8, 0.99}, // removed_ratio diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 13eae9f2c2..8d0e5e539d 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -626,7 +626,7 @@ void brute_force_search( resource::sync_stream(res, stream); float sparsity = (1.0f * nnz_h / (1.0f * n_queries * n_dataset)); - if (sparsity > 0.01f) { + if (sparsity > 0.011f) { raft::resources stream_pool_handle(res); raft::resource::set_cuda_stream(stream_pool_handle, stream); auto idx_norm = idx.has_norms() ? const_cast(idx.norms().data_handle()) : nullptr; From 96f4e83b85ca61a704d46b48d69dcea4e411b9d9 Mon Sep 17 00:00:00 2001 From: hrong Date: Tue, 21 May 2024 12:12:30 -0700 Subject: [PATCH 09/15] Splitting(revert) the cuVS part --- cpp/bench/prims/CMakeLists.txt | 1 - cpp/bench/prims/neighbors/knn.cuh | 134 +---- .../knn/brute_force_filter_float_int64_t.cu | 31 -- .../raft/neighbors/brute_force-ext.cuh | 25 - .../raft/neighbors/brute_force-inl.cuh | 29 +- .../raft/neighbors/detail/knn_brute_force.cuh | 201 +------- .../neighbors/brute_force_knn_index_float.cu | 16 - .../neighbors/prefiltered_brute_force.cu | 471 ------------------ 8 files changed, 8 insertions(+), 900 deletions(-) delete mode 100644 cpp/bench/prims/neighbors/knn/brute_force_filter_float_int64_t.cu delete mode 100644 cpp/test/sparse/neighbors/prefiltered_brute_force.cu diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index 471a849888..0771a60e58 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -145,7 +145,6 @@ if(BUILD_PRIMS_BENCH) NAME NEIGHBORS_BENCH PATH - neighbors/knn/brute_force_filter_float_int64_t.cu neighbors/knn/brute_force_float_int64_t.cu neighbors/knn/brute_force_float_uint32_t.cu neighbors/knn/cagra_float_uint32_t.cu diff --git a/cpp/bench/prims/neighbors/knn.cuh b/cpp/bench/prims/neighbors/knn.cuh index 026d74f74b..6499078623 100644 --- a/cpp/bench/prims/neighbors/knn.cuh +++ b/cpp/bench/prims/neighbors/knn.cuh @@ -20,7 +20,6 @@ #include #include -#include #include #include #include @@ -37,10 +36,7 @@ #include -#include #include -#include -#include namespace raft::bench::spatial { @@ -55,24 +51,12 @@ struct params { size_t k; /** Ratio of removed indices. */ double removed_ratio; - /** Distance Type. */ - raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded; }; inline auto operator<<(std::ostream& os, const params& p) -> std::ostream& { - os << p.n_samples << "#" << p.n_dims << "#" << p.n_queries << "#" << p.k; - if (p.removed_ratio > 0.0) { - os << "#" << p.removed_ratio; - } else { - os << "#" - << "[No filtered]"; - } - switch (p.metric) { - case raft::distance::DistanceType::InnerProduct: os << "#InnerProduct"; break; - case raft::distance::DistanceType::L2Expanded: os << "#L2Expanded"; break; - default: os << "UNKNOWN DistanceType, please add one case here."; - } + os << p.n_samples << "#" << p.n_dims << "#" << p.n_queries << "#" << p.k << "#" + << p.removed_ratio; return os; } @@ -165,7 +149,7 @@ struct ivf_flat_knn { ivf_flat_knn(const raft::device_resources& handle, const params& ps, const ValT* data) : ps(ps) { index_params.n_lists = 4096; - index_params.metric = ps.metric; + index_params.metric = raft::distance::DistanceType::L2Expanded; index.emplace(raft::neighbors::ivf_flat::build( handle, index_params, data, IdxT(ps.n_samples), uint32_t(ps.n_dims))); } @@ -200,7 +184,7 @@ struct ivf_pq_knn { ivf_pq_knn(const raft::device_resources& handle, const params& ps, const ValT* data) : ps(ps) { index_params.n_lists = 4096; - index_params.metric = ps.metric; + index_params.metric = raft::distance::DistanceType::L2Expanded; auto data_view = raft::make_device_matrix_view(data, ps.n_samples, ps.n_dims); index.emplace(raft::neighbors::ivf_pq::build(handle, index_params, data_view)); } @@ -252,88 +236,6 @@ struct brute_force_knn { } }; -template -RAFT_KERNEL initialize_random_bits( - bitmap_t* data, IdxT N, float sparsity, size_t total_bits, unsigned long seed) -{ - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= N) return; - - curandState state; - curand_init(seed, idx, 0, &state); - - bitmap_t value = 0; - for (int i = 0; i < sizeof(bitmap_t) * 8; i++) { - int rnd = curand(&state) % 10000; - - if (rnd < int(10000 * sparsity) && (idx * sizeof(bitmap_t) * 8 + i < total_bits)) { - bitmap_t bit_mask = 1u << i; - value |= bit_mask; - } - } - data[idx] = value; -} - -template -struct brute_force_filter_knn { - using dist_t = float; - using bitmap_t = std::uint32_t; - - std::optional> index; - raft::neighbors::brute_force::index_params index_params; - raft::neighbors::brute_force::search_params search_params; - raft::core::bitset removed_indices_bitset_; - params ps; - - brute_force_filter_knn(const raft::device_resources& handle, const params& ps, const ValT* data) - : ps(ps), removed_indices_bitset_(handle, ps.n_samples * ps.n_queries) - { - auto stream = resource::get_cuda_stream(handle); - index_params.metric = ps.metric; - - auto data_view = raft::make_device_matrix_view(data, ps.n_samples, ps.n_dims); - index.emplace(raft::neighbors::brute_force::build(handle, index_params, data_view)); - - IdxT element = raft::ceildiv(IdxT(ps.n_samples * ps.n_queries), IdxT(sizeof(bitmap_t) * 8)); - - size_t threadsPerBlock = 256; - size_t numBlocks = (element + threadsPerBlock - 1) / threadsPerBlock; - unsigned long seed = 1234; - initialize_random_bits<<>>( - removed_indices_bitset_.data(), - removed_indices_bitset_.size(), - float(1.0 - ps.removed_ratio), - ps.n_samples * ps.n_queries, - seed); - - resource::sync_stream(handle); - } - - void search(const raft::device_resources& handle, - const ValT* search_items, - ValT* out_dists, - IdxT* out_idxs) - { - auto queries_view = - raft::make_device_matrix_view(search_items, ps.n_queries, ps.n_dims); - auto neighbors_view = - raft::make_device_matrix_view(out_idxs, ps.n_queries, ps.k); - auto distance_view = - raft::make_device_matrix_view(out_dists, ps.n_queries, ps.k); - - if (ps.removed_ratio > 0) { - auto filter = raft::core::bitmap_view( - (const bitmap_t*)removed_indices_bitset_.data(), IdxT(ps.n_queries), IdxT(ps.n_samples)); - - raft::neighbors::brute_force::search_with_filtering( - handle, *index, queries_view, filter, neighbors_view, distance_view); - } else { - raft::neighbors::brute_force::search( - handle, search_params, *index, queries_view, neighbors_view, distance_view); - } - } -}; - template struct ivf_flat_filter_knn { using dist_t = float; @@ -348,7 +250,7 @@ struct ivf_flat_filter_knn { : ps(ps), removed_indices_bitset_(handle, ps.n_samples) { index_params.n_lists = 4096; - index_params.metric = ps.metric; + index_params.metric = raft::distance::DistanceType::L2Expanded; index.emplace(raft::neighbors::ivf_flat::build( handle, index_params, data, IdxT(ps.n_samples), uint32_t(ps.n_dims))); auto removed_indices = @@ -396,7 +298,7 @@ struct ivf_pq_filter_knn { : ps(ps), removed_indices_bitset_(handle, ps.n_samples) { index_params.n_lists = 4096; - index_params.metric = ps.metric; + index_params.metric = raft::distance::DistanceType::L2Expanded; auto data_view = raft::make_device_matrix_view(data, ps.n_samples, ps.n_dims); index.emplace(raft::neighbors::ivf_pq::build(handle, index_params, data_view)); auto removed_indices = @@ -598,34 +500,10 @@ const std::vector kInputsFilter = {size_t(255)}, // k {0.0, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64} // removed_ratio ); - -const std::vector kInputsBruteForceFilter = raft::util::itertools::product( - {size_t(10000000), size_t(1 * 1024 * 1024)}, // n_samples - {size_t(256)}, // n_dim - {size_t(1), size_t(10), size_t(100), size_t(1000)}, // n_queries - {size_t(256)}, // k - {0.0, 0.8, 0.99}, // removed_ratio - {raft::distance::DistanceType::InnerProduct}); - -const std::vector kInputsBruteForceFilterExtra = - raft::util::itertools::product({size_t(1024 * 1024)}, // n_samples - {size_t(256), size_t(768)}, // n_dim - {size_t(10), - size_t(20), - size_t(40), - size_t(60), - size_t(80), - size_t(100), - size_t(300)}, // n_queries - {size_t(255)}, // k - {0.3, 0.4, 0.9}, // removed_ratio - {raft::distance::DistanceType::InnerProduct}); - inline const std::vector kAllStrategies{ TransferStrategy::NO_COPY, TransferStrategy::MAP_PINNED, TransferStrategy::MANAGED}; inline const std::vector kNoCopyOnly{TransferStrategy::NO_COPY}; -inline const std::vector kScopeOnlySearch{Scope::SEARCH}; inline const std::vector kScopeFull{Scope::BUILD_SEARCH}; inline const std::vector kAllScopes{Scope::BUILD_SEARCH, Scope::SEARCH, Scope::BUILD}; diff --git a/cpp/bench/prims/neighbors/knn/brute_force_filter_float_int64_t.cu b/cpp/bench/prims/neighbors/knn/brute_force_filter_float_int64_t.cu deleted file mode 100644 index e4e39a8de9..0000000000 --- a/cpp/bench/prims/neighbors/knn/brute_force_filter_float_int64_t.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter -#include "../knn.cuh" - -namespace raft::bench::spatial { -KNN_REGISTER( - float, int64_t, brute_force_filter_knn, kInputsBruteForceFilter, kNoCopyOnly, kScopeOnlySearch); - -KNN_REGISTER(float, - int64_t, - brute_force_filter_knn, - kInputsBruteForceFilterExtra, - kNoCopyOnly, - kScopeOnlySearch); - -} // namespace raft::bench::spatial diff --git a/cpp/include/raft/neighbors/brute_force-ext.cuh b/cpp/include/raft/neighbors/brute_force-ext.cuh index c033077e51..4055c253c8 100644 --- a/cpp/include/raft/neighbors/brute_force-ext.cuh +++ b/cpp/include/raft/neighbors/brute_force-ext.cuh @@ -16,7 +16,6 @@ #pragma once -#include #include // raft::device_matrix_view #include // raft::identity_op #include // raft::resources @@ -66,14 +65,6 @@ void search(raft::resources const& res, raft::device_matrix_view neighbors, raft::device_matrix_view distances) RAFT_EXPLICIT; -template -void search_with_filtering(raft::resources const& res, - const index& idx, - raft::device_matrix_view queries, - raft::core::bitmap_view filter, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances) RAFT_EXPLICIT; - template ( raft::device_matrix_view neighbors, raft::device_matrix_view distances); -extern template void search_with_filtering( - raft::resources const& res, - const raft::neighbors::brute_force::index& idx, - raft::device_matrix_view queries, - raft::core::bitmap_view filter, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances); - extern template void search( raft::resources const& res, const raft::neighbors::brute_force::index& idx, @@ -169,14 +152,6 @@ extern template void search( raft::device_matrix_view neighbors, raft::device_matrix_view distances); -extern template void search_with_filtering( - raft::resources const& res, - const raft::neighbors::brute_force::index& idx, - raft::device_matrix_view queries, - raft::core::bitmap_view filter, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances); - extern template raft::neighbors::brute_force::index build( raft::resources const& res, raft::device_matrix_view dataset, diff --git a/cpp/include/raft/neighbors/brute_force-inl.cuh b/cpp/include/raft/neighbors/brute_force-inl.cuh index c10a45d3b4..f955cc8518 100644 --- a/cpp/include/raft/neighbors/brute_force-inl.cuh +++ b/cpp/include/raft/neighbors/brute_force-inl.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ #pragma once -#include #include #include #include @@ -450,31 +449,5 @@ void search(raft::resources const& res, raft::neighbors::detail::brute_force_search(res, idx, queries, neighbors, distances); } -/** - * @brief Brute Force search with filter using the constructed index. - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] res raft resources - * @param[in] idx brute force index - * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] - * @param[in] sample_filter a device filter function that green lights samples for a given query - * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, - * k] - */ -template -void search_with_filtering(raft::resources const& res, - const index& idx, - raft::device_matrix_view queries, - raft::core::bitmap_view sample_filter, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances) -{ - raft::neighbors::detail::brute_force_search( - res, idx, queries, sample_filter, neighbors, distances); -} /** @} */ // end group brute_force_knn } // namespace raft::neighbors::brute_force diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 8d0e5e539d..daa2798b00 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -16,9 +16,6 @@ #pragma once -#include -#include -#include #include #include #include @@ -28,18 +25,12 @@ #include #include #include -#include #include #include #include #include #include #include -#include -#include -#include -#include -#include #include #include #include @@ -81,8 +72,7 @@ void tiled_brute_force_knn(const raft::resources& handle, size_t max_col_tile_size = 0, DistanceEpilogue distance_epilogue = raft::identity_op(), const ElementType* precomputed_index_norms = nullptr, - const ElementType* precomputed_search_norms = nullptr, - const uint32_t* filter_bitmap = nullptr) + const ElementType* precomputed_search_norms = nullptr) { // Figure out the number of rows/cols to tile for size_t tile_rows = 0; @@ -247,27 +237,6 @@ void tiled_brute_force_knn(const raft::resources& handle, } } - if (filter_bitmap != nullptr) { - auto distances_ptr = temp_distances.data(); - auto count = thrust::make_counting_iterator(0); - ElementType masked_distance = select_min ? std::numeric_limits::infinity() - : std::numeric_limits::lowest(); - thrust::for_each(resource::get_thrust_policy(handle), - count, - count + current_query_size * current_centroid_size, - [=] __device__(IndexType idx) { - IndexType row = i + (idx / current_centroid_size); - IndexType col = j + (idx % current_centroid_size); - IndexType g_idx = row * n + col; - IndexType item_idx = (g_idx) >> 5; - uint32_t bit_idx = (g_idx)&31; - uint32_t filter = filter_bitmap[item_idx]; - if ((filter & (uint32_t(1) << bit_idx)) == 0) { - distances_ptr[idx] = masked_distance; - } - }); - } - matrix::select_k( handle, raft::make_device_matrix_view( @@ -579,172 +548,4 @@ void brute_force_search( norms.size() ? &norms : nullptr, query_norms ? query_norms->data_handle() : nullptr); } - -template -void brute_force_search( - raft::resources const& res, - const raft::neighbors::brute_force::index& idx, - raft::device_matrix_view queries, - raft::core::bitmap_view filter, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances, - std::optional> query_norms = std::nullopt) -{ - auto metric = idx.metric(); - - RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), "Value of k must match for outputs"); - RAFT_EXPECTS(idx.dataset().extent(1) == queries.extent(1), - "Number of columns in queries must match brute force index"); - RAFT_EXPECTS(metric == raft::distance::DistanceType::InnerProduct || - metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded || - metric == raft::distance::DistanceType::CosineExpanded, - "Only Euclidean, IP, and Cosine are supported!"); - - RAFT_EXPECTS(idx.has_norms() || !(metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded || - metric == raft::distance::DistanceType::CosineExpanded), - "Index must has norms when using Euclidean, IP, and Cosine!"); - - IdxT n_queries = queries.extent(0); - IdxT n_dataset = idx.dataset().extent(0); - IdxT dim = idx.dataset().extent(1); - IdxT k = neighbors.extent(1); - - auto stream = resource::get_cuda_stream(res); - - // calc nnz - IdxT nnz_h = 0; - rmm::device_scalar nnz(0, stream); - auto nnz_view = make_device_scalar_view(nnz.data()); - auto filter_view = - raft::make_device_vector_view(filter.data(), filter.n_elements()); - - raft::detail::popc(res, filter_view, n_queries * n_dataset, nnz_view); - raft::copy(&nnz_h, nnz.data(), 1, stream); - - resource::sync_stream(res, stream); - float sparsity = (1.0f * nnz_h / (1.0f * n_queries * n_dataset)); - - if (sparsity > 0.011f) { - raft::resources stream_pool_handle(res); - raft::resource::set_cuda_stream(stream_pool_handle, stream); - auto idx_norm = idx.has_norms() ? const_cast(idx.norms().data_handle()) : nullptr; - - tiled_brute_force_knn(stream_pool_handle, - queries.data_handle(), - idx.dataset().data_handle(), - n_queries, - n_dataset, - dim, - k, - distances.data_handle(), - neighbors.data_handle(), - metric, - 2.0, - 0, - 0, - raft::identity_op(), - idx_norm, - nullptr, - filter.data()); - return; - } - - auto csr = raft::make_device_csr_matrix(res, n_queries, n_dataset, nnz_h); - - // fill csr - raft::sparse::convert::bitmap_to_csr(res, filter, csr); - - // create filter csr view - auto compressed_csr_view = csr.structure_view(); - rmm::device_uvector rows(compressed_csr_view.get_nnz(), stream); - raft::sparse::convert::csr_to_coo(compressed_csr_view.get_indptr().data(), - compressed_csr_view.get_n_rows(), - rows.data(), - compressed_csr_view.get_nnz(), - stream); - if (n_queries > 10) { - auto csr_view = make_device_csr_matrix_view(csr.get_elements().data(), - compressed_csr_view); - - // create dataset view - auto dataset_view = raft::make_device_matrix_view( - idx.dataset().data_handle(), dim, n_dataset); - - // calc dot - T alpha = static_cast(1.0f); - T beta = static_cast(0.0f); - raft::sparse::linalg::sddmm(res, - queries, - dataset_view, - csr_view, - raft::linalg::Operation::NON_TRANSPOSE, - raft::linalg::Operation::NON_TRANSPOSE, - raft::make_host_scalar_view(&alpha), - raft::make_host_scalar_view(&beta)); - } else { - raft::sparse::distance::detail::faster_dot_on_csr(res, - csr.get_elements().data(), - compressed_csr_view.get_nnz(), - compressed_csr_view.get_indptr().data(), - compressed_csr_view.get_indices().data(), - queries.data_handle(), - idx.dataset().data_handle(), - compressed_csr_view.get_n_rows(), - dim); - } - - // post process - std::optional> query_norms_; - if (metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded || - metric == raft::distance::DistanceType::CosineExpanded) { - if (metric == raft::distance::DistanceType::CosineExpanded) { - if (!query_norms) { - query_norms_ = make_device_vector(res, n_queries); - raft::linalg::rowNorm((T*)(query_norms_->data_handle()), - queries.data_handle(), - dim, - n_queries, - raft::linalg::L2Norm, - true, - stream, - raft::sqrt_op{}); - } - } else { - if (!query_norms) { - query_norms_ = make_device_vector(res, n_queries); - raft::linalg::rowNorm((T*)(query_norms_->data_handle()), - queries.data_handle(), - dim, - n_queries, - raft::linalg::L2Norm, - true, - stream, - raft::identity_op{}); - } - } - raft::sparse::distance::detail::epilogue_on_csr( - res, - csr.get_elements().data(), - compressed_csr_view.get_nnz(), - rows.data(), - compressed_csr_view.get_indices().data(), - query_norms ? query_norms->data_handle() : query_norms_->data_handle(), - idx.norms().data_handle(), - metric); - } - - // select k - auto const_csr_view = make_device_csr_matrix_view( - csr.get_elements().data(), compressed_csr_view); - std::optional> no_opt = std::nullopt; - bool select_min = raft::distance::is_min_close(metric); - raft::sparse::matrix::select_k( - res, const_csr_view, no_opt, distances, neighbors, select_min, true); - - return; -} - } // namespace raft::neighbors::detail diff --git a/cpp/src/neighbors/brute_force_knn_index_float.cu b/cpp/src/neighbors/brute_force_knn_index_float.cu index 4de66b163d..de94be4c09 100644 --- a/cpp/src/neighbors/brute_force_knn_index_float.cu +++ b/cpp/src/neighbors/brute_force_knn_index_float.cu @@ -37,14 +37,6 @@ template void raft::neighbors::brute_force::search( raft::device_matrix_view neighbors, raft::device_matrix_view distances); -template void raft::neighbors::brute_force::search_with_filtering( - raft::resources const& res, - const raft::neighbors::brute_force::index& idx, - raft::device_matrix_view queries, - raft::core::bitmap_view filter, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances); - template void raft::neighbors::brute_force::search( raft::resources const& res, const raft::neighbors::brute_force::index& idx, @@ -60,14 +52,6 @@ template void raft::neighbors::brute_force::search( raft::device_matrix_view neighbors, raft::device_matrix_view distances); -template void raft::neighbors::brute_force::search_with_filtering( - raft::resources const& res, - const raft::neighbors::brute_force::index& idx, - raft::device_matrix_view queries, - raft::core::bitmap_view filter, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances); - template raft::neighbors::brute_force::index raft::neighbors::brute_force:: build::accessor_type>( raft::resources const& res, diff --git a/cpp/test/sparse/neighbors/prefiltered_brute_force.cu b/cpp/test/sparse/neighbors/prefiltered_brute_force.cu deleted file mode 100644 index d08b564c7c..0000000000 --- a/cpp/test/sparse/neighbors/prefiltered_brute_force.cu +++ /dev/null @@ -1,471 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../../neighbors/knn_utils.cuh" -#include "../../test_utils.cuh" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace raft::neighbors::brute_force { - -template -struct PrefilteredBruteForceInputs { - index_t n_queries; - index_t n_dataset; - index_t dim; - index_t top_k; - float sparsity; - raft::distance::DistanceType metric; - bool select_min = true; -}; - -template -struct CompareApproxWithInf { - CompareApproxWithInf(T eps_) : eps(eps_) {} - bool operator()(const T& a, const T& b) const - { - if (std::isinf(a) && std::isinf(b)) return true; - T diff = std::abs(a - b); - T m = std::max(std::abs(a), std::abs(b)); - T ratio = diff > eps ? diff / m : diff; - - return (ratio <= eps); - } - - private: - T eps; -}; - -template -class PrefilteredBruteForceTest - : public ::testing::TestWithParam> { - public: - PrefilteredBruteForceTest() - : stream(resource::get_cuda_stream(handle)), - params(::testing::TestWithParam>::GetParam()), - filter_d(0, stream), - dataset_d(0, stream), - queries_d(0, stream), - out_val_d(0, stream), - out_val_expected_d(0, stream), - out_idx_d(0, stream), - out_idx_expected_d(0, stream) - { - } - - protected: - index_t create_sparse_matrix(index_t m, index_t n, float sparsity, std::vector& bitmap) - { - index_t total = static_cast(m * n); - index_t num_ones = static_cast((total * 1.0f) * sparsity); - index_t res = num_ones; - - for (auto& item : bitmap) { - item = static_cast(0); - } - - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution dis(0, total - 1); - - while (num_ones > 0) { - index_t index = dis(gen); - - bitmap_t& element = bitmap[index / (8 * sizeof(bitmap_t))]; - index_t bit_position = index % (8 * sizeof(bitmap_t)); - - if (((element >> bit_position) & 1) == 0) { - element |= (static_cast(1) << bit_position); - num_ones--; - } - } - return res; - } - - void cpu_convert_to_csr(std::vector& bitmap, - index_t rows, - index_t cols, - std::vector& indices, - std::vector& indptr) - { - index_t offset_indptr = 0; - index_t offset_values = 0; - indptr[offset_indptr++] = 0; - - index_t index = 0; - bitmap_t element = 0; - index_t bit_position = 0; - - for (index_t i = 0; i < rows; ++i) { - for (index_t j = 0; j < cols; ++j) { - index = i * cols + j; - element = bitmap[index / (8 * sizeof(bitmap_t))]; - bit_position = index % (8 * sizeof(bitmap_t)); - - if (((element >> bit_position) & 1)) { - indices[offset_values] = static_cast(j); - offset_values++; - } - } - indptr[offset_indptr++] = static_cast(offset_values); - } - } - - void cpu_sddmm(const std::vector& A, - const std::vector& B, - std::vector& vals, - const std::vector& cols, - const std::vector& row_ptrs, - bool is_row_major_A, - bool is_row_major_B, - value_t alpha = 1.0, - value_t beta = 0.0) - { - if (params.n_queries * params.dim != static_cast(A.size()) || - params.dim * params.n_dataset != static_cast(B.size())) { - std::cerr << "Matrix dimensions and vector size do not match!" << std::endl; - return; - } - - bool trans_a = is_row_major_A; - bool trans_b = is_row_major_B; - - for (index_t i = 0; i < params.n_queries; ++i) { - for (index_t j = row_ptrs[i]; j < row_ptrs[i + 1]; ++j) { - value_t sum = 0; - value_t norms_A = 0; - value_t norms_B = 0; - for (index_t l = 0; l < params.dim; ++l) { - index_t a_index = trans_a ? i * params.dim + l : l * params.n_queries + i; - index_t b_index = trans_b ? l * params.n_dataset + cols[j] : cols[j] * params.dim + l; - sum += A[a_index] * B[b_index]; - - norms_A += A[a_index] * A[a_index]; - norms_B += B[b_index] * B[b_index]; - } - vals[j] = alpha * sum + beta * vals[j]; - if (params.metric == raft::distance::DistanceType::L2Expanded) { - vals[j] = value_t(-2.0) * vals[j] + norms_A + norms_B; - } else if (params.metric == raft::distance::DistanceType::L2SqrtExpanded) { - vals[j] = std::sqrt(value_t(-2.0) * vals[j] + norms_A + norms_B); - } else if (params.metric == raft::distance::DistanceType::CosineExpanded) { - vals[j] = value_t(1.0) - vals[j] / std::sqrt(norms_A * norms_B); - } - } - } - } - - void cpu_select_k(const std::vector& indptr_h, - const std::vector& indices_h, - const std::vector& values_h, - std::optional>& in_idx_h, - index_t n_queries, - index_t n_dataset, - index_t top_k, - std::vector& out_values_h, - std::vector& out_indices_h, - bool select_min = true) - { - auto comp = [select_min](const std::pair& a, - const std::pair& b) { - return select_min ? a.first < b.first : a.first >= b.first; - }; - - for (index_t row = 0; row < n_queries; ++row) { - std::priority_queue, - std::vector>, - decltype(comp)> - pq(comp); - - for (index_t idx = indptr_h[row]; idx < indptr_h[row + 1]; ++idx) { - pq.push({values_h[idx], (in_idx_h.has_value()) ? (*in_idx_h)[idx] : indices_h[idx]}); - if (pq.size() > size_t(top_k)) { pq.pop(); } - } - - std::vector> row_pairs; - while (!pq.empty()) { - row_pairs.push_back(pq.top()); - pq.pop(); - } - - if (select_min) { - std::sort(row_pairs.begin(), row_pairs.end(), [](const auto& a, const auto& b) { - return a.first <= b.first; - }); - } else { - std::sort(row_pairs.begin(), row_pairs.end(), [](const auto& a, const auto& b) { - return a.first >= b.first; - }); - } - for (index_t col = 0; col < top_k; col++) { - if (col < index_t(row_pairs.size())) { - out_values_h[row * top_k + col] = row_pairs[col].first; - out_indices_h[row * top_k + col] = row_pairs[col].second; - } - } - } - } - - void random_array(value_t* array, size_t size) - { - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_real_distribution dis(-10.0, 10.0); - std::unordered_set uset; - - while (uset.size() < size) { - uset.insert(dis(gen)); - } - typename std::unordered_set::iterator it = uset.begin(); - for (size_t i = 0; i < size; ++i) { - array[i] = *(it++); - } - } - - void SetUp() override - { - index_t element = - raft::ceildiv(params.n_queries * params.n_dataset, index_t(sizeof(bitmap_t) * 8)); - std::vector filter_h(element); - - nnz = create_sparse_matrix(params.n_queries, params.n_dataset, params.sparsity, filter_h); - - index_t dataset_size = params.n_dataset * params.dim; - index_t queries_size = params.n_queries * params.dim; - - std::vector dataset_h(dataset_size); - std::vector queries_h(queries_size); - - dataset_d.resize(dataset_size, stream); - queries_d.resize(queries_size, stream); - - auto blobs_in_val = - raft::make_device_matrix(handle, 1, dataset_size + queries_size); - auto labels = raft::make_device_vector(handle, 1); - - raft::random::make_blobs(blobs_in_val.data_handle(), - labels.data_handle(), - 1, - dataset_size + queries_size, - 1, - stream, - false, - nullptr, - nullptr, - value_t(1.0), - false, - value_t(-1.0f), - value_t(1.0f), - uint64_t(2024)); - - raft::copy(dataset_h.data(), blobs_in_val.data_handle(), dataset_size, stream); - raft::copy(dataset_d.data(), blobs_in_val.data_handle(), dataset_size, stream); - - raft::copy(queries_h.data(), blobs_in_val.data_handle() + dataset_size, queries_size, stream); - raft::copy(queries_d.data(), blobs_in_val.data_handle() + dataset_size, queries_size, stream); - - resource::sync_stream(handle); - - std::vector values_h(nnz); - std::vector indices_h(nnz); - std::vector indptr_h(params.n_queries + 1); - - filter_d.resize(filter_h.size(), stream); - cpu_convert_to_csr(filter_h, params.n_queries, params.n_dataset, indices_h, indptr_h); - - cpu_sddmm(queries_h, dataset_h, values_h, indices_h, indptr_h, true, false); - - bool select_min = raft::distance::is_min_close(params.metric); - - std::vector out_val_h(params.n_queries * params.top_k, - select_min ? std::numeric_limits::infinity() - : std::numeric_limits::lowest()); - std::vector out_idx_h(params.n_queries * params.top_k, static_cast(0)); - - out_val_d.resize(params.n_queries * params.top_k, stream); - out_idx_d.resize(params.n_queries * params.top_k, stream); - - update_device(out_val_d.data(), out_val_h.data(), out_val_h.size(), stream); - update_device(out_idx_d.data(), out_idx_h.data(), out_idx_h.size(), stream); - update_device(filter_d.data(), filter_h.data(), filter_h.size(), stream); - - resource::sync_stream(handle); - - std::optional> optional_indices_h = std::nullopt; - - cpu_select_k(indptr_h, - indices_h, - values_h, - optional_indices_h, - params.n_queries, - params.n_dataset, - params.top_k, - out_val_h, - out_idx_h, - select_min); - - out_val_expected_d.resize(params.n_queries * params.top_k, stream); - out_idx_expected_d.resize(params.n_queries * params.top_k, stream); - - update_device(out_val_expected_d.data(), out_val_h.data(), out_val_h.size(), stream); - update_device(out_idx_expected_d.data(), out_idx_h.data(), out_idx_h.size(), stream); - - resource::sync_stream(handle); - } - - void Run() - { - auto dataset_raw = raft::make_device_matrix_view( - (const value_t*)dataset_d.data(), params.n_dataset, params.dim); - - auto queries = raft::make_device_matrix_view( - (const value_t*)queries_d.data(), params.n_queries, params.dim); - - brute_force::index_params index_params{}; - index_params.metric = params.metric; - index_params.metric_arg = 0; - - auto dataset = brute_force::build(handle, index_params, dataset_raw); - - auto filter = - raft::core::bitmap_view((const bitmap_t*)filter_d.data(), params.n_queries, params.n_dataset); - - auto out_val = raft::make_device_matrix_view( - out_val_d.data(), params.n_queries, params.top_k); - auto out_idx = raft::make_device_matrix_view( - out_idx_d.data(), params.n_queries, params.top_k); - - brute_force::search_with_filtering(handle, dataset, queries, filter, out_idx, out_val); - - ASSERT_TRUE(raft::spatial::knn::devArrMatchKnnPair(out_idx_expected_d.data(), - out_idx.data_handle(), - out_val_expected_d.data(), - out_val.data_handle(), - params.n_queries, - params.top_k, - 0.001f, - stream, - true)); - } - - protected: - raft::resources handle; - cudaStream_t stream; - - PrefilteredBruteForceInputs params; - - index_t nnz; - - rmm::device_uvector dataset_d; - rmm::device_uvector queries_d; - rmm::device_uvector filter_d; - - rmm::device_uvector out_val_d; - rmm::device_uvector out_val_expected_d; - - rmm::device_uvector out_idx_d; - rmm::device_uvector out_idx_expected_d; -}; - -using PrefilteredBruteForceTest_float_int64 = PrefilteredBruteForceTest; -TEST_P(PrefilteredBruteForceTest_float_int64, Result) { Run(); } - -template -const std::vector> selectk_inputs = { - {1, 100000, 255, 255, 0.4, raft::distance::DistanceType::L2Expanded}, - {10, 100000, 512, 16, 0.5, raft::distance::DistanceType::L2Expanded}, - {20, 100000, 2052, 16, 0.2, raft::distance::DistanceType::L2Expanded}, - {1, 10000, 255, 16, 0.4, raft::distance::DistanceType::InnerProduct}, - {20, 10000, 512, 16, 0.5, raft::distance::DistanceType::InnerProduct}, - {100, 10000, 2052, 16, 0.2, raft::distance::DistanceType::InnerProduct}, - {1000, 10000, 1, 0, 0.1, raft::distance::DistanceType::L2Expanded}, - {1000, 10000, 3, 0, 0.1, raft::distance::DistanceType::InnerProduct}, - {1000, 10000, 5, 0, 0.1, raft::distance::DistanceType::L2SqrtExpanded}, - {1000, 10000, 8, 0, 0.1, raft::distance::DistanceType::CosineExpanded}, - - {1000, 10000, 1, 1, 0.1, raft::distance::DistanceType::L2Expanded}, - {1000, 10000, 3, 1, 0.1, raft::distance::DistanceType::InnerProduct}, - {1000, 10000, 5, 1, 0.1, raft::distance::DistanceType::L2SqrtExpanded}, - {1000, 10000, 8, 1, 0.1, raft::distance::DistanceType::CosineExpanded}, - - {1000, 10000, 2050, 16, 0.4, raft::distance::DistanceType::L2Expanded}, - {1000, 10000, 2051, 16, 0.5, raft::distance::DistanceType::L2Expanded}, - {1000, 10000, 2052, 16, 0.2, raft::distance::DistanceType::L2Expanded}, - {1000, 10000, 2050, 16, 0.4, raft::distance::DistanceType::InnerProduct}, - {1000, 10000, 2051, 16, 0.5, raft::distance::DistanceType::InnerProduct}, - {1000, 10000, 2052, 16, 0.2, raft::distance::DistanceType::InnerProduct}, - {1000, 10000, 2050, 16, 0.4, raft::distance::DistanceType::L2SqrtExpanded}, - {1000, 10000, 2051, 16, 0.5, raft::distance::DistanceType::L2SqrtExpanded}, - {1000, 10000, 2052, 16, 0.2, raft::distance::DistanceType::L2SqrtExpanded}, - {1000, 10000, 2050, 16, 0.4, raft::distance::DistanceType::CosineExpanded}, - {1000, 10000, 2051, 16, 0.5, raft::distance::DistanceType::CosineExpanded}, - {1000, 10000, 2052, 16, 0.2, raft::distance::DistanceType::CosineExpanded}, - - {1000, 10000, 1, 16, 0.5, raft::distance::DistanceType::L2Expanded}, - {1000, 10000, 2, 16, 0.2, raft::distance::DistanceType::L2Expanded}, - {1000, 10000, 3, 16, 0.4, raft::distance::DistanceType::InnerProduct}, - {1000, 10000, 4, 16, 0.5, raft::distance::DistanceType::InnerProduct}, - {1000, 10000, 5, 16, 0.2, raft::distance::DistanceType::L2SqrtExpanded}, - {1000, 10000, 8, 16, 0.4, raft::distance::DistanceType::L2SqrtExpanded}, - {1000, 10000, 5, 16, 0.5, raft::distance::DistanceType::CosineExpanded}, - {1000, 10000, 8, 16, 0.2, raft::distance::DistanceType::CosineExpanded}}; - -template -const std::vector> selectk_inputs_extra = - raft::util::itertools::product>( - {index_t(1), index_t(10), index_t(1000)}, // n_queries - {index_t(10 * 1024), index_t(100 * 1024)}, // n_dataset - {index_t(128), index_t(256), index_t(768), index_t(4096)}, // n_dim - {index_t(1), index_t(255), index_t(1024)}, // k - {float(0.0), float(0.2), float(0.01)}, // sparsity - {raft::distance::DistanceType::InnerProduct, - raft::distance::DistanceType::L2Expanded, - raft::distance::DistanceType::L2SqrtExpanded, - raft::distance::DistanceType::CosineExpanded}); - -INSTANTIATE_TEST_CASE_P(PrefilteredBruteForceTest, - PrefilteredBruteForceTest_float_int64, - ::testing::ValuesIn(selectk_inputs)); - -INSTANTIATE_TEST_CASE_P(PrefilteredBruteForceExtraTest, - PrefilteredBruteForceTest_float_int64, - ::testing::ValuesIn(selectk_inputs_extra)); - -} // namespace raft::neighbors::brute_force From 4f1aa171db5ac92133f06238886ec605af87e964 Mon Sep 17 00:00:00 2001 From: hrong Date: Tue, 21 May 2024 12:25:44 -0700 Subject: [PATCH 10/15] Fix CI issue --- cpp/test/CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index d97e4c3580..ff0518a4d0 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -339,7 +339,6 @@ if(BUILD_TESTS) sparse/neighbors/cross_component_nn.cu sparse/neighbors/brute_force.cu sparse/neighbors/knn_graph.cu - sparse/neighbors/prefiltered_brute_force.cu LIB EXPLICIT_INSTANTIATE_ONLY ) From 9e24c5a0ac51623c7b774829d291da9f4580ed6a Mon Sep 17 00:00:00 2001 From: hrong Date: Wed, 22 May 2024 13:38:23 -0700 Subject: [PATCH 11/15] Move sparse distance API utils to cuvs and split the bitmap --- cpp/include/raft/core/bitmap.cuh | 114 +++------------- cpp/include/raft/core/bitmap.hpp | 123 ++++++++++++++++++ .../raft/sparse/distance/detail/utils.cuh | 70 +--------- 3 files changed, 139 insertions(+), 168 deletions(-) create mode 100644 cpp/include/raft/core/bitmap.hpp diff --git a/cpp/include/raft/core/bitmap.cuh b/cpp/include/raft/core/bitmap.cuh index 0056cfa5f4..d5f3617c12 100644 --- a/cpp/include/raft/core/bitmap.cuh +++ b/cpp/include/raft/core/bitmap.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include #include #include #include @@ -25,105 +25,21 @@ #include namespace raft::core { -/** - * @defgroup bitmap Bitmap - * @{ - */ -/** - * @brief View of a RAFT Bitmap. - * - * This lightweight structure which represents and manipulates a two-dimensional bitmap matrix view - * with row major order. This class provides functionality for handling a matrix where each element - * is represented as a bit in a bitmap. - * - * @tparam bitmap_t Underlying type of the bitmap array. Default is uint32_t. - * @tparam index_t Indexing type used. Default is uint32_t. - */ -template -struct bitmap_view : public bitset_view { - static_assert((std::is_same::type, uint32_t>::value || - std::is_same::type, uint64_t>::value), - "The bitmap_t must be uint32_t or uint64_t."); - /** - * @brief Create a bitmap view from a device raw pointer. - * - * @param bitmap_ptr Device raw pointer - * @param rows Number of row in the matrix. - * @param cols Number of col in the matrix. - */ - _RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr, index_t rows, index_t cols) - : bitset_view(bitmap_ptr, rows * cols), rows_(rows), cols_(cols) - { - } - - /** - * @brief Create a bitmap view from a device vector view of the bitset. - * - * @param bitmap_span Device vector view of the bitmap - * @param rows Number of row in the matrix. - * @param cols Number of col in the matrix. - */ - _RAFT_HOST_DEVICE bitmap_view(raft::device_vector_view bitmap_span, - index_t rows, - index_t cols) - : bitset_view(bitmap_span, rows * cols), rows_(rows), cols_(cols) - { - } - - private: - // Hide the constructors of bitset_view. - _RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr, index_t bitmap_len) - : bitset_view(bitmap_ptr, bitmap_len) - { - } - - _RAFT_HOST_DEVICE bitmap_view(raft::device_vector_view bitmap_span, - index_t bitmap_len) - : bitset_view(bitmap_span, bitmap_len) - { - } - - public: - /** - * @brief Device function to test if a given row and col are set in the bitmap. - * - * @param row Row index of the bit to test - * @param col Col index of the bit to test - * @return bool True if index has not been unset in the bitset - */ - inline _RAFT_DEVICE auto test(const index_t row, const index_t col) const -> bool - { - return test(row * cols_ + col); - } - - /** - * @brief Device function to set a given row and col to set_value in the bitset. - * - * @param row Row index of the bit to set - * @param col Col index of the bit to set - * @param new_value Value to set the bit to (true or false) - */ - inline _RAFT_DEVICE void set(const index_t row, const index_t col, bool new_value) const - { - set(row * cols_ + col, &new_value); - } - - /** - * @brief Get the total number of rows - * @return index_t The total number of rows - */ - inline _RAFT_HOST_DEVICE index_t get_n_rows() const { return rows_; } - - /** - * @brief Get the total number of columns - * @return index_t The total number of columns - */ - inline _RAFT_HOST_DEVICE index_t get_n_cols() const { return cols_; } - private: - index_t rows_; - index_t cols_; -}; +template +_RAFT_HOST_DEVICE inline bool bitmap_view::test(const index_t row, + const index_t col) const +{ + return test(row * cols_ + col); +} + +template +_RAFT_HOST_DEVICE void bitmap_view::set(const index_t row, + const index_t col, + bool new_value) const +{ + set(row * cols_ + col, &new_value); +} /** @} */ } // end namespace raft::core diff --git a/cpp/include/raft/core/bitmap.hpp b/cpp/include/raft/core/bitmap.hpp new file mode 100644 index 0000000000..5c77866164 --- /dev/null +++ b/cpp/include/raft/core/bitmap.hpp @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace raft::core { +/** + * @defgroup bitmap Bitmap + * @{ + */ +/** + * @brief View of a RAFT Bitmap. + * + * This lightweight structure which represents and manipulates a two-dimensional bitmap matrix view + * with row major order. This class provides functionality for handling a matrix where each element + * is represented as a bit in a bitmap. + * + * @tparam bitmap_t Underlying type of the bitmap array. Default is uint32_t. + * @tparam index_t Indexing type used. Default is uint32_t. + */ +template +struct bitmap_view : public bitset_view { + static_assert((std::is_same::type, uint32_t>::value || + std::is_same::type, uint64_t>::value), + "The bitmap_t must be uint32_t or uint64_t."); + /** + * @brief Create a bitmap view from a device raw pointer. + * + * @param bitmap_ptr Device raw pointer + * @param rows Number of row in the matrix. + * @param cols Number of col in the matrix. + */ + _RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr, index_t rows, index_t cols) + : bitset_view(bitmap_ptr, rows * cols), rows_(rows), cols_(cols) + { + } + + /** + * @brief Create a bitmap view from a device vector view of the bitset. + * + * @param bitmap_span Device vector view of the bitmap + * @param rows Number of row in the matrix. + * @param cols Number of col in the matrix. + */ + _RAFT_HOST_DEVICE bitmap_view(raft::device_vector_view bitmap_span, + index_t rows, + index_t cols) + : bitset_view(bitmap_span, rows * cols), rows_(rows), cols_(cols) + { + } + + private: + // Hide the constructors of bitset_view. + _RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr, index_t bitmap_len) + : bitset_view(bitmap_ptr, bitmap_len) + { + } + + _RAFT_HOST_DEVICE bitmap_view(raft::device_vector_view bitmap_span, + index_t bitmap_len) + : bitset_view(bitmap_span, bitmap_len) + { + } + + public: + /** + * @brief Device function to test if a given row and col are set in the bitmap. + * + * @param row Row index of the bit to test + * @param col Col index of the bit to test + * @return bool True if index has not been unset in the bitset + */ + inline _RAFT_HOST_DEVICE bool test(const index_t row, const index_t col) const; + + /** + * @brief Device function to set a given row and col to set_value in the bitset. + * + * @param row Row index of the bit to set + * @param col Col index of the bit to set + * @param new_value Value to set the bit to (true or false) + */ + inline _RAFT_HOST_DEVICE void set(const index_t row, const index_t col, bool new_value) const; + + /** + * @brief Get the total number of rows + * @return index_t The total number of rows + */ + inline _RAFT_HOST_DEVICE index_t get_n_rows() const { return rows_; } + + /** + * @brief Get the total number of columns + * @return index_t The total number of columns + */ + inline _RAFT_HOST_DEVICE index_t get_n_cols() const { return cols_; } + + private: + index_t rows_; + index_t cols_; +}; + +/** @} */ +} // end namespace raft::core diff --git a/cpp/include/raft/sparse/distance/detail/utils.cuh b/cpp/include/raft/sparse/distance/detail/utils.cuh index 6122641f7d..9799dec0b4 100644 --- a/cpp/include/raft/sparse/distance/detail/utils.cuh +++ b/cpp/include/raft/sparse/distance/detail/utils.cuh @@ -41,75 +41,6 @@ inline int max_cols_per_block() sizeof(value_t); } -template -RAFT_KERNEL epilogue_on_csr_kernel(value_t* __restrict__ compressed_C, - const value_idx* __restrict__ rows, - const value_idx* __restrict__ cols, - const value_t* __restrict__ Q_sq_norms, - const value_t* __restrict__ R_sq_norms, - value_idx nnz, - expansion_f expansion_func) -{ - auto tid = blockDim.x * blockIdx.x + threadIdx.x; - - if (tid >= nnz) return; - const value_idx i = rows[tid]; - const value_idx j = cols[tid]; - - compressed_C[tid] = expansion_func(compressed_C[tid], Q_sq_norms[i], R_sq_norms[j]); -} - -template -void epilogue_on_csr(raft::resources const& handle, - value_t* compressed_C, - const value_idx nnz, - const value_idx* rows, - const value_idx* cols, - const value_t* Q_sq_norms, - const value_t* R_sq_norms, - raft::distance::DistanceType metric) -{ - if (nnz == 0) return; - auto stream = resource::get_cuda_stream(handle); - - int blocks = raft::ceildiv((size_t)nnz, tpb); - if (metric == raft::distance::DistanceType::L2Expanded) { - epilogue_on_csr_kernel<<>>( - compressed_C, - rows, - cols, - Q_sq_norms, - R_sq_norms, - nnz, - [] __device__ __host__(value_t dot, value_t q_norm, value_t r_norm) -> value_t { - return value_t(-2.0) * dot + q_norm + r_norm; - }); - } else if (metric == raft::distance::DistanceType::L2SqrtExpanded) { - epilogue_on_csr_kernel<<>>( - compressed_C, - rows, - cols, - Q_sq_norms, - R_sq_norms, - nnz, - [] __device__ __host__(value_t dot, value_t q_norm, value_t r_norm) -> value_t { - return raft::sqrt(value_t(-2.0) * dot + q_norm + r_norm); - }); - } else if (metric == raft::distance::DistanceType::CosineExpanded) { - epilogue_on_csr_kernel<<>>( - compressed_C, - rows, - cols, - Q_sq_norms, - R_sq_norms, - nnz, - [] __device__ __host__(value_t dot, value_t q_norm, value_t r_norm) -> value_t { - return value_t(1.0) - dot / (q_norm * r_norm); - }); - } - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - template RAFT_KERNEL faster_dot_on_csr_kernel(value_t* __restrict__ dot, const value_idx* __restrict__ indptr, @@ -234,6 +165,7 @@ void faster_dot_on_csr(raft::resources const& handle, RAFT_CUDA_TRY(cudaPeekAtLastError()); } + } // namespace detail } // namespace distance } // namespace sparse From 18cb672eb81a54ba20d54ab50d563b2a252f4e34 Mon Sep 17 00:00:00 2001 From: hrong Date: Wed, 22 May 2024 17:33:39 -0700 Subject: [PATCH 12/15] Optimize by review comments --- cpp/include/raft/core/bitmap.cuh | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/include/raft/core/bitmap.cuh b/cpp/include/raft/core/bitmap.cuh index d5f3617c12..30fb9fcfcd 100644 --- a/cpp/include/raft/core/bitmap.cuh +++ b/cpp/include/raft/core/bitmap.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include #include From 97a0e74aafcbc49014f09cb4fce4f5b1247a3dea Mon Sep 17 00:00:00 2001 From: hrong Date: Thu, 23 May 2024 13:41:06 -0700 Subject: [PATCH 13/15] Remove the sparse select_k instantiations - plus Improve C++ maintainability --- cpp/CMakeLists.txt | 8 ----- .../sparse/convert/detail/bitmap_to_csr.cuh | 4 +-- .../raft/sparse/distance/detail/utils.cuh | 4 --- .../matrix/detail/select_k_double_int64_t.cu | 32 ----------------- .../matrix/detail/select_k_double_uint32_t.cu | 34 ------------------- .../matrix/detail/select_k_float_int32.cu | 32 ----------------- .../matrix/detail/select_k_float_int64_t.cu | 32 ----------------- .../matrix/detail/select_k_float_uint32_t.cu | 32 ----------------- .../matrix/detail/select_k_half_int64_t.cu | 32 ----------------- .../matrix/detail/select_k_half_uint32_t.cu | 32 ----------------- 10 files changed, 2 insertions(+), 240 deletions(-) delete mode 100644 cpp/src/sparse/matrix/detail/select_k_double_int64_t.cu delete mode 100644 cpp/src/sparse/matrix/detail/select_k_double_uint32_t.cu delete mode 100644 cpp/src/sparse/matrix/detail/select_k_float_int32.cu delete mode 100644 cpp/src/sparse/matrix/detail/select_k_float_int64_t.cu delete mode 100644 cpp/src/sparse/matrix/detail/select_k_float_uint32_t.cu delete mode 100644 cpp/src/sparse/matrix/detail/select_k_half_int64_t.cu delete mode 100644 cpp/src/sparse/matrix/detail/select_k_half_uint32_t.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 9a910bda52..39472cae67 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -334,14 +334,6 @@ if(RAFT_COMPILE_LIBRARY) src/matrix/detail/select_k_float_int32.cu src/matrix/detail/select_k_half_int64_t.cu src/matrix/detail/select_k_half_uint32_t.cu - src/sparse/matrix/detail/select_k_half_uint32_t.cu - src/sparse/matrix/detail/select_k_double_int64_t.cu - src/sparse/matrix/detail/select_k_double_uint32_t.cu - src/sparse/matrix/detail/select_k_float_int64_t.cu - src/sparse/matrix/detail/select_k_float_uint32_t.cu - src/sparse/matrix/detail/select_k_float_int32.cu - src/sparse/matrix/detail/select_k_half_int64_t.cu - src/sparse/matrix/detail/select_k_half_uint32_t.cu src/neighbors/ball_cover.cu src/neighbors/brute_force_fused_l2_knn_float_int64_t.cu src/neighbors/brute_force_knn_int64_t_float_int64_t.cu diff --git a/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh index a9624d891a..8c299954a8 100644 --- a/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh +++ b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh @@ -68,7 +68,7 @@ RAFT_KERNEL __launch_bounds__(calc_nnz_by_rows_tpb) calc_nnz_by_rows_kernel(cons while (offset < num_cols) { index_t bitmap_idx = lane_id + (s_bit + offset) / BITS_PER_BITMAP; - typename std::remove_const::type l_bitmap = 0; + std::remove_const_t l_bitmap = 0; if (bitmap_idx * BITS_PER_BITMAP < e_bit) { l_bitmap = bitmap[bitmap_idx]; } @@ -177,7 +177,7 @@ RAFT_KERNEL __launch_bounds__(fill_indices_by_rows_tpb) #pragma unroll for (index_t offset = 0; offset < num_cols; offset += BITS_PER_BITMAP * warpSize) { index_t bitmap_idx = lane_id + (s_bit + offset) / BITS_PER_BITMAP; - typename std::remove_const::type l_bitmap = 0; + std::remove_const_t l_bitmap = 0; index_t l_offset = offset + lane_id * BITS_PER_BITMAP - (s_bit % BITS_PER_BITMAP); if (bitmap_idx * BITS_PER_BITMAP < e_bit) { l_bitmap = bitmap[bitmap_idx]; } diff --git a/cpp/include/raft/sparse/distance/detail/utils.cuh b/cpp/include/raft/sparse/distance/detail/utils.cuh index 9799dec0b4..42b545180b 100644 --- a/cpp/include/raft/sparse/distance/detail/utils.cuh +++ b/cpp/include/raft/sparse/distance/detail/utils.cuh @@ -58,9 +58,7 @@ RAFT_KERNEL faster_dot_on_csr_kernel(value_t* __restrict__ dot, value_t* s_A = (value_t*)smem; value_idx cur_row = -1; -#pragma unroll for (int row = blockIdx.x; row < n_rows; row += gridDim.x) { -#pragma unroll for (int dot_id = blockIdx.y + indptr[row]; dot_id < indptr[row + 1]; dot_id += gridDim.y) { if (dot_id >= nnz) { return; } const value_idx col = cols[dot_id] * dim; @@ -70,7 +68,6 @@ RAFT_KERNEL faster_dot_on_csr_kernel(value_t* __restrict__ dot, __syncthreads(); if (cur_row != row) { -#pragma unroll for (value_idx k = vec_id; k < dim; k += blockDim.x) { s_A[k] = A[row * dim + k]; } @@ -78,7 +75,6 @@ RAFT_KERNEL faster_dot_on_csr_kernel(value_t* __restrict__ dot, } value_t l_dot_ = 0.0; -#pragma unroll for (value_idx k = vec_id; k < dim; k += blockDim.x) { asm("prefetch.global.L2 [%0];" ::"l"(B_col + k + blockDim.x)); l_dot_ += s_A[k] * __ldcg(B_col + k); diff --git a/cpp/src/sparse/matrix/detail/select_k_double_int64_t.cu b/cpp/src/sparse/matrix/detail/select_k_double_int64_t.cu deleted file mode 100644 index c784b50dad..0000000000 --- a/cpp/src/sparse/matrix/detail/select_k_double_int64_t.cu +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#define instantiate_raft_sparse_matrix_detail_select_k(T, IdxT) \ - template void raft::sparse::matrix::detail::select_k( \ - raft::resources const& handle, \ - raft::device_csr_matrix_view in_val, \ - std::optional> in_idx, \ - raft::device_matrix_view out_val, \ - raft::device_matrix_view out_idx, \ - bool select_min, \ - bool sorted, \ - raft::matrix::SelectAlgo algo) - -instantiate_raft_sparse_matrix_detail_select_k(double, int64_t); - -#undef instantiate_raft_sparse_matrix_detail_select_k \ No newline at end of file diff --git a/cpp/src/sparse/matrix/detail/select_k_double_uint32_t.cu b/cpp/src/sparse/matrix/detail/select_k_double_uint32_t.cu deleted file mode 100644 index 98bab9a504..0000000000 --- a/cpp/src/sparse/matrix/detail/select_k_double_uint32_t.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include // uint32_t - -#define instantiate_raft_sparse_matrix_detail_select_k(T, IdxT) \ - template void raft::sparse::matrix::detail::select_k( \ - raft::resources const& handle, \ - raft::device_csr_matrix_view in_val, \ - std::optional> in_idx, \ - raft::device_matrix_view out_val, \ - raft::device_matrix_view out_idx, \ - bool select_min, \ - bool sorted, \ - raft::matrix::SelectAlgo algo) - -instantiate_raft_sparse_matrix_detail_select_k(double, uint32_t); - -#undef instantiate_raft_sparse_matrix_detail_select_k \ No newline at end of file diff --git a/cpp/src/sparse/matrix/detail/select_k_float_int32.cu b/cpp/src/sparse/matrix/detail/select_k_float_int32.cu deleted file mode 100644 index 49bec86e6e..0000000000 --- a/cpp/src/sparse/matrix/detail/select_k_float_int32.cu +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#define instantiate_raft_sparse_matrix_detail_select_k(T, IdxT) \ - template void raft::sparse::matrix::detail::select_k( \ - raft::resources const& handle, \ - raft::device_csr_matrix_view in_val, \ - std::optional> in_idx, \ - raft::device_matrix_view out_val, \ - raft::device_matrix_view out_idx, \ - bool select_min, \ - bool sorted, \ - raft::matrix::SelectAlgo algo) - -instantiate_raft_sparse_matrix_detail_select_k(float, int); - -#undef instantiate_raft_sparse_matrix_detail_select_k diff --git a/cpp/src/sparse/matrix/detail/select_k_float_int64_t.cu b/cpp/src/sparse/matrix/detail/select_k_float_int64_t.cu deleted file mode 100644 index 412b06e587..0000000000 --- a/cpp/src/sparse/matrix/detail/select_k_float_int64_t.cu +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#define instantiate_raft_sparse_matrix_detail_select_k(T, IdxT) \ - template void raft::sparse::matrix::detail::select_k( \ - raft::resources const& handle, \ - raft::device_csr_matrix_view in_val, \ - std::optional> in_idx, \ - raft::device_matrix_view out_val, \ - raft::device_matrix_view out_idx, \ - bool select_min, \ - bool sorted, \ - raft::matrix::SelectAlgo algo) - -instantiate_raft_sparse_matrix_detail_select_k(float, int64_t); - -#undef instantiate_raft_sparse_matrix_detail_select_k diff --git a/cpp/src/sparse/matrix/detail/select_k_float_uint32_t.cu b/cpp/src/sparse/matrix/detail/select_k_float_uint32_t.cu deleted file mode 100644 index 8ba3f0e22b..0000000000 --- a/cpp/src/sparse/matrix/detail/select_k_float_uint32_t.cu +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#define instantiate_raft_sparse_matrix_detail_select_k(T, IdxT) \ - template void raft::sparse::matrix::detail::select_k( \ - raft::resources const& handle, \ - raft::device_csr_matrix_view in_val, \ - std::optional> in_idx, \ - raft::device_matrix_view out_val, \ - raft::device_matrix_view out_idx, \ - bool select_min, \ - bool sorted, \ - raft::matrix::SelectAlgo algo) - -instantiate_raft_sparse_matrix_detail_select_k(float, uint32_t); - -#undef instantiate_raft_sparse_matrix_detail_select_k diff --git a/cpp/src/sparse/matrix/detail/select_k_half_int64_t.cu b/cpp/src/sparse/matrix/detail/select_k_half_int64_t.cu deleted file mode 100644 index 24c844f8c8..0000000000 --- a/cpp/src/sparse/matrix/detail/select_k_half_int64_t.cu +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#define instantiate_raft_sparse_matrix_detail_select_k(T, IdxT) \ - template void raft::sparse::matrix::detail::select_k( \ - raft::resources const& handle, \ - raft::device_csr_matrix_view in_val, \ - std::optional> in_idx, \ - raft::device_matrix_view out_val, \ - raft::device_matrix_view out_idx, \ - bool select_min, \ - bool sorted, \ - raft::matrix::SelectAlgo algo) - -instantiate_raft_sparse_matrix_detail_select_k(__half, int64_t); - -#undef instantiate_raft_sparse_matrix_detail_select_k diff --git a/cpp/src/sparse/matrix/detail/select_k_half_uint32_t.cu b/cpp/src/sparse/matrix/detail/select_k_half_uint32_t.cu deleted file mode 100644 index d63dc64933..0000000000 --- a/cpp/src/sparse/matrix/detail/select_k_half_uint32_t.cu +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#define instantiate_raft_sparse_matrix_detail_select_k(T, IdxT) \ - template void raft::sparse::matrix::detail::select_k( \ - raft::resources const& handle, \ - raft::device_csr_matrix_view in_val, \ - std::optional> in_idx, \ - raft::device_matrix_view out_val, \ - raft::device_matrix_view out_idx, \ - bool select_min, \ - bool sorted, \ - raft::matrix::SelectAlgo algo) - -instantiate_raft_sparse_matrix_detail_select_k(__half, uint32_t); - -#undef instantiate_raft_sparse_matrix_detail_select_k From 7d08443a6717f48d08df4567e3863b212f04a635 Mon Sep 17 00:00:00 2001 From: hrong Date: Thu, 23 May 2024 14:23:10 -0700 Subject: [PATCH 14/15] Fix CI issue --- cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh index 8c299954a8..b1b0291a85 100644 --- a/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh +++ b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh @@ -67,8 +67,8 @@ RAFT_KERNEL __launch_bounds__(calc_nnz_by_rows_tpb) calc_nnz_by_rows_kernel(cons index_t l_sum = 0; while (offset < num_cols) { - index_t bitmap_idx = lane_id + (s_bit + offset) / BITS_PER_BITMAP; - std::remove_const_t l_bitmap = 0; + index_t bitmap_idx = lane_id + (s_bit + offset) / BITS_PER_BITMAP; + std::remove_const_t l_bitmap = 0; if (bitmap_idx * BITS_PER_BITMAP < e_bit) { l_bitmap = bitmap[bitmap_idx]; } @@ -176,8 +176,8 @@ RAFT_KERNEL __launch_bounds__(fill_indices_by_rows_tpb) #pragma unroll for (index_t offset = 0; offset < num_cols; offset += BITS_PER_BITMAP * warpSize) { - index_t bitmap_idx = lane_id + (s_bit + offset) / BITS_PER_BITMAP; - std::remove_const_t l_bitmap = 0; + index_t bitmap_idx = lane_id + (s_bit + offset) / BITS_PER_BITMAP; + std::remove_const_t l_bitmap = 0; index_t l_offset = offset + lane_id * BITS_PER_BITMAP - (s_bit % BITS_PER_BITMAP); if (bitmap_idx * BITS_PER_BITMAP < e_bit) { l_bitmap = bitmap[bitmap_idx]; } From f38642f725f5cfc0a74622fbef5eb9563efd21a3 Mon Sep 17 00:00:00 2001 From: hrong Date: Thu, 23 May 2024 18:26:47 -0700 Subject: [PATCH 15/15] Fix docs issue. --- cpp/include/raft/core/bitmap.cuh | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/include/raft/core/bitmap.cuh b/cpp/include/raft/core/bitmap.cuh index 30fb9fcfcd..2c23a77e47 100644 --- a/cpp/include/raft/core/bitmap.cuh +++ b/cpp/include/raft/core/bitmap.cuh @@ -42,5 +42,4 @@ _RAFT_HOST_DEVICE void bitmap_view::set(const index_t row, set(row * cols_ + col, &new_value); } -/** @} */ } // end namespace raft::core