From cbd22641d3673beec2cb045d658f8c585724b793 Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Fri, 24 May 2024 12:22:30 +0200 Subject: [PATCH] simplify Literal/ScalarValue null handling --- .../substrait/src/logical_plan/producer.rs | 224 +++--------------- 1 file changed, 28 insertions(+), 196 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 03c544034828b..825351df39557 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -1102,7 +1102,7 @@ pub fn to_substrait_rex( ))), }) } - Expr::Literal(value) => to_substrait_literal_expr(value, true), + Expr::Literal(value) => to_substrait_literal_expr(value), Expr::Alias(Alias { expr, .. }) => { to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info) } @@ -1532,10 +1532,9 @@ fn make_substrait_like_expr( }; let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset, extension_info)?; - let escape_char = to_substrait_literal_expr( - &ScalarValue::Utf8(escape_char.map(|c| c.to_string())), - true, - )?; + let escape_char = to_substrait_literal_expr(&ScalarValue::Utf8( + escape_char.map(|c| c.to_string()), + ))?; let arguments = vec![ FunctionArgument { arg_type: Some(ArgType::Value(expr)), @@ -1691,7 +1690,17 @@ fn to_substrait_bounds(window_frame: &WindowFrame) -> Result<(Bound, Bound)> { )) } -fn to_substrait_literal(value: &ScalarValue, nullable: bool) -> Result { +fn to_substrait_literal(value: &ScalarValue) -> Result { + if value.is_null() { + return Ok(Literal { + nullable: true, + type_variation_reference: DEFAULT_TYPE_REF, + literal_type: Some(LiteralType::Null(to_substrait_type( + &value.data_type(), + true, + )?)), + }); + } let (literal_type, type_variation_reference) = match value { ScalarValue::Boolean(Some(b)) => (LiteralType::Boolean(*b), DEFAULT_TYPE_REF), ScalarValue::Int8(Some(n)) => (LiteralType::I8(*n as i32), DEFAULT_TYPE_REF), @@ -1749,34 +1758,34 @@ fn to_substrait_literal(value: &ScalarValue, nullable: bool) -> Result }), DECIMAL_128_TYPE_REF, ), - ScalarValue::List(l) if !value.is_null() => ( + ScalarValue::List(l) => ( convert_array_to_literal_list(l)?, DEFAULT_CONTAINER_TYPE_REF, ), - ScalarValue::LargeList(l) if !value.is_null() => { + ScalarValue::LargeList(l) => { (convert_array_to_literal_list(l)?, LARGE_CONTAINER_TYPE_REF) } - ScalarValue::Struct(s) if !value.is_null() => ( + ScalarValue::Struct(s) => ( LiteralType::Struct(Struct { fields: s .columns() .iter() .zip(s.fields()) .map(|(col, field)| { - to_substrait_literal( - &ScalarValue::try_from_array(col, 0)?, - field.is_nullable(), - ) + to_substrait_literal(&ScalarValue::try_from_array(col, 0)?) }) .collect::>>()?, }), DEFAULT_TYPE_REF, ), - _ => (try_to_substrait_null(value)?, DEFAULT_TYPE_REF), + _ => ( + not_impl_err!("Unsupported literal: {value:?}")?, + DEFAULT_TYPE_REF, + ), }; Ok(Literal { - nullable, + nullable: false, type_variation_reference, literal_type: Some(literal_type), }) @@ -1789,12 +1798,7 @@ fn convert_array_to_literal_list( let nested_array = array.value(0); let values = (0..nested_array.len()) - .map(|i| { - to_substrait_literal( - &ScalarValue::try_from_array(&nested_array, i)?, - array.is_nullable(), - ) - }) + .map(|i| to_substrait_literal(&ScalarValue::try_from_array(&nested_array, i)?)) .collect::>>()?; if values.is_empty() { @@ -1810,8 +1814,8 @@ fn convert_array_to_literal_list( } } -fn to_substrait_literal_expr(value: &ScalarValue, nullable: bool) -> Result { - let literal = to_substrait_literal(value, nullable)?; +fn to_substrait_literal_expr(value: &ScalarValue) -> Result { + let literal = to_substrait_literal(value)?; Ok(Expression { rex_type: Some(RexType::Literal(literal)), }) @@ -1846,175 +1850,6 @@ fn to_substrait_unary_scalar_fn( }) } -fn try_to_substrait_null(v: &ScalarValue) -> Result { - let default_nullability = r#type::Nullability::Nullable as i32; - match v { - ScalarValue::Boolean(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Bool(r#type::Boolean { - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::Int8(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::I8(r#type::I8 { - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::UInt8(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::I8(r#type::I8 { - type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::Int16(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::I16(r#type::I16 { - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::UInt16(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::I16(r#type::I16 { - type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::Int32(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::I32(r#type::I32 { - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::UInt32(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::I32(r#type::I32 { - type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::Int64(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::I64(r#type::I64 { - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::UInt64(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::I64(r#type::I64 { - type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::Float32(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Fp32(r#type::Fp32 { - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::Float64(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Fp64(r#type::Fp64 { - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::TimestampSecond(None, _) => { - Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Timestamp(r#type::Timestamp { - type_variation_reference: TIMESTAMP_SECOND_TYPE_REF, - nullability: default_nullability, - })), - })) - } - ScalarValue::TimestampMillisecond(None, _) => { - Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Timestamp(r#type::Timestamp { - type_variation_reference: TIMESTAMP_MILLI_TYPE_REF, - nullability: default_nullability, - })), - })) - } - ScalarValue::TimestampMicrosecond(None, _) => { - Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Timestamp(r#type::Timestamp { - type_variation_reference: TIMESTAMP_MICRO_TYPE_REF, - nullability: default_nullability, - })), - })) - } - ScalarValue::TimestampNanosecond(None, _) => { - Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Timestamp(r#type::Timestamp { - type_variation_reference: TIMESTAMP_NANO_TYPE_REF, - nullability: default_nullability, - })), - })) - } - ScalarValue::Date32(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Date(r#type::Date { - type_variation_reference: DATE_32_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::Date64(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Date(r#type::Date { - type_variation_reference: DATE_64_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::Binary(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Binary(r#type::Binary { - type_variation_reference: DEFAULT_CONTAINER_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::LargeBinary(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Binary(r#type::Binary { - type_variation_reference: LARGE_CONTAINER_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::FixedSizeBinary(_, None) => { - Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Binary(r#type::Binary { - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, - })), - })) - } - ScalarValue::Utf8(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::String(r#type::String { - type_variation_reference: DEFAULT_CONTAINER_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::LargeUtf8(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::String(r#type::String { - type_variation_reference: LARGE_CONTAINER_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::Decimal128(None, p, s) => { - Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Decimal(r#type::Decimal { - scale: *s as i32, - precision: *p as i32, - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, - })), - })) - } - ScalarValue::List(l) => { - Ok(LiteralType::Null(to_substrait_type(l.data_type(), true)?)) - } - ScalarValue::LargeList(l) => { - Ok(LiteralType::Null(to_substrait_type(l.data_type(), true)?)) - } - ScalarValue::Struct(s) => { - Ok(LiteralType::Null(to_substrait_type(s.data_type(), true)?)) - } - // TODO: Extend support for remaining data types - _ => not_impl_err!("Unsupported literal: {v:?}"), - } -} - /// Try to convert an [Expr] to a [FieldReference]. /// Returns `Err` if the [Expr] is not a [Expr::Column]. fn try_to_substrait_field_reference( @@ -2174,11 +2009,8 @@ mod test { fn round_trip_literal(scalar: ScalarValue) -> Result<()> { println!("Checking round trip of {scalar:?}"); - // As DataFusion doesn't consider nullability as a property of the scalar, but field, - // it doesn't matter if we set nullability to true or false here. - let substrait_literal = to_substrait_literal(&scalar, true)?; + let substrait_literal = to_substrait_literal(&scalar)?; let roundtrip_scalar = from_substrait_literal(&substrait_literal)?; - println!("Substrait literal: {substrait_literal:?}"); assert_eq!(scalar, roundtrip_scalar); Ok(()) }