From 3eab24b5704779290d3ac10fc22de85cad222807 Mon Sep 17 00:00:00 2001 From: achirkin Date: Fri, 11 Mar 2022 12:10:49 +0100 Subject: [PATCH] Allow different types in select-k functions (float/double, int/size_t) --- cpp/bench/spatial/selection.cu | 13 ++- cpp/include/raft/cudart_utils.h | 18 +++- .../knn/detail/ivf_flat/radix_topk.cuh | 13 +-- .../knn/detail/ivf_flat/warpsort_topk.cuh | 47 ++--------- .../spatial/knn/detail/selection_faiss.cuh | 16 ++-- cpp/test/spatial/selection.cu | 83 ++++++++++++------- 6 files changed, 106 insertions(+), 84 deletions(-) diff --git a/cpp/bench/spatial/selection.cu b/cpp/bench/spatial/selection.cu index fbe65c2dbc..314a6b007e 100644 --- a/cpp/bench/spatial/selection.cu +++ b/cpp/bench/spatial/selection.cu @@ -109,9 +109,14 @@ SELECTION_REGISTER(float, int, RADIX_8_BITS); SELECTION_REGISTER(float, int, RADIX_11_BITS); SELECTION_REGISTER(float, int, WARP_SORT); -// SELECTION_REGISTER(double, int, FAISS); -// SELECTION_REGISTER(double, int, RADIX_8_BITS); -// SELECTION_REGISTER(double, int, RADIX_11_BITS); -// SELECTION_REGISTER(double, int, WARP_SORT); +SELECTION_REGISTER(double, int, FAISS); +SELECTION_REGISTER(double, int, RADIX_8_BITS); +SELECTION_REGISTER(double, int, RADIX_11_BITS); +SELECTION_REGISTER(double, int, WARP_SORT); + +SELECTION_REGISTER(double, size_t, FAISS); +SELECTION_REGISTER(double, size_t, RADIX_8_BITS); +SELECTION_REGISTER(double, size_t, RADIX_11_BITS); +SELECTION_REGISTER(double, size_t, WARP_SORT); } // namespace raft::bench::spatial diff --git a/cpp/include/raft/cudart_utils.h b/cpp/include/raft/cudart_utils.h index 936065afba..1b278576ca 100644 --- a/cpp/include/raft/cudart_utils.h +++ b/cpp/include/raft/cudart_utils.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -430,4 +430,20 @@ IntType gcd(IntType a, IntType b) return a; } +template +constexpr T lower_bound() +{ + if constexpr (std::numeric_limits::has_infinity && std::numeric_limits::is_signed) { + return -std::numeric_limits::infinity(); + } + return std::numeric_limits::lowest(); +} + +template +constexpr T upper_bound() +{ + if constexpr (std::numeric_limits::has_infinity) { return std::numeric_limits::infinity(); } + return std::numeric_limits::max(); +} + } // namespace raft diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat/radix_topk.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat/radix_topk.cuh index 31d4f7a009..e88ac602f7 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat/radix_topk.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat/radix_topk.cuh @@ -22,6 +22,7 @@ #include #include +#include /* Two implementations: @@ -176,7 +177,7 @@ __device__ void filter_and_histogram(const T* in_buf, if (pass == 0) { auto f = [greater, start_bit, mask](T value, IdxT) { int bucket = calc_bucket(value, start_bit, mask, greater); - atomicAdd(histogram_smem + bucket, 1); + atomicAdd(histogram_smem + bucket, IdxT(1)); }; vectorized_process(in_buf, len, f); } else { @@ -208,11 +209,11 @@ __device__ void filter_and_histogram(const T* in_buf, int prev_bucket = calc_bucket(value, previous_start_bit, previous_mask, greater); if (prev_bucket == want_bucket) { - IdxT pos = atomicAdd(&filter_cnt, 1); + IdxT pos = atomicAdd(&filter_cnt, IdxT(1)); out_buf[pos] = value; if (out_idx_buf) { out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; } int bucket = calc_bucket(value, start_bit, mask, greater); - atomicAdd(histogram_smem + bucket, 1); + atomicAdd(histogram_smem + bucket, IdxT(1)); if (counter_len == 1) { if (out) { @@ -223,7 +224,7 @@ __device__ void filter_and_histogram(const T* in_buf, } } } else if (out && prev_bucket < want_bucket) { - IdxT pos = atomicAdd(&out_cnt, 1); + IdxT pos = atomicAdd(&out_cnt, IdxT(1)); out[pos] = value; out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } @@ -385,12 +386,12 @@ __global__ void radix_kernel(const T* in_buf, const T value = out_buf[i]; int bucket = calc_bucket(value, start_bit, mask, greater); if (bucket < want_bucket) { - IdxT pos = atomicAdd(&out_cnt, 1); + IdxT pos = atomicAdd(&out_cnt, IdxT(1)); out[pos] = value; out_idx[pos] = out_idx_buf[i]; } else if (bucket == want_bucket) { IdxT needed_num_of_kth = counter->k; - IdxT back_pos = atomicAdd(&(counter->out_back_cnt), 1); + IdxT back_pos = atomicAdd(&(counter->out_back_cnt), IdxT(1)); if (back_pos < needed_num_of_kth) { IdxT pos = k - 1 - back_pos; out[pos] = value; diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat/warpsort_topk.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat/warpsort_topk.cuh index 47221b92bb..3d8835c254 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat/warpsort_topk.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat/warpsort_topk.cuh @@ -121,32 +121,6 @@ static constexpr int kMaxCapacity = 512; namespace { -template -constexpr T get_lower_bound() -{ - if (std::numeric_limits::has_infinity && std::numeric_limits::is_signed) { - return -std::numeric_limits::infinity(); - } else { - return std::numeric_limits::lowest(); - } -} - -template -constexpr T get_upper_bound() -{ - if (std::numeric_limits::has_infinity) { - return std::numeric_limits::infinity(); - } else { - return std::numeric_limits::max(); - } -} - -template -constexpr T get_dummy(bool greater) -{ - return greater ? get_lower_bound() : get_upper_bound(); -} - template __device__ inline bool is_greater_than(T val, T baseline) { @@ -217,7 +191,7 @@ class WarpSort { template class WarpSelect : public WarpSort { public: - __device__ WarpSelect(IdxT k, T dummy) + __device__ WarpSelect(int k, T dummy) : WarpSort(k, dummy), buf_len_(0), k_th_(dummy), @@ -335,7 +309,7 @@ class WarpSelect : public WarpSort { template class WarpBitonic : public WarpSort { public: - __device__ WarpBitonic(IdxT k, T dummy) + __device__ WarpBitonic(int k, T dummy) : WarpSort(k, dummy), buf_len_(0) { for (int i = 0; i < max_arr_len_; ++i) { @@ -465,7 +439,7 @@ class WarpBitonic : public WarpSort { template class WarpMerge : public WarpSort { public: - __device__ WarpMerge(IdxT k, T dummy) : WarpSort(k, dummy) {} + __device__ WarpMerge(int k, T dummy) : WarpSort(k, dummy) {} __device__ void add(const T* in, const IdxT* in_idx, IdxT start, IdxT end) { @@ -507,7 +481,7 @@ template