/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/aggregates/group_values/row.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | use crate::aggregates::group_values::GroupValues; |
19 | | use ahash::RandomState; |
20 | | use arrow::compute::cast; |
21 | | use arrow::record_batch::RecordBatch; |
22 | | use arrow::row::{RowConverter, Rows, SortField}; |
23 | | use arrow_array::{Array, ArrayRef, ListArray, StructArray}; |
24 | | use arrow_schema::{DataType, SchemaRef}; |
25 | | use datafusion_common::hash_utils::create_hashes; |
26 | | use datafusion_common::Result; |
27 | | use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; |
28 | | use datafusion_expr::EmitTo; |
29 | | use hashbrown::raw::RawTable; |
30 | | use std::sync::Arc; |
31 | | |
32 | | /// A [`GroupValues`] making use of [`Rows`] |
33 | | /// |
34 | | /// This is a general implementation of [`GroupValues`] that works for any |
35 | | /// combination of data types and number of columns, including nested types such as |
36 | | /// structs and lists. |
37 | | /// |
38 | | /// It uses the arrow-rs [`Rows`] to store the group values, which is a row-wise |
39 | | /// representation. |
40 | | pub struct GroupValuesRows { |
41 | | /// The output schema |
42 | | schema: SchemaRef, |
43 | | |
44 | | /// Converter for the group values |
45 | | row_converter: RowConverter, |
46 | | |
47 | | /// Logically maps group values to a group_index in |
48 | | /// [`Self::group_values`] and in each accumulator |
49 | | /// |
50 | | /// Uses the raw API of hashbrown to avoid actually storing the |
51 | | /// keys (group values) in the table |
52 | | /// |
53 | | /// keys: u64 hashes of the GroupValue |
54 | | /// values: (hash, group_index) |
55 | | map: RawTable<(u64, usize)>, |
56 | | |
57 | | /// The size of `map` in bytes |
58 | | map_size: usize, |
59 | | |
60 | | /// The actual group by values, stored in arrow [`Row`] format. |
61 | | /// `group_values[i]` holds the group value for group_index `i`. |
62 | | /// |
63 | | /// The row format is used to compare group keys quickly and store |
64 | | /// them efficiently in memory. Quick comparison is especially |
65 | | /// important for multi-column group keys. |
66 | | /// |
67 | | /// [`Row`]: arrow::row::Row |
68 | | group_values: Option<Rows>, |
69 | | |
70 | | /// reused buffer to store hashes |
71 | | hashes_buffer: Vec<u64>, |
72 | | |
73 | | /// reused buffer to store rows |
74 | | rows_buffer: Rows, |
75 | | |
76 | | /// Random state for creating hashes |
77 | | random_state: RandomState, |
78 | | } |
79 | | |
80 | | impl GroupValuesRows { |
81 | 1 | pub fn try_new(schema: SchemaRef) -> Result<Self> { |
82 | 1 | let row_converter = RowConverter::new( |
83 | 1 | schema |
84 | 1 | .fields() |
85 | 1 | .iter() |
86 | 1 | .map(|f| SortField::new(f.data_type().clone())) |
87 | 1 | .collect(), |
88 | 1 | )?0 ; |
89 | | |
90 | 1 | let map = RawTable::with_capacity(0); |
91 | 1 | |
92 | 1 | let starting_rows_capacity = 1000; |
93 | 1 | |
94 | 1 | let starting_data_capacity = 64 * starting_rows_capacity; |
95 | 1 | let rows_buffer = |
96 | 1 | row_converter.empty_rows(starting_rows_capacity, starting_data_capacity); |
97 | 1 | Ok(Self { |
98 | 1 | schema, |
99 | 1 | row_converter, |
100 | 1 | map, |
101 | 1 | map_size: 0, |
102 | 1 | group_values: None, |
103 | 1 | hashes_buffer: Default::default(), |
104 | 1 | rows_buffer, |
105 | 1 | random_state: Default::default(), |
106 | 1 | }) |
107 | 1 | } |
108 | | } |
109 | | |
110 | | impl GroupValues for GroupValuesRows { |
111 | 1 | fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec<usize>) -> Result<()> { |
112 | 1 | // Convert the group keys into the row format |
113 | 1 | let group_rows = &mut self.rows_buffer; |
114 | 1 | group_rows.clear(); |
115 | 1 | self.row_converter.append(group_rows, cols)?0 ; |
116 | 1 | let n_rows = group_rows.num_rows(); |
117 | | |
118 | 1 | let mut group_values = match self.group_values.take() { |
119 | 0 | Some(group_values) => group_values, |
120 | 1 | None => self.row_converter.empty_rows(0, 0), |
121 | | }; |
122 | | |
123 | | // tracks to which group each of the input rows belongs |
124 | 1 | groups.clear(); |
125 | 1 | |
126 | 1 | // 1.1 Calculate the group keys for the group values |
127 | 1 | let batch_hashes = &mut self.hashes_buffer; |
128 | 1 | batch_hashes.clear(); |
129 | 1 | batch_hashes.resize(n_rows, 0); |
130 | 1 | create_hashes(cols, &self.random_state, batch_hashes)?0 ; |
131 | | |
132 | 3 | for (row, &target_hash) in batch_hashes.iter().enumerate()1 { |
133 | 3 | let entry = self.map.get_mut(target_hash, |(exist_hash, group_idx)| { |
134 | 1 | // Somewhat surprisingly, this closure can be called even if the |
135 | 1 | // hash doesn't match, so check the hash first with an integer |
136 | 1 | // comparison first avoid the more expensive comparison with |
137 | 1 | // group value. https://github.com/apache/datafusion/pull/11718 |
138 | 1 | target_hash == *exist_hash |
139 | | // verify that the group that we are inserting with hash is |
140 | | // actually the same key value as the group in |
141 | | // existing_idx (aka group_values @ row) |
142 | 1 | && group_rows.row(row) == group_values.row(*group_idx) |
143 | 3 | }1 ); |
144 | | |
145 | 3 | let group_idx = match entry { |
146 | | // Existing group_index for this group value |
147 | 1 | Some((_hash, group_idx)) => *group_idx, |
148 | | // 1.2 Need to create new entry for the group |
149 | | None => { |
150 | | // Add new entry to aggr_state and save newly created index |
151 | 2 | let group_idx = group_values.num_rows(); |
152 | 2 | group_values.push(group_rows.row(row)); |
153 | 2 | |
154 | 2 | // for hasher function, use precomputed hash value |
155 | 2 | self.map.insert_accounted( |
156 | 2 | (target_hash, group_idx), |
157 | 2 | |(hash, _group_index)| *hash, |
158 | 2 | &mut self.map_size, |
159 | 2 | ); |
160 | 2 | group_idx |
161 | | } |
162 | | }; |
163 | 3 | groups.push(group_idx); |
164 | | } |
165 | | |
166 | 1 | self.group_values = Some(group_values); |
167 | 1 | |
168 | 1 | Ok(()) |
169 | 1 | } |
170 | | |
171 | 3 | fn size(&self) -> usize { |
172 | 3 | let group_values_size = self.group_values.as_ref().map(|v| v.size()).unwrap_or(0); |
173 | 3 | self.row_converter.size() |
174 | 3 | + group_values_size |
175 | 3 | + self.map_size |
176 | 3 | + self.rows_buffer.size() |
177 | 3 | + self.hashes_buffer.allocated_size() |
178 | 3 | } |
179 | | |
180 | 1 | fn is_empty(&self) -> bool { |
181 | 1 | self.len() == 0 |
182 | 1 | } |
183 | | |
184 | 4 | fn len(&self) -> usize { |
185 | 4 | self.group_values |
186 | 4 | .as_ref() |
187 | 4 | .map(|group_values| group_values.num_rows()2 ) |
188 | 4 | .unwrap_or(0) |
189 | 4 | } |
190 | | |
191 | 1 | fn emit(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> { |
192 | 1 | let mut group_values = self |
193 | 1 | .group_values |
194 | 1 | .take() |
195 | 1 | .expect("Can not emit from empty rows"); |
196 | | |
197 | 1 | let mut output = match emit_to { |
198 | | EmitTo::All => { |
199 | 1 | let output = self.row_converter.convert_rows(&group_values)?0 ; |
200 | 1 | group_values.clear(); |
201 | 1 | output |
202 | | } |
203 | 0 | EmitTo::First(n) => { |
204 | 0 | let groups_rows = group_values.iter().take(n); |
205 | 0 | let output = self.row_converter.convert_rows(groups_rows)?; |
206 | | // Clear out first n group keys by copying them to a new Rows. |
207 | | // TODO file some ticket in arrow-rs to make this more efficient? |
208 | 0 | let mut new_group_values = self.row_converter.empty_rows(0, 0); |
209 | 0 | for row in group_values.iter().skip(n) { |
210 | 0 | new_group_values.push(row); |
211 | 0 | } |
212 | 0 | std::mem::swap(&mut new_group_values, &mut group_values); |
213 | | |
214 | | // SAFETY: self.map outlives iterator and is not modified concurrently |
215 | | unsafe { |
216 | 0 | for bucket in self.map.iter() { |
217 | | // Decrement group index by n |
218 | 0 | match bucket.as_ref().1.checked_sub(n) { |
219 | | // Group index was >= n, shift value down |
220 | 0 | Some(sub) => bucket.as_mut().1 = sub, |
221 | | // Group index was < n, so remove from table |
222 | 0 | None => self.map.erase(bucket), |
223 | | } |
224 | | } |
225 | | } |
226 | 0 | output |
227 | | } |
228 | | }; |
229 | | |
230 | | // TODO: Materialize dictionaries in group keys |
231 | | // https://github.com/apache/datafusion/issues/7647 |
232 | 1 | for (field, array) in self.schema.fields.iter().zip(&mut output) { |
233 | 1 | let expected = field.data_type(); |
234 | 1 | *array = dictionary_encode_if_necessary( |
235 | 1 | Arc::<dyn arrow_array::Array>::clone(array), |
236 | 1 | expected, |
237 | 1 | )?0 ; |
238 | | } |
239 | | |
240 | 1 | self.group_values = Some(group_values); |
241 | 1 | Ok(output) |
242 | 1 | } |
243 | | |
244 | 1 | fn clear_shrink(&mut self, batch: &RecordBatch) { |
245 | 1 | let count = batch.num_rows(); |
246 | 1 | self.group_values = self.group_values.take().map(|mut rows| { |
247 | 1 | rows.clear(); |
248 | 1 | rows |
249 | 1 | }); |
250 | 1 | self.map.clear(); |
251 | 1 | self.map.shrink_to(count, |_| 00 ); // hasher does not matter since the map is cleared |
252 | 1 | self.map_size = self.map.capacity() * std::mem::size_of::<(u64, usize)>(); |
253 | 1 | self.hashes_buffer.clear(); |
254 | 1 | self.hashes_buffer.shrink_to(count); |
255 | 1 | } |
256 | | } |
257 | | |
258 | 3 | fn dictionary_encode_if_necessary( |
259 | 3 | array: ArrayRef, |
260 | 3 | expected: &DataType, |
261 | 3 | ) -> Result<ArrayRef> { |
262 | 3 | match (expected, array.data_type()) { |
263 | 1 | (DataType::Struct(expected_fields), _) => { |
264 | 1 | let struct_array = array.as_any().downcast_ref::<StructArray>().unwrap(); |
265 | 1 | let arrays = expected_fields |
266 | 1 | .iter() |
267 | 1 | .zip(struct_array.columns()) |
268 | 2 | .map(|(expected_field, column)| { |
269 | 2 | dictionary_encode_if_necessary( |
270 | 2 | Arc::<dyn arrow_array::Array>::clone(column), |
271 | 2 | expected_field.data_type(), |
272 | 2 | ) |
273 | 2 | }) |
274 | 1 | .collect::<Result<Vec<_>>>()?0 ; |
275 | | |
276 | 1 | Ok(Arc::new(StructArray::try_new( |
277 | 1 | expected_fields.clone(), |
278 | 1 | arrays, |
279 | 1 | struct_array.nulls().cloned(), |
280 | 1 | )?0 )) |
281 | | } |
282 | 0 | (DataType::List(expected_field), &DataType::List(_)) => { |
283 | 0 | let list = array.as_any().downcast_ref::<ListArray>().unwrap(); |
284 | 0 |
|
285 | 0 | Ok(Arc::new(ListArray::try_new( |
286 | 0 | Arc::<arrow_schema::Field>::clone(expected_field), |
287 | 0 | list.offsets().clone(), |
288 | 0 | dictionary_encode_if_necessary( |
289 | 0 | Arc::<dyn arrow_array::Array>::clone(list.values()), |
290 | 0 | expected_field.data_type(), |
291 | 0 | )?, |
292 | 0 | list.nulls().cloned(), |
293 | 0 | )?)) |
294 | | } |
295 | 2 | (DataType::Dictionary(_, _), _) => Ok(cast(array.as_ref(), expected)?0 ), |
296 | 0 | (_, _) => Ok(Arc::<dyn arrow_array::Array>::clone(&array)), |
297 | | } |
298 | 3 | } |