Skip to content

Commit

Permalink
feat: DataType::contains support nested type (#4042)
Browse files Browse the repository at this point in the history
* feat: DataType::contains support nested type

* support recurse

* check typeID for Union
  • Loading branch information
Weijun-H authored Apr 12, 2023
1 parent 6ce332a commit a35c6c5
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 1 deletion.
26 changes: 26 additions & 0 deletions arrow-schema/src/datatype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,32 @@ impl DataType {
}
}
}

/// Check to see if `self` is a superset of `other`
///
/// If DataType is a nested type, then it will check to see if the nested type is a superset of the other nested type
/// else it will check to see if the DataType is equal to the other DataType
pub fn contains(&self, other: &DataType) -> bool {
match (self, other) {
(DataType::List(f1), DataType::List(f2))
| (DataType::LargeList(f1), DataType::LargeList(f2)) => f1.contains(f2),
(DataType::FixedSizeList(f1, s1), DataType::FixedSizeList(f2, s2)) => {
s1 == s2 && f1.contains(f2)
}
(DataType::Map(f1, s1), DataType::Map(f2, s2)) => s1 == s2 && f1.contains(f2),
(DataType::Struct(f1), DataType::Struct(f2)) => f1.contains(f2),
(DataType::Union(f1, s1), DataType::Union(f2, s2)) => {
s1 == s2
&& f1
.iter()
.all(|f1| f2.iter().any(|f2| f1.0 == f2.0 && f1.1.contains(f2.1)))
}
(DataType::Dictionary(k1, v1), DataType::Dictionary(k2, v2)) => {
k1.contains(k2) && v1.contains(v2)
}
_ => self == other,
}
}
}

/// The maximum precision for [DataType::Decimal128] values
Expand Down
64 changes: 63 additions & 1 deletion arrow-schema/src/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ impl Field {
/// * all other fields are equal
pub fn contains(&self, other: &Field) -> bool {
self.name == other.name
&& self.data_type == other.data_type
&& self.data_type.contains(&other.data_type)
&& self.dict_id == other.dict_id
&& self.dict_is_ordered == other.dict_is_ordered
// self need to be nullable or both of them are not nullable
Expand Down Expand Up @@ -758,6 +758,68 @@ mod test {

assert!(!field1.contains(&field2));
assert!(!field2.contains(&field1));

// UnionFields with different type ID
let field1 = Field::new(
"field1",
DataType::Union(
UnionFields::new(
vec![1, 2],
vec![
Field::new("field1", DataType::UInt8, true),
Field::new("field3", DataType::Utf8, false),
],
),
UnionMode::Dense,
),
true,
);
let field2 = Field::new(
"field1",
DataType::Union(
UnionFields::new(
vec![1, 3],
vec![
Field::new("field1", DataType::UInt8, false),
Field::new("field3", DataType::Utf8, false),
],
),
UnionMode::Dense,
),
true,
);
assert!(!field1.contains(&field2));

// UnionFields with same type ID
let field1 = Field::new(
"field1",
DataType::Union(
UnionFields::new(
vec![1, 2],
vec![
Field::new("field1", DataType::UInt8, true),
Field::new("field3", DataType::Utf8, false),
],
),
UnionMode::Dense,
),
true,
);
let field2 = Field::new(
"field1",
DataType::Union(
UnionFields::new(
vec![1, 2],
vec![
Field::new("field1", DataType::UInt8, false),
Field::new("field3", DataType::Utf8, false),
],
),
UnionMode::Dense,
),
true,
);
assert!(field1.contains(&field2));
}

#[cfg(feature = "serde")]
Expand Down

0 comments on commit a35c6c5

Please sign in to comment.