From 015795cf875b75f0088be30f5e38c11bcacb6363 Mon Sep 17 00:00:00 2001 From: Karthikeyan Natarajan Date: Tue, 31 Aug 2021 23:38:57 +0530 Subject: [PATCH] add unit test temporarily --- cpp/tests/groupby/mean_tests.cpp | 59 ++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/cpp/tests/groupby/mean_tests.cpp b/cpp/tests/groupby/mean_tests.cpp index 613e1555b79..9bceebfb241 100644 --- a/cpp/tests/groupby/mean_tests.cpp +++ b/cpp/tests/groupby/mean_tests.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include @@ -160,5 +161,63 @@ TEST_F(groupby_dictionary_mean_test, basic) keys, vals, expect_keys, expect_vals, cudf::make_mean_aggregation()); } +struct groupby_corr_test : public cudf::test::BaseFixture { +}; +template +using fwcw = fixed_width_column_wrapper; +using structs = structs_column_wrapper; + +TEST_F(groupby_corr_test, basic) +{ + using K = int32_t; + using M0 = uint8_t; + using M1 = int16_t; + using R = cudf::detail::target_type_t; + + // clang-format off + auto keys = fwcw { 1, 2, 3, 1, 2, 2, 1, 3, 3, 2 }; + auto member_0 = fwcw{{ 1, 1, 1, 2, 2, 3, 3, 1, 1, 4 }};//, null_at(1)}; + auto member_1 = fwcw{{ 1, 1, 1, 2, -2, 3, 3, 1, 1, -4 }};//, null_at(7)}; + auto values = structs{{member_0, member_1}};//, null_at(4)}; + // clang-format on + + fixed_width_column_wrapper expect_keys({1, 2, 3}); + fixed_width_column_wrapper expect_vals{ + {1.000000, -0.41522739926869984, std::numeric_limits::quiet_NaN()}}; //, null_at(2)}; + // clang-format on + + auto agg = cudf::make_corr_aggregation(); + std::vector requests; + requests.emplace_back(groupby::aggregation_request()); + requests[0].values = values; + + requests[0].aggregations.push_back(std::move(agg)); + requests.emplace_back(groupby::aggregation_request()); + // WAR to force groupby to use sort implementation + requests[0].aggregations.push_back(make_nth_element_aggregation(0)); + + requests[1].values = column_view(values).child(0); + requests[1].aggregations.push_back(cudf::make_mean_aggregation()); + requests[1].aggregations.push_back(cudf::make_std_aggregation()); + requests.emplace_back(groupby::aggregation_request()); + requests[2].values = column_view(values).child(1); + requests[2].aggregations.push_back(cudf::make_mean_aggregation()); + requests[2].aggregations.push_back(cudf::make_std_aggregation()); + + groupby::groupby gb_obj(table_view({keys})); + auto result = gb_obj.aggregate(requests); + + cudf::test::print(*result.second[0].results[0]); + cudf::test::print(*result.second[1].results[0]); + cudf::test::print(*result.second[1].results[1]); + cudf::test::print(*result.second[2].results[0]); + cudf::test::print(*result.second[2].results[1]); + + CUDF_TEST_EXPECT_TABLES_EQUAL(table_view({expect_keys}), result.first->view()); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + expect_vals, *result.second[0].results[0], debug_output_level::ALL_ERRORS); + // test_single_agg(keys, values, expect_keys, expect_vals, std::move(agg)); +} + } // namespace test } // namespace cudf