Skip to content

Commit

Permalink
[FEA] Add pre-filtering to CAGRA (#1811)
Browse files Browse the repository at this point in the history
This PR adds the pre-filtering feature to the CAGRA search implementations.

Rel: taken over from #1765

## Algorithm
The pre-filtering algorithm removes a node that should not be in the final result after it has behaved as a parent node. This way, the nodes that should not be in the final result are also used in the graph traversal, avoiding potential performance degradation.

## Changes
- Add filtering operation on a parent node after internal top-M buffer candidate calculation.
- Add filtering operation to result buffer before storing them in the device memory.

Authors:
  - tsuki (https://github.com/enp1s0)
  - Micka (https://github.com/lowener)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Micka (https://github.com/lowener)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1811
  • Loading branch information
enp1s0 authored Sep 25, 2023
1 parent 8522a14 commit cb24d99
Show file tree
Hide file tree
Showing 55 changed files with 2,142 additions and 1,280 deletions.
75 changes: 68 additions & 7 deletions cpp/include/raft/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@ namespace raft::neighbors::cagra {
* // use default index parameters
* cagra::index_params build_params;
* cagra::search_params search_params
* auto knn_graph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), 128);
* auto knn_graph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), 128);
* // create knn graph
* cagra::build_knn_graph(res, dataset, knn_graph.view(), 2, build_params, search_params);
* auto optimized_gaph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), 64);
* auto optimized_gaph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), 64);
* cagra::optimize(res, dataset, knn_graph.view(), optimized_graph.view());
* // Construct an index from dataset and optimized knn_graph
* auto index = cagra::index<T, IdxT>(res, build_params.metric(), dataset,
* optimized_graph.view());
* optimized_graph.view());
* @endcode
*
* @tparam DataT data element type
Expand Down Expand Up @@ -106,7 +106,7 @@ void build_knn_graph(raft::resources const& res,
* @code{.cpp}
* using namespace raft::neighbors;
* cagra::index_params build_params;
* auto knn_graph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), 128);
* auto knn_graph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), 128);
* // build KNN graph not using `cagra::build_knn_graph`
* // build(knn_graph, dataset, ...);
* // sort graph index
Expand All @@ -115,7 +115,7 @@ void build_knn_graph(raft::resources const& res,
* cagra::optimize(res, dataset, knn_graph.view(), optimized_graph.view());
* // Construct an index from dataset and optimized knn_graph
* auto index = cagra::index<T, IdxT>(res, build_params.metric(), dataset,
* optimized_graph.view());
* optimized_graph.view());
* @endcode
*
* @tparam DataT type of the data in the source dataset
Expand Down Expand Up @@ -316,9 +316,70 @@ void search(raft::resources const& res,
auto distances_internal = raft::make_device_matrix_view<float, int64_t, row_major>(
distances.data_handle(), distances.extent(0), distances.extent(1));

cagra::detail::search_main<T, internal_IdxT, IdxT>(
res, params, idx, queries_internal, neighbors_internal, distances_internal);
cagra::detail::search_main<T,
internal_IdxT,
decltype(raft::neighbors::filtering::none_cagra_sample_filter()),
IdxT>(res,
params,
idx,
queries_internal,
neighbors_internal,
distances_internal,
raft::neighbors::filtering::none_cagra_sample_filter());
}

/**
* @brief Search ANN using the constructed index with the given sample filter.
*
* See the [cagra::build](#cagra::build) documentation for a usage example.
*
* @tparam T data element type
* @tparam IdxT type of the indices
* @tparam CagraSampleFilterT Device filter function, with the signature
* `(uint32_t query ix, uint32_t sample_ix) -> bool`
*
* @param[in] res raft resources
* @param[in] params configure the search
* @param[in] idx cagra index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
* @param[in] sample_filter a device filter function that greenlights samples for a given query
*/
template <typename T, typename IdxT, typename CagraSampleFilterT>
void search_with_filtering(raft::resources const& res,
const search_params& params,
const index<T, IdxT>& idx,
raft::device_matrix_view<const T, int64_t, row_major> queries,
raft::device_matrix_view<IdxT, int64_t, row_major> neighbors,
raft::device_matrix_view<float, int64_t, row_major> distances,
CagraSampleFilterT sample_filter = CagraSampleFilterT())
{
RAFT_EXPECTS(
queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0),
"Number of rows in output neighbors and distances matrices must equal the number of queries.");

RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1),
"Number of columns in output neighbors and distances matrices must equal k");
RAFT_EXPECTS(queries.extent(1) == idx.dim(),
"Number of query dimensions should equal number of dimensions in the index.");

using internal_IdxT = typename std::make_unsigned<IdxT>::type;
auto queries_internal = raft::make_device_matrix_view<const T, int64_t, row_major>(
queries.data_handle(), queries.extent(0), queries.extent(1));
auto neighbors_internal = raft::make_device_matrix_view<internal_IdxT, int64_t, row_major>(
reinterpret_cast<internal_IdxT*>(neighbors.data_handle()),
neighbors.extent(0),
neighbors.extent(1));
auto distances_internal = raft::make_device_matrix_view<float, int64_t, row_major>(
distances.data_handle(), distances.extent(0), distances.extent(1));

cagra::detail::search_main<T, internal_IdxT, CagraSampleFilterT, IdxT>(
res, params, idx, queries_internal, neighbors_internal, distances_internal, sample_filter);
}

/** @} */ // end group cagra

} // namespace raft::neighbors::cagra
Expand Down
60 changes: 55 additions & 5 deletions cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <raft/core/resource/cuda_stream.hpp>
#include <raft/neighbors/detail/ivf_pq_search.cuh>
#include <raft/neighbors/sample_filter_types.hpp>
#include <raft/spatial/knn/detail/ann_utils.cuh>

#include <raft/core/device_mdspan.hpp>
Expand All @@ -34,6 +35,48 @@

namespace raft::neighbors::cagra::detail {

template <class CagraSampleFilterT>
struct CagraSampleFilterWithQueryIdOffset {
const uint32_t offset;
CagraSampleFilterT filter;

CagraSampleFilterWithQueryIdOffset(const uint32_t offset, const CagraSampleFilterT filter)
: offset(offset), filter(filter)
{
}

_RAFT_DEVICE auto operator()(const uint32_t query_id, const uint32_t sample_id)
{
return filter(query_id + offset, sample_id);
}
};

template <class CagraSampleFilterT>
struct CagraSampleFilterT_Selector {
using type = CagraSampleFilterWithQueryIdOffset<CagraSampleFilterT>;
};
template <>
struct CagraSampleFilterT_Selector<raft::neighbors::filtering::none_cagra_sample_filter> {
using type = raft::neighbors::filtering::none_cagra_sample_filter;
};

// A helper function to set a query id offset
template <class CagraSampleFilterT>
inline typename CagraSampleFilterT_Selector<CagraSampleFilterT>::type set_offset(
CagraSampleFilterT filter, const uint32_t offset)
{
typename CagraSampleFilterT_Selector<CagraSampleFilterT>::type new_filter(offset, filter);
return new_filter;
}
template <>
inline
typename CagraSampleFilterT_Selector<raft::neighbors::filtering::none_cagra_sample_filter>::type
set_offset<raft::neighbors::filtering::none_cagra_sample_filter>(
raft::neighbors::filtering::none_cagra_sample_filter filter, const uint32_t)
{
return filter;
}

/**
* @brief Search ANN using the constructed index.
*
Expand All @@ -54,13 +97,18 @@ namespace raft::neighbors::cagra::detail {
* k]
*/

template <typename T, typename internal_IdxT, typename IdxT = uint32_t, typename DistanceT = float>
template <typename T,
typename internal_IdxT,
typename CagraSampleFilterT,
typename IdxT = uint32_t,
typename DistanceT = float>
void search_main(raft::resources const& res,
search_params params,
const index<T, IdxT>& index,
raft::device_matrix_view<const T, int64_t, row_major> queries,
raft::device_matrix_view<internal_IdxT, int64_t, row_major> neighbors,
raft::device_matrix_view<DistanceT, int64_t, row_major> distances)
raft::device_matrix_view<DistanceT, int64_t, row_major> distances,
CagraSampleFilterT sample_filter = CagraSampleFilterT())
{
resource::detail::warn_non_pool_workspace(res, "raft::neighbors::cagra::search");
RAFT_LOG_DEBUG("# dataset size = %lu, dim = %lu\n",
Expand All @@ -77,8 +125,9 @@ void search_main(raft::resources const& res,
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"cagra::search(max_queries = %u, k = %u, dim = %zu)", params.max_queries, topk, index.dim());

std::unique_ptr<search_plan_impl<T, internal_IdxT, DistanceT>> plan =
factory<T, internal_IdxT, DistanceT>::create(
using CagraSampleFilterT_s = typename CagraSampleFilterT_Selector<CagraSampleFilterT>::type;
std::unique_ptr<search_plan_impl<T, internal_IdxT, DistanceT, CagraSampleFilterT_s>> plan =
factory<T, internal_IdxT, DistanceT, CagraSampleFilterT_s>::create(
res, params, index.dim(), index.graph_degree(), topk);

plan->check(neighbors.extent(1));
Expand Down Expand Up @@ -119,7 +168,8 @@ void search_main(raft::resources const& res,
n_queries,
_seed_ptr,
_num_executed_iterations,
topk);
topk,
set_offset(sample_filter, qid));
}

static_assert(std::is_same_v<DistanceT, float>,
Expand Down
13 changes: 8 additions & 5 deletions cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,17 +155,20 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(INDEX_T* const result_child_in
INDEX_T* const visited_hashmap_ptr,
const std::uint32_t hash_bitlen,
const INDEX_T* const parent_indices,
const INDEX_T* const internal_topk_list,
const std::uint32_t search_width)
{
const INDEX_T invalid_index = utils::get_max_value<INDEX_T>();
constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask<INDEX_T>::value;
const INDEX_T invalid_index = utils::get_max_value<INDEX_T>();

// Read child indices of parents from knn graph and check if the distance
// computaiton is necessary.
for (uint32_t i = threadIdx.x; i < knn_k * search_width; i += BLOCK_SIZE) {
const INDEX_T parent_id = parent_indices[i / knn_k];
INDEX_T child_id = invalid_index;
if (parent_id != invalid_index) {
child_id = knn_graph[(i % knn_k) + ((uint64_t)knn_k * parent_id)];
const INDEX_T smem_parent_id = parent_indices[i / knn_k];
INDEX_T child_id = invalid_index;
if (smem_parent_id != invalid_index) {
const auto parent_id = internal_topk_list[smem_parent_id] & ~index_msb_1_mask;
child_id = knn_graph[(i % knn_k) + ((uint64_t)knn_k * parent_id)];
}
if (child_id != invalid_index) {
if (hashmap::insert(visited_hashmap_ptr, hash_bitlen, child_id) == 0) {
Expand Down
42 changes: 25 additions & 17 deletions cpp/include/raft/neighbors/detail/cagra/factory.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,25 @@
#include "search_multi_kernel.cuh"
#include "search_plan.cuh"
#include "search_single_cta.cuh"
#include <raft/neighbors/sample_filter_types.hpp>

namespace raft::neighbors::cagra::detail {

template <typename T, typename IdxT = uint32_t, typename DistanceT = float>
template <typename T,
typename IdxT = uint32_t,
typename DistanceT = float,
typename CagraSampleFilterT = raft::neighbors::filtering::none_cagra_sample_filter>
class factory {
public:
/**
* Create a search structure for dataset with dim features.
*/
static std::unique_ptr<search_plan_impl<T, IdxT, DistanceT>> create(raft::resources const& res,
search_params const& params,
int64_t dim,
int64_t graph_degree,
uint32_t topk)
static std::unique_ptr<search_plan_impl<T, IdxT, DistanceT, CagraSampleFilterT>> create(
raft::resources const& res,
search_params const& params,
int64_t dim,
int64_t graph_degree,
uint32_t topk)
{
search_plan_impl_base plan(params, dim, graph_degree, topk);
switch (plan.max_dim) {
Expand Down Expand Up @@ -63,26 +68,29 @@ class factory {
break;
default: RAFT_LOG_DEBUG("Incorrect max_dim (%lu)\n", plan.max_dim);
}
return std::unique_ptr<search_plan_impl<T, IdxT, DistanceT>>();
return std::unique_ptr<search_plan_impl<T, IdxT, DistanceT, CagraSampleFilterT>>();
}

private:
template <unsigned MAX_DATASET_DIM, unsigned TEAM_SIZE>
static std::unique_ptr<search_plan_impl<T, IdxT, DistanceT>> dispatch_kernel(
static std::unique_ptr<search_plan_impl<T, IdxT, DistanceT, CagraSampleFilterT>> dispatch_kernel(
raft::resources const& res, search_plan_impl_base& plan)
{
if (plan.algo == search_algo::SINGLE_CTA) {
return std::unique_ptr<search_plan_impl<T, IdxT, DistanceT>>(
new single_cta_search::search<TEAM_SIZE, MAX_DATASET_DIM, T, IdxT, DistanceT>(
res, plan, plan.dim, plan.graph_degree, plan.topk));
return std::unique_ptr<search_plan_impl<T, IdxT, DistanceT, CagraSampleFilterT>>(
new single_cta_search::
search<TEAM_SIZE, MAX_DATASET_DIM, T, IdxT, DistanceT, CagraSampleFilterT>(
res, plan, plan.dim, plan.graph_degree, plan.topk));
} else if (plan.algo == search_algo::MULTI_CTA) {
return std::unique_ptr<search_plan_impl<T, IdxT, DistanceT>>(
new multi_cta_search::search<TEAM_SIZE, MAX_DATASET_DIM, T, IdxT, DistanceT>(
res, plan, plan.dim, plan.graph_degree, plan.topk));
return std::unique_ptr<search_plan_impl<T, IdxT, DistanceT, CagraSampleFilterT>>(
new multi_cta_search::
search<TEAM_SIZE, MAX_DATASET_DIM, T, IdxT, DistanceT, CagraSampleFilterT>(
res, plan, plan.dim, plan.graph_degree, plan.topk));
} else {
return std::unique_ptr<search_plan_impl<T, IdxT, DistanceT>>(
new multi_kernel_search::search<TEAM_SIZE, MAX_DATASET_DIM, T, IdxT, DistanceT>(
res, plan, plan.dim, plan.graph_degree, plan.topk));
return std::unique_ptr<search_plan_impl<T, IdxT, DistanceT, CagraSampleFilterT>>(
new multi_kernel_search::
search<TEAM_SIZE, MAX_DATASET_DIM, T, IdxT, DistanceT, CagraSampleFilterT>(
res, plan, plan.dim, plan.graph_degree, plan.topk));
}
}
};
Expand Down
Loading

0 comments on commit cb24d99

Please sign in to comment.