forked from rapidsai/raft
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[FEA] support of prefiltered brute force (rapidsai#2294)
- This PR is one part of the feature of rapidsai#1969 - Add the API of 'search_with_filtering' for brute force. Authors: - James Rong (https://github.com/rhdong) ```shell ***WARNING*** CPU scaling is enabled, the benchmark real time measurements may be noisy and will incur extra overhead. ----------------------------------------------------------------------------------------------------- Benchmark Time CPU Iterations ----------------------------------------------------------------------------------------------------- KNN/float/int64_t/brute_force_filter_knn/0/0/0/manual_time 33.1 ms 69.9 ms 21 1000000#128#1000#255#0#InnerProduct#NO_COPY#SEARCH KNN/float/int64_t/brute_force_filter_knn/1/0/0/manual_time 38.0 ms 74.8 ms 18 1000000#128#1000#255#0#L2Expanded#NO_COPY#SEARCH KNN/float/int64_t/brute_force_filter_knn/2/0/0/manual_time 41.7 ms 78.5 ms 17 1000000#128#1000#255#0.8#InnerProduct#NO_COPY#SEARCH KNN/float/int64_t/brute_force_filter_knn/3/0/0/manual_time 57.5 ms 94.3 ms 12 1000000#128#1000#255#0.8#L2Expanded#NO_COPY#SEARCH KNN/float/int64_t/brute_force_filter_knn/4/0/0/manual_time 19.7 ms 56.4 ms 35 1000000#128#1000#255#0.9#InnerProduct#NO_COPY#SEARCH KNN/float/int64_t/brute_force_filter_knn/5/0/0/manual_time 26.1 ms 62.8 ms 27 1000000#128#1000#255#0.9#L2Expanded#NO_COPY#SEARCH``` Authors: - rhdong (https://github.com/rhdong) - Artem M. Chirkin (https://github.com/achirkin) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Robert Maynard (https://github.com/robertmaynard) - Corey J. Nolet (https://github.com/cjnolet) - Divye Gala (https://github.com/divyegala) URL: rapidsai#2294
- Loading branch information
Showing
18 changed files
with
388 additions
and
365 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
/* | ||
* Copyright (c) 2024, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <raft/core/bitset.hpp> | ||
#include <raft/core/detail/mdspan_util.cuh> | ||
#include <raft/core/device_container_policy.hpp> | ||
#include <raft/core/device_mdarray.hpp> | ||
#include <raft/core/resources.hpp> | ||
|
||
#include <type_traits> | ||
|
||
namespace raft::core { | ||
/** | ||
* @defgroup bitmap Bitmap | ||
* @{ | ||
*/ | ||
/** | ||
* @brief View of a RAFT Bitmap. | ||
* | ||
* This lightweight structure which represents and manipulates a two-dimensional bitmap matrix view | ||
* with row major order. This class provides functionality for handling a matrix where each element | ||
* is represented as a bit in a bitmap. | ||
* | ||
* @tparam bitmap_t Underlying type of the bitmap array. Default is uint32_t. | ||
* @tparam index_t Indexing type used. Default is uint32_t. | ||
*/ | ||
template <typename bitmap_t = uint32_t, typename index_t = uint32_t> | ||
struct bitmap_view : public bitset_view<bitmap_t, index_t> { | ||
static_assert((std::is_same<typename std::remove_const<bitmap_t>::type, uint32_t>::value || | ||
std::is_same<typename std::remove_const<bitmap_t>::type, uint64_t>::value), | ||
"The bitmap_t must be uint32_t or uint64_t."); | ||
/** | ||
* @brief Create a bitmap view from a device raw pointer. | ||
* | ||
* @param bitmap_ptr Device raw pointer | ||
* @param rows Number of row in the matrix. | ||
* @param cols Number of col in the matrix. | ||
*/ | ||
_RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr, index_t rows, index_t cols) | ||
: bitset_view<bitmap_t, index_t>(bitmap_ptr, rows * cols), rows_(rows), cols_(cols) | ||
{ | ||
} | ||
|
||
/** | ||
* @brief Create a bitmap view from a device vector view of the bitset. | ||
* | ||
* @param bitmap_span Device vector view of the bitmap | ||
* @param rows Number of row in the matrix. | ||
* @param cols Number of col in the matrix. | ||
*/ | ||
_RAFT_HOST_DEVICE bitmap_view(raft::device_vector_view<bitmap_t, index_t> bitmap_span, | ||
index_t rows, | ||
index_t cols) | ||
: bitset_view<bitmap_t, index_t>(bitmap_span, rows * cols), rows_(rows), cols_(cols) | ||
{ | ||
} | ||
|
||
private: | ||
// Hide the constructors of bitset_view. | ||
_RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr, index_t bitmap_len) | ||
: bitset_view<bitmap_t, index_t>(bitmap_ptr, bitmap_len) | ||
{ | ||
} | ||
|
||
_RAFT_HOST_DEVICE bitmap_view(raft::device_vector_view<bitmap_t, index_t> bitmap_span, | ||
index_t bitmap_len) | ||
: bitset_view<bitmap_t, index_t>(bitmap_span, bitmap_len) | ||
{ | ||
} | ||
|
||
public: | ||
/** | ||
* @brief Device function to test if a given row and col are set in the bitmap. | ||
* | ||
* @param row Row index of the bit to test | ||
* @param col Col index of the bit to test | ||
* @return bool True if index has not been unset in the bitset | ||
*/ | ||
inline _RAFT_HOST_DEVICE bool test(const index_t row, const index_t col) const; | ||
|
||
/** | ||
* @brief Device function to set a given row and col to set_value in the bitset. | ||
* | ||
* @param row Row index of the bit to set | ||
* @param col Col index of the bit to set | ||
* @param new_value Value to set the bit to (true or false) | ||
*/ | ||
inline _RAFT_HOST_DEVICE void set(const index_t row, const index_t col, bool new_value) const; | ||
|
||
/** | ||
* @brief Get the total number of rows | ||
* @return index_t The total number of rows | ||
*/ | ||
inline _RAFT_HOST_DEVICE index_t get_n_rows() const { return rows_; } | ||
|
||
/** | ||
* @brief Get the total number of columns | ||
* @return index_t The total number of columns | ||
*/ | ||
inline _RAFT_HOST_DEVICE index_t get_n_cols() const { return cols_; } | ||
|
||
private: | ||
index_t rows_; | ||
index_t cols_; | ||
}; | ||
|
||
/** @} */ | ||
} // end namespace raft::core |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
/* | ||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#pragma once | ||
|
||
#include <raft/core/detail/mdspan_util.cuh> | ||
#include <raft/core/device_mdarray.hpp> | ||
#include <raft/core/resources.hpp> | ||
#include <raft/linalg/coalesced_reduction.cuh> | ||
|
||
namespace raft::detail { | ||
|
||
/** | ||
* @brief Count the number of bits that are set to 1 in a vector. | ||
* | ||
* @tparam value_t the value type of the vector. | ||
* @tparam index_t the index type of vector and scalar. | ||
* | ||
* @param[in] res raft handle for managing expensive resources | ||
* @param[in] values Number of row in the matrix. | ||
* @param[in] max_len Maximum number of bits to count. | ||
* @param[out] counter Number of bits that are set to 1. | ||
*/ | ||
template <typename value_t, typename index_t> | ||
void popc(const raft::resources& res, | ||
device_vector_view<value_t, index_t> values, | ||
index_t max_len, | ||
raft::device_scalar_view<index_t> counter) | ||
{ | ||
auto values_size = values.size(); | ||
auto values_matrix = raft::make_device_matrix_view<value_t, index_t, col_major>( | ||
values.data_handle(), values_size, 1); | ||
auto counter_vector = raft::make_device_vector_view<index_t, index_t>(counter.data_handle(), 1); | ||
|
||
static constexpr index_t len_per_item = sizeof(value_t) * 8; | ||
|
||
value_t tail_len = (max_len % len_per_item); | ||
value_t tail_mask = tail_len ? (value_t)((value_t{1} << tail_len) - value_t{1}) : ~value_t{0}; | ||
raft::linalg::coalesced_reduction( | ||
res, | ||
values_matrix, | ||
counter_vector, | ||
index_t{0}, | ||
false, | ||
[tail_mask, values_size] __device__(value_t value, index_t index) { | ||
index_t result = 0; | ||
if constexpr (len_per_item == 64) { | ||
if (index == values_size - 1) | ||
result = index_t(raft::detail::popc(value & tail_mask)); | ||
else | ||
result = index_t(raft::detail::popc(value)); | ||
} else { // Needed because popc is not overloaded for 16 and 8 bit elements | ||
if (index == values_size - 1) | ||
result = index_t(raft::detail::popc(uint32_t{value} & tail_mask)); | ||
else | ||
result = index_t(raft::detail::popc(uint32_t{value})); | ||
} | ||
|
||
return result; | ||
}); | ||
} | ||
|
||
} // end namespace raft::detail |
Oops, something went wrong.