/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/aggregates/topk/hash_table.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! A wrapper around `hashbrown::RawTable` that allows entries to be tracked by index |
19 | | |
20 | | use crate::aggregates::group_values::primitive::HashValue; |
21 | | use crate::aggregates::topk::heap::Comparable; |
22 | | use ahash::RandomState; |
23 | | use arrow::datatypes::i256; |
24 | | use arrow_array::builder::PrimitiveBuilder; |
25 | | use arrow_array::cast::AsArray; |
26 | | use arrow_array::{ |
27 | | downcast_primitive, Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray, StringArray, |
28 | | }; |
29 | | use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; |
30 | | use arrow_schema::DataType; |
31 | | use datafusion_common::DataFusionError; |
32 | | use datafusion_common::Result; |
33 | | use half::f16; |
34 | | use hashbrown::raw::RawTable; |
35 | | use std::fmt::Debug; |
36 | | use std::sync::Arc; |
37 | | |
38 | | /// A "type alias" for Keys which are stored in our map |
39 | | pub trait KeyType: Clone + Comparable + Debug {} |
40 | | |
41 | | impl<T> KeyType for T where T: Clone + Comparable + Debug {} |
42 | | |
43 | | /// An entry in our hash table that: |
44 | | /// 1. memoizes the hash |
45 | | /// 2. contains the key (ID) |
46 | | /// 3. contains the value (heap_idx - an index into the corresponding heap) |
47 | | pub struct HashTableItem<ID: KeyType> { |
48 | | hash: u64, |
49 | | pub id: ID, |
50 | | pub heap_idx: usize, |
51 | | } |
52 | | |
53 | | /// A custom wrapper around `hashbrown::RawTable` that: |
54 | | /// 1. limits the number of entries to the top K |
55 | | /// 2. Allocates a capacity greater than top K to maintain a low-fill factor and prevent resizing |
56 | | /// 3. Tracks indexes to allow corresponding heap to refer to entries by index vs hash |
57 | | /// 4. Catches resize events to allow the corresponding heap to update it's indexes |
58 | | struct TopKHashTable<ID: KeyType> { |
59 | | map: RawTable<HashTableItem<ID>>, |
60 | | limit: usize, |
61 | | } |
62 | | |
63 | | /// An interface to hide the generic type signature of TopKHashTable behind arrow arrays |
64 | | pub trait ArrowHashTable { |
65 | | fn set_batch(&mut self, ids: ArrayRef); |
66 | | fn len(&self) -> usize; |
67 | | // JUSTIFICATION |
68 | | // Benefit: ~15% speedup + required to index into RawTable from binary heap |
69 | | // Soundness: the caller must provide valid indexes |
70 | | unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]); |
71 | | // JUSTIFICATION |
72 | | // Benefit: ~15% speedup + required to index into RawTable from binary heap |
73 | | // Soundness: the caller must provide a valid index |
74 | | unsafe fn heap_idx_at(&self, map_idx: usize) -> usize; |
75 | | unsafe fn take_all(&mut self, indexes: Vec<usize>) -> ArrayRef; |
76 | | |
77 | | // JUSTIFICATION |
78 | | // Benefit: ~15% speedup + required to index into RawTable from binary heap |
79 | | // Soundness: the caller must provide valid indexes |
80 | | unsafe fn find_or_insert( |
81 | | &mut self, |
82 | | row_idx: usize, |
83 | | replace_idx: usize, |
84 | | map: &mut Vec<(usize, usize)>, |
85 | | ) -> (usize, bool); |
86 | | } |
87 | | |
88 | | // An implementation of ArrowHashTable for String keys |
89 | | pub struct StringHashTable { |
90 | | owned: ArrayRef, |
91 | | map: TopKHashTable<Option<String>>, |
92 | | rnd: RandomState, |
93 | | } |
94 | | |
95 | | // An implementation of ArrowHashTable for any `ArrowPrimitiveType` key |
96 | | struct PrimitiveHashTable<VAL: ArrowPrimitiveType> |
97 | | where |
98 | | Option<<VAL as ArrowPrimitiveType>::Native>: Comparable, |
99 | | { |
100 | | owned: ArrayRef, |
101 | | map: TopKHashTable<Option<VAL::Native>>, |
102 | | rnd: RandomState, |
103 | | } |
104 | | |
105 | | impl StringHashTable { |
106 | 10 | pub fn new(limit: usize) -> Self { |
107 | 10 | let vals: Vec<&str> = Vec::new(); |
108 | 10 | let owned = Arc::new(StringArray::from(vals)); |
109 | 10 | Self { |
110 | 10 | owned, |
111 | 10 | map: TopKHashTable::new(limit, limit * 10), |
112 | 10 | rnd: ahash::RandomState::default(), |
113 | 10 | } |
114 | 10 | } |
115 | | } |
116 | | |
117 | | impl ArrowHashTable for StringHashTable { |
118 | 10 | fn set_batch(&mut self, ids: ArrayRef) { |
119 | 10 | self.owned = ids; |
120 | 10 | } |
121 | | |
122 | 20 | fn len(&self) -> usize { |
123 | 20 | self.map.len() |
124 | 20 | } |
125 | | |
126 | 18 | unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) { |
127 | 18 | self.map.update_heap_idx(mapper); |
128 | 18 | } |
129 | | |
130 | 5 | unsafe fn heap_idx_at(&self, map_idx: usize) -> usize { |
131 | 5 | self.map.heap_idx_at(map_idx) |
132 | 5 | } |
133 | | |
134 | 10 | unsafe fn take_all(&mut self, indexes: Vec<usize>) -> ArrayRef { |
135 | 10 | let ids = self.map.take_all(indexes); |
136 | 10 | Arc::new(StringArray::from(ids)) |
137 | 10 | } |
138 | | |
139 | 18 | unsafe fn find_or_insert( |
140 | 18 | &mut self, |
141 | 18 | row_idx: usize, |
142 | 18 | replace_idx: usize, |
143 | 18 | mapper: &mut Vec<(usize, usize)>, |
144 | 18 | ) -> (usize, bool) { |
145 | 18 | let ids = self |
146 | 18 | .owned |
147 | 18 | .as_any() |
148 | 18 | .downcast_ref::<StringArray>() |
149 | 18 | .expect("StringArray required"); |
150 | 18 | let id = if ids.is_null(row_idx) { |
151 | 2 | None |
152 | | } else { |
153 | 16 | Some(ids.value(row_idx)) |
154 | | }; |
155 | | |
156 | 18 | let hash = self.rnd.hash_one(id); |
157 | 18 | if let Some(map_idx5 ) = self |
158 | 18 | .map |
159 | 18 | .find(hash, |mi| id == mi.as_ref().map(5 |id| id.as_str()4 )5 ) |
160 | | { |
161 | 5 | return (map_idx, false); |
162 | 13 | } |
163 | 13 | |
164 | 13 | // we're full and this is a better value, so remove the worst |
165 | 13 | let heap_idx = self.map.remove_if_full(replace_idx); |
166 | 13 | |
167 | 13 | // add the new group |
168 | 13 | let id = id.map(|id| id.to_string()12 ); |
169 | 13 | let map_idx = self.map.insert(hash, id, heap_idx, mapper); |
170 | 13 | (map_idx, true) |
171 | 18 | } |
172 | | } |
173 | | |
174 | | impl<VAL: ArrowPrimitiveType> PrimitiveHashTable<VAL> |
175 | | where |
176 | | Option<<VAL as ArrowPrimitiveType>::Native>: Comparable, |
177 | | Option<<VAL as ArrowPrimitiveType>::Native>: HashValue, |
178 | | { |
179 | 0 | pub fn new(limit: usize) -> Self { |
180 | 0 | let owned = Arc::new(PrimitiveArray::<VAL>::builder(0).finish()); |
181 | 0 | Self { |
182 | 0 | owned, |
183 | 0 | map: TopKHashTable::new(limit, limit * 10), |
184 | 0 | rnd: ahash::RandomState::default(), |
185 | 0 | } |
186 | 0 | } |
187 | | } |
188 | | |
189 | | impl<VAL: ArrowPrimitiveType> ArrowHashTable for PrimitiveHashTable<VAL> |
190 | | where |
191 | | Option<<VAL as ArrowPrimitiveType>::Native>: Comparable, |
192 | | Option<<VAL as ArrowPrimitiveType>::Native>: HashValue, |
193 | | { |
194 | 0 | fn set_batch(&mut self, ids: ArrayRef) { |
195 | 0 | self.owned = ids; |
196 | 0 | } |
197 | | |
198 | 0 | fn len(&self) -> usize { |
199 | 0 | self.map.len() |
200 | 0 | } |
201 | | |
202 | 0 | unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) { |
203 | 0 | self.map.update_heap_idx(mapper); |
204 | 0 | } |
205 | | |
206 | 0 | unsafe fn heap_idx_at(&self, map_idx: usize) -> usize { |
207 | 0 | self.map.heap_idx_at(map_idx) |
208 | 0 | } |
209 | | |
210 | 0 | unsafe fn take_all(&mut self, indexes: Vec<usize>) -> ArrayRef { |
211 | 0 | let ids = self.map.take_all(indexes); |
212 | 0 | let mut builder: PrimitiveBuilder<VAL> = PrimitiveArray::builder(ids.len()); |
213 | 0 | for id in ids.into_iter() { |
214 | 0 | match id { |
215 | 0 | None => builder.append_null(), |
216 | 0 | Some(id) => builder.append_value(id), |
217 | | } |
218 | | } |
219 | 0 | let ids = builder.finish(); |
220 | 0 | Arc::new(ids) |
221 | 0 | } |
222 | | |
223 | 0 | unsafe fn find_or_insert( |
224 | 0 | &mut self, |
225 | 0 | row_idx: usize, |
226 | 0 | replace_idx: usize, |
227 | 0 | mapper: &mut Vec<(usize, usize)>, |
228 | 0 | ) -> (usize, bool) { |
229 | 0 | let ids = self.owned.as_primitive::<VAL>(); |
230 | 0 | let id: Option<VAL::Native> = if ids.is_null(row_idx) { |
231 | 0 | None |
232 | | } else { |
233 | 0 | Some(ids.value(row_idx)) |
234 | | }; |
235 | | |
236 | 0 | let hash: u64 = id.hash(&self.rnd); |
237 | 0 | if let Some(map_idx) = self.map.find(hash, |mi| id == *mi) { |
238 | 0 | return (map_idx, false); |
239 | 0 | } |
240 | 0 |
|
241 | 0 | // we're full and this is a better value, so remove the worst |
242 | 0 | let heap_idx = self.map.remove_if_full(replace_idx); |
243 | 0 |
|
244 | 0 | // add the new group |
245 | 0 | let map_idx = self.map.insert(hash, id, heap_idx, mapper); |
246 | 0 | (map_idx, true) |
247 | 0 | } |
248 | | } |
249 | | |
250 | | impl<ID: KeyType> TopKHashTable<ID> { |
251 | 11 | pub fn new(limit: usize, capacity: usize) -> Self { |
252 | 11 | Self { |
253 | 11 | map: RawTable::with_capacity(capacity), |
254 | 11 | limit, |
255 | 11 | } |
256 | 11 | } |
257 | | |
258 | 18 | pub fn find(&self, hash: u64, mut eq: impl FnMut(&ID) -> bool) -> Option<usize> { |
259 | 18 | let bucket5 = self.map.find(hash, |mi| eq(&mi.id)5 )?13 ; |
260 | | // JUSTIFICATION |
261 | | // Benefit: ~15% speedup + required to index into RawTable from binary heap |
262 | | // Soundness: getting the index of a bucket we just found |
263 | 5 | let idx = unsafe { self.map.bucket_index(&bucket) }; |
264 | 5 | Some(idx) |
265 | 18 | } |
266 | | |
267 | 5 | pub unsafe fn heap_idx_at(&self, map_idx: usize) -> usize { |
268 | 5 | let bucket = unsafe { self.map.bucket(map_idx) }; |
269 | 5 | bucket.as_ref().heap_idx |
270 | 5 | } |
271 | | |
272 | 13 | pub unsafe fn remove_if_full(&mut self, replace_idx: usize) -> usize { |
273 | 13 | if self.map.len() >= self.limit { |
274 | 2 | self.map.erase(self.map.bucket(replace_idx)); |
275 | 2 | 0 // if full, always replace top node |
276 | | } else { |
277 | 11 | self.map.len() // if we're not full, always append to end |
278 | | } |
279 | 13 | } |
280 | | |
281 | 18 | unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) { |
282 | 18 | for (m, h0 ) in mapper { |
283 | 0 | self.map.bucket(*m).as_mut().heap_idx = *h |
284 | | } |
285 | 18 | } |
286 | | |
287 | 18 | pub fn insert( |
288 | 18 | &mut self, |
289 | 18 | hash: u64, |
290 | 18 | id: ID, |
291 | 18 | heap_idx: usize, |
292 | 18 | mapper: &mut Vec<(usize, usize)>, |
293 | 18 | ) -> usize { |
294 | 18 | let mi = HashTableItem::new(hash, id, heap_idx); |
295 | 18 | let bucket = self.map.try_insert_no_grow(hash, mi); |
296 | 18 | let bucket = match bucket { |
297 | 17 | Ok(bucket) => bucket, |
298 | 1 | Err(new_item) => { |
299 | 3 | let bucket = self.map.insert(hash, new_item, |mi| mi.hash)1 ; |
300 | | // JUSTIFICATION |
301 | | // Benefit: ~15% speedup + required to index into RawTable from binary heap |
302 | | // Soundness: we're getting indexes of buckets, not dereferencing them |
303 | | unsafe { |
304 | 4 | for bucket in self.map.iter()1 { |
305 | 4 | let heap_idx = bucket.as_ref().heap_idx; |
306 | 4 | let map_idx = self.map.bucket_index(&bucket); |
307 | 4 | mapper.push((heap_idx, map_idx)); |
308 | 4 | } |
309 | | } |
310 | 1 | bucket |
311 | | } |
312 | | }; |
313 | | // JUSTIFICATION |
314 | | // Benefit: ~15% speedup + required to index into RawTable from binary heap |
315 | | // Soundness: we're getting indexes of buckets, not dereferencing them |
316 | 18 | unsafe { self.map.bucket_index(&bucket) } |
317 | 18 | } |
318 | | |
319 | 21 | pub fn len(&self) -> usize { |
320 | 21 | self.map.len() |
321 | 21 | } |
322 | | |
323 | 11 | pub unsafe fn take_all(&mut self, idxs: Vec<usize>) -> Vec<ID> { |
324 | 11 | let ids = idxs |
325 | 11 | .into_iter() |
326 | 16 | .map(|idx| self.map.bucket(idx).as_ref().id.clone()) |
327 | 11 | .collect(); |
328 | 11 | self.map.clear(); |
329 | 11 | ids |
330 | 11 | } |
331 | | } |
332 | | |
333 | | impl<ID: KeyType> HashTableItem<ID> { |
334 | 18 | pub fn new(hash: u64, id: ID, heap_idx: usize) -> Self { |
335 | 18 | Self { hash, id, heap_idx } |
336 | 18 | } |
337 | | } |
338 | | |
339 | | impl HashValue for Option<String> { |
340 | 0 | fn hash(&self, state: &RandomState) -> u64 { |
341 | 0 | state.hash_one(self) |
342 | 0 | } |
343 | | } |
344 | | |
345 | | macro_rules! hash_float { |
346 | | ($($t:ty),+) => { |
347 | | $(impl HashValue for Option<$t> { |
348 | 0 | fn hash(&self, state: &RandomState) -> u64 { |
349 | 0 | self.map(|me| me.hash(state)).unwrap_or(0) |
350 | 0 | } |
351 | | })+ |
352 | | }; |
353 | | } |
354 | | |
355 | | macro_rules! has_integer { |
356 | | ($($t:ty),+) => { |
357 | | $(impl HashValue for Option<$t> { |
358 | 0 | fn hash(&self, state: &RandomState) -> u64 { |
359 | 0 | self.map(|me| me.hash(state)).unwrap_or(0) |
360 | 0 | } |
361 | | })+ |
362 | | }; |
363 | | } |
364 | | |
365 | | has_integer!(i8, i16, i32, i64, i128, i256); |
366 | | has_integer!(u8, u16, u32, u64); |
367 | | has_integer!(IntervalDayTime, IntervalMonthDayNano); |
368 | | hash_float!(f16, f32, f64); |
369 | | |
370 | 10 | pub fn new_hash_table( |
371 | 10 | limit: usize, |
372 | 10 | kt: DataType, |
373 | 10 | ) -> Result<Box<dyn ArrowHashTable + Send>> { |
374 | | macro_rules! downcast_helper { |
375 | | ($kt:ty, $d:ident) => { |
376 | | return Ok(Box::new(PrimitiveHashTable::<$kt>::new(limit))) |
377 | | }; |
378 | | } |
379 | | |
380 | 0 | downcast_primitive! { |
381 | 0 | kt => (downcast_helper, kt), |
382 | 10 | DataType::Utf8 => return Ok(Box::new(StringHashTable::new(limit))), |
383 | 0 | _ => {} |
384 | 0 | } |
385 | 0 |
|
386 | 0 | Err(DataFusionError::Execution(format!( |
387 | 0 | "Can't create HashTable for type: {kt:?}" |
388 | 0 | ))) |
389 | 10 | } |
390 | | |
391 | | #[cfg(test)] |
392 | | mod tests { |
393 | | use super::*; |
394 | | use std::collections::BTreeMap; |
395 | | |
396 | | #[test] |
397 | 1 | fn should_resize_properly() -> Result<()> { |
398 | 1 | let mut heap_to_map = BTreeMap::<usize, usize>::new(); |
399 | 1 | let mut map = TopKHashTable::<Option<String>>::new(5, 3); |
400 | 5 | for (heap_idx, id) in vec!["1", "2", "3", "4", "5"].into_iter().enumerate()1 { |
401 | 5 | let mut mapper = vec![]; |
402 | 5 | let hash = heap_idx as u64; |
403 | 5 | let map_idx = map.insert(hash, Some(id.to_string()), heap_idx, &mut mapper); |
404 | 5 | let _ = heap_to_map.insert(heap_idx, map_idx); |
405 | 5 | if heap_idx == 3 { |
406 | 1 | assert_eq!( |
407 | 1 | mapper, |
408 | 1 | vec![(0, 0), (1, 1), (2, 2), (3, 3)], |
409 | 0 | "Pass {heap_idx} resized incorrectly!" |
410 | | ); |
411 | 5 | for (heap_idx, map_idx4 ) in mapper { |
412 | 4 | let _ = heap_to_map.insert(heap_idx, map_idx); |
413 | 4 | } |
414 | | } else { |
415 | 4 | assert_eq!(mapper, vec![], "Pass {heap_idx} should not have resized!"0 ); |
416 | | } |
417 | | } |
418 | | |
419 | 1 | let (_heap_idxs, map_idxs): (Vec<_>, Vec<_>) = heap_to_map.into_iter().unzip(); |
420 | 1 | let ids = unsafe { map.take_all(map_idxs) }; |
421 | 1 | assert_eq!( |
422 | 1 | format!("{:?}", ids), |
423 | 1 | r#"[Some("1"), Some("2"), Some("3"), Some("4"), Some("5")]"# |
424 | 1 | ); |
425 | 1 | assert_eq!(map.len(), 0, "Map should have been cleared!"0 ); |
426 | | |
427 | 1 | Ok(()) |
428 | 1 | } |
429 | | } |