Skip to content

Commit

Permalink
feat: [comet-parquet-exec] Schema adapter fixes (#1139)
Browse files Browse the repository at this point in the history
* support more timestamp conversions

* improve error handling

* rename projected_table_schema to required_schema

* Save

* save

* save

* code cleanup
  • Loading branch information
andygrove authored Dec 6, 2024
1 parent bf5a2c6 commit bd797f5
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 171 deletions.
31 changes: 15 additions & 16 deletions native/core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
use super::expressions::EvalMode;
use crate::execution::datafusion::expressions::comet_scalar_funcs::create_comet_physical_fun;
use crate::execution::operators::{CopyMode, FilterExec};
use crate::execution::operators::{CopyMode, FilterExec as CometFilterExec};
use crate::{
errors::ExpressionError,
execution::{
Expand Down Expand Up @@ -55,6 +55,7 @@ use datafusion::functions_aggregate::bit_and_or_xor::{bit_and_udaf, bit_or_udaf,
use datafusion::functions_aggregate::min_max::max_udaf;
use datafusion::functions_aggregate::min_max::min_udaf;
use datafusion::functions_aggregate::sum::sum_udaf;
use datafusion::physical_plan::filter::FilterExec;
use datafusion::physical_plan::windows::BoundedWindowAggExec;
use datafusion::physical_plan::InputOrderMode;
use datafusion::{
Expand Down Expand Up @@ -102,7 +103,7 @@ use datafusion_comet_proto::{
};
use datafusion_comet_spark_expr::{
Cast, CreateNamedStruct, DateTruncExpr, GetArrayStructFields, GetStructField, HourExpr, IfExpr,
ListExtract, MinuteExpr, RLike, SecondExpr, TimestampTruncExpr, ToJson,
ListExtract, MinuteExpr, RLike, SecondExpr, SparkCastOptions, TimestampTruncExpr, ToJson,
};
use datafusion_common::config::TableParquetOptions;
use datafusion_common::scalar::ScalarStructBuilder;
Expand Down Expand Up @@ -392,14 +393,11 @@ impl PhysicalPlanner {
ExprStruct::Cast(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let timezone = expr.timezone.clone();
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
Ok(Arc::new(Cast::new(
child,
datatype,
eval_mode,
timezone,
expr.allow_incompat,
SparkCastOptions::new(eval_mode, &expr.timezone, expr.allow_incompat),
)))
}
ExprStruct::Hour(expr) => {
Expand Down Expand Up @@ -767,24 +765,21 @@ impl PhysicalPlanner {
let data_type = return_type.map(to_arrow_datatype).unwrap();
// For some Decimal128 operations, we need wider internal digits.
// Cast left and right to Decimal256 and cast the result back to Decimal128
let left = Arc::new(Cast::new_without_timezone(
let left = Arc::new(Cast::new(
left,
DataType::Decimal256(p1, s1),
EvalMode::Legacy,
false,
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
));
let right = Arc::new(Cast::new_without_timezone(
let right = Arc::new(Cast::new(
right,
DataType::Decimal256(p2, s2),
EvalMode::Legacy,
false,
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
));
let child = Arc::new(BinaryExpr::new(left, op, right));
Ok(Arc::new(Cast::new_without_timezone(
Ok(Arc::new(Cast::new(
child,
data_type,
EvalMode::Legacy,
false,
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
)))
}
(
Expand Down Expand Up @@ -851,7 +846,11 @@ impl PhysicalPlanner {
let predicate =
self.create_expr(filter.predicate.as_ref().unwrap(), child.schema())?;

Ok((scans, Arc::new(FilterExec::try_new(predicate, child)?)))
if can_reuse_input_batch(&child) {
Ok((scans, Arc::new(CometFilterExec::try_new(predicate, child)?)))
} else {
Ok((scans, Arc::new(FilterExec::try_new(predicate, child)?)))
}
}
OpStruct::HashAgg(agg) => {
assert!(children.len() == 1);
Expand Down
83 changes: 42 additions & 41 deletions native/core/src/execution/datafusion/schema_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
use arrow::compute::can_cast_types;
use arrow_array::{new_null_array, Array, RecordBatch, RecordBatchOptions};
use arrow_schema::{DataType, Schema, SchemaRef};
use arrow_schema::{DataType, Schema, SchemaRef, TimeUnit};
use datafusion::datasource::schema_adapter::{SchemaAdapter, SchemaAdapterFactory, SchemaMapper};
use datafusion_comet_spark_expr::{spark_cast, EvalMode};
use datafusion_comet_spark_expr::{spark_cast, EvalMode, SparkCastOptions};
use datafusion_common::plan_err;
use datafusion_expr::ColumnarValue;
use std::sync::Arc;
Expand All @@ -38,11 +38,11 @@ impl SchemaAdapterFactory for CometSchemaAdapterFactory {
/// schema.
fn create(
&self,
projected_table_schema: SchemaRef,
required_schema: SchemaRef,
table_schema: SchemaRef,
) -> Box<dyn SchemaAdapter> {
Box::new(CometSchemaAdapter {
projected_table_schema,
required_schema,
table_schema,
})
}
Expand All @@ -54,7 +54,7 @@ impl SchemaAdapterFactory for CometSchemaAdapterFactory {
pub struct CometSchemaAdapter {
/// The schema for the table, projected to include only the fields being output (projected) by the
/// associated ParquetExec
projected_table_schema: SchemaRef,
required_schema: SchemaRef,
/// The entire table schema for the table we're using this to adapt.
///
/// This is used to evaluate any filters pushed down into the scan
Expand All @@ -69,7 +69,7 @@ impl SchemaAdapter for CometSchemaAdapter {
///
/// Panics if index is not in range for the table schema
fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option<usize> {
let field = self.projected_table_schema.field(index);
let field = self.required_schema.field(index);
Some(file_schema.fields.find(field.name())?.0)
}

Expand All @@ -87,42 +87,34 @@ impl SchemaAdapter for CometSchemaAdapter {
file_schema: &Schema,
) -> datafusion_common::Result<(Arc<dyn SchemaMapper>, Vec<usize>)> {
let mut projection = Vec::with_capacity(file_schema.fields().len());
let mut field_mappings = vec![None; self.projected_table_schema.fields().len()];
let mut field_mappings = vec![None; self.required_schema.fields().len()];

for (file_idx, file_field) in file_schema.fields.iter().enumerate() {
if let Some((table_idx, table_field)) =
self.projected_table_schema.fields().find(file_field.name())
self.required_schema.fields().find(file_field.name())
{
// workaround for struct casting
match (file_field.data_type(), table_field.data_type()) {
// TODO need to use Comet cast logic to determine which casts are supported,
// but for now just add a hack to support casting between struct types
(DataType::Struct(_), DataType::Struct(_)) => {
field_mappings[table_idx] = Some(projection.len());
projection.push(file_idx);
}
_ => {
if can_cast_types(file_field.data_type(), table_field.data_type()) {
field_mappings[table_idx] = Some(projection.len());
projection.push(file_idx);
} else {
return plan_err!(
"Cannot cast file schema field {} of type {:?} to table schema field of type {:?}",
file_field.name(),
file_field.data_type(),
table_field.data_type()
);
}
}
if comet_can_cast_types(file_field.data_type(), table_field.data_type()) {
field_mappings[table_idx] = Some(projection.len());
projection.push(file_idx);
} else {
return plan_err!(
"Cannot cast file schema field {} of type {:?} to required schema field of type {:?}",
file_field.name(),
file_field.data_type(),
table_field.data_type()
);
}
}
}

let mut cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", false);
cast_options.is_adapting_schema = true;
Ok((
Arc::new(SchemaMapping {
projected_table_schema: Arc::<Schema>::clone(&self.projected_table_schema),
required_schema: Arc::<Schema>::clone(&self.required_schema),
field_mappings,
table_schema: Arc::<Schema>::clone(&self.table_schema),
cast_options
}),
projection,
))
Expand Down Expand Up @@ -161,7 +153,7 @@ impl SchemaAdapter for CometSchemaAdapter {
pub struct SchemaMapping {
/// The schema of the table. This is the expected schema after conversion
/// and it should match the schema of the query result.
projected_table_schema: SchemaRef,
required_schema: SchemaRef,
/// Mapping from field index in `projected_table_schema` to index in
/// projected file_schema.
///
Expand All @@ -173,6 +165,8 @@ pub struct SchemaMapping {
/// This contains all fields in the table, regardless of if they will be
/// projected out or not.
table_schema: SchemaRef,

cast_options: SparkCastOptions,
}

impl SchemaMapper for SchemaMapping {
Expand All @@ -185,7 +179,7 @@ impl SchemaMapper for SchemaMapping {
let batch_cols = batch.columns().to_vec();

let cols = self
.projected_table_schema
.required_schema
// go through each field in the projected schema
.fields()
.iter()
Expand All @@ -204,10 +198,7 @@ impl SchemaMapper for SchemaMapping {
spark_cast(
ColumnarValue::Array(Arc::clone(&batch_cols[batch_idx])),
field.data_type(),
// TODO need to pass in configs here
EvalMode::Legacy,
"UTC",
false,
&self.cast_options,
)?
.into_array(batch_rows)
},
Expand All @@ -218,7 +209,7 @@ impl SchemaMapper for SchemaMapping {
// Necessary to handle empty batches
let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));

let schema = Arc::<Schema>::clone(&self.projected_table_schema);
let schema = Arc::<Schema>::clone(&self.required_schema);
let record_batch = RecordBatch::try_new_with_options(schema, cols, &options)?;
Ok(record_batch)
}
Expand Down Expand Up @@ -255,10 +246,7 @@ impl SchemaMapper for SchemaMapping {
spark_cast(
ColumnarValue::Array(Arc::clone(batch_col)),
table_field.data_type(),
// TODO need to pass in configs here
EvalMode::Legacy,
"UTC",
false,
&self.cast_options,
)?
.into_array(batch_col.len())
// and if that works, return the field and column.
Expand All @@ -277,3 +265,16 @@ impl SchemaMapper for SchemaMapping {
Ok(record_batch)
}
}

fn comet_can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
// TODO this is just a quick hack to get tests passing
match (from_type, to_type) {
(DataType::Struct(_), DataType::Struct(_)) => {
// workaround for struct casting
true
}
// TODO this is maybe no longer needed
(_, DataType::Timestamp(TimeUnit::Nanosecond, _)) => false,
_ => can_cast_types(from_type, to_type),
}
}
2 changes: 1 addition & 1 deletion native/core/src/execution/operators/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ impl DisplayAs for FilterExec {

impl ExecutionPlan for FilterExec {
fn name(&self) -> &'static str {
"FilterExec"
"CometFilterExec"
}

/// Return a reference to Any that can be used for downcasting
Expand Down
Loading

0 comments on commit bd797f5

Please sign in to comment.