/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/aggregates/group_values/primitive.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | use crate::aggregates::group_values::GroupValues; |
19 | | use ahash::RandomState; |
20 | | use arrow::array::BooleanBufferBuilder; |
21 | | use arrow::buffer::NullBuffer; |
22 | | use arrow::datatypes::i256; |
23 | | use arrow::record_batch::RecordBatch; |
24 | | use arrow_array::cast::AsArray; |
25 | | use arrow_array::{ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray}; |
26 | | use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; |
27 | | use arrow_schema::DataType; |
28 | | use datafusion_common::Result; |
29 | | use datafusion_execution::memory_pool::proxy::VecAllocExt; |
30 | | use datafusion_expr::EmitTo; |
31 | | use half::f16; |
32 | | use hashbrown::raw::RawTable; |
33 | | use std::sync::Arc; |
34 | | |
35 | | /// A trait to allow hashing of floating point numbers |
36 | | pub(crate) trait HashValue { |
37 | | fn hash(&self, state: &RandomState) -> u64; |
38 | | } |
39 | | |
40 | | macro_rules! hash_integer { |
41 | | ($($t:ty),+) => { |
42 | | $(impl HashValue for $t { |
43 | | #[cfg(not(feature = "force_hash_collisions"))] |
44 | 343 | fn hash(&self, state: &RandomState) -> u64 { |
45 | 343 | state.hash_one(self) |
46 | 343 | } |
47 | | |
48 | | #[cfg(feature = "force_hash_collisions")] |
49 | | fn hash(&self, _state: &RandomState) -> u64 { |
50 | | 0 |
51 | | } |
52 | | })+ |
53 | | }; |
54 | | } |
55 | | hash_integer!(i8, i16, i32, i64, i128, i256); |
56 | | hash_integer!(u8, u16, u32, u64); |
57 | | hash_integer!(IntervalDayTime, IntervalMonthDayNano); |
58 | | |
59 | | macro_rules! hash_float { |
60 | | ($($t:ty),+) => { |
61 | | $(impl HashValue for $t { |
62 | | #[cfg(not(feature = "force_hash_collisions"))] |
63 | 0 | fn hash(&self, state: &RandomState) -> u64 { |
64 | 0 | state.hash_one(self.to_bits()) |
65 | 0 | } |
66 | | |
67 | | #[cfg(feature = "force_hash_collisions")] |
68 | | fn hash(&self, _state: &RandomState) -> u64 { |
69 | | 0 |
70 | | } |
71 | | })+ |
72 | | }; |
73 | | } |
74 | | |
75 | | hash_float!(f16, f32, f64); |
76 | | |
77 | | /// A [`GroupValues`] storing a single column of primitive values |
78 | | /// |
79 | | /// This specialization is significantly faster than using the more general |
80 | | /// purpose `Row`s format |
81 | | pub struct GroupValuesPrimitive<T: ArrowPrimitiveType> { |
82 | | /// The data type of the output array |
83 | | data_type: DataType, |
84 | | /// Stores the group index based on the hash of its value |
85 | | /// |
86 | | /// We don't store the hashes as hashing fixed width primitives |
87 | | /// is fast enough for this not to benefit performance |
88 | | map: RawTable<usize>, |
89 | | /// The group index of the null value if any |
90 | | null_group: Option<usize>, |
91 | | /// The values for each group index |
92 | | values: Vec<T::Native>, |
93 | | /// The random state used to generate hashes |
94 | | random_state: RandomState, |
95 | | } |
96 | | |
97 | | impl<T: ArrowPrimitiveType> GroupValuesPrimitive<T> { |
98 | 56 | pub fn new(data_type: DataType) -> Self { |
99 | 56 | assert!(PrimitiveArray::<T>::is_compatible(&data_type)); |
100 | 56 | Self { |
101 | 56 | data_type, |
102 | 56 | map: RawTable::with_capacity(128), |
103 | 56 | values: Vec::with_capacity(128), |
104 | 56 | null_group: None, |
105 | 56 | random_state: Default::default(), |
106 | 56 | } |
107 | 56 | } |
108 | | } |
109 | | |
110 | | impl<T: ArrowPrimitiveType> GroupValues for GroupValuesPrimitive<T> |
111 | | where |
112 | | T::Native: HashValue, |
113 | | { |
114 | 100 | fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec<usize>) -> Result<()> { |
115 | 100 | assert_eq!(cols.len(), 1); |
116 | 100 | groups.clear(); |
117 | | |
118 | 337 | for v in cols[0].as_primitive::<T>()100 { |
119 | 337 | let group_id = match v { |
120 | 0 | None => *self.null_group.get_or_insert_with(|| { |
121 | 0 | let group_id = self.values.len(); |
122 | 0 | self.values.push(Default::default()); |
123 | 0 | group_id |
124 | 0 | }), |
125 | 337 | Some(key) => { |
126 | 337 | let state = &self.random_state; |
127 | 337 | let hash = key.hash(state); |
128 | 337 | let insert = self.map.find_or_find_insert_slot( |
129 | 337 | hash, |
130 | 337 | |g| unsafe { self.values.get_unchecked(*g).is_eq(key) }143 , |
131 | 337 | |g| unsafe { self.values.get_unchecked(*g).hash(state) }6 , |
132 | 337 | ); |
133 | 337 | |
134 | 337 | // SAFETY: No mutation occurred since find_or_find_insert_slot |
135 | 337 | unsafe { |
136 | 337 | match insert { |
137 | 143 | Ok(v) => *v.as_ref(), |
138 | 194 | Err(slot) => { |
139 | 194 | let g = self.values.len(); |
140 | 194 | self.map.insert_in_slot(hash, slot, g); |
141 | 194 | self.values.push(key); |
142 | 194 | g |
143 | | } |
144 | | } |
145 | | } |
146 | | } |
147 | | }; |
148 | 337 | groups.push(group_id) |
149 | | } |
150 | 100 | Ok(()) |
151 | 100 | } |
152 | | |
153 | 300 | fn size(&self) -> usize { |
154 | 300 | self.map.capacity() * std::mem::size_of::<usize>() + self.values.allocated_size() |
155 | 300 | } |
156 | | |
157 | 94 | fn is_empty(&self) -> bool { |
158 | 94 | self.values.is_empty() |
159 | 94 | } |
160 | | |
161 | 403 | fn len(&self) -> usize { |
162 | 403 | self.values.len() |
163 | 403 | } |
164 | | |
165 | 94 | fn emit(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> { |
166 | 94 | fn build_primitive<T: ArrowPrimitiveType>( |
167 | 94 | values: Vec<T::Native>, |
168 | 94 | null_idx: Option<usize>, |
169 | 94 | ) -> PrimitiveArray<T> { |
170 | 94 | let nulls = null_idx.map(|null_idx| { |
171 | 0 | let mut buffer = BooleanBufferBuilder::new(values.len()); |
172 | 0 | buffer.append_n(values.len(), true); |
173 | 0 | buffer.set_bit(null_idx, false); |
174 | 0 | unsafe { NullBuffer::new_unchecked(buffer.finish(), 1) } |
175 | 94 | }); |
176 | 94 | PrimitiveArray::<T>::new(values.into(), nulls) |
177 | 94 | } |
178 | | |
179 | 94 | let array: PrimitiveArray<T> = match emit_to { |
180 | | EmitTo::All => { |
181 | 62 | self.map.clear(); |
182 | 62 | build_primitive(std::mem::take(&mut self.values), self.null_group.take()) |
183 | | } |
184 | 32 | EmitTo::First(n) => { |
185 | | // SAFETY: self.map outlives iterator and is not modified concurrently |
186 | | unsafe { |
187 | 88 | for bucket in self.map.iter()32 { |
188 | | // Decrement group index by n |
189 | 88 | match bucket.as_ref().checked_sub(n) { |
190 | | // Group index was >= n, shift value down |
191 | 32 | Some(sub) => *bucket.as_mut() = sub, |
192 | | // Group index was < n, so remove from table |
193 | 56 | None => self.map.erase(bucket), |
194 | | } |
195 | | } |
196 | | } |
197 | 32 | let null_group = match &mut self.null_group { |
198 | 0 | Some(v) if *v >= n => { |
199 | 0 | *v -= n; |
200 | 0 | None |
201 | | } |
202 | 0 | Some(_) => self.null_group.take(), |
203 | 32 | None => None, |
204 | | }; |
205 | 32 | let mut split = self.values.split_off(n); |
206 | 32 | std::mem::swap(&mut self.values, &mut split); |
207 | 32 | build_primitive(split, null_group) |
208 | | } |
209 | | }; |
210 | 94 | Ok(vec![Arc::new(array.with_data_type(self.data_type.clone()))]) |
211 | 94 | } |
212 | | |
213 | 62 | fn clear_shrink(&mut self, batch: &RecordBatch) { |
214 | 62 | let count = batch.num_rows(); |
215 | 62 | self.values.clear(); |
216 | 62 | self.values.shrink_to(count); |
217 | 62 | self.map.clear(); |
218 | 62 | self.map.shrink_to(count, |_| 00 ); // hasher does not matter since the map is cleared |
219 | 62 | } |
220 | | } |