Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename CAGRA parameter num_parents to search_width #1676

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/bench/ann/src/raft/raft_benchmark.cu
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ void parse_search_param(const nlohmann::json& conf,
typename raft::bench::ann::RaftCagra<T, IdxT>::SearchParam& param)
{
if (conf.contains("itopk")) { param.p.itopk_size = conf.at("itopk"); }
if (conf.contains("search_width")) { param.p.num_parents = conf.at("search_width"); }
if (conf.contains("search_width")) { param.p.search_width = conf.at("search_width"); }
if (conf.contains("max_iterations")) { param.p.max_iterations = conf.at("max_iterations"); }
}
#endif
Expand Down
12 changes: 6 additions & 6 deletions cpp/bench/prims/neighbors/cagra_bench.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ struct params {
int degree;
int itopk_size;
int block_size;
int num_parents;
int search_width;
int max_iterations;
};

Expand Down Expand Up @@ -85,7 +85,7 @@ struct CagraBench : public fixture {
search_params.itopk_size = params_.itopk_size;
search_params.team_size = 0;
search_params.thread_block_size = params_.block_size;
search_params.num_parents = params_.num_parents;
search_params.search_width = params_.search_width;

auto indices = make_device_matrix<IdxT, int64_t>(handle, params_.n_queries, params_.k);
auto distances = make_device_matrix<float, int64_t>(handle, params_.n_queries, params_.k);
Expand All @@ -106,7 +106,7 @@ struct CagraBench : public fixture {
int iterations = params_.max_iterations;
if (iterations == 0) {
// see search_plan_impl::adjust_search_params()
double r = params_.itopk_size / static_cast<float>(params_.num_parents);
double r = params_.itopk_size / static_cast<float>(params_.search_width);
iterations = 1 + std::min(r * 1.1, r + 10);
}
state.counters["dataset (GiB)"] = data_size / (1 << 30);
Expand All @@ -118,7 +118,7 @@ struct CagraBench : public fixture {
state.counters["k"] = params_.k;
state.counters["itopk_size"] = params_.itopk_size;
state.counters["block_size"] = params_.block_size;
state.counters["num_parents"] = params_.num_parents;
state.counters["search_width"] = params_.search_width;
state.counters["iterations"] = iterations;
}

Expand All @@ -140,7 +140,7 @@ inline const std::vector<params> generate_inputs()
{64}, // knn graph degree
{64}, // itopk_size
{0}, // block_size
{1}, // num_parents
{1}, // search_width
{0} // max_iterations
);
auto inputs2 = raft::util::itertools::product<params>({2000000ull, 10000000ull}, // n_samples
Expand All @@ -150,7 +150,7 @@ inline const std::vector<params> generate_inputs()
{64}, // knn graph degree
{64}, // itopk_size
{64, 128, 256, 512, 1024}, // block_size
{1}, // num_parents
{1}, // search_width
{0} // max_iterations
);
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/neighbors/cagra_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ struct search_params : ann::search_params {

/*/ Number of graph nodes to select as the starting point for the search in each iteration. aka
* search width?*/
size_t num_parents = 1;
size_t search_width = 1;
/** Lower limit of search iterations. */
size_t min_iterations = 0;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,13 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(INDEX_T* const result_child_in
INDEX_T* const visited_hashmap_ptr,
const std::uint32_t hash_bitlen,
const INDEX_T* const parent_indices,
const std::uint32_t num_parents)
const std::uint32_t search_width)
{
const INDEX_T invalid_index = utils::get_max_value<INDEX_T>();

// Read child indices of parents from knn graph and check if the distance
// computaiton is necessary.
for (uint32_t i = threadIdx.x; i < knn_k * num_parents; i += BLOCK_SIZE) {
for (uint32_t i = threadIdx.x; i < knn_k * search_width; i += BLOCK_SIZE) {
const INDEX_T parent_id = parent_indices[i / knn_k];
INDEX_T child_id = invalid_index;
if (parent_id != invalid_index) {
Expand Down Expand Up @@ -203,10 +203,10 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(INDEX_T* const result_child_in
__syncthreads();

// Compute the distance to child nodes
std::uint32_t max_i = knn_k * num_parents;
std::uint32_t max_i = knn_k * search_width;
if (max_i % (32 / TEAM_SIZE)) { max_i += (32 / TEAM_SIZE) - (max_i % (32 / TEAM_SIZE)); }
for (std::uint32_t i = threadIdx.x / TEAM_SIZE; i < max_i; i += BLOCK_SIZE / TEAM_SIZE) {
const bool valid_i = (i < (knn_k * num_parents));
const bool valid_i = (i < (knn_k * search_width));
INDEX_T child_id = invalid_index;
if (valid_i) { child_id = result_child_indices_ptr[i]; }

Expand Down
14 changes: 7 additions & 7 deletions cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ struct search : public search_plan_impl<DATA_T, INDEX_T, DISTANCE_T> {
using search_plan_impl<DATA_T, INDEX_T, DISTANCE_T>::itopk_size;
using search_plan_impl<DATA_T, INDEX_T, DISTANCE_T>::algo;
using search_plan_impl<DATA_T, INDEX_T, DISTANCE_T>::team_size;
using search_plan_impl<DATA_T, INDEX_T, DISTANCE_T>::num_parents;
using search_plan_impl<DATA_T, INDEX_T, DISTANCE_T>::search_width;
using search_plan_impl<DATA_T, INDEX_T, DISTANCE_T>::min_iterations;
using search_plan_impl<DATA_T, INDEX_T, DISTANCE_T>::max_iterations;
using search_plan_impl<DATA_T, INDEX_T, DISTANCE_T>::thread_block_size;
Expand Down Expand Up @@ -108,17 +108,17 @@ struct search : public search_plan_impl<DATA_T, INDEX_T, DISTANCE_T> {
void set_params(raft::resources const& res, const search_params& params)
{
this->itopk_size = 32;
num_parents = 1;
num_cta_per_query = max(params.num_parents, params.itopk_size / 32);
result_buffer_size = itopk_size + num_parents * graph_degree;
search_width = 1;
num_cta_per_query = max(params.search_width, params.itopk_size / 32);
result_buffer_size = itopk_size + search_width * graph_degree;
typedef raft::Pow2<32> AlignBytes;
unsigned result_buffer_size_32 = AlignBytes::roundUp(result_buffer_size);
// constexpr unsigned max_result_buffer_size = 256;
RAFT_EXPECTS(result_buffer_size_32 <= 256, "Result buffer size cannot exceed 256");

smem_size = sizeof(float) * max_dim +
(sizeof(INDEX_T) + sizeof(DISTANCE_T)) * result_buffer_size_32 +
sizeof(uint32_t) * num_parents + sizeof(uint32_t);
sizeof(uint32_t) * search_width + sizeof(uint32_t);
RAFT_LOG_DEBUG("# smem_size: %u", smem_size);

//
Expand All @@ -143,7 +143,7 @@ struct search : public search_plan_impl<DATA_T, INDEX_T, DISTANCE_T> {
cudaDeviceProp deviceProp = resource::get_device_properties(res);
RAFT_LOG_DEBUG("# multiProcessorCount: %d", deviceProp.multiProcessorCount);
while ((block_size < max_block_size) &&
(graph_degree * num_parents * team_size >= block_size * 2) &&
(graph_degree * search_width * team_size >= block_size * 2) &&
(num_cta_per_query * max_queries <=
(1024 / (block_size * 2)) * deviceProp.multiProcessorCount)) {
block_size *= 2;
Expand Down Expand Up @@ -210,7 +210,7 @@ struct search : public search_plan_impl<DATA_T, INDEX_T, DISTANCE_T> {
rand_xor_mask,
num_seeds,
itopk_size,
num_parents,
search_width,
min_iterations,
max_iterations,
stream);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void select_and_run(raft::device_matrix_view<const DATA_T, int64_t, layout_strid
uint64_t rand_xor_mask,
uint32_t num_seeds,
size_t itopk_size,
size_t num_parents,
size_t search_width,
size_t min_iterations,
size_t max_iterations,
cudaStream_t stream) RAFT_EXPLICIT;
Expand All @@ -73,7 +73,7 @@ void select_and_run(raft::device_matrix_view<const DATA_T, int64_t, layout_strid
uint64_t rand_xor_mask, \
uint32_t num_seeds, \
size_t itopk_size, \
size_t num_parents, \
size_t search_width, \
size_t min_iterations, \
size_t max_iterations, \
cudaStream_t stream);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ namespace multi_cta_search {
// #define _CLK_BREAKDOWN

template <class INDEX_T>
__device__ void pickup_next_parents(INDEX_T* const next_parent_indices, // [num_parents]
const uint32_t num_parents,
__device__ void pickup_next_parents(INDEX_T* const next_parent_indices, // [search_width]
const uint32_t search_width,
INDEX_T* const itopk_indices, // [num_itopk]
const size_t num_itopk,
uint32_t* const terminate_flag)
Expand All @@ -56,7 +56,7 @@ __device__ void pickup_next_parents(INDEX_T* const next_parent_indices, // [num
const unsigned warp_id = threadIdx.x / 32;
if (warp_id > 0) { return; }
const unsigned lane_id = threadIdx.x % 32;
for (uint32_t i = lane_id; i < num_parents; i += 32) {
for (uint32_t i = lane_id; i < search_width; i += 32) {
next_parent_indices[i] = utils::get_max_value<INDEX_T>();
}
uint32_t max_itopk = num_itopk;
Expand All @@ -74,13 +74,13 @@ __device__ void pickup_next_parents(INDEX_T* const next_parent_indices, // [num
const uint32_t ballot_mask = __ballot_sync(0xffffffff, new_parent);
if (new_parent) {
const auto i = __popc(ballot_mask & ((1 << lane_id) - 1)) + num_new_parents;
if (i < num_parents) {
if (i < search_width) {
next_parent_indices[i] = index;
itopk_indices[j] |= index_msb_1_mask; // set most significant bit as used node
}
}
num_new_parents += __popc(ballot_mask);
if (num_new_parents >= num_parents) { break; }
if (num_new_parents >= search_width) { break; }
}
if (threadIdx.x == 0 && (num_new_parents == 0)) { *terminate_flag = 1; }
}
Expand Down Expand Up @@ -149,7 +149,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel(
INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen]
const uint32_t hash_bitlen,
const uint32_t itopk_size,
const uint32_t num_parents,
const uint32_t search_width,
const uint32_t min_iteration,
const uint32_t max_iteration,
uint32_t* const num_executed_iterations /* stats */
Expand Down Expand Up @@ -183,10 +183,10 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel(
// Layout of result_buffer
// +----------------+------------------------------+---------+
// | internal_top_k | neighbors of parent nodes | padding |
// | <itopk_size> | <num_parents * graph_degree> | upto 32 |
// | <itopk_size> | <search_width * graph_degree> | upto 32 |
// +----------------+------------------------------+---------+
// |<--- result_buffer_size --->|
uint32_t result_buffer_size = itopk_size + (num_parents * graph_degree);
uint32_t result_buffer_size = itopk_size + (search_width * graph_degree);
uint32_t result_buffer_size_32 = result_buffer_size;
if (result_buffer_size % 32) { result_buffer_size_32 += 32 - (result_buffer_size % 32); }
assert(result_buffer_size_32 <= MAX_ELEMENTS);
Expand All @@ -197,7 +197,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel(
reinterpret_cast<DISTANCE_T*>(result_indices_buffer + result_buffer_size_32);
auto parent_indices_buffer =
reinterpret_cast<INDEX_T*>(result_distances_buffer + result_buffer_size_32);
auto terminate_flag = reinterpret_cast<uint32_t*>(parent_indices_buffer + num_parents);
auto terminate_flag = reinterpret_cast<uint32_t*>(parent_indices_buffer + search_width);

#if 0
/* debug */
Expand Down Expand Up @@ -252,7 +252,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel(
_CLK_START();
topk_by_bitonic_sort<MAX_ELEMENTS, INDEX_T>(result_distances_buffer,
result_indices_buffer,
itopk_size + (num_parents * graph_degree),
itopk_size + (search_width * graph_degree),
itopk_size);
_CLK_REC(clk_topk);

Expand All @@ -264,7 +264,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel(
// pick up next parents
_CLK_START();
pickup_next_parents<INDEX_T>(
parent_indices_buffer, num_parents, result_indices_buffer, itopk_size, terminate_flag);
parent_indices_buffer, search_width, result_indices_buffer, itopk_size, terminate_flag);
_CLK_REC(clk_pickup_parents);

__syncthreads();
Expand All @@ -287,7 +287,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel(
local_visited_hashmap_ptr,
hash_bitlen,
parent_indices_buffer,
num_parents);
search_width);
_CLK_REC(clk_compute_distance);
__syncthreads();

Expand Down Expand Up @@ -472,7 +472,7 @@ void select_and_run( // raft::resources const& res,
uint64_t rand_xor_mask,
uint32_t num_seeds,
size_t itopk_size,
size_t num_parents,
size_t search_width,
size_t min_iterations,
size_t max_iterations,
cudaStream_t stream)
Expand Down Expand Up @@ -510,7 +510,7 @@ void select_and_run( // raft::resources const& res,
hashmap_ptr,
hash_bitlen,
itopk_size,
num_parents,
search_width,
min_iterations,
max_iterations,
num_executed_iterations);
Expand Down
Loading