Skip to content

Commit

Permalink
fix: merge pushdown handling (#2326)
Browse files Browse the repository at this point in the history
# Description
Fix broken test case with partitions

- fixes #2158

---------

Co-authored-by: ion-elgreco <[email protected]>
  • Loading branch information
Blajda and ion-elgreco authored Mar 23, 2024
1 parent 7928e95 commit 00c919f
Showing 1 changed file with 126 additions and 28 deletions.
154 changes: 126 additions & 28 deletions crates/core/src/operations/merge/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ use datafusion_common::{Column, DFSchema, ScalarValue, TableReference};
use datafusion_expr::expr::Placeholder;
use datafusion_expr::{col, conditional_expressions::CaseBuilder, lit, when, Expr, JoinType};
use datafusion_expr::{
BinaryExpr, Distinct, Extension, Filter, LogicalPlan, LogicalPlanBuilder, Operator, Projection,
BinaryExpr, Distinct, Extension, LogicalPlan, LogicalPlanBuilder, Operator, Projection,
UserDefinedLogicalNode, UNNAMED_TABLE,
};
use futures::future::BoxFuture;
Expand Down Expand Up @@ -866,10 +866,6 @@ async fn try_construct_early_filter(
let table_metadata = table_snapshot.metadata();
let partition_columns = &table_metadata.partition_columns;

if partition_columns.is_empty() {
return Ok(None);
}

let mut placeholders = HashMap::default();

match generalize_filter(
Expand Down Expand Up @@ -897,20 +893,16 @@ async fn try_construct_early_filter(
)?)
.into(),
));

let execution_plan = session_state
.create_physical_plan(&distinct_partitions)
.await?;

let items = execute_plan_to_batch(session_state, execution_plan).await?;

let placeholder_names = items
.schema()
.fields()
.iter()
.map(|f| f.name().to_owned())
.collect_vec();

let expr = (0..items.num_rows())
.map(|i| {
let replacements = placeholder_names
Expand All @@ -926,7 +918,6 @@ async fn try_construct_early_filter(
.collect::<DeltaResult<Vec<_>>>()?
.into_iter()
.reduce(Expr::or);

Ok(expr)
}
}
Expand All @@ -953,6 +944,7 @@ async fn execute(
let exec_start = Instant::now();

let current_metadata = snapshot.metadata();
let state = state.with_query_planner(Arc::new(MergePlanner {}));

// TODO: Given the join predicate, remove any expression that involve the
// source table and keep expressions that only involve the target table.
Expand Down Expand Up @@ -993,28 +985,24 @@ async fn execute(
.with_file_column(true)
.build(snapshot)?;

let file_column = Arc::new(scan_config.file_column_name.clone().unwrap());

let target_provider = Arc::new(DeltaTableProvider::try_new(
snapshot.clone(),
log_store.clone(),
scan_config,
scan_config.clone(),
)?);

let target_provider = provider_as_source(target_provider);

let target = LogicalPlanBuilder::scan(target_name.clone(), target_provider, None)?.build()?;
let target =
LogicalPlanBuilder::scan(target_name.clone(), target_provider.clone(), None)?.build()?;

let source_schema = source.schema();
let target_schema = target.schema();
let join_schema_df = build_join_schema(source_schema, target_schema, &JoinType::Full)?;
let join_schema_df = build_join_schema(source_schema, &target_schema, &JoinType::Full)?;
let predicate = match predicate {
Expression::DataFusion(expr) => expr,
Expression::String(s) => parse_predicate_expression(&join_schema_df, s, &state)?,
};

let state = state.with_query_planner(Arc::new(MergePlanner {}));

// Attempt to construct an early filter that we can apply to the Add action list and the delta scan.
// In the case where there are partition columns in the join predicate, we can scan the source table
// to get the distinct list of partitions affected and constrain the search to those.
Expand All @@ -1033,11 +1021,24 @@ async fn execute(
)
.await?
};
let target = match target_subset_filter.as_ref() {
None => target,
Some(subset_filter) => {
LogicalPlan::Filter(Filter::try_new(subset_filter.clone(), target.into())?)

let file_column = Arc::new(scan_config.file_column_name.clone().unwrap());
// Need to manually push this filter into the scan... We want to PRUNE files not FILTER RECORDS
let target = match target_subset_filter.clone() {
Some(filter) => {
let filter = match &target_alias {
Some(alias) => remove_table_alias(filter, alias),
None => filter,
};
LogicalPlanBuilder::scan_with_filters(
target_name.clone(),
target_provider,
None,
vec![filter],
)?
.build()?
}
None => LogicalPlanBuilder::scan(target_name.clone(), target_provider, None)?.build()?,
};

let source = DataFrame::new(state.clone(), source);
Expand Down Expand Up @@ -1428,7 +1429,7 @@ async fn execute(
let commit_predicate = match target_subset_filter {
None => None, // No predicate means it's a full table merge
Some(some_filter) => {
let predict_expr = match target_alias {
let predict_expr = match &target_alias {
None => some_filter,
Some(alias) => remove_table_alias(some_filter, alias),
};
Expand Down Expand Up @@ -1463,7 +1464,7 @@ async fn execute(
))
}

fn remove_table_alias(expr: Expr, table_alias: String) -> Expr {
fn remove_table_alias(expr: Expr, table_alias: &str) -> Expr {
expr.transform(&|expr| match expr {
Expr::Column(c) => match c.relation {
Some(rel) if rel.table() == table_alias => Ok(Transformed::Yes(Expr::Column(
Expand Down Expand Up @@ -2017,7 +2018,6 @@ mod tests {
let table = setup_table(Some(vec!["modified"])).await;
let table = write_data(table, &schema).await;
assert_eq!(table.version(), 1);

let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
Expand All @@ -2032,7 +2032,6 @@ mod tests {
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();

let (table, _metrics) = DeltaOps(table)
.merge(
source,
Expand All @@ -2057,9 +2056,7 @@ mod tests {
.unwrap()
.await
.unwrap();

assert_eq!(table.version(), 2);

let commit_info = table.history(None).await.unwrap();
let last_commit = &commit_info[0];
let parameters = last_commit.operation_parameters.clone().unwrap();
Expand Down Expand Up @@ -2485,6 +2482,7 @@ mod tests {
let commit_info = table.history(None).await.unwrap();
let last_commit = &commit_info[0];
let parameters = last_commit.operation_parameters.clone().unwrap();

assert_eq!(parameters["predicate"], json!("modified = '2021-02-02'"));

let expected = vec![
Expand Down Expand Up @@ -2974,4 +2972,104 @@ mod tests {
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);
}

#[tokio::test]
async fn test_merge_pushdowns_partitioned() {
//See #2158
let schema = vec![
StructField::new(
"id".to_string(),
DataType::Primitive(PrimitiveType::String),
true,
),
StructField::new(
"cost".to_string(),
DataType::Primitive(PrimitiveType::Float),
true,
),
StructField::new(
"month".to_string(),
DataType::Primitive(PrimitiveType::String),
true,
),
];

let arrow_schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", ArrowDataType::Utf8, true),
Field::new("cost", ArrowDataType::Float32, true),
Field::new("month", ArrowDataType::Utf8, true),
]));

let part_cols = vec!["month"];
let table = DeltaOps::new_in_memory()
.create()
.with_columns(schema)
.with_partition_columns(part_cols)
.await
.unwrap();

let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&arrow_schema.clone()),
vec![
Arc::new(arrow::array::StringArray::from(vec!["A", "B"])),
Arc::new(arrow::array::Float32Array::from(vec![Some(10.15), None])),
Arc::new(arrow::array::StringArray::from(vec![
"2023-07-04",
"2023-07-04",
])),
],
)
.unwrap();

let table = DeltaOps(table)
.write(vec![batch.clone()])
.with_save_mode(SaveMode::Append)
.await
.unwrap();
assert_eq!(table.version(), 1);
assert_eq!(table.get_files_count(), 1);

let batch = RecordBatch::try_new(
Arc::clone(&arrow_schema.clone()),
vec![
Arc::new(arrow::array::StringArray::from(vec!["A", "B"])),
Arc::new(arrow::array::Float32Array::from(vec![
Some(12.15),
Some(11.15),
])),
Arc::new(arrow::array::StringArray::from(vec![
"2023-07-04",
"2023-07-04",
])),
],
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();

let (table, _metrics) = DeltaOps(table)
.merge(source, "target.id = source.id and target.cost is null")
.with_source_alias("source")
.with_target_alias("target")
.when_matched_update(|insert| {
insert
.update("id", "target.id")
.update("cost", "source.cost")
.update("month", "target.month")
})
.unwrap()
.await
.unwrap();

let expected = vec![
"+----+-------+------------+",
"| id | cost | month |",
"+----+-------+------------+",
"| A | 10.15 | 2023-07-04 |",
"| B | 11.15 | 2023-07-04 |",
"+----+-------+------------+",
];
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);
}
}

0 comments on commit 00c919f

Please sign in to comment.