Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-expr/src/window/sliding_aggregate.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Physical exec for aggregate window function expressions.
19
20
use std::any::Any;
21
use std::ops::Range;
22
use std::sync::Arc;
23
24
use arrow::array::{Array, ArrayRef};
25
use arrow::datatypes::Field;
26
use arrow::record_batch::RecordBatch;
27
28
use datafusion_common::{Result, ScalarValue};
29
use datafusion_expr::{Accumulator, WindowFrame};
30
31
use crate::aggregate::AggregateFunctionExpr;
32
use crate::window::window_expr::AggregateWindowExpr;
33
use crate::window::{
34
    PartitionBatches, PartitionWindowAggStates, PlainAggregateWindowExpr, WindowExpr,
35
};
36
use crate::{expressions::PhysicalSortExpr, reverse_order_bys, PhysicalExpr};
37
38
/// A window expr that takes the form of an aggregate function that
39
/// can be incrementally computed over sliding windows.
40
///
41
/// See comments on [`WindowExpr`] for more details.
42
#[derive(Debug)]
43
pub struct SlidingAggregateWindowExpr {
44
    aggregate: AggregateFunctionExpr,
45
    partition_by: Vec<Arc<dyn PhysicalExpr>>,
46
    order_by: Vec<PhysicalSortExpr>,
47
    window_frame: Arc<WindowFrame>,
48
}
49
50
impl SlidingAggregateWindowExpr {
51
    /// Create a new (sliding) aggregate window function expression.
52
1
    pub fn new(
53
1
        aggregate: AggregateFunctionExpr,
54
1
        partition_by: &[Arc<dyn PhysicalExpr>],
55
1
        order_by: &[PhysicalSortExpr],
56
1
        window_frame: Arc<WindowFrame>,
57
1
    ) -> Self {
58
1
        Self {
59
1
            aggregate,
60
1
            partition_by: partition_by.to_vec(),
61
1
            order_by: order_by.to_vec(),
62
1
            window_frame,
63
1
        }
64
1
    }
65
66
    /// Get the [AggregateFunctionExpr] of this object.
67
0
    pub fn get_aggregate_expr(&self) -> &AggregateFunctionExpr {
68
0
        &self.aggregate
69
0
    }
70
}
71
72
/// Incrementally update window function using the fact that batch is
73
/// pre-sorted given the sort columns and then per partition point.
74
///
75
/// Evaluates the peer group (e.g. `SUM` or `MAX` gives the same results
76
/// for peers) and concatenate the results.
77
impl WindowExpr for SlidingAggregateWindowExpr {
78
    /// Return a reference to Any that can be used for downcasting
79
1
    fn as_any(&self) -> &dyn Any {
80
1
        self
81
1
    }
82
83
12
    fn field(&self) -> Result<Field> {
84
12
        Ok(self.aggregate.field())
85
12
    }
86
87
1
    fn name(&self) -> &str {
88
1
        self.aggregate.name()
89
1
    }
90
91
9
    fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
92
9
        self.aggregate.expressions()
93
9
    }
94
95
0
    fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
96
0
        self.aggregate_evaluate(batch)
97
0
    }
98
99
5
    fn evaluate_stateful(
100
5
        &self,
101
5
        partition_batches: &PartitionBatches,
102
5
        window_agg_state: &mut PartitionWindowAggStates,
103
5
    ) -> Result<()> {
104
5
        self.aggregate_evaluate_stateful(partition_batches, window_agg_state)
105
5
    }
106
107
12
    fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>] {
108
12
        &self.partition_by
109
12
    }
110
111
30
    fn order_by(&self) -> &[PhysicalSortExpr] {
112
30
        &self.order_by
113
30
    }
114
115
13
    fn get_window_frame(&self) -> &Arc<WindowFrame> {
116
13
        &self.window_frame
117
13
    }
118
119
0
    fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>> {
120
0
        self.aggregate.reverse_expr().map(|reverse_expr| {
121
0
            let reverse_window_frame = self.window_frame.reverse();
122
0
            if reverse_window_frame.start_bound.is_unbounded() {
123
0
                Arc::new(PlainAggregateWindowExpr::new(
124
0
                    reverse_expr,
125
0
                    &self.partition_by.clone(),
126
0
                    &reverse_order_bys(&self.order_by),
127
0
                    Arc::new(self.window_frame.reverse()),
128
0
                )) as _
129
            } else {
130
0
                Arc::new(SlidingAggregateWindowExpr::new(
131
0
                    reverse_expr,
132
0
                    &self.partition_by.clone(),
133
0
                    &reverse_order_bys(&self.order_by),
134
0
                    Arc::new(self.window_frame.reverse()),
135
0
                )) as _
136
            }
137
0
        })
138
0
    }
139
140
0
    fn uses_bounded_memory(&self) -> bool {
141
0
        !self.window_frame.end_bound.is_unbounded()
142
0
    }
143
144
0
    fn with_new_expressions(
145
0
        &self,
146
0
        args: Vec<Arc<dyn PhysicalExpr>>,
147
0
        partition_bys: Vec<Arc<dyn PhysicalExpr>>,
148
0
        order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
149
0
    ) -> Option<Arc<dyn WindowExpr>> {
150
0
        debug_assert_eq!(self.order_by.len(), order_by_exprs.len());
151
152
0
        let new_order_by = self
153
0
            .order_by
154
0
            .iter()
155
0
            .zip(order_by_exprs)
156
0
            .map(|(req, new_expr)| PhysicalSortExpr {
157
0
                expr: new_expr,
158
0
                options: req.options,
159
0
            })
160
0
            .collect::<Vec<_>>();
161
0
        Some(Arc::new(SlidingAggregateWindowExpr {
162
0
            aggregate: self.aggregate.with_new_expressions(args, vec![])?,
163
0
            partition_by: partition_bys,
164
0
            order_by: new_order_by,
165
0
            window_frame: Arc::clone(&self.window_frame),
166
        }))
167
0
    }
168
}
169
170
impl AggregateWindowExpr for SlidingAggregateWindowExpr {
171
3
    fn get_accumulator(&self) -> Result<Box<dyn Accumulator>> {
172
3
        self.aggregate.create_sliding_accumulator()
173
3
    }
174
175
    /// Given current range and the last range, calculates the accumulator
176
    /// result for the range of interest.
177
8
    fn get_aggregate_result_inside_range(
178
8
        &self,
179
8
        last_range: &Range<usize>,
180
8
        cur_range: &Range<usize>,
181
8
        value_slice: &[ArrayRef],
182
8
        accumulator: &mut Box<dyn Accumulator>,
183
8
    ) -> Result<ScalarValue> {
184
8
        if cur_range.start == cur_range.end {
185
0
            self.aggregate
186
0
                .default_value(self.aggregate.field().data_type())
187
        } else {
188
            // Accumulate any new rows that have entered the window:
189
8
            let update_bound = cur_range.end - last_range.end;
190
8
            if update_bound > 0 {
191
6
                let update: Vec<ArrayRef> = value_slice
192
6
                    .iter()
193
6
                    .map(|v| v.slice(last_range.end, update_bound))
194
6
                    .collect();
195
6
                accumulator.update_batch(&update)
?0
196
2
            }
197
198
            // Remove rows that have now left the window:
199
8
            let retract_bound = cur_range.start - last_range.start;
200
8
            if retract_bound > 0 {
201
6
                let retract: Vec<ArrayRef> = value_slice
202
6
                    .iter()
203
6
                    .map(|v| v.slice(last_range.start, retract_bound))
204
6
                    .collect();
205
6
                accumulator.retract_batch(&retract)
?0
206
2
            }
207
8
            accumulator.evaluate()
208
        }
209
8
    }
210
}