Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/expr/src/utils.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Expression utilities
19
20
use std::cmp::Ordering;
21
use std::collections::{HashMap, HashSet};
22
use std::ops::Deref;
23
use std::sync::Arc;
24
25
use crate::expr::{Alias, Sort, WildcardOptions, WindowFunction};
26
use crate::expr_rewriter::strip_outer_reference;
27
use crate::{
28
    and, BinaryExpr, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, Operator,
29
};
30
use datafusion_expr_common::signature::{Signature, TypeSignature};
31
32
use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
33
use datafusion_common::tree_node::{
34
    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
35
};
36
use datafusion_common::utils::get_at_indices;
37
use datafusion_common::{
38
    internal_err, plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef,
39
    DataFusionError, Result, TableReference,
40
};
41
42
use indexmap::IndexSet;
43
use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem};
44
45
pub use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity;
46
47
///  The value to which `COUNT(*)` is expanded to in
48
///  `COUNT(<constant>)` expressions
49
pub use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;
50
51
/// Recursively walk a list of expression trees, collecting the unique set of columns
52
/// referenced in the expression
53
#[deprecated(since = "40.0.0", note = "Expr::add_column_refs instead")]
54
0
pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet<Column>) -> Result<()> {
55
0
    for e in expr {
56
0
        expr_to_columns(e, accum)?;
57
    }
58
0
    Ok(())
59
0
}
60
61
/// Count the number of distinct exprs in a list of group by expressions. If the
62
/// first element is a `GroupingSet` expression then it must be the only expr.
63
0
pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result<usize> {
64
0
    grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len())
65
0
}
66
67
/// The [power set] (or powerset) of a set S is the set of all subsets of S, \
68
/// including the empty set and S itself.
69
///
70
/// Example:
71
///
72
/// If S is the set {x, y, z}, then all the subsets of S are \
73
///  {} \
74
///  {x} \
75
///  {y} \
76
///  {z} \
77
///  {x, y} \
78
///  {x, z} \
79
///  {y, z} \
80
///  {x, y, z} \
81
///  and hence the power set of S is {{}, {x}, {y}, {z}, {x, y}, {x, z}, {y, z}, {x, y, z}}.
82
///
83
/// [power set]: https://en.wikipedia.org/wiki/Power_set
84
0
fn powerset<T>(slice: &[T]) -> Result<Vec<Vec<&T>>, String> {
85
0
    if slice.len() >= 64 {
86
0
        return Err("The size of the set must be less than 64.".into());
87
0
    }
88
0
89
0
    let mut v = Vec::new();
90
0
    for mask in 0..(1 << slice.len()) {
91
0
        let mut ss = vec![];
92
0
        let mut bitset = mask;
93
0
        while bitset > 0 {
94
0
            let rightmost: u64 = bitset & !(bitset - 1);
95
0
            let idx = rightmost.trailing_zeros();
96
0
            let item = slice.get(idx as usize).unwrap();
97
0
            ss.push(item);
98
0
            // zero the trailing bit
99
0
            bitset &= bitset - 1;
100
0
        }
101
0
        v.push(ss);
102
    }
103
0
    Ok(v)
104
0
}
105
106
/// check the number of expressions contained in the grouping_set
107
0
fn check_grouping_set_size_limit(size: usize) -> Result<()> {
108
0
    let max_grouping_set_size = 65535;
109
0
    if size > max_grouping_set_size {
110
0
        return plan_err!("The number of group_expression in grouping_set exceeds the maximum limit {max_grouping_set_size}, found {size}");
111
0
    }
112
0
113
0
    Ok(())
114
0
}
115
116
/// check the number of grouping_set contained in the grouping sets
117
0
fn check_grouping_sets_size_limit(size: usize) -> Result<()> {
118
0
    let max_grouping_sets_size = 4096;
119
0
    if size > max_grouping_sets_size {
120
0
        return plan_err!("The number of grouping_set in grouping_sets exceeds the maximum limit {max_grouping_sets_size}, found {size}");
121
0
    }
122
0
123
0
    Ok(())
124
0
}
125
126
/// Merge two grouping_set
127
///
128
/// # Example
129
/// ```text
130
/// (A, B), (C, D) -> (A, B, C, D)
131
/// ```
132
///
133
/// # Error
134
/// - [`DataFusionError`]: The number of group_expression in grouping_set exceeds the maximum limit
135
///
136
/// [`DataFusionError`]: datafusion_common::DataFusionError
137
0
fn merge_grouping_set<T: Clone>(left: &[T], right: &[T]) -> Result<Vec<T>> {
138
0
    check_grouping_set_size_limit(left.len() + right.len())?;
139
0
    Ok(left.iter().chain(right.iter()).cloned().collect())
140
0
}
141
142
/// Compute the cross product of two grouping_sets
143
///
144
/// # Example
145
/// ```text
146
/// [(A, B), (C, D)], [(E), (F)] -> [(A, B, E), (A, B, F), (C, D, E), (C, D, F)]
147
/// ```
148
///
149
/// # Error
150
/// - [`DataFusionError`]: The number of group_expression in grouping_set exceeds the maximum limit
151
/// - [`DataFusionError`]: The number of grouping_set in grouping_sets exceeds the maximum limit
152
///
153
/// [`DataFusionError`]: datafusion_common::DataFusionError
154
0
fn cross_join_grouping_sets<T: Clone>(
155
0
    left: &[Vec<T>],
156
0
    right: &[Vec<T>],
157
0
) -> Result<Vec<Vec<T>>> {
158
0
    let grouping_sets_size = left.len() * right.len();
159
0
160
0
    check_grouping_sets_size_limit(grouping_sets_size)?;
161
162
0
    let mut result = Vec::with_capacity(grouping_sets_size);
163
0
    for le in left {
164
0
        for re in right {
165
0
            result.push(merge_grouping_set(le, re)?);
166
        }
167
    }
168
0
    Ok(result)
169
0
}
170
171
/// Convert multiple grouping expressions into one [`GroupingSet::GroupingSets`],\
172
/// if the grouping expression does not contain [`Expr::GroupingSet`] or only has one expression,\
173
/// no conversion will be performed.
174
///
175
/// e.g.
176
///
177
/// person.id,\
178
/// GROUPING SETS ((person.age, person.salary),(person.age)),\
179
/// ROLLUP(person.state, person.birth_date)
180
///
181
/// =>
182
///
183
/// GROUPING SETS (\
184
///   (person.id, person.age, person.salary),\
185
///   (person.id, person.age, person.salary, person.state),\
186
///   (person.id, person.age, person.salary, person.state, person.birth_date),\
187
///   (person.id, person.age),\
188
///   (person.id, person.age, person.state),\
189
///   (person.id, person.age, person.state, person.birth_date)\
190
/// )
191
0
pub fn enumerate_grouping_sets(group_expr: Vec<Expr>) -> Result<Vec<Expr>> {
192
0
    let has_grouping_set = group_expr
193
0
        .iter()
194
0
        .any(|expr| matches!(expr, Expr::GroupingSet(_)));
195
0
    if !has_grouping_set || group_expr.len() == 1 {
196
0
        return Ok(group_expr);
197
0
    }
198
    // only process mix grouping sets
199
0
    let partial_sets = group_expr
200
0
        .iter()
201
0
        .map(|expr| {
202
0
            let exprs = match expr {
203
0
                Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets)) => {
204
0
                    check_grouping_sets_size_limit(grouping_sets.len())?;
205
0
                    grouping_sets.iter().map(|e| e.iter().collect()).collect()
206
                }
207
0
                Expr::GroupingSet(GroupingSet::Cube(group_exprs)) => {
208
0
                    let grouping_sets = powerset(group_exprs)
209
0
                        .map_err(|e| plan_datafusion_err!("{}", e))?;
210
0
                    check_grouping_sets_size_limit(grouping_sets.len())?;
211
0
                    grouping_sets
212
                }
213
0
                Expr::GroupingSet(GroupingSet::Rollup(group_exprs)) => {
214
0
                    let size = group_exprs.len();
215
0
                    let slice = group_exprs.as_slice();
216
0
                    check_grouping_sets_size_limit(size * (size + 1) / 2 + 1)?;
217
0
                    (0..(size + 1))
218
0
                        .map(|i| slice[0..i].iter().collect())
219
0
                        .collect()
220
                }
221
0
                expr => vec![vec![expr]],
222
            };
223
0
            Ok(exprs)
224
0
        })
225
0
        .collect::<Result<Vec<_>>>()?;
226
227
    // cross join
228
0
    let grouping_sets = partial_sets
229
0
        .into_iter()
230
0
        .map(Ok)
231
0
        .reduce(|l, r| cross_join_grouping_sets(&l?, &r?))
232
0
        .transpose()?
233
0
        .map(|e| {
234
0
            e.into_iter()
235
0
                .map(|e| e.into_iter().cloned().collect())
236
0
                .collect()
237
0
        })
238
0
        .unwrap_or_default();
239
0
240
0
    Ok(vec![Expr::GroupingSet(GroupingSet::GroupingSets(
241
0
        grouping_sets,
242
0
    ))])
243
0
}
244
245
/// Find all distinct exprs in a list of group by expressions. If the
246
/// first element is a `GroupingSet` expression then it must be the only expr.
247
0
pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result<Vec<&Expr>> {
248
0
    if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
249
0
        if group_expr.len() > 1 {
250
0
            return plan_err!(
251
0
                "Invalid group by expressions, GroupingSet must be the only expression"
252
0
            );
253
0
        }
254
0
        Ok(grouping_set.distinct_expr())
255
    } else {
256
0
        Ok(group_expr
257
0
            .iter()
258
0
            .collect::<IndexSet<_>>()
259
0
            .into_iter()
260
0
            .collect())
261
    }
262
0
}
263
264
/// Recursively walk an expression tree, collecting the unique set of columns
265
/// referenced in the expression
266
0
pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet<Column>) -> Result<()> {
267
0
    expr.apply(|expr| {
268
0
        match expr {
269
0
            Expr::Column(qc) => {
270
0
                accum.insert(qc.clone());
271
0
            }
272
            // Use explicit pattern match instead of a default
273
            // implementation, so that in the future if someone adds
274
            // new Expr types, they will check here as well
275
            Expr::Unnest(_)
276
            | Expr::ScalarVariable(_, _)
277
            | Expr::Alias(_)
278
            | Expr::Literal(_)
279
            | Expr::BinaryExpr { .. }
280
            | Expr::Like { .. }
281
            | Expr::SimilarTo { .. }
282
            | Expr::Not(_)
283
            | Expr::IsNotNull(_)
284
            | Expr::IsNull(_)
285
            | Expr::IsTrue(_)
286
            | Expr::IsFalse(_)
287
            | Expr::IsUnknown(_)
288
            | Expr::IsNotTrue(_)
289
            | Expr::IsNotFalse(_)
290
            | Expr::IsNotUnknown(_)
291
            | Expr::Negative(_)
292
            | Expr::Between { .. }
293
            | Expr::Case { .. }
294
            | Expr::Cast { .. }
295
            | Expr::TryCast { .. }
296
            | Expr::ScalarFunction(..)
297
            | Expr::WindowFunction { .. }
298
            | Expr::AggregateFunction { .. }
299
            | Expr::GroupingSet(_)
300
            | Expr::InList { .. }
301
            | Expr::Exists { .. }
302
            | Expr::InSubquery(_)
303
            | Expr::ScalarSubquery(_)
304
            | Expr::Wildcard { .. }
305
            | Expr::Placeholder(_)
306
0
            | Expr::OuterReferenceColumn { .. } => {}
307
        }
308
0
        Ok(TreeNodeRecursion::Continue)
309
0
    })
310
0
    .map(|_| ())
311
0
}
312
313
/// Find excluded columns in the schema, if any
314
/// SELECT * EXCLUDE(col1, col2), would return `vec![col1, col2]`
315
0
fn get_excluded_columns(
316
0
    opt_exclude: Option<&ExcludeSelectItem>,
317
0
    opt_except: Option<&ExceptSelectItem>,
318
0
    schema: &DFSchema,
319
0
    qualifier: Option<&TableReference>,
320
0
) -> Result<Vec<Column>> {
321
0
    let mut idents = vec![];
322
0
    if let Some(excepts) = opt_except {
323
0
        idents.push(&excepts.first_element);
324
0
        idents.extend(&excepts.additional_elements);
325
0
    }
326
0
    if let Some(exclude) = opt_exclude {
327
0
        match exclude {
328
0
            ExcludeSelectItem::Single(ident) => idents.push(ident),
329
0
            ExcludeSelectItem::Multiple(idents_inner) => idents.extend(idents_inner),
330
        }
331
0
    }
332
    // Excluded columns should be unique
333
0
    let n_elem = idents.len();
334
0
    let unique_idents = idents.into_iter().collect::<HashSet<_>>();
335
0
    // if HashSet size, and vector length are different, this means that some of the excluded columns
336
0
    // are not unique. In this case return error.
337
0
    if n_elem != unique_idents.len() {
338
0
        return plan_err!("EXCLUDE or EXCEPT contains duplicate column names");
339
0
    }
340
0
341
0
    let mut result = vec![];
342
0
    for ident in unique_idents.into_iter() {
343
0
        let col_name = ident.value.as_str();
344
0
        let (qualifier, field) = schema.qualified_field_with_name(qualifier, col_name)?;
345
0
        result.push(Column::from((qualifier, field)));
346
    }
347
0
    Ok(result)
348
0
}
349
350
/// Returns all `Expr`s in the schema, except the `Column`s in the `columns_to_skip`
351
0
fn get_exprs_except_skipped(
352
0
    schema: &DFSchema,
353
0
    columns_to_skip: HashSet<Column>,
354
0
) -> Vec<Expr> {
355
0
    if columns_to_skip.is_empty() {
356
0
        schema.iter().map(Expr::from).collect::<Vec<Expr>>()
357
    } else {
358
0
        schema
359
0
            .columns()
360
0
            .iter()
361
0
            .filter_map(|c| {
362
0
                if !columns_to_skip.contains(c) {
363
0
                    Some(Expr::Column(c.clone()))
364
                } else {
365
0
                    None
366
                }
367
0
            })
368
0
            .collect::<Vec<Expr>>()
369
    }
370
0
}
371
372
/// Resolves an `Expr::Wildcard` to a collection of `Expr::Column`'s.
373
0
pub fn expand_wildcard(
374
0
    schema: &DFSchema,
375
0
    plan: &LogicalPlan,
376
0
    wildcard_options: Option<&WildcardOptions>,
377
0
) -> Result<Vec<Expr>> {
378
0
    let using_columns = plan.using_columns()?;
379
0
    let mut columns_to_skip = using_columns
380
0
        .into_iter()
381
0
        // For each USING JOIN condition, only expand to one of each join column in projection
382
0
        .flat_map(|cols| {
383
0
            let mut cols = cols.into_iter().collect::<Vec<_>>();
384
0
            // sort join columns to make sure we consistently keep the same
385
0
            // qualified column
386
0
            cols.sort();
387
0
            let mut out_column_names: HashSet<String> = HashSet::new();
388
0
            cols.into_iter()
389
0
                .filter_map(|c| {
390
0
                    if out_column_names.contains(&c.name) {
391
0
                        Some(c)
392
                    } else {
393
0
                        out_column_names.insert(c.name);
394
0
                        None
395
                    }
396
0
                })
397
0
                .collect::<Vec<_>>()
398
0
        })
399
0
        .collect::<HashSet<_>>();
400
0
    let excluded_columns = if let Some(WildcardOptions {
401
0
        exclude: opt_exclude,
402
0
        except: opt_except,
403
        ..
404
0
    }) = wildcard_options
405
    {
406
0
        get_excluded_columns(opt_exclude.as_ref(), opt_except.as_ref(), schema, None)?
407
    } else {
408
0
        vec![]
409
    };
410
    // Add each excluded `Column` to columns_to_skip
411
0
    columns_to_skip.extend(excluded_columns);
412
0
    Ok(get_exprs_except_skipped(schema, columns_to_skip))
413
0
}
414
415
/// Resolves an `Expr::Wildcard` to a collection of qualified `Expr::Column`'s.
416
0
pub fn expand_qualified_wildcard(
417
0
    qualifier: &TableReference,
418
0
    schema: &DFSchema,
419
0
    wildcard_options: Option<&WildcardOptions>,
420
0
) -> Result<Vec<Expr>> {
421
0
    let qualified_indices = schema.fields_indices_with_qualified(qualifier);
422
0
    let projected_func_dependencies = schema
423
0
        .functional_dependencies()
424
0
        .project_functional_dependencies(&qualified_indices, qualified_indices.len());
425
0
    let fields_with_qualified = get_at_indices(schema.fields(), &qualified_indices)?;
426
0
    if fields_with_qualified.is_empty() {
427
0
        return plan_err!("Invalid qualifier {qualifier}");
428
0
    }
429
0
430
0
    let qualified_schema = Arc::new(Schema::new(fields_with_qualified));
431
0
    let qualified_dfschema =
432
0
        DFSchema::try_from_qualified_schema(qualifier.clone(), &qualified_schema)?
433
0
            .with_functional_dependencies(projected_func_dependencies)?;
434
0
    let excluded_columns = if let Some(WildcardOptions {
435
0
        exclude: opt_exclude,
436
0
        except: opt_except,
437
        ..
438
0
    }) = wildcard_options
439
    {
440
0
        get_excluded_columns(
441
0
            opt_exclude.as_ref(),
442
0
            opt_except.as_ref(),
443
0
            schema,
444
0
            Some(qualifier),
445
0
        )?
446
    } else {
447
0
        vec![]
448
    };
449
    // Add each excluded `Column` to columns_to_skip
450
0
    let mut columns_to_skip = HashSet::new();
451
0
    columns_to_skip.extend(excluded_columns);
452
0
    Ok(get_exprs_except_skipped(
453
0
        &qualified_dfschema,
454
0
        columns_to_skip,
455
0
    ))
456
0
}
457
458
/// (expr, "is the SortExpr for window (either comes from PARTITION BY or ORDER BY columns)")
459
/// if bool is true SortExpr comes from `PARTITION BY` column, if false comes from `ORDER BY` column
460
type WindowSortKey = Vec<(Sort, bool)>;
461
462
/// Generate a sort key for a given window expr's partition_by and order_by expr
463
0
pub fn generate_sort_key(
464
0
    partition_by: &[Expr],
465
0
    order_by: &[Sort],
466
0
) -> Result<WindowSortKey> {
467
0
    let normalized_order_by_keys = order_by
468
0
        .iter()
469
0
        .map(|e| {
470
0
            let Sort { expr, .. } = e;
471
0
            Sort::new(expr.clone(), true, false)
472
0
        })
473
0
        .collect::<Vec<_>>();
474
0
475
0
    let mut final_sort_keys = vec![];
476
0
    let mut is_partition_flag = vec![];
477
0
    partition_by.iter().for_each(|e| {
478
0
        // By default, create sort key with ASC is true and NULLS LAST to be consistent with
479
0
        // PostgreSQL's rule: https://www.postgresql.org/docs/current/queries-order.html
480
0
        let e = e.clone().sort(true, false);
481
0
        if let Some(pos) = normalized_order_by_keys.iter().position(|key| key.eq(&e)) {
482
0
            let order_by_key = &order_by[pos];
483
0
            if !final_sort_keys.contains(order_by_key) {
484
0
                final_sort_keys.push(order_by_key.clone());
485
0
                is_partition_flag.push(true);
486
0
            }
487
0
        } else if !final_sort_keys.contains(&e) {
488
0
            final_sort_keys.push(e);
489
0
            is_partition_flag.push(true);
490
0
        }
491
0
    });
492
0
493
0
    order_by.iter().for_each(|e| {
494
0
        if !final_sort_keys.contains(e) {
495
0
            final_sort_keys.push(e.clone());
496
0
            is_partition_flag.push(false);
497
0
        }
498
0
    });
499
0
    let res = final_sort_keys
500
0
        .into_iter()
501
0
        .zip(is_partition_flag)
502
0
        .collect::<Vec<_>>();
503
0
    Ok(res)
504
0
}
505
506
/// Compare the sort expr as PostgreSQL's common_prefix_cmp():
507
/// <https://github.com/postgres/postgres/blob/master/src/backend/optimizer/plan/planner.c>
508
0
pub fn compare_sort_expr(
509
0
    sort_expr_a: &Sort,
510
0
    sort_expr_b: &Sort,
511
0
    schema: &DFSchemaRef,
512
0
) -> Ordering {
513
0
    let Sort {
514
0
        expr: expr_a,
515
0
        asc: asc_a,
516
0
        nulls_first: nulls_first_a,
517
0
    } = sort_expr_a;
518
0
519
0
    let Sort {
520
0
        expr: expr_b,
521
0
        asc: asc_b,
522
0
        nulls_first: nulls_first_b,
523
0
    } = sort_expr_b;
524
0
525
0
    let ref_indexes_a = find_column_indexes_referenced_by_expr(expr_a, schema);
526
0
    let ref_indexes_b = find_column_indexes_referenced_by_expr(expr_b, schema);
527
0
    for (idx_a, idx_b) in ref_indexes_a.iter().zip(ref_indexes_b.iter()) {
528
0
        match idx_a.cmp(idx_b) {
529
            Ordering::Less => {
530
0
                return Ordering::Less;
531
            }
532
            Ordering::Greater => {
533
0
                return Ordering::Greater;
534
            }
535
0
            Ordering::Equal => {}
536
        }
537
    }
538
0
    match ref_indexes_a.len().cmp(&ref_indexes_b.len()) {
539
0
        Ordering::Less => return Ordering::Greater,
540
        Ordering::Greater => {
541
0
            return Ordering::Less;
542
        }
543
0
        Ordering::Equal => {}
544
0
    }
545
0
    match (asc_a, asc_b) {
546
        (true, false) => {
547
0
            return Ordering::Greater;
548
        }
549
        (false, true) => {
550
0
            return Ordering::Less;
551
        }
552
0
        _ => {}
553
0
    }
554
0
    match (nulls_first_a, nulls_first_b) {
555
        (true, false) => {
556
0
            return Ordering::Less;
557
        }
558
        (false, true) => {
559
0
            return Ordering::Greater;
560
        }
561
0
        _ => {}
562
0
    }
563
0
    Ordering::Equal
564
0
}
565
566
/// group a slice of window expression expr by their order by expressions
567
0
pub fn group_window_expr_by_sort_keys(
568
0
    window_expr: Vec<Expr>,
569
0
) -> Result<Vec<(WindowSortKey, Vec<Expr>)>> {
570
0
    let mut result = vec![];
571
0
    window_expr.into_iter().try_for_each(|expr| match &expr {
572
0
        Expr::WindowFunction( WindowFunction{ partition_by, order_by, .. }) => {
573
0
            let sort_key = generate_sort_key(partition_by, order_by)?;
574
0
            if let Some((_, values)) = result.iter_mut().find(
575
0
                |group: &&mut (WindowSortKey, Vec<Expr>)| matches!(group, (key, _) if *key == sort_key),
576
0
            ) {
577
0
                values.push(expr);
578
0
            } else {
579
0
                result.push((sort_key, vec![expr]))
580
            }
581
0
            Ok(())
582
        }
583
0
        other => internal_err!(
584
0
            "Impossibly got non-window expr {other:?}"
585
0
        ),
586
0
    })?;
587
0
    Ok(result)
588
0
}
589
590
/// Collect all deeply nested `Expr::AggregateFunction`.
591
/// They are returned in order of occurrence (depth
592
/// first), with duplicates omitted.
593
0
pub fn find_aggregate_exprs(exprs: &[Expr]) -> Vec<Expr> {
594
0
    find_exprs_in_exprs(exprs, &|nested_expr| {
595
0
        matches!(nested_expr, Expr::AggregateFunction { .. })
596
0
    })
597
0
}
598
599
/// Collect all deeply nested `Expr::WindowFunction`. They are returned in order of occurrence
600
/// (depth first), with duplicates omitted.
601
0
pub fn find_window_exprs(exprs: &[Expr]) -> Vec<Expr> {
602
0
    find_exprs_in_exprs(exprs, &|nested_expr| {
603
0
        matches!(nested_expr, Expr::WindowFunction { .. })
604
0
    })
605
0
}
606
607
/// Collect all deeply nested `Expr::OuterReferenceColumn`. They are returned in order of occurrence
608
/// (depth first), with duplicates omitted.
609
0
pub fn find_out_reference_exprs(expr: &Expr) -> Vec<Expr> {
610
0
    find_exprs_in_expr(expr, &|nested_expr| {
611
0
        matches!(nested_expr, Expr::OuterReferenceColumn { .. })
612
0
    })
613
0
}
614
615
/// Search the provided `Expr`'s, and all of their nested `Expr`, for any that
616
/// pass the provided test. The returned `Expr`'s are deduplicated and returned
617
/// in order of appearance (depth first).
618
0
fn find_exprs_in_exprs<F>(exprs: &[Expr], test_fn: &F) -> Vec<Expr>
619
0
where
620
0
    F: Fn(&Expr) -> bool,
621
0
{
622
0
    exprs
623
0
        .iter()
624
0
        .flat_map(|expr| find_exprs_in_expr(expr, test_fn))
625
0
        .fold(vec![], |mut acc, expr| {
626
0
            if !acc.contains(&expr) {
627
0
                acc.push(expr)
628
0
            }
629
0
            acc
630
0
        })
631
0
}
632
633
/// Search an `Expr`, and all of its nested `Expr`'s, for any that pass the
634
/// provided test. The returned `Expr`'s are deduplicated and returned in order
635
/// of appearance (depth first).
636
0
fn find_exprs_in_expr<F>(expr: &Expr, test_fn: &F) -> Vec<Expr>
637
0
where
638
0
    F: Fn(&Expr) -> bool,
639
0
{
640
0
    let mut exprs = vec![];
641
0
    expr.apply(|expr| {
642
0
        if test_fn(expr) {
643
0
            if !(exprs.contains(expr)) {
644
0
                exprs.push(expr.clone())
645
0
            }
646
            // stop recursing down this expr once we find a match
647
0
            return Ok(TreeNodeRecursion::Jump);
648
0
        }
649
0
650
0
        Ok(TreeNodeRecursion::Continue)
651
0
    })
652
0
    // pre_visit always returns OK, so this will always too
653
0
    .expect("no way to return error during recursion");
654
0
    exprs
655
0
}
656
657
/// Recursively inspect an [`Expr`] and all its children.
658
0
pub fn inspect_expr_pre<F, E>(expr: &Expr, mut f: F) -> Result<(), E>
659
0
where
660
0
    F: FnMut(&Expr) -> Result<(), E>,
661
0
{
662
0
    let mut err = Ok(());
663
0
    expr.apply(|expr| {
664
0
        if let Err(e) = f(expr) {
665
            // save the error for later (it may not be a DataFusionError
666
0
            err = Err(e);
667
0
            Ok(TreeNodeRecursion::Stop)
668
        } else {
669
            // keep going
670
0
            Ok(TreeNodeRecursion::Continue)
671
        }
672
0
    })
673
0
    // The closure always returns OK, so this will always too
674
0
    .expect("no way to return error during recursion");
675
0
676
0
    err
677
0
}
678
679
/// Create field meta-data from an expression, for use in a result set schema
680
0
pub fn exprlist_to_fields<'a>(
681
0
    exprs: impl IntoIterator<Item = &'a Expr>,
682
0
    plan: &LogicalPlan,
683
0
) -> Result<Vec<(Option<TableReference>, Arc<Field>)>> {
684
0
    // look for exact match in plan's output schema
685
0
    let wildcard_schema = find_base_plan(plan).schema();
686
0
    let input_schema = plan.schema();
687
0
    let result = exprs
688
0
        .into_iter()
689
0
        .map(|e| match e {
690
0
            Expr::Wildcard { qualifier, options } => match qualifier {
691
                None => {
692
0
                    let excluded: Vec<String> = get_excluded_columns(
693
0
                        options.exclude.as_ref(),
694
0
                        options.except.as_ref(),
695
0
                        wildcard_schema,
696
0
                        None,
697
0
                    )?
698
0
                    .into_iter()
699
0
                    .map(|c| c.flat_name())
700
0
                    .collect();
701
0
                    Ok::<_, DataFusionError>(
702
0
                        wildcard_schema
703
0
                            .field_names()
704
0
                            .iter()
705
0
                            .enumerate()
706
0
                            .filter(|(_, s)| !excluded.contains(s))
707
0
                            .map(|(i, _)| wildcard_schema.qualified_field(i))
708
0
                            .map(|(qualifier, f)| {
709
0
                                (qualifier.cloned(), Arc::new(f.to_owned()))
710
0
                            })
711
0
                            .collect::<Vec<_>>(),
712
0
                    )
713
                }
714
0
                Some(qualifier) => {
715
0
                    let excluded: Vec<String> = get_excluded_columns(
716
0
                        options.exclude.as_ref(),
717
0
                        options.except.as_ref(),
718
0
                        wildcard_schema,
719
0
                        Some(qualifier),
720
0
                    )?
721
0
                    .into_iter()
722
0
                    .map(|c| c.flat_name())
723
0
                    .collect();
724
0
                    Ok(wildcard_schema
725
0
                        .fields_with_qualified(qualifier)
726
0
                        .into_iter()
727
0
                        .filter_map(|field| {
728
0
                            let flat_name = format!("{}.{}", qualifier, field.name());
729
0
                            if excluded.contains(&flat_name) {
730
0
                                None
731
                            } else {
732
0
                                Some((
733
0
                                    Some(qualifier.clone()),
734
0
                                    Arc::new(field.to_owned()),
735
0
                                ))
736
                            }
737
0
                        })
738
0
                        .collect::<Vec<_>>())
739
                }
740
            },
741
0
            _ => Ok(vec![e.to_field(input_schema)?]),
742
0
        })
743
0
        .collect::<Result<Vec<_>>>()?
744
0
        .into_iter()
745
0
        .flatten()
746
0
        .collect();
747
0
    Ok(result)
748
0
}
749
750
/// Find the suitable base plan to expand the wildcard expression recursively.
751
/// When planning [LogicalPlan::Window] and [LogicalPlan::Aggregate], we will generate
752
/// an intermediate plan based on the relation plan (e.g. [LogicalPlan::TableScan], [LogicalPlan::Subquery], ...).
753
/// If we expand a wildcard expression basing the intermediate plan, we could get some duplicate fields.
754
0
pub fn find_base_plan(input: &LogicalPlan) -> &LogicalPlan {
755
0
    match input {
756
0
        LogicalPlan::Window(window) => find_base_plan(&window.input),
757
0
        LogicalPlan::Aggregate(agg) => find_base_plan(&agg.input),
758
        // [SqlToRel::try_process_unnest] will convert Expr(Unnest(Expr)) to Projection/Unnest/Projection
759
        // We should expand the wildcard expression based on the input plan of the inner Projection.
760
0
        LogicalPlan::Unnest(unnest) => {
761
0
            if let LogicalPlan::Projection(projection) = unnest.input.deref() {
762
0
                find_base_plan(&projection.input)
763
            } else {
764
0
                input
765
            }
766
        }
767
0
        LogicalPlan::Filter(filter) => {
768
0
            if filter.having {
769
                // If a filter is used for a having clause, its input plan is an aggregation.
770
                // We should expand the wildcard expression based on the aggregation's input plan.
771
0
                find_base_plan(&filter.input)
772
            } else {
773
0
                input
774
            }
775
        }
776
0
        _ => input,
777
    }
778
0
}
779
780
/// Count the number of real fields. We should expand the wildcard expression to get the actual number.
781
0
pub fn exprlist_len(
782
0
    exprs: &[Expr],
783
0
    schema: &DFSchemaRef,
784
0
    wildcard_schema: Option<&DFSchemaRef>,
785
0
) -> Result<usize> {
786
0
    exprs
787
0
        .iter()
788
0
        .map(|e| match e {
789
            Expr::Wildcard {
790
                qualifier: None,
791
0
                options,
792
            } => {
793
0
                let excluded = get_excluded_columns(
794
0
                    options.exclude.as_ref(),
795
0
                    options.except.as_ref(),
796
0
                    wildcard_schema.unwrap_or(schema),
797
0
                    None,
798
0
                )?
799
0
                .into_iter()
800
0
                .collect::<HashSet<Column>>();
801
0
                Ok(
802
0
                    get_exprs_except_skipped(wildcard_schema.unwrap_or(schema), excluded)
803
0
                        .len(),
804
0
                )
805
            }
806
            Expr::Wildcard {
807
0
                qualifier: Some(qualifier),
808
0
                options,
809
            } => {
810
0
                let related_wildcard_schema = wildcard_schema.as_ref().map_or_else(
811
0
                    || Ok(Arc::clone(schema)),
812
0
                    |schema| {
813
0
                        // Eliminate the fields coming from other tables.
814
0
                        let qualified_fields = schema
815
0
                            .fields()
816
0
                            .iter()
817
0
                            .enumerate()
818
0
                            .filter_map(|(idx, field)| {
819
0
                                let (maybe_table_ref, _) = schema.qualified_field(idx);
820
0
                                if maybe_table_ref.map_or(true, |q| q == qualifier) {
821
0
                                    Some((maybe_table_ref.cloned(), Arc::clone(field)))
822
                                } else {
823
0
                                    None
824
                                }
825
0
                            })
826
0
                            .collect::<Vec<_>>();
827
0
                        let metadata = schema.metadata().clone();
828
0
                        DFSchema::new_with_metadata(qualified_fields, metadata)
829
0
                            .map(Arc::new)
830
0
                    },
831
0
                )?;
832
0
                let excluded = get_excluded_columns(
833
0
                    options.exclude.as_ref(),
834
0
                    options.except.as_ref(),
835
0
                    related_wildcard_schema.as_ref(),
836
0
                    Some(qualifier),
837
0
                )?
838
0
                .into_iter()
839
0
                .collect::<HashSet<Column>>();
840
0
                Ok(
841
0
                    get_exprs_except_skipped(related_wildcard_schema.as_ref(), excluded)
842
0
                        .len(),
843
0
                )
844
            }
845
0
            _ => Ok(1),
846
0
        })
847
0
        .sum()
848
0
}
849
850
/// Convert an expression into Column expression if it's already provided as input plan.
851
///
852
/// For example, it rewrites:
853
///
854
/// ```text
855
/// .aggregate(vec![col("c1")], vec![sum(col("c2"))])?
856
/// .project(vec![col("c1"), sum(col("c2"))?
857
/// ```
858
///
859
/// Into:
860
///
861
/// ```text
862
/// .aggregate(vec![col("c1")], vec![sum(col("c2"))])?
863
/// .project(vec![col("c1"), col("SUM(c2)")?
864
/// ```
865
0
pub fn columnize_expr(e: Expr, input: &LogicalPlan) -> Result<Expr> {
866
0
    let output_exprs = match input.columnized_output_exprs() {
867
0
        Ok(exprs) if !exprs.is_empty() => exprs,
868
0
        _ => return Ok(e),
869
    };
870
0
    let exprs_map: HashMap<&Expr, Column> = output_exprs.into_iter().collect();
871
0
    e.transform_down(|node: Expr| match exprs_map.get(&node) {
872
0
        Some(column) => Ok(Transformed::new(
873
0
            Expr::Column(column.clone()),
874
0
            true,
875
0
            TreeNodeRecursion::Jump,
876
0
        )),
877
0
        None => Ok(Transformed::no(node)),
878
0
    })
879
0
    .data()
880
0
}
881
882
/// Collect all deeply nested `Expr::Column`'s. They are returned in order of
883
/// appearance (depth first), and may contain duplicates.
884
0
pub fn find_column_exprs(exprs: &[Expr]) -> Vec<Expr> {
885
0
    exprs
886
0
        .iter()
887
0
        .flat_map(find_columns_referenced_by_expr)
888
0
        .map(Expr::Column)
889
0
        .collect()
890
0
}
891
892
0
pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec<Column> {
893
0
    let mut exprs = vec![];
894
0
    e.apply(|expr| {
895
0
        if let Expr::Column(c) = expr {
896
0
            exprs.push(c.clone())
897
0
        }
898
0
        Ok(TreeNodeRecursion::Continue)
899
0
    })
900
0
    // As the closure always returns Ok, this "can't" error
901
0
    .expect("Unexpected error");
902
0
    exprs
903
0
}
904
905
/// Convert any `Expr` to an `Expr::Column`.
906
0
pub fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
907
0
    match expr {
908
0
        Expr::Column(col) => {
909
0
            let (qualifier, field) = plan.schema().qualified_field_from_column(col)?;
910
0
            Ok(Expr::from(Column::from((qualifier, field))))
911
        }
912
0
        _ => Ok(Expr::Column(Column::from_name(
913
0
            expr.schema_name().to_string(),
914
0
        ))),
915
    }
916
0
}
917
918
/// Recursively walk an expression tree, collecting the column indexes
919
/// referenced in the expression
920
0
pub(crate) fn find_column_indexes_referenced_by_expr(
921
0
    e: &Expr,
922
0
    schema: &DFSchemaRef,
923
0
) -> Vec<usize> {
924
0
    let mut indexes = vec![];
925
0
    e.apply(|expr| {
926
0
        match expr {
927
0
            Expr::Column(qc) => {
928
0
                if let Ok(idx) = schema.index_of_column(qc) {
929
0
                    indexes.push(idx);
930
0
                }
931
            }
932
0
            Expr::Literal(_) => {
933
0
                indexes.push(usize::MAX);
934
0
            }
935
0
            _ => {}
936
        }
937
0
        Ok(TreeNodeRecursion::Continue)
938
0
    })
939
0
    .unwrap();
940
0
    indexes
941
0
}
942
943
/// can this data type be used in hash join equal conditions??
944
/// data types here come from function 'equal_rows', if more data types are supported
945
/// in equal_rows(hash join), add those data types here to generate join logical plan.
946
0
pub fn can_hash(data_type: &DataType) -> bool {
947
0
    match data_type {
948
0
        DataType::Null => true,
949
0
        DataType::Boolean => true,
950
0
        DataType::Int8 => true,
951
0
        DataType::Int16 => true,
952
0
        DataType::Int32 => true,
953
0
        DataType::Int64 => true,
954
0
        DataType::UInt8 => true,
955
0
        DataType::UInt16 => true,
956
0
        DataType::UInt32 => true,
957
0
        DataType::UInt64 => true,
958
0
        DataType::Float32 => true,
959
0
        DataType::Float64 => true,
960
0
        DataType::Timestamp(time_unit, _) => match time_unit {
961
0
            TimeUnit::Second => true,
962
0
            TimeUnit::Millisecond => true,
963
0
            TimeUnit::Microsecond => true,
964
0
            TimeUnit::Nanosecond => true,
965
        },
966
0
        DataType::Utf8 => true,
967
0
        DataType::LargeUtf8 => true,
968
0
        DataType::Decimal128(_, _) => true,
969
0
        DataType::Date32 => true,
970
0
        DataType::Date64 => true,
971
0
        DataType::FixedSizeBinary(_) => true,
972
0
        DataType::Dictionary(key_type, value_type)
973
0
            if *value_type.as_ref() == DataType::Utf8 =>
974
0
        {
975
0
            DataType::is_dictionary_key_type(key_type)
976
        }
977
0
        DataType::List(_) => true,
978
0
        DataType::LargeList(_) => true,
979
0
        DataType::FixedSizeList(_, _) => true,
980
0
        DataType::Struct(fields) => fields.iter().all(|f| can_hash(f.data_type())),
981
0
        _ => false,
982
    }
983
0
}
984
985
/// Check whether all columns are from the schema.
986
0
pub fn check_all_columns_from_schema(
987
0
    columns: &HashSet<&Column>,
988
0
    schema: &DFSchema,
989
0
) -> Result<bool> {
990
0
    for col in columns.iter() {
991
0
        let exist = schema.is_column_from_schema(col);
992
0
        if !exist {
993
0
            return Ok(false);
994
0
        }
995
    }
996
997
0
    Ok(true)
998
0
}
999
1000
/// Give two sides of the equijoin predicate, return a valid join key pair.
1001
/// If there is no valid join key pair, return None.
1002
///
1003
/// A valid join means:
1004
/// 1. All referenced column of the left side is from the left schema, and
1005
///    all referenced column of the right side is from the right schema.
1006
/// 2. Or opposite. All referenced column of the left side is from the right schema,
1007
///    and the right side is from the left schema.
1008
///
1009
0
pub fn find_valid_equijoin_key_pair(
1010
0
    left_key: &Expr,
1011
0
    right_key: &Expr,
1012
0
    left_schema: &DFSchema,
1013
0
    right_schema: &DFSchema,
1014
0
) -> Result<Option<(Expr, Expr)>> {
1015
0
    let left_using_columns = left_key.column_refs();
1016
0
    let right_using_columns = right_key.column_refs();
1017
0
1018
0
    // Conditions like a = 10, will be added to non-equijoin.
1019
0
    if left_using_columns.is_empty() || right_using_columns.is_empty() {
1020
0
        return Ok(None);
1021
0
    }
1022
0
1023
0
    if check_all_columns_from_schema(&left_using_columns, left_schema)?
1024
0
        && check_all_columns_from_schema(&right_using_columns, right_schema)?
1025
    {
1026
0
        return Ok(Some((left_key.clone(), right_key.clone())));
1027
0
    } else if check_all_columns_from_schema(&right_using_columns, left_schema)?
1028
0
        && check_all_columns_from_schema(&left_using_columns, right_schema)?
1029
    {
1030
0
        return Ok(Some((right_key.clone(), left_key.clone())));
1031
0
    }
1032
0
1033
0
    Ok(None)
1034
0
}
1035
1036
/// Creates a detailed error message for a function with wrong signature.
1037
///
1038
/// For example, a query like `select round(3.14, 1.1);` would yield:
1039
/// ```text
1040
/// Error during planning: No function matches 'round(Float64, Float64)'. You might need to add explicit type casts.
1041
///     Candidate functions:
1042
///     round(Float64, Int64)
1043
///     round(Float32, Int64)
1044
///     round(Float64)
1045
///     round(Float32)
1046
/// ```
1047
0
pub fn generate_signature_error_msg(
1048
0
    func_name: &str,
1049
0
    func_signature: Signature,
1050
0
    input_expr_types: &[DataType],
1051
0
) -> String {
1052
0
    let candidate_signatures = func_signature
1053
0
        .type_signature
1054
0
        .to_string_repr()
1055
0
        .iter()
1056
0
        .map(|args_str| format!("\t{func_name}({args_str})"))
1057
0
        .collect::<Vec<String>>()
1058
0
        .join("\n");
1059
0
1060
0
    format!(
1061
0
            "No function matches the given name and argument types '{}({})'. You might need to add explicit type casts.\n\tCandidate functions:\n{}",
1062
0
            func_name, TypeSignature::join_types(input_expr_types, ", "), candidate_signatures
1063
0
        )
1064
0
}
1065
1066
/// Splits a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]`
1067
///
1068
/// See [`split_conjunction_owned`] for more details and an example.
1069
0
pub fn split_conjunction(expr: &Expr) -> Vec<&Expr> {
1070
0
    split_conjunction_impl(expr, vec![])
1071
0
}
1072
1073
0
fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<&'a Expr> {
1074
0
    match expr {
1075
        Expr::BinaryExpr(BinaryExpr {
1076
0
            right,
1077
0
            op: Operator::And,
1078
0
            left,
1079
0
        }) => {
1080
0
            let exprs = split_conjunction_impl(left, exprs);
1081
0
            split_conjunction_impl(right, exprs)
1082
        }
1083
0
        Expr::Alias(Alias { expr, .. }) => split_conjunction_impl(expr, exprs),
1084
0
        other => {
1085
0
            exprs.push(other);
1086
0
            exprs
1087
        }
1088
    }
1089
0
}
1090
1091
/// Splits an owned conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]`
1092
///
1093
/// This is often used to "split" filter expressions such as `col1 = 5
1094
/// AND col2 = 10` into [`col1 = 5`, `col2 = 10`];
1095
///
1096
/// # Example
1097
/// ```
1098
/// # use datafusion_expr::{col, lit};
1099
/// # use datafusion_expr::utils::split_conjunction_owned;
1100
/// // a=1 AND b=2
1101
/// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2)));
1102
///
1103
/// // [a=1, b=2]
1104
/// let split = vec![
1105
///   col("a").eq(lit(1)),
1106
///   col("b").eq(lit(2)),
1107
/// ];
1108
///
1109
/// // use split_conjunction_owned to split them
1110
/// assert_eq!(split_conjunction_owned(expr), split);
1111
/// ```
1112
0
pub fn split_conjunction_owned(expr: Expr) -> Vec<Expr> {
1113
0
    split_binary_owned(expr, Operator::And)
1114
0
}
1115
1116
/// Splits an owned binary operator tree [`Expr`] such as `A <OP> B <OP> C` => `[A, B, C]`
1117
///
1118
/// This is often used to "split" expressions such as `col1 = 5
1119
/// AND col2 = 10` into [`col1 = 5`, `col2 = 10`];
1120
///
1121
/// # Example
1122
/// ```
1123
/// # use datafusion_expr::{col, lit, Operator};
1124
/// # use datafusion_expr::utils::split_binary_owned;
1125
/// # use std::ops::Add;
1126
/// // a=1 + b=2
1127
/// let expr = col("a").eq(lit(1)).add(col("b").eq(lit(2)));
1128
///
1129
/// // [a=1, b=2]
1130
/// let split = vec![
1131
///   col("a").eq(lit(1)),
1132
///   col("b").eq(lit(2)),
1133
/// ];
1134
///
1135
/// // use split_binary_owned to split them
1136
/// assert_eq!(split_binary_owned(expr, Operator::Plus), split);
1137
/// ```
1138
0
pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec<Expr> {
1139
0
    split_binary_owned_impl(expr, op, vec![])
1140
0
}
1141
1142
0
fn split_binary_owned_impl(
1143
0
    expr: Expr,
1144
0
    operator: Operator,
1145
0
    mut exprs: Vec<Expr>,
1146
0
) -> Vec<Expr> {
1147
0
    match expr {
1148
0
        Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => {
1149
0
            let exprs = split_binary_owned_impl(*left, operator, exprs);
1150
0
            split_binary_owned_impl(*right, operator, exprs)
1151
        }
1152
0
        Expr::Alias(Alias { expr, .. }) => {
1153
0
            split_binary_owned_impl(*expr, operator, exprs)
1154
        }
1155
0
        other => {
1156
0
            exprs.push(other);
1157
0
            exprs
1158
        }
1159
    }
1160
0
}
1161
1162
/// Splits an binary operator tree [`Expr`] such as `A <OP> B <OP> C` => `[A, B, C]`
1163
///
1164
/// See [`split_binary_owned`] for more details and an example.
1165
0
pub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> {
1166
0
    split_binary_impl(expr, op, vec![])
1167
0
}
1168
1169
0
fn split_binary_impl<'a>(
1170
0
    expr: &'a Expr,
1171
0
    operator: Operator,
1172
0
    mut exprs: Vec<&'a Expr>,
1173
0
) -> Vec<&'a Expr> {
1174
0
    match expr {
1175
0
        Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator => {
1176
0
            let exprs = split_binary_impl(left, operator, exprs);
1177
0
            split_binary_impl(right, operator, exprs)
1178
        }
1179
0
        Expr::Alias(Alias { expr, .. }) => split_binary_impl(expr, operator, exprs),
1180
0
        other => {
1181
0
            exprs.push(other);
1182
0
            exprs
1183
        }
1184
    }
1185
0
}
1186
1187
/// Combines an array of filter expressions into a single filter
1188
/// expression consisting of the input filter expressions joined with
1189
/// logical AND.
1190
///
1191
/// Returns None if the filters array is empty.
1192
///
1193
/// # Example
1194
/// ```
1195
/// # use datafusion_expr::{col, lit};
1196
/// # use datafusion_expr::utils::conjunction;
1197
/// // a=1 AND b=2
1198
/// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2)));
1199
///
1200
/// // [a=1, b=2]
1201
/// let split = vec![
1202
///   col("a").eq(lit(1)),
1203
///   col("b").eq(lit(2)),
1204
/// ];
1205
///
1206
/// // use conjunction to join them together with `AND`
1207
/// assert_eq!(conjunction(split), Some(expr));
1208
/// ```
1209
0
pub fn conjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> {
1210
0
    filters.into_iter().reduce(Expr::and)
1211
0
}
1212
1213
/// Combines an array of filter expressions into a single filter
1214
/// expression consisting of the input filter expressions joined with
1215
/// logical OR.
1216
///
1217
/// Returns None if the filters array is empty.
1218
///
1219
/// # Example
1220
/// ```
1221
/// # use datafusion_expr::{col, lit};
1222
/// # use datafusion_expr::utils::disjunction;
1223
/// // a=1 OR b=2
1224
/// let expr = col("a").eq(lit(1)).or(col("b").eq(lit(2)));
1225
///
1226
/// // [a=1, b=2]
1227
/// let split = vec![
1228
///   col("a").eq(lit(1)),
1229
///   col("b").eq(lit(2)),
1230
/// ];
1231
///
1232
/// // use disjuncton to join them together with `OR`
1233
/// assert_eq!(disjunction(split), Some(expr));
1234
/// ```
1235
0
pub fn disjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> {
1236
0
    filters.into_iter().reduce(Expr::or)
1237
0
}
1238
1239
/// Returns a new [LogicalPlan] that filters the output of  `plan` with a
1240
/// [LogicalPlan::Filter] with all `predicates` ANDed.
1241
///
1242
/// # Example
1243
/// Before:
1244
/// ```text
1245
/// plan
1246
/// ```
1247
///
1248
/// After:
1249
/// ```text
1250
/// Filter(predicate)
1251
///   plan
1252
/// ```
1253
0
pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result<LogicalPlan> {
1254
0
    // reduce filters to a single filter with an AND
1255
0
    let predicate = predicates
1256
0
        .iter()
1257
0
        .skip(1)
1258
0
        .fold(predicates[0].clone(), |acc, predicate| {
1259
0
            and(acc, (*predicate).to_owned())
1260
0
        });
1261
0
1262
0
    Ok(LogicalPlan::Filter(Filter::try_new(
1263
0
        predicate,
1264
0
        Arc::new(plan),
1265
0
    )?))
1266
0
}
1267
1268
/// Looks for correlating expressions: for example, a binary expression with one field from the subquery, and
1269
/// one not in the subquery (closed upon from outer scope)
1270
///
1271
/// # Arguments
1272
///
1273
/// * `exprs` - List of expressions that may or may not be joins
1274
///
1275
/// # Return value
1276
///
1277
/// Tuple of (expressions containing joins, remaining non-join expressions)
1278
0
pub fn find_join_exprs(exprs: Vec<&Expr>) -> Result<(Vec<Expr>, Vec<Expr>)> {
1279
0
    let mut joins = vec![];
1280
0
    let mut others = vec![];
1281
0
    for filter in exprs.into_iter() {
1282
        // If the expression contains correlated predicates, add it to join filters
1283
0
        if filter.contains_outer() {
1284
0
            if !matches!(filter, Expr::BinaryExpr(BinaryExpr{ left, op: Operator::Eq, right }) if left.eq(right))
1285
0
            {
1286
0
                joins.push(strip_outer_reference((*filter).clone()));
1287
0
            }
1288
0
        } else {
1289
0
            others.push((*filter).clone());
1290
0
        }
1291
    }
1292
1293
0
    Ok((joins, others))
1294
0
}
1295
1296
/// Returns the first (and only) element in a slice, or an error
1297
///
1298
/// # Arguments
1299
///
1300
/// * `slice` - The slice to extract from
1301
///
1302
/// # Return value
1303
///
1304
/// The first element, or an error
1305
0
pub fn only_or_err<T>(slice: &[T]) -> Result<&T> {
1306
0
    match slice {
1307
0
        [it] => Ok(it),
1308
0
        [] => plan_err!("No items found!"),
1309
0
        _ => plan_err!("More than one item found!"),
1310
    }
1311
0
}
1312
1313
/// merge inputs schema into a single schema.
1314
0
pub fn merge_schema(inputs: &[&LogicalPlan]) -> DFSchema {
1315
0
    if inputs.len() == 1 {
1316
0
        inputs[0].schema().as_ref().clone()
1317
    } else {
1318
0
        inputs.iter().map(|input| input.schema()).fold(
1319
0
            DFSchema::empty(),
1320
0
            |mut lhs, rhs| {
1321
0
                lhs.merge(rhs);
1322
0
                lhs
1323
0
            },
1324
0
        )
1325
    }
1326
0
}
1327
1328
/// Build state name. State is the intermediate state of the aggregate function.
1329
137
pub fn format_state_name(name: &str, state_name: &str) -> String {
1330
137
    format!("{name}[{state_name}]")
1331
137
}
1332
1333
#[cfg(test)]
1334
mod tests {
1335
    use super::*;
1336
    use crate::{
1337
        col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup,
1338
        test::function_stub::max_udaf, test::function_stub::min_udaf,
1339
        test::function_stub::sum_udaf, Cast, ExprFunctionExt, WindowFunctionDefinition,
1340
    };
1341
1342
    #[test]
1343
    fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> {
1344
        let result = group_window_expr_by_sort_keys(vec![])?;
1345
        let expected: Vec<(WindowSortKey, Vec<Expr>)> = vec![];
1346
        assert_eq!(expected, result);
1347
        Ok(())
1348
    }
1349
1350
    #[test]
1351
    fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> {
1352
        let max1 = Expr::WindowFunction(expr::WindowFunction::new(
1353
            WindowFunctionDefinition::AggregateUDF(max_udaf()),
1354
            vec![col("name")],
1355
        ));
1356
        let max2 = Expr::WindowFunction(expr::WindowFunction::new(
1357
            WindowFunctionDefinition::AggregateUDF(max_udaf()),
1358
            vec![col("name")],
1359
        ));
1360
        let min3 = Expr::WindowFunction(expr::WindowFunction::new(
1361
            WindowFunctionDefinition::AggregateUDF(min_udaf()),
1362
            vec![col("name")],
1363
        ));
1364
        let sum4 = Expr::WindowFunction(expr::WindowFunction::new(
1365
            WindowFunctionDefinition::AggregateUDF(sum_udaf()),
1366
            vec![col("age")],
1367
        ));
1368
        let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
1369
        let result = group_window_expr_by_sort_keys(exprs.to_vec())?;
1370
        let key = vec![];
1371
        let expected: Vec<(WindowSortKey, Vec<Expr>)> =
1372
            vec![(key, vec![max1, max2, min3, sum4])];
1373
        assert_eq!(expected, result);
1374
        Ok(())
1375
    }
1376
1377
    #[test]
1378
    fn test_group_window_expr_by_sort_keys() -> Result<()> {
1379
        let age_asc = expr::Sort::new(col("age"), true, true);
1380
        let name_desc = expr::Sort::new(col("name"), false, true);
1381
        let created_at_desc = expr::Sort::new(col("created_at"), false, true);
1382
        let max1 = Expr::WindowFunction(expr::WindowFunction::new(
1383
            WindowFunctionDefinition::AggregateUDF(max_udaf()),
1384
            vec![col("name")],
1385
        ))
1386
        .order_by(vec![age_asc.clone(), name_desc.clone()])
1387
        .build()
1388
        .unwrap();
1389
        let max2 = Expr::WindowFunction(expr::WindowFunction::new(
1390
            WindowFunctionDefinition::AggregateUDF(max_udaf()),
1391
            vec![col("name")],
1392
        ));
1393
        let min3 = Expr::WindowFunction(expr::WindowFunction::new(
1394
            WindowFunctionDefinition::AggregateUDF(min_udaf()),
1395
            vec![col("name")],
1396
        ))
1397
        .order_by(vec![age_asc.clone(), name_desc.clone()])
1398
        .build()
1399
        .unwrap();
1400
        let sum4 = Expr::WindowFunction(expr::WindowFunction::new(
1401
            WindowFunctionDefinition::AggregateUDF(sum_udaf()),
1402
            vec![col("age")],
1403
        ))
1404
        .order_by(vec![
1405
            name_desc.clone(),
1406
            age_asc.clone(),
1407
            created_at_desc.clone(),
1408
        ])
1409
        .build()
1410
        .unwrap();
1411
        // FIXME use as_ref
1412
        let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
1413
        let result = group_window_expr_by_sort_keys(exprs.to_vec())?;
1414
1415
        let key1 = vec![(age_asc.clone(), false), (name_desc.clone(), false)];
1416
        let key2 = vec![];
1417
        let key3 = vec![
1418
            (name_desc, false),
1419
            (age_asc, false),
1420
            (created_at_desc, false),
1421
        ];
1422
1423
        let expected: Vec<(WindowSortKey, Vec<Expr>)> = vec![
1424
            (key1, vec![max1, min3]),
1425
            (key2, vec![max2]),
1426
            (key3, vec![sum4]),
1427
        ];
1428
        assert_eq!(expected, result);
1429
        Ok(())
1430
    }
1431
1432
    #[test]
1433
    fn avoid_generate_duplicate_sort_keys() -> Result<()> {
1434
        let asc_or_desc = [true, false];
1435
        let nulls_first_or_last = [true, false];
1436
        let partition_by = &[col("age"), col("name"), col("created_at")];
1437
        for asc_ in asc_or_desc {
1438
            for nulls_first_ in nulls_first_or_last {
1439
                let order_by = &[
1440
                    Sort {
1441
                        expr: col("age"),
1442
                        asc: asc_,
1443
                        nulls_first: nulls_first_,
1444
                    },
1445
                    Sort {
1446
                        expr: col("name"),
1447
                        asc: asc_,
1448
                        nulls_first: nulls_first_,
1449
                    },
1450
                ];
1451
1452
                let expected = vec![
1453
                    (
1454
                        Sort {
1455
                            expr: col("age"),
1456
                            asc: asc_,
1457
                            nulls_first: nulls_first_,
1458
                        },
1459
                        true,
1460
                    ),
1461
                    (
1462
                        Sort {
1463
                            expr: col("name"),
1464
                            asc: asc_,
1465
                            nulls_first: nulls_first_,
1466
                        },
1467
                        true,
1468
                    ),
1469
                    (
1470
                        Sort {
1471
                            expr: col("created_at"),
1472
                            asc: true,
1473
                            nulls_first: false,
1474
                        },
1475
                        true,
1476
                    ),
1477
                ];
1478
                let result = generate_sort_key(partition_by, order_by)?;
1479
                assert_eq!(expected, result);
1480
            }
1481
        }
1482
        Ok(())
1483
    }
1484
1485
    #[test]
1486
    fn test_enumerate_grouping_sets() -> Result<()> {
1487
        let multi_cols = vec![col("col1"), col("col2"), col("col3")];
1488
        let simple_col = col("simple_col");
1489
        let cube = cube(multi_cols.clone());
1490
        let rollup = rollup(multi_cols.clone());
1491
        let grouping_set = grouping_set(vec![multi_cols]);
1492
1493
        // 1. col
1494
        let sets = enumerate_grouping_sets(vec![simple_col.clone()])?;
1495
        let result = format!("[{}]", expr_vec_fmt!(sets));
1496
        assert_eq!("[simple_col]", &result);
1497
1498
        // 2. cube
1499
        let sets = enumerate_grouping_sets(vec![cube.clone()])?;
1500
        let result = format!("[{}]", expr_vec_fmt!(sets));
1501
        assert_eq!("[CUBE (col1, col2, col3)]", &result);
1502
1503
        // 3. rollup
1504
        let sets = enumerate_grouping_sets(vec![rollup.clone()])?;
1505
        let result = format!("[{}]", expr_vec_fmt!(sets));
1506
        assert_eq!("[ROLLUP (col1, col2, col3)]", &result);
1507
1508
        // 4. col + cube
1509
        let sets = enumerate_grouping_sets(vec![simple_col.clone(), cube.clone()])?;
1510
        let result = format!("[{}]", expr_vec_fmt!(sets));
1511
        assert_eq!(
1512
            "[GROUPING SETS (\
1513
            (simple_col), \
1514
            (simple_col, col1), \
1515
            (simple_col, col2), \
1516
            (simple_col, col1, col2), \
1517
            (simple_col, col3), \
1518
            (simple_col, col1, col3), \
1519
            (simple_col, col2, col3), \
1520
            (simple_col, col1, col2, col3))]",
1521
            &result
1522
        );
1523
1524
        // 5. col + rollup
1525
        let sets = enumerate_grouping_sets(vec![simple_col.clone(), rollup.clone()])?;
1526
        let result = format!("[{}]", expr_vec_fmt!(sets));
1527
        assert_eq!(
1528
            "[GROUPING SETS (\
1529
            (simple_col), \
1530
            (simple_col, col1), \
1531
            (simple_col, col1, col2), \
1532
            (simple_col, col1, col2, col3))]",
1533
            &result
1534
        );
1535
1536
        // 6. col + grouping_set
1537
        let sets =
1538
            enumerate_grouping_sets(vec![simple_col.clone(), grouping_set.clone()])?;
1539
        let result = format!("[{}]", expr_vec_fmt!(sets));
1540
        assert_eq!(
1541
            "[GROUPING SETS (\
1542
            (simple_col, col1, col2, col3))]",
1543
            &result
1544
        );
1545
1546
        // 7. col + grouping_set + rollup
1547
        let sets = enumerate_grouping_sets(vec![
1548
            simple_col.clone(),
1549
            grouping_set,
1550
            rollup.clone(),
1551
        ])?;
1552
        let result = format!("[{}]", expr_vec_fmt!(sets));
1553
        assert_eq!(
1554
            "[GROUPING SETS (\
1555
            (simple_col, col1, col2, col3), \
1556
            (simple_col, col1, col2, col3, col1), \
1557
            (simple_col, col1, col2, col3, col1, col2), \
1558
            (simple_col, col1, col2, col3, col1, col2, col3))]",
1559
            &result
1560
        );
1561
1562
        // 8. col + cube + rollup
1563
        let sets = enumerate_grouping_sets(vec![simple_col, cube, rollup])?;
1564
        let result = format!("[{}]", expr_vec_fmt!(sets));
1565
        assert_eq!(
1566
            "[GROUPING SETS (\
1567
            (simple_col), \
1568
            (simple_col, col1), \
1569
            (simple_col, col1, col2), \
1570
            (simple_col, col1, col2, col3), \
1571
            (simple_col, col1), \
1572
            (simple_col, col1, col1), \
1573
            (simple_col, col1, col1, col2), \
1574
            (simple_col, col1, col1, col2, col3), \
1575
            (simple_col, col2), \
1576
            (simple_col, col2, col1), \
1577
            (simple_col, col2, col1, col2), \
1578
            (simple_col, col2, col1, col2, col3), \
1579
            (simple_col, col1, col2), \
1580
            (simple_col, col1, col2, col1), \
1581
            (simple_col, col1, col2, col1, col2), \
1582
            (simple_col, col1, col2, col1, col2, col3), \
1583
            (simple_col, col3), \
1584
            (simple_col, col3, col1), \
1585
            (simple_col, col3, col1, col2), \
1586
            (simple_col, col3, col1, col2, col3), \
1587
            (simple_col, col1, col3), \
1588
            (simple_col, col1, col3, col1), \
1589
            (simple_col, col1, col3, col1, col2), \
1590
            (simple_col, col1, col3, col1, col2, col3), \
1591
            (simple_col, col2, col3), \
1592
            (simple_col, col2, col3, col1), \
1593
            (simple_col, col2, col3, col1, col2), \
1594
            (simple_col, col2, col3, col1, col2, col3), \
1595
            (simple_col, col1, col2, col3), \
1596
            (simple_col, col1, col2, col3, col1), \
1597
            (simple_col, col1, col2, col3, col1, col2), \
1598
            (simple_col, col1, col2, col3, col1, col2, col3))]",
1599
            &result
1600
        );
1601
1602
        Ok(())
1603
    }
1604
    #[test]
1605
    fn test_split_conjunction() {
1606
        let expr = col("a");
1607
        let result = split_conjunction(&expr);
1608
        assert_eq!(result, vec![&expr]);
1609
    }
1610
1611
    #[test]
1612
    fn test_split_conjunction_two() {
1613
        let expr = col("a").eq(lit(5)).and(col("b"));
1614
        let expr1 = col("a").eq(lit(5));
1615
        let expr2 = col("b");
1616
1617
        let result = split_conjunction(&expr);
1618
        assert_eq!(result, vec![&expr1, &expr2]);
1619
    }
1620
1621
    #[test]
1622
    fn test_split_conjunction_alias() {
1623
        let expr = col("a").eq(lit(5)).and(col("b").alias("the_alias"));
1624
        let expr1 = col("a").eq(lit(5));
1625
        let expr2 = col("b"); // has no alias
1626
1627
        let result = split_conjunction(&expr);
1628
        assert_eq!(result, vec![&expr1, &expr2]);
1629
    }
1630
1631
    #[test]
1632
    fn test_split_conjunction_or() {
1633
        let expr = col("a").eq(lit(5)).or(col("b"));
1634
        let result = split_conjunction(&expr);
1635
        assert_eq!(result, vec![&expr]);
1636
    }
1637
1638
    #[test]
1639
    fn test_split_binary_owned() {
1640
        let expr = col("a");
1641
        assert_eq!(split_binary_owned(expr.clone(), Operator::And), vec![expr]);
1642
    }
1643
1644
    #[test]
1645
    fn test_split_binary_owned_two() {
1646
        assert_eq!(
1647
            split_binary_owned(col("a").eq(lit(5)).and(col("b")), Operator::And),
1648
            vec![col("a").eq(lit(5)), col("b")]
1649
        );
1650
    }
1651
1652
    #[test]
1653
    fn test_split_binary_owned_different_op() {
1654
        let expr = col("a").eq(lit(5)).or(col("b"));
1655
        assert_eq!(
1656
            // expr is connected by OR, but pass in AND
1657
            split_binary_owned(expr.clone(), Operator::And),
1658
            vec![expr]
1659
        );
1660
    }
1661
1662
    #[test]
1663
    fn test_split_conjunction_owned() {
1664
        let expr = col("a");
1665
        assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]);
1666
    }
1667
1668
    #[test]
1669
    fn test_split_conjunction_owned_two() {
1670
        assert_eq!(
1671
            split_conjunction_owned(col("a").eq(lit(5)).and(col("b"))),
1672
            vec![col("a").eq(lit(5)), col("b")]
1673
        );
1674
    }
1675
1676
    #[test]
1677
    fn test_split_conjunction_owned_alias() {
1678
        assert_eq!(
1679
            split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias"))),
1680
            vec![
1681
                col("a").eq(lit(5)),
1682
                // no alias on b
1683
                col("b"),
1684
            ]
1685
        );
1686
    }
1687
1688
    #[test]
1689
    fn test_conjunction_empty() {
1690
        assert_eq!(conjunction(vec![]), None);
1691
    }
1692
1693
    #[test]
1694
    fn test_conjunction() {
1695
        // `[A, B, C]`
1696
        let expr = conjunction(vec![col("a"), col("b"), col("c")]);
1697
1698
        // --> `(A AND B) AND C`
1699
        assert_eq!(expr, Some(col("a").and(col("b")).and(col("c"))));
1700
1701
        // which is different than `A AND (B AND C)`
1702
        assert_ne!(expr, Some(col("a").and(col("b").and(col("c")))));
1703
    }
1704
1705
    #[test]
1706
    fn test_disjunction_empty() {
1707
        assert_eq!(disjunction(vec![]), None);
1708
    }
1709
1710
    #[test]
1711
    fn test_disjunction() {
1712
        // `[A, B, C]`
1713
        let expr = disjunction(vec![col("a"), col("b"), col("c")]);
1714
1715
        // --> `(A OR B) OR C`
1716
        assert_eq!(expr, Some(col("a").or(col("b")).or(col("c"))));
1717
1718
        // which is different than `A OR (B OR C)`
1719
        assert_ne!(expr, Some(col("a").or(col("b").or(col("c")))));
1720
    }
1721
1722
    #[test]
1723
    fn test_split_conjunction_owned_or() {
1724
        let expr = col("a").eq(lit(5)).or(col("b"));
1725
        assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]);
1726
    }
1727
1728
    #[test]
1729
    fn test_collect_expr() -> Result<()> {
1730
        let mut accum: HashSet<Column> = HashSet::new();
1731
        expr_to_columns(
1732
            &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)),
1733
            &mut accum,
1734
        )?;
1735
        expr_to_columns(
1736
            &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)),
1737
            &mut accum,
1738
        )?;
1739
        assert_eq!(1, accum.len());
1740
        assert!(accum.contains(&Column::from_name("a")));
1741
        Ok(())
1742
    }
1743
}