This repository has been archived by the owner on Mar 21, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 757
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
945cd09
commit 45c8380
Showing
9 changed files
with
1,223 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
#include <thrust/detail/config.h> | ||
|
||
#if THRUST_CPP_DIALECT >= 2014 | ||
|
||
#include <unittest/unittest.h> | ||
#include <unittest/util_async.h> | ||
|
||
#include <thrust/async/scan.h> | ||
|
||
#include <thrust/device_vector.h> | ||
#include <thrust/host_vector.h> | ||
|
||
// TODO Finish implementing tests. Draw from other async algos, as well as | ||
// the older scan tests. | ||
|
||
namespace | ||
{ | ||
|
||
template <typename value_type> | ||
struct async_exclusive_scan_def | ||
{ | ||
public: | ||
// Input and output types for the algorithms: | ||
using input_type = thrust::device_vector<value_type>; | ||
using output_type = thrust::device_vector<value_type>; | ||
|
||
using postfix_args_type = std::tuple< // List any extra arg overloads: | ||
std::tuple<>, // - no extra args | ||
std::tuple<value_type>, // - initial_value | ||
std::tuple<value_type, thrust::maximum<>> // - initial_value, binary_op | ||
>; | ||
|
||
// Create instances of the extra arguments to use when invoking the | ||
// algorithm: | ||
static postfix_args_type generate_postfix_args() | ||
{ | ||
return { | ||
{}, // no extra args | ||
{42}, // initial_value | ||
{57, thrust::maximum<>{}} // initial_value, binary_op | ||
}; | ||
} | ||
|
||
// Generate an instance of the input: | ||
static input_type generate_input() | ||
{ | ||
input_type input(1024); | ||
thrust::sequence(input.begin(), input.end(), 25, 3); | ||
return input; | ||
} | ||
|
||
// Invoke a reference implementation for a single overload as described by | ||
// postfix_tuple. This tuple contains instances of any additional arguments | ||
// to pass to the algorithm. The tuple/index_sequence pattern is used to | ||
// support the "no extra args" overload, since the parameter pack expansion | ||
// will do exactly what we want in all cases. | ||
template <typename PostfixArgTuple, std::size_t... PostfixArgIndices> | ||
static void invoke_reference(PostfixArgTuple &&postfix_tuple, | ||
std::index_sequence<PostfixArgIndices...>, | ||
input_type const &input, | ||
output_type &output) | ||
{ | ||
// Create host versions of the input/output: | ||
thrust::host_vector<value_type> host_input(input); | ||
thrust::host_vector<value_type> host_output(input.size()); | ||
|
||
// Run host synchronous algorithm to generate reference. | ||
thrust::exclusive_scan(host_input.cbegin(), | ||
host_input.cend(), | ||
host_output.begin(), | ||
std::get<PostfixArgIndices>( | ||
THRUST_FWD(postfix_tuple))...); | ||
|
||
// Copy back to device. | ||
output = host_output; | ||
} | ||
|
||
// Invoke the async algorithm for a single overload as described by | ||
// the prefix and postfix tuples. These tuple contains instances of any | ||
// additional arguments to pass to the algorithm. The tuple/index_sequence | ||
// pattern is used to support the "no extra args" overload, since the | ||
// parameter pack expansion will do exactly what we want in all cases. | ||
// Prefix args are included here (but not for invoke_reference) to allow the | ||
// test framework to change the execution policy. | ||
// This method must return an event or future. | ||
template <typename PrefixArgTuple, | ||
std::size_t... PrefixArgIndices, | ||
typename PostfixArgTuple, | ||
std::size_t... PostfixArgIndices> | ||
static auto invoke_async(PrefixArgTuple &&prefix_tuple, | ||
std::index_sequence<PrefixArgIndices...>, | ||
PostfixArgTuple &&postfix_tuple, | ||
std::index_sequence<PostfixArgIndices...>, | ||
input_type const &input, | ||
output_type &output) | ||
{ | ||
output.resize(input.size()); | ||
auto e = thrust::async::exclusive_scan( | ||
std::get<PrefixArgIndices>(THRUST_FWD(prefix_tuple))..., | ||
input.cbegin(), | ||
input.cend(), | ||
output.begin(), | ||
std::get<PostfixArgIndices>(THRUST_FWD(postfix_tuple))...); | ||
return e; | ||
} | ||
|
||
// Wait on and validate the event/future (usually with TEST_EVENT_WAIT / | ||
// TEST_FUTURE_VALUE_RETRIEVAL), then check that the reference output matches | ||
// the testing output. | ||
template <typename EventType> | ||
static void compare_outputs(EventType &e, | ||
output_type const &ref, | ||
output_type const &test) | ||
{ | ||
TEST_EVENT_WAIT(e); | ||
ASSERT_EQUAL_QUIET(ref, test); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
void TestPolicyOverloads() | ||
{ | ||
// Only ints are tested here because we just want to check that the policies | ||
// are propagated correctly, so keep codegen to a minimum. | ||
unittest::test_async_policy_overloads<async_exclusive_scan_def<int>>::run(); | ||
} | ||
DECLARE_UNITTEST(TestPolicyOverloads); | ||
|
||
#endif // C++14 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
#include <thrust/detail/config.h> | ||
|
||
#if THRUST_CPP_DIALECT >= 2014 | ||
|
||
#include <unittest/unittest.h> | ||
#include <unittest/util_async.h> | ||
|
||
#include <thrust/async/scan.h> | ||
|
||
#include <thrust/device_vector.h> | ||
#include <thrust/host_vector.h> | ||
|
||
// TODO Finish implementing tests. Draw from other async algos, as well as | ||
// the older scan tests. | ||
|
||
namespace | ||
{ | ||
|
||
template <typename value_type> | ||
struct async_inclusive_scan_def | ||
{ | ||
public: | ||
using input_type = thrust::device_vector<value_type>; | ||
using output_type = thrust::device_vector<value_type>; | ||
|
||
using postfix_args_type = std::tuple< // List any extra arg overloads: | ||
std::tuple<>, // - no extra args | ||
std::tuple<thrust::maximum<>> // - Non-default binary-op | ||
>; | ||
|
||
static postfix_args_type generate_postfix_args() | ||
{ | ||
return { | ||
{}, // - no extra args | ||
{thrust::maximum<>{}} // - non-default binary_op | ||
}; | ||
} | ||
|
||
static input_type generate_input() | ||
{ | ||
input_type input(1024); | ||
thrust::sequence(input.begin(), input.end(), 25, 3); | ||
return input; | ||
} | ||
|
||
template <typename PostfixArgTuple, std::size_t... PostfixArgIndices> | ||
static void invoke_reference(PostfixArgTuple &&postfix_tuple, | ||
std::index_sequence<PostfixArgIndices...>, | ||
input_type const &input, | ||
output_type &output) | ||
{ | ||
// Create host versions of the input/output: | ||
thrust::host_vector<value_type> host_input(input); | ||
thrust::host_vector<value_type> host_output(input.size()); | ||
|
||
// Run host synchronous algorithm to generate reference. | ||
thrust::inclusive_scan(host_input.cbegin(), | ||
host_input.cend(), | ||
host_output.begin(), | ||
std::get<PostfixArgIndices>( | ||
THRUST_FWD(postfix_tuple))...); | ||
|
||
// Copy back to device. | ||
output = host_output; | ||
} | ||
|
||
template <typename PrefixArgTuple, | ||
std::size_t... PrefixArgIndices, | ||
typename PostfixArgTuple, | ||
std::size_t... PostfixArgIndices> | ||
static auto invoke_async(PrefixArgTuple &&prefix_tuple, | ||
std::index_sequence<PrefixArgIndices...>, | ||
PostfixArgTuple &&postfix_tuple, | ||
std::index_sequence<PostfixArgIndices...>, | ||
input_type const &input, | ||
output_type &output) | ||
{ | ||
output.resize(input.size()); | ||
auto e = thrust::async::inclusive_scan( | ||
std::get<PrefixArgIndices>(THRUST_FWD(prefix_tuple))..., | ||
input.cbegin(), | ||
input.cend(), | ||
output.begin(), | ||
std::get<PostfixArgIndices>(THRUST_FWD(postfix_tuple))...); | ||
return e; | ||
} | ||
|
||
template <typename EventType> | ||
static void compare_outputs(EventType &e, | ||
output_type const &ref, | ||
output_type const &test) | ||
{ | ||
TEST_EVENT_WAIT(e); | ||
ASSERT_EQUAL_QUIET(ref, test); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
void TestPolicyOverloads() | ||
{ | ||
// Only ints are tested here because we just want to check that the policies | ||
// are propagated correctly, so keep codegen to a minimum. | ||
unittest::test_async_policy_overloads<async_inclusive_scan_def<int>>::run(); | ||
} | ||
DECLARE_UNITTEST(TestPolicyOverloads); | ||
|
||
#endif // C++14 |
Oops, something went wrong.