Skip to content

Commit

Permalink
add MastNodeId::TryFrom<u32>
Browse files Browse the repository at this point in the history
  • Loading branch information
sergerad committed Jul 15, 2024
1 parent bbd3bb7 commit 489e93f
Showing 1 changed file with 42 additions and 6 deletions.
48 changes: 42 additions & 6 deletions core/src/mast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,22 @@ pub trait MerkleTreeNode {
/// Note that the [`MastForest`] does *not* ensure that equal [`MastNode`]s have equal
/// [`MastNodeId`] handles. Hence, [`MastNodeId`] equality must not be used to test for equality of
/// the underlying [`MastNode`].
///
/// [`MastNodeId`] u32 value must be less than 2^30.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct MastNodeId(u32);

impl TryFrom<u32> for MastNodeId {
type Error = &'static str;

fn try_from(id: u32) -> Result<Self, Self::Error> {
if id > u32::MAX >> 2 {
return Err("MastNodeId must be less than 2^30");
}
Ok(Self(id))
}
}

impl fmt::Display for MastNodeId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "MastNodeId({})", self.0)
Expand Down Expand Up @@ -62,13 +75,13 @@ impl MastForest {
/// Adds a node to the forest, and returns the associated [`MastNodeId`].
///
/// Adding two duplicate nodes will result in two distinct returned [`MastNodeId`]s.
///
/// # Panics
///
/// This function will panic if the length of the provided node is greater than 2^30.
pub fn add_node(&mut self, node: MastNode) -> MastNodeId {
let new_node_id = MastNodeId(
self.nodes
.len()
.try_into()
.expect("invalid node id: exceeded maximum number of nodes in a single forest"),
);
let new_node_id = MastNodeId::try_from(self.nodes.len() as u32)
.expect("invalid node id: exceeded maximum number of nodes in a single forest");

self.nodes.push(node);

Expand Down Expand Up @@ -132,3 +145,26 @@ impl Index<MastNodeId> for MastForest {
&self.nodes[idx]
}
}

#[cfg(test)]
mod test {
#[test]
fn test_mast_node_id_try_from_u32() {
use super::MastNodeId;
use std::convert::TryFrom;

let tests = vec![
(u32::MAX, false),
(u32::MAX >> 1, false),
((u32::MAX >> 2) + 1, false),
(u32::MAX >> 2, true),
(1_073_741_824, false),
(1_073_741_823, true),
(0, true),
];
for (id, expected) in tests {
let result = MastNodeId::try_from(id).is_ok();
assert_eq!(result, expected, "id: {}", id);
}
}
}

0 comments on commit 489e93f

Please sign in to comment.