Skip to content

Commit

Permalink
Add String view helper functions (#11517)
Browse files Browse the repository at this point in the history
* add functions

* add tests for hash util
  • Loading branch information
XiangpengHao authored Jul 19, 2024
1 parent 8d8732c commit 8e0ca1a
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 12 deletions.
11 changes: 11 additions & 0 deletions datafusion/common/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ use arrow::{
},
datatypes::{ArrowDictionaryKeyType, ArrowPrimitiveType},
};
use arrow_array::{BinaryViewArray, StringViewArray};

// Downcast ArrayRef to Date32Array
pub fn as_date32_array(array: &dyn Array) -> Result<&Date32Array> {
Expand Down Expand Up @@ -87,6 +88,11 @@ pub fn as_string_array(array: &dyn Array) -> Result<&StringArray> {
Ok(downcast_value!(array, StringArray))
}

// Downcast ArrayRef to StringViewArray
pub fn as_string_view_array(array: &dyn Array) -> Result<&StringViewArray> {
Ok(downcast_value!(array, StringViewArray))
}

// Downcast ArrayRef to UInt32Array
pub fn as_uint32_array(array: &dyn Array) -> Result<&UInt32Array> {
Ok(downcast_value!(array, UInt32Array))
Expand Down Expand Up @@ -221,6 +227,11 @@ pub fn as_binary_array(array: &dyn Array) -> Result<&BinaryArray> {
Ok(downcast_value!(array, BinaryArray))
}

// Downcast ArrayRef to BinaryViewArray
pub fn as_binary_view_array(array: &dyn Array) -> Result<&BinaryViewArray> {
Ok(downcast_value!(array, BinaryViewArray))
}

// Downcast ArrayRef to FixedSizeListArray
pub fn as_fixed_size_list_array(array: &dyn Array) -> Result<&FixedSizeListArray> {
Ok(downcast_value!(array, FixedSizeListArray))
Expand Down
121 changes: 109 additions & 12 deletions datafusion/common/src/hash_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,8 @@ pub fn create_hashes<'a>(
random_state: &RandomState,
hashes_buffer: &'a mut Vec<u64>,
) -> Result<&'a mut Vec<u64>> {
use crate::cast::{as_binary_view_array, as_string_view_array};

for (i, col) in arrays.iter().enumerate() {
let array = col.as_ref();
// combine hashes with `combine_hashes` for all columns besides the first
Expand All @@ -370,8 +372,10 @@ pub fn create_hashes<'a>(
DataType::Boolean => hash_array(as_boolean_array(array)?, random_state, hashes_buffer, rehash),
DataType::Utf8 => hash_array(as_string_array(array)?, random_state, hashes_buffer, rehash),
DataType::LargeUtf8 => hash_array(as_largestring_array(array), random_state, hashes_buffer, rehash),
DataType::Utf8View => hash_array(as_string_view_array(array)?, random_state, hashes_buffer, rehash),
DataType::Binary => hash_array(as_generic_binary_array::<i32>(array)?, random_state, hashes_buffer, rehash),
DataType::LargeBinary => hash_array(as_generic_binary_array::<i64>(array)?, random_state, hashes_buffer, rehash),
DataType::BinaryView => hash_array(as_binary_view_array(array)?, random_state, hashes_buffer, rehash),
DataType::FixedSizeBinary(_) => {
let array: &FixedSizeBinaryArray = array.as_any().downcast_ref().unwrap();
hash_array(array, random_state, hashes_buffer, rehash)
Expand Down Expand Up @@ -486,22 +490,57 @@ mod tests {
Ok(())
}

#[test]
fn create_hashes_binary() -> Result<()> {
let byte_array = Arc::new(BinaryArray::from_vec(vec![
&[4, 3, 2],
&[4, 3, 2],
&[1, 2, 3],
]));
macro_rules! create_hash_binary {
($NAME:ident, $ARRAY:ty) => {
#[cfg(not(feature = "force_hash_collisions"))]
#[test]
fn $NAME() {
let binary = [
Some(b"short".to_byte_slice()),
None,
Some(b"long but different 12 bytes string"),
Some(b"short2"),
Some(b"Longer than 12 bytes string"),
Some(b"short"),
Some(b"Longer than 12 bytes string"),
];

let binary_array = Arc::new(binary.iter().cloned().collect::<$ARRAY>());
let ref_array = Arc::new(binary.iter().cloned().collect::<BinaryArray>());

let random_state = RandomState::with_seeds(0, 0, 0, 0);

let mut binary_hashes = vec![0; binary.len()];
create_hashes(&[binary_array], &random_state, &mut binary_hashes)
.unwrap();

let mut ref_hashes = vec![0; binary.len()];
create_hashes(&[ref_array], &random_state, &mut ref_hashes).unwrap();

// Null values result in a zero hash,
for (val, hash) in binary.iter().zip(binary_hashes.iter()) {
match val {
Some(_) => assert_ne!(*hash, 0),
None => assert_eq!(*hash, 0),
}
}

let random_state = RandomState::with_seeds(0, 0, 0, 0);
let hashes_buff = &mut vec![0; byte_array.len()];
let hashes = create_hashes(&[byte_array], &random_state, hashes_buff)?;
assert_eq!(hashes.len(), 3,);
// same logical values should hash to the same hash value
assert_eq!(binary_hashes, ref_hashes);

Ok(())
// Same values should map to same hash values
assert_eq!(binary[0], binary[5]);
assert_eq!(binary[4], binary[6]);

// different binary should map to different hash values
assert_ne!(binary[0], binary[2]);
}
};
}

create_hash_binary!(binary_array, BinaryArray);
create_hash_binary!(binary_view_array, BinaryViewArray);

#[test]
fn create_hashes_fixed_size_binary() -> Result<()> {
let input_arg = vec![vec![1, 2], vec![5, 6], vec![5, 6]];
Expand All @@ -517,6 +556,64 @@ mod tests {
Ok(())
}

macro_rules! create_hash_string {
($NAME:ident, $ARRAY:ty) => {
#[cfg(not(feature = "force_hash_collisions"))]
#[test]
fn $NAME() {
let strings = [
Some("short"),
None,
Some("long but different 12 bytes string"),
Some("short2"),
Some("Longer than 12 bytes string"),
Some("short"),
Some("Longer than 12 bytes string"),
];

let string_array = Arc::new(strings.iter().cloned().collect::<$ARRAY>());
let dict_array = Arc::new(
strings
.iter()
.cloned()
.collect::<DictionaryArray<Int8Type>>(),
);

let random_state = RandomState::with_seeds(0, 0, 0, 0);

let mut string_hashes = vec![0; strings.len()];
create_hashes(&[string_array], &random_state, &mut string_hashes)
.unwrap();

let mut dict_hashes = vec![0; strings.len()];
create_hashes(&[dict_array], &random_state, &mut dict_hashes).unwrap();

// Null values result in a zero hash,
for (val, hash) in strings.iter().zip(string_hashes.iter()) {
match val {
Some(_) => assert_ne!(*hash, 0),
None => assert_eq!(*hash, 0),
}
}

// same logical values should hash to the same hash value
assert_eq!(string_hashes, dict_hashes);

// Same values should map to same hash values
assert_eq!(strings[0], strings[5]);
assert_eq!(strings[4], strings[6]);

// different strings should map to different hash values
assert_ne!(strings[0], strings[2]);
}
};
}

create_hash_string!(string_array, StringArray);
create_hash_string!(large_string_array, LargeStringArray);
create_hash_string!(string_view_array, StringArray);
create_hash_string!(dict_string_array, DictionaryArray<Int8Type>);

#[test]
// Tests actual values of hashes, which are different if forcing collisions
#[cfg(not(feature = "force_hash_collisions"))]
Expand Down
39 changes: 39 additions & 0 deletions datafusion/physical-expr/src/aggregate/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ use arrow_array::types::{
Decimal128Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
UInt16Type, UInt32Type, UInt64Type, UInt8Type,
};
use arrow_array::{BinaryViewArray, StringViewArray};
use datafusion_common::internal_err;
use datafusion_common::ScalarValue;
use datafusion_common::{downcast_value, DataFusionError, Result};
Expand Down Expand Up @@ -453,6 +454,14 @@ fn min_batch(values: &ArrayRef) -> Result<ScalarValue> {
DataType::LargeUtf8 => {
typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, min_string)
}
DataType::Utf8View => {
typed_min_max_batch_string!(
values,
StringViewArray,
Utf8View,
min_string_view
)
}
DataType::Boolean => {
typed_min_max_batch!(values, BooleanArray, Boolean, min_boolean)
}
Expand All @@ -467,6 +476,14 @@ fn min_batch(values: &ArrayRef) -> Result<ScalarValue> {
min_binary
)
}
DataType::BinaryView => {
typed_min_max_batch_binary!(
&values,
BinaryViewArray,
BinaryView,
min_binary_view
)
}
_ => min_max_batch!(values, min),
})
}
Expand All @@ -480,12 +497,28 @@ fn max_batch(values: &ArrayRef) -> Result<ScalarValue> {
DataType::LargeUtf8 => {
typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, max_string)
}
DataType::Utf8View => {
typed_min_max_batch_string!(
values,
StringViewArray,
Utf8View,
max_string_view
)
}
DataType::Boolean => {
typed_min_max_batch!(values, BooleanArray, Boolean, max_boolean)
}
DataType::Binary => {
typed_min_max_batch_binary!(&values, BinaryArray, Binary, max_binary)
}
DataType::BinaryView => {
typed_min_max_batch_binary!(
&values,
BinaryViewArray,
BinaryView,
max_binary_view
)
}
DataType::LargeBinary => {
typed_min_max_batch_binary!(
&values,
Expand Down Expand Up @@ -629,12 +662,18 @@ macro_rules! min_max {
(ScalarValue::LargeUtf8(lhs), ScalarValue::LargeUtf8(rhs)) => {
typed_min_max_string!(lhs, rhs, LargeUtf8, $OP)
}
(ScalarValue::Utf8View(lhs), ScalarValue::Utf8View(rhs)) => {
typed_min_max_string!(lhs, rhs, Utf8View, $OP)
}
(ScalarValue::Binary(lhs), ScalarValue::Binary(rhs)) => {
typed_min_max_string!(lhs, rhs, Binary, $OP)
}
(ScalarValue::LargeBinary(lhs), ScalarValue::LargeBinary(rhs)) => {
typed_min_max_string!(lhs, rhs, LargeBinary, $OP)
}
(ScalarValue::BinaryView(lhs), ScalarValue::BinaryView(rhs)) => {
typed_min_max_string!(lhs, rhs, BinaryView, $OP)
}
(ScalarValue::TimestampSecond(lhs, l_tz), ScalarValue::TimestampSecond(rhs, _)) => {
typed_min_max!(lhs, rhs, TimestampSecond, $OP, l_tz)
}
Expand Down

0 comments on commit 8e0ca1a

Please sign in to comment.