diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index a4f724202475..991aa61fbf15 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -65,7 +65,7 @@ use std::str::FromStr; use std::sync::Arc; use substrait::proto::exchange_rel::ExchangeKind; use substrait::proto::expression::literal::user_defined::Val; -use substrait::proto::expression::literal::IntervalDayToSecond; +use substrait::proto::expression::literal::{IntervalDayToSecond, IntervalYearToMonth}; use substrait::proto::expression::subquery::SubqueryType; use substrait::proto::expression::{self, FieldReference, Literal, ScalarFunction}; use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile; @@ -1414,7 +1414,7 @@ fn from_substrait_type( })?; let field = Arc::new(Field::new_list_field( from_substrait_type(inner_type, dfs_names, name_idx)?, - // We ignore Substrait's nullability here to match to_substrait_literal + // We ignore Substrait's nullability here to match to_substrait_literal // which always creates nullable lists true, )); @@ -1445,12 +1445,15 @@ fn from_substrait_type( )); match map.type_variation_reference { DEFAULT_CONTAINER_TYPE_VARIATION_REF => { - Ok(DataType::Map(Arc::new(Field::new_struct( - "entries", - [key_field, value_field], - false, // The inner map field is always non-nullable (Arrow #1697), - )), false)) - }, + Ok(DataType::Map( + Arc::new(Field::new_struct( + "entries", + [key_field, value_field], + false, // The inner map field is always non-nullable (Arrow #1697), + )), + false, + )) + } v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" )?, @@ -1467,14 +1470,33 @@ fn from_substrait_type( "Unsupported Substrait type variation {v} of type {s_kind:?}" ), }, + r#type::Kind::IntervalYear(i) => match i.type_variation_reference { + DEFAULT_TYPE_VARIATION_REF => { + Ok(DataType::Interval(IntervalUnit::YearMonth)) + } + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + }, + r#type::Kind::IntervalDay(i) => match i.type_variation_reference { + DEFAULT_TYPE_VARIATION_REF => { + Ok(DataType::Interval(IntervalUnit::DayTime)) + } + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + }, r#type::Kind::UserDefined(u) => { match u.type_reference { + // Kept for backwards compatibility, use IntervalYear instead INTERVAL_YEAR_MONTH_TYPE_REF => { Ok(DataType::Interval(IntervalUnit::YearMonth)) } + // Kept for backwards compatibility, use IntervalDay instead INTERVAL_DAY_TIME_TYPE_REF => { Ok(DataType::Interval(IntervalUnit::DayTime)) } + // Not supported yet by Substrait INTERVAL_MONTH_DAY_NANO_TYPE_REF => { Ok(DataType::Interval(IntervalUnit::MonthDayNano)) } @@ -1484,7 +1506,7 @@ fn from_substrait_type( u.type_variation_reference ), } - }, + } r#type::Kind::Struct(s) => Ok(DataType::Struct(from_substrait_struct_type( s, dfs_names, name_idx, )?)), @@ -1753,11 +1775,16 @@ fn from_substrait_literal( seconds, microseconds, })) => { + // DF only supports millisecond precision, so we lose the micros here ScalarValue::new_interval_dt(*days, (seconds * 1000) + (microseconds / 1000)) } + Some(LiteralType::IntervalYearToMonth(IntervalYearToMonth { years, months })) => { + ScalarValue::new_interval_ym(*years, *months) + } Some(LiteralType::FixedChar(c)) => ScalarValue::Utf8(Some(c.clone())), Some(LiteralType::UserDefined(user_defined)) => { match user_defined.type_reference { + // Kept for backwards compatibility, use IntervalYearToMonth instead INTERVAL_YEAR_MONTH_TYPE_REF => { let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { return substrait_err!("Interval year month value is empty"); @@ -1770,6 +1797,7 @@ fn from_substrait_literal( })?; ScalarValue::IntervalYearMonth(Some(i32::from_le_bytes(value_slice))) } + // Kept for backwards compatibility, use IntervalDayToSecond instead INTERVAL_DAY_TIME_TYPE_REF => { let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { return substrait_err!("Interval day time value is empty"); diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 8d039a050249..7849d0bd431e 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -48,12 +48,11 @@ use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Opera use datafusion::prelude::Expr; use pbjson_types::Any as ProtoAny; use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields}; -use substrait::proto::expression::literal::user_defined::Val; -use substrait::proto::expression::literal::UserDefined; -use substrait::proto::expression::literal::{List, Struct}; +use substrait::proto::expression::literal::{ + user_defined, IntervalDayToSecond, IntervalYearToMonth, List, Struct, UserDefined, +}; use substrait::proto::expression::subquery::InPredicate; use substrait::proto::expression::window_function::BoundsType; -use substrait::proto::r#type::{parameter, Parameter}; use substrait::proto::read_rel::VirtualTable; use substrait::proto::{CrossRel, ExchangeRel}; use substrait::{ @@ -95,9 +94,7 @@ use crate::variation_const::{ DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, - INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_DAY_TIME_TYPE_URL, INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_URL, - INTERVAL_YEAR_MONTH_TYPE_REF, INTERVAL_YEAR_MONTH_TYPE_URL, LARGE_CONTAINER_TYPE_VARIATION_REF, TIMESTAMP_MICRO_TYPE_VARIATION_REF, TIMESTAMP_MILLI_TYPE_VARIATION_REF, TIMESTAMP_NANO_TYPE_VARIATION_REF, TIMESTAMP_SECOND_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, @@ -1534,47 +1531,31 @@ fn to_substrait_type(dt: &DataType, nullable: bool) -> Result { - // define two type parameters for convenience - let i32_param = Parameter { - parameter: Some(parameter::Parameter::DataType(substrait::proto::Type { - kind: Some(r#type::Kind::I32(r#type::I32 { + match interval_unit { + IntervalUnit::YearMonth => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::IntervalYear(r#type::IntervalYear { type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability: r#type::Nullability::Unspecified as i32, + nullability, })), - })), - }; - let i64_param = Parameter { - parameter: Some(parameter::Parameter::DataType(substrait::proto::Type { - kind: Some(r#type::Kind::I64(r#type::I64 { + }), + IntervalUnit::DayTime => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::IntervalDay(r#type::IntervalDay { type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability: r#type::Nullability::Unspecified as i32, + nullability, })), - })), - }; - - let (type_parameters, type_reference) = match interval_unit { - IntervalUnit::YearMonth => { - let type_parameters = vec![i32_param]; - (type_parameters, INTERVAL_YEAR_MONTH_TYPE_REF) - } - IntervalUnit::DayTime => { - let type_parameters = vec![i64_param]; - (type_parameters, INTERVAL_DAY_TIME_TYPE_REF) - } + }), IntervalUnit::MonthDayNano => { - // use 2 `i64` as `i128` - let type_parameters = vec![i64_param.clone(), i64_param]; - (type_parameters, INTERVAL_MONTH_DAY_NANO_TYPE_REF) + // Substrait doesn't currently support this type, so we represent it as a UDT + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::UserDefined(r#type::UserDefined { + type_reference: INTERVAL_MONTH_DAY_NANO_TYPE_REF, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + type_parameters: vec![], + })), + }) } - }; - Ok(substrait::proto::Type { - kind: Some(r#type::Kind::UserDefined(r#type::UserDefined { - type_reference, - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - type_parameters, - })), - }) + } } DataType::Binary => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Binary(r#type::Binary { @@ -1954,45 +1935,23 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { (LiteralType::Date(*d), DATE_32_TYPE_VARIATION_REF) } // Date64 literal is not supported in Substrait - ScalarValue::IntervalYearMonth(Some(i)) => { - let bytes = i.to_le_bytes(); - ( - LiteralType::UserDefined(UserDefined { - type_reference: INTERVAL_YEAR_MONTH_TYPE_REF, - type_parameters: vec![Parameter { - parameter: Some(parameter::Parameter::DataType( - substrait::proto::Type { - kind: Some(r#type::Kind::I32(r#type::I32 { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability: r#type::Nullability::Required as i32, - })), - }, - )), - }], - val: Some(Val::Value(ProtoAny { - type_url: INTERVAL_YEAR_MONTH_TYPE_URL.to_string(), - value: bytes.to_vec().into(), - })), - }), - INTERVAL_YEAR_MONTH_TYPE_REF, - ) - } + ScalarValue::IntervalYearMonth(Some(i)) => ( + LiteralType::IntervalYearToMonth(IntervalYearToMonth { + // DF only tracks total months, but there should always be 12 months in a year + years: *i / 12, + months: *i % 12, + }), + DEFAULT_TYPE_VARIATION_REF, + ), ScalarValue::IntervalMonthDayNano(Some(i)) => { - // treat `i128` as two contiguous `i64` + // IntervalMonthDayNano is internally represented as a 128-bit integer, containing + // months (32bit), days (32bit), and nanoseconds (64bit) let bytes = i.to_byte_slice(); - let i64_param = Parameter { - parameter: Some(parameter::Parameter::DataType(substrait::proto::Type { - kind: Some(r#type::Kind::I64(r#type::I64 { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability: r#type::Nullability::Required as i32, - })), - })), - }; ( LiteralType::UserDefined(UserDefined { type_reference: INTERVAL_MONTH_DAY_NANO_TYPE_REF, - type_parameters: vec![i64_param.clone(), i64_param], - val: Some(Val::Value(ProtoAny { + type_parameters: vec![], + val: Some(user_defined::Val::Value(ProtoAny { type_url: INTERVAL_MONTH_DAY_NANO_TYPE_URL.to_string(), value: bytes.to_vec().into(), })), @@ -2000,29 +1959,14 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { INTERVAL_MONTH_DAY_NANO_TYPE_REF, ) } - ScalarValue::IntervalDayTime(Some(i)) => { - let bytes = i.to_byte_slice(); - ( - LiteralType::UserDefined(UserDefined { - type_reference: INTERVAL_DAY_TIME_TYPE_REF, - type_parameters: vec![Parameter { - parameter: Some(parameter::Parameter::DataType( - substrait::proto::Type { - kind: Some(r#type::Kind::I64(r#type::I64 { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability: r#type::Nullability::Required as i32, - })), - }, - )), - }], - val: Some(Val::Value(ProtoAny { - type_url: INTERVAL_DAY_TIME_TYPE_URL.to_string(), - value: bytes.to_vec().into(), - })), - }), - INTERVAL_DAY_TIME_TYPE_REF, - ) - } + ScalarValue::IntervalDayTime(Some(i)) => ( + LiteralType::IntervalDayToSecond(IntervalDayToSecond { + days: i.days, + seconds: i.milliseconds / 1000, + microseconds: (i.milliseconds % 1000) * 1000, + }), + DEFAULT_TYPE_VARIATION_REF, + ), ScalarValue::Binary(Some(b)) => ( LiteralType::Binary(b.clone()), DEFAULT_CONTAINER_TYPE_VARIATION_REF,