Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Boolean Coercion #8331

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions datafusion/expr/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<D
return Some(lhs_type.clone());
}
comparison_binary_numeric_coercion(lhs_type, rhs_type)
.or_else(|| bool_coercion(lhs_type, rhs_type))
.or_else(|| dictionary_coercion(lhs_type, rhs_type, true))
.or_else(|| temporal_coercion(lhs_type, rhs_type))
.or_else(|| string_coercion(lhs_type, rhs_type))
Expand Down Expand Up @@ -353,6 +354,20 @@ fn string_temporal_coercion(
}
}

/// Coerce `Boolean` to other larger types, like Numeric as `1` or String as "1"
fn bool_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
match (lhs_type, rhs_type) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this also needs to check when rhs_type is a DataType::Boolean as well.

I would expect both of the following queries to work and return the same thing

sDataFusion CLI v33.0.0
❯ select 1 = true;
Error during planning: Cannot infer common argument type for comparison operation Int64 = Boolean
❯ select true = 1;
+--------------------------+
| Boolean(true) = Int64(1) |
+--------------------------+
| true                     |
+--------------------------+
1 row in set. Query took 0.006 seconds.

(DataType::Boolean, other_type) | (other_type, DataType::Boolean) => {
if can_cast_types(&DataType::Boolean, other_type) {
Some(other_type.to_owned())
} else {
None
}
}
_ => None,
}
}

/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
/// where one both are numeric
fn comparison_binary_numeric_coercion(
Expand Down
10 changes: 4 additions & 6 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1100,9 +1100,8 @@ mod test {

let empty = empty_with_type(DataType::Int64);
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, "");
let err = ret.unwrap_err().to_string();
assert!(err.contains("Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean"), "{err}");
let expected = "Projection: CAST(a AS Boolean) IS TRUE\n EmptyRelation";
assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?;

// is not true
let expr = col("a").is_not_true();
Expand Down Expand Up @@ -1202,9 +1201,8 @@ mod test {

let empty = empty_with_type(DataType::Utf8);
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected);
let err = ret.unwrap_err().to_string();
assert!(err.contains("Cannot infer common argument type for comparison operation Utf8 IS DISTINCT FROM Boolean"), "{err}");
let expected = "Projection: CAST(a AS Boolean) IS UNKNOWN\n EmptyRelation";
assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?;

// is not unknown
let expr = col("a").is_not_unknown();
Expand Down
63 changes: 0 additions & 63 deletions datafusion/physical-expr/src/expressions/case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -799,69 +799,6 @@ mod tests {
Ok(batch)
}

#[test]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to slt

fn case_test_incompatible() -> Result<()> {
// 1 then is int64
// 2 then is boolean
let batch = case_test_batch()?;
let schema = batch.schema();

// CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN true END
let when1 = binary(
col("a", &schema)?,
Operator::Eq,
lit("foo"),
&batch.schema(),
)?;
let then1 = lit(123i32);
let when2 = binary(
col("a", &schema)?,
Operator::Eq,
lit("bar"),
&batch.schema(),
)?;
let then2 = lit(true);

let expr = generate_case_when_with_type_coercion(
None,
vec![(when1, then1), (when2, then2)],
None,
schema.as_ref(),
);
assert!(expr.is_err());

// then 1 is int32
// then 2 is int64
// else is float
// CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 1.23 END
let when1 = binary(
col("a", &schema)?,
Operator::Eq,
lit("foo"),
&batch.schema(),
)?;
let then1 = lit(123i32);
let when2 = binary(
col("a", &schema)?,
Operator::Eq,
lit("bar"),
&batch.schema(),
)?;
let then2 = lit(456i64);
let else_expr = lit(1.23f64);

let expr = generate_case_when_with_type_coercion(
None,
vec![(when1, then1), (when2, then2)],
Some(else_expr),
schema.as_ref(),
);
assert!(expr.is_ok());
let result_type = expr.unwrap().data_type(schema.as_ref())?;
assert_eq!(DataType::Float64, result_type);
Ok(())
}

#[test]
fn case_eq() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
Expand Down
12 changes: 0 additions & 12 deletions datafusion/sql/tests/sql_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2190,18 +2190,6 @@ fn union_with_aliases() {
quick_test(sql, expected);
}

#[test]
fn union_with_incompatible_data_types() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to slt

let sql = "SELECT 'a' a UNION ALL SELECT true a";
let err = logical_plan(sql)
.expect_err("query should have failed")
.strip_backtrace();
assert_eq!(
"Error during planning: UNION Column a (type: Boolean) is not compatible with column a (type: Utf8)",
err
);
}

#[test]
fn empty_over() {
let sql = "SELECT order_id, MAX(order_id) OVER () from orders";
Expand Down
14 changes: 10 additions & 4 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -557,10 +557,16 @@ select column1[0:5], column2[0:3], column3[0:9] from arrays;
## make_array (aliases: `make_list`)

# make_array scalar function #1
query ???
select make_array(1, 2, 3), make_array(1.0, 2.0, 3.0), make_array('h', 'e', 'l', 'l', 'o');
----
[1, 2, 3] [1.0, 2.0, 3.0] [h, e, l, l, o]
query ??????
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add some tests for other uses of this logic (not just in make_array) such as comparisons

I notice that postgres doesn't handle boolean coercion

postgres=# select true = 1;
ERROR:  operator does not exist: boolean = integer
LINE 1: select true = 1;
                    ^
HINT:  No operator matches the given name and argument types. You might need to add explicit type casts.

However, after this PR datafusion does:

DataFusion CLI v33.0.0select true = 1;
+--------------------------+
| Boolean(true) = Int64(1) |
+--------------------------+
| true                     |
+--------------------------+
1 row in set. Query took 0.006 seconds.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duckdb support boolean coercion, we can consider follow it. If it does not break the overall design, I think we can support it too.

select
make_array(1, 2, 3),
make_array(1.0, 2.0, 3.0),
make_array('h', 'e', 'l', 'l', 'o'),
make_array(true, 1, 2, false),
make_array(true, 1, 2.3, false),
make_array(true, 1, 2.3, false, '4');
----
[1, 2, 3] [1.0, 2.0, 3.0] [h, e, l, l, o] [1, 1, 2, 0] [1.0, 1.0, 2.3, 0.0] [true, 1, 2.3, false, 4]

# make_array scalar function #2
query ???
Expand Down
32 changes: 32 additions & 0 deletions datafusion/sqllogictest/test_files/scalar.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,27 @@ NULL
123
NULL

# integer with float
query I
select case when a = 'rust' then arrow_cast(1, 'Int32') when a == 'c++' then arrow_cast(2, 'Int64') else 1.5 end from (values('python')) as t(a);
----
1.5

# integer with boolean
query I
select case when a = 'rust' then 1 when a == 'c++' then false end from (values('c++')) as t(a);
----
0

# type coercion not supported in case expr (boolean <-> timestamp)
#
# DataFusion error: type_coercion
# caused by
# Error during planning: Failed to coerce then ([Timestamp(Nanosecond, None), Boolean]) and else (None) to common types in CASE WHEN expression
query error
select case when a = 'foo' then arrow_cast(500, 'Timestamp(Nanosecond, None)') when a = 'bar' then true end from (values('foo')) as t(a);


# csv_query_sum_cast() {

statement ok
Expand Down Expand Up @@ -1926,3 +1947,14 @@ A true
B false
C false
D false

# bool_coercion
query BBBB
select 1 == true, false == 0, 1.0 == true, false == 'false';
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

----
true true true true

query BBBB
select 2 > true, false is not distinct from 0, 1.0 >= true, true is distinct from 'false';
----
true true true true
18 changes: 18 additions & 0 deletions datafusion/sqllogictest/test_files/union.slt
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,24 @@ SELECT 1 UNION SELECT 2
1
2

# union bool with string
query T
jayzhan211 marked this conversation as resolved.
Show resolved Hide resolved
SELECT 'a' a UNION ALL SELECT true a
----
true
a

# union bool with integer
query I
SELECT 123 a UNION ALL SELECT true a
----
1
123

# union incompatible types (bool with interval)
query error DataFusion error: Error during planning: UNION Column a \(type: Interval\(MonthDayNano\)\) is not compatible with column a \(type: Boolean\)
SELECT true a UNION SELECT interval '1 minute' a

# union_with_except_input
query T rowsort
(
Expand Down