Skip to content

Commit

Permalink
[FEA] support of prefiltered brute force based on cuSparseSDDMM
Browse files Browse the repository at this point in the history
- This PR is one part of the feature of #1969
- Add the API of 'search_with_filtering' for brute force.
Authors:
  - James Rong (https://github.com/rhdong)
  • Loading branch information
rhdong committed May 7, 2024
1 parent ef28628 commit d85fcf0
Show file tree
Hide file tree
Showing 20 changed files with 976 additions and 46 deletions.
8 changes: 8 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,14 @@ if(RAFT_COMPILE_LIBRARY)
src/matrix/detail/select_k_float_int32.cu
src/matrix/detail/select_k_half_int64_t.cu
src/matrix/detail/select_k_half_uint32_t.cu
src/sparse/matrix/detail/select_k_half_uint32_t.cu
src/sparse/matrix/detail/select_k_double_int64_t.cu
src/sparse/matrix/detail/select_k_double_uint32_t.cu
src/sparse/matrix/detail/select_k_float_int64_t.cu
src/sparse/matrix/detail/select_k_float_uint32_t.cu
src/sparse/matrix/detail/select_k_float_int32.cu
src/sparse/matrix/detail/select_k_half_int64_t.cu
src/sparse/matrix/detail/select_k_half_uint32_t.cu
src/neighbors/ball_cover.cu
src/neighbors/brute_force_fused_l2_knn_float_int64_t.cu
src/neighbors/brute_force_knn_int64_t_float_int64_t.cu
Expand Down
1 change: 1 addition & 0 deletions cpp/bench/prims/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ if(BUILD_PRIMS_BENCH)
NAME
NEIGHBORS_BENCH
PATH
bench/prims/neighbors/knn/brute_force_filter_float_int64_t.cu
bench/prims/neighbors/knn/brute_force_float_int64_t.cu
bench/prims/neighbors/knn/brute_force_float_uint32_t.cu
bench/prims/neighbors/knn/cagra_float_uint32_t.cu
Expand Down
111 changes: 107 additions & 4 deletions cpp/bench/prims/neighbors/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <raft/core/bitset.cuh>
#include <raft/core/resource/device_id.hpp>
#include <raft/neighbors/brute_force.cuh>
#include <raft/neighbors/ivf_flat.cuh>
#include <raft/neighbors/ivf_pq.cuh>
#include <raft/neighbors/sample_filter.cuh>
Expand All @@ -36,7 +37,10 @@

#include <thrust/sequence.h>

#include <algorithm>
#include <optional>
#include <random>
#include <vector>

namespace raft::bench::spatial {

Expand All @@ -51,12 +55,19 @@ struct params {
size_t k;
/** Ratio of removed indices. */
double removed_ratio;
/** Distance Type. */
raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded;
};

inline auto operator<<(std::ostream& os, const params& p) -> std::ostream&
{
os << p.n_samples << "#" << p.n_dims << "#" << p.n_queries << "#" << p.k << "#"
<< p.removed_ratio;
switch (p.metric) {
case raft::distance::DistanceType::InnerProduct: os << "#InnerProduct"; break;
case raft::distance::DistanceType::L2Expanded: os << "#L2Expanded"; break;
default: os << "UNKNOWN DistanceType, please add one case here.";
}
return os;
}

Expand Down Expand Up @@ -149,7 +160,7 @@ struct ivf_flat_knn {
ivf_flat_knn(const raft::device_resources& handle, const params& ps, const ValT* data) : ps(ps)
{
index_params.n_lists = 4096;
index_params.metric = raft::distance::DistanceType::L2Expanded;
index_params.metric = ps.metric;
index.emplace(raft::neighbors::ivf_flat::build(
handle, index_params, data, IdxT(ps.n_samples), uint32_t(ps.n_dims)));
}
Expand Down Expand Up @@ -184,7 +195,7 @@ struct ivf_pq_knn {
ivf_pq_knn(const raft::device_resources& handle, const params& ps, const ValT* data) : ps(ps)
{
index_params.n_lists = 4096;
index_params.metric = raft::distance::DistanceType::L2Expanded;
index_params.metric = ps.metric;
auto data_view = raft::make_device_matrix_view<const ValT, IdxT>(data, ps.n_samples, ps.n_dims);
index.emplace(raft::neighbors::ivf_pq::build(handle, index_params, data_view));
}
Expand Down Expand Up @@ -236,6 +247,88 @@ struct brute_force_knn {
}
};

template <typename IdxT, typename bitmap_t = std::uint32_t>
RAFT_KERNEL initialize_random_bits(
bitmap_t* data, IdxT N, float sparsity, size_t total_bits, unsigned long seed)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) return;

curandState state;
curand_init(seed, idx, 0, &state);

bitmap_t value = 0;
for (int i = 0; i < sizeof(bitmap_t) * 8; i++) {
int rnd = curand(&state) % 10000;

if (rnd < int(10000 * sparsity) && (idx * sizeof(bitmap_t) * 8 + i < total_bits)) {
bitmap_t bit_mask = 1u << i;
value |= bit_mask;
}
}
data[idx] = value;
}

template <typename ValT, typename IdxT>
struct brute_force_filter_knn {
using dist_t = float;
using bitmap_t = std::uint32_t;

std::optional<raft::neighbors::brute_force::index<ValT>> index;
raft::neighbors::brute_force::index_params index_params;
raft::neighbors::brute_force::search_params search_params;
raft::core::bitset<bitmap_t, IdxT> removed_indices_bitset_;
params ps;

brute_force_filter_knn(const raft::device_resources& handle, const params& ps, const ValT* data)
: ps(ps), removed_indices_bitset_(handle, ps.n_samples * ps.n_queries)
{
auto stream = resource::get_cuda_stream(handle);
index_params.metric = ps.metric;

auto data_view = raft::make_device_matrix_view<const ValT, IdxT>(data, ps.n_samples, ps.n_dims);
index.emplace(raft::neighbors::brute_force::build(handle, index_params, data_view));

IdxT element = raft::ceildiv(IdxT(ps.n_samples * ps.n_queries), IdxT(sizeof(bitmap_t) * 8));

size_t threadsPerBlock = 256;
size_t numBlocks = (element + threadsPerBlock - 1) / threadsPerBlock;
unsigned long seed = 1234;
initialize_random_bits<<<numBlocks, threadsPerBlock, 0, stream>>>(
removed_indices_bitset_.data(),
removed_indices_bitset_.size(),
float(1.0 - ps.removed_ratio),
ps.n_samples * ps.n_queries,
seed);

resource::sync_stream(handle);
}

void search(const raft::device_resources& handle,
const ValT* search_items,
ValT* out_dists,
IdxT* out_idxs)
{
auto queries_view =
raft::make_device_matrix_view<const ValT, IdxT>(search_items, ps.n_queries, ps.n_dims);
auto neighbors_view =
raft::make_device_matrix_view<IdxT, IdxT, raft::row_major>(out_idxs, ps.n_queries, ps.k);
auto distance_view =
raft::make_device_matrix_view<ValT, IdxT, raft::row_major>(out_dists, ps.n_queries, ps.k);

if (ps.removed_ratio > 0) {
auto filter = raft::core::bitmap_view(
(const bitmap_t*)removed_indices_bitset_.data(), IdxT(ps.n_queries), IdxT(ps.n_samples));

raft::neighbors::brute_force::search_with_filtering(
handle, *index, queries_view, filter, neighbors_view, distance_view);
} else {
raft::neighbors::brute_force::search(
handle, search_params, *index, queries_view, neighbors_view, distance_view);
}
}
};

template <typename ValT, typename IdxT>
struct ivf_flat_filter_knn {
using dist_t = float;
Expand All @@ -250,7 +343,7 @@ struct ivf_flat_filter_knn {
: ps(ps), removed_indices_bitset_(handle, ps.n_samples)
{
index_params.n_lists = 4096;
index_params.metric = raft::distance::DistanceType::L2Expanded;
index_params.metric = ps.metric;
index.emplace(raft::neighbors::ivf_flat::build(
handle, index_params, data, IdxT(ps.n_samples), uint32_t(ps.n_dims)));
auto removed_indices =
Expand Down Expand Up @@ -298,7 +391,7 @@ struct ivf_pq_filter_knn {
: ps(ps), removed_indices_bitset_(handle, ps.n_samples)
{
index_params.n_lists = 4096;
index_params.metric = raft::distance::DistanceType::L2Expanded;
index_params.metric = ps.metric;
auto data_view = raft::make_device_matrix_view<const ValT, IdxT>(data, ps.n_samples, ps.n_dims);
index.emplace(raft::neighbors::ivf_pq::build(handle, index_params, data_view));
auto removed_indices =
Expand Down Expand Up @@ -500,10 +593,20 @@ const std::vector<params> kInputsFilter =
{size_t(255)}, // k
{0.0, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64} // removed_ratio
);

const std::vector<params> kInputsBruteForceFilter = raft::util::itertools::product<params>(
{size_t(1000000)}, // n_samples
{size_t(128)}, // n_dim
{size_t(1000)}, // n_queries
{size_t(255)}, // k
{0.0, 0.8, 0.9}, // removed_ratio
{raft::distance::DistanceType::InnerProduct, raft::distance::DistanceType::L2Expanded});

inline const std::vector<TransferStrategy> kAllStrategies{
TransferStrategy::NO_COPY, TransferStrategy::MAP_PINNED, TransferStrategy::MANAGED};
inline const std::vector<TransferStrategy> kNoCopyOnly{TransferStrategy::NO_COPY};

inline const std::vector<Scope> kScopeOnlySearch{Scope::SEARCH};
inline const std::vector<Scope> kScopeFull{Scope::BUILD_SEARCH};
inline const std::vector<Scope> kAllScopes{Scope::BUILD_SEARCH, Scope::SEARCH, Scope::BUILD};

Expand Down
25 changes: 25 additions & 0 deletions cpp/bench/prims/neighbors/knn/brute_force_filter_float_int64_t.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter
#include "../knn.cuh"

namespace raft::bench::spatial {

KNN_REGISTER(
float, int64_t, brute_force_filter_knn, kInputsBruteForceFilter, kNoCopyOnly, kScopeOnlySearch);

} // namespace raft::bench::spatial
6 changes: 4 additions & 2 deletions cpp/include/raft/core/bitmap.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include <raft/core/device_mdarray.hpp>
#include <raft/core/resources.hpp>

#include <type_traits>

namespace raft::core {
/**
* @defgroup bitmap Bitmap
Expand All @@ -39,8 +41,8 @@ namespace raft::core {
*/
template <typename bitmap_t = uint32_t, typename index_t = uint32_t>
struct bitmap_view : public bitset_view<bitmap_t, index_t> {
static_assert((std::is_same<bitmap_t, uint32_t>::value ||
std::is_same<bitmap_t, uint64_t>::value),
static_assert((std::is_same<typename std::remove_const<bitmap_t>::type, uint32_t>::value ||
std::is_same<typename std::remove_const<bitmap_t>::type, uint64_t>::value),
"The bitmap_t must be uint32_t or uint64_t.");
/**
* @brief Create a bitmap view from a device raw pointer.
Expand Down
36 changes: 4 additions & 32 deletions cpp/include/raft/core/bitset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#pragma once

#include <raft/core/detail/mdspan_util.cuh> // native_popc
#include <raft/core/detail/popc.cuh>
#include <raft/core/device_container_policy.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/resource/thrust_policy.hpp>
Expand Down Expand Up @@ -326,37 +326,9 @@ struct bitset {
*/
void count(const raft::resources& res, raft::device_scalar_view<index_t> count_gpu_scalar)
{
auto n_elements_ = n_elements();
auto count_gpu =
raft::make_device_vector_view<index_t, index_t>(count_gpu_scalar.data_handle(), 1);
auto bitset_matrix_view = raft::make_device_matrix_view<const bitset_t, index_t, col_major>(
bitset_.data(), n_elements_, 1);

bitset_t n_last_element = (bitset_len_ % bitset_element_size);
bitset_t last_element_mask =
n_last_element ? (bitset_t)((bitset_t{1} << n_last_element) - bitset_t{1}) : ~bitset_t{0};
raft::linalg::coalesced_reduction(
res,
bitset_matrix_view,
count_gpu,
index_t{0},
false,
[last_element_mask, n_elements_] __device__(bitset_t element, index_t index) {
index_t result = 0;
if constexpr (bitset_element_size == 64) {
if (index == n_elements_ - 1)
result = index_t(raft::detail::popc(element & last_element_mask));
else
result = index_t(raft::detail::popc(element));
} else { // Needed because popc is not overloaded for 16 and 8 bit elements
if (index == n_elements_ - 1)
result = index_t(raft::detail::popc(uint32_t{element} & last_element_mask));
else
result = index_t(raft::detail::popc(uint32_t{element}));
}

return result;
});
auto values =
raft::make_device_vector_view<const bitset_t, index_t>(bitset_.data(), n_elements());
raft::detail::popc(res, values, bitset_len_, count_gpu_scalar);
}
/**
* @brief Returns the number of bits set to true.
Expand Down
75 changes: 75 additions & 0 deletions cpp/include/raft/core/detail/popc.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <raft/core/detail/mdspan_util.cuh>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/coalesced_reduction.cuh>

namespace raft::detail {

/**
* @brief Count the number of bits that are set to 1 in a vector.
*
* @tparam value_t the value type of the vector.
* @tparam index_t the index type of vector and scalar.
*
* @param[in] res raft handle for managing expensive resources
* @param[in] values Number of row in the matrix.
* @param[in] max_len Maximum number of bits to count.
* @param[out] counter Number of bits that are set to 1.
*/
template <typename value_t, typename index_t>
void popc(const raft::resources& res,
device_vector_view<value_t, index_t> values,
index_t max_len,
raft::device_scalar_view<index_t> counter)
{
auto values_size = values.size();
auto values_matrix = raft::make_device_matrix_view<value_t, index_t, col_major>(
values.data_handle(), values_size, 1);
auto counter_vector = raft::make_device_vector_view<index_t, index_t>(counter.data_handle(), 1);

static constexpr index_t len_per_item = sizeof(value_t) * 8;

value_t tail_len = (max_len % len_per_item);
value_t tail_mask = tail_len ? (value_t)((value_t{1} << tail_len) - value_t{1}) : ~value_t{0};
raft::linalg::coalesced_reduction(
res,
values_matrix,
counter_vector,
index_t{0},
false,
[tail_mask, values_size] __device__(value_t value, index_t index) {
index_t result = 0;
if constexpr (len_per_item == 64) {
if (index == values_size - 1)
result = index_t(raft::detail::popc(value & tail_mask));
else
result = index_t(raft::detail::popc(value));
} else { // Needed because popc is not overloaded for 16 and 8 bit elements
if (index == values_size - 1)
result = index_t(raft::detail::popc(uint32_t{value} & tail_mask));
else
result = index_t(raft::detail::popc(uint32_t{value}));
}

return result;
});
}

} // end namespace raft::detail
Loading

0 comments on commit d85fcf0

Please sign in to comment.