Skip to content

Commit

Permalink
Handle execute_on_stream in binary_search algorithms
Browse files Browse the repository at this point in the history
The stream information got lost when using the `thrust::reference` API
to pass inputs/outputs to/from the device.

Fixes NVIDIA#921
Bug 2173437
  • Loading branch information
alliepiper committed May 15, 2020
1 parent d09ba09 commit 9a3cfbf
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 8 deletions.
25 changes: 25 additions & 0 deletions testing/cuda/binary_search.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include <unittest/unittest.h>

#include <thrust/binary_search.h>
#include <thrust/device_vector.h>
#include <thrust/distance.h>
#include <thrust/pair.h>
#include <thrust/sequence.h>

void TestEqualRangeOnStream()
{ // Regression test for GH issue #921 (nvbug 2173437)
typedef typename thrust::device_vector<int> vector_t;
typedef typename vector_t::iterator iterator_t;
typedef thrust::pair<iterator_t, iterator_t> result_t;

vector_t input(10);
thrust::sequence(thrust::device, input.begin(), input.end(), 0);
cudaStream_t stream = 0;
result_t result = thrust::equal_range(thrust::cuda::par.on(stream),
input.begin(), input.end(),
5);

ASSERT_EQUAL(5, thrust::distance(input.begin(), result.first));
ASSERT_EQUAL(6, thrust::distance(input.begin(), result.second));
}
DECLARE_UNITTEST(TestEqualRangeOnStream);
1 change: 1 addition & 0 deletions testing/cuda/binary_search.mk
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CUDACC_FLAGS += -rdc=true
32 changes: 24 additions & 8 deletions thrust/system/detail/generic/binary_search.inl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <thrust/for_each.h>
#include <thrust/detail/function.h>
#include <thrust/system/detail/generic/scalar/binary_search.h>
#include <thrust/system/detail/generic/select_system.h>

#include <thrust/detail/temporary_array.h>
#include <thrust/detail/type_traits.h>
Expand Down Expand Up @@ -150,19 +151,34 @@ OutputType binary_search(thrust::execution_policy<DerivedPolicy> &exec,
BinarySearchFunction func)
{
// use the vectorized path to implement the scalar version

// allocate device buffers for value and output
thrust::detail::temporary_array<T,DerivedPolicy> d_value(exec,1);
thrust::detail::temporary_array<OutputType,DerivedPolicy> d_output(exec,1);

// copy value to device
d_value[0] = value;


{ // copy value to device
typedef typename thrust::iterator_system<const T*>::type value_in_system_t;
value_in_system_t value_in_system;
using thrust::system::detail::generic::select_system;
thrust::copy_n(select_system(thrust::detail::derived_cast(thrust::detail::strip_const(value_in_system)),
thrust::detail::derived_cast(thrust::detail::strip_const(exec))),
&value, 1, d_value.begin());
}

// perform the query
thrust::system::detail::generic::detail::binary_search(exec, begin, end, d_value.begin(), d_value.end(), d_output.begin(), comp, func);

// copy result to host and return
return d_output[0];

OutputType output;
{ // copy result to host and return
typedef typename thrust::iterator_system<OutputType*>::type result_out_system_t;
result_out_system_t result_out_system;
using thrust::system::detail::generic::select_system;
thrust::copy_n(select_system(thrust::detail::derived_cast(thrust::detail::strip_const(exec)),
thrust::detail::derived_cast(thrust::detail::strip_const(result_out_system))),
d_output.begin(), 1, &output);
}

return output;
}


Expand Down

0 comments on commit 9a3cfbf

Please sign in to comment.