diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs index aec665aa3013..9f20dceb980a 100644 --- a/arrow-cast/src/cast.rs +++ b/arrow-cast/src/cast.rs @@ -156,8 +156,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Utf8, LargeUtf8) => true, (LargeUtf8, Utf8) => true, - (Binary, LargeBinary) => true, - (LargeBinary, Binary) => true, + (Binary, LargeBinary | Utf8 | LargeUtf8) => true, + (LargeBinary, Binary | Utf8 | LargeUtf8) => true, (Utf8, Binary | LargeBinary @@ -185,7 +185,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Timestamp(_, _), Utf8) | (Timestamp(_, _), LargeUtf8) => true, (Date32, Utf8) | (Date32, LargeUtf8) => true, (Date64, Utf8) | (Date64, LargeUtf8) => true, - (_, Utf8 | LargeUtf8) => (DataType::is_numeric(from_type) && from_type != &Float16) || from_type == &Binary, + (_, Utf8 | LargeUtf8) => DataType::is_numeric(from_type) && from_type != &Float16, // start numeric casts ( @@ -1180,30 +1180,8 @@ pub fn cast_with_options( } Date32 => cast_date32_to_string::(array), Date64 => cast_date64_to_string::(array), - Binary => { - let array = array.as_any().downcast_ref::().unwrap(); - Ok(Arc::new( - array - .iter() - .map(|maybe_value| match maybe_value { - Some(value) => { - let result = std::str::from_utf8(value); - if cast_options.safe { - Ok(result.ok()) - } else { - Some(result.map_err(|_| { - ArrowError::CastError( - "Cannot cast binary to string".to_string(), - ) - })) - .transpose() - } - } - None => Ok(None), - }) - .collect::>()?, - )) - } + Binary => cast_binary_to_generic_string::(array, cast_options), + LargeBinary => cast_binary_to_generic_string::(array, cast_options), _ => Err(ArrowError::CastError(format!( "Casting from {from_type:?} to {to_type:?} not supported", ))), @@ -1236,30 +1214,8 @@ pub fn cast_with_options( } Date32 => cast_date32_to_string::(array), Date64 => cast_date64_to_string::(array), - Binary => { - let array = array.as_any().downcast_ref::().unwrap(); - Ok(Arc::new( - array - .iter() - .map(|maybe_value| match maybe_value { - Some(value) => { - let result = std::str::from_utf8(value); - if cast_options.safe { - Ok(result.ok()) - } else { - Some(result.map_err(|_| { - ArrowError::CastError( - "Cannot cast binary to string".to_string(), - ) - })) - .transpose() - } - } - None => Ok(None), - }) - .collect::>()?, - )) - } + Binary => cast_binary_to_generic_string::(array, cast_options), + LargeBinary => cast_binary_to_generic_string::(array, cast_options), _ => Err(ArrowError::CastError(format!( "Casting from {from_type:?} to {to_type:?} not supported", ))), @@ -3436,6 +3392,77 @@ fn cast_list_inner( Ok(Arc::new(list) as ArrayRef) } +/// Helper function to cast from `GenericBinaryArray` to `GenericStringArray`. This function performs +/// UTF8 validation during casting. For invalid UTF8 value, it could be Null or returning `Err` depending +/// `CastOptions`. +fn cast_binary_to_generic_string( + array: &dyn Array, + cast_options: &CastOptions, +) -> Result +where + I: OffsetSizeTrait + ToPrimitive, + O: OffsetSizeTrait + NumCast, +{ + let array = array + .as_any() + .downcast_ref::>>() + .unwrap(); + + if !cast_options.safe { + let offsets = array.value_offsets(); + let values = array.value_data(); + + // We only need to validate that all values are valid UTF-8 + let validated = std::str::from_utf8(values) + .map_err(|_| ArrowError::CastError("Invalid UTF-8 sequence".to_string()))?; + + let mut offset_builder = BufferBuilder::::new(offsets.len()); + offsets + .iter() + .try_for_each::<_, Result<_, ArrowError>>(|offset| { + if !validated.is_char_boundary(offset.as_usize()) { + return Err(ArrowError::CastError( + "Invalid UTF-8 sequence".to_string(), + )); + } + + let offset = ::from(*offset).ok_or_else(|| { + ArrowError::ComputeError(format!( + "{}Binary array too large to cast to {}String array", + I::PREFIX, + O::PREFIX + )) + })?; + offset_builder.append(offset); + Ok(()) + })?; + + let offset_buffer = offset_builder.finish(); + + let builder = ArrayData::builder(GenericStringArray::::DATA_TYPE) + .len(array.len()) + .add_buffer(offset_buffer) + .add_buffer(array.data().buffers()[1].clone()) + .null_count(array.null_count()) + .null_bit_buffer(array.data().null_buffer().cloned()); + + // SAFETY: + // Validated UTF-8 above + Ok(Arc::new(GenericStringArray::::from(unsafe { + builder.build_unchecked() + }))) + } else { + Ok(Arc::new( + array + .iter() + .map(|maybe_value| { + maybe_value.and_then(|value| std::str::from_utf8(value).ok()) + }) + .collect::>>(), + )) + } +} + /// Helper function to cast from one `ByteArrayType` to another and vice versa. /// If the target one (e.g., `LargeUtf8`) is too large for the source array it will return an Error. fn cast_byte_container(