From b3bb21ae08161ed1fe94a8627b5c7ff482337578 Mon Sep 17 00:00:00 2001 From: "Artem M. Chirkin" <9253178+achirkin@users.noreply.github.com> Date: Fri, 11 Aug 2023 03:22:54 +0200 Subject: [PATCH] Fix IVF-PQ fused kernel performance problems (#1726) Fix occasional slowdown of the compute similarity kernels: 1. Allow using the fused version for small work size cases 2. Switch to the warpsort implementation that uses less registers at the expense of shared memory. Addresses (at least partially) https://github.com/rapidsai/raft/issues/1621 Authors: - Artem M. Chirkin (https://github.com/achirkin) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1726 --- .../detail/ivf_pq_compute_similarity-inl.cuh | 123 +++++++++++++----- 1 file changed, 89 insertions(+), 34 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh index 90d993abd5..2ab216b13b 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh @@ -43,21 +43,39 @@ static_assert((kMaxCapacity >= 32) && !(kMaxCapacity & (kMaxCapacity - 1)), auto RAFT_WEAK_FUNCTION is_local_topk_feasible(uint32_t k, uint32_t n_probes, uint32_t n_queries) -> bool { - if (k > kMaxCapacity) { return false; } // warp_sort not possible - if (n_probes <= 16) { return false; } // too few clusters - if (n_queries * n_probes <= 256) { return false; } // overall amount of work is too small + if (k > kMaxCapacity) { return false; } // warp_sort not possible + if (n_queries * n_probes <= 16) { return false; } // overall amount of work is too small return true; } template struct pq_block_sort { - using type = matrix::detail::select::warpsort:: - block_sort; + using type = matrix::detail::select::warpsort::block_sort< + matrix::detail::select::warpsort::warp_sort_distributed_ext, + Capacity, + true, + T, + IdxT>; + + static auto get_mem_required(uint32_t k_max) + { + if (k_max == 0 || k_max > Capacity) { + return pq_block_sort<0, T, IdxT>::get_mem_required(k_max); + } + if constexpr (Capacity > 1) { + if (k_max * 2 <= Capacity) { + return pq_block_sort<(Capacity / 2), T, IdxT>::get_mem_required(k_max); + } + } + return type::queue_t::mem_required; + } }; template struct pq_block_sort<0, T, IdxT> : dummy_block_sort_t { using type = dummy_block_sort_t; + static auto mem_required(uint32_t) -> size_t { return 0; } + static auto get_mem_required(uint32_t) { return mem_required; } }; template @@ -212,7 +230,7 @@ __device__ auto ivfpq_compute_score(uint32_t pq_dim, * [n_clusters, dim]. * @param pq_centers * The device pointer to the cluster centers in the PQ space - * [pq_dim, pq_book_size, pq_len] or [n_clusters, pq_book_size, pq_len,]. + * [pq_dim, pq_book_size, pq_len] or [n_clusters, pq_book_size, pq_len]. * @param pq_dataset * The device pointer to the PQ index (data) [n_rows, ...]. * @param cluster_labels @@ -275,7 +293,9 @@ __global__ void compute_similarity_kernel(uint32_t dim, /* Shared memory: * lut_scores: lookup table (LUT) of size = `pq_dim << PqBits` (when EnableSMemLut) - * base_diff: size = dim (which is equal to `pq_dim * pq_len`) or dim*2 + * lut_end+: + * base_diff: size = dim (which is equal to `pq_dim * pq_len`) or dim*2 + * topk::warp_sort::mem_required - local topk temporary buffer (if necessary) * topk::block_sort: some amount of shared memory, but overlaps with the rest: block_sort only needs shared memory for `.done()` operation, which can come very last. */ @@ -294,13 +314,11 @@ __global__ void compute_similarity_kernel(uint32_t dim, lut_scores += lut_size * blockIdx.x; } - float* base_diff = nullptr; - if constexpr (PrecompBaseDiff) { - if constexpr (EnableSMemLut) { - base_diff = reinterpret_cast(lut_scores + lut_size); - } else { - base_diff = reinterpret_cast(smem_buf); - } + uint8_t* lut_end = nullptr; + if constexpr (EnableSMemLut) { + lut_end = reinterpret_cast(lut_scores + lut_size); + } else { + lut_end = smem_buf; } for (int ib = blockIdx.x; ib < n_queries * n_probes; ib += gridDim.x) { @@ -347,15 +365,15 @@ __global__ void compute_similarity_kernel(uint32_t dim, case distance::DistanceType::L2SqrtExpanded: case distance::DistanceType::L2Expanded: { for (uint32_t i = threadIdx.x; i < dim; i += blockDim.x) { - base_diff[i] = query[i] - cluster_center[i]; + reinterpret_cast(lut_end)[i] = query[i] - cluster_center[i]; } } break; case distance::DistanceType::InnerProduct: { float2 pvals; for (uint32_t i = threadIdx.x; i < dim; i += blockDim.x) { - pvals.x = query[i]; - pvals.y = cluster_center[i] * pvals.x; - reinterpret_cast(base_diff)[i] = pvals; + pvals.x = query[i]; + pvals.y = cluster_center[i] * pvals.x; + reinterpret_cast(lut_end)[i] = pvals; } } break; default: __builtin_unreachable(); @@ -382,7 +400,7 @@ __global__ void compute_similarity_kernel(uint32_t dim, case distance::DistanceType::L2Expanded: { float diff; if constexpr (PrecompBaseDiff) { - diff = base_diff[j]; + diff = reinterpret_cast(lut_end)[j]; } else { diff = query[j] - cluster_center[j]; } @@ -393,7 +411,7 @@ __global__ void compute_similarity_kernel(uint32_t dim, // NB: we negate the scores as we hardcoded select-topk to always compute the minimum float q; if constexpr (PrecompBaseDiff) { - float2 pvals = reinterpret_cast(base_diff)[j]; + float2 pvals = reinterpret_cast(lut_end)[j]; q = pvals.x; score -= pvals.y; } else { @@ -438,7 +456,6 @@ __global__ void compute_similarity_kernel(uint32_t dim, constexpr OutT kDummy = upper_bound(); OutT query_kth = kDummy; if constexpr (kManageLocalTopK) { query_kth = OutT(query_kths[query_ix]); } - local_topk_t block_topk(topk, nullptr, query_kth); OutT early_stop_limit = kDummy; switch (metric) { // If the metric is non-negative, we can use the query_kth approximation as an early stop @@ -453,6 +470,7 @@ __global__ void compute_similarity_kernel(uint32_t dim, // Ensure lut_scores is written by all threads before using it in ivfpq-compute-score __threadfence_block(); __syncthreads(); + local_topk_t block_topk(topk, lut_end, query_kth); // Compute a distance for each sample for (uint32_t i = threadIdx.x; i < n_samples_aligned; @@ -680,13 +698,31 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, // Shared memory for storing pre-computed pieces to speedup the lookup table construction // (e.g. the distance between a cluster center and the query for L2). size_t bdf_mem = sizeof(float) * precomp_data_count; - // Shared memory for the fused top-k component; it may overlap with the other uses of shared - // memory and depends on the number of threads. - struct ltk_mem_t { + + // Shared memory used by the fused top-k during cluster scanning; + // may overlap with the precomputed distance array + struct ltk_add_mem_t { + size_t (*mem_required)(uint32_t); + + ltk_add_mem_t(bool manage_local_topk, uint32_t topk) + : mem_required(pq_block_sort::get_mem_required( + manage_local_topk ? topk : 0)) + { + } + + [[nodiscard]] auto operator()(uint32_t n_threads) const -> size_t + { + return mem_required(n_threads); + } + } ltk_add_mem{manage_local_topk, topk}; + + // Shared memory for the fused top-k component; + // may overlap with all other uses of shared memory + struct ltk_reduce_mem_t { uint32_t subwarp_size; uint32_t topk; bool manage_local_topk; - ltk_mem_t(bool manage_local_topk, uint32_t topk) + ltk_reduce_mem_t(bool manage_local_topk, uint32_t topk) : manage_local_topk(manage_local_topk), topk(topk) { subwarp_size = WarpSize; @@ -703,7 +739,19 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, n_threads / subwarp_size, topk) : 0; } - } ltk_mem{manage_local_topk, topk}; + } ltk_reduce_mem{manage_local_topk, topk}; + + struct total_shared_mem_t { + ltk_add_mem_t& ltk_add_mem; + ltk_reduce_mem_t& ltk_reduce_mem; + size_t lut_mem; + size_t bdf_mem; + [[nodiscard]] auto operator()(uint32_t n_threads) const -> size_t + { + return std::max(ltk_reduce_mem(n_threads), + lut_mem + std::max(bdf_mem, ltk_add_mem(n_threads))); + } + }; // Total amount of work; should be enough to occupy the GPU. uint32_t n_blocks = n_queries * n_probes; @@ -749,17 +797,24 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, auto conf_no_basediff = get_compute_similarity_kernel; auto conf_no_smem_lut = get_compute_similarity_kernel; auto topk_or_zero = manage_local_topk ? topk : 0u; - std::array candidates{std::make_tuple(conf_fast(pq_bits, topk_or_zero), lut_mem + bdf_mem, true), - std::make_tuple(conf_no_basediff(pq_bits, topk_or_zero), lut_mem, true), - std::make_tuple(conf_no_smem_lut(pq_bits, topk_or_zero), bdf_mem, false)}; + std::array candidates{ + std::make_tuple(conf_fast(pq_bits, topk_or_zero), + total_shared_mem_t{ltk_add_mem, ltk_reduce_mem, lut_mem, bdf_mem}, + true), + std::make_tuple(conf_no_basediff(pq_bits, topk_or_zero), + total_shared_mem_t{ltk_add_mem, ltk_reduce_mem, lut_mem, 0}, + true), + std::make_tuple(conf_no_smem_lut(pq_bits, topk_or_zero), + total_shared_mem_t{ltk_add_mem, ltk_reduce_mem, 0, bdf_mem}, + false)}; // we may allow slightly lower than 100% occupancy; constexpr double kTargetOccupancy = 0.75; // This struct is used to select the better candidate occupancy_t selected_perf{}; selected selected_config; - for (auto [kernel, smem_size_const, lut_is_in_shmem] : candidates) { - if (smem_size_const > dev_props.sharedMemPerBlockOptin) { + for (auto [kernel, smem_size_f, lut_is_in_shmem] : candidates) { + if (smem_size_f(WarpSize) > dev_props.sharedMemPerBlockOptin) { // Even a single block cannot fit into an SM due to shmem requirements. Skip the candidate. continue; } @@ -770,7 +825,7 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, // launch configuration, we will tighten the carveout once more, based on the final memory // usage and occupancy. const int max_carveout = - estimate_carveout(preferred_shmem_carveout, smem_size_const, dev_props); + estimate_carveout(preferred_shmem_carveout, smem_size_f(WarpSize), dev_props); RAFT_CUDA_TRY( cudaFuncSetAttribute(kernel, cudaFuncAttributePreferredSharedMemoryCarveout, max_carveout)); @@ -780,7 +835,7 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, uint32_t n_threads = round_down_safe(kernel_attrs.maxThreadsPerBlock, n_threads_gty); // Actual required shmem depens on the number of threads - size_t smem_size = max(smem_size_const, ltk_mem(n_threads)); + size_t smem_size = smem_size_f(n_threads); // Make sure the kernel can get enough shmem. cudaError_t cuda_status = @@ -807,7 +862,7 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, } if (n_threads_tmp < n_threads) { while (n_threads_tmp >= n_threads_min) { - auto smem_size_tmp = max(smem_size_const, ltk_mem(n_threads_tmp)); + auto smem_size_tmp = smem_size_f(n_threads_tmp); occupancy_t tmp( smem_size_tmp, n_threads_tmp, kernel, dev_props); bool select_it = false;