Skip to content

Commit

Permalink
Add sample filter conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
lowener committed Oct 15, 2023
1 parent cbcd4a0 commit 81dad56
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 2 deletions.
4 changes: 3 additions & 1 deletion build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ INSTALL_TARGET=install
BUILD_REPORT_METRICS=""
BUILD_REPORT_INCL_CACHE_STATS=OFF

TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_NN_DESCENT_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST"
TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_NN_DESCENT_TEST;NEIGHBORS_ANN_IVF_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST"
BENCH_TARGETS="CLUSTER_BENCH;CORE_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH"

CACHE_ARGS=""
Expand Down Expand Up @@ -324,6 +324,8 @@ if hasArg tests || (( ${NUMARGS} == 0 )); then
$CMAKE_TARGET == *"DISTANCE_TEST"* || \
$CMAKE_TARGET == *"MATRIX_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_ANN_CAGRA_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_ANN_IVF_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_ANN_NN_DESCENT_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_TEST"* || \
$CMAKE_TARGET == *"SPARSE_DIST_TEST" || \
$CMAKE_TARGET == *"SPARSE_NEIGHBORS_TEST"* || \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock)
const bool valid = vec_id < list_length;

// Process first shm_assisted_dim dimensions (always using shared memory)
if (valid && sample_filter(queries_offset + blockIdx.y, probe_id, vec_id)) {
if (valid && sample_filter(queries_offset + blockIdx.y, list_id, vec_id)) {
loadAndComputeDist<kUnroll, decltype(compute_dist), Veclen, T, AccT> lc(dist,
compute_dist);
for (int pos = 0; pos < shm_assisted_dim;
Expand Down
30 changes: 30 additions & 0 deletions cpp/include/raft/neighbors/sample_filter.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,34 @@ struct bitset_filter {
return bitset_view_.test(sample_ix);
}
};

/**
* @brief Filter used to convert the cluster index and sample index
* of an IVF search into a sample index. This can be used as an
* intermediate filter.
*
* @tparam index_t Indexing type
* @tparam filter_t
*/
template <typename index_t, typename filter_t>
struct ivf_to_sample_filter {
index_t** const inds_ptrs_;
const filter_t next_filter_;

ivf_to_sample_filter(index_t** const inds_ptrs, const filter_t next_filter)
: inds_ptrs_{inds_ptrs}, next_filter_{next_filter}
{
}

inline _RAFT_HOST_DEVICE bool operator()(
// query index
const uint32_t query_ix,
// the current inverted list index
const uint32_t cluster_ix,
// the index of the current sample inside the current inverted list
const uint32_t sample_ix) const
{
return next_filter_(query_ix, inds_ptrs_[cluster_ix][sample_ix]);
}
};
} // namespace raft::neighbors::filtering
108 changes: 108 additions & 0 deletions cpp/test/neighbors/ann_ivf_flat.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
#pragma once

#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter

#include "../test_utils.cuh"
#include "ann_utils.cuh"
#include <raft/core/device_mdarray.hpp>
Expand All @@ -26,6 +28,7 @@
#include <raft/linalg/map.cuh>
#include <raft/neighbors/ivf_flat_types.hpp>
#include <raft/neighbors/ivf_list.hpp>
#include <raft/neighbors/sample_filter.cuh>
#include <raft/util/cudart_utils.hpp>
#include <raft/util/fast_int_div.cuh>
#include <thrust/functional.h>
Expand Down Expand Up @@ -57,6 +60,10 @@

namespace raft::neighbors::ivf_flat {

struct test_ivf_sample_filter {
static constexpr unsigned offset = 300;
};

template <typename IdxT>
struct AnnIvfFlatInputs {
IdxT num_queries;
Expand Down Expand Up @@ -406,6 +413,107 @@ class AnnIVFFlatTest : public ::testing::TestWithParam<AnnIvfFlatInputs<IdxT>> {
}
}

void testFilter()
{
size_t queries_size = ps.num_queries * ps.k;
std::vector<IdxT> indices_ivfflat(queries_size);
std::vector<IdxT> indices_naive(queries_size);
std::vector<T> distances_ivfflat(queries_size);
std::vector<T> distances_naive(queries_size);

{
rmm::device_uvector<T> distances_naive_dev(queries_size, stream_);
rmm::device_uvector<IdxT> indices_naive_dev(queries_size, stream_);
auto* database_filtered_ptr = database.data() + test_ivf_sample_filter::offset * ps.dim;
naive_knn<T, DataT, IdxT>(handle_,
distances_naive_dev.data(),
indices_naive_dev.data(),
search_queries.data(),
database_filtered_ptr,
ps.num_queries,
ps.num_db_vecs - test_ivf_sample_filter::offset,
ps.dim,
ps.k,
ps.metric);
raft::linalg::addScalar(indices_naive_dev.data(),
indices_naive_dev.data(),
IdxT(test_ivf_sample_filter::offset),
queries_size,
stream_);
update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_);
update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_);
resource::sync_stream(handle_);
}

{
// unless something is really wrong with clustering, this could serve as a lower bound on
// recall
double min_recall = static_cast<double>(ps.nprobe) / static_cast<double>(ps.nlist);

auto distances_ivfflat_dev = raft::make_device_matrix<T, IdxT>(handle_, ps.num_queries, ps.k);
auto indices_ivfflat_dev =
raft::make_device_matrix<IdxT, IdxT>(handle_, ps.num_queries, ps.k);

{
ivf_flat::index_params index_params;
ivf_flat::search_params search_params;
index_params.n_lists = ps.nlist;
index_params.metric = ps.metric;
index_params.adaptive_centers = ps.adaptive_centers;
search_params.n_probes = ps.nprobe;

index_params.add_data_on_build = true;
index_params.kmeans_trainset_fraction = 0.5;
index_params.metric_arg = 0;

// Create IVF Flat index
auto database_view = raft::make_device_matrix_view<const DataT, IdxT>(
(const DataT*)database.data(), ps.num_db_vecs, ps.dim);
auto index = ivf_flat::build(handle_, index_params, database_view);

// Create Bitset filter
auto removed_indices =
raft::make_device_vector<IdxT, int64_t>(handle_, test_ivf_sample_filter::offset);
thrust::sequence(resource::get_thrust_policy(handle_),
thrust::device_pointer_cast(removed_indices.data_handle()),
thrust::device_pointer_cast(removed_indices.data_handle() +
test_ivf_sample_filter::offset));
resource::sync_stream(handle_);

raft::core::bitset<std::uint32_t, IdxT> removed_indices_bitset(
handle_, removed_indices.view(), ps.num_db_vecs);

// Search with the filter
auto search_queries_view = raft::make_device_matrix_view<const DataT, IdxT>(
search_queries.data(), ps.num_queries, ps.dim);
ivf_flat::search_with_filtering(
handle_,
search_params,
index,
search_queries_view,
indices_ivfflat_dev.view(),
distances_ivfflat_dev.view(),
raft::neighbors::filtering::ivf_to_sample_filter(
index.inds_ptrs().data_handle(),
raft::neighbors::filtering::bitset_filter(removed_indices_bitset.view())));

update_host(
distances_ivfflat.data(), distances_ivfflat_dev.data_handle(), queries_size, stream_);
update_host(
indices_ivfflat.data(), indices_ivfflat_dev.data_handle(), queries_size, stream_);
resource::sync_stream(handle_);
}
ASSERT_TRUE(eval_neighbours(indices_naive,
indices_ivfflat,
distances_naive,
distances_ivfflat,
ps.num_queries,
ps.k,
0.001,
min_recall));
}
}

void SetUp() override
{
database.resize(ps.num_db_vecs * ps.dim, stream_);
Expand Down
1 change: 1 addition & 0 deletions cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ TEST_P(AnnIVFFlatTestF, AnnIVFFlat)
{
this->testIVFFlat();
this->testPacker();
this->testFilter();
}

INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatTestF, ::testing::ValuesIn(inputs));
Expand Down

0 comments on commit 81dad56

Please sign in to comment.