Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add substrait support for Interval types and literals #10646

Merged
merged 4 commits into from
May 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 72 additions & 4 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
// under the License.

use async_recursion::async_recursion;
use datafusion::arrow::datatypes::{DataType, Field, TimeUnit};
use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit};
use datafusion::common::{
not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef,
};
Expand All @@ -39,6 +39,7 @@ use datafusion::{
scalar::ScalarValue,
};
use substrait::proto::exchange_rel::ExchangeKind;
use substrait::proto::expression::literal::user_defined::Val;
use substrait::proto::expression::subquery::SubqueryType;
use substrait::proto::expression::{FieldReference, Literal, ScalarFunction};
use substrait::proto::{
Expand Down Expand Up @@ -71,9 +72,10 @@ use std::sync::Arc;

use crate::variation_const::{
DATE_32_TYPE_REF, DATE_64_TYPE_REF, DECIMAL_128_TYPE_REF, DECIMAL_256_TYPE_REF,
DEFAULT_CONTAINER_TYPE_REF, DEFAULT_TYPE_REF, LARGE_CONTAINER_TYPE_REF,
TIMESTAMP_MICRO_TYPE_REF, TIMESTAMP_MILLI_TYPE_REF, TIMESTAMP_NANO_TYPE_REF,
TIMESTAMP_SECOND_TYPE_REF, UNSIGNED_INTEGER_TYPE_REF,
DEFAULT_CONTAINER_TYPE_REF, DEFAULT_TYPE_REF, INTERVAL_DAY_TIME_TYPE_REF,
INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_YEAR_MONTH_TYPE_REF,
LARGE_CONTAINER_TYPE_REF, TIMESTAMP_MICRO_TYPE_REF, TIMESTAMP_MILLI_TYPE_REF,
TIMESTAMP_NANO_TYPE_REF, TIMESTAMP_SECOND_TYPE_REF, UNSIGNED_INTEGER_TYPE_REF,
};

enum ScalarFunctionType {
Expand Down Expand Up @@ -1160,6 +1162,24 @@ pub(crate) fn from_substrait_type(dt: &substrait::proto::Type) -> Result<DataTyp
"Unsupported Substrait type variation {v} of type {s_kind:?}"
),
},
r#type::Kind::UserDefined(u) => {
match u.type_reference {
INTERVAL_YEAR_MONTH_TYPE_REF => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may be totally wrong, but are you sure this is how type_reference is supposed to work? I'd kind of expect the type_reference to map to an extension / MappingType::ExtensionType, kinda like function_reference.

Copy link
Contributor

@alamb alamb May 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not at all sure how this is supposed to work -- when I was reviewing this PR I think I concluded it followed the existing patterns.

If you have additional information / knowledge that would help improve things I think we would welcome that help

Copy link
Member Author

@waynexia waynexia May 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two "references":

https://github.com/apache/datafusion/pull/10646/files#diff-d1c5f4c37ac8286d2045acb61bee17382179469557132eb02844413b260ae41bR1440-R1441

To my understanding, type variation (like these) are for different types from one base type. And type references (like these) are for different base types, those user-defined types.

I'll submit a patch to add document for the usage tomorrow.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, actually I think both references (type_variation_reference and type_reference) have the same problem - I hadn't realized it affects the type_variation_reference too. Now, if each system defines the type/variation references as consts, then plans will look compatible, but may have different meaning (nothing tells another Substrait producer/consumer that type_variation_reference = 1 on a list means a large list, for example).

I believe both should be defined as SimpleExtension files, following the schema (like here https://github.com/apache/arrow/blob/main/format/substrait/extension_types.yaml) rather than as constants (kinda what

//! we make use of the [simple extensions](https://substrait.io/extensions/#simple-extensions) of substrait
already says 😅). And then (I think) in the plan itself, there'd be instances of the SimpleExtensionDeclaration messages (https://github.com/substrait-io/substrait/blob/4f5b4ac4d473c9f03f30f86eca080073d99ef1c7/proto/substrait/extensions/extensions.proto#L39), and the type_reference and type_variation_reference would link to the type_anchor and type_variation_anchor, rather than to the hard-coded constants.

That way one can (in theory, at least) teach another Substrait producer/consumer about the DataFusion extensions and keep plans compatible, or at least the systems will recognize that the plans are incompatible as it refers to extension URIs that the other system doesn't know about.

Does that make sense? I'm not 100% sure of anything I'm saying here, so I may as be understanding something wrong.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe both should be defined as SimpleExtension files

Yes, that's right. The last (I'm aware of, at least) big piece we are missing in substrait is those *.yaml spec for all extended things and related URL settings. At present, all the things are defined in the document.

From the substrait website, we'll need a yaml parsing component to support extensions from other systems as well, if we are going to implement the ability to consume plans from external systems.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

going to file a tracking issue for tasks related to substrait support

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated at #8149

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah cool, so it was the plan all along :) Sg!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document PR: #10719

Ok(DataType::Interval(IntervalUnit::YearMonth))
}
INTERVAL_DAY_TIME_TYPE_REF => {
Ok(DataType::Interval(IntervalUnit::DayTime))
}
INTERVAL_MONTH_DAY_NANO_TYPE_REF => {
Ok(DataType::Interval(IntervalUnit::MonthDayNano))
}
_ => not_impl_err!(
"Unsupported Substrait user defined type with ref {} and variation {}",
u.type_reference,
u.type_variation_reference
),
}
},
r#type::Kind::Struct(s) => {
let mut fields = vec![];
for (i, f) in s.types.iter().enumerate() {
Expand Down Expand Up @@ -1341,6 +1361,54 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result<ScalarValue> {
builder.build()?
}
Some(LiteralType::Null(ntype)) => from_substrait_null(ntype)?,
Some(LiteralType::UserDefined(user_defined)) => {
match user_defined.type_reference {
INTERVAL_YEAR_MONTH_TYPE_REF => {
let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else {
return substrait_err!("Interval year month value is empty");
};
let value_slice: [u8; 4] =
raw_val.value.clone().try_into().map_err(|_| {
substrait_datafusion_err!(
"Failed to parse interval year month value"
)
})?;
ScalarValue::IntervalYearMonth(Some(i32::from_le_bytes(value_slice)))
}
INTERVAL_DAY_TIME_TYPE_REF => {
let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else {
return substrait_err!("Interval day time value is empty");
};
let value_slice: [u8; 8] =
raw_val.value.clone().try_into().map_err(|_| {
substrait_datafusion_err!(
"Failed to parse interval day time value"
)
})?;
ScalarValue::IntervalDayTime(Some(i64::from_le_bytes(value_slice)))
}
INTERVAL_MONTH_DAY_NANO_TYPE_REF => {
let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else {
return substrait_err!("Interval month day nano value is empty");
};
let value_slice: [u8; 16] =
raw_val.value.clone().try_into().map_err(|_| {
substrait_datafusion_err!(
"Failed to parse interval month day nano value"
)
})?;
ScalarValue::IntervalMonthDayNano(Some(i128::from_le_bytes(
value_slice,
)))
}
_ => {
return not_impl_err!(
"Unsupported Substrait user defined type with ref {}",
user_defined.type_reference
)
}
}
}
_ => return not_impl_err!("Unsupported literal_type: {:?}", lit.literal_type),
};

Expand Down
125 changes: 122 additions & 3 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use std::collections::HashMap;
use std::ops::Deref;
use std::sync::Arc;

use datafusion::arrow::datatypes::IntervalUnit;
use datafusion::logical_expr::{
CrossJoin, Distinct, Like, Partitioning, WindowFrameUnits,
};
Expand All @@ -43,9 +44,12 @@ use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Opera
use datafusion::prelude::Expr;
use prost_types::Any as ProtoAny;
use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields};
use substrait::proto::expression::literal::user_defined::Val;
use substrait::proto::expression::literal::UserDefined;
use substrait::proto::expression::literal::{List, Struct};
use substrait::proto::expression::subquery::InPredicate;
use substrait::proto::expression::window_function::BoundsType;
use substrait::proto::r#type::{parameter, Parameter};
use substrait::proto::{CrossRel, ExchangeRel};
use substrait::{
proto::{
Expand Down Expand Up @@ -84,9 +88,12 @@ use substrait::{

use crate::variation_const::{
DATE_32_TYPE_REF, DATE_64_TYPE_REF, DECIMAL_128_TYPE_REF, DECIMAL_256_TYPE_REF,
DEFAULT_CONTAINER_TYPE_REF, DEFAULT_TYPE_REF, LARGE_CONTAINER_TYPE_REF,
TIMESTAMP_MICRO_TYPE_REF, TIMESTAMP_MILLI_TYPE_REF, TIMESTAMP_NANO_TYPE_REF,
TIMESTAMP_SECOND_TYPE_REF, UNSIGNED_INTEGER_TYPE_REF,
DEFAULT_CONTAINER_TYPE_REF, DEFAULT_TYPE_REF, INTERVAL_DAY_TIME_TYPE_REF,
INTERVAL_DAY_TIME_TYPE_URL, INTERVAL_MONTH_DAY_NANO_TYPE_REF,
INTERVAL_MONTH_DAY_NANO_TYPE_URL, INTERVAL_YEAR_MONTH_TYPE_REF,
INTERVAL_YEAR_MONTH_TYPE_URL, LARGE_CONTAINER_TYPE_REF, TIMESTAMP_MICRO_TYPE_REF,
TIMESTAMP_MILLI_TYPE_REF, TIMESTAMP_NANO_TYPE_REF, TIMESTAMP_SECOND_TYPE_REF,
UNSIGNED_INTEGER_TYPE_REF,
};

/// Convert DataFusion LogicalPlan to Substrait Plan
Expand Down Expand Up @@ -1394,6 +1401,49 @@ fn to_substrait_type(dt: &DataType) -> Result<substrait::proto::Type> {
nullability: default_nullability,
})),
}),
DataType::Interval(interval_unit) => {
// define two type parameters for convenience
let i32_param = Parameter {
parameter: Some(parameter::Parameter::DataType(substrait::proto::Type {
kind: Some(r#type::Kind::I32(r#type::I32 {
type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
})),
};
let i64_param = Parameter {
parameter: Some(parameter::Parameter::DataType(substrait::proto::Type {
kind: Some(r#type::Kind::I64(r#type::I64 {
type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
})),
};

let (type_parameters, type_reference) = match interval_unit {
IntervalUnit::YearMonth => {
let type_parameters = vec![i32_param];
(type_parameters, INTERVAL_YEAR_MONTH_TYPE_REF)
}
IntervalUnit::DayTime => {
let type_parameters = vec![i64_param];
(type_parameters, INTERVAL_DAY_TIME_TYPE_REF)
}
IntervalUnit::MonthDayNano => {
// use 2 `i64` as `i128`
let type_parameters = vec![i64_param.clone(), i64_param];
(type_parameters, INTERVAL_MONTH_DAY_NANO_TYPE_REF)
}
};
Ok(substrait::proto::Type {
kind: Some(r#type::Kind::UserDefined(r#type::UserDefined {
type_reference,
type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
type_parameters,
})),
})
}
DataType::Binary => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::Binary(r#type::Binary {
type_variation_reference: DEFAULT_CONTAINER_TYPE_REF,
Expand Down Expand Up @@ -1721,6 +1771,75 @@ fn to_substrait_literal(value: &ScalarValue) -> Result<Literal> {
}
ScalarValue::Date32(Some(d)) => (LiteralType::Date(*d), DATE_32_TYPE_REF),
// Date64 literal is not supported in Substrait
ScalarValue::IntervalYearMonth(Some(i)) => {
let bytes = i.to_le_bytes();
(
LiteralType::UserDefined(UserDefined {
type_reference: INTERVAL_YEAR_MONTH_TYPE_REF,
type_parameters: vec![Parameter {
parameter: Some(parameter::Parameter::DataType(
substrait::proto::Type {
kind: Some(r#type::Kind::I32(r#type::I32 {
type_variation_reference: DEFAULT_TYPE_REF,
nullability: r#type::Nullability::Required as i32,
})),
},
)),
}],
val: Some(Val::Value(ProtoAny {
type_url: INTERVAL_YEAR_MONTH_TYPE_URL.to_string(),
value: bytes.to_vec(),
})),
}),
INTERVAL_YEAR_MONTH_TYPE_REF,
)
}
ScalarValue::IntervalMonthDayNano(Some(i)) => {
// treat `i128` as two contiguous `i64`
let bytes = i.to_le_bytes();
let i64_param = Parameter {
parameter: Some(parameter::Parameter::DataType(substrait::proto::Type {
kind: Some(r#type::Kind::I64(r#type::I64 {
type_variation_reference: DEFAULT_TYPE_REF,
nullability: r#type::Nullability::Required as i32,
})),
})),
};
(
LiteralType::UserDefined(UserDefined {
type_reference: INTERVAL_MONTH_DAY_NANO_TYPE_REF,
type_parameters: vec![i64_param.clone(), i64_param],
val: Some(Val::Value(ProtoAny {
type_url: INTERVAL_MONTH_DAY_NANO_TYPE_URL.to_string(),
value: bytes.to_vec(),
})),
}),
INTERVAL_MONTH_DAY_NANO_TYPE_REF,
)
}
ScalarValue::IntervalDayTime(Some(i)) => {
let bytes = i.to_le_bytes();
(
LiteralType::UserDefined(UserDefined {
type_reference: INTERVAL_DAY_TIME_TYPE_REF,
type_parameters: vec![Parameter {
parameter: Some(parameter::Parameter::DataType(
substrait::proto::Type {
kind: Some(r#type::Kind::I64(r#type::I64 {
type_variation_reference: DEFAULT_TYPE_REF,
nullability: r#type::Nullability::Required as i32,
})),
},
)),
}],
val: Some(Val::Value(ProtoAny {
type_url: INTERVAL_DAY_TIME_TYPE_URL.to_string(),
value: bytes.to_vec(),
})),
}),
INTERVAL_DAY_TIME_TYPE_REF,
)
}
ScalarValue::Binary(Some(b)) => {
(LiteralType::Binary(b.clone()), DEFAULT_CONTAINER_TYPE_REF)
}
Expand Down
56 changes: 56 additions & 0 deletions datafusion/substrait/src/variation_const.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
//! - Default type reference is 0. It is used when the actual type is the same with the original type.
//! - Extended variant type references start from 1, and ususlly increase by 1.

// For type variations
pub const DEFAULT_TYPE_REF: u32 = 0;
pub const UNSIGNED_INTEGER_TYPE_REF: u32 = 1;
pub const TIMESTAMP_SECOND_TYPE_REF: u32 = 0;
Expand All @@ -37,3 +38,58 @@ pub const DEFAULT_CONTAINER_TYPE_REF: u32 = 0;
pub const LARGE_CONTAINER_TYPE_REF: u32 = 1;
pub const DECIMAL_128_TYPE_REF: u32 = 0;
pub const DECIMAL_256_TYPE_REF: u32 = 1;

// For custom types
/// For [`DataType::Interval`] with [`IntervalUnit::YearMonth`].
///
/// An `i32` for elapsed whole months. See also [`ScalarValue::IntervalYearMonth`]
/// for the literal definition in DataFusion.
///
/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval
/// [`IntervalUnit::YearMonth`]: datafusion::arrow::datatypes::IntervalUnit::YearMonth
/// [`ScalarValue::IntervalYearMonth`]: datafusion::common::ScalarValue::IntervalYearMonth
pub const INTERVAL_YEAR_MONTH_TYPE_REF: u32 = 1;

/// For [`DataType::Interval`] with [`IntervalUnit::DayTime`].
///
/// An `i64` as:
/// - days: `i32`
/// - milliseconds: `i32`
///
/// See also [`ScalarValue::IntervalDayTime`] for the literal definition in DataFusion.
///
/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval
/// [`IntervalUnit::DayTime`]: datafusion::arrow::datatypes::IntervalUnit::DayTime
/// [`ScalarValue::IntervalDayTime`]: datafusion::common::ScalarValue::IntervalDayTime
pub const INTERVAL_DAY_TIME_TYPE_REF: u32 = 2;

/// For [`DataType::Interval`] with [`IntervalUnit::MonthDayNano`].
///
/// An `i128` as:
/// - months: `i32`
/// - days: `i32`
/// - nanoseconds: `i64`
///
/// See also [`ScalarValue::IntervalMonthDayNano`] for the literal definition in DataFusion.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI the interval implementation is changing in the next arrow I think: apache/arrow-rs#5769

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for informing, I'd like to help migrate to the new arrow version. BTW, is there any plan for next bump?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think ETA is next week sometime apache/arrow-rs#5688

Maybe you can make a "pre-release" of DataFusion against the un-released version of arrow-rs (which @tustvold often does to make sure as a way to sanity check the release)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good 👍 The "pre-update" is #10765 (still WIP

///
/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval
/// [`IntervalUnit::MonthDayNano`]: datafusion::arrow::datatypes::IntervalUnit::MonthDayNano
/// [`ScalarValue::IntervalMonthDayNano`]: datafusion::common::ScalarValue::IntervalMonthDayNano
pub const INTERVAL_MONTH_DAY_NANO_TYPE_REF: u32 = 3;

// For User Defined URLs
/// For [`DataType::Interval`] with [`IntervalUnit::YearMonth`].
///
/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval
/// [`IntervalUnit::YearMonth`]: datafusion::arrow::datatypes::IntervalUnit::YearMonth
pub const INTERVAL_YEAR_MONTH_TYPE_URL: &str = "interval-year-month";
/// For [`DataType::Interval`] with [`IntervalUnit::DayTime`].
///
/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval
/// [`IntervalUnit::DayTime`]: datafusion::arrow::datatypes::IntervalUnit::DayTime
pub const INTERVAL_DAY_TIME_TYPE_URL: &str = "interval-day-time";
/// For [`DataType::Interval`] with [`IntervalUnit::MonthDayNano`].
///
/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval
/// [`IntervalUnit::MonthDayNano`]: datafusion::arrow::datatypes::IntervalUnit::MonthDayNano
pub const INTERVAL_MONTH_DAY_NANO_TYPE_URL: &str = "interval-month-day-nano";
26 changes: 23 additions & 3 deletions datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use datafusion_substrait::logical_plan::{
use std::hash::Hash;
use std::sync::Arc;

use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit};
use datafusion::common::{not_impl_err, plan_err, DFSchema, DFSchemaRef};
use datafusion::error::Result;
use datafusion::execution::context::SessionState;
Expand Down Expand Up @@ -496,6 +496,24 @@ async fn roundtrip_arithmetic_ops() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn roundtrip_interval_literal() -> Result<()> {
roundtrip(
"SELECT g from data where g = arrow_cast(INTERVAL '1 YEAR', 'Interval(YearMonth)')",
)
.await?;
roundtrip(
"SELECT g from data where g = arrow_cast(INTERVAL '1 YEAR', 'Interval(DayTime)')",
)
.await?;
roundtrip(
"SELECT g from data where g = arrow_cast(INTERVAL '1 YEAR', 'Interval(MonthDayNano)')",
)
.await?;

Ok(())
}

#[tokio::test]
async fn roundtrip_like() -> Result<()> {
roundtrip("SELECT f FROM data WHERE f LIKE 'a%b'").await
Expand Down Expand Up @@ -1035,14 +1053,16 @@ async fn create_context() -> Result<SessionContext> {
.with_serializer_registry(Arc::new(MockSerializerRegistry));
let ctx = SessionContext::new_with_state(state);
let mut explicit_options = CsvReadOptions::new();
let schema = Schema::new(vec![
let fields = vec![
Field::new("a", DataType::Int64, true),
Field::new("b", DataType::Decimal128(5, 2), true),
Field::new("c", DataType::Date32, true),
Field::new("d", DataType::Boolean, true),
Field::new("e", DataType::UInt32, true),
Field::new("f", DataType::Utf8, true),
]);
Field::new("g", DataType::Interval(IntervalUnit::DayTime), true),
];
let schema = Schema::new(fields);
explicit_options.schema = Some(&schema);
ctx.register_csv("data", "tests/testdata/data.csv", explicit_options)
.await?;
Expand Down