Skip to content

Commit

Permalink
add unit test temporarily
Browse files Browse the repository at this point in the history
  • Loading branch information
karthikeyann committed Aug 31, 2021
1 parent 4c989a9 commit 015795c
Showing 1 changed file with 59 additions and 0 deletions.
59 changes: 59 additions & 0 deletions cpp/tests/groupby/mean_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/

#include <cmath>
#include <cudf/detail/aggregation/aggregation.hpp>
#include <cudf/utilities/traits.hpp>

Expand Down Expand Up @@ -160,5 +161,63 @@ TEST_F(groupby_dictionary_mean_test, basic)
keys, vals, expect_keys, expect_vals, cudf::make_mean_aggregation<groupby_aggregation>());
}

struct groupby_corr_test : public cudf::test::BaseFixture {
};
template <typename T>
using fwcw = fixed_width_column_wrapper<T>;
using structs = structs_column_wrapper;

TEST_F(groupby_corr_test, basic)
{
using K = int32_t;
using M0 = uint8_t;
using M1 = int16_t;
using R = cudf::detail::target_type_t<M0, aggregation::CORR>;

// clang-format off
auto keys = fwcw<K> { 1, 2, 3, 1, 2, 2, 1, 3, 3, 2 };
auto member_0 = fwcw<M0>{{ 1, 1, 1, 2, 2, 3, 3, 1, 1, 4 }};//, null_at(1)};
auto member_1 = fwcw<M1>{{ 1, 1, 1, 2, -2, 3, 3, 1, 1, -4 }};//, null_at(7)};
auto values = structs{{member_0, member_1}};//, null_at(4)};
// clang-format on

fixed_width_column_wrapper<K> expect_keys({1, 2, 3});
fixed_width_column_wrapper<R, double> expect_vals{
{1.000000, -0.41522739926869984, std::numeric_limits<double>::quiet_NaN()}}; //, null_at(2)};
// clang-format on

auto agg = cudf::make_corr_aggregation<groupby_aggregation>();
std::vector<groupby::aggregation_request> requests;
requests.emplace_back(groupby::aggregation_request());
requests[0].values = values;

requests[0].aggregations.push_back(std::move(agg));
requests.emplace_back(groupby::aggregation_request());
// WAR to force groupby to use sort implementation
requests[0].aggregations.push_back(make_nth_element_aggregation<groupby_aggregation>(0));

requests[1].values = column_view(values).child(0);
requests[1].aggregations.push_back(cudf::make_mean_aggregation<groupby_aggregation>());
requests[1].aggregations.push_back(cudf::make_std_aggregation<groupby_aggregation>());
requests.emplace_back(groupby::aggregation_request());
requests[2].values = column_view(values).child(1);
requests[2].aggregations.push_back(cudf::make_mean_aggregation<groupby_aggregation>());
requests[2].aggregations.push_back(cudf::make_std_aggregation<groupby_aggregation>());

groupby::groupby gb_obj(table_view({keys}));
auto result = gb_obj.aggregate(requests);

cudf::test::print(*result.second[0].results[0]);
cudf::test::print(*result.second[1].results[0]);
cudf::test::print(*result.second[1].results[1]);
cudf::test::print(*result.second[2].results[0]);
cudf::test::print(*result.second[2].results[1]);

CUDF_TEST_EXPECT_TABLES_EQUAL(table_view({expect_keys}), result.first->view());
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(
expect_vals, *result.second[0].results[0], debug_output_level::ALL_ERRORS);
// test_single_agg(keys, values, expect_keys, expect_vals, std::move(agg));
}

} // namespace test
} // namespace cudf

0 comments on commit 015795c

Please sign in to comment.