Skip to content

Commit

Permalink
Add Coalesce function (#1969)
Browse files Browse the repository at this point in the history
* feat: coalesce function. WIP

* fix: add license header

* fix: support all primitive data types

* fix: add test case and remove commented code

* fix: refactor code and support coalesce from ballista

* fix: support scalar values in coalesce

* fix: clippy

* fix: clippy by removing unneeded mut

* fix: clippy by removing unneeded mut

* fix: address review comments

Co-authored-by: Sathis Kumar <[email protected]>
  • Loading branch information
msathis and msathis authored Apr 6, 2022
1 parent 8b09a5c commit b890190
Show file tree
Hide file tree
Showing 12 changed files with 316 additions and 8 deletions.
1 change: 1 addition & 0 deletions ballista/rust/core/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ enum ScalarFunction {
Translate=60;
Trim=61;
Upper=62;
Coalesce=63;
}

message ScalarFunctionNode {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema};
pub use display::display_schema;
pub use expr::{
abs, acos, and, approx_distinct, approx_percentile_cont, array, ascii, asin, atan,
avg, bit_length, btrim, call_fn, case, ceil, character_length, chr, col,
avg, bit_length, btrim, call_fn, case, ceil, character_length, chr, coalesce, col,
columnize_expr, combine_filters, concat, concat_expr, concat_ws, concat_ws_expr, cos,
count, count_distinct, create_udaf, create_udf, date_part, date_trunc, digest, exp,
exprlist_to_fields, floor, in_list, initcap, left, length, lit, lit_timestamp_nano,
Expand Down
11 changes: 11 additions & 0 deletions datafusion/core/src/physical_plan/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ use arrow::{
use datafusion_expr::ScalarFunctionImplementation;
pub use datafusion_expr::{BuiltinScalarFunction, Signature, TypeSignature, Volatility};
use datafusion_physical_expr::array_expressions;
use datafusion_physical_expr::conditional_expressions;
use datafusion_physical_expr::datetime_expressions;
use datafusion_physical_expr::math_expressions;
use datafusion_physical_expr::string_expressions;
Expand Down Expand Up @@ -111,6 +112,11 @@ pub fn return_type(
utf8_to_int_type(&input_expr_types[0], "character_length")
}
BuiltinScalarFunction::Chr => Ok(DataType::Utf8),
BuiltinScalarFunction::Coalesce => {
// COALESCE has multiple args and they might get coerced, get a preview of this
let coerced_types = data_types(input_expr_types, &signature(fun));
coerced_types.map(|types| types[0].clone())
}
BuiltinScalarFunction::Concat => Ok(DataType::Utf8),
BuiltinScalarFunction::ConcatWithSeparator => Ok(DataType::Utf8),
BuiltinScalarFunction::DatePart => Ok(DataType::Int32),
Expand Down Expand Up @@ -366,6 +372,10 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature {
BuiltinScalarFunction::Concat | BuiltinScalarFunction::ConcatWithSeparator => {
Signature::variadic(vec![DataType::Utf8], fun.volatility())
}
BuiltinScalarFunction::Coalesce => Signature::variadic(
conditional_expressions::SUPPORTED_COALESCE_TYPES.to_vec(),
fun.volatility(),
),
BuiltinScalarFunction::Ascii
| BuiltinScalarFunction::BitLength
| BuiltinScalarFunction::CharacterLength
Expand Down Expand Up @@ -817,6 +827,7 @@ pub fn create_physical_fun(
BuiltinScalarFunction::Chr => {
Arc::new(|args| make_scalar_function(string_expressions::chr)(args))
}
BuiltinScalarFunction::Coalesce => Arc::new(conditional_expressions::coalesce),
BuiltinScalarFunction::Concat => Arc::new(string_expressions::concat),
BuiltinScalarFunction::ConcatWithSeparator => {
Arc::new(|args| make_scalar_function(string_expressions::concat_ws)(args))
Expand Down
10 changes: 5 additions & 5 deletions datafusion/core/src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ pub use crate::execution::options::{
};
pub use crate::logical_plan::{
approx_percentile_cont, array, ascii, avg, bit_length, btrim, character_length, chr,
col, concat, concat_ws, count, create_udf, date_part, date_trunc, digest, in_list,
initcap, left, length, lit, lower, lpad, ltrim, max, md5, min, now, octet_length,
random, regexp_match, regexp_replace, repeat, replace, reverse, right, rpad, rtrim,
sha224, sha256, sha384, sha512, split_part, starts_with, strpos, substr, sum, to_hex,
translate, trim, upper, Column, JoinType, Partitioning,
coalesce, col, concat, concat_ws, count, create_udf, date_part, date_trunc, digest,
in_list, initcap, left, length, lit, lower, lpad, ltrim, max, md5, min, now,
octet_length, random, regexp_match, regexp_replace, repeat, replace, reverse, right,
rpad, rtrim, sha224, sha256, sha384, sha512, split_part, starts_with, strpos, substr,
sum, to_hex, translate, trim, upper, Column, JoinType, Partitioning,
};
183 changes: 183 additions & 0 deletions datafusion/core/tests/sql/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,186 @@ async fn query_count_distinct() -> Result<()> {
assert_batches_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn coalesce_static_empty_value() -> Result<()> {
let ctx = SessionContext::new();
let sql = "SELECT COALESCE('', 'test')";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+---------------------------------+",
"| coalesce(Utf8(\"\"),Utf8(\"test\")) |",
"+---------------------------------+",
"| |",
"+---------------------------------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn coalesce_static_value_with_null() -> Result<()> {
let ctx = SessionContext::new();
let sql = "SELECT COALESCE(NULL, 'test')";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-----------------------------------+",
"| coalesce(Utf8(NULL),Utf8(\"test\")) |",
"+-----------------------------------+",
"| test |",
"+-----------------------------------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn coalesce_result() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("c1", DataType::Int32, true),
Field::new("c2", DataType::Int32, true),
]));

let data = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![Some(0), None, Some(1), None, None])),
Arc::new(Int32Array::from(vec![
Some(1),
Some(1),
Some(0),
Some(1),
None,
])),
],
)?;

let table = MemTable::try_new(schema, vec![vec![data]])?;

let ctx = SessionContext::new();
ctx.register_table("test", Arc::new(table))?;
let sql = "SELECT COALESCE(c1, c2) FROM test";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+---------------------------+",
"| coalesce(test.c1,test.c2) |",
"+---------------------------+",
"| 0 |",
"| 1 |",
"| 1 |",
"| 1 |",
"| |",
"+---------------------------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn coalesce_result_with_default_value() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("c1", DataType::Int32, true),
Field::new("c2", DataType::Int32, true),
]));

let data = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![Some(0), None, Some(1), None, None])),
Arc::new(Int32Array::from(vec![
Some(1),
Some(1),
Some(0),
Some(1),
None,
])),
],
)?;

let table = MemTable::try_new(schema, vec![vec![data]])?;

let ctx = SessionContext::new();
ctx.register_table("test", Arc::new(table))?;
let sql = "SELECT COALESCE(c1, c2, '-1') FROM test";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+--------------------------------------+",
"| coalesce(test.c1,test.c2,Utf8(\"-1\")) |",
"+--------------------------------------+",
"| 0 |",
"| 1 |",
"| 1 |",
"| 1 |",
"| -1 |",
"+--------------------------------------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn coalesce_sum_with_default_value() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("c1", DataType::Int32, true),
Field::new("c2", DataType::Int32, true),
]));

let data = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![Some(1), None, Some(1), None])),
Arc::new(Int32Array::from(vec![Some(2), Some(2), None, None])),
],
)?;

let table = MemTable::try_new(schema, vec![vec![data]])?;

let ctx = SessionContext::new();
ctx.register_table("test", Arc::new(table))?;
let sql = "SELECT SUM(COALESCE(c1, c2, 0)) FROM test";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-----------------------------------------+",
"| SUM(coalesce(test.c1,test.c2,Int64(0))) |",
"+-----------------------------------------+",
"| 4 |",
"+-----------------------------------------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn coalesce_mul_with_default_value() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("c1", DataType::Int32, true),
Field::new("c2", DataType::Int32, true),
]));

let data = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![Some(1), None, Some(1), None])),
Arc::new(Int32Array::from(vec![Some(2), Some(2), None, None])),
],
)?;

let table = MemTable::try_new(schema, vec![vec![data]])?;

let ctx = SessionContext::new();
ctx.register_table("test", Arc::new(table))?;
let sql = "SELECT COALESCE(c1 * c2, 0) FROM test";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+---------------------------------------------+",
"| coalesce(test.c1 Multiply test.c2,Int64(0)) |",
"+---------------------------------------------+",
"| 2 |",
"| 0 |",
"| 0 |",
"| 0 |",
"+---------------------------------------------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
}
6 changes: 6 additions & 0 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ pub enum BuiltinScalarFunction {
Atan,
/// ceil
Ceil,
/// coalesce
Coalesce,
/// cos
Cos,
/// Digest
Expand Down Expand Up @@ -174,6 +176,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Asin => Volatility::Immutable,
BuiltinScalarFunction::Atan => Volatility::Immutable,
BuiltinScalarFunction::Ceil => Volatility::Immutable,
BuiltinScalarFunction::Coalesce => Volatility::Immutable,
BuiltinScalarFunction::Cos => Volatility::Immutable,
BuiltinScalarFunction::Exp => Volatility::Immutable,
BuiltinScalarFunction::Floor => Volatility::Immutable,
Expand Down Expand Up @@ -271,6 +274,9 @@ impl FromStr for BuiltinScalarFunction {
"tan" => BuiltinScalarFunction::Tan,
"trunc" => BuiltinScalarFunction::Trunc,

// conditional functions
"coalesce" => BuiltinScalarFunction::Coalesce,

// string functions
"array" => BuiltinScalarFunction::Array,
"ascii" => BuiltinScalarFunction::Ascii,
Expand Down
9 changes: 9 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,15 @@ pub fn array(args: Vec<Expr>) -> Expr {
}
}

/// Returns `coalesce(args...)`, which evaluates to the value of the first [Expr]
/// which is not NULL
pub fn coalesce(args: Vec<Expr>) -> Expr {
Expr::ScalarFunction {
fun: built_in_function::BuiltinScalarFunction::Coalesce,
args,
}
}

#[cfg(test)]
mod test {
use super::*;
Expand Down
87 changes: 87 additions & 0 deletions datafusion/physical-expr/src/conditional_expressions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use arrow::array::{new_null_array, Array, BooleanArray};
use arrow::compute;
use arrow::compute::kernels::zip::zip;
use arrow::datatypes::DataType;

use datafusion_common::{DataFusionError, Result};
use datafusion_expr::ColumnarValue;

/// coalesce evaluates to the first value which is not NULL
pub fn coalesce(args: &[ColumnarValue]) -> Result<ColumnarValue> {
// do not accept 0 arguments.
if args.is_empty() {
return Err(DataFusionError::Internal(format!(
"coalesce was called with {} arguments. It requires at least 1.",
args.len()
)));
}

let size = match args[0] {
ColumnarValue::Array(ref a) => a.len(),
ColumnarValue::Scalar(ref _s) => 1,
};
let mut res = new_null_array(&args[0].data_type(), size);

for column_value in args {
for i in 0..size {
match column_value {
ColumnarValue::Array(array_ref) => {
let curr_null_mask = compute::is_null(res.as_ref())?;
let arr_not_null_mask = compute::is_not_null(array_ref)?;
let bool_mask = compute::and(&curr_null_mask, &arr_not_null_mask)?;
res = zip(&bool_mask, array_ref, &res)?;
}
ColumnarValue::Scalar(scalar) => {
if !scalar.is_null() && res.is_null(i) {
let vec: Vec<bool> =
(0..size).into_iter().map(|j| j == i).collect();
let bool_arr = BooleanArray::from(vec);
res =
zip(&bool_arr, scalar.to_array_of_size(size).as_ref(), &res)?;
continue;
}
}
}
}
}

Ok(ColumnarValue::Array(Arc::new(res)))
}

/// Currently supported types by the coalesce function.
/// The order of these types correspond to the order on which coercion applies
/// This should thus be from least informative to most informative
pub static SUPPORTED_COALESCE_TYPES: &[DataType] = &[
DataType::Boolean,
DataType::UInt8,
DataType::UInt16,
DataType::UInt32,
DataType::UInt64,
DataType::Int8,
DataType::Int16,
DataType::Int32,
DataType::Int64,
DataType::Float32,
DataType::Float64,
DataType::Utf8,
DataType::LargeUtf8,
];
1 change: 1 addition & 0 deletions datafusion/physical-expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
mod aggregate_expr;
pub mod array_expressions;
pub mod coercion_rule;
pub mod conditional_expressions;
#[cfg(feature = "crypto_expressions")]
pub mod crypto_expressions;
pub mod datetime_expressions;
Expand Down
Loading

0 comments on commit b890190

Please sign in to comment.