Skip to content

Commit

Permalink
Merge branch 'branch-23.10' into fix-nn_desccent-constructor-bug
Browse files Browse the repository at this point in the history
  • Loading branch information
enp1s0 authored Oct 4, 2023
2 parents e04ffea + bf3f85a commit 9afb119
Show file tree
Hide file tree
Showing 7 changed files with 241 additions and 11 deletions.
52 changes: 45 additions & 7 deletions cpp/bench/prims/neighbors/cagra_bench.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

#include <common/benchmark.hpp>
#include <raft/neighbors/cagra.cuh>
#include <raft/neighbors/sample_filter.cuh>
#include <raft/random/rng.cuh>
#include <raft/util/itertools.hpp>
#include <thrust/sequence.h>

#include <optional>

Expand All @@ -40,6 +42,8 @@ struct params {
int block_size;
int search_width;
int max_iterations;
/** Ratio of removed indices. */
double removed_ratio;
};

template <typename T, typename IdxT>
Expand All @@ -49,7 +53,8 @@ struct CagraBench : public fixture {
params_(ps),
queries_(make_device_matrix<T, int64_t>(handle, ps.n_queries, ps.n_dims)),
dataset_(make_device_matrix<T, int64_t>(handle, ps.n_samples, ps.n_dims)),
knn_graph_(make_device_matrix<IdxT, int64_t>(handle, ps.n_samples, ps.degree))
knn_graph_(make_device_matrix<IdxT, int64_t>(handle, ps.n_samples, ps.degree)),
removed_indices_bitset_(handle, ps.n_samples)
{
// Generate random dataset and queriees
raft::random::RngState state{42};
Expand All @@ -74,6 +79,13 @@ struct CagraBench : public fixture {

auto metric = raft::distance::DistanceType::L2Expanded;

auto removed_indices =
raft::make_device_vector<IdxT, int64_t>(handle, ps.removed_ratio * ps.n_samples);
thrust::sequence(
resource::get_thrust_policy(handle),
thrust::device_pointer_cast(removed_indices.data_handle()),
thrust::device_pointer_cast(removed_indices.data_handle() + removed_indices.extent(0)));
removed_indices_bitset_.set(handle, removed_indices.view());
index_.emplace(raft::neighbors::cagra::index<T, IdxT>(
handle, metric, make_const_mdspan(dataset_.view()), make_const_mdspan(knn_graph_.view())));
}
Expand All @@ -95,10 +107,18 @@ struct CagraBench : public fixture {
distances.data_handle(), params_.n_queries, params_.k);

auto queries_v = make_const_mdspan(queries_.view());
loop_on_state(state, [&]() {
raft::neighbors::cagra::search(
this->handle, search_params, *this->index_, queries_v, ind_v, dist_v);
});
if (params_.removed_ratio > 0) {
auto filter = raft::neighbors::filtering::bitset_filter(removed_indices_bitset_.view());
loop_on_state(state, [&]() {
raft::neighbors::cagra::search_with_filtering(
this->handle, search_params, *this->index_, queries_v, ind_v, dist_v, filter);
});
} else {
loop_on_state(state, [&]() {
raft::neighbors::cagra::search(
this->handle, search_params, *this->index_, queries_v, ind_v, dist_v);
});
}

double data_size = params_.n_samples * params_.n_dims * sizeof(T);
double graph_size = params_.n_samples * params_.degree * sizeof(IdxT);
Expand All @@ -120,6 +140,7 @@ struct CagraBench : public fixture {
state.counters["block_size"] = params_.block_size;
state.counters["search_width"] = params_.search_width;
state.counters["iterations"] = iterations;
state.counters["removed_ratio"] = params_.removed_ratio;
}

private:
Expand All @@ -128,6 +149,7 @@ struct CagraBench : public fixture {
raft::device_matrix<T, int64_t, row_major> queries_;
raft::device_matrix<T, int64_t, row_major> dataset_;
raft::device_matrix<IdxT, int64_t, row_major> knn_graph_;
raft::core::bitset<std::uint32_t, IdxT> removed_indices_bitset_;
};

inline const std::vector<params> generate_inputs()
Expand All @@ -141,7 +163,8 @@ inline const std::vector<params> generate_inputs()
{64}, // itopk_size
{0}, // block_size
{1}, // search_width
{0} // max_iterations
{0}, // max_iterations
{0.0} // removed_ratio
);
auto inputs2 = raft::util::itertools::product<params>({2000000ull, 10000000ull}, // n_samples
{128}, // dataset dim
Expand All @@ -151,7 +174,22 @@ inline const std::vector<params> generate_inputs()
{64}, // itopk_size
{64, 128, 256, 512, 1024}, // block_size
{1}, // search_width
{0} // max_iterations
{0}, // max_iterations
{0.0} // removed_ratio
);
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());

inputs2 = raft::util::itertools::product<params>(
{2000000ull, 10000000ull}, // n_samples
{128}, // dataset dim
{1, 10, 10000}, // n_queries
{255}, // k
{64}, // knn graph degree
{300}, // itopk_size
{256}, // block_size
{2}, // search_width
{0}, // max_iterations
{0.0, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64} // removed_ratio
);
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());
return inputs;
Expand Down
20 changes: 19 additions & 1 deletion cpp/include/raft/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,25 @@ void search(raft::resources const& res,
/**
* @brief Search ANN using the constructed index with the given sample filter.
*
* See the [cagra::build](#cagra::build) documentation for a usage example.
* Usage example:
* @code{.cpp}
* using namespace raft::neighbors;
* // use default index parameters
* cagra::index_params index_params;
* // create and fill the index from a [N, D] dataset
* auto index = cagra::build(res, index_params, dataset);
* // use default search parameters
* cagra::search_params search_params;
* // create a bitset to filter the search
* auto removed_indices = raft::make_device_vector<IdxT>(res, n_removed_indices);
* raft::core::bitset<std::uint32_t, IdxT> removed_indices_bitset(
* res, removed_indices.view(), dataset.extent(0));
* // search K nearest neighbours according to a bitset
* auto neighbors = raft::make_device_matrix<uint32_t>(res, n_queries, k);
* auto distances = raft::make_device_matrix<float>(res, n_queries, k);
* cagra::search_with_filtering(res, search_params, index, queries, neighbors, distances,
* filtering::bitset_filter(removed_indices_bitset.view()));
* @endcode
*
* @tparam T data element type
* @tparam IdxT type of the indices
Expand Down
48 changes: 48 additions & 0 deletions cpp/include/raft/neighbors/sample_filter.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cstddef>
#include <cstdint>

#include <raft/core/bitset.cuh>

namespace raft::neighbors::filtering {
/**
* @brief Filter an index with a bitset
*
* @tparam index_t Indexing type
*/
template <typename bitset_t, typename index_t>
struct bitset_filter {
// View of the bitset to use as a filter
const raft::core::bitset_view<bitset_t, index_t> bitset_view_;

bitset_filter(const raft::core::bitset_view<bitset_t, index_t> bitset_for_filtering)
: bitset_view_{bitset_for_filtering}
{
}
inline _RAFT_HOST_DEVICE bool operator()(
// query index
const uint32_t query_ix,
// the index of the current sample
const uint32_t sample_ix) const
{
return bitset_view_.test(sample_ix);
}
};
} // namespace raft::neighbors::filtering
114 changes: 114 additions & 0 deletions cpp/test/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <raft/linalg/add.cuh>
#include <raft/neighbors/cagra.cuh>
#include <raft/neighbors/cagra_serialize.cuh>
#include <raft/neighbors/sample_filter.cuh>
#include <raft/random/rng.cuh>
#include <raft/util/itertools.hpp>

Expand Down Expand Up @@ -525,6 +526,119 @@ class AnnCagraFilterTest : public ::testing::TestWithParam<AnnCagraInputs> {
}
}

void testCagraRemoved()
{
size_t queries_size = ps.n_queries * ps.k;
std::vector<IdxT> indices_Cagra(queries_size);
std::vector<IdxT> indices_naive(queries_size);
std::vector<DistanceT> distances_Cagra(queries_size);
std::vector<DistanceT> distances_naive(queries_size);

{
rmm::device_uvector<DistanceT> distances_naive_dev(queries_size, stream_);
rmm::device_uvector<IdxT> indices_naive_dev(queries_size, stream_);
auto* database_filtered_ptr = database.data() + test_cagra_sample_filter::offset * ps.dim;
naive_knn<DistanceT, DataT, IdxT>(handle_,
distances_naive_dev.data(),
indices_naive_dev.data(),
search_queries.data(),
database_filtered_ptr,
ps.n_queries,
ps.n_rows - test_cagra_sample_filter::offset,
ps.dim,
ps.k,
ps.metric);
raft::linalg::addScalar(indices_naive_dev.data(),
indices_naive_dev.data(),
IdxT(test_cagra_sample_filter::offset),
queries_size,
stream_);
update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_);
update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_);
resource::sync_stream(handle_);
}

{
rmm::device_uvector<DistanceT> distances_dev(queries_size, stream_);
rmm::device_uvector<IdxT> indices_dev(queries_size, stream_);

{
cagra::index_params index_params;
index_params.metric = ps.metric; // Note: currently ony the cagra::index_params metric is
// not used for knn_graph building.
cagra::search_params search_params;
search_params.algo = ps.algo;
search_params.max_queries = ps.max_queries;
search_params.team_size = ps.team_size;
search_params.hashmap_mode = cagra::hash_mode::HASH;

auto database_view = raft::make_device_matrix_view<const DataT, int64_t>(
(const DataT*)database.data(), ps.n_rows, ps.dim);

cagra::index<DataT, IdxT> index(handle_);
if (ps.host_dataset) {
auto database_host = raft::make_host_matrix<DataT, int64_t>(ps.n_rows, ps.dim);
raft::copy(database_host.data_handle(), database.data(), database.size(), stream_);
auto database_host_view = raft::make_host_matrix_view<const DataT, int64_t>(
(const DataT*)database_host.data_handle(), ps.n_rows, ps.dim);
index = cagra::build<DataT, IdxT>(handle_, index_params, database_host_view);
} else {
index = cagra::build<DataT, IdxT>(handle_, index_params, database_view);
}

if (!ps.include_serialized_dataset) { index.update_dataset(handle_, database_view); }

auto search_queries_view = raft::make_device_matrix_view<const DataT, int64_t>(
search_queries.data(), ps.n_queries, ps.dim);
auto indices_out_view =
raft::make_device_matrix_view<IdxT, int64_t>(indices_dev.data(), ps.n_queries, ps.k);
auto dists_out_view = raft::make_device_matrix_view<DistanceT, int64_t>(
distances_dev.data(), ps.n_queries, ps.k);
auto removed_indices =
raft::make_device_vector<IdxT, int64_t>(handle_, test_cagra_sample_filter::offset);
thrust::sequence(
resource::get_thrust_policy(handle_),
thrust::device_pointer_cast(removed_indices.data_handle()),
thrust::device_pointer_cast(removed_indices.data_handle() + removed_indices.extent(0)));
resource::sync_stream(handle_);
raft::core::bitset<std::uint32_t, IdxT> removed_indices_bitset(
handle_, removed_indices.view(), ps.n_rows);
cagra::search_with_filtering(
handle_,
search_params,
index,
search_queries_view,
indices_out_view,
dists_out_view,
raft::neighbors::filtering::bitset_filter(removed_indices_bitset.view()));
update_host(distances_Cagra.data(), distances_dev.data(), queries_size, stream_);
update_host(indices_Cagra.data(), indices_dev.data(), queries_size, stream_);
resource::sync_stream(handle_);
}

double min_recall = ps.min_recall;
EXPECT_TRUE(eval_neighbours(indices_naive,
indices_Cagra,
distances_naive,
distances_Cagra,
ps.n_queries,
ps.k,
0.001,
min_recall));
EXPECT_TRUE(eval_distances(handle_,
database.data(),
search_queries.data(),
indices_dev.data(),
distances_dev.data(),
ps.n_rows,
ps.dim,
ps.n_queries,
ps.k,
ps.metric,
1.0e-4));
}
}

void SetUp() override
{
database.resize(((size_t)ps.n_rows) * ps.dim, stream_);
Expand Down
6 changes: 5 additions & 1 deletion cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ typedef AnnCagraSortTest<float, float, std::uint32_t> AnnCagraSortTestF_U32;
TEST_P(AnnCagraSortTestF_U32, AnnCagraSort) { this->testCagraSort(); }

typedef AnnCagraFilterTest<float, float, std::uint32_t> AnnCagraFilterTestF_U32;
TEST_P(AnnCagraFilterTestF_U32, AnnCagraFilter) { this->testCagraFilter(); }
TEST_P(AnnCagraFilterTestF_U32, AnnCagraFilter)
{
this->testCagraFilter();
this->testCagraRemoved();
}

INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestF_U32, ::testing::ValuesIn(inputs));
INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestF_U32, ::testing::ValuesIn(inputs));
Expand Down
6 changes: 5 additions & 1 deletion cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ TEST_P(AnnCagraTestI8_U32, AnnCagra) { this->testCagra(); }
typedef AnnCagraSortTest<float, std::int8_t, std::uint32_t> AnnCagraSortTestI8_U32;
TEST_P(AnnCagraSortTestI8_U32, AnnCagraSort) { this->testCagraSort(); }
typedef AnnCagraFilterTest<float, std::int8_t, std::uint32_t> AnnCagraFilterTestI8_U32;
TEST_P(AnnCagraFilterTestI8_U32, AnnCagraFilter) { this->testCagraFilter(); }
TEST_P(AnnCagraFilterTestI8_U32, AnnCagraFilter)
{
this->testCagraFilter();
this->testCagraRemoved();
}

INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestI8_U32, ::testing::ValuesIn(inputs));
INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestI8_U32, ::testing::ValuesIn(inputs));
Expand Down
6 changes: 5 additions & 1 deletion cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ typedef AnnCagraSortTest<float, std::uint8_t, std::uint32_t> AnnCagraSortTestU8_
TEST_P(AnnCagraSortTestU8_U32, AnnCagraSort) { this->testCagraSort(); }

typedef AnnCagraFilterTest<float, std::uint8_t, std::uint32_t> AnnCagraFilterTestU8_U32;
TEST_P(AnnCagraFilterTestU8_U32, AnnCagraSort) { this->testCagraFilter(); }
TEST_P(AnnCagraFilterTestU8_U32, AnnCagraSort)
{
this->testCagraFilter();
this->testCagraRemoved();
}

INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestU8_U32, ::testing::ValuesIn(inputs));
INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestU8_U32, ::testing::ValuesIn(inputs));
Expand Down

0 comments on commit 9afb119

Please sign in to comment.