Skip to content

Commit

Permalink
simplify Literal/ScalarValue null handling
Browse files Browse the repository at this point in the history
  • Loading branch information
Blizzara committed May 24, 2024
1 parent b6ff771 commit cbd2264
Showing 1 changed file with 28 additions and 196 deletions.
224 changes: 28 additions & 196 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1102,7 +1102,7 @@ pub fn to_substrait_rex(
))),
})
}
Expr::Literal(value) => to_substrait_literal_expr(value, true),
Expr::Literal(value) => to_substrait_literal_expr(value),
Expr::Alias(Alias { expr, .. }) => {
to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)
}
Expand Down Expand Up @@ -1532,10 +1532,9 @@ fn make_substrait_like_expr(
};
let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?;
let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset, extension_info)?;
let escape_char = to_substrait_literal_expr(
&ScalarValue::Utf8(escape_char.map(|c| c.to_string())),
true,
)?;
let escape_char = to_substrait_literal_expr(&ScalarValue::Utf8(
escape_char.map(|c| c.to_string()),
))?;
let arguments = vec![
FunctionArgument {
arg_type: Some(ArgType::Value(expr)),
Expand Down Expand Up @@ -1691,7 +1690,17 @@ fn to_substrait_bounds(window_frame: &WindowFrame) -> Result<(Bound, Bound)> {
))
}

fn to_substrait_literal(value: &ScalarValue, nullable: bool) -> Result<Literal> {
fn to_substrait_literal(value: &ScalarValue) -> Result<Literal> {
if value.is_null() {
return Ok(Literal {
nullable: true,
type_variation_reference: DEFAULT_TYPE_REF,
literal_type: Some(LiteralType::Null(to_substrait_type(
&value.data_type(),
true,
)?)),
});
}
let (literal_type, type_variation_reference) = match value {
ScalarValue::Boolean(Some(b)) => (LiteralType::Boolean(*b), DEFAULT_TYPE_REF),
ScalarValue::Int8(Some(n)) => (LiteralType::I8(*n as i32), DEFAULT_TYPE_REF),
Expand Down Expand Up @@ -1749,34 +1758,34 @@ fn to_substrait_literal(value: &ScalarValue, nullable: bool) -> Result<Literal>
}),
DECIMAL_128_TYPE_REF,
),
ScalarValue::List(l) if !value.is_null() => (
ScalarValue::List(l) => (
convert_array_to_literal_list(l)?,
DEFAULT_CONTAINER_TYPE_REF,
),
ScalarValue::LargeList(l) if !value.is_null() => {
ScalarValue::LargeList(l) => {
(convert_array_to_literal_list(l)?, LARGE_CONTAINER_TYPE_REF)
}
ScalarValue::Struct(s) if !value.is_null() => (
ScalarValue::Struct(s) => (
LiteralType::Struct(Struct {
fields: s
.columns()
.iter()
.zip(s.fields())
.map(|(col, field)| {
to_substrait_literal(
&ScalarValue::try_from_array(col, 0)?,
field.is_nullable(),
)
to_substrait_literal(&ScalarValue::try_from_array(col, 0)?)
})
.collect::<Result<Vec<_>>>()?,
}),
DEFAULT_TYPE_REF,
),
_ => (try_to_substrait_null(value)?, DEFAULT_TYPE_REF),
_ => (
not_impl_err!("Unsupported literal: {value:?}")?,
DEFAULT_TYPE_REF,
),
};

Ok(Literal {
nullable,
nullable: false,
type_variation_reference,
literal_type: Some(literal_type),
})
Expand All @@ -1789,12 +1798,7 @@ fn convert_array_to_literal_list<T: OffsetSizeTrait>(
let nested_array = array.value(0);

let values = (0..nested_array.len())
.map(|i| {
to_substrait_literal(
&ScalarValue::try_from_array(&nested_array, i)?,
array.is_nullable(),
)
})
.map(|i| to_substrait_literal(&ScalarValue::try_from_array(&nested_array, i)?))
.collect::<Result<Vec<_>>>()?;

if values.is_empty() {
Expand All @@ -1810,8 +1814,8 @@ fn convert_array_to_literal_list<T: OffsetSizeTrait>(
}
}

fn to_substrait_literal_expr(value: &ScalarValue, nullable: bool) -> Result<Expression> {
let literal = to_substrait_literal(value, nullable)?;
fn to_substrait_literal_expr(value: &ScalarValue) -> Result<Expression> {
let literal = to_substrait_literal(value)?;
Ok(Expression {
rex_type: Some(RexType::Literal(literal)),
})
Expand Down Expand Up @@ -1846,175 +1850,6 @@ fn to_substrait_unary_scalar_fn(
})
}

fn try_to_substrait_null(v: &ScalarValue) -> Result<LiteralType> {
let default_nullability = r#type::Nullability::Nullable as i32;
match v {
ScalarValue::Boolean(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::Bool(r#type::Boolean {
type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::Int8(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::I8(r#type::I8 {
type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::UInt8(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::I8(r#type::I8 {
type_variation_reference: UNSIGNED_INTEGER_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::Int16(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::I16(r#type::I16 {
type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::UInt16(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::I16(r#type::I16 {
type_variation_reference: UNSIGNED_INTEGER_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::Int32(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::I32(r#type::I32 {
type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::UInt32(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::I32(r#type::I32 {
type_variation_reference: UNSIGNED_INTEGER_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::Int64(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::I64(r#type::I64 {
type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::UInt64(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::I64(r#type::I64 {
type_variation_reference: UNSIGNED_INTEGER_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::Float32(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::Fp32(r#type::Fp32 {
type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::Float64(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::Fp64(r#type::Fp64 {
type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::TimestampSecond(None, _) => {
Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::Timestamp(r#type::Timestamp {
type_variation_reference: TIMESTAMP_SECOND_TYPE_REF,
nullability: default_nullability,
})),
}))
}
ScalarValue::TimestampMillisecond(None, _) => {
Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::Timestamp(r#type::Timestamp {
type_variation_reference: TIMESTAMP_MILLI_TYPE_REF,
nullability: default_nullability,
})),
}))
}
ScalarValue::TimestampMicrosecond(None, _) => {
Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::Timestamp(r#type::Timestamp {
type_variation_reference: TIMESTAMP_MICRO_TYPE_REF,
nullability: default_nullability,
})),
}))
}
ScalarValue::TimestampNanosecond(None, _) => {
Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::Timestamp(r#type::Timestamp {
type_variation_reference: TIMESTAMP_NANO_TYPE_REF,
nullability: default_nullability,
})),
}))
}
ScalarValue::Date32(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::Date(r#type::Date {
type_variation_reference: DATE_32_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::Date64(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::Date(r#type::Date {
type_variation_reference: DATE_64_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::Binary(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::Binary(r#type::Binary {
type_variation_reference: DEFAULT_CONTAINER_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::LargeBinary(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::Binary(r#type::Binary {
type_variation_reference: LARGE_CONTAINER_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::FixedSizeBinary(_, None) => {
Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::Binary(r#type::Binary {
type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
}))
}
ScalarValue::Utf8(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::String(r#type::String {
type_variation_reference: DEFAULT_CONTAINER_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::LargeUtf8(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::String(r#type::String {
type_variation_reference: LARGE_CONTAINER_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::Decimal128(None, p, s) => {
Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::Decimal(r#type::Decimal {
scale: *s as i32,
precision: *p as i32,
type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
}))
}
ScalarValue::List(l) => {
Ok(LiteralType::Null(to_substrait_type(l.data_type(), true)?))
}
ScalarValue::LargeList(l) => {
Ok(LiteralType::Null(to_substrait_type(l.data_type(), true)?))
}
ScalarValue::Struct(s) => {
Ok(LiteralType::Null(to_substrait_type(s.data_type(), true)?))
}
// TODO: Extend support for remaining data types
_ => not_impl_err!("Unsupported literal: {v:?}"),
}
}

/// Try to convert an [Expr] to a [FieldReference].
/// Returns `Err` if the [Expr] is not a [Expr::Column].
fn try_to_substrait_field_reference(
Expand Down Expand Up @@ -2174,11 +2009,8 @@ mod test {
fn round_trip_literal(scalar: ScalarValue) -> Result<()> {
println!("Checking round trip of {scalar:?}");

// As DataFusion doesn't consider nullability as a property of the scalar, but field,
// it doesn't matter if we set nullability to true or false here.
let substrait_literal = to_substrait_literal(&scalar, true)?;
let substrait_literal = to_substrait_literal(&scalar)?;
let roundtrip_scalar = from_substrait_literal(&substrait_literal)?;
println!("Substrait literal: {substrait_literal:?}");
assert_eq!(scalar, roundtrip_scalar);
Ok(())
}
Expand Down

0 comments on commit cbd2264

Please sign in to comment.