Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/expr/src/tree_node.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Tree node implementation for logical expr
19
20
use crate::expr::{
21
    AggregateFunction, Alias, Between, BinaryExpr, Case, Cast, GroupingSet, InList,
22
    InSubquery, Like, Placeholder, ScalarFunction, Sort, TryCast, Unnest, WindowFunction,
23
};
24
use crate::{Expr, ExprFunctionExt};
25
26
use datafusion_common::tree_node::{
27
    Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion,
28
};
29
use datafusion_common::{map_until_stop_and_collect, Result};
30
31
impl TreeNode for Expr {
32
0
    fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
33
0
        &'n self,
34
0
        f: F,
35
0
    ) -> Result<TreeNodeRecursion> {
36
0
        let children = match self {
37
0
            Expr::Alias(Alias{expr,..})
38
0
            | Expr::Unnest(Unnest{expr})
39
0
            | Expr::Not(expr)
40
0
            | Expr::IsNotNull(expr)
41
0
            | Expr::IsTrue(expr)
42
0
            | Expr::IsFalse(expr)
43
0
            | Expr::IsUnknown(expr)
44
0
            | Expr::IsNotTrue(expr)
45
0
            | Expr::IsNotFalse(expr)
46
0
            | Expr::IsNotUnknown(expr)
47
0
            | Expr::IsNull(expr)
48
0
            | Expr::Negative(expr)
49
0
            | Expr::Cast(Cast { expr, .. })
50
0
            | Expr::TryCast(TryCast { expr, .. })
51
0
            | Expr::InSubquery(InSubquery{ expr, .. }) => vec![expr.as_ref()],
52
0
            Expr::GroupingSet(GroupingSet::Rollup(exprs))
53
0
            | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.iter().collect(),
54
0
            Expr::ScalarFunction (ScalarFunction{ args, .. } )  => {
55
0
                args.iter().collect()
56
            }
57
0
            Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
58
0
                lists_of_exprs.iter().flatten().collect()
59
            }
60
            Expr::Column(_)
61
            // Treat OuterReferenceColumn as a leaf expression
62
            | Expr::OuterReferenceColumn(_, _)
63
            | Expr::ScalarVariable(_, _)
64
            | Expr::Literal(_)
65
            | Expr::Exists {..}
66
            | Expr::ScalarSubquery(_)
67
            | Expr::Wildcard {..}
68
0
            | Expr::Placeholder (_) => vec![],
69
0
            Expr::BinaryExpr(BinaryExpr { left, right, .. }) => {
70
0
                vec![left.as_ref(), right.as_ref()]
71
            }
72
0
            Expr::Like(Like { expr, pattern, .. })
73
0
            | Expr::SimilarTo(Like { expr, pattern, .. }) => {
74
0
                vec![expr.as_ref(), pattern.as_ref()]
75
            }
76
            Expr::Between(Between {
77
0
                expr, low, high, ..
78
0
            }) => vec![expr.as_ref(), low.as_ref(), high.as_ref()],
79
0
            Expr::Case(case) => {
80
0
                let mut expr_vec = vec![];
81
0
                if let Some(expr) = case.expr.as_ref() {
82
0
                    expr_vec.push(expr.as_ref());
83
0
                };
84
0
                for (when, then) in case.when_then_expr.iter() {
85
0
                    expr_vec.push(when.as_ref());
86
0
                    expr_vec.push(then.as_ref());
87
0
                }
88
0
                if let Some(else_expr) = case.else_expr.as_ref() {
89
0
                    expr_vec.push(else_expr.as_ref());
90
0
                }
91
0
                expr_vec
92
            }
93
0
            Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. })
94
0
             => {
95
0
                let mut expr_vec = args.iter().collect::<Vec<_>>();
96
0
                if let Some(f) = filter {
97
0
                    expr_vec.push(f.as_ref());
98
0
                }
99
0
                if let Some(order_by) = order_by {
100
0
                    expr_vec.extend(order_by.iter().map(|sort| &sort.expr));
101
0
                }
102
0
                expr_vec
103
            }
104
            Expr::WindowFunction(WindowFunction {
105
0
                args,
106
0
                partition_by,
107
0
                order_by,
108
0
                ..
109
0
            }) => {
110
0
                let mut expr_vec = args.iter().collect::<Vec<_>>();
111
0
                expr_vec.extend(partition_by);
112
0
                expr_vec.extend(order_by.iter().map(|sort| &sort.expr));
113
0
                expr_vec
114
            }
115
0
            Expr::InList(InList { expr, list, .. }) => {
116
0
                let mut expr_vec = vec![expr.as_ref()];
117
0
                expr_vec.extend(list);
118
0
                expr_vec
119
            }
120
        };
121
122
0
        children.into_iter().apply_until_stop(f)
123
0
    }
124
125
0
    fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
126
0
        self,
127
0
        mut f: F,
128
0
    ) -> Result<Transformed<Self>> {
129
0
        Ok(match self {
130
            Expr::Column(_)
131
            | Expr::Wildcard { .. }
132
            | Expr::Placeholder(Placeholder { .. })
133
            | Expr::OuterReferenceColumn(_, _)
134
            | Expr::Exists { .. }
135
            | Expr::ScalarSubquery(_)
136
            | Expr::ScalarVariable(_, _)
137
0
            | Expr::Literal(_) => Transformed::no(self),
138
0
            Expr::Unnest(Unnest { expr, .. }) => transform_box(expr, &mut f)?
139
0
                .update_data(|be| Expr::Unnest(Unnest::new_boxed(be))),
140
            Expr::Alias(Alias {
141
0
                expr,
142
0
                relation,
143
0
                name,
144
0
            }) => f(*expr)?.update_data(|e| Expr::Alias(Alias::new(e, relation, name))),
145
            Expr::InSubquery(InSubquery {
146
0
                expr,
147
0
                subquery,
148
0
                negated,
149
0
            }) => transform_box(expr, &mut f)?.update_data(|be| {
150
0
                Expr::InSubquery(InSubquery::new(be, subquery, negated))
151
0
            }),
152
0
            Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
153
0
                map_until_stop_and_collect!(
154
0
                    transform_box(left, &mut f),
155
0
                    right,
156
0
                    transform_box(right, &mut f)
157
0
                )?
158
0
                .update_data(|(new_left, new_right)| {
159
0
                    Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right))
160
0
                })
161
            }
162
            Expr::Like(Like {
163
0
                negated,
164
0
                expr,
165
0
                pattern,
166
0
                escape_char,
167
0
                case_insensitive,
168
0
            }) => map_until_stop_and_collect!(
169
0
                transform_box(expr, &mut f),
170
0
                pattern,
171
0
                transform_box(pattern, &mut f)
172
0
            )?
173
0
            .update_data(|(new_expr, new_pattern)| {
174
0
                Expr::Like(Like::new(
175
0
                    negated,
176
0
                    new_expr,
177
0
                    new_pattern,
178
0
                    escape_char,
179
0
                    case_insensitive,
180
0
                ))
181
0
            }),
182
            Expr::SimilarTo(Like {
183
0
                negated,
184
0
                expr,
185
0
                pattern,
186
0
                escape_char,
187
0
                case_insensitive,
188
0
            }) => map_until_stop_and_collect!(
189
0
                transform_box(expr, &mut f),
190
0
                pattern,
191
0
                transform_box(pattern, &mut f)
192
0
            )?
193
0
            .update_data(|(new_expr, new_pattern)| {
194
0
                Expr::SimilarTo(Like::new(
195
0
                    negated,
196
0
                    new_expr,
197
0
                    new_pattern,
198
0
                    escape_char,
199
0
                    case_insensitive,
200
0
                ))
201
0
            }),
202
0
            Expr::Not(expr) => transform_box(expr, &mut f)?.update_data(Expr::Not),
203
0
            Expr::IsNotNull(expr) => {
204
0
                transform_box(expr, &mut f)?.update_data(Expr::IsNotNull)
205
            }
206
0
            Expr::IsNull(expr) => transform_box(expr, &mut f)?.update_data(Expr::IsNull),
207
0
            Expr::IsTrue(expr) => transform_box(expr, &mut f)?.update_data(Expr::IsTrue),
208
0
            Expr::IsFalse(expr) => {
209
0
                transform_box(expr, &mut f)?.update_data(Expr::IsFalse)
210
            }
211
0
            Expr::IsUnknown(expr) => {
212
0
                transform_box(expr, &mut f)?.update_data(Expr::IsUnknown)
213
            }
214
0
            Expr::IsNotTrue(expr) => {
215
0
                transform_box(expr, &mut f)?.update_data(Expr::IsNotTrue)
216
            }
217
0
            Expr::IsNotFalse(expr) => {
218
0
                transform_box(expr, &mut f)?.update_data(Expr::IsNotFalse)
219
            }
220
0
            Expr::IsNotUnknown(expr) => {
221
0
                transform_box(expr, &mut f)?.update_data(Expr::IsNotUnknown)
222
            }
223
0
            Expr::Negative(expr) => {
224
0
                transform_box(expr, &mut f)?.update_data(Expr::Negative)
225
            }
226
            Expr::Between(Between {
227
0
                expr,
228
0
                negated,
229
0
                low,
230
0
                high,
231
0
            }) => map_until_stop_and_collect!(
232
0
                transform_box(expr, &mut f),
233
0
                low,
234
0
                transform_box(low, &mut f),
235
0
                high,
236
0
                transform_box(high, &mut f)
237
0
            )?
238
0
            .update_data(|(new_expr, new_low, new_high)| {
239
0
                Expr::Between(Between::new(new_expr, negated, new_low, new_high))
240
0
            }),
241
            Expr::Case(Case {
242
0
                expr,
243
0
                when_then_expr,
244
0
                else_expr,
245
0
            }) => map_until_stop_and_collect!(
246
0
                transform_option_box(expr, &mut f),
247
0
                when_then_expr,
248
0
                when_then_expr
249
0
                    .into_iter()
250
0
                    .map_until_stop_and_collect(|(when, then)| {
251
0
                        map_until_stop_and_collect!(
252
0
                            transform_box(when, &mut f),
253
0
                            then,
254
0
                            transform_box(then, &mut f)
255
0
                        )
256
0
                    }),
257
0
                else_expr,
258
0
                transform_option_box(else_expr, &mut f)
259
0
            )?
260
0
            .update_data(|(new_expr, new_when_then_expr, new_else_expr)| {
261
0
                Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr))
262
0
            }),
263
0
            Expr::Cast(Cast { expr, data_type }) => transform_box(expr, &mut f)?
264
0
                .update_data(|be| Expr::Cast(Cast::new(be, data_type))),
265
0
            Expr::TryCast(TryCast { expr, data_type }) => transform_box(expr, &mut f)?
266
0
                .update_data(|be| Expr::TryCast(TryCast::new(be, data_type))),
267
0
            Expr::ScalarFunction(ScalarFunction { func, args }) => {
268
0
                transform_vec(args, &mut f)?.map_data(|new_args| {
269
0
                    Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
270
0
                        func, new_args,
271
0
                    )))
272
0
                })?
273
            }
274
            Expr::WindowFunction(WindowFunction {
275
0
                args,
276
0
                fun,
277
0
                partition_by,
278
0
                order_by,
279
0
                window_frame,
280
0
                null_treatment,
281
0
            }) => map_until_stop_and_collect!(
282
0
                transform_vec(args, &mut f),
283
0
                partition_by,
284
0
                transform_vec(partition_by, &mut f),
285
0
                order_by,
286
0
                transform_sort_vec(order_by, &mut f)
287
0
            )?
288
0
            .update_data(|(new_args, new_partition_by, new_order_by)| {
289
0
                Expr::WindowFunction(WindowFunction::new(fun, new_args))
290
0
                    .partition_by(new_partition_by)
291
0
                    .order_by(new_order_by)
292
0
                    .window_frame(window_frame)
293
0
                    .null_treatment(null_treatment)
294
0
                    .build()
295
0
                    .unwrap()
296
0
            }),
297
            Expr::AggregateFunction(AggregateFunction {
298
0
                args,
299
0
                func,
300
0
                distinct,
301
0
                filter,
302
0
                order_by,
303
0
                null_treatment,
304
0
            }) => map_until_stop_and_collect!(
305
0
                transform_vec(args, &mut f),
306
0
                filter,
307
0
                transform_option_box(filter, &mut f),
308
0
                order_by,
309
0
                transform_sort_option_vec(order_by, &mut f)
310
0
            )?
311
0
            .map_data(|(new_args, new_filter, new_order_by)| {
312
0
                Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
313
0
                    func,
314
0
                    new_args,
315
0
                    distinct,
316
0
                    new_filter,
317
0
                    new_order_by,
318
0
                    null_treatment,
319
0
                )))
320
0
            })?,
321
0
            Expr::GroupingSet(grouping_set) => match grouping_set {
322
0
                GroupingSet::Rollup(exprs) => transform_vec(exprs, &mut f)?
323
0
                    .update_data(|ve| Expr::GroupingSet(GroupingSet::Rollup(ve))),
324
0
                GroupingSet::Cube(exprs) => transform_vec(exprs, &mut f)?
325
0
                    .update_data(|ve| Expr::GroupingSet(GroupingSet::Cube(ve))),
326
0
                GroupingSet::GroupingSets(lists_of_exprs) => lists_of_exprs
327
0
                    .into_iter()
328
0
                    .map_until_stop_and_collect(|exprs| transform_vec(exprs, &mut f))?
329
0
                    .update_data(|new_lists_of_exprs| {
330
0
                        Expr::GroupingSet(GroupingSet::GroupingSets(new_lists_of_exprs))
331
0
                    }),
332
            },
333
            Expr::InList(InList {
334
0
                expr,
335
0
                list,
336
0
                negated,
337
0
            }) => map_until_stop_and_collect!(
338
0
                transform_box(expr, &mut f),
339
0
                list,
340
0
                transform_vec(list, &mut f)
341
0
            )?
342
0
            .update_data(|(new_expr, new_list)| {
343
0
                Expr::InList(InList::new(new_expr, new_list, negated))
344
0
            }),
345
        })
346
0
    }
347
}
348
349
0
fn transform_box<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
350
0
    be: Box<Expr>,
351
0
    f: &mut F,
352
0
) -> Result<Transformed<Box<Expr>>> {
353
0
    Ok(f(*be)?.update_data(Box::new))
354
0
}
355
356
0
fn transform_option_box<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
357
0
    obe: Option<Box<Expr>>,
358
0
    f: &mut F,
359
0
) -> Result<Transformed<Option<Box<Expr>>>> {
360
0
    obe.map_or(Ok(Transformed::no(None)), |be| {
361
0
        Ok(transform_box(be, f)?.update_data(Some))
362
0
    })
363
0
}
364
365
/// &mut transform a Option<`Vec` of `Expr`s>
366
0
pub fn transform_option_vec<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
367
0
    ove: Option<Vec<Expr>>,
368
0
    f: &mut F,
369
0
) -> Result<Transformed<Option<Vec<Expr>>>> {
370
0
    ove.map_or(Ok(Transformed::no(None)), |ve| {
371
0
        Ok(transform_vec(ve, f)?.update_data(Some))
372
0
    })
373
0
}
374
375
/// &mut transform a `Vec` of `Expr`s
376
0
fn transform_vec<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
377
0
    ve: Vec<Expr>,
378
0
    f: &mut F,
379
0
) -> Result<Transformed<Vec<Expr>>> {
380
0
    ve.into_iter().map_until_stop_and_collect(f)
381
0
}
382
383
0
pub fn transform_sort_option_vec<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
384
0
    sorts_option: Option<Vec<Sort>>,
385
0
    f: &mut F,
386
0
) -> Result<Transformed<Option<Vec<Sort>>>> {
387
0
    sorts_option.map_or(Ok(Transformed::no(None)), |sorts| {
388
0
        Ok(transform_sort_vec(sorts, f)?.update_data(Some))
389
0
    })
390
0
}
391
392
0
pub fn transform_sort_vec<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
393
0
    sorts: Vec<Sort>,
394
0
    mut f: &mut F,
395
0
) -> Result<Transformed<Vec<Sort>>> {
396
0
    Ok(sorts
397
0
        .iter()
398
0
        .map(|sort| sort.expr.clone())
399
0
        .map_until_stop_and_collect(&mut f)?
400
0
        .update_data(|transformed_exprs| {
401
0
            replace_sort_expressions(sorts, transformed_exprs)
402
0
        }))
403
0
}
404
405
0
pub fn replace_sort_expressions(sorts: Vec<Sort>, new_expr: Vec<Expr>) -> Vec<Sort> {
406
0
    assert_eq!(sorts.len(), new_expr.len());
407
0
    sorts
408
0
        .into_iter()
409
0
        .zip(new_expr)
410
0
        .map(|(sort, expr)| replace_sort_expression(sort, expr))
411
0
        .collect()
412
0
}
413
414
0
pub fn replace_sort_expression(sort: Sort, new_expr: Expr) -> Sort {
415
0
    Sort {
416
0
        expr: new_expr,
417
0
        ..sort
418
0
    }
419
0
}