/Users/andrewlamb/Software/datafusion/datafusion/physical-expr/src/equivalence/class.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | use std::fmt::Display; |
19 | | use std::sync::Arc; |
20 | | |
21 | | use super::{add_offset_to_expr, collapse_lex_req, ProjectionMapping}; |
22 | | use crate::{ |
23 | | expressions::Column, physical_expr::deduplicate_physical_exprs, |
24 | | physical_exprs_bag_equal, physical_exprs_contains, LexOrdering, LexOrderingRef, |
25 | | LexRequirement, LexRequirementRef, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, |
26 | | PhysicalSortRequirement, |
27 | | }; |
28 | | |
29 | | use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; |
30 | | use datafusion_common::JoinType; |
31 | | use datafusion_physical_expr_common::physical_expr::format_physical_expr_list; |
32 | | |
33 | | #[derive(Debug, Clone)] |
34 | | /// A structure representing a expression known to be constant in a physical execution plan. |
35 | | /// |
36 | | /// The `ConstExpr` struct encapsulates an expression that is constant during the execution |
37 | | /// of a query. For example if a predicate like `A = 5` applied earlier in the plan `A` would |
38 | | /// be known constant |
39 | | /// |
40 | | /// # Fields |
41 | | /// |
42 | | /// - `expr`: Constant expression for a node in the physical plan. |
43 | | /// |
44 | | /// - `across_partitions`: A boolean flag indicating whether the constant expression is |
45 | | /// valid across partitions. If set to `true`, the constant expression has same value for all partitions. |
46 | | /// If set to `false`, the constant expression may have different values for different partitions. |
47 | | /// |
48 | | /// # Example |
49 | | /// |
50 | | /// ```rust |
51 | | /// # use datafusion_physical_expr::ConstExpr; |
52 | | /// # use datafusion_physical_expr::expressions::lit; |
53 | | /// let col = lit(5); |
54 | | /// // Create a constant expression from a physical expression ref |
55 | | /// let const_expr = ConstExpr::from(&col); |
56 | | /// // create a constant expression from a physical expression |
57 | | /// let const_expr = ConstExpr::from(col); |
58 | | /// ``` |
59 | | pub struct ConstExpr { |
60 | | expr: Arc<dyn PhysicalExpr>, |
61 | | across_partitions: bool, |
62 | | } |
63 | | |
64 | | impl ConstExpr { |
65 | | /// Create a new constant expression from a physical expression. |
66 | | /// |
67 | | /// Note you can also use `ConstExpr::from` to create a constant expression |
68 | | /// from a reference as well |
69 | 549 | pub fn new(expr: Arc<dyn PhysicalExpr>) -> Self { |
70 | 549 | Self { |
71 | 549 | expr, |
72 | 549 | // By default, assume constant expressions are not same across partitions. |
73 | 549 | across_partitions: false, |
74 | 549 | } |
75 | 549 | } |
76 | | |
77 | 280 | pub fn with_across_partitions(mut self, across_partitions: bool) -> Self { |
78 | 280 | self.across_partitions = across_partitions; |
79 | 280 | self |
80 | 280 | } |
81 | | |
82 | 276 | pub fn across_partitions(&self) -> bool { |
83 | 276 | self.across_partitions |
84 | 276 | } |
85 | | |
86 | 1.36k | pub fn expr(&self) -> &Arc<dyn PhysicalExpr> { |
87 | 1.36k | &self.expr |
88 | 1.36k | } |
89 | | |
90 | 276 | pub fn owned_expr(self) -> Arc<dyn PhysicalExpr> { |
91 | 276 | self.expr |
92 | 276 | } |
93 | | |
94 | 0 | pub fn map<F>(&self, f: F) -> Option<Self> |
95 | 0 | where |
96 | 0 | F: Fn(&Arc<dyn PhysicalExpr>) -> Option<Arc<dyn PhysicalExpr>>, |
97 | 0 | { |
98 | 0 | let maybe_expr = f(&self.expr); |
99 | 0 | maybe_expr.map(|expr| Self { |
100 | 0 | expr, |
101 | 0 | across_partitions: self.across_partitions, |
102 | 0 | }) |
103 | 0 | } |
104 | | } |
105 | | |
106 | | /// Display implementation for `ConstExpr` |
107 | | /// |
108 | | /// Example `c` or `c(across_partitions)` |
109 | | impl Display for ConstExpr { |
110 | 0 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { |
111 | 0 | write!(f, "{}", self.expr)?; |
112 | 0 | if self.across_partitions { |
113 | 0 | write!(f, "(across_partitions)")?; |
114 | 0 | } |
115 | 0 | Ok(()) |
116 | 0 | } |
117 | | } |
118 | | |
119 | | impl From<Arc<dyn PhysicalExpr>> for ConstExpr { |
120 | 337 | fn from(expr: Arc<dyn PhysicalExpr>) -> Self { |
121 | 337 | Self::new(expr) |
122 | 337 | } |
123 | | } |
124 | | |
125 | | impl From<&Arc<dyn PhysicalExpr>> for ConstExpr { |
126 | 209 | fn from(expr: &Arc<dyn PhysicalExpr>) -> Self { |
127 | 209 | Self::new(Arc::clone(expr)) |
128 | 209 | } |
129 | | } |
130 | | |
131 | | /// Checks whether `expr` is among in the `const_exprs`. |
132 | 2.46k | pub fn const_exprs_contains( |
133 | 2.46k | const_exprs: &[ConstExpr], |
134 | 2.46k | expr: &Arc<dyn PhysicalExpr>, |
135 | 2.46k | ) -> bool { |
136 | 2.46k | const_exprs |
137 | 2.46k | .iter() |
138 | 2.46k | .any(|const_expr| const_expr.expr.eq(expr)191 ) |
139 | 2.46k | } |
140 | | |
141 | | /// An `EquivalenceClass` is a set of [`Arc<dyn PhysicalExpr>`]s that are known |
142 | | /// to have the same value for all tuples in a relation. These are generated by |
143 | | /// equality predicates (e.g. `a = b`), typically equi-join conditions and |
144 | | /// equality conditions in filters. |
145 | | /// |
146 | | /// Two `EquivalenceClass`es are equal if they contains the same expressions in |
147 | | /// without any ordering. |
148 | | #[derive(Debug, Clone)] |
149 | | pub struct EquivalenceClass { |
150 | | /// The expressions in this equivalence class. The order doesn't |
151 | | /// matter for equivalence purposes |
152 | | /// |
153 | | /// TODO: use a HashSet for this instead of a Vec |
154 | | exprs: Vec<Arc<dyn PhysicalExpr>>, |
155 | | } |
156 | | |
157 | | impl PartialEq for EquivalenceClass { |
158 | | /// Returns true if other is equal in the sense |
159 | | /// of bags (multi-sets), disregarding their orderings. |
160 | 0 | fn eq(&self, other: &Self) -> bool { |
161 | 0 | physical_exprs_bag_equal(&self.exprs, &other.exprs) |
162 | 0 | } |
163 | | } |
164 | | |
165 | | impl EquivalenceClass { |
166 | | /// Create a new empty equivalence class |
167 | 0 | pub fn new_empty() -> Self { |
168 | 0 | Self { exprs: vec![] } |
169 | 0 | } |
170 | | |
171 | | // Create a new equivalence class from a pre-existing `Vec` |
172 | 183 | pub fn new(mut exprs: Vec<Arc<dyn PhysicalExpr>>) -> Self { |
173 | 183 | deduplicate_physical_exprs(&mut exprs); |
174 | 183 | Self { exprs } |
175 | 183 | } |
176 | | |
177 | | /// Return the inner vector of expressions |
178 | 5 | pub fn into_vec(self) -> Vec<Arc<dyn PhysicalExpr>> { |
179 | 5 | self.exprs |
180 | 5 | } |
181 | | |
182 | | /// Return the "canonical" expression for this class (the first element) |
183 | | /// if any |
184 | 45 | fn canonical_expr(&self) -> Option<Arc<dyn PhysicalExpr>> { |
185 | 45 | self.exprs.first().cloned() |
186 | 45 | } |
187 | | |
188 | | /// Insert the expression into this class, meaning it is known to be equal to |
189 | | /// all other expressions in this class |
190 | 0 | pub fn push(&mut self, expr: Arc<dyn PhysicalExpr>) { |
191 | 0 | if !self.contains(&expr) { |
192 | 0 | self.exprs.push(expr); |
193 | 0 | } |
194 | 0 | } |
195 | | |
196 | | /// Inserts all the expressions from other into this class |
197 | 0 | pub fn extend(&mut self, other: Self) { |
198 | 0 | for expr in other.exprs { |
199 | 0 | // use push so entries are deduplicated |
200 | 0 | self.push(expr); |
201 | 0 | } |
202 | 0 | } |
203 | | |
204 | | /// Returns true if this equivalence class contains t expression |
205 | 118 | pub fn contains(&self, expr: &Arc<dyn PhysicalExpr>) -> bool { |
206 | 118 | physical_exprs_contains(&self.exprs, expr) |
207 | 118 | } |
208 | | |
209 | | /// Returns true if this equivalence class has any entries in common with `other` |
210 | 15 | pub fn contains_any(&self, other: &Self) -> bool { |
211 | 30 | self.exprs.iter().any(|e| other.contains(e)) |
212 | 15 | } |
213 | | |
214 | | /// return the number of items in this class |
215 | 534 | pub fn len(&self) -> usize { |
216 | 534 | self.exprs.len() |
217 | 534 | } |
218 | | |
219 | | /// return true if this class is empty |
220 | 0 | pub fn is_empty(&self) -> bool { |
221 | 0 | self.exprs.is_empty() |
222 | 0 | } |
223 | | |
224 | | /// Iterate over all elements in this class, in some arbitrary order |
225 | 0 | pub fn iter(&self) -> impl Iterator<Item = &Arc<dyn PhysicalExpr>> { |
226 | 0 | self.exprs.iter() |
227 | 0 | } |
228 | | |
229 | | /// Return a new equivalence class that have the specified offset added to |
230 | | /// each expression (used when schemas are appended such as in joins) |
231 | 0 | pub fn with_offset(&self, offset: usize) -> Self { |
232 | 0 | let new_exprs = self |
233 | 0 | .exprs |
234 | 0 | .iter() |
235 | 0 | .cloned() |
236 | 0 | .map(|e| add_offset_to_expr(e, offset)) |
237 | 0 | .collect(); |
238 | 0 | Self::new(new_exprs) |
239 | 0 | } |
240 | | } |
241 | | |
242 | | impl Display for EquivalenceClass { |
243 | 0 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { |
244 | 0 | write!(f, "[{}]", format_physical_expr_list(&self.exprs)) |
245 | 0 | } |
246 | | } |
247 | | |
248 | | /// An `EquivalenceGroup` is a collection of `EquivalenceClass`es where each |
249 | | /// class represents a distinct equivalence class in a relation. |
250 | | #[derive(Debug, Clone)] |
251 | | pub struct EquivalenceGroup { |
252 | | pub classes: Vec<EquivalenceClass>, |
253 | | } |
254 | | |
255 | | impl EquivalenceGroup { |
256 | | /// Creates an empty equivalence group. |
257 | 3.11k | pub fn empty() -> Self { |
258 | 3.11k | Self { classes: vec![] } |
259 | 3.11k | } |
260 | | |
261 | | /// Creates an equivalence group from the given equivalence classes. |
262 | 669 | pub fn new(classes: Vec<EquivalenceClass>) -> Self { |
263 | 669 | let mut result = Self { classes }; |
264 | 669 | result.remove_redundant_entries(); |
265 | 669 | result |
266 | 669 | } |
267 | | |
268 | | /// Returns how many equivalence classes there are in this group. |
269 | 0 | pub fn len(&self) -> usize { |
270 | 0 | self.classes.len() |
271 | 0 | } |
272 | | |
273 | | /// Checks whether this equivalence group is empty. |
274 | 0 | pub fn is_empty(&self) -> bool { |
275 | 0 | self.len() == 0 |
276 | 0 | } |
277 | | |
278 | | /// Returns an iterator over the equivalence classes in this group. |
279 | 7.73k | pub fn iter(&self) -> impl Iterator<Item = &EquivalenceClass> { |
280 | 7.73k | self.classes.iter() |
281 | 7.73k | } |
282 | | |
283 | | /// Adds the equality `left` = `right` to this equivalence group. |
284 | | /// New equality conditions often arise after steps like `Filter(a = b)`, |
285 | | /// `Alias(a, a as b)` etc. |
286 | 183 | pub fn add_equal_conditions( |
287 | 183 | &mut self, |
288 | 183 | left: &Arc<dyn PhysicalExpr>, |
289 | 183 | right: &Arc<dyn PhysicalExpr>, |
290 | 183 | ) { |
291 | 183 | let mut first_class = None; |
292 | 183 | let mut second_class = None; |
293 | 183 | for (idx, cls15 ) in self.classes.iter().enumerate() { |
294 | 15 | if cls.contains(left) { |
295 | 0 | first_class = Some(idx); |
296 | 15 | } |
297 | 15 | if cls.contains(right) { |
298 | 0 | second_class = Some(idx); |
299 | 15 | } |
300 | | } |
301 | 183 | match (first_class, second_class) { |
302 | 0 | (Some(mut first_idx), Some(mut second_idx)) => { |
303 | 0 | // If the given left and right sides belong to different classes, |
304 | 0 | // we should unify/bridge these classes. |
305 | 0 | if first_idx != second_idx { |
306 | | // By convention, make sure `second_idx` is larger than `first_idx`. |
307 | 0 | if first_idx > second_idx { |
308 | 0 | (first_idx, second_idx) = (second_idx, first_idx); |
309 | 0 | } |
310 | | // Remove the class at `second_idx` and merge its values with |
311 | | // the class at `first_idx`. The convention above makes sure |
312 | | // that `first_idx` is still valid after removing `second_idx`. |
313 | 0 | let other_class = self.classes.swap_remove(second_idx); |
314 | 0 | self.classes[first_idx].extend(other_class); |
315 | 0 | } |
316 | | } |
317 | 0 | (Some(group_idx), None) => { |
318 | 0 | // Right side is new, extend left side's class: |
319 | 0 | self.classes[group_idx].push(Arc::clone(right)); |
320 | 0 | } |
321 | 0 | (None, Some(group_idx)) => { |
322 | 0 | // Left side is new, extend right side's class: |
323 | 0 | self.classes[group_idx].push(Arc::clone(left)); |
324 | 0 | } |
325 | 183 | (None, None) => { |
326 | 183 | // None of the expressions is among existing classes. |
327 | 183 | // Create a new equivalence class and extend the group. |
328 | 183 | self.classes.push(EquivalenceClass::new(vec![ |
329 | 183 | Arc::clone(left), |
330 | 183 | Arc::clone(right), |
331 | 183 | ])); |
332 | 183 | } |
333 | | } |
334 | 183 | } |
335 | | |
336 | | /// Removes redundant entries from this group. |
337 | 1.82k | fn remove_redundant_entries(&mut self) { |
338 | 1.82k | // Remove duplicate entries from each equivalence class: |
339 | 1.82k | self.classes.retain_mut(|cls| { |
340 | 178 | // Keep groups that have at least two entries as singleton class is |
341 | 178 | // meaningless (i.e. it contains no non-trivial information): |
342 | 178 | cls.len() > 1 |
343 | 1.82k | }); |
344 | 1.82k | // Unify/bridge groups that have common expressions: |
345 | 1.82k | self.bridge_classes() |
346 | 1.82k | } |
347 | | |
348 | | /// This utility function unifies/bridges classes that have common expressions. |
349 | | /// For example, assume that we have [`EquivalenceClass`]es `[a, b]` and `[b, c]`. |
350 | | /// Since both classes contain `b`, columns `a`, `b` and `c` are actually all |
351 | | /// equal and belong to one class. This utility converts merges such classes. |
352 | 1.82k | fn bridge_classes(&mut self) { |
353 | 1.82k | let mut idx = 0; |
354 | 1.99k | while idx < self.classes.len() { |
355 | 178 | let mut next_idx = idx + 1; |
356 | 178 | let start_size = self.classes[idx].len(); |
357 | 193 | while next_idx < self.classes.len() { |
358 | 15 | if self.classes[idx].contains_any(&self.classes[next_idx]) { |
359 | 0 | let extension = self.classes.swap_remove(next_idx); |
360 | 0 | self.classes[idx].extend(extension); |
361 | 15 | } else { |
362 | 15 | next_idx += 1; |
363 | 15 | } |
364 | | } |
365 | 178 | if self.classes[idx].len() > start_size { |
366 | 0 | continue; |
367 | 178 | } |
368 | 178 | idx += 1; |
369 | | } |
370 | 1.82k | } |
371 | | |
372 | | /// Extends this equivalence group with the `other` equivalence group. |
373 | 1.15k | pub fn extend(&mut self, other: Self) { |
374 | 1.15k | self.classes.extend(other.classes); |
375 | 1.15k | self.remove_redundant_entries(); |
376 | 1.15k | } |
377 | | |
378 | | /// Normalizes the given physical expression according to this group. |
379 | | /// The expression is replaced with the first expression in the equivalence |
380 | | /// class it matches with (if any). |
381 | 6.44k | pub fn normalize_expr(&self, expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> { |
382 | 6.44k | Arc::clone(&expr) |
383 | 6.44k | .transform(|expr| { |
384 | 6.44k | for cls53 in self.iter() { |
385 | 53 | if cls.contains(&expr) { |
386 | 45 | return Ok(Transformed::yes(cls.canonical_expr().unwrap())); |
387 | 8 | } |
388 | | } |
389 | 6.40k | Ok(Transformed::no(expr)) |
390 | 6.44k | }) |
391 | 6.44k | .data() |
392 | 6.44k | .unwrap_or(expr) |
393 | 6.44k | } |
394 | | |
395 | | /// Normalizes the given sort expression according to this group. |
396 | | /// The underlying physical expression is replaced with the first expression |
397 | | /// in the equivalence class it matches with (if any). If the underlying |
398 | | /// expression does not belong to any equivalence class in this group, returns |
399 | | /// the sort expression as is. |
400 | 0 | pub fn normalize_sort_expr( |
401 | 0 | &self, |
402 | 0 | mut sort_expr: PhysicalSortExpr, |
403 | 0 | ) -> PhysicalSortExpr { |
404 | 0 | sort_expr.expr = self.normalize_expr(sort_expr.expr); |
405 | 0 | sort_expr |
406 | 0 | } |
407 | | |
408 | | /// Normalizes the given sort requirement according to this group. |
409 | | /// The underlying physical expression is replaced with the first expression |
410 | | /// in the equivalence class it matches with (if any). If the underlying |
411 | | /// expression does not belong to any equivalence class in this group, returns |
412 | | /// the given sort requirement as is. |
413 | 3.42k | pub fn normalize_sort_requirement( |
414 | 3.42k | &self, |
415 | 3.42k | mut sort_requirement: PhysicalSortRequirement, |
416 | 3.42k | ) -> PhysicalSortRequirement { |
417 | 3.42k | sort_requirement.expr = self.normalize_expr(sort_requirement.expr); |
418 | 3.42k | sort_requirement |
419 | 3.42k | } |
420 | | |
421 | | /// This function applies the `normalize_expr` function for all expressions |
422 | | /// in `exprs` and returns the corresponding normalized physical expressions. |
423 | 2.58k | pub fn normalize_exprs( |
424 | 2.58k | &self, |
425 | 2.58k | exprs: impl IntoIterator<Item = Arc<dyn PhysicalExpr>>, |
426 | 2.58k | ) -> Vec<Arc<dyn PhysicalExpr>> { |
427 | 2.58k | exprs |
428 | 2.58k | .into_iter() |
429 | 2.58k | .map(|expr| self.normalize_expr(expr)1.63k ) |
430 | 2.58k | .collect() |
431 | 2.58k | } |
432 | | |
433 | | /// This function applies the `normalize_sort_expr` function for all sort |
434 | | /// expressions in `sort_exprs` and returns the corresponding normalized |
435 | | /// sort expressions. |
436 | 0 | pub fn normalize_sort_exprs(&self, sort_exprs: LexOrderingRef) -> LexOrdering { |
437 | 0 | // Convert sort expressions to sort requirements: |
438 | 0 | let sort_reqs = PhysicalSortRequirement::from_sort_exprs(sort_exprs.iter()); |
439 | 0 | // Normalize the requirements: |
440 | 0 | let normalized_sort_reqs = self.normalize_sort_requirements(&sort_reqs); |
441 | 0 | // Convert sort requirements back to sort expressions: |
442 | 0 | PhysicalSortRequirement::to_sort_exprs(normalized_sort_reqs.inner) |
443 | 0 | } |
444 | | |
445 | | /// This function applies the `normalize_sort_requirement` function for all |
446 | | /// requirements in `sort_reqs` and returns the corresponding normalized |
447 | | /// sort requirements. |
448 | 1.25k | pub fn normalize_sort_requirements( |
449 | 1.25k | &self, |
450 | 1.25k | sort_reqs: LexRequirementRef, |
451 | 1.25k | ) -> LexRequirement { |
452 | 1.25k | collapse_lex_req(LexRequirement::new( |
453 | 1.25k | sort_reqs |
454 | 1.25k | .iter() |
455 | 3.42k | .map(|sort_req| self.normalize_sort_requirement(sort_req.clone())) |
456 | 1.25k | .collect(), |
457 | 1.25k | )) |
458 | 1.25k | } |
459 | | |
460 | | /// Projects `expr` according to the given projection mapping. |
461 | | /// If the resulting expression is invalid after projection, returns `None`. |
462 | 1 | pub fn project_expr( |
463 | 1 | &self, |
464 | 1 | mapping: &ProjectionMapping, |
465 | 1 | expr: &Arc<dyn PhysicalExpr>, |
466 | 1 | ) -> Option<Arc<dyn PhysicalExpr>> { |
467 | | // First, we try to project expressions with an exact match. If we are |
468 | | // unable to do this, we consult equivalence classes. |
469 | 1 | if let Some(target) = mapping.target_expr(expr) { |
470 | | // If we match the source, we can project directly: |
471 | 1 | return Some(target); |
472 | | } else { |
473 | | // If the given expression is not inside the mapping, try to project |
474 | | // expressions considering the equivalence classes. |
475 | 0 | for (source, target) in mapping.iter() { |
476 | | // If we match an equivalent expression to `source`, then we can |
477 | | // project. For example, if we have the mapping `(a as a1, a + c)` |
478 | | // and the equivalence class `(a, b)`, expression `b` projects to `a1`. |
479 | 0 | if self |
480 | 0 | .get_equivalence_class(source) |
481 | 0 | .map_or(false, |group| group.contains(expr)) |
482 | | { |
483 | 0 | return Some(Arc::clone(target)); |
484 | 0 | } |
485 | | } |
486 | | } |
487 | | // Project a non-leaf expression by projecting its children. |
488 | 0 | let children = expr.children(); |
489 | 0 | if children.is_empty() { |
490 | | // Leaf expression should be inside mapping. |
491 | 0 | return None; |
492 | 0 | } |
493 | 0 | children |
494 | 0 | .into_iter() |
495 | 0 | .map(|child| self.project_expr(mapping, child)) |
496 | 0 | .collect::<Option<Vec<_>>>() |
497 | 0 | .map(|children| Arc::clone(expr).with_new_children(children).unwrap()) |
498 | 1 | } |
499 | | |
500 | | /// Projects this equivalence group according to the given projection mapping. |
501 | 52 | pub fn project(&self, mapping: &ProjectionMapping) -> Self { |
502 | 52 | let projected_classes = self.iter().filter_map(|cls| { |
503 | 0 | let new_class = cls |
504 | 0 | .iter() |
505 | 0 | .filter_map(|expr| self.project_expr(mapping, expr)) |
506 | 0 | .collect::<Vec<_>>(); |
507 | 0 | (new_class.len() > 1).then_some(EquivalenceClass::new(new_class)) |
508 | 52 | }); |
509 | 52 | // TODO: Convert the algorithm below to a version that uses `HashMap`. |
510 | 52 | // once `Arc<dyn PhysicalExpr>` can be stored in `HashMap`. |
511 | 52 | // See issue: https://github.com/apache/datafusion/issues/8027 |
512 | 52 | let mut new_classes = vec![]; |
513 | 61 | for (source, target) in mapping.iter()52 { |
514 | 61 | if new_classes.is_empty() { |
515 | 49 | new_classes.push((source, vec![Arc::clone(target)])); |
516 | 49 | }12 |
517 | 49 | if let Some((_, values)) = |
518 | 61 | new_classes.iter_mut().find(|(key, _)| key.eq(source)) |
519 | | { |
520 | 49 | if !physical_exprs_contains(values, target) { |
521 | 0 | values.push(Arc::clone(target)); |
522 | 49 | } |
523 | 12 | } |
524 | | } |
525 | | // Only add equivalence classes with at least two members as singleton |
526 | | // equivalence classes are meaningless. |
527 | 52 | let new_classes = new_classes |
528 | 52 | .into_iter() |
529 | 52 | .filter_map(|(_, values)| (values.len() > 1).then_some(values)49 ) |
530 | 52 | .map(EquivalenceClass::new); |
531 | 52 | |
532 | 52 | let classes = projected_classes.chain(new_classes).collect(); |
533 | 52 | Self::new(classes) |
534 | 52 | } |
535 | | |
536 | | /// Returns the equivalence class containing `expr`. If no equivalence class |
537 | | /// contains `expr`, returns `None`. |
538 | 0 | fn get_equivalence_class( |
539 | 0 | &self, |
540 | 0 | expr: &Arc<dyn PhysicalExpr>, |
541 | 0 | ) -> Option<&EquivalenceClass> { |
542 | 0 | self.iter().find(|cls| cls.contains(expr)) |
543 | 0 | } |
544 | | |
545 | | /// Combine equivalence groups of the given join children. |
546 | 1.14k | pub fn join( |
547 | 1.14k | &self, |
548 | 1.14k | right_equivalences: &Self, |
549 | 1.14k | join_type: &JoinType, |
550 | 1.14k | left_size: usize, |
551 | 1.14k | on: &[(PhysicalExprRef, PhysicalExprRef)], |
552 | 1.14k | ) -> Self { |
553 | 1.14k | match join_type { |
554 | | JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { |
555 | 617 | let mut result = Self::new( |
556 | 617 | self.iter() |
557 | 617 | .cloned() |
558 | 617 | .chain( |
559 | 617 | right_equivalences |
560 | 617 | .iter() |
561 | 617 | .map(|cls| cls.with_offset(left_size)0 ), |
562 | 617 | ) |
563 | 617 | .collect(), |
564 | 617 | ); |
565 | 617 | // In we have an inner join, expressions in the "on" condition |
566 | 617 | // are equal in the resulting table. |
567 | 617 | if join_type == &JoinType::Inner { |
568 | 178 | for (lhs, rhs) in on.iter()174 { |
569 | 178 | let new_lhs = Arc::clone(lhs) as _; |
570 | 178 | // Rewrite rhs to point to the right side of the join: |
571 | 178 | let new_rhs = Arc::clone(rhs) |
572 | 178 | .transform(|expr| { |
573 | 178 | if let Some(column) = |
574 | 178 | expr.as_any().downcast_ref::<Column>() |
575 | | { |
576 | 178 | let new_column = Arc::new(Column::new( |
577 | 178 | column.name(), |
578 | 178 | column.index() + left_size, |
579 | 178 | )) |
580 | 178 | as _; |
581 | 178 | return Ok(Transformed::yes(new_column)); |
582 | 0 | } |
583 | 0 |
|
584 | 0 | Ok(Transformed::no(expr)) |
585 | 178 | }) |
586 | 178 | .data() |
587 | 178 | .unwrap(); |
588 | 178 | result.add_equal_conditions(&new_lhs, &new_rhs); |
589 | 178 | } |
590 | 443 | } |
591 | 617 | result |
592 | | } |
593 | 268 | JoinType::LeftSemi | JoinType::LeftAnti => self.clone(), |
594 | 264 | JoinType::RightSemi | JoinType::RightAnti => right_equivalences.clone(), |
595 | | } |
596 | 1.14k | } |
597 | | } |
598 | | |
599 | | impl Display for EquivalenceGroup { |
600 | 0 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { |
601 | 0 | write!(f, "[")?; |
602 | 0 | let mut iter = self.iter(); |
603 | 0 | if let Some(cls) = iter.next() { |
604 | 0 | write!(f, "{}", cls)?; |
605 | 0 | } |
606 | 0 | for cls in iter { |
607 | 0 | write!(f, ", {}", cls)?; |
608 | | } |
609 | 0 | write!(f, "]") |
610 | 0 | } |
611 | | } |
612 | | |
613 | | #[cfg(test)] |
614 | | mod tests { |
615 | | |
616 | | use super::*; |
617 | | use crate::equivalence::tests::create_test_params; |
618 | | use crate::expressions::{lit, Literal}; |
619 | | |
620 | | use datafusion_common::{Result, ScalarValue}; |
621 | | |
622 | | #[test] |
623 | | fn test_bridge_groups() -> Result<()> { |
624 | | // First entry in the tuple is argument, second entry is the bridged result |
625 | | let test_cases = vec![ |
626 | | // ------- TEST CASE 1 -----------// |
627 | | ( |
628 | | vec![vec![1, 2, 3], vec![2, 4, 5], vec![11, 12, 9], vec![7, 6, 5]], |
629 | | // Expected is compared with set equality. Order of the specific results may change. |
630 | | vec![vec![1, 2, 3, 4, 5, 6, 7], vec![9, 11, 12]], |
631 | | ), |
632 | | // ------- TEST CASE 2 -----------// |
633 | | ( |
634 | | vec![vec![1, 2, 3], vec![3, 4, 5], vec![9, 8, 7], vec![7, 6, 5]], |
635 | | // Expected |
636 | | vec![vec![1, 2, 3, 4, 5, 6, 7, 8, 9]], |
637 | | ), |
638 | | ]; |
639 | | for (entries, expected) in test_cases { |
640 | | let entries = entries |
641 | | .into_iter() |
642 | | .map(|entry| entry.into_iter().map(lit).collect::<Vec<_>>()) |
643 | | .map(EquivalenceClass::new) |
644 | | .collect::<Vec<_>>(); |
645 | | let expected = expected |
646 | | .into_iter() |
647 | | .map(|entry| entry.into_iter().map(lit).collect::<Vec<_>>()) |
648 | | .map(EquivalenceClass::new) |
649 | | .collect::<Vec<_>>(); |
650 | | let mut eq_groups = EquivalenceGroup::new(entries.clone()); |
651 | | eq_groups.bridge_classes(); |
652 | | let eq_groups = eq_groups.classes; |
653 | | let err_msg = format!( |
654 | | "error in test entries: {:?}, expected: {:?}, actual:{:?}", |
655 | | entries, expected, eq_groups |
656 | | ); |
657 | | assert_eq!(eq_groups.len(), expected.len(), "{}", err_msg); |
658 | | for idx in 0..eq_groups.len() { |
659 | | assert_eq!(&eq_groups[idx], &expected[idx], "{}", err_msg); |
660 | | } |
661 | | } |
662 | | Ok(()) |
663 | | } |
664 | | |
665 | | #[test] |
666 | | fn test_remove_redundant_entries_eq_group() -> Result<()> { |
667 | | let entries = [ |
668 | | EquivalenceClass::new(vec![lit(1), lit(1), lit(2)]), |
669 | | // This group is meaningless should be removed |
670 | | EquivalenceClass::new(vec![lit(3), lit(3)]), |
671 | | EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]), |
672 | | ]; |
673 | | // Given equivalences classes are not in succinct form. |
674 | | // Expected form is the most plain representation that is functionally same. |
675 | | let expected = [ |
676 | | EquivalenceClass::new(vec![lit(1), lit(2)]), |
677 | | EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]), |
678 | | ]; |
679 | | let mut eq_groups = EquivalenceGroup::new(entries.to_vec()); |
680 | | eq_groups.remove_redundant_entries(); |
681 | | |
682 | | let eq_groups = eq_groups.classes; |
683 | | assert_eq!(eq_groups.len(), expected.len()); |
684 | | assert_eq!(eq_groups.len(), 2); |
685 | | |
686 | | assert_eq!(eq_groups[0], expected[0]); |
687 | | assert_eq!(eq_groups[1], expected[1]); |
688 | | Ok(()) |
689 | | } |
690 | | |
691 | | #[test] |
692 | | fn test_schema_normalize_expr_with_equivalence() -> Result<()> { |
693 | | let col_a = &Column::new("a", 0); |
694 | | let col_b = &Column::new("b", 1); |
695 | | let col_c = &Column::new("c", 2); |
696 | | // Assume that column a and c are aliases. |
697 | | let (_test_schema, eq_properties) = create_test_params()?; |
698 | | |
699 | | let col_a_expr = Arc::new(col_a.clone()) as Arc<dyn PhysicalExpr>; |
700 | | let col_b_expr = Arc::new(col_b.clone()) as Arc<dyn PhysicalExpr>; |
701 | | let col_c_expr = Arc::new(col_c.clone()) as Arc<dyn PhysicalExpr>; |
702 | | // Test cases for equivalence normalization, |
703 | | // First entry in the tuple is argument, second entry is expected result after normalization. |
704 | | let expressions = vec![ |
705 | | // Normalized version of the column a and c should go to a |
706 | | // (by convention all the expressions inside equivalence class are mapped to the first entry |
707 | | // in this case a is the first entry in the equivalence class.) |
708 | | (&col_a_expr, &col_a_expr), |
709 | | (&col_c_expr, &col_a_expr), |
710 | | // Cannot normalize column b |
711 | | (&col_b_expr, &col_b_expr), |
712 | | ]; |
713 | | let eq_group = eq_properties.eq_group(); |
714 | | for (expr, expected_eq) in expressions { |
715 | | assert!( |
716 | | expected_eq.eq(&eq_group.normalize_expr(Arc::clone(expr))), |
717 | | "error in test: expr: {expr:?}" |
718 | | ); |
719 | | } |
720 | | |
721 | | Ok(()) |
722 | | } |
723 | | |
724 | | #[test] |
725 | | fn test_contains_any() { |
726 | | let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true)))) |
727 | | as Arc<dyn PhysicalExpr>; |
728 | | let lit_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) |
729 | | as Arc<dyn PhysicalExpr>; |
730 | | let lit2 = |
731 | | Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>; |
732 | | let lit1 = |
733 | | Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>; |
734 | | let col_b_expr = Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>; |
735 | | |
736 | | let cls1 = |
737 | | EquivalenceClass::new(vec![Arc::clone(&lit_true), Arc::clone(&lit_false)]); |
738 | | let cls2 = |
739 | | EquivalenceClass::new(vec![Arc::clone(&lit_true), Arc::clone(&col_b_expr)]); |
740 | | let cls3 = EquivalenceClass::new(vec![Arc::clone(&lit2), Arc::clone(&lit1)]); |
741 | | |
742 | | // lit_true is common |
743 | | assert!(cls1.contains_any(&cls2)); |
744 | | // there is no common entry |
745 | | assert!(!cls1.contains_any(&cls3)); |
746 | | assert!(!cls2.contains_any(&cls3)); |
747 | | } |
748 | | } |