Skip to content

Commit

Permalink
feat:implement postgres style 'overlay' string function (#8117)
Browse files Browse the repository at this point in the history
* feat:implement posgres style 'overlay' string function

* code format

* code format

* code format

* code format

* add sql slt test

* fix modify other case issue

* add test expr

* add annotation

* add overlay function sql reference doc

* add sql case and format doc
  • Loading branch information
Syleechan authored Nov 14, 2023
1 parent a38ac20 commit 4535551
Show file tree
Hide file tree
Showing 12 changed files with 260 additions and 6 deletions.
18 changes: 17 additions & 1 deletion datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,8 @@ pub enum BuiltinScalarFunction {
RegexpMatch,
/// arrow_typeof
ArrowTypeof,
/// overlay
OverLay,
}

/// Maps the sql function name to `BuiltinScalarFunction`
Expand Down Expand Up @@ -455,6 +457,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Struct => Volatility::Immutable,
BuiltinScalarFunction::FromUnixtime => Volatility::Immutable,
BuiltinScalarFunction::ArrowTypeof => Volatility::Immutable,
BuiltinScalarFunction::OverLay => Volatility::Immutable,

// Stable builtin functions
BuiltinScalarFunction::Now => Volatility::Stable,
Expand Down Expand Up @@ -812,6 +815,10 @@ impl BuiltinScalarFunction {

BuiltinScalarFunction::Abs => Ok(input_expr_types[0].clone()),

BuiltinScalarFunction::OverLay => {
utf8_to_str_type(&input_expr_types[0], "overlay")
}

BuiltinScalarFunction::Acos
| BuiltinScalarFunction::Asin
| BuiltinScalarFunction::Atan
Expand Down Expand Up @@ -1258,7 +1265,15 @@ impl BuiltinScalarFunction {
}
BuiltinScalarFunction::ArrowTypeof => Signature::any(1, self.volatility()),
BuiltinScalarFunction::Abs => Signature::any(1, self.volatility()),

BuiltinScalarFunction::OverLay => Signature::one_of(
vec![
Exact(vec![Utf8, Utf8, Int64, Int64]),
Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]),
Exact(vec![Utf8, Utf8, Int64]),
Exact(vec![LargeUtf8, LargeUtf8, Int64]),
],
self.volatility(),
),
BuiltinScalarFunction::Acos
| BuiltinScalarFunction::Asin
| BuiltinScalarFunction::Atan
Expand Down Expand Up @@ -1517,6 +1532,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] {
BuiltinScalarFunction::Cardinality => &["cardinality"],
BuiltinScalarFunction::MakeArray => &["make_array", "make_list"],
BuiltinScalarFunction::ArrayIntersect => &["array_intersect", "list_intersect"],
BuiltinScalarFunction::OverLay => &["overlay"],

// struct functions
BuiltinScalarFunction::Struct => &["struct"],
Expand Down
7 changes: 7 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,11 @@ nary_scalar_expr!(
"concatenates several strings, placing a seperator between each one"
);
nary_scalar_expr!(Concat, concat_expr, "concatenates several strings");
nary_scalar_expr!(
OverLay,
overlay,
"replace the substring of string that starts at the start'th character and extends for count characters with new substring"
);

// date functions
scalar_expr!(DatePart, date_part, part date, "extracts a subfield from the date");
Expand Down Expand Up @@ -1174,6 +1179,8 @@ mod test {
test_nary_scalar_expr!(MakeArray, array, input);

test_unary_scalar_expr!(ArrowTypeof, arrow_typeof);
test_nary_scalar_expr!(OverLay, overlay, string, characters, position, len);
test_nary_scalar_expr!(OverLay, overlay, string, characters, position);
}

#[test]
Expand Down
11 changes: 11 additions & 0 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,17 @@ pub fn create_physical_fun(
"{input_data_type}"
)))))
}),
BuiltinScalarFunction::OverLay => Arc::new(|args| match args[0].data_type() {
DataType::Utf8 => {
make_scalar_function(string_expressions::overlay::<i32>)(args)
}
DataType::LargeUtf8 => {
make_scalar_function(string_expressions::overlay::<i64>)(args)
}
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {other:?} for function overlay",
))),
}),
})
}

Expand Down
108 changes: 108 additions & 0 deletions datafusion/physical-expr/src/string_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -553,11 +553,102 @@ pub fn uuid(args: &[ColumnarValue]) -> Result<ColumnarValue> {
Ok(ColumnarValue::Array(Arc::new(array)))
}

/// OVERLAY(string1 PLACING string2 FROM integer FOR integer2)
/// Replaces a substring of string1 with string2 starting at the integer bit
/// pgsql overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas
/// overlay('Txxxxas' placing 'hom' from 2) -> Thomxas, without for option, str2's len is instead
pub fn overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
3 => {
let string_array = as_generic_string_array::<T>(&args[0])?;
let characters_array = as_generic_string_array::<T>(&args[1])?;
let pos_num = as_int64_array(&args[2])?;

let result = string_array
.iter()
.zip(characters_array.iter())
.zip(pos_num.iter())
.map(|((string, characters), start_pos)| {
match (string, characters, start_pos) {
(Some(string), Some(characters), Some(start_pos)) => {
let string_len = string.chars().count();
let characters_len = characters.chars().count();
let replace_len = characters_len as i64;
let mut res =
String::with_capacity(string_len.max(characters_len));

//as sql replace index start from 1 while string index start from 0
if start_pos > 1 && start_pos - 1 < string_len as i64 {
let start = (start_pos - 1) as usize;
res.push_str(&string[..start]);
}
res.push_str(characters);
// if start + replace_len - 1 >= string_length, just to string end
if start_pos + replace_len - 1 < string_len as i64 {
let end = (start_pos + replace_len - 1) as usize;
res.push_str(&string[end..]);
}
Ok(Some(res))
}
_ => Ok(None),
}
})
.collect::<Result<GenericStringArray<T>>>()?;
Ok(Arc::new(result) as ArrayRef)
}
4 => {
let string_array = as_generic_string_array::<T>(&args[0])?;
let characters_array = as_generic_string_array::<T>(&args[1])?;
let pos_num = as_int64_array(&args[2])?;
let len_num = as_int64_array(&args[3])?;

let result = string_array
.iter()
.zip(characters_array.iter())
.zip(pos_num.iter())
.zip(len_num.iter())
.map(|(((string, characters), start_pos), len)| {
match (string, characters, start_pos, len) {
(Some(string), Some(characters), Some(start_pos), Some(len)) => {
let string_len = string.chars().count();
let characters_len = characters.chars().count();
let replace_len = len.min(string_len as i64);
let mut res =
String::with_capacity(string_len.max(characters_len));

//as sql replace index start from 1 while string index start from 0
if start_pos > 1 && start_pos - 1 < string_len as i64 {
let start = (start_pos - 1) as usize;
res.push_str(&string[..start]);
}
res.push_str(characters);
// if start + replace_len - 1 >= string_length, just to string end
if start_pos + replace_len - 1 < string_len as i64 {
let end = (start_pos + replace_len - 1) as usize;
res.push_str(&string[end..]);
}
Ok(Some(res))
}
_ => Ok(None),
}
})
.collect::<Result<GenericStringArray<T>>>()?;
Ok(Arc::new(result) as ArrayRef)
}
other => {
internal_err!(
"overlay was called with {other} arguments. It requires 3 or 4."
)
}
}
}

#[cfg(test)]
mod tests {

use crate::string_expressions;
use arrow::{array::Int32Array, datatypes::Int32Type};
use arrow_array::Int64Array;

use super::*;

Expand Down Expand Up @@ -599,4 +690,21 @@ mod tests {

Ok(())
}

#[test]
fn to_overlay() -> Result<()> {
let string =
Arc::new(StringArray::from(vec!["123", "abcdefg", "xyz", "Txxxxas"]));
let replace_string =
Arc::new(StringArray::from(vec!["abc", "qwertyasdfg", "ijk", "hom"]));
let start = Arc::new(Int64Array::from(vec![4, 1, 1, 2])); // start
let end = Arc::new(Int64Array::from(vec![5, 7, 2, 4])); // replace len

let res = overlay::<i32>(&[string, replace_string, start, end]).unwrap();
let result = as_generic_string_array::<i32>(&res).unwrap();
let expected = StringArray::from(vec!["abc", "qwertyasdfg", "ijkz", "Thomas"]);
assert_eq!(&expected, result);

Ok(())
}
}
1 change: 1 addition & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,7 @@ enum ScalarFunction {
ToTimestampNanos = 118;
ArrayIntersect = 119;
ArrayUnion = 120;
OverLay = 121;
}

message ScalarFunctionNode {
Expand Down
3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 11 additions & 4 deletions datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ use datafusion_expr::{
factorial, flatten, floor, from_unixtime, gcd, isnan, iszero, lcm, left, ln, log,
log10, log2,
logical_plan::{PlanType, StringifiedPlan},
lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, pi, power, radians,
random, regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad,
rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, sqrt,
starts_with, string_to_array, strpos, struct_fun, substr, substring, tan, tanh,
lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi, power,
radians, random, regexp_match, regexp_replace, repeat, replace, reverse, right,
round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part,
sqrt, starts_with, string_to_array, strpos, struct_fun, substr, substring, tan, tanh,
to_hex, to_timestamp_micros, to_timestamp_millis, to_timestamp_nanos,
to_timestamp_seconds, translate, trim, trunc, upper, uuid,
window_frame::regularize,
Expand Down Expand Up @@ -546,6 +546,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
ScalarFunction::Isnan => Self::Isnan,
ScalarFunction::Iszero => Self::Iszero,
ScalarFunction::ArrowTypeof => Self::ArrowTypeof,
ScalarFunction::OverLay => Self::OverLay,
}
}
}
Expand Down Expand Up @@ -1680,6 +1681,12 @@ pub fn parse_expr(
parse_expr(&args[1], registry)?,
parse_expr(&args[2], registry)?,
)),
ScalarFunction::OverLay => Ok(overlay(
args.to_owned()
.iter()
.map(|expr| parse_expr(expr, registry))
.collect::<Result<Vec<_>, _>>()?,
)),
ScalarFunction::StructFun => {
Ok(struct_fun(parse_expr(&args[0], registry)?))
}
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1553,6 +1553,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
BuiltinScalarFunction::Isnan => Self::Isnan,
BuiltinScalarFunction::Iszero => Self::Iszero,
BuiltinScalarFunction::ArrowTypeof => Self::ArrowTypeof,
BuiltinScalarFunction::OverLay => Self::OverLay,
};

Ok(scalar_function)
Expand Down
40 changes: 39 additions & 1 deletion datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
schema,
planner_context,
),

SQLExpr::Overlay {
expr,
overlay_what,
overlay_from,
overlay_for,
} => self.sql_overlay_to_expr(
*expr,
*overlay_what,
*overlay_from,
overlay_for,
schema,
planner_context,
),
SQLExpr::Nested(e) => {
self.sql_expr_to_logical_expr(*e, schema, planner_context)
}
Expand Down Expand Up @@ -645,6 +657,32 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args)))
}

fn sql_overlay_to_expr(
&self,
expr: SQLExpr,
overlay_what: SQLExpr,
overlay_from: SQLExpr,
overlay_for: Option<Box<SQLExpr>>,
schema: &DFSchema,
planner_context: &mut PlannerContext,
) -> Result<Expr> {
let fun = BuiltinScalarFunction::OverLay;
let arg = self.sql_expr_to_logical_expr(expr, schema, planner_context)?;
let what_arg =
self.sql_expr_to_logical_expr(overlay_what, schema, planner_context)?;
let from_arg =
self.sql_expr_to_logical_expr(overlay_from, schema, planner_context)?;
let args = match overlay_for {
Some(for_expr) => {
let for_expr =
self.sql_expr_to_logical_expr(*for_expr, schema, planner_context)?;
vec![arg, what_arg, from_arg, for_expr]
}
None => vec![arg, what_arg, from_arg],
};
Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args)))
}

fn sql_agg_with_filter_to_expr(
&self,
expr: SQLExpr,
Expand Down
Loading

0 comments on commit 4535551

Please sign in to comment.