From 5f0dfeded3e8bc63832dc6ab37fda1e62910d423 Mon Sep 17 00:00:00 2001 From: rhdong Date: Thu, 23 May 2024 21:35:53 -0700 Subject: [PATCH] [FEA] support of prefiltered brute force (#2294) - This PR is one part of the feature of #1969 - Add the API of 'search_with_filtering' for brute force. Authors: - James Rong (https://github.com/rhdong) ```shell ***WARNING*** CPU scaling is enabled, the benchmark real time measurements may be noisy and will incur extra overhead. ----------------------------------------------------------------------------------------------------- Benchmark Time CPU Iterations ----------------------------------------------------------------------------------------------------- KNN/float/int64_t/brute_force_filter_knn/0/0/0/manual_time 33.1 ms 69.9 ms 21 1000000#128#1000#255#0#InnerProduct#NO_COPY#SEARCH KNN/float/int64_t/brute_force_filter_knn/1/0/0/manual_time 38.0 ms 74.8 ms 18 1000000#128#1000#255#0#L2Expanded#NO_COPY#SEARCH KNN/float/int64_t/brute_force_filter_knn/2/0/0/manual_time 41.7 ms 78.5 ms 17 1000000#128#1000#255#0.8#InnerProduct#NO_COPY#SEARCH KNN/float/int64_t/brute_force_filter_knn/3/0/0/manual_time 57.5 ms 94.3 ms 12 1000000#128#1000#255#0.8#L2Expanded#NO_COPY#SEARCH KNN/float/int64_t/brute_force_filter_knn/4/0/0/manual_time 19.7 ms 56.4 ms 35 1000000#128#1000#255#0.9#InnerProduct#NO_COPY#SEARCH KNN/float/int64_t/brute_force_filter_knn/5/0/0/manual_time 26.1 ms 62.8 ms 27 1000000#128#1000#255#0.9#L2Expanded#NO_COPY#SEARCH``` Authors: - rhdong (https://github.com/rhdong) - Artem M. Chirkin (https://github.com/achirkin) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Robert Maynard (https://github.com/robertmaynard) - Corey J. Nolet (https://github.com/cjnolet) - Divye Gala (https://github.com/divyegala) URL: https://github.com/rapidsai/raft/pull/2294 --- cpp/include/raft/core/bitmap.cuh | 116 +++------------- cpp/include/raft/core/bitmap.hpp | 123 +++++++++++++++++ cpp/include/raft/core/bitset.cuh | 42 ++---- cpp/include/raft/core/detail/popc.cuh | 75 +++++++++++ .../sparse/convert/detail/bitmap_to_csr.cuh | 10 +- .../raft/sparse/distance/detail/utils.cuh | 127 +++++++++++++++++- .../sparse/matrix/detail/select_k-ext.cuh | 2 +- .../raft/sparse/matrix/detail/select_k.cuh | 3 +- .../matrix/detail/select_k_double_int64_t.cu | 32 ----- .../matrix/detail/select_k_double_uint32_t.cu | 34 ----- .../matrix/detail/select_k_float_int32.cu | 32 ----- .../matrix/detail/select_k_float_int64_t.cu | 32 ----- .../matrix/detail/select_k_float_uint32_t.cu | 32 ----- .../matrix/detail/select_k_half_int64_t.cu | 32 ----- .../matrix/detail/select_k_half_uint32_t.cu | 32 ----- cpp/test/CMakeLists.txt | 1 + cpp/test/ext_headers/00_generate.py | 1 + .../raft_sparse_matrix_detail_select_k.cu | 27 ++++ 18 files changed, 388 insertions(+), 365 deletions(-) create mode 100644 cpp/include/raft/core/bitmap.hpp create mode 100644 cpp/include/raft/core/detail/popc.cuh delete mode 100644 cpp/src/sparse/matrix/detail/select_k_double_int64_t.cu delete mode 100644 cpp/src/sparse/matrix/detail/select_k_double_uint32_t.cu delete mode 100644 cpp/src/sparse/matrix/detail/select_k_float_int32.cu delete mode 100644 cpp/src/sparse/matrix/detail/select_k_float_int64_t.cu delete mode 100644 cpp/src/sparse/matrix/detail/select_k_float_uint32_t.cu delete mode 100644 cpp/src/sparse/matrix/detail/select_k_half_int64_t.cu delete mode 100644 cpp/src/sparse/matrix/detail/select_k_half_uint32_t.cu create mode 100644 cpp/test/ext_headers/raft_sparse_matrix_detail_select_k.cu diff --git a/cpp/include/raft/core/bitmap.cuh b/cpp/include/raft/core/bitmap.cuh index 829c84ed25..2c23a77e47 100644 --- a/cpp/include/raft/core/bitmap.cuh +++ b/cpp/include/raft/core/bitmap.cuh @@ -16,112 +16,30 @@ #pragma once +#include #include #include #include #include #include -namespace raft::core { -/** - * @defgroup bitmap Bitmap - * @{ - */ -/** - * @brief View of a RAFT Bitmap. - * - * This lightweight structure which represents and manipulates a two-dimensional bitmap matrix view - * with row major order. This class provides functionality for handling a matrix where each element - * is represented as a bit in a bitmap. - * - * @tparam bitmap_t Underlying type of the bitmap array. Default is uint32_t. - * @tparam index_t Indexing type used. Default is uint32_t. - */ -template -struct bitmap_view : public bitset_view { - static_assert((std::is_same::value || - std::is_same::value), - "The bitmap_t must be uint32_t or uint64_t."); - /** - * @brief Create a bitmap view from a device raw pointer. - * - * @param bitmap_ptr Device raw pointer - * @param rows Number of row in the matrix. - * @param cols Number of col in the matrix. - */ - _RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr, index_t rows, index_t cols) - : bitset_view(bitmap_ptr, rows * cols), rows_(rows), cols_(cols) - { - } - - /** - * @brief Create a bitmap view from a device vector view of the bitset. - * - * @param bitmap_span Device vector view of the bitmap - * @param rows Number of row in the matrix. - * @param cols Number of col in the matrix. - */ - _RAFT_HOST_DEVICE bitmap_view(raft::device_vector_view bitmap_span, - index_t rows, - index_t cols) - : bitset_view(bitmap_span, rows * cols), rows_(rows), cols_(cols) - { - } +#include - private: - // Hide the constructors of bitset_view. - _RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr, index_t bitmap_len) - : bitset_view(bitmap_ptr, bitmap_len) - { - } - - _RAFT_HOST_DEVICE bitmap_view(raft::device_vector_view bitmap_span, - index_t bitmap_len) - : bitset_view(bitmap_span, bitmap_len) - { - } - - public: - /** - * @brief Device function to test if a given row and col are set in the bitmap. - * - * @param row Row index of the bit to test - * @param col Col index of the bit to test - * @return bool True if index has not been unset in the bitset - */ - inline _RAFT_DEVICE auto test(const index_t row, const index_t col) const -> bool - { - return test(row * cols_ + col); - } - - /** - * @brief Device function to set a given row and col to set_value in the bitset. - * - * @param row Row index of the bit to set - * @param col Col index of the bit to set - * @param new_value Value to set the bit to (true or false) - */ - inline _RAFT_DEVICE void set(const index_t row, const index_t col, bool new_value) const - { - set(row * cols_ + col, &new_value); - } - - /** - * @brief Get the total number of rows - * @return index_t The total number of rows - */ - inline _RAFT_HOST_DEVICE index_t get_n_rows() const { return rows_; } - - /** - * @brief Get the total number of columns - * @return index_t The total number of columns - */ - inline _RAFT_HOST_DEVICE index_t get_n_cols() const { return cols_; } +namespace raft::core { - private: - index_t rows_; - index_t cols_; -}; +template +_RAFT_HOST_DEVICE inline bool bitmap_view::test(const index_t row, + const index_t col) const +{ + return test(row * cols_ + col); +} + +template +_RAFT_HOST_DEVICE void bitmap_view::set(const index_t row, + const index_t col, + bool new_value) const +{ + set(row * cols_ + col, &new_value); +} -/** @} */ } // end namespace raft::core diff --git a/cpp/include/raft/core/bitmap.hpp b/cpp/include/raft/core/bitmap.hpp new file mode 100644 index 0000000000..5c77866164 --- /dev/null +++ b/cpp/include/raft/core/bitmap.hpp @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace raft::core { +/** + * @defgroup bitmap Bitmap + * @{ + */ +/** + * @brief View of a RAFT Bitmap. + * + * This lightweight structure which represents and manipulates a two-dimensional bitmap matrix view + * with row major order. This class provides functionality for handling a matrix where each element + * is represented as a bit in a bitmap. + * + * @tparam bitmap_t Underlying type of the bitmap array. Default is uint32_t. + * @tparam index_t Indexing type used. Default is uint32_t. + */ +template +struct bitmap_view : public bitset_view { + static_assert((std::is_same::type, uint32_t>::value || + std::is_same::type, uint64_t>::value), + "The bitmap_t must be uint32_t or uint64_t."); + /** + * @brief Create a bitmap view from a device raw pointer. + * + * @param bitmap_ptr Device raw pointer + * @param rows Number of row in the matrix. + * @param cols Number of col in the matrix. + */ + _RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr, index_t rows, index_t cols) + : bitset_view(bitmap_ptr, rows * cols), rows_(rows), cols_(cols) + { + } + + /** + * @brief Create a bitmap view from a device vector view of the bitset. + * + * @param bitmap_span Device vector view of the bitmap + * @param rows Number of row in the matrix. + * @param cols Number of col in the matrix. + */ + _RAFT_HOST_DEVICE bitmap_view(raft::device_vector_view bitmap_span, + index_t rows, + index_t cols) + : bitset_view(bitmap_span, rows * cols), rows_(rows), cols_(cols) + { + } + + private: + // Hide the constructors of bitset_view. + _RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr, index_t bitmap_len) + : bitset_view(bitmap_ptr, bitmap_len) + { + } + + _RAFT_HOST_DEVICE bitmap_view(raft::device_vector_view bitmap_span, + index_t bitmap_len) + : bitset_view(bitmap_span, bitmap_len) + { + } + + public: + /** + * @brief Device function to test if a given row and col are set in the bitmap. + * + * @param row Row index of the bit to test + * @param col Col index of the bit to test + * @return bool True if index has not been unset in the bitset + */ + inline _RAFT_HOST_DEVICE bool test(const index_t row, const index_t col) const; + + /** + * @brief Device function to set a given row and col to set_value in the bitset. + * + * @param row Row index of the bit to set + * @param col Col index of the bit to set + * @param new_value Value to set the bit to (true or false) + */ + inline _RAFT_HOST_DEVICE void set(const index_t row, const index_t col, bool new_value) const; + + /** + * @brief Get the total number of rows + * @return index_t The total number of rows + */ + inline _RAFT_HOST_DEVICE index_t get_n_rows() const { return rows_; } + + /** + * @brief Get the total number of columns + * @return index_t The total number of columns + */ + inline _RAFT_HOST_DEVICE index_t get_n_cols() const { return cols_; } + + private: + index_t rows_; + index_t cols_; +}; + +/** @} */ +} // end namespace raft::core diff --git a/cpp/include/raft/core/bitset.cuh b/cpp/include/raft/core/bitset.cuh index cdfbe0b8dd..d7eedee92e 100644 --- a/cpp/include/raft/core/bitset.cuh +++ b/cpp/include/raft/core/bitset.cuh @@ -17,7 +17,7 @@ #pragma once #include -#include // native_popc +#include #include #include #include @@ -60,6 +60,12 @@ _RAFT_HOST_DEVICE void bitset_view::set(const index_t sample_ } } +template +_RAFT_HOST_DEVICE inline index_t bitset_view::n_elements() const +{ + return raft::ceildiv(bitset_len_, bitset_element_size); +} + template bitset::bitset(const raft::resources& res, raft::device_vector_view mask_index, @@ -161,37 +167,9 @@ template void bitset::count(const raft::resources& res, raft::device_scalar_view count_gpu_scalar) { - auto n_elements_ = n_elements(); - auto count_gpu = - raft::make_device_vector_view(count_gpu_scalar.data_handle(), 1); - auto bitset_matrix_view = raft::make_device_matrix_view( - bitset_.data(), n_elements_, 1); - - bitset_t n_last_element = (bitset_len_ % bitset_element_size); - bitset_t last_element_mask = - n_last_element ? (bitset_t)((bitset_t{1} << n_last_element) - bitset_t{1}) : ~bitset_t{0}; - raft::linalg::coalesced_reduction( - res, - bitset_matrix_view, - count_gpu, - index_t{0}, - false, - [last_element_mask, n_elements_] __device__(bitset_t element, index_t index) { - index_t result = 0; - if constexpr (bitset_element_size == 64) { - if (index == n_elements_ - 1) - result = index_t(raft::detail::popc(element & last_element_mask)); - else - result = index_t(raft::detail::popc(element)); - } else { // Needed because popc is not overloaded for 16 and 8 bit elements - if (index == n_elements_ - 1) - result = index_t(raft::detail::popc(uint32_t{element} & last_element_mask)); - else - result = index_t(raft::detail::popc(uint32_t{element})); - } - - return result; - }); + auto values = + raft::make_device_vector_view(bitset_.data(), n_elements()); + raft::detail::popc(res, values, bitset_len_, count_gpu_scalar); } } // end namespace raft::core diff --git a/cpp/include/raft/core/detail/popc.cuh b/cpp/include/raft/core/detail/popc.cuh new file mode 100644 index 0000000000..d74b68b715 --- /dev/null +++ b/cpp/include/raft/core/detail/popc.cuh @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +namespace raft::detail { + +/** + * @brief Count the number of bits that are set to 1 in a vector. + * + * @tparam value_t the value type of the vector. + * @tparam index_t the index type of vector and scalar. + * + * @param[in] res raft handle for managing expensive resources + * @param[in] values Number of row in the matrix. + * @param[in] max_len Maximum number of bits to count. + * @param[out] counter Number of bits that are set to 1. + */ +template +void popc(const raft::resources& res, + device_vector_view values, + index_t max_len, + raft::device_scalar_view counter) +{ + auto values_size = values.size(); + auto values_matrix = raft::make_device_matrix_view( + values.data_handle(), values_size, 1); + auto counter_vector = raft::make_device_vector_view(counter.data_handle(), 1); + + static constexpr index_t len_per_item = sizeof(value_t) * 8; + + value_t tail_len = (max_len % len_per_item); + value_t tail_mask = tail_len ? (value_t)((value_t{1} << tail_len) - value_t{1}) : ~value_t{0}; + raft::linalg::coalesced_reduction( + res, + values_matrix, + counter_vector, + index_t{0}, + false, + [tail_mask, values_size] __device__(value_t value, index_t index) { + index_t result = 0; + if constexpr (len_per_item == 64) { + if (index == values_size - 1) + result = index_t(raft::detail::popc(value & tail_mask)); + else + result = index_t(raft::detail::popc(value)); + } else { // Needed because popc is not overloaded for 16 and 8 bit elements + if (index == values_size - 1) + result = index_t(raft::detail::popc(uint32_t{value} & tail_mask)); + else + result = index_t(raft::detail::popc(uint32_t{value})); + } + + return result; + }); +} + +} // end namespace raft::detail \ No newline at end of file diff --git a/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh index b0315486ff..b1b0291a85 100644 --- a/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh +++ b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh @@ -67,8 +67,8 @@ RAFT_KERNEL __launch_bounds__(calc_nnz_by_rows_tpb) calc_nnz_by_rows_kernel(cons index_t l_sum = 0; while (offset < num_cols) { - index_t bitmap_idx = lane_id + (s_bit + offset) / BITS_PER_BITMAP; - bitmap_t l_bitmap = bitmap_t(0); + index_t bitmap_idx = lane_id + (s_bit + offset) / BITS_PER_BITMAP; + std::remove_const_t l_bitmap = 0; if (bitmap_idx * BITS_PER_BITMAP < e_bit) { l_bitmap = bitmap[bitmap_idx]; } @@ -176,9 +176,9 @@ RAFT_KERNEL __launch_bounds__(fill_indices_by_rows_tpb) #pragma unroll for (index_t offset = 0; offset < num_cols; offset += BITS_PER_BITMAP * warpSize) { - index_t bitmap_idx = lane_id + (s_bit + offset) / BITS_PER_BITMAP; - bitmap_t l_bitmap = bitmap_t(0); - index_t l_offset = offset + lane_id * BITS_PER_BITMAP - (s_bit % BITS_PER_BITMAP); + index_t bitmap_idx = lane_id + (s_bit + offset) / BITS_PER_BITMAP; + std::remove_const_t l_bitmap = 0; + index_t l_offset = offset + lane_id * BITS_PER_BITMAP - (s_bit % BITS_PER_BITMAP); if (bitmap_idx * BITS_PER_BITMAP < e_bit) { l_bitmap = bitmap[bitmap_idx]; } diff --git a/cpp/include/raft/sparse/distance/detail/utils.cuh b/cpp/include/raft/sparse/distance/detail/utils.cuh index ed2b414c70..42b545180b 100644 --- a/cpp/include/raft/sparse/distance/detail/utils.cuh +++ b/cpp/include/raft/sparse/distance/detail/utils.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,11 @@ #pragma once +#include +#include + #include +#include namespace raft { namespace sparse { @@ -37,6 +41,127 @@ inline int max_cols_per_block() sizeof(value_t); } +template +RAFT_KERNEL faster_dot_on_csr_kernel(value_t* __restrict__ dot, + const value_idx* __restrict__ indptr, + const value_idx* __restrict__ cols, + const value_t* __restrict__ A, + const value_t* __restrict__ B, + const value_idx nnz, + const value_idx n_rows, + const value_idx dim) +{ + auto vec_id = threadIdx.x; + auto lane_id = threadIdx.x & 0x1f; + + extern __shared__ char smem[]; + value_t* s_A = (value_t*)smem; + value_idx cur_row = -1; + + for (int row = blockIdx.x; row < n_rows; row += gridDim.x) { + for (int dot_id = blockIdx.y + indptr[row]; dot_id < indptr[row + 1]; dot_id += gridDim.y) { + if (dot_id >= nnz) { return; } + const value_idx col = cols[dot_id] * dim; + const value_t* __restrict__ B_col = B + col; + + if (threadIdx.x == 0) { dot[dot_id] = 0.0; } + __syncthreads(); + + if (cur_row != row) { + for (value_idx k = vec_id; k < dim; k += blockDim.x) { + s_A[k] = A[row * dim + k]; + } + cur_row = row; + } + + value_t l_dot_ = 0.0; + for (value_idx k = vec_id; k < dim; k += blockDim.x) { + asm("prefetch.global.L2 [%0];" ::"l"(B_col + k + blockDim.x)); + l_dot_ += s_A[k] * __ldcg(B_col + k); + } + l_dot_ += __shfl_down_sync(0xffffffff, l_dot_, 16); + l_dot_ += __shfl_down_sync(0xffff, l_dot_, 8); + l_dot_ += __shfl_down_sync(0xff, l_dot_, 4); + l_dot_ += __shfl_down_sync(0xf, l_dot_, 2); + l_dot_ += __shfl_down_sync(0x3, l_dot_, 1); + + if (lane_id == 0) { atomicAdd_block(dot + dot_id, l_dot_); } + } + } +} + +template +void faster_dot_on_csr(raft::resources const& handle, + value_t* dot, + const value_idx nnz, + const value_idx* indptr, + const value_idx* cols, + const value_t* A, + const value_t* B, + const value_idx n_rows, + const value_idx dim) +{ + if (nnz == 0 || n_rows == 0) return; + + auto stream = resource::get_cuda_stream(handle); + + constexpr value_idx MAX_ROW_PER_ITER = 500; + int dev_id, sm_count, blocks_per_sm; + + const int smem_size = dim * sizeof(value_t); + cudaGetDevice(&dev_id); + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + + if (dim < 128) { + constexpr int tpb = 64; + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &blocks_per_sm, faster_dot_on_csr_kernel, tpb, smem_size); + auto block_x = std::min(n_rows, MAX_ROW_PER_ITER); + auto block_y = + (std::min(value_idx(blocks_per_sm * sm_count * 16), nnz) + block_x - 1) / block_x; + dim3 blocks(block_x, block_y, 1); + + faster_dot_on_csr_kernel + <<>>(dot, indptr, cols, A, B, nnz, n_rows, dim); + + } else if (dim < 256) { + constexpr int tpb = 128; + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &blocks_per_sm, faster_dot_on_csr_kernel, tpb, smem_size); + auto block_x = std::min(n_rows, MAX_ROW_PER_ITER); + auto block_y = + (std::min(value_idx(blocks_per_sm * sm_count * 16), nnz) + block_x - 1) / block_x; + dim3 blocks(block_x, block_y, 1); + + faster_dot_on_csr_kernel + <<>>(dot, indptr, cols, A, B, nnz, n_rows, dim); + } else if (dim < 512) { + constexpr int tpb = 256; + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &blocks_per_sm, faster_dot_on_csr_kernel, tpb, smem_size); + auto block_x = std::min(n_rows, MAX_ROW_PER_ITER); + auto block_y = + (std::min(value_idx(blocks_per_sm * sm_count * 16), nnz) + block_x - 1) / block_x; + dim3 blocks(block_x, block_y, 1); + + faster_dot_on_csr_kernel + <<>>(dot, indptr, cols, A, B, nnz, n_rows, dim); + } else { + constexpr int tpb = 512; + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &blocks_per_sm, faster_dot_on_csr_kernel, tpb, smem_size); + auto block_x = std::min(n_rows, MAX_ROW_PER_ITER); + auto block_y = + (std::min(value_idx(blocks_per_sm * sm_count * 16), nnz) + block_x - 1) / block_x; + dim3 blocks(block_x, block_y, 1); + + faster_dot_on_csr_kernel + <<>>(dot, indptr, cols, A, B, nnz, n_rows, dim); + } + + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + } // namespace detail } // namespace distance } // namespace sparse diff --git a/cpp/include/raft/sparse/matrix/detail/select_k-ext.cuh b/cpp/include/raft/sparse/matrix/detail/select_k-ext.cuh index 922356b040..01625a0ce8 100644 --- a/cpp/include/raft/sparse/matrix/detail/select_k-ext.cuh +++ b/cpp/include/raft/sparse/matrix/detail/select_k-ext.cuh @@ -37,7 +37,7 @@ void select_k(raft::resources const& handle, raft::device_matrix_view out_idx, bool select_min, bool sorted = false, - raft::matrix::SelectAlgo algo = SelectAlgo::kAuto) RAFT_EXPLICIT; + raft::matrix::SelectAlgo algo = raft::matrix::SelectAlgo::kAuto) RAFT_EXPLICIT; } // namespace raft::sparse::matrix::detail #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY diff --git a/cpp/include/raft/sparse/matrix/detail/select_k.cuh b/cpp/include/raft/sparse/matrix/detail/select_k.cuh index 711169984b..5d52b94b2f 100644 --- a/cpp/include/raft/sparse/matrix/detail/select_k.cuh +++ b/cpp/include/raft/sparse/matrix/detail/select_k.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ #ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY #include "select_k-inl.cuh" + #endif #ifdef RAFT_COMPILED diff --git a/cpp/src/sparse/matrix/detail/select_k_double_int64_t.cu b/cpp/src/sparse/matrix/detail/select_k_double_int64_t.cu deleted file mode 100644 index c784b50dad..0000000000 --- a/cpp/src/sparse/matrix/detail/select_k_double_int64_t.cu +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#define instantiate_raft_sparse_matrix_detail_select_k(T, IdxT) \ - template void raft::sparse::matrix::detail::select_k( \ - raft::resources const& handle, \ - raft::device_csr_matrix_view in_val, \ - std::optional> in_idx, \ - raft::device_matrix_view out_val, \ - raft::device_matrix_view out_idx, \ - bool select_min, \ - bool sorted, \ - raft::matrix::SelectAlgo algo) - -instantiate_raft_sparse_matrix_detail_select_k(double, int64_t); - -#undef instantiate_raft_sparse_matrix_detail_select_k \ No newline at end of file diff --git a/cpp/src/sparse/matrix/detail/select_k_double_uint32_t.cu b/cpp/src/sparse/matrix/detail/select_k_double_uint32_t.cu deleted file mode 100644 index 98bab9a504..0000000000 --- a/cpp/src/sparse/matrix/detail/select_k_double_uint32_t.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include // uint32_t - -#define instantiate_raft_sparse_matrix_detail_select_k(T, IdxT) \ - template void raft::sparse::matrix::detail::select_k( \ - raft::resources const& handle, \ - raft::device_csr_matrix_view in_val, \ - std::optional> in_idx, \ - raft::device_matrix_view out_val, \ - raft::device_matrix_view out_idx, \ - bool select_min, \ - bool sorted, \ - raft::matrix::SelectAlgo algo) - -instantiate_raft_sparse_matrix_detail_select_k(double, uint32_t); - -#undef instantiate_raft_sparse_matrix_detail_select_k \ No newline at end of file diff --git a/cpp/src/sparse/matrix/detail/select_k_float_int32.cu b/cpp/src/sparse/matrix/detail/select_k_float_int32.cu deleted file mode 100644 index bff213ae69..0000000000 --- a/cpp/src/sparse/matrix/detail/select_k_float_int32.cu +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#define instantiate_raft_sparse_matrix_detail_select_k(T, IdxT) \ - template void raft::matrix::detail::select_k( \ - raft::resources const& handle, \ - raft::device_csr_matrix_view in_val, \ - std::optional> in_idx, \ - raft::device_matrix_view out_val, \ - raft::device_matrix_view out_idx, \ - bool select_min, \ - bool sorted, \ - raft::matrix::SelectAlgo algo) - -instantiate_raft_sparse_matrix_detail_select_k(float, int); - -#undef instantiate_raft_sparse_matrix_detail_select_k diff --git a/cpp/src/sparse/matrix/detail/select_k_float_int64_t.cu b/cpp/src/sparse/matrix/detail/select_k_float_int64_t.cu deleted file mode 100644 index 412b06e587..0000000000 --- a/cpp/src/sparse/matrix/detail/select_k_float_int64_t.cu +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#define instantiate_raft_sparse_matrix_detail_select_k(T, IdxT) \ - template void raft::sparse::matrix::detail::select_k( \ - raft::resources const& handle, \ - raft::device_csr_matrix_view in_val, \ - std::optional> in_idx, \ - raft::device_matrix_view out_val, \ - raft::device_matrix_view out_idx, \ - bool select_min, \ - bool sorted, \ - raft::matrix::SelectAlgo algo) - -instantiate_raft_sparse_matrix_detail_select_k(float, int64_t); - -#undef instantiate_raft_sparse_matrix_detail_select_k diff --git a/cpp/src/sparse/matrix/detail/select_k_float_uint32_t.cu b/cpp/src/sparse/matrix/detail/select_k_float_uint32_t.cu deleted file mode 100644 index 8ba3f0e22b..0000000000 --- a/cpp/src/sparse/matrix/detail/select_k_float_uint32_t.cu +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#define instantiate_raft_sparse_matrix_detail_select_k(T, IdxT) \ - template void raft::sparse::matrix::detail::select_k( \ - raft::resources const& handle, \ - raft::device_csr_matrix_view in_val, \ - std::optional> in_idx, \ - raft::device_matrix_view out_val, \ - raft::device_matrix_view out_idx, \ - bool select_min, \ - bool sorted, \ - raft::matrix::SelectAlgo algo) - -instantiate_raft_sparse_matrix_detail_select_k(float, uint32_t); - -#undef instantiate_raft_sparse_matrix_detail_select_k diff --git a/cpp/src/sparse/matrix/detail/select_k_half_int64_t.cu b/cpp/src/sparse/matrix/detail/select_k_half_int64_t.cu deleted file mode 100644 index 24c844f8c8..0000000000 --- a/cpp/src/sparse/matrix/detail/select_k_half_int64_t.cu +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#define instantiate_raft_sparse_matrix_detail_select_k(T, IdxT) \ - template void raft::sparse::matrix::detail::select_k( \ - raft::resources const& handle, \ - raft::device_csr_matrix_view in_val, \ - std::optional> in_idx, \ - raft::device_matrix_view out_val, \ - raft::device_matrix_view out_idx, \ - bool select_min, \ - bool sorted, \ - raft::matrix::SelectAlgo algo) - -instantiate_raft_sparse_matrix_detail_select_k(__half, int64_t); - -#undef instantiate_raft_sparse_matrix_detail_select_k diff --git a/cpp/src/sparse/matrix/detail/select_k_half_uint32_t.cu b/cpp/src/sparse/matrix/detail/select_k_half_uint32_t.cu deleted file mode 100644 index d63dc64933..0000000000 --- a/cpp/src/sparse/matrix/detail/select_k_half_uint32_t.cu +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#define instantiate_raft_sparse_matrix_detail_select_k(T, IdxT) \ - template void raft::sparse::matrix::detail::select_k( \ - raft::resources const& handle, \ - raft::device_csr_matrix_view in_val, \ - std::optional> in_idx, \ - raft::device_matrix_view out_val, \ - raft::device_matrix_view out_idx, \ - bool select_min, \ - bool sorted, \ - raft::matrix::SelectAlgo algo) - -instantiate_raft_sparse_matrix_detail_select_k(__half, uint32_t); - -#undef instantiate_raft_sparse_matrix_detail_select_k diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index dac3418c8e..ff0518a4d0 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -193,6 +193,7 @@ if(BUILD_TESTS) ext_headers/raft_neighbors_refine.cu ext_headers/raft_neighbors_detail_ivf_flat_search.cu ext_headers/raft_linalg_detail_coalesced_reduction.cu + ext_headers/raft_sparse_matrix_detail_select_k.cu ext_headers/raft_spatial_knn_detail_ball_cover_registers.cu ext_headers/raft_neighbors_detail_ivf_flat_interleaved_scan.cu ext_headers/raft_neighbors_detail_ivf_pq_compute_similarity.cu diff --git a/cpp/test/ext_headers/00_generate.py b/cpp/test/ext_headers/00_generate.py index d9c766979b..1e1106f8bf 100644 --- a/cpp/test/ext_headers/00_generate.py +++ b/cpp/test/ext_headers/00_generate.py @@ -54,6 +54,7 @@ "raft/neighbors/refine-ext.cuh", "raft/neighbors/detail/ivf_flat_search-ext.cuh", "raft/linalg/detail/coalesced_reduction-ext.cuh", + "raft/sparse/matrix/detail/select_k-ext.cuh", "raft/spatial/knn/detail/ball_cover/registers-ext.cuh", "raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh", "raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh", diff --git a/cpp/test/ext_headers/raft_sparse_matrix_detail_select_k.cu b/cpp/test/ext_headers/raft_sparse_matrix_detail_select_k.cu new file mode 100644 index 0000000000..b748a31a5b --- /dev/null +++ b/cpp/test/ext_headers/raft_sparse_matrix_detail_select_k.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include