Skip to content

Commit

Permalink
fix pattern handling in regexp_match function (#1065)
Browse files Browse the repository at this point in the history
  • Loading branch information
QP Hou authored Oct 5, 2021
1 parent a8dedc8 commit 5cc4e9f
Showing 1 changed file with 66 additions and 4 deletions.
70 changes: 66 additions & 4 deletions datafusion/src/physical_plan/regex_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,17 @@ macro_rules! downcast_string_arg {
/// extract a specific group from a string column, using a regular expression
pub fn regexp_match<T: StringOffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
2 => compute::regexp_match(downcast_string_arg!(args[0], "string", T), downcast_string_arg!(args[1], "pattern", T), None)
.map_err(DataFusionError::ArrowError),
3 => compute::regexp_match(downcast_string_arg!(args[0], "string", T), downcast_string_arg!(args[1], "pattern", T), Some(downcast_string_arg!(args[1], "flags", T)))
.map_err(DataFusionError::ArrowError),
2 => {
let values = downcast_string_arg!(args[0], "string", T);
let regex = downcast_string_arg!(args[1], "pattern", T);
compute::regexp_match(values, regex, None).map_err(DataFusionError::ArrowError)
}
3 => {
let values = downcast_string_arg!(args[0], "string", T);
let regex = downcast_string_arg!(args[1], "pattern", T);
let flags = Some(downcast_string_arg!(args[2], "flags", T));
compute::regexp_match(values, regex, flags).map_err(DataFusionError::ArrowError)
}
other => Err(DataFusionError::Internal(format!(
"regexp_match was called with {} arguments. It requires at least 2 and at most 3.",
other
Expand Down Expand Up @@ -170,3 +177,58 @@ pub fn regexp_replace<T: StringOffsetSizeTrait>(args: &[ArrayRef]) -> Result<Arr
))),
}
}

#[cfg(test)]
mod tests {
use super::*;
use arrow::array::*;

#[test]
fn test_case_sensitive_regexp_match() {
let values = StringArray::from(vec!["abc"; 5]);
let patterns =
StringArray::from(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]);

let elem_builder: GenericStringBuilder<i32> = GenericStringBuilder::new(0);
let mut expected_builder = ListBuilder::new(elem_builder);
expected_builder.values().append_value("a").unwrap();
expected_builder.append(true).unwrap();
expected_builder.append(false).unwrap();
expected_builder.values().append_value("b").unwrap();
expected_builder.append(true).unwrap();
expected_builder.append(false).unwrap();
expected_builder.append(false).unwrap();
let expected = expected_builder.finish();

let re = regexp_match::<i32>(&[Arc::new(values), Arc::new(patterns)]).unwrap();

assert_eq!(re.as_ref(), &expected);
}

#[test]
fn test_case_insensitive_regexp_match() {
let values = StringArray::from(vec!["abc"; 5]);
let patterns =
StringArray::from(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]);
let flags = StringArray::from(vec!["i"; 5]);

let elem_builder: GenericStringBuilder<i32> = GenericStringBuilder::new(0);
let mut expected_builder = ListBuilder::new(elem_builder);
expected_builder.values().append_value("a").unwrap();
expected_builder.append(true).unwrap();
expected_builder.values().append_value("a").unwrap();
expected_builder.append(true).unwrap();
expected_builder.values().append_value("b").unwrap();
expected_builder.append(true).unwrap();
expected_builder.values().append_value("b").unwrap();
expected_builder.append(true).unwrap();
expected_builder.append(false).unwrap();
let expected = expected_builder.finish();

let re =
regexp_match::<i32>(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)])
.unwrap();

assert_eq!(re.as_ref(), &expected);
}
}

0 comments on commit 5cc4e9f

Please sign in to comment.