diff --git a/Cargo.lock b/Cargo.lock index 510500be49..218cce0ae4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4686,9 +4686,9 @@ checksum = "bebd363326d05ec3e2f532ab7660680f3b02130d780c299bca73469d521bc0ed" [[package]] name = "serde" -version = "1.0.163" +version = "1.0.150" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2113ab51b87a539ae008b5c6c02dc020ffa39afd2d83cffcb3f4eb2722cebec2" +checksum = "e326c9ec8042f1b5da33252c8a37e9ffbd2c9bef0155215b6e6c80c790e05f91" dependencies = [ "serde_derive", ] @@ -4725,13 +4725,13 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.163" +version = "1.0.150" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c805777e3930c8883389c602315a24224bcc738b63905ef87cd1420353ea93e" +checksum = "42a3df25b0713732468deadad63ab9da1f1fd75a48a15024b50363f128db627e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.18", + "syn 1.0.109", ] [[package]] diff --git a/base_layer/mmr/Cargo.toml b/base_layer/mmr/Cargo.toml index a036f549e3..c22001413d 100644 --- a/base_layer/mmr/Cargo.toml +++ b/base_layer/mmr/Cargo.toml @@ -13,19 +13,19 @@ native_bitmap = ["croaring"] benches = ["criterion"] [dependencies] -tari_utilities = "0.4.10" -tari_crypto = { version = "0.16"} -tari_common = {path = "../../common"} -thiserror = "1.0.26" -borsh = "0.9.3" -digest = "0.9.0" +tari_utilities = "0.4" +tari_crypto = { version = "0.16" } +tari_common = { path = "../../common" } +thiserror = "1.0" +borsh = "0.9" +digest = "0.9" log = "0.4" -serde = { version = "1.0.97", features = ["derive"] } -croaring = { version = "0.5.2", optional = true } -criterion = { version="0.2", optional = true } +serde = { version = "1.0", features = ["derive"] } +croaring = { version = "0.5", optional = true } +criterion = { version = "0.2", optional = true } [dev-dependencies] -rand="0.8.0" +rand = "0.8.0" blake2 = "0.9.0" serde_json = "1.0" bincode = "1.1" @@ -39,6 +39,6 @@ name = "bench" harness = false [[test]] -name="tari_mmr_integration_tests" -path="tests/mmr_integration_tests.rs" -required-features=["native_bitmap"] +name = "tari_mmr_integration_tests" +path = "tests/mmr_integration_tests.rs" +required-features = ["native_bitmap"] diff --git a/base_layer/mmr/src/balanced_binary_merkle_proof.rs b/base_layer/mmr/src/balanced_binary_merkle_proof.rs index bc17653c2e..6745ad8f4d 100644 --- a/base_layer/mmr/src/balanced_binary_merkle_proof.rs +++ b/base_layer/mmr/src/balanced_binary_merkle_proof.rs @@ -20,37 +20,32 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use std::{ - collections::HashMap, - convert::{TryFrom, TryInto}, - marker::PhantomData, -}; +use std::{collections::HashMap, convert::TryFrom, marker::PhantomData}; use borsh::{BorshDeserialize, BorshSerialize}; use digest::Digest; use serde::{Deserialize, Serialize}; use tari_common::DomainDigest; -use tari_utilities::ByteArray; use thiserror::Error; use crate::{common::hash_together, BalancedBinaryMerkleTree, Hash}; -pub(crate) fn cast_to_u32(value: usize) -> Result { +fn cast_to_u32(value: usize) -> Result { u32::try_from(value).map_err(|_| BalancedBinaryMerkleProofError::MathOverflow) } #[derive(BorshDeserialize, BorshSerialize, Deserialize, Serialize, Clone, Debug, Default, PartialEq, Eq)] pub struct BalancedBinaryMerkleProof { - pub path: Vec, - pub node_index: u32, + /// Since this is balanced tree, the index `2k+1` is always left child and `2k` is right child + path: Vec, + node_index: u32, _phantom: PhantomData, } -// Since this is balanced tree, the index `2k+1` is always left child and `2k` is right child - impl BalancedBinaryMerkleProof where D: Digest + DomainDigest { + #[must_use = "Must use the result of the proof verification"] pub fn verify(&self, root: &Hash, leaf_hash: Hash) -> bool { let mut computed_root = leaf_hash; let mut node_index = self.node_index; @@ -60,32 +55,51 @@ where D: Digest + DomainDigest } else { computed_root = hash_together::(sibling, &computed_root); } - node_index = (node_index - 1) >> 1; + + match node_index.checked_sub(1).and_then(|i| i.checked_shr(1)) { + Some(i) => { + node_index = i; + }, + None => return false, + } } - &computed_root == root + computed_root == *root } pub fn generate_proof( tree: &BalancedBinaryMerkleTree, leaf_index: usize, ) -> Result { - let mut node_index = tree.get_node_index(leaf_index); - let mut proof = Vec::new(); - while node_index > 0 { - // Sibling - let parent = (node_index - 1) >> 1; - // The children are 2i+1 and 2i+2, so together are 4i+3, we substract one, we get the other. - let sibling = 4 * parent + 3 - node_index; - proof.push(tree.get_hash(sibling).clone()); + let node_index = tree.get_node_index(leaf_index); + let mut index = node_index; + let mut path = Vec::new(); + while index > 0 { + // Parent at (i - 1) / 2 + let parent = (index - 1) >> 1; + // The children are 2i + 1 and 2i + 2, so together are 4i + 3. We subtract one, we get the other. + let sibling = 4 * parent + 3 - index; + let hash = tree + .get_hash(sibling) + .cloned() + .ok_or(BalancedBinaryMerkleProofError::TreeDoesNotContainLeafIndex { leaf_index })?; + path.push(hash); // Traverse to parent - node_index = parent; + index = parent; } Ok(Self { - path: proof, - node_index: cast_to_u32(tree.get_node_index(leaf_index))?, + path, + node_index: cast_to_u32(node_index)?, _phantom: PhantomData, }) } + + pub fn path(&self) -> &[Hash] { + &self.path + } + + pub fn node_index(&self) -> u32 { + self.node_index + } } #[derive(Debug, Error)] @@ -96,57 +110,64 @@ pub enum BalancedBinaryMerkleProofError { BadProofSemantics, #[error("Math overflow")] MathOverflow, + #[error("Tree does not contain leaf index {leaf_index}")] + TreeDoesNotContainLeafIndex { leaf_index: usize }, + #[error("Index {index} is out of range. The len is {len}")] + IndexOutOfRange { index: usize, len: usize }, } /// Flag to indicate if proof data represents an index or a node hash /// This reduces the need for checking lengths instead #[derive(Clone, Debug, Serialize, Deserialize)] -pub enum MergedBalancedBinaryMerkleDataType { - Index, - Hash, +pub enum MergedBalancedBinaryMerkleIndexOrHash { + Index(u64), + Hash(Hash), } #[derive(Debug, Serialize, Deserialize, Clone)] pub struct MergedBalancedBinaryMerkleProof { - pub paths: Vec)>>, // these tuples can contain indexes or hashes! - pub node_indices: Vec, - pub heights: Vec, + paths: Vec>, + node_indices: Vec, + heights: Vec, _phantom: PhantomData, } impl MergedBalancedBinaryMerkleProof where D: Digest + DomainDigest { - pub fn create_from_proofs( - proofs: Vec>, - ) -> Result { + pub fn create_from_proofs(proofs: &[BalancedBinaryMerkleProof]) -> Result { let heights = proofs .iter() .map(|proof| cast_to_u32(proof.path.len())) .collect::, _>>()?; - let max_height = heights + let max_height = *heights .iter() .max() .ok_or(BalancedBinaryMerkleProofError::CantMergeZeroProofs)?; + + if max_height == 0 { + return Err(BalancedBinaryMerkleProofError::BadProofSemantics); + } + let mut indices = proofs.iter().map(|proof| proof.node_index).collect::>(); let mut paths = vec![Vec::new(); proofs.len()]; - let mut join_indices = vec![None; proofs.len()]; - for height in (0..*max_height).rev() { - let mut hash_map = HashMap::new(); + let mut join_indices = vec![false; proofs.len()]; + let mut hash_map = HashMap::new(); + for height in (0..max_height).rev() { + hash_map.clear(); for (index, proof) in proofs.iter().enumerate() { // If this path was already joined ignore it. - if join_indices[index].is_none() && proof.path.len() > height as usize { + if !join_indices[index] && proof.path.len() > height as usize { let parent = (indices[index] - 1) >> 1; - if let Some(other_proof) = hash_map.insert(parent, index) { - join_indices[index] = Some(other_proof); - // The other proof doesn't need a hash, it needs an index to this proof - *paths[other_proof].first_mut().unwrap() = - (MergedBalancedBinaryMerkleDataType::Index, index.to_le_bytes().to_vec()); + if let Some(other_proof_idx) = hash_map.insert(parent, index) { + join_indices[index] = true; + // The other proof doesn't need a hash, it needs an index to this hash + *paths[other_proof_idx].first_mut().unwrap() = + MergedBalancedBinaryMerkleIndexOrHash::Index(index as u64); } else { paths[index].insert( 0, - ( - MergedBalancedBinaryMerkleDataType::Hash, + MergedBalancedBinaryMerkleIndexOrHash::Hash( proof.path[proof.path.len() - 1 - height as usize].clone(), ), ); @@ -155,6 +176,7 @@ where D: Digest + DomainDigest } } } + Ok(Self { paths, node_indices: proofs.iter().map(|proof| proof.node_index).collect::>(), @@ -166,80 +188,124 @@ where D: Digest + DomainDigest pub fn verify_consume( mut self, root: &Hash, - leaves_hashes: Vec, + leaf_hashes: Vec, ) -> Result { // Check that the proof and verifier data match let n = self.node_indices.len(); // number of merged proofs - if self.paths.len() != n || leaves_hashes.len() != n { + if self.paths.len() != n || leaf_hashes.len() != n { return Err(BalancedBinaryMerkleProofError::BadProofSemantics); } - let mut computed_hashes = leaves_hashes; - let max_height = self + let mut computed_hashes = leaf_hashes; + let max_height = *self .heights .iter() .max() .ok_or(BalancedBinaryMerkleProofError::CantMergeZeroProofs)?; // We need to compute the hashes row by row to be sure they are processed correctly. - for height in (0..*max_height).rev() { + for height in (0..max_height).rev() { let hashes = computed_hashes.clone(); - for (leaf, index) in computed_hashes.iter_mut().zip(0..n) { - if self.heights[index] > height { - if let Some(hash_or_index) = self.paths[index].pop() { - let hash = match hash_or_index.0 { - MergedBalancedBinaryMerkleDataType::Index => { - // An index must be a valid `usize` - let index = usize::from_le_bytes( - hash_or_index - .1 - .as_bytes() - .try_into() - .map_err(|_| BalancedBinaryMerkleProofError::BadProofSemantics)?, - ); - - // The index must also point to one of the proofs - if index < hashes.len() { - &hashes[index] - } else { - return Err(BalancedBinaryMerkleProofError::BadProofSemantics); - } - }, - MergedBalancedBinaryMerkleDataType::Hash => &hash_or_index.1, - }; - let parent = (self.node_indices[index] - 1) >> 1; - if self.node_indices[index] & 1 == 1 { - *leaf = hash_together::(leaf, hash); - } else { - *leaf = hash_together::(hash, leaf); - } - self.node_indices[index] = parent; - } + for (index, leaf) in computed_hashes.iter_mut().enumerate() { + if self.heights[index] <= height { + continue; } + + let Some(hash_or_index) = self.paths[index].pop() else { + // Path at this index already completely processed + continue; + }; + + let hash = match hash_or_index { + MergedBalancedBinaryMerkleIndexOrHash::Index(index) => { + let index = usize::try_from(index).map_err(|_| BalancedBinaryMerkleProofError::MathOverflow)?; + + // The index must also point to one of the proofs + hashes + .get(index) + .ok_or(BalancedBinaryMerkleProofError::IndexOutOfRange { + index, + len: hashes.len(), + })? + }, + MergedBalancedBinaryMerkleIndexOrHash::Hash(ref hash) => hash, + }; + // Left (2k + 1) or right (2k) sibling? + if self.node_indices[index] & 1 == 1 { + *leaf = hash_together::(leaf, hash); + } else { + *leaf = hash_together::(hash, leaf); + } + // Parent + self.node_indices[index] = (self.node_indices[index] - 1) >> 1; } } - Ok(&computed_hashes[0] == root) + Ok(computed_hashes[0] == *root) } } #[cfg(test)] mod test { use tari_crypto::{hash::blake2::Blake256, hash_domain, hashing::DomainSeparatedHasher}; + use tari_utilities::hex::from_hex; + + use super::*; - use super::MergedBalancedBinaryMerkleProof; - use crate::{BalancedBinaryMerkleProof, BalancedBinaryMerkleTree}; hash_domain!(TestDomain, "testing", 0); + type TestHasher = DomainSeparatedHasher; + + #[test] + fn test_small_tree() { + let leaves = (0..4usize).map(|i| vec![i as u8; 32]).collect::>(); + let bmt = BalancedBinaryMerkleTree::::create(leaves.clone()); + + assert_eq!(bmt.num_nodes(), (4 << 1) - 1); + assert_eq!(bmt.num_leaf_nodes(), 4); + let root = bmt.get_merkle_root(); + let proof = BalancedBinaryMerkleProof::generate_proof(&bmt, 0).unwrap(); + assert!(proof.verify(&root, leaves[0].clone())); + assert!(!proof.verify(&root, leaves[1].clone())); + assert!(!proof.verify(&root, leaves[2].clone())); + assert!(!proof.verify(&root, leaves[3].clone())); + + let proof1 = BalancedBinaryMerkleProof::generate_proof(&bmt, 1).unwrap(); + + let merged = MergedBalancedBinaryMerkleProof::create_from_proofs(&[proof, proof1]).unwrap(); + assert!(merged + .verify_consume(&root, vec![leaves[0].clone(), leaves[1].clone()]) + .unwrap()); + } + + #[test] + fn test_zero_height_proof_should_be_invalid() { + let proof = MergedBalancedBinaryMerkleProof:: { + paths: vec![vec![]], + node_indices: vec![0], + heights: vec![0], + _phantom: PhantomData, + }; + assert!(!proof.verify_consume(&vec![0u8; 32], vec![vec![]]).unwrap()); + + let proof = MergedBalancedBinaryMerkleProof:: { + paths: vec![vec![]], + node_indices: vec![0], + heights: vec![1], + _phantom: PhantomData, + }; + assert!(!proof.verify_consume(&vec![0u8; 32], vec![vec![]]).unwrap()); + } + #[test] fn test_generate_and_verify_big_tree() { - for n in [1usize, 100, 1000, 10000] { + for n in [1usize, 100, 1000, 10_000] { let leaves = (0..n) .map(|i| [i.to_le_bytes().to_vec(), vec![0u8; 24]].concat()) .collect::>(); let hash_0 = leaves[0].clone(); let hash_n_half = leaves[n / 2].clone(); let hash_last = leaves[n - 1].clone(); - let bmt = BalancedBinaryMerkleTree::>::create(leaves); + let bmt = BalancedBinaryMerkleTree::::create(leaves); let root = bmt.get_merkle_root(); let proof = BalancedBinaryMerkleProof::generate_proof(&bmt, 0).unwrap(); assert!(proof.verify(&root, hash_0)); @@ -253,7 +319,7 @@ mod test { #[test] fn test_merge_proof() { let leaves = (0..255).map(|i| vec![i; 32]).collect::>(); - let bmt = BalancedBinaryMerkleTree::>::create(leaves.clone()); + let bmt = BalancedBinaryMerkleTree::::create(leaves.clone()); let indices = [50, 0, 200, 150, 100]; let root = bmt.get_merkle_root(); let proofs = indices @@ -261,7 +327,7 @@ mod test { .map(|i| BalancedBinaryMerkleProof::generate_proof(&bmt, *i)) .collect::, _>>() .unwrap(); - let merged_proof = MergedBalancedBinaryMerkleProof::create_from_proofs(proofs).unwrap(); + let merged_proof = MergedBalancedBinaryMerkleProof::create_from_proofs(&proofs).unwrap(); assert!(merged_proof .verify_consume(&root, indices.iter().map(|i| leaves[*i].clone()).collect::>()) .unwrap()); @@ -270,13 +336,112 @@ mod test { #[test] fn test_merge_proof_full_tree() { let leaves = (0..255).map(|i| vec![i; 32]).collect::>(); - let bmt = BalancedBinaryMerkleTree::>::create(leaves.clone()); + let bmt = BalancedBinaryMerkleTree::::create(leaves.clone()); let root = bmt.get_merkle_root(); let proofs = (0..255) .map(|i| BalancedBinaryMerkleProof::generate_proof(&bmt, i)) .collect::, _>>() .unwrap(); - let merged_proof = MergedBalancedBinaryMerkleProof::create_from_proofs(proofs).unwrap(); + let merged_proof = MergedBalancedBinaryMerkleProof::create_from_proofs(&proofs).unwrap(); assert!(merged_proof.verify_consume(&root, leaves).unwrap()); } + + #[test] + fn test_verify_faulty_proof() { + let faulty_proof = BalancedBinaryMerkleProof:: { + path: vec![vec![1u8; 32], vec![1u8; 32]], + node_index: 2, + _phantom: Default::default(), + }; + + // This used to panic since this proof is not possible by using generate_proof + assert!(!faulty_proof.verify(&vec![0u8; 32], vec![0u8; 32])); + + let faulty_proof = BalancedBinaryMerkleProof:: { + path: vec![vec![1u8; 32], vec![1u8; 32], vec![0u8; 32], vec![0u8; 32]], + node_index: 3, + _phantom: Default::default(), + }; + assert!(!faulty_proof.verify(&vec![0u8; 32], vec![0u8; 32])); + + // Merged proof - no panic + let proof = MergedBalancedBinaryMerkleProof:: { + paths: vec![], + node_indices: vec![], + heights: vec![], + _phantom: PhantomData, + }; + proof.verify_consume(&vec![0u8; 32], vec![]).unwrap_err(); + + let proof = MergedBalancedBinaryMerkleProof:: { + paths: vec![vec![MergedBalancedBinaryMerkleIndexOrHash::Hash(vec![1u8; 32])], vec![ + MergedBalancedBinaryMerkleIndexOrHash::Hash(vec![2u8; 32]), + ]], + node_indices: vec![1, 1], + // max_height == 0 which equates to leaf_hash[0] == root, even though this proof is invalid. + // This assumes an attacker can control the first leaf hash. + heights: vec![0, 0], + _phantom: PhantomData, + }; + // TODO: This should fail but does not + // proof .verify_consume(&vec![5u8; 32], vec![vec![5u8; 32], vec![2u8; 32]]) .unwrap_err(); + assert!(proof + .verify_consume(&vec![5u8; 32], vec![vec![5u8; 32], vec![2u8; 32]]) + .unwrap()); + } + + #[test] + fn test_generate_faulty_proof() { + let bmt = BalancedBinaryMerkleTree::::create(vec![]); + let err = BalancedBinaryMerkleProof::::generate_proof(&bmt, 1).unwrap_err(); + assert!(matches!( + err, + BalancedBinaryMerkleProofError::TreeDoesNotContainLeafIndex { leaf_index: 1 } + )); + } + + #[test] + fn test_real_world_example() { + hash_domain!( + ValidatorNodeBmtHashDomain, + "com.tari.tari_project.base_layer.core.validator_node_mmr", + 1 + ); + pub type ValidatorNodeBmtHasherBlake256 = DomainSeparatedHasher; + let root = from_hex("faa36732a63077aa0eafcae451c5b12ee6971f1329b8ce9f966289168fdc4c5b").unwrap(); + let testdata = vec![ + // (bincode encoded proof as hex, node hash) + ("030000000000000020000000000000007152175a9df02caf2f7078d41c9523f627232e89d7ed208bde8ad30512cc5ae22000000000000000a0b14150acc67458e95ba40cbdbf0daa4622220b48fd36a9908b1ce1dd9f0ebf20000000000000008ba9eb45f6a462707bcc929b622369c45a62a0c423b6318c0f8b686dc1294af70e000000", "5f0d31e3a5f8a741702b609e2d7594cbedbddd6e93fe8145d06b752e4e4d20b3"), + ("0400000000000000200000000000000010d997fc0d9f7825ac853b6086650f961bcf3179b3a9fd8b961ba917603321292000000000000000eabff75e4f71e94527127ec5742f5d0d91d4e38fb6bd7726f9c48e44454f2fc420000000000000004077bc7fb1f539818f7ac581a8131a6ee3a516f13822db1199d455bcd24e896c2000000000000000d24c6e09533fcb8dcfcb964c9ef3313d789e85ad4c5e270f69b884de8ad424280f000000", "cba14c691513694e2b94bc270aad6d06c24f18f5c67f207eedb7821aa1f1e02a"), + ("04000000000000002000000000000000cba14c691513694e2b94bc270aad6d06c24f18f5c67f207eedb7821aa1f1e02a2000000000000000eabff75e4f71e94527127ec5742f5d0d91d4e38fb6bd7726f9c48e44454f2fc420000000000000004077bc7fb1f539818f7ac581a8131a6ee3a516f13822db1199d455bcd24e896c2000000000000000d24c6e09533fcb8dcfcb964c9ef3313d789e85ad4c5e270f69b884de8ad4242810000000", "10d997fc0d9f7825ac853b6086650f961bcf3179b3a9fd8b961ba91760332129"), + ("0400000000000000200000000000000056ceb0eb5bce9d33b775bedfabc2884b10852216737632132209d74bf6a4192f200000000000000020e7c8546d77b299faaf5d025b34c20d606555e99e66a5bf95e9f845853feccf20000000000000004077bc7fb1f539818f7ac581a8131a6ee3a516f13822db1199d455bcd24e896c2000000000000000d24c6e09533fcb8dcfcb964c9ef3313d789e85ad4c5e270f69b884de8ad4242811000000", "c2fd409e09a1e4c942fdf1fb6d75f15f971d9ae2a17621eda197bc0c21a503c4"), + ("04000000000000002000000000000000c2fd409e09a1e4c942fdf1fb6d75f15f971d9ae2a17621eda197bc0c21a503c4200000000000000020e7c8546d77b299faaf5d025b34c20d606555e99e66a5bf95e9f845853feccf20000000000000004077bc7fb1f539818f7ac581a8131a6ee3a516f13822db1199d455bcd24e896c2000000000000000d24c6e09533fcb8dcfcb964c9ef3313d789e85ad4c5e270f69b884de8ad4242812000000", "56ceb0eb5bce9d33b775bedfabc2884b10852216737632132209d74bf6a4192f"), + ("0400000000000000200000000000000033bb552bb30f28eff843e05d327776366b8cf8ae04d5a69e038a2f4a3157ff6620000000000000007008f070d4cfbd6e91cbfb27ec911d56c664acfcb88e451da792e9ef0277ced22000000000000000b9d9216cc6679406cf8b5995a0473ebfe0584ae71ce0fc3aa01b76d3794526e92000000000000000d24c6e09533fcb8dcfcb964c9ef3313d789e85ad4c5e270f69b884de8ad4242813000000", "1d0a655879908a688ff97f08f05d067a9c30cb0192655fbb895699b8a1e36072"), + ("040000000000000020000000000000001d0a655879908a688ff97f08f05d067a9c30cb0192655fbb895699b8a1e3607220000000000000007008f070d4cfbd6e91cbfb27ec911d56c664acfcb88e451da792e9ef0277ced22000000000000000b9d9216cc6679406cf8b5995a0473ebfe0584ae71ce0fc3aa01b76d3794526e92000000000000000d24c6e09533fcb8dcfcb964c9ef3313d789e85ad4c5e270f69b884de8ad4242814000000", "33bb552bb30f28eff843e05d327776366b8cf8ae04d5a69e038a2f4a3157ff66"), + ("0400000000000000200000000000000097c5bc19efb43f536f078d401d8e8cb130c0329bbc2b2116608f14adc7a7cdd420000000000000000f1af01536d530734a08d904e4c4e4224d3c5b42df4f7235d0d439efe148b36a2000000000000000b9d9216cc6679406cf8b5995a0473ebfe0584ae71ce0fc3aa01b76d3794526e92000000000000000d24c6e09533fcb8dcfcb964c9ef3313d789e85ad4c5e270f69b884de8ad4242815000000", "939e7cd43ed3c31774ebcf53525963cb668de84b88ddfb2f2efc72814599f44a"), + ("04000000000000002000000000000000939e7cd43ed3c31774ebcf53525963cb668de84b88ddfb2f2efc72814599f44a20000000000000000f1af01536d530734a08d904e4c4e4224d3c5b42df4f7235d0d439efe148b36a2000000000000000b9d9216cc6679406cf8b5995a0473ebfe0584ae71ce0fc3aa01b76d3794526e92000000000000000d24c6e09533fcb8dcfcb964c9ef3313d789e85ad4c5e270f69b884de8ad4242816000000", "97c5bc19efb43f536f078d401d8e8cb130c0329bbc2b2116608f14adc7a7cdd4"), + ]; + let leaf_hashes = testdata + .iter() + .map(|(_, leaf_hash)| from_hex(leaf_hash).unwrap()) + .collect::>(); + + let proofs = testdata + .into_iter() + .enumerate() + .map(|(i, (data, leaf_hash))| { + let proof: BalancedBinaryMerkleProof = + bincode::deserialize(&from_hex(data).unwrap()).unwrap(); + assert!( + proof.verify(&root, from_hex(leaf_hash).unwrap()), + "proof {} is invalid", + i + ); + proof + }) + .collect::>(); + + let merged = MergedBalancedBinaryMerkleProof::create_from_proofs(&proofs).unwrap(); + assert!(merged.verify_consume(&root, leaf_hashes.clone()).unwrap()); + } } diff --git a/base_layer/mmr/src/balanced_binary_merkle_tree.rs b/base_layer/mmr/src/balanced_binary_merkle_tree.rs index 488ad22397..b11911418f 100644 --- a/base_layer/mmr/src/balanced_binary_merkle_tree.rs +++ b/base_layer/mmr/src/balanced_binary_merkle_tree.rs @@ -26,7 +26,7 @@ use digest::Digest; use tari_common::DomainDigest; use thiserror::Error; -use crate::{common::hash_together, ArrayLike, Hash}; +use crate::{common::hash_together, Hash}; pub(crate) fn cast_to_u32(value: usize) -> Result { u32::try_from(value).map_err(|_| BalancedBinaryMerkleTreeError::MathOverFlow) @@ -86,11 +86,25 @@ where D: Digest + DomainDigest } } - pub(crate) fn get_hash(&self, pos: usize) -> &Hash { - &self.hashes[pos] + /// Returns the number of _leaf_ nodes in the tree. That is, the number of hashes that are committed to by the + /// Merkle root. + pub fn num_leaf_nodes(&self) -> usize { + if self.hashes.is_empty() { + return 0; + } + ((self.hashes.len() - 1) >> 1) + 1 + } + + /// Returns the number of nodes in the tree. + pub fn num_nodes(&self) -> usize { + self.hashes.len() + } + + pub fn get_hash(&self, pos: usize) -> Option<&Hash> { + self.hashes.get(pos) } - pub fn get_leaf(&self, leaf_index: usize) -> &Hash { + pub fn get_leaf(&self, leaf_index: usize) -> Option<&Hash> { self.get_hash(self.get_node_index(leaf_index)) } @@ -101,8 +115,8 @@ where D: Digest + DomainDigest pub fn find_leaf_index_for_hash(&self, hash: &Hash) -> Result { let pos = self .hashes - .position(hash) - .expect("Unexpected Balanced Binary Merkle Tree error") + .iter() + .position(|h| h == hash) .ok_or(BalancedBinaryMerkleTreeError::LeafNotFound)?; if pos < (self.hashes.len() >> 1) { // The hash provided was not for leaf, but for node. @@ -124,6 +138,7 @@ mod test { fn test_empty_tree() { let leaves = vec![]; let bmt = BalancedBinaryMerkleTree::>::create(leaves); + assert_eq!(bmt.num_leaf_nodes(), 0); let root = bmt.get_merkle_root(); assert_eq!(root, vec![ 72, 54, 179, 2, 214, 45, 9, 89, 161, 132, 177, 251, 229, 46, 124, 233, 32, 186, 46, 87, 127, 247, 19, 36, @@ -135,6 +150,7 @@ mod test { fn test_single_node_tree() { let leaves = vec![vec![0; 32]]; let bmt = BalancedBinaryMerkleTree::>::create(leaves); + assert_eq!(bmt.num_leaf_nodes(), 1); let root = bmt.get_merkle_root(); assert_eq!(root, vec![0; 32]); } @@ -143,6 +159,8 @@ mod test { fn test_find_leaf() { let leaves = (0..100).map(|i| vec![i; 32]).collect::>(); let bmt = BalancedBinaryMerkleTree::>::create(leaves); + assert_eq!(bmt.num_leaf_nodes(), 100); + assert_eq!(bmt.num_nodes(), (100 << 1) - 1); assert_eq!(bmt.find_leaf_index_for_hash(&vec![42; 32]).unwrap(), 42); // Non existing hash assert_eq!(