diff --git a/bench-vortex/src/bin/notimplemented.rs b/bench-vortex/src/bin/notimplemented.rs index b77d2427ce..33b18bb4e2 100644 --- a/bench-vortex/src/bin/notimplemented.rs +++ b/bench-vortex/src/bin/notimplemented.rs @@ -1,6 +1,7 @@ #![feature(float_next_up_down)] use std::process::ExitCode; +use std::sync::Arc; use prettytable::{Cell, Row, Table}; use vortex::array::builder::VarBinBuilder; @@ -92,13 +93,11 @@ fn enc_impls() -> Vec { .into_array(), ConstantArray::new(10, 1).into_array(), DateTimePartsArray::try_new( - DType::Extension( - ExtDType::new( - TIME_ID.clone(), - Some(TemporalMetadata::Time(TimeUnit::S).into()), - ), - Nullability::NonNullable, - ), + DType::Extension(Arc::new(ExtDType::new( + TIME_ID.clone(), + Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)), + Some(TemporalMetadata::Time(TimeUnit::S).into()), + ))), PrimitiveArray::from(vec![1]).into_array(), PrimitiveArray::from(vec![0]).into_array(), PrimitiveArray::from(vec![0]).into_array(), diff --git a/encodings/datetime-parts/src/compute.rs b/encodings/datetime-parts/src/compute.rs index 5e57800580..5bebfb6ca2 100644 --- a/encodings/datetime-parts/src/compute.rs +++ b/encodings/datetime-parts/src/compute.rs @@ -5,9 +5,9 @@ use vortex::compute::{slice, take, ArrayCompute, SliceFn, TakeFn}; use vortex::validity::ArrayValidity; use vortex::{Array, ArrayDType, IntoArray, IntoArrayVariant}; use vortex_datetime_dtype::{TemporalMetadata, TimeUnit}; -use vortex_dtype::{DType, PType}; +use vortex_dtype::DType; use vortex_error::{vortex_bail, VortexResult, VortexUnwrap as _}; -use vortex_scalar::Scalar; +use vortex_scalar::{Scalar, ScalarValue}; use crate::DateTimePartsArray; @@ -51,22 +51,20 @@ impl SliceFn for DateTimePartsArray { impl ScalarAtFn for DateTimePartsArray { fn scalar_at(&self, index: usize) -> VortexResult { - let DType::Extension(ext, nullability) = self.dtype().clone() else { + let DType::Extension(ext) = self.dtype().clone() else { vortex_bail!( "DateTimePartsArray must have extension dtype, found {}", self.dtype() ); }; - let TemporalMetadata::Timestamp(time_unit, _) = TemporalMetadata::try_from(&ext)? else { + let TemporalMetadata::Timestamp(time_unit, _) = TemporalMetadata::try_from(ext.as_ref())? + else { vortex_bail!("Metadata must be Timestamp, found {}", ext.id()); }; if !self.is_valid(index) { - return Ok(Scalar::extension( - ext, - Scalar::null(DType::Primitive(PType::I64, nullability)), - )); + return Ok(Scalar::extension(ext, ScalarValue::Null)); } let divisor = match time_unit { @@ -83,10 +81,7 @@ impl ScalarAtFn for DateTimePartsArray { let scalar = days * 86_400 * divisor + seconds * divisor + subseconds; - Ok(Scalar::extension( - ext, - Scalar::primitive(scalar, nullability), - )) + Ok(Scalar::extension(ext, scalar.into())) } fn scalar_at_unchecked(&self, index: usize) -> Scalar { @@ -98,11 +93,11 @@ impl ScalarAtFn for DateTimePartsArray { /// /// Enforces that the passed array is actually a [DateTimePartsArray] with proper metadata. pub fn decode_to_temporal(array: &DateTimePartsArray) -> VortexResult { - let DType::Extension(ext, _) = array.dtype().clone() else { + let DType::Extension(ext) = array.dtype().clone() else { vortex_bail!(ComputeError: "expected dtype to be DType::Extension variant") }; - let Ok(temporal_metadata) = TemporalMetadata::try_from(&ext) else { + let Ok(temporal_metadata) = TemporalMetadata::try_from(ext.as_ref()) else { vortex_bail!(ComputeError: "must decode TemporalMetadata from extension metadata"); }; @@ -187,7 +182,7 @@ mod test { assert_eq!(validity, raw_millis.validity()); let date_times = DateTimePartsArray::try_new( - DType::Extension(temporal_array.ext_dtype().clone(), validity.nullability()), + DType::Extension(temporal_array.ext_dtype().clone()), days, seconds, subseconds, diff --git a/pyvortex/src/python_repr.rs b/pyvortex/src/python_repr.rs index 0a8cdf099c..2ee7709590 100644 --- a/pyvortex/src/python_repr.rs +++ b/pyvortex/src/python_repr.rs @@ -46,13 +46,18 @@ impl Display for DTypePythonRepr<'_> { n.python_repr() ), DType::List(edt, n) => write!(f, "list({}, {})", edt.python_repr(), n.python_repr()), - DType::Extension(ext, n) => { - write!(f, "ext(\"{}\", ", ext.id().python_repr())?; + DType::Extension(ext) => { + write!( + f, + "ext(\"{}\", {}, ", + ext.id().python_repr(), + ext.storage_dtype().python_repr() + )?; match ext.metadata() { None => write!(f, "None")?, Some(metadata) => write!(f, "{}", metadata.python_repr())?, }; - write!(f, ", {})", n.python_repr()) + write!(f, ")") } } } diff --git a/pyvortex/src/scalar.rs b/pyvortex/src/scalar.rs index dbf9a5ed1e..2f5e4aa7a3 100644 --- a/pyvortex/src/scalar.rs +++ b/pyvortex/src/scalar.rs @@ -14,7 +14,7 @@ use vortex_error::vortex_panic; use vortex_scalar::{PValue, Scalar, ScalarValue}; pub fn scalar_into_py(py: Python, x: Scalar, copy_into_python: bool) -> PyResult { - let (value, dtype) = x.into_parts(); + let (dtype, value) = x.into_parts(); scalar_value_into_py(py, value, &dtype, copy_into_python) } diff --git a/vortex-array/src/array/chunked/canonical.rs b/vortex-array/src/array/chunked/canonical.rs index 741d8a86d5..b50c4e53bc 100644 --- a/vortex-array/src/array/chunked/canonical.rs +++ b/vortex-array/src/array/chunked/canonical.rs @@ -69,7 +69,7 @@ pub(crate) fn try_canonicalize_chunks( // / \ // storage storage // - DType::Extension(ext_dtype, _) => { + DType::Extension(ext_dtype) => { // Recursively apply canonicalization and packing to the storage array backing // each chunk of the extension array. let storage_chunks: Vec = chunks @@ -78,11 +78,7 @@ pub(crate) fn try_canonicalize_chunks( // ExtensionArray, so we should canonicalize each chunk into ExtensionArray first. .map(|chunk| chunk.clone().into_extension().map(|ext| ext.storage())) .collect::>>()?; - let storage_dtype = storage_chunks - .first() - .ok_or_else(|| vortex_err!("Expected at least one chunk in ChunkedArray"))? - .dtype() - .clone(); + let storage_dtype = ext_dtype.storage_dtype().clone(); let chunked_storage = ChunkedArray::try_new(storage_chunks, storage_dtype)?.into_array(); diff --git a/vortex-array/src/array/constant/variants.rs b/vortex-array/src/array/constant/variants.rs index 81e3a590f5..5ff4d756e5 100644 --- a/vortex-array/src/array/constant/variants.rs +++ b/vortex-array/src/array/constant/variants.rs @@ -3,8 +3,8 @@ use std::sync::Arc; use vortex_dtype::field::Field; use vortex_dtype::{DType, PType}; -use vortex_error::{vortex_panic, VortexError, VortexExpect as _, VortexResult}; -use vortex_scalar::{ExtScalar, Scalar, ScalarValue, StructScalar}; +use vortex_error::{VortexError, VortexExpect as _, VortexResult}; +use vortex_scalar::{Scalar, ScalarValue, StructScalar}; use crate::array::constant::ConstantArray; use crate::iter::{Accessor, AccessorRef}; @@ -211,22 +211,9 @@ impl ListArrayTrait for ConstantArray {} impl ExtensionArrayTrait for ConstantArray { fn storage_array(&self) -> Array { - let scalar_ext = ExtScalar::try_new(self.dtype(), self.scalar_value()) - .vortex_expect("Expected an extension scalar"); - - // FIXME(ngates): there's not enough information to get the storage array. - let n = self.dtype().nullability(); - let storage_dtype = match scalar_ext.value() { - ScalarValue::Bool(_) => DType::Binary(n), - ScalarValue::Primitive(pvalue) => DType::Primitive(pvalue.ptype(), n), - ScalarValue::Buffer(_) => DType::Binary(n), - ScalarValue::BufferString(_) => DType::Utf8(n), - ScalarValue::List(_) => vortex_panic!("List not supported"), - ScalarValue::Null => DType::Null, - }; - + let storage_dtype = self.ext_dtype().storage_dtype().clone(); ConstantArray::new( - Scalar::new(storage_dtype, scalar_ext.value().clone()), + Scalar::new(storage_dtype, self.scalar_value().clone()), self.len(), ) .into_array() diff --git a/vortex-array/src/array/datetime/mod.rs b/vortex-array/src/array/datetime/mod.rs index 9f8a2cab1c..06accaa405 100644 --- a/vortex-array/src/array/datetime/mod.rs +++ b/vortex-array/src/array/datetime/mod.rs @@ -1,6 +1,8 @@ #[cfg(test)] mod test; +use std::sync::Arc; + use vortex_datetime_dtype::{TemporalMetadata, TimeUnit, DATE_ID, TIMESTAMP_ID, TIME_ID}; use vortex_dtype::{DType, ExtDType}; use vortex_error::{vortex_panic, VortexError}; @@ -68,28 +70,24 @@ impl TemporalArray { /// /// If any other time unit is provided, it panics. pub fn new_date(array: Array, time_unit: TimeUnit) -> Self { - let ext_dtype = match time_unit { + match time_unit { TimeUnit::D => { assert_width!(i32, array); - - ExtDType::new( - DATE_ID.clone(), - Some(TemporalMetadata::Date(time_unit).into()), - ) } TimeUnit::Ms => { assert_width!(i64, array); - - ExtDType::new( - DATE_ID.clone(), - Some(TemporalMetadata::Date(time_unit).into()), - ) } _ => vortex_panic!("invalid TimeUnit {time_unit} for vortex.date"), }; + let ext_dtype = ExtDType::new( + DATE_ID.clone(), + Arc::new(array.dtype().clone()), + Some(TemporalMetadata::Date(time_unit).into()), + ); + Self { - ext: ExtensionArray::new(ext_dtype, array), + ext: ExtensionArray::new(Arc::new(ext_dtype), array), temporal_metadata: TemporalMetadata::Date(time_unit), } } @@ -123,7 +121,11 @@ impl TemporalArray { let temporal_metadata = TemporalMetadata::Time(time_unit); Self { ext: ExtensionArray::new( - ExtDType::new(TIME_ID.clone(), Some(temporal_metadata.clone().into())), + Arc::new(ExtDType::new( + TIME_ID.clone(), + Arc::new(array.dtype().clone()), + Some(temporal_metadata.clone().into()), + )), array, ), temporal_metadata, @@ -145,7 +147,11 @@ impl TemporalArray { Self { ext: ExtensionArray::new( - ExtDType::new(TIMESTAMP_ID.clone(), Some(temporal_metadata.clone().into())), + Arc::new(ExtDType::new( + TIMESTAMP_ID.clone(), + Arc::new(array.dtype().clone()), + Some(temporal_metadata.clone().into()), + )), array, ), temporal_metadata, @@ -171,8 +177,8 @@ impl TemporalArray { } /// Retrieve the extension DType associated with the underlying array. - pub fn ext_dtype(&self) -> &ExtDType { - self.ext.ext_dtype() + pub fn ext_dtype(&self) -> Arc { + self.ext.ext_dtype().clone() } } @@ -195,7 +201,7 @@ impl TryFrom<&Array> for TemporalArray { /// `TemporalMetadata` variants, an error is returned. fn try_from(value: &Array) -> Result { let ext = ExtensionArray::try_from(value)?; - let temporal_metadata = TemporalMetadata::try_from(ext.ext_dtype())?; + let temporal_metadata = TemporalMetadata::try_from(ext.ext_dtype().as_ref())?; Ok(Self { ext, @@ -232,7 +238,7 @@ impl TryFrom for TemporalArray { type Error = VortexError; fn try_from(ext: ExtensionArray) -> Result { - let temporal_metadata = TemporalMetadata::try_from(ext.ext_dtype())?; + let temporal_metadata = TemporalMetadata::try_from(ext.ext_dtype().as_ref())?; Ok(Self { ext, temporal_metadata, diff --git a/vortex-array/src/array/extension/compute.rs b/vortex-array/src/array/extension/compute.rs index 97c971013e..f550dc7293 100644 --- a/vortex-array/src/array/extension/compute.rs +++ b/vortex-array/src/array/extension/compute.rs @@ -60,14 +60,14 @@ impl ScalarAtFn for ExtensionArray { fn scalar_at(&self, index: usize) -> VortexResult { Ok(Scalar::extension( self.ext_dtype().clone(), - scalar_at(self.storage(), index)?, + scalar_at(self.storage(), index)?.into_value(), )) } fn scalar_at_unchecked(&self, index: usize) -> Scalar { Scalar::extension( self.ext_dtype().clone(), - scalar_at_unchecked(self.storage(), index), + scalar_at_unchecked(self.storage(), index).into_value(), ) } } diff --git a/vortex-array/src/array/extension/mod.rs b/vortex-array/src/array/extension/mod.rs index b742feaa78..9ce1deebcc 100644 --- a/vortex-array/src/array/extension/mod.rs +++ b/vortex-array/src/array/extension/mod.rs @@ -1,4 +1,5 @@ use std::fmt::{Debug, Display}; +use std::sync::Arc; use serde::{Deserialize, Serialize}; use vortex_dtype::{DType, ExtDType, ExtID}; @@ -16,9 +17,7 @@ mod compute; impl_encoding!("vortex.ext", ids::EXTENSION, Extension); #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ExtensionMetadata { - storage_dtype: DType, -} +pub struct ExtensionMetadata; impl Display for ExtensionMetadata { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -27,13 +26,17 @@ impl Display for ExtensionMetadata { } impl ExtensionArray { - pub fn new(ext_dtype: ExtDType, storage: Array) -> Self { + pub fn new(ext_dtype: Arc, storage: Array) -> Self { + assert_eq!( + ext_dtype.storage_dtype(), + storage.dtype(), + "ExtensionArray: storage_dtype must match storage array DType", + ); + Self::try_from_parts( - DType::Extension(ext_dtype, storage.dtype().nullability()), + DType::Extension(ext_dtype), storage.len(), - ExtensionMetadata { - storage_dtype: storage.dtype().clone(), - }, + ExtensionMetadata, [storage].into(), Default::default(), ) @@ -42,7 +45,7 @@ impl ExtensionArray { pub fn storage(&self) -> Array { self.as_ref() - .child(0, &self.metadata().storage_dtype, self.len()) + .child(0, self.ext_dtype().storage_dtype(), self.len()) .vortex_expect("Missing storage array for ExtensionArray") } diff --git a/vortex-array/src/arrow/dtype.rs b/vortex-array/src/arrow/dtype.rs index 6d39c7f05e..3c2a884a4a 100644 --- a/vortex-array/src/arrow/dtype.rs +++ b/vortex-array/src/arrow/dtype.rs @@ -83,10 +83,9 @@ impl FromArrowType<&Field> for DType { | DataType::Date64 | DataType::Time32(_) | DataType::Time64(_) - | DataType::Timestamp(..) => Extension( - make_temporal_ext_dtype(field.data_type()), - field.is_nullable().into(), - ), + | DataType::Timestamp(..) => Extension(Arc::new( + make_temporal_ext_dtype(field.data_type()).with_nullability(nullability), + )), DataType::List(e) | DataType::LargeList(e) => { List(Arc::new(Self::from_arrow(e.as_ref())), nullability) } @@ -171,7 +170,7 @@ pub fn infer_data_type(dtype: &DType) -> VortexResult { // (32-bit), Large List View (64-bit). We cannot both guarantee zero-copy and commit to an // Arrow dtype because we do not how large our offsets are. DType::List(..) => vortex_bail!("Unsupported dtype: {}", dtype), - DType::Extension(ext_dtype, _) => { + DType::Extension(ext_dtype) => { // Try and match against the known extension DTypes. if is_temporal_ext_type(ext_dtype.id()) { make_arrow_temporal_dtype(ext_dtype) @@ -234,10 +233,11 @@ mod test { #[test] #[should_panic] fn test_dtype_conversion_panics() { - let _ = infer_data_type(&DType::Extension( - ExtDType::new(ExtID::from("my-fake-ext-dtype"), None), - Nullability::NonNullable, - )) + let _ = infer_data_type(&DType::Extension(Arc::new(ExtDType::new( + ExtID::from("my-fake-ext-dtype"), + Arc::new(DType::Utf8(Nullability::NonNullable)), + None, + )))) .unwrap(); } diff --git a/vortex-array/src/canonical.rs b/vortex-array/src/canonical.rs index a434f04929..abc1307a2a 100644 --- a/vortex-array/src/canonical.rs +++ b/vortex-array/src/canonical.rs @@ -76,10 +76,14 @@ impl Canonical { Canonical::Struct(a) => struct_to_arrow(a)?, Canonical::VarBinView(a) => varbinview_as_arrow(&a), Canonical::Extension(a) => { - if !is_temporal_ext_type(a.id()) { - vortex_bail!("unsupported extension dtype with ID {}", a.id().as_ref()) + if is_temporal_ext_type(a.id()) { + temporal_to_arrow(TemporalArray::try_from(&a.into_array())?)? + } else { + // Convert storage array directly into arrow, losing type information + // that will let us round-trip. + // TODO(aduffy): https://github.com/spiraldb/vortex/issues/1167 + a.storage().into_canonical()?.into_arrow()? } - temporal_to_arrow(TemporalArray::try_from(&a.into_array())?)? } }) } diff --git a/vortex-array/src/variants.rs b/vortex-array/src/variants.rs index 8730dbd722..7249110e11 100644 --- a/vortex-array/src/variants.rs +++ b/vortex-array/src/variants.rs @@ -4,6 +4,7 @@ //! encoding, they can use these traits to write encoding-agnostic code. use std::ops::Not; +use std::sync::Arc; use vortex_dtype::field::Field; use vortex_dtype::{DType, ExtDType, FieldNames}; @@ -256,8 +257,8 @@ pub trait ListArrayTrait: ArrayTrait {} pub trait ExtensionArrayTrait: ArrayTrait { /// Returns the extension logical [`DType`]. - fn ext_dtype(&self) -> &ExtDType { - let DType::Extension(ext_dtype, _nullability) = self.dtype() else { + fn ext_dtype(&self) -> &Arc { + let DType::Extension(ext_dtype) = self.dtype() else { vortex_panic!("Expected ExtDType") }; ext_dtype diff --git a/vortex-datafusion/src/datatype.rs b/vortex-datafusion/src/datatype.rs new file mode 100644 index 0000000000..65db1e54e0 --- /dev/null +++ b/vortex-datafusion/src/datatype.rs @@ -0,0 +1,204 @@ +//! Convert between Vortex [vortex_dtype::DType] and Apache Arrow [arrow_schema::DataType]. +//! +//! Apache Arrow's type system includes physical information, which could lead to ambiguities as +//! Vortex treats encodings as separate from logical types. +//! +//! [`infer_schema`] and its sibling [`infer_data_type`] use a simple algorithm, where every +//! logical type is encoded in its simplest corresponding Arrow type. This reflects the reality that +//! most compute engines don't make use of the entire type range arrow-rs supports. +//! +//! For this reason, it's recommended to do as much computation as possible within Vortex, and then +//! materialize an Arrow ArrayRef at the very end of the processing chain. + +use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaBuilder}; +use vortex_datetime_dtype::arrow::make_arrow_temporal_dtype; +use vortex_datetime_dtype::is_temporal_ext_type; +use vortex_dtype::{DType, Nullability, PType}; +use vortex_error::vortex_panic; + +/// Convert a Vortex [struct DType][DType] to an Arrow [Schema]. +/// +/// # Panics +/// +/// This function will panic if the provided `dtype` is not a StructDType, or if the struct DType +/// has top-level nullability. +pub(crate) fn infer_schema(dtype: &DType) -> Schema { + let DType::Struct(struct_dtype, nullable) = dtype else { + vortex_panic!("only DType::Struct can be converted to arrow schema"); + }; + + if *nullable != Nullability::NonNullable { + vortex_panic!("top-level struct in Schema must be NonNullable"); + } + + let mut builder = SchemaBuilder::with_capacity(struct_dtype.names().len()); + for (field_name, field_dtype) in struct_dtype + .names() + .iter() + .zip(struct_dtype.dtypes().iter()) + { + builder.push(FieldRef::from(Field::new( + field_name.to_string(), + infer_data_type(field_dtype), + field_dtype.is_nullable(), + ))); + } + + builder.finish() +} + +pub(crate) fn infer_data_type(dtype: &DType) -> DataType { + match dtype { + DType::Null => DataType::Null, + DType::Bool(_) => DataType::Boolean, + DType::Primitive(ptype, _) => match ptype { + PType::U8 => DataType::UInt8, + PType::U16 => DataType::UInt16, + PType::U32 => DataType::UInt32, + PType::U64 => DataType::UInt64, + PType::I8 => DataType::Int8, + PType::I16 => DataType::Int16, + PType::I32 => DataType::Int32, + PType::I64 => DataType::Int64, + PType::F16 => DataType::Float16, + PType::F32 => DataType::Float32, + PType::F64 => DataType::Float64, + }, + DType::Utf8(_) => DataType::Utf8, + DType::Binary(_) => DataType::Binary, + DType::Struct(struct_dtype, _) => { + let mut fields = Vec::with_capacity(struct_dtype.names().len()); + for (field_name, field_dt) in struct_dtype + .names() + .iter() + .zip(struct_dtype.dtypes().iter()) + { + fields.push(FieldRef::from(Field::new( + field_name.to_string(), + infer_data_type(field_dt), + field_dt.is_nullable(), + ))); + } + + DataType::Struct(Fields::from(fields)) + } + DType::List(list_dt, _) => { + let dtype: &DType = list_dt; + DataType::List(FieldRef::from(Field::new( + "element", + infer_data_type(dtype), + dtype.is_nullable(), + ))) + } + DType::Extension(ext_dtype) => { + // Special case: the Vortex logical type system represents many temporal types from + // Arrow, and we want those to serialize properly. + if is_temporal_ext_type(ext_dtype.id()) { + make_arrow_temporal_dtype(ext_dtype) + } else { + // All other extension types, we rely on the scalar type to determine how it gets + // pushed to Arrow. + infer_data_type(ext_dtype.storage_dtype()) + } + } + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow_schema::{DataType, Field, FieldRef, Fields, Schema}; + use vortex_dtype::{DType, FieldName, FieldNames, Nullability, PType, StructDType}; + + use super::*; + + #[test] + fn test_dtype_conversion_success() { + assert_eq!(infer_data_type(&DType::Null), DataType::Null); + + assert_eq!( + infer_data_type(&DType::Bool(Nullability::NonNullable)), + DataType::Boolean + ); + + assert_eq!( + infer_data_type(&DType::Primitive(PType::U64, Nullability::NonNullable)), + DataType::UInt64 + ); + + assert_eq!( + infer_data_type(&DType::Utf8(Nullability::NonNullable)), + DataType::Utf8 + ); + + assert_eq!( + infer_data_type(&DType::Binary(Nullability::NonNullable)), + DataType::Binary + ); + + assert_eq!( + infer_data_type(&DType::List( + Arc::new(DType::Bool(Nullability::NonNullable)), + Nullability::Nullable, + )), + DataType::List(FieldRef::from(Field::new( + "element".to_string(), + DataType::Boolean, + false, + ))) + ); + + assert_eq!( + infer_data_type(&DType::Struct( + StructDType::new( + FieldNames::from(vec![FieldName::from("field_a"), FieldName::from("field_b")]), + vec![DType::Bool(false.into()), DType::Utf8(true.into())], + ), + Nullability::NonNullable, + )), + DataType::Struct(Fields::from(vec![ + FieldRef::from(Field::new("field_a", DataType::Boolean, false)), + FieldRef::from(Field::new("field_b", DataType::Utf8, true)), + ])) + ); + } + + #[test] + fn test_schema_conversion() { + let struct_dtype = the_struct(); + let schema_nonnull = DType::Struct(struct_dtype.clone(), Nullability::NonNullable); + + assert_eq!( + infer_schema(&schema_nonnull), + Schema::new(Fields::from(vec![ + Field::new("field_a", DataType::Boolean, false), + Field::new("field_b", DataType::Utf8, false), + Field::new("field_c", DataType::Int32, true), + ])) + ); + } + + #[test] + #[should_panic] + fn test_schema_conversion_panics() { + let struct_dtype = the_struct(); + let schema_null = DType::Struct(struct_dtype.clone(), Nullability::Nullable); + let _ = infer_schema(&schema_null); + } + + fn the_struct() -> StructDType { + StructDType::new( + FieldNames::from([ + FieldName::from("field_a"), + FieldName::from("field_b"), + FieldName::from("field_c"), + ]), + vec![ + DType::Bool(Nullability::NonNullable), + DType::Utf8(Nullability::NonNullable), + DType::Primitive(PType::I32, Nullability::Nullable), + ], + ) + } +} diff --git a/vortex-datetime-dtype/src/arrow.rs b/vortex-datetime-dtype/src/arrow.rs index e334bf377a..335ed1f6d4 100644 --- a/vortex-datetime-dtype/src/arrow.rs +++ b/vortex-datetime-dtype/src/arrow.rs @@ -1,7 +1,9 @@ #![cfg(feature = "arrow")] +use std::sync::Arc; + use arrow_schema::{DataType, TimeUnit as ArrowTimeUnit}; -use vortex_dtype::ExtDType; +use vortex_dtype::{ExtDType, PType}; use vortex_error::{vortex_bail, vortex_panic, VortexError, VortexExpect as _, VortexResult}; use crate::temporal::{TemporalMetadata, DATE_ID, TIMESTAMP_ID, TIME_ID}; @@ -17,9 +19,10 @@ pub fn make_temporal_ext_dtype(data_type: &DataType) -> ExtDType { DataType::Timestamp(time_unit, time_zone) => { let time_unit = TimeUnit::from(time_unit); let tz = time_zone.clone().map(|s| s.to_string()); - + // PType is inferred for arrow based on the time units. ExtDType::new( TIMESTAMP_ID.clone(), + Arc::new(PType::I64.into()), Some(TemporalMetadata::Timestamp(time_unit, tz).into()), ) } @@ -27,6 +30,7 @@ pub fn make_temporal_ext_dtype(data_type: &DataType) -> ExtDType { let time_unit = TimeUnit::from(time_unit); ExtDType::new( TIME_ID.clone(), + Arc::new(PType::I32.into()), Some(TemporalMetadata::Time(time_unit).into()), ) } @@ -34,15 +38,18 @@ pub fn make_temporal_ext_dtype(data_type: &DataType) -> ExtDType { let time_unit = TimeUnit::from(time_unit); ExtDType::new( TIME_ID.clone(), + Arc::new(PType::I64.into()), Some(TemporalMetadata::Time(time_unit).into()), ) } DataType::Date32 => ExtDType::new( DATE_ID.clone(), + Arc::new(PType::I32.into()), Some(TemporalMetadata::Date(TimeUnit::D).into()), ), DataType::Date64 => ExtDType::new( DATE_ID.clone(), + Arc::new(PType::I64.into()), Some(TemporalMetadata::Date(TimeUnit::Ms).into()), ), _ => unimplemented!("{data_type} conversion"), @@ -123,6 +130,7 @@ mod tests { fn test_make_arrow_timestamp() { let ext_dtype = ExtDType::new( TIMESTAMP_ID.clone(), + Arc::new(PType::I64.into()), Some(TemporalMetadata::Timestamp(TimeUnit::Ms, None).into()), ); let expected_arrow_type = DataType::Timestamp(ArrowTimeUnit::Millisecond, None); @@ -138,6 +146,7 @@ mod tests { fn test_make_arrow_time32() { let ext_dtype = ExtDType::new( TIME_ID.clone(), + Arc::new(PType::I32.into()), Some(TemporalMetadata::Time(TimeUnit::Ms).into()), ); let expected_arrow_type = DataType::Time32(ArrowTimeUnit::Millisecond); @@ -152,6 +161,7 @@ mod tests { fn test_make_arrow_time64() { let ext_dtype = ExtDType::new( TIME_ID.clone(), + Arc::new(PType::I64.into()), Some(TemporalMetadata::Time(TimeUnit::Us).into()), ); let expected_arrow_type = DataType::Time64(ArrowTimeUnit::Microsecond); @@ -166,6 +176,7 @@ mod tests { fn test_make_arrow_date32() { let ext_dtype = ExtDType::new( DATE_ID.clone(), + Arc::new(PType::I32.into()), Some(TemporalMetadata::Date(TimeUnit::D).into()), ); let expected_arrow_type = DataType::Date32; @@ -180,6 +191,7 @@ mod tests { fn test_make_arrow_date64() { let ext_dtype = ExtDType::new( DATE_ID.clone(), + Arc::new(PType::I64.into()), Some(TemporalMetadata::Date(TimeUnit::Ms).into()), ); let expected_arrow_type = DataType::Date64; diff --git a/vortex-datetime-dtype/src/temporal.rs b/vortex-datetime-dtype/src/temporal.rs index 696322bb5b..eb46dfc143 100644 --- a/vortex-datetime-dtype/src/temporal.rs +++ b/vortex-datetime-dtype/src/temporal.rs @@ -1,4 +1,5 @@ use std::fmt::Display; +use std::sync::Arc; use jiff::civil::{Date, Time}; use jiff::{Timestamp, Zoned}; @@ -100,24 +101,33 @@ impl TemporalMetadata { use vortex_dtype::{ExtDType, ExtMetadata}; use vortex_error::{vortex_bail, vortex_err, vortex_panic, VortexError, VortexResult}; -impl TryFrom<&ExtDType> for TemporalMetadata { - type Error = VortexError; - - fn try_from(ext_dtype: &ExtDType) -> Result { - let metadata = ext_dtype - .metadata() - .ok_or_else(|| vortex_err!("ExtDType is missing metadata"))?; - match ext_dtype.id().as_ref() { - x if x == TIME_ID.as_ref() => decode_time_metadata(metadata), - x if x == DATE_ID.as_ref() => decode_date_metadata(metadata), - x if x == TIMESTAMP_ID.as_ref() => decode_timestamp_metadata(metadata), - _ => { - vortex_bail!("ExtDType must be one of the known temporal types") +macro_rules! impl_temporal_metadata_try_from { + ($typ:ty) => { + impl TryFrom<$typ> for TemporalMetadata { + type Error = VortexError; + + fn try_from(ext_dtype: $typ) -> Result { + let metadata = ext_dtype + .metadata() + .ok_or_else(|| vortex_err!("ExtDType is missing metadata"))?; + match ext_dtype.id().as_ref() { + x if x == TIME_ID.as_ref() => decode_time_metadata(metadata), + x if x == DATE_ID.as_ref() => decode_date_metadata(metadata), + x if x == TIMESTAMP_ID.as_ref() => decode_timestamp_metadata(metadata), + _ => { + vortex_bail!("ExtDType must be one of the known temporal types") + } + } } } - } + }; } +impl_temporal_metadata_try_from!(ExtDType); +impl_temporal_metadata_try_from!(&ExtDType); +impl_temporal_metadata_try_from!(Arc); +impl_temporal_metadata_try_from!(Box); + fn decode_date_metadata(ext_meta: &ExtMetadata) -> VortexResult { let tag = ext_meta.as_ref()[0]; let time_unit = @@ -188,7 +198,9 @@ impl From for ExtMetadata { #[cfg(test)] mod tests { - use vortex_dtype::{ExtDType, ExtMetadata}; + use std::sync::Arc; + + use vortex_dtype::{ExtDType, ExtMetadata, PType}; use crate::{TemporalMetadata, TimeUnit, TIMESTAMP_ID}; @@ -207,8 +219,12 @@ mod tests { .as_slice() ); - let temporal_metadata = - TemporalMetadata::try_from(&ExtDType::new(TIMESTAMP_ID.clone(), Some(meta))).unwrap(); + let temporal_metadata = TemporalMetadata::try_from(&ExtDType::new( + TIMESTAMP_ID.clone(), + Arc::new(PType::I64.into()), + Some(meta), + )) + .unwrap(); assert_eq!( temporal_metadata, diff --git a/vortex-dtype/src/dtype.rs b/vortex-dtype/src/dtype.rs index b97dbb4bee..59826cfeba 100644 --- a/vortex-dtype/src/dtype.rs +++ b/vortex-dtype/src/dtype.rs @@ -36,8 +36,8 @@ pub enum DType { Struct(StructDType, Nullability), /// A variable-length list type, parameterized by a single element DType List(Arc, Nullability), - /// Extension types are user-defined types - Extension(ExtDType, Nullability), + /// User-defined extension types + Extension(Arc), } impl DType { @@ -64,7 +64,7 @@ impl DType { Binary(n) => matches!(n, Nullable), Struct(_, n) => matches!(n, Nullable), List(_, n) => matches!(n, Nullable), - Extension(_, n) => matches!(n, Nullable), + Extension(ext_dtype) => ext_dtype.storage_dtype().is_nullable(), } } @@ -88,7 +88,7 @@ impl DType { Binary(_) => Binary(nullability), Struct(st, _) => Struct(st.clone(), nullability), List(c, _) => List(c.clone(), nullability), - Extension(ext, _) => Extension(ext.clone(), nullability), + Extension(ext) => Extension(Arc::new(ext.with_nullability(nullability))), } } @@ -155,14 +155,16 @@ impl Display for DType { n ), List(edt, n) => write!(f, "list({}){}", edt, n), - Extension(ext, n) => write!( + Extension(ext) => write!( f, - "ext({}{}){}", + "ext({}, {}{}){}", ext.id(), + ext.storage_dtype() + .with_nullability(Nullability::NonNullable), ext.metadata() .map(|m| format!(", {:?}", m)) .unwrap_or_else(|| "".to_string()), - n + ext.storage_dtype().nullability(), ), } } diff --git a/vortex-dtype/src/extension.rs b/vortex-dtype/src/extension.rs index 8231450582..066c3df0af 100644 --- a/vortex-dtype/src/extension.rs +++ b/vortex-dtype/src/extension.rs @@ -1,6 +1,8 @@ use std::fmt::{Display, Formatter}; use std::sync::Arc; +use crate::{DType, Nullability}; + /// A unique identifier for an extension type #[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd, Hash)] #[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))] @@ -60,13 +62,52 @@ impl From<&[u8]> for ExtMetadata { #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct ExtDType { id: ExtID, + storage_dtype: Arc, metadata: Option, } impl ExtDType { - /// Constructs a new `ExtDType` from an `ExtID` and optional `ExtMetadata` - pub fn new(id: ExtID, metadata: Option) -> Self { - Self { id, metadata } + /// Creates a new `ExtDType`. + /// + /// Extension data types in Vortex allows library users to express additional semantic meaning + /// on top of a set of scalar values. Metadata can optionally be provided for the extension type + /// to allow for parameterized types. + /// + /// A simple example would be if one wanted to create a `vortex.temperature` extension type. The + /// canonical encoding for such values would be `f64`, and the metadata can contain an optional + /// temperature unit, allowing downstream users to be sure they properly account for Celsius + /// and Fahrenheit conversions. + /// + /// ``` + /// use std::sync::Arc; + /// use vortex_dtype::{DType, ExtDType, ExtID, ExtMetadata, Nullability, PType}; + /// + /// #[repr(u8)] + /// enum TemperatureUnit { + /// C = 0u8, + /// F = 1u8, + /// } + /// + /// // Make a new extension type that encodes the unit for a set of nullable `f64`. + /// pub fn create_temperature_type(unit: TemperatureUnit) -> ExtDType { + /// ExtDType::new( + /// ExtID::new("vortex.temperature".into()), + /// Arc::new(DType::Primitive(PType::F64, Nullability::Nullable)), + /// Some(ExtMetadata::new([unit as u8].into())) + /// ) + /// } + /// ``` + pub fn new(id: ExtID, storage_dtype: Arc, metadata: Option) -> Self { + assert!( + !matches!(storage_dtype.as_ref(), &DType::Extension(_)), + "ExtDType cannot have Extension storage_dtype" + ); + + Self { + id, + storage_dtype, + metadata, + } } /// Returns the `ExtID` for this extension type @@ -75,6 +116,21 @@ impl ExtDType { &self.id } + /// Returns the `ExtMetadata` for this extension type, if it exists + #[inline] + pub fn storage_dtype(&self) -> &DType { + self.storage_dtype.as_ref() + } + + /// Returns a new `ExtDType` with the given nullability + pub fn with_nullability(&self, nullability: Nullability) -> Self { + Self::new( + self.id.clone(), + Arc::new(self.storage_dtype.with_nullability(nullability)), + self.metadata.clone(), + ) + } + /// Returns the `ExtMetadata` for this extension type, if it exists #[inline] pub fn metadata(&self) -> Option<&ExtMetadata> { diff --git a/vortex-dtype/src/serde/flatbuffers/mod.rs b/vortex-dtype/src/serde/flatbuffers/mod.rs index f8571614f8..49f8fa9496 100644 --- a/vortex-dtype/src/serde/flatbuffers/mod.rs +++ b/vortex-dtype/src/serde/flatbuffers/mod.rs @@ -86,10 +86,17 @@ impl TryFrom> for DType { vortex_err!("failed to parse extension id from flatbuffer") })?); let metadata = fb_ext.metadata().map(|m| ExtMetadata::from(m.bytes())); - Ok(Self::Extension( - ExtDType::new(id, metadata), - fb_ext.nullable().into(), - )) + Ok(Self::Extension(Arc::new(ExtDType::new( + id, + Arc::new( + DType::try_from(fb_ext.storage_dtype().ok_or_else(|| { + vortex_err!( + InvalidSerde: "storage_dtype must be present on DType fbs message") + })?) + .map_err(|e| vortex_err!("failed to create DType from fbs message: {e}"))?, + ), + metadata, + )))) } _ => Err(vortex_err!("Unknown DType variant")), } @@ -171,15 +178,16 @@ impl WriteFlatBuffer for DType { ) .as_union_value() } - Self::Extension(ext, n) => { + Self::Extension(ext) => { let id = Some(fbb.create_string(ext.id().as_ref())); + let storage_dtype = Some(ext.storage_dtype().write_flatbuffer(fbb)); let metadata = ext.metadata().map(|m| fbb.create_vector(m.as_ref())); fb::Extension::create( fbb, &fb::ExtensionArgs { id, + storage_dtype, metadata, - nullable: (*n).into(), }, ) .as_union_value() diff --git a/vortex-dtype/src/serde/proto.rs b/vortex-dtype/src/serde/proto.rs index 6ffa8ae89c..1b180e77a4 100644 --- a/vortex-dtype/src/serde/proto.rs +++ b/vortex-dtype/src/serde/proto.rs @@ -46,12 +46,16 @@ impl TryFrom<&pb::DType> for DType { )) } DtypeType::Extension(e) => Ok(Self::Extension( - ExtDType::new( + Arc::new(ExtDType::new( ExtID::from(e.id.as_str()), + Arc::new(DType::try_from(e.storage_dtype + .as_ref() + .ok_or_else(|| vortex_err!(InvalidSerde: "storage_dtype must be provided in DType proto message"))? + .as_ref(), + ).map_err(|e| vortex_err!("failed converting DType from proto message: {}", e))?), e.metadata.as_ref().map(|m| ExtMetadata::from(m.as_ref())), ), - e.nullable.into(), - )), + ))), } } } @@ -83,11 +87,11 @@ impl From<&DType> for pb::DType { element_type: Some(Box::new(l.as_ref().into())), nullable: (*n).into(), })), - DType::Extension(e, n) => DtypeType::Extension(pb::Extension { + DType::Extension(e) => DtypeType::Extension(Box::new(pb::Extension { id: e.id().as_ref().into(), + storage_dtype: Some(Box::new(e.storage_dtype().into())), metadata: e.metadata().map(|m| m.as_ref().into()), - nullable: (*n).into(), - }), + })), }), } } diff --git a/vortex-flatbuffers/flatbuffers/vortex-dtype/dtype.fbs b/vortex-flatbuffers/flatbuffers/vortex-dtype/dtype.fbs index 8a31de56e0..1b97fc1453 100644 --- a/vortex-flatbuffers/flatbuffers/vortex-dtype/dtype.fbs +++ b/vortex-flatbuffers/flatbuffers/vortex-dtype/dtype.fbs @@ -52,8 +52,8 @@ table List { table Extension { id: string; + storage_dtype: DType; metadata: [ubyte]; - nullable: bool; } union Type { @@ -72,4 +72,4 @@ table DType { type: Type; } -root_type DType; \ No newline at end of file +root_type DType; diff --git a/vortex-flatbuffers/src/generated/array.rs b/vortex-flatbuffers/src/generated/array.rs index 5eed6a7c8a..58cad0616d 100644 --- a/vortex-flatbuffers/src/generated/array.rs +++ b/vortex-flatbuffers/src/generated/array.rs @@ -3,8 +3,8 @@ // @generated -use crate::scalar::*; use crate::dtype::*; +use crate::scalar::*; use core::mem; use core::cmp::Ordering; diff --git a/vortex-flatbuffers/src/generated/dtype.rs b/vortex-flatbuffers/src/generated/dtype.rs index b86459ec7a..695eee4133 100644 --- a/vortex-flatbuffers/src/generated/dtype.rs +++ b/vortex-flatbuffers/src/generated/dtype.rs @@ -1128,8 +1128,8 @@ impl<'a> flatbuffers::Follow<'a> for Extension<'a> { impl<'a> Extension<'a> { pub const VT_ID: flatbuffers::VOffsetT = 4; - pub const VT_METADATA: flatbuffers::VOffsetT = 6; - pub const VT_NULLABLE: flatbuffers::VOffsetT = 8; + pub const VT_STORAGE_DTYPE: flatbuffers::VOffsetT = 6; + pub const VT_METADATA: flatbuffers::VOffsetT = 8; #[inline] pub unsafe fn init_from_table(table: flatbuffers::Table<'a>) -> Self { @@ -1142,8 +1142,8 @@ impl<'a> Extension<'a> { ) -> flatbuffers::WIPOffset> { let mut builder = ExtensionBuilder::new(_fbb); if let Some(x) = args.metadata { builder.add_metadata(x); } + if let Some(x) = args.storage_dtype { builder.add_storage_dtype(x); } if let Some(x) = args.id { builder.add_id(x); } - builder.add_nullable(args.nullable); builder.finish() } @@ -1156,18 +1156,18 @@ impl<'a> Extension<'a> { unsafe { self._tab.get::>(Extension::VT_ID, None)} } #[inline] - pub fn metadata(&self) -> Option> { + pub fn storage_dtype(&self) -> Option> { // Safety: // Created from valid Table for this object // which contains a valid value in this slot - unsafe { self._tab.get::>>(Extension::VT_METADATA, None)} + unsafe { self._tab.get::>(Extension::VT_STORAGE_DTYPE, None)} } #[inline] - pub fn nullable(&self) -> bool { + pub fn metadata(&self) -> Option> { // Safety: // Created from valid Table for this object // which contains a valid value in this slot - unsafe { self._tab.get::(Extension::VT_NULLABLE, Some(false)).unwrap()} + unsafe { self._tab.get::>>(Extension::VT_METADATA, None)} } } @@ -1179,24 +1179,24 @@ impl flatbuffers::Verifiable for Extension<'_> { use self::flatbuffers::Verifiable; v.visit_table(pos)? .visit_field::>("id", Self::VT_ID, false)? + .visit_field::>("storage_dtype", Self::VT_STORAGE_DTYPE, false)? .visit_field::>>("metadata", Self::VT_METADATA, false)? - .visit_field::("nullable", Self::VT_NULLABLE, false)? .finish(); Ok(()) } } pub struct ExtensionArgs<'a> { pub id: Option>, + pub storage_dtype: Option>>, pub metadata: Option>>, - pub nullable: bool, } impl<'a> Default for ExtensionArgs<'a> { #[inline] fn default() -> Self { ExtensionArgs { id: None, + storage_dtype: None, metadata: None, - nullable: false, } } } @@ -1211,12 +1211,12 @@ impl<'a: 'b, 'b, A: flatbuffers::Allocator + 'a> ExtensionBuilder<'a, 'b, A> { self.fbb_.push_slot_always::>(Extension::VT_ID, id); } #[inline] - pub fn add_metadata(&mut self, metadata: flatbuffers::WIPOffset>) { - self.fbb_.push_slot_always::>(Extension::VT_METADATA, metadata); + pub fn add_storage_dtype(&mut self, storage_dtype: flatbuffers::WIPOffset>) { + self.fbb_.push_slot_always::>(Extension::VT_STORAGE_DTYPE, storage_dtype); } #[inline] - pub fn add_nullable(&mut self, nullable: bool) { - self.fbb_.push_slot::(Extension::VT_NULLABLE, nullable, false); + pub fn add_metadata(&mut self, metadata: flatbuffers::WIPOffset>) { + self.fbb_.push_slot_always::>(Extension::VT_METADATA, metadata); } #[inline] pub fn new(_fbb: &'b mut flatbuffers::FlatBufferBuilder<'a, A>) -> ExtensionBuilder<'a, 'b, A> { @@ -1237,8 +1237,8 @@ impl core::fmt::Debug for Extension<'_> { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { let mut ds = f.debug_struct("Extension"); ds.field("id", &self.id()); + ds.field("storage_dtype", &self.storage_dtype()); ds.field("metadata", &self.metadata()); - ds.field("nullable", &self.nullable()); ds.finish() } } diff --git a/vortex-flatbuffers/src/generated/message.rs b/vortex-flatbuffers/src/generated/message.rs index 17c43298a4..e9a123c1b8 100644 --- a/vortex-flatbuffers/src/generated/message.rs +++ b/vortex-flatbuffers/src/generated/message.rs @@ -3,8 +3,8 @@ // @generated -use crate::scalar::*; use crate::dtype::*; +use crate::scalar::*; use crate::array::*; use core::mem; use core::cmp::Ordering; diff --git a/vortex-proto/proto/dtype.proto b/vortex-proto/proto/dtype.proto index d4fe588428..63efaca1ff 100644 --- a/vortex-proto/proto/dtype.proto +++ b/vortex-proto/proto/dtype.proto @@ -54,8 +54,8 @@ message List { message Extension { string id = 1; - optional bytes metadata = 2; - bool nullable = 3; + DType storage_dtype = 2; + optional bytes metadata = 3; } message DType { @@ -81,4 +81,4 @@ message Field { message FieldPath { repeated Field path = 1; -} \ No newline at end of file +} diff --git a/vortex-proto/src/generated/vortex.dtype.rs b/vortex-proto/src/generated/vortex.dtype.rs index 9a3a956007..667854cbea 100644 --- a/vortex-proto/src/generated/vortex.dtype.rs +++ b/vortex-proto/src/generated/vortex.dtype.rs @@ -1,14 +1,11 @@ // This file is @generated by prost-build. -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct Null {} -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct Bool { #[prost(bool, tag = "1")] pub nullable: bool, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct Primitive { #[prost(enumeration = "PType", tag = "1")] @@ -16,7 +13,6 @@ pub struct Primitive { #[prost(bool, tag = "2")] pub nullable: bool, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct Decimal { #[prost(uint32, tag = "1")] @@ -26,19 +22,16 @@ pub struct Decimal { #[prost(bool, tag = "3")] pub nullable: bool, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct Utf8 { #[prost(bool, tag = "1")] pub nullable: bool, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct Binary { #[prost(bool, tag = "1")] pub nullable: bool, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Struct { #[prost(string, repeated, tag = "1")] @@ -48,7 +41,6 @@ pub struct Struct { #[prost(bool, tag = "3")] pub nullable: bool, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct List { #[prost(message, optional, boxed, tag = "1")] @@ -56,17 +48,15 @@ pub struct List { #[prost(bool, tag = "2")] pub nullable: bool, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Extension { #[prost(string, tag = "1")] pub id: ::prost::alloc::string::String, - #[prost(bytes = "vec", optional, tag = "2")] + #[prost(message, optional, boxed, tag = "2")] + pub storage_dtype: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(bytes = "vec", optional, tag = "3")] pub metadata: ::core::option::Option<::prost::alloc::vec::Vec>, - #[prost(bool, tag = "3")] - pub nullable: bool, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct DType { #[prost(oneof = "d_type::DtypeType", tags = "1, 2, 3, 4, 5, 6, 7, 8, 9")] @@ -74,7 +64,6 @@ pub struct DType { } /// Nested message and enum types in `DType`. pub mod d_type { - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum DtypeType { #[prost(message, tag = "1")] @@ -94,10 +83,9 @@ pub mod d_type { #[prost(message, tag = "8")] List(::prost::alloc::boxed::Box), #[prost(message, tag = "9")] - Extension(super::Extension), + Extension(::prost::alloc::boxed::Box), } } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Field { #[prost(oneof = "field::FieldType", tags = "1, 2")] @@ -105,7 +93,6 @@ pub struct Field { } /// Nested message and enum types in `Field`. pub mod field { - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum FieldType { #[prost(string, tag = "1")] @@ -114,7 +101,6 @@ pub mod field { Index(u64), } } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FieldPath { #[prost(message, repeated, tag = "1")] @@ -142,17 +128,17 @@ impl PType { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - PType::U8 => "U8", - PType::U16 => "U16", - PType::U32 => "U32", - PType::U64 => "U64", - PType::I8 => "I8", - PType::I16 => "I16", - PType::I32 => "I32", - PType::I64 => "I64", - PType::F16 => "F16", - PType::F32 => "F32", - PType::F64 => "F64", + Self::U8 => "U8", + Self::U16 => "U16", + Self::U32 => "U32", + Self::U64 => "U64", + Self::I8 => "I8", + Self::I16 => "I16", + Self::I32 => "I32", + Self::I64 => "I64", + Self::F16 => "F16", + Self::F32 => "F32", + Self::F64 => "F64", } } /// Creates an enum from field names used in the ProtoBuf definition. diff --git a/vortex-proto/src/generated/vortex.scalar.rs b/vortex-proto/src/generated/vortex.scalar.rs index 1cae9384e1..f79441ff51 100644 --- a/vortex-proto/src/generated/vortex.scalar.rs +++ b/vortex-proto/src/generated/vortex.scalar.rs @@ -1,5 +1,4 @@ // This file is @generated by prost-build. -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Scalar { #[prost(message, optional, tag = "1")] @@ -7,7 +6,6 @@ pub struct Scalar { #[prost(message, optional, tag = "2")] pub value: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ScalarValue { #[prost(oneof = "scalar_value::Kind", tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12")] @@ -15,7 +13,6 @@ pub struct ScalarValue { } /// Nested message and enum types in `ScalarValue`. pub mod scalar_value { - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum Kind { #[prost(enumeration = "::prost_types::NullValue", tag = "1")] @@ -42,7 +39,6 @@ pub mod scalar_value { ListValue(super::ListValue), } } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ListValue { #[prost(message, repeated, tag = "1")] diff --git a/vortex-proto/src/lib.rs b/vortex-proto/src/lib.rs index 41a1a9d178..2660878ed2 100644 --- a/vortex-proto/src/lib.rs +++ b/vortex-proto/src/lib.rs @@ -1,3 +1,5 @@ +#![allow(clippy::all, clippy::nursery)] + #[cfg(feature = "dtype")] #[rustfmt::skip] #[path = "./generated/vortex.dtype.rs"] diff --git a/vortex-scalar/src/arrow.rs b/vortex-scalar/src/arrow.rs index 9a02462bbc..e79735368c 100644 --- a/vortex-scalar/src/arrow.rs +++ b/vortex-scalar/src/arrow.rs @@ -66,9 +66,9 @@ impl TryFrom<&Scalar> for Arc { DType::List(..) => { todo!("list scalar conversion") } - DType::Extension(ext, _) => { + DType::Extension(ext) => { if is_temporal_ext_type(ext.id()) { - let metadata = TemporalMetadata::try_from(ext)?; + let metadata = TemporalMetadata::try_from(ext.as_ref())?; let pv = value.value.as_pvalue()?; return match metadata { TemporalMetadata::Time(u) => match u { diff --git a/vortex-scalar/src/datafusion.rs b/vortex-scalar/src/datafusion.rs index 806c19084d..6a9cbf7c98 100644 --- a/vortex-scalar/src/datafusion.rs +++ b/vortex-scalar/src/datafusion.rs @@ -1,4 +1,6 @@ #![cfg(feature = "datafusion")] +use std::sync::Arc; + use datafusion_common::ScalarValue; use vortex_buffer::Buffer; use vortex_datetime_dtype::arrow::make_temporal_ext_dtype; @@ -12,11 +14,12 @@ impl TryFrom for ScalarValue { type Error = VortexError; fn try_from(value: Scalar) -> Result { - Ok(match value.dtype { + let (dtype, value) = value.into_parts(); + Ok(match dtype { DType::Null => ScalarValue::Null, - DType::Bool(_) => ScalarValue::Boolean(value.value.as_bool()?), + DType::Bool(_) => ScalarValue::Boolean(value.as_bool()?), DType::Primitive(ptype, _) => { - let pvalue = value.value.as_pvalue()?; + let pvalue = value.as_pvalue()?; match pvalue { None => match ptype { PType::U8 => ScalarValue::UInt8(None), @@ -46,15 +49,11 @@ impl TryFrom for ScalarValue { }, } } - DType::Utf8(_) => ScalarValue::Utf8( - value - .value - .as_buffer_string()? - .map(|b| b.as_str().to_string()), - ), + DType::Utf8(_) => { + ScalarValue::Utf8(value.as_buffer_string()?.map(|b| b.as_str().to_string())) + } DType::Binary(_) => ScalarValue::Binary( value - .value .as_buffer()? .map(|b| b.into_vec().unwrap_or_else(|buf| buf.as_slice().to_vec())), ), @@ -64,10 +63,12 @@ impl TryFrom for ScalarValue { DType::List(..) => { todo!("list scalar conversion") } - DType::Extension(ext, _) => { + DType::Extension(ext) => { + // Special handling: temporal extension types in Vortex correspond to Arrow's + // temporal physical types. if is_temporal_ext_type(ext.id()) { - let metadata = TemporalMetadata::try_from(&ext)?; - let pv = value.value.as_pvalue()?; + let metadata = TemporalMetadata::try_from(ext.as_ref())?; + let pv = value.as_pvalue()?; return Ok(match metadata { TemporalMetadata::Time(u) => match u { TimeUnit::Ns => { @@ -111,9 +112,11 @@ impl TryFrom for ScalarValue { } }, }); + } else { + // Unknown extension type: perform scalar conversion using the canonical + // scalar DType. + ScalarValue::try_from(Scalar::new(ext.storage_dtype().clone(), value))? } - - todo!("Non temporal extension scalar conversion") } }) } @@ -147,9 +150,10 @@ impl From for Scalar { ScalarValue::Date32(v) | ScalarValue::Time32Second(v) | ScalarValue::Time32Millisecond(v) => v.map(|i| { - let ext_dtype = make_temporal_ext_dtype(&value.data_type()); + let ext_dtype = make_temporal_ext_dtype(&value.data_type()) + .with_nullability(Nullability::Nullable); Scalar::new( - DType::Extension(ext_dtype, Nullability::Nullable), + DType::Extension(Arc::new(ext_dtype)), crate::ScalarValue::Primitive(PValue::I32(i)), ) }), @@ -162,7 +166,7 @@ impl From for Scalar { | ScalarValue::TimestampNanosecond(v, _) => v.map(|i| { let ext_dtype = make_temporal_ext_dtype(&value.data_type()); Scalar::new( - DType::Extension(ext_dtype, Nullability::Nullable), + DType::Extension(Arc::new(ext_dtype.with_nullability(Nullability::Nullable))), crate::ScalarValue::Primitive(PValue::I64(i)), ) }), diff --git a/vortex-scalar/src/display.rs b/vortex-scalar/src/display.rs index b9dced5ecf..4a6e793a3f 100644 --- a/vortex-scalar/src/display.rs +++ b/vortex-scalar/src/display.rs @@ -59,8 +59,10 @@ impl Display for Scalar { } } DType::List(..) => todo!(), - DType::Extension(dtype, _) if is_temporal_ext_type(dtype.id()) => { - let metadata = TemporalMetadata::try_from(dtype).map_err(|_| std::fmt::Error)?; + // Specialized handling for date/time/timestamp builtin extension types. + DType::Extension(dtype) if is_temporal_ext_type(dtype.id()) => { + let metadata = + TemporalMetadata::try_from(dtype.as_ref()).map_err(|_| std::fmt::Error)?; match ExtScalar::try_from(self) .map_err(|_| std::fmt::Error)? .value() @@ -79,7 +81,18 @@ impl Display for Scalar { _ => Err(std::fmt::Error), } } - DType::Extension(..) => todo!(), + // Generic handling of unknown extension types. + // TODO(aduffy): Allow extension authors plugin their own Scalar display. + DType::Extension(..) => { + let scalar_value = ExtScalar::try_from(self) + .map_err(|_| std::fmt::Error)? + .value(); + if scalar_value.is_null() { + write!(f, "null") + } else { + write!(f, "{}", scalar_value) + } + } } } } @@ -234,13 +247,11 @@ mod tests { #[test] fn display_time() { fn dtype() -> DType { - DType::Extension( - ExtDType::new( - TIME_ID.clone(), - Some(ExtMetadata::from(TemporalMetadata::Time(TimeUnit::S))), - ), - Nullable, - ) + DType::Extension(Arc::new(ExtDType::new( + TIME_ID.clone(), + Arc::new(DType::Primitive(PType::I32, Nullable)), + Some(ExtMetadata::from(TemporalMetadata::Time(TimeUnit::S))), + ))) } assert_eq!(format!("{}", Scalar::null(dtype())), "null"); @@ -260,13 +271,11 @@ mod tests { #[test] fn display_date() { fn dtype() -> DType { - DType::Extension( - ExtDType::new( - DATE_ID.clone(), - Some(ExtMetadata::from(TemporalMetadata::Date(TimeUnit::D))), - ), - Nullable, - ) + DType::Extension(Arc::new(ExtDType::new( + DATE_ID.clone(), + Arc::new(DType::Primitive(PType::I32, Nullable)), + Some(ExtMetadata::from(TemporalMetadata::Date(TimeUnit::D))), + ))) } assert_eq!(format!("{}", Scalar::null(dtype())), "null"); @@ -299,16 +308,14 @@ mod tests { #[test] fn display_local_timestamp() { fn dtype() -> DType { - DType::Extension( - ExtDType::new( - TIMESTAMP_ID.clone(), - Some(ExtMetadata::from(TemporalMetadata::Timestamp( - TimeUnit::S, - None, - ))), - ), - Nullable, - ) + DType::Extension(Arc::new(ExtDType::new( + TIMESTAMP_ID.clone(), + Arc::new(DType::Primitive(PType::I32, Nullable)), + Some(ExtMetadata::from(TemporalMetadata::Timestamp( + TimeUnit::S, + None, + ))), + ))) } assert_eq!(format!("{}", Scalar::null(dtype())), "null"); @@ -329,16 +336,14 @@ mod tests { #[test] fn display_zoned_timestamp() { fn dtype() -> DType { - DType::Extension( - ExtDType::new( - TIMESTAMP_ID.clone(), - Some(ExtMetadata::from(TemporalMetadata::Timestamp( - TimeUnit::S, - Some(String::from("Pacific/Guam")), - ))), - ), - Nullable, - ) + DType::Extension(Arc::new(ExtDType::new( + TIMESTAMP_ID.clone(), + Arc::new(DType::Primitive(PType::I64, Nullable)), + Some(ExtMetadata::from(TemporalMetadata::Timestamp( + TimeUnit::S, + Some(String::from("Pacific/Guam")), + ))), + ))) } assert_eq!(format!("{}", Scalar::null(dtype())), "null"); diff --git a/vortex-scalar/src/extension.rs b/vortex-scalar/src/extension.rs index e345693d3f..f7068c5a02 100644 --- a/vortex-scalar/src/extension.rs +++ b/vortex-scalar/src/extension.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use vortex_dtype::{DType, ExtDType}; use vortex_error::{vortex_bail, VortexError, VortexResult}; @@ -44,10 +46,10 @@ impl<'a> TryFrom<&'a Scalar> for ExtScalar<'a> { } impl Scalar { - pub fn extension(ext_dtype: ExtDType, storage: Self) -> Self { + pub fn extension(ext_dtype: Arc, value: ScalarValue) -> Self { Self { - dtype: DType::Extension(ext_dtype, storage.dtype().nullability()), - value: storage.value, + dtype: DType::Extension(ext_dtype), + value, } } } diff --git a/vortex-scalar/src/lib.rs b/vortex-scalar/src/lib.rs index cc63358d14..52031495bb 100644 --- a/vortex-scalar/src/lib.rs +++ b/vortex-scalar/src/lib.rs @@ -57,13 +57,13 @@ impl Scalar { } #[inline] - pub fn into_value(self) -> ScalarValue { - self.value + pub fn into_parts(self) -> (DType, ScalarValue) { + (self.dtype, self.value) } #[inline] - pub fn into_parts(self) -> (ScalarValue, DType) { - (self.value, self.dtype) + pub fn into_value(self) -> ScalarValue { + self.value } pub fn is_valid(&self) -> bool { diff --git a/xtask/src/main.rs b/xtask/src/main.rs index 69e8049229..835374adbb 100644 --- a/xtask/src/main.rs +++ b/xtask/src/main.rs @@ -39,7 +39,6 @@ fn execute_generate_proto() -> anyhow::Result<()> { let proto_files = vec![ vortex_proto.join("proto").join("dtype.proto"), vortex_proto.join("proto").join("scalar.proto"), - vortex_proto.join("proto").join("expr.proto"), ]; for file in &proto_files {