From b564d1e6ce7358b6889abd7acd1bcf603ab45c9f Mon Sep 17 00:00:00 2001 From: achirkin Date: Tue, 8 Aug 2023 18:10:18 +0200 Subject: [PATCH] Fix ivf-pq fused kernel perf problems --- .../detail/ivf_pq_compute_similarity-inl.cuh | 123 +++++++++++++----- 1 file changed, 89 insertions(+), 34 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh index 90d993abd5..2ab216b13b 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh @@ -43,21 +43,39 @@ static_assert((kMaxCapacity >= 32) && !(kMaxCapacity & (kMaxCapacity - 1)), auto RAFT_WEAK_FUNCTION is_local_topk_feasible(uint32_t k, uint32_t n_probes, uint32_t n_queries) -> bool { - if (k > kMaxCapacity) { return false; } // warp_sort not possible - if (n_probes <= 16) { return false; } // too few clusters - if (n_queries * n_probes <= 256) { return false; } // overall amount of work is too small + if (k > kMaxCapacity) { return false; } // warp_sort not possible + if (n_queries * n_probes <= 16) { return false; } // overall amount of work is too small return true; } template struct pq_block_sort { - using type = matrix::detail::select::warpsort:: - block_sort; + using type = matrix::detail::select::warpsort::block_sort< + matrix::detail::select::warpsort::warp_sort_distributed_ext, + Capacity, + true, + T, + IdxT>; + + static auto get_mem_required(uint32_t k_max) + { + if (k_max == 0 || k_max > Capacity) { + return pq_block_sort<0, T, IdxT>::get_mem_required(k_max); + } + if constexpr (Capacity > 1) { + if (k_max * 2 <= Capacity) { + return pq_block_sort<(Capacity / 2), T, IdxT>::get_mem_required(k_max); + } + } + return type::queue_t::mem_required; + } }; template struct pq_block_sort<0, T, IdxT> : dummy_block_sort_t { using type = dummy_block_sort_t; + static auto mem_required(uint32_t) -> size_t { return 0; } + static auto get_mem_required(uint32_t) { return mem_required; } }; template @@ -212,7 +230,7 @@ __device__ auto ivfpq_compute_score(uint32_t pq_dim, * [n_clusters, dim]. * @param pq_centers * The device pointer to the cluster centers in the PQ space - * [pq_dim, pq_book_size, pq_len] or [n_clusters, pq_book_size, pq_len,]. + * [pq_dim, pq_book_size, pq_len] or [n_clusters, pq_book_size, pq_len]. * @param pq_dataset * The device pointer to the PQ index (data) [n_rows, ...]. * @param cluster_labels @@ -275,7 +293,9 @@ __global__ void compute_similarity_kernel(uint32_t dim, /* Shared memory: * lut_scores: lookup table (LUT) of size = `pq_dim << PqBits` (when EnableSMemLut) - * base_diff: size = dim (which is equal to `pq_dim * pq_len`) or dim*2 + * lut_end+: + * base_diff: size = dim (which is equal to `pq_dim * pq_len`) or dim*2 + * topk::warp_sort::mem_required - local topk temporary buffer (if necessary) * topk::block_sort: some amount of shared memory, but overlaps with the rest: block_sort only needs shared memory for `.done()` operation, which can come very last. */ @@ -294,13 +314,11 @@ __global__ void compute_similarity_kernel(uint32_t dim, lut_scores += lut_size * blockIdx.x; } - float* base_diff = nullptr; - if constexpr (PrecompBaseDiff) { - if constexpr (EnableSMemLut) { - base_diff = reinterpret_cast(lut_scores + lut_size); - } else { - base_diff = reinterpret_cast(smem_buf); - } + uint8_t* lut_end = nullptr; + if constexpr (EnableSMemLut) { + lut_end = reinterpret_cast(lut_scores + lut_size); + } else { + lut_end = smem_buf; } for (int ib = blockIdx.x; ib < n_queries * n_probes; ib += gridDim.x) { @@ -347,15 +365,15 @@ __global__ void compute_similarity_kernel(uint32_t dim, case distance::DistanceType::L2SqrtExpanded: case distance::DistanceType::L2Expanded: { for (uint32_t i = threadIdx.x; i < dim; i += blockDim.x) { - base_diff[i] = query[i] - cluster_center[i]; + reinterpret_cast(lut_end)[i] = query[i] - cluster_center[i]; } } break; case distance::DistanceType::InnerProduct: { float2 pvals; for (uint32_t i = threadIdx.x; i < dim; i += blockDim.x) { - pvals.x = query[i]; - pvals.y = cluster_center[i] * pvals.x; - reinterpret_cast(base_diff)[i] = pvals; + pvals.x = query[i]; + pvals.y = cluster_center[i] * pvals.x; + reinterpret_cast(lut_end)[i] = pvals; } } break; default: __builtin_unreachable(); @@ -382,7 +400,7 @@ __global__ void compute_similarity_kernel(uint32_t dim, case distance::DistanceType::L2Expanded: { float diff; if constexpr (PrecompBaseDiff) { - diff = base_diff[j]; + diff = reinterpret_cast(lut_end)[j]; } else { diff = query[j] - cluster_center[j]; } @@ -393,7 +411,7 @@ __global__ void compute_similarity_kernel(uint32_t dim, // NB: we negate the scores as we hardcoded select-topk to always compute the minimum float q; if constexpr (PrecompBaseDiff) { - float2 pvals = reinterpret_cast(base_diff)[j]; + float2 pvals = reinterpret_cast(lut_end)[j]; q = pvals.x; score -= pvals.y; } else { @@ -438,7 +456,6 @@ __global__ void compute_similarity_kernel(uint32_t dim, constexpr OutT kDummy = upper_bound(); OutT query_kth = kDummy; if constexpr (kManageLocalTopK) { query_kth = OutT(query_kths[query_ix]); } - local_topk_t block_topk(topk, nullptr, query_kth); OutT early_stop_limit = kDummy; switch (metric) { // If the metric is non-negative, we can use the query_kth approximation as an early stop @@ -453,6 +470,7 @@ __global__ void compute_similarity_kernel(uint32_t dim, // Ensure lut_scores is written by all threads before using it in ivfpq-compute-score __threadfence_block(); __syncthreads(); + local_topk_t block_topk(topk, lut_end, query_kth); // Compute a distance for each sample for (uint32_t i = threadIdx.x; i < n_samples_aligned; @@ -680,13 +698,31 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, // Shared memory for storing pre-computed pieces to speedup the lookup table construction // (e.g. the distance between a cluster center and the query for L2). size_t bdf_mem = sizeof(float) * precomp_data_count; - // Shared memory for the fused top-k component; it may overlap with the other uses of shared - // memory and depends on the number of threads. - struct ltk_mem_t { + + // Shared memory used by the fused top-k during cluster scanning; + // may overlap with the precomputed distance array + struct ltk_add_mem_t { + size_t (*mem_required)(uint32_t); + + ltk_add_mem_t(bool manage_local_topk, uint32_t topk) + : mem_required(pq_block_sort::get_mem_required( + manage_local_topk ? topk : 0)) + { + } + + [[nodiscard]] auto operator()(uint32_t n_threads) const -> size_t + { + return mem_required(n_threads); + } + } ltk_add_mem{manage_local_topk, topk}; + + // Shared memory for the fused top-k component; + // may overlap with all other uses of shared memory + struct ltk_reduce_mem_t { uint32_t subwarp_size; uint32_t topk; bool manage_local_topk; - ltk_mem_t(bool manage_local_topk, uint32_t topk) + ltk_reduce_mem_t(bool manage_local_topk, uint32_t topk) : manage_local_topk(manage_local_topk), topk(topk) { subwarp_size = WarpSize; @@ -703,7 +739,19 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, n_threads / subwarp_size, topk) : 0; } - } ltk_mem{manage_local_topk, topk}; + } ltk_reduce_mem{manage_local_topk, topk}; + + struct total_shared_mem_t { + ltk_add_mem_t& ltk_add_mem; + ltk_reduce_mem_t& ltk_reduce_mem; + size_t lut_mem; + size_t bdf_mem; + [[nodiscard]] auto operator()(uint32_t n_threads) const -> size_t + { + return std::max(ltk_reduce_mem(n_threads), + lut_mem + std::max(bdf_mem, ltk_add_mem(n_threads))); + } + }; // Total amount of work; should be enough to occupy the GPU. uint32_t n_blocks = n_queries * n_probes; @@ -749,17 +797,24 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, auto conf_no_basediff = get_compute_similarity_kernel; auto conf_no_smem_lut = get_compute_similarity_kernel; auto topk_or_zero = manage_local_topk ? topk : 0u; - std::array candidates{std::make_tuple(conf_fast(pq_bits, topk_or_zero), lut_mem + bdf_mem, true), - std::make_tuple(conf_no_basediff(pq_bits, topk_or_zero), lut_mem, true), - std::make_tuple(conf_no_smem_lut(pq_bits, topk_or_zero), bdf_mem, false)}; + std::array candidates{ + std::make_tuple(conf_fast(pq_bits, topk_or_zero), + total_shared_mem_t{ltk_add_mem, ltk_reduce_mem, lut_mem, bdf_mem}, + true), + std::make_tuple(conf_no_basediff(pq_bits, topk_or_zero), + total_shared_mem_t{ltk_add_mem, ltk_reduce_mem, lut_mem, 0}, + true), + std::make_tuple(conf_no_smem_lut(pq_bits, topk_or_zero), + total_shared_mem_t{ltk_add_mem, ltk_reduce_mem, 0, bdf_mem}, + false)}; // we may allow slightly lower than 100% occupancy; constexpr double kTargetOccupancy = 0.75; // This struct is used to select the better candidate occupancy_t selected_perf{}; selected selected_config; - for (auto [kernel, smem_size_const, lut_is_in_shmem] : candidates) { - if (smem_size_const > dev_props.sharedMemPerBlockOptin) { + for (auto [kernel, smem_size_f, lut_is_in_shmem] : candidates) { + if (smem_size_f(WarpSize) > dev_props.sharedMemPerBlockOptin) { // Even a single block cannot fit into an SM due to shmem requirements. Skip the candidate. continue; } @@ -770,7 +825,7 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, // launch configuration, we will tighten the carveout once more, based on the final memory // usage and occupancy. const int max_carveout = - estimate_carveout(preferred_shmem_carveout, smem_size_const, dev_props); + estimate_carveout(preferred_shmem_carveout, smem_size_f(WarpSize), dev_props); RAFT_CUDA_TRY( cudaFuncSetAttribute(kernel, cudaFuncAttributePreferredSharedMemoryCarveout, max_carveout)); @@ -780,7 +835,7 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, uint32_t n_threads = round_down_safe(kernel_attrs.maxThreadsPerBlock, n_threads_gty); // Actual required shmem depens on the number of threads - size_t smem_size = max(smem_size_const, ltk_mem(n_threads)); + size_t smem_size = smem_size_f(n_threads); // Make sure the kernel can get enough shmem. cudaError_t cuda_status = @@ -807,7 +862,7 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, } if (n_threads_tmp < n_threads) { while (n_threads_tmp >= n_threads_min) { - auto smem_size_tmp = max(smem_size_const, ltk_mem(n_threads_tmp)); + auto smem_size_tmp = smem_size_f(n_threads_tmp); occupancy_t tmp( smem_size_tmp, n_threads_tmp, kernel, dev_props); bool select_it = false;