Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-expr/src/equivalence/class.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
use std::fmt::Display;
19
use std::sync::Arc;
20
21
use super::{add_offset_to_expr, collapse_lex_req, ProjectionMapping};
22
use crate::{
23
    expressions::Column, physical_expr::deduplicate_physical_exprs,
24
    physical_exprs_bag_equal, physical_exprs_contains, LexOrdering, LexOrderingRef,
25
    LexRequirement, LexRequirementRef, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr,
26
    PhysicalSortRequirement,
27
};
28
29
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
30
use datafusion_common::JoinType;
31
use datafusion_physical_expr_common::physical_expr::format_physical_expr_list;
32
33
#[derive(Debug, Clone)]
34
/// A structure representing a expression known to be constant in a physical execution plan.
35
///
36
/// The `ConstExpr` struct encapsulates an expression that is constant during the execution
37
/// of a query. For example if a predicate like `A = 5` applied earlier in the plan `A` would
38
/// be known constant
39
///
40
/// # Fields
41
///
42
/// - `expr`: Constant expression for a node in the physical plan.
43
///
44
/// - `across_partitions`: A boolean flag indicating whether the constant expression is
45
///   valid across partitions. If set to `true`, the constant expression has same value for all partitions.
46
///   If set to `false`, the constant expression may have different values for different partitions.
47
///
48
/// # Example
49
///
50
/// ```rust
51
/// # use datafusion_physical_expr::ConstExpr;
52
/// # use datafusion_physical_expr::expressions::lit;
53
/// let col = lit(5);
54
/// // Create a constant expression from a physical expression ref
55
/// let const_expr = ConstExpr::from(&col);
56
/// // create a constant expression from a physical expression
57
/// let const_expr = ConstExpr::from(col);
58
/// ```
59
pub struct ConstExpr {
60
    expr: Arc<dyn PhysicalExpr>,
61
    across_partitions: bool,
62
}
63
64
impl ConstExpr {
65
    /// Create a new constant expression from a physical expression.
66
    ///
67
    /// Note you can also use `ConstExpr::from` to create a constant expression
68
    /// from a reference as well
69
549
    pub fn new(expr: Arc<dyn PhysicalExpr>) -> Self {
70
549
        Self {
71
549
            expr,
72
549
            // By default, assume constant expressions are not same across partitions.
73
549
            across_partitions: false,
74
549
        }
75
549
    }
76
77
280
    pub fn with_across_partitions(mut self, across_partitions: bool) -> Self {
78
280
        self.across_partitions = across_partitions;
79
280
        self
80
280
    }
81
82
276
    pub fn across_partitions(&self) -> bool {
83
276
        self.across_partitions
84
276
    }
85
86
1.36k
    pub fn expr(&self) -> &Arc<dyn PhysicalExpr> {
87
1.36k
        &self.expr
88
1.36k
    }
89
90
276
    pub fn owned_expr(self) -> Arc<dyn PhysicalExpr> {
91
276
        self.expr
92
276
    }
93
94
0
    pub fn map<F>(&self, f: F) -> Option<Self>
95
0
    where
96
0
        F: Fn(&Arc<dyn PhysicalExpr>) -> Option<Arc<dyn PhysicalExpr>>,
97
0
    {
98
0
        let maybe_expr = f(&self.expr);
99
0
        maybe_expr.map(|expr| Self {
100
0
            expr,
101
0
            across_partitions: self.across_partitions,
102
0
        })
103
0
    }
104
}
105
106
/// Display implementation for `ConstExpr`
107
///
108
/// Example `c` or `c(across_partitions)`
109
impl Display for ConstExpr {
110
0
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
111
0
        write!(f, "{}", self.expr)?;
112
0
        if self.across_partitions {
113
0
            write!(f, "(across_partitions)")?;
114
0
        }
115
0
        Ok(())
116
0
    }
117
}
118
119
impl From<Arc<dyn PhysicalExpr>> for ConstExpr {
120
337
    fn from(expr: Arc<dyn PhysicalExpr>) -> Self {
121
337
        Self::new(expr)
122
337
    }
123
}
124
125
impl From<&Arc<dyn PhysicalExpr>> for ConstExpr {
126
209
    fn from(expr: &Arc<dyn PhysicalExpr>) -> Self {
127
209
        Self::new(Arc::clone(expr))
128
209
    }
129
}
130
131
/// Checks whether `expr` is among in the `const_exprs`.
132
2.46k
pub fn const_exprs_contains(
133
2.46k
    const_exprs: &[ConstExpr],
134
2.46k
    expr: &Arc<dyn PhysicalExpr>,
135
2.46k
) -> bool {
136
2.46k
    const_exprs
137
2.46k
        .iter()
138
2.46k
        .any(|const_expr| 
const_expr.expr.eq(expr)191
)
139
2.46k
}
140
141
/// An `EquivalenceClass` is a set of [`Arc<dyn PhysicalExpr>`]s that are known
142
/// to have the same value for all tuples in a relation. These are generated by
143
/// equality predicates (e.g. `a = b`), typically equi-join conditions and
144
/// equality conditions in filters.
145
///
146
/// Two `EquivalenceClass`es are equal if they contains the same expressions in
147
/// without any ordering.
148
#[derive(Debug, Clone)]
149
pub struct EquivalenceClass {
150
    /// The expressions in this equivalence class. The order doesn't
151
    /// matter for equivalence purposes
152
    ///
153
    /// TODO: use a HashSet for this instead of a Vec
154
    exprs: Vec<Arc<dyn PhysicalExpr>>,
155
}
156
157
impl PartialEq for EquivalenceClass {
158
    /// Returns true if other is equal in the sense
159
    /// of bags (multi-sets), disregarding their orderings.
160
0
    fn eq(&self, other: &Self) -> bool {
161
0
        physical_exprs_bag_equal(&self.exprs, &other.exprs)
162
0
    }
163
}
164
165
impl EquivalenceClass {
166
    /// Create a new empty equivalence class
167
0
    pub fn new_empty() -> Self {
168
0
        Self { exprs: vec![] }
169
0
    }
170
171
    // Create a new equivalence class from a pre-existing `Vec`
172
183
    pub fn new(mut exprs: Vec<Arc<dyn PhysicalExpr>>) -> Self {
173
183
        deduplicate_physical_exprs(&mut exprs);
174
183
        Self { exprs }
175
183
    }
176
177
    /// Return the inner vector of expressions
178
5
    pub fn into_vec(self) -> Vec<Arc<dyn PhysicalExpr>> {
179
5
        self.exprs
180
5
    }
181
182
    /// Return the "canonical" expression for this class (the first element)
183
    /// if any
184
45
    fn canonical_expr(&self) -> Option<Arc<dyn PhysicalExpr>> {
185
45
        self.exprs.first().cloned()
186
45
    }
187
188
    /// Insert the expression into this class, meaning it is known to be equal to
189
    /// all other expressions in this class
190
0
    pub fn push(&mut self, expr: Arc<dyn PhysicalExpr>) {
191
0
        if !self.contains(&expr) {
192
0
            self.exprs.push(expr);
193
0
        }
194
0
    }
195
196
    /// Inserts all the expressions from other into this class
197
0
    pub fn extend(&mut self, other: Self) {
198
0
        for expr in other.exprs {
199
0
            // use push so entries are deduplicated
200
0
            self.push(expr);
201
0
        }
202
0
    }
203
204
    /// Returns true if this equivalence class contains t expression
205
118
    pub fn contains(&self, expr: &Arc<dyn PhysicalExpr>) -> bool {
206
118
        physical_exprs_contains(&self.exprs, expr)
207
118
    }
208
209
    /// Returns true if this equivalence class has any entries in common with `other`
210
15
    pub fn contains_any(&self, other: &Self) -> bool {
211
30
        self.exprs.iter().any(|e| other.contains(e))
212
15
    }
213
214
    /// return the number of items in this class
215
534
    pub fn len(&self) -> usize {
216
534
        self.exprs.len()
217
534
    }
218
219
    /// return true if this class is empty
220
0
    pub fn is_empty(&self) -> bool {
221
0
        self.exprs.is_empty()
222
0
    }
223
224
    /// Iterate over all elements in this class, in some arbitrary order
225
0
    pub fn iter(&self) -> impl Iterator<Item = &Arc<dyn PhysicalExpr>> {
226
0
        self.exprs.iter()
227
0
    }
228
229
    /// Return a new equivalence class that have the specified offset added to
230
    /// each expression (used when schemas are appended such as in joins)
231
0
    pub fn with_offset(&self, offset: usize) -> Self {
232
0
        let new_exprs = self
233
0
            .exprs
234
0
            .iter()
235
0
            .cloned()
236
0
            .map(|e| add_offset_to_expr(e, offset))
237
0
            .collect();
238
0
        Self::new(new_exprs)
239
0
    }
240
}
241
242
impl Display for EquivalenceClass {
243
0
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
244
0
        write!(f, "[{}]", format_physical_expr_list(&self.exprs))
245
0
    }
246
}
247
248
/// An `EquivalenceGroup` is a collection of `EquivalenceClass`es where each
249
/// class represents a distinct equivalence class in a relation.
250
#[derive(Debug, Clone)]
251
pub struct EquivalenceGroup {
252
    pub classes: Vec<EquivalenceClass>,
253
}
254
255
impl EquivalenceGroup {
256
    /// Creates an empty equivalence group.
257
3.11k
    pub fn empty() -> Self {
258
3.11k
        Self { classes: vec![] }
259
3.11k
    }
260
261
    /// Creates an equivalence group from the given equivalence classes.
262
669
    pub fn new(classes: Vec<EquivalenceClass>) -> Self {
263
669
        let mut result = Self { classes };
264
669
        result.remove_redundant_entries();
265
669
        result
266
669
    }
267
268
    /// Returns how many equivalence classes there are in this group.
269
0
    pub fn len(&self) -> usize {
270
0
        self.classes.len()
271
0
    }
272
273
    /// Checks whether this equivalence group is empty.
274
0
    pub fn is_empty(&self) -> bool {
275
0
        self.len() == 0
276
0
    }
277
278
    /// Returns an iterator over the equivalence classes in this group.
279
7.73k
    pub fn iter(&self) -> impl Iterator<Item = &EquivalenceClass> {
280
7.73k
        self.classes.iter()
281
7.73k
    }
282
283
    /// Adds the equality `left` = `right` to this equivalence group.
284
    /// New equality conditions often arise after steps like `Filter(a = b)`,
285
    /// `Alias(a, a as b)` etc.
286
183
    pub fn add_equal_conditions(
287
183
        &mut self,
288
183
        left: &Arc<dyn PhysicalExpr>,
289
183
        right: &Arc<dyn PhysicalExpr>,
290
183
    ) {
291
183
        let mut first_class = None;
292
183
        let mut second_class = None;
293
183
        for (
idx, cls15
) in self.classes.iter().enumerate() {
294
15
            if cls.contains(left) {
295
0
                first_class = Some(idx);
296
15
            }
297
15
            if cls.contains(right) {
298
0
                second_class = Some(idx);
299
15
            }
300
        }
301
183
        match (first_class, second_class) {
302
0
            (Some(mut first_idx), Some(mut second_idx)) => {
303
0
                // If the given left and right sides belong to different classes,
304
0
                // we should unify/bridge these classes.
305
0
                if first_idx != second_idx {
306
                    // By convention, make sure `second_idx` is larger than `first_idx`.
307
0
                    if first_idx > second_idx {
308
0
                        (first_idx, second_idx) = (second_idx, first_idx);
309
0
                    }
310
                    // Remove the class at `second_idx` and merge its values with
311
                    // the class at `first_idx`. The convention above makes sure
312
                    // that `first_idx` is still valid after removing `second_idx`.
313
0
                    let other_class = self.classes.swap_remove(second_idx);
314
0
                    self.classes[first_idx].extend(other_class);
315
0
                }
316
            }
317
0
            (Some(group_idx), None) => {
318
0
                // Right side is new, extend left side's class:
319
0
                self.classes[group_idx].push(Arc::clone(right));
320
0
            }
321
0
            (None, Some(group_idx)) => {
322
0
                // Left side is new, extend right side's class:
323
0
                self.classes[group_idx].push(Arc::clone(left));
324
0
            }
325
183
            (None, None) => {
326
183
                // None of the expressions is among existing classes.
327
183
                // Create a new equivalence class and extend the group.
328
183
                self.classes.push(EquivalenceClass::new(vec![
329
183
                    Arc::clone(left),
330
183
                    Arc::clone(right),
331
183
                ]));
332
183
            }
333
        }
334
183
    }
335
336
    /// Removes redundant entries from this group.
337
1.82k
    fn remove_redundant_entries(&mut self) {
338
1.82k
        // Remove duplicate entries from each equivalence class:
339
1.82k
        self.classes.retain_mut(|cls| {
340
178
            // Keep groups that have at least two entries as singleton class is
341
178
            // meaningless (i.e. it contains no non-trivial information):
342
178
            cls.len() > 1
343
1.82k
        });
344
1.82k
        // Unify/bridge groups that have common expressions:
345
1.82k
        self.bridge_classes()
346
1.82k
    }
347
348
    /// This utility function unifies/bridges classes that have common expressions.
349
    /// For example, assume that we have [`EquivalenceClass`]es `[a, b]` and `[b, c]`.
350
    /// Since both classes contain `b`, columns `a`, `b` and `c` are actually all
351
    /// equal and belong to one class. This utility converts merges such classes.
352
1.82k
    fn bridge_classes(&mut self) {
353
1.82k
        let mut idx = 0;
354
1.99k
        while idx < self.classes.len() {
355
178
            let mut next_idx = idx + 1;
356
178
            let start_size = self.classes[idx].len();
357
193
            while next_idx < self.classes.len() {
358
15
                if self.classes[idx].contains_any(&self.classes[next_idx]) {
359
0
                    let extension = self.classes.swap_remove(next_idx);
360
0
                    self.classes[idx].extend(extension);
361
15
                } else {
362
15
                    next_idx += 1;
363
15
                }
364
            }
365
178
            if self.classes[idx].len() > start_size {
366
0
                continue;
367
178
            }
368
178
            idx += 1;
369
        }
370
1.82k
    }
371
372
    /// Extends this equivalence group with the `other` equivalence group.
373
1.15k
    pub fn extend(&mut self, other: Self) {
374
1.15k
        self.classes.extend(other.classes);
375
1.15k
        self.remove_redundant_entries();
376
1.15k
    }
377
378
    /// Normalizes the given physical expression according to this group.
379
    /// The expression is replaced with the first expression in the equivalence
380
    /// class it matches with (if any).
381
6.44k
    pub fn normalize_expr(&self, expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
382
6.44k
        Arc::clone(&expr)
383
6.44k
            .transform(|expr| {
384
6.44k
                for 
cls53
in self.iter() {
385
53
                    if cls.contains(&expr) {
386
45
                        return Ok(Transformed::yes(cls.canonical_expr().unwrap()));
387
8
                    }
388
                }
389
6.40k
                Ok(Transformed::no(expr))
390
6.44k
            })
391
6.44k
            .data()
392
6.44k
            .unwrap_or(expr)
393
6.44k
    }
394
395
    /// Normalizes the given sort expression according to this group.
396
    /// The underlying physical expression is replaced with the first expression
397
    /// in the equivalence class it matches with (if any). If the underlying
398
    /// expression does not belong to any equivalence class in this group, returns
399
    /// the sort expression as is.
400
0
    pub fn normalize_sort_expr(
401
0
        &self,
402
0
        mut sort_expr: PhysicalSortExpr,
403
0
    ) -> PhysicalSortExpr {
404
0
        sort_expr.expr = self.normalize_expr(sort_expr.expr);
405
0
        sort_expr
406
0
    }
407
408
    /// Normalizes the given sort requirement according to this group.
409
    /// The underlying physical expression is replaced with the first expression
410
    /// in the equivalence class it matches with (if any). If the underlying
411
    /// expression does not belong to any equivalence class in this group, returns
412
    /// the given sort requirement as is.
413
3.42k
    pub fn normalize_sort_requirement(
414
3.42k
        &self,
415
3.42k
        mut sort_requirement: PhysicalSortRequirement,
416
3.42k
    ) -> PhysicalSortRequirement {
417
3.42k
        sort_requirement.expr = self.normalize_expr(sort_requirement.expr);
418
3.42k
        sort_requirement
419
3.42k
    }
420
421
    /// This function applies the `normalize_expr` function for all expressions
422
    /// in `exprs` and returns the corresponding normalized physical expressions.
423
2.58k
    pub fn normalize_exprs(
424
2.58k
        &self,
425
2.58k
        exprs: impl IntoIterator<Item = Arc<dyn PhysicalExpr>>,
426
2.58k
    ) -> Vec<Arc<dyn PhysicalExpr>> {
427
2.58k
        exprs
428
2.58k
            .into_iter()
429
2.58k
            .map(|expr| 
self.normalize_expr(expr)1.63k
)
430
2.58k
            .collect()
431
2.58k
    }
432
433
    /// This function applies the `normalize_sort_expr` function for all sort
434
    /// expressions in `sort_exprs` and returns the corresponding normalized
435
    /// sort expressions.
436
0
    pub fn normalize_sort_exprs(&self, sort_exprs: LexOrderingRef) -> LexOrdering {
437
0
        // Convert sort expressions to sort requirements:
438
0
        let sort_reqs = PhysicalSortRequirement::from_sort_exprs(sort_exprs.iter());
439
0
        // Normalize the requirements:
440
0
        let normalized_sort_reqs = self.normalize_sort_requirements(&sort_reqs);
441
0
        // Convert sort requirements back to sort expressions:
442
0
        PhysicalSortRequirement::to_sort_exprs(normalized_sort_reqs.inner)
443
0
    }
444
445
    /// This function applies the `normalize_sort_requirement` function for all
446
    /// requirements in `sort_reqs` and returns the corresponding normalized
447
    /// sort requirements.
448
1.25k
    pub fn normalize_sort_requirements(
449
1.25k
        &self,
450
1.25k
        sort_reqs: LexRequirementRef,
451
1.25k
    ) -> LexRequirement {
452
1.25k
        collapse_lex_req(LexRequirement::new(
453
1.25k
            sort_reqs
454
1.25k
                .iter()
455
3.42k
                .map(|sort_req| self.normalize_sort_requirement(sort_req.clone()))
456
1.25k
                .collect(),
457
1.25k
        ))
458
1.25k
    }
459
460
    /// Projects `expr` according to the given projection mapping.
461
    /// If the resulting expression is invalid after projection, returns `None`.
462
1
    pub fn project_expr(
463
1
        &self,
464
1
        mapping: &ProjectionMapping,
465
1
        expr: &Arc<dyn PhysicalExpr>,
466
1
    ) -> Option<Arc<dyn PhysicalExpr>> {
467
        // First, we try to project expressions with an exact match. If we are
468
        // unable to do this, we consult equivalence classes.
469
1
        if let Some(target) = mapping.target_expr(expr) {
470
            // If we match the source, we can project directly:
471
1
            return Some(target);
472
        } else {
473
            // If the given expression is not inside the mapping, try to project
474
            // expressions considering the equivalence classes.
475
0
            for (source, target) in mapping.iter() {
476
                // If we match an equivalent expression to `source`, then we can
477
                // project. For example, if we have the mapping `(a as a1, a + c)`
478
                // and the equivalence class `(a, b)`, expression `b` projects to `a1`.
479
0
                if self
480
0
                    .get_equivalence_class(source)
481
0
                    .map_or(false, |group| group.contains(expr))
482
                {
483
0
                    return Some(Arc::clone(target));
484
0
                }
485
            }
486
        }
487
        // Project a non-leaf expression by projecting its children.
488
0
        let children = expr.children();
489
0
        if children.is_empty() {
490
            // Leaf expression should be inside mapping.
491
0
            return None;
492
0
        }
493
0
        children
494
0
            .into_iter()
495
0
            .map(|child| self.project_expr(mapping, child))
496
0
            .collect::<Option<Vec<_>>>()
497
0
            .map(|children| Arc::clone(expr).with_new_children(children).unwrap())
498
1
    }
499
500
    /// Projects this equivalence group according to the given projection mapping.
501
52
    pub fn project(&self, mapping: &ProjectionMapping) -> Self {
502
52
        let projected_classes = self.iter().filter_map(|cls| {
503
0
            let new_class = cls
504
0
                .iter()
505
0
                .filter_map(|expr| self.project_expr(mapping, expr))
506
0
                .collect::<Vec<_>>();
507
0
            (new_class.len() > 1).then_some(EquivalenceClass::new(new_class))
508
52
        });
509
52
        // TODO: Convert the algorithm below to a version that uses `HashMap`.
510
52
        //       once `Arc<dyn PhysicalExpr>` can be stored in `HashMap`.
511
52
        // See issue: https://github.com/apache/datafusion/issues/8027
512
52
        let mut new_classes = vec![];
513
61
        for (source, target) in 
mapping.iter()52
{
514
61
            if new_classes.is_empty() {
515
49
                new_classes.push((source, vec![Arc::clone(target)]));
516
49
            }
12
517
49
            if let Some((_, values)) =
518
61
                new_classes.iter_mut().find(|(key, _)| key.eq(source))
519
            {
520
49
                if !physical_exprs_contains(values, target) {
521
0
                    values.push(Arc::clone(target));
522
49
                }
523
12
            }
524
        }
525
        // Only add equivalence classes with at least two members as singleton
526
        // equivalence classes are meaningless.
527
52
        let new_classes = new_classes
528
52
            .into_iter()
529
52
            .filter_map(|(_, values)| 
(values.len() > 1).then_some(values)49
)
530
52
            .map(EquivalenceClass::new);
531
52
532
52
        let classes = projected_classes.chain(new_classes).collect();
533
52
        Self::new(classes)
534
52
    }
535
536
    /// Returns the equivalence class containing `expr`. If no equivalence class
537
    /// contains `expr`, returns `None`.
538
0
    fn get_equivalence_class(
539
0
        &self,
540
0
        expr: &Arc<dyn PhysicalExpr>,
541
0
    ) -> Option<&EquivalenceClass> {
542
0
        self.iter().find(|cls| cls.contains(expr))
543
0
    }
544
545
    /// Combine equivalence groups of the given join children.
546
1.14k
    pub fn join(
547
1.14k
        &self,
548
1.14k
        right_equivalences: &Self,
549
1.14k
        join_type: &JoinType,
550
1.14k
        left_size: usize,
551
1.14k
        on: &[(PhysicalExprRef, PhysicalExprRef)],
552
1.14k
    ) -> Self {
553
1.14k
        match join_type {
554
            JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
555
617
                let mut result = Self::new(
556
617
                    self.iter()
557
617
                        .cloned()
558
617
                        .chain(
559
617
                            right_equivalences
560
617
                                .iter()
561
617
                                .map(|cls| 
cls.with_offset(left_size)0
),
562
617
                        )
563
617
                        .collect(),
564
617
                );
565
617
                // In we have an inner join, expressions in the "on" condition
566
617
                // are equal in the resulting table.
567
617
                if join_type == &JoinType::Inner {
568
178
                    for (lhs, rhs) in 
on.iter()174
{
569
178
                        let new_lhs = Arc::clone(lhs) as _;
570
178
                        // Rewrite rhs to point to the right side of the join:
571
178
                        let new_rhs = Arc::clone(rhs)
572
178
                            .transform(|expr| {
573
178
                                if let Some(column) =
574
178
                                    expr.as_any().downcast_ref::<Column>()
575
                                {
576
178
                                    let new_column = Arc::new(Column::new(
577
178
                                        column.name(),
578
178
                                        column.index() + left_size,
579
178
                                    ))
580
178
                                        as _;
581
178
                                    return Ok(Transformed::yes(new_column));
582
0
                                }
583
0
584
0
                                Ok(Transformed::no(expr))
585
178
                            })
586
178
                            .data()
587
178
                            .unwrap();
588
178
                        result.add_equal_conditions(&new_lhs, &new_rhs);
589
178
                    }
590
443
                }
591
617
                result
592
            }
593
268
            JoinType::LeftSemi | JoinType::LeftAnti => self.clone(),
594
264
            JoinType::RightSemi | JoinType::RightAnti => right_equivalences.clone(),
595
        }
596
1.14k
    }
597
}
598
599
impl Display for EquivalenceGroup {
600
0
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
601
0
        write!(f, "[")?;
602
0
        let mut iter = self.iter();
603
0
        if let Some(cls) = iter.next() {
604
0
            write!(f, "{}", cls)?;
605
0
        }
606
0
        for cls in iter {
607
0
            write!(f, ", {}", cls)?;
608
        }
609
0
        write!(f, "]")
610
0
    }
611
}
612
613
#[cfg(test)]
614
mod tests {
615
616
    use super::*;
617
    use crate::equivalence::tests::create_test_params;
618
    use crate::expressions::{lit, Literal};
619
620
    use datafusion_common::{Result, ScalarValue};
621
622
    #[test]
623
    fn test_bridge_groups() -> Result<()> {
624
        // First entry in the tuple is argument, second entry is the bridged result
625
        let test_cases = vec![
626
            // ------- TEST CASE 1 -----------//
627
            (
628
                vec![vec![1, 2, 3], vec![2, 4, 5], vec![11, 12, 9], vec![7, 6, 5]],
629
                // Expected is compared with set equality. Order of the specific results may change.
630
                vec![vec![1, 2, 3, 4, 5, 6, 7], vec![9, 11, 12]],
631
            ),
632
            // ------- TEST CASE 2 -----------//
633
            (
634
                vec![vec![1, 2, 3], vec![3, 4, 5], vec![9, 8, 7], vec![7, 6, 5]],
635
                // Expected
636
                vec![vec![1, 2, 3, 4, 5, 6, 7, 8, 9]],
637
            ),
638
        ];
639
        for (entries, expected) in test_cases {
640
            let entries = entries
641
                .into_iter()
642
                .map(|entry| entry.into_iter().map(lit).collect::<Vec<_>>())
643
                .map(EquivalenceClass::new)
644
                .collect::<Vec<_>>();
645
            let expected = expected
646
                .into_iter()
647
                .map(|entry| entry.into_iter().map(lit).collect::<Vec<_>>())
648
                .map(EquivalenceClass::new)
649
                .collect::<Vec<_>>();
650
            let mut eq_groups = EquivalenceGroup::new(entries.clone());
651
            eq_groups.bridge_classes();
652
            let eq_groups = eq_groups.classes;
653
            let err_msg = format!(
654
                "error in test entries: {:?}, expected: {:?}, actual:{:?}",
655
                entries, expected, eq_groups
656
            );
657
            assert_eq!(eq_groups.len(), expected.len(), "{}", err_msg);
658
            for idx in 0..eq_groups.len() {
659
                assert_eq!(&eq_groups[idx], &expected[idx], "{}", err_msg);
660
            }
661
        }
662
        Ok(())
663
    }
664
665
    #[test]
666
    fn test_remove_redundant_entries_eq_group() -> Result<()> {
667
        let entries = [
668
            EquivalenceClass::new(vec![lit(1), lit(1), lit(2)]),
669
            // This group is meaningless should be removed
670
            EquivalenceClass::new(vec![lit(3), lit(3)]),
671
            EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]),
672
        ];
673
        // Given equivalences classes are not in succinct form.
674
        // Expected form is the most plain representation that is functionally same.
675
        let expected = [
676
            EquivalenceClass::new(vec![lit(1), lit(2)]),
677
            EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]),
678
        ];
679
        let mut eq_groups = EquivalenceGroup::new(entries.to_vec());
680
        eq_groups.remove_redundant_entries();
681
682
        let eq_groups = eq_groups.classes;
683
        assert_eq!(eq_groups.len(), expected.len());
684
        assert_eq!(eq_groups.len(), 2);
685
686
        assert_eq!(eq_groups[0], expected[0]);
687
        assert_eq!(eq_groups[1], expected[1]);
688
        Ok(())
689
    }
690
691
    #[test]
692
    fn test_schema_normalize_expr_with_equivalence() -> Result<()> {
693
        let col_a = &Column::new("a", 0);
694
        let col_b = &Column::new("b", 1);
695
        let col_c = &Column::new("c", 2);
696
        // Assume that column a and c are aliases.
697
        let (_test_schema, eq_properties) = create_test_params()?;
698
699
        let col_a_expr = Arc::new(col_a.clone()) as Arc<dyn PhysicalExpr>;
700
        let col_b_expr = Arc::new(col_b.clone()) as Arc<dyn PhysicalExpr>;
701
        let col_c_expr = Arc::new(col_c.clone()) as Arc<dyn PhysicalExpr>;
702
        // Test cases for equivalence normalization,
703
        // First entry in the tuple is argument, second entry is expected result after normalization.
704
        let expressions = vec![
705
            // Normalized version of the column a and c should go to a
706
            // (by convention all the expressions inside equivalence class are mapped to the first entry
707
            // in this case a is the first entry in the equivalence class.)
708
            (&col_a_expr, &col_a_expr),
709
            (&col_c_expr, &col_a_expr),
710
            // Cannot normalize column b
711
            (&col_b_expr, &col_b_expr),
712
        ];
713
        let eq_group = eq_properties.eq_group();
714
        for (expr, expected_eq) in expressions {
715
            assert!(
716
                expected_eq.eq(&eq_group.normalize_expr(Arc::clone(expr))),
717
                "error in test: expr: {expr:?}"
718
            );
719
        }
720
721
        Ok(())
722
    }
723
724
    #[test]
725
    fn test_contains_any() {
726
        let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true))))
727
            as Arc<dyn PhysicalExpr>;
728
        let lit_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false))))
729
            as Arc<dyn PhysicalExpr>;
730
        let lit2 =
731
            Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>;
732
        let lit1 =
733
            Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
734
        let col_b_expr = Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>;
735
736
        let cls1 =
737
            EquivalenceClass::new(vec![Arc::clone(&lit_true), Arc::clone(&lit_false)]);
738
        let cls2 =
739
            EquivalenceClass::new(vec![Arc::clone(&lit_true), Arc::clone(&col_b_expr)]);
740
        let cls3 = EquivalenceClass::new(vec![Arc::clone(&lit2), Arc::clone(&lit1)]);
741
742
        // lit_true is common
743
        assert!(cls1.contains_any(&cls2));
744
        // there is no common entry
745
        assert!(!cls1.contains_any(&cls3));
746
        assert!(!cls2.contains_any(&cls3));
747
    }
748
}