Skip to content

Commit

Permalink
feat: support LargeList in make_array and array_length (#8121)
Browse files Browse the repository at this point in the history
* feat: support  LargeList in make_array and
array_length

* chore: add tests

* fix: update tests for nested array

* use usise_as

* add new_large_list

* refactor array_length

* add comment

* update test in sqllogictest

* fix ci

* fix macro

* use usize_as

* update comment

* return based on data_type in make_array
Weijun-H authored Dec 3, 2023
1 parent 075ff3d commit f6af014
Showing 2 changed files with 83 additions and 13 deletions.
47 changes: 34 additions & 13 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
@@ -171,6 +171,10 @@ fn compute_array_length(
value = downcast_arg!(value, ListArray).value(0);
current_dimension += 1;
}
DataType::LargeList(..) => {
value = downcast_arg!(value, LargeListArray).value(0);
current_dimension += 1;
}
_ => return Ok(None),
}
}
@@ -252,7 +256,7 @@ macro_rules! call_array_function {
}

/// Convert one or more [`ArrayRef`] of the same type into a
/// `ListArray`
/// `ListArray` or 'LargeListArray' depending on the offset size.
///
/// # Example (non nested)
///
@@ -291,7 +295,10 @@ macro_rules! call_array_function {
/// └──────────────┘ └──────────────┘ └─────────────────────────────┘
/// col1 col2 output
/// ```
fn array_array(args: &[ArrayRef], data_type: DataType) -> Result<ArrayRef> {
fn array_array<O: OffsetSizeTrait>(
args: &[ArrayRef],
data_type: DataType,
) -> Result<ArrayRef> {
// do not accept 0 arguments.
if args.is_empty() {
return plan_err!("Array requires at least one argument");
@@ -308,8 +315,9 @@ fn array_array(args: &[ArrayRef], data_type: DataType) -> Result<ArrayRef> {
total_len += arg_data.len();
data.push(arg_data);
}
let mut offsets = Vec::with_capacity(total_len);
offsets.push(0);

let mut offsets: Vec<O> = Vec::with_capacity(total_len);
offsets.push(O::usize_as(0));

let capacity = Capacities::Array(total_len);
let data_ref = data.iter().collect::<Vec<_>>();
@@ -327,11 +335,11 @@ fn array_array(args: &[ArrayRef], data_type: DataType) -> Result<ArrayRef> {
mutable.extend_nulls(1);
}
}
offsets.push(mutable.len() as i32);
offsets.push(O::usize_as(mutable.len()));
}

let data = mutable.freeze();
Ok(Arc::new(ListArray::try_new(

Ok(Arc::new(GenericListArray::<O>::try_new(
Arc::new(Field::new("item", data_type, true)),
OffsetBuffer::new(offsets.into()),
arrow_array::make_array(data),
@@ -356,7 +364,8 @@ pub fn make_array(arrays: &[ArrayRef]) -> Result<ArrayRef> {
let array = new_null_array(&DataType::Null, arrays.len());
Ok(Arc::new(array_into_list_array(array)))
}
data_type => array_array(arrays, data_type),
DataType::LargeList(..) => array_array::<i64>(arrays, data_type),
_ => array_array::<i32>(arrays, data_type),
}
}

@@ -1693,11 +1702,11 @@ pub fn flatten(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(Arc::new(flattened_array) as ArrayRef)
}

/// Array_length SQL function
pub fn array_length(args: &[ArrayRef]) -> Result<ArrayRef> {
let list_array = as_list_array(&args[0])?;
let dimension = if args.len() == 2 {
as_int64_array(&args[1])?.clone()
/// Dispatch array length computation based on the offset type.
fn array_length_dispatch<O: OffsetSizeTrait>(array: &[ArrayRef]) -> Result<ArrayRef> {
let list_array = as_generic_list_array::<O>(&array[0])?;
let dimension = if array.len() == 2 {
as_int64_array(&array[1])?.clone()
} else {
Int64Array::from_value(1, list_array.len())
};
@@ -1711,6 +1720,18 @@ pub fn array_length(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(Arc::new(result) as ArrayRef)
}

/// Array_length SQL function
pub fn array_length(args: &[ArrayRef]) -> Result<ArrayRef> {
match &args[0].data_type() {
DataType::List(_) => array_length_dispatch::<i32>(args),
DataType::LargeList(_) => array_length_dispatch::<i64>(args),
_ => internal_err!(
"array_length does not support type '{:?}'",
args[0].data_type()
),
}
}

/// Array_dims SQL function
pub fn array_dims(args: &[ArrayRef]) -> Result<ArrayRef> {
let list_array = as_list_array(&args[0])?;
49 changes: 49 additions & 0 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
@@ -2371,24 +2371,44 @@ select array_length(make_array(1, 2, 3, 4, 5)), array_length(make_array(1, 2, 3)
----
5 3 3

query III
select array_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')), array_length(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')), array_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'LargeList(List(Int64))'));
----
5 3 3

# array_length scalar function #2
query III
select array_length(make_array(1, 2, 3, 4, 5), 1), array_length(make_array(1, 2, 3), 1), array_length(make_array([1, 2], [3, 4], [5, 6]), 1);
----
5 3 3

query III
select array_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 1), array_length(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 1), array_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'LargeList(List(Int64))'), 1);
----
5 3 3

# array_length scalar function #3
query III
select array_length(make_array(1, 2, 3, 4, 5), 2), array_length(make_array(1, 2, 3), 2), array_length(make_array([1, 2], [3, 4], [5, 6]), 2);
----
NULL NULL 2

query III
select array_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_length(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 2), array_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'LargeList(List(Int64))'), 2);
----
NULL NULL 2

# array_length scalar function #4
query II
select array_length(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 1), array_length(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 2);
----
3 2

query II
select array_length(arrow_cast(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 'LargeList(List(List(Int64)))'), 1), array_length(arrow_cast(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 'LargeList(List(List(Int64)))'), 2);
----
3 2

# array_length scalar function #5
query III
select array_length(make_array()), array_length(make_array(), 1), array_length(make_array(), 2)
@@ -2407,6 +2427,11 @@ select list_length(make_array(1, 2, 3, 4, 5)), list_length(make_array(1, 2, 3)),
----
5 3 3 NULL

query III
select list_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')), list_length(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')), list_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'LargeList(List(Int64))'));
----
5 3 3

# array_length with columns
query I
select array_length(column1, column3) from arrays_values;
@@ -2420,6 +2445,18 @@ NULL
NULL
NULL

query I
select array_length(arrow_cast(column1, 'LargeList(Int64)'), column3) from arrays_values;
----
10
NULL
NULL
NULL
NULL
NULL
NULL
NULL

# array_length with columns and scalars
query II
select array_length(array[array[1, 2], array[3, 4]], column3), array_length(column1, 1) from arrays_values;
@@ -2433,6 +2470,18 @@ NULL 10
NULL 10
NULL 10

query II
select array_length(arrow_cast(array[array[1, 2], array[3, 4]], 'LargeList(List(Int64))'), column3), array_length(arrow_cast(column1, 'LargeList(Int64)'), 1) from arrays_values;
----
2 10
2 10
NULL 10
NULL 10
NULL NULL
NULL 10
NULL 10
NULL 10

## array_dims (aliases: `list_dims`)

# array dims error

0 comments on commit f6af014

Please sign in to comment.