diff --git a/ballista/rust/core/src/execution_plans/shuffle_writer.rs b/ballista/rust/core/src/execution_plans/shuffle_writer.rs index 8081dab36ab5..b1db21fa90a1 100644 --- a/ballista/rust/core/src/execution_plans/shuffle_writer.rs +++ b/ballista/rust/core/src/execution_plans/shuffle_writer.rs @@ -44,7 +44,7 @@ use datafusion::arrow::ipc::reader::FileReader; use datafusion::arrow::ipc::writer::FileWriter; use datafusion::arrow::record_batch::RecordBatch; use datafusion::error::{DataFusionError, Result}; -use datafusion::physical_plan::hash_join::create_hashes; +use datafusion::physical_plan::hash_utils::create_hashes; use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::Partitioning::RoundRobinBatch; use datafusion::physical_plan::{ diff --git a/datafusion/src/physical_plan/hash_join.rs b/datafusion/src/physical_plan/hash_join.rs index 00ca1539d714..1a174bb11d10 100644 --- a/datafusion/src/physical_plan/hash_join.rs +++ b/datafusion/src/physical_plan/hash_join.rs @@ -18,18 +18,15 @@ //! Defines the join plan for executing partitions in parallel and then joining the results //! into a set of partitions. -use ahash::CallHasher; use ahash::RandomState; use arrow::{ array::{ - ArrayData, ArrayRef, BooleanArray, Date32Array, Date64Array, Float32Array, - Float64Array, LargeStringArray, PrimitiveArray, TimestampMicrosecondArray, - TimestampMillisecondArray, TimestampNanosecondArray, UInt32BufferBuilder, - UInt32Builder, UInt64BufferBuilder, UInt64Builder, + ArrayData, ArrayRef, BooleanArray, LargeStringArray, PrimitiveArray, + UInt32BufferBuilder, UInt32Builder, UInt64BufferBuilder, UInt64Builder, }, compute, - datatypes::{TimeUnit, UInt32Type, UInt64Type}, + datatypes::{UInt32Type, UInt64Type}, }; use smallvec::{smallvec, SmallVec}; use std::{any::Any, usize}; @@ -53,6 +50,7 @@ use arrow::array::{ }; use super::expressions::Column; +use super::hash_utils::create_hashes; use super::{ coalesce_partitions::CoalescePartitionsExec, hash_utils::{build_join_schema, check_join_is_valid, JoinOn}, @@ -790,13 +788,6 @@ impl BuildHasher for IdHashBuilder { } } -// Combines two hashes into one hash -#[inline] -fn combine_hashes(l: u64, r: u64) -> u64 { - let hash = (17 * 37u64).wrapping_add(l); - hash.wrapping_mul(37).wrapping_add(r) -} - macro_rules! equal_rows_elem { ($array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident) => {{ let left_array = $l.as_any().downcast_ref::<$array_type>().unwrap(); @@ -848,338 +839,6 @@ fn equal_rows( err.unwrap_or(Ok(res)) } -macro_rules! hash_array { - ($array_type:ident, $column: ident, $ty: ident, $hashes: ident, $random_state: ident, $multi_col: ident) => { - let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); - if array.null_count() == 0 { - if $multi_col { - for (i, hash) in $hashes.iter_mut().enumerate() { - *hash = combine_hashes( - $ty::get_hash(&array.value(i), $random_state), - *hash, - ); - } - } else { - for (i, hash) in $hashes.iter_mut().enumerate() { - *hash = $ty::get_hash(&array.value(i), $random_state); - } - } - } else { - if $multi_col { - for (i, hash) in $hashes.iter_mut().enumerate() { - if !array.is_null(i) { - *hash = combine_hashes( - $ty::get_hash(&array.value(i), $random_state), - *hash, - ); - } - } - } else { - for (i, hash) in $hashes.iter_mut().enumerate() { - if !array.is_null(i) { - *hash = $ty::get_hash(&array.value(i), $random_state); - } - } - } - } - }; -} - -macro_rules! hash_array_primitive { - ($array_type:ident, $column: ident, $ty: ident, $hashes: ident, $random_state: ident, $multi_col: ident) => { - let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); - let values = array.values(); - - if array.null_count() == 0 { - if $multi_col { - for (hash, value) in $hashes.iter_mut().zip(values.iter()) { - *hash = combine_hashes($ty::get_hash(value, $random_state), *hash); - } - } else { - for (hash, value) in $hashes.iter_mut().zip(values.iter()) { - *hash = $ty::get_hash(value, $random_state) - } - } - } else { - if $multi_col { - for (i, (hash, value)) in - $hashes.iter_mut().zip(values.iter()).enumerate() - { - if !array.is_null(i) { - *hash = - combine_hashes($ty::get_hash(value, $random_state), *hash); - } - } - } else { - for (i, (hash, value)) in - $hashes.iter_mut().zip(values.iter()).enumerate() - { - if !array.is_null(i) { - *hash = $ty::get_hash(value, $random_state); - } - } - } - } - }; -} - -macro_rules! hash_array_float { - ($array_type:ident, $column: ident, $ty: ident, $hashes: ident, $random_state: ident, $multi_col: ident) => { - let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); - let values = array.values(); - - if array.null_count() == 0 { - if $multi_col { - for (hash, value) in $hashes.iter_mut().zip(values.iter()) { - *hash = combine_hashes( - $ty::get_hash( - &$ty::from_le_bytes(value.to_le_bytes()), - $random_state, - ), - *hash, - ); - } - } else { - for (hash, value) in $hashes.iter_mut().zip(values.iter()) { - *hash = $ty::get_hash( - &$ty::from_le_bytes(value.to_le_bytes()), - $random_state, - ) - } - } - } else { - if $multi_col { - for (i, (hash, value)) in - $hashes.iter_mut().zip(values.iter()).enumerate() - { - if !array.is_null(i) { - *hash = combine_hashes( - $ty::get_hash( - &$ty::from_le_bytes(value.to_le_bytes()), - $random_state, - ), - *hash, - ); - } - } - } else { - for (i, (hash, value)) in - $hashes.iter_mut().zip(values.iter()).enumerate() - { - if !array.is_null(i) { - *hash = $ty::get_hash( - &$ty::from_le_bytes(value.to_le_bytes()), - $random_state, - ); - } - } - } - } - }; -} - -/// Creates hash values for every element in the row based on the values in the columns -pub fn create_hashes<'a>( - arrays: &[ArrayRef], - random_state: &RandomState, - hashes_buffer: &'a mut Vec, -) -> Result<&'a mut Vec> { - // combine hashes with `combine_hashes` if we have more than 1 column - let multi_col = arrays.len() > 1; - - for col in arrays { - match col.data_type() { - DataType::UInt8 => { - hash_array_primitive!( - UInt8Array, - col, - u8, - hashes_buffer, - random_state, - multi_col - ); - } - DataType::UInt16 => { - hash_array_primitive!( - UInt16Array, - col, - u16, - hashes_buffer, - random_state, - multi_col - ); - } - DataType::UInt32 => { - hash_array_primitive!( - UInt32Array, - col, - u32, - hashes_buffer, - random_state, - multi_col - ); - } - DataType::UInt64 => { - hash_array_primitive!( - UInt64Array, - col, - u64, - hashes_buffer, - random_state, - multi_col - ); - } - DataType::Int8 => { - hash_array_primitive!( - Int8Array, - col, - i8, - hashes_buffer, - random_state, - multi_col - ); - } - DataType::Int16 => { - hash_array_primitive!( - Int16Array, - col, - i16, - hashes_buffer, - random_state, - multi_col - ); - } - DataType::Int32 => { - hash_array_primitive!( - Int32Array, - col, - i32, - hashes_buffer, - random_state, - multi_col - ); - } - DataType::Int64 => { - hash_array_primitive!( - Int64Array, - col, - i64, - hashes_buffer, - random_state, - multi_col - ); - } - DataType::Float32 => { - hash_array_float!( - Float32Array, - col, - u32, - hashes_buffer, - random_state, - multi_col - ); - } - DataType::Float64 => { - hash_array_float!( - Float64Array, - col, - u64, - hashes_buffer, - random_state, - multi_col - ); - } - DataType::Timestamp(TimeUnit::Millisecond, None) => { - hash_array_primitive!( - TimestampMillisecondArray, - col, - i64, - hashes_buffer, - random_state, - multi_col - ); - } - DataType::Timestamp(TimeUnit::Microsecond, None) => { - hash_array_primitive!( - TimestampMicrosecondArray, - col, - i64, - hashes_buffer, - random_state, - multi_col - ); - } - DataType::Timestamp(TimeUnit::Nanosecond, None) => { - hash_array_primitive!( - TimestampNanosecondArray, - col, - i64, - hashes_buffer, - random_state, - multi_col - ); - } - DataType::Date32 => { - hash_array_primitive!( - Date32Array, - col, - i32, - hashes_buffer, - random_state, - multi_col - ); - } - DataType::Date64 => { - hash_array_primitive!( - Date64Array, - col, - i64, - hashes_buffer, - random_state, - multi_col - ); - } - DataType::Boolean => { - hash_array!( - BooleanArray, - col, - u8, - hashes_buffer, - random_state, - multi_col - ); - } - DataType::Utf8 => { - hash_array!( - StringArray, - col, - str, - hashes_buffer, - random_state, - multi_col - ); - } - DataType::LargeUtf8 => { - hash_array!( - LargeStringArray, - col, - str, - hashes_buffer, - random_state, - multi_col - ); - } - _ => { - // This is internal because we should have caught this before. - return Err(DataFusionError::Internal( - "Unsupported data type in hasher".to_string(), - )); - } - } - } - Ok(hashes_buffer) -} - // Produces a batch for left-side rows that have/have not been matched during the whole join fn produce_from_matched( visited_left_side: &[bool], @@ -2115,22 +1774,6 @@ mod tests { Ok(()) } - #[test] - fn create_hashes_for_float_arrays() -> Result<()> { - let f32_arr = Arc::new(Float32Array::from(vec![0.12, 0.5, 1f32, 444.7])); - let f64_arr = Arc::new(Float64Array::from(vec![0.12, 0.5, 1f64, 444.7])); - - let random_state = RandomState::with_seeds(0, 0, 0, 0); - let hashes_buff = &mut vec![0; f32_arr.len()]; - let hashes = create_hashes(&[f32_arr], &random_state, hashes_buff)?; - assert_eq!(hashes.len(), 4,); - - let hashes = create_hashes(&[f64_arr], &random_state, hashes_buff)?; - assert_eq!(hashes.len(), 4,); - - Ok(()) - } - #[test] fn join_with_hash_collision() -> Result<()> { let mut hashmap_left = HashMap::with_capacity_and_hasher(2, IdHashBuilder {}); diff --git a/datafusion/src/physical_plan/hash_utils.rs b/datafusion/src/physical_plan/hash_utils.rs index 9243affe9cfc..e937b4ea549c 100644 --- a/datafusion/src/physical_plan/hash_utils.rs +++ b/datafusion/src/physical_plan/hash_utils.rs @@ -18,7 +18,14 @@ //! Functionality used both on logical and physical plans use crate::error::{DataFusionError, Result}; -use arrow::datatypes::{Field, Schema}; +use ahash::{CallHasher, RandomState}; +use arrow::array::{ + Array, ArrayRef, BooleanArray, Date32Array, Date64Array, Float32Array, Float64Array, + Int16Array, Int32Array, Int64Array, Int8Array, LargeStringArray, StringArray, + TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, + UInt16Array, UInt32Array, UInt64Array, UInt8Array, +}; +use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use std::collections::HashSet; use crate::logical_plan::JoinType; @@ -101,8 +108,351 @@ pub fn build_join_schema(left: &Schema, right: &Schema, join_type: &JoinType) -> Schema::new(fields) } +// Combines two hashes into one hash +#[inline] +fn combine_hashes(l: u64, r: u64) -> u64 { + let hash = (17 * 37u64).wrapping_add(l); + hash.wrapping_mul(37).wrapping_add(r) +} + +macro_rules! hash_array { + ($array_type:ident, $column: ident, $ty: ident, $hashes: ident, $random_state: ident, $multi_col: ident) => { + let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); + if array.null_count() == 0 { + if $multi_col { + for (i, hash) in $hashes.iter_mut().enumerate() { + *hash = combine_hashes( + $ty::get_hash(&array.value(i), $random_state), + *hash, + ); + } + } else { + for (i, hash) in $hashes.iter_mut().enumerate() { + *hash = $ty::get_hash(&array.value(i), $random_state); + } + } + } else { + if $multi_col { + for (i, hash) in $hashes.iter_mut().enumerate() { + if !array.is_null(i) { + *hash = combine_hashes( + $ty::get_hash(&array.value(i), $random_state), + *hash, + ); + } + } + } else { + for (i, hash) in $hashes.iter_mut().enumerate() { + if !array.is_null(i) { + *hash = $ty::get_hash(&array.value(i), $random_state); + } + } + } + } + }; +} + +macro_rules! hash_array_primitive { + ($array_type:ident, $column: ident, $ty: ident, $hashes: ident, $random_state: ident, $multi_col: ident) => { + let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); + let values = array.values(); + + if array.null_count() == 0 { + if $multi_col { + for (hash, value) in $hashes.iter_mut().zip(values.iter()) { + *hash = combine_hashes($ty::get_hash(value, $random_state), *hash); + } + } else { + for (hash, value) in $hashes.iter_mut().zip(values.iter()) { + *hash = $ty::get_hash(value, $random_state) + } + } + } else { + if $multi_col { + for (i, (hash, value)) in + $hashes.iter_mut().zip(values.iter()).enumerate() + { + if !array.is_null(i) { + *hash = + combine_hashes($ty::get_hash(value, $random_state), *hash); + } + } + } else { + for (i, (hash, value)) in + $hashes.iter_mut().zip(values.iter()).enumerate() + { + if !array.is_null(i) { + *hash = $ty::get_hash(value, $random_state); + } + } + } + } + }; +} + +macro_rules! hash_array_float { + ($array_type:ident, $column: ident, $ty: ident, $hashes: ident, $random_state: ident, $multi_col: ident) => { + let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); + let values = array.values(); + + if array.null_count() == 0 { + if $multi_col { + for (hash, value) in $hashes.iter_mut().zip(values.iter()) { + *hash = combine_hashes( + $ty::get_hash( + &$ty::from_le_bytes(value.to_le_bytes()), + $random_state, + ), + *hash, + ); + } + } else { + for (hash, value) in $hashes.iter_mut().zip(values.iter()) { + *hash = $ty::get_hash( + &$ty::from_le_bytes(value.to_le_bytes()), + $random_state, + ) + } + } + } else { + if $multi_col { + for (i, (hash, value)) in + $hashes.iter_mut().zip(values.iter()).enumerate() + { + if !array.is_null(i) { + *hash = combine_hashes( + $ty::get_hash( + &$ty::from_le_bytes(value.to_le_bytes()), + $random_state, + ), + *hash, + ); + } + } + } else { + for (i, (hash, value)) in + $hashes.iter_mut().zip(values.iter()).enumerate() + { + if !array.is_null(i) { + *hash = $ty::get_hash( + &$ty::from_le_bytes(value.to_le_bytes()), + $random_state, + ); + } + } + } + } + }; +} + +/// Creates hash values for every row, based on the values in the columns +/// +/// This implements so-called "vectorized hashing" +pub fn create_hashes<'a>( + arrays: &[ArrayRef], + random_state: &RandomState, + hashes_buffer: &'a mut Vec, +) -> Result<&'a mut Vec> { + // combine hashes with `combine_hashes` if we have more than 1 column + let multi_col = arrays.len() > 1; + + for col in arrays { + match col.data_type() { + DataType::UInt8 => { + hash_array_primitive!( + UInt8Array, + col, + u8, + hashes_buffer, + random_state, + multi_col + ); + } + DataType::UInt16 => { + hash_array_primitive!( + UInt16Array, + col, + u16, + hashes_buffer, + random_state, + multi_col + ); + } + DataType::UInt32 => { + hash_array_primitive!( + UInt32Array, + col, + u32, + hashes_buffer, + random_state, + multi_col + ); + } + DataType::UInt64 => { + hash_array_primitive!( + UInt64Array, + col, + u64, + hashes_buffer, + random_state, + multi_col + ); + } + DataType::Int8 => { + hash_array_primitive!( + Int8Array, + col, + i8, + hashes_buffer, + random_state, + multi_col + ); + } + DataType::Int16 => { + hash_array_primitive!( + Int16Array, + col, + i16, + hashes_buffer, + random_state, + multi_col + ); + } + DataType::Int32 => { + hash_array_primitive!( + Int32Array, + col, + i32, + hashes_buffer, + random_state, + multi_col + ); + } + DataType::Int64 => { + hash_array_primitive!( + Int64Array, + col, + i64, + hashes_buffer, + random_state, + multi_col + ); + } + DataType::Float32 => { + hash_array_float!( + Float32Array, + col, + u32, + hashes_buffer, + random_state, + multi_col + ); + } + DataType::Float64 => { + hash_array_float!( + Float64Array, + col, + u64, + hashes_buffer, + random_state, + multi_col + ); + } + DataType::Timestamp(TimeUnit::Millisecond, None) => { + hash_array_primitive!( + TimestampMillisecondArray, + col, + i64, + hashes_buffer, + random_state, + multi_col + ); + } + DataType::Timestamp(TimeUnit::Microsecond, None) => { + hash_array_primitive!( + TimestampMicrosecondArray, + col, + i64, + hashes_buffer, + random_state, + multi_col + ); + } + DataType::Timestamp(TimeUnit::Nanosecond, None) => { + hash_array_primitive!( + TimestampNanosecondArray, + col, + i64, + hashes_buffer, + random_state, + multi_col + ); + } + DataType::Date32 => { + hash_array_primitive!( + Date32Array, + col, + i32, + hashes_buffer, + random_state, + multi_col + ); + } + DataType::Date64 => { + hash_array_primitive!( + Date64Array, + col, + i64, + hashes_buffer, + random_state, + multi_col + ); + } + DataType::Boolean => { + hash_array!( + BooleanArray, + col, + u8, + hashes_buffer, + random_state, + multi_col + ); + } + DataType::Utf8 => { + hash_array!( + StringArray, + col, + str, + hashes_buffer, + random_state, + multi_col + ); + } + DataType::LargeUtf8 => { + hash_array!( + LargeStringArray, + col, + str, + hashes_buffer, + random_state, + multi_col + ); + } + _ => { + // This is internal because we should have caught this before. + return Err(DataFusionError::Internal( + "Unsupported data type in hasher".to_string(), + )); + } + } + } + Ok(hashes_buffer) +} + #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; fn check(left: &[Column], right: &[Column], on: &[(Column, Column)]) -> Result<()> { @@ -163,4 +513,20 @@ mod tests { assert!(check(&left, &right, on).is_ok()); } + + #[test] + fn create_hashes_for_float_arrays() -> Result<()> { + let f32_arr = Arc::new(Float32Array::from(vec![0.12, 0.5, 1f32, 444.7])); + let f64_arr = Arc::new(Float64Array::from(vec![0.12, 0.5, 1f64, 444.7])); + + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let hashes_buff = &mut vec![0; f32_arr.len()]; + let hashes = create_hashes(&[f32_arr], &random_state, hashes_buff)?; + assert_eq!(hashes.len(), 4,); + + let hashes = create_hashes(&[f64_arr], &random_state, hashes_buff)?; + assert_eq!(hashes.len(), 4,); + + Ok(()) + } } diff --git a/datafusion/src/physical_plan/repartition.rs b/datafusion/src/physical_plan/repartition.rs index e67e4c2d4477..b59071adb3a1 100644 --- a/datafusion/src/physical_plan/repartition.rs +++ b/datafusion/src/physical_plan/repartition.rs @@ -25,13 +25,14 @@ use std::time::Instant; use std::{any::Any, vec}; use crate::error::{DataFusionError, Result}; +use crate::physical_plan::hash_utils::create_hashes; use crate::physical_plan::{DisplayFormatType, ExecutionPlan, Partitioning, SQLMetric}; use arrow::record_batch::RecordBatch; use arrow::{array::Array, error::Result as ArrowResult}; use arrow::{compute::take, datatypes::SchemaRef}; use tokio_stream::wrappers::UnboundedReceiverStream; -use super::{hash_join::create_hashes, RecordBatchStream, SendableRecordBatchStream}; +use super::{RecordBatchStream, SendableRecordBatchStream}; use async_trait::async_trait; use futures::stream::Stream;