Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Update scan accum / binary_op edgecase handling.
Browse files Browse the repository at this point in the history
TBB's scan was implemented differently than the other backends, leading
to some failing unit tests.

This patch fixes these inconsistencies by making the following changes:

- Follow P0571's guidance regarding accumulator variable type.
  - https://wg21.link/P0571
  - The accumulator's type is now:
    - The type of the user-supplied initial value (if provided), or
    - The input iterator's value type if no initial value.
- Follow C++ standard guidance for default binary operator type.
  - https://eel.is/c++draft/exclusive.scan#1
  - Thrust binary/unary functors now specialize a default void template
    parameter. Types are deduced and forwarded transparently.
  - Updated the scan's default binary operator to the new
    `thrust::plus<>` specialization.
- The `intermediate_type_from_function_and_iterators` helper is no
  longer needed and has been removed.

Closes #1170.
  • Loading branch information
alliepiper committed Jun 8, 2020
1 parent e478243 commit b528451
Show file tree
Hide file tree
Showing 11 changed files with 335 additions and 363 deletions.
73 changes: 37 additions & 36 deletions testing/scan.cu
Original file line number Diff line number Diff line change
Expand Up @@ -250,48 +250,49 @@ void TestScanMixedTypes(void)

IntVector int_output(4);
FloatVector float_output(4);
// float -> int should use using plus<int> operator by default

// float -> int should use plus<void> operator and float accumulator by default
thrust::inclusive_scan(float_input.begin(), float_input.end(), int_output.begin());
ASSERT_EQUAL(int_output[0], 1);
ASSERT_EQUAL(int_output[1], 3);
ASSERT_EQUAL(int_output[2], 6);
ASSERT_EQUAL(int_output[3], 10);
// float -> float with plus<int> operator (int accumulator)
ASSERT_EQUAL(int_output[0], 1); // in: 1.5 accum: 1.5f out: 1
ASSERT_EQUAL(int_output[1], 4); // in: 2.5 accum: 4.0f out: 4
ASSERT_EQUAL(int_output[2], 7); // in: 3.5 accum: 7.5f out: 7
ASSERT_EQUAL(int_output[3], 12); // in: 4.5 accum: 12.f out: 12

// float -> float with plus<int> operator (float accumulator)
thrust::inclusive_scan(float_input.begin(), float_input.end(), float_output.begin(), thrust::plus<int>());
ASSERT_EQUAL(float_output[0], 1.5);
ASSERT_EQUAL(float_output[1], 3.0);
ASSERT_EQUAL(float_output[2], 6.0);
ASSERT_EQUAL(float_output[3], 10.0);
// float -> int should use using plus<int> operator by default
ASSERT_EQUAL(float_output[0], 1.5f); // in: 1.5 accum: 1.5f out: 1.5f
ASSERT_EQUAL(float_output[1], 3.0f); // in: 2.5 accum: 3.0f out: 3.0f
ASSERT_EQUAL(float_output[2], 6.0f); // in: 3.5 accum: 6.0f out: 6.0f
ASSERT_EQUAL(float_output[3], 10.0f); // in: 4.5 accum: 10.f out: 10.f

// float -> int should use plus<void> operator and float accumulator by default
thrust::exclusive_scan(float_input.begin(), float_input.end(), int_output.begin());
ASSERT_EQUAL(int_output[0], 0);
ASSERT_EQUAL(int_output[1], 1);
ASSERT_EQUAL(int_output[2], 3);
ASSERT_EQUAL(int_output[3], 6);
// float -> int should use using plus<int> operator by default
ASSERT_EQUAL(int_output[0], 0); // out: 0.0f in: 1.5 accum: 1.5f
ASSERT_EQUAL(int_output[1], 1); // out: 1.5f in: 2.5 accum: 4.0f
ASSERT_EQUAL(int_output[2], 4); // out: 4.0f in: 3.5 accum: 7.5f
ASSERT_EQUAL(int_output[3], 7); // out: 7.5f in: 4.5 accum: 12.f

// float -> int should use plus<> operator and float accumulator by default
thrust::exclusive_scan(float_input.begin(), float_input.end(), int_output.begin(), (float) 5.5);
ASSERT_EQUAL(int_output[0], 5);
ASSERT_EQUAL(int_output[1], 7);
ASSERT_EQUAL(int_output[2], 9);
ASSERT_EQUAL(int_output[3], 13);
// int -> float should use using plus<float> operator by default
ASSERT_EQUAL(int_output[0], 5); // out: 5.5f in: 1.5 accum: 7.0f
ASSERT_EQUAL(int_output[1], 7); // out: 7.0f in: 2.5 accum: 9.5f
ASSERT_EQUAL(int_output[2], 9); // out: 9.5f in: 3.5 accum: 13.0f
ASSERT_EQUAL(int_output[3], 13); // out: 13.f in: 4.5 accum: 17.4f

// int -> float should use using plus<> operator and int accumulator by default
thrust::inclusive_scan(int_input.begin(), int_input.end(), float_output.begin());
ASSERT_EQUAL(float_output[0], 1.0);
ASSERT_EQUAL(float_output[1], 3.0);
ASSERT_EQUAL(float_output[2], 6.0);
ASSERT_EQUAL(float_output[3], 10.0);

// int -> float should use using plus<float> operator by default
ASSERT_EQUAL(float_output[0], 1.f); // in: 1 accum: 1 out: 1
ASSERT_EQUAL(float_output[1], 3.f); // in: 2 accum: 3 out: 3
ASSERT_EQUAL(float_output[2], 6.f); // in: 3 accum: 6 out: 6
ASSERT_EQUAL(float_output[3], 10.f); // in: 4 accum: 10 out: 10

// int -> float + float init_value should use using plus<> operator and
// float accumulator by default
thrust::exclusive_scan(int_input.begin(), int_input.end(), float_output.begin(), (float) 5.5);
ASSERT_EQUAL(float_output[0], 5.5);
ASSERT_EQUAL(float_output[1], 6.5);
ASSERT_EQUAL(float_output[2], 8.5);
ASSERT_EQUAL(float_output[3], 11.5);
ASSERT_EQUAL(float_output[0], 5.5f); // out: 5.5f in: 1 accum: 6.5f
ASSERT_EQUAL(float_output[1], 6.5f); // out: 6.0f in: 2 accum: 8.5f
ASSERT_EQUAL(float_output[2], 8.5f); // out: 8.0f in: 3 accum: 11.5f
ASSERT_EQUAL(float_output[3], 11.5f); // out: 11.f in: 4 accum: 15.5f
}
void TestScanMixedTypesHost(void)
{
Expand Down

This file was deleted.

Loading

0 comments on commit b528451

Please sign in to comment.