diff --git a/library/alloc/src/collections/btree/node.rs b/library/alloc/src/collections/btree/node.rs index 4fa97ff053e60..903faf3fa969a 100644 --- a/library/alloc/src/collections/btree/node.rs +++ b/library/alloc/src/collections/btree/node.rs @@ -257,8 +257,13 @@ impl Root { /// `NodeRef` points to an internal node, and when this is `LeafOrInternal` the /// `NodeRef` could be pointing to either type of node. pub struct NodeRef { - /// The number of levels below the node. + /// The number of levels below the node, a property of the node that cannot be + /// entirely described by `Type` and that the node does not store itself either. + /// Unconstrained if `Type` is `LeafOrInternal`, must be zero if `Type` is `Leaf`, + /// and must be non-zero if `Type` is `Internal`. height: usize, + /// The pointer to the leaf or internal node. The definition of `InternalNode` + /// ensures that the pointer is valid either way. node: NonNull>, _marker: PhantomData<(BorrowType, Type)>, } @@ -315,8 +320,8 @@ impl NodeRef { unsafe { usize::from((*self.as_leaf_ptr()).len) } } - /// Returns the height of this node in the whole tree. Zero height denotes the - /// leaf level. + /// Returns the height of this node with respect to the leaf level. Zero height means the + /// node is a leaf itself. pub fn height(&self) -> usize { self.height } @@ -584,9 +589,11 @@ impl<'a, K, V, Type> NodeRef, K, V, Type> { // to avoid aliasing with outstanding references to other elements, // in particular, those returned to the caller in earlier iterations. let leaf = self.node.as_ptr(); + let keys = unsafe { &raw const (*leaf).keys }; + let vals = unsafe { &raw mut (*leaf).vals }; // We must coerce to unsized array pointers because of Rust issue #74679. - let keys: *const [_] = unsafe { &raw const (*leaf).keys }; - let vals: *mut [_] = unsafe { &raw mut (*leaf).vals }; + let keys: *const [_] = keys; + let vals: *mut [_] = vals; // SAFETY: The keys and values of a node must always be initialized up to length. let key = unsafe { (&*keys.get_unchecked(idx)).assume_init_ref() }; let val = unsafe { (&mut *vals.get_unchecked_mut(idx)).assume_init_mut() }; @@ -817,11 +824,25 @@ impl Handle, mar } } +impl NodeRef { + /// Could be a public implementation of PartialEq, but only used in this module. + fn eq(&self, other: &Self) -> bool { + let Self { node, height, _marker: _ } = self; + if *node == other.node { + debug_assert_eq!(*height, other.height); + true + } else { + false + } + } +} + impl PartialEq for Handle, HandleType> { fn eq(&self, other: &Self) -> bool { - self.node.node == other.node.node && self.idx == other.idx + let Self { node, idx, _marker: _ } = self; + node.eq(&other.node) && *idx == other.idx } } @@ -829,7 +850,8 @@ impl PartialOrd for Handle, HandleType> { fn partial_cmp(&self, other: &Self) -> Option { - if self.node.node == other.node.node { Some(self.idx.cmp(&other.idx)) } else { None } + let Self { node, idx, _marker: _ } = self; + if node.eq(&other.node) { Some(idx.cmp(&other.idx)) } else { None } } } diff --git a/library/alloc/src/collections/btree/node/tests.rs b/library/alloc/src/collections/btree/node/tests.rs index 54c3709821acd..2ef9aad0ccdcf 100644 --- a/library/alloc/src/collections/btree/node/tests.rs +++ b/library/alloc/src/collections/btree/node/tests.rs @@ -1,4 +1,5 @@ use super::*; +use core::cmp::Ordering::*; #[test] fn test_splitpoint() { @@ -24,6 +25,38 @@ fn test_splitpoint() { } } +#[test] +fn test_partial_cmp_eq() { + let mut root1: Root = Root::new_leaf(); + let mut leaf1 = unsafe { root1.leaf_node_as_mut() }; + leaf1.push(1, ()); + root1.push_internal_level(); + let root2: Root = Root::new_leaf(); + + let leaf_edge_1a = root1.node_as_ref().first_leaf_edge().forget_node_type(); + let leaf_edge_1b = root1.node_as_ref().last_leaf_edge().forget_node_type(); + let top_edge_1 = root1.node_as_ref().first_edge(); + let top_edge_2 = root2.node_as_ref().first_edge(); + + assert!(leaf_edge_1a == leaf_edge_1a); + assert!(leaf_edge_1a != leaf_edge_1b); + assert!(leaf_edge_1a != top_edge_1); + assert!(leaf_edge_1a != top_edge_2); + assert!(top_edge_1 == top_edge_1); + assert!(top_edge_1 != top_edge_2); + + assert_eq!(leaf_edge_1a.partial_cmp(&leaf_edge_1a), Some(Equal)); + assert_eq!(leaf_edge_1a.partial_cmp(&leaf_edge_1b), Some(Less)); + assert_eq!(leaf_edge_1a.partial_cmp(&top_edge_1), None); + assert_eq!(leaf_edge_1a.partial_cmp(&top_edge_2), None); + assert_eq!(top_edge_1.partial_cmp(&top_edge_1), Some(Equal)); + assert_eq!(top_edge_1.partial_cmp(&top_edge_2), None); + + root1.pop_internal_level(); + unsafe { root1.into_ref().deallocate_and_ascend() }; + unsafe { root2.into_ref().deallocate_and_ascend() }; +} + #[test] #[cfg(target_arch = "x86_64")] fn test_sizes() {