From fd08520904f30aaf6562fb3147b0960c5863855a Mon Sep 17 00:00:00 2001 From: zhangli20 Date: Sun, 8 Dec 2024 22:11:08 +0800 Subject: [PATCH] apply #13688: Improve substr() performance by avoiding using owned string --- datafusion/functions/src/unicode/substr.rs | 281 ++++++++++++++++----- 1 file changed, 218 insertions(+), 63 deletions(-) diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index 40d3a4d13e97..670020d5f90f 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -16,20 +16,18 @@ // under the License. use std::any::Any; -use std::cmp::max; use std::sync::Arc; +use crate::string::common::StringArrayType; use crate::utils::{make_scalar_function, utf8_to_str_type}; -use arrow::array::{ - make_view, Array, ArrayAccessor, ArrayIter, ArrayRef, AsArray, ByteView, - GenericStringArray, OffsetSizeTrait, StringViewArray, -}; +use arrow::array::{Array, ArrayIter, ArrayRef, AsArray, ByteView, GenericStringBuilder, Int64Array, make_view, OffsetSizeTrait, StringViewArray}; use arrow::datatypes::DataType; use arrow_buffer::{NullBufferBuilder, ScalarBuffer}; use datafusion_common::cast::as_int64_array; -use datafusion_common::{exec_datafusion_err, exec_err, Result}; -use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_common::{exec_err, plan_err, Result}; +use datafusion_expr::{ + ColumnarValue, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct SubstrFunc { @@ -45,19 +43,8 @@ impl Default for SubstrFunc { impl SubstrFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::one_of( - vec![ - Exact(vec![Utf8, Int64]), - Exact(vec![LargeUtf8, Int64]), - Exact(vec![Utf8, Int64, Int64]), - Exact(vec![LargeUtf8, Int64, Int64]), - Exact(vec![Utf8View, Int64]), - Exact(vec![Utf8View, Int64, Int64]), - ], - Volatility::Immutable, - ), + signature: Signature::user_defined(Volatility::Immutable), aliases: vec![String::from("substring")], } } @@ -91,6 +78,72 @@ impl ScalarUDFImpl for SubstrFunc { fn aliases(&self) -> &[String] { &self.aliases } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() < 2 || arg_types.len() > 3 { + return plan_err!( + "The {} function requires 2 or 3 arguments, but got {}.", + self.name(), + arg_types.len() + ); + } + let first_data_type = match &arg_types[0] { + DataType::Null => Ok(DataType::Utf8), + DataType::LargeUtf8 | DataType::Utf8View | DataType::Utf8 => Ok(arg_types[0].clone()), + DataType::Dictionary(key_type, value_type) => { + if key_type.is_integer() { + match value_type.as_ref() { + DataType::Null => Ok(DataType::Utf8), + DataType::LargeUtf8 | DataType::Utf8View | DataType::Utf8 => Ok(*value_type.clone()), + _ => plan_err!( + "The first argument of the {} function can only be a string, but got {:?}.", + self.name(), + arg_types[0] + ), + } + } else { + plan_err!( + "The first argument of the {} function can only be a string, but got {:?}.", + self.name(), + arg_types[0] + ) + } + } + _ => plan_err!( + "The first argument of the {} function can only be a string, but got {:?}.", + self.name(), + arg_types[0] + ) + }?; + + if ![DataType::Int64, DataType::Int32, DataType::Null].contains(&arg_types[1]) { + return plan_err!( + "The second argument of the {} function can only be an integer, but got {:?}.", + self.name(), + arg_types[1] + ); + } + + if arg_types.len() == 3 + && ![DataType::Int64, DataType::Int32, DataType::Null].contains(&arg_types[2]) + { + return plan_err!( + "The third argument of the {} function can only be an integer, but got {:?}.", + self.name(), + arg_types[2] + ); + } + + if arg_types.len() == 2 { + Ok(vec![first_data_type.to_owned(), DataType::Int64]) + } else { + Ok(vec![ + first_data_type.to_owned(), + DataType::Int64, + DataType::Int64, + ]) + } + } } /// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).) @@ -119,19 +172,27 @@ pub fn substr(args: &[ArrayRef]) -> Result { } // Convert the given `start` and `count` to valid byte indices within `input` string +// // Input `start` and `count` are equivalent to PostgreSQL's `substr(s, start, count)` // `start` is 1-based, if `count` is not provided count to the end of the string // Input indices are character-based, and return values are byte indices // The input bounds can be outside string bounds, this function will return // the intersection between input bounds and valid string bounds +// `input_ascii_only` is used to optimize this function if `input` is ASCII-only // // * Example // 'Hi🌏' in-mem (`[]` for one char, `x` for one byte): [x][x][xxxx] // `get_true_start_end('Hi🌏', 1, None) -> (0, 6)` // `get_true_start_end('Hi🌏', 1, 1) -> (0, 1)` // `get_true_start_end('Hi🌏', -10, 2) -> (0, 0)` -fn get_true_start_end(input: &str, start: i64, count: Option) -> (usize, usize) { - let start = start - 1; +fn get_true_start_end( + input: &str, + start: i64, + count: Option, + is_input_ascii_only: bool, +) -> (usize, usize) { + let start = start.checked_sub(1).unwrap_or(start); + let end = match count { Some(count) => start + count as i64, None => input.len() as i64, @@ -142,6 +203,14 @@ fn get_true_start_end(input: &str, start: i64, count: Option) -> (usize, us let end = end.clamp(0, input.len() as i64) as usize; let count = end - start; + // If input is ASCII-only, byte-based indices equals to char-based indices + if is_input_ascii_only { + return (start, end); + } + + // Otherwise, calculate byte indices from char indices + // Note this decoding is relatively expensive for this simple `substr` function,, + // so the implementation attempts to decode in one pass (and caused the complexity) let (mut st, mut ed) = (input.len(), input.len()); let mut start_counting = false; let mut cnt = 0; @@ -165,6 +234,7 @@ fn get_true_start_end(input: &str, start: i64, count: Option) -> (usize, us (st, ed) } + /// Make a `u128` based on the given substr, start(offset to view.offset), and /// push into to the given buffers fn make_and_append_view( @@ -186,6 +256,53 @@ fn make_and_append_view( null_builder.append_non_null(); } +// String characters are variable length encoded in UTF-8, `substr()` function's +// arguments are character-based, converting them into byte-based indices +// requires expensive decoding. +// However, checking if a string is ASCII-only is relatively cheap. +// If strings are ASCII only, use byte-based indices instead. +// +// A common pattern to call `substr()` is taking a small prefix of a long +// string, such as `substr(long_str_with_1k_chars, 1, 32)`. +// In such case the overhead of ASCII-validation may not be worth it, so +// skip the validation for short prefix for now. +fn enable_ascii_fast_path<'a, V: StringArrayType<'a>>( + string_array: &V, + start: &Int64Array, + count: Option<&Int64Array>, +) -> bool { + let is_short_prefix = match count { + Some(count) => { + let short_prefix_threshold = 32.0; + let n_sample = 10; + + // HACK: can be simplified if function has specialized + // implementation for `ScalarValue` (implement without `make_scalar_function()`) + let avg_prefix_len = start + .iter() + .zip(count.iter()) + .take(n_sample) + .map(|(start, count)| { + let start = start.unwrap_or(0); + let count = count.unwrap_or(0); + // To get substring, need to decode from 0 to start+count instead of start to start+count + start + count + }) + .sum::(); + + avg_prefix_len as f64 / n_sample as f64 <= short_prefix_threshold + } + None => false, + }; + + if is_short_prefix { + // Skip ASCII validation for short prefix + false + } else { + string_array.is_ascii() + } +} + // The decoding process refs the trait at: arrow/arrow-data/src/byte_view.rs:44 // From for ByteView fn string_view_substr( @@ -196,6 +313,14 @@ fn string_view_substr( let mut null_builder = NullBufferBuilder::new(string_view_array.len()); let start_array = as_int64_array(&args[0])?; + let count_array_opt = if args.len() == 2 { + Some(as_int64_array(&args[1])?) + } else { + None + }; + + let enable_ascii_fast_path = + enable_ascii_fast_path(&string_view_array, start_array, count_array_opt); // In either case of `substr(s, i)` or `substr(s, i, cnt)` // If any of input argument is `NULL`, the result is `NULL` @@ -207,7 +332,8 @@ fn string_view_substr( .zip(start_array.iter()) { if let (Some(str), Some(start)) = (str_opt, start_opt) { - let (start, end) = get_true_start_end(str, start, None); + let (start, end) = + get_true_start_end(str, start, None, enable_ascii_fast_path); let substr = &str[start..end]; make_and_append_view( @@ -224,7 +350,7 @@ fn string_view_substr( } } 2 => { - let count_array = as_int64_array(&args[1])?; + let count_array = count_array_opt.unwrap(); for (((str_opt, raw_view), start_opt), count_opt) in string_view_array .iter() .zip(string_view_array.views().iter()) @@ -239,8 +365,17 @@ fn string_view_substr( "negative substring length not allowed: substr(, {start}, {count})" ); } else { - let (start, end) = - get_true_start_end(str, start, Some(count as u64)); + if start == i64::MIN { + return exec_err!( + "negative overflow when calculating skip value" + ); + } + let (start, end) = get_true_start_end( + str, + start, + Some(count as u64), + enable_ascii_fast_path, + ); let substr = &str[start..end]; make_and_append_view( @@ -283,58 +418,78 @@ fn string_view_substr( fn string_substr<'a, V, T>(string_array: V, args: &[ArrayRef]) -> Result where - V: ArrayAccessor, + V: StringArrayType<'a>, T: OffsetSizeTrait, { + let start_array = as_int64_array(&args[0])?; + let count_array_opt = if args.len() == 2 { + Some(as_int64_array(&args[1])?) + } else { + None + }; + + let enable_ascii_fast_path = + enable_ascii_fast_path(&string_array, start_array, count_array_opt); + match args.len() { 1 => { let iter = ArrayIter::new(string_array); - let start_array = as_int64_array(&args[0])?; - - let result = iter - .zip(start_array.iter()) - .map(|(string, start)| match (string, start) { + let mut result_builder = GenericStringBuilder::::new(); + for (string, start) in iter.zip(start_array.iter()) { + match (string, start) { (Some(string), Some(start)) => { - if start <= 0 { - Some(string.to_string()) - } else { - Some(string.chars().skip(start as usize - 1).collect()) - } + let (start, end) = get_true_start_end( + string, + start, + None, + enable_ascii_fast_path, + ); // start, end is byte-based + let substr = &string[start..end]; + result_builder.append_value(substr); } - _ => None, - }) - .collect::>(); - Ok(Arc::new(result) as ArrayRef) + _ => { + result_builder.append_null(); + } + } + } + Ok(Arc::new(result_builder.finish()) as ArrayRef) } 2 => { let iter = ArrayIter::new(string_array); - let start_array = as_int64_array(&args[0])?; - let count_array = as_int64_array(&args[1])?; + let count_array = count_array_opt.unwrap(); + let mut result_builder = GenericStringBuilder::::new(); - let result = iter - .zip(start_array.iter()) - .zip(count_array.iter()) - .map(|((string, start), count)| { - match (string, start, count) { - (Some(string), Some(start), Some(count)) => { - if count < 0 { - exec_err!( + for ((string, start), count) in + iter.zip(start_array.iter()).zip(count_array.iter()) + { + match (string, start, count) { + (Some(string), Some(start), Some(count)) => { + if count < 0 { + return exec_err!( "negative substring length not allowed: substr(, {start}, {count})" - ) - } else { - let skip = max(0, start.checked_sub(1).ok_or_else( - || exec_datafusion_err!("negative overflow when calculating skip value") - )?); - let count = max(0, count + (if start < 1 { start - 1 } else { 0 })); - Ok(Some(string.chars().skip(skip as usize).take(count as usize).collect::())) + ); + } else { + if start == i64::MIN { + return exec_err!( + "negative overflow when calculating skip value" + ); } + let (start, end) = get_true_start_end( + string, + start, + Some(count as u64), + enable_ascii_fast_path, + ); // start, end is byte-based + let substr = &string[start..end]; + result_builder.append_value(substr); } - _ => Ok(None), } - }) - .collect::>>()?; - - Ok(Arc::new(result) as ArrayRef) + _ => { + result_builder.append_null(); + } + } + } + Ok(Arc::new(result_builder.finish()) as ArrayRef) } other => { exec_err!("substr was called with {other} arguments. It requires 2 or 3.")