Skip to content

Commit

Permalink
ARROW-9902: [Rust] [DataFusion] Add array() built-in function
Browse files Browse the repository at this point in the history
This adds `array()` built-in function to most primitive types. For composite types, this is more challenging and I decided to scope out of this PR.

Closes #8102 from jorgecarleitao/array

Authored-by: Jorge C. Leitao <[email protected]>
Signed-off-by: Andy Grove <[email protected]>
  • Loading branch information
jorgecarleitao authored and andygrove committed Sep 20, 2020
1 parent 13ab29d commit 98b710a
Show file tree
Hide file tree
Showing 7 changed files with 303 additions and 74 deletions.
2 changes: 2 additions & 0 deletions rust/datafusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ DataFusion includes a simple command-line interactive SQL utility. See the [CLI
- [ ] Basic date functions
- [ ] Basic time functions
- [x] Basic timestamp functions
- nested functions
- [x] Array of columns
- [x] Sorting
- [ ] Nested types
- [ ] Lists
Expand Down
8 changes: 8 additions & 0 deletions rust/datafusion/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,14 @@ pub fn concat(args: Vec<Expr>) -> Expr {
}
}

/// returns an array of fixed size with each argument on it.
pub fn array(args: Vec<Expr>) -> Expr {
Expr::ScalarFunction {
fun: functions::BuiltinScalarFunction::Array,
args,
}
}

/// Creates a new UDF with a specific signature and specific return type.
/// This is a helper function to create a new UDF.
/// The function `create_udf` returns a subset of all possible `ScalarFunction`:
Expand Down
108 changes: 108 additions & 0 deletions rust/datafusion/src/physical_plan/array_expressions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Array expressions
use crate::error::{ExecutionError, Result};
use arrow::array::*;
use arrow::datatypes::DataType;
use std::sync::Arc;

macro_rules! downcast_vec {
($ARGS:expr, $ARRAY_TYPE:ident) => {{
$ARGS
.iter()
.map(|e| match e.as_any().downcast_ref::<$ARRAY_TYPE>() {
Some(array) => Ok(array),
_ => Err(ExecutionError::General("failed to downcast".to_string())),
})
}};
}

macro_rules! array {
($ARGS:expr, $ARRAY_TYPE:ident, $BUILDER_TYPE:ident) => {{
// downcast all arguments to their common format
let args =
downcast_vec!($ARGS, $ARRAY_TYPE).collect::<Result<Vec<&$ARRAY_TYPE>>>()?;

let mut builder = FixedSizeListBuilder::<$BUILDER_TYPE>::new(
<$BUILDER_TYPE>::new(args[0].len()),
args.len() as i32,
);
// for each entry in the array
for index in 0..args[0].len() {
for arg in &args {
if arg.is_null(index) {
builder.values().append_null()?;
} else {
builder.values().append_value(arg.value(index))?;
}
}
builder.append(true)?;
}
Ok(Arc::new(builder.finish()))
}};
}

/// put values in an array.
pub fn array(args: &[ArrayRef]) -> Result<ArrayRef> {
// do not accept 0 arguments.
if args.len() == 0 {
return Err(ExecutionError::InternalError(
"array requires at least one argument".to_string(),
));
}

match args[0].data_type() {
DataType::Utf8 => array!(args, StringArray, StringBuilder),
DataType::LargeUtf8 => array!(args, LargeStringArray, LargeStringBuilder),
DataType::Boolean => array!(args, BooleanArray, BooleanBuilder),
DataType::Float32 => array!(args, Float32Array, Float32Builder),
DataType::Float64 => array!(args, Float64Array, Float64Builder),
DataType::Int8 => array!(args, Int8Array, Int8Builder),
DataType::Int16 => array!(args, Int16Array, Int16Builder),
DataType::Int32 => array!(args, Int32Array, Int32Builder),
DataType::Int64 => array!(args, Int64Array, Int64Builder),
DataType::UInt8 => array!(args, UInt8Array, UInt8Builder),
DataType::UInt16 => array!(args, UInt16Array, UInt16Builder),
DataType::UInt32 => array!(args, UInt32Array, UInt32Builder),
DataType::UInt64 => array!(args, UInt64Array, UInt64Builder),
data_type => Err(ExecutionError::NotImplemented(format!(
"Array is not implemented for type '{:?}'.",
data_type
))),
}
}

/// Currently supported types by the array function.
/// The order of these types correspond to the order on which coercion applies
/// This should thus be from least informative to most informative
pub static SUPPORTED_ARRAY_TYPES: &'static [DataType] = &[
DataType::Boolean,
DataType::UInt8,
DataType::UInt16,
DataType::UInt32,
DataType::UInt64,
DataType::Int8,
DataType::Int16,
DataType::Int32,
DataType::Int64,
DataType::Float32,
DataType::Float64,
DataType::Utf8,
DataType::LargeUtf8,
];
81 changes: 79 additions & 2 deletions rust/datafusion/src/physical_plan/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use super::{
PhysicalExpr,
};
use crate::error::{ExecutionError, Result};
use crate::physical_plan::array_expressions;
use crate::physical_plan::datetime_expressions;
use crate::physical_plan::math_expressions;
use crate::physical_plan::string_expressions;
Expand Down Expand Up @@ -118,6 +119,8 @@ pub enum BuiltinScalarFunction {
Concat,
/// to_timestamp
ToTimestamp,
/// construct an array from columns
Array,
}

impl fmt::Display for BuiltinScalarFunction {
Expand Down Expand Up @@ -151,6 +154,7 @@ impl FromStr for BuiltinScalarFunction {
"length" => BuiltinScalarFunction::Length,
"concat" => BuiltinScalarFunction::Concat,
"to_timestamp" => BuiltinScalarFunction::ToTimestamp,
"array" => BuiltinScalarFunction::Array,
_ => {
return Err(ExecutionError::General(format!(
"There is no built-in function named {}",
Expand Down Expand Up @@ -189,6 +193,10 @@ pub fn return_type(
BuiltinScalarFunction::ToTimestamp => {
Ok(DataType::Timestamp(TimeUnit::Nanosecond, None))
}
BuiltinScalarFunction::Array => Ok(DataType::FixedSizeList(
Box::new(arg_types[0].clone()),
arg_types.len() as i32,
)),
_ => Ok(DataType::Float64),
}
}
Expand Down Expand Up @@ -225,6 +233,7 @@ pub fn create_physical_expr(
BuiltinScalarFunction::ToTimestamp => {
|args| Ok(Arc::new(datetime_expressions::to_timestamp(args)?))
}
BuiltinScalarFunction::Array => |args| Ok(array_expressions::array(args)?),
});
// coerce
let args = coerce(args, input_schema, &signature(fun))?;
Expand All @@ -251,6 +260,9 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature {
BuiltinScalarFunction::Length => Signature::Uniform(1, vec![DataType::Utf8]),
BuiltinScalarFunction::Concat => Signature::Variadic(vec![DataType::Utf8]),
BuiltinScalarFunction::ToTimestamp => Signature::Uniform(1, vec![DataType::Utf8]),
BuiltinScalarFunction::Array => {
Signature::Variadic(array_expressions::SUPPORTED_ARRAY_TYPES.to_vec())
}
// math expressions expect 1 argument of type f64 or f32
// priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we
// return the best approximation for it (in f64).
Expand Down Expand Up @@ -342,8 +354,8 @@ mod tests {
};
use arrow::{
array::{
ArrayRef, Float64Array, Int32Array, PrimitiveArrayOps, StringArray,
StringArrayOps,
ArrayRef, FixedSizeListArray, Float64Array, Int32Array, PrimitiveArrayOps,
StringArray, StringArrayOps,
},
datatypes::Field,
record_batch::RecordBatch,
Expand Down Expand Up @@ -432,4 +444,69 @@ mod tests {
Ok(())
}
}

fn generic_test_array(
value1: ScalarValue,
value2: ScalarValue,
expected_type: DataType,
expected: &str,
) -> Result<()> {
// any type works here: we evaluate against a literal of `value`
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let columns: Vec<ArrayRef> = vec![Arc::new(Int32Array::from(vec![1]))];

let expr = create_physical_expr(
&BuiltinScalarFunction::Array,
&vec![lit(value1.clone()), lit(value2.clone())],
&schema,
)?;

// type is correct
assert_eq!(
expr.data_type(&schema)?,
// type equals to a common coercion
DataType::FixedSizeList(Box::new(expected_type), 2)
);

// evaluate works
let result =
expr.evaluate(&RecordBatch::try_new(Arc::new(schema.clone()), columns)?)?;

// downcast works
let result = result
.as_any()
.downcast_ref::<FixedSizeListArray>()
.unwrap();

// value is correct
assert_eq!(format!("{:?}", result.value(0)), expected);

Ok(())
}

#[test]
fn test_array() -> Result<()> {
generic_test_array(
ScalarValue::Utf8("aa".to_string()),
ScalarValue::Utf8("aa".to_string()),
DataType::Utf8,
"StringArray\n[\n \"aa\",\n \"aa\",\n]",
)?;

// different types, to validate that casting happens
generic_test_array(
ScalarValue::UInt32(1),
ScalarValue::UInt64(1),
DataType::UInt64,
"PrimitiveArray<UInt64>\n[\n 1,\n 1,\n]",
)?;

// different types (another order), to validate that casting happens
generic_test_array(
ScalarValue::UInt64(1),
ScalarValue::UInt32(1),
DataType::UInt64,
"PrimitiveArray<UInt64>\n[\n 1,\n 1,\n]",
)
}
}
1 change: 1 addition & 0 deletions rust/datafusion/src/physical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ pub trait Accumulator: Debug {
}

pub mod aggregates;
pub mod array_expressions;
pub mod common;
pub mod csv;
pub mod datetime_expressions;
Expand Down
2 changes: 1 addition & 1 deletion rust/datafusion/src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@
pub use crate::dataframe::DataFrame;
pub use crate::execution::context::{ExecutionConfig, ExecutionContext};
pub use crate::logical_plan::{
avg, col, concat, count, create_udf, length, lit, max, min, sum,
array, avg, col, concat, count, create_udf, length, lit, max, min, sum,
};
pub use crate::physical_plan::csv::CsvReadOptions;
Loading

0 comments on commit 98b710a

Please sign in to comment.