-
Notifications
You must be signed in to change notification settings - Fork 73
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEA] support of prefiltered brute force #146
Changes from all commits
7357b31
9fbca1a
c3e14d7
dd85921
bdce2a3
9ab817a
4aeb3a4
1d3adb8
d505b6b
b69a27b
7c5030b
d62fe79
8788763
a290192
5addf70
8873fde
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
/* | ||
* Copyright (c) 2024, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <raft/core/bitmap.hpp> | ||
|
||
namespace cuvs::core { | ||
/* To use bitmap functions containing CUDA code, include <raft/core/bitmap.cuh> */ | ||
|
||
template <typename bitmap_t, typename index_t> | ||
using bitmap_view = raft::core::bitmap_view<bitmap_t, index_t>; | ||
|
||
} // end namespace cuvs::core |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
*/ | ||
|
||
#pragma once | ||
|
||
#include <cuvs/distance/distance.hpp> | ||
#include <cuvs/neighbors/brute_force.hpp> | ||
|
||
|
@@ -23,16 +24,26 @@ | |
#include "./fused_l2_knn.cuh" | ||
#include "./haversine_distance.cuh" | ||
#include "./knn_merge_parts.cuh" | ||
#include "./knn_utils.cuh" | ||
|
||
rhdong marked this conversation as resolved.
Show resolved
Hide resolved
|
||
#include <raft/core/bitmap.cuh> | ||
#include <raft/core/detail/popc.cuh> | ||
#include <raft/core/device_csr_matrix.hpp> | ||
#include <raft/core/resource/cuda_stream.hpp> | ||
#include <raft/core/resource/cuda_stream_pool.hpp> | ||
#include <raft/core/resource/device_memory_resource.hpp> | ||
#include <raft/core/resource/thrust_policy.hpp> | ||
#include <raft/core/resources.hpp> | ||
#include <raft/linalg/map.cuh> | ||
#include <raft/linalg/norm.cuh> | ||
#include <raft/linalg/transpose.cuh> | ||
#include <raft/matrix/init.cuh> | ||
#include <raft/matrix/select_k.cuh> | ||
#include <raft/sparse/convert/coo.cuh> | ||
#include <raft/sparse/convert/csr.cuh> | ||
#include <raft/sparse/distance/detail/utils.cuh> | ||
#include <raft/sparse/linalg/sddmm.hpp> | ||
#include <raft/sparse/matrix/select_k.cuh> | ||
#include <raft/util/cuda_utils.cuh> | ||
#include <raft/util/cudart_utils.hpp> | ||
|
||
|
@@ -65,7 +76,8 @@ void tiled_brute_force_knn(const raft::resources& handle, | |
size_t max_row_tile_size = 0, | ||
size_t max_col_tile_size = 0, | ||
const ElementType* precomputed_index_norms = nullptr, | ||
const ElementType* precomputed_search_norms = nullptr) | ||
const ElementType* precomputed_search_norms = nullptr, | ||
const uint32_t* filter_bitmap = nullptr) | ||
{ | ||
// Figure out the number of rows/cols to tile for | ||
size_t tile_rows = 0; | ||
|
@@ -214,6 +226,27 @@ void tiled_brute_force_knn(const raft::resources& handle, | |
}); | ||
} | ||
|
||
if (filter_bitmap != nullptr) { | ||
auto distances_ptr = temp_distances.data(); | ||
auto count = thrust::make_counting_iterator<IndexType>(0); | ||
ElementType masked_distance = select_min ? std::numeric_limits<ElementType>::infinity() | ||
: std::numeric_limits<ElementType>::lowest(); | ||
thrust::for_each(raft::resource::get_thrust_policy(handle), | ||
count, | ||
count + current_query_size * current_centroid_size, | ||
[=] __device__(IndexType idx) { | ||
IndexType row = i + (idx / current_centroid_size); | ||
IndexType col = j + (idx % current_centroid_size); | ||
IndexType g_idx = row * n + col; | ||
IndexType item_idx = (g_idx) >> 5; | ||
uint32_t bit_idx = (g_idx)&31; | ||
uint32_t filter = filter_bitmap[item_idx]; | ||
if ((filter & (uint32_t(1) << bit_idx)) == 0) { | ||
distances_ptr[idx] = masked_distance; | ||
} | ||
}); | ||
} | ||
|
||
raft::matrix::select_k<ElementType, IndexType>( | ||
handle, | ||
raft::make_device_matrix_view<const ElementType, int64_t, raft::row_major>( | ||
|
@@ -519,6 +552,173 @@ void brute_force_search( | |
query_norms ? query_norms->data_handle() : nullptr); | ||
} | ||
|
||
template <typename T, typename IdxT, typename BitmapT> | ||
void brute_force_search_filtered( | ||
raft::resources const& res, | ||
const cuvs::neighbors::brute_force::index<T>& idx, | ||
raft::device_matrix_view<const T, IdxT, raft::row_major> queries, | ||
cuvs::core::bitmap_view<const BitmapT, IdxT> filter, | ||
raft::device_matrix_view<IdxT, IdxT, raft::row_major> neighbors, | ||
raft::device_matrix_view<T, IdxT, raft::row_major> distances, | ||
std::optional<raft::device_vector_view<const T, IdxT>> query_norms = std::nullopt) | ||
{ | ||
auto metric = idx.metric(); | ||
|
||
RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), "Value of k must match for outputs"); | ||
RAFT_EXPECTS(idx.dataset().extent(1) == queries.extent(1), | ||
"Number of columns in queries must match brute force index"); | ||
RAFT_EXPECTS(metric == cuvs::distance::DistanceType::InnerProduct || | ||
metric == cuvs::distance::DistanceType::L2Expanded || | ||
metric == cuvs::distance::DistanceType::L2SqrtExpanded || | ||
metric == cuvs::distance::DistanceType::CosineExpanded, | ||
"Only Euclidean, IP, and Cosine are supported!"); | ||
|
||
RAFT_EXPECTS(idx.has_norms() || !(metric == cuvs::distance::DistanceType::L2Expanded || | ||
metric == cuvs::distance::DistanceType::L2SqrtExpanded || | ||
metric == cuvs::distance::DistanceType::CosineExpanded), | ||
"Index must has norms when using Euclidean, IP, and Cosine!"); | ||
|
||
IdxT n_queries = queries.extent(0); | ||
IdxT n_dataset = idx.dataset().extent(0); | ||
IdxT dim = idx.dataset().extent(1); | ||
IdxT k = neighbors.extent(1); | ||
|
||
auto stream = raft::resource::get_cuda_stream(res); | ||
|
||
// calc nnz | ||
IdxT nnz_h = 0; | ||
rmm::device_scalar<IdxT> nnz(0, stream); | ||
auto nnz_view = raft::make_device_scalar_view<IdxT>(nnz.data()); | ||
auto filter_view = | ||
raft::make_device_vector_view<const BitmapT, IdxT>(filter.data(), filter.n_elements()); | ||
|
||
// TODO(rhdong): Need to switch to the public API, | ||
// with the issue: https://github.com/rapidsai/cuvs/issues/158 | ||
raft::detail::popc(res, filter_view, n_queries * n_dataset, nnz_view); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should never be calling into raft detail APIs in cuvs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Totally agree! I've noticed this, but as you know, just for so tight schedule to split the raft internal calling to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for creating an issue for this. Can you please reference a link to the github issue in a comment on this line of code so we don't lose sight of it? Then we can merge this in. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure! |
||
raft::copy(&nnz_h, nnz.data(), 1, stream); | ||
|
||
raft::resource::sync_stream(res, stream); | ||
float sparsity = (1.0f * nnz_h / (1.0f * n_queries * n_dataset)); | ||
|
||
if (sparsity > 0.01f) { | ||
raft::resources stream_pool_handle(res); | ||
raft::resource::set_cuda_stream(stream_pool_handle, stream); | ||
auto idx_norm = idx.has_norms() ? const_cast<T*>(idx.norms().data_handle()) : nullptr; | ||
|
||
tiled_brute_force_knn<T, IdxT>(stream_pool_handle, | ||
queries.data_handle(), | ||
idx.dataset().data_handle(), | ||
n_queries, | ||
n_dataset, | ||
dim, | ||
k, | ||
distances.data_handle(), | ||
neighbors.data_handle(), | ||
metric, | ||
2.0, | ||
0, | ||
0, | ||
idx_norm, | ||
nullptr, | ||
filter.data()); | ||
} else { | ||
auto csr = raft::make_device_csr_matrix<T, IdxT>(res, n_queries, n_dataset, nnz_h); | ||
|
||
// fill csr | ||
raft::sparse::convert::bitmap_to_csr(res, filter, csr); | ||
|
||
// create filter csr view | ||
auto compressed_csr_view = csr.structure_view(); | ||
rmm::device_uvector<IdxT> rows(compressed_csr_view.get_nnz(), stream); | ||
raft::sparse::convert::csr_to_coo(compressed_csr_view.get_indptr().data(), | ||
compressed_csr_view.get_n_rows(), | ||
rows.data(), | ||
compressed_csr_view.get_nnz(), | ||
stream); | ||
if (n_queries > 10) { | ||
auto csr_view = raft::make_device_csr_matrix_view<T, IdxT, IdxT, IdxT>( | ||
csr.get_elements().data(), compressed_csr_view); | ||
|
||
// create dataset view | ||
auto dataset_view = raft::make_device_matrix_view<const T, IdxT, raft::col_major>( | ||
idx.dataset().data_handle(), dim, n_dataset); | ||
|
||
// calc dot | ||
T alpha = static_cast<T>(1.0f); | ||
T beta = static_cast<T>(0.0f); | ||
raft::sparse::linalg::sddmm(res, | ||
queries, | ||
dataset_view, | ||
csr_view, | ||
raft::linalg::Operation::NON_TRANSPOSE, | ||
raft::linalg::Operation::NON_TRANSPOSE, | ||
raft::make_host_scalar_view<T>(&alpha), | ||
raft::make_host_scalar_view<T>(&beta)); | ||
} else { | ||
raft::sparse::distance::detail::faster_dot_on_csr(res, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should never be calling into RAFT detail APIs in cuvs. If you think about the reason why we separate public from details APIs, it's so that each library can put forth a contract that they make guarantees not to break across versions while still hosting internal implementation-specific code that makes no such guarantees. If we're invoking detail APIs from RAFT within cuVS then we're going to end up either 1) making breakig changes those detail APIs and not realizing we just broke cuVS downstream, or 2) never being able to make changes to detail APIs because they need to maintain the same guarantees as public APIs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This needs to be fixed, but RAFT is in code freeze. What I think you should do to fix this is to create a GIthub issue in cuVS to describe the issue, then add comments in all the places in this PR that you are calling into RAFT detail APIs and reference the Github issue there. After you do that and address the other PR review feedback, we can merge this PR into 24.06 but you'll need to open up a RAFT PR to immediately expose all of these APIs publicly in RAFT and then open up a PR to invoke the public APIs in cuVS (target both to 24.08 since it's too late to make raft changes for 24.06). Of course, if you've already exposed these APIs publicly in RAFT and calling into detail was just an oversight, then they could simply be fixed in this PR. The general rule of thumb for detail APIs even within a library is that you should expose any functions publicly if you are going to invoke them across namespace separate namespace boundaries. Obvously that's more simple to do for a header-only library like RAFT than it is for a library where all public APIs are strictly compiled like cuVS, but we should be striving to reduce the amount of inter-namespace calls into detail APIs as possible so we can maintain our public API contracts. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for creating an issue to fix this. Let's get this C++ side merged in. Can you please reference the github issue in a comment just abov ethis line of code so that we don't lose sight of it? Then we can get this merged in. |
||
csr.get_elements().data(), | ||
compressed_csr_view.get_nnz(), | ||
compressed_csr_view.get_indptr().data(), | ||
compressed_csr_view.get_indices().data(), | ||
queries.data_handle(), | ||
idx.dataset().data_handle(), | ||
compressed_csr_view.get_n_rows(), | ||
dim); | ||
} | ||
|
||
// post process | ||
std::optional<raft::device_vector<T, IdxT>> query_norms_; | ||
if (metric == cuvs::distance::DistanceType::L2Expanded || | ||
metric == cuvs::distance::DistanceType::L2SqrtExpanded || | ||
metric == cuvs::distance::DistanceType::CosineExpanded) { | ||
if (metric == cuvs::distance::DistanceType::CosineExpanded) { | ||
if (!query_norms) { | ||
query_norms_ = raft::make_device_vector<T, IdxT>(res, n_queries); | ||
raft::linalg::rowNorm((T*)(query_norms_->data_handle()), | ||
queries.data_handle(), | ||
dim, | ||
n_queries, | ||
raft::linalg::L2Norm, | ||
true, | ||
stream, | ||
raft::sqrt_op{}); | ||
} | ||
} else { | ||
if (!query_norms) { | ||
query_norms_ = raft::make_device_vector<T, IdxT>(res, n_queries); | ||
raft::linalg::rowNorm((T*)(query_norms_->data_handle()), | ||
queries.data_handle(), | ||
dim, | ||
n_queries, | ||
raft::linalg::L2Norm, | ||
true, | ||
stream, | ||
raft::identity_op{}); | ||
} | ||
} | ||
cuvs::neighbors::detail::epilogue_on_csr( | ||
res, | ||
csr.get_elements().data(), | ||
compressed_csr_view.get_nnz(), | ||
rows.data(), | ||
compressed_csr_view.get_indices().data(), | ||
query_norms ? query_norms->data_handle() : query_norms_->data_handle(), | ||
idx.norms().data_handle(), | ||
metric); | ||
} | ||
|
||
// select k | ||
auto const_csr_view = raft::make_device_csr_matrix_view<const T, IdxT, IdxT, IdxT>( | ||
csr.get_elements().data(), compressed_csr_view); | ||
std::optional<raft::device_vector_view<const IdxT, IdxT>> no_opt = std::nullopt; | ||
bool select_min = cuvs::distance::is_min_close(metric); | ||
raft::sparse::matrix::select_k( | ||
res, const_csr_view, no_opt, distances, neighbors, select_min, true); | ||
} | ||
|
||
return; | ||
} | ||
|
||
template <typename T> | ||
cuvs::neighbors::brute_force::index<T> build( | ||
raft::resources const& res, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The bitset header is included here because the bitset struct prototype is being defined here. I don't see anything being defined for the bitmap here, though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mainly for reducing the number of the include files in the
cuvs/neighbors/brute_force.hpp
and other files needcuvs::core::bitmap_view
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should be including what we use and not relying on transitivity here. Ultimately, this creates a lot of confusion for future developers. If bitmap is being explicitly required in parts of the public APIs, it should be getting included there and not here. Also, why aren't we including the bitmap here, defining it here, and compiling it in src/ like bitset is doing?