Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] support of prefiltered brute force #146

Merged
merged 16 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions cpp/include/cuvs/core/bitmap.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/core/bitmap.hpp>

namespace cuvs::core {
/* To use bitmap functions containing CUDA code, include <raft/core/bitmap.cuh> */

template <typename bitmap_t, typename index_t>
using bitmap_view = raft::core::bitmap_view<bitmap_t, index_t>;

} // end namespace cuvs::core
5 changes: 4 additions & 1 deletion cpp/include/cuvs/neighbors/brute_force.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,15 @@ auto build(raft::resources const& handle,
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k]
* @param[in] sample_filter a optional device bitmap filter function that greenlights samples for a
* given
*/
void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::index<float>& index,
raft::device_matrix_view<const float, int64_t, raft::row_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances);
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
std::optional<cuvs::core::bitmap_view<const uint32_t, int64_t>> sample_filter);
/**
* @}
*/
Expand Down
1 change: 1 addition & 0 deletions cpp/include/cuvs/neighbors/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <raft/util/cudart_utils.hpp> // get_device_for_address
#include <raft/util/integer_utils.hpp> // rounding up

#include <cuvs/core/bitmap.hpp>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bitset header is included here because the bitset struct prototype is being defined here. I don't see anything being defined for the bitmap here, though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mainly for reducing the number of the include files in the cuvs/neighbors/brute_force.hpp and other files need cuvs::core::bitmap_view

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be including what we use and not relying on transitivity here. Ultimately, this creates a lot of confusion for future developers. If bitmap is being explicitly required in parts of the public APIs, it should be getting included there and not here. Also, why aren't we including the bitmap here, defining it here, and compiling it in src/ like bitset is doing?

#include <cuvs/core/bitset.hpp>
#include <raft/core/detail/macros.hpp>

Expand Down
46 changes: 27 additions & 19 deletions cpp/src/neighbors/brute_force.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include "./detail/knn_brute_force.cuh"

#include <cuvs/neighbors/brute_force.hpp>

#include <raft/core/copy.hpp>
Expand Down Expand Up @@ -84,25 +85,32 @@ void index<T>::update_dataset(raft::resources const& res,
dataset_view_ = raft::make_const_mdspan(dataset_.view());
}

#define CUVS_INST_BFKNN(T) \
auto build(raft::resources const& res, \
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset, \
cuvs::distance::DistanceType metric, \
T metric_arg) \
->cuvs::neighbors::brute_force::index<T> \
{ \
return detail::build<T>(res, dataset, metric, metric_arg); \
} \
\
void search(raft::resources const& res, \
const cuvs::neighbors::brute_force::index<T>& idx, \
raft::device_matrix_view<const T, int64_t, raft::row_major> queries, \
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors, \
raft::device_matrix_view<T, int64_t, raft::row_major> distances) \
{ \
detail::brute_force_search<T, int64_t>(res, idx, queries, neighbors, distances); \
} \
\
#define CUVS_INST_BFKNN(T) \
auto build(raft::resources const& res, \
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset, \
cuvs::distance::DistanceType metric, \
T metric_arg) \
->cuvs::neighbors::brute_force::index<T> \
{ \
return detail::build<T>(res, dataset, metric, metric_arg); \
} \
\
void search( \
raft::resources const& res, \
const cuvs::neighbors::brute_force::index<T>& idx, \
raft::device_matrix_view<const T, int64_t, raft::row_major> queries, \
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors, \
raft::device_matrix_view<T, int64_t, raft::row_major> distances, \
std::optional<cuvs::core::bitmap_view<const uint32_t, int64_t>> sample_filter = std::nullopt) \
{ \
if (!sample_filter.has_value()) { \
detail::brute_force_search<T, int64_t>(res, idx, queries, neighbors, distances); \
} else { \
detail::brute_force_search_filtered<T, int64_t>( \
res, idx, queries, *sample_filter, neighbors, distances); \
} \
} \
\
template struct cuvs::neighbors::brute_force::index<T>;

CUVS_INST_BFKNN(float);
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/neighbors/brute_force_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ void _search(cuvsResources_t res,
auto distances_mds = cuvs::core::from_dlpack<distances_mdspan_type>(distances_tensor);

cuvs::neighbors::brute_force::search(
*res_ptr, *index_ptr, queries_mds, neighbors_mds, distances_mds);
*res_ptr, *index_ptr, queries_mds, neighbors_mds, distances_mds, std::nullopt);
}

} // namespace
Expand Down
202 changes: 201 additions & 1 deletion cpp/src/neighbors/detail/knn_brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#pragma once

#include <cuvs/distance/distance.hpp>
#include <cuvs/neighbors/brute_force.hpp>

Expand All @@ -23,16 +24,26 @@
#include "./fused_l2_knn.cuh"
#include "./haversine_distance.cuh"
#include "./knn_merge_parts.cuh"
#include "./knn_utils.cuh"

rhdong marked this conversation as resolved.
Show resolved Hide resolved
#include <raft/core/bitmap.cuh>
#include <raft/core/detail/popc.cuh>
#include <raft/core/device_csr_matrix.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/cuda_stream_pool.hpp>
#include <raft/core/resource/device_memory_resource.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/map.cuh>
#include <raft/linalg/norm.cuh>
#include <raft/linalg/transpose.cuh>
#include <raft/matrix/init.cuh>
#include <raft/matrix/select_k.cuh>
#include <raft/sparse/convert/coo.cuh>
#include <raft/sparse/convert/csr.cuh>
#include <raft/sparse/distance/detail/utils.cuh>
#include <raft/sparse/linalg/sddmm.hpp>
#include <raft/sparse/matrix/select_k.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>

Expand Down Expand Up @@ -65,7 +76,8 @@ void tiled_brute_force_knn(const raft::resources& handle,
size_t max_row_tile_size = 0,
size_t max_col_tile_size = 0,
const ElementType* precomputed_index_norms = nullptr,
const ElementType* precomputed_search_norms = nullptr)
const ElementType* precomputed_search_norms = nullptr,
const uint32_t* filter_bitmap = nullptr)
{
// Figure out the number of rows/cols to tile for
size_t tile_rows = 0;
Expand Down Expand Up @@ -214,6 +226,27 @@ void tiled_brute_force_knn(const raft::resources& handle,
});
}

if (filter_bitmap != nullptr) {
auto distances_ptr = temp_distances.data();
auto count = thrust::make_counting_iterator<IndexType>(0);
ElementType masked_distance = select_min ? std::numeric_limits<ElementType>::infinity()
: std::numeric_limits<ElementType>::lowest();
thrust::for_each(raft::resource::get_thrust_policy(handle),
count,
count + current_query_size * current_centroid_size,
[=] __device__(IndexType idx) {
IndexType row = i + (idx / current_centroid_size);
IndexType col = j + (idx % current_centroid_size);
IndexType g_idx = row * n + col;
IndexType item_idx = (g_idx) >> 5;
uint32_t bit_idx = (g_idx)&31;
uint32_t filter = filter_bitmap[item_idx];
if ((filter & (uint32_t(1) << bit_idx)) == 0) {
distances_ptr[idx] = masked_distance;
}
});
}

raft::matrix::select_k<ElementType, IndexType>(
handle,
raft::make_device_matrix_view<const ElementType, int64_t, raft::row_major>(
Expand Down Expand Up @@ -519,6 +552,173 @@ void brute_force_search(
query_norms ? query_norms->data_handle() : nullptr);
}

template <typename T, typename IdxT, typename BitmapT>
void brute_force_search_filtered(
raft::resources const& res,
const cuvs::neighbors::brute_force::index<T>& idx,
raft::device_matrix_view<const T, IdxT, raft::row_major> queries,
cuvs::core::bitmap_view<const BitmapT, IdxT> filter,
raft::device_matrix_view<IdxT, IdxT, raft::row_major> neighbors,
raft::device_matrix_view<T, IdxT, raft::row_major> distances,
std::optional<raft::device_vector_view<const T, IdxT>> query_norms = std::nullopt)
{
auto metric = idx.metric();

RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), "Value of k must match for outputs");
RAFT_EXPECTS(idx.dataset().extent(1) == queries.extent(1),
"Number of columns in queries must match brute force index");
RAFT_EXPECTS(metric == cuvs::distance::DistanceType::InnerProduct ||
metric == cuvs::distance::DistanceType::L2Expanded ||
metric == cuvs::distance::DistanceType::L2SqrtExpanded ||
metric == cuvs::distance::DistanceType::CosineExpanded,
"Only Euclidean, IP, and Cosine are supported!");

RAFT_EXPECTS(idx.has_norms() || !(metric == cuvs::distance::DistanceType::L2Expanded ||
metric == cuvs::distance::DistanceType::L2SqrtExpanded ||
metric == cuvs::distance::DistanceType::CosineExpanded),
"Index must has norms when using Euclidean, IP, and Cosine!");

IdxT n_queries = queries.extent(0);
IdxT n_dataset = idx.dataset().extent(0);
IdxT dim = idx.dataset().extent(1);
IdxT k = neighbors.extent(1);

auto stream = raft::resource::get_cuda_stream(res);

// calc nnz
IdxT nnz_h = 0;
rmm::device_scalar<IdxT> nnz(0, stream);
auto nnz_view = raft::make_device_scalar_view<IdxT>(nnz.data());
auto filter_view =
raft::make_device_vector_view<const BitmapT, IdxT>(filter.data(), filter.n_elements());

// TODO(rhdong): Need to switch to the public API,
// with the issue: https://github.com/rapidsai/cuvs/issues/158
raft::detail::popc(res, filter_view, n_queries * n_dataset, nnz_view);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should never be calling into raft detail APIs in cuvs.

Copy link
Member Author

@rhdong rhdong May 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally agree! I've noticed this, but as you know, just for so tight schedule to split the raft internal calling to cuVS(I suppose public API needs more regular test/benchmark code, at that time, I had to make it in first). I will fix it ASAP: #158

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for creating an issue for this. Can you please reference a link to the github issue in a comment on this line of code so we don't lose sight of it? Then we can merge this in.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure!

raft::copy(&nnz_h, nnz.data(), 1, stream);

raft::resource::sync_stream(res, stream);
float sparsity = (1.0f * nnz_h / (1.0f * n_queries * n_dataset));

if (sparsity > 0.01f) {
raft::resources stream_pool_handle(res);
raft::resource::set_cuda_stream(stream_pool_handle, stream);
auto idx_norm = idx.has_norms() ? const_cast<T*>(idx.norms().data_handle()) : nullptr;

tiled_brute_force_knn<T, IdxT>(stream_pool_handle,
queries.data_handle(),
idx.dataset().data_handle(),
n_queries,
n_dataset,
dim,
k,
distances.data_handle(),
neighbors.data_handle(),
metric,
2.0,
0,
0,
idx_norm,
nullptr,
filter.data());
} else {
auto csr = raft::make_device_csr_matrix<T, IdxT>(res, n_queries, n_dataset, nnz_h);

// fill csr
raft::sparse::convert::bitmap_to_csr(res, filter, csr);

// create filter csr view
auto compressed_csr_view = csr.structure_view();
rmm::device_uvector<IdxT> rows(compressed_csr_view.get_nnz(), stream);
raft::sparse::convert::csr_to_coo(compressed_csr_view.get_indptr().data(),
compressed_csr_view.get_n_rows(),
rows.data(),
compressed_csr_view.get_nnz(),
stream);
if (n_queries > 10) {
auto csr_view = raft::make_device_csr_matrix_view<T, IdxT, IdxT, IdxT>(
csr.get_elements().data(), compressed_csr_view);

// create dataset view
auto dataset_view = raft::make_device_matrix_view<const T, IdxT, raft::col_major>(
idx.dataset().data_handle(), dim, n_dataset);

// calc dot
T alpha = static_cast<T>(1.0f);
T beta = static_cast<T>(0.0f);
raft::sparse::linalg::sddmm(res,
queries,
dataset_view,
csr_view,
raft::linalg::Operation::NON_TRANSPOSE,
raft::linalg::Operation::NON_TRANSPOSE,
raft::make_host_scalar_view<T>(&alpha),
raft::make_host_scalar_view<T>(&beta));
} else {
raft::sparse::distance::detail::faster_dot_on_csr(res,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should never be calling into RAFT detail APIs in cuvs. If you think about the reason why we separate public from details APIs, it's so that each library can put forth a contract that they make guarantees not to break across versions while still hosting internal implementation-specific code that makes no such guarantees. If we're invoking detail APIs from RAFT within cuVS then we're going to end up either 1) making breakig changes those detail APIs and not realizing we just broke cuVS downstream, or 2) never being able to make changes to detail APIs because they need to maintain the same guarantees as public APIs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be fixed, but RAFT is in code freeze. What I think you should do to fix this is to create a GIthub issue in cuVS to describe the issue, then add comments in all the places in this PR that you are calling into RAFT detail APIs and reference the Github issue there.

After you do that and address the other PR review feedback, we can merge this PR into 24.06 but you'll need to open up a RAFT PR to immediately expose all of these APIs publicly in RAFT and then open up a PR to invoke the public APIs in cuVS (target both to 24.08 since it's too late to make raft changes for 24.06).

Of course, if you've already exposed these APIs publicly in RAFT and calling into detail was just an oversight, then they could simply be fixed in this PR. The general rule of thumb for detail APIs even within a library is that you should expose any functions publicly if you are going to invoke them across namespace separate namespace boundaries. Obvously that's more simple to do for a header-only library like RAFT than it is for a library where all public APIs are strictly compiled like cuVS, but we should be striving to reduce the amount of inter-namespace calls into detail APIs as possible so we can maintain our public API contracts.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for creating an issue to fix this. Let's get this C++ side merged in. Can you please reference the github issue in a comment just abov ethis line of code so that we don't lose sight of it? Then we can get this merged in.

csr.get_elements().data(),
compressed_csr_view.get_nnz(),
compressed_csr_view.get_indptr().data(),
compressed_csr_view.get_indices().data(),
queries.data_handle(),
idx.dataset().data_handle(),
compressed_csr_view.get_n_rows(),
dim);
}

// post process
std::optional<raft::device_vector<T, IdxT>> query_norms_;
if (metric == cuvs::distance::DistanceType::L2Expanded ||
metric == cuvs::distance::DistanceType::L2SqrtExpanded ||
metric == cuvs::distance::DistanceType::CosineExpanded) {
if (metric == cuvs::distance::DistanceType::CosineExpanded) {
if (!query_norms) {
query_norms_ = raft::make_device_vector<T, IdxT>(res, n_queries);
raft::linalg::rowNorm((T*)(query_norms_->data_handle()),
queries.data_handle(),
dim,
n_queries,
raft::linalg::L2Norm,
true,
stream,
raft::sqrt_op{});
}
} else {
if (!query_norms) {
query_norms_ = raft::make_device_vector<T, IdxT>(res, n_queries);
raft::linalg::rowNorm((T*)(query_norms_->data_handle()),
queries.data_handle(),
dim,
n_queries,
raft::linalg::L2Norm,
true,
stream,
raft::identity_op{});
}
}
cuvs::neighbors::detail::epilogue_on_csr(
res,
csr.get_elements().data(),
compressed_csr_view.get_nnz(),
rows.data(),
compressed_csr_view.get_indices().data(),
query_norms ? query_norms->data_handle() : query_norms_->data_handle(),
idx.norms().data_handle(),
metric);
}

// select k
auto const_csr_view = raft::make_device_csr_matrix_view<const T, IdxT, IdxT, IdxT>(
csr.get_elements().data(), compressed_csr_view);
std::optional<raft::device_vector_view<const IdxT, IdxT>> no_opt = std::nullopt;
bool select_min = cuvs::distance::is_min_close(metric);
raft::sparse::matrix::select_k(
res, const_csr_view, no_opt, distances, neighbors, select_min, true);
}

return;
}

template <typename T>
cuvs::neighbors::brute_force::index<T> build(
raft::resources const& res,
Expand Down
Loading
Loading