diff --git a/arrow-array/src/builder/primitive_builder.rs b/arrow-array/src/builder/primitive_builder.rs index f3f3f3728db9..a969e121808b 100644 --- a/arrow-array/src/builder/primitive_builder.rs +++ b/arrow-array/src/builder/primitive_builder.rs @@ -21,6 +21,7 @@ use crate::types::*; use crate::{ArrayRef, ArrowPrimitiveType, PrimitiveArray}; use arrow_buffer::{Buffer, MutableBuffer}; use arrow_data::ArrayData; +use arrow_schema::DataType; use std::any::Any; use std::sync::Arc; @@ -94,6 +95,7 @@ pub type Decimal256Builder = PrimitiveBuilder; pub struct PrimitiveBuilder { values_builder: BufferBuilder, null_buffer_builder: NullBufferBuilder, + data_type: DataType, } impl ArrayBuilder for PrimitiveBuilder { @@ -150,6 +152,7 @@ impl PrimitiveBuilder { Self { values_builder: BufferBuilder::::new(capacity), null_buffer_builder: NullBufferBuilder::new(capacity), + data_type: T::DATA_TYPE, } } @@ -169,9 +172,29 @@ impl PrimitiveBuilder { Self { values_builder, null_buffer_builder, + data_type: T::DATA_TYPE, } } + /// By default [`PrimitiveBuilder`] uses [`ArrowPrimitiveType::DATA_TYPE`] as the + /// data type of the generated array. + /// + /// This method allows overriding the data type, to allow specifying timezones + /// for [`DataType::Timestamp`] or precision and scale for [`DataType::Decimal128`] + /// + /// # Panics + /// + /// This method panics if `data_type` is not [PrimitiveArray::is_compatible] + pub fn with_data_type(self, data_type: DataType) -> Self { + assert!( + PrimitiveArray::::is_compatible(&data_type), + "incompatible data type for builder, expected {} got {}", + T::DATA_TYPE, + data_type + ); + Self { data_type, ..self } + } + /// Returns the capacity of this builder measured in slots of type `T` pub fn capacity(&self) -> usize { self.values_builder.capacity() @@ -250,7 +273,7 @@ impl PrimitiveBuilder { pub fn finish(&mut self) -> PrimitiveArray { let len = self.len(); let null_bit_buffer = self.null_buffer_builder.finish(); - let builder = ArrayData::builder(T::DATA_TYPE) + let builder = ArrayData::builder(self.data_type.clone()) .len(len) .add_buffer(self.values_builder.finish()) .null_bit_buffer(null_bit_buffer); @@ -267,7 +290,7 @@ impl PrimitiveBuilder { .as_slice() .map(Buffer::from_slice_ref); let values_buffer = Buffer::from_slice_ref(self.values_builder.as_slice()); - let builder = ArrayData::builder(T::DATA_TYPE) + let builder = ArrayData::builder(self.data_type.clone()) .len(len) .add_buffer(values_buffer) .null_bit_buffer(null_bit_buffer); @@ -309,6 +332,7 @@ impl PrimitiveBuilder { mod tests { use super::*; use arrow_buffer::Buffer; + use arrow_schema::TimeUnit; use crate::array::Array; use crate::array::BooleanArray; @@ -528,4 +552,30 @@ mod tests { assert_eq!(5, arr.len()); assert_eq!(0, builder.len()); } + + #[test] + fn test_primitive_array_builder_with_data_type() { + let mut builder = + Decimal128Builder::new().with_data_type(DataType::Decimal128(1, 2)); + builder.append_value(1); + let array = builder.finish(); + assert_eq!(array.precision(), 1); + assert_eq!(array.scale(), 2); + + let data_type = + DataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".to_string())); + let mut builder = + TimestampNanosecondBuilder::new().with_data_type(data_type.clone()); + builder.append_value(1); + let array = builder.finish(); + assert_eq!(array.data_type(), &data_type); + } + + #[test] + #[should_panic( + expected = "incompatible data type for builder, expected Int32 got Int64" + )] + fn test_invalid_with_data_type() { + Int32Builder::new().with_data_type(DataType::Int64); + } } diff --git a/arrow-array/src/builder/struct_builder.rs b/arrow-array/src/builder/struct_builder.rs index 12bcaf0944ef..ecf9ca4ffea7 100644 --- a/arrow-array/src/builder/struct_builder.rs +++ b/arrow-array/src/builder/struct_builder.rs @@ -115,9 +115,10 @@ pub fn make_builder(datatype: &DataType, capacity: usize) -> Box { Box::new(FixedSizeBinaryBuilder::with_capacity(capacity, *len)) } - DataType::Decimal128(_precision, _scale) => { - Box::new(Decimal128Builder::with_capacity(capacity)) - } + DataType::Decimal128(p, s) => Box::new( + Decimal128Builder::with_capacity(capacity) + .with_data_type(DataType::Decimal128(*p, *s)), + ), DataType::Utf8 => Box::new(StringBuilder::with_capacity(capacity, 1024)), DataType::Date32 => Box::new(Date32Builder::with_capacity(capacity)), DataType::Date64 => Box::new(Date64Builder::with_capacity(capacity)), @@ -133,18 +134,22 @@ pub fn make_builder(datatype: &DataType, capacity: usize) -> Box { Box::new(Time64NanosecondBuilder::with_capacity(capacity)) } - DataType::Timestamp(TimeUnit::Second, _) => { - Box::new(TimestampSecondBuilder::with_capacity(capacity)) - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - Box::new(TimestampMillisecondBuilder::with_capacity(capacity)) - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - Box::new(TimestampMicrosecondBuilder::with_capacity(capacity)) - } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - Box::new(TimestampNanosecondBuilder::with_capacity(capacity)) - } + DataType::Timestamp(TimeUnit::Second, tz) => Box::new( + TimestampSecondBuilder::with_capacity(capacity) + .with_data_type(DataType::Timestamp(TimeUnit::Second, tz.clone())), + ), + DataType::Timestamp(TimeUnit::Millisecond, tz) => Box::new( + TimestampMillisecondBuilder::with_capacity(capacity) + .with_data_type(DataType::Timestamp(TimeUnit::Millisecond, tz.clone())), + ), + DataType::Timestamp(TimeUnit::Microsecond, tz) => Box::new( + TimestampMicrosecondBuilder::with_capacity(capacity) + .with_data_type(DataType::Timestamp(TimeUnit::Microsecond, tz.clone())), + ), + DataType::Timestamp(TimeUnit::Nanosecond, tz) => Box::new( + TimestampNanosecondBuilder::with_capacity(capacity) + .with_data_type(DataType::Timestamp(TimeUnit::Nanosecond, tz.clone())), + ), DataType::Interval(IntervalUnit::YearMonth) => { Box::new(IntervalYearMonthBuilder::with_capacity(capacity)) } @@ -484,6 +489,33 @@ mod tests { assert!(builder.field_builder::(2).is_some()); } + #[test] + fn test_datatype_properties() { + let fields = vec![ + Field::new("f1", DataType::Decimal128(1, 2), false), + Field::new( + "f2", + DataType::Timestamp(TimeUnit::Millisecond, Some("+00:00".to_string())), + false, + ), + ]; + let mut builder = StructBuilder::from_fields(fields.clone(), 1); + builder + .field_builder::(0) + .unwrap() + .append_value(1); + builder + .field_builder::(1) + .unwrap() + .append_value(1); + builder.append(true); + let array = builder.finish(); + + assert_eq!(array.data_type(), &DataType::Struct(fields.clone())); + assert_eq!(array.column(0).data_type(), fields[0].data_type()); + assert_eq!(array.column(1).data_type(), fields[1].data_type()); + } + #[test] #[should_panic( expected = "Data type List(Field { name: \"item\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) is not currently supported"