-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use partial aggregation schema for spilling to avoid column mismatch in GroupedHashAggregateStream #13995
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you. I found the fix easy to follow 😄, and the change makes sense to me.
I have a suggestion to improve test coverage:
Since min/max
only has one intermediate aggregate state (partial min/max), we should also test aggregate functions that produce more than one intermediate state, like avg
(partial sum and count).
Duplicating the existing test and modifying one of the aggregate functions to avg should be sufficient.
datafusion/core/src/dataframe/mod.rs
Outdated
|
||
let result = | ||
common::collect(single_aggregate.execute(0, Arc::clone(&task_ctx))?).await?; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest to add an assertion here to make sure spilling actually happened for certain test cases. Like:
let metrics = single_aggregate.metrics();
// ...and assert some metrics inside like 'spill count' is > 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @2010YOUY01 for the review and suggestions.
I have implemented both.
datafusion/core/src/dataframe/mod.rs
Outdated
@@ -2743,6 +2754,143 @@ mod tests { | |||
Ok(()) | |||
} | |||
|
|||
// test for https://github.com/apache/datafusion/issues/13949 | |||
async fn run_test_with_spill_pool_if_necessary( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose it'll be better to move this test to other aggregate tests in datafusion/physical-plan/src/mod.rs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi @korowa ,
ie move to datafusion/physical-plan/src/aggregates/mod.rs, am I correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad, yes, I meant aggregates/mod.rs
@@ -522,7 +527,7 @@ impl GroupedHashAggregateStream { | |||
let spill_state = SpillState { | |||
spills: vec![], | |||
spill_expr, | |||
spill_schema: Arc::clone(&agg_schema), | |||
spill_schema: partial_agg_schema, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like the issue was related only to AggregateMode::Single[Partitioned]
cases, since for both Final and FinalPartitioned, there is a reassignment right before spilling (the new value is a schema for Partial output which is exactly group_by + state fields). Perhaps we can remove this reassignment now and rely on original spill_schema value set on stream creation (before removing it, we need to ensure that spill schema will be equal to intermediate result schema for any aggregation mode which supports spilling)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi @korowa ,
remove this reassignment now
In other words, remove these lines, am I correct?
datafusion/datafusion/physical-plan/src/aggregates/row_hash.rs
Lines 967 to 969 in 487b952
// Use input batch (Partial mode) schema for spilling because | |
// the spilled data will be merged and re-evaluated later. | |
self.spill_state.spill_schema = batch.schema(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this line seems to be redundant now -- I'd expect all aggregation modes to have the same spill schema (which is set by this PR), so it shouldn't depend on stream input anymore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for confirming.
The lines are removed.
/// This helper function constructs such a schema: | ||
/// `[group_col_1, group_col_2, ..., state_col_1, state_col_2, ...]` | ||
/// so that partial aggregation data can be handled consistently. | ||
fn build_partial_agg_schema( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps instead of the new helper we could reuse aggregates::create_schema?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked create_schema and it handles aggregates like MIN, MAX well but it does not handle AVG which has multiple intermediate states (partial sum, partial count).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I'm not mistaken, it should for mode = AggregateMode::Partial
-- for this case it also returns state_fields instead of result field
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aaa..... 🤔
Thanks for the pointer. It does work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you @kosiew @2010YOUY01
Going to merge it tomorrow, in case anyone else would like to review it.
@@ -43,6 +38,10 @@ use crate::physical_plan::{ | |||
ExecutionPlan, SendableRecordBatchStream, | |||
}; | |||
use crate::prelude::SessionContext; | |||
use std::any::Any; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor: this import reordering can be reverted to leave the file unmodified
❤️ |
Thanks for the rapid fix, @kosiew! |
Which issue does this PR close?
Closes #13949.
Rationale for this change
When an aggregation operator spills intermediate (partial) state to disk, it needs a schema that includes both the group-by columns and partial-aggregator columns (e.g., partial sums, counts, etc.). Previously, the code used the original input schema for spilling, which does not match the additional columns representing aggregator states. As a result, reading back the spilled data caused a mismatch error:
This PR addresses that by introducing a partial aggregation schema that combines group columns and aggregator state columns, ensuring consistency when spilling and later reading the spilled data.
What changes are included in this PR?
Are these changes tested?
Yes
Are there any user-facing changes?
No