From a9e060d354917114bb8baa5de9c55ef917f203af Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Wed, 14 Dec 2022 01:41:13 -0800 Subject: [PATCH] Cleanup faiss includes (#1098) This mainly removes un-needed faiss includes - but also removes the `KeyValueWarpSelect` function which isn't being used anywhere, and replaces `faiss::gpu::KeyValuePair` with `cub::KeyValuePair` in warp_select_faiss.cuh Authors: - Ben Frederickson (https://github.com/benfred) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1098 --- .../raft/spatial/knn/detail/ann_quantized.cuh | 6 - .../raft/spatial/knn/detail/ball_cover.cuh | 2 - .../knn/detail/ball_cover/registers.cuh | 1 - .../raft/spatial/knn/detail/common_faiss.h | 2 +- .../spatial/knn/detail/haversine_distance.cuh | 3 - .../knn/detail/knn_brute_force_faiss.cuh | 2 - .../spatial/knn/detail/selection_faiss.cuh | 5 - .../spatial/knn/detail/warp_select_faiss.cuh | 208 +----------------- 8 files changed, 4 insertions(+), 225 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh index e5900ffd69..e6c0b83161 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh @@ -32,16 +32,10 @@ #include -#include #include #include #include #include -#include -#include -#include -#include -#include #include diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index 32a8f0ed33..797dbaab50 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -38,9 +38,7 @@ #include #include -#include #include -#include #include #include diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh index 9c5307e683..a883a1eadd 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh @@ -30,7 +30,6 @@ #include #include -#include #include diff --git a/cpp/include/raft/spatial/knn/detail/common_faiss.h b/cpp/include/raft/spatial/knn/detail/common_faiss.h index b098d0991d..57076350f0 100644 --- a/cpp/include/raft/spatial/knn/detail/common_faiss.h +++ b/cpp/include/raft/spatial/knn/detail/common_faiss.h @@ -19,7 +19,7 @@ #include #include -#include +#include #include namespace raft { diff --git a/cpp/include/raft/spatial/knn/detail/haversine_distance.cuh b/cpp/include/raft/spatial/knn/detail/haversine_distance.cuh index 5c03f8f67c..333fc1c573 100644 --- a/cpp/include/raft/spatial/knn/detail/haversine_distance.cuh +++ b/cpp/include/raft/spatial/knn/detail/haversine_distance.cuh @@ -19,11 +19,8 @@ #include #include -#include -#include #include #include -#include #include #include diff --git a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh index 0c33c3f38f..086cae1089 100644 --- a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh @@ -23,10 +23,8 @@ #include #include -#include #include #include -#include #include #include diff --git a/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh b/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh index 239379aad5..27c7e006ca 100644 --- a/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh @@ -20,12 +20,7 @@ #include #include -#include -#include -#include -#include #include -#include namespace raft { namespace spatial { diff --git a/cpp/include/raft/spatial/knn/detail/warp_select_faiss.cuh b/cpp/include/raft/spatial/knn/detail/warp_select_faiss.cuh index b6ffbd5122..2ce2d34cca 100644 --- a/cpp/include/raft/spatial/knn/detail/warp_select_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/warp_select_faiss.cuh @@ -7,51 +7,17 @@ #pragma once -#include - #include #include #include #include #include +#include + namespace faiss { namespace gpu { - -template -struct KeyValuePair { - typedef _Key Key; ///< Key data type - typedef _Value Value; ///< Value data type - - Key key; ///< Item key - Value value; ///< Item value - - /// Constructor - __host__ __device__ __forceinline__ KeyValuePair() {} - - /// Copy Constructors - __host__ __device__ __forceinline__ KeyValuePair(cub::KeyValuePair<_Key, _Value>& kvp) - : key(kvp.key), value(kvp.value) - { - } - - __host__ __device__ __forceinline__ KeyValuePair(faiss::gpu::KeyValuePair<_Key, _Value>& kvp) - : key(kvp.key), value(kvp.value) - { - } - - /// Constructor - __host__ __device__ __forceinline__ KeyValuePair(Key const& key, Value const& value) - : key(key), value(value) - { - } - - /// Inequality operator - __host__ __device__ __forceinline__ bool operator!=(const KeyValuePair& b) - { - return (value != b.value) || (key != b.key); - } -}; +using raft::KeyValuePair; // // This file contains functions to: @@ -591,173 +557,5 @@ inline __device__ void warpSortAnyRegistersKVP(K k[N], KeyValuePair v[N]) { BitonicSortStepKVP::sort(k, v); } - -// `Dir` true, produce largest values. -// `Dir` false, produce smallest values. -template -struct KeyValueWarpSelect { - static constexpr int kNumWarpQRegisters = NumWarpQ / faiss::gpu::kWarpSize; - - __device__ inline KeyValueWarpSelect(K initKVal, faiss::gpu::KeyValuePair initVVal, int k) - : initK(initKVal), - initV(initVVal), - numVals(0), - warpKTop(initKVal), - warpKTopRDist(initKVal), - kLane((k - 1) % faiss::gpu::kWarpSize) - { - static_assert(faiss::gpu::utils::isPowerOf2(ThreadsPerBlock), "threads must be a power-of-2"); - static_assert(faiss::gpu::utils::isPowerOf2(NumWarpQ), "warp queue must be power-of-2"); - - // Fill the per-thread queue keys with the default value -#pragma unroll - for (int i = 0; i < NumThreadQ; ++i) { - threadK[i] = initK; - threadV[i].key = initV.key; - threadV[i].value = initV.value; - } - - // Fill the warp queue with the default value -#pragma unroll - for (int i = 0; i < kNumWarpQRegisters; ++i) { - warpK[i] = initK; - warpV[i].key = initV.key; - warpV[i].value = initV.value; - } - } - - __device__ inline void addThreadQ(K k, faiss::gpu::KeyValuePair& v) - { - if (Dir ? Comp::gt(k, warpKTop) : Comp::lt(k, warpKTop)) { - // Rotate right -#pragma unroll - for (int i = NumThreadQ - 1; i > 0; --i) { - threadK[i] = threadK[i - 1]; - threadV[i].key = threadV[i - 1].key; - threadV[i].value = threadV[i - 1].value; - } - - threadK[0] = k; - threadV[0].key = v.key; - threadV[0].value = v.value; - ++numVals; - } - } - /// This function handles sorting and merging together the - /// per-thread queues with the warp-wide queue, creating a sorted - /// list across both - - // TODO - __device__ inline void mergeWarpQ() - { - // Sort all of the per-thread queues - faiss::gpu::warpSortAnyRegistersKVP(threadK, threadV); - - // The warp queue is already sorted, and now that we've sorted the - // per-thread queue, merge both sorted lists together, producing - // one sorted list - faiss::gpu::warpMergeAnyRegistersKVP( - warpK, warpV, threadK, threadV); - } - - /// WARNING: all threads in a warp must participate in this. - /// Otherwise, you must call the constituent parts separately. - __device__ inline void add(K k, faiss::gpu::KeyValuePair& v) - { - addThreadQ(k, v); - checkThreadQ(); - } - - __device__ inline void reduce() - { - // Have all warps dump and merge their queues; this will produce - // the final per-warp results - mergeWarpQ(); - } - - __device__ inline void checkThreadQ() - { - bool needSort = (numVals == NumThreadQ); - -#if CUDA_VERSION >= 9000 - needSort = __any_sync(0xffffffff, needSort); -#else - needSort = __any(needSort); -#endif - - if (!needSort) { - // no lanes have triggered a sort - return; - } - - mergeWarpQ(); - - // Any top-k elements have been merged into the warp queue; we're - // free to reset the thread queues - numVals = 0; - -#pragma unroll - for (int i = 0; i < NumThreadQ; ++i) { - threadK[i] = initK; - threadV[i].key = initV.key; - threadV[i].value = initV.value; - } - - // We have to beat at least this element - warpKTopRDist = shfl(warpV[kNumWarpQRegisters - 1].key, kLane); - warpKTop = shfl(warpK[kNumWarpQRegisters - 1], kLane); - } - - /// Dump final k selected values for this warp out - __device__ inline void writeOut(K* outK, V* outV, int k) - { - int laneId = faiss::gpu::getLaneId(); - -#pragma unroll - for (int i = 0; i < kNumWarpQRegisters; ++i) { - int idx = i * faiss::gpu::kWarpSize + laneId; - - if (idx < k) { - outK[idx] = warpK[i]; - outV[idx] = warpV[i].value; - } - } - } - - // Default element key - const K initK; - - // Default element value - const faiss::gpu::KeyValuePair initV; - - // Number of valid elements in our thread queue - int numVals; - - // The k-th highest (Dir) or lowest (!Dir) element - K warpKTop; - - // TopK's distance to closest landmark - K warpKTopRDist; - - // Thread queue values - K threadK[NumThreadQ]; - faiss::gpu::KeyValuePair threadV[NumThreadQ]; - - // warpK[0] is highest (Dir) or lowest (!Dir) - K warpK[kNumWarpQRegisters]; - faiss::gpu::KeyValuePair warpV[kNumWarpQRegisters]; - - // This is what lane we should load an approximation (>=k) to the - // kth element from the last register in the warp queue (i.e., - // warpK[kNumWarpQRegisters - 1]). - int kLane; -}; - } // namespace gpu } // namespace faiss