Skip to content

Commit

Permalink
feat: implement Unary Expr in substrait (#8534)
Browse files Browse the repository at this point in the history
Signed-off-by: Ruihang Xia <[email protected]>
  • Loading branch information
waynexia authored Dec 15, 2023
1 parent b7fde3c commit b71bec0
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 86 deletions.
74 changes: 32 additions & 42 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1253,7 +1253,9 @@ struct BuiltinExprBuilder {
impl BuiltinExprBuilder {
pub fn try_from_name(name: &str) -> Option<Self> {
match name {
"not" | "like" | "ilike" | "is_null" | "is_not_null" => Some(Self {
"not" | "like" | "ilike" | "is_null" | "is_not_null" | "is_true"
| "is_false" | "is_not_true" | "is_not_false" | "is_unknown"
| "is_not_unknown" | "negative" => Some(Self {
expr_name: name.to_string(),
}),
_ => None,
Expand All @@ -1267,37 +1269,51 @@ impl BuiltinExprBuilder {
extensions: &HashMap<u32, &String>,
) -> Result<Arc<Expr>> {
match self.expr_name.as_str() {
"not" => Self::build_not_expr(f, input_schema, extensions).await,
"like" => Self::build_like_expr(false, f, input_schema, extensions).await,
"ilike" => Self::build_like_expr(true, f, input_schema, extensions).await,
"is_null" => {
Self::build_is_null_expr(false, f, input_schema, extensions).await
}
"is_not_null" => {
Self::build_is_null_expr(true, f, input_schema, extensions).await
"not" | "negative" | "is_null" | "is_not_null" | "is_true" | "is_false"
| "is_not_true" | "is_not_false" | "is_unknown" | "is_not_unknown" => {
Self::build_unary_expr(&self.expr_name, f, input_schema, extensions).await
}
_ => {
not_impl_err!("Unsupported builtin expression: {}", self.expr_name)
}
}
}

async fn build_not_expr(
async fn build_unary_expr(
fn_name: &str,
f: &ScalarFunction,
input_schema: &DFSchema,
extensions: &HashMap<u32, &String>,
) -> Result<Arc<Expr>> {
if f.arguments.len() != 1 {
return not_impl_err!("Expect one argument for `NOT` expr");
return substrait_err!("Expect one argument for {fn_name} expr");
}
let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else {
return not_impl_err!("Invalid arguments type for `NOT` expr");
return substrait_err!("Invalid arguments type for {fn_name} expr");
};
let expr = from_substrait_rex(expr_substrait, input_schema, extensions)
let arg = from_substrait_rex(expr_substrait, input_schema, extensions)
.await?
.as_ref()
.clone();
Ok(Arc::new(Expr::Not(Box::new(expr))))
let arg = Box::new(arg);

let expr = match fn_name {
"not" => Expr::Not(arg),
"negative" => Expr::Negative(arg),
"is_null" => Expr::IsNull(arg),
"is_not_null" => Expr::IsNotNull(arg),
"is_true" => Expr::IsTrue(arg),
"is_false" => Expr::IsFalse(arg),
"is_not_true" => Expr::IsNotTrue(arg),
"is_not_false" => Expr::IsNotFalse(arg),
"is_unknown" => Expr::IsUnknown(arg),
"is_not_unknown" => Expr::IsNotUnknown(arg),
_ => return not_impl_err!("Unsupported builtin expression: {}", fn_name),
};

Ok(Arc::new(expr))
}

async fn build_like_expr(
Expand All @@ -1308,25 +1324,25 @@ impl BuiltinExprBuilder {
) -> Result<Arc<Expr>> {
let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" };
if f.arguments.len() != 3 {
return not_impl_err!("Expect three arguments for `{fn_name}` expr");
return substrait_err!("Expect three arguments for `{fn_name}` expr");
}

let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else {
return not_impl_err!("Invalid arguments type for `{fn_name}` expr");
return substrait_err!("Invalid arguments type for `{fn_name}` expr");
};
let expr = from_substrait_rex(expr_substrait, input_schema, extensions)
.await?
.as_ref()
.clone();
let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else {
return not_impl_err!("Invalid arguments type for `{fn_name}` expr");
return substrait_err!("Invalid arguments type for `{fn_name}` expr");
};
let pattern = from_substrait_rex(pattern_substrait, input_schema, extensions)
.await?
.as_ref()
.clone();
let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type else {
return not_impl_err!("Invalid arguments type for `{fn_name}` expr");
return substrait_err!("Invalid arguments type for `{fn_name}` expr");
};
let escape_char_expr =
from_substrait_rex(escape_char_substrait, input_schema, extensions)
Expand All @@ -1347,30 +1363,4 @@ impl BuiltinExprBuilder {
case_insensitive,
})))
}

async fn build_is_null_expr(
is_not: bool,
f: &ScalarFunction,
input_schema: &DFSchema,
extensions: &HashMap<u32, &String>,
) -> Result<Arc<Expr>> {
let fn_name = if is_not { "IS NOT NULL" } else { "IS NULL" };
let arg = f.arguments.first().ok_or_else(|| {
substrait_datafusion_err!("expect one argument for `{fn_name}` expr")
})?;
match &arg.arg_type {
Some(ArgType::Value(e)) => {
let expr = from_substrait_rex(e, input_schema, extensions)
.await?
.as_ref()
.clone();
if is_not {
Ok(Arc::new(Expr::IsNotNull(Box::new(expr))))
} else {
Ok(Arc::new(Expr::IsNull(Box::new(expr))))
}
}
_ => substrait_err!("Invalid arguments for `{fn_name}` expression"),
}
}
}
141 changes: 97 additions & 44 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1083,50 +1083,76 @@ pub fn to_substrait_rex(
col_ref_offset,
extension_info,
),
Expr::IsNull(arg) => {
let arguments: Vec<FunctionArgument> = vec![FunctionArgument {
arg_type: Some(ArgType::Value(to_substrait_rex(
arg,
schema,
col_ref_offset,
extension_info,
)?)),
}];

let function_name = "is_null".to_string();
let function_anchor = _register_function(function_name, extension_info);
Ok(Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
function_reference: function_anchor,
arguments,
output_type: None,
args: vec![],
options: vec![],
})),
})
}
Expr::IsNotNull(arg) => {
let arguments: Vec<FunctionArgument> = vec![FunctionArgument {
arg_type: Some(ArgType::Value(to_substrait_rex(
arg,
schema,
col_ref_offset,
extension_info,
)?)),
}];

let function_name = "is_not_null".to_string();
let function_anchor = _register_function(function_name, extension_info);
Ok(Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
function_reference: function_anchor,
arguments,
output_type: None,
args: vec![],
options: vec![],
})),
})
}
Expr::Not(arg) => to_substrait_unary_scalar_fn(
"not",
arg,
schema,
col_ref_offset,
extension_info,
),
Expr::IsNull(arg) => to_substrait_unary_scalar_fn(
"is_null",
arg,
schema,
col_ref_offset,
extension_info,
),
Expr::IsNotNull(arg) => to_substrait_unary_scalar_fn(
"is_not_null",
arg,
schema,
col_ref_offset,
extension_info,
),
Expr::IsTrue(arg) => to_substrait_unary_scalar_fn(
"is_true",
arg,
schema,
col_ref_offset,
extension_info,
),
Expr::IsFalse(arg) => to_substrait_unary_scalar_fn(
"is_false",
arg,
schema,
col_ref_offset,
extension_info,
),
Expr::IsUnknown(arg) => to_substrait_unary_scalar_fn(
"is_unknown",
arg,
schema,
col_ref_offset,
extension_info,
),
Expr::IsNotTrue(arg) => to_substrait_unary_scalar_fn(
"is_not_true",
arg,
schema,
col_ref_offset,
extension_info,
),
Expr::IsNotFalse(arg) => to_substrait_unary_scalar_fn(
"is_not_false",
arg,
schema,
col_ref_offset,
extension_info,
),
Expr::IsNotUnknown(arg) => to_substrait_unary_scalar_fn(
"is_not_unknown",
arg,
schema,
col_ref_offset,
extension_info,
),
Expr::Negative(arg) => to_substrait_unary_scalar_fn(
"negative",
arg,
schema,
col_ref_offset,
extension_info,
),
_ => {
not_impl_err!("Unsupported expression: {expr:?}")
}
Expand Down Expand Up @@ -1591,6 +1617,33 @@ fn to_substrait_literal(value: &ScalarValue) -> Result<Expression> {
})
}

/// Util to generate substrait [RexType::ScalarFunction] with one argument
fn to_substrait_unary_scalar_fn(
fn_name: &str,
arg: &Expr,
schema: &DFSchemaRef,
col_ref_offset: usize,
extension_info: &mut (
Vec<extensions::SimpleExtensionDeclaration>,
HashMap<String, u32>,
),
) -> Result<Expression> {
let function_anchor = _register_function(fn_name.to_string(), extension_info);
let substrait_expr = to_substrait_rex(arg, schema, col_ref_offset, extension_info)?;

Ok(Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
function_reference: function_anchor,
arguments: vec![FunctionArgument {
arg_type: Some(ArgType::Value(substrait_expr)),
}],
output_type: None,
options: vec![],
..Default::default()
})),
})
}

fn try_to_substrait_null(v: &ScalarValue) -> Result<LiteralType> {
let default_nullability = r#type::Nullability::Nullable as i32;
match v {
Expand Down
40 changes: 40 additions & 0 deletions datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,46 @@ async fn roundtrip_ilike() -> Result<()> {
roundtrip("SELECT f FROM data WHERE f ILIKE 'a%b'").await
}

#[tokio::test]
async fn roundtrip_not() -> Result<()> {
roundtrip("SELECT * FROM data WHERE NOT d").await
}

#[tokio::test]
async fn roundtrip_negative() -> Result<()> {
roundtrip("SELECT * FROM data WHERE -a = 1").await
}

#[tokio::test]
async fn roundtrip_is_true() -> Result<()> {
roundtrip("SELECT * FROM data WHERE d IS TRUE").await
}

#[tokio::test]
async fn roundtrip_is_false() -> Result<()> {
roundtrip("SELECT * FROM data WHERE d IS FALSE").await
}

#[tokio::test]
async fn roundtrip_is_not_true() -> Result<()> {
roundtrip("SELECT * FROM data WHERE d IS NOT TRUE").await
}

#[tokio::test]
async fn roundtrip_is_not_false() -> Result<()> {
roundtrip("SELECT * FROM data WHERE d IS NOT FALSE").await
}

#[tokio::test]
async fn roundtrip_is_unknown() -> Result<()> {
roundtrip("SELECT * FROM data WHERE d IS UNKNOWN").await
}

#[tokio::test]
async fn roundtrip_is_not_unknown() -> Result<()> {
roundtrip("SELECT * FROM data WHERE d IS NOT UNKNOWN").await
}

#[tokio::test]
async fn roundtrip_union() -> Result<()> {
roundtrip("SELECT a, e FROM data UNION SELECT a, e FROM data").await
Expand Down

0 comments on commit b71bec0

Please sign in to comment.