diff --git a/arrow-select/src/take.rs b/arrow-select/src/take.rs index 0b1d44319493..8006afb373bb 100644 --- a/arrow-select/src/take.rs +++ b/arrow-select/src/take.rs @@ -121,30 +121,30 @@ where } downcast_primitive_array! { - values => Ok(Arc::new(take_primitive(values, indices)?)), + values => Ok(Arc::new(take_primitive(values, indices))), DataType::Boolean => { let values = values.as_any().downcast_ref::().unwrap(); - Ok(Arc::new(take_boolean(values, indices)?)) + Ok(Arc::new(take_boolean(values, indices))) } DataType::Decimal128(p, s) => { let decimal_values = values.as_any().downcast_ref::().unwrap(); - let array = take_primitive(decimal_values, indices)? + let array = take_primitive(decimal_values, indices) .with_precision_and_scale(*p, *s) .unwrap(); Ok(Arc::new(array)) } DataType::Decimal256(p, s) => { let decimal_values = values.as_any().downcast_ref::().unwrap(); - let array = take_primitive(decimal_values, indices)? + let array = take_primitive(decimal_values, indices) .with_precision_and_scale(*p, *s) .unwrap(); Ok(Arc::new(array)) } DataType::Utf8 => { - Ok(Arc::new(take_bytes(as_string_array(values), indices)?)) + Ok(Arc::new(take_bytes(as_string_array(values), indices))) } DataType::LargeUtf8 => { - Ok(Arc::new(take_bytes(as_largestring_array(values), indices)?)) + Ok(Arc::new(take_bytes(as_largestring_array(values), indices))) } DataType::List(_) => { let values = values @@ -198,14 +198,14 @@ where Ok(Arc::new(StructArray::from((fields, is_valid))) as ArrayRef) } DataType::Dictionary(_, _) => downcast_dictionary_array! { - values => Ok(Arc::new(take_dict(values, indices)?)), + values => Ok(Arc::new(take_dict(values, indices))), t => unimplemented!("Take not supported for dictionary type {:?}", t) } DataType::Binary => { - Ok(Arc::new(take_bytes(as_generic_binary_array::(values), indices)?)) + Ok(Arc::new(take_bytes(as_generic_binary_array::(values), indices))) } DataType::LargeBinary => { - Ok(Arc::new(take_bytes(as_generic_binary_array::(values), indices)?)) + Ok(Arc::new(take_bytes(as_generic_binary_array::(values), indices))) } DataType::FixedSizeBinary(size) => { let values = values @@ -246,28 +246,23 @@ fn maybe_usize(index: I) -> Result { } // take implementation when neither values nor indices contain nulls -fn take_no_nulls( - values: &[T], - indices: &[I], -) -> Result<(Buffer, Option), ArrowError> +fn take_no_nulls(values: &[T], indices: &[I]) -> (Buffer, Option) where T: ArrowNativeType, I: ArrowNativeType, { - let values = indices - .iter() - .map(|index| Result::<_, ArrowError>::Ok(values[maybe_usize::(*index)?])); + let values = indices.iter().map(|index| values[index.as_usize()]); // Soundness: `slice.map` is `TrustedLen`. - let buffer = unsafe { Buffer::try_from_trusted_len_iter(values)? }; + let buffer = unsafe { Buffer::from_trusted_len_iter(values) }; - Ok((buffer, None)) + (buffer, None) } // take implementation when only values contain nulls fn take_values_nulls( values: &PrimitiveArray, indices: &[I], -) -> Result<(Buffer, Option), ArrowError> +) -> (Buffer, Option) where T: ArrowPrimitiveType, I: ArrowNativeType, @@ -279,7 +274,7 @@ fn take_values_nulls_inner( values_data: &ArrayData, values: &[T], indices: &[I], -) -> Result<(Buffer, Option), ArrowError> +) -> (Buffer, Option) where T: ArrowNativeType, I: ArrowNativeType, @@ -290,15 +285,15 @@ where let mut null_count = 0; let values = indices.iter().enumerate().map(|(i, index)| { - let index = maybe_usize::(*index)?; + let index = index.as_usize(); if values_data.is_null(index) { null_count += 1; bit_util::unset_bit(null_slice, i); } - Result::<_, ArrowError>::Ok(values[index]) + values[index] }); // Soundness: `slice.map` is `TrustedLen`. - let buffer = unsafe { Buffer::try_from_trusted_len_iter(values)? }; + let buffer = unsafe { Buffer::from_trusted_len_iter(values) }; let nulls = if null_count == 0 { // if only non-null values were taken @@ -307,14 +302,14 @@ where Some(nulls.into()) }; - Ok((buffer, nulls)) + (buffer, nulls) } // take implementation when only indices contain nulls fn take_indices_nulls( values: &[T], indices: &PrimitiveArray, -) -> Result<(Buffer, Option), ArrowError> +) -> (Buffer, Option) where T: ArrowNativeType, I: ArrowPrimitiveType, @@ -327,14 +322,14 @@ fn take_indices_nulls_inner( values: &[T], indices: &[I], indices_data: &ArrayData, -) -> Result<(Buffer, Option), ArrowError> +) -> (Buffer, Option) where T: ArrowNativeType, I: ArrowNativeType, { let values = indices.iter().map(|index| { - let index = maybe_usize::(*index)?; - Result::<_, ArrowError>::Ok(match values.get(index) { + let index = index.as_usize(); + match values.get(index) { Some(value) => *value, None => { if indices_data.is_null(index) { @@ -343,25 +338,25 @@ where panic!("Out-of-bounds index {}", index) } } - }) + } }); // Soundness: `slice.map` is `TrustedLen`. - let buffer = unsafe { Buffer::try_from_trusted_len_iter(values)? }; + let buffer = unsafe { Buffer::from_trusted_len_iter(values) }; - Ok(( + ( buffer, indices_data .null_buffer() .map(|b| b.bit_slice(indices_data.offset(), indices.len())), - )) + ) } // take implementation when both values and indices contain nulls fn take_values_indices_nulls( values: &PrimitiveArray, indices: &PrimitiveArray, -) -> Result<(Buffer, Option), ArrowError> +) -> (Buffer, Option) where T: ArrowPrimitiveType, I: ArrowPrimitiveType, @@ -380,7 +375,7 @@ fn take_values_indices_nulls_inner( values_data: &ArrayData, indices: &[I], indices_data: &ArrayData, -) -> Result<(Buffer, Option), ArrowError> +) -> (Buffer, Option) where T: ArrowNativeType, I: ArrowNativeType, @@ -394,19 +389,18 @@ where if indices_data.is_null(i) { null_count += 1; bit_util::unset_bit(null_slice, i); - Ok(T::default()) + T::default() } else { - let index = maybe_usize::(index)?; + let index = index.as_usize(); if values_data.is_null(index) { null_count += 1; bit_util::unset_bit(null_slice, i); } - Result::<_, ArrowError>::Ok(values[index]) + values[index] } }); // Soundness: `slice.map` is `TrustedLen`. - let buffer = unsafe { Buffer::try_from_trusted_len_iter(values)? }; - + let buffer = unsafe { Buffer::from_trusted_len_iter(values) }; let nulls = if null_count == 0 { // if only non-null values were taken None @@ -414,10 +408,10 @@ where Some(nulls.into()) }; - Ok((buffer, nulls)) + (buffer, nulls) } -/// `take` implementation for all primitive arrays +/// `take` implementation for [`PrimitiveArray`] /// /// This checks if an `indices` slot is populated, and gets the value from `values` /// as the populated index. @@ -426,10 +420,10 @@ where /// values: [1, 2, 3, null, 5] /// indices: [0, null, 4, 3] /// The result is: [1 (slot 0), null (null slot), 5 (slot 4), null (slot 3)] -fn take_primitive( +pub fn take_primitive( values: &PrimitiveArray, indices: &PrimitiveArray, -) -> Result, ArrowError> +) -> PrimitiveArray where T: ArrowPrimitiveType, I: ArrowPrimitiveType, @@ -444,22 +438,22 @@ where (false, false) => { // * no nulls // * all `indices.values()` are valid - take_no_nulls::(values.values(), indices.values())? + take_no_nulls::(values.values(), indices.values()) } (true, false) => { // * nulls come from `values` alone // * all `indices.values()` are valid - take_values_nulls::(values, indices.values())? + take_values_nulls::(values, indices.values()) } (false, true) => { // in this branch it is unsound to read and use `index.values()`, // as doing so is UB when they come from a null slot. - take_indices_nulls::(values.values(), indices)? + take_indices_nulls::(values.values(), indices) } (true, true) => { // in this branch it is unsound to read and use `index.values()`, // as doing so is UB when they come from a null slot. - take_values_indices_nulls::(values, indices)? + take_values_indices_nulls::(values, indices) } }; @@ -474,18 +468,14 @@ where vec![], ) }; - Ok(PrimitiveArray::::from(data)) + PrimitiveArray::::from(data) } -fn take_bits( +fn take_bits( values: &Buffer, values_offset: usize, indices: &PrimitiveArray, -) -> Result -where - IndexType: ArrowPrimitiveType, - IndexType::Native: ToPrimitive, -{ +) -> Buffer { let len = indices.len(); let values_slice = values.as_slice(); let mut output_buffer = MutableBuffer::new_null(len); @@ -494,54 +484,39 @@ where let indices_has_nulls = indices.null_count() > 0; if indices_has_nulls { - indices - .iter() - .enumerate() - .try_for_each::<_, Result<(), ArrowError>>(|(i, index)| { - if let Some(index) = index { - let index = ToPrimitive::to_usize(&index).ok_or_else(|| { - ArrowError::ComputeError("Cast to usize failed".to_string()) - })?; - - if bit_util::get_bit(values_slice, values_offset + index) { - bit_util::set_bit(output_slice, i); - } - } - - Ok(()) - })?; - } else { - indices - .values() - .iter() - .enumerate() - .try_for_each::<_, Result<(), ArrowError>>(|(i, index)| { - let index = ToPrimitive::to_usize(index).ok_or_else(|| { - ArrowError::ComputeError("Cast to usize failed".to_string()) - })?; - + indices.iter().enumerate().for_each(|(i, index)| { + if let Some(index) = index { + let index = index.as_usize(); if bit_util::get_bit(values_slice, values_offset + index) { bit_util::set_bit(output_slice, i); } - Ok(()) - })?; + } + }); + } else { + indices.values().iter().enumerate().for_each(|(i, index)| { + let index = index.as_usize(); + + if bit_util::get_bit(values_slice, values_offset + index) { + bit_util::set_bit(output_slice, i); + } + }); } - Ok(output_buffer.into()) + output_buffer.into() } -/// `take` implementation for boolean arrays -fn take_boolean( +/// `take` implementation for [`BooleanArray`] +/// +/// # Panic +/// +/// Panics on out of bounds index +pub fn take_boolean( values: &BooleanArray, indices: &PrimitiveArray, -) -> Result -where - IndexType: ArrowPrimitiveType, - IndexType::Native: ToPrimitive, -{ - let val_buf = take_bits(values.values(), values.offset(), indices)?; +) -> BooleanArray { + let val_buf = take_bits(values.values(), values.offset(), indices); let null_buf = match values.data().null_buffer() { Some(buf) if values.null_count() > 0 => { - Some(take_bits(buf, values.offset(), indices)?) + Some(take_bits(buf, values.offset(), indices)) } _ => indices .data() @@ -560,19 +535,14 @@ where vec![], ) }; - Ok(BooleanArray::from(data)) + BooleanArray::from(data) } -/// `take` implementation for string arrays -fn take_bytes( +/// `take` implementation for [`GenericByteArray`] +pub fn take_bytes( array: &GenericByteArray, indices: &PrimitiveArray, -) -> Result, ArrowError> -where - T: ByteArrayType, - IndexType: ArrowPrimitiveType, - IndexType::Native: ToPrimitive, -{ +) -> GenericByteArray { let data_len = indices.len(); let bytes_offset = (data_len + 1) * std::mem::size_of::(); @@ -586,10 +556,7 @@ where let nulls; if array.null_count() == 0 && indices.null_count() == 0 { for (i, offset) in offsets.iter_mut().skip(1).enumerate() { - let index = ToPrimitive::to_usize(&indices.value(i)).ok_or_else(|| { - ArrowError::ComputeError("Cast to usize failed".to_string()) - })?; - + let index = indices.value(i).as_usize(); let s = array.value(index); length_so_far += T::Offset::from_usize(s.as_ref().len()).unwrap(); @@ -604,9 +571,7 @@ where let null_slice = null_buf.as_slice_mut(); for (i, offset) in offsets.iter_mut().skip(1).enumerate() { - let index = ToPrimitive::to_usize(&indices.value(i)).ok_or_else(|| { - ArrowError::ComputeError("Cast to usize failed".to_string()) - })?; + let index = indices.value(i).as_usize(); if array.is_valid(index) { let s = array.value(index).as_ref(); @@ -622,10 +587,7 @@ where } else if array.null_count() == 0 { for (i, offset) in offsets.iter_mut().skip(1).enumerate() { if indices.is_valid(i) { - let index = - ToPrimitive::to_usize(&indices.value(i)).ok_or_else(|| { - ArrowError::ComputeError("Cast to usize failed".to_string()) - })?; + let index = indices.value(i).as_usize(); let s = array.value(index).as_ref(); @@ -642,10 +604,7 @@ where let null_slice = null_buf.as_slice_mut(); for (i, offset) in offsets.iter_mut().skip(1).enumerate() { - let index = ToPrimitive::to_usize(&indices.value(i)).ok_or_else(|| { - ArrowError::ComputeError("Cast to usize failed".to_string()) - })?; - + let index = indices.value(i).as_usize(); if array.is_valid(index) && indices.is_valid(i) { let s = array.value(index).as_ref(); @@ -669,15 +628,15 @@ where let array_data = unsafe { array_data.build_unchecked() }; - Ok(GenericByteArray::from(array_data)) + GenericByteArray::from(array_data) } -/// `take` implementation for list arrays +/// `take` implementation for [`GenericListArray`] /// /// Calculates the index and indexed offset for the inner array, /// applying `take` on the inner array, then reconstructing a list array /// with the indexed offsets -fn take_list( +pub fn take_list( values: &GenericListArray, indices: &PrimitiveArray, ) -> Result, ArrowError> @@ -691,7 +650,7 @@ where // TODO: Some optimizations can be done here such as if it is // taking the whole list or a contiguous sublist let (list_indices, offsets) = - take_value_indices_from_list::(values, indices)?; + take_value_indices_from_list::(values, indices); let taken = take_impl::(values.values().as_ref(), &list_indices, None)?; // determine null count and null buffer, which are a function of `values` and `indices` @@ -724,12 +683,12 @@ where Ok(GenericListArray::::from(list_data)) } -/// `take` implementation for `FixedSizeListArray` +/// `take` implementation for [`FixedSizeListArray`] /// /// Calculates the index and indexed offset for the inner array, /// applying `take` on the inner array, then reconstructing a list array /// with the indexed offsets -fn take_fixed_size_list( +pub fn take_fixed_size_list( values: &FixedSizeListArray, indices: &PrimitiveArray, length: ::Native, @@ -738,7 +697,7 @@ where IndexType: ArrowPrimitiveType, IndexType::Native: ToPrimitive, { - let list_indices = take_value_indices_from_fixed_size_list(values, indices, length)?; + let list_indices = take_value_indices_from_fixed_size_list(values, indices, length); let taken = take_impl::(values.values().as_ref(), &list_indices, None)?; // determine null count and null buffer, which are a function of `values` and `indices` @@ -766,7 +725,8 @@ where Ok(FixedSizeListArray::from(list_data)) } -fn take_fixed_size_binary( +/// `take` implementation for [`FixedSizeBinaryArray`] +pub fn take_fixed_size_binary( values: &FixedSizeBinaryArray, indices: &PrimitiveArray, size: i32, @@ -793,21 +753,21 @@ where FixedSizeBinaryArray::try_from_sparse_iter_with_size(array_iter, size) } -/// `take` implementation for dictionary arrays +/// `take` implementation for [`DictionaryArray`] /// /// applies `take` to the keys of the dictionary array and returns a new dictionary array /// with the same dictionary values and reordered keys -fn take_dict( +pub fn take_dict( values: &DictionaryArray, indices: &PrimitiveArray, -) -> Result, ArrowError> +) -> DictionaryArray where T: ArrowPrimitiveType, T::Native: num::Num, I: ArrowPrimitiveType, I::Native: ToPrimitive, { - let new_keys = take_primitive::(values.keys(), indices)?; + let new_keys = take_primitive::(values.keys(), indices); let new_keys_data = new_keys.data_ref(); let data = unsafe { @@ -822,7 +782,7 @@ where ) }; - Ok(DictionaryArray::::from(data)) + DictionaryArray::::from(data) } /// Takes/filters a list array's inner data using the offsets of the list array. @@ -833,7 +793,7 @@ where fn take_value_indices_from_list( list: &GenericListArray, indices: &PrimitiveArray, -) -> Result<(PrimitiveArray, Vec), ArrowError> +) -> (PrimitiveArray, Vec) where IndexType: ArrowPrimitiveType, IndexType::Native: ToPrimitive, @@ -852,9 +812,7 @@ where // compute the value indices, and set offsets accordingly for i in 0..indices.len() { if indices.is_valid(i) { - let ix = ToPrimitive::to_usize(&indices.value(i)).ok_or_else(|| { - ArrowError::ComputeError("Cast to usize failed".to_string()) - })?; + let ix = indices.value(i).as_usize(); let start = offsets[ix]; let end = offsets[ix + 1]; current_offset += end - start; @@ -872,7 +830,7 @@ where } } - Ok((PrimitiveArray::::from(values), new_offsets)) + (PrimitiveArray::::from(values), new_offsets) } /// Takes/filters a fixed size list array's inner data using the offsets of the list array. @@ -880,7 +838,7 @@ fn take_value_indices_from_fixed_size_list( list: &FixedSizeListArray, indices: &PrimitiveArray, length: ::Native, -) -> Result, ArrowError> +) -> PrimitiveArray where IndexType: ArrowPrimitiveType, IndexType::Native: ToPrimitive, @@ -889,9 +847,7 @@ where for i in 0..indices.len() { if indices.is_valid(i) { - let index = ToPrimitive::to_usize(&indices.value(i)).ok_or_else(|| { - ArrowError::ComputeError("Cast to usize failed".to_string()) - })?; + let index = indices.value(i).as_usize(); let start = list.value_offset(index) as ::Native; @@ -899,7 +855,7 @@ where } } - Ok(PrimitiveArray::::from(values)) + PrimitiveArray::::from(values) } #[cfg(test)] @@ -2056,7 +2012,7 @@ mod tests { ]); let indices = UInt32Array::from(vec![2, 0]); - let (indexed, offsets) = take_value_indices_from_list(&list, &indices).unwrap(); + let (indexed, offsets) = take_value_indices_from_list(&list, &indices); assert_eq!(indexed, Int32Array::from(vec![5, 6, 7, 8, 9, 0, 1])); assert_eq!(offsets, vec![0, 5, 7]); @@ -2072,7 +2028,7 @@ mod tests { let indices = UInt32Array::from(vec![2, 0]); let (indexed, offsets) = - take_value_indices_from_list::<_, Int64Type>(&list, &indices).unwrap(); + take_value_indices_from_list::<_, Int64Type>(&list, &indices); assert_eq!(indexed, Int64Array::from(vec![5, 6, 7, 8, 9, 0, 1])); assert_eq!(offsets, vec![0, 5, 7]); @@ -2091,14 +2047,12 @@ mod tests { ); let indices = UInt32Array::from(vec![2, 1, 0]); - let indexed = - take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap(); + let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3); assert_eq!(indexed, UInt32Array::from(vec![6, 7, 8, 3, 4, 5, 0, 1, 2])); let indices = UInt32Array::from(vec![3, 2, 1, 2, 0]); - let indexed = - take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap(); + let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3); assert_eq!( indexed,