Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add String view helper functions #11517

Merged
merged 2 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions datafusion/common/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ use arrow::{
},
datatypes::{ArrowDictionaryKeyType, ArrowPrimitiveType},
};
use arrow_array::{BinaryViewArray, StringViewArray};

// Downcast ArrayRef to Date32Array
pub fn as_date32_array(array: &dyn Array) -> Result<&Date32Array> {
Expand Down Expand Up @@ -87,6 +88,11 @@ pub fn as_string_array(array: &dyn Array) -> Result<&StringArray> {
Ok(downcast_value!(array, StringArray))
}

// Downcast ArrayRef to StringViewArray
pub fn as_string_view_array(array: &dyn Array) -> Result<&StringViewArray> {
Ok(downcast_value!(array, StringViewArray))
}

// Downcast ArrayRef to UInt32Array
pub fn as_uint32_array(array: &dyn Array) -> Result<&UInt32Array> {
Ok(downcast_value!(array, UInt32Array))
Expand Down Expand Up @@ -221,6 +227,11 @@ pub fn as_binary_array(array: &dyn Array) -> Result<&BinaryArray> {
Ok(downcast_value!(array, BinaryArray))
}

// Downcast ArrayRef to BinaryViewArray
pub fn as_binary_view_array(array: &dyn Array) -> Result<&BinaryViewArray> {
Ok(downcast_value!(array, BinaryViewArray))
}

// Downcast ArrayRef to FixedSizeListArray
pub fn as_fixed_size_list_array(array: &dyn Array) -> Result<&FixedSizeListArray> {
Ok(downcast_value!(array, FixedSizeListArray))
Expand Down
121 changes: 109 additions & 12 deletions datafusion/common/src/hash_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,8 @@ pub fn create_hashes<'a>(
random_state: &RandomState,
hashes_buffer: &'a mut Vec<u64>,
) -> Result<&'a mut Vec<u64>> {
use crate::cast::{as_binary_view_array, as_string_view_array};

for (i, col) in arrays.iter().enumerate() {
let array = col.as_ref();
// combine hashes with `combine_hashes` for all columns besides the first
Expand All @@ -370,8 +372,10 @@ pub fn create_hashes<'a>(
DataType::Boolean => hash_array(as_boolean_array(array)?, random_state, hashes_buffer, rehash),
DataType::Utf8 => hash_array(as_string_array(array)?, random_state, hashes_buffer, rehash),
DataType::LargeUtf8 => hash_array(as_largestring_array(array), random_state, hashes_buffer, rehash),
DataType::Utf8View => hash_array(as_string_view_array(array)?, random_state, hashes_buffer, rehash),
DataType::Binary => hash_array(as_generic_binary_array::<i32>(array)?, random_state, hashes_buffer, rehash),
DataType::LargeBinary => hash_array(as_generic_binary_array::<i64>(array)?, random_state, hashes_buffer, rehash),
DataType::BinaryView => hash_array(as_binary_view_array(array)?, random_state, hashes_buffer, rehash),
DataType::FixedSizeBinary(_) => {
let array: &FixedSizeBinaryArray = array.as_any().downcast_ref().unwrap();
hash_array(array, random_state, hashes_buffer, rehash)
Expand Down Expand Up @@ -486,22 +490,57 @@ mod tests {
Ok(())
}

#[test]
fn create_hashes_binary() -> Result<()> {
let byte_array = Arc::new(BinaryArray::from_vec(vec![
&[4, 3, 2],
&[4, 3, 2],
&[1, 2, 3],
]));
macro_rules! create_hash_binary {
($NAME:ident, $ARRAY:ty) => {
#[cfg(not(feature = "force_hash_collisions"))]
#[test]
fn $NAME() {
let binary = [
Some(b"short".to_byte_slice()),
None,
Some(b"long but different 12 bytes string"),
Some(b"short2"),
Some(b"Longer than 12 bytes string"),
Some(b"short"),
Some(b"Longer than 12 bytes string"),
];

let binary_array = Arc::new(binary.iter().cloned().collect::<$ARRAY>());
let ref_array = Arc::new(binary.iter().cloned().collect::<BinaryArray>());

let random_state = RandomState::with_seeds(0, 0, 0, 0);

let mut binary_hashes = vec![0; binary.len()];
create_hashes(&[binary_array], &random_state, &mut binary_hashes)
.unwrap();

let mut ref_hashes = vec![0; binary.len()];
create_hashes(&[ref_array], &random_state, &mut ref_hashes).unwrap();

// Null values result in a zero hash,
for (val, hash) in binary.iter().zip(binary_hashes.iter()) {
match val {
Some(_) => assert_ne!(*hash, 0),
None => assert_eq!(*hash, 0),
}
}

let random_state = RandomState::with_seeds(0, 0, 0, 0);
let hashes_buff = &mut vec![0; byte_array.len()];
let hashes = create_hashes(&[byte_array], &random_state, hashes_buff)?;
assert_eq!(hashes.len(), 3,);
// same logical values should hash to the same hash value
assert_eq!(binary_hashes, ref_hashes);

Ok(())
// Same values should map to same hash values
assert_eq!(binary[0], binary[5]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

assert_eq!(binary[4], binary[6]);

// different binary should map to different hash values
assert_ne!(binary[0], binary[2]);
}
};
}

create_hash_binary!(binary_array, BinaryArray);
create_hash_binary!(binary_view_array, BinaryViewArray);

#[test]
fn create_hashes_fixed_size_binary() -> Result<()> {
let input_arg = vec![vec![1, 2], vec![5, 6], vec![5, 6]];
Expand All @@ -517,6 +556,64 @@ mod tests {
Ok(())
}

macro_rules! create_hash_string {
($NAME:ident, $ARRAY:ty) => {
#[cfg(not(feature = "force_hash_collisions"))]
#[test]
fn $NAME() {
let strings = [
Some("short"),
None,
Some("long but different 12 bytes string"),
Some("short2"),
Some("Longer than 12 bytes string"),
Some("short"),
Some("Longer than 12 bytes string"),
];

let string_array = Arc::new(strings.iter().cloned().collect::<$ARRAY>());
let dict_array = Arc::new(
strings
.iter()
.cloned()
.collect::<DictionaryArray<Int8Type>>(),
);

let random_state = RandomState::with_seeds(0, 0, 0, 0);

let mut string_hashes = vec![0; strings.len()];
create_hashes(&[string_array], &random_state, &mut string_hashes)
.unwrap();

let mut dict_hashes = vec![0; strings.len()];
create_hashes(&[dict_array], &random_state, &mut dict_hashes).unwrap();

// Null values result in a zero hash,
for (val, hash) in strings.iter().zip(string_hashes.iter()) {
match val {
Some(_) => assert_ne!(*hash, 0),
None => assert_eq!(*hash, 0),
}
}

// same logical values should hash to the same hash value
assert_eq!(string_hashes, dict_hashes);

// Same values should map to same hash values
assert_eq!(strings[0], strings[5]);
assert_eq!(strings[4], strings[6]);

// different strings should map to different hash values
assert_ne!(strings[0], strings[2]);
}
};
}

create_hash_string!(string_array, StringArray);
create_hash_string!(large_string_array, LargeStringArray);
create_hash_string!(string_view_array, StringArray);
create_hash_string!(dict_string_array, DictionaryArray<Int8Type>);

#[test]
// Tests actual values of hashes, which are different if forcing collisions
#[cfg(not(feature = "force_hash_collisions"))]
Expand Down
39 changes: 39 additions & 0 deletions datafusion/physical-expr/src/aggregate/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ use arrow_array::types::{
Decimal128Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
UInt16Type, UInt32Type, UInt64Type, UInt8Type,
};
use arrow_array::{BinaryViewArray, StringViewArray};
use datafusion_common::internal_err;
use datafusion_common::ScalarValue;
use datafusion_common::{downcast_value, DataFusionError, Result};
Expand Down Expand Up @@ -453,6 +454,14 @@ fn min_batch(values: &ArrayRef) -> Result<ScalarValue> {
DataType::LargeUtf8 => {
typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, min_string)
}
DataType::Utf8View => {
typed_min_max_batch_string!(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Presumably this is using the fast kernel from apache/arrow-rs#6053

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

values,
StringViewArray,
Utf8View,
min_string_view
)
}
DataType::Boolean => {
typed_min_max_batch!(values, BooleanArray, Boolean, min_boolean)
}
Expand All @@ -467,6 +476,14 @@ fn min_batch(values: &ArrayRef) -> Result<ScalarValue> {
min_binary
)
}
DataType::BinaryView => {
typed_min_max_batch_binary!(
&values,
BinaryViewArray,
BinaryView,
min_binary_view
)
}
_ => min_max_batch!(values, min),
})
}
Expand All @@ -480,12 +497,28 @@ fn max_batch(values: &ArrayRef) -> Result<ScalarValue> {
DataType::LargeUtf8 => {
typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, max_string)
}
DataType::Utf8View => {
typed_min_max_batch_string!(
values,
StringViewArray,
Utf8View,
max_string_view
)
}
DataType::Boolean => {
typed_min_max_batch!(values, BooleanArray, Boolean, max_boolean)
}
DataType::Binary => {
typed_min_max_batch_binary!(&values, BinaryArray, Binary, max_binary)
}
DataType::BinaryView => {
typed_min_max_batch_binary!(
&values,
BinaryViewArray,
BinaryView,
max_binary_view
)
}
DataType::LargeBinary => {
typed_min_max_batch_binary!(
&values,
Expand Down Expand Up @@ -629,12 +662,18 @@ macro_rules! min_max {
(ScalarValue::LargeUtf8(lhs), ScalarValue::LargeUtf8(rhs)) => {
typed_min_max_string!(lhs, rhs, LargeUtf8, $OP)
}
(ScalarValue::Utf8View(lhs), ScalarValue::Utf8View(rhs)) => {
typed_min_max_string!(lhs, rhs, Utf8View, $OP)
}
(ScalarValue::Binary(lhs), ScalarValue::Binary(rhs)) => {
typed_min_max_string!(lhs, rhs, Binary, $OP)
}
(ScalarValue::LargeBinary(lhs), ScalarValue::LargeBinary(rhs)) => {
typed_min_max_string!(lhs, rhs, LargeBinary, $OP)
}
(ScalarValue::BinaryView(lhs), ScalarValue::BinaryView(rhs)) => {
typed_min_max_string!(lhs, rhs, BinaryView, $OP)
}
(ScalarValue::TimestampSecond(lhs, l_tz), ScalarValue::TimestampSecond(rhs, _)) => {
typed_min_max!(lhs, rhs, TimestampSecond, $OP, l_tz)
}
Expand Down