From b56910a3f75f8841116482dc83a609ebff5fb1de Mon Sep 17 00:00:00 2001 From: Christoph Schulze Date: Fri, 3 Jan 2025 12:19:26 +0100 Subject: [PATCH] support IS_DISTINCT_FROM IS_NOT_DISTINCT_FROM in interval arithmetic --- datafusion/expr/src/interval_arithmetic.rs | 156 ++++++++++++++---- .../physical-expr/src/intervals/cp_solver.rs | 7 +- 2 files changed, 133 insertions(+), 30 deletions(-) diff --git a/datafusion/expr/src/interval_arithmetic.rs b/datafusion/expr/src/interval_arithmetic.rs index 97c2103f525f..81a90dd4ff12 100644 --- a/datafusion/expr/src/interval_arithmetic.rs +++ b/datafusion/expr/src/interval_arithmetic.rs @@ -288,23 +288,11 @@ impl Interval { // Standardize floating-point endpoints: DataType::Float32 => handle_float_intervals!(Float32, f32, lower, upper), DataType::Float64 => handle_float_intervals!(Float64, f64, lower, upper), - // Unsigned null values for lower bounds are set to zero: - DataType::UInt8 if lower.is_null() => Self { - lower: ScalarValue::UInt8(Some(0)), - upper, - }, - DataType::UInt16 if lower.is_null() => Self { - lower: ScalarValue::UInt16(Some(0)), - upper, - }, - DataType::UInt32 if lower.is_null() => Self { - lower: ScalarValue::UInt32(Some(0)), - upper, - }, - DataType::UInt64 if lower.is_null() => Self { - lower: ScalarValue::UInt64(Some(0)), - upper, - }, + // Lower bounds of unsigned integer null values are set to zero: + // DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 if lower.is_null() => Self { + // lower: ScalarValue::new_zero(&lower.data_type()).unwrap(), + // upper, + // }, // Other data types do not require standardization: _ => Self { lower, upper }, } @@ -355,8 +343,8 @@ impl Interval { // There must be no way to create an interval whose endpoints have // different types. - assert!( - lower_type == upper_type, + assert_eq!( + lower_type, upper_type, "Interval bounds have different types: {lower_type} != {upper_type}" ); lower_type @@ -374,6 +362,10 @@ impl Interval { ) } + pub fn is_null(&self) -> bool { + self.lower.is_null() && self.upper.is_null() + } + pub const CERTAINLY_FALSE: Self = Self { lower: ScalarValue::Boolean(Some(false)), upper: ScalarValue::Boolean(Some(false)), @@ -564,6 +556,36 @@ impl Interval { } } + pub fn union>(&self, other: T) -> Result> { + let rhs = other.borrow(); + if self.data_type().ne(&rhs.data_type()) { + return internal_err!( + "Only intervals with the same data type are intersectable, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ); + }; + + // If it is evident that the result is an empty interval, short-circuit + // and directly return `None`. + if (!(self.lower.is_null() || rhs.upper.is_null()) && self.lower > rhs.upper) + || (!(self.upper.is_null() || rhs.lower.is_null()) && self.upper < rhs.lower) + { + return Ok(None); + } + + let lower = min_of_bounds(&self.lower, &rhs.lower); + let upper = max_of_bounds(&self.upper, &rhs.upper); + + // New lower and upper bounds must always construct a valid interval. + assert!( + lower.is_null() || upper.is_null() || (lower <= upper), + "The intersection of two intervals can not be an invalid interval" + ); + + Ok(Some(Self { lower, upper })) + } + /// Compute the intersection of this interval with the given interval. /// If the intersection is empty, return `None`. /// @@ -593,7 +615,7 @@ impl Interval { // New lower and upper bounds must always construct a valid interval. assert!( - (lower.is_null() || upper.is_null() || (lower <= upper)), + lower.is_null() || upper.is_null() || (lower <= upper), "The intersection of two intervals can not be an invalid interval" ); @@ -846,9 +868,9 @@ pub fn apply_operator(op: &Operator, lhs: &Interval, rhs: &Interval) -> Result lhs.mul(rhs), Operator::Divide => lhs.div(rhs), Operator::IsDistinctFrom | Operator::IsNotDistinctFrom => { - let nullable_interval = NullableInterval::from(lhs) - .apply_operator(op, &NullableInterval::from(rhs)); - nullable_interval.and_then(|x| { + NullableInterval::from(lhs) + .apply_operator(op, &NullableInterval::from(rhs)) + .and_then(|x| { x.values().cloned().ok_or(DataFusionError::Internal( "Unexpected null value interval".to_string(), )) @@ -1185,17 +1207,17 @@ pub fn satisfy_greater( } if !left.upper.is_null() && left.upper <= right.lower { - if !strict && left.upper == right.lower { + return if !strict && left.upper == right.lower { // Singleton intervals: - return Ok(Some(( + Ok(Some(( Interval::new(left.upper.clone(), left.upper.clone()), Interval::new(left.upper.clone(), left.upper.clone()), - ))); + ))) } else { // Left-hand side: <--======----0------------> // Right-hand side: <------------0--======----> // No intersection, infeasible to propagate: - return Ok(None); + Ok(None) } } @@ -1579,7 +1601,7 @@ impl Display for NullableInterval { impl From<&Interval> for NullableInterval { fn from(value: &Interval) -> Self { - if value.lower().is_null() && value.upper().is_null() { + if value.is_null() { Self::Null { datatype: value.data_type(), } @@ -2493,6 +2515,84 @@ mod tests { Ok(()) } + #[test] + fn test_union() -> Result<()> { + let possible_cases: Vec<(Interval, Interval, Interval)> = vec![ + ( + Interval::make(Some(1000_i64), None)?, + Interval::make::(None, None)?, + Interval::make(Some(1000_i64), None)?, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(1000_i64))?, + Interval::make(Some(1000_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(2000_i64))?, + Interval::make(Some(1000_i64), Some(2000_i64))?, + ), + ( + Interval::make(Some(1000_i64), Some(2000_i64))?, + Interval::make(Some(1000_i64), None)?, + Interval::make(Some(1000_i64), Some(2000_i64))?, + ), + ( + Interval::make(Some(1000_i64), Some(2000_i64))?, + Interval::make(Some(1000_i64), Some(1500_i64))?, + Interval::make(Some(1000_i64), Some(2000_i64))?, + ), + ( + Interval::make(Some(1000_i64), Some(2000_i64))?, + Interval::make(Some(500_i64), Some(1500_i64))?, + Interval::make(Some(500_i64), Some(2000_i64))?, + ), + ( + Interval::make::(None, None)?, + Interval::make::(None, None)?, + Interval::make::(None, None)?, + ), + ( + Interval::make(None, Some(2000_u64))?, + Interval::make(Some(500_u64), None)?, + Interval::make(Some(500_u64), Some(2000_u64))?, + ), + ( + Interval::make(Some(0_u64), Some(0_u64))?, + Interval::make(Some(0_u64), None)?, + Interval::make(Some(0_u64), Some(0_u64))?, + ), + ( + Interval::make(Some(1000.0_f32), None)?, + Interval::make(None, Some(1000.0_f32))?, + Interval::make(Some(1000.0_f32), Some(1000.0_f32))?, + ), + ( + Interval::make(Some(1000.0_f32), Some(1500.0_f32))?, + Interval::make(Some(0.0_f32), Some(1500.0_f32))?, + Interval::make(Some(0.0_f32), Some(1500.0_f32))?, + ), + ( + Interval::make(Some(-1000.0_f64), Some(1500.0_f64))?, + Interval::make(Some(-1500.0_f64), Some(2000.0_f64))?, + Interval::make(Some(-1500.0_f64), Some(2000.0_f64))?, + ), + ( + Interval::make(Some(16.0_f64), Some(32.0_f64))?, + Interval::make(Some(32.0_f64), Some(64.0_f64))?, + Interval::make(Some(16.0_f64), Some(64.0_f64))?, + ), + ]; + for (first, second, expected) in possible_cases { + let union = first.union(second.clone())?.unwrap(); + println!("\nleft:{:?} right:{:?} \nunion: {} expected:{:?}", first, second, &union, expected); + assert_eq!(union, expected) + } + + Ok(()) + } + #[test] fn test_contains() -> Result<()> { let possible_cases = vec![ diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index 5ba628e7ce40..e0bc2d4a580e 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -314,9 +314,12 @@ pub fn propagate_comparison( ) -> Result> { if parent == &Interval::CERTAINLY_TRUE { match op { - Operator::Eq => left_child.intersect(right_child).map(|result| { + Operator::Eq | Operator::IsNotDistinctFrom => left_child.intersect(right_child).map(|result| { result.map(|intersection| (intersection.clone(), intersection)) }), + Operator::NotEq | Operator::IsDistinctFrom => left_child.union(right_child).map(|result| { + result.map(|unin| (unin.clone(), unin)) + }), Operator::Gt => satisfy_greater(left_child, right_child, true), Operator::GtEq => satisfy_greater(left_child, right_child, false), Operator::Lt => satisfy_greater(right_child, left_child, true) @@ -329,7 +332,7 @@ pub fn propagate_comparison( } } else if parent == &Interval::CERTAINLY_FALSE { match op { - Operator::Eq => { + Operator::Eq | Operator::IsNotDistinctFrom | Operator::NotEq | Operator::IsDistinctFrom => { // TODO: Propagation is not possible until we support interval sets. Ok(None) }