Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add necessary trait bounds to balanced merkle tree #5232

1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions base_layer/mmr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ tari_utilities = { git = "https://github.com/tari-project/tari_utilities.git", t
tari_crypto = { git = "https://github.com/tari-project/tari-crypto.git", tag = "v0.16.8" }
tari_common = {path = "../../common"}
thiserror = "1.0.26"
borsh = "0.9.3"
digest = "0.9.0"
log = "0.4"
serde = { version = "1.0.97", features = ["derive"] }
Expand Down
54 changes: 37 additions & 17 deletions base_layer/mmr/src/balanced_binary_merkle_proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,29 @@
// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

use std::{collections::HashMap, convert::TryInto, marker::PhantomData};
use std::{
collections::HashMap,
convert::{TryFrom, TryInto},
marker::PhantomData,
};

use borsh::{BorshDeserialize, BorshSerialize};
use digest::Digest;
use serde::{Deserialize, Serialize};
use tari_common::DomainDigest;
use tari_utilities::ByteArray;
use thiserror::Error;

use crate::{common::hash_together, BalancedBinaryMerkleTree, Hash};

#[derive(Debug)]
pub(crate) fn cast_to_u32(value: usize) -> Result<u32, MergedBalancedBinaryMerkleProofError> {
u32::try_from(value).map_err(|_| MergedBalancedBinaryMerkleProofError::MathOverflow)
}

#[derive(BorshDeserialize, BorshSerialize, Deserialize, Serialize, Clone, Debug, Default, PartialEq, Eq)]
pub struct BalancedBinaryMerkleProof<D> {
pub path: Vec<Hash>,
pub node_index: usize,
pub node_index: u32,
_phantom: PhantomData<D>,
}

Expand All @@ -55,7 +65,10 @@ where D: Digest + DomainDigest
&computed_root == root
}

pub fn generate_proof(tree: &BalancedBinaryMerkleTree<D>, leaf_index: usize) -> Self {
pub fn generate_proof(
tree: &BalancedBinaryMerkleTree<D>,
leaf_index: usize,
) -> Result<Self, MergedBalancedBinaryMerkleProofError> {
let mut node_index = tree.get_node_index(leaf_index);
let mut proof = Vec::new();
while node_index > 0 {
Expand All @@ -67,11 +80,11 @@ where D: Digest + DomainDigest
// Traverse to parent
node_index = parent;
}
Self {
Ok(Self {
path: proof,
node_index: tree.get_node_index(leaf_index),
node_index: cast_to_u32(tree.get_node_index(leaf_index))?,
_phantom: PhantomData,
}
})
}
}

Expand All @@ -81,6 +94,8 @@ pub enum MergedBalancedBinaryMerkleProofError {
CantMergeZeroProofs,
#[error("Bad proof semantics")]
BadProofSemantics,
#[error("Math overflow")]
MathOverflow,
jorgeantonio21 marked this conversation as resolved.
Show resolved Hide resolved
}

/// Flag to indicate if proof data represents an index or a node hash
Expand All @@ -94,8 +109,8 @@ pub enum MergedBalancedBinaryMerkleDataType {
#[derive(Debug)]
pub struct MergedBalancedBinaryMerkleProof<D> {
pub paths: Vec<Vec<(MergedBalancedBinaryMerkleDataType, Vec<u8>)>>, // these tuples can contain indexes or hashes!
pub node_indices: Vec<usize>,
pub heights: Vec<usize>,
pub node_indices: Vec<u32>,
pub heights: Vec<u32>,
_phantom: PhantomData<D>,
}

Expand All @@ -105,7 +120,10 @@ where D: Digest + DomainDigest
pub fn create_from_proofs(
proofs: Vec<BalancedBinaryMerkleProof<D>>,
) -> Result<Self, MergedBalancedBinaryMerkleProofError> {
let heights = proofs.iter().map(|proof| proof.path.len()).collect::<Vec<_>>();
let heights = proofs
.iter()
.map(|proof| cast_to_u32(proof.path.len()))
.collect::<Result<Vec<_>, _>>()?;
let max_height = heights
.iter()
.max()
Expand All @@ -117,7 +135,7 @@ where D: Digest + DomainDigest
let mut hash_map = HashMap::new();
for (index, proof) in proofs.iter().enumerate() {
// If this path was already joined ignore it.
if join_indices[index].is_none() && proof.path.len() > height {
if join_indices[index].is_none() && proof.path.len() > height as usize {
let parent = (indices[index] - 1) >> 1;
if let Some(other_proof) = hash_map.insert(parent, index) {
join_indices[index] = Some(other_proof);
Expand All @@ -129,7 +147,7 @@ where D: Digest + DomainDigest
0,
(
MergedBalancedBinaryMerkleDataType::Hash,
proof.path[proof.path.len() - 1 - height].clone(),
proof.path[proof.path.len() - 1 - height as usize].clone(),
),
);
}
Expand Down Expand Up @@ -223,11 +241,11 @@ mod test {
let hash_last = leaves[n - 1].clone();
let bmt = BalancedBinaryMerkleTree::<DomainSeparatedHasher<Blake256, TestDomain>>::create(leaves);
let root = bmt.get_merkle_root();
let proof = BalancedBinaryMerkleProof::generate_proof(&bmt, 0);
let proof = BalancedBinaryMerkleProof::generate_proof(&bmt, 0).unwrap();
assert!(proof.verify(&root, hash_0));
let proof = BalancedBinaryMerkleProof::generate_proof(&bmt, n / 2);
let proof = BalancedBinaryMerkleProof::generate_proof(&bmt, n / 2).unwrap();
assert!(proof.verify(&root, hash_n_half));
let proof = BalancedBinaryMerkleProof::generate_proof(&bmt, n - 1);
let proof = BalancedBinaryMerkleProof::generate_proof(&bmt, n - 1).unwrap();
assert!(proof.verify(&root, hash_last));
}
}
Expand All @@ -241,7 +259,8 @@ mod test {
let proofs = indices
.iter()
.map(|i| BalancedBinaryMerkleProof::generate_proof(&bmt, *i))
.collect::<Vec<_>>();
.collect::<Result<Vec<_>, _>>()
.unwrap();
let merged_proof = MergedBalancedBinaryMerkleProof::create_from_proofs(proofs).unwrap();
assert!(merged_proof
.verify_consume(&root, indices.iter().map(|i| leaves[*i].clone()).collect::<Vec<_>>())
Expand All @@ -255,7 +274,8 @@ mod test {
let root = bmt.get_merkle_root();
let proofs = (0..255)
.map(|i| BalancedBinaryMerkleProof::generate_proof(&bmt, i))
.collect::<Vec<_>>();
.collect::<Result<Vec<_>, _>>()
.unwrap();
let merged_proof = MergedBalancedBinaryMerkleProof::create_from_proofs(proofs).unwrap();
assert!(merged_proof.verify_consume(&root, leaves).unwrap());
}
Expand Down
12 changes: 9 additions & 3 deletions base_layer/mmr/src/balanced_binary_merkle_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,24 @@
// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

use std::marker::PhantomData;
use std::{convert::TryFrom, marker::PhantomData};

use digest::Digest;
use tari_common::DomainDigest;
use thiserror::Error;

use crate::{common::hash_together, ArrayLike, Hash};

pub(crate) fn cast_to_u32(value: usize) -> Result<u32, BalancedBinaryMerkleTreeError> {
u32::try_from(value).map_err(|_| BalancedBinaryMerkleTreeError::MathOverFlow)
}

#[derive(Clone, Debug, PartialEq, Eq, Error)]
pub enum BalancedBinaryMerkleTreeError {
#[error("There is no leaf with the hash provided.")]
LeafNotFound,
#[error("Math overflow")]
MathOverFlow,
}

// The hashes are perfectly balanced binary tree, so parent at index `i` (0-based) has children at positions `2*i+1` and
Expand Down Expand Up @@ -92,7 +98,7 @@ where D: Digest + DomainDigest
leaf_index + (self.hashes.len() >> 1)
}

pub fn find_leaf_index_for_hash(&self, hash: &Hash) -> Result<usize, BalancedBinaryMerkleTreeError> {
pub fn find_leaf_index_for_hash(&self, hash: &Hash) -> Result<u32, BalancedBinaryMerkleTreeError> {
let pos = self
.hashes
.position(hash)
Expand All @@ -102,7 +108,7 @@ where D: Digest + DomainDigest
// The hash provided was not for leaf, but for node.
Err(BalancedBinaryMerkleTreeError::LeafNotFound)
} else {
Ok(pos - (self.hashes.len() >> 1))
Ok(cast_to_u32(pos - (self.hashes.len() >> 1))?)
}
}
}
Expand Down