From 5cc4e9f53fab29e81ea7c98baac8ce277a0cb54a Mon Sep 17 00:00:00 2001 From: QP Hou Date: Tue, 5 Oct 2021 12:40:46 -0700 Subject: [PATCH] fix pattern handling in regexp_match function (#1065) --- .../src/physical_plan/regex_expressions.rs | 70 +++++++++++++++++-- 1 file changed, 66 insertions(+), 4 deletions(-) diff --git a/datafusion/src/physical_plan/regex_expressions.rs b/datafusion/src/physical_plan/regex_expressions.rs index 69b27ffb2662..4a10d0d95b26 100644 --- a/datafusion/src/physical_plan/regex_expressions.rs +++ b/datafusion/src/physical_plan/regex_expressions.rs @@ -47,10 +47,17 @@ macro_rules! downcast_string_arg { /// extract a specific group from a string column, using a regular expression pub fn regexp_match(args: &[ArrayRef]) -> Result { match args.len() { - 2 => compute::regexp_match(downcast_string_arg!(args[0], "string", T), downcast_string_arg!(args[1], "pattern", T), None) - .map_err(DataFusionError::ArrowError), - 3 => compute::regexp_match(downcast_string_arg!(args[0], "string", T), downcast_string_arg!(args[1], "pattern", T), Some(downcast_string_arg!(args[1], "flags", T))) - .map_err(DataFusionError::ArrowError), + 2 => { + let values = downcast_string_arg!(args[0], "string", T); + let regex = downcast_string_arg!(args[1], "pattern", T); + compute::regexp_match(values, regex, None).map_err(DataFusionError::ArrowError) + } + 3 => { + let values = downcast_string_arg!(args[0], "string", T); + let regex = downcast_string_arg!(args[1], "pattern", T); + let flags = Some(downcast_string_arg!(args[2], "flags", T)); + compute::regexp_match(values, regex, flags).map_err(DataFusionError::ArrowError) + } other => Err(DataFusionError::Internal(format!( "regexp_match was called with {} arguments. It requires at least 2 and at most 3.", other @@ -170,3 +177,58 @@ pub fn regexp_replace(args: &[ArrayRef]) -> Result = GenericStringBuilder::new(0); + let mut expected_builder = ListBuilder::new(elem_builder); + expected_builder.values().append_value("a").unwrap(); + expected_builder.append(true).unwrap(); + expected_builder.append(false).unwrap(); + expected_builder.values().append_value("b").unwrap(); + expected_builder.append(true).unwrap(); + expected_builder.append(false).unwrap(); + expected_builder.append(false).unwrap(); + let expected = expected_builder.finish(); + + let re = regexp_match::(&[Arc::new(values), Arc::new(patterns)]).unwrap(); + + assert_eq!(re.as_ref(), &expected); + } + + #[test] + fn test_case_insensitive_regexp_match() { + let values = StringArray::from(vec!["abc"; 5]); + let patterns = + StringArray::from(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]); + let flags = StringArray::from(vec!["i"; 5]); + + let elem_builder: GenericStringBuilder = GenericStringBuilder::new(0); + let mut expected_builder = ListBuilder::new(elem_builder); + expected_builder.values().append_value("a").unwrap(); + expected_builder.append(true).unwrap(); + expected_builder.values().append_value("a").unwrap(); + expected_builder.append(true).unwrap(); + expected_builder.values().append_value("b").unwrap(); + expected_builder.append(true).unwrap(); + expected_builder.values().append_value("b").unwrap(); + expected_builder.append(true).unwrap(); + expected_builder.append(false).unwrap(); + let expected = expected_builder.finish(); + + let re = + regexp_match::(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)]) + .unwrap(); + + assert_eq!(re.as_ref(), &expected); + } +}