Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/aggregates/topk_stream.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! A memory-conscious aggregation implementation that limits group buckets to a fixed number
19
20
use crate::aggregates::topk::priority_map::PriorityMap;
21
use crate::aggregates::{
22
    aggregate_expressions, evaluate_group_by, evaluate_many, AggregateExec,
23
    PhysicalGroupBy,
24
};
25
use crate::{RecordBatchStream, SendableRecordBatchStream};
26
use arrow::util::pretty::print_batches;
27
use arrow_array::{Array, ArrayRef, RecordBatch};
28
use arrow_schema::SchemaRef;
29
use datafusion_common::DataFusionError;
30
use datafusion_common::Result;
31
use datafusion_execution::TaskContext;
32
use datafusion_physical_expr::PhysicalExpr;
33
use futures::stream::{Stream, StreamExt};
34
use log::{trace, Level};
35
use std::pin::Pin;
36
use std::sync::Arc;
37
use std::task::{Context, Poll};
38
39
pub struct GroupedTopKAggregateStream {
40
    partition: usize,
41
    row_count: usize,
42
    started: bool,
43
    schema: SchemaRef,
44
    input: SendableRecordBatchStream,
45
    aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>,
46
    group_by: PhysicalGroupBy,
47
    priority_map: PriorityMap,
48
}
49
50
impl GroupedTopKAggregateStream {
51
0
    pub fn new(
52
0
        aggr: &AggregateExec,
53
0
        context: Arc<TaskContext>,
54
0
        partition: usize,
55
0
        limit: usize,
56
0
    ) -> Result<Self> {
57
0
        let agg_schema = Arc::clone(&aggr.schema);
58
0
        let group_by = aggr.group_by.clone();
59
0
        let input = aggr.input.execute(partition, Arc::clone(&context))?;
60
0
        let aggregate_arguments =
61
0
            aggregate_expressions(&aggr.aggr_expr, &aggr.mode, group_by.expr.len())?;
62
0
        let (val_field, desc) = aggr
63
0
            .get_minmax_desc()
64
0
            .ok_or_else(|| DataFusionError::Internal("Min/max required".to_string()))?;
65
66
0
        let (expr, _) = &aggr.group_expr().expr()[0];
67
0
        let kt = expr.data_type(&aggr.input().schema())?;
68
0
        let vt = val_field.data_type().clone();
69
70
0
        let priority_map = PriorityMap::new(kt, vt, limit, desc)?;
71
72
0
        Ok(GroupedTopKAggregateStream {
73
0
            partition,
74
0
            started: false,
75
0
            row_count: 0,
76
0
            schema: agg_schema,
77
0
            input,
78
0
            aggregate_arguments,
79
0
            group_by,
80
0
            priority_map,
81
0
        })
82
0
    }
83
}
84
85
impl RecordBatchStream for GroupedTopKAggregateStream {
86
0
    fn schema(&self) -> SchemaRef {
87
0
        Arc::clone(&self.schema)
88
0
    }
89
}
90
91
impl GroupedTopKAggregateStream {
92
0
    fn intern(&mut self, ids: ArrayRef, vals: ArrayRef) -> Result<()> {
93
0
        let len = ids.len();
94
0
        self.priority_map.set_batch(ids, Arc::clone(&vals));
95
0
96
0
        let has_nulls = vals.null_count() > 0;
97
0
        for row_idx in 0..len {
98
0
            if has_nulls && vals.is_null(row_idx) {
99
0
                continue;
100
0
            }
101
0
            self.priority_map.insert(row_idx)?;
102
        }
103
0
        Ok(())
104
0
    }
105
}
106
107
impl Stream for GroupedTopKAggregateStream {
108
    type Item = Result<RecordBatch>;
109
110
0
    fn poll_next(
111
0
        mut self: Pin<&mut Self>,
112
0
        cx: &mut Context<'_>,
113
0
    ) -> Poll<Option<Self::Item>> {
114
0
        while let Poll::Ready(res) = self.input.poll_next_unpin(cx) {
115
0
            match res {
116
                // got a batch, convert to rows and append to our TreeMap
117
0
                Some(Ok(batch)) => {
118
0
                    self.started = true;
119
0
                    trace!(
120
0
                        "partition {} has {} rows and got batch with {} rows",
121
0
                        self.partition,
122
0
                        self.row_count,
123
0
                        batch.num_rows()
124
                    );
125
0
                    if log::log_enabled!(Level::Trace) && batch.num_rows() < 20 {
126
0
                        print_batches(&[batch.clone()])?;
127
0
                    }
128
0
                    self.row_count += batch.num_rows();
129
0
                    let batches = &[batch];
130
0
                    let group_by_values =
131
0
                        evaluate_group_by(&self.group_by, batches.first().unwrap())?;
132
0
                    assert_eq!(
133
0
                        group_by_values.len(),
134
                        1,
135
0
                        "Exactly 1 group value required"
136
                    );
137
0
                    assert_eq!(
138
0
                        group_by_values[0].len(),
139
                        1,
140
0
                        "Exactly 1 group value required"
141
                    );
142
0
                    let group_by_values = Arc::clone(&group_by_values[0][0]);
143
0
                    let input_values = evaluate_many(
144
0
                        &self.aggregate_arguments,
145
0
                        batches.first().unwrap(),
146
0
                    )?;
147
0
                    assert_eq!(input_values.len(), 1, "Exactly 1 input required");
148
0
                    assert_eq!(input_values[0].len(), 1, "Exactly 1 input required");
149
0
                    let input_values = Arc::clone(&input_values[0][0]);
150
0
151
0
                    // iterate over each column of group_by values
152
0
                    (*self).intern(group_by_values, input_values)?;
153
                }
154
                // inner is done, emit all rows and switch to producing output
155
                None => {
156
0
                    if self.priority_map.is_empty() {
157
0
                        trace!("partition {} emit None", self.partition);
158
0
                        return Poll::Ready(None);
159
0
                    }
160
0
                    let cols = self.priority_map.emit()?;
161
0
                    let batch = RecordBatch::try_new(Arc::clone(&self.schema), cols)?;
162
0
                    trace!(
163
0
                        "partition {} emit batch with {} rows",
164
0
                        self.partition,
165
0
                        batch.num_rows()
166
                    );
167
0
                    if log::log_enabled!(Level::Trace) {
168
0
                        print_batches(&[batch.clone()])?;
169
0
                    }
170
0
                    return Poll::Ready(Some(Ok(batch)));
171
                }
172
                // inner had error, return to caller
173
0
                Some(Err(e)) => {
174
0
                    return Poll::Ready(Some(Err(e)));
175
                }
176
            }
177
        }
178
0
        Poll::Pending
179
0
    }
180
}