Skip to content

Commit

Permalink
Cache IVF-PQ and select-warpsort kernel launch parameters to reduce l…
Browse files Browse the repository at this point in the history
…atency
  • Loading branch information
achirkin committed Aug 30, 2023
1 parent 6f58669 commit 7aefb3b
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 47 deletions.
2 changes: 1 addition & 1 deletion cpp/include/raft/matrix/detail/select_k-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ void select_k(raft::resources const& handle,
case Algo::kWarpDistributedShm:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_distributed_ext>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
handle, in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
case Algo::kFaissBlockSelect:
return neighbors::detail::select_k(
in_val, in_idx, batch_size, len, out_val, out_idx, select_min, k, stream);
Expand Down
91 changes: 62 additions & 29 deletions cpp/include/raft/matrix/detail/select_warpsort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

#include <raft/core/detail/macros.hpp>
#include <raft/core/logger.hpp>
#include <raft/core/resource/user_resource.hpp>
#include <raft/util/bitonic_sort.cuh>
#include <raft/util/cache.hpp>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/integer_utils.hpp>
#include <raft/util/pow2_utils.cuh>
Expand Down Expand Up @@ -773,6 +775,11 @@ __launch_bounds__(256) __global__
queue.store(out + block_id * k, out_idx + block_id * k);
}

struct launch_params {
int block_size = 0;
int min_grid_size = 0;
};

template <template <int, bool, typename, typename> class WarpSortClass,
typename T,
typename IdxT,
Expand All @@ -790,29 +797,28 @@ struct launch_setup {
* @param[in] block_size_limit
* Forcefully limit the block size (optional)
*/
static void calc_optimal_params(int k,
int* block_size,
int* min_grid_size,
int block_size_limit = 0)
static auto calc_optimal_params(int k, int block_size_limit) -> launch_params
{
const int capacity = bound_by_power_of_two(k);
if constexpr (Capacity > 1) {
if (capacity < Capacity) {
return launch_setup<WarpSortClass, T, IdxT, Capacity / 2>::calc_optimal_params(
capacity, block_size, min_grid_size, block_size_limit);
capacity, block_size_limit);
}
}
ASSERT(capacity <= Capacity, "Requested k is too big (%d)", k);
auto calc_smem = [k](int block_size) {
int num_of_warp = block_size / std::min<int>(WarpSize, Capacity);
return calc_smem_size_for_block_wide<T, IdxT>(num_of_warp, k);
};
launch_params ps;
RAFT_CUDA_TRY(cudaOccupancyMaxPotentialBlockSizeVariableSMem(
min_grid_size,
block_size,
&ps.min_grid_size,
&ps.block_size,
block_kernel<WarpSortClass, Capacity, true, T, IdxT>,
calc_smem,
block_size_limit));
return ps;
}

static void kernel(int k,
Expand Down Expand Up @@ -869,6 +875,29 @@ struct launch_setup {
}
};

template <template <int, bool, typename, typename> class WarpSortClass, typename T, typename IdxT>
struct warpsort_params_cache {
static constexpr size_t kDefaultSize = 100;
cache::lru<uint64_t, std::hash<uint64_t>, std::equal_to<>, launch_params> value{kDefaultSize};
};

template <template <int, bool, typename, typename> class WarpSortClass, typename T, typename IdxT>
static auto calc_optimal_params(raft::resources const& res, int k, int block_size_limit = 0)
-> launch_params
{
static thread_local std::unordered_map<uint64_t, launch_params> memo{};
uint64_t key = (static_cast<uint64_t>(k) << 32) | static_cast<uint64_t>(block_size_limit);
auto& cache =
resource::get_user_resource<warpsort_params_cache<WarpSortClass, T, IdxT>>(res)->value;
launch_params val;
if (!cache.get(key, &val)) {
val =
launch_setup<WarpSortClass, T, IdxT, kMaxCapacity>::calc_optimal_params(k, block_size_limit);
cache.set(key, val);
}
return val;
}

template <template <int, bool, typename, typename> class WarpSortClass>
struct LaunchThreshold {};

Expand Down Expand Up @@ -898,15 +927,19 @@ struct LaunchThreshold<warp_sort_immediate> {
};

template <template <int, bool, typename, typename> class WarpSortClass, typename T, typename IdxT>
void calc_launch_parameter(
size_t batch_size, size_t len, int k, int* p_num_of_block, int* p_num_of_warp)
void calc_launch_parameter(raft::resources const& res,
size_t batch_size,
size_t len,
int k,
int* p_num_of_block,
int* p_num_of_warp)
{
const int capacity = bound_by_power_of_two(k);
const int capacity_per_full_warp = std::max(capacity, WarpSize);
int block_size = 0;
int min_grid_size = 0;
launch_setup<WarpSortClass, T, IdxT>::calc_optimal_params(k, &block_size, &min_grid_size);
block_size = Pow2<WarpSize>::roundDown(block_size);
auto lps = calc_optimal_params<WarpSortClass, T, IdxT>(res, k);
int block_size = lps.block_size;
int min_grid_size = lps.min_grid_size;
block_size = Pow2<WarpSize>::roundDown(block_size);

int num_of_warp;
int num_of_block;
Expand Down Expand Up @@ -950,19 +983,16 @@ void calc_launch_parameter(
// to occupy a single block well.
block_size = adjust_block_size(block_size);
do {
num_of_warp = block_size / WarpSize;
int another_block_size = 0;
int another_min_grid_size = 0;
launch_setup<WarpSortClass, T, IdxT>::calc_optimal_params(
k, &another_block_size, &another_min_grid_size, block_size);
another_block_size = adjust_block_size(another_block_size);
if (batch_size >= size_t(another_min_grid_size) // still have enough work
&& another_block_size < block_size // protect against an infinite loop
&& another_min_grid_size * another_block_size >
num_of_warp = block_size / WarpSize;
auto another = calc_optimal_params<WarpSortClass, T, IdxT>(res, k, block_size);
another.block_size = adjust_block_size(another.block_size);
if (batch_size >= size_t(another.min_grid_size) // still have enough work
&& another.block_size < block_size // protect against an infinite loop
&& another.min_grid_size * another.block_size >
min_grid_size * block_size // improve occupancy
) {
block_size = another_block_size;
min_grid_size = another_min_grid_size;
block_size = another.block_size;
min_grid_size = another.min_grid_size;
} else {
break;
}
Expand Down Expand Up @@ -1036,7 +1066,8 @@ void select_k_(int num_of_block,
}

template <typename T, typename IdxT, template <int, bool, typename, typename> class WarpSortClass>
void select_k_impl(const T* in,
void select_k_impl(raft::resources const& res,
const T* in,
const IdxT* in_idx,
size_t batch_size,
size_t len,
Expand All @@ -1049,7 +1080,8 @@ void select_k_impl(const T* in,
{
int num_of_block = 0;
int num_of_warp = 0;
calc_launch_parameter<WarpSortClass, T, IdxT>(batch_size, len, k, &num_of_block, &num_of_warp);
calc_launch_parameter<WarpSortClass, T, IdxT>(
res, batch_size, len, k, &num_of_block, &num_of_warp);

select_k_<WarpSortClass, T, IdxT>(num_of_block,
num_of_warp,
Expand Down Expand Up @@ -1103,7 +1135,8 @@ void select_k_impl(const T* in,
* memory pool here to avoid memory allocations within the call).
*/
template <typename T, typename IdxT>
void select_k(const T* in,
void select_k(raft::resources const& res,
const T* in,
const IdxT* in_idx,
size_t batch_size,
size_t len,
Expand All @@ -1123,7 +1156,7 @@ void select_k(const T* in,
int num_of_block = 0;
int num_of_warp = 0;
calc_launch_parameter<warp_sort_immediate, T, IdxT>(
batch_size, len, k, &num_of_block, &num_of_warp);
res, batch_size, len, k, &num_of_block, &num_of_warp);
int len_per_thread = len / (num_of_block * num_of_warp * std::min(capacity, WarpSize));

if (len_per_thread <= LaunchThreshold<warp_sort_immediate>::len_factor_for_choosing) {
Expand All @@ -1141,7 +1174,7 @@ void select_k(const T* in,
mr);
} else {
calc_launch_parameter<warp_sort_filtered, T, IdxT>(
batch_size, len, k, &num_of_block, &num_of_warp);
res, batch_size, len, k, &num_of_block, &num_of_warp);
select_k_<warp_sort_filtered, T, IdxT>(num_of_block,
num_of_warp,
in,
Expand Down
95 changes: 83 additions & 12 deletions cpp/include/raft/neighbors/detail/ivf_pq_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@
#include <raft/core/operators.hpp>
#include <raft/core/resource/detail/device_memory_resource.hpp>
#include <raft/core/resource/device_memory_resource.hpp>
#include <raft/core/resource/user_resource.hpp>
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/linalg/gemm.cuh>
#include <raft/linalg/map.cuh>
#include <raft/linalg/unary_op.cuh>
#include <raft/matrix/detail/select_k.cuh>
#include <raft/matrix/detail/select_warpsort.cuh>
#include <raft/util/cache.hpp>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/device_atomics.cuh>
#include <raft/util/device_loads_stores.cuh>
Expand Down Expand Up @@ -80,6 +82,12 @@ void select_clusters(raft::resources const& handle,
const float* cluster_centers, // [n_lists, dim_ext]
rmm::mr::device_memory_resource* mr)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"ivf_pq::search::select_clusters(n_probes = %u, n_queries = %u, n_lists = %u, dim = %u)",
n_probes,
n_queries,
n_lists,
dim);
auto stream = resource::get_cuda_stream(handle);
/* NOTE[qc_distances]
Expand Down Expand Up @@ -409,6 +417,46 @@ constexpr inline auto expected_probe_coresidency(uint32_t n_clusters,
return 1 + (n_queries - 1) * n_probes / (2 * n_clusters);
}

struct search_kernel_key {
bool manage_local_topk;
uint32_t locality_hint;
double preferred_shmem_carveout;
uint32_t pq_bits;
uint32_t pq_dim;
uint32_t precomp_data_count;
uint32_t n_queries;
uint32_t n_probes;
uint32_t topk;
};

inline auto operator==(const search_kernel_key& a, const search_kernel_key& b) -> bool
{
return a.manage_local_topk == b.manage_local_topk && a.locality_hint == b.locality_hint &&
a.preferred_shmem_carveout == b.preferred_shmem_carveout && a.pq_bits == b.pq_bits &&
a.pq_dim == b.pq_dim && a.precomp_data_count == b.precomp_data_count &&
a.n_queries == b.n_queries && a.n_probes == b.n_probes && a.topk == b.topk;
}

struct search_kernel_key_hash {
inline auto operator()(const search_kernel_key& x) const noexcept -> std::size_t
{
return (size_t{x.manage_local_topk} << 63) +
size_t{x.topk} * size_t{x.n_probes} * size_t{x.n_queries} +
size_t{x.precomp_data_count} * size_t{x.pq_dim} * size_t{x.pq_bits};
}
};

template <typename OutT, typename LutT, typename IvfSampleFilterT>
struct search_kernel_cache {
/** Number of matmul invocations to cache. */
static constexpr size_t kDefaultSize = 100;
cache::lru<search_kernel_key,
search_kernel_key_hash,
std::equal_to<>,
selected<OutT, LutT, IvfSampleFilterT>>
value{kDefaultSize};
};

/**
* The "main part" of the search, which assumes that outer-level `search` has already:
*
Expand All @@ -433,6 +481,12 @@ void ivfpq_search_worker(raft::resources const& handle,
double preferred_shmem_carveout,
IvfSampleFilterT sample_filter)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"ivf_pq::search-worker(n_queries = %u, n_probes = %u, k = %u, dim = %zu)",
n_queries,
n_probes,
topK,
index.dim());
auto stream = resource::get_cuda_stream(handle);
auto mr = resource::get_workspace_resource(handle);

Expand Down Expand Up @@ -535,17 +589,32 @@ void ivfpq_search_worker(raft::resources const& handle,
} break;
}

auto search_instance = compute_similarity_select<ScoreT, LutT, IvfSampleFilterT>(
resource::get_device_properties(handle),
manage_local_topk,
coresidency,
preferred_shmem_carveout,
index.pq_bits(),
index.pq_dim(),
precomp_data_count,
n_queries,
n_probes,
topK);
selected<ScoreT, LutT, IvfSampleFilterT> search_instance;
search_kernel_key search_key{manage_local_topk,
coresidency,
preferred_shmem_carveout,
index.pq_bits(),
index.pq_dim(),
precomp_data_count,
n_queries,
n_probes,
topK};
auto& cache =
resource::get_user_resource<search_kernel_cache<ScoreT, LutT, IvfSampleFilterT>>(handle)->value;
if (!cache.get(search_key, &search_instance)) {
search_instance = compute_similarity_select<ScoreT, LutT, IvfSampleFilterT>(
resource::get_device_properties(handle),
manage_local_topk,
coresidency,
preferred_shmem_carveout,
index.pq_bits(),
index.pq_dim(),
precomp_data_count,
n_queries,
n_probes,
topK);
cache.set(search_key, search_instance);
}

rmm::device_uvector<LutT> device_lut(search_instance.device_lut_size, stream, mr);
std::optional<device_vector<float>> query_kths_buf{std::nullopt};
Expand Down Expand Up @@ -696,7 +765,7 @@ inline auto get_max_batch_size(raft::resources const& res,
uint32_t max_samples) -> uint32_t
{
uint32_t max_batch_size = n_queries;
uint32_t n_ctas_total = getMultiProcessorCount() * 2;
uint32_t n_ctas_total = resource::get_device_properties(res).multiProcessorCount * 2;
uint32_t n_ctas_total_per_batch = n_ctas_total / max_batch_size;
float utilization = float(n_ctas_total_per_batch * max_batch_size) / n_ctas_total;
if (n_ctas_total_per_batch > 1 || (n_ctas_total_per_batch == 1 && utilization < 0.6)) {
Expand Down Expand Up @@ -798,6 +867,8 @@ inline void search(raft::resources const& handle,

for (uint32_t offset_q = 0; offset_q < n_queries; offset_q += max_queries) {
uint32_t queries_batch = min(max_queries, n_queries - offset_q);
common::nvtx::range<common::nvtx::domain::raft> batch_scope(
"ivf_pq::search-batch(queries: %u - %u)", offset_q, offset_q + queries_batch);

select_clusters(handle,
clusters_to_probe.data(),
Expand Down
10 changes: 5 additions & 5 deletions cpp/internal/raft_internal/matrix/select_k.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -147,23 +147,23 @@ void select_k_impl(const resources& handle,
stream);
case Algo::kWarpAuto:
return detail::select::warpsort::select_k<T, IdxT>(
in, in_idx, batch_size, len, k, out, out_idx, select_min, stream);
handle, in, in_idx, batch_size, len, k, out, out_idx, select_min, stream);
case Algo::kWarpImmediate:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_immediate>(
in, in_idx, batch_size, len, k, out, out_idx, select_min, stream);
handle, in, in_idx, batch_size, len, k, out, out_idx, select_min, stream);
case Algo::kWarpFiltered:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_filtered>(
in, in_idx, batch_size, len, k, out, out_idx, select_min, stream);
handle, in, in_idx, batch_size, len, k, out, out_idx, select_min, stream);
case Algo::kWarpDistributed:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_distributed>(
in, in_idx, batch_size, len, k, out, out_idx, select_min, stream);
handle, in, in_idx, batch_size, len, k, out, out_idx, select_min, stream);
case Algo::kWarpDistributedShm:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_distributed_ext>(
in, in_idx, batch_size, len, k, out, out_idx, select_min, stream);
handle, in, in_idx, batch_size, len, k, out, out_idx, select_min, stream);
case Algo::kFaissBlockSelect:
return neighbors::detail::select_k(
in, in_idx, batch_size, len, out, out_idx, select_min, k, stream);
Expand Down

0 comments on commit 7aefb3b

Please sign in to comment.