From 0d2ce17f1fd2b3ab0eb268886355cae1bb956726 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Thu, 13 Apr 2023 16:55:03 +0100 Subject: [PATCH] Add ByteArray constructors (#3879) --- arrow-array/src/array/byte_array.rs | 65 ++++++++++++++++++++++++++- arrow-array/src/array/string_array.rs | 28 ++---------- arrow-array/src/types.rs | 46 ++++++++++++++++++- 3 files changed, 112 insertions(+), 27 deletions(-) diff --git a/arrow-array/src/array/byte_array.rs b/arrow-array/src/array/byte_array.rs index e23079ef9be9..ffabb6e174f5 100644 --- a/arrow-array/src/array/byte_array.rs +++ b/arrow-array/src/array/byte_array.rs @@ -21,10 +21,10 @@ use crate::iterator::ArrayIter; use crate::types::bytes::ByteArrayNativeType; use crate::types::ByteArrayType; use crate::{Array, ArrayAccessor, ArrayRef, OffsetSizeTrait}; -use arrow_buffer::{ArrowNativeType, Buffer}; +use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer}; use arrow_buffer::{NullBuffer, OffsetBuffer}; use arrow_data::{ArrayData, ArrayDataBuilder}; -use arrow_schema::DataType; +use arrow_schema::{ArrowError, DataType}; use std::any::Any; use std::sync::Arc; @@ -60,6 +60,67 @@ impl GenericByteArray { /// Data type of the array. pub const DATA_TYPE: DataType = T::DATA_TYPE; + /// Create a new [`GenericByteArray`] from the provided parts, panicking on failure + /// + /// # Panics + /// + /// Panics if [`GenericByteArray::try_new`] returns an error + pub fn new( + offsets: OffsetBuffer, + values: Buffer, + nulls: Option, + ) -> Self { + Self::try_new(offsets, values, nulls).unwrap() + } + + /// Create a new [`GenericByteArray`] from the provided parts, returning an error on failure + /// + /// # Errors + /// + /// * `offsets.len() - 1 != nulls.len()` + /// * Any consecutive pair of `offsets` does not denote a valid slice of `values` + pub fn try_new( + offsets: OffsetBuffer, + values: Buffer, + nulls: Option, + ) -> Result { + let len = offsets.len() - 1; + T::validate(&offsets, &values)?; + + if let Some(n) = nulls.as_ref() { + if n.len() != len { + return Err(ArrowError::InvalidArgumentError(format!( + "Incorrect number of nulls for {}{}Array, expected {len} got {}", + T::Offset::PREFIX, + T::PREFIX, + n.len(), + ))); + } + } + + Ok(Self { + data_type: T::DATA_TYPE, + value_offsets: offsets, + value_data: values, + nulls, + }) + } + + /// Create a new [`GenericByteArray`] of length `len` where all values are null + pub fn new_null(len: usize) -> Self { + Self { + data_type: T::DATA_TYPE, + value_offsets: OffsetBuffer::new_zeroed(len), + value_data: MutableBuffer::new(0).into(), + nulls: Some(NullBuffer::new_null(len)), + } + } + + /// Deconstruct this array into its constituent parts + pub fn into_parts(self) -> (OffsetBuffer, Buffer, Option) { + (self.value_offsets, self.value_data, self.nulls) + } + /// Returns the length for value at index `i`. /// # Panics /// Panics if index `i` is out of bounds. diff --git a/arrow-array/src/array/string_array.rs b/arrow-array/src/array/string_array.rs index e042f29c22d1..7c4a375299db 100644 --- a/arrow-array/src/array/string_array.rs +++ b/arrow-array/src/array/string_array.rs @@ -16,9 +16,7 @@ // under the License. use crate::types::GenericStringType; -use crate::{ - Array, GenericBinaryArray, GenericByteArray, GenericListArray, OffsetSizeTrait, -}; +use crate::{GenericBinaryArray, GenericByteArray, GenericListArray, OffsetSizeTrait}; use arrow_buffer::{bit_util, MutableBuffer}; use arrow_data::ArrayData; use arrow_schema::{ArrowError, DataType}; @@ -105,27 +103,8 @@ impl GenericStringArray { pub fn try_from_binary( v: GenericBinaryArray, ) -> Result { - let offsets = v.value_offsets(); - let values = v.value_data(); - - // We only need to validate that all values are valid UTF-8 - let validated = std::str::from_utf8(values).map_err(|e| { - ArrowError::CastError(format!("Encountered non UTF-8 data: {e}")) - })?; - - for offset in offsets.iter() { - let o = offset.as_usize(); - if !validated.is_char_boundary(o) { - return Err(ArrowError::CastError(format!( - "Split UTF-8 codepoint at offset {o}" - ))); - } - } - - let builder = v.into_data().into_builder().data_type(Self::DATA_TYPE); - // SAFETY: - // Validated UTF-8 above - Ok(Self::from(unsafe { builder.build_unchecked() })) + let (offsets, values, nulls) = v.into_parts(); + Self::try_new(offsets, values, nulls) } } @@ -261,6 +240,7 @@ mod tests { use super::*; use crate::builder::{ListBuilder, PrimitiveBuilder, StringBuilder}; use crate::types::UInt8Type; + use crate::Array; use arrow_buffer::Buffer; use arrow_schema::Field; use std::sync::Arc; diff --git a/arrow-array/src/types.rs b/arrow-array/src/types.rs index e2d7a2492227..402365f88ec7 100644 --- a/arrow-array/src/types.rs +++ b/arrow-array/src/types.rs @@ -19,7 +19,7 @@ use crate::delta::shift_months; use crate::{ArrowNativeTypeOp, OffsetSizeTrait}; -use arrow_buffer::i256; +use arrow_buffer::{i256, Buffer, OffsetBuffer}; use arrow_data::decimal::{validate_decimal256_precision, validate_decimal_precision}; use arrow_schema::{ ArrowError, DataType, IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION, @@ -882,10 +882,18 @@ pub trait ByteArrayType: 'static + Send + Sync + bytes::ByteArrayTypeSealed { /// Utf8Array will have native type has &str /// BinaryArray will have type as [u8] type Native: bytes::ByteArrayNativeType + AsRef + AsRef<[u8]> + ?Sized; + /// "Binary" or "String", for use in error messages const PREFIX: &'static str; + /// Datatype of array elements const DATA_TYPE: DataType; + + /// Verifies that every consecutive pair of `offsets` denotes a valid slice of `values` + fn validate( + offsets: &OffsetBuffer, + values: &Buffer, + ) -> Result<(), ArrowError>; } /// [`ByteArrayType`] for string arrays @@ -903,6 +911,27 @@ impl ByteArrayType for GenericStringType { } else { DataType::Utf8 }; + + fn validate( + offsets: &OffsetBuffer, + values: &Buffer, + ) -> Result<(), ArrowError> { + // Verify that the slice as a whole is valid UTF-8 + let validated = std::str::from_utf8(values).map_err(|e| { + ArrowError::InvalidArgumentError(format!("Encountered non UTF-8 data: {e}")) + })?; + + // Verify each offset is at a valid character boundary in this UTF-8 array + for offset in offsets.iter() { + let o = offset.as_usize(); + if !validated.is_char_boundary(o) { + return Err(ArrowError::InvalidArgumentError(format!( + "Split UTF-8 codepoint at offset {o}" + ))); + } + } + Ok(()) + } } /// An arrow utf8 array with i32 offsets @@ -925,6 +954,21 @@ impl ByteArrayType for GenericBinaryType { } else { DataType::Binary }; + + fn validate( + offsets: &OffsetBuffer, + values: &Buffer, + ) -> Result<(), ArrowError> { + // offsets are guaranteed to be monotonically increasing and non-empty + let max_offset = offsets.last().unwrap().as_usize(); + if values.len() < max_offset { + return Err(ArrowError::InvalidArgumentError(format!( + "Maximum offset of {max_offset} is larger than values of length {}", + values.len() + ))); + } + Ok(()) + } } /// An arrow binary array with i32 offsets