From 583fc77fa022ff6f48522c69f67b56e01ada2f1f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 26 Dec 2023 10:01:27 -0800 Subject: [PATCH 1/8] Make regexp_match take Datum pattern input --- arrow-string/src/regexp.rs | 231 +++++++++++++++++++++++++++++++++---- 1 file changed, 209 insertions(+), 22 deletions(-) diff --git a/arrow-string/src/regexp.rs b/arrow-string/src/regexp.rs index 25c712d20f08..05f3f3d60b39 100644 --- a/arrow-string/src/regexp.rs +++ b/arrow-string/src/regexp.rs @@ -152,28 +152,7 @@ pub fn regexp_is_match_utf8_scalar( Ok(BooleanArray::from(data)) } -/// Extract all groups matched by a regular expression for a given String array. -/// -/// Modelled after the Postgres [regexp_match]. -/// -/// Returns a ListArray of [`GenericStringArray`] with each element containing the leftmost-first -/// match of the corresponding index in `regex_array` to string in `array` -/// -/// If there is no match, the list element is NULL. -/// -/// If a match is found, and the pattern contains no capturing parenthesized subexpressions, -/// then the list element is a single-element [`GenericStringArray`] containing the substring -/// matching the whole pattern. -/// -/// If a match is found, and the pattern contains capturing parenthesized subexpressions, then the -/// list element is a [`GenericStringArray`] whose n'th element is the substring matching -/// the n'th capturing parenthesized subexpression of the pattern. -/// -/// The flags parameter is an optional text string containing zero or more single-letter flags -/// that change the function's behavior. -/// -/// [regexp_match]: https://www.postgresql.org/docs/current/functions-matching.html#FUNCTIONS-POSIX-REGEXP -pub fn regexp_match( +fn regexp_array_match( array: &GenericStringArray, regex_array: &GenericStringArray, flags_array: Option<&GenericStringArray>, @@ -248,6 +227,214 @@ pub fn regexp_match( Ok(Arc::new(list_builder.finish())) } +fn get_scalar_pattern_flag<'a, OffsetSize: OffsetSizeTrait>( + regex_array: &'a dyn Array, + flag_array: Option<&'a dyn Array>, +) -> (&'a str, Option<&'a str>) { + let regex = regex_array + .as_any() + .downcast_ref::>() + .expect("Unable to downcast to StringArray/LargeStringArray"); + let regex = regex.value(0); + + if flag_array.is_some() { + let flag = flag_array + .unwrap() + .as_any() + .downcast_ref::>() + .expect("Unable to downcast to StringArray/LargeStringArray"); + + if flag.is_valid(0) { + let flag = flag.value(0); + (regex, Some(flag)) + } else { + (regex, None) + } + } else { + (regex, None) + } +} + +fn regexp_scalar_match( + array: &dyn Array, + regex: Option<&Regex>, +) -> std::result::Result { + if regex.is_none() {} + + let builder: GenericStringBuilder = GenericStringBuilder::with_capacity(0, 0); + let mut list_builder = ListBuilder::new(builder); + + let array = array + .as_any() + .downcast_ref::>() + .expect("Unable to downcast to StringArray/LargeStringArray"); + + let regex = regex.unwrap(); + + array + .iter() + .map(|value| { + match value { + // Required for Postgres compatibility: + // SELECT regexp_match('foobarbequebaz', ''); = {""} + Some(_) if regex.as_str() == "" => { + list_builder.values().append_value(""); + list_builder.append(true); + } + Some(value) => match regex.captures(value) { + Some(caps) => { + let mut iter = caps.iter(); + if caps.len() > 1 { + iter.next(); + } + for m in iter.flatten() { + list_builder.values().append_value(m.as_str()); + } + + list_builder.append(true); + } + None => list_builder.append(false), + }, + _ => list_builder.append(false), + } + Ok(()) + }) + .collect::, ArrowError>>()?; + + Ok(Arc::new(list_builder.finish())) +} + +/// Extract all groups matched by a regular expression for a given String array. +/// +/// Modelled after the Postgres [regexp_match]. +/// +/// Returns a ListArray of [`GenericStringArray`] with each element containing the leftmost-first +/// match of the corresponding index in `regex_array` to string in `array` +/// +/// If there is no match, the list element is NULL. +/// +/// If a match is found, and the pattern contains no capturing parenthesized subexpressions, +/// then the list element is a single-element [`GenericStringArray`] containing the substring +/// matching the whole pattern. +/// +/// If a match is found, and the pattern contains capturing parenthesized subexpressions, then the +/// list element is a [`GenericStringArray`] whose n'th element is the substring matching +/// the n'th capturing parenthesized subexpression of the pattern. +/// +/// The flags parameter is an optional text string containing zero or more single-letter flags +/// that change the function's behavior. +/// +/// [regexp_match]: https://www.postgresql.org/docs/current/functions-matching.html#FUNCTIONS-POSIX-REGEXP +pub fn regexp_match( + array: &dyn Datum, + regex_array: &dyn Datum, + flags_array: Option<&dyn Datum>, +) -> std::result::Result { + let (lhs, is_lhs_scalar) = array.get(); + let (rhs, is_rhs_scalar) = regex_array.get(); + + let (flags, is_flags_scalar) = match flags_array { + Some(flags) => { + let (flags, is_flags_scalar) = flags.get(); + (Some(flags), Some(is_flags_scalar)) + } + None => (None, None), + }; + + if is_lhs_scalar { + return Err(ArrowError::ComputeError(format!( + "regexp_match() requires array to be either Utf8 or LargeUtf8 array instead of scalar" + ))); + } + + if is_flags_scalar.is_some() && is_rhs_scalar != is_flags_scalar.unwrap() { + return Err(ArrowError::ComputeError(format!( + "regexp_match() requires both pattern and flags to be either scalar or array" + ))); + } + + if flags_array.is_some() && rhs.data_type() != flags.unwrap().data_type() { + return Err(ArrowError::ComputeError(format!( + "regexp_match() requires both pattern and flags to be either string or largestring" + ))); + } + + if is_rhs_scalar { + // Regex and flag is scalars + let (regex, flag) = match rhs.data_type() { + DataType::Utf8 => get_scalar_pattern_flag::(rhs, flags), + DataType::LargeUtf8 => get_scalar_pattern_flag::(rhs, flags), + _ => { + return Err(ArrowError::ComputeError(format!( + "regexp_match() requires pattern to be either Utf8 or LargeUtf8" + ))); + } + }; + + let pattern = if let Some(flag) = flag { + format!("(?{regex}){flag}") + } else { + regex.to_string() + }; + + let re = Regex::new(pattern.as_str()).map_err(|e| { + ArrowError::ComputeError(format!("Regular expression did not compile: {e:?}")) + })?; + + match lhs.data_type() { + DataType::Utf8 => regexp_scalar_match::(lhs, Some(&re)), + DataType::LargeUtf8 => regexp_scalar_match::(lhs, Some(&re)), + _ => { + return Err(ArrowError::ComputeError(format!( + "regexp_match() requires array to be either Utf8 or LargeUtf8" + ))); + } + } + } else { + match rhs.data_type() { + DataType::Utf8 => { + let array = lhs + .as_any() + .downcast_ref::>() + .expect("Unable to downcast to StringArray/LargeStringArray"); + let regex_array = rhs + .as_any() + .downcast_ref::>() + .expect("Unable to downcast to StringArray/LargeStringArray"); + let flags_array = flags.map(|flags| { + flags + .as_any() + .downcast_ref::>() + .expect("Unable to downcast to StringArray/LargeStringArray") + }); + regexp_array_match(array, regex_array, flags_array) + } + DataType::LargeUtf8 => { + let array = lhs + .as_any() + .downcast_ref::>() + .expect("Unable to downcast to StringArray/LargeStringArray"); + let regex_array = rhs + .as_any() + .downcast_ref::>() + .expect("Unable to downcast to StringArray/LargeStringArray"); + let flags_array = flags.map(|flags| { + flags + .as_any() + .downcast_ref::>() + .expect("Unable to downcast to StringArray/LargeStringArray") + }); + regexp_array_match(array, regex_array, flags_array) + } + _ => { + return Err(ArrowError::ComputeError(format!( + "regexp_match() requires pattern to be either Utf8 or LargeUtf8" + ))); + } + } + } +} + #[cfg(test)] mod tests { use super::*; From a456e2c93912658a9f0b444ac4775ea6094a8ef8 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 26 Dec 2023 10:31:50 -0800 Subject: [PATCH 2/8] Add more tests --- arrow-string/src/regexp.rs | 74 ++++++++++++++++++++++++++++++++------ 1 file changed, 63 insertions(+), 11 deletions(-) diff --git a/arrow-string/src/regexp.rs b/arrow-string/src/regexp.rs index 05f3f3d60b39..0fe7e898e609 100644 --- a/arrow-string/src/regexp.rs +++ b/arrow-string/src/regexp.rs @@ -22,7 +22,7 @@ use arrow_array::builder::{BooleanBufferBuilder, GenericStringBuilder, ListBuild use arrow_array::*; use arrow_buffer::NullBuffer; use arrow_data::{ArrayData, ArrayDataBuilder}; -use arrow_schema::{ArrowError, DataType}; +use arrow_schema::{ArrowError, DataType, Field}; use regex::Regex; use std::collections::HashMap; use std::sync::Arc; @@ -230,12 +230,16 @@ fn regexp_array_match( fn get_scalar_pattern_flag<'a, OffsetSize: OffsetSizeTrait>( regex_array: &'a dyn Array, flag_array: Option<&'a dyn Array>, -) -> (&'a str, Option<&'a str>) { +) -> (Option<&'a str>, Option<&'a str>) { let regex = regex_array .as_any() .downcast_ref::>() .expect("Unable to downcast to StringArray/LargeStringArray"); - let regex = regex.value(0); + let regex = if regex.is_valid(0) { + Some(regex.value(0)) + } else { + None + }; if flag_array.is_some() { let flag = flag_array @@ -257,10 +261,8 @@ fn get_scalar_pattern_flag<'a, OffsetSize: OffsetSizeTrait>( fn regexp_scalar_match( array: &dyn Array, - regex: Option<&Regex>, + regex: &Regex, ) -> std::result::Result { - if regex.is_none() {} - let builder: GenericStringBuilder = GenericStringBuilder::with_capacity(0, 0); let mut list_builder = ListBuilder::new(builder); @@ -269,8 +271,6 @@ fn regexp_scalar_match( .downcast_ref::>() .expect("Unable to downcast to StringArray/LargeStringArray"); - let regex = regex.unwrap(); - array .iter() .map(|value| { @@ -371,8 +371,17 @@ pub fn regexp_match( } }; + if regex.is_none() { + return Ok(new_null_array( + &DataType::List(Arc::new(Field::new("item", lhs.data_type().clone(), true))), + lhs.len(), + )); + } + + let regex = regex.unwrap(); + let pattern = if let Some(flag) = flag { - format!("(?{regex}){flag}") + format!("(?{flag}){regex}") } else { regex.to_string() }; @@ -382,8 +391,8 @@ pub fn regexp_match( })?; match lhs.data_type() { - DataType::Utf8 => regexp_scalar_match::(lhs, Some(&re)), - DataType::LargeUtf8 => regexp_scalar_match::(lhs, Some(&re)), + DataType::Utf8 => regexp_scalar_match::(lhs, &re), + DataType::LargeUtf8 => regexp_scalar_match::(lhs, &re), _ => { return Err(ArrowError::ComputeError(format!( "regexp_match() requires array to be either Utf8 or LargeUtf8" @@ -491,6 +500,49 @@ mod tests { assert_eq!(&expected, result); } + #[test] + fn match_scalar_pattern() { + let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None]; + let array = StringArray::from(values); + let pattern = Scalar::new(StringArray::from(vec![r"x.*-(\d*)-.*"; 1])); + let flags = Scalar::new(StringArray::from(vec!["i"; 1])); + let actual = regexp_match(&array, &pattern, Some(&flags)).unwrap(); + let elem_builder: GenericStringBuilder = GenericStringBuilder::with_capacity(0, 0); + let mut expected_builder = ListBuilder::new(elem_builder); + expected_builder.append(false); + expected_builder.values().append_value("7"); + expected_builder.append(true); + expected_builder.append(false); + expected_builder.append(false); + let expected = expected_builder.finish(); + let result = actual.as_any().downcast_ref::().unwrap(); + assert_eq!(&expected, result); + + // No flag + let values = vec![Some("abc-005-def"), Some("x-7-5"), Some("X545"), None]; + let array = StringArray::from(values); + let actual = regexp_match(&array, &pattern, None).unwrap(); + let result = actual.as_any().downcast_ref::().unwrap(); + assert_eq!(&expected, result); + } + + #[test] + fn match_scalar_no_pattern() { + let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None]; + let array = StringArray::from(values); + let pattern = Scalar::new(new_null_array(&DataType::Utf8, 1)); + let actual = regexp_match(&array, &pattern, None).unwrap(); + let elem_builder: GenericStringBuilder = GenericStringBuilder::with_capacity(0, 0); + let mut expected_builder = ListBuilder::new(elem_builder); + expected_builder.append(false); + expected_builder.append(false); + expected_builder.append(false); + expected_builder.append(false); + let expected = expected_builder.finish(); + let result = actual.as_any().downcast_ref::().unwrap(); + assert_eq!(&expected, result); + } + #[test] fn test_single_group_not_skip_match() { let array = StringArray::from(vec![Some("foo"), Some("bar")]); From 1bed0b46c3615d48e949dde517548e84e003fad2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 26 Dec 2023 10:46:04 -0800 Subject: [PATCH 3/8] More --- arrow-string/src/regexp.rs | 94 +++++++++++--------------------------- 1 file changed, 27 insertions(+), 67 deletions(-) diff --git a/arrow-string/src/regexp.rs b/arrow-string/src/regexp.rs index 0fe7e898e609..e29f3eaf3951 100644 --- a/arrow-string/src/regexp.rs +++ b/arrow-string/src/regexp.rs @@ -260,17 +260,12 @@ fn get_scalar_pattern_flag<'a, OffsetSize: OffsetSizeTrait>( } fn regexp_scalar_match( - array: &dyn Array, + array: &GenericStringArray, regex: &Regex, ) -> std::result::Result { let builder: GenericStringBuilder = GenericStringBuilder::with_capacity(0, 0); let mut list_builder = ListBuilder::new(builder); - let array = array - .as_any() - .downcast_ref::>() - .expect("Unable to downcast to StringArray/LargeStringArray"); - array .iter() .map(|value| { @@ -325,14 +320,19 @@ fn regexp_scalar_match( /// that change the function's behavior. /// /// [regexp_match]: https://www.postgresql.org/docs/current/functions-matching.html#FUNCTIONS-POSIX-REGEXP -pub fn regexp_match( - array: &dyn Datum, +pub fn regexp_match( + array: &GenericStringArray, regex_array: &dyn Datum, flags_array: Option<&dyn Datum>, ) -> std::result::Result { - let (lhs, is_lhs_scalar) = array.get(); let (rhs, is_rhs_scalar) = regex_array.get(); + if array.data_type() != rhs.data_type() { + return Err(ArrowError::ComputeError(format!( + "regexp_match() requires both array and pattern to be either Utf8 or LargeUtf8" + ))); + } + let (flags, is_flags_scalar) = match flags_array { Some(flags) => { let (flags, is_flags_scalar) = flags.get(); @@ -341,12 +341,6 @@ pub fn regexp_match( None => (None, None), }; - if is_lhs_scalar { - return Err(ArrowError::ComputeError(format!( - "regexp_match() requires array to be either Utf8 or LargeUtf8 array instead of scalar" - ))); - } - if is_flags_scalar.is_some() && is_rhs_scalar != is_flags_scalar.unwrap() { return Err(ArrowError::ComputeError(format!( "regexp_match() requires both pattern and flags to be either scalar or array" @@ -373,8 +367,12 @@ pub fn regexp_match( if regex.is_none() { return Ok(new_null_array( - &DataType::List(Arc::new(Field::new("item", lhs.data_type().clone(), true))), - lhs.len(), + &DataType::List(Arc::new(Field::new( + "item", + array.data_type().clone(), + true, + ))), + array.len(), )); } @@ -390,57 +388,19 @@ pub fn regexp_match( ArrowError::ComputeError(format!("Regular expression did not compile: {e:?}")) })?; - match lhs.data_type() { - DataType::Utf8 => regexp_scalar_match::(lhs, &re), - DataType::LargeUtf8 => regexp_scalar_match::(lhs, &re), - _ => { - return Err(ArrowError::ComputeError(format!( - "regexp_match() requires array to be either Utf8 or LargeUtf8" - ))); - } - } + regexp_scalar_match(array, &re) } else { - match rhs.data_type() { - DataType::Utf8 => { - let array = lhs - .as_any() - .downcast_ref::>() - .expect("Unable to downcast to StringArray/LargeStringArray"); - let regex_array = rhs - .as_any() - .downcast_ref::>() - .expect("Unable to downcast to StringArray/LargeStringArray"); - let flags_array = flags.map(|flags| { - flags - .as_any() - .downcast_ref::>() - .expect("Unable to downcast to StringArray/LargeStringArray") - }); - regexp_array_match(array, regex_array, flags_array) - } - DataType::LargeUtf8 => { - let array = lhs - .as_any() - .downcast_ref::>() - .expect("Unable to downcast to StringArray/LargeStringArray"); - let regex_array = rhs - .as_any() - .downcast_ref::>() - .expect("Unable to downcast to StringArray/LargeStringArray"); - let flags_array = flags.map(|flags| { - flags - .as_any() - .downcast_ref::>() - .expect("Unable to downcast to StringArray/LargeStringArray") - }); - regexp_array_match(array, regex_array, flags_array) - } - _ => { - return Err(ArrowError::ComputeError(format!( - "regexp_match() requires pattern to be either Utf8 or LargeUtf8" - ))); - } - } + let regex_array = rhs + .as_any() + .downcast_ref::>() + .expect("Unable to downcast to StringArray/LargeStringArray"); + let flags_array = flags.map(|flags| { + flags + .as_any() + .downcast_ref::>() + .expect("Unable to downcast to StringArray/LargeStringArray") + }); + regexp_array_match(array, regex_array, flags_array) } } From c96a6dc2eb797a7d17fd89da9e125bfc949f6e18 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 26 Dec 2023 10:59:50 -0800 Subject: [PATCH 4/8] Update benchmark --- arrow-string/src/regexp.rs | 2 +- arrow/benches/regexp_kernels.rs | 9 ++++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/arrow-string/src/regexp.rs b/arrow-string/src/regexp.rs index e29f3eaf3951..7ff42421a9f9 100644 --- a/arrow-string/src/regexp.rs +++ b/arrow-string/src/regexp.rs @@ -152,7 +152,7 @@ pub fn regexp_is_match_utf8_scalar( Ok(BooleanArray::from(data)) } -fn regexp_array_match( +pub fn regexp_array_match( array: &GenericStringArray, regex_array: &GenericStringArray, flags_array: Option<&GenericStringArray>, diff --git a/arrow/benches/regexp_kernels.rs b/arrow/benches/regexp_kernels.rs index eb38ba6783bc..d5ffbcb997ff 100644 --- a/arrow/benches/regexp_kernels.rs +++ b/arrow/benches/regexp_kernels.rs @@ -25,7 +25,7 @@ use arrow::array::*; use arrow::compute::kernels::regexp::*; use arrow::util::bench_util::*; -fn bench_regexp(arr: &GenericStringArray, regex_array: &GenericStringArray) { +fn bench_regexp(arr: &GenericStringArray, regex_array: &dyn Datum) { regexp_match(criterion::black_box(arr), regex_array, None).unwrap(); } @@ -38,6 +38,13 @@ fn add_benchmark(c: &mut Criterion) { let pattern = GenericStringArray::::from(pattern_values); c.bench_function("regexp", |b| b.iter(|| bench_regexp(&arr_string, &pattern))); + + let pattern_values = vec![r".*-(\d*)-.*"]; + let pattern = Scalar::new(GenericStringArray::::from(pattern_values)); + + c.bench_function("regexp scalar", |b| { + b.iter(|| bench_regexp(&arr_string, &pattern)) + }); } criterion_group!(benches, add_benchmark); From 91a93f0a6ec56aea87569b24905bb0a8872ea600 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 26 Dec 2023 11:04:17 -0800 Subject: [PATCH 5/8] Fix clippy --- arrow-string/src/regexp.rs | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/arrow-string/src/regexp.rs b/arrow-string/src/regexp.rs index 7ff42421a9f9..765adf507b2c 100644 --- a/arrow-string/src/regexp.rs +++ b/arrow-string/src/regexp.rs @@ -328,9 +328,10 @@ pub fn regexp_match( let (rhs, is_rhs_scalar) = regex_array.get(); if array.data_type() != rhs.data_type() { - return Err(ArrowError::ComputeError(format!( + return Err(ArrowError::ComputeError( "regexp_match() requires both array and pattern to be either Utf8 or LargeUtf8" - ))); + .to_string(), + )); } let (flags, is_flags_scalar) = match flags_array { @@ -342,15 +343,17 @@ pub fn regexp_match( }; if is_flags_scalar.is_some() && is_rhs_scalar != is_flags_scalar.unwrap() { - return Err(ArrowError::ComputeError(format!( + return Err(ArrowError::ComputeError( "regexp_match() requires both pattern and flags to be either scalar or array" - ))); + .to_string(), + )); } if flags_array.is_some() && rhs.data_type() != flags.unwrap().data_type() { - return Err(ArrowError::ComputeError(format!( + return Err(ArrowError::ComputeError( "regexp_match() requires both pattern and flags to be either string or largestring" - ))); + .to_string(), + )); } if is_rhs_scalar { @@ -359,9 +362,9 @@ pub fn regexp_match( DataType::Utf8 => get_scalar_pattern_flag::(rhs, flags), DataType::LargeUtf8 => get_scalar_pattern_flag::(rhs, flags), _ => { - return Err(ArrowError::ComputeError(format!( - "regexp_match() requires pattern to be either Utf8 or LargeUtf8" - ))); + return Err(ArrowError::ComputeError( + "regexp_match() requires pattern to be either Utf8 or LargeUtf8".to_string(), + )); } }; From cab02dd151b62a9b9be3e7b63a29101b8cefea76 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 31 Dec 2023 19:12:38 -0800 Subject: [PATCH 6/8] For review --- arrow-string/src/regexp.rs | 74 +++++++++++++++++++------------------- 1 file changed, 36 insertions(+), 38 deletions(-) diff --git a/arrow-string/src/regexp.rs b/arrow-string/src/regexp.rs index 765adf507b2c..40753756c742 100644 --- a/arrow-string/src/regexp.rs +++ b/arrow-string/src/regexp.rs @@ -19,6 +19,7 @@ //! expression of a \[Large\]StringArray use arrow_array::builder::{BooleanBufferBuilder, GenericStringBuilder, ListBuilder}; +use arrow_array::cast::AsArray; use arrow_array::*; use arrow_buffer::NullBuffer; use arrow_data::{ArrayData, ArrayDataBuilder}; @@ -231,29 +232,12 @@ fn get_scalar_pattern_flag<'a, OffsetSize: OffsetSizeTrait>( regex_array: &'a dyn Array, flag_array: Option<&'a dyn Array>, ) -> (Option<&'a str>, Option<&'a str>) { - let regex = regex_array - .as_any() - .downcast_ref::>() - .expect("Unable to downcast to StringArray/LargeStringArray"); - let regex = if regex.is_valid(0) { - Some(regex.value(0)) - } else { - None - }; - - if flag_array.is_some() { - let flag = flag_array - .unwrap() - .as_any() - .downcast_ref::>() - .expect("Unable to downcast to StringArray/LargeStringArray"); + let regex = regex_array.as_string::(); + let regex = regex.is_valid(0).then(|| regex.value(0)); - if flag.is_valid(0) { - let flag = flag.value(0); - (regex, Some(flag)) - } else { - (regex, None) - } + if let Some(flag_array) = flag_array { + let flag = flag_array.as_string::(); + (regex, flag.is_valid(0).then(|| flag.value(0))) } else { (regex, None) } @@ -262,7 +246,7 @@ fn get_scalar_pattern_flag<'a, OffsetSize: OffsetSizeTrait>( fn regexp_scalar_match( array: &GenericStringArray, regex: &Regex, -) -> std::result::Result { +) -> Result { let builder: GenericStringBuilder = GenericStringBuilder::with_capacity(0, 0); let mut list_builder = ListBuilder::new(builder); @@ -320,11 +304,11 @@ fn regexp_scalar_match( /// that change the function's behavior. /// /// [regexp_match]: https://www.postgresql.org/docs/current/functions-matching.html#FUNCTIONS-POSIX-REGEXP -pub fn regexp_match( - array: &GenericStringArray, +pub fn regexp_match( + array: &dyn Array, regex_array: &dyn Datum, flags_array: Option<&dyn Datum>, -) -> std::result::Result { +) -> Result { let (rhs, is_rhs_scalar) = regex_array.get(); if array.data_type() != rhs.data_type() { @@ -391,19 +375,33 @@ pub fn regexp_match( ArrowError::ComputeError(format!("Regular expression did not compile: {e:?}")) })?; - regexp_scalar_match(array, &re) + match array.data_type() { + DataType::Utf8 => regexp_scalar_match(array.as_string::(), &re), + DataType::LargeUtf8 => regexp_scalar_match(array.as_string::(), &re), + _ => { + return Err(ArrowError::ComputeError( + "regexp_match() requires array to be either Utf8 or LargeUtf8".to_string(), + )); + } + } } else { - let regex_array = rhs - .as_any() - .downcast_ref::>() - .expect("Unable to downcast to StringArray/LargeStringArray"); - let flags_array = flags.map(|flags| { - flags - .as_any() - .downcast_ref::>() - .expect("Unable to downcast to StringArray/LargeStringArray") - }); - regexp_array_match(array, regex_array, flags_array) + match array.data_type() { + DataType::Utf8 => { + let regex_array = rhs.as_string(); + let flags_array = flags.map(|flags| flags.as_string()); + regexp_array_match(array.as_string::(), regex_array, flags_array) + } + DataType::LargeUtf8 => { + let regex_array = rhs.as_string(); + let flags_array = flags.map(|flags| flags.as_string()); + regexp_array_match(array.as_string::(), regex_array, flags_array) + } + _ => { + return Err(ArrowError::ComputeError( + "regexp_match() requires array to be either Utf8 or LargeUtf8".to_string(), + )); + } + } } } From e93ee10e9e25b8183f72c9d6fde6be6868f11ebe Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 31 Dec 2023 19:21:42 -0800 Subject: [PATCH 7/8] Fix clippy --- arrow-string/src/regexp.rs | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/arrow-string/src/regexp.rs b/arrow-string/src/regexp.rs index 40753756c742..13b06826cc8d 100644 --- a/arrow-string/src/regexp.rs +++ b/arrow-string/src/regexp.rs @@ -378,11 +378,9 @@ pub fn regexp_match( match array.data_type() { DataType::Utf8 => regexp_scalar_match(array.as_string::(), &re), DataType::LargeUtf8 => regexp_scalar_match(array.as_string::(), &re), - _ => { - return Err(ArrowError::ComputeError( - "regexp_match() requires array to be either Utf8 or LargeUtf8".to_string(), - )); - } + _ => Err(ArrowError::ComputeError( + "regexp_match() requires array to be either Utf8 or LargeUtf8".to_string(), + )), } } else { match array.data_type() { @@ -396,11 +394,9 @@ pub fn regexp_match( let flags_array = flags.map(|flags| flags.as_string()); regexp_array_match(array.as_string::(), regex_array, flags_array) } - _ => { - return Err(ArrowError::ComputeError( - "regexp_match() requires array to be either Utf8 or LargeUtf8".to_string(), - )); - } + _ => Err(ArrowError::ComputeError( + "regexp_match() requires array to be either Utf8 or LargeUtf8".to_string(), + )), } } } From ad46720c5a86f9c10a2d9ab85a7e913a56103195 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 31 Dec 2023 23:43:31 -0800 Subject: [PATCH 8/8] Don't expose utility function --- arrow-string/src/regexp.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arrow-string/src/regexp.rs b/arrow-string/src/regexp.rs index 13b06826cc8d..5e539b91b492 100644 --- a/arrow-string/src/regexp.rs +++ b/arrow-string/src/regexp.rs @@ -153,7 +153,7 @@ pub fn regexp_is_match_utf8_scalar( Ok(BooleanArray::from(data)) } -pub fn regexp_array_match( +fn regexp_array_match( array: &GenericStringArray, regex_array: &GenericStringArray, flags_array: Option<&GenericStringArray>,