Skip to content

Commit

Permalink
Introduce sample filtering to IVFPQ index search (#1513)
Browse files Browse the repository at this point in the history
A prototype that introduces a per-sample filtering for IVFPQ search. Please feel free to use it as a foundation for the future change, if appropriate, because the code is functional, but is not super clean-and-neat.

The diff introduces a template parameter called `SampleFilterT`. An instance is expected 
* to be `SampleFilterT()` constructible (which was mostly needed to define a default behavior in the form of `SampleFilterT sample_filter=SampleFilterT()`, see below)
* to provide a `inline __device__ bool operator(...)`  that returns `true` is a given sample is valid for being used against a given query in IVFPQ search

The default filter (that I set as a default one in certain facilities in the form of `typename SampleFilterT = NoneSampleFilter` in order not to modify way to many files) allows all samples to be used:
```
struct NoneSampleFilter {
  inline __device__ __host__ bool operator()(
    // query index
    const uint32_t query_ix,
    // the current inverted list index
    const uint32_t cluster_ix,
    // the index of the current sample inside the current inverted list
    const uint32_t sample_ix
  ) const {
    return true;
  }
};
```
Here `__host__` is needed for a CPU-based testing only.
Also, I've provided an implementation of `BitMaskSampleFilter` that allows to filter samples based on a bit mask, as an example. The implementation was tested in the semi-production environment.

All the filter-related files were added to `cpp/include/raft/neighbors/detail/sample_filter.cuh`.

I did not change the default `ivf_pq_search()` method remains unchanged, but one more `ivf_pq_search_with_filtering()` method with an additional template argument `SampleFilterT` and one more input parameter was introduced.
```
template <typename T, typename IdxT, typename SampleFilterT>
void search_with_filtering(raft::device_resources const& handle,
            const raft::neighbors::ivf_pq::search_params& params,
            const index<IdxT>& idx,
            const T* queries,
            uint32_t n_queries,
            uint32_t k,
            IdxT* neighbors,
            float* distances,
            rmm::mr::device_memory_resource* mr = nullptr,
            SampleFilterT sample_filter = SampleFilterT());
```

All the current instantiations use `NoneSampleFilter` only.

I've used `SampleFilterT sample_filter` parameters passing instead of `const SampleFilterT sample_filter` in the function calls in order to be able to add some debugging facilities to a filter and with the hope that the compiler is smart enough to understand the de-facto constness if needed.

The filter does not take a computed distance score into account by design, thus the current implementation cannot have a distance threshold. This can be easily changed, if appropriate.

It is still questionable to me whether this filtering needs to be injected right inside the search kernel instead of doing post-processing, please let me know if you have any thoughts on the topic.

I'm happy to address the comments.

Thanks.

Authors:
  - Alexander Guzhva (https://github.com/alexanderguzhva)
  - Artem M. Chirkin (https://github.com/achirkin)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Artem M. Chirkin (https://github.com/achirkin)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1513
  • Loading branch information
alexanderguzhva authored May 19, 2023
1 parent 1f61b47 commit cdf107b
Show file tree
Hide file tree
Showing 14 changed files with 687 additions and 378 deletions.
124 changes: 73 additions & 51 deletions cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <raft/core/detail/macros.hpp> // RAFT_WEAK_FUNCTION
#include <raft/distance/distance_types.hpp> // raft::distance::DistanceType
#include <raft/neighbors/detail/ivf_pq_fp_8bit.cuh> // raft::neighbors::ivf_pq::detail::fp_8bit
#include <raft/neighbors/detail/sample_filter.cuh> // NoneSampleFilter
#include <raft/neighbors/ivf_pq_types.hpp> // raft::neighbors::ivf_pq::codebook_gen
#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT
#include <rmm/cuda_stream_view.hpp> // rmm::cuda_stream_view
Expand All @@ -36,6 +37,7 @@ auto RAFT_WEAK_FUNCTION is_local_topk_feasible(uint32_t k, uint32_t n_probes, ui

template <typename OutT,
typename LutT,
typename SampleFilterT,
uint32_t PqBits,
int Capacity,
bool PrecompBaseDiff,
Expand All @@ -45,6 +47,7 @@ __global__ void compute_similarity_kernel(uint32_t n_rows,
uint32_t n_probes,
uint32_t pq_dim,
uint32_t n_queries,
uint32_t queries_offset,
distance::DistanceType metric,
codebook_gen codebook_kind,
uint32_t topk,
Expand All @@ -57,32 +60,34 @@ __global__ void compute_similarity_kernel(uint32_t n_rows,
const float* queries,
const uint32_t* index_list,
float* query_kths,
SampleFilterT sample_filter,
LutT* lut_scores,
OutT* _out_scores,
uint32_t* _out_indices) RAFT_EXPLICIT;

// The signature of the kernel defined by a minimal set of template parameters
template <typename OutT, typename LutT>
template <typename OutT, typename LutT, typename SampleFilterT>
using compute_similarity_kernel_t =
decltype(&compute_similarity_kernel<OutT, LutT, 8, 0, true, true>);
decltype(&compute_similarity_kernel<OutT, LutT, SampleFilterT, 8, 0, true, true>);

template <typename OutT, typename LutT>
template <typename OutT, typename LutT, typename SampleFilterT>
struct selected {
compute_similarity_kernel_t<OutT, LutT> kernel;
compute_similarity_kernel_t<OutT, LutT, SampleFilterT> kernel;
dim3 grid_dim;
dim3 block_dim;
size_t smem_size;
size_t device_lut_size;
};

template <typename OutT, typename LutT>
void compute_similarity_run(selected<OutT, LutT> s,
template <typename OutT, typename LutT, typename SampleFilterT>
void compute_similarity_run(selected<OutT, LutT, SampleFilterT> s,
rmm::cuda_stream_view stream,
uint32_t n_rows,
uint32_t dim,
uint32_t n_probes,
uint32_t pq_dim,
uint32_t n_queries,
uint32_t queries_offset,
distance::DistanceType metric,
codebook_gen codebook_kind,
uint32_t topk,
Expand All @@ -95,6 +100,7 @@ void compute_similarity_run(selected<OutT, LutT> s,
const float* queries,
const uint32_t* index_list,
float* query_kths,
SampleFilterT sample_filter,
LutT* lut_scores,
OutT* _out_scores,
uint32_t* _out_indices) RAFT_EXPLICIT;
Expand All @@ -113,7 +119,7 @@ void compute_similarity_run(selected<OutT, LutT> s,
* beyond this limit do not consider increasing the number of active blocks per SM
* would improve locality anymore.
*/
template <typename OutT, typename LutT>
template <typename OutT, typename LutT, typename SampleFilterT>
auto compute_similarity_select(const cudaDeviceProp& dev_props,
bool manage_local_topk,
int locality_hint,
Expand All @@ -123,62 +129,78 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props,
uint32_t precomp_data_count,
uint32_t n_queries,
uint32_t n_probes,
uint32_t topk) -> selected<OutT, LutT> RAFT_EXPLICIT;
uint32_t topk) -> selected<OutT, LutT, SampleFilterT> RAFT_EXPLICIT;

} // namespace raft::neighbors::ivf_pq::detail

#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY

#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \
extern template auto raft::neighbors::ivf_pq::detail::compute_similarity_select<OutT, LutT>( \
const cudaDeviceProp& dev_props, \
bool manage_local_topk, \
int locality_hint, \
double preferred_shmem_carveout, \
uint32_t pq_bits, \
uint32_t pq_dim, \
uint32_t precomp_data_count, \
uint32_t n_queries, \
uint32_t n_probes, \
uint32_t topk) \
->raft::neighbors::ivf_pq::detail::selected<OutT, LutT>; \
\
extern template void raft::neighbors::ivf_pq::detail::compute_similarity_run<OutT, LutT>( \
raft::neighbors::ivf_pq::detail::selected<OutT, LutT> s, \
rmm::cuda_stream_view stream, \
uint32_t n_rows, \
uint32_t dim, \
uint32_t n_probes, \
uint32_t pq_dim, \
uint32_t n_queries, \
raft::distance::DistanceType metric, \
raft::neighbors::ivf_pq::codebook_gen codebook_kind, \
uint32_t topk, \
uint32_t max_samples, \
const float* cluster_centers, \
const float* pq_centers, \
const uint8_t* const* pq_dataset, \
const uint32_t* cluster_labels, \
const uint32_t* _chunk_indices, \
const float* queries, \
const uint32_t* index_list, \
float* query_kths, \
LutT* lut_scores, \
OutT* _out_scores, \
#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \
OutT, LutT, SampleFilterT) \
extern template auto \
raft::neighbors::ivf_pq::detail::compute_similarity_select<OutT, LutT, SampleFilterT>( \
const cudaDeviceProp& dev_props, \
bool manage_local_topk, \
int locality_hint, \
double preferred_shmem_carveout, \
uint32_t pq_bits, \
uint32_t pq_dim, \
uint32_t precomp_data_count, \
uint32_t n_queries, \
uint32_t n_probes, \
uint32_t topk) \
->raft::neighbors::ivf_pq::detail::selected<OutT, LutT, SampleFilterT>; \
\
extern template void \
raft::neighbors::ivf_pq::detail::compute_similarity_run<OutT, LutT, SampleFilterT>( \
raft::neighbors::ivf_pq::detail::selected<OutT, LutT, SampleFilterT> s, \
rmm::cuda_stream_view stream, \
uint32_t n_rows, \
uint32_t dim, \
uint32_t n_probes, \
uint32_t pq_dim, \
uint32_t n_queries, \
uint32_t queries_offset, \
raft::distance::DistanceType metric, \
raft::neighbors::ivf_pq::codebook_gen codebook_kind, \
uint32_t topk, \
uint32_t max_samples, \
const float* cluster_centers, \
const float* pq_centers, \
const uint8_t* const* pq_dataset, \
const uint32_t* cluster_labels, \
const uint32_t* _chunk_indices, \
const float* queries, \
const uint32_t* index_list, \
float* query_kths, \
SampleFilterT sample_filter, \
LutT* lut_scores, \
OutT* _out_scores, \
uint32_t* _out_indices);

#define COMMA ,
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(
half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>);
half,
raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>,
raft::neighbors::ivf_pq::detail::NoneSampleFilter);
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(
half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>);
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(half, half);
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(float, half);
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(float, float);
half,
raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>,
raft::neighbors::ivf_pq::detail::NoneSampleFilter);
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(
float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>);
half, half, raft::neighbors::ivf_pq::detail::NoneSampleFilter);
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(
float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>);
float, half, raft::neighbors::ivf_pq::detail::NoneSampleFilter);
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(
float, float, raft::neighbors::ivf_pq::detail::NoneSampleFilter);
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(
float,
raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>,
raft::neighbors::ivf_pq::detail::NoneSampleFilter);
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(
float,
raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>,
raft::neighbors::ivf_pq::detail::NoneSampleFilter);

#undef COMMA

Expand Down
Loading

0 comments on commit cdf107b

Please sign in to comment.