Skip to content

Commit

Permalink
make array_intersect handle empty/null arrays rightly
Browse files Browse the repository at this point in the history
Signed-off-by: veeupup <[email protected]>
  • Loading branch information
Veeupup committed Nov 19, 2023
1 parent 9dbaf5f commit c08d6cb
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 49 deletions.
13 changes: 8 additions & 5 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -599,14 +599,19 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArraySlice => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayToString => Ok(Utf8),
BuiltinScalarFunction::ArrayIntersect => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayIntersect => {
match (input_expr_types[0].clone(), input_expr_types[1].clone()) {
(DataType::Null, DataType::Null) => Ok(DataType::List(Arc::new(
Field::new("item", DataType::Null, true),
))),
(dt, _) => Ok(dt),
}
}
BuiltinScalarFunction::ArrayUnion => {
match (input_expr_types[0].clone(), input_expr_types[1].clone()) {
(DataType::Null, DataType::Null) => Ok(DataType::List(Arc::new(
Field::new("item", DataType::Null, true),
))),
(DataType::Null, dt) => Ok(dt),
(dt, DataType::Null) => Ok(dt),
(dt, _) => Ok(dt),
}
}
Expand All @@ -618,8 +623,6 @@ impl BuiltinScalarFunction {
(DataType::Null, DataType::Null) => Ok(DataType::List(Arc::new(
Field::new("item", DataType::Null, true),
))),
(DataType::Null, dt) => Ok(dt),
(dt, DataType::Null) => Ok(dt),
(dt, _) => Ok(dt),
}
}
Expand Down
112 changes: 68 additions & 44 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ pub fn array_except(args: &[ArrayRef]) -> Result<ArrayRef> {
match (array1.data_type(), array2.data_type()) {
(DataType::Null, DataType::Null) => {
// NullArray(1): means null, NullArray(0): means []
// except([], null) = [], except(null, []) = null, except(null, null) = null
// except([], []) = [], except([], null) = [], except(null, []) = null, except(null, null) = null
let nulls = match (array1.len(), array2.len()) {
(1, _) => Some(NullBuffer::new_null(1)),
_ => None,
Expand Down Expand Up @@ -1527,7 +1527,7 @@ pub fn array_union(args: &[ArrayRef]) -> Result<ArrayRef> {
match (array1.data_type(), array2.data_type()) {
(DataType::Null, DataType::Null) => {
// NullArray(1): means null, NullArray(0): means []
// union([], null) = [], union(null, []) = [], union(null, null) = null
// union([], []) = [], union([], null) = [], union(null, []) = [], union(null, null) = null
let nulls = match (array1.len(), array2.len()) {
(1, 1) => Some(NullBuffer::new_null(1)),
_ => None,
Expand Down Expand Up @@ -2028,55 +2028,79 @@ pub fn string_to_array<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef
pub fn array_intersect(args: &[ArrayRef]) -> Result<ArrayRef> {
assert_eq!(args.len(), 2);

let first_array = as_list_array(&args[0])?;
let second_array = as_list_array(&args[1])?;
let first_array = &args[0];
let second_array = &args[1];

if first_array.value_type() != second_array.value_type() {
return internal_err!("array_intersect is not implemented for '{first_array:?}' and '{second_array:?}'");
}
let dt = first_array.value_type();
match (first_array.data_type(), second_array.data_type()) {
(DataType::Null, DataType::Null) => {
// NullArray(1): means null, NullArray(0): means []
// intersect([], []) = [], intersect([], null) = [], intersect(null, []) = [], intersect(null, null) = null
let nulls = match (first_array.len(), second_array.len()) {
(1, 1) => Some(NullBuffer::new_null(1)),
_ => None,
};
let arr = Arc::new(ListArray::try_new(
Arc::new(Field::new("item", DataType::Null, true)),
OffsetBuffer::new(vec![0; 2].into()),
Arc::new(NullArray::new(0)),
nulls,
)?) as ArrayRef;
Ok(arr)
}
_ => {
let first_array = as_list_array(&first_array)?;
let second_array = as_list_array(&second_array)?;

let mut offsets = vec![0];
let mut new_arrays = vec![];

let converter = RowConverter::new(vec![SortField::new(dt.clone())])?;
for (first_arr, second_arr) in first_array.iter().zip(second_array.iter()) {
if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) {
let l_values = converter.convert_columns(&[first_arr])?;
let r_values = converter.convert_columns(&[second_arr])?;

let values_set: HashSet<_> = l_values.iter().collect();
let mut rows = Vec::with_capacity(r_values.num_rows());
for r_val in r_values.iter().sorted().dedup() {
if values_set.contains(&r_val) {
rows.push(r_val);
}
if first_array.value_type() != second_array.value_type() {
return internal_err!("array_intersect is not implemented for '{first_array:?}' and '{second_array:?}'");
}

let last_offset: i32 = match offsets.last().copied() {
Some(offset) => offset,
None => return internal_err!("offsets should not be empty"),
};
offsets.push(last_offset + rows.len() as i32);
let arrays = converter.convert_rows(rows)?;
let array = match arrays.get(0) {
Some(array) => array.clone(),
None => {
return internal_err!(
"array_intersect: failed to get array from rows"
)
let dt = first_array.value_type();

let mut offsets = vec![0];
let mut new_arrays = vec![];

let converter = RowConverter::new(vec![SortField::new(dt.clone())])?;
for (first_arr, second_arr) in first_array.iter().zip(second_array.iter()) {
if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) {
let l_values = converter.convert_columns(&[first_arr])?;
let r_values = converter.convert_columns(&[second_arr])?;

let values_set: HashSet<_> = l_values.iter().collect();
let mut rows = Vec::with_capacity(r_values.num_rows());
for r_val in r_values.iter().sorted().dedup() {
if values_set.contains(&r_val) {
rows.push(r_val);
}
}

let last_offset: i32 = match offsets.last().copied() {
Some(offset) => offset,
None => return internal_err!("offsets should not be empty"),
};
offsets.push(last_offset + rows.len() as i32);
let arrays = converter.convert_rows(rows)?;
let array = match arrays.get(0) {
Some(array) => array.clone(),
None => {
return internal_err!(
"array_intersect: failed to get array from rows"
)
}
};
new_arrays.push(array);
}
};
new_arrays.push(array);
}

let field = Arc::new(Field::new("item", dt, true));
let offsets = OffsetBuffer::new(offsets.into());
let new_arrays_ref =
new_arrays.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
let values = compute::concat(&new_arrays_ref)?;
let arr = Arc::new(ListArray::try_new(field, offsets, values, None)?);
Ok(arr)
}
}

let field = Arc::new(Field::new("item", dt, true));
let offsets = OffsetBuffer::new(offsets.into());
let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
let values = compute::concat(&new_arrays_ref)?;
let arr = Arc::new(ListArray::try_new(field, offsets, values, None)?);
Ok(arr)
}

#[cfg(test)]
Expand Down
20 changes: 20 additions & 0 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2687,6 +2687,26 @@ SELECT array_intersect(make_array(1,2,3), make_array(2,3,4)),
----
[2, 3] [] [aa, cc] [true] [2.2, 3.3] [[2, 2], [3, 3]]

query ?
select array_intersect([], []);
----
[]

query ?
select array_intersect([], null);
----
[]

query ?
select array_intersect(null, []);
----
[]

query ?
select array_intersect(null, null);
----
NULL

query ??????
SELECT list_intersect(make_array(1,2,3), make_array(2,3,4)),
list_intersect(make_array(1,3,5), make_array(2,4,6)),
Expand Down

0 comments on commit c08d6cb

Please sign in to comment.