diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 58579bcf9cae..28f796ba651c 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -39,6 +39,8 @@ use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; /// This is the single-valued counter-part of arrow’s `Array`. #[derive(Clone)] pub enum ScalarValue { + /// represents `DataType::Null` (castable to/from any other type) + Null, /// true or false value Boolean(Option), /// 32bit float @@ -170,6 +172,8 @@ impl PartialEq for ScalarValue { (IntervalMonthDayNano(_), _) => false, (Struct(v1, t1), Struct(v2, t2)) => v1.eq(v2) && t1.eq(t2), (Struct(_, _), _) => false, + (Null, Null) => true, + (Null, _) => false, } } } @@ -270,6 +274,8 @@ impl PartialOrd for ScalarValue { } } (Struct(_, _), _) => None, + (Null, Null) => Some(Ordering::Equal), + (Null, _) => None, } } } @@ -325,6 +331,8 @@ impl std::hash::Hash for ScalarValue { v.hash(state); t.hash(state); } + // stable hash for Null value + Null => 1.hash(state), } } } @@ -594,6 +602,7 @@ impl ScalarValue { DataType::Interval(IntervalUnit::MonthDayNano) } ScalarValue::Struct(_, fields) => DataType::Struct(fields.as_ref().clone()), + ScalarValue::Null => DataType::Null, } } @@ -623,7 +632,8 @@ impl ScalarValue { pub fn is_null(&self) -> bool { matches!( *self, - ScalarValue::Boolean(None) + ScalarValue::Null + | ScalarValue::Boolean(None) | ScalarValue::UInt8(None) | ScalarValue::UInt16(None) | ScalarValue::UInt32(None) @@ -847,6 +857,7 @@ impl ScalarValue { ScalarValue::iter_to_decimal_array(scalars, precision, scale)?; Arc::new(decimal_array) } + DataType::Null => ScalarValue::iter_to_null_array(scalars), DataType::Boolean => build_array_primitive!(BooleanArray, Boolean), DataType::Float32 => build_array_primitive!(Float32Array, Float32), DataType::Float64 => build_array_primitive!(Float64Array, Float64), @@ -979,6 +990,17 @@ impl ScalarValue { Ok(array) } + fn iter_to_null_array(scalars: impl IntoIterator) -> ArrayRef { + let length = + scalars + .into_iter() + .fold(0usize, |r, element: ScalarValue| match element { + ScalarValue::Null => r + 1, + _ => unreachable!(), + }); + new_null_array(&DataType::Null, length) + } + fn iter_to_decimal_array( scalars: impl IntoIterator, precision: &usize, @@ -1252,6 +1274,7 @@ impl ScalarValue { Arc::new(StructArray::from(field_values)) } }, + ScalarValue::Null => new_null_array(&DataType::Null, size), } } @@ -1277,6 +1300,7 @@ impl ScalarValue { } Ok(match array.data_type() { + DataType::Null => ScalarValue::Null, DataType::Decimal(precision, scale) => { ScalarValue::get_decimal_value_from_array(array, index, precision, scale) } @@ -1530,6 +1554,7 @@ impl ScalarValue { eq_array_primitive!(array, index, IntervalMonthDayNanoArray, val) } ScalarValue::Struct(_, _) => unimplemented!(), + ScalarValue::Null => array.data().is_null(index), } } @@ -1760,6 +1785,7 @@ impl TryFrom<&DataType> for ScalarValue { DataType::Interval(IntervalUnit::MonthDayNano) => { ScalarValue::IntervalMonthDayNano(None) } + DataType::Null => ScalarValue::Null, _ => { return Err(DataFusionError::NotImplemented(format!( "Can't create a scalar from data_type \"{:?}\"", @@ -1852,6 +1878,7 @@ impl fmt::Display for ScalarValue { )?, None => write!(f, "NULL")?, }, + ScalarValue::Null => write!(f, "NULL")?, }; Ok(()) } @@ -1919,6 +1946,7 @@ impl fmt::Debug for ScalarValue { None => write!(f, "Struct(NULL)"), } } + ScalarValue::Null => write!(f, "NULL"), } } } diff --git a/datafusion/core/src/logical_plan/builder.rs b/datafusion/core/src/logical_plan/builder.rs index 808af572efbb..10bbed971cc2 100644 --- a/datafusion/core/src/logical_plan/builder.rs +++ b/datafusion/core/src/logical_plan/builder.rs @@ -154,7 +154,7 @@ impl LogicalPlanBuilder { .iter() .enumerate() .map(|(j, expr)| { - if let Expr::Literal(ScalarValue::Utf8(None)) = expr { + if let Expr::Literal(ScalarValue::Null) = expr { nulls.push((i, j)); Ok(field_types[j].clone()) } else { diff --git a/datafusion/core/src/physical_plan/hash_join.rs b/datafusion/core/src/physical_plan/hash_join.rs index ce371fecf070..4fa92c3cb98a 100644 --- a/datafusion/core/src/physical_plan/hash_join.rs +++ b/datafusion/core/src/physical_plan/hash_join.rs @@ -845,7 +845,11 @@ fn equal_rows( .iter() .zip(right_arrays) .all(|(l, r)| match l.data_type() { - DataType::Null => true, + DataType::Null => { + // lhs and rhs are both `DataType::Null`, so the euqal result + // is dependent on `null_equals_null` + null_equals_null + } DataType::Boolean => { equal_rows_elem!(BooleanArray, l, r, left, right, null_equals_null) } diff --git a/datafusion/core/src/physical_plan/hash_utils.rs b/datafusion/core/src/physical_plan/hash_utils.rs index 4acc84fcdc1e..9562e900298a 100644 --- a/datafusion/core/src/physical_plan/hash_utils.rs +++ b/datafusion/core/src/physical_plan/hash_utils.rs @@ -39,6 +39,19 @@ fn combine_hashes(l: u64, r: u64) -> u64 { hash.wrapping_mul(37).wrapping_add(r) } +fn hash_null(random_state: &RandomState, hashes_buffer: &'_ mut [u64], mul_col: bool) { + if mul_col { + hashes_buffer.iter_mut().for_each(|hash| { + // stable hash for null value + *hash = combine_hashes(i128::get_hash(&1, random_state), *hash); + }) + } else { + hashes_buffer.iter_mut().for_each(|hash| { + *hash = i128::get_hash(&1, random_state); + }) + } +} + fn hash_decimal128<'a>( array: &ArrayRef, random_state: &RandomState, @@ -310,6 +323,9 @@ pub fn create_hashes<'a>( for col in arrays { match col.data_type() { + DataType::Null => { + hash_null(random_state, hashes_buffer, multi_col); + } DataType::Decimal(_, _) => { hash_decimal128(col, random_state, hashes_buffer, multi_col); } diff --git a/datafusion/core/src/sql/planner.rs b/datafusion/core/src/sql/planner.rs index dcd4b7d25787..59938e656f7d 100644 --- a/datafusion/core/src/sql/planner.rs +++ b/datafusion/core/src/sql/planner.rs @@ -1681,7 +1681,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Value(Value::Number(n, _)) => parse_sql_number(&n), SQLExpr::Value(Value::SingleQuotedString(s)) => Ok(lit(s)), SQLExpr::Value(Value::Null) => { - Ok(Expr::Literal(ScalarValue::Utf8(None))) + Ok(Expr::Literal(ScalarValue::Null)) } SQLExpr::Value(Value::Boolean(n)) => Ok(lit(n)), SQLExpr::UnaryOp { op, expr } => { @@ -1707,7 +1707,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Value(Value::SingleQuotedString(ref s)) => Ok(lit(s.clone())), SQLExpr::Value(Value::EscapedStringLiteral(ref s)) => Ok(lit(s.clone())), SQLExpr::Value(Value::Boolean(n)) => Ok(lit(n)), - SQLExpr::Value(Value::Null) => Ok(Expr::Literal(ScalarValue::Utf8(None))), + SQLExpr::Value(Value::Null) => Ok(Expr::Literal(ScalarValue::Null)), SQLExpr::Extract { field, expr } => Ok(Expr::ScalarFunction { fun: BuiltinScalarFunction::DatePart, args: vec![ @@ -4259,9 +4259,9 @@ mod tests { fn union_with_null() { let sql = "SELECT NULL a UNION ALL SELECT 1.1 a"; let expected = "Union\ - \n Projection: Utf8(NULL) AS a\ + \n Projection: CAST(NULL AS Float64) AS a\ \n EmptyRelation\ - \n Projection: CAST(Float64(1.1) AS Utf8) AS a\ + \n Projection: Float64(1.1) AS a\ \n EmptyRelation"; quick_test(sql, expected); } diff --git a/datafusion/core/tests/sql/expr.rs b/datafusion/core/tests/sql/expr.rs index d9ee3b062c44..ac24940e46ee 100644 --- a/datafusion/core/tests/sql/expr.rs +++ b/datafusion/core/tests/sql/expr.rs @@ -362,11 +362,11 @@ async fn test_string_concat_operator() -> Result<()> { let sql = "SELECT 'aa' || NULL || 'd'"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+---------------------------------------+", - "| Utf8(\"aa\") || Utf8(NULL) || Utf8(\"d\") |", - "+---------------------------------------+", - "| |", - "+---------------------------------------+", + "+---------------------------------+", + "| Utf8(\"aa\") || NULL || Utf8(\"d\") |", + "+---------------------------------+", + "| |", + "+---------------------------------+", ]; assert_batches_eq!(expected, &actual); diff --git a/datafusion/core/tests/sql/functions.rs b/datafusion/core/tests/sql/functions.rs index e5285d049905..847f53832b18 100644 --- a/datafusion/core/tests/sql/functions.rs +++ b/datafusion/core/tests/sql/functions.rs @@ -197,11 +197,11 @@ async fn coalesce_static_value_with_null() -> Result<()> { let sql = "SELECT COALESCE(NULL, 'test')"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+-----------------------------------+", - "| coalesce(Utf8(NULL),Utf8(\"test\")) |", - "+-----------------------------------+", - "| test |", - "+-----------------------------------+", + "+-----------------------------+", + "| coalesce(NULL,Utf8(\"test\")) |", + "+-----------------------------+", + "| test |", + "+-----------------------------+", ]; assert_batches_eq!(expected, &actual); Ok(()) diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 735f9472d325..352215f12338 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -840,7 +840,11 @@ async fn inner_join_nulls() { let sql = "SELECT * FROM (SELECT null AS id1) t1 INNER JOIN (SELECT null AS id2) t2 ON id1 = id2"; - let expected = vec!["++", "++"]; + #[rustfmt::skip] + let expected = vec![ + "++", + "++", + ]; let ctx = create_join_context_qualified("t1", "t2").unwrap(); let actual = execute_to_batches(&ctx, sql).await; diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index 45ea0206dc2a..1e5307b4491c 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -398,15 +398,37 @@ async fn select_distinct_from() { 1 IS NOT DISTINCT FROM CAST(NULL as INT) as c, 1 IS NOT DISTINCT FROM 1 as d, NULL IS DISTINCT FROM NULL as e, - NULL IS NOT DISTINCT FROM NULL as f + NULL IS NOT DISTINCT FROM NULL as f, + NULL is DISTINCT FROM 1 as g, + NULL is NOT DISTINCT FROM 1 as h "; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+------+-------+-------+------+-------+------+", - "| a | b | c | d | e | f |", - "+------+-------+-------+------+-------+------+", - "| true | false | false | true | false | true |", - "+------+-------+-------+------+-------+------+", + "+------+-------+-------+------+-------+------+------+-------+", + "| a | b | c | d | e | f | g | h |", + "+------+-------+-------+------+-------+------+------+-------+", + "| true | false | false | true | false | true | true | false |", + "+------+-------+-------+------+-------+------+------+-------+", + ]; + assert_batches_eq!(expected, &actual); + + let sql = "select + NULL IS DISTINCT FROM NULL as a, + NULL IS NOT DISTINCT FROM NULL as b, + NULL is DISTINCT FROM 1 as c, + NULL is NOT DISTINCT FROM 1 as d, + 1 IS DISTINCT FROM CAST(NULL as INT) as e, + 1 IS DISTINCT FROM 1 as f, + 1 IS NOT DISTINCT FROM CAST(NULL as INT) as g, + 1 IS NOT DISTINCT FROM 1 as h + "; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-------+------+------+-------+------+-------+-------+------+", + "| a | b | c | d | e | f | g | h |", + "+-------+------+------+-------+------+-------+-------+------+", + "| false | true | true | false | true | false | false | true |", + "+-------+------+------+-------+------+-------+-------+------+", ]; assert_batches_eq!(expected, &actual); } diff --git a/datafusion/expr/src/binary_rule.rs b/datafusion/expr/src/binary_rule.rs index af7d69d619bc..b098266abc81 100644 --- a/datafusion/expr/src/binary_rule.rs +++ b/datafusion/expr/src/binary_rule.rs @@ -625,6 +625,7 @@ fn eq_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { numerical_coercion(lhs_type, rhs_type) .or_else(|| dictionary_coercion(lhs_type, rhs_type)) .or_else(|| temporal_coercion(lhs_type, rhs_type)) + .or_else(|| null_coercion(lhs_type, rhs_type)) } /// Coercion rule for interval diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 4dcb3d9d61b4..e6cdfa428f7b 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -65,6 +65,7 @@ macro_rules! make_utf8_to_return_type { Ok(match arg_type { DataType::LargeUtf8 => $largeUtf8Type, DataType::Utf8 => $utf8Type, + DataType::Null => DataType::Null, _ => { // this error is internal as `data_types` should have captured this. return Err(DataFusionError::Internal(format!( @@ -234,6 +235,7 @@ pub fn return_type( DataType::Utf8 => { DataType::List(Box::new(Field::new("item", DataType::Utf8, true))) } + DataType::Null => DataType::Null, _ => { // this error is internal as `data_types` should have captured this. return Err(DataFusionError::Internal( diff --git a/datafusion/expr/src/type_coercion.rs b/datafusion/expr/src/type_coercion.rs index e57de663390b..1290e858ecda 100644 --- a/datafusion/expr/src/type_coercion.rs +++ b/datafusion/expr/src/type_coercion.rs @@ -31,7 +31,10 @@ //! use crate::{Signature, TypeSignature}; -use arrow::datatypes::{DataType, TimeUnit}; +use arrow::{ + compute::can_cast_types, + datatypes::{DataType, TimeUnit}, +}; use datafusion_common::{DataFusionError, Result}; /// Returns the data types that each argument must be coerced to match @@ -142,25 +145,35 @@ fn maybe_data_types( /// See the module level documentation for more detail on coercion. pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool { use self::DataType::*; + // Null can convert to most of types match type_into { - Int8 => matches!(type_from, Int8), - Int16 => matches!(type_from, Int8 | Int16 | UInt8), - Int32 => matches!(type_from, Int8 | Int16 | Int32 | UInt8 | UInt16), + Int8 => matches!(type_from, Null | Int8), + Int16 => matches!(type_from, Null | Int8 | Int16 | UInt8), + Int32 => matches!(type_from, Null | Int8 | Int16 | Int32 | UInt8 | UInt16), Int64 => matches!( type_from, - Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 + Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 ), - UInt8 => matches!(type_from, UInt8), - UInt16 => matches!(type_from, UInt8 | UInt16), - UInt32 => matches!(type_from, UInt8 | UInt16 | UInt32), - UInt64 => matches!(type_from, UInt8 | UInt16 | UInt32 | UInt64), + UInt8 => matches!(type_from, Null | UInt8), + UInt16 => matches!(type_from, Null | UInt8 | UInt16), + UInt32 => matches!(type_from, Null | UInt8 | UInt16 | UInt32), + UInt64 => matches!(type_from, Null | UInt8 | UInt16 | UInt32 | UInt64), Float32 => matches!( type_from, - Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 | Float32 + Null | Int8 + | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Float32 ), Float64 => matches!( type_from, - Int8 | Int16 + Null | Int8 + | Int16 | Int32 | Int64 | UInt8 @@ -171,9 +184,10 @@ pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool { | Float64 ), Timestamp(TimeUnit::Nanosecond, None) => { - matches!(type_from, Timestamp(_, None) | Date32) + matches!(type_from, Null | Timestamp(_, None) | Date32) } Utf8 | LargeUtf8 => true, + Null => can_cast_types(type_from, type_into), _ => false, } } diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 81aee81e2823..6f1f7e629ec7 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -804,6 +804,20 @@ macro_rules! compute_decimal_op { }}; } +macro_rules! compute_null_op { + ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ + let ll = $LEFT + .as_any() + .downcast_ref::<$DT>() + .expect("compute_op failed to downcast array"); + let rr = $RIGHT + .as_any() + .downcast_ref::<$DT>() + .expect("compute_op failed to downcast array"); + Ok(Arc::new(paste::expr! {[<$OP _null>]}(&ll, &rr)?)) + }}; +} + /// Invoke a compute kernel on a pair of binary data arrays macro_rules! compute_utf8_op { ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ @@ -1103,6 +1117,7 @@ macro_rules! binary_array_op_scalar { macro_rules! binary_array_op { ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ match $LEFT.data_type() { + DataType::Null => compute_null_op!($LEFT, $RIGHT, $OP, NullArray), DataType::Decimal(_,_) => compute_decimal_op!($LEFT, $RIGHT, $OP, DecimalArray), DataType::Int8 => compute_op!($LEFT, $RIGHT, $OP, Int8Array), DataType::Int16 => compute_op!($LEFT, $RIGHT, $OP, Int16Array), @@ -1484,7 +1499,16 @@ impl BinaryExpr { Operator::GtEq => gt_eq_dyn(&left, &right), Operator::Eq => eq_dyn(&left, &right), Operator::NotEq => neq_dyn(&left, &right), - Operator::IsDistinctFrom => binary_array_op!(left, right, is_distinct_from), + Operator::IsDistinctFrom => { + match (left_data_type, right_data_type) { + // exchange lhs and rhs when lhs is Null, since `binary_array_op` is + // always try to down cast array according to $LEFT expression. + (DataType::Null, _) => { + binary_array_op!(right, left, is_distinct_from) + } + _ => binary_array_op!(left, right, is_distinct_from), + } + } Operator::IsNotDistinctFrom => { binary_array_op!(left, right, is_not_distinct_from) } @@ -1561,6 +1585,27 @@ fn is_distinct_from_utf8( .collect()) } +fn is_distinct_from_null(left: &NullArray, _right: &NullArray) -> Result { + let length = left.len(); + make_boolean_array(length, false) +} + +fn is_not_distinct_from_null( + left: &NullArray, + _right: &NullArray, +) -> Result { + let length = left.len(); + make_boolean_array(length, true) +} + +pub fn eq_null(left: &NullArray, _right: &NullArray) -> Result { + Ok((0..left.len()).into_iter().map(|_| None).collect()) +} + +fn make_boolean_array(length: usize, value: bool) -> Result { + Ok((0..length).into_iter().map(|_| Some(value)).collect()) +} + fn is_not_distinct_from( left: &PrimitiveArray, right: &PrimitiveArray, diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index bfac11a5c1ac..93c141453076 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -453,6 +453,10 @@ impl PhysicalExpr for InListExpr { DataType::LargeUtf8 => { self.compare_utf8::(array, list_values, self.negated) } + DataType::Null => { + let null_array = new_null_array(&DataType::Boolean, array.len()); + Ok(ColumnarValue::Array(Arc::new(null_array))) + } DataType::Timestamp(unit, _) => match unit { TimeUnit::Second => make_contains_primitive!( array, diff --git a/datafusion/physical-expr/src/expressions/nullif.rs b/datafusion/physical-expr/src/expressions/nullif.rs index c2f3dcea0dd2..2999fe4135df 100644 --- a/datafusion/physical-expr/src/expressions/nullif.rs +++ b/datafusion/physical-expr/src/expressions/nullif.rs @@ -17,7 +17,7 @@ use std::sync::Arc; -use crate::expressions::binary::{eq_decimal, eq_decimal_scalar}; +use crate::expressions::binary::{eq_decimal, eq_decimal_scalar, eq_null}; use arrow::array::Array; use arrow::array::*; use arrow::compute::kernels::boolean::nullif;