diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index ea48007211f5..a132963c8dbe 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -32,7 +32,7 @@ use crate::datatypes::{ }; use crate::error::{ArrowError, Result}; use crate::util::bit_util; -use regex::Regex; +use regex::{escape, Regex}; use std::any::type_name; use std::collections::HashMap; @@ -263,7 +263,7 @@ where let re = if let Some(ref regex) = map.get(pat) { regex } else { - let re_pattern = pat.replace("%", ".*").replace("_", "."); + let re_pattern = escape(pat).replace("%", ".*").replace("_", "."); let re = op(&re_pattern)?; map.insert(pat, re); map.get(pat).unwrap() @@ -303,7 +303,7 @@ where /// use arrow::compute::like_utf8; /// /// let strings = StringArray::from(vec!["Arrow", "Arrow", "Arrow", "Ar"]); -/// let patterns = StringArray::from(vec!["A%", "B%", "A.", "A."]); +/// let patterns = StringArray::from(vec!["A%", "B%", "A.", "A_"]); /// /// let result = like_utf8(&strings, &patterns).unwrap(); /// assert_eq!(result, BooleanArray::from(vec![true, false, false, true])); @@ -360,11 +360,7 @@ pub fn like_utf8_scalar( } } } else { - let re_pattern = right - .replace("%", ".*") - .replace("_", ".") - .replace("(", r#"\("#) - .replace(")", r#"\)"#); + let re_pattern = escape(right).replace("%", ".*").replace("_", "."); let re = Regex::new(&format!("^{}$", re_pattern)).map_err(|e| { ArrowError::ComputeError(format!( "Unable to build regex from LIKE pattern: {}", @@ -440,11 +436,7 @@ pub fn nlike_utf8_scalar( result.append(!left.value(i).ends_with(&right[1..])); } } else { - let re_pattern = right - .replace("%", ".*") - .replace("_", ".") - .replace("(", r#"\("#) - .replace(")", r#"\)"#); + let re_pattern = escape(right).replace("%", ".*").replace("_", "."); let re = Regex::new(&format!("^{}$", re_pattern)).map_err(|e| { ArrowError::ComputeError(format!( "Unable to build regex from LIKE pattern: {}", @@ -525,11 +517,7 @@ pub fn ilike_utf8_scalar( ); } } else { - let re_pattern = right - .replace("%", ".*") - .replace("_", ".") - .replace("(", r#"\("#) - .replace(")", r#"\)"#); + let re_pattern = escape(right).replace("%", ".*").replace("_", "."); let re = Regex::new(&format!("(?i)^{}$", re_pattern)).map_err(|e| { ArrowError::ComputeError(format!( "Unable to build regex from ILIKE pattern: {}", @@ -2235,10 +2223,10 @@ mod tests { test_utf8!( test_utf8_array_like, - vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrows", "arrow"], - vec!["arrow", "ar%", "%ro%", "foo", "arr", "arrow_", "arrow_"], + vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrows", "arrow", "arrow"], + vec!["arrow", "ar%", "%ro%", "foo", "arr", "arrow_", "arrow_", ".*"], like_utf8, - vec![true, true, true, false, false, true, false] + vec![true, true, true, false, false, true, false, false] ); test_utf8_scalar!( @@ -2248,6 +2236,23 @@ mod tests { like_utf8_scalar, vec![true, true, false, false] ); + + test_utf8_scalar!( + test_utf8_array_like_scalar_escape_regex, + vec![".*", "a", "*"], + ".*", + like_utf8_scalar, + vec![true, false, false] + ); + + test_utf8_scalar!( + test_utf8_array_like_scalar_escape_regex_dot, + vec![".", "a", "*"], + ".", + like_utf8_scalar, + vec![true, false, false] + ); + test_utf8_scalar!( test_utf8_array_like_scalar, vec!["arrow", "parquet", "datafusion", "flight"], @@ -2316,6 +2321,22 @@ mod tests { nlike_utf8_scalar, vec![false, false, true, true] ); + + test_utf8_scalar!( + test_utf8_array_nlike_scalar_escape_regex, + vec![".*", "a", "*"], + ".*", + nlike_utf8_scalar, + vec![false, true, true] + ); + + test_utf8_scalar!( + test_utf8_array_nlike_scalar_escape_regex_dot, + vec![".", "a", "*"], + ".", + nlike_utf8_scalar, + vec![false, true, true] + ); test_utf8_scalar!( test_utf8_array_nlike_scalar, vec!["arrow", "parquet", "datafusion", "flight"],