Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/expr/src/expr_rewriter/order_by.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Rewrite for order by expressions
19
20
use crate::expr::Alias;
21
use crate::expr_rewriter::normalize_col;
22
use crate::{expr::Sort, Cast, Expr, LogicalPlan, TryCast};
23
24
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
25
use datafusion_common::{Column, Result};
26
27
/// Rewrite sort on aggregate expressions to sort on the column of aggregate output
28
/// For example, `max(x)` is written to `col("max(x)")`
29
0
pub fn rewrite_sort_cols_by_aggs(
30
0
    sorts: impl IntoIterator<Item = impl Into<Sort>>,
31
0
    plan: &LogicalPlan,
32
0
) -> Result<Vec<Sort>> {
33
0
    sorts
34
0
        .into_iter()
35
0
        .map(|e| {
36
0
            let sort = e.into();
37
0
            Ok(Sort::new(
38
0
                rewrite_sort_col_by_aggs(sort.expr, plan)?,
39
0
                sort.asc,
40
0
                sort.nulls_first,
41
            ))
42
0
        })
43
0
        .collect()
44
0
}
45
46
0
fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
47
0
    let plan_inputs = plan.inputs();
48
0
49
0
    // Joins, and Unions are not yet handled (should have a projection
50
0
    // on top of them)
51
0
    if plan_inputs.len() == 1 {
52
0
        let proj_exprs = plan.expressions();
53
0
        rewrite_in_terms_of_projection(expr, proj_exprs, plan_inputs[0])
54
    } else {
55
0
        Ok(expr)
56
    }
57
0
}
58
59
/// Rewrites a sort expression in terms of the output of the previous [`LogicalPlan`]
60
///
61
/// Example:
62
///
63
/// Given an input expression such as `col(a) + col(b) + col(c)`
64
///
65
/// into `col(a) + col("b + c")`
66
///
67
/// Remember that:
68
/// 1. given a projection with exprs: [a, b + c]
69
/// 2. t produces an output schema with two columns "a", "b + c"
70
0
fn rewrite_in_terms_of_projection(
71
0
    expr: Expr,
72
0
    proj_exprs: Vec<Expr>,
73
0
    input: &LogicalPlan,
74
0
) -> Result<Expr> {
75
0
    // assumption is that each item in exprs, such as "b + c" is
76
0
    // available as an output column named "b + c"
77
0
    expr.transform(|expr| {
78
        // search for unnormalized names first such as "c1" (such as aliases)
79
0
        if let Some(found) = proj_exprs.iter().find(|a| (**a) == expr) {
80
0
            let (qualifier, field_name) = found.qualified_name();
81
0
            let col = Expr::Column(Column::new(qualifier, field_name));
82
0
            return Ok(Transformed::yes(col));
83
0
        }
84
85
        // if that doesn't work, try to match the expression as an
86
        // output column -- however first it must be "normalized"
87
        // (e.g. "c1" --> "t.c1") because that normalization is done
88
        // at the input of the aggregate.
89
90
0
        let normalized_expr = if let Ok(e) = normalize_col(expr.clone(), input) {
91
0
            e
92
        } else {
93
            // The expr is not based on Aggregate plan output. Skip it.
94
0
            return Ok(Transformed::no(expr));
95
        };
96
97
        // expr is an actual expr like min(t.c2), but we are looking
98
        // for a column with the same "MIN(C2)", so translate there
99
0
        let name = normalized_expr.schema_name().to_string();
100
0
101
0
        let search_col = Expr::Column(Column {
102
0
            relation: None,
103
0
            name,
104
0
        });
105
106
        // look for the column named the same as this expr
107
0
        if let Some(found) = proj_exprs.iter().find(|a| expr_match(&search_col, a)) {
108
0
            let found = found.clone();
109
0
            return Ok(Transformed::yes(match normalized_expr {
110
0
                Expr::Cast(Cast { expr: _, data_type }) => Expr::Cast(Cast {
111
0
                    expr: Box::new(found),
112
0
                    data_type,
113
0
                }),
114
0
                Expr::TryCast(TryCast { expr: _, data_type }) => Expr::TryCast(TryCast {
115
0
                    expr: Box::new(found),
116
0
                    data_type,
117
0
                }),
118
0
                _ => found,
119
            }));
120
0
        }
121
0
122
0
        Ok(Transformed::no(expr))
123
0
    })
124
0
    .data()
125
0
}
126
127
/// Does the underlying expr match e?
128
/// so avg(c) as average will match avgc
129
0
fn expr_match(needle: &Expr, expr: &Expr) -> bool {
130
    // check inside aliases
131
0
    if let Expr::Alias(Alias { expr, .. }) = &expr {
132
0
        expr.as_ref() == needle
133
    } else {
134
0
        expr == needle
135
    }
136
0
}
137
138
#[cfg(test)]
139
mod test {
140
    use std::ops::Add;
141
    use std::sync::Arc;
142
143
    use arrow::datatypes::{DataType, Field, Schema};
144
145
    use crate::{
146
        cast, col, lit, logical_plan::builder::LogicalTableSource, try_cast,
147
        LogicalPlanBuilder,
148
    };
149
150
    use super::*;
151
    use crate::test::function_stub::avg;
152
    use crate::test::function_stub::min;
153
154
    #[test]
155
    fn rewrite_sort_cols_by_agg() {
156
        //  gby c1, agg: min(c2)
157
        let agg = make_input()
158
            .aggregate(
159
                // gby: c1
160
                vec![col("c1")],
161
                // agg: min(c2)
162
                vec![min(col("c2"))],
163
            )
164
            .unwrap()
165
            .build()
166
            .unwrap();
167
168
        let cases = vec![
169
            TestCase {
170
                desc: "c1 --> c1",
171
                input: sort(col("c1")),
172
                expected: sort(col("c1")),
173
            },
174
            TestCase {
175
                desc: "c1 + c2 --> c1 + c2",
176
                input: sort(col("c1") + col("c1")),
177
                expected: sort(col("c1") + col("c1")),
178
            },
179
            TestCase {
180
                desc: r#"min(c2) --> "min(c2)"#,
181
                input: sort(min(col("c2"))),
182
                expected: sort(min(col("c2"))),
183
            },
184
            TestCase {
185
                desc: r#"c1 + min(c2) --> "c1 + min(c2)"#,
186
                input: sort(col("c1") + min(col("c2"))),
187
                expected: sort(col("c1") + min(col("c2"))),
188
            },
189
        ];
190
191
        for case in cases {
192
            case.run(&agg)
193
        }
194
    }
195
196
    #[test]
197
    fn rewrite_sort_cols_by_agg_alias() {
198
        let agg = make_input()
199
            .aggregate(
200
                // gby c1
201
                vec![col("c1")],
202
                // agg: min(c2), avg(c3)
203
                vec![min(col("c2")), avg(col("c3"))],
204
            )
205
            .unwrap()
206
            //  projects out an expression "c1" that is different than the column "c1"
207
            .project(vec![
208
                // c1 + 1 as c1,
209
                col("c1").add(lit(1)).alias("c1"),
210
                // min(c2)
211
                min(col("c2")),
212
                // avg("c3") as average
213
                avg(col("c3")).alias("average"),
214
            ])
215
            .unwrap()
216
            .build()
217
            .unwrap();
218
219
        let cases = vec![
220
            TestCase {
221
                desc: "c1 --> c1  -- column *named* c1 that came out of the projection, (not t.c1)",
222
                input: sort(col("c1")),
223
                // should be "c1" not t.c1
224
                expected: sort(col("c1")),
225
            },
226
            TestCase {
227
                desc: r#"min(c2) --> "min(c2)" -- (column *named* "min(t.c2)"!)"#,
228
                input: sort(min(col("c2"))),
229
                expected: sort(col("min(t.c2)")),
230
            },
231
            TestCase {
232
                desc: r#"c1 + min(c2) --> "c1 + min(c2)" -- (column *named* "min(t.c2)"!)"#,
233
                input: sort(col("c1") + min(col("c2"))),
234
                // should be "c1" not t.c1
235
                expected: sort(col("c1") + col("min(t.c2)")),
236
            },
237
            TestCase {
238
                desc: r#"avg(c3) --> "avg(t.c3)" as average (column *named* "avg(t.c3)", aliased)"#,
239
                input: sort(avg(col("c3"))),
240
                expected: sort(col("avg(t.c3)").alias("average")),
241
            },
242
        ];
243
244
        for case in cases {
245
            case.run(&agg)
246
        }
247
    }
248
249
    #[test]
250
    fn preserve_cast() {
251
        let plan = make_input()
252
            .project(vec![col("c2").alias("c2")])
253
            .unwrap()
254
            .project(vec![col("c2").alias("c2")])
255
            .unwrap()
256
            .build()
257
            .unwrap();
258
259
        let cases = vec![
260
            TestCase {
261
                desc: "Cast is preserved by rewrite_sort_cols_by_aggs",
262
                input: sort(cast(col("c2"), DataType::Int64)),
263
                expected: sort(cast(col("c2").alias("c2"), DataType::Int64)),
264
            },
265
            TestCase {
266
                desc: "TryCast is preserved by rewrite_sort_cols_by_aggs",
267
                input: sort(try_cast(col("c2"), DataType::Int64)),
268
                expected: sort(try_cast(col("c2").alias("c2"), DataType::Int64)),
269
            },
270
        ];
271
272
        for case in cases {
273
            case.run(&plan)
274
        }
275
    }
276
277
    struct TestCase {
278
        desc: &'static str,
279
        input: Sort,
280
        expected: Sort,
281
    }
282
283
    impl TestCase {
284
        /// calls rewrite_sort_cols_by_aggs for expr and compares it to expected_expr
285
        fn run(self, input_plan: &LogicalPlan) {
286
            let Self {
287
                desc,
288
                input,
289
                expected,
290
            } = self;
291
292
            println!("running: '{desc}'");
293
            let mut exprs =
294
                rewrite_sort_cols_by_aggs(vec![input.clone()], input_plan).unwrap();
295
296
            assert_eq!(exprs.len(), 1);
297
            let rewritten = exprs.pop().unwrap();
298
299
            assert_eq!(
300
                rewritten, expected,
301
                "\n\ninput:{input:?}\nrewritten:{rewritten:?}\nexpected:{expected:?}\n"
302
            );
303
        }
304
    }
305
306
    /// Scan of a table: t(c1 int, c2 varchar, c3 float)
307
    fn make_input() -> LogicalPlanBuilder {
308
        let schema = Arc::new(Schema::new(vec![
309
            Field::new("c1", DataType::Int32, true),
310
            Field::new("c2", DataType::Utf8, true),
311
            Field::new("c3", DataType::Float64, true),
312
        ]));
313
        let projection = None;
314
        LogicalPlanBuilder::scan(
315
            "t",
316
            Arc::new(LogicalTableSource::new(schema)),
317
            projection,
318
        )
319
        .unwrap()
320
    }
321
322
    fn sort(expr: Expr) -> Sort {
323
        let asc = true;
324
        let nulls_first = true;
325
        expr.sort(asc, nulls_first)
326
    }
327
}