Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/aggregates/topk/hash_table.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! A wrapper around `hashbrown::RawTable` that allows entries to be tracked by index
19
20
use crate::aggregates::group_values::primitive::HashValue;
21
use crate::aggregates::topk::heap::Comparable;
22
use ahash::RandomState;
23
use arrow::datatypes::i256;
24
use arrow_array::builder::PrimitiveBuilder;
25
use arrow_array::cast::AsArray;
26
use arrow_array::{
27
    downcast_primitive, Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray, StringArray,
28
};
29
use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
30
use arrow_schema::DataType;
31
use datafusion_common::DataFusionError;
32
use datafusion_common::Result;
33
use half::f16;
34
use hashbrown::raw::RawTable;
35
use std::fmt::Debug;
36
use std::sync::Arc;
37
38
/// A "type alias" for Keys which are stored in our map
39
pub trait KeyType: Clone + Comparable + Debug {}
40
41
impl<T> KeyType for T where T: Clone + Comparable + Debug {}
42
43
/// An entry in our hash table that:
44
/// 1. memoizes the hash
45
/// 2. contains the key (ID)
46
/// 3. contains the value (heap_idx - an index into the corresponding heap)
47
pub struct HashTableItem<ID: KeyType> {
48
    hash: u64,
49
    pub id: ID,
50
    pub heap_idx: usize,
51
}
52
53
/// A custom wrapper around `hashbrown::RawTable` that:
54
/// 1. limits the number of entries to the top K
55
/// 2. Allocates a capacity greater than top K to maintain a low-fill factor and prevent resizing
56
/// 3. Tracks indexes to allow corresponding heap to refer to entries by index vs hash
57
/// 4. Catches resize events to allow the corresponding heap to update it's indexes
58
struct TopKHashTable<ID: KeyType> {
59
    map: RawTable<HashTableItem<ID>>,
60
    limit: usize,
61
}
62
63
/// An interface to hide the generic type signature of TopKHashTable behind arrow arrays
64
pub trait ArrowHashTable {
65
    fn set_batch(&mut self, ids: ArrayRef);
66
    fn len(&self) -> usize;
67
    // JUSTIFICATION
68
    //  Benefit:  ~15% speedup + required to index into RawTable from binary heap
69
    //  Soundness: the caller must provide valid indexes
70
    unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]);
71
    // JUSTIFICATION
72
    //  Benefit:  ~15% speedup + required to index into RawTable from binary heap
73
    //  Soundness: the caller must provide a valid index
74
    unsafe fn heap_idx_at(&self, map_idx: usize) -> usize;
75
    unsafe fn take_all(&mut self, indexes: Vec<usize>) -> ArrayRef;
76
77
    // JUSTIFICATION
78
    //  Benefit:  ~15% speedup + required to index into RawTable from binary heap
79
    //  Soundness: the caller must provide valid indexes
80
    unsafe fn find_or_insert(
81
        &mut self,
82
        row_idx: usize,
83
        replace_idx: usize,
84
        map: &mut Vec<(usize, usize)>,
85
    ) -> (usize, bool);
86
}
87
88
// An implementation of ArrowHashTable for String keys
89
pub struct StringHashTable {
90
    owned: ArrayRef,
91
    map: TopKHashTable<Option<String>>,
92
    rnd: RandomState,
93
}
94
95
// An implementation of ArrowHashTable for any `ArrowPrimitiveType` key
96
struct PrimitiveHashTable<VAL: ArrowPrimitiveType>
97
where
98
    Option<<VAL as ArrowPrimitiveType>::Native>: Comparable,
99
{
100
    owned: ArrayRef,
101
    map: TopKHashTable<Option<VAL::Native>>,
102
    rnd: RandomState,
103
}
104
105
impl StringHashTable {
106
10
    pub fn new(limit: usize) -> Self {
107
10
        let vals: Vec<&str> = Vec::new();
108
10
        let owned = Arc::new(StringArray::from(vals));
109
10
        Self {
110
10
            owned,
111
10
            map: TopKHashTable::new(limit, limit * 10),
112
10
            rnd: ahash::RandomState::default(),
113
10
        }
114
10
    }
115
}
116
117
impl ArrowHashTable for StringHashTable {
118
10
    fn set_batch(&mut self, ids: ArrayRef) {
119
10
        self.owned = ids;
120
10
    }
121
122
20
    fn len(&self) -> usize {
123
20
        self.map.len()
124
20
    }
125
126
18
    unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) {
127
18
        self.map.update_heap_idx(mapper);
128
18
    }
129
130
5
    unsafe fn heap_idx_at(&self, map_idx: usize) -> usize {
131
5
        self.map.heap_idx_at(map_idx)
132
5
    }
133
134
10
    unsafe fn take_all(&mut self, indexes: Vec<usize>) -> ArrayRef {
135
10
        let ids = self.map.take_all(indexes);
136
10
        Arc::new(StringArray::from(ids))
137
10
    }
138
139
18
    unsafe fn find_or_insert(
140
18
        &mut self,
141
18
        row_idx: usize,
142
18
        replace_idx: usize,
143
18
        mapper: &mut Vec<(usize, usize)>,
144
18
    ) -> (usize, bool) {
145
18
        let ids = self
146
18
            .owned
147
18
            .as_any()
148
18
            .downcast_ref::<StringArray>()
149
18
            .expect("StringArray required");
150
18
        let id = if ids.is_null(row_idx) {
151
2
            None
152
        } else {
153
16
            Some(ids.value(row_idx))
154
        };
155
156
18
        let hash = self.rnd.hash_one(id);
157
18
        if let Some(
map_idx5
) = self
158
18
            .map
159
18
            .find(hash, |mi| 
id == mi.as_ref().map(5
|id|
id.as_str()4
)5
)
160
        {
161
5
            return (map_idx, false);
162
13
        }
163
13
164
13
        // we're full and this is a better value, so remove the worst
165
13
        let heap_idx = self.map.remove_if_full(replace_idx);
166
13
167
13
        // add the new group
168
13
        let id = id.map(|id| 
id.to_string()12
);
169
13
        let map_idx = self.map.insert(hash, id, heap_idx, mapper);
170
13
        (map_idx, true)
171
18
    }
172
}
173
174
impl<VAL: ArrowPrimitiveType> PrimitiveHashTable<VAL>
175
where
176
    Option<<VAL as ArrowPrimitiveType>::Native>: Comparable,
177
    Option<<VAL as ArrowPrimitiveType>::Native>: HashValue,
178
{
179
0
    pub fn new(limit: usize) -> Self {
180
0
        let owned = Arc::new(PrimitiveArray::<VAL>::builder(0).finish());
181
0
        Self {
182
0
            owned,
183
0
            map: TopKHashTable::new(limit, limit * 10),
184
0
            rnd: ahash::RandomState::default(),
185
0
        }
186
0
    }
187
}
188
189
impl<VAL: ArrowPrimitiveType> ArrowHashTable for PrimitiveHashTable<VAL>
190
where
191
    Option<<VAL as ArrowPrimitiveType>::Native>: Comparable,
192
    Option<<VAL as ArrowPrimitiveType>::Native>: HashValue,
193
{
194
0
    fn set_batch(&mut self, ids: ArrayRef) {
195
0
        self.owned = ids;
196
0
    }
197
198
0
    fn len(&self) -> usize {
199
0
        self.map.len()
200
0
    }
201
202
0
    unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) {
203
0
        self.map.update_heap_idx(mapper);
204
0
    }
205
206
0
    unsafe fn heap_idx_at(&self, map_idx: usize) -> usize {
207
0
        self.map.heap_idx_at(map_idx)
208
0
    }
209
210
0
    unsafe fn take_all(&mut self, indexes: Vec<usize>) -> ArrayRef {
211
0
        let ids = self.map.take_all(indexes);
212
0
        let mut builder: PrimitiveBuilder<VAL> = PrimitiveArray::builder(ids.len());
213
0
        for id in ids.into_iter() {
214
0
            match id {
215
0
                None => builder.append_null(),
216
0
                Some(id) => builder.append_value(id),
217
            }
218
        }
219
0
        let ids = builder.finish();
220
0
        Arc::new(ids)
221
0
    }
222
223
0
    unsafe fn find_or_insert(
224
0
        &mut self,
225
0
        row_idx: usize,
226
0
        replace_idx: usize,
227
0
        mapper: &mut Vec<(usize, usize)>,
228
0
    ) -> (usize, bool) {
229
0
        let ids = self.owned.as_primitive::<VAL>();
230
0
        let id: Option<VAL::Native> = if ids.is_null(row_idx) {
231
0
            None
232
        } else {
233
0
            Some(ids.value(row_idx))
234
        };
235
236
0
        let hash: u64 = id.hash(&self.rnd);
237
0
        if let Some(map_idx) = self.map.find(hash, |mi| id == *mi) {
238
0
            return (map_idx, false);
239
0
        }
240
0
241
0
        // we're full and this is a better value, so remove the worst
242
0
        let heap_idx = self.map.remove_if_full(replace_idx);
243
0
244
0
        // add the new group
245
0
        let map_idx = self.map.insert(hash, id, heap_idx, mapper);
246
0
        (map_idx, true)
247
0
    }
248
}
249
250
impl<ID: KeyType> TopKHashTable<ID> {
251
11
    pub fn new(limit: usize, capacity: usize) -> Self {
252
11
        Self {
253
11
            map: RawTable::with_capacity(capacity),
254
11
            limit,
255
11
        }
256
11
    }
257
258
18
    pub fn find(&self, hash: u64, mut eq: impl FnMut(&ID) -> bool) -> Option<usize> {
259
18
        let 
bucket5
= self.map.find(hash, |mi|
eq(&mi.id)5
)
?13
;
260
        // JUSTIFICATION
261
        //  Benefit:  ~15% speedup + required to index into RawTable from binary heap
262
        //  Soundness: getting the index of a bucket we just found
263
5
        let idx = unsafe { self.map.bucket_index(&bucket) };
264
5
        Some(idx)
265
18
    }
266
267
5
    pub unsafe fn heap_idx_at(&self, map_idx: usize) -> usize {
268
5
        let bucket = unsafe { self.map.bucket(map_idx) };
269
5
        bucket.as_ref().heap_idx
270
5
    }
271
272
13
    pub unsafe fn remove_if_full(&mut self, replace_idx: usize) -> usize {
273
13
        if self.map.len() >= self.limit {
274
2
            self.map.erase(self.map.bucket(replace_idx));
275
2
            0 // if full, always replace top node
276
        } else {
277
11
            self.map.len() // if we're not full, always append to end
278
        }
279
13
    }
280
281
18
    unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) {
282
18
        for (
m, h0
) in mapper {
283
0
            self.map.bucket(*m).as_mut().heap_idx = *h
284
        }
285
18
    }
286
287
18
    pub fn insert(
288
18
        &mut self,
289
18
        hash: u64,
290
18
        id: ID,
291
18
        heap_idx: usize,
292
18
        mapper: &mut Vec<(usize, usize)>,
293
18
    ) -> usize {
294
18
        let mi = HashTableItem::new(hash, id, heap_idx);
295
18
        let bucket = self.map.try_insert_no_grow(hash, mi);
296
18
        let bucket = match bucket {
297
17
            Ok(bucket) => bucket,
298
1
            Err(new_item) => {
299
3
                let bucket = self.map.insert(hash, new_item, |mi| mi.hash
)1
;
300
                // JUSTIFICATION
301
                //  Benefit:  ~15% speedup + required to index into RawTable from binary heap
302
                //  Soundness: we're getting indexes of buckets, not dereferencing them
303
                unsafe {
304
4
                    for bucket in 
self.map.iter()1
{
305
4
                        let heap_idx = bucket.as_ref().heap_idx;
306
4
                        let map_idx = self.map.bucket_index(&bucket);
307
4
                        mapper.push((heap_idx, map_idx));
308
4
                    }
309
                }
310
1
                bucket
311
            }
312
        };
313
        // JUSTIFICATION
314
        //  Benefit:  ~15% speedup + required to index into RawTable from binary heap
315
        //  Soundness: we're getting indexes of buckets, not dereferencing them
316
18
        unsafe { self.map.bucket_index(&bucket) }
317
18
    }
318
319
21
    pub fn len(&self) -> usize {
320
21
        self.map.len()
321
21
    }
322
323
11
    pub unsafe fn take_all(&mut self, idxs: Vec<usize>) -> Vec<ID> {
324
11
        let ids = idxs
325
11
            .into_iter()
326
16
            .map(|idx| self.map.bucket(idx).as_ref().id.clone())
327
11
            .collect();
328
11
        self.map.clear();
329
11
        ids
330
11
    }
331
}
332
333
impl<ID: KeyType> HashTableItem<ID> {
334
18
    pub fn new(hash: u64, id: ID, heap_idx: usize) -> Self {
335
18
        Self { hash, id, heap_idx }
336
18
    }
337
}
338
339
impl HashValue for Option<String> {
340
0
    fn hash(&self, state: &RandomState) -> u64 {
341
0
        state.hash_one(self)
342
0
    }
343
}
344
345
macro_rules! hash_float {
346
    ($($t:ty),+) => {
347
        $(impl HashValue for Option<$t> {
348
0
            fn hash(&self, state: &RandomState) -> u64 {
349
0
                self.map(|me| me.hash(state)).unwrap_or(0)
350
0
            }
351
        })+
352
    };
353
}
354
355
macro_rules! has_integer {
356
    ($($t:ty),+) => {
357
        $(impl HashValue for Option<$t> {
358
0
            fn hash(&self, state: &RandomState) -> u64 {
359
0
                self.map(|me| me.hash(state)).unwrap_or(0)
360
0
            }
361
        })+
362
    };
363
}
364
365
has_integer!(i8, i16, i32, i64, i128, i256);
366
has_integer!(u8, u16, u32, u64);
367
has_integer!(IntervalDayTime, IntervalMonthDayNano);
368
hash_float!(f16, f32, f64);
369
370
10
pub fn new_hash_table(
371
10
    limit: usize,
372
10
    kt: DataType,
373
10
) -> Result<Box<dyn ArrowHashTable + Send>> {
374
    macro_rules! downcast_helper {
375
        ($kt:ty, $d:ident) => {
376
            return Ok(Box::new(PrimitiveHashTable::<$kt>::new(limit)))
377
        };
378
    }
379
380
0
    downcast_primitive! {
381
0
        kt => (downcast_helper, kt),
382
10
        DataType::Utf8 => return Ok(Box::new(StringHashTable::new(limit))),
383
0
        _ => {}
384
0
    }
385
0
386
0
    Err(DataFusionError::Execution(format!(
387
0
        "Can't create HashTable for type: {kt:?}"
388
0
    )))
389
10
}
390
391
#[cfg(test)]
392
mod tests {
393
    use super::*;
394
    use std::collections::BTreeMap;
395
396
    #[test]
397
1
    fn should_resize_properly() -> Result<()> {
398
1
        let mut heap_to_map = BTreeMap::<usize, usize>::new();
399
1
        let mut map = TopKHashTable::<Option<String>>::new(5, 3);
400
5
        for (heap_idx, id) in 
vec!["1", "2", "3", "4", "5"].into_iter().enumerate()1
{
401
5
            let mut mapper = vec![];
402
5
            let hash = heap_idx as u64;
403
5
            let map_idx = map.insert(hash, Some(id.to_string()), heap_idx, &mut mapper);
404
5
            let _ = heap_to_map.insert(heap_idx, map_idx);
405
5
            if heap_idx == 3 {
406
1
                assert_eq!(
407
1
                    mapper,
408
1
                    vec![(0, 0), (1, 1), (2, 2), (3, 3)],
409
0
                    "Pass {heap_idx} resized incorrectly!"
410
                );
411
5
                for (
heap_idx, map_idx4
) in mapper {
412
4
                    let _ = heap_to_map.insert(heap_idx, map_idx);
413
4
                }
414
            } else {
415
4
                assert_eq!(mapper, vec![], 
"Pass {heap_idx} should not have resized!"0
);
416
            }
417
        }
418
419
1
        let (_heap_idxs, map_idxs): (Vec<_>, Vec<_>) = heap_to_map.into_iter().unzip();
420
1
        let ids = unsafe { map.take_all(map_idxs) };
421
1
        assert_eq!(
422
1
            format!("{:?}", ids),
423
1
            r#"[Some("1"), Some("2"), Some("3"), Some("4"), Some("5")]"#
424
1
        );
425
1
        assert_eq!(map.len(), 0, 
"Map should have been cleared!"0
);
426
427
1
        Ok(())
428
1
    }
429
}