Skip to content

Commit

Permalink
feat(arrow-select): concat kernel will merge dictionary values for …
Browse files Browse the repository at this point in the history
…list of dictionaries (#6893)

* feat(arrow-select): make list of dictionary merge dictionary keys

TODO:
- [ ] Add support to nested lists
- [ ] Add more tests
- [ ] Fix failing test

* fix concat lists of dictionaries

* format

* remove unused import

* improve test helper

* feat: add merge offset buffers into one

* format

* add reproduction tst

* recommit

* fix clippy

* fix clippy

* fix clippy

* improve offsets code according to code review

* use concat dictionaries

* add specialize code to concat lists to be able to use the concat dictionary logic

* remove the use of ArrayData
  • Loading branch information
rluvaton authored Jan 4, 2025
1 parent debc5bf commit 2799268
Show file tree
Hide file tree
Showing 2 changed files with 232 additions and 11 deletions.
52 changes: 52 additions & 0 deletions arrow-buffer/src/buffer/offset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,38 @@ impl<O: ArrowNativeType> OffsetBuffer<O> {
Self(out.into())
}

/// Get an Iterator over the lengths of this [`OffsetBuffer`]
///
/// ```
/// # use arrow_buffer::{OffsetBuffer, ScalarBuffer};
/// let offsets = OffsetBuffer::<_>::new(ScalarBuffer::<i32>::from(vec![0, 1, 4, 9]));
/// assert_eq!(offsets.lengths().collect::<Vec<usize>>(), vec![1, 3, 5]);
/// ```
///
/// Empty [`OffsetBuffer`] will return an empty iterator
/// ```
/// # use arrow_buffer::OffsetBuffer;
/// let offsets = OffsetBuffer::<i32>::new_empty();
/// assert_eq!(offsets.lengths().count(), 0);
/// ```
///
/// This can be used to merge multiple [`OffsetBuffer`]s to one
/// ```
/// # use arrow_buffer::{OffsetBuffer, ScalarBuffer};
///
/// let buffer1 = OffsetBuffer::<i32>::from_lengths([2, 6, 3, 7, 2]);
/// let buffer2 = OffsetBuffer::<i32>::from_lengths([1, 3, 5, 7, 9]);
///
/// let merged = OffsetBuffer::<i32>::from_lengths(
/// vec![buffer1, buffer2].iter().flat_map(|x| x.lengths())
/// );
///
/// assert_eq!(merged.lengths().collect::<Vec<_>>(), &[2, 6, 3, 7, 2, 1, 3, 5, 7, 9]);
/// ```
pub fn lengths(&self) -> impl ExactSizeIterator<Item = usize> + '_ {
self.0.windows(2).map(|x| x[1].as_usize() - x[0].as_usize())
}

/// Free up unused memory.
pub fn shrink_to_fit(&mut self) {
self.0.shrink_to_fit();
Expand Down Expand Up @@ -244,4 +276,24 @@ mod tests {
fn from_lengths_usize_overflow() {
OffsetBuffer::<i32>::from_lengths([usize::MAX, 1]);
}

#[test]
fn get_lengths() {
let offsets = OffsetBuffer::<i32>::new(ScalarBuffer::<i32>::from(vec![0, 1, 4, 9]));
assert_eq!(offsets.lengths().collect::<Vec<usize>>(), vec![1, 3, 5]);
}

#[test]
fn get_lengths_should_be_with_fixed_size() {
let offsets = OffsetBuffer::<i32>::new(ScalarBuffer::<i32>::from(vec![0, 1, 4, 9]));
let iter = offsets.lengths();
assert_eq!(iter.size_hint(), (3, Some(3)));
assert_eq!(iter.len(), 3);
}

#[test]
fn get_lengths_from_empty_offset_buffer_should_be_empty_iterator() {
let offsets = OffsetBuffer::<i32>::new_empty();
assert_eq!(offsets.lengths().collect::<Vec<usize>>(), vec![]);
}
}
191 changes: 180 additions & 11 deletions arrow-select/src/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ use crate::dictionary::{merge_dictionary_values, should_merge_dictionary_values}
use arrow_array::cast::AsArray;
use arrow_array::types::*;
use arrow_array::*;
use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder, NullBuffer};
use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder, NullBuffer, OffsetBuffer};
use arrow_data::transform::{Capacities, MutableArrayData};
use arrow_schema::{ArrowError, DataType, SchemaRef};
use arrow_schema::{ArrowError, DataType, FieldRef, SchemaRef};
use std::sync::Arc;

fn binary_capacity<T: ByteArrayType>(arrays: &[&dyn Array]) -> Capacities {
Expand Down Expand Up @@ -129,6 +129,54 @@ fn concat_dictionaries<K: ArrowDictionaryKeyType>(
Ok(Arc::new(array))
}

fn concat_lists<OffsetSize: OffsetSizeTrait>(
arrays: &[&dyn Array],
field: &FieldRef,
) -> Result<ArrayRef, ArrowError> {
let mut output_len = 0;
let mut list_has_nulls = false;

let lists = arrays
.iter()
.map(|x| x.as_list::<OffsetSize>())
.inspect(|l| {
output_len += l.len();
list_has_nulls |= l.null_count() != 0;
})
.collect::<Vec<_>>();

let lists_nulls = list_has_nulls.then(|| {
let mut nulls = BooleanBufferBuilder::new(output_len);
for l in &lists {
match l.nulls() {
Some(n) => nulls.append_buffer(n.inner()),
None => nulls.append_n(l.len(), true),
}
}
NullBuffer::new(nulls.finish())
});

let values: Vec<&dyn Array> = lists
.iter()
.map(|x| x.values().as_ref())
.collect::<Vec<_>>();

let concatenated_values = concat(values.as_slice())?;

// Merge value offsets from the lists
let value_offset_buffer =
OffsetBuffer::<OffsetSize>::from_lengths(lists.iter().flat_map(|x| x.offsets().lengths()));

let array = GenericListArray::<OffsetSize>::try_new(
Arc::clone(field),
value_offset_buffer,
concatenated_values,
lists_nulls,
)?;

Ok(Arc::new(array))
}

macro_rules! dict_helper {
($t:ty, $arrays:expr) => {
return Ok(Arc::new(concat_dictionaries::<$t>($arrays)?) as _)
Expand Down Expand Up @@ -163,14 +211,20 @@ pub fn concat(arrays: &[&dyn Array]) -> Result<ArrayRef, ArrowError> {
"It is not possible to concatenate arrays of different data types.".to_string(),
));
}
if let DataType::Dictionary(k, _) = d {
downcast_integer! {
k.as_ref() => (dict_helper, arrays),
_ => unreachable!("illegal dictionary key type {k}")
};
} else {
let capacity = get_capacity(arrays, d);
concat_fallback(arrays, capacity)

match d {
DataType::Dictionary(k, _) => {
downcast_integer! {
k.as_ref() => (dict_helper, arrays),
_ => unreachable!("illegal dictionary key type {k}")
}
}
DataType::List(field) => concat_lists::<i32>(arrays, field),
DataType::LargeList(field) => concat_lists::<i64>(arrays, field),
_ => {
let capacity = get_capacity(arrays, d);
concat_fallback(arrays, capacity)
}
}
}

Expand Down Expand Up @@ -228,8 +282,9 @@ pub fn concat_batches<'a>(
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::builder::StringDictionaryBuilder;
use arrow_array::builder::{GenericListBuilder, StringDictionaryBuilder};
use arrow_schema::{Field, Schema};
use std::fmt::Debug;

#[test]
fn test_concat_empty_vec() {
Expand Down Expand Up @@ -851,4 +906,118 @@ mod tests {
assert_eq!(array.null_count(), 10);
assert_eq!(array.logical_null_count(), 10);
}

#[test]
fn concat_dictionary_list_array_simple() {
let scalars = vec![
create_single_row_list_of_dict(vec![Some("a")]),
create_single_row_list_of_dict(vec![Some("a")]),
create_single_row_list_of_dict(vec![Some("b")]),
];

let arrays = scalars
.iter()
.map(|a| a as &(dyn Array))
.collect::<Vec<_>>();
let concat_res = concat(arrays.as_slice()).unwrap();

let expected_list = create_list_of_dict(vec![
// Row 1
Some(vec![Some("a")]),
Some(vec![Some("a")]),
Some(vec![Some("b")]),
]);

let list = concat_res.as_list::<i32>();

// Assert that the list is equal to the expected list
list.iter().zip(expected_list.iter()).for_each(|(a, b)| {
assert_eq!(a, b);
});

assert_dictionary_has_unique_values::<_, StringArray>(
list.values().as_dictionary::<Int32Type>(),
);
}

#[test]
fn concat_many_dictionary_list_arrays() {
let number_of_unique_values = 8;
let scalars = (0..80000)
.map(|i| {
create_single_row_list_of_dict(vec![Some(
(i % number_of_unique_values).to_string(),
)])
})
.collect::<Vec<_>>();

let arrays = scalars
.iter()
.map(|a| a as &(dyn Array))
.collect::<Vec<_>>();
let concat_res = concat(arrays.as_slice()).unwrap();

let expected_list = create_list_of_dict(
(0..80000)
.map(|i| Some(vec![Some((i % number_of_unique_values).to_string())]))
.collect::<Vec<_>>(),
);

let list = concat_res.as_list::<i32>();

// Assert that the list is equal to the expected list
list.iter().zip(expected_list.iter()).for_each(|(a, b)| {
assert_eq!(a, b);
});

assert_dictionary_has_unique_values::<_, StringArray>(
list.values().as_dictionary::<Int32Type>(),
);
}

fn create_single_row_list_of_dict(
list_items: Vec<Option<impl AsRef<str>>>,
) -> GenericListArray<i32> {
let rows = list_items.into_iter().map(Some).collect();

create_list_of_dict(vec![rows])
}

fn create_list_of_dict(
rows: Vec<Option<Vec<Option<impl AsRef<str>>>>>,
) -> GenericListArray<i32> {
let mut builder =
GenericListBuilder::<i32, _>::new(StringDictionaryBuilder::<Int32Type>::new());

for row in rows {
builder.append_option(row);
}

builder.finish()
}

fn assert_dictionary_has_unique_values<'a, K, V>(array: &'a DictionaryArray<K>)
where
K: ArrowDictionaryKeyType,
V: Sync + Send + 'static,
&'a V: ArrayAccessor + IntoIterator,

<&'a V as ArrayAccessor>::Item: Default + Clone + PartialEq + Debug + Ord,
<&'a V as IntoIterator>::Item: Clone + PartialEq + Debug + Ord,
{
let dict = array.downcast_dict::<V>().unwrap();
let mut values = dict.values().into_iter().collect::<Vec<_>>();

// remove duplicates must be sorted first so we can compare
values.sort();

let mut unique_values = values.clone();

unique_values.dedup();

assert_eq!(
values, unique_values,
"There are duplicates in the value list (the value list here is sorted which is only for the assertion)"
);
}
}

0 comments on commit 2799268

Please sign in to comment.