diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 9ccddbfce068..ae3e32c3107d 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -291,6 +291,7 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { + match (lhs_type, rhs_type) { + (DataType::Boolean, other_type) | (other_type, DataType::Boolean) => { + if can_cast_types(&DataType::Boolean, other_type) { + Some(other_type.to_owned()) + } else { + None + } + } + _ => None, + } +} + /// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation /// where one both are numeric fn comparison_binary_numeric_coercion( diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 6628e8961e26..28358c2ff633 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -1100,9 +1100,8 @@ mod test { let empty = empty_with_type(DataType::Int64); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, ""); - let err = ret.unwrap_err().to_string(); - assert!(err.contains("Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean"), "{err}"); + let expected = "Projection: CAST(a AS Boolean) IS TRUE\n EmptyRelation"; + assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?; // is not true let expr = col("a").is_not_true(); @@ -1202,9 +1201,8 @@ mod test { let empty = empty_with_type(DataType::Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected); - let err = ret.unwrap_err().to_string(); - assert!(err.contains("Cannot infer common argument type for comparison operation Utf8 IS DISTINCT FROM Boolean"), "{err}"); + let expected = "Projection: CAST(a AS Boolean) IS UNKNOWN\n EmptyRelation"; + assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?; // is not unknown let expr = col("a").is_not_unknown(); diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 5fcfd61d90e4..15b2b81499c2 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -799,69 +799,6 @@ mod tests { Ok(batch) } - #[test] - fn case_test_incompatible() -> Result<()> { - // 1 then is int64 - // 2 then is boolean - let batch = case_test_batch()?; - let schema = batch.schema(); - - // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN true END - let when1 = binary( - col("a", &schema)?, - Operator::Eq, - lit("foo"), - &batch.schema(), - )?; - let then1 = lit(123i32); - let when2 = binary( - col("a", &schema)?, - Operator::Eq, - lit("bar"), - &batch.schema(), - )?; - let then2 = lit(true); - - let expr = generate_case_when_with_type_coercion( - None, - vec![(when1, then1), (when2, then2)], - None, - schema.as_ref(), - ); - assert!(expr.is_err()); - - // then 1 is int32 - // then 2 is int64 - // else is float - // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 1.23 END - let when1 = binary( - col("a", &schema)?, - Operator::Eq, - lit("foo"), - &batch.schema(), - )?; - let then1 = lit(123i32); - let when2 = binary( - col("a", &schema)?, - Operator::Eq, - lit("bar"), - &batch.schema(), - )?; - let then2 = lit(456i64); - let else_expr = lit(1.23f64); - - let expr = generate_case_when_with_type_coercion( - None, - vec![(when1, then1), (when2, then2)], - Some(else_expr), - schema.as_ref(), - ); - assert!(expr.is_ok()); - let result_type = expr.unwrap().data_type(schema.as_ref())?; - assert_eq!(DataType::Float64, result_type); - Ok(()) - } - #[test] fn case_eq() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index d5b06bcf815f..304228070bed 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -2190,18 +2190,6 @@ fn union_with_aliases() { quick_test(sql, expected); } -#[test] -fn union_with_incompatible_data_types() { - let sql = "SELECT 'a' a UNION ALL SELECT true a"; - let err = logical_plan(sql) - .expect_err("query should have failed") - .strip_backtrace(); - assert_eq!( - "Error during planning: UNION Column a (type: Boolean) is not compatible with column a (type: Utf8)", - err - ); -} - #[test] fn empty_over() { let sql = "SELECT order_id, MAX(order_id) OVER () from orders"; diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 9e3ac3bf08f6..a41d9f5ee737 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -557,10 +557,16 @@ select column1[0:5], column2[0:3], column3[0:9] from arrays; ## make_array (aliases: `make_list`) # make_array scalar function #1 -query ??? -select make_array(1, 2, 3), make_array(1.0, 2.0, 3.0), make_array('h', 'e', 'l', 'l', 'o'); ----- -[1, 2, 3] [1.0, 2.0, 3.0] [h, e, l, l, o] +query ?????? +select + make_array(1, 2, 3), + make_array(1.0, 2.0, 3.0), + make_array('h', 'e', 'l', 'l', 'o'), + make_array(true, 1, 2, false), + make_array(true, 1, 2.3, false), + make_array(true, 1, 2.3, false, '4'); +---- +[1, 2, 3] [1.0, 2.0, 3.0] [h, e, l, l, o] [1, 1, 2, 0] [1.0, 1.0, 2.3, 0.0] [true, 1, 2.3, false, 4] # make_array scalar function #2 query ??? diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index ecb7fe13fcf4..6e81161d62c9 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -1296,6 +1296,27 @@ NULL 123 NULL +# integer with float +query I +select case when a = 'rust' then arrow_cast(1, 'Int32') when a == 'c++' then arrow_cast(2, 'Int64') else 1.5 end from (values('python')) as t(a); +---- +1.5 + +# integer with boolean +query I +select case when a = 'rust' then 1 when a == 'c++' then false end from (values('c++')) as t(a); +---- +0 + +# type coercion not supported in case expr (boolean <-> timestamp) +# +# DataFusion error: type_coercion +# caused by +# Error during planning: Failed to coerce then ([Timestamp(Nanosecond, None), Boolean]) and else (None) to common types in CASE WHEN expression +query error +select case when a = 'foo' then arrow_cast(500, 'Timestamp(Nanosecond, None)') when a = 'bar' then true end from (values('foo')) as t(a); + + # csv_query_sum_cast() { statement ok @@ -1926,3 +1947,14 @@ A true B false C false D false + +# bool_coercion +query BBBB +select 1 == true, false == 0, 1.0 == true, false == 'false'; +---- +true true true true + +query BBBB +select 2 > true, false is not distinct from 0, 1.0 >= true, true is distinct from 'false'; +---- +true true true true diff --git a/datafusion/sqllogictest/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt index 0f255cdb9fb9..f81873ed8570 100644 --- a/datafusion/sqllogictest/test_files/union.slt +++ b/datafusion/sqllogictest/test_files/union.slt @@ -157,6 +157,24 @@ SELECT 1 UNION SELECT 2 1 2 +# union bool with string +query T +SELECT 'a' a UNION ALL SELECT true a +---- +true +a + +# union bool with integer +query I +SELECT 123 a UNION ALL SELECT true a +---- +1 +123 + +# union incompatible types (bool with interval) +query error DataFusion error: Error during planning: UNION Column a \(type: Interval\(MonthDayNano\)\) is not compatible with column a \(type: Boolean\) +SELECT true a UNION SELECT interval '1 minute' a + # union_with_except_input query T rowsort (