diff --git a/arrow-string/src/regexp.rs b/arrow-string/src/regexp.rs index 4924e33df485..e17afff9b56b 100644 --- a/arrow-string/src/regexp.rs +++ b/arrow-string/src/regexp.rs @@ -19,6 +19,7 @@ //! expression of a \[Large\]StringArray use crate::like::StringArrayType; + use arrow_array::builder::{BooleanBufferBuilder, GenericStringBuilder, ListBuilder}; use arrow_array::cast::AsArray; use arrow_array::*; @@ -26,9 +27,93 @@ use arrow_buffer::NullBuffer; use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::{ArrowError, DataType, Field}; use regex::Regex; + use std::collections::HashMap; use std::sync::Arc; +/// Perform SQL `array ~ regex_array` operation on [`StringArray`] / [`LargeStringArray`]. +/// If `regex_array` element has an empty value, the corresponding result value is always true. +/// +/// `flags_array` are optional [`StringArray`] / [`LargeStringArray`] flag, which allow +/// special search modes, such as case insensitive and multi-line mode. +/// See the documentation [here](https://docs.rs/regex/1.5.4/regex/#grouping-and-flags) +/// for more information. +#[deprecated(since = "54.0.0", note = "please use `regex_is_match` instead")] +pub fn regexp_is_match_utf8( + array: &GenericStringArray, + regex_array: &GenericStringArray, + flags_array: Option<&GenericStringArray>, +) -> Result { + if array.len() != regex_array.len() { + return Err(ArrowError::ComputeError( + "Cannot perform comparison operation on arrays of different length".to_string(), + )); + } + let nulls = NullBuffer::union(array.nulls(), regex_array.nulls()); + + let mut patterns: HashMap = HashMap::new(); + let mut result = BooleanBufferBuilder::new(array.len()); + + let complete_pattern = match flags_array { + Some(flags) => Box::new( + regex_array + .iter() + .zip(flags.iter()) + .map(|(pattern, flags)| { + pattern.map(|pattern| match flags { + Some(flag) => format!("(?{flag}){pattern}"), + None => pattern.to_string(), + }) + }), + ) as Box>>, + None => Box::new( + regex_array + .iter() + .map(|pattern| pattern.map(|pattern| pattern.to_string())), + ), + }; + + array + .iter() + .zip(complete_pattern) + .map(|(value, pattern)| { + match (value, pattern) { + // Required for Postgres compatibility: + // SELECT 'foobarbequebaz' ~ ''); = true + (Some(_), Some(pattern)) if pattern == *"" => { + result.append(true); + } + (Some(value), Some(pattern)) => { + let existing_pattern = patterns.get(&pattern); + let re = match existing_pattern { + Some(re) => re, + None => { + let re = Regex::new(pattern.as_str()).map_err(|e| { + ArrowError::ComputeError(format!( + "Regular expression did not compile: {e:?}" + )) + })?; + patterns.entry(pattern).or_insert(re) + } + }; + result.append(re.is_match(value)); + } + _ => result.append(false), + } + Ok(()) + }) + .collect::, ArrowError>>()?; + + let data = unsafe { + ArrayDataBuilder::new(DataType::Boolean) + .len(array.len()) + .buffers(vec![result.into()]) + .nulls(nulls) + .build_unchecked() + }; + Ok(BooleanArray::from(data)) +} + /// Perform SQL `array ~ regex_array` operation on /// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`]. /// @@ -38,7 +123,7 @@ use std::sync::Arc; /// which allow special search modes, such as case-insensitive and multi-line mode. /// See the documentation [here](https://docs.rs/regex/1.5.4/regex/#grouping-and-flags) /// for more information. -pub fn regexp_is_match_utf8<'a, S1, S2, S3>( +pub fn regexp_is_match<'a, S1, S2, S3>( array: &'a S1, regex_array: &'a S2, flags_array: Option<&'a S3>, @@ -120,11 +205,56 @@ where Ok(BooleanArray::from(data)) } +/// Perform SQL `array ~ regex_array` operation on [`StringArray`] / +/// [`LargeStringArray`] and a scalar. +/// +/// See the documentation on [`regexp_is_match_utf8`] for more details. +#[deprecated(since = "54.0.0", note = "please use `regex_is_match_scalar` instead")] +pub fn regexp_is_match_utf8_scalar( + array: &GenericStringArray, + regex: &str, + flag: Option<&str>, +) -> Result { + let null_bit_buffer = array.nulls().map(|x| x.inner().sliced()); + let mut result = BooleanBufferBuilder::new(array.len()); + + let pattern = match flag { + Some(flag) => format!("(?{flag}){regex}"), + None => regex.to_string(), + }; + + if pattern.is_empty() { + result.append_n(array.len(), true); + } else { + let re = Regex::new(pattern.as_str()).map_err(|e| { + ArrowError::ComputeError(format!("Regular expression did not compile: {e:?}")) + })?; + for i in 0..array.len() { + let value = array.value(i); + result.append(re.is_match(value)); + } + } + + let buffer = result.into(); + let data = unsafe { + ArrayData::new_unchecked( + DataType::Boolean, + array.len(), + None, + null_bit_buffer, + 0, + vec![buffer], + vec![], + ) + }; + Ok(BooleanArray::from(data)) +} + /// Perform SQL `array ~ regex_array` operation on /// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`] and a scalar. /// -/// See the documentation on [`regexp_is_match_utf8`] for more details. -pub fn regexp_is_match_utf8_scalar<'a, S>( +/// See the documentation on [`regexp_is_match`] for more details. +pub fn regexp_is_match_scalar<'a, S>( array: &'a S, regex: &str, flag: Option<&str>, @@ -139,6 +269,7 @@ where Some(flag) => format!("(?{flag}){regex}"), None => regex.to_string(), }; + if pattern.is_empty() { result.append_n(array.len(), true); } else { @@ -163,6 +294,7 @@ where vec![], ) }; + Ok(BooleanArray::from(data)) } @@ -603,45 +735,60 @@ mod tests { } test_flag_utf8!( - test_utf8_array_regexp_is_match, + test_array_regexp_is_match_utf8, StringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]), StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]), - regexp_is_match_utf8::< - GenericStringArray, - GenericStringArray, - GenericStringArray, - >, + regexp_is_match_utf8, [true, false, true, false, false, true] ); test_flag_utf8!( - test_utf8_array_regexp_is_match_2, + test_array_regexp_is_match_utf8_insensitive, + StringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]), + StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]), + StringArray::from(vec!["i"; 6]), + regexp_is_match_utf8, + [true, true, true, true, false, true] + ); + + test_flag_utf8_scalar!( + test_array_regexp_is_match_utf8_scalar, + StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]), + "^ar", + regexp_is_match_utf8_scalar, + [true, false, false, false] + ); + test_flag_utf8_scalar!( + test_array_regexp_is_match_utf8_scalar_empty, + StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]), + "", + regexp_is_match_utf8_scalar, + [true, true, true, true] + ); + test_flag_utf8_scalar!( + test_array_regexp_is_match_utf8_scalar_insensitive, + StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]), + "^ar", + "i", + regexp_is_match_utf8_scalar, + [true, true, false, false] + ); + + test_flag_utf8!( + tes_array_regexp_is_match, StringViewArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]), StringViewArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]), - regexp_is_match_utf8::, + regexp_is_match::, [true, false, true, false, false, true] ); test_flag_utf8!( - test_utf8_array_regexp_is_match_3, + test_array_regexp_is_match_2, StringViewArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]), StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]), - regexp_is_match_utf8::, GenericStringArray>, + regexp_is_match::, GenericStringArray>, [true, false, true, false, false, true] ); - - test_flag_utf8!( - test_utf8_array_regexp_is_match_insensitive, - StringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]), - StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]), - StringArray::from(vec!["i"; 6]), - regexp_is_match_utf8::< - GenericStringArray, - GenericStringArray, - GenericStringArray, - >, - [true, true, true, true, false, true] - ); test_flag_utf8!( - test_utf8_array_regexp_is_match_insensitive_2, + test_array_regexp_is_match_insensitive, StringViewArray::from(vec![ "Official Rust implementation of Apache Arrow", "apache/arrow-rs", @@ -661,27 +808,20 @@ mod tests { "" ]), StringViewArray::from(vec!["i"; 7]), - regexp_is_match_utf8::, + regexp_is_match::, [true, true, true, true, true, false, true] ); test_flag_utf8!( - test_utf8_array_regexp_is_match_insensitive_3, + test_array_regexp_is_match_insensitive_2, LargeStringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]), StringViewArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]), StringArray::from(vec!["i"; 6]), - regexp_is_match_utf8::, StringViewArray, GenericStringArray>, + regexp_is_match::, StringViewArray, GenericStringArray>, [true, true, true, true, false, true] ); test_flag_utf8_scalar!( - test_utf8_array_regexp_is_match_scalar, - StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]), - "^ar", - regexp_is_match_utf8_scalar::>, - [true, false, false, false] - ); - test_flag_utf8_scalar!( - test_utf8_array_regexp_is_match_scalar_2, + test_array_regexp_is_match_scalar, StringViewArray::from(vec![ "apache/arrow-rs", "APACHE/ARROW-RS", @@ -689,19 +829,11 @@ mod tests { "PARQUET", ]), "^ap", - regexp_is_match_utf8_scalar::, + regexp_is_match_scalar::, [true, false, false, false] ); - test_flag_utf8_scalar!( - test_utf8_array_regexp_is_match_empty_scalar, - StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]), - "", - regexp_is_match_utf8_scalar::>, - [true, true, true, true] - ); - test_flag_utf8_scalar!( - test_utf8_array_regexp_is_match_empty_scalar_2, + test_array_regexp_is_match_scalar_empty, StringViewArray::from(vec![ "apache/arrow-rs", "APACHE/ARROW-RS", @@ -709,20 +841,11 @@ mod tests { "PARQUET", ]), "", - regexp_is_match_utf8_scalar::, + regexp_is_match_scalar::, [true, true, true, true] ); - - test_flag_utf8_scalar!( - test_utf8_array_regexp_is_match_insensitive_scalar, - StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]), - "^ar", - "i", - regexp_is_match_utf8_scalar::>, - [true, true, false, false] - ); test_flag_utf8_scalar!( - test_utf8_array_regexp_is_match_insensitive_scalar_2, + test_array_regexp_is_match_scalar_insensitive, StringViewArray::from(vec![ "apache/arrow-rs", "APACHE/ARROW-RS", @@ -731,7 +854,7 @@ mod tests { ]), "^ap", "i", - regexp_is_match_utf8_scalar::, + regexp_is_match_scalar::, [true, true, false, false] ); } diff --git a/arrow/benches/comparison_kernels.rs b/arrow/benches/comparison_kernels.rs index 892c1dd40ccc..0e02a1d46163 100644 --- a/arrow/benches/comparison_kernels.rs +++ b/arrow/benches/comparison_kernels.rs @@ -15,11 +15,9 @@ // specific language governing permissions and limitations // under the License. +extern crate arrow; #[macro_use] extern crate criterion; -extern crate arrow; - -use std::time::Duration; use arrow::compute::kernels::cmp::*; use arrow::util::bench_util::*; @@ -27,8 +25,8 @@ use arrow::util::test_util::seedable_rng; use arrow::{array::*, datatypes::Float32Type, datatypes::Int32Type}; use arrow_buffer::IntervalMonthDayNano; use arrow_string::like::*; -use arrow_string::regexp::regexp_is_match_utf8_scalar; -use criterion::{Criterion, SamplingMode}; +use arrow_string::regexp::{regexp_is_match_scalar, regexp_is_match_utf8_scalar}; +use criterion::Criterion; use rand::rngs::StdRng; use rand::Rng; @@ -54,8 +52,8 @@ fn bench_nilike_utf8_scalar(arr_a: &StringArray, value_b: &str) { nilike(arr_a, &StringArray::new_scalar(value_b)).unwrap(); } -fn bench_regexp_is_match_utf8view_scalar(arr_a: &StringViewArray, value_b: &str) { - regexp_is_match_utf8_scalar( +fn bench_regexp_is_match_scalar(arr_a: &StringViewArray, value_b: &str) { + regexp_is_match_scalar( criterion::black_box(arr_a), criterion::black_box(value_b), None, @@ -357,25 +355,20 @@ fn add_benchmark(c: &mut Criterion) { // StringArray: regexp_matches_utf8 scalar benchmarks let mut group = c.benchmark_group("StringArray: regexp_matches_utf8 scalar benchmarks".to_string()); - group.sampling_mode(SamplingMode::Flat); - group.sample_size(60); - group.measurement_time(Duration::from_secs(8)); - - group.bench_function("regexp_matches_utf8 scalar starts with", |b| { - b.iter(|| bench_regexp_is_match_utf8_scalar(&arr_string, "^xx")) - }); - group.bench_function("regexp_matches_utf8 scalar contains", |b| { - b.iter(|| bench_regexp_is_match_utf8_scalar(&arr_string, ".*xx.*")) - }); - - group.bench_function("regexp_matches_utf8 scalar ends with", |b| { - b.iter(|| bench_regexp_is_match_utf8_scalar(&arr_string, "xx$")) - }); - - group.bench_function("regexp_matches_utf8 scalar complex", |b| { - b.iter(|| bench_regexp_is_match_utf8_scalar(&arr_string, ".*x{2}.xX.*xXX")) - }); + group + .bench_function("regexp_matches_utf8 scalar starts with", |b| { + b.iter(|| bench_regexp_is_match_utf8_scalar(&arr_string, "^xx")) + }) + .bench_function("regexp_matches_utf8 scalar contains", |b| { + b.iter(|| bench_regexp_is_match_utf8_scalar(&arr_string, ".*xxXX.*")) + }) + .bench_function("regexp_matches_utf8 scalar ends with", |b| { + b.iter(|| bench_regexp_is_match_utf8_scalar(&arr_string, "xx$")) + }) + .bench_function("regexp_matches_utf8 scalar complex", |b| { + b.iter(|| bench_regexp_is_match_utf8_scalar(&arr_string, ".*x{2}.xX.*xXX")) + }); group.finish(); @@ -383,21 +376,19 @@ fn add_benchmark(c: &mut Criterion) { group = c.benchmark_group("StringViewArray: regexp_matches_utf8view scalar benchmarks".to_string()); - group.bench_function("regexp_matches_utf8view scalar starts with", |b| { - b.iter(|| bench_regexp_is_match_utf8view_scalar(&arr_string_view, "^xx")) - }); - - group.bench_function("regexp_matches_utf8view scalar contains", |b| { - b.iter(|| bench_regexp_is_match_utf8view_scalar(&arr_string_view, ".*xx.*")) - }); - - group.bench_function("regexp_matches_utf8view scalar ends with", |b| { - b.iter(|| bench_regexp_is_match_utf8view_scalar(&arr_string_view, "xx$")) - }); - - group.bench_function("regexp_matches_utf8view scalar complex", |b| { - b.iter(|| bench_regexp_is_match_utf8view_scalar(&arr_string_view, ".*x{2}.xX.*xXX")) - }); + group + .bench_function("regexp_matches_utf8view scalar starts with", |b| { + b.iter(|| bench_regexp_is_match_scalar(&arr_string_view, "^xx")) + }) + .bench_function("regexp_matches_utf8view scalar contains", |b| { + b.iter(|| bench_regexp_is_match_scalar(&arr_string_view, ".*xxXX.*")) + }) + .bench_function("regexp_matches_utf8view scalar ends with", |b| { + b.iter(|| bench_regexp_is_match_scalar(&arr_string_view, "xx$")) + }) + .bench_function("regexp_matches_utf8view scalar complex", |b| { + b.iter(|| bench_regexp_is_match_scalar(&arr_string_view, ".*x{2}.xX.*xXX")) + }); group.finish();