diff --git a/cpp/include/raft/core/math.hpp b/cpp/include/raft/core/math.hpp index 9ce768bf40..e082aaf41a 100644 --- a/cpp/include/raft/core/math.hpp +++ b/cpp/include/raft/core/math.hpp @@ -56,7 +56,11 @@ constexpr RAFT_INLINE_FUNCTION auto abs(T x) !std::is_same_v, T> { - return x < T{0} ? -x : x; + if constexpr (std::is_unsigned_v) { + return x; + } else { + return x < T{0} ? -x : x; + } } #if defined(_RAFT_HAS_CUDA) template diff --git a/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh b/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh index 9f3be7ce0e..9680cbc636 100644 --- a/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh +++ b/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh @@ -42,6 +42,18 @@ struct ReductionThinPolicy { static constexpr bool NoSequentialReduce = noLoop; }; +template +DI void KahanBabushkaNeumaierSum(Type& sum, Type& compensation, const Type& cur_value) +{ + const Type t = sum + cur_value; + if (abs(sum) >= abs(cur_value)) { + compensation += (sum - t) + cur_value; + } else { + compensation += (cur_value - t) + sum; + } + sum = t; +} + template +RAFT_KERNEL __launch_bounds__(Policy::ThreadsPerBlock) coalescedSumThinKernel(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + MainLambda main_op, + FinalLambda final_op, + bool inplace = false) +{ + /* The strategy to achieve near-SOL memory bandwidth differs based on D: + * - For small D, we need to process multiple rows per logical warp in order to have + * multiple loads per thread and increase bytes in flight and amortize latencies. + * - For large D, we start with a sequential reduction. The compiler partially unrolls + * that loop (e.g. first a loop of stride 16, then 8, 4, and 1). + */ + IdxType i0 = threadIdx.y + (Policy::RowsPerBlock * static_cast(blockIdx.x)); + if (i0 >= N) return; + + OutType acc[Policy::RowsPerLogicalWarp]; + OutType thread_c[Policy::RowsPerLogicalWarp]; + +#pragma unroll + for (int k = 0; k < Policy::RowsPerLogicalWarp; k++) { + acc[k] = init; + thread_c[k] = 0; + } + + if constexpr (Policy::NoSequentialReduce) { + IdxType j = threadIdx.x; + if (j < D) { +#pragma unroll + for (IdxType k = 0; k < Policy::RowsPerLogicalWarp; k++) { + // Only the first row is known to be within bounds. Clamp to avoid out-of-mem read. + const IdxType i = raft::min(i0 + k * Policy::NumLogicalWarps, N - 1); + // acc[k] = reduce_op(acc[k], main_op(data[j + (D * i)], j)); + KahanBabushkaNeumaierSum(acc[k], thread_c[k], main_op(data[j + (D * i)], j)); + } + } + } else { + for (IdxType j = threadIdx.x; j < D; j += Policy::LogicalWarpSize) { +#pragma unroll + for (IdxType k = 0; k < Policy::RowsPerLogicalWarp; k++) { + const IdxType i = raft::min(i0 + k * Policy::NumLogicalWarps, N - 1); + // acc[k] = reduce_op(acc[k], main_op(data[j + (D * i)], j)); + KahanBabushkaNeumaierSum(acc[k], thread_c[k], main_op(data[j + (D * i)], j)); + } + } + } + + /* This vector reduction has two benefits compared to naive separate reductions: + * - It avoids the LSU bottleneck when the number of columns is around 32 (e.g. for 32, 5 shuffles + * are required and there is no initial sequential reduction to amortize that cost). + * - It distributes the outputs to multiple threads, enabling a coalesced store when the number of + * rows per logical warp and logical warp size are equal. + */ + raft::logicalWarpReduceVector( + acc, threadIdx.x, raft::add_op()); + + raft::logicalWarpReduceVector( + thread_c, threadIdx.x, raft::add_op()); + + constexpr int reducOutVecWidth = + std::max(1, Policy::RowsPerLogicalWarp / Policy::LogicalWarpSize); + constexpr int reducOutGroupSize = + std::max(1, Policy::LogicalWarpSize / Policy::RowsPerLogicalWarp); + constexpr int reducNumGroups = Policy::LogicalWarpSize / reducOutGroupSize; + + if (threadIdx.x % reducOutGroupSize == 0) { + const int groupId = threadIdx.x / reducOutGroupSize; + if (inplace) { +#pragma unroll + for (int k = 0; k < reducOutVecWidth; k++) { + const int reductionId = k * reducNumGroups + groupId; + const IdxType i = i0 + reductionId * Policy::NumLogicalWarps; + if (i < N) { dots[i] = final_op(dots[i] + acc[k] + thread_c[k]); } + } + } else { +#pragma unroll + for (int k = 0; k < reducOutVecWidth; k++) { + const int reductionId = k * reducNumGroups + groupId; + const IdxType i = i0 + reductionId * Policy::NumLogicalWarps; + if (i < N) { dots[i] = final_op(acc[k] + thread_c[k]); } + } + } + } +} + template (Policy::NoSequentialReduce)); dim3 threads(Policy::LogicalWarpSize, Policy::NumLogicalWarps, 1); dim3 blocks(ceildiv(N, Policy::RowsPerBlock), 1, 1); - coalescedReductionThinKernel - <<>>(dots, data, D, N, init, main_op, reduce_op, final_op, inplace); + if constexpr (std::is_same_v) { + coalescedSumThinKernel + <<>>(dots, data, D, N, init, main_op, final_op, inplace); + } else { + coalescedReductionThinKernel<<>>( + dots, data, D, N, init, main_op, reduce_op, final_op, inplace); + } RAFT_CUDA_TRY(cudaPeekAtLastError()); } @@ -240,6 +350,44 @@ RAFT_KERNEL __launch_bounds__(TPB) coalescedReductionMediumKernel(OutType* dots, } } +template +RAFT_KERNEL __launch_bounds__(TPB) coalescedSumMediumKernel(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + MainLambda main_op, + FinalLambda final_op, + bool inplace = false) +{ + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage1; + __shared__ typename BlockReduce::TempStorage temp_storage2; + OutType thread_data = init; + OutType thread_c = (OutType)0; + + IdxType rowStart = blockIdx.x * D; + for (IdxType i = threadIdx.x; i < D; i += TPB) { + IdxType idx = rowStart + i; + KahanBabushkaNeumaierSum(thread_data, thread_c, main_op(data[idx], i)); + } + OutType block_acc = BlockReduce(temp_storage1).Sum(thread_data); + OutType block_c = BlockReduce(temp_storage2).Sum(thread_c); + + if (threadIdx.x == 0) { + if (inplace) { + dots[blockIdx.x] = final_op(dots[blockIdx.x] + block_acc + block_c); + } else { + dots[blockIdx.x] = final_op(block_acc + block_c); + } + } +} + template fun_scope("coalescedReductionMedium<%d>", TPB); - coalescedReductionMediumKernel - <<>>(dots, data, D, N, init, main_op, reduce_op, final_op, inplace); + if constexpr (std::is_same_v) { + coalescedSumMediumKernel + <<>>(dots, data, D, N, init, main_op, final_op, inplace); + } else { + coalescedReductionMediumKernel + <<>>(dots, data, D, N, init, main_op, reduce_op, final_op, inplace); + } RAFT_CUDA_TRY(cudaPeekAtLastError()); } @@ -322,6 +475,32 @@ RAFT_KERNEL __launch_bounds__(Policy::ThreadsPerBlock) if (threadIdx.x == 0) { buffer[Policy::BlocksPerRow * blockIdx.x + blockIdx.y] = acc; } } +template +RAFT_KERNEL __launch_bounds__(Policy::ThreadsPerBlock) coalescedSumThickKernel( + OutType* buffer, const InType* data, IdxType D, IdxType N, OutType init, MainLambda main_op) +{ + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage1; + __shared__ typename BlockReduce::TempStorage temp_storage2; + + OutType thread_data = init; + OutType thread_c = (OutType)0; + + IdxType rowStart = blockIdx.x * D; + for (IdxType i = blockIdx.y * Policy::ThreadsPerBlock + threadIdx.x; i < D; + i += Policy::BlockStride) { + IdxType idx = rowStart + i; + KahanBabushkaNeumaierSum(thread_data, thread_c, main_op(data[idx], i)); + } + + OutType block_acc = BlockReduce(temp_storage1).Sum(thread_data); + OutType block_c = BlockReduce(temp_storage2).Sum(thread_c); + + if (threadIdx.x == 0) { + buffer[Policy::BlocksPerRow * blockIdx.x + blockIdx.y] = block_acc + block_c; + } +} + template - <<>>(buffer.data(), data, D, N, init, main_op, reduce_op); + if constexpr (std::is_same_v) { + coalescedSumThickKernel + <<>>(buffer.data(), data, D, N, init, main_op); + } else { + coalescedReductionThickKernel + <<>>(buffer.data(), data, D, N, init, main_op, reduce_op); + } RAFT_CUDA_TRY(cudaPeekAtLastError()); coalescedReductionThin(dots, @@ -391,18 +574,16 @@ void coalescedReductionThickDispatcher(OutType* dots, { // Note: multiple elements per thread to take advantage of the sequential reduction and loop // unrolling - if (D < IdxType(32768)) { - coalescedReductionThick, ReductionThinPolicy<32, 128, 1>>( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); - } else { - coalescedReductionThick, ReductionThinPolicy<32, 128, 1>>( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); - } + coalescedReductionThick, ReductionThinPolicy<32, 128, 1>>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } // Primitive to perform reductions along the coalesced dimension of the matrix, i.e. reduce along // rows for row major or reduce along columns for column major layout. Can do an inplace reduction // adding to original values of dots if requested. +// In case of an add-reduction, a compensated summation will be performed in order to reduce +// numerical error. Note that the compensation will only be performed 'per-thread' for performance +// reasons and therefore not be equivalent to a sequential compensation. template = IdxType(4) * numSMs) { + if (D <= IdxType(512) || (N >= IdxType(16) * numSMs && D < IdxType(2048))) { coalescedReductionThinDispatcher( dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); - } else if (N < numSMs && D >= IdxType(16384)) { + } else if (N < numSMs && D >= IdxType(1 << 17)) { coalescedReductionThickDispatcher( dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } else { diff --git a/cpp/include/raft/linalg/detail/strided_reduction.cuh b/cpp/include/raft/linalg/detail/strided_reduction.cuh index 617ac6d874..567dc6220e 100644 --- a/cpp/include/raft/linalg/detail/strided_reduction.cuh +++ b/cpp/include/raft/linalg/detail/strided_reduction.cuh @@ -28,38 +28,63 @@ namespace raft { namespace linalg { namespace detail { -// Kernel to perform reductions along the strided dimension +// Kernel to perform summation along the strided dimension // of the matrix, i.e. reduce along columns for row major or reduce along rows // for column major layout +// A compensated summation will be performed in order to reduce numerical error. +// Note that the compensation will only be performed 'per-block' for performance +// reasons and therefore not be equivalent to a sequential compensation. + template RAFT_KERNEL stridedSummationKernel( - Type* dots, const Type* data, int D, int N, Type init, MainLambda main_op) + Type* out, const Type* data, int D, int N, Type init, MainLambda main_op) { // Thread reduction - Type thread_data = Type(init); - int colStart = blockIdx.x * blockDim.x + threadIdx.x; + Type thread_sum = Type(init); + Type thread_c = Type(0); + int colStart = blockIdx.x * blockDim.x + threadIdx.x; if (colStart < D) { int rowStart = blockIdx.y * blockDim.y + threadIdx.y; int stride = blockDim.y * gridDim.y; for (int j = rowStart; j < N; j += stride) { int idx = colStart + j * D; - thread_data += main_op(data[idx], j); + + // KahanBabushkaNeumaierSum + const Type cur_value = main_op(data[idx], j); + const Type t = thread_sum + cur_value; + if (abs(thread_sum) >= abs(cur_value)) { + thread_c += (thread_sum - t) + cur_value; + } else { + thread_c += (cur_value - t) + thread_sum; + } + thread_sum = t; } } // Block reduction - extern __shared__ char tmp[]; // One element per thread in block - Type* temp = (Type*)tmp; // Cast to desired type - int myidx = threadIdx.x + blockDim.x * threadIdx.y; - temp[myidx] = thread_data; + extern __shared__ char tmp[]; + auto* block_sum = (Type*)tmp; + auto* block_c = block_sum + blockDim.x; + + if (threadIdx.y == 0) { + block_sum[threadIdx.x] = Type(0); + block_c[threadIdx.x] = Type(0); + } __syncthreads(); - for (int j = blockDim.y / 2; j > 0; j /= 2) { - if (threadIdx.y < j) temp[myidx] += temp[myidx + j * blockDim.x]; - __syncthreads(); + // also compute compensation for block-sum + const Type old_sum = atomicAdd(block_sum + threadIdx.x, thread_sum); + const Type t = old_sum + thread_sum; + if (abs(old_sum) >= abs(thread_sum)) { + thread_c += (old_sum - t) + thread_sum; + } else { + thread_c += (thread_sum - t) + old_sum; } + raft::myAtomicAdd(block_c + threadIdx.x, thread_c); + __syncthreads(); // Grid reduction - if ((colStart < D) && (threadIdx.y == 0)) raft::myAtomicAdd(dots + colStart, temp[myidx]); + if (colStart < D && (threadIdx.y == 0)) + raft::myAtomicAdd(out + colStart, block_sum[threadIdx.x] + block_c[threadIdx.x]); } // Kernel to perform reductions along the strided dimension @@ -127,23 +152,35 @@ void stridedReduction(OutType* dots, /// for atomics in stridedKernel (redesign for this is already underway) if (!inplace) raft::linalg::unaryOp(dots, dots, D, raft::const_op(init), stream); - // Arbitrary numbers for now, probably need to tune - const dim3 thrds(32, 16); - IdxType elemsPerThread = raft::ceildiv(N, (IdxType)thrds.y); - elemsPerThread = (elemsPerThread > 8) ? 8 : elemsPerThread; - const dim3 nblks(raft::ceildiv(D, (IdxType)thrds.x), - raft::ceildiv(N, (IdxType)thrds.y * elemsPerThread)); - const size_t shmemSize = sizeof(OutType) * thrds.x * thrds.y; - ///@todo: this complication should go away once we have eliminated the need /// for atomics in stridedKernel (redesign for this is already underway) if constexpr (std::is_same::value && - std::is_same::value) + std::is_same::value) { + constexpr int TPB = 256; + constexpr int ColsPerBlk = 8; + constexpr dim3 Block(ColsPerBlk, TPB / ColsPerBlk); + constexpr int MinRowsPerThread = 16; + constexpr int MinRowsPerBlk = Block.y * MinRowsPerThread; + constexpr int MaxBlocksDimY = 8192; + + const dim3 grid(raft::ceildiv(D, (IdxType)ColsPerBlk), + raft::min((IdxType)MaxBlocksDimY, raft::ceildiv(N, (IdxType)MinRowsPerBlk))); + const size_t shmemSize = sizeof(OutType) * Block.x * 2; + stridedSummationKernel - <<>>(dots, data, D, N, init, main_op); - else + <<>>(dots, data, D, N, init, main_op); + } else { + // Arbitrary numbers for now, probably need to tune + const dim3 thrds(32, 16); + IdxType elemsPerThread = raft::ceildiv(N, (IdxType)thrds.y); + elemsPerThread = (elemsPerThread > 8) ? 8 : elemsPerThread; + const dim3 nblks(raft::ceildiv(D, (IdxType)thrds.x), + raft::ceildiv(N, (IdxType)thrds.y * elemsPerThread)); + const size_t shmemSize = sizeof(OutType) * thrds.x * thrds.y; + stridedReductionKernel <<>>(dots, data, D, N, init, main_op, reduce_op); + } ///@todo: this complication should go away once we have eliminated the need /// for atomics in stridedKernel (redesign for this is already underway) diff --git a/cpp/include/raft/linalg/reduce.cuh b/cpp/include/raft/linalg/reduce.cuh index a4523c6926..8fd6e45d37 100644 --- a/cpp/include/raft/linalg/reduce.cuh +++ b/cpp/include/raft/linalg/reduce.cuh @@ -31,6 +31,9 @@ namespace linalg { /** * @brief Compute reduction of the input matrix along the requested dimension + * In case of an add-reduction, a compensated summation will be performed + * in order to reduce numerical error. Note that the compensation will not + * be equivalent to a sequential compensation to preserve parallel efficiency. * * @tparam InType the data type of the input * @tparam OutType the data type of the output (as well as the data type for @@ -92,7 +95,9 @@ void reduce(OutType* dots, * is either row-major or column-major, while allowing the choose the * dimension for reduction. Depending upon the dimension chosen for * reduction, the memory accesses may be coalesced or strided. - * + * In case of an add-reduction, a compensated summation will be performed + * in order to reduce numerical error. Note that the compensation will not + * be equivalent to a sequential compensation to preserve parallel efficiency. * @tparam InElementType the input data-type of underlying raft::matrix_view * @tparam LayoutPolicy The layout of Input/Output (row or col major) * @tparam OutElementType the output data-type of underlying raft::matrix_view and reduction diff --git a/cpp/include/raft/stats/detail/mean.cuh b/cpp/include/raft/stats/detail/mean.cuh index 6c330acb26..ee39c87a68 100644 --- a/cpp/include/raft/stats/detail/mean.cuh +++ b/cpp/include/raft/stats/detail/mean.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include @@ -25,61 +26,23 @@ namespace raft { namespace stats { namespace detail { -///@todo: ColsPerBlk has been tested only for 32! -template -RAFT_KERNEL meanKernelRowMajor(Type* mu, const Type* data, IdxType D, IdxType N) -{ - const int RowsPerBlkPerIter = TPB / ColsPerBlk; - IdxType thisColId = threadIdx.x % ColsPerBlk; - IdxType thisRowId = threadIdx.x / ColsPerBlk; - IdxType colId = thisColId + ((IdxType)blockIdx.y * ColsPerBlk); - IdxType rowId = thisRowId + ((IdxType)blockIdx.x * RowsPerBlkPerIter); - Type thread_data = Type(0); - const IdxType stride = RowsPerBlkPerIter * gridDim.x; - for (IdxType i = rowId; i < N; i += stride) - thread_data += (colId < D) ? data[i * D + colId] : Type(0); - __shared__ Type smu[ColsPerBlk]; - if (threadIdx.x < ColsPerBlk) smu[threadIdx.x] = Type(0); - __syncthreads(); - raft::myAtomicAdd(smu + thisColId, thread_data); - __syncthreads(); - if (threadIdx.x < ColsPerBlk && colId < D) raft::myAtomicAdd(mu + colId, smu[thisColId]); -} - -template -RAFT_KERNEL meanKernelColMajor(Type* mu, const Type* data, IdxType D, IdxType N) -{ - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - Type thread_data = Type(0); - IdxType colStart = N * blockIdx.x; - for (IdxType i = threadIdx.x; i < N; i += TPB) { - IdxType idx = colStart + i; - thread_data += data[idx]; - } - Type acc = BlockReduce(temp_storage).Sum(thread_data); - if (threadIdx.x == 0) { mu[blockIdx.x] = acc / N; } -} - template void mean( Type* mu, const Type* data, IdxType D, IdxType N, bool sample, bool rowMajor, cudaStream_t stream) { - static const int TPB = 256; - if (rowMajor) { - static const int RowsPerThread = 4; - static const int ColsPerBlk = 32; - static const int RowsPerBlk = (TPB / ColsPerBlk) * RowsPerThread; - dim3 grid(raft::ceildiv(N, (IdxType)RowsPerBlk), raft::ceildiv(D, (IdxType)ColsPerBlk)); - RAFT_CUDA_TRY(cudaMemsetAsync(mu, 0, sizeof(Type) * D, stream)); - meanKernelRowMajor<<>>(mu, data, D, N); - RAFT_CUDA_TRY(cudaPeekAtLastError()); - Type ratio = Type(1) / (sample ? Type(N - 1) : Type(N)); - raft::linalg::scalarMultiply(mu, mu, ratio, D, stream); - } else { - meanKernelColMajor<<>>(mu, data, D, N); - } - RAFT_CUDA_TRY(cudaPeekAtLastError()); + Type ratio = Type(1) / ((sample) ? Type(N - 1) : Type(N)); + raft::linalg::reduce(mu, + data, + D, + N, + Type(0), + rowMajor, + false, + stream, + false, + raft::identity_op(), + raft::add_op(), + raft::mul_const_op(ratio)); } } // namespace detail diff --git a/cpp/include/raft/stats/detail/stddev.cuh b/cpp/include/raft/stats/detail/stddev.cuh index bc2644a233..4c861b49fb 100644 --- a/cpp/include/raft/stats/detail/stddev.cuh +++ b/cpp/include/raft/stats/detail/stddev.cuh @@ -16,7 +16,9 @@ #pragma once +#include #include +#include #include #include @@ -25,63 +27,6 @@ namespace raft { namespace stats { namespace detail { -///@todo: ColPerBlk has been tested only for 32! -template -RAFT_KERNEL stddevKernelRowMajor(Type* std, const Type* data, IdxType D, IdxType N) -{ - const int RowsPerBlkPerIter = TPB / ColsPerBlk; - IdxType thisColId = threadIdx.x % ColsPerBlk; - IdxType thisRowId = threadIdx.x / ColsPerBlk; - IdxType colId = thisColId + ((IdxType)blockIdx.y * ColsPerBlk); - IdxType rowId = thisRowId + ((IdxType)blockIdx.x * RowsPerBlkPerIter); - Type thread_data = Type(0); - const IdxType stride = RowsPerBlkPerIter * gridDim.x; - for (IdxType i = rowId; i < N; i += stride) { - Type val = (colId < D) ? data[i * D + colId] : Type(0); - thread_data += val * val; - } - __shared__ Type sstd[ColsPerBlk]; - if (threadIdx.x < ColsPerBlk) sstd[threadIdx.x] = Type(0); - __syncthreads(); - raft::myAtomicAdd(sstd + thisColId, thread_data); - __syncthreads(); - if (threadIdx.x < ColsPerBlk && colId < D) raft::myAtomicAdd(std + colId, sstd[thisColId]); -} - -template -RAFT_KERNEL stddevKernelColMajor(Type* std, const Type* data, const Type* mu, IdxType D, IdxType N) -{ - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - Type thread_data = Type(0); - IdxType colStart = N * blockIdx.x; - Type m = mu[blockIdx.x]; - for (IdxType i = threadIdx.x; i < N; i += TPB) { - IdxType idx = colStart + i; - Type diff = data[idx] - m; - thread_data += diff * diff; - } - Type acc = BlockReduce(temp_storage).Sum(thread_data); - if (threadIdx.x == 0) { std[blockIdx.x] = raft::sqrt(acc / N); } -} - -template -RAFT_KERNEL varsKernelColMajor(Type* var, const Type* data, const Type* mu, IdxType D, IdxType N) -{ - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - Type thread_data = Type(0); - IdxType colStart = N * blockIdx.x; - Type m = mu[blockIdx.x]; - for (IdxType i = threadIdx.x; i < N; i += TPB) { - IdxType idx = colStart + i; - Type diff = data[idx] - m; - thread_data += diff * diff; - } - Type acc = BlockReduce(temp_storage).Sum(thread_data); - if (threadIdx.x == 0) { var[blockIdx.x] = acc / N; } -} - /** * @brief Compute stddev of the input matrix * @@ -110,26 +55,22 @@ void stddev(Type* std, bool rowMajor, cudaStream_t stream) { - static const int TPB = 256; - if (rowMajor) { - static const int RowsPerThread = 4; - static const int ColsPerBlk = 32; - static const int RowsPerBlk = (TPB / ColsPerBlk) * RowsPerThread; - dim3 grid(raft::ceildiv(N, (IdxType)RowsPerBlk), raft::ceildiv(D, (IdxType)ColsPerBlk)); - RAFT_CUDA_TRY(cudaMemset(std, 0, sizeof(Type) * D)); - stddevKernelRowMajor<<>>(std, data, D, N); - Type ratio = Type(1) / (sample ? Type(N - 1) : Type(N)); - raft::linalg::binaryOp( - std, - std, - mu, - D, - [ratio] __device__(Type a, Type b) { return raft::sqrt(a * ratio - b * b); }, - stream); - } else { - stddevKernelColMajor<<>>(std, data, mu, D, N); - } - RAFT_CUDA_TRY(cudaPeekAtLastError()); + raft::linalg::reduce( + std, data, D, N, Type(0), rowMajor, false, stream, false, [mu] __device__(Type a, IdxType i) { + return a * a; + }); + Type ratio = Type(1) / ((sample) ? Type(N - 1) : Type(N)); + Type ratio_mean = sample ? ratio * Type(N) : Type(1); + raft::linalg::binaryOp(std, + std, + mu, + D, + raft::compose_op(raft::sqrt_op(), + raft::abs_op(), + [ratio, ratio_mean] __device__(Type a, Type b) { + return a * ratio - b * b * ratio_mean; + }), + stream); } /** @@ -160,21 +101,21 @@ void vars(Type* var, bool rowMajor, cudaStream_t stream) { - static const int TPB = 256; - if (rowMajor) { - static const int RowsPerThread = 4; - static const int ColsPerBlk = 32; - static const int RowsPerBlk = (TPB / ColsPerBlk) * RowsPerThread; - dim3 grid(raft::ceildiv(N, (IdxType)RowsPerBlk), raft::ceildiv(D, (IdxType)ColsPerBlk)); - RAFT_CUDA_TRY(cudaMemset(var, 0, sizeof(Type) * D)); - stddevKernelRowMajor<<>>(var, data, D, N); - Type ratio = Type(1) / (sample ? Type(N - 1) : Type(N)); - raft::linalg::binaryOp( - var, var, mu, D, [ratio] __device__(Type a, Type b) { return a * ratio - b * b; }, stream); - } else { - varsKernelColMajor<<>>(var, data, mu, D, N); - } - RAFT_CUDA_TRY(cudaPeekAtLastError()); + raft::linalg::reduce( + var, data, D, N, Type(0), rowMajor, false, stream, false, [mu] __device__(Type a, IdxType i) { + return a * a; + }); + Type ratio = Type(1) / ((sample) ? Type(N - 1) : Type(N)); + Type ratio_mean = sample ? ratio * Type(N) : Type(1); + raft::linalg::binaryOp(var, + var, + mu, + D, + raft::compose_op(raft::abs_op(), + [ratio, ratio_mean] __device__(Type a, Type b) { + return a * ratio - b * b * ratio_mean; + }), + stream); } } // namespace detail diff --git a/cpp/include/raft/stats/detail/sum.cuh b/cpp/include/raft/stats/detail/sum.cuh index 4f85536e6c..39bd2c3b6c 100644 --- a/cpp/include/raft/stats/detail/sum.cuh +++ b/cpp/include/raft/stats/detail/sum.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include @@ -25,106 +26,10 @@ namespace raft { namespace stats { namespace detail { -///@todo: ColsPerBlk has been tested only for 32! -template -RAFT_KERNEL sumKernelRowMajor(Type* mu, const Type* data, IdxType D, IdxType N) -{ - const int RowsPerBlkPerIter = TPB / ColsPerBlk; - IdxType thisColId = threadIdx.x % ColsPerBlk; - IdxType thisRowId = threadIdx.x / ColsPerBlk; - IdxType colId = thisColId + ((IdxType)blockIdx.y * ColsPerBlk); - IdxType rowId = thisRowId + ((IdxType)blockIdx.x * RowsPerBlkPerIter); - Type thread_sum = Type(0); - const IdxType stride = RowsPerBlkPerIter * gridDim.x; - for (IdxType i = rowId; i < N; i += stride) { - thread_sum += (colId < D) ? data[i * D + colId] : Type(0); - } - __shared__ Type smu[ColsPerBlk]; - if (threadIdx.x < ColsPerBlk) smu[threadIdx.x] = Type(0); - __syncthreads(); - raft::myAtomicAdd(smu + thisColId, thread_sum); - __syncthreads(); - if (threadIdx.x < ColsPerBlk && colId < D) raft::myAtomicAdd(mu + colId, smu[thisColId]); -} - -template -RAFT_KERNEL sumKahanKernelRowMajor(Type* mu, const Type* data, IdxType D, IdxType N) -{ - constexpr int RowsPerBlkPerIter = TPB / ColsPerBlk; - IdxType thisColId = threadIdx.x % ColsPerBlk; - IdxType thisRowId = threadIdx.x / ColsPerBlk; - IdxType colId = thisColId + ((IdxType)blockIdx.y * ColsPerBlk); - IdxType rowId = thisRowId + ((IdxType)blockIdx.x * RowsPerBlkPerIter); - Type thread_sum = Type(0); - Type thread_c = Type(0); - const IdxType stride = RowsPerBlkPerIter * gridDim.x; - for (IdxType i = rowId; i < N; i += stride) { - // KahanBabushkaNeumaierSum - const Type cur_value = (colId < D) ? data[i * D + colId] : Type(0); - const Type t = thread_sum + cur_value; - if (abs(thread_sum) >= abs(cur_value)) { - thread_c += (thread_sum - t) + cur_value; - } else { - thread_c += (cur_value - t) + thread_sum; - } - thread_sum = t; - } - thread_sum += thread_c; - __shared__ Type smu[ColsPerBlk]; - if (threadIdx.x < ColsPerBlk) smu[threadIdx.x] = Type(0); - __syncthreads(); - raft::myAtomicAdd(smu + thisColId, thread_sum); - __syncthreads(); - if (threadIdx.x < ColsPerBlk && colId < D) raft::myAtomicAdd(mu + colId, smu[thisColId]); -} - -template -RAFT_KERNEL sumKahanKernelColMajor(Type* mu, const Type* data, IdxType D, IdxType N) -{ - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - Type thread_sum = Type(0); - Type thread_c = Type(0); - IdxType colStart = N * blockIdx.x; - for (IdxType i = threadIdx.x; i < N; i += TPB) { - // KahanBabushkaNeumaierSum - IdxType idx = colStart + i; - const Type cur_value = data[idx]; - const Type t = thread_sum + cur_value; - if (abs(thread_sum) >= abs(cur_value)) { - thread_c += (thread_sum - t) + cur_value; - } else { - thread_c += (cur_value - t) + thread_sum; - } - thread_sum = t; - } - thread_sum += thread_c; - Type acc = BlockReduce(temp_storage).Sum(thread_sum); - if (threadIdx.x == 0) { mu[blockIdx.x] = acc; } -} - template void sum(Type* output, const Type* input, IdxType D, IdxType N, bool rowMajor, cudaStream_t stream) { - static const int TPB = 256; - if (rowMajor) { - static const int ColsPerBlk = 8; - static const int MinRowsPerThread = 16; - static const int MinRowsPerBlk = (TPB / ColsPerBlk) * MinRowsPerThread; - static const int MaxBlocksDimX = 8192; - - const IdxType grid_y = raft::ceildiv(D, (IdxType)ColsPerBlk); - const IdxType grid_x = - raft::min((IdxType)MaxBlocksDimX, raft::ceildiv(N, (IdxType)MinRowsPerBlk)); - - dim3 grid(grid_x, grid_y); - RAFT_CUDA_TRY(cudaMemset(output, 0, sizeof(Type) * D)); - sumKahanKernelRowMajor - <<>>(output, input, D, N); - } else { - sumKahanKernelColMajor<<>>(output, input, D, N); - } - RAFT_CUDA_TRY(cudaPeekAtLastError()); + raft::linalg::reduce(output, input, D, N, Type(0), rowMajor, false, stream); } } // namespace detail diff --git a/cpp/test/stats/cov.cu b/cpp/test/stats/cov.cu index 41812979b6..602f356b9f 100644 --- a/cpp/test/stats/cov.cu +++ b/cpp/test/stats/cov.cu @@ -40,7 +40,8 @@ struct CovInputs { template ::std::ostream& operator<<(::std::ostream& os, const CovInputs& dims) { - return os; + return os << "{ " << dims.tolerance << ", " << dims.rows << ", " << dims.cols << ", " + << dims.sample << ", " << dims.rowMajor << "}" << std::endl; } template @@ -71,8 +72,7 @@ class CovTest : public ::testing::TestWithParam> { cov_act.resize(cols * cols, stream); normal(handle, r, data.data(), len, params.mean, var); - raft::stats::mean( - mean_act.data(), data.data(), cols, rows, params.sample, params.rowMajor, stream); + raft::stats::mean(mean_act.data(), data.data(), cols, rows, false, params.rowMajor, stream); if (params.rowMajor) { using layout = raft::row_major; cov(handle, @@ -102,7 +102,7 @@ class CovTest : public ::testing::TestWithParam> { raft::update_device(data_cm.data(), data_h, 6, stream); raft::update_device(cov_cm_ref.data(), cov_cm_ref_h, 4, stream); - raft::stats::mean(mean_cm.data(), data_cm.data(), 2, 3, true, false, stream); + raft::stats::mean(mean_cm.data(), data_cm.data(), 2, 3, false, false, stream); cov(handle, cov_cm.data(), data_cm.data(), mean_cm.data(), 2, 3, true, false, true, stream); } diff --git a/cpp/test/stats/mean.cu b/cpp/test/stats/mean.cu index 61b57ce739..c5fe83d95b 100644 --- a/cpp/test/stats/mean.cu +++ b/cpp/test/stats/mean.cu @@ -35,12 +35,14 @@ struct MeanInputs { int rows, cols; bool sample, rowMajor; unsigned long long int seed; + T stddev = (T)1.0; }; template ::std::ostream& operator<<(::std::ostream& os, const MeanInputs& dims) { - return os; + return os << "{ " << dims.tolerance << ", " << dims.rows << ", " << dims.cols << ", " + << dims.sample << ", " << dims.rowMajor << ", " << dims.stddev << "}" << std::endl; } template @@ -61,7 +63,7 @@ class MeanTest : public ::testing::TestWithParam> { { raft::random::RngState r(params.seed); int len = rows * cols; - normal(handle, r, data.data(), len, params.mean, (T)1.0); + normal(handle, r, data.data(), len, params.mean, params.stddev); meanSGtest(data.data(), stream); } @@ -96,38 +98,72 @@ class MeanTest : public ::testing::TestWithParam> { // measured mean (of a normal distribution) will fall outside of an epsilon of // 0.15 only 4/10000 times. (epsilon of 0.1 will fail 30/100 times) const std::vector> inputsf = { - {0.15f, 1.f, 1024, 32, true, false, 1234ULL}, {0.15f, 1.f, 1024, 64, true, false, 1234ULL}, - {0.15f, 1.f, 1024, 128, true, false, 1234ULL}, {0.15f, 1.f, 1024, 256, true, false, 1234ULL}, - {0.15f, -1.f, 1024, 32, false, false, 1234ULL}, {0.15f, -1.f, 1024, 64, false, false, 1234ULL}, - {0.15f, -1.f, 1024, 128, false, false, 1234ULL}, {0.15f, -1.f, 1024, 256, false, false, 1234ULL}, - {0.15f, 1.f, 1024, 32, true, true, 1234ULL}, {0.15f, 1.f, 1024, 64, true, true, 1234ULL}, - {0.15f, 1.f, 1024, 128, true, true, 1234ULL}, {0.15f, 1.f, 1024, 256, true, true, 1234ULL}, - {0.15f, -1.f, 1024, 32, false, true, 1234ULL}, {0.15f, -1.f, 1024, 64, false, true, 1234ULL}, - {0.15f, -1.f, 1024, 128, false, true, 1234ULL}, {0.15f, -1.f, 1024, 256, false, true, 1234ULL}, - {0.15f, -1.f, 1030, 1, false, false, 1234ULL}, {0.15f, -1.f, 1030, 60, true, false, 1234ULL}, - {2.0f, -1.f, 31, 120, false, false, 1234ULL}, {2.0f, -1.f, 1, 130, true, false, 1234ULL}, - {0.15f, -1.f, 1030, 1, false, true, 1234ULL}, {0.15f, -1.f, 1030, 60, true, true, 1234ULL}, - {2.0f, -1.f, 31, 120, false, true, 1234ULL}, {2.0f, -1.f, 1, 130, false, true, 1234ULL}, - {2.0f, -1.f, 1, 1, false, false, 1234ULL}, {2.0f, -1.f, 1, 1, false, true, 1234ULL}, - {2.0f, -1.f, 7, 23, false, false, 1234ULL}, {2.0f, -1.f, 7, 23, false, true, 1234ULL}, - {2.0f, -1.f, 17, 5, false, false, 1234ULL}, {2.0f, -1.f, 17, 5, false, true, 1234ULL}}; + {0.15f, 1.f, 1024, 32, true, false, 1234ULL}, + {0.15f, 1.f, 1024, 64, true, false, 1234ULL}, + {0.15f, 1.f, 1024, 128, true, false, 1234ULL}, + {0.15f, 1.f, 1024, 256, true, false, 1234ULL}, + {0.15f, -1.f, 1024, 32, false, false, 1234ULL}, + {0.15f, -1.f, 1024, 64, false, false, 1234ULL}, + {0.15f, -1.f, 1024, 128, false, false, 1234ULL}, + {0.15f, -1.f, 1024, 256, false, false, 1234ULL}, + {0.15f, 1.f, 1024, 32, true, true, 1234ULL}, + {0.15f, 1.f, 1024, 64, true, true, 1234ULL}, + {0.15f, 1.f, 1024, 128, true, true, 1234ULL}, + {0.15f, 1.f, 1024, 256, true, true, 1234ULL}, + {0.15f, -1.f, 1024, 32, false, true, 1234ULL}, + {0.15f, -1.f, 1024, 64, false, true, 1234ULL}, + {0.15f, -1.f, 1024, 128, false, true, 1234ULL}, + {0.15f, -1.f, 1024, 256, false, true, 1234ULL}, + {0.15f, -1.f, 1030, 1, false, false, 1234ULL}, + {0.15f, -1.f, 1030, 60, true, false, 1234ULL}, + {2.0f, -1.f, 31, 120, false, false, 1234ULL}, + {2.0f, -1.f, 1, 130, false, false, 1234ULL}, + {0.15f, -1.f, 1030, 1, false, true, 1234ULL}, + {0.15f, -1.f, 1030, 60, true, true, 1234ULL}, + {2.0f, -1.f, 31, 120, false, true, 1234ULL}, + {2.0f, -1.f, 1, 130, false, true, 1234ULL}, + {2.0f, -1.f, 1, 1, false, false, 1234ULL}, + {2.0f, -1.f, 1, 1, false, true, 1234ULL}, + {2.0f, -1.f, 7, 23, false, false, 1234ULL}, + {2.0f, -1.f, 7, 23, false, true, 1234ULL}, + {2.0f, -1.f, 17, 5, false, false, 1234ULL}, + {2.0f, -1.f, 17, 5, false, true, 1234ULL}, + {0.0001f, 0.1f, 1 << 27, 2, false, false, 1234ULL, 0.0001f}, + {0.0001f, 0.1f, 1 << 27, 2, false, true, 1234ULL, 0.0001f}}; const std::vector> inputsd = { - {0.15, 1.0, 1024, 32, true, false, 1234ULL}, {0.15, 1.0, 1024, 64, true, false, 1234ULL}, - {0.15, 1.0, 1024, 128, true, false, 1234ULL}, {0.15, 1.0, 1024, 256, true, false, 1234ULL}, - {0.15, -1.0, 1024, 32, false, false, 1234ULL}, {0.15, -1.0, 1024, 64, false, false, 1234ULL}, - {0.15, -1.0, 1024, 128, false, false, 1234ULL}, {0.15, -1.0, 1024, 256, false, false, 1234ULL}, - {0.15, 1.0, 1024, 32, true, true, 1234ULL}, {0.15, 1.0, 1024, 64, true, true, 1234ULL}, - {0.15, 1.0, 1024, 128, true, true, 1234ULL}, {0.15, 1.0, 1024, 256, true, true, 1234ULL}, - {0.15, -1.0, 1024, 32, false, true, 1234ULL}, {0.15, -1.0, 1024, 64, false, true, 1234ULL}, - {0.15, -1.0, 1024, 128, false, true, 1234ULL}, {0.15, -1.0, 1024, 256, false, true, 1234ULL}, - {0.15, -1.0, 1030, 1, false, false, 1234ULL}, {0.15, -1.0, 1030, 60, true, false, 1234ULL}, - {2.0, -1.0, 31, 120, false, false, 1234ULL}, {2.0, -1.0, 1, 130, true, false, 1234ULL}, - {0.15, -1.0, 1030, 1, false, true, 1234ULL}, {0.15, -1.0, 1030, 60, true, true, 1234ULL}, - {2.0, -1.0, 31, 120, false, true, 1234ULL}, {2.0, -1.0, 1, 130, false, true, 1234ULL}, - {2.0, -1.0, 1, 1, false, false, 1234ULL}, {2.0, -1.0, 1, 1, false, true, 1234ULL}, - {2.0, -1.0, 7, 23, false, false, 1234ULL}, {2.0, -1.0, 7, 23, false, true, 1234ULL}, - {2.0, -1.0, 17, 5, false, false, 1234ULL}, {2.0, -1.0, 17, 5, false, true, 1234ULL}}; + {0.15, 1.0, 1024, 32, true, false, 1234ULL}, + {0.15, 1.0, 1024, 64, true, false, 1234ULL}, + {0.15, 1.0, 1024, 128, true, false, 1234ULL}, + {0.15, 1.0, 1024, 256, true, false, 1234ULL}, + {0.15, -1.0, 1024, 32, false, false, 1234ULL}, + {0.15, -1.0, 1024, 64, false, false, 1234ULL}, + {0.15, -1.0, 1024, 128, false, false, 1234ULL}, + {0.15, -1.0, 1024, 256, false, false, 1234ULL}, + {0.15, 1.0, 1024, 32, true, true, 1234ULL}, + {0.15, 1.0, 1024, 64, true, true, 1234ULL}, + {0.15, 1.0, 1024, 128, true, true, 1234ULL}, + {0.15, 1.0, 1024, 256, true, true, 1234ULL}, + {0.15, -1.0, 1024, 32, false, true, 1234ULL}, + {0.15, -1.0, 1024, 64, false, true, 1234ULL}, + {0.15, -1.0, 1024, 128, false, true, 1234ULL}, + {0.15, -1.0, 1024, 256, false, true, 1234ULL}, + {0.15, -1.0, 1030, 1, false, false, 1234ULL}, + {0.15, -1.0, 1030, 60, true, false, 1234ULL}, + {2.0, -1.0, 31, 120, false, false, 1234ULL}, + {2.0, -1.0, 1, 130, false, false, 1234ULL}, + {0.15, -1.0, 1030, 1, false, true, 1234ULL}, + {0.15, -1.0, 1030, 60, true, true, 1234ULL}, + {2.0, -1.0, 31, 120, false, true, 1234ULL}, + {2.0, -1.0, 1, 130, false, true, 1234ULL}, + {2.0, -1.0, 1, 1, false, false, 1234ULL}, + {2.0, -1.0, 1, 1, false, true, 1234ULL}, + {2.0, -1.0, 7, 23, false, false, 1234ULL}, + {2.0, -1.0, 7, 23, false, true, 1234ULL}, + {2.0, -1.0, 17, 5, false, false, 1234ULL}, + {2.0, -1.0, 17, 5, false, true, 1234ULL}, + {1e-8, 1e-1, 1 << 27, 2, false, false, 1234ULL, 0.0001}, + {1e-8, 1e-1, 1 << 27, 2, false, true, 1234ULL, 0.0001}}; typedef MeanTest MeanTestF; TEST_P(MeanTestF, Result) diff --git a/cpp/test/stats/stddev.cu b/cpp/test/stats/stddev.cu index 641621c1c6..f4c5f92f49 100644 --- a/cpp/test/stats/stddev.cu +++ b/cpp/test/stats/stddev.cu @@ -39,7 +39,8 @@ struct StdDevInputs { template ::std::ostream& operator<<(::std::ostream& os, const StdDevInputs& dims) { - return os; + return os << "{ " << dims.tolerance << ", " << dims.rows << ", " << dims.cols << ", " + << dims.sample << ", " << dims.rowMajor << "}" << std::endl; } template @@ -81,7 +82,7 @@ class StdDevTest : public ::testing::TestWithParam> { mean(handle, raft::make_device_matrix_view(data, rows, cols), raft::make_device_vector_view(mean_act.data(), cols), - params.sample); + false); stddev(handle, raft::make_device_matrix_view(data, rows, cols), @@ -99,7 +100,7 @@ class StdDevTest : public ::testing::TestWithParam> { mean(handle, raft::make_device_matrix_view(data, rows, cols), raft::make_device_vector_view(mean_act.data(), cols), - params.sample); + false); stddev(handle, raft::make_device_matrix_view(data, rows, cols), @@ -147,13 +148,15 @@ const std::vector> inputsf = { {0.5f, -1.f, 2.f, 31, 1, true, true, 1234ULL}, {1.f, -1.f, 2.f, 1, 257, false, true, 1234ULL}, {0.5f, -1.f, 2.f, 31, 1, false, false, 1234ULL}, - {1.f, -1.f, 2.f, 1, 257, true, false, 1234ULL}, + {1.f, -1.f, 2.f, 1, 257, false, false, 1234ULL}, {1.f, -1.f, 2.f, 1, 1, false, false, 1234ULL}, {1.f, -1.f, 2.f, 7, 23, false, false, 1234ULL}, {1.f, -1.f, 2.f, 17, 5, false, false, 1234ULL}, {1.f, -1.f, 2.f, 1, 1, false, true, 1234ULL}, {1.f, -1.f, 2.f, 7, 23, false, true, 1234ULL}, - {1.f, -1.f, 2.f, 17, 5, false, true, 1234ULL}}; + {1.f, -1.f, 2.f, 17, 5, false, true, 1234ULL}, + {0.00001f, 0.001f, 0.f, 1 << 27, 2, false, true, 1234ULL}, + {0.00001f, 0.001f, 0.f, 1 << 27, 2, false, false, 1234ULL}}; const std::vector> inputsd = { {0.1, 1.0, 2.0, 1024, 32, true, false, 1234ULL}, @@ -177,13 +180,15 @@ const std::vector> inputsd = { {0.5, -1.0, 2.0, 31, 1, true, true, 1234ULL}, {1.0, -1.0, 2.0, 1, 257, false, true, 1234ULL}, {0.5, -1.0, 2.0, 31, 1, false, false, 1234ULL}, - {1.0, -1.0, 2.0, 1, 257, true, false, 1234ULL}, + {1.0, -1.0, 2.0, 1, 257, false, false, 1234ULL}, {1.0, -1.0, 2.0, 1, 1, false, false, 1234ULL}, {1.0, -1.0, 2.0, 7, 23, false, false, 1234ULL}, {1.0, -1.0, 2.0, 17, 5, false, false, 1234ULL}, {1.0, -1.0, 2.0, 1, 1, false, true, 1234ULL}, {1.0, -1.0, 2.0, 7, 23, false, true, 1234ULL}, - {1.0, -1.0, 2.0, 17, 5, false, true, 1234ULL}}; + {1.0, -1.0, 2.0, 17, 5, false, true, 1234ULL}, + {1e-7, 0.001, 0.0, 1 << 27, 2, false, true, 1234ULL}, + {1e-7, 0.001, 0.0, 1 << 27, 2, false, false, 1234ULL}}; typedef StdDevTest StdDevTestF; TEST_P(StdDevTestF, Result) diff --git a/cpp/test/stats/sum.cu b/cpp/test/stats/sum.cu index bf2aa44a2c..fbb398cc5b 100644 --- a/cpp/test/stats/sum.cu +++ b/cpp/test/stats/sum.cu @@ -40,7 +40,8 @@ struct SumInputs { template ::std::ostream& operator<<(::std::ostream& os, const SumInputs& dims) { - return os; + return os << "{ " << dims.tolerance << ", " << dims.rows << ", " << dims.cols << ", " + << dims.rowMajor << ", " << dims.value << "}" << std::endl; } template @@ -57,13 +58,31 @@ class SumTest : public ::testing::TestWithParam> { } protected: - void runTest() + void runTest(bool checkErrorCompensation = false) { int len = rows * cols; + double large_factor = 1e7; + + if constexpr (std::is_same_v) large_factor = 1e12; + std::vector data_h(len); for (int i = 0; i < len; i++) { - data_h[i] = T(params.value); + data_h[i] = double(params.value); + int row = params.rowMajor ? i / cols : i % rows; + + // every 3 elements (in a column) contain 2 large dummy elements + // (one of them negative) and one element with 3x compensating + // for the 2 missing elements + if (checkErrorCompensation && row % 3 == 2) { + data_h[i] = double(params.value) * large_factor; + // compensate with opposite error 3 rows up + int idx2 = params.rowMajor ? (i - cols) : (i - 1); + data_h[idx2] = -1 * double(params.value) * large_factor; + // compensate 2 missing values + int idx3 = params.rowMajor ? (i - 2 * cols) : (i - 2); + data_h[idx3] = 3.0 * double(params.value); + } } raft::update_device(data.data(), data_h.data(), len, stream); @@ -83,8 +102,10 @@ class SumTest : public ::testing::TestWithParam> { double expected = double(params.rows) * params.value; + double tolerance = checkErrorCompensation ? 100 * params.tolerance : params.tolerance; + ASSERT_TRUE(raft::devArrMatch( - T(expected), sum_act.data(), params.cols, raft::CompareApprox(params.tolerance))); + T(expected), sum_act.data(), params.cols, raft::CompareApprox(tolerance))); } protected: @@ -96,43 +117,29 @@ class SumTest : public ::testing::TestWithParam> { rmm::device_uvector data, sum_act; }; -const std::vector> inputsf = {{0.0001f, 4, 5, true, 1}, - {0.0001f, 1024, 32, true, 1}, - {0.0001f, 1024, 256, true, 1}, - {0.0001f, 100000000, 1, true, 0.001}, - {0.0001f, 1, 30, true, 0.001}, - {0.0001f, 1, 1, true, 0.001}, - {0.0001f, 17, 5, true, 0.001}, - {0.0001f, 7, 23, true, 0.001}, - {0.0001f, 3, 97, true, 0.001}, - {0.0001f, 4, 5, false, 1}, - {0.0001f, 1024, 32, false, 1}, - {0.0001f, 1024, 256, false, 1}, - {0.0001f, 100000000, 1, false, 0.001}, - {0.0001f, 1, 30, false, 0.001}, - {0.0001f, 1, 1, false, 0.001}, - {0.0001f, 17, 5, false, 0.001}, - {0.0001f, 7, 23, false, 0.001}, - {0.0001f, 3, 97, false, 0.001}}; - -const std::vector> inputsd = {{0.000001, 1024, 32, true, 1}, - {0.000001, 1024, 256, true, 1}, - {0.000001, 1024, 256, true, 1}, - {0.000001, 100000000, 1, true, 0.001}, - {0.000001, 1, 30, true, 0.0001}, - {0.000001, 1, 1, true, 0.0001}, - {0.000001, 17, 5, true, 0.0001}, - {0.000001, 7, 23, true, 0.0001}, - {0.000001, 3, 97, true, 0.0001}, - {0.000001, 1024, 32, false, 1}, - {0.000001, 1024, 256, false, 1}, - {0.000001, 1024, 256, false, 1}, - {0.000001, 100000000, 1, false, 0.001}, - {0.000001, 1, 30, false, 0.0001}, - {0.000001, 1, 1, false, 0.0001}, - {0.000001, 17, 5, false, 0.0001}, - {0.000001, 7, 23, false, 0.0001}, - {0.000001, 3, 97, false, 0.0001}}; +const std::vector> inputsf = { + {0.0001f, 4, 5, true, 1}, {0.0001f, 1024, 32, true, 1}, + {0.0001f, 1024, 256, true, 1}, {0.0001f, 100000000, 1, true, 0.001}, + {0.0001f, 1 << 27, 2, true, 0.1}, {0.0001f, 1, 30, true, 0.001}, + {0.0001f, 1, 1, true, 0.001}, {0.0001f, 17, 5, true, 0.001}, + {0.0001f, 7, 23, true, 0.001}, {0.0001f, 3, 97, true, 0.001}, + {0.0001f, 4, 5, false, 1}, {0.0001f, 1024, 32, false, 1}, + {0.0001f, 1024, 256, false, 1}, {0.0001f, 100000000, 1, false, 0.001}, + {0.0001f, 1 << 27, 2, false, 0.1}, {0.0001f, 1, 30, false, 0.001}, + {0.0001f, 1, 1, false, 0.001}, {0.0001f, 17, 5, false, 0.001}, + {0.0001f, 7, 23, false, 0.001}, {0.0001f, 3, 97, false, 0.001}}; + +const std::vector> inputsd = { + {0.000001, 1024, 32, true, 1}, {0.000001, 1024, 256, true, 1}, + {0.000001, 1024, 256, true, 1}, {0.000001, 100000000, 1, true, 0.001}, + {1e-9, 1 << 27, 2, true, 0.1}, {0.000001, 1, 30, true, 0.0001}, + {0.000001, 1, 1, true, 0.0001}, {0.000001, 17, 5, true, 0.0001}, + {0.000001, 7, 23, true, 0.0001}, {0.000001, 3, 97, true, 0.0001}, + {0.000001, 1024, 32, false, 1}, {0.000001, 1024, 256, false, 1}, + {0.000001, 1024, 256, false, 1}, {0.000001, 100000000, 1, false, 0.001}, + {1e-9, 1 << 27, 2, false, 0.1}, {0.000001, 1, 30, false, 0.0001}, + {0.000001, 1, 1, false, 0.0001}, {0.000001, 17, 5, false, 0.0001}, + {0.000001, 7, 23, false, 0.0001}, {0.000001, 3, 97, false, 0.0001}}; typedef SumTest SumTestF; typedef SumTest SumTestD; @@ -140,6 +147,9 @@ typedef SumTest SumTestD; TEST_P(SumTestF, Result) { runTest(); } TEST_P(SumTestD, Result) { runTest(); } +TEST_P(SumTestF, Accuracy) { runTest(true); } +TEST_P(SumTestD, Accuracy) { runTest(true); } + INSTANTIATE_TEST_CASE_P(SumTests, SumTestF, ::testing::ValuesIn(inputsf)); INSTANTIATE_TEST_CASE_P(SumTests, SumTestD, ::testing::ValuesIn(inputsd));