Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: new create_one ExpressionHandler API #662

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
284 changes: 283 additions & 1 deletion kernel/src/engine/arrow_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,87 @@ impl ExpressionHandler for ArrowExpressionHandler {
output_type,
})
}

fn create_one(&self, schema: SchemaRef, expr: &Expression) -> DeltaResult<Box<dyn EngineData>> {
if let Expression::Struct(child_exprs) = expr {
zachschuermann marked this conversation as resolved.
Show resolved Hide resolved
if schema.len() != child_exprs.len() {
return Err(Error::Generic(format!(
"Schema has {} top-level fields, but struct expr has {} children",
schema.len(),
child_exprs.len()
)));
}

let arrays: Vec<ArrayRef> = schema
.fields()
.zip(child_exprs.iter())
.map(|(field, child_expr)| create_single_row_array(child_expr, field))
.collect::<Result<Vec<_>, Error>>()?;
zachschuermann marked this conversation as resolved.
Show resolved Hide resolved

let record_batch = RecordBatch::try_new(Arc::new(schema.as_ref().try_into()?), arrays)?;
Ok(Box::new(ArrowEngineData::new(record_batch)))
} else {
Err(Error::generic(
"ArrowExpressionHandler::create_one() requires a top-level struct expression",
))
}
}
}

fn create_single_row_array(expr: &Expression, field: &StructField) -> DeltaResult<ArrayRef> {
match expr {
// simple case: for literals, just create a single-row array and ensure the data types match
Expression::Literal(scalar) => {
let array = scalar.to_array(1)?;
zachschuermann marked this conversation as resolved.
Show resolved Hide resolved
// TODO(zach): we could do better and cast here instead of always failing
zachschuermann marked this conversation as resolved.
Show resolved Hide resolved
ensure_data_types(field.data_type(), array.data_type(), true)?;
Ok(array)
}
// recursive case: for struct expressions, build a struct array by recursing into each child
Expression::Struct(child_exprs) => {
// co-traverse the expression and schema: we expect the data type to be struct, error
// otherwise.
match field.data_type() {
zachschuermann marked this conversation as resolved.
Show resolved Hide resolved
DataType::Struct(struct_type) => {
if struct_type.len() != child_exprs.len() {
return Err(Error::Generic(format!(
"Schema struct field has {} children, but expression has {} children",
struct_type.len(),
child_exprs.len()
)));
}

let child_arrays: Vec<ArrayRef> = struct_type
.fields()
.zip(child_exprs.iter())
.map(|(subfield, subexpr)| create_single_row_array(subexpr, subfield))
.collect::<Result<Vec<ArrayRef>, Error>>()?;

let arrow_fields = struct_type
.fields()
.map(|f| ArrowField::try_from(f))
.collect::<Result<Vec<ArrowField>, ArrowError>>()?;
zachschuermann marked this conversation as resolved.
Show resolved Hide resolved

let struct_array = StructArray::new(
arrow_fields.into(),
child_arrays,
None, // FIXME: null bitmap
zachschuermann marked this conversation as resolved.
Show resolved Hide resolved
);

Ok(Arc::new(struct_array))
}
other_type => Err(Error::Generic(format!(
"Expected struct type in schema, but got {:?}",
other_type
))),
}
}
// fail for any non-literal, non-struct expressions
non_literal_non_struct => Err(Error::Unsupported(format!(
"build_array_from_expr: unhandled expr variant: {:?}",
non_literal_non_struct
))),
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -568,7 +649,7 @@ impl ExpressionEvaluator for DefaultExpressionEvaluator {
mod tests {
use std::ops::{Add, Div, Mul, Sub};

use arrow_array::{GenericStringArray, Int32Array};
use arrow_array::{create_array, record_batch, GenericStringArray, Int32Array};
use arrow_buffer::ScalarBuffer;
use arrow_schema::{DataType, Field, Fields, Schema};

Expand Down Expand Up @@ -867,4 +948,205 @@ mod tests {
let expected = Arc::new(BooleanArray::from(vec![true, false]));
assert_eq!(results.as_ref(), expected.as_ref());
}

#[test]
fn test_create_one() {
let expr = Expression::struct_from([
Expression::literal(1),
Expression::literal(2),
Expression::literal(3),
]);
let schema = Arc::new(crate::schema::StructType::new([
StructField::new("a", DeltaDataTypes::INTEGER, true),
zachschuermann marked this conversation as resolved.
Show resolved Hide resolved
StructField::new("b", DeltaDataTypes::INTEGER, true),
StructField::new("c", DeltaDataTypes::INTEGER, true),
]));

let handler = ArrowExpressionHandler;
let actual = handler.create_one(schema, &expr).unwrap();
let expected =
record_batch!(("a", Int32, [1]), ("b", Int32, [2]), ("c", Int32, [3])).unwrap();
let actual_rb: RecordBatch = actual
.into_any()
.downcast::<ArrowEngineData>()
.unwrap()
.into();
assert_eq!(actual_rb, expected);
}

#[test]
fn test_create_one_string() {
let expr = Expression::struct_from([Expression::literal("a")]);
let schema = Arc::new(crate::schema::StructType::new([StructField::new(
zachschuermann marked this conversation as resolved.
Show resolved Hide resolved
"col_1",
DeltaDataTypes::STRING,
true,
)]));

let handler = ArrowExpressionHandler;
let actual = handler.create_one(schema, &expr).unwrap();
let expected = record_batch!(("col_1", Utf8, ["a"])).unwrap();
let actual_rb: RecordBatch = actual
.into_any()
.downcast::<ArrowEngineData>()
.unwrap()
.into();
zachschuermann marked this conversation as resolved.
Show resolved Hide resolved
assert_eq!(actual_rb, expected);
}

#[test]
fn test_create_one_null() {
let expr = Expression::struct_from([Expression::null_literal(DeltaDataTypes::INTEGER)]);
let schema = Arc::new(crate::schema::StructType::new([StructField::new(
"col_1",
DeltaDataTypes::INTEGER,
true,
)]));

let handler = ArrowExpressionHandler;
let actual = handler.create_one(schema, &expr).unwrap();
let expected = record_batch!(("col_1", Int32, [None])).unwrap();
let actual_rb: RecordBatch = actual
.into_any()
.downcast::<ArrowEngineData>()
.unwrap()
.into();
assert_eq!(actual_rb, expected);
}

#[test]
fn test_create_one_non_null() {
let expr = Expression::struct_from([Expression::literal(1)]);
let schema = Arc::new(crate::schema::StructType::new([StructField::new(
"a",
DeltaDataTypes::INTEGER,
false,
)]));

let handler = ArrowExpressionHandler;
let actual = handler.create_one(schema, &expr).unwrap();
let expected_schema = Arc::new(arrow_schema::Schema::new(vec![Field::new(
"a",
DataType::Int32,
false,
)]));
let expected =
RecordBatch::try_new(expected_schema, vec![create_array!(Int32, [1])]).unwrap();
let actual_rb: RecordBatch = actual
.into_any()
.downcast::<ArrowEngineData>()
.unwrap()
.into();
assert_eq!(actual_rb, expected);
}

#[test]
fn test_create_one_disallow_column_ref() {
let expr = Expression::struct_from([column_expr!("a")]);
let schema = Arc::new(crate::schema::StructType::new([StructField::new(
"a",
DeltaDataTypes::INTEGER,
true,
)]));

let handler = ArrowExpressionHandler;
let actual = handler.create_one(schema, &expr);
assert!(actual.is_err());
}

#[test]
fn test_create_one_disallow_operator() {
let expr = Expression::struct_from([Expression::binary(
BinaryOperator::Plus,
Expression::literal(1),
Expression::literal(2),
)]);
let schema = Arc::new(crate::schema::StructType::new([StructField::new(
"a",
DeltaDataTypes::INTEGER,
true,
)]));

let handler = ArrowExpressionHandler;
let actual = handler.create_one(schema, &expr);
assert!(actual.is_err());
}

#[test]
fn test_create_one_nested() {
let expr = Expression::struct_from([Expression::struct_from([
Expression::literal(1),
Expression::literal(2),
])]);
let schema = Arc::new(crate::schema::StructType::new([StructField::new(
"a",
DeltaDataTypes::struct_type([
StructField::new("b", DeltaDataTypes::INTEGER, true),
StructField::new("c", DeltaDataTypes::INTEGER, false),
]),
false,
)]));

let handler = ArrowExpressionHandler;
let actual = handler.create_one(schema, &expr).unwrap();
let expected_schema = Arc::new(arrow_schema::Schema::new(vec![Field::new(
"a",
DataType::Struct(
vec![
Field::new("b", DataType::Int32, true),
Field::new("c", DataType::Int32, false),
]
.into(),
),
false,
)]));
let expected = RecordBatch::try_new(
expected_schema,
vec![Arc::new(StructArray::from(vec![
(
Arc::new(Field::new("b", DataType::Int32, true)),
create_array!(Int32, [1]) as ArrayRef,
),
(
Arc::new(Field::new("c", DataType::Int32, false)),
create_array!(Int32, [2]) as ArrayRef,
),
]))],
)
.unwrap();
let actual_rb: RecordBatch = actual
.into_any()
.downcast::<ArrowEngineData>()
.unwrap()
.into();
assert_eq!(actual_rb, expected);

// make the same but with literal struct instead of struct of literal
zachschuermann marked this conversation as resolved.
Show resolved Hide resolved
let struct_data = StructData::try_new(
vec![
StructField::new("b", DeltaDataTypes::INTEGER, true),
StructField::new("c", DeltaDataTypes::INTEGER, false),
],
vec![Scalar::Integer(1), Scalar::Integer(2)],
)
.unwrap();
let expr = Expression::struct_from([Expression::literal(Scalar::Struct(struct_data))]);
let schema = Arc::new(crate::schema::StructType::new([StructField::new(
"a",
DeltaDataTypes::struct_type([
StructField::new("b", DeltaDataTypes::INTEGER, true),
StructField::new("c", DeltaDataTypes::INTEGER, false),
]),
false,
)]));

let handler = ArrowExpressionHandler;
let actual = handler.create_one(schema, &expr).unwrap();
let actual_rb: RecordBatch = actual
.into_any()
.downcast::<ArrowEngineData>()
.unwrap()
.into();
assert_eq!(actual_rb, expected);
}
}
9 changes: 9 additions & 0 deletions kernel/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,15 @@ pub trait ExpressionHandler: AsAny {
expression: Expression,
output_type: DataType,
) -> Arc<dyn ExpressionEvaluator>;

/// Create a single-row [`EngineData`] by evaluating an [`Expression`] with no column
/// references.
///
/// The schema of the output is the schema parameter which must match the output of the
/// expression.
// Note: we will stick with a Schema instead of DataType (more constrained can expand in
// future)
fn create_one(&self, schema: SchemaRef, expr: &Expression) -> DeltaResult<Box<dyn EngineData>>;
}

/// Provides file system related functionalities to Delta Kernel.
Expand Down
5 changes: 5 additions & 0 deletions kernel/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,11 @@ impl StructType {
self.fields.values()
}

pub fn len(&self) -> usize {
// O(1) for indexmap
self.fields.len()
}

/// Extracts the name and type of all leaf columns, in schema order. Caller should pass Some
/// `own_name` if this schema is embedded in a larger struct (e.g. `add.*`) and None if the
/// schema is a top-level result (e.g. `*`).
Expand Down
Loading