From 8af587e96fcfbb64d895fb918fc06e7a7b235517 Mon Sep 17 00:00:00 2001 From: Lordworms Date: Tue, 13 Aug 2024 21:05:29 -0700 Subject: [PATCH] Update SPLIT_PART scalar function to support Utf8View --- datafusion/functions/src/string/split_part.rs | 128 +++++++++++++----- .../sqllogictest/test_files/functions.slt | 32 +++++ .../sqllogictest/test_files/string_view.slt | 5 +- 3 files changed, 128 insertions(+), 37 deletions(-) diff --git a/datafusion/functions/src/string/split_part.rs b/datafusion/functions/src/string/split_part.rs index d6f7bb4a4d4a..19721f0fad28 100644 --- a/datafusion/functions/src/string/split_part.rs +++ b/datafusion/functions/src/string/split_part.rs @@ -21,7 +21,9 @@ use std::sync::Arc; use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; use arrow::datatypes::DataType; -use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use datafusion_common::cast::{ + as_generic_string_array, as_int64_array, as_string_view_array, +}; use datafusion_common::{exec_err, Result}; use datafusion_expr::TypeSignature::*; use datafusion_expr::{ColumnarValue, Volatility}; @@ -46,7 +48,12 @@ impl SplitPartFunc { Self { signature: Signature::one_of( vec![ + Exact(vec![Utf8View, Utf8View, Int64]), + Exact(vec![Utf8View, Utf8, Int64]), + Exact(vec![Utf8View, LargeUtf8, Int64]), + Exact(vec![Utf8, Utf8View, Int64]), Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![LargeUtf8, Utf8View, Int64]), Exact(vec![LargeUtf8, Utf8, Int64]), Exact(vec![Utf8, LargeUtf8, Int64]), Exact(vec![LargeUtf8, LargeUtf8, Int64]), @@ -75,50 +82,101 @@ impl ScalarUDFImpl for SplitPartFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - match args[0].data_type() { - DataType::Utf8 => make_scalar_function(split_part::, vec![])(args), - DataType::LargeUtf8 => make_scalar_function(split_part::, vec![])(args), - other => { - exec_err!("Unsupported data type {other:?} for function split_part") + match (args[0].data_type(), args[1].data_type()) { + ( + DataType::Utf8 | DataType::Utf8View, + DataType::Utf8 | DataType::Utf8View, + ) => make_scalar_function(split_part::, vec![])(args), + (DataType::LargeUtf8, DataType::LargeUtf8) => { + make_scalar_function(split_part::, vec![])(args) } + (_, DataType::LargeUtf8) => { + make_scalar_function(split_part::, vec![])(args) + } + (DataType::LargeUtf8, _) => { + make_scalar_function(split_part::, vec![])(args) + } + (first_type, second_type) => exec_err!( + "unsupported first type {} and second type {} for split_part function", + first_type, + second_type + ), } } } +macro_rules! process_split_part { + ($string_array: expr, $delimiter_array: expr, $n_array: expr) => {{ + let result = $string_array + .iter() + .zip($delimiter_array.iter()) + .zip($n_array.iter()) + .map(|((string, delimiter), n)| match (string, delimiter, n) { + (Some(string), Some(delimiter), Some(n)) => { + let split_string: Vec<&str> = string.split(delimiter).collect(); + let len = split_string.len(); + + let index = match n.cmp(&0) { + std::cmp::Ordering::Less => len as i64 + n, + std::cmp::Ordering::Equal => { + return exec_err!("field position must not be zero"); + } + std::cmp::Ordering::Greater => n - 1, + } as usize; + + if index < len { + Ok(Some(split_string[index])) + } else { + Ok(Some("")) + } + } + _ => Ok(None), + }) + .collect::>>()?; + Ok(Arc::new(result) as ArrayRef) + }}; +} + /// Splits string at occurrences of delimiter and returns the n'th field (counting from one). /// split_part('abc~@~def~@~ghi', '~@~', 2) = 'def' -fn split_part(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - let delimiter_array = as_generic_string_array::(&args[1])?; +fn split_part( + args: &[ArrayRef], +) -> Result { let n_array = as_int64_array(&args[2])?; - let result = string_array - .iter() - .zip(delimiter_array.iter()) - .zip(n_array.iter()) - .map(|((string, delimiter), n)| match (string, delimiter, n) { - (Some(string), Some(delimiter), Some(n)) => { - let split_string: Vec<&str> = string.split(delimiter).collect(); - let len = split_string.len(); - - let index = match n.cmp(&0) { - std::cmp::Ordering::Less => len as i64 + n, - std::cmp::Ordering::Equal => { - return exec_err!("field position must not be zero"); - } - std::cmp::Ordering::Greater => n - 1, - } as usize; - - if index < len { - Ok(Some(split_string[index])) - } else { - Ok(Some("")) + match (args[0].data_type(), args[1].data_type()) { + (DataType::Utf8View, _) => { + let string_array = as_string_view_array(&args[0])?; + match args[1].data_type() { + DataType::Utf8View => { + let delimiter_array = as_string_view_array(&args[1])?; + process_split_part!(string_array, delimiter_array, n_array) + } + _ => { + let delimiter_array = + as_generic_string_array::(&args[1])?; + process_split_part!(string_array, delimiter_array, n_array) } } - _ => Ok(None), - }) - .collect::>>()?; - - Ok(Arc::new(result) as ArrayRef) + } + (_, DataType::Utf8View) => { + let delimiter_array = as_string_view_array(&args[1])?; + match args[0].data_type() { + DataType::Utf8View => { + let string_array = as_string_view_array(&args[0])?; + process_split_part!(string_array, delimiter_array, n_array) + } + _ => { + let string_array = as_generic_string_array::(&args[0])?; + process_split_part!(string_array, delimiter_array, n_array) + } + } + } + (_, _) => { + let string_array = as_generic_string_array::(&args[0])?; + let delimiter_array = as_generic_string_array::(&args[1])?; + process_split_part!(string_array, delimiter_array, n_array) + } + } } #[cfg(test)] diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index bea3016a21d3..e135fb1aa7c2 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -788,6 +788,38 @@ SELECT split_part(arrow_cast('foo_bar', 'Dictionary(Int32, Utf8)'), '_', 2) ---- bar +# test largeutf8, utf8view for split_part +query T +SELECT split_part(arrow_cast('large_apple_large_orange_large_banana', 'LargeUtf8'), '_', 3) +---- +large + +query T +SELECT split_part(arrow_cast('view_apple_view_orange_view_banana', 'Utf8View'), '_', 3); +---- +view + +query T +SELECT split_part('test_large_split_large_case', arrow_cast('_large', 'LargeUtf8'), 2) +---- +_split + +query T +SELECT split_part(arrow_cast('huge_large_apple_large_orange_large_banana', 'LargeUtf8'), arrow_cast('_', 'Utf8View'), 2) +---- +large + +query T +SELECT split_part(arrow_cast('view_apple_view_large_banana', 'Utf8View'), arrow_cast('_large', 'LargeUtf8'), 2) +---- +_banana + +query T +SELECT split_part(NULL, '_', 2) +---- +NULL + + query B SELECT starts_with('foobar', 'foo') ---- diff --git a/datafusion/sqllogictest/test_files/string_view.slt b/datafusion/sqllogictest/test_files/string_view.slt index 0a9b73babb96..817c014a0402 100644 --- a/datafusion/sqllogictest/test_files/string_view.slt +++ b/datafusion/sqllogictest/test_files/string_view.slt @@ -936,11 +936,12 @@ logical_plan ## TODO file ticket query TT EXPLAIN SELECT - SPLIT_PART(column1_utf8view, 'f', 1) as c + SPLIT_PART(column1_utf8view, 'f', 1) as c1, + SPLIT_PART('testtesttest',column1_utf8view, 1) as c2 FROM test; ---- logical_plan -01)Projection: split_part(CAST(test.column1_utf8view AS Utf8), Utf8("f"), Int64(1)) AS c +01)Projection: split_part(test.column1_utf8view, Utf8("f"), Int64(1)) AS c1, split_part(Utf8("testtesttest"), test.column1_utf8view, Int64(1)) AS c2 02)--TableScan: test projection=[column1_utf8view] ## Ensure no casts for STRPOS