/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/aggregates/group_values/column.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | use crate::aggregates::group_values::group_column::{ |
19 | | ByteGroupValueBuilder, ByteViewGroupValueBuilder, GroupColumn, |
20 | | PrimitiveGroupValueBuilder, |
21 | | }; |
22 | | use crate::aggregates::group_values::GroupValues; |
23 | | use ahash::RandomState; |
24 | | use arrow::compute::cast; |
25 | | use arrow::datatypes::{ |
26 | | BinaryViewType, Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, |
27 | | Int32Type, Int64Type, Int8Type, StringViewType, UInt16Type, UInt32Type, UInt64Type, |
28 | | UInt8Type, |
29 | | }; |
30 | | use arrow::record_batch::RecordBatch; |
31 | | use arrow_array::{Array, ArrayRef}; |
32 | | use arrow_schema::{DataType, Schema, SchemaRef}; |
33 | | use datafusion_common::hash_utils::create_hashes; |
34 | | use datafusion_common::{not_impl_err, DataFusionError, Result}; |
35 | | use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; |
36 | | use datafusion_expr::EmitTo; |
37 | | use datafusion_physical_expr::binary_map::OutputType; |
38 | | |
39 | | use hashbrown::raw::RawTable; |
40 | | |
41 | | /// A [`GroupValues`] that stores multiple columns of group values. |
42 | | /// |
43 | | /// |
44 | | pub struct GroupValuesColumn { |
45 | | /// The output schema |
46 | | schema: SchemaRef, |
47 | | |
48 | | /// Logically maps group values to a group_index in |
49 | | /// [`Self::group_values`] and in each accumulator |
50 | | /// |
51 | | /// Uses the raw API of hashbrown to avoid actually storing the |
52 | | /// keys (group values) in the table |
53 | | /// |
54 | | /// keys: u64 hashes of the GroupValue |
55 | | /// values: (hash, group_index) |
56 | | map: RawTable<(u64, usize)>, |
57 | | |
58 | | /// The size of `map` in bytes |
59 | | map_size: usize, |
60 | | |
61 | | /// The actual group by values, stored column-wise. Compare from |
62 | | /// the left to right, each column is stored as [`GroupColumn`]. |
63 | | /// |
64 | | /// Performance tests showed that this design is faster than using the |
65 | | /// more general purpose [`GroupValuesRows`]. See the ticket for details: |
66 | | /// <https://github.com/apache/datafusion/pull/12269> |
67 | | /// |
68 | | /// [`GroupValuesRows`]: crate::aggregates::group_values::row::GroupValuesRows |
69 | | group_values: Vec<Box<dyn GroupColumn>>, |
70 | | |
71 | | /// reused buffer to store hashes |
72 | | hashes_buffer: Vec<u64>, |
73 | | |
74 | | /// Random state for creating hashes |
75 | | random_state: RandomState, |
76 | | } |
77 | | |
78 | | impl GroupValuesColumn { |
79 | | /// Create a new instance of GroupValuesColumn if supported for the specified schema |
80 | 13 | pub fn try_new(schema: SchemaRef) -> Result<Self> { |
81 | 13 | let map = RawTable::with_capacity(0); |
82 | 13 | Ok(Self { |
83 | 13 | schema, |
84 | 13 | map, |
85 | 13 | map_size: 0, |
86 | 13 | group_values: vec![], |
87 | 13 | hashes_buffer: Default::default(), |
88 | 13 | random_state: Default::default(), |
89 | 13 | }) |
90 | 13 | } |
91 | | |
92 | | /// Returns true if [`GroupValuesColumn`] supported for the specified schema |
93 | 14 | pub fn supported_schema(schema: &Schema) -> bool { |
94 | 14 | schema |
95 | 14 | .fields() |
96 | 14 | .iter() |
97 | 28 | .map(|f| f.data_type()) |
98 | 14 | .all(Self::supported_type) |
99 | 14 | } |
100 | | |
101 | | /// Returns true if the specified data type is supported by [`GroupValuesColumn`] |
102 | | /// |
103 | | /// In order to be supported, there must be a specialized implementation of |
104 | | /// [`GroupColumn`] for the data type, instantiated in [`Self::intern`] |
105 | 28 | fn supported_type(data_type: &DataType) -> bool { |
106 | 1 | matches!( |
107 | 28 | *data_type, |
108 | | DataType::Int8 |
109 | | | DataType::Int16 |
110 | | | DataType::Int32 |
111 | | | DataType::Int64 |
112 | | | DataType::UInt8 |
113 | | | DataType::UInt16 |
114 | | | DataType::UInt32 |
115 | | | DataType::UInt64 |
116 | | | DataType::Float32 |
117 | | | DataType::Float64 |
118 | | | DataType::Utf8 |
119 | | | DataType::LargeUtf8 |
120 | | | DataType::Binary |
121 | | | DataType::LargeBinary |
122 | | | DataType::Date32 |
123 | | | DataType::Date64 |
124 | | | DataType::Utf8View |
125 | | | DataType::BinaryView |
126 | | ) |
127 | 28 | } |
128 | | } |
129 | | |
130 | | /// instantiates a [`PrimitiveGroupValueBuilder`] and pushes it into $v |
131 | | /// |
132 | | /// Arguments: |
133 | | /// `$v`: the vector to push the new builder into |
134 | | /// `$nullable`: whether the input can contains nulls |
135 | | /// `$t`: the primitive type of the builder |
136 | | /// |
137 | | macro_rules! instantiate_primitive { |
138 | | ($v:expr, $nullable:expr, $t:ty) => { |
139 | | if $nullable { |
140 | | let b = PrimitiveGroupValueBuilder::<$t, true>::new(); |
141 | | $v.push(Box::new(b) as _) |
142 | | } else { |
143 | | let b = PrimitiveGroupValueBuilder::<$t, false>::new(); |
144 | | $v.push(Box::new(b) as _) |
145 | | } |
146 | | }; |
147 | | } |
148 | | |
149 | | impl GroupValues for GroupValuesColumn { |
150 | 68 | fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec<usize>) -> Result<()> { |
151 | 68 | let n_rows = cols[0].len(); |
152 | 68 | |
153 | 68 | if self.group_values.is_empty() { |
154 | 13 | let mut v = Vec::with_capacity(cols.len()); |
155 | | |
156 | 27 | for f in self.schema.fields().iter()13 { |
157 | 27 | let nullable = f.is_nullable(); |
158 | 27 | match f.data_type() { |
159 | 0 | &DataType::Int8 => instantiate_primitive!(v, nullable, Int8Type), |
160 | 0 | &DataType::Int16 => instantiate_primitive!(v, nullable, Int16Type), |
161 | 1 | &DataType::Int32 => instantiate_primitive!(v0 , nullable, Int32Type), |
162 | 0 | &DataType::Int64 => instantiate_primitive!(v, nullable, Int64Type), |
163 | 0 | &DataType::UInt8 => instantiate_primitive!(v, nullable, UInt8Type), |
164 | 0 | &DataType::UInt16 => instantiate_primitive!(v, nullable, UInt16Type), |
165 | 12 | &DataType::UInt32 => instantiate_primitive!(v0 , nullable, UInt32Type), |
166 | 0 | &DataType::UInt64 => instantiate_primitive!(v, nullable, UInt64Type), |
167 | | &DataType::Float32 => { |
168 | 2 | instantiate_primitive!(v0 , nullable, Float32Type) |
169 | | } |
170 | | &DataType::Float64 => { |
171 | 12 | instantiate_primitive!(v0 , nullable, Float64Type) |
172 | | } |
173 | 0 | &DataType::Date32 => instantiate_primitive!(v, nullable, Date32Type), |
174 | 0 | &DataType::Date64 => instantiate_primitive!(v, nullable, Date64Type), |
175 | | &DataType::Utf8 => { |
176 | 0 | let b = ByteGroupValueBuilder::<i32>::new(OutputType::Utf8); |
177 | 0 | v.push(Box::new(b) as _) |
178 | | } |
179 | | &DataType::LargeUtf8 => { |
180 | 0 | let b = ByteGroupValueBuilder::<i64>::new(OutputType::Utf8); |
181 | 0 | v.push(Box::new(b) as _) |
182 | | } |
183 | | &DataType::Binary => { |
184 | 0 | let b = ByteGroupValueBuilder::<i32>::new(OutputType::Binary); |
185 | 0 | v.push(Box::new(b) as _) |
186 | | } |
187 | | &DataType::LargeBinary => { |
188 | 0 | let b = ByteGroupValueBuilder::<i64>::new(OutputType::Binary); |
189 | 0 | v.push(Box::new(b) as _) |
190 | | } |
191 | | &DataType::Utf8View => { |
192 | 0 | let b = ByteViewGroupValueBuilder::<StringViewType>::new(); |
193 | 0 | v.push(Box::new(b) as _) |
194 | | } |
195 | | &DataType::BinaryView => { |
196 | 0 | let b = ByteViewGroupValueBuilder::<BinaryViewType>::new(); |
197 | 0 | v.push(Box::new(b) as _) |
198 | | } |
199 | 0 | dt => { |
200 | 0 | return not_impl_err!("{dt} not supported in GroupValuesColumn") |
201 | | } |
202 | | } |
203 | | } |
204 | 13 | self.group_values = v; |
205 | 55 | } |
206 | | |
207 | | // tracks to which group each of the input rows belongs |
208 | 68 | groups.clear(); |
209 | 68 | |
210 | 68 | // 1.1 Calculate the group keys for the group values |
211 | 68 | let batch_hashes = &mut self.hashes_buffer; |
212 | 68 | batch_hashes.clear(); |
213 | 68 | batch_hashes.resize(n_rows, 0); |
214 | 68 | create_hashes(cols, &self.random_state, batch_hashes)?0 ; |
215 | | |
216 | 98.5k | for (row, &target_hash) in batch_hashes.iter().enumerate()68 { |
217 | 98.5k | let entry = self.map.get_mut(target_hash, |(exist_hash, group_idx)| { |
218 | 98.3k | // Somewhat surprisingly, this closure can be called even if the |
219 | 98.3k | // hash doesn't match, so check the hash first with an integer |
220 | 98.3k | // comparison first avoid the more expensive comparison with |
221 | 98.3k | // group value. https://github.com/apache/datafusion/pull/11718 |
222 | 98.3k | if target_hash != *exist_hash { |
223 | 1 | return false; |
224 | 98.3k | } |
225 | | |
226 | 295k | fn check_row_equal( |
227 | 295k | array_row: &dyn GroupColumn, |
228 | 295k | lhs_row: usize, |
229 | 295k | array: &ArrayRef, |
230 | 295k | rhs_row: usize, |
231 | 295k | ) -> bool { |
232 | 295k | array_row.equal_to(lhs_row, array, rhs_row) |
233 | 295k | } |
234 | | |
235 | 295k | for (i, group_val) in self.group_values.iter().enumerate()98.3k { |
236 | 295k | if !check_row_equal(group_val.as_ref(), *group_idx, &cols[i], row) { |
237 | 0 | return false; |
238 | 295k | } |
239 | | } |
240 | | |
241 | 98.3k | true |
242 | 98.5k | }98.3k ); |
243 | | |
244 | 98.5k | let group_idx = match entry { |
245 | | // Existing group_index for this group value |
246 | 98.3k | Some((_hash, group_idx)) => *group_idx, |
247 | | // 1.2 Need to create new entry for the group |
248 | | None => { |
249 | | // Add new entry to aggr_state and save newly created index |
250 | | // let group_idx = group_values.num_rows(); |
251 | | // group_values.push(group_rows.row(row)); |
252 | | |
253 | 163 | let mut checklen = 0; |
254 | 163 | let group_idx = self.group_values[0].len(); |
255 | 329 | for (i, group_value) in self.group_values.iter_mut().enumerate()163 { |
256 | 329 | group_value.append_val(&cols[i], row); |
257 | 329 | let len = group_value.len(); |
258 | 329 | if i == 0 { |
259 | 163 | checklen = len; |
260 | 163 | } else { |
261 | 166 | debug_assert_eq!(checklen, len); |
262 | | } |
263 | | } |
264 | | |
265 | | // for hasher function, use precomputed hash value |
266 | 163 | self.map.insert_accounted( |
267 | 163 | (target_hash, group_idx), |
268 | 163 | |(hash, _group_index)| *hash, |
269 | 163 | &mut self.map_size, |
270 | 163 | ); |
271 | 163 | group_idx |
272 | | } |
273 | | }; |
274 | 98.5k | groups.push(group_idx); |
275 | | } |
276 | | |
277 | 68 | Ok(()) |
278 | 68 | } |
279 | | |
280 | 68 | fn size(&self) -> usize { |
281 | 92 | let group_values_size: usize = self.group_values.iter().map(|v| v.size()).sum(); |
282 | 68 | group_values_size + self.map_size + self.hashes_buffer.allocated_size() |
283 | 68 | } |
284 | | |
285 | 17 | fn is_empty(&self) -> bool { |
286 | 17 | self.len() == 0 |
287 | 17 | } |
288 | | |
289 | 189 | fn len(&self) -> usize { |
290 | 189 | if self.group_values.is_empty() { |
291 | 17 | return 0; |
292 | 172 | } |
293 | 172 | |
294 | 172 | self.group_values[0].len() |
295 | 189 | } |
296 | | |
297 | 15 | fn emit(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> { |
298 | 15 | let mut output = match emit_to { |
299 | | EmitTo::All => { |
300 | 11 | let group_values = std::mem::take(&mut self.group_values); |
301 | 11 | debug_assert!(self.group_values.is_empty()); |
302 | | |
303 | 11 | group_values |
304 | 11 | .into_iter() |
305 | 23 | .map(|v| v.build()) |
306 | 11 | .collect::<Vec<_>>() |
307 | | } |
308 | 4 | EmitTo::First(n) => { |
309 | 4 | let output = self |
310 | 4 | .group_values |
311 | 4 | .iter_mut() |
312 | 8 | .map(|v| v.take_n(n)) |
313 | 4 | .collect::<Vec<_>>(); |
314 | | |
315 | | // SAFETY: self.map outlives iterator and is not modified concurrently |
316 | | unsafe { |
317 | 46 | for bucket in self.map.iter()4 { |
318 | | // Decrement group index by n |
319 | 46 | match bucket.as_ref().1.checked_sub(n) { |
320 | | // Group index was >= n, shift value down |
321 | 6 | Some(sub) => bucket.as_mut().1 = sub, |
322 | | // Group index was < n, so remove from table |
323 | 40 | None => self.map.erase(bucket), |
324 | | } |
325 | | } |
326 | | } |
327 | | |
328 | 4 | output |
329 | | } |
330 | | }; |
331 | | |
332 | | // TODO: Materialize dictionaries in group keys (#7647) |
333 | 31 | for (field, array) in self.schema.fields.iter().zip(&mut output)15 { |
334 | 31 | let expected = field.data_type(); |
335 | 31 | if let DataType::Dictionary(_, v0 ) = expected { |
336 | 0 | let actual = array.data_type(); |
337 | 0 | if v.as_ref() != actual { |
338 | 0 | return Err(DataFusionError::Internal(format!( |
339 | 0 | "Converted group rows expected dictionary of {v} got {actual}" |
340 | 0 | ))); |
341 | 0 | } |
342 | 0 | *array = cast(array.as_ref(), expected)?; |
343 | 31 | } |
344 | | } |
345 | | |
346 | 15 | Ok(output) |
347 | 15 | } |
348 | | |
349 | 13 | fn clear_shrink(&mut self, batch: &RecordBatch) { |
350 | 13 | let count = batch.num_rows(); |
351 | 13 | self.group_values.clear(); |
352 | 13 | self.map.clear(); |
353 | 13 | self.map.shrink_to(count, |_| 00 ); // hasher does not matter since the map is cleared |
354 | 13 | self.map_size = self.map.capacity() * std::mem::size_of::<(u64, usize)>(); |
355 | 13 | self.hashes_buffer.clear(); |
356 | 13 | self.hashes_buffer.shrink_to(count); |
357 | 13 | } |
358 | | } |