From b71bec0fd7d17eeab5e8002842322082cd187a25 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Sat, 16 Dec 2023 03:18:08 +0800 Subject: [PATCH] feat: implement Unary Expr in substrait (#8534) Signed-off-by: Ruihang Xia --- .../substrait/src/logical_plan/consumer.rs | 74 ++++----- .../substrait/src/logical_plan/producer.rs | 141 ++++++++++++------ .../tests/cases/roundtrip_logical_plan.rs | 40 +++++ 3 files changed, 169 insertions(+), 86 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index f6b556fc6448..f64dc764a7ed 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -1253,7 +1253,9 @@ struct BuiltinExprBuilder { impl BuiltinExprBuilder { pub fn try_from_name(name: &str) -> Option { match name { - "not" | "like" | "ilike" | "is_null" | "is_not_null" => Some(Self { + "not" | "like" | "ilike" | "is_null" | "is_not_null" | "is_true" + | "is_false" | "is_not_true" | "is_not_false" | "is_unknown" + | "is_not_unknown" | "negative" => Some(Self { expr_name: name.to_string(), }), _ => None, @@ -1267,14 +1269,11 @@ impl BuiltinExprBuilder { extensions: &HashMap, ) -> Result> { match self.expr_name.as_str() { - "not" => Self::build_not_expr(f, input_schema, extensions).await, "like" => Self::build_like_expr(false, f, input_schema, extensions).await, "ilike" => Self::build_like_expr(true, f, input_schema, extensions).await, - "is_null" => { - Self::build_is_null_expr(false, f, input_schema, extensions).await - } - "is_not_null" => { - Self::build_is_null_expr(true, f, input_schema, extensions).await + "not" | "negative" | "is_null" | "is_not_null" | "is_true" | "is_false" + | "is_not_true" | "is_not_false" | "is_unknown" | "is_not_unknown" => { + Self::build_unary_expr(&self.expr_name, f, input_schema, extensions).await } _ => { not_impl_err!("Unsupported builtin expression: {}", self.expr_name) @@ -1282,22 +1281,39 @@ impl BuiltinExprBuilder { } } - async fn build_not_expr( + async fn build_unary_expr( + fn_name: &str, f: &ScalarFunction, input_schema: &DFSchema, extensions: &HashMap, ) -> Result> { if f.arguments.len() != 1 { - return not_impl_err!("Expect one argument for `NOT` expr"); + return substrait_err!("Expect one argument for {fn_name} expr"); } let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { - return not_impl_err!("Invalid arguments type for `NOT` expr"); + return substrait_err!("Invalid arguments type for {fn_name} expr"); }; - let expr = from_substrait_rex(expr_substrait, input_schema, extensions) + let arg = from_substrait_rex(expr_substrait, input_schema, extensions) .await? .as_ref() .clone(); - Ok(Arc::new(Expr::Not(Box::new(expr)))) + let arg = Box::new(arg); + + let expr = match fn_name { + "not" => Expr::Not(arg), + "negative" => Expr::Negative(arg), + "is_null" => Expr::IsNull(arg), + "is_not_null" => Expr::IsNotNull(arg), + "is_true" => Expr::IsTrue(arg), + "is_false" => Expr::IsFalse(arg), + "is_not_true" => Expr::IsNotTrue(arg), + "is_not_false" => Expr::IsNotFalse(arg), + "is_unknown" => Expr::IsUnknown(arg), + "is_not_unknown" => Expr::IsNotUnknown(arg), + _ => return not_impl_err!("Unsupported builtin expression: {}", fn_name), + }; + + Ok(Arc::new(expr)) } async fn build_like_expr( @@ -1308,25 +1324,25 @@ impl BuiltinExprBuilder { ) -> Result> { let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" }; if f.arguments.len() != 3 { - return not_impl_err!("Expect three arguments for `{fn_name}` expr"); + return substrait_err!("Expect three arguments for `{fn_name}` expr"); } let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { - return not_impl_err!("Invalid arguments type for `{fn_name}` expr"); + return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; let expr = from_substrait_rex(expr_substrait, input_schema, extensions) .await? .as_ref() .clone(); let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else { - return not_impl_err!("Invalid arguments type for `{fn_name}` expr"); + return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; let pattern = from_substrait_rex(pattern_substrait, input_schema, extensions) .await? .as_ref() .clone(); let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type else { - return not_impl_err!("Invalid arguments type for `{fn_name}` expr"); + return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; let escape_char_expr = from_substrait_rex(escape_char_substrait, input_schema, extensions) @@ -1347,30 +1363,4 @@ impl BuiltinExprBuilder { case_insensitive, }))) } - - async fn build_is_null_expr( - is_not: bool, - f: &ScalarFunction, - input_schema: &DFSchema, - extensions: &HashMap, - ) -> Result> { - let fn_name = if is_not { "IS NOT NULL" } else { "IS NULL" }; - let arg = f.arguments.first().ok_or_else(|| { - substrait_datafusion_err!("expect one argument for `{fn_name}` expr") - })?; - match &arg.arg_type { - Some(ArgType::Value(e)) => { - let expr = from_substrait_rex(e, input_schema, extensions) - .await? - .as_ref() - .clone(); - if is_not { - Ok(Arc::new(Expr::IsNotNull(Box::new(expr)))) - } else { - Ok(Arc::new(Expr::IsNull(Box::new(expr)))) - } - } - _ => substrait_err!("Invalid arguments for `{fn_name}` expression"), - } - } } diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index c5f1278be6e0..81498964eb61 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -1083,50 +1083,76 @@ pub fn to_substrait_rex( col_ref_offset, extension_info, ), - Expr::IsNull(arg) => { - let arguments: Vec = vec![FunctionArgument { - arg_type: Some(ArgType::Value(to_substrait_rex( - arg, - schema, - col_ref_offset, - extension_info, - )?)), - }]; - - let function_name = "is_null".to_string(); - let function_anchor = _register_function(function_name, extension_info); - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments, - output_type: None, - args: vec![], - options: vec![], - })), - }) - } - Expr::IsNotNull(arg) => { - let arguments: Vec = vec![FunctionArgument { - arg_type: Some(ArgType::Value(to_substrait_rex( - arg, - schema, - col_ref_offset, - extension_info, - )?)), - }]; - - let function_name = "is_not_null".to_string(); - let function_anchor = _register_function(function_name, extension_info); - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments, - output_type: None, - args: vec![], - options: vec![], - })), - }) - } + Expr::Not(arg) => to_substrait_unary_scalar_fn( + "not", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsNull(arg) => to_substrait_unary_scalar_fn( + "is_null", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsNotNull(arg) => to_substrait_unary_scalar_fn( + "is_not_null", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsTrue(arg) => to_substrait_unary_scalar_fn( + "is_true", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsFalse(arg) => to_substrait_unary_scalar_fn( + "is_false", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsUnknown(arg) => to_substrait_unary_scalar_fn( + "is_unknown", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsNotTrue(arg) => to_substrait_unary_scalar_fn( + "is_not_true", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsNotFalse(arg) => to_substrait_unary_scalar_fn( + "is_not_false", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsNotUnknown(arg) => to_substrait_unary_scalar_fn( + "is_not_unknown", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::Negative(arg) => to_substrait_unary_scalar_fn( + "negative", + arg, + schema, + col_ref_offset, + extension_info, + ), _ => { not_impl_err!("Unsupported expression: {expr:?}") } @@ -1591,6 +1617,33 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { }) } +/// Util to generate substrait [RexType::ScalarFunction] with one argument +fn to_substrait_unary_scalar_fn( + fn_name: &str, + arg: &Expr, + schema: &DFSchemaRef, + col_ref_offset: usize, + extension_info: &mut ( + Vec, + HashMap, + ), +) -> Result { + let function_anchor = _register_function(fn_name.to_string(), extension_info); + let substrait_expr = to_substrait_rex(arg, schema, col_ref_offset, extension_info)?; + + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(substrait_expr)), + }], + output_type: None, + options: vec![], + ..Default::default() + })), + }) +} + fn try_to_substrait_null(v: &ScalarValue) -> Result { let default_nullability = r#type::Nullability::Nullable as i32; match v { diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 691fba864449..91d5a9469627 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -483,6 +483,46 @@ async fn roundtrip_ilike() -> Result<()> { roundtrip("SELECT f FROM data WHERE f ILIKE 'a%b'").await } +#[tokio::test] +async fn roundtrip_not() -> Result<()> { + roundtrip("SELECT * FROM data WHERE NOT d").await +} + +#[tokio::test] +async fn roundtrip_negative() -> Result<()> { + roundtrip("SELECT * FROM data WHERE -a = 1").await +} + +#[tokio::test] +async fn roundtrip_is_true() -> Result<()> { + roundtrip("SELECT * FROM data WHERE d IS TRUE").await +} + +#[tokio::test] +async fn roundtrip_is_false() -> Result<()> { + roundtrip("SELECT * FROM data WHERE d IS FALSE").await +} + +#[tokio::test] +async fn roundtrip_is_not_true() -> Result<()> { + roundtrip("SELECT * FROM data WHERE d IS NOT TRUE").await +} + +#[tokio::test] +async fn roundtrip_is_not_false() -> Result<()> { + roundtrip("SELECT * FROM data WHERE d IS NOT FALSE").await +} + +#[tokio::test] +async fn roundtrip_is_unknown() -> Result<()> { + roundtrip("SELECT * FROM data WHERE d IS UNKNOWN").await +} + +#[tokio::test] +async fn roundtrip_is_not_unknown() -> Result<()> { + roundtrip("SELECT * FROM data WHERE d IS NOT UNKNOWN").await +} + #[tokio::test] async fn roundtrip_union() -> Result<()> { roundtrip("SELECT a, e FROM data UNION SELECT a, e FROM data").await