Skip to content

Commit

Permalink
Support <bool col> = <bool col> and <bool col> != <bool col>
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Oct 21, 2021
1 parent 2b002e4 commit fda35ac
Show file tree
Hide file tree
Showing 2 changed files with 243 additions and 9 deletions.
212 changes: 205 additions & 7 deletions datafusion/src/physical_plan/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,29 @@ macro_rules! boolean_op {
}};
}

/// Invoke a boolean kernel with a scalar on an array
macro_rules! boolean_op_scalar {
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
let ll = $LEFT
.as_any()
.downcast_ref::<BooleanArray>()
.expect("boolean_op_scalar failed to downcast array");

let result = if let ScalarValue::Boolean(scalar) = $RIGHT {
Ok(
Arc::new(paste::expr! {[<$OP _bool_scalar>]}(&ll, scalar.as_ref())?)
as ArrayRef,
)
} else {
Err(DataFusionError::Internal(format!(
"boolean_op_scalar failed to cast literal value {}",
$RIGHT
)))
};
Some(result)
}};
}

macro_rules! binary_string_array_flag_op {
($LEFT:expr, $RIGHT:expr, $OP:ident, $NOT:expr, $FLAG:expr) => {{
match $LEFT.data_type() {
Expand Down Expand Up @@ -592,9 +615,19 @@ impl BinaryExpr {
Operator::GtEq => {
binary_array_op_scalar!(array, scalar.clone(), gt_eq)
}
Operator::Eq => binary_array_op_scalar!(array, scalar.clone(), eq),
Operator::Eq => {
if array.data_type() == &DataType::Boolean {
boolean_op_scalar!(array, scalar.clone(), eq)
} else {
binary_array_op_scalar!(array, scalar.clone(), eq)
}
}
Operator::NotEq => {
binary_array_op_scalar!(array, scalar.clone(), neq)
if array.data_type() == &DataType::Boolean {
boolean_op_scalar!(array, scalar.clone(), neq)
} else {
binary_array_op_scalar!(array, scalar.clone(), neq)
}
}
Operator::Like => {
binary_string_array_op_scalar!(array, scalar.clone(), like)
Expand Down Expand Up @@ -659,9 +692,19 @@ impl BinaryExpr {
Operator::GtEq => {
binary_array_op_scalar!(array, scalar.clone(), lt_eq)
}
Operator::Eq => binary_array_op_scalar!(array, scalar.clone(), eq),
Operator::Eq => {
if array.data_type() == &DataType::Boolean {
boolean_op_scalar!(array, scalar.clone(), eq)
} else {
binary_array_op_scalar!(array, scalar.clone(), eq)
}
}
Operator::NotEq => {
binary_array_op_scalar!(array, scalar.clone(), neq)
if array.data_type() == &DataType::Boolean {
boolean_op_scalar!(array, scalar.clone(), neq)
} else {
binary_array_op_scalar!(array, scalar.clone(), neq)
}
}
// if scalar operation is not supported - fallback to array implementation
_ => None,
Expand All @@ -683,8 +726,21 @@ impl BinaryExpr {
Operator::LtEq => binary_array_op!(left, right, lt_eq),
Operator::Gt => binary_array_op!(left, right, gt),
Operator::GtEq => binary_array_op!(left, right, gt_eq),
Operator::Eq => binary_array_op!(left, right, eq),
Operator::NotEq => binary_array_op!(left, right, neq),
Operator::Eq => {
if left_data_type == &DataType::Boolean {
boolean_op!(left, right, eq_bool)
} else {
binary_array_op!(left, right, eq)
}
}
Operator::NotEq => {
if left_data_type == &DataType::Boolean {
boolean_op!(left, right, neq_bool)
} else {
binary_array_op!(left, right, neq)
}
}

Operator::IsDistinctFrom => binary_array_op!(left, right, is_distinct_from),
Operator::IsNotDistinctFrom => {
binary_array_op!(left, right, is_not_distinct_from)
Expand Down Expand Up @@ -814,14 +870,68 @@ pub fn binary(
Ok(Arc::new(BinaryExpr::new(l, op, r)))
}

// TODO file a ticket with arrow-rs to include these kernels

fn eq_bool(lhs: &BooleanArray, rhs: &BooleanArray) -> Result<BooleanArray> {
let arr: BooleanArray = lhs
.iter()
.zip(rhs.iter())
.map(|v| match v {
// both lhs and rhs were non null
(Some(lhs), Some(rhs)) => Some(lhs == rhs),
_ => None,
})
.collect();

Ok(arr)
}

fn eq_bool_scalar(lhs: &BooleanArray, rhs: Option<&bool>) -> Result<BooleanArray> {
let arr: BooleanArray = lhs
.iter()
.map(|v| match (v, rhs) {
// both lhs and rhs were non null
(Some(lhs), Some(rhs)) => Some(lhs == *rhs),
_ => None,
})
.collect();
Ok(arr)
}

fn neq_bool(lhs: &BooleanArray, rhs: &BooleanArray) -> Result<BooleanArray> {
let arr: BooleanArray = lhs
.iter()
.zip(rhs.iter())
.map(|v| match v {
// both lhs and rhs were non null
(Some(lhs), Some(rhs)) => Some(lhs != rhs),
_ => None,
})
.collect();

Ok(arr)
}

fn neq_bool_scalar(lhs: &BooleanArray, rhs: Option<&bool>) -> Result<BooleanArray> {
let arr: BooleanArray = lhs
.iter()
.map(|v| match (v, rhs) {
// both lhs and rhs were non null
(Some(lhs), Some(rhs)) => Some(lhs != *rhs),
_ => None,
})
.collect();
Ok(arr)
}

#[cfg(test)]
mod tests {
use arrow::datatypes::{ArrowNumericType, Field, Int32Type, SchemaRef};
use arrow::util::display::array_value_to_string;

use super::*;
use crate::error::Result;
use crate::physical_plan::expressions::col;
use crate::physical_plan::expressions::{col, lit};

// Create a binary expression without coercion. Used here when we do not want to coerce the expressions
// to valid types. Usage can result in an execution (after plan) error.
Expand Down Expand Up @@ -1371,6 +1481,42 @@ mod tests {
Ok(())
}

// Test `scalar <op> arr` produces expected
fn apply_logic_op_scalar_arr(
schema: &SchemaRef,
scalar: bool,
arr: &ArrayRef,
op: Operator,
expected: &BooleanArray,
) -> Result<()> {
let scalar = lit(scalar.into());

let arithmetic_op = binary_simple(scalar, op, col("a", schema)?);
let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?;
let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
assert_eq!(result.as_ref(), expected);

Ok(())
}

// Test `arr <op> scalar` produces expected
fn apply_logic_op_arr_scalar(
schema: &SchemaRef,
arr: &ArrayRef,
scalar: bool,
op: Operator,
expected: &BooleanArray,
) -> Result<()> {
let scalar = lit(scalar.into());

let arithmetic_op = binary_simple(col("a", schema)?, op, scalar);
let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?;
let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
assert_eq!(result.as_ref(), expected);

Ok(())
}

#[test]
fn and_with_nulls_op() -> Result<()> {
let schema = Schema::new(vec![
Expand Down Expand Up @@ -1461,6 +1607,58 @@ mod tests {
Ok(())
}

#[test]
fn eq_op_bool() {
let schema = Schema::new(vec![
Field::new("a", DataType::Boolean, false),
Field::new("b", DataType::Boolean, false),
]);
let a = BooleanArray::from(vec![Some(true), None, Some(false), None]);
let b =
BooleanArray::from(vec![Some(true), Some(false), Some(true), Some(false)]);

let expected = BooleanArray::from(vec![Some(true), None, Some(false), None]);
apply_logic_op(Arc::new(schema), a, b, Operator::Eq, expected).unwrap();
}

#[test]
fn eq_op_bool_scalar() {
let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]);
let schema = Arc::new(schema);
let a: ArrayRef =
Arc::new(BooleanArray::from(vec![Some(true), None, Some(false)]));

let expected = BooleanArray::from(vec![Some(true), None, Some(false)]);
apply_logic_op_scalar_arr(&schema, true, &a, Operator::Eq, &expected).unwrap();
apply_logic_op_arr_scalar(&schema, &a, true, Operator::Eq, &expected).unwrap();
}

#[test]
fn neq_op_bool() {
let schema = Schema::new(vec![
Field::new("a", DataType::Boolean, false),
Field::new("b", DataType::Boolean, false),
]);
let a = BooleanArray::from(vec![Some(true), None, Some(false), None]);
let b =
BooleanArray::from(vec![Some(true), Some(false), Some(true), Some(false)]);

let expected = BooleanArray::from(vec![Some(false), None, Some(true), None]);
apply_logic_op(Arc::new(schema), a, b, Operator::NotEq, expected).unwrap();
}

#[test]
fn neq_op_bool_scalar() {
let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]);
let schema = Arc::new(schema);
let a: ArrayRef =
Arc::new(BooleanArray::from(vec![Some(true), None, Some(false)]));

let expected = BooleanArray::from(vec![Some(false), None, Some(true)]);
apply_logic_op_scalar_arr(&schema, true, &a, Operator::NotEq, &expected).unwrap();
apply_logic_op_arr_scalar(&schema, &a, true, Operator::NotEq, &expected).unwrap();
}

#[test]
fn test_coersion_error() -> Result<()> {
let expr =
Expand Down
40 changes: 38 additions & 2 deletions datafusion/tests/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ async fn select_distinct_simple_4() {
async fn select_distinct_from() {
let mut ctx = ExecutionContext::new();

let sql = "select
let sql = "select
1 IS DISTINCT FROM CAST(NULL as INT) as a,
1 IS DISTINCT FROM 1 as b,
1 IS NOT DISTINCT FROM CAST(NULL as INT) as c,
Expand All @@ -621,7 +621,7 @@ async fn select_distinct_from() {
async fn select_distinct_from_utf8() {
let mut ctx = ExecutionContext::new();

let sql = "select
let sql = "select
'x' IS DISTINCT FROM NULL as a,
'x' IS DISTINCT FROM 'x' as b,
'x' IS NOT DISTINCT FROM NULL as c,
Expand Down Expand Up @@ -812,6 +812,40 @@ async fn csv_query_having_without_group_by() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn csv_query_boolean_eq() -> Result<()> {
let mut ctx = ExecutionContext::new();
register_aggregate_simple_csv(&mut ctx).await?;

let sql = "SELECT c3, c3 = c3 as eq, c3 != c3 as neq FROM aggregate_simple";
let actual = execute_to_batches(&mut ctx, sql).await;

let expected = vec![
"+-------+------+-------+",
"| c3 | eq | neq |",
"+-------+------+-------+",
"| true | true | false |",
"| false | true | false |",
"| false | true | false |",
"| true | true | false |",
"| true | true | false |",
"| true | true | false |",
"| false | true | false |",
"| false | true | false |",
"| false | true | false |",
"| false | true | false |",
"| true | true | false |",
"| true | true | false |",
"| true | true | false |",
"| true | true | false |",
"| true | true | false |",
"+-------+------+-------+",
];
assert_batches_eq!(expected, &actual);

Ok(())
}

#[tokio::test]
async fn csv_query_avg_sqrt() -> Result<()> {
let mut ctx = create_ctx()?;
Expand Down Expand Up @@ -4054,6 +4088,8 @@ macro_rules! test_expression {
async fn test_boolean_expressions() -> Result<()> {
test_expression!("true", "true");
test_expression!("false", "false");
test_expression!("false = false", "true");
test_expression!("true = false", "false");
Ok(())
}

Expand Down

0 comments on commit fda35ac

Please sign in to comment.