Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/expr/src/expr_rewriter/mod.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Expression rewriter
19
20
use std::collections::HashMap;
21
use std::collections::HashSet;
22
use std::fmt::Debug;
23
use std::sync::Arc;
24
25
use crate::expr::{Alias, Sort, Unnest};
26
use crate::logical_plan::Projection;
27
use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder};
28
29
use datafusion_common::config::ConfigOptions;
30
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
31
use datafusion_common::TableReference;
32
use datafusion_common::{Column, DFSchema, Result};
33
34
mod order_by;
35
pub use order_by::rewrite_sort_cols_by_aggs;
36
37
/// Trait for rewriting [`Expr`]s into function calls.
38
///
39
/// This trait is used with `FunctionRegistry::register_function_rewrite` to
40
/// to evaluating `Expr`s using functions that may not be built in to DataFusion
41
///
42
/// For example, concatenating arrays `a || b` is represented as
43
/// `Operator::ArrowAt`, but can be implemented by calling a function
44
/// `array_concat` from the `functions-nested` crate.
45
// This is not used in datafusion internally, but it is still helpful for downstream project so don't remove it.
46
pub trait FunctionRewrite: Debug {
47
    /// Return a human readable name for this rewrite
48
    fn name(&self) -> &str;
49
50
    /// Potentially rewrite `expr` to some other expression
51
    ///
52
    /// Note that recursion is handled by the caller -- this method should only
53
    /// handle `expr`, not recurse to its children.
54
    fn rewrite(
55
        &self,
56
        expr: Expr,
57
        schema: &DFSchema,
58
        config: &ConfigOptions,
59
    ) -> Result<Transformed<Expr>>;
60
}
61
62
/// Recursively call `LogicalPlanBuilder::normalize` on all [`Column`] expressions
63
/// in the `expr` expression tree.
64
0
pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
65
0
    expr.transform(|expr| {
66
        Ok({
67
0
            if let Expr::Column(c) = expr {
68
0
                let col = LogicalPlanBuilder::normalize(plan, c)?;
69
0
                Transformed::yes(Expr::Column(col))
70
            } else {
71
0
                Transformed::no(expr)
72
            }
73
        })
74
0
    })
75
0
    .data()
76
0
}
77
78
/// See [`Column::normalize_with_schemas_and_ambiguity_check`] for usage
79
0
pub fn normalize_col_with_schemas_and_ambiguity_check(
80
0
    expr: Expr,
81
0
    schemas: &[&[&DFSchema]],
82
0
    using_columns: &[HashSet<Column>],
83
0
) -> Result<Expr> {
84
    // Normalize column inside Unnest
85
0
    if let Expr::Unnest(Unnest { expr }) = expr {
86
0
        let e = normalize_col_with_schemas_and_ambiguity_check(
87
0
            expr.as_ref().clone(),
88
0
            schemas,
89
0
            using_columns,
90
0
        )?;
91
0
        return Ok(Expr::Unnest(Unnest { expr: Box::new(e) }));
92
0
    }
93
0
94
0
    expr.transform(|expr| {
95
        Ok({
96
0
            if let Expr::Column(c) = expr {
97
0
                let col =
98
0
                    c.normalize_with_schemas_and_ambiguity_check(schemas, using_columns)?;
99
0
                Transformed::yes(Expr::Column(col))
100
            } else {
101
0
                Transformed::no(expr)
102
            }
103
        })
104
0
    })
105
0
    .data()
106
0
}
107
108
/// Recursively normalize all [`Column`] expressions in a list of expression trees
109
0
pub fn normalize_cols(
110
0
    exprs: impl IntoIterator<Item = impl Into<Expr>>,
111
0
    plan: &LogicalPlan,
112
0
) -> Result<Vec<Expr>> {
113
0
    exprs
114
0
        .into_iter()
115
0
        .map(|e| normalize_col(e.into(), plan))
116
0
        .collect()
117
0
}
118
119
0
pub fn normalize_sorts(
120
0
    sorts: impl IntoIterator<Item = impl Into<Sort>>,
121
0
    plan: &LogicalPlan,
122
0
) -> Result<Vec<Sort>> {
123
0
    sorts
124
0
        .into_iter()
125
0
        .map(|e| {
126
0
            let sort = e.into();
127
0
            normalize_col(sort.expr, plan)
128
0
                .map(|expr| Sort::new(expr, sort.asc, sort.nulls_first))
129
0
        })
130
0
        .collect()
131
0
}
132
133
/// Recursively replace all [`Column`] expressions in a given expression tree with
134
/// `Column` expressions provided by the hash map argument.
135
0
pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Result<Expr> {
136
0
    expr.transform(|expr| {
137
        Ok({
138
0
            if let Expr::Column(c) = &expr {
139
0
                match replace_map.get(c) {
140
0
                    Some(new_c) => Transformed::yes(Expr::Column((*new_c).to_owned())),
141
0
                    None => Transformed::no(expr),
142
                }
143
            } else {
144
0
                Transformed::no(expr)
145
            }
146
        })
147
0
    })
148
0
    .data()
149
0
}
150
151
/// Recursively 'unnormalize' (remove all qualifiers) from an
152
/// expression tree.
153
///
154
/// For example, if there were expressions like `foo.bar` this would
155
/// rewrite it to just `bar`.
156
0
pub fn unnormalize_col(expr: Expr) -> Expr {
157
0
    expr.transform(|expr| {
158
        Ok({
159
0
            if let Expr::Column(c) = expr {
160
0
                let col = Column {
161
0
                    relation: None,
162
0
                    name: c.name,
163
0
                };
164
0
                Transformed::yes(Expr::Column(col))
165
            } else {
166
0
                Transformed::no(expr)
167
            }
168
        })
169
0
    })
170
0
    .data()
171
0
    .expect("Unnormalize is infallible")
172
0
}
173
174
/// Create a Column from the Scalar Expr
175
0
pub fn create_col_from_scalar_expr(
176
0
    scalar_expr: &Expr,
177
0
    subqry_alias: String,
178
0
) -> Result<Column> {
179
0
    match scalar_expr {
180
0
        Expr::Alias(Alias { name, .. }) => Ok(Column::new(
181
0
            Some::<TableReference>(subqry_alias.into()),
182
0
            name,
183
0
        )),
184
0
        Expr::Column(Column { relation: _, name }) => Ok(Column::new(
185
0
            Some::<TableReference>(subqry_alias.into()),
186
0
            name,
187
0
        )),
188
        _ => {
189
0
            let scalar_column = scalar_expr.schema_name().to_string();
190
0
            Ok(Column::new(
191
0
                Some::<TableReference>(subqry_alias.into()),
192
0
                scalar_column,
193
0
            ))
194
        }
195
    }
196
0
}
197
198
/// Recursively un-normalize all [`Column`] expressions in a list of expression trees
199
#[inline]
200
0
pub fn unnormalize_cols(exprs: impl IntoIterator<Item = Expr>) -> Vec<Expr> {
201
0
    exprs.into_iter().map(unnormalize_col).collect()
202
0
}
203
204
/// Recursively remove all the ['OuterReferenceColumn'] and return the inside Column
205
/// in the expression tree.
206
0
pub fn strip_outer_reference(expr: Expr) -> Expr {
207
0
    expr.transform(|expr| {
208
        Ok({
209
0
            if let Expr::OuterReferenceColumn(_, col) = expr {
210
0
                Transformed::yes(Expr::Column(col))
211
            } else {
212
0
                Transformed::no(expr)
213
            }
214
        })
215
0
    })
216
0
    .data()
217
0
    .expect("strip_outer_reference is infallible")
218
0
}
219
220
/// Returns plan with expressions coerced to types compatible with
221
/// schema types
222
0
pub fn coerce_plan_expr_for_schema(
223
0
    plan: LogicalPlan,
224
0
    schema: &DFSchema,
225
0
) -> Result<LogicalPlan> {
226
0
    match plan {
227
        // special case Projection to avoid adding multiple projections
228
0
        LogicalPlan::Projection(Projection { expr, input, .. }) => {
229
0
            let new_exprs = coerce_exprs_for_schema(expr, input.schema(), schema)?;
230
0
            let projection = Projection::try_new(new_exprs, input)?;
231
0
            Ok(LogicalPlan::Projection(projection))
232
        }
233
        _ => {
234
0
            let exprs: Vec<Expr> = plan.schema().iter().map(Expr::from).collect();
235
0
            let new_exprs = coerce_exprs_for_schema(exprs, plan.schema(), schema)?;
236
0
            let add_project = new_exprs.iter().any(|expr| expr.try_as_col().is_none());
237
0
            if add_project {
238
0
                let projection = Projection::try_new(new_exprs, Arc::new(plan))?;
239
0
                Ok(LogicalPlan::Projection(projection))
240
            } else {
241
0
                Ok(plan)
242
            }
243
        }
244
    }
245
0
}
246
247
0
fn coerce_exprs_for_schema(
248
0
    exprs: Vec<Expr>,
249
0
    src_schema: &DFSchema,
250
0
    dst_schema: &DFSchema,
251
0
) -> Result<Vec<Expr>> {
252
0
    exprs
253
0
        .into_iter()
254
0
        .enumerate()
255
0
        .map(|(idx, expr)| {
256
0
            let new_type = dst_schema.field(idx).data_type();
257
0
            if new_type != &expr.get_type(src_schema)? {
258
0
                match expr {
259
0
                    Expr::Alias(Alias { expr, name, .. }) => {
260
0
                        Ok(expr.cast_to(new_type, src_schema)?.alias(name))
261
                    }
262
0
                    Expr::Wildcard { .. } => Ok(expr),
263
0
                    _ => expr.cast_to(new_type, src_schema),
264
                }
265
            } else {
266
0
                Ok(expr)
267
            }
268
0
        })
269
0
        .collect::<Result<_>>()
270
0
}
271
272
/// Recursively un-alias an expressions
273
#[inline]
274
0
pub fn unalias(expr: Expr) -> Expr {
275
0
    match expr {
276
0
        Expr::Alias(Alias { expr, .. }) => unalias(*expr),
277
0
        _ => expr,
278
    }
279
0
}
280
281
/// Handles ensuring the name of rewritten expressions is not changed.
282
///
283
/// This is important when optimizing plans to ensure the output
284
/// schema of plan nodes don't change after optimization.
285
/// For example, if an expression `1 + 2` is rewritten to `3`, the name of the
286
/// expression should be preserved: `3 as "1 + 2"`
287
///
288
/// See <https://github.com/apache/datafusion/issues/3555> for details
289
pub struct NamePreserver {
290
    use_alias: bool,
291
}
292
293
/// If the qualified name of an expression is remembered, it will be preserved
294
/// when rewriting the expression
295
pub enum SavedName {
296
    /// Saved qualified name to be preserved
297
    Saved {
298
        relation: Option<TableReference>,
299
        name: String,
300
    },
301
    /// Name is not preserved
302
    None,
303
}
304
305
impl NamePreserver {
306
    /// Create a new NamePreserver for rewriting the `expr` that is part of the specified plan
307
0
    pub fn new(plan: &LogicalPlan) -> Self {
308
        Self {
309
            // The schema of Filter and Join nodes comes from their inputs rather than their output expressions,
310
            // so there is no need to use aliases to preserve expression names.
311
0
            use_alias: !matches!(plan, LogicalPlan::Filter(_) | LogicalPlan::Join(_)),
312
        }
313
0
    }
314
315
    /// Create a new NamePreserver for rewriting the `expr`s in `Projection`
316
    ///
317
    /// This will use aliases
318
0
    pub fn new_for_projection() -> Self {
319
0
        Self { use_alias: true }
320
0
    }
321
322
0
    pub fn save(&self, expr: &Expr) -> SavedName {
323
0
        if self.use_alias {
324
0
            let (relation, name) = expr.qualified_name();
325
0
            SavedName::Saved { relation, name }
326
        } else {
327
0
            SavedName::None
328
        }
329
0
    }
330
}
331
332
impl SavedName {
333
    /// Ensures the qualified name of the rewritten expression is preserved
334
0
    pub fn restore(self, expr: Expr) -> Expr {
335
0
        match self {
336
0
            SavedName::Saved { relation, name } => {
337
0
                let (new_relation, new_name) = expr.qualified_name();
338
0
                if new_relation != relation || new_name != name {
339
0
                    expr.alias_qualified(relation, name)
340
                } else {
341
0
                    expr
342
                }
343
            }
344
0
            SavedName::None => expr,
345
        }
346
0
    }
347
}
348
349
#[cfg(test)]
350
mod test {
351
    use std::ops::Add;
352
353
    use super::*;
354
    use crate::{col, lit, Cast};
355
    use arrow::datatypes::{DataType, Field, Schema};
356
    use datafusion_common::tree_node::TreeNodeRewriter;
357
    use datafusion_common::ScalarValue;
358
359
    #[derive(Default)]
360
    struct RecordingRewriter {
361
        v: Vec<String>,
362
    }
363
364
    impl TreeNodeRewriter for RecordingRewriter {
365
        type Node = Expr;
366
367
        fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
368
            self.v.push(format!("Previsited {expr}"));
369
            Ok(Transformed::no(expr))
370
        }
371
372
        fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
373
            self.v.push(format!("Mutated {expr}"));
374
            Ok(Transformed::no(expr))
375
        }
376
    }
377
378
    #[test]
379
    fn rewriter_rewrite() {
380
        // rewrites all "foo" string literals to "bar"
381
        let transformer = |expr: Expr| -> Result<Transformed<Expr>> {
382
            match expr {
383
                Expr::Literal(ScalarValue::Utf8(Some(utf8_val))) => {
384
                    let utf8_val = if utf8_val == "foo" {
385
                        "bar".to_string()
386
                    } else {
387
                        utf8_val
388
                    };
389
                    Ok(Transformed::yes(lit(utf8_val)))
390
                }
391
                // otherwise, return None
392
                _ => Ok(Transformed::no(expr)),
393
            }
394
        };
395
396
        // rewrites "foo" --> "bar"
397
        let rewritten = col("state")
398
            .eq(lit("foo"))
399
            .transform(transformer)
400
            .data()
401
            .unwrap();
402
        assert_eq!(rewritten, col("state").eq(lit("bar")));
403
404
        // doesn't rewrite
405
        let rewritten = col("state")
406
            .eq(lit("baz"))
407
            .transform(transformer)
408
            .data()
409
            .unwrap();
410
        assert_eq!(rewritten, col("state").eq(lit("baz")));
411
    }
412
413
    #[test]
414
    fn normalize_cols() {
415
        let expr = col("a") + col("b") + col("c");
416
417
        // Schemas with some matching and some non matching cols
418
        let schema_a = make_schema_with_empty_metadata(
419
            vec![Some("tableA".into()), Some("tableA".into())],
420
            vec!["a", "aa"],
421
        );
422
        let schema_c = make_schema_with_empty_metadata(
423
            vec![Some("tableC".into()), Some("tableC".into())],
424
            vec!["cc", "c"],
425
        );
426
        let schema_b =
427
            make_schema_with_empty_metadata(vec![Some("tableB".into())], vec!["b"]);
428
        // non matching
429
        let schema_f = make_schema_with_empty_metadata(
430
            vec![Some("tableC".into()), Some("tableC".into())],
431
            vec!["f", "ff"],
432
        );
433
        let schemas = vec![schema_c, schema_f, schema_b, schema_a];
434
        let schemas = schemas.iter().collect::<Vec<_>>();
435
436
        let normalized_expr =
437
            normalize_col_with_schemas_and_ambiguity_check(expr, &[&schemas], &[])
438
                .unwrap();
439
        assert_eq!(
440
            normalized_expr,
441
            col("tableA.a") + col("tableB.b") + col("tableC.c")
442
        );
443
    }
444
445
    #[test]
446
    fn normalize_cols_non_exist() {
447
        // test normalizing columns when the name doesn't exist
448
        let expr = col("a") + col("b");
449
        let schema_a =
450
            make_schema_with_empty_metadata(vec![Some("\"tableA\"".into())], vec!["a"]);
451
        let schemas = [schema_a];
452
        let schemas = schemas.iter().collect::<Vec<_>>();
453
454
        let error =
455
            normalize_col_with_schemas_and_ambiguity_check(expr, &[&schemas], &[])
456
                .unwrap_err()
457
                .strip_backtrace();
458
        assert_eq!(
459
            error,
460
            r#"Schema error: No field named b. Valid fields are "tableA".a."#
461
        );
462
    }
463
464
    #[test]
465
    fn unnormalize_cols() {
466
        let expr = col("tableA.a") + col("tableB.b");
467
        let unnormalized_expr = unnormalize_col(expr);
468
        assert_eq!(unnormalized_expr, col("a") + col("b"));
469
    }
470
471
    fn make_schema_with_empty_metadata(
472
        qualifiers: Vec<Option<TableReference>>,
473
        fields: Vec<&str>,
474
    ) -> DFSchema {
475
        let fields = fields
476
            .iter()
477
            .map(|f| Arc::new(Field::new(f.to_string(), DataType::Int8, false)))
478
            .collect::<Vec<_>>();
479
        let schema = Arc::new(Schema::new(fields));
480
        DFSchema::from_field_specific_qualified_schema(qualifiers, &schema).unwrap()
481
    }
482
483
    #[test]
484
    fn rewriter_visit() {
485
        let mut rewriter = RecordingRewriter::default();
486
        col("state").eq(lit("CO")).rewrite(&mut rewriter).unwrap();
487
488
        assert_eq!(
489
            rewriter.v,
490
            vec![
491
                "Previsited state = Utf8(\"CO\")",
492
                "Previsited state",
493
                "Mutated state",
494
                "Previsited Utf8(\"CO\")",
495
                "Mutated Utf8(\"CO\")",
496
                "Mutated state = Utf8(\"CO\")"
497
            ]
498
        )
499
    }
500
501
    #[test]
502
    fn test_rewrite_preserving_name() {
503
        test_rewrite(col("a"), col("a"));
504
505
        test_rewrite(col("a"), col("b"));
506
507
        // cast data types
508
        test_rewrite(
509
            col("a"),
510
            Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int32)),
511
        );
512
513
        // change literal type from i32 to i64
514
        test_rewrite(col("a").add(lit(1i32)), col("a").add(lit(1i64)));
515
516
        // test preserve qualifier
517
        test_rewrite(
518
            Expr::Column(Column::new(Some("test"), "a")),
519
            Expr::Column(Column::new_unqualified("test.a")),
520
        );
521
        test_rewrite(
522
            Expr::Column(Column::new_unqualified("test.a")),
523
            Expr::Column(Column::new(Some("test"), "a")),
524
        );
525
    }
526
527
    /// rewrites `expr_from` to `rewrite_to` while preserving the original qualified name
528
    /// by using the `NamePreserver`
529
    fn test_rewrite(expr_from: Expr, rewrite_to: Expr) {
530
        struct TestRewriter {
531
            rewrite_to: Expr,
532
        }
533
534
        impl TreeNodeRewriter for TestRewriter {
535
            type Node = Expr;
536
537
            fn f_up(&mut self, _: Expr) -> Result<Transformed<Expr>> {
538
                Ok(Transformed::yes(self.rewrite_to.clone()))
539
            }
540
        }
541
542
        let mut rewriter = TestRewriter {
543
            rewrite_to: rewrite_to.clone(),
544
        };
545
        let saved_name = NamePreserver { use_alias: true }.save(&expr_from);
546
        let new_expr = expr_from.clone().rewrite(&mut rewriter).unwrap().data;
547
        let new_expr = saved_name.restore(new_expr);
548
549
        let original_name = expr_from.qualified_name();
550
        let new_name = new_expr.qualified_name();
551
        assert_eq!(
552
            original_name, new_name,
553
            "mismatch rewriting expr_from: {expr_from} to {rewrite_to}"
554
        )
555
    }
556
}