Skip to content

Commit

Permalink
Add support for Substrait Struct literals and type (#10622)
Browse files Browse the repository at this point in the history
* Add support for (un-named) Substrait Struct literal

Adds support for converting from DataFusion Struct ScalarValues into Substrait Struct Literals and back.
All structs are assumed to be unnamed - ie fields are renamed
into "c0", "c1", etc

* add converting from Substrait Struct type

* cargo fmt --all

* Unit test for NULL inside Struct

* retry ci
  • Loading branch information
Blizzara authored May 23, 2024
1 parent c75a957 commit 19d9174
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 1 deletion.
22 changes: 22 additions & 0 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ use substrait::proto::{FunctionArgument, SortField};

use datafusion::arrow::array::GenericListArray;
use datafusion::common::plan_err;
use datafusion::common::scalar::ScalarStructBuilder;
use datafusion::logical_expr::expr::{InList, InSubquery, Sort};
use std::collections::HashMap;
use std::str::FromStr;
Expand Down Expand Up @@ -1159,6 +1160,15 @@ pub(crate) fn from_substrait_type(dt: &substrait::proto::Type) -> Result<DataTyp
"Unsupported Substrait type variation {v} of type {s_kind:?}"
),
},
r#type::Kind::Struct(s) => {
let mut fields = vec![];
for (i, f) in s.types.iter().enumerate() {
let field =
Field::new(&format!("c{i}"), from_substrait_type(f)?, true);
fields.push(field);
}
Ok(DataType::Struct(fields.into()))
}
_ => not_impl_err!("Unsupported Substrait type: {s_kind:?}"),
},
_ => not_impl_err!("`None` Substrait kind is not supported"),
Expand Down Expand Up @@ -1318,6 +1328,18 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result<ScalarValue> {
}
}
}
Some(LiteralType::Struct(s)) => {
let mut builder = ScalarStructBuilder::new();
for (i, field) in s.fields.iter().enumerate() {
let sv = from_substrait_literal(field)?;
// c0, c1, ... align with e.g. SqlToRel::create_named_struct
builder = builder.with_scalar(
Field::new(&format!("c{i}"), sv.data_type(), field.nullable),
sv,
);
}
builder.build()?
}
Some(LiteralType::Null(ntype)) => from_substrait_null(ntype)?,
_ => return not_impl_err!("Unsupported literal_type: {:?}", lit.literal_type),
};
Expand Down
36 changes: 35 additions & 1 deletion datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Opera
use datafusion::prelude::Expr;
use prost_types::Any as ProtoAny;
use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields};
use substrait::proto::expression::literal::List;
use substrait::proto::expression::literal::{List, Struct};
use substrait::proto::expression::subquery::InPredicate;
use substrait::proto::expression::window_function::BoundsType;
use substrait::proto::{CrossRel, ExchangeRel};
Expand Down Expand Up @@ -1751,6 +1751,18 @@ fn to_substrait_literal(value: &ScalarValue) -> Result<Literal> {
ScalarValue::LargeList(l) if !value.is_null() => {
(convert_array_to_literal_list(l)?, LARGE_CONTAINER_TYPE_REF)
}
ScalarValue::Struct(s) if !value.is_null() => (
LiteralType::Struct(Struct {
fields: s
.columns()
.iter()
.map(|col| {
to_substrait_literal(&ScalarValue::try_from_array(col, 0)?)
})
.collect::<Result<Vec<_>>>()?,
}),
DEFAULT_TYPE_REF,
),
_ => (try_to_substrait_null(value)?, DEFAULT_TYPE_REF),
};

Expand Down Expand Up @@ -1979,6 +1991,9 @@ fn try_to_substrait_null(v: &ScalarValue) -> Result<LiteralType> {
ScalarValue::LargeList(l) => {
Ok(LiteralType::Null(to_substrait_type(l.data_type())?))
}
ScalarValue::Struct(s) => {
Ok(LiteralType::Null(to_substrait_type(s.data_type())?))
}
// TODO: Extend support for remaining data types
_ => not_impl_err!("Unsupported literal: {v:?}"),
}
Expand Down Expand Up @@ -2061,6 +2076,7 @@ mod test {
use crate::logical_plan::consumer::{from_substrait_literal, from_substrait_type};
use datafusion::arrow::array::GenericListArray;
use datafusion::arrow::datatypes::Field;
use datafusion::common::scalar::ScalarStructBuilder;

use super::*;

Expand Down Expand Up @@ -2125,6 +2141,17 @@ mod test {
),
)))?;

let c0 = Field::new("c0", DataType::Boolean, true);
let c1 = Field::new("c1", DataType::Int32, true);
let c2 = Field::new("c2", DataType::Utf8, true);
round_trip_literal(
ScalarStructBuilder::new()
.with_scalar(c0, ScalarValue::Boolean(Some(true)))
.with_scalar(c1, ScalarValue::Int32(Some(1)))
.with_scalar(c2, ScalarValue::Utf8(None))
.build()?,
)?;

Ok(())
}

Expand Down Expand Up @@ -2169,6 +2196,13 @@ mod test {
round_trip_type(DataType::LargeList(
Field::new_list_field(DataType::Int32, true).into(),
))?;
round_trip_type(DataType::Struct(
vec![
Field::new("c0", DataType::Int32, true),
Field::new("c1", DataType::Utf8, true),
]
.into(),
))?;

Ok(())
}
Expand Down
10 changes: 10 additions & 0 deletions datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,16 @@ async fn roundtrip_literal_list() -> Result<()> {
.await
}

#[tokio::test]
async fn roundtrip_literal_struct() -> Result<()> {
assert_expected_plan(
"SELECT STRUCT(1, true, CAST(NULL AS STRING)) FROM data",
"Projection: Struct({c0:1,c1:true,c2:})\
\n TableScan: data projection=[]",
)
.await
}

/// Construct a plan that cast columns. Only those SQL types are supported for now.
#[tokio::test]
async fn new_test_grammar() -> Result<()> {
Expand Down

0 comments on commit 19d9174

Please sign in to comment.