From f313fa0f29153f0208eec8ad25838df3f6656d9b Mon Sep 17 00:00:00 2001 From: Georgy Evtushenko Date: Tue, 15 Nov 2022 23:25:50 +0400 Subject: [PATCH] Fix overflow in reduce --- cub/agent/agent_reduce.cuh | 9 ++++++--- test/test_device_reduce.cu | 14 +++++++------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/cub/agent/agent_reduce.cuh b/cub/agent/agent_reduce.cuh index aa18efa07e..62523d5c3f 100644 --- a/cub/agent/agent_reduce.cuh +++ b/cub/agent/agent_reduce.cuh @@ -355,7 +355,7 @@ struct AgentReduce { AccumT thread_aggregate{}; - if (even_share.block_offset + TILE_ITEMS > even_share.block_end) + if (even_share.block_end - even_share.block_offset < TILE_ITEMS) { // First tile isn't full (not all threads have valid items) int valid_items = even_share.block_end - even_share.block_offset; @@ -377,14 +377,17 @@ struct AgentReduce even_share.block_offset += even_share.block_stride; // Consume subsequent full tiles of input - while (even_share.block_offset + TILE_ITEMS <= even_share.block_end) + while (even_share.block_offset <= even_share.block_end - TILE_ITEMS) { ConsumeTile(thread_aggregate, even_share.block_offset, TILE_ITEMS, Int2Type(), can_vectorize); - even_share.block_offset += even_share.block_stride; + const OffsetT new_offset = even_share.block_offset + even_share.block_stride; + even_share.block_offset = (new_offset > even_share.block_offset) + ? new_offset + : even_share.block_end; } // Consume a partially-full tile diff --git a/test/test_device_reduce.cu b/test/test_device_reduce.cu index b3df906d12..113dbd2e77 100644 --- a/test/test_device_reduce.cu +++ b/test/test_device_reduce.cu @@ -1333,10 +1333,10 @@ __global__ void InitializeTestAccumulatorTypes(int num_items, } } -template -void TestBigIndicesHelper(int magnitude) +template +void TestBigIndicesHelper(OffsetT num_items) { - const std::size_t num_items = 1ll << magnitude; thrust::constant_iterator const_iter(T{1}); thrust::device_vector out(1); std::size_t* d_out = thrust::raw_pointer_cast(out.data()); @@ -1360,10 +1360,10 @@ void TestBigIndicesHelper(int magnitude) template void TestBigIndices() { - TestBigIndicesHelper(30); - TestBigIndicesHelper(31); - TestBigIndicesHelper(32); - TestBigIndicesHelper(33); + TestBigIndicesHelper(1ull << 30); + TestBigIndicesHelper(1ull << 31); + TestBigIndicesHelper((1ull << 32) - 1); + TestBigIndicesHelper(1ull << 33); } void TestAccumulatorTypes()