Skip to content

Commit

Permalink
specialize grouped rolling MEAN of timestamp
Browse files Browse the repository at this point in the history
  • Loading branch information
karthikeyann committed May 29, 2020
1 parent fe2e54c commit c80de81
Showing 1 changed file with 38 additions and 2 deletions.
40 changes: 38 additions & 2 deletions cpp/tests/grouped_rolling/grouped_rolling_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,23 @@ class GroupedRollingTest : public cudf::test::BaseFixture {
return col.release();
}

template <typename OutputType,
typename agg_op,
bool is_mean_of_timestamp,
std::enable_if_t<!is_mean_of_timestamp>* = nullptr>
auto run_op(OutputType& val, OutputType const& in)
{
val = agg_op{}(in, val);
}
template <typename OutputType,
typename agg_op,
bool is_mean_of_timestamp,
std::enable_if_t<is_mean_of_timestamp>* = nullptr>
auto run_op(OutputType& val, OutputType const& in)
{
val = static_cast<OutputType>(agg_op{}(in.time_since_epoch(), val.time_since_epoch()));
}

template <typename agg_op,
cudf::aggregation::Kind k,
typename OutputType,
Expand Down Expand Up @@ -374,7 +391,8 @@ class GroupedRollingTest : public cudf::test::BaseFixture {
size_type count = 0;
for (size_type j = start_index; j < end_index; j++) {
if (!input.nullable() || cudf::bit_is_set(valid_mask, j)) {
val = op(static_cast<OutputType>(in_col[j]), val);
run_op<OutputType, agg_op, (is_mean and cudf::is_timestamp<OutputType>())>(
val, static_cast<OutputType>(in_col[j]));
count++;
}
}
Expand Down Expand Up @@ -941,6 +959,23 @@ class GroupedTimeRangeRollingTest : public cudf::test::BaseFixture {
return col.release();
}

template <typename OutputType,
typename agg_op,
bool is_mean_of_timestamp,
std::enable_if_t<!is_mean_of_timestamp>* = nullptr>
auto run_op(OutputType& val, OutputType const& in)
{
val = agg_op{}(in, val);
}
template <typename OutputType,
typename agg_op,
bool is_mean_of_timestamp,
std::enable_if_t<is_mean_of_timestamp>* = nullptr>
auto run_op(OutputType& val, OutputType const& in)
{
val = static_cast<OutputType>(agg_op{}(in.time_since_epoch(), val.time_since_epoch()));
}

template <typename agg_op,
cudf::aggregation::Kind k,
typename OutputType,
Expand Down Expand Up @@ -1010,7 +1045,8 @@ class GroupedTimeRangeRollingTest : public cudf::test::BaseFixture {
size_type count = 0;
for (size_type j = start_index; j < end_index; j++) {
if (!input.nullable() || cudf::bit_is_set(valid_mask, j)) {
val = op(static_cast<OutputType>(in_col[j]), val);
run_op<OutputType, agg_op, (is_mean and cudf::is_timestamp<OutputType>())>(
val, static_cast<OutputType>(in_col[j]));
count++;
}
}
Expand Down

0 comments on commit c80de81

Please sign in to comment.