Skip to content

Commit

Permalink
Add proper support for null literal by introducing `ScalarValue::Nu…
Browse files Browse the repository at this point in the history
…ll` (#2364)

* introduce null

* fix fmt
  • Loading branch information
WinkerDu authored May 6, 2022
1 parent b70da54 commit fcc35e8
Show file tree
Hide file tree
Showing 15 changed files with 217 additions and 54 deletions.
30 changes: 29 additions & 1 deletion datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc};
/// This is the single-valued counter-part of arrow’s `Array`.
#[derive(Clone)]
pub enum ScalarValue {
/// represents `DataType::Null` (castable to/from any other type)
Null,
/// true or false value
Boolean(Option<bool>),
/// 32bit float
Expand Down Expand Up @@ -170,6 +172,8 @@ impl PartialEq for ScalarValue {
(IntervalMonthDayNano(_), _) => false,
(Struct(v1, t1), Struct(v2, t2)) => v1.eq(v2) && t1.eq(t2),
(Struct(_, _), _) => false,
(Null, Null) => true,
(Null, _) => false,
}
}
}
Expand Down Expand Up @@ -270,6 +274,8 @@ impl PartialOrd for ScalarValue {
}
}
(Struct(_, _), _) => None,
(Null, Null) => Some(Ordering::Equal),
(Null, _) => None,
}
}
}
Expand Down Expand Up @@ -325,6 +331,8 @@ impl std::hash::Hash for ScalarValue {
v.hash(state);
t.hash(state);
}
// stable hash for Null value
Null => 1.hash(state),
}
}
}
Expand Down Expand Up @@ -594,6 +602,7 @@ impl ScalarValue {
DataType::Interval(IntervalUnit::MonthDayNano)
}
ScalarValue::Struct(_, fields) => DataType::Struct(fields.as_ref().clone()),
ScalarValue::Null => DataType::Null,
}
}

Expand Down Expand Up @@ -623,7 +632,8 @@ impl ScalarValue {
pub fn is_null(&self) -> bool {
matches!(
*self,
ScalarValue::Boolean(None)
ScalarValue::Null
| ScalarValue::Boolean(None)
| ScalarValue::UInt8(None)
| ScalarValue::UInt16(None)
| ScalarValue::UInt32(None)
Expand Down Expand Up @@ -836,6 +846,7 @@ impl ScalarValue {
ScalarValue::iter_to_decimal_array(scalars, precision, scale)?;
Arc::new(decimal_array)
}
DataType::Null => ScalarValue::iter_to_null_array(scalars),
DataType::Boolean => build_array_primitive!(BooleanArray, Boolean),
DataType::Float32 => build_array_primitive!(Float32Array, Float32),
DataType::Float64 => build_array_primitive!(Float64Array, Float64),
Expand Down Expand Up @@ -968,6 +979,17 @@ impl ScalarValue {
Ok(array)
}

fn iter_to_null_array(scalars: impl IntoIterator<Item = ScalarValue>) -> ArrayRef {
let length =
scalars
.into_iter()
.fold(0usize, |r, element: ScalarValue| match element {
ScalarValue::Null => r + 1,
_ => unreachable!(),
});
new_null_array(&DataType::Null, length)
}

fn iter_to_decimal_array(
scalars: impl IntoIterator<Item = ScalarValue>,
precision: &usize,
Expand Down Expand Up @@ -1241,6 +1263,7 @@ impl ScalarValue {
Arc::new(StructArray::from(field_values))
}
},
ScalarValue::Null => new_null_array(&DataType::Null, size),
}
}

Expand All @@ -1266,6 +1289,7 @@ impl ScalarValue {
}

Ok(match array.data_type() {
DataType::Null => ScalarValue::Null,
DataType::Decimal(precision, scale) => {
ScalarValue::get_decimal_value_from_array(array, index, precision, scale)
}
Expand Down Expand Up @@ -1522,6 +1546,7 @@ impl ScalarValue {
eq_array_primitive!(array, index, IntervalMonthDayNanoArray, val)
}
ScalarValue::Struct(_, _) => unimplemented!(),
ScalarValue::Null => array.data().is_null(index),
}
}

Expand Down Expand Up @@ -1743,6 +1768,7 @@ impl TryFrom<&DataType> for ScalarValue {
DataType::Struct(fields) => {
ScalarValue::Struct(None, Box::new(fields.clone()))
}
DataType::Null => ScalarValue::Null,
_ => {
return Err(DataFusionError::NotImplemented(format!(
"Can't create a scalar from data_type \"{:?}\"",
Expand Down Expand Up @@ -1835,6 +1861,7 @@ impl fmt::Display for ScalarValue {
)?,
None => write!(f, "NULL")?,
},
ScalarValue::Null => write!(f, "NULL")?,
};
Ok(())
}
Expand Down Expand Up @@ -1902,6 +1929,7 @@ impl fmt::Debug for ScalarValue {
None => write!(f, "Struct(NULL)"),
}
}
ScalarValue::Null => write!(f, "NULL"),
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ impl LogicalPlanBuilder {
.iter()
.enumerate()
.map(|(j, expr)| {
if let Expr::Literal(ScalarValue::Utf8(None)) = expr {
if let Expr::Literal(ScalarValue::Null) = expr {
nulls.push((i, j));
Ok(field_types[j].clone())
} else {
Expand Down
6 changes: 5 additions & 1 deletion datafusion/core/src/physical_plan/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,11 @@ fn equal_rows(
.iter()
.zip(right_arrays)
.all(|(l, r)| match l.data_type() {
DataType::Null => true,
DataType::Null => {
// lhs and rhs are both `DataType::Null`, so the euqal result
// is dependent on `null_equals_null`
null_equals_null
}
DataType::Boolean => {
equal_rows_elem!(BooleanArray, l, r, left, right, null_equals_null)
}
Expand Down
16 changes: 16 additions & 0 deletions datafusion/core/src/physical_plan/hash_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ fn combine_hashes(l: u64, r: u64) -> u64 {
hash.wrapping_mul(37).wrapping_add(r)
}

fn hash_null(random_state: &RandomState, hashes_buffer: &'_ mut [u64], mul_col: bool) {
if mul_col {
hashes_buffer.iter_mut().for_each(|hash| {
// stable hash for null value
*hash = combine_hashes(i128::get_hash(&1, random_state), *hash);
})
} else {
hashes_buffer.iter_mut().for_each(|hash| {
*hash = i128::get_hash(&1, random_state);
})
}
}

fn hash_decimal128<'a>(
array: &ArrayRef,
random_state: &RandomState,
Expand Down Expand Up @@ -284,6 +297,9 @@ pub fn create_hashes<'a>(

for col in arrays {
match col.data_type() {
DataType::Null => {
hash_null(random_state, hashes_buffer, multi_col);
}
DataType::Decimal(_, _) => {
hash_decimal128(col, random_state, hashes_buffer, multi_col);
}
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1531,7 +1531,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
SQLExpr::Value(Value::Number(n, _)) => parse_sql_number(&n),
SQLExpr::Value(Value::SingleQuotedString(s)) => Ok(lit(s)),
SQLExpr::Value(Value::Null) => {
Ok(Expr::Literal(ScalarValue::Utf8(None)))
Ok(Expr::Literal(ScalarValue::Null))
}
SQLExpr::Value(Value::Boolean(n)) => Ok(lit(n)),
SQLExpr::UnaryOp { op, expr } => self.parse_sql_unary_op(
Expand Down Expand Up @@ -1569,7 +1569,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
SQLExpr::Value(Value::Number(n, _)) => parse_sql_number(&n),
SQLExpr::Value(Value::SingleQuotedString(ref s)) => Ok(lit(s.clone())),
SQLExpr::Value(Value::Boolean(n)) => Ok(lit(n)),
SQLExpr::Value(Value::Null) => Ok(Expr::Literal(ScalarValue::Utf8(None))),
SQLExpr::Value(Value::Null) => Ok(Expr::Literal(ScalarValue::Null)),
SQLExpr::Extract { field, expr } => Ok(Expr::ScalarFunction {
fun: BuiltinScalarFunction::DatePart,
args: vec![
Expand Down
46 changes: 23 additions & 23 deletions datafusion/core/tests/sql/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,25 +121,25 @@ async fn case_when_else_with_null_contant() -> Result<()> {
FROM t1";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+----------------------------------------------------------------------------------------------+",
"| CASE WHEN #t1.c1 = Utf8(\"a\") THEN Int64(1) WHEN Utf8(NULL) THEN Int64(2) ELSE Int64(999) END |",
"+----------------------------------------------------------------------------------------------+",
"| 1 |",
"| 999 |",
"| 999 |",
"| 999 |",
"+----------------------------------------------------------------------------------------------+",
"+----------------------------------------------------------------------------------------+",
"| CASE WHEN #t1.c1 = Utf8(\"a\") THEN Int64(1) WHEN NULL THEN Int64(2) ELSE Int64(999) END |",
"+----------------------------------------------------------------------------------------+",
"| 1 |",
"| 999 |",
"| 999 |",
"| 999 |",
"+----------------------------------------------------------------------------------------+",
];
assert_batches_eq!(expected, &actual);

let sql = "SELECT CASE WHEN NULL THEN 'foo' ELSE 'bar' END";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+------------------------------------------------------------+",
"| CASE WHEN Utf8(NULL) THEN Utf8(\"foo\") ELSE Utf8(\"bar\") END |",
"+------------------------------------------------------------+",
"| bar |",
"+------------------------------------------------------------+",
"+------------------------------------------------------+",
"| CASE WHEN NULL THEN Utf8(\"foo\") ELSE Utf8(\"bar\") END |",
"+------------------------------------------------------+",
"| bar |",
"+------------------------------------------------------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
Expand Down Expand Up @@ -347,11 +347,11 @@ async fn test_string_concat_operator() -> Result<()> {
let sql = "SELECT 'aa' || NULL || 'd'";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+---------------------------------------+",
"| Utf8(\"aa\") || Utf8(NULL) || Utf8(\"d\") |",
"+---------------------------------------+",
"| |",
"+---------------------------------------+",
"+---------------------------------+",
"| Utf8(\"aa\") || NULL || Utf8(\"d\") |",
"+---------------------------------+",
"| |",
"+---------------------------------+",
];
assert_batches_eq!(expected, &actual);

Expand Down Expand Up @@ -387,11 +387,11 @@ async fn test_not_expressions() -> Result<()> {
let sql = "SELECT null, not(null)";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+------------+----------------+",
"| Utf8(NULL) | NOT Utf8(NULL) |",
"+------------+----------------+",
"| | |",
"+------------+----------------+",
"+------+----------+",
"| NULL | NOT NULL |",
"+------+----------+",
"| | |",
"+------+----------+",
];
assert_batches_eq!(expected, &actual);

Expand Down
10 changes: 5 additions & 5 deletions datafusion/core/tests/sql/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,11 @@ async fn coalesce_static_value_with_null() -> Result<()> {
let sql = "SELECT COALESCE(NULL, 'test')";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-----------------------------------+",
"| coalesce(Utf8(NULL),Utf8(\"test\")) |",
"+-----------------------------------+",
"| test |",
"+-----------------------------------+",
"+-----------------------------+",
"| coalesce(NULL,Utf8(\"test\")) |",
"+-----------------------------+",
"| test |",
"+-----------------------------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
Expand Down
6 changes: 5 additions & 1 deletion datafusion/core/tests/sql/joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,11 @@ async fn inner_join_nulls() {
let sql = "SELECT * FROM (SELECT null AS id1) t1
INNER JOIN (SELECT null AS id2) t2 ON id1 = id2";

let expected = vec!["++", "++"];
#[rustfmt::skip]
let expected = vec![
"++",
"++",
];

let ctx = create_join_context_qualified().unwrap();
let actual = execute_to_batches(&ctx, sql).await;
Expand Down
34 changes: 28 additions & 6 deletions datafusion/core/tests/sql/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -398,15 +398,37 @@ async fn select_distinct_from() {
1 IS NOT DISTINCT FROM CAST(NULL as INT) as c,
1 IS NOT DISTINCT FROM 1 as d,
NULL IS DISTINCT FROM NULL as e,
NULL IS NOT DISTINCT FROM NULL as f
NULL IS NOT DISTINCT FROM NULL as f,
NULL is DISTINCT FROM 1 as g,
NULL is NOT DISTINCT FROM 1 as h
";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+------+-------+-------+------+-------+------+",
"| a | b | c | d | e | f |",
"+------+-------+-------+------+-------+------+",
"| true | false | false | true | false | true |",
"+------+-------+-------+------+-------+------+",
"+------+-------+-------+------+-------+------+------+-------+",
"| a | b | c | d | e | f | g | h |",
"+------+-------+-------+------+-------+------+------+-------+",
"| true | false | false | true | false | true | true | false |",
"+------+-------+-------+------+-------+------+------+-------+",
];
assert_batches_eq!(expected, &actual);

let sql = "select
NULL IS DISTINCT FROM NULL as a,
NULL IS NOT DISTINCT FROM NULL as b,
NULL is DISTINCT FROM 1 as c,
NULL is NOT DISTINCT FROM 1 as d,
1 IS DISTINCT FROM CAST(NULL as INT) as e,
1 IS DISTINCT FROM 1 as f,
1 IS NOT DISTINCT FROM CAST(NULL as INT) as g,
1 IS NOT DISTINCT FROM 1 as h
";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-------+------+------+-------+------+-------+-------+------+",
"| a | b | c | d | e | f | g | h |",
"+-------+------+------+-------+------+-------+-------+------+",
"| false | true | true | false | true | false | false | true |",
"+-------+------+------+-------+------+-------+-------+------+",
];
assert_batches_eq!(expected, &actual);
}
Expand Down
22 changes: 22 additions & 0 deletions datafusion/expr/src/binary_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -590,8 +590,30 @@ fn eq_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
numerical_coercion(lhs_type, rhs_type)
.or_else(|| dictionary_coercion(lhs_type, rhs_type))
.or_else(|| temporal_coercion(lhs_type, rhs_type))
.or_else(|| null_coercion(lhs_type, rhs_type))
}

/// coercion rules from NULL type. Since NULL can be casted to most of types in arrow,
/// either lhs or rhs is NULL, if NULL can be casted to type of the other side, the coecion is valid.
fn null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
match (lhs_type, rhs_type) {
(DataType::Null, _) => {
if can_cast_types(&DataType::Null, rhs_type) {
Some(rhs_type.clone())
} else {
None
}
}
(_, DataType::Null) => {
if can_cast_types(&DataType::Null, lhs_type) {
Some(lhs_type.clone())
} else {
None
}
}
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
Expand Down
2 changes: 2 additions & 0 deletions datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ macro_rules! make_utf8_to_return_type {
Ok(match arg_type {
DataType::LargeUtf8 => $largeUtf8Type,
DataType::Utf8 => $utf8Type,
DataType::Null => DataType::Null,
_ => {
// this error is internal as `data_types` should have captured this.
return Err(DataFusionError::Internal(format!(
Expand Down Expand Up @@ -209,6 +210,7 @@ pub fn return_type(
DataType::Utf8 => {
DataType::List(Box::new(Field::new("item", DataType::Utf8, true)))
}
DataType::Null => DataType::Null,
_ => {
// this error is internal as `data_types` should have captured this.
return Err(DataFusionError::Internal(
Expand Down
Loading

0 comments on commit fcc35e8

Please sign in to comment.