Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] support of prefiltered brute force #2294

Merged
merged 29 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
d85fcf0
[FEA] support of prefiltered brute force based on cuSparseSDDMM
rhdong May 7, 2024
e3ef7bc
Merge remote-tracking branch 'origin/branch-24.06' into rhdong/prefil…
rhdong May 8, 2024
c7e4e7a
Improve the performance in classic scenarios by replace the cuSparseS…
rhdong May 8, 2024
ed07f60
Merge remote-tracking branch 'origin/branch-24.06' into rhdong/prefil…
rhdong May 13, 2024
5c5aa9b
optimize and remove used.
rhdong May 13, 2024
b4971c2
Update cpp/include/raft/sparse/distance/detail/utils.cuh
achirkin May 13, 2024
b1c1bb8
Merge remote-tracking branch 'origin/branch-24.06' into rhdong/prefil…
rhdong May 14, 2024
77ee4a6
Test cases adjustment
rhdong May 14, 2024
ea8420a
Merge remote-tracking branch 'origin/branch-24.06' into rhdong/prefil…
rhdong May 15, 2024
cc2b228
Merge SDDMM with customized kernel, optimize bitset count
rhdong May 15, 2024
68731b0
Merge branch 'branch-24.06' into rhdong/prefiltered-bf
cjnolet May 17, 2024
3a81f19
Merge remote-tracking branch 'origin/branch-24.06' into rhdong/prefil…
rhdong May 20, 2024
2684afe
Optimize by dense bfknn
rhdong May 20, 2024
57193a5
Merge branch 'rhdong/prefiltered-bf' of https://github.com/rhdong/raf…
rhdong May 20, 2024
8e1217c
Optimize the test cases
rhdong May 20, 2024
56f00cd
Merge branch 'rhdong/prefiltered-bf' of https://github.com/rhdong/raf…
rhdong May 20, 2024
71bd24b
Merge remote-tracking branch 'origin/branch-24.06' into rhdong/prefil…
rhdong May 21, 2024
96f4e83
Splitting(revert) the cuVS part
rhdong May 21, 2024
4f1aa17
Fix CI issue
rhdong May 21, 2024
b718673
Merge remote-tracking branch 'origin/branch-24.06' into rhdong/prefil…
rhdong May 22, 2024
9e24c5a
Move sparse distance API utils to cuvs and split the bitmap
rhdong May 22, 2024
18cb672
Optimize by review comments
rhdong May 23, 2024
e393af9
Merge branch 'branch-24.06' into rhdong/prefiltered-bf
cjnolet May 23, 2024
72c71f5
Merge remote-tracking branch 'origin/branch-24.06' into rhdong/prefil…
rhdong May 23, 2024
97a0e74
Remove the sparse select_k instantiations
rhdong May 23, 2024
de49e0c
Merge branch 'rhdong/prefiltered-bf' of https://github.com/rhdong/raf…
rhdong May 23, 2024
7d08443
Fix CI issue
rhdong May 23, 2024
18ba927
Merge branch 'branch-24.06' into rhdong/prefiltered-bf
cjnolet May 23, 2024
f38642f
Fix docs issue.
rhdong May 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,14 @@ if(RAFT_COMPILE_LIBRARY)
src/matrix/detail/select_k_float_int32.cu
src/matrix/detail/select_k_half_int64_t.cu
src/matrix/detail/select_k_half_uint32_t.cu
src/sparse/matrix/detail/select_k_half_uint32_t.cu
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're removing libraft.so in a future version, so I think we probably want to avoid adding these instantiations.

src/sparse/matrix/detail/select_k_double_int64_t.cu
src/sparse/matrix/detail/select_k_double_uint32_t.cu
src/sparse/matrix/detail/select_k_float_int64_t.cu
src/sparse/matrix/detail/select_k_float_uint32_t.cu
src/sparse/matrix/detail/select_k_float_int32.cu
src/sparse/matrix/detail/select_k_half_int64_t.cu
src/sparse/matrix/detail/select_k_half_uint32_t.cu
src/neighbors/ball_cover.cu
src/neighbors/brute_force_fused_l2_knn_float_int64_t.cu
src/neighbors/brute_force_knn_int64_t_float_int64_t.cu
Expand Down
1 change: 1 addition & 0 deletions cpp/bench/prims/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ if(BUILD_PRIMS_BENCH)
NAME
NEIGHBORS_BENCH
PATH
bench/prims/neighbors/knn/brute_force_filter_float_int64_t.cu
bench/prims/neighbors/knn/brute_force_float_int64_t.cu
bench/prims/neighbors/knn/brute_force_float_uint32_t.cu
bench/prims/neighbors/knn/cagra_float_uint32_t.cu
Expand Down
134 changes: 128 additions & 6 deletions cpp/bench/prims/neighbors/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <raft/core/bitset.cuh>
#include <raft/core/resource/device_id.hpp>
#include <raft/neighbors/brute_force.cuh>
#include <raft/neighbors/ivf_flat.cuh>
#include <raft/neighbors/ivf_pq.cuh>
#include <raft/neighbors/sample_filter.cuh>
Expand All @@ -36,7 +37,10 @@

#include <thrust/sequence.h>

#include <algorithm>
#include <optional>
#include <random>
#include <vector>

namespace raft::bench::spatial {

Expand All @@ -51,12 +55,24 @@ struct params {
size_t k;
/** Ratio of removed indices. */
double removed_ratio;
/** Distance Type. */
raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded;
};

inline auto operator<<(std::ostream& os, const params& p) -> std::ostream&
{
os << p.n_samples << "#" << p.n_dims << "#" << p.n_queries << "#" << p.k << "#"
<< p.removed_ratio;
os << p.n_samples << "#" << p.n_dims << "#" << p.n_queries << "#" << p.k;
if (p.removed_ratio > 0.0) {
os << "#" << p.removed_ratio;
} else {
os << "#"
<< "[No filtered]";
}
switch (p.metric) {
case raft::distance::DistanceType::InnerProduct: os << "#InnerProduct"; break;
case raft::distance::DistanceType::L2Expanded: os << "#L2Expanded"; break;
default: os << "UNKNOWN DistanceType, please add one case here.";
}
return os;
}

Expand Down Expand Up @@ -149,7 +165,7 @@ struct ivf_flat_knn {
ivf_flat_knn(const raft::device_resources& handle, const params& ps, const ValT* data) : ps(ps)
{
index_params.n_lists = 4096;
index_params.metric = raft::distance::DistanceType::L2Expanded;
index_params.metric = ps.metric;
index.emplace(raft::neighbors::ivf_flat::build(
handle, index_params, data, IdxT(ps.n_samples), uint32_t(ps.n_dims)));
}
Expand Down Expand Up @@ -184,7 +200,7 @@ struct ivf_pq_knn {
ivf_pq_knn(const raft::device_resources& handle, const params& ps, const ValT* data) : ps(ps)
{
index_params.n_lists = 4096;
index_params.metric = raft::distance::DistanceType::L2Expanded;
index_params.metric = ps.metric;
auto data_view = raft::make_device_matrix_view<const ValT, IdxT>(data, ps.n_samples, ps.n_dims);
index.emplace(raft::neighbors::ivf_pq::build(handle, index_params, data_view));
}
Expand Down Expand Up @@ -236,6 +252,88 @@ struct brute_force_knn {
}
};

template <typename IdxT, typename bitmap_t = std::uint32_t>
RAFT_KERNEL initialize_random_bits(
bitmap_t* data, IdxT N, float sparsity, size_t total_bits, unsigned long seed)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) return;

curandState state;
curand_init(seed, idx, 0, &state);

bitmap_t value = 0;
for (int i = 0; i < sizeof(bitmap_t) * 8; i++) {
int rnd = curand(&state) % 10000;

if (rnd < int(10000 * sparsity) && (idx * sizeof(bitmap_t) * 8 + i < total_bits)) {
bitmap_t bit_mask = 1u << i;
value |= bit_mask;
}
}
data[idx] = value;
}

template <typename ValT, typename IdxT>
struct brute_force_filter_knn {
using dist_t = float;
using bitmap_t = std::uint32_t;

std::optional<raft::neighbors::brute_force::index<ValT>> index;
raft::neighbors::brute_force::index_params index_params;
raft::neighbors::brute_force::search_params search_params;
raft::core::bitset<bitmap_t, IdxT> removed_indices_bitset_;
params ps;

brute_force_filter_knn(const raft::device_resources& handle, const params& ps, const ValT* data)
: ps(ps), removed_indices_bitset_(handle, ps.n_samples * ps.n_queries)
{
auto stream = resource::get_cuda_stream(handle);
index_params.metric = ps.metric;

auto data_view = raft::make_device_matrix_view<const ValT, IdxT>(data, ps.n_samples, ps.n_dims);
index.emplace(raft::neighbors::brute_force::build(handle, index_params, data_view));

IdxT element = raft::ceildiv(IdxT(ps.n_samples * ps.n_queries), IdxT(sizeof(bitmap_t) * 8));

size_t threadsPerBlock = 256;
size_t numBlocks = (element + threadsPerBlock - 1) / threadsPerBlock;
unsigned long seed = 1234;
initialize_random_bits<<<numBlocks, threadsPerBlock, 0, stream>>>(
removed_indices_bitset_.data(),
removed_indices_bitset_.size(),
float(1.0 - ps.removed_ratio),
ps.n_samples * ps.n_queries,
seed);

resource::sync_stream(handle);
}

void search(const raft::device_resources& handle,
const ValT* search_items,
ValT* out_dists,
IdxT* out_idxs)
{
auto queries_view =
raft::make_device_matrix_view<const ValT, IdxT>(search_items, ps.n_queries, ps.n_dims);
auto neighbors_view =
raft::make_device_matrix_view<IdxT, IdxT, raft::row_major>(out_idxs, ps.n_queries, ps.k);
auto distance_view =
raft::make_device_matrix_view<ValT, IdxT, raft::row_major>(out_dists, ps.n_queries, ps.k);

if (ps.removed_ratio > 0) {
auto filter = raft::core::bitmap_view(
(const bitmap_t*)removed_indices_bitset_.data(), IdxT(ps.n_queries), IdxT(ps.n_samples));

raft::neighbors::brute_force::search_with_filtering(
handle, *index, queries_view, filter, neighbors_view, distance_view);
} else {
raft::neighbors::brute_force::search(
handle, search_params, *index, queries_view, neighbors_view, distance_view);
}
}
};

template <typename ValT, typename IdxT>
struct ivf_flat_filter_knn {
using dist_t = float;
Expand All @@ -250,7 +348,7 @@ struct ivf_flat_filter_knn {
: ps(ps), removed_indices_bitset_(handle, ps.n_samples)
{
index_params.n_lists = 4096;
index_params.metric = raft::distance::DistanceType::L2Expanded;
index_params.metric = ps.metric;
index.emplace(raft::neighbors::ivf_flat::build(
handle, index_params, data, IdxT(ps.n_samples), uint32_t(ps.n_dims)));
auto removed_indices =
Expand Down Expand Up @@ -298,7 +396,7 @@ struct ivf_pq_filter_knn {
: ps(ps), removed_indices_bitset_(handle, ps.n_samples)
{
index_params.n_lists = 4096;
index_params.metric = raft::distance::DistanceType::L2Expanded;
index_params.metric = ps.metric;
auto data_view = raft::make_device_matrix_view<const ValT, IdxT>(data, ps.n_samples, ps.n_dims);
index.emplace(raft::neighbors::ivf_pq::build(handle, index_params, data_view));
auto removed_indices =
Expand Down Expand Up @@ -500,10 +598,34 @@ const std::vector<params> kInputsFilter =
{size_t(255)}, // k
{0.0, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64} // removed_ratio
);

const std::vector<params> kInputsBruteForceFilter = raft::util::itertools::product<params>(
{size_t(10000000), size_t(1 * 1024 * 1024)}, // n_samples
{size_t(256), size_t(2048)}, // n_dim
{size_t(1), size_t(10), size_t(100), size_t(1000)}, // n_queries
{size_t(256)}, // k
{0.0, 0.8, 0.99}, // removed_ratio
{raft::distance::DistanceType::InnerProduct});

const std::vector<params> kInputsBruteForceFilterExtra =
raft::util::itertools::product<params>({size_t(1024 * 1024)}, // n_samples
{size_t(256), size_t(768)}, // n_dim
{size_t(10),
size_t(20),
size_t(40),
size_t(60),
size_t(80),
size_t(100),
size_t(300)}, // n_queries
{size_t(255)}, // k
{0.3, 0.4, 0.9}, // removed_ratio
{raft::distance::DistanceType::InnerProduct});

inline const std::vector<TransferStrategy> kAllStrategies{
TransferStrategy::NO_COPY, TransferStrategy::MAP_PINNED, TransferStrategy::MANAGED};
inline const std::vector<TransferStrategy> kNoCopyOnly{TransferStrategy::NO_COPY};

inline const std::vector<Scope> kScopeOnlySearch{Scope::SEARCH};
inline const std::vector<Scope> kScopeFull{Scope::BUILD_SEARCH};
inline const std::vector<Scope> kAllScopes{Scope::BUILD_SEARCH, Scope::SEARCH, Scope::BUILD};

Expand Down
31 changes: 31 additions & 0 deletions cpp/bench/prims/neighbors/knn/brute_force_filter_float_int64_t.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter
#include "../knn.cuh"

namespace raft::bench::spatial {
KNN_REGISTER(
float, int64_t, brute_force_filter_knn, kInputsBruteForceFilter, kNoCopyOnly, kScopeOnlySearch);

KNN_REGISTER(float,
int64_t,
brute_force_filter_knn,
kInputsBruteForceFilterExtra,
kNoCopyOnly,
kScopeOnlySearch);

} // namespace raft::bench::spatial
6 changes: 4 additions & 2 deletions cpp/include/raft/core/bitmap.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include <raft/core/device_mdarray.hpp>
#include <raft/core/resources.hpp>

#include <type_traits>

namespace raft::core {
/**
* @defgroup bitmap Bitmap
Expand All @@ -39,8 +41,8 @@ namespace raft::core {
*/
template <typename bitmap_t = uint32_t, typename index_t = uint32_t>
struct bitmap_view : public bitset_view<bitmap_t, index_t> {
static_assert((std::is_same<bitmap_t, uint32_t>::value ||
std::is_same<bitmap_t, uint64_t>::value),
static_assert((std::is_same<typename std::remove_const<bitmap_t>::type, uint32_t>::value ||
std::is_same<typename std::remove_const<bitmap_t>::type, uint64_t>::value),
"The bitmap_t must be uint32_t or uint64_t.");
/**
* @brief Create a bitmap view from a device raw pointer.
Expand Down
42 changes: 10 additions & 32 deletions cpp/include/raft/core/bitset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#pragma once

#include <raft/core/bitset.hpp>
#include <raft/core/detail/mdspan_util.cuh> // native_popc
#include <raft/core/detail/popc.cuh>
#include <raft/core/device_container_policy.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/resource/thrust_policy.hpp>
Expand Down Expand Up @@ -60,6 +60,12 @@ _RAFT_HOST_DEVICE void bitset_view<bitset_t, index_t>::set(const index_t sample_
}
}

template <typename bitset_t, typename index_t>
_RAFT_HOST_DEVICE inline index_t bitset_view<bitset_t, index_t>::n_elements() const
{
return raft::ceildiv(bitset_len_, bitset_element_size);
}

template <typename bitset_t, typename index_t>
bitset<bitset_t, index_t>::bitset(const raft::resources& res,
raft::device_vector_view<const index_t, index_t> mask_index,
Expand Down Expand Up @@ -161,37 +167,9 @@ template <typename bitset_t, typename index_t>
void bitset<bitset_t, index_t>::count(const raft::resources& res,
raft::device_scalar_view<index_t> count_gpu_scalar)
{
auto n_elements_ = n_elements();
auto count_gpu =
raft::make_device_vector_view<index_t, index_t>(count_gpu_scalar.data_handle(), 1);
auto bitset_matrix_view = raft::make_device_matrix_view<const bitset_t, index_t, raft::col_major>(
bitset_.data(), n_elements_, 1);

bitset_t n_last_element = (bitset_len_ % bitset_element_size);
bitset_t last_element_mask =
n_last_element ? (bitset_t)((bitset_t{1} << n_last_element) - bitset_t{1}) : ~bitset_t{0};
raft::linalg::coalesced_reduction(
res,
bitset_matrix_view,
count_gpu,
index_t{0},
false,
[last_element_mask, n_elements_] __device__(bitset_t element, index_t index) {
index_t result = 0;
if constexpr (bitset_element_size == 64) {
if (index == n_elements_ - 1)
result = index_t(raft::detail::popc(element & last_element_mask));
else
result = index_t(raft::detail::popc(element));
} else { // Needed because popc is not overloaded for 16 and 8 bit elements
if (index == n_elements_ - 1)
result = index_t(raft::detail::popc(uint32_t{element} & last_element_mask));
else
result = index_t(raft::detail::popc(uint32_t{element}));
}

return result;
});
auto values =
raft::make_device_vector_view<const bitset_t, index_t>(bitset_.data(), n_elements());
raft::detail::popc(res, values, bitset_len_, count_gpu_scalar);
}

} // end namespace raft::core
Loading
Loading