Skip to content

Commit

Permalink
Implement PartialOrd for ScalarValue (#838)
Browse files Browse the repository at this point in the history
* Implement PartialOrd for ScalarValue.

* Avoid catch all match.
  • Loading branch information
viirya authored Aug 8, 2021
1 parent ee27f6e commit 4ddd2f5
Showing 1 changed file with 146 additions and 0 deletions.
146 changes: 146 additions & 0 deletions datafusion/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use arrow::{
},
};
use ordered_float::OrderedFloat;
use std::cmp::Ordering;
use std::convert::{Infallible, TryInto};
use std::str::FromStr;
use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc};
Expand Down Expand Up @@ -156,6 +157,81 @@ impl PartialEq for ScalarValue {
}
}

// manual implementation of `PartialOrd` that uses OrderedFloat to
// get defined behavior for floating point
impl PartialOrd for ScalarValue {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
use ScalarValue::*;
// This purposely doesn't have a catch-all "(_, _)" so that
// any newly added enum variant will require editing this list
// or else face a compile error
match (self, other) {
(Boolean(v1), Boolean(v2)) => v1.partial_cmp(v2),
(Boolean(_), _) => None,
(Float32(v1), Float32(v2)) => {
let v1 = v1.map(OrderedFloat);
let v2 = v2.map(OrderedFloat);
v1.partial_cmp(&v2)
}
(Float32(_), _) => None,
(Float64(v1), Float64(v2)) => {
let v1 = v1.map(OrderedFloat);
let v2 = v2.map(OrderedFloat);
v1.partial_cmp(&v2)
}
(Float64(_), _) => None,
(Int8(v1), Int8(v2)) => v1.partial_cmp(v2),
(Int8(_), _) => None,
(Int16(v1), Int16(v2)) => v1.partial_cmp(v2),
(Int16(_), _) => None,
(Int32(v1), Int32(v2)) => v1.partial_cmp(v2),
(Int32(_), _) => None,
(Int64(v1), Int64(v2)) => v1.partial_cmp(v2),
(Int64(_), _) => None,
(UInt8(v1), UInt8(v2)) => v1.partial_cmp(v2),
(UInt8(_), _) => None,
(UInt16(v1), UInt16(v2)) => v1.partial_cmp(v2),
(UInt16(_), _) => None,
(UInt32(v1), UInt32(v2)) => v1.partial_cmp(v2),
(UInt32(_), _) => None,
(UInt64(v1), UInt64(v2)) => v1.partial_cmp(v2),
(UInt64(_), _) => None,
(Utf8(v1), Utf8(v2)) => v1.partial_cmp(v2),
(Utf8(_), _) => None,
(LargeUtf8(v1), LargeUtf8(v2)) => v1.partial_cmp(v2),
(LargeUtf8(_), _) => None,
(Binary(v1), Binary(v2)) => v1.partial_cmp(v2),
(Binary(_), _) => None,
(LargeBinary(v1), LargeBinary(v2)) => v1.partial_cmp(v2),
(LargeBinary(_), _) => None,
(List(v1, t1), List(v2, t2)) => {
if t1.eq(t2) {
v1.partial_cmp(v2)
} else {
None
}
}
(List(_, _), _) => None,
(Date32(v1), Date32(v2)) => v1.partial_cmp(v2),
(Date32(_), _) => None,
(Date64(v1), Date64(v2)) => v1.partial_cmp(v2),
(Date64(_), _) => None,
(TimestampSecond(v1), TimestampSecond(v2)) => v1.partial_cmp(v2),
(TimestampSecond(_), _) => None,
(TimestampMillisecond(v1), TimestampMillisecond(v2)) => v1.partial_cmp(v2),
(TimestampMillisecond(_), _) => None,
(TimestampMicrosecond(v1), TimestampMicrosecond(v2)) => v1.partial_cmp(v2),
(TimestampMicrosecond(_), _) => None,
(TimestampNanosecond(v1), TimestampNanosecond(v2)) => v1.partial_cmp(v2),
(TimestampNanosecond(_), _) => None,
(IntervalYearMonth(v1), IntervalYearMonth(v2)) => v1.partial_cmp(v2),
(IntervalYearMonth(_), _) => None,
(IntervalDayTime(v1), IntervalDayTime(v2)) => v1.partial_cmp(v2),
(IntervalDayTime(_), _) => None,
}
}
}

impl Eq for ScalarValue {}

// manual implementation of `Hash` that uses OrderedFloat to
Expand Down Expand Up @@ -1577,4 +1653,74 @@ mod tests {
// per distinct value.
assert_eq!(std::mem::size_of::<ScalarValue>(), 32);
}

#[test]
fn scalar_partial_ordering() {
use ScalarValue::*;

assert_eq!(
Int64(Some(33)).partial_cmp(&Int64(Some(0))),
Some(Ordering::Greater)
);
assert_eq!(
Int64(Some(0)).partial_cmp(&Int64(Some(33))),
Some(Ordering::Less)
);
assert_eq!(
Int64(Some(33)).partial_cmp(&Int64(Some(33))),
Some(Ordering::Equal)
);
// For different data type, `partial_cmp` returns None.
assert_eq!(Int64(Some(33)).partial_cmp(&Int32(Some(33))), None);
assert_eq!(Int32(Some(33)).partial_cmp(&Int64(Some(33))), None);

assert_eq!(
List(
Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])),
Box::new(DataType::Int32)
)
.partial_cmp(&List(
Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])),
Box::new(DataType::Int32)
)),
Some(Ordering::Equal)
);

assert_eq!(
List(
Some(Box::new(vec![Int32(Some(10)), Int32(Some(5))])),
Box::new(DataType::Int32)
)
.partial_cmp(&List(
Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])),
Box::new(DataType::Int32)
)),
Some(Ordering::Greater)
);

assert_eq!(
List(
Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])),
Box::new(DataType::Int32)
)
.partial_cmp(&List(
Some(Box::new(vec![Int32(Some(10)), Int32(Some(5))])),
Box::new(DataType::Int32)
)),
Some(Ordering::Less)
);

// For different data type, `partial_cmp` returns None.
assert_eq!(
List(
Some(Box::new(vec![Int64(Some(1)), Int64(Some(5))])),
Box::new(DataType::Int64)
)
.partial_cmp(&List(
Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])),
Box::new(DataType::Int32)
)),
None
);
}
}

0 comments on commit 4ddd2f5

Please sign in to comment.