From c67a8a4a81f10626a7937837126a9ad3a2d428b0 Mon Sep 17 00:00:00 2001 From: Andrew Werner Date: Tue, 19 Nov 2024 23:30:29 -0500 Subject: [PATCH] vtab::arrow: support structs under list and array --- crates/duckdb/src/core/vector.rs | 10 +++ crates/duckdb/src/vtab/arrow.rs | 138 ++++++++++++++++++++++++++++++- 2 files changed, 144 insertions(+), 4 deletions(-) diff --git a/crates/duckdb/src/core/vector.rs b/crates/duckdb/src/core/vector.rs index 92e5622a..7ad6cbb7 100644 --- a/crates/duckdb/src/core/vector.rs +++ b/crates/duckdb/src/core/vector.rs @@ -161,6 +161,11 @@ impl ListVector { FlatVector::with_capacity(unsafe { duckdb_list_vector_get_child(self.entries.ptr) }, capacity) } + /// Returns the struct vector child. + pub fn struct_vector_child(&self) -> StructVector { + StructVector::from(unsafe { duckdb_list_vector_get_child(self.entries.ptr) }) + } + /// Set primitive data to the child node. pub fn set_child(&self, data: &[T]) { self.child(data.len()).copy(data); @@ -227,6 +232,11 @@ impl ArrayVector { FlatVector::with_capacity(unsafe { duckdb_array_vector_get_child(self.ptr) }, capacity) } + /// Returns the child vector. + pub fn struct_vector_child(&self) -> StructVector { + StructVector::from(unsafe { duckdb_array_vector_get_child(self.ptr) }) + } + /// Set primitive data to the child node. pub fn set_child(&self, data: &[T]) { self.child(data.len()).copy(data); diff --git a/crates/duckdb/src/vtab/arrow.rs b/crates/duckdb/src/vtab/arrow.rs index 219f6f71..875dbf7f 100644 --- a/crates/duckdb/src/vtab/arrow.rs +++ b/crates/duckdb/src/vtab/arrow.rs @@ -523,17 +523,23 @@ fn list_array_to_vector>( out: &mut ListVector, ) -> Result<(), Box> { let value_array = array.values(); - let mut child = out.child(value_array.len()); match value_array.data_type() { dt if dt.is_primitive() => { + let mut child = out.child(value_array.len()); primitive_array_to_vector(value_array.as_ref(), &mut child)?; } DataType::Utf8 => { + let mut child = out.child(value_array.len()); string_array_to_vector(as_string_array(value_array.as_ref()), &mut child); } DataType::Binary => { + let mut child = out.child(value_array.len()); binary_array_to_vector(as_generic_binary_array(value_array.as_ref()), &mut child); } + DataType::Struct(_) => { + let mut child = out.struct_vector_child(); + struct_array_to_vector(as_struct_array(value_array.as_ref()), &mut child)?; + } _ => { return Err("Nested list is not supported yet.".into()); } @@ -554,17 +560,23 @@ fn fixed_size_list_array_to_vector( out: &mut ArrayVector, ) -> Result<(), Box> { let value_array = array.values(); - let mut child = out.child(value_array.len()); match value_array.data_type() { dt if dt.is_primitive() => { + let mut child = out.child(value_array.len()); primitive_array_to_vector(value_array.as_ref(), &mut child)?; } DataType::Utf8 => { + let mut child = out.child(value_array.len()); string_array_to_vector(as_string_array(value_array.as_ref()), &mut child); } DataType::Binary => { + let mut child = out.child(value_array.len()); binary_array_to_vector(as_generic_binary_array(value_array.as_ref()), &mut child); } + DataType::Struct(_) => { + let mut child = out.struct_vector_child(); + struct_array_to_vector(as_struct_array(value_array.as_ref()), &mut child)?; + } _ => { return Err("Nested array is not supported yet.".into()); } @@ -703,17 +715,18 @@ mod test { use arrow::{ array::{ Array, ArrayRef, AsArray, BinaryArray, Date32Array, Date64Array, Decimal128Array, Decimal256Array, - DurationSecondArray, FixedSizeListArray, GenericByteArray, GenericListArray, Int32Array, + DurationSecondArray, FixedSizeListArray, GenericByteArray, GenericListArray, Int32Array, Int64Array, IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeStringArray, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, Time32SecondArray, Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, }, - buffer::{OffsetBuffer, ScalarBuffer}, + buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}, datatypes::{ i256, ArrowPrimitiveType, ByteArrayType, DataType, DurationSecondType, Field, Fields, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType, Schema, }, record_batch::RecordBatch, + util::pretty, }; use std::{error::Error, sync::Arc}; @@ -1264,4 +1277,121 @@ mod test { assert_eq!(column.len(), 1); assert_eq!(column.value(0), b"test"); } + + #[test] + fn test_array_of_structs() -> Result<(), Box> { + let db = Connection::open_in_memory()?; + db.register_table_function::("arrow")?; + + // Create the inner struct field + let struct_field = Field::new("foo", DataType::Int64, true); + let struct_type = DataType::Struct(Fields::from(vec![struct_field])); + + // Create struct array builder + let mut int_builder = Int64Array::builder(4); + int_builder.append_value(1); + int_builder.append_value(2); + int_builder.append_value(3); + int_builder.append_value(4); + int_builder.append_null(); + int_builder.append_value(5); + + let struct_array = StructArray::new( + Fields::from(vec![Field::new("foo", DataType::Int64, true)]), + vec![Arc::new(int_builder.finish())], + None, + ); + + // Create fixed size list array of structs + let array = FixedSizeListArray::new( + Arc::new(Field::new("item", struct_type, true)), + 2, + Arc::new(struct_array), + Some(NullBuffer::from([true, false, true].as_slice())), + ); + + // Create record batch + let schema = Schema::new(vec![Field::new("a", array.data_type().clone(), true)]); + let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)])?; + + let param = arrow_recordbatch_to_query_params(rb); + let mut stmt = db.prepare("SELECT a[1].foo, a[2].foo FROM arrow(?, ?)")?; + let mut arr = stmt.query_arrow(param)?; + let rb = arr.next().expect("no record batch"); + let rb = [rb]; + let printed = pretty::pretty_format_batches(&rb).unwrap(); + assert_eq!( + "\ ++------------+------------+ +| (a[1]).foo | (a[2]).foo | ++------------+------------+ +| 1 | 2 | +| | | +| | 5 | ++------------+------------+", + printed.to_string(), + "{printed}" + ); + Ok(()) + } + + #[test] + fn test_array_of_lists() -> Result<(), Box> { + let db = Connection::open_in_memory()?; + db.register_table_function::("arrow")?; + + // Create the inner struct field + let struct_field = Field::new("foo", DataType::Int64, true); + let struct_type = DataType::Struct(Fields::from(vec![struct_field])); + + // Create struct array builder + let mut int_builder = Int64Array::builder(4); + int_builder.append_value(1); + int_builder.append_value(2); + int_builder.append_value(3); + int_builder.append_value(4); + int_builder.append_null(); + int_builder.append_value(5); + + let struct_array = StructArray::new( + Fields::from(vec![Field::new("foo", DataType::Int64, true)]), + vec![Arc::new(int_builder.finish())], + None, + ); + + // Create fixed size list array of structs + let array = ListArray::new( + Arc::new(Field::new("item", struct_type, true)), + OffsetBuffer::from_lengths([1, 2, 0, 1, 0, 2]), + Arc::new(struct_array), + Some(NullBuffer::from([true, true, false, true, true, true].as_slice())), + ); + + // Create record batch + let schema = Schema::new(vec![Field::new("a", array.data_type().clone(), true)]); + let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)])?; + + let param = arrow_recordbatch_to_query_params(rb); + let mut stmt = db.prepare("SELECT a FROM arrow(?, ?)")?; + let mut arr = stmt.query_arrow(param)?; + let rb = arr.next().expect("no record batch"); + let rb = [rb]; + let printed = pretty::pretty_format_batches(&rb).unwrap(); + assert_eq!( + "\ ++----------------------+ +| a | ++----------------------+ +| [{foo: 1}] | +| [{foo: 2}, {foo: 3}] | +| | +| [{foo: 4}] | +| [] | +| [{foo: }, {foo: 5}] | ++----------------------+", + printed.to_string(), + "{printed}" + ); + Ok(()) + } }