diff --git a/cpp/include/cuvs/core/bitmap.hpp b/cpp/include/cuvs/core/bitmap.hpp new file mode 100644 index 000000000..80ae25cd2 --- /dev/null +++ b/cpp/include/cuvs/core/bitmap.hpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace cuvs::core { +/* To use bitmap functions containing CUDA code, include */ + +template +using bitmap_view = raft::core::bitmap_view; + +} // end namespace cuvs::core diff --git a/cpp/include/cuvs/neighbors/brute_force.hpp b/cpp/include/cuvs/neighbors/brute_force.hpp index 755c8cfdb..13a5ea0cb 100644 --- a/cpp/include/cuvs/neighbors/brute_force.hpp +++ b/cpp/include/cuvs/neighbors/brute_force.hpp @@ -191,12 +191,15 @@ auto build(raft::resources const& handle, * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset * [n_queries, k] * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] + * @param[in] sample_filter a optional device bitmap filter function that greenlights samples for a + * given */ void search(raft::resources const& handle, const cuvs::neighbors::brute_force::index& index, raft::device_matrix_view queries, raft::device_matrix_view neighbors, - raft::device_matrix_view distances); + raft::device_matrix_view distances, + std::optional> sample_filter); /** * @} */ diff --git a/cpp/include/cuvs/neighbors/common.hpp b/cpp/include/cuvs/neighbors/common.hpp index 45fa1a107..72d35961f 100644 --- a/cpp/include/cuvs/neighbors/common.hpp +++ b/cpp/include/cuvs/neighbors/common.hpp @@ -25,6 +25,7 @@ #include // get_device_for_address #include // rounding up +#include #include #include diff --git a/cpp/src/neighbors/brute_force.cu b/cpp/src/neighbors/brute_force.cu index 45d4be4a7..13554c0b5 100644 --- a/cpp/src/neighbors/brute_force.cu +++ b/cpp/src/neighbors/brute_force.cu @@ -15,6 +15,7 @@ */ #include "./detail/knn_brute_force.cuh" + #include #include @@ -84,25 +85,32 @@ void index::update_dataset(raft::resources const& res, dataset_view_ = raft::make_const_mdspan(dataset_.view()); } -#define CUVS_INST_BFKNN(T) \ - auto build(raft::resources const& res, \ - raft::device_matrix_view dataset, \ - cuvs::distance::DistanceType metric, \ - T metric_arg) \ - ->cuvs::neighbors::brute_force::index \ - { \ - return detail::build(res, dataset, metric, metric_arg); \ - } \ - \ - void search(raft::resources const& res, \ - const cuvs::neighbors::brute_force::index& idx, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances) \ - { \ - detail::brute_force_search(res, idx, queries, neighbors, distances); \ - } \ - \ +#define CUVS_INST_BFKNN(T) \ + auto build(raft::resources const& res, \ + raft::device_matrix_view dataset, \ + cuvs::distance::DistanceType metric, \ + T metric_arg) \ + ->cuvs::neighbors::brute_force::index \ + { \ + return detail::build(res, dataset, metric, metric_arg); \ + } \ + \ + void search( \ + raft::resources const& res, \ + const cuvs::neighbors::brute_force::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + std::optional> sample_filter = std::nullopt) \ + { \ + if (!sample_filter.has_value()) { \ + detail::brute_force_search(res, idx, queries, neighbors, distances); \ + } else { \ + detail::brute_force_search_filtered( \ + res, idx, queries, *sample_filter, neighbors, distances); \ + } \ + } \ + \ template struct cuvs::neighbors::brute_force::index; CUVS_INST_BFKNN(float); diff --git a/cpp/src/neighbors/brute_force_c.cpp b/cpp/src/neighbors/brute_force_c.cpp index e988ac2f0..5f04ffa34 100644 --- a/cpp/src/neighbors/brute_force_c.cpp +++ b/cpp/src/neighbors/brute_force_c.cpp @@ -66,7 +66,7 @@ void _search(cuvsResources_t res, auto distances_mds = cuvs::core::from_dlpack(distances_tensor); cuvs::neighbors::brute_force::search( - *res_ptr, *index_ptr, queries_mds, neighbors_mds, distances_mds); + *res_ptr, *index_ptr, queries_mds, neighbors_mds, distances_mds, std::nullopt); } } // namespace diff --git a/cpp/src/neighbors/detail/knn_brute_force.cuh b/cpp/src/neighbors/detail/knn_brute_force.cuh index 29cd26d9f..4865ade77 100644 --- a/cpp/src/neighbors/detail/knn_brute_force.cuh +++ b/cpp/src/neighbors/detail/knn_brute_force.cuh @@ -15,6 +15,7 @@ */ #pragma once + #include #include @@ -23,16 +24,26 @@ #include "./fused_l2_knn.cuh" #include "./haversine_distance.cuh" #include "./knn_merge_parts.cuh" +#include "./knn_utils.cuh" +#include +#include +#include #include #include #include #include #include #include +#include #include #include #include +#include +#include +#include +#include +#include #include #include @@ -65,7 +76,8 @@ void tiled_brute_force_knn(const raft::resources& handle, size_t max_row_tile_size = 0, size_t max_col_tile_size = 0, const ElementType* precomputed_index_norms = nullptr, - const ElementType* precomputed_search_norms = nullptr) + const ElementType* precomputed_search_norms = nullptr, + const uint32_t* filter_bitmap = nullptr) { // Figure out the number of rows/cols to tile for size_t tile_rows = 0; @@ -214,6 +226,27 @@ void tiled_brute_force_knn(const raft::resources& handle, }); } + if (filter_bitmap != nullptr) { + auto distances_ptr = temp_distances.data(); + auto count = thrust::make_counting_iterator(0); + ElementType masked_distance = select_min ? std::numeric_limits::infinity() + : std::numeric_limits::lowest(); + thrust::for_each(raft::resource::get_thrust_policy(handle), + count, + count + current_query_size * current_centroid_size, + [=] __device__(IndexType idx) { + IndexType row = i + (idx / current_centroid_size); + IndexType col = j + (idx % current_centroid_size); + IndexType g_idx = row * n + col; + IndexType item_idx = (g_idx) >> 5; + uint32_t bit_idx = (g_idx)&31; + uint32_t filter = filter_bitmap[item_idx]; + if ((filter & (uint32_t(1) << bit_idx)) == 0) { + distances_ptr[idx] = masked_distance; + } + }); + } + raft::matrix::select_k( handle, raft::make_device_matrix_view( @@ -519,6 +552,173 @@ void brute_force_search( query_norms ? query_norms->data_handle() : nullptr); } +template +void brute_force_search_filtered( + raft::resources const& res, + const cuvs::neighbors::brute_force::index& idx, + raft::device_matrix_view queries, + cuvs::core::bitmap_view filter, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + std::optional> query_norms = std::nullopt) +{ + auto metric = idx.metric(); + + RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), "Value of k must match for outputs"); + RAFT_EXPECTS(idx.dataset().extent(1) == queries.extent(1), + "Number of columns in queries must match brute force index"); + RAFT_EXPECTS(metric == cuvs::distance::DistanceType::InnerProduct || + metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded || + metric == cuvs::distance::DistanceType::CosineExpanded, + "Only Euclidean, IP, and Cosine are supported!"); + + RAFT_EXPECTS(idx.has_norms() || !(metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded || + metric == cuvs::distance::DistanceType::CosineExpanded), + "Index must has norms when using Euclidean, IP, and Cosine!"); + + IdxT n_queries = queries.extent(0); + IdxT n_dataset = idx.dataset().extent(0); + IdxT dim = idx.dataset().extent(1); + IdxT k = neighbors.extent(1); + + auto stream = raft::resource::get_cuda_stream(res); + + // calc nnz + IdxT nnz_h = 0; + rmm::device_scalar nnz(0, stream); + auto nnz_view = raft::make_device_scalar_view(nnz.data()); + auto filter_view = + raft::make_device_vector_view(filter.data(), filter.n_elements()); + + // TODO(rhdong): Need to switch to the public API, + // with the issue: https://github.com/rapidsai/cuvs/issues/158 + raft::detail::popc(res, filter_view, n_queries * n_dataset, nnz_view); + raft::copy(&nnz_h, nnz.data(), 1, stream); + + raft::resource::sync_stream(res, stream); + float sparsity = (1.0f * nnz_h / (1.0f * n_queries * n_dataset)); + + if (sparsity > 0.01f) { + raft::resources stream_pool_handle(res); + raft::resource::set_cuda_stream(stream_pool_handle, stream); + auto idx_norm = idx.has_norms() ? const_cast(idx.norms().data_handle()) : nullptr; + + tiled_brute_force_knn(stream_pool_handle, + queries.data_handle(), + idx.dataset().data_handle(), + n_queries, + n_dataset, + dim, + k, + distances.data_handle(), + neighbors.data_handle(), + metric, + 2.0, + 0, + 0, + idx_norm, + nullptr, + filter.data()); + } else { + auto csr = raft::make_device_csr_matrix(res, n_queries, n_dataset, nnz_h); + + // fill csr + raft::sparse::convert::bitmap_to_csr(res, filter, csr); + + // create filter csr view + auto compressed_csr_view = csr.structure_view(); + rmm::device_uvector rows(compressed_csr_view.get_nnz(), stream); + raft::sparse::convert::csr_to_coo(compressed_csr_view.get_indptr().data(), + compressed_csr_view.get_n_rows(), + rows.data(), + compressed_csr_view.get_nnz(), + stream); + if (n_queries > 10) { + auto csr_view = raft::make_device_csr_matrix_view( + csr.get_elements().data(), compressed_csr_view); + + // create dataset view + auto dataset_view = raft::make_device_matrix_view( + idx.dataset().data_handle(), dim, n_dataset); + + // calc dot + T alpha = static_cast(1.0f); + T beta = static_cast(0.0f); + raft::sparse::linalg::sddmm(res, + queries, + dataset_view, + csr_view, + raft::linalg::Operation::NON_TRANSPOSE, + raft::linalg::Operation::NON_TRANSPOSE, + raft::make_host_scalar_view(&alpha), + raft::make_host_scalar_view(&beta)); + } else { + raft::sparse::distance::detail::faster_dot_on_csr(res, + csr.get_elements().data(), + compressed_csr_view.get_nnz(), + compressed_csr_view.get_indptr().data(), + compressed_csr_view.get_indices().data(), + queries.data_handle(), + idx.dataset().data_handle(), + compressed_csr_view.get_n_rows(), + dim); + } + + // post process + std::optional> query_norms_; + if (metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded || + metric == cuvs::distance::DistanceType::CosineExpanded) { + if (metric == cuvs::distance::DistanceType::CosineExpanded) { + if (!query_norms) { + query_norms_ = raft::make_device_vector(res, n_queries); + raft::linalg::rowNorm((T*)(query_norms_->data_handle()), + queries.data_handle(), + dim, + n_queries, + raft::linalg::L2Norm, + true, + stream, + raft::sqrt_op{}); + } + } else { + if (!query_norms) { + query_norms_ = raft::make_device_vector(res, n_queries); + raft::linalg::rowNorm((T*)(query_norms_->data_handle()), + queries.data_handle(), + dim, + n_queries, + raft::linalg::L2Norm, + true, + stream, + raft::identity_op{}); + } + } + cuvs::neighbors::detail::epilogue_on_csr( + res, + csr.get_elements().data(), + compressed_csr_view.get_nnz(), + rows.data(), + compressed_csr_view.get_indices().data(), + query_norms ? query_norms->data_handle() : query_norms_->data_handle(), + idx.norms().data_handle(), + metric); + } + + // select k + auto const_csr_view = raft::make_device_csr_matrix_view( + csr.get_elements().data(), compressed_csr_view); + std::optional> no_opt = std::nullopt; + bool select_min = cuvs::distance::is_min_close(metric); + raft::sparse::matrix::select_k( + res, const_csr_view, no_opt, distances, neighbors, select_min, true); + } + + return; +} + template cuvs::neighbors::brute_force::index build( raft::resources const& res, diff --git a/cpp/src/neighbors/detail/knn_utils.cuh b/cpp/src/neighbors/detail/knn_utils.cuh new file mode 100644 index 000000000..1cc709fa4 --- /dev/null +++ b/cpp/src/neighbors/detail/knn_utils.cuh @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include + +#include +#include + +namespace cuvs::neighbors::detail { + +template +RAFT_KERNEL epilogue_on_csr_kernel(value_t* __restrict__ compressed_C, + const value_idx* __restrict__ rows, + const value_idx* __restrict__ cols, + const value_t* __restrict__ Q_sq_norms, + const value_t* __restrict__ R_sq_norms, + value_idx nnz, + expansion_f expansion_func) +{ + auto tid = blockDim.x * blockIdx.x + threadIdx.x; + + if (tid >= nnz) return; + const value_idx i = rows[tid]; + const value_idx j = cols[tid]; + + compressed_C[tid] = expansion_func(compressed_C[tid], Q_sq_norms[i], R_sq_norms[j]); +} + +template +void epilogue_on_csr(raft::resources const& handle, + value_t* compressed_C, + const value_idx nnz, + const value_idx* rows, + const value_idx* cols, + const value_t* Q_sq_norms, + const value_t* R_sq_norms, + cuvs::distance::DistanceType metric) +{ + if (nnz == 0) return; + auto stream = raft::resource::get_cuda_stream(handle); + + int blocks = raft::ceildiv((size_t)nnz, tpb); + if (metric == cuvs::distance::DistanceType::L2Expanded) { + epilogue_on_csr_kernel<<>>( + compressed_C, + rows, + cols, + Q_sq_norms, + R_sq_norms, + nnz, + [] __device__ __host__(value_t dot, value_t q_norm, value_t r_norm) -> value_t { + return value_t(-2.0) * dot + q_norm + r_norm; + }); + } else if (metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + epilogue_on_csr_kernel<<>>( + compressed_C, + rows, + cols, + Q_sq_norms, + R_sq_norms, + nnz, + [] __device__ __host__(value_t dot, value_t q_norm, value_t r_norm) -> value_t { + return raft::sqrt(value_t(-2.0) * dot + q_norm + r_norm); + }); + } else if (metric == cuvs::distance::DistanceType::CosineExpanded) { + epilogue_on_csr_kernel<<>>( + compressed_C, + rows, + cols, + Q_sq_norms, + R_sq_norms, + nnz, + [] __device__ __host__(value_t dot, value_t q_norm, value_t r_norm) -> value_t { + return value_t(1.0) - dot / (q_norm * r_norm); + }); + } + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} +} // namespace cuvs::neighbors::detail diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 3fbf62cdb..1fae2f70b 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -91,7 +91,15 @@ endfunction() if(BUILD_TESTS) ConfigureTest( - NAME NEIGHBORS_TEST PATH test/neighbors/brute_force.cu test/neighbors/refine.cu GPUS 1 PERCENT + NAME + NEIGHBORS_TEST + PATH + test/neighbors/brute_force.cu + test/neighbors/brute_force_prefiltered.cu + test/neighbors/refine.cu + GPUS + 1 + PERCENT 100 ) diff --git a/cpp/test/neighbors/brute_force.cu b/cpp/test/neighbors/brute_force.cu index fdb180186..081a2966e 100644 --- a/cpp/test/neighbors/brute_force.cu +++ b/cpp/test/neighbors/brute_force.cu @@ -82,7 +82,7 @@ class KNNTest : public ::testing::TestWithParam { auto metric = cuvs::distance::DistanceType::L2Unexpanded; auto idx = cuvs::neighbors::brute_force::build(handle, index, metric); - cuvs::neighbors::brute_force::search(handle, idx, search, indices, distances); + cuvs::neighbors::brute_force::search(handle, idx, search, indices, distances, std::nullopt); build_actual_output<<>>( actual_labels_.data(), rows_, k_, search_labels_.data(), indices_.data()); diff --git a/cpp/test/neighbors/brute_force_prefiltered.cu b/cpp/test/neighbors/brute_force_prefiltered.cu new file mode 100644 index 000000000..17166fd7a --- /dev/null +++ b/cpp/test/neighbors/brute_force_prefiltered.cu @@ -0,0 +1,524 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" +#include "knn_utils.cuh" + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace cuvs::neighbors::brute_force { + +template +struct PrefilteredBruteForceInputs { + index_t n_queries; + index_t n_dataset; + index_t dim; + index_t top_k; + float sparsity; + cuvs::distance::DistanceType metric; + bool select_min = true; +}; + +template +struct CompareApproxWithInf { + CompareApproxWithInf(T eps_) : eps(eps_) {} + bool operator()(const T& a, const T& b) const + { + if (std::isinf(a) && std::isinf(b)) return true; + T diff = std::abs(a - b); + T m = std::max(std::abs(a), std::abs(b)); + T ratio = diff > eps ? diff / m : diff; + + return (ratio <= eps); + } + + private: + T eps; +}; + +template +RAFT_KERNEL normalize_kernel( + OutT* theta, const InT* in_vals, size_t max_scale, size_t r_scale, size_t c_scale) +{ + size_t idx = threadIdx.x; + if (idx < max_scale) { + auto a = OutT(in_vals[4 * idx]); + auto b = OutT(in_vals[4 * idx + 1]); + auto c = OutT(in_vals[4 * idx + 2]); + auto d = OutT(in_vals[4 * idx + 3]); + auto sum = a + b + c + d; + a /= sum; + b /= sum; + c /= sum; + d /= sum; + theta[4 * idx] = a; + theta[4 * idx + 1] = b; + theta[4 * idx + 2] = c; + theta[4 * idx + 3] = d; + } +} + +template +void normalize(OutT* theta, + const InT* in_vals, + size_t max_scale, + size_t r_scale, + size_t c_scale, + bool handle_rect, + bool theta_array, + cudaStream_t stream) +{ + normalize_kernel<<<1, 256, 0, stream>>>(theta, in_vals, max_scale, r_scale, c_scale); + RAFT_CUDA_TRY(cudaGetLastError()); +} + +template +RAFT_KERNEL set_bitmap_kernel( + const index_t* src, const index_t* dst, bitmap_t* bitmap, index_t n_edges, index_t n_cols) +{ + size_t idx = threadIdx.x + blockDim.x * blockIdx.x; + if (idx < n_edges) { + index_t row = src[idx]; + index_t col = dst[idx]; + index_t g_idx = row * n_cols + col; + index_t item_idx = (g_idx) >> 5; + uint32_t bit_idx = (g_idx)&31; + atomicOr(bitmap + item_idx, (uint32_t(1) << bit_idx)); + } +} + +template +void set_bitmap(const index_t* src, + const index_t* dst, + bitmap_t* bitmap, + index_t n_edges, + index_t n_cols, + cudaStream_t stream) +{ + int block_size = 256; + int blocks = raft::ceildiv(n_edges, block_size); + set_bitmap_kernel + <<>>(src, dst, bitmap, n_edges, n_cols); + RAFT_CUDA_TRY(cudaGetLastError()); +} +template +class PrefilteredBruteForceTest + : public ::testing::TestWithParam> { + public: + PrefilteredBruteForceTest() + : stream(raft::resource::get_cuda_stream(handle)), + params(::testing::TestWithParam>::GetParam()), + filter_d(0, stream), + dataset_d(0, stream), + queries_d(0, stream), + out_val_d(0, stream), + out_val_expected_d(0, stream), + out_idx_d(0, stream), + out_idx_expected_d(0, stream) + { + } + + protected: + index_t create_sparse_matrix_with_rmat(index_t m, + index_t n, + value_t sparsity, + rmm::device_uvector& filter_d) + { + index_t r_scale = (index_t)std::log2(m); + index_t c_scale = (index_t)std::log2(n); + index_t n_edges = (index_t)(m * n * 1.0 * sparsity); + index_t max_scale = std::max(r_scale, c_scale); + + rmm::device_uvector out_src{(unsigned long)n_edges, stream}; + rmm::device_uvector out_dst{(unsigned long)n_edges, stream}; + rmm::device_uvector theta{(unsigned long)(4 * max_scale), stream}; + + raft::random::RngState state{2024ULL, raft::random::GeneratorType::GenPC}; + + raft::random::uniform(handle, state, theta.data(), theta.size(), 0.0f, 1.0f); + normalize( + theta.data(), theta.data(), max_scale, r_scale, c_scale, r_scale != c_scale, true, stream); + raft::random::rmat_rectangular_gen((index_t*)nullptr, + out_src.data(), + out_dst.data(), + theta.data(), + r_scale, + c_scale, + n_edges, + stream, + state); + + index_t nnz_h = 0; + { + auto src = out_src.data(); + auto dst = out_dst.data(); + auto bitmap = filter_d.data(); + rmm::device_scalar nnz(0, stream); + auto nnz_view = raft::make_device_scalar_view(nnz.data()); + auto filter_view = + raft::make_device_vector_view(filter_d.data(), filter_d.size()); + + set_bitmap(src, dst, bitmap, n_edges, n, stream); + + // TODO(rhdong): Need to switch to the public API, + // with the issue: https://github.com/rapidsai/cuvs/issues/158 + raft::detail::popc(handle, filter_view, m * n, nnz_view); + raft::copy(&nnz_h, nnz.data(), 1, stream); + + raft::resource::sync_stream(handle, stream); + } + + return nnz_h; + } + + void cpu_convert_to_csr(std::vector& bitmap, + index_t rows, + index_t cols, + std::vector& indices, + std::vector& indptr) + { + index_t offset_indptr = 0; + index_t offset_values = 0; + indptr[offset_indptr++] = 0; + + index_t index = 0; + bitmap_t element = 0; + index_t bit_position = 0; + + for (index_t i = 0; i < rows; ++i) { + for (index_t j = 0; j < cols; ++j) { + index = i * cols + j; + element = bitmap[index / (8 * sizeof(bitmap_t))]; + bit_position = index % (8 * sizeof(bitmap_t)); + + if (((element >> bit_position) & 1)) { + indices[offset_values] = static_cast(j); + offset_values++; + } + } + indptr[offset_indptr++] = static_cast(offset_values); + } + } + + void cpu_sddmm(const std::vector& A, + const std::vector& B, + std::vector& vals, + const std::vector& cols, + const std::vector& row_ptrs, + bool is_row_major_A, + bool is_row_major_B, + value_t alpha = 1.0, + value_t beta = 0.0) + { + if (params.n_queries * params.dim != static_cast(A.size()) || + params.dim * params.n_dataset != static_cast(B.size())) { + std::cerr << "Matrix dimensions and vector size do not match!" << std::endl; + return; + } + + bool trans_a = is_row_major_A; + bool trans_b = is_row_major_B; + + for (index_t i = 0; i < params.n_queries; ++i) { + for (index_t j = row_ptrs[i]; j < row_ptrs[i + 1]; ++j) { + value_t sum = 0; + value_t norms_A = 0; + value_t norms_B = 0; + for (index_t l = 0; l < params.dim; ++l) { + index_t a_index = trans_a ? i * params.dim + l : l * params.n_queries + i; + index_t b_index = trans_b ? l * params.n_dataset + cols[j] : cols[j] * params.dim + l; + sum += A[a_index] * B[b_index]; + + norms_A += A[a_index] * A[a_index]; + norms_B += B[b_index] * B[b_index]; + } + vals[j] = alpha * sum + beta * vals[j]; + if (params.metric == cuvs::distance::DistanceType::L2Expanded) { + vals[j] = value_t(-2.0) * vals[j] + norms_A + norms_B; + } else if (params.metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + vals[j] = std::sqrt(value_t(-2.0) * vals[j] + norms_A + norms_B); + } else if (params.metric == cuvs::distance::DistanceType::CosineExpanded) { + vals[j] = value_t(1.0) - vals[j] / std::sqrt(norms_A * norms_B); + } + } + } + } + + void cpu_select_k(const std::vector& indptr_h, + const std::vector& indices_h, + const std::vector& values_h, + std::optional>& in_idx_h, + index_t n_queries, + index_t n_dataset, + index_t top_k, + std::vector& out_values_h, + std::vector& out_indices_h, + bool select_min = true) + { + auto comp = [select_min](const std::pair& a, + const std::pair& b) { + return select_min ? a.first < b.first : a.first >= b.first; + }; + + for (index_t row = 0; row < n_queries; ++row) { + std::priority_queue, + std::vector>, + decltype(comp)> + pq(comp); + + for (index_t idx = indptr_h[row]; idx < indptr_h[row + 1]; ++idx) { + pq.push({values_h[idx], (in_idx_h.has_value()) ? (*in_idx_h)[idx] : indices_h[idx]}); + if (pq.size() > size_t(top_k)) { pq.pop(); } + } + + std::vector> row_pairs; + while (!pq.empty()) { + row_pairs.push_back(pq.top()); + pq.pop(); + } + + if (select_min) { + std::sort(row_pairs.begin(), row_pairs.end(), [](const auto& a, const auto& b) { + return a.first <= b.first; + }); + } else { + std::sort(row_pairs.begin(), row_pairs.end(), [](const auto& a, const auto& b) { + return a.first >= b.first; + }); + } + for (index_t col = 0; col < top_k; col++) { + if (col < index_t(row_pairs.size())) { + out_values_h[row * top_k + col] = row_pairs[col].first; + out_indices_h[row * top_k + col] = row_pairs[col].second; + } + } + } + } + + void SetUp() override + { + index_t element = + raft::ceildiv(params.n_queries * params.n_dataset, index_t(sizeof(bitmap_t) * 8)); + std::vector filter_h(element); + filter_d.resize(element, stream); + + nnz = + create_sparse_matrix_with_rmat(params.n_queries, params.n_dataset, params.sparsity, filter_d); + + raft::update_host(filter_h.data(), filter_d.data(), filter_d.size(), stream); + raft::resource::sync_stream(handle, stream); + + index_t dataset_size = params.n_dataset * params.dim; + index_t queries_size = params.n_queries * params.dim; + + std::vector dataset_h(dataset_size); + std::vector queries_h(queries_size); + + dataset_d.resize(dataset_size, stream); + queries_d.resize(queries_size, stream); + + auto blobs_in_val = + raft::make_device_matrix(handle, 1, dataset_size + queries_size); + auto labels = raft::make_device_vector(handle, 1); + + raft::random::make_blobs(blobs_in_val.data_handle(), + labels.data_handle(), + 1, + dataset_size + queries_size, + 1, + stream, + false, + nullptr, + nullptr, + value_t(1.0), + false, + value_t(-1.0f), + value_t(1.0f), + uint64_t(2024)); + + raft::copy(dataset_h.data(), blobs_in_val.data_handle(), dataset_size, stream); + raft::copy(dataset_d.data(), blobs_in_val.data_handle(), dataset_size, stream); + + raft::copy(queries_h.data(), blobs_in_val.data_handle() + dataset_size, queries_size, stream); + raft::copy(queries_d.data(), blobs_in_val.data_handle() + dataset_size, queries_size, stream); + + raft::resource::sync_stream(handle); + + std::vector values_h(nnz); + std::vector indices_h(nnz); + std::vector indptr_h(params.n_queries + 1); + + cpu_convert_to_csr(filter_h, params.n_queries, params.n_dataset, indices_h, indptr_h); + + cpu_sddmm(queries_h, dataset_h, values_h, indices_h, indptr_h, true, false); + + bool select_min = cuvs::distance::is_min_close(params.metric); + + std::vector out_val_h(params.n_queries * params.top_k, + select_min ? std::numeric_limits::infinity() + : std::numeric_limits::lowest()); + std::vector out_idx_h(params.n_queries * params.top_k, static_cast(0)); + + out_val_d.resize(params.n_queries * params.top_k, stream); + out_idx_d.resize(params.n_queries * params.top_k, stream); + + raft::update_device(out_val_d.data(), out_val_h.data(), out_val_h.size(), stream); + raft::update_device(out_idx_d.data(), out_idx_h.data(), out_idx_h.size(), stream); + + raft::resource::sync_stream(handle); + + std::optional> optional_indices_h = std::nullopt; + + cpu_select_k(indptr_h, + indices_h, + values_h, + optional_indices_h, + params.n_queries, + params.n_dataset, + params.top_k, + out_val_h, + out_idx_h, + select_min); + + out_val_expected_d.resize(params.n_queries * params.top_k, stream); + out_idx_expected_d.resize(params.n_queries * params.top_k, stream); + + raft::update_device(out_val_expected_d.data(), out_val_h.data(), out_val_h.size(), stream); + raft::update_device(out_idx_expected_d.data(), out_idx_h.data(), out_idx_h.size(), stream); + + raft::resource::sync_stream(handle); + } + + void Run() + { + auto dataset_raw = raft::make_device_matrix_view( + (const value_t*)dataset_d.data(), params.n_dataset, params.dim); + + auto queries = raft::make_device_matrix_view( + (const value_t*)queries_d.data(), params.n_queries, params.dim); + + auto dataset = brute_force::build(handle, dataset_raw, params.metric); + + auto filter = cuvs::core::bitmap_view( + (const bitmap_t*)filter_d.data(), params.n_queries, params.n_dataset); + + auto out_val = raft::make_device_matrix_view( + out_val_d.data(), params.n_queries, params.top_k); + auto out_idx = raft::make_device_matrix_view( + out_idx_d.data(), params.n_queries, params.top_k); + + brute_force::search(handle, dataset, queries, out_idx, out_val, std::make_optional(filter)); + + ASSERT_TRUE(cuvs::neighbors::devArrMatchKnnPair(out_idx_expected_d.data(), + out_idx.data_handle(), + out_val_expected_d.data(), + out_val.data_handle(), + params.n_queries, + params.top_k, + 0.001f, + stream, + true)); + } + + protected: + raft::resources handle; + cudaStream_t stream; + + PrefilteredBruteForceInputs params; + + index_t nnz; + + rmm::device_uvector dataset_d; + rmm::device_uvector queries_d; + rmm::device_uvector filter_d; + + rmm::device_uvector out_val_d; + rmm::device_uvector out_val_expected_d; + + rmm::device_uvector out_idx_d; + rmm::device_uvector out_idx_expected_d; +}; + +using PrefilteredBruteForceTest_float_int64 = PrefilteredBruteForceTest; +TEST_P(PrefilteredBruteForceTest_float_int64, Result) { Run(); } + +template +const std::vector> selectk_inputs = { + {2, 131072, 255, 255, 0.4, cuvs::distance::DistanceType::L2Expanded}, + {8, 131072, 512, 16, 0.5, cuvs::distance::DistanceType::L2Expanded}, + {16, 131072, 2052, 16, 0.2, cuvs::distance::DistanceType::L2Expanded}, + {2, 8192, 255, 16, 0.4, cuvs::distance::DistanceType::InnerProduct}, + {16, 8192, 512, 16, 0.5, cuvs::distance::DistanceType::InnerProduct}, + {128, 8192, 2052, 16, 0.2, cuvs::distance::DistanceType::InnerProduct}, + {1024, 8192, 1, 0, 0.1, cuvs::distance::DistanceType::L2Expanded}, + {1024, 8192, 3, 0, 0.1, cuvs::distance::DistanceType::InnerProduct}, + {1024, 8192, 5, 0, 0.1, cuvs::distance::DistanceType::L2SqrtExpanded}, + {1024, 8192, 8, 0, 0.1, cuvs::distance::DistanceType::CosineExpanded}, + + {1024, 8192, 1, 1, 0.1, cuvs::distance::DistanceType::L2Expanded}, + {1024, 8192, 3, 1, 0.1, cuvs::distance::DistanceType::InnerProduct}, + {1024, 8192, 5, 1, 0.1, cuvs::distance::DistanceType::L2SqrtExpanded}, + {1024, 8192, 8, 1, 0.1, cuvs::distance::DistanceType::CosineExpanded}, + + {1024, 8192, 2050, 16, 0.4, cuvs::distance::DistanceType::L2Expanded}, + {1024, 8192, 2051, 16, 0.5, cuvs::distance::DistanceType::L2Expanded}, + {1024, 8192, 2052, 16, 0.2, cuvs::distance::DistanceType::L2Expanded}, + {1024, 8192, 2050, 16, 0.4, cuvs::distance::DistanceType::InnerProduct}, + {1024, 8192, 2051, 16, 0.5, cuvs::distance::DistanceType::InnerProduct}, + {1024, 8192, 2052, 16, 0.2, cuvs::distance::DistanceType::InnerProduct}, + {1024, 8192, 2050, 16, 0.4, cuvs::distance::DistanceType::L2SqrtExpanded}, + {1024, 8192, 2051, 16, 0.5, cuvs::distance::DistanceType::L2SqrtExpanded}, + {1024, 8192, 2052, 16, 0.2, cuvs::distance::DistanceType::L2SqrtExpanded}, + {1024, 8192, 2050, 16, 0.4, cuvs::distance::DistanceType::CosineExpanded}, + {1024, 8192, 2051, 16, 0.5, cuvs::distance::DistanceType::CosineExpanded}, + {1024, 8192, 2052, 16, 0.2, cuvs::distance::DistanceType::CosineExpanded}, + + {1024, 8192, 1, 16, 0.5, cuvs::distance::DistanceType::L2Expanded}, + {1024, 8192, 2, 16, 0.2, cuvs::distance::DistanceType::L2Expanded}, + {1024, 8192, 3, 16, 0.4, cuvs::distance::DistanceType::InnerProduct}, + {1024, 8192, 4, 16, 0.5, cuvs::distance::DistanceType::InnerProduct}, + {1024, 8192, 5, 16, 0.2, cuvs::distance::DistanceType::L2SqrtExpanded}, + {1024, 8192, 8, 16, 0.4, cuvs::distance::DistanceType::L2SqrtExpanded}, + {1024, 8192, 5, 16, 0.5, cuvs::distance::DistanceType::CosineExpanded}, + {1024, 8192, 8, 16, 0.2, cuvs::distance::DistanceType::CosineExpanded}}; + +INSTANTIATE_TEST_CASE_P(PrefilteredBruteForceTest, + PrefilteredBruteForceTest_float_int64, + ::testing::ValuesIn(selectk_inputs)); + +} // namespace cuvs::neighbors::brute_force diff --git a/cpp/test/neighbors/knn_utils.cuh b/cpp/test/neighbors/knn_utils.cuh new file mode 100644 index 000000000..d95174ef6 --- /dev/null +++ b/cpp/test/neighbors/knn_utils.cuh @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../test_utils.cuh" + +#include + +#include + +#include + +namespace cuvs::neighbors { +template +struct idx_dist_pair { + IdxT idx; + DistT dist; + compareDist eq_compare; + bool operator==(const idx_dist_pair& a) const + { + if (idx == a.idx) return true; + if (eq_compare(dist, a.dist)) return true; + return false; + } + idx_dist_pair(IdxT x, DistT y, compareDist op) : idx(x), dist(y), eq_compare(op) {} +}; + +template +testing::AssertionResult devArrMatchKnnPair(const T* expected_idx, + const T* actual_idx, + const DistT* expected_dist, + const DistT* actual_dist, + size_t rows, + size_t cols, + const DistT eps, + cudaStream_t stream = 0, + bool sort_inputs = false) +{ + size_t size = rows * cols; + std::unique_ptr exp_idx_h(new T[size]); + std::unique_ptr act_idx_h(new T[size]); + std::unique_ptr exp_dist_h(new DistT[size]); + std::unique_ptr act_dist_h(new DistT[size]); + raft::update_host(exp_idx_h.get(), expected_idx, size, stream); + raft::update_host(act_idx_h.get(), actual_idx, size, stream); + raft::update_host(exp_dist_h.get(), expected_dist, size, stream); + raft::update_host(act_dist_h.get(), actual_dist, size, stream); + + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + for (size_t i(0); i < rows; ++i) { + std::vector> actual; + std::vector> expected; + for (size_t j(0); j < cols; ++j) { + auto idx = i * cols + j; // row major assumption! + auto exp_idx = exp_idx_h.get()[idx]; + auto act_idx = act_idx_h.get()[idx]; + auto exp_dist = exp_dist_h.get()[idx]; + auto act_dist = act_dist_h.get()[idx]; + actual.push_back(std::make_pair(act_dist, act_idx)); + expected.push_back(std::make_pair(exp_dist, exp_idx)); + } + if (sort_inputs) { + // inputs could be unsorted here, sort for comparison + std::sort(actual.begin(), actual.end()); + std::sort(expected.begin(), expected.end()); + } + for (size_t j(0); j < cols; ++j) { + auto act = actual[j]; + auto exp = expected[j]; + idx_dist_pair exp_kvp(exp.second, exp.first, cuvs::CompareApprox(eps)); + idx_dist_pair act_kvp(act.second, act.first, cuvs::CompareApprox(eps)); + if (!(exp_kvp == act_kvp)) { + return testing::AssertionFailure() + << "actual=" << act_kvp.idx << "," << act_kvp.dist << "!=" + << "expected" << exp_kvp.idx << "," << exp_kvp.dist << " @" << i << "," << j; + } + } + } + return testing::AssertionSuccess(); +} +} // namespace cuvs::neighbors