diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index 4ae30ae86cdd..56b52bd73f9b 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -517,6 +517,9 @@ make_error!(not_impl_err, not_impl_datafusion_err, NotImplemented); // Exposes a macro to create `DataFusionError::Execution` make_error!(exec_err, exec_datafusion_err, Execution); +// Exposes a macro to create `DataFusionError::Substrait` +make_error!(substrait_err, substrait_datafusion_err, Substrait); + // Exposes a macro to create `DataFusionError::SQL` #[macro_export] macro_rules! sql_err { diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index ffc9d094ab91..f6b556fc6448 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -17,7 +17,9 @@ use async_recursion::async_recursion; use datafusion::arrow::datatypes::{DataType, Field, TimeUnit}; -use datafusion::common::{not_impl_err, DFField, DFSchema, DFSchemaRef}; +use datafusion::common::{ + not_impl_err, substrait_datafusion_err, substrait_err, DFField, DFSchema, DFSchemaRef, +}; use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::{ @@ -73,16 +75,7 @@ use crate::variation_const::{ enum ScalarFunctionType { Builtin(BuiltinScalarFunction), Op(Operator), - /// [Expr::Not] - Not, - /// [Expr::Like] Used for filtering rows based on the given wildcard pattern. Case sensitive - Like, - /// [Expr::Like] Case insensitive operator counterpart of `Like` - ILike, - /// [Expr::IsNull] - IsNull, - /// [Expr::IsNotNull] - IsNotNull, + Expr(BuiltinExprBuilder), } pub fn name_to_op(name: &str) -> Result { @@ -127,14 +120,11 @@ fn scalar_function_type_from_str(name: &str) -> Result { return Ok(ScalarFunctionType::Builtin(fun)); } - match name { - "not" => Ok(ScalarFunctionType::Not), - "like" => Ok(ScalarFunctionType::Like), - "ilike" => Ok(ScalarFunctionType::ILike), - "is_null" => Ok(ScalarFunctionType::IsNull), - "is_not_null" => Ok(ScalarFunctionType::IsNotNull), - others => not_impl_err!("Unsupported function name: {others:?}"), + if let Some(builder) = BuiltinExprBuilder::try_from_name(name) { + return Ok(ScalarFunctionType::Expr(builder)); } + + not_impl_err!("Unsupported function name: {name:?}") } fn split_eq_and_noneq_join_predicate_with_nulls_equality( @@ -519,9 +509,7 @@ pub async fn from_substrait_rel( }, Some(RelType::ExtensionLeaf(extension)) => { let Some(ext_detail) = &extension.detail else { - return Err(DataFusionError::Substrait( - "Unexpected empty detail in ExtensionLeafRel".to_string(), - )); + return substrait_err!("Unexpected empty detail in ExtensionLeafRel"); }; let plan = ctx .state() @@ -531,18 +519,16 @@ pub async fn from_substrait_rel( } Some(RelType::ExtensionSingle(extension)) => { let Some(ext_detail) = &extension.detail else { - return Err(DataFusionError::Substrait( - "Unexpected empty detail in ExtensionSingleRel".to_string(), - )); + return substrait_err!("Unexpected empty detail in ExtensionSingleRel"); }; let plan = ctx .state() .serializer_registry() .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; let Some(input_rel) = &extension.input else { - return Err(DataFusionError::Substrait( - "ExtensionSingleRel doesn't contains input rel. Try use ExtensionLeafRel instead".to_string() - )); + return substrait_err!( + "ExtensionSingleRel doesn't contains input rel. Try use ExtensionLeafRel instead" + ); }; let input_plan = from_substrait_rel(ctx, input_rel, extensions).await?; let plan = plan.from_template(&plan.expressions(), &[input_plan]); @@ -550,9 +536,7 @@ pub async fn from_substrait_rel( } Some(RelType::ExtensionMulti(extension)) => { let Some(ext_detail) = &extension.detail else { - return Err(DataFusionError::Substrait( - "Unexpected empty detail in ExtensionSingleRel".to_string(), - )); + return substrait_err!("Unexpected empty detail in ExtensionSingleRel"); }; let plan = ctx .state() @@ -881,64 +865,8 @@ pub async fn from_substrait_rex( ), } } - ScalarFunctionType::Not => { - let arg = f.arguments.first().ok_or_else(|| { - DataFusionError::Substrait( - "expect one argument for `NOT` expr".to_string(), - ) - })?; - match &arg.arg_type { - Some(ArgType::Value(e)) => { - let expr = from_substrait_rex(e, input_schema, extensions) - .await? - .as_ref() - .clone(); - Ok(Arc::new(Expr::Not(Box::new(expr)))) - } - _ => not_impl_err!("Invalid arguments for Not expression"), - } - } - ScalarFunctionType::Like => { - make_datafusion_like(false, f, input_schema, extensions).await - } - ScalarFunctionType::ILike => { - make_datafusion_like(true, f, input_schema, extensions).await - } - ScalarFunctionType::IsNull => { - let arg = f.arguments.first().ok_or_else(|| { - DataFusionError::Substrait( - "expect one argument for `IS NULL` expr".to_string(), - ) - })?; - match &arg.arg_type { - Some(ArgType::Value(e)) => { - let expr = from_substrait_rex(e, input_schema, extensions) - .await? - .as_ref() - .clone(); - Ok(Arc::new(Expr::IsNull(Box::new(expr)))) - } - _ => not_impl_err!("Invalid arguments for IS NULL expression"), - } - } - ScalarFunctionType::IsNotNull => { - let arg = f.arguments.first().ok_or_else(|| { - DataFusionError::Substrait( - "expect one argument for `IS NOT NULL` expr".to_string(), - ) - })?; - match &arg.arg_type { - Some(ArgType::Value(e)) => { - let expr = from_substrait_rex(e, input_schema, extensions) - .await? - .as_ref() - .clone(); - Ok(Arc::new(Expr::IsNotNull(Box::new(expr)))) - } - _ => { - not_impl_err!("Invalid arguments for IS NOT NULL expression") - } - } + ScalarFunctionType::Expr(builder) => { + builder.build(f, input_schema, extensions).await } } } @@ -960,9 +888,7 @@ pub async fn from_substrait_rex( ), from_substrait_type(output_type)?, )))), - None => Err(DataFusionError::Substrait( - "Cast experssion without output type is not allowed".to_string(), - )), + None => substrait_err!("Cast experssion without output type is not allowed"), }, Some(RexType::WindowFunction(window)) => { let fun = match extensions.get(&window.function_reference) { @@ -1087,9 +1013,7 @@ fn from_substrait_type(dt: &substrait::proto::Type) -> Result { r#type::Kind::List(list) => { let inner_type = from_substrait_type(list.r#type.as_ref().ok_or_else(|| { - DataFusionError::Substrait( - "List type must have inner type".to_string(), - ) + substrait_datafusion_err!("List type must have inner type") })?)?; let field = Arc::new(Field::new("list_item", inner_type, true)); match list.type_variation_reference { @@ -1141,9 +1065,7 @@ fn from_substrait_bound( } } }, - None => Err(DataFusionError::Substrait( - "WindowFunction missing Substrait Bound kind".to_string(), - )), + None => substrait_err!("WindowFunction missing Substrait Bound kind"), }, None => { if is_lower { @@ -1162,36 +1084,28 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result { DEFAULT_TYPE_REF => ScalarValue::Int8(Some(*n as i8)), UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt8(Some(*n as u8)), others => { - return Err(DataFusionError::Substrait(format!( - "Unknown type variation reference {others}", - ))); + return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::I16(n)) => match lit.type_variation_reference { DEFAULT_TYPE_REF => ScalarValue::Int16(Some(*n as i16)), UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt16(Some(*n as u16)), others => { - return Err(DataFusionError::Substrait(format!( - "Unknown type variation reference {others}", - ))); + return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::I32(n)) => match lit.type_variation_reference { DEFAULT_TYPE_REF => ScalarValue::Int32(Some(*n)), UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt32(Some(*n as u32)), others => { - return Err(DataFusionError::Substrait(format!( - "Unknown type variation reference {others}", - ))); + return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::I64(n)) => match lit.type_variation_reference { DEFAULT_TYPE_REF => ScalarValue::Int64(Some(*n)), UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt64(Some(*n as u64)), others => { - return Err(DataFusionError::Substrait(format!( - "Unknown type variation reference {others}", - ))); + return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::Fp32(f)) => ScalarValue::Float32(Some(*f)), @@ -1202,9 +1116,7 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result { TIMESTAMP_MICRO_TYPE_REF => ScalarValue::TimestampMicrosecond(Some(*t), None), TIMESTAMP_NANO_TYPE_REF => ScalarValue::TimestampNanosecond(Some(*t), None), others => { - return Err(DataFusionError::Substrait(format!( - "Unknown type variation reference {others}", - ))); + return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::Date(d)) => ScalarValue::Date32(Some(*d)), @@ -1212,38 +1124,30 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result { DEFAULT_CONTAINER_TYPE_REF => ScalarValue::Utf8(Some(s.clone())), LARGE_CONTAINER_TYPE_REF => ScalarValue::LargeUtf8(Some(s.clone())), others => { - return Err(DataFusionError::Substrait(format!( - "Unknown type variation reference {others}", - ))); + return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::Binary(b)) => match lit.type_variation_reference { DEFAULT_CONTAINER_TYPE_REF => ScalarValue::Binary(Some(b.clone())), LARGE_CONTAINER_TYPE_REF => ScalarValue::LargeBinary(Some(b.clone())), others => { - return Err(DataFusionError::Substrait(format!( - "Unknown type variation reference {others}", - ))); + return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::FixedBinary(b)) => { ScalarValue::FixedSizeBinary(b.len() as _, Some(b.clone())) } Some(LiteralType::Decimal(d)) => { - let value: [u8; 16] = - d.value - .clone() - .try_into() - .or(Err(DataFusionError::Substrait( - "Failed to parse decimal value".to_string(), - )))?; + let value: [u8; 16] = d + .value + .clone() + .try_into() + .or(substrait_err!("Failed to parse decimal value"))?; let p = d.precision.try_into().map_err(|e| { - DataFusionError::Substrait(format!( - "Failed to parse decimal precision: {e}" - )) + substrait_datafusion_err!("Failed to parse decimal precision: {e}") })?; let s = d.scale.try_into().map_err(|e| { - DataFusionError::Substrait(format!("Failed to parse decimal scale: {e}")) + substrait_datafusion_err!("Failed to parse decimal scale: {e}") })?; ScalarValue::Decimal128( Some(std::primitive::i128::from_le_bytes(value)), @@ -1341,50 +1245,132 @@ fn from_substrait_null(null_type: &Type) -> Result { } } -async fn make_datafusion_like( - case_insensitive: bool, - f: &ScalarFunction, - input_schema: &DFSchema, - extensions: &HashMap, -) -> Result> { - let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" }; - if f.arguments.len() != 3 { - return not_impl_err!("Expect three arguments for `{fn_name}` expr"); +/// Build [`Expr`] from its name and required inputs. +struct BuiltinExprBuilder { + expr_name: String, +} + +impl BuiltinExprBuilder { + pub fn try_from_name(name: &str) -> Option { + match name { + "not" | "like" | "ilike" | "is_null" | "is_not_null" => Some(Self { + expr_name: name.to_string(), + }), + _ => None, + } } - let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { - return not_impl_err!("Invalid arguments type for `{fn_name}` expr"); - }; - let expr = from_substrait_rex(expr_substrait, input_schema, extensions) - .await? - .as_ref() - .clone(); - let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else { - return not_impl_err!("Invalid arguments type for `{fn_name}` expr"); - }; - let pattern = from_substrait_rex(pattern_substrait, input_schema, extensions) - .await? - .as_ref() - .clone(); - let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type else { - return not_impl_err!("Invalid arguments type for `{fn_name}` expr"); - }; - let escape_char_expr = - from_substrait_rex(escape_char_substrait, input_schema, extensions) + pub async fn build( + self, + f: &ScalarFunction, + input_schema: &DFSchema, + extensions: &HashMap, + ) -> Result> { + match self.expr_name.as_str() { + "not" => Self::build_not_expr(f, input_schema, extensions).await, + "like" => Self::build_like_expr(false, f, input_schema, extensions).await, + "ilike" => Self::build_like_expr(true, f, input_schema, extensions).await, + "is_null" => { + Self::build_is_null_expr(false, f, input_schema, extensions).await + } + "is_not_null" => { + Self::build_is_null_expr(true, f, input_schema, extensions).await + } + _ => { + not_impl_err!("Unsupported builtin expression: {}", self.expr_name) + } + } + } + + async fn build_not_expr( + f: &ScalarFunction, + input_schema: &DFSchema, + extensions: &HashMap, + ) -> Result> { + if f.arguments.len() != 1 { + return not_impl_err!("Expect one argument for `NOT` expr"); + } + let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { + return not_impl_err!("Invalid arguments type for `NOT` expr"); + }; + let expr = from_substrait_rex(expr_substrait, input_schema, extensions) .await? .as_ref() .clone(); - let Expr::Literal(ScalarValue::Utf8(escape_char)) = escape_char_expr else { - return Err(DataFusionError::Substrait(format!( - "Expect Utf8 literal for escape char, but found {escape_char_expr:?}", - ))); - }; + Ok(Arc::new(Expr::Not(Box::new(expr)))) + } + + async fn build_like_expr( + case_insensitive: bool, + f: &ScalarFunction, + input_schema: &DFSchema, + extensions: &HashMap, + ) -> Result> { + let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" }; + if f.arguments.len() != 3 { + return not_impl_err!("Expect three arguments for `{fn_name}` expr"); + } + + let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { + return not_impl_err!("Invalid arguments type for `{fn_name}` expr"); + }; + let expr = from_substrait_rex(expr_substrait, input_schema, extensions) + .await? + .as_ref() + .clone(); + let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else { + return not_impl_err!("Invalid arguments type for `{fn_name}` expr"); + }; + let pattern = from_substrait_rex(pattern_substrait, input_schema, extensions) + .await? + .as_ref() + .clone(); + let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type else { + return not_impl_err!("Invalid arguments type for `{fn_name}` expr"); + }; + let escape_char_expr = + from_substrait_rex(escape_char_substrait, input_schema, extensions) + .await? + .as_ref() + .clone(); + let Expr::Literal(ScalarValue::Utf8(escape_char)) = escape_char_expr else { + return substrait_err!( + "Expect Utf8 literal for escape char, but found {escape_char_expr:?}" + ); + }; - Ok(Arc::new(Expr::Like(Like { - negated: false, - expr: Box::new(expr), - pattern: Box::new(pattern), - escape_char: escape_char.map(|c| c.chars().next().unwrap()), - case_insensitive, - }))) + Ok(Arc::new(Expr::Like(Like { + negated: false, + expr: Box::new(expr), + pattern: Box::new(pattern), + escape_char: escape_char.map(|c| c.chars().next().unwrap()), + case_insensitive, + }))) + } + + async fn build_is_null_expr( + is_not: bool, + f: &ScalarFunction, + input_schema: &DFSchema, + extensions: &HashMap, + ) -> Result> { + let fn_name = if is_not { "IS NOT NULL" } else { "IS NULL" }; + let arg = f.arguments.first().ok_or_else(|| { + substrait_datafusion_err!("expect one argument for `{fn_name}` expr") + })?; + match &arg.arg_type { + Some(ArgType::Value(e)) => { + let expr = from_substrait_rex(e, input_schema, extensions) + .await? + .as_ref() + .clone(); + if is_not { + Ok(Arc::new(Expr::IsNotNull(Box::new(expr)))) + } else { + Ok(Arc::new(Expr::IsNull(Box::new(expr)))) + } + } + _ => substrait_err!("Invalid arguments for `{fn_name}` expression"), + } + } }