Skip to content

Commit

Permalink
Support more time types to arrow vtab (#289)
Browse files Browse the repository at this point in the history
* add more time types to arrow vtab

* clippy

* properly support non-tz timestamps

* dont compare timezones
  • Loading branch information
Maxxen authored Apr 11, 2024
1 parent b82db39 commit f85893f
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 17 deletions.
221 changes: 204 additions & 17 deletions src/vtab/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use super::{
use crate::vtab::vector::Inserter;
use arrow::array::{
as_boolean_array, as_large_list_array, as_list_array, as_primitive_array, as_string_array, Array, ArrayData,
BooleanArray, Decimal128Array, FixedSizeListArray, GenericListArray, OffsetSizeTrait, PrimitiveArray, StringArray,
StructArray,
AsArray, BooleanArray, Decimal128Array, FixedSizeListArray, GenericListArray, OffsetSizeTrait, PrimitiveArray,
StringArray, StructArray,
};

use arrow::{
Expand Down Expand Up @@ -138,9 +138,15 @@ pub fn to_duckdb_type_id(data_type: &DataType) -> Result<LogicalTypeId, Box<dyn
DataType::UInt64 => UBigint,
DataType::Float32 => Float,
DataType::Float64 => Double,
DataType::Timestamp(_, _) => Timestamp,
DataType::Date32 => Time,
DataType::Date64 => Time,
DataType::Timestamp(unit, None) => match unit {
TimeUnit::Second => TimestampS,
TimeUnit::Millisecond => TimestampMs,
TimeUnit::Microsecond => Timestamp,
TimeUnit::Nanosecond => TimestampNs,
},
DataType::Timestamp(_, Some(_)) => TimestampTZ,
DataType::Date32 => Date,
DataType::Date64 => Date,
DataType::Time32(_) => Time,
DataType::Time64(_) => Time,
DataType::Duration(_) => Interval,
Expand Down Expand Up @@ -250,6 +256,16 @@ fn primitive_array_to_flat_vector<T: ArrowPrimitiveType>(array: &PrimitiveArray<
out_vector.copy::<T::Native>(array.values());
}

fn primitive_array_to_flat_vector_cast<T: ArrowPrimitiveType>(
data_type: DataType,
array: &dyn Array,
out_vector: &mut dyn Vector,
) {
let array = arrow::compute::kernels::cast::cast(array, &data_type).unwrap();
let out_vector: &mut FlatVector = out_vector.as_mut_any().downcast_mut().unwrap();
out_vector.copy::<T::Native>(array.as_primitive::<T>().values());
}

fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) {
match array.data_type() {
DataType::Boolean => {
Expand Down Expand Up @@ -303,6 +319,7 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) {
out.as_mut_any().downcast_mut().unwrap(),
);
}
DataType::Float16 => todo!("Float16 is not supported yet"),
DataType::Float32 => {
primitive_array_to_flat_vector::<Float32Type>(
as_primitive_array(array),
Expand All @@ -324,22 +341,55 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) {
out.as_mut_any().downcast_mut().unwrap(),
);
}
// DataType::Decimal256(_, _) => {
// primitive_array_to_flat_vector::<Decimal256Type>(
// as_primitive_array(array),
// out.as_mut_any().downcast_mut().unwrap(),
// );
// }
_ => {
todo!()
DataType::Decimal256(_, _) => todo!("Decimal256 is not supported yet"),

// DuckDB Only supports timetamp_tz in microsecond precision
DataType::Timestamp(_, Some(tz)) => primitive_array_to_flat_vector_cast::<TimestampMicrosecondType>(
DataType::Timestamp(TimeUnit::Microsecond, Some(tz.clone())),
array,
out,
),
DataType::Timestamp(unit, None) => match unit {
TimeUnit::Second => primitive_array_to_flat_vector::<TimestampSecondType>(
as_primitive_array(array),
out.as_mut_any().downcast_mut().unwrap(),
),
TimeUnit::Millisecond => primitive_array_to_flat_vector::<TimestampMillisecondType>(
as_primitive_array(array),
out.as_mut_any().downcast_mut().unwrap(),
),
TimeUnit::Microsecond => primitive_array_to_flat_vector::<TimestampMicrosecondType>(
as_primitive_array(array),
out.as_mut_any().downcast_mut().unwrap(),
),
TimeUnit::Nanosecond => primitive_array_to_flat_vector::<TimestampNanosecondType>(
as_primitive_array(array),
out.as_mut_any().downcast_mut().unwrap(),
),
},
DataType::Date32 => {
primitive_array_to_flat_vector::<Date32Type>(
as_primitive_array(array),
out.as_mut_any().downcast_mut().unwrap(),
);
}
DataType::Date64 => primitive_array_to_flat_vector_cast::<Date32Type>(Date32Type::DATA_TYPE, array, out),
DataType::Time32(_) => {
primitive_array_to_flat_vector_cast::<Time64MicrosecondType>(Time64MicrosecondType::DATA_TYPE, array, out)
}
DataType::Time64(_) => {
primitive_array_to_flat_vector_cast::<Time64MicrosecondType>(Time64MicrosecondType::DATA_TYPE, array, out)
}
_ => todo!(
"Converting '{dtype:#?}' to primitive flat vector is not supported",
dtype = array.data_type()
),
}
}

/// Convert Arrow [BooleanArray] to a duckdb vector.
/// Convert Arrow [Decimal128Array] to a duckdb vector.
fn decimal_array_to_vector(array: &Decimal128Array, out: &mut FlatVector) {
assert!(array.len() <= out.capacity());

for i in 0..array.len() {
out.as_mut_slice()[i] = array.value_as_string(i).parse::<f64>().unwrap();
}
Expand Down Expand Up @@ -488,8 +538,12 @@ mod test {
use super::{arrow_recordbatch_to_query_params, ArrowVTab};
use crate::{Connection, Result};
use arrow::{
array::{Float64Array, Int32Array},
datatypes::{DataType, Field, Schema},
array::{
Array, AsArray, Date32Array, Date64Array, Float64Array, Int32Array, PrimitiveArray, StringArray,
Time32SecondArray, Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
TimestampNanosecondArray, TimestampSecondArray,
},
datatypes::{ArrowPrimitiveType, DataType, Field, Schema},
record_batch::RecordBatch,
};
use std::{error::Error, sync::Arc};
Expand Down Expand Up @@ -534,4 +588,137 @@ mod test {
assert_eq!(column.value(0), 15);
Ok(())
}

fn check_rust_primitive_array_roundtrip<T1, T2>(
input_array: PrimitiveArray<T1>,
expected_array: PrimitiveArray<T2>,
) -> Result<(), Box<dyn Error>>
where
T1: ArrowPrimitiveType,
T2: ArrowPrimitiveType,
{
let db = Connection::open_in_memory()?;
db.register_table_function::<ArrowVTab>("arrow")?;

// Roundtrip a record batch from Rust to DuckDB and back to Rust
let schema = Schema::new(vec![Field::new("a", input_array.data_type().clone(), false)]);

let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(input_array.clone())])?;
let param = arrow_recordbatch_to_query_params(rb);
let mut stmt = db.prepare("select a from arrow(?, ?)")?;
let rb = stmt.query_arrow(param)?.next().expect("no record batch");

let output_any_array = rb.column(0);
match (output_any_array.data_type(), expected_array.data_type()) {
// TODO: DuckDB doesnt return timestamp_tz properly yet, so we just check that the units are the same
(DataType::Timestamp(unit_a, _), DataType::Timestamp(unit_b, _)) => assert_eq!(unit_a, unit_b),
(a, b) => assert_eq!(a, b),
}

let maybe_output_array = output_any_array.as_primitive_opt::<T2>();

match maybe_output_array {
Some(output_array) => {
// Check that the output array is the same as the input array
assert_eq!(output_array.len(), expected_array.len());
for i in 0..output_array.len() {
assert_eq!(output_array.is_valid(i), expected_array.is_valid(i));
if output_array.is_valid(i) {
assert_eq!(output_array.value(i), expected_array.value(i));
}
}
}
None => {
panic!("Output array is not a PrimitiveArray {:?}", rb.column(0).data_type());
}
}

Ok(())
}

#[test]
fn test_timestamp_roundtrip() -> Result<(), Box<dyn Error>> {
check_rust_primitive_array_roundtrip(Int32Array::from(vec![1, 2, 3]), Int32Array::from(vec![1, 2, 3]))?;

check_rust_primitive_array_roundtrip(
TimestampMicrosecondArray::from(vec![1, 2, 3]),
TimestampMicrosecondArray::from(vec![1, 2, 3]),
)?;

check_rust_primitive_array_roundtrip(
TimestampNanosecondArray::from(vec![1, 2, 3]),
TimestampNanosecondArray::from(vec![1, 2, 3]),
)?;

check_rust_primitive_array_roundtrip(
TimestampSecondArray::from(vec![1, 2, 3]),
TimestampSecondArray::from(vec![1, 2, 3]),
)?;

check_rust_primitive_array_roundtrip(
TimestampMillisecondArray::from(vec![1, 2, 3]),
TimestampMillisecondArray::from(vec![1, 2, 3]),
)?;

// DuckDB can only return timestamp_tz in microseconds
// Note: DuckDB by default returns timestamp_tz with UTC because the rust
// driver doesnt support timestamp_tz properly when reading. In the
// future we should be able to roundtrip timestamp_tz with other timezones too
check_rust_primitive_array_roundtrip(
TimestampNanosecondArray::from(vec![1000, 2000, 3000]).with_timezone_utc(),
TimestampMicrosecondArray::from(vec![1, 2, 3]).with_timezone_utc(),
)?;

check_rust_primitive_array_roundtrip(
TimestampMillisecondArray::from(vec![1, 2, 3]).with_timezone_utc(),
TimestampMicrosecondArray::from(vec![1000, 2000, 3000]).with_timezone_utc(),
)?;

check_rust_primitive_array_roundtrip(
TimestampSecondArray::from(vec![1, 2, 3]).with_timezone_utc(),
TimestampMicrosecondArray::from(vec![1_000_000, 2_000_000, 3_000_000]).with_timezone_utc(),
)?;

check_rust_primitive_array_roundtrip(
TimestampMicrosecondArray::from(vec![1, 2, 3]).with_timezone_utc(),
TimestampMicrosecondArray::from(vec![1, 2, 3]).with_timezone_utc(),
)?;

check_rust_primitive_array_roundtrip(Date32Array::from(vec![1, 2, 3]), Date32Array::from(vec![1, 2, 3]))?;

let mid = arrow::temporal_conversions::MILLISECONDS_IN_DAY;
check_rust_primitive_array_roundtrip(
Date64Array::from(vec![mid, 2 * mid, 3 * mid]),
Date32Array::from(vec![1, 2, 3]),
)?;

check_rust_primitive_array_roundtrip(
Time32SecondArray::from(vec![1, 2, 3]),
Time64MicrosecondArray::from(vec![1_000_000, 2_000_000, 3_000_000]),
)?;

Ok(())
}

#[test]
fn test_timestamp_tz_insert() -> Result<(), Box<dyn Error>> {
// TODO: This test should be reworked once we support TIMESTAMP_TZ properly

let db = Connection::open_in_memory()?;
db.register_table_function::<ArrowVTab>("arrow")?;

let array = TimestampMicrosecondArray::from(vec![1]).with_timezone("+05:00");
let schema = Schema::new(vec![Field::new("a", array.data_type().clone(), false)]);

// Since we cant get TIMESTAMP_TZ from the rust client yet, we just check that we can insert it properly here.
let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).expect("failed to create record batch");
let param = arrow_recordbatch_to_query_params(rb);
let mut stmt = db.prepare("select typeof(a)::VARCHAR from arrow(?, ?)")?;
let mut arr = stmt.query_arrow(param)?;
let rb = arr.next().expect("no record batch");
assert_eq!(rb.num_columns(), 1);
let column = rb.column(0).as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(column.value(0), "TIMESTAMP WITH TIME ZONE");
Ok(())
}
}
3 changes: 3 additions & 0 deletions src/vtab/logical_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ pub enum LogicalTypeId {
Uuid = DUCKDB_TYPE_DUCKDB_TYPE_UUID,
/// Union
Union = DUCKDB_TYPE_DUCKDB_TYPE_UNION,
/// Timestamp TZ
TimestampTZ = DUCKDB_TYPE_DUCKDB_TYPE_TIMESTAMP_TZ,
}

impl From<u32> for LogicalTypeId {
Expand Down Expand Up @@ -100,6 +102,7 @@ impl From<u32> for LogicalTypeId {
DUCKDB_TYPE_DUCKDB_TYPE_MAP => Self::Map,
DUCKDB_TYPE_DUCKDB_TYPE_UUID => Self::Uuid,
DUCKDB_TYPE_DUCKDB_TYPE_UNION => Self::Union,
DUCKDB_TYPE_DUCKDB_TYPE_TIMESTAMP_TZ => Self::TimestampTZ,
_ => panic!(),
}
}
Expand Down

0 comments on commit f85893f

Please sign in to comment.