Skip to content

Commit

Permalink
Fix compilation error when _CLK_BREAKDOWN is defined in cagra.
Browse files Browse the repository at this point in the history
PR #1740 forgot to rename `BLOCK_SIZE` in `#ifdef _CLK_BREAKDOWN`
blocks.

also remove an unused function in search_single_cta_kernel-inl.cuh
  • Loading branch information
jiangyinzuo committed Jun 2, 2024
1 parent 8ef71de commit ae78ee9
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ __launch_bounds__(1024, 1) RAFT_KERNEL search_kernel(

#if 0
/* debug */
for (unsigned i = threadIdx.x; i < result_buffer_size_32; i += BLOCK_SIZE) {
for (unsigned i = threadIdx.x; i < result_buffer_size_32; i += blockDim.x) {
result_indices_buffer[i] = utils::get_max_value<INDEX_T>();
result_distances_buffer[i] = utils::get_max_value<DISTANCE_T>();
}
Expand Down Expand Up @@ -351,7 +351,7 @@ __launch_bounds__(1024, 1) RAFT_KERNEL search_kernel(
}

#ifdef _CLK_BREAKDOWN
if ((threadIdx.x == 0 || threadIdx.x == BLOCK_SIZE - 1) && (blockIdx.x == 0) &&
if ((threadIdx.x == 0 || threadIdx.x == blockDim.x - 1) && (blockIdx.x == 0) &&
((query_id * 3) % gridDim.y < 3)) {
RAFT_LOG_DEBUG(
"query, %d, thread, %d"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -448,14 +448,6 @@ __device__ inline void hashmap_restore(INDEX_T* const hashmap_ptr,
}
}

template <class T, unsigned BLOCK_SIZE>
__device__ inline void set_value_device(T* const ptr, const T fill, const std::uint32_t count)
{
for (std::uint32_t i = threadIdx.x; i < count; i += BLOCK_SIZE) {
ptr[i] = fill;
}
}

// One query one thread block
template <uint32_t TEAM_SIZE,
uint32_t DATASET_BLOCK_DIM,
Expand Down Expand Up @@ -791,7 +783,7 @@ __launch_bounds__(1024, 1) RAFT_KERNEL search_kernel(
num_executed_iterations[query_id] = iter + 1;
}
#ifdef _CLK_BREAKDOWN
if ((threadIdx.x == 0 || threadIdx.x == BLOCK_SIZE - 1) && ((query_id * 3) % gridDim.y < 3)) {
if ((threadIdx.x == 0 || threadIdx.x == blockDim.x - 1) && ((query_id * 3) % gridDim.y < 3)) {
RAFT_LOG_DEBUG(
"query, %d, thread, %d"
", init, %d"
Expand Down

0 comments on commit ae78ee9

Please sign in to comment.