From 13a1601a2f2ccab1662436e715055bdc2e8b0392 Mon Sep 17 00:00:00 2001 From: duripeng <453243496@qq.com> Date: Thu, 21 Apr 2022 01:25:12 +0800 Subject: [PATCH 1/2] introduce null --- datafusion/common/src/scalar.rs | 30 +++++++++++- datafusion/core/src/logical_plan/builder.rs | 2 +- .../core/src/physical_plan/hash_join.rs | 6 ++- .../core/src/physical_plan/hash_utils.rs | 20 ++++++++ datafusion/core/src/sql/planner.rs | 4 +- datafusion/core/tests/sql/expr.rs | 46 +++++++++--------- datafusion/core/tests/sql/functions.rs | 10 ++-- datafusion/core/tests/sql/joins.rs | 6 ++- datafusion/core/tests/sql/select.rs | 34 +++++++++++--- datafusion/expr/src/binary_rule.rs | 22 +++++++++ datafusion/expr/src/function.rs | 2 + datafusion/expr/src/type_coercion.rs | 40 +++++++++++----- .../physical-expr/src/expressions/binary.rs | 47 ++++++++++++++++++- .../physical-expr/src/expressions/in_list.rs | 4 ++ .../physical-expr/src/expressions/nullif.rs | 2 +- 15 files changed, 221 insertions(+), 54 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 03a59ff6d3db..4a7bc5337c5b 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -39,6 +39,8 @@ use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; /// This is the single-valued counter-part of arrow’s `Array`. #[derive(Clone)] pub enum ScalarValue { + /// represents `DataType::Null` (castable to/from any other type) + Null, /// true or false value Boolean(Option), /// 32bit float @@ -170,6 +172,8 @@ impl PartialEq for ScalarValue { (IntervalMonthDayNano(_), _) => false, (Struct(v1, t1), Struct(v2, t2)) => v1.eq(v2) && t1.eq(t2), (Struct(_, _), _) => false, + (Null, Null) => true, + (Null, _) => false, } } } @@ -270,6 +274,8 @@ impl PartialOrd for ScalarValue { } } (Struct(_, _), _) => None, + (Null, Null) => Some(Ordering::Equal), + (Null, _) => None, } } } @@ -325,6 +331,8 @@ impl std::hash::Hash for ScalarValue { v.hash(state); t.hash(state); } + // stable hash for Null value + Null => 1.hash(state), } } } @@ -594,6 +602,7 @@ impl ScalarValue { DataType::Interval(IntervalUnit::MonthDayNano) } ScalarValue::Struct(_, fields) => DataType::Struct(fields.as_ref().clone()), + ScalarValue::Null => DataType::Null, } } @@ -623,7 +632,8 @@ impl ScalarValue { pub fn is_null(&self) -> bool { matches!( *self, - ScalarValue::Boolean(None) + ScalarValue::Null + | ScalarValue::Boolean(None) | ScalarValue::UInt8(None) | ScalarValue::UInt16(None) | ScalarValue::UInt32(None) @@ -836,6 +846,7 @@ impl ScalarValue { ScalarValue::iter_to_decimal_array(scalars, precision, scale)?; Arc::new(decimal_array) } + DataType::Null => ScalarValue::iter_to_null_array(scalars), DataType::Boolean => build_array_primitive!(BooleanArray, Boolean), DataType::Float32 => build_array_primitive!(Float32Array, Float32), DataType::Float64 => build_array_primitive!(Float64Array, Float64), @@ -968,6 +979,17 @@ impl ScalarValue { Ok(array) } + fn iter_to_null_array(scalars: impl IntoIterator) -> ArrayRef { + let length = + scalars + .into_iter() + .fold(0usize, |r, element: ScalarValue| match element { + ScalarValue::Null => r + 1, + _ => unreachable!(), + }); + new_null_array(&DataType::Null, length) + } + fn iter_to_decimal_array( scalars: impl IntoIterator, precision: &usize, @@ -1241,6 +1263,7 @@ impl ScalarValue { Arc::new(StructArray::from(field_values)) } }, + ScalarValue::Null => new_null_array(&DataType::Null, size), } } @@ -1266,6 +1289,7 @@ impl ScalarValue { } Ok(match array.data_type() { + DataType::Null => ScalarValue::Null, DataType::Decimal(precision, scale) => { ScalarValue::get_decimal_value_from_array(array, index, precision, scale) } @@ -1522,6 +1546,7 @@ impl ScalarValue { eq_array_primitive!(array, index, IntervalMonthDayNanoArray, val) } ScalarValue::Struct(_, _) => unimplemented!(), + ScalarValue::Null => array.data().is_null(index), } } @@ -1743,6 +1768,7 @@ impl TryFrom<&DataType> for ScalarValue { DataType::Struct(fields) => { ScalarValue::Struct(None, Box::new(fields.clone())) } + DataType::Null => ScalarValue::Null, _ => { return Err(DataFusionError::NotImplemented(format!( "Can't create a scalar from data_type \"{:?}\"", @@ -1835,6 +1861,7 @@ impl fmt::Display for ScalarValue { )?, None => write!(f, "NULL")?, }, + ScalarValue::Null => write!(f, "NULL")?, }; Ok(()) } @@ -1902,6 +1929,7 @@ impl fmt::Debug for ScalarValue { None => write!(f, "Struct(NULL)"), } } + ScalarValue::Null => write!(f, "NULL"), } } } diff --git a/datafusion/core/src/logical_plan/builder.rs b/datafusion/core/src/logical_plan/builder.rs index 1fbb1f5f9dfe..8a0ea6d6667f 100644 --- a/datafusion/core/src/logical_plan/builder.rs +++ b/datafusion/core/src/logical_plan/builder.rs @@ -155,7 +155,7 @@ impl LogicalPlanBuilder { .iter() .enumerate() .map(|(j, expr)| { - if let Expr::Literal(ScalarValue::Utf8(None)) = expr { + if let Expr::Literal(ScalarValue::Null) = expr { nulls.push((i, j)); Ok(field_types[j].clone()) } else { diff --git a/datafusion/core/src/physical_plan/hash_join.rs b/datafusion/core/src/physical_plan/hash_join.rs index 8a4a342c11f6..ee763241aea7 100644 --- a/datafusion/core/src/physical_plan/hash_join.rs +++ b/datafusion/core/src/physical_plan/hash_join.rs @@ -817,7 +817,11 @@ fn equal_rows( .iter() .zip(right_arrays) .all(|(l, r)| match l.data_type() { - DataType::Null => true, + DataType::Null => { + // lhs and rhs are both `DataType::Null`, so the euqal result + // is dependent on `null_equals_null` + null_equals_null + } DataType::Boolean => { equal_rows_elem!(BooleanArray, l, r, left, right, null_equals_null) } diff --git a/datafusion/core/src/physical_plan/hash_utils.rs b/datafusion/core/src/physical_plan/hash_utils.rs index 4e503b19e7bf..63c9ad522fab 100644 --- a/datafusion/core/src/physical_plan/hash_utils.rs +++ b/datafusion/core/src/physical_plan/hash_utils.rs @@ -39,6 +39,23 @@ fn combine_hashes(l: u64, r: u64) -> u64 { hash.wrapping_mul(37).wrapping_add(r) } +fn hash_null<'a>( + random_state: &RandomState, + hashes_buffer: &'a mut [u64], + mul_col: bool, +) { + if mul_col { + hashes_buffer.iter_mut().for_each(|hash| { + // stable hash for null value + *hash = combine_hashes(i128::get_hash(&1, random_state), *hash); + }) + } else { + hashes_buffer.iter_mut().for_each(|hash| { + *hash = i128::get_hash(&1, random_state); + }) + } +} + fn hash_decimal128<'a>( array: &ArrayRef, random_state: &RandomState, @@ -284,6 +301,9 @@ pub fn create_hashes<'a>( for col in arrays { match col.data_type() { + DataType::Null => { + hash_null(random_state, hashes_buffer, multi_col); + } DataType::Decimal(_, _) => { hash_decimal128(col, random_state, hashes_buffer, multi_col); } diff --git a/datafusion/core/src/sql/planner.rs b/datafusion/core/src/sql/planner.rs index 425c685f9351..5305bdf506bd 100644 --- a/datafusion/core/src/sql/planner.rs +++ b/datafusion/core/src/sql/planner.rs @@ -1491,7 +1491,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Value(Value::Number(n, _)) => parse_sql_number(&n), SQLExpr::Value(Value::SingleQuotedString(s)) => Ok(lit(s)), SQLExpr::Value(Value::Null) => { - Ok(Expr::Literal(ScalarValue::Utf8(None))) + Ok(Expr::Literal(ScalarValue::Null)) } SQLExpr::Value(Value::Boolean(n)) => Ok(lit(n)), SQLExpr::UnaryOp { op, expr } => self.parse_sql_unary_op( @@ -1529,7 +1529,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Value(Value::Number(n, _)) => parse_sql_number(&n), SQLExpr::Value(Value::SingleQuotedString(ref s)) => Ok(lit(s.clone())), SQLExpr::Value(Value::Boolean(n)) => Ok(lit(n)), - SQLExpr::Value(Value::Null) => Ok(Expr::Literal(ScalarValue::Utf8(None))), + SQLExpr::Value(Value::Null) => Ok(Expr::Literal(ScalarValue::Null)), SQLExpr::Extract { field, expr } => Ok(Expr::ScalarFunction { fun: BuiltinScalarFunction::DatePart, args: vec![ diff --git a/datafusion/core/tests/sql/expr.rs b/datafusion/core/tests/sql/expr.rs index e62acc502c55..1dffc2eb9366 100644 --- a/datafusion/core/tests/sql/expr.rs +++ b/datafusion/core/tests/sql/expr.rs @@ -121,25 +121,25 @@ async fn case_when_else_with_null_contant() -> Result<()> { FROM t1"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+----------------------------------------------------------------------------------------------+", - "| CASE WHEN #t1.c1 = Utf8(\"a\") THEN Int64(1) WHEN Utf8(NULL) THEN Int64(2) ELSE Int64(999) END |", - "+----------------------------------------------------------------------------------------------+", - "| 1 |", - "| 999 |", - "| 999 |", - "| 999 |", - "+----------------------------------------------------------------------------------------------+", + "+----------------------------------------------------------------------------------------+", + "| CASE WHEN #t1.c1 = Utf8(\"a\") THEN Int64(1) WHEN NULL THEN Int64(2) ELSE Int64(999) END |", + "+----------------------------------------------------------------------------------------+", + "| 1 |", + "| 999 |", + "| 999 |", + "| 999 |", + "+----------------------------------------------------------------------------------------+", ]; assert_batches_eq!(expected, &actual); let sql = "SELECT CASE WHEN NULL THEN 'foo' ELSE 'bar' END"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+------------------------------------------------------------+", - "| CASE WHEN Utf8(NULL) THEN Utf8(\"foo\") ELSE Utf8(\"bar\") END |", - "+------------------------------------------------------------+", - "| bar |", - "+------------------------------------------------------------+", + "+------------------------------------------------------+", + "| CASE WHEN NULL THEN Utf8(\"foo\") ELSE Utf8(\"bar\") END |", + "+------------------------------------------------------+", + "| bar |", + "+------------------------------------------------------+", ]; assert_batches_eq!(expected, &actual); Ok(()) @@ -347,11 +347,11 @@ async fn test_string_concat_operator() -> Result<()> { let sql = "SELECT 'aa' || NULL || 'd'"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+---------------------------------------+", - "| Utf8(\"aa\") || Utf8(NULL) || Utf8(\"d\") |", - "+---------------------------------------+", - "| |", - "+---------------------------------------+", + "+---------------------------------+", + "| Utf8(\"aa\") || NULL || Utf8(\"d\") |", + "+---------------------------------+", + "| |", + "+---------------------------------+", ]; assert_batches_eq!(expected, &actual); @@ -387,11 +387,11 @@ async fn test_not_expressions() -> Result<()> { let sql = "SELECT null, not(null)"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+------------+----------------+", - "| Utf8(NULL) | NOT Utf8(NULL) |", - "+------------+----------------+", - "| | |", - "+------------+----------------+", + "+------+----------+", + "| NULL | NOT NULL |", + "+------+----------+", + "| | |", + "+------+----------+", ]; assert_batches_eq!(expected, &actual); diff --git a/datafusion/core/tests/sql/functions.rs b/datafusion/core/tests/sql/functions.rs index 857781aa35a3..396bd11940c1 100644 --- a/datafusion/core/tests/sql/functions.rs +++ b/datafusion/core/tests/sql/functions.rs @@ -176,11 +176,11 @@ async fn coalesce_static_value_with_null() -> Result<()> { let sql = "SELECT COALESCE(NULL, 'test')"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+-----------------------------------+", - "| coalesce(Utf8(NULL),Utf8(\"test\")) |", - "+-----------------------------------+", - "| test |", - "+-----------------------------------+", + "+-----------------------------+", + "| coalesce(NULL,Utf8(\"test\")) |", + "+-----------------------------+", + "| test |", + "+-----------------------------+", ]; assert_batches_eq!(expected, &actual); Ok(()) diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index aaa8adac5061..312b687a60b3 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -829,7 +829,11 @@ async fn inner_join_nulls() { let sql = "SELECT * FROM (SELECT null AS id1) t1 INNER JOIN (SELECT null AS id2) t2 ON id1 = id2"; - let expected = vec!["++", "++"]; + #[rustfmt::skip] + let expected = vec![ + "++", + "++", + ]; let ctx = create_join_context_qualified().unwrap(); let actual = execute_to_batches(&ctx, sql).await; diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index 747a9e05a7cf..4ab3a83be1ea 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -398,15 +398,37 @@ async fn select_distinct_from() { 1 IS NOT DISTINCT FROM CAST(NULL as INT) as c, 1 IS NOT DISTINCT FROM 1 as d, NULL IS DISTINCT FROM NULL as e, - NULL IS NOT DISTINCT FROM NULL as f + NULL IS NOT DISTINCT FROM NULL as f, + NULL is DISTINCT FROM 1 as g, + NULL is NOT DISTINCT FROM 1 as h "; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+------+-------+-------+------+-------+------+", - "| a | b | c | d | e | f |", - "+------+-------+-------+------+-------+------+", - "| true | false | false | true | false | true |", - "+------+-------+-------+------+-------+------+", + "+------+-------+-------+------+-------+------+------+-------+", + "| a | b | c | d | e | f | g | h |", + "+------+-------+-------+------+-------+------+------+-------+", + "| true | false | false | true | false | true | true | false |", + "+------+-------+-------+------+-------+------+------+-------+", + ]; + assert_batches_eq!(expected, &actual); + + let sql = "select + NULL IS DISTINCT FROM NULL as a, + NULL IS NOT DISTINCT FROM NULL as b, + NULL is DISTINCT FROM 1 as c, + NULL is NOT DISTINCT FROM 1 as d, + 1 IS DISTINCT FROM CAST(NULL as INT) as e, + 1 IS DISTINCT FROM 1 as f, + 1 IS NOT DISTINCT FROM CAST(NULL as INT) as g, + 1 IS NOT DISTINCT FROM 1 as h + "; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-------+------+------+-------+------+-------+-------+------+", + "| a | b | c | d | e | f | g | h |", + "+-------+------+------+-------+------+-------+-------+------+", + "| false | true | true | false | true | false | false | true |", + "+-------+------+------+-------+------+-------+-------+------+", ]; assert_batches_eq!(expected, &actual); } diff --git a/datafusion/expr/src/binary_rule.rs b/datafusion/expr/src/binary_rule.rs index ad46770f107b..63a9712fd5ac 100644 --- a/datafusion/expr/src/binary_rule.rs +++ b/datafusion/expr/src/binary_rule.rs @@ -590,8 +590,30 @@ fn eq_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { numerical_coercion(lhs_type, rhs_type) .or_else(|| dictionary_coercion(lhs_type, rhs_type)) .or_else(|| temporal_coercion(lhs_type, rhs_type)) + .or_else(|| null_coercion(lhs_type, rhs_type)) } +/// coercion rules from NULL type. Since NULL can be casted to most of types in arrow, +/// either lhs or rhs is NULL, if NULL can be casted to type of the other side, the coecion is valid. +fn null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { + match (lhs_type, rhs_type) { + (DataType::Null, _) => { + if can_cast_types(&DataType::Null, rhs_type) { + Some(rhs_type.clone()) + } else { + None + } + } + (_, DataType::Null) => { + if can_cast_types(&DataType::Null, lhs_type) { + Some(lhs_type.clone()) + } else { + None + } + } + _ => None, + } +} #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 385e247bd3a6..d631e0f83bb6 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -58,6 +58,7 @@ macro_rules! make_utf8_to_return_type { Ok(match arg_type { DataType::LargeUtf8 => $largeUtf8Type, DataType::Utf8 => $utf8Type, + DataType::Null => DataType::Null, _ => { // this error is internal as `data_types` should have captured this. return Err(DataFusionError::Internal(format!( @@ -209,6 +210,7 @@ pub fn return_type( DataType::Utf8 => { DataType::List(Box::new(Field::new("item", DataType::Utf8, true))) } + DataType::Null => DataType::Null, _ => { // this error is internal as `data_types` should have captured this. return Err(DataFusionError::Internal( diff --git a/datafusion/expr/src/type_coercion.rs b/datafusion/expr/src/type_coercion.rs index 8cea256f1829..33a540d6f1ef 100644 --- a/datafusion/expr/src/type_coercion.rs +++ b/datafusion/expr/src/type_coercion.rs @@ -31,7 +31,10 @@ //! use crate::{Signature, TypeSignature}; -use arrow::datatypes::{DataType, TimeUnit}; +use arrow::{ + compute::can_cast_types, + datatypes::{DataType, TimeUnit}, +}; use datafusion_common::{DataFusionError, Result}; /// Returns the data types that each argument must be coerced to match @@ -142,25 +145,35 @@ fn maybe_data_types( /// See the module level documentation for more detail on coercion. pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool { use self::DataType::*; + // Null can convert to most of types match type_into { - Int8 => matches!(type_from, Int8), - Int16 => matches!(type_from, Int8 | Int16 | UInt8), - Int32 => matches!(type_from, Int8 | Int16 | Int32 | UInt8 | UInt16), + Int8 => matches!(type_from, Null | Int8), + Int16 => matches!(type_from, Null | Int8 | Int16 | UInt8), + Int32 => matches!(type_from, Null | Int8 | Int16 | Int32 | UInt8 | UInt16), Int64 => matches!( type_from, - Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 + Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 ), - UInt8 => matches!(type_from, UInt8), - UInt16 => matches!(type_from, UInt8 | UInt16), - UInt32 => matches!(type_from, UInt8 | UInt16 | UInt32), - UInt64 => matches!(type_from, UInt8 | UInt16 | UInt32 | UInt64), + UInt8 => matches!(type_from, Null | UInt8), + UInt16 => matches!(type_from, Null | UInt8 | UInt16), + UInt32 => matches!(type_from, Null | UInt8 | UInt16 | UInt32), + UInt64 => matches!(type_from, Null | UInt8 | UInt16 | UInt32 | UInt64), Float32 => matches!( type_from, - Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 | Float32 + Null | Int8 + | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Float32 ), Float64 => matches!( type_from, - Int8 | Int16 + Null | Int8 + | Int16 | Int32 | Int64 | UInt8 @@ -171,8 +184,11 @@ pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool { | Float64 | Decimal(_, _) ), - Timestamp(TimeUnit::Nanosecond, None) => matches!(type_from, Timestamp(_, None)), + Timestamp(TimeUnit::Nanosecond, None) => { + matches!(type_from, Null | Timestamp(_, None)) + } Utf8 | LargeUtf8 => true, + Null => can_cast_types(type_from, type_into), _ => false, } } diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 6dafb43f92b8..060f30cb2d77 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -604,6 +604,20 @@ macro_rules! compute_decimal_op { }}; } +macro_rules! compute_null_op { + ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ + let ll = $LEFT + .as_any() + .downcast_ref::<$DT>() + .expect("compute_op failed to downcast array"); + let rr = $RIGHT + .as_any() + .downcast_ref::<$DT>() + .expect("compute_op failed to downcast array"); + Ok(Arc::new(paste::expr! {[<$OP _null>]}(&ll, &rr)?)) + }}; +} + /// Invoke a compute kernel on a pair of binary data arrays macro_rules! compute_utf8_op { ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ @@ -909,6 +923,7 @@ macro_rules! binary_array_op_scalar { macro_rules! binary_array_op { ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ match $LEFT.data_type() { + DataType::Null => compute_null_op!($LEFT, $RIGHT, $OP, NullArray), DataType::Decimal(_,_) => compute_decimal_op!($LEFT, $RIGHT, $OP, DecimalArray), DataType::Int8 => compute_op!($LEFT, $RIGHT, $OP, Int8Array), DataType::Int16 => compute_op!($LEFT, $RIGHT, $OP, Int16Array), @@ -1261,7 +1276,16 @@ impl BinaryExpr { Operator::GtEq => gt_eq_dyn(&left, &right), Operator::Eq => eq_dyn(&left, &right), Operator::NotEq => neq_dyn(&left, &right), - Operator::IsDistinctFrom => binary_array_op!(left, right, is_distinct_from), + Operator::IsDistinctFrom => { + match (left_data_type, right_data_type) { + // exchange lhs and rhs when lhs is Null, since `binary_array_op` is + // always try to down cast array according to $LEFT expression. + (DataType::Null, _) => { + binary_array_op!(right, left, is_distinct_from) + } + _ => binary_array_op!(left, right, is_distinct_from), + } + } Operator::IsNotDistinctFrom => { binary_array_op!(left, right, is_not_distinct_from) } @@ -1336,6 +1360,27 @@ fn is_distinct_from_utf8( .collect()) } +fn is_distinct_from_null(left: &NullArray, _right: &NullArray) -> Result { + let length = left.len(); + make_boolean_array(length, false) +} + +fn is_not_distinct_from_null( + left: &NullArray, + _right: &NullArray, +) -> Result { + let length = left.len(); + make_boolean_array(length, true) +} + +pub fn eq_null(left: &NullArray, _right: &NullArray) -> Result { + Ok((0..left.len()).into_iter().map(|_| None).collect()) +} + +fn make_boolean_array(length: usize, value: bool) -> Result { + Ok((0..length).into_iter().map(|_| Some(value)).collect()) +} + fn is_not_distinct_from( left: &PrimitiveArray, right: &PrimitiveArray, diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index a6894b938ff6..7094a718d000 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -629,6 +629,10 @@ impl PhysicalExpr for InListExpr { DataType::LargeUtf8 => { self.compare_utf8::(array, list_values, self.negated) } + DataType::Null => { + let null_array = new_null_array(&DataType::Boolean, array.len()); + Ok(ColumnarValue::Array(Arc::new(null_array))) + } datatype => Result::Err(DataFusionError::NotImplemented(format!( "InList does not support datatype {:?}.", datatype diff --git a/datafusion/physical-expr/src/expressions/nullif.rs b/datafusion/physical-expr/src/expressions/nullif.rs index 307e3a07f394..2d1f3654d241 100644 --- a/datafusion/physical-expr/src/expressions/nullif.rs +++ b/datafusion/physical-expr/src/expressions/nullif.rs @@ -17,7 +17,7 @@ use std::sync::Arc; -use crate::expressions::binary::{eq_decimal, eq_decimal_scalar}; +use crate::expressions::binary::{eq_decimal, eq_decimal_scalar, eq_null}; use arrow::array::Array; use arrow::array::*; use arrow::compute::kernels::boolean::nullif; From 109bc0e8e57f40db9e6b2578e1313d42cde8bf4c Mon Sep 17 00:00:00 2001 From: duripeng <453243496@qq.com> Date: Fri, 6 May 2022 02:57:45 +0800 Subject: [PATCH 2/2] fix fmt --- datafusion/core/src/physical_plan/hash_utils.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/datafusion/core/src/physical_plan/hash_utils.rs b/datafusion/core/src/physical_plan/hash_utils.rs index 63c9ad522fab..2ca1fa3df9d1 100644 --- a/datafusion/core/src/physical_plan/hash_utils.rs +++ b/datafusion/core/src/physical_plan/hash_utils.rs @@ -39,11 +39,7 @@ fn combine_hashes(l: u64, r: u64) -> u64 { hash.wrapping_mul(37).wrapping_add(r) } -fn hash_null<'a>( - random_state: &RandomState, - hashes_buffer: &'a mut [u64], - mul_col: bool, -) { +fn hash_null(random_state: &RandomState, hashes_buffer: &'_ mut [u64], mul_col: bool) { if mul_col { hashes_buffer.iter_mut().for_each(|hash| { // stable hash for null value