Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-expr/src/window/aggregate.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Physical exec for aggregate window function expressions.
19
20
use std::any::Any;
21
use std::ops::Range;
22
use std::sync::Arc;
23
24
use arrow::array::Array;
25
use arrow::record_batch::RecordBatch;
26
use arrow::{array::ArrayRef, datatypes::Field};
27
28
use datafusion_common::ScalarValue;
29
use datafusion_common::{DataFusionError, Result};
30
use datafusion_expr::{Accumulator, WindowFrame};
31
32
use crate::aggregate::AggregateFunctionExpr;
33
use crate::window::window_expr::AggregateWindowExpr;
34
use crate::window::{
35
    PartitionBatches, PartitionWindowAggStates, SlidingAggregateWindowExpr, WindowExpr,
36
};
37
use crate::{expressions::PhysicalSortExpr, reverse_order_bys, PhysicalExpr};
38
39
/// A window expr that takes the form of an aggregate function.
40
///
41
/// See comments on [`WindowExpr`] for more details.
42
#[derive(Debug)]
43
pub struct PlainAggregateWindowExpr {
44
    aggregate: AggregateFunctionExpr,
45
    partition_by: Vec<Arc<dyn PhysicalExpr>>,
46
    order_by: Vec<PhysicalSortExpr>,
47
    window_frame: Arc<WindowFrame>,
48
}
49
50
impl PlainAggregateWindowExpr {
51
    /// Create a new aggregate window function expression
52
1
    pub fn new(
53
1
        aggregate: AggregateFunctionExpr,
54
1
        partition_by: &[Arc<dyn PhysicalExpr>],
55
1
        order_by: &[PhysicalSortExpr],
56
1
        window_frame: Arc<WindowFrame>,
57
1
    ) -> Self {
58
1
        Self {
59
1
            aggregate,
60
1
            partition_by: partition_by.to_vec(),
61
1
            order_by: order_by.to_vec(),
62
1
            window_frame,
63
1
        }
64
1
    }
65
66
    /// Get aggregate expr of AggregateWindowExpr
67
0
    pub fn get_aggregate_expr(&self) -> &AggregateFunctionExpr {
68
0
        &self.aggregate
69
0
    }
70
}
71
72
/// peer based evaluation based on the fact that batch is pre-sorted given the sort columns
73
/// and then per partition point we'll evaluate the peer group (e.g. SUM or MAX gives the same
74
/// results for peers) and concatenate the results.
75
impl WindowExpr for PlainAggregateWindowExpr {
76
    /// Return a reference to Any that can be used for downcasting
77
1
    fn as_any(&self) -> &dyn Any {
78
1
        self
79
1
    }
80
81
1
    fn field(&self) -> Result<Field> {
82
1
        Ok(self.aggregate.field())
83
1
    }
84
85
0
    fn name(&self) -> &str {
86
0
        self.aggregate.name()
87
0
    }
88
89
0
    fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
90
0
        self.aggregate.expressions()
91
0
    }
92
93
0
    fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
94
0
        self.aggregate_evaluate(batch)
95
0
    }
96
97
0
    fn evaluate_stateful(
98
0
        &self,
99
0
        partition_batches: &PartitionBatches,
100
0
        window_agg_state: &mut PartitionWindowAggStates,
101
0
    ) -> Result<()> {
102
0
        self.aggregate_evaluate_stateful(partition_batches, window_agg_state)?;
103
104
        // Update window frame range for each partition. As we know that
105
        // non-sliding aggregations will never call `retract_batch`, this value
106
        // can safely increase, and we can remove "old" parts of the state.
107
        // This enables us to run queries involving UNBOUNDED PRECEDING frames
108
        // using bounded memory for suitable aggregations.
109
0
        for partition_row in partition_batches.keys() {
110
0
            let window_state =
111
0
                window_agg_state.get_mut(partition_row).ok_or_else(|| {
112
0
                    DataFusionError::Execution("Cannot find state".to_string())
113
0
                })?;
114
0
            let state = &mut window_state.state;
115
0
            if self.window_frame.start_bound.is_unbounded() {
116
0
                state.window_frame_range.start =
117
0
                    state.window_frame_range.end.saturating_sub(1);
118
0
            }
119
        }
120
0
        Ok(())
121
0
    }
122
123
3
    fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>] {
124
3
        &self.partition_by
125
3
    }
126
127
0
    fn order_by(&self) -> &[PhysicalSortExpr] {
128
0
        &self.order_by
129
0
    }
130
131
0
    fn get_window_frame(&self) -> &Arc<WindowFrame> {
132
0
        &self.window_frame
133
0
    }
134
135
0
    fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>> {
136
0
        self.aggregate.reverse_expr().map(|reverse_expr| {
137
0
            let reverse_window_frame = self.window_frame.reverse();
138
0
            if reverse_window_frame.start_bound.is_unbounded() {
139
0
                Arc::new(PlainAggregateWindowExpr::new(
140
0
                    reverse_expr,
141
0
                    &self.partition_by.clone(),
142
0
                    &reverse_order_bys(&self.order_by),
143
0
                    Arc::new(self.window_frame.reverse()),
144
0
                )) as _
145
            } else {
146
0
                Arc::new(SlidingAggregateWindowExpr::new(
147
0
                    reverse_expr,
148
0
                    &self.partition_by.clone(),
149
0
                    &reverse_order_bys(&self.order_by),
150
0
                    Arc::new(self.window_frame.reverse()),
151
0
                )) as _
152
            }
153
0
        })
154
0
    }
155
156
0
    fn uses_bounded_memory(&self) -> bool {
157
0
        !self.window_frame.end_bound.is_unbounded()
158
0
    }
159
}
160
161
impl AggregateWindowExpr for PlainAggregateWindowExpr {
162
0
    fn get_accumulator(&self) -> Result<Box<dyn Accumulator>> {
163
0
        self.aggregate.create_accumulator()
164
0
    }
165
166
    /// For a given range, calculate accumulation result inside the range on
167
    /// `value_slice` and update accumulator state.
168
    // We assume that `cur_range` contains `last_range` and their start points
169
    // are same. In summary if `last_range` is `Range{start: a,end: b}` and
170
    // `cur_range` is `Range{start: a1, end: b1}`, it is guaranteed that a1=a and b1>=b.
171
0
    fn get_aggregate_result_inside_range(
172
0
        &self,
173
0
        last_range: &Range<usize>,
174
0
        cur_range: &Range<usize>,
175
0
        value_slice: &[ArrayRef],
176
0
        accumulator: &mut Box<dyn Accumulator>,
177
0
    ) -> Result<ScalarValue> {
178
0
        if cur_range.start == cur_range.end {
179
0
            self.aggregate
180
0
                .default_value(self.aggregate.field().data_type())
181
        } else {
182
            // Accumulate any new rows that have entered the window:
183
0
            let update_bound = cur_range.end - last_range.end;
184
0
            // A non-sliding aggregation only processes new data, it never
185
0
            // deals with expiring data as its starting point is always the
186
0
            // same point (i.e. the beginning of the table/frame). Hence, we
187
0
            // do not call `retract_batch`.
188
0
            if update_bound > 0 {
189
0
                let update: Vec<ArrayRef> = value_slice
190
0
                    .iter()
191
0
                    .map(|v| v.slice(last_range.end, update_bound))
192
0
                    .collect();
193
0
                accumulator.update_batch(&update)?
194
0
            }
195
0
            accumulator.evaluate()
196
        }
197
0
    }
198
}