Skip to content

Commit

Permalink
Cleanup faiss includes (#1098)
Browse files Browse the repository at this point in the history
This mainly removes un-needed faiss includes - but also removes the `KeyValueWarpSelect` function which isn't being used anywhere, and replaces `faiss::gpu::KeyValuePair` with `cub::KeyValuePair` in warp_select_faiss.cuh

Authors:
  - Ben Frederickson (https://github.com/benfred)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1098
  • Loading branch information
benfred authored Dec 14, 2022
1 parent 6d83c91 commit a9e060d
Show file tree
Hide file tree
Showing 8 changed files with 4 additions and 225 deletions.
6 changes: 0 additions & 6 deletions cpp/include/raft/spatial/knn/detail/ann_quantized.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,10 @@

#include <rmm/cuda_stream_view.hpp>

#include <faiss/gpu/GpuDistance.h>
#include <faiss/gpu/GpuIndexFlat.h>
#include <faiss/gpu/GpuIndexIVFFlat.h>
#include <faiss/gpu/GpuIndexIVFPQ.h>
#include <faiss/gpu/GpuIndexIVFScalarQuantizer.h>
#include <faiss/gpu/GpuResources.h>
#include <faiss/gpu/utils/Limits.cuh>
#include <faiss/gpu/utils/Select.cuh>
#include <faiss/gpu/utils/Tensor.cuh>
#include <faiss/utils/Heap.h>

#include <thrust/iterator/transform_iterator.h>

Expand Down
2 changes: 0 additions & 2 deletions cpp/include/raft/spatial/knn/detail/ball_cover.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@
#include <rmm/device_uvector.hpp>
#include <rmm/exec_policy.hpp>

#include <faiss/gpu/utils/Limits.cuh>
#include <faiss/gpu/utils/Select.cuh>
#include <faiss/utils/Heap.h>

#include <thrust/fill.h>
#include <thrust/for_each.h>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

#include <faiss/gpu/utils/Limits.cuh>
#include <faiss/gpu/utils/Select.cuh>
#include <faiss/utils/Heap.h>

#include <thrust/fill.h>

Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/spatial/knn/detail/common_faiss.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>

#include <faiss/gpu/GpuDistance.h>
#include <faiss/MetricType.h>
#include <raft/distance/distance_types.hpp>

namespace raft {
Expand Down
3 changes: 0 additions & 3 deletions cpp/include/raft/spatial/knn/detail/haversine_distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,8 @@
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>

#include <faiss/gpu/GpuDistance.h>
#include <faiss/gpu/GpuResources.h>
#include <faiss/gpu/utils/Limits.cuh>
#include <faiss/gpu/utils/Select.cuh>
#include <faiss/utils/Heap.h>

#include <raft/core/handle.hpp>
#include <raft/distance/distance_types.hpp>
Expand Down
2 changes: 0 additions & 2 deletions cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@
#include <rmm/device_uvector.hpp>

#include <faiss/gpu/GpuDistance.h>
#include <faiss/gpu/GpuResources.h>
#include <faiss/gpu/utils/Limits.cuh>
#include <faiss/gpu/utils/Select.cuh>
#include <faiss/utils/Heap.h>

#include <cstdint>
#include <iostream>
Expand Down
5 changes: 0 additions & 5 deletions cpp/include/raft/spatial/knn/detail/selection_faiss.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,7 @@
#include <raft/util/cudart_utils.hpp>
#include <raft/util/pow2_utils.cuh>

#include <faiss/gpu/GpuDistance.h>
#include <faiss/gpu/GpuIndexFlat.h>
#include <faiss/gpu/GpuResources.h>
#include <faiss/gpu/utils/Limits.cuh>
#include <faiss/gpu/utils/Select.cuh>
#include <faiss/utils/Heap.h>

namespace raft {
namespace spatial {
Expand Down
208 changes: 3 additions & 205 deletions cpp/include/raft/spatial/knn/detail/warp_select_faiss.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -7,51 +7,17 @@

#pragma once

#include <cub/cub.cuh>

#include <faiss/gpu/utils/DeviceDefs.cuh>
#include <faiss/gpu/utils/MergeNetworkUtils.cuh>
#include <faiss/gpu/utils/PtxUtils.cuh>
#include <faiss/gpu/utils/StaticUtils.h>
#include <faiss/gpu/utils/WarpShuffles.cuh>

#include <raft/core/kvp.hpp>

namespace faiss {
namespace gpu {

template <typename _Key, typename _Value>
struct KeyValuePair {
typedef _Key Key; ///< Key data type
typedef _Value Value; ///< Value data type

Key key; ///< Item key
Value value; ///< Item value

/// Constructor
__host__ __device__ __forceinline__ KeyValuePair() {}

/// Copy Constructors
__host__ __device__ __forceinline__ KeyValuePair(cub::KeyValuePair<_Key, _Value>& kvp)
: key(kvp.key), value(kvp.value)
{
}

__host__ __device__ __forceinline__ KeyValuePair(faiss::gpu::KeyValuePair<_Key, _Value>& kvp)
: key(kvp.key), value(kvp.value)
{
}

/// Constructor
__host__ __device__ __forceinline__ KeyValuePair(Key const& key, Value const& value)
: key(key), value(value)
{
}

/// Inequality operator
__host__ __device__ __forceinline__ bool operator!=(const KeyValuePair& b)
{
return (value != b.value) || (key != b.key);
}
};
using raft::KeyValuePair;

//
// This file contains functions to:
Expand Down Expand Up @@ -591,173 +557,5 @@ inline __device__ void warpSortAnyRegistersKVP(K k[N], KeyValuePair<K, V> v[N])
{
BitonicSortStepKVP<K, V, N, Dir, Comp>::sort(k, v);
}

// `Dir` true, produce largest values.
// `Dir` false, produce smallest values.
template <typename K,
typename V,
bool Dir,
typename Comp,
int NumWarpQ,
int NumThreadQ,
int ThreadsPerBlock>
struct KeyValueWarpSelect {
static constexpr int kNumWarpQRegisters = NumWarpQ / faiss::gpu::kWarpSize;

__device__ inline KeyValueWarpSelect(K initKVal, faiss::gpu::KeyValuePair<K, V> initVVal, int k)
: initK(initKVal),
initV(initVVal),
numVals(0),
warpKTop(initKVal),
warpKTopRDist(initKVal),
kLane((k - 1) % faiss::gpu::kWarpSize)
{
static_assert(faiss::gpu::utils::isPowerOf2(ThreadsPerBlock), "threads must be a power-of-2");
static_assert(faiss::gpu::utils::isPowerOf2(NumWarpQ), "warp queue must be power-of-2");

// Fill the per-thread queue keys with the default value
#pragma unroll
for (int i = 0; i < NumThreadQ; ++i) {
threadK[i] = initK;
threadV[i].key = initV.key;
threadV[i].value = initV.value;
}

// Fill the warp queue with the default value
#pragma unroll
for (int i = 0; i < kNumWarpQRegisters; ++i) {
warpK[i] = initK;
warpV[i].key = initV.key;
warpV[i].value = initV.value;
}
}

__device__ inline void addThreadQ(K k, faiss::gpu::KeyValuePair<K, V>& v)
{
if (Dir ? Comp::gt(k, warpKTop) : Comp::lt(k, warpKTop)) {
// Rotate right
#pragma unroll
for (int i = NumThreadQ - 1; i > 0; --i) {
threadK[i] = threadK[i - 1];
threadV[i].key = threadV[i - 1].key;
threadV[i].value = threadV[i - 1].value;
}

threadK[0] = k;
threadV[0].key = v.key;
threadV[0].value = v.value;
++numVals;
}
}
/// This function handles sorting and merging together the
/// per-thread queues with the warp-wide queue, creating a sorted
/// list across both

// TODO
__device__ inline void mergeWarpQ()
{
// Sort all of the per-thread queues
faiss::gpu::warpSortAnyRegistersKVP<K, V, NumThreadQ, !Dir, Comp>(threadK, threadV);

// The warp queue is already sorted, and now that we've sorted the
// per-thread queue, merge both sorted lists together, producing
// one sorted list
faiss::gpu::warpMergeAnyRegistersKVP<K, V, kNumWarpQRegisters, NumThreadQ, !Dir, Comp, false>(
warpK, warpV, threadK, threadV);
}

/// WARNING: all threads in a warp must participate in this.
/// Otherwise, you must call the constituent parts separately.
__device__ inline void add(K k, faiss::gpu::KeyValuePair<K, V>& v)
{
addThreadQ(k, v);
checkThreadQ();
}

__device__ inline void reduce()
{
// Have all warps dump and merge their queues; this will produce
// the final per-warp results
mergeWarpQ();
}

__device__ inline void checkThreadQ()
{
bool needSort = (numVals == NumThreadQ);

#if CUDA_VERSION >= 9000
needSort = __any_sync(0xffffffff, needSort);
#else
needSort = __any(needSort);
#endif

if (!needSort) {
// no lanes have triggered a sort
return;
}

mergeWarpQ();

// Any top-k elements have been merged into the warp queue; we're
// free to reset the thread queues
numVals = 0;

#pragma unroll
for (int i = 0; i < NumThreadQ; ++i) {
threadK[i] = initK;
threadV[i].key = initV.key;
threadV[i].value = initV.value;
}

// We have to beat at least this element
warpKTopRDist = shfl(warpV[kNumWarpQRegisters - 1].key, kLane);
warpKTop = shfl(warpK[kNumWarpQRegisters - 1], kLane);
}

/// Dump final k selected values for this warp out
__device__ inline void writeOut(K* outK, V* outV, int k)
{
int laneId = faiss::gpu::getLaneId();

#pragma unroll
for (int i = 0; i < kNumWarpQRegisters; ++i) {
int idx = i * faiss::gpu::kWarpSize + laneId;

if (idx < k) {
outK[idx] = warpK[i];
outV[idx] = warpV[i].value;
}
}
}

// Default element key
const K initK;

// Default element value
const faiss::gpu::KeyValuePair<K, V> initV;

// Number of valid elements in our thread queue
int numVals;

// The k-th highest (Dir) or lowest (!Dir) element
K warpKTop;

// TopK's distance to closest landmark
K warpKTopRDist;

// Thread queue values
K threadK[NumThreadQ];
faiss::gpu::KeyValuePair<K, V> threadV[NumThreadQ];

// warpK[0] is highest (Dir) or lowest (!Dir)
K warpK[kNumWarpQRegisters];
faiss::gpu::KeyValuePair<K, V> warpV[kNumWarpQRegisters];

// This is what lane we should load an approximation (>=k) to the
// kth element from the last register in the warp queue (i.e.,
// warpK[kNumWarpQRegisters - 1]).
int kLane;
};

} // namespace gpu
} // namespace faiss

0 comments on commit a9e060d

Please sign in to comment.