From 56dfad0b2a03bc14f398a2998a68da2bc02fb7d2 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Wed, 18 Jan 2023 12:48:47 +0000 Subject: [PATCH] Improve concat kernel capacity estimation (#3546) * Improve concat kernel capacity estimation * Review feedback * Format --- arrow-select/src/concat.rs | 137 +++++++++++++++++++++++++------------ 1 file changed, 93 insertions(+), 44 deletions(-) diff --git a/arrow-select/src/concat.rs b/arrow-select/src/concat.rs index 7e28f1695509..cff8fd25b7f1 100644 --- a/arrow-select/src/concat.rs +++ b/arrow-select/src/concat.rs @@ -30,24 +30,28 @@ //! assert_eq!(arr.len(), 3); //! ``` +use arrow_array::types::*; use arrow_array::*; +use arrow_buffer::ArrowNativeType; use arrow_data::transform::{Capacities, MutableArrayData}; -use arrow_data::ArrayData; use arrow_schema::{ArrowError, DataType, SchemaRef}; -fn compute_str_values_length(arrays: &[&ArrayData]) -> usize { - arrays - .iter() - .map(|&data| { - // get the length of the value buffer - let buf_len = data.buffers()[1].len(); - // find the offset of the buffer - // this returns a slice of offsets, starting from the offset of the array - // so we can take the first value - let offset = data.buffer::(0)[0]; - buf_len - offset.to_usize().unwrap() - }) - .sum() +fn binary_capacity(arrays: &[&dyn Array]) -> Capacities { + let mut item_capacity = 0; + let mut bytes_capacity = 0; + for array in arrays { + let a = array + .as_any() + .downcast_ref::>() + .unwrap(); + + // Guaranteed to always have at least one element + let offsets = a.value_offsets(); + bytes_capacity += offsets[offsets.len() - 1].as_usize() - offsets[0].as_usize(); + item_capacity += a.len() + } + + Capacities::Binary(item_capacity, Some(bytes_capacity)) } /// Concatenate multiple [Array] of the same type into a single [ArrayRef]. @@ -61,43 +65,27 @@ pub fn concat(arrays: &[&dyn Array]) -> Result { return Ok(array.slice(0, array.len())); } - if arrays - .iter() - .any(|array| array.data_type() != arrays[0].data_type()) - { + let d = arrays[0].data_type(); + if arrays.iter().skip(1).any(|array| array.data_type() != d) { return Err(ArrowError::InvalidArgumentError( "It is not possible to concatenate arrays of different data types." .to_string(), )); } - let lengths = arrays.iter().map(|array| array.len()).collect::>(); - let capacity = lengths.iter().sum(); - - let arrays = arrays.iter().map(|a| a.data()).collect::>(); - - let mut mutable = match arrays[0].data_type() { - DataType::Utf8 => { - let str_values_size = compute_str_values_length::(&arrays); - MutableArrayData::with_capacities( - arrays, - false, - Capacities::Binary(capacity, Some(str_values_size)), - ) - } - DataType::LargeUtf8 => { - let str_values_size = compute_str_values_length::(&arrays); - MutableArrayData::with_capacities( - arrays, - false, - Capacities::Binary(capacity, Some(str_values_size)), - ) - } - _ => MutableArrayData::new(arrays, false, capacity), + let capacity = match d { + DataType::Utf8 => binary_capacity::(arrays), + DataType::LargeUtf8 => binary_capacity::(arrays), + DataType::Binary => binary_capacity::(arrays), + DataType::LargeBinary => binary_capacity::(arrays), + _ => Capacities::Array(arrays.iter().map(|a| a.len()).sum()), }; - for (i, len) in lengths.iter().enumerate() { - mutable.extend(i, 0, *len) + let array_data = arrays.iter().map(|a| a.data()).collect::>(); + let mut mutable = MutableArrayData::with_capacities(array_data, false, capacity); + + for (i, a) in arrays.iter().enumerate() { + mutable.extend(i, 0, a.len()) } Ok(make_array(mutable.freeze())) @@ -139,7 +127,6 @@ pub fn concat_batches<'a>( #[cfg(test)] mod tests { use super::*; - use arrow_array::types::*; use arrow_schema::{Field, Schema}; use std::sync::Arc; @@ -665,4 +652,66 @@ mod tests { "Invalid argument error: batches[1] schema is different with argument schema.", ); } + + #[test] + fn concat_capacity() { + let a = Int32Array::from_iter_values(0..100); + let b = Int32Array::from_iter_values(10..20); + let a = concat(&[&a, &b]).unwrap(); + let data = a.data(); + assert_eq!(data.buffers()[0].len(), 440); + assert_eq!(data.buffers()[0].capacity(), 448); // Nearest multiple of 64 + + let a = concat(&[&a.slice(10, 20), &b]).unwrap(); + let data = a.data(); + assert_eq!(data.buffers()[0].len(), 120); + assert_eq!(data.buffers()[0].capacity(), 128); // Nearest multiple of 64 + + let a = StringArray::from_iter_values(std::iter::repeat("foo").take(100)); + let b = StringArray::from(vec!["bingo", "bongo", "lorem", ""]); + + let a = concat(&[&a, &b]).unwrap(); + let data = a.data(); + // (100 + 4 + 1) * size_of() + assert_eq!(data.buffers()[0].len(), 420); + assert_eq!(data.buffers()[0].capacity(), 448); // Nearest multiple of 64 + + // len("foo") * 100 + len("bingo") + len("bongo") + len("lorem") + assert_eq!(data.buffers()[1].len(), 315); + assert_eq!(data.buffers()[1].capacity(), 320); // Nearest multiple of 64 + + let a = concat(&[&a.slice(10, 40), &b]).unwrap(); + let data = a.data(); + // (40 + 4 + 5) * size_of() + assert_eq!(data.buffers()[0].len(), 180); + assert_eq!(data.buffers()[0].capacity(), 192); // Nearest multiple of 64 + + // len("foo") * 40 + len("bingo") + len("bongo") + len("lorem") + assert_eq!(data.buffers()[1].len(), 135); + assert_eq!(data.buffers()[1].capacity(), 192); // Nearest multiple of 64 + + let a = LargeBinaryArray::from_iter_values(std::iter::repeat(b"foo").take(100)); + let b = + LargeBinaryArray::from_iter_values(std::iter::repeat(b"cupcakes").take(10)); + + let a = concat(&[&a, &b]).unwrap(); + let data = a.data(); + // (100 + 10 + 1) * size_of() + assert_eq!(data.buffers()[0].len(), 888); + assert_eq!(data.buffers()[0].capacity(), 896); // Nearest multiple of 64 + + // len("foo") * 100 + len("cupcakes") * 10 + assert_eq!(data.buffers()[1].len(), 380); + assert_eq!(data.buffers()[1].capacity(), 384); // Nearest multiple of 64 + + let a = concat(&[&a.slice(10, 40), &b]).unwrap(); + let data = a.data(); + // (40 + 10 + 1) * size_of() + assert_eq!(data.buffers()[0].len(), 408); + assert_eq!(data.buffers()[0].capacity(), 448); // Nearest multiple of 64 + + // len("foo") * 40 + len("cupcakes") * 10 + assert_eq!(data.buffers()[1].len(), 200); + assert_eq!(data.buffers()[1].capacity(), 256); // Nearest multiple of 64 + } }