Skip to content

Commit

Permalink
feat: Add support for cardinality function on maps (#11801)
Browse files Browse the repository at this point in the history
* feat: Add support for cardinality function on maps

* chore: Fix prettier

* feat: Add specialized signature for MapArray in ArrayFunctionSignature
  • Loading branch information
Weijun-H authored Aug 5, 2024
1 parent fcd907d commit eb2b5fe
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 9 deletions.
6 changes: 6 additions & 0 deletions datafusion/expr/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ pub enum ArrayFunctionSignature {
/// The function takes a single argument that must be a List/LargeList/FixedSizeList
/// or something that can be coerced to one of those types.
Array,
/// Specialized Signature for MapArray
/// The function takes a single argument that must be a MapArray
MapArray,
}

impl std::fmt::Display for ArrayFunctionSignature {
Expand All @@ -165,6 +168,9 @@ impl std::fmt::Display for ArrayFunctionSignature {
ArrayFunctionSignature::Array => {
write!(f, "array")
}
ArrayFunctionSignature::MapArray => {
write!(f, "map_array")
}
}
}
}
Expand Down
10 changes: 10 additions & 0 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,16 @@ fn get_valid_types(
array(&current_types[0])
.map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
}
ArrayFunctionSignature::MapArray => {
if current_types.len() != 1 {
return Ok(vec![vec![]]);
}

match &current_types[0] {
DataType::Map(_, _) => vec![vec![current_types[0].clone()]],
_ => vec![vec![]],
}
}
},
TypeSignature::Any(number) => {
if current_types.len() != *number {
Expand Down
39 changes: 31 additions & 8 deletions datafusion/functions-nested/src/cardinality.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,39 @@
//! [`ScalarUDFImpl`] definitions for cardinality function.
use crate::utils::make_scalar_function;
use arrow_array::{ArrayRef, GenericListArray, OffsetSizeTrait, UInt64Array};
use arrow_array::{
Array, ArrayRef, GenericListArray, MapArray, OffsetSizeTrait, UInt64Array,
};
use arrow_schema::DataType;
use arrow_schema::DataType::{FixedSizeList, LargeList, List, UInt64};
use datafusion_common::cast::{as_large_list_array, as_list_array};
use arrow_schema::DataType::{FixedSizeList, LargeList, List, Map, UInt64};
use datafusion_common::cast::{as_large_list_array, as_list_array, as_map_array};
use datafusion_common::Result;
use datafusion_common::{exec_err, plan_err};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use datafusion_expr::{
ArrayFunctionSignature, ColumnarValue, ScalarUDFImpl, Signature, TypeSignature,
Volatility,
};
use std::any::Any;
use std::sync::Arc;

make_udf_expr_and_func!(
Cardinality,
cardinality,
array,
"returns the total number of elements in the array.",
"returns the total number of elements in the array or map.",
cardinality_udf
);

impl Cardinality {
pub fn new() -> Self {
Self {
signature: Signature::array(Volatility::Immutable),
signature: Signature::one_of(
vec![
TypeSignature::ArraySignature(ArrayFunctionSignature::Array),
TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray),
],
Volatility::Immutable,
),
aliases: vec![],
}
}
Expand All @@ -64,9 +75,9 @@ impl ScalarUDFImpl for Cardinality {

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Ok(match arg_types[0] {
List(_) | LargeList(_) | FixedSizeList(_, _) => UInt64,
List(_) | LargeList(_) | FixedSizeList(_, _) | Map(_, _) => UInt64,
_ => {
return plan_err!("The cardinality function can only accept List/LargeList/FixedSizeList.");
return plan_err!("The cardinality function can only accept List/LargeList/FixedSizeList/Map.");
}
})
}
Expand Down Expand Up @@ -95,12 +106,24 @@ pub fn cardinality_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let list_array = as_large_list_array(&args[0])?;
generic_list_cardinality::<i64>(list_array)
}
Map(_, _) => {
let map_array = as_map_array(&args[0])?;
generic_map_cardinality(map_array)
}
other => {
exec_err!("cardinality does not support type '{:?}'", other)
}
}
}

fn generic_map_cardinality(array: &MapArray) -> Result<ArrayRef> {
let result: UInt64Array = array
.iter()
.map(|opt_arr| opt_arr.map(|arr| arr.len() as u64))
.collect();
Ok(Arc::new(result))
}

fn generic_list_cardinality<O: OffsetSizeTrait>(
array: &GenericListArray<O>,
) -> Result<ArrayRef> {
Expand Down
9 changes: 9 additions & 0 deletions datafusion/sqllogictest/test_files/map.slt
Original file line number Diff line number Diff line change
Expand Up @@ -459,3 +459,12 @@ SELECT MAP { 'a': 1, 2: 3 };
# SELECT MAKE_MAP(1, null, 2, 33, 3, null)[2];
# ----
# 33

## cardinality

# cardinality scalar function
query IIII
select cardinality(map([1, 2, 3], ['a', 'b', 'c'])), cardinality(MAP {'a': 1, 'b': null}), cardinality(MAP([],[])),
cardinality(MAP {'a': MAP {1:'a', 2:'b', 3:'c'}, 'b': MAP {2:'c', 4:'d'} });
----
3 2 0 2
2 changes: 1 addition & 1 deletion docs/source/user-guide/expressions.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ select log(-1), log(0), sqrt(-1);
| array_except(array1, array2) | Returns an array of the elements that appear in the first array but not in the second. `array_except([1, 2, 3, 4], [5, 6, 3, 4]) -> [1, 2]` |
| array_resize(array, size, value) | Resizes the list to contain size elements. Initializes new elements with value or empty if value is not set. `array_resize([1, 2, 3], 5, 0) -> [1, 2, 3, 0, 0]` |
| array_sort(array, desc, null_first) | Returns sorted array. `array_sort([3, 1, 2, 5, 4]) -> [1, 2, 3, 4, 5]` |
| cardinality(array) | Returns the total number of elements in the array. `cardinality([[1, 2, 3], [4, 5, 6]]) -> 6` |
| cardinality(array/map) | Returns the total number of elements in the array or map. `cardinality([[1, 2, 3], [4, 5, 6]]) -> 6` |
| make_array(value1, [value2 [, ...]]) | Returns an Arrow array using the specified input expressions. `make_array(1, 2, 3) -> [1, 2, 3]` |
| range(start [, stop, step]) | Returns an Arrow array between start and stop with step. `SELECT range(2, 10, 3) -> [2, 5, 8]` |
| string_to_array(array, delimiter, null_string) | Splits a `string` based on a `delimiter` and returns an array of parts. Any parts matching the optional `null_string` will be replaced with `NULL`. `string_to_array('abc#def#ghi', '#', ' ') -> ['abc', 'def', 'ghi']` |
Expand Down

0 comments on commit eb2b5fe

Please sign in to comment.