Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-expr/src/intervals/cp_solver.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Constraint propagator/solver for custom PhysicalExpr graphs.
19
20
use std::collections::HashSet;
21
use std::fmt::{Display, Formatter};
22
use std::sync::Arc;
23
24
use super::utils::{
25
    convert_duration_type_to_interval, convert_interval_type_to_duration, get_inverse_op,
26
};
27
use crate::expressions::Literal;
28
use crate::utils::{build_dag, ExprTreeNode};
29
use crate::PhysicalExpr;
30
31
use arrow_schema::{DataType, Schema};
32
use datafusion_common::{internal_err, Result};
33
use datafusion_expr::interval_arithmetic::{apply_operator, satisfy_greater, Interval};
34
use datafusion_expr::Operator;
35
36
use petgraph::graph::NodeIndex;
37
use petgraph::stable_graph::{DefaultIx, StableGraph};
38
use petgraph::visit::{Bfs, Dfs, DfsPostOrder, EdgeRef};
39
use petgraph::Outgoing;
40
41
// Interval arithmetic provides a way to perform mathematical operations on
42
// intervals, which represent a range of possible values rather than a single
43
// point value. This allows for the propagation of ranges through mathematical
44
// operations, and can be used to compute bounds for a complicated expression.
45
// The key idea is that by breaking down a complicated expression into simpler
46
// terms, and then combining the bounds for those simpler terms, one can
47
// obtain bounds for the overall expression.
48
//
49
// For example, consider a mathematical expression such as x^2 + y = 4. Since
50
// it would be a binary tree in [PhysicalExpr] notation, this type of an
51
// hierarchical computation is well-suited for a graph based implementation.
52
// In such an implementation, an equation system f(x) = 0 is represented by a
53
// directed acyclic expression graph (DAEG).
54
//
55
// In order to use interval arithmetic to compute bounds for this expression,
56
// one would first determine intervals that represent the possible values of x
57
// and y. Let's say that the interval for x is [1, 2] and the interval for y
58
// is [-3, 1]. In the chart below, you can see how the computation takes place.
59
//
60
// This way of using interval arithmetic to compute bounds for a complex
61
// expression by combining the bounds for the constituent terms within the
62
// original expression allows us to reason about the range of possible values
63
// of the expression. This information later can be used in range pruning of
64
// the provably unnecessary parts of `RecordBatch`es.
65
//
66
// References
67
// 1 - Kabak, Mehmet Ozan. Analog Circuit Start-Up Behavior Analysis: An Interval
68
// Arithmetic Based Approach, Chapter 4. Stanford University, 2015.
69
// 2 - Moore, Ramon E. Interval analysis. Vol. 4. Englewood Cliffs: Prentice-Hall, 1966.
70
// 3 - F. Messine, "Deterministic global optimization using interval constraint
71
// propagation techniques," RAIRO-Operations Research, vol. 38, no. 04,
72
// pp. 277{293, 2004.
73
//
74
// ``` text
75
// Computing bounds for an expression using interval arithmetic.           Constraint propagation through a top-down evaluation of the expression
76
//                                                                         graph using inverse semantics.
77
//
78
//                                                                                 [-2, 5] ∩ [4, 4] = [4, 4]              [4, 4]
79
//             +-----+                        +-----+                                      +-----+                        +-----+
80
//        +----|  +  |----+              +----|  +  |----+                            +----|  +  |----+              +----|  +  |----+
81
//        |    |     |    |              |    |     |    |                            |    |     |    |              |    |     |    |
82
//        |    +-----+    |              |    +-----+    |                            |    +-----+    |              |    +-----+    |
83
//        |               |              |               |                            |               |              |               |
84
//    +-----+           +-----+      +-----+           +-----+                    +-----+           +-----+      +-----+           +-----+
85
//    |   2 |           |  y  |      |   2 | [1, 4]    |  y  |                    |   2 | [1, 4]    |  y  |      |   2 | [1, 4]    |  y  | [0, 1]*
86
//    |[.]  |           |     |      |[.]  |           |     |                    |[.]  |           |     |      |[.]  |           |     |
87
//    +-----+           +-----+      +-----+           +-----+                    +-----+           +-----+      +-----+           +-----+
88
//       |                              |                                            |              [-3, 1]         |
89
//       |                              |                                            |                              |
90
//     +---+                          +---+                                        +---+                          +---+
91
//     | x | [1, 2]                   | x | [1, 2]                                 | x | [1, 2]                   | x | [1, 2]
92
//     +---+                          +---+                                        +---+                          +---+
93
//
94
//  (a) Bottom-up evaluation: Step1 (b) Bottom up evaluation: Step2             (a) Top-down propagation: Step1 (b) Top-down propagation: Step2
95
//
96
//                                        [1 - 3, 4 + 1] = [-2, 5]                                                    [1 - 3, 4 + 1] = [-2, 5]
97
//             +-----+                        +-----+                                      +-----+                        +-----+
98
//        +----|  +  |----+              +----|  +  |----+                            +----|  +  |----+              +----|  +  |----+
99
//        |    |     |    |              |    |     |    |                            |    |     |    |              |    |     |    |
100
//        |    +-----+    |              |    +-----+    |                            |    +-----+    |              |    +-----+    |
101
//        |               |              |               |                            |               |              |               |
102
//    +-----+           +-----+      +-----+           +-----+                    +-----+           +-----+      +-----+           +-----+
103
//    |   2 |[1, 4]     |  y  |      |   2 |[1, 4]     |  y  |                    |   2 |[3, 4]**   |  y  |      |   2 |[1, 4]     |  y  |
104
//    |[.]  |           |     |      |[.]  |           |     |                    |[.]  |           |     |      |[.]  |           |     |
105
//    +-----+           +-----+      +-----+           +-----+                    +-----+           +-----+      +-----+           +-----+
106
//       |              [-3, 1]         |              [-3, 1]                       |              [0, 1]          |              [-3, 1]
107
//       |                              |                                            |                              |
108
//     +---+                          +---+                                        +---+                          +---+
109
//     | x | [1, 2]                   | x | [1, 2]                                 | x | [1, 2]                   | x | [sqrt(3), 2]***
110
//     +---+                          +---+                                        +---+                          +---+
111
//
112
//  (c) Bottom-up evaluation: Step3 (d) Bottom-up evaluation: Step4             (c) Top-down propagation: Step3  (d) Top-down propagation: Step4
113
//
114
//                                                                             * [-3, 1] ∩ ([4, 4] - [1, 4]) = [0, 1]
115
//                                                                             ** [1, 4] ∩ ([4, 4] - [0, 1]) = [3, 4]
116
//                                                                             *** [1, 2] ∩ [sqrt(3), sqrt(4)] = [sqrt(3), 2]
117
// ```
118
119
/// This object implements a directed acyclic expression graph (DAEG) that
120
/// is used to compute ranges for expressions through interval arithmetic.
121
#[derive(Clone, Debug)]
122
pub struct ExprIntervalGraph {
123
    graph: StableGraph<ExprIntervalGraphNode, usize>,
124
    root: NodeIndex,
125
}
126
127
impl ExprIntervalGraph {
128
    /// Estimate size of bytes including `Self`.
129
7.88k
    pub fn size(&self) -> usize {
130
7.88k
        let node_memory_usage = self.graph.node_count()
131
7.88k
            * (std::mem::size_of::<ExprIntervalGraphNode>()
132
7.88k
                + std::mem::size_of::<NodeIndex>());
133
7.88k
        let edge_memory_usage = self.graph.edge_count()
134
7.88k
            * (std::mem::size_of::<usize>() + std::mem::size_of::<NodeIndex>() * 2);
135
7.88k
136
7.88k
        std::mem::size_of_val(self) + node_memory_usage + edge_memory_usage
137
7.88k
    }
138
}
139
140
/// This object encapsulates all possible constraint propagation results.
141
#[derive(PartialEq, Debug)]
142
pub enum PropagationResult {
143
    CannotPropagate,
144
    Infeasible,
145
    Success,
146
}
147
148
/// This is a node in the DAEG; it encapsulates a reference to the actual
149
/// [`PhysicalExpr`] as well as an interval containing expression bounds.
150
#[derive(Clone, Debug)]
151
pub struct ExprIntervalGraphNode {
152
    expr: Arc<dyn PhysicalExpr>,
153
    interval: Interval,
154
}
155
156
impl Display for ExprIntervalGraphNode {
157
0
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
158
0
        write!(f, "{}", self.expr)
159
0
    }
160
}
161
162
impl ExprIntervalGraphNode {
163
    /// Constructs a new DAEG node with an [-∞, ∞] range.
164
10.2k
    pub fn new_unbounded(expr: Arc<dyn PhysicalExpr>, dt: &DataType) -> Result<Self> {
165
10.2k
        Interval::make_unbounded(dt)
166
10.2k
            .map(|interval| ExprIntervalGraphNode { expr, interval })
167
10.2k
    }
168
169
    /// Constructs a new DAEG node with the given range.
170
4.20k
    pub fn new_with_interval(expr: Arc<dyn PhysicalExpr>, interval: Interval) -> Self {
171
4.20k
        ExprIntervalGraphNode { expr, interval }
172
4.20k
    }
173
174
    /// Get the interval object representing the range of the expression.
175
199k
    pub fn interval(&self) -> &Interval {
176
199k
        &self.interval
177
199k
    }
178
179
    /// This function creates a DAEG node from DataFusion's [`ExprTreeNode`]
180
    /// object. Literals are created with definite, singleton intervals while
181
    /// any other expression starts with an indefinite interval ([-∞, ∞]).
182
14.4k
    pub fn make_node(node: &ExprTreeNode<NodeIndex>, schema: &Schema) -> Result<Self> {
183
14.4k
        let expr = Arc::clone(&node.expr);
184
14.4k
        if let Some(
literal4.20k
) = expr.as_any().downcast_ref::<Literal>() {
185
4.20k
            let value = literal.value();
186
4.20k
            Interval::try_new(value.clone(), value.clone())
187
4.20k
                .map(|interval| Self::new_with_interval(expr, interval))
188
        } else {
189
10.2k
            expr.data_type(schema)
190
10.2k
                .and_then(|dt| Self::new_unbounded(expr, &dt))
191
        }
192
14.4k
    }
193
}
194
195
impl PartialEq for ExprIntervalGraphNode {
196
0
    fn eq(&self, other: &Self) -> bool {
197
0
        self.expr.eq(&other.expr)
198
0
    }
199
}
200
201
/// This function refines intervals `left_child` and `right_child` by applying
202
/// constraint propagation through `parent` via operation. The main idea is
203
/// that we can shrink ranges of variables x and y using parent interval p.
204
///
205
/// Assuming that x,y and p has ranges [xL, xU], [yL, yU], and [pL, pU], we
206
/// apply the following operations:
207
/// - For plus operation, specifically, we would first do
208
///     - [xL, xU] <- ([pL, pU] - [yL, yU]) ∩ [xL, xU], and then
209
///     - [yL, yU] <- ([pL, pU] - [xL, xU]) ∩ [yL, yU].
210
/// - For minus operation, specifically, we would first do
211
///     - [xL, xU] <- ([yL, yU] + [pL, pU]) ∩ [xL, xU], and then
212
///     - [yL, yU] <- ([xL, xU] - [pL, pU]) ∩ [yL, yU].
213
/// - For multiplication operation, specifically, we would first do
214
///     - [xL, xU] <- ([pL, pU] / [yL, yU]) ∩ [xL, xU], and then
215
///     - [yL, yU] <- ([pL, pU] / [xL, xU]) ∩ [yL, yU].
216
/// - For division operation, specifically, we would first do
217
///     - [xL, xU] <- ([yL, yU] * [pL, pU]) ∩ [xL, xU], and then
218
///     - [yL, yU] <- ([xL, xU] / [pL, pU]) ∩ [yL, yU].
219
22.1k
pub fn propagate_arithmetic(
220
22.1k
    op: &Operator,
221
22.1k
    parent: &Interval,
222
22.1k
    left_child: &Interval,
223
22.1k
    right_child: &Interval,
224
22.1k
) -> Result<Option<(Interval, Interval)>> {
225
22.1k
    let inverse_op = get_inverse_op(*op)
?0
;
226
22.1k
    match (left_child.data_type(), right_child.data_type()) {
227
        // If we have a child whose type is a time interval (i.e. DataType::Interval),
228
        // we need special handling since timestamp differencing results in a
229
        // Duration type.
230
        (DataType::Timestamp(..), DataType::Interval(_)) => {
231
1.34k
            propagate_time_interval_at_right(
232
1.34k
                left_child,
233
1.34k
                right_child,
234
1.34k
                parent,
235
1.34k
                op,
236
1.34k
                &inverse_op,
237
1.34k
            )
238
        }
239
        (DataType::Interval(_), DataType::Timestamp(..)) => {
240
0
            propagate_time_interval_at_left(
241
0
                left_child,
242
0
                right_child,
243
0
                parent,
244
0
                op,
245
0
                &inverse_op,
246
0
            )
247
        }
248
        _ => {
249
            // First, propagate to the left:
250
20.8k
            match apply_operator(&inverse_op, parent, right_child)
?0
251
20.8k
                .intersect(left_child)
?0
252
            {
253
                // Left is feasible:
254
20.8k
                Some(value) => Ok(
255
20.8k
                    // Propagate to the right using the new left.
256
20.8k
                    propagate_right(&value, parent, right_child, op, &inverse_op)
?0
257
20.8k
                        .map(|right| (value, right)),
258
20.8k
                ),
259
                // If the left child is infeasible, short-circuit.
260
0
                None => Ok(None),
261
            }
262
        }
263
    }
264
22.1k
}
265
266
/// This function refines intervals `left_child` and `right_child` by applying
267
/// comparison propagation through `parent` via operation. The main idea is
268
/// that we can shrink ranges of variables x and y using parent interval p.
269
/// Two intervals can be ordered in 6 ways for a Gt `>` operator:
270
/// ```text
271
///                           (1): Infeasible, short-circuit
272
/// left:   |        ================                                               |
273
/// right:  |                           ========================                    |
274
///
275
///                             (2): Update both interval
276
/// left:   |              ======================                                   |
277
/// right:  |                             ======================                    |
278
///                                          |
279
///                                          V
280
/// left:   |                             =======                                   |
281
/// right:  |                             =======                                   |
282
///
283
///                             (3): Update left interval
284
/// left:   |                  ==============================                       |
285
/// right:  |                           ==========                                  |
286
///                                          |
287
///                                          V
288
/// left:   |                           =====================                       |
289
/// right:  |                           ==========                                  |
290
///
291
///                             (4): Update right interval
292
/// left:   |                           ==========                                  |
293
/// right:  |                   ===========================                         |
294
///                                          |
295
///                                          V
296
/// left:   |                           ==========                                  |
297
/// right   |                   ==================                                  |
298
///
299
///                                   (5): No change
300
/// left:   |                       ============================                    |
301
/// right:  |               ===================                                     |
302
///
303
///                                   (6): No change
304
/// left:   |                                    ====================               |
305
/// right:  |                ===============                                        |
306
///
307
///         -inf --------------------------------------------------------------- +inf
308
/// ```
309
11.4k
pub fn propagate_comparison(
310
11.4k
    op: &Operator,
311
11.4k
    parent: &Interval,
312
11.4k
    left_child: &Interval,
313
11.4k
    right_child: &Interval,
314
11.4k
) -> Result<Option<(Interval, Interval)>> {
315
11.4k
    if parent == &Interval::CERTAINLY_TRUE {
316
11.4k
        match op {
317
5
            Operator::Eq => left_child.intersect(right_child).map(|result| {
318
5
                result.map(|intersection| (intersection.clone(), intersection))
319
5
            }),
320
5.07k
            Operator::Gt => satisfy_greater(left_child, right_child, true),
321
662
            Operator::GtEq => satisfy_greater(left_child, right_child, false),
322
5.07k
            Operator::Lt => satisfy_greater(right_child, left_child, true)
323
5.07k
                .map(|t| t.map(reverse_tuple)),
324
677
            Operator::LtEq => satisfy_greater(right_child, left_child, false)
325
677
                .map(|t| t.map(reverse_tuple)),
326
0
            _ => internal_err!(
327
0
                "The operator must be a comparison operator to propagate intervals"
328
0
            ),
329
        }
330
0
    } else if parent == &Interval::CERTAINLY_FALSE {
331
0
        match op {
332
            Operator::Eq => {
333
                // TODO: Propagation is not possible until we support interval sets.
334
0
                Ok(None)
335
            }
336
0
            Operator::Gt => satisfy_greater(right_child, left_child, false),
337
0
            Operator::GtEq => satisfy_greater(right_child, left_child, true),
338
0
            Operator::Lt => satisfy_greater(left_child, right_child, false)
339
0
                .map(|t| t.map(reverse_tuple)),
340
0
            Operator::LtEq => satisfy_greater(left_child, right_child, true)
341
0
                .map(|t| t.map(reverse_tuple)),
342
0
            _ => internal_err!(
343
0
                "The operator must be a comparison operator to propagate intervals"
344
0
            ),
345
        }
346
    } else {
347
        // Uncertainty cannot change any end-point of the intervals.
348
0
        Ok(None)
349
    }
350
11.4k
}
351
352
impl ExprIntervalGraph {
353
1.14k
    pub fn try_new(expr: Arc<dyn PhysicalExpr>, schema: &Schema) -> Result<Self> {
354
        // Build the full graph:
355
1.14k
        let (root, graph) =
356
14.4k
            
build_dag(expr, &1.14k
|node| ExprIntervalGraphNode::make_node(node, schema)
)1.14k
?0
;
357
1.14k
        Ok(Self { graph, root })
358
1.14k
    }
359
360
0
    pub fn node_count(&self) -> usize {
361
0
        self.graph.node_count()
362
0
    }
363
364
    // Sometimes, we do not want to calculate and/or propagate intervals all
365
    // way down to leaf expressions. For example, assume that we have a
366
    // `SymmetricHashJoin` which has a child with an output ordering like:
367
    //
368
    // PhysicalSortExpr {
369
    //     expr: BinaryExpr('a', +, 'b'),
370
    //     sort_option: ..
371
    // }
372
    //
373
    // i.e. its output order comes from a clause like "ORDER BY a + b". In such
374
    // a case, we must calculate the interval for the BinaryExpr('a', +, 'b')
375
    // instead of the columns inside this BinaryExpr, because this interval
376
    // decides whether we prune or not. Therefore, children `PhysicalExpr`s of
377
    // this `BinaryExpr` may be pruned for performance. The figure below
378
    // explains this example visually.
379
    //
380
    // Note that we just remove the nodes from the DAEG, do not make any change
381
    // to the plan itself.
382
    //
383
    // ```text
384
    //
385
    //                                  +-----+                                          +-----+
386
    //                                  | GT  |                                          | GT  |
387
    //                         +--------|     |-------+                         +--------|     |-------+
388
    //                         |        +-----+       |                         |        +-----+       |
389
    //                         |                      |                         |                      |
390
    //                      +-----+                   |                      +-----+                   |
391
    //                      |Cast |                   |                      |Cast |                   |
392
    //                      |     |                   |             --\      |     |                   |
393
    //                      +-----+                   |       ----------     +-----+                   |
394
    //                         |                      |             --/         |                      |
395
    //                         |                      |                         |                      |
396
    //                      +-----+                +-----+                   +-----+                +-----+
397
    //                   +--|Plus |--+          +--|Plus |--+                |Plus |             +--|Plus |--+
398
    //                   |  |     |  |          |  |     |  |                |     |             |  |     |  |
399
    //  Prune from here  |  +-----+  |          |  +-----+  |                +-----+             |  +-----+  |
400
    //  ------------------------------------    |           |                                    |           |
401
    //                   |           |          |           |                                    |           |
402
    //                +-----+     +-----+    +-----+     +-----+                              +-----+     +-----+
403
    //                | a   |     |  b  |    |  c  |     |  2  |                              |  c  |     |  2  |
404
    //                |     |     |     |    |     |     |     |                              |     |     |     |
405
    //                +-----+     +-----+    +-----+     +-----+                              +-----+     +-----+
406
    //
407
    // ```
408
409
    /// This function associates stable node indices with [`PhysicalExpr`]s so
410
    /// that we can match `Arc<dyn PhysicalExpr>` and NodeIndex objects during
411
    /// membership tests.
412
1.14k
    pub fn gather_node_indices(
413
1.14k
        &mut self,
414
1.14k
        exprs: &[Arc<dyn PhysicalExpr>],
415
1.14k
    ) -> Vec<(Arc<dyn PhysicalExpr>, usize)> {
416
1.14k
        let graph = &self.graph;
417
1.14k
        let mut bfs = Bfs::new(graph, self.root);
418
1.14k
        // We collect the node indices (usize) of [PhysicalExpr]s in the order
419
1.14k
        // given by argument `exprs`. To preserve this order, we initialize each
420
1.14k
        // expression's node index with usize::MAX, and then find the corresponding
421
1.14k
        // node indices by traversing the graph.
422
1.14k
        let mut removals = vec![];
423
1.14k
        let mut expr_node_indices = exprs
424
1.14k
            .iter()
425
2.25k
            .map(|e| (Arc::clone(e), usize::MAX))
426
1.14k
            .collect::<Vec<_>>();
427
15.5k
        while let Some(
node14.4k
) = bfs.next(graph) {
428
            // Get the plan corresponding to this node:
429
14.4k
            let expr = &graph[node].expr;
430
            // If the current expression is among `exprs`, slate its children
431
            // for removal:
432
27.6k
            if let Some(
value2.25k
) =
exprs.iter().position(14.4k
|e| expr.eq(e)
)14.4k
{
433
                // Update the node index of the associated `PhysicalExpr`:
434
2.25k
                expr_node_indices[value].1 = node.index();
435
2.25k
                for 
edge128
in graph.edges_directed(node, Outgoing) {
436
128
                    // Slate the child for removal, do not remove immediately.
437
128
                    removals.push(edge.id());
438
128
                }
439
12.1k
            }
440
        }
441
1.26k
        for 
edge_idx128
in removals {
442
128
            self.graph.remove_edge(edge_idx);
443
128
        }
444
        // Get the set of node indices reachable from the root node:
445
1.14k
        let connected_nodes = self.connected_nodes();
446
1.14k
        // Remove nodes not connected to the root node:
447
1.14k
        self.graph
448
14.4k
            .retain_nodes(|_, index| connected_nodes.contains(&index));
449
1.14k
        expr_node_indices
450
1.14k
    }
451
452
    /// Returns the set of node indices reachable from the root node via a
453
    /// simple depth-first search.
454
1.14k
    fn connected_nodes(&self) -> HashSet<NodeIndex> {
455
1.14k
        let mut nodes = HashSet::new();
456
1.14k
        let mut dfs = Dfs::new(&self.graph, self.root);
457
15.4k
        while let Some(
node14.2k
) = dfs.next(&self.graph) {
458
14.2k
            nodes.insert(node);
459
14.2k
        }
460
1.14k
        nodes
461
1.14k
    }
462
463
    /// Updates intervals for all expressions in the DAEG by successive
464
    /// bottom-up and top-down traversals.
465
5.76k
    pub fn update_ranges(
466
5.76k
        &mut self,
467
5.76k
        leaf_bounds: &mut [(usize, Interval)],
468
5.76k
        given_range: Interval,
469
5.76k
    ) -> Result<PropagationResult> {
470
5.76k
        self.assign_intervals(leaf_bounds);
471
5.76k
        let bounds = self.evaluate_bounds()
?0
;
472
        // There are three possible cases to consider:
473
        // (1) given_range ⊇ bounds => Nothing to propagate
474
        // (2) ∅ ⊂ (given_range ∩ bounds) ⊂ bounds => Can propagate
475
        // (3) Disjoint sets => Infeasible
476
5.76k
        if given_range.contains(bounds)
?0
== Interval::CERTAINLY_TRUE {
477
            // First case:
478
2
            Ok(PropagationResult::CannotPropagate)
479
5.75k
        } else if bounds.contains(&given_range)
?0
!= Interval::CERTAINLY_FALSE {
480
            // Second case:
481
5.75k
            let result = self.propagate_constraints(given_range);
482
5.75k
            self.update_intervals(leaf_bounds);
483
5.75k
            result
484
        } else {
485
            // Third case:
486
2
            Ok(PropagationResult::Infeasible)
487
        }
488
5.76k
    }
489
490
    /// This function assigns given ranges to expressions in the DAEG.
491
    /// The argument `assignments` associates indices of sought expressions
492
    /// with their corresponding new ranges.
493
5.76k
    pub fn assign_intervals(&mut self, assignments: &[(usize, Interval)]) {
494
17.2k
        for (
index, interval11.4k
) in assignments {
495
11.4k
            let node_index = NodeIndex::from(*index as DefaultIx);
496
11.4k
            self.graph[node_index].interval = interval.clone();
497
11.4k
        }
498
5.76k
    }
499
500
    /// This function fetches ranges of expressions from the DAEG. The argument
501
    /// `assignments` associates indices of sought expressions with their ranges,
502
    /// which this function modifies to reflect the intervals in the DAEG.
503
5.75k
    pub fn update_intervals(&self, assignments: &mut [(usize, Interval)]) {
504
11.4k
        for (index, interval) in 
assignments.iter_mut()5.75k
{
505
11.4k
            let node_index = NodeIndex::from(*index as DefaultIx);
506
11.4k
            *interval = self.graph[node_index].interval.clone();
507
11.4k
        }
508
5.75k
    }
509
510
    /// Computes bounds for an expression using interval arithmetic via a
511
    /// bottom-up traversal.
512
    ///
513
    /// # Arguments
514
    /// * `leaf_bounds` - &[(usize, Interval)]. Provide NodeIndex, Interval tuples for leaf variables.
515
    ///
516
    /// # Examples
517
    ///
518
    /// ```
519
    /// use arrow::datatypes::DataType;
520
    /// use arrow::datatypes::Field;
521
    /// use arrow::datatypes::Schema;
522
    /// use datafusion_common::ScalarValue;
523
    /// use datafusion_expr::interval_arithmetic::Interval;
524
    /// use datafusion_expr::Operator;
525
    /// use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal};
526
    /// use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph;
527
    /// use datafusion_physical_expr::PhysicalExpr;
528
    /// use std::sync::Arc;
529
    ///
530
    /// let expr = Arc::new(BinaryExpr::new(
531
    ///     Arc::new(Column::new("gnz", 0)),
532
    ///     Operator::Plus,
533
    ///     Arc::new(Literal::new(ScalarValue::Int32(Some(10)))),
534
    /// ));
535
    ///
536
    /// let schema = Schema::new(vec![Field::new("gnz".to_string(), DataType::Int32, true)]);
537
    ///
538
    /// let mut graph = ExprIntervalGraph::try_new(expr, &schema).unwrap();
539
    /// // Do it once, while constructing.
540
    /// let node_indices = graph
541
    ///     .gather_node_indices(&[Arc::new(Column::new("gnz", 0))]);
542
    /// let left_index = node_indices.get(0).unwrap().1;
543
    ///
544
    /// // Provide intervals for leaf variables (here, there is only one).
545
    /// let intervals = vec![(
546
    ///     left_index,
547
    ///     Interval::make(Some(10), Some(20)).unwrap(),
548
    /// )];
549
    ///
550
    /// // Evaluate bounds for the composite expression:
551
    /// graph.assign_intervals(&intervals);
552
    /// assert_eq!(
553
    ///     graph.evaluate_bounds().unwrap(),
554
    ///     &Interval::make(Some(20), Some(30)).unwrap(),
555
    /// )
556
    /// ```
557
5.76k
    pub fn evaluate_bounds(&mut self) -> Result<&Interval> {
558
5.76k
        let mut dfs = DfsPostOrder::new(&self.graph, self.root);
559
78.9k
        while let Some(
node73.2k
) = dfs.next(&self.graph) {
560
73.2k
            let neighbors = self.graph.neighbors_directed(node, Outgoing);
561
73.2k
            let mut children_intervals = neighbors
562
79.5k
                .map(|child| self.graph[child].interval())
563
73.2k
                .collect::<Vec<_>>();
564
73.2k
            // If the current expression is a leaf, its interval should already
565
73.2k
            // be set externally, just continue with the evaluation procedure:
566
73.2k
            if !children_intervals.is_empty() {
567
                // Reverse to align with `PhysicalExpr`'s children:
568
40.1k
                children_intervals.reverse();
569
40.1k
                self.graph[node].interval =
570
40.1k
                    self.graph[node].expr.evaluate_bounds(&children_intervals)
?0
;
571
33.0k
            }
572
        }
573
5.76k
        Ok(&self.graph[self.root].interval)
574
5.76k
    }
575
576
    /// Updates/shrinks bounds for leaf expressions using interval arithmetic
577
    /// via a top-down traversal.
578
5.75k
    fn propagate_constraints(
579
5.75k
        &mut self,
580
5.75k
        given_range: Interval,
581
5.75k
    ) -> Result<PropagationResult> {
582
5.75k
        let mut bfs = Bfs::new(&self.graph, self.root);
583
584
        // Adjust the root node with the given range:
585
5.75k
        if let Some(interval) = self.graph[self.root].interval.intersect(given_range)
?0
{
586
5.75k
            self.graph[self.root].interval = interval;
587
5.75k
        } else {
588
0
            return Ok(PropagationResult::Infeasible);
589
        }
590
591
78.9k
        while let Some(
node73.1k
) = bfs.next(&self.graph) {
592
73.1k
            let neighbors = self.graph.neighbors_directed(node, Outgoing);
593
73.1k
            let mut children = neighbors.collect::<Vec<_>>();
594
73.1k
            // If the current expression is a leaf, its range is now final.
595
73.1k
            // So, just continue with the propagation procedure:
596
73.1k
            if children.is_empty() {
597
33.0k
                continue;
598
40.1k
            }
599
40.1k
            // Reverse to align with `PhysicalExpr`'s children:
600
40.1k
            children.reverse();
601
40.1k
            let children_intervals = children
602
40.1k
                .iter()
603
79.5k
                .map(|child| self.graph[*child].interval())
604
40.1k
                .collect::<Vec<_>>();
605
40.1k
            let node_interval = self.graph[node].interval();
606
40.1k
            let propagated_intervals = self.graph[node]
607
40.1k
                .expr
608
40.1k
                .propagate_constraints(node_interval, &children_intervals)
?0
;
609
40.1k
            if let Some(propagated_intervals) = propagated_intervals {
610
79.5k
                for (child, interval) in 
children.into_iter().zip(propagated_intervals)40.1k
{
611
79.5k
                    self.graph[child].interval = interval;
612
79.5k
                }
613
            } else {
614
                // The constraint is infeasible, report:
615
0
                return Ok(PropagationResult::Infeasible);
616
            }
617
        }
618
5.75k
        Ok(PropagationResult::Success)
619
5.75k
    }
620
621
    /// Returns the interval associated with the node at the given `index`.
622
33
    pub fn get_interval(&self, index: usize) -> Interval {
623
33
        self.graph[NodeIndex::new(index)].interval.clone()
624
33
    }
625
}
626
627
/// This is a subfunction of the `propagate_arithmetic` function that propagates to the right child.
628
22.1k
fn propagate_right(
629
22.1k
    left: &Interval,
630
22.1k
    parent: &Interval,
631
22.1k
    right: &Interval,
632
22.1k
    op: &Operator,
633
22.1k
    inverse_op: &Operator,
634
22.1k
) -> Result<Option<Interval>> {
635
22.1k
    match op {
636
13.3k
        Operator::Minus => apply_operator(op, left, parent),
637
8.85k
        Operator::Plus => apply_operator(inverse_op, parent, left),
638
0
        Operator::Divide => apply_operator(op, left, parent),
639
0
        Operator::Multiply => apply_operator(inverse_op, parent, left),
640
0
        _ => internal_err!("Interval arithmetic does not support the operator {}", op),
641
0
    }?
642
22.1k
    .intersect(right)
643
22.1k
}
644
645
/// During the propagation of [`Interval`] values on an [`ExprIntervalGraph`],
646
/// if there exists a `timestamp - timestamp` operation, the result would be
647
/// of type `Duration`. However, we may encounter a situation where a time interval
648
/// is involved in an arithmetic operation with a `Duration` type. This function
649
/// offers special handling for such cases, where the time interval resides on
650
/// the left side of the operation.
651
0
fn propagate_time_interval_at_left(
652
0
    left_child: &Interval,
653
0
    right_child: &Interval,
654
0
    parent: &Interval,
655
0
    op: &Operator,
656
0
    inverse_op: &Operator,
657
0
) -> Result<Option<(Interval, Interval)>> {
658
    // We check if the child's time interval(s) has a non-zero month or day field(s).
659
    // If so, we return it as is without propagating. Otherwise, we first convert
660
    // the time intervals to the `Duration` type, then propagate, and then convert
661
    // the bounds to time intervals again.
662
0
    let result = if let Some(duration) = convert_interval_type_to_duration(left_child) {
663
0
        match apply_operator(inverse_op, parent, right_child)?.intersect(duration)? {
664
0
            Some(value) => {
665
0
                let left = convert_duration_type_to_interval(&value);
666
0
                let right = propagate_right(&value, parent, right_child, op, inverse_op)?;
667
0
                match (left, right) {
668
0
                    (Some(left), Some(right)) => Some((left, right)),
669
0
                    _ => None,
670
                }
671
            }
672
0
            None => None,
673
        }
674
    } else {
675
0
        propagate_right(left_child, parent, right_child, op, inverse_op)?
676
0
            .map(|right| (left_child.clone(), right))
677
    };
678
0
    Ok(result)
679
0
}
680
681
/// During the propagation of [`Interval`] values on an [`ExprIntervalGraph`],
682
/// if there exists a `timestamp - timestamp` operation, the result would be
683
/// of type `Duration`. However, we may encounter a situation where a time interval
684
/// is involved in an arithmetic operation with a `Duration` type. This function
685
/// offers special handling for such cases, where the time interval resides on
686
/// the right side of the operation.
687
1.34k
fn propagate_time_interval_at_right(
688
1.34k
    left_child: &Interval,
689
1.34k
    right_child: &Interval,
690
1.34k
    parent: &Interval,
691
1.34k
    op: &Operator,
692
1.34k
    inverse_op: &Operator,
693
1.34k
) -> Result<Option<(Interval, Interval)>> {
694
    // We check if the child's time interval(s) has a non-zero month or day field(s).
695
    // If so, we return it as is without propagating. Otherwise, we first convert
696
    // the time intervals to the `Duration` type, then propagate, and then convert
697
    // the bounds to time intervals again.
698
1.34k
    let result = if let Some(duration) = convert_interval_type_to_duration(right_child) {
699
1.34k
        match apply_operator(inverse_op, parent, &duration)
?0
.intersect(left_child)
?0
{
700
1.34k
            Some(value) => {
701
1.34k
                propagate_right(left_child, parent, &duration, op, inverse_op)
?0
702
1.34k
                    .and_then(|right| convert_duration_type_to_interval(&right))
703
1.34k
                    .map(|right| (value, right))
704
            }
705
0
            None => None,
706
        }
707
    } else {
708
0
        apply_operator(inverse_op, parent, right_child)?
709
0
            .intersect(left_child)?
710
0
            .map(|value| (value, right_child.clone()))
711
    };
712
1.34k
    Ok(result)
713
1.34k
}
714
715
5.75k
fn reverse_tuple<T, U>((first, second): (T, U)) -> (U, T) {
716
5.75k
    (second, first)
717
5.75k
}
718
719
#[cfg(test)]
720
mod tests {
721
    use super::*;
722
    use crate::expressions::{BinaryExpr, Column};
723
    use crate::intervals::test_utils::gen_conjunctive_numerical_expr;
724
725
    use arrow::datatypes::TimeUnit;
726
    use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
727
    use arrow_schema::Field;
728
    use datafusion_common::ScalarValue;
729
730
    use itertools::Itertools;
731
    use rand::rngs::StdRng;
732
    use rand::{Rng, SeedableRng};
733
    use rstest::*;
734
735
    #[allow(clippy::too_many_arguments)]
736
    fn experiment(
737
        expr: Arc<dyn PhysicalExpr>,
738
        exprs_with_interval: (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>),
739
        left_interval: Interval,
740
        right_interval: Interval,
741
        left_expected: Interval,
742
        right_expected: Interval,
743
        result: PropagationResult,
744
        schema: &Schema,
745
    ) -> Result<()> {
746
        let col_stats = vec![
747
            (Arc::clone(&exprs_with_interval.0), left_interval),
748
            (Arc::clone(&exprs_with_interval.1), right_interval),
749
        ];
750
        let expected = vec![
751
            (Arc::clone(&exprs_with_interval.0), left_expected),
752
            (Arc::clone(&exprs_with_interval.1), right_expected),
753
        ];
754
        let mut graph = ExprIntervalGraph::try_new(expr, schema)?;
755
        let expr_indexes = graph.gather_node_indices(
756
            &col_stats.iter().map(|(e, _)| Arc::clone(e)).collect_vec(),
757
        );
758
759
        let mut col_stat_nodes = col_stats
760
            .iter()
761
            .zip(expr_indexes.iter())
762
            .map(|((_, interval), (_, index))| (*index, interval.clone()))
763
            .collect_vec();
764
        let expected_nodes = expected
765
            .iter()
766
            .zip(expr_indexes.iter())
767
            .map(|((_, interval), (_, index))| (*index, interval.clone()))
768
            .collect_vec();
769
770
        let exp_result =
771
            graph.update_ranges(&mut col_stat_nodes[..], Interval::CERTAINLY_TRUE)?;
772
        assert_eq!(exp_result, result);
773
        col_stat_nodes.iter().zip(expected_nodes.iter()).for_each(
774
            |((_, calculated_interval_node), (_, expected))| {
775
                // NOTE: These randomized tests only check for conservative containment,
776
                // not openness/closedness of endpoints.
777
778
                // Calculated bounds are relaxed by 1 to cover all strict and
779
                // and non-strict comparison cases since we have only closed bounds.
780
                let one = ScalarValue::new_one(&expected.data_type()).unwrap();
781
                assert!(
782
                    calculated_interval_node.lower()
783
                        <= &expected.lower().add(&one).unwrap(),
784
                    "{}",
785
                    format!(
786
                        "Calculated {} must be less than or equal {}",
787
                        calculated_interval_node.lower(),
788
                        expected.lower()
789
                    )
790
                );
791
                assert!(
792
                    calculated_interval_node.upper()
793
                        >= &expected.upper().sub(&one).unwrap(),
794
                    "{}",
795
                    format!(
796
                        "Calculated {} must be greater than or equal {}",
797
                        calculated_interval_node.upper(),
798
                        expected.upper()
799
                    )
800
                );
801
            },
802
        );
803
        Ok(())
804
    }
805
806
    macro_rules! generate_cases {
807
        ($FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
808
            fn $FUNC_NAME<const ASC: bool>(
809
                expr: Arc<dyn PhysicalExpr>,
810
                left_col: Arc<dyn PhysicalExpr>,
811
                right_col: Arc<dyn PhysicalExpr>,
812
                seed: u64,
813
                expr_left: $TYPE,
814
                expr_right: $TYPE,
815
            ) -> Result<()> {
816
                let mut r = StdRng::seed_from_u64(seed);
817
818
                let (left_given, right_given, left_expected, right_expected) = if ASC {
819
                    let left = r.gen_range((0 as $TYPE)..(1000 as $TYPE));
820
                    let right = r.gen_range((0 as $TYPE)..(1000 as $TYPE));
821
                    (
822
                        (Some(left), None),
823
                        (Some(right), None),
824
                        (Some(<$TYPE>::max(left, right + expr_left)), None),
825
                        (Some(<$TYPE>::max(right, left + expr_right)), None),
826
                    )
827
                } else {
828
                    let left = r.gen_range((0 as $TYPE)..(1000 as $TYPE));
829
                    let right = r.gen_range((0 as $TYPE)..(1000 as $TYPE));
830
                    (
831
                        (None, Some(left)),
832
                        (None, Some(right)),
833
                        (None, Some(<$TYPE>::min(left, right + expr_left))),
834
                        (None, Some(<$TYPE>::min(right, left + expr_right))),
835
                    )
836
                };
837
838
                experiment(
839
                    expr,
840
                    (left_col.clone(), right_col.clone()),
841
                    Interval::make(left_given.0, left_given.1).unwrap(),
842
                    Interval::make(right_given.0, right_given.1).unwrap(),
843
                    Interval::make(left_expected.0, left_expected.1).unwrap(),
844
                    Interval::make(right_expected.0, right_expected.1).unwrap(),
845
                    PropagationResult::Success,
846
                    &Schema::new(vec![
847
                        Field::new(
848
                            left_col.as_any().downcast_ref::<Column>().unwrap().name(),
849
                            DataType::$SCALAR,
850
                            true,
851
                        ),
852
                        Field::new(
853
                            right_col.as_any().downcast_ref::<Column>().unwrap().name(),
854
                            DataType::$SCALAR,
855
                            true,
856
                        ),
857
                    ]),
858
                )
859
            }
860
        };
861
    }
862
    generate_cases!(generate_case_i32, i32, Int32);
863
    generate_cases!(generate_case_i64, i64, Int64);
864
    generate_cases!(generate_case_f32, f32, Float32);
865
    generate_cases!(generate_case_f64, f64, Float64);
866
867
    #[test]
868
    fn testing_not_possible() -> Result<()> {
869
        let left_col = Arc::new(Column::new("left_watermark", 0));
870
        let right_col = Arc::new(Column::new("right_watermark", 0));
871
872
        // left_watermark > right_watermark + 5
873
        let left_and_1 = Arc::new(BinaryExpr::new(
874
            Arc::clone(&left_col) as Arc<dyn PhysicalExpr>,
875
            Operator::Plus,
876
            Arc::new(Literal::new(ScalarValue::Int32(Some(5)))),
877
        ));
878
        let expr = Arc::new(BinaryExpr::new(
879
            left_and_1,
880
            Operator::Gt,
881
            Arc::clone(&right_col) as Arc<dyn PhysicalExpr>,
882
        ));
883
        experiment(
884
            expr,
885
            (
886
                Arc::clone(&left_col) as Arc<dyn PhysicalExpr>,
887
                Arc::clone(&right_col) as Arc<dyn PhysicalExpr>,
888
            ),
889
            Interval::make(Some(10_i32), Some(20_i32))?,
890
            Interval::make(Some(100), None)?,
891
            Interval::make(Some(10), Some(20))?,
892
            Interval::make(Some(100), None)?,
893
            PropagationResult::Infeasible,
894
            &Schema::new(vec![
895
                Field::new(
896
                    left_col.as_any().downcast_ref::<Column>().unwrap().name(),
897
                    DataType::Int32,
898
                    true,
899
                ),
900
                Field::new(
901
                    right_col.as_any().downcast_ref::<Column>().unwrap().name(),
902
                    DataType::Int32,
903
                    true,
904
                ),
905
            ]),
906
        )
907
    }
908
909
    macro_rules! integer_float_case_1 {
910
        ($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
911
            #[rstest]
912
            #[test]
913
            fn $TEST_FUNC_NAME(
914
                #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)]
915
                seed: u64,
916
                #[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
917
                #[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
918
            ) -> Result<()> {
919
                let left_col = Arc::new(Column::new("left_watermark", 0));
920
                let right_col = Arc::new(Column::new("right_watermark", 0));
921
922
                // left_watermark + 1 > right_watermark + 11 AND left_watermark + 3 < right_watermark + 33
923
                let expr = gen_conjunctive_numerical_expr(
924
                    left_col.clone(),
925
                    right_col.clone(),
926
                    (
927
                        Operator::Plus,
928
                        Operator::Plus,
929
                        Operator::Plus,
930
                        Operator::Plus,
931
                    ),
932
                    ScalarValue::$SCALAR(Some(1 as $TYPE)),
933
                    ScalarValue::$SCALAR(Some(11 as $TYPE)),
934
                    ScalarValue::$SCALAR(Some(3 as $TYPE)),
935
                    ScalarValue::$SCALAR(Some(33 as $TYPE)),
936
                    (greater_op, less_op),
937
                );
938
                // l > r + 10 AND r > l - 30
939
                let l_gt_r = 10 as $TYPE;
940
                let r_gt_l = -30 as $TYPE;
941
                $GENERATE_CASE_FUNC_NAME::<true>(
942
                    expr.clone(),
943
                    left_col.clone(),
944
                    right_col.clone(),
945
                    seed,
946
                    l_gt_r,
947
                    r_gt_l,
948
                )?;
949
                // Descending tests
950
                // r < l - 10 AND l < r + 30
951
                let r_lt_l = -l_gt_r;
952
                let l_lt_r = -r_gt_l;
953
                $GENERATE_CASE_FUNC_NAME::<false>(
954
                    expr, left_col, right_col, seed, l_lt_r, r_lt_l,
955
                )
956
            }
957
        };
958
    }
959
960
    integer_float_case_1!(case_1_i32, generate_case_i32, i32, Int32);
961
    integer_float_case_1!(case_1_i64, generate_case_i64, i64, Int64);
962
    integer_float_case_1!(case_1_f64, generate_case_f64, f64, Float64);
963
    integer_float_case_1!(case_1_f32, generate_case_f32, f32, Float32);
964
965
    macro_rules! integer_float_case_2 {
966
        ($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
967
            #[rstest]
968
            #[test]
969
            fn $TEST_FUNC_NAME(
970
                #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)]
971
                seed: u64,
972
                #[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
973
                #[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
974
            ) -> Result<()> {
975
                let left_col = Arc::new(Column::new("left_watermark", 0));
976
                let right_col = Arc::new(Column::new("right_watermark", 0));
977
978
                // left_watermark - 1 > right_watermark + 5 AND left_watermark + 3 < right_watermark + 10
979
                let expr = gen_conjunctive_numerical_expr(
980
                    left_col.clone(),
981
                    right_col.clone(),
982
                    (
983
                        Operator::Minus,
984
                        Operator::Plus,
985
                        Operator::Plus,
986
                        Operator::Plus,
987
                    ),
988
                    ScalarValue::$SCALAR(Some(1 as $TYPE)),
989
                    ScalarValue::$SCALAR(Some(5 as $TYPE)),
990
                    ScalarValue::$SCALAR(Some(3 as $TYPE)),
991
                    ScalarValue::$SCALAR(Some(10 as $TYPE)),
992
                    (greater_op, less_op),
993
                );
994
                // l > r + 6 AND r > l - 7
995
                let l_gt_r = 6 as $TYPE;
996
                let r_gt_l = -7 as $TYPE;
997
                $GENERATE_CASE_FUNC_NAME::<true>(
998
                    expr.clone(),
999
                    left_col.clone(),
1000
                    right_col.clone(),
1001
                    seed,
1002
                    l_gt_r,
1003
                    r_gt_l,
1004
                )?;
1005
                // Descending tests
1006
                // r < l - 6 AND l < r + 7
1007
                let r_lt_l = -l_gt_r;
1008
                let l_lt_r = -r_gt_l;
1009
                $GENERATE_CASE_FUNC_NAME::<false>(
1010
                    expr, left_col, right_col, seed, l_lt_r, r_lt_l,
1011
                )
1012
            }
1013
        };
1014
    }
1015
1016
    integer_float_case_2!(case_2_i32, generate_case_i32, i32, Int32);
1017
    integer_float_case_2!(case_2_i64, generate_case_i64, i64, Int64);
1018
    integer_float_case_2!(case_2_f64, generate_case_f64, f64, Float64);
1019
    integer_float_case_2!(case_2_f32, generate_case_f32, f32, Float32);
1020
1021
    macro_rules! integer_float_case_3 {
1022
        ($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
1023
            #[rstest]
1024
            #[test]
1025
            fn $TEST_FUNC_NAME(
1026
                #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)]
1027
                seed: u64,
1028
                #[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
1029
                #[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
1030
            ) -> Result<()> {
1031
                let left_col = Arc::new(Column::new("left_watermark", 0));
1032
                let right_col = Arc::new(Column::new("right_watermark", 0));
1033
1034
                // left_watermark - 1 > right_watermark + 5 AND left_watermark - 3 < right_watermark + 10
1035
                let expr = gen_conjunctive_numerical_expr(
1036
                    left_col.clone(),
1037
                    right_col.clone(),
1038
                    (
1039
                        Operator::Minus,
1040
                        Operator::Plus,
1041
                        Operator::Minus,
1042
                        Operator::Plus,
1043
                    ),
1044
                    ScalarValue::$SCALAR(Some(1 as $TYPE)),
1045
                    ScalarValue::$SCALAR(Some(5 as $TYPE)),
1046
                    ScalarValue::$SCALAR(Some(3 as $TYPE)),
1047
                    ScalarValue::$SCALAR(Some(10 as $TYPE)),
1048
                    (greater_op, less_op),
1049
                );
1050
                // l > r + 6 AND r > l - 13
1051
                let l_gt_r = 6 as $TYPE;
1052
                let r_gt_l = -13 as $TYPE;
1053
                $GENERATE_CASE_FUNC_NAME::<true>(
1054
                    expr.clone(),
1055
                    left_col.clone(),
1056
                    right_col.clone(),
1057
                    seed,
1058
                    l_gt_r,
1059
                    r_gt_l,
1060
                )?;
1061
                // Descending tests
1062
                // r < l - 6 AND l < r + 13
1063
                let r_lt_l = -l_gt_r;
1064
                let l_lt_r = -r_gt_l;
1065
                $GENERATE_CASE_FUNC_NAME::<false>(
1066
                    expr, left_col, right_col, seed, l_lt_r, r_lt_l,
1067
                )
1068
            }
1069
        };
1070
    }
1071
1072
    integer_float_case_3!(case_3_i32, generate_case_i32, i32, Int32);
1073
    integer_float_case_3!(case_3_i64, generate_case_i64, i64, Int64);
1074
    integer_float_case_3!(case_3_f64, generate_case_f64, f64, Float64);
1075
    integer_float_case_3!(case_3_f32, generate_case_f32, f32, Float32);
1076
1077
    macro_rules! integer_float_case_4 {
1078
        ($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
1079
            #[rstest]
1080
            #[test]
1081
            fn $TEST_FUNC_NAME(
1082
                #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)]
1083
                seed: u64,
1084
                #[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
1085
                #[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
1086
            ) -> Result<()> {
1087
                let left_col = Arc::new(Column::new("left_watermark", 0));
1088
                let right_col = Arc::new(Column::new("right_watermark", 0));
1089
1090
                // left_watermark - 10 > right_watermark - 5 AND left_watermark - 30 < right_watermark - 3
1091
                let expr = gen_conjunctive_numerical_expr(
1092
                    left_col.clone(),
1093
                    right_col.clone(),
1094
                    (
1095
                        Operator::Minus,
1096
                        Operator::Minus,
1097
                        Operator::Minus,
1098
                        Operator::Plus,
1099
                    ),
1100
                    ScalarValue::$SCALAR(Some(10 as $TYPE)),
1101
                    ScalarValue::$SCALAR(Some(5 as $TYPE)),
1102
                    ScalarValue::$SCALAR(Some(3 as $TYPE)),
1103
                    ScalarValue::$SCALAR(Some(10 as $TYPE)),
1104
                    (greater_op, less_op),
1105
                );
1106
                // l > r + 5 AND r > l - 13
1107
                let l_gt_r = 5 as $TYPE;
1108
                let r_gt_l = -13 as $TYPE;
1109
                $GENERATE_CASE_FUNC_NAME::<true>(
1110
                    expr.clone(),
1111
                    left_col.clone(),
1112
                    right_col.clone(),
1113
                    seed,
1114
                    l_gt_r,
1115
                    r_gt_l,
1116
                )?;
1117
                // Descending tests
1118
                // r < l - 5 AND l < r + 13
1119
                let r_lt_l = -l_gt_r;
1120
                let l_lt_r = -r_gt_l;
1121
                $GENERATE_CASE_FUNC_NAME::<false>(
1122
                    expr, left_col, right_col, seed, l_lt_r, r_lt_l,
1123
                )
1124
            }
1125
        };
1126
    }
1127
1128
    integer_float_case_4!(case_4_i32, generate_case_i32, i32, Int32);
1129
    integer_float_case_4!(case_4_i64, generate_case_i64, i64, Int64);
1130
    integer_float_case_4!(case_4_f64, generate_case_f64, f64, Float64);
1131
    integer_float_case_4!(case_4_f32, generate_case_f32, f32, Float32);
1132
1133
    macro_rules! integer_float_case_5 {
1134
        ($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
1135
            #[rstest]
1136
            #[test]
1137
            fn $TEST_FUNC_NAME(
1138
                #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)]
1139
                seed: u64,
1140
                #[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
1141
                #[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
1142
            ) -> Result<()> {
1143
                let left_col = Arc::new(Column::new("left_watermark", 0));
1144
                let right_col = Arc::new(Column::new("right_watermark", 0));
1145
1146
                // left_watermark - 10 > right_watermark - 5 AND left_watermark - 30 < right_watermark - 3
1147
                let expr = gen_conjunctive_numerical_expr(
1148
                    left_col.clone(),
1149
                    right_col.clone(),
1150
                    (
1151
                        Operator::Minus,
1152
                        Operator::Minus,
1153
                        Operator::Minus,
1154
                        Operator::Minus,
1155
                    ),
1156
                    ScalarValue::$SCALAR(Some(10 as $TYPE)),
1157
                    ScalarValue::$SCALAR(Some(5 as $TYPE)),
1158
                    ScalarValue::$SCALAR(Some(30 as $TYPE)),
1159
                    ScalarValue::$SCALAR(Some(3 as $TYPE)),
1160
                    (greater_op, less_op),
1161
                );
1162
                // l > r + 5 AND r > l - 27
1163
                let l_gt_r = 5 as $TYPE;
1164
                let r_gt_l = -27 as $TYPE;
1165
                $GENERATE_CASE_FUNC_NAME::<true>(
1166
                    expr.clone(),
1167
                    left_col.clone(),
1168
                    right_col.clone(),
1169
                    seed,
1170
                    l_gt_r,
1171
                    r_gt_l,
1172
                )?;
1173
                // Descending tests
1174
                // r < l - 5 AND l < r + 27
1175
                let r_lt_l = -l_gt_r;
1176
                let l_lt_r = -r_gt_l;
1177
                $GENERATE_CASE_FUNC_NAME::<false>(
1178
                    expr, left_col, right_col, seed, l_lt_r, r_lt_l,
1179
                )
1180
            }
1181
        };
1182
    }
1183
1184
    integer_float_case_5!(case_5_i32, generate_case_i32, i32, Int32);
1185
    integer_float_case_5!(case_5_i64, generate_case_i64, i64, Int64);
1186
    integer_float_case_5!(case_5_f64, generate_case_f64, f64, Float64);
1187
    integer_float_case_5!(case_5_f32, generate_case_f32, f32, Float32);
1188
1189
    #[test]
1190
    fn test_gather_node_indices_dont_remove() -> Result<()> {
1191
        // Expression: a@0 + b@1 + 1 > a@0 - b@1, given a@0 + b@1.
1192
        // Do not remove a@0 or b@1, only remove edges since a@0 - b@1 also
1193
        // depends on leaf nodes a@0 and b@1.
1194
        let left_expr = Arc::new(BinaryExpr::new(
1195
            Arc::new(BinaryExpr::new(
1196
                Arc::new(Column::new("a", 0)),
1197
                Operator::Plus,
1198
                Arc::new(Column::new("b", 1)),
1199
            )),
1200
            Operator::Plus,
1201
            Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
1202
        ));
1203
1204
        let right_expr = Arc::new(BinaryExpr::new(
1205
            Arc::new(Column::new("a", 0)),
1206
            Operator::Minus,
1207
            Arc::new(Column::new("b", 1)),
1208
        ));
1209
        let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr));
1210
        let mut graph = ExprIntervalGraph::try_new(
1211
            expr,
1212
            &Schema::new(vec![
1213
                Field::new("a", DataType::Int32, true),
1214
                Field::new("b", DataType::Int32, true),
1215
            ]),
1216
        )
1217
        .unwrap();
1218
        // Define a test leaf node.
1219
        let leaf_node = Arc::new(BinaryExpr::new(
1220
            Arc::new(Column::new("a", 0)),
1221
            Operator::Plus,
1222
            Arc::new(Column::new("b", 1)),
1223
        ));
1224
        // Store the current node count.
1225
        let prev_node_count = graph.node_count();
1226
        // Gather the index of node in the expression graph that match the test leaf node.
1227
        graph.gather_node_indices(&[leaf_node]);
1228
        // Store the final node count.
1229
        let final_node_count = graph.node_count();
1230
        // Assert that the final node count is equal the previous node count.
1231
        // This means we did not remove any node.
1232
        assert_eq!(prev_node_count, final_node_count);
1233
        Ok(())
1234
    }
1235
1236
    #[test]
1237
    fn test_gather_node_indices_remove() -> Result<()> {
1238
        // Expression: a@0 + b@1 + 1 > y@0 - z@1, given a@0 + b@1.
1239
        // We expect to remove two nodes since we do not need a@ and b@.
1240
        let left_expr = Arc::new(BinaryExpr::new(
1241
            Arc::new(BinaryExpr::new(
1242
                Arc::new(Column::new("a", 0)),
1243
                Operator::Plus,
1244
                Arc::new(Column::new("b", 1)),
1245
            )),
1246
            Operator::Plus,
1247
            Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
1248
        ));
1249
1250
        let right_expr = Arc::new(BinaryExpr::new(
1251
            Arc::new(Column::new("y", 0)),
1252
            Operator::Minus,
1253
            Arc::new(Column::new("z", 1)),
1254
        ));
1255
        let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr));
1256
        let mut graph = ExprIntervalGraph::try_new(
1257
            expr,
1258
            &Schema::new(vec![
1259
                Field::new("a", DataType::Int32, true),
1260
                Field::new("b", DataType::Int32, true),
1261
                Field::new("y", DataType::Int32, true),
1262
                Field::new("z", DataType::Int32, true),
1263
            ]),
1264
        )
1265
        .unwrap();
1266
        // Define a test leaf node.
1267
        let leaf_node = Arc::new(BinaryExpr::new(
1268
            Arc::new(Column::new("a", 0)),
1269
            Operator::Plus,
1270
            Arc::new(Column::new("b", 1)),
1271
        ));
1272
        // Store the current node count.
1273
        let prev_node_count = graph.node_count();
1274
        // Gather the index of node in the expression graph that match the test leaf node.
1275
        graph.gather_node_indices(&[leaf_node]);
1276
        // Store the final node count.
1277
        let final_node_count = graph.node_count();
1278
        // Assert that the final node count is two less than the previous node
1279
        // count; i.e. that we did remove two nodes.
1280
        assert_eq!(prev_node_count, final_node_count + 2);
1281
        Ok(())
1282
    }
1283
1284
    #[test]
1285
    fn test_gather_node_indices_remove_one() -> Result<()> {
1286
        // Expression: a@0 + b@1 + 1 > a@0 - z@1, given a@0 + b@1.
1287
        // We expect to remove one nodesince we still need a@ but not b@.
1288
        let left_expr = Arc::new(BinaryExpr::new(
1289
            Arc::new(BinaryExpr::new(
1290
                Arc::new(Column::new("a", 0)),
1291
                Operator::Plus,
1292
                Arc::new(Column::new("b", 1)),
1293
            )),
1294
            Operator::Plus,
1295
            Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
1296
        ));
1297
1298
        let right_expr = Arc::new(BinaryExpr::new(
1299
            Arc::new(Column::new("a", 0)),
1300
            Operator::Minus,
1301
            Arc::new(Column::new("z", 1)),
1302
        ));
1303
        let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr));
1304
        let mut graph = ExprIntervalGraph::try_new(
1305
            expr,
1306
            &Schema::new(vec![
1307
                Field::new("a", DataType::Int32, true),
1308
                Field::new("b", DataType::Int32, true),
1309
                Field::new("z", DataType::Int32, true),
1310
            ]),
1311
        )
1312
        .unwrap();
1313
        // Define a test leaf node.
1314
        let leaf_node = Arc::new(BinaryExpr::new(
1315
            Arc::new(Column::new("a", 0)),
1316
            Operator::Plus,
1317
            Arc::new(Column::new("b", 1)),
1318
        ));
1319
        // Store the current node count.
1320
        let prev_node_count = graph.node_count();
1321
        // Gather the index of node in the expression graph that match the test leaf node.
1322
        graph.gather_node_indices(&[leaf_node]);
1323
        // Store the final node count.
1324
        let final_node_count = graph.node_count();
1325
        // Assert that the final node count is one less than the previous node
1326
        // count; i.e. that we did remove two nodes.
1327
        assert_eq!(prev_node_count, final_node_count + 1);
1328
        Ok(())
1329
    }
1330
1331
    #[test]
1332
    fn test_gather_node_indices_cannot_provide() -> Result<()> {
1333
        // Expression: a@0 + 1 + b@1 > y@0 - z@1 -> provide a@0 + b@1
1334
        // TODO: We expect nodes a@0 and b@1 to be pruned, and intervals to be provided from the a@0 + b@1 node.
1335
        // However, we do not have an exact node for a@0 + b@1 due to the binary tree structure of the expressions.
1336
        // Pruning and interval providing for BinaryExpr expressions are more challenging without exact matches.
1337
        // Currently, we only support exact matches for BinaryExprs, but we plan to extend support beyond exact matches in the future.
1338
        let left_expr = Arc::new(BinaryExpr::new(
1339
            Arc::new(BinaryExpr::new(
1340
                Arc::new(Column::new("a", 0)),
1341
                Operator::Plus,
1342
                Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
1343
            )),
1344
            Operator::Plus,
1345
            Arc::new(Column::new("b", 1)),
1346
        ));
1347
1348
        let right_expr = Arc::new(BinaryExpr::new(
1349
            Arc::new(Column::new("y", 0)),
1350
            Operator::Minus,
1351
            Arc::new(Column::new("z", 1)),
1352
        ));
1353
        let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr));
1354
        let mut graph = ExprIntervalGraph::try_new(
1355
            expr,
1356
            &Schema::new(vec![
1357
                Field::new("a", DataType::Int32, true),
1358
                Field::new("b", DataType::Int32, true),
1359
                Field::new("y", DataType::Int32, true),
1360
                Field::new("z", DataType::Int32, true),
1361
            ]),
1362
        )
1363
        .unwrap();
1364
        // Define a test leaf node.
1365
        let leaf_node = Arc::new(BinaryExpr::new(
1366
            Arc::new(Column::new("a", 0)),
1367
            Operator::Plus,
1368
            Arc::new(Column::new("b", 1)),
1369
        ));
1370
        // Store the current node count.
1371
        let prev_node_count = graph.node_count();
1372
        // Gather the index of node in the expression graph that match the test leaf node.
1373
        graph.gather_node_indices(&[leaf_node]);
1374
        // Store the final node count.
1375
        let final_node_count = graph.node_count();
1376
        // Assert that the final node count is equal the previous node count (i.e., no node was pruned).
1377
        assert_eq!(prev_node_count, final_node_count);
1378
        Ok(())
1379
    }
1380
1381
    #[test]
1382
    fn test_propagate_constraints_singleton_interval_at_right() -> Result<()> {
1383
        let expression = BinaryExpr::new(
1384
            Arc::new(Column::new("ts_column", 0)),
1385
            Operator::Plus,
1386
            Arc::new(Literal::new(ScalarValue::new_interval_mdn(0, 1, 321))),
1387
        );
1388
        let parent = Interval::try_new(
1389
            // 15.10.2020 - 10:11:12.000_000_321 AM
1390
            ScalarValue::TimestampNanosecond(Some(1_602_756_672_000_000_321), None),
1391
            // 16.10.2020 - 10:11:12.000_000_321 AM
1392
            ScalarValue::TimestampNanosecond(Some(1_602_843_072_000_000_321), None),
1393
        )?;
1394
        let left_child = Interval::try_new(
1395
            // 10.10.2020 - 10:11:12 AM
1396
            ScalarValue::TimestampNanosecond(Some(1_602_324_672_000_000_000), None),
1397
            // 20.10.2020 - 10:11:12 AM
1398
            ScalarValue::TimestampNanosecond(Some(1_603_188_672_000_000_000), None),
1399
        )?;
1400
        let right_child = Interval::try_new(
1401
            // 1 day 321 ns
1402
            ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano {
1403
                months: 0,
1404
                days: 1,
1405
                nanoseconds: 321,
1406
            })),
1407
            // 1 day 321 ns
1408
            ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano {
1409
                months: 0,
1410
                days: 1,
1411
                nanoseconds: 321,
1412
            })),
1413
        )?;
1414
        let children = vec![&left_child, &right_child];
1415
        let result = expression
1416
            .propagate_constraints(&parent, &children)?
1417
            .unwrap();
1418
1419
        assert_eq!(
1420
            vec![
1421
                Interval::try_new(
1422
                    // 14.10.2020 - 10:11:12 AM
1423
                    ScalarValue::TimestampNanosecond(
1424
                        Some(1_602_670_272_000_000_000),
1425
                        None
1426
                    ),
1427
                    // 15.10.2020 - 10:11:12 AM
1428
                    ScalarValue::TimestampNanosecond(
1429
                        Some(1_602_756_672_000_000_000),
1430
                        None
1431
                    ),
1432
                )?,
1433
                Interval::try_new(
1434
                    // 1 day 321 ns in Duration type
1435
                    ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano {
1436
                        months: 0,
1437
                        days: 1,
1438
                        nanoseconds: 321,
1439
                    })),
1440
                    // 1 day 321 ns in Duration type
1441
                    ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano {
1442
                        months: 0,
1443
                        days: 1,
1444
                        nanoseconds: 321,
1445
                    })),
1446
                )?
1447
            ],
1448
            result
1449
        );
1450
1451
        Ok(())
1452
    }
1453
1454
    #[test]
1455
    fn test_propagate_constraints_column_interval_at_left() -> Result<()> {
1456
        let expression = BinaryExpr::new(
1457
            Arc::new(Column::new("interval_column", 1)),
1458
            Operator::Plus,
1459
            Arc::new(Column::new("ts_column", 0)),
1460
        );
1461
        let parent = Interval::try_new(
1462
            // 15.10.2020 - 10:11:12 AM
1463
            ScalarValue::TimestampMillisecond(Some(1_602_756_672_000), None),
1464
            // 16.10.2020 - 10:11:12 AM
1465
            ScalarValue::TimestampMillisecond(Some(1_602_843_072_000), None),
1466
        )?;
1467
        let right_child = Interval::try_new(
1468
            // 10.10.2020 - 10:11:12 AM
1469
            ScalarValue::TimestampMillisecond(Some(1_602_324_672_000), None),
1470
            // 20.10.2020 - 10:11:12 AM
1471
            ScalarValue::TimestampMillisecond(Some(1_603_188_672_000), None),
1472
        )?;
1473
        let left_child = Interval::try_new(
1474
            // 2 days in millisecond
1475
            ScalarValue::IntervalDayTime(Some(IntervalDayTime {
1476
                days: 0,
1477
                milliseconds: 172_800_000,
1478
            })),
1479
            // 10 days in millisecond
1480
            ScalarValue::IntervalDayTime(Some(IntervalDayTime {
1481
                days: 0,
1482
                milliseconds: 864_000_000,
1483
            })),
1484
        )?;
1485
        let children = vec![&left_child, &right_child];
1486
        let result = expression
1487
            .propagate_constraints(&parent, &children)?
1488
            .unwrap();
1489
1490
        assert_eq!(
1491
            vec![
1492
                Interval::try_new(
1493
                    // 2 days in millisecond
1494
                    ScalarValue::IntervalDayTime(Some(IntervalDayTime {
1495
                        days: 0,
1496
                        milliseconds: 172_800_000,
1497
                    })),
1498
                    // 6 days
1499
                    ScalarValue::IntervalDayTime(Some(IntervalDayTime {
1500
                        days: 0,
1501
                        milliseconds: 518_400_000,
1502
                    })),
1503
                )?,
1504
                Interval::try_new(
1505
                    // 10.10.2020 - 10:11:12 AM
1506
                    ScalarValue::TimestampMillisecond(Some(1_602_324_672_000), None),
1507
                    // 14.10.2020 - 10:11:12 AM
1508
                    ScalarValue::TimestampMillisecond(Some(1_602_670_272_000), None),
1509
                )?
1510
            ],
1511
            result
1512
        );
1513
1514
        Ok(())
1515
    }
1516
1517
    #[test]
1518
    fn test_propagate_comparison() -> Result<()> {
1519
        // In the examples below:
1520
        // `left` is unbounded: [?, ?],
1521
        // `right` is known to be [1000,1000]
1522
        // so `left` < `right` results in no new knowledge of `right` but knowing that `left` is now < 1000:` [?, 999]
1523
        let left = Interval::make_unbounded(&DataType::Int64)?;
1524
        let right = Interval::make(Some(1000_i64), Some(1000_i64))?;
1525
        assert_eq!(
1526
            (Some((
1527
                Interval::make(None, Some(999_i64))?,
1528
                Interval::make(Some(1000_i64), Some(1000_i64))?,
1529
            ))),
1530
            propagate_comparison(
1531
                &Operator::Lt,
1532
                &Interval::CERTAINLY_TRUE,
1533
                &left,
1534
                &right
1535
            )?
1536
        );
1537
1538
        let left =
1539
            Interval::make_unbounded(&DataType::Timestamp(TimeUnit::Nanosecond, None))?;
1540
        let right = Interval::try_new(
1541
            ScalarValue::TimestampNanosecond(Some(1000), None),
1542
            ScalarValue::TimestampNanosecond(Some(1000), None),
1543
        )?;
1544
        assert_eq!(
1545
            (Some((
1546
                Interval::try_new(
1547
                    ScalarValue::try_from(&DataType::Timestamp(
1548
                        TimeUnit::Nanosecond,
1549
                        None
1550
                    ))
1551
                    .unwrap(),
1552
                    ScalarValue::TimestampNanosecond(Some(999), None),
1553
                )?,
1554
                Interval::try_new(
1555
                    ScalarValue::TimestampNanosecond(Some(1000), None),
1556
                    ScalarValue::TimestampNanosecond(Some(1000), None),
1557
                )?
1558
            ))),
1559
            propagate_comparison(
1560
                &Operator::Lt,
1561
                &Interval::CERTAINLY_TRUE,
1562
                &left,
1563
                &right
1564
            )?
1565
        );
1566
1567
        let left = Interval::make_unbounded(&DataType::Timestamp(
1568
            TimeUnit::Nanosecond,
1569
            Some("+05:00".into()),
1570
        ))?;
1571
        let right = Interval::try_new(
1572
            ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())),
1573
            ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())),
1574
        )?;
1575
        assert_eq!(
1576
            (Some((
1577
                Interval::try_new(
1578
                    ScalarValue::try_from(&DataType::Timestamp(
1579
                        TimeUnit::Nanosecond,
1580
                        Some("+05:00".into()),
1581
                    ))
1582
                    .unwrap(),
1583
                    ScalarValue::TimestampNanosecond(Some(999), Some("+05:00".into())),
1584
                )?,
1585
                Interval::try_new(
1586
                    ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())),
1587
                    ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())),
1588
                )?
1589
            ))),
1590
            propagate_comparison(
1591
                &Operator::Lt,
1592
                &Interval::CERTAINLY_TRUE,
1593
                &left,
1594
                &right
1595
            )?
1596
        );
1597
1598
        Ok(())
1599
    }
1600
1601
    #[test]
1602
    fn test_propagate_or() -> Result<()> {
1603
        let expr = Arc::new(BinaryExpr::new(
1604
            Arc::new(Column::new("a", 0)),
1605
            Operator::Or,
1606
            Arc::new(Column::new("b", 1)),
1607
        ));
1608
        let parent = Interval::CERTAINLY_FALSE;
1609
        let children_set = vec![
1610
            vec![&Interval::CERTAINLY_FALSE, &Interval::UNCERTAIN],
1611
            vec![&Interval::UNCERTAIN, &Interval::CERTAINLY_FALSE],
1612
            vec![&Interval::CERTAINLY_FALSE, &Interval::CERTAINLY_FALSE],
1613
            vec![&Interval::UNCERTAIN, &Interval::UNCERTAIN],
1614
        ];
1615
        for children in children_set {
1616
            assert_eq!(
1617
                expr.propagate_constraints(&parent, &children)?.unwrap(),
1618
                vec![Interval::CERTAINLY_FALSE, Interval::CERTAINLY_FALSE],
1619
            );
1620
        }
1621
1622
        let parent = Interval::CERTAINLY_FALSE;
1623
        let children_set = vec![
1624
            vec![&Interval::CERTAINLY_TRUE, &Interval::UNCERTAIN],
1625
            vec![&Interval::UNCERTAIN, &Interval::CERTAINLY_TRUE],
1626
        ];
1627
        for children in children_set {
1628
            assert_eq!(expr.propagate_constraints(&parent, &children)?, None,);
1629
        }
1630
1631
        let parent = Interval::CERTAINLY_TRUE;
1632
        let children = vec![&Interval::CERTAINLY_FALSE, &Interval::UNCERTAIN];
1633
        assert_eq!(
1634
            expr.propagate_constraints(&parent, &children)?.unwrap(),
1635
            vec![Interval::CERTAINLY_FALSE, Interval::CERTAINLY_TRUE]
1636
        );
1637
1638
        let parent = Interval::CERTAINLY_TRUE;
1639
        let children = vec![&Interval::UNCERTAIN, &Interval::UNCERTAIN];
1640
        assert_eq!(
1641
            expr.propagate_constraints(&parent, &children)?.unwrap(),
1642
            // Empty means unchanged intervals.
1643
            vec![]
1644
        );
1645
1646
        Ok(())
1647
    }
1648
1649
    #[test]
1650
    fn test_propagate_certainly_false_and() -> Result<()> {
1651
        let expr = Arc::new(BinaryExpr::new(
1652
            Arc::new(Column::new("a", 0)),
1653
            Operator::And,
1654
            Arc::new(Column::new("b", 1)),
1655
        ));
1656
        let parent = Interval::CERTAINLY_FALSE;
1657
        let children_and_results_set = vec![
1658
            (
1659
                vec![&Interval::CERTAINLY_TRUE, &Interval::UNCERTAIN],
1660
                vec![Interval::CERTAINLY_TRUE, Interval::CERTAINLY_FALSE],
1661
            ),
1662
            (
1663
                vec![&Interval::UNCERTAIN, &Interval::CERTAINLY_TRUE],
1664
                vec![Interval::CERTAINLY_FALSE, Interval::CERTAINLY_TRUE],
1665
            ),
1666
            (
1667
                vec![&Interval::UNCERTAIN, &Interval::UNCERTAIN],
1668
                // Empty means unchanged intervals.
1669
                vec![],
1670
            ),
1671
            (
1672
                vec![&Interval::CERTAINLY_FALSE, &Interval::UNCERTAIN],
1673
                vec![],
1674
            ),
1675
        ];
1676
        for (children, result) in children_and_results_set {
1677
            assert_eq!(
1678
                expr.propagate_constraints(&parent, &children)?.unwrap(),
1679
                result
1680
            );
1681
        }
1682
1683
        Ok(())
1684
    }
1685
}