Skip to content

Commit

Permalink
specialize rolling MEAN for timestamp
Browse files Browse the repository at this point in the history
  • Loading branch information
karthikeyann committed May 29, 2020
1 parent 3e72ab0 commit afb4f34
Showing 1 changed file with 41 additions and 1 deletion.
42 changes: 41 additions & 1 deletion cpp/src/rolling/rolling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ template <typename InputType,
bool has_nulls>
std::enable_if_t<!std::is_same<InputType, cudf::string_view>::value and
!(op == aggregation::COUNT_VALID || op == aggregation::COUNT_ALL ||
op == aggregation::ROW_NUMBER),
op == aggregation::ROW_NUMBER) and
!(is_timestamp<InputType>() and op == aggregation::MEAN),
bool>
__device__ process_rolling_window(column_device_view input,
mutable_column_device_view output,
Expand Down Expand Up @@ -192,6 +193,45 @@ std::enable_if_t<!std::is_same<InputType, cudf::string_view>::value and
return output_is_valid;
}

/**
* @brief Mean on only timestamp types and returns true if the
* operation was valid, else false.
*/
template <typename InputType,
typename OutputType,
typename agg_op,
aggregation::Kind op,
bool has_nulls>
std::enable_if_t<(is_timestamp<InputType>() and op == aggregation::MEAN), bool> __device__
process_rolling_window(column_device_view input,
mutable_column_device_view output,
size_type start_index,
size_type end_index,
size_type current_index,
size_type min_periods)
{
// declare this as volatile to avoid some compiler optimizations that lead to incorrect results
// for CUDA 10.0 and below (fixed in CUDA 10.1)
volatile cudf::size_type count = 0;
OutputType val = agg_op::template identity<OutputType>();

for (size_type j = start_index; j < end_index; j++) {
if (!has_nulls || input.is_valid(j)) {
OutputType element = input.element<InputType>(j);
val = agg_op{}(element.time_since_epoch(), val.time_since_epoch());
count++;
}
}

bool output_is_valid = (count >= min_periods);

// store the output value, one per thread
cudf::detail::rolling_store_output_functor<OutputType, op == aggregation::MEAN>{}(
output.element<OutputType>(current_index), val, count);

return output_is_valid;
}

/**
* @brief Computes the rolling window function
*
Expand Down

0 comments on commit afb4f34

Please sign in to comment.