From a8a46efc37b129ff9da52e4617b1ae93dafbb601 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 27 Oct 2021 12:23:11 -0400 Subject: [PATCH] Fixing overflow in expanded distances --- cpp/include/raft/sparse/distance/detail/bin_distance.cuh | 2 +- cpp/include/raft/sparse/distance/detail/l2_distance.cuh | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/sparse/distance/detail/bin_distance.cuh b/cpp/include/raft/sparse/distance/detail/bin_distance.cuh index e6dd1331ae..3f8c32a20b 100644 --- a/cpp/include/raft/sparse/distance/detail/bin_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/bin_distance.cuh @@ -54,7 +54,7 @@ __global__ void compute_binary_warp_kernel(value_t *__restrict__ C, const value_t *__restrict__ R_norms, value_idx n_rows, value_idx n_cols, expansion_f expansion_func) { - value_idx tid = blockDim.x * blockIdx.x + threadIdx.x; + std::size_t tid = blockDim.x * blockIdx.x + threadIdx.x; value_idx i = tid / n_cols; value_idx j = tid % n_cols; diff --git a/cpp/include/raft/sparse/distance/detail/l2_distance.cuh b/cpp/include/raft/sparse/distance/detail/l2_distance.cuh index e7ac78b80a..f06a15215c 100644 --- a/cpp/include/raft/sparse/distance/detail/l2_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/l2_distance.cuh @@ -64,7 +64,7 @@ __global__ void compute_euclidean_warp_kernel( value_t *__restrict__ C, const value_t *__restrict__ Q_sq_norms, const value_t *__restrict__ R_sq_norms, value_idx n_rows, value_idx n_cols, expansion_f expansion_func) { - value_idx tid = blockDim.x * blockIdx.x + threadIdx.x; + std::size_t tid = blockDim.x * blockIdx.x + threadIdx.x; value_idx i = tid / n_cols; value_idx j = tid % n_cols; @@ -85,7 +85,7 @@ __global__ void compute_correlation_warp_kernel( const value_t *__restrict__ R_sq_norms, const value_t *__restrict__ Q_norms, const value_t *__restrict__ R_norms, value_idx n_rows, value_idx n_cols, value_idx n) { - value_idx tid = blockDim.x * blockIdx.x + threadIdx.x; + std::size_t tid = blockDim.x * blockIdx.x + threadIdx.x; value_idx i = tid / n_cols; value_idx j = tid % n_cols;