Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix IVF-PQ fused kernel performance problems #1726

Merged
merged 4 commits into from
Aug 11, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 89 additions & 34 deletions cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,39 @@ static_assert((kMaxCapacity >= 32) && !(kMaxCapacity & (kMaxCapacity - 1)),
auto RAFT_WEAK_FUNCTION is_local_topk_feasible(uint32_t k, uint32_t n_probes, uint32_t n_queries)
-> bool
{
if (k > kMaxCapacity) { return false; } // warp_sort not possible
if (n_probes <= 16) { return false; } // too few clusters
if (n_queries * n_probes <= 256) { return false; } // overall amount of work is too small
if (k > kMaxCapacity) { return false; } // warp_sort not possible
if (n_queries * n_probes <= 16) { return false; } // overall amount of work is too small
return true;
}

template <int Capacity, typename T, typename IdxT>
struct pq_block_sort {
using type = matrix::detail::select::warpsort::
block_sort<matrix::detail::select::warpsort::warp_sort_distributed, Capacity, true, T, IdxT>;
using type = matrix::detail::select::warpsort::block_sort<
matrix::detail::select::warpsort::warp_sort_distributed_ext,
Capacity,
true,
T,
IdxT>;

static auto get_mem_required(uint32_t k_max)
{
if (k_max == 0 || k_max > Capacity) {
return pq_block_sort<0, T, IdxT>::get_mem_required(k_max);
}
if constexpr (Capacity > 1) {
if (k_max * 2 <= Capacity) {
return pq_block_sort<(Capacity / 2), T, IdxT>::get_mem_required(k_max);
}
}
return type::queue_t::mem_required;
}
};

template <typename T, typename IdxT>
struct pq_block_sort<0, T, IdxT> : dummy_block_sort_t<T, IdxT> {
using type = dummy_block_sort_t<T, IdxT>;
static auto mem_required(uint32_t) -> size_t { return 0; }
static auto get_mem_required(uint32_t) { return mem_required; }
};

template <int Capacity, typename T, typename IdxT>
Expand Down Expand Up @@ -212,7 +230,7 @@ __device__ auto ivfpq_compute_score(uint32_t pq_dim,
* [n_clusters, dim].
* @param pq_centers
* The device pointer to the cluster centers in the PQ space
* [pq_dim, pq_book_size, pq_len] or [n_clusters, pq_book_size, pq_len,].
* [pq_dim, pq_book_size, pq_len] or [n_clusters, pq_book_size, pq_len].
* @param pq_dataset
* The device pointer to the PQ index (data) [n_rows, ...].
* @param cluster_labels
Expand Down Expand Up @@ -275,7 +293,9 @@ __global__ void compute_similarity_kernel(uint32_t dim,
/* Shared memory:

* lut_scores: lookup table (LUT) of size = `pq_dim << PqBits` (when EnableSMemLut)
* base_diff: size = dim (which is equal to `pq_dim * pq_len`) or dim*2
* lut_end+:
* base_diff: size = dim (which is equal to `pq_dim * pq_len`) or dim*2
* topk::warp_sort::mem_required - local topk temporary buffer (if necessary)
* topk::block_sort: some amount of shared memory, but overlaps with the rest:
block_sort only needs shared memory for `.done()` operation, which can come very last.
*/
Expand All @@ -294,13 +314,11 @@ __global__ void compute_similarity_kernel(uint32_t dim,
lut_scores += lut_size * blockIdx.x;
}

float* base_diff = nullptr;
if constexpr (PrecompBaseDiff) {
if constexpr (EnableSMemLut) {
base_diff = reinterpret_cast<float*>(lut_scores + lut_size);
} else {
base_diff = reinterpret_cast<float*>(smem_buf);
}
uint8_t* lut_end = nullptr;
if constexpr (EnableSMemLut) {
lut_end = reinterpret_cast<uint8_t*>(lut_scores + lut_size);
} else {
lut_end = smem_buf;
}

for (int ib = blockIdx.x; ib < n_queries * n_probes; ib += gridDim.x) {
Expand Down Expand Up @@ -347,15 +365,15 @@ __global__ void compute_similarity_kernel(uint32_t dim,
case distance::DistanceType::L2SqrtExpanded:
case distance::DistanceType::L2Expanded: {
for (uint32_t i = threadIdx.x; i < dim; i += blockDim.x) {
base_diff[i] = query[i] - cluster_center[i];
reinterpret_cast<float*>(lut_end)[i] = query[i] - cluster_center[i];
}
} break;
case distance::DistanceType::InnerProduct: {
float2 pvals;
for (uint32_t i = threadIdx.x; i < dim; i += blockDim.x) {
pvals.x = query[i];
pvals.y = cluster_center[i] * pvals.x;
reinterpret_cast<float2*>(base_diff)[i] = pvals;
pvals.x = query[i];
pvals.y = cluster_center[i] * pvals.x;
reinterpret_cast<float2*>(lut_end)[i] = pvals;
}
} break;
default: __builtin_unreachable();
Expand All @@ -382,7 +400,7 @@ __global__ void compute_similarity_kernel(uint32_t dim,
case distance::DistanceType::L2Expanded: {
float diff;
if constexpr (PrecompBaseDiff) {
diff = base_diff[j];
diff = reinterpret_cast<float*>(lut_end)[j];
} else {
diff = query[j] - cluster_center[j];
}
Expand All @@ -393,7 +411,7 @@ __global__ void compute_similarity_kernel(uint32_t dim,
// NB: we negate the scores as we hardcoded select-topk to always compute the minimum
float q;
if constexpr (PrecompBaseDiff) {
float2 pvals = reinterpret_cast<float2*>(base_diff)[j];
float2 pvals = reinterpret_cast<float2*>(lut_end)[j];
q = pvals.x;
score -= pvals.y;
} else {
Expand Down Expand Up @@ -438,7 +456,6 @@ __global__ void compute_similarity_kernel(uint32_t dim,
constexpr OutT kDummy = upper_bound<OutT>();
OutT query_kth = kDummy;
if constexpr (kManageLocalTopK) { query_kth = OutT(query_kths[query_ix]); }
local_topk_t block_topk(topk, nullptr, query_kth);
OutT early_stop_limit = kDummy;
switch (metric) {
// If the metric is non-negative, we can use the query_kth approximation as an early stop
Expand All @@ -453,6 +470,7 @@ __global__ void compute_similarity_kernel(uint32_t dim,
// Ensure lut_scores is written by all threads before using it in ivfpq-compute-score
__threadfence_block();
__syncthreads();
local_topk_t block_topk(topk, lut_end, query_kth);

// Compute a distance for each sample
for (uint32_t i = threadIdx.x; i < n_samples_aligned;
Expand Down Expand Up @@ -680,13 +698,31 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props,
// Shared memory for storing pre-computed pieces to speedup the lookup table construction
// (e.g. the distance between a cluster center and the query for L2).
size_t bdf_mem = sizeof(float) * precomp_data_count;
// Shared memory for the fused top-k component; it may overlap with the other uses of shared
// memory and depends on the number of threads.
struct ltk_mem_t {

// Shared memory used by the fused top-k during cluster scanning;
// may overlap with the precomputed distance array
struct ltk_add_mem_t {
size_t (*mem_required)(uint32_t);

ltk_add_mem_t(bool manage_local_topk, uint32_t topk)
: mem_required(pq_block_sort<kMaxCapacity, OutT, uint32_t>::get_mem_required(
manage_local_topk ? topk : 0))
{
}

[[nodiscard]] auto operator()(uint32_t n_threads) const -> size_t
{
return mem_required(n_threads);
}
} ltk_add_mem{manage_local_topk, topk};

// Shared memory for the fused top-k component;
// may overlap with all other uses of shared memory
struct ltk_reduce_mem_t {
uint32_t subwarp_size;
uint32_t topk;
bool manage_local_topk;
ltk_mem_t(bool manage_local_topk, uint32_t topk)
ltk_reduce_mem_t(bool manage_local_topk, uint32_t topk)
: manage_local_topk(manage_local_topk), topk(topk)
{
subwarp_size = WarpSize;
Expand All @@ -703,7 +739,19 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props,
n_threads / subwarp_size, topk)
: 0;
}
} ltk_mem{manage_local_topk, topk};
} ltk_reduce_mem{manage_local_topk, topk};

struct total_shared_mem_t {
ltk_add_mem_t& ltk_add_mem;
ltk_reduce_mem_t& ltk_reduce_mem;
size_t lut_mem;
size_t bdf_mem;
[[nodiscard]] auto operator()(uint32_t n_threads) const -> size_t
{
return std::max(ltk_reduce_mem(n_threads),
lut_mem + std::max(bdf_mem, ltk_add_mem(n_threads)));
}
};

// Total amount of work; should be enough to occupy the GPU.
uint32_t n_blocks = n_queries * n_probes;
Expand Down Expand Up @@ -749,17 +797,24 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props,
auto conf_no_basediff = get_compute_similarity_kernel<OutT, LutT, false, true, IvfSampleFilterT>;
auto conf_no_smem_lut = get_compute_similarity_kernel<OutT, LutT, true, false, IvfSampleFilterT>;
auto topk_or_zero = manage_local_topk ? topk : 0u;
std::array candidates{std::make_tuple(conf_fast(pq_bits, topk_or_zero), lut_mem + bdf_mem, true),
std::make_tuple(conf_no_basediff(pq_bits, topk_or_zero), lut_mem, true),
std::make_tuple(conf_no_smem_lut(pq_bits, topk_or_zero), bdf_mem, false)};
std::array candidates{
std::make_tuple(conf_fast(pq_bits, topk_or_zero),
total_shared_mem_t{ltk_add_mem, ltk_reduce_mem, lut_mem, bdf_mem},
true),
std::make_tuple(conf_no_basediff(pq_bits, topk_or_zero),
total_shared_mem_t{ltk_add_mem, ltk_reduce_mem, lut_mem, 0},
true),
std::make_tuple(conf_no_smem_lut(pq_bits, topk_or_zero),
total_shared_mem_t{ltk_add_mem, ltk_reduce_mem, 0, bdf_mem},
false)};

// we may allow slightly lower than 100% occupancy;
constexpr double kTargetOccupancy = 0.75;
// This struct is used to select the better candidate
occupancy_t<OutT, LutT, IvfSampleFilterT> selected_perf{};
selected<OutT, LutT, IvfSampleFilterT> selected_config;
for (auto [kernel, smem_size_const, lut_is_in_shmem] : candidates) {
if (smem_size_const > dev_props.sharedMemPerBlockOptin) {
for (auto [kernel, smem_size_f, lut_is_in_shmem] : candidates) {
if (smem_size_f(WarpSize) > dev_props.sharedMemPerBlockOptin) {
// Even a single block cannot fit into an SM due to shmem requirements. Skip the candidate.
continue;
}
Expand All @@ -770,7 +825,7 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props,
// launch configuration, we will tighten the carveout once more, based on the final memory
// usage and occupancy.
const int max_carveout =
estimate_carveout(preferred_shmem_carveout, smem_size_const, dev_props);
estimate_carveout(preferred_shmem_carveout, smem_size_f(WarpSize), dev_props);
RAFT_CUDA_TRY(
cudaFuncSetAttribute(kernel, cudaFuncAttributePreferredSharedMemoryCarveout, max_carveout));

Expand All @@ -780,7 +835,7 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props,
uint32_t n_threads = round_down_safe<uint32_t>(kernel_attrs.maxThreadsPerBlock, n_threads_gty);

// Actual required shmem depens on the number of threads
size_t smem_size = max(smem_size_const, ltk_mem(n_threads));
size_t smem_size = smem_size_f(n_threads);

// Make sure the kernel can get enough shmem.
cudaError_t cuda_status =
Expand All @@ -807,7 +862,7 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props,
}
if (n_threads_tmp < n_threads) {
while (n_threads_tmp >= n_threads_min) {
auto smem_size_tmp = max(smem_size_const, ltk_mem(n_threads_tmp));
auto smem_size_tmp = smem_size_f(n_threads_tmp);
occupancy_t<OutT, LutT, IvfSampleFilterT> tmp(
smem_size_tmp, n_threads_tmp, kernel, dev_props);
bool select_it = false;
Expand Down