Skip to content

Commit

Permalink
Preserve DataType metadata in make_builder
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Jan 3, 2023
1 parent 17b3210 commit 2ec48c7
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 17 deletions.
55 changes: 53 additions & 2 deletions arrow-array/src/builder/primitive_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use crate::types::*;
use crate::{ArrayRef, ArrowPrimitiveType, PrimitiveArray};
use arrow_buffer::{Buffer, MutableBuffer};
use arrow_data::ArrayData;
use arrow_schema::DataType;
use std::any::Any;
use std::sync::Arc;

Expand Down Expand Up @@ -94,6 +95,7 @@ pub type Decimal256Builder = PrimitiveBuilder<Decimal256Type>;
pub struct PrimitiveBuilder<T: ArrowPrimitiveType> {
values_builder: BufferBuilder<T::Native>,
null_buffer_builder: NullBufferBuilder,
data_type: DataType,
}

impl<T: ArrowPrimitiveType> ArrayBuilder for PrimitiveBuilder<T> {
Expand Down Expand Up @@ -150,6 +152,7 @@ impl<T: ArrowPrimitiveType> PrimitiveBuilder<T> {
Self {
values_builder: BufferBuilder::<T::Native>::new(capacity),
null_buffer_builder: NullBufferBuilder::new(capacity),
data_type: T::DATA_TYPE,
}
}

Expand All @@ -169,9 +172,30 @@ impl<T: ArrowPrimitiveType> PrimitiveBuilder<T> {
Self {
values_builder,
null_buffer_builder,
data_type: T::DATA_TYPE,
}
}

/// By default [`PrimitiveBuilder`] uses [`T::DATA_TYPE`] as the data type of
/// the generated array.
///
/// This method allows overriding the data type, to allow specifying timezones
/// for [`DataType::Timestamp`] or precision and scale for [`DataType::Decimal128`]
///
/// # Panics
///
/// This method panics if `data_type` is not the same variant as [`T::DATA_TYPE`]
pub fn with_data_type(self, data_type: DataType) -> Self {
assert_eq!(
std::mem::discriminant(&data_type),
std::mem::discriminant(&T::DATA_TYPE),
"incompatible data type for builder, expected {} got {}",
T::DATA_TYPE,
data_type
);
Self { data_type, ..self }
}

/// Returns the capacity of this builder measured in slots of type `T`
pub fn capacity(&self) -> usize {
self.values_builder.capacity()
Expand Down Expand Up @@ -250,7 +274,7 @@ impl<T: ArrowPrimitiveType> PrimitiveBuilder<T> {
pub fn finish(&mut self) -> PrimitiveArray<T> {
let len = self.len();
let null_bit_buffer = self.null_buffer_builder.finish();
let builder = ArrayData::builder(T::DATA_TYPE)
let builder = ArrayData::builder(self.data_type.clone())
.len(len)
.add_buffer(self.values_builder.finish())
.null_bit_buffer(null_bit_buffer);
Expand All @@ -267,7 +291,7 @@ impl<T: ArrowPrimitiveType> PrimitiveBuilder<T> {
.as_slice()
.map(Buffer::from_slice_ref);
let values_buffer = Buffer::from_slice_ref(self.values_builder.as_slice());
let builder = ArrayData::builder(T::DATA_TYPE)
let builder = ArrayData::builder(self.data_type.clone())
.len(len)
.add_buffer(values_buffer)
.null_bit_buffer(null_bit_buffer);
Expand Down Expand Up @@ -309,6 +333,7 @@ impl<T: ArrowPrimitiveType> PrimitiveBuilder<T> {
mod tests {
use super::*;
use arrow_buffer::Buffer;
use arrow_schema::TimeUnit;

use crate::array::Array;
use crate::array::BooleanArray;
Expand Down Expand Up @@ -528,4 +553,30 @@ mod tests {
assert_eq!(5, arr.len());
assert_eq!(0, builder.len());
}

#[test]
fn test_primitive_array_builder_with_data_type() {
let mut builder =
Decimal128Builder::new().with_data_type(DataType::Decimal128(1, 2));
builder.append_value(1);
let array = builder.finish();
assert_eq!(array.precision(), 1);
assert_eq!(array.scale(), 2);

let data_type =
DataType::Timestamp(TimeUnit::Millisecond, Some("+00:00".to_string()));
let mut builder =
TimestampNanosecondBuilder::new().with_data_type(data_type.clone());
builder.append_value(1);
let array = builder.finish();
assert_eq!(array.data_type(), &data_type);
}

#[test]
#[should_panic(
expected = "incompatible data type for builder, expected Int32 got Int64"
)]
fn test_invalid_with_data_type() {
Int32Builder::new().with_data_type(DataType::Int64);
}
}
62 changes: 47 additions & 15 deletions arrow-array/src/builder/struct_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,10 @@ pub fn make_builder(datatype: &DataType, capacity: usize) -> Box<dyn ArrayBuilde
DataType::FixedSizeBinary(len) => {
Box::new(FixedSizeBinaryBuilder::with_capacity(capacity, *len))
}
DataType::Decimal128(_precision, _scale) => {
Box::new(Decimal128Builder::with_capacity(capacity))
}
DataType::Decimal128(p, s) => Box::new(
Decimal128Builder::with_capacity(capacity)
.with_data_type(DataType::Decimal128(*p, *s)),
),
DataType::Utf8 => Box::new(StringBuilder::with_capacity(capacity, 1024)),
DataType::Date32 => Box::new(Date32Builder::with_capacity(capacity)),
DataType::Date64 => Box::new(Date64Builder::with_capacity(capacity)),
Expand All @@ -133,18 +134,22 @@ pub fn make_builder(datatype: &DataType, capacity: usize) -> Box<dyn ArrayBuilde
DataType::Time64(TimeUnit::Nanosecond) => {
Box::new(Time64NanosecondBuilder::with_capacity(capacity))
}
DataType::Timestamp(TimeUnit::Second, _) => {
Box::new(TimestampSecondBuilder::with_capacity(capacity))
}
DataType::Timestamp(TimeUnit::Millisecond, _) => {
Box::new(TimestampMillisecondBuilder::with_capacity(capacity))
}
DataType::Timestamp(TimeUnit::Microsecond, _) => {
Box::new(TimestampMicrosecondBuilder::with_capacity(capacity))
}
DataType::Timestamp(TimeUnit::Nanosecond, _) => {
Box::new(TimestampNanosecondBuilder::with_capacity(capacity))
}
DataType::Timestamp(TimeUnit::Second, tz) => Box::new(
TimestampSecondBuilder::with_capacity(capacity)
.with_data_type(DataType::Timestamp(TimeUnit::Second, tz.clone())),
),
DataType::Timestamp(TimeUnit::Millisecond, tz) => Box::new(
TimestampMillisecondBuilder::with_capacity(capacity)
.with_data_type(DataType::Timestamp(TimeUnit::Millisecond, tz.clone())),
),
DataType::Timestamp(TimeUnit::Microsecond, tz) => Box::new(
TimestampMicrosecondBuilder::with_capacity(capacity)
.with_data_type(DataType::Timestamp(TimeUnit::Microsecond, tz.clone())),
),
DataType::Timestamp(TimeUnit::Nanosecond, tz) => Box::new(
TimestampNanosecondBuilder::with_capacity(capacity)
.with_data_type(DataType::Timestamp(TimeUnit::Nanosecond, tz.clone())),
),
DataType::Interval(IntervalUnit::YearMonth) => {
Box::new(IntervalYearMonthBuilder::with_capacity(capacity))
}
Expand Down Expand Up @@ -484,6 +489,33 @@ mod tests {
assert!(builder.field_builder::<StructBuilder>(2).is_some());
}

#[test]
fn test_datatype_properties() {
let fields = vec![
Field::new("f1", DataType::Decimal128(1, 2), false),
Field::new(
"f2",
DataType::Timestamp(TimeUnit::Millisecond, Some("+00:00".to_string())),
false,
),
];
let mut builder = StructBuilder::from_fields(fields.clone(), 1);
builder
.field_builder::<Decimal128Builder>(0)
.unwrap()
.append_value(1);
builder
.field_builder::<TimestampMillisecondBuilder>(1)
.unwrap()
.append_value(1);
builder.append(true);
let array = builder.finish();

assert_eq!(array.data_type(), &DataType::Struct(fields.clone()));
assert_eq!(array.column(0).data_type(), fields[0].data_type());
assert_eq!(array.column(1).data_type(), fields[1].data_type());
}

#[test]
#[should_panic(
expected = "Data type List(Field { name: \"item\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) is not currently supported"
Expand Down

0 comments on commit 2ec48c7

Please sign in to comment.