Skip to content

Commit

Permalink
refactor: migrate GroupValuesPrimitive to HashTable
Browse files Browse the repository at this point in the history
  • Loading branch information
crepererum committed Nov 22, 2024
1 parent 5227895 commit a80a55a
Showing 1 changed file with 21 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use datafusion_common::Result;
use datafusion_execution::memory_pool::proxy::VecAllocExt;
use datafusion_expr::EmitTo;
use half::f16;
use hashbrown::raw::RawTable;
use hashbrown::hash_table::HashTable;
use std::mem::size_of;
use std::sync::Arc;

Expand Down Expand Up @@ -86,7 +86,7 @@ pub struct GroupValuesPrimitive<T: ArrowPrimitiveType> {
///
/// We don't store the hashes as hashing fixed width primitives
/// is fast enough for this not to benefit performance
map: RawTable<usize>,
map: HashTable<usize>,
/// The group index of the null value if any
null_group: Option<usize>,
/// The values for each group index
Expand All @@ -100,7 +100,7 @@ impl<T: ArrowPrimitiveType> GroupValuesPrimitive<T> {
assert!(PrimitiveArray::<T>::is_compatible(&data_type));
Self {
data_type,
map: RawTable::with_capacity(128),
map: HashTable::with_capacity(128),
values: Vec::with_capacity(128),
null_group: None,
random_state: Default::default(),
Expand All @@ -126,22 +126,19 @@ where
Some(key) => {
let state = &self.random_state;
let hash = key.hash(state);
let insert = self.map.find_or_find_insert_slot(
let insert = self.map.entry(
hash,
|g| unsafe { self.values.get_unchecked(*g).is_eq(key) },
|g| unsafe { self.values.get_unchecked(*g).hash(state) },
);

// SAFETY: No mutation occurred since find_or_find_insert_slot
unsafe {
match insert {
Ok(v) => *v.as_ref(),
Err(slot) => {
let g = self.values.len();
self.map.insert_in_slot(hash, slot, g);
self.values.push(key);
g
}
match insert {
hashbrown::hash_table::Entry::Occupied(o) => *o.get(),
hashbrown::hash_table::Entry::Vacant(v) => {
let g = self.values.len();
v.insert(g);
self.values.push(key);
g
}
}
}
Expand Down Expand Up @@ -183,18 +180,18 @@ where
build_primitive(std::mem::take(&mut self.values), self.null_group.take())
}
EmitTo::First(n) => {
// SAFETY: self.map outlives iterator and is not modified concurrently
unsafe {
for bucket in self.map.iter() {
// Decrement group index by n
match bucket.as_ref().checked_sub(n) {
// Group index was >= n, shift value down
Some(sub) => *bucket.as_mut() = sub,
// Group index was < n, so remove from table
None => self.map.erase(bucket),
self.map.retain(|group_idx| {
// Decrement group index by n
match group_idx.checked_sub(n) {
// Group index was >= n, shift value down
Some(sub) => {
*group_idx = sub;
true
}
// Group index was < n, so remove from table
None => false,
}
}
});
let null_group = match &mut self.null_group {
Some(v) if *v >= n => {
*v -= n;
Expand Down

0 comments on commit a80a55a

Please sign in to comment.