diff --git a/thrust/system/cuda/detail/reduce.h b/thrust/system/cuda/detail/reduce.h index 793f0624d..db84bf439 100644 --- a/thrust/system/cuda/detail/reduce.h +++ b/thrust/system/cuda/detail/reduce.h @@ -939,20 +939,19 @@ reduce_n(execution_policy &policy, if (__THRUST_HAS_CUDART__) { - detail::temporary_array ret(policy, 1); - // Determine temporary device storage requirements. + T* ret_ptr = NULL; size_t tmp_size = 0; cuda_cub::throw_on_error( cub::DeviceReduce::Reduce(NULL, tmp_size, - first, ret.begin(), num_items, binary_op, init, + first, ret_ptr, num_items, binary_op, init, stream, THRUST_DEBUG_SYNC_FLAG), "after reduction step 1"); // Allocate temporary storage. - detail::temporary_array tmp(policy, tmp_size); + detail::temporary_array tmp(policy, sizeof(T) + tmp_size); // Run reduction. @@ -960,21 +959,24 @@ reduce_n(execution_policy &policy, // `reference`, which has an `operator&` that returns a `pointer`, which // has a `.get` method that returns a raw pointer, which we can (finally) // `static_cast` to `void*`. - void* tmp_ptr = static_cast((&*tmp.begin()).get()); + ret_ptr = reinterpret_cast((&*tmp.begin()).get()); + void* tmp_ptr = static_cast((&*(tmp.begin() + sizeof(T))).get()); cuda_cub::throw_on_error( cub::DeviceReduce::Reduce(tmp_ptr, tmp_size, - first, ret.begin(), num_items, binary_op, init, + first, ret_ptr, num_items, binary_op, init, stream, THRUST_DEBUG_SYNC_FLAG), "after reduction step 2"); + // Synchronize the stream and get the value. + cuda_cub::throw_on_error(cuda_cub::synchronize(policy), "reduce failed to synchronize"); - // `ret.begin()` yields a `normal_iterator`, which dereferences to a + // `tmp.begin()` yields a `normal_iterator`, which dereferences to a // `reference`, which has an `operator&` that returns a `pointer`, which // has a `.get` method that returns a raw pointer, which we can (finally) // `static_cast` to `void*`. - return cuda_cub::get_value(policy, (&*ret.begin()).get()); + return cuda_cub::get_value(policy, reinterpret_cast((&*tmp.begin()).get())); } #if !__THRUST_HAS_CUDART__