Skip to content

Commit

Permalink
Add proper support for null literal by introducing `ScalarValue::Nu…
Browse files Browse the repository at this point in the history
…ll` (apache#2364)

* introduce null

* fix fmt
  • Loading branch information
WinkerDu authored and MazterQyou committed Jun 9, 2023
1 parent 36f471d commit 99fdae5
Show file tree
Hide file tree
Showing 15 changed files with 180 additions and 38 deletions.
30 changes: 29 additions & 1 deletion datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc};
/// This is the single-valued counter-part of arrow’s `Array`.
#[derive(Clone)]
pub enum ScalarValue {
/// represents `DataType::Null` (castable to/from any other type)
Null,
/// true or false value
Boolean(Option<bool>),
/// 32bit float
Expand Down Expand Up @@ -170,6 +172,8 @@ impl PartialEq for ScalarValue {
(IntervalMonthDayNano(_), _) => false,
(Struct(v1, t1), Struct(v2, t2)) => v1.eq(v2) && t1.eq(t2),
(Struct(_, _), _) => false,
(Null, Null) => true,
(Null, _) => false,
}
}
}
Expand Down Expand Up @@ -270,6 +274,8 @@ impl PartialOrd for ScalarValue {
}
}
(Struct(_, _), _) => None,
(Null, Null) => Some(Ordering::Equal),
(Null, _) => None,
}
}
}
Expand Down Expand Up @@ -325,6 +331,8 @@ impl std::hash::Hash for ScalarValue {
v.hash(state);
t.hash(state);
}
// stable hash for Null value
Null => 1.hash(state),
}
}
}
Expand Down Expand Up @@ -594,6 +602,7 @@ impl ScalarValue {
DataType::Interval(IntervalUnit::MonthDayNano)
}
ScalarValue::Struct(_, fields) => DataType::Struct(fields.as_ref().clone()),
ScalarValue::Null => DataType::Null,
}
}

Expand Down Expand Up @@ -623,7 +632,8 @@ impl ScalarValue {
pub fn is_null(&self) -> bool {
matches!(
*self,
ScalarValue::Boolean(None)
ScalarValue::Null
| ScalarValue::Boolean(None)
| ScalarValue::UInt8(None)
| ScalarValue::UInt16(None)
| ScalarValue::UInt32(None)
Expand Down Expand Up @@ -844,6 +854,7 @@ impl ScalarValue {
ScalarValue::iter_to_decimal_array(scalars, precision, scale)?;
Arc::new(decimal_array)
}
DataType::Null => ScalarValue::iter_to_null_array(scalars),
DataType::Boolean => build_array_primitive!(BooleanArray, Boolean),
DataType::Float32 => build_array_primitive!(Float32Array, Float32),
DataType::Float64 => build_array_primitive!(Float64Array, Float64),
Expand Down Expand Up @@ -976,6 +987,17 @@ impl ScalarValue {
Ok(array)
}

fn iter_to_null_array(scalars: impl IntoIterator<Item = ScalarValue>) -> ArrayRef {
let length =
scalars
.into_iter()
.fold(0usize, |r, element: ScalarValue| match element {
ScalarValue::Null => r + 1,
_ => unreachable!(),
});
new_null_array(&DataType::Null, length)
}

fn iter_to_decimal_array(
scalars: impl IntoIterator<Item = ScalarValue>,
precision: &usize,
Expand Down Expand Up @@ -1249,6 +1271,7 @@ impl ScalarValue {
Arc::new(StructArray::from(field_values))
}
},
ScalarValue::Null => new_null_array(&DataType::Null, size),
}
}

Expand All @@ -1274,6 +1297,7 @@ impl ScalarValue {
}

Ok(match array.data_type() {
DataType::Null => ScalarValue::Null,
DataType::Decimal(precision, scale) => {
ScalarValue::get_decimal_value_from_array(array, index, precision, scale)
}
Expand Down Expand Up @@ -1519,6 +1543,7 @@ impl ScalarValue {
eq_array_primitive!(array, index, IntervalMonthDayNanoArray, val)
}
ScalarValue::Struct(_, _) => unimplemented!(),
ScalarValue::Null => array.data().is_null(index),
}
}

Expand Down Expand Up @@ -1740,6 +1765,7 @@ impl TryFrom<&DataType> for ScalarValue {
DataType::Struct(fields) => {
ScalarValue::Struct(None, Box::new(fields.clone()))
}
DataType::Null => ScalarValue::Null,
_ => {
return Err(DataFusionError::NotImplemented(format!(
"Can't create a scalar from data_type \"{:?}\"",
Expand Down Expand Up @@ -1832,6 +1858,7 @@ impl fmt::Display for ScalarValue {
)?,
None => write!(f, "NULL")?,
},
ScalarValue::Null => write!(f, "NULL")?,
};
Ok(())
}
Expand Down Expand Up @@ -1899,6 +1926,7 @@ impl fmt::Debug for ScalarValue {
None => write!(f, "Struct(NULL)"),
}
}
ScalarValue::Null => write!(f, "NULL"),
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ impl LogicalPlanBuilder {
.iter()
.enumerate()
.map(|(j, expr)| {
if let Expr::Literal(ScalarValue::Utf8(None)) = expr {
if let Expr::Literal(ScalarValue::Null) = expr {
nulls.push((i, j));
Ok(field_types[j].clone())
} else {
Expand Down
6 changes: 5 additions & 1 deletion datafusion/core/src/physical_plan/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,11 @@ fn equal_rows(
.iter()
.zip(right_arrays)
.all(|(l, r)| match l.data_type() {
DataType::Null => true,
DataType::Null => {
// lhs and rhs are both `DataType::Null`, so the euqal result
// is dependent on `null_equals_null`
null_equals_null
}
DataType::Boolean => {
equal_rows_elem!(BooleanArray, l, r, left, right, null_equals_null)
}
Expand Down
16 changes: 16 additions & 0 deletions datafusion/core/src/physical_plan/hash_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ fn combine_hashes(l: u64, r: u64) -> u64 {
hash.wrapping_mul(37).wrapping_add(r)
}

fn hash_null(random_state: &RandomState, hashes_buffer: &'_ mut [u64], mul_col: bool) {
if mul_col {
hashes_buffer.iter_mut().for_each(|hash| {
// stable hash for null value
*hash = combine_hashes(i128::get_hash(&1, random_state), *hash);
})
} else {
hashes_buffer.iter_mut().for_each(|hash| {
*hash = i128::get_hash(&1, random_state);
})
}
}

fn hash_decimal128<'a>(
array: &ArrayRef,
random_state: &RandomState,
Expand Down Expand Up @@ -310,6 +323,9 @@ pub fn create_hashes<'a>(

for col in arrays {
match col.data_type() {
DataType::Null => {
hash_null(random_state, hashes_buffer, multi_col);
}
DataType::Decimal(_, _) => {
hash_decimal128(col, random_state, hashes_buffer, multi_col);
}
Expand Down
8 changes: 4 additions & 4 deletions datafusion/core/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1681,7 +1681,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
SQLExpr::Value(Value::Number(n, _)) => parse_sql_number(&n),
SQLExpr::Value(Value::SingleQuotedString(s)) => Ok(lit(s)),
SQLExpr::Value(Value::Null) => {
Ok(Expr::Literal(ScalarValue::Utf8(None)))
Ok(Expr::Literal(ScalarValue::Null))
}
SQLExpr::Value(Value::Boolean(n)) => Ok(lit(n)),
SQLExpr::UnaryOp { op, expr } => {
Expand All @@ -1707,7 +1707,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
SQLExpr::Value(Value::SingleQuotedString(ref s)) => Ok(lit(s.clone())),
SQLExpr::Value(Value::EscapedStringLiteral(ref s)) => Ok(lit(s.clone())),
SQLExpr::Value(Value::Boolean(n)) => Ok(lit(n)),
SQLExpr::Value(Value::Null) => Ok(Expr::Literal(ScalarValue::Utf8(None))),
SQLExpr::Value(Value::Null) => Ok(Expr::Literal(ScalarValue::Null)),
SQLExpr::Extract { field, expr } => Ok(Expr::ScalarFunction {
fun: BuiltinScalarFunction::DatePart,
args: vec![
Expand Down Expand Up @@ -4258,9 +4258,9 @@ mod tests {
fn union_with_null() {
let sql = "SELECT NULL a UNION ALL SELECT 1.1 a";
let expected = "Union\
\n Projection: Utf8(NULL) AS a\
\n Projection: CAST(NULL AS Float64) AS a\
\n EmptyRelation\
\n Projection: CAST(Float64(1.1) AS Utf8) AS a\
\n Projection: Float64(1.1) AS a\
\n EmptyRelation";
quick_test(sql, expected);
}
Expand Down
10 changes: 5 additions & 5 deletions datafusion/core/tests/sql/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,11 +362,11 @@ async fn test_string_concat_operator() -> Result<()> {
let sql = "SELECT 'aa' || NULL || 'd'";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+---------------------------------------+",
"| Utf8(\"aa\") || Utf8(NULL) || Utf8(\"d\") |",
"+---------------------------------------+",
"| |",
"+---------------------------------------+",
"+---------------------------------+",
"| Utf8(\"aa\") || NULL || Utf8(\"d\") |",
"+---------------------------------+",
"| |",
"+---------------------------------+",
];
assert_batches_eq!(expected, &actual);

Expand Down
10 changes: 5 additions & 5 deletions datafusion/core/tests/sql/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,11 @@ async fn coalesce_static_value_with_null() -> Result<()> {
let sql = "SELECT COALESCE(NULL, 'test')";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-----------------------------------+",
"| coalesce(Utf8(NULL),Utf8(\"test\")) |",
"+-----------------------------------+",
"| test |",
"+-----------------------------------+",
"+-----------------------------+",
"| coalesce(NULL,Utf8(\"test\")) |",
"+-----------------------------+",
"| test |",
"+-----------------------------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
Expand Down
6 changes: 5 additions & 1 deletion datafusion/core/tests/sql/joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,11 @@ async fn inner_join_nulls() {
let sql = "SELECT * FROM (SELECT null AS id1) t1
INNER JOIN (SELECT null AS id2) t2 ON id1 = id2";

let expected = vec!["++", "++"];
#[rustfmt::skip]
let expected = vec![
"++",
"++",
];

let ctx = create_join_context_qualified("t1", "t2").unwrap();
let actual = execute_to_batches(&ctx, sql).await;
Expand Down
34 changes: 28 additions & 6 deletions datafusion/core/tests/sql/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -398,15 +398,37 @@ async fn select_distinct_from() {
1 IS NOT DISTINCT FROM CAST(NULL as INT) as c,
1 IS NOT DISTINCT FROM 1 as d,
NULL IS DISTINCT FROM NULL as e,
NULL IS NOT DISTINCT FROM NULL as f
NULL IS NOT DISTINCT FROM NULL as f,
NULL is DISTINCT FROM 1 as g,
NULL is NOT DISTINCT FROM 1 as h
";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+------+-------+-------+------+-------+------+",
"| a | b | c | d | e | f |",
"+------+-------+-------+------+-------+------+",
"| true | false | false | true | false | true |",
"+------+-------+-------+------+-------+------+",
"+------+-------+-------+------+-------+------+------+-------+",
"| a | b | c | d | e | f | g | h |",
"+------+-------+-------+------+-------+------+------+-------+",
"| true | false | false | true | false | true | true | false |",
"+------+-------+-------+------+-------+------+------+-------+",
];
assert_batches_eq!(expected, &actual);

let sql = "select
NULL IS DISTINCT FROM NULL as a,
NULL IS NOT DISTINCT FROM NULL as b,
NULL is DISTINCT FROM 1 as c,
NULL is NOT DISTINCT FROM 1 as d,
1 IS DISTINCT FROM CAST(NULL as INT) as e,
1 IS DISTINCT FROM 1 as f,
1 IS NOT DISTINCT FROM CAST(NULL as INT) as g,
1 IS NOT DISTINCT FROM 1 as h
";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-------+------+------+-------+------+-------+-------+------+",
"| a | b | c | d | e | f | g | h |",
"+-------+------+------+-------+------+-------+-------+------+",
"| false | true | true | false | true | false | false | true |",
"+-------+------+------+-------+------+-------+-------+------+",
];
assert_batches_eq!(expected, &actual);
}
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/binary_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,7 @@ fn eq_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
numerical_coercion(lhs_type, rhs_type)
.or_else(|| dictionary_coercion(lhs_type, rhs_type))
.or_else(|| temporal_coercion(lhs_type, rhs_type))
.or_else(|| null_coercion(lhs_type, rhs_type))
}

/// Coercion rule for interval
Expand Down
2 changes: 2 additions & 0 deletions datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ macro_rules! make_utf8_to_return_type {
Ok(match arg_type {
DataType::LargeUtf8 => $largeUtf8Type,
DataType::Utf8 => $utf8Type,
DataType::Null => DataType::Null,
_ => {
// this error is internal as `data_types` should have captured this.
return Err(DataFusionError::Internal(format!(
Expand Down Expand Up @@ -226,6 +227,7 @@ pub fn return_type(
DataType::Utf8 => {
DataType::List(Box::new(Field::new("item", DataType::Utf8, true)))
}
DataType::Null => DataType::Null,
_ => {
// this error is internal as `data_types` should have captured this.
return Err(DataFusionError::Internal(
Expand Down
40 changes: 28 additions & 12 deletions datafusion/expr/src/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
//!
use crate::{Signature, TypeSignature};
use arrow::datatypes::{DataType, TimeUnit};
use arrow::{
compute::can_cast_types,
datatypes::{DataType, TimeUnit},
};
use datafusion_common::{DataFusionError, Result};

/// Returns the data types that each argument must be coerced to match
Expand Down Expand Up @@ -142,25 +145,35 @@ fn maybe_data_types(
/// See the module level documentation for more detail on coercion.
pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool {
use self::DataType::*;
// Null can convert to most of types
match type_into {
Int8 => matches!(type_from, Int8),
Int16 => matches!(type_from, Int8 | Int16 | UInt8),
Int32 => matches!(type_from, Int8 | Int16 | Int32 | UInt8 | UInt16),
Int8 => matches!(type_from, Null | Int8),
Int16 => matches!(type_from, Null | Int8 | Int16 | UInt8),
Int32 => matches!(type_from, Null | Int8 | Int16 | Int32 | UInt8 | UInt16),
Int64 => matches!(
type_from,
Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32
Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32
),
UInt8 => matches!(type_from, UInt8),
UInt16 => matches!(type_from, UInt8 | UInt16),
UInt32 => matches!(type_from, UInt8 | UInt16 | UInt32),
UInt64 => matches!(type_from, UInt8 | UInt16 | UInt32 | UInt64),
UInt8 => matches!(type_from, Null | UInt8),
UInt16 => matches!(type_from, Null | UInt8 | UInt16),
UInt32 => matches!(type_from, Null | UInt8 | UInt16 | UInt32),
UInt64 => matches!(type_from, Null | UInt8 | UInt16 | UInt32 | UInt64),
Float32 => matches!(
type_from,
Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 | Float32
Null | Int8
| Int16
| Int32
| Int64
| UInt8
| UInt16
| UInt32
| UInt64
| Float32
),
Float64 => matches!(
type_from,
Int8 | Int16
Null | Int8
| Int16
| Int32
| Int64
| UInt8
Expand All @@ -170,8 +183,11 @@ pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool {
| Float32
| Float64
),
Timestamp(TimeUnit::Nanosecond, None) => matches!(type_from, Timestamp(_, None)),
Timestamp(TimeUnit::Nanosecond, None) => {
matches!(type_from, Null | Timestamp(_, None))
}
Utf8 | LargeUtf8 => true,
Null => can_cast_types(type_from, type_into),
_ => false,
}
}
Expand Down
Loading

0 comments on commit 99fdae5

Please sign in to comment.