From a80a55a1c2d408770720c12bc10cb09f93709a0d Mon Sep 17 00:00:00 2001 From: Marco Neumann Date: Fri, 22 Nov 2024 12:13:12 +0100 Subject: [PATCH] refactor: migrate `GroupValuesPrimitive` to `HashTable` For #13433. --- .../group_values/single_group_by/primitive.rs | 45 +++++++++---------- 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs index 05214ec10d68b..c36ee9c5599b9 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs @@ -29,7 +29,7 @@ use datafusion_common::Result; use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_expr::EmitTo; use half::f16; -use hashbrown::raw::RawTable; +use hashbrown::hash_table::HashTable; use std::mem::size_of; use std::sync::Arc; @@ -86,7 +86,7 @@ pub struct GroupValuesPrimitive { /// /// We don't store the hashes as hashing fixed width primitives /// is fast enough for this not to benefit performance - map: RawTable, + map: HashTable, /// The group index of the null value if any null_group: Option, /// The values for each group index @@ -100,7 +100,7 @@ impl GroupValuesPrimitive { assert!(PrimitiveArray::::is_compatible(&data_type)); Self { data_type, - map: RawTable::with_capacity(128), + map: HashTable::with_capacity(128), values: Vec::with_capacity(128), null_group: None, random_state: Default::default(), @@ -126,22 +126,19 @@ where Some(key) => { let state = &self.random_state; let hash = key.hash(state); - let insert = self.map.find_or_find_insert_slot( + let insert = self.map.entry( hash, |g| unsafe { self.values.get_unchecked(*g).is_eq(key) }, |g| unsafe { self.values.get_unchecked(*g).hash(state) }, ); - // SAFETY: No mutation occurred since find_or_find_insert_slot - unsafe { - match insert { - Ok(v) => *v.as_ref(), - Err(slot) => { - let g = self.values.len(); - self.map.insert_in_slot(hash, slot, g); - self.values.push(key); - g - } + match insert { + hashbrown::hash_table::Entry::Occupied(o) => *o.get(), + hashbrown::hash_table::Entry::Vacant(v) => { + let g = self.values.len(); + v.insert(g); + self.values.push(key); + g } } } @@ -183,18 +180,18 @@ where build_primitive(std::mem::take(&mut self.values), self.null_group.take()) } EmitTo::First(n) => { - // SAFETY: self.map outlives iterator and is not modified concurrently - unsafe { - for bucket in self.map.iter() { - // Decrement group index by n - match bucket.as_ref().checked_sub(n) { - // Group index was >= n, shift value down - Some(sub) => *bucket.as_mut() = sub, - // Group index was < n, so remove from table - None => self.map.erase(bucket), + self.map.retain(|group_idx| { + // Decrement group index by n + match group_idx.checked_sub(n) { + // Group index was >= n, shift value down + Some(sub) => { + *group_idx = sub; + true } + // Group index was < n, so remove from table + None => false, } - } + }); let null_group = match &mut self.null_group { Some(v) if *v >= n => { *v -= n;