Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Let CUB select reduce offsets #1832

Merged
merged 1 commit into from
Nov 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions testing/cuda/reduce.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <unittest/unittest.h>
#include <thrust/reduce.h>
#include <thrust/execution_policy.h>
#include <thrust/iterator/constant_iterator.h>


template<typename ExecutionPolicy, typename Iterator, typename T, typename Iterator2>
Expand Down Expand Up @@ -99,3 +100,22 @@ void TestReduceCudaStreamsNoSync()
}
DECLARE_UNITTEST(TestReduceCudaStreamsNoSync);

#if defined(THRUST_RDC_ENABLED)
void TestReduceLargeInput()
{
using T = unsigned long long;
using OffsetT = std::size_t;
const OffsetT num_items = 1ull << 32;

thrust::constant_iterator<T> d_data(T{1});
thrust::device_vector<T> d_result(1);

reduce_kernel<<<1,1>>>(thrust::device, d_data, d_data + num_items, T{}, d_result.begin());
cudaError_t const err = cudaDeviceSynchronize();
ASSERT_EQUAL(cudaSuccess, err);

ASSERT_EQUAL(num_items, d_result[0]);
}
DECLARE_UNITTEST(TestReduceLargeInput);
#endif

10 changes: 2 additions & 8 deletions thrust/system/cuda/detail/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -943,11 +943,8 @@ T reduce_n_impl(execution_policy<Derived>& policy,

size_t tmp_size = 0;

THRUST_INDEX_TYPE_DISPATCH2(status,
THRUST_INDEX_TYPE_DISPATCH(status,
cub::DeviceReduce::Reduce,
(cub::DispatchReduce<
InputIt, T*, Size, BinaryOp, T
>::Dispatch),
num_items,
(NULL, tmp_size, first, reinterpret_cast<T*>(NULL),
num_items_fixed, binary_op, init, stream));
Expand All @@ -970,11 +967,8 @@ T reduce_n_impl(execution_policy<Derived>& policy,
// make this guarantee.
T* ret_ptr = thrust::detail::aligned_reinterpret_cast<T*>(tmp.data().get());
void* tmp_ptr = static_cast<void*>((tmp.data() + sizeof(T)).get());
THRUST_INDEX_TYPE_DISPATCH2(status,
THRUST_INDEX_TYPE_DISPATCH(status,
cub::DeviceReduce::Reduce,
(cub::DispatchReduce<
InputIt, T*, Size, BinaryOp, T
>::Dispatch),
num_items,
(tmp_ptr, tmp_size, first, ret_ptr,
num_items_fixed, binary_op, init, stream));
Expand Down