Skip to content

Commit

Permalink
apply apache#13688: Improve substr() performance by avoiding using ow…
Browse files Browse the repository at this point in the history
…ned string
  • Loading branch information
zhangli20 committed Dec 8, 2024
1 parent 2bc42ea commit fd08520
Showing 1 changed file with 218 additions and 63 deletions.
281 changes: 218 additions & 63 deletions datafusion/functions/src/unicode/substr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,18 @@
// under the License.

use std::any::Any;
use std::cmp::max;
use std::sync::Arc;

use crate::string::common::StringArrayType;
use crate::utils::{make_scalar_function, utf8_to_str_type};
use arrow::array::{
make_view, Array, ArrayAccessor, ArrayIter, ArrayRef, AsArray, ByteView,
GenericStringArray, OffsetSizeTrait, StringViewArray,
};
use arrow::array::{Array, ArrayIter, ArrayRef, AsArray, ByteView, GenericStringBuilder, Int64Array, make_view, OffsetSizeTrait, StringViewArray};
use arrow::datatypes::DataType;
use arrow_buffer::{NullBufferBuilder, ScalarBuffer};
use datafusion_common::cast::as_int64_array;
use datafusion_common::{exec_datafusion_err, exec_err, Result};
use datafusion_expr::TypeSignature::Exact;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use datafusion_common::{exec_err, plan_err, Result};
use datafusion_expr::{
ColumnarValue, ScalarUDFImpl, Signature, Volatility,
};

#[derive(Debug)]
pub struct SubstrFunc {
Expand All @@ -45,19 +43,8 @@ impl Default for SubstrFunc {

impl SubstrFunc {
pub fn new() -> Self {
use DataType::*;
Self {
signature: Signature::one_of(
vec![
Exact(vec![Utf8, Int64]),
Exact(vec![LargeUtf8, Int64]),
Exact(vec![Utf8, Int64, Int64]),
Exact(vec![LargeUtf8, Int64, Int64]),
Exact(vec![Utf8View, Int64]),
Exact(vec![Utf8View, Int64, Int64]),
],
Volatility::Immutable,
),
signature: Signature::user_defined(Volatility::Immutable),
aliases: vec![String::from("substring")],
}
}
Expand Down Expand Up @@ -91,6 +78,72 @@ impl ScalarUDFImpl for SubstrFunc {
fn aliases(&self) -> &[String] {
&self.aliases
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() < 2 || arg_types.len() > 3 {
return plan_err!(
"The {} function requires 2 or 3 arguments, but got {}.",
self.name(),
arg_types.len()
);
}
let first_data_type = match &arg_types[0] {
DataType::Null => Ok(DataType::Utf8),
DataType::LargeUtf8 | DataType::Utf8View | DataType::Utf8 => Ok(arg_types[0].clone()),
DataType::Dictionary(key_type, value_type) => {
if key_type.is_integer() {
match value_type.as_ref() {
DataType::Null => Ok(DataType::Utf8),
DataType::LargeUtf8 | DataType::Utf8View | DataType::Utf8 => Ok(*value_type.clone()),
_ => plan_err!(
"The first argument of the {} function can only be a string, but got {:?}.",
self.name(),
arg_types[0]
),
}
} else {
plan_err!(
"The first argument of the {} function can only be a string, but got {:?}.",
self.name(),
arg_types[0]
)
}
}
_ => plan_err!(
"The first argument of the {} function can only be a string, but got {:?}.",
self.name(),
arg_types[0]
)
}?;

if ![DataType::Int64, DataType::Int32, DataType::Null].contains(&arg_types[1]) {
return plan_err!(
"The second argument of the {} function can only be an integer, but got {:?}.",
self.name(),
arg_types[1]
);
}

if arg_types.len() == 3
&& ![DataType::Int64, DataType::Int32, DataType::Null].contains(&arg_types[2])
{
return plan_err!(
"The third argument of the {} function can only be an integer, but got {:?}.",
self.name(),
arg_types[2]
);
}

if arg_types.len() == 2 {
Ok(vec![first_data_type.to_owned(), DataType::Int64])
} else {
Ok(vec![
first_data_type.to_owned(),
DataType::Int64,
DataType::Int64,
])
}
}
}

/// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).)
Expand Down Expand Up @@ -119,19 +172,27 @@ pub fn substr(args: &[ArrayRef]) -> Result<ArrayRef> {
}

// Convert the given `start` and `count` to valid byte indices within `input` string
//
// Input `start` and `count` are equivalent to PostgreSQL's `substr(s, start, count)`
// `start` is 1-based, if `count` is not provided count to the end of the string
// Input indices are character-based, and return values are byte indices
// The input bounds can be outside string bounds, this function will return
// the intersection between input bounds and valid string bounds
// `input_ascii_only` is used to optimize this function if `input` is ASCII-only
//
// * Example
// 'Hi🌏' in-mem (`[]` for one char, `x` for one byte): [x][x][xxxx]
// `get_true_start_end('Hi🌏', 1, None) -> (0, 6)`
// `get_true_start_end('Hi🌏', 1, 1) -> (0, 1)`
// `get_true_start_end('Hi🌏', -10, 2) -> (0, 0)`
fn get_true_start_end(input: &str, start: i64, count: Option<u64>) -> (usize, usize) {
let start = start - 1;
fn get_true_start_end(
input: &str,
start: i64,
count: Option<u64>,
is_input_ascii_only: bool,
) -> (usize, usize) {
let start = start.checked_sub(1).unwrap_or(start);

let end = match count {
Some(count) => start + count as i64,
None => input.len() as i64,
Expand All @@ -142,6 +203,14 @@ fn get_true_start_end(input: &str, start: i64, count: Option<u64>) -> (usize, us
let end = end.clamp(0, input.len() as i64) as usize;
let count = end - start;

// If input is ASCII-only, byte-based indices equals to char-based indices
if is_input_ascii_only {
return (start, end);
}

// Otherwise, calculate byte indices from char indices
// Note this decoding is relatively expensive for this simple `substr` function,,
// so the implementation attempts to decode in one pass (and caused the complexity)
let (mut st, mut ed) = (input.len(), input.len());
let mut start_counting = false;
let mut cnt = 0;
Expand All @@ -165,6 +234,7 @@ fn get_true_start_end(input: &str, start: i64, count: Option<u64>) -> (usize, us
(st, ed)
}


/// Make a `u128` based on the given substr, start(offset to view.offset), and
/// push into to the given buffers
fn make_and_append_view(
Expand All @@ -186,6 +256,53 @@ fn make_and_append_view(
null_builder.append_non_null();
}

// String characters are variable length encoded in UTF-8, `substr()` function's
// arguments are character-based, converting them into byte-based indices
// requires expensive decoding.
// However, checking if a string is ASCII-only is relatively cheap.
// If strings are ASCII only, use byte-based indices instead.
//
// A common pattern to call `substr()` is taking a small prefix of a long
// string, such as `substr(long_str_with_1k_chars, 1, 32)`.
// In such case the overhead of ASCII-validation may not be worth it, so
// skip the validation for short prefix for now.
fn enable_ascii_fast_path<'a, V: StringArrayType<'a>>(
string_array: &V,
start: &Int64Array,
count: Option<&Int64Array>,
) -> bool {
let is_short_prefix = match count {
Some(count) => {
let short_prefix_threshold = 32.0;
let n_sample = 10;

// HACK: can be simplified if function has specialized
// implementation for `ScalarValue` (implement without `make_scalar_function()`)
let avg_prefix_len = start
.iter()
.zip(count.iter())
.take(n_sample)
.map(|(start, count)| {
let start = start.unwrap_or(0);
let count = count.unwrap_or(0);
// To get substring, need to decode from 0 to start+count instead of start to start+count
start + count
})
.sum::<i64>();

avg_prefix_len as f64 / n_sample as f64 <= short_prefix_threshold
}
None => false,
};

if is_short_prefix {
// Skip ASCII validation for short prefix
false
} else {
string_array.is_ascii()
}
}

// The decoding process refs the trait at: arrow/arrow-data/src/byte_view.rs:44
// From<u128> for ByteView
fn string_view_substr(
Expand All @@ -196,6 +313,14 @@ fn string_view_substr(
let mut null_builder = NullBufferBuilder::new(string_view_array.len());

let start_array = as_int64_array(&args[0])?;
let count_array_opt = if args.len() == 2 {
Some(as_int64_array(&args[1])?)
} else {
None
};

let enable_ascii_fast_path =
enable_ascii_fast_path(&string_view_array, start_array, count_array_opt);

// In either case of `substr(s, i)` or `substr(s, i, cnt)`
// If any of input argument is `NULL`, the result is `NULL`
Expand All @@ -207,7 +332,8 @@ fn string_view_substr(
.zip(start_array.iter())
{
if let (Some(str), Some(start)) = (str_opt, start_opt) {
let (start, end) = get_true_start_end(str, start, None);
let (start, end) =
get_true_start_end(str, start, None, enable_ascii_fast_path);
let substr = &str[start..end];

make_and_append_view(
Expand All @@ -224,7 +350,7 @@ fn string_view_substr(
}
}
2 => {
let count_array = as_int64_array(&args[1])?;
let count_array = count_array_opt.unwrap();
for (((str_opt, raw_view), start_opt), count_opt) in string_view_array
.iter()
.zip(string_view_array.views().iter())
Expand All @@ -239,8 +365,17 @@ fn string_view_substr(
"negative substring length not allowed: substr(<str>, {start}, {count})"
);
} else {
let (start, end) =
get_true_start_end(str, start, Some(count as u64));
if start == i64::MIN {
return exec_err!(
"negative overflow when calculating skip value"
);
}
let (start, end) = get_true_start_end(
str,
start,
Some(count as u64),
enable_ascii_fast_path,
);
let substr = &str[start..end];

make_and_append_view(
Expand Down Expand Up @@ -283,58 +418,78 @@ fn string_view_substr(

fn string_substr<'a, V, T>(string_array: V, args: &[ArrayRef]) -> Result<ArrayRef>
where
V: ArrayAccessor<Item = &'a str>,
V: StringArrayType<'a>,
T: OffsetSizeTrait,
{
let start_array = as_int64_array(&args[0])?;
let count_array_opt = if args.len() == 2 {
Some(as_int64_array(&args[1])?)
} else {
None
};

let enable_ascii_fast_path =
enable_ascii_fast_path(&string_array, start_array, count_array_opt);

match args.len() {
1 => {
let iter = ArrayIter::new(string_array);
let start_array = as_int64_array(&args[0])?;

let result = iter
.zip(start_array.iter())
.map(|(string, start)| match (string, start) {
let mut result_builder = GenericStringBuilder::<T>::new();
for (string, start) in iter.zip(start_array.iter()) {
match (string, start) {
(Some(string), Some(start)) => {
if start <= 0 {
Some(string.to_string())
} else {
Some(string.chars().skip(start as usize - 1).collect())
}
let (start, end) = get_true_start_end(
string,
start,
None,
enable_ascii_fast_path,
); // start, end is byte-based
let substr = &string[start..end];
result_builder.append_value(substr);
}
_ => None,
})
.collect::<GenericStringArray<T>>();
Ok(Arc::new(result) as ArrayRef)
_ => {
result_builder.append_null();
}
}
}
Ok(Arc::new(result_builder.finish()) as ArrayRef)
}
2 => {
let iter = ArrayIter::new(string_array);
let start_array = as_int64_array(&args[0])?;
let count_array = as_int64_array(&args[1])?;
let count_array = count_array_opt.unwrap();
let mut result_builder = GenericStringBuilder::<T>::new();

let result = iter
.zip(start_array.iter())
.zip(count_array.iter())
.map(|((string, start), count)| {
match (string, start, count) {
(Some(string), Some(start), Some(count)) => {
if count < 0 {
exec_err!(
for ((string, start), count) in
iter.zip(start_array.iter()).zip(count_array.iter())
{
match (string, start, count) {
(Some(string), Some(start), Some(count)) => {
if count < 0 {
return exec_err!(
"negative substring length not allowed: substr(<str>, {start}, {count})"
)
} else {
let skip = max(0, start.checked_sub(1).ok_or_else(
|| exec_datafusion_err!("negative overflow when calculating skip value")
)?);
let count = max(0, count + (if start < 1 { start - 1 } else { 0 }));
Ok(Some(string.chars().skip(skip as usize).take(count as usize).collect::<String>()))
);
} else {
if start == i64::MIN {
return exec_err!(
"negative overflow when calculating skip value"
);
}
let (start, end) = get_true_start_end(
string,
start,
Some(count as u64),
enable_ascii_fast_path,
); // start, end is byte-based
let substr = &string[start..end];
result_builder.append_value(substr);
}
_ => Ok(None),
}
})
.collect::<Result<GenericStringArray<T>>>()?;

Ok(Arc::new(result) as ArrayRef)
_ => {
result_builder.append_null();
}
}
}
Ok(Arc::new(result_builder.finish()) as ArrayRef)
}
other => {
exec_err!("substr was called with {other} arguments. It requires 2 or 3.")
Expand Down

0 comments on commit fd08520

Please sign in to comment.