Skip to content

Commit

Permalink
feat: specify the storage_dtype in ExtDType (#1007)
Browse files Browse the repository at this point in the history
Adds `storage_dtype` to `ExtDType`.

This PR adds `storage_dtype` to `ExtDType`.

This is desirable for a few reasons

* Makes it possible to canonicalize an empty chunked ExtensionArray
* Makes it possible to determine the storage DType for a ConstantArray
without examining its value
* Makes it possible for Vortex to reason about externally authored
extension types. This is still not fully complete, as an ideal
experience would allow extension authors to override IntoCanonical,
IntoArrow, Display, etc.

To avoid duplicating the nullability, we remove top-level `nullability`
from the `DType::Extension` variant, instead nullability is accessed
through the inner ExtDType.


---------

Co-authored-by: Will Manning <[email protected]>
  • Loading branch information
a10y and lwwmanning authored Nov 1, 2024
1 parent ba09095 commit 2d439a2
Show file tree
Hide file tree
Showing 33 changed files with 545 additions and 253 deletions.
13 changes: 6 additions & 7 deletions bench-vortex/src/bin/notimplemented.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![feature(float_next_up_down)]

use std::process::ExitCode;
use std::sync::Arc;

use prettytable::{Cell, Row, Table};
use vortex::array::builder::VarBinBuilder;
Expand Down Expand Up @@ -92,13 +93,11 @@ fn enc_impls() -> Vec<Array> {
.into_array(),
ConstantArray::new(10, 1).into_array(),
DateTimePartsArray::try_new(
DType::Extension(
ExtDType::new(
TIME_ID.clone(),
Some(TemporalMetadata::Time(TimeUnit::S).into()),
),
Nullability::NonNullable,
),
DType::Extension(Arc::new(ExtDType::new(
TIME_ID.clone(),
Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
Some(TemporalMetadata::Time(TimeUnit::S).into()),
))),
PrimitiveArray::from(vec![1]).into_array(),
PrimitiveArray::from(vec![0]).into_array(),
PrimitiveArray::from(vec![0]).into_array(),
Expand Down
25 changes: 10 additions & 15 deletions encodings/datetime-parts/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ use vortex::compute::{slice, take, ArrayCompute, SliceFn, TakeFn};
use vortex::validity::ArrayValidity;
use vortex::{Array, ArrayDType, IntoArray, IntoArrayVariant};
use vortex_datetime_dtype::{TemporalMetadata, TimeUnit};
use vortex_dtype::{DType, PType};
use vortex_dtype::DType;
use vortex_error::{vortex_bail, VortexResult, VortexUnwrap as _};
use vortex_scalar::Scalar;
use vortex_scalar::{Scalar, ScalarValue};

use crate::DateTimePartsArray;

Expand Down Expand Up @@ -51,22 +51,20 @@ impl SliceFn for DateTimePartsArray {

impl ScalarAtFn for DateTimePartsArray {
fn scalar_at(&self, index: usize) -> VortexResult<Scalar> {
let DType::Extension(ext, nullability) = self.dtype().clone() else {
let DType::Extension(ext) = self.dtype().clone() else {
vortex_bail!(
"DateTimePartsArray must have extension dtype, found {}",
self.dtype()
);
};

let TemporalMetadata::Timestamp(time_unit, _) = TemporalMetadata::try_from(&ext)? else {
let TemporalMetadata::Timestamp(time_unit, _) = TemporalMetadata::try_from(ext.as_ref())?
else {
vortex_bail!("Metadata must be Timestamp, found {}", ext.id());
};

if !self.is_valid(index) {
return Ok(Scalar::extension(
ext,
Scalar::null(DType::Primitive(PType::I64, nullability)),
));
return Ok(Scalar::extension(ext, ScalarValue::Null));
}

let divisor = match time_unit {
Expand All @@ -83,10 +81,7 @@ impl ScalarAtFn for DateTimePartsArray {

let scalar = days * 86_400 * divisor + seconds * divisor + subseconds;

Ok(Scalar::extension(
ext,
Scalar::primitive(scalar, nullability),
))
Ok(Scalar::extension(ext, scalar.into()))
}

fn scalar_at_unchecked(&self, index: usize) -> Scalar {
Expand All @@ -98,11 +93,11 @@ impl ScalarAtFn for DateTimePartsArray {
///
/// Enforces that the passed array is actually a [DateTimePartsArray] with proper metadata.
pub fn decode_to_temporal(array: &DateTimePartsArray) -> VortexResult<TemporalArray> {
let DType::Extension(ext, _) = array.dtype().clone() else {
let DType::Extension(ext) = array.dtype().clone() else {
vortex_bail!(ComputeError: "expected dtype to be DType::Extension variant")
};

let Ok(temporal_metadata) = TemporalMetadata::try_from(&ext) else {
let Ok(temporal_metadata) = TemporalMetadata::try_from(ext.as_ref()) else {
vortex_bail!(ComputeError: "must decode TemporalMetadata from extension metadata");
};

Expand Down Expand Up @@ -187,7 +182,7 @@ mod test {
assert_eq!(validity, raw_millis.validity());

let date_times = DateTimePartsArray::try_new(
DType::Extension(temporal_array.ext_dtype().clone(), validity.nullability()),
DType::Extension(temporal_array.ext_dtype().clone()),
days,
seconds,
subseconds,
Expand Down
11 changes: 8 additions & 3 deletions pyvortex/src/python_repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,18 @@ impl Display for DTypePythonRepr<'_> {
n.python_repr()
),
DType::List(edt, n) => write!(f, "list({}, {})", edt.python_repr(), n.python_repr()),
DType::Extension(ext, n) => {
write!(f, "ext(\"{}\", ", ext.id().python_repr())?;
DType::Extension(ext) => {
write!(
f,
"ext(\"{}\", {}, ",
ext.id().python_repr(),
ext.storage_dtype().python_repr()
)?;
match ext.metadata() {
None => write!(f, "None")?,
Some(metadata) => write!(f, "{}", metadata.python_repr())?,
};
write!(f, ", {})", n.python_repr())
write!(f, ")")
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion pyvortex/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use vortex_error::vortex_panic;
use vortex_scalar::{PValue, Scalar, ScalarValue};

pub fn scalar_into_py(py: Python, x: Scalar, copy_into_python: bool) -> PyResult<PyObject> {
let (value, dtype) = x.into_parts();
let (dtype, value) = x.into_parts();
scalar_value_into_py(py, value, &dtype, copy_into_python)
}

Expand Down
8 changes: 2 additions & 6 deletions vortex-array/src/array/chunked/canonical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ pub(crate) fn try_canonicalize_chunks(
// / \
// storage storage
//
DType::Extension(ext_dtype, _) => {
DType::Extension(ext_dtype) => {
// Recursively apply canonicalization and packing to the storage array backing
// each chunk of the extension array.
let storage_chunks: Vec<Array> = chunks
Expand All @@ -78,11 +78,7 @@ pub(crate) fn try_canonicalize_chunks(
// ExtensionArray, so we should canonicalize each chunk into ExtensionArray first.
.map(|chunk| chunk.clone().into_extension().map(|ext| ext.storage()))
.collect::<VortexResult<Vec<Array>>>()?;
let storage_dtype = storage_chunks
.first()
.ok_or_else(|| vortex_err!("Expected at least one chunk in ChunkedArray"))?
.dtype()
.clone();
let storage_dtype = ext_dtype.storage_dtype().clone();
let chunked_storage =
ChunkedArray::try_new(storage_chunks, storage_dtype)?.into_array();

Expand Down
21 changes: 4 additions & 17 deletions vortex-array/src/array/constant/variants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::sync::Arc;

use vortex_dtype::field::Field;
use vortex_dtype::{DType, PType};
use vortex_error::{vortex_panic, VortexError, VortexExpect as _, VortexResult};
use vortex_scalar::{ExtScalar, Scalar, ScalarValue, StructScalar};
use vortex_error::{VortexError, VortexExpect as _, VortexResult};
use vortex_scalar::{Scalar, ScalarValue, StructScalar};

use crate::array::constant::ConstantArray;
use crate::iter::{Accessor, AccessorRef};
Expand Down Expand Up @@ -211,22 +211,9 @@ impl ListArrayTrait for ConstantArray {}

impl ExtensionArrayTrait for ConstantArray {
fn storage_array(&self) -> Array {
let scalar_ext = ExtScalar::try_new(self.dtype(), self.scalar_value())
.vortex_expect("Expected an extension scalar");

// FIXME(ngates): there's not enough information to get the storage array.
let n = self.dtype().nullability();
let storage_dtype = match scalar_ext.value() {
ScalarValue::Bool(_) => DType::Binary(n),
ScalarValue::Primitive(pvalue) => DType::Primitive(pvalue.ptype(), n),
ScalarValue::Buffer(_) => DType::Binary(n),
ScalarValue::BufferString(_) => DType::Utf8(n),
ScalarValue::List(_) => vortex_panic!("List not supported"),
ScalarValue::Null => DType::Null,
};

let storage_dtype = self.ext_dtype().storage_dtype().clone();
ConstantArray::new(
Scalar::new(storage_dtype, scalar_ext.value().clone()),
Scalar::new(storage_dtype, self.scalar_value().clone()),
self.len(),
)
.into_array()
Expand Down
42 changes: 24 additions & 18 deletions vortex-array/src/array/datetime/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#[cfg(test)]
mod test;

use std::sync::Arc;

use vortex_datetime_dtype::{TemporalMetadata, TimeUnit, DATE_ID, TIMESTAMP_ID, TIME_ID};
use vortex_dtype::{DType, ExtDType};
use vortex_error::{vortex_panic, VortexError};
Expand Down Expand Up @@ -68,28 +70,24 @@ impl TemporalArray {
///
/// If any other time unit is provided, it panics.
pub fn new_date(array: Array, time_unit: TimeUnit) -> Self {
let ext_dtype = match time_unit {
match time_unit {
TimeUnit::D => {
assert_width!(i32, array);

ExtDType::new(
DATE_ID.clone(),
Some(TemporalMetadata::Date(time_unit).into()),
)
}
TimeUnit::Ms => {
assert_width!(i64, array);

ExtDType::new(
DATE_ID.clone(),
Some(TemporalMetadata::Date(time_unit).into()),
)
}
_ => vortex_panic!("invalid TimeUnit {time_unit} for vortex.date"),
};

let ext_dtype = ExtDType::new(
DATE_ID.clone(),
Arc::new(array.dtype().clone()),
Some(TemporalMetadata::Date(time_unit).into()),
);

Self {
ext: ExtensionArray::new(ext_dtype, array),
ext: ExtensionArray::new(Arc::new(ext_dtype), array),
temporal_metadata: TemporalMetadata::Date(time_unit),
}
}
Expand Down Expand Up @@ -123,7 +121,11 @@ impl TemporalArray {
let temporal_metadata = TemporalMetadata::Time(time_unit);
Self {
ext: ExtensionArray::new(
ExtDType::new(TIME_ID.clone(), Some(temporal_metadata.clone().into())),
Arc::new(ExtDType::new(
TIME_ID.clone(),
Arc::new(array.dtype().clone()),
Some(temporal_metadata.clone().into()),
)),
array,
),
temporal_metadata,
Expand All @@ -145,7 +147,11 @@ impl TemporalArray {

Self {
ext: ExtensionArray::new(
ExtDType::new(TIMESTAMP_ID.clone(), Some(temporal_metadata.clone().into())),
Arc::new(ExtDType::new(
TIMESTAMP_ID.clone(),
Arc::new(array.dtype().clone()),
Some(temporal_metadata.clone().into()),
)),
array,
),
temporal_metadata,
Expand All @@ -171,8 +177,8 @@ impl TemporalArray {
}

/// Retrieve the extension DType associated with the underlying array.
pub fn ext_dtype(&self) -> &ExtDType {
self.ext.ext_dtype()
pub fn ext_dtype(&self) -> Arc<ExtDType> {
self.ext.ext_dtype().clone()
}
}

Expand All @@ -195,7 +201,7 @@ impl TryFrom<&Array> for TemporalArray {
/// `TemporalMetadata` variants, an error is returned.
fn try_from(value: &Array) -> Result<Self, Self::Error> {
let ext = ExtensionArray::try_from(value)?;
let temporal_metadata = TemporalMetadata::try_from(ext.ext_dtype())?;
let temporal_metadata = TemporalMetadata::try_from(ext.ext_dtype().as_ref())?;

Ok(Self {
ext,
Expand Down Expand Up @@ -232,7 +238,7 @@ impl TryFrom<ExtensionArray> for TemporalArray {
type Error = VortexError;

fn try_from(ext: ExtensionArray) -> Result<Self, Self::Error> {
let temporal_metadata = TemporalMetadata::try_from(ext.ext_dtype())?;
let temporal_metadata = TemporalMetadata::try_from(ext.ext_dtype().as_ref())?;
Ok(Self {
ext,
temporal_metadata,
Expand Down
4 changes: 2 additions & 2 deletions vortex-array/src/array/extension/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ impl ScalarAtFn for ExtensionArray {
fn scalar_at(&self, index: usize) -> VortexResult<Scalar> {
Ok(Scalar::extension(
self.ext_dtype().clone(),
scalar_at(self.storage(), index)?,
scalar_at(self.storage(), index)?.into_value(),
))
}

fn scalar_at_unchecked(&self, index: usize) -> Scalar {
Scalar::extension(
self.ext_dtype().clone(),
scalar_at_unchecked(self.storage(), index),
scalar_at_unchecked(self.storage(), index).into_value(),
)
}
}
Expand Down
21 changes: 12 additions & 9 deletions vortex-array/src/array/extension/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::fmt::{Debug, Display};
use std::sync::Arc;

use serde::{Deserialize, Serialize};
use vortex_dtype::{DType, ExtDType, ExtID};
Expand All @@ -16,9 +17,7 @@ mod compute;
impl_encoding!("vortex.ext", ids::EXTENSION, Extension);

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtensionMetadata {
storage_dtype: DType,
}
pub struct ExtensionMetadata;

impl Display for ExtensionMetadata {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Expand All @@ -27,13 +26,17 @@ impl Display for ExtensionMetadata {
}

impl ExtensionArray {
pub fn new(ext_dtype: ExtDType, storage: Array) -> Self {
pub fn new(ext_dtype: Arc<ExtDType>, storage: Array) -> Self {
assert_eq!(
ext_dtype.storage_dtype(),
storage.dtype(),
"ExtensionArray: storage_dtype must match storage array DType",
);

Self::try_from_parts(
DType::Extension(ext_dtype, storage.dtype().nullability()),
DType::Extension(ext_dtype),
storage.len(),
ExtensionMetadata {
storage_dtype: storage.dtype().clone(),
},
ExtensionMetadata,
[storage].into(),
Default::default(),
)
Expand All @@ -42,7 +45,7 @@ impl ExtensionArray {

pub fn storage(&self) -> Array {
self.as_ref()
.child(0, &self.metadata().storage_dtype, self.len())
.child(0, self.ext_dtype().storage_dtype(), self.len())
.vortex_expect("Missing storage array for ExtensionArray")
}

Expand Down
18 changes: 9 additions & 9 deletions vortex-array/src/arrow/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,9 @@ impl FromArrowType<&Field> for DType {
| DataType::Date64
| DataType::Time32(_)
| DataType::Time64(_)
| DataType::Timestamp(..) => Extension(
make_temporal_ext_dtype(field.data_type()),
field.is_nullable().into(),
),
| DataType::Timestamp(..) => Extension(Arc::new(
make_temporal_ext_dtype(field.data_type()).with_nullability(nullability),
)),
DataType::List(e) | DataType::LargeList(e) => {
List(Arc::new(Self::from_arrow(e.as_ref())), nullability)
}
Expand Down Expand Up @@ -171,7 +170,7 @@ pub fn infer_data_type(dtype: &DType) -> VortexResult<DataType> {
// (32-bit), Large List View (64-bit). We cannot both guarantee zero-copy and commit to an
// Arrow dtype because we do not how large our offsets are.
DType::List(..) => vortex_bail!("Unsupported dtype: {}", dtype),
DType::Extension(ext_dtype, _) => {
DType::Extension(ext_dtype) => {
// Try and match against the known extension DTypes.
if is_temporal_ext_type(ext_dtype.id()) {
make_arrow_temporal_dtype(ext_dtype)
Expand Down Expand Up @@ -234,10 +233,11 @@ mod test {
#[test]
#[should_panic]
fn test_dtype_conversion_panics() {
let _ = infer_data_type(&DType::Extension(
ExtDType::new(ExtID::from("my-fake-ext-dtype"), None),
Nullability::NonNullable,
))
let _ = infer_data_type(&DType::Extension(Arc::new(ExtDType::new(
ExtID::from("my-fake-ext-dtype"),
Arc::new(DType::Utf8(Nullability::NonNullable)),
None,
))))
.unwrap();
}

Expand Down
Loading

0 comments on commit 2d439a2

Please sign in to comment.