diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 5a71ab91db1a3..a08485fd35554 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -63,6 +63,7 @@ use substrait::proto::{FunctionArgument, SortField}; use datafusion::arrow::array::GenericListArray; use datafusion::common::plan_err; +use datafusion::common::scalar::ScalarStructBuilder; use datafusion::logical_expr::expr::{InList, InSubquery, Sort}; use std::collections::HashMap; use std::str::FromStr; @@ -1159,6 +1160,15 @@ pub(crate) fn from_substrait_type(dt: &substrait::proto::Type) -> Result { + let mut fields = vec![]; + for (i, f) in s.types.iter().enumerate() { + let field = + Field::new(&format!("c{i}"), from_substrait_type(f)?, true); + fields.push(field); + } + Ok(DataType::Struct(fields.into())) + } _ => not_impl_err!("Unsupported Substrait type: {s_kind:?}"), }, _ => not_impl_err!("`None` Substrait kind is not supported"), @@ -1318,6 +1328,18 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result { } } } + Some(LiteralType::Struct(s)) => { + let mut builder = ScalarStructBuilder::new(); + for (i, field) in s.fields.iter().enumerate() { + let sv = from_substrait_literal(field)?; + // c0, c1, ... align with e.g. SqlToRel::create_named_struct + builder = builder.with_scalar( + Field::new(&format!("c{i}"), sv.data_type(), field.nullable), + sv, + ); + } + builder.build()? + } Some(LiteralType::Null(ntype)) => from_substrait_null(ntype)?, _ => return not_impl_err!("Unsupported literal_type: {:?}", lit.literal_type), }; diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index bfdffdc3a260f..e216008c73dae 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -43,7 +43,7 @@ use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Opera use datafusion::prelude::Expr; use prost_types::Any as ProtoAny; use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields}; -use substrait::proto::expression::literal::List; +use substrait::proto::expression::literal::{List, Struct}; use substrait::proto::expression::subquery::InPredicate; use substrait::proto::expression::window_function::BoundsType; use substrait::proto::{CrossRel, ExchangeRel}; @@ -1751,6 +1751,18 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { ScalarValue::LargeList(l) if !value.is_null() => { (convert_array_to_literal_list(l)?, LARGE_CONTAINER_TYPE_REF) } + ScalarValue::Struct(s) if !value.is_null() => ( + LiteralType::Struct(Struct { + fields: s + .columns() + .iter() + .map(|col| { + to_substrait_literal(&ScalarValue::try_from_array(col, 0)?) + }) + .collect::>>()?, + }), + DEFAULT_TYPE_REF, + ), _ => (try_to_substrait_null(value)?, DEFAULT_TYPE_REF), }; @@ -1979,6 +1991,9 @@ fn try_to_substrait_null(v: &ScalarValue) -> Result { ScalarValue::LargeList(l) => { Ok(LiteralType::Null(to_substrait_type(l.data_type())?)) } + ScalarValue::Struct(s) => { + Ok(LiteralType::Null(to_substrait_type(s.data_type())?)) + } // TODO: Extend support for remaining data types _ => not_impl_err!("Unsupported literal: {v:?}"), } @@ -2061,6 +2076,7 @@ mod test { use crate::logical_plan::consumer::{from_substrait_literal, from_substrait_type}; use datafusion::arrow::array::GenericListArray; use datafusion::arrow::datatypes::Field; + use datafusion::common::scalar::ScalarStructBuilder; use super::*; @@ -2125,6 +2141,17 @@ mod test { ), )))?; + let c0 = Field::new("c0", DataType::Boolean, true); + let c1 = Field::new("c1", DataType::Int32, true); + let c2 = Field::new("c2", DataType::Utf8, true); + round_trip_literal( + ScalarStructBuilder::new() + .with_scalar(c0, ScalarValue::Boolean(Some(true))) + .with_scalar(c1, ScalarValue::Int32(Some(1))) + .with_scalar(c2, ScalarValue::Utf8(None)) + .build()?, + )?; + Ok(()) } @@ -2169,6 +2196,13 @@ mod test { round_trip_type(DataType::LargeList( Field::new_list_field(DataType::Int32, true).into(), ))?; + round_trip_type(DataType::Struct( + vec![ + Field::new("c0", DataType::Int32, true), + Field::new("c1", DataType::Utf8, true), + ] + .into(), + ))?; Ok(()) } diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 02371063ef131..8d0e96cedd406 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -675,6 +675,16 @@ async fn roundtrip_literal_list() -> Result<()> { .await } +#[tokio::test] +async fn roundtrip_literal_struct() -> Result<()> { + assert_expected_plan( + "SELECT STRUCT(1, true, CAST(NULL AS STRING)) FROM data", + "Projection: Struct({c0:1,c1:true,c2:})\ + \n TableScan: data projection=[]", + ) + .await +} + /// Construct a plan that cast columns. Only those SQL types are supported for now. #[tokio::test] async fn new_test_grammar() -> Result<()> {