Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Merge pull request #1832 from senior-zero/fix-main/github/reduce_offsets
Browse files Browse the repository at this point in the history
Let CUB select reduce offsets
  • Loading branch information
gevtushenko authored Nov 25, 2022
2 parents d4f3fa9 + af899e3 commit 9f443fd
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
20 changes: 20 additions & 0 deletions testing/cuda/reduce.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <unittest/unittest.h>
#include <thrust/reduce.h>
#include <thrust/execution_policy.h>
#include <thrust/iterator/constant_iterator.h>


template<typename ExecutionPolicy, typename Iterator, typename T, typename Iterator2>
Expand Down Expand Up @@ -99,3 +100,22 @@ void TestReduceCudaStreamsNoSync()
}
DECLARE_UNITTEST(TestReduceCudaStreamsNoSync);

#if defined(THRUST_RDC_ENABLED)
void TestReduceLargeInput()
{
using T = unsigned long long;
using OffsetT = std::size_t;
const OffsetT num_items = 1ull << 32;

thrust::constant_iterator<T> d_data(T{1});
thrust::device_vector<T> d_result(1);

reduce_kernel<<<1,1>>>(thrust::device, d_data, d_data + num_items, T{}, d_result.begin());
cudaError_t const err = cudaDeviceSynchronize();
ASSERT_EQUAL(cudaSuccess, err);

ASSERT_EQUAL(num_items, d_result[0]);
}
DECLARE_UNITTEST(TestReduceLargeInput);
#endif

10 changes: 2 additions & 8 deletions thrust/system/cuda/detail/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -943,11 +943,8 @@ T reduce_n_impl(execution_policy<Derived>& policy,

size_t tmp_size = 0;

THRUST_INDEX_TYPE_DISPATCH2(status,
THRUST_INDEX_TYPE_DISPATCH(status,
cub::DeviceReduce::Reduce,
(cub::DispatchReduce<
InputIt, T*, Size, BinaryOp, T
>::Dispatch),
num_items,
(NULL, tmp_size, first, reinterpret_cast<T*>(NULL),
num_items_fixed, binary_op, init, stream));
Expand All @@ -970,11 +967,8 @@ T reduce_n_impl(execution_policy<Derived>& policy,
// make this guarantee.
T* ret_ptr = thrust::detail::aligned_reinterpret_cast<T*>(tmp.data().get());
void* tmp_ptr = static_cast<void*>((tmp.data() + sizeof(T)).get());
THRUST_INDEX_TYPE_DISPATCH2(status,
THRUST_INDEX_TYPE_DISPATCH(status,
cub::DeviceReduce::Reduce,
(cub::DispatchReduce<
InputIt, T*, Size, BinaryOp, T
>::Dispatch),
num_items,
(tmp_ptr, tmp_size, first, ret_ptr,
num_items_fixed, binary_op, init, stream));
Expand Down

0 comments on commit 9f443fd

Please sign in to comment.