diff --git a/crates/polars-arrow/src/compute/take/fixed_size_list.rs b/crates/polars-arrow/src/compute/take/fixed_size_list.rs index 2a52a1ae3fd1..624a09db3368 100644 --- a/crates/polars-arrow/src/compute/take/fixed_size_list.rs +++ b/crates/polars-arrow/src/compute/take/fixed_size_list.rs @@ -18,62 +18,17 @@ use std::mem::ManuallyDrop; use polars_utils::itertools::Itertools; -use polars_utils::IdxSize; use super::Index; -use crate::array::growable::{Growable, GrowableFixedSizeList}; use crate::array::{Array, ArrayRef, FixedSizeListArray, PrimitiveArray, StaticArray}; use crate::bitmap::MutableBitmap; use crate::compute::take::bitmap::{take_bitmap_nulls_unchecked, take_bitmap_unchecked}; use crate::compute::utils::combine_validities_and; use crate::datatypes::reshape::{Dimension, ReshapeDimension}; -use crate::datatypes::{ArrowDataType, PhysicalType}; +use crate::datatypes::{ArrowDataType, IdxArr, PhysicalType}; use crate::legacy::prelude::FromData; use crate::with_match_primitive_type; -pub(super) unsafe fn take_unchecked_slow( - values: &FixedSizeListArray, - indices: &PrimitiveArray, -) -> FixedSizeListArray { - let take_len = std::cmp::min(values.len(), 1); - let mut capacity = 0; - let arrays = indices - .values() - .iter() - .map(|index| { - let index = index.to_usize(); - let slice = values.clone().sliced_unchecked(index, take_len); - capacity += slice.len(); - slice - }) - .collect::>(); - - let arrays = arrays.iter().collect(); - - if let Some(validity) = indices.validity() { - let mut growable: GrowableFixedSizeList = - GrowableFixedSizeList::new(arrays, true, capacity); - - for index in 0..indices.len() { - if validity.get_bit_unchecked(index) { - growable.extend(index, 0, 1); - } else { - growable.extend_validity(1) - } - } - - growable.into() - } else { - let mut growable: GrowableFixedSizeList = - GrowableFixedSizeList::new(arrays, false, capacity); - for index in 0..indices.len() { - growable.extend(index, 0, 1); - } - - growable.into() - } -} - fn get_stride_and_leaf_type(dtype: &ArrowDataType, size: usize) -> (usize, &ArrowDataType) { if let ArrowDataType::FixedSizeList(inner, size_inner) = dtype { get_stride_and_leaf_type(inner.dtype(), *size_inner * size) @@ -163,10 +118,7 @@ fn arr_no_validities_recursive(arr: &dyn Array) -> bool { } /// `take` implementation for FixedSizeListArrays -pub(super) unsafe fn take_unchecked( - values: &FixedSizeListArray, - indices: &PrimitiveArray, -) -> ArrayRef { +pub(super) unsafe fn take_unchecked(values: &FixedSizeListArray, indices: &IdxArr) -> ArrayRef { let (stride, leaf_type) = get_stride_and_leaf_type(values.dtype(), 1); if leaf_type.to_physical_type().is_primitive() && arr_no_validities_recursive(values.values().as_ref()) @@ -249,7 +201,7 @@ pub(super) unsafe fn take_unchecked( .unwrap() .with_validity(outer_validity) } else { - take_unchecked_slow(values, indices).boxed() + super::take_unchecked_impl_generic(values, indices, &FixedSizeListArray::new_null).boxed() } } diff --git a/crates/polars-arrow/src/compute/take/list.rs b/crates/polars-arrow/src/compute/take/list.rs index 36ca1f72131f..497f680a944e 100644 --- a/crates/polars-arrow/src/compute/take/list.rs +++ b/crates/polars-arrow/src/compute/take/list.rs @@ -15,9 +15,7 @@ // specific language governing permissions and limitations // under the License. -use super::Index; -use crate::array::growable::{Growable, GrowableList}; -use crate::array::{Array, ListArray}; +use crate::array::{self, ArrayFromIterDtype, ListArray, StaticArray}; use crate::datatypes::IdxArr; use crate::offset::Offset; @@ -25,45 +23,9 @@ use crate::offset::Offset; pub(super) unsafe fn take_unchecked( values: &ListArray, indices: &IdxArr, -) -> ListArray { - // fast-path: all values to take are none - if indices.null_count() == indices.len() { - return ListArray::::new_null(values.dtype().clone(), indices.len()); - } - - let mut capacity = 0; - let arrays = indices - .iter() - .flat_map(|opt_idx| { - opt_idx.map(|index| { - let index = index.to_usize(); - let slice = values.clone().sliced(index, 1); - capacity += slice.len(); - slice - }) - }) - .collect::>>(); - - let arrays = arrays.iter().collect(); - if let Some(validity) = indices.validity() { - let mut growable: GrowableList = GrowableList::new(arrays, true, capacity); - let mut not_null_index = 0; - for index in 0..indices.len() { - if validity.get_bit_unchecked(index) { - growable.extend(not_null_index, 0, 1); - not_null_index += 1; - } else { - growable.extend_validity(1) - } - } - - growable.into() - } else { - let mut growable: GrowableList = GrowableList::new(arrays, false, capacity); - for index in 0..indices.len() { - growable.extend(index, 0, 1); - } - - growable.into() - } +) -> ListArray +where + ListArray: StaticArray + ArrayFromIterDtype>>, +{ + super::take_unchecked_impl_generic(values, indices, &ListArray::new_null) } diff --git a/crates/polars-arrow/src/compute/take/mod.rs b/crates/polars-arrow/src/compute/take/mod.rs index bdd782a1d609..1c1ad3c51373 100644 --- a/crates/polars-arrow/src/compute/take/mod.rs +++ b/crates/polars-arrow/src/compute/take/mod.rs @@ -17,9 +17,12 @@ //! Defines take kernel for [`Array`] -use crate::array::{new_empty_array, Array, NullArray, Utf8ViewArray}; +use crate::array::{ + self, new_empty_array, Array, ArrayCollectIterExt, ArrayFromIterDtype, NullArray, StaticArray, + Utf8ViewArray, +}; use crate::compute::take::binview::take_binview_unchecked; -use crate::datatypes::IdxArr; +use crate::datatypes::{ArrowDataType, IdxArr}; use crate::types::Index; mod binary; @@ -82,3 +85,67 @@ pub unsafe fn take_unchecked(values: &dyn Array, indices: &IdxArr) -> Box unimplemented!("Take not supported for data type {:?}", t), } } + +/// Naive default implementation +unsafe fn take_unchecked_impl_generic( + values: &T, + indices: &IdxArr, + new_null_func: &dyn Fn(ArrowDataType, usize) -> T, +) -> T +where + T: StaticArray + ArrayFromIterDtype>>, +{ + if values.null_count() == values.len() || indices.null_count() == indices.len() { + return new_null_func(values.dtype().clone(), indices.len()); + } + + match (indices.has_nulls(), values.has_nulls()) { + (true, true) => { + let values_validity = values.validity().unwrap(); + + indices + .iter() + .map(|i| { + if let Some(i) = i { + let i = *i as usize; + if values_validity.get_bit_unchecked(i) { + return Some(values.value_unchecked(i)); + } + } + None + }) + .collect_arr_trusted_with_dtype(values.dtype().clone()) + }, + (true, false) => indices + .iter() + .map(|i| { + if let Some(i) = i { + let i = *i as usize; + return Some(values.value_unchecked(i)); + } + None + }) + .collect_arr_trusted_with_dtype(values.dtype().clone()), + (false, true) => { + let values_validity = values.validity().unwrap(); + + indices + .values_iter() + .map(|i| { + let i = *i as usize; + if values_validity.get_bit_unchecked(i) { + return Some(values.value_unchecked(i)); + } + None + }) + .collect_arr_trusted_with_dtype(values.dtype().clone()) + }, + (false, false) => indices + .values_iter() + .map(|i| { + let i = *i as usize; + Some(values.value_unchecked(i)) + }) + .collect_arr_trusted_with_dtype(values.dtype().clone()), + } +} diff --git a/crates/polars-core/src/chunked_array/ops/gather.rs b/crates/polars-core/src/chunked_array/ops/gather.rs index fc162626bc27..ca1411f3e653 100644 --- a/crates/polars-core/src/chunked_array/ops/gather.rs +++ b/crates/polars-core/src/chunked_array/ops/gather.rs @@ -316,14 +316,10 @@ impl IdxCa { #[cfg(feature = "dtype-array")] impl ChunkTakeUnchecked for ArrayChunked { unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self { - let a = self.rechunk(); - let index = indices.rechunk(); - - let chunks = a - .downcast_iter() - .zip(index.downcast_iter()) - .map(|(arr, idx)| take_unchecked(arr, idx)) - .collect::>(); + let chunks = vec![take_unchecked( + &self.rechunk().downcast_into_array(), + &indices.rechunk().downcast_into_array(), + )]; self.copy_with_chunks(chunks) } } @@ -338,14 +334,10 @@ impl + ?Sized> ChunkTakeUnchecked for ArrayChunked { impl ChunkTakeUnchecked for ListChunked { unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self { - let a = self.rechunk(); - let index = indices.rechunk(); - - let chunks = a - .downcast_iter() - .zip(index.downcast_iter()) - .map(|(arr, idx)| take_unchecked(arr, idx)) - .collect::>(); + let chunks = vec![take_unchecked( + &self.rechunk().downcast_into_array(), + &indices.rechunk().downcast_into_array(), + )]; self.copy_with_chunks(chunks) } }