From 5b5f4522f592633c5d08738cf1766f507e138876 Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Tue, 25 May 2021 20:21:58 +0800 Subject: [PATCH] fix Signed-off-by: sperlingxx --- cpp/src/rolling/rolling_collect_list.cuh | 3 +-- cpp/src/rolling/rolling_detail.cuh | 6 +++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/cpp/src/rolling/rolling_collect_list.cuh b/cpp/src/rolling/rolling_collect_list.cuh index f5a2e59fd2a..0ffafe349b9 100644 --- a/cpp/src/rolling/rolling_collect_list.cuh +++ b/cpp/src/rolling/rolling_collect_list.cuh @@ -283,7 +283,7 @@ std::unique_ptr rolling_collect_list(column_view const& input, PrecedingIter preceding_begin_raw, FollowingIter following_begin_raw, size_type min_periods, - rolling_aggregation const& agg, + null_policy null_handling, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { @@ -321,7 +321,6 @@ std::unique_ptr rolling_collect_list(column_view const& input, // If gather_map collects null elements, and null_policy == EXCLUDE, // those elements must be filtered out, and offsets recomputed. - auto null_handling = dynamic_cast(agg)._null_handling; if (null_handling == null_policy::EXCLUDE && input.has_nulls()) { auto num_child_nulls = count_child_nulls(input, gather_map, stream); if (num_child_nulls != 0) { diff --git a/cpp/src/rolling/rolling_detail.cuh b/cpp/src/rolling/rolling_detail.cuh index 580024f6dba..2eae5669eeb 100644 --- a/cpp/src/rolling/rolling_detail.cuh +++ b/cpp/src/rolling/rolling_detail.cuh @@ -687,7 +687,7 @@ class rolling_aggregation_postprocessor final : public cudf::detail::aggregation preceding_window_begin, following_window_begin, min_periods, - agg, + agg._null_handling, stream, mr); } @@ -700,7 +700,7 @@ class rolling_aggregation_postprocessor final : public cudf::detail::aggregation preceding_window_begin, following_window_begin, min_periods, - agg, + agg._null_handling, stream, mr); @@ -953,7 +953,7 @@ struct dispatch_rolling { { // do any preprocessing of aggregations (eg, MIN -> ARGMIN, COLLECT_LIST -> nothing) rolling_aggregation_preprocessor preprocessor; - auto preprocessed_aggs = agg.get_simple_aggregations(input.type(), preprocessor); + auto preprocessed_aggs = agg.get_simple_aggregations( input.type(), preprocessor); CUDF_EXPECTS(preprocessed_aggs.size() <= 1, "Encountered a non-trivial rolling aggregation result");