Skip to content

Commit

Permalink
filter kernel should work with UnionArray (#1412)
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya authored Mar 16, 2022
1 parent f0646f8 commit bfccb5f
Show file tree
Hide file tree
Showing 4 changed files with 334 additions and 5 deletions.
19 changes: 16 additions & 3 deletions arrow/src/array/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,16 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuff
MutableBuffer::new(capacity * mem::size_of::<u8>()),
empty_buffer,
],
DataType::Union(_, _) => unimplemented!(),
DataType::Union(_, mode) => {
let type_ids = MutableBuffer::new(capacity * mem::size_of::<i8>());
match mode {
UnionMode::Sparse => [type_ids, empty_buffer],
UnionMode::Dense => {
let offsets = MutableBuffer::new(capacity * mem::size_of::<i32>());
[type_ids, offsets]
}
}
}
}
}

Expand All @@ -210,7 +219,8 @@ pub(crate) fn into_buffers(
DataType::Utf8
| DataType::Binary
| DataType::LargeUtf8
| DataType::LargeBinary => vec![buffer1.into(), buffer2.into()],
| DataType::LargeBinary
| DataType::Union(_, _) => vec![buffer1.into(), buffer2.into()],
_ => vec![buffer1.into()],
}
}
Expand Down Expand Up @@ -559,7 +569,10 @@ impl ArrayData {
DataType::Map(field, _) => {
vec![Self::new_empty(field.data_type())]
}
DataType::Union(_, _) => unimplemented!(),
DataType::Union(fields, _) => fields
.iter()
.map(|field| Self::new_empty(field.data_type()))
.collect(),
DataType::Dictionary(_, data_type) => {
vec![Self::new_empty(data_type)]
}
Expand Down
20 changes: 18 additions & 2 deletions arrow/src/array/transform/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ mod list;
mod null;
mod primitive;
mod structure;
mod union;
mod utils;
mod variable_size;

Expand Down Expand Up @@ -272,9 +273,12 @@ fn build_extend(array: &ArrayData) -> Extend {
DataType::Struct(_) => structure::build_extend(array),
DataType::FixedSizeBinary(_) => fixed_binary::build_extend(array),
DataType::Float16 => primitive::build_extend::<f16>(array),
DataType::Union(_, mode) => match mode {
UnionMode::Sparse => union::build_extend_sparse(array),
UnionMode::Dense => union::build_extend_dense(array),
},
/*
DataType::FixedSizeList(_, _) => {}
DataType::Union(_) => {}
*/
ty => todo!(
"Take and filter operations still not supported for this datatype: `{:?}`",
Expand Down Expand Up @@ -326,9 +330,12 @@ fn build_extend_nulls(data_type: &DataType) -> ExtendNulls {
DataType::Struct(_) => structure::extend_nulls,
DataType::FixedSizeBinary(_) => fixed_binary::extend_nulls,
DataType::Float16 => primitive::extend_nulls::<f16>,
DataType::Union(_, mode) => match mode {
UnionMode::Sparse => union::extend_nulls_sparse,
UnionMode::Dense => union::extend_nulls_dense,
},
/*
DataType::FixedSizeList(_, _) => {}
DataType::Union(_) => {}
*/
ty => todo!(
"Take and filter operations still not supported for this datatype: `{:?}`",
Expand Down Expand Up @@ -522,6 +529,15 @@ impl<'a> MutableArrayData<'a> {
})
.collect::<Vec<_>>(),
},
DataType::Union(fields, _) => (0..fields.len())
.map(|i| {
let child_arrays = arrays
.iter()
.map(|array| &array.child_data()[i])
.collect::<Vec<_>>();
MutableArrayData::new(child_arrays, use_nulls, array_capacity)
})
.collect::<Vec<_>>(),
ty => {
todo!("Take and filter operations still not supported for this datatype: `{:?}`", ty)
}
Expand Down
161 changes: 161 additions & 0 deletions arrow/src/array/transform/union.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::array::ArrayData;

use super::{Extend, _MutableArrayData};

pub(super) fn build_extend_sparse(array: &ArrayData) -> Extend {
let type_ids = array.buffer::<i8>(0);

if array.null_count() == 0 {
Box::new(
move |mutable: &mut _MutableArrayData,
index: usize,
start: usize,
len: usize| {
// extends type_ids
mutable
.buffer1
.extend_from_slice(&type_ids[start..start + len]);

mutable
.child_data
.iter_mut()
.for_each(|child| child.extend(index, start, start + len))
},
)
} else {
Box::new(
move |mutable: &mut _MutableArrayData,
index: usize,
start: usize,
len: usize| {
// extends type_ids
mutable
.buffer1
.extend_from_slice(&type_ids[start..start + len]);

(start..start + len).for_each(|i| {
if array.is_valid(i) {
mutable
.child_data
.iter_mut()
.for_each(|child| child.extend(index, i, i + 1))
} else {
mutable
.child_data
.iter_mut()
.for_each(|child| child.extend_nulls(1))
}
})
},
)
}
}

pub(super) fn build_extend_dense(array: &ArrayData) -> Extend {
let type_ids = array.buffer::<i8>(0);
let offsets = array.buffer::<i32>(1);

if array.null_count() == 0 {
Box::new(
move |mutable: &mut _MutableArrayData,
index: usize,
start: usize,
len: usize| {
// extends type_ids
mutable
.buffer1
.extend_from_slice(&type_ids[start..start + len]);
// extends offsets
mutable
.buffer2
.extend_from_slice(&offsets[start..start + len]);

(start..start + len).for_each(|i| {
let type_id = type_ids[i] as usize;
let offset_start = offsets[start] as usize;

mutable.child_data[type_id].extend(
index,
offset_start,
offset_start + 1,
)
})
},
)
} else {
Box::new(
move |mutable: &mut _MutableArrayData,
index: usize,
start: usize,
len: usize| {
// extends type_ids
mutable
.buffer1
.extend_from_slice(&type_ids[start..start + len]);
// extends offsets
mutable
.buffer2
.extend_from_slice(&offsets[start..start + len]);

(start..start + len).for_each(|i| {
let type_id = type_ids[i] as usize;
let offset_start = offsets[start] as usize;

if array.is_valid(i) {
mutable.child_data[type_id].extend(
index,
offset_start,
offset_start + 1,
)
} else {
mutable.child_data[type_id].extend_nulls(1)
}
})
},
)
}
}

pub(super) fn extend_nulls_dense(mutable: &mut _MutableArrayData, len: usize) {
let mut count: usize = 0;
let num = len / mutable.child_data.len();
mutable
.child_data
.iter_mut()
.enumerate()
.for_each(|(idx, child)| {
let n = if count + num > len { len - count } else { num };
count += n;
mutable
.buffer1
.extend_from_slice(vec![idx as i8; n].as_slice());
mutable
.buffer2
.extend_from_slice(vec![child.len() as i32; n].as_slice());
child.extend_nulls(n)
})
}

pub(super) fn extend_nulls_sparse(mutable: &mut _MutableArrayData, len: usize) {
mutable
.child_data
.iter_mut()
.for_each(|child| child.extend_nulls(len))
}
139 changes: 139 additions & 0 deletions arrow/src/compute/kernels/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1521,4 +1521,143 @@ mod tests {

assert_eq!(&expected, &got);
}

fn test_filter_union_array(array: UnionArray) {
let filter_array = BooleanArray::from(vec![true, false, false]);
let c = filter(&array, &filter_array).unwrap();
let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();

let mut builder = UnionBuilder::new_dense(1);
builder.append::<Int32Type>("A", 1).unwrap();
let expected_array = builder.build().unwrap();

compare_union_arrays(filtered, &expected_array);

let filter_array = BooleanArray::from(vec![true, false, true]);
let c = filter(&array, &filter_array).unwrap();
let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();

let mut builder = UnionBuilder::new_dense(2);
builder.append::<Int32Type>("A", 1).unwrap();
builder.append::<Int32Type>("A", 34).unwrap();
let expected_array = builder.build().unwrap();

compare_union_arrays(filtered, &expected_array);

let filter_array = BooleanArray::from(vec![true, true, false]);
let c = filter(&array, &filter_array).unwrap();
let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();

let mut builder = UnionBuilder::new_dense(2);
builder.append::<Int32Type>("A", 1).unwrap();
builder.append::<Float64Type>("B", 3.2).unwrap();
let expected_array = builder.build().unwrap();

compare_union_arrays(filtered, &expected_array);
}

#[test]
fn test_filter_union_array_dense() {
let mut builder = UnionBuilder::new_dense(3);
builder.append::<Int32Type>("A", 1).unwrap();
builder.append::<Float64Type>("B", 3.2).unwrap();
builder.append::<Int32Type>("A", 34).unwrap();
let array = builder.build().unwrap();

test_filter_union_array(array);
}

#[test]
fn test_filter_union_array_dense_with_nulls() {
let mut builder = UnionBuilder::new_dense(4);
builder.append::<Int32Type>("A", 1).unwrap();
builder.append::<Float64Type>("B", 3.2).unwrap();
builder.append_null().unwrap();
builder.append::<Int32Type>("A", 34).unwrap();
let array = builder.build().unwrap();

let filter_array = BooleanArray::from(vec![true, false, true, false]);
let c = filter(&array, &filter_array).unwrap();
let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();

let mut builder = UnionBuilder::new_dense(1);
builder.append::<Int32Type>("A", 1).unwrap();
builder.append_null().unwrap();
let expected_array = builder.build().unwrap();

compare_union_arrays(filtered, &expected_array);
}

#[test]
fn test_filter_union_array_sparse() {
let mut builder = UnionBuilder::new_sparse(3);
builder.append::<Int32Type>("A", 1).unwrap();
builder.append::<Float64Type>("B", 3.2).unwrap();
builder.append::<Int32Type>("A", 34).unwrap();
let array = builder.build().unwrap();

test_filter_union_array(array);
}

#[test]
fn test_filter_union_array_sparse_with_nulls() {
let mut builder = UnionBuilder::new_sparse(4);
builder.append::<Int32Type>("A", 1).unwrap();
builder.append::<Float64Type>("B", 3.2).unwrap();
builder.append_null().unwrap();
builder.append::<Int32Type>("A", 34).unwrap();
let array = builder.build().unwrap();

let filter_array = BooleanArray::from(vec![true, false, true, false]);
let c = filter(&array, &filter_array).unwrap();
let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();

let mut builder = UnionBuilder::new_dense(1);
builder.append::<Int32Type>("A", 1).unwrap();
builder.append_null().unwrap();
let expected_array = builder.build().unwrap();

compare_union_arrays(filtered, &expected_array);
}

fn compare_union_arrays(union1: &UnionArray, union2: &UnionArray) {
assert_eq!(union1.len(), union2.len());

for i in 0..union1.len() {
let type_id = union1.type_id(i);

let slot1 = union1.value(i);
let slot2 = union2.value(i);

assert_eq!(union1.is_null(i), union2.is_null(i));

if !union1.is_null(i) && !union2.is_null(i) {
match type_id {
0 => {
let slot1 = slot1.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(slot1.len(), 1);
let value1 = slot1.value(0);

let slot2 = slot2.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(slot2.len(), 1);
let value2 = slot2.value(0);
assert_eq!(value1, value2);
}
1 => {
let slot1 =
slot1.as_any().downcast_ref::<Float64Array>().unwrap();
assert_eq!(slot1.len(), 1);
let value1 = slot1.value(0);

let slot2 =
slot2.as_any().downcast_ref::<Float64Array>().unwrap();
assert_eq!(slot2.len(), 1);
let value2 = slot2.value(0);
assert_eq!(value1, value2);
}
_ => unreachable!(),
}
}
}
}
}

0 comments on commit bfccb5f

Please sign in to comment.