Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add proper support for null literal by introducing ScalarValue::Null #2364

Merged
merged 2 commits into from
May 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc};
/// This is the single-valued counter-part of arrow’s `Array`.
#[derive(Clone)]
pub enum ScalarValue {
/// represents `DataType::Null` (castable to/from any other type)
Null,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

/// true or false value
Boolean(Option<bool>),
/// 32bit float
Expand Down Expand Up @@ -170,6 +172,8 @@ impl PartialEq for ScalarValue {
(IntervalMonthDayNano(_), _) => false,
(Struct(v1, t1), Struct(v2, t2)) => v1.eq(v2) && t1.eq(t2),
(Struct(_, _), _) => false,
(Null, Null) => true,
(Null, _) => false,
WinkerDu marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
Expand Down Expand Up @@ -270,6 +274,8 @@ impl PartialOrd for ScalarValue {
}
}
(Struct(_, _), _) => None,
(Null, Null) => Some(Ordering::Equal),
(Null, _) => None,
WinkerDu marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
Expand Down Expand Up @@ -325,6 +331,8 @@ impl std::hash::Hash for ScalarValue {
v.hash(state);
t.hash(state);
}
// stable hash for Null value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Null => 1.hash(state),
}
}
}
Expand Down Expand Up @@ -594,6 +602,7 @@ impl ScalarValue {
DataType::Interval(IntervalUnit::MonthDayNano)
}
ScalarValue::Struct(_, fields) => DataType::Struct(fields.as_ref().clone()),
ScalarValue::Null => DataType::Null,
}
}

Expand Down Expand Up @@ -623,7 +632,8 @@ impl ScalarValue {
pub fn is_null(&self) -> bool {
matches!(
*self,
ScalarValue::Boolean(None)
ScalarValue::Null
| ScalarValue::Boolean(None)
| ScalarValue::UInt8(None)
| ScalarValue::UInt16(None)
| ScalarValue::UInt32(None)
Expand Down Expand Up @@ -836,6 +846,7 @@ impl ScalarValue {
ScalarValue::iter_to_decimal_array(scalars, precision, scale)?;
Arc::new(decimal_array)
}
DataType::Null => ScalarValue::iter_to_null_array(scalars),
DataType::Boolean => build_array_primitive!(BooleanArray, Boolean),
DataType::Float32 => build_array_primitive!(Float32Array, Float32),
DataType::Float64 => build_array_primitive!(Float64Array, Float64),
Expand Down Expand Up @@ -968,6 +979,17 @@ impl ScalarValue {
Ok(array)
}

fn iter_to_null_array(scalars: impl IntoIterator<Item = ScalarValue>) -> ArrayRef {
let length =
scalars
.into_iter()
.fold(0usize, |r, element: ScalarValue| match element {
ScalarValue::Null => r + 1,
_ => unreachable!(),
});
new_null_array(&DataType::Null, length)
}

fn iter_to_decimal_array(
scalars: impl IntoIterator<Item = ScalarValue>,
precision: &usize,
Expand Down Expand Up @@ -1241,6 +1263,7 @@ impl ScalarValue {
Arc::new(StructArray::from(field_values))
}
},
ScalarValue::Null => new_null_array(&DataType::Null, size),
}
}

Expand All @@ -1266,6 +1289,7 @@ impl ScalarValue {
}

Ok(match array.data_type() {
DataType::Null => ScalarValue::Null,
DataType::Decimal(precision, scale) => {
ScalarValue::get_decimal_value_from_array(array, index, precision, scale)
}
Expand Down Expand Up @@ -1522,6 +1546,7 @@ impl ScalarValue {
eq_array_primitive!(array, index, IntervalMonthDayNanoArray, val)
}
ScalarValue::Struct(_, _) => unimplemented!(),
ScalarValue::Null => array.data().is_null(index),
}
}

Expand Down Expand Up @@ -1743,6 +1768,7 @@ impl TryFrom<&DataType> for ScalarValue {
DataType::Struct(fields) => {
ScalarValue::Struct(None, Box::new(fields.clone()))
}
DataType::Null => ScalarValue::Null,
_ => {
return Err(DataFusionError::NotImplemented(format!(
"Can't create a scalar from data_type \"{:?}\"",
Expand Down Expand Up @@ -1835,6 +1861,7 @@ impl fmt::Display for ScalarValue {
)?,
None => write!(f, "NULL")?,
},
ScalarValue::Null => write!(f, "NULL")?,
};
Ok(())
}
Expand Down Expand Up @@ -1902,6 +1929,7 @@ impl fmt::Debug for ScalarValue {
None => write!(f, "Struct(NULL)"),
}
}
ScalarValue::Null => write!(f, "NULL"),
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ impl LogicalPlanBuilder {
.iter()
.enumerate()
.map(|(j, expr)| {
if let Expr::Literal(ScalarValue::Utf8(None)) = expr {
if let Expr::Literal(ScalarValue::Null) = expr {
nulls.push((i, j));
Ok(field_types[j].clone())
} else {
Expand Down
6 changes: 5 additions & 1 deletion datafusion/core/src/physical_plan/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,11 @@ fn equal_rows(
.iter()
.zip(right_arrays)
.all(|(l, r)| match l.data_type() {
DataType::Null => true,
DataType::Null => {
// lhs and rhs are both `DataType::Null`, so the euqal result
// is dependent on `null_equals_null`
null_equals_null
}
DataType::Boolean => {
equal_rows_elem!(BooleanArray, l, r, left, right, null_equals_null)
}
Expand Down
16 changes: 16 additions & 0 deletions datafusion/core/src/physical_plan/hash_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ fn combine_hashes(l: u64, r: u64) -> u64 {
hash.wrapping_mul(37).wrapping_add(r)
}

fn hash_null(random_state: &RandomState, hashes_buffer: &'_ mut [u64], mul_col: bool) {
if mul_col {
hashes_buffer.iter_mut().for_each(|hash| {
// stable hash for null value
*hash = combine_hashes(i128::get_hash(&1, random_state), *hash);
})
} else {
hashes_buffer.iter_mut().for_each(|hash| {
*hash = i128::get_hash(&1, random_state);
})
}
}

fn hash_decimal128<'a>(
array: &ArrayRef,
random_state: &RandomState,
Expand Down Expand Up @@ -284,6 +297,9 @@ pub fn create_hashes<'a>(

for col in arrays {
match col.data_type() {
DataType::Null => {
hash_null(random_state, hashes_buffer, multi_col);
}
DataType::Decimal(_, _) => {
hash_decimal128(col, random_state, hashes_buffer, multi_col);
}
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1491,7 +1491,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
SQLExpr::Value(Value::Number(n, _)) => parse_sql_number(&n),
SQLExpr::Value(Value::SingleQuotedString(s)) => Ok(lit(s)),
SQLExpr::Value(Value::Null) => {
Ok(Expr::Literal(ScalarValue::Utf8(None)))
Ok(Expr::Literal(ScalarValue::Null))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well this explains a lot of odd type coercion errors I have been seeing

}
SQLExpr::Value(Value::Boolean(n)) => Ok(lit(n)),
SQLExpr::UnaryOp { op, expr } => self.parse_sql_unary_op(
Expand Down Expand Up @@ -1529,7 +1529,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
SQLExpr::Value(Value::Number(n, _)) => parse_sql_number(&n),
SQLExpr::Value(Value::SingleQuotedString(ref s)) => Ok(lit(s.clone())),
SQLExpr::Value(Value::Boolean(n)) => Ok(lit(n)),
SQLExpr::Value(Value::Null) => Ok(Expr::Literal(ScalarValue::Utf8(None))),
SQLExpr::Value(Value::Null) => Ok(Expr::Literal(ScalarValue::Null)),
SQLExpr::Extract { field, expr } => Ok(Expr::ScalarFunction {
fun: BuiltinScalarFunction::DatePart,
args: vec![
Expand Down
46 changes: 23 additions & 23 deletions datafusion/core/tests/sql/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,25 +121,25 @@ async fn case_when_else_with_null_contant() -> Result<()> {
FROM t1";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+----------------------------------------------------------------------------------------------+",
"| CASE WHEN #t1.c1 = Utf8(\"a\") THEN Int64(1) WHEN Utf8(NULL) THEN Int64(2) ELSE Int64(999) END |",
"+----------------------------------------------------------------------------------------------+",
"| 1 |",
"| 999 |",
"| 999 |",
"| 999 |",
"+----------------------------------------------------------------------------------------------+",
"+----------------------------------------------------------------------------------------+",
"| CASE WHEN #t1.c1 = Utf8(\"a\") THEN Int64(1) WHEN NULL THEN Int64(2) ELSE Int64(999) END |",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the key difference here. Very nice 👍

"+----------------------------------------------------------------------------------------+",
"| 1 |",
"| 999 |",
"| 999 |",
"| 999 |",
"+----------------------------------------------------------------------------------------+",
];
assert_batches_eq!(expected, &actual);

let sql = "SELECT CASE WHEN NULL THEN 'foo' ELSE 'bar' END";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+------------------------------------------------------------+",
"| CASE WHEN Utf8(NULL) THEN Utf8(\"foo\") ELSE Utf8(\"bar\") END |",
"+------------------------------------------------------------+",
"| bar |",
"+------------------------------------------------------------+",
"+------------------------------------------------------+",
"| CASE WHEN NULL THEN Utf8(\"foo\") ELSE Utf8(\"bar\") END |",
"+------------------------------------------------------+",
"| bar |",
"+------------------------------------------------------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
Expand Down Expand Up @@ -347,11 +347,11 @@ async fn test_string_concat_operator() -> Result<()> {
let sql = "SELECT 'aa' || NULL || 'd'";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+---------------------------------------+",
"| Utf8(\"aa\") || Utf8(NULL) || Utf8(\"d\") |",
"+---------------------------------------+",
"| |",
"+---------------------------------------+",
"+---------------------------------+",
"| Utf8(\"aa\") || NULL || Utf8(\"d\") |",
"+---------------------------------+",
"| |",
"+---------------------------------+",
];
assert_batches_eq!(expected, &actual);

Expand Down Expand Up @@ -387,11 +387,11 @@ async fn test_not_expressions() -> Result<()> {
let sql = "SELECT null, not(null)";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+------------+----------------+",
"| Utf8(NULL) | NOT Utf8(NULL) |",
"+------------+----------------+",
"| | |",
"+------------+----------------+",
"+------+----------+",
"| NULL | NOT NULL |",
"+------+----------+",
"| | |",
"+------+----------+",
];
assert_batches_eq!(expected, &actual);

Expand Down
10 changes: 5 additions & 5 deletions datafusion/core/tests/sql/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,11 @@ async fn coalesce_static_value_with_null() -> Result<()> {
let sql = "SELECT COALESCE(NULL, 'test')";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-----------------------------------+",
"| coalesce(Utf8(NULL),Utf8(\"test\")) |",
"+-----------------------------------+",
"| test |",
"+-----------------------------------+",
"+-----------------------------+",
"| coalesce(NULL,Utf8(\"test\")) |",
"+-----------------------------+",
"| test |",
"+-----------------------------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
Expand Down
6 changes: 5 additions & 1 deletion datafusion/core/tests/sql/joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,11 @@ async fn inner_join_nulls() {
let sql = "SELECT * FROM (SELECT null AS id1) t1
INNER JOIN (SELECT null AS id2) t2 ON id1 = id2";

let expected = vec!["++", "++"];
#[rustfmt::skip]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

let expected = vec![
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This answer is not correct -- there should be no rows that match.

This is because the join should produce rows where id1 = id2 evaluates to true

However, null = null evaluates to null 🤯

Here is the query in postgres:

alamb=# SELECT * FROM (SELECT null AS id1) t1
            INNER JOIN (SELECT null AS id2) t2 ON id1 = id2
alamb-# ;
 id1 | id2 
-----+-----
(0 rows)

"++",
"++",
];

let ctx = create_join_context_qualified().unwrap();
let actual = execute_to_batches(&ctx, sql).await;
Expand Down
34 changes: 28 additions & 6 deletions datafusion/core/tests/sql/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -398,15 +398,37 @@ async fn select_distinct_from() {
1 IS NOT DISTINCT FROM CAST(NULL as INT) as c,
1 IS NOT DISTINCT FROM 1 as d,
NULL IS DISTINCT FROM NULL as e,
NULL IS NOT DISTINCT FROM NULL as f
NULL IS NOT DISTINCT FROM NULL as f,
NULL is DISTINCT FROM 1 as g,
NULL is NOT DISTINCT FROM 1 as h
";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+------+-------+-------+------+-------+------+",
"| a | b | c | d | e | f |",
"+------+-------+-------+------+-------+------+",
"| true | false | false | true | false | true |",
"+------+-------+-------+------+-------+------+",
"+------+-------+-------+------+-------+------+------+-------+",
"| a | b | c | d | e | f | g | h |",
"+------+-------+-------+------+-------+------+------+-------+",
"| true | false | false | true | false | true | true | false |",
"+------+-------+-------+------+-------+------+------+-------+",
];
assert_batches_eq!(expected, &actual);

let sql = "select
NULL IS DISTINCT FROM NULL as a,
NULL IS NOT DISTINCT FROM NULL as b,
NULL is DISTINCT FROM 1 as c,
NULL is NOT DISTINCT FROM 1 as d,
1 IS DISTINCT FROM CAST(NULL as INT) as e,
1 IS DISTINCT FROM 1 as f,
1 IS NOT DISTINCT FROM CAST(NULL as INT) as g,
1 IS NOT DISTINCT FROM 1 as h
";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-------+------+------+-------+------+-------+-------+------+",
"| a | b | c | d | e | f | g | h |",
"+-------+------+------+-------+------+-------+-------+------+",
"| false | true | true | false | true | false | false | true |",
"+-------+------+------+-------+------+-------+-------+------+",
];
assert_batches_eq!(expected, &actual);
}
Expand Down
22 changes: 22 additions & 0 deletions datafusion/expr/src/binary_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -590,8 +590,30 @@ fn eq_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
numerical_coercion(lhs_type, rhs_type)
.or_else(|| dictionary_coercion(lhs_type, rhs_type))
.or_else(|| temporal_coercion(lhs_type, rhs_type))
.or_else(|| null_coercion(lhs_type, rhs_type))
}

/// coercion rules from NULL type. Since NULL can be casted to most of types in arrow,
/// either lhs or rhs is NULL, if NULL can be casted to type of the other side, the coecion is valid.
fn null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very cool

match (lhs_type, rhs_type) {
(DataType::Null, _) => {
if can_cast_types(&DataType::Null, rhs_type) {
Some(rhs_type.clone())
} else {
None
}
}
(_, DataType::Null) => {
if can_cast_types(&DataType::Null, lhs_type) {
Some(lhs_type.clone())
} else {
None
}
}
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
Expand Down
2 changes: 2 additions & 0 deletions datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ macro_rules! make_utf8_to_return_type {
Ok(match arg_type {
DataType::LargeUtf8 => $largeUtf8Type,
DataType::Utf8 => $utf8Type,
DataType::Null => DataType::Null,
_ => {
// this error is internal as `data_types` should have captured this.
return Err(DataFusionError::Internal(format!(
Expand Down Expand Up @@ -209,6 +210,7 @@ pub fn return_type(
DataType::Utf8 => {
DataType::List(Box::new(Field::new("item", DataType::Utf8, true)))
}
DataType::Null => DataType::Null,
_ => {
// this error is internal as `data_types` should have captured this.
return Err(DataFusionError::Internal(
Expand Down
Loading