From 0cab805d67a86dc7fe68858936546e15f5c1fcdf Mon Sep 17 00:00:00 2001 From: Xuanwo Date: Thu, 17 Oct 2024 19:08:28 +0800 Subject: [PATCH] Provide field and schema metadata missing on distinct aggregations. (#12691) (#12975) * test(12687): reproducer of missing metadata bug * fix(12687): minimum change needed to fix the missing metadata Co-authored-by: wiedld Co-authored-by: Andrew Lamb --- .../physical-plan/src/aggregates/mod.rs | 25 ++++++++----- datafusion/physical-plan/src/projection.rs | 2 +- .../sqllogictest/test_files/metadata.slt | 36 +++++++++++++++++++ 3 files changed, 53 insertions(+), 10 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 617f1da3abdd..cd6db0964011 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -26,6 +26,7 @@ use crate::aggregates::{ topk_stream::GroupedTopKAggregateStream, }; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; +use crate::projection::get_field_metadata; use crate::windows::get_ordered_partition_by_indices; use crate::{ DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode, @@ -793,14 +794,17 @@ fn create_schema( ) -> Result { let mut fields = Vec::with_capacity(group_expr.len() + aggr_expr.len()); for (index, (expr, name)) in group_expr.iter().enumerate() { - fields.push(Field::new( - name, - expr.data_type(input_schema)?, - // In cases where we have multiple grouping sets, we will use NULL expressions in - // order to align the grouping sets. So the field must be nullable even if the underlying - // schema field is not. - group_expr_nullable[index] || expr.nullable(input_schema)?, - )) + fields.push( + Field::new( + name, + expr.data_type(input_schema)?, + // In cases where we have multiple grouping sets, we will use NULL expressions in + // order to align the grouping sets. So the field must be nullable even if the underlying + // schema field is not. + group_expr_nullable[index] || expr.nullable(input_schema)?, + ) + .with_metadata(get_field_metadata(expr, input_schema).unwrap_or_default()), + ) } match mode { @@ -821,7 +825,10 @@ fn create_schema( } } - Ok(Schema::new(fields)) + Ok(Schema::new_with_metadata( + fields, + input_schema.metadata().clone(), + )) } fn group_schema(schema: &Schema, group_count: usize) -> SchemaRef { diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index f1b9cdaf728f..4c889d1fc88c 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -237,7 +237,7 @@ impl ExecutionPlan for ProjectionExec { /// If e is a direct column reference, returns the field level /// metadata for that field, if any. Otherwise returns None -fn get_field_metadata( +pub(crate) fn get_field_metadata( e: &Arc, input_schema: &Schema, ) -> Option> { diff --git a/datafusion/sqllogictest/test_files/metadata.slt b/datafusion/sqllogictest/test_files/metadata.slt index d6e3ad0c2062..92bc24489feb 100644 --- a/datafusion/sqllogictest/test_files/metadata.slt +++ b/datafusion/sqllogictest/test_files/metadata.slt @@ -59,6 +59,42 @@ WHERE "data"."id" = "samples"."id"; 3 + +# Regression test: prevent field metadata loss per https://github.com/apache/datafusion/issues/12687 +query I +select count(distinct name) from table_with_metadata; +---- +2 + +# Regression test: prevent field metadata loss per https://github.com/apache/datafusion/issues/12687 +query I +select approx_median(distinct id) from table_with_metadata; +---- +2 + +# Regression test: prevent field metadata loss per https://github.com/apache/datafusion/issues/12687 +statement ok +select array_agg(distinct id) from table_with_metadata; + +query I +select distinct id from table_with_metadata order by id; +---- +1 +3 +NULL + +query I +select count(id) from table_with_metadata; +---- +2 + +query I +select count(id) cnt from table_with_metadata group by name order by cnt; +---- +0 +1 +1 + # Regression test: missing schema metadata, when aggregate on cross join query I SELECT count("data"."id")