/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/aggregates/topk_stream.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! A memory-conscious aggregation implementation that limits group buckets to a fixed number |
19 | | |
20 | | use crate::aggregates::topk::priority_map::PriorityMap; |
21 | | use crate::aggregates::{ |
22 | | aggregate_expressions, evaluate_group_by, evaluate_many, AggregateExec, |
23 | | PhysicalGroupBy, |
24 | | }; |
25 | | use crate::{RecordBatchStream, SendableRecordBatchStream}; |
26 | | use arrow::util::pretty::print_batches; |
27 | | use arrow_array::{Array, ArrayRef, RecordBatch}; |
28 | | use arrow_schema::SchemaRef; |
29 | | use datafusion_common::DataFusionError; |
30 | | use datafusion_common::Result; |
31 | | use datafusion_execution::TaskContext; |
32 | | use datafusion_physical_expr::PhysicalExpr; |
33 | | use futures::stream::{Stream, StreamExt}; |
34 | | use log::{trace, Level}; |
35 | | use std::pin::Pin; |
36 | | use std::sync::Arc; |
37 | | use std::task::{Context, Poll}; |
38 | | |
39 | | pub struct GroupedTopKAggregateStream { |
40 | | partition: usize, |
41 | | row_count: usize, |
42 | | started: bool, |
43 | | schema: SchemaRef, |
44 | | input: SendableRecordBatchStream, |
45 | | aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>, |
46 | | group_by: PhysicalGroupBy, |
47 | | priority_map: PriorityMap, |
48 | | } |
49 | | |
50 | | impl GroupedTopKAggregateStream { |
51 | 0 | pub fn new( |
52 | 0 | aggr: &AggregateExec, |
53 | 0 | context: Arc<TaskContext>, |
54 | 0 | partition: usize, |
55 | 0 | limit: usize, |
56 | 0 | ) -> Result<Self> { |
57 | 0 | let agg_schema = Arc::clone(&aggr.schema); |
58 | 0 | let group_by = aggr.group_by.clone(); |
59 | 0 | let input = aggr.input.execute(partition, Arc::clone(&context))?; |
60 | 0 | let aggregate_arguments = |
61 | 0 | aggregate_expressions(&aggr.aggr_expr, &aggr.mode, group_by.expr.len())?; |
62 | 0 | let (val_field, desc) = aggr |
63 | 0 | .get_minmax_desc() |
64 | 0 | .ok_or_else(|| DataFusionError::Internal("Min/max required".to_string()))?; |
65 | | |
66 | 0 | let (expr, _) = &aggr.group_expr().expr()[0]; |
67 | 0 | let kt = expr.data_type(&aggr.input().schema())?; |
68 | 0 | let vt = val_field.data_type().clone(); |
69 | | |
70 | 0 | let priority_map = PriorityMap::new(kt, vt, limit, desc)?; |
71 | | |
72 | 0 | Ok(GroupedTopKAggregateStream { |
73 | 0 | partition, |
74 | 0 | started: false, |
75 | 0 | row_count: 0, |
76 | 0 | schema: agg_schema, |
77 | 0 | input, |
78 | 0 | aggregate_arguments, |
79 | 0 | group_by, |
80 | 0 | priority_map, |
81 | 0 | }) |
82 | 0 | } |
83 | | } |
84 | | |
85 | | impl RecordBatchStream for GroupedTopKAggregateStream { |
86 | 0 | fn schema(&self) -> SchemaRef { |
87 | 0 | Arc::clone(&self.schema) |
88 | 0 | } |
89 | | } |
90 | | |
91 | | impl GroupedTopKAggregateStream { |
92 | 0 | fn intern(&mut self, ids: ArrayRef, vals: ArrayRef) -> Result<()> { |
93 | 0 | let len = ids.len(); |
94 | 0 | self.priority_map.set_batch(ids, Arc::clone(&vals)); |
95 | 0 |
|
96 | 0 | let has_nulls = vals.null_count() > 0; |
97 | 0 | for row_idx in 0..len { |
98 | 0 | if has_nulls && vals.is_null(row_idx) { |
99 | 0 | continue; |
100 | 0 | } |
101 | 0 | self.priority_map.insert(row_idx)?; |
102 | | } |
103 | 0 | Ok(()) |
104 | 0 | } |
105 | | } |
106 | | |
107 | | impl Stream for GroupedTopKAggregateStream { |
108 | | type Item = Result<RecordBatch>; |
109 | | |
110 | 0 | fn poll_next( |
111 | 0 | mut self: Pin<&mut Self>, |
112 | 0 | cx: &mut Context<'_>, |
113 | 0 | ) -> Poll<Option<Self::Item>> { |
114 | 0 | while let Poll::Ready(res) = self.input.poll_next_unpin(cx) { |
115 | 0 | match res { |
116 | | // got a batch, convert to rows and append to our TreeMap |
117 | 0 | Some(Ok(batch)) => { |
118 | 0 | self.started = true; |
119 | 0 | trace!( |
120 | 0 | "partition {} has {} rows and got batch with {} rows", |
121 | 0 | self.partition, |
122 | 0 | self.row_count, |
123 | 0 | batch.num_rows() |
124 | | ); |
125 | 0 | if log::log_enabled!(Level::Trace) && batch.num_rows() < 20 { |
126 | 0 | print_batches(&[batch.clone()])?; |
127 | 0 | } |
128 | 0 | self.row_count += batch.num_rows(); |
129 | 0 | let batches = &[batch]; |
130 | 0 | let group_by_values = |
131 | 0 | evaluate_group_by(&self.group_by, batches.first().unwrap())?; |
132 | 0 | assert_eq!( |
133 | 0 | group_by_values.len(), |
134 | | 1, |
135 | 0 | "Exactly 1 group value required" |
136 | | ); |
137 | 0 | assert_eq!( |
138 | 0 | group_by_values[0].len(), |
139 | | 1, |
140 | 0 | "Exactly 1 group value required" |
141 | | ); |
142 | 0 | let group_by_values = Arc::clone(&group_by_values[0][0]); |
143 | 0 | let input_values = evaluate_many( |
144 | 0 | &self.aggregate_arguments, |
145 | 0 | batches.first().unwrap(), |
146 | 0 | )?; |
147 | 0 | assert_eq!(input_values.len(), 1, "Exactly 1 input required"); |
148 | 0 | assert_eq!(input_values[0].len(), 1, "Exactly 1 input required"); |
149 | 0 | let input_values = Arc::clone(&input_values[0][0]); |
150 | 0 |
|
151 | 0 | // iterate over each column of group_by values |
152 | 0 | (*self).intern(group_by_values, input_values)?; |
153 | | } |
154 | | // inner is done, emit all rows and switch to producing output |
155 | | None => { |
156 | 0 | if self.priority_map.is_empty() { |
157 | 0 | trace!("partition {} emit None", self.partition); |
158 | 0 | return Poll::Ready(None); |
159 | 0 | } |
160 | 0 | let cols = self.priority_map.emit()?; |
161 | 0 | let batch = RecordBatch::try_new(Arc::clone(&self.schema), cols)?; |
162 | 0 | trace!( |
163 | 0 | "partition {} emit batch with {} rows", |
164 | 0 | self.partition, |
165 | 0 | batch.num_rows() |
166 | | ); |
167 | 0 | if log::log_enabled!(Level::Trace) { |
168 | 0 | print_batches(&[batch.clone()])?; |
169 | 0 | } |
170 | 0 | return Poll::Ready(Some(Ok(batch))); |
171 | | } |
172 | | // inner had error, return to caller |
173 | 0 | Some(Err(e)) => { |
174 | 0 | return Poll::Ready(Some(Err(e))); |
175 | | } |
176 | | } |
177 | | } |
178 | 0 | Poll::Pending |
179 | 0 | } |
180 | | } |