Skip to content

Commit

Permalink
Improve concat kernel capacity estimation (#3546)
Browse files Browse the repository at this point in the history
* Improve concat kernel capacity estimation

* Review feedback

* Format
  • Loading branch information
tustvold authored Jan 18, 2023
1 parent 96831de commit 56dfad0
Showing 1 changed file with 93 additions and 44 deletions.
137 changes: 93 additions & 44 deletions arrow-select/src/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,28 @@
//! assert_eq!(arr.len(), 3);
//! ```
use arrow_array::types::*;
use arrow_array::*;
use arrow_buffer::ArrowNativeType;
use arrow_data::transform::{Capacities, MutableArrayData};
use arrow_data::ArrayData;
use arrow_schema::{ArrowError, DataType, SchemaRef};

fn compute_str_values_length<Offset: OffsetSizeTrait>(arrays: &[&ArrayData]) -> usize {
arrays
.iter()
.map(|&data| {
// get the length of the value buffer
let buf_len = data.buffers()[1].len();
// find the offset of the buffer
// this returns a slice of offsets, starting from the offset of the array
// so we can take the first value
let offset = data.buffer::<Offset>(0)[0];
buf_len - offset.to_usize().unwrap()
})
.sum()
fn binary_capacity<T: ByteArrayType>(arrays: &[&dyn Array]) -> Capacities {
let mut item_capacity = 0;
let mut bytes_capacity = 0;
for array in arrays {
let a = array
.as_any()
.downcast_ref::<GenericByteArray<T>>()
.unwrap();

// Guaranteed to always have at least one element
let offsets = a.value_offsets();
bytes_capacity += offsets[offsets.len() - 1].as_usize() - offsets[0].as_usize();
item_capacity += a.len()
}

Capacities::Binary(item_capacity, Some(bytes_capacity))
}

/// Concatenate multiple [Array] of the same type into a single [ArrayRef].
Expand All @@ -61,43 +65,27 @@ pub fn concat(arrays: &[&dyn Array]) -> Result<ArrayRef, ArrowError> {
return Ok(array.slice(0, array.len()));
}

if arrays
.iter()
.any(|array| array.data_type() != arrays[0].data_type())
{
let d = arrays[0].data_type();
if arrays.iter().skip(1).any(|array| array.data_type() != d) {
return Err(ArrowError::InvalidArgumentError(
"It is not possible to concatenate arrays of different data types."
.to_string(),
));
}

let lengths = arrays.iter().map(|array| array.len()).collect::<Vec<_>>();
let capacity = lengths.iter().sum();

let arrays = arrays.iter().map(|a| a.data()).collect::<Vec<_>>();

let mut mutable = match arrays[0].data_type() {
DataType::Utf8 => {
let str_values_size = compute_str_values_length::<i32>(&arrays);
MutableArrayData::with_capacities(
arrays,
false,
Capacities::Binary(capacity, Some(str_values_size)),
)
}
DataType::LargeUtf8 => {
let str_values_size = compute_str_values_length::<i64>(&arrays);
MutableArrayData::with_capacities(
arrays,
false,
Capacities::Binary(capacity, Some(str_values_size)),
)
}
_ => MutableArrayData::new(arrays, false, capacity),
let capacity = match d {
DataType::Utf8 => binary_capacity::<Utf8Type>(arrays),
DataType::LargeUtf8 => binary_capacity::<LargeUtf8Type>(arrays),
DataType::Binary => binary_capacity::<BinaryType>(arrays),
DataType::LargeBinary => binary_capacity::<LargeBinaryType>(arrays),
_ => Capacities::Array(arrays.iter().map(|a| a.len()).sum()),
};

for (i, len) in lengths.iter().enumerate() {
mutable.extend(i, 0, *len)
let array_data = arrays.iter().map(|a| a.data()).collect::<Vec<_>>();
let mut mutable = MutableArrayData::with_capacities(array_data, false, capacity);

for (i, a) in arrays.iter().enumerate() {
mutable.extend(i, 0, a.len())
}

Ok(make_array(mutable.freeze()))
Expand Down Expand Up @@ -139,7 +127,6 @@ pub fn concat_batches<'a>(
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::types::*;
use arrow_schema::{Field, Schema};
use std::sync::Arc;

Expand Down Expand Up @@ -665,4 +652,66 @@ mod tests {
"Invalid argument error: batches[1] schema is different with argument schema.",
);
}

#[test]
fn concat_capacity() {
let a = Int32Array::from_iter_values(0..100);
let b = Int32Array::from_iter_values(10..20);
let a = concat(&[&a, &b]).unwrap();
let data = a.data();
assert_eq!(data.buffers()[0].len(), 440);
assert_eq!(data.buffers()[0].capacity(), 448); // Nearest multiple of 64

let a = concat(&[&a.slice(10, 20), &b]).unwrap();
let data = a.data();
assert_eq!(data.buffers()[0].len(), 120);
assert_eq!(data.buffers()[0].capacity(), 128); // Nearest multiple of 64

let a = StringArray::from_iter_values(std::iter::repeat("foo").take(100));
let b = StringArray::from(vec!["bingo", "bongo", "lorem", ""]);

let a = concat(&[&a, &b]).unwrap();
let data = a.data();
// (100 + 4 + 1) * size_of<i32>()
assert_eq!(data.buffers()[0].len(), 420);
assert_eq!(data.buffers()[0].capacity(), 448); // Nearest multiple of 64

// len("foo") * 100 + len("bingo") + len("bongo") + len("lorem")
assert_eq!(data.buffers()[1].len(), 315);
assert_eq!(data.buffers()[1].capacity(), 320); // Nearest multiple of 64

let a = concat(&[&a.slice(10, 40), &b]).unwrap();
let data = a.data();
// (40 + 4 + 5) * size_of<i32>()
assert_eq!(data.buffers()[0].len(), 180);
assert_eq!(data.buffers()[0].capacity(), 192); // Nearest multiple of 64

// len("foo") * 40 + len("bingo") + len("bongo") + len("lorem")
assert_eq!(data.buffers()[1].len(), 135);
assert_eq!(data.buffers()[1].capacity(), 192); // Nearest multiple of 64

let a = LargeBinaryArray::from_iter_values(std::iter::repeat(b"foo").take(100));
let b =
LargeBinaryArray::from_iter_values(std::iter::repeat(b"cupcakes").take(10));

let a = concat(&[&a, &b]).unwrap();
let data = a.data();
// (100 + 10 + 1) * size_of<i64>()
assert_eq!(data.buffers()[0].len(), 888);
assert_eq!(data.buffers()[0].capacity(), 896); // Nearest multiple of 64

// len("foo") * 100 + len("cupcakes") * 10
assert_eq!(data.buffers()[1].len(), 380);
assert_eq!(data.buffers()[1].capacity(), 384); // Nearest multiple of 64

let a = concat(&[&a.slice(10, 40), &b]).unwrap();
let data = a.data();
// (40 + 10 + 1) * size_of<i64>()
assert_eq!(data.buffers()[0].len(), 408);
assert_eq!(data.buffers()[0].capacity(), 448); // Nearest multiple of 64

// len("foo") * 40 + len("cupcakes") * 10
assert_eq!(data.buffers()[1].len(), 200);
assert_eq!(data.buffers()[1].capacity(), 256); // Nearest multiple of 64
}
}

0 comments on commit 56dfad0

Please sign in to comment.