/Users/andrewlamb/Software/datafusion/datafusion/expr/src/utils.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Expression utilities |
19 | | |
20 | | use std::cmp::Ordering; |
21 | | use std::collections::{HashMap, HashSet}; |
22 | | use std::ops::Deref; |
23 | | use std::sync::Arc; |
24 | | |
25 | | use crate::expr::{Alias, Sort, WildcardOptions, WindowFunction}; |
26 | | use crate::expr_rewriter::strip_outer_reference; |
27 | | use crate::{ |
28 | | and, BinaryExpr, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, Operator, |
29 | | }; |
30 | | use datafusion_expr_common::signature::{Signature, TypeSignature}; |
31 | | |
32 | | use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; |
33 | | use datafusion_common::tree_node::{ |
34 | | Transformed, TransformedResult, TreeNode, TreeNodeRecursion, |
35 | | }; |
36 | | use datafusion_common::utils::get_at_indices; |
37 | | use datafusion_common::{ |
38 | | internal_err, plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, |
39 | | DataFusionError, Result, TableReference, |
40 | | }; |
41 | | |
42 | | use indexmap::IndexSet; |
43 | | use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem}; |
44 | | |
45 | | pub use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity; |
46 | | |
47 | | /// The value to which `COUNT(*)` is expanded to in |
48 | | /// `COUNT(<constant>)` expressions |
49 | | pub use datafusion_common::utils::expr::COUNT_STAR_EXPANSION; |
50 | | |
51 | | /// Recursively walk a list of expression trees, collecting the unique set of columns |
52 | | /// referenced in the expression |
53 | | #[deprecated(since = "40.0.0", note = "Expr::add_column_refs instead")] |
54 | 0 | pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet<Column>) -> Result<()> { |
55 | 0 | for e in expr { |
56 | 0 | expr_to_columns(e, accum)?; |
57 | | } |
58 | 0 | Ok(()) |
59 | 0 | } |
60 | | |
61 | | /// Count the number of distinct exprs in a list of group by expressions. If the |
62 | | /// first element is a `GroupingSet` expression then it must be the only expr. |
63 | 0 | pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result<usize> { |
64 | 0 | grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len()) |
65 | 0 | } |
66 | | |
67 | | /// The [power set] (or powerset) of a set S is the set of all subsets of S, \ |
68 | | /// including the empty set and S itself. |
69 | | /// |
70 | | /// Example: |
71 | | /// |
72 | | /// If S is the set {x, y, z}, then all the subsets of S are \ |
73 | | /// {} \ |
74 | | /// {x} \ |
75 | | /// {y} \ |
76 | | /// {z} \ |
77 | | /// {x, y} \ |
78 | | /// {x, z} \ |
79 | | /// {y, z} \ |
80 | | /// {x, y, z} \ |
81 | | /// and hence the power set of S is {{}, {x}, {y}, {z}, {x, y}, {x, z}, {y, z}, {x, y, z}}. |
82 | | /// |
83 | | /// [power set]: https://en.wikipedia.org/wiki/Power_set |
84 | 0 | fn powerset<T>(slice: &[T]) -> Result<Vec<Vec<&T>>, String> { |
85 | 0 | if slice.len() >= 64 { |
86 | 0 | return Err("The size of the set must be less than 64.".into()); |
87 | 0 | } |
88 | 0 |
|
89 | 0 | let mut v = Vec::new(); |
90 | 0 | for mask in 0..(1 << slice.len()) { |
91 | 0 | let mut ss = vec![]; |
92 | 0 | let mut bitset = mask; |
93 | 0 | while bitset > 0 { |
94 | 0 | let rightmost: u64 = bitset & !(bitset - 1); |
95 | 0 | let idx = rightmost.trailing_zeros(); |
96 | 0 | let item = slice.get(idx as usize).unwrap(); |
97 | 0 | ss.push(item); |
98 | 0 | // zero the trailing bit |
99 | 0 | bitset &= bitset - 1; |
100 | 0 | } |
101 | 0 | v.push(ss); |
102 | | } |
103 | 0 | Ok(v) |
104 | 0 | } |
105 | | |
106 | | /// check the number of expressions contained in the grouping_set |
107 | 0 | fn check_grouping_set_size_limit(size: usize) -> Result<()> { |
108 | 0 | let max_grouping_set_size = 65535; |
109 | 0 | if size > max_grouping_set_size { |
110 | 0 | return plan_err!("The number of group_expression in grouping_set exceeds the maximum limit {max_grouping_set_size}, found {size}"); |
111 | 0 | } |
112 | 0 |
|
113 | 0 | Ok(()) |
114 | 0 | } |
115 | | |
116 | | /// check the number of grouping_set contained in the grouping sets |
117 | 0 | fn check_grouping_sets_size_limit(size: usize) -> Result<()> { |
118 | 0 | let max_grouping_sets_size = 4096; |
119 | 0 | if size > max_grouping_sets_size { |
120 | 0 | return plan_err!("The number of grouping_set in grouping_sets exceeds the maximum limit {max_grouping_sets_size}, found {size}"); |
121 | 0 | } |
122 | 0 |
|
123 | 0 | Ok(()) |
124 | 0 | } |
125 | | |
126 | | /// Merge two grouping_set |
127 | | /// |
128 | | /// # Example |
129 | | /// ```text |
130 | | /// (A, B), (C, D) -> (A, B, C, D) |
131 | | /// ``` |
132 | | /// |
133 | | /// # Error |
134 | | /// - [`DataFusionError`]: The number of group_expression in grouping_set exceeds the maximum limit |
135 | | /// |
136 | | /// [`DataFusionError`]: datafusion_common::DataFusionError |
137 | 0 | fn merge_grouping_set<T: Clone>(left: &[T], right: &[T]) -> Result<Vec<T>> { |
138 | 0 | check_grouping_set_size_limit(left.len() + right.len())?; |
139 | 0 | Ok(left.iter().chain(right.iter()).cloned().collect()) |
140 | 0 | } |
141 | | |
142 | | /// Compute the cross product of two grouping_sets |
143 | | /// |
144 | | /// # Example |
145 | | /// ```text |
146 | | /// [(A, B), (C, D)], [(E), (F)] -> [(A, B, E), (A, B, F), (C, D, E), (C, D, F)] |
147 | | /// ``` |
148 | | /// |
149 | | /// # Error |
150 | | /// - [`DataFusionError`]: The number of group_expression in grouping_set exceeds the maximum limit |
151 | | /// - [`DataFusionError`]: The number of grouping_set in grouping_sets exceeds the maximum limit |
152 | | /// |
153 | | /// [`DataFusionError`]: datafusion_common::DataFusionError |
154 | 0 | fn cross_join_grouping_sets<T: Clone>( |
155 | 0 | left: &[Vec<T>], |
156 | 0 | right: &[Vec<T>], |
157 | 0 | ) -> Result<Vec<Vec<T>>> { |
158 | 0 | let grouping_sets_size = left.len() * right.len(); |
159 | 0 |
|
160 | 0 | check_grouping_sets_size_limit(grouping_sets_size)?; |
161 | | |
162 | 0 | let mut result = Vec::with_capacity(grouping_sets_size); |
163 | 0 | for le in left { |
164 | 0 | for re in right { |
165 | 0 | result.push(merge_grouping_set(le, re)?); |
166 | | } |
167 | | } |
168 | 0 | Ok(result) |
169 | 0 | } |
170 | | |
171 | | /// Convert multiple grouping expressions into one [`GroupingSet::GroupingSets`],\ |
172 | | /// if the grouping expression does not contain [`Expr::GroupingSet`] or only has one expression,\ |
173 | | /// no conversion will be performed. |
174 | | /// |
175 | | /// e.g. |
176 | | /// |
177 | | /// person.id,\ |
178 | | /// GROUPING SETS ((person.age, person.salary),(person.age)),\ |
179 | | /// ROLLUP(person.state, person.birth_date) |
180 | | /// |
181 | | /// => |
182 | | /// |
183 | | /// GROUPING SETS (\ |
184 | | /// (person.id, person.age, person.salary),\ |
185 | | /// (person.id, person.age, person.salary, person.state),\ |
186 | | /// (person.id, person.age, person.salary, person.state, person.birth_date),\ |
187 | | /// (person.id, person.age),\ |
188 | | /// (person.id, person.age, person.state),\ |
189 | | /// (person.id, person.age, person.state, person.birth_date)\ |
190 | | /// ) |
191 | 0 | pub fn enumerate_grouping_sets(group_expr: Vec<Expr>) -> Result<Vec<Expr>> { |
192 | 0 | let has_grouping_set = group_expr |
193 | 0 | .iter() |
194 | 0 | .any(|expr| matches!(expr, Expr::GroupingSet(_))); |
195 | 0 | if !has_grouping_set || group_expr.len() == 1 { |
196 | 0 | return Ok(group_expr); |
197 | 0 | } |
198 | | // only process mix grouping sets |
199 | 0 | let partial_sets = group_expr |
200 | 0 | .iter() |
201 | 0 | .map(|expr| { |
202 | 0 | let exprs = match expr { |
203 | 0 | Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets)) => { |
204 | 0 | check_grouping_sets_size_limit(grouping_sets.len())?; |
205 | 0 | grouping_sets.iter().map(|e| e.iter().collect()).collect() |
206 | | } |
207 | 0 | Expr::GroupingSet(GroupingSet::Cube(group_exprs)) => { |
208 | 0 | let grouping_sets = powerset(group_exprs) |
209 | 0 | .map_err(|e| plan_datafusion_err!("{}", e))?; |
210 | 0 | check_grouping_sets_size_limit(grouping_sets.len())?; |
211 | 0 | grouping_sets |
212 | | } |
213 | 0 | Expr::GroupingSet(GroupingSet::Rollup(group_exprs)) => { |
214 | 0 | let size = group_exprs.len(); |
215 | 0 | let slice = group_exprs.as_slice(); |
216 | 0 | check_grouping_sets_size_limit(size * (size + 1) / 2 + 1)?; |
217 | 0 | (0..(size + 1)) |
218 | 0 | .map(|i| slice[0..i].iter().collect()) |
219 | 0 | .collect() |
220 | | } |
221 | 0 | expr => vec![vec![expr]], |
222 | | }; |
223 | 0 | Ok(exprs) |
224 | 0 | }) |
225 | 0 | .collect::<Result<Vec<_>>>()?; |
226 | | |
227 | | // cross join |
228 | 0 | let grouping_sets = partial_sets |
229 | 0 | .into_iter() |
230 | 0 | .map(Ok) |
231 | 0 | .reduce(|l, r| cross_join_grouping_sets(&l?, &r?)) |
232 | 0 | .transpose()? |
233 | 0 | .map(|e| { |
234 | 0 | e.into_iter() |
235 | 0 | .map(|e| e.into_iter().cloned().collect()) |
236 | 0 | .collect() |
237 | 0 | }) |
238 | 0 | .unwrap_or_default(); |
239 | 0 |
|
240 | 0 | Ok(vec![Expr::GroupingSet(GroupingSet::GroupingSets( |
241 | 0 | grouping_sets, |
242 | 0 | ))]) |
243 | 0 | } |
244 | | |
245 | | /// Find all distinct exprs in a list of group by expressions. If the |
246 | | /// first element is a `GroupingSet` expression then it must be the only expr. |
247 | 0 | pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result<Vec<&Expr>> { |
248 | 0 | if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() { |
249 | 0 | if group_expr.len() > 1 { |
250 | 0 | return plan_err!( |
251 | 0 | "Invalid group by expressions, GroupingSet must be the only expression" |
252 | 0 | ); |
253 | 0 | } |
254 | 0 | Ok(grouping_set.distinct_expr()) |
255 | | } else { |
256 | 0 | Ok(group_expr |
257 | 0 | .iter() |
258 | 0 | .collect::<IndexSet<_>>() |
259 | 0 | .into_iter() |
260 | 0 | .collect()) |
261 | | } |
262 | 0 | } |
263 | | |
264 | | /// Recursively walk an expression tree, collecting the unique set of columns |
265 | | /// referenced in the expression |
266 | 0 | pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet<Column>) -> Result<()> { |
267 | 0 | expr.apply(|expr| { |
268 | 0 | match expr { |
269 | 0 | Expr::Column(qc) => { |
270 | 0 | accum.insert(qc.clone()); |
271 | 0 | } |
272 | | // Use explicit pattern match instead of a default |
273 | | // implementation, so that in the future if someone adds |
274 | | // new Expr types, they will check here as well |
275 | | Expr::Unnest(_) |
276 | | | Expr::ScalarVariable(_, _) |
277 | | | Expr::Alias(_) |
278 | | | Expr::Literal(_) |
279 | | | Expr::BinaryExpr { .. } |
280 | | | Expr::Like { .. } |
281 | | | Expr::SimilarTo { .. } |
282 | | | Expr::Not(_) |
283 | | | Expr::IsNotNull(_) |
284 | | | Expr::IsNull(_) |
285 | | | Expr::IsTrue(_) |
286 | | | Expr::IsFalse(_) |
287 | | | Expr::IsUnknown(_) |
288 | | | Expr::IsNotTrue(_) |
289 | | | Expr::IsNotFalse(_) |
290 | | | Expr::IsNotUnknown(_) |
291 | | | Expr::Negative(_) |
292 | | | Expr::Between { .. } |
293 | | | Expr::Case { .. } |
294 | | | Expr::Cast { .. } |
295 | | | Expr::TryCast { .. } |
296 | | | Expr::ScalarFunction(..) |
297 | | | Expr::WindowFunction { .. } |
298 | | | Expr::AggregateFunction { .. } |
299 | | | Expr::GroupingSet(_) |
300 | | | Expr::InList { .. } |
301 | | | Expr::Exists { .. } |
302 | | | Expr::InSubquery(_) |
303 | | | Expr::ScalarSubquery(_) |
304 | | | Expr::Wildcard { .. } |
305 | | | Expr::Placeholder(_) |
306 | 0 | | Expr::OuterReferenceColumn { .. } => {} |
307 | | } |
308 | 0 | Ok(TreeNodeRecursion::Continue) |
309 | 0 | }) |
310 | 0 | .map(|_| ()) |
311 | 0 | } |
312 | | |
313 | | /// Find excluded columns in the schema, if any |
314 | | /// SELECT * EXCLUDE(col1, col2), would return `vec![col1, col2]` |
315 | 0 | fn get_excluded_columns( |
316 | 0 | opt_exclude: Option<&ExcludeSelectItem>, |
317 | 0 | opt_except: Option<&ExceptSelectItem>, |
318 | 0 | schema: &DFSchema, |
319 | 0 | qualifier: Option<&TableReference>, |
320 | 0 | ) -> Result<Vec<Column>> { |
321 | 0 | let mut idents = vec![]; |
322 | 0 | if let Some(excepts) = opt_except { |
323 | 0 | idents.push(&excepts.first_element); |
324 | 0 | idents.extend(&excepts.additional_elements); |
325 | 0 | } |
326 | 0 | if let Some(exclude) = opt_exclude { |
327 | 0 | match exclude { |
328 | 0 | ExcludeSelectItem::Single(ident) => idents.push(ident), |
329 | 0 | ExcludeSelectItem::Multiple(idents_inner) => idents.extend(idents_inner), |
330 | | } |
331 | 0 | } |
332 | | // Excluded columns should be unique |
333 | 0 | let n_elem = idents.len(); |
334 | 0 | let unique_idents = idents.into_iter().collect::<HashSet<_>>(); |
335 | 0 | // if HashSet size, and vector length are different, this means that some of the excluded columns |
336 | 0 | // are not unique. In this case return error. |
337 | 0 | if n_elem != unique_idents.len() { |
338 | 0 | return plan_err!("EXCLUDE or EXCEPT contains duplicate column names"); |
339 | 0 | } |
340 | 0 |
|
341 | 0 | let mut result = vec![]; |
342 | 0 | for ident in unique_idents.into_iter() { |
343 | 0 | let col_name = ident.value.as_str(); |
344 | 0 | let (qualifier, field) = schema.qualified_field_with_name(qualifier, col_name)?; |
345 | 0 | result.push(Column::from((qualifier, field))); |
346 | | } |
347 | 0 | Ok(result) |
348 | 0 | } |
349 | | |
350 | | /// Returns all `Expr`s in the schema, except the `Column`s in the `columns_to_skip` |
351 | 0 | fn get_exprs_except_skipped( |
352 | 0 | schema: &DFSchema, |
353 | 0 | columns_to_skip: HashSet<Column>, |
354 | 0 | ) -> Vec<Expr> { |
355 | 0 | if columns_to_skip.is_empty() { |
356 | 0 | schema.iter().map(Expr::from).collect::<Vec<Expr>>() |
357 | | } else { |
358 | 0 | schema |
359 | 0 | .columns() |
360 | 0 | .iter() |
361 | 0 | .filter_map(|c| { |
362 | 0 | if !columns_to_skip.contains(c) { |
363 | 0 | Some(Expr::Column(c.clone())) |
364 | | } else { |
365 | 0 | None |
366 | | } |
367 | 0 | }) |
368 | 0 | .collect::<Vec<Expr>>() |
369 | | } |
370 | 0 | } |
371 | | |
372 | | /// Resolves an `Expr::Wildcard` to a collection of `Expr::Column`'s. |
373 | 0 | pub fn expand_wildcard( |
374 | 0 | schema: &DFSchema, |
375 | 0 | plan: &LogicalPlan, |
376 | 0 | wildcard_options: Option<&WildcardOptions>, |
377 | 0 | ) -> Result<Vec<Expr>> { |
378 | 0 | let using_columns = plan.using_columns()?; |
379 | 0 | let mut columns_to_skip = using_columns |
380 | 0 | .into_iter() |
381 | 0 | // For each USING JOIN condition, only expand to one of each join column in projection |
382 | 0 | .flat_map(|cols| { |
383 | 0 | let mut cols = cols.into_iter().collect::<Vec<_>>(); |
384 | 0 | // sort join columns to make sure we consistently keep the same |
385 | 0 | // qualified column |
386 | 0 | cols.sort(); |
387 | 0 | let mut out_column_names: HashSet<String> = HashSet::new(); |
388 | 0 | cols.into_iter() |
389 | 0 | .filter_map(|c| { |
390 | 0 | if out_column_names.contains(&c.name) { |
391 | 0 | Some(c) |
392 | | } else { |
393 | 0 | out_column_names.insert(c.name); |
394 | 0 | None |
395 | | } |
396 | 0 | }) |
397 | 0 | .collect::<Vec<_>>() |
398 | 0 | }) |
399 | 0 | .collect::<HashSet<_>>(); |
400 | 0 | let excluded_columns = if let Some(WildcardOptions { |
401 | 0 | exclude: opt_exclude, |
402 | 0 | except: opt_except, |
403 | | .. |
404 | 0 | }) = wildcard_options |
405 | | { |
406 | 0 | get_excluded_columns(opt_exclude.as_ref(), opt_except.as_ref(), schema, None)? |
407 | | } else { |
408 | 0 | vec![] |
409 | | }; |
410 | | // Add each excluded `Column` to columns_to_skip |
411 | 0 | columns_to_skip.extend(excluded_columns); |
412 | 0 | Ok(get_exprs_except_skipped(schema, columns_to_skip)) |
413 | 0 | } |
414 | | |
415 | | /// Resolves an `Expr::Wildcard` to a collection of qualified `Expr::Column`'s. |
416 | 0 | pub fn expand_qualified_wildcard( |
417 | 0 | qualifier: &TableReference, |
418 | 0 | schema: &DFSchema, |
419 | 0 | wildcard_options: Option<&WildcardOptions>, |
420 | 0 | ) -> Result<Vec<Expr>> { |
421 | 0 | let qualified_indices = schema.fields_indices_with_qualified(qualifier); |
422 | 0 | let projected_func_dependencies = schema |
423 | 0 | .functional_dependencies() |
424 | 0 | .project_functional_dependencies(&qualified_indices, qualified_indices.len()); |
425 | 0 | let fields_with_qualified = get_at_indices(schema.fields(), &qualified_indices)?; |
426 | 0 | if fields_with_qualified.is_empty() { |
427 | 0 | return plan_err!("Invalid qualifier {qualifier}"); |
428 | 0 | } |
429 | 0 |
|
430 | 0 | let qualified_schema = Arc::new(Schema::new(fields_with_qualified)); |
431 | 0 | let qualified_dfschema = |
432 | 0 | DFSchema::try_from_qualified_schema(qualifier.clone(), &qualified_schema)? |
433 | 0 | .with_functional_dependencies(projected_func_dependencies)?; |
434 | 0 | let excluded_columns = if let Some(WildcardOptions { |
435 | 0 | exclude: opt_exclude, |
436 | 0 | except: opt_except, |
437 | | .. |
438 | 0 | }) = wildcard_options |
439 | | { |
440 | 0 | get_excluded_columns( |
441 | 0 | opt_exclude.as_ref(), |
442 | 0 | opt_except.as_ref(), |
443 | 0 | schema, |
444 | 0 | Some(qualifier), |
445 | 0 | )? |
446 | | } else { |
447 | 0 | vec![] |
448 | | }; |
449 | | // Add each excluded `Column` to columns_to_skip |
450 | 0 | let mut columns_to_skip = HashSet::new(); |
451 | 0 | columns_to_skip.extend(excluded_columns); |
452 | 0 | Ok(get_exprs_except_skipped( |
453 | 0 | &qualified_dfschema, |
454 | 0 | columns_to_skip, |
455 | 0 | )) |
456 | 0 | } |
457 | | |
458 | | /// (expr, "is the SortExpr for window (either comes from PARTITION BY or ORDER BY columns)") |
459 | | /// if bool is true SortExpr comes from `PARTITION BY` column, if false comes from `ORDER BY` column |
460 | | type WindowSortKey = Vec<(Sort, bool)>; |
461 | | |
462 | | /// Generate a sort key for a given window expr's partition_by and order_by expr |
463 | 0 | pub fn generate_sort_key( |
464 | 0 | partition_by: &[Expr], |
465 | 0 | order_by: &[Sort], |
466 | 0 | ) -> Result<WindowSortKey> { |
467 | 0 | let normalized_order_by_keys = order_by |
468 | 0 | .iter() |
469 | 0 | .map(|e| { |
470 | 0 | let Sort { expr, .. } = e; |
471 | 0 | Sort::new(expr.clone(), true, false) |
472 | 0 | }) |
473 | 0 | .collect::<Vec<_>>(); |
474 | 0 |
|
475 | 0 | let mut final_sort_keys = vec![]; |
476 | 0 | let mut is_partition_flag = vec![]; |
477 | 0 | partition_by.iter().for_each(|e| { |
478 | 0 | // By default, create sort key with ASC is true and NULLS LAST to be consistent with |
479 | 0 | // PostgreSQL's rule: https://www.postgresql.org/docs/current/queries-order.html |
480 | 0 | let e = e.clone().sort(true, false); |
481 | 0 | if let Some(pos) = normalized_order_by_keys.iter().position(|key| key.eq(&e)) { |
482 | 0 | let order_by_key = &order_by[pos]; |
483 | 0 | if !final_sort_keys.contains(order_by_key) { |
484 | 0 | final_sort_keys.push(order_by_key.clone()); |
485 | 0 | is_partition_flag.push(true); |
486 | 0 | } |
487 | 0 | } else if !final_sort_keys.contains(&e) { |
488 | 0 | final_sort_keys.push(e); |
489 | 0 | is_partition_flag.push(true); |
490 | 0 | } |
491 | 0 | }); |
492 | 0 |
|
493 | 0 | order_by.iter().for_each(|e| { |
494 | 0 | if !final_sort_keys.contains(e) { |
495 | 0 | final_sort_keys.push(e.clone()); |
496 | 0 | is_partition_flag.push(false); |
497 | 0 | } |
498 | 0 | }); |
499 | 0 | let res = final_sort_keys |
500 | 0 | .into_iter() |
501 | 0 | .zip(is_partition_flag) |
502 | 0 | .collect::<Vec<_>>(); |
503 | 0 | Ok(res) |
504 | 0 | } |
505 | | |
506 | | /// Compare the sort expr as PostgreSQL's common_prefix_cmp(): |
507 | | /// <https://github.com/postgres/postgres/blob/master/src/backend/optimizer/plan/planner.c> |
508 | 0 | pub fn compare_sort_expr( |
509 | 0 | sort_expr_a: &Sort, |
510 | 0 | sort_expr_b: &Sort, |
511 | 0 | schema: &DFSchemaRef, |
512 | 0 | ) -> Ordering { |
513 | 0 | let Sort { |
514 | 0 | expr: expr_a, |
515 | 0 | asc: asc_a, |
516 | 0 | nulls_first: nulls_first_a, |
517 | 0 | } = sort_expr_a; |
518 | 0 |
|
519 | 0 | let Sort { |
520 | 0 | expr: expr_b, |
521 | 0 | asc: asc_b, |
522 | 0 | nulls_first: nulls_first_b, |
523 | 0 | } = sort_expr_b; |
524 | 0 |
|
525 | 0 | let ref_indexes_a = find_column_indexes_referenced_by_expr(expr_a, schema); |
526 | 0 | let ref_indexes_b = find_column_indexes_referenced_by_expr(expr_b, schema); |
527 | 0 | for (idx_a, idx_b) in ref_indexes_a.iter().zip(ref_indexes_b.iter()) { |
528 | 0 | match idx_a.cmp(idx_b) { |
529 | | Ordering::Less => { |
530 | 0 | return Ordering::Less; |
531 | | } |
532 | | Ordering::Greater => { |
533 | 0 | return Ordering::Greater; |
534 | | } |
535 | 0 | Ordering::Equal => {} |
536 | | } |
537 | | } |
538 | 0 | match ref_indexes_a.len().cmp(&ref_indexes_b.len()) { |
539 | 0 | Ordering::Less => return Ordering::Greater, |
540 | | Ordering::Greater => { |
541 | 0 | return Ordering::Less; |
542 | | } |
543 | 0 | Ordering::Equal => {} |
544 | 0 | } |
545 | 0 | match (asc_a, asc_b) { |
546 | | (true, false) => { |
547 | 0 | return Ordering::Greater; |
548 | | } |
549 | | (false, true) => { |
550 | 0 | return Ordering::Less; |
551 | | } |
552 | 0 | _ => {} |
553 | 0 | } |
554 | 0 | match (nulls_first_a, nulls_first_b) { |
555 | | (true, false) => { |
556 | 0 | return Ordering::Less; |
557 | | } |
558 | | (false, true) => { |
559 | 0 | return Ordering::Greater; |
560 | | } |
561 | 0 | _ => {} |
562 | 0 | } |
563 | 0 | Ordering::Equal |
564 | 0 | } |
565 | | |
566 | | /// group a slice of window expression expr by their order by expressions |
567 | 0 | pub fn group_window_expr_by_sort_keys( |
568 | 0 | window_expr: Vec<Expr>, |
569 | 0 | ) -> Result<Vec<(WindowSortKey, Vec<Expr>)>> { |
570 | 0 | let mut result = vec![]; |
571 | 0 | window_expr.into_iter().try_for_each(|expr| match &expr { |
572 | 0 | Expr::WindowFunction( WindowFunction{ partition_by, order_by, .. }) => { |
573 | 0 | let sort_key = generate_sort_key(partition_by, order_by)?; |
574 | 0 | if let Some((_, values)) = result.iter_mut().find( |
575 | 0 | |group: &&mut (WindowSortKey, Vec<Expr>)| matches!(group, (key, _) if *key == sort_key), |
576 | 0 | ) { |
577 | 0 | values.push(expr); |
578 | 0 | } else { |
579 | 0 | result.push((sort_key, vec![expr])) |
580 | | } |
581 | 0 | Ok(()) |
582 | | } |
583 | 0 | other => internal_err!( |
584 | 0 | "Impossibly got non-window expr {other:?}" |
585 | 0 | ), |
586 | 0 | })?; |
587 | 0 | Ok(result) |
588 | 0 | } |
589 | | |
590 | | /// Collect all deeply nested `Expr::AggregateFunction`. |
591 | | /// They are returned in order of occurrence (depth |
592 | | /// first), with duplicates omitted. |
593 | 0 | pub fn find_aggregate_exprs(exprs: &[Expr]) -> Vec<Expr> { |
594 | 0 | find_exprs_in_exprs(exprs, &|nested_expr| { |
595 | 0 | matches!(nested_expr, Expr::AggregateFunction { .. }) |
596 | 0 | }) |
597 | 0 | } |
598 | | |
599 | | /// Collect all deeply nested `Expr::WindowFunction`. They are returned in order of occurrence |
600 | | /// (depth first), with duplicates omitted. |
601 | 0 | pub fn find_window_exprs(exprs: &[Expr]) -> Vec<Expr> { |
602 | 0 | find_exprs_in_exprs(exprs, &|nested_expr| { |
603 | 0 | matches!(nested_expr, Expr::WindowFunction { .. }) |
604 | 0 | }) |
605 | 0 | } |
606 | | |
607 | | /// Collect all deeply nested `Expr::OuterReferenceColumn`. They are returned in order of occurrence |
608 | | /// (depth first), with duplicates omitted. |
609 | 0 | pub fn find_out_reference_exprs(expr: &Expr) -> Vec<Expr> { |
610 | 0 | find_exprs_in_expr(expr, &|nested_expr| { |
611 | 0 | matches!(nested_expr, Expr::OuterReferenceColumn { .. }) |
612 | 0 | }) |
613 | 0 | } |
614 | | |
615 | | /// Search the provided `Expr`'s, and all of their nested `Expr`, for any that |
616 | | /// pass the provided test. The returned `Expr`'s are deduplicated and returned |
617 | | /// in order of appearance (depth first). |
618 | 0 | fn find_exprs_in_exprs<F>(exprs: &[Expr], test_fn: &F) -> Vec<Expr> |
619 | 0 | where |
620 | 0 | F: Fn(&Expr) -> bool, |
621 | 0 | { |
622 | 0 | exprs |
623 | 0 | .iter() |
624 | 0 | .flat_map(|expr| find_exprs_in_expr(expr, test_fn)) |
625 | 0 | .fold(vec![], |mut acc, expr| { |
626 | 0 | if !acc.contains(&expr) { |
627 | 0 | acc.push(expr) |
628 | 0 | } |
629 | 0 | acc |
630 | 0 | }) |
631 | 0 | } |
632 | | |
633 | | /// Search an `Expr`, and all of its nested `Expr`'s, for any that pass the |
634 | | /// provided test. The returned `Expr`'s are deduplicated and returned in order |
635 | | /// of appearance (depth first). |
636 | 0 | fn find_exprs_in_expr<F>(expr: &Expr, test_fn: &F) -> Vec<Expr> |
637 | 0 | where |
638 | 0 | F: Fn(&Expr) -> bool, |
639 | 0 | { |
640 | 0 | let mut exprs = vec![]; |
641 | 0 | expr.apply(|expr| { |
642 | 0 | if test_fn(expr) { |
643 | 0 | if !(exprs.contains(expr)) { |
644 | 0 | exprs.push(expr.clone()) |
645 | 0 | } |
646 | | // stop recursing down this expr once we find a match |
647 | 0 | return Ok(TreeNodeRecursion::Jump); |
648 | 0 | } |
649 | 0 |
|
650 | 0 | Ok(TreeNodeRecursion::Continue) |
651 | 0 | }) |
652 | 0 | // pre_visit always returns OK, so this will always too |
653 | 0 | .expect("no way to return error during recursion"); |
654 | 0 | exprs |
655 | 0 | } |
656 | | |
657 | | /// Recursively inspect an [`Expr`] and all its children. |
658 | 0 | pub fn inspect_expr_pre<F, E>(expr: &Expr, mut f: F) -> Result<(), E> |
659 | 0 | where |
660 | 0 | F: FnMut(&Expr) -> Result<(), E>, |
661 | 0 | { |
662 | 0 | let mut err = Ok(()); |
663 | 0 | expr.apply(|expr| { |
664 | 0 | if let Err(e) = f(expr) { |
665 | | // save the error for later (it may not be a DataFusionError |
666 | 0 | err = Err(e); |
667 | 0 | Ok(TreeNodeRecursion::Stop) |
668 | | } else { |
669 | | // keep going |
670 | 0 | Ok(TreeNodeRecursion::Continue) |
671 | | } |
672 | 0 | }) |
673 | 0 | // The closure always returns OK, so this will always too |
674 | 0 | .expect("no way to return error during recursion"); |
675 | 0 |
|
676 | 0 | err |
677 | 0 | } |
678 | | |
679 | | /// Create field meta-data from an expression, for use in a result set schema |
680 | 0 | pub fn exprlist_to_fields<'a>( |
681 | 0 | exprs: impl IntoIterator<Item = &'a Expr>, |
682 | 0 | plan: &LogicalPlan, |
683 | 0 | ) -> Result<Vec<(Option<TableReference>, Arc<Field>)>> { |
684 | 0 | // look for exact match in plan's output schema |
685 | 0 | let wildcard_schema = find_base_plan(plan).schema(); |
686 | 0 | let input_schema = plan.schema(); |
687 | 0 | let result = exprs |
688 | 0 | .into_iter() |
689 | 0 | .map(|e| match e { |
690 | 0 | Expr::Wildcard { qualifier, options } => match qualifier { |
691 | | None => { |
692 | 0 | let excluded: Vec<String> = get_excluded_columns( |
693 | 0 | options.exclude.as_ref(), |
694 | 0 | options.except.as_ref(), |
695 | 0 | wildcard_schema, |
696 | 0 | None, |
697 | 0 | )? |
698 | 0 | .into_iter() |
699 | 0 | .map(|c| c.flat_name()) |
700 | 0 | .collect(); |
701 | 0 | Ok::<_, DataFusionError>( |
702 | 0 | wildcard_schema |
703 | 0 | .field_names() |
704 | 0 | .iter() |
705 | 0 | .enumerate() |
706 | 0 | .filter(|(_, s)| !excluded.contains(s)) |
707 | 0 | .map(|(i, _)| wildcard_schema.qualified_field(i)) |
708 | 0 | .map(|(qualifier, f)| { |
709 | 0 | (qualifier.cloned(), Arc::new(f.to_owned())) |
710 | 0 | }) |
711 | 0 | .collect::<Vec<_>>(), |
712 | 0 | ) |
713 | | } |
714 | 0 | Some(qualifier) => { |
715 | 0 | let excluded: Vec<String> = get_excluded_columns( |
716 | 0 | options.exclude.as_ref(), |
717 | 0 | options.except.as_ref(), |
718 | 0 | wildcard_schema, |
719 | 0 | Some(qualifier), |
720 | 0 | )? |
721 | 0 | .into_iter() |
722 | 0 | .map(|c| c.flat_name()) |
723 | 0 | .collect(); |
724 | 0 | Ok(wildcard_schema |
725 | 0 | .fields_with_qualified(qualifier) |
726 | 0 | .into_iter() |
727 | 0 | .filter_map(|field| { |
728 | 0 | let flat_name = format!("{}.{}", qualifier, field.name()); |
729 | 0 | if excluded.contains(&flat_name) { |
730 | 0 | None |
731 | | } else { |
732 | 0 | Some(( |
733 | 0 | Some(qualifier.clone()), |
734 | 0 | Arc::new(field.to_owned()), |
735 | 0 | )) |
736 | | } |
737 | 0 | }) |
738 | 0 | .collect::<Vec<_>>()) |
739 | | } |
740 | | }, |
741 | 0 | _ => Ok(vec![e.to_field(input_schema)?]), |
742 | 0 | }) |
743 | 0 | .collect::<Result<Vec<_>>>()? |
744 | 0 | .into_iter() |
745 | 0 | .flatten() |
746 | 0 | .collect(); |
747 | 0 | Ok(result) |
748 | 0 | } |
749 | | |
750 | | /// Find the suitable base plan to expand the wildcard expression recursively. |
751 | | /// When planning [LogicalPlan::Window] and [LogicalPlan::Aggregate], we will generate |
752 | | /// an intermediate plan based on the relation plan (e.g. [LogicalPlan::TableScan], [LogicalPlan::Subquery], ...). |
753 | | /// If we expand a wildcard expression basing the intermediate plan, we could get some duplicate fields. |
754 | 0 | pub fn find_base_plan(input: &LogicalPlan) -> &LogicalPlan { |
755 | 0 | match input { |
756 | 0 | LogicalPlan::Window(window) => find_base_plan(&window.input), |
757 | 0 | LogicalPlan::Aggregate(agg) => find_base_plan(&agg.input), |
758 | | // [SqlToRel::try_process_unnest] will convert Expr(Unnest(Expr)) to Projection/Unnest/Projection |
759 | | // We should expand the wildcard expression based on the input plan of the inner Projection. |
760 | 0 | LogicalPlan::Unnest(unnest) => { |
761 | 0 | if let LogicalPlan::Projection(projection) = unnest.input.deref() { |
762 | 0 | find_base_plan(&projection.input) |
763 | | } else { |
764 | 0 | input |
765 | | } |
766 | | } |
767 | 0 | LogicalPlan::Filter(filter) => { |
768 | 0 | if filter.having { |
769 | | // If a filter is used for a having clause, its input plan is an aggregation. |
770 | | // We should expand the wildcard expression based on the aggregation's input plan. |
771 | 0 | find_base_plan(&filter.input) |
772 | | } else { |
773 | 0 | input |
774 | | } |
775 | | } |
776 | 0 | _ => input, |
777 | | } |
778 | 0 | } |
779 | | |
780 | | /// Count the number of real fields. We should expand the wildcard expression to get the actual number. |
781 | 0 | pub fn exprlist_len( |
782 | 0 | exprs: &[Expr], |
783 | 0 | schema: &DFSchemaRef, |
784 | 0 | wildcard_schema: Option<&DFSchemaRef>, |
785 | 0 | ) -> Result<usize> { |
786 | 0 | exprs |
787 | 0 | .iter() |
788 | 0 | .map(|e| match e { |
789 | | Expr::Wildcard { |
790 | | qualifier: None, |
791 | 0 | options, |
792 | | } => { |
793 | 0 | let excluded = get_excluded_columns( |
794 | 0 | options.exclude.as_ref(), |
795 | 0 | options.except.as_ref(), |
796 | 0 | wildcard_schema.unwrap_or(schema), |
797 | 0 | None, |
798 | 0 | )? |
799 | 0 | .into_iter() |
800 | 0 | .collect::<HashSet<Column>>(); |
801 | 0 | Ok( |
802 | 0 | get_exprs_except_skipped(wildcard_schema.unwrap_or(schema), excluded) |
803 | 0 | .len(), |
804 | 0 | ) |
805 | | } |
806 | | Expr::Wildcard { |
807 | 0 | qualifier: Some(qualifier), |
808 | 0 | options, |
809 | | } => { |
810 | 0 | let related_wildcard_schema = wildcard_schema.as_ref().map_or_else( |
811 | 0 | || Ok(Arc::clone(schema)), |
812 | 0 | |schema| { |
813 | 0 | // Eliminate the fields coming from other tables. |
814 | 0 | let qualified_fields = schema |
815 | 0 | .fields() |
816 | 0 | .iter() |
817 | 0 | .enumerate() |
818 | 0 | .filter_map(|(idx, field)| { |
819 | 0 | let (maybe_table_ref, _) = schema.qualified_field(idx); |
820 | 0 | if maybe_table_ref.map_or(true, |q| q == qualifier) { |
821 | 0 | Some((maybe_table_ref.cloned(), Arc::clone(field))) |
822 | | } else { |
823 | 0 | None |
824 | | } |
825 | 0 | }) |
826 | 0 | .collect::<Vec<_>>(); |
827 | 0 | let metadata = schema.metadata().clone(); |
828 | 0 | DFSchema::new_with_metadata(qualified_fields, metadata) |
829 | 0 | .map(Arc::new) |
830 | 0 | }, |
831 | 0 | )?; |
832 | 0 | let excluded = get_excluded_columns( |
833 | 0 | options.exclude.as_ref(), |
834 | 0 | options.except.as_ref(), |
835 | 0 | related_wildcard_schema.as_ref(), |
836 | 0 | Some(qualifier), |
837 | 0 | )? |
838 | 0 | .into_iter() |
839 | 0 | .collect::<HashSet<Column>>(); |
840 | 0 | Ok( |
841 | 0 | get_exprs_except_skipped(related_wildcard_schema.as_ref(), excluded) |
842 | 0 | .len(), |
843 | 0 | ) |
844 | | } |
845 | 0 | _ => Ok(1), |
846 | 0 | }) |
847 | 0 | .sum() |
848 | 0 | } |
849 | | |
850 | | /// Convert an expression into Column expression if it's already provided as input plan. |
851 | | /// |
852 | | /// For example, it rewrites: |
853 | | /// |
854 | | /// ```text |
855 | | /// .aggregate(vec![col("c1")], vec![sum(col("c2"))])? |
856 | | /// .project(vec![col("c1"), sum(col("c2"))? |
857 | | /// ``` |
858 | | /// |
859 | | /// Into: |
860 | | /// |
861 | | /// ```text |
862 | | /// .aggregate(vec![col("c1")], vec![sum(col("c2"))])? |
863 | | /// .project(vec![col("c1"), col("SUM(c2)")? |
864 | | /// ``` |
865 | 0 | pub fn columnize_expr(e: Expr, input: &LogicalPlan) -> Result<Expr> { |
866 | 0 | let output_exprs = match input.columnized_output_exprs() { |
867 | 0 | Ok(exprs) if !exprs.is_empty() => exprs, |
868 | 0 | _ => return Ok(e), |
869 | | }; |
870 | 0 | let exprs_map: HashMap<&Expr, Column> = output_exprs.into_iter().collect(); |
871 | 0 | e.transform_down(|node: Expr| match exprs_map.get(&node) { |
872 | 0 | Some(column) => Ok(Transformed::new( |
873 | 0 | Expr::Column(column.clone()), |
874 | 0 | true, |
875 | 0 | TreeNodeRecursion::Jump, |
876 | 0 | )), |
877 | 0 | None => Ok(Transformed::no(node)), |
878 | 0 | }) |
879 | 0 | .data() |
880 | 0 | } |
881 | | |
882 | | /// Collect all deeply nested `Expr::Column`'s. They are returned in order of |
883 | | /// appearance (depth first), and may contain duplicates. |
884 | 0 | pub fn find_column_exprs(exprs: &[Expr]) -> Vec<Expr> { |
885 | 0 | exprs |
886 | 0 | .iter() |
887 | 0 | .flat_map(find_columns_referenced_by_expr) |
888 | 0 | .map(Expr::Column) |
889 | 0 | .collect() |
890 | 0 | } |
891 | | |
892 | 0 | pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec<Column> { |
893 | 0 | let mut exprs = vec![]; |
894 | 0 | e.apply(|expr| { |
895 | 0 | if let Expr::Column(c) = expr { |
896 | 0 | exprs.push(c.clone()) |
897 | 0 | } |
898 | 0 | Ok(TreeNodeRecursion::Continue) |
899 | 0 | }) |
900 | 0 | // As the closure always returns Ok, this "can't" error |
901 | 0 | .expect("Unexpected error"); |
902 | 0 | exprs |
903 | 0 | } |
904 | | |
905 | | /// Convert any `Expr` to an `Expr::Column`. |
906 | 0 | pub fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> { |
907 | 0 | match expr { |
908 | 0 | Expr::Column(col) => { |
909 | 0 | let (qualifier, field) = plan.schema().qualified_field_from_column(col)?; |
910 | 0 | Ok(Expr::from(Column::from((qualifier, field)))) |
911 | | } |
912 | 0 | _ => Ok(Expr::Column(Column::from_name( |
913 | 0 | expr.schema_name().to_string(), |
914 | 0 | ))), |
915 | | } |
916 | 0 | } |
917 | | |
918 | | /// Recursively walk an expression tree, collecting the column indexes |
919 | | /// referenced in the expression |
920 | 0 | pub(crate) fn find_column_indexes_referenced_by_expr( |
921 | 0 | e: &Expr, |
922 | 0 | schema: &DFSchemaRef, |
923 | 0 | ) -> Vec<usize> { |
924 | 0 | let mut indexes = vec![]; |
925 | 0 | e.apply(|expr| { |
926 | 0 | match expr { |
927 | 0 | Expr::Column(qc) => { |
928 | 0 | if let Ok(idx) = schema.index_of_column(qc) { |
929 | 0 | indexes.push(idx); |
930 | 0 | } |
931 | | } |
932 | 0 | Expr::Literal(_) => { |
933 | 0 | indexes.push(usize::MAX); |
934 | 0 | } |
935 | 0 | _ => {} |
936 | | } |
937 | 0 | Ok(TreeNodeRecursion::Continue) |
938 | 0 | }) |
939 | 0 | .unwrap(); |
940 | 0 | indexes |
941 | 0 | } |
942 | | |
943 | | /// can this data type be used in hash join equal conditions?? |
944 | | /// data types here come from function 'equal_rows', if more data types are supported |
945 | | /// in equal_rows(hash join), add those data types here to generate join logical plan. |
946 | 0 | pub fn can_hash(data_type: &DataType) -> bool { |
947 | 0 | match data_type { |
948 | 0 | DataType::Null => true, |
949 | 0 | DataType::Boolean => true, |
950 | 0 | DataType::Int8 => true, |
951 | 0 | DataType::Int16 => true, |
952 | 0 | DataType::Int32 => true, |
953 | 0 | DataType::Int64 => true, |
954 | 0 | DataType::UInt8 => true, |
955 | 0 | DataType::UInt16 => true, |
956 | 0 | DataType::UInt32 => true, |
957 | 0 | DataType::UInt64 => true, |
958 | 0 | DataType::Float32 => true, |
959 | 0 | DataType::Float64 => true, |
960 | 0 | DataType::Timestamp(time_unit, _) => match time_unit { |
961 | 0 | TimeUnit::Second => true, |
962 | 0 | TimeUnit::Millisecond => true, |
963 | 0 | TimeUnit::Microsecond => true, |
964 | 0 | TimeUnit::Nanosecond => true, |
965 | | }, |
966 | 0 | DataType::Utf8 => true, |
967 | 0 | DataType::LargeUtf8 => true, |
968 | 0 | DataType::Decimal128(_, _) => true, |
969 | 0 | DataType::Date32 => true, |
970 | 0 | DataType::Date64 => true, |
971 | 0 | DataType::FixedSizeBinary(_) => true, |
972 | 0 | DataType::Dictionary(key_type, value_type) |
973 | 0 | if *value_type.as_ref() == DataType::Utf8 => |
974 | 0 | { |
975 | 0 | DataType::is_dictionary_key_type(key_type) |
976 | | } |
977 | 0 | DataType::List(_) => true, |
978 | 0 | DataType::LargeList(_) => true, |
979 | 0 | DataType::FixedSizeList(_, _) => true, |
980 | 0 | DataType::Struct(fields) => fields.iter().all(|f| can_hash(f.data_type())), |
981 | 0 | _ => false, |
982 | | } |
983 | 0 | } |
984 | | |
985 | | /// Check whether all columns are from the schema. |
986 | 0 | pub fn check_all_columns_from_schema( |
987 | 0 | columns: &HashSet<&Column>, |
988 | 0 | schema: &DFSchema, |
989 | 0 | ) -> Result<bool> { |
990 | 0 | for col in columns.iter() { |
991 | 0 | let exist = schema.is_column_from_schema(col); |
992 | 0 | if !exist { |
993 | 0 | return Ok(false); |
994 | 0 | } |
995 | | } |
996 | | |
997 | 0 | Ok(true) |
998 | 0 | } |
999 | | |
1000 | | /// Give two sides of the equijoin predicate, return a valid join key pair. |
1001 | | /// If there is no valid join key pair, return None. |
1002 | | /// |
1003 | | /// A valid join means: |
1004 | | /// 1. All referenced column of the left side is from the left schema, and |
1005 | | /// all referenced column of the right side is from the right schema. |
1006 | | /// 2. Or opposite. All referenced column of the left side is from the right schema, |
1007 | | /// and the right side is from the left schema. |
1008 | | /// |
1009 | 0 | pub fn find_valid_equijoin_key_pair( |
1010 | 0 | left_key: &Expr, |
1011 | 0 | right_key: &Expr, |
1012 | 0 | left_schema: &DFSchema, |
1013 | 0 | right_schema: &DFSchema, |
1014 | 0 | ) -> Result<Option<(Expr, Expr)>> { |
1015 | 0 | let left_using_columns = left_key.column_refs(); |
1016 | 0 | let right_using_columns = right_key.column_refs(); |
1017 | 0 |
|
1018 | 0 | // Conditions like a = 10, will be added to non-equijoin. |
1019 | 0 | if left_using_columns.is_empty() || right_using_columns.is_empty() { |
1020 | 0 | return Ok(None); |
1021 | 0 | } |
1022 | 0 |
|
1023 | 0 | if check_all_columns_from_schema(&left_using_columns, left_schema)? |
1024 | 0 | && check_all_columns_from_schema(&right_using_columns, right_schema)? |
1025 | | { |
1026 | 0 | return Ok(Some((left_key.clone(), right_key.clone()))); |
1027 | 0 | } else if check_all_columns_from_schema(&right_using_columns, left_schema)? |
1028 | 0 | && check_all_columns_from_schema(&left_using_columns, right_schema)? |
1029 | | { |
1030 | 0 | return Ok(Some((right_key.clone(), left_key.clone()))); |
1031 | 0 | } |
1032 | 0 |
|
1033 | 0 | Ok(None) |
1034 | 0 | } |
1035 | | |
1036 | | /// Creates a detailed error message for a function with wrong signature. |
1037 | | /// |
1038 | | /// For example, a query like `select round(3.14, 1.1);` would yield: |
1039 | | /// ```text |
1040 | | /// Error during planning: No function matches 'round(Float64, Float64)'. You might need to add explicit type casts. |
1041 | | /// Candidate functions: |
1042 | | /// round(Float64, Int64) |
1043 | | /// round(Float32, Int64) |
1044 | | /// round(Float64) |
1045 | | /// round(Float32) |
1046 | | /// ``` |
1047 | 0 | pub fn generate_signature_error_msg( |
1048 | 0 | func_name: &str, |
1049 | 0 | func_signature: Signature, |
1050 | 0 | input_expr_types: &[DataType], |
1051 | 0 | ) -> String { |
1052 | 0 | let candidate_signatures = func_signature |
1053 | 0 | .type_signature |
1054 | 0 | .to_string_repr() |
1055 | 0 | .iter() |
1056 | 0 | .map(|args_str| format!("\t{func_name}({args_str})")) |
1057 | 0 | .collect::<Vec<String>>() |
1058 | 0 | .join("\n"); |
1059 | 0 |
|
1060 | 0 | format!( |
1061 | 0 | "No function matches the given name and argument types '{}({})'. You might need to add explicit type casts.\n\tCandidate functions:\n{}", |
1062 | 0 | func_name, TypeSignature::join_types(input_expr_types, ", "), candidate_signatures |
1063 | 0 | ) |
1064 | 0 | } |
1065 | | |
1066 | | /// Splits a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` |
1067 | | /// |
1068 | | /// See [`split_conjunction_owned`] for more details and an example. |
1069 | 0 | pub fn split_conjunction(expr: &Expr) -> Vec<&Expr> { |
1070 | 0 | split_conjunction_impl(expr, vec![]) |
1071 | 0 | } |
1072 | | |
1073 | 0 | fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<&'a Expr> { |
1074 | 0 | match expr { |
1075 | | Expr::BinaryExpr(BinaryExpr { |
1076 | 0 | right, |
1077 | 0 | op: Operator::And, |
1078 | 0 | left, |
1079 | 0 | }) => { |
1080 | 0 | let exprs = split_conjunction_impl(left, exprs); |
1081 | 0 | split_conjunction_impl(right, exprs) |
1082 | | } |
1083 | 0 | Expr::Alias(Alias { expr, .. }) => split_conjunction_impl(expr, exprs), |
1084 | 0 | other => { |
1085 | 0 | exprs.push(other); |
1086 | 0 | exprs |
1087 | | } |
1088 | | } |
1089 | 0 | } |
1090 | | |
1091 | | /// Splits an owned conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` |
1092 | | /// |
1093 | | /// This is often used to "split" filter expressions such as `col1 = 5 |
1094 | | /// AND col2 = 10` into [`col1 = 5`, `col2 = 10`]; |
1095 | | /// |
1096 | | /// # Example |
1097 | | /// ``` |
1098 | | /// # use datafusion_expr::{col, lit}; |
1099 | | /// # use datafusion_expr::utils::split_conjunction_owned; |
1100 | | /// // a=1 AND b=2 |
1101 | | /// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2))); |
1102 | | /// |
1103 | | /// // [a=1, b=2] |
1104 | | /// let split = vec![ |
1105 | | /// col("a").eq(lit(1)), |
1106 | | /// col("b").eq(lit(2)), |
1107 | | /// ]; |
1108 | | /// |
1109 | | /// // use split_conjunction_owned to split them |
1110 | | /// assert_eq!(split_conjunction_owned(expr), split); |
1111 | | /// ``` |
1112 | 0 | pub fn split_conjunction_owned(expr: Expr) -> Vec<Expr> { |
1113 | 0 | split_binary_owned(expr, Operator::And) |
1114 | 0 | } |
1115 | | |
1116 | | /// Splits an owned binary operator tree [`Expr`] such as `A <OP> B <OP> C` => `[A, B, C]` |
1117 | | /// |
1118 | | /// This is often used to "split" expressions such as `col1 = 5 |
1119 | | /// AND col2 = 10` into [`col1 = 5`, `col2 = 10`]; |
1120 | | /// |
1121 | | /// # Example |
1122 | | /// ``` |
1123 | | /// # use datafusion_expr::{col, lit, Operator}; |
1124 | | /// # use datafusion_expr::utils::split_binary_owned; |
1125 | | /// # use std::ops::Add; |
1126 | | /// // a=1 + b=2 |
1127 | | /// let expr = col("a").eq(lit(1)).add(col("b").eq(lit(2))); |
1128 | | /// |
1129 | | /// // [a=1, b=2] |
1130 | | /// let split = vec![ |
1131 | | /// col("a").eq(lit(1)), |
1132 | | /// col("b").eq(lit(2)), |
1133 | | /// ]; |
1134 | | /// |
1135 | | /// // use split_binary_owned to split them |
1136 | | /// assert_eq!(split_binary_owned(expr, Operator::Plus), split); |
1137 | | /// ``` |
1138 | 0 | pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec<Expr> { |
1139 | 0 | split_binary_owned_impl(expr, op, vec![]) |
1140 | 0 | } |
1141 | | |
1142 | 0 | fn split_binary_owned_impl( |
1143 | 0 | expr: Expr, |
1144 | 0 | operator: Operator, |
1145 | 0 | mut exprs: Vec<Expr>, |
1146 | 0 | ) -> Vec<Expr> { |
1147 | 0 | match expr { |
1148 | 0 | Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => { |
1149 | 0 | let exprs = split_binary_owned_impl(*left, operator, exprs); |
1150 | 0 | split_binary_owned_impl(*right, operator, exprs) |
1151 | | } |
1152 | 0 | Expr::Alias(Alias { expr, .. }) => { |
1153 | 0 | split_binary_owned_impl(*expr, operator, exprs) |
1154 | | } |
1155 | 0 | other => { |
1156 | 0 | exprs.push(other); |
1157 | 0 | exprs |
1158 | | } |
1159 | | } |
1160 | 0 | } |
1161 | | |
1162 | | /// Splits an binary operator tree [`Expr`] such as `A <OP> B <OP> C` => `[A, B, C]` |
1163 | | /// |
1164 | | /// See [`split_binary_owned`] for more details and an example. |
1165 | 0 | pub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> { |
1166 | 0 | split_binary_impl(expr, op, vec![]) |
1167 | 0 | } |
1168 | | |
1169 | 0 | fn split_binary_impl<'a>( |
1170 | 0 | expr: &'a Expr, |
1171 | 0 | operator: Operator, |
1172 | 0 | mut exprs: Vec<&'a Expr>, |
1173 | 0 | ) -> Vec<&'a Expr> { |
1174 | 0 | match expr { |
1175 | 0 | Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator => { |
1176 | 0 | let exprs = split_binary_impl(left, operator, exprs); |
1177 | 0 | split_binary_impl(right, operator, exprs) |
1178 | | } |
1179 | 0 | Expr::Alias(Alias { expr, .. }) => split_binary_impl(expr, operator, exprs), |
1180 | 0 | other => { |
1181 | 0 | exprs.push(other); |
1182 | 0 | exprs |
1183 | | } |
1184 | | } |
1185 | 0 | } |
1186 | | |
1187 | | /// Combines an array of filter expressions into a single filter |
1188 | | /// expression consisting of the input filter expressions joined with |
1189 | | /// logical AND. |
1190 | | /// |
1191 | | /// Returns None if the filters array is empty. |
1192 | | /// |
1193 | | /// # Example |
1194 | | /// ``` |
1195 | | /// # use datafusion_expr::{col, lit}; |
1196 | | /// # use datafusion_expr::utils::conjunction; |
1197 | | /// // a=1 AND b=2 |
1198 | | /// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2))); |
1199 | | /// |
1200 | | /// // [a=1, b=2] |
1201 | | /// let split = vec![ |
1202 | | /// col("a").eq(lit(1)), |
1203 | | /// col("b").eq(lit(2)), |
1204 | | /// ]; |
1205 | | /// |
1206 | | /// // use conjunction to join them together with `AND` |
1207 | | /// assert_eq!(conjunction(split), Some(expr)); |
1208 | | /// ``` |
1209 | 0 | pub fn conjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> { |
1210 | 0 | filters.into_iter().reduce(Expr::and) |
1211 | 0 | } |
1212 | | |
1213 | | /// Combines an array of filter expressions into a single filter |
1214 | | /// expression consisting of the input filter expressions joined with |
1215 | | /// logical OR. |
1216 | | /// |
1217 | | /// Returns None if the filters array is empty. |
1218 | | /// |
1219 | | /// # Example |
1220 | | /// ``` |
1221 | | /// # use datafusion_expr::{col, lit}; |
1222 | | /// # use datafusion_expr::utils::disjunction; |
1223 | | /// // a=1 OR b=2 |
1224 | | /// let expr = col("a").eq(lit(1)).or(col("b").eq(lit(2))); |
1225 | | /// |
1226 | | /// // [a=1, b=2] |
1227 | | /// let split = vec![ |
1228 | | /// col("a").eq(lit(1)), |
1229 | | /// col("b").eq(lit(2)), |
1230 | | /// ]; |
1231 | | /// |
1232 | | /// // use disjuncton to join them together with `OR` |
1233 | | /// assert_eq!(disjunction(split), Some(expr)); |
1234 | | /// ``` |
1235 | 0 | pub fn disjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> { |
1236 | 0 | filters.into_iter().reduce(Expr::or) |
1237 | 0 | } |
1238 | | |
1239 | | /// Returns a new [LogicalPlan] that filters the output of `plan` with a |
1240 | | /// [LogicalPlan::Filter] with all `predicates` ANDed. |
1241 | | /// |
1242 | | /// # Example |
1243 | | /// Before: |
1244 | | /// ```text |
1245 | | /// plan |
1246 | | /// ``` |
1247 | | /// |
1248 | | /// After: |
1249 | | /// ```text |
1250 | | /// Filter(predicate) |
1251 | | /// plan |
1252 | | /// ``` |
1253 | 0 | pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result<LogicalPlan> { |
1254 | 0 | // reduce filters to a single filter with an AND |
1255 | 0 | let predicate = predicates |
1256 | 0 | .iter() |
1257 | 0 | .skip(1) |
1258 | 0 | .fold(predicates[0].clone(), |acc, predicate| { |
1259 | 0 | and(acc, (*predicate).to_owned()) |
1260 | 0 | }); |
1261 | 0 |
|
1262 | 0 | Ok(LogicalPlan::Filter(Filter::try_new( |
1263 | 0 | predicate, |
1264 | 0 | Arc::new(plan), |
1265 | 0 | )?)) |
1266 | 0 | } |
1267 | | |
1268 | | /// Looks for correlating expressions: for example, a binary expression with one field from the subquery, and |
1269 | | /// one not in the subquery (closed upon from outer scope) |
1270 | | /// |
1271 | | /// # Arguments |
1272 | | /// |
1273 | | /// * `exprs` - List of expressions that may or may not be joins |
1274 | | /// |
1275 | | /// # Return value |
1276 | | /// |
1277 | | /// Tuple of (expressions containing joins, remaining non-join expressions) |
1278 | 0 | pub fn find_join_exprs(exprs: Vec<&Expr>) -> Result<(Vec<Expr>, Vec<Expr>)> { |
1279 | 0 | let mut joins = vec![]; |
1280 | 0 | let mut others = vec![]; |
1281 | 0 | for filter in exprs.into_iter() { |
1282 | | // If the expression contains correlated predicates, add it to join filters |
1283 | 0 | if filter.contains_outer() { |
1284 | 0 | if !matches!(filter, Expr::BinaryExpr(BinaryExpr{ left, op: Operator::Eq, right }) if left.eq(right)) |
1285 | 0 | { |
1286 | 0 | joins.push(strip_outer_reference((*filter).clone())); |
1287 | 0 | } |
1288 | 0 | } else { |
1289 | 0 | others.push((*filter).clone()); |
1290 | 0 | } |
1291 | | } |
1292 | | |
1293 | 0 | Ok((joins, others)) |
1294 | 0 | } |
1295 | | |
1296 | | /// Returns the first (and only) element in a slice, or an error |
1297 | | /// |
1298 | | /// # Arguments |
1299 | | /// |
1300 | | /// * `slice` - The slice to extract from |
1301 | | /// |
1302 | | /// # Return value |
1303 | | /// |
1304 | | /// The first element, or an error |
1305 | 0 | pub fn only_or_err<T>(slice: &[T]) -> Result<&T> { |
1306 | 0 | match slice { |
1307 | 0 | [it] => Ok(it), |
1308 | 0 | [] => plan_err!("No items found!"), |
1309 | 0 | _ => plan_err!("More than one item found!"), |
1310 | | } |
1311 | 0 | } |
1312 | | |
1313 | | /// merge inputs schema into a single schema. |
1314 | 0 | pub fn merge_schema(inputs: &[&LogicalPlan]) -> DFSchema { |
1315 | 0 | if inputs.len() == 1 { |
1316 | 0 | inputs[0].schema().as_ref().clone() |
1317 | | } else { |
1318 | 0 | inputs.iter().map(|input| input.schema()).fold( |
1319 | 0 | DFSchema::empty(), |
1320 | 0 | |mut lhs, rhs| { |
1321 | 0 | lhs.merge(rhs); |
1322 | 0 | lhs |
1323 | 0 | }, |
1324 | 0 | ) |
1325 | | } |
1326 | 0 | } |
1327 | | |
1328 | | /// Build state name. State is the intermediate state of the aggregate function. |
1329 | 137 | pub fn format_state_name(name: &str, state_name: &str) -> String { |
1330 | 137 | format!("{name}[{state_name}]") |
1331 | 137 | } |
1332 | | |
1333 | | #[cfg(test)] |
1334 | | mod tests { |
1335 | | use super::*; |
1336 | | use crate::{ |
1337 | | col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, |
1338 | | test::function_stub::max_udaf, test::function_stub::min_udaf, |
1339 | | test::function_stub::sum_udaf, Cast, ExprFunctionExt, WindowFunctionDefinition, |
1340 | | }; |
1341 | | |
1342 | | #[test] |
1343 | | fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> { |
1344 | | let result = group_window_expr_by_sort_keys(vec![])?; |
1345 | | let expected: Vec<(WindowSortKey, Vec<Expr>)> = vec![]; |
1346 | | assert_eq!(expected, result); |
1347 | | Ok(()) |
1348 | | } |
1349 | | |
1350 | | #[test] |
1351 | | fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> { |
1352 | | let max1 = Expr::WindowFunction(expr::WindowFunction::new( |
1353 | | WindowFunctionDefinition::AggregateUDF(max_udaf()), |
1354 | | vec![col("name")], |
1355 | | )); |
1356 | | let max2 = Expr::WindowFunction(expr::WindowFunction::new( |
1357 | | WindowFunctionDefinition::AggregateUDF(max_udaf()), |
1358 | | vec![col("name")], |
1359 | | )); |
1360 | | let min3 = Expr::WindowFunction(expr::WindowFunction::new( |
1361 | | WindowFunctionDefinition::AggregateUDF(min_udaf()), |
1362 | | vec![col("name")], |
1363 | | )); |
1364 | | let sum4 = Expr::WindowFunction(expr::WindowFunction::new( |
1365 | | WindowFunctionDefinition::AggregateUDF(sum_udaf()), |
1366 | | vec![col("age")], |
1367 | | )); |
1368 | | let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; |
1369 | | let result = group_window_expr_by_sort_keys(exprs.to_vec())?; |
1370 | | let key = vec![]; |
1371 | | let expected: Vec<(WindowSortKey, Vec<Expr>)> = |
1372 | | vec![(key, vec![max1, max2, min3, sum4])]; |
1373 | | assert_eq!(expected, result); |
1374 | | Ok(()) |
1375 | | } |
1376 | | |
1377 | | #[test] |
1378 | | fn test_group_window_expr_by_sort_keys() -> Result<()> { |
1379 | | let age_asc = expr::Sort::new(col("age"), true, true); |
1380 | | let name_desc = expr::Sort::new(col("name"), false, true); |
1381 | | let created_at_desc = expr::Sort::new(col("created_at"), false, true); |
1382 | | let max1 = Expr::WindowFunction(expr::WindowFunction::new( |
1383 | | WindowFunctionDefinition::AggregateUDF(max_udaf()), |
1384 | | vec![col("name")], |
1385 | | )) |
1386 | | .order_by(vec![age_asc.clone(), name_desc.clone()]) |
1387 | | .build() |
1388 | | .unwrap(); |
1389 | | let max2 = Expr::WindowFunction(expr::WindowFunction::new( |
1390 | | WindowFunctionDefinition::AggregateUDF(max_udaf()), |
1391 | | vec![col("name")], |
1392 | | )); |
1393 | | let min3 = Expr::WindowFunction(expr::WindowFunction::new( |
1394 | | WindowFunctionDefinition::AggregateUDF(min_udaf()), |
1395 | | vec![col("name")], |
1396 | | )) |
1397 | | .order_by(vec![age_asc.clone(), name_desc.clone()]) |
1398 | | .build() |
1399 | | .unwrap(); |
1400 | | let sum4 = Expr::WindowFunction(expr::WindowFunction::new( |
1401 | | WindowFunctionDefinition::AggregateUDF(sum_udaf()), |
1402 | | vec![col("age")], |
1403 | | )) |
1404 | | .order_by(vec![ |
1405 | | name_desc.clone(), |
1406 | | age_asc.clone(), |
1407 | | created_at_desc.clone(), |
1408 | | ]) |
1409 | | .build() |
1410 | | .unwrap(); |
1411 | | // FIXME use as_ref |
1412 | | let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; |
1413 | | let result = group_window_expr_by_sort_keys(exprs.to_vec())?; |
1414 | | |
1415 | | let key1 = vec![(age_asc.clone(), false), (name_desc.clone(), false)]; |
1416 | | let key2 = vec![]; |
1417 | | let key3 = vec![ |
1418 | | (name_desc, false), |
1419 | | (age_asc, false), |
1420 | | (created_at_desc, false), |
1421 | | ]; |
1422 | | |
1423 | | let expected: Vec<(WindowSortKey, Vec<Expr>)> = vec![ |
1424 | | (key1, vec![max1, min3]), |
1425 | | (key2, vec![max2]), |
1426 | | (key3, vec![sum4]), |
1427 | | ]; |
1428 | | assert_eq!(expected, result); |
1429 | | Ok(()) |
1430 | | } |
1431 | | |
1432 | | #[test] |
1433 | | fn avoid_generate_duplicate_sort_keys() -> Result<()> { |
1434 | | let asc_or_desc = [true, false]; |
1435 | | let nulls_first_or_last = [true, false]; |
1436 | | let partition_by = &[col("age"), col("name"), col("created_at")]; |
1437 | | for asc_ in asc_or_desc { |
1438 | | for nulls_first_ in nulls_first_or_last { |
1439 | | let order_by = &[ |
1440 | | Sort { |
1441 | | expr: col("age"), |
1442 | | asc: asc_, |
1443 | | nulls_first: nulls_first_, |
1444 | | }, |
1445 | | Sort { |
1446 | | expr: col("name"), |
1447 | | asc: asc_, |
1448 | | nulls_first: nulls_first_, |
1449 | | }, |
1450 | | ]; |
1451 | | |
1452 | | let expected = vec![ |
1453 | | ( |
1454 | | Sort { |
1455 | | expr: col("age"), |
1456 | | asc: asc_, |
1457 | | nulls_first: nulls_first_, |
1458 | | }, |
1459 | | true, |
1460 | | ), |
1461 | | ( |
1462 | | Sort { |
1463 | | expr: col("name"), |
1464 | | asc: asc_, |
1465 | | nulls_first: nulls_first_, |
1466 | | }, |
1467 | | true, |
1468 | | ), |
1469 | | ( |
1470 | | Sort { |
1471 | | expr: col("created_at"), |
1472 | | asc: true, |
1473 | | nulls_first: false, |
1474 | | }, |
1475 | | true, |
1476 | | ), |
1477 | | ]; |
1478 | | let result = generate_sort_key(partition_by, order_by)?; |
1479 | | assert_eq!(expected, result); |
1480 | | } |
1481 | | } |
1482 | | Ok(()) |
1483 | | } |
1484 | | |
1485 | | #[test] |
1486 | | fn test_enumerate_grouping_sets() -> Result<()> { |
1487 | | let multi_cols = vec![col("col1"), col("col2"), col("col3")]; |
1488 | | let simple_col = col("simple_col"); |
1489 | | let cube = cube(multi_cols.clone()); |
1490 | | let rollup = rollup(multi_cols.clone()); |
1491 | | let grouping_set = grouping_set(vec![multi_cols]); |
1492 | | |
1493 | | // 1. col |
1494 | | let sets = enumerate_grouping_sets(vec![simple_col.clone()])?; |
1495 | | let result = format!("[{}]", expr_vec_fmt!(sets)); |
1496 | | assert_eq!("[simple_col]", &result); |
1497 | | |
1498 | | // 2. cube |
1499 | | let sets = enumerate_grouping_sets(vec![cube.clone()])?; |
1500 | | let result = format!("[{}]", expr_vec_fmt!(sets)); |
1501 | | assert_eq!("[CUBE (col1, col2, col3)]", &result); |
1502 | | |
1503 | | // 3. rollup |
1504 | | let sets = enumerate_grouping_sets(vec![rollup.clone()])?; |
1505 | | let result = format!("[{}]", expr_vec_fmt!(sets)); |
1506 | | assert_eq!("[ROLLUP (col1, col2, col3)]", &result); |
1507 | | |
1508 | | // 4. col + cube |
1509 | | let sets = enumerate_grouping_sets(vec![simple_col.clone(), cube.clone()])?; |
1510 | | let result = format!("[{}]", expr_vec_fmt!(sets)); |
1511 | | assert_eq!( |
1512 | | "[GROUPING SETS (\ |
1513 | | (simple_col), \ |
1514 | | (simple_col, col1), \ |
1515 | | (simple_col, col2), \ |
1516 | | (simple_col, col1, col2), \ |
1517 | | (simple_col, col3), \ |
1518 | | (simple_col, col1, col3), \ |
1519 | | (simple_col, col2, col3), \ |
1520 | | (simple_col, col1, col2, col3))]", |
1521 | | &result |
1522 | | ); |
1523 | | |
1524 | | // 5. col + rollup |
1525 | | let sets = enumerate_grouping_sets(vec![simple_col.clone(), rollup.clone()])?; |
1526 | | let result = format!("[{}]", expr_vec_fmt!(sets)); |
1527 | | assert_eq!( |
1528 | | "[GROUPING SETS (\ |
1529 | | (simple_col), \ |
1530 | | (simple_col, col1), \ |
1531 | | (simple_col, col1, col2), \ |
1532 | | (simple_col, col1, col2, col3))]", |
1533 | | &result |
1534 | | ); |
1535 | | |
1536 | | // 6. col + grouping_set |
1537 | | let sets = |
1538 | | enumerate_grouping_sets(vec![simple_col.clone(), grouping_set.clone()])?; |
1539 | | let result = format!("[{}]", expr_vec_fmt!(sets)); |
1540 | | assert_eq!( |
1541 | | "[GROUPING SETS (\ |
1542 | | (simple_col, col1, col2, col3))]", |
1543 | | &result |
1544 | | ); |
1545 | | |
1546 | | // 7. col + grouping_set + rollup |
1547 | | let sets = enumerate_grouping_sets(vec![ |
1548 | | simple_col.clone(), |
1549 | | grouping_set, |
1550 | | rollup.clone(), |
1551 | | ])?; |
1552 | | let result = format!("[{}]", expr_vec_fmt!(sets)); |
1553 | | assert_eq!( |
1554 | | "[GROUPING SETS (\ |
1555 | | (simple_col, col1, col2, col3), \ |
1556 | | (simple_col, col1, col2, col3, col1), \ |
1557 | | (simple_col, col1, col2, col3, col1, col2), \ |
1558 | | (simple_col, col1, col2, col3, col1, col2, col3))]", |
1559 | | &result |
1560 | | ); |
1561 | | |
1562 | | // 8. col + cube + rollup |
1563 | | let sets = enumerate_grouping_sets(vec![simple_col, cube, rollup])?; |
1564 | | let result = format!("[{}]", expr_vec_fmt!(sets)); |
1565 | | assert_eq!( |
1566 | | "[GROUPING SETS (\ |
1567 | | (simple_col), \ |
1568 | | (simple_col, col1), \ |
1569 | | (simple_col, col1, col2), \ |
1570 | | (simple_col, col1, col2, col3), \ |
1571 | | (simple_col, col1), \ |
1572 | | (simple_col, col1, col1), \ |
1573 | | (simple_col, col1, col1, col2), \ |
1574 | | (simple_col, col1, col1, col2, col3), \ |
1575 | | (simple_col, col2), \ |
1576 | | (simple_col, col2, col1), \ |
1577 | | (simple_col, col2, col1, col2), \ |
1578 | | (simple_col, col2, col1, col2, col3), \ |
1579 | | (simple_col, col1, col2), \ |
1580 | | (simple_col, col1, col2, col1), \ |
1581 | | (simple_col, col1, col2, col1, col2), \ |
1582 | | (simple_col, col1, col2, col1, col2, col3), \ |
1583 | | (simple_col, col3), \ |
1584 | | (simple_col, col3, col1), \ |
1585 | | (simple_col, col3, col1, col2), \ |
1586 | | (simple_col, col3, col1, col2, col3), \ |
1587 | | (simple_col, col1, col3), \ |
1588 | | (simple_col, col1, col3, col1), \ |
1589 | | (simple_col, col1, col3, col1, col2), \ |
1590 | | (simple_col, col1, col3, col1, col2, col3), \ |
1591 | | (simple_col, col2, col3), \ |
1592 | | (simple_col, col2, col3, col1), \ |
1593 | | (simple_col, col2, col3, col1, col2), \ |
1594 | | (simple_col, col2, col3, col1, col2, col3), \ |
1595 | | (simple_col, col1, col2, col3), \ |
1596 | | (simple_col, col1, col2, col3, col1), \ |
1597 | | (simple_col, col1, col2, col3, col1, col2), \ |
1598 | | (simple_col, col1, col2, col3, col1, col2, col3))]", |
1599 | | &result |
1600 | | ); |
1601 | | |
1602 | | Ok(()) |
1603 | | } |
1604 | | #[test] |
1605 | | fn test_split_conjunction() { |
1606 | | let expr = col("a"); |
1607 | | let result = split_conjunction(&expr); |
1608 | | assert_eq!(result, vec![&expr]); |
1609 | | } |
1610 | | |
1611 | | #[test] |
1612 | | fn test_split_conjunction_two() { |
1613 | | let expr = col("a").eq(lit(5)).and(col("b")); |
1614 | | let expr1 = col("a").eq(lit(5)); |
1615 | | let expr2 = col("b"); |
1616 | | |
1617 | | let result = split_conjunction(&expr); |
1618 | | assert_eq!(result, vec![&expr1, &expr2]); |
1619 | | } |
1620 | | |
1621 | | #[test] |
1622 | | fn test_split_conjunction_alias() { |
1623 | | let expr = col("a").eq(lit(5)).and(col("b").alias("the_alias")); |
1624 | | let expr1 = col("a").eq(lit(5)); |
1625 | | let expr2 = col("b"); // has no alias |
1626 | | |
1627 | | let result = split_conjunction(&expr); |
1628 | | assert_eq!(result, vec![&expr1, &expr2]); |
1629 | | } |
1630 | | |
1631 | | #[test] |
1632 | | fn test_split_conjunction_or() { |
1633 | | let expr = col("a").eq(lit(5)).or(col("b")); |
1634 | | let result = split_conjunction(&expr); |
1635 | | assert_eq!(result, vec![&expr]); |
1636 | | } |
1637 | | |
1638 | | #[test] |
1639 | | fn test_split_binary_owned() { |
1640 | | let expr = col("a"); |
1641 | | assert_eq!(split_binary_owned(expr.clone(), Operator::And), vec![expr]); |
1642 | | } |
1643 | | |
1644 | | #[test] |
1645 | | fn test_split_binary_owned_two() { |
1646 | | assert_eq!( |
1647 | | split_binary_owned(col("a").eq(lit(5)).and(col("b")), Operator::And), |
1648 | | vec![col("a").eq(lit(5)), col("b")] |
1649 | | ); |
1650 | | } |
1651 | | |
1652 | | #[test] |
1653 | | fn test_split_binary_owned_different_op() { |
1654 | | let expr = col("a").eq(lit(5)).or(col("b")); |
1655 | | assert_eq!( |
1656 | | // expr is connected by OR, but pass in AND |
1657 | | split_binary_owned(expr.clone(), Operator::And), |
1658 | | vec![expr] |
1659 | | ); |
1660 | | } |
1661 | | |
1662 | | #[test] |
1663 | | fn test_split_conjunction_owned() { |
1664 | | let expr = col("a"); |
1665 | | assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]); |
1666 | | } |
1667 | | |
1668 | | #[test] |
1669 | | fn test_split_conjunction_owned_two() { |
1670 | | assert_eq!( |
1671 | | split_conjunction_owned(col("a").eq(lit(5)).and(col("b"))), |
1672 | | vec![col("a").eq(lit(5)), col("b")] |
1673 | | ); |
1674 | | } |
1675 | | |
1676 | | #[test] |
1677 | | fn test_split_conjunction_owned_alias() { |
1678 | | assert_eq!( |
1679 | | split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias"))), |
1680 | | vec![ |
1681 | | col("a").eq(lit(5)), |
1682 | | // no alias on b |
1683 | | col("b"), |
1684 | | ] |
1685 | | ); |
1686 | | } |
1687 | | |
1688 | | #[test] |
1689 | | fn test_conjunction_empty() { |
1690 | | assert_eq!(conjunction(vec![]), None); |
1691 | | } |
1692 | | |
1693 | | #[test] |
1694 | | fn test_conjunction() { |
1695 | | // `[A, B, C]` |
1696 | | let expr = conjunction(vec![col("a"), col("b"), col("c")]); |
1697 | | |
1698 | | // --> `(A AND B) AND C` |
1699 | | assert_eq!(expr, Some(col("a").and(col("b")).and(col("c")))); |
1700 | | |
1701 | | // which is different than `A AND (B AND C)` |
1702 | | assert_ne!(expr, Some(col("a").and(col("b").and(col("c"))))); |
1703 | | } |
1704 | | |
1705 | | #[test] |
1706 | | fn test_disjunction_empty() { |
1707 | | assert_eq!(disjunction(vec![]), None); |
1708 | | } |
1709 | | |
1710 | | #[test] |
1711 | | fn test_disjunction() { |
1712 | | // `[A, B, C]` |
1713 | | let expr = disjunction(vec![col("a"), col("b"), col("c")]); |
1714 | | |
1715 | | // --> `(A OR B) OR C` |
1716 | | assert_eq!(expr, Some(col("a").or(col("b")).or(col("c")))); |
1717 | | |
1718 | | // which is different than `A OR (B OR C)` |
1719 | | assert_ne!(expr, Some(col("a").or(col("b").or(col("c"))))); |
1720 | | } |
1721 | | |
1722 | | #[test] |
1723 | | fn test_split_conjunction_owned_or() { |
1724 | | let expr = col("a").eq(lit(5)).or(col("b")); |
1725 | | assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]); |
1726 | | } |
1727 | | |
1728 | | #[test] |
1729 | | fn test_collect_expr() -> Result<()> { |
1730 | | let mut accum: HashSet<Column> = HashSet::new(); |
1731 | | expr_to_columns( |
1732 | | &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), |
1733 | | &mut accum, |
1734 | | )?; |
1735 | | expr_to_columns( |
1736 | | &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), |
1737 | | &mut accum, |
1738 | | )?; |
1739 | | assert_eq!(1, accum.len()); |
1740 | | assert!(accum.contains(&Column::from_name("a"))); |
1741 | | Ok(()) |
1742 | | } |
1743 | | } |