Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/aggregates/order/partial.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
use arrow::row::{OwnedRow, RowConverter, Rows, SortField};
19
use arrow_array::ArrayRef;
20
use arrow_schema::Schema;
21
use datafusion_common::Result;
22
use datafusion_execution::memory_pool::proxy::VecAllocExt;
23
use datafusion_expr::EmitTo;
24
use datafusion_physical_expr::PhysicalSortExpr;
25
use std::sync::Arc;
26
27
/// Tracks grouping state when the data is ordered by some subset of
28
/// the group keys.
29
///
30
/// Once the next *sort key* value is seen, never see groups with that
31
/// sort key again, so we can emit all groups with the previous sort
32
/// key and earlier.
33
///
34
/// For example, given `SUM(amt) GROUP BY id, state` if the input is
35
/// sorted by `state`, when a new value of `state` is seen, all groups
36
/// with prior values of `state` can be emitted.
37
///
38
/// The state is tracked like this:
39
///
40
/// ```text
41
///                                            ┏━━━━━━━━━━━━━━━━━┓ ┏━━━━━━━┓
42
///     ┌─────┐    ┌───────────────────┐ ┌─────┃        9        ┃ ┃ "MD"  ┃
43
///     │┌───┐│    │ ┌──────────────┐  │ │     ┗━━━━━━━━━━━━━━━━━┛ ┗━━━━━━━┛
44
///     ││ 0 ││    │ │  123, "MA"   │  │ │        current_sort      sort_key
45
///     │└───┘│    │ └──────────────┘  │ │
46
///     │ ... │    │    ...            │ │      current_sort tracks the
47
///     │┌───┐│    │ ┌──────────────┐  │ │      smallest group index that had
48
///     ││ 8 ││    │ │  765, "MA"   │  │ │      the same sort_key as current
49
///     │├───┤│    │ ├──────────────┤  │ │
50
///     ││ 9 ││    │ │  923, "MD"   │◀─┼─┘
51
///     │├───┤│    │ ├──────────────┤  │        ┏━━━━━━━━━━━━━━┓
52
///     ││10 ││    │ │  345, "MD"   │  │  ┌─────┃      11      ┃
53
///     │├───┤│    │ ├──────────────┤  │  │     ┗━━━━━━━━━━━━━━┛
54
///     ││11 ││    │ │  124, "MD"   │◀─┼──┘         current
55
///     │└───┘│    │ └──────────────┘  │
56
///     └─────┘    └───────────────────┘
57
///
58
///  group indices
59
/// (in group value  group_values               current tracks the most
60
///      order)                                    recent group index
61
///```
62
#[derive(Debug)]
63
pub struct GroupOrderingPartial {
64
    /// State machine
65
    state: State,
66
67
    /// The indexes of the group by columns that form the sort key.
68
    /// For example if grouping by `id, state` and ordered by `state`
69
    /// this would be `[1]`.
70
    order_indices: Vec<usize>,
71
72
    /// Converter for the sort key (used on the group columns
73
    /// specified in `order_indexes`)
74
    row_converter: RowConverter,
75
}
76
77
#[derive(Debug, Default)]
78
enum State {
79
    /// The ordering was temporarily taken.  `Self::Taken` is left
80
    /// when state must be temporarily taken to satisfy the borrow
81
    /// checker. If an error happens before the state can be restored,
82
    /// the ordering information is lost and execution can not
83
    /// proceed, but there is no undefined behavior.
84
    #[default]
85
    Taken,
86
87
    /// Seen no input yet
88
    Start,
89
90
    /// Data is in progress.
91
    InProgress {
92
        /// Smallest group index with the sort_key
93
        current_sort: usize,
94
        /// The sort key of group_index `current_sort`
95
        sort_key: OwnedRow,
96
        /// index of the current group for which values are being
97
        /// generated
98
        current: usize,
99
    },
100
101
    /// Seen end of input, all groups can be emitted
102
    Complete,
103
}
104
105
impl GroupOrderingPartial {
106
0
    pub fn try_new(
107
0
        input_schema: &Schema,
108
0
        order_indices: &[usize],
109
0
        ordering: &[PhysicalSortExpr],
110
0
    ) -> Result<Self> {
111
0
        assert!(!order_indices.is_empty());
112
0
        assert!(order_indices.len() <= ordering.len());
113
114
        // get only the section of ordering, that consist of group by expressions.
115
0
        let fields = ordering[0..order_indices.len()]
116
0
            .iter()
117
0
            .map(|sort_expr| {
118
0
                Ok(SortField::new_with_options(
119
0
                    sort_expr.expr.data_type(input_schema)?,
120
0
                    sort_expr.options,
121
                ))
122
0
            })
123
0
            .collect::<Result<Vec<_>>>()?;
124
125
        Ok(Self {
126
0
            state: State::Start,
127
0
            order_indices: order_indices.to_vec(),
128
0
            row_converter: RowConverter::new(fields)?,
129
        })
130
0
    }
131
132
    /// Creates sort keys from the group values
133
    ///
134
    /// For example, if group_values had `A, B, C` but the input was
135
    /// only sorted on `B` and `C` this should return rows for (`B`,
136
    /// `C`)
137
0
    fn compute_sort_keys(&mut self, group_values: &[ArrayRef]) -> Result<Rows> {
138
0
        // Take only the columns that are in the sort key
139
0
        let sort_values: Vec<_> = self
140
0
            .order_indices
141
0
            .iter()
142
0
            .map(|&idx| Arc::clone(&group_values[idx]))
143
0
            .collect();
144
0
145
0
        Ok(self.row_converter.convert_columns(&sort_values)?)
146
0
    }
147
148
    /// How many groups be emitted, or None if no data can be emitted
149
0
    pub fn emit_to(&self) -> Option<EmitTo> {
150
0
        match &self.state {
151
0
            State::Taken => unreachable!("State previously taken"),
152
0
            State::Start => None,
153
0
            State::InProgress { current_sort, .. } => {
154
0
                // Can not emit if we are still on the first row sort
155
0
                // row otherwise we can emit all groups that had earlier sort keys
156
0
                //
157
0
                if *current_sort == 0 {
158
0
                    None
159
                } else {
160
0
                    Some(EmitTo::First(*current_sort))
161
                }
162
            }
163
0
            State::Complete => Some(EmitTo::All),
164
        }
165
0
    }
166
167
    /// remove the first n groups from the internal state, shifting
168
    /// all existing indexes down by `n`
169
0
    pub fn remove_groups(&mut self, n: usize) {
170
0
        match &mut self.state {
171
0
            State::Taken => unreachable!("State previously taken"),
172
0
            State::Start => panic!("invalid state: start"),
173
            State::InProgress {
174
0
                current_sort,
175
0
                current,
176
0
                sort_key: _,
177
0
            } => {
178
0
                // shift indexes down by n
179
0
                assert!(*current >= n);
180
0
                *current -= n;
181
0
                assert!(*current_sort >= n);
182
0
                *current_sort -= n;
183
            }
184
0
            State::Complete { .. } => panic!("invalid state: complete"),
185
        }
186
0
    }
187
188
    /// Note that the input is complete so any outstanding groups are done as well
189
0
    pub fn input_done(&mut self) {
190
0
        self.state = match self.state {
191
0
            State::Taken => unreachable!("State previously taken"),
192
0
            _ => State::Complete,
193
0
        };
194
0
    }
195
196
    /// Called when new groups are added in a batch. See documentation
197
    /// on [`super::GroupOrdering::new_groups`]
198
0
    pub fn new_groups(
199
0
        &mut self,
200
0
        batch_group_values: &[ArrayRef],
201
0
        group_indices: &[usize],
202
0
        total_num_groups: usize,
203
0
    ) -> Result<()> {
204
0
        assert!(total_num_groups > 0);
205
0
        assert!(!batch_group_values.is_empty());
206
207
0
        let max_group_index = total_num_groups - 1;
208
209
        // compute the sort key values for each group
210
0
        let sort_keys = self.compute_sort_keys(batch_group_values)?;
211
212
0
        let old_state = std::mem::take(&mut self.state);
213
0
        let (mut current_sort, mut sort_key) = match &old_state {
214
0
            State::Taken => unreachable!("State previously taken"),
215
0
            State::Start => (0, sort_keys.row(0)),
216
            State::InProgress {
217
0
                current_sort,
218
0
                sort_key,
219
0
                ..
220
0
            } => (*current_sort, sort_key.row()),
221
            State::Complete => {
222
0
                panic!("Saw new group after the end of input");
223
            }
224
        };
225
226
        // Find latest sort key
227
0
        let iter = group_indices.iter().zip(sort_keys.iter());
228
0
        for (&group_index, group_sort_key) in iter {
229
            // Does this group have seen a new sort_key?
230
0
            if sort_key != group_sort_key {
231
0
                current_sort = group_index;
232
0
                sort_key = group_sort_key;
233
0
            }
234
        }
235
236
0
        self.state = State::InProgress {
237
0
            current_sort,
238
0
            sort_key: sort_key.owned(),
239
0
            current: max_group_index,
240
0
        };
241
0
242
0
        Ok(())
243
0
    }
244
245
    /// Return the size of memory allocated by this structure
246
0
    pub(crate) fn size(&self) -> usize {
247
0
        std::mem::size_of::<Self>()
248
0
            + self.order_indices.allocated_size()
249
0
            + self.row_converter.size()
250
0
    }
251
}