-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add proper support for null
literal by introducing ScalarValue::Null
#2364
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,6 +39,8 @@ use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; | |
/// This is the single-valued counter-part of arrow’s `Array`. | ||
#[derive(Clone)] | ||
pub enum ScalarValue { | ||
/// represents `DataType::Null` (castable to/from any other type) | ||
Null, | ||
/// true or false value | ||
Boolean(Option<bool>), | ||
/// 32bit float | ||
|
@@ -170,6 +172,8 @@ impl PartialEq for ScalarValue { | |
(IntervalMonthDayNano(_), _) => false, | ||
(Struct(v1, t1), Struct(v2, t2)) => v1.eq(v2) && t1.eq(t2), | ||
(Struct(_, _), _) => false, | ||
(Null, Null) => true, | ||
(Null, _) => false, | ||
WinkerDu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
} | ||
} | ||
|
@@ -270,6 +274,8 @@ impl PartialOrd for ScalarValue { | |
} | ||
} | ||
(Struct(_, _), _) => None, | ||
(Null, Null) => Some(Ordering::Equal), | ||
(Null, _) => None, | ||
WinkerDu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
} | ||
} | ||
|
@@ -325,6 +331,8 @@ impl std::hash::Hash for ScalarValue { | |
v.hash(state); | ||
t.hash(state); | ||
} | ||
// stable hash for Null value | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
Null => 1.hash(state), | ||
} | ||
} | ||
} | ||
|
@@ -594,6 +602,7 @@ impl ScalarValue { | |
DataType::Interval(IntervalUnit::MonthDayNano) | ||
} | ||
ScalarValue::Struct(_, fields) => DataType::Struct(fields.as_ref().clone()), | ||
ScalarValue::Null => DataType::Null, | ||
} | ||
} | ||
|
||
|
@@ -623,7 +632,8 @@ impl ScalarValue { | |
pub fn is_null(&self) -> bool { | ||
matches!( | ||
*self, | ||
ScalarValue::Boolean(None) | ||
ScalarValue::Null | ||
| ScalarValue::Boolean(None) | ||
| ScalarValue::UInt8(None) | ||
| ScalarValue::UInt16(None) | ||
| ScalarValue::UInt32(None) | ||
|
@@ -836,6 +846,7 @@ impl ScalarValue { | |
ScalarValue::iter_to_decimal_array(scalars, precision, scale)?; | ||
Arc::new(decimal_array) | ||
} | ||
DataType::Null => ScalarValue::iter_to_null_array(scalars), | ||
DataType::Boolean => build_array_primitive!(BooleanArray, Boolean), | ||
DataType::Float32 => build_array_primitive!(Float32Array, Float32), | ||
DataType::Float64 => build_array_primitive!(Float64Array, Float64), | ||
|
@@ -968,6 +979,17 @@ impl ScalarValue { | |
Ok(array) | ||
} | ||
|
||
fn iter_to_null_array(scalars: impl IntoIterator<Item = ScalarValue>) -> ArrayRef { | ||
let length = | ||
scalars | ||
.into_iter() | ||
.fold(0usize, |r, element: ScalarValue| match element { | ||
ScalarValue::Null => r + 1, | ||
_ => unreachable!(), | ||
}); | ||
new_null_array(&DataType::Null, length) | ||
} | ||
|
||
fn iter_to_decimal_array( | ||
scalars: impl IntoIterator<Item = ScalarValue>, | ||
precision: &usize, | ||
|
@@ -1241,6 +1263,7 @@ impl ScalarValue { | |
Arc::new(StructArray::from(field_values)) | ||
} | ||
}, | ||
ScalarValue::Null => new_null_array(&DataType::Null, size), | ||
} | ||
} | ||
|
||
|
@@ -1266,6 +1289,7 @@ impl ScalarValue { | |
} | ||
|
||
Ok(match array.data_type() { | ||
DataType::Null => ScalarValue::Null, | ||
DataType::Decimal(precision, scale) => { | ||
ScalarValue::get_decimal_value_from_array(array, index, precision, scale) | ||
} | ||
|
@@ -1522,6 +1546,7 @@ impl ScalarValue { | |
eq_array_primitive!(array, index, IntervalMonthDayNanoArray, val) | ||
} | ||
ScalarValue::Struct(_, _) => unimplemented!(), | ||
ScalarValue::Null => array.data().is_null(index), | ||
} | ||
} | ||
|
||
|
@@ -1743,6 +1768,7 @@ impl TryFrom<&DataType> for ScalarValue { | |
DataType::Struct(fields) => { | ||
ScalarValue::Struct(None, Box::new(fields.clone())) | ||
} | ||
DataType::Null => ScalarValue::Null, | ||
_ => { | ||
return Err(DataFusionError::NotImplemented(format!( | ||
"Can't create a scalar from data_type \"{:?}\"", | ||
|
@@ -1835,6 +1861,7 @@ impl fmt::Display for ScalarValue { | |
)?, | ||
None => write!(f, "NULL")?, | ||
}, | ||
ScalarValue::Null => write!(f, "NULL")?, | ||
}; | ||
Ok(()) | ||
} | ||
|
@@ -1902,6 +1929,7 @@ impl fmt::Debug for ScalarValue { | |
None => write!(f, "Struct(NULL)"), | ||
} | ||
} | ||
ScalarValue::Null => write!(f, "NULL"), | ||
} | ||
} | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1491,7 +1491,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { | |
SQLExpr::Value(Value::Number(n, _)) => parse_sql_number(&n), | ||
SQLExpr::Value(Value::SingleQuotedString(s)) => Ok(lit(s)), | ||
SQLExpr::Value(Value::Null) => { | ||
Ok(Expr::Literal(ScalarValue::Utf8(None))) | ||
Ok(Expr::Literal(ScalarValue::Null)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🎉 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well this explains a lot of odd type coercion errors I have been seeing |
||
} | ||
SQLExpr::Value(Value::Boolean(n)) => Ok(lit(n)), | ||
SQLExpr::UnaryOp { op, expr } => self.parse_sql_unary_op( | ||
|
@@ -1529,7 +1529,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { | |
SQLExpr::Value(Value::Number(n, _)) => parse_sql_number(&n), | ||
SQLExpr::Value(Value::SingleQuotedString(ref s)) => Ok(lit(s.clone())), | ||
SQLExpr::Value(Value::Boolean(n)) => Ok(lit(n)), | ||
SQLExpr::Value(Value::Null) => Ok(Expr::Literal(ScalarValue::Utf8(None))), | ||
SQLExpr::Value(Value::Null) => Ok(Expr::Literal(ScalarValue::Null)), | ||
SQLExpr::Extract { field, expr } => Ok(Expr::ScalarFunction { | ||
fun: BuiltinScalarFunction::DatePart, | ||
args: vec![ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -121,25 +121,25 @@ async fn case_when_else_with_null_contant() -> Result<()> { | |
FROM t1"; | ||
let actual = execute_to_batches(&ctx, sql).await; | ||
let expected = vec![ | ||
"+----------------------------------------------------------------------------------------------+", | ||
"| CASE WHEN #t1.c1 = Utf8(\"a\") THEN Int64(1) WHEN Utf8(NULL) THEN Int64(2) ELSE Int64(999) END |", | ||
"+----------------------------------------------------------------------------------------------+", | ||
"| 1 |", | ||
"| 999 |", | ||
"| 999 |", | ||
"| 999 |", | ||
"+----------------------------------------------------------------------------------------------+", | ||
"+----------------------------------------------------------------------------------------+", | ||
"| CASE WHEN #t1.c1 = Utf8(\"a\") THEN Int64(1) WHEN NULL THEN Int64(2) ELSE Int64(999) END |", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the key difference here. Very nice 👍 |
||
"+----------------------------------------------------------------------------------------+", | ||
"| 1 |", | ||
"| 999 |", | ||
"| 999 |", | ||
"| 999 |", | ||
"+----------------------------------------------------------------------------------------+", | ||
]; | ||
assert_batches_eq!(expected, &actual); | ||
|
||
let sql = "SELECT CASE WHEN NULL THEN 'foo' ELSE 'bar' END"; | ||
let actual = execute_to_batches(&ctx, sql).await; | ||
let expected = vec![ | ||
"+------------------------------------------------------------+", | ||
"| CASE WHEN Utf8(NULL) THEN Utf8(\"foo\") ELSE Utf8(\"bar\") END |", | ||
"+------------------------------------------------------------+", | ||
"| bar |", | ||
"+------------------------------------------------------------+", | ||
"+------------------------------------------------------+", | ||
"| CASE WHEN NULL THEN Utf8(\"foo\") ELSE Utf8(\"bar\") END |", | ||
"+------------------------------------------------------+", | ||
"| bar |", | ||
"+------------------------------------------------------+", | ||
]; | ||
assert_batches_eq!(expected, &actual); | ||
Ok(()) | ||
|
@@ -347,11 +347,11 @@ async fn test_string_concat_operator() -> Result<()> { | |
let sql = "SELECT 'aa' || NULL || 'd'"; | ||
let actual = execute_to_batches(&ctx, sql).await; | ||
let expected = vec![ | ||
"+---------------------------------------+", | ||
"| Utf8(\"aa\") || Utf8(NULL) || Utf8(\"d\") |", | ||
"+---------------------------------------+", | ||
"| |", | ||
"+---------------------------------------+", | ||
"+---------------------------------+", | ||
"| Utf8(\"aa\") || NULL || Utf8(\"d\") |", | ||
"+---------------------------------+", | ||
"| |", | ||
"+---------------------------------+", | ||
]; | ||
assert_batches_eq!(expected, &actual); | ||
|
||
|
@@ -387,11 +387,11 @@ async fn test_not_expressions() -> Result<()> { | |
let sql = "SELECT null, not(null)"; | ||
let actual = execute_to_batches(&ctx, sql).await; | ||
let expected = vec![ | ||
"+------------+----------------+", | ||
"| Utf8(NULL) | NOT Utf8(NULL) |", | ||
"+------------+----------------+", | ||
"| | |", | ||
"+------------+----------------+", | ||
"+------+----------+", | ||
"| NULL | NOT NULL |", | ||
"+------+----------+", | ||
"| | |", | ||
"+------+----------+", | ||
]; | ||
assert_batches_eq!(expected, &actual); | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -829,7 +829,11 @@ async fn inner_join_nulls() { | |
let sql = "SELECT * FROM (SELECT null AS id1) t1 | ||
INNER JOIN (SELECT null AS id2) t2 ON id1 = id2"; | ||
|
||
let expected = vec!["++", "++"]; | ||
#[rustfmt::skip] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
let expected = vec![ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This answer is not correct -- there should be no rows that match. This is because the join should produce rows where However, Here is the query in postgres: alamb=# SELECT * FROM (SELECT null AS id1) t1
INNER JOIN (SELECT null AS id2) t2 ON id1 = id2
alamb-# ;
id1 | id2
-----+-----
(0 rows)
|
||
"++", | ||
"++", | ||
]; | ||
|
||
let ctx = create_join_context_qualified().unwrap(); | ||
let actual = execute_to_batches(&ctx, sql).await; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -590,8 +590,30 @@ fn eq_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> { | |
numerical_coercion(lhs_type, rhs_type) | ||
.or_else(|| dictionary_coercion(lhs_type, rhs_type)) | ||
.or_else(|| temporal_coercion(lhs_type, rhs_type)) | ||
.or_else(|| null_coercion(lhs_type, rhs_type)) | ||
} | ||
|
||
/// coercion rules from NULL type. Since NULL can be casted to most of types in arrow, | ||
/// either lhs or rhs is NULL, if NULL can be casted to type of the other side, the coecion is valid. | ||
fn null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. very cool |
||
match (lhs_type, rhs_type) { | ||
(DataType::Null, _) => { | ||
if can_cast_types(&DataType::Null, rhs_type) { | ||
Some(rhs_type.clone()) | ||
} else { | ||
None | ||
} | ||
} | ||
(_, DataType::Null) => { | ||
if can_cast_types(&DataType::Null, lhs_type) { | ||
Some(lhs_type.clone()) | ||
} else { | ||
None | ||
} | ||
} | ||
_ => None, | ||
} | ||
} | ||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍