Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat:implement postgres style 'overlay' string function #8117

Merged
merged 13 commits into from
Nov 14, 2023
18 changes: 17 additions & 1 deletion datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,8 @@ pub enum BuiltinScalarFunction {
RegexpMatch,
/// arrow_typeof
ArrowTypeof,
/// overlay
OverLay,
}

/// Maps the sql function name to `BuiltinScalarFunction`
Expand Down Expand Up @@ -455,6 +457,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Struct => Volatility::Immutable,
BuiltinScalarFunction::FromUnixtime => Volatility::Immutable,
BuiltinScalarFunction::ArrowTypeof => Volatility::Immutable,
BuiltinScalarFunction::OverLay => Volatility::Immutable,

// Stable builtin functions
BuiltinScalarFunction::Now => Volatility::Stable,
Expand Down Expand Up @@ -812,6 +815,10 @@ impl BuiltinScalarFunction {

BuiltinScalarFunction::Abs => Ok(input_expr_types[0].clone()),

BuiltinScalarFunction::OverLay => {
utf8_to_str_type(&input_expr_types[0], "overlay")
}

BuiltinScalarFunction::Acos
| BuiltinScalarFunction::Asin
| BuiltinScalarFunction::Atan
Expand Down Expand Up @@ -1258,7 +1265,15 @@ impl BuiltinScalarFunction {
}
BuiltinScalarFunction::ArrowTypeof => Signature::any(1, self.volatility()),
BuiltinScalarFunction::Abs => Signature::any(1, self.volatility()),

BuiltinScalarFunction::OverLay => Signature::one_of(
vec![
Exact(vec![Utf8, Utf8, Int64, Int64]),
Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]),
Exact(vec![Utf8, Utf8, Int64]),
Exact(vec![LargeUtf8, LargeUtf8, Int64]),
],
self.volatility(),
),
BuiltinScalarFunction::Acos
| BuiltinScalarFunction::Asin
| BuiltinScalarFunction::Atan
Expand Down Expand Up @@ -1517,6 +1532,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] {
BuiltinScalarFunction::Cardinality => &["cardinality"],
BuiltinScalarFunction::MakeArray => &["make_array", "make_list"],
BuiltinScalarFunction::ArrayIntersect => &["array_intersect", "list_intersect"],
BuiltinScalarFunction::OverLay => &["overlay"],

// struct functions
BuiltinScalarFunction::Struct => &["struct"],
Expand Down
7 changes: 7 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,11 @@ nary_scalar_expr!(
"concatenates several strings, placing a seperator between each one"
);
nary_scalar_expr!(Concat, concat_expr, "concatenates several strings");
nary_scalar_expr!(
OverLay,
overlay,
"replace the substring of string that starts at the start'th character and extends for count characters with new substring"
);

// date functions
scalar_expr!(DatePart, date_part, part date, "extracts a subfield from the date");
Expand Down Expand Up @@ -1174,6 +1179,8 @@ mod test {
test_nary_scalar_expr!(MakeArray, array, input);

test_unary_scalar_expr!(ArrowTypeof, arrow_typeof);
test_nary_scalar_expr!(OverLay, overlay, string, characters, position, len);
test_nary_scalar_expr!(OverLay, overlay, string, characters, position);
}

#[test]
Expand Down
11 changes: 11 additions & 0 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,17 @@ pub fn create_physical_fun(
"{input_data_type}"
)))))
}),
BuiltinScalarFunction::OverLay => Arc::new(|args| match args[0].data_type() {
DataType::Utf8 => {
make_scalar_function(string_expressions::overlay::<i32>)(args)
}
DataType::LargeUtf8 => {
make_scalar_function(string_expressions::overlay::<i64>)(args)
}
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {other:?} for function overlay",
))),
}),
})
}

Expand Down
108 changes: 108 additions & 0 deletions datafusion/physical-expr/src/string_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -553,11 +553,102 @@ pub fn uuid(args: &[ColumnarValue]) -> Result<ColumnarValue> {
Ok(ColumnarValue::Array(Arc::new(array)))
}

/// OVERLAY(string1 PLACING string2 FROM integer FOR integer2)
/// Replaces a substring of string1 with string2 starting at the integer bit
/// pgsql overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas
/// overlay('Txxxxas' placing 'hom' from 2) -> Thomxas, without for option, str2's len is instead
pub fn overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
3 => {
let string_array = as_generic_string_array::<T>(&args[0])?;
let characters_array = as_generic_string_array::<T>(&args[1])?;
let pos_num = as_int64_array(&args[2])?;

let result = string_array
.iter()
.zip(characters_array.iter())
.zip(pos_num.iter())
.map(|((string, characters), start_pos)| {
match (string, characters, start_pos) {
(Some(string), Some(characters), Some(start_pos)) => {
let string_len = string.chars().count();
let characters_len = characters.chars().count();
let replace_len = characters_len as i64;
let mut res =
String::with_capacity(string_len.max(characters_len));

//as sql replace index start from 1 while string index start from 0
if start_pos > 1 && start_pos - 1 < string_len as i64 {
let start = (start_pos - 1) as usize;
res.push_str(&string[..start]);
}
res.push_str(characters);
// if start + replace_len - 1 >= string_length, just to string end
if start_pos + replace_len - 1 < string_len as i64 {
let end = (start_pos + replace_len - 1) as usize;
res.push_str(&string[end..]);
}
Ok(Some(res))
}
_ => Ok(None),
}
})
.collect::<Result<GenericStringArray<T>>>()?;
Ok(Arc::new(result) as ArrayRef)
}
4 => {
let string_array = as_generic_string_array::<T>(&args[0])?;
let characters_array = as_generic_string_array::<T>(&args[1])?;
let pos_num = as_int64_array(&args[2])?;
let len_num = as_int64_array(&args[3])?;

let result = string_array
.iter()
.zip(characters_array.iter())
.zip(pos_num.iter())
.zip(len_num.iter())
.map(|(((string, characters), start_pos), len)| {
match (string, characters, start_pos, len) {
(Some(string), Some(characters), Some(start_pos), Some(len)) => {
let string_len = string.chars().count();
let characters_len = characters.chars().count();
let replace_len = len.min(string_len as i64);
let mut res =
String::with_capacity(string_len.max(characters_len));

//as sql replace index start from 1 while string index start from 0
if start_pos > 1 && start_pos - 1 < string_len as i64 {
let start = (start_pos - 1) as usize;
res.push_str(&string[..start]);
}
res.push_str(characters);
// if start + replace_len - 1 >= string_length, just to string end
if start_pos + replace_len - 1 < string_len as i64 {
let end = (start_pos + replace_len - 1) as usize;
res.push_str(&string[end..]);
}
Ok(Some(res))
}
_ => Ok(None),
}
})
.collect::<Result<GenericStringArray<T>>>()?;
Ok(Arc::new(result) as ArrayRef)
}
other => {
internal_err!(
"overlay was called with {other} arguments. It requires 3 or 4."
)
}
}
}

#[cfg(test)]
mod tests {

use crate::string_expressions;
use arrow::{array::Int32Array, datatypes::Int32Type};
use arrow_array::Int64Array;

use super::*;

Expand Down Expand Up @@ -599,4 +690,21 @@ mod tests {

Ok(())
}

#[test]
fn to_overlay() -> Result<()> {
let string =
Arc::new(StringArray::from(vec!["123", "abcdefg", "xyz", "Txxxxas"]));
let replace_string =
Arc::new(StringArray::from(vec!["abc", "qwertyasdfg", "ijk", "hom"]));
let start = Arc::new(Int64Array::from(vec![4, 1, 1, 2])); // start
let end = Arc::new(Int64Array::from(vec![5, 7, 2, 4])); // replace len

let res = overlay::<i32>(&[string, replace_string, start, end]).unwrap();
let result = as_generic_string_array::<i32>(&res).unwrap();
let expected = StringArray::from(vec!["abc", "qwertyasdfg", "ijkz", "Thomas"]);
assert_eq!(&expected, result);

Ok(())
}
}
1 change: 1 addition & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,7 @@ enum ScalarFunction {
ToTimestampNanos = 118;
ArrayIntersect = 119;
ArrayUnion = 120;
OverLay = 121;
}

message ScalarFunctionNode {
Expand Down
3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 11 additions & 4 deletions datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ use datafusion_expr::{
factorial, flatten, floor, from_unixtime, gcd, isnan, iszero, lcm, left, ln, log,
log10, log2,
logical_plan::{PlanType, StringifiedPlan},
lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, pi, power, radians,
random, regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad,
rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, sqrt,
starts_with, string_to_array, strpos, struct_fun, substr, substring, tan, tanh,
lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi, power,
radians, random, regexp_match, regexp_replace, repeat, replace, reverse, right,
round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part,
sqrt, starts_with, string_to_array, strpos, struct_fun, substr, substring, tan, tanh,
to_hex, to_timestamp_micros, to_timestamp_millis, to_timestamp_nanos,
to_timestamp_seconds, translate, trim, trunc, upper, uuid,
window_frame::regularize,
Expand Down Expand Up @@ -546,6 +546,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
ScalarFunction::Isnan => Self::Isnan,
ScalarFunction::Iszero => Self::Iszero,
ScalarFunction::ArrowTypeof => Self::ArrowTypeof,
ScalarFunction::OverLay => Self::OverLay,
}
}
}
Expand Down Expand Up @@ -1680,6 +1681,12 @@ pub fn parse_expr(
parse_expr(&args[1], registry)?,
parse_expr(&args[2], registry)?,
)),
ScalarFunction::OverLay => Ok(overlay(
args.to_owned()
.iter()
.map(|expr| parse_expr(expr, registry))
.collect::<Result<Vec<_>, _>>()?,
)),
ScalarFunction::StructFun => {
Ok(struct_fun(parse_expr(&args[0], registry)?))
}
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1553,6 +1553,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
BuiltinScalarFunction::Isnan => Self::Isnan,
BuiltinScalarFunction::Iszero => Self::Iszero,
BuiltinScalarFunction::ArrowTypeof => Self::ArrowTypeof,
BuiltinScalarFunction::OverLay => Self::OverLay,
};

Ok(scalar_function)
Expand Down
40 changes: 39 additions & 1 deletion datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
schema,
planner_context,
),

SQLExpr::Overlay {
expr,
overlay_what,
overlay_from,
overlay_for,
} => self.sql_overlay_to_expr(
*expr,
*overlay_what,
*overlay_from,
overlay_for,
schema,
planner_context,
),
SQLExpr::Nested(e) => {
self.sql_expr_to_logical_expr(*e, schema, planner_context)
}
Expand Down Expand Up @@ -645,6 +657,32 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args)))
}

fn sql_overlay_to_expr(
&self,
expr: SQLExpr,
overlay_what: SQLExpr,
overlay_from: SQLExpr,
overlay_for: Option<Box<SQLExpr>>,
schema: &DFSchema,
planner_context: &mut PlannerContext,
) -> Result<Expr> {
let fun = BuiltinScalarFunction::OverLay;
let arg = self.sql_expr_to_logical_expr(expr, schema, planner_context)?;
let what_arg =
self.sql_expr_to_logical_expr(overlay_what, schema, planner_context)?;
let from_arg =
self.sql_expr_to_logical_expr(overlay_from, schema, planner_context)?;
let args = match overlay_for {
Some(for_expr) => {
let for_expr =
self.sql_expr_to_logical_expr(*for_expr, schema, planner_context)?;
vec![arg, what_arg, from_arg, for_expr]
}
None => vec![arg, what_arg, from_arg],
};
Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args)))
}

fn sql_agg_with_filter_to_expr(
&self,
expr: SQLExpr,
Expand Down
Loading