/Users/andrewlamb/Software/datafusion/datafusion/expr/src/expr_rewriter/order_by.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Rewrite for order by expressions |
19 | | |
20 | | use crate::expr::Alias; |
21 | | use crate::expr_rewriter::normalize_col; |
22 | | use crate::{expr::Sort, Cast, Expr, LogicalPlan, TryCast}; |
23 | | |
24 | | use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; |
25 | | use datafusion_common::{Column, Result}; |
26 | | |
27 | | /// Rewrite sort on aggregate expressions to sort on the column of aggregate output |
28 | | /// For example, `max(x)` is written to `col("max(x)")` |
29 | 0 | pub fn rewrite_sort_cols_by_aggs( |
30 | 0 | sorts: impl IntoIterator<Item = impl Into<Sort>>, |
31 | 0 | plan: &LogicalPlan, |
32 | 0 | ) -> Result<Vec<Sort>> { |
33 | 0 | sorts |
34 | 0 | .into_iter() |
35 | 0 | .map(|e| { |
36 | 0 | let sort = e.into(); |
37 | 0 | Ok(Sort::new( |
38 | 0 | rewrite_sort_col_by_aggs(sort.expr, plan)?, |
39 | 0 | sort.asc, |
40 | 0 | sort.nulls_first, |
41 | | )) |
42 | 0 | }) |
43 | 0 | .collect() |
44 | 0 | } |
45 | | |
46 | 0 | fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result<Expr> { |
47 | 0 | let plan_inputs = plan.inputs(); |
48 | 0 |
|
49 | 0 | // Joins, and Unions are not yet handled (should have a projection |
50 | 0 | // on top of them) |
51 | 0 | if plan_inputs.len() == 1 { |
52 | 0 | let proj_exprs = plan.expressions(); |
53 | 0 | rewrite_in_terms_of_projection(expr, proj_exprs, plan_inputs[0]) |
54 | | } else { |
55 | 0 | Ok(expr) |
56 | | } |
57 | 0 | } |
58 | | |
59 | | /// Rewrites a sort expression in terms of the output of the previous [`LogicalPlan`] |
60 | | /// |
61 | | /// Example: |
62 | | /// |
63 | | /// Given an input expression such as `col(a) + col(b) + col(c)` |
64 | | /// |
65 | | /// into `col(a) + col("b + c")` |
66 | | /// |
67 | | /// Remember that: |
68 | | /// 1. given a projection with exprs: [a, b + c] |
69 | | /// 2. t produces an output schema with two columns "a", "b + c" |
70 | 0 | fn rewrite_in_terms_of_projection( |
71 | 0 | expr: Expr, |
72 | 0 | proj_exprs: Vec<Expr>, |
73 | 0 | input: &LogicalPlan, |
74 | 0 | ) -> Result<Expr> { |
75 | 0 | // assumption is that each item in exprs, such as "b + c" is |
76 | 0 | // available as an output column named "b + c" |
77 | 0 | expr.transform(|expr| { |
78 | | // search for unnormalized names first such as "c1" (such as aliases) |
79 | 0 | if let Some(found) = proj_exprs.iter().find(|a| (**a) == expr) { |
80 | 0 | let (qualifier, field_name) = found.qualified_name(); |
81 | 0 | let col = Expr::Column(Column::new(qualifier, field_name)); |
82 | 0 | return Ok(Transformed::yes(col)); |
83 | 0 | } |
84 | | |
85 | | // if that doesn't work, try to match the expression as an |
86 | | // output column -- however first it must be "normalized" |
87 | | // (e.g. "c1" --> "t.c1") because that normalization is done |
88 | | // at the input of the aggregate. |
89 | | |
90 | 0 | let normalized_expr = if let Ok(e) = normalize_col(expr.clone(), input) { |
91 | 0 | e |
92 | | } else { |
93 | | // The expr is not based on Aggregate plan output. Skip it. |
94 | 0 | return Ok(Transformed::no(expr)); |
95 | | }; |
96 | | |
97 | | // expr is an actual expr like min(t.c2), but we are looking |
98 | | // for a column with the same "MIN(C2)", so translate there |
99 | 0 | let name = normalized_expr.schema_name().to_string(); |
100 | 0 |
|
101 | 0 | let search_col = Expr::Column(Column { |
102 | 0 | relation: None, |
103 | 0 | name, |
104 | 0 | }); |
105 | | |
106 | | // look for the column named the same as this expr |
107 | 0 | if let Some(found) = proj_exprs.iter().find(|a| expr_match(&search_col, a)) { |
108 | 0 | let found = found.clone(); |
109 | 0 | return Ok(Transformed::yes(match normalized_expr { |
110 | 0 | Expr::Cast(Cast { expr: _, data_type }) => Expr::Cast(Cast { |
111 | 0 | expr: Box::new(found), |
112 | 0 | data_type, |
113 | 0 | }), |
114 | 0 | Expr::TryCast(TryCast { expr: _, data_type }) => Expr::TryCast(TryCast { |
115 | 0 | expr: Box::new(found), |
116 | 0 | data_type, |
117 | 0 | }), |
118 | 0 | _ => found, |
119 | | })); |
120 | 0 | } |
121 | 0 |
|
122 | 0 | Ok(Transformed::no(expr)) |
123 | 0 | }) |
124 | 0 | .data() |
125 | 0 | } |
126 | | |
127 | | /// Does the underlying expr match e? |
128 | | /// so avg(c) as average will match avgc |
129 | 0 | fn expr_match(needle: &Expr, expr: &Expr) -> bool { |
130 | | // check inside aliases |
131 | 0 | if let Expr::Alias(Alias { expr, .. }) = &expr { |
132 | 0 | expr.as_ref() == needle |
133 | | } else { |
134 | 0 | expr == needle |
135 | | } |
136 | 0 | } |
137 | | |
138 | | #[cfg(test)] |
139 | | mod test { |
140 | | use std::ops::Add; |
141 | | use std::sync::Arc; |
142 | | |
143 | | use arrow::datatypes::{DataType, Field, Schema}; |
144 | | |
145 | | use crate::{ |
146 | | cast, col, lit, logical_plan::builder::LogicalTableSource, try_cast, |
147 | | LogicalPlanBuilder, |
148 | | }; |
149 | | |
150 | | use super::*; |
151 | | use crate::test::function_stub::avg; |
152 | | use crate::test::function_stub::min; |
153 | | |
154 | | #[test] |
155 | | fn rewrite_sort_cols_by_agg() { |
156 | | // gby c1, agg: min(c2) |
157 | | let agg = make_input() |
158 | | .aggregate( |
159 | | // gby: c1 |
160 | | vec![col("c1")], |
161 | | // agg: min(c2) |
162 | | vec![min(col("c2"))], |
163 | | ) |
164 | | .unwrap() |
165 | | .build() |
166 | | .unwrap(); |
167 | | |
168 | | let cases = vec![ |
169 | | TestCase { |
170 | | desc: "c1 --> c1", |
171 | | input: sort(col("c1")), |
172 | | expected: sort(col("c1")), |
173 | | }, |
174 | | TestCase { |
175 | | desc: "c1 + c2 --> c1 + c2", |
176 | | input: sort(col("c1") + col("c1")), |
177 | | expected: sort(col("c1") + col("c1")), |
178 | | }, |
179 | | TestCase { |
180 | | desc: r#"min(c2) --> "min(c2)"#, |
181 | | input: sort(min(col("c2"))), |
182 | | expected: sort(min(col("c2"))), |
183 | | }, |
184 | | TestCase { |
185 | | desc: r#"c1 + min(c2) --> "c1 + min(c2)"#, |
186 | | input: sort(col("c1") + min(col("c2"))), |
187 | | expected: sort(col("c1") + min(col("c2"))), |
188 | | }, |
189 | | ]; |
190 | | |
191 | | for case in cases { |
192 | | case.run(&agg) |
193 | | } |
194 | | } |
195 | | |
196 | | #[test] |
197 | | fn rewrite_sort_cols_by_agg_alias() { |
198 | | let agg = make_input() |
199 | | .aggregate( |
200 | | // gby c1 |
201 | | vec![col("c1")], |
202 | | // agg: min(c2), avg(c3) |
203 | | vec![min(col("c2")), avg(col("c3"))], |
204 | | ) |
205 | | .unwrap() |
206 | | // projects out an expression "c1" that is different than the column "c1" |
207 | | .project(vec![ |
208 | | // c1 + 1 as c1, |
209 | | col("c1").add(lit(1)).alias("c1"), |
210 | | // min(c2) |
211 | | min(col("c2")), |
212 | | // avg("c3") as average |
213 | | avg(col("c3")).alias("average"), |
214 | | ]) |
215 | | .unwrap() |
216 | | .build() |
217 | | .unwrap(); |
218 | | |
219 | | let cases = vec![ |
220 | | TestCase { |
221 | | desc: "c1 --> c1 -- column *named* c1 that came out of the projection, (not t.c1)", |
222 | | input: sort(col("c1")), |
223 | | // should be "c1" not t.c1 |
224 | | expected: sort(col("c1")), |
225 | | }, |
226 | | TestCase { |
227 | | desc: r#"min(c2) --> "min(c2)" -- (column *named* "min(t.c2)"!)"#, |
228 | | input: sort(min(col("c2"))), |
229 | | expected: sort(col("min(t.c2)")), |
230 | | }, |
231 | | TestCase { |
232 | | desc: r#"c1 + min(c2) --> "c1 + min(c2)" -- (column *named* "min(t.c2)"!)"#, |
233 | | input: sort(col("c1") + min(col("c2"))), |
234 | | // should be "c1" not t.c1 |
235 | | expected: sort(col("c1") + col("min(t.c2)")), |
236 | | }, |
237 | | TestCase { |
238 | | desc: r#"avg(c3) --> "avg(t.c3)" as average (column *named* "avg(t.c3)", aliased)"#, |
239 | | input: sort(avg(col("c3"))), |
240 | | expected: sort(col("avg(t.c3)").alias("average")), |
241 | | }, |
242 | | ]; |
243 | | |
244 | | for case in cases { |
245 | | case.run(&agg) |
246 | | } |
247 | | } |
248 | | |
249 | | #[test] |
250 | | fn preserve_cast() { |
251 | | let plan = make_input() |
252 | | .project(vec![col("c2").alias("c2")]) |
253 | | .unwrap() |
254 | | .project(vec![col("c2").alias("c2")]) |
255 | | .unwrap() |
256 | | .build() |
257 | | .unwrap(); |
258 | | |
259 | | let cases = vec![ |
260 | | TestCase { |
261 | | desc: "Cast is preserved by rewrite_sort_cols_by_aggs", |
262 | | input: sort(cast(col("c2"), DataType::Int64)), |
263 | | expected: sort(cast(col("c2").alias("c2"), DataType::Int64)), |
264 | | }, |
265 | | TestCase { |
266 | | desc: "TryCast is preserved by rewrite_sort_cols_by_aggs", |
267 | | input: sort(try_cast(col("c2"), DataType::Int64)), |
268 | | expected: sort(try_cast(col("c2").alias("c2"), DataType::Int64)), |
269 | | }, |
270 | | ]; |
271 | | |
272 | | for case in cases { |
273 | | case.run(&plan) |
274 | | } |
275 | | } |
276 | | |
277 | | struct TestCase { |
278 | | desc: &'static str, |
279 | | input: Sort, |
280 | | expected: Sort, |
281 | | } |
282 | | |
283 | | impl TestCase { |
284 | | /// calls rewrite_sort_cols_by_aggs for expr and compares it to expected_expr |
285 | | fn run(self, input_plan: &LogicalPlan) { |
286 | | let Self { |
287 | | desc, |
288 | | input, |
289 | | expected, |
290 | | } = self; |
291 | | |
292 | | println!("running: '{desc}'"); |
293 | | let mut exprs = |
294 | | rewrite_sort_cols_by_aggs(vec![input.clone()], input_plan).unwrap(); |
295 | | |
296 | | assert_eq!(exprs.len(), 1); |
297 | | let rewritten = exprs.pop().unwrap(); |
298 | | |
299 | | assert_eq!( |
300 | | rewritten, expected, |
301 | | "\n\ninput:{input:?}\nrewritten:{rewritten:?}\nexpected:{expected:?}\n" |
302 | | ); |
303 | | } |
304 | | } |
305 | | |
306 | | /// Scan of a table: t(c1 int, c2 varchar, c3 float) |
307 | | fn make_input() -> LogicalPlanBuilder { |
308 | | let schema = Arc::new(Schema::new(vec![ |
309 | | Field::new("c1", DataType::Int32, true), |
310 | | Field::new("c2", DataType::Utf8, true), |
311 | | Field::new("c3", DataType::Float64, true), |
312 | | ])); |
313 | | let projection = None; |
314 | | LogicalPlanBuilder::scan( |
315 | | "t", |
316 | | Arc::new(LogicalTableSource::new(schema)), |
317 | | projection, |
318 | | ) |
319 | | .unwrap() |
320 | | } |
321 | | |
322 | | fn sort(expr: Expr) -> Sort { |
323 | | let asc = true; |
324 | | let nulls_first = true; |
325 | | expr.sort(asc, nulls_first) |
326 | | } |
327 | | } |