Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sample filtering for ivf_flat. Filtering code refactoring and cleanup #1541

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,25 @@

#include <cstdint> // uintX_t
#include <raft/neighbors/ivf_flat_types.hpp> // raft::neighbors::ivf_flat::index
#include <raft/neighbors/sample_filter.cuh> // none_ivf_sample_filter
#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT
#include <rmm/cuda_stream_view.hpp> // rmm:cuda_stream_view

#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY

namespace raft::neighbors::ivf_flat::detail {

template <typename T, typename AccT, typename IdxT>
template <typename T, typename AccT, typename IdxT, typename IvfSampleFilterT>
void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index<T, IdxT>& index,
const T* queries,
const uint32_t* coarse_query_results,
const uint32_t n_queries,
const uint32_t queries_offset,
const raft::distance::DistanceType metric,
const uint32_t n_probes,
const uint32_t k,
const bool select_min,
IvfSampleFilterT sample_filter,
IdxT* neighbors,
float* distances,
uint32_t& grid_dim_x,
Expand All @@ -43,23 +46,30 @@ void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index<T, IdxT>& i

#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY

#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(T, AccT, IdxT) \
extern template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan<T, AccT, IdxT>( \
const raft::neighbors::ivf_flat::index<T, IdxT>& index, \
const T* queries, \
const uint32_t* coarse_query_results, \
const uint32_t n_queries, \
const raft::distance::DistanceType metric, \
const uint32_t n_probes, \
const uint32_t k, \
const bool select_min, \
IdxT* neighbors, \
float* distances, \
uint32_t& grid_dim_x, \
#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( \
T, AccT, IdxT, IvfSampleFilterT) \
extern template void \
raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan<T, AccT, IdxT, IvfSampleFilterT>( \
const raft::neighbors::ivf_flat::index<T, IdxT>& index, \
const T* queries, \
const uint32_t* coarse_query_results, \
const uint32_t n_queries, \
const uint32_t queries_offset, \
const raft::distance::DistanceType metric, \
const uint32_t n_probes, \
const uint32_t k, \
const bool select_min, \
IvfSampleFilterT sample_filter, \
IdxT* neighbors, \
float* distances, \
uint32_t& grid_dim_x, \
rmm::cuda_stream_view stream)

instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(float, float, int64_t);
instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(int8_t, int32_t, int64_t);
instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(uint8_t, uint32_t, int64_t);
instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(
float, float, int64_t, raft::neighbors::filtering::none_ivf_sample_filter);
instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(
int8_t, int32_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter);
instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(
uint8_t, uint32_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter);

#undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,7 @@ struct loadAndComputeDist<kUnroll, Lambda, 1, int8_t, int32_t> {
* @param n_probes
* @param k
* @param dim
* @param sample_filter
* @param[out] neighbors
* @param[out] distances
*/
Expand All @@ -655,6 +656,7 @@ template <int Capacity,
typename T,
typename AccT,
typename IdxT,
typename IvfSampleFilterT,
typename Lambda,
typename PostLambda>
__global__ void __launch_bounds__(kThreadsPerBlock)
Expand All @@ -666,9 +668,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
const IdxT* const* list_indices_ptrs,
const T* const* list_data_ptrs,
const uint32_t* list_sizes,
const uint32_t queries_offset,
const uint32_t n_probes,
const uint32_t k,
const uint32_t dim,
IvfSampleFilterT sample_filter,
IdxT* neighbors,
float* distances)
{
Expand Down Expand Up @@ -736,7 +740,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
const bool valid = vec_id < list_length;

// Process first shm_assisted_dim dimensions (always using shared memory)
if (valid) {
if (valid && sample_filter(queries_offset + blockIdx.y, probe_id, vec_id)) {
loadAndComputeDist<kUnroll, decltype(compute_dist), Veclen, T, AccT> lc(dist,
compute_dist);
for (int pos = 0; pos < shm_assisted_dim;
Expand Down Expand Up @@ -803,6 +807,7 @@ template <int Capacity,
typename T,
typename AccT,
typename IdxT,
typename IvfSampleFilterT,
typename Lambda,
typename PostLambda>
void launch_kernel(Lambda lambda,
Expand All @@ -811,17 +816,26 @@ void launch_kernel(Lambda lambda,
const T* queries,
const uint32_t* coarse_index,
const uint32_t num_queries,
const uint32_t queries_offset,
const uint32_t n_probes,
const uint32_t k,
IvfSampleFilterT sample_filter,
IdxT* neighbors,
float* distances,
uint32_t& grid_dim_x,
rmm::cuda_stream_view stream)
{
RAFT_EXPECTS(Veclen == index.veclen(),
"Configured Veclen does not match the index interleaving pattern.");
constexpr auto kKernel =
interleaved_scan_kernel<Capacity, Veclen, Ascending, T, AccT, IdxT, Lambda, PostLambda>;
constexpr auto kKernel = interleaved_scan_kernel<Capacity,
Veclen,
Ascending,
T,
AccT,
IdxT,
IvfSampleFilterT,
Lambda,
PostLambda>;
const int max_query_smem = 16384;
int query_smem_elems =
std::min<int>(max_query_smem / sizeof(T), Pow2<Veclen * WarpSize>::roundUp(index.dim()));
Expand Down Expand Up @@ -860,9 +874,11 @@ void launch_kernel(Lambda lambda,
index.inds_ptrs().data_handle(),
index.data_ptrs().data_handle(),
index.list_sizes().data_handle(),
queries_offset + query_offset,
n_probes,
k,
index.dim(),
sample_filter,
neighbors,
distances);
queries += grid_dim_y * index.dim();
Expand Down Expand Up @@ -931,6 +947,7 @@ template <int Capacity,
typename T,
typename AccT,
typename IdxT,
typename IvfSampleFilterT,
typename... Args>
void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... args)
{
Expand All @@ -943,6 +960,7 @@ void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... arg
T,
AccT,
IdxT,
IvfSampleFilterT,
euclidean_dist<Veclen, T, AccT>,
raft::identity_op>({}, {}, std::forward<Args>(args)...);
case raft::distance::DistanceType::L2SqrtExpanded:
Expand All @@ -953,6 +971,7 @@ void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... arg
T,
AccT,
IdxT,
IvfSampleFilterT,
euclidean_dist<Veclen, T, AccT>,
raft::sqrt_op>({}, {}, std::forward<Args>(args)...);
case raft::distance::DistanceType::InnerProduct:
Expand All @@ -962,6 +981,7 @@ void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... arg
T,
AccT,
IdxT,
IvfSampleFilterT,
inner_prod_dist<Veclen, T, AccT>,
raft::identity_op>({}, {}, std::forward<Args>(args)...);
// NB: update the description of `knn::ivf_flat::build` when adding here a new metric.
Expand All @@ -976,6 +996,7 @@ void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... arg
template <typename T,
typename AccT,
typename IdxT,
typename IvfSampleFilterT,
int Capacity = matrix::detail::select::warpsort::kMaxCapacity,
int Veclen = std::max<int>(1, 16 / sizeof(T))>
struct select_interleaved_scan_kernel {
Expand All @@ -990,13 +1011,20 @@ struct select_interleaved_scan_kernel {
{
if constexpr (Capacity > 1) {
if (capacity * 2 <= Capacity) {
return select_interleaved_scan_kernel<T, AccT, IdxT, Capacity / 2, Veclen>::run(
capacity, veclen, select_min, std::forward<Args>(args)...);
return select_interleaved_scan_kernel<T,
AccT,
IdxT,
IvfSampleFilterT,
Capacity / 2,
Veclen>::run(capacity,
veclen,
select_min,
std::forward<Args>(args)...);
}
}
if constexpr (Veclen > 1) {
if (veclen % Veclen != 0) {
return select_interleaved_scan_kernel<T, AccT, IdxT, Capacity, 1>::run(
return select_interleaved_scan_kernel<T, AccT, IdxT, IvfSampleFilterT, Capacity, 1>::run(
capacity, 1, select_min, std::forward<Args>(args)...);
}
}
Expand All @@ -1010,9 +1038,11 @@ struct select_interleaved_scan_kernel {
veclen == Veclen,
"Veclen must be power-of-two not bigger than the maximum allowed size for this data type.");
if (select_min) {
launch_with_fixed_consts<Capacity, Veclen, true, T, AccT, IdxT>(std::forward<Args>(args)...);
launch_with_fixed_consts<Capacity, Veclen, true, T, AccT, IdxT, IvfSampleFilterT>(
std::forward<Args>(args)...);
} else {
launch_with_fixed_consts<Capacity, Veclen, false, T, AccT, IdxT>(std::forward<Args>(args)...);
launch_with_fixed_consts<Capacity, Veclen, false, T, AccT, IdxT, IvfSampleFilterT>(
std::forward<Args>(args)...);
}
}
};
Expand All @@ -1028,6 +1058,9 @@ struct select_interleaved_scan_kernel {
* @param[in] queries device pointer to the query vectors [batch_size, dim]
* @param[in] coarse_query_results device pointer to the cluster (list) ids [batch_size, n_probes]
* @param n_queries batch size
* @param[in] queries_offset
* An offset of the current query batch. It is used for feeding sample_filter with the
* correct query index.
* @param metric type of the measured distance
* @param n_probes number of nearest clusters to query
* @param k number of nearest neighbors.
Expand All @@ -1041,36 +1074,43 @@ struct select_interleaved_scan_kernel {
* @param[inout] grid_dim_x number of blocks launched across all n_probes clusters;
* (one block processes one or more probes, hence: 1 <= grid_dim_x <= n_probes)
* @param stream
* @param sample_filter
* A filter that selects samples for a given query. Use an instance of none_ivf_sample_filter to
* provide a green light for every sample.
*/
template <typename T, typename AccT, typename IdxT>
template <typename T, typename AccT, typename IdxT, typename IvfSampleFilterT>
void ivfflat_interleaved_scan(const index<T, IdxT>& index,
const T* queries,
const uint32_t* coarse_query_results,
const uint32_t n_queries,
const uint32_t queries_offset,
const raft::distance::DistanceType metric,
const uint32_t n_probes,
const uint32_t k,
const bool select_min,
IvfSampleFilterT sample_filter,
IdxT* neighbors,
float* distances,
uint32_t& grid_dim_x,
rmm::cuda_stream_view stream)
{
const int capacity = bound_by_power_of_two(k);
select_interleaved_scan_kernel<T, AccT, IdxT>::run(capacity,
index.veclen(),
select_min,
metric,
index,
queries,
coarse_query_results,
n_queries,
n_probes,
k,
neighbors,
distances,
grid_dim_x,
stream);
select_interleaved_scan_kernel<T, AccT, IdxT, IvfSampleFilterT>::run(capacity,
index.veclen(),
select_min,
metric,
index,
queries,
coarse_query_results,
n_queries,
queries_offset,
n_probes,
k,
sample_filter,
neighbors,
distances,
grid_dim_x,
stream);
}

} // namespace raft::neighbors::ivf_flat::detail
Loading