Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: change some hashbrown RawTable uses to HashTable (round 2) #13524

Merged
merged 3 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ use arrow_array::{Array, ArrayRef};
use arrow_schema::{DataType, Schema, SchemaRef, TimeUnit};
use datafusion_common::hash_utils::create_hashes;
use datafusion_common::{not_impl_err, DataFusionError, Result};
use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt};
use datafusion_execution::memory_pool::proxy::{HashTableAllocExt, VecAllocExt};
use datafusion_expr::EmitTo;
use datafusion_physical_expr::binary_map::OutputType;

use hashbrown::raw::RawTable;
use hashbrown::hash_table::HashTable;

const NON_INLINED_FLAG: u64 = 0x8000000000000000;
const VALUE_MASK: u64 = 0x7FFFFFFFFFFFFFFF;
Expand Down Expand Up @@ -180,7 +180,7 @@ pub struct GroupValuesColumn<const STREAMING: bool> {
/// And we use [`GroupIndexView`] to represent such `group indices` in table.
///
///
map: RawTable<(u64, GroupIndexView)>,
map: HashTable<(u64, GroupIndexView)>,

/// The size of `map` in bytes
map_size: usize,
Expand Down Expand Up @@ -261,7 +261,7 @@ impl<const STREAMING: bool> GroupValuesColumn<STREAMING> {

/// Create a new instance of GroupValuesColumn if supported for the specified schema
pub fn try_new(schema: SchemaRef) -> Result<Self> {
let map = RawTable::with_capacity(0);
let map = HashTable::with_capacity(0);
Ok(Self {
schema,
map,
Expand Down Expand Up @@ -338,7 +338,7 @@ impl<const STREAMING: bool> GroupValuesColumn<STREAMING> {
for (row, &target_hash) in batch_hashes.iter().enumerate() {
let entry = self
.map
.get_mut(target_hash, |(exist_hash, group_idx_view)| {
.find_mut(target_hash, |(exist_hash, group_idx_view)| {
// It is ensured to be inlined in `scalarized_intern`
debug_assert!(!group_idx_view.is_non_inlined());

Expand Down Expand Up @@ -506,7 +506,7 @@ impl<const STREAMING: bool> GroupValuesColumn<STREAMING> {
for (row, &target_hash) in batch_hashes.iter().enumerate() {
let entry = self
.map
.get(target_hash, |(exist_hash, _)| target_hash == *exist_hash);
.find(target_hash, |(exist_hash, _)| target_hash == *exist_hash);

let Some((_, group_index_view)) = entry else {
// 1. Bucket not found case
Expand Down Expand Up @@ -733,7 +733,7 @@ impl<const STREAMING: bool> GroupValuesColumn<STREAMING> {

for &row in &self.vectorized_operation_buffers.remaining_row_indices {
let target_hash = batch_hashes[row];
let entry = map.get_mut(target_hash, |(exist_hash, _)| {
let entry = map.find_mut(target_hash, |(exist_hash, _)| {
// Somewhat surprisingly, this closure can be called even if the
// hash doesn't match, so check the hash first with an integer
// comparison first avoid the more expensive comparison with
Expand Down Expand Up @@ -852,7 +852,7 @@ impl<const STREAMING: bool> GroupValuesColumn<STREAMING> {
/// Return group indices of the hash, also if its `group_index_view` is non-inlined
#[cfg(test)]
fn get_indices_by_hash(&self, hash: u64) -> Option<(Vec<usize>, GroupIndexView)> {
let entry = self.map.get(hash, |(exist_hash, _)| hash == *exist_hash);
let entry = self.map.find(hash, |(exist_hash, _)| hash == *exist_hash);

match entry {
Some((_, group_index_view)) => {
Expand Down Expand Up @@ -1083,67 +1083,63 @@ impl<const STREAMING: bool> GroupValues for GroupValuesColumn<STREAMING> {
.collect::<Vec<_>>();
let mut next_new_list_offset = 0;

// SAFETY: self.map outlives iterator and is not modified concurrently
unsafe {
for bucket in self.map.iter() {
// In non-streaming case, we need to check if the `group index view`
// is `inlined` or `non-inlined`
if !STREAMING && bucket.as_ref().1.is_non_inlined() {
// Non-inlined case
// We take `group_index_list` from `old_group_index_lists`

// list_offset is incrementally
self.emit_group_index_list_buffer.clear();
let list_offset = bucket.as_ref().1.value() as usize;
for group_index in self.group_index_lists[list_offset].iter()
{
if let Some(remaining) = group_index.checked_sub(n) {
self.emit_group_index_list_buffer.push(remaining);
}
self.map.retain(|(_exist_hash, group_idx_view)| {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💯 for removing the unsafe

// In non-streaming case, we need to check if the `group index view`
// is `inlined` or `non-inlined`
if !STREAMING && group_idx_view.is_non_inlined() {
// Non-inlined case
// We take `group_index_list` from `old_group_index_lists`

// list_offset is incrementally
self.emit_group_index_list_buffer.clear();
let list_offset = group_idx_view.value() as usize;
for group_index in self.group_index_lists[list_offset].iter() {
if let Some(remaining) = group_index.checked_sub(n) {
self.emit_group_index_list_buffer.push(remaining);
}

// The possible results:
// - `new_group_index_list` is empty, we should erase this bucket
// - only one value in `new_group_index_list`, switch the `view` to `inlined`
// - still multiple values in `new_group_index_list`, build and set the new `unlined view`
if self.emit_group_index_list_buffer.is_empty() {
self.map.erase(bucket);
} else if self.emit_group_index_list_buffer.len() == 1 {
let group_index =
self.emit_group_index_list_buffer.first().unwrap();
bucket.as_mut().1 =
GroupIndexView::new_inlined(*group_index as u64);
} else {
let group_index_list =
&mut self.group_index_lists[next_new_list_offset];
group_index_list.clear();
group_index_list
.extend(self.emit_group_index_list_buffer.iter());
bucket.as_mut().1 = GroupIndexView::new_non_inlined(
next_new_list_offset as u64,
);
next_new_list_offset += 1;
}

continue;
}

// The possible results:
// - `new_group_index_list` is empty, we should erase this bucket
// - only one value in `new_group_index_list`, switch the `view` to `inlined`
// - still multiple values in `new_group_index_list`, build and set the new `unlined view`
if self.emit_group_index_list_buffer.is_empty() {
false
} else if self.emit_group_index_list_buffer.len() == 1 {
let group_index =
self.emit_group_index_list_buffer.first().unwrap();
*group_idx_view =
GroupIndexView::new_inlined(*group_index as u64);
true
} else {
let group_index_list =
&mut self.group_index_lists[next_new_list_offset];
group_index_list.clear();
group_index_list
.extend(self.emit_group_index_list_buffer.iter());
*group_idx_view = GroupIndexView::new_non_inlined(
next_new_list_offset as u64,
);
next_new_list_offset += 1;
true
}
} else {
// In `streaming case`, the `group index view` is ensured to be `inlined`
debug_assert!(!bucket.as_ref().1.is_non_inlined());
debug_assert!(!group_idx_view.is_non_inlined());

// Inlined case, we just decrement group index by n)
let group_index = bucket.as_ref().1.value() as usize;
let group_index = group_idx_view.value() as usize;
match group_index.checked_sub(n) {
// Group index was >= n, shift value down
Some(sub) => {
bucket.as_mut().1 =
GroupIndexView::new_inlined(sub as u64)
*group_idx_view = GroupIndexView::new_inlined(sub as u64);
true
}
// Group index was < n, so remove from table
None => self.map.erase(bucket),
None => false,
}
}
}
});

if !STREAMING {
self.group_index_lists.truncate(next_new_list_offset);
Expand Down Expand Up @@ -1234,7 +1230,7 @@ mod tests {
use arrow::{compute::concat_batches, util::pretty::pretty_format_batches};
use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray, StringViewArray};
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use datafusion_common::utils::proxy::RawTableAllocExt;
use datafusion_common::utils::proxy::HashTableAllocExt;
use datafusion_expr::EmitTo;

use crate::aggregates::group_values::{
Expand Down
30 changes: 15 additions & 15 deletions datafusion/physical-plan/src/aggregates/group_values/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ use arrow_array::{Array, ArrayRef, ListArray, StructArray};
use arrow_schema::{DataType, SchemaRef};
use datafusion_common::hash_utils::create_hashes;
use datafusion_common::Result;
use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt};
use datafusion_execution::memory_pool::proxy::{HashTableAllocExt, VecAllocExt};
use datafusion_expr::EmitTo;
use hashbrown::raw::RawTable;
use hashbrown::hash_table::HashTable;
use log::debug;
use std::mem::size_of;
use std::sync::Arc;
Expand Down Expand Up @@ -54,7 +54,7 @@ pub struct GroupValuesRows {
///
/// keys: u64 hashes of the GroupValue
/// values: (hash, group_index)
map: RawTable<(u64, usize)>,
map: HashTable<(u64, usize)>,

/// The size of `map` in bytes
map_size: usize,
Expand Down Expand Up @@ -92,7 +92,7 @@ impl GroupValuesRows {
.collect(),
)?;

let map = RawTable::with_capacity(0);
let map = HashTable::with_capacity(0);

let starting_rows_capacity = 1000;

Expand Down Expand Up @@ -135,7 +135,7 @@ impl GroupValues for GroupValuesRows {
create_hashes(cols, &self.random_state, batch_hashes)?;

for (row, &target_hash) in batch_hashes.iter().enumerate() {
let entry = self.map.get_mut(target_hash, |(exist_hash, group_idx)| {
let entry = self.map.find_mut(target_hash, |(exist_hash, group_idx)| {
// Somewhat surprisingly, this closure can be called even if the
// hash doesn't match, so check the hash first with an integer
// comparison first avoid the more expensive comparison with
Expand Down Expand Up @@ -216,18 +216,18 @@ impl GroupValues for GroupValuesRows {
}
std::mem::swap(&mut new_group_values, &mut group_values);

// SAFETY: self.map outlives iterator and is not modified concurrently
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Picking this one, but this pattern occurs multiple times:

That SAFETY statement is just plain wrong: while iterating over the map, we erase elements from it. That is THE prime example for "concurrent modification of containers" and why Rust's lifetimes/reference system prevents that. I'm kinda surprised that this hasn't exploded yet.

unsafe {
for bucket in self.map.iter() {
// Decrement group index by n
match bucket.as_ref().1.checked_sub(n) {
// Group index was >= n, shift value down
Some(sub) => bucket.as_mut().1 = sub,
// Group index was < n, so remove from table
None => self.map.erase(bucket),
self.map.retain(|(_exists_hash, group_idx)| {
// Decrement group index by n
match group_idx.checked_sub(n) {
// Group index was >= n, shift value down
Some(sub) => {
*group_idx = sub;
true
}
// Group index was < n, so remove from table
None => false,
}
}
});
output
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use datafusion_common::Result;
use datafusion_execution::memory_pool::proxy::VecAllocExt;
use datafusion_expr::EmitTo;
use half::f16;
use hashbrown::raw::RawTable;
use hashbrown::hash_table::HashTable;
use std::mem::size_of;
use std::sync::Arc;

Expand Down Expand Up @@ -86,7 +86,7 @@ pub struct GroupValuesPrimitive<T: ArrowPrimitiveType> {
///
/// We don't store the hashes as hashing fixed width primitives
/// is fast enough for this not to benefit performance
map: RawTable<usize>,
map: HashTable<usize>,
/// The group index of the null value if any
null_group: Option<usize>,
/// The values for each group index
Expand All @@ -100,7 +100,7 @@ impl<T: ArrowPrimitiveType> GroupValuesPrimitive<T> {
assert!(PrimitiveArray::<T>::is_compatible(&data_type));
Self {
data_type,
map: RawTable::with_capacity(128),
map: HashTable::with_capacity(128),
values: Vec::with_capacity(128),
null_group: None,
random_state: Default::default(),
Expand All @@ -126,22 +126,19 @@ where
Some(key) => {
let state = &self.random_state;
let hash = key.hash(state);
let insert = self.map.find_or_find_insert_slot(
let insert = self.map.entry(
hash,
|g| unsafe { self.values.get_unchecked(*g).is_eq(key) },
|g| unsafe { self.values.get_unchecked(*g).hash(state) },
);

// SAFETY: No mutation occurred since find_or_find_insert_slot
unsafe {
match insert {
Ok(v) => *v.as_ref(),
Err(slot) => {
let g = self.values.len();
self.map.insert_in_slot(hash, slot, g);
self.values.push(key);
g
}
match insert {
hashbrown::hash_table::Entry::Occupied(o) => *o.get(),
hashbrown::hash_table::Entry::Vacant(v) => {
let g = self.values.len();
v.insert(g);
self.values.push(key);
g
}
}
}
Expand Down Expand Up @@ -183,18 +180,18 @@ where
build_primitive(std::mem::take(&mut self.values), self.null_group.take())
}
EmitTo::First(n) => {
// SAFETY: self.map outlives iterator and is not modified concurrently
unsafe {
for bucket in self.map.iter() {
// Decrement group index by n
match bucket.as_ref().checked_sub(n) {
// Group index was >= n, shift value down
Some(sub) => *bucket.as_mut() = sub,
// Group index was < n, so remove from table
None => self.map.erase(bucket),
self.map.retain(|group_idx| {
// Decrement group index by n
match group_idx.checked_sub(n) {
// Group index was >= n, shift value down
Some(sub) => {
*group_idx = sub;
true
}
// Group index was < n, so remove from table
None => false,
}
}
});
let null_group = match &mut self.null_group {
Some(v) if *v >= n => {
*v -= n;
Expand Down