Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-expr/src/equivalence/projection.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
use std::sync::Arc;
19
20
use crate::expressions::Column;
21
use crate::PhysicalExpr;
22
23
use arrow::datatypes::SchemaRef;
24
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
25
use datafusion_common::{internal_err, Result};
26
27
/// Stores the mapping between source expressions and target expressions for a
28
/// projection.
29
#[derive(Debug, Clone)]
30
pub struct ProjectionMapping {
31
    /// Mapping between source expressions and target expressions.
32
    /// Vector indices correspond to the indices after projection.
33
    pub map: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)>,
34
}
35
36
impl ProjectionMapping {
37
    /// Constructs the mapping between a projection's input and output
38
    /// expressions.
39
    ///
40
    /// For example, given the input projection expressions (`a + b`, `c + d`)
41
    /// and an output schema with two columns `"c + d"` and `"a + b"`, the
42
    /// projection mapping would be:
43
    ///
44
    /// ```text
45
    ///  [0]: (c + d, col("c + d"))
46
    ///  [1]: (a + b, col("a + b"))
47
    /// ```
48
    ///
49
    /// where `col("c + d")` means the column named `"c + d"`.
50
52
    pub fn try_new(
51
52
        expr: &[(Arc<dyn PhysicalExpr>, String)],
52
52
        input_schema: &SchemaRef,
53
52
    ) -> Result<Self> {
54
52
        // Construct a map from the input expressions to the output expression of the projection:
55
52
        expr.iter()
56
52
            .enumerate()
57
61
            .map(|(expr_idx, (expression, name))| {
58
61
                let target_expr = Arc::new(Column::new(name, expr_idx)) as _;
59
61
                Arc::clone(expression)
60
61
                    .transform_down(|e| match e.as_any().downcast_ref::<Column>() {
61
60
                        Some(col) => {
62
60
                            // Sometimes, an expression and its name in the input_schema
63
60
                            // doesn't match. This can cause problems, so we make sure
64
60
                            // that the expression name matches with the name in `input_schema`.
65
60
                            // Conceptually, `source_expr` and `expression` should be the same.
66
60
                            let idx = col.index();
67
60
                            let matching_input_field = input_schema.field(idx);
68
60
                            if col.name() != matching_input_field.name() {
69
0
                                return internal_err!("Input field name {} does not match with the projection expression {}",
70
0
                                    matching_input_field.name(),col.name())
71
60
                                }
72
60
                            let matching_input_column =
73
60
                                Column::new(matching_input_field.name(), idx);
74
60
                            Ok(Transformed::yes(Arc::new(matching_input_column)))
75
                        }
76
1
                        None => Ok(Transformed::no(e)),
77
61
                    })
78
61
                    .data()
79
61
                    .map(|source_expr| (source_expr, target_expr))
80
61
            })
81
52
            .collect::<Result<Vec<_>>>()
82
52
            .map(|map| Self { map })
83
52
    }
84
85
    /// Constructs a subset mapping using the provided indices.
86
    ///
87
    /// This is used when the output is a subset of the input without any
88
    /// other transformations. The indices are for columns in the schema.
89
0
    pub fn from_indices(indices: &[usize], schema: &SchemaRef) -> Result<Self> {
90
0
        let projection_exprs = project_index_to_exprs(indices, schema);
91
0
        ProjectionMapping::try_new(&projection_exprs, schema)
92
0
    }
93
94
    /// Iterate over pairs of (source, target) expressions
95
209
    pub fn iter(
96
209
        &self,
97
209
    ) -> impl Iterator<Item = &(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> + '_ {
98
209
        self.map.iter()
99
209
    }
100
101
    /// This function returns the target expression for a given source expression.
102
    ///
103
    /// # Arguments
104
    ///
105
    /// * `expr` - Source physical expression.
106
    ///
107
    /// # Returns
108
    ///
109
    /// An `Option` containing the target for the given source expression,
110
    /// where a `None` value means that `expr` is not inside the mapping.
111
1
    pub fn target_expr(
112
1
        &self,
113
1
        expr: &Arc<dyn PhysicalExpr>,
114
1
    ) -> Option<Arc<dyn PhysicalExpr>> {
115
1
        self.map
116
1
            .iter()
117
1
            .find(|(source, _)| source.eq(expr))
118
1
            .map(|(_, target)| Arc::clone(target))
119
1
    }
120
}
121
122
0
fn project_index_to_exprs(
123
0
    projection_index: &[usize],
124
0
    schema: &SchemaRef,
125
0
) -> Vec<(Arc<dyn PhysicalExpr>, String)> {
126
0
    projection_index
127
0
        .iter()
128
0
        .map(|index| {
129
0
            let field = schema.field(*index);
130
0
            (
131
0
                Arc::new(Column::new(field.name(), *index)) as Arc<dyn PhysicalExpr>,
132
0
                field.name().to_owned(),
133
0
            )
134
0
        })
135
0
        .collect::<Vec<_>>()
136
0
}
137
138
#[cfg(test)]
139
mod tests {
140
    use super::*;
141
    use crate::equivalence::tests::{
142
        apply_projection, convert_to_orderings, convert_to_orderings_owned,
143
        create_random_schema, generate_table_for_eq_properties, is_table_same_after_sort,
144
        output_schema,
145
    };
146
    use crate::equivalence::EquivalenceProperties;
147
    use crate::expressions::{col, BinaryExpr};
148
    use crate::udf::create_physical_expr;
149
    use crate::utils::tests::TestScalarUDF;
150
    use crate::PhysicalSortExpr;
151
152
    use arrow::datatypes::{DataType, Field, Schema};
153
    use arrow_schema::{SortOptions, TimeUnit};
154
    use datafusion_common::DFSchema;
155
    use datafusion_expr::{Operator, ScalarUDF};
156
157
    use itertools::Itertools;
158
159
    #[test]
160
    fn project_orderings() -> Result<()> {
161
        let schema = Arc::new(Schema::new(vec![
162
            Field::new("a", DataType::Int32, true),
163
            Field::new("b", DataType::Int32, true),
164
            Field::new("c", DataType::Int32, true),
165
            Field::new("d", DataType::Int32, true),
166
            Field::new("e", DataType::Int32, true),
167
            Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true),
168
        ]));
169
        let col_a = &col("a", &schema)?;
170
        let col_b = &col("b", &schema)?;
171
        let col_c = &col("c", &schema)?;
172
        let col_d = &col("d", &schema)?;
173
        let col_e = &col("e", &schema)?;
174
        let col_ts = &col("ts", &schema)?;
175
        let a_plus_b = Arc::new(BinaryExpr::new(
176
            Arc::clone(col_a),
177
            Operator::Plus,
178
            Arc::clone(col_b),
179
        )) as Arc<dyn PhysicalExpr>;
180
        let b_plus_d = Arc::new(BinaryExpr::new(
181
            Arc::clone(col_b),
182
            Operator::Plus,
183
            Arc::clone(col_d),
184
        )) as Arc<dyn PhysicalExpr>;
185
        let b_plus_e = Arc::new(BinaryExpr::new(
186
            Arc::clone(col_b),
187
            Operator::Plus,
188
            Arc::clone(col_e),
189
        )) as Arc<dyn PhysicalExpr>;
190
        let c_plus_d = Arc::new(BinaryExpr::new(
191
            Arc::clone(col_c),
192
            Operator::Plus,
193
            Arc::clone(col_d),
194
        )) as Arc<dyn PhysicalExpr>;
195
196
        let option_asc = SortOptions {
197
            descending: false,
198
            nulls_first: false,
199
        };
200
        let option_desc = SortOptions {
201
            descending: true,
202
            nulls_first: true,
203
        };
204
205
        let test_cases = vec![
206
            // ---------- TEST CASE 1 ------------
207
            (
208
                // orderings
209
                vec![
210
                    // [b ASC]
211
                    vec![(col_b, option_asc)],
212
                ],
213
                // projection exprs
214
                vec![(col_b, "b_new".to_string()), (col_a, "a_new".to_string())],
215
                // expected
216
                vec![
217
                    // [b_new ASC]
218
                    vec![("b_new", option_asc)],
219
                ],
220
            ),
221
            // ---------- TEST CASE 2 ------------
222
            (
223
                // orderings
224
                vec![
225
                    // empty ordering
226
                ],
227
                // projection exprs
228
                vec![(col_c, "c_new".to_string()), (col_b, "b_new".to_string())],
229
                // expected
230
                vec![
231
                    // no ordering at the output
232
                ],
233
            ),
234
            // ---------- TEST CASE 3 ------------
235
            (
236
                // orderings
237
                vec![
238
                    // [ts ASC]
239
                    vec![(col_ts, option_asc)],
240
                ],
241
                // projection exprs
242
                vec![
243
                    (col_b, "b_new".to_string()),
244
                    (col_a, "a_new".to_string()),
245
                    (col_ts, "ts_new".to_string()),
246
                ],
247
                // expected
248
                vec![
249
                    // [ts_new ASC]
250
                    vec![("ts_new", option_asc)],
251
                ],
252
            ),
253
            // ---------- TEST CASE 4 ------------
254
            (
255
                // orderings
256
                vec![
257
                    // [a ASC, ts ASC]
258
                    vec![(col_a, option_asc), (col_ts, option_asc)],
259
                    // [b ASC, ts ASC]
260
                    vec![(col_b, option_asc), (col_ts, option_asc)],
261
                ],
262
                // projection exprs
263
                vec![
264
                    (col_b, "b_new".to_string()),
265
                    (col_a, "a_new".to_string()),
266
                    (col_ts, "ts_new".to_string()),
267
                ],
268
                // expected
269
                vec![
270
                    // [a_new ASC, ts_new ASC]
271
                    vec![("a_new", option_asc), ("ts_new", option_asc)],
272
                    // [b_new ASC, ts_new ASC]
273
                    vec![("b_new", option_asc), ("ts_new", option_asc)],
274
                ],
275
            ),
276
            // ---------- TEST CASE 5 ------------
277
            (
278
                // orderings
279
                vec![
280
                    // [a + b ASC]
281
                    vec![(&a_plus_b, option_asc)],
282
                ],
283
                // projection exprs
284
                vec![
285
                    (col_b, "b_new".to_string()),
286
                    (col_a, "a_new".to_string()),
287
                    (&a_plus_b, "a+b".to_string()),
288
                ],
289
                // expected
290
                vec![
291
                    // [a + b ASC]
292
                    vec![("a+b", option_asc)],
293
                ],
294
            ),
295
            // ---------- TEST CASE 6 ------------
296
            (
297
                // orderings
298
                vec![
299
                    // [a + b ASC, c ASC]
300
                    vec![(&a_plus_b, option_asc), (col_c, option_asc)],
301
                ],
302
                // projection exprs
303
                vec![
304
                    (col_b, "b_new".to_string()),
305
                    (col_a, "a_new".to_string()),
306
                    (col_c, "c_new".to_string()),
307
                    (&a_plus_b, "a+b".to_string()),
308
                ],
309
                // expected
310
                vec![
311
                    // [a + b ASC, c_new ASC]
312
                    vec![("a+b", option_asc), ("c_new", option_asc)],
313
                ],
314
            ),
315
            // ------- TEST CASE 7 ----------
316
            (
317
                vec![
318
                    // [a ASC, b ASC, c ASC]
319
                    vec![(col_a, option_asc), (col_b, option_asc)],
320
                    // [a ASC, d ASC]
321
                    vec![(col_a, option_asc), (col_d, option_asc)],
322
                ],
323
                // b as b_new, a as a_new, d as d_new b+d
324
                vec![
325
                    (col_b, "b_new".to_string()),
326
                    (col_a, "a_new".to_string()),
327
                    (col_d, "d_new".to_string()),
328
                    (&b_plus_d, "b+d".to_string()),
329
                ],
330
                // expected
331
                vec![
332
                    // [a_new ASC, b_new ASC]
333
                    vec![("a_new", option_asc), ("b_new", option_asc)],
334
                    // [a_new ASC, d_new ASC]
335
                    vec![("a_new", option_asc), ("d_new", option_asc)],
336
                    // [a_new ASC, b+d ASC]
337
                    vec![("a_new", option_asc), ("b+d", option_asc)],
338
                ],
339
            ),
340
            // ------- TEST CASE 8 ----------
341
            (
342
                // orderings
343
                vec![
344
                    // [b+d ASC]
345
                    vec![(&b_plus_d, option_asc)],
346
                ],
347
                // proj exprs
348
                vec![
349
                    (col_b, "b_new".to_string()),
350
                    (col_a, "a_new".to_string()),
351
                    (col_d, "d_new".to_string()),
352
                    (&b_plus_d, "b+d".to_string()),
353
                ],
354
                // expected
355
                vec![
356
                    // [b+d ASC]
357
                    vec![("b+d", option_asc)],
358
                ],
359
            ),
360
            // ------- TEST CASE 9 ----------
361
            (
362
                // orderings
363
                vec![
364
                    // [a ASC, d ASC, b ASC]
365
                    vec![
366
                        (col_a, option_asc),
367
                        (col_d, option_asc),
368
                        (col_b, option_asc),
369
                    ],
370
                    // [c ASC]
371
                    vec![(col_c, option_asc)],
372
                ],
373
                // proj exprs
374
                vec![
375
                    (col_b, "b_new".to_string()),
376
                    (col_a, "a_new".to_string()),
377
                    (col_d, "d_new".to_string()),
378
                    (col_c, "c_new".to_string()),
379
                ],
380
                // expected
381
                vec![
382
                    // [a_new ASC, d_new ASC, b_new ASC]
383
                    vec![
384
                        ("a_new", option_asc),
385
                        ("d_new", option_asc),
386
                        ("b_new", option_asc),
387
                    ],
388
                    // [c_new ASC],
389
                    vec![("c_new", option_asc)],
390
                ],
391
            ),
392
            // ------- TEST CASE 10 ----------
393
            (
394
                vec![
395
                    // [a ASC, b ASC, c ASC]
396
                    vec![
397
                        (col_a, option_asc),
398
                        (col_b, option_asc),
399
                        (col_c, option_asc),
400
                    ],
401
                    // [a ASC, d ASC]
402
                    vec![(col_a, option_asc), (col_d, option_asc)],
403
                ],
404
                // proj exprs
405
                vec![
406
                    (col_b, "b_new".to_string()),
407
                    (col_a, "a_new".to_string()),
408
                    (col_c, "c_new".to_string()),
409
                    (&c_plus_d, "c+d".to_string()),
410
                ],
411
                // expected
412
                vec![
413
                    // [a_new ASC, b_new ASC, c_new ASC]
414
                    vec![
415
                        ("a_new", option_asc),
416
                        ("b_new", option_asc),
417
                        ("c_new", option_asc),
418
                    ],
419
                    // [a_new ASC, b_new ASC, c+d ASC]
420
                    vec![
421
                        ("a_new", option_asc),
422
                        ("b_new", option_asc),
423
                        ("c+d", option_asc),
424
                    ],
425
                ],
426
            ),
427
            // ------- TEST CASE 11 ----------
428
            (
429
                // orderings
430
                vec![
431
                    // [a ASC, b ASC]
432
                    vec![(col_a, option_asc), (col_b, option_asc)],
433
                    // [a ASC, d ASC]
434
                    vec![(col_a, option_asc), (col_d, option_asc)],
435
                ],
436
                // proj exprs
437
                vec![
438
                    (col_b, "b_new".to_string()),
439
                    (col_a, "a_new".to_string()),
440
                    (&b_plus_d, "b+d".to_string()),
441
                ],
442
                // expected
443
                vec![
444
                    // [a_new ASC, b_new ASC]
445
                    vec![("a_new", option_asc), ("b_new", option_asc)],
446
                    // [a_new ASC, b + d ASC]
447
                    vec![("a_new", option_asc), ("b+d", option_asc)],
448
                ],
449
            ),
450
            // ------- TEST CASE 12 ----------
451
            (
452
                // orderings
453
                vec![
454
                    // [a ASC, b ASC, c ASC]
455
                    vec![
456
                        (col_a, option_asc),
457
                        (col_b, option_asc),
458
                        (col_c, option_asc),
459
                    ],
460
                ],
461
                // proj exprs
462
                vec![(col_c, "c_new".to_string()), (col_a, "a_new".to_string())],
463
                // expected
464
                vec![
465
                    // [a_new ASC]
466
                    vec![("a_new", option_asc)],
467
                ],
468
            ),
469
            // ------- TEST CASE 13 ----------
470
            (
471
                // orderings
472
                vec![
473
                    // [a ASC, b ASC, c ASC]
474
                    vec![
475
                        (col_a, option_asc),
476
                        (col_b, option_asc),
477
                        (col_c, option_asc),
478
                    ],
479
                    // [a ASC, a + b ASC, c ASC]
480
                    vec![
481
                        (col_a, option_asc),
482
                        (&a_plus_b, option_asc),
483
                        (col_c, option_asc),
484
                    ],
485
                ],
486
                // proj exprs
487
                vec![
488
                    (col_c, "c_new".to_string()),
489
                    (col_b, "b_new".to_string()),
490
                    (col_a, "a_new".to_string()),
491
                    (&a_plus_b, "a+b".to_string()),
492
                ],
493
                // expected
494
                vec![
495
                    // [a_new ASC, b_new ASC, c_new ASC]
496
                    vec![
497
                        ("a_new", option_asc),
498
                        ("b_new", option_asc),
499
                        ("c_new", option_asc),
500
                    ],
501
                    // [a_new ASC, a+b ASC, c_new ASC]
502
                    vec![
503
                        ("a_new", option_asc),
504
                        ("a+b", option_asc),
505
                        ("c_new", option_asc),
506
                    ],
507
                ],
508
            ),
509
            // ------- TEST CASE 14 ----------
510
            (
511
                // orderings
512
                vec![
513
                    // [a ASC, b ASC]
514
                    vec![(col_a, option_asc), (col_b, option_asc)],
515
                    // [c ASC, b ASC]
516
                    vec![(col_c, option_asc), (col_b, option_asc)],
517
                    // [d ASC, e ASC]
518
                    vec![(col_d, option_asc), (col_e, option_asc)],
519
                ],
520
                // proj exprs
521
                vec![
522
                    (col_c, "c_new".to_string()),
523
                    (col_d, "d_new".to_string()),
524
                    (col_a, "a_new".to_string()),
525
                    (&b_plus_e, "b+e".to_string()),
526
                ],
527
                // expected
528
                vec![
529
                    // [a_new ASC, d_new ASC, b+e ASC]
530
                    vec![
531
                        ("a_new", option_asc),
532
                        ("d_new", option_asc),
533
                        ("b+e", option_asc),
534
                    ],
535
                    // [d_new ASC, a_new ASC, b+e ASC]
536
                    vec![
537
                        ("d_new", option_asc),
538
                        ("a_new", option_asc),
539
                        ("b+e", option_asc),
540
                    ],
541
                    // [c_new ASC, d_new ASC, b+e ASC]
542
                    vec![
543
                        ("c_new", option_asc),
544
                        ("d_new", option_asc),
545
                        ("b+e", option_asc),
546
                    ],
547
                    // [d_new ASC, c_new ASC, b+e ASC]
548
                    vec![
549
                        ("d_new", option_asc),
550
                        ("c_new", option_asc),
551
                        ("b+e", option_asc),
552
                    ],
553
                ],
554
            ),
555
            // ------- TEST CASE 15 ----------
556
            (
557
                // orderings
558
                vec![
559
                    // [a ASC, c ASC, b ASC]
560
                    vec![
561
                        (col_a, option_asc),
562
                        (col_c, option_asc),
563
                        (col_b, option_asc),
564
                    ],
565
                ],
566
                // proj exprs
567
                vec![
568
                    (col_c, "c_new".to_string()),
569
                    (col_a, "a_new".to_string()),
570
                    (&a_plus_b, "a+b".to_string()),
571
                ],
572
                // expected
573
                vec![
574
                    // [a_new ASC, d_new ASC, b+e ASC]
575
                    vec![
576
                        ("a_new", option_asc),
577
                        ("c_new", option_asc),
578
                        ("a+b", option_asc),
579
                    ],
580
                ],
581
            ),
582
            // ------- TEST CASE 16 ----------
583
            (
584
                // orderings
585
                vec![
586
                    // [a ASC, b ASC]
587
                    vec![(col_a, option_asc), (col_b, option_asc)],
588
                    // [c ASC, b DESC]
589
                    vec![(col_c, option_asc), (col_b, option_desc)],
590
                    // [e ASC]
591
                    vec![(col_e, option_asc)],
592
                ],
593
                // proj exprs
594
                vec![
595
                    (col_c, "c_new".to_string()),
596
                    (col_a, "a_new".to_string()),
597
                    (col_b, "b_new".to_string()),
598
                    (&b_plus_e, "b+e".to_string()),
599
                ],
600
                // expected
601
                vec![
602
                    // [a_new ASC, b_new ASC]
603
                    vec![("a_new", option_asc), ("b_new", option_asc)],
604
                    // [a_new ASC, b_new ASC]
605
                    vec![("a_new", option_asc), ("b+e", option_asc)],
606
                    // [c_new ASC, b_new DESC]
607
                    vec![("c_new", option_asc), ("b_new", option_desc)],
608
                ],
609
            ),
610
        ];
611
612
        for (idx, (orderings, proj_exprs, expected)) in test_cases.into_iter().enumerate()
613
        {
614
            let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
615
616
            let orderings = convert_to_orderings(&orderings);
617
            eq_properties.add_new_orderings(orderings);
618
619
            let proj_exprs = proj_exprs
620
                .into_iter()
621
                .map(|(expr, name)| (Arc::clone(expr), name))
622
                .collect::<Vec<_>>();
623
            let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?;
624
            let output_schema = output_schema(&projection_mapping, &schema)?;
625
626
            let expected = expected
627
                .into_iter()
628
                .map(|ordering| {
629
                    ordering
630
                        .into_iter()
631
                        .map(|(name, options)| {
632
                            (col(name, &output_schema).unwrap(), options)
633
                        })
634
                        .collect::<Vec<_>>()
635
                })
636
                .collect::<Vec<_>>();
637
            let expected = convert_to_orderings_owned(&expected);
638
639
            let projected_eq = eq_properties.project(&projection_mapping, output_schema);
640
            let orderings = projected_eq.oeq_class();
641
642
            let err_msg = format!(
643
                "test_idx: {:?}, actual: {:?}, expected: {:?}, projection_mapping: {:?}",
644
                idx, orderings.orderings, expected, projection_mapping
645
            );
646
647
            assert_eq!(orderings.len(), expected.len(), "{}", err_msg);
648
            for expected_ordering in &expected {
649
                assert!(orderings.contains(expected_ordering), "{}", err_msg)
650
            }
651
        }
652
653
        Ok(())
654
    }
655
656
    #[test]
657
    fn project_orderings2() -> Result<()> {
658
        let schema = Arc::new(Schema::new(vec![
659
            Field::new("a", DataType::Int32, true),
660
            Field::new("b", DataType::Int32, true),
661
            Field::new("c", DataType::Int32, true),
662
            Field::new("d", DataType::Int32, true),
663
            Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true),
664
        ]));
665
        let col_a = &col("a", &schema)?;
666
        let col_b = &col("b", &schema)?;
667
        let col_c = &col("c", &schema)?;
668
        let col_ts = &col("ts", &schema)?;
669
        let a_plus_b = Arc::new(BinaryExpr::new(
670
            Arc::clone(col_a),
671
            Operator::Plus,
672
            Arc::clone(col_b),
673
        )) as Arc<dyn PhysicalExpr>;
674
675
        let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new());
676
        let round_c = &create_physical_expr(
677
            &test_fun,
678
            &[Arc::clone(col_c)],
679
            &schema,
680
            &[],
681
            &DFSchema::empty(),
682
        )?;
683
684
        let option_asc = SortOptions {
685
            descending: false,
686
            nulls_first: false,
687
        };
688
689
        let proj_exprs = vec![
690
            (col_b, "b_new".to_string()),
691
            (col_a, "a_new".to_string()),
692
            (col_c, "c_new".to_string()),
693
            (round_c, "round_c_res".to_string()),
694
        ];
695
        let proj_exprs = proj_exprs
696
            .into_iter()
697
            .map(|(expr, name)| (Arc::clone(expr), name))
698
            .collect::<Vec<_>>();
699
        let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?;
700
        let output_schema = output_schema(&projection_mapping, &schema)?;
701
702
        let col_a_new = &col("a_new", &output_schema)?;
703
        let col_b_new = &col("b_new", &output_schema)?;
704
        let col_c_new = &col("c_new", &output_schema)?;
705
        let col_round_c_res = &col("round_c_res", &output_schema)?;
706
        let a_new_plus_b_new = Arc::new(BinaryExpr::new(
707
            Arc::clone(col_a_new),
708
            Operator::Plus,
709
            Arc::clone(col_b_new),
710
        )) as Arc<dyn PhysicalExpr>;
711
712
        let test_cases = vec![
713
            // ---------- TEST CASE 1 ------------
714
            (
715
                // orderings
716
                vec![
717
                    // [a ASC]
718
                    vec![(col_a, option_asc)],
719
                ],
720
                // expected
721
                vec![
722
                    // [b_new ASC]
723
                    vec![(col_a_new, option_asc)],
724
                ],
725
            ),
726
            // ---------- TEST CASE 2 ------------
727
            (
728
                // orderings
729
                vec![
730
                    // [a+b ASC]
731
                    vec![(&a_plus_b, option_asc)],
732
                ],
733
                // expected
734
                vec![
735
                    // [b_new ASC]
736
                    vec![(&a_new_plus_b_new, option_asc)],
737
                ],
738
            ),
739
            // ---------- TEST CASE 3 ------------
740
            (
741
                // orderings
742
                vec![
743
                    // [a ASC, ts ASC]
744
                    vec![(col_a, option_asc), (col_ts, option_asc)],
745
                ],
746
                // expected
747
                vec![
748
                    // [a_new ASC, date_bin_res ASC]
749
                    vec![(col_a_new, option_asc)],
750
                ],
751
            ),
752
            // ---------- TEST CASE 4 ------------
753
            (
754
                // orderings
755
                vec![
756
                    // [a ASC, ts ASC, b ASC]
757
                    vec![
758
                        (col_a, option_asc),
759
                        (col_ts, option_asc),
760
                        (col_b, option_asc),
761
                    ],
762
                ],
763
                // expected
764
                vec![
765
                    // [a_new ASC, date_bin_res ASC]
766
                    vec![(col_a_new, option_asc)],
767
                ],
768
            ),
769
            // ---------- TEST CASE 5 ------------
770
            (
771
                // orderings
772
                vec![
773
                    // [a ASC, c ASC]
774
                    vec![(col_a, option_asc), (col_c, option_asc)],
775
                ],
776
                // expected
777
                vec![
778
                    // [a_new ASC, round_c_res ASC, c_new ASC]
779
                    vec![(col_a_new, option_asc), (col_round_c_res, option_asc)],
780
                    // [a_new ASC, c_new ASC]
781
                    vec![(col_a_new, option_asc), (col_c_new, option_asc)],
782
                ],
783
            ),
784
            // ---------- TEST CASE 6 ------------
785
            (
786
                // orderings
787
                vec![
788
                    // [c ASC, b ASC]
789
                    vec![(col_c, option_asc), (col_b, option_asc)],
790
                ],
791
                // expected
792
                vec![
793
                    // [round_c_res ASC]
794
                    vec![(col_round_c_res, option_asc)],
795
                    // [c_new ASC, b_new ASC]
796
                    vec![(col_c_new, option_asc), (col_b_new, option_asc)],
797
                ],
798
            ),
799
            // ---------- TEST CASE 7 ------------
800
            (
801
                // orderings
802
                vec![
803
                    // [a+b ASC, c ASC]
804
                    vec![(&a_plus_b, option_asc), (col_c, option_asc)],
805
                ],
806
                // expected
807
                vec![
808
                    // [a+b ASC, round(c) ASC, c_new ASC]
809
                    vec![
810
                        (&a_new_plus_b_new, option_asc),
811
                        (col_round_c_res, option_asc),
812
                    ],
813
                    // [a+b ASC, c_new ASC]
814
                    vec![(&a_new_plus_b_new, option_asc), (col_c_new, option_asc)],
815
                ],
816
            ),
817
        ];
818
819
        for (idx, (orderings, expected)) in test_cases.iter().enumerate() {
820
            let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
821
822
            let orderings = convert_to_orderings(orderings);
823
            eq_properties.add_new_orderings(orderings);
824
825
            let expected = convert_to_orderings(expected);
826
827
            let projected_eq =
828
                eq_properties.project(&projection_mapping, Arc::clone(&output_schema));
829
            let orderings = projected_eq.oeq_class();
830
831
            let err_msg = format!(
832
                "test idx: {:?}, actual: {:?}, expected: {:?}, projection_mapping: {:?}",
833
                idx, orderings.orderings, expected, projection_mapping
834
            );
835
836
            assert_eq!(orderings.len(), expected.len(), "{}", err_msg);
837
            for expected_ordering in &expected {
838
                assert!(orderings.contains(expected_ordering), "{}", err_msg)
839
            }
840
        }
841
        Ok(())
842
    }
843
844
    #[test]
845
    fn project_orderings3() -> Result<()> {
846
        let schema = Arc::new(Schema::new(vec![
847
            Field::new("a", DataType::Int32, true),
848
            Field::new("b", DataType::Int32, true),
849
            Field::new("c", DataType::Int32, true),
850
            Field::new("d", DataType::Int32, true),
851
            Field::new("e", DataType::Int32, true),
852
            Field::new("f", DataType::Int32, true),
853
        ]));
854
        let col_a = &col("a", &schema)?;
855
        let col_b = &col("b", &schema)?;
856
        let col_c = &col("c", &schema)?;
857
        let col_d = &col("d", &schema)?;
858
        let col_e = &col("e", &schema)?;
859
        let col_f = &col("f", &schema)?;
860
        let a_plus_b = Arc::new(BinaryExpr::new(
861
            Arc::clone(col_a),
862
            Operator::Plus,
863
            Arc::clone(col_b),
864
        )) as Arc<dyn PhysicalExpr>;
865
866
        let option_asc = SortOptions {
867
            descending: false,
868
            nulls_first: false,
869
        };
870
871
        let proj_exprs = vec![
872
            (col_c, "c_new".to_string()),
873
            (col_d, "d_new".to_string()),
874
            (&a_plus_b, "a+b".to_string()),
875
        ];
876
        let proj_exprs = proj_exprs
877
            .into_iter()
878
            .map(|(expr, name)| (Arc::clone(expr), name))
879
            .collect::<Vec<_>>();
880
        let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?;
881
        let output_schema = output_schema(&projection_mapping, &schema)?;
882
883
        let col_a_plus_b_new = &col("a+b", &output_schema)?;
884
        let col_c_new = &col("c_new", &output_schema)?;
885
        let col_d_new = &col("d_new", &output_schema)?;
886
887
        let test_cases = vec![
888
            // ---------- TEST CASE 1 ------------
889
            (
890
                // orderings
891
                vec![
892
                    // [d ASC, b ASC]
893
                    vec![(col_d, option_asc), (col_b, option_asc)],
894
                    // [c ASC, a ASC]
895
                    vec![(col_c, option_asc), (col_a, option_asc)],
896
                ],
897
                // equal conditions
898
                vec![],
899
                // expected
900
                vec![
901
                    // [d_new ASC, c_new ASC, a+b ASC]
902
                    vec![
903
                        (col_d_new, option_asc),
904
                        (col_c_new, option_asc),
905
                        (col_a_plus_b_new, option_asc),
906
                    ],
907
                    // [c_new ASC, d_new ASC, a+b ASC]
908
                    vec![
909
                        (col_c_new, option_asc),
910
                        (col_d_new, option_asc),
911
                        (col_a_plus_b_new, option_asc),
912
                    ],
913
                ],
914
            ),
915
            // ---------- TEST CASE 2 ------------
916
            (
917
                // orderings
918
                vec![
919
                    // [d ASC, b ASC]
920
                    vec![(col_d, option_asc), (col_b, option_asc)],
921
                    // [c ASC, e ASC], Please note that a=e
922
                    vec![(col_c, option_asc), (col_e, option_asc)],
923
                ],
924
                // equal conditions
925
                vec![(col_e, col_a)],
926
                // expected
927
                vec![
928
                    // [d_new ASC, c_new ASC, a+b ASC]
929
                    vec![
930
                        (col_d_new, option_asc),
931
                        (col_c_new, option_asc),
932
                        (col_a_plus_b_new, option_asc),
933
                    ],
934
                    // [c_new ASC, d_new ASC, a+b ASC]
935
                    vec![
936
                        (col_c_new, option_asc),
937
                        (col_d_new, option_asc),
938
                        (col_a_plus_b_new, option_asc),
939
                    ],
940
                ],
941
            ),
942
            // ---------- TEST CASE 3 ------------
943
            (
944
                // orderings
945
                vec![
946
                    // [d ASC, b ASC]
947
                    vec![(col_d, option_asc), (col_b, option_asc)],
948
                    // [c ASC, e ASC], Please note that a=f
949
                    vec![(col_c, option_asc), (col_e, option_asc)],
950
                ],
951
                // equal conditions
952
                vec![(col_a, col_f)],
953
                // expected
954
                vec![
955
                    // [d_new ASC]
956
                    vec![(col_d_new, option_asc)],
957
                    // [c_new ASC]
958
                    vec![(col_c_new, option_asc)],
959
                ],
960
            ),
961
        ];
962
        for (orderings, equal_columns, expected) in test_cases {
963
            let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
964
            for (lhs, rhs) in equal_columns {
965
                eq_properties.add_equal_conditions(lhs, rhs)?;
966
            }
967
968
            let orderings = convert_to_orderings(&orderings);
969
            eq_properties.add_new_orderings(orderings);
970
971
            let expected = convert_to_orderings(&expected);
972
973
            let projected_eq =
974
                eq_properties.project(&projection_mapping, Arc::clone(&output_schema));
975
            let orderings = projected_eq.oeq_class();
976
977
            let err_msg = format!(
978
                "actual: {:?}, expected: {:?}, projection_mapping: {:?}",
979
                orderings.orderings, expected, projection_mapping
980
            );
981
982
            assert_eq!(orderings.len(), expected.len(), "{}", err_msg);
983
            for expected_ordering in &expected {
984
                assert!(orderings.contains(expected_ordering), "{}", err_msg)
985
            }
986
        }
987
988
        Ok(())
989
    }
990
991
    #[test]
992
    fn project_orderings_random() -> Result<()> {
993
        const N_RANDOM_SCHEMA: usize = 20;
994
        const N_ELEMENTS: usize = 125;
995
        const N_DISTINCT: usize = 5;
996
997
        for seed in 0..N_RANDOM_SCHEMA {
998
            // Create a random schema with random properties
999
            let (test_schema, eq_properties) = create_random_schema(seed as u64)?;
1000
            // Generate a data that satisfies properties given
1001
            let table_data_with_properties =
1002
                generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?;
1003
            // Floor(a)
1004
            let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new());
1005
            let floor_a = create_physical_expr(
1006
                &test_fun,
1007
                &[col("a", &test_schema)?],
1008
                &test_schema,
1009
                &[],
1010
                &DFSchema::empty(),
1011
            )?;
1012
            // a + b
1013
            let a_plus_b = Arc::new(BinaryExpr::new(
1014
                col("a", &test_schema)?,
1015
                Operator::Plus,
1016
                col("b", &test_schema)?,
1017
            )) as Arc<dyn PhysicalExpr>;
1018
            let proj_exprs = vec![
1019
                (col("a", &test_schema)?, "a_new"),
1020
                (col("b", &test_schema)?, "b_new"),
1021
                (col("c", &test_schema)?, "c_new"),
1022
                (col("d", &test_schema)?, "d_new"),
1023
                (col("e", &test_schema)?, "e_new"),
1024
                (col("f", &test_schema)?, "f_new"),
1025
                (floor_a, "floor(a)"),
1026
                (a_plus_b, "a+b"),
1027
            ];
1028
1029
            for n_req in 0..=proj_exprs.len() {
1030
                for proj_exprs in proj_exprs.iter().combinations(n_req) {
1031
                    let proj_exprs = proj_exprs
1032
                        .into_iter()
1033
                        .map(|(expr, name)| (Arc::clone(expr), name.to_string()))
1034
                        .collect::<Vec<_>>();
1035
                    let (projected_batch, projected_eq) = apply_projection(
1036
                        proj_exprs.clone(),
1037
                        &table_data_with_properties,
1038
                        &eq_properties,
1039
                    )?;
1040
1041
                    // Make sure each ordering after projection is valid.
1042
                    for ordering in projected_eq.oeq_class().iter() {
1043
                        let err_msg = format!(
1044
                            "Error in test case ordering:{:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}, proj_exprs: {:?}",
1045
                            ordering, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants, proj_exprs
1046
                        );
1047
                        // Since ordered section satisfies schema, we expect
1048
                        // that result will be same after sort (e.g sort was unnecessary).
1049
                        assert!(
1050
                            is_table_same_after_sort(
1051
                                ordering.clone(),
1052
                                projected_batch.clone(),
1053
                            )?,
1054
                            "{}",
1055
                            err_msg
1056
                        );
1057
                    }
1058
                }
1059
            }
1060
        }
1061
1062
        Ok(())
1063
    }
1064
1065
    #[test]
1066
    fn ordering_satisfy_after_projection_random() -> Result<()> {
1067
        const N_RANDOM_SCHEMA: usize = 20;
1068
        const N_ELEMENTS: usize = 125;
1069
        const N_DISTINCT: usize = 5;
1070
        const SORT_OPTIONS: SortOptions = SortOptions {
1071
            descending: false,
1072
            nulls_first: false,
1073
        };
1074
1075
        for seed in 0..N_RANDOM_SCHEMA {
1076
            // Create a random schema with random properties
1077
            let (test_schema, eq_properties) = create_random_schema(seed as u64)?;
1078
            // Generate a data that satisfies properties given
1079
            let table_data_with_properties =
1080
                generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?;
1081
            // Floor(a)
1082
            let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new());
1083
            let floor_a = create_physical_expr(
1084
                &test_fun,
1085
                &[col("a", &test_schema)?],
1086
                &test_schema,
1087
                &[],
1088
                &DFSchema::empty(),
1089
            )?;
1090
            // a + b
1091
            let a_plus_b = Arc::new(BinaryExpr::new(
1092
                col("a", &test_schema)?,
1093
                Operator::Plus,
1094
                col("b", &test_schema)?,
1095
            )) as Arc<dyn PhysicalExpr>;
1096
            let proj_exprs = vec![
1097
                (col("a", &test_schema)?, "a_new"),
1098
                (col("b", &test_schema)?, "b_new"),
1099
                (col("c", &test_schema)?, "c_new"),
1100
                (col("d", &test_schema)?, "d_new"),
1101
                (col("e", &test_schema)?, "e_new"),
1102
                (col("f", &test_schema)?, "f_new"),
1103
                (floor_a, "floor(a)"),
1104
                (a_plus_b, "a+b"),
1105
            ];
1106
1107
            for n_req in 0..=proj_exprs.len() {
1108
                for proj_exprs in proj_exprs.iter().combinations(n_req) {
1109
                    let proj_exprs = proj_exprs
1110
                        .into_iter()
1111
                        .map(|(expr, name)| (Arc::clone(expr), name.to_string()))
1112
                        .collect::<Vec<_>>();
1113
                    let (projected_batch, projected_eq) = apply_projection(
1114
                        proj_exprs.clone(),
1115
                        &table_data_with_properties,
1116
                        &eq_properties,
1117
                    )?;
1118
1119
                    let projection_mapping =
1120
                        ProjectionMapping::try_new(&proj_exprs, &test_schema)?;
1121
1122
                    let projected_exprs = projection_mapping
1123
                        .iter()
1124
                        .map(|(_source, target)| Arc::clone(target))
1125
                        .collect::<Vec<_>>();
1126
1127
                    for n_req in 0..=projected_exprs.len() {
1128
                        for exprs in projected_exprs.iter().combinations(n_req) {
1129
                            let requirement = exprs
1130
                                .into_iter()
1131
                                .map(|expr| PhysicalSortExpr {
1132
                                    expr: Arc::clone(expr),
1133
                                    options: SORT_OPTIONS,
1134
                                })
1135
                                .collect::<Vec<_>>();
1136
                            let expected = is_table_same_after_sort(
1137
                                requirement.clone(),
1138
                                projected_batch.clone(),
1139
                            )?;
1140
                            let err_msg = format!(
1141
                                "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}, projected_eq.oeq_class: {:?}, projected_eq.eq_group: {:?}, projected_eq.constants: {:?}, projection_mapping: {:?}",
1142
                                requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants, projected_eq.oeq_class, projected_eq.eq_group, projected_eq.constants, projection_mapping
1143
                            );
1144
                            // Check whether ordering_satisfy API result and
1145
                            // experimental result matches.
1146
                            assert_eq!(
1147
                                projected_eq.ordering_satisfy(&requirement),
1148
                                expected,
1149
                                "{}",
1150
                                err_msg
1151
                            );
1152
                        }
1153
                    }
1154
                }
1155
            }
1156
        }
1157
1158
        Ok(())
1159
    }
1160
}