diff --git a/cpp/tests/iterator/iterator_tests.cuh b/cpp/tests/iterator/iterator_tests.cuh index 4ec347c4bc1..07eb595449c 100644 --- a/cpp/tests/iterator/iterator_tests.cuh +++ b/cpp/tests/iterator/iterator_tests.cuh @@ -18,8 +18,8 @@ #include #include -#include // include iterator header -#include //for meanvar +#include +#include // for meanvar #include #include @@ -28,6 +28,7 @@ #include #include +#include #include #include @@ -83,7 +84,17 @@ struct IteratorTest : public cudf::test::BaseFixture { EXPECT_EQ(thrust::distance(d_in, d_in_last), num_items); auto dev_expected = cudf::detail::make_device_uvector_sync(expected); - bool result = thrust::equal(thrust::device, d_in, d_in_last, dev_expected.begin()); + // using a temporary vector and calling transform and all_of separately is + // equivalent to thrust::equal but compiles ~3x faster + auto dev_results = rmm::device_uvector(num_items, rmm::cuda_stream_default); + thrust::transform(thrust::device, + d_in, + d_in_last, + dev_expected.begin(), + dev_results.begin(), + thrust::equal_to{}); + auto result = thrust::all_of( + thrust::device, dev_results.begin(), dev_results.end(), thrust::identity{}); EXPECT_TRUE(result) << "thrust test"; } diff --git a/cpp/tests/iterator/optional_iterator_test_numeric.cu b/cpp/tests/iterator/optional_iterator_test_numeric.cu index 6d51f4a5c14..a8c135a726f 100644 --- a/cpp/tests/iterator/optional_iterator_test_numeric.cu +++ b/cpp/tests/iterator/optional_iterator_test_numeric.cu @@ -50,21 +50,15 @@ struct transformer_optional_meanvar { } }; -struct sum_if_not_null { - template - CUDA_HOST_DEVICE_CALLABLE thrust::optional operator()(const thrust::optional& lhs, - const thrust::optional& rhs) - { - return lhs.value_or(T{0}) + rhs.value_or(T{0}); - } +template +struct optional_to_meanvar { + CUDA_HOST_DEVICE_CALLABLE T operator()(const thrust::optional& v) { return v.value_or(T{0}); } }; // TODO: enable this test also at __CUDACC_DEBUG__ // This test causes fatal compilation error only at device debug mode. // Workaround: exclude this test only at device debug mode. #if !defined(__CUDACC_DEBUG__) -// This test computes `count`, `sum`, `sum_of_squares` at a single reduction call. -// It would be useful for `var`, `std` operation TYPED_TEST(NumericOptionalIteratorTest, mean_var_output) { using T = TypeParam; @@ -104,22 +98,27 @@ TYPED_TEST(NumericOptionalIteratorTest, mean_var_output) expected_value.value_squared = std::accumulate( replaced_array.begin(), replaced_array.end(), T{0}, [](T acc, T i) { return acc + i * i; }); - // std::cout << "expected = " << expected_value << std::endl; - // GPU test auto it_dev = d_col->optional_begin(cudf::contains_nulls::YES{}); auto it_dev_squared = thrust::make_transform_iterator(it_dev, transformer); - auto result = thrust::reduce(it_dev_squared, - it_dev_squared + d_col->size(), - thrust::optional{T_output{}}, - sum_if_not_null{}); + + // this can be computed with a single reduce and without a temporary output vector + // but the approach increases the compile time by ~2x + auto results = rmm::device_uvector(d_col->size(), rmm::cuda_stream_default); + thrust::transform(thrust::device, + it_dev_squared, + it_dev_squared + d_col->size(), + results.begin(), + optional_to_meanvar{}); + auto result = thrust::reduce(thrust::device, results.begin(), results.end(), T_output{}); + if (not std::is_floating_point()) { - EXPECT_EQ(expected_value, *result) << "optional iterator reduction sum"; + EXPECT_EQ(expected_value, result) << "optional iterator reduction sum"; } else { - EXPECT_NEAR(expected_value.value, result->value, 1e-3) << "optional iterator reduction sum"; - EXPECT_NEAR(expected_value.value_squared, result->value_squared, 1e-3) + EXPECT_NEAR(expected_value.value, result.value, 1e-3) << "optional iterator reduction sum"; + EXPECT_NEAR(expected_value.value_squared, result.value_squared, 1e-3) << "optional iterator reduction sum squared"; - EXPECT_EQ(expected_value.count, result->count) << "optional iterator reduction count"; + EXPECT_EQ(expected_value.count, result.count) << "optional iterator reduction count"; } } #endif