From 6a1c0cb27392447d01dcd69ce157ea65fcac69df Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Thu, 23 May 2024 12:25:35 +0200 Subject: [PATCH] More properly handle nullability of types/literals in Substrait This isn't perfect; some things are still assumed to just always be nullable (e.g. Literal list elements). --- .../substrait/src/logical_plan/consumer.rs | 60 +++++++- .../substrait/src/logical_plan/producer.rs | 137 +++++++++++------- 2 files changed, 134 insertions(+), 63 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index a08485fd35554..3eb4b6d9729c5 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -1136,11 +1136,13 @@ pub(crate) fn from_substrait_type(dt: &substrait::proto::Type) -> Result { - let inner_type = - from_substrait_type(list.r#type.as_ref().ok_or_else(|| { - substrait_datafusion_err!("List type must have inner type") - })?)?; - let field = Arc::new(Field::new_list_field(inner_type, true)); + let inner_type = list.r#type.as_ref().ok_or_else(|| { + substrait_datafusion_err!("List type must have inner type") + })?; + let field = Arc::new(Field::new_list_field( + from_substrait_type(inner_type)?, + is_substrait_type_nullable(inner_type)?, + )); match list.type_variation_reference { DEFAULT_CONTAINER_TYPE_REF => Ok(DataType::List(field)), LARGE_CONTAINER_TYPE_REF => Ok(DataType::LargeList(field)), @@ -1163,8 +1165,11 @@ pub(crate) fn from_substrait_type(dt: &substrait::proto::Type) -> Result { let mut fields = vec![]; for (i, f) in s.types.iter().enumerate() { - let field = - Field::new(&format!("c{i}"), from_substrait_type(f)?, true); + let field = Field::new( + &format!("c{i}"), + from_substrait_type(f)?, + is_substrait_type_nullable(f)?, + ); fields.push(field); } Ok(DataType::Struct(fields.into())) @@ -1175,6 +1180,47 @@ pub(crate) fn from_substrait_type(dt: &substrait::proto::Type) -> Result Result { + fn is_nullable(nullability: i32) -> bool { + nullability != substrait::proto::r#type::Nullability::Required as i32 + } + + let nullable = match dtype + .clone() + .kind + .ok_or(substrait_datafusion_err!("Type must contain Kind"))? + { + r#type::Kind::Bool(val) => is_nullable(val.nullability), + r#type::Kind::I8(val) => is_nullable(val.nullability), + r#type::Kind::I16(val) => is_nullable(val.nullability), + r#type::Kind::I32(val) => is_nullable(val.nullability), + r#type::Kind::I64(val) => is_nullable(val.nullability), + r#type::Kind::Fp32(val) => is_nullable(val.nullability), + r#type::Kind::Fp64(val) => is_nullable(val.nullability), + r#type::Kind::String(val) => is_nullable(val.nullability), + r#type::Kind::Binary(val) => is_nullable(val.nullability), + r#type::Kind::Timestamp(val) => is_nullable(val.nullability), + r#type::Kind::Date(val) => is_nullable(val.nullability), + r#type::Kind::Time(val) => is_nullable(val.nullability), + r#type::Kind::IntervalYear(val) => is_nullable(val.nullability), + r#type::Kind::IntervalDay(val) => is_nullable(val.nullability), + r#type::Kind::TimestampTz(val) => is_nullable(val.nullability), + r#type::Kind::Uuid(val) => is_nullable(val.nullability), + r#type::Kind::FixedChar(val) => is_nullable(val.nullability), + r#type::Kind::Varchar(val) => is_nullable(val.nullability), + r#type::Kind::FixedBinary(val) => is_nullable(val.nullability), + r#type::Kind::Decimal(val) => is_nullable(val.nullability), + r#type::Kind::PrecisionTimestamp(val) => is_nullable(val.nullability), + r#type::Kind::PrecisionTimestampTz(val) => is_nullable(val.nullability), + r#type::Kind::Struct(val) => is_nullable(val.nullability), + r#type::Kind::List(val) => is_nullable(val.nullability), + r#type::Kind::Map(val) => is_nullable(val.nullability), + r#type::Kind::UserDefined(val) => is_nullable(val.nullability), + r#type::Kind::UserDefinedTypeReference(_) => true, // not implemented, assume nullable + }; + Ok(nullable) +} + fn from_substrait_bound( bound: &Option, is_lower: bool, diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index e216008c73dae..03c544034828b 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -1089,7 +1089,7 @@ pub fn to_substrait_rex( Ok(Expression { rex_type: Some(RexType::Cast(Box::new( substrait::proto::expression::Cast { - r#type: Some(to_substrait_type(data_type)?), + r#type: Some(to_substrait_type(data_type, true)?), input: Some(Box::new(to_substrait_rex( ctx, expr, @@ -1102,7 +1102,7 @@ pub fn to_substrait_rex( ))), }) } - Expr::Literal(value) => to_substrait_literal_expr(value), + Expr::Literal(value) => to_substrait_literal_expr(value, true), Expr::Alias(Alias { expr, .. }) => { to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info) } @@ -1296,75 +1296,79 @@ pub fn to_substrait_rex( } } -fn to_substrait_type(dt: &DataType) -> Result { - let default_nullability = r#type::Nullability::Required as i32; +fn to_substrait_type(dt: &DataType, nullable: bool) -> Result { + let nullability = if nullable { + r#type::Nullability::Nullable as i32 + } else { + r#type::Nullability::Required as i32 + }; match dt { DataType::Null => internal_err!("Null cast is not valid"), DataType::Boolean => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Bool(r#type::Boolean { type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, + nullability, })), }), DataType::Int8 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I8(r#type::I8 { type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, + nullability, })), }), DataType::UInt8 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I8(r#type::I8 { type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, - nullability: default_nullability, + nullability, })), }), DataType::Int16 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I16(r#type::I16 { type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, + nullability, })), }), DataType::UInt16 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I16(r#type::I16 { type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, - nullability: default_nullability, + nullability, })), }), DataType::Int32 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I32(r#type::I32 { type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, + nullability, })), }), DataType::UInt32 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I32(r#type::I32 { type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, - nullability: default_nullability, + nullability, })), }), DataType::Int64 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I64(r#type::I64 { type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, + nullability, })), }), DataType::UInt64 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I64(r#type::I64 { type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, - nullability: default_nullability, + nullability, })), }), // Float16 is not supported in Substrait DataType::Float32 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Fp32(r#type::Fp32 { type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, + nullability, })), }), DataType::Float64 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Fp64(r#type::Fp64 { type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, + nullability, })), }), // Timezone is ignored. @@ -1378,90 +1382,90 @@ fn to_substrait_type(dt: &DataType) -> Result { Ok(substrait::proto::Type { kind: Some(r#type::Kind::Timestamp(r#type::Timestamp { type_variation_reference, - nullability: default_nullability, + nullability, })), }) } DataType::Date32 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Date(r#type::Date { type_variation_reference: DATE_32_TYPE_REF, - nullability: default_nullability, + nullability, })), }), DataType::Date64 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Date(r#type::Date { type_variation_reference: DATE_64_TYPE_REF, - nullability: default_nullability, + nullability, })), }), DataType::Binary => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Binary(r#type::Binary { type_variation_reference: DEFAULT_CONTAINER_TYPE_REF, - nullability: default_nullability, + nullability, })), }), DataType::FixedSizeBinary(length) => Ok(substrait::proto::Type { kind: Some(r#type::Kind::FixedBinary(r#type::FixedBinary { length: *length, type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, + nullability, })), }), DataType::LargeBinary => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Binary(r#type::Binary { type_variation_reference: LARGE_CONTAINER_TYPE_REF, - nullability: default_nullability, + nullability, })), }), DataType::Utf8 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::String(r#type::String { type_variation_reference: DEFAULT_CONTAINER_TYPE_REF, - nullability: default_nullability, + nullability, })), }), DataType::LargeUtf8 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::String(r#type::String { type_variation_reference: LARGE_CONTAINER_TYPE_REF, - nullability: default_nullability, + nullability, })), }), DataType::List(inner) => { - let inner_type = to_substrait_type(inner.data_type())?; + let inner_type = to_substrait_type(inner.data_type(), inner.is_nullable())?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::List(Box::new(r#type::List { r#type: Some(Box::new(inner_type)), type_variation_reference: DEFAULT_CONTAINER_TYPE_REF, - nullability: default_nullability, + nullability, }))), }) } DataType::LargeList(inner) => { - let inner_type = to_substrait_type(inner.data_type())?; + let inner_type = to_substrait_type(inner.data_type(), inner.is_nullable())?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::List(Box::new(r#type::List { r#type: Some(Box::new(inner_type)), type_variation_reference: LARGE_CONTAINER_TYPE_REF, - nullability: default_nullability, + nullability, }))), }) } DataType::Struct(fields) => { let field_types = fields .iter() - .map(|field| to_substrait_type(field.data_type())) + .map(|field| to_substrait_type(field.data_type(), field.is_nullable())) .collect::>>()?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::Struct(r#type::Struct { types: field_types, type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, + nullability, })), }) } DataType::Decimal128(p, s) => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Decimal(r#type::Decimal { type_variation_reference: DECIMAL_128_TYPE_REF, - nullability: default_nullability, + nullability, scale: *s as i32, precision: *p as i32, })), @@ -1469,7 +1473,7 @@ fn to_substrait_type(dt: &DataType) -> Result { DataType::Decimal256(p, s) => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Decimal(r#type::Decimal { type_variation_reference: DECIMAL_256_TYPE_REF, - nullability: default_nullability, + nullability, scale: *s as i32, precision: *p as i32, })), @@ -1528,9 +1532,10 @@ fn make_substrait_like_expr( }; let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset, extension_info)?; - let escape_char = to_substrait_literal_expr(&ScalarValue::Utf8( - escape_char.map(|c| c.to_string()), - ))?; + let escape_char = to_substrait_literal_expr( + &ScalarValue::Utf8(escape_char.map(|c| c.to_string())), + true, + )?; let arguments = vec![ FunctionArgument { arg_type: Some(ArgType::Value(expr)), @@ -1686,7 +1691,7 @@ fn to_substrait_bounds(window_frame: &WindowFrame) -> Result<(Bound, Bound)> { )) } -fn to_substrait_literal(value: &ScalarValue) -> Result { +fn to_substrait_literal(value: &ScalarValue, nullable: bool) -> Result { let (literal_type, type_variation_reference) = match value { ScalarValue::Boolean(Some(b)) => (LiteralType::Boolean(*b), DEFAULT_TYPE_REF), ScalarValue::Int8(Some(n)) => (LiteralType::I8(*n as i32), DEFAULT_TYPE_REF), @@ -1756,8 +1761,12 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { fields: s .columns() .iter() - .map(|col| { - to_substrait_literal(&ScalarValue::try_from_array(col, 0)?) + .zip(s.fields()) + .map(|(col, field)| { + to_substrait_literal( + &ScalarValue::try_from_array(col, 0)?, + field.is_nullable(), + ) }) .collect::>>()?, }), @@ -1767,7 +1776,7 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { }; Ok(Literal { - nullable: true, + nullable, type_variation_reference, literal_type: Some(literal_type), }) @@ -1780,11 +1789,16 @@ fn convert_array_to_literal_list( let nested_array = array.value(0); let values = (0..nested_array.len()) - .map(|i| to_substrait_literal(&ScalarValue::try_from_array(&nested_array, i)?)) + .map(|i| { + to_substrait_literal( + &ScalarValue::try_from_array(&nested_array, i)?, + array.is_nullable(), + ) + }) .collect::>>()?; if values.is_empty() { - let et = match to_substrait_type(array.data_type())? { + let et = match to_substrait_type(array.data_type(), array.is_nullable())? { substrait::proto::Type { kind: Some(r#type::Kind::List(lt)), } => lt.as_ref().to_owned(), @@ -1796,8 +1810,8 @@ fn convert_array_to_literal_list( } } -fn to_substrait_literal_expr(value: &ScalarValue) -> Result { - let literal = to_substrait_literal(value)?; +fn to_substrait_literal_expr(value: &ScalarValue, nullable: bool) -> Result { + let literal = to_substrait_literal(value, nullable)?; Ok(Expression { rex_type: Some(RexType::Literal(literal)), }) @@ -1987,12 +2001,14 @@ fn try_to_substrait_null(v: &ScalarValue) -> Result { })), })) } - ScalarValue::List(l) => Ok(LiteralType::Null(to_substrait_type(l.data_type())?)), + ScalarValue::List(l) => { + Ok(LiteralType::Null(to_substrait_type(l.data_type(), true)?)) + } ScalarValue::LargeList(l) => { - Ok(LiteralType::Null(to_substrait_type(l.data_type())?)) + Ok(LiteralType::Null(to_substrait_type(l.data_type(), true)?)) } ScalarValue::Struct(s) => { - Ok(LiteralType::Null(to_substrait_type(s.data_type())?)) + Ok(LiteralType::Null(to_substrait_type(s.data_type(), true)?)) } // TODO: Extend support for remaining data types _ => not_impl_err!("Unsupported literal: {v:?}"), @@ -2141,8 +2157,8 @@ mod test { ), )))?; - let c0 = Field::new("c0", DataType::Boolean, true); - let c1 = Field::new("c1", DataType::Int32, true); + let c0 = Field::new("c0", DataType::Boolean, false); + let c1 = Field::new("c1", DataType::Int32, false); let c2 = Field::new("c2", DataType::Utf8, true); round_trip_literal( ScalarStructBuilder::new() @@ -2158,8 +2174,11 @@ mod test { fn round_trip_literal(scalar: ScalarValue) -> Result<()> { println!("Checking round trip of {scalar:?}"); - let substrait_literal = to_substrait_literal(&scalar)?; + // As DataFusion doesn't consider nullability as a property of the scalar, but field, + // it doesn't matter if we set nullability to true or false here. + let substrait_literal = to_substrait_literal(&scalar, true)?; let roundtrip_scalar = from_substrait_literal(&substrait_literal)?; + println!("Substrait literal: {substrait_literal:?}"); assert_eq!(scalar, roundtrip_scalar); Ok(()) } @@ -2190,16 +2209,20 @@ mod test { round_trip_type(DataType::LargeUtf8)?; round_trip_type(DataType::Decimal128(10, 2))?; round_trip_type(DataType::Decimal256(30, 2))?; - round_trip_type(DataType::List( - Field::new_list_field(DataType::Int32, true).into(), - ))?; - round_trip_type(DataType::LargeList( - Field::new_list_field(DataType::Int32, true).into(), - ))?; + + for nullable in [true, false] { + round_trip_type(DataType::List( + Field::new_list_field(DataType::Int32, nullable).into(), + ))?; + round_trip_type(DataType::LargeList( + Field::new_list_field(DataType::Int32, nullable).into(), + ))?; + } + round_trip_type(DataType::Struct( vec![ Field::new("c0", DataType::Int32, true), - Field::new("c1", DataType::Utf8, true), + Field::new("c1", DataType::Utf8, false), ] .into(), ))?; @@ -2210,7 +2233,9 @@ mod test { fn round_trip_type(dt: DataType) -> Result<()> { println!("Checking round trip of {dt:?}"); - let substrait = to_substrait_type(&dt)?; + // As DataFusion doesn't consider nullability as a property of the type, but field, + // it doesn't matter if we set nullability to true or false here. + let substrait = to_substrait_type(&dt, true)?; let roundtrip_dt = from_substrait_type(&substrait)?; assert_eq!(dt, roundtrip_dt); Ok(())