diff --git a/cpp/include/raft/stats/detail/sum.cuh b/cpp/include/raft/stats/detail/sum.cuh index 1d727f8b77..bb45eb50f4 100644 --- a/cpp/include/raft/stats/detail/sum.cuh +++ b/cpp/include/raft/stats/detail/sum.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,7 +43,7 @@ RAFT_KERNEL sumKernelRowMajor(Type* mu, const Type* data, IdxType D, IdxType N) __syncthreads(); raft::myAtomicAdd(smu + thisColId, thread_data); __syncthreads(); - if (threadIdx.x < ColsPerBlk) raft::myAtomicAdd(mu + colId, smu[thisColId]); + if (threadIdx.x < ColsPerBlk && colId < D) raft::myAtomicAdd(mu + colId, smu[thisColId]); } template diff --git a/cpp/test/stats/sum.cu b/cpp/test/stats/sum.cu index 040b662c42..290c53a563 100644 --- a/cpp/test/stats/sum.cu +++ b/cpp/test/stats/sum.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * Copyright (c) 2018-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -81,8 +81,8 @@ class SumTest : public ::testing::TestWithParam> { rmm::device_uvector data, sum_act; }; -const std::vector> inputsf = {{0.05f, 1024, 32, 1234ULL}, - {0.05f, 1024, 256, 1234ULL}}; +const std::vector> inputsf = { + {0.05f, 4, 5, 1234ULL}, {0.05f, 1024, 32, 1234ULL}, {0.05f, 1024, 256, 1234ULL}}; const std::vector> inputsd = {{0.05, 1024, 32, 1234ULL}, {0.05, 1024, 256, 1234ULL}};