/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/aggregates/order/partial.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | use arrow::row::{OwnedRow, RowConverter, Rows, SortField}; |
19 | | use arrow_array::ArrayRef; |
20 | | use arrow_schema::Schema; |
21 | | use datafusion_common::Result; |
22 | | use datafusion_execution::memory_pool::proxy::VecAllocExt; |
23 | | use datafusion_expr::EmitTo; |
24 | | use datafusion_physical_expr::PhysicalSortExpr; |
25 | | use std::sync::Arc; |
26 | | |
27 | | /// Tracks grouping state when the data is ordered by some subset of |
28 | | /// the group keys. |
29 | | /// |
30 | | /// Once the next *sort key* value is seen, never see groups with that |
31 | | /// sort key again, so we can emit all groups with the previous sort |
32 | | /// key and earlier. |
33 | | /// |
34 | | /// For example, given `SUM(amt) GROUP BY id, state` if the input is |
35 | | /// sorted by `state`, when a new value of `state` is seen, all groups |
36 | | /// with prior values of `state` can be emitted. |
37 | | /// |
38 | | /// The state is tracked like this: |
39 | | /// |
40 | | /// ```text |
41 | | /// ┏━━━━━━━━━━━━━━━━━┓ ┏━━━━━━━┓ |
42 | | /// ┌─────┐ ┌───────────────────┐ ┌─────┃ 9 ┃ ┃ "MD" ┃ |
43 | | /// │┌───┐│ │ ┌──────────────┐ │ │ ┗━━━━━━━━━━━━━━━━━┛ ┗━━━━━━━┛ |
44 | | /// ││ 0 ││ │ │ 123, "MA" │ │ │ current_sort sort_key |
45 | | /// │└───┘│ │ └──────────────┘ │ │ |
46 | | /// │ ... │ │ ... │ │ current_sort tracks the |
47 | | /// │┌───┐│ │ ┌──────────────┐ │ │ smallest group index that had |
48 | | /// ││ 8 ││ │ │ 765, "MA" │ │ │ the same sort_key as current |
49 | | /// │├───┤│ │ ├──────────────┤ │ │ |
50 | | /// ││ 9 ││ │ │ 923, "MD" │◀─┼─┘ |
51 | | /// │├───┤│ │ ├──────────────┤ │ ┏━━━━━━━━━━━━━━┓ |
52 | | /// ││10 ││ │ │ 345, "MD" │ │ ┌─────┃ 11 ┃ |
53 | | /// │├───┤│ │ ├──────────────┤ │ │ ┗━━━━━━━━━━━━━━┛ |
54 | | /// ││11 ││ │ │ 124, "MD" │◀─┼──┘ current |
55 | | /// │└───┘│ │ └──────────────┘ │ |
56 | | /// └─────┘ └───────────────────┘ |
57 | | /// |
58 | | /// group indices |
59 | | /// (in group value group_values current tracks the most |
60 | | /// order) recent group index |
61 | | ///``` |
62 | | #[derive(Debug)] |
63 | | pub struct GroupOrderingPartial { |
64 | | /// State machine |
65 | | state: State, |
66 | | |
67 | | /// The indexes of the group by columns that form the sort key. |
68 | | /// For example if grouping by `id, state` and ordered by `state` |
69 | | /// this would be `[1]`. |
70 | | order_indices: Vec<usize>, |
71 | | |
72 | | /// Converter for the sort key (used on the group columns |
73 | | /// specified in `order_indexes`) |
74 | | row_converter: RowConverter, |
75 | | } |
76 | | |
77 | | #[derive(Debug, Default)] |
78 | | enum State { |
79 | | /// The ordering was temporarily taken. `Self::Taken` is left |
80 | | /// when state must be temporarily taken to satisfy the borrow |
81 | | /// checker. If an error happens before the state can be restored, |
82 | | /// the ordering information is lost and execution can not |
83 | | /// proceed, but there is no undefined behavior. |
84 | | #[default] |
85 | | Taken, |
86 | | |
87 | | /// Seen no input yet |
88 | | Start, |
89 | | |
90 | | /// Data is in progress. |
91 | | InProgress { |
92 | | /// Smallest group index with the sort_key |
93 | | current_sort: usize, |
94 | | /// The sort key of group_index `current_sort` |
95 | | sort_key: OwnedRow, |
96 | | /// index of the current group for which values are being |
97 | | /// generated |
98 | | current: usize, |
99 | | }, |
100 | | |
101 | | /// Seen end of input, all groups can be emitted |
102 | | Complete, |
103 | | } |
104 | | |
105 | | impl GroupOrderingPartial { |
106 | 0 | pub fn try_new( |
107 | 0 | input_schema: &Schema, |
108 | 0 | order_indices: &[usize], |
109 | 0 | ordering: &[PhysicalSortExpr], |
110 | 0 | ) -> Result<Self> { |
111 | 0 | assert!(!order_indices.is_empty()); |
112 | 0 | assert!(order_indices.len() <= ordering.len()); |
113 | | |
114 | | // get only the section of ordering, that consist of group by expressions. |
115 | 0 | let fields = ordering[0..order_indices.len()] |
116 | 0 | .iter() |
117 | 0 | .map(|sort_expr| { |
118 | 0 | Ok(SortField::new_with_options( |
119 | 0 | sort_expr.expr.data_type(input_schema)?, |
120 | 0 | sort_expr.options, |
121 | | )) |
122 | 0 | }) |
123 | 0 | .collect::<Result<Vec<_>>>()?; |
124 | | |
125 | | Ok(Self { |
126 | 0 | state: State::Start, |
127 | 0 | order_indices: order_indices.to_vec(), |
128 | 0 | row_converter: RowConverter::new(fields)?, |
129 | | }) |
130 | 0 | } |
131 | | |
132 | | /// Creates sort keys from the group values |
133 | | /// |
134 | | /// For example, if group_values had `A, B, C` but the input was |
135 | | /// only sorted on `B` and `C` this should return rows for (`B`, |
136 | | /// `C`) |
137 | 0 | fn compute_sort_keys(&mut self, group_values: &[ArrayRef]) -> Result<Rows> { |
138 | 0 | // Take only the columns that are in the sort key |
139 | 0 | let sort_values: Vec<_> = self |
140 | 0 | .order_indices |
141 | 0 | .iter() |
142 | 0 | .map(|&idx| Arc::clone(&group_values[idx])) |
143 | 0 | .collect(); |
144 | 0 |
|
145 | 0 | Ok(self.row_converter.convert_columns(&sort_values)?) |
146 | 0 | } |
147 | | |
148 | | /// How many groups be emitted, or None if no data can be emitted |
149 | 0 | pub fn emit_to(&self) -> Option<EmitTo> { |
150 | 0 | match &self.state { |
151 | 0 | State::Taken => unreachable!("State previously taken"), |
152 | 0 | State::Start => None, |
153 | 0 | State::InProgress { current_sort, .. } => { |
154 | 0 | // Can not emit if we are still on the first row sort |
155 | 0 | // row otherwise we can emit all groups that had earlier sort keys |
156 | 0 | // |
157 | 0 | if *current_sort == 0 { |
158 | 0 | None |
159 | | } else { |
160 | 0 | Some(EmitTo::First(*current_sort)) |
161 | | } |
162 | | } |
163 | 0 | State::Complete => Some(EmitTo::All), |
164 | | } |
165 | 0 | } |
166 | | |
167 | | /// remove the first n groups from the internal state, shifting |
168 | | /// all existing indexes down by `n` |
169 | 0 | pub fn remove_groups(&mut self, n: usize) { |
170 | 0 | match &mut self.state { |
171 | 0 | State::Taken => unreachable!("State previously taken"), |
172 | 0 | State::Start => panic!("invalid state: start"), |
173 | | State::InProgress { |
174 | 0 | current_sort, |
175 | 0 | current, |
176 | 0 | sort_key: _, |
177 | 0 | } => { |
178 | 0 | // shift indexes down by n |
179 | 0 | assert!(*current >= n); |
180 | 0 | *current -= n; |
181 | 0 | assert!(*current_sort >= n); |
182 | 0 | *current_sort -= n; |
183 | | } |
184 | 0 | State::Complete { .. } => panic!("invalid state: complete"), |
185 | | } |
186 | 0 | } |
187 | | |
188 | | /// Note that the input is complete so any outstanding groups are done as well |
189 | 0 | pub fn input_done(&mut self) { |
190 | 0 | self.state = match self.state { |
191 | 0 | State::Taken => unreachable!("State previously taken"), |
192 | 0 | _ => State::Complete, |
193 | 0 | }; |
194 | 0 | } |
195 | | |
196 | | /// Called when new groups are added in a batch. See documentation |
197 | | /// on [`super::GroupOrdering::new_groups`] |
198 | 0 | pub fn new_groups( |
199 | 0 | &mut self, |
200 | 0 | batch_group_values: &[ArrayRef], |
201 | 0 | group_indices: &[usize], |
202 | 0 | total_num_groups: usize, |
203 | 0 | ) -> Result<()> { |
204 | 0 | assert!(total_num_groups > 0); |
205 | 0 | assert!(!batch_group_values.is_empty()); |
206 | | |
207 | 0 | let max_group_index = total_num_groups - 1; |
208 | | |
209 | | // compute the sort key values for each group |
210 | 0 | let sort_keys = self.compute_sort_keys(batch_group_values)?; |
211 | | |
212 | 0 | let old_state = std::mem::take(&mut self.state); |
213 | 0 | let (mut current_sort, mut sort_key) = match &old_state { |
214 | 0 | State::Taken => unreachable!("State previously taken"), |
215 | 0 | State::Start => (0, sort_keys.row(0)), |
216 | | State::InProgress { |
217 | 0 | current_sort, |
218 | 0 | sort_key, |
219 | 0 | .. |
220 | 0 | } => (*current_sort, sort_key.row()), |
221 | | State::Complete => { |
222 | 0 | panic!("Saw new group after the end of input"); |
223 | | } |
224 | | }; |
225 | | |
226 | | // Find latest sort key |
227 | 0 | let iter = group_indices.iter().zip(sort_keys.iter()); |
228 | 0 | for (&group_index, group_sort_key) in iter { |
229 | | // Does this group have seen a new sort_key? |
230 | 0 | if sort_key != group_sort_key { |
231 | 0 | current_sort = group_index; |
232 | 0 | sort_key = group_sort_key; |
233 | 0 | } |
234 | | } |
235 | | |
236 | 0 | self.state = State::InProgress { |
237 | 0 | current_sort, |
238 | 0 | sort_key: sort_key.owned(), |
239 | 0 | current: max_group_index, |
240 | 0 | }; |
241 | 0 |
|
242 | 0 | Ok(()) |
243 | 0 | } |
244 | | |
245 | | /// Return the size of memory allocated by this structure |
246 | 0 | pub(crate) fn size(&self) -> usize { |
247 | 0 | std::mem::size_of::<Self>() |
248 | 0 | + self.order_indices.allocated_size() |
249 | 0 | + self.row_converter.size() |
250 | 0 | } |
251 | | } |