Skip to content

Commit

Permalink
feat: add necessary trait bounds to balanced merkle tree (#5232)
Browse files Browse the repository at this point in the history
Description
---
Introduces necessary trait bounds to the `BalancedBinaryMerkleProof` struct.

Motivation and Context
---
In order to update the Tari DAN layer repo to make use of BMTs, we need that the current struct `BalancedBinaryMerkleProof` derives certain trait bounds, namely `#[derive(Clone, PartialEq, Eq, Deserialize, Serialize)]`.

How Has This Been Tested?
---
Run `cargo build`.


<!-- Does this include a breaking change? If so, include this line as a footer -->
<!-- BREAKING CHANGE: Description what the user should do, e.g. delete a database, resync the chain -->
  • Loading branch information
jorgeantonio21 authored Mar 10, 2023
1 parent 78838bf commit 3b971a3
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 29 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions base_layer/mmr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ tari_utilities = { git = "https://github.com/tari-project/tari_utilities.git", t
tari_crypto = { git = "https://github.com/tari-project/tari-crypto.git", tag = "v0.16.8" }
tari_common = {path = "../../common"}
thiserror = "1.0.26"
borsh = "0.9.3"
digest = "0.9.0"
log = "0.4"
serde = { version = "1.0.97", features = ["derive"] }
Expand Down
70 changes: 45 additions & 25 deletions base_layer/mmr/src/balanced_binary_merkle_proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,29 @@
// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

use std::{collections::HashMap, convert::TryInto, marker::PhantomData};
use std::{
collections::HashMap,
convert::{TryFrom, TryInto},
marker::PhantomData,
};

use borsh::{BorshDeserialize, BorshSerialize};
use digest::Digest;
use serde::{Deserialize, Serialize};
use tari_common::DomainDigest;
use tari_utilities::ByteArray;
use thiserror::Error;

use crate::{common::hash_together, BalancedBinaryMerkleTree, Hash};

#[derive(Debug)]
pub(crate) fn cast_to_u32(value: usize) -> Result<u32, BalancedBinaryMerkleProofError> {
u32::try_from(value).map_err(|_| BalancedBinaryMerkleProofError::MathOverflow)
}

#[derive(BorshDeserialize, BorshSerialize, Deserialize, Serialize, Clone, Debug, Default, PartialEq, Eq)]
pub struct BalancedBinaryMerkleProof<D> {
pub path: Vec<Hash>,
pub node_index: usize,
pub node_index: u32,
_phantom: PhantomData<D>,
}

Expand All @@ -55,7 +65,10 @@ where D: Digest + DomainDigest
&computed_root == root
}

pub fn generate_proof(tree: &BalancedBinaryMerkleTree<D>, leaf_index: usize) -> Self {
pub fn generate_proof(
tree: &BalancedBinaryMerkleTree<D>,
leaf_index: usize,
) -> Result<Self, BalancedBinaryMerkleProofError> {
let mut node_index = tree.get_node_index(leaf_index);
let mut proof = Vec::new();
while node_index > 0 {
Expand All @@ -67,20 +80,22 @@ where D: Digest + DomainDigest
// Traverse to parent
node_index = parent;
}
Self {
Ok(Self {
path: proof,
node_index: tree.get_node_index(leaf_index),
node_index: cast_to_u32(tree.get_node_index(leaf_index))?,
_phantom: PhantomData,
}
})
}
}

#[derive(Debug, Error)]
pub enum MergedBalancedBinaryMerkleProofError {
pub enum BalancedBinaryMerkleProofError {
#[error("Can't merge zero proofs.")]
CantMergeZeroProofs,
#[error("Bad proof semantics")]
BadProofSemantics,
#[error("Math overflow")]
MathOverflow,
}

/// Flag to indicate if proof data represents an index or a node hash
Expand All @@ -94,8 +109,8 @@ pub enum MergedBalancedBinaryMerkleDataType {
#[derive(Debug)]
pub struct MergedBalancedBinaryMerkleProof<D> {
pub paths: Vec<Vec<(MergedBalancedBinaryMerkleDataType, Vec<u8>)>>, // these tuples can contain indexes or hashes!
pub node_indices: Vec<usize>,
pub heights: Vec<usize>,
pub node_indices: Vec<u32>,
pub heights: Vec<u32>,
_phantom: PhantomData<D>,
}

Expand All @@ -104,20 +119,23 @@ where D: Digest + DomainDigest
{
pub fn create_from_proofs(
proofs: Vec<BalancedBinaryMerkleProof<D>>,
) -> Result<Self, MergedBalancedBinaryMerkleProofError> {
let heights = proofs.iter().map(|proof| proof.path.len()).collect::<Vec<_>>();
) -> Result<Self, BalancedBinaryMerkleProofError> {
let heights = proofs
.iter()
.map(|proof| cast_to_u32(proof.path.len()))
.collect::<Result<Vec<_>, _>>()?;
let max_height = heights
.iter()
.max()
.ok_or(MergedBalancedBinaryMerkleProofError::CantMergeZeroProofs)?;
.ok_or(BalancedBinaryMerkleProofError::CantMergeZeroProofs)?;
let mut indices = proofs.iter().map(|proof| proof.node_index).collect::<Vec<_>>();
let mut paths = vec![Vec::new(); proofs.len()];
let mut join_indices = vec![None; proofs.len()];
for height in (0..*max_height).rev() {
let mut hash_map = HashMap::new();
for (index, proof) in proofs.iter().enumerate() {
// If this path was already joined ignore it.
if join_indices[index].is_none() && proof.path.len() > height {
if join_indices[index].is_none() && proof.path.len() > height as usize {
let parent = (indices[index] - 1) >> 1;
if let Some(other_proof) = hash_map.insert(parent, index) {
join_indices[index] = Some(other_proof);
Expand All @@ -129,7 +147,7 @@ where D: Digest + DomainDigest
0,
(
MergedBalancedBinaryMerkleDataType::Hash,
proof.path[proof.path.len() - 1 - height].clone(),
proof.path[proof.path.len() - 1 - height as usize].clone(),
),
);
}
Expand All @@ -149,19 +167,19 @@ where D: Digest + DomainDigest
mut self,
root: &Hash,
leaves_hashes: Vec<Hash>,
) -> Result<bool, MergedBalancedBinaryMerkleProofError> {
) -> Result<bool, BalancedBinaryMerkleProofError> {
// Check that the proof and verifier data match
let n = self.node_indices.len(); // number of merged proofs
if self.paths.len() != n || leaves_hashes.len() != n {
return Err(MergedBalancedBinaryMerkleProofError::BadProofSemantics);
return Err(BalancedBinaryMerkleProofError::BadProofSemantics);
}

let mut computed_hashes = leaves_hashes;
let max_height = self
.heights
.iter()
.max()
.ok_or(MergedBalancedBinaryMerkleProofError::CantMergeZeroProofs)?;
.ok_or(BalancedBinaryMerkleProofError::CantMergeZeroProofs)?;

// We need to compute the hashes row by row to be sure they are processed correctly.
for height in (0..*max_height).rev() {
Expand All @@ -177,14 +195,14 @@ where D: Digest + DomainDigest
.1
.as_bytes()
.try_into()
.map_err(|_| MergedBalancedBinaryMerkleProofError::BadProofSemantics)?,
.map_err(|_| BalancedBinaryMerkleProofError::BadProofSemantics)?,
);

// The index must also point to one of the proofs
if index < hashes.len() {
&hashes[index]
} else {
return Err(MergedBalancedBinaryMerkleProofError::BadProofSemantics);
return Err(BalancedBinaryMerkleProofError::BadProofSemantics);
}
},
MergedBalancedBinaryMerkleDataType::Hash => &hash_or_index.1,
Expand Down Expand Up @@ -223,11 +241,11 @@ mod test {
let hash_last = leaves[n - 1].clone();
let bmt = BalancedBinaryMerkleTree::<DomainSeparatedHasher<Blake256, TestDomain>>::create(leaves);
let root = bmt.get_merkle_root();
let proof = BalancedBinaryMerkleProof::generate_proof(&bmt, 0);
let proof = BalancedBinaryMerkleProof::generate_proof(&bmt, 0).unwrap();
assert!(proof.verify(&root, hash_0));
let proof = BalancedBinaryMerkleProof::generate_proof(&bmt, n / 2);
let proof = BalancedBinaryMerkleProof::generate_proof(&bmt, n / 2).unwrap();
assert!(proof.verify(&root, hash_n_half));
let proof = BalancedBinaryMerkleProof::generate_proof(&bmt, n - 1);
let proof = BalancedBinaryMerkleProof::generate_proof(&bmt, n - 1).unwrap();
assert!(proof.verify(&root, hash_last));
}
}
Expand All @@ -241,7 +259,8 @@ mod test {
let proofs = indices
.iter()
.map(|i| BalancedBinaryMerkleProof::generate_proof(&bmt, *i))
.collect::<Vec<_>>();
.collect::<Result<Vec<_>, _>>()
.unwrap();
let merged_proof = MergedBalancedBinaryMerkleProof::create_from_proofs(proofs).unwrap();
assert!(merged_proof
.verify_consume(&root, indices.iter().map(|i| leaves[*i].clone()).collect::<Vec<_>>())
Expand All @@ -255,7 +274,8 @@ mod test {
let root = bmt.get_merkle_root();
let proofs = (0..255)
.map(|i| BalancedBinaryMerkleProof::generate_proof(&bmt, i))
.collect::<Vec<_>>();
.collect::<Result<Vec<_>, _>>()
.unwrap();
let merged_proof = MergedBalancedBinaryMerkleProof::create_from_proofs(proofs).unwrap();
assert!(merged_proof.verify_consume(&root, leaves).unwrap());
}
Expand Down
12 changes: 9 additions & 3 deletions base_layer/mmr/src/balanced_binary_merkle_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,24 @@
// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

use std::marker::PhantomData;
use std::{convert::TryFrom, marker::PhantomData};

use digest::Digest;
use tari_common::DomainDigest;
use thiserror::Error;

use crate::{common::hash_together, ArrayLike, Hash};

pub(crate) fn cast_to_u32(value: usize) -> Result<u32, BalancedBinaryMerkleTreeError> {
u32::try_from(value).map_err(|_| BalancedBinaryMerkleTreeError::MathOverFlow)
}

#[derive(Clone, Debug, PartialEq, Eq, Error)]
pub enum BalancedBinaryMerkleTreeError {
#[error("There is no leaf with the hash provided.")]
LeafNotFound,
#[error("Math overflow")]
MathOverFlow,
}

// The hashes are perfectly balanced binary tree, so parent at index `i` (0-based) has children at positions `2*i+1` and
Expand Down Expand Up @@ -92,7 +98,7 @@ where D: Digest + DomainDigest
leaf_index + (self.hashes.len() >> 1)
}

pub fn find_leaf_index_for_hash(&self, hash: &Hash) -> Result<usize, BalancedBinaryMerkleTreeError> {
pub fn find_leaf_index_for_hash(&self, hash: &Hash) -> Result<u32, BalancedBinaryMerkleTreeError> {
let pos = self
.hashes
.position(hash)
Expand All @@ -102,7 +108,7 @@ where D: Digest + DomainDigest
// The hash provided was not for leaf, but for node.
Err(BalancedBinaryMerkleTreeError::LeafNotFound)
} else {
Ok(pos - (self.hashes.len() >> 1))
Ok(cast_to_u32(pos - (self.hashes.len() >> 1))?)
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion base_layer/mmr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ pub mod pruned_hashset;
pub use backend::{ArrayLike, ArrayLikeExt};
pub use balanced_binary_merkle_proof::{
BalancedBinaryMerkleProof,
BalancedBinaryMerkleProofError,
MergedBalancedBinaryMerkleProof,
MergedBalancedBinaryMerkleProofError,
};
pub use balanced_binary_merkle_tree::{BalancedBinaryMerkleTree, BalancedBinaryMerkleTreeError};
/// MemBackendVec is a shareable, memory only, vector that can be be used with MmrCache to store checkpoints.
Expand Down

0 comments on commit 3b971a3

Please sign in to comment.