Skip to content

Commit

Permalink
fix: Fix performance regression for sort/gather on list/array columns (
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion authored Nov 1, 2024
1 parent ebeeea7 commit 802e692
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 113 deletions.
54 changes: 3 additions & 51 deletions crates/polars-arrow/src/compute/take/fixed_size_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,62 +18,17 @@
use std::mem::ManuallyDrop;

use polars_utils::itertools::Itertools;
use polars_utils::IdxSize;

use super::Index;
use crate::array::growable::{Growable, GrowableFixedSizeList};
use crate::array::{Array, ArrayRef, FixedSizeListArray, PrimitiveArray, StaticArray};
use crate::bitmap::MutableBitmap;
use crate::compute::take::bitmap::{take_bitmap_nulls_unchecked, take_bitmap_unchecked};
use crate::compute::utils::combine_validities_and;
use crate::datatypes::reshape::{Dimension, ReshapeDimension};
use crate::datatypes::{ArrowDataType, PhysicalType};
use crate::datatypes::{ArrowDataType, IdxArr, PhysicalType};
use crate::legacy::prelude::FromData;
use crate::with_match_primitive_type;

pub(super) unsafe fn take_unchecked_slow<O: Index>(
values: &FixedSizeListArray,
indices: &PrimitiveArray<O>,
) -> FixedSizeListArray {
let take_len = std::cmp::min(values.len(), 1);
let mut capacity = 0;
let arrays = indices
.values()
.iter()
.map(|index| {
let index = index.to_usize();
let slice = values.clone().sliced_unchecked(index, take_len);
capacity += slice.len();
slice
})
.collect::<Vec<FixedSizeListArray>>();

let arrays = arrays.iter().collect();

if let Some(validity) = indices.validity() {
let mut growable: GrowableFixedSizeList =
GrowableFixedSizeList::new(arrays, true, capacity);

for index in 0..indices.len() {
if validity.get_bit_unchecked(index) {
growable.extend(index, 0, 1);
} else {
growable.extend_validity(1)
}
}

growable.into()
} else {
let mut growable: GrowableFixedSizeList =
GrowableFixedSizeList::new(arrays, false, capacity);
for index in 0..indices.len() {
growable.extend(index, 0, 1);
}

growable.into()
}
}

fn get_stride_and_leaf_type(dtype: &ArrowDataType, size: usize) -> (usize, &ArrowDataType) {
if let ArrowDataType::FixedSizeList(inner, size_inner) = dtype {
get_stride_and_leaf_type(inner.dtype(), *size_inner * size)
Expand Down Expand Up @@ -163,10 +118,7 @@ fn arr_no_validities_recursive(arr: &dyn Array) -> bool {
}

/// `take` implementation for FixedSizeListArrays
pub(super) unsafe fn take_unchecked(
values: &FixedSizeListArray,
indices: &PrimitiveArray<IdxSize>,
) -> ArrayRef {
pub(super) unsafe fn take_unchecked(values: &FixedSizeListArray, indices: &IdxArr) -> ArrayRef {
let (stride, leaf_type) = get_stride_and_leaf_type(values.dtype(), 1);
if leaf_type.to_physical_type().is_primitive()
&& arr_no_validities_recursive(values.values().as_ref())
Expand Down Expand Up @@ -249,7 +201,7 @@ pub(super) unsafe fn take_unchecked(
.unwrap()
.with_validity(outer_validity)
} else {
take_unchecked_slow(values, indices).boxed()
super::take_unchecked_impl_generic(values, indices, &FixedSizeListArray::new_null).boxed()
}
}

Expand Down
50 changes: 6 additions & 44 deletions crates/polars-arrow/src/compute/take/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,55 +15,17 @@
// specific language governing permissions and limitations
// under the License.

use super::Index;
use crate::array::growable::{Growable, GrowableList};
use crate::array::{Array, ListArray};
use crate::array::{self, ArrayFromIterDtype, ListArray, StaticArray};
use crate::datatypes::IdxArr;
use crate::offset::Offset;

/// `take` implementation for ListArrays
pub(super) unsafe fn take_unchecked<I: Offset>(
values: &ListArray<I>,
indices: &IdxArr,
) -> ListArray<I> {
// fast-path: all values to take are none
if indices.null_count() == indices.len() {
return ListArray::<I>::new_null(values.dtype().clone(), indices.len());
}

let mut capacity = 0;
let arrays = indices
.iter()
.flat_map(|opt_idx| {
opt_idx.map(|index| {
let index = index.to_usize();
let slice = values.clone().sliced(index, 1);
capacity += slice.len();
slice
})
})
.collect::<Vec<ListArray<I>>>();

let arrays = arrays.iter().collect();
if let Some(validity) = indices.validity() {
let mut growable: GrowableList<I> = GrowableList::new(arrays, true, capacity);
let mut not_null_index = 0;
for index in 0..indices.len() {
if validity.get_bit_unchecked(index) {
growable.extend(not_null_index, 0, 1);
not_null_index += 1;
} else {
growable.extend_validity(1)
}
}

growable.into()
} else {
let mut growable: GrowableList<I> = GrowableList::new(arrays, false, capacity);
for index in 0..indices.len() {
growable.extend(index, 0, 1);
}

growable.into()
}
) -> ListArray<I>
where
ListArray<I>: StaticArray + ArrayFromIterDtype<std::option::Option<Box<dyn array::Array>>>,
{
super::take_unchecked_impl_generic(values, indices, &ListArray::new_null)
}
71 changes: 69 additions & 2 deletions crates/polars-arrow/src/compute/take/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@

//! Defines take kernel for [`Array`]
use crate::array::{new_empty_array, Array, NullArray, Utf8ViewArray};
use crate::array::{
self, new_empty_array, Array, ArrayCollectIterExt, ArrayFromIterDtype, NullArray, StaticArray,
Utf8ViewArray,
};
use crate::compute::take::binview::take_binview_unchecked;
use crate::datatypes::IdxArr;
use crate::datatypes::{ArrowDataType, IdxArr};
use crate::types::Index;

mod binary;
Expand Down Expand Up @@ -82,3 +85,67 @@ pub unsafe fn take_unchecked(values: &dyn Array, indices: &IdxArr) -> Box<dyn Ar
t => unimplemented!("Take not supported for data type {:?}", t),
}
}

/// Naive default implementation
unsafe fn take_unchecked_impl_generic<T>(
values: &T,
indices: &IdxArr,
new_null_func: &dyn Fn(ArrowDataType, usize) -> T,
) -> T
where
T: StaticArray + ArrayFromIterDtype<std::option::Option<Box<dyn array::Array>>>,
{
if values.null_count() == values.len() || indices.null_count() == indices.len() {
return new_null_func(values.dtype().clone(), indices.len());
}

match (indices.has_nulls(), values.has_nulls()) {
(true, true) => {
let values_validity = values.validity().unwrap();

indices
.iter()
.map(|i| {
if let Some(i) = i {
let i = *i as usize;
if values_validity.get_bit_unchecked(i) {
return Some(values.value_unchecked(i));
}
}
None
})
.collect_arr_trusted_with_dtype(values.dtype().clone())
},
(true, false) => indices
.iter()
.map(|i| {
if let Some(i) = i {
let i = *i as usize;
return Some(values.value_unchecked(i));
}
None
})
.collect_arr_trusted_with_dtype(values.dtype().clone()),
(false, true) => {
let values_validity = values.validity().unwrap();

indices
.values_iter()
.map(|i| {
let i = *i as usize;
if values_validity.get_bit_unchecked(i) {
return Some(values.value_unchecked(i));
}
None
})
.collect_arr_trusted_with_dtype(values.dtype().clone())
},
(false, false) => indices
.values_iter()
.map(|i| {
let i = *i as usize;
Some(values.value_unchecked(i))
})
.collect_arr_trusted_with_dtype(values.dtype().clone()),
}
}
24 changes: 8 additions & 16 deletions crates/polars-core/src/chunked_array/ops/gather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,14 +316,10 @@ impl IdxCa {
#[cfg(feature = "dtype-array")]
impl ChunkTakeUnchecked<IdxCa> for ArrayChunked {
unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
let a = self.rechunk();
let index = indices.rechunk();

let chunks = a
.downcast_iter()
.zip(index.downcast_iter())
.map(|(arr, idx)| take_unchecked(arr, idx))
.collect::<Vec<_>>();
let chunks = vec![take_unchecked(
&self.rechunk().downcast_into_array(),
&indices.rechunk().downcast_into_array(),
)];
self.copy_with_chunks(chunks)
}
}
Expand All @@ -338,14 +334,10 @@ impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for ArrayChunked {

impl ChunkTakeUnchecked<IdxCa> for ListChunked {
unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
let a = self.rechunk();
let index = indices.rechunk();

let chunks = a
.downcast_iter()
.zip(index.downcast_iter())
.map(|(arr, idx)| take_unchecked(arr, idx))
.collect::<Vec<_>>();
let chunks = vec![take_unchecked(
&self.rechunk().downcast_into_array(),
&indices.rechunk().downcast_into_array(),
)];
self.copy_with_chunks(chunks)
}
}
Expand Down

0 comments on commit 802e692

Please sign in to comment.