Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-expr/src/utils/mod.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
mod guarantee;
19
pub use guarantee::{Guarantee, LiteralGuarantee};
20
use hashbrown::HashSet;
21
22
use std::borrow::Borrow;
23
use std::collections::HashMap;
24
use std::sync::Arc;
25
26
use crate::expressions::{BinaryExpr, Column};
27
use crate::tree_node::ExprContext;
28
use crate::PhysicalExpr;
29
use crate::PhysicalSortExpr;
30
31
use arrow::datatypes::SchemaRef;
32
use datafusion_common::tree_node::{
33
    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
34
};
35
use datafusion_common::Result;
36
use datafusion_expr::Operator;
37
38
use itertools::Itertools;
39
use petgraph::graph::NodeIndex;
40
use petgraph::stable_graph::StableGraph;
41
42
/// Assume the predicate is in the form of CNF, split the predicate to a Vec of PhysicalExprs.
43
///
44
/// For example, split "a1 = a2 AND b1 <= b2 AND c1 != c2" into ["a1 = a2", "b1 <= b2", "c1 != c2"]
45
33
pub fn split_conjunction(
46
33
    predicate: &Arc<dyn PhysicalExpr>,
47
33
) -> Vec<&Arc<dyn PhysicalExpr>> {
48
33
    split_impl(Operator::And, predicate, vec![])
49
33
}
50
51
/// Assume the predicate is in the form of DNF, split the predicate to a Vec of PhysicalExprs.
52
///
53
/// For example, split "a1 = a2 OR b1 <= b2 OR c1 != c2" into ["a1 = a2", "b1 <= b2", "c1 != c2"]
54
0
pub fn split_disjunction(
55
0
    predicate: &Arc<dyn PhysicalExpr>,
56
0
) -> Vec<&Arc<dyn PhysicalExpr>> {
57
0
    split_impl(Operator::Or, predicate, vec![])
58
0
}
59
60
67
fn split_impl<'a>(
61
67
    operator: Operator,
62
67
    predicate: &'a Arc<dyn PhysicalExpr>,
63
67
    mut exprs: Vec<&'a Arc<dyn PhysicalExpr>>,
64
67
) -> Vec<&'a Arc<dyn PhysicalExpr>> {
65
67
    match predicate.as_any().downcast_ref::<BinaryExpr>() {
66
67
        Some(
binary17
) if binary.op() == &operator => {
67
17
            let exprs = split_impl(operator, binary.left(), exprs);
68
17
            split_impl(operator, binary.right(), exprs)
69
        }
70
        Some(_) | None => {
71
50
            exprs.push(predicate);
72
50
            exprs
73
        }
74
    }
75
67
}
76
77
/// This function maps back requirement after ProjectionExec
78
/// to the Executor for its input.
79
// Specifically, `ProjectionExec` changes index of `Column`s in the schema of its input executor.
80
// This function changes requirement given according to ProjectionExec schema to the requirement
81
// according to schema of input executor to the ProjectionExec.
82
// For instance, Column{"a", 0} would turn to Column{"a", 1}. Please note that this function assumes that
83
// name of the Column is unique. If we have a requirement such that Column{"a", 0}, Column{"a", 1}.
84
// This function will produce incorrect result (It will only emit single Column as a result).
85
0
pub fn map_columns_before_projection(
86
0
    parent_required: &[Arc<dyn PhysicalExpr>],
87
0
    proj_exprs: &[(Arc<dyn PhysicalExpr>, String)],
88
0
) -> Vec<Arc<dyn PhysicalExpr>> {
89
0
    let column_mapping = proj_exprs
90
0
        .iter()
91
0
        .filter_map(|(expr, name)| {
92
0
            expr.as_any()
93
0
                .downcast_ref::<Column>()
94
0
                .map(|column| (name.clone(), column.clone()))
95
0
        })
96
0
        .collect::<HashMap<_, _>>();
97
0
    parent_required
98
0
        .iter()
99
0
        .filter_map(|r| {
100
0
            r.as_any()
101
0
                .downcast_ref::<Column>()
102
0
                .and_then(|c| column_mapping.get(c.name()))
103
0
        })
104
0
        .map(|e| Arc::new(e.clone()) as _)
105
0
        .collect()
106
0
}
107
108
/// This function returns all `Arc<dyn PhysicalExpr>`s inside the given
109
/// `PhysicalSortExpr` sequence.
110
0
pub fn convert_to_expr<T: Borrow<PhysicalSortExpr>>(
111
0
    sequence: impl IntoIterator<Item = T>,
112
0
) -> Vec<Arc<dyn PhysicalExpr>> {
113
0
    sequence
114
0
        .into_iter()
115
0
        .map(|elem| Arc::clone(&elem.borrow().expr))
116
0
        .collect()
117
0
}
118
119
/// This function finds the indices of `targets` within `items` using strict
120
/// equality.
121
0
pub fn get_indices_of_exprs_strict<T: Borrow<Arc<dyn PhysicalExpr>>>(
122
0
    targets: impl IntoIterator<Item = T>,
123
0
    items: &[Arc<dyn PhysicalExpr>],
124
0
) -> Vec<usize> {
125
0
    targets
126
0
        .into_iter()
127
0
        .filter_map(|target| items.iter().position(|e| e.eq(target.borrow())))
128
0
        .collect()
129
0
}
130
131
pub type ExprTreeNode<T> = ExprContext<Option<T>>;
132
133
/// This struct is used to convert a [`PhysicalExpr`] tree into a DAEG (i.e. an expression
134
/// DAG) by collecting identical expressions in one node. Caller specifies the node type
135
/// in the DAEG via the `constructor` argument, which constructs nodes in the DAEG from
136
/// the [`ExprTreeNode`] ancillary object.
137
struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>> {
138
    // The resulting DAEG (expression DAG).
139
    graph: StableGraph<T, usize>,
140
    // A vector of visited expression nodes and their corresponding node indices.
141
    visited_plans: Vec<(Arc<dyn PhysicalExpr>, NodeIndex)>,
142
    // A function to convert an input expression node to T.
143
    constructor: &'a F,
144
}
145
146
impl<'a, T, F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>>
147
    PhysicalExprDAEGBuilder<'a, T, F>
148
{
149
    // This method mutates an expression node by transforming it to a physical expression
150
    // and adding it to the graph. The method returns the mutated expression node.
151
17.0k
    fn mutate(
152
17.0k
        &mut self,
153
17.0k
        mut node: ExprTreeNode<NodeIndex>,
154
17.0k
    ) -> Result<Transformed<ExprTreeNode<NodeIndex>>> {
155
17.0k
        // Get the expression associated with the input expression node.
156
17.0k
        let expr = &node.expr;
157
158
        // Check if the expression has already been visited.
159
92.2k
        let 
node_idx17.0k
= match
self.visited_plans.iter().find(17.0k
|(e, _)| expr.eq(e)
)17.0k
{
160
            // If the expression has been visited, return the corresponding node index.
161
2.64k
            Some((_, idx)) => *idx,
162
            // If the expression has not been visited, add a new node to the graph and
163
            // add edges to its child nodes. Add the visited expression to the vector
164
            // of visited expressions and return the newly created node index.
165
            None => {
166
14.4k
                let node_idx = self.graph.add_node((self.constructor)(&node)
?0
);
167
15.6k
                for expr_node in 
node.children.iter()14.4k
{
168
15.6k
                    self.graph.add_edge(node_idx, expr_node.data.unwrap(), 0);
169
15.6k
                }
170
14.4k
                self.visited_plans.push((Arc::clone(expr), node_idx));
171
14.4k
                node_idx
172
            }
173
        };
174
        // Set the data field of the input expression node to the corresponding node index.
175
17.0k
        node.data = Some(node_idx);
176
17.0k
        // Return the mutated expression node.
177
17.0k
        Ok(Transformed::yes(node))
178
17.0k
    }
179
}
180
181
// A function that builds a directed acyclic graph of physical expression trees.
182
1.14k
pub fn build_dag<T, F>(
183
1.14k
    expr: Arc<dyn PhysicalExpr>,
184
1.14k
    constructor: &F,
185
1.14k
) -> Result<(NodeIndex, StableGraph<T, usize>)>
186
1.14k
where
187
1.14k
    F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>,
188
1.14k
{
189
1.14k
    // Create a new expression tree node from the input expression.
190
1.14k
    let init = ExprTreeNode::new_default(expr);
191
1.14k
    // Create a new `PhysicalExprDAEGBuilder` instance.
192
1.14k
    let mut builder = PhysicalExprDAEGBuilder {
193
1.14k
        graph: StableGraph::<T, usize>::new(),
194
1.14k
        visited_plans: Vec::<(Arc<dyn PhysicalExpr>, NodeIndex)>::new(),
195
1.14k
        constructor,
196
1.14k
    };
197
    // Use the builder to transform the expression tree node into a DAG.
198
17.0k
    let 
root1.14k
=
init.transform_up(1.14k
|node| builder.mutate(node)
).data()1.14k
?0
;
199
    // Return a tuple containing the root node index and the DAG.
200
1.14k
    Ok((root.data.unwrap(), builder.graph))
201
1.14k
}
202
203
/// Recursively extract referenced [`Column`]s within a [`PhysicalExpr`].
204
4.50k
pub fn collect_columns(expr: &Arc<dyn PhysicalExpr>) -> HashSet<Column> {
205
4.50k
    let mut columns = HashSet::<Column>::new();
206
4.90k
    expr.apply(|expr| {
207
4.90k
        if let Some(
column4.59k
) = expr.as_any().downcast_ref::<Column>() {
208
4.59k
            columns.get_or_insert_owned(column);
209
4.59k
        }
307
210
4.90k
        Ok(TreeNodeRecursion::Continue)
211
4.90k
    })
212
4.50k
    // pre_visit always returns OK, so this will always too
213
4.50k
    .expect("no way to return error during recursion");
214
4.50k
    columns
215
4.50k
}
216
217
/// Re-assign column indices referenced in predicate according to given schema.
218
/// This may be helpful when dealing with projections.
219
0
pub fn reassign_predicate_columns(
220
0
    pred: Arc<dyn PhysicalExpr>,
221
0
    schema: &SchemaRef,
222
0
    ignore_not_found: bool,
223
0
) -> Result<Arc<dyn PhysicalExpr>> {
224
0
    pred.transform_down(|expr| {
225
0
        let expr_any = expr.as_any();
226
227
0
        if let Some(column) = expr_any.downcast_ref::<Column>() {
228
0
            let index = match schema.index_of(column.name()) {
229
0
                Ok(idx) => idx,
230
0
                Err(_) if ignore_not_found => usize::MAX,
231
0
                Err(e) => return Err(e.into()),
232
            };
233
0
            return Ok(Transformed::yes(Arc::new(Column::new(
234
0
                column.name(),
235
0
                index,
236
0
            ))));
237
0
        }
238
0
        Ok(Transformed::no(expr))
239
0
    })
240
0
    .data()
241
0
}
242
243
/// Merge left and right sort expressions, checking for duplicates.
244
2
pub fn merge_vectors(
245
2
    left: &[PhysicalSortExpr],
246
2
    right: &[PhysicalSortExpr],
247
2
) -> Vec<PhysicalSortExpr> {
248
2
    left.iter()
249
2
        .cloned()
250
2
        .chain(right.iter().cloned())
251
2
        .unique()
252
2
        .collect()
253
2
}
254
255
#[cfg(test)]
256
pub(crate) mod tests {
257
    use std::any::Any;
258
    use std::fmt::{Display, Formatter};
259
260
    use super::*;
261
    use crate::expressions::{binary, cast, col, in_list, lit, Literal};
262
263
    use arrow_array::{ArrayRef, Float32Array, Float64Array};
264
    use arrow_schema::{DataType, Field, Schema};
265
    use datafusion_common::{exec_err, DataFusionError, ScalarValue};
266
    use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
267
    use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
268
269
    use petgraph::visit::Bfs;
270
271
    #[derive(Debug, Clone)]
272
    pub struct TestScalarUDF {
273
        pub(crate) signature: Signature,
274
    }
275
276
    impl TestScalarUDF {
277
        pub fn new() -> Self {
278
            use DataType::*;
279
            Self {
280
                signature: Signature::uniform(
281
                    1,
282
                    vec![Float64, Float32],
283
                    Volatility::Immutable,
284
                ),
285
            }
286
        }
287
    }
288
289
    impl ScalarUDFImpl for TestScalarUDF {
290
        fn as_any(&self) -> &dyn Any {
291
            self
292
        }
293
        fn name(&self) -> &str {
294
            "test-scalar-udf"
295
        }
296
297
        fn signature(&self) -> &Signature {
298
            &self.signature
299
        }
300
301
        fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
302
            let arg_type = &arg_types[0];
303
304
            match arg_type {
305
                DataType::Float32 => Ok(DataType::Float32),
306
                _ => Ok(DataType::Float64),
307
            }
308
        }
309
310
        fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
311
            Ok(input[0].sort_properties)
312
        }
313
314
        fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
315
            let args = ColumnarValue::values_to_arrays(args)?;
316
317
            let arr: ArrayRef = match args[0].data_type() {
318
                DataType::Float64 => Arc::new({
319
                    let arg = &args[0]
320
                        .as_any()
321
                        .downcast_ref::<Float64Array>()
322
                        .ok_or_else(|| {
323
                            DataFusionError::Internal(format!(
324
                                "could not cast {} to {}",
325
                                self.name(),
326
                                std::any::type_name::<Float64Array>()
327
                            ))
328
                        })?;
329
330
                    arg.iter()
331
                        .map(|a| a.map(f64::floor))
332
                        .collect::<Float64Array>()
333
                }),
334
                DataType::Float32 => Arc::new({
335
                    let arg = &args[0]
336
                        .as_any()
337
                        .downcast_ref::<Float32Array>()
338
                        .ok_or_else(|| {
339
                            DataFusionError::Internal(format!(
340
                                "could not cast {} to {}",
341
                                self.name(),
342
                                std::any::type_name::<Float32Array>()
343
                            ))
344
                        })?;
345
346
                    arg.iter()
347
                        .map(|a| a.map(f32::floor))
348
                        .collect::<Float32Array>()
349
                }),
350
                other => {
351
                    return exec_err!(
352
                        "Unsupported data type {other:?} for function {}",
353
                        self.name()
354
                    );
355
                }
356
            };
357
            Ok(ColumnarValue::Array(arr))
358
        }
359
    }
360
361
    #[derive(Clone)]
362
    struct DummyProperty {
363
        expr_type: String,
364
    }
365
366
    /// This is a dummy node in the DAEG; it stores a reference to the actual
367
    /// [PhysicalExpr] as well as a dummy property.
368
    #[derive(Clone)]
369
    struct PhysicalExprDummyNode {
370
        pub expr: Arc<dyn PhysicalExpr>,
371
        pub property: DummyProperty,
372
    }
373
374
    impl Display for PhysicalExprDummyNode {
375
        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
376
            write!(f, "{}", self.expr)
377
        }
378
    }
379
380
    fn make_dummy_node(node: &ExprTreeNode<NodeIndex>) -> Result<PhysicalExprDummyNode> {
381
        let expr = Arc::clone(&node.expr);
382
        let dummy_property = if expr.as_any().is::<BinaryExpr>() {
383
            "Binary"
384
        } else if expr.as_any().is::<Column>() {
385
            "Column"
386
        } else if expr.as_any().is::<Literal>() {
387
            "Literal"
388
        } else {
389
            "Other"
390
        }
391
        .to_owned();
392
        Ok(PhysicalExprDummyNode {
393
            expr,
394
            property: DummyProperty {
395
                expr_type: dummy_property,
396
            },
397
        })
398
    }
399
400
    #[test]
401
    fn test_build_dag() -> Result<()> {
402
        let schema = Schema::new(vec![
403
            Field::new("0", DataType::Int32, true),
404
            Field::new("1", DataType::Int32, true),
405
            Field::new("2", DataType::Int32, true),
406
        ]);
407
        let expr = binary(
408
            cast(
409
                binary(
410
                    col("0", &schema)?,
411
                    Operator::Plus,
412
                    col("1", &schema)?,
413
                    &schema,
414
                )?,
415
                &schema,
416
                DataType::Int64,
417
            )?,
418
            Operator::Gt,
419
            binary(
420
                cast(col("2", &schema)?, &schema, DataType::Int64)?,
421
                Operator::Plus,
422
                lit(ScalarValue::Int64(Some(10))),
423
                &schema,
424
            )?,
425
            &schema,
426
        )?;
427
        let mut vector_dummy_props = vec![];
428
        let (root, graph) = build_dag(expr, &make_dummy_node)?;
429
        let mut bfs = Bfs::new(&graph, root);
430
        while let Some(node_index) = bfs.next(&graph) {
431
            let node = &graph[node_index];
432
            vector_dummy_props.push(node.property.clone());
433
        }
434
435
        assert_eq!(
436
            vector_dummy_props
437
                .iter()
438
                .filter(|property| property.expr_type == "Binary")
439
                .count(),
440
            3
441
        );
442
        assert_eq!(
443
            vector_dummy_props
444
                .iter()
445
                .filter(|property| property.expr_type == "Column")
446
                .count(),
447
            3
448
        );
449
        assert_eq!(
450
            vector_dummy_props
451
                .iter()
452
                .filter(|property| property.expr_type == "Literal")
453
                .count(),
454
            1
455
        );
456
        assert_eq!(
457
            vector_dummy_props
458
                .iter()
459
                .filter(|property| property.expr_type == "Other")
460
                .count(),
461
            2
462
        );
463
        Ok(())
464
    }
465
466
    #[test]
467
    fn test_convert_to_expr() -> Result<()> {
468
        let schema = Schema::new(vec![Field::new("a", DataType::UInt64, false)]);
469
        let sort_expr = vec![PhysicalSortExpr {
470
            expr: col("a", &schema)?,
471
            options: Default::default(),
472
        }];
473
        assert!(convert_to_expr(&sort_expr)[0].eq(&sort_expr[0].expr));
474
        Ok(())
475
    }
476
477
    #[test]
478
    fn test_get_indices_of_exprs_strict() {
479
        let list1: Vec<Arc<dyn PhysicalExpr>> = vec![
480
            Arc::new(Column::new("a", 0)),
481
            Arc::new(Column::new("b", 1)),
482
            Arc::new(Column::new("c", 2)),
483
            Arc::new(Column::new("d", 3)),
484
        ];
485
        let list2: Vec<Arc<dyn PhysicalExpr>> = vec![
486
            Arc::new(Column::new("b", 1)),
487
            Arc::new(Column::new("c", 2)),
488
            Arc::new(Column::new("a", 0)),
489
        ];
490
        assert_eq!(get_indices_of_exprs_strict(&list1, &list2), vec![2, 0, 1]);
491
        assert_eq!(get_indices_of_exprs_strict(&list2, &list1), vec![1, 2, 0]);
492
    }
493
494
    #[test]
495
    fn test_reassign_predicate_columns_in_list() {
496
        let int_field = Field::new("should_not_matter", DataType::Int64, true);
497
        let dict_field = Field::new(
498
            "id",
499
            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
500
            true,
501
        );
502
        let schema_small = Arc::new(Schema::new(vec![dict_field.clone()]));
503
        let schema_big = Arc::new(Schema::new(vec![int_field, dict_field]));
504
        let pred = in_list(
505
            Arc::new(Column::new_with_schema("id", &schema_big).unwrap()),
506
            vec![lit(ScalarValue::Dictionary(
507
                Box::new(DataType::Int32),
508
                Box::new(ScalarValue::from("2")),
509
            ))],
510
            &false,
511
            &schema_big,
512
        )
513
        .unwrap();
514
515
        let actual = reassign_predicate_columns(pred, &schema_small, false).unwrap();
516
517
        let expected = in_list(
518
            Arc::new(Column::new_with_schema("id", &schema_small).unwrap()),
519
            vec![lit(ScalarValue::Dictionary(
520
                Box::new(DataType::Int32),
521
                Box::new(ScalarValue::from("2")),
522
            ))],
523
            &false,
524
            &schema_small,
525
        )
526
        .unwrap();
527
528
        assert_eq!(actual.as_ref(), expected.as_any());
529
    }
530
531
    #[test]
532
    fn test_collect_columns() -> Result<()> {
533
        let expr1 = Arc::new(Column::new("col1", 2)) as _;
534
        let mut expected = HashSet::new();
535
        expected.insert(Column::new("col1", 2));
536
        assert_eq!(collect_columns(&expr1), expected);
537
538
        let expr2 = Arc::new(Column::new("col2", 5)) as _;
539
        let mut expected = HashSet::new();
540
        expected.insert(Column::new("col2", 5));
541
        assert_eq!(collect_columns(&expr2), expected);
542
543
        let expr3 = Arc::new(BinaryExpr::new(expr1, Operator::Plus, expr2)) as _;
544
        let mut expected = HashSet::new();
545
        expected.insert(Column::new("col1", 2));
546
        expected.insert(Column::new("col2", 5));
547
        assert_eq!(collect_columns(&expr3), expected);
548
        Ok(())
549
    }
550
}