Skip to content

Commit

Permalink
Fix integer overflow in ANN kmeans (#835)
Browse files Browse the repository at this point in the history
CUDA block and thread indices are uint32, so the following operation results in an overflow before being cast to uint64:

```cuda
uint64_t gid = threadIdx.x + (blockDim.x * blockIdx.x);
```

This PR fixes the error that @tfeher and myself have encountered while benchmarking large datasets:

```
Incorrect mesocluster size at 0. 625618 vs 625858
```

cc @achirkin

Authors:
  - Louis Sugy (https://github.com/Nyrio)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)

URL: #835
  • Loading branch information
Nyrio authored Sep 21, 2022
1 parent d9c7aa9 commit e394ac2
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion cpp/include/raft/spatial/knn/detail/ann_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ __global__ void accumulate_into_selected_kernel(uint32_t n_rows,
const T* input,
const uint32_t* row_ids)
{
uint64_t gid = threadIdx.x + (blockDim.x * blockIdx.x);
uint64_t gid = threadIdx.x + (blockDim.x * static_cast<uint64_t>(blockIdx.x));
uint64_t j = gid % n_cols;
uint64_t i = gid / n_cols;
if (i >= n_rows) return;
Expand Down

0 comments on commit e394ac2

Please sign in to comment.