Skip to content

Commit

Permalink
Fix IVF-PQ fused kernel performance problems (#1726)
Browse files Browse the repository at this point in the history
Fix occasional slowdown of the compute similarity kernels:

  1. Allow using the fused version for small work size cases
  2. Switch to the warpsort implementation that uses less registers at the expense of shared memory.
 
Addresses (at least partially) #1621

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1726
  • Loading branch information
achirkin authored Aug 11, 2023
1 parent bfe0fb5 commit b3bb21a
Showing 1 changed file with 89 additions and 34 deletions.
123 changes: 89 additions & 34 deletions cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,39 @@ static_assert((kMaxCapacity >= 32) && !(kMaxCapacity & (kMaxCapacity - 1)),
auto RAFT_WEAK_FUNCTION is_local_topk_feasible(uint32_t k, uint32_t n_probes, uint32_t n_queries)
-> bool
{
if (k > kMaxCapacity) { return false; } // warp_sort not possible
if (n_probes <= 16) { return false; } // too few clusters
if (n_queries * n_probes <= 256) { return false; } // overall amount of work is too small
if (k > kMaxCapacity) { return false; } // warp_sort not possible
if (n_queries * n_probes <= 16) { return false; } // overall amount of work is too small
return true;
}

template <int Capacity, typename T, typename IdxT>
struct pq_block_sort {
using type = matrix::detail::select::warpsort::
block_sort<matrix::detail::select::warpsort::warp_sort_distributed, Capacity, true, T, IdxT>;
using type = matrix::detail::select::warpsort::block_sort<
matrix::detail::select::warpsort::warp_sort_distributed_ext,
Capacity,
true,
T,
IdxT>;

static auto get_mem_required(uint32_t k_max)
{
if (k_max == 0 || k_max > Capacity) {
return pq_block_sort<0, T, IdxT>::get_mem_required(k_max);
}
if constexpr (Capacity > 1) {
if (k_max * 2 <= Capacity) {
return pq_block_sort<(Capacity / 2), T, IdxT>::get_mem_required(k_max);
}
}
return type::queue_t::mem_required;
}
};

template <typename T, typename IdxT>
struct pq_block_sort<0, T, IdxT> : dummy_block_sort_t<T, IdxT> {
using type = dummy_block_sort_t<T, IdxT>;
static auto mem_required(uint32_t) -> size_t { return 0; }
static auto get_mem_required(uint32_t) { return mem_required; }
};

template <int Capacity, typename T, typename IdxT>
Expand Down Expand Up @@ -212,7 +230,7 @@ __device__ auto ivfpq_compute_score(uint32_t pq_dim,
* [n_clusters, dim].
* @param pq_centers
* The device pointer to the cluster centers in the PQ space
* [pq_dim, pq_book_size, pq_len] or [n_clusters, pq_book_size, pq_len,].
* [pq_dim, pq_book_size, pq_len] or [n_clusters, pq_book_size, pq_len].
* @param pq_dataset
* The device pointer to the PQ index (data) [n_rows, ...].
* @param cluster_labels
Expand Down Expand Up @@ -275,7 +293,9 @@ __global__ void compute_similarity_kernel(uint32_t dim,
/* Shared memory:
* lut_scores: lookup table (LUT) of size = `pq_dim << PqBits` (when EnableSMemLut)
* base_diff: size = dim (which is equal to `pq_dim * pq_len`) or dim*2
* lut_end+:
* base_diff: size = dim (which is equal to `pq_dim * pq_len`) or dim*2
* topk::warp_sort::mem_required - local topk temporary buffer (if necessary)
* topk::block_sort: some amount of shared memory, but overlaps with the rest:
block_sort only needs shared memory for `.done()` operation, which can come very last.
*/
Expand All @@ -294,13 +314,11 @@ __global__ void compute_similarity_kernel(uint32_t dim,
lut_scores += lut_size * blockIdx.x;
}

float* base_diff = nullptr;
if constexpr (PrecompBaseDiff) {
if constexpr (EnableSMemLut) {
base_diff = reinterpret_cast<float*>(lut_scores + lut_size);
} else {
base_diff = reinterpret_cast<float*>(smem_buf);
}
uint8_t* lut_end = nullptr;
if constexpr (EnableSMemLut) {
lut_end = reinterpret_cast<uint8_t*>(lut_scores + lut_size);
} else {
lut_end = smem_buf;
}

for (int ib = blockIdx.x; ib < n_queries * n_probes; ib += gridDim.x) {
Expand Down Expand Up @@ -347,15 +365,15 @@ __global__ void compute_similarity_kernel(uint32_t dim,
case distance::DistanceType::L2SqrtExpanded:
case distance::DistanceType::L2Expanded: {
for (uint32_t i = threadIdx.x; i < dim; i += blockDim.x) {
base_diff[i] = query[i] - cluster_center[i];
reinterpret_cast<float*>(lut_end)[i] = query[i] - cluster_center[i];
}
} break;
case distance::DistanceType::InnerProduct: {
float2 pvals;
for (uint32_t i = threadIdx.x; i < dim; i += blockDim.x) {
pvals.x = query[i];
pvals.y = cluster_center[i] * pvals.x;
reinterpret_cast<float2*>(base_diff)[i] = pvals;
pvals.x = query[i];
pvals.y = cluster_center[i] * pvals.x;
reinterpret_cast<float2*>(lut_end)[i] = pvals;
}
} break;
default: __builtin_unreachable();
Expand All @@ -382,7 +400,7 @@ __global__ void compute_similarity_kernel(uint32_t dim,
case distance::DistanceType::L2Expanded: {
float diff;
if constexpr (PrecompBaseDiff) {
diff = base_diff[j];
diff = reinterpret_cast<float*>(lut_end)[j];
} else {
diff = query[j] - cluster_center[j];
}
Expand All @@ -393,7 +411,7 @@ __global__ void compute_similarity_kernel(uint32_t dim,
// NB: we negate the scores as we hardcoded select-topk to always compute the minimum
float q;
if constexpr (PrecompBaseDiff) {
float2 pvals = reinterpret_cast<float2*>(base_diff)[j];
float2 pvals = reinterpret_cast<float2*>(lut_end)[j];
q = pvals.x;
score -= pvals.y;
} else {
Expand Down Expand Up @@ -438,7 +456,6 @@ __global__ void compute_similarity_kernel(uint32_t dim,
constexpr OutT kDummy = upper_bound<OutT>();
OutT query_kth = kDummy;
if constexpr (kManageLocalTopK) { query_kth = OutT(query_kths[query_ix]); }
local_topk_t block_topk(topk, nullptr, query_kth);
OutT early_stop_limit = kDummy;
switch (metric) {
// If the metric is non-negative, we can use the query_kth approximation as an early stop
Expand All @@ -453,6 +470,7 @@ __global__ void compute_similarity_kernel(uint32_t dim,
// Ensure lut_scores is written by all threads before using it in ivfpq-compute-score
__threadfence_block();
__syncthreads();
local_topk_t block_topk(topk, lut_end, query_kth);

// Compute a distance for each sample
for (uint32_t i = threadIdx.x; i < n_samples_aligned;
Expand Down Expand Up @@ -680,13 +698,31 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props,
// Shared memory for storing pre-computed pieces to speedup the lookup table construction
// (e.g. the distance between a cluster center and the query for L2).
size_t bdf_mem = sizeof(float) * precomp_data_count;
// Shared memory for the fused top-k component; it may overlap with the other uses of shared
// memory and depends on the number of threads.
struct ltk_mem_t {

// Shared memory used by the fused top-k during cluster scanning;
// may overlap with the precomputed distance array
struct ltk_add_mem_t {
size_t (*mem_required)(uint32_t);

ltk_add_mem_t(bool manage_local_topk, uint32_t topk)
: mem_required(pq_block_sort<kMaxCapacity, OutT, uint32_t>::get_mem_required(
manage_local_topk ? topk : 0))
{
}

[[nodiscard]] auto operator()(uint32_t n_threads) const -> size_t
{
return mem_required(n_threads);
}
} ltk_add_mem{manage_local_topk, topk};

// Shared memory for the fused top-k component;
// may overlap with all other uses of shared memory
struct ltk_reduce_mem_t {
uint32_t subwarp_size;
uint32_t topk;
bool manage_local_topk;
ltk_mem_t(bool manage_local_topk, uint32_t topk)
ltk_reduce_mem_t(bool manage_local_topk, uint32_t topk)
: manage_local_topk(manage_local_topk), topk(topk)
{
subwarp_size = WarpSize;
Expand All @@ -703,7 +739,19 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props,
n_threads / subwarp_size, topk)
: 0;
}
} ltk_mem{manage_local_topk, topk};
} ltk_reduce_mem{manage_local_topk, topk};

struct total_shared_mem_t {
ltk_add_mem_t& ltk_add_mem;
ltk_reduce_mem_t& ltk_reduce_mem;
size_t lut_mem;
size_t bdf_mem;
[[nodiscard]] auto operator()(uint32_t n_threads) const -> size_t
{
return std::max(ltk_reduce_mem(n_threads),
lut_mem + std::max(bdf_mem, ltk_add_mem(n_threads)));
}
};

// Total amount of work; should be enough to occupy the GPU.
uint32_t n_blocks = n_queries * n_probes;
Expand Down Expand Up @@ -749,17 +797,24 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props,
auto conf_no_basediff = get_compute_similarity_kernel<OutT, LutT, false, true, IvfSampleFilterT>;
auto conf_no_smem_lut = get_compute_similarity_kernel<OutT, LutT, true, false, IvfSampleFilterT>;
auto topk_or_zero = manage_local_topk ? topk : 0u;
std::array candidates{std::make_tuple(conf_fast(pq_bits, topk_or_zero), lut_mem + bdf_mem, true),
std::make_tuple(conf_no_basediff(pq_bits, topk_or_zero), lut_mem, true),
std::make_tuple(conf_no_smem_lut(pq_bits, topk_or_zero), bdf_mem, false)};
std::array candidates{
std::make_tuple(conf_fast(pq_bits, topk_or_zero),
total_shared_mem_t{ltk_add_mem, ltk_reduce_mem, lut_mem, bdf_mem},
true),
std::make_tuple(conf_no_basediff(pq_bits, topk_or_zero),
total_shared_mem_t{ltk_add_mem, ltk_reduce_mem, lut_mem, 0},
true),
std::make_tuple(conf_no_smem_lut(pq_bits, topk_or_zero),
total_shared_mem_t{ltk_add_mem, ltk_reduce_mem, 0, bdf_mem},
false)};

// we may allow slightly lower than 100% occupancy;
constexpr double kTargetOccupancy = 0.75;
// This struct is used to select the better candidate
occupancy_t<OutT, LutT, IvfSampleFilterT> selected_perf{};
selected<OutT, LutT, IvfSampleFilterT> selected_config;
for (auto [kernel, smem_size_const, lut_is_in_shmem] : candidates) {
if (smem_size_const > dev_props.sharedMemPerBlockOptin) {
for (auto [kernel, smem_size_f, lut_is_in_shmem] : candidates) {
if (smem_size_f(WarpSize) > dev_props.sharedMemPerBlockOptin) {
// Even a single block cannot fit into an SM due to shmem requirements. Skip the candidate.
continue;
}
Expand All @@ -770,7 +825,7 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props,
// launch configuration, we will tighten the carveout once more, based on the final memory
// usage and occupancy.
const int max_carveout =
estimate_carveout(preferred_shmem_carveout, smem_size_const, dev_props);
estimate_carveout(preferred_shmem_carveout, smem_size_f(WarpSize), dev_props);
RAFT_CUDA_TRY(
cudaFuncSetAttribute(kernel, cudaFuncAttributePreferredSharedMemoryCarveout, max_carveout));

Expand All @@ -780,7 +835,7 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props,
uint32_t n_threads = round_down_safe<uint32_t>(kernel_attrs.maxThreadsPerBlock, n_threads_gty);

// Actual required shmem depens on the number of threads
size_t smem_size = max(smem_size_const, ltk_mem(n_threads));
size_t smem_size = smem_size_f(n_threads);

// Make sure the kernel can get enough shmem.
cudaError_t cuda_status =
Expand All @@ -807,7 +862,7 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props,
}
if (n_threads_tmp < n_threads) {
while (n_threads_tmp >= n_threads_min) {
auto smem_size_tmp = max(smem_size_const, ltk_mem(n_threads_tmp));
auto smem_size_tmp = smem_size_f(n_threads_tmp);
occupancy_t<OutT, LutT, IvfSampleFilterT> tmp(
smem_size_tmp, n_threads_tmp, kernel, dev_props);
bool select_it = false;
Expand Down

0 comments on commit b3bb21a

Please sign in to comment.