Skip to content

Commit

Permalink
replace TypeSignature::String with TypeSignature::Coercible (apache#1…
Browse files Browse the repository at this point in the history
…4917)

* deprecated use of TypeSignature::String

* make kernel functions private
  • Loading branch information
zjregee authored Mar 5, 2025
1 parent 5d08325 commit 7597769
Show file tree
Hide file tree
Showing 14 changed files with 378 additions and 139 deletions.
17 changes: 14 additions & 3 deletions datafusion/functions-nested/src/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use arrow::array::{
UInt8Array,
};
use arrow::datatypes::{DataType, Field};
use datafusion_expr::TypeSignature;

use datafusion_common::{
internal_datafusion_err, not_impl_err, plan_err, DataFusionError, Result,
Expand All @@ -44,8 +43,10 @@ use arrow::datatypes::DataType::{
};
use datafusion_common::cast::{as_large_list_array, as_list_array};
use datafusion_common::exec_err;
use datafusion_common::types::logical_string;
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
Coercion, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature,
TypeSignatureClass, Volatility,
};
use datafusion_functions::{downcast_arg, downcast_named_arg};
use datafusion_macros::user_doc;
Expand Down Expand Up @@ -251,7 +252,17 @@ impl StringToArray {
pub fn new() -> Self {
Self {
signature: Signature::one_of(
vec![TypeSignature::String(2), TypeSignature::String(3)],
vec![
TypeSignature::Coercible(vec![
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
]),
TypeSignature::Coercible(vec![
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
]),
],
Volatility::Immutable,
),
aliases: vec![String::from("string_to_list")],
Expand Down
27 changes: 20 additions & 7 deletions datafusion/functions/src/regex/regexplike.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@ use arrow::array::{Array, ArrayRef, AsArray, GenericStringArray};
use arrow::compute::kernels::regexp;
use arrow::datatypes::DataType;
use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View};
use datafusion_common::exec_err;
use datafusion_common::ScalarValue;
use datafusion_common::{arrow_datafusion_err, plan_err};
use datafusion_common::{internal_err, DataFusionError, Result};
use datafusion_expr::{ColumnarValue, Documentation, TypeSignature};
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use datafusion_common::types::logical_string;
use datafusion_common::{
arrow_datafusion_err, exec_err, internal_err, plan_err, DataFusionError, Result,
ScalarValue,
};
use datafusion_expr::{
Coercion, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature,
TypeSignatureClass, Volatility,
};
use datafusion_macros::user_doc;

use std::any::Any;
Expand Down Expand Up @@ -79,7 +82,17 @@ impl RegexpLikeFunc {
pub fn new() -> Self {
Self {
signature: Signature::one_of(
vec![TypeSignature::String(2), TypeSignature::String(3)],
vec![
TypeSignature::Coercible(vec![
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
]),
TypeSignature::Coercible(vec![
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
]),
],
Volatility::Immutable,
),
}
Expand Down
17 changes: 13 additions & 4 deletions datafusion/functions/src/string/bit_length.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,13 @@ use arrow::datatypes::DataType;
use std::any::Any;

use crate::utils::utf8_to_int_type;
use datafusion_common::{utils::take_function_args, Result, ScalarValue};
use datafusion_expr::{ColumnarValue, Documentation, Volatility};
use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
use datafusion_common::types::logical_string;
use datafusion_common::utils::take_function_args;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{
Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
TypeSignatureClass, Volatility,
};
use datafusion_macros::user_doc;

#[user_doc(
Expand Down Expand Up @@ -55,7 +59,12 @@ impl Default for BitLengthFunc {
impl BitLengthFunc {
pub fn new() -> Self {
Self {
signature: Signature::string(1, Volatility::Immutable),
signature: Signature::coercible(
vec![Coercion::new_exact(TypeSignatureClass::Native(
logical_string(),
))],
Volatility::Immutable,
),
}
}
}
Expand Down
85 changes: 57 additions & 28 deletions datafusion/functions/src/string/contains.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ use arrow::array::{Array, ArrayRef, AsArray};
use arrow::compute::contains as arrow_contains;
use arrow::datatypes::DataType;
use arrow::datatypes::DataType::{Boolean, LargeUtf8, Utf8, Utf8View};
use datafusion_common::exec_err;
use datafusion_common::DataFusionError;
use datafusion_common::Result;
use datafusion_common::types::logical_string;
use datafusion_common::{exec_err, DataFusionError, Result};
use datafusion_expr::binary::{binary_to_string_coercion, string_coercion};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
Volatility,
Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
TypeSignatureClass, Volatility,
};
use datafusion_macros::user_doc;
use std::any::Any;
Expand Down Expand Up @@ -60,7 +60,13 @@ impl Default for ContainsFunc {
impl ContainsFunc {
pub fn new() -> Self {
Self {
signature: Signature::string(2, Volatility::Immutable),
signature: Signature::coercible(
vec![
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
],
Volatility::Immutable,
),
}
}
}
Expand Down Expand Up @@ -92,29 +98,52 @@ impl ScalarUDFImpl for ContainsFunc {
}

/// use `arrow::compute::contains` to do the calculation for contains
pub fn contains(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
match (args[0].data_type(), args[1].data_type()) {
(Utf8View, Utf8View) => {
let mod_str = args[0].as_string_view();
let match_str = args[1].as_string_view();
let res = arrow_contains(mod_str, match_str)?;
Ok(Arc::new(res) as ArrayRef)
}
(Utf8, Utf8) => {
let mod_str = args[0].as_string::<i32>();
let match_str = args[1].as_string::<i32>();
let res = arrow_contains(mod_str, match_str)?;
Ok(Arc::new(res) as ArrayRef)
}
(LargeUtf8, LargeUtf8) => {
let mod_str = args[0].as_string::<i64>();
let match_str = args[1].as_string::<i64>();
let res = arrow_contains(mod_str, match_str)?;
Ok(Arc::new(res) as ArrayRef)
}
other => {
exec_err!("Unsupported data type {other:?} for function `contains`.")
fn contains(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
if let Some(coercion_data_type) =
string_coercion(args[0].data_type(), args[1].data_type()).or_else(|| {
binary_to_string_coercion(args[0].data_type(), args[1].data_type())
})
{
let arg0 = if args[0].data_type() == &coercion_data_type {
Arc::clone(&args[0])
} else {
arrow::compute::kernels::cast::cast(&args[0], &coercion_data_type)?
};
let arg1 = if args[1].data_type() == &coercion_data_type {
Arc::clone(&args[1])
} else {
arrow::compute::kernels::cast::cast(&args[1], &coercion_data_type)?
};

match coercion_data_type {
Utf8View => {
let mod_str = arg0.as_string_view();
let match_str = arg1.as_string_view();
let res = arrow_contains(mod_str, match_str)?;
Ok(Arc::new(res) as ArrayRef)
}
Utf8 => {
let mod_str = arg0.as_string::<i32>();
let match_str = arg1.as_string::<i32>();
let res = arrow_contains(mod_str, match_str)?;
Ok(Arc::new(res) as ArrayRef)
}
LargeUtf8 => {
let mod_str = arg0.as_string::<i64>();
let match_str = arg1.as_string::<i64>();
let res = arrow_contains(mod_str, match_str)?;
Ok(Arc::new(res) as ArrayRef)
}
other => {
exec_err!("Unsupported data type {other:?} for function `contains`.")
}
}
} else {
exec_err!(
"Unsupported data type {:?}, {:?} for function `contains`.",
args[0].data_type(),
args[1].data_type()
)
}
}

Expand Down
43 changes: 36 additions & 7 deletions datafusion/functions/src/string/ends_with.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@ use arrow::array::ArrayRef;
use arrow::datatypes::DataType;

use crate::utils::make_scalar_function;
use datafusion_common::types::logical_string;
use datafusion_common::{internal_err, Result};
use datafusion_expr::{ColumnarValue, Documentation, Volatility};
use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
use datafusion_expr::binary::{binary_to_string_coercion, string_coercion};
use datafusion_expr::{
Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
TypeSignatureClass, Volatility,
};
use datafusion_macros::user_doc;

#[user_doc(
Expand Down Expand Up @@ -62,7 +66,13 @@ impl Default for EndsWithFunc {
impl EndsWithFunc {
pub fn new() -> Self {
Self {
signature: Signature::string(2, Volatility::Immutable),
signature: Signature::coercible(
vec![
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
],
Volatility::Immutable,
),
}
}
}
Expand Down Expand Up @@ -102,10 +112,29 @@ impl ScalarUDFImpl for EndsWithFunc {

/// Returns true if string ends with suffix.
/// ends_with('alphabet', 'abet') = 't'
pub fn ends_with(args: &[ArrayRef]) -> Result<ArrayRef> {
let result = arrow::compute::kernels::comparison::ends_with(&args[0], &args[1])?;

Ok(Arc::new(result) as ArrayRef)
fn ends_with(args: &[ArrayRef]) -> Result<ArrayRef> {
if let Some(coercion_data_type) =
string_coercion(args[0].data_type(), args[1].data_type()).or_else(|| {
binary_to_string_coercion(args[0].data_type(), args[1].data_type())
})
{
let arg0 = if args[0].data_type() == &coercion_data_type {
Arc::clone(&args[0])
} else {
arrow::compute::kernels::cast::cast(&args[0], &coercion_data_type)?
};
let arg1 = if args[1].data_type() == &coercion_data_type {
Arc::clone(&args[1])
} else {
arrow::compute::kernels::cast::cast(&args[1], &coercion_data_type)?
};
let result = arrow::compute::kernels::comparison::ends_with(&arg0, &arg1)?;
Ok(Arc::new(result) as ArrayRef)
} else {
internal_err!(
"Unsupported data types for ends_with. Expected Utf8, LargeUtf8 or Utf8View"
)
}
}

#[cfg(test)]
Expand Down
Loading

0 comments on commit 7597769

Please sign in to comment.