/Users/andrewlamb/Software/datafusion/datafusion/physical-expr/src/intervals/cp_solver.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | //http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Constraint propagator/solver for custom PhysicalExpr graphs. |
19 | | |
20 | | use std::collections::HashSet; |
21 | | use std::fmt::{Display, Formatter}; |
22 | | use std::sync::Arc; |
23 | | |
24 | | use super::utils::{ |
25 | | convert_duration_type_to_interval, convert_interval_type_to_duration, get_inverse_op, |
26 | | }; |
27 | | use crate::expressions::Literal; |
28 | | use crate::utils::{build_dag, ExprTreeNode}; |
29 | | use crate::PhysicalExpr; |
30 | | |
31 | | use arrow_schema::{DataType, Schema}; |
32 | | use datafusion_common::{internal_err, Result}; |
33 | | use datafusion_expr::interval_arithmetic::{apply_operator, satisfy_greater, Interval}; |
34 | | use datafusion_expr::Operator; |
35 | | |
36 | | use petgraph::graph::NodeIndex; |
37 | | use petgraph::stable_graph::{DefaultIx, StableGraph}; |
38 | | use petgraph::visit::{Bfs, Dfs, DfsPostOrder, EdgeRef}; |
39 | | use petgraph::Outgoing; |
40 | | |
41 | | // Interval arithmetic provides a way to perform mathematical operations on |
42 | | // intervals, which represent a range of possible values rather than a single |
43 | | // point value. This allows for the propagation of ranges through mathematical |
44 | | // operations, and can be used to compute bounds for a complicated expression. |
45 | | // The key idea is that by breaking down a complicated expression into simpler |
46 | | // terms, and then combining the bounds for those simpler terms, one can |
47 | | // obtain bounds for the overall expression. |
48 | | // |
49 | | // For example, consider a mathematical expression such as x^2 + y = 4. Since |
50 | | // it would be a binary tree in [PhysicalExpr] notation, this type of an |
51 | | // hierarchical computation is well-suited for a graph based implementation. |
52 | | // In such an implementation, an equation system f(x) = 0 is represented by a |
53 | | // directed acyclic expression graph (DAEG). |
54 | | // |
55 | | // In order to use interval arithmetic to compute bounds for this expression, |
56 | | // one would first determine intervals that represent the possible values of x |
57 | | // and y. Let's say that the interval for x is [1, 2] and the interval for y |
58 | | // is [-3, 1]. In the chart below, you can see how the computation takes place. |
59 | | // |
60 | | // This way of using interval arithmetic to compute bounds for a complex |
61 | | // expression by combining the bounds for the constituent terms within the |
62 | | // original expression allows us to reason about the range of possible values |
63 | | // of the expression. This information later can be used in range pruning of |
64 | | // the provably unnecessary parts of `RecordBatch`es. |
65 | | // |
66 | | // References |
67 | | // 1 - Kabak, Mehmet Ozan. Analog Circuit Start-Up Behavior Analysis: An Interval |
68 | | // Arithmetic Based Approach, Chapter 4. Stanford University, 2015. |
69 | | // 2 - Moore, Ramon E. Interval analysis. Vol. 4. Englewood Cliffs: Prentice-Hall, 1966. |
70 | | // 3 - F. Messine, "Deterministic global optimization using interval constraint |
71 | | // propagation techniques," RAIRO-Operations Research, vol. 38, no. 04, |
72 | | // pp. 277{293, 2004. |
73 | | // |
74 | | // ``` text |
75 | | // Computing bounds for an expression using interval arithmetic. Constraint propagation through a top-down evaluation of the expression |
76 | | // graph using inverse semantics. |
77 | | // |
78 | | // [-2, 5] ∩ [4, 4] = [4, 4] [4, 4] |
79 | | // +-----+ +-----+ +-----+ +-----+ |
80 | | // +----| + |----+ +----| + |----+ +----| + |----+ +----| + |----+ |
81 | | // | | | | | | | | | | | | | | | | |
82 | | // | +-----+ | | +-----+ | | +-----+ | | +-----+ | |
83 | | // | | | | | | | | |
84 | | // +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ |
85 | | // | 2 | | y | | 2 | [1, 4] | y | | 2 | [1, 4] | y | | 2 | [1, 4] | y | [0, 1]* |
86 | | // |[.] | | | |[.] | | | |[.] | | | |[.] | | | |
87 | | // +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ |
88 | | // | | | [-3, 1] | |
89 | | // | | | | |
90 | | // +---+ +---+ +---+ +---+ |
91 | | // | x | [1, 2] | x | [1, 2] | x | [1, 2] | x | [1, 2] |
92 | | // +---+ +---+ +---+ +---+ |
93 | | // |
94 | | // (a) Bottom-up evaluation: Step1 (b) Bottom up evaluation: Step2 (a) Top-down propagation: Step1 (b) Top-down propagation: Step2 |
95 | | // |
96 | | // [1 - 3, 4 + 1] = [-2, 5] [1 - 3, 4 + 1] = [-2, 5] |
97 | | // +-----+ +-----+ +-----+ +-----+ |
98 | | // +----| + |----+ +----| + |----+ +----| + |----+ +----| + |----+ |
99 | | // | | | | | | | | | | | | | | | | |
100 | | // | +-----+ | | +-----+ | | +-----+ | | +-----+ | |
101 | | // | | | | | | | | |
102 | | // +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ |
103 | | // | 2 |[1, 4] | y | | 2 |[1, 4] | y | | 2 |[3, 4]** | y | | 2 |[1, 4] | y | |
104 | | // |[.] | | | |[.] | | | |[.] | | | |[.] | | | |
105 | | // +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ |
106 | | // | [-3, 1] | [-3, 1] | [0, 1] | [-3, 1] |
107 | | // | | | | |
108 | | // +---+ +---+ +---+ +---+ |
109 | | // | x | [1, 2] | x | [1, 2] | x | [1, 2] | x | [sqrt(3), 2]*** |
110 | | // +---+ +---+ +---+ +---+ |
111 | | // |
112 | | // (c) Bottom-up evaluation: Step3 (d) Bottom-up evaluation: Step4 (c) Top-down propagation: Step3 (d) Top-down propagation: Step4 |
113 | | // |
114 | | // * [-3, 1] ∩ ([4, 4] - [1, 4]) = [0, 1] |
115 | | // ** [1, 4] ∩ ([4, 4] - [0, 1]) = [3, 4] |
116 | | // *** [1, 2] ∩ [sqrt(3), sqrt(4)] = [sqrt(3), 2] |
117 | | // ``` |
118 | | |
119 | | /// This object implements a directed acyclic expression graph (DAEG) that |
120 | | /// is used to compute ranges for expressions through interval arithmetic. |
121 | | #[derive(Clone, Debug)] |
122 | | pub struct ExprIntervalGraph { |
123 | | graph: StableGraph<ExprIntervalGraphNode, usize>, |
124 | | root: NodeIndex, |
125 | | } |
126 | | |
127 | | impl ExprIntervalGraph { |
128 | | /// Estimate size of bytes including `Self`. |
129 | 7.88k | pub fn size(&self) -> usize { |
130 | 7.88k | let node_memory_usage = self.graph.node_count() |
131 | 7.88k | * (std::mem::size_of::<ExprIntervalGraphNode>() |
132 | 7.88k | + std::mem::size_of::<NodeIndex>()); |
133 | 7.88k | let edge_memory_usage = self.graph.edge_count() |
134 | 7.88k | * (std::mem::size_of::<usize>() + std::mem::size_of::<NodeIndex>() * 2); |
135 | 7.88k | |
136 | 7.88k | std::mem::size_of_val(self) + node_memory_usage + edge_memory_usage |
137 | 7.88k | } |
138 | | } |
139 | | |
140 | | /// This object encapsulates all possible constraint propagation results. |
141 | | #[derive(PartialEq, Debug)] |
142 | | pub enum PropagationResult { |
143 | | CannotPropagate, |
144 | | Infeasible, |
145 | | Success, |
146 | | } |
147 | | |
148 | | /// This is a node in the DAEG; it encapsulates a reference to the actual |
149 | | /// [`PhysicalExpr`] as well as an interval containing expression bounds. |
150 | | #[derive(Clone, Debug)] |
151 | | pub struct ExprIntervalGraphNode { |
152 | | expr: Arc<dyn PhysicalExpr>, |
153 | | interval: Interval, |
154 | | } |
155 | | |
156 | | impl Display for ExprIntervalGraphNode { |
157 | 0 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { |
158 | 0 | write!(f, "{}", self.expr) |
159 | 0 | } |
160 | | } |
161 | | |
162 | | impl ExprIntervalGraphNode { |
163 | | /// Constructs a new DAEG node with an [-∞, ∞] range. |
164 | 10.2k | pub fn new_unbounded(expr: Arc<dyn PhysicalExpr>, dt: &DataType) -> Result<Self> { |
165 | 10.2k | Interval::make_unbounded(dt) |
166 | 10.2k | .map(|interval| ExprIntervalGraphNode { expr, interval }) |
167 | 10.2k | } |
168 | | |
169 | | /// Constructs a new DAEG node with the given range. |
170 | 4.20k | pub fn new_with_interval(expr: Arc<dyn PhysicalExpr>, interval: Interval) -> Self { |
171 | 4.20k | ExprIntervalGraphNode { expr, interval } |
172 | 4.20k | } |
173 | | |
174 | | /// Get the interval object representing the range of the expression. |
175 | 199k | pub fn interval(&self) -> &Interval { |
176 | 199k | &self.interval |
177 | 199k | } |
178 | | |
179 | | /// This function creates a DAEG node from DataFusion's [`ExprTreeNode`] |
180 | | /// object. Literals are created with definite, singleton intervals while |
181 | | /// any other expression starts with an indefinite interval ([-∞, ∞]). |
182 | 14.4k | pub fn make_node(node: &ExprTreeNode<NodeIndex>, schema: &Schema) -> Result<Self> { |
183 | 14.4k | let expr = Arc::clone(&node.expr); |
184 | 14.4k | if let Some(literal4.20k ) = expr.as_any().downcast_ref::<Literal>() { |
185 | 4.20k | let value = literal.value(); |
186 | 4.20k | Interval::try_new(value.clone(), value.clone()) |
187 | 4.20k | .map(|interval| Self::new_with_interval(expr, interval)) |
188 | | } else { |
189 | 10.2k | expr.data_type(schema) |
190 | 10.2k | .and_then(|dt| Self::new_unbounded(expr, &dt)) |
191 | | } |
192 | 14.4k | } |
193 | | } |
194 | | |
195 | | impl PartialEq for ExprIntervalGraphNode { |
196 | 0 | fn eq(&self, other: &Self) -> bool { |
197 | 0 | self.expr.eq(&other.expr) |
198 | 0 | } |
199 | | } |
200 | | |
201 | | /// This function refines intervals `left_child` and `right_child` by applying |
202 | | /// constraint propagation through `parent` via operation. The main idea is |
203 | | /// that we can shrink ranges of variables x and y using parent interval p. |
204 | | /// |
205 | | /// Assuming that x,y and p has ranges [xL, xU], [yL, yU], and [pL, pU], we |
206 | | /// apply the following operations: |
207 | | /// - For plus operation, specifically, we would first do |
208 | | /// - [xL, xU] <- ([pL, pU] - [yL, yU]) ∩ [xL, xU], and then |
209 | | /// - [yL, yU] <- ([pL, pU] - [xL, xU]) ∩ [yL, yU]. |
210 | | /// - For minus operation, specifically, we would first do |
211 | | /// - [xL, xU] <- ([yL, yU] + [pL, pU]) ∩ [xL, xU], and then |
212 | | /// - [yL, yU] <- ([xL, xU] - [pL, pU]) ∩ [yL, yU]. |
213 | | /// - For multiplication operation, specifically, we would first do |
214 | | /// - [xL, xU] <- ([pL, pU] / [yL, yU]) ∩ [xL, xU], and then |
215 | | /// - [yL, yU] <- ([pL, pU] / [xL, xU]) ∩ [yL, yU]. |
216 | | /// - For division operation, specifically, we would first do |
217 | | /// - [xL, xU] <- ([yL, yU] * [pL, pU]) ∩ [xL, xU], and then |
218 | | /// - [yL, yU] <- ([xL, xU] / [pL, pU]) ∩ [yL, yU]. |
219 | 22.1k | pub fn propagate_arithmetic( |
220 | 22.1k | op: &Operator, |
221 | 22.1k | parent: &Interval, |
222 | 22.1k | left_child: &Interval, |
223 | 22.1k | right_child: &Interval, |
224 | 22.1k | ) -> Result<Option<(Interval, Interval)>> { |
225 | 22.1k | let inverse_op = get_inverse_op(*op)?0 ; |
226 | 22.1k | match (left_child.data_type(), right_child.data_type()) { |
227 | | // If we have a child whose type is a time interval (i.e. DataType::Interval), |
228 | | // we need special handling since timestamp differencing results in a |
229 | | // Duration type. |
230 | | (DataType::Timestamp(..), DataType::Interval(_)) => { |
231 | 1.34k | propagate_time_interval_at_right( |
232 | 1.34k | left_child, |
233 | 1.34k | right_child, |
234 | 1.34k | parent, |
235 | 1.34k | op, |
236 | 1.34k | &inverse_op, |
237 | 1.34k | ) |
238 | | } |
239 | | (DataType::Interval(_), DataType::Timestamp(..)) => { |
240 | 0 | propagate_time_interval_at_left( |
241 | 0 | left_child, |
242 | 0 | right_child, |
243 | 0 | parent, |
244 | 0 | op, |
245 | 0 | &inverse_op, |
246 | 0 | ) |
247 | | } |
248 | | _ => { |
249 | | // First, propagate to the left: |
250 | 20.8k | match apply_operator(&inverse_op, parent, right_child)?0 |
251 | 20.8k | .intersect(left_child)?0 |
252 | | { |
253 | | // Left is feasible: |
254 | 20.8k | Some(value) => Ok( |
255 | 20.8k | // Propagate to the right using the new left. |
256 | 20.8k | propagate_right(&value, parent, right_child, op, &inverse_op)?0 |
257 | 20.8k | .map(|right| (value, right)), |
258 | 20.8k | ), |
259 | | // If the left child is infeasible, short-circuit. |
260 | 0 | None => Ok(None), |
261 | | } |
262 | | } |
263 | | } |
264 | 22.1k | } |
265 | | |
266 | | /// This function refines intervals `left_child` and `right_child` by applying |
267 | | /// comparison propagation through `parent` via operation. The main idea is |
268 | | /// that we can shrink ranges of variables x and y using parent interval p. |
269 | | /// Two intervals can be ordered in 6 ways for a Gt `>` operator: |
270 | | /// ```text |
271 | | /// (1): Infeasible, short-circuit |
272 | | /// left: | ================ | |
273 | | /// right: | ======================== | |
274 | | /// |
275 | | /// (2): Update both interval |
276 | | /// left: | ====================== | |
277 | | /// right: | ====================== | |
278 | | /// | |
279 | | /// V |
280 | | /// left: | ======= | |
281 | | /// right: | ======= | |
282 | | /// |
283 | | /// (3): Update left interval |
284 | | /// left: | ============================== | |
285 | | /// right: | ========== | |
286 | | /// | |
287 | | /// V |
288 | | /// left: | ===================== | |
289 | | /// right: | ========== | |
290 | | /// |
291 | | /// (4): Update right interval |
292 | | /// left: | ========== | |
293 | | /// right: | =========================== | |
294 | | /// | |
295 | | /// V |
296 | | /// left: | ========== | |
297 | | /// right | ================== | |
298 | | /// |
299 | | /// (5): No change |
300 | | /// left: | ============================ | |
301 | | /// right: | =================== | |
302 | | /// |
303 | | /// (6): No change |
304 | | /// left: | ==================== | |
305 | | /// right: | =============== | |
306 | | /// |
307 | | /// -inf --------------------------------------------------------------- +inf |
308 | | /// ``` |
309 | 11.4k | pub fn propagate_comparison( |
310 | 11.4k | op: &Operator, |
311 | 11.4k | parent: &Interval, |
312 | 11.4k | left_child: &Interval, |
313 | 11.4k | right_child: &Interval, |
314 | 11.4k | ) -> Result<Option<(Interval, Interval)>> { |
315 | 11.4k | if parent == &Interval::CERTAINLY_TRUE { |
316 | 11.4k | match op { |
317 | 5 | Operator::Eq => left_child.intersect(right_child).map(|result| { |
318 | 5 | result.map(|intersection| (intersection.clone(), intersection)) |
319 | 5 | }), |
320 | 5.07k | Operator::Gt => satisfy_greater(left_child, right_child, true), |
321 | 662 | Operator::GtEq => satisfy_greater(left_child, right_child, false), |
322 | 5.07k | Operator::Lt => satisfy_greater(right_child, left_child, true) |
323 | 5.07k | .map(|t| t.map(reverse_tuple)), |
324 | 677 | Operator::LtEq => satisfy_greater(right_child, left_child, false) |
325 | 677 | .map(|t| t.map(reverse_tuple)), |
326 | 0 | _ => internal_err!( |
327 | 0 | "The operator must be a comparison operator to propagate intervals" |
328 | 0 | ), |
329 | | } |
330 | 0 | } else if parent == &Interval::CERTAINLY_FALSE { |
331 | 0 | match op { |
332 | | Operator::Eq => { |
333 | | // TODO: Propagation is not possible until we support interval sets. |
334 | 0 | Ok(None) |
335 | | } |
336 | 0 | Operator::Gt => satisfy_greater(right_child, left_child, false), |
337 | 0 | Operator::GtEq => satisfy_greater(right_child, left_child, true), |
338 | 0 | Operator::Lt => satisfy_greater(left_child, right_child, false) |
339 | 0 | .map(|t| t.map(reverse_tuple)), |
340 | 0 | Operator::LtEq => satisfy_greater(left_child, right_child, true) |
341 | 0 | .map(|t| t.map(reverse_tuple)), |
342 | 0 | _ => internal_err!( |
343 | 0 | "The operator must be a comparison operator to propagate intervals" |
344 | 0 | ), |
345 | | } |
346 | | } else { |
347 | | // Uncertainty cannot change any end-point of the intervals. |
348 | 0 | Ok(None) |
349 | | } |
350 | 11.4k | } |
351 | | |
352 | | impl ExprIntervalGraph { |
353 | 1.14k | pub fn try_new(expr: Arc<dyn PhysicalExpr>, schema: &Schema) -> Result<Self> { |
354 | | // Build the full graph: |
355 | 1.14k | let (root, graph) = |
356 | 14.4k | build_dag(expr, &1.14k |node| ExprIntervalGraphNode::make_node(node, schema))1.14k ?0 ; |
357 | 1.14k | Ok(Self { graph, root }) |
358 | 1.14k | } |
359 | | |
360 | 0 | pub fn node_count(&self) -> usize { |
361 | 0 | self.graph.node_count() |
362 | 0 | } |
363 | | |
364 | | // Sometimes, we do not want to calculate and/or propagate intervals all |
365 | | // way down to leaf expressions. For example, assume that we have a |
366 | | // `SymmetricHashJoin` which has a child with an output ordering like: |
367 | | // |
368 | | // PhysicalSortExpr { |
369 | | // expr: BinaryExpr('a', +, 'b'), |
370 | | // sort_option: .. |
371 | | // } |
372 | | // |
373 | | // i.e. its output order comes from a clause like "ORDER BY a + b". In such |
374 | | // a case, we must calculate the interval for the BinaryExpr('a', +, 'b') |
375 | | // instead of the columns inside this BinaryExpr, because this interval |
376 | | // decides whether we prune or not. Therefore, children `PhysicalExpr`s of |
377 | | // this `BinaryExpr` may be pruned for performance. The figure below |
378 | | // explains this example visually. |
379 | | // |
380 | | // Note that we just remove the nodes from the DAEG, do not make any change |
381 | | // to the plan itself. |
382 | | // |
383 | | // ```text |
384 | | // |
385 | | // +-----+ +-----+ |
386 | | // | GT | | GT | |
387 | | // +--------| |-------+ +--------| |-------+ |
388 | | // | +-----+ | | +-----+ | |
389 | | // | | | | |
390 | | // +-----+ | +-----+ | |
391 | | // |Cast | | |Cast | | |
392 | | // | | | --\ | | | |
393 | | // +-----+ | ---------- +-----+ | |
394 | | // | | --/ | | |
395 | | // | | | | |
396 | | // +-----+ +-----+ +-----+ +-----+ |
397 | | // +--|Plus |--+ +--|Plus |--+ |Plus | +--|Plus |--+ |
398 | | // | | | | | | | | | | | | | | |
399 | | // Prune from here | +-----+ | | +-----+ | +-----+ | +-----+ | |
400 | | // ------------------------------------ | | | | |
401 | | // | | | | | | |
402 | | // +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ |
403 | | // | a | | b | | c | | 2 | | c | | 2 | |
404 | | // | | | | | | | | | | | | |
405 | | // +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ |
406 | | // |
407 | | // ``` |
408 | | |
409 | | /// This function associates stable node indices with [`PhysicalExpr`]s so |
410 | | /// that we can match `Arc<dyn PhysicalExpr>` and NodeIndex objects during |
411 | | /// membership tests. |
412 | 1.14k | pub fn gather_node_indices( |
413 | 1.14k | &mut self, |
414 | 1.14k | exprs: &[Arc<dyn PhysicalExpr>], |
415 | 1.14k | ) -> Vec<(Arc<dyn PhysicalExpr>, usize)> { |
416 | 1.14k | let graph = &self.graph; |
417 | 1.14k | let mut bfs = Bfs::new(graph, self.root); |
418 | 1.14k | // We collect the node indices (usize) of [PhysicalExpr]s in the order |
419 | 1.14k | // given by argument `exprs`. To preserve this order, we initialize each |
420 | 1.14k | // expression's node index with usize::MAX, and then find the corresponding |
421 | 1.14k | // node indices by traversing the graph. |
422 | 1.14k | let mut removals = vec![]; |
423 | 1.14k | let mut expr_node_indices = exprs |
424 | 1.14k | .iter() |
425 | 2.25k | .map(|e| (Arc::clone(e), usize::MAX)) |
426 | 1.14k | .collect::<Vec<_>>(); |
427 | 15.5k | while let Some(node14.4k ) = bfs.next(graph) { |
428 | | // Get the plan corresponding to this node: |
429 | 14.4k | let expr = &graph[node].expr; |
430 | | // If the current expression is among `exprs`, slate its children |
431 | | // for removal: |
432 | 27.6k | if let Some(value2.25k ) = exprs.iter().position(14.4k |e| expr.eq(e))14.4k { |
433 | | // Update the node index of the associated `PhysicalExpr`: |
434 | 2.25k | expr_node_indices[value].1 = node.index(); |
435 | 2.25k | for edge128 in graph.edges_directed(node, Outgoing) { |
436 | 128 | // Slate the child for removal, do not remove immediately. |
437 | 128 | removals.push(edge.id()); |
438 | 128 | } |
439 | 12.1k | } |
440 | | } |
441 | 1.26k | for edge_idx128 in removals { |
442 | 128 | self.graph.remove_edge(edge_idx); |
443 | 128 | } |
444 | | // Get the set of node indices reachable from the root node: |
445 | 1.14k | let connected_nodes = self.connected_nodes(); |
446 | 1.14k | // Remove nodes not connected to the root node: |
447 | 1.14k | self.graph |
448 | 14.4k | .retain_nodes(|_, index| connected_nodes.contains(&index)); |
449 | 1.14k | expr_node_indices |
450 | 1.14k | } |
451 | | |
452 | | /// Returns the set of node indices reachable from the root node via a |
453 | | /// simple depth-first search. |
454 | 1.14k | fn connected_nodes(&self) -> HashSet<NodeIndex> { |
455 | 1.14k | let mut nodes = HashSet::new(); |
456 | 1.14k | let mut dfs = Dfs::new(&self.graph, self.root); |
457 | 15.4k | while let Some(node14.2k ) = dfs.next(&self.graph) { |
458 | 14.2k | nodes.insert(node); |
459 | 14.2k | } |
460 | 1.14k | nodes |
461 | 1.14k | } |
462 | | |
463 | | /// Updates intervals for all expressions in the DAEG by successive |
464 | | /// bottom-up and top-down traversals. |
465 | 5.76k | pub fn update_ranges( |
466 | 5.76k | &mut self, |
467 | 5.76k | leaf_bounds: &mut [(usize, Interval)], |
468 | 5.76k | given_range: Interval, |
469 | 5.76k | ) -> Result<PropagationResult> { |
470 | 5.76k | self.assign_intervals(leaf_bounds); |
471 | 5.76k | let bounds = self.evaluate_bounds()?0 ; |
472 | | // There are three possible cases to consider: |
473 | | // (1) given_range ⊇ bounds => Nothing to propagate |
474 | | // (2) ∅ ⊂ (given_range ∩ bounds) ⊂ bounds => Can propagate |
475 | | // (3) Disjoint sets => Infeasible |
476 | 5.76k | if given_range.contains(bounds)?0 == Interval::CERTAINLY_TRUE { |
477 | | // First case: |
478 | 2 | Ok(PropagationResult::CannotPropagate) |
479 | 5.75k | } else if bounds.contains(&given_range)?0 != Interval::CERTAINLY_FALSE { |
480 | | // Second case: |
481 | 5.75k | let result = self.propagate_constraints(given_range); |
482 | 5.75k | self.update_intervals(leaf_bounds); |
483 | 5.75k | result |
484 | | } else { |
485 | | // Third case: |
486 | 2 | Ok(PropagationResult::Infeasible) |
487 | | } |
488 | 5.76k | } |
489 | | |
490 | | /// This function assigns given ranges to expressions in the DAEG. |
491 | | /// The argument `assignments` associates indices of sought expressions |
492 | | /// with their corresponding new ranges. |
493 | 5.76k | pub fn assign_intervals(&mut self, assignments: &[(usize, Interval)]) { |
494 | 17.2k | for (index, interval11.4k ) in assignments { |
495 | 11.4k | let node_index = NodeIndex::from(*index as DefaultIx); |
496 | 11.4k | self.graph[node_index].interval = interval.clone(); |
497 | 11.4k | } |
498 | 5.76k | } |
499 | | |
500 | | /// This function fetches ranges of expressions from the DAEG. The argument |
501 | | /// `assignments` associates indices of sought expressions with their ranges, |
502 | | /// which this function modifies to reflect the intervals in the DAEG. |
503 | 5.75k | pub fn update_intervals(&self, assignments: &mut [(usize, Interval)]) { |
504 | 11.4k | for (index, interval) in assignments.iter_mut()5.75k { |
505 | 11.4k | let node_index = NodeIndex::from(*index as DefaultIx); |
506 | 11.4k | *interval = self.graph[node_index].interval.clone(); |
507 | 11.4k | } |
508 | 5.75k | } |
509 | | |
510 | | /// Computes bounds for an expression using interval arithmetic via a |
511 | | /// bottom-up traversal. |
512 | | /// |
513 | | /// # Arguments |
514 | | /// * `leaf_bounds` - &[(usize, Interval)]. Provide NodeIndex, Interval tuples for leaf variables. |
515 | | /// |
516 | | /// # Examples |
517 | | /// |
518 | | /// ``` |
519 | | /// use arrow::datatypes::DataType; |
520 | | /// use arrow::datatypes::Field; |
521 | | /// use arrow::datatypes::Schema; |
522 | | /// use datafusion_common::ScalarValue; |
523 | | /// use datafusion_expr::interval_arithmetic::Interval; |
524 | | /// use datafusion_expr::Operator; |
525 | | /// use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; |
526 | | /// use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; |
527 | | /// use datafusion_physical_expr::PhysicalExpr; |
528 | | /// use std::sync::Arc; |
529 | | /// |
530 | | /// let expr = Arc::new(BinaryExpr::new( |
531 | | /// Arc::new(Column::new("gnz", 0)), |
532 | | /// Operator::Plus, |
533 | | /// Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), |
534 | | /// )); |
535 | | /// |
536 | | /// let schema = Schema::new(vec![Field::new("gnz".to_string(), DataType::Int32, true)]); |
537 | | /// |
538 | | /// let mut graph = ExprIntervalGraph::try_new(expr, &schema).unwrap(); |
539 | | /// // Do it once, while constructing. |
540 | | /// let node_indices = graph |
541 | | /// .gather_node_indices(&[Arc::new(Column::new("gnz", 0))]); |
542 | | /// let left_index = node_indices.get(0).unwrap().1; |
543 | | /// |
544 | | /// // Provide intervals for leaf variables (here, there is only one). |
545 | | /// let intervals = vec![( |
546 | | /// left_index, |
547 | | /// Interval::make(Some(10), Some(20)).unwrap(), |
548 | | /// )]; |
549 | | /// |
550 | | /// // Evaluate bounds for the composite expression: |
551 | | /// graph.assign_intervals(&intervals); |
552 | | /// assert_eq!( |
553 | | /// graph.evaluate_bounds().unwrap(), |
554 | | /// &Interval::make(Some(20), Some(30)).unwrap(), |
555 | | /// ) |
556 | | /// ``` |
557 | 5.76k | pub fn evaluate_bounds(&mut self) -> Result<&Interval> { |
558 | 5.76k | let mut dfs = DfsPostOrder::new(&self.graph, self.root); |
559 | 78.9k | while let Some(node73.2k ) = dfs.next(&self.graph) { |
560 | 73.2k | let neighbors = self.graph.neighbors_directed(node, Outgoing); |
561 | 73.2k | let mut children_intervals = neighbors |
562 | 79.5k | .map(|child| self.graph[child].interval()) |
563 | 73.2k | .collect::<Vec<_>>(); |
564 | 73.2k | // If the current expression is a leaf, its interval should already |
565 | 73.2k | // be set externally, just continue with the evaluation procedure: |
566 | 73.2k | if !children_intervals.is_empty() { |
567 | | // Reverse to align with `PhysicalExpr`'s children: |
568 | 40.1k | children_intervals.reverse(); |
569 | 40.1k | self.graph[node].interval = |
570 | 40.1k | self.graph[node].expr.evaluate_bounds(&children_intervals)?0 ; |
571 | 33.0k | } |
572 | | } |
573 | 5.76k | Ok(&self.graph[self.root].interval) |
574 | 5.76k | } |
575 | | |
576 | | /// Updates/shrinks bounds for leaf expressions using interval arithmetic |
577 | | /// via a top-down traversal. |
578 | 5.75k | fn propagate_constraints( |
579 | 5.75k | &mut self, |
580 | 5.75k | given_range: Interval, |
581 | 5.75k | ) -> Result<PropagationResult> { |
582 | 5.75k | let mut bfs = Bfs::new(&self.graph, self.root); |
583 | | |
584 | | // Adjust the root node with the given range: |
585 | 5.75k | if let Some(interval) = self.graph[self.root].interval.intersect(given_range)?0 { |
586 | 5.75k | self.graph[self.root].interval = interval; |
587 | 5.75k | } else { |
588 | 0 | return Ok(PropagationResult::Infeasible); |
589 | | } |
590 | | |
591 | 78.9k | while let Some(node73.1k ) = bfs.next(&self.graph) { |
592 | 73.1k | let neighbors = self.graph.neighbors_directed(node, Outgoing); |
593 | 73.1k | let mut children = neighbors.collect::<Vec<_>>(); |
594 | 73.1k | // If the current expression is a leaf, its range is now final. |
595 | 73.1k | // So, just continue with the propagation procedure: |
596 | 73.1k | if children.is_empty() { |
597 | 33.0k | continue; |
598 | 40.1k | } |
599 | 40.1k | // Reverse to align with `PhysicalExpr`'s children: |
600 | 40.1k | children.reverse(); |
601 | 40.1k | let children_intervals = children |
602 | 40.1k | .iter() |
603 | 79.5k | .map(|child| self.graph[*child].interval()) |
604 | 40.1k | .collect::<Vec<_>>(); |
605 | 40.1k | let node_interval = self.graph[node].interval(); |
606 | 40.1k | let propagated_intervals = self.graph[node] |
607 | 40.1k | .expr |
608 | 40.1k | .propagate_constraints(node_interval, &children_intervals)?0 ; |
609 | 40.1k | if let Some(propagated_intervals) = propagated_intervals { |
610 | 79.5k | for (child, interval) in children.into_iter().zip(propagated_intervals)40.1k { |
611 | 79.5k | self.graph[child].interval = interval; |
612 | 79.5k | } |
613 | | } else { |
614 | | // The constraint is infeasible, report: |
615 | 0 | return Ok(PropagationResult::Infeasible); |
616 | | } |
617 | | } |
618 | 5.75k | Ok(PropagationResult::Success) |
619 | 5.75k | } |
620 | | |
621 | | /// Returns the interval associated with the node at the given `index`. |
622 | 33 | pub fn get_interval(&self, index: usize) -> Interval { |
623 | 33 | self.graph[NodeIndex::new(index)].interval.clone() |
624 | 33 | } |
625 | | } |
626 | | |
627 | | /// This is a subfunction of the `propagate_arithmetic` function that propagates to the right child. |
628 | 22.1k | fn propagate_right( |
629 | 22.1k | left: &Interval, |
630 | 22.1k | parent: &Interval, |
631 | 22.1k | right: &Interval, |
632 | 22.1k | op: &Operator, |
633 | 22.1k | inverse_op: &Operator, |
634 | 22.1k | ) -> Result<Option<Interval>> { |
635 | 22.1k | match op { |
636 | 13.3k | Operator::Minus => apply_operator(op, left, parent), |
637 | 8.85k | Operator::Plus => apply_operator(inverse_op, parent, left), |
638 | 0 | Operator::Divide => apply_operator(op, left, parent), |
639 | 0 | Operator::Multiply => apply_operator(inverse_op, parent, left), |
640 | 0 | _ => internal_err!("Interval arithmetic does not support the operator {}", op), |
641 | 0 | }? |
642 | 22.1k | .intersect(right) |
643 | 22.1k | } |
644 | | |
645 | | /// During the propagation of [`Interval`] values on an [`ExprIntervalGraph`], |
646 | | /// if there exists a `timestamp - timestamp` operation, the result would be |
647 | | /// of type `Duration`. However, we may encounter a situation where a time interval |
648 | | /// is involved in an arithmetic operation with a `Duration` type. This function |
649 | | /// offers special handling for such cases, where the time interval resides on |
650 | | /// the left side of the operation. |
651 | 0 | fn propagate_time_interval_at_left( |
652 | 0 | left_child: &Interval, |
653 | 0 | right_child: &Interval, |
654 | 0 | parent: &Interval, |
655 | 0 | op: &Operator, |
656 | 0 | inverse_op: &Operator, |
657 | 0 | ) -> Result<Option<(Interval, Interval)>> { |
658 | | // We check if the child's time interval(s) has a non-zero month or day field(s). |
659 | | // If so, we return it as is without propagating. Otherwise, we first convert |
660 | | // the time intervals to the `Duration` type, then propagate, and then convert |
661 | | // the bounds to time intervals again. |
662 | 0 | let result = if let Some(duration) = convert_interval_type_to_duration(left_child) { |
663 | 0 | match apply_operator(inverse_op, parent, right_child)?.intersect(duration)? { |
664 | 0 | Some(value) => { |
665 | 0 | let left = convert_duration_type_to_interval(&value); |
666 | 0 | let right = propagate_right(&value, parent, right_child, op, inverse_op)?; |
667 | 0 | match (left, right) { |
668 | 0 | (Some(left), Some(right)) => Some((left, right)), |
669 | 0 | _ => None, |
670 | | } |
671 | | } |
672 | 0 | None => None, |
673 | | } |
674 | | } else { |
675 | 0 | propagate_right(left_child, parent, right_child, op, inverse_op)? |
676 | 0 | .map(|right| (left_child.clone(), right)) |
677 | | }; |
678 | 0 | Ok(result) |
679 | 0 | } |
680 | | |
681 | | /// During the propagation of [`Interval`] values on an [`ExprIntervalGraph`], |
682 | | /// if there exists a `timestamp - timestamp` operation, the result would be |
683 | | /// of type `Duration`. However, we may encounter a situation where a time interval |
684 | | /// is involved in an arithmetic operation with a `Duration` type. This function |
685 | | /// offers special handling for such cases, where the time interval resides on |
686 | | /// the right side of the operation. |
687 | 1.34k | fn propagate_time_interval_at_right( |
688 | 1.34k | left_child: &Interval, |
689 | 1.34k | right_child: &Interval, |
690 | 1.34k | parent: &Interval, |
691 | 1.34k | op: &Operator, |
692 | 1.34k | inverse_op: &Operator, |
693 | 1.34k | ) -> Result<Option<(Interval, Interval)>> { |
694 | | // We check if the child's time interval(s) has a non-zero month or day field(s). |
695 | | // If so, we return it as is without propagating. Otherwise, we first convert |
696 | | // the time intervals to the `Duration` type, then propagate, and then convert |
697 | | // the bounds to time intervals again. |
698 | 1.34k | let result = if let Some(duration) = convert_interval_type_to_duration(right_child) { |
699 | 1.34k | match apply_operator(inverse_op, parent, &duration)?0 .intersect(left_child)?0 { |
700 | 1.34k | Some(value) => { |
701 | 1.34k | propagate_right(left_child, parent, &duration, op, inverse_op)?0 |
702 | 1.34k | .and_then(|right| convert_duration_type_to_interval(&right)) |
703 | 1.34k | .map(|right| (value, right)) |
704 | | } |
705 | 0 | None => None, |
706 | | } |
707 | | } else { |
708 | 0 | apply_operator(inverse_op, parent, right_child)? |
709 | 0 | .intersect(left_child)? |
710 | 0 | .map(|value| (value, right_child.clone())) |
711 | | }; |
712 | 1.34k | Ok(result) |
713 | 1.34k | } |
714 | | |
715 | 5.75k | fn reverse_tuple<T, U>((first, second): (T, U)) -> (U, T) { |
716 | 5.75k | (second, first) |
717 | 5.75k | } |
718 | | |
719 | | #[cfg(test)] |
720 | | mod tests { |
721 | | use super::*; |
722 | | use crate::expressions::{BinaryExpr, Column}; |
723 | | use crate::intervals::test_utils::gen_conjunctive_numerical_expr; |
724 | | |
725 | | use arrow::datatypes::TimeUnit; |
726 | | use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; |
727 | | use arrow_schema::Field; |
728 | | use datafusion_common::ScalarValue; |
729 | | |
730 | | use itertools::Itertools; |
731 | | use rand::rngs::StdRng; |
732 | | use rand::{Rng, SeedableRng}; |
733 | | use rstest::*; |
734 | | |
735 | | #[allow(clippy::too_many_arguments)] |
736 | | fn experiment( |
737 | | expr: Arc<dyn PhysicalExpr>, |
738 | | exprs_with_interval: (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>), |
739 | | left_interval: Interval, |
740 | | right_interval: Interval, |
741 | | left_expected: Interval, |
742 | | right_expected: Interval, |
743 | | result: PropagationResult, |
744 | | schema: &Schema, |
745 | | ) -> Result<()> { |
746 | | let col_stats = vec![ |
747 | | (Arc::clone(&exprs_with_interval.0), left_interval), |
748 | | (Arc::clone(&exprs_with_interval.1), right_interval), |
749 | | ]; |
750 | | let expected = vec![ |
751 | | (Arc::clone(&exprs_with_interval.0), left_expected), |
752 | | (Arc::clone(&exprs_with_interval.1), right_expected), |
753 | | ]; |
754 | | let mut graph = ExprIntervalGraph::try_new(expr, schema)?; |
755 | | let expr_indexes = graph.gather_node_indices( |
756 | | &col_stats.iter().map(|(e, _)| Arc::clone(e)).collect_vec(), |
757 | | ); |
758 | | |
759 | | let mut col_stat_nodes = col_stats |
760 | | .iter() |
761 | | .zip(expr_indexes.iter()) |
762 | | .map(|((_, interval), (_, index))| (*index, interval.clone())) |
763 | | .collect_vec(); |
764 | | let expected_nodes = expected |
765 | | .iter() |
766 | | .zip(expr_indexes.iter()) |
767 | | .map(|((_, interval), (_, index))| (*index, interval.clone())) |
768 | | .collect_vec(); |
769 | | |
770 | | let exp_result = |
771 | | graph.update_ranges(&mut col_stat_nodes[..], Interval::CERTAINLY_TRUE)?; |
772 | | assert_eq!(exp_result, result); |
773 | | col_stat_nodes.iter().zip(expected_nodes.iter()).for_each( |
774 | | |((_, calculated_interval_node), (_, expected))| { |
775 | | // NOTE: These randomized tests only check for conservative containment, |
776 | | // not openness/closedness of endpoints. |
777 | | |
778 | | // Calculated bounds are relaxed by 1 to cover all strict and |
779 | | // and non-strict comparison cases since we have only closed bounds. |
780 | | let one = ScalarValue::new_one(&expected.data_type()).unwrap(); |
781 | | assert!( |
782 | | calculated_interval_node.lower() |
783 | | <= &expected.lower().add(&one).unwrap(), |
784 | | "{}", |
785 | | format!( |
786 | | "Calculated {} must be less than or equal {}", |
787 | | calculated_interval_node.lower(), |
788 | | expected.lower() |
789 | | ) |
790 | | ); |
791 | | assert!( |
792 | | calculated_interval_node.upper() |
793 | | >= &expected.upper().sub(&one).unwrap(), |
794 | | "{}", |
795 | | format!( |
796 | | "Calculated {} must be greater than or equal {}", |
797 | | calculated_interval_node.upper(), |
798 | | expected.upper() |
799 | | ) |
800 | | ); |
801 | | }, |
802 | | ); |
803 | | Ok(()) |
804 | | } |
805 | | |
806 | | macro_rules! generate_cases { |
807 | | ($FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => { |
808 | | fn $FUNC_NAME<const ASC: bool>( |
809 | | expr: Arc<dyn PhysicalExpr>, |
810 | | left_col: Arc<dyn PhysicalExpr>, |
811 | | right_col: Arc<dyn PhysicalExpr>, |
812 | | seed: u64, |
813 | | expr_left: $TYPE, |
814 | | expr_right: $TYPE, |
815 | | ) -> Result<()> { |
816 | | let mut r = StdRng::seed_from_u64(seed); |
817 | | |
818 | | let (left_given, right_given, left_expected, right_expected) = if ASC { |
819 | | let left = r.gen_range((0 as $TYPE)..(1000 as $TYPE)); |
820 | | let right = r.gen_range((0 as $TYPE)..(1000 as $TYPE)); |
821 | | ( |
822 | | (Some(left), None), |
823 | | (Some(right), None), |
824 | | (Some(<$TYPE>::max(left, right + expr_left)), None), |
825 | | (Some(<$TYPE>::max(right, left + expr_right)), None), |
826 | | ) |
827 | | } else { |
828 | | let left = r.gen_range((0 as $TYPE)..(1000 as $TYPE)); |
829 | | let right = r.gen_range((0 as $TYPE)..(1000 as $TYPE)); |
830 | | ( |
831 | | (None, Some(left)), |
832 | | (None, Some(right)), |
833 | | (None, Some(<$TYPE>::min(left, right + expr_left))), |
834 | | (None, Some(<$TYPE>::min(right, left + expr_right))), |
835 | | ) |
836 | | }; |
837 | | |
838 | | experiment( |
839 | | expr, |
840 | | (left_col.clone(), right_col.clone()), |
841 | | Interval::make(left_given.0, left_given.1).unwrap(), |
842 | | Interval::make(right_given.0, right_given.1).unwrap(), |
843 | | Interval::make(left_expected.0, left_expected.1).unwrap(), |
844 | | Interval::make(right_expected.0, right_expected.1).unwrap(), |
845 | | PropagationResult::Success, |
846 | | &Schema::new(vec![ |
847 | | Field::new( |
848 | | left_col.as_any().downcast_ref::<Column>().unwrap().name(), |
849 | | DataType::$SCALAR, |
850 | | true, |
851 | | ), |
852 | | Field::new( |
853 | | right_col.as_any().downcast_ref::<Column>().unwrap().name(), |
854 | | DataType::$SCALAR, |
855 | | true, |
856 | | ), |
857 | | ]), |
858 | | ) |
859 | | } |
860 | | }; |
861 | | } |
862 | | generate_cases!(generate_case_i32, i32, Int32); |
863 | | generate_cases!(generate_case_i64, i64, Int64); |
864 | | generate_cases!(generate_case_f32, f32, Float32); |
865 | | generate_cases!(generate_case_f64, f64, Float64); |
866 | | |
867 | | #[test] |
868 | | fn testing_not_possible() -> Result<()> { |
869 | | let left_col = Arc::new(Column::new("left_watermark", 0)); |
870 | | let right_col = Arc::new(Column::new("right_watermark", 0)); |
871 | | |
872 | | // left_watermark > right_watermark + 5 |
873 | | let left_and_1 = Arc::new(BinaryExpr::new( |
874 | | Arc::clone(&left_col) as Arc<dyn PhysicalExpr>, |
875 | | Operator::Plus, |
876 | | Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), |
877 | | )); |
878 | | let expr = Arc::new(BinaryExpr::new( |
879 | | left_and_1, |
880 | | Operator::Gt, |
881 | | Arc::clone(&right_col) as Arc<dyn PhysicalExpr>, |
882 | | )); |
883 | | experiment( |
884 | | expr, |
885 | | ( |
886 | | Arc::clone(&left_col) as Arc<dyn PhysicalExpr>, |
887 | | Arc::clone(&right_col) as Arc<dyn PhysicalExpr>, |
888 | | ), |
889 | | Interval::make(Some(10_i32), Some(20_i32))?, |
890 | | Interval::make(Some(100), None)?, |
891 | | Interval::make(Some(10), Some(20))?, |
892 | | Interval::make(Some(100), None)?, |
893 | | PropagationResult::Infeasible, |
894 | | &Schema::new(vec![ |
895 | | Field::new( |
896 | | left_col.as_any().downcast_ref::<Column>().unwrap().name(), |
897 | | DataType::Int32, |
898 | | true, |
899 | | ), |
900 | | Field::new( |
901 | | right_col.as_any().downcast_ref::<Column>().unwrap().name(), |
902 | | DataType::Int32, |
903 | | true, |
904 | | ), |
905 | | ]), |
906 | | ) |
907 | | } |
908 | | |
909 | | macro_rules! integer_float_case_1 { |
910 | | ($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => { |
911 | | #[rstest] |
912 | | #[test] |
913 | | fn $TEST_FUNC_NAME( |
914 | | #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)] |
915 | | seed: u64, |
916 | | #[values(Operator::Gt, Operator::GtEq)] greater_op: Operator, |
917 | | #[values(Operator::Lt, Operator::LtEq)] less_op: Operator, |
918 | | ) -> Result<()> { |
919 | | let left_col = Arc::new(Column::new("left_watermark", 0)); |
920 | | let right_col = Arc::new(Column::new("right_watermark", 0)); |
921 | | |
922 | | // left_watermark + 1 > right_watermark + 11 AND left_watermark + 3 < right_watermark + 33 |
923 | | let expr = gen_conjunctive_numerical_expr( |
924 | | left_col.clone(), |
925 | | right_col.clone(), |
926 | | ( |
927 | | Operator::Plus, |
928 | | Operator::Plus, |
929 | | Operator::Plus, |
930 | | Operator::Plus, |
931 | | ), |
932 | | ScalarValue::$SCALAR(Some(1 as $TYPE)), |
933 | | ScalarValue::$SCALAR(Some(11 as $TYPE)), |
934 | | ScalarValue::$SCALAR(Some(3 as $TYPE)), |
935 | | ScalarValue::$SCALAR(Some(33 as $TYPE)), |
936 | | (greater_op, less_op), |
937 | | ); |
938 | | // l > r + 10 AND r > l - 30 |
939 | | let l_gt_r = 10 as $TYPE; |
940 | | let r_gt_l = -30 as $TYPE; |
941 | | $GENERATE_CASE_FUNC_NAME::<true>( |
942 | | expr.clone(), |
943 | | left_col.clone(), |
944 | | right_col.clone(), |
945 | | seed, |
946 | | l_gt_r, |
947 | | r_gt_l, |
948 | | )?; |
949 | | // Descending tests |
950 | | // r < l - 10 AND l < r + 30 |
951 | | let r_lt_l = -l_gt_r; |
952 | | let l_lt_r = -r_gt_l; |
953 | | $GENERATE_CASE_FUNC_NAME::<false>( |
954 | | expr, left_col, right_col, seed, l_lt_r, r_lt_l, |
955 | | ) |
956 | | } |
957 | | }; |
958 | | } |
959 | | |
960 | | integer_float_case_1!(case_1_i32, generate_case_i32, i32, Int32); |
961 | | integer_float_case_1!(case_1_i64, generate_case_i64, i64, Int64); |
962 | | integer_float_case_1!(case_1_f64, generate_case_f64, f64, Float64); |
963 | | integer_float_case_1!(case_1_f32, generate_case_f32, f32, Float32); |
964 | | |
965 | | macro_rules! integer_float_case_2 { |
966 | | ($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => { |
967 | | #[rstest] |
968 | | #[test] |
969 | | fn $TEST_FUNC_NAME( |
970 | | #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)] |
971 | | seed: u64, |
972 | | #[values(Operator::Gt, Operator::GtEq)] greater_op: Operator, |
973 | | #[values(Operator::Lt, Operator::LtEq)] less_op: Operator, |
974 | | ) -> Result<()> { |
975 | | let left_col = Arc::new(Column::new("left_watermark", 0)); |
976 | | let right_col = Arc::new(Column::new("right_watermark", 0)); |
977 | | |
978 | | // left_watermark - 1 > right_watermark + 5 AND left_watermark + 3 < right_watermark + 10 |
979 | | let expr = gen_conjunctive_numerical_expr( |
980 | | left_col.clone(), |
981 | | right_col.clone(), |
982 | | ( |
983 | | Operator::Minus, |
984 | | Operator::Plus, |
985 | | Operator::Plus, |
986 | | Operator::Plus, |
987 | | ), |
988 | | ScalarValue::$SCALAR(Some(1 as $TYPE)), |
989 | | ScalarValue::$SCALAR(Some(5 as $TYPE)), |
990 | | ScalarValue::$SCALAR(Some(3 as $TYPE)), |
991 | | ScalarValue::$SCALAR(Some(10 as $TYPE)), |
992 | | (greater_op, less_op), |
993 | | ); |
994 | | // l > r + 6 AND r > l - 7 |
995 | | let l_gt_r = 6 as $TYPE; |
996 | | let r_gt_l = -7 as $TYPE; |
997 | | $GENERATE_CASE_FUNC_NAME::<true>( |
998 | | expr.clone(), |
999 | | left_col.clone(), |
1000 | | right_col.clone(), |
1001 | | seed, |
1002 | | l_gt_r, |
1003 | | r_gt_l, |
1004 | | )?; |
1005 | | // Descending tests |
1006 | | // r < l - 6 AND l < r + 7 |
1007 | | let r_lt_l = -l_gt_r; |
1008 | | let l_lt_r = -r_gt_l; |
1009 | | $GENERATE_CASE_FUNC_NAME::<false>( |
1010 | | expr, left_col, right_col, seed, l_lt_r, r_lt_l, |
1011 | | ) |
1012 | | } |
1013 | | }; |
1014 | | } |
1015 | | |
1016 | | integer_float_case_2!(case_2_i32, generate_case_i32, i32, Int32); |
1017 | | integer_float_case_2!(case_2_i64, generate_case_i64, i64, Int64); |
1018 | | integer_float_case_2!(case_2_f64, generate_case_f64, f64, Float64); |
1019 | | integer_float_case_2!(case_2_f32, generate_case_f32, f32, Float32); |
1020 | | |
1021 | | macro_rules! integer_float_case_3 { |
1022 | | ($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => { |
1023 | | #[rstest] |
1024 | | #[test] |
1025 | | fn $TEST_FUNC_NAME( |
1026 | | #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)] |
1027 | | seed: u64, |
1028 | | #[values(Operator::Gt, Operator::GtEq)] greater_op: Operator, |
1029 | | #[values(Operator::Lt, Operator::LtEq)] less_op: Operator, |
1030 | | ) -> Result<()> { |
1031 | | let left_col = Arc::new(Column::new("left_watermark", 0)); |
1032 | | let right_col = Arc::new(Column::new("right_watermark", 0)); |
1033 | | |
1034 | | // left_watermark - 1 > right_watermark + 5 AND left_watermark - 3 < right_watermark + 10 |
1035 | | let expr = gen_conjunctive_numerical_expr( |
1036 | | left_col.clone(), |
1037 | | right_col.clone(), |
1038 | | ( |
1039 | | Operator::Minus, |
1040 | | Operator::Plus, |
1041 | | Operator::Minus, |
1042 | | Operator::Plus, |
1043 | | ), |
1044 | | ScalarValue::$SCALAR(Some(1 as $TYPE)), |
1045 | | ScalarValue::$SCALAR(Some(5 as $TYPE)), |
1046 | | ScalarValue::$SCALAR(Some(3 as $TYPE)), |
1047 | | ScalarValue::$SCALAR(Some(10 as $TYPE)), |
1048 | | (greater_op, less_op), |
1049 | | ); |
1050 | | // l > r + 6 AND r > l - 13 |
1051 | | let l_gt_r = 6 as $TYPE; |
1052 | | let r_gt_l = -13 as $TYPE; |
1053 | | $GENERATE_CASE_FUNC_NAME::<true>( |
1054 | | expr.clone(), |
1055 | | left_col.clone(), |
1056 | | right_col.clone(), |
1057 | | seed, |
1058 | | l_gt_r, |
1059 | | r_gt_l, |
1060 | | )?; |
1061 | | // Descending tests |
1062 | | // r < l - 6 AND l < r + 13 |
1063 | | let r_lt_l = -l_gt_r; |
1064 | | let l_lt_r = -r_gt_l; |
1065 | | $GENERATE_CASE_FUNC_NAME::<false>( |
1066 | | expr, left_col, right_col, seed, l_lt_r, r_lt_l, |
1067 | | ) |
1068 | | } |
1069 | | }; |
1070 | | } |
1071 | | |
1072 | | integer_float_case_3!(case_3_i32, generate_case_i32, i32, Int32); |
1073 | | integer_float_case_3!(case_3_i64, generate_case_i64, i64, Int64); |
1074 | | integer_float_case_3!(case_3_f64, generate_case_f64, f64, Float64); |
1075 | | integer_float_case_3!(case_3_f32, generate_case_f32, f32, Float32); |
1076 | | |
1077 | | macro_rules! integer_float_case_4 { |
1078 | | ($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => { |
1079 | | #[rstest] |
1080 | | #[test] |
1081 | | fn $TEST_FUNC_NAME( |
1082 | | #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)] |
1083 | | seed: u64, |
1084 | | #[values(Operator::Gt, Operator::GtEq)] greater_op: Operator, |
1085 | | #[values(Operator::Lt, Operator::LtEq)] less_op: Operator, |
1086 | | ) -> Result<()> { |
1087 | | let left_col = Arc::new(Column::new("left_watermark", 0)); |
1088 | | let right_col = Arc::new(Column::new("right_watermark", 0)); |
1089 | | |
1090 | | // left_watermark - 10 > right_watermark - 5 AND left_watermark - 30 < right_watermark - 3 |
1091 | | let expr = gen_conjunctive_numerical_expr( |
1092 | | left_col.clone(), |
1093 | | right_col.clone(), |
1094 | | ( |
1095 | | Operator::Minus, |
1096 | | Operator::Minus, |
1097 | | Operator::Minus, |
1098 | | Operator::Plus, |
1099 | | ), |
1100 | | ScalarValue::$SCALAR(Some(10 as $TYPE)), |
1101 | | ScalarValue::$SCALAR(Some(5 as $TYPE)), |
1102 | | ScalarValue::$SCALAR(Some(3 as $TYPE)), |
1103 | | ScalarValue::$SCALAR(Some(10 as $TYPE)), |
1104 | | (greater_op, less_op), |
1105 | | ); |
1106 | | // l > r + 5 AND r > l - 13 |
1107 | | let l_gt_r = 5 as $TYPE; |
1108 | | let r_gt_l = -13 as $TYPE; |
1109 | | $GENERATE_CASE_FUNC_NAME::<true>( |
1110 | | expr.clone(), |
1111 | | left_col.clone(), |
1112 | | right_col.clone(), |
1113 | | seed, |
1114 | | l_gt_r, |
1115 | | r_gt_l, |
1116 | | )?; |
1117 | | // Descending tests |
1118 | | // r < l - 5 AND l < r + 13 |
1119 | | let r_lt_l = -l_gt_r; |
1120 | | let l_lt_r = -r_gt_l; |
1121 | | $GENERATE_CASE_FUNC_NAME::<false>( |
1122 | | expr, left_col, right_col, seed, l_lt_r, r_lt_l, |
1123 | | ) |
1124 | | } |
1125 | | }; |
1126 | | } |
1127 | | |
1128 | | integer_float_case_4!(case_4_i32, generate_case_i32, i32, Int32); |
1129 | | integer_float_case_4!(case_4_i64, generate_case_i64, i64, Int64); |
1130 | | integer_float_case_4!(case_4_f64, generate_case_f64, f64, Float64); |
1131 | | integer_float_case_4!(case_4_f32, generate_case_f32, f32, Float32); |
1132 | | |
1133 | | macro_rules! integer_float_case_5 { |
1134 | | ($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => { |
1135 | | #[rstest] |
1136 | | #[test] |
1137 | | fn $TEST_FUNC_NAME( |
1138 | | #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)] |
1139 | | seed: u64, |
1140 | | #[values(Operator::Gt, Operator::GtEq)] greater_op: Operator, |
1141 | | #[values(Operator::Lt, Operator::LtEq)] less_op: Operator, |
1142 | | ) -> Result<()> { |
1143 | | let left_col = Arc::new(Column::new("left_watermark", 0)); |
1144 | | let right_col = Arc::new(Column::new("right_watermark", 0)); |
1145 | | |
1146 | | // left_watermark - 10 > right_watermark - 5 AND left_watermark - 30 < right_watermark - 3 |
1147 | | let expr = gen_conjunctive_numerical_expr( |
1148 | | left_col.clone(), |
1149 | | right_col.clone(), |
1150 | | ( |
1151 | | Operator::Minus, |
1152 | | Operator::Minus, |
1153 | | Operator::Minus, |
1154 | | Operator::Minus, |
1155 | | ), |
1156 | | ScalarValue::$SCALAR(Some(10 as $TYPE)), |
1157 | | ScalarValue::$SCALAR(Some(5 as $TYPE)), |
1158 | | ScalarValue::$SCALAR(Some(30 as $TYPE)), |
1159 | | ScalarValue::$SCALAR(Some(3 as $TYPE)), |
1160 | | (greater_op, less_op), |
1161 | | ); |
1162 | | // l > r + 5 AND r > l - 27 |
1163 | | let l_gt_r = 5 as $TYPE; |
1164 | | let r_gt_l = -27 as $TYPE; |
1165 | | $GENERATE_CASE_FUNC_NAME::<true>( |
1166 | | expr.clone(), |
1167 | | left_col.clone(), |
1168 | | right_col.clone(), |
1169 | | seed, |
1170 | | l_gt_r, |
1171 | | r_gt_l, |
1172 | | )?; |
1173 | | // Descending tests |
1174 | | // r < l - 5 AND l < r + 27 |
1175 | | let r_lt_l = -l_gt_r; |
1176 | | let l_lt_r = -r_gt_l; |
1177 | | $GENERATE_CASE_FUNC_NAME::<false>( |
1178 | | expr, left_col, right_col, seed, l_lt_r, r_lt_l, |
1179 | | ) |
1180 | | } |
1181 | | }; |
1182 | | } |
1183 | | |
1184 | | integer_float_case_5!(case_5_i32, generate_case_i32, i32, Int32); |
1185 | | integer_float_case_5!(case_5_i64, generate_case_i64, i64, Int64); |
1186 | | integer_float_case_5!(case_5_f64, generate_case_f64, f64, Float64); |
1187 | | integer_float_case_5!(case_5_f32, generate_case_f32, f32, Float32); |
1188 | | |
1189 | | #[test] |
1190 | | fn test_gather_node_indices_dont_remove() -> Result<()> { |
1191 | | // Expression: a@0 + b@1 + 1 > a@0 - b@1, given a@0 + b@1. |
1192 | | // Do not remove a@0 or b@1, only remove edges since a@0 - b@1 also |
1193 | | // depends on leaf nodes a@0 and b@1. |
1194 | | let left_expr = Arc::new(BinaryExpr::new( |
1195 | | Arc::new(BinaryExpr::new( |
1196 | | Arc::new(Column::new("a", 0)), |
1197 | | Operator::Plus, |
1198 | | Arc::new(Column::new("b", 1)), |
1199 | | )), |
1200 | | Operator::Plus, |
1201 | | Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), |
1202 | | )); |
1203 | | |
1204 | | let right_expr = Arc::new(BinaryExpr::new( |
1205 | | Arc::new(Column::new("a", 0)), |
1206 | | Operator::Minus, |
1207 | | Arc::new(Column::new("b", 1)), |
1208 | | )); |
1209 | | let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr)); |
1210 | | let mut graph = ExprIntervalGraph::try_new( |
1211 | | expr, |
1212 | | &Schema::new(vec![ |
1213 | | Field::new("a", DataType::Int32, true), |
1214 | | Field::new("b", DataType::Int32, true), |
1215 | | ]), |
1216 | | ) |
1217 | | .unwrap(); |
1218 | | // Define a test leaf node. |
1219 | | let leaf_node = Arc::new(BinaryExpr::new( |
1220 | | Arc::new(Column::new("a", 0)), |
1221 | | Operator::Plus, |
1222 | | Arc::new(Column::new("b", 1)), |
1223 | | )); |
1224 | | // Store the current node count. |
1225 | | let prev_node_count = graph.node_count(); |
1226 | | // Gather the index of node in the expression graph that match the test leaf node. |
1227 | | graph.gather_node_indices(&[leaf_node]); |
1228 | | // Store the final node count. |
1229 | | let final_node_count = graph.node_count(); |
1230 | | // Assert that the final node count is equal the previous node count. |
1231 | | // This means we did not remove any node. |
1232 | | assert_eq!(prev_node_count, final_node_count); |
1233 | | Ok(()) |
1234 | | } |
1235 | | |
1236 | | #[test] |
1237 | | fn test_gather_node_indices_remove() -> Result<()> { |
1238 | | // Expression: a@0 + b@1 + 1 > y@0 - z@1, given a@0 + b@1. |
1239 | | // We expect to remove two nodes since we do not need a@ and b@. |
1240 | | let left_expr = Arc::new(BinaryExpr::new( |
1241 | | Arc::new(BinaryExpr::new( |
1242 | | Arc::new(Column::new("a", 0)), |
1243 | | Operator::Plus, |
1244 | | Arc::new(Column::new("b", 1)), |
1245 | | )), |
1246 | | Operator::Plus, |
1247 | | Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), |
1248 | | )); |
1249 | | |
1250 | | let right_expr = Arc::new(BinaryExpr::new( |
1251 | | Arc::new(Column::new("y", 0)), |
1252 | | Operator::Minus, |
1253 | | Arc::new(Column::new("z", 1)), |
1254 | | )); |
1255 | | let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr)); |
1256 | | let mut graph = ExprIntervalGraph::try_new( |
1257 | | expr, |
1258 | | &Schema::new(vec![ |
1259 | | Field::new("a", DataType::Int32, true), |
1260 | | Field::new("b", DataType::Int32, true), |
1261 | | Field::new("y", DataType::Int32, true), |
1262 | | Field::new("z", DataType::Int32, true), |
1263 | | ]), |
1264 | | ) |
1265 | | .unwrap(); |
1266 | | // Define a test leaf node. |
1267 | | let leaf_node = Arc::new(BinaryExpr::new( |
1268 | | Arc::new(Column::new("a", 0)), |
1269 | | Operator::Plus, |
1270 | | Arc::new(Column::new("b", 1)), |
1271 | | )); |
1272 | | // Store the current node count. |
1273 | | let prev_node_count = graph.node_count(); |
1274 | | // Gather the index of node in the expression graph that match the test leaf node. |
1275 | | graph.gather_node_indices(&[leaf_node]); |
1276 | | // Store the final node count. |
1277 | | let final_node_count = graph.node_count(); |
1278 | | // Assert that the final node count is two less than the previous node |
1279 | | // count; i.e. that we did remove two nodes. |
1280 | | assert_eq!(prev_node_count, final_node_count + 2); |
1281 | | Ok(()) |
1282 | | } |
1283 | | |
1284 | | #[test] |
1285 | | fn test_gather_node_indices_remove_one() -> Result<()> { |
1286 | | // Expression: a@0 + b@1 + 1 > a@0 - z@1, given a@0 + b@1. |
1287 | | // We expect to remove one nodesince we still need a@ but not b@. |
1288 | | let left_expr = Arc::new(BinaryExpr::new( |
1289 | | Arc::new(BinaryExpr::new( |
1290 | | Arc::new(Column::new("a", 0)), |
1291 | | Operator::Plus, |
1292 | | Arc::new(Column::new("b", 1)), |
1293 | | )), |
1294 | | Operator::Plus, |
1295 | | Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), |
1296 | | )); |
1297 | | |
1298 | | let right_expr = Arc::new(BinaryExpr::new( |
1299 | | Arc::new(Column::new("a", 0)), |
1300 | | Operator::Minus, |
1301 | | Arc::new(Column::new("z", 1)), |
1302 | | )); |
1303 | | let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr)); |
1304 | | let mut graph = ExprIntervalGraph::try_new( |
1305 | | expr, |
1306 | | &Schema::new(vec![ |
1307 | | Field::new("a", DataType::Int32, true), |
1308 | | Field::new("b", DataType::Int32, true), |
1309 | | Field::new("z", DataType::Int32, true), |
1310 | | ]), |
1311 | | ) |
1312 | | .unwrap(); |
1313 | | // Define a test leaf node. |
1314 | | let leaf_node = Arc::new(BinaryExpr::new( |
1315 | | Arc::new(Column::new("a", 0)), |
1316 | | Operator::Plus, |
1317 | | Arc::new(Column::new("b", 1)), |
1318 | | )); |
1319 | | // Store the current node count. |
1320 | | let prev_node_count = graph.node_count(); |
1321 | | // Gather the index of node in the expression graph that match the test leaf node. |
1322 | | graph.gather_node_indices(&[leaf_node]); |
1323 | | // Store the final node count. |
1324 | | let final_node_count = graph.node_count(); |
1325 | | // Assert that the final node count is one less than the previous node |
1326 | | // count; i.e. that we did remove two nodes. |
1327 | | assert_eq!(prev_node_count, final_node_count + 1); |
1328 | | Ok(()) |
1329 | | } |
1330 | | |
1331 | | #[test] |
1332 | | fn test_gather_node_indices_cannot_provide() -> Result<()> { |
1333 | | // Expression: a@0 + 1 + b@1 > y@0 - z@1 -> provide a@0 + b@1 |
1334 | | // TODO: We expect nodes a@0 and b@1 to be pruned, and intervals to be provided from the a@0 + b@1 node. |
1335 | | // However, we do not have an exact node for a@0 + b@1 due to the binary tree structure of the expressions. |
1336 | | // Pruning and interval providing for BinaryExpr expressions are more challenging without exact matches. |
1337 | | // Currently, we only support exact matches for BinaryExprs, but we plan to extend support beyond exact matches in the future. |
1338 | | let left_expr = Arc::new(BinaryExpr::new( |
1339 | | Arc::new(BinaryExpr::new( |
1340 | | Arc::new(Column::new("a", 0)), |
1341 | | Operator::Plus, |
1342 | | Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), |
1343 | | )), |
1344 | | Operator::Plus, |
1345 | | Arc::new(Column::new("b", 1)), |
1346 | | )); |
1347 | | |
1348 | | let right_expr = Arc::new(BinaryExpr::new( |
1349 | | Arc::new(Column::new("y", 0)), |
1350 | | Operator::Minus, |
1351 | | Arc::new(Column::new("z", 1)), |
1352 | | )); |
1353 | | let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr)); |
1354 | | let mut graph = ExprIntervalGraph::try_new( |
1355 | | expr, |
1356 | | &Schema::new(vec![ |
1357 | | Field::new("a", DataType::Int32, true), |
1358 | | Field::new("b", DataType::Int32, true), |
1359 | | Field::new("y", DataType::Int32, true), |
1360 | | Field::new("z", DataType::Int32, true), |
1361 | | ]), |
1362 | | ) |
1363 | | .unwrap(); |
1364 | | // Define a test leaf node. |
1365 | | let leaf_node = Arc::new(BinaryExpr::new( |
1366 | | Arc::new(Column::new("a", 0)), |
1367 | | Operator::Plus, |
1368 | | Arc::new(Column::new("b", 1)), |
1369 | | )); |
1370 | | // Store the current node count. |
1371 | | let prev_node_count = graph.node_count(); |
1372 | | // Gather the index of node in the expression graph that match the test leaf node. |
1373 | | graph.gather_node_indices(&[leaf_node]); |
1374 | | // Store the final node count. |
1375 | | let final_node_count = graph.node_count(); |
1376 | | // Assert that the final node count is equal the previous node count (i.e., no node was pruned). |
1377 | | assert_eq!(prev_node_count, final_node_count); |
1378 | | Ok(()) |
1379 | | } |
1380 | | |
1381 | | #[test] |
1382 | | fn test_propagate_constraints_singleton_interval_at_right() -> Result<()> { |
1383 | | let expression = BinaryExpr::new( |
1384 | | Arc::new(Column::new("ts_column", 0)), |
1385 | | Operator::Plus, |
1386 | | Arc::new(Literal::new(ScalarValue::new_interval_mdn(0, 1, 321))), |
1387 | | ); |
1388 | | let parent = Interval::try_new( |
1389 | | // 15.10.2020 - 10:11:12.000_000_321 AM |
1390 | | ScalarValue::TimestampNanosecond(Some(1_602_756_672_000_000_321), None), |
1391 | | // 16.10.2020 - 10:11:12.000_000_321 AM |
1392 | | ScalarValue::TimestampNanosecond(Some(1_602_843_072_000_000_321), None), |
1393 | | )?; |
1394 | | let left_child = Interval::try_new( |
1395 | | // 10.10.2020 - 10:11:12 AM |
1396 | | ScalarValue::TimestampNanosecond(Some(1_602_324_672_000_000_000), None), |
1397 | | // 20.10.2020 - 10:11:12 AM |
1398 | | ScalarValue::TimestampNanosecond(Some(1_603_188_672_000_000_000), None), |
1399 | | )?; |
1400 | | let right_child = Interval::try_new( |
1401 | | // 1 day 321 ns |
1402 | | ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano { |
1403 | | months: 0, |
1404 | | days: 1, |
1405 | | nanoseconds: 321, |
1406 | | })), |
1407 | | // 1 day 321 ns |
1408 | | ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano { |
1409 | | months: 0, |
1410 | | days: 1, |
1411 | | nanoseconds: 321, |
1412 | | })), |
1413 | | )?; |
1414 | | let children = vec![&left_child, &right_child]; |
1415 | | let result = expression |
1416 | | .propagate_constraints(&parent, &children)? |
1417 | | .unwrap(); |
1418 | | |
1419 | | assert_eq!( |
1420 | | vec![ |
1421 | | Interval::try_new( |
1422 | | // 14.10.2020 - 10:11:12 AM |
1423 | | ScalarValue::TimestampNanosecond( |
1424 | | Some(1_602_670_272_000_000_000), |
1425 | | None |
1426 | | ), |
1427 | | // 15.10.2020 - 10:11:12 AM |
1428 | | ScalarValue::TimestampNanosecond( |
1429 | | Some(1_602_756_672_000_000_000), |
1430 | | None |
1431 | | ), |
1432 | | )?, |
1433 | | Interval::try_new( |
1434 | | // 1 day 321 ns in Duration type |
1435 | | ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano { |
1436 | | months: 0, |
1437 | | days: 1, |
1438 | | nanoseconds: 321, |
1439 | | })), |
1440 | | // 1 day 321 ns in Duration type |
1441 | | ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano { |
1442 | | months: 0, |
1443 | | days: 1, |
1444 | | nanoseconds: 321, |
1445 | | })), |
1446 | | )? |
1447 | | ], |
1448 | | result |
1449 | | ); |
1450 | | |
1451 | | Ok(()) |
1452 | | } |
1453 | | |
1454 | | #[test] |
1455 | | fn test_propagate_constraints_column_interval_at_left() -> Result<()> { |
1456 | | let expression = BinaryExpr::new( |
1457 | | Arc::new(Column::new("interval_column", 1)), |
1458 | | Operator::Plus, |
1459 | | Arc::new(Column::new("ts_column", 0)), |
1460 | | ); |
1461 | | let parent = Interval::try_new( |
1462 | | // 15.10.2020 - 10:11:12 AM |
1463 | | ScalarValue::TimestampMillisecond(Some(1_602_756_672_000), None), |
1464 | | // 16.10.2020 - 10:11:12 AM |
1465 | | ScalarValue::TimestampMillisecond(Some(1_602_843_072_000), None), |
1466 | | )?; |
1467 | | let right_child = Interval::try_new( |
1468 | | // 10.10.2020 - 10:11:12 AM |
1469 | | ScalarValue::TimestampMillisecond(Some(1_602_324_672_000), None), |
1470 | | // 20.10.2020 - 10:11:12 AM |
1471 | | ScalarValue::TimestampMillisecond(Some(1_603_188_672_000), None), |
1472 | | )?; |
1473 | | let left_child = Interval::try_new( |
1474 | | // 2 days in millisecond |
1475 | | ScalarValue::IntervalDayTime(Some(IntervalDayTime { |
1476 | | days: 0, |
1477 | | milliseconds: 172_800_000, |
1478 | | })), |
1479 | | // 10 days in millisecond |
1480 | | ScalarValue::IntervalDayTime(Some(IntervalDayTime { |
1481 | | days: 0, |
1482 | | milliseconds: 864_000_000, |
1483 | | })), |
1484 | | )?; |
1485 | | let children = vec![&left_child, &right_child]; |
1486 | | let result = expression |
1487 | | .propagate_constraints(&parent, &children)? |
1488 | | .unwrap(); |
1489 | | |
1490 | | assert_eq!( |
1491 | | vec![ |
1492 | | Interval::try_new( |
1493 | | // 2 days in millisecond |
1494 | | ScalarValue::IntervalDayTime(Some(IntervalDayTime { |
1495 | | days: 0, |
1496 | | milliseconds: 172_800_000, |
1497 | | })), |
1498 | | // 6 days |
1499 | | ScalarValue::IntervalDayTime(Some(IntervalDayTime { |
1500 | | days: 0, |
1501 | | milliseconds: 518_400_000, |
1502 | | })), |
1503 | | )?, |
1504 | | Interval::try_new( |
1505 | | // 10.10.2020 - 10:11:12 AM |
1506 | | ScalarValue::TimestampMillisecond(Some(1_602_324_672_000), None), |
1507 | | // 14.10.2020 - 10:11:12 AM |
1508 | | ScalarValue::TimestampMillisecond(Some(1_602_670_272_000), None), |
1509 | | )? |
1510 | | ], |
1511 | | result |
1512 | | ); |
1513 | | |
1514 | | Ok(()) |
1515 | | } |
1516 | | |
1517 | | #[test] |
1518 | | fn test_propagate_comparison() -> Result<()> { |
1519 | | // In the examples below: |
1520 | | // `left` is unbounded: [?, ?], |
1521 | | // `right` is known to be [1000,1000] |
1522 | | // so `left` < `right` results in no new knowledge of `right` but knowing that `left` is now < 1000:` [?, 999] |
1523 | | let left = Interval::make_unbounded(&DataType::Int64)?; |
1524 | | let right = Interval::make(Some(1000_i64), Some(1000_i64))?; |
1525 | | assert_eq!( |
1526 | | (Some(( |
1527 | | Interval::make(None, Some(999_i64))?, |
1528 | | Interval::make(Some(1000_i64), Some(1000_i64))?, |
1529 | | ))), |
1530 | | propagate_comparison( |
1531 | | &Operator::Lt, |
1532 | | &Interval::CERTAINLY_TRUE, |
1533 | | &left, |
1534 | | &right |
1535 | | )? |
1536 | | ); |
1537 | | |
1538 | | let left = |
1539 | | Interval::make_unbounded(&DataType::Timestamp(TimeUnit::Nanosecond, None))?; |
1540 | | let right = Interval::try_new( |
1541 | | ScalarValue::TimestampNanosecond(Some(1000), None), |
1542 | | ScalarValue::TimestampNanosecond(Some(1000), None), |
1543 | | )?; |
1544 | | assert_eq!( |
1545 | | (Some(( |
1546 | | Interval::try_new( |
1547 | | ScalarValue::try_from(&DataType::Timestamp( |
1548 | | TimeUnit::Nanosecond, |
1549 | | None |
1550 | | )) |
1551 | | .unwrap(), |
1552 | | ScalarValue::TimestampNanosecond(Some(999), None), |
1553 | | )?, |
1554 | | Interval::try_new( |
1555 | | ScalarValue::TimestampNanosecond(Some(1000), None), |
1556 | | ScalarValue::TimestampNanosecond(Some(1000), None), |
1557 | | )? |
1558 | | ))), |
1559 | | propagate_comparison( |
1560 | | &Operator::Lt, |
1561 | | &Interval::CERTAINLY_TRUE, |
1562 | | &left, |
1563 | | &right |
1564 | | )? |
1565 | | ); |
1566 | | |
1567 | | let left = Interval::make_unbounded(&DataType::Timestamp( |
1568 | | TimeUnit::Nanosecond, |
1569 | | Some("+05:00".into()), |
1570 | | ))?; |
1571 | | let right = Interval::try_new( |
1572 | | ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())), |
1573 | | ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())), |
1574 | | )?; |
1575 | | assert_eq!( |
1576 | | (Some(( |
1577 | | Interval::try_new( |
1578 | | ScalarValue::try_from(&DataType::Timestamp( |
1579 | | TimeUnit::Nanosecond, |
1580 | | Some("+05:00".into()), |
1581 | | )) |
1582 | | .unwrap(), |
1583 | | ScalarValue::TimestampNanosecond(Some(999), Some("+05:00".into())), |
1584 | | )?, |
1585 | | Interval::try_new( |
1586 | | ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())), |
1587 | | ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())), |
1588 | | )? |
1589 | | ))), |
1590 | | propagate_comparison( |
1591 | | &Operator::Lt, |
1592 | | &Interval::CERTAINLY_TRUE, |
1593 | | &left, |
1594 | | &right |
1595 | | )? |
1596 | | ); |
1597 | | |
1598 | | Ok(()) |
1599 | | } |
1600 | | |
1601 | | #[test] |
1602 | | fn test_propagate_or() -> Result<()> { |
1603 | | let expr = Arc::new(BinaryExpr::new( |
1604 | | Arc::new(Column::new("a", 0)), |
1605 | | Operator::Or, |
1606 | | Arc::new(Column::new("b", 1)), |
1607 | | )); |
1608 | | let parent = Interval::CERTAINLY_FALSE; |
1609 | | let children_set = vec![ |
1610 | | vec![&Interval::CERTAINLY_FALSE, &Interval::UNCERTAIN], |
1611 | | vec![&Interval::UNCERTAIN, &Interval::CERTAINLY_FALSE], |
1612 | | vec![&Interval::CERTAINLY_FALSE, &Interval::CERTAINLY_FALSE], |
1613 | | vec![&Interval::UNCERTAIN, &Interval::UNCERTAIN], |
1614 | | ]; |
1615 | | for children in children_set { |
1616 | | assert_eq!( |
1617 | | expr.propagate_constraints(&parent, &children)?.unwrap(), |
1618 | | vec![Interval::CERTAINLY_FALSE, Interval::CERTAINLY_FALSE], |
1619 | | ); |
1620 | | } |
1621 | | |
1622 | | let parent = Interval::CERTAINLY_FALSE; |
1623 | | let children_set = vec![ |
1624 | | vec![&Interval::CERTAINLY_TRUE, &Interval::UNCERTAIN], |
1625 | | vec![&Interval::UNCERTAIN, &Interval::CERTAINLY_TRUE], |
1626 | | ]; |
1627 | | for children in children_set { |
1628 | | assert_eq!(expr.propagate_constraints(&parent, &children)?, None,); |
1629 | | } |
1630 | | |
1631 | | let parent = Interval::CERTAINLY_TRUE; |
1632 | | let children = vec![&Interval::CERTAINLY_FALSE, &Interval::UNCERTAIN]; |
1633 | | assert_eq!( |
1634 | | expr.propagate_constraints(&parent, &children)?.unwrap(), |
1635 | | vec![Interval::CERTAINLY_FALSE, Interval::CERTAINLY_TRUE] |
1636 | | ); |
1637 | | |
1638 | | let parent = Interval::CERTAINLY_TRUE; |
1639 | | let children = vec![&Interval::UNCERTAIN, &Interval::UNCERTAIN]; |
1640 | | assert_eq!( |
1641 | | expr.propagate_constraints(&parent, &children)?.unwrap(), |
1642 | | // Empty means unchanged intervals. |
1643 | | vec![] |
1644 | | ); |
1645 | | |
1646 | | Ok(()) |
1647 | | } |
1648 | | |
1649 | | #[test] |
1650 | | fn test_propagate_certainly_false_and() -> Result<()> { |
1651 | | let expr = Arc::new(BinaryExpr::new( |
1652 | | Arc::new(Column::new("a", 0)), |
1653 | | Operator::And, |
1654 | | Arc::new(Column::new("b", 1)), |
1655 | | )); |
1656 | | let parent = Interval::CERTAINLY_FALSE; |
1657 | | let children_and_results_set = vec![ |
1658 | | ( |
1659 | | vec![&Interval::CERTAINLY_TRUE, &Interval::UNCERTAIN], |
1660 | | vec![Interval::CERTAINLY_TRUE, Interval::CERTAINLY_FALSE], |
1661 | | ), |
1662 | | ( |
1663 | | vec![&Interval::UNCERTAIN, &Interval::CERTAINLY_TRUE], |
1664 | | vec![Interval::CERTAINLY_FALSE, Interval::CERTAINLY_TRUE], |
1665 | | ), |
1666 | | ( |
1667 | | vec![&Interval::UNCERTAIN, &Interval::UNCERTAIN], |
1668 | | // Empty means unchanged intervals. |
1669 | | vec![], |
1670 | | ), |
1671 | | ( |
1672 | | vec![&Interval::CERTAINLY_FALSE, &Interval::UNCERTAIN], |
1673 | | vec![], |
1674 | | ), |
1675 | | ]; |
1676 | | for (children, result) in children_and_results_set { |
1677 | | assert_eq!( |
1678 | | expr.propagate_constraints(&parent, &children)?.unwrap(), |
1679 | | result |
1680 | | ); |
1681 | | } |
1682 | | |
1683 | | Ok(()) |
1684 | | } |
1685 | | } |